相关项目下载链接
训练框架
在开始实现相应模块功能之前,首先熟悉训练框架·train.py。
1. 导入与模型字典构建
python
import inspect
import math
from datetime import datetime
from pathlib import Path
import torch
import ae, autoregressive, bsq # 自定义模型模块(AE/BSQ/自回归)
# 收集ae/bsq模块中所有继承nn.Module的块级模型类
patch_models = {
n: m for M in [ae, bsq] for n, m in inspect.getmembers(M) if inspect.isclass(m) and issubclass(m, torch.nn.Module)
}
# 收集autoregressive模块中所有继承nn.Module的自回归模型类
ar_models = {
n: m for M in [autoregressive] for n, m in inspect.getmembers(M) if inspect.isclass(m) and issubclass(m, torch.nn.Module)
}
2. 核心训练函数 train()
共包含三个部分:块级模型训练器 PatchTrainer、自回归模型训练器 AutoregressiveTrainer、模型保存回调 CheckPointer。
其中,
- 块级模型训练器 PatchTrainer 专用于 AE/BSQ 模型。值得注意 的是数据预处理过程,图像归一化的方式是(/255.0 - 0.5),将像素值映射到[-0.5, 0.5]而不是[0,1];损失函数采用MSE(均方误差),适配图像重构任务;优化器为AdamW,学习率1e-3;基于
ImageDataset加载原始图像数据集。 - 自回归模型训练器 AutoregressiveTrainer 专用于 AR 模型。使用交叉熵损失,适配令牌序列的分类预测任务;基于
TokenDataset加载令牌化后的图像序列;优化器为AdamW,学习率1e-3。 - 模型保存回调 CheckPointer 。模型保存的触发时机是在每个训练 epoch 结束后;有两种保存方式:一种是带时间戳的模型,保存方式为
checkpoints/{时间戳}_{模型名}.pth;另一种是最新的模型,保存路径为当前目录下的{模型名}.pth。
此外,还实现了模型加载 / 创建的逻辑。
python
def train(model_name_or_path: str, epochs: int = 5, batch_size: int = 64):
import lightning as L
from lightning.pytorch.loggers import TensorBoardLogger
from data import ImageDataset, TokenDataset
class PatchTrainer(L.LightningModule):
def __init__(self, model):
super().__init__()
self.model = model
def training_step(self, x, batch_idx):
x = x.float() / 255.0 - 0.5
x_hat, additional_losses = self.model(x)
loss = torch.nn.functional.mse_loss(x_hat, x)
self.log("train/loss", loss, prog_bar=True)
for k, v in additional_losses.items():
self.log(f"train/{k}", v)
return loss + sum(additional_losses.values())
def validation_step(self, x, batch_idx):
x = x.float() / 255.0 - 0.5
with torch.no_grad():
x_hat, additional_losses = self.model(x)
loss = torch.nn.functional.mse_loss(x_hat, x)
self.log("validation/loss", loss, prog_bar=True)
for k, v in additional_losses.items():
self.log(f"validation/{k}", v)
if batch_idx == 0:
self.logger.experiment.add_images(
"input", (x[:64] + 0.5).clamp(min=0, max=1).permute(0, 3, 1, 2), self.global_step
)
self.logger.experiment.add_images(
"prediction", (x_hat[:64] + 0.5).clamp(min=0, max=1).permute(0, 3, 1, 2), self.global_step
)
return loss
def configure_optimizers(self):
return torch.optim.AdamW(self.parameters(), lr=1e-3)
def train_dataloader(self):
dataset = ImageDataset("train")
return torch.utils.data.DataLoader(dataset, batch_size=batch_size, num_workers=4, shuffle=True)
def val_dataloader(self):
dataset = ImageDataset("valid")
return torch.utils.data.DataLoader(dataset, batch_size=4096, num_workers=4, shuffle=True)
class AutoregressiveTrainer(L.LightningModule):
def __init__(self, model):
super().__init__()
self.model = model
def training_step(self, x, batch_idx):
x_hat, additional_losses = self.model(x)
loss = (
torch.nn.functional.cross_entropy(x_hat.view(-1, x_hat.shape[-1]), x.view(-1), reduction="sum")
/ math.log(2)
/ x.shape[0]
)
self.log("train/loss", loss, prog_bar=True)
for k, v in additional_losses.items():
self.log(f"train/{k}", v)
return loss + sum(additional_losses.values())
def validation_step(self, x, batch_idx):
with torch.no_grad():
x_hat, additional_losses = self.model(x)
loss = (
torch.nn.functional.cross_entropy(x_hat.view(-1, x_hat.shape[-1]), x.view(-1), reduction="sum")
/ math.log(2)
/ x.shape[0]
)
self.log("validation/loss", loss, prog_bar=True)
for k, v in additional_losses.items():
self.log(f"validation/{k}", v)
return loss
def configure_optimizers(self):
return torch.optim.AdamW(self.parameters(), lr=1e-3)
def train_dataloader(self):
dataset = TokenDataset("train")
return torch.utils.data.DataLoader(dataset, batch_size=batch_size, num_workers=4, shuffle=True)
def val_dataloader(self):
dataset = TokenDataset("valid")
return torch.utils.data.DataLoader(dataset, batch_size=batch_size, num_workers=4, shuffle=True)
class CheckPointer(L.Callback):
def on_train_epoch_end(self, trainer, pl_module):
fn = Path(f"checkpoints/{timestamp}_{model_name}.pth")
fn.parent.mkdir(exist_ok=True, parents=True)
torch.save(model, fn)
torch.save(model, Path(__file__).parent / f"{model_name}.pth")
# Load or create the model
if Path(model_name_or_path).exists():
model = torch.load(model_name_or_path, weights_only=False)
model_name = model.__class__.__name__
else:
model_name = model_name_or_path
if model_name in patch_models:
model = patch_models[model_name]()
elif model_name in ar_models:
model = ar_models[model_name]()
else:
raise ValueError(f"Unknown model: {model_name}")
# Create the lightning model
if isinstance(model, (autoregressive.Autoregressive)):
l_model = AutoregressiveTrainer(model)
else:
l_model = PatchTrainer(model)
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
logger = TensorBoardLogger("logs", name=f"{timestamp}_{model_name}")
trainer = L.Trainer(max_epochs=epochs, logger=logger, callbacks=[CheckPointer()])
trainer.fit(
model=l_model,
)
3. 命令行启动
本项目借助fire库实现命令行参数解析,无需手动解析--epochs/--batch_size等参数,直接通过python train.py {模型名} --epochs 10启动训练。
python
if __name__ == "__main__":
from fire import Fire
Fire(train)
train.py核心使用方法如下:
python
# 训练块级自编码器
python train.py PatchAutoEncoder --epochs 5 --batch_size 64
# 训练自回归模型
python train.py AutoregressiveModel --epochs 10 --batch_size 32
# 加载已有模型续训
python train.py checkpoints/2025-10-20_PatchAutoEncoder.pth --epochs 10
加载数据
接下来熟悉这个项目是如何进行数据加载的,data.py模块定义两类 PyTorch 兼容的数据集类。
其中:
ImageDataset:加载原始 JPG 图像,提供缓存机制提升读取效率;TokenDataset:加载令牌化后的图像张量(由tokenize.py生成),供自回归模型训练使用。。
1. 导入依赖库
python
from pathlib import Path
import torch
from PIL import Image
# 自动定位数据集根目录:当前文件的父父目录下的data文件夹
DATASET_PATH = Path(__file__).parent.parent / "data"
2. ImageDataset(原始图像数据集)
python
class ImageDataset:
def __init__(self, split: str, cache_images: bool = True):
# 收集split(train/valid)目录下所有.jpg文件路径
self.image_paths = list((DATASET_PATH / split).rglob("*.jpg"))
# 初始化图像缓存列表,避免重复读取磁盘
self._image_cache = [None] * len(self.image_paths)
self._cache_images = cache_images # 是否开启缓存
def __len__(self) -> int:
return len(self.image_paths) # 数据集总长度
def __getitem__(self, idx: int) -> torch.Tensor:
# 优先读取缓存,无缓存则加载图像
if self._image_cache[idx] is not None:
return self._image_cache[idx]
# 图像加载:PIL打开→转numpy数组→转torch.uint8张量(保持原始像素值)
img = torch.tensor(np.array(Image.open(self.image_paths[idx])), dtype=torch.uint8)
# 开启缓存则存入,后续复用
if self._cache_images:
self._image_cache[idx] = img
return img
3. TokenDataset(令牌化数据集)
python
class TokenDataset(torch.utils.data.TensorDataset):
def __init__(self, split: str):
# 加载令牌化后的张量文件(由tokenize.py生成)
tensor_path = DATASET_PATH / f"tokenized_{split}.pth"
if not tensor_path.exists():
# 文件不存在时给出明确提示,符合作业流程指引
raise FileNotFoundError(f"Tokenized dataset not found at {tensor_path}...")
self.data = torch.load(tensor_path, weights_only=False)
def __getitem__(self, idx: int) -> torch.Tensor:
# 返回长整型张量(适配自回归模型的离散令牌输入)
return torch.tensor(self.data[idx], dtype=torch.long)
def __len__(self) -> int:
return len(self.data)
这两个数据集加载对象的使用方法如下所示:
python
# 加载训练集原始图像(用于AE/BSQ训练)
from data import ImageDataset, TokenDataset
train_img_ds = ImageDataset("train", cache_images=True)
img_tensor = train_img_ds[0] # 取第0张图像,shape: (H, W, 3)
# 加载训练集令牌数据(用于自回归模型训练)
train_token_ds = TokenDataset("train")
token_tensor = train_token_ds[0] # 取第0个令牌序列,shape: (序列长度,)
# 配合DataLoader使用
from torch.utils.data import DataLoader
train_loader = DataLoader(train_token_ds, batch_size=64, shuffle=True)