【PyTorch][chapter 28][李宏毅深度学习][Diffusion Model-2]

前言:

本篇主要简单介绍一下State Diffusion. State Diffuison 里面Noise Predictor 模型

主要应用了Unet 架构,提供了对应的PyTorch 代码。

https://github.com/nickd16/Diffusion-Models-from-Scratch


目录:

  1. 训练过程
  2. 采样过程
  3. U-Net
  4. 参考

一 训练过程(Forward Process)

1.1 论文

1.2 算法和训练目标

  1. 从我们的训练数据集中随机抽取一个图片
  2. 在我们的噪声(方差)计划上选择一个随机时间步长T
  3. 将该时间步的噪声添加到我们的数据中,通过"扩散核"模拟前向扩散过程
  4. 将消散后的图像传入模型,以预测我们添加的噪声
  5. 计算预测噪声和实际噪声之间的均方误差,并通过该目标函数优化模型的参数
  6. 然后重复!

最后得到Noise Predictor

1.3 超参数如何计算

可以通过下面图表示, 是一个逐渐递减的超参数,噪声比率越来越大

复制代码
# -*- coding: utf-8 -*-
"""
Created on Tue Dec 10 17:03:16 2024

@author: chengxf2
"""
import torch.nn as nn
import torch
class DDPM_Scheduler(nn.Module):
    def __init__(self, num_time_steps: int=1000):
        super().__init__()
        self.beta = torch.linspace(1e-4, 0.02, num_time_steps, requires_grad=False)
        alpha = 1 - self.beta
        self.alpha = torch.cumprod(alpha, dim=0).requires_grad_(False)


net = DDPM_Scheduler(20)

 
print(net.alpha)

二 采样算法(Reverse Process)

2.1 论文

2.2 采样算法总结如下:

  1. 从标准正态分布中生成随机噪声

for t = T,...1

  1. 通过估计逆过程分布来更新 Z(图片+噪声),其中均值由上一步的 Z 参数化,方差由我 们的模型在该时间步估计的噪声参数化

  2. 添加少量噪音以增加稳定性(解释如下)

  3. 重复此操作,直到到达时间步骤 0,即恢复的图像!

2.3 添加少量噪音以增加稳定性

但直观上可以归结为一个迭代过程,我们从纯噪声开始,估计在时间步骤 t 理论上添加的噪声,然后减去它。我们这样做直到我们得到生成的样本 。我们应该注意的唯一小细节是,在我们减去估计的噪声后,我们会加回一小部分以保持过程稳定。例如,在迭代过程开始时一次性估计和减去总噪声量会导致非常不连贯的样本,因此在实践中,经验表明,在每个时间步骤中加回一点噪声并进行迭代可以生成更好的样本。


三 U-Net

参考: 一文搞定UNet------图像分割(语义分割) - 简书

DDPM 论文的作者使用了最初为医学图像分割设计的 UNET 架构来构建模型,以预测扩散逆向过程的噪声。这里面简单的介绍一下UNet 架构

UNet是一种专门用于图像分割任务的卷积神经网络(CNN)架构,最早由Olaf Ronneberger等人在2015年提出。以下是对UNet的详细介绍:

3.1 模型

灰色箭头:

复制和裁剪,最上层的箭头:一张568∗568的图片经过操作后生成一张392∗392的图片,然后和经过收缩路径后的UNet图片合起来(原图为64通道,经过收缩路径的图片为64通道,合起来为128通道)
红色箭头:

为2∗2最大池化层,经过最大池化层后图片的尺寸要除以2。
绿色箭头:

为上采样操作,一般使用转置卷积(注:转置卷积只是将矩阵形状进行了还原,输出的矩阵数值和原来的不一样。)
蓝绿色箭头:

为一个1∗1的卷积核,输入通道数为64,输出通道数为2。可得Padding为0,Stride为

卷积计算公式:

采用的也是编码器解码器结构

左边为编码器,右边为解码器

复制代码
# -*- coding: utf-8 -*-
"""
Created on Thu Jan  2 10:10:16 2025

@author: chengxf2
"""
import copy
import os
import random
import shutil
import zipfile
from math import atan2, cos, sin, sqrt, pi, log

import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from PIL import Image
from numpy import linalg as LA
from torch import optim, nn
from torch.utils.data import DataLoader, random_split
from torch.utils.data.dataset import Dataset
from torchvision import transforms
from tqdm import tqdm

