实现扩散模型 Stable Diffusion - MNIST 数据集

1 测试结果:

实现变分自编码器 VAE - MNIST 数据集

实现时间步调度器

2 模型实现:

python 复制代码
import math
import os
import time

import torch
import torchvision
from matplotlib import pyplot as plt
from torch import nn, Tensor
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader
from torchvision import transforms as T
from tqdm import tqdm
from transformers import CLIPTokenizer, CLIPTextModel

from scheduler import TimeStepScheduler
from vae import VAEWithLoss


# 文本编码器
class TextEncoder(nn.Module):
    def __init__(self, path: str = None):
        super(TextEncoder, self).__init__()
        if path is None:  # 在线加载
            path = r"openai/clip-vit-base-patch32"
        # 分词器
        self.tokenizer = CLIPTokenizer.from_pretrained(path)
        # 文本模型
        self.encoder = CLIPTextModel.from_pretrained(path).eval()

    def forward(self, texts: list[str] | tuple[str]) -> Tensor:
        with torch.no_grad():
            inputs = self.tokenizer(
                texts,
                add_special_tokens=False,
                padding=True,
                truncation=True,
                max_length=512,
                return_tensors="pt",
            )
            outputs = self.encoder(return_dict=True, **inputs)

        return outputs["last_hidden_state"]  # (batch, seq_len, 512)


# 时间步嵌入
class TimestepEmbedding(nn.Module):
    def __init__(self,
                 max_step: int = 1000,
                 d_model: int = 512):
        super(TimestepEmbedding, self).__init__()
        # 初始化位置向量
        pe = torch.zeros((max_step, d_model))
        position = torch.arange(0, max_step).unsqueeze(1)
        div_term = torch.exp(
            # 1 / 10000 ** (2 * k / d)
            torch.arange(0, d_model, 2) * -(math.log(10000) / d_model)
        )
        # 偶数位置
        pe[:, 0::2] = torch.sin(position * div_term)
        # 奇数位置
        pe[:, 1::2] = torch.cos(position * div_term)
        # 缓存,(max_step, d_model)
        self.register_buffer("pe", pe)
        # MLP 层
        self.mlp = nn.Sequential(
            nn.Linear(d_model, 4 * d_model),
            nn.SiLU(),
            nn.Linear(4 * d_model, d_model),
        )

    def forward(self, t: Tensor) -> Tensor:
        if t.dim() > 1:  # (batch, 1)
            t = t.squeeze(-1)
        t = t.long()

        return self.mlp(self.pe[t])  # (batch, d_model)


# 标准残差块
class ResNetBlock(nn.Module):
    def __init__(self, c1: int, c2: int,
                 c: int = None, d: int = 512):
        super(ResNetBlock, self).__init__()
        if c is None:  # 中间通道数
            c = c2
        # 卷积层
        self.layer1 = nn.Sequential(
            nn.Conv2d(c1, c, kernel_size=3, padding=1, bias=False),
            nn.GroupNorm(32, c),
            nn.SiLU(),
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(c, c2, kernel_size=3, padding=1, bias=False),
            nn.GroupNorm(32, c2),
            nn.SiLU(),
        )
        # 线性层
        self.fc = nn.Sequential(
            nn.SiLU(),
            nn.Linear(d, c)  # time_dim
        )
        # 残差连接
        self.shortcut = nn.Conv2d(c1, c2, kernel_size=1) \
            if c1 != c2 else nn.Identity()

    def forward(self, x: Tensor, e: Tensor) -> Tensor:
        y = self.layer1(x)
        # 添加时间步嵌入
        e = self.fc(e)
        e = e.view(e.size(0), -1, 1, 1)  # (batch, c, 1, 1)
        y = self.layer2(y + e)

        return y + self.shortcut(x)  # (batch, c2, height, width)


# 交叉注意力机制
class CrossAttention(nn.Module):
    def __init__(self, d1: int, d2: int):
        super(CrossAttention, self).__init__()
        # 线性层
        self.fc_q = nn.Linear(d1, d1)
        self.fc_k = nn.Linear(d2, d1)
        self.fc_v = nn.Linear(d2, d1)
        self.fc_out = nn.Linear(d1, d1)
        # 注意力层
        self.attn = nn.MultiheadAttention(d1, num_heads=8, batch_first=True)

    def forward(self, x1: Tensor, x2: Tensor) -> Tensor:
        q = self.fc_q(x1)
        k = self.fc_k(x2)
        v = self.fc_v(x2)
        y, _ = self.attn(q, k, v)

        return self.fc_out(y)  # (batch, seq_len, d1)


