整体架构
一、整体架构
- 系统功能
这是一个手语到文本(Sign-to-Text, S2T)的多任务学习系统,同时完成:
- 手语识别(CSLR):视频 → Gloss序列
- 手语翻译(SLT):视频 → 德语文本
-
核心组件关系
视频数据 → 视觉特征提取 → 多任务模型 → 输出
↓ ↓ ↓
Phoenix-2014 SignLanguageModel Gloss序列/德语文本
↓ ↓ ↓
Gloss标注 Tokenizer 评估指标
二、关键模块详解
- Tokenizer 模块 (
Tokenizer.py)
python
class GlossTokenizer_S2G:
# 专门处理手语Gloss的tokenizer
# 基于你的 gloss2ids.pkl 文件
作用:
- 将手语gloss(如
"REGEN")转换为ID(如4) - 将ID序列转换回gloss文本
- 处理特殊标记:
<si>,<unk>,<pad>,</s>
- 数据集模块 (
datasets.py)
python
class S2T_Dataset:
# 加载Phoenix-2014数据集
# 包含:视频路径、gloss标注、德语翻译
数据格式:
python
# 每个样本包含
{
'name': 'video_001', # 视频文件名
'video_path': 'path/to/video',
'gloss': 'PROGNOSE MORGEN REGEN NORD', # Gloss序列
'text': 'Prognose für morgen: Regen im Norden', # 德语翻译
'gloss_ids': [X, 8, 4, 10] # 通过tokenizer转换
}
- 核心模型 (
model.py)
python
class SignLanguageModel:
# 多任务手语模型
def __init__(self, cfg, args):
# 包含:
# 1. 视觉编码器(CNN/Transformer)
# 2. 识别网络(CTC解码)
# 3. 翻译网络(Transformer解码)
def forward(self, src_input):
# 同时计算识别和翻译损失
return {
'recognition_loss': loss1, # Gloss识别损失
'translation_loss': loss2, # 德语翻译损失
'total_loss': loss1 + loss2
}
三、训练流程详解
- 数据加载和预处理
python
# 使用你的gloss词汇表
tokenizer = GlossTokenizer_S2G(config['gloss'])
# config['gloss'] 指向 gloss2ids.pkl
# 创建数据集
train_data = S2T_Dataset(
path=config['data']['train_label_path'], # Phoenix-2014训练标注
tokenizer=tokenizer,
config=config,
phase='train'
)
- 模型初始化
python
# 创建多任务模型
model = SignLanguageModel(cfg=config, args=args)
# 模型包含:
# - 视觉特征提取器(处理视频帧)
# - 时序建模(LSTM/Transformer)
# - 双输出头:Gloss分类 + 文本生成
- 训练循环
python
for epoch in range(args.start_epoch, args.epochs):
# 1. 训练阶段
train_stats = train_one_epoch(...)
# 2. 验证阶段
test_stats = evaluate(...)
# 3. 保存最佳模型(基于BLEU-4或WER)
if bleu_4 < test_stats["bleu4"]:
save_best_checkpoint()
四、评估和推理流程
- 手语识别评估(CSLR)
python
def evaluate(...):
# CTC解码获取Gloss序列
ctc_decode_output = model.recognition_network.decode(
gloss_logits=gls_logits,
beam_size=beam_size,
input_lengths=output['input_lengths']
)
# 转换为可读文本
batch_pred_gls = tokenizer.convert_ids_to_tokens(ctc_decode_output)
# 计算WER(词错误率)
wer_results = wer_list(hypotheses=gls_hyp, references=gls_ref)
评估指标:
- WER(Word Error Rate):手语识别的主要指标
- 计算预测gloss序列与真实gloss序列的差异
- 手语翻译评估(SLT)
python
# 生成德语文本
generate_output = model.generate_txt(
transformer_inputs=output['transformer_inputs'],
generate_cfg=generate_cfg
)
# 计算BLEU和ROUGE
bleu_dict = bleu(references=txt_ref, hypotheses=txt_hyp)
rouge_score = rouge(references=txt_ref, hypotheses=txt_hyp)
评估指标:
- BLEU-4:n-gram精度,主要翻译指标
- ROUGE:召回率导向的评估
- 数据清洗(重要!)
python
# Phoenix-2014数据需要特殊清洗
if config['data']['dataset_name'].lower() == 'phoenix-2014t':
gls_ref = [clean_phoenix_2014_trans(results[n]['gls_ref']) for n in results]
gls_hyp = [clean_phoenix_2014_trans(results[n][hyp_name]) for n in results]
清洗内容:
- 移除标点符号
- 统一大小写
- 处理特殊字符
- 标准化gloss格式
五、模型架构细节推测
- 视觉编码器
python
# 可能的结构
class VisualEncoder(nn.Module):
def __init__(self):
# 3D-CNN 或 Transformer
# 处理视频序列,提取时空特征
def forward(self, video_frames):
# 输入: [batch, frames, H, W, C]
# 输出: [batch, seq_len, hidden_dim]
- 识别网络(CTC-based)
python
class RecognitionNetwork(nn.Module):
# CTC(Connectionist Temporal Classification)
# 处理变长序列对齐问题
def decode(self, gloss_logits, beam_size):
# Beam Search解码
# 将连续输出映射到离散gloss序列
- 翻译网络(Transformer-based)
python
class TranslationNetwork(nn.Module):
# Transformer编码器-解码器
# 将视觉特征转换为德语文本
def generate_txt(self, transformer_inputs, generate_cfg):
# 自回归生成文本
# 使用beam search或sampling
六、多任务学习策略
- 损失函数组合
python
total_loss = recognition_loss + λ * translation_loss
# λ是平衡两个任务的超参数
-
共享表示学习
视频输入 → 共享视觉编码器 → 共享特征
↓
识别头 → Gloss输出
翻译头 → 德语输出 -
课程学习
可能采用:
- 先训练识别任务(相对简单)
- 然后联合训练识别+翻译
- 最后微调翻译任务
七、配置文件和超参数
- YAML配置文件示例
yaml
# configs/csl-daily_s2g.yaml
data:
dataset_name: 'phoenix-2014t'
gloss: 'path/to/gloss2ids.pkl' # 你的词汇表
train_label_path: 'data/train.gloss'
dev_label_path: 'data/dev.gloss'
test_label_path: 'data/test.gloss'
model:
visual_backbone: 'resnet3d'
hidden_dim: 512
num_layers: 6
num_heads: 8
training:
optimization:
lr: 1e-4
weight_decay: 0.01
scheduler: 'cosine'
validation:
recognition:
beam_size: 5
translation:
beam_size: 5
max_length: 100
- 关键超参数
python
batch_size = 2 # 小批量(视频数据内存大)
epochs = 100 # 训练轮数
beam_size = 5 # Beam Search宽度
hidden_dim = 512 # 模型隐藏维度
八、分布式训练支持
- DDP初始化
python
def init_ddp(local_rank):
# 分布式数据并行
dist.init_process_group(backend='nccl', init_method='env://')
- 数据并行策略
- 每个GPU处理一部分数据
- 梯度同步更新
- 支持多机多卡训练
九、实验日志和监控
- WandB集成
python
# 实验跟踪
wandb.init(project='VLP', config=config)
wandb.log({'epoch': epoch, 'train_loss': loss, 'dev_bleu': bleu4})
- 日志记录
python
# 保存训练日志
with (output_dir / "log.txt").open("a") as f:
f.write(json.dumps(log_stats) + "\n")
十、完整推理流程示例
输入视频处理:
python
# 1. 加载视频
video = load_video('weather_forecast.mp4') # [T, H, W, C]
# 2. 提取特征
visual_features = model.visual_encoder(video)
# 3. 识别手语
gloss_logits = model.recognition_head(visual_features)
predicted_gloss = model.decode(gloss_logits) # "REGEN MORGEN NORD"
# 4. 翻译成德语
german_text = model.translate(visual_features) # "Morgen Regen im Norden"
十一、手语识别的特殊挑战处理
-
序列长度不匹配
视频帧数: 300帧
Gloss数量: 5个
通过CTC解决对齐问题 -
视觉特征提取
- 使用3D-CNN捕获时空特征
- 姿态估计作为辅助特征
- 光流信息增强运动理解
- 数据增强策略
python
# 手语特定的数据增强
- 时间扭曲(改变播放速度)
- 空间裁剪(模拟不同视角)
- 颜色抖动(适应不同光照)
总结
这段代码实现了一个完整的手语识别与翻译系统,具有以下特点:
- 多任务架构:同时学习识别(gloss)和翻译(德语)
- 专业数据处理:使用Phoenix-2014数据集和专用清洗函数
- 工业级实现:支持分布式训练、实验跟踪、模型保存
- 领域特定优化:针对手语识别的CTC解码、beam search等
- 可扩展设计:模块化设计,易于扩展到其他手语数据集
这个系统将视觉的手语视频 通过深度学习模型转换为结构化的文本输出,是典型的跨模态人工智能应用,对于促进听障人士交流具有重要意义。