图解Diffusion扩散模型+代码

0、项目视频详解

视频教程见B站https://www.bilibili.com/video/BV1e8411a7mz

1、diffusion模型理论(推导出损失函数)

1.1、背景

随着人工智能在图像生成,文本生成以及多模态生成等领域的技术不断累积,如:生成对抗网络(GAN)、变微分自动编码器(VAE)、normalizing flow models、自回归模型(AR)、energy-based models以及近年来大火的扩散模型(Diffusion Model)。

扩散模型的成功并非横空出世一般,突然出现在人们的视野中。其实早在2015年就已有人提出相类似的想法,最终在2020年提出了我们所熟知的"denoising diffusion probabilistic models"。DDPM

近期的novelai的生成技术同样是基于扩散模型,以下可以看到其强大的生成效果。可在此处跳转进行玩耍。

本项目可以达到的效果如下。输入向日葵,cfg=7的结果。可以看到,效果已经比较不错了。

1.2、模型训练与采样的算法流程

先放个图,1.3和1.4进行具体的流程与公式推导。我们要做的就是要推导出训练过程中的损失函数。

1.3、前向噪声扩散公式推导

diffusion模型的前向过程是向原始图片中逐步的添加高斯噪声,直至最后的图像趋于高斯分布。由于噪声占比会越来越大,所以添加噪声的强度也会越来越大。如下图所示:

  • 每一时刻的图像都由前一时刻的图像添加噪声得到

  • 最后的图像会变成纯噪声

  • 每一时刻的添加的噪声强度均不同,目前有线性调度器,余弦调度器等

  • 这一过程构建了我们训练所用到的标签,后面会看到

下面的推导过程展示了,我们如何从初始图像直接得到第t时刻的图像。

这个公式为下面的推导打上一个铺垫,下面一节就是关键的损失函数推导了。

1.4、优化目标,损失函数推导

上面的正向扩散并不难,下面我们推导反向扩散过程。即由Xt到Xt-1。

2、非条件生成(随机生成图片)

使用stanford汽车图片为例,没有类别。

2.1、训练过程解析

我们使用前向过程采样得到标签,训练时使用Unet网络结构,同时在模型的输入中嵌入时间步的编码。这类似于transformer模型中的位置编码,让模型更容易训练。 如下图所示:

2.2、数据解压

解压我们的数据集。只需要首次运行该项目时解压即可!

In [13]

import os
if not os.path.exists("work/cars"):
    !mkdir work/cars
!unzip -oq data/data173302/stanford_cars.zip -d work/cars

In [14]

# 删除多余文件
!rm -rf work/cars/cars_test
!rm -rf work/cars/devkit
!rm -rf work/cars/car_devkit.tgz
!rm -rf work/cars/cars_train.tgz
!rm -rf work/cars/cars_test.tgz
!rm -rf work/cars/cars_test_annos_withlabels.mat

2.3、数据展示

查看我们的汽车图片。

In [1]

import paddle
import paddle.vision
import matplotlib.pyplot as plt
from PIL import Image
%matplotlib inline

# 定义展示图片函数
def show_images(imgs_paths=[],cols=4):
    num_samples = len(imgs_paths)
    plt.figure(figsize=(15,15))
    i = 0
    for img_path in imgs_paths:
        img = Image.open(img_path)
        plt.subplot(int(num_samples/cols + 1), cols, i + 1)
        plt.imshow(img)
        i += 1

imgs_paths = [
    "work/cars/cars_train/05930.jpg", "work/cars/cars_train/06816.jpg", "work/cars/cars_train/02885.jpg", "work/cars/cars_train/07471.jpg",
    "work/cars/cars_train/06600.jpg", "work/cars/cars_train/06020.jpg", "work/cars/cars_train/04818.jpg", "work/cars/cars_train/06088.jpg"
]
show_images(imgs_paths)
复制代码
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/__init__.py:107: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import MutableMapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/rcsetup.py:20: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Iterable, Mapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/colors.py:53: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Sized
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2349: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  if isinstance(obj, collections.Iterator):
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2366: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  return list(data) if isinstance(data, collections.MappingView) else data
复制代码
<Figure size 1500x1500 with 8 Axes>

2.4、构建数据集

我们使用paddle.vision里的数据集接口即可。

In [2]

import os
import paddle
import paddle.nn as nn
import paddle.vision as V
from PIL import Image
from matplotlib import pyplot as plt
from paddle.io import DataLoader

