[深度学习网络从入门到入土] 残差网络ResNet

[深度学习网络从入门到入土] 残差网络ResNet

个人导航

知乎:https://www.zhihu.com/people/byzh_rc

CSDN:https://blog.csdn.net/qq_54636039

注:本文仅对所述内容做了框架性引导,具体细节可查询其余相关资料or源码

参考文章:各方资料

文章目录

  • [[深度学习网络从入门到入土] 残差网络ResNet](#[深度学习网络从入门到入土] 残差网络ResNet)
  • 个人导航
  • 参考资料
  • 背景
  • 架构(公式)
        • [1. ==BasicBlock(ResNet18/34)==](#1. ==BasicBlock(ResNet18/34)==)
        • [2. ==Bottleneck(ResNet50/101/152)==](#2. ==Bottleneck(ResNet50/101/152)==)
        • [3. ==Shortcut 类型==](#3. ==Shortcut 类型==)
  • 创新点
        • [1. ==残差连接==(Skip Connection)](#1. ==残差连接==(Skip Connection))
        • [2. 可训练超深网络](#2. 可训练超深网络)
        • [3. 结构简洁但极强](#3. 结构简洁但极强)
  • [为什么 ResNet 能训练 152 层](#为什么 ResNet 能训练 152 层)
  • 代码实现
  • 项目实例

参考资料

Deep Residual Learning for Image Recognition.

背景

在 2014--2015 年,深度 CNN 进入"越深越好"的阶段:

  • AlexNet:8 层
  • VGG:16--19 层
  • GoogLeNet:22 层

问题来了:当网络超过 20 层后,训练误差反而上升

这不是过拟合,而是优化困难(degradation problem)

resnet横空出世: 让网络学习"残差",而不是直接学习映射

传统网络:
H ( x ) = F ( x ) H(x)=F(x) H(x)=F(x)

ResNet:
H ( x ) = F ( x ) + x H(x) = F(x) + x H(x)=F(x)+x

架构(公式)

1. BasicBlock(ResNet18/34)
复制代码
Conv → BN → ReLU
Conv → BN
+
Shortcut
→ ReLU

y = ReLU ( F ( x ) + x ) F ( x ) = W 2 σ ( W 1 x ) y = \text{ReLU}(F(x) + x) \\ F(x) = W_2 \sigma(W_1 x) y=ReLU(F(x)+x)F(x)=W2σ(W1x)

2. Bottleneck(ResNet50/101/152)

当网络变得非常深时,使用瓶颈结构:

复制代码
1×1(降维) → 3×3(提取特征) → 1×1(升维)

F ( x ) = W 3 σ ( W 2 σ ( W 1 x ) ) F(x) = W_3 \sigma(W_2 \sigma(W_1 x)) F(x)=W3σ(W2σ(W1x))

3. Shortcut 类型

情况1:尺寸相同
y = F ( x ) + x y = F(x) + x y=F(x)+x

情况2:尺寸不同(下采样)
y = F ( x ) + W s x W s = 1 × 1 Conv y = F(x) + W_s x \\ \color{purple}{W_s = 1\times1 \text{ Conv}} y=F(x)+WsxWs=1×1 Conv

创新点

1. 残差连接(Skip Connection)

允许梯度直接传播

2. 可训练超深网络

152 层首次成功训练

3. 结构简洁但极强

成为后续几乎所有视觉网络的基础(DenseNet, U-Net)

为什么 ResNet 能训练 152 层

残差网络的理论基础:
∂ y ∂ x = ∂ F ( x ) ∂ x + 1 \frac{\partial y}{\partial x} = \frac{\partial F(x)}{\partial x} + 1 ∂x∂y=∂x∂F(x)+1

即梯度中始终存在 "+1" 项:

  • 梯度不会消失
  • 网络可以直接传递恒等映射

代码实现

py 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F

from byzh.ai.Butils import b_get_params

class BasicBlock(nn.Module):
    """
    给 ResNet18/34 用
    """
    expansion = 1

    def __init__(self, in_ch, out_ch, stride=1):
        super().__init__()

        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, stride, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_ch)

        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, 1, 1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_ch)

        self.shortcut = nn.Sequential()

        if stride != 1 or in_ch != out_ch:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_ch, out_ch, 1, stride, bias=False),
                nn.BatchNorm2d(out_ch)
            )

    def forward(self, x):
        out = torch.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = torch.relu(out)
        return out

class Bottleneck(nn.Module):
    """
    给 ResNet50/101/152 用
    """
    expansion = 4  # 输出通道 = out_ch * 4

    def __init__(self, in_ch, out_ch, stride=1):
        super().__init__()

        # 1x1 降维
        self.conv1 = nn.Conv2d(in_ch, out_ch, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_ch)

        # 3x3 特征提取(这里做 stride 下采样)
        self.conv2 = nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_ch)

        # 1x1 升维
        self.conv3 = nn.Conv2d(out_ch, out_ch * self.expansion, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(out_ch * self.expansion)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_ch != out_ch * self.expansion:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_ch, out_ch * self.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_ch * self.expansion)
            )

    def forward(self, x):
        out = torch.relu(self.bn1(self.conv1(x)))
        out = torch.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))

        out += self.shortcut(x)
        out = torch.relu(out)
        return out


