第五章:计算机视觉-项目实战之生成式算法实战:扩散模型
第三部分:生成式算法实战:扩散模型
第四节:在新数据集上微调现有扩散模型
一、章节导读
前一节中,我们从零开始训练了一个 DDPM 扩散模型 ,理解了其核心原理与训练过程。
但在真实项目中,我们往往不希望从头训练一个模型------
原因包括:
-
数据集有限,难以支撑从零开始训练;
-
训练一个高质量的扩散模型需要极高的算力成本;
-
我们更希望在已有模型的基础上微调,让它快速适应特定领域。
本节将系统讲解如何在已有扩散模型(如 Stable Diffusion、DDPM、DreamBooth 模型等)上进行微调(Fine-tuning)。
二、为什么需要微调扩散模型
扩散模型(Diffusion Models)如 Stable Diffusion、Imagen、DALLE 2 等,通常在大规模通用数据上训练。
这些模型能生成"漂亮的图像",但它们并不了解特定领域。
比如:
-
你希望模型学会生成 企业专属产品风格图像;
-
你想让它生成 特定人物/角色/艺术风格;
-
或者你在做科研项目,只需要它理解一个 小领域数据集。
这时,微调(Fine-tuning)就是最经济、高效的选择。
三、微调思路与方法分类
(1)全模型微调(Full Fine-tuning)
直接对整个扩散模型的参数进行重新训练。
优点是灵活、无约束,缺点是显存与时间消耗极高。
(2)LoRA 微调(Low-Rank Adaptation)
仅在模型的部分层中引入低秩矩阵进行学习。
它的核心思想是:
在不修改原始模型参数的前提下,用小规模附加权重实现新的特征学习。
优点:
-
显存占用低
-
微调速度快
-
可与多个任务组合(可"叠加"多个LoRA)
(3)DreamBooth 微调
由 Google 提出,用于生成特定人物或物体。
DreamBooth 的做法是:
-
在已有的 Stable Diffusion 模型上;
-
使用少量目标样本(如一个人10张照片);
-
学习一个与该目标绑定的"特殊标识词",例如
a photo of <my_person>
; -
从而在生成时"召唤出"该对象。
四、微调准备
1.环境依赖
以 Hugging Face Diffusers + PyTorch 为例:
python
pip install diffusers transformers accelerate safetensors datasets
2.加载预训练模型
python
from diffusers import DDPMPipeline, UNet2DModel
model_id = "google/ddpm-cifar10-32"
pipeline = DDPMPipeline.from_pretrained(model_id)
或使用 Stable Diffusion:
python
from diffusers import StableDiffusionPipeline
model_id = "runwayml/stable-diffusion-v1-5"
pipeline = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipeline = pipeline.to("cuda")
3.准备新数据集
数据集要求:
-
与模型原始训练输入尺寸一致(如 512×512);
-
保证多样性;
-
可以通过 DataLoader 加载:
python
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import glob
class MyCustomDataset(Dataset):
def __init__(self, path, transform):
self.files = glob.glob(path + "/*.jpg")
self.transform = transform
def __len__(self):
return len(self.files)
def __getitem__(self, i):
img = Image.open(self.files[i]).convert("RGB")
return self.transform(img)
dataset = MyCustomDataset("./my_data", transform=your_transforms)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
五、LoRA微调实现流程(推荐)
我们重点介绍 Stable Diffusion + LoRA 微调流程。
Step 1:定义 LoRA 模型
python
from diffusers import UNet2DConditionModel
from peft import LoraConfig, get_peft_model
unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
lora_config = LoraConfig(
r=4, # rank
lora_alpha=16,
target_modules=["to_q", "to_k", "to_v"], # self-attention部分
lora_dropout=0.1,
)
unet = get_peft_model(unet, lora_config)
Step 2:冻结主干,启用LoRA参数训练
python
for name, param in unet.named_parameters():
if "lora" not in name:
param.requires_grad = False
Step 3:训练循环
python
optimizer = torch.optim.AdamW(unet.parameters(), lr=1e-4)
for epoch in range(num_epochs):
for imgs in dataloader:
imgs = imgs.to("cuda")
noise = torch.randn_like(imgs)
timesteps = torch.randint(0, 1000, (imgs.shape[0],), device="cuda")
noisy_imgs = pipeline.scheduler.add_noise(imgs, noise, timesteps)
noise_pred = unet(noisy_imgs, timesteps, encoder_hidden_states=None).sample
loss = torch.nn.functional.mse_loss(noise_pred, noise)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"Epoch [{epoch}] Loss: {loss.item():.4f}")
Step 4:保存权重
python
unet.save_pretrained("./lora_diffusion_finetuned")
六、模型推理与效果对比
微调完成后,我们可以加载并生成图像:
python
from diffusers import StableDiffusionPipeline
from peft import PeftModel
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5").to("cuda")
pipe.unet = PeftModel.from_pretrained(pipe.unet, "./lora_diffusion_finetuned")
prompt = "A photo of a futuristic electric car in neon city"
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
image.save("finetuned_output.png")
效果对比表:
模型 | 数据集 | 样本数量 | 训练时间 | 生成效果 |
---|---|---|---|---|
Stable Diffusion 原模型 | 通用大数据集 | 无 | 无 | 泛化好,但风格偏通用 |
微调后模型(LoRA) | 特定产品图 | 500 | 2小时 | 明显贴合目标风格 |
DreamBooth 微调模型 | 特定人物 | 20 | 1小时 | 能复现人物特征 |
七、训练曲线与指标
微调过程中常用指标:
-
Loss 曲线:反映收敛性;
-
FID(Fréchet Inception Distance):衡量生成图像与真实样本的差异;
-
PSNR/SSIM(超分辨率任务)
训练完成后,你可以绘制如下曲线图以分析效果:
python
import matplotlib.pyplot as plt
plt.plot(loss_list)
plt.title("LoRA Fine-tuning Loss Curve")
plt.xlabel("Steps")
plt.ylabel("Loss")
plt.show()
八、微调中的优化技巧
技巧 | 说明 |
---|---|
EMA 平滑参数 | 提升生成稳定性 |
渐进式学习率衰减 | 训练早期快速收敛 |
Prompt Engineering | 微调中精心设计提示词能引导模型更好学习目标风格 |
混合精度训练(AMP) | 降低显存占用 |
数据增强 | 对小数据集尤为关键(水平翻转、亮度调整等) |
九、小结
内容 | 核心要点 |
---|---|
微调目标 | 让通用扩散模型适应特定任务或风格 |
微调方法 | 全参数、LoRA、DreamBooth |
实践流程 | 加载预训练模型 → 构建数据集 → 选择微调策略 → 保存权重 |
优化建议 | 控制学习率、EMA、Prompt优化、数据增强 |
本节收获
通过本节,你已经掌握了:
-
如何在新数据集上微调现有扩散模型;
-
如何使用LoRA等轻量化方法快速适配领域任务;
-
如何评估与验证微调效果。