# 这里我们不需要用到图像标签,可以直接用paddle.vision里面提供的数据集接口
def get_data(args):
    transforms = V.transforms.Compose([
        V.transforms.Resize(80),  # args.image_size + 1/4 *args.image_size
        V.transforms.RandomResizedCrop(args.image_size, scale=(0.8, 1.0)),
        V.transforms.ToTensor(),
        V.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    dataset = V.datasets.ImageFolder(args.dataset_path, transform=transforms)
    dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
    return dataloader

2.5、训练流程

训练中我们可以修改ARGS类的参数进行超参数定义。基本上,只要知道我们的损失函数是两张图片之间的均方误差,代码部分会变得比较简单。对比GAN而言,diffusion的参数更加容易调整,也更容易训练。

In [3]

"""ddpm"""

import os
import paddle
import paddle.nn as nn
from matplotlib import pyplot as plt
%matplotlib inline
from tqdm import tqdm
from paddle import optimizer
# from utils import *
from modules import UNet    # 模型
import logging
import numpy as np

logging.basicConfig(format="%(asctime)s - %(levelname)s: %(message)s", level=logging.INFO, datefmt="%I:%M:%S")


class Diffusion:
    def __init__(self, noise_steps=500, beta_start=1e-4, beta_end=0.02, img_size=64, device="cuda"):
        self.noise_steps = noise_steps
        self.beta_start = beta_start
        self.beta_end = beta_end
        self.img_size = img_size
        self.device = device

        self.beta = self.prepare_noise_schedule()
        self.alpha = 1. - self.beta
        self.alpha_hat = paddle.cumprod(self.alpha, dim=0)

    def prepare_noise_schedule(self):
        return paddle.linspace(self.beta_start, self.beta_end, self.noise_steps)

    def noise_images(self, x, t):
        sqrt_alpha_hat = paddle.sqrt(self.alpha_hat[t])[:, None, None, None]
        sqrt_one_minus_alpha_hat = paddle.sqrt(1 - self.alpha_hat[t])[:, None, None, None]
        Ɛ = paddle.randn(shape=x.shape)
        return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * Ɛ, Ɛ

    def sample_timesteps(self, n):
        return paddle.randint(low=1, high=self.noise_steps, shape=(n,))

    def sample(self, model, n):
        logging.info(f"Sampling {n} new images....")
        model.eval()
        with paddle.no_grad():
            x = paddle.randn((n, 3, self.img_size, self.img_size))
            for i in tqdm(reversed(range(1, self.noise_steps)), position=0):


                t = paddle.to_tensor([i] * x.shape[0]).astype("int64")
                # print(x.shape, t.shape)

                # print(f"完成第{i}步")
                predicted_noise = model(x, t)
                alpha = self.alpha[t][:, None, None, None]
                alpha_hat = self.alpha_hat[t][:, None, None, None]
                beta = self.beta[t][:, None, None, None]
                if i > 1:
                    noise = paddle.randn(shape=x.shape)
                else:
                    noise = paddle.zeros_like(x)
                x = 1 / paddle.sqrt(alpha) * (x - ((1 - alpha) / (paddle.sqrt(1 - alpha_hat))) * predicted_noise) + paddle.sqrt(beta) * noise
        model.train()
        x = (x.clip(-1, 1) + 1) / 2
        x = (x * 255)
        return x

def train(args):
    # setup_logging(args.run_name)
    device = args.device
    dataloader = get_data(args)

    image = next(iter(dataloader))[0]

    model = UNet()
    opt = optimizer.Adam(learning_rate=args.lr, parameters=model.parameters())
    mse = nn.MSELoss()
    diffusion = Diffusion(img_size=args.image_size, device=device)
    # logger = SummaryWriter(os.path.join("runs", args.run_name))
    l = len(dataloader)

    for epoch in range(args.epochs):
        logging.info(f"Starting epoch {epoch}:")
        pbar = tqdm(dataloader)
        for i, images in enumerate(pbar):
            # print(images)
            t = diffusion.sample_timesteps(images[0].shape[0])
            x_t, noise = diffusion.noise_images(images[0], t)
            predicted_noise = model(x_t, t)
            loss = mse(noise, predicted_noise)  # 损失函数

            opt.clear_grad()
            loss.backward()
            opt.step()

            pbar.set_postfix(MSE=loss.item())

            # print(("MSE", loss.item(), "global_step", epoch * l + i))
            # logger.add_scalar("MSE", loss.item(), global_step=epoch * l + i)
        
        if epoch % 20 == 0:
            paddle.save(model.state_dict(), f"car_models/ddpm_uncond{epoch}.pdparams")
            sampled_images = diffusion.sample(model, n=8)

            for i in range(8):
                img = sampled_images[i].transpose([1, 2, 0])
                img = np.array(img).astype("uint8")
                plt.subplot(2,4,i+1)
                plt.imshow(img)
            plt.show()

def launch():
    import argparse

    # 参数设置
    class ARGS:
        def __init__(self):
            self.run_name = "DDPM_Uncondtional"
            self.epochs = 150
            self.batch_size = 24
            self.image_size = 64
            self.dataset_path = r"/home/aistudio/work/cars"
            self.device = "cuda"
            self.lr = 1.5e-4

    args = ARGS()
    train(args)


if __name__ == '__main__':
    launch()
    pass
复制代码
W1024 11:03:25.091079   573 gpu_resources.cc:61] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.2, Runtime API Version: 11.2
W1024 11:03:25.094197   573 gpu_resources.cc:91] device: 0, cuDNN Version: 8.2.
11:03:25 - INFO: Starting epoch 0:
100%|██████████| 340/340 [02:13<00:00,  3.70it/s, MSE=0.15] 
11:05:39 - INFO: Sampling 8 new images....
499it [00:20, 23.93it/s]
复制代码
<Figure size 640x480 with 8 Axes>
复制代码
11:06:00 - INFO: Starting epoch 1:
100%|██████████| 340/340 [02:13<00:00,  3.11it/s, MSE=0.0725]
11:08:14 - INFO: Starting epoch 2:
100%|██████████| 340/340 [02:12<00:00,  3.37it/s, MSE=0.0777]
11:10:26 - INFO: Starting epoch 3:
100%|██████████| 340/340 [02:12<00:00,  3.44it/s, MSE=0.0814]
11:12:38 - INFO: Starting epoch 4:
100%|██████████| 340/340 [02:12<00:00,  3.30it/s, MSE=0.0579]
11:14:51 - INFO: Starting epoch 5:
100%|██████████| 340/340 [02:13<00:00,  3.40it/s, MSE=0.107] 
11:17:05 - INFO: Starting epoch 6:
100%|██████████| 340/340 [02:14<00:00,  3.49it/s, MSE=0.0742]
11:19:19 - INFO: Starting epoch 7:
100%|██████████| 340/340 [02:14<00:00,  3.20it/s, MSE=0.0422]
11:21:34 - INFO: Starting epoch 8:
100%|██████████| 340/340 [02:13<00:00,  3.26it/s, MSE=0.0527]
11:23:47 - INFO: Starting epoch 9:
100%|██████████| 340/340 [02:13<00:00,  3.45it/s, MSE=0.064] 
11:26:01 - INFO: Starting epoch 10:
100%|██████████| 340/340 [02:15<00:00,  2.91it/s, MSE=0.043] 
11:28:17 - INFO: Starting epoch 11:
100%|██████████| 340/340 [02:14<00:00,  2.60it/s, MSE=0.0712]
11:30:31 - INFO: Starting epoch 12:
100%|██████████| 340/340 [02:13<00:00,  3.23it/s, MSE=0.0674]
11:32:44 - INFO: Starting epoch 13:
100%|██████████| 340/340 [02:14<00:00,  3.00it/s, MSE=0.0464]
11:34:59 - INFO: Starting epoch 14:
100%|██████████| 340/340 [02:14<00:00,  2.93it/s, MSE=0.0349]
11:37:13 - INFO: Starting epoch 15:
100%|██████████| 340/340 [02:13<00:00,  3.58it/s, MSE=0.0279]
11:39:26 - INFO: Starting epoch 16:
100%|██████████| 340/340 [02:14<00:00,  2.62it/s, MSE=0.0436]
11:41:40 - INFO: Starting epoch 17:
100%|██████████| 340/340 [02:15<00:00,  3.06it/s, MSE=0.0278]
11:43:55 - INFO: Starting epoch 18:
100%|██████████| 340/340 [02:13<00:00,  3.03it/s, MSE=0.0318]
11:46:09 - INFO: Starting epoch 19:
100%|██████████| 340/340 [02:13<00:00,  3.01it/s, MSE=0.0743]
11:48:22 - INFO: Starting epoch 20:
100%|██████████| 340/340 [02:12<00:00,  3.26it/s, MSE=0.0721]
11:50:36 - INFO: Sampling 8 new images....
499it [00:20, 24.05it/s]
复制代码
<Figure size 640x480 with 8 Axes>
复制代码
11:50:57 - INFO: Starting epoch 21:
100%|██████████| 340/340 [02:13<00:00,  3.32it/s, MSE=0.0275]
11:53:10 - INFO: Starting epoch 22:
100%|██████████| 340/340 [02:13<00:00,  3.23it/s, MSE=0.028] 
11:55:24 - INFO: Starting epoch 23:
100%|██████████| 340/340 [02:13<00:00,  2.89it/s, MSE=0.0155]
11:57:37 - INFO: Starting epoch 24:
100%|██████████| 340/340 [02:13<00:00,  3.17it/s, MSE=0.0386]
11:59:51 - INFO: Starting epoch 25:
100%|██████████| 340/340 [02:13<00:00,  3.16it/s, MSE=0.0189]
12:02:04 - INFO: Starting epoch 26:
100%|██████████| 340/340 [02:13<00:00,  3.23it/s, MSE=0.0285]
12:04:18 - INFO: Starting epoch 27:
100%|██████████| 340/340 [02:13<00:00,  3.47it/s, MSE=0.0593]
12:06:31 - INFO: Starting epoch 28:
100%|██████████| 340/340 [02:14<00:00,  2.98it/s, MSE=0.0151]
12:08:45 - INFO: Starting epoch 29:
100%|██████████| 340/340 [02:12<00:00,  3.40it/s, MSE=0.0552]
12:10:57 - INFO: Starting epoch 30:
100%|██████████| 340/340 [02:14<00:00,  3.53it/s, MSE=0.0335]
12:13:12 - INFO: Starting epoch 31:
100%|██████████| 340/340 [02:13<00:00,  3.01it/s, MSE=0.00773]
12:15:25 - INFO: Starting epoch 32:
100%|██████████| 340/340 [02:13<00:00,  3.03it/s, MSE=0.0907]
12:17:39 - INFO: Starting epoch 33:
100%|██████████| 340/340 [02:15<00:00,  3.65it/s, MSE=0.0412]
12:19:54 - INFO: Starting epoch 34:
100%|██████████| 340/340 [02:13<00:00,  3.55it/s, MSE=0.0359]
12:22:08 - INFO: Starting epoch 35:
100%|██████████| 340/340 [02:13<00:00,  3.30it/s, MSE=0.0563]
12:24:21 - INFO: Starting epoch 36:
100%|██████████| 340/340 [02:13<00:00,  3.34it/s, MSE=0.0299]
12:26:35 - INFO: Starting epoch 37:
100%|██████████| 340/340 [02:13<00:00,  3.24it/s, MSE=0.0315] 
12:28:49 - INFO: Starting epoch 38:
100%|██████████| 340/340 [02:13<00:00,  3.08it/s, MSE=0.0455]
12:31:02 - INFO: Starting epoch 39:
100%|██████████| 340/340 [02:12<00:00,  3.23it/s, MSE=0.024] 
12:33:15 - INFO: Starting epoch 40:
100%|██████████| 340/340 [02:13<00:00,  3.32it/s, MSE=0.0416]
12:35:29 - INFO: Sampling 8 new images....
499it [00:20, 23.89it/s]
复制代码
<Figure size 640x480 with 8 Axes>
复制代码
12:35:50 - INFO: Starting epoch 41:
100%|██████████| 340/340 [02:13<00:00,  3.18it/s, MSE=0.0134]
12:38:03 - INFO: Starting epoch 42:
100%|██████████| 340/340 [02:12<00:00,  3.77it/s, MSE=0.0948]
12:40:16 - INFO: Starting epoch 43:
100%|██████████| 340/340 [02:13<00:00,  3.16it/s, MSE=0.0208]
12:42:30 - INFO: Starting epoch 44:
100%|██████████| 340/340 [02:13<00:00,  3.29it/s, MSE=0.0421]
12:44:44 - INFO: Starting epoch 45:
100%|██████████| 340/340 [02:13<00:00,  2.88it/s, MSE=0.0296]
12:46:57 - INFO: Starting epoch 46:
100%|██████████| 340/340 [02:12<00:00,  3.00it/s, MSE=0.0398]
12:49:10 - INFO: Starting epoch 47:
100%|██████████| 340/340 [02:13<00:00,  3.06it/s, MSE=0.0269]
12:51:24 - INFO: Starting epoch 48:
100%|██████████| 340/340 [02:12<00:00,  3.34it/s, MSE=0.0635]
12:53:37 - INFO: Starting epoch 49:
100%|██████████| 340/340 [02:12<00:00,  3.58it/s, MSE=0.0687]
12:55:49 - INFO: Starting epoch 50:
100%|██████████| 340/340 [02:12<00:00,  3.08it/s, MSE=0.0253]
12:58:01 - INFO: Starting epoch 51:
100%|██████████| 340/340 [02:12<00:00,  3.33it/s, MSE=0.0219]
01:00:14 - INFO: Starting epoch 52:
100%|██████████| 340/340 [02:12<00:00,  3.13it/s, MSE=0.0422]
01:02:27 - INFO: Starting epoch 53:
100%|██████████| 340/340 [02:12<00:00,  3.26it/s, MSE=0.0187]
01:04:39 - INFO: Starting epoch 54:
100%|██████████| 340/340 [02:14<00:00,  3.39it/s, MSE=0.0453]
01:06:54 - INFO: Starting epoch 55:
100%|██████████| 340/340 [02:14<00:00,  3.45it/s, MSE=0.101] 
01:09:08 - INFO: Starting epoch 56:
100%|██████████| 340/340 [02:15<00:00,  3.22it/s, MSE=0.016] 
01:11:23 - INFO: Starting epoch 57:
100%|██████████| 340/340 [02:14<00:00,  3.21it/s, MSE=0.0173]
01:13:38 - INFO: Starting epoch 58:
100%|██████████| 340/340 [02:13<00:00,  2.65it/s, MSE=0.0127]
01:15:52 - INFO: Starting epoch 59:
100%|██████████| 340/340 [02:14<00:00,  3.56it/s, MSE=0.112] 
01:18:06 - INFO: Starting epoch 60:
100%|██████████| 340/340 [02:14<00:00,  3.01it/s, MSE=0.0155]
01:20:21 - INFO: Sampling 8 new images....
499it [00:21, 23.74it/s]
复制代码
<Figure size 640x480 with 8 Axes>
复制代码
01:20:42 - INFO: Starting epoch 61:
100%|██████████| 340/340 [02:15<00:00,  3.17it/s, MSE=0.0143]
01:22:58 - INFO: Starting epoch 62:
100%|██████████| 340/340 [02:15<00:00,  3.26it/s, MSE=0.0731]
01:25:14 - INFO: Starting epoch 63:
100%|██████████| 340/340 [02:14<00:00,  3.38it/s, MSE=0.0484]
01:27:28 - INFO: Starting epoch 64:
100%|██████████| 340/340 [02:16<00:00,  3.30it/s, MSE=0.0154]
01:29:45 - INFO: Starting epoch 65:
100%|██████████| 340/340 [02:15<00:00,  3.31it/s, MSE=0.0224]
01:32:00 - INFO: Starting epoch 66:
100%|██████████| 340/340 [02:15<00:00,  3.14it/s, MSE=0.0265]
01:34:16 - INFO: Starting epoch 67:
100%|██████████| 340/340 [02:14<00:00,  3.10it/s, MSE=0.0326]
01:36:30 - INFO: Starting epoch 68:
100%|██████████| 340/340 [02:14<00:00,  3.35it/s, MSE=0.0656]
01:38:44 - INFO: Starting epoch 69:
100%|██████████| 340/340 [02:14<00:00,  3.20it/s, MSE=0.0591]
01:40:58 - INFO: Starting epoch 70:
100%|██████████| 340/340 [02:13<00:00,  3.34it/s, MSE=0.0196] 
01:43:12 - INFO: Starting epoch 71:
100%|██████████| 340/340 [02:15<00:00,  2.64it/s, MSE=0.021] 
01:45:28 - INFO: Starting epoch 72:
100%|██████████| 340/340 [02:14<00:00,  2.85it/s, MSE=0.0166]
01:47:42 - INFO: Starting epoch 73:
100%|██████████| 340/340 [02:15<00:00,  3.31it/s, MSE=0.0408]
01:49:57 - INFO: Starting epoch 74:
100%|██████████| 340/340 [02:14<00:00,  3.06it/s, MSE=0.0705] 
01:52:12 - INFO: Starting epoch 75:
100%|██████████| 340/340 [02:14<00:00,  3.06it/s, MSE=0.0326]
01:54:26 - INFO: Starting epoch 76:
100%|██████████| 340/340 [02:13<00:00,  3.55it/s, MSE=0.016] 
01:56:39 - INFO: Starting epoch 77:
100%|██████████| 340/340 [02:13<00:00,  2.98it/s, MSE=0.0122]
01:58:53 - INFO: Starting epoch 78:
100%|██████████| 340/340 [02:13<00:00,  3.57it/s, MSE=0.0304]
02:01:06 - INFO: Starting epoch 79:
100%|██████████| 340/340 [02:14<00:00,  3.17it/s, MSE=0.0186]
02:03:21 - INFO: Starting epoch 80:
100%|██████████| 340/340 [02:14<00:00,  3.37it/s, MSE=0.0248]
02:05:35 - INFO: Sampling 8 new images....
499it [00:21, 22.82it/s]
复制代码
<Figure size 640x480 with 8 Axes>
复制代码
02:05:57 - INFO: Starting epoch 81:
100%|██████████| 340/340 [02:13<00:00,  2.93it/s, MSE=0.0321]
02:08:11 - INFO: Starting epoch 82:
100%|██████████| 340/340 [02:15<00:00,  2.76it/s, MSE=0.0274]
02:10:26 - INFO: Starting epoch 83:
100%|██████████| 340/340 [02:16<00:00,  3.49it/s, MSE=0.0069]
02:12:42 - INFO: Starting epoch 84:
100%|██████████| 340/340 [02:13<00:00,  3.05it/s, MSE=0.0847]
02:14:56 - INFO: Starting epoch 85:
100%|██████████| 340/340 [02:13<00:00,  3.23it/s, MSE=0.0237]
02:17:09 - INFO: Starting epoch 86:
100%|██████████| 340/340 [02:13<00:00,  2.71it/s, MSE=0.0124]
02:19:23 - INFO: Starting epoch 87:
100%|██████████| 340/340 [02:14<00:00,  3.69it/s, MSE=0.0537]
02:21:37 - INFO: Starting epoch 88:
100%|██████████| 340/340 [02:13<00:00,  3.13it/s, MSE=0.0463]
02:23:51 - INFO: Starting epoch 89:
100%|██████████| 340/340 [02:13<00:00,  2.85it/s, MSE=0.0137]
02:26:04 - INFO: Starting epoch 90:
100%|██████████| 340/340 [02:12<00:00,  3.05it/s, MSE=0.0198]
02:28:17 - INFO: Starting epoch 91:
100%|██████████| 340/340 [02:12<00:00,  3.31it/s, MSE=0.0205] 
02:30:30 - INFO: Starting epoch 92:
100%|██████████| 340/340 [02:12<00:00,  2.79it/s, MSE=0.0146]
02:32:43 - INFO: Starting epoch 93:
100%|██████████| 340/340 [02:12<00:00,  2.94it/s, MSE=0.00888]
02:34:56 - INFO: Starting epoch 94:
100%|██████████| 340/340 [02:12<00:00,  3.20it/s, MSE=0.0572]
02:37:08 - INFO: Starting epoch 95:
100%|██████████| 340/340 [02:13<00:00,  3.11it/s, MSE=0.021] 
02:39:22 - INFO: Starting epoch 96:
100%|██████████| 340/340 [02:13<00:00,  3.24it/s, MSE=0.0392]
02:41:35 - INFO: Starting epoch 97:
100%|██████████| 340/340 [02:12<00:00,  2.66it/s, MSE=0.0166]
02:43:48 - INFO: Starting epoch 98:
100%|██████████| 340/340 [02:14<00:00,  2.51it/s, MSE=0.0591]
02:46:03 - INFO: Starting epoch 99:
100%|██████████| 340/340 [02:16<00:00,  3.14it/s, MSE=0.0283]
02:48:19 - INFO: Starting epoch 100:
100%|██████████| 340/340 [02:13<00:00,  3.19it/s, MSE=0.0276]
02:50:33 - INFO: Sampling 8 new images....
499it [00:21, 23.23it/s]
复制代码
<Figure size 640x480 with 8 Axes>
复制代码
02:50:55 - INFO: Starting epoch 101:
100%|██████████| 340/340 [02:14<00:00,  3.48it/s, MSE=0.0293] 
02:53:10 - INFO: Starting epoch 102:
100%|██████████| 340/340 [02:16<00:00,  3.12it/s, MSE=0.0518]
02:55:27 - INFO: Starting epoch 103:
100%|██████████| 340/340 [02:14<00:00,  3.46it/s, MSE=0.0133]
02:57:42 - INFO: Starting epoch 104:
100%|██████████| 340/340 [02:15<00:00,  3.32it/s, MSE=0.0207] 
02:59:58 - INFO: Starting epoch 105:
100%|██████████| 340/340 [02:14<00:00,  3.26it/s, MSE=0.00727]
03:02:12 - INFO: Starting epoch 106:
100%|██████████| 340/340 [02:15<00:00,  3.81it/s, MSE=0.0319]
03:04:28 - INFO: Starting epoch 107:
100%|██████████| 340/340 [02:15<00:00,  3.11it/s, MSE=0.0348]
03:06:44 - INFO: Starting epoch 108:
100%|██████████| 340/340 [02:15<00:00,  3.34it/s, MSE=0.0245]
03:08:59 - INFO: Starting epoch 109:
100%|██████████| 340/340 [02:15<00:00,  3.24it/s, MSE=0.0139]
03:11:14 - INFO: Starting epoch 110:
100%|██████████| 340/340 [02:15<00:00,  3.23it/s, MSE=0.0311]
03:13:29 - INFO: Starting epoch 111:
100%|██████████| 340/340 [02:15<00:00,  3.53it/s, MSE=0.0234]
03:15:45 - INFO: Starting epoch 112:
100%|██████████| 340/340 [02:16<00:00,  3.13it/s, MSE=0.0158]
03:18:01 - INFO: Starting epoch 113:
100%|██████████| 340/340 [02:15<00:00,  3.44it/s, MSE=0.0315]
03:20:17 - INFO: Starting epoch 114:
100%|██████████| 340/340 [02:13<00:00,  3.16it/s, MSE=0.0187]
03:22:30 - INFO: Starting epoch 115:
100%|██████████| 340/340 [02:13<00:00,  3.23it/s, MSE=0.0228]
03:24:43 - INFO: Starting epoch 116:
100%|██████████| 340/340 [02:14<00:00,  3.04it/s, MSE=0.0607]
03:26:57 - INFO: Starting epoch 117:
100%|██████████| 340/340 [02:13<00:00,  3.34it/s, MSE=0.0217]
03:29:10 - INFO: Starting epoch 118:
100%|██████████| 340/340 [02:13<00:00,  3.28it/s, MSE=0.0131]
03:31:24 - INFO: Starting epoch 119:
100%|██████████| 340/340 [02:15<00:00,  3.54it/s, MSE=0.0618]
03:33:39 - INFO: Starting epoch 120:
100%|██████████| 340/340 [02:15<00:00,  3.08it/s, MSE=0.0388]
03:35:55 - INFO: Sampling 8 new images....
499it [00:21, 23.36it/s]
复制代码
<Figure size 640x480 with 8 Axes>
复制代码
03:36:16 - INFO: Starting epoch 121:
100%|██████████| 340/340 [02:19<00:00,  3.14it/s, MSE=0.0142] 
03:38:36 - INFO: Starting epoch 122:
100%|██████████| 340/340 [02:19<00:00,  2.97it/s, MSE=0.0112]
03:40:56 - INFO: Starting epoch 123:
100%|██████████| 340/340 [02:19<00:00,  2.84it/s, MSE=0.0243]
03:43:15 - INFO: Starting epoch 124:
100%|██████████| 340/340 [02:19<00:00,  3.11it/s, MSE=0.0312]
03:45:35 - INFO: Starting epoch 125:
100%|██████████| 340/340 [02:19<00:00,  3.26it/s, MSE=0.0513] 
03:47:54 - INFO: Starting epoch 126:
100%|██████████| 340/340 [02:18<00:00,  3.10it/s, MSE=0.0254]
03:50:13 - INFO: Starting epoch 127:
100%|██████████| 340/340 [02:17<00:00,  3.18it/s, MSE=0.00965]
03:52:30 - INFO: Starting epoch 128:
100%|██████████| 340/340 [02:17<00:00,  3.35it/s, MSE=0.0183]
03:54:47 - INFO: Starting epoch 129:
100%|██████████| 340/340 [02:17<00:00,  3.36it/s, MSE=0.0158]
03:57:05 - INFO: Starting epoch 130:
100%|██████████| 340/340 [02:18<00:00,  3.29it/s, MSE=0.0326]
03:59:24 - INFO: Starting epoch 131:
100%|██████████| 340/340 [02:17<00:00,  3.18it/s, MSE=0.0224]
04:01:42 - INFO: Starting epoch 132:
100%|██████████| 340/340 [02:16<00:00,  3.11it/s, MSE=0.0367]
04:03:58 - INFO: Starting epoch 133:
100%|██████████| 340/340 [02:18<00:00,  2.95it/s, MSE=0.0231]
04:06:16 - INFO: Starting epoch 134:
100%|██████████| 340/340 [02:19<00:00,  3.34it/s, MSE=0.0195]
04:08:35 - INFO: Starting epoch 135:
100%|██████████| 340/340 [02:18<00:00,  3.30it/s, MSE=0.00914]
04:10:54 - INFO: Starting epoch 136:
100%|██████████| 340/340 [02:19<00:00,  2.76it/s, MSE=0.0355]
04:13:13 - INFO: Starting epoch 137:
100%|██████████| 340/340 [02:19<00:00,  3.14it/s, MSE=0.0365]
04:15:33 - INFO: Starting epoch 138:
100%|██████████| 340/340 [02:20<00:00,  3.38it/s, MSE=0.0182] 
04:17:53 - INFO: Starting epoch 139:
100%|██████████| 340/340 [02:18<00:00,  3.19it/s, MSE=0.057]  
04:20:11 - INFO: Starting epoch 140:
100%|██████████| 340/340 [02:16<00:00,  3.27it/s, MSE=0.0156] 
04:22:28 - INFO: Sampling 8 new images....
499it [00:21, 22.81it/s]
复制代码
<Figure size 640x480 with 8 Axes>
复制代码
04:22:51 - INFO: Starting epoch 141:
100%|██████████| 340/340 [02:17<00:00,  3.11it/s, MSE=0.0256]
04:25:09 - INFO: Starting epoch 142:
100%|██████████| 340/340 [02:16<00:00,  2.82it/s, MSE=0.0271]
04:27:26 - INFO: Starting epoch 143:
100%|██████████| 340/340 [02:16<00:00,  3.35it/s, MSE=0.041] 
04:29:42 - INFO: Starting epoch 144:
100%|██████████| 340/340 [02:16<00:00,  3.04it/s, MSE=0.0126] 
04:31:59 - INFO: Starting epoch 145:
100%|██████████| 340/340 [02:16<00:00,  3.38it/s, MSE=0.0186]
04:34:16 - INFO: Starting epoch 146:
100%|██████████| 340/340 [02:19<00:00,  3.21it/s, MSE=0.0195]
04:36:36 - INFO: Starting epoch 147:
100%|██████████| 340/340 [02:19<00:00,  2.58it/s, MSE=0.00809]
04:38:55 - INFO: Starting epoch 148:
100%|██████████| 340/340 [02:20<00:00,  3.04it/s, MSE=0.0113]
04:41:15 - INFO: Starting epoch 149:
100%|██████████| 340/340 [02:19<00:00,  3.17it/s, MSE=0.013] 

2.6、使用训练好的模型进行采样

我们可以加载训练时觉得不错的模型进行采样生成。这个项目仅作为演示,生成汽车可能并不具备特别的价值。但是最新的novelai已经可以生成超高水平的二次元绘画,所以通过这个项目帮助我们理解diffusion模型的底层原理,可以让未来接触更多改进版的diffusion模型更加轻松。

In [6]

import paddle

model = UNet()
model.set_state_dict(paddle.load("car_models/ddpm_uncond140.pdparams"))   # 加载模型文件
diffusion = Diffusion(img_size=64, device="cuda")

sampled_images = diffusion.sample(model, n=8)

# 采样图片
for i in range(8):
    img = sampled_images[i].transpose([1, 2, 0])
    img = np.array(img).astype("uint8")
    plt.subplot(2, 4,i+1)
    plt.imshow(img)
plt.show()
复制代码
05:37:15 - INFO: Sampling 8 new images....
499it [00:22, 22.61it/s]
复制代码
<Figure size 640x480 with 8 Axes>

3、条件生成(通过标签指导图片生成)

3.1、训练过程解析

同非条件生成一样,我们使用前向过程采样得到标签,训练时使用Unet网络结构,同时在模型的输入中嵌入时间步的编码。这类似于transformer模型中的位置编码,让模型更容易训练。 这里我们额外添加类别的标签编码,也作为模型的输入。其中cfg表示条件生成与非条件生成之间的比值,cfg越大,生成的图像中条件生成的比例就越大(生成图像=(1-alpha)* 条件生成+(alpha)* 非条件生成),其中alpha与cfg相关。

  • ------cfg, classifier free guidance(标签引导)

另一方面,下面这个训练使用了上一代模型与当前模型参数的指数平均,削减因为离群点对模型参数更新的影响,从而实现更稳定的梯度更新。

  • ------ema, exponential moving average(指数移动平均)

运行下面代码前先重启内核!清空显存占用。

3.2、解压数据集

我们使用花朵数据集,包含5种种类,这样后面我们在采样时就可以指定其中一种种类进行生成。

In [1]

# 解压花朵数据集
import os
if not os.path.exists("work/flowers"):
    !mkdir work/flowers
!unzip -oq data/data173680/flowers.zip -d work/flowers

In [2]

# 加载数据集
"""由于条件生成需要同时提供图片标签,因此我们这里自定义数据集"""

