【第五章:计算机视觉-项目实战之生成式算法实战:扩散模型】3.生成式算法实战:扩散模型-(4)在新数据集上微调现有扩散模型

第五章:计算机视觉-项目实战之生成式算法实战:扩散模型

第三部分:生成式算法实战:扩散模型

第四节:在新数据集上微调现有扩散模型


一、章节导读

前一节中,我们从零开始训练了一个 DDPM 扩散模型 ,理解了其核心原理与训练过程。

但在真实项目中,我们往往不希望从头训练一个模型------

原因包括:

  • 数据集有限,难以支撑从零开始训练;

  • 训练一个高质量的扩散模型需要极高的算力成本;

  • 我们更希望在已有模型的基础上微调,让它快速适应特定领域。

本节将系统讲解如何在已有扩散模型(如 Stable Diffusion、DDPM、DreamBooth 模型等)上进行微调(Fine-tuning)


二、为什么需要微调扩散模型

扩散模型(Diffusion Models)如 Stable Diffusion、Imagen、DALLE 2 等,通常在大规模通用数据上训练。

这些模型能生成"漂亮的图像",但它们并不了解特定领域。

比如:

  • 你希望模型学会生成 企业专属产品风格图像

  • 你想让它生成 特定人物/角色/艺术风格

  • 或者你在做科研项目,只需要它理解一个 小领域数据集

这时,微调(Fine-tuning)就是最经济、高效的选择。


三、微调思路与方法分类

(1)全模型微调(Full Fine-tuning)

直接对整个扩散模型的参数进行重新训练。

优点是灵活、无约束,缺点是显存与时间消耗极高。

(2)LoRA 微调(Low-Rank Adaptation)

仅在模型的部分层中引入低秩矩阵进行学习。

它的核心思想是:

在不修改原始模型参数的前提下,用小规模附加权重实现新的特征学习。

优点:

  • 显存占用低

  • 微调速度快

  • 可与多个任务组合(可"叠加"多个LoRA)

(3)DreamBooth 微调

由 Google 提出,用于生成特定人物或物体。

DreamBooth 的做法是:

  • 在已有的 Stable Diffusion 模型上;

  • 使用少量目标样本(如一个人10张照片);

  • 学习一个与该目标绑定的"特殊标识词",例如 a photo of <my_person>

  • 从而在生成时"召唤出"该对象。


四、微调准备

1.环境依赖

Hugging Face Diffusers + PyTorch 为例:

python 复制代码
pip install diffusers transformers accelerate safetensors datasets

2.加载预训练模型

python 复制代码
from diffusers import DDPMPipeline, UNet2DModel

model_id = "google/ddpm-cifar10-32"
pipeline = DDPMPipeline.from_pretrained(model_id)

或使用 Stable Diffusion:

python 复制代码
from diffusers import StableDiffusionPipeline

model_id = "runwayml/stable-diffusion-v1-5"
pipeline = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipeline = pipeline.to("cuda")

3.准备新数据集

数据集要求:

  • 与模型原始训练输入尺寸一致(如 512×512);

  • 保证多样性;

  • 可以通过 DataLoader 加载:

python 复制代码
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import glob

class MyCustomDataset(Dataset):
    def __init__(self, path, transform):
        self.files = glob.glob(path + "/*.jpg")
        self.transform = transform
    def __len__(self):
        return len(self.files)
    def __getitem__(self, i):
        img = Image.open(self.files[i]).convert("RGB")
        return self.transform(img)

dataset = MyCustomDataset("./my_data", transform=your_transforms)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

五、LoRA微调实现流程(推荐)

我们重点介绍 Stable Diffusion + LoRA 微调流程

Step 1:定义 LoRA 模型

python 复制代码
from diffusers import UNet2DConditionModel
from peft import LoraConfig, get_peft_model

unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
lora_config = LoraConfig(
    r=4,               # rank
    lora_alpha=16,
    target_modules=["to_q", "to_k", "to_v"],  # self-attention部分
    lora_dropout=0.1,
)
unet = get_peft_model(unet, lora_config)

Step 2:冻结主干,启用LoRA参数训练

python 复制代码
for name, param in unet.named_parameters():
    if "lora" not in name:
        param.requires_grad = False

Step 3:训练循环

