图像自回归生成(Auto-regressive image generation)实战学习(六)

相关项目下载链接

基于 Transformer 的自回归图像生成模型完整的链路是:1、先用 Patch AutoEncoder + BSQ 量化器,把原始图像压缩为离散的 token 序列(每个整数 token 对应原图的一个小图像 patch)2、训练这个自回归 Transformer 模型,学习 token 之间的空间共现规律;3、通过generate方法生成全新的token序列;4、用 BSQ 量化器把 token 序列解码回可保存的 png 图片。

本节内容主要介绍如何通过generate方法生成全新的 token 序列。

定义主模型

主模型对应的代码在autoregressive.py,在上一节中我们并没有定义generation方法的具体实现,本节对其逻辑进行补全。为了兼容补全后的generation方法,还需要对前向传播算法进行维度匹配调整。

补全generation方法

python 复制代码
 @torch.no_grad()
    def generate(self, B: int = 1, h: int = 20, w: int = 30, device=None) -> torch.Tensor:
        if device is None:
            device = self.embedding.weight.device

        gen_seq = torch.zeros((B, h, w), dtype=torch.long, device=device)
        total_len = h * w

        for k in range(total_len):
            # 把 1D 索引 k 转回 2D 坐标 (i,j)
            i = k // w  # 行号
            j = k % w   # 列号

            logits, _ = self.forward(gen_seq)

            next_token_logits = logits[:, i, j, :] / 0.9
            
            next_token = torch.multinomial(
                F.softmax(next_token_logits, dim=-1), 
                num_samples=1
            ).squeeze(1)

            gen_seq[:, i, j] = next_token

        return gen_seq

调整前向传播算法

python 复制代码
    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
    	# 对训练和推理进行维度匹配
        if x.dim() == 4:
            x = x.squeeze(1)
        B, h, w = x.shape
        L = h * w

        # 展平成序列
        x_flat = x.reshape(B, L)

        # 嵌入 + 位置编码
        token_emb = self.embedding(x_flat)
        pos_idx = torch.arange(L, device=x.device)
        pos_emb = self.pos_emb(pos_idx)
        x_emb = token_emb + pos_emb

        # 自回归右移(关键)
        x_emb = F.pad(x_emb, (0,0,1,0))[:, :-1]

        # 因果掩码
        mask = self._generate_causal_mask(L, x.device)
        trans_out = self.transformer(x_emb, mask=mask)

        # 输出
        logits = self.fc_out(trans_out)
        logits_2d = logits.reshape(B, h, w, self.n_tokens)

        return logits_2d, {}

模块测评

下面进行图像生成的功能测试:

python 复制代码
mkdir test
python -m homework.generation checkpoints/BSQPatchAutoEncoder.pth checkpoints/AutoregressiveModel.pth 8 test 

所得的解码后的PNG图片如下所示:

|----------------------------------------------------------------------------|----------------------------------------------------------------------------|----------------------------------------------------------------------------|----------------------------------------------------------------------------|
| | | | |
| | | | |

将代码打包为压缩文件

python 复制代码
python bundle.py homework 20260412

进行评分自测:

python 复制代码
python -m grader 20260412.zip

最终测试得分如下:

可选的优化方向

更优的量化器(更小的图像块、更高的码率)

  1. 缩小 patch 尺寸:你当前patch_size=5,可改为 3或2。更小的图像块意味着更细的图像粒度,大幅减少单 patch 的信息损失,生成的图像细节更丰富、块效应更少。
  2. 提升码本码率:你当前codebook_bits=10(仅 1024 个码本),可提升到 12或14。码本容量越大,量化的精度越高,单个 token 能表达的图像信息越丰富,生成的画面连贯性更强。
  3. 辅助优化:提升 Patch AutoEncoder 的重建能力(比如增加卷积层、调整 latent_dim),降低量化器的基础重建 MSE,从根源上提升 token 的质量。

更大的 Transformer 模型参数量

  1. 增加 Transformer 深度:把 Encoder 层数从 2 层提升到 4/6 层,更深的网络能拟合更复杂的 token 序列分布。
  2. 提升隐层维度:把d_latent从 128 提升到 256/512(注意nhead必须能整除d_latent),更高的维度能承载更丰富的图像语义信息。

更优的训练策略

  1. 增加训练轮次:可提升到 10/20/50 轮,配合学习率衰减策略,让模型充分学习 patch 的空间分布和长距离依赖关系。
  2. 优化学习率策略:在 AdamW 优化器中加入「warmup 预热 + 余弦退火衰减」,避免训练初期梯度爆炸,同时让模型在训练后期更精细地拟合分布,大幅提升生成效果。
相关推荐
GIS数据转换器2 小时前
延凡低成本低空无人机AI巡检方案
大数据·人工智能·信息可视化·数据挖掘·无人机
weixin_443478512 小时前
Flutter组件学习之图表
学习·flutter·信息可视化
倦王3 小时前
大模型学习2
学习
徒 花3 小时前
HCIP学习05 链路聚合(Eth-Trunk)+ VRRP
服务器·网络·学习·hcip
黑金IT3 小时前
AI Agent “小龙虾终极进化”——自主学习与持久化记忆的架构实现
人工智能·学习·架构
weixin_395772473 小时前
计算机网络学习笔记】初始网络之网络发展和OSI七层模型
笔记·学习·计算机网络
Omics Pro3 小时前
上海AI Lab+复旦大学:双轨协同实现自动化虚拟细胞建模
运维·人工智能·语言模型·自然语言处理·数据挖掘·数据分析·自动化
南境十里·墨染春水3 小时前
linux学习进展 进程的内存管理
linux·服务器·学习
小陈phd3 小时前
多模态大模型学习笔记(三十四)——ChatTTS:新一代中文语音合成工具原理与实战解析
笔记·学习·语音识别