# 1、将图片数据写入txt文件。flowers本来是分类数据集,这里我们把他的训练集和验证集都提取出来,当作我们生成模型的训练集。
import os
train_sunflower = os.listdir("work/flowers/pic/train/sunflower")            # 0------向日葵
valid_sunflower = os.listdir("work/flowers/pic/validation/sunflower")       # 0------向日葵
train_rose      = os.listdir("work/flowers/pic/train/rose")                 # 1------玫瑰
valid_rose      = os.listdir("work/flowers/pic/validation/rose")            # 1------玫瑰
train_tulip     = os.listdir("work/flowers/pic/train/tulip")                # 2------郁金香
valid_tulip     = os.listdir("work/flowers/pic/validation/tulip")           # 2------郁金香
train_dandelion = os.listdir("work/flowers/pic/train/dandelion")            # 3------蒲公英
valid_dandelion = os.listdir("work/flowers/pic/validation/dandelion")       # 3------蒲公英
train_daisy     = os.listdir("work/flowers/pic/train/daisy")                # 4------雏菊
valid_daisy     = os.listdir("work/flowers/pic/validation/daisy")           # 4------雏菊

with open("flowers_data.txt", 'w') as f:
    for image in train_sunflower:
        f.write("work/flowers/pic/train/sunflower/" + image + ";" + "0" + "\n")
    for image in valid_sunflower:
        f.write("work/flowers/pic/validation/sunflower/" + image + ";" + "0" + "\n")
    for image in train_rose:
        f.write("work/flowers/pic/train/rose/" + image + ";" + "1" + "\n")
    for image in valid_rose:
        f.write("work/flowers/pic/validation/rose/" + image + ";" + "1" + "\n")
    for image in train_tulip:
        f.write("work/flowers/pic/train/tulip/" + image + ";" + "2" + "\n")
    for image in valid_tulip:
        f.write("work/flowers/pic/validation/tulip/" + image + ";" + "2" + "\n")
    for image in train_dandelion:
        f.write("work/flowers/pic/train/dandelion/" + image + ";" + "3" + "\n")
    for image in valid_dandelion:
        f.write("work/flowers/pic/validation/dandelion/" + image + ";" + "3" + "\n")
    for image in train_daisy:
        f.write("work/flowers/pic/train/daisy/" + image + ";" + "4" + "\n")
    for image in valid_daisy:
        f.write("work/flowers/pic/validation/daisy/" + image + ";" + "4" + "\n")

3.3、构建数据集

因为这里我们的数据迭代器需要同时返回图片及标签。所以我们使用基础api构建我们的数据集。

In [3]

# 2、构建数据集
# 数据变化,返回图片与标签
import paddle.vision as V
from PIL import Image
from paddle.io import Dataset, DataLoader
from tqdm import tqdm

# 数据变换
transforms = V.transforms.Compose([
        V.transforms.Resize(80),  # args.image_size + 1/4 *args.image_size
        V.transforms.RandomResizedCrop(64, scale=(0.8, 1.0)),
        V.transforms.ToTensor(),
        V.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

class TrainDataFlowers(Dataset):
    def __init__(self, txt_path="flowers_data.txt"):
        with open(txt_path, "r") as f:
            data = f.readlines()
        self.image_paths = data[:-1]    # 最后一行是空行,舍弃
    
    def __getitem__(self, index):
        image_path, label = self.image_paths[index].split(";")
        image = Image.open(image_path)
        image = transforms(image)

        label = int(label)
        
        return image, label
    
    def __len__(self):
        return len(self.image_paths)

dataset = TrainDataFlowers()
dataloader = DataLoader(dataset, batch_size=24, shuffle=True)

if __name__ == "__main__": # 测试数据集是否可用
    pbar = tqdm(dataloader)
    for i, (images, labels) in enumerate(pbar):
        pass
    print("ok")
复制代码
  0%|          | 0/181 [00:00<?, ?it/s]W1023 15:49:27.184664  3398 gpu_resources.cc:61] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.2, Runtime API Version: 11.2
W1023 15:49:27.188580  3398 gpu_resources.cc:91] device: 0, cuDNN Version: 8.2.
100%|██████████| 181/181 [00:15<00:00, 11.37it/s]
复制代码
ok
复制代码

3.4、训练流程

训练中我们可以修改ARGS类的参数进行超参数定义。基本上,只要知道我们的损失函数是两张图片之间的均方误差,代码部分会变得比较简单。对比GAN而言,diffusion的参数更加容易调整,也更容易训练。

In [4]

import os
import paddle
import copy
import paddle.nn as nn
from matplotlib import pyplot as plt
%matplotlib inline
from tqdm import tqdm
from paddle import optimizer
from modules import UNet_conditional, EMA
import logging
import numpy as np
logging.basicConfig(format="%(asctime)s - %(levelname)s: %(message)s", level=logging.INFO, datefmt="%I:%M:%S")

class Diffusion:
    def __init__(self, noise_steps=500, beta_start=1e-4, beta_end=0.02, img_size=256, device="cuda"):
        self.noise_steps = noise_steps
        self.beta_start = beta_start
        self.beta_end = beta_end

        self.beta = self.prepare_noise_schedule()
        self.alpha = 1. - self.beta
        self.alpha_hat = paddle.cumprod(self.alpha, dim=0)

        self.img_size = img_size
        self.device = device

    def prepare_noise_schedule(self):
        return paddle.linspace(self.beta_start, self.beta_end, self.noise_steps)

    def noise_images(self, x, t):
        sqrt_alpha_hat = paddle.sqrt(self.alpha_hat[t])[:, None, None, None]
        sqrt_one_minus_alpha_hat = paddle.sqrt(1 - self.alpha_hat[t])[:, None, None, None]
        Ɛ = paddle.randn(shape=x.shape)
        return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * Ɛ, Ɛ

    def sample_timesteps(self, n):
        return paddle.randint(low=1, high=self.noise_steps, shape=(n,))

    def sample(self, model, n, labels, cfg_scale=3):
        logging.info(f"Sampling {n} new images....")
        model.eval()
        with paddle.no_grad():
            x = paddle.randn((n, 3, self.img_size, self.img_size))
            for i in tqdm(reversed(range(1, self.noise_steps)), position=0):
                t = paddle.to_tensor([i] * x.shape[0]).astype("int64")
                predicted_noise = model(x, t, labels)
                if cfg_scale > 0:
                    uncond_predicted_noise = model(x, t, None)
                    cfg_scale = paddle.to_tensor(cfg_scale).astype("float32")
                    predicted_noise = paddle.lerp(uncond_predicted_noise, predicted_noise, cfg_scale)
                alpha = self.alpha[t][:, None, None, None]
                alpha_hat = self.alpha_hat[t][:, None, None, None]
                beta = self.beta[t][:, None, None, None]
                if i > 1:
                    noise = paddle.randn(shape=x.shape)
                else:
                    noise = paddle.zeros_like(x)
                x = 1 / paddle.sqrt(alpha) * (x - ((1 - alpha) / (paddle.sqrt(1 - alpha_hat))) * predicted_noise) + paddle.sqrt(beta) * noise
        model.train()
        x = (x.clip(-1, 1) + 1) / 2
        x = (x * 255)
        return x


def train(args):
    # setup_logging(args.run_name)
    device = args.device
    dataloader = args.dataloader
    model = UNet_conditional(num_classes=args.num_classes)
    opt = optimizer.Adam(learning_rate=args.lr, parameters=model.parameters())
    mse = nn.MSELoss()
    diffusion = Diffusion(img_size=args.image_size, device=device)
    l = len(dataloader)
    ema = EMA(0.995)
    ema_model = copy.deepcopy(model)
    ema_model.eval()
    # print("ema_model", ema_model)

    for epoch in range(args.epochs):
        logging.info(f"Starting epoch {epoch}:")
        pbar = tqdm(dataloader)
        for i, (images, labels) in enumerate(pbar):
            t = diffusion.sample_timesteps(images.shape[0])
            x_t, noise = diffusion.noise_images(images, t)
            if np.random.random() < 0.1:
                labels = None
            predicted_noise = model(x_t, t, labels)
            loss = mse(noise, predicted_noise)  # 损失函数

            opt.clear_grad()
            loss.backward()
            opt.step()

            ema.step_ema(ema_model, model)
            pbar.set_postfix(MSE=loss.item())
            # logger.add_scalar("MSE", loss.item(), global_step=epoch * l + i)

        if epoch % 30 == 0:     # 保存模型,可视化训练结果。
            paddle.save(model.state_dict(), f"models/ddpm_cond{epoch}.pdparams")

            labels = paddle.arange(5).astype("int64")
            # 一共采样10张图片
            # 从左到右依次为-->向日葵,玫瑰,郁金香,蒲公英,雏菊
            sampled_images1 = diffusion.sample(model, n=len(labels), labels=labels)
            sampled_images2 = diffusion.sample(model, n=len(labels), labels=labels)
            # ema_sampled_images = diffusion.sample(ema_model, n=len(labels), labels=labels)
            for i in range(5):
                img = sampled_images1[i].transpose([1, 2, 0])
                img = np.array(img).astype("uint8")
                plt.subplot(2,5,i+1)
                plt.imshow(img)
            for i in range(5):
                img = sampled_images2[i].transpose([1, 2, 0])
                img = np.array(img).astype("uint8")
                plt.subplot(2,5,i+1+5)
                plt.imshow(img)
            plt.show()


def launch():
    import argparse

    # 参数设置
    class ARGS:
        def __init__(self):
            self.run_name = "DDPM_Uncondtional"
            self.epochs = 300
            self.batch_size = 48
            self.image_size = 64
            self.device = "cuda"
            self.lr = 1.5e-4
            self.num_classes = 5
            self.dataloader = dataloader


    args = ARGS()
    train(args)


if __name__ == '__main__':
    # 训练
    launch()
    pass
