实现变分自编码器 VAE- MNIST 数据集

1 测试结果:

2 模型实现:

python 复制代码
import os
import time

import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
from torch import nn, Tensor
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader
from torchvision import transforms as T
from tqdm import tqdm


# 残差块
class ResNetBlock(nn.Module):
    def __init__(self, c1: int, c2: int, c: int = None,
                 up: bool = False, down: bool = False):
        super(ResNetBlock, self).__init__()
        if c is None:  # 中间通道数
            c = c2
        # 上采样层
        self.up = nn.ConvTranspose2d(c1, c1, kernel_size=2, stride=2) \
            if up else nn.Identity()
        # 残差连接
        self.shortcut = nn.Conv2d(c1, c2, kernel_size=1) \
            if c1 != c2 else nn.Identity()
        # 卷积层
        self.layers = nn.Sequential(
            nn.Conv2d(c1, c, kernel_size=3, padding=1, bias=False),
            # 参考 Stable Diffusion
            nn.GroupNorm(32, c),
            nn.SiLU(True),
            nn.Conv2d(c, c2, kernel_size=3, padding=1, bias=False),
            nn.GroupNorm(32, c2),
            nn.SiLU(True),
        )
        # 下采样层
        self.down = nn.Conv2d(c2, c2, kernel_size=3, stride=2, padding=1) \
            if down else nn.Identity()

    def forward(self, x: Tensor) -> Tensor:
        y = self.up(x)
        res = self.shortcut(y)
        y = self.layers(y)

        return self.down(y + res)  # (batch, c2, height, width)


# 注意力机制
class Attention(nn.Module):
    def __init__(self, c: int):
        super(Attention, self).__init__()
        # 映射层
        self.layers = nn.Sequential(
            nn.GroupNorm(32, c),
            nn.Conv2d(c, 3 * c, kernel_size=3, padding=1, bias=False),
        )
        # 注意力层
        self.attn = nn.MultiheadAttention(c, 8, batch_first=True)
        # 输出层
        self.conv_out = nn.Conv2d(c, c, kernel_size=3, padding=1)

    def forward(self, x: Tensor) -> Tensor:
        batch, channel, height, width = x.shape
        # 统一计算 Q, K, V
        qkv = self.layers(x)
        qkv = qkv.view(batch, 3 * channel, -1).transpose(1, 2)
        q, k, v = qkv.chunk(3, dim=-1)
        y, _ = self.attn(q, k, v)
        y = y.transpose(1, 2).view(batch, channel, height, width)

        return x + self.conv_out(y)


# 编码器
class VAEEncoder(nn.Module):
    def __init__(self, c1: int = 3, c2: int = 4):
        super(VAEEncoder, self).__init__()
        self.layers = nn.Sequential(
            # 输入层
            nn.Conv2d(c1, 64, kernel_size=3, padding=1),
            # 残差块
            ResNetBlock(64, 128),
            ResNetBlock(128, 128, down=True),
            ResNetBlock(128, 256),
            ResNetBlock(256, 256, down=True),
            # 注意力层
            Attention(256),
            # 输出层
            nn.Conv2d(256, 2 * c2, kernel_size=3, padding=1),
        )

    def forward(self, x: Tensor) -> tuple[Tensor, ...]:
        y = self.layers(x)
        # 拆分均值和对数方差
        mean, log_var = torch.chunk(y, 2, dim=1)
        # 参考 Stable Diffusion
        mean = mean * 0.18215
        log_var = log_var * 0.18215
        # 重参数化采样
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        z = mean + eps * std

        return z, mean, log_var  # (batch, c2, height, width)


# 解码器
class VAEDecoder(nn.Module):
    def __init__(self, c1: int = 4, c2: int = 3):
        super(VAEDecoder, self).__init__()
        self.layers = nn.Sequential(
            # 输入层
            nn.Conv2d(c1, 256, kernel_size=3, padding=1),
            # 注意力层
            Attention(256),
            # 残差块
            ResNetBlock(256, 256, up=True),
            ResNetBlock(256, 128),
            ResNetBlock(128, 128, up=True),
            ResNetBlock(128, 64),
            # 输出层
            nn.Conv2d(64, c2, kernel_size=3, padding=1),
        )

    def forward(self, x: Tensor) -> Tensor:
        x = x / 0.18215
        return self.layers(x)  # (batch_size, c2, height, width)


