whisper 模型源码解读

whisper官方源码

whisper 模型官方代码:https://github.com/openai/whisper/blob/main/whisper/model.py ;注释如下

python 复制代码
import base64
import gzip
from dataclasses import dataclass
from typing import Dict, Iterable, Optional

import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor, nn

# 从其他模块导入必要的函数
from .decoding import decode as decode_function
from .decoding import detect_language as detect_language_function
from .transcribe import transcribe as transcribe_function

@dataclass
class ModelDimensions:
    """
    该类用于存储模型的各项参数
    """
    n_mels: int  # Mel谱图的频带数量
    n_audio_ctx: int  # 音频上下文窗口大小
    n_audio_state: int  # 音频状态维度
    n_audio_head: int  # 音频注意力头数量
    n_audio_layer: int  # 音频层数量
    n_vocab: int  # 词汇表大小
    n_text_ctx: int  # 文本上下文窗口大小
    n_text_state: int  # 文本状态维度
    n_text_head: int  # 文本注意力头数量
    n_text_layer: int  # 文本层数量

class LayerNorm(nn.LayerNorm):
    def forward(self, x: Tensor) -> Tensor:
        """
        重写 forward 方法,确保输入张量的类型在归一化前后保持一致
        """
        return super().forward(x.float()).type(x.dtype)

class Linear(nn.Linear):
    def forward(self, x: Tensor) -> Tensor:
        """
        重写 forward 方法,确保权重和偏置与输入张量的类型一致
        """
        return F.linear(
            x,
            self.weight.to(x.dtype),
            None if self.bias is None else self.bias.to(x.dtype),
        )

class Conv1d(nn.Conv1d):
    def _conv_forward(
        self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
    ) -> Tensor:
        """
        重写 _conv_forward 方法,确保卷积操作中的权重和偏置与输入张量的类型一致
        """
        return super()._conv_forward(
            x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
        )

def sinusoids(length, channels, max_timescale=10000):
    """
    生成用于位置嵌入的正弦曲线
    """
    assert channels % 2 == 0
    log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
    inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
    scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
    return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)

class MultiHeadAttention(nn.Module):
    def __init__(self, n_state: int, n_head: int):
        """
        初始化多头注意力层
        """
        super().__init__()
        self.n_head = n_head
        self.query = Linear(n_state, n_state)
        self.key = Linear(n_state, n_state, bias=False)
        self.value = Linear(n_state, n_state)
        self.out = Linear(n_state, n_state)

    def forward(
        self,
        x: Tensor,
        xa: Optional[Tensor] = None,
        mask: Optional[Tensor] = None,
        kv_cache: Optional[dict] = None,
    ):
        """
        多头注意力的前向传播
        """
        q = self.query(x)

        if kv_cache is None or xa is None or self.key not in kv_cache:
            # 如果没有缓存键和值,则正常计算
            k = self.key(x if xa is None else xa)
            v = self.value(x if xa is None else xa)
        else:
            # 如果有缓存,则使用缓存的键和值
            k = kv_cache[self.key]
            v = kv_cache[self.value]

        wv, qk = self.qkv_attention(q, k, v, mask)
        return self.out(wv), qk

    def qkv_attention(
        self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
    ):
        """
        计算 QKV 注意力
        """
        n_batch, n_ctx, n_state = q.shape
        scale = (n_state // self.n_head) ** -0.25
        q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
        k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
        v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)

        qk = q @ k
        if mask is not None:
            qk = qk + mask[:n_ctx, :n_ctx]
        qk = qk.float()

        w = F.softmax(qk, dim=-1).to(q.dtype)
        return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()

class ResidualAttentionBlock(nn.Module):
    def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
        """
        初始化残差注意力块
        """
        super().__init__()

        self.attn = MultiHeadAttention(n_state, n_head)
        self.attn_ln = LayerNorm(n_state)

        self.cross_attn = (
            MultiHeadAttention(n_state, n_head) if cross_attention else None
        )
        self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None

        n_mlp = n_state * 4
        self.mlp = nn.Sequential(
            Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
        )
        self.mlp_ln = LayerNorm(n_state)

    def forward(
        self,
        x: Tensor,
        xa: Optional[Tensor] = None,
        mask: Optional[Tensor] = None,
        kv_cache: Optional[dict] = None,
    ):
        """
        残差注意力块的前向传播
        """
        x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
        if self.cross_attn:
            x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
        x = x + self.mlp(self.mlp_ln(x))
        return x

