【PyTorch][chapter 28][李宏毅深度学习][Diffusion Model-2]

前言:

本篇主要简单介绍一下State Diffusion. State Diffuison 里面Noise Predictor 模型

主要应用了Unet 架构,提供了对应的PyTorch 代码。

https://github.com/nickd16/Diffusion-Models-from-Scratch


目录:

  1. 训练过程
  2. 采样过程
  3. U-Net
  4. 参考

一 训练过程(Forward Process)

1.1 论文

1.2 算法和训练目标

  1. 从我们的训练数据集中随机抽取一个图片
  2. 在我们的噪声(方差)计划上选择一个随机时间步长T
  3. 将该时间步的噪声添加到我们的数据中,通过"扩散核"模拟前向扩散过程
  4. 将消散后的图像传入模型,以预测我们添加的噪声
  5. 计算预测噪声和实际噪声之间的均方误差,并通过该目标函数优化模型的参数
  6. 然后重复!

最后得到Noise Predictor

1.3 超参数如何计算

可以通过下面图表示, 是一个逐渐递减的超参数,噪声比率越来越大

# -*- coding: utf-8 -*-
"""
Created on Tue Dec 10 17:03:16 2024

@author: chengxf2
"""
import torch.nn as nn
import torch
class DDPM_Scheduler(nn.Module):
    def __init__(self, num_time_steps: int=1000):
        super().__init__()
        self.beta = torch.linspace(1e-4, 0.02, num_time_steps, requires_grad=False)
        alpha = 1 - self.beta
        self.alpha = torch.cumprod(alpha, dim=0).requires_grad_(False)


net = DDPM_Scheduler(20)

 
print(net.alpha)

二 采样算法(Reverse Process)

2.1 论文

2.2 采样算法总结如下:

  1. 从标准正态分布中生成随机噪声

for t = T,...1

  1. 通过估计逆过程分布来更新 Z(图片+噪声),其中均值由上一步的 Z 参数化,方差由我 们的模型在该时间步估计的噪声参数化

  2. 添加少量噪音以增加稳定性(解释如下)

  3. 重复此操作,直到到达时间步骤 0,即恢复的图像!

2.3 添加少量噪音以增加稳定性

但直观上可以归结为一个迭代过程,我们从纯噪声开始,估计在时间步骤 t 理论上添加的噪声,然后减去它。我们这样做直到我们得到生成的样本 。我们应该注意的唯一小细节是,在我们减去估计的噪声后,我们会加回一小部分以保持过程稳定。例如,在迭代过程开始时一次性估计和减去总噪声量会导致非常不连贯的样本,因此在实践中,经验表明,在每个时间步骤中加回一点噪声并进行迭代可以生成更好的样本。


三 U-Net

参考: 一文搞定UNet------图像分割(语义分割) - 简书

DDPM 论文的作者使用了最初为医学图像分割设计的 UNET 架构来构建模型,以预测扩散逆向过程的噪声。这里面简单的介绍一下UNet 架构

UNet是一种专门用于图像分割任务的卷积神经网络(CNN)架构,最早由Olaf Ronneberger等人在2015年提出。以下是对UNet的详细介绍:

3.1 模型

灰色箭头:

复制和裁剪,最上层的箭头:一张568∗568的图片经过操作后生成一张392∗392的图片,然后和经过收缩路径后的UNet图片合起来(原图为64通道,经过收缩路径的图片为64通道,合起来为128通道)
红色箭头:

为2∗2最大池化层,经过最大池化层后图片的尺寸要除以2。
绿色箭头:

为上采样操作,一般使用转置卷积(注:转置卷积只是将矩阵形状进行了还原,输出的矩阵数值和原来的不一样。)
蓝绿色箭头:

为一个1∗1的卷积核,输入通道数为64,输出通道数为2。可得Padding为0,Stride为

卷积计算公式:

采用的也是编码器解码器结构

左边为编码器,右边为解码器

# -*- coding: utf-8 -*-
"""
Created on Thu Jan  2 10:10:16 2025

@author: chengxf2
"""
import copy
import os
import random
import shutil
import zipfile
from math import atan2, cos, sin, sqrt, pi, log

import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from PIL import Image
from numpy import linalg as LA
from torch import optim, nn
from torch.utils.data import DataLoader, random_split
from torch.utils.data.dataset import Dataset
from torchvision import transforms
from tqdm import tqdm

