使用 Rectified Flow 和 Diffusion Transformer实现 MNIST 手写数字图像生成

1. 作者介绍

程锡贵,男,西安工程大学电子信息学院,2025级研究生,张宏伟人工智能课题组

研究方向:机器视觉与人工智能

电子邮件:15327178796@163.com

2. Rectified Flow 与 Diffusion Transformer 理论介绍

2.1 实验任务与总体框架

本实验的目标是生成 MNIST 风格的手写数字图像。模型训练完成后,不需要输入某张待识别图片,而是从随机高斯噪声出发,经过多步速度场更新,逐步得到具有手写数字轮廓的灰度图像。

图 1 Rectified Flow + DiT 手写数字图像生成流程

2.2 Rectified Flow 基本原理

Rectified Flow 将图像生成视为一个连续分布运输问题:从标准高斯噪声分布采样初始状态 x₀,从真实 MNIST 数据分布采样目标图像 x₁,并在两者之间构造直线路径。

xₜ = (1 - t)x₀ + tx₁ , t ∈ 0, 1

当 t=0 时,xₜ 等于纯噪声 x₀;当 t=1 时,xₜ 等于真实图像 x₁。由于该路径是线性的,对时间 t 求导后可得到目标速度:

v*(xₜ, t) = dxₜ / dt = x₁ - x₀

训练阶段使用神经网络 vθ(xₜ, t) 拟合上述目标速度,采用均方误差作为损失函数:

L = E || vθ(xₜ, t) - (x₁ - x₀) ||²

2.3 Diffusion Transformer (DiT)

DiT 使用 Transformer 替代常见扩散模型中的 U-Net 作为核心网络。在本实验中,DiT 的作用是接收中间状态图像 xₜ 与连续时间 t,并预测与输入图像尺寸相同的速度图 vθ(xₜ, t)。

图 2 DiT 速度预测网络结构

2.4 DiT 中的关键模块

模块 作用
Patch Embedding 将 1×28×28 的图像切分为 4×4 patch,并映射为 token。默认参数下得到 7×7=49 个 token。
DiTBlock 由多头自注意力、MLP、残差连接与条件调制组成,用于建模不同 patch 之间的关系。
AdaLN Adaptive Layer Normalization。根据时间条件动态生成 scale、shift 与 gate,使网络适应不同生成阶段。
unpatchify 将 Transformer 输出的 patch 重新排列并拼接为 1×28×28 的速度预测图。

数据在 DiT 内部的主要形状变化为:

B, 1, 28, 28B, 256, 7, 7B, 49, 256B, 49, 16B, 1, 28, 28

2.5 Euler ODE 采样

训练完成后,模型从随机噪声开始生成图像。将时间区间 0, 1 划分为若干步,在每一步调用 DiT 预测速度,并使用 Euler 方法更新当前图像:

xₜ₊Δₜ = xₜ + vθ(xₜ, t) · Δt

默认 sample_steps=100,对应 Δt=0.01。采样步数增加时通常更新更细致,但推理时间也会增加。

3. Rectified Flow + DiT 手写数字生成实验

3.1 实验环境与软件包

本项目使用 Python 与 PyTorch 实现。建议优先使用带 NVIDIA GPU 的环境训练;若没有 GPU,程序会自动切换到 CPU,但训练速度较慢。

软件或包 用途 安装说明
Python 运行实验代码 建议使用 Python 3.10 或更高版本
torch 张量计算、模型搭建、反向传播和 GPU 加速 pip install torch
torchvision 下载 MNIST、图像预处理、保存生成图片 pip install torchvision
tqdm 显示训练进度条 pip install tqdm

安装项目依赖:

pip install -r requirements.txt

检查 CUDA 是否可用:

python -c "import torch; print(torch.cuda.is_available())"

3.2 MNIST 数据集介绍

MNIST 是经典的手写数字数据集,图像为 28×28 灰度图,数字类别为 0 至 9。程序使用 torchvision.datasets.MNIST 自动下载训练数据,并通过 Normalize((0.5,), (0.5,)) 将像素范围从 0, 1 转换到 -1, 1

3.3 项目文件结构

rf_dit_mnist/

├── train.py # 训练入口

├── sample.py # 单独采样入口

├── utils.py # 时间嵌入和 AdaLN 调制函数

├── models/

│ ├── init.py

