LCM-distillation 训练代码解析-SD1.5

本文主要个人学习记录和分享,文中若有不正确的地方欢迎大家指出与讨论!

文中主要是对理论和代码部分做一个对照理解,需要读完论文后再进行阅读。

github项目网址luosiallen/latent-consistency-model: Latent Consistency Models: Synthesizing High-Resolution Images with Few-Step Inference (github.com)

相关论文

本文解析的代码位于latent-consistency-model/LCM_Training_Script/consistency_distillation/train_lcm_distill_sd_wds.py

官方给出的训练设置如下所示:

0.训练前的准备(logging设置、accelerator等)

python 复制代码
def main(args):
    logging_dir = Path(args.output_dir, args.logging_dir)

    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)

    accelerator = Accelerator(
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        mixed_precision=args.mixed_precision,
        log_with=args.report_to,
        project_config=accelerator_project_config,
        split_batches=True,  # It's important to set this to True when using webdataset to get the right number of steps for lr scheduling. If set to False, the number of steps will be devide by the number of processes assuming batches are multiplied by the number of processes
    )

    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    logger.info(accelerator.state, main_process_only=False)
    if accelerator.is_local_main_process:
        transformers.utils.logging.set_verbosity_warning()
        diffusers.utils.logging.set_verbosity_info()
    else:
        transformers.utils.logging.set_verbosity_error()
        diffusers.utils.logging.set_verbosity_error()

    # If passed along, set the training seed now.
    if args.seed is not None:
        set_seed(args.seed)

    # Handle the repository creation
    if accelerator.is_main_process:
        if args.output_dir is not None:
            os.makedirs(args.output_dir, exist_ok=True)

        if args.push_to_hub:
            create_repo(
                repo_id=args.hub_model_id or Path(args.output_dir).name,
                exist_ok=True,
                token=args.hub_token,
                private=True,
            ).repo_id

1.Noise scheduler

python 复制代码
# 1. Create the noise scheduler and the desired noise schedule.
noise_scheduler = DDPMScheduler.from_pretrained(
	args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision
)
# pretrained_teacher_model:"runwayml/stable-diffusion-v1-5"

# The scheduler calculates the alpha and sigma schedule for us
alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod)
sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod)
solver = DDIMSolver(
	noise_scheduler.alphas_cumprod.numpy(),
	timesteps=noise_scheduler.config.num_train_timesteps,
	ddim_timesteps=args.num_ddim_timesteps,
)

在kaggle的notebook上调用DDPMScheduler:

所采用的scaled_linear schedule如下:

和DDPM中的linear schedule 有所不同,该schedule在 <math xmlns="http://www.w3.org/1998/Math/MathML"> β s t a r t \sqrt{\beta_{start}} </math>βstart 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> β e n d \sqrt{\beta_{end}} </math>βend 之间生成等间隔的1000(num_train_timesteps)个数值,再对各数值取平方,虽然首尾值和linear schedule一样,但中间的值发生了变化。

我们需要的alphas_cumprod如下:

<math xmlns="http://www.w3.org/1998/Math/MathML"> α t = 1 − β t \alpha_t=1-\beta_t </math>αt=1−βt , <math xmlns="http://www.w3.org/1998/Math/MathML"> α ‾ t = ∏ i = 0 t α i \overline\alpha_t=\prod_{i=0}^t\alpha_i </math>αt=∏i=0tαi

可视化相关变量:

python 复制代码
import numpy
import torch
import matplotlib.pyplot as plt

betas = noise_scheduler.betas
alphas = noise_scheduler.alphas
alphas_cumprod = noise_scheduler.alphas_cumprod
num_train_timesteps = noise_scheduler.config.num_train_timesteps

# 创建横坐标
x = torch.arange(1, num_train_timesteps + 1)
# 创建子图
fig, axs = plt.subplots(1, 3, figsize=(16, 6))
# 绘制 betas 子图
axs[0].plot(x, betas.cpu().numpy())
axs[0].set_xlabel('Timestep')
axs[0].set_ylabel('Beta Value')
axs[0].set_title('Beta Values over Timesteps')
axs[0].grid(True)
# 绘制 alphas 子图
axs[1].plot(x, alphas.cpu().numpy())
axs[1].set_xlabel('Timestep')
axs[1].set_ylabel('Alpha Value')
axs[1].set_title('Alpha Values over Timesteps')
axs[1].grid(True)
# 绘制 alphas_cumprod 子图
axs[2].plot(x, alphas_cumprod.cpu().numpy())
axs[2].set_xlabel('Timestep')
axs[2].set_ylabel('Alpha Cumulative Product')
axs[2].set_title('Alpha Cumulative Product over Timesteps')
axs[2].grid(True)
# 调整子图间的间距并显示图形
plt.tight_layout()
plt.show()

