本文主要个人学习记录和分享,文中若有不正确的地方欢迎大家指出与讨论!
文中主要是对理论和代码部分做一个对照理解,需要读完论文后再进行阅读。
相关论文:
- Consistency Models
- Latent Consistency Models: Synthesizing High-Resolution Images with Few-Step Inference
- Denoising Diffusion Implicit Models
本文解析的代码位于latent-consistency-model/LCM_Training_Script/consistency_distillation/train_lcm_distill_sd_wds.py

官方给出的训练设置如下所示:

0.训练前的准备(logging设置、accelerator等)
python
def main(args):
logging_dir = Path(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with=args.report_to,
project_config=accelerator_project_config,
split_batches=True, # It's important to set this to True when using webdataset to get the right number of steps for lr scheduling. If set to False, the number of steps will be devide by the number of processes assuming batches are multiplied by the number of processes
)
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
transformers.utils.logging.set_verbosity_warning()
diffusers.utils.logging.set_verbosity_info()
else:
transformers.utils.logging.set_verbosity_error()
diffusers.utils.logging.set_verbosity_error()
# If passed along, set the training seed now.
if args.seed is not None:
set_seed(args.seed)
# Handle the repository creation
if accelerator.is_main_process:
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
if args.push_to_hub:
create_repo(
repo_id=args.hub_model_id or Path(args.output_dir).name,
exist_ok=True,
token=args.hub_token,
private=True,
).repo_id
1.Noise scheduler
python
# 1. Create the noise scheduler and the desired noise schedule.
noise_scheduler = DDPMScheduler.from_pretrained(
args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision
)
# pretrained_teacher_model:"runwayml/stable-diffusion-v1-5"
# The scheduler calculates the alpha and sigma schedule for us
alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod)
sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod)
solver = DDIMSolver(
noise_scheduler.alphas_cumprod.numpy(),
timesteps=noise_scheduler.config.num_train_timesteps,
ddim_timesteps=args.num_ddim_timesteps,
)
在kaggle的notebook上调用DDPMScheduler:

所采用的scaled_linear schedule
如下:
和DDPM中的linear schedule 有所不同,该schedule在 <math xmlns="http://www.w3.org/1998/Math/MathML"> β s t a r t \sqrt{\beta_{start}} </math>βstart 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> β e n d \sqrt{\beta_{end}} </math>βend 之间生成等间隔的1000(num_train_timesteps)个数值,再对各数值取平方,虽然首尾值和linear schedule一样,但中间的值发生了变化。
我们需要的alphas_cumprod
如下:

<math xmlns="http://www.w3.org/1998/Math/MathML"> α t = 1 − β t \alpha_t=1-\beta_t </math>αt=1−βt , <math xmlns="http://www.w3.org/1998/Math/MathML"> α ‾ t = ∏ i = 0 t α i \overline\alpha_t=\prod_{i=0}^t\alpha_i </math>αt=∏i=0tαi
可视化相关变量:
python
import numpy
import torch
import matplotlib.pyplot as plt
betas = noise_scheduler.betas
alphas = noise_scheduler.alphas
alphas_cumprod = noise_scheduler.alphas_cumprod
num_train_timesteps = noise_scheduler.config.num_train_timesteps
# 创建横坐标
x = torch.arange(1, num_train_timesteps + 1)
# 创建子图
fig, axs = plt.subplots(1, 3, figsize=(16, 6))
# 绘制 betas 子图
axs[0].plot(x, betas.cpu().numpy())
axs[0].set_xlabel('Timestep')
axs[0].set_ylabel('Beta Value')
axs[0].set_title('Beta Values over Timesteps')
axs[0].grid(True)
# 绘制 alphas 子图
axs[1].plot(x, alphas.cpu().numpy())
axs[1].set_xlabel('Timestep')
axs[1].set_ylabel('Alpha Value')
axs[1].set_title('Alpha Values over Timesteps')
axs[1].grid(True)
# 绘制 alphas_cumprod 子图
axs[2].plot(x, alphas_cumprod.cpu().numpy())
axs[2].set_xlabel('Timestep')
axs[2].set_ylabel('Alpha Cumulative Product')
axs[2].set_title('Alpha Cumulative Product over Timesteps')
axs[2].grid(True)
# 调整子图间的间距并显示图形
plt.tight_layout()
plt.show()

