【超详细解读】基于DiT的MNIST扩散模型(DDPM)完整实现
引言
扩散模型(Diffusion Model)是当前生成式AI领域的研究热点,而DiT(Diffusion Transformer)则是将Transformer架构与扩散模型结合的经典范式。本文将逐行解读一份完整的基于DiT实现MNIST手写数字生成的DDPM(Denoising Diffusion Probabilistic Models)代码,从核心架构到训练推理,带你彻底理解每一行代码的功能与设计思路。
一、代码整体架构概览
这份代码实现了一个完整的DiT-DDPM训练与生成流程,核心模块包括:
- 环境配置与参数定义
- DiT核心网络架构(Transformer+条件嵌入)
- 位置编码辅助函数
- DDPM扩散过程实现
- 数据加载与训练流程
- 采样生成与可视化
二、逐模块代码解读
2.1 依赖导入与全局配置
python
from typing import Tuple, Dict
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import save_image, make_grid
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter
import numpy as np
import os
import math
from timm.models.vision_transformer import PatchEmbed, Attention, Mlp
import warnings
warnings.filterwarnings("ignore") # 屏蔽无关警告
- 核心作用:导入所有必要的依赖库,涵盖深度学习框架(PyTorch)、数据处理(torchvision)、可视化(matplotlib)、进度显示(tqdm)等
- 关键依赖说明 :
timm.models.vision_transformer:复用成熟的Transformer组件(PatchEmbed/Attention/Mlp),避免重复造轮子warnings.filterwarnings:屏蔽训练过程中的无关警告,保持日志整洁
python
# ===================== 全局配置(可按需修改) =====================
# 训练参数
TRAIN_EPOCHS = 10 # 总训练轮次
BATCH_SIZE = 128 # 批次大小(GPU显存足够可调大,如256)
LEARNING_RATE = 1e-4 # 初始学习率
N_T = 400 # DDPM时间步
GUIDE_WEIGHTS = [0.0, 0.5, 2.0] # 采样引导权重
# 路径配置(脚本位于dit文件夹下,所有路径基于dit目录)
BASE_DIR = os.path.dirname(os.path.abspath(__file__)) # dit文件夹绝对路径
DATA_DIR = os.path.join(BASE_DIR, "mnist_data") # 数据集保存目录
OUTPUT_DIR = os.path.join(BASE_DIR, "dit_ddpm_output") # 输出目录(图片/GIF/模型)
MODEL_DIR = os.path.join(OUTPUT_DIR, "models") # 模型权重保存目录
IMAGE_DIR = os.path.join(OUTPUT_DIR, "images") # 生成图片保存目录
GIF_DIR = os.path.join(OUTPUT_DIR, "gifs") # GIF保存目录
# 设备配置(自动检测CUDA)
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"📌 使用设备: {DEVICE}")
if torch.cuda.is_available():
torch.backends.cudnn.benchmark = True # 加速CUDA运算
print(f"✅ GPU加速已开启,显存可用: {torch.cuda.get_device_properties(0).total_memory/1024**3:.1f}GB")
- 训练参数:定义核心超参数,DDPM的时间步N_T=400是平衡训练效率与生成质量的常用值
- 路径配置 :使用
os.path构建跨平台路径,避免硬编码路径导致的兼容性问题 - 设备配置 :
- 自动检测CUDA设备,优先使用GPU加速
cudnn.benchmark = True:针对固定输入尺寸的场景优化CUDA卷积运算,提升训练速度
2.2 DiT模型核心层实现
2.2.1 辅助头部与调制函数
python
class SimpleHead(nn.Module):
def __init__(self, in_dim, out_dim):
super(SimpleHead, self).__init__()
self.linear1 = nn.Linear(in_dim, in_dim+out_dim)
self.linear2 = nn.Linear(in_dim+out_dim, out_dim)
self.act = nn.SiLU()
def forward(self, x):
x=self.linear1(x)
x=self.linear2(self.act(x))
return x
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
- SimpleHead:简单的两层全连接头,用于特征映射,SiLU激活函数是Transformer架构的常用选择
- modulate函数:核心调制函数,实现AdaLN(Adaptive Layer Normalization)的核心逻辑,通过shift和scale参数对特征进行自适应调整,是DiT融合时间/类别条件的关键
2.2.2 时间步嵌入模块
python
class TimestepEmbedder(nn.Module):
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
).to(device=t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq)
return t_emb
- 核心功能:将离散的时间步t转换为高维连续嵌入向量
- timestep_embedding静态方法 :
- 使用正弦/余弦位置编码(Sinusoidal Embedding),这是Transformer中位置编码的经典实现
- 将时间步映射到不同频率的余弦/正弦函数值,让模型能够学习时间步之间的连续关系
- MLP层:对频率嵌入进行非线性变换,增强表达能力
2.2.3 类别嵌入模块
python
class LabelEmbedder(nn.Module):
def __init__(self, num_classes, hidden_size, dropout_prob):
super().__init__()
use_cfg_embedding = dropout_prob > 0
self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
self.num_classes = num_classes
self.dropout_prob = dropout_prob
def token_drop(self, labels, force_drop_ids=None):
if force_drop_ids is None:
drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
else:
drop_ids = force_drop_ids == 1
labels = torch.where(drop_ids, self.num_classes, labels)
return labels
def forward(self, labels, train, force_drop_ids=None):
use_dropout = self.dropout_prob > 0
if (train and use_dropout) or (force_drop_ids is not None):
labels = self.token_drop(labels, force_drop_ids)
embeddings = self.embedding_table(labels)
return embeddings, labels
- 核心功能:实现类别嵌入+CFG(Classifier-Free Guidance)的token drop机制
- token_drop方法 :
- 训练时以指定概率将类别标签替换为特殊token(num_classes)
- 为CFG采样提供无类别条件的分支,这是提升生成质量的关键技巧
- Embedding层:额外增加一个embedding维度用于存储无类别条件的token
2.2.4 DiT Block实现
python
class DiTBlock(nn.Module):
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
super().__init__()
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, qk_norm=block_kwargs.get("qk_norm", False))
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
def forward(self, x, c):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
return x
- 核心架构:基于Transformer Block的改进版本,融合AdaLN调制
- 关键设计 :
elementwise_affine=False:关闭LayerNorm的可学习参数,改为通过AdaLN动态调制adaLN_modulation:将条件向量c(时间+类别)映射为6个调制参数(分别用于MSA和MLP的shift/scale/gate)- 门控机制(gate_msa/gate_mlp):自适应控制Attention和MLP输出的权重
- 残差连接:保持Transformer的核心结构,保证梯度流动
2.2.5 最终层与DiT主模型
python
class FinalLayer(nn.Module):
def __init__(self, hidden_size, patch_size, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class DiT(nn.Module):
def __init__(
self,
input_size=28,
patch_size=4,
in_channels=1,
hidden_size=384,
depth=12,
num_heads=6,
mlp_ratio=4.0,
class_dropout_prob=0.1,
num_classes=10,
learn_sigma=False,** block_kwargs
):
super().__init__()
self.learn_sigma = learn_sigma
self.in_channels = in_channels
self.out_channels = in_channels * 2 if learn_sigma else in_channels
self.patch_size = patch_size
self.num_heads = num_heads
self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
self.t_embedder = TimestepEmbedder(hidden_size)
self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
num_patches = self.x_embedder.num_patches
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
self.blocks = nn.ModuleList([
DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, **block_kwargs) for _ in range(depth)
])
self.ap_head = SimpleHead(hidden_size, hidden_size)
self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
self.initialize_weights()
def initialize_weights(self):
# 权重初始化逻辑,保证模型稳定收敛
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5))
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.x_embedder.proj.bias, 0)
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
for block in self.blocks:
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
def unpatchify(self, x):
c = self.out_channels
p = self.x_embedder.patch_size[0]
h = w = int(x.shape[1] ** 0.5)
assert h * w == x.shape[1]
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
x = torch.einsum('nhwpqc->nchpwq', x)
imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
return imgs
def forward(self, x, t, y, context_mask=None, ad=4):
force_drop_ids = context_mask if context_mask is not None else None
x = self.x_embedder(x) + self.pos_embed
t = self.t_embedder(t)
y_emb, _ = self.y_embedder(y, self.training, force_drop_ids)
c = t + y_emb
for i, block in enumerate(self.blocks):
x = block(x, c)
if (i + 1) == ad:
xr = self.ap_head(x) if self.training else x
x = self.final_layer(x, c)
x = self.unpatchify(x)
return x
- FinalLayer:DiT的输出层,将Transformer特征映射回图像空间,同样使用AdaLN调制
- DiT主模型核心流程 :
- PatchEmbed:将28x28图像切分为4x4的patch,映射到hidden_size维度
- 位置编码:加载预计算的2D正弦位置编码(非可学习)
- 条件融合:时间嵌入+类别嵌入拼接为条件向量c
- DiT Blocks:堆叠12层DiT Block,每层都接收条件向量c进行调制
- Unpatchify:将patch特征还原为完整图像尺寸
- 权重初始化:精细化的初始化策略,保证模型初始状态稳定,特别是将调制层初始化为0,避免初始时过度调制
2.3 位置编码辅助函数
python
# ===================== 位置编码辅助函数 =====================
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
grid_h = np.arange(grid_size, dtype=np.float32)
grid_w = np.arange(grid_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h)
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size, grid_size])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token and extra_tokens > 0:
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])
emb = np.concatenate([emb_h, emb_w], axis=1)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float64)
omega /= embed_dim / 2.
omega = 1. / 10000**omega
pos = pos.reshape(-1)
out = np.einsum('m,d->md', pos, omega)
emb_sin = np.sin(out)
emb_cos = np.cos(out)
emb = np.concatenate([emb_sin, emb_cos], axis=1)
return emb
- 核心功能:生成2D正弦余弦位置编码
- 实现逻辑 :
- 生成网格坐标(grid_h/grid_w)
- 分别计算水平/垂直方向的1D位置编码
- 拼接为2D位置编码
- 优势:预计算的位置编码无需学习,减少模型参数,且对图像类任务更友好
2.4 DDPM扩散框架实现
python
# ===================== DDPM扩散框架 =====================
def ddpm_schedules(beta1, beta2, T):
beta_t = (beta2 - beta1) * torch.arange(0, T + 1, dtype=torch.float32) / T + beta1
alpha_t = 1 - beta_t
log_alpha_t = torch.log(alpha_t)
alphabar_t = torch.cumsum(log_alpha_t, dim=0).exp()
sqrt_beta_t = torch.sqrt(beta_t)
oneover_sqrta = 1 / torch.sqrt(alpha_t)
sqrtab = torch.sqrt(alphabar_t)
sqrtmab = torch.sqrt(1 - alphabar_t)
mab_over_sqrtmab = (1 - alpha_t) / sqrtmab
return {
"alpha_t": alpha_t, "oneover_sqrta": oneover_sqrta, "sqrt_beta_t": sqrt_beta_t,
"alphabar_t": alphabar_t, "sqrtab": sqrtab, "sqrtmab": sqrtmab,
"mab_over_sqrtmab": mab_over_sqrtmab,
}
class DDPM(nn.Module):
def __init__(self, nn_model, betas=(1e-4, 0.02), n_T=400, device="cpu", drop_prob=0.1):
super().__init__()
self.nn_model = nn_model.to(device)
for k, v in ddpm_schedules(betas[0], betas[1], n_T).items():
self.register_buffer(k, v.to(device)) # 缓冲区移到指定设备
self.n_T = n_T
self.device = device
self.drop_prob = drop_prob
self.loss_mse = nn.MSELoss()
def forward(self, x, c):
_ts = torch.randint(1, self.n_T+1, (x.shape[0],)).to(self.device)
noise = torch.randn_like(x)
x_t = self.sqrtab[_ts, None, None, None] * x + self.sqrtmab[_ts, None, None, None] * noise
context_mask = torch.bernoulli(torch.zeros_like(c)+self.drop_prob).to(self.device)
eps_pred = self.nn_model(x_t, _ts, c, context_mask)
return self.loss_mse(noise, eps_pred)
@torch.no_grad() # 推理阶段禁用梯度,节省显存
def sample(self, n_sample, size=(1,28,28), guide_w=0.0):
x_i = torch.randn(n_sample, *size).to(self.device)
c_i = torch.arange(0,10).to(self.device).repeat(int(n_sample/10))
context_mask = torch.zeros_like(c_i).to(self.device)
c_i = c_i.repeat(2)
context_mask = context_mask.repeat(2)
context_mask[n_sample:] = 1.
x_i_store = []
for i in range(self.n_T, 0, -1):
print(f'采样进度:{i}/{self.n_T}', end='\r')
t_is = torch.tensor([i]*n_sample).to(self.device).repeat(2)
z = torch.randn(n_sample, *size).to(self.device) if i > 1 else 0.
x_i = x_i.repeat(2,1,1,1)
eps = self.nn_model(x_i, t_is, c_i, context_mask)
eps_cond, eps_uncond = eps[:n_sample], eps[n_sample:]
eps = eps_uncond + guide_w * (eps_cond - eps_uncond)
x_i = self.oneover_sqrta[i] * (x_i[:n_sample] - eps * self.mab_over_sqrtmab[i]) + self.sqrt_beta_t[i] * z
if i%20==0 or i==self.n_T or i<8:
x_i_store.append(x_i.detach().cpu().numpy())
return x_i, np.array(x_i_store)
- ddpm_schedules函数:预计算DDPM所需的所有系数(α、β、累积乘积等),避免训练时重复计算
- DDPM类核心功能 :
- 训练阶段(forward) :
- 随机采样时间步_ts
- 正向扩散过程:给干净图像x添加噪声得到x_t
- 调用DiT模型预测噪声eps_pred
- 计算预测噪声与真实噪声的MSE损失
- 采样阶段(sample) :
- 从纯噪声开始,逆向扩散逐步去噪
- CFG实现:同时计算有条件(cond)和无条件(uncond)的噪声预测,加权融合
- guide_w:引导权重,越大生成的图像越符合类别条件,但可能多样性降低
- 保存中间采样结果,用于生成GIF动画
- 训练阶段(forward) :
2.5 工具函数与训练流程
python
# ===================== 工具函数:创建文件夹 =====================
def create_dirs():
"""创建所有需要的文件夹"""
dirs = [DATA_DIR, MODEL_DIR, IMAGE_DIR, GIF_DIR]
for dir_path in dirs:
os.makedirs(dir_path, exist_ok=True)
print(f"📁 创建文件夹: {dir_path}")
print("✅ 所有文件夹创建完成")
# ===================== 核心训练函数 =====================
def train_dit_mnist():
# 1. 创建文件夹
create_dirs()
# 2. 初始化DiT模型
print("\n🚀 初始化DiT模型...")
dit_model = DiT(
input_size=28,
patch_size=4,
in_channels=1,
hidden_size=384,
depth=12,
num_heads=6,
class_dropout_prob=0.1,
num_classes=10,
learn_sigma=False
).to(DEVICE)
# 3. 初始化DDPM框架
ddpm = DDPM(nn_model=dit_model, n_T=N_T, device=DEVICE, drop_prob=0.1)
ddpm.train()
# 4. 加载MNIST数据集(自动下载)
print("\n📥 加载MNIST数据集(自动下载)...")
tf = transforms.Compose([transforms.ToTensor()])
train_dataset = MNIST(DATA_DIR, train=True, download=True, transform=tf)
# DataLoader优化:pin_memory=True加速GPU传输,num_workers=4利用多线程
train_loader = DataLoader(
train_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=4 if os.name != "nt" else 0, # Windows下num_workers设为0避免报错
pin_memory=True,
drop_last=True # 丢弃最后不完整批次,避免维度错误
)
print(f"✅ 数据集加载完成,共{len(train_loader)}个batch,总样本数{len(train_dataset)}")
# 5. 优化器配置(带学习率衰减)
optimizer = torch.optim.AdamW(ddpm.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.999), weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=TRAIN_EPOCHS) # 余弦退火衰减
# 6. 开始训练
print(f"\n============= 开始训练,共{TRAIN_EPOCHS}轮 =============")
for epoch in range(TRAIN_EPOCHS):
print(f"\n===== 训练轮次: {epoch+1}/{TRAIN_EPOCHS} =====")
ddpm.train()
loss_ema = None
pbar = tqdm(train_loader, desc=f"Epoch {epoch+1} Loss: ---")
for x, c in pbar:
# 数据移到GPU
x = x.to(DEVICE, non_blocking=True) # non_blocking加速传输
c = c.to(DEVICE, non_blocking=True)
# 前向传播+反向传播
optimizer.zero_grad()
loss = ddpm(x, c)
loss.backward()
torch.nn.utils.clip_grad_norm_(ddpm.parameters(), max_norm=1.0) # 梯度裁剪,防止爆炸
optimizer.step()
# 平滑显示损失
loss_ema = loss.item() if loss_ema is None else 0.95*loss_ema + 0.05*loss.item()
pbar.set_description(f"Epoch {epoch+1} Loss: {loss_ema:.4f} (LR: {optimizer.param_groups[0]['lr']:.6f})")
# 学习率衰减
scheduler.step()
# 7. 每轮训练后采样生成验证图像
ddpm.eval()
with torch.no_grad():
n_sample = 40
# 加载真实样本用于对比
real_x, real_c = next(iter(train_loader))
real_x = real_x.to(DEVICE)
real_c = real_c.to(DEVICE)
for w_idx, w in enumerate(GUIDE_WEIGHTS):
# 采样生成图像
x_gen, x_gen_store = ddpm.sample(n_sample, size=(1,28,28), guide_w=w)
# 拼接真实图像对比
x_real = torch.Tensor(x_gen.shape).to(DEVICE)
for k in range(10):
for j in range(int(n_sample/10)):
idx_mask = (real_c == k)
idx_list = torch.nonzero(idx_mask).squeeze()
if len(idx_list.shape) == 0:
idx = 0
elif j >= len(idx_list):
idx = idx_list[0].item() if idx_list.numel() > 0 else 0
else:
idx = idx_list[j].item()
x_real[k+(j*10)] = real_x[idx]
# 生成网格图并保存
x_all = torch.cat([x_gen, x_real])
grid = make_grid(x_all*-1 + 1, nrow=10, padding=2)
img_path = os.path.join(IMAGE_DIR, f"epoch_{epoch+1}_guide_w_{w:.1f}.png")
save_image(grid, img_path)
print(f"\n💾 保存生成图像: {img_path}")
# 生成GIF(每2轮生成一次,节省时间)
if (epoch+1) % 2 == 0 or epoch == TRAIN_EPOCHS-1:
fig, axs = plt.subplots(nrows=int(n_sample/10), ncols=10, sharex=True, sharey=True, figsize=(8,3))
def animate_diff(i, x_gen_store):
for row in range(int(n_sample/10)):
for col in range(10):
axs[row, col].clear()
axs[row, col].set_xticks([])
axs[row, col].set_yticks([])
axs[row, col].imshow(-x_gen_store[i,(row*10)+col,0], cmap='gray')
return axs.flatten()
ani = FuncAnimation(fig, animate_diff, fargs=[x_gen_store], interval=200, blit=False, frames=x_gen_store.shape[0])
gif_path = os.path.join(GIF_DIR, f"epoch_{epoch+1}_guide_w_{w:.1f}.gif")
ani.save(gif_path, dpi=100, writer=PillowWriter(fps=5))
print(f"💾 保存采样GIF: {gif_path}")
plt.close(fig) # 释放内存
# 8. 保存模型权重
model_path = os.path.join(MODEL_DIR, f"dit_ddpm_epoch_{epoch+1}.pth")
torch.save(ddpm.state_dict(), model_path)
print(f"💾 保存模型权重: {model_path}")
print(f"\n🎉 训练完成!所有结果保存至: {OUTPUT_DIR}")
# ===================== 启动训练 =====================
if __name__ == "__main__":
train_dit_mnist()
- create_dirs函数 :批量创建所需文件夹,
exist_ok=True避免重复创建报错 - train_dit_mnist核心流程 :
- 模型初始化:创建DiT模型并封装到DDPM框架中
- 数据加载:优化的DataLoader配置(pin_memory/num_workers)提升训练效率
- 优化器配置:AdamW+余弦退火学习率衰减,这是Transformer训练的标配
- 训练循环 :
- 梯度裁剪(max_norm=1.0)防止梯度爆炸
- EMA平滑损失显示,更直观观察训练趋势
- 验证与可视化 :
- 每轮训练后生成图像,对比真实样本
- 生成不同引导权重的结果,观察CFG效果
- 生成扩散过程GIF,直观展示逆向扩散过程
- 模型保存:每轮保存模型权重,方便后续推理/继续训练
三、关键技术点总结
- DiT核心创新:将AdaLN调制融入Transformer Block,实现时间/类别条件的高效融合
- CFG(Classifier-Free Guidance):通过token drop机制实现无类别条件分支,提升生成质量
- DDPM优化:预计算扩散系数、梯度裁剪、余弦退火等工程优化保证训练稳定
- 工程实践:跨平台路径处理、GPU加速配置、数据加载优化等提升代码鲁棒性
四、运行与调优建议
- 环境配置 :建议使用PyTorch 2.0+、CUDA 11.7+,安装依赖:
pip install torch torchvision timm matplotlib tqdm - 显存优化:若显存不足,可减小BATCH_SIZE(如64)或hidden_size(如256)
- 调优方向 :
- 增加训练轮次(如50轮)提升生成质量
- 调整guide_w(推荐1.0-3.0)平衡生成质量与多样性
- 尝试更大的patch_size(如8)或depth(如16)
五、总结
本文详细解读了基于DiT的MNIST扩散模型完整实现代码,从核心架构到工程细节,覆盖了DiT Block设计、条件嵌入、DDPM扩散过程、CFG采样等关键技术点。这份代码不仅实现了基础的生成功能,还包含了大量工程优化和可视化技巧,是学习扩散模型与Transformer结合的绝佳实践案例。通过理解这份代码,你可以快速上手DiT架构,并将其扩展到其他图像生成任务中。