接着来看DDIMSolver,该段代码对应的是Denoising Diffusion Implicit Models论文中的(12)式:

因为DDIM特指 <math xmlns="http://www.w3.org/1998/Math/MathML"> σ t = 0 \sigma_t=0 </math>σt=0的情况,所以对应的公式要将该条件带入后去理解。

同时DDIM论文12式中的 <math xmlns="http://www.w3.org/1998/Math/MathML"> α t \alpha_t </math>αt对应的是DDPM以及上文中的 <math xmlns="http://www.w3.org/1998/Math/MathML"> α ‾ t \overline\alpha_t </math>αt!!!

因此alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod)对应的是 <math xmlns="http://www.w3.org/1998/Math/MathML"> α t \sqrt\alpha_t </math>α t

python 复制代码
class DDIMSolver:
    def __init__(self, alpha_cumprods, timesteps=1000, ddim_timesteps=50):
        # DDIM sampling parameters
        step_ratio = timesteps // ddim_timesteps
        self.ddim_timesteps = (np.arange(1, ddim_timesteps + 1) * step_ratio).round().astype(np.int64) - 1
        #取出跳步采样对应的timesteps
        #array([ 19,  39,  59,  79,  99, 119, 139, 159, 179, 199, 219, 239, 259,
        #279, 299, 319, 339, 359, 379, 399, 419, 439, 459, 479, 499, 519,
        #539, 559, 579, 599, 619, 639, 659, 679, 699, 719, 739, 759, 779,
        #799, 819, 839, 859, 879, 899, 919, 939, 959, 979, 999])

        #取出对应的alpha_cumprods和alpha_cumprods_prev
        self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps]
        self.ddim_alpha_cumprods_prev = np.asarray(
            [alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist()
        )
        
        # convert to torch tensors
        self.ddim_timesteps = torch.from_numpy(self.ddim_timesteps).long()
        self.ddim_alpha_cumprods = torch.from_numpy(self.ddim_alpha_cumprods)
        self.ddim_alpha_cumprods_prev = torch.from_numpy(self.ddim_alpha_cumprods_prev)

    def to(self, device):
        self.ddim_timesteps = self.ddim_timesteps.to(device)
        self.ddim_alpha_cumprods = self.ddim_alpha_cumprods.to(device)
        self.ddim_alpha_cumprods_prev = self.ddim_alpha_cumprods_prev.to(device)
        return self

    def ddim_step(self, pred_x0, pred_noise, timestep_index):
        alpha_cumprod_prev = extract_into_tensor(self.ddim_alpha_cumprods_prev, timestep_index, pred_x0.shape)
        dir_xt = (1.0 - alpha_cumprod_prev).sqrt() * pred_noise # 对应"direction pointing to x_t"      
        x_prev = alpha_cumprod_prev.sqrt() * pred_x0 + dir_xt   # alpha_cumprod_prev.sqrt() * pred_x0对应式中第一部分
        return x_prev

2-6.Load所需的模块并Freeze无需训练的模块

python 复制代码
# 2. Load tokenizers from SD-XL checkpoint.
tokenizer = AutoTokenizer.from_pretrained(
    args.pretrained_teacher_model, subfolder="tokenizer", revision=args.teacher_revision, use_fast=False
)

# 3. Load text encoders from SD-1.5 checkpoint.
# import correct text encoder classes
text_encoder = CLIPTextModel.from_pretrained(
    args.pretrained_teacher_model, subfolder="text_encoder", revision=args.teacher_revision
)

# 4. Load VAE from SD-XL checkpoint (or more stable VAE)
vae = AutoencoderKL.from_pretrained(
    args.pretrained_teacher_model,
    subfolder="vae",
    revision=args.teacher_revision,
)

# 5. Load teacher U-Net from SD-XL checkpoint
teacher_unet = UNet2DConditionModel.from_pretrained(
    args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision
)

# 6. Freeze teacher vae, text_encoder, and teacher_unet
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
teacher_unet.requires_grad_(False)

源代码中多次出现从SD-XL checkpoint中load模块,我觉得很奇怪,不应该是SD-1.5么?

8-9.创建online student U-Nets和target student U-Net

源码里没有7(不是我漏了)(doge)

该部分代码对应的为Consistency Models论文中的这部分定义,为的是进行EMA update

python 复制代码
# 8. Create online (`unet`) student U-Nets. This will be updated by the optimizer (e.g. via backpropagation.)
# Add `time_cond_proj_dim` to the student U-Net if `teacher_unet.config.time_cond_proj_dim` is None
if teacher_unet.config.time_cond_proj_dim is None:
    teacher_unet.config["time_cond_proj_dim"] = args.unet_time_cond_proj_dim
unet = UNet2DConditionModel(**teacher_unet.config)
# load teacher_unet weights into unet
unet.load_state_dict(teacher_unet.state_dict(), strict=False)
unet.train()

# 9. Create target (`ema_unet`) student U-Net parameters. This will be updated via EMA updates (polyak averaging).
# Initialize from unet
target_unet = UNet2DConditionModel(**teacher_unet.config)
target_unet.load_state_dict(unet.state_dict())
target_unet.train()
target_unet.requires_grad_(False)

# Check that all trainable models are in full precision
low_precision_error_string = (
    " Please make sure to always have all model weights in full float32 precision when starting training - even if"
    " doing mixed precision training, copy of the weights should still be float32."
)

if accelerator.unwrap_model(unet).dtype != torch.float32:
    raise ValueError(
        #            f"Controlnet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}"
        f"Unet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}"
    )

10-11.precision和device对齐,处理保存和加载checkpoints

python 复制代码
    # 10. Handle mixed precision and device placement
    # For mixed precision training we cast all non-trainable weigths to half-precision
    # as these weights are only used for inference, keeping weights in full precision is not required.
    weight_dtype = torch.float32
    if accelerator.mixed_precision == "fp16":
        weight_dtype = torch.float16
    elif accelerator.mixed_precision == "bf16":
        weight_dtype = torch.bfloat16

    # Move unet, vae and text_encoder to device and cast to weight_dtype
    # The VAE is in float32 to avoid NaN losses.
    vae.to(accelerator.device)
    if args.pretrained_vae_model_name_or_path is not None:
        vae.to(dtype=weight_dtype)
    text_encoder.to(accelerator.device, dtype=weight_dtype)

    # Move teacher_unet to device, optionally cast to weight_dtype
    target_unet.to(accelerator.device)
    teacher_unet.to(accelerator.device)
    if args.cast_teacher_unet:
        teacher_unet.to(dtype=weight_dtype)

    # Also move the alpha and sigma noise schedules to accelerator.device.
    alpha_schedule = alpha_schedule.to(accelerator.device)
    sigma_schedule = sigma_schedule.to(accelerator.device)
    solver = solver.to(accelerator.device)

    # 11. Handle saving and loading of checkpoints
    # `accelerate` 0.16.0 will have better support for customized saving
    if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
        # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
        def save_model_hook(models, weights, output_dir):
            if accelerator.is_main_process:
                target_unet.save_pretrained(os.path.join(output_dir, "unet_target"))

                for i, model in enumerate(models):
                    model.save_pretrained(os.path.join(output_dir, "unet"))

                    # make sure to pop weight so that corresponding model is not saved again
                    weights.pop()

        def load_model_hook(models, input_dir):
            load_model = UNet2DConditionModel.from_pretrained(os.path.join(input_dir, "unet_target"))
            target_unet.load_state_dict(load_model.state_dict())
            target_unet.to(accelerator.device)
            del load_model

            for i in range(len(models)):
                # pop models so that they are not loaded again
                model = models.pop()

                # load diffusers style into model
                load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
                model.register_to_config(**load_model.config)

                model.load_state_dict(load_model.state_dict())
                del load_model

        accelerator.register_save_state_pre_hook(save_model_hook)
        accelerator.register_load_state_pre_hook(load_model_hook)

12.可选优化项

python 复制代码
    # 12. Enable optimizations
    if args.enable_xformers_memory_efficient_attention:
        if is_xformers_available():
            import xformers

            xformers_version = version.parse(xformers.__version__)
            if xformers_version == version.parse("0.0.16"):
                logger.warn(
                    "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
                )
            unet.enable_xformers_memory_efficient_attention()
            teacher_unet.enable_xformers_memory_efficient_attention()
            target_unet.enable_xformers_memory_efficient_attention()
        else:
            raise ValueError("xformers is not available. Make sure it is installed correctly")

    # Enable TF32 for faster training on Ampere GPUs,
    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
    if args.allow_tf32:
        torch.backends.cuda.matmul.allow_tf32 = True

    if args.gradient_checkpointing:
        unet.enable_gradient_checkpointing()

    # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
    if args.use_8bit_adam:
        try:
            import bitsandbytes as bnb
        except ImportError:
            raise ImportError(
                "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
            )

        optimizer_class = bnb.optim.AdamW8bit
    else:
        optimizer_class = torch.optim.AdamW

12.创建optimizer 和dataset(原来序号问题出在这里)

python 复制代码
    optimizer = optimizer_class(
        unet.parameters(),
        lr=args.learning_rate,
        betas=(args.adam_beta1, args.adam_beta2),
        weight_decay=args.adam_weight_decay,
        eps=args.adam_epsilon,
    )

    # Here, we compute not just the text embeddings but also the additional embeddings
    # needed for the SD XL UNet to operate.
    #将text prompts转为embeddings
    def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tokenizer, is_train=True):
        prompt_embeds = encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, is_train)
        return {"prompt_embeds": prompt_embeds}

    dataset = Text2ImageDataset(
        train_shards_path_or_url=args.train_shards_path_or_url,
        num_train_examples=args.max_train_samples,
        per_gpu_batch_size=args.train_batch_size,
        global_batch_size=args.train_batch_size * accelerator.num_processes,
        num_workers=args.dataloader_num_workers,
        resolution=args.resolution,
        shuffle_buffer_size=1000,
        pin_memory=True,
        persistent_workers=True,
    )
    train_dataloader = dataset.train_dataloader

    compute_embeddings_fn = functools.partial(
        compute_embeddings,
        proportion_empty_prompts=0,
        text_encoder=text_encoder,
        tokenizer=tokenizer,
    )

    # Scheduler and math around the number of training steps.
    overrode_max_train_steps = False
    num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)
    if args.max_train_steps is None:
        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
        overrode_max_train_steps = True

    lr_scheduler = get_scheduler(
        args.lr_scheduler,
        optimizer=optimizer,
        num_warmup_steps=args.lr_warmup_steps,
        num_training_steps=args.max_train_steps,
    )

    # Prepare everything with our `accelerator`.
    unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)

    # We need to recalculate our total training steps as the size of the training dataloader may have changed.
    num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)
    if overrode_max_train_steps:
        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
    # Afterwards we recalculate our number of training epochs
    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)

    # We need to initialize the trackers we use, and also store our configuration.
    # The trackers initializes automatically on the main process.
    if accelerator.is_main_process:
        tracker_config = dict(vars(args))
        accelerator.init_trackers(args.tracker_project_name, config=tracker_config)

    #对应classifier free guidance的无条件embedding部分
    uncond_input_ids = tokenizer(
        [""] * args.train_batch_size, return_tensors="pt", padding="max_length", max_length=77
    ).input_ids.to(accelerator.device)
    uncond_prompt_embeds = text_encoder(uncond_input_ids)[0]

