Wenet代码分析:混合CTC-Attention的端到端语音识别模型ASRModel
代码文件位置:wenet/transformer/asr_model.py
导入必要的库
python
from typing import Dict, List, Optional, Tuple
import torch
from torch.nn.utils.rnn import pad_sequence
from wenet.transformer.ctc import CTC
from wenet.transformer.decoder import TransformerDecoder
from wenet.transformer.encoder import BaseEncoder
from wenet.transformer.label_smoothing_loss import LabelSmoothingLoss
from wenet.transformer.search import (ctc_greedy_search,
ctc_prefix_beam_search,
attention_beam_search,
attention_rescoring, DecodeResult)
from wenet.utils.mask import make_pad_mask
from wenet.utils.common import (IGNORE_ID, add_sos_eos, th_accuracy,
reverse_pad_list)
from wenet.utils.context_graph import ContextGraph
ASRModel
类定义
1. 初始化模型
这个部分初始化了模型的各个组件,包括编码器、解码器、CTC模块和损失函数。还设置了一些重要的超参数,如ctc_weight
、reverse_weight
和lsm_weight
。
python
class ASRModel(torch.nn.Module):
"""CTC-attention hybrid Encoder-Decoder model
这是一个CTC-注意力混合编码器-解码器模型,用于语音识别。
"""
def __init__(
self,
vocab_size: int, # 词汇大小,即输出词汇的总数
encoder: BaseEncoder, # 编码器模型
decoder: TransformerDecoder, # 解码器模型
ctc: CTC, # CTC模块
ctc_weight: float = 0.5, # CTC损失的权重
ignore_id: int = IGNORE_ID, # 忽略标识符,用于填充
reverse_weight: float = 0.0, # 逆序解码器的权重
lsm_weight: float = 0.0, # 标签平滑损失的权重
length_normalized_loss: bool = False,# 是否对损失进行长度归一化
special_tokens: Optional[dict] = None, # 特殊标记的字典,例如<sos>和<eos>
apply_non_blank_embedding: bool = False, # 是否使用非空白嵌入
):
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
super().__init__()
# 设置起始和结束符号的索引。如果没有提供特殊标记,则默认设置为词汇表的最后一个索引
self.sos = (vocab_size - 1 if special_tokens is None else
special_tokens.get("<sos>", vocab_size - 1))
self.eos = (vocab_size - 1 if special_tokens is None else
special_tokens.get("<eos>", vocab_size - 1))
self.vocab_size = vocab_size
self.special_tokens = special_tokens
self.ignore_id = ignore_id
self.ctc_weight = ctc_weight
self.reverse_weight = reverse_weight
self.apply_non_blank_embedding = apply_non_blank_embedding
# 初始化编码器、解码器和CTC模块
self.encoder = encoder
self.decoder = decoder
self.ctc = ctc
# 初始化标签平滑损失函数
self.criterion_att = LabelSmoothingLoss(
size=vocab_size, # 词汇大小
padding_idx=ignore_id, # 忽略标识符
smoothing=lsm_weight, # 标签平滑的权重
normalize_length=length_normalized_loss, # 是否进行长度归一化
)
2. 前向传播
这个部分实现了前向传播函数forward
,包含了以下步骤:
- 将输入的语音数据和目标文本数据移到指定的设备上。
- 通过编码器处理语音数据,生成编码器输出和掩码。
- 如果
ctc_weight
不为0,则计算CTC损失和CTC概率。 - 如果
apply_non_blank_embedding
为真,则过滤空白嵌入。 - 计算注意力解码器的损失和准确率。
- 根据CTC损失和注意力解码器的损失,计算总损失。
python
@torch.jit.unused
def forward(
self,
batch: dict, # 输入的批次数据,包括特征、特征长度、目标和目标长度
device: torch.device, # 运行设备
) -> Dict[str, Optional[torch.Tensor]]:
"""Frontend + Encoder + Decoder + Calc loss
前端 + 编码器 + 解码器 + 计算损失
"""
# 将输入数据转移到指定设备上
speech = batch['feats'].to(device)
speech_lengths = batch['feats_lengths'].to(device)
text = batch['target'].to(device)
text_lengths = batch['target_lengths'].to(device)
# 确保目标长度的维度为1(即每个样本对应一个长度)
assert text_lengths.dim() == 1, text_lengths.shape
# 检查批次大小是否一致
assert (speech.shape[0] == speech_lengths.shape[0] == text.shape[0] ==
text_lengths.shape[0]), (speech.shape, speech_lengths.shape,
text.shape, text_lengths.shape)
# 1. 编码器:将语音特征编码为隐藏状态
encoder_out, encoder_mask = self.encoder(speech, speech_lengths)
encoder_out_lens = encoder_mask.squeeze(1).sum(1) # 计算每个样本的有效长度
# 2a. CTC分支:计算CTC损失
if self.ctc_weight != 0.0:
loss_ctc, ctc_probs = self.ctc(encoder_out, encoder_out_lens, text,
text_lengths)
else:
loss_ctc, ctc_probs = None, None
# 2b. 注意力解码器分支
# 如果应用非空白嵌入,将CTC概率作为解码器输入
if self.apply_non_blank_embedding:
assert self.ctc_weight != 0
assert ctc_probs is not None
encoder_out, encoder_mask = self.filter_blank_embedding(
ctc_probs, encoder_out)
# 如果CTC权重不为1.0,计算注意力损失
if self.ctc_weight != 1.0:
loss_att, acc_att = self._calc_att_loss(
encoder_out, encoder_mask, text, text_lengths, {
"langs": batch["langs"],
"tasks": batch["tasks"]
})
else:
loss_att = None
acc_att = None
# 组合CTC损失和注意力损失
if loss_ctc is None:
loss = loss_att
elif loss_att is None:
loss = loss_ctc
else:
loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
# 返回损失和准确度
return {
"loss": loss,
"loss_att": loss_att,
"loss_ctc": loss_ctc,
"th_accuracy": acc_att,
}
3. 计算注意力损失
这个方法计算注意力损失:
- 添加起始符和结束符到目标序列。
- 反转目标序列,用于从右到左的解码器。
- 前向传播解码器,得到解码器输出。
- 计算注意力损失和准确率。
python
def _calc_att_loss(
self,
encoder_out: torch.Tensor, # 编码器的输出 (B, Tmax, D),B是批次大小,Tmax 是时间步数,D是特征维度。
encoder_mask: torch.Tensor, # 编码器的掩码 (B, 1, Tmax)
ys_pad: torch.Tensor, # 目标序列,填充后的 (B, Lmax),Lmax 是目标序列的最大长度。
ys_pad_lens: torch.Tensor, # 目标序列的长度 (B)
infos: Dict[str, List[str]] = None, # 额外的信息,用于多任务学习 (可选)
) -> Tuple[torch.Tensor, torch.Tensor]: # 返回注意力损失和准确度
# 添加起始和结束标记,并调整目标序列和目标序列长度
ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
ys_in_lens = ys_pad_lens + 1 # 增加1以包括起始标记
# 将目标序列进行反转,用于右到左解码器
r_ys_pad = reverse_pad_list(ys_pad, ys_pad_lens, float(self.ignore_id))
r_ys_in_pad, r_ys_out_pad = add_sos_eos(r_ys_pad, self.sos, self.eos, self.ignore_id)
# 1. 前向解码器
# 通过解码器生成输出和反向解码器输出
decoder_out, r_decoder_out, _ = self.decoder(encoder_out, encoder_mask,
ys_in_pad, ys_in_lens,
r_ys_in_pad, self.reverse_weight)
# 2. 计算注意力损失
# 使用标签平滑损失函数计算正向解码器的损失
loss_att = self.criterion_att(decoder_out, ys_out_pad)
# 初始化反向解码器的损失为0
r_loss_att = torch.tensor(0.0)
if self.reverse_weight > 0.0:
# 如果反向权重大于0,计算反向解码器的损失
r_loss_att = self.criterion_att(r_decoder_out, r_ys_out_pad)
# 根据正向和反向权重组合最终的注意力损失
loss_att = loss_att * (1 - self.reverse_weight) + r_loss_att * self.reverse_weight
# 计算准确度
acc_att = th_accuracy(
decoder_out.view(-1, self.vocab_size), # 将解码器输出调整为二维 (B*Lmax, vocab_size)
ys_out_pad, # 目标输出 (B, Lmax)
ignore_label=self.ignore_id, # 忽略的填充标识符
)
# 返回注意力损失和准确度
return loss_att, acc_att
4. 过滤空白嵌入
过滤掉CTC解码过程中生成的空白标记:
- 获取CTC概率的最大索引。
- 为每个批次选择非空白标记对应的编码器输出。
- 填充序列并生成相应的掩码。
python
def filter_blank_embedding(
self, ctc_probs: torch.Tensor, # CTC概率 (B, Tmax, vocab_size)
encoder_out: torch.Tensor # 编码器输出 (B, Tmax, D)
) -> Tuple[torch.Tensor, torch.Tensor]: # 返回过滤后的编码器输出和编码器掩码
"""
过滤CTC解码过程中生成的空白标记,保留非空白标记的嵌入。
Args:
ctc_probs: CTC模型输出的概率分布,形状为 (B, Tmax, vocab_size)
encoder_out: 编码器的输出,形状为 (B, Tmax, D)
Returns:
encoder_out: 过滤后的编码器输出,仅保留非空白标记的嵌入
encoder_mask: 对应的编码器掩码
"""
batch_size = encoder_out.size(0) # 获取批次大小
maxlen = encoder_out.size(1) # 获取时间步的最大长度
top1_index = torch.argmax(ctc_probs, dim=2) # 获取每个时间步上CTC概率最大的索引 (B, Tmax)
indices = []
for j in range(batch_size):
# 对于每个样本,选择非空白标记的索引
indices.append(
torch.tensor(
[i for i in range(maxlen) if top1_index[j][i] != 0]
)
)
# 根据非空白标记的索引,选择对应的编码器输出
select_encoder_out = [
torch.index_select(encoder_out[i, :, :], 0,
indices[i].to(encoder_out.device))
for i in range(batch_size)
]
# 将选择的编码器输出序列填充为相同长度
select_encoder_out = pad_sequence(select_encoder_out,
batch_first=True,
padding_value=0).to(encoder_out.device)
# 计算选择后的序列长度
xs_lens = torch.tensor([len(indices[i]) for i in range(batch_size)]).to(encoder_out.device)
T = select_encoder_out.size(1) # 获取填充后的时间步最大长度
# 创建编码器掩码
encoder_mask = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
encoder_out = select_encoder_out # 更新编码器输出为选择后的结果
return encoder_out, encoder_mask # 返回编码器输出和掩码
5. 解码
这个方法实现了解码功能,支持多种解码方法,如CTC贪婪搜索、CTC前缀束搜索、注意力解码和注意力重评分:
- 根据指定的解码方法,选择相应的解码策略。
- 通过编码器处理输入的语音数据。
- 根据CTC概率和注意力机制生成解码结果。
python
def decode(
self,
methods: List[str], # 解码方法列表
speech: torch.Tensor, # 输入语音特征 (batch, max_len, feat_dim)
speech_lengths: torch.Tensor, # 输入语音长度 (batch, )
beam_size: int, # 集束搜索的束宽
decoding_chunk_size: int = -1, # 解码块的大小
num_decoding_left_chunks: int = -1, # 剩余解码块的数量
ctc_weight: float = 0.0, # CTC得分的权重
simulate_streaming: bool = False, # 是否模拟流式解码
reverse_weight: float = 0.0, # 反向解码器的权重
context_graph: ContextGraph = None, # 上下文图
blank_id: int = 0, # 空白标记的ID
blank_penalty: float = 0.0, # 空白标记的惩罚
length_penalty: float = 0.0, # 长度惩罚
infos: Dict[str, List[str]] = None, # 额外的信息
) -> Dict[str, List[DecodeResult]]: # 返回解码结果的字典
"""
解码输入语音
Args:
methods: 使用的解码方法列表,包括以下方法:
* ctc_greedy_search
* ctc_prefix_beam_search
* attention
* attention_rescoring
speech: 输入语音特征,形状为 (batch, max_len, feat_dim)
speech_lengths: 输入语音长度,形状为 (batch, )
beam_size: 集束搜索的束宽
decoding_chunk_size: 动态块训练模型的解码块大小
<0: 使用完整块进行解码
>0: 使用固定块大小进行解码
0: 训练中使用,不允许在此处使用
simulate_streaming: 是否以流式方式进行编码器前向计算
reverse_weight: 反向解码器的权重
ctc_weight: CTC得分的权重
Returns:
各种解码方法的结果字典
"""
# 确保输入的批次大小和长度大小一致
assert speech.shape[0] == speech_lengths.shape[0]
# 确保解码块大小不为0
assert decoding_chunk_size != 0
# 前向编码器,获取编码器输出和掩码
encoder_out, encoder_mask = self._forward_encoder(
speech, speech_lengths, decoding_chunk_size,
num_decoding_left_chunks, simulate_streaming)
# 计算编码器输出的有效长度
encoder_lens = encoder_mask.squeeze(1).sum(1)
# 计算CTC概率
ctc_probs = self.ctc_logprobs(encoder_out, blank_penalty, blank_id)
results = {} # 存储解码结果
# 使用注意力机制进行解码
if 'attention' in methods:
results['attention'] = attention_beam_search(
self, encoder_out, encoder_mask, beam_size, length_penalty, infos)
# 使用CTC贪婪搜索进行解码
if 'ctc_greedy_search' in methods:
results['ctc_greedy_search'] = ctc_greedy_search(
ctc_probs, encoder_lens, blank_id)
# 使用CTC前缀集束搜索进行解码
if 'ctc_prefix_beam_search' in methods:
ctc_prefix_result = ctc_prefix_beam_search(
ctc_probs, encoder_lens, beam_size, context_graph, blank_id)
results['ctc_prefix_beam_search'] = ctc_prefix_result
# 使用注意力重评分进行解码
if 'attention_rescoring' in methods:
# 确保CTC前缀集束搜索结果存在
if 'ctc_prefix_beam_search' in results:
ctc_prefix_result = results['ctc_prefix_beam_search']
else:
ctc_prefix_result = ctc_prefix_beam_search(
ctc_probs, encoder_lens, beam_size, context_graph, blank_id)
# 过滤空白标记嵌入
if self.apply_non_blank_embedding:
encoder_out, _ = self.filter_blank_embedding(ctc_probs, encoder_out)
# 使用注意力重评分进行解码
results['attention_rescoring'] = attention_rescoring(
self, ctc_prefix_result, encoder_out, encoder_lens, ctc_weight,
reverse_weight, infos)
return results # 返回解码结果字典
其他辅助方法
1.获取CTC激活
这个方法返回编码器输出经过log_softmax变换后的CTC激活值。
python
@torch.jit.export
def ctc_activation(self, xs: torch.Tensor) -> torch.Tensor:
""" Export interface for c++ call, apply linear transform and log
softmax before ctc
Args:
xs (torch.Tensor): encoder output
Returns:
torch.Tensor: activation before ctc
"""
return self.ctc.log_softmax(xs)
2.其他导出接口
这些方法提供了模型的相关信息,如子采样率、右上下文、起始符号和结束符号。
python
@torch.jit.export
def subsampling_rate(self) -> int:
""" Export interface for c++ call, return subsampling_rate of the
model
"""
return self.encoder.embed.subsampling_rate
@torch.jit.export
def right_context(self) -> int:
""" Export interface for c++ call, return right_context of the model
"""
return self.encoder.embed.right_context
@torch.jit.export
def sos_symbol(self) -> int:
""" Export interface for c++ call, return sos symbol id of the model
"""
return self.sos
@torch.jit.export
def eos_symbol(self) -> int:
""" Export interface for c++ call, return eos symbol id of the model
"""
return self.eos
3.编码器前向传播
这个方法提供了一个接口,用于分块进行编码器前向传播。它返回当前块的输出、注意力缓存和CNN缓存。
python
@torch.jit.export
def forward_encoder_chunk(
self,
xs: torch.Tensor, # chunk 输入,形状为 (b=1, time, mel-dim)
offset: int, # 编码器输出时间戳的当前偏移量
required_cache_size: int, # 下一个chunk计算所需的缓存大小
att_cache: torch.Tensor = torch.zeros(0, 0, 0, 0), # Transformer/Conformer注意力中的KEY和VALUE的缓存张量
cnn_cache: torch.Tensor = torch.zeros(0, 0, 0, 0), # Conformer中cnn模块的缓存张量
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # 返回当前输入xs的输出、新的注意力缓存和新的cnn缓存
"""
导出接口供C++调用,给定输入chunk xs,并返回从时间0到当前chunk的输出。
Args:
xs (torch.Tensor): chunk输入,形状为 (b=1, time, mel-dim),
其中 `time == (chunk_size - 1) * subsample_rate + \
subsample.right_context + 1`
offset (int): 编码器输出时间戳的当前偏移量
required_cache_size (int): 下一个chunk计算所需的缓存大小
>=0: 实际缓存大小
<0: 表示需要所有历史缓存
att_cache (torch.Tensor): Transformer/Conformer注意力中的KEY和VALUE的缓存张量,形状为
(elayers, head, cache_t1, d_k * 2),其中 `head * d_k == hidden-dim` 且
`cache_t1 == chunk_size * num_decoding_left_chunks`。
cnn_cache (torch.Tensor): Conformer中cnn模块的缓存张量,形状为
(elayers, b=1, hidden-dim, cache_t2),其中 `cache_t2 == cnn.lorder - 1`
Returns:
torch.Tensor: 当前输入xs的输出,形状为 (b=1, chunk_size, hidden-dim)。
torch.Tensor: 下一个chunk所需的新注意力缓存,形状为
(elayers, head, ?, d_k * 2),具体取决于 required_cache_size。
torch.Tensor: 下一个chunk所需的新Conformer cnn缓存,形状与原cnn_cache相同。
"""
return self.encoder.forward_chunk(xs, offset, required_cache_size,
att_cache, cnn_cache)
4.双向解码器检查
这个方法检查解码器是否为双向解码器。
python
@torch.jit.export
def is_bidirectional_decoder(self) -> bool:
"""
Returns:
torch.Tensor: decoder output
"""
if hasattr(self.decoder, 'right_decoder'):
return True
else:
return False
前向解码器
这个方法实现了对多假设的解码操作。它将多个假设和一个编码器输出进行解码,并返回解码器输出。
python
@torch.jit.export
def forward_attention_decoder(
self,
hyps: torch.Tensor, # 从CTC前缀集束搜索中获得的假设,已经在开头填充了<sos>
hyps_lens: torch.Tensor, # 每个假设的长度
encoder_out: torch.Tensor, # 编码器输出
reverse_weight: float = 0, # 用于验证是否使用从右到左的解码器,> 0 将使用
) -> Tuple[torch.Tensor, torch.Tensor]: # 返回解码器输出
"""
供C++调用的导出接口,使用多个CTC前缀集束搜索的假设和一个编码器输出进行前向解码。
Args:
hyps (torch.Tensor): 从CTC前缀集束搜索中获得的假设,已经在开头填充了<sos>
hyps_lens (torch.Tensor): 每个假设的长度
encoder_out (torch.Tensor): 编码器输出
r_hyps (torch.Tensor): 从CTC前缀集束搜索中获得的假设,已经在开头填充了<eos>,用于从右到左解码器
reverse_weight: 用于验证是否使用从右到左解码器,> 0 将使用
Returns:
torch.Tensor: 解码器输出
"""
assert encoder_out.size(0) == 1 # 确保编码器输出的批次大小为1
num_hyps = hyps.size(0) # 获取假设的数量
assert hyps_lens.size(0) == num_hyps # 确保假设长度的数量与假设数量相同
# 将编码器输出重复num_hyps次,以匹配假设的数量
encoder_out = encoder_out.repeat(num_hyps, 1, 1)
encoder_mask = torch.ones(num_hyps, 1, encoder_out.size(1),
dtype=torch.bool, device=encoder_out.device) # 创建编码器掩码
# 处理从右到左的解码器输入
r_hyps_lens = hyps_lens - 1 # 获取从右到左的假设长度
r_hyps = hyps[:, 1:] # 移除开头的<sos>标记
max_len = torch.max(r_hyps_lens) # 获取最大假设长度
index_range = torch.arange(0, max_len, 1).to(encoder_out.device) # 创建索引范围
seq_len_expand = r_hyps_lens.unsqueeze(1) # 扩展假设长度以匹配索引范围
seq_mask = seq_len_expand > index_range # 创建序列掩码 (beam, max_len)
index = (seq_len_expand - 1) - index_range # 计算索引 (beam, max_len)
index = index * seq_mask # 应用序列掩码
r_hyps = torch.gather(r_hyps, 1, index) # 根据索引选择假设
r_hyps = torch.where(seq_mask, r_hyps, self.eos) # 替换无效的假设为<eos>
r_hyps = torch.cat([hyps[:, 0:1], r_hyps], dim=1) # 在开头添加<sos>
# 前向解码器
decoder_out, r_decoder_out, _ = self.decoder(
encoder_out, encoder_mask, hyps, hyps_lens, r_hyps, reverse_weight) # (num_hyps, max_hyps_len, vocab_size)
# 应用log_softmax
decoder_out = torch.nn.functional.log_softmax(decoder_out, dim=-1)
r_decoder_out = torch.nn.functional.log_softmax(r_decoder_out, dim=-1)
return decoder_out, r_decoder_out # 返回解码器输出