手语识别及翻译项目实战系列(五)整体架构代码详细代码实现

整体架构代码详细解读

好的,我将对这段手语识别与翻译的代码进行逐行详细解释。由于代码较长,我将按功能模块进行分组解释。

一、导入模块部分

python 复制代码
import torch
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
from Tokenizer import GlossTokenizer_S2G  # 手语Gloss的tokenizer
from model import SignLanguageModel  # 手语识别翻译模型
import utils as utils  # 工具函数
from datasets import S2T_Dataset  # 手语到文本的数据集
import os
import time
import argparse, json, datetime
import numpy as np
from collections import defaultdict
import yaml  # 配置文件解析
import random
import wandb  # 实验跟踪工具
from pathlib import Path
import math
import sys
from typing import Iterable
from loguru import logger  # 日志记录

# *metric - 评估指标
from metrics import wer_list, bleu, rouge  # WER(词错误率), BLEU, ROUGE
import torch.distributed as dist  # 分布式训练

# global definition
from optimizer import build_optimizer, build_scheduler  # 优化器和学习率调度器
from phoenix_cleanup import clean_phoenix_2014_trans, clean_phoenix_2014  # Phoenix数据集清洗函数

关键解释

  • GlossTokenizer_S2G: 专门处理手语gloss的tokenizer,基于你的gloss2ids.pkl
  • SignLanguageModel: 核心模型,同时进行手语识别和翻译
  • S2T_Dataset: 手语到文本的数据集,加载Phoenix数据集
  • wer_list: 计算词错误率,手语识别的主要评估指标
  • clean_phoenix_2014*: Phoenix数据集专用的清洗函数,处理标注格式

二、参数解析器

python 复制代码
def get_args_parser():
    parser = argparse.ArgumentParser('Visual-Language-Pretraining (VLP) V2 scripts', add_help=False)

    # 批量大小参数
    parser.add_argument('--batch-size', default=2, type=int)  # 设为2,视频数据内存消耗大
    
    # 训练轮数
    parser.add_argument('--epochs', default=100, type=int)

    # 分布式训练参数
    parser.add_argument('--world_size', default=2, type=int,
                        help='number of distributed processes')  # 分布式进程数
    parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')  # 分布式初始化URL
    parser.add_argument('--local_rank', default=0, type=int)  # 本地GPU rank

    # * 微调参数
    parser.add_argument('--finetune', default='', help='finetune from checkpoint')  # 从检查点微调
    
    # * 基本参数
    parser.add_argument('--device', default='cuda',
                        help='device to use for training / testing')  # 设备
    parser.add_argument('--seed', default=0, type=int)  # 随机种子
    parser.add_argument('--resume', default='', help='resume from checkpoint')  # 恢复训练
    parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
                        help='start epoch')  # 起始轮数
    parser.add_argument('--eval', action='store_true', help='Perform evaluation only')  # 仅评估模式
    parser.add_argument('--num_workers', default=4, type=int)  # 数据加载工作进程数
    parser.add_argument('--pin-mem', action='store_true',
                        help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')  # 锁页内存
    parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem',
                        help='')
    parser.set_defaults(pin_mem=True)  # 默认使用锁页内存
    parser.add_argument('--config', type=str, default='configs/csl-daily_s2g.yaml')  # 配置文件路径

    # * wandb参数
    parser.add_argument("--log_all", action="store_true",
                        help="flag to log in all processes, otherwise only in rank0",)  # 是否所有进程都记录日志
    parser.add_argument("--entity", type=str,
                        help="wandb entity",)  # wandb组织/用户
    parser.add_argument("--project", type=str, default='VLP',
                        help="wandb project",)  # wandb项目名

    return parser

三、分布式初始化函数

python 复制代码
def init_ddp(local_rank):
    torch.cuda.set_device(local_rank)  # 设置当前GPU设备
    os.environ['RANK'] = str(local_rank)  # 设置环境变量RANK
    dist.init_process_group(backend='nccl', init_method='env://')  # 初始化分布式进程组,使用NCCL后端

解释:初始化分布式数据并行(DDP)训练,允许多GPU训练。

四、主函数(main)第一部分:初始化

python 复制代码
def main(args, config):
    # 初始化分布式训练模式
    utils.init_distributed_mode(args)
    print(args)  # 打印参数
    
    # 设置设备
    device = torch.device(args.device)
    
    # 设置随机种子以保证可重复性
    seed = args.seed + utils.get_rank()  # 每个进程使用不同的种子
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    cudnn.benchmark = False  # 关闭cuDNN自动优化,保证确定性

五、数据准备部分

python 复制代码
    print(f"Creating dataset:")
    # 创建Gloss tokenizer,用于将gloss转换为ID
    tokenizer = GlossTokenizer_S2G(config['gloss'])  # config['gloss']指向gloss2ids.pkl
    
    # 创建训练数据集
    train_data = S2T_Dataset(path=config['data']['train_label_path'],  # 训练标注路径
                             tokenizer=tokenizer,  # gloss tokenizer
                             config=config,  # 配置文件
                             args=args,  # 命令行参数
                             phase='train',  # 训练阶段
                             training_refurbish=True)  # 是否进行数据增强
    print(train_data)  # 打印数据集信息
    
    # 创建训练数据加载器
    train_dataloader = DataLoader(train_data,
                                  batch_size=args.batch_size,
                                  num_workers=args.num_workers,
                                  collate_fn=train_data.collate_fn,  # 自定义批处理函数
                                  shuffle=True,  # 训练时打乱数据
                                  pin_memory=args.pin_mem,  # 使用锁页内存
                                  drop_last=True)  # 丢弃最后不完整的batch

    # 创建验证数据集和加载器(与训练集类似,但不shuffle)
    dev_data = S2T_Dataset(path=config['data']['dev_label_path'], 
                           tokenizer=tokenizer, config=config, args=args,
                           phase='val', training_refurbish=True)
    print(dev_data)
    dev_dataloader = DataLoader(dev_data,
                                batch_size=args.batch_size,
                                num_workers=args.num_workers,
                                collate_fn=dev_data.collate_fn,
                                pin_memory=args.pin_mem)

    # 创建测试数据集和加载器
    test_data = S2T_Dataset(path=config['data']['test_label_path'], 
                            tokenizer=tokenizer, config=config, args=args,
                            phase='test', training_refurbish=True)
    print(test_data)
    test_dataloader = DataLoader(test_data,
                                 batch_size=args.batch_size,
                                 num_workers=args.num_workers,
                                 collate_fn=test_data.collate_fn,
                                 pin_memory=args.pin_mem)

数据流程Phoenix标注文件S2T_DatasetDataLoader → 批数据

六、模型和优化器初始化

python 复制代码
    print(f"Creating model:")
    # 创建手语识别翻译模型
    model = SignLanguageModel(cfg=config, args=args)
    model.to(device)  # 移动到指定设备
    print(model)  # 打印模型结构

    # 如果有微调检查点,加载预训练权重
    if args.finetune:
        checkpoint = torch.load(args.finetune, map_location='cpu')  # 加载检查点
        ret = model.load_state_dict(checkpoint['model'], strict=False)  # 加载模型权重
        print('Missing keys: \n', '\n'.join(ret.missing_keys))  # 打印缺失的键
        print('Unexpected keys: \n', '\n'.join(ret.unexpected_keys))  # 打印意外的键

    # 计算模型参数数量(以MB为单位)
    n_parameters = utils.count_parameters_in_MB(model)
    print(f'number of params: {n_parameters}M')
    
    # 构建优化器
    optimizer = build_optimizer(config=config['training']['optimization'], model=model)
    
    # 构建学习率调度器
    scheduler, scheduler_type = build_scheduler(config=config['training']['optimization'], optimizer=optimizer)
    
    # 设置模型保存目录
    output_dir = Path(config['training']['model_dir'])

七、恢复训练检查点

python 复制代码
    # 如果有恢复检查点,加载训练状态
    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')  # 加载检查点
        model.load_state_dict(checkpoint['model'], strict=True)  # 加载模型权重
        
        # 如果不是仅评估模式,加载优化器和调度器状态
        if not args.eval and 'optimizer' in checkpoint and 'scheduler' in checkpoint and 'epoch' in checkpoint:
            optimizer.load_state_dict(checkpoint['optimizer'])  # 加载优化器状态
            scheduler.load_state_dict(checkpoint['scheduler'])  # 加载调度器状态
            args.start_epoch = checkpoint['epoch'] + 1  # 设置起始轮数

八、仅评估模式

