基于DiT+DDPM的MNIST数字生成:模型推理实战教程
一、前言
扩散模型(Diffusion Model)在图像生成领域展现出卓越的性能,本文基于DiT(Diffusion Transformer)+DDPM(Denoising Diffusion Probabilistic Models)实现MNIST手写数字的指定生成,手把手教大家如何加载预训练的扩散模型权重,实现指定单个数字(0-9)的图像生成推理,并保存生成结果。
本文代码可直接运行,核心实现:
- 加载预训练的DiT+DDPM模型权重
- 指定任意数字(0-9)生成对应手写体图像
- 自定义生成样本数量、引导权重等参数
- 保存生成结果和采样过程GIF
二、环境准备
1. 依赖安装
确保安装以下核心依赖包:
bash
pip install torch torchvision timm matplotlib numpy tqdm pillow
2. 预训练模型准备
提前准备好训练完成的DiT+DDPM模型权重文件(.pth格式),本文以model_2.pth为例,路径为./data/diffusion_dit_mnist/model_2.pth。
三、核心代码解析
1. 模型结构复用
首先完整保留DiT扩散模型的核心结构定义(包括DiTBlock、TimestepEmbedder、LabelEmbedder等),这部分代码与训练阶段完全一致,确保模型结构和权重匹配:
python
# 以下为模型核心定义(完整代码见文末)
class SimpleHead(nn.Module):... # 简单头层
class TimestepEmbedder(nn.Module):... # 时间步嵌入
class LabelEmbedder(nn.Module):... # 类别标签嵌入
class DiTBlock(nn.Module):... # DiT核心块
class FinalLayer(nn.Module):... # 最终输出层
class DiT(nn.Module):... # 完整DiT模型
2. DDPM采样逻辑改造(核心)
原DDPM的sample方法默认生成0-9全数字,我们新增sample_specific_digit方法,实现指定数字生成:
python
class DDPM(nn.Module):
def __init__(self, nn_model, betas=(1e-4, 0.02), n_T=400, device="cpu", drop_prob=0.1):
super().__init__()
self.nn_model = nn_model.to(device)
# 注册扩散过程的关键参数(beta/alpha等)
for k, v in ddpm_schedules(betas[0], betas[1], n_T).items():
self.register_buffer(k, v)
self.n_T = n_T
self.device = device
self.drop_prob = drop_prob
self.loss_mse = nn.MSELoss()
# 新增:指定数字采样方法
def sample_specific_digit(self, target_digit, n_sample=4, size=(1,28,28), guide_w=2.0):
"""
生成指定数字的图像
:param target_digit: 要生成的数字(0-9)
:param n_sample: 生成该数字的样本数量
:param size: 图像尺寸 (1,28,28)
:param guide_w: CFG引导权重(越大生成越精准)
:return: 生成的图像张量,采样过程存储
"""
# 1. 初始化噪声(从标准正态分布采样)
x_i = torch.randn(n_sample, *size).to(self.device)
# 2. 关键:类别标签仅包含指定数字(而非0-9全类别)
c_i = torch.tensor([target_digit]*n_sample).to(self.device)
context_mask = torch.zeros_like(c_i).to(self.device)
# 3. CFG(Classifier-Free Guidance)策略:提升生成精准度
c_i = c_i.repeat(2)
context_mask = context_mask.repeat(2)
context_mask[n_sample:] = 1. # 后半部分为无条件采样
x_i_store = [] # 保存采样过程,用于生成GIF
for i in range(self.n_T, 0, -1):
print(f'采样进度:{i}/{self.n_T} (生成数字{target_digit})', end='\r')
t_is = torch.tensor([i]*n_sample).to(self.device).repeat(2)
z = torch.randn(n_sample, *size).to(self.device) if i > 1 else 0.
# 复制张量用于CFG计算
x_i = x_i.repeat(2,1,1,1)
eps = self.nn_model(x_i, t_is, c_i, context_mask)
# CFG核心计算:有条件预测 - 无条件预测
eps_cond, eps_uncond = eps[:n_sample], eps[n_sample:]
eps = eps_uncond + guide_w * (eps_cond - eps_uncond)
# 反向扩散步骤:去噪
x_i = self.oneover_sqrta[i] * (x_i[:n_sample] - eps * self.mab_over_sqrtmab[i]) + self.sqrt_beta_t[i] * z
# 保存采样过程(可选)
if i%20==0 or i==self.n_T or i<8:
x_i_store.append(x_i.detach().cpu().numpy())
print() # 换行
return x_i, np.array(x_i_store)
3. 封装推理函数(一键调用)
将模型加载、采样、结果保存封装为generate_specific_digit函数,降低使用门槛:
python
def generate_specific_digit(
pretrained_pth, # 预训练模型路径
target_digit, # 要生成的数字(0-9)
n_sample=4, # 生成该数字的样本数
guide_w=2.0, # CFG引导权重(推荐2.0)
save_dir='./generated_digits/',
device="cpu"
):
"""
生成指定数字的MNIST图像
:param pretrained_pth: 预训练.pth文件路径
:param target_digit: 目标数字(0-9)
:param n_sample: 生成样本数量
:param guide_w: CFG引导权重
:param save_dir: 保存目录
:param device: 运行设备(cpu/cuda)
"""
# 1. 参数校验
if not 0 <= target_digit <=9:
raise ValueError("target_digit必须是0-9之间的整数!")
if not os.path.exists(pretrained_pth):
raise FileNotFoundError(f"预训练文件不存在:{pretrained_pth}")
# 2. 创建保存目录
os.makedirs(save_dir, exist_ok=True)
# 3. 初始化DiT模型(与训练时参数完全一致)
dit_model = DiT(
input_size=28,
patch_size=4,
in_channels=1,
hidden_size=384,
depth=12,
num_heads=6,
class_dropout_prob=0.1,
num_classes=10,
learn_sigma=False
)
# 4. 加载预训练权重
ddpm = DDPM(nn_model=dit_model, n_T=400, device=device, drop_prob=0.1)
checkpoint = torch.load(pretrained_pth, map_location=device, weights_only=True)
ddpm.load_state_dict(checkpoint)
ddpm.to(device)
ddpm.eval() # 必须设置为评估模式(禁用Dropout等训练层)
print(f"✅ 加载模型完成,开始生成数字{target_digit}...")
# 5. 生成指定数字(禁用梯度计算,提升速度)
with torch.no_grad():
x_gen, x_gen_store = ddpm.sample_specific_digit(
target_digit=target_digit,
n_sample=n_sample,
guide_w=guide_w
)
# 6. 保存生成结果
# 调整图像范围(MNIST为黑底白字,反转后更易查看)
grid = make_grid(x_gen * -1 + 1, nrow=int(math.sqrt(n_sample)))
save_path = os.path.join(save_dir, f"digit_{target_digit}_samples_{n_sample}_w{guide_w}.png")
save_image(grid, save_path)
print(f"💾 生成结果已保存:{save_path}")
# 7. 生成采样过程GIF(可选,直观展示扩散去噪过程)
fig, axs = plt.subplots(nrows=1, ncols=n_sample, figsize=(n_sample*2, 2))
if n_sample == 1:
axs = [axs] # 处理单样本情况
def animate(i):
for idx in range(n_sample):
axs[idx].clear()
axs[idx].set_xticks([])
axs[idx].set_yticks([])
axs[idx].imshow(-x_gen_store[i, idx, 0], cmap='gray')
return axs
ani = FuncAnimation(fig, animate, frames=len(x_gen_store), interval=200, blit=False)
gif_path = os.path.join(save_dir, f"digit_{target_digit}_process_w{guide_w}.gif")
ani.save(gif_path, dpi=100, writer=PillowWriter(fps=5))
plt.close()
print(f"💾 采样过程GIF已保存:{gif_path}")
return x_gen
四、推理实战
1. 基础调用(生成单个数字)
修改if __name__ == "__main__"中的参数,一键生成指定数字:
python
if __name__ == "__main__":
# 配置参数
PRETRAINED_PTH = "./data/diffusion_dit_mnist/model_2.pth" # 你的模型路径
TARGET_DIGIT = 5 # 要生成的数字(0-9任选)
N_SAMPLE = 4 # 生成4个该数字的样本
GUIDE_W = 2.0 # CFG引导权重(推荐1.0-3.0)
DEVICE = "cpu" # 有GPU改为"cuda:0"
# 调用生成函数
generate_specific_digit(
pretrained_pth=PRETRAINED_PTH,
target_digit=TARGET_DIGIT,
n_sample=N_SAMPLE,
guide_w=GUIDE_W,
device=DEVICE
)
2. 多数字批量生成(扩展)
循环调用函数,批量生成0-9所有数字:
python
if __name__ == "__main__":
PRETRAINED_PTH = "./data/diffusion_dit_mnist/model_2.pth"
DEVICE = "cpu"
# 循环生成0-9所有数字
for digit in range(10):
generate_specific_digit(
pretrained_pth=PRETRAINED_PTH,
target_digit=digit,
n_sample=4,
guide_w=2.0,
device=DEVICE
)
五、关键参数调优
| 参数 | 作用 | 推荐值 |
|---|---|---|
target_digit |
指定生成的数字 | 0-9 |
n_sample |
生成该数字的样本数量 | 1-16(根据显存调整) |
guide_w(CFG权重) |
生成精准度控制 | 1.0-3.0(越大越精准,过小易生成错误数字) |
device |
运行设备 | 有GPU用"cuda:0"(速度提升10倍+) |
六、结果展示
1. 生成的数字图像
生成的图像默认保存至./generated_digits/目录,文件名格式:digit_5_samples_4_w2.0.png,示例效果:
- 5个样本的数字5图像拼接成2x2网格
- 黑底白字(MNIST标准格式),边缘清晰,数字特征明显
2. 采样过程GIF
自动生成扩散去噪过程的GIF(digit_5_process_w2.0.gif),可直观看到从噪声逐步变为清晰数字的过程。
七、常见问题解决
1. 模型加载报错
- 原因:模型结构与权重不匹配、路径错误
- 解决:确保推理时的DiT模型参数(hidden_size、depth等)与训练时完全一致;检查
PRETRAINED_PTH路径是否正确。
2. 生成图像模糊/错误
- 原因:CFG权重过小、模型训练轮数不足
- 解决:调大
guide_w至2.0-3.0;增加模型训练轮数。
3. 显存不足
- 原因:
n_sample过大、使用CPU运行 - 解决:减小
n_sample至1-2;切换到GPU运行。
八、总结
本文实现了基于DiT+DDPM的MNIST指定数字生成推理,核心要点:
- 复用训练阶段的模型结构,确保权重匹配;
- 改造采样逻辑,让类别标签仅包含指定数字;
- 封装一键调用函数,支持参数自定义;
- 通过CFG权重控制生成精准度,提升效果。
该方法可扩展到其他分类生成任务(如CIFAR-10、FashionMNIST),只需调整模型输入尺寸、类别数等参数即可。
扩展思考
- 可增加图像后处理(如二值化、边缘增强),提升生成数字的清晰度;
- 结合GradIO制作可视化界面,实现交互式数字生成;
- 调整DiT模型参数(如depth、hidden_size),平衡生成质量和速度。