大模型之二十八-语音识别Whisper进阶

在上一篇博客大模型之二十七-语音识别Whisper实例浅析中遗留了几个问题,这里来看一下前两个问题。

1.如果不是Huggingface上可以下载的数据该怎么办?

2.上面的代码是可以训练了,但是训练的时候loss真的会和我们预期一致吗?比如如下怎么办?

进阶内容

在Whisper语音识别fine-tune的例子中,我们使用的是Huggingface封装好的数据加载以及Transformer工具,这将很多底层细节对开发人员屏蔽了,但是对于技术人员而言,这还远远不够,本篇通过一个要解决两个问题:

1.数据集是私有的,并不是Huggingface开源的数据集

2.不使用Huggingface封装好的Training pipeline,在Whisper开源的源代码基础之上fine-tune模型,并验证准确性。

整个框架代码使用pytorch-lightning来实现,目前很多优秀的比较大的开源都是实用pytorch-lightning来实现的。

安装一些python库

首先下载Whisper源代码,并且

shell 复制代码
! pip install git+https://github.com/openai/whisper.git
! pip install jiwer 
! pip install pytorch-lightning==2.4.0
! pip install -qqq evaluate==0.2.2

导入必要的python包

python 复制代码
import os
import glob
import numpy as np

try:
    import tensorflow  # required in Colab to avoid protobuf compatibility issues
except ImportError:
    pass

import torch
from torch import nn
import pandas as pd
import whisper
import torchaudio
import torchaudio.transforms as at

from pytorch_lightning import LightningModule
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

from tqdm.notebook import tqdm
import evaluate

from transformers import (
    AdamW,
    get_linear_schedule_with_warmup
)

遗留的第一个问题--数据集

这里的数据集基于清华大学开源的30小时中文照着文本读而录的音频,原下载地址

为了减小资源的开销,在有限的资源下,多迭代epoch,这里对数据集做了处理:

  • 将数据集缩到了10个小时/30小时,
  • 去掉了txt里音素的标注,只留文本,因为在这个数据集开源的时候,那时语音识别系统还是基于音素的。

可以关注私信我,联系索取处理之后的语料。

数据集处理

python 复制代码
import glob

DATASET_DIR = "/kaggle/input/th30-all"
SAMPLE_RATE = 16000
BATCH_SIZE = 4
TRAIN_RATE = 0.85

#whipser的输入是30s,16kHz采样率,最长480000 sample
AUDIO_MAX_LENGTH = 480000
TEXT_MAX_LENGTH = 120

DEVICE = "gpu" if torch.cuda.is_available() else "cpu"

###################### 读取数据信息并分离出train和val
dataset_dir = DATASET_DIR
transcripts_path_list = glob.glob(os.path.join(dataset_dir, "*.txt"))
print(len(transcripts_path_list))
shell 复制代码
13388

读取数据信息并分离出train和val

python 复制代码
dataset_dir = DATASET_DIR
transcripts_path_list = glob.glob(os.path.join(dataset_dir, "*.txt"))
print(len(transcripts_path_list))

def load_wave(wave_path, sample_rate:int=16000) -> torch.Tensor:
    waveform, sr = torchaudio.load(wave_path, normalize=True)
    if sample_rate != sr:
        waveform = at.Resample(sr, sample_rate)(waveform)
    return waveform

def get_audio_file_list(transcripts_path_list, text_max_length=120, audio_max_sample_length=480000, sample_rate=16000):
    audio_transcript_pair_list = []
    for transcripts_path in tqdm(transcripts_path_list):
        # audio文件目录确认
        audio_dir = os.path.dirname(transcripts_path)

        # 从翻译文本获取音频和文本
        with open(transcripts_path, "r") as f:
            text_list = f.readlines()
        for text in text_list:
            audio_id, text = text.replace("\n", "").split(":")
            #print(audio_id, text)

            audio_path = os.path.join(audio_dir, f"{audio_id}.wav")
            if os.path.exists(audio_path):
                # 检查数据
                audio = load_wave(audio_path, sample_rate=sample_rate)[0]
                if len(text) > text_max_length or len(audio) > audio_max_sample_length:
                    print(len(text), len(audio))
                    continue
                audio_transcript_pair_list.append((audio_id, str(audio_path), text))
    return audio_transcript_pair_list

train_num = int(len(transcripts_path_list) * TRAIN_RATE)
train_transcripts_path_list, eval_transcripts_path_list = transcripts_path_list[:train_num], transcripts_path_list[train_num:]
train_audio_transcript_pair_list = get_audio_file_list(train_transcripts_path_list, TEXT_MAX_LENGTH, AUDIO_MAX_LENGTH, SAMPLE_RATE)
eval_audio_transcript_pair_list = get_audio_file_list(eval_transcripts_path_list, TEXT_MAX_LENGTH, AUDIO_MAX_LENGTH, SAMPLE_RATE)
print("TRAIN AUDIO DATASET NUM: ", len(train_audio_transcript_pair_list))
print("EVAL AUDIO DATASET NUM: ", len(eval_audio_transcript_pair_list))
shell 复制代码
13388
  0%|          | 0/11379 [00:00<?, ?it/s]  
  0%|          | 0/2009 [00:00<?, ?it/s]
