文章目录
-
- 一、概述
- 二、核心区别总览
- 三、详细对比分析
-
- [1. 状态管理机制](#1. 状态管理机制)
-
- [1.1 训练模式:无状态处理](#1.1 训练模式:无状态处理)
- [1.2 流式推理:有状态处理](#1.2 流式推理:有状态处理)
- [2. 网络行为差异](#2. 网络行为差异)
-
- [2.1 Layer Dropout机制](#2.1 Layer Dropout机制)
- [2.2 RandomCombine机制](#2.2 RandomCombine机制)
- [3. 数据处理流程](#3. 数据处理流程)
-
- [3.1 训练模式的完整数据流](#3.1 训练模式的完整数据流)
- [3.2 流式推理的完整数据流](#3.2 流式推理的完整数据流)
- [4. 适用场景详解](#4. 适用场景详解)
-
- [4.1 训练模式的应用场景](#4.1 训练模式的应用场景)
- [4.2 流式推理的应用场景](#4.2 流式推理的应用场景)
- 四、代码示例
- 五、性能对比
- 六、最佳实践
- 七、常见问题
-
- [Q1: 流式推理的结果和训练时不一致?](#Q1: 流式推理的结果和训练时不一致?)
- [Q2: 流式推理时chunk边界有断裂感?](#Q2: 流式推理时chunk边界有断裂感?)
- [Q3: 多会话时显存不足?](#Q3: 多会话时显存不足?)
- [Q4: 如何加速流式推理?](#Q4: 如何加速流式推理?)
- 八、总结
一、概述
在LSTM-based RNN编码器中,训练模式(Training Mode) 和流式推理模式(Streaming Inference Mode) 是两种完全不同的工作方式。理解它们的区别对于正确使用模型至关重要。
什么是训练模式?
训练模式用于学习模型参数,处理完整的音频序列,通过反向传播优化网络权重。
特点:
- 批量处理多个完整样本
- 需要计算梯度
- 使用随机性(Dropout、RandomCombine)提高泛化能力
- 高吞吐量,高延迟
什么是流式推理模式?
流式推理模式用于实时应用,将音频流分成小chunk逐段处理,通过维护LSTM状态保持连续性。
特点:
- 单样本分chunk处理
- 不需要梯度
- 无随机性,结果确定
- 低延迟,适合实时场景
二、核心区别总览
快速对比表
维度 | 训练模式 | 流式推理模式 |
---|---|---|
主要目标 | 学习模型参数 | 实时输出结果 |
数据形式 | 完整序列一次性处理 | 音频流分chunk逐段处理 |
批次大小 | 较大 (5-32+) | 通常为1 |
序列长度 | 长 (数百到数千帧) | 短 (16-32帧/chunk) |
状态管理 | ❌ 不需要 | ✅ 必须维护LSTM状态 |
模式标志 | model.train() |
model.eval() |
梯度计算 | ✅ 需要 (requires_grad=True ) |
❌ 不需要 (torch.no_grad() ) |
RandomCombine | ✅ 启用,随机组合层输出 | ❌ 禁用,只用最后一层 |
Layer Dropout | ✅ 启用 (alpha可能<1) | ❌ 禁用 (alpha=1.0) |
Warmup参数 | ✅ 使用,控制层bypass | ❌ 固定为1.0 |
内存占用 | 高 (~65MB/batch) | 低 (~125KB/chunk) |
延迟 | 高 (秒级) | 低 (毫秒级) |
吞吐量 | 高 (50,000帧/秒) | 中等 (16,000帧/秒) |
GPU利用率 | 高 (批量并行) | 低 (单样本) |
确定性 | ❌ 非确定性 | ✅ 确定性 |
工作流程对比图
训练模式流程:
┌─────────────────────────────────────────┐
│ 输入: Batch样本 (N, T_long, F) │
│ 例如: (32, 1000, 80) │
└─────────────────────────────────────────┘
↓
┌─────────────────────────────────────────┐
│ 卷积下采样 (4倍) │
│ (32, 1000, 80) → (32, 247, 512) │
└─────────────────────────────────────────┘
↓
┌─────────────────────────────────────────┐
│ 12层LSTM编码器 │
│ - 从零状态开始 │
│ - Layer Dropout: 随机bypass一些层 │
│ - RandomCombine: 随机组合多层输出 │
└─────────────────────────────────────────┘
↓
┌─────────────────────────────────────────┐
│ 输出: (32, 247, 512) │
│ 计算损失 → 反向传播 → 更新参数 │
└─────────────────────────────────────────┘
流式推理流程:
┌─────────────────────────────────────────┐
│ 初始化状态: states = get_init_states() │
└─────────────────────────────────────────┘
↓
┌─────────────────┐
│ 音频流循环 │
└─────────────────┘
↓
┌─────────────────────────────────────────┐
│ 输入: Chunk (1, T_short, F) │
│ 例如: (1, 16, 80) │
└─────────────────────────────────────────┘
↓
┌─────────────────────────────────────────┐
│ 卷积下采样 (4倍) │
│ (1, 16, 80) → (1, 1, 512) │
└─────────────────────────────────────────┘
↓
┌─────────────────────────────────────────┐
│ 12层LSTM编码器 │
│ - 使用前一chunk的状态 │
│ - 无Layer Dropout │
│ - 无RandomCombine,只用最后一层 │
│ - 输出新状态 │
└─────────────────────────────────────────┘
↓
┌─────────────────────────────────────────┐
│ 输出: (1, 1, 512) + new_states │
│ states ← new_states (用于下一chunk) │
└─────────────────────────────────────────┘
↓
(回到音频流循环)
三、详细对比分析
1. 状态管理机制
1.1 训练模式:无状态处理
python
# RNN.forward() - 训练模式代码片段
if states is None: # 训练时states为None
# 每个样本从零状态开始,样本间完全独立
x = self.encoder(x, warmup=warmup)[0]
# 返回空状态(仅为满足接口要求)
new_states = (torch.empty(0), torch.empty(0))
原理说明:
- LSTM的hidden和cell状态初始化为零向量
- 每个训练样本是完整的独立utterance
- 样本之间没有时序关系,可以shuffle
- 不需要记忆之前的信息
适用场景:
- 离线训练:每个样本是完整录音
- 批量评估:处理录音文件集合
- 不关心样本间的连续性
1.2 流式推理:有状态处理
python
# RNN.forward() - 流式推理代码片段
if states is not None: # 流式时必须提供states
# 确保在评估模式
assert not self.training
# 验证状态的形状
assert len(states) == 2
assert states[0].shape == (num_layers, batch_size, d_model)
assert states[1].shape == (num_layers, batch_size, rnn_hidden_size)
# 使用之前的状态处理当前chunk
x, new_states = self.encoder(x, states)
状态内容:
python
states = (hidden_states, cell_states)
# hidden_states: (12, 1, 512)
# - 12层,每层的隐藏状态
# - 用于LSTM的输出
#
# cell_states: (12, 1, 1024)
# - 12层,每层的细胞状态
# - LSTM的内部记忆
状态初始化:
python
# 第一个chunk开始前
states = model.get_init_states(batch_size=1, device=device)
# 内部实现
def get_init_states(self, batch_size=1, device=torch.device("cpu")):
hidden_states = torch.zeros(
(self.num_encoder_layers, batch_size, self.d_model),
device=device
)
cell_states = torch.zeros(
(self.num_encoder_layers, batch_size, self.rnn_hidden_size),
device=device
)
return (hidden_states, cell_states)
状态传递流程:
python
# 流式推理主循环
states = model.get_init_states(batch_size=1, device=device)
for chunk in audio_stream:
# 1. 使用当前states处理chunk
embeddings, lengths, new_states = model(
chunk, chunk_lens, states=states
)
# 2. 更新states,传递给下一个chunk
states = new_states
# 3. 使用embeddings做后续处理
process(embeddings)
为什么需要状态?
- 保持连续性:音频流是连续的,LSTM需要记住之前的信息
- 上下文依赖:当前chunk的理解依赖之前的context
- 避免边界效应:chunk边界不会导致信息丢失
状态管理注意事项:
- ⚠️ 必须正确传递states,否则每个chunk独立处理
- ⚠️ 新对话/新音频流需要重置states
- ⚠️ 多线程场景需要为每个流维护独立的states
2. 网络行为差异
2.1 Layer Dropout机制
训练模式 - 有Layer Dropout:
python
# RNNEncoderLayer.forward()
def forward(self, src, states=None, warmup=1.0):
src_orig = src # 保存原始输入
# 计算warmup缩放
warmup_scale = min(0.1 + warmup, 1.0)
if self.training:
# 训练时:随机决定是否bypass该层
if torch.rand(()).item() <= (1.0 - self.layer_dropout):
alpha = warmup_scale # 使用该层
else:
alpha = 0.1 # bypass该层
else:
alpha = 1.0 # 推理时完全使用
# ... LSTM和FeedForward处理 ...
# 应用layer dropout
if alpha != 1.0:
# 混合原始输入和处理后的输出
src = alpha * src + (1 - alpha) * src_orig
return src, new_states
Alpha值的含义:
alpha = 1.0
: 完全使用该层的输出alpha = 0.1
: 基本bypass该层(90%使用原始输入)0.1 < alpha < 1.0
: 部分使用该层
Layer Dropout的作用:
- 渐进式训练:训练初期(warmup小)更频繁bypass层,减少训练难度
- 正则化:随机bypass增强模型鲁棒性
- 加速收敛:避免深层网络训练初期梯度问题
Warmup调度示例:
python
# 训练循环
total_steps = 100000
warmup_steps = 10000
for step in range(total_steps):
# 前10000步warmup从0增长到1
warmup = min(1.0, step / warmup_steps)
# warmup对layer dropout的影响:
# step=0: warmup=0, warmup_scale=0.1
# step=5000: warmup=0.5, warmup_scale=0.6
# step>=10000: warmup=1.0, warmup_scale=1.0
output = model(x, x_lens, warmup=warmup)
流式推理 - 无Layer Dropout:
python
# 推理模式
if self.training:
# 训练逻辑(上面的代码)
else:
alpha = 1.0 # 始终完全使用每一层
# 结果:
# src = 1.0 * src + (1-1.0) * src_orig = src
# 不会混合原始输入,完全使用处理后的输出
为什么推理不用Layer Dropout?
- 确定性:推理结果需要可复现
- 最优性能:使用全部层获得最佳效果
- 无正则化需求:推理不需要防止过拟合
2.2 RandomCombine机制
训练模式 - 启用RandomCombine:
python
# RNNEncoder.forward()
def forward(self, src, states=None, warmup=1.0):
output = src
outputs = [] # 存储辅助层输出
# 逐层处理
for i, layer in enumerate(self.layers):
output = layer(output, warmup=warmup)[0]
# 收集辅助层输出
if self.combiner is not None and i in self.aux_layers:
outputs.append(output)
# 训练时:随机组合多层输出
if self.combiner is not None:
output = self.combiner(outputs)
return output, new_states
RandomCombine的实现:
python
# RandomCombine.forward()
def forward(self, inputs): # inputs是多层的输出列表
# 推理时:直接返回最后一层
if not self.training:
return inputs[-1]
# 训练时:随机组合
# 例如:inputs = [layer4_out, layer7_out, layer10_out, layer11_out]
# 生成随机权重
weights = self._get_random_weights(...)
# weights: (num_frames, 4),每帧的权重不同
# 加权组合
output = weighted_sum(inputs, weights)
return output
随机权重生成策略:
python
# 以pure_prob=0.333的概率:选择单一层(one-hot)
if rand() < 0.333:
# 以final_weight=0.5的概率选择最后一层
if rand() < 0.5:
weights = [0, 0, 0, 1] # 最后一层
else:
weights = [1, 0, 0, 0] # 随机非最后层
# 或 [0, 1, 0, 0], [0, 0, 1, 0]
# 以(1-pure_prob)=0.667的概率:加权组合
else:
# 生成连续权重,给最后一层更高权重
log_weights = randn(4) * stddev
log_weights[3] += final_log_weight
weights = softmax(log_weights)
# 例如: [0.1, 0.2, 0.15, 0.55]
RandomCombine的作用:
- 类似Iterated Loss:让中间层也参与最终输出
- 改善梯度流:中间层获得更直接的监督信号
- 提高鲁棒性:测试时只用最后一层也能工作
辅助层配置示例:
python
# 12层网络,aux_layer_period=3
aux_layers = list(range(12//3, 12-1, 3))
# aux_layers = [4, 7, 10]
# 加上最后一层: [4, 7, 10, 11]
# RandomCombine会随机组合这4层的输出
流式推理 - 禁用RandomCombine:
python
# RandomCombine.forward()
def forward(self, inputs):
if not self.training:
# 推理时:只返回最后一层
return inputs[-1]
# (训练逻辑被跳过)
为什么推理只用最后一层?
- 效率:不需要计算随机权重
- 最优性能:最后一层通常表现最好
- 确定性:避免随机性
3. 数据处理流程
3.1 训练模式的完整数据流
数据准备:
python
# DataLoader批次
batch = {
'features': torch.randn(32, 1000, 80), # 32个样本,最长1000帧
'feature_lens': torch.tensor([1000, 980, 950, ..., 600]), # 实际长度
'targets': ..., # 目标标签
}
# 特点:
# 1. 批量处理:32个样本并行
# 2. 变长序列:使用padding统一长度
# 3. 完整utterance:每个样本是完整的录音
前向传播:
python
# 设置训练模式
model.train()
# 准备数据
x = batch['features'] # (32, 1000, 80)
x_lens = batch['feature_lens'] # (32,)
targets = batch['targets']
# 前向传播
with torch.enable_grad(): # 需要梯度
embeddings, lengths, _ = model(
x,
x_lens,
states=None, # 不传递状态
warmup=current_warmup # 当前warmup值
)
# embeddings: (32, 247, 512)
# lengths: (32,) - [247, 242, 234, ..., 147]
# 计算损失(例如CTC Loss或Transducer Loss)
loss = criterion(embeddings, targets, lengths)
反向传播:
python
# 梯度清零
optimizer.zero_grad()
# 反向传播
loss.backward()
# 梯度裁剪(可选)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
# 参数更新
optimizer.step()
# 学习率调度(可选)
scheduler.step()
内存占用分析:
python
# 前向传播需要保存的张量:
# 1. 输入: (32, 1000, 80) = 32 * 1000 * 80 * 4bytes ≈ 10MB
# 2. Conv输出: (32, 247, 512) ≈ 16MB
# 3. 每层LSTM输出: (247, 32, 512) ≈ 16MB * 12层 = 192MB
# 4. 梯度: 约等于参数量(96.5M参数 * 4bytes ≈ 386MB)
#
# 总计:约 600MB (单个batch)
# 实际GPU显存占用:1-2GB(包括优化器状态等)
3.2 流式推理的完整数据流
初始化:
python
# 设置评估模式
model.eval()
# 移动到设备
device = torch.device('cuda')
model = model.to(device)
# 初始化状态
states = model.get_init_states(batch_size=1, device=device)
# states[0]: (12, 1, 512) - hidden states
# states[1]: (12, 1, 1024) - cell states
音频流处理:
python
# 模拟音频流(实际应用中从麦克风/网络获取)
def audio_stream_generator(audio_file, chunk_size=16):
"""
从音频文件生成chunk流
Args:
audio_file: 音频文件路径
chunk_size: 每个chunk的帧数
Yields:
chunk: (1, chunk_size, 80)
"""
# 加载音频
features = load_audio_features(audio_file) # (T, 80)
# 分chunk
for i in range(0, len(features), chunk_size):
chunk = features[i:i+chunk_size]
# 填充到chunk_size(最后一个chunk可能不足)
if len(chunk) < chunk_size:
chunk = F.pad(chunk, (0, 0, 0, chunk_size - len(chunk)))
# 添加batch维度
chunk = chunk.unsqueeze(0) # (1, chunk_size, 80)
yield chunk, min(chunk_size, len(features) - i)
# 主处理循环
all_embeddings = []
with torch.no_grad(): # 推理不需要梯度
for chunk, chunk_len in audio_stream_generator(audio_file):
# 移动到设备
chunk = chunk.to(device)
chunk_lens = torch.tensor([chunk_len], device=device)
# 处理当前chunk
embeddings, lengths, new_states = model(
chunk, # (1, 16, 80)
chunk_lens, # (1,)
states=states, # 使用上一chunk的状态
warmup=1.0 # 推理不使用warmup
)
# 保存结果
all_embeddings.append(embeddings)
# 更新状态
states = new_states
# 实时处理(例如关键词检测)
if detect_keyword(embeddings):
print("检测到关键词!")
# 拼接所有输出
final_embeddings = torch.cat(all_embeddings, dim=1)
内存占用分析:
python
# 流式推理需要保存的张量:
# 1. 当前chunk: (1, 16, 80) ≈ 5KB
# 2. Conv输出: (1, 1, 512) ≈ 2KB
# 3. 每层输出: (1, 1, 512) ≈ 2KB * 12层 = 24KB
# 4. LSTM状态: (12, 1, 1024) * 2 ≈ 96KB
#
# 总计:约 127KB (单个chunk)
# 实际GPU显存占用:模型参数(~386MB) + 运行时(~1MB) ≈ 400MB
延迟分析:
python
# 假设音频采样率16kHz,帧率100Hz(10ms per frame)
chunk_size = 16 # 帧
# 音频延迟
audio_latency = chunk_size * 10ms = 160ms
# 计算延迟(GPU推理)
compute_latency ≈ 5-10ms
# 总延迟
total_latency = 160ms + 10ms = 170ms
# 实时因子 (RTF)
RTF = compute_latency / audio_latency = 10ms / 160ms ≈ 0.06
# 结论:可以实时处理(RTF < 1)
4. 适用场景详解
4.1 训练模式的应用场景
✅ 场景1:模型训练
python
# 离线训练脚本
import torch
from torch.utils.data import DataLoader
from lstm import RNN
# 数据集
train_dataset = AudioDataset(data_dir='train')
train_loader = DataLoader(
train_dataset,
batch_size=32,
shuffle=True, # 打乱样本顺序
num_workers=4,
collate_fn=collate_fn # 处理变长序列
)
# 模型
model = RNN(num_features=80, d_model=512, num_encoder_layers=12)
model.train()
# 训练循环
for epoch in range(num_epochs):
for batch in train_loader:
x, x_lens, targets = batch
# 前向传播
embeddings, lengths, _ = model(x, x_lens, warmup=epoch/100)
# 计算损失
loss = criterion(embeddings, targets, lengths)
# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
适用条件:
- ✅ 有大量标注数据
- ✅ 有GPU资源
- ✅ 可以批量处理
- ✅ 无实时要求
✅ 场景2:离线批量评估
python
# 批量评估脚本
model.eval()
test_loader = DataLoader(test_dataset, batch_size=16)
all_predictions = []
with torch.no_grad():
for batch in test_loader:
x, x_lens = batch
embeddings, lengths, _ = model(x, x_lens)
# 后续处理(如解码)
predictions = decoder(embeddings, lengths)
all_predictions.extend(predictions)
# 计算指标
accuracy = compute_accuracy(all_predictions, ground_truth)
适用条件:
- ✅ 处理录音文件集合
- ✅ 无实时要求
- ✅ 可以批量处理提高效率
✅ 场景3:研究实验
python
# 对比不同配置
configs = [
{'num_layers': 6, 'd_model': 256},
{'num_layers': 12, 'd_model': 512},
{'num_layers': 18, 'd_model': 768},
]
for config in configs:
model = RNN(**config)
train_and_evaluate(model)
适用条件:
- ✅ 需要快速迭代实验
- ✅ 对比不同超参数
- ✅ 分析模型行为
4.2 流式推理的应用场景
✅ 场景1:语音助手
python
# 智能音箱/手机语音助手
class VoiceAssistant:
def __init__(self):
self.model = load_model()
self.model.eval()
self.states = self.model.get_init_states(1, device)
def process_audio_stream(self):
"""处理实时音频流"""
mic = Microphone()
while True:
# 从麦克风获取chunk(例如160ms音频)
chunk = mic.read_chunk()
# 特征提取
features = extract_features(chunk) # (1, 16, 80)
# 模型推理
with torch.no_grad():
embeddings, _, new_states = self.model(
features,
torch.tensor([16]),
states=self.states
)
# 更新状态
self.states = new_states
# 关键词检测
if keyword_detector(embeddings) == "小爱同学":
self.wake_up()
self.reset_states() # 唤醒后重置
关键要求:
- ⚡ 低延迟 (< 200ms)
- 📱 边缘设备(手机、音箱)
- 🔄 连续处理音频流
- 💾 内存受限
✅ 场景2:实时字幕系统
python
# 视频会议/直播实时字幕
class RealtimeTranscriber:
def __init__(self):
self.encoder = RNN(...)
self.decoder = Decoder(...)
self.states = self.encoder.get_init_states(1, device)
def transcribe_stream(self, audio_stream):
"""实时转录音频流"""
for chunk in audio_stream:
# 编码
embeddings, _, new_states = self.encoder(
chunk, chunk_lens, states=self.states
)
self.states = new_states
# 解码
text = self.decoder(embeddings)
# 实时显示
display_subtitle(text)
yield text
关键要求:
- ⚡ 实时响应
- 📺 流媒体场景
- 🔄 连续输出文本
✅ 场景3:电话客服系统
python
# 智能客服语音识别
class CallCenterASR:
def __init__(self):
self.model = RNN(...)
self.sessions = {} # 每个通话维护独立状态
def handle_call(self, call_id, audio_stream):
"""处理电话音频流"""
# 为新通话初始化状态
if call_id not in self.sessions:
self.sessions[call_id] = {
'states': self.model.get_init_states(1, device),
'transcript': []
}
session = self.sessions[call_id]
for chunk in audio_stream:
# 处理音频chunk
embeddings, _, new_states = self.model(
chunk, chunk_lens, states=session['states']
)
# 更新状态
session['states'] = new_states
# 识别文本
text = recognize(embeddings)
session['transcript'].append(text)
# 意图理解
intent = understand_intent(text)
response = generate_response(intent)
yield response
def end_call(self, call_id):
"""通话结束,清理状态"""
del self.sessions[call_id]
关键要求:
- 📞 多路并发(多个通话同时进行)
- 💾 每个通话独立状态
- ⚡ 低延迟响应
✅ 场景4:边缘设备部署
python
# 嵌入式设备(如树莓派)
class EdgeKWS:
"""边缘设备关键词识别"""
def __init__(self, model_path):
# 加载量化/压缩的模型
self.model = load_quantized_model(model_path)
self.model.eval()
self.states = self.model.get_init_states(1, 'cpu')
def detect_keyword(self, audio_stream):
"""在边缘设备上运行"""
for chunk in audio_stream:
# CPU推理
with torch.no_grad():
embeddings, _, new_states = self.model(
chunk, chunk_lens, states=self.states
)
self.states = new_states
# 关键词检测
if is_keyword(embeddings):
return True
return False
关键要求:
- 💾 内存极度受限 (< 100MB)
- 🔋 功耗受限
- 🚫 无网络连接(离线工作)
- 📱 CPU推理
四、代码示例
训练模式完整示例
python
"""
完整的训练脚本示例
包含数据加载、训练循环、验证、保存模型等
"""
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from lstm import RNN
# ============================================================================
# 1. 数据准备
# ============================================================================
class AudioDataset(torch.utils.data.Dataset):
"""音频数据集"""
def __init__(self, data_dir, manifest_file):
self.data = self.load_manifest(manifest_file)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
# 加载音频特征
features = load_features(self.data[idx]['audio_path']) # (T, 80)
targets = self.data[idx]['targets']
return features, targets
def collate_fn(batch):
"""处理变长序列"""
features_list, targets_list = zip(*batch)
# 获取最大长度
max_len = max(f.size(0) for f in features_list)
batch_size = len(features_list)
# Padding
features_padded = torch.zeros(batch_size, max_len, 80)
feature_lens = torch.zeros(batch_size, dtype=torch.long)
for i, feat in enumerate(features_list):
length = feat.size(0)
features_padded[i, :length] = feat
feature_lens[i] = length
return features_padded, feature_lens, targets_list
# 创建数据加载器
train_dataset = AudioDataset('data/train', 'train.json')
train_loader = DataLoader(
train_dataset,
batch_size=32,
shuffle=True,
num_workers=4,
collate_fn=collate_fn,
pin_memory=True # 加速GPU传输
)
val_dataset = AudioDataset('data/val', 'val.json')
val_loader = DataLoader(
val_dataset,
batch_size=16,
shuffle=False,
collate_fn=collate_fn
)
# ============================================================================
# 2. 模型创建
# ============================================================================
model = RNN(
num_features=80,
subsampling_factor=4,
d_model=512,
dim_feedforward=2048,
rnn_hidden_size=1024,
num_encoder_layers=12,
dropout=0.1,
layer_dropout=0.075,
aux_layer_period=3, # 启用RandomCombine
)
# 移动到GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
print(f"模型参数量: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M")
# ============================================================================
# 3. 优化器和损失函数
# ============================================================================
# 优化器
optimizer = torch.optim.Adam(
model.parameters(),
lr=1e-3,
betas=(0.9, 0.98),
eps=1e-9
)
# 学习率调度器
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
mode='min',
factor=0.5,
patience=3,
verbose=True
)
# 损失函数(示例:CTC Loss)
criterion = nn.CTCLoss(blank=0, reduction='mean')
# ============================================================================
# 4. 训练函数
# ============================================================================
def train_epoch(model, data_loader, optimizer, criterion, epoch, total_epochs):
"""训练一个epoch"""
model.train()
total_loss = 0
num_batches = len(data_loader)
for batch_idx, (features, feature_lens, targets) in enumerate(data_loader):
# 移动到设备
features = features.to(device)
feature_lens = feature_lens.to(device)
# 计算warmup
# 前10个epoch从0增长到1
warmup = min(1.0, epoch / 10.0)
# 前向传播
embeddings, lengths, _ = model(
features,
feature_lens,
states=None, # 训练不需要状态
warmup=warmup
)
# 准备CTC Loss的输入
# embeddings: (N, T, d_model) -> (T, N, d_model)
log_probs = torch.log_softmax(embeddings.transpose(0, 1), dim=-1)
# 计算损失
loss = criterion(log_probs, targets, lengths, target_lengths)
# 反向传播
optimizer.zero_grad()
loss.backward()
# 梯度裁剪
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
# 更新参数
optimizer.step()
# 统计
total_loss += loss.item()
# 打印进度
if (batch_idx + 1) % 10 == 0:
avg_loss = total_loss / (batch_idx + 1)
print(f"Epoch [{epoch}/{total_epochs}] "
f"Batch [{batch_idx+1}/{num_batches}] "
f"Loss: {loss.item():.4f} "
f"Avg Loss: {avg_loss:.4f} "
f"Warmup: {warmup:.2f}")
return total_loss / num_batches
# ============================================================================
# 5. 验证函数
# ============================================================================
def validate(model, data_loader, criterion):
"""验证模型"""
model.eval()
total_loss = 0
num_batches = len(data_loader)
with torch.no_grad():
for features, feature_lens, targets in data_loader:
features = features.to(device)
feature_lens = feature_lens.to(device)
# 前向传播(推理模式)
embeddings, lengths, _ = model(
features,
feature_lens,
states=None,
warmup=1.0 # 验证时warmup=1.0
)
# 计算损失
log_probs = torch.log_softmax(embeddings.transpose(0, 1), dim=-1)
loss = criterion(log_probs, targets, lengths, target_lengths)
total_loss += loss.item()
avg_loss = total_loss / num_batches
return avg_loss
# ============================================================================
# 6. 主训练循环
# ============================================================================
def main():
num_epochs = 100
best_val_loss = float('inf')
for epoch in range(1, num_epochs + 1):
print(f"\n{'='*60}")
print(f"Epoch {epoch}/{num_epochs}")
print(f"{'='*60}")
# 训练
train_loss = train_epoch(
model, train_loader, optimizer, criterion, epoch, num_epochs
)
# 验证
val_loss = validate(model, val_loader, criterion)
print(f"\nEpoch {epoch} Summary:")
print(f" Train Loss: {train_loss:.4f}")
print(f" Val Loss: {val_loss:.4f}")
# 学习率调度
scheduler.step(val_loss)
# 保存最佳模型
if val_loss < best_val_loss:
best_val_loss = val_loss
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'val_loss': val_loss,
}, 'best_model.pt')
print(f" ✓ 保存最佳模型 (val_loss={val_loss:.4f})")
# 定期保存checkpoint
if epoch % 10 == 0:
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
}, f'checkpoint_epoch_{epoch}.pt')
if __name__ == '__main__':
main()
流式推理完整示例
python
"""
完整的流式推理脚本示例
包含音频流处理、状态管理、实时关键词检测等
"""
import torch
import numpy as np
from lstm import RNN
# ============================================================================
# 1. 模型加载
# ============================================================================
def load_model(checkpoint_path, device):
"""加载训练好的模型"""
# 创建模型
model = RNN(
num_features=80,
d_model=512,
rnn_hidden_size=1024,
num_encoder_layers=12,
)
# 加载权重
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
# 设置评估模式
model.eval()
model = model.to(device)
print(f"✓ 模型加载成功")
return model
# ============================================================================
# 2. 音频流处理器
# ============================================================================
class AudioStreamProcessor:
"""音频流处理器"""
def __init__(self, model, device, chunk_size=16):
"""
Args:
model: RNN模型
device: 设备(CPU或GPU)
chunk_size: 每个chunk的帧数
"""
self.model = model
self.device = device
self.chunk_size = chunk_size
# 初始化状态
self.reset_states()
# 统计信息
self.total_chunks = 0
self.total_time = 0
def reset_states(self):
"""重置LSTM状态(新对话/新音频流时调用)"""
self.states = self.model.get_init_states(
batch_size=1,
device=self.device
)
print("✓ 状态已重置")
def process_chunk(self, chunk):
"""
处理单个音频chunk
Args:
chunk: 音频特征,形状 (chunk_size, 80) 或 (1, chunk_size, 80)
Returns:
embeddings: 编码后的特征 (1, T', 512)
lengths: 输出长度
"""
# 确保形状正确
if chunk.dim() == 2:
chunk = chunk.unsqueeze(0) # (chunk_size, 80) -> (1, chunk_size, 80)
# 获取实际长度
chunk_len = chunk.size(1)
chunk_lens = torch.tensor([chunk_len], device=self.device)
# 移动到设备
chunk = chunk.to(self.device)
# 推理
import time
start_time = time.time()
with torch.no_grad():
embeddings, lengths, new_states = self.model(
chunk,
chunk_lens,
states=self.states,
warmup=1.0
)
# 更新状态
self.states = new_states
# 统计
elapsed = time.time() - start_time
self.total_chunks += 1
self.total_time += elapsed
return embeddings, lengths
def get_stats(self):
"""获取统计信息"""
avg_time = self.total_time / self.total_chunks if self.total_chunks > 0 else 0
# 计算实时因子
# chunk_size帧 @ 100fps = chunk_size * 10ms
audio_duration = self.chunk_size * 0.01 # 秒
rtf = avg_time / audio_duration if audio_duration > 0 else 0
return {
'total_chunks': self.total_chunks,
'total_time': self.total_time,
'avg_time_per_chunk': avg_time,
'rtf': rtf
}
# ============================================================================
# 3. 音频流生成器
# ============================================================================
def audio_stream_from_file(audio_file, chunk_size=16):
"""
从音频文件生成chunk流(模拟实时流)
Args:
audio_file: 音频文件路径
chunk_size: chunk大小(帧数)
Yields:
chunk: (chunk_size, 80)
"""
# 加载音频特征(假设已经提取好)
# 实际应用中需要实时提取特征
features = np.load(audio_file) # (T, 80)
print(f"音频总长度: {len(features)} 帧 ({len(features)*0.01:.2f} 秒)")
print(f"Chunk大小: {chunk_size} 帧 ({chunk_size*0.01:.2f} 秒)")
print(f"总chunk数: {len(features) // chunk_size}")
print()
# 分chunk
for i in range(0, len(features), chunk_size):
chunk = features[i:i+chunk_size]
# 最后一个chunk可能不足,需要padding
if len(chunk) < chunk_size:
chunk = np.pad(
chunk,
((0, chunk_size - len(chunk)), (0, 0)),
mode='constant'
)
# 转换为tensor
chunk = torch.from_numpy(chunk).float()
yield chunk
# 模拟实时延迟(可选)
# import time
# time.sleep(chunk_size * 0.01)
# ============================================================================
# 4. 关键词检测器(示例)
# ============================================================================
class KeywordDetector:
"""简单的关键词检测器"""
def __init__(self, keywords, threshold=0.5):
self.keywords = keywords
self.threshold = threshold
self.keyword_classifier = self.load_classifier()
def load_classifier(self):
"""加载关键词分类器(示例)"""
# 实际应用中这里是一个分类器网络
# 这里简化为随机检测
return lambda x: np.random.rand() > 0.95
def detect(self, embeddings):
"""
检测关键词
Args:
embeddings: 编码特征 (1, T', 512)
Returns:
detected: 是否检测到关键词
keyword: 检测到的关键词(如果有)
"""
# 简化的检测逻辑
score = self.keyword_classifier(embeddings)
if score:
return True, "小爱同学"
return False, None
# ============================================================================
# 5. 主流式推理流程
# ============================================================================
def main_streaming():
"""主流式推理函数"""
# 设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"使用设备: {device}\n")
# 加载模型
model = load_model('best_model.pt', device)
# 创建处理器
processor = AudioStreamProcessor(
model=model,
device=device,
chunk_size=16
)
# 创建关键词检测器
detector = KeywordDetector(keywords=["小爱同学", "你好"])
# 处理音频流
print("开始处理音频流...")
print("="*60)
audio_file = 'test_audio_features.npy'
all_embeddings = []
for chunk_idx, chunk in enumerate(audio_stream_from_file(audio_file)):
# 处理chunk
embeddings, lengths = processor.process_chunk(chunk)
# 保存结果
all_embeddings.append(embeddings)
# 关键词检测
detected, keyword = detector.detect(embeddings)
# 打印信息
if detected:
print(f"Chunk {chunk_idx:3d}: ✓ 检测到关键词 [{keyword}]")
else:
print(f"Chunk {chunk_idx:3d}: - 处理完成", end='\r')
print("\n" + "="*60)
print("处理完成!\n")
# 打印统计信息
stats = processor.get_stats()
print("统计信息:")
print(f" 总chunk数: {stats['total_chunks']}")
print(f" 总耗时: {stats['total_time']:.3f} 秒")
print(f" 平均每chunk: {stats['avg_time_per_chunk']*1000:.2f} ms")
print(f" 实时因子 (RTF): {stats['rtf']:.3f}")
if stats['rtf'] < 1.0:
print(f" ✓ 可以实时处理 (RTF < 1.0)")
else:
print(f" ✗ 无法实时处理 (RTF >= 1.0)")
# 拼接所有输出
final_embeddings = torch.cat(all_embeddings, dim=1)
print(f"\n最终输出形状: {final_embeddings.shape}")
# ============================================================================
# 6. 多会话管理示例(电话客服场景)
# ============================================================================
class MultiSessionManager:
"""多会话管理器(用于电话客服等场景)"""
def __init__(self, model, device):
self.model = model
self.device = device
self.sessions = {}
def create_session(self, session_id):
"""创建新会话"""
if session_id in self.sessions:
print(f"警告: 会话 {session_id} 已存在")
return
self.sessions[session_id] = {
'states': self.model.get_init_states(1, self.device),
'created_at': time.time(),
'chunk_count': 0
}
print(f"✓ 创建会话: {session_id}")
def process_chunk(self, session_id, chunk):
"""处理指定会话的chunk"""
if session_id not in self.sessions:
raise ValueError(f"会话 {session_id} 不存在")
session = self.sessions[session_id]
# 处理
chunk = chunk.to(self.device)
chunk_lens = torch.tensor([chunk.size(1)], device=self.device)
with torch.no_grad():
embeddings, lengths, new_states = self.model(
chunk,
chunk_lens,
states=session['states'],
warmup=1.0
)
# 更新会话状态
session['states'] = new_states
session['chunk_count'] += 1
return embeddings, lengths
def close_session(self, session_id):
"""关闭会话,释放资源"""
if session_id in self.sessions:
del self.sessions[session_id]
print(f"✓ 关闭会话: {session_id}")
def get_active_sessions(self):
"""获取活跃会话列表"""
return list(self.sessions.keys())
# 使用示例
def demo_multi_session():
device = torch.device('cuda')
model = load_model('best_model.pt', device)
manager = MultiSessionManager(model, device)
# 模拟3个并发通话
call_ids = ['call_001', 'call_002', 'call_003']
# 创建会话
for call_id in call_ids:
manager.create_session(call_id)
# 交替处理各个会话的音频
for i in range(100): # 模拟100个chunk
# 轮流处理各个会话
call_id = call_ids[i % 3]
chunk = torch.randn(1, 16, 80) # 模拟音频chunk
embeddings, lengths = manager.process_chunk(call_id, chunk)
# 后续处理...
# 关闭会话
for call_id in call_ids:
manager.close_session(call_id)
# ============================================================================
# 7. 主入口
# ============================================================================
if __name__ == '__main__':
# 单会话流式推理
main_streaming()
# 多会话示例
# demo_multi_session()
五、性能对比
1. 吞吐量对比
训练模式
配置:
- Batch size: 32
- Sequence length: 1000帧
- GPU: NVIDIA V100
性能指标:
处理速度: ~50 utterances/second
吞吐量: 50 * 1000 = 50,000 帧/秒
GPU利用率: 85-95%
显存占用: ~4GB
优势:
- ✅ 批量并行处理,GPU利用率高
- ✅ 吞吐量大,适合大规模数据处理
劣势:
- ❌ 延迟高(必须等待完整序列)
- ❌ 显存占用大
流式推理
配置:
- Batch size: 1
- Chunk size: 16帧
- GPU: NVIDIA V100
性能指标:
处理速度: ~1000 chunks/second
吞吐量: 1000 * 16 = 16,000 帧/秒
GPU利用率: 15-25%
显存占用: ~500MB
优势:
- ✅ 低延迟(实时处理)
- ✅ 显存占用小
劣势:
- ❌ GPU利用率低(单样本)
- ❌ 吞吐量较小
💡 结论:
- 训练模式适合离线批量处理
- 流式推理适合实时单样本处理
2. 延迟对比
训练模式延迟
假设音频帧率 = 100 fps (10ms/frame)
序列长度: 1000帧
音频时长: 1000 / 100 = 10秒
处理时间: ~10秒 (取决于GPU性能)
端到端延迟 = 10秒 (必须等待完整序列)
实时因子 (RTF) = 10秒 / 10秒 = 1.0
特点:
- 延迟 = 整个序列的时长
- 不适合实时应用
- 适合离线处理
流式推理延迟
Chunk大小: 16帧
音频时长: 16 / 100 = 0.16秒 = 160ms
处理时间: ~10ms (GPU推理)
端到端延迟 = 160ms + 10ms = 170ms
实时因子 (RTF) = 10ms / 160ms = 0.0625
延迟分解:
1. 音频采集延迟: 160ms (chunk时长)
2. 特征提取延迟: ~5ms
3. 模型推理延迟: ~10ms
4. 后处理延迟: ~5ms
总延迟: 180ms
💡 结论:
- 流式推理延迟低(< 200ms)
- RTF << 1,可以实时处理
- 适合实时应用
3. 内存占用对比
训练模式内存
GPU显存占用:
1. 模型参数:
- 96.5M参数 × 4 bytes = 386 MB
2. 单个batch:
- 输入: (32, 1000, 80) × 4 bytes ≈ 10 MB
- Conv输出: (32, 247, 512) × 4 bytes ≈ 16 MB
- 12层LSTM输出: (247, 32, 512) × 4 bytes × 12 ≈ 192 MB
3. 梯度:
- 约等于参数量 ≈ 386 MB
4. 优化器状态 (Adam):
- 2倍参数量 ≈ 772 MB
总计: 386 + 218 + 386 + 772 ≈ 1762 MB ≈ 1.7 GB
实际显存占用: 2-4 GB (包括PyTorch overhead)
流式推理内存
GPU显存占用:
1. 模型参数:
- 96.5M参数 × 4 bytes = 386 MB
2. 单个chunk:
- 输入: (1, 16, 80) × 4 bytes ≈ 5 KB
- Conv输出: (1, 1, 512) × 4 bytes ≈ 2 KB
- 12层输出: (1, 1, 512) × 4 bytes × 12 ≈ 24 KB
3. LSTM状态:
- Hidden: (12, 1, 512) × 4 bytes ≈ 24 KB
- Cell: (12, 1, 1024) × 4 bytes ≈ 48 KB
总计: 386 + 0.1 ≈ 386 MB
实际显存占用: 400-500 MB
💡 结论:
- 训练模式显存占用大(~3GB)
- 流式推理显存占用小(~400MB)
- 流式推理可以在低端GPU甚至CPU上运行
4. 计算效率对比
批量处理效率
Batch Size | 吞吐量 (帧/秒) | GPU利用率 | 单样本延迟 |
---|---|---|---|
1 | 1,000 | 15% | 1s |
8 | 7,500 | 45% | 8s |
16 | 14,000 | 70% | 16s |
32 | 25,000 | 90% | 32s |
64 | 38,000 | 95% | 64s |
观察:
- Batch size越大,吞吐量越高
- 但延迟也线性增加
- GPU利用率饱和点约在batch_size=32
Chunk大小影响
Chunk Size | 延迟 | RTF | 下采样输出 |
---|---|---|---|
8 | 80ms | 0.125 | 可能为0 ⚠️ |
16 | 160ms | 0.0625 | 1-2帧 ✓ |
32 | 320ms | 0.031 | 3-4帧 ✓ |
64 | 640ms | 0.016 | 7-8帧 ✓ |
建议:
- 推荐chunk_size=16-32
- 太小: 下采样后可能为0
- 太大: 延迟增加
六、最佳实践
训练模式最佳实践
1. Warmup调度
python
# ✅ 推荐: 线性warmup
def get_warmup(step, warmup_steps=10000):
return min(1.0, step / warmup_steps)
# 使用
for step in range(total_steps):
warmup = get_warmup(step)
output = model(x, x_lens, warmup=warmup)
# ❌ 不推荐: 固定warmup
warmup = 0.5 # 不随训练变化
2. Batch Size选择
python
# ✅ 推荐: 根据GPU显存动态调整
def find_optimal_batch_size(model, device):
batch_size = 64
while batch_size > 1:
try:
x = torch.randn(batch_size, 1000, 80, device=device)
_ = model(x, torch.full((batch_size,), 1000))
return batch_size
except RuntimeError: # OOM
batch_size //= 2
return 1
# ❌ 不推荐: 固定batch size可能OOM或浪费显存
batch_size = 128 # 可能OOM
3. 梯度累积
python
# ✅ 推荐: 显存不足时使用梯度累积
accumulation_steps = 4
optimizer.zero_grad()
for i, batch in enumerate(dataloader):
loss = compute_loss(model, batch)
loss = loss / accumulation_steps
loss.backward()
if (i + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
4. 混合精度训练
python
# ✅ 推荐: 使用混合精度加速训练
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()
for batch in dataloader:
optimizer.zero_grad()
with autocast():
output = model(x, x_lens)
loss = criterion(output, target)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
流式推理最佳实践
1. 状态管理
python
# ✅ 推荐: 正确管理状态
class StreamingASR:
def __init__(self, model):
self.model = model
self.states = None
def start_utterance(self):
"""开始新utterance时重置状态"""
self.states = self.model.get_init_states(1, device)
def process_chunk(self, chunk):
if self.states is None:
self.start_utterance()
output, _, new_states = self.model(
chunk, chunk_lens, states=self.states
)
self.states = new_states
return output
def end_utterance(self):
"""结束utterance时清理状态"""
self.states = None
# ❌ 不推荐: 忘记管理状态
def process_stream(chunks):
# 错误: 每个chunk都从零状态开始
for chunk in chunks:
output = model(chunk, chunk_lens, states=None)
2. Chunk大小选择
python
# ✅ 推荐: 根据延迟要求选择chunk大小
def choose_chunk_size(latency_requirement_ms, frame_rate_fps=100):
"""
根据延迟要求选择chunk大小
Args:
latency_requirement_ms: 延迟要求(毫秒)
frame_rate_fps: 帧率
Returns:
chunk_size: chunk大小(帧数)
"""
# 考虑下采样因子=4,需要至少9帧输入
min_chunk_size = 16 # 确保下采样后有输出
# 根据延迟计算最大chunk大小
max_chunk_size = int(latency_requirement_ms / (1000 / frame_rate_fps))
# 选择16的倍数(方便硬件优化)
chunk_size = min(max_chunk_size, 64)
chunk_size = max(chunk_size, min_chunk_size)
chunk_size = (chunk_size // 16) * 16
return chunk_size
# 示例
chunk_size = choose_chunk_size(latency_requirement_ms=200)
print(f"Chunk大小: {chunk_size} 帧")
3. 内存优化
python
# ✅ 推荐: 流式推理时禁用梯度
model.eval()
for param in model.parameters():
param.requires_grad = False
with torch.no_grad():
for chunk in stream:
output = model.process(chunk)
# ✅ 推荐: 使用inplace操作
torch.backends.cudnn.benchmark = True
# ✅ 推荐: 量化模型(如果精度允许)
quantized_model = torch.quantization.quantize_dynamic(
model, {nn.Linear, nn.LSTM}, dtype=torch.qint8
)
4. 多线程/多进程
python
# ✅ 推荐: 音频处理和模型推理分离
import queue
from threading import Thread
def audio_capture_thread(audio_queue):
"""音频采集线程"""
while True:
chunk = capture_audio()
audio_queue.put(chunk)
def inference_thread(audio_queue, result_queue):
"""推理线程"""
states = model.get_init_states(1, device)
while True:
chunk = audio_queue.get()
output, _, new_states = model(chunk, states=states)
states = new_states
result_queue.put(output)
# 启动
audio_q = queue.Queue(maxsize=10)
result_q = queue.Queue(maxsize=10)
Thread(target=audio_capture_thread, args=(audio_q,)).start()
Thread(target=inference_thread, args=(audio_q, result_q)).start()
七、常见问题
Q1: 流式推理的结果和训练时不一致?
原因:
- RandomCombine在训练和推理时行为不同
- Layer Dropout在训练时有随机性
- Dropout层的影响
解决:
python
# 确保设置为评估模式
model.eval()
# 或者在训练时也测试流式推理
model.eval()
with torch.no_grad():
# 流式推理测试
...
model.train()
Q2: 流式推理时chunk边界有断裂感?
原因 :
卷积下采样在chunk边界可能损失信息
解决:使用重叠chunk
python
# ✅ 使用重叠
chunk_size = 16
overlap = 4 # 重叠4帧
for i in range(0, len(audio), chunk_size - overlap):
chunk = audio[i:i+chunk_size]
output = process(chunk)
# 只使用中间部分,丢弃边界
valid_output = output[:, overlap//2:-overlap//2, :]
Q3: 多会话时显存不足?
解决:
python
# 1. 限制并发会话数
MAX_SESSIONS = 100
# 2. 自动清理长时间未活动的会话
def cleanup_inactive_sessions(sessions, timeout=300):
now = time.time()
for sid, session in list(sessions.items()):
if now - session['last_active'] > timeout:
del sessions[sid]
# 3. 使用CPU推理
model = model.cpu()
Q4: 如何加速流式推理?
方法:
- 模型量化
python
quantized_model = torch.quantization.quantize_dynamic(
model, {nn.LSTM, nn.Linear}, dtype=torch.qint8
)
- 模型剪枝
python
# 减少层数
model_small = RNN(num_encoder_layers=6) # 从12减少到6
- 使用TorchScript
python
traced = torch.jit.trace(model, (example_input, example_lens))
traced.save('model_traced.pt')
- ONNX导出
python
torch.onnx.export(model, (x, x_lens), 'model.onnx')
八、总结
核心差异
特性 | 训练模式 | 流式推理 |
---|---|---|
目标 | 学习参数 | 实时输出 |
状态 | 不需要 | 必须维护 |
随机性 | 有 | 无 |
延迟 | 高 | 低 |
吞吐量 | 高 | 中 |
内存 | 大 | 小 |
选择建议
使用训练模式:
- ✅ 模型训练
- ✅ 离线批量评估
- ✅ 研究实验
- ✅ 数据分析
使用流式推理:
- ✅ 实时应用(语音助手)
- ✅ 边缘设备
- ✅ 低延迟要求
- ✅ 内存受限场景
关键要点
-
状态管理是流式推理的核心
- 必须正确维护和传递LSTM状态
- 新对话需要重置状态
-
训练和推理的网络行为不同
- Layer Dropout只在训练时有效
- RandomCombine只在训练时启用
-
性能权衡
- 训练模式: 高吞吐、高延迟、高内存
- 流式推理: 低延迟、低内存、中等吞吐
-
正确设置模式
- 训练:
model.train()
- 推理:
model.eval()
+torch.no_grad()
- 训练: