J7学习打卡笔记

问题:如果conv shortcut=False,

那么在执行"x=Add() ... "语句时,

通道数不一致的,为什么不会报错

python 复制代码
import torch
import torch.nn as nn


# 定义分组卷积模块
class GroupedConvBlock(nn.Module):
    def __init__(self, in_channels, groups, g_channels, stride):
        super(GroupedConvBlock, self).__init__()
        self.groups = groups
        self.group_conv = nn.ModuleList([
            nn.Conv2d(g_channels, g_channels, kernel_size=3, stride=stride, padding=1, bias=False)
            for _ in range(groups)
        ])
        self.bn = nn.BatchNorm2d(in_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        # 分组数据
        split_x = torch.split(x, x.size(1) // self.groups, dim=1)
        group_out = [conv(g) for g, conv in zip(split_x, self.group_conv)]
        # 合并数据
        x = torch.cat(group_out, dim=1)
        x = self.bn(x)
        x = self.relu(x)
        return x


# 定义残差模块
class ResNeXtBlock(nn.Module):
    def __init__(self, in_channels, filters, groups=32, stride=1, conv_shortcut=False):
        super(ResNeXtBlock, self).__init__()
        self.conv_shortcut = conv_shortcut
        self.groups = groups
        self.g_channels = filters // groups

        # Shortcut分支
        if conv_shortcut:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, filters * 2, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(filters * 2),
            )
        else:
            self.shortcut = nn.Identity()

        # 主分支
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, filters, kernel_size=1, stride=1, bias=False),
            nn.BatchNorm2d(filters),
            nn.ReLU(inplace=True)
        )
        self.grouped_conv = GroupedConvBlock(filters, groups, self.g_channels, stride)
        self.conv3 = nn.Sequential(
            nn.Conv2d(filters, filters * 2, kernel_size=1, stride=1, bias=False),
            nn.BatchNorm2d(filters * 2),
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        shortcut = self.shortcut(x)
        x = self.conv1(x)
        x = self.grouped_conv(x)
        x = self.conv3(x)
        x += shortcut
        x = self.relu(x)
        return x


# 定义 ResNeXt-50 模型
class ResNeXt50(nn.Module):
    def __init__(self, num_classes=1000):
        super(ResNeXt50, self).__init__()
        self.stem = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )

        # 堆叠ResNeXt模块
        self.layer1 = self.make_layer(64, 128, 3, stride=1)
        self.layer2 = self.make_layer(256, 256, 4, stride=2)
        self.layer3 = self.make_layer(512, 512, 6, stride=2)
        self.layer4 = self.make_layer(1024, 1024, 3, stride=2)

        # 全局平均池化和分类层
        self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(2048, num_classes)

    @staticmethod
    def make_layer(in_channels, filters, blocks, stride):
        layers = [ResNeXtBlock(in_channels, filters, stride=stride)]
        for _ in range(1, blocks):
            layers.append(ResNeXtBlock(filters * 2, filters, stride=1))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.stem(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.global_avg_pool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

conv_shortcut=True:使用 1x1 卷积层调整输入特征图的尺寸和通道数,适用于需要改变特征图尺寸和通道数的场景,例如每个残差堆栈的第一个残差单元。

conv_shortcut=False:直接使用输入特征图作为残差连接,适用于不需要改变特征图尺寸和通道数的场景,例如每个残差堆栈的后续残差单元。

为什么不会报错:

  • 输入特征图的通道数已经匹配:
    在调用 block 函数之前,输入特征图 x 的通道数已经是 filters * 2。
    例如,在 stack 函数的第一个 block 调用后,x 的通道数变为 filters * 2,后续 block 函数调用时 conv_shortcut=False,因此通道数保持一致。

个人总结

  • 不报错可能是通道数已经匹配
  • 实际上采用本人代码时,修改conv_shortcut=False确实出现了错误,报错显示张量形状不匹配
相关推荐
极市平台1 小时前
骁龙大赛-技术分享第5期(上)
人工智能·经验分享·笔记·后端·个人开发
全栈陈序员1 小时前
【Python】基础语法入门(十七)——文件操作与数据持久化:安全读写本地数据
开发语言·人工智能·python·学习
啄缘之间1 小时前
11. UVM Test [uvm_test]
经验分享·笔记·学习·uvm·总结
RisunJan2 小时前
【行测】类比推理-自称他称全同
学习
wan55cn@126.com2 小时前
人类文明可通过技术手段(如加强航天器防护、改进电网设计)缓解地球两极反转带来的影响
人工智能·笔记·搜索引擎·百度·微信
石像鬼₧魂石2 小时前
Termux ↔ Windows 靶机 反向连接实操命令清单
linux·windows·学习
非凡ghost2 小时前
JRiver Media Center(媒体管理软件)
android·学习·智能手机·媒体·软件需求
会飞的土拨鼠呀2 小时前
docker部署 outline(栗子云笔记)
笔记·docker·容器
_Minato_3 小时前
数据库知识整理——数据库设计的步骤
数据库·经验分享·笔记·软考
hssfscv3 小时前
Mysql学习笔记——事务
笔记·学习·mysql