整体架构代码详细解读
好的,我将对这段手语识别与翻译的代码进行逐行详细解释。由于代码较长,我将按功能模块进行分组解释。
一、导入模块部分
python
import torch
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
from Tokenizer import GlossTokenizer_S2G # 手语Gloss的tokenizer
from model import SignLanguageModel # 手语识别翻译模型
import utils as utils # 工具函数
from datasets import S2T_Dataset # 手语到文本的数据集
import os
import time
import argparse, json, datetime
import numpy as np
from collections import defaultdict
import yaml # 配置文件解析
import random
import wandb # 实验跟踪工具
from pathlib import Path
import math
import sys
from typing import Iterable
from loguru import logger # 日志记录
# *metric - 评估指标
from metrics import wer_list, bleu, rouge # WER(词错误率), BLEU, ROUGE
import torch.distributed as dist # 分布式训练
# global definition
from optimizer import build_optimizer, build_scheduler # 优化器和学习率调度器
from phoenix_cleanup import clean_phoenix_2014_trans, clean_phoenix_2014 # Phoenix数据集清洗函数
关键解释:
GlossTokenizer_S2G: 专门处理手语gloss的tokenizer,基于你的gloss2ids.pklSignLanguageModel: 核心模型,同时进行手语识别和翻译S2T_Dataset: 手语到文本的数据集,加载Phoenix数据集wer_list: 计算词错误率,手语识别的主要评估指标clean_phoenix_2014*: Phoenix数据集专用的清洗函数,处理标注格式
二、参数解析器
python
def get_args_parser():
parser = argparse.ArgumentParser('Visual-Language-Pretraining (VLP) V2 scripts', add_help=False)
# 批量大小参数
parser.add_argument('--batch-size', default=2, type=int) # 设为2,视频数据内存消耗大
# 训练轮数
parser.add_argument('--epochs', default=100, type=int)
# 分布式训练参数
parser.add_argument('--world_size', default=2, type=int,
help='number of distributed processes') # 分布式进程数
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') # 分布式初始化URL
parser.add_argument('--local_rank', default=0, type=int) # 本地GPU rank
# * 微调参数
parser.add_argument('--finetune', default='', help='finetune from checkpoint') # 从检查点微调
# * 基本参数
parser.add_argument('--device', default='cuda',
help='device to use for training / testing') # 设备
parser.add_argument('--seed', default=0, type=int) # 随机种子
parser.add_argument('--resume', default='', help='resume from checkpoint') # 恢复训练
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
help='start epoch') # 起始轮数
parser.add_argument('--eval', action='store_true', help='Perform evaluation only') # 仅评估模式
parser.add_argument('--num_workers', default=4, type=int) # 数据加载工作进程数
parser.add_argument('--pin-mem', action='store_true',
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') # 锁页内存
parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem',
help='')
parser.set_defaults(pin_mem=True) # 默认使用锁页内存
parser.add_argument('--config', type=str, default='configs/csl-daily_s2g.yaml') # 配置文件路径
# * wandb参数
parser.add_argument("--log_all", action="store_true",
help="flag to log in all processes, otherwise only in rank0",) # 是否所有进程都记录日志
parser.add_argument("--entity", type=str,
help="wandb entity",) # wandb组织/用户
parser.add_argument("--project", type=str, default='VLP',
help="wandb project",) # wandb项目名
return parser
三、分布式初始化函数
python
def init_ddp(local_rank):
torch.cuda.set_device(local_rank) # 设置当前GPU设备
os.environ['RANK'] = str(local_rank) # 设置环境变量RANK
dist.init_process_group(backend='nccl', init_method='env://') # 初始化分布式进程组,使用NCCL后端
解释:初始化分布式数据并行(DDP)训练,允许多GPU训练。
四、主函数(main)第一部分:初始化
python
def main(args, config):
# 初始化分布式训练模式
utils.init_distributed_mode(args)
print(args) # 打印参数
# 设置设备
device = torch.device(args.device)
# 设置随机种子以保证可重复性
seed = args.seed + utils.get_rank() # 每个进程使用不同的种子
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
cudnn.benchmark = False # 关闭cuDNN自动优化,保证确定性
五、数据准备部分
python
print(f"Creating dataset:")
# 创建Gloss tokenizer,用于将gloss转换为ID
tokenizer = GlossTokenizer_S2G(config['gloss']) # config['gloss']指向gloss2ids.pkl
# 创建训练数据集
train_data = S2T_Dataset(path=config['data']['train_label_path'], # 训练标注路径
tokenizer=tokenizer, # gloss tokenizer
config=config, # 配置文件
args=args, # 命令行参数
phase='train', # 训练阶段
training_refurbish=True) # 是否进行数据增强
print(train_data) # 打印数据集信息
# 创建训练数据加载器
train_dataloader = DataLoader(train_data,
batch_size=args.batch_size,
num_workers=args.num_workers,
collate_fn=train_data.collate_fn, # 自定义批处理函数
shuffle=True, # 训练时打乱数据
pin_memory=args.pin_mem, # 使用锁页内存
drop_last=True) # 丢弃最后不完整的batch
# 创建验证数据集和加载器(与训练集类似,但不shuffle)
dev_data = S2T_Dataset(path=config['data']['dev_label_path'],
tokenizer=tokenizer, config=config, args=args,
phase='val', training_refurbish=True)
print(dev_data)
dev_dataloader = DataLoader(dev_data,
batch_size=args.batch_size,
num_workers=args.num_workers,
collate_fn=dev_data.collate_fn,
pin_memory=args.pin_mem)
# 创建测试数据集和加载器
test_data = S2T_Dataset(path=config['data']['test_label_path'],
tokenizer=tokenizer, config=config, args=args,
phase='test', training_refurbish=True)
print(test_data)
test_dataloader = DataLoader(test_data,
batch_size=args.batch_size,
num_workers=args.num_workers,
collate_fn=test_data.collate_fn,
pin_memory=args.pin_mem)
数据流程 :Phoenix标注文件 → S2T_Dataset → DataLoader → 批数据
六、模型和优化器初始化
python
print(f"Creating model:")
# 创建手语识别翻译模型
model = SignLanguageModel(cfg=config, args=args)
model.to(device) # 移动到指定设备
print(model) # 打印模型结构
# 如果有微调检查点,加载预训练权重
if args.finetune:
checkpoint = torch.load(args.finetune, map_location='cpu') # 加载检查点
ret = model.load_state_dict(checkpoint['model'], strict=False) # 加载模型权重
print('Missing keys: \n', '\n'.join(ret.missing_keys)) # 打印缺失的键
print('Unexpected keys: \n', '\n'.join(ret.unexpected_keys)) # 打印意外的键
# 计算模型参数数量(以MB为单位)
n_parameters = utils.count_parameters_in_MB(model)
print(f'number of params: {n_parameters}M')
# 构建优化器
optimizer = build_optimizer(config=config['training']['optimization'], model=model)
# 构建学习率调度器
scheduler, scheduler_type = build_scheduler(config=config['training']['optimization'], optimizer=optimizer)
# 设置模型保存目录
output_dir = Path(config['training']['model_dir'])
七、恢复训练检查点
python
# 如果有恢复检查点,加载训练状态
if args.resume:
checkpoint = torch.load(args.resume, map_location='cpu') # 加载检查点
model.load_state_dict(checkpoint['model'], strict=True) # 加载模型权重
# 如果不是仅评估模式,加载优化器和调度器状态
if not args.eval and 'optimizer' in checkpoint and 'scheduler' in checkpoint and 'epoch' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer']) # 加载优化器状态
scheduler.load_state_dict(checkpoint['scheduler']) # 加载调度器状态
args.start_epoch = checkpoint['epoch'] + 1 # 设置起始轮数
八、仅评估模式
python
# 如果只是评估模式,不训练,直接评估
if args.eval:
if not args.resume: # 评估需要加载训练好的模型
logger.warning('Please specify the trained model: --resume /path/to/best_checkpoint.pth')
# 评估验证集
dev_stats = evaluate(args, config, dev_dataloader, model, tokenizer, epoch=0, beam_size=5,
generate_cfg=config['training']['validation']['translation'],
do_translation=config['do_translation'], do_recognition=config['do_recognition'])
print(f"Dev loss of the network on the {len(dev_dataloader)} test videos: {dev_stats['loss']:.3f}")
# 评估测试集
test_stats = evaluate(args, config, test_dataloader, model, tokenizer, epoch=0, beam_size=5,
generate_cfg=config['testing']['translation'],
do_translation=config['do_translation'], do_recognition=config['do_recognition'])
print(f"Test loss of the network on the {len(test_dataloader)} test videos: {test_stats['loss']:.3f}")
return # 评估完成,直接返回
九、训练循环
python
print(f"Start training for {args.epochs} epochs")
start_time = time.time() # 记录开始时间
# 初始化最佳指标
min_loss = 200 # 最小WER(手语识别指标,越低越好)
bleu_4 = 0 # 最佳BLEU-4(手语翻译指标,越高越好)
# 开始训练循环
for epoch in range(args.start_epoch, args.epochs):
scheduler.step() # 更新学习率
# 1. 训练一个epoch
train_stats = train_one_epoch(args, model, tokenizer, train_dataloader, optimizer, device, epoch)
# 2. 保存当前检查点
checkpoint_paths = [output_dir / f'checkpoint.pth'] # 检查点保存路径
for checkpoint_path in checkpoint_paths:
utils.save_on_master({ # 只在主进程保存
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': scheduler.state_dict(),
'epoch': epoch,
}, checkpoint_path)
# 3. 在验证集上评估
test_stats = evaluate(args, config, dev_dataloader, model, tokenizer, epoch,
beam_size=config['training']['validation']['recognition']['beam_size'],
generate_cfg=config['training']['validation']['translation'],
do_translation=config['do_translation'], do_recognition=config['do_recognition'])
十、保存最佳模型逻辑
python
# 根据任务类型保存最佳模型
if config['task'] == "S2T": # 如果是手语翻译任务
if bleu_4 < test_stats["bleu4"]: # 如果当前BLEU-4更好
bleu_4 = test_stats["bleu4"] # 更新最佳BLEU-4
checkpoint_paths = [output_dir / 'best_checkpoint.pth'] # 最佳模型路径
for checkpoint_path in checkpoint_paths:
utils.save_on_master({ # 保存最佳模型
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': scheduler.state_dict(),
'epoch': epoch,
}, checkpoint_path)
print(f"* DEV BLEU-4 {test_stats['bleu4']:.3f} Max DEV BLEU-4 {bleu_4}")
else: # 如果是手语识别任务
if min_loss > test_stats["wer"]: # 如果当前WER更低(更好)
min_loss = test_stats["wer"] # 更新最佳WER
checkpoint_paths = [output_dir / 'best_checkpoint.pth'] # 最佳模型路径
for checkpoint_path in checkpoint_paths:
utils.save_on_master({ # 保存最佳模型
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': scheduler.state_dict(),
'epoch': epoch,
}, checkpoint_path)
print(f"* DEV wer {test_stats['wer']:.3f} Min DEV WER {min_loss}")
十一、日志记录
python
# 记录到wandb(如果启用)
if args.run:
args.run.log(
{'epoch': epoch + 1,
'training/train_loss': train_stats['loss'],
'dev/dev_loss': test_stats['loss'],
'dev/min_loss': min_loss})
# 合并训练和测试统计信息
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
**{f'test_{k}': v for k, v in test_stats.items()},
'epoch': epoch,
'n_parameters': n_parameters}
# 写入日志文件
with (output_dir / "log.txt").open("a") as f:
f.write(json.dumps(log_stats) + "\n")
十二、训练结束后评估最佳模型
python
# 最后一个epoch:用最佳模型在测试集上评估
test_on_last_epoch = True
if test_on_last_epoch:
# 加载最佳检查点
checkpoint = torch.load(str(output_dir) + '/best_checkpoint.pth', map_location='cpu')
model.load_state_dict(checkpoint['model'], strict=True) # 加载最佳模型
# 在验证集上评估
dev_stats = evaluate(args, config, dev_dataloader, model, tokenizer, epoch=0,
beam_size=config['testing']['recognition']['beam_size'],
generate_cfg=config['training']['validation']['translation'],
do_translation=config['do_translation'], do_recognition=config['do_recognition'])
print(f"Dev loss of the network on the {len(dev_dataloader)} test videos: {dev_stats['loss']:.3f}")
# 在测试集上评估
test_stats = evaluate(args, config, test_dataloader, model, tokenizer, epoch=0,
beam_size=config['testing']['recognition']['beam_size'],
generate_cfg=config['testing']['translation'],
do_translation=config['do_translation'], do_recognition=config['do_recognition'])
print(f"Test loss of the network on the {len(test_dataloader)} test videos: {test_stats['loss']:.3f}")
# 记录最终评估结果
if config['do_recognition']: # 如果做了手语识别
with (output_dir / "log.txt").open("a") as f:
f.write(json.dumps({'Dev WER:': dev_stats['wer'],
'Test WER:': test_stats['wer']}) + "\n")
if config['do_translation']: # 如果做了手语翻译
with (output_dir / "log.txt").open("a") as f:
f.write(json.dumps({'Dev Bleu-4:': dev_stats['bleu4'],
'Test Bleu-4:': test_stats['bleu4']}) + "\n")
# 计算总训练时间
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))
十三、单epoch训练函数
python
def train_one_epoch(args, model: torch.nn.Module, criterion,
data_loader: Iterable, optimizer: torch.optim.Optimizer,
device: torch.device, epoch: int):
model.train() # 设置模型为训练模式
metric_logger = utils.MetricLogger(delimiter=" ") # 创建指标记录器
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) # 添加学习率记录
header = 'Epoch: [{}/{}]'.format(epoch, args.epochs) # 进度条头
print_freq = 10 # 每10个batch打印一次
# 遍历数据加载器
for step, (src_input) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
optimizer.zero_grad() # 清空梯度
output = model(src_input) # 前向传播
# 使用异常检测(调试用)
with torch.autograd.set_detect_anomaly(True):
output['total_loss'].backward() # 反向传播计算梯度
optimizer.step() # 更新参数
model.zero_grad() # 再次清空梯度(安全)
loss_value = output['total_loss'].item() # 获取损失值
# 检查损失是否有效
if not math.isfinite(loss_value):
print("Loss is {}, stopping training".format(loss_value))
sys.exit(1) # 损失无效,停止训练
# 更新指标记录器
metric_logger.update(loss=loss_value)
metric_logger.update(lr=optimizer.param_groups[0]["lr"]) # 记录学习率
# 记录到wandb
if args.run:
args.run.log({'epoch': epoch + 1, 'epoch/train_loss': loss_value})
# 收集所有进程的统计信息
print("Averaged stats:", metric_logger)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()} # 返回平均指标
十四、评估函数(重点)
python
def evaluate(args, config, dev_dataloader, model, tokenizer, epoch, beam_size=1, generate_cfg={}, do_translation=True,
do_recognition=True):
model.eval() # 设置模型为评估模式
metric_logger = utils.MetricLogger(delimiter=" ") # 指标记录器
header = 'Test:' # 进度条头
print_freq = 10 # 打印频率
results = defaultdict(dict) # 存储每个样本的结果
with torch.no_grad(): # 关闭梯度计算
# 遍历验证/测试集
for step, (src_input) in enumerate(metric_logger.log_every(dev_dataloader, print_freq, header)):
output = model(src_input) # 前向传播
# 如果做手语识别
if do_recognition:
# 遍历所有输出,找到gloss_logits
for k, gls_logits in output.items():
if not 'gloss_logits' in k: # 如果不是gloss_logits,跳过
continue
logits_name = k.replace('gloss_logits', '') # 提取logits名称
# CTC解码获取预测的gloss序列
ctc_decode_output = model.recognition_network.decode(
gloss_logits=gls_logits, # gloss logits
beam_size=beam_size, # beam search宽度
input_lengths=output['input_lengths']) # 输入长度
# 将ID转换为gloss token
batch_pred_gls = tokenizer.convert_ids_to_tokens(ctc_decode_output)
# 存储每个样本的结果
for name, gls_hyp, gls_ref in zip(src_input['name'], batch_pred_gls, src_input['gloss']):
results[name][f'{logits_name}gls_hyp'] = \
' '.join(gls_hyp).upper() if tokenizer.lower_case \
else ' '.join(gls_hyp) # 预测的gloss序列
results[name]['gls_ref'] = gls_ref.upper() if tokenizer.lower_case \
else gls_ref # 参考gloss序列
# 如果做手语翻译
if do_translation:
# 生成德语文本
generate_output = model.generate_txt(
transformer_inputs=output['transformer_inputs'], # 翻译器输入
generate_cfg=generate_cfg) # 生成配置
# 存储每个样本的翻译结果
for name, txt_hyp, txt_ref in zip(src_input['name'], generate_output['decoded_sequences'],
src_input['text']):
results[name]['txt_hyp'], results[name]['txt_ref'] = txt_hyp, txt_ref
# 更新损失指标
metric_logger.update(loss=output['total_loss'].item())
十五、评估指标计算
python
# 计算手语识别指标(WER)
if do_recognition:
evaluation_results = {}
evaluation_results['wer'] = 200 # 初始化WER为较大值
# 遍历所有预测结果
for hyp_name in results[name].keys():
if not 'gls_hyp' in hyp_name: # 如果不是预测结果,跳过
continue
k = hyp_name.replace('gls_hyp', '') # 提取结果名称
# 根据数据集类型进行数据清洗
if config['data']['dataset_name'].lower() == 'phoenix-2014t':
# 清洗Phoenix-2014T数据
gls_ref = [clean_phoenix_2014_trans(results[n]['gls_ref']) for n in results]
gls_hyp = [clean_phoenix_2014_trans(results[n][hyp_name]) for n in results]
elif config['data']['dataset_name'].lower() == 'phoenix-2014':
# 清洗Phoenix-2014数据
gls_ref = [clean_phoenix_2014(results[n]['gls_ref']) for n in results]
gls_hyp = [clean_phoenix_2014(results[n][hyp_name]) for n in results]
elif config['data']['dataset_name'].lower() == 'csl-daily':
# CSL-Daily数据集
gls_ref = [results[n]['gls_ref'] for n in results]
gls_hyp = [results[n][hyp_name] for n in results]
# 计算WER(词错误率)
wer_results = wer_list(hypotheses=gls_hyp, references=gls_ref)
evaluation_results[k + 'wer_list'] = wer_results # 存储详细结果
evaluation_results['wer'] = min(wer_results['wer'], evaluation_results['wer']) # 取最小WER
metric_logger.update(wer=evaluation_results['wer']) # 更新WER指标
# 计算手语翻译指标(BLEU, ROUGE)
if do_translation:
txt_ref = [results[n]['txt_ref'] for n in results] # 参考翻译
txt_hyp = [results[n]['txt_hyp'] for n in results] # 预测翻译
# 计算BLEU分数
bleu_dict = bleu(references=txt_ref, hypotheses=txt_hyp, level=config['data']['level'])
# 计算ROUGE分数
rouge_score = rouge(references=txt_ref, hypotheses=txt_hyp, level=config['data']['level'])
# 打印指标
for k, v in bleu_dict.items():
print('{} {:.2f}'.format(k, v))
print('ROUGE: {:.2f}'.format(rouge_score))
# 存储评估结果
evaluation_results['rouge'], evaluation_results['bleu'] = rouge_score, bleu_dict
# 记录到wandb
wandb.log({'eval/BLEU4': bleu_dict['bleu4']})
wandb.log({'eval/ROUGE': rouge_score})
# 更新指标记录器
metric_logger.update(bleu1=bleu_dict['bleu1'])
metric_logger.update(bleu2=bleu_dict['bleu2'])
metric_logger.update(bleu3=bleu_dict['bleu3'])
metric_logger.update(bleu4=bleu_dict['bleu4'])
metric_logger.update(rouge=rouge_score)
十六、评估结果记录
python
# 记录到wandb
if args.run:
args.run.log(
{'epoch': epoch + 1,
'epoch/dev_loss': output['recognition_loss'].item(),
'wer': evaluation_results['wer']})
# 打印平均统计信息
print("* Averaged stats:", metric_logger)
print('* DEV loss {losses.global_avg:.3f}'.format(losses=metric_logger.loss))
# 返回所有指标的全局平均值
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
十七、WandB实验设置
python
def setup_run(args, config):
# 如果所有进程都记录日志
if args.log_all:
os.environ["WANDB_MODE"] = config['training']['wandb'] if not args.eval else 'disabled'
run = wandb.init(
entity=args.entity, # wandb组织/用户
project=args.project, # 项目名
group=args.output_dir.split('/')[-1], # 实验组名
config=config, # 配置参数
)
# 定义指标
run.define_metric("epoch")
run.define_metric("training/*", step_metric="epoch")
run.define_metric("dev/*", step_metric="epoch")
else:
# 只在主进程记录日志
if utils.is_main_process():
os.environ["WANDB_MODE"] = config['training']['wandb'] if not args.eval else 'disabled'
run = wandb.init(
entity=args.entity,
project=args.project,
config=config,
)
run.define_metric("epoch")
run.define_metric("training/*", step_metric="epoch")
run.define_metric("dev/*", step_metric="epoch")
else:
os.environ["WANDB_MODE"] = 'disabled'
run = False
return run
十八、程序入口点
python
if __name__ == '__main__':
os.environ["TOKENIZERS_PARALLELISM"] = "false" # 禁用tokenizers并行,避免警告
# 创建参数解析器
parser = argparse.ArgumentParser('Visual-Language-Pretraining (VLP) V2 scripts', parents=[get_args_parser()])
args = parser.parse_args() # 解析命令行参数
# 加载YAML配置文件
with open(args.config, 'r+', encoding='utf-8') as f:
config = yaml.load(f, Loader=yaml.FullLoader)
# 初始化wandb运行
args.run = setup_run(args, config)
# 创建模型保存目录
Path(config['training']['model_dir']).mkdir(parents=True, exist_ok=True)
# 运行主函数
main(args, config)
关键总结
- 多任务架构:同时训练手语识别(Gloss序列)和手语翻译(德语文本)
- Phoenix数据集专用:使用专门的清洗函数处理数据集格式
- CTC解码:用于手语识别,处理视频帧与gloss序列的对齐问题
- Beam Search:用于手语翻译的文本生成
- 分布式训练:支持多GPU训练加速
- 完整的实验跟踪:WandB日志记录、模型检查点保存
- 评估指标 :
- 手语识别:WER(词错误率)
- 手语翻译:BLEU-4、ROUGE
这是一个工业级的手语识别与翻译系统实现,代码结构清晰,功能完整,支持训练、评估、分布式训练等多种功能。