【DriveGen 文件详解】04——evaluate.py

bash 复制代码
DriveGen/
├── configs/
│   └── default.yaml              # 训练配置文件
├── DriveGen/
│   ├── __init__.py               # 包初始化
│   ├── models/
│   │   ├── __init__.py
│   │   ├── embedding.py          # Patch嵌入、时间步编码、位置编码
│   │   ├── attention.py          # 空间注意力、时间注意力
│   │   ├── dit_block.py          # AdaLN-Zero DiT Block
│   │   └── stdit.py              # STDiT 完整模型
│   ├── data/
│   │   ├── __init__.py
│   │   └── dataset.py            # 合成数据集 + nuScenes 适配器
│   ├── schedules/
│   │   ├── __init__.py
│   │   └── noise_schedule.py     # 线性/余弦噪声调度
│   └── utils/
│       ├── __init__.py
│       ├── visualization.py      # 视频保存、对比图、损失曲线
│       └── logger.py             # 日志工具
├── train.py                      # 训练脚本
├── inference.py                  # 推理脚本(DDPM 采样 + CFG)
├── evaluate.py                   # 评估脚本(FID 计算)
├── requirements.txt              # 依赖清单
├── setup.py                      # 安装配置
└── README.md                     # 本文件

LQY-hh/DriveGen-Transformer-: 自动驾驶技术的发展离不开海量数据的支撑,但稀有场景(如极端天气、突发事故)的数据采集成本极高。**DriveGen** 旨在通过扩散模型生成高质量的驾驶场景视频,为自动驾驶算法提供无限的虚拟训练数据。 ### 核心价值https://github.com/LQY-hh/DriveGen-Transformer-

DriveGen 评估脚本说明文档

概述

evaluate.py 是 DriveGen 项目的视频生成质量评估脚本 ,主要用于计算生成视频与真实视频之间的 FID(Frechet Inception Distance)分数,以量化评估生成模型的性能。


核心功能

功能模块 说明
FID 计算 衡量生成分布与真实分布的距离,值越小表示生成质量越好
特征提取 使用简化的 CNN 网络从视频帧中提取特征向量
指标评估 自动生成评估报告,保存结果到文件

评估原理

FID(Frechet Inception Distance)

FID 是衡量生成模型质量的标准指标,其核心思想是:

  1. 特征提取:从真实图像和生成图像中提取高维特征

  2. 分布建模:假设特征服从多元高斯分布,计算均值和协方差

  3. 距离计算:计算两个高斯分布之间的 Frechet 距离

数学公式

复制代码
d² = ||μ₁ - μ₂||² + Tr(σ₁ + σ₂ - 2√(σ₁σ₂))

其中:

  • μ₁, μ₂:两个分布的均值向量

  • σ₁, σ₂:两个分布的协方差矩阵

  • Tr:矩阵的迹


代码结构

1. 特征提取器

