python
import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from diffusers import DDPMScheduler, UNet2DModel
from matplotlib import pyplot as plt
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')
sql
Using device: cuda
python
dataset = torchvision.datasets.MNIST(root=r"D:\Pycharm_Project\data\MNIST", train=True, download=True,
transform=torchvision.transforms.ToTensor())
train_dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
x, y = next(iter(train_dataloader))
print('Input shape:', x.shape)
print('Labels:', y)
plt.imshow(torchvision.utils.make_grid(x)[0], cmap='Greys');
css
Input shape: torch.Size([8, 1, 28, 28])
Labels: tensor([3, 5, 1, 7, 0, 6, 8, 6])

扩散模型的退化过程
python
# # 通过引入一个参数来控制输入的噪声量(即内容损坏的程度)
# def corrupt(x,amount):
# "根据amount为输入的x添加噪声,这就是退化过程"
# noise = torch.randn_like(x)
# amount=amount.view(-1,1,1,1) # 整理形状以保证广播机制不出错
# return x*(1-amount)+noise*amount
def corrupt(x, amount):
"""根据amount为输入x加入噪声,这就是退化过程"""
noise = torch.rand_like(x) # rand_like函数返回一个与x形状相同的张量,其中的值服从0-1均匀分布,randn_like返回的是标准正态分布
amount = amount.view(-1, 1, 1, 1)
return x*(1-amount) + noise*amount
python
# 对输出结果可视化
# 绘制输入数据
fig, axs = plt.subplots(2, 1, figsize=(7, 5))
axs[0].set_title('Input data')
axs[0].imshow(torchvision.utils.make_grid(x)[0], cmap='Greys')
# 加入噪声
amount = torch.linspace(0, 1, x.shape[0]) # Left to right -> more corruption
noised_x = corrupt(x, amount)
# 绘制加噪版本的图像
axs[1].set_title('Corrupted data (-- amount increases -->)')
axs[1].imshow(torchvision.utils.make_grid(noised_x)[0], cmap='Greys');

扩散模型的训练
python
class BasicUNet(nn.Module):
def __init__(self, in_channels=1, out_channels=1):
super().__init__()
self.down_layers = torch.nn.ModuleList([
nn.Conv2d(in_channels, 32, kernel_size=5, padding=2),
nn.Conv2d(32, 64, kernel_size=5, padding=2),
nn.Conv2d(64, 64, kernel_size=5, padding=2),
])
self.up_layers = torch.nn.ModuleList([
nn.Conv2d(64, 64, kernel_size=5, padding=2),
nn.Conv2d(64, 32, kernel_size=5, padding=2),
nn.Conv2d(32, out_channels, kernel_size=5, padding=2),
])
# 激活函数
self.act = nn.SiLU()
self.downscale = nn.MaxPool2d(2)
self.upscale = nn.Upsample(scale_factor=2)
def forward(self, x):
h = []
for i, l in enumerate(self.down_layers):
# 通过运算层与激活函数
x = self.act(l(x))
if i < 2:
# 排列供残差连接使用的数据
h.append(x)
# 连接下采样
x = self.downscale(x)
for i, l in enumerate(self.up_layers):
if i > 0:
# 连接上采样
x = self.upscale(x)
# 得到之前排列好的供残差连接使用的数据
x += h.pop()
x = self.act(l(x))
return x
python
net = BasicUNet()
x = torch.rand(8, 1, 28, 28)
net(x).shape
css
torch.Size([8, 1, 28, 28])
python
# 查看网络参数个数
sum(p.numel() for p in net.parameters())
309057
python
# 训练参数
batch_size = 128
train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
n_epochs = 3
# 创建网络
net = BasicUNet()
net.to(device)
loss_fn = nn.MSELoss()
# 优化器
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
losses = []
# 开始训练
for epoch in range(n_epochs):
for x,y in train_dataloader:
# 得到数据并准备开始退化
x = x.to(device)
# 随机噪声
noise_amount = torch.rand(x.shape[0], device=device)
# 退化
noised_x=corrupt(x,noise_amount)
# 得到预测结果
pred = net(noised_x)
# 计算损失
loss = loss_fn(pred, x)
# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 记录损失
losses.append(loss.item())
# 输出损失的平均值
avg_loss = sum(losses[-len(train_dataloader):]) / len(train_dataloader)
print(f'Epoch {epoch+1}/{n_epochs}: average loss {avg_loss:.5f}')
yaml
Epoch 1/3: average loss 0.02758
Epoch 2/3: average loss 0.02110
Epoch 3/3: average loss 0.01915
python
plt.plot(losses)
plt.xlabel('Iteration')
plt.ylabel('Loss')
plt.show()