接着来看DDIMSolver
,该段代码对应的是Denoising Diffusion Implicit Models论文中的(12)式:

因为DDIM特指 <math xmlns="http://www.w3.org/1998/Math/MathML"> σ t = 0 \sigma_t=0 </math>σt=0的情况,所以对应的公式要将该条件带入后去理解。
同时DDIM论文12式中的 <math xmlns="http://www.w3.org/1998/Math/MathML"> α t \alpha_t </math>αt对应的是DDPM以及上文中的 <math xmlns="http://www.w3.org/1998/Math/MathML"> α ‾ t \overline\alpha_t </math>αt!!!
因此alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod)对应的是 <math xmlns="http://www.w3.org/1998/Math/MathML"> α t \sqrt\alpha_t </math>α t
python
class DDIMSolver:
def __init__(self, alpha_cumprods, timesteps=1000, ddim_timesteps=50):
# DDIM sampling parameters
step_ratio = timesteps // ddim_timesteps
self.ddim_timesteps = (np.arange(1, ddim_timesteps + 1) * step_ratio).round().astype(np.int64) - 1
#取出跳步采样对应的timesteps
#array([ 19, 39, 59, 79, 99, 119, 139, 159, 179, 199, 219, 239, 259,
#279, 299, 319, 339, 359, 379, 399, 419, 439, 459, 479, 499, 519,
#539, 559, 579, 599, 619, 639, 659, 679, 699, 719, 739, 759, 779,
#799, 819, 839, 859, 879, 899, 919, 939, 959, 979, 999])
#取出对应的alpha_cumprods和alpha_cumprods_prev
self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps]
self.ddim_alpha_cumprods_prev = np.asarray(
[alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist()
)
# convert to torch tensors
self.ddim_timesteps = torch.from_numpy(self.ddim_timesteps).long()
self.ddim_alpha_cumprods = torch.from_numpy(self.ddim_alpha_cumprods)
self.ddim_alpha_cumprods_prev = torch.from_numpy(self.ddim_alpha_cumprods_prev)
def to(self, device):
self.ddim_timesteps = self.ddim_timesteps.to(device)
self.ddim_alpha_cumprods = self.ddim_alpha_cumprods.to(device)
self.ddim_alpha_cumprods_prev = self.ddim_alpha_cumprods_prev.to(device)
return self
def ddim_step(self, pred_x0, pred_noise, timestep_index):
alpha_cumprod_prev = extract_into_tensor(self.ddim_alpha_cumprods_prev, timestep_index, pred_x0.shape)
dir_xt = (1.0 - alpha_cumprod_prev).sqrt() * pred_noise # 对应"direction pointing to x_t"
x_prev = alpha_cumprod_prev.sqrt() * pred_x0 + dir_xt # alpha_cumprod_prev.sqrt() * pred_x0对应式中第一部分
return x_prev
2-6.Load所需的模块并Freeze无需训练的模块
python
# 2. Load tokenizers from SD-XL checkpoint.
tokenizer = AutoTokenizer.from_pretrained(
args.pretrained_teacher_model, subfolder="tokenizer", revision=args.teacher_revision, use_fast=False
)
# 3. Load text encoders from SD-1.5 checkpoint.
# import correct text encoder classes
text_encoder = CLIPTextModel.from_pretrained(
args.pretrained_teacher_model, subfolder="text_encoder", revision=args.teacher_revision
)
# 4. Load VAE from SD-XL checkpoint (or more stable VAE)
vae = AutoencoderKL.from_pretrained(
args.pretrained_teacher_model,
subfolder="vae",
revision=args.teacher_revision,
)
# 5. Load teacher U-Net from SD-XL checkpoint
teacher_unet = UNet2DConditionModel.from_pretrained(
args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision
)
# 6. Freeze teacher vae, text_encoder, and teacher_unet
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
teacher_unet.requires_grad_(False)
源代码中多次出现从SD-XL checkpoint中load模块,我觉得很奇怪,不应该是SD-1.5么?
8-9.创建online student U-Nets和target student U-Net
源码里没有7(不是我漏了)(doge)

该部分代码对应的为Consistency Models论文中的这部分定义,为的是进行EMA update
python
# 8. Create online (`unet`) student U-Nets. This will be updated by the optimizer (e.g. via backpropagation.)
# Add `time_cond_proj_dim` to the student U-Net if `teacher_unet.config.time_cond_proj_dim` is None
if teacher_unet.config.time_cond_proj_dim is None:
teacher_unet.config["time_cond_proj_dim"] = args.unet_time_cond_proj_dim
unet = UNet2DConditionModel(**teacher_unet.config)
# load teacher_unet weights into unet
unet.load_state_dict(teacher_unet.state_dict(), strict=False)
unet.train()
# 9. Create target (`ema_unet`) student U-Net parameters. This will be updated via EMA updates (polyak averaging).
# Initialize from unet
target_unet = UNet2DConditionModel(**teacher_unet.config)
target_unet.load_state_dict(unet.state_dict())
target_unet.train()
target_unet.requires_grad_(False)
# Check that all trainable models are in full precision
low_precision_error_string = (
" Please make sure to always have all model weights in full float32 precision when starting training - even if"
" doing mixed precision training, copy of the weights should still be float32."
)
if accelerator.unwrap_model(unet).dtype != torch.float32:
raise ValueError(
# f"Controlnet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}"
f"Unet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}"
)
10-11.precision和device对齐,处理保存和加载checkpoints
python
# 10. Handle mixed precision and device placement
# For mixed precision training we cast all non-trainable weigths to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
# Move unet, vae and text_encoder to device and cast to weight_dtype
# The VAE is in float32 to avoid NaN losses.
vae.to(accelerator.device)
if args.pretrained_vae_model_name_or_path is not None:
vae.to(dtype=weight_dtype)
text_encoder.to(accelerator.device, dtype=weight_dtype)
# Move teacher_unet to device, optionally cast to weight_dtype
target_unet.to(accelerator.device)
teacher_unet.to(accelerator.device)
if args.cast_teacher_unet:
teacher_unet.to(dtype=weight_dtype)
# Also move the alpha and sigma noise schedules to accelerator.device.
alpha_schedule = alpha_schedule.to(accelerator.device)
sigma_schedule = sigma_schedule.to(accelerator.device)
solver = solver.to(accelerator.device)
# 11. Handle saving and loading of checkpoints
# `accelerate` 0.16.0 will have better support for customized saving
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
target_unet.save_pretrained(os.path.join(output_dir, "unet_target"))
for i, model in enumerate(models):
model.save_pretrained(os.path.join(output_dir, "unet"))
# make sure to pop weight so that corresponding model is not saved again
weights.pop()
def load_model_hook(models, input_dir):
load_model = UNet2DConditionModel.from_pretrained(os.path.join(input_dir, "unet_target"))
target_unet.load_state_dict(load_model.state_dict())
target_unet.to(accelerator.device)
del load_model
for i in range(len(models)):
# pop models so that they are not loaded again
model = models.pop()
# load diffusers style into model
load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
model.register_to_config(**load_model.config)
model.load_state_dict(load_model.state_dict())
del load_model
accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)
12.可选优化项
python
# 12. Enable optimizations
if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
import xformers
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warn(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
teacher_unet.enable_xformers_memory_efficient_attention()
target_unet.enable_xformers_memory_efficient_attention()
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")
# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if args.allow_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
if args.use_8bit_adam:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError(
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
)
optimizer_class = bnb.optim.AdamW8bit
else:
optimizer_class = torch.optim.AdamW
12.创建optimizer 和dataset(原来序号问题出在这里)
python
optimizer = optimizer_class(
unet.parameters(),
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
# Here, we compute not just the text embeddings but also the additional embeddings
# needed for the SD XL UNet to operate.
#将text prompts转为embeddings
def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tokenizer, is_train=True):
prompt_embeds = encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, is_train)
return {"prompt_embeds": prompt_embeds}
dataset = Text2ImageDataset(
train_shards_path_or_url=args.train_shards_path_or_url,
num_train_examples=args.max_train_samples,
per_gpu_batch_size=args.train_batch_size,
global_batch_size=args.train_batch_size * accelerator.num_processes,
num_workers=args.dataloader_num_workers,
resolution=args.resolution,
shuffle_buffer_size=1000,
pin_memory=True,
persistent_workers=True,
)
train_dataloader = dataset.train_dataloader
compute_embeddings_fn = functools.partial(
compute_embeddings,
proportion_empty_prompts=0,
text_encoder=text_encoder,
tokenizer=tokenizer,
)
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps,
num_training_steps=args.max_train_steps,
)
# Prepare everything with our `accelerator`.
unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)
if overrode_max_train_steps:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if accelerator.is_main_process:
tracker_config = dict(vars(args))
accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
#对应classifier free guidance的无条件embedding部分
uncond_input_ids = tokenizer(
[""] * args.train_batch_size, return_tensors="pt", padding="max_length", max_length=77
).input_ids.to(accelerator.device)
uncond_prompt_embeds = text_encoder(uncond_input_ids)[0]
Train!!!
训练过程中logger的输出设置,保存和加载checkpoints设置,创建训练进度条progress_bar
python
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
logger.info("***** Running training *****")
logger.info(f" Num batches each epoch = {train_dataloader.num_batches}")
logger.info(f" Num Epochs = {args.num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {args.max_train_steps}")
global_step = 0
first_epoch = 0
# Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint:
if args.resume_from_checkpoint != "latest":
path = os.path.basename(args.resume_from_checkpoint)
else:
# Get the most recent checkpoint
dirs = os.listdir(args.output_dir)
dirs = [d for d in dirs if d.startswith("checkpoint")]
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
path = dirs[-1] if len(dirs) > 0 else None
if path is None:
accelerator.print(
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
)
args.resume_from_checkpoint = None
initial_global_step = 0
else:
accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1])
initial_global_step = global_step
first_epoch = global_step // num_update_steps_per_epoch
else:
initial_global_step = 0
progress_bar = tqdm(
range(0, args.max_train_steps),
initial=initial_global_step,
desc="Steps",
# Only show the progress bar once on each machine.
disable=not accelerator.is_local_main_process,
)
训练的核心代码
该段代码比较长,因此分开来看
python
for epoch in range(first_epoch, args.num_train_epochs):
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(unet):
image, text, _, _ = batch
image = image.to(accelerator.device, non_blocking=True) # [b,c,h,w]
encoded_text = compute_embeddings_fn(text) # [b,n1,n2]
pixel_values = image.to(dtype=weight_dtype)
if vae.dtype != weight_dtype:
vae.to(dtype=weight_dtype)
# encode pixel values with batch size of at most 32
#一次最多encode32个img
latents = []
for i in range(0, pixel_values.shape[0], 32):
latents.append(vae.encode(pixel_values[i : i + 32]).latent_dist.sample())
latents = torch.cat(latents, dim=0)
latents = latents * vae.config.scaling_factor
latents = latents.to(weight_dtype)
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# Sample a random timestep for each image t_n ~ U[0, N - k - 1] without bias.
topk = noise_scheduler.config.num_train_timesteps // args.num_ddim_timesteps
index = torch.randint(0, args.num_ddim_timesteps, (bsz,), device=latents.device).long()
start_timesteps = solver.ddim_timesteps[index]
timesteps = start_timesteps - topk
timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps)
前半部分代码是将传入的images(shape假设为[b,c,h,w])通过VAE的Encoder的encode操作后变成隐空间中的latents
并乘以缩放系数scaling_factor
,同时将相应的text prompts通过CLIP后变为text embeddings(shape假设为[b,n1,n2],n1是最大tokens数,n2为embedding的维度)
后半部分是从之前DDIMsolver给定的跳步采样对应的timesteps(在上文1中DDIMsolver class内有给出)中随机采样出batchsize个时间步:

python
# 20.4.4. Get boundary scalings for start_timesteps and (end) timesteps.
c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps)
c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]]
c_skip, c_out = scalings_for_boundary_conditions(timesteps)
c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]]
其中涉及的函数:
python
# From LCMScheduler.get_scalings_for_boundary_condition_discrete
def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0):
c_skip = sigma_data**2 / ((timestep / 0.1) ** 2 + sigma_data**2)
c_out = (timestep / 0.1) / ((timestep / 0.1) ** 2 + sigma_data**2) ** 0.5
return c_skip, c_out

对应的是Consistency Models 25页中的公式,LCM中因为最小步设为0因此有所改动,不过目的都是为了满足
boundary condition
<math xmlns="http://www.w3.org/1998/Math/MathML"> c s k i p ( 0 ) = 1 c_{skip}(0)=1 </math>cskip(0)=1和 <math xmlns="http://www.w3.org/1998/Math/MathML"> c o u t ( 0 ) = 0 c_{out}(0)=0 </math>cout(0)=0
python
def append_dims(x, target_dims):
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
dims_to_append = target_dims - x.ndim
if dims_to_append < 0:
raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
return x[(...,) + (None,) * dims_to_append]
维度对齐

python
# 20.4.5. Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1]
noisy_model_input = noise_scheduler.add_noise(latents, noise, start_timesteps)
# 20.4.6. Sample a random guidance scale w from U[w_min, w_max] and embed it
w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min # (w_min-w_min)*[0,1)+w_min = [w_min,w_max)
w_embedding = guidance_scale_embedding(w, embedding_dim=args.unet_time_cond_proj_dim)
w = w.reshape(bsz, 1, 1, 1)
# Move to U-Net device and dtype
w = w.to(device=latents.device, dtype=latents.dtype)
w_embedding = w_embedding.to(device=latents.device, dtype=latents.dtype)
# 20.4.8. Prepare prompt embeds and unet_added_conditions
prompt_embeds = encoded_text.pop("prompt_embeds")

