ResNet (Residual Network) - 深度网络的新篇章:解决退化问题

背景 随着网络深度的增加,出现了退化问题:更深的网络并不能带来更好的性能,反而可能由于梯度消失或梯度爆炸导致模型训练困难。ResNet通过残差连接有效解决了这一问题。

网络结构 ResNet的核心是残差模块(Residual Block),其通过"跳跃连接"(skip connection)使梯度能够顺利传递,缓解了梯度消失问题。

性能与影响

实际应用案例

  • 基本残差模块:

    python 复制代码
    class BasicBlock(nn.Module):
        def __init__(self, in_channels, out_channels, stride=1, downsample=None):
            super(BasicBlock, self).__init__()
            self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
            self.bn1 = nn.BatchNorm2d(out_channels)
            self.relu = nn.ReLU(inplace=True)
            self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False)
            self.bn2 = nn.BatchNorm2d(out_channels)
            self.downsample = downsample
    
        def forward(self, x):
            identity = x
            if self.downsample is not None:
                identity = self.downsample(x)
    
            out = self.conv1(x)
            out = self.bn1(out)
            out = self.relu(out)
            out = self.conv2(out)
            out = self.bn2(out)
            out += identity
            out = self.relu(out)
    
            return out

    ResNet-18的代码实现:

    python 复制代码
    class ResNet(nn.Module):
        def __init__(self, block, layers, num_classes=1000):
            super(ResNet, self).__init__()
            self.in_channels = 64
            self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
            self.bn1 = nn.BatchNorm2d(64)
            self.relu = nn.ReLU(inplace=True)
            self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
    
            self.layer1 = self._make_layer(block, 64, layers[0])
            self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
            self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
            self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
            self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
            self.fc = nn.Linear(512, num_classes)
    
        def _make_layer(self, block, out_channels, blocks, stride=1):
            downsample = None
            if stride != 1 or self.in_channels != out_channels:
                downsample = nn.Sequential(
                    nn.Conv2d(self.in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                    nn.BatchNorm2d(out_channels)
                )
    
            layers = []
            layers.append(block(self.in_channels, out_channels, stride, downsample))
            self.in_channels = out_channels
            for _ in range(1, blocks):
                layers.append(block(out_channels, out_channels))
    
            return nn.Sequential(*layers)
    
        def forward(self, x):
            x = self.conv1(x)
            x = self.bn1(x)
            x = self.relu(x)
            x = self.maxpool(x)
    
            x = self.layer1(x)
            x = self.layer2(x)
            x = self.layer3(x)
            x = self.layer4(x)
    
            x = self.avgpool(x)
            x = torch.flatten(x, 1)
            x = self.fc(x)
    
            return x
    
    # 实例化ResNet-18
    model = ResNet(BasicBlock, [2, 2, 2, 2])
    print(model)
  • 关键创新

  • 残差连接:通过引入恒等映射,解决深度网络的退化问题。

  • 瓶颈结构:在较深的ResNet中(如ResNet-50、ResNet-101),采用1x1卷积层减少计算量。

  • 模块化设计:基于残差模块的层次堆叠,增强了网络的可扩展性。

  • 多层深度优化:更深的网络(如ResNet-152)在图像分类任务上获得更高精度。

  • 性能:ResNet在ImageNet上以3.57%的错误率刷新了记录。

  • 迁移学习:ResNet的预训练模型被广泛应用于目标检测、分割和其他任务中。

  • 深远影响:残差学习成为现代深度学习的核心思想,被广泛应用于各种网络设计。

  • 医学图像分析:用于病变检测和分割。

  • 自动驾驶:作为感知模块的一部分,用于目标检测和语义分割。

  • 自然语言处理:结合Transformer模型,用于多模态任务。

相关推荐
冬奇Lab2 小时前
Workflow 系列(06):安全——跨步骤注入传播与四层防御
人工智能·工作流引擎
冬奇Lab2 小时前
每日一个开源项目(第149篇):RAG-Anything - 把图片、表格、公式当成一等公民的多模态 RAG 框架
人工智能·开源
米小虾2 小时前
AI Agent 安全实战指南:当智能体开始"不听话",开发者该如何应对?
人工智能·安全·agent
IT_陈寒4 小时前
Vite的热更新突然不香了,排查三小时差点砸键盘
前端·人工智能·后端
阿里云大数据AI技术6 小时前
构建高转化海外电商搜索:阿里云OpenSearch行业算法版的全链路智能优化策略实战
人工智能·搜索引擎
Awu12276 小时前
⚡从零开发 Agent CLI(五)实现一个可治理、可扩展的工具系统
前端·人工智能·claude
字节跳动视频云技术团队6 小时前
让 Agent 成为音视频工作台:AI MediaKit CLI + Skill 发布
人工智能·音视频开发
魏祖潇6 小时前
framework 整合实战——DDD/TDD/SDD 三件套在 framework 仓的真实落地
人工智能·后端
Token炼金师7 小时前
去噪扩散:从随机噪声到高保真图像的数学之路
人工智能·aigc