复制代码
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/__init__.py:107: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import MutableMapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/rcsetup.py:20: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Iterable, Mapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/colors.py:53: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Sized
03:56:58 - INFO: Starting epoch 0:
100%|██████████| 181/181 [01:04<00:00,  3.76it/s, MSE=0.172]
03:58:03 - INFO: Sampling 5 new images....
499it [00:44, 11.13it/s]
03:58:48 - INFO: Sampling 5 new images....
499it [00:44, 11.28it/s]
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2349: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  if isinstance(obj, collections.Iterator):
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2366: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  return list(data) if isinstance(data, collections.MappingView) else data
复制代码
<Figure size 640x480 with 10 Axes>
复制代码
03:59:33 - INFO: Starting epoch 1:
100%|██████████| 181/181 [01:02<00:00,  3.81it/s, MSE=0.104]
04:00:36 - INFO: Starting epoch 2:
100%|██████████| 181/181 [01:02<00:00,  3.78it/s, MSE=0.103] 
04:01:38 - INFO: Starting epoch 3:
100%|██████████| 181/181 [01:02<00:00,  3.75it/s, MSE=0.0912]
04:02:41 - INFO: Starting epoch 4:
100%|██████████| 181/181 [01:02<00:00,  3.80it/s, MSE=0.0649]
04:03:43 - INFO: Starting epoch 5:
100%|██████████| 181/181 [00:59<00:00,  4.66it/s, MSE=0.0631]
04:04:43 - INFO: Starting epoch 6:
100%|██████████| 181/181 [00:55<00:00,  4.72it/s, MSE=0.179] 
04:05:38 - INFO: Starting epoch 7:
100%|██████████| 181/181 [00:55<00:00,  4.59it/s, MSE=0.0908]
04:06:33 - INFO: Starting epoch 8:
100%|██████████| 181/181 [00:55<00:00,  4.69it/s, MSE=0.158] 
04:07:29 - INFO: Starting epoch 9:
100%|██████████| 181/181 [00:55<00:00,  4.58it/s, MSE=0.171] 
04:08:24 - INFO: Starting epoch 10:
100%|██████████| 181/181 [00:55<00:00,  4.65it/s, MSE=0.0362]
04:09:20 - INFO: Starting epoch 11:
100%|██████████| 181/181 [00:55<00:00,  4.60it/s, MSE=0.0444]
04:10:16 - INFO: Starting epoch 12:
100%|██████████| 181/181 [00:55<00:00,  4.63it/s, MSE=0.0393]
04:11:12 - INFO: Starting epoch 13:
100%|██████████| 181/181 [00:55<00:00,  4.66it/s, MSE=0.064] 
04:12:07 - INFO: Starting epoch 14:
100%|██████████| 181/181 [00:55<00:00,  4.67it/s, MSE=0.035] 
04:13:03 - INFO: Starting epoch 15:
100%|██████████| 181/181 [00:55<00:00,  4.60it/s, MSE=0.063] 
04:13:58 - INFO: Starting epoch 16:
100%|██████████| 181/181 [00:55<00:00,  4.62it/s, MSE=0.0157]
04:14:54 - INFO: Starting epoch 17:
100%|██████████| 181/181 [00:55<00:00,  4.69it/s, MSE=0.0159]
04:15:49 - INFO: Starting epoch 18:
100%|██████████| 181/181 [00:55<00:00,  4.61it/s, MSE=0.0212]
04:16:45 - INFO: Starting epoch 19:
100%|██████████| 181/181 [00:55<00:00,  4.61it/s, MSE=0.0252]
04:17:40 - INFO: Starting epoch 20:
100%|██████████| 181/181 [00:55<00:00,  4.64it/s, MSE=0.0192]
04:18:35 - INFO: Starting epoch 21:
100%|██████████| 181/181 [00:55<00:00,  4.61it/s, MSE=0.0361]
04:19:31 - INFO: Starting epoch 22:
100%|██████████| 181/181 [00:55<00:00,  4.58it/s, MSE=0.0177]
04:20:26 - INFO: Starting epoch 23:
100%|██████████| 181/181 [00:55<00:00,  4.54it/s, MSE=0.0527]
04:21:22 - INFO: Starting epoch 24:
100%|██████████| 181/181 [00:55<00:00,  4.59it/s, MSE=0.0458]
04:22:17 - INFO: Starting epoch 25:
100%|██████████| 181/181 [00:55<00:00,  4.66it/s, MSE=0.0539]
04:23:13 - INFO: Starting epoch 26:
100%|██████████| 181/181 [00:55<00:00,  4.64it/s, MSE=0.205] 
04:24:09 - INFO: Starting epoch 27:
100%|██████████| 181/181 [00:55<00:00,  4.66it/s, MSE=0.0463]
04:25:04 - INFO: Starting epoch 28:
100%|██████████| 181/181 [00:55<00:00,  4.64it/s, MSE=0.152] 
04:26:00 - INFO: Starting epoch 29:
100%|██████████| 181/181 [00:55<00:00,  4.59it/s, MSE=0.284] 
04:26:55 - INFO: Starting epoch 30:
100%|██████████| 181/181 [00:55<00:00,  4.71it/s, MSE=0.0896]
04:27:51 - INFO: Sampling 5 new images....
499it [00:44, 11.27it/s]
04:28:36 - INFO: Sampling 5 new images....
499it [00:45, 11.01it/s]
复制代码
<Figure size 640x480 with 10 Axes>
复制代码
04:29:21 - INFO: Starting epoch 31:
100%|██████████| 181/181 [00:55<00:00,  4.61it/s, MSE=0.299] 
04:30:17 - INFO: Starting epoch 32:
100%|██████████| 181/181 [00:55<00:00,  4.69it/s, MSE=0.0226]
04:31:12 - INFO: Starting epoch 33:
100%|██████████| 181/181 [00:55<00:00,  4.70it/s, MSE=0.00727]
04:32:08 - INFO: Starting epoch 34:
100%|██████████| 181/181 [00:55<00:00,  4.65it/s, MSE=0.132] 
04:33:03 - INFO: Starting epoch 35:
100%|██████████| 181/181 [00:55<00:00,  4.65it/s, MSE=0.0498]
04:33:59 - INFO: Starting epoch 36:
100%|██████████| 181/181 [00:55<00:00,  4.62it/s, MSE=0.0107]
04:34:55 - INFO: Starting epoch 37:
100%|██████████| 181/181 [00:55<00:00,  4.68it/s, MSE=0.0116]
04:35:50 - INFO: Starting epoch 38:
100%|██████████| 181/181 [00:55<00:00,  4.56it/s, MSE=0.044] 
04:36:46 - INFO: Starting epoch 39:
100%|██████████| 181/181 [00:55<00:00,  4.71it/s, MSE=0.167] 
04:37:41 - INFO: Starting epoch 40:
100%|██████████| 181/181 [00:55<00:00,  4.61it/s, MSE=0.0359]
04:38:37 - INFO: Starting epoch 41:
100%|██████████| 181/181 [00:55<00:00,  4.67it/s, MSE=0.0064]
04:39:33 - INFO: Starting epoch 42:
100%|██████████| 181/181 [00:55<00:00,  4.71it/s, MSE=0.0107]
04:40:28 - INFO: Starting epoch 43:
100%|██████████| 181/181 [00:55<00:00,  4.60it/s, MSE=0.0216]
04:41:24 - INFO: Starting epoch 44:
100%|██████████| 181/181 [00:55<00:00,  4.66it/s, MSE=0.0361]
04:42:20 - INFO: Starting epoch 45:
100%|██████████| 181/181 [00:55<00:00,  4.59it/s, MSE=0.0368]
04:43:15 - INFO: Starting epoch 46:
100%|██████████| 181/181 [00:55<00:00,  4.65it/s, MSE=0.0283]
04:44:10 - INFO: Starting epoch 47:
100%|██████████| 181/181 [00:55<00:00,  4.63it/s, MSE=0.0352]
04:45:06 - INFO: Starting epoch 48:
100%|██████████| 181/181 [00:55<00:00,  4.66it/s, MSE=0.0499]
04:46:01 - INFO: Starting epoch 49:
100%|██████████| 181/181 [00:55<00:00,  4.64it/s, MSE=0.0359]
04:46:56 - INFO: Starting epoch 50:
100%|██████████| 181/181 [00:55<00:00,  4.62it/s, MSE=0.0555]
04:47:52 - INFO: Starting epoch 51:
100%|██████████| 181/181 [00:55<00:00,  4.68it/s, MSE=0.14]  
04:48:47 - INFO: Starting epoch 52:
100%|██████████| 181/181 [00:54<00:00,  4.65it/s, MSE=0.0136]
04:49:42 - INFO: Starting epoch 53:
100%|██████████| 181/181 [00:55<00:00,  4.60it/s, MSE=0.0242]
04:50:38 - INFO: Starting epoch 54:
100%|██████████| 181/181 [00:55<00:00,  4.65it/s, MSE=0.0252]
04:51:33 - INFO: Starting epoch 55:
100%|██████████| 181/181 [00:55<00:00,  4.64it/s, MSE=0.0274]
04:52:29 - INFO: Starting epoch 56:
100%|██████████| 181/181 [00:55<00:00,  4.58it/s, MSE=0.0727]
04:53:24 - INFO: Starting epoch 57:
100%|██████████| 181/181 [00:55<00:00,  4.72it/s, MSE=0.023] 
04:54:20 - INFO: Starting epoch 58:
100%|██████████| 181/181 [00:55<00:00,  4.65it/s, MSE=0.0457]
04:55:15 - INFO: Starting epoch 59:
100%|██████████| 181/181 [00:55<00:00,  4.61it/s, MSE=0.0123]
04:56:10 - INFO: Starting epoch 60:
100%|██████████| 181/181 [00:55<00:00,  4.65it/s, MSE=0.0125]
04:57:07 - INFO: Sampling 5 new images....
499it [00:45, 10.91it/s]
04:57:52 - INFO: Sampling 5 new images....
499it [00:45, 10.90it/s]
复制代码
<Figure size 640x480 with 10 Axes>
复制代码
04:58:39 - INFO: Starting epoch 61:
100%|██████████| 181/181 [00:56<00:00,  4.53it/s, MSE=0.00765]
04:59:35 - INFO: Starting epoch 62:
100%|██████████| 181/181 [00:55<00:00,  4.63it/s, MSE=0.00355]
05:00:30 - INFO: Starting epoch 63:
100%|██████████| 181/181 [00:55<00:00,  4.61it/s, MSE=0.0256]
05:01:26 - INFO: Starting epoch 64:
100%|██████████| 181/181 [00:55<00:00,  4.66it/s, MSE=0.0413]
05:02:22 - INFO: Starting epoch 65:
100%|██████████| 181/181 [00:55<00:00,  4.65it/s, MSE=0.0146]
05:03:17 - INFO: Starting epoch 66:
100%|██████████| 181/181 [00:56<00:00,  4.57it/s, MSE=0.00737]
05:04:13 - INFO: Starting epoch 67:
100%|██████████| 181/181 [00:56<00:00,  4.63it/s, MSE=0.00363]
05:05:09 - INFO: Starting epoch 68:
100%|██████████| 181/181 [00:56<00:00,  4.58it/s, MSE=0.121] 
05:06:06 - INFO: Starting epoch 69:
100%|██████████| 181/181 [00:56<00:00,  4.53it/s, MSE=0.0124]
05:07:02 - INFO: Starting epoch 70:
100%|██████████| 181/181 [00:56<00:00,  4.53it/s, MSE=0.0235]
05:07:59 - INFO: Starting epoch 71:
100%|██████████| 181/181 [00:55<00:00,  4.52it/s, MSE=0.084] 
05:08:55 - INFO: Starting epoch 72:
100%|██████████| 181/181 [00:55<00:00,  4.61it/s, MSE=0.022] 
05:09:50 - INFO: Starting epoch 73:
100%|██████████| 181/181 [00:56<00:00,  4.61it/s, MSE=0.00922]
05:10:47 - INFO: Starting epoch 74:
100%|██████████| 181/181 [00:56<00:00,  4.27it/s, MSE=0.0059]
05:11:43 - INFO: Starting epoch 75:
100%|██████████| 181/181 [00:56<00:00,  4.60it/s, MSE=0.00901]
05:12:40 - INFO: Starting epoch 76:
100%|██████████| 181/181 [00:56<00:00,  4.60it/s, MSE=0.0261]
05:13:36 - INFO: Starting epoch 77:
100%|██████████| 181/181 [00:55<00:00,  4.72it/s, MSE=0.0317]
05:14:32 - INFO: Starting epoch 78:
100%|██████████| 181/181 [00:55<00:00,  4.70it/s, MSE=0.0379]
05:15:27 - INFO: Starting epoch 79:
100%|██████████| 181/181 [00:54<00:00,  4.62it/s, MSE=0.0126]
05:16:22 - INFO: Starting epoch 80:
100%|██████████| 181/181 [00:55<00:00,  4.57it/s, MSE=0.0129]
05:17:17 - INFO: Starting epoch 81:
100%|██████████| 181/181 [00:55<00:00,  4.61it/s, MSE=0.0174]
05:18:13 - INFO: Starting epoch 82:
100%|██████████| 181/181 [00:57<00:00,  4.59it/s, MSE=0.00267]
05:19:11 - INFO: Starting epoch 83:
100%|██████████| 181/181 [00:57<00:00,  4.61it/s, MSE=0.00863]
05:20:08 - INFO: Starting epoch 84:
100%|██████████| 181/181 [00:55<00:00,  4.59it/s, MSE=0.0928]
05:21:04 - INFO: Starting epoch 85:
100%|██████████| 181/181 [00:55<00:00,  4.71it/s, MSE=0.0151]
05:21:59 - INFO: Starting epoch 86:
100%|██████████| 181/181 [00:55<00:00,  4.60it/s, MSE=0.0231]
05:22:55 - INFO: Starting epoch 87:
100%|██████████| 181/181 [00:55<00:00,  4.50it/s, MSE=0.0442]
05:23:51 - INFO: Starting epoch 88:
100%|██████████| 181/181 [00:55<00:00,  4.58it/s, MSE=0.00999]
05:24:47 - INFO: Starting epoch 89:
100%|██████████| 181/181 [00:55<00:00,  4.57it/s, MSE=0.00467]
05:25:42 - INFO: Starting epoch 90:
100%|██████████| 181/181 [00:55<00:00,  4.64it/s, MSE=0.0219]
05:26:38 - INFO: Sampling 5 new images....
499it [00:44, 11.12it/s]
05:27:23 - INFO: Sampling 5 new images....
499it [00:45, 11.06it/s]
复制代码
<Figure size 640x480 with 10 Axes>
复制代码
05:28:09 - INFO: Starting epoch 91:
100%|██████████| 181/181 [00:55<00:00,  4.70it/s, MSE=0.00285]
05:29:05 - INFO: Starting epoch 92:
100%|██████████| 181/181 [00:55<00:00,  4.67it/s, MSE=0.112] 
05:30:00 - INFO: Starting epoch 93:
100%|██████████| 181/181 [00:55<00:00,  4.67it/s, MSE=0.0108]
05:30:56 - INFO: Starting epoch 94:
100%|██████████| 181/181 [00:55<00:00,  4.62it/s, MSE=0.0281]
05:31:51 - INFO: Starting epoch 95:
100%|██████████| 181/181 [00:55<00:00,  4.70it/s, MSE=0.0355]
05:32:47 - INFO: Starting epoch 96:
100%|██████████| 181/181 [00:55<00:00,  4.62it/s, MSE=0.133] 
05:33:42 - INFO: Starting epoch 97:
100%|██████████| 181/181 [00:56<00:00,  4.56it/s, MSE=0.0138]
05:34:39 - INFO: Starting epoch 98:
100%|██████████| 181/181 [00:56<00:00,  4.66it/s, MSE=0.00963]
05:35:35 - INFO: Starting epoch 99:
100%|██████████| 181/181 [00:56<00:00,  4.59it/s, MSE=0.0298]
05:36:31 - INFO: Starting epoch 100:
100%|██████████| 181/181 [00:56<00:00,  4.65it/s, MSE=0.00709]
05:37:27 - INFO: Starting epoch 101:
100%|██████████| 181/181 [00:55<00:00,  4.64it/s, MSE=0.0737]
05:38:23 - INFO: Starting epoch 102:
100%|██████████| 181/181 [00:55<00:00,  4.61it/s, MSE=0.0105]
05:39:18 - INFO: Starting epoch 103:
100%|██████████| 181/181 [00:56<00:00,  4.56it/s, MSE=0.00631]
05:40:14 - INFO: Starting epoch 104:
100%|██████████| 181/181 [00:55<00:00,  4.55it/s, MSE=0.00662]
05:41:10 - INFO: Starting epoch 105:
100%|██████████| 181/181 [00:55<00:00,  4.67it/s, MSE=0.262] 
05:42:05 - INFO: Starting epoch 106:
100%|██████████| 181/181 [00:55<00:00,  4.64it/s, MSE=0.0206]
05:43:00 - INFO: Starting epoch 107:
100%|██████████| 181/181 [00:55<00:00,  4.62it/s, MSE=0.00979]
05:43:56 - INFO: Starting epoch 108:
100%|██████████| 181/181 [00:55<00:00,  4.64it/s, MSE=0.0121]
05:44:52 - INFO: Starting epoch 109:
100%|██████████| 181/181 [00:55<00:00,  4.65it/s, MSE=0.00493]
05:45:48 - INFO: Starting epoch 110:
100%|██████████| 181/181 [00:55<00:00,  4.66it/s, MSE=0.0158]
05:46:43 - INFO: Starting epoch 111:
100%|██████████| 181/181 [00:55<00:00,  4.61it/s, MSE=0.00567]
05:47:39 - INFO: Starting epoch 112:
100%|██████████| 181/181 [00:55<00:00,  4.64it/s, MSE=0.00994]
05:48:34 - INFO: Starting epoch 113:
100%|██████████| 181/181 [00:56<00:00,  4.57it/s, MSE=0.00712]
05:49:31 - INFO: Starting epoch 114:
100%|██████████| 181/181 [00:56<00:00,  4.61it/s, MSE=0.0414]
05:50:27 - INFO: Starting epoch 115:
100%|██████████| 181/181 [00:56<00:00,  4.61it/s, MSE=0.00445]
05:51:23 - INFO: Starting epoch 116:
100%|██████████| 181/181 [00:55<00:00,  4.66it/s, MSE=0.0967]
05:52:19 - INFO: Starting epoch 117:
100%|██████████| 181/181 [00:55<00:00,  4.55it/s, MSE=0.0384]
05:53:15 - INFO: Starting epoch 118:
100%|██████████| 181/181 [00:55<00:00,  4.60it/s, MSE=0.0122]
05:54:10 - INFO: Starting epoch 119:
100%|██████████| 181/181 [00:55<00:00,  4.66it/s, MSE=0.0342]
05:55:06 - INFO: Starting epoch 120:
100%|██████████| 181/181 [00:56<00:00,  4.63it/s, MSE=0.0257]
05:56:02 - INFO: Sampling 5 new images....
499it [00:44, 11.29it/s]
05:56:46 - INFO: Sampling 5 new images....
499it [00:45, 11.03it/s]
复制代码
<Figure size 640x480 with 10 Axes>
复制代码
05:57:32 - INFO: Starting epoch 121:
100%|██████████| 181/181 [00:55<00:00,  4.60it/s, MSE=0.00285]
05:58:28 - INFO: Starting epoch 122:
100%|██████████| 181/181 [00:55<00:00,  4.60it/s, MSE=0.0274]
05:59:24 - INFO: Starting epoch 123:
100%|██████████| 181/181 [00:55<00:00,  4.68it/s, MSE=0.0629]
06:00:19 - INFO: Starting epoch 124:
100%|██████████| 181/181 [00:55<00:00,  4.61it/s, MSE=0.0203]
06:01:15 - INFO: Starting epoch 125:
100%|██████████| 181/181 [00:55<00:00,  4.68it/s, MSE=0.0619]
06:02:10 - INFO: Starting epoch 126:
100%|██████████| 181/181 [00:55<00:00,  4.69it/s, MSE=0.0456]
06:03:05 - INFO: Starting epoch 127:
100%|██████████| 181/181 [00:55<00:00,  4.54it/s, MSE=0.0157]
06:04:01 - INFO: Starting epoch 128:
100%|██████████| 181/181 [00:55<00:00,  4.65it/s, MSE=0.0136]
06:04:57 - INFO: Starting epoch 129:
100%|██████████| 181/181 [00:55<00:00,  4.63it/s, MSE=0.115] 
06:05:52 - INFO: Starting epoch 130:
100%|██████████| 181/181 [00:55<00:00,  4.68it/s, MSE=0.0519]
06:06:48 - INFO: Starting epoch 131:
100%|██████████| 181/181 [00:56<00:00,  4.61it/s, MSE=0.0179]
06:07:44 - INFO: Starting epoch 132:
100%|██████████| 181/181 [00:56<00:00,  4.72it/s, MSE=0.0211]
06:08:40 - INFO: Starting epoch 133:
100%|██████████| 181/181 [00:55<00:00,  4.69it/s, MSE=0.0172]
06:09:36 - INFO: Starting epoch 134:
100%|██████████| 181/181 [00:55<00:00,  4.58it/s, MSE=0.134] 
06:10:31 - INFO: Starting epoch 135:
100%|██████████| 181/181 [00:55<00:00,  4.56it/s, MSE=0.201] 
06:11:26 - INFO: Starting epoch 136:
100%|██████████| 181/181 [00:55<00:00,  4.64it/s, MSE=0.0325]
06:12:22 - INFO: Starting epoch 137:
100%|██████████| 181/181 [00:55<00:00,  4.70it/s, MSE=0.0203]
06:13:17 - INFO: Starting epoch 138:
100%|██████████| 181/181 [00:55<00:00,  4.60it/s, MSE=0.00265]
06:14:13 - INFO: Starting epoch 139:
100%|██████████| 181/181 [00:55<00:00,  4.66it/s, MSE=0.00424]
06:15:08 - INFO: Starting epoch 140:
100%|██████████| 181/181 [00:55<00:00,  4.73it/s, MSE=0.00383]
06:16:03 - INFO: Starting epoch 141:
100%|██████████| 181/181 [00:54<00:00,  4.66it/s, MSE=0.0153]
06:16:58 - INFO: Starting epoch 142:
100%|██████████| 181/181 [00:54<00:00,  4.68it/s, MSE=0.0284]
06:17:53 - INFO: Starting epoch 143:
100%|██████████| 181/181 [00:55<00:00,  4.55it/s, MSE=0.00366]
06:18:48 - INFO: Starting epoch 144:
100%|██████████| 181/181 [00:56<00:00,  4.52it/s, MSE=0.0912]
06:19:44 - INFO: Starting epoch 145:
100%|██████████| 181/181 [00:55<00:00,  4.65it/s, MSE=0.00704]
06:20:40 - INFO: Starting epoch 146:
100%|██████████| 181/181 [00:55<00:00,  4.72it/s, MSE=0.0042]
06:21:36 - INFO: Starting epoch 147:
100%|██████████| 181/181 [00:54<00:00,  4.63it/s, MSE=0.011] 
06:22:31 - INFO: Starting epoch 148:
100%|██████████| 181/181 [00:54<00:00,  4.66it/s, MSE=0.0379]
06:23:26 - INFO: Starting epoch 149:
100%|██████████| 181/181 [00:55<00:00,  4.65it/s, MSE=0.11]  
06:24:21 - INFO: Starting epoch 150:
100%|██████████| 181/181 [00:55<00:00,  4.72it/s, MSE=0.00667]
06:25:17 - INFO: Sampling 5 new images....
499it [00:43, 11.47it/s]
06:26:01 - INFO: Sampling 5 new images....
499it [00:42, 11.65it/s]
复制代码
<Figure size 640x480 with 10 Axes>
复制代码
06:26:44 - INFO: Starting epoch 151:
100%|██████████| 181/181 [00:55<00:00,  4.66it/s, MSE=0.00426]
06:27:39 - INFO: Starting epoch 152:
100%|██████████| 181/181 [00:55<00:00,  4.69it/s, MSE=0.0859]
06:28:34 - INFO: Starting epoch 153:
100%|██████████| 181/181 [00:55<00:00,  4.72it/s, MSE=0.0238]
06:29:29 - INFO: Starting epoch 154:
100%|██████████| 181/181 [00:54<00:00,  4.69it/s, MSE=0.0261]
06:30:24 - INFO: Starting epoch 155:
100%|██████████| 181/181 [00:54<00:00,  4.73it/s, MSE=0.049] 
06:31:19 - INFO: Starting epoch 156:
100%|██████████| 181/181 [00:54<00:00,  4.68it/s, MSE=0.00625]
06:32:14 - INFO: Starting epoch 157:
100%|██████████| 181/181 [00:55<00:00,  4.71it/s, MSE=0.0107]
06:33:09 - INFO: Starting epoch 158:
100%|██████████| 181/181 [00:55<00:00,  4.66it/s, MSE=0.13]  
06:34:04 - INFO: Starting epoch 159:
100%|██████████| 181/181 [00:55<00:00,  4.67it/s, MSE=0.0495]
06:34:59 - INFO: Starting epoch 160:
100%|██████████| 181/181 [00:55<00:00,  4.65it/s, MSE=0.0112]
06:35:54 - INFO: Starting epoch 161:
100%|██████████| 181/181 [00:55<00:00,  4.69it/s, MSE=0.00525]
06:36:49 - INFO: Starting epoch 162:
100%|██████████| 181/181 [00:55<00:00,  4.61it/s, MSE=0.00437]
06:37:44 - INFO: Starting epoch 163:
100%|██████████| 181/181 [00:54<00:00,  4.72it/s, MSE=0.00408]
06:38:39 - INFO: Starting epoch 164:
100%|██████████| 181/181 [00:55<00:00,  4.68it/s, MSE=0.0177]
06:39:35 - INFO: Starting epoch 165:
100%|██████████| 181/181 [00:55<00:00,  4.71it/s, MSE=0.00417]
06:40:30 - INFO: Starting epoch 166:
100%|██████████| 181/181 [00:55<00:00,  4.65it/s, MSE=0.0786]
06:41:25 - INFO: Starting epoch 167:
100%|██████████| 181/181 [00:55<00:00,  4.67it/s, MSE=0.0205]
06:42:20 - INFO: Starting epoch 168:
100%|██████████| 181/181 [00:55<00:00,  4.69it/s, MSE=0.0952]
06:43:15 - INFO: Starting epoch 169:
100%|██████████| 181/181 [00:55<00:00,  4.66it/s, MSE=0.0118]
06:44:10 - INFO: Starting epoch 170:
100%|██████████| 181/181 [00:55<00:00,  4.62it/s, MSE=0.253] 
06:45:06 - INFO: Starting epoch 171:
100%|██████████| 181/181 [00:55<00:00,  4.71it/s, MSE=0.00373]
06:46:01 - INFO: Starting epoch 172:
100%|██████████| 181/181 [00:54<00:00,  4.67it/s, MSE=0.00618]
06:46:56 - INFO: Starting epoch 173:
100%|██████████| 181/181 [00:54<00:00,  4.67it/s, MSE=0.0083]
06:47:50 - INFO: Starting epoch 174:
100%|██████████| 181/181 [00:54<00:00,  4.67it/s, MSE=0.0171]
06:48:45 - INFO: Starting epoch 175:
100%|██████████| 181/181 [00:54<00:00,  4.60it/s, MSE=0.0216]
06:49:40 - INFO: Starting epoch 176:
100%|██████████| 181/181 [00:55<00:00,  4.71it/s, MSE=0.00168]
06:50:35 - INFO: Starting epoch 177:
100%|██████████| 181/181 [00:55<00:00,  4.69it/s, MSE=0.0166]
06:51:30 - INFO: Starting epoch 178:
100%|██████████| 181/181 [00:54<00:00,  4.72it/s, MSE=0.00656]
06:52:25 - INFO: Starting epoch 179:
100%|██████████| 181/181 [00:54<00:00,  4.68it/s, MSE=0.114] 
06:53:20 - INFO: Starting epoch 180:
100%|██████████| 181/181 [00:54<00:00,  4.70it/s, MSE=0.00226]
06:54:15 - INFO: Sampling 5 new images....
499it [00:42, 11.66it/s]
06:54:58 - INFO: Sampling 5 new images....
499it [00:43, 11.56it/s]
复制代码
<Figure size 640x480 with 10 Axes>
复制代码
06:55:42 - INFO: Starting epoch 181:
100%|██████████| 181/181 [00:54<00:00,  4.72it/s, MSE=0.0484]
06:56:37 - INFO: Starting epoch 182:
100%|██████████| 181/181 [00:54<00:00,  4.69it/s, MSE=0.0196]
06:57:32 - INFO: Starting epoch 183:
100%|██████████| 181/181 [00:54<00:00,  4.68it/s, MSE=0.00695]
06:58:27 - INFO: Starting epoch 184:
100%|██████████| 181/181 [00:54<00:00,  4.69it/s, MSE=0.0515]
06:59:21 - INFO: Starting epoch 185:
100%|██████████| 181/181 [00:54<00:00,  4.67it/s, MSE=0.00296]
07:00:16 - INFO: Starting epoch 186:
100%|██████████| 181/181 [00:54<00:00,  4.76it/s, MSE=0.0878]
07:01:11 - INFO: Starting epoch 187:
100%|██████████| 181/181 [00:54<00:00,  4.73it/s, MSE=0.0574]
07:02:06 - INFO: Starting epoch 188:
100%|██████████| 181/181 [00:54<00:00,  4.75it/s, MSE=0.00468]
07:03:00 - INFO: Starting epoch 189:
100%|██████████| 181/181 [00:54<00:00,  4.61it/s, MSE=0.0289]
07:03:55 - INFO: Starting epoch 190:
100%|██████████| 181/181 [00:54<00:00,  4.73it/s, MSE=0.0167]
07:04:50 - INFO: Starting epoch 191:
100%|██████████| 181/181 [00:54<00:00,  4.65it/s, MSE=0.0505]
07:05:45 - INFO: Starting epoch 192:
100%|██████████| 181/181 [00:54<00:00,  4.78it/s, MSE=0.00374]
07:06:39 - INFO: Starting epoch 193:
100%|██████████| 181/181 [00:54<00:00,  4.67it/s, MSE=0.0176]
07:07:34 - INFO: Starting epoch 194:
100%|██████████| 181/181 [00:54<00:00,  4.76it/s, MSE=0.0161]
07:08:29 - INFO: Starting epoch 195:
100%|██████████| 181/181 [00:54<00:00,  4.63it/s, MSE=0.0161]
07:09:24 - INFO: Starting epoch 196:
100%|██████████| 181/181 [00:54<00:00,  4.66it/s, MSE=0.0358]
07:10:18 - INFO: Starting epoch 197:
100%|██████████| 181/181 [00:54<00:00,  4.67it/s, MSE=0.0694]
07:11:13 - INFO: Starting epoch 198:
100%|██████████| 181/181 [00:55<00:00,  4.66it/s, MSE=0.107] 
07:12:09 - INFO: Starting epoch 199:
100%|██████████| 181/181 [00:55<00:00,  4.69it/s, MSE=0.0383]
07:13:04 - INFO: Starting epoch 200:
100%|██████████| 181/181 [00:55<00:00,  4.65it/s, MSE=0.0169]
07:13:59 - INFO: Starting epoch 201:
100%|██████████| 181/181 [00:55<00:00,  4.61it/s, MSE=0.0855]
07:14:54 - INFO: Starting epoch 202:
100%|██████████| 181/181 [00:55<00:00,  4.68it/s, MSE=0.00749]
07:15:49 - INFO: Starting epoch 203:
100%|██████████| 181/181 [00:55<00:00,  4.60it/s, MSE=0.00324]
07:16:45 - INFO: Starting epoch 204:
100%|██████████| 181/181 [00:54<00:00,  4.74it/s, MSE=0.0965]
07:17:40 - INFO: Starting epoch 205:
100%|██████████| 181/181 [00:54<00:00,  4.70it/s, MSE=0.0277]
07:18:34 - INFO: Starting epoch 206:
100%|██████████| 181/181 [00:54<00:00,  4.71it/s, MSE=0.0146]
07:19:29 - INFO: Starting epoch 207:
100%|██████████| 181/181 [00:54<00:00,  4.69it/s, MSE=0.00659]
07:20:24 - INFO: Starting epoch 208:
100%|██████████| 181/181 [00:54<00:00,  4.73it/s, MSE=0.0176]
07:21:19 - INFO: Starting epoch 209:
100%|██████████| 181/181 [00:54<00:00,  4.65it/s, MSE=0.12]  
07:22:14 - INFO: Starting epoch 210:
100%|██████████| 181/181 [00:54<00:00,  4.76it/s, MSE=0.0688]
07:23:10 - INFO: Sampling 5 new images....
499it [00:43, 11.17it/s]
07:23:53 - INFO: Sampling 5 new images....
499it [00:43, 11.55it/s]
复制代码
<Figure size 640x480 with 10 Axes>
复制代码
07:24:37 - INFO: Starting epoch 211:
100%|██████████| 181/181 [00:54<00:00,  4.72it/s, MSE=0.00553]
07:25:31 - INFO: Starting epoch 212:
100%|██████████| 181/181 [00:55<00:00,  4.68it/s, MSE=0.0851]
07:26:26 - INFO: Starting epoch 213:
100%|██████████| 181/181 [00:54<00:00,  4.69it/s, MSE=0.0147]
07:27:21 - INFO: Starting epoch 214:
100%|██████████| 181/181 [00:55<00:00,  4.75it/s, MSE=0.0669]
07:28:16 - INFO: Starting epoch 215:
100%|██████████| 181/181 [00:54<00:00,  4.69it/s, MSE=0.00531]
07:29:11 - INFO: Starting epoch 216:
100%|██████████| 181/181 [00:54<00:00,  4.63it/s, MSE=0.0315]
07:30:06 - INFO: Starting epoch 217:
100%|██████████| 181/181 [00:54<00:00,  4.76it/s, MSE=0.147] 
07:31:01 - INFO: Starting epoch 218:
100%|██████████| 181/181 [00:54<00:00,  4.63it/s, MSE=0.0547]
07:31:56 - INFO: Starting epoch 219:
100%|██████████| 181/181 [00:54<00:00,  4.74it/s, MSE=0.036] 
07:32:50 - INFO: Starting epoch 220:
100%|██████████| 181/181 [00:54<00:00,  4.67it/s, MSE=0.00479]
07:33:45 - INFO: Starting epoch 221:
100%|██████████| 181/181 [00:55<00:00,  4.71it/s, MSE=0.0225]
07:34:40 - INFO: Starting epoch 222:
100%|██████████| 181/181 [00:54<00:00,  4.67it/s, MSE=0.0192]
07:35:35 - INFO: Starting epoch 223:
100%|██████████| 181/181 [00:54<00:00,  4.69it/s, MSE=0.00701]
07:36:30 - INFO: Starting epoch 224:
100%|██████████| 181/181 [00:54<00:00,  4.56it/s, MSE=0.036] 
07:37:25 - INFO: Starting epoch 225:
100%|██████████| 181/181 [00:54<00:00,  4.79it/s, MSE=0.0908]
07:38:19 - INFO: Starting epoch 226:
100%|██████████| 181/181 [00:55<00:00,  4.72it/s, MSE=0.00345]
07:39:14 - INFO: Starting epoch 227:
100%|██████████| 181/181 [00:54<00:00,  4.73it/s, MSE=0.0657]
07:40:09 - INFO: Starting epoch 228:
100%|██████████| 181/181 [00:54<00:00,  4.69it/s, MSE=0.0841]
07:41:04 - INFO: Starting epoch 229:
100%|██████████| 181/181 [00:54<00:00,  4.72it/s, MSE=0.00407]
07:41:59 - INFO: Starting epoch 230:
100%|██████████| 181/181 [00:54<00:00,  4.66it/s, MSE=0.0249]
07:42:54 - INFO: Starting epoch 231:
100%|██████████| 181/181 [00:54<00:00,  4.71it/s, MSE=0.0563]
07:43:48 - INFO: Starting epoch 232:
100%|██████████| 181/181 [00:54<00:00,  4.67it/s, MSE=0.052] 
07:44:43 - INFO: Starting epoch 233:
100%|██████████| 181/181 [00:54<00:00,  4.74it/s, MSE=0.0698]
07:45:38 - INFO: Starting epoch 234:
100%|██████████| 181/181 [00:54<00:00,  4.65it/s, MSE=0.0553] 
07:46:33 - INFO: Starting epoch 235:
100%|██████████| 181/181 [00:54<00:00,  4.72it/s, MSE=0.00646]
07:47:27 - INFO: Starting epoch 236:
100%|██████████| 181/181 [00:54<00:00,  4.68it/s, MSE=0.00258]
07:48:22 - INFO: Starting epoch 237:
100%|██████████| 181/181 [00:55<00:00,  4.66it/s, MSE=0.0236]
07:49:17 - INFO: Starting epoch 238:
100%|██████████| 181/181 [00:54<00:00,  4.68it/s, MSE=0.0339]
07:50:12 - INFO: Starting epoch 239:
100%|██████████| 181/181 [00:54<00:00,  4.65it/s, MSE=0.00555]
07:51:07 - INFO: Starting epoch 240:
100%|██████████| 181/181 [00:55<00:00,  4.60it/s, MSE=0.0571]
07:52:03 - INFO: Sampling 5 new images....
499it [00:42, 11.79it/s]
07:52:45 - INFO: Sampling 5 new images....
499it [00:42, 11.62it/s]
复制代码
<Figure size 640x480 with 10 Axes>
复制代码
07:53:28 - INFO: Starting epoch 241:
100%|██████████| 181/181 [00:55<00:00,  4.68it/s, MSE=0.00577]
07:54:24 - INFO: Starting epoch 242:
100%|██████████| 181/181 [00:54<00:00,  4.69it/s, MSE=0.0155]
07:55:18 - INFO: Starting epoch 243:
100%|██████████| 181/181 [00:54<00:00,  4.64it/s, MSE=0.0322]
07:56:13 - INFO: Starting epoch 244:
100%|██████████| 181/181 [00:55<00:00,  4.76it/s, MSE=0.00787]
07:57:08 - INFO: Starting epoch 245:
100%|██████████| 181/181 [00:55<00:00,  4.69it/s, MSE=0.116] 
07:58:03 - INFO: Starting epoch 246:
100%|██████████| 181/181 [00:54<00:00,  4.70it/s, MSE=0.0187]
07:58:58 - INFO: Starting epoch 247:
100%|██████████| 181/181 [00:54<00:00,  4.70it/s, MSE=0.059] 
07:59:52 - INFO: Starting epoch 248:
100%|██████████| 181/181 [00:55<00:00,  4.60it/s, MSE=0.0248]
08:00:48 - INFO: Starting epoch 249:
100%|██████████| 181/181 [00:54<00:00,  4.60it/s, MSE=0.0254]
08:01:42 - INFO: Starting epoch 250:
100%|██████████| 181/181 [00:55<00:00,  4.70it/s, MSE=0.133] 
08:02:38 - INFO: Starting epoch 251:
100%|██████████| 181/181 [00:55<00:00,  4.66it/s, MSE=0.0752]
08:03:33 - INFO: Starting epoch 252:
100%|██████████| 181/181 [00:55<00:00,  4.69it/s, MSE=0.00802]
08:04:28 - INFO: Starting epoch 253:
100%|██████████| 181/181 [00:54<00:00,  4.66it/s, MSE=0.254] 
08:05:23 - INFO: Starting epoch 254:
100%|██████████| 181/181 [00:54<00:00,  4.60it/s, MSE=0.0261]
08:06:18 - INFO: Starting epoch 255:
100%|██████████| 181/181 [00:54<00:00,  4.62it/s, MSE=0.0514]
08:07:13 - INFO: Starting epoch 256:
100%|██████████| 181/181 [00:55<00:00,  4.69it/s, MSE=0.00751]
08:08:08 - INFO: Starting epoch 257:
100%|██████████| 181/181 [00:55<00:00,  4.70it/s, MSE=0.0209]
08:09:03 - INFO: Starting epoch 258:
100%|██████████| 181/181 [00:54<00:00,  4.63it/s, MSE=0.0484]
08:09:58 - INFO: Starting epoch 259:
100%|██████████| 181/181 [00:55<00:00,  4.69it/s, MSE=0.0255]
08:10:53 - INFO: Starting epoch 260:
100%|██████████| 181/181 [00:55<00:00,  4.58it/s, MSE=0.00507]
08:11:49 - INFO: Starting epoch 261:
100%|██████████| 181/181 [00:55<00:00,  4.61it/s, MSE=0.0218]
08:12:45 - INFO: Starting epoch 262:
100%|██████████| 181/181 [00:55<00:00,  4.60it/s, MSE=0.0203]
08:13:40 - INFO: Starting epoch 263:
100%|██████████| 181/181 [00:55<00:00,  4.61it/s, MSE=0.036] 
08:14:36 - INFO: Starting epoch 264:
100%|██████████| 181/181 [00:55<00:00,  4.68it/s, MSE=0.0266]
08:15:32 - INFO: Starting epoch 265:
100%|██████████| 181/181 [00:55<00:00,  4.62it/s, MSE=0.0145] 
08:16:27 - INFO: Starting epoch 266:
100%|██████████| 181/181 [00:55<00:00,  4.58it/s, MSE=0.00483]
08:17:23 - INFO: Starting epoch 267:
100%|██████████| 181/181 [00:55<00:00,  4.61it/s, MSE=0.0604]
08:18:18 - INFO: Starting epoch 268:
100%|██████████| 181/181 [00:55<00:00,  4.58it/s, MSE=0.0466]
08:19:14 - INFO: Starting epoch 269:
100%|██████████| 181/181 [00:56<00:00,  4.59it/s, MSE=0.00358]
08:20:10 - INFO: Starting epoch 270:
100%|██████████| 181/181 [00:55<00:00,  4.62it/s, MSE=0.0104]
08:21:06 - INFO: Sampling 5 new images....
499it [00:45, 11.06it/s]
08:21:51 - INFO: Sampling 5 new images....
499it [00:44, 11.33it/s]
复制代码
<Figure size 640x480 with 10 Axes>
复制代码
08:22:36 - INFO: Starting epoch 271:
100%|██████████| 181/181 [00:55<00:00,  4.60it/s, MSE=0.0111]
08:23:31 - INFO: Starting epoch 272:
100%|██████████| 181/181 [00:55<00:00,  4.68it/s, MSE=0.0474]
08:24:26 - INFO: Starting epoch 273:
100%|██████████| 181/181 [00:55<00:00,  4.60it/s, MSE=0.106] 
08:25:22 - INFO: Starting epoch 274:
100%|██████████| 181/181 [00:55<00:00,  4.63it/s, MSE=0.00758]
08:26:18 - INFO: Starting epoch 275:
100%|██████████| 181/181 [00:55<00:00,  4.62it/s, MSE=0.00715]
08:27:13 - INFO: Starting epoch 276:
100%|██████████| 181/181 [00:55<00:00,  4.60it/s, MSE=0.0412]
08:28:08 - INFO: Starting epoch 277:
100%|██████████| 181/181 [00:55<00:00,  4.58it/s, MSE=0.00941]
08:29:04 - INFO: Starting epoch 278:
100%|██████████| 181/181 [00:55<00:00,  4.67it/s, MSE=0.0251]
08:29:59 - INFO: Starting epoch 279:
100%|██████████| 181/181 [00:55<00:00,  4.65it/s, MSE=0.0385]
08:30:55 - INFO: Starting epoch 280:
100%|██████████| 181/181 [00:55<00:00,  4.64it/s, MSE=0.0308]
08:31:50 - INFO: Starting epoch 281:
100%|██████████| 181/181 [00:55<00:00,  4.65it/s, MSE=0.108] 
08:32:45 - INFO: Starting epoch 282:
100%|██████████| 181/181 [00:55<00:00,  4.61it/s, MSE=0.049] 
08:33:41 - INFO: Starting epoch 283:
100%|██████████| 181/181 [00:55<00:00,  4.70it/s, MSE=0.0046]
08:34:37 - INFO: Starting epoch 284:
100%|██████████| 181/181 [00:55<00:00,  4.56it/s, MSE=0.00371]
08:35:32 - INFO: Starting epoch 285:
100%|██████████| 181/181 [00:55<00:00,  4.61it/s, MSE=0.00421]
08:36:27 - INFO: Starting epoch 286:
100%|██████████| 181/181 [00:54<00:00,  4.72it/s, MSE=0.00429]
08:37:22 - INFO: Starting epoch 287:
100%|██████████| 181/181 [00:55<00:00,  4.67it/s, MSE=0.0258]
08:38:17 - INFO: Starting epoch 288:
100%|██████████| 181/181 [00:55<00:00,  4.58it/s, MSE=0.0109]
08:39:13 - INFO: Starting epoch 289:
100%|██████████| 181/181 [00:55<00:00,  4.63it/s, MSE=0.152] 
08:40:08 - INFO: Starting epoch 290:
100%|██████████| 181/181 [00:55<00:00,  4.56it/s, MSE=0.0362]
08:41:04 - INFO: Starting epoch 291:
100%|██████████| 181/181 [00:55<00:00,  4.65it/s, MSE=0.0161]
08:41:59 - INFO: Starting epoch 292:
100%|██████████| 181/181 [00:56<00:00,  4.52it/s, MSE=0.167] 
08:42:56 - INFO: Starting epoch 293:
100%|██████████| 181/181 [00:56<00:00,  4.46it/s, MSE=0.00373]
08:43:52 - INFO: Starting epoch 294:
100%|██████████| 181/181 [00:55<00:00,  4.65it/s, MSE=0.0298]
08:44:48 - INFO: Starting epoch 295:
100%|██████████| 181/181 [00:55<00:00,  4.58it/s, MSE=0.0306]
08:45:43 - INFO: Starting epoch 296:
100%|██████████| 181/181 [00:56<00:00,  4.52it/s, MSE=0.0111]
08:46:40 - INFO: Starting epoch 297:
100%|██████████| 181/181 [00:55<00:00,  4.65it/s, MSE=0.0232]
08:47:36 - INFO: Starting epoch 298:
100%|██████████| 181/181 [00:56<00:00,  4.03it/s, MSE=0.0217]
08:48:32 - INFO: Starting epoch 299:
100%|██████████| 181/181 [00:56<00:00,  4.63it/s, MSE=0.05]  

