【DriveGen 文件详解】02——train.py

cpp 复制代码
DriveGen/
├── configs/
│   └── default.yaml              # 训练配置文件
├── DriveGen/
│   ├── __init__.py               # 包初始化
│   ├── models/
│   │   ├── __init__.py
│   │   ├── embedding.py          # Patch嵌入、时间步编码、位置编码
│   │   ├── attention.py          # 空间注意力、时间注意力
│   │   ├── dit_block.py          # AdaLN-Zero DiT Block
│   │   └── stdit.py              # STDiT 完整模型
│   ├── data/
│   │   ├── __init__.py
│   │   └── dataset.py            # 合成数据集 + nuScenes 适配器
│   ├── schedules/
│   │   ├── __init__.py
│   │   └── noise_schedule.py     # 线性/余弦噪声调度
│   └── utils/
│       ├── __init__.py
│       ├── visualization.py      # 视频保存、对比图、损失曲线
│       └── logger.py             # 日志工具
├── train.py                      # 训练脚本
├── inference.py                  # 推理脚本(DDPM 采样 + CFG)
├── evaluate.py                   # 评估脚本(FID 计算)
├── requirements.txt              # 依赖清单
├── setup.py                      # 安装配置
└── README.md                     # 本文件

LQY-hh/DriveGen-Transformer-: 自动驾驶技术的发展离不开海量数据的支撑,但稀有场景(如极端天气、突发事故)的数据采集成本极高。**DriveGen** 旨在通过扩散模型生成高质量的驾驶场景视频,为自动驾驶算法提供无限的虚拟训练数据。 ### 核心价值https://github.com/LQY-hh/DriveGen-Transformer-DriveGen 训练脚本技术文档

1. 文件概述

train.pyDriveGen 项目的核心训练入口,实现了一个完整的条件视频扩散模型训练流程。该脚本基于扩散模型(Diffusion Model)原理,从驾驶场景的条件帧出发,学习预测未来帧序列,可应用于自动驾驶仿真、数据增强等场景。

核心功能

模块 功能描述
参数解析 支持命令行参数与 YAML 配置文件混合使用
噪声调度 管理扩散过程的加噪/去噪时序
模型构建 支持 STDiT(时空扩散Transformer)和 SimpleUNet
训练循环 实现完整的扩散训练流程
检查点管理 支持保存/恢复训练进度
推理测试 训练结束后自动生成测试样本

2. 技术架构

2.1 整体架构图

复制代码
┌─────────────────────────────────────────────────────────────────────────┐
│                         train.py 训练流程                               │
├─────────────────────────────────────────────────────────────────────────┤
│  命令行参数 ──┐                                                         │
│              │                                                         │
│  YAML 配置 ──┼──→ 配置合并 ──→ 训练配置                                 │
│              │                                                         │
│              ↓                                                         │
│  ┌─────────────────────────────────────────────────────────────────┐    │
│  │                       初始化阶段                                │    │
│  │  随机种子 → 设备检测 → 日志设置 → 噪声调度器 → 数据加载器        │    │
│  │                          → 模型创建 → 优化器                   │    │
│  └─────────────────────────────────────────────────────────────────┘    │
│                              ↓                                         │
│  ┌─────────────────────────────────────────────────────────────────┐    │
│  │                       训练循环                                  │    │
│  │  ┌─────────────────────────────────────────────────────────┐   │    │
│  │  │ 前向扩散: 随机时间步 → 生成噪声 → 加噪到未来帧            │   │    │
│  │  └─────────────────────────────────────────────────────────┘   │    │
│  │                              ↓                                 │   │    │
│  │  ┌─────────────────────────────────────────────────────────┐   │    │
│  │  │ 模型预测: noisy_future + t + condition → noise_pred     │   │    │
│  │  └─────────────────────────────────────────────────────────┘   │    │
│  │                              ↓                                 │   │    │
│  │  ┌─────────────────────────────────────────────────────────┐   │    │
│  │  │ 损失计算: MSE(noise_pred, noise)                        │   │    │
│  │  └─────────────────────────────────────────────────────────┘   │    │
│  │                              ↓                                 │   │    │
│  │  ┌─────────────────────────────────────────────────────────┐   │    │
│  │  │ 梯度累积 → 梯度裁剪 → 优化器更新                         │   │    │
│  │  └─────────────────────────────────────────────────────────┘   │    │
│  │                              ↓                                 │   │    │
│  │  ┌─────────────────────────────────────────────────────────┐   │    │
│  │  │ 检查点保存 / 最佳模型保存 / 损失曲线绘制                 │   │    │
│  │  └─────────────────────────────────────────────────────────┘   │    │
│  └─────────────────────────────────────────────────────────────────┘    │
│                              ↓                                         │
│  ┌─────────────────────────────────────────────────────────────────┐    │
│  │                    训练结束处理                                  │    │
│  │  保存最终检查点 → 绘制损失曲线 → 快速推理测试                    │    │
│  └─────────────────────────────────────────────────────────────────┘    │
└─────────────────────────────────────────────────────────────────────────┘

