ResNetLayer 类

这段代码定义了一个 ResNetLayer 类,是 ResNet 网络结构中的一个"层级模块(stage)",即由多个 ResNetBlock 堆叠而成的层。

我来帮你逐行详细解释代码逻辑和设计思路👇:


🧩 类定义

class ResNetLayer(nn.Module):
"""ResNet layer with multiple ResNet blocks."""

继承自 nn.Module,表示这是一个可训练的 PyTorch 模块。

作用是构建 ResNet 网络中的一层(例如 conv2_x、conv3_x、conv4_x、conv5_x)。


⚙️ 初始化函数

def init(self, c1: int, c2: int, s: int = 1, is_first: bool = False, n: int = 1, e: int = 4):

参数解释:

参数 含义
c1 输入通道数
c2 输出通道数(基本块的输出通道)
s 步幅(stride),用于下采样
is_first 是否是网络的第一层(即 stem 层)
n ResNetBlock 的数量(重复次数)
e expansion ratio(扩展比例,Bottleneck结构中用)

🧠 第一层特判(is_first)

复制代码
if self.is_first:
    self.layer = nn.Sequential(
        Conv(c1, c2, k=7, s=2, p=3, act=True), 
        nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
    )
  • is_first=True 时(通常是 ResNet 的第一个卷积层,即 stem):

    • 使用一个 7×7 卷积 + stride=2 的下采样;

    • 接一个 3×3 最大池化 + stride=2

    • 相当于标准 ResNet 的输入部分(对 ImageNet 图像下采样 4 倍)。


🧩 否则(常规层)

else:

blocks = [ResNetBlock(c1, c2, s, e=e)]

blocks.extend([ResNetBlock(e * c2, c2, 1, e=e) for _ in range(n - 1)])

self.layer = nn.Sequential(*blocks)

这部分构建一个由 n 个 ResNetBlock 组成的层:

  • 第一个 block:

    • 输入通道是 c1

    • 输出通道是 c2

    • stride = s(控制是否下采样)

    • expansion ratio = e

  • 后续 block:

    • 输入通道是 e * c2(因为第一个 block 的输出经过 Bottleneck 扩展)

    • stride = 1(保持尺寸不变)

    • 重复 n-1 次。


🧾 前向传播

复制代码
def forward(self, x: torch.Tensor) -> torch.Tensor:
    return self.layer(x)

🧱 小结(结构对照)

参数示例 对应 ResNet 部分 结构内容
is_first=True Stem(输入层) 7×7 conv + 3×3 maxpool
is_first=False, n=3, s=2 一个 stage(如 conv3_x) 多个 ResNetBlock,首个下采样

🔍 示例用法

复制代码
# 假设 ResNet-50 的前两个阶段
layer1 = ResNetLayer(3, 64, is_first=True)
layer2 = ResNetLayer(64, 64, s=1, n=3, e=4)  # conv2_x
layer3 = ResNetLayer(256, 128, s=2, n=4, e=4) # conv3_x

x = torch.randn(1, 3, 224, 224)
y = layer1(x)
print(y.shape)  # torch.Size([1, 64, 56, 56])

🧩 总体设计逻辑

  • 封装性强:将多个 ResNetBlock 合为一个逻辑层;

  • 可配置性高:可通过参数灵活设置层数、通道数、步幅;

  • 与 ResNet 结构对应

    • ResNetLayerconv2_x ~ conv5_x

    • ResNetBlock ≈ 单个残差单元。

class ResNetLayer(nn.Module):

"""ResNet layer with multiple ResNet blocks.""" # 定义一个包含多个 ResNetBlock 的层(stage)

def init(self, c1: int, c2: int, s: int = 1, is_first: bool = False, n: int = 1, e: int = 4):

"""

初始化 ResNetLayer。

Args:

c1 (int): 输入通道数。

c2 (int): 输出通道数。

s (int): 步幅(stride),用于控制下采样。

is_first (bool): 是否为网络的第一层(stem 层)。

n (int): ResNetBlock 的数量。

e (int): 通道扩展倍率(通常为 4,用于 bottleneck 结构)。

"""

super().init() # 调用父类 nn.Module 的初始化函数

self.is_first = is_first # 记录是否为第一层,用于区别结构

如果是网络的第一层(stem 层)

if self.is_first:

Sequential 顺序容器:依次执行其中的层

self.layer = nn.Sequential(

Conv 是一个自定义的卷积模块(包含卷积、BN、激活)

输入通道 c1,输出通道 c2,卷积核大小 7x7,步幅 2,padding 3,启用激活函数

Conv(c1, c2, k=7, s=2, p=3, act=True),

最大池化层,核大小 3x3,步幅 2,padding 1(进一步下采样)

nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

)

else:

如果不是第一层(常规 ResNet stage)

第一个 block:可能包含下采样(s>1),输入通道 c1,输出通道 c2

blocks = [ResNetBlock(c1, c2, s, e=e)]

后续 n-1 个 block:输入通道是 e*c2(因为 bottleneck 结构输出扩展 e 倍)

步幅为 1(保持特征图大小不变)

blocks.extend([ResNetBlock(e * c2, c2, 1, e=e) for _ in range(n - 1)])

用 nn.Sequential 将多个 block 串联为一个整体层

self.layer = nn.Sequential(*blocks)

def forward(self, x: torch.Tensor) -> torch.Tensor:

"""前向传播函数"""

return self.layer(x) # 将输入 x 依次通过定义好的 layer,并返回结果

  • -1, 1, ResNetLayer, \[3, 64, 1, True, 1\]\] # 0

c1=3, # 输入通道数(RGB图像)

c2=64, # 输出通道数

s=1, # 步幅=1

is_first=True, # 是第一层(stem层)

n=1 # 只有1个block

)

  • -1, 1, ResNetLayer, \[64, 64, 1, False, 3\]\] # 1

c1=64, # 输入通道数

c2=64, # 输出通道数

s=1, # 步幅为1(不下采样)

is_first=False, # 不是第一层

n=3 # 包含3个 ResNetBlock

)

相关推荐
吴佳浩14 小时前
Python入门指南(七) - YOLO检测API进阶实战
人工智能·后端·python
tap.AI14 小时前
RAG系列(二)数据准备与向量索引
开发语言·人工智能
老蒋新思维14 小时前
知识IP的长期主义:当AI成为跨越增长曲线的“第二曲线引擎”|创客匠人
大数据·人工智能·tcp/ip·机器学习·创始人ip·创客匠人·知识变现
货拉拉技术15 小时前
出海技术挑战——Lalamove智能告警降噪
人工智能·后端·监控
wei202315 小时前
汽车智能体Agent:国务院“人工智能+”行动意见 对汽车智能体领域 革命性重塑
人工智能·汽车·agent·智能体
LinkTime_Cloud15 小时前
快手遭遇T0级“黑色闪电”:一场教科书式的“协同打击”,披上了AI“智能外衣”的攻击
人工智能
PPIO派欧云15 小时前
PPIO上线MiniMax-M2.1:聚焦多语言编程与真实世界复杂任务
人工智能
隔壁阿布都15 小时前
使用LangChain4j +Springboot 实现大模型与向量化数据库协同回答
人工智能·spring boot·后端
Coding茶水间15 小时前
基于深度学习的水面垃圾检测系统演示与介绍(YOLOv12/v11/v8/v5模型+Pyqt5界面+训练代码+数据集)
图像处理·人工智能·深度学习·yolo·目标检测·机器学习·计算机视觉
乐迪信息16 小时前
乐迪信息:煤矿皮带区域安全管控:人员违规闯入智能识别
大数据·运维·人工智能·物联网·安全