# 交叉注意力残差块
class CrossAttnResNetBlock(nn.Module):
    def __init__(self, c1: int, c2: int, c: int = None,
                 d1: int = 512, d2: int = 512):
        super(CrossAttnResNetBlock, self).__init__()
        # 标准残差块
        self.res = ResNetBlock(c1, c2, c, d1)
        # 归一化层
        self.norm = nn.GroupNorm(32, c2)
        # 交叉注意力层
        self.attn = CrossAttention(c2, d2)  # text_dim

    def forward(self, x: Tensor, e: Tensor, text: Tensor) -> Tensor:
        y = self.res(x, e)
        residual = y

        batch, channel, height, width = y.shape
        y = self.norm(y)
        # (batch, height * width, channel)
        y = y.view(batch, channel, -1).transpose(1, 2)
        y = self.attn(y, text)
        y = y.transpose(1, 2).view(batch, channel, height, width)

        return y + residual  # (batch, c2, height, width)


# 编码层
class Down(nn.Module):
    def __init__(self, c1: int, c2: int,
                 down: bool = False, **kwargs):
        super(Down, self).__init__()
        # 交叉注意力残差块
        self.res = CrossAttnResNetBlock(c1, c2, **kwargs)
        # 下采样层
        self.down = nn.Conv2d(c2, c2, kernel_size=3, stride=2, padding=1) \
            if down else nn.Identity()

    def forward(self, x: Tensor, e: Tensor, text: Tensor) -> Tensor:
        return self.down(self.res(x, e, text))


# 中间层
class Mid(nn.Module):
    def __init__(self, c: int, **kwargs):
        super(Mid, self).__init__()
        # 交叉注意力残差块
        self.res1 = CrossAttnResNetBlock(c, c, **kwargs)
        self.res2 = CrossAttnResNetBlock(c, c, **kwargs)
        # 映射层
        self.layers = nn.Sequential(
            nn.GroupNorm(32, c),
            nn.Conv2d(c, 3 * c, kernel_size=3, padding=1, bias=False),
        )
        # 自注意力层
        self.attn = nn.MultiheadAttention(c, num_heads=8, batch_first=True)

    def forward(self, x: Tensor, e: Tensor, text: Tensor) -> Tensor:
        y = self.res1(x, e, text)
        # 提取图像内部特征
        batch, channel, height, width = y.shape
        # 统一计算 Q, K, V
        qkv = self.layers(x)
        qkv = qkv.view(batch, 3 * channel, -1).transpose(1, 2)
        q, k, v = qkv.chunk(3, dim=-1)
        y, _ = self.attn(q, k, v)
        y = y.transpose(1, 2).view(batch, channel, height, width)

        return self.res2(y, e, text)


