Wenet代码分析:混合CTC-Attention的端到端语音识别模型`ASRModel`

Wenet代码分析:混合CTC-Attention的端到端语音识别模型ASRModel

代码文件位置:wenet/transformer/asr_model.py

导入必要的库

python 复制代码
from typing import Dict, List, Optional, Tuple

import torch
from torch.nn.utils.rnn import pad_sequence

from wenet.transformer.ctc import CTC
from wenet.transformer.decoder import TransformerDecoder
from wenet.transformer.encoder import BaseEncoder
from wenet.transformer.label_smoothing_loss import LabelSmoothingLoss
from wenet.transformer.search import (ctc_greedy_search,
                                      ctc_prefix_beam_search,
                                      attention_beam_search,
                                      attention_rescoring, DecodeResult)
from wenet.utils.mask import make_pad_mask
from wenet.utils.common import (IGNORE_ID, add_sos_eos, th_accuracy,
                                reverse_pad_list)
from wenet.utils.context_graph import ContextGraph

ASRModel类定义

1. 初始化模型

这个部分初始化了模型的各个组件,包括编码器、解码器、CTC模块和损失函数。还设置了一些重要的超参数,如ctc_weightreverse_weightlsm_weight

python 复制代码
class ASRModel(torch.nn.Module):
    """CTC-attention hybrid Encoder-Decoder model
    这是一个CTC-注意力混合编码器-解码器模型,用于语音识别。
    """

    def __init__(
        self,
        vocab_size: int,                     # 词汇大小,即输出词汇的总数
        encoder: BaseEncoder,                # 编码器模型
        decoder: TransformerDecoder,         # 解码器模型
        ctc: CTC,                            # CTC模块
        ctc_weight: float = 0.5,             # CTC损失的权重
        ignore_id: int = IGNORE_ID,          # 忽略标识符,用于填充
        reverse_weight: float = 0.0,         # 逆序解码器的权重
        lsm_weight: float = 0.0,             # 标签平滑损失的权重
        length_normalized_loss: bool = False,# 是否对损失进行长度归一化
        special_tokens: Optional[dict] = None, # 特殊标记的字典,例如<sos>和<eos>
        apply_non_blank_embedding: bool = False, # 是否使用非空白嵌入
    ):
        assert 0.0 <= ctc_weight <= 1.0, ctc_weight

        super().__init__()
        # 设置起始和结束符号的索引。如果没有提供特殊标记,则默认设置为词汇表的最后一个索引
        self.sos = (vocab_size - 1 if special_tokens is None else
                    special_tokens.get("<sos>", vocab_size - 1))
        self.eos = (vocab_size - 1 if special_tokens is None else
                    special_tokens.get("<eos>", vocab_size - 1))
        self.vocab_size = vocab_size
        self.special_tokens = special_tokens
        self.ignore_id = ignore_id
        self.ctc_weight = ctc_weight
        self.reverse_weight = reverse_weight
        self.apply_non_blank_embedding = apply_non_blank_embedding

        # 初始化编码器、解码器和CTC模块
        self.encoder = encoder
        self.decoder = decoder
        self.ctc = ctc

        # 初始化标签平滑损失函数
        self.criterion_att = LabelSmoothingLoss(
            size=vocab_size,                 # 词汇大小
            padding_idx=ignore_id,           # 忽略标识符
            smoothing=lsm_weight,            # 标签平滑的权重
            normalize_length=length_normalized_loss, # 是否进行长度归一化
        )
2. 前向传播

这个部分实现了前向传播函数forward,包含了以下步骤:

  1. 将输入的语音数据和目标文本数据移到指定的设备上。
  2. 通过编码器处理语音数据,生成编码器输出和掩码。
  3. 如果ctc_weight不为0,则计算CTC损失和CTC概率。
  4. 如果apply_non_blank_embedding为真,则过滤空白嵌入。
  5. 计算注意力解码器的损失和准确率。
  6. 根据CTC损失和注意力解码器的损失,计算总损失。
