昇思 25 天学习打卡营第 25 天 | MindSpore Diffusion 扩散模型

1. 背景:

使用 MindSpore 学习神经网络,打卡第 25 天;主要内容也依据 mindspore 的学习记录。

2. Diffusion 介绍:

Diffusion也是从纯噪声开始通过一个神经网络学习逐步去噪,最终得到一个实际图像。

3. 具体实现:

mindspore 实验中使用离散时间;

3.1 数据下载与处理:

python 复制代码
import math
from functools import partial
%matplotlib inline
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import numpy as np
from multiprocessing import cpu_count
from download import download

import mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import Tensor, Parameter
from mindspore import dtype as mstype
from mindspore.dataset.vision import Resize, Inter, CenterCrop, ToTensor, RandomHorizontalFlip, ToPIL
from mindspore.common.initializer import initializer
from mindspore.amp import DynamicLossScaler

ms.set_seed(0)

# 下载猫猫图像
url = 'https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/image_cat.zip'
path = download(url, './', kind="zip", replace=True)

from mindspore.dataset import ImageFolderDataset

# 从PIL图像转换到mindspore张量, 线性缩放为 [−1,1]
image_size = 128
transforms = [
    Resize(image_size, Inter.BILINEAR),
    CenterCrop(image_size),
    ToTensor(),
    lambda t: (t * 2) - 1
]


path = './image_cat'
dataset = ImageFolderDataset(dataset_dir=path, num_parallel_workers=cpu_count(),
                             extensions=['.jpg', '.jpeg', '.png', '.tiff'],
                             num_shards=1, shard_id=0, shuffle=False,  decode=True)
dataset = dataset.project('image')
transforms.insert(1, RandomHorizontalFlip())
dataset_1 = dataset.map(transforms, 'image')
dataset_2 = dataset_1.batch(1, drop_remainder=True)
x_start = next(dataset_2.create_tuple_iterator())[0]
print(x_start.shape)

import numpy as np

# 反向变换,它接收一个包含 [−1,1] 中的张量,并将它们转回 PIL 图像
reverse_transform = [
    lambda t: (t + 1) / 2,
    lambda t: ops.permute(t, (1, 2, 0)), # CHW to HWC
    lambda t: t * 255.,
    lambda t: t.asnumpy().astype(np.uint8),
    ToPIL()
]

def compose(transform, x):
    for d in transform:
        x = d(x)
    return x

reverse_image = compose(reverse_transform, x_start[0])
reverse_image.show()

3.2 构造 Diffusion 网络

  • 定义模型的辅助方法:
python 复制代码
def rearrange(head, inputs):
    b, hc, x, y = inputs.shape
    c = hc // head
    return inputs.reshape((b, head, c, x * y))

def rsqrt(x):
    res = ops.sqrt(x)
    return ops.inv(res)

def randn_like(x, dtype=None):
    if dtype is None:
        dtype = x.dtype
    res = ops.standard_normal(x.shape).astype(dtype)
    return res

def randn(shape, dtype=None):
    if dtype is None:
        dtype = ms.float32
    res = ops.standard_normal(shape).astype(dtype)
    return res

def randint(low, high, size, dtype=ms.int32):
    res = ops.uniform(size, Tensor(low, dtype), Tensor(high, dtype), dtype=dtype)
    return res

def exists(x):
    return x is not None

def default(val, d):
    if exists(val):
        return val
    return d() if callable(d) else d

def _check_dtype(d1, d2):
    if ms.float32 in (d1, d2):
        return ms.float32
    if d1 == d2:
        return d1
    raise ValueError('dtype is not supported.')

class Residual(nn.Cell):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def construct(self, x, *args, **kwargs):
        return self.fn(x, *args, **kwargs) + x

def Upsample(dim):
    return nn.Conv2dTranspose(dim, dim, 4, 2, pad_mode="pad", padding=1)

def Downsample(dim):
    return nn.Conv2d(dim, dim, 4, 2, pad_mode="pad", padding=1)
  • 位置向量
  • 由于神经网络的参数在时间(噪声水平)上共享,作者使用正弦位置嵌入来编码t
    ,灵感来自Transformer(Vaswani et al., 2017)。对于批处理中的每一张图像,神经网络"知道"它在哪个特定时间步长(噪声水平)上运行
python 复制代码
class SinusoidalPositionEmbeddings(nn.Cell):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = np.exp(np.arange(half_dim) * - emb)
        self.emb = Tensor(emb, ms.float32)

    def construct(self, x):
        emb = x[:, None] * self.emb[None, :]
        emb = ops.concat((ops.sin(emb), ops.cos(emb)), axis=-1)
        return emb
  • ResNet/ConvNeXT块

