图像自回归生成(Auto-regressive image generation)实战学习(二)

相关项目下载链接

训练框架

在开始实现相应模块功能之前,首先熟悉训练框架·train.py

1. 导入与模型字典构建

python 复制代码
import inspect
import math
from datetime import datetime
from pathlib import Path

import torch
import ae, autoregressive, bsq  # 自定义模型模块(AE/BSQ/自回归)

# 收集ae/bsq模块中所有继承nn.Module的块级模型类
patch_models = {
    n: m for M in [ae, bsq] for n, m in inspect.getmembers(M) if inspect.isclass(m) and issubclass(m, torch.nn.Module)
}
# 收集autoregressive模块中所有继承nn.Module的自回归模型类
ar_models = {
    n: m for M in [autoregressive] for n, m in inspect.getmembers(M) if inspect.isclass(m) and issubclass(m, torch.nn.Module)
}

2. 核心训练函数 train()

共包含三个部分:块级模型训练器 PatchTrainer、自回归模型训练器 AutoregressiveTrainer、模型保存回调 CheckPointer。

其中,

  1. 块级模型训练器 PatchTrainer 专用于 AE/BSQ 模型。值得注意 的是数据预处理过程,图像归一化的方式是(/255.0 - 0.5),将像素值映射到[-0.5, 0.5]而不是[0,1];损失函数采用MSE(均方误差),适配图像重构任务;优化器为AdamW,学习率1e-3;基于ImageDataset加载原始图像数据集。
  2. 自回归模型训练器 AutoregressiveTrainer 专用于 AR 模型。使用交叉熵损失,适配令牌序列的分类预测任务;基于TokenDataset加载令牌化后的图像序列;优化器为AdamW,学习率1e-3。
  3. 模型保存回调 CheckPointer 。模型保存的触发时机是在每个训练 epoch 结束后;有两种保存方式:一种是带时间戳的模型,保存方式为checkpoints/{时间戳}_{模型名}.pth;另一种是最新的模型,保存路径为当前目录下的{模型名}.pth

此外,还实现了模型加载 / 创建的逻辑。

