[深度学习网络从入门到入土] 残差网络ResNet

[深度学习网络从入门到入土] 残差网络ResNet

个人导航

知乎:https://www.zhihu.com/people/byzh_rc

CSDN:https://blog.csdn.net/qq_54636039

注:本文仅对所述内容做了框架性引导,具体细节可查询其余相关资料or源码

参考文章:各方资料

文章目录

  • [[深度学习网络从入门到入土] 残差网络ResNet](#[深度学习网络从入门到入土] 残差网络ResNet)
  • 个人导航
  • 参考资料
  • 背景
  • 架构(公式)
        • [1. ==BasicBlock(ResNet18/34)==](#1. ==BasicBlock(ResNet18/34)==)
        • [2. ==Bottleneck(ResNet50/101/152)==](#2. ==Bottleneck(ResNet50/101/152)==)
        • [3. ==Shortcut 类型==](#3. ==Shortcut 类型==)
  • 创新点
        • [1. ==残差连接==(Skip Connection)](#1. ==残差连接==(Skip Connection))
        • [2. 可训练超深网络](#2. 可训练超深网络)
        • [3. 结构简洁但极强](#3. 结构简洁但极强)
  • [为什么 ResNet 能训练 152 层](#为什么 ResNet 能训练 152 层)
  • 代码实现
  • 项目实例

参考资料

Deep Residual Learning for Image Recognition.

背景

在 2014--2015 年,深度 CNN 进入"越深越好"的阶段:

  • AlexNet:8 层
  • VGG:16--19 层
  • GoogLeNet:22 层

问题来了:当网络超过 20 层后,训练误差反而上升

这不是过拟合,而是优化困难(degradation problem)

resnet横空出世: 让网络学习"残差",而不是直接学习映射

传统网络:
H ( x ) = F ( x ) H(x)=F(x) H(x)=F(x)

ResNet:
H ( x ) = F ( x ) + x H(x) = F(x) + x H(x)=F(x)+x

架构(公式)

1. BasicBlock(ResNet18/34)
复制代码
Conv → BN → ReLU
Conv → BN
+
Shortcut
→ ReLU

y = ReLU ( F ( x ) + x ) F ( x ) = W 2 σ ( W 1 x ) y = \text{ReLU}(F(x) + x) \\ F(x) = W_2 \sigma(W_1 x) y=ReLU(F(x)+x)F(x)=W2σ(W1x)

2. Bottleneck(ResNet50/101/152)

当网络变得非常深时,使用瓶颈结构:

复制代码
1×1(降维) → 3×3(提取特征) → 1×1(升维)

F ( x ) = W 3 σ ( W 2 σ ( W 1 x ) ) F(x) = W_3 \sigma(W_2 \sigma(W_1 x)) F(x)=W3σ(W2σ(W1x))

3. Shortcut 类型

情况1:尺寸相同
y = F ( x ) + x y = F(x) + x y=F(x)+x

情况2:尺寸不同(下采样)
y = F ( x ) + W s x W s = 1 × 1 Conv y = F(x) + W_s x \\ \color{purple}{W_s = 1\times1 \text{ Conv}} y=F(x)+WsxWs=1×1 Conv

创新点

1. 残差连接(Skip Connection)

允许梯度直接传播

2. 可训练超深网络

152 层首次成功训练

3. 结构简洁但极强

成为后续几乎所有视觉网络的基础(DenseNet, U-Net)

为什么 ResNet 能训练 152 层

残差网络的理论基础:
∂ y ∂ x = ∂ F ( x ) ∂ x + 1 \frac{\partial y}{\partial x} = \frac{\partial F(x)}{\partial x} + 1 ∂x∂y=∂x∂F(x)+1

即梯度中始终存在 "+1" 项:

  • 梯度不会消失
  • 网络可以直接传递恒等映射

代码实现

py 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F

from byzh.ai.Butils import b_get_params

class BasicBlock(nn.Module):
    """
    给 ResNet18/34 用
    """
    expansion = 1

    def __init__(self, in_ch, out_ch, stride=1):
        super().__init__()

        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, stride, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_ch)

        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, 1, 1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_ch)

        self.shortcut = nn.Sequential()

        if stride != 1 or in_ch != out_ch:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_ch, out_ch, 1, stride, bias=False),
                nn.BatchNorm2d(out_ch)
            )

    def forward(self, x):
        out = torch.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = torch.relu(out)
        return out

class Bottleneck(nn.Module):
    """
    给 ResNet50/101/152 用
    """
    expansion = 4  # 输出通道 = out_ch * 4

    def __init__(self, in_ch, out_ch, stride=1):
        super().__init__()

        # 1x1 降维
        self.conv1 = nn.Conv2d(in_ch, out_ch, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_ch)

        # 3x3 特征提取(这里做 stride 下采样)
        self.conv2 = nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_ch)

        # 1x1 升维
        self.conv3 = nn.Conv2d(out_ch, out_ch * self.expansion, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(out_ch * self.expansion)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_ch != out_ch * self.expansion:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_ch, out_ch * self.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_ch * self.expansion)
            )

    def forward(self, x):
        out = torch.relu(self.bn1(self.conv1(x)))
        out = torch.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))

        out += self.shortcut(x)
        out = torch.relu(out)
        return out