python 复制代码
    # 如果只是评估模式,不训练,直接评估
    if args.eval:
        if not args.resume:  # 评估需要加载训练好的模型
            logger.warning('Please specify the trained model: --resume /path/to/best_checkpoint.pth')
        
        # 评估验证集
        dev_stats = evaluate(args, config, dev_dataloader, model, tokenizer, epoch=0, beam_size=5,
                              generate_cfg=config['training']['validation']['translation'],
                              do_translation=config['do_translation'], do_recognition=config['do_recognition'])
        print(f"Dev loss of the network on the {len(dev_dataloader)} test videos: {dev_stats['loss']:.3f}")

        # 评估测试集
        test_stats = evaluate(args, config, test_dataloader, model, tokenizer, epoch=0, beam_size=5,
                              generate_cfg=config['testing']['translation'],
                              do_translation=config['do_translation'], do_recognition=config['do_recognition'])
        print(f"Test loss of the network on the {len(test_dataloader)} test videos: {test_stats['loss']:.3f}")
        return  # 评估完成,直接返回

九、训练循环

python 复制代码
    print(f"Start training for {args.epochs} epochs")
    start_time = time.time()  # 记录开始时间
    
    # 初始化最佳指标
    min_loss = 200  # 最小WER(手语识别指标,越低越好)
    bleu_4 = 0  # 最佳BLEU-4(手语翻译指标,越高越好)
    
    # 开始训练循环
    for epoch in range(args.start_epoch, args.epochs):
        scheduler.step()  # 更新学习率
        
        # 1. 训练一个epoch
        train_stats = train_one_epoch(args, model, tokenizer, train_dataloader, optimizer, device, epoch)
        
        # 2. 保存当前检查点
        checkpoint_paths = [output_dir / f'checkpoint.pth']  # 检查点保存路径
        for checkpoint_path in checkpoint_paths:
            utils.save_on_master({  # 只在主进程保存
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
                'epoch': epoch,
            }, checkpoint_path)
        
        # 3. 在验证集上评估
        test_stats = evaluate(args, config, dev_dataloader, model, tokenizer, epoch,
                              beam_size=config['training']['validation']['recognition']['beam_size'],
                              generate_cfg=config['training']['validation']['translation'],
                              do_translation=config['do_translation'], do_recognition=config['do_recognition'])

十、保存最佳模型逻辑

python 复制代码
        # 根据任务类型保存最佳模型
        if config['task'] == "S2T":  # 如果是手语翻译任务
            if bleu_4 < test_stats["bleu4"]:  # 如果当前BLEU-4更好
                bleu_4 = test_stats["bleu4"]  # 更新最佳BLEU-4
                checkpoint_paths = [output_dir / 'best_checkpoint.pth']  # 最佳模型路径
                for checkpoint_path in checkpoint_paths:
                    utils.save_on_master({  # 保存最佳模型
                        'model': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'scheduler': scheduler.state_dict(),
                        'epoch': epoch,
                    }, checkpoint_path)

            print(f"* DEV BLEU-4 {test_stats['bleu4']:.3f} Max DEV BLEU-4 {bleu_4}")
        else:  # 如果是手语识别任务
            if min_loss > test_stats["wer"]:  # 如果当前WER更低(更好)
                min_loss = test_stats["wer"]  # 更新最佳WER
                checkpoint_paths = [output_dir / 'best_checkpoint.pth']  # 最佳模型路径
                for checkpoint_path in checkpoint_paths:
                    utils.save_on_master({  # 保存最佳模型
                        'model': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'scheduler': scheduler.state_dict(),
                        'epoch': epoch,
                    }, checkpoint_path)
            print(f"* DEV wer {test_stats['wer']:.3f} Min DEV WER {min_loss}")

十一、日志记录

python 复制代码
        # 记录到wandb(如果启用)
        if args.run:
            args.run.log(
                {'epoch': epoch + 1, 
                 'training/train_loss': train_stats['loss'], 
                 'dev/dev_loss': test_stats['loss'],
                 'dev/min_loss': min_loss})

        # 合并训练和测试统计信息
        log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
                     **{f'test_{k}': v for k, v in test_stats.items()},
                     'epoch': epoch,
                     'n_parameters': n_parameters}

        # 写入日志文件
        with (output_dir / "log.txt").open("a") as f:
            f.write(json.dumps(log_stats) + "\n")

十二、训练结束后评估最佳模型

