说的神马?基于 Wav2Vec2 的端到端中文语音识别系统

说的神马?基于 Wav2Vec2 的端到端中文语音识别系统

代码详见:https://github.com/xiaozhou-alt/Chinese_Speech_Recognition


文章目录

  • [说的神马?基于 Wav2Vec2 的端到端中文语音识别系统](#说的神马?基于 Wav2Vec2 的端到端中文语音识别系统)
  • 一、项目介绍
  • 二、文件夹结构
  • 三、数据集介绍
  • [四、Wav2Vec2 模型介绍](#四、Wav2Vec2 模型介绍)
    • [1. 特征编码器:从原始波形到声学特征](#1. 特征编码器:从原始波形到声学特征)
    • [2. 上下文网络:捕捉语音序列依赖关系](#2. 上下文网络:捕捉语音序列依赖关系)
    • [3. CTC 头与损失函数:解决时序对齐难题](#3. CTC 头与损失函数:解决时序对齐难题)
    • [4. 自监督预训练机制:无标注数据的知识积累](#4. 自监督预训练机制:无标注数据的知识积累)
  • 五、项目实现
    • [1. 评估指标初始化](#1. 评估指标初始化)
    • [2. 数据集准备​](#2. 数据集准备)
      • [2.1 数据路径与文件列表](#2.1 数据路径与文件列表)
      • [2.2 训练集与验证集划分](#2.2 训练集与验证集划分)
    • [3. 自定义数据集类](#3. 自定义数据集类)
      • [3.1 初始化方法 (`init`)](#3.1 初始化方法 (__init__))
      • [3.2 长度方法 (`len`)](#3.2 长度方法 (__len__))
      • [3.3 数据获取方法 (`getitem`)](#3.3 数据获取方法 (__getitem__))
    • [4. 模型加载与配置​](#4. 模型加载与配置)
      • [4.1 Wav2Vec2 模型结构​](#4.1 Wav2Vec2 模型结构)
      • [4.2 关键参数解析​](#4.2 关键参数解析)
      • [4.3 模型初始化​](#4.3 模型初始化)
    • [5. 数据加载器配置​](#5. 数据加载器配置)
      • [5.1 自定义数据整理函数 (data_collator)​](#5.1 自定义数据整理函数 (data_collator))
      • [5.2 数据加载器创建​](#5.2 数据加载器创建)
    • [6. 优化器与学习率调度器](#6. 优化器与学习率调度器)
      • [6.1 优化器选择​](#6.1 优化器选择)
      • [6.2 学习率调度器​](#6.2 学习率调度器)
    • [7. 验证集评估函数](#7. 验证集评估函数)
      • [7.1 评估流程​](#7.1 评估流程)
      • [7.2 CTC 解码​](#7.2 CTC 解码)
    • [8. 开始训练!](#8. 开始训练!)
      • [8.1 训练阶段​](#8.1 训练阶段)
      • [8.2 验证阶段​](#8.2 验证阶段)
      • [8.3 训练记录与早停检查​](#8.3 训练记录与早停检查)
    • [9. 样本可视化测试​](#9. 样本可视化测试)
      • [9.1 样本选择与处理​](#9.1 样本选择与处理)
      • [9.2 模型预测​](#9.2 模型预测)
  • 六、结果展示

一、项目介绍

本项目是一个基于深度学习的中文语音识别系统,旨在将中文语音信号准确转换为对应的文本内容。该系统利用 Hugging Face 的 Transformers 库中预训练的 Wav2Vec2 模型进行微调,专门针对中文语音识别任务进行优化。

Wav2Vec2 是 Facebook AI 研究院提出的一种先进的自监督学习语音模型,能够从大量未标注的语音数据中学习语音表示。本项目基于预训练的 "wav2vec2-large-xlsr-53-chinese-zh-cn" 模型进行微调,该模型已经在大量中文语音数据上进行了预训练,非常适合中文语音识别任务。

系统采用了 CTC(Connectionist Temporal Classification)损失函数,这是一种常用于序列到序列学习的损失函数,特别适合语音识别这类输入和输出序列长度不一致的任务。

项目实现了完整的语音识别流程,包括:

  • 音频数据预处理(格式转换、重采样、声道处理)
  • 自定义数据集加载和处理
  • 模型训练与微调
  • 模型评估(使用 WER 指标)

二、文件夹结构

c 复制代码
Chinese_Speech_Recognition\
├── README.md
├── data.ipynb                  # 用于数据分析和预处理
├── data\
    ├── data_thchs30\           # THCHS-30中文语音数据集
        ├── README-data.md      # 数据集说明文档
        └── data\               # 实际音频数据存放目录
├── log\
├── output\                     # 输出结果目录
    ├── model\                  # 模型存储目录
        ├── sound.pth           # 训练好的模型权重文件
        └── wav2vec2-large-xlsr-53-chinese-zh-cn\  # 预训练模型目录
    ├── pic\                    # 图片和可视化结果目录
    ├── test_results.xlsx       # 测试结果数据表
    └── training_history.xlsx   # 训练历史记录数据表
├── requirements.txt 
└── train.py

三、数据集介绍

数据集下载可前往:THCHS-30

本项目使用 THCHS-30 中文语音数据集进行模型训练和评估。THCHS-30 是一个常用的中文语音识别数据集,包含 30 小时的语音数据。

数据集的组织结构如下:

  • 音频文件:以 WAV 格式存储,采样率可能有所不同
  • 标签文件:与音频文件同名的.trn 文件,包含对应的中文文本标签

数据集的音频时长与索引分布如下所示:

四、Wav2Vec2 模型介绍

1. 特征编码器:从原始波形到声学特征

特征编码器是 Wav2Vec2 的 "耳朵",负责将原始音频波形转换为有意义的声学特征表示。它采用 7 层卷积神经网络结构,通过逐步下采样捕捉不同尺度的音频特征:

x _ l + 1 = GELU ( Conv ( x _ l , k = 3 , s = 2 ) ) x\_{l+1} = \text{GELU}(\text{Conv}(x\_l, k=3, s=2)) x_l+1=GELU(Conv(x_l,k=3,s=2))

其中 x _ l x\_l x_l 表示第 l l l 层的输入特征, k k k 为卷积核大小, s s s 为步长。前 6 6 6 层卷积使用步长为 2 2 2 的 3×3 \textbf{3×3} 3×3 卷积核,最后一层使用步长为 1 1 1 的卷积核,将原始 16 k H z 16kHz 16kHz 音频(每秒 16000 16000 16000 个采样点)压缩为每秒 50 50 50 个特征帧,实现了 320 320 320 倍的时间维度压缩。

在代码实现中,我们通过librosa.resample将所有音频统一处理为 16 k H z 16kHz 16kHz 采样率,为特征编码器提供标准化输入,这一步对应着让 所有 "说话人" 使用相同的 "语速" 发言,便于后续处理。

🤓🤓🤓小周有话说

想象你在听一段 中文新闻播报 ,原始音频就像一条连续不断的声波河流。特征编码器就像一系列精密滤网,第一层滤网捕捉 "你好""你" 的基础声波振动,经过多层过滤后,最终提取出 "你""好" 等音节的核心声学特征,过滤掉背景噪音等无关信息。这就像我们听人说话时,自动忽略环境杂音,只关注关键语音信息 的过程。

2. 上下文网络:捕捉语音序列依赖关系

上下文网络是 Wav2Vec2 的 "大脑",由多个 Transformer 层组成,通过自注意力机制捕捉语音序列中的长距离依赖关系。其核心是多头自注意力计算:

Attention ( Q , K , V ) = softmax ( Q K T d k ) V MultiHead ( Q , K , V ) = Concat ( head 1 , ... , head h ) W O \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V \\ \ \\ \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \dots, \text{head}_h)W^O Attention(Q,K,V)=softmax(dk QKT)V MultiHead(Q,K,V)=Concat(head1,...,headh)WO

其中 Q Q Q、 K K K、 V V V 分别为查询、键和值矩阵, d k d_k dk 为特征维度, h h h 为注意力头数。这种机制允许模型在处理 "北京天安门" 这样的短语时,同时关注 "北京""天安门" 之间的语义关联。

这种机制特别适合中文语音识别,因为中文词语之间没有明显分隔,需要通过上下文理解 "下雨天留客天留我不留" 这样的歧义句子

🤓🤓🤓小周有话说

假设我们识别 "我爱北京天安门" 这句话,当处理到 "北京" 时,上下文网络会像人类理解语言一样,自动联想到后面可能出现的 "天安门"。自注意力权重就像句子中词语间的 "关联强度表""北京" "天安门" 之间的权重会显著高于与其他词的关联。代码中的attention_dropout=0.1参数就像给这个联想过程增加一点 "遗忘概率",防止模型过度依赖某些固定关联,增强泛化能力。

3. CTC 头与损失函数:解决时序对齐难题

CTC(Connectionist Temporal Classification)头是 Wav2Vec2 的 "翻译官",负责将声学特征序列转换为文本序列,并通过 CTC 损失函数优化模型参数:

L = − log ⁡ ( ∑ π ∈ B ( l ) P ( π ∣ x ) ) L = -\log\left(\sum_{\pi \in \mathcal{B}(l)} P(\pi | x)\right) L=−log π∈B(l)∑P(π∣x)

其中 B ( l ) \mathcal{B}(l) B(l) 表示所有与标签序列 l l l 对齐的路径集合, P ( π ∣ x ) P(\pi | x) P(π∣x) 为路径 π \pi π 的概率。CTC 引入空白标签 ∅ 解决语音与文本的时序不对齐问题,允许模型学习发音 - 文本的多对多映射关系

这种特性让模型特别适合处理中文语音中常见的 连读吞音 现象,比如将 "不知道" 的实际发音正确转换为标准文本,而不会被发音的模糊边界干扰

🤓🤓🤓小周有话说

当我们说 "中国" 这个词时,实际发音可能是 "中 ------ 国"(中间有自然停顿),也可能是快速连读 "中国"。CTC 头就像一位 智能校对员 ,能自动忽略多余的停顿(对应空白标签 ),无论你说得快还是慢,都能正确转换为 "中国" 这两个字。代码中ctc_loss_reduction="mean"参数表示采用平均损失策略,就像老师批改作业时,综合评估所有错题的平均错误程度来打分。

4. 自监督预训练机制:无标注数据的知识积累

Wav2Vec2 的成功很大程度上归功于其创新的自监督预训练机制,通过对比损失和掩码预测任务学习通用语音表示:

L contrastive = − log ⁡ ( e s ( z i , c i ) / τ ∑ j = 1 N e s ( z j , c i ) / τ ) L_{\text{contrastive}} = -\log\left(\frac{e^{s(z_i, c_i)/\tau}}{\sum_{j=1}^N e^{s(z_j, c_i)/\tau}}\right) Lcontrastive=−log(∑j=1Nes(zj,ci)/τes(zi,ci)/τ)

其中 s ( z i , c i ) s(z_i, c_i) s(zi,ci) 表示特征 z i z_i zi 与上下文向量 c i c_i ci 的相似度, τ \tau τ 为温度参数。预训练时随机掩码 50 % 50\% 50% 的时间步,让模型从上下文预测被掩码的声学特征。

🤓🤓🤓小周有话说

这就像我们学习中文听力时,即使某些词语 没听清被掩码 ),也能 根据上下文推测 出内容。模型在预训练阶段 "收听" 了大量无标注语音数据,就像婴儿在学会说话前先 "聆听" 周围环境的语言,积累了丰富的语音知识。我们使用的jonatasgrosman/wav2vec2-large-xlsr-53-chinese-zh-cn模型已经通过这种方式学习了中文语音的基本规律,再通过 THCHS-30 数据集微调,就像给模型 "上中文强化班",使其专门适应中文语音识别任务

五、项目实现

1. 评估指标初始化

语音识别任务中最常用的评估指标是词错误率 (WER):

python 复制代码
# 加载WER评估指标
wer_metric = evaluate.load("wer")

词错误率 (WER) 是衡量语音识别系统性能的核心指标,计算方法为:

W E R = ( 替换错误 + 插入错误 + 删除错误 ) / 参考词总数 WER = (替换错误 + 插入错误 + 删除错误) / 参考词总数 WER=(替换错误+插入错误+删除错误)/参考词总数

值越低表示系统性能越好,0 表示完全正确识别。

2. 数据集准备​

2.1 数据路径与文件列表

python 复制代码
# 数据集路径
data_dir = "/kaggle/input/chinese-speech-to-textthchs30/data/data"  # 修改为你的数据路径

# 获取所有文件
wav_files = [f for f in os.listdir(data_dir) if f.endswith('.wav')]
print(f"Found {len(wav_files)} WAV files")

2.2 训练集与验证集划分

python 复制代码
# 划分训练集和验证集 (80%训练, 20%验证)
train_size = int(0.8 * len(wav_files))
val_size = len(wav_files) - train_size
train_files, val_files = random_split(wav_files, [train_size, val_size])

print(f"Training set: {len(train_files)} samples")
print(f"Validation set: {len(val_files)} samples")

数据集划分策略:​

  • 采用 8 : 2 8:2 8:2 的比例划分训练集和验证集
  • 使用random_split函数进行随机划分
  • 训练集用于模型参数学习,验证集用于监控训练过程中的过拟合情况

3. 自定义数据集类

P y T o r c h PyTorch PyTorch 通过自定义Dataset类来处理数据加载,这里我们创建专门的语音数据集类:

python 复制代码
# 创建自定义数据集类
class SpeechDataset(Dataset):
    def __init__(self, file_list, data_dir):
        self.file_list = file_list
        self.data_dir = data_dir
        self.processor = Wav2Vec2Processor.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-chinese-zh-cn")
        
    def __len__(self):
        return len(self.file_list)
    
    def __getitem__(self, idx):
        wav_file = self.file_list[idx]
        wav_path = os.path.join(self.data_dir, wav_file)
        trn_path = os.path.join(self.data_dir, wav_file + ".trn")
        
        # 读取音频文件
        try:
            speech_array, sampling_rate = sf.read(wav_path)
            speech_array = speech_array.astype(np.float32)
        except Exception as e:
            print(f"读取音频文件错误 {wav_path}: {e}")
            return None
        
        # 如果音频是双声道,转换为单声道
        if len(speech_array.shape) > 1:
            speech_array = np.mean(speech_array, axis=1)
        
        # 重采样到16kHz
        if sampling_rate != 16000:
            speech_array = librosa.resample(speech_array, orig_sr=sampling_rate, target_sr=16000)
        
        # 读取标签
        try:
            with open(trn_path, 'r', encoding='utf-8') as f:
                lines = f.readlines()
                # 使用第一行作为标签文本
                text = lines[0].strip() if lines else ""
        except Exception as e:
            print(f"读取标签文件错误 {trn_path}: {e}")
            text = ""
        
        # 处理音频和标签
        inputs = self.processor(
            speech_array, 
            sampling_rate=16000, 
            text=text,
            padding=True,
            return_tensors="pt"
        )
        
        # 移除批次维度
        inputs = {key: inputs[key].squeeze(0) for key in inputs}
        
        return inputs

3.1 初始化方法 (__init__)

  • 接收文件列表和数据目录
  • 加载 W a v 2 V e c 2 Wav2Vec2 Wav2Vec2 处理器 (processor),这是连接原始数据与模型输入的关键组件
  • 使用的预训练模型是专门为中文优化的jonatasgrosman/wav2vec2-large-xlsr-53-chinese-zh-cn
    W a v 2 V e c 2 P r o c e s s o r Wav2Vec2Processor Wav2Vec2Processor 整合了特征提取器和 tokenizer,能同时处理音频数据和文本标签:
  • 对音频:将原始波形转换为模型需要的特征表示
  • 对文本:将文字转换为模型可理解的 token ID

3.2 长度方法 (__len__)

简单返回数据集中样本的数量,是 PyTorch Dataset 的必需方法。

3.3 数据获取方法 (__getitem__)

这是数据集类的核心,负责加载和预处理单个样本:

  1. 音频加载与预处理

    l i b r o s a librosa librosa 的重采样使用基于 FFT 的方法,能在 改变采样率的同时保持音频质量,确保不同来源的音频具有一致的时间尺度。

    • 使用soundfile读取音频文件,获取音频数组和采样率
    • 处理双声道音频:通过求平均值转换为单声道
    • 重采样:统一将音频采样率转换为 16 k H z 16kHz 16kHz,这是 Wav2Vec2 模型的期望输入采样率
  2. 标签加载

    • 从对应的.trn文件中读取文本标签
    • 处理可能的文件读取错误,提高代码健壮性
  3. 数据处理

    • 使用 processor 同时处理音频和文本
    • 转换为 PyTorch 张量格式
    • 移除多余的批次维度,确保单个样本的正确格式

4. 模型加载与配置​

Wav2Vec2 是一种基于自监督学习的语音表示模型,在语音识别任务上表现优异:

python 复制代码
# 加载预训练模型
model = Wav2Vec2ForCTC.from_pretrained(
    "jonatasgrosman/wav2vec2-large-xlsr-53-chinese-zh-cn",
    attention_dropout=0.1,
    hidden_dropout=0.1,
    feat_proj_dropout=0.0,
    mask_time_prob=0.05,
    layerdrop=0.1,
    ctc_loss_reduction="mean",
    pad_token_id=processor.tokenizer.pad_token_id,
    vocab_size=len(processor.tokenizer)
)

# 设置模型为训练模式
model.train()
model.to(device)

4.1 Wav2Vec2 模型结构​

Wav2Vec2ForCTC 是专为 CTC 损失设计的 Wav2Vec2 变体,其结构包括:​

  • 特征编码器:将原始音频转换为高级特征表示
  • 上下文网络:捕获长时依赖关系
  • CTC 头:输出每个时间步的字符概率分布

4.2 关键参数解析​

  • Dropout 参数attention_dropouthidden_dropout等用于防止过拟合
  • mask_time_prob:训练时对时间步进行掩码的概率,是 Wav2Vec2 的自监督学习遗产
  • layerdrop:随机丢弃整个层的概率,增强模型鲁棒性
  • ctc_loss_reduction :CTC 损失的聚合方式,这里使用 "mean"​
  • pad_token_id:指定填充 token 的 ID,用于处理不等长序列

4.3 模型初始化​

  • 加载预训练权重,这是迁移学习的关键
  • 设置为训练模式 (model.train())
  • 将模型移动到之前确定的计算设备 (GPU 或 CPU)

5. 数据加载器配置​

PyTorch 的 DataLoader 负责批处理、打乱数据和并行加载:

python 复制代码
# 定义数据整理函数
def data_collator(batch):
    # 过滤掉None值
    batch = [item for item in batch if item is not None]
    if not batch:
        return None
    
    # 找出批次中最长的音频长度
    max_length = max(item['input_values'].shape[0] for item in batch)
    
    # 找出批次中最长的标签长度
    max_label_length = max(item['labels'].shape[0] for item in batch)
    
    # 初始化填充后的张量
    padded_input_values = []
    padded_attention_mask = []
    padded_labels = []
    
    for item in batch:
        # 填充音频输入
        input_length = item['input_values'].shape[0]
        padding_length = max_length - input_length
        
        if padding_length > 0:
            padded_input = torch.cat([
                item['input_values'],
                torch.zeros(padding_length, dtype=item['input_values'].dtype)
            ])
            padded_attention = torch.cat([
                item['attention_mask'],
                torch.zeros(padding_length, dtype=item['attention_mask'].dtype)
            ])
        else:
            padded_input = item['input_values']
            padded_attention = item['attention_mask']
        
        # 填充标签
        label_length = item['labels'].shape[0]
        label_padding_length = max_label_length - label_length
        
        if label_padding_length > 0:
            padded_label = torch.cat([
                item['labels'],
                torch.full((label_padding_length,), -100, dtype=item['labels'].dtype)
            ])
        else:
            padded_label = item['labels']
        
        padded_input_values.append(padded_input)
        padded_attention_mask.append(padded_attention)
        padded_labels.append(padded_label)
    
    return {
        'input_values': torch.stack(padded_input_values),
        'attention_mask': torch.stack(padded_attention_mask),
        'labels': torch.stack(padded_labels)
    }

# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=data_collator)
val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False, collate_fn=data_collator)

5.1 自定义数据整理函数 (data_collator)​

由于语音数据长度不一,需要自定义整理函数来处理批处理:​

  1. 过滤无效样本 :移除加载失败的样本 (None 值)​

  2. 确定填充长度:找出批次中最长的音频和标签长度​

  3. 音频填充:​

    • 对短于最大长度的音频补零
    • 相应调整注意力掩码 (attention mask),指示有效音频部分
  4. 标签填充:​

    • 使用 − 100 -100 −100 填充标签,这是 PyTorch 中 CTC 损失忽略的特殊值
    • 确保损失计算时不考虑填充部分

5.2 数据加载器创建​

  • batch_size=4:训练批次大小,根据 GPU 显存调整
  • shuffle=True:训练集打乱顺序,增强训练随机性
  • collate_fn=data_collator:使用自定义整理函数
  • 验证集使用较小批次,且不打乱顺序

6. 优化器与学习率调度器

python 复制代码
# 设置优化器
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

# 学习率调度器
num_epochs = 8  # 适当增加训练轮数以获得更好效果
num_training_steps = num_epochs * len(train_loader)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_training_steps)

6.1 优化器选择​

使用 AdamW 优化器,这是 Adam 优化器的变体,增加了权重衰减 (weight decay) 正则化,有助于防止过拟合。初始学习率设置为 1e-4,适合预训练模型的微调。​

6.2 学习率调度器​

采用 余弦退火 调度器:​

  • 学习率随训练步骤按余弦曲线逐渐降低
  • T_max设置为总训练步数,完整经历一个余弦周期
  • 这种策略通常比固定学习率收敛更快,泛化性能更好

7. 验证集评估函数

python 复制代码
# 定义验证集评估函数(计算Loss和WER)
def evaluate_on_validation_set(model, dataloader, processor, device):
    model.eval()
    total_loss = 0
    all_predictions = []
    all_references = []
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating on validation set"):
            if batch is None:
                continue
                
            input_values = batch['input_values'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            # 计算损失
            with torch.cuda.amp.autocast(enabled=device.type == "cuda"):
                outputs = model(input_values=input_values, attention_mask=attention_mask, labels=labels)
                total_loss += outputs.loss.item()
            
            # 生成预测
            logits = outputs.logits
            predicted_ids = torch.argmax(logits, dim=-1)
            
            # 解码预测结果和标签(转换为文本)
            predictions = processor.batch_decode(predicted_ids)
            references = processor.batch_decode(labels, group_tokens=False)  # 不解码成组,保持原始标签
            
            all_predictions.extend(predictions)
            all_references.extend(references)
    
    # 计算平均损失和WER
    avg_loss = total_loss / len(dataloader) if len(dataloader) > 0 else 0
    wer = wer_metric.compute(predictions=all_predictions, references=all_references)
    
    return avg_loss, wer, all_predictions, all_references

7.1 评估流程​

模型状态设置:将模型设为评估模式 (model.eval()),关闭 dropout 等训练特有的操作​

梯度禁用:使用torch.no_grad()上下文管理器,减少内存占用并加速计算​

批量处理:​

将数据移至计算设备​

使用混合精度计算损失​

生成模型预测(logits)​

解码过程:​

将模型输出的 logits 通过 argmax 获取预测的 token ID​

使用 processor 将 token ID 转换为文本(解码)​

特别注意标签解码使用group_tokens=False,确保正确计算 WER​

指标计算:计算平均损失和整体 WER​

7.2 CTC 解码​

代码中使用的是贪婪解码(取每个时间步概率最大的 token),这是一种简单高效的解码策略。更复杂的解码策略(如 beam search)可以进一步提高性能,但计算成本更高。

8. 开始训练!

python 复制代码
# 训练循环
for epoch in range(num_epochs):
    # 训练阶段
    model.train()
    train_loss = 0
    train_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]")
    
    for batch in train_bar:
        if batch is None:
            continue
            
        # 将数据移动到设备
        input_values = batch['input_values'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        # 前向传播 - 使用混合精度
        optimizer.zero_grad()
        with torch.cuda.amp.autocast(enabled=device.type == "cuda"):
            outputs = model(input_values=input_values, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
        
        # 反向传播 - 使用混合精度
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        lr_scheduler.step()
        
        train_loss += loss.item()
        train_bar.set_postfix(loss=loss.item())
    
    avg_train_loss = train_loss / len(train_loader) if len(train_loader) > 0 else 0
    
    # 验证阶段(计算Loss和WER)
    avg_val_loss, val_wer, _, _ = evaluate_on_validation_set(model, val_loader, processor, device)
    
    # 记录训练历史
    train_history.append({
        'epoch': epoch + 1,
        'train_loss': avg_train_loss,
        'val_loss': avg_val_loss,
        'val_wer': val_wer,  # 记录WER
        'learning_rate': lr_scheduler.get_last_lr()[0] if len(train_loader) > 0 else 0
    })
    
    # 打印当前epoch的评估结果
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    print(f"Train Loss: {avg_train_loss:.4f}")
    print(f"Val Loss: {avg_val_loss:.4f}")
    print(f"Val WER: {val_wer:.4f} (越低越好,0表示完全正确)")
    
    # 早停检查(基于WER)
    if val_wer < best_val_wer:
        best_val_wer = val_wer
        patience_counter = 0
        # 保存最佳模型
        torch.save(model.state_dict(), "/kaggle/working/best_model.pth")
        print(f"保存最佳模型 (WER: {best_val_wer:.4f})")
    else:
        patience_counter += 1
        print(f"早停计数器: {patience_counter}/{early_stopping_patience}")
        if patience_counter >= early_stopping_patience:
            print(f"早停触发,在第 {epoch+1} 轮停止训练")
            break

8.1 训练阶段​

  1. 模型状态 :设置为训练模式 (model.train())
  2. 进度跟踪:使用 tqdm 显示训练进度和当前损失
  3. 批量训练步骤
    • 数据移至计算设备
    • 梯度清零 (optimizer.zero_grad())
    • 前向传播:使用混合精度 (autocast) 计算输出和损失
    • 反向传播:使用 scaler 处理混合精度梯度
    • 参数更新:scaler.step(optimizer)确保安全更新
    • 学习率调整:调度器更新学习率
  4. 损失累积:计算整个 epoch 的平均训练损失

8.2 验证阶段​

每个 epoch 结束后,在验证集上评估模型性能:​

  • 调用前面定义的evaluate_on_validation_set函数
  • 获取验证损失和 WER

8.3 训练记录与早停检查​

  • 记录每个 epoch 的关键指标:损失、WER、学习率
  • 打印 epoch 总结,可视化训练进展
  • 早停检查:如果 WER 改善则保存模型并重置计数器,否则递增计数器
  • 当计数器达到耐心值时,停止训练

训练示例输出:

Using device: cuda

混合精度训练: 启用(需要CUDA支持)

Found 13388 WAV files

Training set: 10710 samples

Validation set: 2678 samples

Epoch 1/8

Train Loss: 0.8957

Val Loss: 0.3803

Val WER: 0.4495 (越低越好,0表示完全正确)

保存最佳模型 (WER: 0.4495)

...

Epoch 8/8

Train Loss: 0.1346

Val Loss: 0.0532

Val WER: 0.0584 (越低越好,0表示完全正确)

保存最佳模型 (WER: 0.0584)

最终模型在验证集上的表现:

最终验证Loss: 0.0532

最终验证WER: 0.0584 (错误率:5.84%)

9. 样本可视化测试​

为直观展示模型性能,随机选择部分验证样本进行详细测试:

python 复制代码
# 随机选择10个验证样本进行详细测试
val_file_list = [val_files.dataset[i] for i in val_files.indices]
test_samples = random.sample(val_file_list, min(10, len(val_file_list)))

# 创建测试结果列表
test_results = []

# 进行详细测试
for i, wav_file in enumerate(test_samples):
    wav_path = os.path.join(data_dir, wav_file)
    trn_path = os.path.join(data_dir, wav_file + ".trn")
    
    # 读取真实文本
    try:
        with open(trn_path, 'r', encoding='utf-8') as f:
            lines = f.readlines()
            true_text = lines[0].strip() if lines else ""
    except Exception as e:
        print(f"读取标签文件错误 {trn_path}: {e}")
        true_text = ""
    
    # 读取和处理音频
    try:
        speech_array, sampling_rate = sf.read(wav_path)
        speech_array = speech_array.astype(np.float32)
    except Exception as e:
        print(f"读取音频文件错误 {wav_path}: {e}")
        continue
    
    if len(speech_array.shape) > 1:
        speech_array = np.mean(speech_array, axis=1)
    
    if sampling_rate != 16000:
        speech_array = librosa.resample(speech_array, orig_sr=sampling_rate, target_sr=16000)
    
    # 模型预测
    inputs = processor(speech_array, sampling_rate=16000, return_tensors="pt", padding=True)
    
    with torch.no_grad():
        with torch.cuda.amp.autocast(enabled=device.type == "cuda"):
            logits = model(inputs.input_values.to(device)).logits
    
    predicted_ids = torch.argmax(logits, dim=-1)
    prediction = processor.batch_decode(predicted_ids)[0]
    
    # 保存结果
    test_results.append({
        'audio_file': wav_file,
        'true_text': true_text,
        'predicted_text': prediction,
        'audio_array': speech_array,
        'sampling_rate': 16000
    })
    
    print(f"\nSample {i+1}:")
    print(f"True: {true_text}")
    print(f"Predicted: {prediction}")
    print("-" * 50)

9.1 样本选择与处理​

  • 从验证集中随机选择 10 个样本
  • 对每个样本进行完整处理:音频加载、预处理、文本读取

9.2 模型预测​

  • 使用 processor 准备模型输入
  • 禁用梯度计算加速推理
  • 使用混合精度提高效率
  • 模型输出 logits 后,通过 argmax 获取预测的 token ID​
  • 解码 token ID 为文本,得到最终识别结果

六、结果展示

训练的历史记录 Loss 和 WER 如下所示:

在验证集上得到的部分测试结果如下(下面的语音是假的):

这个才是真的 (因为这个是视频):

中文语音转换

如果你喜欢我的文章,不妨给小周一个免费的点赞和关注吧!

相关推荐
YJlio几秒前
Registry Usage (RU) 学习笔记(15.5):注册表内存占用体检与 Hive 体量分析
服务器·windows·笔记·python·学习·tcp/ip·django
奔波霸的伶俐虫2 分钟前
redisTemplate.opsForList()里面方法怎么用
java·开发语言·数据库·python·sql
datamonday3 分钟前
[EAI-037] π0.6* 基于RECAP方法与优势调节的自进化VLA机器人模型
人工智能·深度学习·机器人·具身智能·vla
Toky丶8 分钟前
【文献阅读】Pt2-Llm: Post-Training Ternarization For Large Language Models
人工智能·语言模型·自然语言处理
梵得儿SHI8 分钟前
(第七篇)Spring AI 核心技术攻坚:国内模型深度集成与国产化 AI 应用实战指南
java·人工智能·spring·springai框架·国产化it生态·主流大模型的集成方案·麒麟系统部署调优
longze_79 分钟前
生成式UI与未来AI交互变革
人工智能·python·ai·ai编程·cursor·蓝湖
weixin_4380774911 分钟前
CS336 Assignment 4 (data): Filtering Language Modeling Data 翻译和实现
人工智能·python·语言模型·自然语言处理
合方圆~小文12 分钟前
工业摄像头工作原理与核心特性
数据库·人工智能·模块测试
小郭团队12 分钟前
未来PLC会消失吗?会被嵌入式系统取代吗?
c语言·人工智能·python·嵌入式硬件·架构
yesyesido13 分钟前
智能文件格式转换器:文本/Excel与CSV无缝互转的在线工具
开发语言·python·excel