Pytorch-MLP-Mnist

文章目录

model.py

py 复制代码
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

class MLP_cls(nn.Module):
    def __init__(self,in_dim=28*28):
        super(MLP_cls,self).__init__()
        self.lin1 = nn.Linear(in_dim,128)
        self.lin2 = nn.Linear(128,64)
        self.lin3 = nn.Linear(64,10)
        self.relu = nn.ReLU()
        init.xavier_uniform_(self.lin1.weight)
        init.xavier_uniform_(self.lin2.weight)
        init.xavier_uniform_(self.lin3.weight)

    def forward(self,x):
        x = x.view(-1,28*28)
        x = self.lin1(x)
        x = self.relu(x)
        x = self.lin2(x)
        x = self.relu(x)
        x = self.lin3(x)
        x = self.relu(x)
        return x

main.py

py 复制代码
import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
import torch.optim as optim
from model import MLP_cls


seed = 42
torch.manual_seed(seed)
batch_size_train = 64
batch_size_test  = 64
epochs = 10
learning_rate = 0.01
momentum = 0.5
mlp_net = MLP_cls()

train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('./data/', train=True, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.5,), (0.5,))
                               ])),
    batch_size=batch_size_train, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('./data/', train=False, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.5,), (0.5,))
                               ])),
    batch_size=batch_size_test, shuffle=True)

optimizer = optim.SGD(mlp_net.parameters(), lr=learning_rate,momentum=momentum)
criterion = nn.CrossEntropyLoss()

print("****************Begin Training****************")
mlp_net.train()
for epoch in range(epochs):
    run_loss = 0
    correct_num = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        out = mlp_net(data)
        _,pred = torch.max(out,dim=1)
        optimizer.zero_grad()
        loss = criterion(out,target)
        loss.backward()
        run_loss += loss
        optimizer.step()
        correct_num  += torch.sum(pred==target)
    print('epoch',epoch,'loss {:.2f}'.format(run_loss.item()/len(train_loader)),'accuracy {:.2f}'.format(correct_num.item()/(len(train_loader)*batch_size_train)))



print("****************Begin Testing****************")
mlp_net.eval()
test_loss = 0
test_correct_num = 0
for batch_idx, (data, target) in enumerate(test_loader):
    out = mlp_net(data)
    _,pred = torch.max(out,dim=1)
    test_loss += criterion(out,target)
    test_correct_num  += torch.sum(pred==target)
print('loss {:.2f}'.format(test_loss.item()/len(test_loader)),'accuracy {:.2f}'.format(test_correct_num.item()/(len(test_loader)*batch_size_test)))

参数设置

bash 复制代码
'./data/' #数据保存路径
seed = 42 #随机种子
batch_size_train = 64
batch_size_test  = 64
epochs = 10

optim --> SGD
learning_rate = 0.01
momentum = 0.5

注意事项

初始化权重

这里使用这种方式

py 复制代码
        init.xavier_uniform_(self.lin1.weight)
        init.xavier_uniform_(self.lin2.weight)
        init.xavier_uniform_(self.lin3.weight)

如果发现loss和acc不变

检查一下是不是忘记写optimizer.step()了

关于数据下载

数据在download=True时,会下载在./data文件夹下

关于输出格式

这里用'xxx {:.2f}'.format(xxx),保留两位小数。注意中间的空格,区分:.2f和%2f

运行图

相关推荐
liwulin05067 分钟前
【PYTHON-YOLOV8N】yoloface+pytorch+cnn进行面部表情识别
python·yolo·cnn
(●—●)橘子……23 分钟前
记力扣1471.数组中的k个最强值 练习理解
数据结构·python·学习·算法·leetcode
会挠头但不秃24 分钟前
深度学习(5)循环神经网络
人工智能·rnn·深度学习
_OP_CHEN26 分钟前
用极狐 CodeRider-Kilo 开发俄罗斯方块:AI 辅助编程的沉浸式体验
人工智能·vscode·python·ai编程·ai编程插件·coderider-kilo
这张生成的图像能检测吗26 分钟前
(论文速读)LCT:用于RGB-D突出物体检测的轻型跨模态变压器
图像处理·目标检测·计算机视觉·深度估计·轻量化模型·跨模态融合·rgb-d
Wpa.wk29 分钟前
自动化测试 - 文件上传 和 弹窗处理
开发语言·javascript·自动化测试·经验分享·爬虫·python·selenium
_OP_CHEN30 分钟前
【Python基础】(二)从 0 到 1 入门 Python 语法基础:从表达式到运算符的全面指南
开发语言·python
brave and determined35 分钟前
CANN训练营 学习(day7)昇腾AI训练全流程实战:从模型迁移到性能优化的深度指南
pytorch·ai·ai训练·昇腾ai·msprobe·模型性能调优·训练配置
我命由我1234539 分钟前
Python Flask 开发:在 Flask 中返回字符串时,浏览器将其作为 HTML 解析
服务器·开发语言·后端·python·flask·html·学习方法
拾忆,想起42 分钟前
设计模式:软件开发的可复用武功秘籍
开发语言·python·算法·微服务·设计模式·性能优化·服务发现