class CarvanaDataset(Dataset):
    def __init__(self, root_path, limit=None):
        self.root_path = root_path
        self.limit = limit
        self.images = sorted([root_path + "/train/" + i for i in os.listdir(root_path + "/train/")])[:self.limit]
        self.masks = sorted([root_path + "/train_masks/" + i for i in os.listdir(root_path + "/train_masks/")])[:self.limit]

        self.transform = transforms.Compose([
            transforms.Resize((512, 512)),
            transforms.ToTensor()])
        
        if self.limit is None:
            self.limit = len(self.images)

    def __getitem__(self, index):
        img = Image.open(self.images[index]).convert("RGB")
        mask = Image.open(self.masks[index]).convert("L")

        return self.transform(img), self.transform(mask)

    def __len__(self):
        return min(len(self.images), self.limit)

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        '''
        如图所示:
        每一步中重复进行的双重卷积(蓝色箭头)。
        它包括两个3x3的卷积,之后是ReLU激活函数:
        '''
        super().__init__()
        self.conv_op = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
            )
    def forward(self, x):
        output = self.conv_op(x)
        return output

class DownSample(nn.Module):
    '''
   下采样:
        这对应于图中左侧的部分(编码路径)
        在那里我们执行双重卷积和最大池化(红色箭头)。
    '''
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = DoubleConv(in_channels, out_channels)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        down = self.conv(x)
        p = self.pool(down)
        return down, p
    
class UpSample(nn.Module):
    '''
    上采样:
    这对应于图中右侧的部分(解码路径)。
    这是通过反卷积(绿色箭头)后接一个双重卷积来完成的。
    我们可以看到,在每次最大池化(MaxPooling)之前,都有一次复制和裁剪(灰色箭头),总共四次。
    '''
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_channels, in_channels//2, kernel_size=2, stride=2)
        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, inputs, x2):
        x1 = self.up(inputs)
        x = torch.cat([x1, x2], 1)
        return self.conv(x)
    
class UNet(nn.Module):
    def __init__(self, in_channels, num_classes):
        super().__init__()
        self.down_convolution_1 = DownSample(in_channels, out_channels=64)
        self.down_convolution_2 = DownSample(in_channels=64, out_channels=128)
        self.down_convolution_3 = DownSample(in_channels=128, out_channels=256)
        self.down_convolution_4 = DownSample(in_channels=256, out_channels=512)
        self.bottle_neck = DoubleConv(512, 1024)
        self.up_convolution_1 = UpSample(1024, 512)
        self.up_convolution_2 = UpSample(512, 256)
        self.up_convolution_3 = UpSample(256, 128)
        self.up_convolution_4 = UpSample(128, 64)
        self.out = nn.Conv2d(64, out_channels=num_classes, kernel_size=1)
    
    def forward(self, x):
        down_1,p1 = self.down_convolution_1(x)
        down_2, p2 = self.down_convolution_2(p1)
        down_3, p3 = self.down_convolution_3(p2)
        down_4, p4 = self.down_convolution_4(p3)
        bott = self.bottle_neck(p4)
        up_1 = self.up_convolution_1(bott, down_4)
        up_2 = self.up_convolution_2(up_1, down_3)
        up_3 = self.up_convolution_3(up_2, down_2)
        up_4 = self.up_convolution_4(up_3, down_1)
        out =  self.out(up_4)
        return out
def dice_coefficient(prediction, target, epsilon=1e-07):
    prediction_copy = prediction.clone()
    prediction_copy[prediction_copy < 0] = 0
    prediction_copy[prediction_copy > 0] = 1
    intersection = abs(torch.sum(prediction_copy * target))
    union = abs(torch.sum(prediction_copy) + torch.sum(target))
    dice = (2. * intersection + epsilon) / (union + epsilon)
    return dice
def drawloss():
    epochs_list = list(range(1, EPOCHS + 1))

    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(epochs_list, train_losses, label='Training Loss')
    plt.plot(epochs_list, val_losses, label='Validation Loss')
    plt.xticks(ticks=list(range(1, EPOCHS + 1, 1))) 
    plt.title('Loss over epochs')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.grid()
    plt.tight_layout()
    
    plt.legend()
    
    
    plt.subplot(1, 2, 2)
    plt.plot(epochs_list, train_dcs, label='Training DICE')
    plt.plot(epochs_list, val_dcs, label='Validation DICE')
    plt.xticks(ticks=list(range(1, EPOCHS + 1, 1)))  
    plt.title('DICE Coefficient over epochs')
    plt.xlabel('Epochs')
    plt.ylabel('DICE')
    plt.grid()
    plt.legend()
    
    plt.tight_layout()
    plt.show()
    epochs_list = list(range(1, EPOCHS + 1))

    plt.figure(figsize=(12, 5))
    plt.plot(epochs_list, train_losses, label='Training Loss')
    plt.plot(epochs_list, val_losses, label='Validation Loss')
    plt.xticks(ticks=list(range(1, EPOCHS + 1, 1))) 
    plt.ylim(0, 0.05)
    plt.title('Loss over epochs (zoomed)')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.grid()
    plt.tight_layout()
    
    plt.legend()
    plt.show()