python 复制代码
        # 最后一个epoch:用最佳模型在测试集上评估
        test_on_last_epoch = True
        if test_on_last_epoch:
            # 加载最佳检查点
            checkpoint = torch.load(str(output_dir) + '/best_checkpoint.pth', map_location='cpu')
            model.load_state_dict(checkpoint['model'], strict=True)  # 加载最佳模型
            
            # 在验证集上评估
            dev_stats = evaluate(args, config, dev_dataloader, model, tokenizer, epoch=0, 
                                 beam_size=config['testing']['recognition']['beam_size'],
                                 generate_cfg=config['training']['validation']['translation'],
                                 do_translation=config['do_translation'], do_recognition=config['do_recognition'])
            print(f"Dev loss of the network on the {len(dev_dataloader)} test videos: {dev_stats['loss']:.3f}")
            
            # 在测试集上评估
            test_stats = evaluate(args, config, test_dataloader, model, tokenizer, epoch=0, 
                                  beam_size=config['testing']['recognition']['beam_size'],
                                  generate_cfg=config['testing']['translation'],
                                  do_translation=config['do_translation'], do_recognition=config['do_recognition'])
            print(f"Test loss of the network on the {len(test_dataloader)} test videos: {test_stats['loss']:.3f}")
            
            # 记录最终评估结果
            if config['do_recognition']:  # 如果做了手语识别
                with (output_dir / "log.txt").open("a") as f:
                    f.write(json.dumps({'Dev WER:': dev_stats['wer'],
                                        'Test WER:': test_stats['wer']}) + "\n")
            if config['do_translation']:  # 如果做了手语翻译
                with (output_dir / "log.txt").open("a") as f:
                    f.write(json.dumps({'Dev Bleu-4:': dev_stats['bleu4'],
                                        'Test Bleu-4:': test_stats['bleu4']}) + "\n")
    
    # 计算总训练时间
    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))

十三、单epoch训练函数

python 复制代码
def train_one_epoch(args, model: torch.nn.Module, criterion,
                    data_loader: Iterable, optimizer: torch.optim.Optimizer,
                    device: torch.device, epoch: int):
    model.train()  # 设置模型为训练模式
    metric_logger = utils.MetricLogger(delimiter="  ")  # 创建指标记录器
    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))  # 添加学习率记录
    header = 'Epoch: [{}/{}]'.format(epoch, args.epochs)  # 进度条头
    print_freq = 10  # 每10个batch打印一次
    
    # 遍历数据加载器
    for step, (src_input) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
        optimizer.zero_grad()  # 清空梯度
        
        output = model(src_input)  # 前向传播
        
        # 使用异常检测(调试用)
        with torch.autograd.set_detect_anomaly(True):
            output['total_loss'].backward()  # 反向传播计算梯度
        
        optimizer.step()  # 更新参数
        model.zero_grad()  # 再次清空梯度(安全)
        
        loss_value = output['total_loss'].item()  # 获取损失值
        
        # 检查损失是否有效
        if not math.isfinite(loss_value):
            print("Loss is {}, stopping training".format(loss_value))
            sys.exit(1)  # 损失无效,停止训练
        
        # 更新指标记录器
        metric_logger.update(loss=loss_value)
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])  # 记录学习率
    
    # 记录到wandb
    if args.run:
        args.run.log({'epoch': epoch + 1, 'epoch/train_loss': loss_value})
    
    # 收集所有进程的统计信息
    print("Averaged stats:", metric_logger)
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}  # 返回平均指标

十四、评估函数(重点)