对比输入数据、退化数据、预测数据
python
# 我们也可以尝试通过抓取一批数据来得到不同程度的损坏数据,然后将他们输入模型以获得预测
# 可视化模型在"带噪"输入上的表现
# 生成一批数据
x, y = next(iter(train_dataloader))
x=x[:8].to(device)
# 在(0,1)范围内选择退化量
amount = torch.linspace(0, 1, x.shape[0], device=device)
noised_x = corrupt(x, amount)
# 得到预测结果
with torch.no_grad():
pred = net(noised_x).detach().cpu()
# 绘制结果
fig, axs = plt.subplots(3, 1, figsize=(12,7))
axs[0].set_title('Input data')
axs[0].imshow(torchvision.utils.make_grid(x.cpu())[0].clip(0,1), cmap='Greys')
axs[1].set_title('Corrupted data')
axs[1].imshow(torchvision.utils.make_grid(noised_x.cpu())[0].clip(0,1), cmap='Greys')
axs[2].set_title('Network output')
axs[2].imshow(torchvision.utils.make_grid(pred)[0].clip(0,1), cmap='Greys')
arduino
<matplotlib.image.AxesImage at 0x206ffb066a0>

扩散模型的采样过程
采样过程的方案:从完全随机的噪声开始,先检查一下模型的预测结果,然后只朝着预测方向移动一小部分。
可以理解为:将带有噪声的图像输入到模型中,得到一个预测输出,如果当前输出结果稍微好一点,那么将这次的预测输出重新作为输入再次输入到模型
python
def sample_with_step(x, n_steps):
step_history = [x.detach().cpu()]
pred_output_history = []
for i in range(n_steps):
with torch.no_grad():
pred = net(x) # 预测去噪结果
# 将模型输出保存下来
pred_output_history.append(pred.detach().cpu())
# 朝着预测方向移动的因子(移动多少)
mix_factor = 1/(n_steps - i)
x = x*(1-mix_factor) + pred*mix_factor # 移动过程
step_history.append(x.detach().cpu())
return x, step_history, pred_output_history
python
n_steps = 5
# 完全随机的值开始
x = torch.rand(8, 1, 28, 28).to(device)
x, step_history, pred_output_history = sample_with_step(x, n_steps)
python
fig, axs = plt.subplots(n_steps, 2, figsize=(9, 4), sharex=True)
axs[0,0].set_title('x (model input)')
axs[0,1].set_title('model prediction')
for i in range(n_steps):
axs[i, 0].imshow(torchvision.utils.make_grid(step_history[i])[0].clip(0, 1), cmap='Greys')
axs[i, 1].imshow(torchvision.utils.make_grid(pred_output_history[i])[0].clip(0, 1), cmap='Greys')

也可以将采样过程拆解成更多步,以获得质量更高的图像
python
n_steps = 40
# 完全随机的值开始
x = torch.rand(64, 1, 28, 28).to(device)
for i in range(n_steps):
noise_amount = torch.ones((x.shape[0],)).to(device)*(1-(i/n_steps))
with torch.no_grad():
pred = net(x)
mix_factor = 1/(n_steps - i)
x = x*(1-mix_factor) + pred*mix_factor
fig, axs = plt.subplots(1, 1, figsize=(12, 12))
axs.imshow(torchvision.utils.make_grid(x.detach().cpu(),nrow=8)[0].clip(0, 1), cmap='Greys')
arduino
<matplotlib.image.AxesImage at 0x2069bb03e50>

