AE——重构数字(Pytorch+mnist)

1、简介

  • **AE(自编码器)**由编码器和解码器组成,编码器将输入数据映射到潜在空间,解码器将潜在表示映射回原始输入空间。
  • AE的训练目标通常是最小化重构误差,即尽可能地重构输入数据,使得解码器输出与原始输入尽可能接近。
  • AE通常用于数据压缩、去噪、特征提取等任务。
  • 本文利用AE,输入数字图像。训练后,输入测试数字图像,重构生成新的数字图像。
    • 【注】本文案例需要输入才能生成输出,目标是重构,而不是生成。
  • 可以看出,重构图片和原始图片差别不大。
  • 【注】输出的10张数字图像是输入的测试图像的第一批次。

2、代码

python 复制代码
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision


# 在一个类中编写编码器和解码器层。为编码器和解码器层的组件都定义了全连接层
class AE(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
        self.encoder_hidden_layer = nn.Linear(
            in_features=kwargs["input_shape"], out_features=128
        )  # 编码器隐藏层
        self.encoder_output_layer = nn.Linear(
            in_features=128, out_features=128
        )  # 编码器输出层
        self.decoder_hidden_layer = nn.Linear(
            in_features=128, out_features=128
        )  # 解码器隐藏层
        self.decoder_output_layer = nn.Linear(
            in_features=128, out_features=kwargs["input_shape"]
        )  # 解码器输出层

    # 定义了模型的前向传播过程,包括激活函数的应用和重构图像的生成
    def forward(self, features):
        activation = self.encoder_hidden_layer(features)
        activation = torch.relu(activation)  # ReLU 激活函数,得到编码器的激活值
        code = self.encoder_output_layer(activation)
        code = torch.sigmoid(code)  # Sigmoid 激活函数,以确保编码后的表示在 [0, 1] 范围内
        activation = self.decoder_hidden_layer(code)
        activation = torch.relu(activation)
        activation = self.decoder_output_layer(activation)
        reconstructed = torch.sigmoid(activation)
        return reconstructed


if __name__ == '__main__':
    # 设置批大小、学习周期和学习率
    batch_size = 512
    epochs = 30
    learning_rate = 1e-3

    # 载入 MNIST 数据集中的图片进行训练
    transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])  # 将图像转换为张量

    train_dataset = torchvision.datasets.MNIST(
        root="~/torch_datasets", train=True, transform=transform, download=True
    )  # 加载 MNIST 数据集的训练集,设置路径、转换和下载为 True

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True
    )  # 创建一个数据加载器,用于加载训练数据,设置批处理大小和是否随机打乱数据

    # 在使用定义的 AE 类之前,有以下事情要做:
    # 配置要在哪个设备上运行
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 建立 AE 模型并载入到 CPU 设备
    model = AE(input_shape=784).to(device)

    # Adam 优化器,学习率 10e-3
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # 使用均方误差(MSE)损失函数
    criterion = nn.MSELoss()

    # 在GPU设备上运行,实例化一个输入大小为784的AE自编码器,并用Adam作为训练优化器用MSELoss作为损失函数
    # 训练:
    for epoch in range(epochs):
        loss = 0
        for batch_features, _ in train_loader:
            # 将小批数据变形为 [N, 784] 矩阵,并加载到 CPU 设备
            batch_features = batch_features.view(-1, 784).to(device)

            # 梯度设置为 0,因为 torch 会累加梯度
            optimizer.zero_grad()

            # 计算重构
            outputs = model(batch_features)

            # 计算训练重建损失
            train_loss = criterion(outputs, batch_features)

            # 计算累积梯度
            train_loss.backward()

            # 根据当前梯度更新参数
            optimizer.step()

            # 将小批量训练损失加到周期损失中
            loss += train_loss.item()

        # 计算每个周期的训练损失
        loss = loss / len(train_loader)

        # 显示每个周期的训练损失
        print("epoch : {}/{}, recon loss = {:.8f}".format(epoch + 1, epochs, loss))

    # 用训练过的自编码器提取一些测试用例来重构
    test_dataset = torchvision.datasets.MNIST(
        root="~/torch_datasets", train=False, transform=transform, download=True
    )  # 加载 MNIST 测试数据集

    test_loader = torch.utils.data.DataLoader(
        test_dataset, batch_size=10, shuffle=False
    )  # 创建一个测试数据加载器

    test_examples = None

    # 通过循环遍历测试数据加载器,获取一个批次的图像数据
    with torch.no_grad():  # 使用 torch.no_grad() 上下文管理器,确保在该上下文中不会进行梯度计算
        for batch_features in test_loader:  # 历测试数据加载器中的每个批次的图像数据
            batch_features = batch_features[0]  # 获取当前批次的图像数据
            test_examples = batch_features.view(-1, 784).to(
                device)  # 将当前批次的图像数据转换为大小为 (批大小, 784) 的张量,并加载到指定的设备(CPU 或 GPU)上
            reconstruction = model(test_examples)  # 使用训练好的自编码器模型对测试数据进行重构,即生成重构的图像
            break

    # 试着用训练过的自编码器重建一些测试图像
    with torch.no_grad():
        number = 10  # 设置要显示的图像数量
        plt.figure(figsize=(20, 4))  # 创建一个新的 Matplotlib 图形,设置图形大小为 (20, 4)
        for index in range(number):  # 遍历要显示的图像数量
            # 显示原始图
            ax = plt.subplot(2, number, index + 1)
            plt.imshow(test_examples[index].cpu().numpy().reshape(28, 28))
            plt.gray()
            ax.get_xaxis().set_visible(False)
            ax.get_yaxis().set_visible(False)

            # 显示重构图
            ax = plt.subplot(2, number, index + 1 + number)
            plt.imshow(reconstruction[index].cpu().numpy().reshape(28, 28))
            plt.gray()
            ax.get_xaxis().set_visible(False)
            ax.get_yaxis().set_visible(False)
        plt.savefig('reconstruction_results.png')  # 保存图像
        plt.show()
相关推荐
NAGNIP12 小时前
一文搞懂深度学习中的通用逼近定理!
人工智能·算法·面试
冬奇Lab13 小时前
一天一个开源项目(第36篇):EverMemOS - 跨 LLM 与平台的长时记忆 OS,让 Agent 会记忆更会推理
人工智能·开源·资讯
冬奇Lab13 小时前
OpenClaw 源码深度解析(一):Gateway——为什么需要一个"中枢"
人工智能·开源·源码阅读
AngelPP17 小时前
OpenClaw 架构深度解析:如何把 AI 助手搬到你的个人设备上
人工智能
宅小年17 小时前
Claude Code 换成了Kimi K2.5后,我再也回不去了
人工智能·ai编程·claude
九狼17 小时前
Flutter URL Scheme 跨平台跳转
人工智能·flutter·github
ZFSS17 小时前
Kimi Chat Completion API 申请及使用
前端·人工智能
天翼云开发者社区18 小时前
春节复工福利就位!天翼云息壤2500万Tokens免费送,全品类大模型一键畅玩!
人工智能·算力服务·息壤
知识浅谈18 小时前
教你如何用 Gemini 将课本图片一键转为精美 PPT
人工智能
Ray Liang19 小时前
被低估的量化版模型,小身材也能干大事
人工智能·ai·ai助手·mindx