Scale Decoupled Distillation 论文中SPP发生了什么

SPP

python 复制代码
import torch
import torch.nn as nn

# 定义SPP类
class SPP(nn.Module):
    def __init__(self, M=None):
        super(SPP, self).__init__()
        self.pooling_4x4 = nn.AdaptiveAvgPool2d((4, 4))
        self.pooling_2x2 = nn.AdaptiveAvgPool2d((2, 2))
        self.pooling_1x1 = nn.AdaptiveAvgPool2d((1, 1))

        self.M = M
        print(f"初始化 M:{self.M}")

    def forward(self, x):
        print(f"输入 x 的形状: {x.shape}")
        
        # 进行4x4池化
        x_4x4 = self.pooling_4x4(x)
        print(f"4x4池化后 x_4x4 的形状: {x_4x4.shape}")
        
        # 进行2x2池化
        x_2x2 = self.pooling_2x2(x_4x4)
        print(f"2x2池化后 x_2x2 的形状: {x_2x2.shape}")
        
        # 进行1x1池化
        x_1x1 = self.pooling_1x1(x_4x4)
        print(f"1x1池化后 x_1x1 的形状: {x_1x1.shape}")

        # 展平特征图
        x_4x4_flatten = torch.flatten(x_4x4, start_dim=2, end_dim=3)
        print(f"4x4池化展平后 x_4x4_flatten 的形状: {x_4x4_flatten.shape}")
        
        x_2x2_flatten = torch.flatten(x_2x2, start_dim=2, end_dim=3)
        print(f"2x2池化展平后 x_2x2_flatten 的形状: {x_2x2_flatten.shape}")
        
        x_1x1_flatten = torch.flatten(x_1x1, start_dim=2, end_dim=3)
        print(f"1x1池化展平后 x_1x1_flatten 的形状: {x_1x1_flatten.shape}")

        # 根据 M 值拼接特征
        if self.M == '[1,2,4]':
            x_feature = torch.cat((x_1x1_flatten, x_2x2_flatten, x_4x4_flatten), dim=2)
            print(f"特征拼接后 (M=[1,2,4]) x_feature 的形状: {x_feature.shape}")
        elif self.M == '[1,2]':
            x_feature = torch.cat((x_1x1_flatten, x_2x2_flatten), dim=2)
            print(f"特征拼接后 (M=[1,2]) x_feature 的形状: {x_feature.shape}")
        else:
            raise NotImplementedError('ERROR M')

        # 计算特征强度
        x_strength = x_feature.permute((2, 0, 1))
        print(f"特征强度计算前 x_strength 形状: {x_strength.shape}")
        
        x_strength = torch.mean(x_strength, dim=2)
        print(f"特征强度计算后 x_strength 的形状: {x_strength.shape}")

        return x_feature, x_strength


# 创建一个SPP模块实例
M = '[1,2,4]'  # 设置M为'[1,2,4]',拼接所有三个尺度
spp = SPP(M=M)

# 输入一个示例张量,形状为 [batch_size, channels, height, width]
input_tensor = torch.randn(2, 3, 16, 16)  # 假设 batch_size=2, channels=3, height=16, width=16

# 前向传播并打印每一步结果
x_feature, x_strength = spp(input_tensor)

运行结果

python 复制代码
输入 x 的形状: torch.Size([2, 3, 16, 16])
4x4池化后 x_4x4 的形状: torch.Size([2, 3, 4, 4])
2x2池化后 x_2x2 的形状: torch.Size([2, 3, 2, 2])
1x1池化后 x_1x1 的形状: torch.Size([2, 3, 1, 1])
4x4池化展平后 x_4x4_flatten 的形状: torch.Size([2, 3, 16])
2x2池化展平后 x_2x2_flatten 的形状: torch.Size([2, 3, 4])
1x1池化展平后 x_1x1_flatten 的形状: torch.Size([2, 3, 1])
特征拼接后 (M=[1,2,4]) x_feature 的形状: torch.Size([2, 3, 21])
特征强度计算前 x_strength 形状: torch.Size([21, 2, 3])
特征强度计算后 x_strength 的形状: torch.Size([21, 2])

x_strength = torch.mean(x_strength, dim=2)

x_strength = torch.mean(x_strength, dim=2) 这行代码的作用是对张量 x_strength 的第三个维度(即 dim=2)进行平均操作。具体来说,它是在指定的维度上计算每个元素的平均值,从而减少该维度的大小。

