相关项目下载链接
基于 Transformer 的自回归图像生成模型完整的链路是:1、先用 Patch AutoEncoder + BSQ 量化器,把原始图像压缩为离散的 token 序列(每个整数 token 对应原图的一个小图像 patch)2、训练这个自回归 Transformer 模型,学习 token 之间的空间共现规律;3、通过generate方法生成全新的token序列;4、用 BSQ 量化器把 token 序列解码回可保存的 png 图片。
本节内容主要介绍如何通过generate方法生成全新的 token 序列。
定义主模型
主模型对应的代码在autoregressive.py,在上一节中我们并没有定义generation方法的具体实现,本节对其逻辑进行补全。为了兼容补全后的generation方法,还需要对前向传播算法进行维度匹配调整。
补全generation方法
python
@torch.no_grad()
def generate(self, B: int = 1, h: int = 20, w: int = 30, device=None) -> torch.Tensor:
if device is None:
device = self.embedding.weight.device
gen_seq = torch.zeros((B, h, w), dtype=torch.long, device=device)
total_len = h * w
for k in range(total_len):
# 把 1D 索引 k 转回 2D 坐标 (i,j)
i = k // w # 行号
j = k % w # 列号
logits, _ = self.forward(gen_seq)
next_token_logits = logits[:, i, j, :] / 0.9
next_token = torch.multinomial(
F.softmax(next_token_logits, dim=-1),
num_samples=1
).squeeze(1)
gen_seq[:, i, j] = next_token
return gen_seq
调整前向传播算法
python
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
# 对训练和推理进行维度匹配
if x.dim() == 4:
x = x.squeeze(1)
B, h, w = x.shape
L = h * w
# 展平成序列
x_flat = x.reshape(B, L)
# 嵌入 + 位置编码
token_emb = self.embedding(x_flat)
pos_idx = torch.arange(L, device=x.device)
pos_emb = self.pos_emb(pos_idx)
x_emb = token_emb + pos_emb
# 自回归右移(关键)
x_emb = F.pad(x_emb, (0,0,1,0))[:, :-1]
# 因果掩码
mask = self._generate_causal_mask(L, x.device)
trans_out = self.transformer(x_emb, mask=mask)
# 输出
logits = self.fc_out(trans_out)
logits_2d = logits.reshape(B, h, w, self.n_tokens)
return logits_2d, {}
模块测评
下面进行图像生成的功能测试:
python
mkdir test
python -m homework.generation checkpoints/BSQPatchAutoEncoder.pth checkpoints/AutoregressiveModel.pth 8 test
所得的解码后的PNG图片如下所示:
|----------------------------------------------------------------------------|----------------------------------------------------------------------------|----------------------------------------------------------------------------|----------------------------------------------------------------------------|
|
|
|
|
|
|
|
|
|
|
将代码打包为压缩文件
python
python bundle.py homework 20260412
进行评分自测:
python
python -m grader 20260412.zip
最终测试得分如下:

可选的优化方向
更优的量化器(更小的图像块、更高的码率)
- 缩小 patch 尺寸:你当前patch_size=5,可改为 3或2。更小的图像块意味着更细的图像粒度,大幅减少单 patch 的信息损失,生成的图像细节更丰富、块效应更少。
- 提升码本码率:你当前codebook_bits=10(仅 1024 个码本),可提升到 12或14。码本容量越大,量化的精度越高,单个 token 能表达的图像信息越丰富,生成的画面连贯性更强。
- 辅助优化:提升 Patch AutoEncoder 的重建能力(比如增加卷积层、调整 latent_dim),降低量化器的基础重建 MSE,从根源上提升 token 的质量。
更大的 Transformer 模型参数量
- 增加 Transformer 深度:把 Encoder 层数从 2 层提升到 4/6 层,更深的网络能拟合更复杂的 token 序列分布。
- 提升隐层维度:把d_latent从 128 提升到 256/512(注意nhead必须能整除d_latent),更高的维度能承载更丰富的图像语义信息。
更优的训练策略
- 增加训练轮次:可提升到 10/20/50 轮,配合学习率衰减策略,让模型充分学习 patch 的空间分布和长距离依赖关系。
- 优化学习率策略:在 AdamW 优化器中加入「warmup 预热 + 余弦退火衰减」,避免训练初期梯度爆炸,同时让模型在训练后期更精细地拟合分布,大幅提升生成效果。