2.2 文件依赖关系

复制代码
train.py
├── schedules/noise_schedule.py     # 噪声调度器
│   └── NoiseScheduler 类
│   └── get_noise_schedule() 函数
├── data/dataset.py                 # 数据集模块
│   └── get_dataloader() 函数
│   └── SyntheticDrivingDataset 类
├── utils/logger.py                 # 日志工具
│   └── setup_logger() 函数
├── utils/visualization.py          # 可视化工具
│   ├── save_video()
│   ├── save_comparison()
│   ├── plot_training_loss()
│   ├── save_frames_as_images()
│   └── create_gif()
└── DriveGen/models/                # 模型模块(可选)
    └── STDiT 类

3. 核心函数详解

3.1 参数解析:parse_args()

文件位置:train.py#L47-L116

功能:解析命令行参数,支持灵活的训练配置。

参数 类型 默认值 说明
--config str configs/default.yaml 配置文件路径
--resume str None 从检查点恢复训练
--epochs int None 训练轮数(覆盖配置)
--lr float None 学习率(覆盖配置)
--batch_size int None 批量大小(覆盖配置)
--seed int 42 随机种子
--device str None 计算设备
--output_dir str None 输出目录

使用示例

复制代码
# 使用默认配置
python train.py

# 指定训练参数
python train.py --epochs 100 --lr 0.0005 --batch_size 8

# 从检查点恢复
python train.py --resume checkpoints/latest.pth

# 使用自定义配置文件
python train.py --config configs/custom.yaml

3.2 配置加载:load_config()

文件位置:train.py#L119-L140

功能:加载 YAML 配置文件并返回配置字典。

复制代码
def load_config(config_path: str) -> Dict[str, Any]:
    if not os.path.exists(config_path):
        raise FileNotFoundError(f"配置文件不存在: {config_path}")
    
    with open(config_path, 'r', encoding='utf-8') as f:
        config = yaml.safe_load(f)
    
    return config

配置文件结构(示例):

复制代码
model:
  hidden_dim: 192
  depth: 6
  num_heads: 6
  patch_size: 4
  num_frames: 4
  condition_frames: 1

training:
  num_epochs: 50
  learning_rate: 0.0001
  batch_size: 8
  gradient_accumulation_steps: 4

noise:
  num_timesteps: 1000
  schedule: cosine
  beta_start: 0.0001
  beta_end: 0.02

3.3 设备管理:get_device()

文件位置:train.py#L166-L191

功能:自动检测并返回最佳计算设备。

设备优先级

  1. 用户指定设备(--device 参数)

  2. CUDA(NVIDIA GPU)

  3. MPS(Apple Silicon GPU)

  4. CPU(兜底方案)

    def get_device(device_str: Optional[str] = None) -> torch.device:
    if device_str is not None:
    return torch.device(device_str)

    复制代码
     if torch.cuda.is_available():
         device = torch.device('cuda')
         print(f"[get_device] 使用 CUDA: {torch.cuda.get_device_name(0)}")
     elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
         device = torch.device('mps')
         print("[get_device] 使用 Apple MPS")
     else:
         device = torch.device('cpu')
     
     return device

3.4 模型创建:create_model()

文件位置:train.py#L194-L236

功能:创建扩散模型,优先使用 STDiT,失败则回退到 SimpleUNet。

模型类型 输入格式 适用场景
STDiT (B, C, T, H, W) 正式训练,时空Transformer架构
SimpleUNet (B, T, C, H, W) 测试/验证,轻量级CNN

模型配置参数

