【WEKWS】论文解读 && 语音唤醒系统架构详解 && 实战

目录

  1. 项目概述
  2. 系统架构总览
  3. 模型架构详解
  4. 数据加载器架构
  5. TCN模型训练推理流程
  6. 部署与测试
  7. 性能对比

项目概述

WEKWS (Wake-up Word Keyword Spotting;https://arxiv.org/pdf/2210.16743.pdf) 是一个基于PyTorch的端到端关键词识别系统,支持多种深度学习模型架构,包括TCN、FSMN、GRU、MDTC等。该项目提供了完整的训练、推理、部署和测试流程。

核心特性

  • 多模型支持: TCN、DS_TCN、FSMN、GRU、MDTC、TDNN-LSTM等;可魔改
  • 多损失函数: Max-Pooling Loss、CTC Loss
  • 流式/非流式推理: 支持实时流式关键词检测
  • 量化部署: 支持PyTorch JIT和ONNX导出
  • 数据增强: SpecAugment、速度扰动、混响/噪声添加

系统架构总览

复制代码
┌─────────────────────────────────────────────────────────────┐
│                    WEKWS KWS System                         │
├─────────────────────────────────────────────────────────────┤
│  ┌─────────────┐  ┌─────────────┐  ┌─────────────┐         │
│  │   Data      │  │   Model     │  │  Training   │         │
│  │  Pipeline   │  │  Architecture│  │   Pipeline  │         │
│  └─────────────┘  └─────────────┘  └─────────────┘         │
│         │               │               │                   │
│  ┌─────────────┐  ┌─────────────┐  ┌─────────────┐         │
│  │ Feature     │  │   KWS       │  │  Inference  │         │
│  │ Extraction  │  │   Model     │  │   Engine    │         │
│  └─────────────┘  └─────────────┘  └─────────────┘         │
│         │               │               │                   │
│  ┌─────────────┐  ┌─────────────┐  ┌─────────────┐         │
│  │  Dataset    │  │ Classifier  │  │ Deployment  │         │
│  │  Processor  │  │   & Loss    │  │   Tools     │         │
│  └─────────────┘  └─────────────┘  └─────────────┘         │
└─────────────────────────────────────────────────────────────┘

核心组件

  1. 数据管道 (wekws/dataset/): 数据加载、预处理、增强
  2. 模型架构 (wekws/model/): 各种KWS模型实现
  3. 训练管道 (wekws/bin/): 训练、评估、推理脚本
  4. 工具集 (wekws/utils/): 辅助功能模块

模型架构详解

1. TCN (Temporal Convolutional Network)

架构特点
  • 因果卷积: 确保模型只能看到过去的信息
  • 膨胀卷积: 扩大感受野而不增加参数量
  • 残差连接: 缓解梯度消失问题
  • 深度可分离卷积: 减少参数量,提高效率
核心实现
python 复制代码
class TCN(nn.Module):
    def __init__(self, num_layers, channel, kernel_size, dropout, block_class):
        super().__init__()
        self.network = nn.ModuleList()
        for i in range(num_layers):
            dilation = 2**i  # 指数增长的膨胀率
            self.network.append(block_class(channel, kernel_size, dilation, dropout))
TCN配置 (tcn.yaml)
yaml 复制代码
model:
    input_dim: 32
    hidden_dim: 256
    preprocessing:
        type: linear
    backbone:
        type: tcn
        ds: false  # 是否使用深度可分离卷积
        num_layers: 6
        kernel_size: 8
        dropout: 0.1

2. DS_TCN (Depthwise Separable TCN)

架构特点
  • 在TCN基础上使用深度可分离卷积
  • 显著减少参数量,提高计算效率
  • 保持相同的感受野和性能
核心实现
python 复制代码
class DsCnnBlock(Block):
    def __init__(self, channel, kernel_size, dilation, dropout):
        super().__init__(channel, kernel_size, dilation, dropout)
        self.cnn = nn.Sequential(
            nn.Conv1d(channel, channel, kernel_size, dilation=dilation, groups=channel),
            nn.BatchNorm1d(channel),
            nn.ReLU(),
            nn.Conv1d(channel, channel, kernel_size=1),
            nn.BatchNorm1d(channel),
            nn.ReLU(),
            nn.Dropout(dropout),
        )

3. FSMN (Feedforward Sequential Memory Network)

架构特点
  • 记忆块: 捕获长距离依赖关系
  • 前馈结构: 避免RNN的顺序计算
  • 参数共享: 减少模型参数量
FSMN配置 (fsmn.yaml)
yaml 复制代码
model:
    input_dim: 32
    hidden_dim: 128
    backbone:
        type: fsmn
        input_affine_dim: 140
        num_layers: 4
        linear_dim: 250
        proj_dim: 128
        left_order: 10
        right_order: 2

4. GRU (Gated Recurrent Unit)

架构特点
  • 循环结构: 处理序列数据
  • 门控机制: 控制信息流动
  • 双向支持: 可配置双向GRU
GRU配置 (gru.yaml)
yaml 复制代码
model:
    hidden_dim: 128
    backbone:
        type: gru
        num_layers: 2

5. MDTC (Multi-scale Dilated Temporal Convolution)

架构特点
  • 多尺度膨胀: 不同尺度的特征提取
  • 因果卷积: 支持流式推理
  • 堆叠结构: 多层堆叠增强表达能力
MDTC配置 (mdtc.yaml)
yaml 复制代码
model:
    hidden_dim: 64
    backbone:
        type: mdtc
        num_stack: 4
        stack_size: 4
        kernel_size: 5
        hidden_dim: 64
        causal: True

6. TDNN-LSTM

架构特点
  • TDNN: 时延神经网络,捕获局部特征
  • LSTM: 长短期记忆网络,处理长距离依赖
  • 混合架构: 结合CNN和RNN的优势

数据加载器架构

数据处理流程

复制代码
┌─────────────┐    ┌─────────────┐    ┌─────────────┐    ┌─────────────┐
│  Raw Audio  │───▶│  Parse &    │───▶│  Feature    │───▶│  Data       │
│  Files      │    │  Filter     │    │  Extraction │    │  Augmentation│
└─────────────┘    └─────────────┘    └─────────────┘    └─────────────┘
                           │                   │                   │
                    ┌─────────────┐    ┌─────────────┐    ┌─────────────┐
                    │  Resample   │    │  MFCC/FBank │    │  SpecAug    │
                    │  Speed      │    │  Signal     │    │  Shuffle    │
                    │  Perturb    │    │  Features   │    │  Batch      │
                    └─────────────┘    └─────────────┘    └─────────────┘

核心数据处理器

1. 数据解析 (parse_raw)
python 复制代码
def parse_raw(data):
    """解析原始数据格式"""
    for item in data:
        # 解析音频路径和标签
        yield {
            'wav': item['wav'],
            'label': item['label'],
            'duration': item['duration']
        }
2. 特征提取
  • MFCC: 梅尔频率倒谱系数
  • FBank: 梅尔滤波器组特征
  • Signal: 原始信号特征
3. 数据增强
  • SpecAugment: 频谱增强
  • Speed Perturb: 速度扰动
  • Reverb/Noise: 混响和噪声添加

数据集配置

TCN数据集配置
yaml 复制代码
dataset_conf:
    filter_conf:
        max_length: 2048
        min_length: 0
    resample_conf:
        resample_rate: 16000
    feature_extraction_conf:
        feature_type: 'signal'
        frame_shift: 0.001
        frame_length: 0.002
        num_mel_bins: 32
    spec_aug: false
    batch_conf:
        batch_size: 128

TCN模型训练推理流程

1. 数据准备阶段 (Stage -1 & 0)

数据下载
bash 复制代码
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
  echo "Download and extract all datasets"
  local/mobvoi_data_download.sh --dl_dir $download_dir
fi
数据预处理
bash 复制代码
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
  echo "Preparing datasets..."
  # 创建词典
  echo "<filler> -1" > dict/words.txt
  echo "clean 0" >> dict/words.txt
  echo "nice 1" >> dict/words.txt
  
  # 处理训练、开发、测试集
  for folder in train dev test; do
    mkdir -p data_lol/$folder
    # 合并正负样本
    cat data_lol/p_$folder/wav.scp data_lol/n_$folder/wav.scp > data_lol/$folder/wav.scp
    cat data_lol/p_$folder/text data_lol/n_$folder/text > data_lol/$folder/text
  done
fi

2. 特征计算阶段 (Stage 1)

CMVN计算
bash 复制代码
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
  echo "Compute CMVN and Format datasets"
  tools/compute_cmvn_stats.py --num_workers 16 --train_config $config \
    --in_scp data_lol/train/wav.scp \
    --out_cmvn data_lol/train/global_cmvn
数据列表生成
bash 复制代码
  # 生成数据列表
  for x in train ; do
    tools/wav_to_duration.sh --nj 8 data_lol/$x/wav.scp data_lol/$x/wav.dur
    tools/make_list.py data_lol/$x/wav.scp data_lol/$x/text.label \
      data_lol/$x/wav.dur data_lol/$x/data.list
    tools/shuffle_list.py --input data_lol/$x/data.list --output data_lol/$x/data.rand.list
  done
fi

3. 模型训练阶段 (Stage 2)

训练配置
bash 复制代码
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
  echo "Start training ..."
  mkdir -p $dir
  
  # 配置CMVN选项
  cmvn_opts=
  $norm_mean && cmvn_opts="--cmvn_file data_lol/train/global_cmvn"
  $norm_var && cmvn_opts="$cmvn_opts --norm_var"
  
  # 启动分布式训练
  num_gpus=$(echo $gpus | awk -F ',' '{print NF}')
  torchrun --standalone --nnodes=1 --nproc_per_node=$num_gpus \
    wekws/bin/train.py --gpus $gpus \
      --config $config \
      --train_data data_lol/train/data.rand.list \
      --cv_data data_lol/train/cross.list \
      --model_dir $dir \
      --num_workers 8 \
      --num_keywords $num_keywords \
      --min_duration 10 \
      --seed 666 \
      $cmvn_opts \
      ${checkpoint:+--checkpoint $checkpoint}
fi
训练流程详解
  1. 模型初始化: 根据配置文件初始化KWSModel
  2. 数据加载: 使用DataLoader加载训练和验证数据
  3. 优化器配置: Adam优化器,学习率调度
  4. 训练循环 :
    • 前向传播计算损失
    • 反向传播更新参数
    • 验证集评估
    • 模型保存

4. 模型评估阶段 (Stage 3)

模型平均
bash 复制代码
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
  echo "Do model average, Compute FRR/FAR ..."
  python wekws/bin/average_model.py \
    --dst_model $score_checkpoint \
    --src_path $dir \
    --num ${num_average} \
    --val_best
性能评估
bash 复制代码
  # 评分计算
  python wekws/bin/score.py \
    --config $dir/config.yaml \
    --test_data data_lol/train/data.list \
    --gpu 0 \
    --batch_size 256 \
    --checkpoint $score_checkpoint \
    --score_file $result_dir/score.txt \
    --num_workers 8

  # DET曲线计算
  for keyword in 0 1; do
    python wekws/bin/compute_det.py \
      --keyword $keyword \
      --test_data data_lol/train/data.list \
      --window_shift $window_shift \
      --score_file $result_dir/score.txt \
      --stats_file $result_dir/stats.${keyword}.txt
  done

  # 绘制DET曲线
  python wekws/bin/plot_det_curve.py \
      --keywords_dict dict/words.txt \
      --stats_dir $result_dir \
      --figure_file $result_dir/det.png \
      --ylim 5 \
      --xlim 10
fi

5. 模型导出阶段 (Stage 4)

JIT导出
bash 复制代码
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
  result_dir=$dir/test_$(basename $score_checkpoint)
  jit_model=$(basename $score_checkpoint | sed -e 's:.pt$:.zip:g')
  python wekws/bin/export_jit.py \
    --config $dir/config.yaml \
    --checkpoint $score_checkpoint \
    --jit_model $result_dir/$jit_model
ONNX导出
bash 复制代码
  onnx_model=$(basename $score_checkpoint | sed -e 's:.pt$:.onnx:g')
  python wekws/bin/export_onnx.py \
    --config $dir/config.yaml \
    --checkpoint $score_checkpoint \
    --onnx_model $result_dir/$onnx_model
fi

部署与测试

1. 模型量化

静态量化
python 复制代码
# wekws/bin/static_quantize.py
def static_quantize(config, test_data, checkpoint, quant_model):
    # 加载模型
    model = init_model(configs['model'])
    load_checkpoint(model, checkpoint)
    
    # 准备量化
    model.eval()
    model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
    torch.quantization.prepare(model, inplace=True)
    
    # 校准
    with torch.no_grad():
        for data in calib_data:
            model(data)
    
    # 转换
    torch.quantization.convert(model, inplace=True)

2. 流式推理

流式KWS检测
python 复制代码
# wekws/bin/stream_kws_ctc.py
class StreamingKWS:
    def __init__(self, model, config):
        self.model = model
        self.config = config
        self.cache = torch.zeros(0, 0, 0)
        
    def process_frame(self, frame):
        # 处理单帧音频
        output, self.cache = self.model(frame, self.cache)
        return output

3. 性能测试

DET曲线分析
  • FRR (False Rejection Rate): 错误拒绝率
  • FAR (False Acceptance Rate): 错误接受率
  • EER (Equal Error Rate): 等错误率

性能对比

Max-Pooling Loss 性能对比

模型 参数量(K) 训练轮数 hi_xiaowen FRR nihao_wenwen FRR
GRU 203 80(avg30) 0.088901 0.083827
TCN 134 80(avg30) 0.023494 0.029884
DS_TCN 287 80(avg30) 0.005357 0.006390
MDTC 156 80(avg10) 0.007142 0.005920
MDTC_Small 31 80(avg10) 0.005357 0.005920

CTC Loss 性能对比 (FAR固定为12小时一次)

模型 损失函数 hi_xiaowen FRR nihao_wenwen FRR
DS_TCN(spec_aug) Max-pooling 0.051217 0.021896
DS_TCN(spec_aug) CTC 0.056574 0.056856
FSMN(spec_aug) CTC 0.031012 0.022460

流式vs非流式性能对比

模型 流式 hi_xiaowen FRR nihao_wenwen FRR
DS_TCN(spec_aug) 0.056574 0.056856
DS_TCN(spec_aug) 0.132694 0.057044
FSMN(spec_aug) 0.031012 0.022460
FSMN(spec_aug) 0.115215 0.020205

总结

WEKWS系统提供了完整的关键词识别解决方案,具有以下优势:

  1. 多模型支持: 提供多种先进的KWS模型架构
  2. 端到端流程: 从数据准备到模型部署的完整流程
  3. 高性能: 在多个数据集上达到业界领先水平
  4. 易用性: 清晰的代码结构和详细的文档
  5. 可扩展性: 支持自定义模型和数据集

推荐使用FSMN-CTC模型作为基础模型进行二次开发,该模型在小数据集KWS任务上表现最为稳健。

相关推荐
ZHENGZJM2 小时前
项目复杂度评估与系列博客大纲生成
系统架构·ai应用
皓晗11 小时前
Whisper-large-v3参数详解:config.yaml与language自动检测机制解析
语音识别·whisper模型·多语言处理
byte轻骑兵20 小时前
【HFP】规范精讲[23]: 蓝牙超宽频语音革命——LC3-SWB编码深度解析,重塑无线通话体验
人工智能·语音识别·蓝牙·hfp·通话
码云社区20 小时前
上门做饭系统架构设计:基于Spring Cloud的微服务实践与源码解析
spring cloud·微服务·系统架构
2603_9547083121 小时前
多微电网系统架构:集群协同与能量互济的网络设计
网络·人工智能·分布式·物联网·架构·系统架构
爱上珍珠的贝壳1 天前
ESP32-S3-CAM:豆包语音识别文字后控制小车(五)——认识L298N驱动模块
人工智能·语音识别·智能硬件·esp32-s3·l298n·减速电机
特力康小冬1 天前
从“看得见”到“喊得出”:智能驱赶防垂钓告警装置如何筑牢输电安全防线?
人工智能·语音识别
herinspace1 天前
管家婆实用帖-如何使用ping命令检测网络环境
网络·数据库·人工智能·学习·excel·语音识别
齐齐大魔王1 天前
智能语音技术(八)
人工智能·语音识别