│ └── dit.py # Diffusion Transformer

├── flow/

│ ├── init.py

│ └── rectified_flow.py # Rectified Flow 训练和采样

├── requirements.txt

└── README.md

3.4 训练与测试步骤

步骤 操作说明
步骤 1 解压项目压缩包,并进入 rf_dit_mnist 目录。
步骤 2 安装依赖:pip install -r requirements.txt。
步骤 3 执行训练命令。程序会自动下载 MNIST,并定期保存生成图片和 latest.pt
步骤 4 训练结束后使用 sample.py 加载 latest.pt,单独生成 final_sample.png。

推荐训练命令:

python train.py --epochs 30 --batch_size 128 --sample_steps 100

显存不足时可降低 batch size 和模型规模:

python train.py --epochs 30 --batch_size 64 --dim 128 --depth 4 --heads 4

训练完成后单独采样:

python sample.py --ckpt ./runs_rf_dit_mnist/latest.pt

3.5 训练过程说明

阶段 程序操作 对应代码位置
数据加载 读取 MNIST 图像并归一化到 -1, 1 train.py
构造路径 采样 x₀、x₁ 与 t,计算 xₜ=(1-t)x₀+tx₁ flow/rectified_flow.py
速度预测 DiT 输入 xₜ 与 t,输出预测速度 models/dit.py
损失计算 MSE(pred_v, x₁-x₀) flow/rectified_flow.py
反向传播 更新模型参数并裁剪梯度 flow/rectified_flow.py
生成样例 从噪声出发,利用 Euler 方法积分 flow/rectified_flow.py

3.6 完整实验代码

以下代码与项目文件结构一一对应。为了便于阅读,代码按照文件拆分展示,并加入了必要注释。

文件:requirements.txt

torch torchvision tqdm

文件:models/init.py

导出 DiT 模型,方便其他文件使用 from models import DiT

from .dit import DiT

文件:flow/init.py

导出 Rectified Flow 的训练和采样函数

from .rectified_flow import train_one_epoch, sample

文件:utils.py

python 复制代码
import math
import torch
 
 
def timestep_embedding(t, dim, max_period=10000):
    """将连续时间 t 编码为正弦余弦向量。
 
    参数:
        t: [B],范围为 [0, 1] 的连续时间。
        dim: 时间嵌入的维度。
    返回:
        [B, dim] 的时间特征。
    """
    half = dim // 2
    t = t * 1000.0
 
    freqs = torch.exp(
        -math.log(max_period)
        * torch.arange(0, half, dtype=torch.float32, device=t.device)
        / half
    )
 
    args = t[:, None] * freqs[None]
    emb = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
 
    if dim % 2 == 1:
        emb = torch.cat([emb, torch.zeros_like(emb[:, :1])], dim=-1)
 
    return emb
 
 
def modulate(x, shift, scale):
    """AdaLN 调制:根据时间条件对标准化后的特征缩放和平移。"""
    return x * (1 + scale[:, None, :]) + shift[:, None, :]

文件:models/dit.py

python 复制代码
import torch
import torch.nn as nn
 
from utils import timestep_embedding, modulate
 
 
class DiTBlock(nn.Module):
    """一个 DiTBlock:自注意力、MLP、残差连接和 AdaLN 条件调制。"""
 
    def __init__(self, dim, heads, mlp_ratio=4.0):
        super().__init__()
 
        # 不使用 LayerNorm 自带仿射参数,由 AdaLN 动态生成 scale 和 shift。
        self.norm1 = nn.LayerNorm(dim, elementwise_affine=False)
        self.attn = nn.MultiheadAttention(
            embed_dim=dim,
            num_heads=heads,
            batch_first=True
        )
 
        self.norm2 = nn.LayerNorm(dim, elementwise_affine=False)
 
        hidden_dim = int(dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, dim)
        )
 
        # 输出 Attention 和 MLP 各自所需的 shift、scale 和 gate。
        self.adaLN = nn.Sequential(
            nn.SiLU(),
            nn.Linear(dim, dim * 6)
        )
 
    def forward(self, x, c):
        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = \
            self.adaLN(c).chunk(6, dim=1)
 
        # 多头自注意力分支。
        x_norm = modulate(self.norm1(x), shift_msa, scale_msa)
        attn_out, _ = self.attn(x_norm, x_norm, x_norm, need_weights=False)
        x = x + gate_msa[:, None, :] * attn_out
 
        # MLP 分支。
        x_norm = modulate(self.norm2(x), shift_mlp, scale_mlp)
        mlp_out = self.mlp(x_norm)
        x = x + gate_mlp[:, None, :] * mlp_out
 
        return x
 
 
