图像自回归生成(Auto-regressive image generation)实战学习(三)

相关项目下载链接

本节内容基于ae.py文件,实现了一个Patch 级自动编码器(Patch AutoEncoder),核心功能是将图像按固定尺寸分块(Patch),通过编码 - 解码流程实现图像重。具体流程如下:

导入依赖库

python 复制代码
import abc
import torch

定义模型加载函数

python 复制代码
def load() -> torch.nn.Module:
    from pathlib import Path

    model_name = "PatchAutoEncoder"
    model_path = Path(__file__).parent / f"{model_name}.pth"
    print(f"Loading {model_name} from {model_path}")
    return torch.load(model_path, weights_only=False)

定义转换通道函数,实现 "通道最后"(HWC)与 "通道第一"(CHW)格式转换

python 复制代码
def hwc_to_chw(x: torch.Tensor) -> torch.Tensor:
    """
    Convert an arbitrary tensor from (H, W, C) to (C, H, W) format.
    This allows us to switch from transformer-style channel-last to pytorch-style channel-first
    images. Works with or without the batch dimension.
    """
    dims = list(range(x.dim()))
    dims = dims[:-3] + [dims[-1]] + [dims[-3]] + [dims[-2]]
    return x.permute(*dims)

def chw_to_hwc(x: torch.Tensor) -> torch.Tensor:
    """
    The opposite of hwc_to_chw. Works with or without the batch dimension.
    """
    dims = list(range(x.dim()))
    dims = dims[:-3] + [dims[-2]] + [dims[-1]] + [dims[-3]]
    return x.permute(*dims)