原作者使用 Wide ResNet 块,Phil Wang决定添加 ConvNeXT 替换 ResNet block.

python 复制代码
class Block(nn.Cell):
    def __init__(self, dim, dim_out, groups=1):
        super().__init__()
        self.proj = nn.Conv2d(dim, dim_out, 3, pad_mode="pad", padding=1)
        self.proj = c(dim, dim_out, 3, padding=1, pad_mode='pad')
        self.norm = nn.GroupNorm(groups, dim_out)
        self.act = nn.SiLU()

    def construct(self, x, scale_shift=None):
        x = self.proj(x)
        x = self.norm(x)

        if exists(scale_shift):
            scale, shift = scale_shift
            x = x * (scale + 1) + shift

        x = self.act(x)
        return x

class ConvNextBlock(nn.Cell):
    def __init__(self, dim, dim_out, *, time_emb_dim=None, mult=2, norm=True):
        super().__init__()
        self.mlp = (
            nn.SequentialCell(nn.GELU(), nn.Dense(time_emb_dim, dim))
            if exists(time_emb_dim)
            else None
        )

        self.ds_conv = nn.Conv2d(dim, dim, 7, padding=3, group=dim, pad_mode="pad")
        self.net = nn.SequentialCell(
            nn.GroupNorm(1, dim) if norm else nn.Identity(),
            nn.Conv2d(dim, dim_out * mult, 3, padding=1, pad_mode="pad"),
            nn.GELU(),
            nn.GroupNorm(1, dim_out * mult),
            nn.Conv2d(dim_out * mult, dim_out, 3, padding=1, pad_mode="pad"),
        )

        self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

    def construct(self, x, time_emb=None):
        h = self.ds_conv(x)
        if exists(self.mlp) and exists(time_emb):
            assert exists(time_emb), "time embedding must be passed in"
            condition = self.mlp(time_emb)
            condition = condition.expand_dims(-1).expand_dims(-1)
            h = h + condition

        h = self.net(h)
        return h + self.res_conv(x)
  • Attention 模块:
    DDPM作者将其添加到卷积块之间。Attention是著名的Transformer架构(Vaswani et al., 2017),在人工智能的各个领域都取得了巨大的成功,从NLP到蛋白质折叠。Phil Wang使用了两种注意力变体:一种是常规的multi-head self-attention(如Transformer中使用的),另一种是LinearAttention(Shen et al., 2018),其时间和内存要求在序列长度上线性缩放,而不是在常规注意力中缩放。 要想对Attention机制进行深入的了解。
python 复制代码
class Attention(nn.Cell):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = dim_head * heads

        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, pad_mode='valid', has_bias=False)
        self.to_out = nn.Conv2d(hidden_dim, dim, 1, pad_mode='valid', has_bias=True)
        self.map = ops.Map()
        self.partial = ops.Partial()

    def construct(self, x):
        b, _, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, 1)
        q, k, v = self.map(self.partial(rearrange, self.heads), qkv)

        q = q * self.scale

        # 'b h d i, b h d j -> b h i j'
        sim = ops.bmm(q.swapaxes(2, 3), k)
        attn = ops.softmax(sim, axis=-1)
        # 'b h i j, b h d j -> b h i d'
        out = ops.bmm(attn, v.swapaxes(2, 3))
        out = out.swapaxes(-1, -2).reshape((b, -1, h, w))

        return self.to_out(out)


class LayerNorm(nn.Cell):
    def __init__(self, dim):
        super().__init__()
        self.g = Parameter(initializer('ones', (1, dim, 1, 1)), name='g')

    def construct(self, x):
        eps = 1e-5
        var = x.var(1, keepdims=True)
        mean = x.mean(1, keep_dims=True)
        return (x - mean) * rsqrt((var + eps)) * self.g


class LinearAttention(nn.Cell):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, pad_mode='valid', has_bias=False)

        self.to_out = nn.SequentialCell(
            nn.Conv2d(hidden_dim, dim, 1, pad_mode='valid', has_bias=True),
            LayerNorm(dim)
        )

        self.map = ops.Map()
        self.partial = ops.Partial()

    def construct(self, x):
        b, _, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, 1)
        q, k, v = self.map(self.partial(rearrange, self.heads), qkv)

        q = ops.softmax(q, -2)
        k = ops.softmax(k, -1)

        q = q * self.scale
        v = v / (h * w)

        # 'b h d n, b h e n -> b h d e'
        context = ops.bmm(k, v.swapaxes(2, 3))
        # 'b h d e, b h d n -> b h e n'
        out = ops.bmm(context.swapaxes(2, 3), q)

        out = out.reshape((b, -1, h, w))
        return self.to_out(out)
  • 组归一化
    作者将U-Net的卷积/注意层与群归一化(Wu et al., 2018)。下面,我们定义一个PreNorm类,将用于在注意层之前应用groupnorm。