20.4.5对应扩散前向加噪过程,得到的noisy_model_input
为含噪程度不同的noisy latents
20.4.6对应guidance scale
的随机采样过程并将其进行embedding操作转为w_embedding
20.4.8将之前prompt embeddings相应的值赋值给变量prompt_embeds
python
# 20.4.9. Get online LCM prediction on z_{t_{n + k}}, w, c, t_{n + k}
noise_pred = unet(
noisy_model_input,
start_timesteps,
timestep_cond=w_embedding,
encoder_hidden_states=prompt_embeds.float(),
added_cond_kwargs=encoded_text,
).sample
pred_x_0 = predicted_origin(
noise_pred,
start_timesteps,
noisy_model_input,
noise_scheduler.config.prediction_type,
alpha_schedule,
sigma_schedule,
)
model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0 #对应(27)
其中涉及的函数:
python
# Compare LCMScheduler.step, Step 4
def predicted_origin(model_output, timesteps, sample, prediction_type, alphas, sigmas):
if prediction_type == "epsilon":
sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
alphas = extract_into_tensor(alphas, timesteps, sample.shape)
pred_x_0 = (sample - sigmas * model_output) / alphas #对应(28)
elif prediction_type == "v_prediction":
pred_x_0 = alphas[timesteps] * sample - sigmas[timesteps] * model_output
else:
raise ValueError(f"Prediction type {prediction_type} currently not supported.")
return pred_x_0

