ResNet网络详解及其PyTorch实现
ResNet网络简介
残差神经网络 (Residual Neural Network,简称ResNet )是深度学习领域中一种非常重要的神经网络架构,由Microsoft Research的何恺明、张祥雨、任少卿、孙剑等人于2015年提出。ResNet的主要贡献在于解决了深度神经网络训练过程中的梯度消失和梯度爆炸等问题,使得可以训练非常深的神经网络,极大地提高了模型的性能。
论文地址:https://arxiv.org/abs/1512.03385
论文翻译:《Deep Residual Learning for Image Recognition》论文翻译精读
背景介绍
图1:具有20层和56层"普通"网络的CIFAR-10上的训练误差(左)和测试误差(右)。更深的网络具有更高的训练误差,测试误差也同样更高
不少研究表明,网络的深度对于提高模型的准确性至关重要 。因此自然而然想到是否随着网络层数的增加,模型准确率会不断提升?然而事实未必如此,随着网络层数的增加,首先会引来梯度消失与梯度爆炸 问题。针对这个问题,可以通过批标准化和权重正则化等方法解决,并且适用于与数十层的网络。
然而,随着网络的继续加深,又暴露出了退化问题 :随着网络深度的增加,准确性变得饱和,然后迅速退化。并且这种退化并不是由过度拟合引起的,并且向适当深度的模型添加更多层会导致更高的训练误差,如图1所示。
残差学习
图1训练准确性的下降表明并非所有系统都同样容易优化。关于退化问题,论文作者考虑到,如果在浅层网络上添加恒等映射层,更深的模型不应比其更浅的对应模型产生更高的训练误差。为此,他们构建了如下图所示的残差学习块。
图2:残差学习:一个构建块 在论文中,引入深度残差学习框架来解决退化问题。
形式上,将所需的底层映射表示为H(x),让堆叠的非线性层拟合另一个映射F(X)=H(x)-x。原始映射被重新构造为F(x)+x。并假设优化残差映射比优化原始映射更容易 。在极端情况下,如果恒等映射是最优的,那么将残差推至零比通过一堆非线性层拟合恒等映射更容易。并且,这种恒等映射既不引入额外的参数,也不引入额外的计算复杂度。
最终实验结果表明:
- 极深残差网络很容易优化,但是当深度增加时,对应的"普通"网络(简单地堆叠层)表现出更高的训练误差
- 深度残差网络可以轻松地从大大增加的深度中获得精度增益,产生比以前的网络更好的结果
ResNet网络结构
图3:ResNet网络结构
图4:左:VGG-19 模型(196 亿次浮点计算)作为参考。中:具有 34 个参数层的普通网络(36 亿次浮点计算)。右图:具有 34 个参数层的残差网络(36 亿次浮点计算)
ResNet的PyTorch实现
python
import torchvision
net = torchvision.models.resnet18()
print(net)
ResNet18:
python
ResNet(
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(layer1): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(1): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer2): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer3): Sequential(
(0): BasicBlock(
(conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer4): Sequential(
(0): BasicBlock(
(conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
(fc): Linear(in_features=512, out_features=1000, bias=True)
)
ResNet50实现CIFAR10分类
python
import torch
from torch.utils.tensorboard.summary import image
import torchvision
import torch.nn.functional as F
import torch.nn as nn
import torchvision.transforms as transforms
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
myWriter = SummaryWriter('./tensorboard/log/')
myTransforms = transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
# load
train_dataset = torchvision.datasets.CIFAR10(root='./data/', train=True, download=True,
transform=myTransforms)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=0)
test_dataset = torchvision.datasets.CIFAR10(root='./data/', train=False, download=True,
transform=myTransforms)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=4, shuffle=True, num_workers=0)
# 定义模型
myModel = torchvision.models.resnet50(pretrained=True)
# 将原来的ResNet18的最后两层全连接层拿掉,替换成一个输出单元为10的全连接层
inchannel = myModel.fc.in_features
myModel.fc = nn.Linear(inchannel, 10)
# 损失函数及优化器
# GPU加速
myDevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
myModel = myModel.to(myDevice)
learning_rate = 0.001
myOptimzier = optim.SGD(myModel.parameters(), lr=learning_rate, momentum=0.9)
myLoss = torch.nn.CrossEntropyLoss()
for _epoch in range(10):
training_loss = 0.0
for _step, input_data in enumerate(train_loader):
image, label = input_data[0].to(myDevice), input_data[1].to(myDevice) # GPU加速
predict_label = myModel.forward(image)
loss = myLoss(predict_label, label)
myWriter.add_scalar('training loss', loss, global_step=_epoch * len(train_loader) + _step)
myOptimzier.zero_grad()
loss.backward()
myOptimzier.step()
training_loss = training_loss + loss.item()
if _step % 10 == 0:
print('[iteration - %3d] training loss: %.3f' % (_epoch * len(train_loader) + _step, training_loss / 10))
training_loss = 0.0
print()
correct = 0
total = 0
# torch.save(myModel, 'Resnet50_Own.pkl') # 保存整个模型
myModel.eval()
for images, labels in test_loader:
# GPU加速
images = images.to(myDevice)
labels = labels.to(myDevice)
outputs = myModel(images) # 在非训练的时候是需要加的,没有这句代码,一些网络层的值会发生变动,不会固定
numbers, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Testing Accuracy : %.3f %%' % (100 * correct / total))
myWriter.add_scalar('test_Accuracy', 100 * correct / total)
运行结果如下:
python
Testing Accuracy : 94.780 %
Testing Accuracy : 91.790 %
Testing Accuracy : 93.320 %
Testing Accuracy : 95.000 %
Testing Accuracy : 95.660 %
Testing Accuracy : 96.040 %
Testing Accuracy : 94.860 %
Testing Accuracy : 96.300 %
Testing Accuracy : 96.250 %
Testing Accuracy : 95.800 %