基于DiT+DDPM的MNIST数字生成:模型推理实战教程

基于DiT+DDPM的MNIST数字生成:模型推理实战教程

一、前言

扩散模型(Diffusion Model)在图像生成领域展现出卓越的性能,本文基于DiT(Diffusion Transformer)+DDPM(Denoising Diffusion Probabilistic Models)实现MNIST手写数字的指定生成,手把手教大家如何加载预训练的扩散模型权重,实现指定单个数字(0-9)的图像生成推理,并保存生成结果。

本文代码可直接运行,核心实现:

  • 加载预训练的DiT+DDPM模型权重
  • 指定任意数字(0-9)生成对应手写体图像
  • 自定义生成样本数量、引导权重等参数
  • 保存生成结果和采样过程GIF

二、环境准备

1. 依赖安装

确保安装以下核心依赖包:

bash 复制代码
pip install torch torchvision timm matplotlib numpy tqdm pillow

2. 预训练模型准备

提前准备好训练完成的DiT+DDPM模型权重文件(.pth格式),本文以model_2.pth为例,路径为./data/diffusion_dit_mnist/model_2.pth

三、核心代码解析

1. 模型结构复用

首先完整保留DiT扩散模型的核心结构定义(包括DiTBlock、TimestepEmbedder、LabelEmbedder等),这部分代码与训练阶段完全一致,确保模型结构和权重匹配:

python 复制代码
# 以下为模型核心定义(完整代码见文末)
class SimpleHead(nn.Module):...  # 简单头层
class TimestepEmbedder(nn.Module):...  # 时间步嵌入
class LabelEmbedder(nn.Module):...  # 类别标签嵌入
class DiTBlock(nn.Module):...  # DiT核心块
class FinalLayer(nn.Module):...  # 最终输出层
class DiT(nn.Module):...  # 完整DiT模型

2. DDPM采样逻辑改造(核心)

原DDPM的sample方法默认生成0-9全数字,我们新增sample_specific_digit方法,实现指定数字生成:

python 复制代码
class DDPM(nn.Module):
    def __init__(self, nn_model, betas=(1e-4, 0.02), n_T=400, device="cpu", drop_prob=0.1):
        super().__init__()
        self.nn_model = nn_model.to(device)
        # 注册扩散过程的关键参数(beta/alpha等)
        for k, v in ddpm_schedules(betas[0], betas[1], n_T).items():
            self.register_buffer(k, v)
        self.n_T = n_T
        self.device = device
        self.drop_prob = drop_prob
        self.loss_mse = nn.MSELoss()

    # 新增:指定数字采样方法
    def sample_specific_digit(self, target_digit, n_sample=4, size=(1,28,28), guide_w=2.0):
        """
        生成指定数字的图像
        :param target_digit: 要生成的数字(0-9)
        :param n_sample: 生成该数字的样本数量
        :param size: 图像尺寸 (1,28,28)
        :param guide_w: CFG引导权重(越大生成越精准)
        :return: 生成的图像张量,采样过程存储
        """
        # 1. 初始化噪声(从标准正态分布采样)
        x_i = torch.randn(n_sample, *size).to(self.device)
        
        # 2. 关键:类别标签仅包含指定数字(而非0-9全类别)
        c_i = torch.tensor([target_digit]*n_sample).to(self.device)
        context_mask = torch.zeros_like(c_i).to(self.device)
        
        # 3. CFG(Classifier-Free Guidance)策略:提升生成精准度
        c_i = c_i.repeat(2)
        context_mask = context_mask.repeat(2)
        context_mask[n_sample:] = 1.  # 后半部分为无条件采样

        x_i_store = []  # 保存采样过程,用于生成GIF
        for i in range(self.n_T, 0, -1):
            print(f'采样进度:{i}/{self.n_T} (生成数字{target_digit})', end='\r')
            t_is = torch.tensor([i]*n_sample).to(self.device).repeat(2)
            z = torch.randn(n_sample, *size).to(self.device) if i > 1 else 0.

            # 复制张量用于CFG计算
            x_i = x_i.repeat(2,1,1,1)
            eps = self.nn_model(x_i, t_is, c_i, context_mask)
            # CFG核心计算:有条件预测 - 无条件预测
            eps_cond, eps_uncond = eps[:n_sample], eps[n_sample:]
            eps = eps_uncond + guide_w * (eps_cond - eps_uncond)

            # 反向扩散步骤:去噪
            x_i = self.oneover_sqrta[i] * (x_i[:n_sample] - eps * self.mab_over_sqrtmab[i]) + self.sqrt_beta_t[i] * z
            # 保存采样过程(可选)
            if i%20==0 or i==self.n_T or i<8:
                x_i_store.append(x_i.detach().cpu().numpy())
        print()  # 换行
        return x_i, np.array(x_i_store)