# 解码层
class Up(nn.Module):
    def __init__(self, c1: int, c2: int,
                 up: bool = False, **kwargs):
        super(Up, self).__init__()
        # 转置卷积层
        self.up = nn.ConvTranspose2d(c1, c1 // 2, kernel_size=2, stride=2) \
            if up else nn.Conv2d(c1, c1 // 2, kernel_size=3, padding=1)
        # 交叉注意力残差块
        self.res = CrossAttnResNetBlock(c1, c2, **kwargs)

    def forward(self, x1: Tensor, x2: Tensor, e: Tensor, text: Tensor) -> Tensor:
        x1 = self.up(x1)
        # 跳跃连接
        x = torch.cat([x1, x2], dim=1)

        return self.res(x, e, text)


# U-Net
class UNet(nn.Module):
    def __init__(self,
                 max_step: int = 1000,
                 beta_mode: str = "linear",
                 time_dim: int = 512,
                 text_dim: int = 512):
        super(UNet, self).__init__()
        # 时间步调度器
        self.scheduler = TimeStepScheduler(
            max_step,
            beta_mode=beta_mode,
            device=device,
        )
        # 时间步嵌入
        self.embed = TimestepEmbedding(max_step, time_dim)
        # 输入层
        self.conv_in = nn.Conv2d(4, 64, kernel_size=3, padding=1)
        # 编码层
        param = {"d1": time_dim, "d2": text_dim}
        self.down1 = Down(64, 128, **param)
        self.down2 = Down(128, 256, **param)
        # 中间层
        self.mid = Mid(256, **param)
        # 解码层
        self.up1 = Up(256, 128, **param)
        self.up2 = Up(128, 64, **param)
        # 输出层
        self.conv_out = nn.Conv2d(64, 4, kernel_size=3, padding=1)

    def forward(self, x: Tensor, text: Tensor) -> tuple[Tensor, ...]:
        ret = self.scheduler(x)
        # 预测噪声
        pred_noise = self._steps(ret["x_t"], ret["t"], text)

        return pred_noise, ret["noise"]

    # 预测
    def predict(self, **kwargs) -> tuple[Tensor, list[Tensor]]:
        self.eval()  # 测试模型
        with torch.no_grad():
            # 参考 TimeStepScheduler
            return self.scheduler.predict(self._steps, **kwargs)

    # 编解码
    def _steps(self, x: Tensor, t: Tensor, text: Tensor) -> Tensor:
        e = self.embed(t)
        x1 = self.conv_in(x)

        x2 = self.down1(x1, e, text)
        x3 = self.down2(x2, e, text)
        y = self.mid(x3, e, text)
        y = self.up1(y, x2, e, text)
        y = self.up2(y, x1, e, text)

        return self.conv_out(y)


# Stable Diffusion
class StableDiffusion(nn.Module):
    def __init__(self, path_vae: str, **kwargs):
        super(StableDiffusion, self).__init__()
        # VAE
        self.vae = self._load_vae(path_vae)
        # U-Net
        self.u_net = UNet(**kwargs)

    def forward(self, x: Tensor, text: Tensor) -> tuple[Tensor, ...]:
        x = self.vae.encode(x)
        return self.u_net(x, text)

    # 预测
    def predict(self, **kwargs) -> tuple[Tensor, list[Tensor]]:
        self.eval()
        with torch.no_grad():
            x, lst = self.u_net.predict(**kwargs)
            return self.vae.decode(x), lst

    # 加载 VAE 模型
    @staticmethod
    def _load_vae(path: str) -> nn.Module:
        info = torch.load(path, map_location=device)
        model = VAEWithLoss(**info["param"]).to(device)
        model.load_state_dict(info["weight"])
        model.eval()
        # 冻结参数
        for param in model.parameters():
            param.requires_grad = False

        return model


# Stable Diffusion With MSELoss
class SDWithLoss(nn.Module):
    def __init__(self, **kwargs):
        super(SDWithLoss, self).__init__()
        self.model = StableDiffusion(**kwargs)
        # 均方差损失函数
        self.criterion = nn.MSELoss()

    def forward(self, x: Tensor, text: Tensor) -> Tensor:
        pred_noise, noise = self.model(x, text)
        return self.criterion(pred_noise, noise)

    # 预测
    def predict(self, **kwargs) -> tuple[Tensor, list[Tensor]]:
        self.eval()
        with torch.no_grad():
            return self.model.predict(**kwargs)

3 训练及验证:

python 复制代码
# 主函数
class TrainValTest:
    def __init__(self):
        self.loader_train, self.loader_val = None, None  # DataLoader
        self.model = None  # 模型
        self.optimizer = None  # 优化器
        self.scheduler = None  # 学习率调度器
        self._init_score()  # 指标
        self.round = 0  # 当前训练轮次

    def __call__(self):
        print("加载数据集...")
        self._load_dataset()

        print("创建模型...")
        self._create_model()

        print("开始训练...")
        for i in range(config["epoch"]):
            # 初始化
            self.round = i + 1
            self._init_score()
            self.optimizer.zero_grad()
            time.sleep(1)

            # 训练
            self._train()
            time.sleep(1)
            # 验证
            if self.round % config["val_step"] == 0:
                self._val()
                time.sleep(1)
            if torch.cuda.is_available():  # 清理缓存
                torch.cuda.empty_cache()

            # 更新学习率
            self.scheduler.step()
            # 保存训练指标
            self._save_loss()
            # 保存模型
            if self.round > config["model_round"]:
                self._save_model()

    # 初始化指标
    def _init_score(self):
        self.loss_train, self.loss_val = 0, 0

    # 加载数据集
    def _load_dataset(self):
        # 下载手写数字 MNIST 数据集
        train_dataset = torchvision.datasets.MNIST(
            root=os.path.join(config["root"], "data"),
            train=True, download=True, transform=tf,
        )
        test_dataset = torchvision.datasets.MNIST(
            root=os.path.join(config["root"], "data"),
            train=False, download=True, transform=tf,
        )
        self.loader_train = DataLoader(
            train_dataset,
            batch_size=config["batch_size"][0],
            shuffle=True,
            collate_fn=self._fun,
        )
        self.loader_val = DataLoader(
            test_dataset,
            batch_size=config["batch_size"][1],
            shuffle=False,
            collate_fn=self._fun,
        )
        print(
            f"训练集加载器:{len(self.loader_train)}",
            f"验证集加载器:{len(self.loader_val)}",
            sep=","
        )

    # 创建模型
    def _create_model(self):
        self.model = SDWithLoss(
            path_vae=config["path_vae"],
            max_step=config["max_step"],
            beta_mode=config["beta_mode"],
            time_dim=config["time_dim"],
            text_dim=config["text_dim"],
        ).to(device)
        self.optimizer = AdamW(
            self.model.parameters(),
            lr=config["lr"],
        )
        self.scheduler = CosineAnnealingLR(
            self.optimizer,
            T_max=config["epoch"],
            eta_min=config["lr"] * 0.1,
        )
        num_params = sum(p.numel() for p in self.model.parameters())
        print(f"总参数量:{num_params}")

    # 训练
    def _train(self):
        self.model.train()  # 训练模式
        length = len(self.loader_train)
        for i, data in enumerate(tqdm(self.loader_train)):
            # 向前计算
            loss = self.model(*data)
            # 损失
            self.loss_train += loss.item() / length
            # 反向传播,梯度累积
            loss /= config["acc_step"]
            loss.backward()
            # 每 acc_step 步更新一次参数
            if (i + 1) % config["acc_step"] == 0:
                self._update_grad()
                if torch.cuda.is_available():  # 清理缓存
                    torch.cuda.empty_cache()
        # 处理最后一个不完整的累积批次
        if length % config["acc_step"] != 0:
            self._update_grad()
        # 打印日志
        print(
            f"第 {self.round} 轮",
            "训练损失:{:.4f}".format(self.loss_train),
            sep=","
        )

    # 验证
    def _val(self):
        self.model.eval()  # 测试模式
        length = len(self.loader_val)
        with torch.no_grad():
            for data in tqdm(self.loader_val):
                loss = self.model(*data)
                self.loss_val += loss.item() / length
        print(
            f"第 {self.round} 轮",
            "验证损失:{:.4f}".format(self.loss_val),
            sep=","
        )

    # 保存训练指标
    def _save_loss(self):
        with open(
                os.path.join(
                    config["root"],
                    "result", "temp",
                    "loss.csv"
                ), "a+", encoding="utf-8"
        ) as f:
            f.write("{:.4f}".format(self.loss_train) + "," +
                    "{:.4f}".format(self.loss_val) + "\n")
        print(
            f"第 {self.round} 轮",
            "已保存训练指标...\n",
            sep=","
        )

    # 保存模型
    def _save_model(self):
        info = {
            "weight": self.model.state_dict(),
            "param": {
                "max_step": config["max_step"],
                "beta_mode": config["beta_mode"],
                "time_dim": config["time_dim"],
                "text_dim": config["text_dim"],
            },
            "result": {
                "loss_train": self.loss_train,
                "loss_val": self.loss_val,
            }
        }
        name = f"model_{self.round}.pt"
        torch.save(
            info,
            os.path.join(
                config["root"],
                "result", "temp",
                name
            )
        )
        print(
            f"第 {self.round} 轮",
            f"已保存模型:{name}...\n",
            sep=","
        )

    # 更新参数
    def _update_grad(self):
        # 梯度裁剪
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=5.0)
        self.optimizer.step()
        self.optimizer.zero_grad()

    # collate_fn
    @staticmethod
    def _fun(batch):
        x, text = zip(*batch)
        x = torch.stack(x, dim=0).to(device)
        text = [str(item) for item in text]
        text = text_encoder(text).to(device)

        return x, text

    # 测试
    @staticmethod
    def test():
        # 加载模型
        info = torch.load(
            os.path.join(
                config["root"],
                "result/final/model_sd.pt",
            ),
            map_location=device,
        )
        print("模型训练指标:", info["result"])
        model = SDWithLoss(
            path_vae=config["path_vae"],
            **info["param"],
        ).to(device)
        model.load_state_dict(info["weight"])
        model.eval()
        # 模型预测
        out, lst = model.predict(
            shape=(10, 4, 7, 7),
            text=text_encoder([
                "0", "1", "2", "3", "4",
                "5", "6", "7", "8", "9",
            ]).to(device),
            sampling_mode="ddim",
            step=100,
        )
        out_np = out.squeeze(1).detach().cpu().numpy()  # (10, 28, 28)
        plt.figure(figsize=(15, 10))
        for i in range(10):
            ax = plt.subplot(2, 5, i + 1)
            ax.imshow(out_np[i])
            ax.set_title(f"Number '{i}'\n(28 × 28)")
            ax.axis("off")
        plt.show()


if __name__ == '__main__':
    # 配置
    config = {
        # "root": r"/root/autodl-tmp/stable_diffusion",
        "root": r"D:\Project\Transformer\stable_diffusion",
        # "path_text_encoder": r"/root/autodl-tmp/stable_diffusion/clip-vit-base-patch32",
        "path_text_encoder": r"D:\Project\Transformer\ztool\tokenizer\clip-vit-base-patch32",
        "ratio": (0.8, 0.2),
        "batch_size": (50, 100),
        # "path_vae": r"/root/autodl-tmp/stable_diffusion/result/final/model.pt",
        "path_vae": r"D:\Project\Transformer\stable_diffusion\result\final\model_vae.pt",
        "max_step": 1000,
        "beta_mode": "linear",
        "time_dim": 512,
        "text_dim": 512,
        "epoch": 50,
        "lr": 5e-4,
        "acc_step": 2,
        "val_step": 1,
        "model_round": 30,
    }
    # 默认设备
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # 图片预处理
    tf = T.Compose([
        T.Resize((28, 28)),
        T.ToTensor(),
        T.Normalize((0.1307,), (0.3081,))
    ])
    # 文本编码器
    text_encoder = TextEncoder(config["path_text_encoder"])

    train_val_test = TrainValTest()
    # train_val_test()
    TrainValTest.test()

4 训练结果:

loss_train,loss_val

0.0483,0.0486

相关推荐
Asher阿舍技术站1 小时前
【AI基础学习系列】四、Prompt基础知识
人工智能·学习·prompt
SailingCoder1 小时前
【 从“打补丁“到“换思路“ 】一次企业级 AI Agent 的架构拐点
大数据·前端·人工智能·面试·架构·agent
jz_ddk2 小时前
[指南] Python循环语句完全指南
开发语言·python·continue·循环·for·while·break
Evand J2 小时前
【Python代码例程】长短期记忆网络(LSTM)和无迹卡尔曼滤波(UKF)的结合,处理复杂非线性系统和时间序列数据
python·lstm·滤波
hqyjzsb2 小时前
企业培训ROI深度分析:如何将CAIE认证的显性与隐性成本纳入投资回报率模型
人工智能·考研·职场和发展·创业创新·学习方法·业界资讯·改行学it
大模型真好玩2 小时前
最强开源多模态大模型它来啦——一文详解Qwen3.5核心特性
人工智能·agent·vibecoding
是店小二呀2 小时前
CANN Catlass:AI 处理器高性能计算的核心引擎
人工智能
罗技1232 小时前
Docker启动Coco AI Server后,如何访问内置Easysearch?
人工智能·docker·容器
新缸中之脑2 小时前
TinyFish:网站转结构化API
人工智能