代码文件:
finetune_and_save.py
python
import os
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import soundfile as sf
from torch.utils.data import DataLoader, Dataset
DEVICE = torch.device("cpu")
class SileroFinetuneDataset(Dataset):
def __init__(self, csv_file, labels):
if not os.path.exists(csv_file):
print(f"创建演示索引文件: {csv_file}")
pd.DataFrame([["data/test1.wav", "welcome to the ai speech recognition demo"]]).to_csv(csv_file, index=False, header=False)
self.df = pd.read_csv(csv_file, header=None, names=['path', 'text']).dropna()
self.char_to_idx = {char: i for i, char in enumerate(labels)}
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
path = self.df.iloc[idx]['path'].strip()
text = str(self.df.iloc[idx]['text']).lower()
try:
speech, sample_rate = sf.read(path)
if len(speech.shape) > 1:
speech = np.mean(speech, axis=1)
waveform = torch.from_numpy(speech).float()
except Exception as e:
print(f"读取 {path} 出错: {e}")
waveform = torch.zeros(16000)
target = torch.tensor([self.char_to_idx[c] for c in text if c in self.char_to_idx], dtype=torch.long)
return waveform, target
def collate_fn(batch):
waveforms, targets = zip(*batch)
input_lengths = torch.tensor([w.shape[0] for w in waveforms], dtype=torch.long)
target_lengths = torch.tensor([t.shape[0] for t in targets], dtype=torch.long)
waveforms_padded = torch.nn.utils.rnn.pad_sequence(waveforms, batch_first=True)
targets_padded = torch.nn.utils.rnn.pad_sequence(targets, batch_first=True)
return waveforms_padded, targets_padded, input_lengths, target_lengths
def run_finetuning():
LANG = 'en'
LR = 1e-6
EPOCHS = 10
SAVE_PATH = "silero_stt_finetuned.pt"
print(f"📦 正在加载 Silero {LANG} 预训练模型 (CPU 模式)...")
model, decoder, utils = torch.hub.load(repo_or_dir='snakers4/silero-models',
model='silero_stt',
language=LANG,
device=DEVICE)
model.train()
for param in model.parameters():
param.requires_grad = True
dataset = SileroFinetuneDataset('metadata.csv', model.labels)
dataloader = DataLoader(dataset, batch_size=1, collate_fn=collate_fn, shuffle=True)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
criterion = nn.CTCLoss(blank=0, reduction='mean', zero_infinity=True)
print(f"🏗️ 开始微调任务...")
for epoch in range(EPOCHS):
epoch_loss = 0
for batch in dataloader:
waveforms, targets, input_lengths, target_lengths = batch
model.train()
torch.set_grad_enabled(True)
optimizer.zero_grad()
log_probs = model(waveforms)
if not log_probs.requires_grad:
log_probs.requires_grad_(True)
log_probs_trans = log_probs.transpose(0, 1)
output_lengths = torch.full(size=(waveforms.size(0),),
fill_value=log_probs_trans.size(0),
dtype=torch.long)
loss = criterion(log_probs_trans, targets, output_lengths, target_lengths)
if loss.requires_grad:
loss.backward() # 计算梯度
optimizer.step() # 更新权重
epoch_loss += loss.item()
else:
print("⚠️ 警告:当前 Batch 无法计算梯度,请检查模型是否被锁定。")
print(f"🔹 轮次 {epoch+1}/{EPOCHS} | 平均误差: {epoch_loss/len(dataloader):.6f}")
# if (epoch_loss / len(dataloader)) < 0.05:
# print(f"🎯 达到理想误差值,提前停止训练以防止过拟合。")
# break
torch.save(model.state_dict(), SAVE_PATH)
print(f"\n✅ 权重已生成并保存至: {SAVE_PATH}")
if __name__ == "__main__":
os.makedirs("data", exist_ok=True)
run_finetuning()
目录结构:
-- finetune_and_save.py
-- metadata.csv
-- silero_stt_finetuned.pt
-- data
---- data/test1.wav
---- data/test2.wav
---- data/test3.wav