class DiT(nn.Module):
    """面向 28×28 MNIST 灰度图的轻量级 Diffusion Transformer。"""
 
    def __init__(
        self,
        img_size=28,
        patch_size=4,
        in_channels=1,
        dim=256,
        depth=6,
        heads=8,
        mlp_ratio=4.0
    ):
        super().__init__()
        assert img_size % patch_size == 0
 
        self.img_size = img_size
        self.patch_size = patch_size
        self.in_channels = in_channels
 
        self.num_patches = (img_size // patch_size) ** 2
        self.patch_dim = in_channels * patch_size * patch_size
 
        # 使用卷积完成 Patch Embedding。
        self.patch_embed = nn.Conv2d(
            in_channels,
            dim,
            kernel_size=patch_size,
            stride=patch_size
        )
 
        # 可学习位置编码:告诉 Transformer 每个 patch 位于图像的什么位置。
        self.pos_embed = nn.Parameter(
            torch.zeros(1, self.num_patches, dim)
        )
 
        # 把连续时间嵌入映射为条件向量 c。
        self.time_mlp = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.SiLU(),
            nn.Linear(dim * 4, dim)
        )
 
        self.blocks = nn.ModuleList([
            DiTBlock(dim, heads, mlp_ratio)
            for _ in range(depth)
        ])
 
        self.final_norm = nn.LayerNorm(dim, elementwise_affine=False)
        self.final_adaLN = nn.Sequential(
            nn.SiLU(),
            nn.Linear(dim, dim * 2)
        )
        self.final_linear = nn.Linear(dim, self.patch_dim)
 
        self.initialize_weights()
 
    def initialize_weights(self):
        """初始化位置编码、线性层和卷积层。"""
        nn.init.normal_(self.pos_embed, std=0.02)
 
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
 
            if isinstance(module, nn.Conv2d):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
 
    def unpatchify(self, x):
        """将 patch 形式的预测结果重新拼接成二维图像。"""
        B = x.shape[0]
        p = self.patch_size
        h = w = self.img_size // p
        c = self.in_channels
 
        x = x.reshape(B, h, w, c, p, p)
        x = torch.einsum("bhwcpq->bchpwq", x)
        x = x.reshape(B, c, h * p, w * p)
 
        return x
 
    def forward(self, x, t):
        """输入中间图像 x_t 与时间 t,输出同尺寸的速度预测图。"""
        # [B, 1, 28, 28] -> [B, dim, 7, 7] -> [B, 49, dim]
        x = self.patch_embed(x)
        x = x.flatten(2).transpose(1, 2)
 
        x = x + self.pos_embed
 
        t_emb = timestep_embedding(t, x.shape[-1])
        c = self.time_mlp(t_emb)
 
        for block in self.blocks:
            x = block(x, c)
 
        shift, scale = self.final_adaLN(c).chunk(2, dim=1)
        x = modulate(self.final_norm(x), shift, scale)
 
        # 每个 token 输出一个 patch,再拼回速度图。
        x = self.final_linear(x)
        x = self.unpatchify(x)
 
        return x

文件:flow/rectified_flow.py

python 复制代码
import torch
import torch.nn.functional as F
from tqdm import tqdm
 
 
def train_one_epoch(model, loader, optimizer, device):
    """训练一个 epoch:学习从噪声到真实图像的速度场。"""
    model.train()
    total_loss = 0.0
 
    for imgs, _ in tqdm(loader, desc="Training", leave=False):
        imgs = imgs.to(device)
 
        # x0:高斯噪声;x1:真实 MNIST 图像。
        x0 = torch.randn_like(imgs)
        x1 = imgs
 
        B = imgs.shape[0]
 
        # 为 batch 中每个样本随机采样连续时间 t。
        t = torch.rand(B, device=device)
        t_view = t.view(B, 1, 1, 1)
 
        # Rectified Flow 直线路径:x_t = (1-t)x0 + tx1。
        xt = (1 - t_view) * x0 + t_view * x1
 
        # 对直线路径求导,可得目标速度 x1 - x0。
        target_v = x1 - x0
 
        pred_v = model(xt, t)
        loss = F.mse_loss(pred_v, target_v)
 
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
 
        total_loss += loss.item() * B
 
    return total_loss / len(loader.dataset)
 
 
