【PyTorch Lightning】

python 复制代码
import os
from torch import optim, nn, utils, Tensor
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import lightning as L

# define any number of nn.Modules (or use your current ones)
encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))


# define the LightningModule
class LitAutoEncoder(L.LightningModule):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        # it is independent of forward
        x, _ = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = nn.functional.mse_loss(x_hat, x)
        # Logging to TensorBoard (if installed) by default
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer


# init the autoencoder
autoencoder = LitAutoEncoder(encoder, decoder)

# setup data
dataset = MNIST(os.getcwd(), download=True, transform=ToTensor())
train_loader = utils.data.DataLoader(dataset)

# train the model (hint: here are some helpful Trainer arguments for rapid idea iteration)
trainer = L.Trainer(limit_train_batches=100, max_epochs=1)
trainer.fit(model=autoencoder, train_dataloaders=train_loader)

# load checkpoint
checkpoint = "./lightning_logs/version_0/checkpoints/epoch=0-step=100.ckpt"
autoencoder = LitAutoEncoder.load_from_checkpoint(checkpoint, encoder=encoder, decoder=decoder)

# choose your trained nn.Module
encoder = autoencoder.encoder
encoder.eval()

# embed 4 fake images!
fake_image_batch = torch.rand(4, 28 * 28, device=autoencoder.device)
embeddings = encoder(fake_image_batch)
print("⚡" * 20, "\nPredictions (4 image embeddings):\n", embeddings, "\n", "⚡" * 20)
python 复制代码
tensorboard --logdir .

Module中要定义forward函数;

Lightning Module中除了forward函数,还要定义configure_optimizers, training_step和validation_step函数

相关推荐
小龙报16 小时前
《赋能AI解锁Coze智能体搭建核心技能(1)--- 初识coze》
人工智能·语言模型·数据分析·交互·文心一言·机器翻译·coze
love530love16 小时前
【笔记】Podman Desktop 部署 开源数字人 HeyGem.ai
人工智能·windows·笔记·python·容器·开源·podman
CoookeCola16 小时前
开源图像与视频过曝检测工具:HSV色彩空间分析与时序平滑处理技术详解
人工智能·深度学习·算法·目标检测·计算机视觉·开源·音视频
董厂长17 小时前
综述:deepSeek-OCR,paddle-OCR,VLM
人工智能·计算机视觉
gfdgd xi17 小时前
deepin 终端,但是版本是 deepin 15 的
linux·python·架构·ssh·bash·shell·deepin
禁默17 小时前
基于金仓KFS工具,破解多数据并存,浙人医改造实战医疗信创
数据库·人工智能·金仓数据库
云卓SKYDROID17 小时前
无人机动力学模块技术要点与难点
人工智能·无人机·材质·高科技·云卓科技
王六岁17 小时前
🐍 前端开发 0 基础学 Python 入门指南:条件语句篇
前端·python
java1234_小锋17 小时前
PyTorch2 Python深度学习 - 初识PyTorch2,实现一个简单的线性神经网络
开发语言·python·深度学习·pytorch2
胡萝卜3.017 小时前
C++面向对象继承全面解析:不能被继承的类、多继承、菱形虚拟继承与设计模式实践
开发语言·c++·人工智能·stl·继承·菱形继承·组合vs继承