AIGC笔记--Stable Diffusion源码剖析之DDIM

1--前言

以论文《High-Resolution Image Synthesis with Latent Diffusion Models》 开源的项目为例,剖析Stable Diffusion经典组成部分,巩固学习加深印象。

2--DDIM

一个可以debug的小demo:SD_DDIM

以文生图为例,剖析SD中DDIM的核心组成模块。 本质上SD的DDIM遵循论文DENOISING DIFFUSION IMPLICIT MODELS的核心公式。

3--核心模块剖析

见SD_DDIM

4--完整代码

python 复制代码
import torch
import pytorch_lightning as pl

import numpy as np
from tqdm import tqdm
from functools import partial

# From https://github.com/CompVis/latent-diffusion/blob/main/ldm/modules/diffusionmodules/util.py
def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose = True):
    # select alphas for computing the variance schedule
    alphas = alphacums[ddim_timesteps] # 由于alphacums来自DDPM,所以本质上还是调用了DDPM的alphas_cumprod,即[0.9983, 0.9804, ..., 0.0058]
    alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) # 构成alphas_prev的方法是保留前49个alphas,同时在最前面添加DDPM的alphas_cumprod[0], 即[0.9991]

    # according the the formula provided in https://arxiv.org/abs/2010.02502 论文中的公式16
    sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
    if verbose:
        print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
        print(f'For the chosen value of eta, which is {eta}, '
              f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
    return sigmas, alphas, alphas_prev

# From https://github.com/CompVis/latent-diffusion/blob/main/ldm/modules/diffusionmodules/util.py
# 获取 ddim 的timesteps
def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose = True):
    if ddim_discr_method == 'uniform':
        c = num_ddpm_timesteps // num_ddim_timesteps # 1000 // 50 = 20
        ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) # 间隔c取样
    elif ddim_discr_method == 'quad':
        ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
    else:
        raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')

    # assert ddim_timesteps.shape[0] == num_ddim_timesteps
    # add one to get the final alpha values right (the ones from first scale to data during sampling)
    steps_out = ddim_timesteps + 1 # 每个数值加1
    if verbose:
        print(f'Selected timesteps for ddim sampler: {steps_out}')
    return steps_out # [1, 21, 41, ..., 981]

# From https://github.com/CompVis/latent-diffusion/blob/main/ldm/modules/diffusionmodules/util.py
def noise_like(shape, device, repeat = False):
    repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
    noise = lambda: torch.randn(shape, device=device)
    return repeat_noise() if repeat else noise()

# From https://github.com/CompVis/latent-diffusion/blob/main/ldm/modules/diffusionmodules/util.py
def make_beta_schedule(schedule, n_timestep, linear_start = 1e-4, linear_end = 2e-2, cosine_s = 8e-3):
    if schedule == "linear":
        betas = (
                torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype = torch.float64) ** 2
        )

    elif schedule == "cosine":
        timesteps = (
                torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
        )
        alphas = timesteps / (1 + cosine_s) * np.pi / 2
        alphas = torch.cos(alphas).pow(2)
        alphas = alphas / alphas[0]
        betas = 1 - alphas[1:] / alphas[:-1]
        betas = np.clip(betas, a_min=0, a_max=0.999)

    elif schedule == "sqrt_linear":
        betas = torch.linspace(linear_start, linear_end, n_timestep, dtype = torch.float64)
    elif schedule == "sqrt":
        betas = torch.linspace(linear_start, linear_end, n_timestep, dtype = torch.float64) ** 0.5
    else:
        raise ValueError(f"schedule '{schedule}' unknown.")
    return betas.numpy()