@torch.no_grad()
def sample(model, device, n_samples=64, steps=100):
    """从高斯噪声开始,使用 Euler 方法积分得到生成图像。"""
    model.eval()
 
    x = torch.randn(n_samples, 1, 28, 28, device=device)
    dt = 1.0 / steps
 
    for i in range(steps):
        t = torch.full((n_samples,), i / steps, device=device)
        v = model(x, t)
 
        # Euler ODE 更新:x_(t+dt) = x_t + v(x_t, t) * dt。
        x = x + v * dt
 
    # 将图像从训练时的 [-1, 1] 转回保存图片所需的 [0, 1]。
    x = x.clamp(-1, 1)
    x = (x + 1) / 2
 
    return x

文件:train.py

python 复制代码
import os
import argparse
 
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image
 
from models import DiT
from flow import train_one_epoch, sample
 
 
def parse_args():
    """读取训练参数。"""
    parser = argparse.ArgumentParser()
 
    parser.add_argument("--data_dir", type=str, default="./data")
    parser.add_argument("--save_dir", type=str, default="./runs_rf_dit_mnist")
 
    parser.add_argument("--epochs", type=int, default=30)
    parser.add_argument("--batch_size", type=int, default=128)
    parser.add_argument("--lr", type=float, default=1e-4)
 
    parser.add_argument("--dim", type=int, default=256)
    parser.add_argument("--depth", type=int, default=6)
    parser.add_argument("--heads", type=int, default=8)
    parser.add_argument("--patch_size", type=int, default=4)
 
    parser.add_argument("--sample_steps", type=int, default=100)
    parser.add_argument("--sample_every", type=int, default=1)
 
    return parser.parse_args()
 
 
def main():
    args = parse_args()
    os.makedirs(args.save_dir, exist_ok=True)
 
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)
 
    # 将图像像素范围由 [0, 1] 映射到 [-1, 1]。
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
 
    # 若本地没有 MNIST,则自动下载。
    dataset = datasets.MNIST(
        root=args.data_dir,
        train=True,
        transform=transform,
        download=True
    )
 
    loader = DataLoader(
        dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=4,
        pin_memory=True,
        drop_last=True
    )
 
    model = DiT(
        img_size=28,
        patch_size=args.patch_size,
        in_channels=1,
        dim=args.dim,
        depth=args.depth,
        heads=args.heads
    ).to(device)
 
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=args.lr,
        betas=(0.9, 0.95),
        weight_decay=1e-4
    )
 
    for epoch in range(1, args.epochs + 1):
        loss = train_one_epoch(model, loader, optimizer, device)
        print(f"Epoch [{epoch}/{args.epochs}] Loss: {loss:.6f}")
 
        # 每隔指定 epoch 保存一次生成样例。
        if epoch % args.sample_every == 0:
            samples = sample(
                model,
                device,
                n_samples=64,
                steps=args.sample_steps
            )
 
            save_path = os.path.join(
                args.save_dir,
                f"sample_epoch_{epoch}.png"
            )
 
            save_image(samples, save_path, nrow=8)
            print(f"Saved samples to {save_path}")
 
        # 保存最新模型和训练参数,便于之后单独采样。
        ckpt_path = os.path.join(args.save_dir, "latest.pt")
        torch.save(
            {
                "model": model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "epoch": epoch,
                "args": vars(args)
            },
            ckpt_path
        )
 
 
if __name__ == "__main__":
    main()

文件:sample.py

python 复制代码
import os
import argparse
 
import torch
from torchvision.utils import save_image
 
from models import DiT
from flow import sample
 
 
def parse_args():
    """读取单独采样时所需的参数。"""
    parser = argparse.ArgumentParser()
    parser.add_argument("--ckpt", type=str,
                        default="./runs_rf_dit_mnist/latest.pt")
    parser.add_argument("--save_path", type=str,
                        default="./runs_rf_dit_mnist/final_sample.png")
    parser.add_argument("--n_samples", type=int, default=64)
    parser.add_argument("--sample_steps", type=int, default=100)
    return parser.parse_args()
 
 