class CarvanaDataset(Dataset):
    def __init__(self, root_path, limit=None):
        self.root_path = root_path
        self.limit = limit
        self.images = sorted([root_path + "/train/" + i for i in os.listdir(root_path + "/train/")])[:self.limit]
        self.masks = sorted([root_path + "/train_masks/" + i for i in os.listdir(root_path + "/train_masks/")])[:self.limit]

        self.transform = transforms.Compose([
            transforms.Resize((512, 512)),
            transforms.ToTensor()])
        
        if self.limit is None:
            self.limit = len(self.images)

    def __getitem__(self, index):
        img = Image.open(self.images[index]).convert("RGB")
        mask = Image.open(self.masks[index]).convert("L")

        return self.transform(img), self.transform(mask)

    def __len__(self):
        return min(len(self.images), self.limit)

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        '''
        如图所示:
        每一步中重复进行的双重卷积(蓝色箭头)。
        它包括两个3x3的卷积,之后是ReLU激活函数:
        '''
        super().__init__()
        self.conv_op = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
            )
    def forward(self, x):
        output = self.conv_op(x)
        return output

class DownSample(nn.Module):
    '''
   下采样:
        这对应于图中左侧的部分(编码路径)
        在那里我们执行双重卷积和最大池化(红色箭头)。
    '''
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = DoubleConv(in_channels, out_channels)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        down = self.conv(x)
        p = self.pool(down)
        return down, p
    
class UpSample(nn.Module):
    '''
    上采样:
    这对应于图中右侧的部分(解码路径)。
    这是通过反卷积(绿色箭头)后接一个双重卷积来完成的。
    我们可以看到,在每次最大池化(MaxPooling)之前,都有一次复制和裁剪(灰色箭头),总共四次。
    '''
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_channels, in_channels//2, kernel_size=2, stride=2)
        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, inputs, x2):
        x1 = self.up(inputs)
        x = torch.cat([x1, x2], 1)
        return self.conv(x)
    
class UNet(nn.Module):
    def __init__(self, in_channels, num_classes):
        super().__init__()
        self.down_convolution_1 = DownSample(in_channels, out_channels=64)
        self.down_convolution_2 = DownSample(in_channels=64, out_channels=128)
        self.down_convolution_3 = DownSample(in_channels=128, out_channels=256)
        self.down_convolution_4 = DownSample(in_channels=256, out_channels=512)
        self.bottle_neck = DoubleConv(512, 1024)
        self.up_convolution_1 = UpSample(1024, 512)
        self.up_convolution_2 = UpSample(512, 256)
        self.up_convolution_3 = UpSample(256, 128)
        self.up_convolution_4 = UpSample(128, 64)
        self.out = nn.Conv2d(64, out_channels=num_classes, kernel_size=1)
    
    def forward(self, x):
        down_1,p1 = self.down_convolution_1(x)
        down_2, p2 = self.down_convolution_2(p1)
        down_3, p3 = self.down_convolution_3(p2)
        down_4, p4 = self.down_convolution_4(p3)
        bott = self.bottle_neck(p4)
        up_1 = self.up_convolution_1(bott, down_4)
        up_2 = self.up_convolution_2(up_1, down_3)
        up_3 = self.up_convolution_3(up_2, down_2)
        up_4 = self.up_convolution_4(up_3, down_1)
        out =  self.out(up_4)
        return out
def dice_coefficient(prediction, target, epsilon=1e-07):
    prediction_copy = prediction.clone()
    prediction_copy[prediction_copy < 0] = 0
    prediction_copy[prediction_copy > 0] = 1
    intersection = abs(torch.sum(prediction_copy * target))
    union = abs(torch.sum(prediction_copy) + torch.sum(target))
    dice = (2. * intersection + epsilon) / (union + epsilon)
    return dice
def drawloss():
    epochs_list = list(range(1, EPOCHS + 1))

    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(epochs_list, train_losses, label='Training Loss')
    plt.plot(epochs_list, val_losses, label='Validation Loss')
    plt.xticks(ticks=list(range(1, EPOCHS + 1, 1))) 
    plt.title('Loss over epochs')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.grid()
    plt.tight_layout()
    
    plt.legend()
    
    
    plt.subplot(1, 2, 2)
    plt.plot(epochs_list, train_dcs, label='Training DICE')
    plt.plot(epochs_list, val_dcs, label='Validation DICE')
    plt.xticks(ticks=list(range(1, EPOCHS + 1, 1)))  
    plt.title('DICE Coefficient over epochs')
    plt.xlabel('Epochs')
    plt.ylabel('DICE')
    plt.grid()
    plt.legend()
    
    plt.tight_layout()
    plt.show()
    epochs_list = list(range(1, EPOCHS + 1))

    plt.figure(figsize=(12, 5))
    plt.plot(epochs_list, train_losses, label='Training Loss')
    plt.plot(epochs_list, val_losses, label='Validation Loss')
    plt.xticks(ticks=list(range(1, EPOCHS + 1, 1))) 
    plt.ylim(0, 0.05)
    plt.title('Loss over epochs (zoomed)')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.grid()
    plt.tight_layout()
    
    plt.legend()
    plt.show()