3. 封装推理函数(一键调用)

将模型加载、采样、结果保存封装为generate_specific_digit函数,降低使用门槛:

python 复制代码
def generate_specific_digit(
    pretrained_pth,          # 预训练模型路径
    target_digit,            # 要生成的数字(0-9)
    n_sample=4,              # 生成该数字的样本数
    guide_w=2.0,             # CFG引导权重(推荐2.0)
    save_dir='./generated_digits/',
    device="cpu"
):
    """
    生成指定数字的MNIST图像
    :param pretrained_pth: 预训练.pth文件路径
    :param target_digit: 目标数字(0-9)
    :param n_sample: 生成样本数量
    :param guide_w: CFG引导权重
    :param save_dir: 保存目录
    :param device: 运行设备(cpu/cuda)
    """
    # 1. 参数校验
    if not 0 <= target_digit <=9:
        raise ValueError("target_digit必须是0-9之间的整数!")
    if not os.path.exists(pretrained_pth):
        raise FileNotFoundError(f"预训练文件不存在:{pretrained_pth}")
    
    # 2. 创建保存目录
    os.makedirs(save_dir, exist_ok=True)
    
    # 3. 初始化DiT模型(与训练时参数完全一致)
    dit_model = DiT(
        input_size=28,
        patch_size=4,
        in_channels=1,
        hidden_size=384,
        depth=12,
        num_heads=6,
        class_dropout_prob=0.1,
        num_classes=10,
        learn_sigma=False
    )
    
    # 4. 加载预训练权重
    ddpm = DDPM(nn_model=dit_model, n_T=400, device=device, drop_prob=0.1)
    checkpoint = torch.load(pretrained_pth, map_location=device, weights_only=True)
    ddpm.load_state_dict(checkpoint)
    ddpm.to(device)
    ddpm.eval()  # 必须设置为评估模式(禁用Dropout等训练层)
    print(f"✅ 加载模型完成,开始生成数字{target_digit}...")
    
    # 5. 生成指定数字(禁用梯度计算,提升速度)
    with torch.no_grad():
        x_gen, x_gen_store = ddpm.sample_specific_digit(
            target_digit=target_digit,
            n_sample=n_sample,
            guide_w=guide_w
        )
    
    # 6. 保存生成结果
    # 调整图像范围(MNIST为黑底白字,反转后更易查看)
    grid = make_grid(x_gen * -1 + 1, nrow=int(math.sqrt(n_sample)))
    save_path = os.path.join(save_dir, f"digit_{target_digit}_samples_{n_sample}_w{guide_w}.png")
    save_image(grid, save_path)
    print(f"💾 生成结果已保存:{save_path}")
    
    # 7. 生成采样过程GIF(可选,直观展示扩散去噪过程)
    fig, axs = plt.subplots(nrows=1, ncols=n_sample, figsize=(n_sample*2, 2))
    if n_sample == 1:
        axs = [axs]  # 处理单样本情况
    
    def animate(i):
        for idx in range(n_sample):
            axs[idx].clear()
            axs[idx].set_xticks([])
            axs[idx].set_yticks([])
            axs[idx].imshow(-x_gen_store[i, idx, 0], cmap='gray')
        return axs
    
    ani = FuncAnimation(fig, animate, frames=len(x_gen_store), interval=200, blit=False)
    gif_path = os.path.join(save_dir, f"digit_{target_digit}_process_w{guide_w}.gif")
    ani.save(gif_path, dpi=100, writer=PillowWriter(fps=5))
    plt.close()
    print(f"💾 采样过程GIF已保存:{gif_path}")
    
    return x_gen

四、推理实战

1. 基础调用(生成单个数字)

修改if __name__ == "__main__"中的参数,一键生成指定数字:

python 复制代码
if __name__ == "__main__":
    # 配置参数
    PRETRAINED_PTH = "./data/diffusion_dit_mnist/model_2.pth"  # 你的模型路径
    TARGET_DIGIT = 5                                           # 要生成的数字(0-9任选)
    N_SAMPLE = 4                                               # 生成4个该数字的样本
    GUIDE_W = 2.0                                              # CFG引导权重(推荐1.0-3.0)
    DEVICE = "cpu"                                             # 有GPU改为"cuda:0"
    
    # 调用生成函数
    generate_specific_digit(
        pretrained_pth=PRETRAINED_PTH,
        target_digit=TARGET_DIGIT,
        n_sample=N_SAMPLE,
        guide_w=GUIDE_W,
        device=DEVICE
    )