Train!!!

训练过程中logger的输出设置,保存和加载checkpoints设置,创建训练进度条progress_bar

python 复制代码
    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps

    logger.info("***** Running training *****")
    logger.info(f"  Num batches each epoch = {train_dataloader.num_batches}")
    logger.info(f"  Num Epochs = {args.num_train_epochs}")
    logger.info(f"  Instantaneous batch size per device = {args.train_batch_size}")
    logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
    logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
    logger.info(f"  Total optimization steps = {args.max_train_steps}")
    global_step = 0
    first_epoch = 0

    # Potentially load in the weights and states from a previous save
    if args.resume_from_checkpoint:
        if args.resume_from_checkpoint != "latest":
            path = os.path.basename(args.resume_from_checkpoint)
        else:
            # Get the most recent checkpoint
            dirs = os.listdir(args.output_dir)
            dirs = [d for d in dirs if d.startswith("checkpoint")]
            dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
            path = dirs[-1] if len(dirs) > 0 else None

        if path is None:
            accelerator.print(
                f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
            )
            args.resume_from_checkpoint = None
            initial_global_step = 0
        else:
            accelerator.print(f"Resuming from checkpoint {path}")
            accelerator.load_state(os.path.join(args.output_dir, path))
            global_step = int(path.split("-")[1])

            initial_global_step = global_step
            first_epoch = global_step // num_update_steps_per_epoch
    else:
        initial_global_step = 0

    progress_bar = tqdm(
        range(0, args.max_train_steps),
        initial=initial_global_step,
        desc="Steps",
        # Only show the progress bar once on each machine.
        disable=not accelerator.is_local_main_process,
    )