class ResNet(nn.Module):
    """
    input shape: (N, 3, 224, 224)
    """
    def __init__(self, block, layers, num_classes=1000):
        super().__init__()
        self.in_ch = 64

        self.conv1 = nn.Conv2d(3, 64, 7, 2, 3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.maxpool = nn.MaxPool2d(3, 2, 1)

        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, out_ch, blocks, stride=1):
        layers = []
        layers.append(block(self.in_ch, out_ch, stride))
        self.in_ch = out_ch * block.expansion

        for _ in range(1, blocks):
            layers.append(block(self.in_ch, out_ch))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = torch.relu(self.bn1(self.conv1(x)))
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

class B_ResNet18_Paper(ResNet):
    """
    input shape: (N, 3, 224, 224)
    """
    def __init__(self, num_classes=1000):
        block = BasicBlock
        layers = [2, 2, 2, 2]
        super().__init__(block=block, layers=layers, num_classes=num_classes)


class B_ResNet34_Paper(ResNet):
    """
    input shape: (N, 3, 224, 224)
    """
    def __init__(self, num_classes=1000):
        block = BasicBlock
        layers = [3, 4, 6, 3]
        super().__init__(block=block, layers=layers, num_classes=num_classes)

class B_ResNet50_Paper(ResNet):
    """
    input shape: (N, 3, 224, 224)
    """
    def __init__(self, num_classes=1000):
        block = Bottleneck
        layers = [3, 4, 6, 3]
        super().__init__(block=block, layers=layers, num_classes=num_classes)

class B_ResNet101_Paper(ResNet):
    """
    input shape: (N, 3, 224, 224)
    """
    def __init__(self, num_classes=1000):
        block = Bottleneck
        layers = [3, 4, 23, 3]
        super().__init__(block=block, layers=layers, num_classes=num_classes)

class B_ResNet152_Paper(ResNet):
    """
    input shape: (N, 3, 224, 224)
    """
    def __init__(self, num_classes=1000):
        block = Bottleneck
        layers = [3, 8, 36, 3]
        super().__init__(block=block, layers=layers, num_classes=num_classes)

if __name__ == '__main__':
    # ResNet18
    net = B_ResNet18_Paper(num_classes=1000)
    a = torch.randn(50, 3, 224, 224)
    result = net(a)
    print(result.shape)
    print(f"参数量: {b_get_params(net)}")  # 11_689_512

    # ResNet34
    net = B_ResNet34_Paper(num_classes=1000)
    a = torch.randn(50, 3, 224, 224)
    result = net(a)
    print(result.shape)
    print(f"参数量: {b_get_params(net)}")  # 21_797_672

    # ResNet50
    net = B_ResNet50_Paper(num_classes=1000)
    a = torch.randn(50, 3, 224, 224)
    result = net(a)
    print(result.shape)
    print(f"参数量: {b_get_params(net)}")  # 25_557_032

    # ResNet101
    net = B_ResNet101_Paper(num_classes=1000)
    a = torch.randn(50, 3, 224, 224)
    result = net(a)
    print(result.shape)
    print(f"参数量: {b_get_params(net)}")  # 44_549_160

    # ResNet152
    net = B_ResNet152_Paper(num_classes=1000)
    a = torch.randn(50, 3, 224, 224)
    result = net(a)
    print(result.shape)
    print(f"参数量: {b_get_params(net)}")  # 60_192_808