python 复制代码
def train(model_name_or_path: str, epochs: int = 5, batch_size: int = 64):
    import lightning as L
    from lightning.pytorch.loggers import TensorBoardLogger

    from data import ImageDataset, TokenDataset

    class PatchTrainer(L.LightningModule):
        def __init__(self, model):
            super().__init__()
            self.model = model

        def training_step(self, x, batch_idx):
            x = x.float() / 255.0 - 0.5

            x_hat, additional_losses = self.model(x)
            loss = torch.nn.functional.mse_loss(x_hat, x)
            self.log("train/loss", loss, prog_bar=True)
            for k, v in additional_losses.items():
                self.log(f"train/{k}", v)
            return loss + sum(additional_losses.values())

        def validation_step(self, x, batch_idx):
            x = x.float() / 255.0 - 0.5

            with torch.no_grad():
                x_hat, additional_losses = self.model(x)
                loss = torch.nn.functional.mse_loss(x_hat, x)
            self.log("validation/loss", loss, prog_bar=True)
            for k, v in additional_losses.items():
                self.log(f"validation/{k}", v)
            if batch_idx == 0:
                self.logger.experiment.add_images(
                    "input", (x[:64] + 0.5).clamp(min=0, max=1).permute(0, 3, 1, 2), self.global_step
                )
                self.logger.experiment.add_images(
                    "prediction", (x_hat[:64] + 0.5).clamp(min=0, max=1).permute(0, 3, 1, 2), self.global_step
                )
            return loss

        def configure_optimizers(self):
            return torch.optim.AdamW(self.parameters(), lr=1e-3)

        def train_dataloader(self):
            dataset = ImageDataset("train")
            return torch.utils.data.DataLoader(dataset, batch_size=batch_size, num_workers=4, shuffle=True)

        def val_dataloader(self):
            dataset = ImageDataset("valid")
            return torch.utils.data.DataLoader(dataset, batch_size=4096, num_workers=4, shuffle=True)

    class AutoregressiveTrainer(L.LightningModule):
        def __init__(self, model):
            super().__init__()
            self.model = model

        def training_step(self, x, batch_idx):
            x_hat, additional_losses = self.model(x)
            loss = (
                torch.nn.functional.cross_entropy(x_hat.view(-1, x_hat.shape[-1]), x.view(-1), reduction="sum")
                / math.log(2)
                / x.shape[0]
            )
            self.log("train/loss", loss, prog_bar=True)
            for k, v in additional_losses.items():
                self.log(f"train/{k}", v)
            return loss + sum(additional_losses.values())

        def validation_step(self, x, batch_idx):
            with torch.no_grad():
                x_hat, additional_losses = self.model(x)
                loss = (
                    torch.nn.functional.cross_entropy(x_hat.view(-1, x_hat.shape[-1]), x.view(-1), reduction="sum")
                    / math.log(2)
                    / x.shape[0]
                )
            self.log("validation/loss", loss, prog_bar=True)
            for k, v in additional_losses.items():
                self.log(f"validation/{k}", v)
            return loss

        def configure_optimizers(self):
            return torch.optim.AdamW(self.parameters(), lr=1e-3)

        def train_dataloader(self):
            dataset = TokenDataset("train")
            return torch.utils.data.DataLoader(dataset, batch_size=batch_size, num_workers=4, shuffle=True)

        def val_dataloader(self):
            dataset = TokenDataset("valid")
            return torch.utils.data.DataLoader(dataset, batch_size=batch_size, num_workers=4, shuffle=True)

    class CheckPointer(L.Callback):
        def on_train_epoch_end(self, trainer, pl_module):
            fn = Path(f"checkpoints/{timestamp}_{model_name}.pth")
            fn.parent.mkdir(exist_ok=True, parents=True)
            torch.save(model, fn)
            torch.save(model, Path(__file__).parent / f"{model_name}.pth")

    # Load or create the model
    if Path(model_name_or_path).exists():
        model = torch.load(model_name_or_path, weights_only=False)
        model_name = model.__class__.__name__
    else:
        model_name = model_name_or_path
        if model_name in patch_models:
            model = patch_models[model_name]()
        elif model_name in ar_models:
            model = ar_models[model_name]()
        else:
            raise ValueError(f"Unknown model: {model_name}")

    # Create the lightning model
    if isinstance(model, (autoregressive.Autoregressive)):
        l_model = AutoregressiveTrainer(model)
    else:
        l_model = PatchTrainer(model)

    timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    logger = TensorBoardLogger("logs", name=f"{timestamp}_{model_name}")
    trainer = L.Trainer(max_epochs=epochs, logger=logger, callbacks=[CheckPointer()])
    trainer.fit(
        model=l_model,
    )

3. 命令行启动

本项目借助fire库实现命令行参数解析,无需手动解析--epochs/--batch_size等参数,直接通过python train.py {模型名} --epochs 10启动训练。

python 复制代码
if __name__ == "__main__":
    from fire import Fire
    Fire(train)

train.py核心使用方法如下:

python 复制代码
# 训练块级自编码器
python train.py PatchAutoEncoder --epochs 5 --batch_size 64

# 训练自回归模型
python train.py AutoregressiveModel --epochs 10 --batch_size 32

# 加载已有模型续训
python train.py checkpoints/2025-10-20_PatchAutoEncoder.pth --epochs 10

加载数据

接下来熟悉这个项目是如何进行数据加载的,data.py模块定义两类 PyTorch 兼容的数据集类。

其中:

  • ImageDataset:加载原始 JPG 图像,提供缓存机制提升读取效率;
  • TokenDataset:加载令牌化后的图像张量(由tokenize.py生成),供自回归模型训练使用。。