python 复制代码
@torch.jit.unused
def forward(
    self,
    batch: dict,               # 输入的批次数据,包括特征、特征长度、目标和目标长度
    device: torch.device,      # 运行设备
) -> Dict[str, Optional[torch.Tensor]]:
    """Frontend + Encoder + Decoder + Calc loss
    前端 + 编码器 + 解码器 + 计算损失
    """
    
    # 将输入数据转移到指定设备上
    speech = batch['feats'].to(device)
    speech_lengths = batch['feats_lengths'].to(device)
    text = batch['target'].to(device)
    text_lengths = batch['target_lengths'].to(device)

    # 确保目标长度的维度为1(即每个样本对应一个长度)
    assert text_lengths.dim() == 1, text_lengths.shape
    
    # 检查批次大小是否一致
    assert (speech.shape[0] == speech_lengths.shape[0] == text.shape[0] ==
            text_lengths.shape[0]), (speech.shape, speech_lengths.shape,
                                     text.shape, text_lengths.shape)
    
    # 1. 编码器:将语音特征编码为隐藏状态
    encoder_out, encoder_mask = self.encoder(speech, speech_lengths)
    encoder_out_lens = encoder_mask.squeeze(1).sum(1)  # 计算每个样本的有效长度

    # 2a. CTC分支:计算CTC损失
    if self.ctc_weight != 0.0:
        loss_ctc, ctc_probs = self.ctc(encoder_out, encoder_out_lens, text,
                                       text_lengths)
    else:
        loss_ctc, ctc_probs = None, None

    # 2b. 注意力解码器分支
    # 如果应用非空白嵌入,将CTC概率作为解码器输入
    if self.apply_non_blank_embedding:
        assert self.ctc_weight != 0
        assert ctc_probs is not None
        encoder_out, encoder_mask = self.filter_blank_embedding(
            ctc_probs, encoder_out)
    
    # 如果CTC权重不为1.0,计算注意力损失
    if self.ctc_weight != 1.0:
        loss_att, acc_att = self._calc_att_loss(
            encoder_out, encoder_mask, text, text_lengths, {
                "langs": batch["langs"],
                "tasks": batch["tasks"]
            })
    else:
        loss_att = None
        acc_att = None

    # 组合CTC损失和注意力损失
    if loss_ctc is None:
        loss = loss_att
    elif loss_att is None:
        loss = loss_ctc
    else:
        loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
    
    # 返回损失和准确度
    return {
        "loss": loss,
        "loss_att": loss_att,
        "loss_ctc": loss_ctc,
        "th_accuracy": acc_att,
    }
3. 计算注意力损失

这个方法计算注意力损失:

  1. 添加起始符和结束符到目标序列。
  2. 反转目标序列,用于从右到左的解码器。
  3. 前向传播解码器,得到解码器输出。
  4. 计算注意力损失和准确率。
python 复制代码
 def _calc_att_loss(
    self,
    encoder_out: torch.Tensor,              # 编码器的输出 (B, Tmax, D),B是批次大小,Tmax 是时间步数,D是特征维度。
    encoder_mask: torch.Tensor,             # 编码器的掩码 (B, 1, Tmax)
    ys_pad: torch.Tensor,                   # 目标序列,填充后的 (B, Lmax),Lmax 是目标序列的最大长度。
    ys_pad_lens: torch.Tensor,              # 目标序列的长度 (B)
    infos: Dict[str, List[str]] = None,     # 额外的信息,用于多任务学习 (可选)
) -> Tuple[torch.Tensor, torch.Tensor]:     # 返回注意力损失和准确度

    # 添加起始和结束标记,并调整目标序列和目标序列长度
    ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
    ys_in_lens = ys_pad_lens + 1  # 增加1以包括起始标记

    # 将目标序列进行反转,用于右到左解码器
    r_ys_pad = reverse_pad_list(ys_pad, ys_pad_lens, float(self.ignore_id))
    r_ys_in_pad, r_ys_out_pad = add_sos_eos(r_ys_pad, self.sos, self.eos, self.ignore_id)

    # 1. 前向解码器
    # 通过解码器生成输出和反向解码器输出
    decoder_out, r_decoder_out, _ = self.decoder(encoder_out, encoder_mask,
                                                 ys_in_pad, ys_in_lens,
                                                 r_ys_in_pad, self.reverse_weight)

    # 2. 计算注意力损失
    # 使用标签平滑损失函数计算正向解码器的损失
    loss_att = self.criterion_att(decoder_out, ys_out_pad)
    
    # 初始化反向解码器的损失为0
    r_loss_att = torch.tensor(0.0)
    if self.reverse_weight > 0.0:
        # 如果反向权重大于0,计算反向解码器的损失
        r_loss_att = self.criterion_att(r_decoder_out, r_ys_out_pad)
    
    # 根据正向和反向权重组合最终的注意力损失
    loss_att = loss_att * (1 - self.reverse_weight) + r_loss_att * self.reverse_weight
    
    # 计算准确度
    acc_att = th_accuracy(
        decoder_out.view(-1, self.vocab_size),  # 将解码器输出调整为二维 (B*Lmax, vocab_size)
        ys_out_pad,                            # 目标输出 (B, Lmax)
        ignore_label=self.ignore_id,           # 忽略的填充标识符
    )
    
    # 返回注意力损失和准确度
    return loss_att, acc_att