# VAE
class VAE(nn.Module):
    def __init__(self,
                 img_channels: int = 3,
                 latent_channels: int = 4):
        super(VAE, self).__init__()
        # 编码器
        self.encoder = VAEEncoder(img_channels, latent_channels)
        # 解码器
        self.decoder = VAEDecoder(latent_channels, img_channels)

    def forward(self, x: Tensor) -> tuple[Tensor, ...]:
        z, mean, log_var = self.encoder(x)
        out = self.decoder(z)

        return out, z, mean, log_var


# VAEWithLoss
class VAEWithLoss(nn.Module):
    def __init__(self, **kwargs):
        super(VAEWithLoss, self).__init__()
        self.model = VAE(**kwargs)
        # 均方差损失函数
        self.criterion = nn.MSELoss()

    def forward(self, x: Tensor, kl_weight: float = 1.) -> dict[str, Tensor]:
        out, z, mean, log_var = self.model(x)
        # 重构损失
        recon_loss = self.criterion(out, x)
        # KL 散度损失
        kl_loss = -0.5 * torch.mean(1 + log_var - mean.pow(2) - log_var.exp())
        # 总损失
        total_loss = recon_loss + kl_weight * kl_loss

        return {
            "out": out,
            "z": z,
            "mean": mean,
            "log_var": log_var,
            "recon_loss": recon_loss,
            "kl_loss": kl_loss,
            "total_loss": total_loss,
        }

    def encode(self, x: Tensor) -> Tensor:
        self.eval()  # 测试模式
        with torch.no_grad():
            return self.model.encoder(x)[1]  # mean

    def decode(self, x: Tensor) -> Tensor:
        self.eval()
        with torch.no_grad():
            return self.model.decoder(x)

3 训练及验证:

python 复制代码
# 主函数
class TrainValTest:
    def __init__(self):
        self.loader_train, self.loader_val = None, None  # DataLoader
        self.model = None  # 模型
        self.optimizer = None  # 优化器
        self.scheduler = None  # 学习率调度器
        self._init_score()  # 指标
        self.round = 0  # 当前训练轮次

    def __call__(self):
        print("加载数据集...")
        self._load_dataset()

        print("创建模型...")
        self._create_model()

        print("开始训练...")
        for i in range(config["epoch"]):
            # 初始化
            self.round = i + 1
            self._init_score()
            self.optimizer.zero_grad()
            time.sleep(1)

            # 训练
            self._train()
            time.sleep(1)
            # 验证
            if self.round % config["val_step"] == 0:
                self._val()
                time.sleep(1)
            if torch.cuda.is_available():  # 清理缓存
                torch.cuda.empty_cache()

            # 更新学习率
            self.scheduler.step()
            # 保存训练指标
            self._save_loss()
            # 保存模型
            if self.round > config["model_round"]:
                self._save_model()

    # 初始化指标
    def _init_score(self):
        self.recon_loss_train, self.recon_loss_val = 0, 0  # 重构损失
        self.kl_loss_train, self.kl_loss_val = 0, 0  # KL 散度损失
        self.total_loss_train, self.total_loss_val = 0, 0  # 总损失

    # 加载数据集
    def _load_dataset(self):
        # 下载手写数字 MNIST 数据集
        train_dataset = torchvision.datasets.MNIST(
            root=os.path.join(config["root"], "data"),
            train=True, download=True, transform=tf,
        )
        test_dataset = torchvision.datasets.MNIST(
            root=os.path.join(config["root"], "data"),
            train=False, download=True, transform=tf,
        )
        self.loader_train = DataLoader(
            train_dataset,
            batch_size=config["batch_size"][0],
            shuffle=True,
        )
        self.loader_val = DataLoader(
            test_dataset,
            batch_size=config["batch_size"][1],
            shuffle=False,
        )
        print(
            f"训练集加载器:{len(self.loader_train)}",
            f"验证集加载器:{len(self.loader_val)}",
            sep=","
        )

    # 创建模型
    def _create_model(self):
        self.model = VAEWithLoss(
            img_channels=config["img_channels"],
            latent_channels=config["latent_channels"],
        ).to(device)
        self.optimizer = AdamW(
            self.model.parameters(),
            lr=config["lr"],
        )
        self.scheduler = CosineAnnealingLR(
            self.optimizer,
            T_max=config["epoch"],
            eta_min=config["lr"] * 0.1,
        )
        num_params = sum(p.numel() for p in self.model.parameters())
        print(f"总参数量:{num_params}")

    # 训练
    def _train(self):
        self.model.train()  # 训练模式
        length = len(self.loader_train)
        for i, data in enumerate(tqdm(self.loader_train)):
            # 向前计算
            ret = self.model(data[0].to(device), self._update_kl())
            # 损失
            self.recon_loss_train += ret["recon_loss"].item() / length
            self.kl_loss_train += ret["kl_loss"].item() / length
            loss = ret["total_loss"]
            self.total_loss_train += loss.item() / length
            # 反向传播,梯度累积
            loss /= config["acc_step"]
            loss.backward()
            # 每 acc_step 步更新一次参数
            if (i + 1) % config["acc_step"] == 0:
                self._update_grad()
                if torch.cuda.is_available():  # 清理缓存
                    torch.cuda.empty_cache()
        # 处理最后一个不完整的累积批次
        if length % config["acc_step"] != 0:
            self._update_grad()
        # 打印日志
        print(
            f"第 {self.round} 轮",
            "训练重构损失:{:.4f}".format(self.recon_loss_train),
            "训练 KL 散度损失:{:.4f}".format(self.kl_loss_train),
            "训练总损失:{:.4f}".format(self.total_loss_train),
            sep=","
        )

    # 验证
    def _val(self):
        self.model.eval()  # 测试模式
        length = len(self.loader_val)
        with torch.no_grad():
            for data in tqdm(self.loader_val):
                ret = self.model(data[0].to(device), self._update_kl())
                self.recon_loss_val += ret["recon_loss"].item() / length
                self.kl_loss_val += ret["kl_loss"].item() / length
                self.total_loss_val += ret["total_loss"].item() / length
        print(
            f"第 {self.round} 轮",
            "验证重构损失:{:.4f}".format(self.recon_loss_val),
            "验证 KL 散度损失:{:.4f}".format(self.kl_loss_val),
            "验证总损失:{:.4f}".format(self.total_loss_val),
            sep=","
        )

    # 保存训练指标
    def _save_loss(self):
        with open(
                os.path.join(
                    config["root"],
                    "result", "temp",
                    "loss.csv"
                ), "a+", encoding="utf-8"
        ) as f:
            f.write("{:.4f}".format(self.recon_loss_train) + "," +
                    "{:.4f}".format(self.kl_loss_train) + "," +
                    "{:.4f}".format(self.total_loss_train) + "," +
                    "{:.4f}".format(self.recon_loss_val) + "," +
                    "{:.4f}".format(self.kl_loss_val) + "," +
                    "{:.4f}".format(self.total_loss_val) + "\n")
        print(
            f"第 {self.round} 轮",
            "已保存训练指标...\n",
            sep=","
        )

    # 保存模型
    def _save_model(self):
        info = {
            "weight": self.model.state_dict(),
            "param": {
                "img_channels": config["img_channels"],
                "latent_channels": config["latent_channels"],
            },
            "result": {
                "recon_loss_train": self.recon_loss_train,
                "kl_loss_train": self.kl_loss_train,
                "total_loss_train": self.total_loss_train,
                "recon_loss_val": self.recon_loss_val,
                "kl_loss_val": self.kl_loss_val,
                "total_loss_val": self.total_loss_val,
            }
        }
        name = f"model_{self.round}.pt"
        torch.save(
            info,
            os.path.join(
                config["root"],
                "result", "temp",
                name
            )
        )
        print(
            f"第 {self.round} 轮",
            f"已保存模型:{name}...\n",
            sep=","
        )

    # 更新 KL 权重
    def _update_kl(self):
        # VAE 模型需要多轮训练,以此来平衡重构损失和 KL 散度损失
        dct = {
            (0, 10): 0.001,
            (10, 20): 0.005,
            (20, 30): 0.01,
            (30, 40): 0.05,
            (40, 50): 0.1,
        }
        for (start, end), weight in dct.items():
            if self.round < end:
                return weight

        return 0.3

    # 更新参数
    def _update_grad(self):
        # 梯度裁剪
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=5.0)
        self.optimizer.step()
        self.optimizer.zero_grad()

    # 测试
    @staticmethod
    def test():
        # 加载模型
        info = torch.load(
            os.path.join(
                config["root"],
                "result/final/model_vae.pt",
            ),
            map_location=device,
        )
        print("模型训练指标:", info["result"])
        model = VAEWithLoss(**info["param"]).to(device)
        model.load_state_dict(info["weight"])
        model.eval()
        # 加载输入
        dataset = torchvision.datasets.MNIST(
            root=os.path.join(config["root"], "data"),
            train=False, download=True, transform=tf,
        )
        x = dataset.__getitem__(0)[0].unsqueeze(0).to(device)  # (1, 1, 28, 28)
        # 编码
        z = model.encode(x)  # (1, 4, 7, 7)
        z_ = torch.randn(1, 4, 7, 7).to(device)
        # 解码
        y = model.decode(z)  # (1, 1, 28, 28)
        y_ = model.decode(z_)

        # 可视化
        x_np = x.squeeze(0).squeeze(0).detach().cpu().numpy()  # (28, 28)
        y_np = y.squeeze(0).squeeze(0).detach().cpu().numpy()  # (28, 28)
        y_np_ = y_.squeeze(0).squeeze(0).detach().cpu().numpy()  # (28, 28)
        z_np = z.squeeze(0).detach().cpu().numpy()  # (4, 7, 7)
        plt.figure(figsize=(15, 10))
        # x
        ax1 = plt.subplot(2, 4, 1)
        ax1.imshow(x_np, cmap="gray")
        ax1.set_title("Input x\n(28 × 28)")
        ax1.axis("off")
        # y
        ax2 = plt.subplot(2, 4, 2)
        ax2.imshow(y_np, cmap="gray")
        ax2.set_title("Output y\n(28 × 28)")
        ax2.axis("off")
        # 重构误差
        ax3 = plt.subplot(2, 4, 3)
        diff = np.abs(x_np - y_np)
        ax3.imshow(diff, cmap="hot")
        ax3.set_title("Recon Error\n|x - y|")
        ax3.axis("off")
        # 随机采样
        ax4 = plt.subplot(2, 4, 4)
        ax4.imshow(y_np_, cmap="gray")
        ax4.set_title("Sample y\n(28 × 28)")
        ax4.axis("off")
        # z
        for i in range(4):
            ax = plt.subplot(2, 4, i + 5)
            cmaps = ["Blues", "Greens", "Reds", "Purples"]
            ax.imshow(z_np[i], cmap=cmaps[i])
            ax.set_title(f"Latent z Channel {i}\n(7 × 7)")
            ax.axis("off")
        plt.show()