训练的核心代码

该段代码比较长,因此分开来看

python 复制代码
    for epoch in range(first_epoch, args.num_train_epochs):
        for step, batch in enumerate(train_dataloader):
            with accelerator.accumulate(unet):
                image, text, _, _ = batch

                image = image.to(accelerator.device, non_blocking=True) # [b,c,h,w]
                encoded_text = compute_embeddings_fn(text) # [b,n1,n2]

                pixel_values = image.to(dtype=weight_dtype)
                if vae.dtype != weight_dtype:
                    vae.to(dtype=weight_dtype)

                # encode pixel values with batch size of at most 32
                #一次最多encode32个img
                latents = []
                for i in range(0, pixel_values.shape[0], 32):
                    latents.append(vae.encode(pixel_values[i : i + 32]).latent_dist.sample())
                latents = torch.cat(latents, dim=0)

                latents = latents * vae.config.scaling_factor
                latents = latents.to(weight_dtype)

                # Sample noise that we'll add to the latents
                noise = torch.randn_like(latents)
                bsz = latents.shape[0]

                # Sample a random timestep for each image t_n ~ U[0, N - k - 1] without bias.
                topk = noise_scheduler.config.num_train_timesteps // args.num_ddim_timesteps
                index = torch.randint(0, args.num_ddim_timesteps, (bsz,), device=latents.device).long()
                start_timesteps = solver.ddim_timesteps[index]
                timesteps = start_timesteps - topk
                timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps)