4. 过滤空白嵌入

过滤掉CTC解码过程中生成的空白标记:

  1. 获取CTC概率的最大索引。
  2. 为每个批次选择非空白标记对应的编码器输出。
  3. 填充序列并生成相应的掩码。
python 复制代码
   def filter_blank_embedding(
        self, ctc_probs: torch.Tensor,           # CTC概率 (B, Tmax, vocab_size)
        encoder_out: torch.Tensor                # 编码器输出 (B, Tmax, D)
    ) -> Tuple[torch.Tensor, torch.Tensor]:      # 返回过滤后的编码器输出和编码器掩码
    """
    过滤CTC解码过程中生成的空白标记,保留非空白标记的嵌入。
    Args:
        ctc_probs: CTC模型输出的概率分布,形状为 (B, Tmax, vocab_size)
        encoder_out: 编码器的输出,形状为 (B, Tmax, D)
    Returns:
        encoder_out: 过滤后的编码器输出,仅保留非空白标记的嵌入
        encoder_mask: 对应的编码器掩码
    """
    batch_size = encoder_out.size(0)              # 获取批次大小
    maxlen = encoder_out.size(1)                  # 获取时间步的最大长度
    top1_index = torch.argmax(ctc_probs, dim=2)   # 获取每个时间步上CTC概率最大的索引 (B, Tmax)
    
    indices = []
    for j in range(batch_size):
        # 对于每个样本,选择非空白标记的索引
        indices.append(
            torch.tensor(
                [i for i in range(maxlen) if top1_index[j][i] != 0]
            )
        )

    # 根据非空白标记的索引,选择对应的编码器输出
    select_encoder_out = [
        torch.index_select(encoder_out[i, :, :], 0,
                           indices[i].to(encoder_out.device))
        for i in range(batch_size)
    ]
    # 将选择的编码器输出序列填充为相同长度
    select_encoder_out = pad_sequence(select_encoder_out,
                                      batch_first=True,
                                      padding_value=0).to(encoder_out.device)
    
    # 计算选择后的序列长度
    xs_lens = torch.tensor([len(indices[i]) for i in range(batch_size)]).to(encoder_out.device)
    T = select_encoder_out.size(1)                # 获取填充后的时间步最大长度
    # 创建编码器掩码
    encoder_mask = ~make_pad_mask(xs_lens, T).unsqueeze(1)  # (B, 1, T)
    encoder_out = select_encoder_out                # 更新编码器输出为选择后的结果
    
    return encoder_out, encoder_mask                # 返回编码器输出和掩码
5. 解码

这个方法实现了解码功能,支持多种解码方法,如CTC贪婪搜索、CTC前缀束搜索、注意力解码和注意力重评分:

  1. 根据指定的解码方法,选择相应的解码策略。
  2. 通过编码器处理输入的语音数据。
  3. 根据CTC概率和注意力机制生成解码结果。