下图可能更好理解一点,两图都源于LCM论文

最后的model_pred对应的是

python
# 20.4.10. Use the ODE solver to predict the kth step in the augmented PF-ODE trajectory after
# noisy_latents with both the conditioning embedding c and unconditional embedding 0
# Get teacher model prediction on noisy_latents and conditional embedding
with torch.no_grad():
with torch.autocast("cuda"):
cond_teacher_output = teacher_unet(
noisy_model_input.to(weight_dtype),
start_timesteps,
encoder_hidden_states=prompt_embeds.to(weight_dtype),
).sample
cond_pred_x0 = predicted_origin(
cond_teacher_output,
start_timesteps,
noisy_model_input,
noise_scheduler.config.prediction_type,
alpha_schedule,
sigma_schedule,
)
# Get teacher model prediction on noisy_latents and unconditional embedding
uncond_teacher_output = teacher_unet(
noisy_model_input.to(weight_dtype),
start_timesteps,
encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype),
).sample
uncond_pred_x0 = predicted_origin(
uncond_teacher_output,
start_timesteps,
noisy_model_input,
noise_scheduler.config.prediction_type,
alpha_schedule,
sigma_schedule,
)
# 20.4.11. Perform "CFG" to get x_prev estimate (using the LCM paper's CFG formulation)
pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0)
pred_noise = cond_teacher_output + w * (cond_teacher_output - uncond_teacher_output)
x_prev = solver.ddim_step(pred_x0, pred_noise, index)

