- 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
- 🍖 原作者:K同学啊
问题:如果conv shortcut=False,
那么在执行"x=Add() ... "语句时,
通道数不一致的,为什么不会报错
python
import torch
import torch.nn as nn
# 定义分组卷积模块
class GroupedConvBlock(nn.Module):
def __init__(self, in_channels, groups, g_channels, stride):
super(GroupedConvBlock, self).__init__()
self.groups = groups
self.group_conv = nn.ModuleList([
nn.Conv2d(g_channels, g_channels, kernel_size=3, stride=stride, padding=1, bias=False)
for _ in range(groups)
])
self.bn = nn.BatchNorm2d(in_channels)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
# 分组数据
split_x = torch.split(x, x.size(1) // self.groups, dim=1)
group_out = [conv(g) for g, conv in zip(split_x, self.group_conv)]
# 合并数据
x = torch.cat(group_out, dim=1)
x = self.bn(x)
x = self.relu(x)
return x
# 定义残差模块
class ResNeXtBlock(nn.Module):
def __init__(self, in_channels, filters, groups=32, stride=1, conv_shortcut=False):
super(ResNeXtBlock, self).__init__()
self.conv_shortcut = conv_shortcut
self.groups = groups
self.g_channels = filters // groups
# Shortcut分支
if conv_shortcut:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, filters * 2, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(filters * 2),
)
else:
self.shortcut = nn.Identity()
# 主分支
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels, filters, kernel_size=1, stride=1, bias=False),
nn.BatchNorm2d(filters),
nn.ReLU(inplace=True)
)
self.grouped_conv = GroupedConvBlock(filters, groups, self.g_channels, stride)
self.conv3 = nn.Sequential(
nn.Conv2d(filters, filters * 2, kernel_size=1, stride=1, bias=False),
nn.BatchNorm2d(filters * 2),
)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
shortcut = self.shortcut(x)
x = self.conv1(x)
x = self.grouped_conv(x)
x = self.conv3(x)
x += shortcut
x = self.relu(x)
return x
# 定义 ResNeXt-50 模型
class ResNeXt50(nn.Module):
def __init__(self, num_classes=1000):
super(ResNeXt50, self).__init__()
self.stem = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
)
# 堆叠ResNeXt模块
self.layer1 = self.make_layer(64, 128, 3, stride=1)
self.layer2 = self.make_layer(256, 256, 4, stride=2)
self.layer3 = self.make_layer(512, 512, 6, stride=2)
self.layer4 = self.make_layer(1024, 1024, 3, stride=2)
# 全局平均池化和分类层
self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(2048, num_classes)
@staticmethod
def make_layer(in_channels, filters, blocks, stride):
layers = [ResNeXtBlock(in_channels, filters, stride=stride)]
for _ in range(1, blocks):
layers.append(ResNeXtBlock(filters * 2, filters, stride=1))
return nn.Sequential(*layers)
def forward(self, x):
x = self.stem(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.global_avg_pool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
conv_shortcut=True:使用 1x1 卷积层调整输入特征图的尺寸和通道数,适用于需要改变特征图尺寸和通道数的场景,例如每个残差堆栈的第一个残差单元。
conv_shortcut=False:直接使用输入特征图作为残差连接,适用于不需要改变特征图尺寸和通道数的场景,例如每个残差堆栈的后续残差单元。
为什么不会报错:
- 输入特征图的通道数已经匹配:
在调用 block 函数之前,输入特征图 x 的通道数已经是 filters * 2。
例如,在 stack 函数的第一个 block 调用后,x 的通道数变为 filters * 2,后续 block 函数调用时 conv_shortcut=False,因此通道数保持一致。
个人总结
- 不报错可能是通道数已经匹配
- 实际上采用本人代码时,修改conv_shortcut=False确实出现了错误,报错显示张量形状不匹配