CIFAR-10 卷积神经网络
下载数据集
python
batchsz = 32
cifar_train= datasets.CIFAR10('data',train=True,transform=torchvision.transforms.Compose([
torchvision.transforms.Resize((32,32)),
torchvision.transforms.ToTensor()
]),download=True)
cifar_train = DataLoader(cifar_train,batch_size=batchsz,shuffle=True)
cifar_test= datasets.CIFAR10('data',train=False,transform=torchvision.transforms.Compose([
torchvision.transforms.Resize((32,32)),
torchvision.transforms.ToTensor()
]),download=True)
cifar_test = DataLoader(cifar_test,batch_size=batchsz,shuffle=True)
构建网络
新建一个lenet5
python
import torch
from torch import nn
from torch.nn import functional as F
class Lenet5(nn.Module):
def __init__(self):
super(Lenet5,self).__init__()
self.conv_unit = nn.Sequential(
# x :[b,3,32,32] => [b,6,]
nn.Conv2d(3,6,5,1), #卷积层
#subsamping 池化层
nn.AvgPool2d(kernel_size=2,stride=2,padding=0),
#
nn.Conv2d(6,16,5,1,0),
nn.AvgPool2d(kernel_size=2,stride=2,padding=0)
)
#flatten
#fc_unit
self.fc_unit = nn.Sequential(
nn.Linear(16*5*5,120),
nn.ReLU(),
nn.Linear(120,84),
nn.ReLU(),
nn.Linear(84,10)
)
# self.criten = nn.CrossEntropyLoss()
def forward(self,x):
bachsz = x.size(0) #获取样本数量
x = self.conv_unit(x)
x = x.view(bachsz,16*5*5)
logits = self.fc_unit(x) #获取输出标签
return logits
运行测试
python
device = torch.device('cuda') #使用gpu运行
model = Lenet5().to(device) #实例化网络
criten = nn.CrossEntropyLoss().to(device) #使用交叉熵
optimizer = optim.Adam(model.parameters(),lr=1e-3) #采用Adam及逆行优化参数
for epoch in range(1000):
for batchidx,(x,lable) in enumerate(cifar_train):
x,lable = x.to(device),lable.to(device)
logits = model(x) #获得预测输出标签值
loss = criten(logits,lable) #计算损失值
optimizer.zero_grad() #将梯度归零
loss.backward() #方向传播
optimizer.step() #优化参数
print(epoch,loss.item())
total_correct = 0
total_num = 0
model.eval()
with torch.no_grad(): #表示不需要求梯度
for x,label in cifar_test:
x,label = x.to(device),label.to(device)
logits = model(x)
pred = logits.argmax(dim=1) 获取预测值
total_correct += torch.eq(pred,label).float().sum().item()
total_num += x.size(0)
acc = total_correct /total_num
print(epoch,acc)
网络图如下: