图像自回归生成(Auto-regressive image generation)实战学习(三)

相关项目下载链接

本节内容基于ae.py文件,实现了一个Patch 级自动编码器(Patch AutoEncoder),核心功能是将图像按固定尺寸分块(Patch),通过编码 - 解码流程实现图像重。具体流程如下:

导入依赖库

python 复制代码
import abc
import torch

定义模型加载函数

python 复制代码
def load() -> torch.nn.Module:
    from pathlib import Path

    model_name = "PatchAutoEncoder"
    model_path = Path(__file__).parent / f"{model_name}.pth"
    print(f"Loading {model_name} from {model_path}")
    return torch.load(model_path, weights_only=False)

定义转换通道函数,实现 "通道最后"(HWC)与 "通道第一"(CHW)格式转换

python 复制代码
def hwc_to_chw(x: torch.Tensor) -> torch.Tensor:
    """
    Convert an arbitrary tensor from (H, W, C) to (C, H, W) format.
    This allows us to switch from transformer-style channel-last to pytorch-style channel-first
    images. Works with or without the batch dimension.
    """
    dims = list(range(x.dim()))
    dims = dims[:-3] + [dims[-1]] + [dims[-3]] + [dims[-2]]
    return x.permute(*dims)

def chw_to_hwc(x: torch.Tensor) -> torch.Tensor:
    """
    The opposite of hwc_to_chw. Works with or without the batch dimension.
    """
    dims = list(range(x.dim()))
    dims = dims[:-3] + [dims[-2]] + [dims[-1]] + [dims[-3]]
    return x.permute(*dims)