2. 多数字批量生成(扩展)

循环调用函数,批量生成0-9所有数字:

python 复制代码
if __name__ == "__main__":
    PRETRAINED_PTH = "./data/diffusion_dit_mnist/model_2.pth"
    DEVICE = "cpu"
    # 循环生成0-9所有数字
    for digit in range(10):
        generate_specific_digit(
            pretrained_pth=PRETRAINED_PTH,
            target_digit=digit,
            n_sample=4,
            guide_w=2.0,
            device=DEVICE
        )

五、关键参数调优

参数 作用 推荐值
target_digit 指定生成的数字 0-9
n_sample 生成该数字的样本数量 1-16(根据显存调整)
guide_w(CFG权重) 生成精准度控制 1.0-3.0(越大越精准,过小易生成错误数字)
device 运行设备 有GPU用"cuda:0"(速度提升10倍+)

六、结果展示

1. 生成的数字图像

生成的图像默认保存至./generated_digits/目录,文件名格式:digit_5_samples_4_w2.0.png,示例效果:

  • 5个样本的数字5图像拼接成2x2网格
  • 黑底白字(MNIST标准格式),边缘清晰,数字特征明显

2. 采样过程GIF

自动生成扩散去噪过程的GIF(digit_5_process_w2.0.gif),可直观看到从噪声逐步变为清晰数字的过程。

七、常见问题解决

1. 模型加载报错

  • 原因:模型结构与权重不匹配、路径错误
  • 解决:确保推理时的DiT模型参数(hidden_size、depth等)与训练时完全一致;检查PRETRAINED_PTH路径是否正确。

2. 生成图像模糊/错误

  • 原因:CFG权重过小、模型训练轮数不足
  • 解决:调大guide_w至2.0-3.0;增加模型训练轮数。

3. 显存不足

  • 原因:n_sample过大、使用CPU运行
  • 解决:减小n_sample至1-2;切换到GPU运行。

八、总结

本文实现了基于DiT+DDPM的MNIST指定数字生成推理,核心要点:

  1. 复用训练阶段的模型结构,确保权重匹配;
  2. 改造采样逻辑,让类别标签仅包含指定数字;
  3. 封装一键调用函数,支持参数自定义;
  4. 通过CFG权重控制生成精准度,提升效果。

该方法可扩展到其他分类生成任务(如CIFAR-10、FashionMNIST),只需调整模型输入尺寸、类别数等参数即可。

扩展思考

  1. 可增加图像后处理(如二值化、边缘增强),提升生成数字的清晰度;
  2. 结合GradIO制作可视化界面,实现交互式数字生成;
  3. 调整DiT模型参数(如depth、hidden_size),平衡生成质量和速度。
相关推荐
天使Di María1 小时前
脑电大模型系列——第一弹:BENDR
人工智能·大模型·脑机接口·精准解码
AI智能观察1 小时前
2026交通数字人智能体Top5 :厂商深度解析,赋能智慧交通新生态
人工智能·智慧城市·数字人·智慧交通·智能体
我的xiaodoujiao2 小时前
使用 Python 语言 从 0 到 1 搭建完整 Web UI自动化测试学习系列 48--本地环境部署Jenkins服务
python·学习·测试工具·pytest
冰西瓜6002 小时前
深度学习的数学原理(五)—— 非线性与激活函数
人工智能·深度学习
田里的水稻2 小时前
FA_规划和控制(PC)-D*规划
人工智能·算法·数学建模·机器人·自动驾驶
love530love2 小时前
【OpenClaw 本地实战 Ep.2】零代码对接:使用交互式向导快速连接本地 LM Studio 用 CUDA GPU 推理
人工智能·windows·gpu·cuda·ollama·lm studio·openclaw
喵手2 小时前
Python爬虫实战:爬取得到App电子书畅销榜 - 从零到交付的完整实战!
爬虫·python·爬虫实战·零基础python爬虫教学·爬取app电子书畅销榜·app电子书畅销榜单数据获取
2401_828890642 小时前
实现变分自编码器 VAE- MNIST 数据集
人工智能·python·深度学习·cnn·transformer
PD我是你的真爱粉2 小时前
RabbitMQ架构实战
python·架构·rabbitmq