1 测试结果:

2 模型实现:
python
import math
import os
import time
import torch
import torchvision
from matplotlib import pyplot as plt
from torch import nn, Tensor
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader
from torchvision import transforms as T
from tqdm import tqdm
from transformers import CLIPTokenizer, CLIPTextModel
from scheduler import TimeStepScheduler
from vae import VAEWithLoss
# 文本编码器
class TextEncoder(nn.Module):
def __init__(self, path: str = None):
super(TextEncoder, self).__init__()
if path is None: # 在线加载
path = r"openai/clip-vit-base-patch32"
# 分词器
self.tokenizer = CLIPTokenizer.from_pretrained(path)
# 文本模型
self.encoder = CLIPTextModel.from_pretrained(path).eval()
def forward(self, texts: list[str] | tuple[str]) -> Tensor:
with torch.no_grad():
inputs = self.tokenizer(
texts,
add_special_tokens=False,
padding=True,
truncation=True,
max_length=512,
return_tensors="pt",
)
outputs = self.encoder(return_dict=True, **inputs)
return outputs["last_hidden_state"] # (batch, seq_len, 512)
# 时间步嵌入
class TimestepEmbedding(nn.Module):
def __init__(self,
max_step: int = 1000,
d_model: int = 512):
super(TimestepEmbedding, self).__init__()
# 初始化位置向量
pe = torch.zeros((max_step, d_model))
position = torch.arange(0, max_step).unsqueeze(1)
div_term = torch.exp(
# 1 / 10000 ** (2 * k / d)
torch.arange(0, d_model, 2) * -(math.log(10000) / d_model)
)
# 偶数位置
pe[:, 0::2] = torch.sin(position * div_term)
# 奇数位置
pe[:, 1::2] = torch.cos(position * div_term)
# 缓存,(max_step, d_model)
self.register_buffer("pe", pe)
# MLP 层
self.mlp = nn.Sequential(
nn.Linear(d_model, 4 * d_model),
nn.SiLU(),
nn.Linear(4 * d_model, d_model),
)
def forward(self, t: Tensor) -> Tensor:
if t.dim() > 1: # (batch, 1)
t = t.squeeze(-1)
t = t.long()
return self.mlp(self.pe[t]) # (batch, d_model)
# 标准残差块
class ResNetBlock(nn.Module):
def __init__(self, c1: int, c2: int,
c: int = None, d: int = 512):
super(ResNetBlock, self).__init__()
if c is None: # 中间通道数
c = c2
# 卷积层
self.layer1 = nn.Sequential(
nn.Conv2d(c1, c, kernel_size=3, padding=1, bias=False),
nn.GroupNorm(32, c),
nn.SiLU(),
)
self.layer2 = nn.Sequential(
nn.Conv2d(c, c2, kernel_size=3, padding=1, bias=False),
nn.GroupNorm(32, c2),
nn.SiLU(),
)
# 线性层
self.fc = nn.Sequential(
nn.SiLU(),
nn.Linear(d, c) # time_dim
)
# 残差连接
self.shortcut = nn.Conv2d(c1, c2, kernel_size=1) \
if c1 != c2 else nn.Identity()
def forward(self, x: Tensor, e: Tensor) -> Tensor:
y = self.layer1(x)
# 添加时间步嵌入
e = self.fc(e)
e = e.view(e.size(0), -1, 1, 1) # (batch, c, 1, 1)
y = self.layer2(y + e)
return y + self.shortcut(x) # (batch, c2, height, width)
# 交叉注意力机制
class CrossAttention(nn.Module):
def __init__(self, d1: int, d2: int):
super(CrossAttention, self).__init__()
# 线性层
self.fc_q = nn.Linear(d1, d1)
self.fc_k = nn.Linear(d2, d1)
self.fc_v = nn.Linear(d2, d1)
self.fc_out = nn.Linear(d1, d1)
# 注意力层
self.attn = nn.MultiheadAttention(d1, num_heads=8, batch_first=True)
def forward(self, x1: Tensor, x2: Tensor) -> Tensor:
q = self.fc_q(x1)
k = self.fc_k(x2)
v = self.fc_v(x2)
y, _ = self.attn(q, k, v)
return self.fc_out(y) # (batch, seq_len, d1)
# 交叉注意力残差块
class CrossAttnResNetBlock(nn.Module):
def __init__(self, c1: int, c2: int, c: int = None,
d1: int = 512, d2: int = 512):
super(CrossAttnResNetBlock, self).__init__()
# 标准残差块
self.res = ResNetBlock(c1, c2, c, d1)
# 归一化层
self.norm = nn.GroupNorm(32, c2)
# 交叉注意力层
self.attn = CrossAttention(c2, d2) # text_dim
def forward(self, x: Tensor, e: Tensor, text: Tensor) -> Tensor:
y = self.res(x, e)
residual = y
batch, channel, height, width = y.shape
y = self.norm(y)
# (batch, height * width, channel)
y = y.view(batch, channel, -1).transpose(1, 2)
y = self.attn(y, text)
y = y.transpose(1, 2).view(batch, channel, height, width)
return y + residual # (batch, c2, height, width)
# 编码层
class Down(nn.Module):
def __init__(self, c1: int, c2: int,
down: bool = False, **kwargs):
super(Down, self).__init__()
# 交叉注意力残差块
self.res = CrossAttnResNetBlock(c1, c2, **kwargs)
# 下采样层
self.down = nn.Conv2d(c2, c2, kernel_size=3, stride=2, padding=1) \
if down else nn.Identity()
def forward(self, x: Tensor, e: Tensor, text: Tensor) -> Tensor:
return self.down(self.res(x, e, text))
# 中间层
class Mid(nn.Module):
def __init__(self, c: int, **kwargs):
super(Mid, self).__init__()
# 交叉注意力残差块
self.res1 = CrossAttnResNetBlock(c, c, **kwargs)
self.res2 = CrossAttnResNetBlock(c, c, **kwargs)
# 映射层
self.layers = nn.Sequential(
nn.GroupNorm(32, c),
nn.Conv2d(c, 3 * c, kernel_size=3, padding=1, bias=False),
)
# 自注意力层
self.attn = nn.MultiheadAttention(c, num_heads=8, batch_first=True)
def forward(self, x: Tensor, e: Tensor, text: Tensor) -> Tensor:
y = self.res1(x, e, text)
# 提取图像内部特征
batch, channel, height, width = y.shape
# 统一计算 Q, K, V
qkv = self.layers(x)
qkv = qkv.view(batch, 3 * channel, -1).transpose(1, 2)
q, k, v = qkv.chunk(3, dim=-1)
y, _ = self.attn(q, k, v)
y = y.transpose(1, 2).view(batch, channel, height, width)
return self.res2(y, e, text)
# 解码层
class Up(nn.Module):
def __init__(self, c1: int, c2: int,
up: bool = False, **kwargs):
super(Up, self).__init__()
# 转置卷积层
self.up = nn.ConvTranspose2d(c1, c1 // 2, kernel_size=2, stride=2) \
if up else nn.Conv2d(c1, c1 // 2, kernel_size=3, padding=1)
# 交叉注意力残差块
self.res = CrossAttnResNetBlock(c1, c2, **kwargs)
def forward(self, x1: Tensor, x2: Tensor, e: Tensor, text: Tensor) -> Tensor:
x1 = self.up(x1)
# 跳跃连接
x = torch.cat([x1, x2], dim=1)
return self.res(x, e, text)
# U-Net
class UNet(nn.Module):
def __init__(self,
max_step: int = 1000,
beta_mode: str = "linear",
time_dim: int = 512,
text_dim: int = 512):
super(UNet, self).__init__()
# 时间步调度器
self.scheduler = TimeStepScheduler(
max_step,
beta_mode=beta_mode,
device=device,
)
# 时间步嵌入
self.embed = TimestepEmbedding(max_step, time_dim)
# 输入层
self.conv_in = nn.Conv2d(4, 64, kernel_size=3, padding=1)
# 编码层
param = {"d1": time_dim, "d2": text_dim}
self.down1 = Down(64, 128, **param)
self.down2 = Down(128, 256, **param)
# 中间层
self.mid = Mid(256, **param)
# 解码层
self.up1 = Up(256, 128, **param)
self.up2 = Up(128, 64, **param)
# 输出层
self.conv_out = nn.Conv2d(64, 4, kernel_size=3, padding=1)
def forward(self, x: Tensor, text: Tensor) -> tuple[Tensor, ...]:
ret = self.scheduler(x)
# 预测噪声
pred_noise = self._steps(ret["x_t"], ret["t"], text)
return pred_noise, ret["noise"]
# 预测
def predict(self, **kwargs) -> tuple[Tensor, list[Tensor]]:
self.eval() # 测试模型
with torch.no_grad():
# 参考 TimeStepScheduler
return self.scheduler.predict(self._steps, **kwargs)
# 编解码
def _steps(self, x: Tensor, t: Tensor, text: Tensor) -> Tensor:
e = self.embed(t)
x1 = self.conv_in(x)
x2 = self.down1(x1, e, text)
x3 = self.down2(x2, e, text)
y = self.mid(x3, e, text)
y = self.up1(y, x2, e, text)
y = self.up2(y, x1, e, text)
return self.conv_out(y)
# Stable Diffusion
class StableDiffusion(nn.Module):
def __init__(self, path_vae: str, **kwargs):
super(StableDiffusion, self).__init__()
# VAE
self.vae = self._load_vae(path_vae)
# U-Net
self.u_net = UNet(**kwargs)
def forward(self, x: Tensor, text: Tensor) -> tuple[Tensor, ...]:
x = self.vae.encode(x)
return self.u_net(x, text)
# 预测
def predict(self, **kwargs) -> tuple[Tensor, list[Tensor]]:
self.eval()
with torch.no_grad():
x, lst = self.u_net.predict(**kwargs)
return self.vae.decode(x), lst
# 加载 VAE 模型
@staticmethod
def _load_vae(path: str) -> nn.Module:
info = torch.load(path, map_location=device)
model = VAEWithLoss(**info["param"]).to(device)
model.load_state_dict(info["weight"])
model.eval()
# 冻结参数
for param in model.parameters():
param.requires_grad = False
return model
# Stable Diffusion With MSELoss
class SDWithLoss(nn.Module):
def __init__(self, **kwargs):
super(SDWithLoss, self).__init__()
self.model = StableDiffusion(**kwargs)
# 均方差损失函数
self.criterion = nn.MSELoss()
def forward(self, x: Tensor, text: Tensor) -> Tensor:
pred_noise, noise = self.model(x, text)
return self.criterion(pred_noise, noise)
# 预测
def predict(self, **kwargs) -> tuple[Tensor, list[Tensor]]:
self.eval()
with torch.no_grad():
return self.model.predict(**kwargs)
3 训练及验证:
python
# 主函数
class TrainValTest:
def __init__(self):
self.loader_train, self.loader_val = None, None # DataLoader
self.model = None # 模型
self.optimizer = None # 优化器
self.scheduler = None # 学习率调度器
self._init_score() # 指标
self.round = 0 # 当前训练轮次
def __call__(self):
print("加载数据集...")
self._load_dataset()
print("创建模型...")
self._create_model()
print("开始训练...")
for i in range(config["epoch"]):
# 初始化
self.round = i + 1
self._init_score()
self.optimizer.zero_grad()
time.sleep(1)
# 训练
self._train()
time.sleep(1)
# 验证
if self.round % config["val_step"] == 0:
self._val()
time.sleep(1)
if torch.cuda.is_available(): # 清理缓存
torch.cuda.empty_cache()
# 更新学习率
self.scheduler.step()
# 保存训练指标
self._save_loss()
# 保存模型
if self.round > config["model_round"]:
self._save_model()
# 初始化指标
def _init_score(self):
self.loss_train, self.loss_val = 0, 0
# 加载数据集
def _load_dataset(self):
# 下载手写数字 MNIST 数据集
train_dataset = torchvision.datasets.MNIST(
root=os.path.join(config["root"], "data"),
train=True, download=True, transform=tf,
)
test_dataset = torchvision.datasets.MNIST(
root=os.path.join(config["root"], "data"),
train=False, download=True, transform=tf,
)
self.loader_train = DataLoader(
train_dataset,
batch_size=config["batch_size"][0],
shuffle=True,
collate_fn=self._fun,
)
self.loader_val = DataLoader(
test_dataset,
batch_size=config["batch_size"][1],
shuffle=False,
collate_fn=self._fun,
)
print(
f"训练集加载器:{len(self.loader_train)}",
f"验证集加载器:{len(self.loader_val)}",
sep=","
)
# 创建模型
def _create_model(self):
self.model = SDWithLoss(
path_vae=config["path_vae"],
max_step=config["max_step"],
beta_mode=config["beta_mode"],
time_dim=config["time_dim"],
text_dim=config["text_dim"],
).to(device)
self.optimizer = AdamW(
self.model.parameters(),
lr=config["lr"],
)
self.scheduler = CosineAnnealingLR(
self.optimizer,
T_max=config["epoch"],
eta_min=config["lr"] * 0.1,
)
num_params = sum(p.numel() for p in self.model.parameters())
print(f"总参数量:{num_params}")
# 训练
def _train(self):
self.model.train() # 训练模式
length = len(self.loader_train)
for i, data in enumerate(tqdm(self.loader_train)):
# 向前计算
loss = self.model(*data)
# 损失
self.loss_train += loss.item() / length
# 反向传播,梯度累积
loss /= config["acc_step"]
loss.backward()
# 每 acc_step 步更新一次参数
if (i + 1) % config["acc_step"] == 0:
self._update_grad()
if torch.cuda.is_available(): # 清理缓存
torch.cuda.empty_cache()
# 处理最后一个不完整的累积批次
if length % config["acc_step"] != 0:
self._update_grad()
# 打印日志
print(
f"第 {self.round} 轮",
"训练损失:{:.4f}".format(self.loss_train),
sep=","
)
# 验证
def _val(self):
self.model.eval() # 测试模式
length = len(self.loader_val)
with torch.no_grad():
for data in tqdm(self.loader_val):
loss = self.model(*data)
self.loss_val += loss.item() / length
print(
f"第 {self.round} 轮",
"验证损失:{:.4f}".format(self.loss_val),
sep=","
)
# 保存训练指标
def _save_loss(self):
with open(
os.path.join(
config["root"],
"result", "temp",
"loss.csv"
), "a+", encoding="utf-8"
) as f:
f.write("{:.4f}".format(self.loss_train) + "," +
"{:.4f}".format(self.loss_val) + "\n")
print(
f"第 {self.round} 轮",
"已保存训练指标...\n",
sep=","
)
# 保存模型
def _save_model(self):
info = {
"weight": self.model.state_dict(),
"param": {
"max_step": config["max_step"],
"beta_mode": config["beta_mode"],
"time_dim": config["time_dim"],
"text_dim": config["text_dim"],
},
"result": {
"loss_train": self.loss_train,
"loss_val": self.loss_val,
}
}
name = f"model_{self.round}.pt"
torch.save(
info,
os.path.join(
config["root"],
"result", "temp",
name
)
)
print(
f"第 {self.round} 轮",
f"已保存模型:{name}...\n",
sep=","
)
# 更新参数
def _update_grad(self):
# 梯度裁剪
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=5.0)
self.optimizer.step()
self.optimizer.zero_grad()
# collate_fn
@staticmethod
def _fun(batch):
x, text = zip(*batch)
x = torch.stack(x, dim=0).to(device)
text = [str(item) for item in text]
text = text_encoder(text).to(device)
return x, text
# 测试
@staticmethod
def test():
# 加载模型
info = torch.load(
os.path.join(
config["root"],
"result/final/model_sd.pt",
),
map_location=device,
)
print("模型训练指标:", info["result"])
model = SDWithLoss(
path_vae=config["path_vae"],
**info["param"],
).to(device)
model.load_state_dict(info["weight"])
model.eval()
# 模型预测
out, lst = model.predict(
shape=(10, 4, 7, 7),
text=text_encoder([
"0", "1", "2", "3", "4",
"5", "6", "7", "8", "9",
]).to(device),
sampling_mode="ddim",
step=100,
)
out_np = out.squeeze(1).detach().cpu().numpy() # (10, 28, 28)
plt.figure(figsize=(15, 10))
for i in range(10):
ax = plt.subplot(2, 5, i + 1)
ax.imshow(out_np[i])
ax.set_title(f"Number '{i}'\n(28 × 28)")
ax.axis("off")
plt.show()
if __name__ == '__main__':
# 配置
config = {
# "root": r"/root/autodl-tmp/stable_diffusion",
"root": r"D:\Project\Transformer\stable_diffusion",
# "path_text_encoder": r"/root/autodl-tmp/stable_diffusion/clip-vit-base-patch32",
"path_text_encoder": r"D:\Project\Transformer\ztool\tokenizer\clip-vit-base-patch32",
"ratio": (0.8, 0.2),
"batch_size": (50, 100),
# "path_vae": r"/root/autodl-tmp/stable_diffusion/result/final/model.pt",
"path_vae": r"D:\Project\Transformer\stable_diffusion\result\final\model_vae.pt",
"max_step": 1000,
"beta_mode": "linear",
"time_dim": 512,
"text_dim": 512,
"epoch": 50,
"lr": 5e-4,
"acc_step": 2,
"val_step": 1,
"model_round": 30,
}
# 默认设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 图片预处理
tf = T.Compose([
T.Resize((28, 28)),
T.ToTensor(),
T.Normalize((0.1307,), (0.3081,))
])
# 文本编码器
text_encoder = TextEncoder(config["path_text_encoder"])
train_val_test = TrainValTest()
# train_val_test()
TrainValTest.test()
4 训练结果:
loss_train,loss_val
0.0483,0.0486