class ResNet(nn.Module):
    """
    input shape: (N, 3, 224, 224)
    """
    def __init__(self, block, layers, num_classes=1000):
        super().__init__()
        self.in_ch = 64

        self.conv1 = nn.Conv2d(3, 64, 7, 2, 3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.maxpool = nn.MaxPool2d(3, 2, 1)

        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, out_ch, blocks, stride=1):
        layers = []
        layers.append(block(self.in_ch, out_ch, stride))
        self.in_ch = out_ch * block.expansion

        for _ in range(1, blocks):
            layers.append(block(self.in_ch, out_ch))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = torch.relu(self.bn1(self.conv1(x)))
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

class B_ResNet18_Paper(ResNet):
    """
    input shape: (N, 3, 224, 224)
    """
    def __init__(self, num_classes=1000):
        block = BasicBlock
        layers = [2, 2, 2, 2]
        super().__init__(block=block, layers=layers, num_classes=num_classes)


class B_ResNet34_Paper(ResNet):
    """
    input shape: (N, 3, 224, 224)
    """
    def __init__(self, num_classes=1000):
        block = BasicBlock
        layers = [3, 4, 6, 3]
        super().__init__(block=block, layers=layers, num_classes=num_classes)

class B_ResNet50_Paper(ResNet):
    """
    input shape: (N, 3, 224, 224)
    """
    def __init__(self, num_classes=1000):
        block = Bottleneck
        layers = [3, 4, 6, 3]
        super().__init__(block=block, layers=layers, num_classes=num_classes)

class B_ResNet101_Paper(ResNet):
    """
    input shape: (N, 3, 224, 224)
    """
    def __init__(self, num_classes=1000):
        block = Bottleneck
        layers = [3, 4, 23, 3]
        super().__init__(block=block, layers=layers, num_classes=num_classes)

class B_ResNet152_Paper(ResNet):
    """
    input shape: (N, 3, 224, 224)
    """
    def __init__(self, num_classes=1000):
        block = Bottleneck
        layers = [3, 8, 36, 3]
        super().__init__(block=block, layers=layers, num_classes=num_classes)

if __name__ == '__main__':
    # ResNet18
    net = B_ResNet18_Paper(num_classes=1000)
    a = torch.randn(50, 3, 224, 224)
    result = net(a)
    print(result.shape)
    print(f"参数量: {b_get_params(net)}")  # 11_689_512

    # ResNet34
    net = B_ResNet34_Paper(num_classes=1000)
    a = torch.randn(50, 3, 224, 224)
    result = net(a)
    print(result.shape)
    print(f"参数量: {b_get_params(net)}")  # 21_797_672

    # ResNet50
    net = B_ResNet50_Paper(num_classes=1000)
    a = torch.randn(50, 3, 224, 224)
    result = net(a)
    print(result.shape)
    print(f"参数量: {b_get_params(net)}")  # 25_557_032

    # ResNet101
    net = B_ResNet101_Paper(num_classes=1000)
    a = torch.randn(50, 3, 224, 224)
    result = net(a)
    print(result.shape)
    print(f"参数量: {b_get_params(net)}")  # 44_549_160

    # ResNet152
    net = B_ResNet152_Paper(num_classes=1000)
    a = torch.randn(50, 3, 224, 224)
    result = net(a)
    print(result.shape)
    print(f"参数量: {b_get_params(net)}")  # 60_192_808