class AudioEncoder(nn.Module):
    def __init__(
        self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
    ):
        """
        初始化音频编码器
        """
        super().__init__()
        self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
        self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
        self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))

        self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
            [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
        )
        self.ln_post = LayerNorm(n_state)

    def forward(self, x: Tensor):
        """
        前向传播,处理音频输入

        x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
            音频的Mel谱图
        """
        x = F.gelu(self.conv1(x))
        x = F.gelu(self.conv2(x))
        x = x.permute(0, 2, 1)

        assert x.shape[1:] == self.positional_embedding.shape, "音频形状不正确"
        x = (x + self.positional_embedding).to(x.dtype)

        for block in self.blocks:
            x = block(x)

        x = self.ln_post(x)
        return x

class TextDecoder(nn.Module):
    def __init__(
        self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
    ):
        """
        初始化文本解码器
        """
        super().__init__()

        self.token_embedding = nn.Embedding(n_vocab, n_state)
        self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))

        self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
            [
                ResidualAttentionBlock(n_state, n_head, cross_attention=True)
                for _ in range(n_layer)
            ]
        )
        self.ln = LayerNorm(n_state)

        mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
        self.register_buffer("mask", mask, persistent=False)

    def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
        """
        前向传播,处理文本输入并结合音频特征

        x : torch.LongTensor, shape = (batch_size, <= n_ctx)
            文本的标

记序列
        xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
            编码后的音频特征
        """
        offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
        x = (
            self.token_embedding(x)
            + self.positional_embedding[offset : offset + x.shape[-1]]
        )
        x = x.to(xa.dtype)

        for block in self.blocks:
            x = block(x, xa, mask=self.mask, kv_cache=kv_cache)

        x = self.ln(x)
        logits = (
            x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
        ).float()

        return logits

class Whisper(nn.Module):
    def __init__(self, dims: ModelDimensions):
        """
        初始化 Whisper 模型
        """
        super().__init__()
        self.dims = dims
        self.encoder = AudioEncoder(
            self.dims.n_mels,
            self.dims.n_audio_ctx,
            self.dims.n_audio_state,
            self.dims.n_audio_head,
            self.dims.n_audio_layer,
        )
        self.decoder = TextDecoder(
            self.dims.n_vocab,
            self.dims.n_text_ctx,
            self.dims.n_text_state,
            self.dims.n_text_head,
            self.dims.n_text_layer,
        )
        # 默认情况下,使用解码器层的后一半进行时间对齐;
        # 若要使用特定的注意力头,可以使用 `set_alignment_heads()` 方法。
        all_heads = torch.zeros(
            self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
        )
        all_heads[self.dims.n_text_layer // 2 :] = True
        self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)

    def set_alignment_heads(self, dump: bytes):
        """
        设置对齐的注意力头
        """
        array = np.frombuffer(
            gzip.decompress(base64.b85decode(dump)), dtype=bool
        ).copy()
        mask = torch.from_numpy(array).reshape(
            self.dims.n_text_layer, self.dims.n_text_head
        )
        self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)

    def embed_audio(self, mel: torch.Tensor):
        """
        编码音频特征
        """
        return self.encoder(mel)

    def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
        """
        获取预测的logits
        """
        return self.decoder(tokens, audio_features)

    def forward(
        self, mel: torch.Tensor, tokens: torch.Tensor
    ) -> Dict[str, torch.Tensor]:
        """
        前向传播
        """
        return self.decoder(tokens, self.encoder(mel))

    @property
    def device(self):
        """
        获取模型所在的设备
        """
        return next(self.parameters()).device

    @property
    def is_multilingual(self):
        """
        判断模型是否支持多语言
        """
        return self.dims.n_vocab >= 51865

    @property
    def num_languages(self):
        """
        获取模型支持的语言数量
        """
        return self.dims.n_vocab - 51765 - int(self.is_multilingual)

    def install_kv_cache_hooks(self, cache: Optional[dict] = None):
        """
        为键和值的投影模块安装缓存钩子

        返回
        -------
        cache : Dict[nn.Module, torch.Tensor]
            映射键/值投影模块到其缓存的字典对象
        hooks : List[RemovableHandle]
            用于停止调用钩子的 PyTorch RemovableHandle 对象列表
        """
        cache = {**cache} if cache is not None else {}
        hooks = []

        def save_to_cache(module, _, output):
            if module not in cache or output.shape[1] > self.dims.n_text_ctx:
                # 第一次标记或交叉注意时保存原始值
                cache[module] = output
            else:
                cache[module] = torch.cat([cache[module], output], dim=1).detach()
            return cache[module]

        def install_hooks(layer: nn.Module):
            if isinstance(layer, MultiHeadAttention):
                hooks.append(layer.key.register_forward_hook(save_to_cache))
                hooks.append(layer.value.register_forward_hook(save_to_cache))

        self.decoder.apply(install_hooks)
        return cache, hooks

    detect_language = detect_language_function  # 语言检测函数
    transcribe = transcribe_function  # 转录函数
    decode = decode_function  # 解码函数

