[深度学习网络从入门到入土] 残差网络ResNet

[深度学习网络从入门到入土] 残差网络ResNet

个人导航

知乎:https://www.zhihu.com/people/byzh_rc

CSDN:https://blog.csdn.net/qq_54636039

注:本文仅对所述内容做了框架性引导,具体细节可查询其余相关资料or源码

参考文章:各方资料

文章目录

  • [[深度学习网络从入门到入土] 残差网络ResNet](#[深度学习网络从入门到入土] 残差网络ResNet)
  • 个人导航
  • 参考资料
  • 背景
  • 架构(公式)
        • [1. ==BasicBlock(ResNet18/34)==](#1. ==BasicBlock(ResNet18/34)==)
        • [2. ==Bottleneck(ResNet50/101/152)==](#2. ==Bottleneck(ResNet50/101/152)==)
        • [3. ==Shortcut 类型==](#3. ==Shortcut 类型==)
  • 创新点
        • [1. ==残差连接==(Skip Connection)](#1. ==残差连接==(Skip Connection))
        • [2. 可训练超深网络](#2. 可训练超深网络)
        • [3. 结构简洁但极强](#3. 结构简洁但极强)
  • [为什么 ResNet 能训练 152 层](#为什么 ResNet 能训练 152 层)
  • 代码实现
  • 项目实例

参考资料

Deep Residual Learning for Image Recognition.

背景

在 2014--2015 年,深度 CNN 进入"越深越好"的阶段:

  • AlexNet:8 层
  • VGG:16--19 层
  • GoogLeNet:22 层

问题来了:当网络超过 20 层后,训练误差反而上升

这不是过拟合,而是优化困难(degradation problem)

resnet横空出世: 让网络学习"残差",而不是直接学习映射

传统网络:
H ( x ) = F ( x ) H(x)=F(x) H(x)=F(x)

ResNet:
H ( x ) = F ( x ) + x H(x) = F(x) + x H(x)=F(x)+x

架构(公式)

1. BasicBlock(ResNet18/34)
复制代码
Conv → BN → ReLU
Conv → BN
+
Shortcut
→ ReLU

y = ReLU ( F ( x ) + x ) F ( x ) = W 2 σ ( W 1 x ) y = \text{ReLU}(F(x) + x) \\ F(x) = W_2 \sigma(W_1 x) y=ReLU(F(x)+x)F(x)=W2σ(W1x)

2. Bottleneck(ResNet50/101/152)

当网络变得非常深时,使用瓶颈结构:

复制代码
1×1(降维) → 3×3(提取特征) → 1×1(升维)

F ( x ) = W 3 σ ( W 2 σ ( W 1 x ) ) F(x) = W_3 \sigma(W_2 \sigma(W_1 x)) F(x)=W3σ(W2σ(W1x))

3. Shortcut 类型

情况1:尺寸相同
y = F ( x ) + x y = F(x) + x y=F(x)+x

情况2:尺寸不同(下采样)
y = F ( x ) + W s x W s = 1 × 1 Conv y = F(x) + W_s x \\ \color{purple}{W_s = 1\times1 \text{ Conv}} y=F(x)+WsxWs=1×1 Conv

创新点

1. 残差连接(Skip Connection)

允许梯度直接传播

2. 可训练超深网络

152 层首次成功训练

3. 结构简洁但极强

成为后续几乎所有视觉网络的基础(DenseNet, U-Net)

为什么 ResNet 能训练 152 层

残差网络的理论基础:
∂ y ∂ x = ∂ F ( x ) ∂ x + 1 \frac{\partial y}{\partial x} = \frac{\partial F(x)}{\partial x} + 1 ∂x∂y=∂x∂F(x)+1

即梯度中始终存在 "+1" 项:

  • 梯度不会消失
  • 网络可以直接传递恒等映射

代码实现

py 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F

from byzh.ai.Butils import b_get_params

class BasicBlock(nn.Module):
    """
    给 ResNet18/34 用
    """
    expansion = 1

    def __init__(self, in_ch, out_ch, stride=1):
        super().__init__()

        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, stride, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_ch)

        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, 1, 1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_ch)

        self.shortcut = nn.Sequential()

        if stride != 1 or in_ch != out_ch:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_ch, out_ch, 1, stride, bias=False),
                nn.BatchNorm2d(out_ch)
            )

    def forward(self, x):
        out = torch.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = torch.relu(out)
        return out

class Bottleneck(nn.Module):
    """
    给 ResNet50/101/152 用
    """
    expansion = 4  # 输出通道 = out_ch * 4

    def __init__(self, in_ch, out_ch, stride=1):
        super().__init__()

        # 1x1 降维
        self.conv1 = nn.Conv2d(in_ch, out_ch, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_ch)

        # 3x3 特征提取(这里做 stride 下采样)
        self.conv2 = nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_ch)

        # 1x1 升维
        self.conv3 = nn.Conv2d(out_ch, out_ch * self.expansion, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(out_ch * self.expansion)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_ch != out_ch * self.expansion:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_ch, out_ch * self.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_ch * self.expansion)
            )

    def forward(self, x):
        out = torch.relu(self.bn1(self.conv1(x)))
        out = torch.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))

        out += self.shortcut(x)
        out = torch.relu(out)
        return out