项目实例

库环境:

复制代码
numpy==1.26.4
torch==2.2.2cu121
byzh-core==0.0.9.21
byzh-ai==0.0.9.53
byzh-extra==0.0.9.12
...

ResNet18训练MNIST数据集:

py 复制代码
# copy all the codes from here to run

import torch
import torch.nn.functional as F
from uploadToPypi_ai.byzh.ai.Bdata import b_stratified_indices

from byzh.ai.Btrainer import B_Classification_Trainer
from byzh.ai.Bdata import B_Download_MNIST, b_get_dataloader_from_tensor
# from uploadToPypi_ai.byzh.ai.Bmodel.study_cnn import B_ResNet18_Paper
from byzh.ai.Bmodel.study_cnn import B_ResNet18_Paper
from byzh.ai.Butils import b_get_device
  
##### hyper params #####
epochs = 10
lr = 1e-3
batch_size = 32
device = b_get_device(use_idle_gpu=True)

##### data #####
downloader = B_Download_MNIST(save_dir='D:/study_cnn/datasets/MNIST')
data_dict = downloader.get_data()
X_train = data_dict['X_train_standard']
y_train = data_dict['y_train']
X_test = data_dict['X_test_standard']
y_test = data_dict['y_test']
num_classes = data_dict['num_classes']
num_samples = data_dict['num_samples']

indices = b_stratified_indices(y_train, num_samples//5)
X_train = X_train[indices]
X_train = F.interpolate(X_train, size=(224, 224), mode='bilinear')
X_train = X_train.repeat(1, 3, 1, 1)
y_train = y_train[indices]

indices = b_stratified_indices(y_test, num_samples//5)
X_test = X_test[indices]
X_test = F.interpolate(X_test, size=(224, 224), mode='bilinear')
X_test = X_test.repeat(1, 3, 1, 1)
y_test = y_test[indices]

train_dataloader, val_dataloader = b_get_dataloader_from_tensor(
    X_train, y_train, X_test, y_test,
    batch_size=batch_size
)

##### model #####
model = B_ResNet18_Paper(num_classes=num_classes)

##### else #####
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = torch.nn.CrossEntropyLoss()

##### trainer #####
trainer = B_Classification_Trainer(
    model=model,
    optimizer=optimizer,
    criterion=criterion,
    train_loader=train_dataloader,
    val_loader=val_dataloader,
    device=device
)
trainer.set_writer1('./runs/resnet18/log.txt')

##### run #####
trainer.train_eval_s(epochs=epochs)

##### calculate #####
trainer.draw_loss_acc('./runs/resnet18/loss_acc.png', y_lim=False)
trainer.save_best_checkpoint('./runs/resnet18/best_checkpoint.pth')
trainer.calculate_model()
相关推荐
Katecat996631 小时前
野生动物多类别目标检测-改进YOLO11结合AKConv提升兔子野兔猞猁狼识别效果
人工智能·目标检测·计算机视觉
Clarence Liu1 小时前
用大白话讲解人工智能(9) Transformer模型:让AI真正理解上下文
人工智能·深度学习·transformer
带娃的IT创业者1 小时前
ArXiv投稿详细操作指南 & AI论文写作最佳实践
人工智能·研究·arxiv·论文发布·论文预印本
麦麦大数据1 小时前
F065_基于机器学习的KDD CUP 99网络入侵检测系统实战
网络·人工智能·机器学习·网络安全·入侵检测
Boxsc_midnight1 小时前
【MCP+ComfyUI+CherryStudio+Ollama】实现对话式智能批量生成图片(或视频)的方案,硬件友好方案!
网络·人工智能
海天一色y1 小时前
从零构建医疗AI Agent:RAG增强检索、混合搜索与模型部署实战
人工智能·langchain·智能体开发
Katecat996631 小时前
基于YOLOv10的混凝土蜂窝缺陷检测系统深度学习模型
人工智能·深度学习·yolo
自然语1 小时前
人工智能之数字生命-观察的实现
数据结构·人工智能·学习·算法
龙亘川1 小时前
城市大脑:智慧城市演进的核心引擎与实践路径探析
人工智能·智慧城市·城市大脑