def random_images_inference(image_tensors, mask_tensors, image_paths, model_pth, device):
    model = UNet(in_channels=3, num_classes=1).to(device)
    model.load_state_dict(torch.load(model_pth, map_location=torch.device(device)))

    transform = transforms.Compose([
        transforms.Resize((512, 512))
    ])

    # Iterate for the images, masks and paths
    for image_pth, mask_pth, image_paths in zip(image_tensors, mask_tensors, image_paths):
        # Load the image
        img = transform(image_pth)
        
        # Predict the imagen with the model
        pred_mask = model(img.unsqueeze(0))
        pred_mask = pred_mask.squeeze(0).permute(1,2,0)
        
        # Load the mask to compare
        mask = transform(mask_pth).permute(1, 2, 0).to(device)
        
        print(f"Image: {os.path.basename(image_paths)}, DICE coefficient: {round(float(dice_coefficient(pred_mask, mask)),5)}")
        
        # Show the images
        img = img.cpu().detach().permute(1, 2, 0)
        pred_mask = pred_mask.cpu().detach()
        pred_mask[pred_mask < 0] = 0
        pred_mask[pred_mask > 0] = 1
        
        plt.figure(figsize=(15, 16))
        plt.subplot(131), plt.imshow(img), plt.title("original")
        plt.subplot(132), plt.imshow(pred_mask, cmap="gray"), plt.title("predicted")
        plt.subplot(133), plt.imshow(mask, cmap="gray"), plt.title("mask")
        plt.show()
    
def  test(trained_model):
    test_running_loss = 0
    test_running_dc = 0
    
    with torch.no_grad():
        for idx, img_mask in enumerate(tqdm(test_dataloader, position=0, leave=True)):
            img = img_mask[0].float().to(device)
            mask = img_mask[1].float().to(device)
            y_pred = trained_model(img)
            loss = criterion(y_pred, mask)
            dc = dice_coefficient(y_pred, mask)
    
            test_running_loss += loss.item()
            test_running_dc += dc.item()
    
        test_loss = test_running_loss / (idx + 1)
        test_dc = test_running_dc / (idx + 1)
