J7学习打卡笔记

问题:如果conv shortcut=False,

那么在执行"x=Add() ... "语句时,

通道数不一致的,为什么不会报错

python 复制代码
import torch
import torch.nn as nn


# 定义分组卷积模块
class GroupedConvBlock(nn.Module):
    def __init__(self, in_channels, groups, g_channels, stride):
        super(GroupedConvBlock, self).__init__()
        self.groups = groups
        self.group_conv = nn.ModuleList([
            nn.Conv2d(g_channels, g_channels, kernel_size=3, stride=stride, padding=1, bias=False)
            for _ in range(groups)
        ])
        self.bn = nn.BatchNorm2d(in_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        # 分组数据
        split_x = torch.split(x, x.size(1) // self.groups, dim=1)
        group_out = [conv(g) for g, conv in zip(split_x, self.group_conv)]
        # 合并数据
        x = torch.cat(group_out, dim=1)
        x = self.bn(x)
        x = self.relu(x)
        return x


# 定义残差模块
class ResNeXtBlock(nn.Module):
    def __init__(self, in_channels, filters, groups=32, stride=1, conv_shortcut=False):
        super(ResNeXtBlock, self).__init__()
        self.conv_shortcut = conv_shortcut
        self.groups = groups
        self.g_channels = filters // groups

        # Shortcut分支
        if conv_shortcut:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, filters * 2, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(filters * 2),
            )
        else:
            self.shortcut = nn.Identity()

        # 主分支
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, filters, kernel_size=1, stride=1, bias=False),
            nn.BatchNorm2d(filters),
            nn.ReLU(inplace=True)
        )
        self.grouped_conv = GroupedConvBlock(filters, groups, self.g_channels, stride)
        self.conv3 = nn.Sequential(
            nn.Conv2d(filters, filters * 2, kernel_size=1, stride=1, bias=False),
            nn.BatchNorm2d(filters * 2),
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        shortcut = self.shortcut(x)
        x = self.conv1(x)
        x = self.grouped_conv(x)
        x = self.conv3(x)
        x += shortcut
        x = self.relu(x)
        return x


# 定义 ResNeXt-50 模型
class ResNeXt50(nn.Module):
    def __init__(self, num_classes=1000):
        super(ResNeXt50, self).__init__()
        self.stem = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )

        # 堆叠ResNeXt模块
        self.layer1 = self.make_layer(64, 128, 3, stride=1)
        self.layer2 = self.make_layer(256, 256, 4, stride=2)
        self.layer3 = self.make_layer(512, 512, 6, stride=2)
        self.layer4 = self.make_layer(1024, 1024, 3, stride=2)

        # 全局平均池化和分类层
        self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(2048, num_classes)

    @staticmethod
    def make_layer(in_channels, filters, blocks, stride):
        layers = [ResNeXtBlock(in_channels, filters, stride=stride)]
        for _ in range(1, blocks):
            layers.append(ResNeXtBlock(filters * 2, filters, stride=1))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.stem(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.global_avg_pool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

conv_shortcut=True:使用 1x1 卷积层调整输入特征图的尺寸和通道数,适用于需要改变特征图尺寸和通道数的场景,例如每个残差堆栈的第一个残差单元。

conv_shortcut=False:直接使用输入特征图作为残差连接,适用于不需要改变特征图尺寸和通道数的场景,例如每个残差堆栈的后续残差单元。

为什么不会报错:

  • 输入特征图的通道数已经匹配:
    在调用 block 函数之前,输入特征图 x 的通道数已经是 filters * 2。
    例如,在 stack 函数的第一个 block 调用后,x 的通道数变为 filters * 2,后续 block 函数调用时 conv_shortcut=False,因此通道数保持一致。

个人总结

  • 不报错可能是通道数已经匹配
  • 实际上采用本人代码时,修改conv_shortcut=False确实出现了错误,报错显示张量形状不匹配
相关推荐
汇能感知2 小时前
摄像头模块在运动相机中的特殊应用
经验分享·笔记·科技
阿巴Jun2 小时前
【数学】线性代数知识点总结
笔记·线性代数·矩阵
茯苓gao2 小时前
STM32G4 速度环开环,电流环闭环 IF模式建模
笔记·stm32·单片机·嵌入式硬件·学习
是誰萆微了承諾3 小时前
【golang学习笔记 gin 】1.2 redis 的使用
笔记·学习·golang
DKPT3 小时前
Java内存区域与内存溢出
java·开发语言·jvm·笔记·学习
aaaweiaaaaaa3 小时前
HTML和CSS学习
前端·css·学习·html
ST.J4 小时前
前端笔记2025
前端·javascript·css·vue.js·笔记
Suckerbin4 小时前
LAMPSecurity: CTF5靶场渗透
笔记·安全·web安全·网络安全
看海天一色听风起雨落4 小时前
Python学习之装饰器
开发语言·python·学习
小憩-4 小时前
【机器学习】吴恩达机器学习笔记
人工智能·笔记·机器学习