1 测试结果:
2 模型实现:
python
复制代码
import os
import time
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
from torch import nn, Tensor
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader
from torchvision import transforms as T
from tqdm import tqdm
# 残差块
class ResNetBlock(nn.Module):
def __init__(self, c1: int, c2: int, c: int = None,
up: bool = False, down: bool = False):
super(ResNetBlock, self).__init__()
if c is None: # 中间通道数
c = c2
# 上采样层
self.up = nn.ConvTranspose2d(c1, c1, kernel_size=2, stride=2) \
if up else nn.Identity()
# 残差连接
self.shortcut = nn.Conv2d(c1, c2, kernel_size=1) \
if c1 != c2 else nn.Identity()
# 卷积层
self.layers = nn.Sequential(
nn.Conv2d(c1, c, kernel_size=3, padding=1, bias=False),
# 参考 Stable Diffusion
nn.GroupNorm(32, c),
nn.SiLU(True),
nn.Conv2d(c, c2, kernel_size=3, padding=1, bias=False),
nn.GroupNorm(32, c2),
nn.SiLU(True),
)
# 下采样层
self.down = nn.Conv2d(c2, c2, kernel_size=3, stride=2, padding=1) \
if down else nn.Identity()
def forward(self, x: Tensor) -> Tensor:
y = self.up(x)
res = self.shortcut(y)
y = self.layers(y)
return self.down(y + res) # (batch, c2, height, width)
# 注意力机制
class Attention(nn.Module):
def __init__(self, c: int):
super(Attention, self).__init__()
# 映射层
self.layers = nn.Sequential(
nn.GroupNorm(32, c),
nn.Conv2d(c, 3 * c, kernel_size=3, padding=1, bias=False),
)
# 注意力层
self.attn = nn.MultiheadAttention(c, 8, batch_first=True)
# 输出层
self.conv_out = nn.Conv2d(c, c, kernel_size=3, padding=1)
def forward(self, x: Tensor) -> Tensor:
batch, channel, height, width = x.shape
# 统一计算 Q, K, V
qkv = self.layers(x)
qkv = qkv.view(batch, 3 * channel, -1).transpose(1, 2)
q, k, v = qkv.chunk(3, dim=-1)
y, _ = self.attn(q, k, v)
y = y.transpose(1, 2).view(batch, channel, height, width)
return x + self.conv_out(y)
# 编码器
class VAEEncoder(nn.Module):
def __init__(self, c1: int = 3, c2: int = 4):
super(VAEEncoder, self).__init__()
self.layers = nn.Sequential(
# 输入层
nn.Conv2d(c1, 64, kernel_size=3, padding=1),
# 残差块
ResNetBlock(64, 128),
ResNetBlock(128, 128, down=True),
ResNetBlock(128, 256),
ResNetBlock(256, 256, down=True),
# 注意力层
Attention(256),
# 输出层
nn.Conv2d(256, 2 * c2, kernel_size=3, padding=1),
)
def forward(self, x: Tensor) -> tuple[Tensor, ...]:
y = self.layers(x)
# 拆分均值和对数方差
mean, log_var = torch.chunk(y, 2, dim=1)
# 参考 Stable Diffusion
mean = mean * 0.18215
log_var = log_var * 0.18215
# 重参数化采样
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
z = mean + eps * std
return z, mean, log_var # (batch, c2, height, width)
# 解码器
class VAEDecoder(nn.Module):
def __init__(self, c1: int = 4, c2: int = 3):
super(VAEDecoder, self).__init__()
self.layers = nn.Sequential(
# 输入层
nn.Conv2d(c1, 256, kernel_size=3, padding=1),
# 注意力层
Attention(256),
# 残差块
ResNetBlock(256, 256, up=True),
ResNetBlock(256, 128),
ResNetBlock(128, 128, up=True),
ResNetBlock(128, 64),
# 输出层
nn.Conv2d(64, c2, kernel_size=3, padding=1),
)
def forward(self, x: Tensor) -> Tensor:
x = x / 0.18215
return self.layers(x) # (batch_size, c2, height, width)
# VAE
class VAE(nn.Module):
def __init__(self,
img_channels: int = 3,
latent_channels: int = 4):
super(VAE, self).__init__()
# 编码器
self.encoder = VAEEncoder(img_channels, latent_channels)
# 解码器
self.decoder = VAEDecoder(latent_channels, img_channels)
def forward(self, x: Tensor) -> tuple[Tensor, ...]:
z, mean, log_var = self.encoder(x)
out = self.decoder(z)
return out, z, mean, log_var
# VAEWithLoss
class VAEWithLoss(nn.Module):
def __init__(self, **kwargs):
super(VAEWithLoss, self).__init__()
self.model = VAE(**kwargs)
# 均方差损失函数
self.criterion = nn.MSELoss()
def forward(self, x: Tensor, kl_weight: float = 1.) -> dict[str, Tensor]:
out, z, mean, log_var = self.model(x)
# 重构损失
recon_loss = self.criterion(out, x)
# KL 散度损失
kl_loss = -0.5 * torch.mean(1 + log_var - mean.pow(2) - log_var.exp())
# 总损失
total_loss = recon_loss + kl_weight * kl_loss
return {
"out": out,
"z": z,
"mean": mean,
"log_var": log_var,
"recon_loss": recon_loss,
"kl_loss": kl_loss,
"total_loss": total_loss,
}
def encode(self, x: Tensor) -> Tensor:
self.eval() # 测试模式
with torch.no_grad():
return self.model.encoder(x)[1] # mean
def decode(self, x: Tensor) -> Tensor:
self.eval()
with torch.no_grad():
return self.model.decoder(x)
3 训练及验证:
python
复制代码
# 主函数
class TrainValTest:
def __init__(self):
self.loader_train, self.loader_val = None, None # DataLoader
self.model = None # 模型
self.optimizer = None # 优化器
self.scheduler = None # 学习率调度器
self._init_score() # 指标
self.round = 0 # 当前训练轮次
def __call__(self):
print("加载数据集...")
self._load_dataset()
print("创建模型...")
self._create_model()
print("开始训练...")
for i in range(config["epoch"]):
# 初始化
self.round = i + 1
self._init_score()
self.optimizer.zero_grad()
time.sleep(1)
# 训练
self._train()
time.sleep(1)
# 验证
if self.round % config["val_step"] == 0:
self._val()
time.sleep(1)
if torch.cuda.is_available(): # 清理缓存
torch.cuda.empty_cache()
# 更新学习率
self.scheduler.step()
# 保存训练指标
self._save_loss()
# 保存模型
if self.round > config["model_round"]:
self._save_model()
# 初始化指标
def _init_score(self):
self.recon_loss_train, self.recon_loss_val = 0, 0 # 重构损失
self.kl_loss_train, self.kl_loss_val = 0, 0 # KL 散度损失
self.total_loss_train, self.total_loss_val = 0, 0 # 总损失
# 加载数据集
def _load_dataset(self):
# 下载手写数字 MNIST 数据集
train_dataset = torchvision.datasets.MNIST(
root=os.path.join(config["root"], "data"),
train=True, download=True, transform=tf,
)
test_dataset = torchvision.datasets.MNIST(
root=os.path.join(config["root"], "data"),
train=False, download=True, transform=tf,
)
self.loader_train = DataLoader(
train_dataset,
batch_size=config["batch_size"][0],
shuffle=True,
)
self.loader_val = DataLoader(
test_dataset,
batch_size=config["batch_size"][1],
shuffle=False,
)
print(
f"训练集加载器:{len(self.loader_train)}",
f"验证集加载器:{len(self.loader_val)}",
sep=","
)
# 创建模型
def _create_model(self):
self.model = VAEWithLoss(
img_channels=config["img_channels"],
latent_channels=config["latent_channels"],
).to(device)
self.optimizer = AdamW(
self.model.parameters(),
lr=config["lr"],
)
self.scheduler = CosineAnnealingLR(
self.optimizer,
T_max=config["epoch"],
eta_min=config["lr"] * 0.1,
)
num_params = sum(p.numel() for p in self.model.parameters())
print(f"总参数量:{num_params}")
# 训练
def _train(self):
self.model.train() # 训练模式
length = len(self.loader_train)
for i, data in enumerate(tqdm(self.loader_train)):
# 向前计算
ret = self.model(data[0].to(device), self._update_kl())
# 损失
self.recon_loss_train += ret["recon_loss"].item() / length
self.kl_loss_train += ret["kl_loss"].item() / length
loss = ret["total_loss"]
self.total_loss_train += loss.item() / length
# 反向传播,梯度累积
loss /= config["acc_step"]
loss.backward()
# 每 acc_step 步更新一次参数
if (i + 1) % config["acc_step"] == 0:
self._update_grad()
if torch.cuda.is_available(): # 清理缓存
torch.cuda.empty_cache()
# 处理最后一个不完整的累积批次
if length % config["acc_step"] != 0:
self._update_grad()
# 打印日志
print(
f"第 {self.round} 轮",
"训练重构损失:{:.4f}".format(self.recon_loss_train),
"训练 KL 散度损失:{:.4f}".format(self.kl_loss_train),
"训练总损失:{:.4f}".format(self.total_loss_train),
sep=","
)
# 验证
def _val(self):
self.model.eval() # 测试模式
length = len(self.loader_val)
with torch.no_grad():
for data in tqdm(self.loader_val):
ret = self.model(data[0].to(device), self._update_kl())
self.recon_loss_val += ret["recon_loss"].item() / length
self.kl_loss_val += ret["kl_loss"].item() / length
self.total_loss_val += ret["total_loss"].item() / length
print(
f"第 {self.round} 轮",
"验证重构损失:{:.4f}".format(self.recon_loss_val),
"验证 KL 散度损失:{:.4f}".format(self.kl_loss_val),
"验证总损失:{:.4f}".format(self.total_loss_val),
sep=","
)
# 保存训练指标
def _save_loss(self):
with open(
os.path.join(
config["root"],
"result", "temp",
"loss.csv"
), "a+", encoding="utf-8"
) as f:
f.write("{:.4f}".format(self.recon_loss_train) + "," +
"{:.4f}".format(self.kl_loss_train) + "," +
"{:.4f}".format(self.total_loss_train) + "," +
"{:.4f}".format(self.recon_loss_val) + "," +
"{:.4f}".format(self.kl_loss_val) + "," +
"{:.4f}".format(self.total_loss_val) + "\n")
print(
f"第 {self.round} 轮",
"已保存训练指标...\n",
sep=","
)
# 保存模型
def _save_model(self):
info = {
"weight": self.model.state_dict(),
"param": {
"img_channels": config["img_channels"],
"latent_channels": config["latent_channels"],
},
"result": {
"recon_loss_train": self.recon_loss_train,
"kl_loss_train": self.kl_loss_train,
"total_loss_train": self.total_loss_train,
"recon_loss_val": self.recon_loss_val,
"kl_loss_val": self.kl_loss_val,
"total_loss_val": self.total_loss_val,
}
}
name = f"model_{self.round}.pt"
torch.save(
info,
os.path.join(
config["root"],
"result", "temp",
name
)
)
print(
f"第 {self.round} 轮",
f"已保存模型:{name}...\n",
sep=","
)
# 更新 KL 权重
def _update_kl(self):
# VAE 模型需要多轮训练,以此来平衡重构损失和 KL 散度损失
dct = {
(0, 10): 0.001,
(10, 20): 0.005,
(20, 30): 0.01,
(30, 40): 0.05,
(40, 50): 0.1,
}
for (start, end), weight in dct.items():
if self.round < end:
return weight
return 0.3
# 更新参数
def _update_grad(self):
# 梯度裁剪
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=5.0)
self.optimizer.step()
self.optimizer.zero_grad()
# 测试
@staticmethod
def test():
# 加载模型
info = torch.load(
os.path.join(
config["root"],
"result/final/model_vae.pt",
),
map_location=device,
)
print("模型训练指标:", info["result"])
model = VAEWithLoss(**info["param"]).to(device)
model.load_state_dict(info["weight"])
model.eval()
# 加载输入
dataset = torchvision.datasets.MNIST(
root=os.path.join(config["root"], "data"),
train=False, download=True, transform=tf,
)
x = dataset.__getitem__(0)[0].unsqueeze(0).to(device) # (1, 1, 28, 28)
# 编码
z = model.encode(x) # (1, 4, 7, 7)
z_ = torch.randn(1, 4, 7, 7).to(device)
# 解码
y = model.decode(z) # (1, 1, 28, 28)
y_ = model.decode(z_)
# 可视化
x_np = x.squeeze(0).squeeze(0).detach().cpu().numpy() # (28, 28)
y_np = y.squeeze(0).squeeze(0).detach().cpu().numpy() # (28, 28)
y_np_ = y_.squeeze(0).squeeze(0).detach().cpu().numpy() # (28, 28)
z_np = z.squeeze(0).detach().cpu().numpy() # (4, 7, 7)
plt.figure(figsize=(15, 10))
# x
ax1 = plt.subplot(2, 4, 1)
ax1.imshow(x_np, cmap="gray")
ax1.set_title("Input x\n(28 × 28)")
ax1.axis("off")
# y
ax2 = plt.subplot(2, 4, 2)
ax2.imshow(y_np, cmap="gray")
ax2.set_title("Output y\n(28 × 28)")
ax2.axis("off")
# 重构误差
ax3 = plt.subplot(2, 4, 3)
diff = np.abs(x_np - y_np)
ax3.imshow(diff, cmap="hot")
ax3.set_title("Recon Error\n|x - y|")
ax3.axis("off")
# 随机采样
ax4 = plt.subplot(2, 4, 4)
ax4.imshow(y_np_, cmap="gray")
ax4.set_title("Sample y\n(28 × 28)")
ax4.axis("off")
# z
for i in range(4):
ax = plt.subplot(2, 4, i + 5)
cmaps = ["Blues", "Greens", "Reds", "Purples"]
ax.imshow(z_np[i], cmap=cmaps[i])
ax.set_title(f"Latent z Channel {i}\n(7 × 7)")
ax.axis("off")
plt.show()
if __name__ == "__main__":
# 配置
config = {
# "root": r"/root/autodl-tmp/stable_diffusion",
"root": r"D:\Project\Transformer\stable_diffusion",
"ratio": (0.8, 0.2),
"batch_size": (50, 100),
"img_channels": 1,
"latent_channels": 4,
"epoch": 500,
"lr": 1e-4,
"acc_step": 2,
"val_step": 1,
"model_round": 60,
}
# 默认设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 图片预处理
tf = T.Compose([
T.Resize((28, 28)),
T.ToTensor(),
T.Normalize((0.1307,), (0.3081,)),
])
train_val_test = TrainValTest()
# train_val_test()
TrainValTest.test()
4 训练结果:
复制代码
recon_loss_train,kl_loss_train,total_loss_train,recon_loss_val,kl_loss_val,total_loss_val
0.0474,0.2545,0.1238,0.0471,0.2540,0.1233