文章目录
model.py
py
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
class MLP_cls(nn.Module):
def __init__(self,in_dim=28*28):
super(MLP_cls,self).__init__()
self.lin1 = nn.Linear(in_dim,128)
self.lin2 = nn.Linear(128,64)
self.lin3 = nn.Linear(64,10)
self.relu = nn.ReLU()
init.xavier_uniform_(self.lin1.weight)
init.xavier_uniform_(self.lin2.weight)
init.xavier_uniform_(self.lin3.weight)
def forward(self,x):
x = x.view(-1,28*28)
x = self.lin1(x)
x = self.relu(x)
x = self.lin2(x)
x = self.relu(x)
x = self.lin3(x)
x = self.relu(x)
return x
main.py
py
import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
import torch.optim as optim
from model import MLP_cls
seed = 42
torch.manual_seed(seed)
batch_size_train = 64
batch_size_test = 64
epochs = 10
learning_rate = 0.01
momentum = 0.5
mlp_net = MLP_cls()
train_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST('./data/', train=True, download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
(0.5,), (0.5,))
])),
batch_size=batch_size_train, shuffle=True)
test_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST('./data/', train=False, download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
(0.5,), (0.5,))
])),
batch_size=batch_size_test, shuffle=True)
optimizer = optim.SGD(mlp_net.parameters(), lr=learning_rate,momentum=momentum)
criterion = nn.CrossEntropyLoss()
print("****************Begin Training****************")
mlp_net.train()
for epoch in range(epochs):
run_loss = 0
correct_num = 0
for batch_idx, (data, target) in enumerate(train_loader):
out = mlp_net(data)
_,pred = torch.max(out,dim=1)
optimizer.zero_grad()
loss = criterion(out,target)
loss.backward()
run_loss += loss
optimizer.step()
correct_num += torch.sum(pred==target)
print('epoch',epoch,'loss {:.2f}'.format(run_loss.item()/len(train_loader)),'accuracy {:.2f}'.format(correct_num.item()/(len(train_loader)*batch_size_train)))
print("****************Begin Testing****************")
mlp_net.eval()
test_loss = 0
test_correct_num = 0
for batch_idx, (data, target) in enumerate(test_loader):
out = mlp_net(data)
_,pred = torch.max(out,dim=1)
test_loss += criterion(out,target)
test_correct_num += torch.sum(pred==target)
print('loss {:.2f}'.format(test_loss.item()/len(test_loader)),'accuracy {:.2f}'.format(test_correct_num.item()/(len(test_loader)*batch_size_test)))
参数设置
bash
'./data/' #数据保存路径
seed = 42 #随机种子
batch_size_train = 64
batch_size_test = 64
epochs = 10
optim --> SGD
learning_rate = 0.01
momentum = 0.5
注意事项
初始化权重
这里使用这种方式
py
init.xavier_uniform_(self.lin1.weight)
init.xavier_uniform_(self.lin2.weight)
init.xavier_uniform_(self.lin3.weight)
如果发现loss和acc不变
检查一下是不是忘记写optimizer.step()了
关于数据下载
数据在download=True时,会下载在./data文件夹下
关于输出格式
这里用'xxx {:.2f}'.format(xxx),保留两位小数。注意中间的空格,区分:.2f和%2f