前半部分代码是将传入的images(shape假设为[b,c,h,w])通过VAE的Encoder的encode操作后变成隐空间中的latents并乘以缩放系数scaling_factor,同时将相应的text prompts通过CLIP后变为text embeddings(shape假设为[b,n1,n2],n1是最大tokens数,n2为embedding的维度)

后半部分是从之前DDIMsolver给定的跳步采样对应的timesteps(在上文1中DDIMsolver class内有给出)中随机采样出batchsize个时间步:


python 复制代码
# 20.4.4. Get boundary scalings for start_timesteps and (end) timesteps.
c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps)
c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]]
c_skip, c_out = scalings_for_boundary_conditions(timesteps)
c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]]

其中涉及的函数:

python 复制代码
# From LCMScheduler.get_scalings_for_boundary_condition_discrete
def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0):
    c_skip = sigma_data**2 / ((timestep / 0.1) ** 2 + sigma_data**2)
    c_out = (timestep / 0.1) / ((timestep / 0.1) ** 2 + sigma_data**2) ** 0.5
    return c_skip, c_out

对应的是Consistency Models 25页中的公式,LCM中因为最小步设为0因此有所改动,不过目的都是为了满足boundary condition <math xmlns="http://www.w3.org/1998/Math/MathML"> c s k i p ( 0 ) = 1 c_{skip}(0)=1 </math>cskip(0)=1和 <math xmlns="http://www.w3.org/1998/Math/MathML"> c o u t ( 0 ) = 0 c_{out}(0)=0 </math>cout(0)=0

python 复制代码
def append_dims(x, target_dims):
    """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
    dims_to_append = target_dims - x.ndim
    if dims_to_append < 0:
        raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
    return x[(...,) + (None,) * dims_to_append]

维度对齐


python 复制代码
# 20.4.5. Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1]
noisy_model_input = noise_scheduler.add_noise(latents, noise, start_timesteps)

# 20.4.6. Sample a random guidance scale w from U[w_min, w_max] and embed it
w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min   # (w_min-w_min)*[0,1)+w_min = [w_min,w_max)
w_embedding = guidance_scale_embedding(w, embedding_dim=args.unet_time_cond_proj_dim)
w = w.reshape(bsz, 1, 1, 1)
# Move to U-Net device and dtype
w = w.to(device=latents.device, dtype=latents.dtype)
w_embedding = w_embedding.to(device=latents.device, dtype=latents.dtype)

# 20.4.8. Prepare prompt embeds and unet_added_conditions
prompt_embeds = encoded_text.pop("prompt_embeds")

20.4.5对应扩散前向加噪过程,得到的noisy_model_input为含噪程度不同的noisy latents

20.4.6对应guidance scale的随机采样过程并将其进行embedding操作转为w_embedding

20.4.8将之前prompt embeddings相应的值赋值给变量prompt_embeds


python 复制代码
# 20.4.9. Get online LCM prediction on z_{t_{n + k}}, w, c, t_{n + k}
noise_pred = unet(
    noisy_model_input,
    start_timesteps,
    timestep_cond=w_embedding,
    encoder_hidden_states=prompt_embeds.float(),
    added_cond_kwargs=encoded_text,
).sample

pred_x_0 = predicted_origin(
    noise_pred,
    start_timesteps,
    noisy_model_input,
    noise_scheduler.config.prediction_type,
    alpha_schedule,
    sigma_schedule,
)

model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0 #对应(27)

其中涉及的函数:

python 复制代码
# Compare LCMScheduler.step, Step 4
def predicted_origin(model_output, timesteps, sample, prediction_type, alphas, sigmas):
    if prediction_type == "epsilon":
        sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
        alphas = extract_into_tensor(alphas, timesteps, sample.shape)
        pred_x_0 = (sample - sigmas * model_output) / alphas  #对应(28)
    elif prediction_type == "v_prediction":
        pred_x_0 = alphas[timesteps] * sample - sigmas[timesteps] * model_output
    else:
        raise ValueError(f"Prediction type {prediction_type} currently not supported.")

    return pred_x_0

下图可能更好理解一点,两图都源于LCM论文

最后的model_pred对应的是


python 复制代码
# 20.4.10. Use the ODE solver to predict the kth step in the augmented PF-ODE trajectory after
# noisy_latents with both the conditioning embedding c and unconditional embedding 0
# Get teacher model prediction on noisy_latents and conditional embedding
with torch.no_grad():
    with torch.autocast("cuda"):
        cond_teacher_output = teacher_unet(
            noisy_model_input.to(weight_dtype),
            start_timesteps,
            encoder_hidden_states=prompt_embeds.to(weight_dtype),
        ).sample
        cond_pred_x0 = predicted_origin(
            cond_teacher_output,
            start_timesteps,
            noisy_model_input,
            noise_scheduler.config.prediction_type,
            alpha_schedule,
            sigma_schedule,
        )

        # Get teacher model prediction on noisy_latents and unconditional embedding
        uncond_teacher_output = teacher_unet(
            noisy_model_input.to(weight_dtype),
            start_timesteps,
            encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype),
        ).sample
        uncond_pred_x0 = predicted_origin(
            uncond_teacher_output,
            start_timesteps,
            noisy_model_input,
            noise_scheduler.config.prediction_type,
            alpha_schedule,
            sigma_schedule,
        )

        # 20.4.11. Perform "CFG" to get x_prev estimate (using the LCM paper's CFG formulation)
        pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0)
        pred_noise = cond_teacher_output + w * (cond_teacher_output - uncond_teacher_output)
        x_prev = solver.ddim_step(pred_x0, pred_noise, index)        

augmented PF-ODE solver求解第前k步的 <math xmlns="http://www.w3.org/1998/Math/MathML"> z ^ t n ψ , w \hat{z}^{\psi,w}{t_n} </math>z^tnψ,w,代码中用的是上文定义的DDIMSolver,公式可能不一样但是思想是一样的,都是为了求 <math xmlns="http://www.w3.org/1998/Math/MathML"> z ^ t n ψ , w \hat{z}^{\psi,w}{t_n} </math>z^tnψ,w


python 复制代码
                # 20.4.12. Get target LCM prediction on x_prev, w, c, t_n
                with torch.no_grad():
                    with torch.autocast("cuda", dtype=weight_dtype):
                        target_noise_pred = target_unet(
                            x_prev.float(),
                            timesteps,
                            timestep_cond=w_embedding,
                            encoder_hidden_states=prompt_embeds.float(),
                        ).sample
                    pred_x_0 = predicted_origin(
                        target_noise_pred,
                        timesteps,
                        x_prev,
                        noise_scheduler.config.prediction_type,
                        alpha_schedule,
                        sigma_schedule,
                    )
                    target = c_skip * x_prev + c_out * pred_x_0

                # 20.4.13. Calculate loss
                if args.loss_type == "l2":
                    loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
                elif args.loss_type == "huber":
                    loss = torch.mean(
                        torch.sqrt((model_pred.float() - target.float()) ** 2 + args.huber_c**2) - args.huber_c
                    )

                # 20.4.14. Backpropagate on the online student model (`unet`)
                accelerator.backward(loss)
                if accelerator.sync_gradients:
                    accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad(set_to_none=True)

20.4.12中target对应的是

后面就是常规的loss计算和反向传播更新online student model参数

更新target student model参数在下面代码中的第四行,下面代码中其他操作就是保存checkpoints以及logger输出的各种处理了。

python 复制代码
            # Checks if the accelerator has performed an optimization step behind the scenes
            if accelerator.sync_gradients:
                # 20.4.15. Make EMA update to target student model parameters
                update_ema(target_unet.parameters(), unet.parameters(), args.ema_decay)
                progress_bar.update(1)
                global_step += 1

                if accelerator.is_main_process:
                    if global_step % args.checkpointing_steps == 0:
                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
                        if args.checkpoints_total_limit is not None:
                            checkpoints = os.listdir(args.output_dir)
                            checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))

                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
                            if len(checkpoints) >= args.checkpoints_total_limit:
                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
                                removing_checkpoints = checkpoints[0:num_to_remove]

                                logger.info(
                                    f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
                                )
                                logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")

                                for removing_checkpoint in removing_checkpoints:
                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
                                    shutil.rmtree(removing_checkpoint)

                        save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
                        accelerator.save_state(save_path)
                        logger.info(f"Saved state to {save_path}")

                    if global_step % args.validation_steps == 0:
                        log_validation(vae, target_unet, args, accelerator, weight_dtype, global_step, "target")
                        log_validation(vae, unet, args, accelerator, weight_dtype, global_step, "online")

            logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
            progress_bar.set_postfix(**logs)
            accelerator.log(logs, step=global_step)

            if global_step >= args.max_train_steps:
                break

    # Create the pipeline using using the trained modules and save it.
    accelerator.wait_for_everyone()
    if accelerator.is_main_process:
        unet = accelerator.unwrap_model(unet)
        unet.save_pretrained(os.path.join(args.output_dir, "unet"))

        target_unet = accelerator.unwrap_model(target_unet)
        target_unet.save_pretrained(os.path.join(args.output_dir, "unet_target"))

    accelerator.end_training()

由于该代码需要自己的数据集才能跑(论文作者好像已经和hf那边联系创作能跑通的数据集了),本文的工作只能在逻辑上串通一下理论和代码,后续如果有可用数据集的话可能考虑出个视频再讲一下。

本人有在考虑写点diffusion model基础(DDPM,DDIM,score-based generative model等)的文章,看后续的空闲时间和大家的反响吧。

最后感谢相关论文的作者们的出色工作!

相关推荐
逐星ing7 小时前
【AIGC】使用Python实现科大讯飞语音服务ASR转录功能:完整指南
python·aigc·xcode
PeterClerk21 小时前
AIGC-Stable Diffusion模型介绍
图像处理·人工智能·python·自然语言处理·stable diffusion·aigc
歌刎21 小时前
DeepSeek开源周Day1:FlashMLA引爆AI推理性能革命!
人工智能·深度学习·nlp·aigc·deepseek
程序员老刘·1 天前
一个Flutter跨4端开发的案例
flutter·aigc
lzq6032 天前
基于AIGC的图表自动化生成工具「图表狐」深度评测:如何用自然语言30秒搞定专业级数据可视化?
信息可视化·自动化·aigc
kakaZhui2 天前
【多模态大模型】端侧语音大模型minicpm-o:手机上的 GPT-4o 级多模态大模型
人工智能·chatgpt·aigc·llama
小溪彼岸3 天前
Google AI Studio强大的Gemini AI模型集成平台
google·aigc
猫头虎-人工智能3 天前
NVIDIA A100 SXM4与NVIDIA A100 PCIe版本区别深度对比:架构、性能与场景解析
gpt·架构·机器人·aigc·文心一言·palm
win4r3 天前
🚀用MCP为AutoGen开挂接入各种工具和框架!Cline零代码开发MCP Server实现接入LangFlow进行文档问答!利用MCP Server突破平
aigc·openai
极客密码4 天前
大厂出品!三个新的 DeepSeek 平替网站
aigc·cursor·deepseek