python 复制代码
class PreNorm(nn.Cell):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = nn.GroupNorm(1, dim)

    def construct(self, x):
        x = self.norm(x)
        return self.fn(x)
  • 条件 U-Net:

网络构建过程如下:

a. 将卷积层应用于噪声图像批上,并计算噪声水平的位置

b. 接下来,应用一系列下采样级。每个下采样阶段由2个ResNet/ConvNeXT块 + groupnorm + attention + 残差连接 + 一个下采样操作组成

c. 在网络的中间,再次应用ResNet或ConvNeXT块,并与attention交织

d. 应用一系列上采样级。每个上采样级由2个ResNet/ConvNeXT块+ groupnorm + attention + 残差连接 + 一个上采样操作组成

e. 应用ResNet/ConvNeXT块,然后应用卷积层

python 复制代码
class Unet(nn.Cell):
    def __init__(
            self,
            dim,
            init_dim=None,
            out_dim=None,
            dim_mults=(1, 2, 4, 8),
            channels=3,
            with_time_emb=True,
            convnext_mult=2,
    ):
        super().__init__()

        self.channels = channels

        init_dim = default(init_dim, dim // 3 * 2)
        self.init_conv = nn.Conv2d(channels, init_dim, 7, padding=3, pad_mode="pad", has_bias=True)

        dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
        in_out = list(zip(dims[:-1], dims[1:]))

        block_klass = partial(ConvNextBlock, mult=convnext_mult)

        if with_time_emb:
            time_dim = dim * 4
            self.time_mlp = nn.SequentialCell(
                SinusoidalPositionEmbeddings(dim),
                nn.Dense(dim, time_dim),
                nn.GELU(),
                nn.Dense(time_dim, time_dim),
            )
        else:
            time_dim = None
            self.time_mlp = None

        self.downs = nn.CellList([])
        self.ups = nn.CellList([])
        num_resolutions = len(in_out)

        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (num_resolutions - 1)

            self.downs.append(
                nn.CellList(
                    [
                        block_klass(dim_in, dim_out, time_emb_dim=time_dim),
                        block_klass(dim_out, dim_out, time_emb_dim=time_dim),
                        Residual(PreNorm(dim_out, LinearAttention(dim_out))),
                        Downsample(dim_out) if not is_last else nn.Identity(),
                    ]
                )
            )

        mid_dim = dims[-1]
        self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
        self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
        self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)

        for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
            is_last = ind >= (num_resolutions - 1)

            self.ups.append(
                nn.CellList(
                    [
                        block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim),
                        block_klass(dim_in, dim_in, time_emb_dim=time_dim),
                        Residual(PreNorm(dim_in, LinearAttention(dim_in))),
                        Upsample(dim_in) if not is_last else nn.Identity(),
                    ]
                )
            )

        out_dim = default(out_dim, channels)
        self.final_conv = nn.SequentialCell(
            block_klass(dim, dim), nn.Conv2d(dim, out_dim, 1)
        )

    def construct(self, x, time):
        x = self.init_conv(x)

        t = self.time_mlp(time) if exists(self.time_mlp) else None

        h = []

        for block1, block2, attn, downsample in self.downs:
            x = block1(x, t)
            x = block2(x, t)
            x = attn(x)
            h.append(x)

            x = downsample(x)

        x = self.mid_block1(x, t)
        x = self.mid_attn(x)
        x = self.mid_block2(x, t)

        len_h = len(h) - 1
        for block1, block2, attn, upsample in self.ups:
            x = ops.concat((x, h[len_h]), 1)
            len_h -= 1
            x = block1(x, t)
            x = block2(x, t)
            x = attn(x)

            x = upsample(x)
        return self.final_conv(x)

3.3 正向扩散

正向扩散

python 复制代码
def linear_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    return np.linspace(beta_start, beta_end, timesteps).astype(np.float32)

# 扩散200步
timesteps = 200

# 定义 beta schedule
betas = linear_beta_schedule(timesteps=timesteps)

