目录
项目概述
WEKWS (Wake-up Word Keyword Spotting;https://arxiv.org/pdf/2210.16743.pdf) 是一个基于PyTorch的端到端关键词识别系统,支持多种深度学习模型架构,包括TCN、FSMN、GRU、MDTC等。该项目提供了完整的训练、推理、部署和测试流程。
核心特性
- 多模型支持: TCN、DS_TCN、FSMN、GRU、MDTC、TDNN-LSTM等;可魔改
- 多损失函数: Max-Pooling Loss、CTC Loss
- 流式/非流式推理: 支持实时流式关键词检测
- 量化部署: 支持PyTorch JIT和ONNX导出
- 数据增强: SpecAugment、速度扰动、混响/噪声添加
系统架构总览
┌─────────────────────────────────────────────────────────────┐
│ WEKWS KWS System │
├─────────────────────────────────────────────────────────────┤
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │
│ │ Data │ │ Model │ │ Training │ │
│ │ Pipeline │ │ Architecture│ │ Pipeline │ │
│ └─────────────┘ └─────────────┘ └─────────────┘ │
│ │ │ │ │
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │
│ │ Feature │ │ KWS │ │ Inference │ │
│ │ Extraction │ │ Model │ │ Engine │ │
│ └─────────────┘ └─────────────┘ └─────────────┘ │
│ │ │ │ │
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │
│ │ Dataset │ │ Classifier │ │ Deployment │ │
│ │ Processor │ │ & Loss │ │ Tools │ │
│ └─────────────┘ └─────────────┘ └─────────────┘ │
└─────────────────────────────────────────────────────────────┘
核心组件
- 数据管道 (
wekws/dataset/): 数据加载、预处理、增强 - 模型架构 (
wekws/model/): 各种KWS模型实现 - 训练管道 (
wekws/bin/): 训练、评估、推理脚本 - 工具集 (
wekws/utils/): 辅助功能模块
模型架构详解
1. TCN (Temporal Convolutional Network)
架构特点
- 因果卷积: 确保模型只能看到过去的信息
- 膨胀卷积: 扩大感受野而不增加参数量
- 残差连接: 缓解梯度消失问题
- 深度可分离卷积: 减少参数量,提高效率
核心实现
python
class TCN(nn.Module):
def __init__(self, num_layers, channel, kernel_size, dropout, block_class):
super().__init__()
self.network = nn.ModuleList()
for i in range(num_layers):
dilation = 2**i # 指数增长的膨胀率
self.network.append(block_class(channel, kernel_size, dilation, dropout))
TCN配置 (tcn.yaml)
yaml
model:
input_dim: 32
hidden_dim: 256
preprocessing:
type: linear
backbone:
type: tcn
ds: false # 是否使用深度可分离卷积
num_layers: 6
kernel_size: 8
dropout: 0.1
2. DS_TCN (Depthwise Separable TCN)
架构特点
- 在TCN基础上使用深度可分离卷积
- 显著减少参数量,提高计算效率
- 保持相同的感受野和性能
核心实现
python
class DsCnnBlock(Block):
def __init__(self, channel, kernel_size, dilation, dropout):
super().__init__(channel, kernel_size, dilation, dropout)
self.cnn = nn.Sequential(
nn.Conv1d(channel, channel, kernel_size, dilation=dilation, groups=channel),
nn.BatchNorm1d(channel),
nn.ReLU(),
nn.Conv1d(channel, channel, kernel_size=1),
nn.BatchNorm1d(channel),
nn.ReLU(),
nn.Dropout(dropout),
)
3. FSMN (Feedforward Sequential Memory Network)
架构特点
- 记忆块: 捕获长距离依赖关系
- 前馈结构: 避免RNN的顺序计算
- 参数共享: 减少模型参数量
FSMN配置 (fsmn.yaml)
yaml
model:
input_dim: 32
hidden_dim: 128
backbone:
type: fsmn
input_affine_dim: 140
num_layers: 4
linear_dim: 250
proj_dim: 128
left_order: 10
right_order: 2
4. GRU (Gated Recurrent Unit)
架构特点
- 循环结构: 处理序列数据
- 门控机制: 控制信息流动
- 双向支持: 可配置双向GRU
GRU配置 (gru.yaml)
yaml
model:
hidden_dim: 128
backbone:
type: gru
num_layers: 2
5. MDTC (Multi-scale Dilated Temporal Convolution)
架构特点
- 多尺度膨胀: 不同尺度的特征提取
- 因果卷积: 支持流式推理
- 堆叠结构: 多层堆叠增强表达能力
MDTC配置 (mdtc.yaml)
yaml
model:
hidden_dim: 64
backbone:
type: mdtc
num_stack: 4
stack_size: 4
kernel_size: 5
hidden_dim: 64
causal: True
6. TDNN-LSTM
架构特点
- TDNN: 时延神经网络,捕获局部特征
- LSTM: 长短期记忆网络,处理长距离依赖
- 混合架构: 结合CNN和RNN的优势
数据加载器架构
数据处理流程
┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐
│ Raw Audio │───▶│ Parse & │───▶│ Feature │───▶│ Data │
│ Files │ │ Filter │ │ Extraction │ │ Augmentation│
└─────────────┘ └─────────────┘ └─────────────┘ └─────────────┘
│ │ │
┌─────────────┐ ┌─────────────┐ ┌─────────────┐
│ Resample │ │ MFCC/FBank │ │ SpecAug │
│ Speed │ │ Signal │ │ Shuffle │
│ Perturb │ │ Features │ │ Batch │
└─────────────┘ └─────────────┘ └─────────────┘
核心数据处理器
1. 数据解析 (parse_raw)
python
def parse_raw(data):
"""解析原始数据格式"""
for item in data:
# 解析音频路径和标签
yield {
'wav': item['wav'],
'label': item['label'],
'duration': item['duration']
}
2. 特征提取
- MFCC: 梅尔频率倒谱系数
- FBank: 梅尔滤波器组特征
- Signal: 原始信号特征
3. 数据增强
- SpecAugment: 频谱增强
- Speed Perturb: 速度扰动
- Reverb/Noise: 混响和噪声添加
数据集配置
TCN数据集配置
yaml
dataset_conf:
filter_conf:
max_length: 2048
min_length: 0
resample_conf:
resample_rate: 16000
feature_extraction_conf:
feature_type: 'signal'
frame_shift: 0.001
frame_length: 0.002
num_mel_bins: 32
spec_aug: false
batch_conf:
batch_size: 128
TCN模型训练推理流程
1. 数据准备阶段 (Stage -1 & 0)
数据下载
bash
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
echo "Download and extract all datasets"
local/mobvoi_data_download.sh --dl_dir $download_dir
fi
数据预处理
bash
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
echo "Preparing datasets..."
# 创建词典
echo "<filler> -1" > dict/words.txt
echo "clean 0" >> dict/words.txt
echo "nice 1" >> dict/words.txt
# 处理训练、开发、测试集
for folder in train dev test; do
mkdir -p data_lol/$folder
# 合并正负样本
cat data_lol/p_$folder/wav.scp data_lol/n_$folder/wav.scp > data_lol/$folder/wav.scp
cat data_lol/p_$folder/text data_lol/n_$folder/text > data_lol/$folder/text
done
fi
2. 特征计算阶段 (Stage 1)
CMVN计算
bash
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
echo "Compute CMVN and Format datasets"
tools/compute_cmvn_stats.py --num_workers 16 --train_config $config \
--in_scp data_lol/train/wav.scp \
--out_cmvn data_lol/train/global_cmvn
数据列表生成
bash
# 生成数据列表
for x in train ; do
tools/wav_to_duration.sh --nj 8 data_lol/$x/wav.scp data_lol/$x/wav.dur
tools/make_list.py data_lol/$x/wav.scp data_lol/$x/text.label \
data_lol/$x/wav.dur data_lol/$x/data.list
tools/shuffle_list.py --input data_lol/$x/data.list --output data_lol/$x/data.rand.list
done
fi
3. 模型训练阶段 (Stage 2)
训练配置
bash
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "Start training ..."
mkdir -p $dir
# 配置CMVN选项
cmvn_opts=
$norm_mean && cmvn_opts="--cmvn_file data_lol/train/global_cmvn"
$norm_var && cmvn_opts="$cmvn_opts --norm_var"
# 启动分布式训练
num_gpus=$(echo $gpus | awk -F ',' '{print NF}')
torchrun --standalone --nnodes=1 --nproc_per_node=$num_gpus \
wekws/bin/train.py --gpus $gpus \
--config $config \
--train_data data_lol/train/data.rand.list \
--cv_data data_lol/train/cross.list \
--model_dir $dir \
--num_workers 8 \
--num_keywords $num_keywords \
--min_duration 10 \
--seed 666 \
$cmvn_opts \
${checkpoint:+--checkpoint $checkpoint}
fi
训练流程详解
- 模型初始化: 根据配置文件初始化KWSModel
- 数据加载: 使用DataLoader加载训练和验证数据
- 优化器配置: Adam优化器,学习率调度
- 训练循环 :
- 前向传播计算损失
- 反向传播更新参数
- 验证集评估
- 模型保存
4. 模型评估阶段 (Stage 3)
模型平均
bash
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
echo "Do model average, Compute FRR/FAR ..."
python wekws/bin/average_model.py \
--dst_model $score_checkpoint \
--src_path $dir \
--num ${num_average} \
--val_best
性能评估
bash
# 评分计算
python wekws/bin/score.py \
--config $dir/config.yaml \
--test_data data_lol/train/data.list \
--gpu 0 \
--batch_size 256 \
--checkpoint $score_checkpoint \
--score_file $result_dir/score.txt \
--num_workers 8
# DET曲线计算
for keyword in 0 1; do
python wekws/bin/compute_det.py \
--keyword $keyword \
--test_data data_lol/train/data.list \
--window_shift $window_shift \
--score_file $result_dir/score.txt \
--stats_file $result_dir/stats.${keyword}.txt
done
# 绘制DET曲线
python wekws/bin/plot_det_curve.py \
--keywords_dict dict/words.txt \
--stats_dir $result_dir \
--figure_file $result_dir/det.png \
--ylim 5 \
--xlim 10
fi
5. 模型导出阶段 (Stage 4)
JIT导出
bash
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
result_dir=$dir/test_$(basename $score_checkpoint)
jit_model=$(basename $score_checkpoint | sed -e 's:.pt$:.zip:g')
python wekws/bin/export_jit.py \
--config $dir/config.yaml \
--checkpoint $score_checkpoint \
--jit_model $result_dir/$jit_model
ONNX导出
bash
onnx_model=$(basename $score_checkpoint | sed -e 's:.pt$:.onnx:g')
python wekws/bin/export_onnx.py \
--config $dir/config.yaml \
--checkpoint $score_checkpoint \
--onnx_model $result_dir/$onnx_model
fi
部署与测试
1. 模型量化
静态量化
python
# wekws/bin/static_quantize.py
def static_quantize(config, test_data, checkpoint, quant_model):
# 加载模型
model = init_model(configs['model'])
load_checkpoint(model, checkpoint)
# 准备量化
model.eval()
model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
torch.quantization.prepare(model, inplace=True)
# 校准
with torch.no_grad():
for data in calib_data:
model(data)
# 转换
torch.quantization.convert(model, inplace=True)
2. 流式推理
流式KWS检测
python
# wekws/bin/stream_kws_ctc.py
class StreamingKWS:
def __init__(self, model, config):
self.model = model
self.config = config
self.cache = torch.zeros(0, 0, 0)
def process_frame(self, frame):
# 处理单帧音频
output, self.cache = self.model(frame, self.cache)
return output
3. 性能测试
DET曲线分析
- FRR (False Rejection Rate): 错误拒绝率
- FAR (False Acceptance Rate): 错误接受率
- EER (Equal Error Rate): 等错误率
性能对比
Max-Pooling Loss 性能对比
| 模型 | 参数量(K) | 训练轮数 | hi_xiaowen FRR | nihao_wenwen FRR |
|---|---|---|---|---|
| GRU | 203 | 80(avg30) | 0.088901 | 0.083827 |
| TCN | 134 | 80(avg30) | 0.023494 | 0.029884 |
| DS_TCN | 287 | 80(avg30) | 0.005357 | 0.006390 |
| MDTC | 156 | 80(avg10) | 0.007142 | 0.005920 |
| MDTC_Small | 31 | 80(avg10) | 0.005357 | 0.005920 |
CTC Loss 性能对比 (FAR固定为12小时一次)
| 模型 | 损失函数 | hi_xiaowen FRR | nihao_wenwen FRR |
|---|---|---|---|
| DS_TCN(spec_aug) | Max-pooling | 0.051217 | 0.021896 |
| DS_TCN(spec_aug) | CTC | 0.056574 | 0.056856 |
| FSMN(spec_aug) | CTC | 0.031012 | 0.022460 |
流式vs非流式性能对比
| 模型 | 流式 | hi_xiaowen FRR | nihao_wenwen FRR |
|---|---|---|---|
| DS_TCN(spec_aug) | 否 | 0.056574 | 0.056856 |
| DS_TCN(spec_aug) | 是 | 0.132694 | 0.057044 |
| FSMN(spec_aug) | 否 | 0.031012 | 0.022460 |
| FSMN(spec_aug) | 是 | 0.115215 | 0.020205 |
总结
WEKWS系统提供了完整的关键词识别解决方案,具有以下优势:
- 多模型支持: 提供多种先进的KWS模型架构
- 端到端流程: 从数据准备到模型部署的完整流程
- 高性能: 在多个数据集上达到业界领先水平
- 易用性: 清晰的代码结构和详细的文档
- 可扩展性: 支持自定义模型和数据集
推荐使用FSMN-CTC模型作为基础模型进行二次开发,该模型在小数据集KWS任务上表现最为稳健。