J7学习打卡笔记

问题:如果conv shortcut=False,

那么在执行"x=Add() ... "语句时,

通道数不一致的,为什么不会报错

python 复制代码
import torch
import torch.nn as nn


# 定义分组卷积模块
class GroupedConvBlock(nn.Module):
    def __init__(self, in_channels, groups, g_channels, stride):
        super(GroupedConvBlock, self).__init__()
        self.groups = groups
        self.group_conv = nn.ModuleList([
            nn.Conv2d(g_channels, g_channels, kernel_size=3, stride=stride, padding=1, bias=False)
            for _ in range(groups)
        ])
        self.bn = nn.BatchNorm2d(in_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        # 分组数据
        split_x = torch.split(x, x.size(1) // self.groups, dim=1)
        group_out = [conv(g) for g, conv in zip(split_x, self.group_conv)]
        # 合并数据
        x = torch.cat(group_out, dim=1)
        x = self.bn(x)
        x = self.relu(x)
        return x


# 定义残差模块
class ResNeXtBlock(nn.Module):
    def __init__(self, in_channels, filters, groups=32, stride=1, conv_shortcut=False):
        super(ResNeXtBlock, self).__init__()
        self.conv_shortcut = conv_shortcut
        self.groups = groups
        self.g_channels = filters // groups

        # Shortcut分支
        if conv_shortcut:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, filters * 2, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(filters * 2),
            )
        else:
            self.shortcut = nn.Identity()

        # 主分支
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, filters, kernel_size=1, stride=1, bias=False),
            nn.BatchNorm2d(filters),
            nn.ReLU(inplace=True)
        )
        self.grouped_conv = GroupedConvBlock(filters, groups, self.g_channels, stride)
        self.conv3 = nn.Sequential(
            nn.Conv2d(filters, filters * 2, kernel_size=1, stride=1, bias=False),
            nn.BatchNorm2d(filters * 2),
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        shortcut = self.shortcut(x)
        x = self.conv1(x)
        x = self.grouped_conv(x)
        x = self.conv3(x)
        x += shortcut
        x = self.relu(x)
        return x


# 定义 ResNeXt-50 模型
class ResNeXt50(nn.Module):
    def __init__(self, num_classes=1000):
        super(ResNeXt50, self).__init__()
        self.stem = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )

        # 堆叠ResNeXt模块
        self.layer1 = self.make_layer(64, 128, 3, stride=1)
        self.layer2 = self.make_layer(256, 256, 4, stride=2)
        self.layer3 = self.make_layer(512, 512, 6, stride=2)
        self.layer4 = self.make_layer(1024, 1024, 3, stride=2)

        # 全局平均池化和分类层
        self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(2048, num_classes)

    @staticmethod
    def make_layer(in_channels, filters, blocks, stride):
        layers = [ResNeXtBlock(in_channels, filters, stride=stride)]
        for _ in range(1, blocks):
            layers.append(ResNeXtBlock(filters * 2, filters, stride=1))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.stem(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.global_avg_pool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

conv_shortcut=True:使用 1x1 卷积层调整输入特征图的尺寸和通道数,适用于需要改变特征图尺寸和通道数的场景,例如每个残差堆栈的第一个残差单元。

conv_shortcut=False:直接使用输入特征图作为残差连接,适用于不需要改变特征图尺寸和通道数的场景,例如每个残差堆栈的后续残差单元。

为什么不会报错:

  • 输入特征图的通道数已经匹配:
    在调用 block 函数之前,输入特征图 x 的通道数已经是 filters * 2。
    例如,在 stack 函数的第一个 block 调用后,x 的通道数变为 filters * 2,后续 block 函数调用时 conv_shortcut=False,因此通道数保持一致。

个人总结

  • 不报错可能是通道数已经匹配
  • 实际上采用本人代码时,修改conv_shortcut=False确实出现了错误,报错显示张量形状不匹配
相关推荐
Slow菜鸟1 小时前
AI学习篇(五) | awesome-design-md 使用说明
人工智能·学习
ZC跨境爬虫2 小时前
跟着 MDN 学 HTML day_9:(信件语义标记)
前端·css·笔记·ui·html
狐狐生风2 小时前
LangChain 向量存储:Chroma、FAISS
人工智能·python·学习·langchain·faiss·agentai
狐狐生风2 小时前
LangChain RAG 基础
人工智能·python·学习·langchain·rag·agentai
努力努力再努力FFF5 小时前
医生对AI辅助诊断感兴趣,作为临床人员该怎么了解和学习?
人工智能·学习
OBiO20135 小时前
Cell | 突破AAV载体容量限制!路中华/姜玉武/刘太安团队开发AAVLINK系统实现大基因递送
笔记
智者知已应修善业6 小时前
【51单片机2个按键控制流水灯运行与暂停】2023-9-6
c++·经验分享·笔记·算法·51单片机
sakiko_6 小时前
UIKit学习笔记5-使用UITableView制作聊天页面
笔记·学习·swift·uikit
Alice-YUE7 小时前
【js高频八股】防抖与节流
开发语言·前端·javascript·笔记·学习·ecmascript
北山有鸟8 小时前
修改源码法和插件法
嵌入式硬件·学习