0、项目视频详解
视频教程见B站https://www.bilibili.com/video/BV1e8411a7mz
1、diffusion模型理论(推导出损失函数)
1.1、背景
随着人工智能在图像生成,文本生成以及多模态生成等领域的技术不断累积,如:生成对抗网络(GAN)、变微分自动编码器(VAE)、normalizing flow models、自回归模型(AR)、energy-based models以及近年来大火的扩散模型(Diffusion Model)。
扩散模型的成功并非横空出世一般,突然出现在人们的视野中。其实早在2015年就已有人提出相类似的想法,最终在2020年提出了我们所熟知的"denoising diffusion probabilistic models"。DDPM
近期的novelai的生成技术同样是基于扩散模型,以下可以看到其强大的生成效果。可在此处跳转进行玩耍。
本项目可以达到的效果如下。输入向日葵,cfg=7的结果。可以看到,效果已经比较不错了。
1.2、模型训练与采样的算法流程
先放个图,1.3和1.4进行具体的流程与公式推导。我们要做的就是要推导出训练过程中的损失函数。
1.3、前向噪声扩散公式推导
diffusion模型的前向过程是向原始图片中逐步的添加高斯噪声,直至最后的图像趋于高斯分布。由于噪声占比会越来越大,所以添加噪声的强度也会越来越大。如下图所示:
-
每一时刻的图像都由前一时刻的图像添加噪声得到
-
最后的图像会变成纯噪声
-
每一时刻的添加的噪声强度均不同,目前有线性调度器,余弦调度器等
-
这一过程构建了我们训练所用到的标签,后面会看到
下面的推导过程展示了,我们如何从初始图像直接得到第t时刻的图像。
这个公式为下面的推导打上一个铺垫,下面一节就是关键的损失函数推导了。
1.4、优化目标,损失函数推导
上面的正向扩散并不难,下面我们推导反向扩散过程。即由Xt到Xt-1。
2、非条件生成(随机生成图片)
使用stanford汽车图片为例,没有类别。
2.1、训练过程解析
我们使用前向过程采样得到标签,训练时使用Unet网络结构,同时在模型的输入中嵌入时间步的编码。这类似于transformer模型中的位置编码,让模型更容易训练。 如下图所示:
2.2、数据解压
解压我们的数据集。只需要首次运行该项目时解压即可!
In [13]
import os
if not os.path.exists("work/cars"):
!mkdir work/cars
!unzip -oq data/data173302/stanford_cars.zip -d work/cars
In [14]
# 删除多余文件
!rm -rf work/cars/cars_test
!rm -rf work/cars/devkit
!rm -rf work/cars/car_devkit.tgz
!rm -rf work/cars/cars_train.tgz
!rm -rf work/cars/cars_test.tgz
!rm -rf work/cars/cars_test_annos_withlabels.mat
2.3、数据展示
查看我们的汽车图片。
In [1]
import paddle
import paddle.vision
import matplotlib.pyplot as plt
from PIL import Image
%matplotlib inline
# 定义展示图片函数
def show_images(imgs_paths=[],cols=4):
num_samples = len(imgs_paths)
plt.figure(figsize=(15,15))
i = 0
for img_path in imgs_paths:
img = Image.open(img_path)
plt.subplot(int(num_samples/cols + 1), cols, i + 1)
plt.imshow(img)
i += 1
imgs_paths = [
"work/cars/cars_train/05930.jpg", "work/cars/cars_train/06816.jpg", "work/cars/cars_train/02885.jpg", "work/cars/cars_train/07471.jpg",
"work/cars/cars_train/06600.jpg", "work/cars/cars_train/06020.jpg", "work/cars/cars_train/04818.jpg", "work/cars/cars_train/06088.jpg"
]
show_images(imgs_paths)
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/__init__.py:107: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
from collections import MutableMapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/rcsetup.py:20: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
from collections import Iterable, Mapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/colors.py:53: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
from collections import Sized
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2349: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
if isinstance(obj, collections.Iterator):
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2366: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
return list(data) if isinstance(data, collections.MappingView) else data
<Figure size 1500x1500 with 8 Axes>
2.4、构建数据集
我们使用paddle.vision里的数据集接口即可。
In [2]
import os
import paddle
import paddle.nn as nn
import paddle.vision as V
from PIL import Image
from matplotlib import pyplot as plt
from paddle.io import DataLoader
# 这里我们不需要用到图像标签,可以直接用paddle.vision里面提供的数据集接口
def get_data(args):
transforms = V.transforms.Compose([
V.transforms.Resize(80), # args.image_size + 1/4 *args.image_size
V.transforms.RandomResizedCrop(args.image_size, scale=(0.8, 1.0)),
V.transforms.ToTensor(),
V.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
dataset = V.datasets.ImageFolder(args.dataset_path, transform=transforms)
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
return dataloader
2.5、训练流程
训练中我们可以修改ARGS类的参数进行超参数定义。基本上,只要知道我们的损失函数是两张图片之间的均方误差,代码部分会变得比较简单。对比GAN而言,diffusion的参数更加容易调整,也更容易训练。
In [3]
"""ddpm"""
import os
import paddle
import paddle.nn as nn
from matplotlib import pyplot as plt
%matplotlib inline
from tqdm import tqdm
from paddle import optimizer
# from utils import *
from modules import UNet # 模型
import logging
import numpy as np
logging.basicConfig(format="%(asctime)s - %(levelname)s: %(message)s", level=logging.INFO, datefmt="%I:%M:%S")
class Diffusion:
def __init__(self, noise_steps=500, beta_start=1e-4, beta_end=0.02, img_size=64, device="cuda"):
self.noise_steps = noise_steps
self.beta_start = beta_start
self.beta_end = beta_end
self.img_size = img_size
self.device = device
self.beta = self.prepare_noise_schedule()
self.alpha = 1. - self.beta
self.alpha_hat = paddle.cumprod(self.alpha, dim=0)
def prepare_noise_schedule(self):
return paddle.linspace(self.beta_start, self.beta_end, self.noise_steps)
def noise_images(self, x, t):
sqrt_alpha_hat = paddle.sqrt(self.alpha_hat[t])[:, None, None, None]
sqrt_one_minus_alpha_hat = paddle.sqrt(1 - self.alpha_hat[t])[:, None, None, None]
Ɛ = paddle.randn(shape=x.shape)
return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * Ɛ, Ɛ
def sample_timesteps(self, n):
return paddle.randint(low=1, high=self.noise_steps, shape=(n,))
def sample(self, model, n):
logging.info(f"Sampling {n} new images....")
model.eval()
with paddle.no_grad():
x = paddle.randn((n, 3, self.img_size, self.img_size))
for i in tqdm(reversed(range(1, self.noise_steps)), position=0):
t = paddle.to_tensor([i] * x.shape[0]).astype("int64")
# print(x.shape, t.shape)
# print(f"完成第{i}步")
predicted_noise = model(x, t)
alpha = self.alpha[t][:, None, None, None]
alpha_hat = self.alpha_hat[t][:, None, None, None]
beta = self.beta[t][:, None, None, None]
if i > 1:
noise = paddle.randn(shape=x.shape)
else:
noise = paddle.zeros_like(x)
x = 1 / paddle.sqrt(alpha) * (x - ((1 - alpha) / (paddle.sqrt(1 - alpha_hat))) * predicted_noise) + paddle.sqrt(beta) * noise
model.train()
x = (x.clip(-1, 1) + 1) / 2
x = (x * 255)
return x
def train(args):
# setup_logging(args.run_name)
device = args.device
dataloader = get_data(args)
image = next(iter(dataloader))[0]
model = UNet()
opt = optimizer.Adam(learning_rate=args.lr, parameters=model.parameters())
mse = nn.MSELoss()
diffusion = Diffusion(img_size=args.image_size, device=device)
# logger = SummaryWriter(os.path.join("runs", args.run_name))
l = len(dataloader)
for epoch in range(args.epochs):
logging.info(f"Starting epoch {epoch}:")
pbar = tqdm(dataloader)
for i, images in enumerate(pbar):
# print(images)
t = diffusion.sample_timesteps(images[0].shape[0])
x_t, noise = diffusion.noise_images(images[0], t)
predicted_noise = model(x_t, t)
loss = mse(noise, predicted_noise) # 损失函数
opt.clear_grad()
loss.backward()
opt.step()
pbar.set_postfix(MSE=loss.item())
# print(("MSE", loss.item(), "global_step", epoch * l + i))
# logger.add_scalar("MSE", loss.item(), global_step=epoch * l + i)
if epoch % 20 == 0:
paddle.save(model.state_dict(), f"car_models/ddpm_uncond{epoch}.pdparams")
sampled_images = diffusion.sample(model, n=8)
for i in range(8):
img = sampled_images[i].transpose([1, 2, 0])
img = np.array(img).astype("uint8")
plt.subplot(2,4,i+1)
plt.imshow(img)
plt.show()
def launch():
import argparse
# 参数设置
class ARGS:
def __init__(self):
self.run_name = "DDPM_Uncondtional"
self.epochs = 150
self.batch_size = 24
self.image_size = 64
self.dataset_path = r"/home/aistudio/work/cars"
self.device = "cuda"
self.lr = 1.5e-4
args = ARGS()
train(args)
if __name__ == '__main__':
launch()
pass
W1024 11:03:25.091079 573 gpu_resources.cc:61] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.2, Runtime API Version: 11.2
W1024 11:03:25.094197 573 gpu_resources.cc:91] device: 0, cuDNN Version: 8.2.
11:03:25 - INFO: Starting epoch 0:
100%|██████████| 340/340 [02:13<00:00, 3.70it/s, MSE=0.15]
11:05:39 - INFO: Sampling 8 new images....
499it [00:20, 23.93it/s]
<Figure size 640x480 with 8 Axes>
11:06:00 - INFO: Starting epoch 1:
100%|██████████| 340/340 [02:13<00:00, 3.11it/s, MSE=0.0725]
11:08:14 - INFO: Starting epoch 2:
100%|██████████| 340/340 [02:12<00:00, 3.37it/s, MSE=0.0777]
11:10:26 - INFO: Starting epoch 3:
100%|██████████| 340/340 [02:12<00:00, 3.44it/s, MSE=0.0814]
11:12:38 - INFO: Starting epoch 4:
100%|██████████| 340/340 [02:12<00:00, 3.30it/s, MSE=0.0579]
11:14:51 - INFO: Starting epoch 5:
100%|██████████| 340/340 [02:13<00:00, 3.40it/s, MSE=0.107]
11:17:05 - INFO: Starting epoch 6:
100%|██████████| 340/340 [02:14<00:00, 3.49it/s, MSE=0.0742]
11:19:19 - INFO: Starting epoch 7:
100%|██████████| 340/340 [02:14<00:00, 3.20it/s, MSE=0.0422]
11:21:34 - INFO: Starting epoch 8:
100%|██████████| 340/340 [02:13<00:00, 3.26it/s, MSE=0.0527]
11:23:47 - INFO: Starting epoch 9:
100%|██████████| 340/340 [02:13<00:00, 3.45it/s, MSE=0.064]
11:26:01 - INFO: Starting epoch 10:
100%|██████████| 340/340 [02:15<00:00, 2.91it/s, MSE=0.043]
11:28:17 - INFO: Starting epoch 11:
100%|██████████| 340/340 [02:14<00:00, 2.60it/s, MSE=0.0712]
11:30:31 - INFO: Starting epoch 12:
100%|██████████| 340/340 [02:13<00:00, 3.23it/s, MSE=0.0674]
11:32:44 - INFO: Starting epoch 13:
100%|██████████| 340/340 [02:14<00:00, 3.00it/s, MSE=0.0464]
11:34:59 - INFO: Starting epoch 14:
100%|██████████| 340/340 [02:14<00:00, 2.93it/s, MSE=0.0349]
11:37:13 - INFO: Starting epoch 15:
100%|██████████| 340/340 [02:13<00:00, 3.58it/s, MSE=0.0279]
11:39:26 - INFO: Starting epoch 16:
100%|██████████| 340/340 [02:14<00:00, 2.62it/s, MSE=0.0436]
11:41:40 - INFO: Starting epoch 17:
100%|██████████| 340/340 [02:15<00:00, 3.06it/s, MSE=0.0278]
11:43:55 - INFO: Starting epoch 18:
100%|██████████| 340/340 [02:13<00:00, 3.03it/s, MSE=0.0318]
11:46:09 - INFO: Starting epoch 19:
100%|██████████| 340/340 [02:13<00:00, 3.01it/s, MSE=0.0743]
11:48:22 - INFO: Starting epoch 20:
100%|██████████| 340/340 [02:12<00:00, 3.26it/s, MSE=0.0721]
11:50:36 - INFO: Sampling 8 new images....
499it [00:20, 24.05it/s]
<Figure size 640x480 with 8 Axes>
11:50:57 - INFO: Starting epoch 21:
100%|██████████| 340/340 [02:13<00:00, 3.32it/s, MSE=0.0275]
11:53:10 - INFO: Starting epoch 22:
100%|██████████| 340/340 [02:13<00:00, 3.23it/s, MSE=0.028]
11:55:24 - INFO: Starting epoch 23:
100%|██████████| 340/340 [02:13<00:00, 2.89it/s, MSE=0.0155]
11:57:37 - INFO: Starting epoch 24:
100%|██████████| 340/340 [02:13<00:00, 3.17it/s, MSE=0.0386]
11:59:51 - INFO: Starting epoch 25:
100%|██████████| 340/340 [02:13<00:00, 3.16it/s, MSE=0.0189]
12:02:04 - INFO: Starting epoch 26:
100%|██████████| 340/340 [02:13<00:00, 3.23it/s, MSE=0.0285]
12:04:18 - INFO: Starting epoch 27:
100%|██████████| 340/340 [02:13<00:00, 3.47it/s, MSE=0.0593]
12:06:31 - INFO: Starting epoch 28:
100%|██████████| 340/340 [02:14<00:00, 2.98it/s, MSE=0.0151]
12:08:45 - INFO: Starting epoch 29:
100%|██████████| 340/340 [02:12<00:00, 3.40it/s, MSE=0.0552]
12:10:57 - INFO: Starting epoch 30:
100%|██████████| 340/340 [02:14<00:00, 3.53it/s, MSE=0.0335]
12:13:12 - INFO: Starting epoch 31:
100%|██████████| 340/340 [02:13<00:00, 3.01it/s, MSE=0.00773]
12:15:25 - INFO: Starting epoch 32:
100%|██████████| 340/340 [02:13<00:00, 3.03it/s, MSE=0.0907]
12:17:39 - INFO: Starting epoch 33:
100%|██████████| 340/340 [02:15<00:00, 3.65it/s, MSE=0.0412]
12:19:54 - INFO: Starting epoch 34:
100%|██████████| 340/340 [02:13<00:00, 3.55it/s, MSE=0.0359]
12:22:08 - INFO: Starting epoch 35:
100%|██████████| 340/340 [02:13<00:00, 3.30it/s, MSE=0.0563]
12:24:21 - INFO: Starting epoch 36:
100%|██████████| 340/340 [02:13<00:00, 3.34it/s, MSE=0.0299]
12:26:35 - INFO: Starting epoch 37:
100%|██████████| 340/340 [02:13<00:00, 3.24it/s, MSE=0.0315]
12:28:49 - INFO: Starting epoch 38:
100%|██████████| 340/340 [02:13<00:00, 3.08it/s, MSE=0.0455]
12:31:02 - INFO: Starting epoch 39:
100%|██████████| 340/340 [02:12<00:00, 3.23it/s, MSE=0.024]
12:33:15 - INFO: Starting epoch 40:
100%|██████████| 340/340 [02:13<00:00, 3.32it/s, MSE=0.0416]
12:35:29 - INFO: Sampling 8 new images....
499it [00:20, 23.89it/s]
<Figure size 640x480 with 8 Axes>
12:35:50 - INFO: Starting epoch 41:
100%|██████████| 340/340 [02:13<00:00, 3.18it/s, MSE=0.0134]
12:38:03 - INFO: Starting epoch 42:
100%|██████████| 340/340 [02:12<00:00, 3.77it/s, MSE=0.0948]
12:40:16 - INFO: Starting epoch 43:
100%|██████████| 340/340 [02:13<00:00, 3.16it/s, MSE=0.0208]
12:42:30 - INFO: Starting epoch 44:
100%|██████████| 340/340 [02:13<00:00, 3.29it/s, MSE=0.0421]
12:44:44 - INFO: Starting epoch 45:
100%|██████████| 340/340 [02:13<00:00, 2.88it/s, MSE=0.0296]
12:46:57 - INFO: Starting epoch 46:
100%|██████████| 340/340 [02:12<00:00, 3.00it/s, MSE=0.0398]
12:49:10 - INFO: Starting epoch 47:
100%|██████████| 340/340 [02:13<00:00, 3.06it/s, MSE=0.0269]
12:51:24 - INFO: Starting epoch 48:
100%|██████████| 340/340 [02:12<00:00, 3.34it/s, MSE=0.0635]
12:53:37 - INFO: Starting epoch 49:
100%|██████████| 340/340 [02:12<00:00, 3.58it/s, MSE=0.0687]
12:55:49 - INFO: Starting epoch 50:
100%|██████████| 340/340 [02:12<00:00, 3.08it/s, MSE=0.0253]
12:58:01 - INFO: Starting epoch 51:
100%|██████████| 340/340 [02:12<00:00, 3.33it/s, MSE=0.0219]
01:00:14 - INFO: Starting epoch 52:
100%|██████████| 340/340 [02:12<00:00, 3.13it/s, MSE=0.0422]
01:02:27 - INFO: Starting epoch 53:
100%|██████████| 340/340 [02:12<00:00, 3.26it/s, MSE=0.0187]
01:04:39 - INFO: Starting epoch 54:
100%|██████████| 340/340 [02:14<00:00, 3.39it/s, MSE=0.0453]
01:06:54 - INFO: Starting epoch 55:
100%|██████████| 340/340 [02:14<00:00, 3.45it/s, MSE=0.101]
01:09:08 - INFO: Starting epoch 56:
100%|██████████| 340/340 [02:15<00:00, 3.22it/s, MSE=0.016]
01:11:23 - INFO: Starting epoch 57:
100%|██████████| 340/340 [02:14<00:00, 3.21it/s, MSE=0.0173]
01:13:38 - INFO: Starting epoch 58:
100%|██████████| 340/340 [02:13<00:00, 2.65it/s, MSE=0.0127]
01:15:52 - INFO: Starting epoch 59:
100%|██████████| 340/340 [02:14<00:00, 3.56it/s, MSE=0.112]
01:18:06 - INFO: Starting epoch 60:
100%|██████████| 340/340 [02:14<00:00, 3.01it/s, MSE=0.0155]
01:20:21 - INFO: Sampling 8 new images....
499it [00:21, 23.74it/s]
<Figure size 640x480 with 8 Axes>
01:20:42 - INFO: Starting epoch 61:
100%|██████████| 340/340 [02:15<00:00, 3.17it/s, MSE=0.0143]
01:22:58 - INFO: Starting epoch 62:
100%|██████████| 340/340 [02:15<00:00, 3.26it/s, MSE=0.0731]
01:25:14 - INFO: Starting epoch 63:
100%|██████████| 340/340 [02:14<00:00, 3.38it/s, MSE=0.0484]
01:27:28 - INFO: Starting epoch 64:
100%|██████████| 340/340 [02:16<00:00, 3.30it/s, MSE=0.0154]
01:29:45 - INFO: Starting epoch 65:
100%|██████████| 340/340 [02:15<00:00, 3.31it/s, MSE=0.0224]
01:32:00 - INFO: Starting epoch 66:
100%|██████████| 340/340 [02:15<00:00, 3.14it/s, MSE=0.0265]
01:34:16 - INFO: Starting epoch 67:
100%|██████████| 340/340 [02:14<00:00, 3.10it/s, MSE=0.0326]
01:36:30 - INFO: Starting epoch 68:
100%|██████████| 340/340 [02:14<00:00, 3.35it/s, MSE=0.0656]
01:38:44 - INFO: Starting epoch 69:
100%|██████████| 340/340 [02:14<00:00, 3.20it/s, MSE=0.0591]
01:40:58 - INFO: Starting epoch 70:
100%|██████████| 340/340 [02:13<00:00, 3.34it/s, MSE=0.0196]
01:43:12 - INFO: Starting epoch 71:
100%|██████████| 340/340 [02:15<00:00, 2.64it/s, MSE=0.021]
01:45:28 - INFO: Starting epoch 72:
100%|██████████| 340/340 [02:14<00:00, 2.85it/s, MSE=0.0166]
01:47:42 - INFO: Starting epoch 73:
100%|██████████| 340/340 [02:15<00:00, 3.31it/s, MSE=0.0408]
01:49:57 - INFO: Starting epoch 74:
100%|██████████| 340/340 [02:14<00:00, 3.06it/s, MSE=0.0705]
01:52:12 - INFO: Starting epoch 75:
100%|██████████| 340/340 [02:14<00:00, 3.06it/s, MSE=0.0326]
01:54:26 - INFO: Starting epoch 76:
100%|██████████| 340/340 [02:13<00:00, 3.55it/s, MSE=0.016]
01:56:39 - INFO: Starting epoch 77:
100%|██████████| 340/340 [02:13<00:00, 2.98it/s, MSE=0.0122]
01:58:53 - INFO: Starting epoch 78:
100%|██████████| 340/340 [02:13<00:00, 3.57it/s, MSE=0.0304]
02:01:06 - INFO: Starting epoch 79:
100%|██████████| 340/340 [02:14<00:00, 3.17it/s, MSE=0.0186]
02:03:21 - INFO: Starting epoch 80:
100%|██████████| 340/340 [02:14<00:00, 3.37it/s, MSE=0.0248]
02:05:35 - INFO: Sampling 8 new images....
499it [00:21, 22.82it/s]
<Figure size 640x480 with 8 Axes>
02:05:57 - INFO: Starting epoch 81:
100%|██████████| 340/340 [02:13<00:00, 2.93it/s, MSE=0.0321]
02:08:11 - INFO: Starting epoch 82:
100%|██████████| 340/340 [02:15<00:00, 2.76it/s, MSE=0.0274]
02:10:26 - INFO: Starting epoch 83:
100%|██████████| 340/340 [02:16<00:00, 3.49it/s, MSE=0.0069]
02:12:42 - INFO: Starting epoch 84:
100%|██████████| 340/340 [02:13<00:00, 3.05it/s, MSE=0.0847]
02:14:56 - INFO: Starting epoch 85:
100%|██████████| 340/340 [02:13<00:00, 3.23it/s, MSE=0.0237]
02:17:09 - INFO: Starting epoch 86:
100%|██████████| 340/340 [02:13<00:00, 2.71it/s, MSE=0.0124]
02:19:23 - INFO: Starting epoch 87:
100%|██████████| 340/340 [02:14<00:00, 3.69it/s, MSE=0.0537]
02:21:37 - INFO: Starting epoch 88:
100%|██████████| 340/340 [02:13<00:00, 3.13it/s, MSE=0.0463]
02:23:51 - INFO: Starting epoch 89:
100%|██████████| 340/340 [02:13<00:00, 2.85it/s, MSE=0.0137]
02:26:04 - INFO: Starting epoch 90:
100%|██████████| 340/340 [02:12<00:00, 3.05it/s, MSE=0.0198]
02:28:17 - INFO: Starting epoch 91:
100%|██████████| 340/340 [02:12<00:00, 3.31it/s, MSE=0.0205]
02:30:30 - INFO: Starting epoch 92:
100%|██████████| 340/340 [02:12<00:00, 2.79it/s, MSE=0.0146]
02:32:43 - INFO: Starting epoch 93:
100%|██████████| 340/340 [02:12<00:00, 2.94it/s, MSE=0.00888]
02:34:56 - INFO: Starting epoch 94:
100%|██████████| 340/340 [02:12<00:00, 3.20it/s, MSE=0.0572]
02:37:08 - INFO: Starting epoch 95:
100%|██████████| 340/340 [02:13<00:00, 3.11it/s, MSE=0.021]
02:39:22 - INFO: Starting epoch 96:
100%|██████████| 340/340 [02:13<00:00, 3.24it/s, MSE=0.0392]
02:41:35 - INFO: Starting epoch 97:
100%|██████████| 340/340 [02:12<00:00, 2.66it/s, MSE=0.0166]
02:43:48 - INFO: Starting epoch 98:
100%|██████████| 340/340 [02:14<00:00, 2.51it/s, MSE=0.0591]
02:46:03 - INFO: Starting epoch 99:
100%|██████████| 340/340 [02:16<00:00, 3.14it/s, MSE=0.0283]
02:48:19 - INFO: Starting epoch 100:
100%|██████████| 340/340 [02:13<00:00, 3.19it/s, MSE=0.0276]
02:50:33 - INFO: Sampling 8 new images....
499it [00:21, 23.23it/s]
<Figure size 640x480 with 8 Axes>
02:50:55 - INFO: Starting epoch 101:
100%|██████████| 340/340 [02:14<00:00, 3.48it/s, MSE=0.0293]
02:53:10 - INFO: Starting epoch 102:
100%|██████████| 340/340 [02:16<00:00, 3.12it/s, MSE=0.0518]
02:55:27 - INFO: Starting epoch 103:
100%|██████████| 340/340 [02:14<00:00, 3.46it/s, MSE=0.0133]
02:57:42 - INFO: Starting epoch 104:
100%|██████████| 340/340 [02:15<00:00, 3.32it/s, MSE=0.0207]
02:59:58 - INFO: Starting epoch 105:
100%|██████████| 340/340 [02:14<00:00, 3.26it/s, MSE=0.00727]
03:02:12 - INFO: Starting epoch 106:
100%|██████████| 340/340 [02:15<00:00, 3.81it/s, MSE=0.0319]
03:04:28 - INFO: Starting epoch 107:
100%|██████████| 340/340 [02:15<00:00, 3.11it/s, MSE=0.0348]
03:06:44 - INFO: Starting epoch 108:
100%|██████████| 340/340 [02:15<00:00, 3.34it/s, MSE=0.0245]
03:08:59 - INFO: Starting epoch 109:
100%|██████████| 340/340 [02:15<00:00, 3.24it/s, MSE=0.0139]
03:11:14 - INFO: Starting epoch 110:
100%|██████████| 340/340 [02:15<00:00, 3.23it/s, MSE=0.0311]
03:13:29 - INFO: Starting epoch 111:
100%|██████████| 340/340 [02:15<00:00, 3.53it/s, MSE=0.0234]
03:15:45 - INFO: Starting epoch 112:
100%|██████████| 340/340 [02:16<00:00, 3.13it/s, MSE=0.0158]
03:18:01 - INFO: Starting epoch 113:
100%|██████████| 340/340 [02:15<00:00, 3.44it/s, MSE=0.0315]
03:20:17 - INFO: Starting epoch 114:
100%|██████████| 340/340 [02:13<00:00, 3.16it/s, MSE=0.0187]
03:22:30 - INFO: Starting epoch 115:
100%|██████████| 340/340 [02:13<00:00, 3.23it/s, MSE=0.0228]
03:24:43 - INFO: Starting epoch 116:
100%|██████████| 340/340 [02:14<00:00, 3.04it/s, MSE=0.0607]
03:26:57 - INFO: Starting epoch 117:
100%|██████████| 340/340 [02:13<00:00, 3.34it/s, MSE=0.0217]
03:29:10 - INFO: Starting epoch 118:
100%|██████████| 340/340 [02:13<00:00, 3.28it/s, MSE=0.0131]
03:31:24 - INFO: Starting epoch 119:
100%|██████████| 340/340 [02:15<00:00, 3.54it/s, MSE=0.0618]
03:33:39 - INFO: Starting epoch 120:
100%|██████████| 340/340 [02:15<00:00, 3.08it/s, MSE=0.0388]
03:35:55 - INFO: Sampling 8 new images....
499it [00:21, 23.36it/s]
<Figure size 640x480 with 8 Axes>
03:36:16 - INFO: Starting epoch 121:
100%|██████████| 340/340 [02:19<00:00, 3.14it/s, MSE=0.0142]
03:38:36 - INFO: Starting epoch 122:
100%|██████████| 340/340 [02:19<00:00, 2.97it/s, MSE=0.0112]
03:40:56 - INFO: Starting epoch 123:
100%|██████████| 340/340 [02:19<00:00, 2.84it/s, MSE=0.0243]
03:43:15 - INFO: Starting epoch 124:
100%|██████████| 340/340 [02:19<00:00, 3.11it/s, MSE=0.0312]
03:45:35 - INFO: Starting epoch 125:
100%|██████████| 340/340 [02:19<00:00, 3.26it/s, MSE=0.0513]
03:47:54 - INFO: Starting epoch 126:
100%|██████████| 340/340 [02:18<00:00, 3.10it/s, MSE=0.0254]
03:50:13 - INFO: Starting epoch 127:
100%|██████████| 340/340 [02:17<00:00, 3.18it/s, MSE=0.00965]
03:52:30 - INFO: Starting epoch 128:
100%|██████████| 340/340 [02:17<00:00, 3.35it/s, MSE=0.0183]
03:54:47 - INFO: Starting epoch 129:
100%|██████████| 340/340 [02:17<00:00, 3.36it/s, MSE=0.0158]
03:57:05 - INFO: Starting epoch 130:
100%|██████████| 340/340 [02:18<00:00, 3.29it/s, MSE=0.0326]
03:59:24 - INFO: Starting epoch 131:
100%|██████████| 340/340 [02:17<00:00, 3.18it/s, MSE=0.0224]
04:01:42 - INFO: Starting epoch 132:
100%|██████████| 340/340 [02:16<00:00, 3.11it/s, MSE=0.0367]
04:03:58 - INFO: Starting epoch 133:
100%|██████████| 340/340 [02:18<00:00, 2.95it/s, MSE=0.0231]
04:06:16 - INFO: Starting epoch 134:
100%|██████████| 340/340 [02:19<00:00, 3.34it/s, MSE=0.0195]
04:08:35 - INFO: Starting epoch 135:
100%|██████████| 340/340 [02:18<00:00, 3.30it/s, MSE=0.00914]
04:10:54 - INFO: Starting epoch 136:
100%|██████████| 340/340 [02:19<00:00, 2.76it/s, MSE=0.0355]
04:13:13 - INFO: Starting epoch 137:
100%|██████████| 340/340 [02:19<00:00, 3.14it/s, MSE=0.0365]
04:15:33 - INFO: Starting epoch 138:
100%|██████████| 340/340 [02:20<00:00, 3.38it/s, MSE=0.0182]
04:17:53 - INFO: Starting epoch 139:
100%|██████████| 340/340 [02:18<00:00, 3.19it/s, MSE=0.057]
04:20:11 - INFO: Starting epoch 140:
100%|██████████| 340/340 [02:16<00:00, 3.27it/s, MSE=0.0156]
04:22:28 - INFO: Sampling 8 new images....
499it [00:21, 22.81it/s]
<Figure size 640x480 with 8 Axes>
04:22:51 - INFO: Starting epoch 141:
100%|██████████| 340/340 [02:17<00:00, 3.11it/s, MSE=0.0256]
04:25:09 - INFO: Starting epoch 142:
100%|██████████| 340/340 [02:16<00:00, 2.82it/s, MSE=0.0271]
04:27:26 - INFO: Starting epoch 143:
100%|██████████| 340/340 [02:16<00:00, 3.35it/s, MSE=0.041]
04:29:42 - INFO: Starting epoch 144:
100%|██████████| 340/340 [02:16<00:00, 3.04it/s, MSE=0.0126]
04:31:59 - INFO: Starting epoch 145:
100%|██████████| 340/340 [02:16<00:00, 3.38it/s, MSE=0.0186]
04:34:16 - INFO: Starting epoch 146:
100%|██████████| 340/340 [02:19<00:00, 3.21it/s, MSE=0.0195]
04:36:36 - INFO: Starting epoch 147:
100%|██████████| 340/340 [02:19<00:00, 2.58it/s, MSE=0.00809]
04:38:55 - INFO: Starting epoch 148:
100%|██████████| 340/340 [02:20<00:00, 3.04it/s, MSE=0.0113]
04:41:15 - INFO: Starting epoch 149:
100%|██████████| 340/340 [02:19<00:00, 3.17it/s, MSE=0.013]
2.6、使用训练好的模型进行采样
我们可以加载训练时觉得不错的模型进行采样生成。这个项目仅作为演示,生成汽车可能并不具备特别的价值。但是最新的novelai已经可以生成超高水平的二次元绘画,所以通过这个项目帮助我们理解diffusion模型的底层原理,可以让未来接触更多改进版的diffusion模型更加轻松。
In [6]
import paddle
model = UNet()
model.set_state_dict(paddle.load("car_models/ddpm_uncond140.pdparams")) # 加载模型文件
diffusion = Diffusion(img_size=64, device="cuda")
sampled_images = diffusion.sample(model, n=8)
# 采样图片
for i in range(8):
img = sampled_images[i].transpose([1, 2, 0])
img = np.array(img).astype("uint8")
plt.subplot(2, 4,i+1)
plt.imshow(img)
plt.show()
05:37:15 - INFO: Sampling 8 new images....
499it [00:22, 22.61it/s]
<Figure size 640x480 with 8 Axes>
3、条件生成(通过标签指导图片生成)
3.1、训练过程解析
同非条件生成一样,我们使用前向过程采样得到标签,训练时使用Unet网络结构,同时在模型的输入中嵌入时间步的编码。这类似于transformer模型中的位置编码,让模型更容易训练。 这里我们额外添加类别的标签编码,也作为模型的输入。其中cfg表示条件生成与非条件生成之间的比值,cfg越大,生成的图像中条件生成的比例就越大(生成图像=(1-alpha)* 条件生成+(alpha)* 非条件生成),其中alpha与cfg相关。
- ------cfg, classifier free guidance(标签引导)
另一方面,下面这个训练使用了上一代模型与当前模型参数的指数平均,削减因为离群点对模型参数更新的影响,从而实现更稳定的梯度更新。
- ------ema, exponential moving average(指数移动平均)
运行下面代码前先重启内核!清空显存占用。
3.2、解压数据集
我们使用花朵数据集,包含5种种类,这样后面我们在采样时就可以指定其中一种种类进行生成。
In [1]
# 解压花朵数据集
import os
if not os.path.exists("work/flowers"):
!mkdir work/flowers
!unzip -oq data/data173680/flowers.zip -d work/flowers
In [2]
# 加载数据集
"""由于条件生成需要同时提供图片标签,因此我们这里自定义数据集"""
# 1、将图片数据写入txt文件。flowers本来是分类数据集,这里我们把他的训练集和验证集都提取出来,当作我们生成模型的训练集。
import os
train_sunflower = os.listdir("work/flowers/pic/train/sunflower") # 0------向日葵
valid_sunflower = os.listdir("work/flowers/pic/validation/sunflower") # 0------向日葵
train_rose = os.listdir("work/flowers/pic/train/rose") # 1------玫瑰
valid_rose = os.listdir("work/flowers/pic/validation/rose") # 1------玫瑰
train_tulip = os.listdir("work/flowers/pic/train/tulip") # 2------郁金香
valid_tulip = os.listdir("work/flowers/pic/validation/tulip") # 2------郁金香
train_dandelion = os.listdir("work/flowers/pic/train/dandelion") # 3------蒲公英
valid_dandelion = os.listdir("work/flowers/pic/validation/dandelion") # 3------蒲公英
train_daisy = os.listdir("work/flowers/pic/train/daisy") # 4------雏菊
valid_daisy = os.listdir("work/flowers/pic/validation/daisy") # 4------雏菊
with open("flowers_data.txt", 'w') as f:
for image in train_sunflower:
f.write("work/flowers/pic/train/sunflower/" + image + ";" + "0" + "\n")
for image in valid_sunflower:
f.write("work/flowers/pic/validation/sunflower/" + image + ";" + "0" + "\n")
for image in train_rose:
f.write("work/flowers/pic/train/rose/" + image + ";" + "1" + "\n")
for image in valid_rose:
f.write("work/flowers/pic/validation/rose/" + image + ";" + "1" + "\n")
for image in train_tulip:
f.write("work/flowers/pic/train/tulip/" + image + ";" + "2" + "\n")
for image in valid_tulip:
f.write("work/flowers/pic/validation/tulip/" + image + ";" + "2" + "\n")
for image in train_dandelion:
f.write("work/flowers/pic/train/dandelion/" + image + ";" + "3" + "\n")
for image in valid_dandelion:
f.write("work/flowers/pic/validation/dandelion/" + image + ";" + "3" + "\n")
for image in train_daisy:
f.write("work/flowers/pic/train/daisy/" + image + ";" + "4" + "\n")
for image in valid_daisy:
f.write("work/flowers/pic/validation/daisy/" + image + ";" + "4" + "\n")
3.3、构建数据集
因为这里我们的数据迭代器需要同时返回图片及标签。所以我们使用基础api构建我们的数据集。
In [3]
# 2、构建数据集
# 数据变化,返回图片与标签
import paddle.vision as V
from PIL import Image
from paddle.io import Dataset, DataLoader
from tqdm import tqdm
# 数据变换
transforms = V.transforms.Compose([
V.transforms.Resize(80), # args.image_size + 1/4 *args.image_size
V.transforms.RandomResizedCrop(64, scale=(0.8, 1.0)),
V.transforms.ToTensor(),
V.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
class TrainDataFlowers(Dataset):
def __init__(self, txt_path="flowers_data.txt"):
with open(txt_path, "r") as f:
data = f.readlines()
self.image_paths = data[:-1] # 最后一行是空行,舍弃
def __getitem__(self, index):
image_path, label = self.image_paths[index].split(";")
image = Image.open(image_path)
image = transforms(image)
label = int(label)
return image, label
def __len__(self):
return len(self.image_paths)
dataset = TrainDataFlowers()
dataloader = DataLoader(dataset, batch_size=24, shuffle=True)
if __name__ == "__main__": # 测试数据集是否可用
pbar = tqdm(dataloader)
for i, (images, labels) in enumerate(pbar):
pass
print("ok")
0%| | 0/181 [00:00<?, ?it/s]W1023 15:49:27.184664 3398 gpu_resources.cc:61] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.2, Runtime API Version: 11.2
W1023 15:49:27.188580 3398 gpu_resources.cc:91] device: 0, cuDNN Version: 8.2.
100%|██████████| 181/181 [00:15<00:00, 11.37it/s]
ok
3.4、训练流程
训练中我们可以修改ARGS类的参数进行超参数定义。基本上,只要知道我们的损失函数是两张图片之间的均方误差,代码部分会变得比较简单。对比GAN而言,diffusion的参数更加容易调整,也更容易训练。
In [4]
import os
import paddle
import copy
import paddle.nn as nn
from matplotlib import pyplot as plt
%matplotlib inline
from tqdm import tqdm
from paddle import optimizer
from modules import UNet_conditional, EMA
import logging
import numpy as np
logging.basicConfig(format="%(asctime)s - %(levelname)s: %(message)s", level=logging.INFO, datefmt="%I:%M:%S")
class Diffusion:
def __init__(self, noise_steps=500, beta_start=1e-4, beta_end=0.02, img_size=256, device="cuda"):
self.noise_steps = noise_steps
self.beta_start = beta_start
self.beta_end = beta_end
self.beta = self.prepare_noise_schedule()
self.alpha = 1. - self.beta
self.alpha_hat = paddle.cumprod(self.alpha, dim=0)
self.img_size = img_size
self.device = device
def prepare_noise_schedule(self):
return paddle.linspace(self.beta_start, self.beta_end, self.noise_steps)
def noise_images(self, x, t):
sqrt_alpha_hat = paddle.sqrt(self.alpha_hat[t])[:, None, None, None]
sqrt_one_minus_alpha_hat = paddle.sqrt(1 - self.alpha_hat[t])[:, None, None, None]
Ɛ = paddle.randn(shape=x.shape)
return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * Ɛ, Ɛ
def sample_timesteps(self, n):
return paddle.randint(low=1, high=self.noise_steps, shape=(n,))
def sample(self, model, n, labels, cfg_scale=3):
logging.info(f"Sampling {n} new images....")
model.eval()
with paddle.no_grad():
x = paddle.randn((n, 3, self.img_size, self.img_size))
for i in tqdm(reversed(range(1, self.noise_steps)), position=0):
t = paddle.to_tensor([i] * x.shape[0]).astype("int64")
predicted_noise = model(x, t, labels)
if cfg_scale > 0:
uncond_predicted_noise = model(x, t, None)
cfg_scale = paddle.to_tensor(cfg_scale).astype("float32")
predicted_noise = paddle.lerp(uncond_predicted_noise, predicted_noise, cfg_scale)
alpha = self.alpha[t][:, None, None, None]
alpha_hat = self.alpha_hat[t][:, None, None, None]
beta = self.beta[t][:, None, None, None]
if i > 1:
noise = paddle.randn(shape=x.shape)
else:
noise = paddle.zeros_like(x)
x = 1 / paddle.sqrt(alpha) * (x - ((1 - alpha) / (paddle.sqrt(1 - alpha_hat))) * predicted_noise) + paddle.sqrt(beta) * noise
model.train()
x = (x.clip(-1, 1) + 1) / 2
x = (x * 255)
return x
def train(args):
# setup_logging(args.run_name)
device = args.device
dataloader = args.dataloader
model = UNet_conditional(num_classes=args.num_classes)
opt = optimizer.Adam(learning_rate=args.lr, parameters=model.parameters())
mse = nn.MSELoss()
diffusion = Diffusion(img_size=args.image_size, device=device)
l = len(dataloader)
ema = EMA(0.995)
ema_model = copy.deepcopy(model)
ema_model.eval()
# print("ema_model", ema_model)
for epoch in range(args.epochs):
logging.info(f"Starting epoch {epoch}:")
pbar = tqdm(dataloader)
for i, (images, labels) in enumerate(pbar):
t = diffusion.sample_timesteps(images.shape[0])
x_t, noise = diffusion.noise_images(images, t)
if np.random.random() < 0.1:
labels = None
predicted_noise = model(x_t, t, labels)
loss = mse(noise, predicted_noise) # 损失函数
opt.clear_grad()
loss.backward()
opt.step()
ema.step_ema(ema_model, model)
pbar.set_postfix(MSE=loss.item())
# logger.add_scalar("MSE", loss.item(), global_step=epoch * l + i)
if epoch % 30 == 0: # 保存模型,可视化训练结果。
paddle.save(model.state_dict(), f"models/ddpm_cond{epoch}.pdparams")
labels = paddle.arange(5).astype("int64")
# 一共采样10张图片
# 从左到右依次为-->向日葵,玫瑰,郁金香,蒲公英,雏菊
sampled_images1 = diffusion.sample(model, n=len(labels), labels=labels)
sampled_images2 = diffusion.sample(model, n=len(labels), labels=labels)
# ema_sampled_images = diffusion.sample(ema_model, n=len(labels), labels=labels)
for i in range(5):
img = sampled_images1[i].transpose([1, 2, 0])
img = np.array(img).astype("uint8")
plt.subplot(2,5,i+1)
plt.imshow(img)
for i in range(5):
img = sampled_images2[i].transpose([1, 2, 0])
img = np.array(img).astype("uint8")
plt.subplot(2,5,i+1+5)
plt.imshow(img)
plt.show()
def launch():
import argparse
# 参数设置
class ARGS:
def __init__(self):
self.run_name = "DDPM_Uncondtional"
self.epochs = 300
self.batch_size = 48
self.image_size = 64
self.device = "cuda"
self.lr = 1.5e-4
self.num_classes = 5
self.dataloader = dataloader
args = ARGS()
train(args)
if __name__ == '__main__':
# 训练
launch()
pass
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/__init__.py:107: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
from collections import MutableMapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/rcsetup.py:20: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
from collections import Iterable, Mapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/colors.py:53: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
from collections import Sized
03:56:58 - INFO: Starting epoch 0:
100%|██████████| 181/181 [01:04<00:00, 3.76it/s, MSE=0.172]
03:58:03 - INFO: Sampling 5 new images....
499it [00:44, 11.13it/s]
03:58:48 - INFO: Sampling 5 new images....
499it [00:44, 11.28it/s]
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2349: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
if isinstance(obj, collections.Iterator):
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2366: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
return list(data) if isinstance(data, collections.MappingView) else data
<Figure size 640x480 with 10 Axes>
03:59:33 - INFO: Starting epoch 1:
100%|██████████| 181/181 [01:02<00:00, 3.81it/s, MSE=0.104]
04:00:36 - INFO: Starting epoch 2:
100%|██████████| 181/181 [01:02<00:00, 3.78it/s, MSE=0.103]
04:01:38 - INFO: Starting epoch 3:
100%|██████████| 181/181 [01:02<00:00, 3.75it/s, MSE=0.0912]
04:02:41 - INFO: Starting epoch 4:
100%|██████████| 181/181 [01:02<00:00, 3.80it/s, MSE=0.0649]
04:03:43 - INFO: Starting epoch 5:
100%|██████████| 181/181 [00:59<00:00, 4.66it/s, MSE=0.0631]
04:04:43 - INFO: Starting epoch 6:
100%|██████████| 181/181 [00:55<00:00, 4.72it/s, MSE=0.179]
04:05:38 - INFO: Starting epoch 7:
100%|██████████| 181/181 [00:55<00:00, 4.59it/s, MSE=0.0908]
04:06:33 - INFO: Starting epoch 8:
100%|██████████| 181/181 [00:55<00:00, 4.69it/s, MSE=0.158]
04:07:29 - INFO: Starting epoch 9:
100%|██████████| 181/181 [00:55<00:00, 4.58it/s, MSE=0.171]
04:08:24 - INFO: Starting epoch 10:
100%|██████████| 181/181 [00:55<00:00, 4.65it/s, MSE=0.0362]
04:09:20 - INFO: Starting epoch 11:
100%|██████████| 181/181 [00:55<00:00, 4.60it/s, MSE=0.0444]
04:10:16 - INFO: Starting epoch 12:
100%|██████████| 181/181 [00:55<00:00, 4.63it/s, MSE=0.0393]
04:11:12 - INFO: Starting epoch 13:
100%|██████████| 181/181 [00:55<00:00, 4.66it/s, MSE=0.064]
04:12:07 - INFO: Starting epoch 14:
100%|██████████| 181/181 [00:55<00:00, 4.67it/s, MSE=0.035]
04:13:03 - INFO: Starting epoch 15:
100%|██████████| 181/181 [00:55<00:00, 4.60it/s, MSE=0.063]
04:13:58 - INFO: Starting epoch 16:
100%|██████████| 181/181 [00:55<00:00, 4.62it/s, MSE=0.0157]
04:14:54 - INFO: Starting epoch 17:
100%|██████████| 181/181 [00:55<00:00, 4.69it/s, MSE=0.0159]
04:15:49 - INFO: Starting epoch 18:
100%|██████████| 181/181 [00:55<00:00, 4.61it/s, MSE=0.0212]
04:16:45 - INFO: Starting epoch 19:
100%|██████████| 181/181 [00:55<00:00, 4.61it/s, MSE=0.0252]
04:17:40 - INFO: Starting epoch 20:
100%|██████████| 181/181 [00:55<00:00, 4.64it/s, MSE=0.0192]
04:18:35 - INFO: Starting epoch 21:
100%|██████████| 181/181 [00:55<00:00, 4.61it/s, MSE=0.0361]
04:19:31 - INFO: Starting epoch 22:
100%|██████████| 181/181 [00:55<00:00, 4.58it/s, MSE=0.0177]
04:20:26 - INFO: Starting epoch 23:
100%|██████████| 181/181 [00:55<00:00, 4.54it/s, MSE=0.0527]
04:21:22 - INFO: Starting epoch 24:
100%|██████████| 181/181 [00:55<00:00, 4.59it/s, MSE=0.0458]
04:22:17 - INFO: Starting epoch 25:
100%|██████████| 181/181 [00:55<00:00, 4.66it/s, MSE=0.0539]
04:23:13 - INFO: Starting epoch 26:
100%|██████████| 181/181 [00:55<00:00, 4.64it/s, MSE=0.205]
04:24:09 - INFO: Starting epoch 27:
100%|██████████| 181/181 [00:55<00:00, 4.66it/s, MSE=0.0463]
04:25:04 - INFO: Starting epoch 28:
100%|██████████| 181/181 [00:55<00:00, 4.64it/s, MSE=0.152]
04:26:00 - INFO: Starting epoch 29:
100%|██████████| 181/181 [00:55<00:00, 4.59it/s, MSE=0.284]
04:26:55 - INFO: Starting epoch 30:
100%|██████████| 181/181 [00:55<00:00, 4.71it/s, MSE=0.0896]
04:27:51 - INFO: Sampling 5 new images....
499it [00:44, 11.27it/s]
04:28:36 - INFO: Sampling 5 new images....
499it [00:45, 11.01it/s]
<Figure size 640x480 with 10 Axes>
04:29:21 - INFO: Starting epoch 31:
100%|██████████| 181/181 [00:55<00:00, 4.61it/s, MSE=0.299]
04:30:17 - INFO: Starting epoch 32:
100%|██████████| 181/181 [00:55<00:00, 4.69it/s, MSE=0.0226]
04:31:12 - INFO: Starting epoch 33:
100%|██████████| 181/181 [00:55<00:00, 4.70it/s, MSE=0.00727]
04:32:08 - INFO: Starting epoch 34:
100%|██████████| 181/181 [00:55<00:00, 4.65it/s, MSE=0.132]
04:33:03 - INFO: Starting epoch 35:
100%|██████████| 181/181 [00:55<00:00, 4.65it/s, MSE=0.0498]
04:33:59 - INFO: Starting epoch 36:
100%|██████████| 181/181 [00:55<00:00, 4.62it/s, MSE=0.0107]
04:34:55 - INFO: Starting epoch 37:
100%|██████████| 181/181 [00:55<00:00, 4.68it/s, MSE=0.0116]
04:35:50 - INFO: Starting epoch 38:
100%|██████████| 181/181 [00:55<00:00, 4.56it/s, MSE=0.044]
04:36:46 - INFO: Starting epoch 39:
100%|██████████| 181/181 [00:55<00:00, 4.71it/s, MSE=0.167]
04:37:41 - INFO: Starting epoch 40:
100%|██████████| 181/181 [00:55<00:00, 4.61it/s, MSE=0.0359]
04:38:37 - INFO: Starting epoch 41:
100%|██████████| 181/181 [00:55<00:00, 4.67it/s, MSE=0.0064]
04:39:33 - INFO: Starting epoch 42:
100%|██████████| 181/181 [00:55<00:00, 4.71it/s, MSE=0.0107]
04:40:28 - INFO: Starting epoch 43:
100%|██████████| 181/181 [00:55<00:00, 4.60it/s, MSE=0.0216]
04:41:24 - INFO: Starting epoch 44:
100%|██████████| 181/181 [00:55<00:00, 4.66it/s, MSE=0.0361]
04:42:20 - INFO: Starting epoch 45:
100%|██████████| 181/181 [00:55<00:00, 4.59it/s, MSE=0.0368]
04:43:15 - INFO: Starting epoch 46:
100%|██████████| 181/181 [00:55<00:00, 4.65it/s, MSE=0.0283]
04:44:10 - INFO: Starting epoch 47:
100%|██████████| 181/181 [00:55<00:00, 4.63it/s, MSE=0.0352]
04:45:06 - INFO: Starting epoch 48:
100%|██████████| 181/181 [00:55<00:00, 4.66it/s, MSE=0.0499]
04:46:01 - INFO: Starting epoch 49:
100%|██████████| 181/181 [00:55<00:00, 4.64it/s, MSE=0.0359]
04:46:56 - INFO: Starting epoch 50:
100%|██████████| 181/181 [00:55<00:00, 4.62it/s, MSE=0.0555]
04:47:52 - INFO: Starting epoch 51:
100%|██████████| 181/181 [00:55<00:00, 4.68it/s, MSE=0.14]
04:48:47 - INFO: Starting epoch 52:
100%|██████████| 181/181 [00:54<00:00, 4.65it/s, MSE=0.0136]
04:49:42 - INFO: Starting epoch 53:
100%|██████████| 181/181 [00:55<00:00, 4.60it/s, MSE=0.0242]
04:50:38 - INFO: Starting epoch 54:
100%|██████████| 181/181 [00:55<00:00, 4.65it/s, MSE=0.0252]
04:51:33 - INFO: Starting epoch 55:
100%|██████████| 181/181 [00:55<00:00, 4.64it/s, MSE=0.0274]
04:52:29 - INFO: Starting epoch 56:
100%|██████████| 181/181 [00:55<00:00, 4.58it/s, MSE=0.0727]
04:53:24 - INFO: Starting epoch 57:
100%|██████████| 181/181 [00:55<00:00, 4.72it/s, MSE=0.023]
04:54:20 - INFO: Starting epoch 58:
100%|██████████| 181/181 [00:55<00:00, 4.65it/s, MSE=0.0457]
04:55:15 - INFO: Starting epoch 59:
100%|██████████| 181/181 [00:55<00:00, 4.61it/s, MSE=0.0123]
04:56:10 - INFO: Starting epoch 60:
100%|██████████| 181/181 [00:55<00:00, 4.65it/s, MSE=0.0125]
04:57:07 - INFO: Sampling 5 new images....
499it [00:45, 10.91it/s]
04:57:52 - INFO: Sampling 5 new images....
499it [00:45, 10.90it/s]
<Figure size 640x480 with 10 Axes>
04:58:39 - INFO: Starting epoch 61:
100%|██████████| 181/181 [00:56<00:00, 4.53it/s, MSE=0.00765]
04:59:35 - INFO: Starting epoch 62:
100%|██████████| 181/181 [00:55<00:00, 4.63it/s, MSE=0.00355]
05:00:30 - INFO: Starting epoch 63:
100%|██████████| 181/181 [00:55<00:00, 4.61it/s, MSE=0.0256]
05:01:26 - INFO: Starting epoch 64:
100%|██████████| 181/181 [00:55<00:00, 4.66it/s, MSE=0.0413]
05:02:22 - INFO: Starting epoch 65:
100%|██████████| 181/181 [00:55<00:00, 4.65it/s, MSE=0.0146]
05:03:17 - INFO: Starting epoch 66:
100%|██████████| 181/181 [00:56<00:00, 4.57it/s, MSE=0.00737]
05:04:13 - INFO: Starting epoch 67:
100%|██████████| 181/181 [00:56<00:00, 4.63it/s, MSE=0.00363]
05:05:09 - INFO: Starting epoch 68:
100%|██████████| 181/181 [00:56<00:00, 4.58it/s, MSE=0.121]
05:06:06 - INFO: Starting epoch 69:
100%|██████████| 181/181 [00:56<00:00, 4.53it/s, MSE=0.0124]
05:07:02 - INFO: Starting epoch 70:
100%|██████████| 181/181 [00:56<00:00, 4.53it/s, MSE=0.0235]
05:07:59 - INFO: Starting epoch 71:
100%|██████████| 181/181 [00:55<00:00, 4.52it/s, MSE=0.084]
05:08:55 - INFO: Starting epoch 72:
100%|██████████| 181/181 [00:55<00:00, 4.61it/s, MSE=0.022]
05:09:50 - INFO: Starting epoch 73:
100%|██████████| 181/181 [00:56<00:00, 4.61it/s, MSE=0.00922]
05:10:47 - INFO: Starting epoch 74:
100%|██████████| 181/181 [00:56<00:00, 4.27it/s, MSE=0.0059]
05:11:43 - INFO: Starting epoch 75:
100%|██████████| 181/181 [00:56<00:00, 4.60it/s, MSE=0.00901]
05:12:40 - INFO: Starting epoch 76:
100%|██████████| 181/181 [00:56<00:00, 4.60it/s, MSE=0.0261]
05:13:36 - INFO: Starting epoch 77:
100%|██████████| 181/181 [00:55<00:00, 4.72it/s, MSE=0.0317]
05:14:32 - INFO: Starting epoch 78:
100%|██████████| 181/181 [00:55<00:00, 4.70it/s, MSE=0.0379]
05:15:27 - INFO: Starting epoch 79:
100%|██████████| 181/181 [00:54<00:00, 4.62it/s, MSE=0.0126]
05:16:22 - INFO: Starting epoch 80:
100%|██████████| 181/181 [00:55<00:00, 4.57it/s, MSE=0.0129]
05:17:17 - INFO: Starting epoch 81:
100%|██████████| 181/181 [00:55<00:00, 4.61it/s, MSE=0.0174]
05:18:13 - INFO: Starting epoch 82:
100%|██████████| 181/181 [00:57<00:00, 4.59it/s, MSE=0.00267]
05:19:11 - INFO: Starting epoch 83:
100%|██████████| 181/181 [00:57<00:00, 4.61it/s, MSE=0.00863]
05:20:08 - INFO: Starting epoch 84:
100%|██████████| 181/181 [00:55<00:00, 4.59it/s, MSE=0.0928]
05:21:04 - INFO: Starting epoch 85:
100%|██████████| 181/181 [00:55<00:00, 4.71it/s, MSE=0.0151]
05:21:59 - INFO: Starting epoch 86:
100%|██████████| 181/181 [00:55<00:00, 4.60it/s, MSE=0.0231]
05:22:55 - INFO: Starting epoch 87:
100%|██████████| 181/181 [00:55<00:00, 4.50it/s, MSE=0.0442]
05:23:51 - INFO: Starting epoch 88:
100%|██████████| 181/181 [00:55<00:00, 4.58it/s, MSE=0.00999]
05:24:47 - INFO: Starting epoch 89:
100%|██████████| 181/181 [00:55<00:00, 4.57it/s, MSE=0.00467]
05:25:42 - INFO: Starting epoch 90:
100%|██████████| 181/181 [00:55<00:00, 4.64it/s, MSE=0.0219]
05:26:38 - INFO: Sampling 5 new images....
499it [00:44, 11.12it/s]
05:27:23 - INFO: Sampling 5 new images....
499it [00:45, 11.06it/s]
<Figure size 640x480 with 10 Axes>
05:28:09 - INFO: Starting epoch 91:
100%|██████████| 181/181 [00:55<00:00, 4.70it/s, MSE=0.00285]
05:29:05 - INFO: Starting epoch 92:
100%|██████████| 181/181 [00:55<00:00, 4.67it/s, MSE=0.112]
05:30:00 - INFO: Starting epoch 93:
100%|██████████| 181/181 [00:55<00:00, 4.67it/s, MSE=0.0108]
05:30:56 - INFO: Starting epoch 94:
100%|██████████| 181/181 [00:55<00:00, 4.62it/s, MSE=0.0281]
05:31:51 - INFO: Starting epoch 95:
100%|██████████| 181/181 [00:55<00:00, 4.70it/s, MSE=0.0355]
05:32:47 - INFO: Starting epoch 96:
100%|██████████| 181/181 [00:55<00:00, 4.62it/s, MSE=0.133]
05:33:42 - INFO: Starting epoch 97:
100%|██████████| 181/181 [00:56<00:00, 4.56it/s, MSE=0.0138]
05:34:39 - INFO: Starting epoch 98:
100%|██████████| 181/181 [00:56<00:00, 4.66it/s, MSE=0.00963]
05:35:35 - INFO: Starting epoch 99:
100%|██████████| 181/181 [00:56<00:00, 4.59it/s, MSE=0.0298]
05:36:31 - INFO: Starting epoch 100:
100%|██████████| 181/181 [00:56<00:00, 4.65it/s, MSE=0.00709]
05:37:27 - INFO: Starting epoch 101:
100%|██████████| 181/181 [00:55<00:00, 4.64it/s, MSE=0.0737]
05:38:23 - INFO: Starting epoch 102:
100%|██████████| 181/181 [00:55<00:00, 4.61it/s, MSE=0.0105]
05:39:18 - INFO: Starting epoch 103:
100%|██████████| 181/181 [00:56<00:00, 4.56it/s, MSE=0.00631]
05:40:14 - INFO: Starting epoch 104:
100%|██████████| 181/181 [00:55<00:00, 4.55it/s, MSE=0.00662]
05:41:10 - INFO: Starting epoch 105:
100%|██████████| 181/181 [00:55<00:00, 4.67it/s, MSE=0.262]
05:42:05 - INFO: Starting epoch 106:
100%|██████████| 181/181 [00:55<00:00, 4.64it/s, MSE=0.0206]
05:43:00 - INFO: Starting epoch 107:
100%|██████████| 181/181 [00:55<00:00, 4.62it/s, MSE=0.00979]
05:43:56 - INFO: Starting epoch 108:
100%|██████████| 181/181 [00:55<00:00, 4.64it/s, MSE=0.0121]
05:44:52 - INFO: Starting epoch 109:
100%|██████████| 181/181 [00:55<00:00, 4.65it/s, MSE=0.00493]
05:45:48 - INFO: Starting epoch 110:
100%|██████████| 181/181 [00:55<00:00, 4.66it/s, MSE=0.0158]
05:46:43 - INFO: Starting epoch 111:
100%|██████████| 181/181 [00:55<00:00, 4.61it/s, MSE=0.00567]
05:47:39 - INFO: Starting epoch 112:
100%|██████████| 181/181 [00:55<00:00, 4.64it/s, MSE=0.00994]
05:48:34 - INFO: Starting epoch 113:
100%|██████████| 181/181 [00:56<00:00, 4.57it/s, MSE=0.00712]
05:49:31 - INFO: Starting epoch 114:
100%|██████████| 181/181 [00:56<00:00, 4.61it/s, MSE=0.0414]
05:50:27 - INFO: Starting epoch 115:
100%|██████████| 181/181 [00:56<00:00, 4.61it/s, MSE=0.00445]
05:51:23 - INFO: Starting epoch 116:
100%|██████████| 181/181 [00:55<00:00, 4.66it/s, MSE=0.0967]
05:52:19 - INFO: Starting epoch 117:
100%|██████████| 181/181 [00:55<00:00, 4.55it/s, MSE=0.0384]
05:53:15 - INFO: Starting epoch 118:
100%|██████████| 181/181 [00:55<00:00, 4.60it/s, MSE=0.0122]
05:54:10 - INFO: Starting epoch 119:
100%|██████████| 181/181 [00:55<00:00, 4.66it/s, MSE=0.0342]
05:55:06 - INFO: Starting epoch 120:
100%|██████████| 181/181 [00:56<00:00, 4.63it/s, MSE=0.0257]
05:56:02 - INFO: Sampling 5 new images....
499it [00:44, 11.29it/s]
05:56:46 - INFO: Sampling 5 new images....
499it [00:45, 11.03it/s]
<Figure size 640x480 with 10 Axes>
05:57:32 - INFO: Starting epoch 121:
100%|██████████| 181/181 [00:55<00:00, 4.60it/s, MSE=0.00285]
05:58:28 - INFO: Starting epoch 122:
100%|██████████| 181/181 [00:55<00:00, 4.60it/s, MSE=0.0274]
05:59:24 - INFO: Starting epoch 123:
100%|██████████| 181/181 [00:55<00:00, 4.68it/s, MSE=0.0629]
06:00:19 - INFO: Starting epoch 124:
100%|██████████| 181/181 [00:55<00:00, 4.61it/s, MSE=0.0203]
06:01:15 - INFO: Starting epoch 125:
100%|██████████| 181/181 [00:55<00:00, 4.68it/s, MSE=0.0619]
06:02:10 - INFO: Starting epoch 126:
100%|██████████| 181/181 [00:55<00:00, 4.69it/s, MSE=0.0456]
06:03:05 - INFO: Starting epoch 127:
100%|██████████| 181/181 [00:55<00:00, 4.54it/s, MSE=0.0157]
06:04:01 - INFO: Starting epoch 128:
100%|██████████| 181/181 [00:55<00:00, 4.65it/s, MSE=0.0136]
06:04:57 - INFO: Starting epoch 129:
100%|██████████| 181/181 [00:55<00:00, 4.63it/s, MSE=0.115]
06:05:52 - INFO: Starting epoch 130:
100%|██████████| 181/181 [00:55<00:00, 4.68it/s, MSE=0.0519]
06:06:48 - INFO: Starting epoch 131:
100%|██████████| 181/181 [00:56<00:00, 4.61it/s, MSE=0.0179]
06:07:44 - INFO: Starting epoch 132:
100%|██████████| 181/181 [00:56<00:00, 4.72it/s, MSE=0.0211]
06:08:40 - INFO: Starting epoch 133:
100%|██████████| 181/181 [00:55<00:00, 4.69it/s, MSE=0.0172]
06:09:36 - INFO: Starting epoch 134:
100%|██████████| 181/181 [00:55<00:00, 4.58it/s, MSE=0.134]
06:10:31 - INFO: Starting epoch 135:
100%|██████████| 181/181 [00:55<00:00, 4.56it/s, MSE=0.201]
06:11:26 - INFO: Starting epoch 136:
100%|██████████| 181/181 [00:55<00:00, 4.64it/s, MSE=0.0325]
06:12:22 - INFO: Starting epoch 137:
100%|██████████| 181/181 [00:55<00:00, 4.70it/s, MSE=0.0203]
06:13:17 - INFO: Starting epoch 138:
100%|██████████| 181/181 [00:55<00:00, 4.60it/s, MSE=0.00265]
06:14:13 - INFO: Starting epoch 139:
100%|██████████| 181/181 [00:55<00:00, 4.66it/s, MSE=0.00424]
06:15:08 - INFO: Starting epoch 140:
100%|██████████| 181/181 [00:55<00:00, 4.73it/s, MSE=0.00383]
06:16:03 - INFO: Starting epoch 141:
100%|██████████| 181/181 [00:54<00:00, 4.66it/s, MSE=0.0153]
06:16:58 - INFO: Starting epoch 142:
100%|██████████| 181/181 [00:54<00:00, 4.68it/s, MSE=0.0284]
06:17:53 - INFO: Starting epoch 143:
100%|██████████| 181/181 [00:55<00:00, 4.55it/s, MSE=0.00366]
06:18:48 - INFO: Starting epoch 144:
100%|██████████| 181/181 [00:56<00:00, 4.52it/s, MSE=0.0912]
06:19:44 - INFO: Starting epoch 145:
100%|██████████| 181/181 [00:55<00:00, 4.65it/s, MSE=0.00704]
06:20:40 - INFO: Starting epoch 146:
100%|██████████| 181/181 [00:55<00:00, 4.72it/s, MSE=0.0042]
06:21:36 - INFO: Starting epoch 147:
100%|██████████| 181/181 [00:54<00:00, 4.63it/s, MSE=0.011]
06:22:31 - INFO: Starting epoch 148:
100%|██████████| 181/181 [00:54<00:00, 4.66it/s, MSE=0.0379]
06:23:26 - INFO: Starting epoch 149:
100%|██████████| 181/181 [00:55<00:00, 4.65it/s, MSE=0.11]
06:24:21 - INFO: Starting epoch 150:
100%|██████████| 181/181 [00:55<00:00, 4.72it/s, MSE=0.00667]
06:25:17 - INFO: Sampling 5 new images....
499it [00:43, 11.47it/s]
06:26:01 - INFO: Sampling 5 new images....
499it [00:42, 11.65it/s]
<Figure size 640x480 with 10 Axes>
06:26:44 - INFO: Starting epoch 151:
100%|██████████| 181/181 [00:55<00:00, 4.66it/s, MSE=0.00426]
06:27:39 - INFO: Starting epoch 152:
100%|██████████| 181/181 [00:55<00:00, 4.69it/s, MSE=0.0859]
06:28:34 - INFO: Starting epoch 153:
100%|██████████| 181/181 [00:55<00:00, 4.72it/s, MSE=0.0238]
06:29:29 - INFO: Starting epoch 154:
100%|██████████| 181/181 [00:54<00:00, 4.69it/s, MSE=0.0261]
06:30:24 - INFO: Starting epoch 155:
100%|██████████| 181/181 [00:54<00:00, 4.73it/s, MSE=0.049]
06:31:19 - INFO: Starting epoch 156:
100%|██████████| 181/181 [00:54<00:00, 4.68it/s, MSE=0.00625]
06:32:14 - INFO: Starting epoch 157:
100%|██████████| 181/181 [00:55<00:00, 4.71it/s, MSE=0.0107]
06:33:09 - INFO: Starting epoch 158:
100%|██████████| 181/181 [00:55<00:00, 4.66it/s, MSE=0.13]
06:34:04 - INFO: Starting epoch 159:
100%|██████████| 181/181 [00:55<00:00, 4.67it/s, MSE=0.0495]
06:34:59 - INFO: Starting epoch 160:
100%|██████████| 181/181 [00:55<00:00, 4.65it/s, MSE=0.0112]
06:35:54 - INFO: Starting epoch 161:
100%|██████████| 181/181 [00:55<00:00, 4.69it/s, MSE=0.00525]
06:36:49 - INFO: Starting epoch 162:
100%|██████████| 181/181 [00:55<00:00, 4.61it/s, MSE=0.00437]
06:37:44 - INFO: Starting epoch 163:
100%|██████████| 181/181 [00:54<00:00, 4.72it/s, MSE=0.00408]
06:38:39 - INFO: Starting epoch 164:
100%|██████████| 181/181 [00:55<00:00, 4.68it/s, MSE=0.0177]
06:39:35 - INFO: Starting epoch 165:
100%|██████████| 181/181 [00:55<00:00, 4.71it/s, MSE=0.00417]
06:40:30 - INFO: Starting epoch 166:
100%|██████████| 181/181 [00:55<00:00, 4.65it/s, MSE=0.0786]
06:41:25 - INFO: Starting epoch 167:
100%|██████████| 181/181 [00:55<00:00, 4.67it/s, MSE=0.0205]
06:42:20 - INFO: Starting epoch 168:
100%|██████████| 181/181 [00:55<00:00, 4.69it/s, MSE=0.0952]
06:43:15 - INFO: Starting epoch 169:
100%|██████████| 181/181 [00:55<00:00, 4.66it/s, MSE=0.0118]
06:44:10 - INFO: Starting epoch 170:
100%|██████████| 181/181 [00:55<00:00, 4.62it/s, MSE=0.253]
06:45:06 - INFO: Starting epoch 171:
100%|██████████| 181/181 [00:55<00:00, 4.71it/s, MSE=0.00373]
06:46:01 - INFO: Starting epoch 172:
100%|██████████| 181/181 [00:54<00:00, 4.67it/s, MSE=0.00618]
06:46:56 - INFO: Starting epoch 173:
100%|██████████| 181/181 [00:54<00:00, 4.67it/s, MSE=0.0083]
06:47:50 - INFO: Starting epoch 174:
100%|██████████| 181/181 [00:54<00:00, 4.67it/s, MSE=0.0171]
06:48:45 - INFO: Starting epoch 175:
100%|██████████| 181/181 [00:54<00:00, 4.60it/s, MSE=0.0216]
06:49:40 - INFO: Starting epoch 176:
100%|██████████| 181/181 [00:55<00:00, 4.71it/s, MSE=0.00168]
06:50:35 - INFO: Starting epoch 177:
100%|██████████| 181/181 [00:55<00:00, 4.69it/s, MSE=0.0166]
06:51:30 - INFO: Starting epoch 178:
100%|██████████| 181/181 [00:54<00:00, 4.72it/s, MSE=0.00656]
06:52:25 - INFO: Starting epoch 179:
100%|██████████| 181/181 [00:54<00:00, 4.68it/s, MSE=0.114]
06:53:20 - INFO: Starting epoch 180:
100%|██████████| 181/181 [00:54<00:00, 4.70it/s, MSE=0.00226]
06:54:15 - INFO: Sampling 5 new images....
499it [00:42, 11.66it/s]
06:54:58 - INFO: Sampling 5 new images....
499it [00:43, 11.56it/s]
<Figure size 640x480 with 10 Axes>
06:55:42 - INFO: Starting epoch 181:
100%|██████████| 181/181 [00:54<00:00, 4.72it/s, MSE=0.0484]
06:56:37 - INFO: Starting epoch 182:
100%|██████████| 181/181 [00:54<00:00, 4.69it/s, MSE=0.0196]
06:57:32 - INFO: Starting epoch 183:
100%|██████████| 181/181 [00:54<00:00, 4.68it/s, MSE=0.00695]
06:58:27 - INFO: Starting epoch 184:
100%|██████████| 181/181 [00:54<00:00, 4.69it/s, MSE=0.0515]
06:59:21 - INFO: Starting epoch 185:
100%|██████████| 181/181 [00:54<00:00, 4.67it/s, MSE=0.00296]
07:00:16 - INFO: Starting epoch 186:
100%|██████████| 181/181 [00:54<00:00, 4.76it/s, MSE=0.0878]
07:01:11 - INFO: Starting epoch 187:
100%|██████████| 181/181 [00:54<00:00, 4.73it/s, MSE=0.0574]
07:02:06 - INFO: Starting epoch 188:
100%|██████████| 181/181 [00:54<00:00, 4.75it/s, MSE=0.00468]
07:03:00 - INFO: Starting epoch 189:
100%|██████████| 181/181 [00:54<00:00, 4.61it/s, MSE=0.0289]
07:03:55 - INFO: Starting epoch 190:
100%|██████████| 181/181 [00:54<00:00, 4.73it/s, MSE=0.0167]
07:04:50 - INFO: Starting epoch 191:
100%|██████████| 181/181 [00:54<00:00, 4.65it/s, MSE=0.0505]
07:05:45 - INFO: Starting epoch 192:
100%|██████████| 181/181 [00:54<00:00, 4.78it/s, MSE=0.00374]
07:06:39 - INFO: Starting epoch 193:
100%|██████████| 181/181 [00:54<00:00, 4.67it/s, MSE=0.0176]
07:07:34 - INFO: Starting epoch 194:
100%|██████████| 181/181 [00:54<00:00, 4.76it/s, MSE=0.0161]
07:08:29 - INFO: Starting epoch 195:
100%|██████████| 181/181 [00:54<00:00, 4.63it/s, MSE=0.0161]
07:09:24 - INFO: Starting epoch 196:
100%|██████████| 181/181 [00:54<00:00, 4.66it/s, MSE=0.0358]
07:10:18 - INFO: Starting epoch 197:
100%|██████████| 181/181 [00:54<00:00, 4.67it/s, MSE=0.0694]
07:11:13 - INFO: Starting epoch 198:
100%|██████████| 181/181 [00:55<00:00, 4.66it/s, MSE=0.107]
07:12:09 - INFO: Starting epoch 199:
100%|██████████| 181/181 [00:55<00:00, 4.69it/s, MSE=0.0383]
07:13:04 - INFO: Starting epoch 200:
100%|██████████| 181/181 [00:55<00:00, 4.65it/s, MSE=0.0169]
07:13:59 - INFO: Starting epoch 201:
100%|██████████| 181/181 [00:55<00:00, 4.61it/s, MSE=0.0855]
07:14:54 - INFO: Starting epoch 202:
100%|██████████| 181/181 [00:55<00:00, 4.68it/s, MSE=0.00749]
07:15:49 - INFO: Starting epoch 203:
100%|██████████| 181/181 [00:55<00:00, 4.60it/s, MSE=0.00324]
07:16:45 - INFO: Starting epoch 204:
100%|██████████| 181/181 [00:54<00:00, 4.74it/s, MSE=0.0965]
07:17:40 - INFO: Starting epoch 205:
100%|██████████| 181/181 [00:54<00:00, 4.70it/s, MSE=0.0277]
07:18:34 - INFO: Starting epoch 206:
100%|██████████| 181/181 [00:54<00:00, 4.71it/s, MSE=0.0146]
07:19:29 - INFO: Starting epoch 207:
100%|██████████| 181/181 [00:54<00:00, 4.69it/s, MSE=0.00659]
07:20:24 - INFO: Starting epoch 208:
100%|██████████| 181/181 [00:54<00:00, 4.73it/s, MSE=0.0176]
07:21:19 - INFO: Starting epoch 209:
100%|██████████| 181/181 [00:54<00:00, 4.65it/s, MSE=0.12]
07:22:14 - INFO: Starting epoch 210:
100%|██████████| 181/181 [00:54<00:00, 4.76it/s, MSE=0.0688]
07:23:10 - INFO: Sampling 5 new images....
499it [00:43, 11.17it/s]
07:23:53 - INFO: Sampling 5 new images....
499it [00:43, 11.55it/s]
<Figure size 640x480 with 10 Axes>
07:24:37 - INFO: Starting epoch 211:
100%|██████████| 181/181 [00:54<00:00, 4.72it/s, MSE=0.00553]
07:25:31 - INFO: Starting epoch 212:
100%|██████████| 181/181 [00:55<00:00, 4.68it/s, MSE=0.0851]
07:26:26 - INFO: Starting epoch 213:
100%|██████████| 181/181 [00:54<00:00, 4.69it/s, MSE=0.0147]
07:27:21 - INFO: Starting epoch 214:
100%|██████████| 181/181 [00:55<00:00, 4.75it/s, MSE=0.0669]
07:28:16 - INFO: Starting epoch 215:
100%|██████████| 181/181 [00:54<00:00, 4.69it/s, MSE=0.00531]
07:29:11 - INFO: Starting epoch 216:
100%|██████████| 181/181 [00:54<00:00, 4.63it/s, MSE=0.0315]
07:30:06 - INFO: Starting epoch 217:
100%|██████████| 181/181 [00:54<00:00, 4.76it/s, MSE=0.147]
07:31:01 - INFO: Starting epoch 218:
100%|██████████| 181/181 [00:54<00:00, 4.63it/s, MSE=0.0547]
07:31:56 - INFO: Starting epoch 219:
100%|██████████| 181/181 [00:54<00:00, 4.74it/s, MSE=0.036]
07:32:50 - INFO: Starting epoch 220:
100%|██████████| 181/181 [00:54<00:00, 4.67it/s, MSE=0.00479]
07:33:45 - INFO: Starting epoch 221:
100%|██████████| 181/181 [00:55<00:00, 4.71it/s, MSE=0.0225]
07:34:40 - INFO: Starting epoch 222:
100%|██████████| 181/181 [00:54<00:00, 4.67it/s, MSE=0.0192]
07:35:35 - INFO: Starting epoch 223:
100%|██████████| 181/181 [00:54<00:00, 4.69it/s, MSE=0.00701]
07:36:30 - INFO: Starting epoch 224:
100%|██████████| 181/181 [00:54<00:00, 4.56it/s, MSE=0.036]
07:37:25 - INFO: Starting epoch 225:
100%|██████████| 181/181 [00:54<00:00, 4.79it/s, MSE=0.0908]
07:38:19 - INFO: Starting epoch 226:
100%|██████████| 181/181 [00:55<00:00, 4.72it/s, MSE=0.00345]
07:39:14 - INFO: Starting epoch 227:
100%|██████████| 181/181 [00:54<00:00, 4.73it/s, MSE=0.0657]
07:40:09 - INFO: Starting epoch 228:
100%|██████████| 181/181 [00:54<00:00, 4.69it/s, MSE=0.0841]
07:41:04 - INFO: Starting epoch 229:
100%|██████████| 181/181 [00:54<00:00, 4.72it/s, MSE=0.00407]
07:41:59 - INFO: Starting epoch 230:
100%|██████████| 181/181 [00:54<00:00, 4.66it/s, MSE=0.0249]
07:42:54 - INFO: Starting epoch 231:
100%|██████████| 181/181 [00:54<00:00, 4.71it/s, MSE=0.0563]
07:43:48 - INFO: Starting epoch 232:
100%|██████████| 181/181 [00:54<00:00, 4.67it/s, MSE=0.052]
07:44:43 - INFO: Starting epoch 233:
100%|██████████| 181/181 [00:54<00:00, 4.74it/s, MSE=0.0698]
07:45:38 - INFO: Starting epoch 234:
100%|██████████| 181/181 [00:54<00:00, 4.65it/s, MSE=0.0553]
07:46:33 - INFO: Starting epoch 235:
100%|██████████| 181/181 [00:54<00:00, 4.72it/s, MSE=0.00646]
07:47:27 - INFO: Starting epoch 236:
100%|██████████| 181/181 [00:54<00:00, 4.68it/s, MSE=0.00258]
07:48:22 - INFO: Starting epoch 237:
100%|██████████| 181/181 [00:55<00:00, 4.66it/s, MSE=0.0236]
07:49:17 - INFO: Starting epoch 238:
100%|██████████| 181/181 [00:54<00:00, 4.68it/s, MSE=0.0339]
07:50:12 - INFO: Starting epoch 239:
100%|██████████| 181/181 [00:54<00:00, 4.65it/s, MSE=0.00555]
07:51:07 - INFO: Starting epoch 240:
100%|██████████| 181/181 [00:55<00:00, 4.60it/s, MSE=0.0571]
07:52:03 - INFO: Sampling 5 new images....
499it [00:42, 11.79it/s]
07:52:45 - INFO: Sampling 5 new images....
499it [00:42, 11.62it/s]
<Figure size 640x480 with 10 Axes>
07:53:28 - INFO: Starting epoch 241:
100%|██████████| 181/181 [00:55<00:00, 4.68it/s, MSE=0.00577]
07:54:24 - INFO: Starting epoch 242:
100%|██████████| 181/181 [00:54<00:00, 4.69it/s, MSE=0.0155]
07:55:18 - INFO: Starting epoch 243:
100%|██████████| 181/181 [00:54<00:00, 4.64it/s, MSE=0.0322]
07:56:13 - INFO: Starting epoch 244:
100%|██████████| 181/181 [00:55<00:00, 4.76it/s, MSE=0.00787]
07:57:08 - INFO: Starting epoch 245:
100%|██████████| 181/181 [00:55<00:00, 4.69it/s, MSE=0.116]
07:58:03 - INFO: Starting epoch 246:
100%|██████████| 181/181 [00:54<00:00, 4.70it/s, MSE=0.0187]
07:58:58 - INFO: Starting epoch 247:
100%|██████████| 181/181 [00:54<00:00, 4.70it/s, MSE=0.059]
07:59:52 - INFO: Starting epoch 248:
100%|██████████| 181/181 [00:55<00:00, 4.60it/s, MSE=0.0248]
08:00:48 - INFO: Starting epoch 249:
100%|██████████| 181/181 [00:54<00:00, 4.60it/s, MSE=0.0254]
08:01:42 - INFO: Starting epoch 250:
100%|██████████| 181/181 [00:55<00:00, 4.70it/s, MSE=0.133]
08:02:38 - INFO: Starting epoch 251:
100%|██████████| 181/181 [00:55<00:00, 4.66it/s, MSE=0.0752]
08:03:33 - INFO: Starting epoch 252:
100%|██████████| 181/181 [00:55<00:00, 4.69it/s, MSE=0.00802]
08:04:28 - INFO: Starting epoch 253:
100%|██████████| 181/181 [00:54<00:00, 4.66it/s, MSE=0.254]
08:05:23 - INFO: Starting epoch 254:
100%|██████████| 181/181 [00:54<00:00, 4.60it/s, MSE=0.0261]
08:06:18 - INFO: Starting epoch 255:
100%|██████████| 181/181 [00:54<00:00, 4.62it/s, MSE=0.0514]
08:07:13 - INFO: Starting epoch 256:
100%|██████████| 181/181 [00:55<00:00, 4.69it/s, MSE=0.00751]
08:08:08 - INFO: Starting epoch 257:
100%|██████████| 181/181 [00:55<00:00, 4.70it/s, MSE=0.0209]
08:09:03 - INFO: Starting epoch 258:
100%|██████████| 181/181 [00:54<00:00, 4.63it/s, MSE=0.0484]
08:09:58 - INFO: Starting epoch 259:
100%|██████████| 181/181 [00:55<00:00, 4.69it/s, MSE=0.0255]
08:10:53 - INFO: Starting epoch 260:
100%|██████████| 181/181 [00:55<00:00, 4.58it/s, MSE=0.00507]
08:11:49 - INFO: Starting epoch 261:
100%|██████████| 181/181 [00:55<00:00, 4.61it/s, MSE=0.0218]
08:12:45 - INFO: Starting epoch 262:
100%|██████████| 181/181 [00:55<00:00, 4.60it/s, MSE=0.0203]
08:13:40 - INFO: Starting epoch 263:
100%|██████████| 181/181 [00:55<00:00, 4.61it/s, MSE=0.036]
08:14:36 - INFO: Starting epoch 264:
100%|██████████| 181/181 [00:55<00:00, 4.68it/s, MSE=0.0266]
08:15:32 - INFO: Starting epoch 265:
100%|██████████| 181/181 [00:55<00:00, 4.62it/s, MSE=0.0145]
08:16:27 - INFO: Starting epoch 266:
100%|██████████| 181/181 [00:55<00:00, 4.58it/s, MSE=0.00483]
08:17:23 - INFO: Starting epoch 267:
100%|██████████| 181/181 [00:55<00:00, 4.61it/s, MSE=0.0604]
08:18:18 - INFO: Starting epoch 268:
100%|██████████| 181/181 [00:55<00:00, 4.58it/s, MSE=0.0466]
08:19:14 - INFO: Starting epoch 269:
100%|██████████| 181/181 [00:56<00:00, 4.59it/s, MSE=0.00358]
08:20:10 - INFO: Starting epoch 270:
100%|██████████| 181/181 [00:55<00:00, 4.62it/s, MSE=0.0104]
08:21:06 - INFO: Sampling 5 new images....
499it [00:45, 11.06it/s]
08:21:51 - INFO: Sampling 5 new images....
499it [00:44, 11.33it/s]
<Figure size 640x480 with 10 Axes>
08:22:36 - INFO: Starting epoch 271:
100%|██████████| 181/181 [00:55<00:00, 4.60it/s, MSE=0.0111]
08:23:31 - INFO: Starting epoch 272:
100%|██████████| 181/181 [00:55<00:00, 4.68it/s, MSE=0.0474]
08:24:26 - INFO: Starting epoch 273:
100%|██████████| 181/181 [00:55<00:00, 4.60it/s, MSE=0.106]
08:25:22 - INFO: Starting epoch 274:
100%|██████████| 181/181 [00:55<00:00, 4.63it/s, MSE=0.00758]
08:26:18 - INFO: Starting epoch 275:
100%|██████████| 181/181 [00:55<00:00, 4.62it/s, MSE=0.00715]
08:27:13 - INFO: Starting epoch 276:
100%|██████████| 181/181 [00:55<00:00, 4.60it/s, MSE=0.0412]
08:28:08 - INFO: Starting epoch 277:
100%|██████████| 181/181 [00:55<00:00, 4.58it/s, MSE=0.00941]
08:29:04 - INFO: Starting epoch 278:
100%|██████████| 181/181 [00:55<00:00, 4.67it/s, MSE=0.0251]
08:29:59 - INFO: Starting epoch 279:
100%|██████████| 181/181 [00:55<00:00, 4.65it/s, MSE=0.0385]
08:30:55 - INFO: Starting epoch 280:
100%|██████████| 181/181 [00:55<00:00, 4.64it/s, MSE=0.0308]
08:31:50 - INFO: Starting epoch 281:
100%|██████████| 181/181 [00:55<00:00, 4.65it/s, MSE=0.108]
08:32:45 - INFO: Starting epoch 282:
100%|██████████| 181/181 [00:55<00:00, 4.61it/s, MSE=0.049]
08:33:41 - INFO: Starting epoch 283:
100%|██████████| 181/181 [00:55<00:00, 4.70it/s, MSE=0.0046]
08:34:37 - INFO: Starting epoch 284:
100%|██████████| 181/181 [00:55<00:00, 4.56it/s, MSE=0.00371]
08:35:32 - INFO: Starting epoch 285:
100%|██████████| 181/181 [00:55<00:00, 4.61it/s, MSE=0.00421]
08:36:27 - INFO: Starting epoch 286:
100%|██████████| 181/181 [00:54<00:00, 4.72it/s, MSE=0.00429]
08:37:22 - INFO: Starting epoch 287:
100%|██████████| 181/181 [00:55<00:00, 4.67it/s, MSE=0.0258]
08:38:17 - INFO: Starting epoch 288:
100%|██████████| 181/181 [00:55<00:00, 4.58it/s, MSE=0.0109]
08:39:13 - INFO: Starting epoch 289:
100%|██████████| 181/181 [00:55<00:00, 4.63it/s, MSE=0.152]
08:40:08 - INFO: Starting epoch 290:
100%|██████████| 181/181 [00:55<00:00, 4.56it/s, MSE=0.0362]
08:41:04 - INFO: Starting epoch 291:
100%|██████████| 181/181 [00:55<00:00, 4.65it/s, MSE=0.0161]
08:41:59 - INFO: Starting epoch 292:
100%|██████████| 181/181 [00:56<00:00, 4.52it/s, MSE=0.167]
08:42:56 - INFO: Starting epoch 293:
100%|██████████| 181/181 [00:56<00:00, 4.46it/s, MSE=0.00373]
08:43:52 - INFO: Starting epoch 294:
100%|██████████| 181/181 [00:55<00:00, 4.65it/s, MSE=0.0298]
08:44:48 - INFO: Starting epoch 295:
100%|██████████| 181/181 [00:55<00:00, 4.58it/s, MSE=0.0306]
08:45:43 - INFO: Starting epoch 296:
100%|██████████| 181/181 [00:56<00:00, 4.52it/s, MSE=0.0111]
08:46:40 - INFO: Starting epoch 297:
100%|██████████| 181/181 [00:55<00:00, 4.65it/s, MSE=0.0232]
08:47:36 - INFO: Starting epoch 298:
100%|██████████| 181/181 [00:56<00:00, 4.03it/s, MSE=0.0217]
08:48:32 - INFO: Starting epoch 299:
100%|██████████| 181/181 [00:56<00:00, 4.63it/s, MSE=0.05]
3.5、使用训练好的模型来采样各种花朵
In [12]
import paddle
model = UNet_conditional(num_classes=5)
model.set_state_dict(paddle.load("models/ddpm_cond270.pdparams")) # 加载模型文件
diffusion = Diffusion(img_size=64, device="cuda")
# 向日葵,玫瑰,郁金香,蒲公英,雏菊分别对应标签0,1,2,3,4
labels = paddle.to_tensor([0, 0, 0, 0, 0]).astype("int64")
# 标签引导强度
cfg_scale = 7
sampled_images = diffusion.sample(model, n=len(labels), labels=labels, cfg_scale=cfg_scale)
for i in range(5):
img = sampled_images[i].transpose([1, 2, 0])
img = np.array(img).astype("uint8")
plt.subplot(1,5,i+1)
plt.imshow(img)
plt.show()
09:01:17 - INFO: Sampling 5 new images....
499it [00:43, 11.60it/s]
<Figure size 640x480 with 5 Axes>
4、总结
-
推理出了diffusion模型的损失函数,从最小化对数似然,到优化变分下界,简化变分下界,得到最后目标,预测噪声。
-
提供了两版代码,其中条件生成与时下最火的text2image原理类似,只是text2image不仅仅使用单一类别作为编码。参考novelai。
-
作为新一代生成模型,diffusion训练的过程可谓是十分的稳定,调参也比GAN相对简单不少!
-
想要更好结果,我们只需要加大T,加大epoch即可。