【Datawhale】扩散模型学习笔记 第一次打卡

文章目录

  • 扩散模型学习笔记
    • [1. 扩散模型库Diffusers](#1. 扩散模型库Diffusers)
      • [1.1 安装](#1.1 安装)
      • [1.2 使用](#1.2 使用)
    • [2. 从零开始搭建扩散模型](#2. 从零开始搭建扩散模型)
      • [2.1 数据准备](#2.1 数据准备)
      • [2.2 损坏过程](#2.2 损坏过程)
      • [2.3 模型构建](#2.3 模型构建)
      • [2.4 模型训练](#2.4 模型训练)
      • [2.5 采样](#2.5 采样)
    • [3. webui](#3. webui)

扩散模型学习笔记

1. 扩散模型库Diffusers

1.1 安装

由于diffusers库更新较快,所以建议时常upgrade

python 复制代码
# pip
pip install --upgrade diffusers[torch]
# conda
conda install -c conda-forge diffusers

1.2 使用

python 复制代码
from diffusers import DiffusionPipeline

generator = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", use_safetensors=True)
generator.to("cuda")
image = generator("An image of a squirrel in Picasso style").images[0]
image.save("image_of_squirrel_painting.png")

2. 从零开始搭建扩散模型

2.1 数据准备

在这个示例中,我们将使用经典的MNIST数据集作为示范。MNIST数据集包含28x28像素的手写数字图像,每个像素值的范围从0到1。

2.2 损坏过程

我们希望能够控制输入数据的损坏程度,因此引入了一个参数 amount,该参数控制了噪声的程度。你可以使用以下方法来添加噪声:

python 复制代码
noise = torch.rand_like(x)
noisy_x = (1 - amount) * x + amount * noise

如果 amount 为0,则输入数据保持不变。如果 amount 为1,输入数据将变为纯粹的噪声。通过混合输入数据和噪声,我们可以确保输出数据的范围仍在0到1之间。

2.3 模型构建

我们将使用UNet模型来处理噪声图像。UNet是一种用于图像分割的常见架构,由压缩路径和扩展路径组成。在这个示范中,我们将构建一个简化版本的UNet,它接收单通道图像,并通过卷积层在下行路径(down_layers)和上行路径(up_layers)之间具有残差连接。我们将使用最大池化进行下采样和 nn.Upsample 进行上采样。

2.4 模型训练

在模型训练过程中,模型的任务是将损坏的输入 noisy_x 转换为对原始图像 x 的最佳估计。我们使用均方误差(MSE)来比较模型的预测与真实值,然后使用反向传播算法来更新模型的参数。

2.5 采样

如果模型在高噪声水平下的预测不够理想,可以进行采样以生成更好的图像。你可以从完全随机的噪声图像开始,然后逐渐接近模型的预测。这意味着你可以检查模型的预测结果,然后只向预测的方向移动一小步,比如向预测值移动20%。这将生成一个具有较少噪声的图像,其中可能包含一些关于输入数据的结构提示。将这个新图像输入模型,希望得到比第一个预测更好的结果。这个过程可以迭代多次,以逐渐减小噪声并生成更好的图像。

这是一个简化的扩散模型搭建和训练的概述。你可以根据具体的问题和数据进行修改和优化,以获得更好的结果。希望这些步骤能帮助你理解如何搭建扩散模型并训练它。

python 复制代码
from diffusers import DDPMScheduler, UNet2DModel
from PIL import Image
import torch
import numpy as np

scheduler = DDPMScheduler.from_pretrained("google/ddpm-cat-256")
model = UNet2DModel.from_pretrained("google/ddpm-cat-256").to("cuda")
scheduler.set_timesteps(50)

sample_size = model.config.sample_size
noise = torch.randn((1, 3, sample_size, sample_size)).to("cuda")
input = noise

for t in scheduler.timesteps:
    with torch.no_grad():
        noisy_residual = model(input, t).sample
        prev_noisy_sample = scheduler.step(noisy_residual, t, input).prev_sample
        input = prev_noisy_sample

image = (input / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
image = Image.fromarray((image * 255).round().astype("uint8"))
image

3. webui

参考我的另一篇博客:https://blog.csdn.net/qq_44824148/article/details/130389357

相关推荐
电子云与长程纠缠2 分钟前
Blender入门学习01
学习·blender
qiuiuiu4131 小时前
正点原子RK3568学习日志12-注册字符设备
linux·开发语言·单片机·学习·ubuntu
聪明的笨猪猪2 小时前
Java JVM “内存(1)”面试清单(含超通俗生活案例与深度理解)
java·经验分享·笔记·面试
_dindong2 小时前
Linux网络编程:Socket编程TCP
linux·服务器·网络·笔记·学习·tcp/ip
金士顿2 小时前
ethercat网络拓扑详细学习
学习
知识分享小能手2 小时前
uni-app 入门学习教程,从入门到精通,uni-app组件 —— 知识点详解与实战案例(4)
前端·javascript·学习·微信小程序·小程序·前端框架·uni-app
wahkim2 小时前
Flutter 学习资源及视频
学习
摇滚侠3 小时前
Spring Boot 3零基础教程,WEB 开发 Thymeleaf 属性优先级 行内写法 变量选择 笔记42
java·spring boot·笔记
摇滚侠3 小时前
Spring Boot 3零基础教程,WEB 开发 Thymeleaf 总结 热部署 常用配置 笔记44
java·spring boot·笔记
小白要努力sgy3 小时前
待学习--中间件
学习·中间件