TRAIN AUDIO DATASET NUM:  11379
EVAL AUDIO DATASET NUM:  2009

Data loader

python 复制代码
woptions = whisper.DecodingOptions(language="zh", without_timestamps=True)
wmodel = whisper.load_model(name="small",download_root="./whisper-small")
wtokenizer = whisper.tokenizer.get_tokenizer(True, language="zh", task=woptions.task)

class Th30Dataset(torch.utils.data.Dataset):
    def __init__(self, audio_info_list, tokenizer, sample_rate) -> None:
        super().__init__()

        self.audio_info_list = audio_info_list
        self.sample_rate = sample_rate
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.audio_info_list)

    def __getitem__(self, index):
        audio_id, audio_path, text = self.audio_info_list[index]

        #aduio mono
        audio = load_wave(audio_path, sample_rate=self.sample_rate)
        audio = whisper.pad_or_trim(audio.flatten(), AUDIO_MAX_LENGTH)
        mel = whisper.log_mel_spectrogram(audio)

        #text
        text = [*self.tokenizer.sot_sequence_including_notimestamps] + self.tokenizer.encode(text)
        labels = text[1:] + [self.tokenizer.eot]

        return {
            "input_ids": mel,
            "labels": labels,
            "dec_input_ids": text
        }
class WhisperDataCollatorWhithPadding:
    def __call__(self, features):
        input_ids, labels, dec_input_ids = [], [], []
        for f in features:
            input_ids.append(f["input_ids"])
            labels.append(f["labels"])
            dec_input_ids.append(f["dec_input_ids"])

        input_ids = torch.concat([input_id[None, :] for input_id in input_ids])


        label_lengths = [len(lab) for lab in labels]
        dec_input_ids_length = [len(e) for e in dec_input_ids]
        max_label_len = max(label_lengths + dec_input_ids_length)

        labels = [np.pad(lab, (0, max_label_len - lab_len), 'constant', constant_values=-100) for lab, lab_len in zip(labels, label_lengths)]
        dec_input_ids = [np.pad(e, (0, max_label_len - e_len), 'constant', constant_values=50257) for e, e_len in zip(dec_input_ids, dec_input_ids_length)] # 50257 is eot token id

        batch = {
            "labels": labels,
            "dec_input_ids": dec_input_ids
        }

        batch = {k: torch.tensor(np.array(v), requires_grad=False) for k, v in batch.items()}
        batch["input_ids"] = input_ids

        return batch

dataset = Th30Dataset(eval_audio_transcript_pair_list, wtokenizer, SAMPLE_RATE)
loader = torch.utils.data.DataLoader(dataset, batch_size=2, collate_fn=WhisperDataCollatorWhithPadding())

这是典型的Pytorch而不是前篇中Huggingface的数据加载方法,需要实现datasetDataLoader,详细参考Pytorch Lightning官方文档。至此,遗留的第一个问题解决。

验证数据集加载

python 复制代码
DEVICE = "gpu" if torch.cuda.is_available() else "cpu"
for b in loader:
    print(b["labels"].shape)
    print(b["input_ids"].shape)
    print(b["dec_input_ids"].shape)

    for token, dec in zip(b["labels"], b["dec_input_ids"]):
        token[token == -100] = wtokenizer.eot
        text = wtokenizer.decode(token)
        print(text)

        dec[dec == -100] = wtokenizer.eot
        text = wtokenizer.decode(dec)
        print(text)
    break
shell 复制代码
torch.Size([2, 50])
torch.Size([2, 80, 3000])
torch.Size([2, 50])
<|zh|><|transcribe|><|notimestamps|>节目单上赫然印着特邀中央乐团百余位演奏演唱家微妙地避开了矛盾<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>
<|startoftranscript|><|zh|><|transcribe|><|notimestamps|>节目单上赫然印着特邀中央乐团百余位演奏演唱家微妙地避开了矛盾<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>
<|zh|><|transcribe|><|notimestamps|>放眼望去永定河两旁人声鼎沸彩旗飘扬推土机挖土机运土车正紧张地忙碌着<|endoftext|>
<|startoftranscript|><|zh|><|transcribe|><|notimestamps|>放眼望去永定河两旁人声鼎沸彩旗飘扬推土机挖土机运土车正紧张地忙碌着

验证解码器