SimpleFeatureExtractor (evaluate.py#L126-L210):

一个轻量级 CNN 网络,用于从图像帧中提取 256 维特征向量。

复制代码
# 网络结构
Conv2d(3, 32) → ReLU → MaxPool → Conv2d(32, 64) → ReLU → MaxPool → Conv2d(64, 128) → AdaptiveAvgPool → Linear(256)

设计说明

  • 输入:(B, 3, H, W),值域 0, 1

  • 输出:(B, 256)

  • 使用自适应平均池化处理不同尺寸输入

2. FID 计算函数

compute_frechet_distance (evaluate.py#L247-L320):

计算两个多元高斯分布之间的 Frechet 距离。

关键处理

  • 使用 scipy.linalg.sqrtm 计算矩阵平方根

  • 处理数值不稳定情况(复数结果)

  • 确保结果非负

compute_fid (evaluate.py#L346-L378):

封装完整的 FID 计算流程:

  1. 计算真实特征的均值和协方差

  2. 计算生成特征的均值和协方差

  3. 调用 compute_frechet_distance 计算距离

3. 特征提取流程

extract_features_from_dataset (evaluate.py#L381-L428):

从数据集中提取真实视频帧的特征。

extract_features_from_generated (evaluate.py#L431-L503):

使用模型生成视频并提取特征。

4. 主函数

main (evaluate.py#L506-L648):

执行完整的评估流程:

复制代码
加载配置 → 创建组件 → 提取真实特征 → 生成并提取特征 → 计算 FID → 保存结果

使用方法

命令行参数

参数 简写 类型 默认值 说明
--checkpoint -c str 必须 模型检查点路径
--config - str configs/default.yaml 配置文件路径
--num_samples -n int 配置文件值 评估样本数
--output_dir -o str eval_results/ 输出目录
--device - str 自动检测 计算设备
--seed - int 42 随机种子
--batch_size - int 8 特征提取批量大小

使用示例

复制代码
# 基本用法
python evaluate.py --checkpoint checkpoints/best.pth

# 指定样本数
python evaluate.py --checkpoint checkpoints/best.pth --num_samples 200

# 指定输出目录和设备
python evaluate.py --checkpoint checkpoints/best.pth --output_dir eval_results/ --device cuda

输出结果

评估完成后,会在输出目录生成 evaluation_results.txt 文件:

复制代码
DriveGen 评估结果
========================================

FID 分数: 12.3456
评估样本数: 100
真实帧数: 400
生成帧数: 400
特征维度: 256
检查点: checkpoints/best.pth

说明:
  FID (Frechet Inception Distance) 衡量生成分布与真实分布的距离。
  值越小越好,0 表示完美匹配。
  注意: 此评估使用简化的特征提取器,结果仅供参考。
  实际应用中建议使用 InceptionV3 获取更准确的 FID。

关键设计特点

1. 简化特征提取

使用自定义 CNN 而非预训练的 InceptionV3,便于学习和部署,同时保持 FID 计算流程的完整性。

2. 数值稳定性

compute_frechet_distance 中:

  • 处理协方差矩阵平方根可能出现的复数问题

  • 使用 SVD 分解作为备选方法

  • 确保最终结果非负

3. 批处理优化

特征提取采用批处理方式,支持大样本评估,提高计算效率。

4. 结果可追溯

自动保存评估参数和结果,便于实验复现和对比分析。


扩展建议

使用 InceptionV3(推荐)

为获得更准确的 FID 分数,建议使用预训练的 InceptionV3:

复制代码
from torchvision.models import inception_v3

class InceptionFeatureExtractor(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = inception_v3(pretrained=True, transform_input=False)
        self.model.fc = nn.Identity()  # 移除分类层
    
    def forward(self, x):
        return self.model(x)

增加更多评估指标

可以扩展支持以下指标:

  • IS(Inception Score):衡量生成样本的多样性和质量

  • LPIPS:感知相似度指标

  • SSIM/PSNR:像素级相似度指标


注意事项

  1. 特征提取器差异:本脚本使用简化的 CNN,与标准 InceptionV3 的 FID 结果不可直接比较

  2. 样本数量:FID 计算需要足够的样本数才能稳定,建议至少 100 个样本

  3. 计算资源:大规模评估可能需要较长时间和较多显存

  4. 结果解读:FID 只是评估指标之一,还需结合主观视觉评估

相关推荐
老高学长1 小时前
金融机构文档加密软件哪个好|合规与安全兼顾|2026新测评
网络·人工智能·安全
闻道参看1 小时前
生成式智能搜索下的流量卡位攻略:初创个体如何甄选高兼容性的 GEO 优化 服务商
人工智能
Herlie1 小时前
6款可编辑AI海报工具深度横测(2026)
大数据·人工智能
轻刀快马1 小时前
跨越“拟人”的最后一道天堑:大模型强化学习(RLHF/RLAIF)底层原理解析
人工智能·深度学习·机器学习
hsg771 小时前
简述:小数据集照片分类的模型训练
人工智能·分类·数据挖掘
清 晨1 小时前
YouTube自动AI标签上线后跨境内容团队如何调整素材审核流程
大数据·人工智能·新媒体运营·内容营销·跨境
qq_283720051 小时前
2026 最新 Python+AI 零基础入门全教程 :从零搭建人工智能完整项目
开发语言·人工智能·python
拓朗工控1 小时前
具身智能的“小空间大算力”难题:边缘AI主机如何落地机器人
人工智能·机器人