语音识别自回归解码过程分析和举例说明

分析

语音识别自回归解码过程通常涉及以下步骤:

  1. 音频预处理:首先将输入的音频信号转换为Mel谱图。这一步骤在实际应用中通常由音频前端处理模块完成。

  2. 音频编码:将预处理后的Mel谱图输入到音频编码器中,生成音频特征表示。这些特征表示将作为后续文本解码器的输入。

  3. 文本解码:文本解码器通过自回归方式生成文本序列。具体来说,文本解码器在每个时间步上根据前一步生成的文本标记以及音频特征生成下一个文本标记。

  4. 语言检测和转录:在生成的文本序列基础上,可以进行语言检测,确认文本所使用的语言。此外,转录过程将生成的文本序列转换为最终的文本输出。

具体步骤

以下代码展示了上述过程的具体实现:

python 复制代码
import torch

# 初始化模型参数
dims = ModelDimensions(
    n_mels=80,
    n_audio_ctx=1500,
    n_audio_state=512,
    n_audio_head=8,
    n_audio_layer=6,
    n_vocab=51865,
    n_text_ctx=448,
    n_text_state=512,
    n_text_head=8,
    n_text_layer=6,
)

# 创建模型实例
model = Whisper(dims)

# 假设我们有一个Mel谱图输入
mel_spectrogram = torch.randn(1, 80, 1500)  # (batch_size, n_mels, n_audio_ctx)

# 编码音频特征
audio_features = model.embed_audio(mel_spectrogram)

# 假设我们有一个初始的文本标记序列
initial_tokens = torch.tensor([[1, 2, 3]])  # (batch_size, seq_len)

# 自回归解码过程
for _ in range(10):  # 假设生成长度为10的序列
    logits = model.logits(initial_tokens, audio_features)
    next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
    initial_tokens = torch.cat([initial_tokens, next_token], dim=-1)

# 最终生成的文本标记序列
final_tokens = initial_tokens

# 打印生成的文本标记序列
print("Generated tokens:", final_tokens)

举例说明

假设我们有一段音频,其Mel谱图表示如下:

mel_spectrogram = torch.randn(1, 80, 1500)

我们希望通过自回归解码生成对应的文本表示。首先,我们将Mel谱图输入到音频编码器中,得到音频特征表示:

python 复制代码
audio_features = model.embed_audio(mel_spectrogram)

然后,我们使用一个初始的文本标记序列(例如,序列开始标记)开始自回归解码过程:

python 复制代码
initial_tokens = torch.tensor([[1]])  # 序列开始标记

在每个时间步,我们根据当前的文本标记序列和音频特征生成下一个文本标记:

python 复制代码
logits = model.logits(initial_tokens, audio_features)
next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
initial_tokens = torch.cat([initial_tokens, next_token], dim=-1)

这个过程重复若干次(例如10次)直到生成完整的文本序列:

python 复制代码
for _ in range(10):
    logits = model.logits(initial_tokens, audio_features)
    next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
    initial_tokens = torch.cat([initial_tokens, next_token], dim=-1)

最终得到的文本标记序列为:

python 复制代码
final_tokens = initial_tokens
print("Generated tokens:", final_tokens)

以上示例展示了从音频输入到文本输出的完整自回归解码过程。

相关推荐
linzhiji13 小时前
whisper get_writer srt_writer 参数说明
前端·python·whisper
花晓木14 小时前
Linux系统上部署Whisper。
linux·运维·whisper
爱看书的小沐3 天前
【小沐学AI】Python实现语音识别(Whisper-Web)
人工智能·python·ai·nlp·whisper·openai·语音识别
MonkeyKing_sunyuhua4 天前
whisper 实现语音转文字
whisper
只恨天高4 天前
最新AI智能聊天对话问答系统源码(图文搭建部署教程)+AI绘画,文生图,TTS语音识别输入,文档分析
人工智能·ai作画·whisper·语音识别
路人与大师12 天前
深入了解 Whisper 的架构、用法以及在语音识别领域的应用和性能特征
人工智能·whisper·语音识别
Ephemeroptera16 天前
导出 Whisper 模型到 ONNX
whisper·openai·语音识别·onnx·int8
平底斜16 天前
优化你的WordPress网站:内链建设与Link Whisper Pro插件的利用
whisper
STONE_KKK23 天前
本地部署Whisper实现语言转文字
人工智能·whisper