3.5、使用训练好的模型来采样各种花朵

In [12]

import paddle
model = UNet_conditional(num_classes=5)
model.set_state_dict(paddle.load("models/ddpm_cond270.pdparams"))   # 加载模型文件
diffusion = Diffusion(img_size=64, device="cuda")

# 向日葵,玫瑰,郁金香,蒲公英,雏菊分别对应标签0,1,2,3,4
labels = paddle.to_tensor([0, 0, 0, 0, 0]).astype("int64")
# 标签引导强度
cfg_scale = 7
sampled_images = diffusion.sample(model, n=len(labels), labels=labels, cfg_scale=cfg_scale)
for i in range(5):
    img = sampled_images[i].transpose([1, 2, 0])
    img = np.array(img).astype("uint8")
    plt.subplot(1,5,i+1)
    plt.imshow(img)
plt.show()
复制代码
09:01:17 - INFO: Sampling 5 new images....
499it [00:43, 11.60it/s]
复制代码
<Figure size 640x480 with 5 Axes>

4、总结

  • 推理出了diffusion模型的损失函数,从最小化对数似然,到优化变分下界,简化变分下界,得到最后目标,预测噪声。

  • 提供了两版代码,其中条件生成与时下最火的text2image原理类似,只是text2image不仅仅使用单一类别作为编码。参考novelai。

  • 作为新一代生成模型,diffusion训练的过程可谓是十分的稳定,调参也比GAN相对简单不少!

  • 想要更好结果,我们只需要加大T,加大epoch即可。