if __name__ == "__main__":
    # 配置
    config = {
        # "root": r"/root/autodl-tmp/stable_diffusion",
        "root": r"D:\Project\Transformer\stable_diffusion",
        "ratio": (0.8, 0.2),
        "batch_size": (50, 100),
        "img_channels": 1,
        "latent_channels": 4,
        "epoch": 500,
        "lr": 1e-4,
        "acc_step": 2,
        "val_step": 1,
        "model_round": 60,
    }
    # 默认设备
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # 图片预处理
    tf = T.Compose([
        T.Resize((28, 28)),
        T.ToTensor(),
        T.Normalize((0.1307,), (0.3081,)),
    ])

    train_val_test = TrainValTest()
    # train_val_test()
    TrainValTest.test()

4 训练结果:

复制代码
recon_loss_train,kl_loss_train,total_loss_train,recon_loss_val,kl_loss_val,total_loss_val
0.0474,0.2545,0.1238,0.0471,0.2540,0.1233
相关推荐
PD我是你的真爱粉2 小时前
RabbitMQ架构实战
python·架构·rabbitmq
DevilSeagull2 小时前
大语言模型完全指南
人工智能·语言模型·自然语言处理
予枫的编程笔记2 小时前
【YF技术周报 Vol.01】OpenAI 国会指控 DeepSeek,字节发布 Seedance 2.0,Java 26 预览版来了
java·人工智能·openai·后端开发·ai技术·spring ai·deepseek
Faker66363aaa2 小时前
基于Faster-RCNN_C4的绝缘子缺陷检测与分类实现
人工智能·分类·数据挖掘
我的xiaodoujiao2 小时前
使用 Python 语言 从 0 到 1 搭建完整 Web UI自动化测试学习系列 49--CI/CD-开始探索使用Jenkins
python·学习·测试工具·ci/cd·jenkins·pytest
南 阳2 小时前
Python从入门到精通day35
数据库·python·oracle
草莓熊Lotso2 小时前
Linux 磁盘基础:从物理结构到 CHS/LBA 寻址,吃透数据存储底层逻辑
linux·运维·服务器·c++·人工智能
天一生水water2 小时前
时间序列故障诊断
人工智能·智慧油田
草莓熊Lotso2 小时前
Qt 核心事件系统全攻略:鼠标 / 键盘 / 定时器 / 窗口 + 事件分发与过滤
运维·开发语言·c++·人工智能·qt·ui·计算机外设