# 定义 alphas
alphas = 1. - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod_prev = np.pad(alphas_cumprod[:-1], (1, 0), constant_values=1)

sqrt_recip_alphas = Tensor(np.sqrt(1. / alphas))
sqrt_alphas_cumprod = Tensor(np.sqrt(alphas_cumprod))
sqrt_one_minus_alphas_cumprod = Tensor(np.sqrt(1. - alphas_cumprod))

# 计算 q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)

p2_loss_weight = (1 + alphas_cumprod / (1 - alphas_cumprod)) ** -0.
p2_loss_weight = Tensor(p2_loss_weight)

def extract(a, t, x_shape):
    b = t.shape[0]
    out = Tensor(a).gather(t, -1)
    return out.reshape(b, *((1,) * (len(x_shape) - 1)))
  • 正向扩散的过程,并在特定时间步长上测试它
python 复制代码
def get_noisy_image(x_start, t):
    # 添加噪音
    x_noisy = q_sample(x_start, t=t)

    # 转换为 PIL 图像
    noisy_image = compose(reverse_transform, x_noisy[0])

    return noisy_image


# 设置 time step
t = Tensor([40])
noisy_image = get_noisy_image(x_start, t)
print(noisy_image)
noisy_image.show()
  • 可视化
python 复制代码
import matplotlib.pyplot as plt

def plot(imgs, with_orig=False, row_title=None, **imshow_kwargs):
    if not isinstance(imgs[0], list):
        imgs = [imgs]

    num_rows = len(imgs)
    num_cols = len(imgs[0]) + with_orig
    _, axs = plt.subplots(figsize=(200, 200), nrows=num_rows, ncols=num_cols, squeeze=False)
    for row_idx, row in enumerate(imgs):
        row = [image] + row if with_orig else row
        for col_idx, img in enumerate(row):
            ax = axs[row_idx, col_idx]
            ax.imshow(np.asarray(img), **imshow_kwargs)
            ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

    if with_orig:
        axs[0, 0].set(title='Original image')
        axs[0, 0].title.set_size(8)
    if row_title is not None:
        for row_idx in range(num_rows):
            axs[row_idx, 0].set(ylabel=row_title[row_idx])

    plt.tight_layout()
 
plot([get_noisy_image(x_start, Tensor([t])) for t in [0, 50, 100, 150, 199]])
  • 定义模型的损失函数
python 复制代码
def p_losses(unet_model, x_start, t, noise=None):
    if noise is None:
        noise = randn_like(x_start)
    x_noisy = q_sample(x_start=x_start, t=t, noise=noise)
    predicted_noise = unet_model(x_noisy, t)

    loss = nn.SmoothL1Loss()(noise, predicted_noise)# todo
    loss = loss.reshape(loss.shape[0], -1)
    loss = loss * extract(p2_loss_weight, t, loss.shape)
    return loss.mean()

3.4 训练

  • 下载数据并转换

    下载MNIST数据集

    url = 'https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/dataset.zip'
    path = download(url, './', kind="zip", replace=True)

    from mindspore.dataset import FashionMnistDataset

    image_size = 28
    channels = 1
    batch_size = 16

    fashion_mnist_dataset_dir = "./dataset"
    dataset = FashionMnistDataset(dataset_dir=fashion_mnist_dataset_dir, usage="train", num_parallel_workers=cpu_count(), shuffle=True, num_shards=1, shard_id=0)

    transforms = [
    RandomHorizontalFlip(),
    ToTensor(),
    lambda t: (t * 2) - 1
    ]

    dataset = dataset.project('image')
    dataset = dataset.shuffle(64)
    dataset = dataset.map(transforms, 'image')
    dataset = dataset.batch(16, drop_remainder=True)

    x = next(dataset.create_dict_iterator())
    print(x.keys())

  • 采样:
    训练期间从模型中采样(以便跟踪进度)

python 复制代码
def p_sample(model, x, t, t_index):
    betas_t = extract(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(
        sqrt_one_minus_alphas_cumprod, t, x.shape
    )
    sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape)
    model_mean = sqrt_recip_alphas_t * (x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t)

    if t_index == 0:
        return model_mean
    posterior_variance_t = extract(posterior_variance, t, x.shape)
    noise = randn_like(x)
    return model_mean + ops.sqrt(posterior_variance_t) * noise

def p_sample_loop(model, shape):
    b = shape[0]
    # 从纯噪声开始
    img = randn(shape, dtype=None)
    imgs = []

    for i in tqdm(reversed(range(0, timesteps)), desc='sampling loop time step', total=timesteps):
        img = p_sample(model, img, ms.numpy.full((b,), i, dtype=mstype.int32), i)
        imgs.append(img.asnumpy())
    return imgs

