cpp
DriveGen/
├── configs/
│ └── default.yaml # 训练配置文件
├── DriveGen/
│ ├── __init__.py # 包初始化
│ ├── models/
│ │ ├── __init__.py
│ │ ├── embedding.py # Patch嵌入、时间步编码、位置编码
│ │ ├── attention.py # 空间注意力、时间注意力
│ │ ├── dit_block.py # AdaLN-Zero DiT Block
│ │ └── stdit.py # STDiT 完整模型
│ ├── data/
│ │ ├── __init__.py
│ │ └── dataset.py # 合成数据集 + nuScenes 适配器
│ ├── schedules/
│ │ ├── __init__.py
│ │ └── noise_schedule.py # 线性/余弦噪声调度
│ └── utils/
│ ├── __init__.py
│ ├── visualization.py # 视频保存、对比图、损失曲线
│ └── logger.py # 日志工具
├── train.py # 训练脚本
├── inference.py # 推理脚本(DDPM 采样 + CFG)
├── evaluate.py # 评估脚本(FID 计算)
├── requirements.txt # 依赖清单
├── setup.py # 安装配置
└── README.md # 本文件
LQY-hh/DriveGen-Transformer-: 自动驾驶技术的发展离不开海量数据的支撑,但稀有场景(如极端天气、突发事故)的数据采集成本极高。**DriveGen** 旨在通过扩散模型生成高质量的驾驶场景视频,为自动驾驶算法提供无限的虚拟训练数据。 ### 核心价值
https://github.com/LQY-hh/DriveGen-Transformer-DriveGen 训练脚本技术文档
1. 文件概述
train.py 是 DriveGen 项目的核心训练入口,实现了一个完整的条件视频扩散模型训练流程。该脚本基于扩散模型(Diffusion Model)原理,从驾驶场景的条件帧出发,学习预测未来帧序列,可应用于自动驾驶仿真、数据增强等场景。
核心功能
| 模块 | 功能描述 |
|---|---|
| 参数解析 | 支持命令行参数与 YAML 配置文件混合使用 |
| 噪声调度 | 管理扩散过程的加噪/去噪时序 |
| 模型构建 | 支持 STDiT(时空扩散Transformer)和 SimpleUNet |
| 训练循环 | 实现完整的扩散训练流程 |
| 检查点管理 | 支持保存/恢复训练进度 |
| 推理测试 | 训练结束后自动生成测试样本 |
2. 技术架构
2.1 整体架构图
┌─────────────────────────────────────────────────────────────────────────┐
│ train.py 训练流程 │
├─────────────────────────────────────────────────────────────────────────┤
│ 命令行参数 ──┐ │
│ │ │
│ YAML 配置 ──┼──→ 配置合并 ──→ 训练配置 │
│ │ │
│ ↓ │
│ ┌─────────────────────────────────────────────────────────────────┐ │
│ │ 初始化阶段 │ │
│ │ 随机种子 → 设备检测 → 日志设置 → 噪声调度器 → 数据加载器 │ │
│ │ → 模型创建 → 优化器 │ │
│ └─────────────────────────────────────────────────────────────────┘ │
│ ↓ │
│ ┌─────────────────────────────────────────────────────────────────┐ │
│ │ 训练循环 │ │
│ │ ┌─────────────────────────────────────────────────────────┐ │ │
│ │ │ 前向扩散: 随机时间步 → 生成噪声 → 加噪到未来帧 │ │ │
│ │ └─────────────────────────────────────────────────────────┘ │ │
│ │ ↓ │ │ │
│ │ ┌─────────────────────────────────────────────────────────┐ │ │
│ │ │ 模型预测: noisy_future + t + condition → noise_pred │ │ │
│ │ └─────────────────────────────────────────────────────────┘ │ │
│ │ ↓ │ │ │
│ │ ┌─────────────────────────────────────────────────────────┐ │ │
│ │ │ 损失计算: MSE(noise_pred, noise) │ │ │
│ │ └─────────────────────────────────────────────────────────┘ │ │
│ │ ↓ │ │ │
│ │ ┌─────────────────────────────────────────────────────────┐ │ │
│ │ │ 梯度累积 → 梯度裁剪 → 优化器更新 │ │ │
│ │ └─────────────────────────────────────────────────────────┘ │ │
│ │ ↓ │ │ │
│ │ ┌─────────────────────────────────────────────────────────┐ │ │
│ │ │ 检查点保存 / 最佳模型保存 / 损失曲线绘制 │ │ │
│ │ └─────────────────────────────────────────────────────────┘ │ │
│ └─────────────────────────────────────────────────────────────────┘ │
│ ↓ │
│ ┌─────────────────────────────────────────────────────────────────┐ │
│ │ 训练结束处理 │ │
│ │ 保存最终检查点 → 绘制损失曲线 → 快速推理测试 │ │
│ └─────────────────────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────────────────────┘
2.2 文件依赖关系
train.py
├── schedules/noise_schedule.py # 噪声调度器
│ └── NoiseScheduler 类
│ └── get_noise_schedule() 函数
├── data/dataset.py # 数据集模块
│ └── get_dataloader() 函数
│ └── SyntheticDrivingDataset 类
├── utils/logger.py # 日志工具
│ └── setup_logger() 函数
├── utils/visualization.py # 可视化工具
│ ├── save_video()
│ ├── save_comparison()
│ ├── plot_training_loss()
│ ├── save_frames_as_images()
│ └── create_gif()
└── DriveGen/models/ # 模型模块(可选)
└── STDiT 类
3. 核心函数详解
3.1 参数解析:parse_args()
文件位置:train.py#L47-L116
功能:解析命令行参数,支持灵活的训练配置。
| 参数 | 类型 | 默认值 | 说明 |
|---|---|---|---|
--config |
str | configs/default.yaml |
配置文件路径 |
--resume |
str | None |
从检查点恢复训练 |
--epochs |
int | None |
训练轮数(覆盖配置) |
--lr |
float | None |
学习率(覆盖配置) |
--batch_size |
int | None |
批量大小(覆盖配置) |
--seed |
int | 42 |
随机种子 |
--device |
str | None |
计算设备 |
--output_dir |
str | None |
输出目录 |
使用示例:
# 使用默认配置
python train.py
# 指定训练参数
python train.py --epochs 100 --lr 0.0005 --batch_size 8
# 从检查点恢复
python train.py --resume checkpoints/latest.pth
# 使用自定义配置文件
python train.py --config configs/custom.yaml
3.2 配置加载:load_config()
文件位置:train.py#L119-L140
功能:加载 YAML 配置文件并返回配置字典。
def load_config(config_path: str) -> Dict[str, Any]:
if not os.path.exists(config_path):
raise FileNotFoundError(f"配置文件不存在: {config_path}")
with open(config_path, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
return config
配置文件结构(示例):
model:
hidden_dim: 192
depth: 6
num_heads: 6
patch_size: 4
num_frames: 4
condition_frames: 1
training:
num_epochs: 50
learning_rate: 0.0001
batch_size: 8
gradient_accumulation_steps: 4
noise:
num_timesteps: 1000
schedule: cosine
beta_start: 0.0001
beta_end: 0.02
3.3 设备管理:get_device()
文件位置:train.py#L166-L191
功能:自动检测并返回最佳计算设备。
设备优先级:
-
用户指定设备(
--device参数) -
CUDA(NVIDIA GPU)
-
MPS(Apple Silicon GPU)
-
CPU(兜底方案)
def get_device(device_str: Optional[str] = None) -> torch.device:
if device_str is not None:
return torch.device(device_str)if torch.cuda.is_available(): device = torch.device('cuda') print(f"[get_device] 使用 CUDA: {torch.cuda.get_device_name(0)}") elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): device = torch.device('mps') print("[get_device] 使用 Apple MPS") else: device = torch.device('cpu') return device
3.4 模型创建:create_model()
文件位置:train.py#L194-L236
功能:创建扩散模型,优先使用 STDiT,失败则回退到 SimpleUNet。
| 模型类型 | 输入格式 | 适用场景 |
|---|---|---|
| STDiT | (B, C, T, H, W) |
正式训练,时空Transformer架构 |
| SimpleUNet | (B, T, C, H, W) |
测试/验证,轻量级CNN |
模型配置参数:
| 参数 | 默认值 | 说明 |
|---|---|---|
hidden_dim |
192 | 隐藏层维度 |
depth |
6 | Transformer层数 |
num_heads |
6 | 注意力头数 |
patch_size |
4 | 图像分块大小 |
num_frames |
4 | 总帧数(条件帧+未来帧) |
condition_frames |
1 | 条件帧数 |
max_timestep |
1000 | 最大时间步数 |
dropout |
0.1 | Dropout比例 |
3.5 SimpleUNet 模型实现
文件位置:train.py#L239-L372
架构设计:
输入: [noisy_frames (B,F,C,H,W) + condition_frames (B,1,C,H,W)]
↓
[展平 + 拼接] → (B, (F+1)*C, H, W)
↓
┌─────────────────────────────────────────────────────────┐
│ 编码器 │
│ enc1 → pool → enc2 → pool → enc3 → pool → bottleneck │
└─────────────────────────────────────────────────────────┘
↓
[时间步嵌入注入]
↓
┌─────────────────────────────────────────────────────────┐
│ 解码器(带跳跃连接) │
│ up3 → dec3 → up2 → dec2 → up1 → dec1 → output │
└─────────────────────────────────────────────────────────┘
↓
输出: noise_pred (B, F, C, H, W)
核心组件:
| 组件 | 功能 |
|---|---|
time_mlp |
将时间步编码为嵌入向量 |
enc1/enc2/enc3 |
编码器卷积块 |
bottleneck |
瓶颈层,融合时间信息 |
up1/up2/up3 |
反卷积上采样 |
dec1/dec2/dec3 |
解码器卷积块(带跳跃连接) |
前向传播:
def forward(self, noisy_frames, t, condition_frames):
# 展平帧维度并拼接条件帧
x = noisy_frames.reshape(B, F * C, H, W)
cond = condition_frames.reshape(B, self.condition_frames * C, H, W)
x = torch.cat([x, cond], dim=1)
# 时间步嵌入
t_emb = self.time_mlp(t.float().unsqueeze(-1) / 1000.0)
# 编码器
e1 = self.enc1(x)
e2 = self.enc2(self.pool(e1))
e3 = self.enc3(self.pool(e2))
# 瓶颈层(注入时间信息)
b = self.bottleneck(self.pool(e3))
b = b + self.time_inject3(t_emb).unsqueeze(-1).unsqueeze(-1)
# 解码器(带跳跃连接)
d3 = self.up3(b)
d3 = torch.cat([d3, e3], dim=1) # 跳跃连接
# ... 后续解码层
# 输出预测噪声
out = self.output(d1)
out = out.reshape(B, F, C, H, W)
return out
3.6 检查点管理
3.6.1 保存检查点:save_checkpoint()
文件位置:train.py#L390-L426
检查点包含内容:
| 字段 | 类型 | 说明 |
|---|---|---|
epoch |
int | 当前训练轮数 |
step |
int | 全局训练步数 |
best_loss |
float | 最佳损失值 |
losses |
list | 损失历史记录 |
config |
dict | 训练配置 |
model_state_dict |
dict | 模型权重 |
optimizer_state_dict |
dict | 优化器状态 |
3.6.2 加载检查点:load_checkpoint()
文件位置:train.py#L429-L462
功能:从检查点恢复模型和优化器状态,支持断点续训。
3.7 快速推理测试:quick_inference_test()
文件位置:train.py#L466-L590
功能:训练结束后自动执行推理测试,生成样本视频验证模型效果。
采样流程:
# 1. 创建条件帧(渐变图案测试)
condition = torch.zeros(1, condition_frames, in_channels, image_size, image_size)
# 2. 从纯噪声开始
x = torch.randn(1, future_frames, in_channels, image_size, image_size)
# 3. DDPM 采样循环
for i in range(len(step_indices) - 1, -1, -1):
t = step_indices[i]
# 预测噪声
noise_pred = model(x, t, condition)
# DDPM 更新公式
# x_{t-1} = (1/sqrt(alpha_t)) * (x_t - noise_pred * (1-alpha_t)/sqrt(1-alpha_bar_t))
# + sigma_t * z
x = x_0_pred if i == 0 else torch.sqrt(alpha_t_prev) * x_0_pred + sigma_t * noise
输出格式:
| 格式 | 保存路径 | 说明 |
|---|---|---|
| MP4 视频 | outputs/test_samples/sample_X/generated.mp4 |
生成的视频 |
| PNG 帧 | outputs/test_samples/sample_X/frame_XX.png |
单独帧图片 |
| GIF 动图 | outputs/test_samples/sample_X/generated.gif |
动画预览 |
3.8 主训练函数:train()
文件位置:train.py#L593-L851
训练流程详解:
阶段一:初始化
# 设置随机种子
torch.manual_seed(seed)
np.random.seed(seed)
# 获取计算设备
device = get_device(args.device)
# 创建噪声调度器
noise_scheduler = NoiseScheduler(
num_timesteps=noise_config['num_timesteps'],
schedule_type=noise_config['schedule'],
)
# 创建数据加载器
dataloader = get_dataloader(config)
# 创建模型
model, use_stdit = create_model(config, device)
# 创建优化器
optimizer = AdamW(
model.parameters(),
lr=train_config['learning_rate'],
betas=(0.9, 0.999),
weight_decay=0.01,
)
阶段二:恢复训练
resume_path = args.resume or train_config.get('resume')
if resume_path:
checkpoint = load_checkpoint(resume_path, model, optimizer)
start_epoch = checkpoint['epoch'] + 1
global_step = checkpoint['step']
best_loss = checkpoint['best_loss']
losses_history = checkpoint.get('losses', [])
阶段三:训练循环
单次迭代流程:
for batch in dataloader:
# 1. 数据预处理
condition_frames = batch['condition_frames'].to(device)
future_frames = batch['future_frames'].to(device)
# 归一化到 [-1, 1]
condition_frames = condition_frames * 2.0 - 1.0
future_frames = future_frames * 2.0 - 1.0
# 2. 前向扩散(加噪)
t = torch.randint(0, max_timesteps, (B,), device=device)
noise = torch.randn_like(future_frames)
noisy_future = noise_scheduler.add_noise(future_frames, t, noise)
# 3. 模型预测
if use_stdit:
noise_pred = model(noisy_future.permute(0,2,1,3,4), t, condition_frames.permute(0,2,1,3,4))
noise_pred = noise_pred.permute(0,2,1,3,4)
else:
noise_pred = model(noisy_future, t, condition_frames)
# 4. 损失计算
loss = nn.functional.mse_loss(noise_pred, noise)
# 5. 梯度累积与更新
loss = loss / grad_accum_steps
loss.backward()
if (batch_idx + 1) % grad_accum_steps == 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
optimizer.step()
optimizer.zero_grad()
global_step += 1
阶段四:检查点保存
# 定期保存检查点
if (epoch + 1) % save_every == 0:
save_checkpoint(model, optimizer, epoch, global_step, best_loss, losses_history, config, 'latest.pth')
# 保存最佳模型
if avg_epoch_loss < best_loss:
best_loss = avg_epoch_loss
save_checkpoint(model, optimizer, epoch, global_step, best_loss, losses_history, config, 'best.pth')
阶段五:训练结束处理
# 保存最终检查点
save_checkpoint(model, optimizer, num_epochs - 1, global_step, best_loss, losses_history, config, 'final.pth')
# 绘制损失曲线
plot_training_loss(losses_history, 'loss_curve_final.png')
# 快速推理测试
quick_inference_test(model, noise_scheduler, config, device, 'outputs', logger)
3.9 主入口:main()
文件位置:train.py#L855-L883
执行流程:
-
解析命令行参数
-
加载 YAML 配置文件
-
合并命令行参数与配置(命令行参数优先)
-
启动训练
-
处理异常(KeyboardInterrupt、其他异常)
4. 扩散模型核心原理
4.1 前向扩散过程
扩散模型通过逐步向数据添加噪声来学习数据分布。对于视频生成任务:
x_0 → x_1 → x_2 → ... → x_T
其中 x_0 是原始视频帧,x_T 是纯高斯噪声。
加噪公式:
x_t = sqrt(alpha_cumprod_t) * x_0 + sqrt(1 - alpha_cumprod_t) * epsilon
-
alpha_cumprod_t:时间步 t 的累积噪声系数 -
epsilon:标准高斯噪声
4.2 反向去噪过程
训练时,模型学习预测噪声:
model(x_t, t, condition) → epsilon_pred
损失函数:
loss = MSE(epsilon_pred, epsilon)
4.3 采样过程
推理时,从纯噪声开始逐步去噪:
for t in range(T, 0, -1):
epsilon_pred = model(x_t, t, condition)
x_{t-1} = (x_t - (1-alpha_t)/sqrt(1-alpha_cumprod_t) * epsilon_pred) / sqrt(alpha_t)
+ sigma_t * noise # 仅在 t > 0 时添加噪声
5. 关键技术特性
5.1 梯度累积
作用:在显存有限的情况下,模拟更大的批量大小。
grad_accum_steps = train_config.get('gradient_accumulation_steps', 1)
loss = loss / grad_accum_steps # 损失除以累积步数
loss.backward()
if (batch_idx + 1) % grad_accum_steps == 0:
optimizer.step()
optimizer.zero_grad()
等效批量大小 :batch_size * grad_accum_steps
5.2 梯度裁剪
作用:防止梯度爆炸,稳定训练。
grad_clip = train_config.get('grad_clip', 1.0)
if grad_clip > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
5.3 噪声调度策略
支持的调度类型:
| 类型 | 特点 | 适用场景 |
|---|---|---|
linear |
线性增加噪声 | 简单场景 |
cosine |
余弦调度,前期噪声增长慢 | 复杂数据分布 |
sigmoid |
Sigmoid 调度 | 平滑过渡 |
6. 输出文件结构
outputs/
├── checkpoints/
│ ├── latest.pth # 最新检查点
│ ├── best.pth # 最佳模型
│ └── final.pth # 最终检查点
├── logs/
│ ├── train.log # 训练日志
│ ├── loss_curve.png # 损失曲线
│ └── loss_curve_final.png
└── test_samples/
├── sample_0/
│ ├── generated.mp4
│ ├── generated.gif
│ └── frame_00.png, frame_01.png, ...
└── sample_1/
├── generated.mp4
├── generated.gif
└── frame_00.png, frame_01.png, ...
7. 性能优化建议
7.1 训练效率优化
| 优化项 | 建议值 | 说明 |
|---|---|---|
| 批量大小 | 8-32 | 根据显存调整 |
| 梯度累积 | 4-8 | 模拟大批次 |
| 混合精度 | FP16/FP8 | 加速训练,减少显存 |
| 数据预加载 | 开启 | 使用 pin_memory=True |
7.2 超参数建议
| 参数 | 建议范围 | 说明 |
|---|---|---|
| 学习率 | 1e-4 ~ 5e-4 | AdamW 优化器 |
| 权重衰减 | 1e-2 ~ 1e-4 | 防止过拟合 |
| 时间步数 | 1000 | 标准配置 |
| beta_start | 1e-4 | 初始噪声系数 |
| beta_end | 0.02 | 终止噪声系数 |
8. 常见问题
8.1 STDiT 导入失败
问题 :运行时提示 ModuleNotFoundError: No module named 'DriveGen.models'
解决方案:
-
确保
DriveGen/models/目录存在且包含__init__.py和 STDiT 实现 -
脚本会自动回退到 SimpleUNet,可用于测试
8.2 显存不足
解决方案:
-
减小
batch_size -
增加
gradient_accumulation_steps -
使用更小的模型配置(
hidden_dim,depth) -
启用混合精度训练
8.3 训练损失不下降
排查步骤:
-
检查数据加载是否正确(帧顺序、归一化范围)
-
确认噪声调度器配置正确
-
检查学习率是否合适
-
验证模型输出形状是否与标签匹配