Silero-VAD模型自定义微调

代码文件:

finetune_and_save.py

python 复制代码
import os
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import soundfile as sf
from torch.utils.data import DataLoader, Dataset

DEVICE = torch.device("cpu")
class SileroFinetuneDataset(Dataset):
    def __init__(self, csv_file, labels):
        if not os.path.exists(csv_file):
            print(f"创建演示索引文件: {csv_file}")
            pd.DataFrame([["data/test1.wav", "welcome to the ai speech recognition demo"]]).to_csv(csv_file, index=False, header=False)

        self.df = pd.read_csv(csv_file, header=None, names=['path', 'text']).dropna()
        self.char_to_idx = {char: i for i, char in enumerate(labels)}

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        path = self.df.iloc[idx]['path'].strip()
        text = str(self.df.iloc[idx]['text']).lower()

        try:
            speech, sample_rate = sf.read(path)
            if len(speech.shape) > 1:
                speech = np.mean(speech, axis=1)
            waveform = torch.from_numpy(speech).float()
        except Exception as e:
            print(f"读取 {path} 出错: {e}")
            waveform = torch.zeros(16000)

        target = torch.tensor([self.char_to_idx[c] for c in text if c in self.char_to_idx], dtype=torch.long)
        return waveform, target

def collate_fn(batch):
    waveforms, targets = zip(*batch)
    input_lengths = torch.tensor([w.shape[0] for w in waveforms], dtype=torch.long)
    target_lengths = torch.tensor([t.shape[0] for t in targets], dtype=torch.long)
    waveforms_padded = torch.nn.utils.rnn.pad_sequence(waveforms, batch_first=True)
    targets_padded = torch.nn.utils.rnn.pad_sequence(targets, batch_first=True)
    return waveforms_padded, targets_padded, input_lengths, target_lengths

def run_finetuning():
    LANG = 'en'
    LR = 1e-6
    EPOCHS = 10
    SAVE_PATH = "silero_stt_finetuned.pt"
    print(f"📦 正在加载 Silero {LANG} 预训练模型 (CPU 模式)...")
    model, decoder, utils = torch.hub.load(repo_or_dir='snakers4/silero-models',
                                           model='silero_stt',
                                           language=LANG,
                                           device=DEVICE)

    model.train()
    for param in model.parameters():
        param.requires_grad = True
    dataset = SileroFinetuneDataset('metadata.csv', model.labels)
    dataloader = DataLoader(dataset, batch_size=1, collate_fn=collate_fn, shuffle=True)
    optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
    criterion = nn.CTCLoss(blank=0, reduction='mean', zero_infinity=True)

    print(f"🏗️ 开始微调任务...")
    for epoch in range(EPOCHS):
        epoch_loss = 0
        for batch in dataloader:
            waveforms, targets, input_lengths, target_lengths = batch
            model.train()
            torch.set_grad_enabled(True)
            optimizer.zero_grad()
            log_probs = model(waveforms)
            if not log_probs.requires_grad:
                log_probs.requires_grad_(True)
            log_probs_trans = log_probs.transpose(0, 1)
            output_lengths = torch.full(size=(waveforms.size(0),),
                                        fill_value=log_probs_trans.size(0),
                                        dtype=torch.long)
            loss = criterion(log_probs_trans, targets, output_lengths, target_lengths)
            if loss.requires_grad:
                loss.backward()      # 计算梯度
                optimizer.step()     # 更新权重
                epoch_loss += loss.item()
            else:
                print("⚠️ 警告:当前 Batch 无法计算梯度,请检查模型是否被锁定。")

        print(f"🔹 轮次 {epoch+1}/{EPOCHS} | 平均误差: {epoch_loss/len(dataloader):.6f}")
        # if (epoch_loss / len(dataloader)) < 0.05:
        #     print(f"🎯 达到理想误差值,提前停止训练以防止过拟合。")
        #     break

    torch.save(model.state_dict(), SAVE_PATH)
    print(f"\n✅ 权重已生成并保存至: {SAVE_PATH}")

if __name__ == "__main__":
    os.makedirs("data", exist_ok=True)
    run_finetuning()

目录结构:

复制代码
-- finetune_and_save.py
-- metadata.csv
-- silero_stt_finetuned.pt
-- data
---- data/test1.wav
---- data/test2.wav
---- data/test3.wav
相关推荐
风吹夏回7 小时前
Python 全局异常处理:从“满屏 try-except”到优雅兜底
开发语言·python
小熊Coding8 小时前
Python爬取当当网二手图书项目实战!
开发语言·爬虫·python·beautifulsoup·requests·二手图书
秋98 小时前
Java项目运行5天左右自动宕机:系统性定位与解决方案
java·开发语言·python
小江的记录本8 小时前
【JVM虚拟机】垃圾回收GC:垃圾收集器:CMS:核心原理、回收流程、优缺点、废弃原因(附《思维导图》+《面试高频考点清单》)
java·jvm·后端·python·spring·面试·maven
EasyCVR8 小时前
国标GB28181视频监控平台EasyCVR行业解决方案深度解读——雪亮工程、智慧城市与智慧交通
人工智能·音视频·智慧城市
田里的水稻9 小时前
OE_ubuntu26.04与宿主机之间复制粘贴内容
人工智能·python·机器人
jiayong2310 小时前
02 创建虚拟环境
python
旺仔来了10 小时前
不联网的Linux下部署python环境
linux·开发语言·python
小江的记录本10 小时前
【JVM虚拟机】垃圾回收GC:垃圾回收算法:标记-清除、标记-复制、标记-整理、分代收集(附《思维导图》+《面试高频考点清单》)
java·jvm·后端·python·算法·安全·面试
IP搭子来一个10 小时前
爬虫采集大量返回 403、429,到底卡在哪一环?
网络·爬虫·python