系列文章目录
- 【扩散模型(二)】IP-Adapter 从条件分支的视角,快速理解相关的可控生成研究
- 【扩散模型(三)】IP-Adapter 源码详解1-训练输入 介绍了训练代码中的 image prompt 的输入部分,即 img projection 模块。
- 【扩散模型(四)】IP-Adapter 源码详解2-训练核心(cross-attention)详细介绍 IP-Adapter 训练代码的核心部分,即插入 Unet 中的、针对 Image prompt 的 cross-attention 模块。
- 【扩散模型(五)】IP-Adapter 源码详解3-推理代码 详细介绍 IP-Adapter 推理过程代码。
- 【可控图像生成系列论文(四)】IP-Adapter 具体是如何训练的?1公式篇
- 本文则以 IP-Adapter Plus 训练代码为例,进行详细介绍。
文章目录
- 系列文章目录
- 整体训练框架
- 一、训了哪些部分?
-
-
- [第一块 - image_proj_model](#第一块 - image_proj_model)
- [第二块 - adapter_modules](#第二块 - adapter_modules)
-
- 二、训练目标
整体训练框架
一、训了哪些部分?
本文以原仓库 ^1^ 的 /path/IP-Adapter/tutorial_train_plus.py 为例,该文件为 SD1.5 IP-Adapter Plus 的训练代码。
从以下代码可以看出,IPAdapter 主要由 unet, image_proj_model, adapter_modules 3 个部分组成,而权重需要被优化的(训练到的)只有 ip_adapter.image_proj_model
.parameters(), 和 ip_adapter.adapter_modules
.parameters() 。
python
ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules, args.pretrained_ip_adapter_path)
# optimizer
params_to_opt = itertools.chain(ip_adapter.image_proj_model.parameters(), ip_adapter.adapter_modules.parameters())
optimizer = torch.optim.AdamW(params_to_opt, lr=args.learning_rate, weight_decay=args.weight_decay)
...
# Prepare everything with our `accelerator`.
ip_adapter, optimizer, train_dataloader = accelerator.prepare(ip_adapter, optimizer, train_dataloader)
第一块 - image_proj_model
在 IP-Adapter Plus 中,采用的是 Resampler 作为img embedding 到 ip_tokens 的映射网络,对图像(image prompt)中信息的抽取更加细粒度。其他模块都不需要梯度下降,如下代码所示。
python
# freeze parameters of models to save more memory
unet.requires_grad_(False)
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
image_encoder.requires_grad_(False)
#ip-adapter-plus
image_proj_model = Resampler(
dim=unet.config.cross_attention_dim,
depth=4,
dim_head=64,
heads=12,
num_queries=args.num_tokens,
embedding_dim=image_encoder.config.hidden_size,
output_dim=unet.config.cross_attention_dim,
ff_mult=4
)
...
第二块 - adapter_modules
Decoupled cross-attention 则在以下代码中进行初始化,关键是在特定的 unet 层中进行替换,详细位置可以参考前文中的图片,本文的重点是后续训练的实现。
python
# init adapter modules
attn_procs = {}
unet_sd = unet.state_dict()
for name in unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
if cross_attention_dim is None:
attn_procs[name] = AttnProcessor()
else:
layer_name = name.split(".processor")[0]
weights = {
"to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
"to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
}
attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, num_tokens=args.num_tokens)
attn_procs[name].load_state_dict(weights)
unet.set_attn_processor(attn_procs)
adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())
二、训练目标
每个 epoch 是遍历完一整个 dataset,我们直接从每个训练步的循环中来看:
- latents 是通过 vae 将输入的 image prompt 压到了隐空间(latent space)中。
- 准备相应的 noise 和 timesteps ,再通过 noise_scheduler 来制作出 noisy_latents。
python
for step, batch in enumerate(train_dataloader):
load_data_time = time.perf_counter() - begin
with accelerator.accumulate(ip_adapter):
# Convert images to latent space
with torch.no_grad():
latents = vae.encode(batch["images"].to(accelerator.device, dtype=weight_dtype)).latent_dist.sample()
latents = latents * vae.config.scaling_factor
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
- clip_images 和 drop_image_embed 是在准备数据的过程中,做了一个随机 drop 的方式进行数据增强,提升模型鲁棒性。
- 数据增强:通过随机丢弃一些图像,模型被迫学习从剩余的图像中提取信息,这可以增加模型的泛化能力。
- 模型鲁棒性:训练模型以处理不完整的数据,使其在实际应用中对缺失数据更加鲁棒。
python
clip_images = []
for clip_image, drop_image_embed in zip(batch["clip_images"], batch["drop_image_embeds"]):
if drop_image_embed == 1:
clip_images.append(torch.zeros_like(clip_image))
else:
clip_images.append(clip_image)
clip_images = torch.stack(clip_images, dim=0)
with torch.no_grad():
image_embeds = image_encoder(clip_images.to(accelerator.device, dtype=weight_dtype), output_hidden_states=True).hidden_states[-2]
with torch.no_grad():
encoder_hidden_states = text_encoder(batch["text_input_ids"].to(accelerator.device))[0]