python 复制代码
with torch.no_grad():
    audio_features = wmodel.encoder(b["input_ids"].cuda())
    input_ids = b["input_ids"]
    labels = b["labels"].long()
    dec_input_ids = b["dec_input_ids"].long()

        
    audio_features = wmodel.encoder(input_ids.cuda())
    print(dec_input_ids)
    print(input_ids.shape, dec_input_ids.shape, audio_features.shape)
    print(audio_features.shape)
    print()

# 计算解码器的输出
out = wmodel.decoder(dec_input_ids.cuda(), audio_features)

print(out.shape)
print(out.view(-1, out.size(-1)).shape)
print(b["labels"].view(-1).shape)
shell 复制代码
tensor([[50258, 50260, 50359, 50363, 45161, 11386, 47446,  5708,  5266,   104,
          5823, 35825, 20708, 17682,  3023,   222,  5975,  1787,   106, 44365,
          3919,    95, 31906,  1593,   247, 11160, 31382,  1881,   237, 31382,
         39861,  5155, 39152,  5648,   247, 10928,  3330,   123, 18937,  2289,
          5881,   249,  5419,   122, 50257, 50257, 50257, 50257, 50257, 50257],
        [50258, 50260, 50359, 50363, 12744, 25281, 22694,  6734, 42503, 12088,
          3308,   111, 36257,  4479,   223,  4035, 32045, 41111,   236,  3308,
           116,  7391,   102,  4479,   245, 11808,   246,  3416,   105, 33597,
         45506, 37960,  8501,   244, 45506, 37960,  3316,   238, 45506, 17819,
            99, 15789,  7732,   100, 44059, 10928, 48839, 16337,   234, 20708]])
torch.Size([2, 80, 3000]) torch.Size([2, 50]) torch.Size([2, 1500, 768])
torch.Size([2, 1500, 768])

torch.Size([2, 50, 51865])
torch.Size([100, 51865])
torch.Size([100])

token转文本输出

tokens = torch.argmax(out, dim=2)
for token in tokens:
    token[token == -100] = wtokenizer.eot
    text = wtokenizer.decode(token)
    print(text)
shell 复制代码
<|zh|><|translate|><|notimestamps|>节目单上赫然印着特邀中央乐团百余位演奏歌唱家微妙的避开了矛盾<|endoftext|><|endoftext|> <|endoftext|><|endoftext|><|endoftext|><|endoftext|>
<|zh|><|transcribe|><|notimestamps|>放眼望去,定河兩旁人身顎沸彩旗飘扬推土机挖土机运土车正紧张地忙碌着<|endoftext|>

构造trainer

python 复制代码
class Config:
    learning_rate = 0.0001
    weight_decay = 0.01
    adam_epsilon = 1e-8
    warmup_steps = 2
    batch_size = 16
    num_worker = 2
    num_train_epochs = 1000
    gradient_accumulation_steps = 1
    sample_rate = SAMPLE_RATE


