TensorFlow的Yes/No 关键词识别模型训练
参考 TensorFlow 官方教程的 Yes/No 关键词识别模型训练脚本,可以生成直接替换原有mirco_speech识别模型数据的C文件。
模型大小20k byte左右
训练脚本speech_trainer.py 使用说明
speech_trainer.py 脚本源码在文章最后
脚本概述
speech_trainer.py
提供一条龙流水线:
- 自动准备数据集(首次或缺失时自动下载/解压 Speech Commands v0.02)
- 克隆
tensorflow
仓库并调用官方train.py
训练 - 冻结模型为 SavedModel(兼容 v1 风格,内部调用官方
freeze.py
) - 生成 TFLite 浮点与量化模型,并做精度验证
- 生成可直接用于 MCU 的
C
源文件micro_speech_quantized_model_data.c
目录与输出位置:
dataset/
:数据集(自动管理,不会被清理)train/
:训练输出(checkpoint 等),支持断点续训logs/
:日志与事件文件models/
:saved_model/
:冻结后的 SavedModelmicro_speech_quantized.tflite
:量化 TFLite 模型micro_speech_float.tflite
:浮点 TFLite 模型micro_speech_quantized_model_data.c
:TFLite Micro C 数组文件
目录结构
train/
├── dataset/ # 数据集(自动管理)
├── logs/ # 训练日志与事件文件
├── models/ # 导出模型
│ ├── saved_model/ # 冻结 SavedModel
│ ├── micro_speech_quantized.tflite # 量化 TFLite 模型
│ ├── micro_speech_float.tflite # 浮点 TFLite 模型
│ └── micro_speech_quantized_model_data.c # MCU 可用的 C 数组
├── tensorflow/ # 克隆的 TF 仓库(含官方 scripts)
└── speech_trainer.py # 本脚本
环境要求
- Python 3.10(Windows 建议
py -3.10
) - 安装依赖:
Windows(PowerShell):
powershell
# 安装Python 3.10
winget install -e --id Python.Python.3.10
# 创建Python虚拟环境
py -3.10 -m venv .venvpy310_win
# 进入Python虚拟环境
.venvpy310_win\Scripts\Activate.ps1
# 更新pip
python -m pip install --upgrade pip
# 安装依赖包
pip install -r requirements.txt
Linux/macOS(bash):
bash
python3.10 -m venv .venvpy310
source .venvpy310/bin/activate
python -m pip install --upgrade pip
pip install -r requirements.txt
快速开始(3 步)
- 准备环境(见上)并激活 venv
- 一条命令启动训练:
bash
python train/speech_trainer.py
- 结束后在
train/models/
获取micro_speech_quantized_model_data.c
、micro_speech_quantized.tflite
等产物
数据集管理逻辑(自动)
- 若
dataset/
不存在:创建并下载speech_commands_v0.02.tar.gz
到dataset/
,随后解压。 - 若
dataset/
已存在:- 若压缩包不存在,则先下载
- 若未解压(通过是否存在
yes/
、no/
、_background_noise_/
判定)则解压
你无需手动干预,脚本会在流水线开始自动确保数据集就绪。
资源与磁盘建议:
- 磁盘空间:≥ 3 GB(含数据集与中间文件)
- 内存:≥ 4 GB(更高内存更稳定)
- 网络:可访问
storage.googleapis.com
(如网络受限请配置代理或手动下载)
命令行参数
--wanted_words
:要训练的词汇(逗号分隔)。默认:yes,no
--training_steps
:训练步数字符串(逗号分段)。默认:12000,3000
--learning_rate
:学习率字符串(逗号分段)。默认:0.001,0.0001
--model_architecture
:模型架构。可选:single_fc
、conv
、low_latency_conv
、low_latency_svdf
、tiny_embedding_conv
、tiny_conv
(默认tiny_conv
)--skip_training
:跳过训练,直接下载官方预训练模型并进入后续转换/导出--resume
:继续上次训练(保留train/
,自动查找最近 checkpoint 作为--start_checkpoint
)--test_env
:仅测试环境(依赖/路径检查),不执行训练
说明:
- 训练时总步数为各分段之和(例如
12000,3000
=> 总步数15000
)。 - 续训时若未找到 checkpoint,将从头开始训练(脚本会提示)。
运行示例
- 基础训练(推荐)
bash
python train/speech_trainer.py
- 指定词表与步数
bash
python train/speech_trainer.py \
--wanted_words yes,no \
--training_steps 12000,3000 \
--learning_rate 0.001,0.0001 \
--model_architecture tiny_conv
- 继续上次训练
bash
python train/speech_trainer.py --resume
- 使用预训练模型(跳过训练)
bash
python train/speech_trainer.py --skip_training
- 仅测试环境
bash
python train/speech_trainer.py --test_env
续训与从零重训
- 继续训练:使用
--resume
,自动寻找train/
下步数最大的*.ckpt-*.index
- 从零重训:删除
train/
与logs/
再运行;或不删目录直接不加--resume
- 仅重新导出:删除
models/
并重跑(会跳过训练,直接冻结与导出)
流水线阶段说明
- 确保数据集:自动下载/解压 Speech Commands v0.02
- 克隆
tensorflow
仓库(若已存在则跳过) - 调用官方
train.py
进行训练(可续训) - 调用官方
freeze.py
生成saved_model/
- 生成
micro_speech_float.tflite
与micro_speech_quantized.tflite
(量化),并进行精度评估 - 生成
micro_speech_quantized_model_data.c
(TFLite Micro C 数组) - 打印各输出文件路径与大小
集成到 MCU 工程
- 训练完成后,
models/micro_speech_quantized_model_data.c
即为可直接集成的模型数据文件。 - 将其复制到你的工程对应目录(例如
kws/
),替换旧模型文件后编译。
常见问题(FAQ)
- 导出目录已存在:脚本已处理为不预创建
saved_model
子目录,如仍遇到该错误,可手动删除models/saved_model/
后重试。 - 续训未找到 checkpoint:确认
train/
下存在形如tiny_conv.ckpt-*.index
文件;否则将从头训练。 - 数据集下载失败:检查网络或手动下载
speech_commands_v0.02.tar.gz
到dataset/
目录后重跑。 - 量化模型精度下降明显:可适当增大代表性数据采样数量或调整训练步数与学习率。
其它提示:
- Windows 执行策略:若激活虚拟环境报策略限制,可在管理员 PowerShell 运行:
powershell
Set-ExecutionPolicy -Scope CurrentUser RemoteSigned
- 国内网络下载慢/失败:可预先手动下载
speech_commands_v0.02.tar.gz
到train/dataset/
再运行。 - TF CPU 指令集提示(SSE/AVX 等):为性能提示,可忽略,不影响功能。
speech_trainer.py代码
python
#!/usr/bin/env python3
"""
语音识别模型训练程序
基于 TensorFlow 的简单音频识别模型训练脚本
支持生成 TensorFlow Lite 模型用于微控制器部署
"""
import os
import sys
import subprocess
import argparse
import logging
from pathlib import Path
import locale
# 设置编码,解决Windows中文路径问题
if sys.platform == 'win32':
import codecs
# 设置控制台编码
if sys.stdout.encoding != 'utf-8':
sys.stdout = codecs.getwriter('utf-8')(sys.stdout.buffer, 'strict')
if sys.stderr.encoding != 'utf-8':
sys.stderr = codecs.getwriter('utf-8')(sys.stderr.buffer, 'strict')
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# 延迟导入
numpy = None
tensorflow = None
class SpeechRecognitionTrainer:
def __init__(self, config=None):
"""初始化训练器配置"""
# 检查并安装依赖
self._check_and_install_dependencies()
# 全局导入
global numpy, tensorflow
import numpy
import tensorflow
# 默认配置
self.config = {
# 训练参数
'wanted_words': 'yes,no',
'training_steps': '1000,1000',
'learning_rate': '0.001,0.0001',
# 模型参数
'preprocess': 'micro',
'window_stride': 20,
'model_architecture': 'tiny_conv',
# 训练控制参数
'verbosity': 'INFO',
'eval_step_interval': '1000',
'save_step_interval': '1000',
# 数据参数
'sample_rate': 16000,
'clip_duration_ms': 1000,
'window_size_ms': 30.0,
'feature_bin_count': 40,
'background_frequency': 0.8,
'background_volume_range': 0.1,
'time_shift_ms': 100.0,
'validation_percentage': 10,
'testing_percentage': 10,
# 量化参数
'quant_input_min': 0.0,
'quant_input_max': 26.0,
}
# 更新配置
if config:
self.config.update(config)
# 计算派生参数
self._calculate_derived_params()
# 设置目录路径
self._setup_directories()
# 运行时状态
self.resume = bool(self.config.get('resume', False))
def _find_latest_checkpoint(self):
"""返回训练目录下最新的 checkpoint 路径(不含扩展名),找不到则返回 None"""
from pathlib import Path
train_dir = Path(self.directories['train'])
model_prefix = self.config['model_architecture'] + '.ckpt-'
candidates = []
for index_file in train_dir.glob(f"{self.config['model_architecture']}.ckpt-*.index"):
name = index_file.name # e.g. tiny_conv.ckpt-12345.index
try:
step_str = name.split('.ckpt-')[-1].split('.index')[0]
step = int(step_str)
candidates.append((step, index_file))
except Exception:
continue
if not candidates:
return None
candidates.sort(key=lambda x: x[0], reverse=True)
latest_index = candidates[0][1]
# 去掉 .index 扩展名
return str(latest_index.with_suffix(''))
def _check_and_install_dependencies(self):
"""检查并安装必要的依赖"""
logger.info("检查并安装依赖...")
required_packages = [
'numpy',
'tensorflow',
'matplotlib',
'scipy',
]
missing_packages = []
for package in required_packages:
try:
__import__(package)
logger.info(f"✓ {package} 已安装")
except ImportError:
missing_packages.append(package)
logger.warning(f"✗ {package} 未安装")
if missing_packages:
logger.info(f"安装缺失的包: {', '.join(missing_packages)}")
for package in missing_packages:
try:
subprocess.check_call([
sys.executable, '-m', 'pip', 'install', package
])
logger.info(f"✓ {package} 安装成功")
except subprocess.CalledProcessError as e:
logger.error(f"✗ {package} 安装失败: {e}")
raise
logger.info("依赖检查完成")
def _calculate_derived_params(self):
"""计算派生参数"""
steps = self.config['training_steps'].split(',')
self.config['total_steps'] = str(sum(int(step) for step in steps))
number_of_labels = self.config['wanted_words'].count(',') + 1
number_of_total_labels = number_of_labels + 2
equal_percentage = int(100.0 / number_of_total_labels)
self.config['silent_percentage'] = equal_percentage
self.config['unknown_percentage'] = equal_percentage
self.config['quant_input_range'] = self.config['quant_input_max'] - self.config['quant_input_min']
def _setup_directories(self):
"""设置目录路径 - 使用绝对路径避免相对路径问题"""
# 获取当前工作目录的绝对路径
base_dir = os.path.abspath(os.getcwd())
self.directories = {
'dataset': os.path.join(base_dir, 'dataset'),
'logs': os.path.join(base_dir, 'logs'),
'train': os.path.join(base_dir, 'train'),
'models': os.path.join(base_dir, 'models'),
'tensorflow_repo': os.path.join(base_dir, 'tensorflow')
}
# 模型文件路径
models_dir = self.directories['models']
self.model_paths = {
'saved_model': os.path.join(models_dir, 'saved_model'),
'model_tflite': os.path.join(models_dir, 'micro_speech_quantized.tflite'),
'float_model_tflite': os.path.join(models_dir, 'micro_speech_float.tflite'),
'model_tflite_micro': os.path.join(models_dir, 'micro_speech_quantized_model_data.c'),
}
def clean_previous_data(self):
"""清理之前的训练数据"""
logger.info("清理之前的训练数据...")
import shutil
for name, directory in self.directories.items():
if name in ('tensorflow_repo', 'dataset'):
logger.info(f"跳过目录: {name} -> {directory}")
continue
if self.resume and name == 'train':
logger.info(f"检测到 --resume,跳过清理训练目录: {directory}")
continue
if os.path.exists(directory):
try:
shutil.rmtree(directory)
logger.info(f"已删除目录: {directory}")
except Exception as e:
logger.warning(f"无法删除 {directory}: {e}")
# 创建必要目录
os.makedirs(self.directories['models'], exist_ok=True)
os.makedirs(self.directories['dataset'], exist_ok=True)
def ensure_dataset(self):
"""确保 dataset 目录存在并包含 speech_commands_v0.02 数据集
规则:
1) 若 dataset 不存在,则创建并下载压缩包到其中,然后解压;
2) 若 dataset 存在:检查是否已有压缩包;若无则下载;随后若未解压则解压。
"""
import urllib.request
import tarfile
dataset_dir = Path(self.directories['dataset'])
dataset_dir.mkdir(parents=True, exist_ok=True)
archive_name = 'speech_commands_v0.02.tar.gz'
archive_path = dataset_dir / archive_name
data_url = 'https://storage.googleapis.com/download.tensorflow.org/data/speech_commands_v0.02.tar.gz'
# 判断是否已解压(通过常见子目录存在性来粗略判断)
def is_extracted(path: Path) -> bool:
common_subdirs = ['yes', 'no', '_background_noise_']
return any((path / sub).exists() for sub in common_subdirs)
# 如无压缩包则下载
if not archive_path.exists():
logger.info(f"下载数据集到: {archive_path}")
try:
urllib.request.urlretrieve(data_url, str(archive_path))
except Exception as e:
logger.error(f"数据集下载失败: {e}")
raise
# 如未解压则解压
if not is_extracted(dataset_dir):
logger.info("解压数据集...")
try:
with tarfile.open(str(archive_path), 'r:gz') as tar:
tar.extractall(str(dataset_dir))
except Exception as e:
logger.error(f"解压失败: {e}")
raise
logger.info("数据集解压完成")
def setup_tensorflow_repo(self):
"""克隆 TensorFlow 仓库"""
if not os.path.exists(self.directories['tensorflow_repo']):
logger.info("克隆 TensorFlow 仓库...")
subprocess.run([
'git', 'clone', '-q', '--depth', '1',
'https://github.com/tensorflow/tensorflow'
], check=True)
logger.info("TensorFlow 仓库克隆完成")
else:
logger.info("TensorFlow 仓库已存在")
def train_model(self):
"""训练模型 - 修复路径问题"""
logger.info("开始训练模型...")
logger.info(f"训练词汇: {self.config['wanted_words']}")
logger.info(f"训练步数: {self.config['training_steps']}")
logger.info(f"学习率: {self.config['learning_rate']}")
logger.info(f"总步数: {self.config['total_steps']}")
# 使用Path对象处理路径
train_script = Path(self.directories['tensorflow_repo']) / 'tensorflow' / 'examples' / 'speech_commands' / 'train.py'
train_cmd = [
sys.executable,
str(train_script), # 转换为字符串
f"--data_dir={self.directories['dataset']}",
f"--wanted_words={self.config['wanted_words']}",
f"--silence_percentage={self.config['silent_percentage']}",
f"--unknown_percentage={self.config['unknown_percentage']}",
f"--preprocess={self.config['preprocess']}",
f"--window_stride={self.config['window_stride']}",
f"--model_architecture={self.config['model_architecture']}",
f"--how_many_training_steps={self.config['training_steps']}",
f"--learning_rate={self.config['learning_rate']}",
f"--train_dir={self.directories['train']}",
f"--summaries_dir={self.directories['logs']}",
f"--verbosity={self.config['verbosity']}",
f"--eval_step_interval={self.config['eval_step_interval']}",
f"--save_step_interval={self.config['save_step_interval']}"
]
# 如需继续训练,自动带上最近 checkpoint
if self.resume:
latest_ckpt = self._find_latest_checkpoint()
if latest_ckpt:
train_cmd.append(f"--start_checkpoint={latest_ckpt}")
logger.info(f"继续训练:使用最近 checkpoint: {latest_ckpt}")
else:
logger.info("--resume 指定但未找到 checkpoint,将从头开始训练")
# 设置环境变量
env = os.environ.copy()
speech_commands_path = str(Path(self.directories['tensorflow_repo']) / 'tensorflow' / 'examples' / 'speech_commands')
env['PYTHONPATH'] = speech_commands_path + os.pathsep + env.get('PYTHONPATH', '')
# Windows特定:设置编码
if sys.platform == 'win32':
env['PYTHONIOENCODING'] = 'utf-8'
try:
logger.info(f"执行训练命令...")
result = subprocess.run(
train_cmd,
check=True,
env=env,
capture_output=False,
text=True,
encoding='utf-8' if sys.platform == 'win32' else None
)
logger.info("模型训练完成")
except subprocess.CalledProcessError as e:
logger.error(f"训练失败: {e}")
raise
def freeze_model(self):
"""冻结模型 - 修复路径和编码问题"""
logger.info("冻结模型...")
# 确保目标目录存在
saved_model_dir = Path(self.model_paths['saved_model'])
saved_model_parent = saved_model_dir.parent
saved_model_parent.mkdir(parents=True, exist_ok=True)
# 如果saved_model目录已存在,删除它
if saved_model_dir.exists():
import shutil
shutil.rmtree(saved_model_dir)
# 注意:不要预创建 saved_model 目录或其子目录,
# 让 SavedModelBuilder 在保存时自行创建,
# 否则会触发 "Export directory already exists, and isn't empty" 错误。
freeze_script = Path(self.directories['tensorflow_repo']) / 'tensorflow' / 'examples' / 'speech_commands' / 'freeze.py'
# 优先使用最新 checkpoint;若找不到则退回到预期步数
latest_ckpt = self._find_latest_checkpoint()
if latest_ckpt:
checkpoint_path = Path(latest_ckpt)
logger.info(f"冻结将使用最近 checkpoint: {checkpoint_path}")
else:
checkpoint_path = Path(self.directories['train']) / f"{self.config['model_architecture']}.ckpt-{self.config['total_steps']}"
logger.info(f"未找到最近 checkpoint,尝试使用预期路径: {checkpoint_path}")
freeze_cmd = [
sys.executable,
str(freeze_script),
f"--wanted_words={self.config['wanted_words']}",
f"--window_stride_ms={self.config['window_stride']}",
f"--preprocess={self.config['preprocess']}",
f"--model_architecture={self.config['model_architecture']}",
f"--start_checkpoint={str(checkpoint_path)}",
f"--save_format=saved_model",
f"--output_file={str(saved_model_dir)}" # 使用绝对路径
]
# 设置环境变量
env = os.environ.copy()
speech_commands_path = str(Path(self.directories['tensorflow_repo']) / 'tensorflow' / 'examples' / 'speech_commands')
env['PYTHONPATH'] = speech_commands_path + os.pathsep + env.get('PYTHONPATH', '')
# Windows特定设置
if sys.platform == 'win32':
env['PYTHONIOENCODING'] = 'utf-8'
env['PYTHONUTF8'] = '1'
try:
logger.info(f"执行冻结命令...")
logger.info(f"输出路径: {saved_model_dir}")
# 使用Popen以更好地处理编码
process = subprocess.Popen(
freeze_cmd,
env=env,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
encoding='utf-8',
errors='replace' # 替换无法解码的字符
)
stdout, stderr = process.communicate()
if process.returncode != 0:
logger.error(f"冻结失败,返回码: {process.returncode}")
if stdout:
logger.error(f"标准输出: {stdout}")
if stderr:
logger.error(f"错误输出: {stderr}")
raise subprocess.CalledProcessError(process.returncode, freeze_cmd)
logger.info("模型冻结完成")
except subprocess.CalledProcessError as e:
logger.error(f"模型冻结失败: {e}")
# 尝试创建一个简单的saved_model作为备选方案
logger.info("尝试使用备选方案创建saved_model...")
self._create_saved_model_fallback()
def _create_saved_model_fallback(self):
"""备选方案:直接从checkpoint创建saved_model"""
try:
import tensorflow as tf
logger.info("使用备选方案创建saved_model...")
# 添加speech_commands路径
speech_commands_path = Path(self.directories['tensorflow_repo']) / 'tensorflow' / 'examples' / 'speech_commands'
if str(speech_commands_path) not in sys.path:
sys.path.insert(0, str(speech_commands_path))
import models
import input_data
# 准备模型设置
model_settings = models.prepare_model_settings(
len(input_data.prepare_words_list(self.config['wanted_words'].split(','))),
self.config['sample_rate'],
self.config['clip_duration_ms'],
self.config['window_size_ms'],
self.config['window_stride'],
self.config['feature_bin_count'],
self.config['preprocess']
)
# 重置默认图
tf.compat.v1.reset_default_graph()
with tf.compat.v1.Session() as sess:
# 创建占位符
fingerprint_size = model_settings['fingerprint_size']
fingerprint_input = tf.compat.v1.placeholder(
tf.float32, [None, fingerprint_size], name='fingerprint_input'
)
# 构建模型(is_training=False 时只返回 logits 张量)
logits = models.create_model(
fingerprint_input,
model_settings,
self.config['model_architecture'],
is_training=False
)
# 添加并命名 softmax 输出
labels_softmax = tf.nn.softmax(logits, name='labels_softmax')
# 恢复权重
# 备选方案同样选择最近 checkpoint
latest_ckpt = self._find_latest_checkpoint()
if latest_ckpt:
checkpoint_path = Path(latest_ckpt)
else:
checkpoint_path = Path(self.directories['train']) / f"{self.config['model_architecture']}.ckpt-{self.config['total_steps']}"
saver = tf.compat.v1.train.Saver()
saver.restore(sess, str(checkpoint_path))
# 保存模型
saved_model_path = Path(self.model_paths['saved_model'])
# 使用TensorFlow 1.x的SavedModelBuilder
builder = tf.compat.v1.saved_model.builder.SavedModelBuilder(str(saved_model_path))
# 定义签名
inputs = {'fingerprint_input': tf.compat.v1.saved_model.utils.build_tensor_info(fingerprint_input)}
outputs = {'labels_softmax': tf.compat.v1.saved_model.utils.build_tensor_info(labels_softmax)}
signature = tf.compat.v1.saved_model.signature_def_utils.build_signature_def(
inputs=inputs,
outputs=outputs,
method_name=tf.compat.v1.saved_model.signature_constants.PREDICT_METHOD_NAME
)
builder.add_meta_graph_and_variables(
sess,
[tf.compat.v1.saved_model.tag_constants.SERVING],
signature_def_map={
tf.compat.v1.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature
}
)
builder.save()
logger.info(f"备选方案:saved_model已创建在 {saved_model_path}")
except Exception as e:
logger.error(f"备选方案也失败了: {e}")
logger.error("建议尝试以下方法:")
logger.error("1. 降级TensorFlow版本到2.13或更早")
logger.error("2. 使用--skip_training选项下载预训练模型")
raise
def convert_to_tflite(self):
"""转换为 TensorFlow Lite 模型"""
logger.info("转换为 TensorFlow Lite 模型...")
# 添加路径
speech_commands_path = Path(self.directories['tensorflow_repo']) / 'tensorflow' / 'examples' / 'speech_commands'
if str(speech_commands_path) not in sys.path:
sys.path.insert(0, str(speech_commands_path))
try:
import input_data
import models
import numpy as np
import tensorflow as tf
# 准备模型设置
model_settings = models.prepare_model_settings(
len(input_data.prepare_words_list(self.config['wanted_words'].split(','))),
self.config['sample_rate'],
self.config['clip_duration_ms'],
self.config['window_size_ms'],
self.config['window_stride'],
self.config['feature_bin_count'],
self.config['preprocess']
)
# 创建音频处理器
data_url = 'https://storage.googleapis.com/download.tensorflow.org/data/speech_commands_v0.02.tar.gz'
audio_processor = input_data.AudioProcessor(
data_url, self.directories['dataset'],
self.config['silent_percentage'],
self.config['unknown_percentage'],
self.config['wanted_words'].split(','),
self.config['validation_percentage'],
self.config['testing_percentage'],
model_settings,
self.directories['logs']
)
with tf.compat.v1.Session() as sess:
# 生成浮点模型
logger.info("生成浮点 TensorFlow Lite 模型...")
float_converter = tf.lite.TFLiteConverter.from_saved_model(
str(Path(self.model_paths['saved_model']))
)
float_tflite_model = float_converter.convert()
with open(self.model_paths['float_model_tflite'], 'wb') as f:
float_model_size = f.write(float_tflite_model)
logger.info(f"浮点模型大小: {float_model_size} 字节")
# 生成量化模型
logger.info("生成量化 TensorFlow Lite 模型...")
converter = tf.lite.TFLiteConverter.from_saved_model(
str(Path(self.model_paths['saved_model']))
)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int8
# 代表性数据集生成器
def representative_dataset_gen():
for i in range(100):
data, _ = audio_processor.get_data(
1, i*1, model_settings,
self.config['background_frequency'],
self.config['background_volume_range'],
self.config['time_shift_ms'],
'testing', sess
)
flattened_data = np.array(
data.flatten(), dtype=np.float32
).reshape(1, 1960)
yield [flattened_data]
converter.representative_dataset = representative_dataset_gen
tflite_model = converter.convert()
with open(self.model_paths['model_tflite'], 'wb') as f:
quantized_model_size = f.write(tflite_model)
logger.info(f"量化模型大小: {quantized_model_size} 字节")
return audio_processor, model_settings
except Exception as e:
logger.error(f"TensorFlow Lite 转换失败: {e}")
raise
def test_tflite_accuracy(self, audio_processor, model_settings):
"""测试 TensorFlow Lite 模型精度"""
logger.info("测试模型精度...")
import numpy as np
import tensorflow as tf
def run_tflite_inference(tflite_model_path, model_type="Float"):
# 加载测试数据
np.random.seed(0)
with tf.compat.v1.Session() as sess:
test_data, test_labels = audio_processor.get_data(
-1, 0, model_settings,
self.config['background_frequency'],
self.config['background_volume_range'],
self.config['time_shift_ms'],
'testing', sess
)
test_data = np.expand_dims(test_data, axis=1).astype(np.float32)
# 初始化解释器
interpreter = tf.lite.Interpreter(
tflite_model_path,
experimental_op_resolver_type=tf.lite.experimental.OpResolverType.BUILTIN_REF
)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()[0]
output_details = interpreter.get_output_details()[0]
# 对于量化模型,手动将输入数据从浮点转换为整数
if model_type == "Quantized":
input_scale, input_zero_point = input_details["quantization"]
test_data = test_data / input_scale + input_zero_point
test_data = test_data.astype(input_details["dtype"])
correct_predictions = 0
for i in range(len(test_data)):
interpreter.set_tensor(input_details["index"], test_data[i])
interpreter.invoke()
output = interpreter.get_tensor(output_details["index"])[0]
top_prediction = output.argmax()
correct_predictions += (top_prediction == test_labels[i])
accuracy = (correct_predictions * 100) / len(test_data)
logger.info(f'{model_type} 模型精度: {accuracy:.2f}% (测试样本数={len(test_data)})')
return accuracy
# 测试模型
try:
float_accuracy = run_tflite_inference(self.model_paths['float_model_tflite'])
except Exception as e:
logger.error(f"浮点模型测试失败: {e}")
float_accuracy = None
try:
quantized_accuracy = run_tflite_inference(
self.model_paths['model_tflite'],
model_type='Quantized'
)
except Exception as e:
logger.error(f"量化模型测试失败: {e}")
quantized_accuracy = None
return float_accuracy, quantized_accuracy
def generate_micro_model(self):
"""生成微控制器 C 源文件"""
logger.info("生成微控制器 C 源文件...")
try:
# 直接使用Python实现
self._generate_c_file_python()
logger.info(f"C 源文件已生成: {self.model_paths['model_tflite_micro']}")
except Exception as e:
logger.error(f"生成 C 源文件失败: {e}")
raise
def _generate_c_file_python(self):
"""使用 Python 生成 C 文件"""
with open(self.model_paths['model_tflite'], 'rb') as f:
model_data = f.read()
# 生成 C 数组
c_content = []
c_content.append('/* Automatically generated by speech_trainer.py */')
c_content.append('#include "micro_speech_quantized_model_data.h"')
c_content.append('')
c_content.append('const unsigned char micro_speech_quantized_tflite[] = {')
# 将字节数据转换为 C 数组格式
hex_values = []
for i, byte in enumerate(model_data):
if i % 16 == 0:
hex_values.append('\n ')
hex_values.append(f'0x{byte:02x}')
if i < len(model_data) - 1:
hex_values.append(',')
if (i + 1) % 16 != 0 and i < len(model_data) - 1:
hex_values.append(' ')
c_content.append(''.join(hex_values))
c_content.append('\n};')
c_content.append(f'const unsigned int micro_speech_quantized_tflite_len = {len(model_data)};')
c_content.append('')
# 写入文件
with open(self.model_paths['model_tflite_micro'], 'w', encoding='utf-8') as f:
f.write('\n'.join(c_content))
def download_pretrained_model(self):
"""下载预训练模型"""
logger.info("下载预训练模型...")
try:
import urllib.request
import tarfile
model_url = "https://storage.googleapis.com/download.tensorflow.org/models/tflite/speech_micro_train_2020_05_10.tgz"
model_file = "speech_micro_train_2020_05_10.tgz"
logger.info(f"从 {model_url} 下载模型...")
urllib.request.urlretrieve(model_url, model_file)
logger.info("解压模型文件...")
with tarfile.open(model_file, 'r:gz') as tar:
tar.extractall('.')
os.remove(model_file)
logger.info("预训练模型下载完成")
return True
except Exception as e:
logger.error(f"下载预训练模型失败: {e}")
return False
def print_model_info(self):
"""打印模型信息"""
logger.info("\n" + "="*50)
logger.info("模型训练完成!")
logger.info("="*50)
for name, path in self.model_paths.items():
if os.path.exists(path):
if os.path.isfile(path):
size = os.path.getsize(path)
logger.info(f"{name}: {path} ({size} 字节)")
else:
logger.info(f"{name}: {path} (目录)")
else:
logger.info(f"{name}: {path} (未生成)")
logger.info("\n部署到微控制器:")
logger.info("1. 参考 TensorFlow Lite Micro 文档")
logger.info("2. 更新 micro_model_settings.h 中的 kCategoryCount 和 kCategoryLabels")
logger.info("3. 使用生成的 micro_speech_quantized_model_data.c 文件替换原有模型文件")
def run_full_pipeline(self, skip_training=False):
"""运行完整的训练流水线"""
try:
logger.info("开始完整的模型训练流水线...")
# 确保数据集就绪
self.ensure_dataset()
# 清理之前的数据
self.clean_previous_data()
if skip_training:
success = self.download_pretrained_model()
if not success:
logger.error("无法下载预训练模型,将执行完整训练")
skip_training = False
if not skip_training:
# 设置 TensorFlow 仓库
self.setup_tensorflow_repo()
# 训练模型
self.train_model()
# 冻结模型
self.freeze_model()
else:
logger.info("使用预训练模型,跳过训练步骤")
# 转换为 TensorFlow Lite
audio_processor, model_settings = self.convert_to_tflite()
# 测试模型精度
self.test_tflite_accuracy(audio_processor, model_settings)
# 生成微控制器模型
self.generate_micro_model()
# 打印模型信息
self.print_model_info()
logger.info("训练流水线完成!")
except KeyboardInterrupt:
logger.info("用户中断训练")
raise
except Exception as e:
logger.error(f"训练流水线失败: {e}")
logger.error("可能的解决方案:")
logger.error("1. 检查网络连接(下载数据集需要)")
logger.error("2. 确保有足够的磁盘空间")
logger.error("3. 检查 Python 环境和依赖")
logger.error("4. 尝试使用 --skip_training 下载预训练模型")
raise
def main():
"""主函数"""
parser = argparse.ArgumentParser(description='语音识别模型训练程序')
parser.add_argument('--wanted_words', default='yes,no',
help='要训练的词汇,用逗号分隔 (默认: yes,no)')
parser.add_argument('--training_steps', default='12000,3000',
help='训练步数,用逗号分隔 (默认: 12000,3000)')
parser.add_argument('--learning_rate', default='0.001,0.0001',
help='学习率,用逗号分隔 (默认: 0.001,0.0001)')
parser.add_argument('--model_architecture', default='tiny_conv',
choices=['single_fc', 'conv', 'low_latency_conv',
'low_latency_svdf', 'tiny_embedding_conv', 'tiny_conv'],
help='模型架构 (默认: tiny_conv)')
parser.add_argument('--skip_training', action='store_true',
help='跳过训练,使用预训练模型')
parser.add_argument('--resume', action='store_true',
help='继续上次训练:保留 train 目录并从最近 checkpoint 恢复')
parser.add_argument('--test_env', action='store_true',
help='仅测试环境,不执行训练')
args = parser.parse_args()
if args.test_env:
# 仅测试环境
logger.info("测试环境配置...")
try:
trainer = SpeechRecognitionTrainer()
logger.info("✓ 环境测试通过")
logger.info("✓ 所有依赖已正确安装")
logger.info("可以开始训练模型")
except Exception as e:
logger.error(f"✗ 环境测试失败: {e}")
sys.exit(1)
return
# 配置训练参数
config = {
'wanted_words': args.wanted_words,
'training_steps': args.training_steps,
'learning_rate': args.learning_rate,
'model_architecture': args.model_architecture,
'resume': args.resume,
}
# 创建训练器
trainer = SpeechRecognitionTrainer(config)
# 运行完整训练流水线
trainer.run_full_pipeline(skip_training=args.skip_training)
if __name__ == '__main__':
main()