TensorFlow的Yes/No 关键词识别模型训练

TensorFlow的Yes/No 关键词识别模型训练

参考 TensorFlow 官方教程的 Yes/No 关键词识别模型训练脚本,可以生成直接替换原有mirco_speech识别模型数据的C文件。

参考来源:https://github.com/tensorflow/tflite-micro/blob/main/tensorflow/lite/micro/examples/micro_speech/train/train_micro_speech_model.ipynb

模型大小20k byte左右

训练脚本speech_trainer.py 使用说明

speech_trainer.py 脚本源码在文章最后

脚本概述

speech_trainer.py 提供一条龙流水线:

  • 自动准备数据集(首次或缺失时自动下载/解压 Speech Commands v0.02)
  • 克隆 tensorflow 仓库并调用官方 train.py 训练
  • 冻结模型为 SavedModel(兼容 v1 风格,内部调用官方 freeze.py
  • 生成 TFLite 浮点与量化模型,并做精度验证
  • 生成可直接用于 MCU 的 C 源文件 micro_speech_quantized_model_data.c

目录与输出位置:

  • dataset/:数据集(自动管理,不会被清理)
  • train/:训练输出(checkpoint 等),支持断点续训
  • logs/:日志与事件文件
  • models/
    • saved_model/:冻结后的 SavedModel
    • micro_speech_quantized.tflite:量化 TFLite 模型
    • micro_speech_float.tflite:浮点 TFLite 模型
    • micro_speech_quantized_model_data.c:TFLite Micro C 数组文件

目录结构

复制代码
train/
├── dataset/                					# 数据集(自动管理)
├── logs/                   					# 训练日志与事件文件
├── models/                 					# 导出模型
│   ├── saved_model/        					# 冻结 SavedModel
│   ├── micro_speech_quantized.tflite       	# 量化 TFLite 模型
│   ├── micro_speech_float.tflite  				# 浮点 TFLite 模型
│   └── micro_speech_quantized_model_data.c     # MCU 可用的 C 数组
├── tensorflow/             					# 克隆的 TF 仓库(含官方 scripts)
└── speech_trainer.py       					# 本脚本

环境要求

  • Python 3.10(Windows 建议 py -3.10
  • 安装依赖:

Windows(PowerShell):

powershell 复制代码
# 安装Python 3.10
winget install -e --id Python.Python.3.10
# 创建Python虚拟环境
py -3.10 -m venv .venvpy310_win
# 进入Python虚拟环境
.venvpy310_win\Scripts\Activate.ps1
# 更新pip
python -m pip install --upgrade pip
# 安装依赖包
pip install -r requirements.txt

Linux/macOS(bash):

bash 复制代码
python3.10 -m venv .venvpy310
source .venvpy310/bin/activate
python -m pip install --upgrade pip
pip install -r requirements.txt

快速开始(3 步)

  1. 准备环境(见上)并激活 venv
  2. 一条命令启动训练:
bash 复制代码
python train/speech_trainer.py
  1. 结束后在 train/models/ 获取 micro_speech_quantized_model_data.cmicro_speech_quantized.tflite 等产物

数据集管理逻辑(自动)

  • dataset/ 不存在:创建并下载 speech_commands_v0.02.tar.gzdataset/,随后解压。
  • dataset/ 已存在:
    • 若压缩包不存在,则先下载
    • 若未解压(通过是否存在 yes/no/_background_noise_/ 判定)则解压

你无需手动干预,脚本会在流水线开始自动确保数据集就绪。

资源与磁盘建议:

  • 磁盘空间:≥ 3 GB(含数据集与中间文件)
  • 内存:≥ 4 GB(更高内存更稳定)
  • 网络:可访问 storage.googleapis.com(如网络受限请配置代理或手动下载)

命令行参数

  • --wanted_words:要训练的词汇(逗号分隔)。默认:yes,no
  • --training_steps:训练步数字符串(逗号分段)。默认:12000,3000
  • --learning_rate:学习率字符串(逗号分段)。默认:0.001,0.0001
  • --model_architecture:模型架构。可选:single_fcconvlow_latency_convlow_latency_svdftiny_embedding_convtiny_conv(默认 tiny_conv
  • --skip_training:跳过训练,直接下载官方预训练模型并进入后续转换/导出
  • --resume:继续上次训练(保留 train/,自动查找最近 checkpoint 作为 --start_checkpoint
  • --test_env:仅测试环境(依赖/路径检查),不执行训练

说明:

  • 训练时总步数为各分段之和(例如 12000,3000 => 总步数 15000)。
  • 续训时若未找到 checkpoint,将从头开始训练(脚本会提示)。

运行示例

  1. 基础训练(推荐)
bash 复制代码
python train/speech_trainer.py
  1. 指定词表与步数
bash 复制代码
python train/speech_trainer.py \
  --wanted_words yes,no \
  --training_steps 12000,3000 \
  --learning_rate 0.001,0.0001 \
  --model_architecture tiny_conv
  1. 继续上次训练
bash 复制代码
python train/speech_trainer.py --resume
  1. 使用预训练模型(跳过训练)
bash 复制代码
python train/speech_trainer.py --skip_training
  1. 仅测试环境
bash 复制代码
python train/speech_trainer.py --test_env

续训与从零重训

  • 继续训练:使用 --resume,自动寻找 train/ 下步数最大的 *.ckpt-*.index
  • 从零重训:删除 train/logs/ 再运行;或不删目录直接不加 --resume
  • 仅重新导出:删除 models/ 并重跑(会跳过训练,直接冻结与导出)

流水线阶段说明

  1. 确保数据集:自动下载/解压 Speech Commands v0.02
  2. 克隆 tensorflow 仓库(若已存在则跳过)
  3. 调用官方 train.py 进行训练(可续训)
  4. 调用官方 freeze.py 生成 saved_model/
  5. 生成 micro_speech_float.tflitemicro_speech_quantized.tflite(量化),并进行精度评估
  6. 生成 micro_speech_quantized_model_data.c(TFLite Micro C 数组)
  7. 打印各输出文件路径与大小

集成到 MCU 工程

  • 训练完成后,models/micro_speech_quantized_model_data.c 即为可直接集成的模型数据文件。
  • 将其复制到你的工程对应目录(例如 kws/),替换旧模型文件后编译。

常见问题(FAQ)

  • 导出目录已存在:脚本已处理为不预创建 saved_model 子目录,如仍遇到该错误,可手动删除 models/saved_model/ 后重试。
  • 续训未找到 checkpoint:确认 train/ 下存在形如 tiny_conv.ckpt-*.index 文件;否则将从头训练。
  • 数据集下载失败:检查网络或手动下载 speech_commands_v0.02.tar.gzdataset/ 目录后重跑。
  • 量化模型精度下降明显:可适当增大代表性数据采样数量或调整训练步数与学习率。

其它提示:

  • Windows 执行策略:若激活虚拟环境报策略限制,可在管理员 PowerShell 运行:
powershell 复制代码
Set-ExecutionPolicy -Scope CurrentUser RemoteSigned
  • 国内网络下载慢/失败:可预先手动下载 speech_commands_v0.02.tar.gztrain/dataset/ 再运行。
  • TF CPU 指令集提示(SSE/AVX 等):为性能提示,可忽略,不影响功能。

speech_trainer.py代码

python 复制代码
#!/usr/bin/env python3
"""
语音识别模型训练程序
基于 TensorFlow 的简单音频识别模型训练脚本
支持生成 TensorFlow Lite 模型用于微控制器部署
"""

import os
import sys
import subprocess
import argparse
import logging
from pathlib import Path
import locale

# 设置编码,解决Windows中文路径问题
if sys.platform == 'win32':
    import codecs
    # 设置控制台编码
    if sys.stdout.encoding != 'utf-8':
        sys.stdout = codecs.getwriter('utf-8')(sys.stdout.buffer, 'strict')
    if sys.stderr.encoding != 'utf-8':
        sys.stderr = codecs.getwriter('utf-8')(sys.stderr.buffer, 'strict')

# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# 延迟导入
numpy = None
tensorflow = None

class SpeechRecognitionTrainer:
    def __init__(self, config=None):
        """初始化训练器配置"""
        # 检查并安装依赖
        self._check_and_install_dependencies()
        
        # 全局导入
        global numpy, tensorflow
        import numpy
        import tensorflow
        
        # 默认配置
        self.config = {
            # 训练参数
            'wanted_words': 'yes,no',
            'training_steps': '1000,1000',
            'learning_rate': '0.001,0.0001',
            
            # 模型参数
            'preprocess': 'micro',
            'window_stride': 20,
            'model_architecture': 'tiny_conv',
            
            # 训练控制参数
            'verbosity': 'INFO',
            'eval_step_interval': '1000',
            'save_step_interval': '1000',
            
            # 数据参数
            'sample_rate': 16000,
            'clip_duration_ms': 1000,
            'window_size_ms': 30.0,
            'feature_bin_count': 40,
            'background_frequency': 0.8,
            'background_volume_range': 0.1,
            'time_shift_ms': 100.0,
            'validation_percentage': 10,
            'testing_percentage': 10,
            
            # 量化参数
            'quant_input_min': 0.0,
            'quant_input_max': 26.0,
        }
        
        # 更新配置
        if config:
            self.config.update(config)
            
        # 计算派生参数
        self._calculate_derived_params()
        
        # 设置目录路径
        self._setup_directories()

        # 运行时状态
        self.resume = bool(self.config.get('resume', False))

    def _find_latest_checkpoint(self):
        """返回训练目录下最新的 checkpoint 路径(不含扩展名),找不到则返回 None"""
        from pathlib import Path
        train_dir = Path(self.directories['train'])
        model_prefix = self.config['model_architecture'] + '.ckpt-'
        candidates = []
        for index_file in train_dir.glob(f"{self.config['model_architecture']}.ckpt-*.index"):
            name = index_file.name  # e.g. tiny_conv.ckpt-12345.index
            try:
                step_str = name.split('.ckpt-')[-1].split('.index')[0]
                step = int(step_str)
                candidates.append((step, index_file))
            except Exception:
                continue
        if not candidates:
            return None
        candidates.sort(key=lambda x: x[0], reverse=True)
        latest_index = candidates[0][1]
        # 去掉 .index 扩展名
        return str(latest_index.with_suffix(''))
        
    def _check_and_install_dependencies(self):
        """检查并安装必要的依赖"""
        logger.info("检查并安装依赖...")
        
        required_packages = [
            'numpy',
            'tensorflow',
            'matplotlib',
            'scipy',
        ]
        
        missing_packages = []
        
        for package in required_packages:
            try:
                __import__(package)
                logger.info(f"✓ {package} 已安装")
            except ImportError:
                missing_packages.append(package)
                logger.warning(f"✗ {package} 未安装")
        
        if missing_packages:
            logger.info(f"安装缺失的包: {', '.join(missing_packages)}")
            for package in missing_packages:
                try:
                    subprocess.check_call([
                        sys.executable, '-m', 'pip', 'install', package
                    ])
                    logger.info(f"✓ {package} 安装成功")
                except subprocess.CalledProcessError as e:
                    logger.error(f"✗ {package} 安装失败: {e}")
                    raise
        
        logger.info("依赖检查完成")
        
    def _calculate_derived_params(self):
        """计算派生参数"""
        steps = self.config['training_steps'].split(',')
        self.config['total_steps'] = str(sum(int(step) for step in steps))
        
        number_of_labels = self.config['wanted_words'].count(',') + 1
        number_of_total_labels = number_of_labels + 2
        equal_percentage = int(100.0 / number_of_total_labels)
        
        self.config['silent_percentage'] = equal_percentage
        self.config['unknown_percentage'] = equal_percentage
        self.config['quant_input_range'] = self.config['quant_input_max'] - self.config['quant_input_min']
        
    def _setup_directories(self):
        """设置目录路径 - 使用绝对路径避免相对路径问题"""
        # 获取当前工作目录的绝对路径
        base_dir = os.path.abspath(os.getcwd())
        
        self.directories = {
            'dataset': os.path.join(base_dir, 'dataset'),
            'logs': os.path.join(base_dir, 'logs'),
            'train': os.path.join(base_dir, 'train'),
            'models': os.path.join(base_dir, 'models'),
            'tensorflow_repo': os.path.join(base_dir, 'tensorflow')
        }
        
        # 模型文件路径
        models_dir = self.directories['models']
        self.model_paths = {
            'saved_model': os.path.join(models_dir, 'saved_model'),
            'model_tflite': os.path.join(models_dir, 'micro_speech_quantized.tflite'),
            'float_model_tflite': os.path.join(models_dir, 'micro_speech_float.tflite'),
            'model_tflite_micro': os.path.join(models_dir, 'micro_speech_quantized_model_data.c'),
        }
        
    def clean_previous_data(self):
        """清理之前的训练数据"""
        logger.info("清理之前的训练数据...")
        import shutil

        for name, directory in self.directories.items():
            if name in ('tensorflow_repo', 'dataset'):
                logger.info(f"跳过目录: {name} -> {directory}")
                continue
            if self.resume and name == 'train':
                logger.info(f"检测到 --resume,跳过清理训练目录: {directory}")
                continue

            if os.path.exists(directory):
                try:
                    shutil.rmtree(directory)
                    logger.info(f"已删除目录: {directory}")
                except Exception as e:
                    logger.warning(f"无法删除 {directory}: {e}")

        # 创建必要目录
        os.makedirs(self.directories['models'], exist_ok=True)
        os.makedirs(self.directories['dataset'], exist_ok=True)

    def ensure_dataset(self):
        """确保 dataset 目录存在并包含 speech_commands_v0.02 数据集

        规则:
        1) 若 dataset 不存在,则创建并下载压缩包到其中,然后解压;
        2) 若 dataset 存在:检查是否已有压缩包;若无则下载;随后若未解压则解压。
        """
        import urllib.request
        import tarfile

        dataset_dir = Path(self.directories['dataset'])
        dataset_dir.mkdir(parents=True, exist_ok=True)

        archive_name = 'speech_commands_v0.02.tar.gz'
        archive_path = dataset_dir / archive_name
        data_url = 'https://storage.googleapis.com/download.tensorflow.org/data/speech_commands_v0.02.tar.gz'

        # 判断是否已解压(通过常见子目录存在性来粗略判断)
        def is_extracted(path: Path) -> bool:
            common_subdirs = ['yes', 'no', '_background_noise_']
            return any((path / sub).exists() for sub in common_subdirs)

        # 如无压缩包则下载
        if not archive_path.exists():
            logger.info(f"下载数据集到: {archive_path}")
            try:
                urllib.request.urlretrieve(data_url, str(archive_path))
            except Exception as e:
                logger.error(f"数据集下载失败: {e}")
                raise

        # 如未解压则解压
        if not is_extracted(dataset_dir):
            logger.info("解压数据集...")
            try:
                with tarfile.open(str(archive_path), 'r:gz') as tar:
                    tar.extractall(str(dataset_dir))
            except Exception as e:
                logger.error(f"解压失败: {e}")
                raise
            logger.info("数据集解压完成")
        
    def setup_tensorflow_repo(self):
        """克隆 TensorFlow 仓库"""
        if not os.path.exists(self.directories['tensorflow_repo']):
            logger.info("克隆 TensorFlow 仓库...")
            subprocess.run([
                'git', 'clone', '-q', '--depth', '1', 
                'https://github.com/tensorflow/tensorflow'
            ], check=True)
            logger.info("TensorFlow 仓库克隆完成")
        else:
            logger.info("TensorFlow 仓库已存在")
            
    def train_model(self):
        """训练模型 - 修复路径问题"""
        logger.info("开始训练模型...")
        logger.info(f"训练词汇: {self.config['wanted_words']}")
        logger.info(f"训练步数: {self.config['training_steps']}")
        logger.info(f"学习率: {self.config['learning_rate']}")
        logger.info(f"总步数: {self.config['total_steps']}")
        
        # 使用Path对象处理路径
        train_script = Path(self.directories['tensorflow_repo']) / 'tensorflow' / 'examples' / 'speech_commands' / 'train.py'
        
        train_cmd = [
            sys.executable,
            str(train_script),  # 转换为字符串
            f"--data_dir={self.directories['dataset']}",
            f"--wanted_words={self.config['wanted_words']}",
            f"--silence_percentage={self.config['silent_percentage']}",
            f"--unknown_percentage={self.config['unknown_percentage']}",
            f"--preprocess={self.config['preprocess']}",
            f"--window_stride={self.config['window_stride']}",
            f"--model_architecture={self.config['model_architecture']}",
            f"--how_many_training_steps={self.config['training_steps']}",
            f"--learning_rate={self.config['learning_rate']}",
            f"--train_dir={self.directories['train']}",
            f"--summaries_dir={self.directories['logs']}",
            f"--verbosity={self.config['verbosity']}",
            f"--eval_step_interval={self.config['eval_step_interval']}",
            f"--save_step_interval={self.config['save_step_interval']}"
        ]

        # 如需继续训练,自动带上最近 checkpoint
        if self.resume:
            latest_ckpt = self._find_latest_checkpoint()
            if latest_ckpt:
                train_cmd.append(f"--start_checkpoint={latest_ckpt}")
                logger.info(f"继续训练:使用最近 checkpoint: {latest_ckpt}")
            else:
                logger.info("--resume 指定但未找到 checkpoint,将从头开始训练")
        
        # 设置环境变量
        env = os.environ.copy()
        speech_commands_path = str(Path(self.directories['tensorflow_repo']) / 'tensorflow' / 'examples' / 'speech_commands')
        env['PYTHONPATH'] = speech_commands_path + os.pathsep + env.get('PYTHONPATH', '')
        
        # Windows特定:设置编码
        if sys.platform == 'win32':
            env['PYTHONIOENCODING'] = 'utf-8'
        
        try:
            logger.info(f"执行训练命令...")
            result = subprocess.run(
                train_cmd, 
                check=True, 
                env=env,
                capture_output=False,
                text=True,
                encoding='utf-8' if sys.platform == 'win32' else None
            )
            logger.info("模型训练完成")
        except subprocess.CalledProcessError as e:
            logger.error(f"训练失败: {e}")
            raise
            
    def freeze_model(self):
        """冻结模型 - 修复路径和编码问题"""
        logger.info("冻结模型...")
        
        # 确保目标目录存在
        saved_model_dir = Path(self.model_paths['saved_model'])
        saved_model_parent = saved_model_dir.parent
        saved_model_parent.mkdir(parents=True, exist_ok=True)
        
        # 如果saved_model目录已存在,删除它
        if saved_model_dir.exists():
            import shutil
            shutil.rmtree(saved_model_dir)
        # 注意:不要预创建 saved_model 目录或其子目录,
        # 让 SavedModelBuilder 在保存时自行创建,
        # 否则会触发 "Export directory already exists, and isn't empty" 错误。
        
        freeze_script = Path(self.directories['tensorflow_repo']) / 'tensorflow' / 'examples' / 'speech_commands' / 'freeze.py'
        
        # 优先使用最新 checkpoint;若找不到则退回到预期步数
        latest_ckpt = self._find_latest_checkpoint()
        if latest_ckpt:
            checkpoint_path = Path(latest_ckpt)
            logger.info(f"冻结将使用最近 checkpoint: {checkpoint_path}")
        else:
            checkpoint_path = Path(self.directories['train']) / f"{self.config['model_architecture']}.ckpt-{self.config['total_steps']}"
            logger.info(f"未找到最近 checkpoint,尝试使用预期路径: {checkpoint_path}")
        
        freeze_cmd = [
            sys.executable,
            str(freeze_script),
            f"--wanted_words={self.config['wanted_words']}",
            f"--window_stride_ms={self.config['window_stride']}",
            f"--preprocess={self.config['preprocess']}",
            f"--model_architecture={self.config['model_architecture']}",
            f"--start_checkpoint={str(checkpoint_path)}",
            f"--save_format=saved_model",
            f"--output_file={str(saved_model_dir)}"  # 使用绝对路径
        ]
        
        # 设置环境变量
        env = os.environ.copy()
        speech_commands_path = str(Path(self.directories['tensorflow_repo']) / 'tensorflow' / 'examples' / 'speech_commands')
        env['PYTHONPATH'] = speech_commands_path + os.pathsep + env.get('PYTHONPATH', '')
        
        # Windows特定设置
        if sys.platform == 'win32':
            env['PYTHONIOENCODING'] = 'utf-8'
            env['PYTHONUTF8'] = '1'
        
        try:
            logger.info(f"执行冻结命令...")
            logger.info(f"输出路径: {saved_model_dir}")
            
            # 使用Popen以更好地处理编码
            process = subprocess.Popen(
                freeze_cmd,
                env=env,
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE,
                text=True,
                encoding='utf-8',
                errors='replace'  # 替换无法解码的字符
            )
            
            stdout, stderr = process.communicate()
            
            if process.returncode != 0:
                logger.error(f"冻结失败,返回码: {process.returncode}")
                if stdout:
                    logger.error(f"标准输出: {stdout}")
                if stderr:
                    logger.error(f"错误输出: {stderr}")
                raise subprocess.CalledProcessError(process.returncode, freeze_cmd)
                
            logger.info("模型冻结完成")
            
        except subprocess.CalledProcessError as e:
            logger.error(f"模型冻结失败: {e}")
            
            # 尝试创建一个简单的saved_model作为备选方案
            logger.info("尝试使用备选方案创建saved_model...")
            self._create_saved_model_fallback()
            
    def _create_saved_model_fallback(self):
        """备选方案:直接从checkpoint创建saved_model"""
        try:
            import tensorflow as tf
            
            logger.info("使用备选方案创建saved_model...")
            
            # 添加speech_commands路径
            speech_commands_path = Path(self.directories['tensorflow_repo']) / 'tensorflow' / 'examples' / 'speech_commands'
            if str(speech_commands_path) not in sys.path:
                sys.path.insert(0, str(speech_commands_path))
            
            import models
            import input_data
            
            # 准备模型设置
            model_settings = models.prepare_model_settings(
                len(input_data.prepare_words_list(self.config['wanted_words'].split(','))),
                self.config['sample_rate'], 
                self.config['clip_duration_ms'], 
                self.config['window_size_ms'],
                self.config['window_stride'], 
                self.config['feature_bin_count'], 
                self.config['preprocess']
            )
            
            # 重置默认图
            tf.compat.v1.reset_default_graph()
            
            with tf.compat.v1.Session() as sess:
                # 创建占位符
                fingerprint_size = model_settings['fingerprint_size']
                fingerprint_input = tf.compat.v1.placeholder(
                    tf.float32, [None, fingerprint_size], name='fingerprint_input'
                )
                
                # 构建模型(is_training=False 时只返回 logits 张量)
                logits = models.create_model(
                    fingerprint_input,
                    model_settings,
                    self.config['model_architecture'],
                    is_training=False
                )
                
                # 添加并命名 softmax 输出
                labels_softmax = tf.nn.softmax(logits, name='labels_softmax')
                
                # 恢复权重
                # 备选方案同样选择最近 checkpoint
                latest_ckpt = self._find_latest_checkpoint()
                if latest_ckpt:
                    checkpoint_path = Path(latest_ckpt)
                else:
                    checkpoint_path = Path(self.directories['train']) / f"{self.config['model_architecture']}.ckpt-{self.config['total_steps']}"
                saver = tf.compat.v1.train.Saver()
                saver.restore(sess, str(checkpoint_path))
                
                # 保存模型
                saved_model_path = Path(self.model_paths['saved_model'])
                
                # 使用TensorFlow 1.x的SavedModelBuilder
                builder = tf.compat.v1.saved_model.builder.SavedModelBuilder(str(saved_model_path))
                
                # 定义签名
                inputs = {'fingerprint_input': tf.compat.v1.saved_model.utils.build_tensor_info(fingerprint_input)}
                outputs = {'labels_softmax': tf.compat.v1.saved_model.utils.build_tensor_info(labels_softmax)}
                
                signature = tf.compat.v1.saved_model.signature_def_utils.build_signature_def(
                    inputs=inputs,
                    outputs=outputs,
                    method_name=tf.compat.v1.saved_model.signature_constants.PREDICT_METHOD_NAME
                )
                
                builder.add_meta_graph_and_variables(
                    sess,
                    [tf.compat.v1.saved_model.tag_constants.SERVING],
                    signature_def_map={
                        tf.compat.v1.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature
                    }
                )
                
                builder.save()
                logger.info(f"备选方案:saved_model已创建在 {saved_model_path}")
                
        except Exception as e:
            logger.error(f"备选方案也失败了: {e}")
            logger.error("建议尝试以下方法:")
            logger.error("1. 降级TensorFlow版本到2.13或更早")
            logger.error("2. 使用--skip_training选项下载预训练模型")
            raise
            
    def convert_to_tflite(self):
        """转换为 TensorFlow Lite 模型"""
        logger.info("转换为 TensorFlow Lite 模型...")
        
        # 添加路径
        speech_commands_path = Path(self.directories['tensorflow_repo']) / 'tensorflow' / 'examples' / 'speech_commands'
        if str(speech_commands_path) not in sys.path:
            sys.path.insert(0, str(speech_commands_path))
        
        try:
            import input_data
            import models
            import numpy as np
            import tensorflow as tf
            
            # 准备模型设置
            model_settings = models.prepare_model_settings(
                len(input_data.prepare_words_list(self.config['wanted_words'].split(','))),
                self.config['sample_rate'], 
                self.config['clip_duration_ms'], 
                self.config['window_size_ms'],
                self.config['window_stride'], 
                self.config['feature_bin_count'], 
                self.config['preprocess']
            )
            
            # 创建音频处理器
            data_url = 'https://storage.googleapis.com/download.tensorflow.org/data/speech_commands_v0.02.tar.gz'
            audio_processor = input_data.AudioProcessor(
                data_url, self.directories['dataset'],
                self.config['silent_percentage'], 
                self.config['unknown_percentage'],
                self.config['wanted_words'].split(','), 
                self.config['validation_percentage'],
                self.config['testing_percentage'], 
                model_settings, 
                self.directories['logs']
            )
            
            with tf.compat.v1.Session() as sess:
                # 生成浮点模型
                logger.info("生成浮点 TensorFlow Lite 模型...")
                float_converter = tf.lite.TFLiteConverter.from_saved_model(
                    str(Path(self.model_paths['saved_model']))
                )
                float_tflite_model = float_converter.convert()
                
                with open(self.model_paths['float_model_tflite'], 'wb') as f:
                    float_model_size = f.write(float_tflite_model)
                logger.info(f"浮点模型大小: {float_model_size} 字节")
                
                # 生成量化模型
                logger.info("生成量化 TensorFlow Lite 模型...")
                converter = tf.lite.TFLiteConverter.from_saved_model(
                    str(Path(self.model_paths['saved_model']))
                )
                converter.optimizations = [tf.lite.Optimize.DEFAULT]
                converter.inference_input_type = tf.int8
                converter.inference_output_type = tf.int8
                
                # 代表性数据集生成器
                def representative_dataset_gen():
                    for i in range(100):
                        data, _ = audio_processor.get_data(
                            1, i*1, model_settings,
                            self.config['background_frequency'],
                            self.config['background_volume_range'],
                            self.config['time_shift_ms'],
                            'testing', sess
                        )
                        flattened_data = np.array(
                            data.flatten(), dtype=np.float32
                        ).reshape(1, 1960)
                        yield [flattened_data]
                        
                converter.representative_dataset = representative_dataset_gen
                tflite_model = converter.convert()
                
                with open(self.model_paths['model_tflite'], 'wb') as f:
                    quantized_model_size = f.write(tflite_model)
                logger.info(f"量化模型大小: {quantized_model_size} 字节")
                
            return audio_processor, model_settings
            
        except Exception as e:
            logger.error(f"TensorFlow Lite 转换失败: {e}")
            raise
            
    def test_tflite_accuracy(self, audio_processor, model_settings):
        """测试 TensorFlow Lite 模型精度"""
        logger.info("测试模型精度...")
        
        import numpy as np
        import tensorflow as tf
        
        def run_tflite_inference(tflite_model_path, model_type="Float"):
            # 加载测试数据
            np.random.seed(0)
            with tf.compat.v1.Session() as sess:
                test_data, test_labels = audio_processor.get_data(
                    -1, 0, model_settings, 
                    self.config['background_frequency'],
                    self.config['background_volume_range'],
                    self.config['time_shift_ms'], 
                    'testing', sess
                )
            test_data = np.expand_dims(test_data, axis=1).astype(np.float32)
            
            # 初始化解释器
            interpreter = tf.lite.Interpreter(
                tflite_model_path,
                experimental_op_resolver_type=tf.lite.experimental.OpResolverType.BUILTIN_REF
            )
            interpreter.allocate_tensors()
            
            input_details = interpreter.get_input_details()[0]
            output_details = interpreter.get_output_details()[0]
            
            # 对于量化模型,手动将输入数据从浮点转换为整数
            if model_type == "Quantized":
                input_scale, input_zero_point = input_details["quantization"]
                test_data = test_data / input_scale + input_zero_point
                test_data = test_data.astype(input_details["dtype"])
                
            correct_predictions = 0
            for i in range(len(test_data)):
                interpreter.set_tensor(input_details["index"], test_data[i])
                interpreter.invoke()
                output = interpreter.get_tensor(output_details["index"])[0]
                top_prediction = output.argmax()
                correct_predictions += (top_prediction == test_labels[i])
                
            accuracy = (correct_predictions * 100) / len(test_data)
            logger.info(f'{model_type} 模型精度: {accuracy:.2f}% (测试样本数={len(test_data)})')
            return accuracy
            
        # 测试模型
        try:
            float_accuracy = run_tflite_inference(self.model_paths['float_model_tflite'])
        except Exception as e:
            logger.error(f"浮点模型测试失败: {e}")
            float_accuracy = None
        
        try:
            quantized_accuracy = run_tflite_inference(
                self.model_paths['model_tflite'], 
                model_type='Quantized'
            )
        except Exception as e:
            logger.error(f"量化模型测试失败: {e}")
            quantized_accuracy = None
        
        return float_accuracy, quantized_accuracy
        
    def generate_micro_model(self):
        """生成微控制器 C 源文件"""
        logger.info("生成微控制器 C 源文件...")
        
        try:
            # 直接使用Python实现
            self._generate_c_file_python()
            logger.info(f"C 源文件已生成: {self.model_paths['model_tflite_micro']}")
        except Exception as e:
            logger.error(f"生成 C 源文件失败: {e}")
            raise
        
    def _generate_c_file_python(self):
        """使用 Python 生成 C 文件"""
        with open(self.model_paths['model_tflite'], 'rb') as f:
            model_data = f.read()
            
        # 生成 C 数组
        c_content = []
        c_content.append('/* Automatically generated by speech_trainer.py */')
        c_content.append('#include "micro_speech_quantized_model_data.h"')
        c_content.append('')
        c_content.append('const unsigned char micro_speech_quantized_tflite[] = {')
        
        # 将字节数据转换为 C 数组格式
        hex_values = []
        for i, byte in enumerate(model_data):
            if i % 16 == 0:
                hex_values.append('\n  ')
            hex_values.append(f'0x{byte:02x}')
            if i < len(model_data) - 1:
                hex_values.append(',')
            if (i + 1) % 16 != 0 and i < len(model_data) - 1:
                hex_values.append(' ')
        
        c_content.append(''.join(hex_values))
        c_content.append('\n};')
        c_content.append(f'const unsigned int micro_speech_quantized_tflite_len = {len(model_data)};')
        c_content.append('')
        
        # 写入文件
        with open(self.model_paths['model_tflite_micro'], 'w', encoding='utf-8') as f:
            f.write('\n'.join(c_content))
            
    def download_pretrained_model(self):
        """下载预训练模型"""
        logger.info("下载预训练模型...")
        
        try:
            import urllib.request
            import tarfile
            
            model_url = "https://storage.googleapis.com/download.tensorflow.org/models/tflite/speech_micro_train_2020_05_10.tgz"
            model_file = "speech_micro_train_2020_05_10.tgz"
            
            logger.info(f"从 {model_url} 下载模型...")
            urllib.request.urlretrieve(model_url, model_file)
            
            logger.info("解压模型文件...")
            with tarfile.open(model_file, 'r:gz') as tar:
                tar.extractall('.')
                
            os.remove(model_file)
            
            logger.info("预训练模型下载完成")
            return True
            
        except Exception as e:
            logger.error(f"下载预训练模型失败: {e}")
            return False
        
    def print_model_info(self):
        """打印模型信息"""
        logger.info("\n" + "="*50)
        logger.info("模型训练完成!")
        logger.info("="*50)
        
        for name, path in self.model_paths.items():
            if os.path.exists(path):
                if os.path.isfile(path):
                    size = os.path.getsize(path)
                    logger.info(f"{name}: {path} ({size} 字节)")
                else:
                    logger.info(f"{name}: {path} (目录)")
            else:
                logger.info(f"{name}: {path} (未生成)")
                
        logger.info("\n部署到微控制器:")
        logger.info("1. 参考 TensorFlow Lite Micro 文档")
        logger.info("2. 更新 micro_model_settings.h 中的 kCategoryCount 和 kCategoryLabels")
        logger.info("3. 使用生成的 micro_speech_quantized_model_data.c 文件替换原有模型文件")
        
    def run_full_pipeline(self, skip_training=False):
        """运行完整的训练流水线"""
        try:
            logger.info("开始完整的模型训练流水线...")
            
            # 确保数据集就绪
            self.ensure_dataset()

            # 清理之前的数据
            self.clean_previous_data()
            
            if skip_training:
                success = self.download_pretrained_model()
                if not success:
                    logger.error("无法下载预训练模型,将执行完整训练")
                    skip_training = False
            
            if not skip_training:
                # 设置 TensorFlow 仓库
                self.setup_tensorflow_repo()
                
                # 训练模型
                self.train_model()
                
                # 冻结模型
                self.freeze_model()
            else:
                logger.info("使用预训练模型,跳过训练步骤")
            
            # 转换为 TensorFlow Lite
            audio_processor, model_settings = self.convert_to_tflite()
            
            # 测试模型精度
            self.test_tflite_accuracy(audio_processor, model_settings)
            
            # 生成微控制器模型
            self.generate_micro_model()
            
            # 打印模型信息
            self.print_model_info()
            
            logger.info("训练流水线完成!")
            
        except KeyboardInterrupt:
            logger.info("用户中断训练")
            raise
        except Exception as e:
            logger.error(f"训练流水线失败: {e}")
            logger.error("可能的解决方案:")
            logger.error("1. 检查网络连接(下载数据集需要)")
            logger.error("2. 确保有足够的磁盘空间")
            logger.error("3. 检查 Python 环境和依赖")
            logger.error("4. 尝试使用 --skip_training 下载预训练模型")
            raise

def main():
    """主函数"""
    parser = argparse.ArgumentParser(description='语音识别模型训练程序')
    parser.add_argument('--wanted_words', default='yes,no', 
                       help='要训练的词汇,用逗号分隔 (默认: yes,no)')
    parser.add_argument('--training_steps', default='12000,3000',
                       help='训练步数,用逗号分隔 (默认: 12000,3000)')
    parser.add_argument('--learning_rate', default='0.001,0.0001',
                       help='学习率,用逗号分隔 (默认: 0.001,0.0001)')
    parser.add_argument('--model_architecture', default='tiny_conv',
                       choices=['single_fc', 'conv', 'low_latency_conv', 
                               'low_latency_svdf', 'tiny_embedding_conv', 'tiny_conv'],
                       help='模型架构 (默认: tiny_conv)')
    parser.add_argument('--skip_training', action='store_true',
                       help='跳过训练,使用预训练模型')
    parser.add_argument('--resume', action='store_true',
                       help='继续上次训练:保留 train 目录并从最近 checkpoint 恢复')
    parser.add_argument('--test_env', action='store_true',
                       help='仅测试环境,不执行训练')
    
    args = parser.parse_args()
    
    if args.test_env:
        # 仅测试环境
        logger.info("测试环境配置...")
        try:
            trainer = SpeechRecognitionTrainer()
            logger.info("✓ 环境测试通过")
            logger.info("✓ 所有依赖已正确安装")
            logger.info("可以开始训练模型")
        except Exception as e:
            logger.error(f"✗ 环境测试失败: {e}")
            sys.exit(1)
        return
    
    # 配置训练参数
    config = {
        'wanted_words': args.wanted_words,
        'training_steps': args.training_steps,
        'learning_rate': args.learning_rate,
        'model_architecture': args.model_architecture,
        'resume': args.resume,
    }
    
    # 创建训练器
    trainer = SpeechRecognitionTrainer(config)
    
    # 运行完整训练流水线
    trainer.run_full_pipeline(skip_training=args.skip_training)

if __name__ == '__main__':
    main()
相关推荐
哔哩哔哩技术8 小时前
AniSoraV3 正式开源,长视频创作智能体框架AniME技术揭秘
人工智能
云卓SKYDROID8 小时前
无人机GPS悬停模块技术解析
人工智能·目标跟踪·无人机·高科技·航线系统
百锦再8 小时前
Python:AI开发第一语言的全面剖析
java·开发语言·人工智能·python·sql·ai·radis
Virgil1398 小时前
如何正确使用ChatGPT做数学建模比赛——数学建模AI使用技巧
人工智能·数学建模·chatgpt
绿豆_13148 小时前
playwright+python UI自动化测试中实现图片颜色和像素对比
自动化测试·python·opencv·计算机视觉·playwirght
IT_陈寒8 小时前
Python性能优化:这5个隐藏技巧让我的代码提速300%!
前端·人工智能·后端
盗理者8 小时前
Python 工具: Windows 带宽监控工具
开发语言·windows·python
大翻哥哥8 小时前
Python 2025:量子计算、区块链与边缘计算的新前沿
python·区块链·量子计算
NG WING YIN8 小时前
量子電腦組裝之一
人工智能·深度学习·软件工程