【第五章:计算机视觉-项目实战之生成式算法实战:扩散模型】3.生成式算法实战:扩散模型-(4)在新数据集上微调现有扩散模型

第五章:计算机视觉-项目实战之生成式算法实战:扩散模型

第三部分:生成式算法实战:扩散模型

第四节:在新数据集上微调现有扩散模型


一、章节导读

前一节中,我们从零开始训练了一个 DDPM 扩散模型 ,理解了其核心原理与训练过程。

但在真实项目中,我们往往不希望从头训练一个模型------

原因包括:

  • 数据集有限,难以支撑从零开始训练;

  • 训练一个高质量的扩散模型需要极高的算力成本;

  • 我们更希望在已有模型的基础上微调,让它快速适应特定领域。

本节将系统讲解如何在已有扩散模型(如 Stable Diffusion、DDPM、DreamBooth 模型等)上进行微调(Fine-tuning)


二、为什么需要微调扩散模型

扩散模型(Diffusion Models)如 Stable Diffusion、Imagen、DALLE 2 等,通常在大规模通用数据上训练。

这些模型能生成"漂亮的图像",但它们并不了解特定领域。

比如:

  • 你希望模型学会生成 企业专属产品风格图像

  • 你想让它生成 特定人物/角色/艺术风格

  • 或者你在做科研项目,只需要它理解一个 小领域数据集

这时,微调(Fine-tuning)就是最经济、高效的选择。


三、微调思路与方法分类

(1)全模型微调(Full Fine-tuning)

直接对整个扩散模型的参数进行重新训练。

优点是灵活、无约束,缺点是显存与时间消耗极高。

(2)LoRA 微调(Low-Rank Adaptation)

仅在模型的部分层中引入低秩矩阵进行学习。

它的核心思想是:

在不修改原始模型参数的前提下,用小规模附加权重实现新的特征学习。

优点:

  • 显存占用低

  • 微调速度快

  • 可与多个任务组合(可"叠加"多个LoRA)

(3)DreamBooth 微调

由 Google 提出,用于生成特定人物或物体。

DreamBooth 的做法是:

  • 在已有的 Stable Diffusion 模型上;

  • 使用少量目标样本(如一个人10张照片);

  • 学习一个与该目标绑定的"特殊标识词",例如 a photo of <my_person>

  • 从而在生成时"召唤出"该对象。


四、微调准备

1.环境依赖

Hugging Face Diffusers + PyTorch 为例:

python 复制代码
pip install diffusers transformers accelerate safetensors datasets

2.加载预训练模型

python 复制代码
from diffusers import DDPMPipeline, UNet2DModel

model_id = "google/ddpm-cifar10-32"
pipeline = DDPMPipeline.from_pretrained(model_id)

或使用 Stable Diffusion:

python 复制代码
from diffusers import StableDiffusionPipeline

model_id = "runwayml/stable-diffusion-v1-5"
pipeline = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipeline = pipeline.to("cuda")

3.准备新数据集

数据集要求:

  • 与模型原始训练输入尺寸一致(如 512×512);

  • 保证多样性;

  • 可以通过 DataLoader 加载:

python 复制代码
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import glob

class MyCustomDataset(Dataset):
    def __init__(self, path, transform):
        self.files = glob.glob(path + "/*.jpg")
        self.transform = transform
    def __len__(self):
        return len(self.files)
    def __getitem__(self, i):
        img = Image.open(self.files[i]).convert("RGB")
        return self.transform(img)

dataset = MyCustomDataset("./my_data", transform=your_transforms)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

五、LoRA微调实现流程(推荐)

我们重点介绍 Stable Diffusion + LoRA 微调流程

Step 1:定义 LoRA 模型

python 复制代码
from diffusers import UNet2DConditionModel
from peft import LoraConfig, get_peft_model

unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
lora_config = LoraConfig(
    r=4,               # rank
    lora_alpha=16,
    target_modules=["to_q", "to_k", "to_v"],  # self-attention部分
    lora_dropout=0.1,
)
unet = get_peft_model(unet, lora_config)

Step 2:冻结主干,启用LoRA参数训练

python 复制代码
for name, param in unet.named_parameters():
    if "lora" not in name:
        param.requires_grad = False

Step 3:训练循环

