@浙大疏锦行 Python day51
复习日,DDPM
            
            
              python
              
              
            
          
          class DenoiseDiffusion():
    def __init__(self, eps_model: nn.Module, n_steps: int, device: torch.device):
        super().__init__()
        self.eps_model = eps_model
        self.n_steps = n_steps
        self.device = device
        self.beta = torch.linspace(0.0001, 0.02, n_steps).to(device)        # beta值
        self.alpha = 1. - self.beta                                         # alpha值
        self.alpha_bar = torch.cumprod(self.alpha, dim=0)                   # alpha_bar值   
        self.sigma2 = self.beta                                             # sampling中的sigma_t
        self.tools = Tools()
    # forward-diffusion process 获得 xt 所服从的高斯分布的mean和var
    def q_xt_x0(self, x0: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        mean = self.tools.gather(self.alpha_bar, t) ** 0.5 * x0
        var = 1 - self.tools.gather(self.alpha_bar, t)
        return mean, var
    # forward-diffusion process,生成xt
    def q_sample(self, x0: torch.Tensor, t: torch.Tensor, eps: Optional[torch.Tensor] = None):
        if eps is None:
            eps = torch.randn_like(x0)
        mean, var = self.q_xt_x0(x0, t)
        return mean + (var ** 0.5) * eps   # return xt 第t时刻加完噪声的图片
    
    # 只有 sampling时才会用到的函数,执行Denoise Process
    # sampling,根据xt和t推出x_{t-1}         抽象出来的一步,可以用于循环n次
    def p_sample(self, xt: torch.Tensor, t: torch.Tensor):
        eps_theta = self.eps_model(xt, t)
        alpha_bar = self.tools.gather(self.alpha_bar, t)
        alpha = self.tools.gather(self.alpha, t)
        eps_coef = (1 - alpha) / (1 - alpha_bar) ** 0.5
        mean = 1 / (alpha ** 0.5) * (xt - eps_coef * eps_theta)
        var = self.tools.gather(self.sigma2, t)
        eps = torch.randn(xt.shape, device=xt.device)
        return mean + (var ** 0.5) * eps        # sigma_t * eps + mean
    
    # 会更新哪些模型的参数呢?
    # loss function
    def loss(self, x0: torch.Tensor, noise: Optional[torch.Tensor] = None):
        batch_size = x0.shape[0]
        t = torch.randint(0, self.n_steps, (batch_size,), device=x0.device, dtype=torch.long)
        if noise is None:
            noise = torch.randn_like(x0)
        xt = self.q_sample(x0, t, eps=noise)        # 传入的值为随机噪声 -- 高斯分布
        eps_theta = self.eps_model(xt, t)           # 模型预测值
        return F.mse_loss(noise, eps_theta)  # mse loss
# 激活函数
class Swish(nn.Module):
    def forward(self, x):
        return x* torch.sigmoid(x)
    
class ResidualBlock(nn.Module):
    """
    每一个Residual block都有两层CNN做特征提取
    """
    def __init__(self, in_channels: int, out_channels: int, time_channels: int,
                 n_groups: int = 32, dropout: float = 0.1):
        """
        Params:
            in_channels:  输入图片的channel数量
            out_channels: 经过residual block后输出特征图的channel数量
            time_channels:time_embedding的向量维度,例如t原来是个整型,值为1,表示时刻1,
                           现在要将其变成维度为(1, time_channels)的向量
            n_groups:     Group Norm中的超参
            dropout:      dropout rate
        """
        super().__init__()
        
        # 第一层卷积 = Group Norm + CNN
        self.norm1 = nn.GroupNorm(n_groups, in_channels)
        self.act1 = Swish()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), padding=(1, 1))
        # 第二层卷积 = Group Norm + CNN
        self.norm2 = nn.GroupNorm(n_groups, out_channels)
        self.act2 = Swish()
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=(3, 3), padding=(1, 1))
        # 当in_c = out_c时,残差连接直接将输入输出相加;
        # 当in_c != out_c时,对输入数据做一次卷积,将其通道数变成和out_c一致,再和输出相加
        if in_channels != out_channels:
            self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1))            # 使用 1x1卷积修改通道数
        else:
            self.shortcut = nn.Identity()      # 占位 
        # t向量的维度time_channels可能不等于out_c,所以我们要对起做一次线性转换
        self.time_emb = nn.Linear(time_channels, out_channels)
        self.time_act = Swish()
        self.dropout = nn.Dropout(dropout)
    def forward(self, x: torch.Tensor, t: torch.Tensor):
        """
        Params:
            x: 输入数据xt,尺寸大小为(batch_size, in_channels, height, width)
            t: 输入数据t,尺寸大小为(batch_size, time_c)
        
        【配合图例进行阅读】
        """
        # 1.输入数据先过一层卷积
        h = self.conv1(self.act1(self.norm1(x)))
        # 2. 对time_embedding向量,通过线性层使time_c变为out_c,再和输入数据的特征图相加
        h += self.time_emb(self.time_act(t))[:, :, None, None]
        # 3、过第二层卷积
        h = self.conv2(self.dropout(self.act2(self.norm2(h))))
        # 4、返回残差连接后的结果
        return h + self.shortcut(x)
