Pytorch-MLP-CIFAR10

文章目录

model.py

py 复制代码
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

class MLP_cls(nn.Module):
    def __init__(self,in_dim=3*32*32):
        super(MLP_cls,self).__init__()
        self.lin1 = nn.Linear(in_dim,128)
        self.lin2 = nn.Linear(128,64)
        self.lin3 = nn.Linear(64,10)
        self.relu = nn.ReLU()
        init.xavier_uniform_(self.lin1.weight)
        init.xavier_uniform_(self.lin2.weight)
        init.xavier_uniform_(self.lin3.weight)

    def forward(self,x):
        x = x.view(-1,3*32*32)
        x = self.lin1(x)
        x = self.relu(x)
        x = self.lin2(x)
        x = self.relu(x)
        x = self.lin3(x)
        x = self.relu(x)
        return x

main.py

py 复制代码
import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
import torch.optim as optim
from model import MLP_cls,CNN_cls


seed = 42
torch.manual_seed(seed)
batch_size_train = 64
batch_size_test  = 64
epochs = 10
learning_rate = 0.01
momentum = 0.5
net = MLP_cls()

train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.CIFAR10('./data/', train=True, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.5,), (0.5,))
                               ])),
    batch_size=batch_size_train, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.CIFAR10('./data/', train=False, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.5,), (0.5,))
                               ])),
    batch_size=batch_size_test, shuffle=True)

optimizer = optim.SGD(net.parameters(), lr=learning_rate,momentum=momentum)
criterion = nn.CrossEntropyLoss()

print("****************Begin Training****************")
net.train()
for epoch in range(epochs):
    run_loss = 0
    correct_num = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        out = net(data)
        _,pred = torch.max(out,dim=1)
        optimizer.zero_grad()
        loss = criterion(out,target)
        loss.backward()
        run_loss += loss
        optimizer.step()
        correct_num  += torch.sum(pred==target)
    print('epoch',epoch,'loss {:.2f}'.format(run_loss.item()/len(train_loader)),'accuracy {:.2f}'.format(correct_num.item()/(len(train_loader)*batch_size_train)))



print("****************Begin Testing****************")
net.eval()
test_loss = 0
test_correct_num = 0
for batch_idx, (data, target) in enumerate(test_loader):
    out = net(data)
    _,pred = torch.max(out,dim=1)
    test_loss += criterion(out,target)
    test_correct_num  += torch.sum(pred==target)
print('loss {:.2f}'.format(test_loss.item()/len(test_loader)),'accuracy {:.2f}'.format(test_correct_num.item()/(len(test_loader)*batch_size_test)))

参数设置

bash 复制代码
'./data/' #数据保存路径
seed = 42 #随机种子
batch_size_train = 64
batch_size_test  = 64
epochs = 10

optim --> SGD
learning_rate = 0.01
momentum = 0.5

注意事项

CIFAR10是彩色图像,单个大小为3*32*32。所以view的时候后面展平。

运行图

相关推荐
段一凡-华北理工大学几秒前
工业领域的Hadoop架构学习~系列文章03:MapReduce编程模型深度解读
大数据·人工智能·hadoop·学习·架构·高炉炼铁·高炉智能化
GitCode官方1 分钟前
开源鸿蒙跨平台直播|15场·10大框架|首期:跨平台不是“权衡之选“,而是基础设施
人工智能·华为·开源·harmonyos·atomgit
蓝速科技2 分钟前
3D 数字人全息舱算力部署方案对比:本地 X86 独显架构与云端 RK 架构怎么选才好
数据结构·人工智能·算法·架构·排序算法
没完没了没日没夜783 分钟前
告别Excel表格!全星研发项目管理APQP软件系统:高端制造研发合规与效率的“破局者”
人工智能
狒狒热知识4 分钟前
软文营销媒体发稿行业规范化发展与企业品牌传播安全保障
大数据·人工智能
小程故事多_804 分钟前
从想法到落地零返工,AI Agent六阶段自动化开发全流水线实践
运维·人工智能·自动化
2601_957888565 分钟前
短视频矩阵获客系统的设计与实践:提升企业数字营销效率的路径
大数据·人工智能·矩阵·企业增长
嵌入式-老费5 分钟前
esp开发与应用(按键和状态机)
人工智能
JustNow_Man5 分钟前
“失败后自动拉起修复 Agent”的闭环流水线
前端·人工智能·chrome·python
2601_957879337 分钟前
企业矩阵系统建设实践:从账号管理到AI内容协同
大数据·人工智能·矩阵系统·数字化运营