超越扩散模型,图像生成新方法

✨✨ 欢迎大家来访Srlua的博文(づ ̄3 ̄)づ╭❤~✨✨

🌟🌟 欢迎各位亲爱的读者,感谢你们抽出宝贵的时间来阅读我的文章。

我是Srlua小谢,在这里我会分享我的知识和经验。🎥

希望在这里,我们能一起探索IT世界的奥妙,提升我们的技能。🔮

记得先点赞👍后阅读哦~ 👏👏

📘📚 所属专栏:传知代码论文复现

欢迎访问我的主页:Srlua小谢 获取更多信息和资源。✨✨🌙🌙

​​

​​

目录

图像生成模型概述

VAE

扩散模型

VAR

演示效果

VAR模型原理

自回归学习的重新定义

多尺度生成策略

Transformer架构的应用

幂律缩放定律

下游任务的泛化能力

VAR模型效果

核心逻辑

模型加载

label标签更改

部署方式

安装相关依赖

下载权重

使用步骤

模型训练

模型预测


本文所有资源均可在该地址处获取。

图像生成模型概述

在图像生成领域,研究者们已经开发了多种类型的生成模型,每种都有其独特的优势和应用场景。这些模型可以大致分为三大类:变分自编码器(VAE)、扩散模型(Diffusion Models)以及最近提出的Visual AutoRegressive (VAR) 模型。下面将逐一介绍这三类模型的基本概念、特点及其在实际应用中的表现。

VAE

变分自编码器(Variational AutoEncoder, VAE)是一种基于概率图模型的生成模型,它由编码器和解码器两部分组成。编码器负责将输入数据映射到一个潜在空间中,而解码器则尝试从这个潜在空间中重建原始输入。与传统的自编码器不同,VAE引入了贝叶斯推断的思想,使得它可以学习到数据的概率分布,并能够根据这个分布来生成新的样本。VAE的一个重要特点是它可以在没有标签的情况下进行无监督学习,但它的生成质量通常不如其他一些更先进的模型,尤其是在处理高分辨率图像时。

扩散模型

扩散模型是一类相对较新的生成模型,它们通过一系列小的噪声添加步骤将随机噪声逐渐转换为真实的图像。这个过程可以看作是"反向"热力学过程,其中初始的完全随机状态(即纯噪声)被逐步引导至目标分布。扩散模型的关键在于设计有效的去噪步骤,以确保最终生成的图像是逼真的。这类模型近年来受到了广泛关注,因为它们能够在多个基准测试上取得优异的成绩,并且生成的图像具有较高的质量和多样性。然而,扩散模型的一个主要缺点是推理速度较慢,因为需要执行多次迭代才能完成一张图片的生成。

VAR

Visual AutoRegressive (VAR) 模型是图像生成领域的最新进展之一,它采用了不同于传统AR模型的新方法来进行视觉自回归学习。VAR模型专注于从粗到细的"下一尺度预测",而不是常见的光栅扫描"下一个token预测"。这种方法不仅提高了图像生成的质量,还显著加快了推理速度。VAR模型利用了改进的多尺度量化自动编码器(VQVAE)对图像进行编码,并结合了GPT-2风格的Transformer架构作为解码器,从而实现了高效且高质量的图像生成。此外,VAR展示了与大语言模型相似的幂律缩放定律,表明其具备良好的可扩展性,并在多个下游任务中表现出强大的泛化能力。

演示效果

VAR模型原理

自回归学习的重新定义

VAR模型的核心思想是对自回归学习进行了重新定义,从传统的逐像素或逐块预测转向了"下一尺度预测"。这意味着模型不是简单地预测下一个像素或token,而是预测下一分辨率级别的图像内容。这种转变带来了几个重要的优势:

保留空间结构:通过多尺度生成,VAR能够更好地保持图像的空间结构信息,避免了由于逐像素生成可能导致的细节丢失。

降低计算复杂度:相比逐像素预测,下一尺度预测大大减少了需要预测的元素数量,从而降低了整体计算复杂度。

提高生成效率:由于减少了预测步骤,VAR在生成同样大小的图像时所需的计算资源更少,因此推理速度更快。

多尺度生成策略

VAR采用了一种多尺度生成策略,该策略允许模型从低分辨率开始,逐步增加图像的分辨率,直到达到所需的最终尺寸。这一过程涉及以下几个关键步骤:

初始编码:使用改进的多尺度量化自动编码器(VQVAE)对输入图像进行编码,得到一个低分辨率的潜在表示。

逐步解码:然后,模型使用GPT-2风格的Transformer架构作为解码器,逐步解码并生成更高分辨率的图像。每次迭代中,模型都会预测当前尺度下的图像内容,并将其与之前生成的内容相结合,以形成更高一级的表示。

融合机制:为了确保不同尺度之间的平滑过渡,VAR引入了一种特殊的融合机制,可以在不同的尺度之间传递信息,保证图像的一致性和连贯性。

Transformer架构的应用

VAR模型选择了Transformer架构作为其核心组件,主要是因为它在处理长序列数据方面的能力。具体来说,Transformer的自注意力机制可以帮助模型捕捉图像中远距离像素之间的依赖关系,这对于生成高质量的图像至关重要。此外,Transformer架构还支持并行计算,进一步提高了模型的训练和推理效率。

幂律缩放定律

研究人员发现,VAR模型遵循一种类似于大语言模型的幂律缩放定律。这意味着随着模型参数量的增加,其性能会按照一定的比例提升。这种特性使得VAR模型在扩大规模时能够保持良好的性能增长,同时也为未来的模型优化提供了理论依据。

下游任务的泛化能力

除了基本的图像生成任务,VAR模型还在多个下游任务中展现了出色的零样本泛化能力。例如,在图像修复、修补和编辑等任务中,VAR能够根据上下文信息准确地填补缺失的部分或修改特定区域,而无需额外的训练。这得益于其强大的表征学习能力和灵活的生成机制,使得VAR成为了一个多功能的图像生成工具。

VAR模型效果

越靠左表示生成的时间越快,越靠下表示生成的FID越低,生成效果越好。

核心逻辑

模型加载

import torch, torchvision
import numpy as np
import PIL.Image as PImage, PIL.ImageDraw as PImageDraw
setattr(torch.nn.Linear, 'reset_parameters', lambda self: None)     # disable default parameter init for faster speed
setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None)  # disable default parameter init for faster speed
from models import VQVAE, build_vae_var

MODEL_DEPTH = 16    # TODO: =====> please specify MODEL_DEPTH <=====
assert MODEL_DEPTH in {16, 20, 24, 30}
vae_ckpt, var_ckpt = 'vae_ch160v4096z32.pth', f'var_d{MODEL_DEPTH}.pth'

# build vae, var
patch_nums = (1, 2, 3, 4, 5, 6, 8, 10, 13, 16)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if 'vae' not in globals() or 'var' not in globals():
    vae, var = build_vae_var(
        V=4096, Cvae=32, ch=160, share_quant_resi=4,    # hard-coded VQVAE hyperparameters
        device=device, patch_nums=patch_nums,
        num_classes=1000, depth=MODEL_DEPTH, shared_aln=False,
    )

# load checkpoints
vae.load_state_dict(torch.load(vae_ckpt, map_location='cpu'), strict=True)
var.load_state_dict(torch.load(var_ckpt, map_location='cpu'), strict=True)
vae.eval(), var.eval()
for p in vae.parameters(): p.requires_grad_(False)
for p in var.parameters(): p.requires_grad_(False)
print(f'prepare finished.')

label标签更改

部署方式

安装相关依赖

pip install -r requirements.txt

下载权重

https://hf-mirror.com/FoundationVision/var/tree/main

使用步骤

模型训练

修改类别个数

python train.py --data_path='' --final_reso=image_size

模型预测

python predict.py

​​

相关推荐
掐死你滴温柔3 小时前
SQLALchemy如何将SQL语句编译为特定数据库方言
数据结构·数据库·python·sql
西猫雷婶3 小时前
python学opencv|读取图像(二十三)使用cv2.putText()绘制文字
开发语言·python·opencv
三掌柜6664 小时前
2025三掌柜赠书活动第一期:动手学深度学习(PyTorch版)
人工智能·pytorch·深度学习
唯创知音5 小时前
基于W2605C语音识别合成芯片的智能语音交互闹钟方案-AI对话享受智能生活
人工智能·单片机·物联网·生活·智能家居·语音识别
说私域5 小时前
数字化供应链创新解决方案在零售行业的应用研究——以开源AI智能名片S2B2C商城小程序为例
人工智能·开源·零售
yvestine6 小时前
数据挖掘——支持向量机分类器
人工智能·算法·机器学习·支持向量机·分类·数据挖掘·svm
阿正的梦工坊6 小时前
PyTorch到C++再到 CUDA 的调用链(C++ ATen 层) :以torch._amp_update_scale_调用为例
c++·人工智能·pytorch
三万棵雪松6 小时前
5.系统学习-PyTorch与多层感知机
人工智能·pytorch·学习
AIGC大时代6 小时前
不只是工具:ChatGPT写作在学术中的创新思维与深度思考
人工智能·chatgpt·prompt·aigc·ai写作
陈序缘7 小时前
PyTorch快速入门
人工智能·pytorch·python·深度学习·算法·机器学习