项目实例

库环境:

复制代码
numpy==1.26.4
torch==2.2.2cu121
byzh-core==0.0.9.21
byzh-ai==0.0.9.53
byzh-extra==0.0.9.12
...

ResNet18训练MNIST数据集:

py 复制代码
# copy all the codes from here to run

import torch
import torch.nn.functional as F
from uploadToPypi_ai.byzh.ai.Bdata import b_stratified_indices

from byzh.ai.Btrainer import B_Classification_Trainer
from byzh.ai.Bdata import B_Download_MNIST, b_get_dataloader_from_tensor
# from uploadToPypi_ai.byzh.ai.Bmodel.study_cnn import B_ResNet18_Paper
from byzh.ai.Bmodel.study_cnn import B_ResNet18_Paper
from byzh.ai.Butils import b_get_device
  
##### hyper params #####
epochs = 10
lr = 1e-3
batch_size = 32
device = b_get_device(use_idle_gpu=True)

##### data #####
downloader = B_Download_MNIST(save_dir='D:/study_cnn/datasets/MNIST')
data_dict = downloader.get_data()
X_train = data_dict['X_train_standard']
y_train = data_dict['y_train']
X_test = data_dict['X_test_standard']
y_test = data_dict['y_test']
num_classes = data_dict['num_classes']
num_samples = data_dict['num_samples']

indices = b_stratified_indices(y_train, num_samples//5)
X_train = X_train[indices]
X_train = F.interpolate(X_train, size=(224, 224), mode='bilinear')
X_train = X_train.repeat(1, 3, 1, 1)
y_train = y_train[indices]

indices = b_stratified_indices(y_test, num_samples//5)
X_test = X_test[indices]
X_test = F.interpolate(X_test, size=(224, 224), mode='bilinear')
X_test = X_test.repeat(1, 3, 1, 1)
y_test = y_test[indices]

train_dataloader, val_dataloader = b_get_dataloader_from_tensor(
    X_train, y_train, X_test, y_test,
    batch_size=batch_size
)

##### model #####
model = B_ResNet18_Paper(num_classes=num_classes)

##### else #####
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = torch.nn.CrossEntropyLoss()

##### trainer #####
trainer = B_Classification_Trainer(
    model=model,
    optimizer=optimizer,
    criterion=criterion,
    train_loader=train_dataloader,
    val_loader=val_dataloader,
    device=device
)
trainer.set_writer1('./runs/resnet18/log.txt')

##### run #####
trainer.train_eval_s(epochs=epochs)

##### calculate #####
trainer.draw_loss_acc('./runs/resnet18/loss_acc.png', y_lim=False)
trainer.save_best_checkpoint('./runs/resnet18/best_checkpoint.pth')
trainer.calculate_model()
相关推荐
lightinging7 小时前
五款主流AI智能体多维对比
人工智能
love530love8 小时前
ComfyUI MediaPipe 猴子补丁终极完善版:补全上下文管理与姿态检测兼容
人工智能·windows·python·comfyui·protobuf·mediapipe
Bruce_Liuxiaowei8 小时前
AI攻防时间差:当漏洞发现速度碾压修复速度— 聚焦技术核心
网络·人工智能·网络安全·ai·系统安全
悟纤8 小时前
AI生成MV
人工智能·seedance2.0·ai mv·一键mv
Clark118 小时前
手写LLM推理框架时,内存管理99%的人会踩的坑 | TFFInfer解析(五)——Tensor 张量系统与内存抽象(下)
人工智能
逸风尊者8 小时前
Robotaxi 行业日报 | 2026-05-17
人工智能
Tutankaaa8 小时前
知识竞赛的“锦囊”设计:场外求助、免答权、双倍分
人工智能
小马过河R8 小时前
RAG检索优化策略:系统性四层框架解析
人工智能·python·算法·ai·llm·rag·问答
方安乐8 小时前
交换机的自学机制
运维·服务器·网络
~kiss~8 小时前
AI 大模型自主涌现专家 EMO 解读 : Pretraining Mixture of Experts for Emergent Modularity
人工智能