class ResNet(nn.Module):
    """
    input shape: (N, 3, 224, 224)
    """
    def __init__(self, block, layers, num_classes=1000):
        super().__init__()
        self.in_ch = 64

        self.conv1 = nn.Conv2d(3, 64, 7, 2, 3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.maxpool = nn.MaxPool2d(3, 2, 1)

        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, out_ch, blocks, stride=1):
        layers = []
        layers.append(block(self.in_ch, out_ch, stride))
        self.in_ch = out_ch * block.expansion

        for _ in range(1, blocks):
            layers.append(block(self.in_ch, out_ch))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = torch.relu(self.bn1(self.conv1(x)))
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

class B_ResNet18_Paper(ResNet):
    """
    input shape: (N, 3, 224, 224)
    """
    def __init__(self, num_classes=1000):
        block = BasicBlock
        layers = [2, 2, 2, 2]
        super().__init__(block=block, layers=layers, num_classes=num_classes)


class B_ResNet34_Paper(ResNet):
    """
    input shape: (N, 3, 224, 224)
    """
    def __init__(self, num_classes=1000):
        block = BasicBlock
        layers = [3, 4, 6, 3]
        super().__init__(block=block, layers=layers, num_classes=num_classes)

class B_ResNet50_Paper(ResNet):
    """
    input shape: (N, 3, 224, 224)
    """
    def __init__(self, num_classes=1000):
        block = Bottleneck
        layers = [3, 4, 6, 3]
        super().__init__(block=block, layers=layers, num_classes=num_classes)

class B_ResNet101_Paper(ResNet):
    """
    input shape: (N, 3, 224, 224)
    """
    def __init__(self, num_classes=1000):
        block = Bottleneck
        layers = [3, 4, 23, 3]
        super().__init__(block=block, layers=layers, num_classes=num_classes)

class B_ResNet152_Paper(ResNet):
    """
    input shape: (N, 3, 224, 224)
    """
    def __init__(self, num_classes=1000):
        block = Bottleneck
        layers = [3, 8, 36, 3]
        super().__init__(block=block, layers=layers, num_classes=num_classes)

if __name__ == '__main__':
    # ResNet18
    net = B_ResNet18_Paper(num_classes=1000)
    a = torch.randn(50, 3, 224, 224)
    result = net(a)
    print(result.shape)
    print(f"参数量: {b_get_params(net)}")  # 11_689_512

    # ResNet34
    net = B_ResNet34_Paper(num_classes=1000)
    a = torch.randn(50, 3, 224, 224)
    result = net(a)
    print(result.shape)
    print(f"参数量: {b_get_params(net)}")  # 21_797_672

    # ResNet50
    net = B_ResNet50_Paper(num_classes=1000)
    a = torch.randn(50, 3, 224, 224)
    result = net(a)
    print(result.shape)
    print(f"参数量: {b_get_params(net)}")  # 25_557_032

    # ResNet101
    net = B_ResNet101_Paper(num_classes=1000)
    a = torch.randn(50, 3, 224, 224)
    result = net(a)
    print(result.shape)
    print(f"参数量: {b_get_params(net)}")  # 44_549_160

    # ResNet152
    net = B_ResNet152_Paper(num_classes=1000)
    a = torch.randn(50, 3, 224, 224)
    result = net(a)
    print(result.shape)
    print(f"参数量: {b_get_params(net)}")  # 60_192_808

