代码填空任务---自编码器模型

1.自编码器模型填空:

python 复制代码
import torch
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from torch import nn, optim
from torch.nn import functional as F
from tqdm import tqdm
import os
 
# os.chdir(os.path.dirname(__file__))
 
'模型结构'

 
#损失函数
#交叉熵,衡量各个像素原始数据与重构数据的误差

#均方误差可作为交叉熵替代使用.衡量各个像素原始数据与重构数据的误差

 
'超参数及构造模型'
#模型参数
#压缩后的特征维度
#encoder和decoder中间层的维度
#原始图片和生成图片的维度
 
#训练参数
#训练时期
#每步训练样本数
#学习率
device =torch.device('cuda' if torch.cuda.is_available() else 'cpu')#训练设备
 
#确定模型,导入已训练模型(如有)
modelname = 'ae.pth'
#模型初始化
#优化器

try:
    model.load_state_dict(torch.load(modelname))
    print('[INFO] Load Model complete')
except:
    pass
 
'训练模型'
#准备mnist数据集 (数据会下载到py文件所在的data文件夹下)
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./', train=True, download=True,
                   transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./', train=False, transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=False)
#此方法获取的数据各像素值范围0-1
 
#训练及测试
loss_history = {'train':[],'eval':[]}
for epoch in range(epochs):   
    #训练

    #每个epoch重置损失,设置进度条
    train_loss = 0
    train_nsample = 0
    t = tqdm(train_loader,desc = f'[train]epoch:{epoch}')
    for imgs, lbls in t: #imgs:(bs,28,28)

        #获取数据
        #imgs:(bs,28*28)
        #模型运算     
        
        #计算损失
        # 重构与原始数据的差距(也可使用loss_MSE)
        #反向传播、参数优化,重置
        
        
        
        #计算平均损失,设置进度条
        
        
        t.set_postfix({'loss':train_loss/train_nsample})
    #每个epoch记录总损失
    loss_history['train'].append(train_loss/train_nsample)
 
    #测试

    #每个epoch重置损失,设置进度条
    test_loss = 0
    test_nsample = 0
    e = tqdm(test_loader,desc = f'[eval]epoch:{epoch}')
    for imgs, label in e:

        #获取数据
        
        #模型运算   
        
        #计算损失
         
        #计算平均损失,设置进度条
        
        
        e.set_postfix({'loss':test_loss/test_nsample})
    #每个epoch记录总损失    
    loss_history['eval'].append(test_loss/test_nsample)
 
 
    #展示效果   
    #将测试步骤中的数据、重构数据绘图
    concat = torch.cat((imgs[0].view(28, 28),
            re_imgs[0].view( 28, 28)), 1)
    plt.matshow(concat.cpu().detach().numpy())
    plt.show()
 
    #显示每个epoch的loss变化
    plt.plot(range(epoch+1),loss_history['train'])
    plt.plot(range(epoch+1),loss_history['eval'])
    plt.show()
    #存储模型
    torch.save(model.state_dict(),modelname)
 
'调用模型'
#对数据集
dataset = datasets.MNIST('./', train=False, transform=transforms.ToTensor())
#取一组手写数据(正常数据)
raw = dataset[0][0].view(1,-1) #raw: bs,28,28->bs,28*28
#对手写数据(正常数据)重构
re_raw = model(raw.to(device))
#取一组随机数据(异常数据)
rand = torch.randn_like(raw)
#对随机数据(异常数据)重构
re_rand = model(rand.to(device))
 
#定义一个衡量标准,按像素平均所有原始数据和重构数据的误差
f = lambda x,y: abs(x-y).mean()
#正常数据 原始数据与重构数据差异
print('正常数据:',f(re_raw.to("cpu"),raw))
#异常数据 原始数据与重构数据差异
print('异常数据:',f(re_rand.to("cpu"),rand))
 
#正常数据,原始数据与重构数据作图
plt.matshow(raw.view(28,28).detach().cpu().numpy())
plt.show()
plt.matshow(re_raw.view(28,28).detach().cpu().numpy())
plt.show()
#异常数据,原始数据与重构数据作图
plt.matshow(rand.view(28,28).detach().cpu().numpy())
plt.show()
plt.matshow(re_rand.view(28,28).detach().cpu().numpy())
plt.show()

参考:

手写系列------AE网络、VAE网络和Condition VAE网络-CSDN博客

相关推荐
HyperAI超神经41 分钟前
IQuest-Coder-V1:基于代码流训练的编程逻辑增强模型;Human Face Emotions:基于多标注维度的人脸情绪识别数据集
人工智能·深度学习·学习·机器学习·ai编程
啊阿狸不会拉杆1 小时前
《机器学习》第 1 章 - 机器学习概述
人工智能·机器学习·ai·ml
52Hz1181 小时前
力扣73.矩阵置零、54.螺旋矩阵、48.旋转图像
python·算法·leetcode·矩阵
咚咚王者1 小时前
人工智能之核心基础 机器学习 第十八章 经典实战项目
人工智能·机器学习
DuHz1 小时前
矩阵束法(Matrix Pencil)用于 FMCW 雷达干扰抑制:论文精读
人工智能·机器学习·矩阵
编程小风筝1 小时前
机器学习和稀疏建模的应用场景和优势
人工智能·机器学习
weixin_462446232 小时前
Python 使用 openpyxl 从 URL 读取 Excel 并获取 Sheet 及单元格样式信息
python·excel·openpyxl
程序员小嬛2 小时前
(TETCI 2024) 从 U-Net 到 Transformer:即插即用注意力模块解析
人工智能·深度学习·机器学习·transformer
毕设源码-钟学长2 小时前
【开题答辩全过程】以 基于Python的健康食谱规划系统的设计与实现为例,包含答辩的问题和答案
开发语言·python
百***78753 小时前
Grok-4.1技术深度解析:双版本架构突破与Python API快速集成指南
大数据·python·架构