1. 导入依赖库

python 复制代码
from pathlib import Path
import torch
from PIL import Image

# 自动定位数据集根目录:当前文件的父父目录下的data文件夹
DATASET_PATH = Path(__file__).parent.parent / "data"

2. ImageDataset(原始图像数据集)

python 复制代码
class ImageDataset:
    def __init__(self, split: str, cache_images: bool = True):
        # 收集split(train/valid)目录下所有.jpg文件路径
        self.image_paths = list((DATASET_PATH / split).rglob("*.jpg"))
        # 初始化图像缓存列表,避免重复读取磁盘
        self._image_cache = [None] * len(self.image_paths)
        self._cache_images = cache_images  # 是否开启缓存

    def __len__(self) -> int:
        return len(self.image_paths)  # 数据集总长度

    def __getitem__(self, idx: int) -> torch.Tensor:
        # 优先读取缓存,无缓存则加载图像
        if self._image_cache[idx] is not None:
            return self._image_cache[idx]
        # 图像加载:PIL打开→转numpy数组→转torch.uint8张量(保持原始像素值)
        img = torch.tensor(np.array(Image.open(self.image_paths[idx])), dtype=torch.uint8)
        # 开启缓存则存入,后续复用
        if self._cache_images:
            self._image_cache[idx] = img
        return img

3. TokenDataset(令牌化数据集)

python 复制代码
class TokenDataset(torch.utils.data.TensorDataset):
    def __init__(self, split: str):
        # 加载令牌化后的张量文件(由tokenize.py生成)
        tensor_path = DATASET_PATH / f"tokenized_{split}.pth"
        if not tensor_path.exists():
            # 文件不存在时给出明确提示,符合作业流程指引
            raise FileNotFoundError(f"Tokenized dataset not found at {tensor_path}...")
        self.data = torch.load(tensor_path, weights_only=False)

    def __getitem__(self, idx: int) -> torch.Tensor:
        # 返回长整型张量(适配自回归模型的离散令牌输入)
        return torch.tensor(self.data[idx], dtype=torch.long)

    def __len__(self) -> int:
        return len(self.data)

这两个数据集加载对象的使用方法如下所示:

python 复制代码
# 加载训练集原始图像(用于AE/BSQ训练)
from data import ImageDataset, TokenDataset

train_img_ds = ImageDataset("train", cache_images=True)
img_tensor = train_img_ds[0]  # 取第0张图像,shape: (H, W, 3)

# 加载训练集令牌数据(用于自回归模型训练)
train_token_ds = TokenDataset("train")
token_tensor = train_token_ds[0]  # 取第0个令牌序列,shape: (序列长度,)

# 配合DataLoader使用
from torch.utils.data import DataLoader
train_loader = DataLoader(train_token_ds, batch_size=64, shuffle=True)
相关推荐
元气满满-樱5 小时前
LNMP架构学习
android·学习·架构
geneculture5 小时前
融智学体系图谱(精确对应版)
大数据·人工智能·学习·融智学的重要应用·信智序位
秋深枫叶红5 小时前
嵌入式第三十六篇——linux系统编程——线程
linux·运维·服务器·学习
走在路上的菜鸟6 小时前
Android学Dart学习笔记第十七节 类-成员方法
android·笔记·学习·flutter
程芯带你刷C语言简单算法题6 小时前
Day30~实现strcmp、strncmp、strchr、strpbrk
c语言·学习·算法·c
桓峰基因6 小时前
SCS 60.单细胞空间转录组空间聚类(SPATA2)
人工智能·算法·机器学习·数据挖掘·聚类
Miqiuha6 小时前
关注feed流系统设计学习
学习
阿蒙Amon6 小时前
JavaScript学习笔记:16.模块
javascript·笔记·学习
im_AMBER6 小时前
Leetcode 79 最佳观光组合
笔记·学习·算法·leetcode