详细解释:

  1. x_strength 的形状

    • 在这行代码之前,x_strength 的形状为 [feature_num, batch_size, channels],也就是通过 permute 操作后将原来的 [batch_size, channels, feature_num] 重新排列为 [feature_num, batch_size, channels]
    • 这意味着,x_strength 的第0个维度表示特征数量(来自多尺度池化的特征块),第1个维度表示批量大小,第2个维度表示通道数。
  2. torch.mean 操作

    • torch.mean(x_strength, dim=2) 表示在第2个维度(即 channels)上计算均值,这意味着对每个特征块在不同通道上的值取平均。
    • 计算后的结果不再包含通道维度,因此它会将形状从 [feature_num, batch_size, channels] 变为 [feature_num, batch_size]
    • 换句话说,对于每个特征块(feature_num),在 batch_size 上的每个样本中,通道上的值被压缩为一个平均值。

举例:

假设 x_strength 的形状为 [3, 2, 4],其中:

  • feature_num = 3(表示有3个特征块),
  • batch_size = 2(表示有2个样本),
  • channels = 4(表示每个特征块有4个通道)。
python 复制代码
x_strength = torch.tensor([
    [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]],  # 第一个特征块
    [[2.0, 3.0, 4.0, 5.0], [6.0, 7.0, 8.0, 9.0]],  # 第二个特征块
    [[3.0, 4.0, 5.0, 6.0], [7.0, 8.0, 9.0, 10.0]]  # 第三个特征块
])

# 对dim=2(通道)取平均
x_strength_mean = torch.mean(x_strength, dim=2)

在这个例子中,x_strength_mean 将对每个特征块在4个通道上的值取平均:

python 复制代码
x_strength_mean = [
    [2.5, 6.5],  # 第一个特征块的通道均值
    [3.5, 7.5],  # 第二个特征块的通道均值
    [4.5, 8.5]   # 第三个特征块的通道均值
]

最终结果:

  • 经过 torch.mean(x_strength, dim=2) 后,x_strength 的形状从 [feature_num, batch_size, channels] 变为 [feature_num, batch_size],即每个特征块对每个样本生成了一个强度值,这个值是该特征块在所有通道上的平均强度。

可不可以说特征强度计算后的x_strength的3通道变成了1通道

是的,可以这样理解。

在计算特征强度的过程中,torch.mean(x_strength, dim=2) 对第3个维度(即通道维度)进行求平均操作,将原本具有多个通道(如3通道或其他通道数)的特征信息压缩为1个值。这个过程相当于将多个通道的信息整合成一个值,表示这些通道的平均特征强度。

因此,可以说计算后的 x_strength 从多通道(例如3通道)变成了1通道,因为它只剩下一个平均强度值来代表原来的多个通道信息。

解释:

  1. 计算前 :原始的 x_strength 形状为 [feature_num, batch_size, channels],表示有 channels 个通道。
  2. 计算后 :经过 torch.mean(x_strength, dim=2) 操作后,channels 维度被压缩,输出形状为 [feature_num, batch_size],即每个特征块只有一个通道,这个通道的值是原来多个通道的平均值。

因此,通道数从多个变为1个,这个1个通道的值是对所有通道的强度的平均表示。

SPP后

代码

python 复制代码
import torch
import torch.nn as nn

# 定义SPP类
class SPP(nn.Module):
    def __init__(self, M=None):
        super(SPP, self).__init__()
        self.pooling_4x4 = nn.AdaptiveAvgPool2d((4, 4))
        self.pooling_2x2 = nn.AdaptiveAvgPool2d((2, 2))
        self.pooling_1x1 = nn.AdaptiveAvgPool2d((1, 1))

        self.M = M
        print(f"初始化 M:{self.M}")

    def forward(self, x):
        print(f"输入 x 的形状: {x.shape}")
        
        x_4x4 = self.pooling_4x4(x)
        x_2x2 = self.pooling_2x2(x_4x4)
        x_1x1 = self.pooling_1x1(x_4x4)

        x_4x4_flatten = torch.flatten(x_4x4, start_dim=2, end_dim=3)
        x_2x2_flatten = torch.flatten(x_2x2, start_dim=2, end_dim=3)
        x_1x1_flatten = torch.flatten(x_1x1, start_dim=2, end_dim=3)

        if self.M == '[1,2,4]':
            x_feature = torch.cat((x_1x1_flatten, x_2x2_flatten, x_4x4_flatten), dim=2)
        elif self.M == '[1,2]':
            x_feature = torch.cat((x_1x1_flatten, x_2x2_flatten), dim=2)
        else:
            raise NotImplementedError('ERROR M')

        x_strength = x_feature.permute((2, 0, 1))
        x_strength = torch.mean(x_strength, dim=2)

        return x_feature, x_strength