python 复制代码
optimizer = torch.optim.AdamW(unet.parameters(), lr=1e-4)

for epoch in range(num_epochs):
    for imgs in dataloader:
        imgs = imgs.to("cuda")

        noise = torch.randn_like(imgs)
        timesteps = torch.randint(0, 1000, (imgs.shape[0],), device="cuda")
        noisy_imgs = pipeline.scheduler.add_noise(imgs, noise, timesteps)

        noise_pred = unet(noisy_imgs, timesteps, encoder_hidden_states=None).sample
        loss = torch.nn.functional.mse_loss(noise_pred, noise)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch [{epoch}] Loss: {loss.item():.4f}")

Step 4:保存权重

python 复制代码
unet.save_pretrained("./lora_diffusion_finetuned")

六、模型推理与效果对比

微调完成后,我们可以加载并生成图像:

python 复制代码
from diffusers import StableDiffusionPipeline
from peft import PeftModel

pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5").to("cuda")
pipe.unet = PeftModel.from_pretrained(pipe.unet, "./lora_diffusion_finetuned")

prompt = "A photo of a futuristic electric car in neon city"
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
image.save("finetuned_output.png")

效果对比表:

模型 数据集 样本数量 训练时间 生成效果
Stable Diffusion 原模型 通用大数据集 泛化好,但风格偏通用
微调后模型(LoRA) 特定产品图 500 2小时 明显贴合目标风格
DreamBooth 微调模型 特定人物 20 1小时 能复现人物特征

七、训练曲线与指标

微调过程中常用指标:

  • Loss 曲线:反映收敛性;

  • FID(Fréchet Inception Distance):衡量生成图像与真实样本的差异;

  • PSNR/SSIM(超分辨率任务)

训练完成后,你可以绘制如下曲线图以分析效果:

python 复制代码
import matplotlib.pyplot as plt
plt.plot(loss_list)
plt.title("LoRA Fine-tuning Loss Curve")
plt.xlabel("Steps")
plt.ylabel("Loss")
plt.show()

八、微调中的优化技巧

技巧 说明
EMA 平滑参数 提升生成稳定性
渐进式学习率衰减 训练早期快速收敛
Prompt Engineering 微调中精心设计提示词能引导模型更好学习目标风格
混合精度训练(AMP) 降低显存占用
数据增强 对小数据集尤为关键(水平翻转、亮度调整等)

九、小结

内容 核心要点
微调目标 让通用扩散模型适应特定任务或风格
微调方法 全参数、LoRA、DreamBooth
实践流程 加载预训练模型 → 构建数据集 → 选择微调策略 → 保存权重
优化建议 控制学习率、EMA、Prompt优化、数据增强

本节收获

通过本节,你已经掌握了:

  • 如何在新数据集上微调现有扩散模型

  • 如何使用LoRA等轻量化方法快速适配领域任务

  • 如何评估与验证微调效果

相关推荐
嵌入式-老费3 小时前
Easyx图形库使用(潜力无限的图像处理)
图像处理·人工智能
JXY_AI3 小时前
AI问答与搜索引擎:信息获取的现状
人工智能·搜索引擎
B站_计算机毕业设计之家3 小时前
Python+Flask+Prophet 汽车之家二手车系统 逻辑回归 二手车推荐系统 机器学习(逻辑回归+Echarts 源码+文档)✅
大数据·人工智能·python·机器学习·数据分析·汽车·大屏端
XXX-X-XXJ3 小时前
三、从 MinIO 存储到 OCR 提取,再到向量索引生成
人工智能·后端·python·ocr
AI人工智能+3 小时前
行驶证识别技术通过OCR和AI实现信息自动化采集与处理,涵盖图像预处理、文字识别及结构化校验,提升效率与准确性
人工智能·深度学习·ocr·行驶证识别
EkihzniY3 小时前
医疗发票 OCR 识别:打通医疗费用处理 “堵点” 的技术助手
大数据·人工智能·ocr
慷仔4 小时前
游戏编程模式-享元模式(Flyweight)
人工智能·游戏·享元模式
dlraba8024 小时前
Pandas:机器学习数据处理的核心利器
人工智能·机器学习·pandas
m0_677034354 小时前
机器学习-推荐系统(上)
人工智能·机器学习