J7学习打卡笔记

问题:如果conv shortcut=False,

那么在执行"x=Add() ... "语句时,

通道数不一致的,为什么不会报错

python 复制代码
import torch
import torch.nn as nn


# 定义分组卷积模块
class GroupedConvBlock(nn.Module):
    def __init__(self, in_channels, groups, g_channels, stride):
        super(GroupedConvBlock, self).__init__()
        self.groups = groups
        self.group_conv = nn.ModuleList([
            nn.Conv2d(g_channels, g_channels, kernel_size=3, stride=stride, padding=1, bias=False)
            for _ in range(groups)
        ])
        self.bn = nn.BatchNorm2d(in_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        # 分组数据
        split_x = torch.split(x, x.size(1) // self.groups, dim=1)
        group_out = [conv(g) for g, conv in zip(split_x, self.group_conv)]
        # 合并数据
        x = torch.cat(group_out, dim=1)
        x = self.bn(x)
        x = self.relu(x)
        return x


# 定义残差模块
class ResNeXtBlock(nn.Module):
    def __init__(self, in_channels, filters, groups=32, stride=1, conv_shortcut=False):
        super(ResNeXtBlock, self).__init__()
        self.conv_shortcut = conv_shortcut
        self.groups = groups
        self.g_channels = filters // groups

        # Shortcut分支
        if conv_shortcut:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, filters * 2, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(filters * 2),
            )
        else:
            self.shortcut = nn.Identity()

        # 主分支
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, filters, kernel_size=1, stride=1, bias=False),
            nn.BatchNorm2d(filters),
            nn.ReLU(inplace=True)
        )
        self.grouped_conv = GroupedConvBlock(filters, groups, self.g_channels, stride)
        self.conv3 = nn.Sequential(
            nn.Conv2d(filters, filters * 2, kernel_size=1, stride=1, bias=False),
            nn.BatchNorm2d(filters * 2),
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        shortcut = self.shortcut(x)
        x = self.conv1(x)
        x = self.grouped_conv(x)
        x = self.conv3(x)
        x += shortcut
        x = self.relu(x)
        return x


# 定义 ResNeXt-50 模型
class ResNeXt50(nn.Module):
    def __init__(self, num_classes=1000):
        super(ResNeXt50, self).__init__()
        self.stem = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )

        # 堆叠ResNeXt模块
        self.layer1 = self.make_layer(64, 128, 3, stride=1)
        self.layer2 = self.make_layer(256, 256, 4, stride=2)
        self.layer3 = self.make_layer(512, 512, 6, stride=2)
        self.layer4 = self.make_layer(1024, 1024, 3, stride=2)

        # 全局平均池化和分类层
        self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(2048, num_classes)

    @staticmethod
    def make_layer(in_channels, filters, blocks, stride):
        layers = [ResNeXtBlock(in_channels, filters, stride=stride)]
        for _ in range(1, blocks):
            layers.append(ResNeXtBlock(filters * 2, filters, stride=1))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.stem(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.global_avg_pool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

conv_shortcut=True:使用 1x1 卷积层调整输入特征图的尺寸和通道数,适用于需要改变特征图尺寸和通道数的场景,例如每个残差堆栈的第一个残差单元。

conv_shortcut=False:直接使用输入特征图作为残差连接,适用于不需要改变特征图尺寸和通道数的场景,例如每个残差堆栈的后续残差单元。

为什么不会报错:

  • 输入特征图的通道数已经匹配:
    在调用 block 函数之前,输入特征图 x 的通道数已经是 filters * 2。
    例如,在 stack 函数的第一个 block 调用后,x 的通道数变为 filters * 2,后续 block 函数调用时 conv_shortcut=False,因此通道数保持一致。

个人总结

  • 不报错可能是通道数已经匹配
  • 实际上采用本人代码时,修改conv_shortcut=False确实出现了错误,报错显示张量形状不匹配
相关推荐
呱呱巨基16 分钟前
Linux 第一个系统程序 进度条
linux·c++·笔记·学习
香芋Yu19 分钟前
【深度学习教程——01_深度基石(Foundation)】03_计算图是什么?PyTorch动态图机制解密
人工智能·pytorch·深度学习
氵文大师29 分钟前
PyTorch 性能分析实战:像手术刀一样精准控制 Nsys Timeline(附自定义颜色教程)
人工智能·pytorch·python
好奇龙猫35 分钟前
【人工智能学习-AI入试相关题目练习-第十七次】
人工智能·学习
林深现海42 分钟前
【刘二大人】PyTorch深度学习实践笔记 —— 第二集:线性模型(凝练版)
pytorch·笔记·深度学习
历程里程碑1 小时前
Linux 16 环境变量
linux·运维·服务器·开发语言·数据库·c++·笔记
我材不敲代码1 小时前
机器学习入门02——新手学习的第一个回归算法:线性回归
学习·机器学习·回归
●VON1 小时前
React Native for OpenHarmony:构建高性能、高体验的 TextInput 输入表单
javascript·学习·react native·react.js·von
横木沉1 小时前
Opencode启动时内置Bun段错误的解决笔记
人工智能·笔记·bun·vibecoding·opencode
●VON1 小时前
React Native for OpenHarmony:ActivityIndicator 动画实现详解
javascript·学习·react native·react.js·性能优化·openharmony