AlexNet 是一个经典的卷积神经网络模型,用于图像分类任务。
目录
大纲
各个文件的作用:
- data就是数据集
- dataloader.py就是数据集的加载以及实例初始化
- model.py就是AlexNet模块的定义
- train.py就是模型的训练
- test.py就是模型的测试
dataloader
python
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
# define the dataloader
transform = transforms.Compose(
[transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
batch_size = 16
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
shuffle=True)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
shuffle=False)
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
if __name__ == '__main__':
# get some random training images
dataiter = iter(train_loader)
images, labels = next(dataiter)
# print labels
print(' '.join('%5s' % classes[labels[j]] for j in range(batch_size)))
# show images
img_grid = torchvision.utils.make_grid(images)
img_grid = img_grid / 2 + 0.5
npimg = img_grid.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
model
python
import torch.nn as nn
import torch
class AlexNet(nn.Module):
def __init__(self, num_classes=10):
super(AlexNet, self).__init__()
self.conv_1 = nn.Sequential(
nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=2),
nn.BatchNorm2d(96),
nn.ReLU(),
nn.MaxPool2d(kernel_size = 3, stride = 2))
self.conv_2 = nn.Sequential(
nn.Conv2d(96, 256, kernel_size=5, stride=1, padding=2),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.MaxPool2d(kernel_size = 3, stride = 2))
self.conv_3 = nn.Sequential(
nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(384),
nn.ReLU())
self.conv_4 = nn.Sequential(
nn.Conv2d(384, 384, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(384),
nn.ReLU())
self.conv_5 = nn.Sequential(
nn.Conv2d(384, 256, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.MaxPool2d(kernel_size = 3, stride = 2))
self.fc_1 = nn.Sequential(
nn.Dropout(0.5),
nn.Linear(9216, 4096),
nn.ReLU())
self.fc_2 = nn.Sequential(
nn.Dropout(0.5),
nn.Linear(4096, 4096),
nn.ReLU())
self.fc_3= nn.Sequential(
nn.Linear(4096, num_classes))
def forward(self, x):
out = self.conv_1(x)
out = self.conv_2(out)
out = self.conv_3(out)
out = self.conv_4(out)
out = self.conv_5(out)
out = out.reshape(out.size(0), -1)
out = self.fc_1(out)
out = self.fc_2(out)
out = self.fc_3(out)
return out
if __name__ == '__main__':
model = AlexNet()
print(model)
x = torch.randn(1, 3, 224, 224)
y = model(x)
print(y.size())
train
python
import torch
import torch.nn as nn
from dataloader import train_loader, test_loader
from model import AlexNet
# define the hyperparameters
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_classes = 10
num_epochs = 20
learning_rate = 1e-3
# load the model
model = AlexNet(num_classes).to(device)
# loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# train the model
total_len = len(train_loader)
for epoch in range(num_epochs):
for i, (images, labels) in enumerate(train_loader):
# move tensors to the configured device
images = images.to(device)
labels = labels.to(device)
# forward pass
outputs = model(images)
loss = criterion(outputs, labels)
# backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (i+1) % 100 == 0:
print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(
epoch+1, num_epochs, i+1, total_len, loss.item()
))
# Validation
with torch.no_grad():
model.eval()
correct = 0
total = 0
for images, labels in test_loader:
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
model.train()
print('Accuracy of the network on the {} validation images: {} %'.format(10000, 100 * correct / total))
# save the model checkpoint
torch.save(model.state_dict(), 'alexnet.pth')
test
python
import torch
from dataloader import test_loader, classes
from model import AlexNet
# load the pretrained model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = AlexNet().to(device)
model.load_state_dict(torch.load('alexnet.pth', map_location=device))
# test the pretrained model on CIFAR-10 test data
with torch.no_grad():
model.eval()
correct = 0
total = 0
for images, labels in test_loader:
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Accuracy of the network on the {} validation images: {} %'.format(10000, 100 * correct / total))