参数 默认值 说明
hidden_dim 192 隐藏层维度
depth 6 Transformer层数
num_heads 6 注意力头数
patch_size 4 图像分块大小
num_frames 4 总帧数(条件帧+未来帧)
condition_frames 1 条件帧数
max_timestep 1000 最大时间步数
dropout 0.1 Dropout比例

3.5 SimpleUNet 模型实现

文件位置:train.py#L239-L372

架构设计

复制代码
输入: [noisy_frames (B,F,C,H,W) + condition_frames (B,1,C,H,W)]
         ↓
    [展平 + 拼接] → (B, (F+1)*C, H, W)
         ↓
    ┌─────────────────────────────────────────────────────────┐
    │                    编码器                               │
    │  enc1 → pool → enc2 → pool → enc3 → pool → bottleneck │
    └─────────────────────────────────────────────────────────┘
         ↓
    [时间步嵌入注入]
         ↓
    ┌─────────────────────────────────────────────────────────┐
    │                    解码器(带跳跃连接)                  │
    │  up3 → dec3 → up2 → dec2 → up1 → dec1 → output        │
    └─────────────────────────────────────────────────────────┘
         ↓
    输出: noise_pred (B, F, C, H, W)

核心组件

组件 功能
time_mlp 将时间步编码为嵌入向量
enc1/enc2/enc3 编码器卷积块
bottleneck 瓶颈层,融合时间信息
up1/up2/up3 反卷积上采样
dec1/dec2/dec3 解码器卷积块(带跳跃连接)

前向传播

复制代码
def forward(self, noisy_frames, t, condition_frames):
    # 展平帧维度并拼接条件帧
    x = noisy_frames.reshape(B, F * C, H, W)
    cond = condition_frames.reshape(B, self.condition_frames * C, H, W)
    x = torch.cat([x, cond], dim=1)
    
    # 时间步嵌入
    t_emb = self.time_mlp(t.float().unsqueeze(-1) / 1000.0)
    
    # 编码器
    e1 = self.enc1(x)
    e2 = self.enc2(self.pool(e1))
    e3 = self.enc3(self.pool(e2))
    
    # 瓶颈层(注入时间信息)
    b = self.bottleneck(self.pool(e3))
    b = b + self.time_inject3(t_emb).unsqueeze(-1).unsqueeze(-1)
    
    # 解码器(带跳跃连接)
    d3 = self.up3(b)
    d3 = torch.cat([d3, e3], dim=1)  # 跳跃连接
    # ... 后续解码层
    
    # 输出预测噪声
    out = self.output(d1)
    out = out.reshape(B, F, C, H, W)
    
    return out

3.6 检查点管理

3.6.1 保存检查点:save_checkpoint()

文件位置:train.py#L390-L426

检查点包含内容

字段 类型 说明
epoch int 当前训练轮数
step int 全局训练步数
best_loss float 最佳损失值
losses list 损失历史记录
config dict 训练配置
model_state_dict dict 模型权重
optimizer_state_dict dict 优化器状态
3.6.2 加载检查点:load_checkpoint()

文件位置:train.py#L429-L462

功能:从检查点恢复模型和优化器状态,支持断点续训。


3.7 快速推理测试:quick_inference_test()

文件位置:train.py#L466-L590

功能:训练结束后自动执行推理测试,生成样本视频验证模型效果。

采样流程

复制代码
# 1. 创建条件帧(渐变图案测试)
condition = torch.zeros(1, condition_frames, in_channels, image_size, image_size)

# 2. 从纯噪声开始
x = torch.randn(1, future_frames, in_channels, image_size, image_size)

# 3. DDPM 采样循环
for i in range(len(step_indices) - 1, -1, -1):
    t = step_indices[i]
    
    # 预测噪声
    noise_pred = model(x, t, condition)
    
    # DDPM 更新公式
    # x_{t-1} = (1/sqrt(alpha_t)) * (x_t - noise_pred * (1-alpha_t)/sqrt(1-alpha_bar_t))
    #          + sigma_t * z
    x = x_0_pred if i == 0 else torch.sqrt(alpha_t_prev) * x_0_pred + sigma_t * noise

输出格式

格式 保存路径 说明
MP4 视频 outputs/test_samples/sample_X/generated.mp4 生成的视频
PNG 帧 outputs/test_samples/sample_X/frame_XX.png 单独帧图片
GIF 动图 outputs/test_samples/sample_X/generated.gif 动画预览