def random_images_inference(image_tensors, mask_tensors, image_paths, model_pth, device):
    model = UNet(in_channels=3, num_classes=1).to(device)
    model.load_state_dict(torch.load(model_pth, map_location=torch.device(device)))

    transform = transforms.Compose([
        transforms.Resize((512, 512))
    ])

    # Iterate for the images, masks and paths
    for image_pth, mask_pth, image_paths in zip(image_tensors, mask_tensors, image_paths):
        # Load the image
        img = transform(image_pth)
        
        # Predict the imagen with the model
        pred_mask = model(img.unsqueeze(0))
        pred_mask = pred_mask.squeeze(0).permute(1,2,0)
        
        # Load the mask to compare
        mask = transform(mask_pth).permute(1, 2, 0).to(device)
        
        print(f"Image: {os.path.basename(image_paths)}, DICE coefficient: {round(float(dice_coefficient(pred_mask, mask)),5)}")
        
        # Show the images
        img = img.cpu().detach().permute(1, 2, 0)
        pred_mask = pred_mask.cpu().detach()
        pred_mask[pred_mask < 0] = 0
        pred_mask[pred_mask > 0] = 1
        
        plt.figure(figsize=(15, 16))
        plt.subplot(131), plt.imshow(img), plt.title("original")
        plt.subplot(132), plt.imshow(pred_mask, cmap="gray"), plt.title("predicted")
        plt.subplot(133), plt.imshow(mask, cmap="gray"), plt.title("mask")
        plt.show()
    
def  test(trained_model):
    test_running_loss = 0
    test_running_dc = 0
    
    with torch.no_grad():
        for idx, img_mask in enumerate(tqdm(test_dataloader, position=0, leave=True)):
            img = img_mask[0].float().to(device)
            mask = img_mask[1].float().to(device)
            y_pred = trained_model(img)
            loss = criterion(y_pred, mask)
            dc = dice_coefficient(y_pred, mask)
    
            test_running_loss += loss.item()
            test_running_dc += dc.item()
    
        test_loss = test_running_loss / (idx + 1)
        test_dc = test_running_dc / (idx + 1)