# 定义主网络,包含SPP和全连接层
class NetWithSPP(nn.Module):
    def __init__(self, M=None, num_classes=1000):
        super(NetWithSPP, self).__init__()
        self.spp = SPP(M)
        self.fc = nn.Linear(3, num_classes)  # 将输入维度修改为3,匹配通道数

    def forward(self, feat4):
        # 从 SPP 获取多尺度特征
        x_spp, x_strength = self.spp(feat4)
        print(f"x_spp 的形状: {x_spp.shape}")
        
        # 调整 x_spp 的维度
        x_spp = x_spp.permute((2, 0, 1))
        print(f"维度转换后 x_spp 的形状: {x_spp.shape}")
        
        # 获取维度大小
        m, b, c = x_spp.shape[0], x_spp.shape[1], x_spp.shape[2]
        print(f"m (feature_num): {m}, b (batch_size): {b}, c (channels): {c}")
        
        # 展平 x_spp 以便输入全连接层
        x_spp = torch.reshape(x_spp, (m * b, c))
        print(f"展平后的 x_spp 的形状: {x_spp.shape}")
        
        # 通过全连接层生成 patch_score
        patch_score = self.fc(x_spp)
        print(f"通过全连接层后的 patch_score 的形状: {patch_score.shape}")
        
        # 将 patch_score 重新调整形状
        patch_score = torch.reshape(patch_score, (m, b, 1000))
        print(f"重新调整形状后的 patch_score 的形状: {patch_score.shape}")
        
        # 最后 permute,恢复到 [batch_size, 1000, feature_num]
        patch_score = patch_score.permute((1, 2, 0))
        print(f"最后 permute 后 patch_score 的形状: {patch_score.shape}")

        return patch_score


# 创建网络实例
M = '[1,2,4]'  # 设置 M 为 '[1,2,4]'
net = NetWithSPP(M=M, num_classes=1000)

# 输入一个示例张量,形状为 [batch_size, channels, height, width]
input_tensor = torch.randn(2, 3, 16, 16)  # 假设 batch_size=2, channels=3, height=16, width=16

# 前向传播并打印每一步结果
patch_score = net(input_tensor)

运行结果

python 复制代码
输入 x 的形状: torch.Size([2, 3, 16, 16])
x_spp 的形状: torch.Size([2, 3, 21])
维度转换后 x_spp 的形状: torch.Size([21, 2, 3])
m (feature_num): 21, b (batch_size): 2, c (channels): 3
展平后的 x_spp 的形状: torch.Size([42, 3])
通过全连接层后的 patch_score 的形状: torch.Size([42, 1000])
重新调整形状后的 patch_score 的形状: torch.Size([21, 2, 1000])
最后 permute 后 patch_score 的形状: torch.Size([2, 1000, 21])
相关推荐
weixin_30777913几秒前
PyTorch基本功能与实现代码
人工智能·pytorch
通信.萌新1 小时前
OpenCV边沿检测(Python版)
人工智能·python·opencv
ARM+FPGA+AI工业主板定制专家1 小时前
基于RK3576/RK3588+FPGA+AI深度学习的轨道异物检测技术研究
人工智能·深度学习
Bran_Liu1 小时前
【LeetCode 刷题】字符串-字符串匹配(KMP)
python·算法·leetcode
weixin_307779131 小时前
分析一个深度学习项目并设计算法和用PyTorch实现的方法和步骤
人工智能·pytorch·python
Channing Lewis2 小时前
flask实现重启后需要重新输入用户名而避免浏览器使用之前已经记录的用户名
后端·python·flask
Channing Lewis2 小时前
如何在 Flask 中实现用户认证?
后端·python·flask
水银嘻嘻2 小时前
【Mac】Python相关知识经验
开发语言·python·macos
小猪咪piggy2 小时前
【深度学习入门】深度学习知识点总结
人工智能·深度学习
汤姆和佩琦2 小时前
2025-1-20-sklearn学习(42) 使用scikit-learn计算 钿车罗帕,相逢处,自有暗尘随马。
人工智能·python·学习·机器学习·scikit-learn·sklearn