导入python包
python
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
silu激活函数
python
class SiLU(nn.Module):
# SiLU激活函数
@staticmethod
def forward(x):
return x * torch.sigmoid(x)
归一化设置
python
def get_norm(norm, num_channels, num_groups):
if norm == "in":
return nn.InstanceNorm2d(num_channels, affine=True)
elif norm == "bn":
return nn.BatchNorm2d(num_channels)
elif norm == "gn":
return nn.GroupNorm(num_groups, num_channels)
elif norm is None:
return nn.Identity()
else:
raise ValueError("unknown normalization type")
计算时间步长的位置嵌入,一半为sin,一半为cos
python
class PositionalEmbedding(nn.Module):
def __init__(self, dim, scale=1.0):
super().__init__()
assert dim % 2 == 0
self.dim = dim
self.scale = scale
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / half_dim
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
# x * self.scale和emb外积
emb = torch.outer(x * self.scale, emb)
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
上下采样层设置
python
class Downsample(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.downsample = nn.Conv2d(in_channels, in_channels, 3, stride=2, padding=1)
def forward(self, x, time_emb, y):
if x.shape[2] % 2 == 1:
raise ValueError("downsampling tensor height should be even")
if x.shape[3] % 2 == 1:
raise ValueError("downsampling tensor width should be even")
return self.downsample(x)
class Upsample(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.upsample = nn.Sequential(
nn.Upsample(scale_factor=2, mode="nearest"),
nn.Conv2d(in_channels, in_channels, 3, padding=1),
)
def forward(self, x, time_emb, y):
return self.upsample(x)
使用Self-Attention注意力机制,做一个全局的Self-Attention
python
class AttentionBlock(nn.Module):
def __init__(self, in_channels, norm="gn", num_groups=32):
super().__init__()
self.in_channels = in_channels
self.norm = get_norm(norm, in_channels, num_groups)
self.to_qkv = nn.Conv2d(in_channels, in_channels * 3, 1)
self.to_out = nn.Conv2d(in_channels, in_channels, 1)
def forward(self, x):
b, c, h, w = x.shape
q, k, v = torch.split(self.to_qkv(self.norm(x)), self.in_channels, dim=1)
q = q.permute(0, 2, 3, 1).view(b, h * w, c)
k = k.view(b, c, h * w)
v = v.permute(0, 2, 3, 1).view(b, h * w, c)
dot_products = torch.bmm(q, k) * (c ** (-0.5))
assert dot_products.shape == (b, h * w, h * w)
attention = torch.softmax(dot_products, dim=-1)
out = torch.bmm(attention, v)
assert out.shape == (b, h * w, c)
out = out.view(b, h, w, c).permute(0, 3, 1, 2)
return self.to_out(out) + x
用于特征提取的残差结构
python
class ResidualBlock(nn.Module):
def __init__(
self, in_channels, out_channels, dropout, time_emb_dim=None, num_classes=None, activation=F.relu,
norm="gn", num_groups=32, use_attention=False,
):
super().__init__()
self.activation = activation
self.norm_1 = get_norm(norm, in_channels, num_groups)
self.conv_1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
self.norm_2 = get_norm(norm, out_channels, num_groups)
self.conv_2 = nn.Sequential(
nn.Dropout(p=dropout),
nn.Conv2d(out_channels, out_channels, 3, padding=1),
)
self.time_bias = nn.Linear(time_emb_dim, out_channels) if time_emb_dim is not None else None
self.class_bias = nn.Embedding(num_classes, out_channels) if num_classes is not None else None
self.residual_connection = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
self.attention = nn.Identity() if not use_attention else AttentionBlock(out_channels, norm, num_groups)
def forward(self, x, time_emb=None, y=None):
out = self.activation(self.norm_1(x))
# 第一个卷积
out = self.conv_1(out)
# 对时间time_emb做一个全连接,施加在通道上
if self.time_bias is not None:
if time_emb is None:
raise ValueError("time conditioning was specified but time_emb is not passed")
out += self.time_bias(self.activation(time_emb))[:, :, None, None]
# 对种类y_emb做一个全连接,施加在通道上
if self.class_bias is not None:
if y is None:
raise ValueError("class conditioning was specified but y is not passed")
out += self.class_bias(y)[:, :, None, None]
out = self.activation(self.norm_2(out))
# 第二个卷积+残差边
out = self.conv_2(out) + self.residual_connection(x)
# 最后做个Attention
out = self.attention(out)
return out
U-Net模型设计
python
class UNet(nn.Module):
def __init__(
self, img_channels, base_channels=128, channel_mults=(1, 2, 2, 2),
num_res_blocks=2, time_emb_dim=128 * 4, time_emb_scale=1.0, num_classes=None, activation=F.silu,
dropout=0.1, attention_resolutions=(1,), norm="gn", num_groups=32, initial_pad=0,
):
super().__init__()
# 使用到的激活函数,一般为SILU
self.activation = activation
# 是否对输入进行padding
self.initial_pad = initial_pad
# 需要去区分的类别数
self.num_classes = num_classes
# 对时间轴输入的全连接层
self.time_mlp = nn.Sequential(
PositionalEmbedding(base_channels, time_emb_scale),
nn.Linear(base_channels, time_emb_dim),
nn.SiLU(),
nn.Linear(time_emb_dim, time_emb_dim),
) if time_emb_dim is not None else None
# 对输入图片的第一个卷积
self.init_conv = nn.Conv2d(img_channels, base_channels, 3, padding=1)
# self.downs用于存储下采样用到的层,首先利用ResidualBlock提取特征
# 然后利用Downsample降低特征图的高宽
self.downs = nn.ModuleList()
self.ups = nn.ModuleList()
# channels指的是每一个模块处理后的通道数
# now_channels是一个中间变量,代表中间的通道数
channels = [base_channels]
now_channels = base_channels
for i, mult in enumerate(channel_mults):
out_channels = base_channels * mult
for _ in range(num_res_blocks):
self.downs.append(
ResidualBlock(
now_channels, out_channels, dropout,
time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation,
norm=norm, num_groups=num_groups, use_attention=i in attention_resolutions,
)
)
now_channels = out_channels
channels.append(now_channels)
if i != len(channel_mults) - 1:
self.downs.append(Downsample(now_channels))
channels.append(now_channels)
# 可以看作是特征整合,中间的一个特征提取模块
self.mid = nn.ModuleList(
[
ResidualBlock(
now_channels, now_channels, dropout,
time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation,
norm=norm, num_groups=num_groups, use_attention=True,
),
ResidualBlock(
now_channels, now_channels, dropout,
time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation,
norm=norm, num_groups=num_groups, use_attention=False,
),
]
)
# 进行上采样,进行特征融合
for i, mult in reversed(list(enumerate(channel_mults))):
out_channels = base_channels * mult
for _ in range(num_res_blocks + 1):
self.ups.append(ResidualBlock(
channels.pop() + now_channels, out_channels, dropout,
time_emb_dim=time_emb_dim, num_classes=num_classes, activation=activation,
norm=norm, num_groups=num_groups, use_attention=i in attention_resolutions,
))
now_channels = out_channels
if i != 0:
self.ups.append(Upsample(now_channels))
assert len(channels) == 0
self.out_norm = get_norm(norm, base_channels, num_groups)
self.out_conv = nn.Conv2d(base_channels, img_channels, 3, padding=1)
def forward(self, x, time=None, y=None):
# 是否对输入进行padding
ip = self.initial_pad
if ip != 0:
x = F.pad(x, (ip,) * 4)
# 对时间轴输入的全连接层
if self.time_mlp is not None:
if time is None:
raise ValueError("time conditioning was specified but tim is not passed")
time_emb = self.time_mlp(time)
else:
time_emb = None
if self.num_classes is not None and y is None:
raise ValueError("class conditioning was specified but y is not passed")
# 对输入图片的第一个卷积
x = self.init_conv(x)
# skips用于存放下采样的中间层
skips = [x]
for layer in self.downs:
x = layer(x, time_emb, y)
skips.append(x)
# 特征整合与提取
for layer in self.mid:
x = layer(x, time_emb, y)
# 上采样并进行特征融合
for layer in self.ups:
if isinstance(layer, ResidualBlock):
x = torch.cat([x, skips.pop()], dim=1)
x = layer(x, time_emb, y)
# 上采样并进行特征融合
x = self.activation(self.out_norm(x))
x = self.out_conv(x)
if self.initial_pad != 0:
return x[:, :, ip:-ip, ip:-ip]
else:
return x