def sample(model, image_size, batch_size=16, channels=3):
    return p_sample_loop(model, shape=(batch_size, channels, image_size, image_size))
  • 训练
python 复制代码
# 定义动态学习率
lr = nn.cosine_decay_lr(min_lr=1e-7, max_lr=1e-4, total_step=10*3750, step_per_epoch=3750, decay_epoch=10)

# 定义 Unet模型
unet_model = Unet(
    dim=image_size,
    channels=channels,
    dim_mults=(1, 2, 4,)
)

name_list = []
for (name, par) in list(unet_model.parameters_and_names()):
    name_list.append(name)
i = 0
for item in list(unet_model.trainable_params()):
    item.name = name_list[i]
    i += 1

# 定义优化器
optimizer = nn.Adam(unet_model.trainable_params(), learning_rate=lr)
loss_scaler = DynamicLossScaler(65536, 2, 1000)

# 定义前向过程
def forward_fn(data, t, noise=None):
    loss = p_losses(unet_model, data, t, noise)
    return loss

# 计算梯度
grad_fn = ms.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=False)

# 梯度更新
def train_step(data, t, noise):
    loss, grads = grad_fn(data, t, noise)
    optimizer(grads)
    return loss

import time

epochs = 10
for epoch in range(epochs):
    begin_time = time.time()
    for step, batch in enumerate(dataset.create_tuple_iterator()):
        unet_model.set_train()
        batch_size = batch[0].shape[0]
        t = randint(0, timesteps, (batch_size,), dtype=ms.int32)
        noise = randn_like(batch[0])
        loss = train_step(batch[0], t, noise)

        if step % 500 == 0:
            print(" epoch: ", epoch, " step: ", step, " Loss: ", loss)
    end_time = time.time()
    times = end_time - begin_time
    print("training time:", times, "s")
    # 展示随机采样效果
    unet_model.set_train(False)
    samples = sample(unet_model, image_size=image_size, batch_size=64, channels=channels)
    plt.imshow(samples[-1][5].reshape(image_size, image_size, channels), cmap="gray")
print("Training Success!")

3.5 推理

要从模型中采样,我们可以只使用上面定义的采样函数

python 复制代码
# 采样64个图片
unet_model.set_train(False)
samples = sample(unet_model, image_size=image_size, batch_size=64, channels=channels)

# 展示一个随机效果
random_index = 5
plt.imshow(samples[-1][random_index].reshape(image_size, image_size, channels), cmap="gray")

# 创建去噪过程
import matplotlib.animation as animation

random_index = 53

fig = plt.figure()
ims = []
for i in range(timesteps):
    im = plt.imshow(samples[i][random_index].reshape(image_size, image_size, channels), cmap="gray", animated=True)
    ims.append([im])

animate = animation.ArtistAnimation(fig, ims, interval=50, blit=True, repeat_delay=100)
animate.save('diffusion.gif')
plt.show()

4. 相关链接:

相关推荐
昨日之日20062 小时前
Moonshine - 新型开源ASR(语音识别)模型,体积小,速度快,比OpenAI Whisper快五倍 本地一键整合包下载
人工智能·whisper·语音识别
浮生如梦_2 小时前
Halcon基于laws纹理特征的SVM分类
图像处理·人工智能·算法·支持向量机·计算机视觉·分类·视觉检测
深度学习lover2 小时前
<项目代码>YOLOv8 苹果腐烂识别<目标检测>
人工智能·python·yolo·目标检测·计算机视觉·苹果腐烂识别
热爱跑步的恒川3 小时前
【论文复现】基于图卷积网络的轻量化推荐模型
网络·人工智能·开源·aigc·ai编程
阡之尘埃5 小时前
Python数据分析案例61——信贷风控评分卡模型(A卡)(scorecardpy 全面解析)
人工智能·python·机器学习·数据分析·智能风控·信贷风控
孙同学要努力7 小时前
全连接神经网络案例——手写数字识别
人工智能·深度学习·神经网络
Eric.Lee20217 小时前
yolo v5 开源项目
人工智能·yolo·目标检测·计算机视觉
其实吧38 小时前
基于Matlab的图像融合研究设计
人工智能·计算机视觉·matlab
丕羽8 小时前
【Pytorch】基本语法
人工智能·pytorch·python
ctrey_8 小时前
2024-11-1 学习人工智能的Day20 openCV(2)
人工智能·opencv·学习