python 复制代码
def decode(
    self,
    methods: List[str],                     # 解码方法列表
    speech: torch.Tensor,                   # 输入语音特征 (batch, max_len, feat_dim)
    speech_lengths: torch.Tensor,           # 输入语音长度 (batch, )
    beam_size: int,                         # 集束搜索的束宽
    decoding_chunk_size: int = -1,          # 解码块的大小
    num_decoding_left_chunks: int = -1,     # 剩余解码块的数量
    ctc_weight: float = 0.0,                # CTC得分的权重
    simulate_streaming: bool = False,       # 是否模拟流式解码
    reverse_weight: float = 0.0,            # 反向解码器的权重
    context_graph: ContextGraph = None,     # 上下文图
    blank_id: int = 0,                      # 空白标记的ID
    blank_penalty: float = 0.0,             # 空白标记的惩罚
    length_penalty: float = 0.0,            # 长度惩罚
    infos: Dict[str, List[str]] = None,     # 额外的信息
) -> Dict[str, List[DecodeResult]]:         # 返回解码结果的字典
    """
    解码输入语音
    Args:
        methods: 使用的解码方法列表,包括以下方法:
            * ctc_greedy_search
            * ctc_prefix_beam_search
            * attention
            * attention_rescoring
        speech: 输入语音特征,形状为 (batch, max_len, feat_dim)
        speech_lengths: 输入语音长度,形状为 (batch, )
        beam_size: 集束搜索的束宽
        decoding_chunk_size: 动态块训练模型的解码块大小
            <0: 使用完整块进行解码
            >0: 使用固定块大小进行解码
            0: 训练中使用,不允许在此处使用
        simulate_streaming: 是否以流式方式进行编码器前向计算
        reverse_weight: 反向解码器的权重
        ctc_weight: CTC得分的权重
    Returns:
        各种解码方法的结果字典
    """
    
    # 确保输入的批次大小和长度大小一致
    assert speech.shape[0] == speech_lengths.shape[0]
    # 确保解码块大小不为0
    assert decoding_chunk_size != 0
    
    # 前向编码器,获取编码器输出和掩码
    encoder_out, encoder_mask = self._forward_encoder(
        speech, speech_lengths, decoding_chunk_size,
        num_decoding_left_chunks, simulate_streaming)
    
    # 计算编码器输出的有效长度
    encoder_lens = encoder_mask.squeeze(1).sum(1)
    
    # 计算CTC概率
    ctc_probs = self.ctc_logprobs(encoder_out, blank_penalty, blank_id)
    
    results = {}  # 存储解码结果
    
    # 使用注意力机制进行解码
    if 'attention' in methods:
        results['attention'] = attention_beam_search(
            self, encoder_out, encoder_mask, beam_size, length_penalty, infos)
    
    # 使用CTC贪婪搜索进行解码
    if 'ctc_greedy_search' in methods:
        results['ctc_greedy_search'] = ctc_greedy_search(
            ctc_probs, encoder_lens, blank_id)
    
    # 使用CTC前缀集束搜索进行解码
    if 'ctc_prefix_beam_search' in methods:
        ctc_prefix_result = ctc_prefix_beam_search(
            ctc_probs, encoder_lens, beam_size, context_graph, blank_id)
        results['ctc_prefix_beam_search'] = ctc_prefix_result
    
    # 使用注意力重评分进行解码
    if 'attention_rescoring' in methods:
        # 确保CTC前缀集束搜索结果存在
        if 'ctc_prefix_beam_search' in results:
            ctc_prefix_result = results['ctc_prefix_beam_search']
        else:
            ctc_prefix_result = ctc_prefix_beam_search(
                ctc_probs, encoder_lens, beam_size, context_graph, blank_id)
        
        # 过滤空白标记嵌入
        if self.apply_non_blank_embedding:
            encoder_out, _ = self.filter_blank_embedding(ctc_probs, encoder_out)
        
        # 使用注意力重评分进行解码
        results['attention_rescoring'] = attention_rescoring(
            self, ctc_prefix_result, encoder_out, encoder_lens, ctc_weight,
            reverse_weight, infos)
    
    return results  # 返回解码结果字典

其他辅助方法

1.获取CTC激活

这个方法返回编码器输出经过log_softmax变换后的CTC激活值。

python 复制代码
    @torch.jit.export
    def ctc_activation(self, xs: torch.Tensor) -> torch.Tensor:
        """ Export interface for c++ call, apply linear transform and log
            softmax before ctc
        Args:
            xs (torch.Tensor): encoder output

        Returns:
            torch.Tensor: activation before ctc

        """
        return self.ctc.log_softmax(xs)