UNet2DModel
UNet2DModel与DDPM对比:
- UNet2DModel比BasicUNet更先进。
- 退化过程的处理方式不同。
- 训练目标不同,包括预测噪声而不是去噪图像。
- UNet2DModel模型通过调节时间步来调节噪声量, 其中t作为一个额外参数传入前向过程中。
Diffusers库中的UNet2DModel模型比BasicUNet模型有如下改进:
- GroupNorm层对每个模块的输入进行了组标准化(group normalization)。
- Dropout层能使训练更平滑。
- 每个块有多个ResNet层(如果layers_per_block未设置为1)。
- 引入了注意力机制(通常仅用于输入分辨率较低的blocks)。
- 可以对时间步进行调节。
- 具有可学习参数的上采样模块和下采样模块。
python
net = UNet2DModel(
sample_size=28, # 目标图像的分辨率
in_channels=1,
out_channels=1,
layers_per_block=2, # 每一个UNet块中的ResNet层数
block_out_channels=(32, 64, 64),
down_block_types=(
"DownBlock2D", # 下采样模块
"AttnDownBlock2D", # 带有空域维度的self-att的ResNet下采样模块
"AttnDownBlock2D",
),
up_block_types=(
"AttnUpBlock2D",
"AttnUpBlock2D", # 带有空域维度的self-att的ResNet上采样模块
"UpBlock2D", # 上采样模块
),
)
python
sum([p.numel() for p in net.parameters()])
# UNet2DModel模型大约有170万个参数,而BasicUNet模型只有30多万个参数。
1707009
python
# 训练数据加载器
batch_size = 128
train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
n_epochs = 3
net.to(device)
loss_fn = nn.MSELoss()
opt = torch.optim.Adam(net.parameters(), lr=1e-3)
losses = []
# 开始训练
for epoch in range(n_epochs):
for x, y in train_dataloader:
# 得到数据并准备退化
x = x.to(device)
# 随机噪声
noise_amount = torch.rand(x.shape[0]).to(device)
# 退化过程
noisy_x = corrupt(x, noise_amount)
# 得到预测结果
pred = net(noisy_x, 0).sample
# 计算损失值
loss = loss_fn(pred, x)
# 反向传播并更新参数
opt.zero_grad()
loss.backward()
opt.step()
losses.append(loss.item())
# 输出损失的均值
avg_loss = sum(losses[-len(train_dataloader):]) / len(train_dataloader)
print(f'Finished epoch {epoch}. Average loss for this epoch: {avg_loss:05f}')
kotlin
Finished epoch 0. Average loss for this epoch: 0.018955
Finished epoch 1. Average loss for this epoch: 0.012771
Finished epoch 2. Average loss for this epoch: 0.011652
python
fig, axs = plt.subplots(1, 2, figsize=(8, 3))
axs[0].plot(losses)
axs[0].set_ylim(0, 0.1)
axs[0].set_title('Loss over time')
n_steps = 100
x = torch.rand(64, 1, 28, 28).to(device)
for i in range(n_steps):
noise_amount = torch.ones((x.shape[0], )).to(device) * (1-(i/n_steps))
with torch.no_grad():
pred = net(x, 0).sample
mix_factor = 1/(n_steps - i)
x = x*(1-mix_factor) + pred*mix_factor
axs[1].imshow(torchvision.utils.make_grid(x.detach().cpu(), nrow=8)[0].clip(0, 1), cmap='Greys')
axs[1].set_title('Generated Samples');

扩散模型的退化过程示例
退化过程:
在某个时间步给定 <math xmlns="http://www.w3.org/1998/Math/MathML"> x ( t − 1 ) x_(t-1) </math>x(t−1) ,可以得到一个噪声稍微增加的 <math xmlns="http://www.w3.org/1998/Math/MathML"> x ( t ) x_(t) </math>x(t) : <math xmlns="http://www.w3.org/1998/Math/MathML"> q ( x t ∣ x t − 1 ) = N ( x t ; 1 − β t x t − 1 , β t I ) q ( x 1 : T ∣ x 0 ) = ∏ t = 1 T q ( x t ∣ x t − 1 ) q(x_t | x_{t-1}) = \mathcal{N}(x_t; \sqrt{1 - \beta_t} x_{t-1}, \beta_t I) q(x_{1:T} | x_0) = \prod^T_{t=1} q(x_t | x_{t-1}) </math>q(xt∣xt−1)=N(xt;1−βt xt−1,βtI)q(x1:T∣x0)=∏t=1Tq(xt∣xt−1)
python
noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
# 对一批图片加噪
fig, axs = plt.subplots(3, 1, figsize=(10, 6))
xb, yb = next(iter(train_dataloader))
xb = xb.to(device)[:8]
xb = xb * 2. - 1.
print('X shape', xb.shape)
# 展示干净的原始输入图片
axs[0].imshow(torchvision.utils.make_grid(xb[:8])[0].detach().cpu(), cmap='Greys')
axs[0].set_title('Clean X')
# 使用调度器加噪
timesteps = torch.linspace(0, 999, 8).long().to(device)
noise = torch.randn_like(xb)
noisy_xb = noise_scheduler.add_noise(xb, noise, timesteps)
print('Noisy X shape', noisy_xb.shape)
# 展示"带噪"版本
axs[1].imshow(torchvision.utils.make_grid(noisy_xb[:8])[0].detach().cpu().clip(-1, 1), cmap='Greys')
axs[1].set_title('Noisy X (clipped to (-1, 1)')
axs[2].imshow(torchvision.utils.make_grid(noisy_xb[:8])[0].detach().cpu(), cmap='Greys')
axs[2].set_title('Noisy X');
css
X shape torch.Size([8, 1, 28, 28])
Noisy X shape torch.Size([8, 1, 28, 28])

- 模型会预测退化过程中使用的噪声。
- 预测噪声这个目标会使权重更倾向于预测得到更低的噪声量。
扩展知识
时间步可以转换为embedding,在多个地方被输入模型。
输入纯噪声,在模型预测的基础上使用足够多的小步,不断迭代,每次去除一点点噪声。