通用唤醒词识别模型 - Wav2Vec2

1 测试结果:

python 复制代码
# 前 6 个为合成的负样本,后 4 个为人声录制的正样本
# 测试结果表明:基本满足唤醒词识别需求,但特征样本对模型效果影响较大

pool_sim: 0.9486, dtw_sim: 0.1332, similarity: 0.4594
False
pool_sim: 0.9341, dtw_sim: 0.1379, similarity: 0.4564
False
pool_sim: 0.9414, dtw_sim: 0.1686, similarity: 0.4777
False
pool_sim: 0.9615, dtw_sim: 0.2154, similarity: 0.5138
False
pool_sim: 0.9378, dtw_sim: 0.1925, similarity: 0.4906
False
pool_sim: 0.8796, dtw_sim: 0.1584, similarity: 0.4469
False
pool_sim: 0.9676, dtw_sim: 0.2625, similarity: 0.5446
False
pool_sim: 0.9780, dtw_sim: 1.0000, similarity: 0.9912
True
pool_sim: 0.9615, dtw_sim: 0.2212, similarity: 0.5173
False
pool_sim: 0.9862, dtw_sim: 0.5523, similarity: 0.7259
True

2 模型实现

python 复制代码
import json
import os

import librosa
import numpy as np
import torch
from fastdtw import fastdtw
from scipy.spatial.distance import euclidean
from torch import nn, Tensor
from transformers import Wav2Vec2Processor, Wav2Vec2Model


# 音频编码器
class AudioEncoder(nn.Module):
    def __init__(self,
                 path: str = None,
                 sample_rate: int = 16000,
                 max_length: int = 10):
        super(AudioEncoder, self).__init__()
        self.sample_rate = sample_rate
        self.max_length = max_length
        if path is None:  # 在线加载
            path = r"facebook/wav2vec2-large-960h-lv60-self"
        # 加载预训练模型
        self.processor = Wav2Vec2Processor.from_pretrained(path)
        self.model = Wav2Vec2Model.from_pretrained(path)

    def forward(self,
                audios: list[str | np.ndarray | Tensor],
                return_type: str = "np") -> dict[str, np.ndarray | Tensor]:
        processed_audios = []
        for audio in audios:
            if isinstance(audio, str):  # 加载音频文件
                waveform, _ = librosa.load(audio, sr=self.sample_rate)
            else:
                waveform = audio
            processed_audios.append(waveform)
        # 提取特征
        with torch.no_grad():
            inputs = self.processor(
                audios,  # ndarray, Tensor
                sampling_rate=self.sample_rate,
                return_tensors="pt",
                padding=True,
                max_length=self.sample_rate * self.max_length,
                truncation=True,
            )
            outputs = self.model(**inputs)
            last_hidden_state = outputs["last_hidden_state"]  # (batch, seq_len, 1024)
            # 取最后一层的平均池化作为全局特征向量
            pooler_output = last_hidden_state.mean(dim=1)  # (batch, 1024)
            if return_type == "np":  # ndarray
                last_hidden_state = last_hidden_state.cpu().numpy()
                pooler_output = pooler_output.cpu().numpy()

        return {
            "last_hidden_state": last_hidden_state,
            "pooler_output": pooler_output,
        }


