J7学习打卡笔记

问题:如果conv shortcut=False,

那么在执行"x=Add() ... "语句时,

通道数不一致的,为什么不会报错

python 复制代码
import torch
import torch.nn as nn


# 定义分组卷积模块
class GroupedConvBlock(nn.Module):
    def __init__(self, in_channels, groups, g_channels, stride):
        super(GroupedConvBlock, self).__init__()
        self.groups = groups
        self.group_conv = nn.ModuleList([
            nn.Conv2d(g_channels, g_channels, kernel_size=3, stride=stride, padding=1, bias=False)
            for _ in range(groups)
        ])
        self.bn = nn.BatchNorm2d(in_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        # 分组数据
        split_x = torch.split(x, x.size(1) // self.groups, dim=1)
        group_out = [conv(g) for g, conv in zip(split_x, self.group_conv)]
        # 合并数据
        x = torch.cat(group_out, dim=1)
        x = self.bn(x)
        x = self.relu(x)
        return x


# 定义残差模块
class ResNeXtBlock(nn.Module):
    def __init__(self, in_channels, filters, groups=32, stride=1, conv_shortcut=False):
        super(ResNeXtBlock, self).__init__()
        self.conv_shortcut = conv_shortcut
        self.groups = groups
        self.g_channels = filters // groups

        # Shortcut分支
        if conv_shortcut:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, filters * 2, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(filters * 2),
            )
        else:
            self.shortcut = nn.Identity()

        # 主分支
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, filters, kernel_size=1, stride=1, bias=False),
            nn.BatchNorm2d(filters),
            nn.ReLU(inplace=True)
        )
        self.grouped_conv = GroupedConvBlock(filters, groups, self.g_channels, stride)
        self.conv3 = nn.Sequential(
            nn.Conv2d(filters, filters * 2, kernel_size=1, stride=1, bias=False),
            nn.BatchNorm2d(filters * 2),
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        shortcut = self.shortcut(x)
        x = self.conv1(x)
        x = self.grouped_conv(x)
        x = self.conv3(x)
        x += shortcut
        x = self.relu(x)
        return x


# 定义 ResNeXt-50 模型
class ResNeXt50(nn.Module):
    def __init__(self, num_classes=1000):
        super(ResNeXt50, self).__init__()
        self.stem = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )

        # 堆叠ResNeXt模块
        self.layer1 = self.make_layer(64, 128, 3, stride=1)
        self.layer2 = self.make_layer(256, 256, 4, stride=2)
        self.layer3 = self.make_layer(512, 512, 6, stride=2)
        self.layer4 = self.make_layer(1024, 1024, 3, stride=2)

        # 全局平均池化和分类层
        self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(2048, num_classes)

    @staticmethod
    def make_layer(in_channels, filters, blocks, stride):
        layers = [ResNeXtBlock(in_channels, filters, stride=stride)]
        for _ in range(1, blocks):
            layers.append(ResNeXtBlock(filters * 2, filters, stride=1))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.stem(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.global_avg_pool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

conv_shortcut=True:使用 1x1 卷积层调整输入特征图的尺寸和通道数,适用于需要改变特征图尺寸和通道数的场景,例如每个残差堆栈的第一个残差单元。

conv_shortcut=False:直接使用输入特征图作为残差连接,适用于不需要改变特征图尺寸和通道数的场景,例如每个残差堆栈的后续残差单元。

为什么不会报错:

  • 输入特征图的通道数已经匹配:
    在调用 block 函数之前,输入特征图 x 的通道数已经是 filters * 2。
    例如,在 stack 函数的第一个 block 调用后,x 的通道数变为 filters * 2,后续 block 函数调用时 conv_shortcut=False,因此通道数保持一致。

个人总结

  • 不报错可能是通道数已经匹配
  • 实际上采用本人代码时,修改conv_shortcut=False确实出现了错误,报错显示张量形状不匹配
相关推荐
取个名字真难呐39 分钟前
随机置矩阵列为0[矩阵乘法pytorch版]
pytorch·python·矩阵
山山而川粤1 小时前
共享充电宝系统|Java|SSM|VUE| 前后端分离
java·开发语言·后端·学习·mysql
Jackilina_Stone3 小时前
【HUAWEI】HCIP-AI-MindSpore Developer V1.0 | 第五章 自然语言处理原理与应用(2 自然语言处理关键技术) | 学习笔记
人工智能·笔记·学习·自然语言处理·hcip·huawei
垂杨有暮鸦⊙_⊙3 小时前
2024年6月英语六级CET6听力原文与解析
笔记·学习·六级
济南小草根3 小时前
JavaScript学习记录10
开发语言·javascript·学习
每天题库3 小时前
特种设备安全管理人员免费题库限时练习(判断题)
学习·安全·考试·题库·考证
安全方案4 小时前
网络安全的学习与实践经验(附资料合集)
学习·安全·web安全
油炸自行车4 小时前
【阅读】认知觉醒
笔记·阅读·读书
田梓燊4 小时前
人机交互复习笔记
笔记·人机交互