图像自回归生成(Auto-regressive image generation)实战学习(六)

相关项目下载链接

基于 Transformer 的自回归图像生成模型完整的链路是:1、先用 Patch AutoEncoder + BSQ 量化器,把原始图像压缩为离散的 token 序列(每个整数 token 对应原图的一个小图像 patch)2、训练这个自回归 Transformer 模型,学习 token 之间的空间共现规律;3、通过generate方法生成全新的token序列;4、用 BSQ 量化器把 token 序列解码回可保存的 png 图片。

本节内容主要介绍如何通过generate方法生成全新的 token 序列。

定义主模型

主模型对应的代码在autoregressive.py,在上一节中我们并没有定义generation方法的具体实现,本节对其逻辑进行补全。为了兼容补全后的generation方法,还需要对前向传播算法进行维度匹配调整。

补全generation方法

python 复制代码
 @torch.no_grad()
    def generate(self, B: int = 1, h: int = 20, w: int = 30, device=None) -> torch.Tensor:
        if device is None:
            device = self.embedding.weight.device

        gen_seq = torch.zeros((B, h, w), dtype=torch.long, device=device)
        total_len = h * w

        for k in range(total_len):
            # 把 1D 索引 k 转回 2D 坐标 (i,j)
            i = k // w  # 行号
            j = k % w   # 列号

            logits, _ = self.forward(gen_seq)

            next_token_logits = logits[:, i, j, :] / 0.9
            
            next_token = torch.multinomial(
                F.softmax(next_token_logits, dim=-1), 
                num_samples=1
            ).squeeze(1)

            gen_seq[:, i, j] = next_token

        return gen_seq

调整前向传播算法

python 复制代码
    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
    	# 对训练和推理进行维度匹配
        if x.dim() == 4:
            x = x.squeeze(1)
        B, h, w = x.shape
        L = h * w

        # 展平成序列
        x_flat = x.reshape(B, L)

        # 嵌入 + 位置编码
        token_emb = self.embedding(x_flat)
        pos_idx = torch.arange(L, device=x.device)
        pos_emb = self.pos_emb(pos_idx)
        x_emb = token_emb + pos_emb

        # 自回归右移(关键)
        x_emb = F.pad(x_emb, (0,0,1,0))[:, :-1]

        # 因果掩码
        mask = self._generate_causal_mask(L, x.device)
        trans_out = self.transformer(x_emb, mask=mask)

        # 输出
        logits = self.fc_out(trans_out)
        logits_2d = logits.reshape(B, h, w, self.n_tokens)

        return logits_2d, {}

模块测评

下面进行图像生成的功能测试:

python 复制代码
mkdir test
python -m homework.generation checkpoints/BSQPatchAutoEncoder.pth checkpoints/AutoregressiveModel.pth 8 test 

所得的解码后的PNG图片如下所示:

|----------------------------------------------------------------------------|----------------------------------------------------------------------------|----------------------------------------------------------------------------|----------------------------------------------------------------------------|
| | | | |
| | | | |

将代码打包为压缩文件

python 复制代码
python bundle.py homework 20260412

进行评分自测:

python 复制代码
python -m grader 20260412.zip

最终测试得分如下:

可选的优化方向

更优的量化器(更小的图像块、更高的码率)

  1. 缩小 patch 尺寸:你当前patch_size=5,可改为 3或2。更小的图像块意味着更细的图像粒度,大幅减少单 patch 的信息损失,生成的图像细节更丰富、块效应更少。
  2. 提升码本码率:你当前codebook_bits=10(仅 1024 个码本),可提升到 12或14。码本容量越大,量化的精度越高,单个 token 能表达的图像信息越丰富,生成的画面连贯性更强。
  3. 辅助优化:提升 Patch AutoEncoder 的重建能力(比如增加卷积层、调整 latent_dim),降低量化器的基础重建 MSE,从根源上提升 token 的质量。

更大的 Transformer 模型参数量

  1. 增加 Transformer 深度:把 Encoder 层数从 2 层提升到 4/6 层,更深的网络能拟合更复杂的 token 序列分布。
  2. 提升隐层维度:把d_latent从 128 提升到 256/512(注意nhead必须能整除d_latent),更高的维度能承载更丰富的图像语义信息。

更优的训练策略

  1. 增加训练轮次:可提升到 10/20/50 轮,配合学习率衰减策略,让模型充分学习 patch 的空间分布和长距离依赖关系。
  2. 优化学习率策略:在 AdamW 优化器中加入「warmup 预热 + 余弦退火衰减」,避免训练初期梯度爆炸,同时让模型在训练后期更精细地拟合分布,大幅提升生成效果。
相关推荐
努力努力再努力FFF2 小时前
医生对AI辅助诊断感兴趣,作为临床人员该怎么了解和学习?
人工智能·学习
sakiko_3 小时前
UIKit学习笔记5-使用UITableView制作聊天页面
笔记·学习·swift·uikit
AI科技星3 小时前
全域数学·72分册:场计算机卷【乖乖数学】
算法·机器学习·数学建模·数据挖掘·量子计算
Alice-YUE4 小时前
【js高频八股】防抖与节流
开发语言·前端·javascript·笔记·学习·ecmascript
北山有鸟5 小时前
修改源码法和插件法
嵌入式硬件·学习
richxu202510015 小时前
嵌入式学习之路->stm32篇->(14)通用定时器(上)
stm32·单片机·嵌入式硬件·学习
qeen875 小时前
【数据结构】建堆的时间复杂度讨论与TOP-K问题
c语言·数据结构·c++·学习·
lizhihai_996 小时前
股市学习心得-六张分时保命图
大数据·人工智能·学习
AI科技星6 小时前
全域数学·数术本源·高维代数卷(72分册)【乖乖数学】
人工智能·算法·数学建模·数据挖掘·量子计算
nashane6 小时前
HarmonyOS 6学习:应用签名文件丢失处理与更新完全指南
学习·华为·harmonyos·harmonyos 5