项目实例

库环境:

复制代码
numpy==1.26.4
torch==2.2.2cu121
byzh-core==0.0.9.21
byzh-ai==0.0.9.53
byzh-extra==0.0.9.12
...

ResNet18训练MNIST数据集:

py 复制代码
# copy all the codes from here to run

import torch
import torch.nn.functional as F
from uploadToPypi_ai.byzh.ai.Bdata import b_stratified_indices

from byzh.ai.Btrainer import B_Classification_Trainer
from byzh.ai.Bdata import B_Download_MNIST, b_get_dataloader_from_tensor
# from uploadToPypi_ai.byzh.ai.Bmodel.study_cnn import B_ResNet18_Paper
from byzh.ai.Bmodel.study_cnn import B_ResNet18_Paper
from byzh.ai.Butils import b_get_device
  
##### hyper params #####
epochs = 10
lr = 1e-3
batch_size = 32
device = b_get_device(use_idle_gpu=True)

##### data #####
downloader = B_Download_MNIST(save_dir='D:/study_cnn/datasets/MNIST')
data_dict = downloader.get_data()
X_train = data_dict['X_train_standard']
y_train = data_dict['y_train']
X_test = data_dict['X_test_standard']
y_test = data_dict['y_test']
num_classes = data_dict['num_classes']
num_samples = data_dict['num_samples']

indices = b_stratified_indices(y_train, num_samples//5)
X_train = X_train[indices]
X_train = F.interpolate(X_train, size=(224, 224), mode='bilinear')
X_train = X_train.repeat(1, 3, 1, 1)
y_train = y_train[indices]

indices = b_stratified_indices(y_test, num_samples//5)
X_test = X_test[indices]
X_test = F.interpolate(X_test, size=(224, 224), mode='bilinear')
X_test = X_test.repeat(1, 3, 1, 1)
y_test = y_test[indices]

train_dataloader, val_dataloader = b_get_dataloader_from_tensor(
    X_train, y_train, X_test, y_test,
    batch_size=batch_size
)

##### model #####
model = B_ResNet18_Paper(num_classes=num_classes)

##### else #####
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = torch.nn.CrossEntropyLoss()

##### trainer #####
trainer = B_Classification_Trainer(
    model=model,
    optimizer=optimizer,
    criterion=criterion,
    train_loader=train_dataloader,
    val_loader=val_dataloader,
    device=device
)
trainer.set_writer1('./runs/resnet18/log.txt')

##### run #####
trainer.train_eval_s(epochs=epochs)

##### calculate #####
trainer.draw_loss_acc('./runs/resnet18/loss_acc.png', y_lim=False)
trainer.save_best_checkpoint('./runs/resnet18/best_checkpoint.pth')
trainer.calculate_model()
相关推荐
Net_Walke几秒前
【网络协议】ECC非对称加密算法介绍
网络·网络协议
测试人社区—6679几秒前
当代码面临道德选择:VR如何为AI伦理决策注入“人性压力”
网络·人工智能·python·microsoft·vr·azure
飞Link9 分钟前
深度解析 TSAD:时序数据异常分类与检测技术的全景指南
大数据·人工智能·机器学习·数据挖掘
独行soc11 分钟前
2026年渗透测试面试题总结-36(题目+回答)
网络·python·安全·web安全·网络安全·渗透测试·安全狮
L***一12 分钟前
网络安全专业入门级认证体系分析与路径规划
网络·安全·web安全
昨夜见军贴061616 分钟前
IACheck:AI报告文档审核助力汽车零部件振动测试报告精准无误
人工智能·汽车
witAI19 分钟前
**Kimi小说灵感2025推荐,从零到一的创意激发指南**
人工智能·python
咚咚王者21 分钟前
人工智能之语言领域 自然语言处理 第五章 文本分类
人工智能·自然语言处理·分类
研究点啥好呢26 分钟前
3月10日GitHub热门项目推荐|自动化的浪潮
运维·人工智能·ai·自动化·github