3.8 主训练函数:train()

文件位置:train.py#L593-L851

训练流程详解

阶段一:初始化
复制代码
# 设置随机种子
torch.manual_seed(seed)
np.random.seed(seed)

# 获取计算设备
device = get_device(args.device)

# 创建噪声调度器
noise_scheduler = NoiseScheduler(
    num_timesteps=noise_config['num_timesteps'],
    schedule_type=noise_config['schedule'],
)

# 创建数据加载器
dataloader = get_dataloader(config)

# 创建模型
model, use_stdit = create_model(config, device)

# 创建优化器
optimizer = AdamW(
    model.parameters(),
    lr=train_config['learning_rate'],
    betas=(0.9, 0.999),
    weight_decay=0.01,
)
阶段二:恢复训练
复制代码
resume_path = args.resume or train_config.get('resume')
if resume_path:
    checkpoint = load_checkpoint(resume_path, model, optimizer)
    start_epoch = checkpoint['epoch'] + 1
    global_step = checkpoint['step']
    best_loss = checkpoint['best_loss']
    losses_history = checkpoint.get('losses', [])
阶段三:训练循环

单次迭代流程

复制代码
for batch in dataloader:
    # 1. 数据预处理
    condition_frames = batch['condition_frames'].to(device)
    future_frames = batch['future_frames'].to(device)
    # 归一化到 [-1, 1]
    condition_frames = condition_frames * 2.0 - 1.0
    future_frames = future_frames * 2.0 - 1.0
    
    # 2. 前向扩散(加噪)
    t = torch.randint(0, max_timesteps, (B,), device=device)
    noise = torch.randn_like(future_frames)
    noisy_future = noise_scheduler.add_noise(future_frames, t, noise)
    
    # 3. 模型预测
    if use_stdit:
        noise_pred = model(noisy_future.permute(0,2,1,3,4), t, condition_frames.permute(0,2,1,3,4))
        noise_pred = noise_pred.permute(0,2,1,3,4)
    else:
        noise_pred = model(noisy_future, t, condition_frames)
    
    # 4. 损失计算
    loss = nn.functional.mse_loss(noise_pred, noise)
    
    # 5. 梯度累积与更新
    loss = loss / grad_accum_steps
    loss.backward()
    
    if (batch_idx + 1) % grad_accum_steps == 0:
        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
        optimizer.step()
        optimizer.zero_grad()
        global_step += 1
阶段四:检查点保存
复制代码
# 定期保存检查点
if (epoch + 1) % save_every == 0:
    save_checkpoint(model, optimizer, epoch, global_step, best_loss, losses_history, config, 'latest.pth')

# 保存最佳模型
if avg_epoch_loss < best_loss:
    best_loss = avg_epoch_loss
    save_checkpoint(model, optimizer, epoch, global_step, best_loss, losses_history, config, 'best.pth')
阶段五:训练结束处理
复制代码
# 保存最终检查点
save_checkpoint(model, optimizer, num_epochs - 1, global_step, best_loss, losses_history, config, 'final.pth')

# 绘制损失曲线
plot_training_loss(losses_history, 'loss_curve_final.png')

# 快速推理测试
quick_inference_test(model, noise_scheduler, config, device, 'outputs', logger)

3.9 主入口:main()

文件位置:train.py#L855-L883

执行流程

  1. 解析命令行参数

  2. 加载 YAML 配置文件

  3. 合并命令行参数与配置(命令行参数优先)

  4. 启动训练

  5. 处理异常(KeyboardInterrupt、其他异常)


4. 扩散模型核心原理

4.1 前向扩散过程

扩散模型通过逐步向数据添加噪声来学习数据分布。对于视频生成任务:

复制代码
x_0 → x_1 → x_2 → ... → x_T

其中 x_0 是原始视频帧,x_T 是纯高斯噪声。

加噪公式

复制代码
x_t = sqrt(alpha_cumprod_t) * x_0 + sqrt(1 - alpha_cumprod_t) * epsilon
  • alpha_cumprod_t:时间步 t 的累积噪声系数

  • epsilon:标准高斯噪声

4.2 反向去噪过程

训练时,模型学习预测噪声:

复制代码
model(x_t, t, condition) → epsilon_pred

损失函数

复制代码
loss = MSE(epsilon_pred, epsilon)

4.3 采样过程