Patch 分块模块

  • 分块模块:接收 (B, H, W, 3) 格式的图像张量,将其按patch_size大小分块,每个 Patch 通过线性变换映射为latent_dim维的嵌入向量,输出 (B, H//patch_size, W//patch_size, latent_dim) 格式的 Patch 嵌入张量。
  • 逆分块模块:接收 (B, w, h, latent_dim) 格式的 Patch 嵌入张量,将其逆转换为原始尺寸的图像张量 (B, wpatch_size, hpatch_size, 3),是PatchifyLinear的逆操作。
python 复制代码
# 分块模块
class PatchifyLinear(torch.nn.Module):
    """
    Takes an image tensor of the shape (B, H, W, 3) and patchifies it into
    an embedding tensor of the shape (B, H//patch_size, W//patch_size, latent_dim).
    It applies a linear transformation to each input patch
    """
    def __init__(self, patch_size: int = 25, latent_dim: int = 128):
        super().__init__()
        self.patch_conv = torch.nn.Conv2d(3, latent_dim, patch_size, patch_size, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = hwc_to_chw(x)
        x = self.patch_conv(x)
        return chw_to_hwc(x)

# 逆分块模块
class UnpatchifyLinear(torch.nn.Module):
    """
    Takes an embedding tensor of the shape (B, w, h, latent_dim) and reconstructs
    an image tensor of the shape (B, w * patch_size, h * patch_size, 3).
    """
    def __init__(self, patch_size: int = 25, latent_dim: int = 128):
        super().__init__()
        self.unpatch_conv = torch.nn.ConvTranspose2d(latent_dim, 3, patch_size, patch_size, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = hwc_to_chw(x)
        x = self.unpatch_conv(x)
        return chw_to_hwc(x)

抽象基类 PatchAutoEncoderBase

定义了 Patch 自动编码器的核心抽象接口,包含encode(编码)和decode(解码)两个抽象方法。

python 复制代码
class PatchAutoEncoderBase(abc.ABC):
    @abc.abstractmethod
    def encode(self, x: torch.Tensor) -> torch.Tensor:
        """
        Encode an input image x (B, H, W, 3) into a tensor (B, h, w, bottleneck)
        """

    @abc.abstractmethod
    def decode(self, x: torch.Tensor) -> torch.Tensor:
        """
        Decode a tensor x (B, h, w, bottleneck) into an image (B, H, W, 3)
        """

核心模型 PatchAutoEncoder

python 复制代码
class PatchAutoEncoder(torch.nn.Module, PatchAutoEncoderBase):
    """
    Implement a PatchLevel AutoEncoder

    Hint: Convolutions work well enough, no need to use a transformer unless you really want.
    Hint: See PatchifyLinear and UnpatchifyLinear for how to use convolutions with the input and
          output dimensions given.
    Hint: You can get away with 3 layers or less.
    Hint: Many architectures work here (even a just PatchifyLinear / UnpatchifyLinear).
          However, later parts of the assignment require both non-linearities (i.e. GeLU) and
          interactions (i.e. convolutions) between patches.
    """
	# 将输入图像编码为低维度的瓶颈特征(bottleneck),输出 (B, h, w, bottleneck) 格式的张量。
    class PatchEncoder(torch.nn.Module):
        """
        (Optionally) Use this class to implement an encoder.
                     It can make later parts of the homework easier (reusable components).
        """

        def __init__(self, patch_size: int, latent_dim: int, bottleneck: int):
            super().__init__()

            self.patchify = PatchifyLinear(patch_size, latent_dim)
            
            self.conv1 = torch.nn.Conv2d(latent_dim, latent_dim, kernel_size=3, padding=1)
            self.conv2 = torch.nn.Conv2d(latent_dim, bottleneck, kernel_size=1)
            self.activation = torch.nn.GELU()

        def forward(self, x: torch.Tensor) -> torch.Tensor:

            x = self.patchify(x)  
            
            x = hwc_to_chw(x)  
            x = self.activation(self.conv1(x))  
            x = self.conv2(x)  
            return chw_to_hwc(x)
            
	# 将瓶颈特征解码为原始尺寸的图像,实现decode(encode(x)) ≈ x的重建目标。
    class PatchDecoder(torch.nn.Module):
        def __init__(self, patch_size: int, latent_dim: int, bottleneck: int):
            super().__init__()

            self.conv1 = torch.nn.Conv2d(bottleneck, latent_dim, kernel_size=1)
            self.conv2 = torch.nn.Conv2d(latent_dim, latent_dim, kernel_size=3, padding=1)
            
            self.unpatchify = UnpatchifyLinear(patch_size, latent_dim)
            self.activation = torch.nn.GELU()
            self.output_act = torch.nn.Tanh()
            self.output_scale = 0.5

        def forward(self, x: torch.Tensor) -> torch.Tensor:
 
            x = hwc_to_chw(x) 
            
            x = self.activation(self.conv1(x))
            x = self.activation(self.conv2(x))  
            
            x = chw_to_hwc(x)
            x = self.unpatchify(x) 
            
            return self.output_act(x) * self.output_scale
            
	# 初始化编码器和解码器,实现完整的前向传播流程,返回重建图像和损失字典。
    def __init__(self, patch_size: int = 25, latent_dim: int = 128, bottleneck: int = 128):
        super().__init__()
        
        self.patch_size = patch_size
        self.encoder = self.PatchEncoder(patch_size, latent_dim, bottleneck)
        self.decoder = self.PatchDecoder(patch_size, latent_dim, bottleneck)

    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
        """
        Return the reconstructed image and a dictionary of additional loss terms you would like to
        minimize (or even just visualize).
        You can return an empty dictionary if you don't have any additional terms.
        """

        z = self.encode(x)
        x_recon = self.decode(z)
        
        recon_loss = torch.nn.functional.mse_loss(x_recon, x)
        
        return x_recon, {"recon_loss": recon_loss}

    def encode(self, x: torch.Tensor) -> torch.Tensor:

        return self.encoder(x)

    def decode(self, x: torch.Tensor) -> torch.Tensor:

        return self.decoder(x)

下面进行模型训练:

python 复制代码
python -m homework.train PatchAutoEncoder

模型训练过程如图所示

将代码打包为压缩文件

python 复制代码
python bundle.py homework 20251223.zip

进行评分自测:

python 复制代码
python -m grader 20251223.zip

最终测试得分如下:

相关推荐
weelinking几秒前
【claude】14_Claude作为技术文档助手
前端·人工智能·react.js·数据挖掘·前端框架
领麦微红外3 分钟前
从被动预警到精准防护:红外测温助力激光安全切割新标准
人工智能·安全
穗余10 分钟前
2026 AI x Web3 School共学营笔记-Day1
人工智能·笔记·web3
zhumin72611 分钟前
人工智能评标应用研究——构建智能化、标准化、可信化的新型评标体系
人工智能
字节高级特工16 分钟前
AI接入指南:从API到本地部署全解析
人工智能
Black蜡笔小新18 分钟前
企业AI算力工作站/深度学习推理工作站DLTM零代码私有化重塑智慧农业AI模型训练体系
人工智能·深度学习
爱喝水的鱼丶21 分钟前
SAP-ABAP:数据类型与数据对象(8篇) 第七篇:进阶优化篇——基于类型与对象特征的性能优化技巧
运维·数据库·学习·性能优化·sap·abap·开发交流
轻刀快马25 分钟前
个人体验:从零构建高可用 Multi-Agent 架构与实战避坑指南
人工智能·架构·agent
SelectDB技术团队25 分钟前
PB 级自动驾驶数据秒级检索:Apache Doris 统一多模态数据平台实践
数据库·人工智能·自动驾驶·apache doris·selectdb
数智工坊28 分钟前
【UniT论文阅读】:用统一物理语言打通人类与人形机器人的知识壁垒
论文阅读·人工智能·深度学习·算法·机器人