if __name__ == "__main__":
    print(os.listdir("../input/carvana-image-masking-challenge/"))
    DATASET_DIR = '../input/carvana-image-masking-challenge/'
    WORKING_DIR = '/kaggle/working/'
    if len(os.listdir(WORKING_DIR)) <= 1:
        with zipfile.ZipFile(DATASET_DIR + 'train.zip', 'r') as zip_file:
            zip_file.extractall(WORKING_DIR)
    
        with zipfile.ZipFile(DATASET_DIR + 'train_masks.zip', 'r') as zip_file:
            zip_file.extractall(WORKING_DIR)
    
    print(
        len(os.listdir(WORKING_DIR + 'train')),
        len(os.listdir(WORKING_DIR + 'train_masks'))
    )
    train_dataset = CarvanaDataset(WORKING_DIR)
    generator = torch.Generator().manual_seed(25)
    train_dataset, test_dataset = random_split(train_dataset, [0.8, 0.2], generator=generator)
    test_dataset, val_dataset =   random_split(test_dataset,  [0.5, 0.5], generator=generator)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    if device == "cuda":
        num_workers = torch.cuda.device_count() * 4
    LEARNING_RATE = 3e-4
    BATCH_SIZE = 8
    POCHS = 10
    train_losses = []
    train_dcs = []
    val_losses = []
    val_dcs = []
    LEARNING_RATE = 3e-4
    BATCH_SIZE = 8
    train_dataloader = DataLoader(dataset=train_dataset,
                              num_workers=num_workers, pin_memory=False,
                              batch_size=BATCH_SIZE,
                              shuffle=True)
    val_dataloader = DataLoader(dataset=val_dataset,
                            num_workers=num_workers, pin_memory=False,
                            batch_size=BATCH_SIZE,
                            shuffle=True)

    test_dataloader = DataLoader(dataset=test_dataset,
                            num_workers=num_workers, pin_memory=False,
                            batch_size=BATCH_SIZE,
                            shuffle=True)

    model = UNet(in_channels=3, num_classes=1).to(device)
    optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)
    criterion = nn.BCEWithLogitsLoss()
   
    for epoch in tqdm(range(EPOCHS)):
        model.train()
        train_running_loss = 0
        train_running_dc = 0
        
        for idx, img_mask in enumerate(tqdm(train_dataloader, position=0, leave=True)):
            img =  img_mask[0].float().to(device)
            mask = img_mask[1].float().to(device)
            
            y_pred = model(img)
            optimizer.zero_grad()
            
            dc = dice_coefficient(y_pred, mask)
            loss = criterion(y_pred, mask)
            
            train_running_loss += loss.item()
            train_running_dc += dc.item()
    
            loss.backward()
            optimizer.step()
    
        train_loss = train_running_loss / (idx + 1)
        train_dc = train_running_dc / (idx + 1)
        
        train_losses.append(train_loss)
        train_dcs.append(train_dc)
    
        model.eval()
        val_running_loss = 0
        val_running_dc = 0
        
        with torch.no_grad():
            for idx, img_mask in enumerate(tqdm(val_dataloader, position=0, leave=True)):
                img = img_mask[0].float().to(device)
                mask = img_mask[1].float().to(device)
    
                y_pred = model(img)
                loss = criterion(y_pred, mask)
                dc = dice_coefficient(y_pred, mask)
                
                val_running_loss += loss.item()
                val_running_dc += dc.item()
    
            val_loss = val_running_loss / (idx + 1)
            val_dc = val_running_dc / (idx + 1)
        
        val_losses.append(val_loss)
        val_dcs.append(val_dc)
    
        print("-" * 30)
        print(f"Training Loss EPOCH {epoch + 1}: {train_loss:.4f}")
        print(f"Training DICE EPOCH {epoch + 1}: {train_dc:.4f}")
        print("\n")
        print(f"Validation Loss EPOCH {epoch + 1}: {val_loss:.4f}")
        print(f"Validation DICE EPOCH {epoch + 1}: {val_dc:.4f}")
        print("-" * 30)
    
    # Saving the model
    torch.save(model.state_dict(), 'my_checkpoint.pth')
    
    n = 10

    image_tensors = []
    mask_tensors = []
    image_paths = []
    
    for _ in range(n):
        random_index = random.randint(0, len(test_dataloader.dataset) - 1)
        random_sample = test_dataloader.dataset[random_index]
    
        image_tensors.append(random_sample[0])  
        mask_tensors.append(random_sample[1]) 
        image_paths.append(random_sample[2]) 
    model_path = '/kaggle/working/my_checkpoint.pth'
    random_images_inference(image_tensors, mask_tensors, image_paths, model_path, device="cpu")

四 参考:

3.【生成式AI】Diffusion Model 原理剖析 (1_4)_哔哩哔哩_bilibili

https://towardsdatascience.com/diffusion-model-from-scratch-in-pytorch-ddpm-9d9760528946

https://medium.com/@mickael.boillaud/denoising-diffusion-model-from-scratch-using-pytorch-658805d293b4

相关推荐
光影少年5 分钟前
人工智能:是助力还是取代?
人工智能
XianxinMao21 分钟前
超越LLaMA:语言模型三重奏Alpaca、Vicuna与WizardLM
人工智能·语言模型·llama
量子-Alex42 分钟前
【CVPR 2024】【遥感目标检测】Poly Kernel Inception Network for Remote Sensing Detection
人工智能·目标检测·计算机视觉
yvestine44 分钟前
数据挖掘——概论
人工智能·笔记·机器学习·数据挖掘
martian6651 小时前
【人工智能数据科学与数据处理】——深入详解人工智能数据科学与数据处理之数据可视化与数据库技术
数据库·人工智能·数据科学·数据处理
AIBigModel2 小时前
微软:GPT-4o-mini只有8B,o1-mini仅100B
深度学习
Fishel-2 小时前
预测facebook签到位置
人工智能·python·算法·机器学习·近邻算法·facebook
道友老李2 小时前
【PyTorch】实现卷积神经网络:使用CNN进行手写数字识别
人工智能·pytorch·cnn
视觉语言导航2 小时前
技术实践︱利用Docker快速体验Matterport3DSimulator!让视觉语言导航(VLN)任务入门再无门槛!
人工智能·docker·具身智能
luoganttcc3 小时前
香橙派安装 opencv 4.9.0
人工智能·opencv·webpack