def main():
    args = parse_args()
 
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)
 
    # 读取训练阶段保存的权重和模型配置。
    ckpt = torch.load(args.ckpt, map_location=device)
    model_args = ckpt["args"]
 
    model = DiT(
        img_size=28,
        patch_size=model_args["patch_size"],
        in_channels=1,
        dim=model_args["dim"],
        depth=model_args["depth"],
        heads=model_args["heads"]
    ).to(device)
 
    model.load_state_dict(ckpt["model"])
 
    samples = sample(
        model,
        device,
        n_samples=args.n_samples,
        steps=args.sample_steps
    )
 
    os.makedirs(os.path.dirname(args.save_path), exist_ok=True)
    save_image(samples, args.save_path, nrow=8)
    print(f"Saved samples to {args.save_path}")
 
 
if __name__ == "__main__":
    main()

3.7 测试结果

训练过程中,程序会在 runs_rf_dit_mnist 文件夹中保存 sample_epoch_1.png、sample_epoch_2.png 等图片。训练完成后,sample.py 默认将新生成的 64 张图片保存为 final_sample.png

3.8 可能遇到的问题与解决办法

问题 可能原因 解决办法
MNIST 下载失败 网络连接异常或数据目录权限不足 检查网络;也可手动下载 MNIST,并放入 data/MNIST/raw。
CUDA 不可用 未安装 GPU 版 PyTorch、显卡驱动异常或无 NVIDIA GPU 运行 torch.cuda.is_available() 检查;必要时重新安装匹配环境的 PyTorch。
显存不足 Out of Memory batch_size、dim 或 depth 过大 减小 batch_size;将 dim 改为 128、depth 改为 4。
Windows 多进程报错 DataLoader 的 num_workers=4 在部分环境中不稳定 train.py 中 num_workers 改为 0。
生成图片模糊 训练轮数不足、模型尚未收敛或采样步数过少 适当增加 epochs;检查 Loss;尝试提高 sample_steps。
模型权重找不到 sample.py 指定的 ckpt 路径不正确 确认 runs_rf_dit_mnist/latest.pt 是否存在,并修正 --ckpt 参数。
生成图无法指定类别 当前代码是无条件生成 如需指定数字类别,需要增加 label embedding,改为条件 DiT。

4. 参考链接

1 Xingchao Liu, Chengyue Gong, Qiang Liu. Learning to Generate and Transfer Data with Rectified Flow. arXiv:2209.03003. https://arxiv.org/abs/2209.03003

2 William Peebles, Saining Xie. Scalable Diffusion Models with Transformers. arXiv:2212.09748. https://arxiv.org/abs/2212.09748

3 Torchvision Documentation: MNIST Dataset. https://docs.pytorch.org/vision/main/generated/torchvision.datasets.MNIST.html

4 PyTorch Documentation. https://pytorch.org/docs/stable/index.html

实验总结

本实验使用 Rectified Flow 建立从噪声到 MNIST 图像的连续直线路径,并使用 DiT 预测路径上的速度场。DiT 将图像转换为 patch token,通过多头注意力建模不同区域之间的关系,并利用 AdaLN 注入时间条件。采样阶段从随机噪声开始,使用 Euler 方法逐步积分,最终得到手写数字图像。该实验结构清晰、代码模块化程度较高,适合作为生成模型与 Transformer 图像建模的入门实践。

相关推荐
z202305081 小时前
RDMA之DCQCN (14)
linux·服务器·网络·人工智能·ai
SimpleLearingAI1 小时前
PyTorch & Numpy 实现线性回归详解
人工智能·算法·多模态大模型
Royzst1 小时前
一、IO 概述
开发语言·python
董董灿是个攻城狮1 小时前
AI 会吃了天涯吗?
人工智能
天风之翼1 小时前
AI 模型部署从入门到生产 —— ONNX 转换、TensorRT 加速、推理服务搭建
人工智能
A15362551 小时前
从 AI 零引用到高转化:GEO 落地价值解析
人工智能
Omics Pro1 小时前
P4医学4大支柱需绑定4大数字技术才可落地
人工智能·python·算法·机器学习·plotly
海鸥-w1 小时前
前端学习python第三天笔记整理(list 列表,str字符串,tuple元组,set集合,dect,函数,类型注解)
前端·python·学习