# 唤醒词模型
class WakeWordModel:
    def __init__(self, config: dict = None):
        self.root = os.path.dirname(os.path.abspath(__file__))  # 工作路径
        self.config_path = os.path.join(self.root, "config.json")
        self.config = config or self.load_config()  # 配置项
        self.device = torch.device(self.config["DEVICE"])
        # 音频编码器
        self.audio_encoder = AudioEncoder(
            self.config["AUDIO_ENCODER"],
            sample_rate=self.config["SAMPLE_RATE"],
            max_length=self.config["MAX_LENGTH"],
        ).to(self.device)
        self.audio_encoder.eval()  # 测试模式

    # 检测唤醒词
    def __call__(self,
                 name: str,
                 audio: str | np.ndarray) -> bool:
        info = self.config["WAKE_WORD_INFO"]
        if name not in info:
            raise ValueError(f"未注册的唤醒词:{name}")
        feature = np.load(info[name]["FEATURE"])
        pool_prototype, dtw_prototypes = feature["pool"], feature["dtw"]
        # 提取被检测样本的特征向量
        ret = self.extract_features([audio])
        pool_sim = self.cosine_similarity(ret["pooler_output"][0], pool_prototype)
        # 计算 DTW 距离
        dtw_feat = ret["last_hidden_state"][0]  # (seq_len, 1024)
        min_dtw_dist, prot_len = float("inf"), 0
        for dtw_prototype in dtw_prototypes:  # 逐样本计算对比
            # 使用欧氏距离作为点距离
            distance, _ = fastdtw(dtw_feat, dtw_prototype, dist=euclidean)
            if distance < min_dtw_dist:
                min_dtw_dist = distance
                prot_len = len(dtw_prototype)  # (seq_len,)
        # 计算平均帧距离
        avg_dist = min_dtw_dist / max(len(dtw_feat), prot_len)
        dtw_sim = 1.0 / (1.0 + avg_dist)
        # 加权相似度
        similarity = 0.4 * pool_sim + 0.6 * dtw_sim
        print("pool_sim: {:.4f}, dtw_sim: {:.4f}, similarity: {:.4f}".format(
            pool_sim, dtw_sim, similarity
        ))

        return similarity >= self.config["THRESHOLD"]

    # 更新唤醒词
    def update_wake_word(self,
                         name: str,
                         samples: list[str]):
        min_samples, max_samples = self.config["MIN_SAMPLES"], self.config["MAX_SAMPLES"]
        if len(samples) < min_samples:  # 限制样本数量
            raise ValueError(f"注册唤醒词至少需要 {min_samples} 个音频样本")
        samples = samples[:max_samples]
        # 提取样本的特征向量
        ret = self.extract_features(samples)
        # 计算特征向量的均值
        pool_prototype = np.mean(ret["pooler_output"], axis=0)  # (1024,)
        dtw_prototypes = ret["last_hidden_state"]  # (batch, seq_len, 1024)
        # 更新配置项
        feature_path = os.path.join(self.root, "feature", f"{name}.npz")
        self.config["WAKE_WORD_INFO"][name] = {
            "SAMPLE": samples,
            "FEATURE": feature_path,
        }
        with open(self.config_path, "w+", encoding="utf-8") as file:
            json.dump(self.config, file, ensure_ascii=False)
        # 保存特征文件
        np.savez(feature_path, pool=pool_prototype, dtw=dtw_prototypes)

    # 提取音频的特征向量
    def extract_features(self,
                         audios: list[str | np.ndarray],
                         top_db: float = 40.0) -> dict[str, np.ndarray]:
        processed_audios = []
        for audio in audios:
            if isinstance(audio, str):  # 加载音频文件
                waveform, _ = librosa.load(audio, sr=self.config["SAMPLE_RATE"])
            else:
                waveform = audio
            intervals = librosa.effects.split(waveform, top_db=top_db)
            waveform_trimmed = []  # 消除静音后的音频
            for start, end in intervals:
                waveform_trimmed.extend(waveform[start:end])
            processed_audios.append(np.array(waveform_trimmed))

        return self.audio_encoder(processed_audios)

    # 加载配置项
    def load_config(self) -> dict:
        if not os.path.exists(self.config_path):
            raise FileNotFoundError(f"配置文件不存在:{self.config_path}")
        with open(self.config_path, "r", encoding="utf-8") as f:
            config = json.load(f)
        # 校验
        required_keys = ["DEVICE", "AUDIO_ENCODER", "SAMPLE_RATE",
                         "MAX_LENGTH", "THRESHOLD", "MIN_SAMPLES",
                         "MAX_SAMPLES", "FEATURE_DIM", "WAKE_WORD_INFO"]
        for key in required_keys:
            if key not in config:
                raise ValueError(f"配置文件中缺少必要的配置项:{key}")

        return config

    # 计算两个向量的余弦相似度
    @staticmethod
    def cosine_similarity(vec1, vec2):
        vec1 = vec1 / np.linalg.norm(vec1)
        vec2 = vec2 / np.linalg.norm(vec2)

        return np.dot(vec1, vec2)

3 训练及验证:

python 复制代码
if __name__ == '__main__':
    wake_word_model = WakeWordModel()
    wake_word_model.update_wake_word(
        "你好坤坤",
        [
            r"D:\Project\Transformer\wake_word\sample\你好坤坤" +
            "\\" + str(i) + ".mp3" for i in range(1, 8)
        ],
    )
    for i in range(1, 11):
        print(
            wake_word_model(
                "你好坤坤",
                r"D:\Project\Transformer\wake_word\audio\test" + str(i) + ".mp3",
            )
        )

4 训练结果:

同 1 测试结果

相关推荐
z442475326几秒前
CSS如何实现元素悬浮在页面底部_利用fixed定位与底部间距
jvm·数据库·python
m0_596406371 分钟前
mysql数据库用户密码加固策略_实施强密码策略与定期轮换
jvm·数据库·python
m0_676544381 分钟前
CSS如何实现语义化样式编写_使用BEM规范提升命名直观性
jvm·数据库·python
KivenMitnick2 分钟前
CialloVOL 1.2:便捷好用的轻量化内存取证分析平台
windows·python·安全·网络安全·flask·系统安全·安全威胁分析
weixin_669545204 分钟前
支持 18W 快充的 2 节/3 节串联锂电池高效同步升压充电芯片 SW7306
人工智能·单片机·嵌入式硬件·硬件工程
wayz114 分钟前
Day 16:PCA主成分分析与降维
人工智能·算法·机器学习
昇腾CANN5 分钟前
4月28日直播丨基于TorchTitan的DeepSeek-V4昇腾续训练优化实践
人工智能·昇腾·cann·deepseek
他是龙5517 分钟前
70:Python安全 & SSTI模板注入 & Jinja2引擎 & 利用绕过 & 工具实战
开发语言·python·安全
jackyrongvip7 分钟前
快速理解本体论
人工智能·本体论
m0_676544389 分钟前
MySQL数据库迁移后如何测试数据可读性_进行简单查询验证.txt
jvm·数据库·python