2.其他导出接口

这些方法提供了模型的相关信息,如子采样率、右上下文、起始符号和结束符号。

python 复制代码
    @torch.jit.export
    def subsampling_rate(self) -> int:
        """ Export interface for c++ call, return subsampling_rate of the
            model
        """
        return self.encoder.embed.subsampling_rate

    @torch.jit.export
    def right_context(self) -> int:
        """ Export interface for c++ call, return right_context of the model
        """
        return self.encoder.embed.right_context

    @torch.jit.export
    def sos_symbol(self) -> int:
        """ Export interface for c++ call, return sos symbol id of the model
        """
        return self.sos

    @torch.jit.export
    def eos_symbol(self) -> int:
        """ Export interface for c++ call, return eos symbol id of the model
        """
        return self.eos
3.编码器前向传播

这个方法提供了一个接口,用于分块进行编码器前向传播。它返回当前块的输出、注意力缓存和CNN缓存。

python 复制代码
@torch.jit.export
def forward_encoder_chunk(
    self,
    xs: torch.Tensor,                 # chunk 输入,形状为 (b=1, time, mel-dim)
    offset: int,                      # 编码器输出时间戳的当前偏移量
    required_cache_size: int,         # 下一个chunk计算所需的缓存大小
    att_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),  # Transformer/Conformer注意力中的KEY和VALUE的缓存张量
    cnn_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),  # Conformer中cnn模块的缓存张量
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:   # 返回当前输入xs的输出、新的注意力缓存和新的cnn缓存
    """
    导出接口供C++调用,给定输入chunk xs,并返回从时间0到当前chunk的输出。

    Args:
        xs (torch.Tensor): chunk输入,形状为 (b=1, time, mel-dim),
            其中 `time == (chunk_size - 1) * subsample_rate + \
                    subsample.right_context + 1`
        offset (int): 编码器输出时间戳的当前偏移量
        required_cache_size (int): 下一个chunk计算所需的缓存大小
            >=0: 实际缓存大小
            <0: 表示需要所有历史缓存
        att_cache (torch.Tensor): Transformer/Conformer注意力中的KEY和VALUE的缓存张量,形状为
            (elayers, head, cache_t1, d_k * 2),其中 `head * d_k == hidden-dim` 且
            `cache_t1 == chunk_size * num_decoding_left_chunks`。
        cnn_cache (torch.Tensor): Conformer中cnn模块的缓存张量,形状为
            (elayers, b=1, hidden-dim, cache_t2),其中 `cache_t2 == cnn.lorder - 1`

    Returns:
        torch.Tensor: 当前输入xs的输出,形状为 (b=1, chunk_size, hidden-dim)。
        torch.Tensor: 下一个chunk所需的新注意力缓存,形状为
            (elayers, head, ?, d_k * 2),具体取决于 required_cache_size。
        torch.Tensor: 下一个chunk所需的新Conformer cnn缓存,形状与原cnn_cache相同。
    """
    return self.encoder.forward_chunk(xs, offset, required_cache_size,
                                      att_cache, cnn_cache)
4.双向解码器检查

这个方法检查解码器是否为双向解码器。

python 复制代码
    @torch.jit.export
    def is_bidirectional_decoder(self) -> bool:
        """
        Returns:
            torch.Tensor: decoder output
        """
        if hasattr(self.decoder, 'right_decoder'):
            return True
        else:
            return False
前向解码器

这个方法实现了对多假设的解码操作。它将多个假设和一个编码器输出进行解码,并返回解码器输出。

python 复制代码
   @torch.jit.export
