bash
DriveGen/
├── configs/
│ └── default.yaml # 训练配置文件
├── DriveGen/
│ ├── __init__.py # 包初始化
│ ├── models/
│ │ ├── __init__.py
│ │ ├── embedding.py # Patch嵌入、时间步编码、位置编码
│ │ ├── attention.py # 空间注意力、时间注意力
│ │ ├── dit_block.py # AdaLN-Zero DiT Block
│ │ └── stdit.py # STDiT 完整模型
│ ├── data/
│ │ ├── __init__.py
│ │ └── dataset.py # 合成数据集 + nuScenes 适配器
│ ├── schedules/
│ │ ├── __init__.py
│ │ └── noise_schedule.py # 线性/余弦噪声调度
│ └── utils/
│ ├── __init__.py
│ ├── visualization.py # 视频保存、对比图、损失曲线
│ └── logger.py # 日志工具
├── train.py # 训练脚本
├── inference.py # 推理脚本(DDPM 采样 + CFG)
├── evaluate.py # 评估脚本(FID 计算)
├── requirements.txt # 依赖清单
├── setup.py # 安装配置
└── README.md # 本文件
DriveGen 评估脚本说明文档
概述
evaluate.py 是 DriveGen 项目的视频生成质量评估脚本 ,主要用于计算生成视频与真实视频之间的 FID(Frechet Inception Distance)分数,以量化评估生成模型的性能。
核心功能
| 功能模块 | 说明 |
|---|---|
| FID 计算 | 衡量生成分布与真实分布的距离,值越小表示生成质量越好 |
| 特征提取 | 使用简化的 CNN 网络从视频帧中提取特征向量 |
| 指标评估 | 自动生成评估报告,保存结果到文件 |
评估原理
FID(Frechet Inception Distance)
FID 是衡量生成模型质量的标准指标,其核心思想是:
-
特征提取:从真实图像和生成图像中提取高维特征
-
分布建模:假设特征服从多元高斯分布,计算均值和协方差
-
距离计算:计算两个高斯分布之间的 Frechet 距离
数学公式:
d² = ||μ₁ - μ₂||² + Tr(σ₁ + σ₂ - 2√(σ₁σ₂))
其中:
-
μ₁, μ₂:两个分布的均值向量
-
σ₁, σ₂:两个分布的协方差矩阵
-
Tr:矩阵的迹
代码结构
1. 特征提取器
SimpleFeatureExtractor (evaluate.py#L126-L210):
一个轻量级 CNN 网络,用于从图像帧中提取 256 维特征向量。
# 网络结构
Conv2d(3, 32) → ReLU → MaxPool → Conv2d(32, 64) → ReLU → MaxPool → Conv2d(64, 128) → AdaptiveAvgPool → Linear(256)
设计说明:
-
输入:(B, 3, H, W),值域 0, 1
-
输出:(B, 256)
-
使用自适应平均池化处理不同尺寸输入
2. FID 计算函数
compute_frechet_distance (evaluate.py#L247-L320):
计算两个多元高斯分布之间的 Frechet 距离。
关键处理:
-
使用
scipy.linalg.sqrtm计算矩阵平方根 -
处理数值不稳定情况(复数结果)
-
确保结果非负
compute_fid (evaluate.py#L346-L378):
封装完整的 FID 计算流程:
-
计算真实特征的均值和协方差
-
计算生成特征的均值和协方差
-
调用
compute_frechet_distance计算距离
3. 特征提取流程
extract_features_from_dataset (evaluate.py#L381-L428):
从数据集中提取真实视频帧的特征。
extract_features_from_generated (evaluate.py#L431-L503):
使用模型生成视频并提取特征。
4. 主函数
main (evaluate.py#L506-L648):
执行完整的评估流程:
加载配置 → 创建组件 → 提取真实特征 → 生成并提取特征 → 计算 FID → 保存结果
使用方法
命令行参数
| 参数 | 简写 | 类型 | 默认值 | 说明 |
|---|---|---|---|---|
--checkpoint |
-c |
str | 必须 | 模型检查点路径 |
--config |
- | str | configs/default.yaml |
配置文件路径 |
--num_samples |
-n |
int | 配置文件值 | 评估样本数 |
--output_dir |
-o |
str | eval_results/ |
输出目录 |
--device |
- | str | 自动检测 | 计算设备 |
--seed |
- | int | 42 | 随机种子 |
--batch_size |
- | int | 8 | 特征提取批量大小 |
使用示例
# 基本用法
python evaluate.py --checkpoint checkpoints/best.pth
# 指定样本数
python evaluate.py --checkpoint checkpoints/best.pth --num_samples 200
# 指定输出目录和设备
python evaluate.py --checkpoint checkpoints/best.pth --output_dir eval_results/ --device cuda
输出结果
评估完成后,会在输出目录生成 evaluation_results.txt 文件:
DriveGen 评估结果
========================================
FID 分数: 12.3456
评估样本数: 100
真实帧数: 400
生成帧数: 400
特征维度: 256
检查点: checkpoints/best.pth
说明:
FID (Frechet Inception Distance) 衡量生成分布与真实分布的距离。
值越小越好,0 表示完美匹配。
注意: 此评估使用简化的特征提取器,结果仅供参考。
实际应用中建议使用 InceptionV3 获取更准确的 FID。
关键设计特点
1. 简化特征提取
使用自定义 CNN 而非预训练的 InceptionV3,便于学习和部署,同时保持 FID 计算流程的完整性。
2. 数值稳定性
在 compute_frechet_distance 中:
-
处理协方差矩阵平方根可能出现的复数问题
-
使用 SVD 分解作为备选方法
-
确保最终结果非负
3. 批处理优化
特征提取采用批处理方式,支持大样本评估,提高计算效率。
4. 结果可追溯
自动保存评估参数和结果,便于实验复现和对比分析。
扩展建议
使用 InceptionV3(推荐)
为获得更准确的 FID 分数,建议使用预训练的 InceptionV3:
from torchvision.models import inception_v3
class InceptionFeatureExtractor(nn.Module):
def __init__(self):
super().__init__()
self.model = inception_v3(pretrained=True, transform_input=False)
self.model.fc = nn.Identity() # 移除分类层
def forward(self, x):
return self.model(x)
增加更多评估指标
可以扩展支持以下指标:
-
IS(Inception Score):衡量生成样本的多样性和质量
-
LPIPS:感知相似度指标
-
SSIM/PSNR:像素级相似度指标
注意事项
-
特征提取器差异:本脚本使用简化的 CNN,与标准 InceptionV3 的 FID 结果不可直接比较
-
样本数量:FID 计算需要足够的样本数才能稳定,建议至少 100 个样本
-
计算资源:大规模评估可能需要较长时间和较多显存
-
结果解读:FID 只是评估指标之一,还需结合主观视觉评估