生成模型实战 | GLOW详解与实现

生成模型实战 | GLOW详解与实现

    • [0. 前言](#0. 前言)
    • [1. 归一化流模型](#1. 归一化流模型)
      • [1.1 归一化流与变换公式](#1.1 归一化流与变换公式)
      • [1.2 RealNVP 的通道翻转](#1.2 RealNVP 的通道翻转)
    • [2. GLOW 架构](#2. GLOW 架构)
      • [2.1 ActNorm](#2.1 ActNorm)
      • [2.2 可逆 1×1 卷积](#2.2 可逆 1×1 卷积)
      • [2.3 仿射耦合层](#2.3 仿射耦合层)
      • [2.4 多尺度架构](#2.4 多尺度架构)
    • [3. 使用 PyTorch 实现 GLOW](#3. 使用 PyTorch 实现 GLOW)
      • [3.1 数据处理](#3.1 数据处理)
      • [3.2 模型构建](#3.2 模型构建)
      • [3.3 模型训练](#3.3 模型训练)

0. 前言

GLOW (Generative Flow) 是一种基于归一化流的生成模型,通过在每个流步骤中引入可逆的 1 × 1 卷积层,替代了RealNVP中通道翻转或固定置换的策略,从而使通道重排更具表达力,同时保持雅可比行列式和逆变换的高效计算能力。本文首先回顾归一化流与 RealNVP 的基本原理,接着剖析 GLOW 的四大核心模块:ActNorm、可逆 1×1 卷积、仿射耦合层和多尺度架构,随后基于 PyTorch 实现 GLOW 模型,并在 CIFAR-10 数据集上进行训练。

1. 归一化流模型

1.1 归一化流与变换公式

在本节中,我们首先简要回顾归一化流模型的核心原理,归一化流利用可逆映射 f f f 将简单分布 p Z ( z ) p_Z(z) pZ(z) 转换到样本分布 p X ( x ) p_X(x) pX(x),并通过以下变换公式实现实现精确对数似然计算和采样:
p X ( x ) = p Z ( f ( x ) )   ∣ ⁡det⁡ ( ⁡ ∂ f ( x ) ∂ x ) ∣ p_X(x)=p_Z(f(x)) |\text{⁡det}⁡ (⁡\frac {∂f(x)}{∂x})| pX(x)=pZ(f(x)) ∣⁡det⁡(⁡∂x∂f(x))∣

1.2 RealNVP 的通道翻转

RealNVP 通过交替使用掩码耦合层 (masking coupling) 和按通道翻转 (reverse channels) 或固定置换,保证每个通道都能被多次变换。

2. GLOW 架构

2.1 ActNorm

ActNorm 是一种专为流模型设计的通道级归一化方法,于 GLOW 中首次提出。该层对输入激活 x x x 执行可学习仿射变换:
y = s ⊙ x + b y=s⊙x+b y=s⊙x+b

其中 s , b ∈ R C s,b∈\mathbb R^C s,b∈RC 分别为每个通道的尺度与偏移参数。这些参数在首次前向传播时通过对一个 minibatch 计算输出通道的均值 μ μ μ 和标准差 σ σ σ 进行数据依赖的初始化,使得初始化后的 y y y 满足 E [ y ] = 0 \mathbb E[y]=0 E[y]=0 和 V a r [ y ] = 1 Var[y]=1 Var[y]=1。

与批归一化不同,ActNorm 仅在初始化时依赖 minibatch,之后无需维护运行时统计量,从而提升了小批数据和解耦训练的稳定性。

python 复制代码
import torch.nn as nn
import torch

def mean_dim(tensor, dim=None, keepdims=False):
    if dim is None:
        return tensor.mean()
    else:
        if isinstance(dim, int):
            dim = [dim]
        dim = sorted(dim)
        for d in dim:
            tensor = tensor.mean(dim=d, keepdim=True)
        if not keepdims:
            for i, d in enumerate(dim):
                tensor.squeeze_(d-i)
        return tensor

class ActNorm(nn.Module):
    def __init__(self, num_features, scale=1., return_ldj=False):
        super(ActNorm, self).__init__()
        self.register_buffer('is_initialized', torch.zeros(1))
        self.bias = nn.Parameter(torch.zeros(1, num_features, 1, 1))
        self.logs = nn.Parameter(torch.zeros(1, num_features, 1, 1))

        self.num_features = num_features
        self.scale = float(scale)
        self.eps = 1e-6
        self.return_ldj = return_ldj

    def initialize_parameters(self, x):
        if not self.training:
            return

        with torch.no_grad():
            bias = -mean_dim(x.clone(), dim=[0, 2, 3], keepdims=True)
            v = mean_dim((x.clone() + bias) ** 2, dim=[0, 2, 3], keepdims=True)
            logs = (self.scale / (v.sqrt() + self.eps)).log()
            self.bias.data.copy_(bias.data)
            self.logs.data.copy_(logs.data)
            self.is_initialized += 1.

    def _center(self, x, reverse=False):
        if reverse:
            return x - self.bias
        else:
            return x + self.bias

    def _scale(self, x, sldj, reverse=False):
        logs = self.logs
        if reverse:
            x = x * logs.mul(-1).exp()
        else:
            x = x * logs.exp()

        if sldj is not None:
            ldj = logs.sum() * x.size(2) * x.size(3)
            if reverse:
                sldj = sldj - ldj
            else:
                sldj = sldj + ldj

        return x, sldj

    def forward(self, x, ldj=None, reverse=False):
        if not self.is_initialized:
            self.initialize_parameters(x)

        if reverse:
            x, ldj = self._scale(x, ldj, reverse)
            x = self._center(x, reverse)
        else:
            x = self._center(x, reverse)
            x, ldj = self._scale(x, ldj, reverse)

        if self.return_ldj:
            return x, ldj

        return x

2.2 可逆 1×1 卷积

GLOW 中的可逆 1×1 卷积用一个 C × C C×C C×C 的可学习矩阵 W W W 取代了 RealNVP 中的固定通道翻转或置换操作。在空间位置 ( i , j ) (i,j) (i,j) 上,其映射可写为:
y i , j = W   x i , j y_{i,j}=W x_{i,j} yi,j=W xi,j

对应的对数行列式为:
log ⁡det⁡ ⁣ ∣ ∂ y ∂ x ∣ = H × W × log ∣ ⁡det W ∣ \text {log}\ \text{⁡det}⁡ ⁣|\frac{∂y}{∂x}|=H×W×\text {log}|\text{⁡det}W| log ⁡det⁡ ⁣∣∂x∂y∣=H×W×log∣⁡detW∣

其中 H , W H,W H,W 分别为空间高宽。

为了进一步加速行列式与逆矩阵的计算,通常将 W W W 参数化为 LU 分解形式,即 W = P L U W=PLU W=PLU,只需学习下三角矩阵 L L L 和上三角矩阵 U U U 的非对角元素,行列式则为 ∏ i U i i ∏iU{ii} ∏iUii。

通过这种可学习的通道重排,模型能够自动挖掘最优的特征混合方式,从而在生成质量与训练效率上均取得显著提升。

python 复制代码
import numpy as np

class InvConv(nn.Module):
    def __init__(self, num_channels):
        super(InvConv, self).__init__()
        self.num_channels = num_channels

        # Initialize with a random orthogonal matrix
        w_init = np.random.randn(num_channels, num_channels)
        w_init = np.linalg.qr(w_init)[0].astype(np.float32)
        self.weight = nn.Parameter(torch.from_numpy(w_init))

    def forward(self, x, sldj, reverse=False):
        ldj = torch.slogdet(self.weight)[1] * x.size(2) * x.size(3)

        if reverse:
            weight = torch.inverse(self.weight.double()).float()
            sldj = sldj - ldj
        else:
            weight = self.weight
            sldj = sldj + ldj

        weight = weight.view(self.num_channels, self.num_channels, 1, 1)
        z = F.conv2d(x, weight)

        return z, sldj

2.3 仿射耦合层

仿射耦合层最早在 RealNVP 中提出,是 GLOW 中不可或缺的组成部分。该层将输入 x ∈ R C × H × W x∈\mathbb R^{C×H×W} x∈RC×H×W 沿通道维度划分为两部分 ( x a , x b ) (x_a,x_b) (xa,xb),并通过神经网络生成尺度和平移参数保证了整个变换的可逆性:
( s , t ) = N N ( x b ) , y a = s ( x b ) ⊙ x a + t ( x b ) , y b = x b (s,t)=NN(x_b),ya=s(x_b)⊙x_a+t(x_b),y_b=x_b (s,t)=NN(xb),ya=s(xb)⊙xa+t(xb),yb=xb

其对数雅可比行列式可高效地计算为:
∑ h , w ,    c ∈ a log⁡ s c ( x b [ h , w ] ) ∑_{h,w,  c∈a}\text {log}⁡s_c(x_b[h,w]) h,w,  c∈a∑log⁡sc(xb[h,w])

仅与输出尺度参数 s s s 的元素相加相关,计算复杂度随输入维度线性增长。

python 复制代码
import torch.nn.functional as F

class Coupling(nn.Module):
    def __init__(self, in_channels, mid_channels):
        super(Coupling, self).__init__()
        self.nn = NN(in_channels, mid_channels, 2 * in_channels)
        self.scale = nn.Parameter(torch.ones(in_channels, 1, 1))

    def forward(self, x, ldj, reverse=False):
        x_change, x_id = x.chunk(2, dim=1)

        st = self.nn(x_id)
        s, t = st[:, 0::2, ...], st[:, 1::2, ...]
        s = self.scale * torch.tanh(s)

        # Scale and translate
        if reverse:
            x_change = x_change * s.mul(-1).exp() - t
            ldj = ldj - s.flatten(1).sum(-1)
        else:
            x_change = (x_change + t) * s.exp()
            ldj = ldj + s.flatten(1).sum(-1)

        x = torch.cat((x_change, x_id), dim=1)

        return x, ldj


class NN(nn.Module):
    def __init__(self, in_channels, mid_channels, out_channels,
                 use_act_norm=False):
        super(NN, self).__init__()
        norm_fn = ActNorm if use_act_norm else nn.BatchNorm2d

        self.in_norm = norm_fn(in_channels)
        self.in_conv = nn.Conv2d(in_channels, mid_channels,
                                 kernel_size=3, padding=1, bias=False)
        nn.init.normal_(self.in_conv.weight, 0., 0.05)

        self.mid_norm = norm_fn(mid_channels)
        self.mid_conv = nn.Conv2d(mid_channels, mid_channels,
                                  kernel_size=1, padding=0, bias=False)
        nn.init.normal_(self.mid_conv.weight, 0., 0.05)

        self.out_norm = norm_fn(mid_channels)
        self.out_conv = nn.Conv2d(mid_channels, out_channels,
                                  kernel_size=3, padding=1, bias=True)
        nn.init.zeros_(self.out_conv.weight)
        nn.init.zeros_(self.out_conv.bias)

    def forward(self, x):
        x = self.in_norm(x)
        x = F.relu(x)
        x = self.in_conv(x)

        x = self.mid_norm(x)
        x = F.relu(x)
        x = self.mid_conv(x)

        x = self.out_norm(x)
        x = F.relu(x)
        x = self.out_conv(x)

        return x

2.4 多尺度架构

GLOW 延续了 RealNVP 的多尺度架构思想,通过分层的流步骤和因子化操作将中间表示逐级分解。整体模型由 L L L 个尺度 (level) 组成,每个尺度内部包含 K K K 次完整的流步骤 (step),每步依次执行 ActNorm、可逆 1×1 卷积和仿射耦合层。

在每个尺度结束时,先通过 squeeze 操作将特征图空间大小减少至原像素的四分之一(同时通道数扩大四倍),然后使用 split 操作将部分通道因子化为潜变量 z z z,余下通道继续进入下一级流。

这种多尺度分解在保持对数似然精度的同时,有效降低了计算与存储开销,并在不同尺度上捕捉图像的全局与局部结构信息。

3. 使用 PyTorch 实现 GLOW

在本节中,使用 PyTorch 实现 GLOW,并在 CIFAR-10 数据集上进行训练。

3.1 数据处理

torchvision.transformsCIFAR-10 图像进行处理:

python 复制代码
transform_train = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor()
    ])

transform_test = transforms.Compose([
    transforms.ToTensor()
])

trainset = torchvision.datasets.CIFAR10(root='data', train=True, download=True, transform=transform_train)
trainloader = data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

testset = torchvision.datasets.CIFAR10(root='data', train=False, download=True, transform=transform_test)
testloader = data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

3.2 模型构建

基于 ActNorm、可逆 1×1 卷积、仿射耦合层和多尺度架构实现 GLOW 模型:

python 复制代码
class Glow(nn.Module):
    def __init__(self, num_channels, num_levels, num_steps):
        super(Glow, self).__init__()

        # Use bounds to rescale images before converting to logits, not learned
        self.register_buffer('bounds', torch.tensor([0.9], dtype=torch.float32))
        self.flows = _Glow(in_channels=4 * 3,  # RGB image after squeeze
                           mid_channels=num_channels,
                           num_levels=num_levels,
                           num_steps=num_steps)

    def forward(self, x, reverse=False):
        if reverse:
            sldj = torch.zeros(x.size(0), device=x.device)
        else:
            # Expect inputs in [0, 1]
            if x.min() < 0 or x.max() > 1:
                raise ValueError('Expected x in [0, 1], got min/max {}/{}'
                                 .format(x.min(), x.max()))

            # De-quantize and convert to logits
            x, sldj = self._pre_process(x)

        x = squeeze(x)
        x, sldj = self.flows(x, sldj, reverse)
        x = squeeze(x, reverse=True)

        return x, sldj

    def _pre_process(self, x):
        y = (x * 255. + torch.rand_like(x)) / 256.
        y = (2 * y - 1) * self.bounds
        y = (y + 1) / 2
        y = y.log() - (1. - y).log()

        # Save log-determinant of Jacobian of initial transform
        ldj = F.softplus(y) + F.softplus(-y) \
            - F.softplus((1. - self.bounds).log() - self.bounds.log())
        sldj = ldj.flatten(1).sum(-1)

        return y, sldj

class _Glow(nn.Module):
    def __init__(self, in_channels, mid_channels, num_levels, num_steps):
        super(_Glow, self).__init__()
        self.steps = nn.ModuleList([_FlowStep(in_channels=in_channels,
                                              mid_channels=mid_channels)
                                    for _ in range(num_steps)])

        if num_levels > 1:
            self.next = _Glow(in_channels=2 * in_channels,
                              mid_channels=mid_channels,
                              num_levels=num_levels - 1,
                              num_steps=num_steps)
        else:
            self.next = None

    def forward(self, x, sldj, reverse=False):
        if not reverse:
            for step in self.steps:
                x, sldj = step(x, sldj, reverse)

        if self.next is not None:
            x = squeeze(x)
            x, x_split = x.chunk(2, dim=1)
            x, sldj = self.next(x, sldj, reverse)
            x = torch.cat((x, x_split), dim=1)
            x = squeeze(x, reverse=True)

        if reverse:
            for step in reversed(self.steps):
                x, sldj = step(x, sldj, reverse)

        return x, sldj


class _FlowStep(nn.Module):
    def __init__(self, in_channels, mid_channels):
        super(_FlowStep, self).__init__()
        # Activation normalization, invertible 1x1 convolution, affine coupling
        self.norm = ActNorm(in_channels, return_ldj=True)
        self.conv = InvConv(in_channels)
        self.coup = Coupling(in_channels // 2, mid_channels)

    def forward(self, x, sldj=None, reverse=False):
        if reverse:
            x, sldj = self.coup(x, sldj, reverse)
            x, sldj = self.conv(x, sldj, reverse)
            x, sldj = self.norm(x, sldj, reverse)
        else:
            x, sldj = self.norm(x, sldj, reverse)
            x, sldj = self.conv(x, sldj, reverse)
            x, sldj = self.coup(x, sldj, reverse)

        return x, sldj


def squeeze(x, reverse=False):
    b, c, h, w = x.size()
    if reverse:
        # Unsqueeze
        x = x.view(b, c // 4, 2, 2, h, w)
        x = x.permute(0, 1, 4, 2, 5, 3).contiguous()
        x = x.view(b, c // 4, h * 2, w * 2)
    else:
        # Squeeze
        x = x.view(b, c, h // 2, 2, w // 2, 2)
        x = x.permute(0, 1, 3, 5, 2, 4).contiguous()
        x = x.view(b, c * 2 * 2, h // 2, w // 2)

    return x

3.3 模型训练

实例化模型、损失函数和优化器,并进行训练:

python 复制代码
net = Glow(num_channels=num_channels,
            num_levels=num_levels,
            num_steps=num_steps)
net = net.to(device)

loss_fn = NLLLoss().to(device)
optimizer = optim.Adam(net.parameters(), lr=lr)
scheduler = sched.LambdaLR(optimizer, lambda s: min(1., s / warm_up))

@torch.enable_grad()
def train(epoch, net, trainloader, device, optimizer, scheduler, loss_fn, max_grad_norm):
    global global_step
    print('\nEpoch: %d' % epoch)
    net.train()
    loss_meter = AverageMeter()
    with tqdm(total=len(trainloader.dataset)) as progress_bar:
        for x, _ in trainloader:
            x = x.to(device)
            optimizer.zero_grad()
            z, sldj = net(x, reverse=False)
            loss = loss_fn(z, sldj)
            loss_meter.update(loss.item(), x.size(0))
            loss.backward()
            if max_grad_norm > 0:
                clip_grad_norm(optimizer, max_grad_norm)
            optimizer.step()
            scheduler.step(global_step)

            progress_bar.set_postfix(nll=loss_meter.avg,
                                     bpd=bits_per_dim(x, loss_meter.avg),
                                     lr=optimizer.param_groups[0]['lr'])
            progress_bar.update(x.size(0))
            global_step += x.size(0)

@torch.no_grad()
def sample(net, batch_size, device):
    z = torch.randn((batch_size, 3, 32, 32), dtype=torch.float32, device=device)
    x, _ = net(z, reverse=True)
    x = torch.sigmoid(x)

    return x

@torch.no_grad()
def test(epoch, net, testloader, device, loss_fn, num_samples):
    global best_loss
    net.eval()
    loss_meter = AverageMeter()
    with tqdm(total=len(testloader.dataset)) as progress_bar:
        for x, _ in testloader:
            x = x.to(device)
            z, sldj = net(x, reverse=False)
            loss = loss_fn(z, sldj)
            loss_meter.update(loss.item(), x.size(0))
            progress_bar.set_postfix(nll=loss_meter.avg,
                                     bpd=bits_per_dim(x, loss_meter.avg))
            progress_bar.update(x.size(0))

    # Save checkpoint
    if loss_meter.avg < best_loss:
        print('Saving...')
        state = {
            'net': net.state_dict(),
            'test_loss': loss_meter.avg,
            'epoch': epoch,
        }
        os.makedirs('ckpts', exist_ok=True)
        torch.save(state, 'ckpts/best.pth.tar')
        best_loss = loss_meter.avg

    # Save samples and data
    images = sample(net, num_samples, device)
    os.makedirs('samples', exist_ok=True)
    images_concat = torchvision.utils.make_grid(images, nrow=int(num_samples ** 0.5), padding=2, pad_value=255)
    torchvision.utils.save_image(images_concat, 'samples/epoch_{}.png'.format(epoch))

start_epoch = 0
for epoch in range(start_epoch, start_epoch + num_epochs):
    train(epoch, net, trainloader, device, optimizer, scheduler,
            loss_fn, max_grad_norm)
    test(epoch, net, testloader, device, loss_fn, num_samples)

100epoch 后,模型可生成逼真度较高的 32 × 32 彩色图像,样本在多通道细节和整体结构上均有良好效果,下图展示了训练过程中,不同 epoch 生成的图像对比:

相关推荐
MUTA️4 分钟前
视觉语言模型在视觉任务上的研究综述
人工智能·深度学习·语言模型·多模态
文火冰糖的硅基工坊1 小时前
[人工智能-综述-17]:AI革命:重塑职业版图,开启文明新篇
人工智能·深度学习·神经网络·架构·信号处理·跨学科融合
CoovallyAIHub1 小时前
数据集分享 | 稻田识别分割数据集、水稻虫害数据集
深度学习·算法·计算机视觉
金井PRATHAMA1 小时前
分布内侧内嗅皮层的层Ⅱ或层Ⅲ的网格细胞(grid cells)对NLP中的深层语义分析的积极影响和启示
人工智能·深度学习·神经网络·机器学习·自然语言处理·知识图谱
盼小辉丶2 小时前
TensorFlow深度学习实战——DeepDream
人工智能·深度学习·tensorflow
FF-Studio3 小时前
25年电赛C题 发挥部分 YOLOv8方案&数据集
python·深度学习·yolo
EdisonZhou3 小时前
多Agent协作入门:移交编排模式
llm·aigc·.net core
Blossom.1187 小时前
基于深度学习的图像分割:使用DeepLabv3实现高效分割
人工智能·python·深度学习·机器学习·分类·机器人·transformer
zzywxc78712 小时前
AI 驱动的软件测试革新:框架、检测与优化实践
人工智能·深度学习·机器学习·数据挖掘·数据分析
墨风如雪12 小时前
别再迷信闭源模型,你桌面的AI推理之王已经诞生
aigc