# origin from https://github.com/CompVis/latent-diffusion/blob/main/ldm/models/diffusion/ddpm.py, modified by ljf
class DDPM(pl.LightningModule):
    def __init__(self, given_betas = None, beta_schedule = "linear", timesteps = 1000, linear_start = 0.00085, linear_end = 0.012, cosine_s = 8e-3):
        super().__init__()
        self.v_posterior = 0.0
        self.parameterization = "eps"
        self.register_schedule(given_betas = given_betas, beta_schedule = beta_schedule, timesteps = timesteps,
                        linear_start = linear_start, linear_end = linear_end, cosine_s = cosine_s)

    def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
                          linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):

        betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
                                    cosine_s=cosine_s) # 计算 betas [0.00085, 0.0008547, ..., 0.012] # total 1000
        alphas = 1. - betas # 根据betas计算alphas [0.99915, 0.9991453, ..., 0.988] # total 1000
        alphas_cumprod = np.cumprod(alphas, axis=0) # 计算alphas_cumprod [0.99915, 0.99915*0.9991453, ..., ..*0.988] # 与本身及前面的数进行相乘
        alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) # 计算alphas_cumprod_prev [1, 0.99915, 0.99915*0.9991453, ...] # 保留前999位

        timesteps, = betas.shape
        self.num_timesteps = int(timesteps)
        self.linear_start = linear_start
        self.linear_end = linear_end
        assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'

        to_torch = partial(torch.tensor, dtype=torch.float32)

        self.register_buffer('betas', to_torch(betas))
        self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
        self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))

        # calculations for diffusion q(x_t | x_{t-1}) and others
        self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
        self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
        self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
        self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
        self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))

        # calculations for posterior q(x_{t-1} | x_t, x_0)
        posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
                    1. - alphas_cumprod) + self.v_posterior * betas
        # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
        self.register_buffer('posterior_variance', to_torch(posterior_variance))
        # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
        self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
        self.register_buffer('posterior_mean_coef1', to_torch(
            betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
        self.register_buffer('posterior_mean_coef2', to_torch(
            (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))

        if self.parameterization == "eps":
            lvlb_weights = self.betas ** 2 / (
                        2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
        elif self.parameterization == "x0":
            lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
        else:
            raise NotImplementedError("mu not supported")
        # TODO how to choose this term
        lvlb_weights[0] = lvlb_weights[1]
        self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
        assert not torch.isnan(self.lvlb_weights).all()

    # 模拟 UNet 预测
    def apply_model(self, x_noisy, t, cond, return_ids=False):
        return torch.rand(x_noisy.shape) # 随机返回一个latent 预测

# Origin from https://github.com/CompVis/latent-diffusion/blob/main/ldm/models/diffusion/ddim.py, modified by ljf
class DDIMSampler(object):
    def __init__(self, model, schedule = "linear", **kwargs):
        super().__init__()
        self.model = model # DDPM的model
        self.ddpm_num_timesteps = model.num_timesteps
        self.schedule = schedule

    def register_buffer(self, name, attr):
        if type(attr) == torch.Tensor:
            if attr.device != torch.device("cuda"):
                attr = attr.to(torch.device("cuda"))
        setattr(self, name, attr)

    def make_schedule(self, ddim_num_steps, ddim_discretize = "uniform", ddim_eta = 0., verbose = True):
        # 获取ddim的timesteps [1, 21, 41, ..., 981]
        self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method = ddim_discretize, num_ddim_timesteps = ddim_num_steps,
                                                  num_ddpm_timesteps = self.ddpm_num_timesteps, verbose = verbose)
        
        alphas_cumprod = self.model.alphas_cumprod # 使用ddpm的alphas_cumprod
        assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
        to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) # lambda表达式,对每一个输入实现相同的操作

        self.register_buffer('betas', to_torch(self.model.betas)) # 使用ddpm的betas
        self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) # 使用ddpm的alphas_cumprod
        self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) #使用ddpm的alphas_cumprod_prev

        # calculations for diffusion q(x_t | x_{t-1}) and others
        self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
        self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
        self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
        self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
        self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))

        # ddim sampling parameters
        ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums = alphas_cumprod.cpu(),
                                                                                   ddim_timesteps = self.ddim_timesteps,
                                                                                   eta = ddim_eta,verbose = verbose)
        self.register_buffer('ddim_sigmas', ddim_sigmas)
        self.register_buffer('ddim_alphas', ddim_alphas)
        self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
        self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
        sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
            (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
                        1 - self.alphas_cumprod / self.alphas_cumprod_prev))
        self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)

    @torch.no_grad()
    def sample(self, S, batch_size, shape, conditioning = None, callback = None,
               img_callback = None, quantize_x0 = False, eta = 0., mask = None, x0 = None,
               temperature = 1., noise_dropout = 0., score_corrector = None, corrector_kwargs = None,
               verbose = True, x_T = None, log_every_t = 100, unconditional_guidance_scale = 1.,
               unconditional_conditioning = None
    ):
        self.make_schedule(ddim_num_steps = S, ddim_eta = eta, verbose = verbose) # 注册各个参数
        # sampling
        C, H, W = shape # [4, 64, 64]
        size = (batch_size, C, H, W) # [3, 4, 64, 64]
        print(f'Data shape for DDIM sampling is {size}, eta {eta}')

        samples, intermediates = self.ddim_sampling(conditioning, size,
                                                    callback = callback,
                                                    img_callback = img_callback,
                                                    quantize_denoised = quantize_x0,
                                                    mask = mask, x0 = x0,
                                                    ddim_use_original_steps = False,
                                                    noise_dropout = noise_dropout,
                                                    temperature = temperature,
                                                    score_corrector = score_corrector,
                                                    corrector_kwargs = corrector_kwargs,
                                                    x_T = x_T,
                                                    log_every_t = log_every_t,
                                                    unconditional_guidance_scale = unconditional_guidance_scale,
                                                    unconditional_conditioning = unconditional_conditioning,
        )
        return samples, intermediates

    @torch.no_grad()
    def ddim_sampling(self, cond, shape,
                      x_T = None, ddim_use_original_steps = False,
                      callback = None, timesteps = None, quantize_denoised = False,
                      mask = None, x0 = None, img_callback = None, log_every_t = 100,
                      temperature = 1., noise_dropout = 0., score_corrector = None, corrector_kwargs = None,
                      unconditional_guidance_scale = 1., unconditional_conditioning = None):
        device = self.model.betas.device
        b = shape[0] # batchsize
        if x_T is None:
            img = torch.randn(shape, device=device)
        else:
            img = x_T

        timesteps = self.ddim_timesteps
        intermediates = {'x_inter': [img], 'pred_x0': [img]}
        time_range = np.flip(timesteps) 
        total_steps = timesteps.shape[0] # 50
        print(f"Running DDIM Sampling with {total_steps} timesteps")

        iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)

        for i, step in enumerate(iterator): # 981, 961, ..., 1
            index = total_steps - i - 1 
            ts = torch.full((b,), step, device=device, dtype=torch.long) # [981, 981, 981], [961, 961, 961], ...
            outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
                                      quantize_denoised=quantize_denoised, temperature=temperature,
                                      noise_dropout=noise_dropout, score_corrector=score_corrector,
                                      corrector_kwargs=corrector_kwargs,
                                      unconditional_guidance_scale=unconditional_guidance_scale,
                                      unconditional_conditioning=unconditional_conditioning)
            img, pred_x0 = outs # 更新img
            if callback: callback(i)
            if img_callback: img_callback(pred_x0, i)

            if index % log_every_t == 0 or index == total_steps - 1:
                intermediates['x_inter'].append(img)
                intermediates['pred_x0'].append(pred_x0)

        return img, intermediates

    @torch.no_grad()
    def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
                      temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
                      unconditional_guidance_scale=1., unconditional_conditioning=None):
        b, *_, device = *x.shape, x.device

        if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
            e_t = self.model.apply_model(x, t, c)
        else:
            x_in = torch.cat([x] * 2) # [3, 4, 64, 64] -> [6, 4, 64, 64]
            t_in = torch.cat([t] * 2) # [3] -> [6]
            c_in = torch.cat([unconditional_conditioning, c]) # [3, 77, 768] -> [6, 77, 768]
            e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) # using Unet
            e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) # free guidance

        # 使用ddpm的参数或者make_ddim_sampling_parameters()函数生成的参数,这里默认使用了后者
        alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
        alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
        sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
        sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas

        # select parameters corresponding to the currently considered timestep
        a_t = torch.full((b, 1, 1, 1), alphas[index], device = device)
        a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device = device)
        sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device = device)
        sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index], device = device)

        # current prediction for x_0
        pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() # 论文https://arxiv.org/pdf/2010.02502中公式(12)的第一项

        # direction pointing to x_t
        dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t # 论文https://arxiv.org/pdf/2010.02502中公式(12)的第二项 
        noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature # 论文https://arxiv.org/pdf/2010.02502中公式(12)的第三项 # 由于输入的eta为0,因此sigma_t为0,因此本式的结果为0

        if noise_dropout > 0.:
            noise = torch.nn.functional.dropout(noise, p=noise_dropout)
        x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise # 构成论文https://arxiv.org/pdf/2010.02502中的公式(12),即根据x_t得到x_(t-1)
        return x_prev, pred_x0 
    
