ResNet (Residual Network) - 深度网络的新篇章:解决退化问题

背景 随着网络深度的增加,出现了退化问题:更深的网络并不能带来更好的性能,反而可能由于梯度消失或梯度爆炸导致模型训练困难。ResNet通过残差连接有效解决了这一问题。

网络结构 ResNet的核心是残差模块(Residual Block),其通过"跳跃连接"(skip connection)使梯度能够顺利传递,缓解了梯度消失问题。

性能与影响

实际应用案例

  • 基本残差模块:

    python 复制代码
    class BasicBlock(nn.Module):
        def __init__(self, in_channels, out_channels, stride=1, downsample=None):
            super(BasicBlock, self).__init__()
            self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
            self.bn1 = nn.BatchNorm2d(out_channels)
            self.relu = nn.ReLU(inplace=True)
            self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False)
            self.bn2 = nn.BatchNorm2d(out_channels)
            self.downsample = downsample
    
        def forward(self, x):
            identity = x
            if self.downsample is not None:
                identity = self.downsample(x)
    
            out = self.conv1(x)
            out = self.bn1(out)
            out = self.relu(out)
            out = self.conv2(out)
            out = self.bn2(out)
            out += identity
            out = self.relu(out)
    
            return out

    ResNet-18的代码实现:

    python 复制代码
    class ResNet(nn.Module):
        def __init__(self, block, layers, num_classes=1000):
            super(ResNet, self).__init__()
            self.in_channels = 64
            self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
            self.bn1 = nn.BatchNorm2d(64)
            self.relu = nn.ReLU(inplace=True)
            self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
    
            self.layer1 = self._make_layer(block, 64, layers[0])
            self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
            self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
            self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
            self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
            self.fc = nn.Linear(512, num_classes)
    
        def _make_layer(self, block, out_channels, blocks, stride=1):
            downsample = None
            if stride != 1 or self.in_channels != out_channels:
                downsample = nn.Sequential(
                    nn.Conv2d(self.in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                    nn.BatchNorm2d(out_channels)
                )
    
            layers = []
            layers.append(block(self.in_channels, out_channels, stride, downsample))
            self.in_channels = out_channels
            for _ in range(1, blocks):
                layers.append(block(out_channels, out_channels))
    
            return nn.Sequential(*layers)
    
        def forward(self, x):
            x = self.conv1(x)
            x = self.bn1(x)
            x = self.relu(x)
            x = self.maxpool(x)
    
            x = self.layer1(x)
            x = self.layer2(x)
            x = self.layer3(x)
            x = self.layer4(x)
    
            x = self.avgpool(x)
            x = torch.flatten(x, 1)
            x = self.fc(x)
    
            return x
    
    # 实例化ResNet-18
    model = ResNet(BasicBlock, [2, 2, 2, 2])
    print(model)
  • 关键创新

  • 残差连接:通过引入恒等映射,解决深度网络的退化问题。

  • 瓶颈结构:在较深的ResNet中(如ResNet-50、ResNet-101),采用1x1卷积层减少计算量。

  • 模块化设计:基于残差模块的层次堆叠,增强了网络的可扩展性。

  • 多层深度优化:更深的网络(如ResNet-152)在图像分类任务上获得更高精度。

  • 性能:ResNet在ImageNet上以3.57%的错误率刷新了记录。

  • 迁移学习:ResNet的预训练模型被广泛应用于目标检测、分割和其他任务中。

  • 深远影响:残差学习成为现代深度学习的核心思想,被广泛应用于各种网络设计。

  • 医学图像分析:用于病变检测和分割。

  • 自动驾驶:作为感知模块的一部分,用于目标检测和语义分割。

  • 自然语言处理:结合Transformer模型,用于多模态任务。

相关推荐
仙人掌_lz几秒前
用PyTorch在超大规模下训练深度学习模型:并行策略全解析
人工智能·pytorch·深度学习
商业讯1 分钟前
深圳无人机展览即将开始,无人机舵机为什么选择伟创动力
人工智能
视觉语言导航8 分钟前
AAAI-2025 | 中科院无人机导航新突破!FELA:基于细粒度对齐的无人机视觉对话导航
人工智能·深度学习·机器人·无人机·具身智能
孚为智能科技13 分钟前
无人机箱号识别系统结合5G技术的应用实践
图像处理·人工智能·5g·目标检测·计算机视觉·视觉检测·无人机
灏瀚星空17 分钟前
地磁-惯性-视觉融合制导系统设计:现代空战导航的抗干扰解决方案
图像处理·人工智能·python·深度学习·算法·机器学习·信息与通信
Livan.Tang20 分钟前
LIO-SAM框架理解
人工智能·机器学习·slam
-曾牛26 分钟前
Spring AI 集成 Mistral AI:构建高效多语言对话助手的实战指南
java·人工智能·后端·spring·microsoft·spring ai
迅易科技37 分钟前
当数控编程“联姻”AI:制造工厂的“智能大脑”如何炼成?
人工智能·ai·知识图谱·ai编程·deepseek
英英_38 分钟前
MATLAB中矩阵和数组的区别
机器学习·matlab·矩阵
沫儿笙1 小时前
KUKA库卡焊接机器人智能气阀
人工智能·物联网·机器人