class WhisperModelModule(LightningModule):
    def __init__(self, cfg: Config, model_name="small", lang="zh", train_dataset=[], eval_dataset=[]) -> None:
        super().__init__()
        self.options = whisper.DecodingOptions(language=lang, without_timestamps=True)
        self.model = whisper.load_model(model_name)
        self.tokenizer = whisper.tokenizer.get_tokenizer(True, language="zh", task=self.options.task)

        # only decoder training
        for p in self.model.encoder.parameters():
            p.requires_grad = False

        self.loss_fn = nn.CrossEntropyLoss(ignore_index=-100)
        self.metrics_wer = evaluate.load("wer")
        self.metrics_cer = evaluate.load("cer")

        self.cfg = cfg
        self.__train_dataset = train_dataset
        self.__eval_dataset = eval_dataset

    def forward(self, x):
        return self.model(x)
    def training_step(self, batch, batch_id):
        input_ids = batch["input_ids"]
        labels = batch["labels"].long()
        dec_input_ids = batch["dec_input_ids"].long()

        with torch.no_grad():
            audio_features = self.model.encoder(input_ids)

        out = self.model.decoder(dec_input_ids, audio_features)
        loss = self.loss_fn(out.view(-1, out.size(-1)), labels.view(-1))
        self.log("train/loss", loss, on_step=False, on_epoch=True,  prog_bar=True, logger=True)
        return loss
    
    def on_train_epoch_end(self):
        avg_loss = self.trainer.callback_metrics.get("train/loss")
        
        # 获取当前的 epoch 数量
        epoch = self.current_epoch
        
        print(f"Epoch: {epoch}, Training - Loss: {avg_loss:.4f}")

    def validation_step(self, batch, batch_id):
        input_ids = batch["input_ids"]
        labels = batch["labels"].long()
        dec_input_ids = batch["dec_input_ids"].long()


        audio_features = self.model.encoder(input_ids)
        out = self.model.decoder(dec_input_ids, audio_features)

        loss = self.loss_fn(out.view(-1, out.size(-1)), labels.view(-1))

        out[out == -100] = self.tokenizer.eot
        labels[labels == -100] = self.tokenizer.eot

        o_list, l_list = [], []
        for o, l in zip(out, labels):
            o = torch.argmax(o, dim=1)
            o_list.append(self.tokenizer.decode(o))
            l_list.append(self.tokenizer.decode(l))
        wer = self.metrics_wer.compute(references=l_list, predictions=o_list)

        self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.log("val/wer", wer, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        # 打印到终端
        #print(f"Validation - Loss: {loss:.4f}, WER: {wer:.4f}")

        return {
            "wer": wer,
            "loss": loss
        }
    
    def on_validation_epoch_end(self):
        avg_loss = self.trainer.callback_metrics.get("val/loss")
        avg_wer = self.trainer.callback_metrics.get("val/wer")
        
        # 获取当前的 epoch 数量
        epoch = self.current_epoch

        print(f"Epoch: {epoch}, Validation - Loss: {avg_loss:.4f}, WER: {avg_wer:.4f}")
    

    def configure_optimizers(self):
        """创建优化程序和调度器 """
        model = self.model
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in model.named_parameters()
                            if not any(nd in n for nd in no_decay)],
                "weight_decay": self.cfg.weight_decay,
            },
            {
                "params": [p for n, p in model.named_parameters()
                            if any(nd in n for nd in no_decay)],
                "weight_decay": 0.0,
            },
        ]
        optimizer = AdamW(optimizer_grouped_parameters,
                          lr=self.cfg.learning_rate,
                          eps=self.cfg.adam_epsilon)
        self.optimizer = optimizer

        scheduler = get_linear_schedule_with_warmup(
            optimizer, num_warmup_steps=self.cfg.warmup_steps,
            num_training_steps=self.t_total
        )
        self.scheduler = scheduler

        return [optimizer], [{"scheduler": scheduler, "interval": "step", "frequency": 1}]

    def setup(self, stage=None):
        """初始设置(读取数据集)"""

        if stage == 'fit' or stage is None:
            self.t_total = (
                    (len(self.__train_dataset) // (self.cfg.batch_size))
                    // self.cfg.gradient_accumulation_steps
                    * float(self.cfg.num_train_epochs)
            )

    def train_dataloader(self):
        """ 创建训练数据加载程序 """
        dataset = Th30Dataset(self.__train_dataset, self.tokenizer, self.cfg.sample_rate)
        return torch.utils.data.DataLoader(dataset,
                                           batch_size=self.cfg.batch_size,
                                           drop_last=True, shuffle=True, num_workers=self.cfg.num_worker,
                                           collate_fn=WhisperDataCollatorWhithPadding()
                                           )

    def val_dataloader(self):
        """ 创建验证数据加载程序 """
        dataset = Th30Dataset(self.__eval_dataset, self.tokenizer, self.cfg.sample_rate)
        return torch.utils.data.DataLoader(dataset,
                                           batch_size=self.cfg.batch_size,
                                           num_workers=self.cfg.num_worker,
                                           collate_fn=WhisperDataCollatorWhithPadding()
                                           )  

主要是对LightningModule类相关方法的重载,定义了train、validate以及optimizer的行为,以及在训练过程中日志和相关信息、checkpoint的保存。

启动训练

python 复制代码
log_output_dir = "./logs"
check_output_dir = "./artifacts"

train_name = "whisper"
train_id = "00001"

model_name = "small"
lang = "zh"


cfg = Config()

# os.mkdir(log_output_dir)
# os.mkdir(check_output_dir)

tflogger = TensorBoardLogger(
    save_dir=log_output_dir,
    name=train_name,
    version=train_id
)

from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

checkpoint_callback = ModelCheckpoint(
    dirpath=f"{check_output_dir}/checkpoint",
    filename="checkpoint-{epoch:04d}",
    save_top_k=2, # all model save
    save_on_train_epoch_end=False,
    monitor='val/wer',  # 需要监控的验证损失
    mode='min',  # 最小化 val_loss
    verbose=True  # 打印更多的信息到控制台
)
callback_list = [checkpoint_callback, LearningRateMonitor(logging_interval="epoch")]
model = WhisperModelModule(cfg, model_name, lang, train_audio_transcript_pair_list, eval_audio_transcript_pair_list)

trainer = Trainer(
    precision=16,
    accelerator="gpu",
    max_epochs=cfg.num_train_epochs,
    check_val_every_n_epoch=2,
    accumulate_grad_batches=cfg.gradient_accumulation_steps,
    logger=tflogger,
    callbacks=callback_list
)

trainer.fit(model)
```shell
对于10小时数据集,你可能看到如下输出:

Epoch: 0, Training - Loss: 1.0872

Epoch: 1, Validation - Loss: 0.4207, WER: 0.9847

Epoch: 1, Training - Loss: 0.2955

Epoch: 2, Training - Loss: 0.5555

Epoch: 3, Validation - Loss: 0.2505, WER: 0.9006

Epoch: 3, Training - Loss: 0.0979

Epoch: 4, Training - Loss: 0.0602

Epoch: 5, Validation - Loss: 0.2889, WER: 0.8764

Epoch: 5, Training - Loss: 0.0721

Epoch: 6, Training - Loss: 0.0947

Epoch: 7, Validation - Loss: 0.3839, WER: 0.9809

Epoch: 7, Training - Loss: 0.1379

对于30小时数据集,即使用完整的th30数据,其中85%用于Training,而15%用于validation你可能看到如下输出:
```shell
3051.3s	121	Epoch: 1, Validation - Loss: 0.0588, WER: 0.3499
3061.1s	122	Epoch: 1, Training - Loss: 0.0340
4279.7s	123	Epoch: 2, Training - Loss: 0.0268
5691.4s	124	Epoch: 3, Validation - Loss: 0.0676, WER: 0.8318
5701.1s	125	Epoch: 3, Training - Loss: 0.0201
6919.2s	126	Epoch: 4, Training - Loss: 0.0257
8329.5s	127	Epoch: 5, Validation - Loss: 0.0484, WER: 0.8472
8329.5s	128	Epoch: 5, Training - Loss: 0.0144
9547.5s	129	Epoch: 6, Training - Loss: 0.1127
10959.3s	130	Epoch: 7, Validation - Loss: 0.0422, WER: 0.3982
10969.4s	131	Epoch: 7, Training - Loss: 0.0053
12188.2s	132	Epoch: 8, Training - Loss: 0.0076
13600.0s	133	Epoch: 9, Validation - Loss: 0.0482, WER: 0.8158
13600.0s	134	Epoch: 9, Training - Loss: 0.0126
14819.0s	135	Epoch: 10, Training - Loss: 0.0152
16230.8s	136	Epoch: 11, Validation - Loss: 0.0544, WER: 0.6829
16230.8s	137	Epoch: 11, Training - Loss: 0.0114
17450.1s	138	Epoch: 12, Training - Loss: 0.0174
18862.4s	139	Epoch: 13, Validation - Loss: 0.0523, WER: 0.3225
18872.1s	140	Epoch: 13, Training - Loss: 0.0117
20091.5s	141	Epoch: 14, Training - Loss: 0.0075
21503.2s	142	Epoch: 15, Validation - Loss: 0.0567, WER: 0.5187
21503.2s	143	Epoch: 15, Training - Loss: 0.0137
22722.5s	144	Epoch: 16, Training - Loss: 0.0150
24134.2s	145	Epoch: 17, Validation - Loss: 0.0631, WER: 0.4559
24134.2s	146	Epoch: 17, Training - Loss: 0.0122
25352.9s	147	Epoch: 18, Training - Loss: 0.0120
26765.0s	148	Epoch: 19, Validation - Loss: 0.0523, WER: 0.7387
26765.0s	149	Epoch: 19, Training - Loss: 0.0060
27983.9s	150	Epoch: 20, Training - Loss: 0.0154
29395.1s	151	Epoch: 21, Validation - Loss: 0.0520, WER: 0.4749
29395.1s	152	Epoch: 21, Training - Loss: 0.0073
30612.5s	153	Epoch: 22, Training - Loss: 0.6361
32022.4s	154	Epoch: 23, Validation - Loss: 0.0396, WER: 0.2912
32033.0s	155	Epoch: 23, Training - Loss: 0.0029
33250.8s	156	Epoch: 24, Training - Loss: 0.0036
34662.0s	157	Epoch: 25, Validation - Loss: 0.0461, WER: 0.6043
34662.0s	158	Epoch: 25, Training - Loss: 0.0094
35880.1s	159	Epoch: 26, Training - Loss: 0.0082
37291.0s	160	Epoch: 27, Validation - Loss: 0.0428, WER: 0.7481
37291.0s	161	Epoch: 27, Training - Loss: 0.0051
38509.7s	162	Epoch: 28, Training - Loss: 0.0075
39920.4s	163	Epoch: 29, Validation - Loss: 0.0447, WER: 0.8736
39920.4s	164	Epoch: 29, Training - Loss: 0.0091
41138.9s	165	Epoch: 30, Training - Loss: 0.0088
42549.9s	166	Epoch: 31, Validation - Loss: 0.0530, WER: 0.4500
42549.9s	167	Epoch: 31, Training - Loss: 0.0072

遗留的第二个问题

首先是数据集的问题,因为可以看到随着时长的增加,看到模型训练过程在符合预期方向走,

  1. 最低数据量 :起步来说,至少需要几个小时的音频数据来进行有效的fine-tuning。例如,从10小时开始,这是一个相对较小的数据集,可以用来调试模型和流程。

  2. 中等数据量 :为了获得更佳的效果,推荐使用20至50小时的音频数据。这可以帮助模型更好地学习到特定语言的特性。

  3. 理想数据量 :如果资源允许,使用超过100小时的音频数据将更有助于模型性能的提升。更多的数据可以显著提高模型的泛化能力和准确性。

当然对于大模型,数据质量越高越好,数据多样性越多越好。

进一步通过tensorboard图可以看到:

在运行12个小时之后可以看到WER比一开始的确实下降了不少,但是还没有达到20%左右,最低的WER在0.2912,但是这里可以观察到一个非常有趣的现象:

在观察到训练损失(Training Loss)持续下降而验证损失(Validation Loss)和字错误率 (WER, Word Error Rate) 没有持续改善或波动较大的情况时,这通常是过拟合的一个迹象。在这种情况下,模型在训练数据上表现得越来越好,但在未见过的验证数据上的表现却没有相对应的提升,甚至出现恶化。

由于callback回调中会保持前两个在验证集上WER最小的两个checkpoint,接下来有几个思路:

1.分析模型在验证集上的错误,看是否存在特定模式或类型的错误,这可能帮助诊断问题并指导进一步模型调整,因为我们是在whipser开源的基础上fine-tune的,所以不可能简化模型结构的本身,如减少层数或神经元数目以改善过拟合。

2.可以考虑正则化技术(L2正则化、Dropout)等以有助于缓解过拟现象,增强模型的泛化能力

3.调整训练策略,调整学习率或者使用不同的优化器,以评估模型在验证集上的表现;

4.增加更多数据,帮助模型学习到更多特征,从而提高模型泛化能力

观察验证集识别效果

由于输出缩略或视觉上的相似性,一些小的差异(如标点、空白或特殊字符)可能不容易觉察。这些微小的差异在计算WER时会被考虑进去,但在人眼检查时可能会被忽略。

可以看到基本上是一致的,但是个别词是有出入的,这是因为th30是人工读的,准确性比较高,并不意味着通话、会议、游戏场景的识别率也能如此。

这里再留几个尾巴给读者自己实现:

CER

1.中文是基于字符的语言,通常我们会使用CER(Character Error Rate,字符错误率)来进行更精确的评估。然而,如果你使用的是WER来评估中文语音识别的质量,这里有几点可能需要注意:

  • 在处理中文时,如果WER是基于词的,就必须先进行准确的分词。中文没有明显的词与词之间的分隔,因此分词的准确性对于WER的计算非常关键。错误的分词可能导致高WER,即使识别的字符完全正确。
  • 中文中的一些微小差异,如同音字错误、词序变化或者是语气词的使用,都可以在视觉上看起来非常相似,但在WER的计算中会被视为错误。
  • 你查看的样本可能并不代表整体数据集的平均表现。此外,中文语音识别可能特别擅长处理某些特定的语句或者在某些领域表现更好。

数据质量

除了数据量之外,数据的质量也非常重要:

  • 多样性:数据应该涵盖多种口音、语速和语调,以及不同的背景噪声环境,这将帮助模型在各种输入条件下都能保持稳定的表现。
  • 标注准确性:确保你的数据标注尽可能准确,错误的标注会直接影响模型学习的结果。

预处理和增强

  • 预处理:对音频进行预处理,如采样率转换(确保和模型训练时使用的采样率一致),音量标准化等。
  • 数据增强:可以考虑使用音频数据增强技术,如添加背景噪声、改变语速和音高等,以增加模型的鲁棒性。

资源和迭代

  • 计算资源:fine-tuning一个语音识别模型可能需要大量的计算资源,特别是当使用大量数据时。确保你有足够的GPU资源进行训练。
  • 迭代和评估:在fine-tuning过程中需要多次迭代和评估,以找到最优的模型参数和设置。

总结来说,训练的过程如炼丹,有些训练的经验是不能从小模型直接用到大模型上的。比如small和对于large-v3两种。

在模型相对较小的时候,learning rate的设置可以比较激进,但是对非常大模型的时候,较大的lr可能导致模型一开始loss就无法收敛,是发散的,但如果设置的lr比较小,那可能使得训练的时长成倍增加,怎么办呢?针对很大的模型,warm-up策略是很多时候会使用的。

load weight and inference

python 复制代码
checkpoint_path = "whisper-checkpoint/checkpoint-epoch0023.ckpt"
state_dict = torch.load(checkpoint_path)
print(state_dict.keys())
state_dict = state_dict['state_dict']
shell 复制代码
/tmp/ipykernel_36/4099222220.py:2: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  state_dict = torch.load(checkpoint_path)
dict_keys(['epoch', 'global_step', 'pytorch-lightning_version', 'state_dict', 'loops', 'callbacks', 'optimizer_states', 'lr_schedulers', 'MixedPrecision'])

加载模型参数

python 复制代码
cfg = Config()
whisper_model = WhisperModelModule(cfg)
whisper_model.load_state_dict(state_dict)
shell 复制代码
100%|███████████████████████████████████████| 461M/461M [00:05<00:00, 87.6MiB/s]
Downloading builder script:   0%|          | 0.00/4.49k [00:00<?, ?B/s]
Downloading builder script:   0%|          | 0.00/5.60k [00:00<?, ?B/s]
<All keys matched successfully>

前向推理

python 复制代码
woptions = whisper.DecodingOptions(language="zh", without_timestamps=True)
dataset = Th30Dataset(eval_audio_transcript_pair_list, wtokenizer, SAMPLE_RATE)
loader = torch.utils.data.DataLoader(dataset, batch_size=2, collate_fn=WhisperDataCollatorWhithPadding())

refs = []
res = []
for b in tqdm(loader):
    input_ids = b["input_ids"].half().cuda()
    labels = b["labels"].long().cuda()
    with torch.no_grad():
        #audio_features = whisper_model.model.encoder(input_ids)
        #out = whisper_model.model.decoder(enc_input_ids, audio_features)
        results = whisper_model.model.decode(input_ids, woptions)
        for r in results:
            res.append(r.text)
        
        for l in labels:
            l[l == -100] = wtokenizer.eot
            ref = wtokenizer.decode(l)
            refs.append(ref)
     ```
     打印推理结果
     ```
     for k, v in zip(refs, res):
    print("-"*10)
    print(k)
    print(v)

部分输出结果

shell 复制代码
  ----------
<|zh|><|transcribe|><|notimestamps|>节目单上赫然印着特邀中央乐团百余位演奏演唱家微妙地避开了矛盾<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>
节目单上赫然印着特邀中央乐团百余位演奏歌唱家微妙地避开了矛盾
----------
<|zh|><|transcribe|><|notimestamps|>放眼望去永定河两旁人声鼎沸彩旗飘扬推土机挖土机运土车正紧张地忙碌着<|endoftext|>
放眼望去永定河两旁人声鼎沸彩旗飘扬推土机挖土机运土车正紧张地忙碌着
----------
<|zh|><|transcribe|><|notimestamps|>旅与游的时间比往往旅长游短与游客的愿望相悖<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>
旅与游的时间比往往旅长游短与游客的愿望相悖
----------
<|zh|><|transcribe|><|notimestamps|>中学毕业后他考入尼恩罗德商业学院毕业后曾服兵役并在一家贸易公司任职<|endoftext|>
中学毕业后他考入尼恩罗德商业学院毕业后曾服兵役并在一家贸易公司任职
----------
<|zh|><|transcribe|><|notimestamps|>该片导演为虞石束杨军主要演员有韩夫一李进慕爱秋周桂云金鑫冯云魁等<|endoftext|>
该片导演为虞石束杨军主要演员有韩夫一李进慕爱秋周桂云金鑫冯云魁等
----------
<|zh|><|transcribe|><|notimestamps|>何勰二话没说立即交了一千五百元的押金又为洛桑卓玛买来了全套新衣服和住院用品<|endoftext|><|endoftext|>
何勰二话没说立即交了一千五百元的押金又为洛桑卓玛买来了全套新衣服和住院用品
----------
<|zh|><|transcribe|><|notimestamps|>印加人所创造的文明与玛雅文明阿兹特克文明一起被誉为美洲印第安三大文明<|endoftext|>
印加人所创造的文明与玛雅文明阿兹特克文明一起被誉为美洲印第安三大文明
----------
<|zh|><|transcribe|><|notimestamps|>今天陪以萌找冯邦找得又累又饿但看见以萌那副着急样我一点也吃不下<|endoftext|><|endoftext|><|endoftext|><|endoftext|>
今天陪以萌找冯邦找得又累又饿但看见以萌那副着急样我一点也吃不下
----------
<|zh|><|transcribe|><|notimestamps|>亲英的北爱尔兰新教派武装十二日晚发表声明威胁要报复爱尔兰共和军<|endoftext|>
亲英的北爱尔兰新教派武装十二日晚发表声明威胁要报复爱尔兰共和军
----------
<|zh|><|transcribe|><|notimestamps|>小仲不顾闲言碎语一天几趟往我家跑为我洗衣做饭熬药煎汤<|endoftext|><|endoftext|><|endoftext|>
小仲不顾闲言碎语一天几趟往我家跑为我洗衣做饭熬药煎汤
----------
<|zh|><|transcribe|><|notimestamps|>这位病人因贲门下胃底大弯静脉曲张伴血管瘤破裂胃内大量喷血<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>
这位病人因贲门下胃底大弯静脉曲张伴血管瘤破裂胃内大量喷血
----------
<|zh|><|transcribe|><|notimestamps|>驻藏边防某部二连战士赵金站岗时隐隐约约听见营区外的不远处有哭泣声<|endoftext|>
驻藏边防某部二连战士赵金站岗时隐隐约约听见营区外的不远处有哭泣声
----------
<|zh|><|transcribe|><|notimestamps|>其种植的红富士苹果以色泽艳丽果质细脆汁多味美和极耐贮运而享誉海内外<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>
其种植的红富士苹果以色泽艳丽果质细脆汁多味美和极耐贮存而享誉海内外
----------
<|zh|><|transcribe|><|notimestamps|>一九四一年十一月陕甘宁边区根据三三制原则举行参议会议员竞选一位名叫森健的学员被推为候选人<|endoftext|>
一九四一年十一月陕甘宁边区根据三三制原则举行参议会议员竞选一位名叫森健的学员被推为候选人
----------
<|zh|><|transcribe|><|notimestamps|>当船往下漂时白唇鹿扬起四蹄在岸边追随好像是送行一直跑了十几里太亲切了<|endoftext|>
当船往下漂时白唇鹿扬起四蹄在岸边追赶好像是送行一直跑了十几里太亲切了
----------
<|zh|><|transcribe|><|notimestamps|>有的单位按年人均月收入减去费用八百元后的余额为应纳税所得额<|endoftext|><|endoftext|><|endoftext|><|endoftext|>
有的单位按年人均月收入减去费用八百元后的余额为应纳税所得额
----------
<|zh|><|transcribe|><|notimestamps|>女性腰部以上特别肥胖者易患乳腺癌腰围与臀围差别不大者患癌率比一般妇女高六倍<|endoftext|>
女性腰部以上特别肥胖者易患乳腺癌腰围与臀围差别不大者患癌率比一般妇女高六倍
----------
<|zh|><|transcribe|><|notimestamps|>日本队在男子团体赛中获银牌队员岩井哲贤在个人全能赛也夺得一枚银牌<|endoftext|><|endoftext|><|endoftext|>
日本队在男子团体赛中获银牌队员岩井哲贤在个人全能赛也夺得一枚银牌
----------
<|zh|><|transcribe|><|notimestamps|>如此举措源于杭州娃哈哈食品集团公司总经理宗庆后对市场特质的洞悉<|endoftext|><|endoftext|><|endoftext|>
如此举措源于杭州娃哈哈食品集团公司总经理宗庆后对市场特质的洞悉
  ```
接下来还有三个问题对于应用更需要细致考虑:
1.Whisper除了识别,还有直接翻译功能,在以前要先识别成中文,再汉译英等,这个好处是显而易见的,首先只要一个模型,节约部分人力、机器以及服务端GPU,业务场景上可以是会议的实时翻译、看英文视频实时翻译成中文,这会减少latency,用户体验也更好;
2.如何在实时的流式场景中使用?
3.kv-caching是个什么技术?12倍是如何做到的?这在工程部署商用价值非常大。

欢迎点赞、收藏、关注,以便及时收到下一篇推送。
相关推荐
井底哇哇4 小时前
ChatGPT是强人工智能吗?
人工智能·chatgpt
Coovally AI模型快速验证4 小时前
MMYOLO:打破单一模式限制,多模态目标检测的革命性突破!
人工智能·算法·yolo·目标检测·机器学习·计算机视觉·目标跟踪
AI浩4 小时前
【面试总结】FFN(前馈神经网络)在Transformer模型中先升维再降维的原因
人工智能·深度学习·计算机视觉·transformer
可为测控4 小时前
图像处理基础(4):高斯滤波器详解
人工智能·算法·计算机视觉
一水鉴天5 小时前
为AI聊天工具添加一个知识系统 之63 详细设计 之4:AI操作系统 之2 智能合约
开发语言·人工智能·python
倔强的石头1065 小时前
解锁辅助驾驶新境界:基于昇腾 AI 异构计算架构 CANN 的应用探秘
人工智能·架构
佛州小李哥6 小时前
Agent群舞,在亚马逊云科技搭建数字营销多代理(Multi-Agent)(下篇)
人工智能·科技·ai·语言模型·云计算·aws·亚马逊云科技
说私域6 小时前
社群裂变+2+1链动新纪元:S2B2C小程序如何重塑企业客户管理版图?
大数据·人工智能·小程序·开源
程序猿阿伟7 小时前
《探秘鸿蒙Next:如何保障AI模型轻量化后多设备协同功能一致》
人工智能·华为·harmonyos
2401_897579657 小时前
AI赋能Flutter开发:ScriptEcho助你高效构建跨端应用
前端·人工智能·flutter