if __name__ == "__main__":

    model = DDPM() # 初始化DDPM model
    sampler = DDIMSampler(model)

    # 模拟FrozenCLIPEmbedder的输出
    batchsize = 3
    c = torch.rand(batchsize, 77, 768) # 模拟有prompt时的embedding
    uc = torch.rand(batchsize, 77, 768) # 模拟无prompt时的embedding

    # 使用ddim进行去噪
    shape = [4, 64, 64]
    scale = 7.5 # unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))
    ddim_eta = 0.0 # ddim eta (eta=0.0 corresponds to deterministic sampling
    samples_ddim, _ = sampler.sample(S = 50, # 采样50步
                                    conditioning = c, # 条件embedding
                                    batch_size = batchsize,
                                    shape = shape,
                                    verbose = False,
                                    unconditional_guidance_scale = scale,
                                    unconditional_conditioning = uc, # 无条件embedding
                                    eta = ddim_eta,
                                    x_T = None)
    
    assert samples_ddim.shape[0] == batchsize
    assert list(samples_ddim[0].shape) == shape
    print("samples_ddim.shape: ", samples_ddim.shape)
    assert samples_ddim.shape[0] == batchsize
    assert list(samples_ddim.shape[1:]) == shape
    print("All Done!")
相关推荐
好评笔记18 小时前
AIGC视频扩散模型新星:Video 版本的SD模型
论文阅读·深度学习·机器学习·计算机视觉·面试·aigc·transformer
AIGC大时代20 小时前
方法建议ChatGPT提示词分享
人工智能·深度学习·chatgpt·aigc·ai写作
正在走向自律1 天前
AI 写作(六):核心技术与多元应用(6/10)
人工智能·aigc·ai写作
寻道码路1 天前
探秘 Docling:多格式文档解析转换大揭秘,赋能 AI 应用新生态
人工智能·aigc·ai编程
好评笔记1 天前
AIGC视频生成模型:Stability AI的SVD(Stable Video Diffusion)模型
论文阅读·人工智能·深度学习·机器学习·计算机视觉·面试·aigc
算家云1 天前
TangoFlux 本地部署实用教程:开启无限音频创意脑洞
人工智能·aigc·模型搭建·算家云、·应用社区·tangoflux
五月君1 天前
Windsurf 发布Wave 2,Web实时搜索、URL上下文、自动化记忆等一大波新功能来袭!
aigc
多森2 天前
Cursor太贵?字节Trae可免费用Claude,10分钟带你实现全栈开发
aigc
是店小二呀2 天前
【2024年CSDN平台总结:新生与成长之路】
数据库·人工智能·程序人生·aigc·音视频
杀生丸学AI2 天前
【三维分割】Gaga:通过3D感知的 Memory Bank 分组任意高斯
aigc·三维重建·nerf·视觉大模型·3dgs·三维高斯溅射·分割一切sam