def forward_attention_decoder(
    self,
    hyps: torch.Tensor,                 # 从CTC前缀集束搜索中获得的假设,已经在开头填充了<sos>
    hyps_lens: torch.Tensor,            # 每个假设的长度
    encoder_out: torch.Tensor,          # 编码器输出
    reverse_weight: float = 0,          # 用于验证是否使用从右到左的解码器,> 0 将使用
) -> Tuple[torch.Tensor, torch.Tensor]: # 返回解码器输出
    """
    供C++调用的导出接口,使用多个CTC前缀集束搜索的假设和一个编码器输出进行前向解码。
    
    Args:
        hyps (torch.Tensor): 从CTC前缀集束搜索中获得的假设,已经在开头填充了<sos>
        hyps_lens (torch.Tensor): 每个假设的长度
        encoder_out (torch.Tensor): 编码器输出
        r_hyps (torch.Tensor): 从CTC前缀集束搜索中获得的假设,已经在开头填充了<eos>,用于从右到左解码器
        reverse_weight: 用于验证是否使用从右到左解码器,> 0 将使用

    Returns:
        torch.Tensor: 解码器输出
    """
    assert encoder_out.size(0) == 1  # 确保编码器输出的批次大小为1
    num_hyps = hyps.size(0)          # 获取假设的数量
    assert hyps_lens.size(0) == num_hyps  # 确保假设长度的数量与假设数量相同

    # 将编码器输出重复num_hyps次,以匹配假设的数量
    encoder_out = encoder_out.repeat(num_hyps, 1, 1)
    encoder_mask = torch.ones(num_hyps, 1, encoder_out.size(1), 
                              dtype=torch.bool, device=encoder_out.device)  # 创建编码器掩码

    # 处理从右到左的解码器输入
    r_hyps_lens = hyps_lens - 1      # 获取从右到左的假设长度
    r_hyps = hyps[:, 1:]             # 移除开头的<sos>标记
    max_len = torch.max(r_hyps_lens) # 获取最大假设长度
    index_range = torch.arange(0, max_len, 1).to(encoder_out.device)  # 创建索引范围
    seq_len_expand = r_hyps_lens.unsqueeze(1)  # 扩展假设长度以匹配索引范围
    seq_mask = seq_len_expand > index_range    # 创建序列掩码 (beam, max_len)
    index = (seq_len_expand - 1) - index_range # 计算索引 (beam, max_len)
    index = index * seq_mask                   # 应用序列掩码
    r_hyps = torch.gather(r_hyps, 1, index)    # 根据索引选择假设
    r_hyps = torch.where(seq_mask, r_hyps, self.eos)  # 替换无效的假设为<eos>
    r_hyps = torch.cat([hyps[:, 0:1], r_hyps], dim=1) # 在开头添加<sos>

    # 前向解码器
    decoder_out, r_decoder_out, _ = self.decoder(
        encoder_out, encoder_mask, hyps, hyps_lens, r_hyps, reverse_weight)  # (num_hyps, max_hyps_len, vocab_size)
    
    # 应用log_softmax
    decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1)
    r_decoder_out = torch.nn.functional.log_softmax(r_decoder_out, dim=-1)
    
    return decoder_out, r_decoder_out  # 返回解码器输出
相关推荐
骇客野人9 分钟前
【人工智能】循环神经网络学习
人工智能·rnn·学习
速融云2 小时前
汽车制造行业案例 | 发动机在制造品管理全解析(附解决方案模板)
大数据·人工智能·自动化·汽车·制造
AI明说2 小时前
什么是稀疏 MoE?Doubao-1.5-pro 如何以少胜多?
人工智能·大模型·moe·豆包
XianxinMao2 小时前
重构开源LLM分类:从二分到三分的转变
人工智能·语言模型·开源
Elastic 中国社区官方博客3 小时前
使用 Elasticsearch 导航检索增强生成图表
大数据·数据库·人工智能·elasticsearch·搜索引擎·ai·全文检索
云天徽上3 小时前
【数据可视化】全国星巴克门店可视化
人工智能·机器学习·信息可视化·数据挖掘·数据分析
大嘴吧Lucy3 小时前
大模型 | AI驱动的数据分析:利用自然语言实现数据查询到可视化呈现
人工智能·信息可视化·数据分析
艾思科蓝 AiScholar4 小时前
【连续多届EI稳定收录&出版级别高&高录用快检索】第五届机械设计与仿真国际学术会议(MDS 2025)
人工智能·数学建模·自然语言处理·系统架构·机器人·软件工程·拓扑学
watersink4 小时前
面试题库笔记
大数据·人工智能·机器学习
Yuleave4 小时前
PaSa:基于大语言模型的综合学术论文搜索智能体
人工智能·语言模型·自然语言处理