python 复制代码
optimizer = torch.optim.AdamW(unet.parameters(), lr=1e-4)

for epoch in range(num_epochs):
    for imgs in dataloader:
        imgs = imgs.to("cuda")

        noise = torch.randn_like(imgs)
        timesteps = torch.randint(0, 1000, (imgs.shape[0],), device="cuda")
        noisy_imgs = pipeline.scheduler.add_noise(imgs, noise, timesteps)

        noise_pred = unet(noisy_imgs, timesteps, encoder_hidden_states=None).sample
        loss = torch.nn.functional.mse_loss(noise_pred, noise)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch [{epoch}] Loss: {loss.item():.4f}")

Step 4:保存权重

python 复制代码
unet.save_pretrained("./lora_diffusion_finetuned")

六、模型推理与效果对比

微调完成后,我们可以加载并生成图像:

python 复制代码
from diffusers import StableDiffusionPipeline
from peft import PeftModel

pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5").to("cuda")
pipe.unet = PeftModel.from_pretrained(pipe.unet, "./lora_diffusion_finetuned")

prompt = "A photo of a futuristic electric car in neon city"
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
image.save("finetuned_output.png")

效果对比表:

模型 数据集 样本数量 训练时间 生成效果
Stable Diffusion 原模型 通用大数据集 泛化好,但风格偏通用
微调后模型(LoRA) 特定产品图 500 2小时 明显贴合目标风格
DreamBooth 微调模型 特定人物 20 1小时 能复现人物特征

七、训练曲线与指标

微调过程中常用指标:

  • Loss 曲线:反映收敛性;

  • FID(Fréchet Inception Distance):衡量生成图像与真实样本的差异;

  • PSNR/SSIM(超分辨率任务)

训练完成后,你可以绘制如下曲线图以分析效果:

python 复制代码
import matplotlib.pyplot as plt
plt.plot(loss_list)
plt.title("LoRA Fine-tuning Loss Curve")
plt.xlabel("Steps")
plt.ylabel("Loss")
plt.show()

八、微调中的优化技巧

技巧 说明
EMA 平滑参数 提升生成稳定性
渐进式学习率衰减 训练早期快速收敛
Prompt Engineering 微调中精心设计提示词能引导模型更好学习目标风格
混合精度训练(AMP) 降低显存占用
数据增强 对小数据集尤为关键(水平翻转、亮度调整等)

九、小结

内容 核心要点
微调目标 让通用扩散模型适应特定任务或风格
微调方法 全参数、LoRA、DreamBooth
实践流程 加载预训练模型 → 构建数据集 → 选择微调策略 → 保存权重
优化建议 控制学习率、EMA、Prompt优化、数据增强

本节收获

通过本节,你已经掌握了:

  • 如何在新数据集上微调现有扩散模型

  • 如何使用LoRA等轻量化方法快速适配领域任务

  • 如何评估与验证微调效果

相关推荐
DisonTangor6 分钟前
【小米拥抱开源】小米MiMo团队开源309B专家混合模型——MiMo-V2-Flash
人工智能·开源·aigc
hxxjxw21 分钟前
Pytorch分布式训练/多卡训练(六) —— Expert Parallelism (MoE的特殊策略)
人工智能·pytorch·python
Robot侠28 分钟前
视觉语言导航从入门到精通(一)
网络·人工智能·microsoft·llm·vln
掘金一周29 分钟前
【用户行为监控】别只做工具人了!手把手带你写一个前端埋点统计 SDK | 掘金一周 12.18
前端·人工智能·后端
神州问学30 分钟前
世界模型:AI的下一个里程碑
人工智能
zhaodiandiandian32 分钟前
AI深耕产业腹地 新质生产力的实践路径与价值彰显
人工智能
古德new36 分钟前
openFuyao AI大数据场景加速技术实践指南
大数据·人工智能
youcans_44 分钟前
【医学影像 AI】FunBench:评估多模态大语言模型的眼底影像解读能力
论文阅读·人工智能·大语言模型·多模态·眼底图像
dagouaofei1 小时前
PPT AI生成实测报告:哪些工具值得长期使用?
人工智能·python·powerpoint
蓝桉~MLGT1 小时前
Ai-Agent学习历程—— Agent认知框架
人工智能·学习