AE——重构数字(Pytorch+mnist)

1、简介

  • **AE(自编码器)**由编码器和解码器组成,编码器将输入数据映射到潜在空间,解码器将潜在表示映射回原始输入空间。
  • AE的训练目标通常是最小化重构误差,即尽可能地重构输入数据,使得解码器输出与原始输入尽可能接近。
  • AE通常用于数据压缩、去噪、特征提取等任务。
  • 本文利用AE,输入数字图像。训练后,输入测试数字图像,重构生成新的数字图像。
    • 【注】本文案例需要输入才能生成输出,目标是重构,而不是生成。
  • 可以看出,重构图片和原始图片差别不大。
  • 【注】输出的10张数字图像是输入的测试图像的第一批次。

2、代码

python 复制代码
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision


# 在一个类中编写编码器和解码器层。为编码器和解码器层的组件都定义了全连接层
class AE(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
        self.encoder_hidden_layer = nn.Linear(
            in_features=kwargs["input_shape"], out_features=128
        )  # 编码器隐藏层
        self.encoder_output_layer = nn.Linear(
            in_features=128, out_features=128
        )  # 编码器输出层
        self.decoder_hidden_layer = nn.Linear(
            in_features=128, out_features=128
        )  # 解码器隐藏层
        self.decoder_output_layer = nn.Linear(
            in_features=128, out_features=kwargs["input_shape"]
        )  # 解码器输出层

    # 定义了模型的前向传播过程,包括激活函数的应用和重构图像的生成
    def forward(self, features):
        activation = self.encoder_hidden_layer(features)
        activation = torch.relu(activation)  # ReLU 激活函数,得到编码器的激活值
        code = self.encoder_output_layer(activation)
        code = torch.sigmoid(code)  # Sigmoid 激活函数,以确保编码后的表示在 [0, 1] 范围内
        activation = self.decoder_hidden_layer(code)
        activation = torch.relu(activation)
        activation = self.decoder_output_layer(activation)
        reconstructed = torch.sigmoid(activation)
        return reconstructed


if __name__ == '__main__':
    # 设置批大小、学习周期和学习率
    batch_size = 512
    epochs = 30
    learning_rate = 1e-3

    # 载入 MNIST 数据集中的图片进行训练
    transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])  # 将图像转换为张量

    train_dataset = torchvision.datasets.MNIST(
        root="~/torch_datasets", train=True, transform=transform, download=True
    )  # 加载 MNIST 数据集的训练集,设置路径、转换和下载为 True

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True
    )  # 创建一个数据加载器,用于加载训练数据,设置批处理大小和是否随机打乱数据

    # 在使用定义的 AE 类之前,有以下事情要做:
    # 配置要在哪个设备上运行
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 建立 AE 模型并载入到 CPU 设备
    model = AE(input_shape=784).to(device)

    # Adam 优化器,学习率 10e-3
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # 使用均方误差(MSE)损失函数
    criterion = nn.MSELoss()

    # 在GPU设备上运行,实例化一个输入大小为784的AE自编码器,并用Adam作为训练优化器用MSELoss作为损失函数
    # 训练:
    for epoch in range(epochs):
        loss = 0
        for batch_features, _ in train_loader:
            # 将小批数据变形为 [N, 784] 矩阵,并加载到 CPU 设备
            batch_features = batch_features.view(-1, 784).to(device)

            # 梯度设置为 0,因为 torch 会累加梯度
            optimizer.zero_grad()

            # 计算重构
            outputs = model(batch_features)

            # 计算训练重建损失
            train_loss = criterion(outputs, batch_features)

            # 计算累积梯度
            train_loss.backward()

            # 根据当前梯度更新参数
            optimizer.step()

            # 将小批量训练损失加到周期损失中
            loss += train_loss.item()

        # 计算每个周期的训练损失
        loss = loss / len(train_loader)

        # 显示每个周期的训练损失
        print("epoch : {}/{}, recon loss = {:.8f}".format(epoch + 1, epochs, loss))

    # 用训练过的自编码器提取一些测试用例来重构
    test_dataset = torchvision.datasets.MNIST(
        root="~/torch_datasets", train=False, transform=transform, download=True
    )  # 加载 MNIST 测试数据集

    test_loader = torch.utils.data.DataLoader(
        test_dataset, batch_size=10, shuffle=False
    )  # 创建一个测试数据加载器

    test_examples = None

    # 通过循环遍历测试数据加载器,获取一个批次的图像数据
    with torch.no_grad():  # 使用 torch.no_grad() 上下文管理器,确保在该上下文中不会进行梯度计算
        for batch_features in test_loader:  # 历测试数据加载器中的每个批次的图像数据
            batch_features = batch_features[0]  # 获取当前批次的图像数据
            test_examples = batch_features.view(-1, 784).to(
                device)  # 将当前批次的图像数据转换为大小为 (批大小, 784) 的张量,并加载到指定的设备(CPU 或 GPU)上
            reconstruction = model(test_examples)  # 使用训练好的自编码器模型对测试数据进行重构,即生成重构的图像
            break

    # 试着用训练过的自编码器重建一些测试图像
    with torch.no_grad():
        number = 10  # 设置要显示的图像数量
        plt.figure(figsize=(20, 4))  # 创建一个新的 Matplotlib 图形,设置图形大小为 (20, 4)
        for index in range(number):  # 遍历要显示的图像数量
            # 显示原始图
            ax = plt.subplot(2, number, index + 1)
            plt.imshow(test_examples[index].cpu().numpy().reshape(28, 28))
            plt.gray()
            ax.get_xaxis().set_visible(False)
            ax.get_yaxis().set_visible(False)

            # 显示重构图
            ax = plt.subplot(2, number, index + 1 + number)
            plt.imshow(reconstruction[index].cpu().numpy().reshape(28, 28))
            plt.gray()
            ax.get_xaxis().set_visible(False)
            ax.get_yaxis().set_visible(False)
        plt.savefig('reconstruction_results.png')  # 保存图像
        plt.show()
相关推荐
元岳数字人小元4 分钟前
AI 数字人开发公司浅谈 虚拟数字人打造景区新服务
人工智能·人机交互·交互
哦哦~9218 分钟前
AI赋能生物医学:从临床数据到药物分子性质预测实战培
人工智能·生物医学·药物分子
GIS数据转换器10 分钟前
城市排水生命线安全运行监测平台深度解析
java·运维·人工智能·python·安全·数据挖掘·无人机
虫无涯13 分钟前
本地离线大模型实战:Ollama + Llama 3.1 8B 全流程部署(适配VSCode Continue代码助手)
人工智能
Rocky Ding*28 分钟前
Latent Consistency Models:一篇读懂扩散模型的少步生成核心基础知识
人工智能·深度学习·机器学习·ai作画·stable diffusion·aigc·ai-native
大山佬30 分钟前
AI 边缘部署:MCU 上的轻量级目标检测,从 YOLO 到 TFLite Micro 的全链路优化
人工智能
数睿数据无代码开发31 分钟前
深度解析smardaten数据大屏:六大核心功能重塑可视化开发
人工智能·信息可视化
陈猪的杰咪32 分钟前
GitHub Copilot 2026计费新规:AI Credits消耗解析与节省策略
人工智能·ai·架构·github·copilot
学术头条40 分钟前
清华团队开源SCAIL-2:角色动画告别骨骼依赖,端到端还原视频中动作细节
人工智能·科技·机器学习·ai·开源·音视频·agi
لا معنى له40 分钟前
世界模型的功能分类法——Renderers, Simulators, Planners, and the Loop That Connects Them
人工智能