python 复制代码
def evaluate(args, config, dev_dataloader, model, tokenizer, epoch, beam_size=1, generate_cfg={}, do_translation=True,
             do_recognition=True):
    model.eval()  # 设置模型为评估模式
    metric_logger = utils.MetricLogger(delimiter="  ")  # 指标记录器
    header = 'Test:'  # 进度条头
    print_freq = 10  # 打印频率
    results = defaultdict(dict)  # 存储每个样本的结果

    with torch.no_grad():  # 关闭梯度计算
        # 遍历验证/测试集
        for step, (src_input) in enumerate(metric_logger.log_every(dev_dataloader, print_freq, header)):
            output = model(src_input)  # 前向传播
            
            # 如果做手语识别
            if do_recognition:
                # 遍历所有输出,找到gloss_logits
                for k, gls_logits in output.items():
                    if not 'gloss_logits' in k:  # 如果不是gloss_logits,跳过
                        continue
                    
                    logits_name = k.replace('gloss_logits', '')  # 提取logits名称
                    
                    # CTC解码获取预测的gloss序列
                    ctc_decode_output = model.recognition_network.decode(
                        gloss_logits=gls_logits,  # gloss logits
                        beam_size=beam_size,  # beam search宽度
                        input_lengths=output['input_lengths'])  # 输入长度
                    
                    # 将ID转换为gloss token
                    batch_pred_gls = tokenizer.convert_ids_to_tokens(ctc_decode_output)
                    
                    # 存储每个样本的结果
                    for name, gls_hyp, gls_ref in zip(src_input['name'], batch_pred_gls, src_input['gloss']):
                        results[name][f'{logits_name}gls_hyp'] = \
                            ' '.join(gls_hyp).upper() if tokenizer.lower_case \
                                else ' '.join(gls_hyp)  # 预测的gloss序列
                        results[name]['gls_ref'] = gls_ref.upper() if tokenizer.lower_case \
                            else gls_ref  # 参考gloss序列
            
            # 如果做手语翻译
            if do_translation:
                # 生成德语文本
                generate_output = model.generate_txt(
                    transformer_inputs=output['transformer_inputs'],  # 翻译器输入
                    generate_cfg=generate_cfg)  # 生成配置
                
                # 存储每个样本的翻译结果
                for name, txt_hyp, txt_ref in zip(src_input['name'], generate_output['decoded_sequences'],
                                                  src_input['text']):
                    results[name]['txt_hyp'], results[name]['txt_ref'] = txt_hyp, txt_ref
            
            # 更新损失指标
            metric_logger.update(loss=output['total_loss'].item())

十五、评估指标计算

python 复制代码
        # 计算手语识别指标(WER)
        if do_recognition:
            evaluation_results = {}
            evaluation_results['wer'] = 200  # 初始化WER为较大值
            
            # 遍历所有预测结果
            for hyp_name in results[name].keys():
                if not 'gls_hyp' in hyp_name:  # 如果不是预测结果,跳过
                    continue
                
                k = hyp_name.replace('gls_hyp', '')  # 提取结果名称
                
                # 根据数据集类型进行数据清洗
                if config['data']['dataset_name'].lower() == 'phoenix-2014t':
                    # 清洗Phoenix-2014T数据
                    gls_ref = [clean_phoenix_2014_trans(results[n]['gls_ref']) for n in results]
                    gls_hyp = [clean_phoenix_2014_trans(results[n][hyp_name]) for n in results]
                elif config['data']['dataset_name'].lower() == 'phoenix-2014':
                    # 清洗Phoenix-2014数据
                    gls_ref = [clean_phoenix_2014(results[n]['gls_ref']) for n in results]
                    gls_hyp = [clean_phoenix_2014(results[n][hyp_name]) for n in results]
                elif config['data']['dataset_name'].lower() == 'csl-daily':
                    # CSL-Daily数据集
                    gls_ref = [results[n]['gls_ref'] for n in results]
                    gls_hyp = [results[n][hyp_name] for n in results]
                
                # 计算WER(词错误率)
                wer_results = wer_list(hypotheses=gls_hyp, references=gls_ref)
                evaluation_results[k + 'wer_list'] = wer_results  # 存储详细结果
                evaluation_results['wer'] = min(wer_results['wer'], evaluation_results['wer'])  # 取最小WER
            
            metric_logger.update(wer=evaluation_results['wer'])  # 更新WER指标

        # 计算手语翻译指标(BLEU, ROUGE)
        if do_translation:
            txt_ref = [results[n]['txt_ref'] for n in results]  # 参考翻译
            txt_hyp = [results[n]['txt_hyp'] for n in results]  # 预测翻译
            
            # 计算BLEU分数
            bleu_dict = bleu(references=txt_ref, hypotheses=txt_hyp, level=config['data']['level'])
            
            # 计算ROUGE分数
            rouge_score = rouge(references=txt_ref, hypotheses=txt_hyp, level=config['data']['level'])
            
            # 打印指标
            for k, v in bleu_dict.items():
                print('{} {:.2f}'.format(k, v))
            print('ROUGE: {:.2f}'.format(rouge_score))
            
            # 存储评估结果
            evaluation_results['rouge'], evaluation_results['bleu'] = rouge_score, bleu_dict
            
            # 记录到wandb
            wandb.log({'eval/BLEU4': bleu_dict['bleu4']})
            wandb.log({'eval/ROUGE': rouge_score})
            
            # 更新指标记录器
            metric_logger.update(bleu1=bleu_dict['bleu1'])
            metric_logger.update(bleu2=bleu_dict['bleu2'])
            metric_logger.update(bleu3=bleu_dict['bleu3'])
            metric_logger.update(bleu4=bleu_dict['bleu4'])
            metric_logger.update(rouge=rouge_score)