# Attention Block
# 通道注意力机制
class AttentionBlock(nn.Module):
    """
    Attention模块
    和Transformer中的multi-head attention原理及实现方式一致
    """
    def __init__(self, n_channels: int, n_heads: int = 1, d_k: int = None, n_groups: int = 32):
        """
        Params:
            n_channels:等待做attention操作的特征图的channel数
            n_heads:   attention头数
            d_k:       每一个attention头处理的向量维度
            n_groups:  Group Norm超参数
        """
        super().__init__()
        # 一般而言,d_k = n_channels // n_heads,需保证n_channels能被n_heads整除
        if d_k is None:
            d_k = n_channels
        # 定义Group Norm
        self.norm = nn.GroupNorm(n_groups, n_channels)
        # Multi-head attention层: 定义输入token分别和q,k,v矩阵相乘后的结果
        self.projection = nn.Linear(n_channels, n_heads * d_k * 3)
        # MLP层
        self.output = nn.Linear(n_heads * d_k, n_channels)
        
        self.scale = d_k ** -0.5
        self.n_heads = n_heads
        self.d_k = d_k
    def forward(self, x: torch.Tensor, t: Optional[torch.Tensor] = None):
        """
        Params:
            x: 输入数据xt,尺寸大小为(batch_size, in_channels, height, width)
            t: 输入数据t,尺寸大小为(batch_size, time_c)
        
        【配合图例进行阅读】
        """
        # t并没有用到,但是为了和ResidualBlock定义方式一致,这里也引入了t
        _ = t
        # 获取shape
        batch_size, n_channels, height, width = x.shape
        # 将输入数据的shape改为(batch_size, height*weight, n_channels)
        # 这三个维度分别等同于transformer输入中的(batch_size, seq_length, token_embedding)
        x = x.view(batch_size, n_channels, -1).permute(0, 2, 1)
        # 计算输入过矩阵q,k,v的结果,self.projection通过矩阵计算,一次性把这三个结果出出来 也就是qkv矩阵是三个结果的拼接
        # 其shape为:(batch_size, height*weight, n_heads, 3 * d_k)
        qkv = self.projection(x).view(batch_size, -1, self.n_heads, 3 * self.d_k)
        # 将拼接结果切开,每一个结果的shape为(batch_size, height*weight, n_heads, d_k)
        q, k, v = torch.chunk(qkv, 3, dim=-1)
        # 以下是正常计算attention score的过程,不再做说明
        attn = torch.einsum('bihd,bjhd->bijh', q, k) * self.scale
        attn = attn.softmax(dim=2)
        res = torch.einsum('bijh,bjhd->bihd', attn, v)
        # 将结果reshape成(batch_size, height*weight,, n_heads * d_k)
        # 复习一下:n_heads * d_k = n_channels
        res = res.view(batch_size, -1, self.n_heads * self.d_k)
        # MLP层,输出结果shape为(batch_size, height*weight,, n_channels)
        res = self.output(res)
        # 残差连接
        res += x
        # 将输出结果从序列形式还原成图像形式,
        # shape为(batch_size, n_channels, height, width)
        res = res.permute(0, 2, 1).view(batch_size, n_channels, height, width)
        return res
class DownBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, time_channels: int,
                 use_attention: bool = False):
        super().__init__()
        self.res_block = ResidualBlock(in_channels, out_channels, time_channels)
        if use_attention:
            self.attn_block = AttentionBlock(out_channels)
        else:
            self.attn_block = nn.Identity()
    
    def forward(self, x: torch.Tensor, t: torch.Tensor):
        x = self.res_block(x, t)
        x = self.attn_block(x)
        return x
class UpBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, time_channels: int,
                 use_attention: bool = False):
        super.__init__()
        self.res_block = ResidualBlock(in_channels + out_channels, out_channels, time_channels)
        if use_attention:
            self.attn = AttentionBlock(out_channels)
        else:
            self.attn = nn.Identity()
    
    def forward(self, x: torch.Tensor, t: torch.Tensor):
        x = self.res_block(x, t)
        x = self.attn(x)
        return x
class TimeEmbedding(nn.Module):
    
    def __init__(self, n_channels: int):
        """
        Params:
            n_channels:即time_channel
        """
        super().__init__()
        self.n_channels = n_channels
        self.lin1 = nn.Linear(self.n_channels // 4, self.n_channels)
        self.act = Swish()
        self.lin2 = nn.Linear(self.n_channels, self.n_channels)
    def forward(self, t: torch.Tensor):
        """
        Params:
            t: 维度(batch_size),整型时刻t
        """
        # 以下转换方法和Transformer的位置编码一致
        # 【强烈建议大家动手跑一遍,打印出每一个步骤的结果和尺寸,更方便理解】
        half_dim = self.n_channels // 8
        emb = math.log(10_000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
        emb = t[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=1)
        # Transform with the MLP
        emb = self.act(self.lin1(emb))
        emb = self.lin2(emb)
        # 输出维度(batch_size, time_channels)
        return emb
class Upsample(nn.Module):
    """
    上采样
    """
    def __init__(self, n_channels):
        super().__init__()
        self.conv = nn.ConvTranspose2d(n_channels, n_channels, (4, 4), (2, 2), (1, 1))
    def forward(self, x: torch.Tensor, t: torch.Tensor):
        _ = t
        return self.conv(x)
class Downsample(nn.Module):
    """
    下采样
    """
    def __init__(self, n_channels):
        super().__init__()
        self.conv = nn.Conv2d(n_channels, n_channels, (3, 3), (2, 2), (1, 1))
    def forward(self, x: torch.Tensor, t: torch.Tensor):
        _ = t
        return self.conv(x)
class MiddleBlock(nn.Module):
    def __init__(self, n_channels: int, time_channels: int):
        super.__init__()
        self.res1 = ResidualBlock(n_channels, n_channels, time_channels)
        self.attn = AttentionBlock(n_channels)
        self.res2 = ResidualBlock(n_channels, n_channels, time_channels)
    def forward(self, x: torch.Tensor, t: torch.Tensor):
        x = self.res1(x, t)
        x = self.attn(x)
        x = self.res2(x, t)
        return x
class UNet(Module):
    """
    DDPM UNet去噪模型主体架构
    """
    def __init__(self, image_channels: int = 3, n_channels: int = 64,
                 ch_mults: Union[Tuple[int, ...], List[int]] = (1, 2, 2, 4),
                 is_attn: Union[Tuple[bool, ...], List[int]] = (False, False, True, True),
                 n_blocks: int = 2):
        """
        Params:
            image_channels:原始输入图片的channel数,对RGB图像来说就是3
            
            n_channels:    在进UNet之前,会对原始图片做一次初步卷积,该初步卷积对应的
                            out_channel数,也就是图中左上角的第一个墨绿色箭头
                            
            ch_mults:      在Encoder下采样的每一层的out_channels倍数,
                            例如ch_mults[i] = 2,表示第i层特征图的out_channel数,
                            是第i-1层的2倍。Decoder上采样时也是同理,用的是反转后的ch_mults
                            
            is_attn:       在Encoder下采样/Decoder上采样的每一层,是否要在CNN做特征提取后再引入attention
                           (会在下文对该结构进行详细说明)
                           
            n_blocks:      在Encoder下采样/Decoder下采样的每一层,需要用多少个DownBlock/UpBlock(见图),
                            Deocder层最终使用的UpBlock数=n_blocks + 1     
        
        """
        super().__init__()
        # 在Encoder下采样/Decoder上采样的过程中,图像依次缩小/放大,
        # 每次变动都会产生一个新的图像分辨率
        # 这里指的就是不同图像分辨率的个数,也可以理解成是Encoder/Decoder的层数
        n_resolutions = len(ch_mults)
        # 对原始图片做预处理,例如图中,将32*32*3 -> 32*32*64
        self.image_proj = nn.Conv2d(image_channels, n_channels, kernel_size=(3, 3), padding=(1, 1))
        # time_embedding,TimeEmbedding是nn.Module子类,我们会在下文详细讲解它的属性和forward方法
        self.time_emb = TimeEmbedding(n_channels * 4)
        # --------------------------
        # 定义Encoder部分
        # --------------------------
        # down列表中的每个元素表示Encoder的每一层
        down = []
        # 初始化out_channel和in_channel
        out_channels = in_channels = n_channels
        # 遍历每一层
        for i in range(n_resolutions):
            # 根据设定好的规则,得到该层的out_channel
            out_channels = in_channels * ch_mults[i]
            # 根据设定好的规则,每一层有n_blocks个DownBlock
            for _ in range(n_blocks):
                down.append(DownBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))
                in_channels = out_channels
            # 对Encoder来说,每一层结束后,我们都做一次下采样,但Encoder的最后一层不做下采样
            if i < n_resolutions - 1:
                down.append(Downsample(in_channels))
        # self.down即是完整的Encoder部分
        self.down = nn.ModuleList(down)
        # --------------------------
        # 定义Middle部分
        # --------------------------
        self.middle = MiddleBlock(out_channels, n_channels * 4, )
        # --------------------------
        # 定义Decoder部分
        # --------------------------
        
        # 和Encoder部分基本一致,可对照绘制的架构图阅读
        up = []
        in_channels = out_channels
        for i in reversed(range(n_resolutions)):
            # `n_blocks` at the same resolution
            out_channels = in_channels
            for _ in range(n_blocks):
                up.append(UpBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))
        
            out_channels = in_channels // ch_mults[i]
            up.append(UpBlock(in_channels, out_channels, n_channels * 4, is_attn[i]))
            in_channels = out_channels
            
            if i > 0:
                up.append(Upsample(in_channels))
        # self.up即是完整的Decoder部分
        self.up = nn.ModuleList(up)
        # 定义group_norm, 激活函数,和最后一层的CNN(用于将Decoder最上一层的特征图还原成原始尺寸)
        self.norm = nn.GroupNorm(8, n_channels)
        self.act = Swish()
        self.final = nn.Conv2d(in_channels, image_channels, kernel_size=(3, 3), padding=(1, 1))
    def forward(self, x: torch.Tensor, t: torch.Tensor):
        """
        Params:
            x: 输入数据xt,尺寸大小为(batch_size, in_channels, height, width)
            t: 输入数据t,尺寸大小为(batch_size)
        """
        # 取得time_embedding
        t = self.time_emb(t)
        # 对原始图片做初步CNN处理
        x = self.image_proj(x)
        # -----------------------
        # Encoder
        # -----------------------
        h = [x]
        # First half of U-Net
        for m in self.down:
            x = m(x, t)
            h.append(x)
        # -----------------------
        # Middle
        # -----------------------
        x = self.middle(x, t)
        # -----------------------
        # Decoder
        # -----------------------
        for m in self.up:
            if isinstance(m, Upsample):
                x = m(x, t)
            else:
                s = h.pop()
                # skip_connection
                x = torch.cat((x, s), dim=1)
                x = m(x, t)
        return self.final(self.act(self.norm(x)))