@浙大疏锦行 Python day51
复习日,DDPM
python
class DenoiseDiffusion():
def __init__(self, eps_model: nn.Module, n_steps: int, device: torch.device):
super().__init__()
self.eps_model = eps_model
self.n_steps = n_steps
self.device = device
self.beta = torch.linspace(0.0001, 0.02, n_steps).to(device) # beta值
self.alpha = 1. - self.beta # alpha值
self.alpha_bar = torch.cumprod(self.alpha, dim=0) # alpha_bar值
self.sigma2 = self.beta # sampling中的sigma_t
self.tools = Tools()
# forward-diffusion process 获得 xt 所服从的高斯分布的mean和var
def q_xt_x0(self, x0: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
mean = self.tools.gather(self.alpha_bar, t) ** 0.5 * x0
var = 1 - self.tools.gather(self.alpha_bar, t)
return mean, var
# forward-diffusion process,生成xt
def q_sample(self, x0: torch.Tensor, t: torch.Tensor, eps: Optional[torch.Tensor] = None):
if eps is None:
eps = torch.randn_like(x0)
mean, var = self.q_xt_x0(x0, t)
return mean + (var ** 0.5) * eps # return xt 第t时刻加完噪声的图片
# 只有 sampling时才会用到的函数,执行Denoise Process
# sampling,根据xt和t推出x_{t-1} 抽象出来的一步,可以用于循环n次
def p_sample(self, xt: torch.Tensor, t: torch.Tensor):
eps_theta = self.eps_model(xt, t)
alpha_bar = self.tools.gather(self.alpha_bar, t)
alpha = self.tools.gather(self.alpha, t)
eps_coef = (1 - alpha) / (1 - alpha_bar) ** 0.5
mean = 1 / (alpha ** 0.5) * (xt - eps_coef * eps_theta)
var = self.tools.gather(self.sigma2, t)
eps = torch.randn(xt.shape, device=xt.device)
return mean + (var ** 0.5) * eps # sigma_t * eps + mean
# 会更新哪些模型的参数呢?
# loss function
def loss(self, x0: torch.Tensor, noise: Optional[torch.Tensor] = None):
batch_size = x0.shape[0]
t = torch.randint(0, self.n_steps, (batch_size,), device=x0.device, dtype=torch.long)
if noise is None:
noise = torch.randn_like(x0)
xt = self.q_sample(x0, t, eps=noise) # 传入的值为随机噪声 -- 高斯分布
eps_theta = self.eps_model(xt, t) # 模型预测值
return F.mse_loss(noise, eps_theta) # mse loss
# 激活函数
class Swish(nn.Module):
def forward(self, x):
return x* torch.sigmoid(x)
class ResidualBlock(nn.Module):
"""
每一个Residual block都有两层CNN做特征提取
"""
def __init__(self, in_channels: int, out_channels: int, time_channels: int,
n_groups: int = 32, dropout: float = 0.1):
"""
Params:
in_channels: 输入图片的channel数量
out_channels: 经过residual block后输出特征图的channel数量
time_channels:time_embedding的向量维度,例如t原来是个整型,值为1,表示时刻1,
现在要将其变成维度为(1, time_channels)的向量
n_groups: Group Norm中的超参
dropout: dropout rate
"""
super().__init__()
# 第一层卷积 = Group Norm + CNN
self.norm1 = nn.GroupNorm(n_groups, in_channels)
self.act1 = Swish()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), padding=(1, 1))
# 第二层卷积 = Group Norm + CNN
self.norm2 = nn.GroupNorm(n_groups, out_channels)
self.act2 = Swish()
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=(3, 3), padding=(1, 1))
# 当in_c = out_c时,残差连接直接将输入输出相加;
# 当in_c != out_c时,对输入数据做一次卷积,将其通道数变成和out_c一致,再和输出相加
if in_channels != out_channels:
self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1)) # 使用 1x1卷积修改通道数
else:
self.shortcut = nn.Identity() # 占位
# t向量的维度time_channels可能不等于out_c,所以我们要对起做一次线性转换
self.time_emb = nn.Linear(time_channels, out_channels)
self.time_act = Swish()
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor, t: torch.Tensor):
"""
Params:
x: 输入数据xt,尺寸大小为(batch_size, in_channels, height, width)
t: 输入数据t,尺寸大小为(batch_size, time_c)
【配合图例进行阅读】
"""
# 1.输入数据先过一层卷积
h = self.conv1(self.act1(self.norm1(x)))
# 2. 对time_embedding向量,通过线性层使time_c变为out_c,再和输入数据的特征图相加
h += self.time_emb(self.time_act(t))[:, :, None, None]
# 3、过第二层卷积
h = self.conv2(self.dropout(self.act2(self.norm2(h))))
# 4、返回残差连接后的结果
return h + self.shortcut(x)
# Attention Block
# 通道注意力机制
class AttentionBlock(nn.Module):
"""
Attention模块
和Transformer中的multi-head attention原理及实现方式一致
"""
def __init__(self, n_channels: int, n_heads: int = 1, d_k: int = None, n_groups: int = 32):
"""
Params:
n_channels:等待做attention操作的特征图的channel数
n_heads: attention头数
d_k: 每一个attention头处理的向量维度
n_groups: Group Norm超参数
"""
super().__init__()
# 一般而言,d_k = n_channels // n_heads,需保证n_channels能被n_heads整除
if d_k is None:
d_k = n_channels
# 定义Group Norm
self.norm = nn.GroupNorm(n_groups, n_channels)
# Multi-head attention层: 定义输入token分别和q,k,v矩阵相乘后的结果
self.projection = nn.Linear(n_channels, n_heads * d_k * 3)
# MLP层
self.output = nn.Linear(n_heads * d_k, n_channels)
self.scale = d_k ** -0.5
self.n_heads = n_heads
self.d_k = d_k
def forward(self, x: torch.Tensor, t: Optional[torch.Tensor] = None):
"""
Params:
x: 输入数据xt,尺寸大小为(batch_size, in_channels, height, width)
t: 输入数据t,尺寸大小为(batch_size, time_c)
【配合图例进行阅读】
"""
# t并没有用到,但是为了和ResidualBlock定义方式一致,这里也引入了t
_ = t
# 获取shape
batch_size, n_channels, height, width = x.shape
# 将输入数据的shape改为(batch_size, height*weight, n_channels)
# 这三个维度分别等同于transformer输入中的(batch_size, seq_length, token_embedding)
x = x.view(batch_size, n_channels, -1).permute(0, 2, 1)
# 计算输入过矩阵q,k,v的结果,self.projection通过矩阵计算,一次性把这三个结果出出来 也就是qkv矩阵是三个结果的拼接
# 其shape为:(batch_size, height*weight, n_heads, 3 * d_k)
qkv = self.projection(x).view(batch_size, -1, self.n_heads, 3 * self.d_k)
# 将拼接结果切开,每一个结果的shape为(batch_size, height*weight, n_heads, d_k)
q, k, v = torch.chunk(qkv, 3, dim=-1)
# 以下是正常计算attention score的过程,不再做说明
attn = torch.einsum('bihd,bjhd->bijh', q, k) * self.scale
attn = attn.softmax(dim=2)
res = torch.einsum('bijh,bjhd->bihd', attn, v)
# 将结果reshape成(batch_size, height*weight,, n_heads * d_k)
# 复习一下:n_heads * d_k = n_channels
res = res.view(batch_size, -1, self.n_heads * self.d_k)
# MLP层,输出结果shape为(batch_size, height*weight,, n_channels)
res = self.output(res)
# 残差连接
res += x
# 将输出结果从序列形式还原成图像形式,
# shape为(batch_size, n_channels, height, width)
res = res.permute(0, 2, 1).view(batch_size, n_channels, height, width)
return res
class DownBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int, time_channels: int,
use_attention: bool = False):
super().__init__()
self.res_block = ResidualBlock(in_channels, out_channels, time_channels)
if use_attention:
self.attn_block = AttentionBlock(out_channels)
else:
self.attn_block = nn.Identity()
def forward(self, x: torch.Tensor, t: torch.Tensor):
x = self.res_block(x, t)
x = self.attn_block(x)
return x
class UpBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int, time_channels: int,
use_attention: bool = False):
super.__init__()
self.res_block = ResidualBlock(in_channels + out_channels, out_channels, time_channels)
if use_attention:
self.attn = AttentionBlock(out_channels)
else:
self.attn = nn.Identity()
def forward(self, x: torch.Tensor, t: torch.Tensor):
x = self.res_block(x, t)
x = self.attn(x)
return x
class TimeEmbedding(nn.Module):
def __init__(self, n_channels: int):
"""
Params:
n_channels:即time_channel
"""
super().__init__()
self.n_channels = n_channels
self.lin1 = nn.Linear(self.n_channels // 4, self.n_channels)
self.act = Swish()
self.lin2 = nn.Linear(self.n_channels, self.n_channels)
def forward(self, t: torch.Tensor):
"""
Params:
t: 维度(batch_size),整型时刻t
"""
# 以下转换方法和Transformer的位置编码一致
# 【强烈建议大家动手跑一遍,打印出每一个步骤的结果和尺寸,更方便理解】
half_dim = self.n_channels // 8
emb = math.log(10_000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
emb = t[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=1)
# Transform with the MLP
emb = self.act(self.lin1(emb))
emb = self.lin2(emb)
# 输出维度(batch_size, time_channels)
return emb
class Upsample(nn.Module):
"""
上采样
"""
def __init__(self, n_channels):
super().__init__()
self.conv = nn.ConvTranspose2d(n_channels, n_channels, (4, 4), (2, 2), (1, 1))
def forward(self, x: torch.Tensor, t: torch.Tensor):
_ = t
return self.conv(x)
class Downsample(nn.Module):
"""
下采样
"""
def __init__(self, n_channels):
super().__init__()
self.conv = nn.Conv2d(n_channels, n_channels, (3, 3), (2, 2), (1, 1))
def forward(self, x: torch.Tensor, t: torch.Tensor):
_ = t
return self.conv(x)
class MiddleBlock(nn.Module):
def __init__(self, n_channels: int, time_channels: int):
super.__init__()
self.res1 = ResidualBlock(n_channels, n_channels, time_channels)
self.attn = AttentionBlock(n_channels)
self.res2 = ResidualBlock(n_channels, n_channels, time_channels)
def forward(self, x: torch.Tensor, t: torch.Tensor):
x = self.res1(x, t)
x = self.attn(x)
x = self.res2(x, t)
return x
class UNet(Module):
"""
DDPM UNet去噪模型主体架构
"""
def __init__(self, image_channels: int = 3, n_channels: int = 64,
ch_mults: Union[Tuple[int, ...], List[int]] = (1, 2, 2, 4),
is_attn: Union[Tuple[bool, ...], List[int]] = (False, False, True, True),
n_blocks: int = 2):
"""
Params:
image_channels:原始输入图片的channel数,对RGB图像来说就是3
n_channels: 在进UNet之前,会对原始图片做一次初步卷积,该初步卷积对应的
out_channel数,也就是图中左上角的第一个墨绿色箭头
ch_mults: 在Encoder下采样的每一层的out_channels倍数,
例如ch_mults[i] = 2,表示第i层特征图的out_channel数,
是第i-1层的2倍。Decoder上采样时也是同理,用的是反转后的ch_mults
is_attn: 在Encoder下采样/Decoder上采样的每一层,是否要在CNN做特征提取后再引入attention
(会在下文对该结构进行详细说明)
n_blocks: 在Encoder下采样/Decoder下采样的每一层,需要用多少个DownBlock/UpBlock(见图),
Deocder层最终使用的UpBlock数=n_blocks + 1
"""
super().__init__()
# 在Encoder下采样/Decoder上采样的过程中,图像依次缩小/放大,
# 每次变动都会产生一个新的图像分辨率
# 这里指的就是不同图像分辨率的个数,也可以理解成是Encoder/Decoder的层数
n_resolutions = len(ch_mults)
# 对原始图片做预处理,例如图中,将32*32*3 -> 32*32*64
self.image_proj = nn.Conv2d(image_channels, n_channels, kernel_size=(3, 3), padding=(1, 1))
# time_embedding,TimeEmbedding是nn.Module子类,我们会在下文详细讲解它的属性和forward方法
self.time_emb = TimeEmbedding(n_channels * 4)
# --------------------------
# 定义Encoder部分
# --------------------------
# down列表中的每个元素表示Encoder的每一层
down = []
# 初始化out_channel和in_channel
out_channels = in_channels = n_channels
# 遍历每一层
for i in range(n_resolutions):
# 根据设定好的规则,得到该层的out_channel
out_channels = in_channels * ch_mults[i]
# 根据设定好的规则,每一层有n_blocks个DownBlock
for _ in range(n_blocks):
down.append(DownBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))
in_channels = out_channels
# 对Encoder来说,每一层结束后,我们都做一次下采样,但Encoder的最后一层不做下采样
if i < n_resolutions - 1:
down.append(Downsample(in_channels))
# self.down即是完整的Encoder部分
self.down = nn.ModuleList(down)
# --------------------------
# 定义Middle部分
# --------------------------
self.middle = MiddleBlock(out_channels, n_channels * 4, )
# --------------------------
# 定义Decoder部分
# --------------------------
# 和Encoder部分基本一致,可对照绘制的架构图阅读
up = []
in_channels = out_channels
for i in reversed(range(n_resolutions)):
# `n_blocks` at the same resolution
out_channels = in_channels
for _ in range(n_blocks):
up.append(UpBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))
out_channels = in_channels // ch_mults[i]
up.append(UpBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))
in_channels = out_channels
if i > 0:
up.append(Upsample(in_channels))
# self.up即是完整的Decoder部分
self.up = nn.ModuleList(up)
# 定义group_norm, 激活函数,和最后一层的CNN(用于将Decoder最上一层的特征图还原成原始尺寸)
self.norm = nn.GroupNorm(8, n_channels)
self.act = Swish()
self.final = nn.Conv2d(in_channels, image_channels, kernel_size=(3, 3), padding=(1, 1))
def forward(self, x: torch.Tensor, t: torch.Tensor):
"""
Params:
x: 输入数据xt,尺寸大小为(batch_size, in_channels, height, width)
t: 输入数据t,尺寸大小为(batch_size)
"""
# 取得time_embedding
t = self.time_emb(t)
# 对原始图片做初步CNN处理
x = self.image_proj(x)
# -----------------------
# Encoder
# -----------------------
h = [x]
# First half of U-Net
for m in self.down:
x = m(x, t)
h.append(x)
# -----------------------
# Middle
# -----------------------
x = self.middle(x, t)
# -----------------------
# Decoder
# -----------------------
for m in self.up:
if isinstance(m, Upsample):
x = m(x, t)
else:
s = h.pop()
# skip_connection
x = torch.cat((x, s), dim=1)
x = m(x, t)
return self.final(self.act(self.norm(x)))