推理时,从纯噪声开始逐步去噪:

复制代码
for t in range(T, 0, -1):
    epsilon_pred = model(x_t, t, condition)
    x_{t-1} = (x_t - (1-alpha_t)/sqrt(1-alpha_cumprod_t) * epsilon_pred) / sqrt(alpha_t)
              + sigma_t * noise  # 仅在 t > 0 时添加噪声

5. 关键技术特性

5.1 梯度累积

作用:在显存有限的情况下,模拟更大的批量大小。

复制代码
grad_accum_steps = train_config.get('gradient_accumulation_steps', 1)
loss = loss / grad_accum_steps  # 损失除以累积步数
loss.backward()

if (batch_idx + 1) % grad_accum_steps == 0:
    optimizer.step()
    optimizer.zero_grad()

等效批量大小batch_size * grad_accum_steps

5.2 梯度裁剪

作用:防止梯度爆炸,稳定训练。

复制代码
grad_clip = train_config.get('grad_clip', 1.0)
if grad_clip > 0:
    torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)

5.3 噪声调度策略

支持的调度类型

类型 特点 适用场景
linear 线性增加噪声 简单场景
cosine 余弦调度,前期噪声增长慢 复杂数据分布
sigmoid Sigmoid 调度 平滑过渡

6. 输出文件结构

复制代码
outputs/
├── checkpoints/
│   ├── latest.pth        # 最新检查点
│   ├── best.pth          # 最佳模型
│   └── final.pth         # 最终检查点
├── logs/
│   ├── train.log         # 训练日志
│   ├── loss_curve.png    # 损失曲线
│   └── loss_curve_final.png
└── test_samples/
    ├── sample_0/
    │   ├── generated.mp4
    │   ├── generated.gif
    │   └── frame_00.png, frame_01.png, ...
    └── sample_1/
        ├── generated.mp4
        ├── generated.gif
        └── frame_00.png, frame_01.png, ...

7. 性能优化建议

7.1 训练效率优化

优化项 建议值 说明
批量大小 8-32 根据显存调整
梯度累积 4-8 模拟大批次
混合精度 FP16/FP8 加速训练,减少显存
数据预加载 开启 使用 pin_memory=True

7.2 超参数建议

参数 建议范围 说明
学习率 1e-4 ~ 5e-4 AdamW 优化器
权重衰减 1e-2 ~ 1e-4 防止过拟合
时间步数 1000 标准配置
beta_start 1e-4 初始噪声系数
beta_end 0.02 终止噪声系数

8. 常见问题

8.1 STDiT 导入失败

问题 :运行时提示 ModuleNotFoundError: No module named 'DriveGen.models'

解决方案

  1. 确保 DriveGen/models/ 目录存在且包含 __init__.py 和 STDiT 实现

  2. 脚本会自动回退到 SimpleUNet,可用于测试

8.2 显存不足

解决方案

  1. 减小 batch_size

  2. 增加 gradient_accumulation_steps

  3. 使用更小的模型配置(hidden_dim, depth

  4. 启用混合精度训练

8.3 训练损失不下降

排查步骤

  1. 检查数据加载是否正确(帧顺序、归一化范围)

  2. 确认噪声调度器配置正确

  3. 检查学习率是否合适

  4. 验证模型输出形状是否与标签匹配

相关推荐
哥布林学者1 小时前
深度学习进阶(二十三)偏置型 RPE
机器学习·ai
Frank学习路上1 小时前
【AI技能】跟着费曼学轨迹预测
人工智能·自动驾驶
workflower1 小时前
人工智能全球治理
大数据·人工智能·设计模式·机器人·动态规划
workflower1 小时前
AI灵活高效的智慧用能核心场景
大数据·人工智能·设计模式·机器人·动态规划
长桥夜波1 小时前
【第四十周】VLN
人工智能·计算机视觉
爱摸鱼的打工仔1 小时前
【VLLM启动大模型共享内存不足-AI知识点】
人工智能
初心未改HD1 小时前
深度学习之正则化技术详解
人工智能·深度学习
user29876982706541 小时前
三、Skills 进阶:Fork 模式与上下文控制
人工智能
闵孚龙1 小时前
Claude Code CLAUDE.md 用户指令覆盖层全解析:AI Agent 记忆系统、上下文工程、规则分层、团队协作与安全治理
人工智能·安全