相关推荐
AI极客菌38 分钟前
Controlnet作者新作IC-light V2:基于FLUX训练,支持处理风格化图像,细节远高于SD1.5。
人工智能·计算机视觉·ai作画·stable diffusion·aigc·flux·人工智能作画
AI绘画小331 天前
【comfyui教程】comfyui古风一键线稿上色,效果还挺惊艳!
人工智能·ai作画·stable diffusion·aigc·comfyui
AI绘画月月2 天前
【comfyui教程】ComfyUI有趣工作流推荐:快速换脸,创意随手掌握!
人工智能·ai作画·stable diffusion·aigc·comfyui
AI绘画咪酱2 天前
【AI绘画】AI绘图教程|stable diffusion(SD)图生图涂鸦超详细攻略,教你快速上手
人工智能·ai作画·stable diffusion·aigc·midjourney
HuggingAI2 天前
stable diffusion 大模型
人工智能·ai·stable diffusion·ai绘画
HuggingAI2 天前
stable diffusion图生图
人工智能·ai·stable diffusion·ai绘画
HuggingAI3 天前
stable diffusion文生图
人工智能·stable diffusion·ai绘画
云端奇趣3 天前
Stable Diffusion 绘画技巧分享,适合新手小白的技巧分享
人工智能·stable diffusion
cskywit4 天前
Stable diffusion 3.5本地运行环境配置记录
stable diffusion
ai绘画-安安妮4 天前
视频号带货书籍,一天佣金1200+(附视频教程)
人工智能·stable diffusion·aigc