十六、评估结果记录

python 复制代码
    # 记录到wandb
    if args.run:
        args.run.log(
            {'epoch': epoch + 1, 
             'epoch/dev_loss': output['recognition_loss'].item(), 
             'wer': evaluation_results['wer']})
    
    # 打印平均统计信息
    print("* Averaged stats:", metric_logger)
    print('* DEV loss {losses.global_avg:.3f}'.format(losses=metric_logger.loss))

    # 返回所有指标的全局平均值
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}

十七、WandB实验设置

python 复制代码
def setup_run(args, config):
    # 如果所有进程都记录日志
    if args.log_all:
        os.environ["WANDB_MODE"] = config['training']['wandb'] if not args.eval else 'disabled'
        run = wandb.init(
            entity=args.entity,  # wandb组织/用户
            project=args.project,  # 项目名
            group=args.output_dir.split('/')[-1],  # 实验组名
            config=config,  # 配置参数
        )
        # 定义指标
        run.define_metric("epoch")
        run.define_metric("training/*", step_metric="epoch")
        run.define_metric("dev/*", step_metric="epoch")
    else:
        # 只在主进程记录日志
        if utils.is_main_process():
            os.environ["WANDB_MODE"] = config['training']['wandb'] if not args.eval else 'disabled'
            run = wandb.init(
                entity=args.entity,
                project=args.project,
                config=config,
            )
            run.define_metric("epoch")
            run.define_metric("training/*", step_metric="epoch")
            run.define_metric("dev/*", step_metric="epoch")
        else:
            os.environ["WANDB_MODE"] = 'disabled'
            run = False
    return run

十八、程序入口点

python 复制代码
if __name__ == '__main__':
    os.environ["TOKENIZERS_PARALLELISM"] = "false"  # 禁用tokenizers并行,避免警告
    
    # 创建参数解析器
    parser = argparse.ArgumentParser('Visual-Language-Pretraining (VLP) V2 scripts', parents=[get_args_parser()])
    args = parser.parse_args()  # 解析命令行参数
    
    # 加载YAML配置文件
    with open(args.config, 'r+', encoding='utf-8') as f:
        config = yaml.load(f, Loader=yaml.FullLoader)
    
    # 初始化wandb运行
    args.run = setup_run(args, config)
    
    # 创建模型保存目录
    Path(config['training']['model_dir']).mkdir(parents=True, exist_ok=True)
    
    # 运行主函数
    main(args, config)

关键总结

  1. 多任务架构:同时训练手语识别(Gloss序列)和手语翻译(德语文本)
  2. Phoenix数据集专用:使用专门的清洗函数处理数据集格式
  3. CTC解码:用于手语识别,处理视频帧与gloss序列的对齐问题
  4. Beam Search:用于手语翻译的文本生成
  5. 分布式训练:支持多GPU训练加速
  6. 完整的实验跟踪:WandB日志记录、模型检查点保存
  7. 评估指标
    • 手语识别:WER(词错误率)
    • 手语翻译:BLEU-4、ROUGE

这是一个工业级的手语识别与翻译系统实现,代码结构清晰,功能完整,支持训练、评估、分布式训练等多种功能。

相关推荐
王干脆2 小时前
面向人机协同的AI Agent设计范式:理论框架与架构实践
人工智能·ai·架构
Eugene__Chen2 小时前
Java的SPI机制(曼波版)
java·开发语言·python
程序猿20232 小时前
JVM与JAVA
java·jvm·python
橘子师兄2 小时前
C++AI大模型接入SDK—deepseek接入封装
c++·人工智能·chatgpt
黄小耶@2 小时前
基于 双向RNN网络 的中文文本预测模型
人工智能·rnn·深度学习
独隅2 小时前
本地大模型训练与 API 服务部署全栈方案:基于 Ubuntu 22.04 LTS 的端到端实现指南
服务器·python·语言模型
gdutxiaoxu2 小时前
browser-use - 让AI Agent真正“会“用浏览器
人工智能·ai agent
Fairy要carry2 小时前
面试-OnlyDecoder用于嵌入模型
人工智能
程序员侠客行2 小时前
Spring集成Mybatis原理详解
java·后端·spring·架构·mybatis