用augmented PF-ODE solver求解第前k步的 <math xmlns="http://www.w3.org/1998/Math/MathML"> z ^ t n ψ , w \hat{z}^{\psi,w}{t_n} </math>z^tnψ,w,代码中用的是上文定义的DDIMSolver,公式可能不一样但是思想是一样的,都是为了求 <math xmlns="http://www.w3.org/1998/Math/MathML"> z ^ t n ψ , w \hat{z}^{\psi,w}{t_n} </math>z^tnψ,w
python
# 20.4.12. Get target LCM prediction on x_prev, w, c, t_n
with torch.no_grad():
with torch.autocast("cuda", dtype=weight_dtype):
target_noise_pred = target_unet(
x_prev.float(),
timesteps,
timestep_cond=w_embedding,
encoder_hidden_states=prompt_embeds.float(),
).sample
pred_x_0 = predicted_origin(
target_noise_pred,
timesteps,
x_prev,
noise_scheduler.config.prediction_type,
alpha_schedule,
sigma_schedule,
)
target = c_skip * x_prev + c_out * pred_x_0
# 20.4.13. Calculate loss
if args.loss_type == "l2":
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
elif args.loss_type == "huber":
loss = torch.mean(
torch.sqrt((model_pred.float() - target.float()) ** 2 + args.huber_c**2) - args.huber_c
)
# 20.4.14. Backpropagate on the online student model (`unet`)
accelerator.backward(loss)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)

20.4.12中target对应的是

后面就是常规的loss计算和反向传播更新online student model参数
更新target student model参数在下面代码中的第四行,下面代码中其他操作就是保存checkpoints以及logger输出的各种处理了。
python
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
# 20.4.15. Make EMA update to target student model parameters
update_ema(target_unet.parameters(), unet.parameters(), args.ema_decay)
progress_bar.update(1)
global_step += 1
if accelerator.is_main_process:
if global_step % args.checkpointing_steps == 0:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
checkpoints = os.listdir(args.output_dir)
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
if len(checkpoints) >= args.checkpoints_total_limit:
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
removing_checkpoints = checkpoints[0:num_to_remove]
logger.info(
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
)
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
for removing_checkpoint in removing_checkpoints:
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
shutil.rmtree(removing_checkpoint)
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}")
if global_step % args.validation_steps == 0:
log_validation(vae, target_unet, args, accelerator, weight_dtype, global_step, "target")
log_validation(vae, unet, args, accelerator, weight_dtype, global_step, "online")
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps:
break
# Create the pipeline using using the trained modules and save it.
accelerator.wait_for_everyone()
if accelerator.is_main_process:
unet = accelerator.unwrap_model(unet)
unet.save_pretrained(os.path.join(args.output_dir, "unet"))
target_unet = accelerator.unwrap_model(target_unet)
target_unet.save_pretrained(os.path.join(args.output_dir, "unet_target"))
accelerator.end_training()
由于该代码需要自己的数据集才能跑(论文作者好像已经和hf那边联系创作能跑通的数据集了),本文的工作只能在逻辑上串通一下理论和代码,后续如果有可用数据集的话可能考虑出个视频再讲一下。
本人有在考虑写点diffusion model基础(DDPM,DDIM,score-based generative model等)的文章,看后续的空闲时间和大家的反响吧。
最后感谢相关论文的作者们的出色工作!