if __name__ == "__main__":
    print(os.listdir("../input/carvana-image-masking-challenge/"))
    DATASET_DIR = '../input/carvana-image-masking-challenge/'
    WORKING_DIR = '/kaggle/working/'
    if len(os.listdir(WORKING_DIR)) <= 1:
        with zipfile.ZipFile(DATASET_DIR + 'train.zip', 'r') as zip_file:
            zip_file.extractall(WORKING_DIR)
    
        with zipfile.ZipFile(DATASET_DIR + 'train_masks.zip', 'r') as zip_file:
            zip_file.extractall(WORKING_DIR)
    
    print(
        len(os.listdir(WORKING_DIR + 'train')),
        len(os.listdir(WORKING_DIR + 'train_masks'))
    )
    train_dataset = CarvanaDataset(WORKING_DIR)
    generator = torch.Generator().manual_seed(25)
    train_dataset, test_dataset = random_split(train_dataset, [0.8, 0.2], generator=generator)
    test_dataset, val_dataset =   random_split(test_dataset,  [0.5, 0.5], generator=generator)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    if device == "cuda":
        num_workers = torch.cuda.device_count() * 4
    LEARNING_RATE = 3e-4
    BATCH_SIZE = 8
    POCHS = 10
    train_losses = []
    train_dcs = []
    val_losses = []
    val_dcs = []
    LEARNING_RATE = 3e-4
    BATCH_SIZE = 8
    train_dataloader = DataLoader(dataset=train_dataset,
                              num_workers=num_workers, pin_memory=False,
                              batch_size=BATCH_SIZE,
                              shuffle=True)
    val_dataloader = DataLoader(dataset=val_dataset,
                            num_workers=num_workers, pin_memory=False,
                            batch_size=BATCH_SIZE,
                            shuffle=True)

    test_dataloader = DataLoader(dataset=test_dataset,
                            num_workers=num_workers, pin_memory=False,
                            batch_size=BATCH_SIZE,
                            shuffle=True)

    model = UNet(in_channels=3, num_classes=1).to(device)
    optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)
    criterion = nn.BCEWithLogitsLoss()
   
    for epoch in tqdm(range(EPOCHS)):
        model.train()
        train_running_loss = 0
        train_running_dc = 0
        
        for idx, img_mask in enumerate(tqdm(train_dataloader, position=0, leave=True)):
            img =  img_mask[0].float().to(device)
            mask = img_mask[1].float().to(device)
            
            y_pred = model(img)
            optimizer.zero_grad()
            
            dc = dice_coefficient(y_pred, mask)
            loss = criterion(y_pred, mask)
            
            train_running_loss += loss.item()
            train_running_dc += dc.item()
    
            loss.backward()
            optimizer.step()
    
        train_loss = train_running_loss / (idx + 1)
        train_dc = train_running_dc / (idx + 1)
        
        train_losses.append(train_loss)
        train_dcs.append(train_dc)
    
        model.eval()
        val_running_loss = 0
        val_running_dc = 0
        
        with torch.no_grad():
            for idx, img_mask in enumerate(tqdm(val_dataloader, position=0, leave=True)):
                img = img_mask[0].float().to(device)
                mask = img_mask[1].float().to(device)
    
                y_pred = model(img)
                loss = criterion(y_pred, mask)
                dc = dice_coefficient(y_pred, mask)
                
                val_running_loss += loss.item()
                val_running_dc += dc.item()
    
            val_loss = val_running_loss / (idx + 1)
            val_dc = val_running_dc / (idx + 1)
        
        val_losses.append(val_loss)
        val_dcs.append(val_dc)
    
        print("-" * 30)
        print(f"Training Loss EPOCH {epoch + 1}: {train_loss:.4f}")
        print(f"Training DICE EPOCH {epoch + 1}: {train_dc:.4f}")
        print("\n")
        print(f"Validation Loss EPOCH {epoch + 1}: {val_loss:.4f}")
        print(f"Validation DICE EPOCH {epoch + 1}: {val_dc:.4f}")
        print("-" * 30)
    
    # Saving the model
    torch.save(model.state_dict(), 'my_checkpoint.pth')
    
    n = 10

    image_tensors = []
    mask_tensors = []
    image_paths = []
    
    for _ in range(n):
        random_index = random.randint(0, len(test_dataloader.dataset) - 1)
        random_sample = test_dataloader.dataset[random_index]
    
        image_tensors.append(random_sample[0])  
        mask_tensors.append(random_sample[1]) 
        image_paths.append(random_sample[2]) 
    model_path = '/kaggle/working/my_checkpoint.pth'
    random_images_inference(image_tensors, mask_tensors, image_paths, model_path, device="cpu")

四 参考:

3.【生成式AI】Diffusion Model 原理剖析 (1_4)_哔哩哔哩_bilibili

https://towardsdatascience.com/diffusion-model-from-scratch-in-pytorch-ddpm-9d9760528946

https://medium.com/@mickael.boillaud/denoising-diffusion-model-from-scratch-using-pytorch-658805d293b4

相关推荐
一切皆是因缘际会37 分钟前
从概率拟合到内生心智:2026 下一代 AI 架构演进与落地实践
人工智能·深度学习·算法·架构
科研前沿1 小时前
镜像视界 CameraGraph™+多智能体:构建自感知自决策的全域空间认知网络技术方案
大数据·运维·人工智能·数码相机·计算机视觉
爱学习的张大1 小时前
具身智能论文问答(2):Diffusion Policy
人工智能
AI科技星1 小时前
全域数学·72分册·射影原本 无穷维射影几何卷细化子目录【乖乖数学】
人工智能·线性代数·算法·机器学习·数学建模·数据挖掘·量子计算
Chef_Chen1 小时前
论文解读:MemOS首次把记忆变成大模型的一等公民资源,Scaling Law迎来第三条曲线
人工智能·agent·memory
风落无尘1 小时前
《智能重生:从垃圾堆到AI工程师》——第四章 变化的艺术
人工智能·线性代数·算法
发哥来了1 小时前
AI视频生成模型选型指南:五大核心维度对比评测
大数据·人工智能·机器学习·ai·aigc
发哥来了1 小时前
AI驱动生产线的实际落地:一个东莞厂商的技术选型实录
大数据·人工智能·机器学习·ai·aigc
AC赳赳老秦1 小时前
知识产权辅助:用 OpenClaw 批量生成专利交底书 / 软著申请材料,自动校验格式与内容合规性
java·人工智能·python·算法·elasticsearch·deepseek·openclaw
AI科技2 小时前
原因大揭秘:为什么别人的编曲伴奏做得又快又好,2026年度甄选5款AI编曲软件汇总
人工智能