Patch 分块模块

  • 分块模块:接收 (B, H, W, 3) 格式的图像张量,将其按patch_size大小分块,每个 Patch 通过线性变换映射为latent_dim维的嵌入向量,输出 (B, H//patch_size, W//patch_size, latent_dim) 格式的 Patch 嵌入张量。
  • 逆分块模块:接收 (B, w, h, latent_dim) 格式的 Patch 嵌入张量,将其逆转换为原始尺寸的图像张量 (B, wpatch_size, hpatch_size, 3),是PatchifyLinear的逆操作。
python 复制代码
# 分块模块
class PatchifyLinear(torch.nn.Module):
    """
    Takes an image tensor of the shape (B, H, W, 3) and patchifies it into
    an embedding tensor of the shape (B, H//patch_size, W//patch_size, latent_dim).
    It applies a linear transformation to each input patch
    """
    def __init__(self, patch_size: int = 25, latent_dim: int = 128):
        super().__init__()
        self.patch_conv = torch.nn.Conv2d(3, latent_dim, patch_size, patch_size, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = hwc_to_chw(x)
        x = self.patch_conv(x)
        return chw_to_hwc(x)

# 逆分块模块
class UnpatchifyLinear(torch.nn.Module):
    """
    Takes an embedding tensor of the shape (B, w, h, latent_dim) and reconstructs
    an image tensor of the shape (B, w * patch_size, h * patch_size, 3).
    """
    def __init__(self, patch_size: int = 25, latent_dim: int = 128):
        super().__init__()
        self.unpatch_conv = torch.nn.ConvTranspose2d(latent_dim, 3, patch_size, patch_size, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = hwc_to_chw(x)
        x = self.unpatch_conv(x)
        return chw_to_hwc(x)

抽象基类 PatchAutoEncoderBase

定义了 Patch 自动编码器的核心抽象接口,包含encode(编码)和decode(解码)两个抽象方法。

python 复制代码
class PatchAutoEncoderBase(abc.ABC):
    @abc.abstractmethod
    def encode(self, x: torch.Tensor) -> torch.Tensor:
        """
        Encode an input image x (B, H, W, 3) into a tensor (B, h, w, bottleneck)
        """

    @abc.abstractmethod
    def decode(self, x: torch.Tensor) -> torch.Tensor:
        """
        Decode a tensor x (B, h, w, bottleneck) into an image (B, H, W, 3)
        """

核心模型 PatchAutoEncoder

python 复制代码
class PatchAutoEncoder(torch.nn.Module, PatchAutoEncoderBase):
    """
    Implement a PatchLevel AutoEncoder

    Hint: Convolutions work well enough, no need to use a transformer unless you really want.
    Hint: See PatchifyLinear and UnpatchifyLinear for how to use convolutions with the input and
          output dimensions given.
    Hint: You can get away with 3 layers or less.
    Hint: Many architectures work here (even a just PatchifyLinear / UnpatchifyLinear).
          However, later parts of the assignment require both non-linearities (i.e. GeLU) and
          interactions (i.e. convolutions) between patches.
    """
	# 将输入图像编码为低维度的瓶颈特征(bottleneck),输出 (B, h, w, bottleneck) 格式的张量。
    class PatchEncoder(torch.nn.Module):
        """
        (Optionally) Use this class to implement an encoder.
                     It can make later parts of the homework easier (reusable components).
        """

        def __init__(self, patch_size: int, latent_dim: int, bottleneck: int):
            super().__init__()

            self.patchify = PatchifyLinear(patch_size, latent_dim)
            
            self.conv1 = torch.nn.Conv2d(latent_dim, latent_dim, kernel_size=3, padding=1)
            self.conv2 = torch.nn.Conv2d(latent_dim, bottleneck, kernel_size=1)
            self.activation = torch.nn.GELU()

        def forward(self, x: torch.Tensor) -> torch.Tensor:

            x = self.patchify(x)  
            
            x = hwc_to_chw(x)  
            x = self.activation(self.conv1(x))  
            x = self.conv2(x)  
            return chw_to_hwc(x)
            
	# 将瓶颈特征解码为原始尺寸的图像,实现decode(encode(x)) ≈ x的重建目标。
    class PatchDecoder(torch.nn.Module):
        def __init__(self, patch_size: int, latent_dim: int, bottleneck: int):
            super().__init__()

            self.conv1 = torch.nn.Conv2d(bottleneck, latent_dim, kernel_size=1)
            self.conv2 = torch.nn.Conv2d(latent_dim, latent_dim, kernel_size=3, padding=1)
            
            self.unpatchify = UnpatchifyLinear(patch_size, latent_dim)
            self.activation = torch.nn.GELU()
            self.output_act = torch.nn.Tanh()
            self.output_scale = 0.5

        def forward(self, x: torch.Tensor) -> torch.Tensor:
 
            x = hwc_to_chw(x) 
            
            x = self.activation(self.conv1(x))
            x = self.activation(self.conv2(x))  
            
            x = chw_to_hwc(x)
            x = self.unpatchify(x) 
            
            return self.output_act(x) * self.output_scale
            
	# 初始化编码器和解码器,实现完整的前向传播流程,返回重建图像和损失字典。
    def __init__(self, patch_size: int = 25, latent_dim: int = 128, bottleneck: int = 128):
        super().__init__()
        
        self.patch_size = patch_size
        self.encoder = self.PatchEncoder(patch_size, latent_dim, bottleneck)
        self.decoder = self.PatchDecoder(patch_size, latent_dim, bottleneck)

    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
        """
        Return the reconstructed image and a dictionary of additional loss terms you would like to
        minimize (or even just visualize).
        You can return an empty dictionary if you don't have any additional terms.
        """

        z = self.encode(x)
        x_recon = self.decode(z)
        
        recon_loss = torch.nn.functional.mse_loss(x_recon, x)
        
        return x_recon, {"recon_loss": recon_loss}

    def encode(self, x: torch.Tensor) -> torch.Tensor:

        return self.encoder(x)

    def decode(self, x: torch.Tensor) -> torch.Tensor:

        return self.decoder(x)

下面进行模型训练:

python 复制代码
python -m homework.train PatchAutoEncoder

模型训练过程如图所示

将代码打包为压缩文件

python 复制代码
python bundle.py homework 20251223.zip

进行评分自测:

python 复制代码
python -m grader 20251223.zip

最终测试得分如下:

相关推荐
那个村的李富贵6 小时前
光影魔术师:CANN加速实时图像风格迁移,让每张照片秒变大师画作
人工智能·aigc·cann
腾讯云开发者8 小时前
“痛点”到“通点”!一份让 AI 真正落地产生真金白银的实战指南
人工智能
CareyWYR8 小时前
每周AI论文速递(260202-260206)
人工智能
hopsky9 小时前
大模型生成PPT的技术原理
人工智能
禁默9 小时前
打通 AI 与信号处理的“任督二脉”:Ascend SIP Boost 加速库深度实战
人工智能·信号处理·cann
心疼你的一切10 小时前
昇腾CANN实战落地:从智慧城市到AIGC,解锁五大行业AI应用的算力密码
数据仓库·人工智能·深度学习·aigc·智慧城市·cann
阿蒙Amon10 小时前
TypeScript学习-第10章:模块与命名空间
学习·ubuntu·typescript
AI绘画哇哒哒10 小时前
【干货收藏】深度解析AI Agent框架:设计原理+主流选型+项目实操,一站式学习指南
人工智能·学习·ai·程序员·大模型·产品经理·转行
数据分析能量站10 小时前
Clawdbot(现名Moltbot)-现状分析
人工智能
那个村的李富贵10 小时前
CANN加速下的AIGC“即时翻译”:AI语音克隆与实时变声实战
人工智能·算法·aigc·cann