Scale Decoupled Distillation 论文中SPP发生了什么

SPP

python 复制代码
import torch
import torch.nn as nn

# 定义SPP类
class SPP(nn.Module):
    def __init__(self, M=None):
        super(SPP, self).__init__()
        self.pooling_4x4 = nn.AdaptiveAvgPool2d((4, 4))
        self.pooling_2x2 = nn.AdaptiveAvgPool2d((2, 2))
        self.pooling_1x1 = nn.AdaptiveAvgPool2d((1, 1))

        self.M = M
        print(f"初始化 M:{self.M}")

    def forward(self, x):
        print(f"输入 x 的形状: {x.shape}")
        
        # 进行4x4池化
        x_4x4 = self.pooling_4x4(x)
        print(f"4x4池化后 x_4x4 的形状: {x_4x4.shape}")
        
        # 进行2x2池化
        x_2x2 = self.pooling_2x2(x_4x4)
        print(f"2x2池化后 x_2x2 的形状: {x_2x2.shape}")
        
        # 进行1x1池化
        x_1x1 = self.pooling_1x1(x_4x4)
        print(f"1x1池化后 x_1x1 的形状: {x_1x1.shape}")

        # 展平特征图
        x_4x4_flatten = torch.flatten(x_4x4, start_dim=2, end_dim=3)
        print(f"4x4池化展平后 x_4x4_flatten 的形状: {x_4x4_flatten.shape}")
        
        x_2x2_flatten = torch.flatten(x_2x2, start_dim=2, end_dim=3)
        print(f"2x2池化展平后 x_2x2_flatten 的形状: {x_2x2_flatten.shape}")
        
        x_1x1_flatten = torch.flatten(x_1x1, start_dim=2, end_dim=3)
        print(f"1x1池化展平后 x_1x1_flatten 的形状: {x_1x1_flatten.shape}")

        # 根据 M 值拼接特征
        if self.M == '[1,2,4]':
            x_feature = torch.cat((x_1x1_flatten, x_2x2_flatten, x_4x4_flatten), dim=2)
            print(f"特征拼接后 (M=[1,2,4]) x_feature 的形状: {x_feature.shape}")
        elif self.M == '[1,2]':
            x_feature = torch.cat((x_1x1_flatten, x_2x2_flatten), dim=2)
            print(f"特征拼接后 (M=[1,2]) x_feature 的形状: {x_feature.shape}")
        else:
            raise NotImplementedError('ERROR M')

        # 计算特征强度
        x_strength = x_feature.permute((2, 0, 1))
        print(f"特征强度计算前 x_strength 形状: {x_strength.shape}")
        
        x_strength = torch.mean(x_strength, dim=2)
        print(f"特征强度计算后 x_strength 的形状: {x_strength.shape}")

        return x_feature, x_strength


# 创建一个SPP模块实例
M = '[1,2,4]'  # 设置M为'[1,2,4]',拼接所有三个尺度
spp = SPP(M=M)

# 输入一个示例张量,形状为 [batch_size, channels, height, width]
input_tensor = torch.randn(2, 3, 16, 16)  # 假设 batch_size=2, channels=3, height=16, width=16

# 前向传播并打印每一步结果
x_feature, x_strength = spp(input_tensor)

运行结果

python 复制代码
输入 x 的形状: torch.Size([2, 3, 16, 16])
4x4池化后 x_4x4 的形状: torch.Size([2, 3, 4, 4])
2x2池化后 x_2x2 的形状: torch.Size([2, 3, 2, 2])
1x1池化后 x_1x1 的形状: torch.Size([2, 3, 1, 1])
4x4池化展平后 x_4x4_flatten 的形状: torch.Size([2, 3, 16])
2x2池化展平后 x_2x2_flatten 的形状: torch.Size([2, 3, 4])
1x1池化展平后 x_1x1_flatten 的形状: torch.Size([2, 3, 1])
特征拼接后 (M=[1,2,4]) x_feature 的形状: torch.Size([2, 3, 21])
特征强度计算前 x_strength 形状: torch.Size([21, 2, 3])
特征强度计算后 x_strength 的形状: torch.Size([21, 2])

x_strength = torch.mean(x_strength, dim=2)

x_strength = torch.mean(x_strength, dim=2) 这行代码的作用是对张量 x_strength 的第三个维度(即 dim=2)进行平均操作。具体来说,它是在指定的维度上计算每个元素的平均值,从而减少该维度的大小。

详细解释:

  1. x_strength 的形状

    • 在这行代码之前,x_strength 的形状为 [feature_num, batch_size, channels],也就是通过 permute 操作后将原来的 [batch_size, channels, feature_num] 重新排列为 [feature_num, batch_size, channels]
    • 这意味着,x_strength 的第0个维度表示特征数量(来自多尺度池化的特征块),第1个维度表示批量大小,第2个维度表示通道数。
  2. torch.mean 操作

    • torch.mean(x_strength, dim=2) 表示在第2个维度(即 channels)上计算均值,这意味着对每个特征块在不同通道上的值取平均。
    • 计算后的结果不再包含通道维度,因此它会将形状从 [feature_num, batch_size, channels] 变为 [feature_num, batch_size]
    • 换句话说,对于每个特征块(feature_num),在 batch_size 上的每个样本中,通道上的值被压缩为一个平均值。

举例:

假设 x_strength 的形状为 [3, 2, 4],其中:

  • feature_num = 3(表示有3个特征块),
  • batch_size = 2(表示有2个样本),
  • channels = 4(表示每个特征块有4个通道)。
python 复制代码
x_strength = torch.tensor([
    [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]],  # 第一个特征块
    [[2.0, 3.0, 4.0, 5.0], [6.0, 7.0, 8.0, 9.0]],  # 第二个特征块
    [[3.0, 4.0, 5.0, 6.0], [7.0, 8.0, 9.0, 10.0]]  # 第三个特征块
])

# 对dim=2(通道)取平均
x_strength_mean = torch.mean(x_strength, dim=2)

在这个例子中,x_strength_mean 将对每个特征块在4个通道上的值取平均:

python 复制代码
x_strength_mean = [
    [2.5, 6.5],  # 第一个特征块的通道均值
    [3.5, 7.5],  # 第二个特征块的通道均值
    [4.5, 8.5]   # 第三个特征块的通道均值
]

最终结果:

  • 经过 torch.mean(x_strength, dim=2) 后,x_strength 的形状从 [feature_num, batch_size, channels] 变为 [feature_num, batch_size],即每个特征块对每个样本生成了一个强度值,这个值是该特征块在所有通道上的平均强度。

可不可以说特征强度计算后的x_strength的3通道变成了1通道

是的,可以这样理解。

在计算特征强度的过程中,torch.mean(x_strength, dim=2) 对第3个维度(即通道维度)进行求平均操作,将原本具有多个通道(如3通道或其他通道数)的特征信息压缩为1个值。这个过程相当于将多个通道的信息整合成一个值,表示这些通道的平均特征强度。

因此,可以说计算后的 x_strength 从多通道(例如3通道)变成了1通道,因为它只剩下一个平均强度值来代表原来的多个通道信息。

解释:

  1. 计算前 :原始的 x_strength 形状为 [feature_num, batch_size, channels],表示有 channels 个通道。
  2. 计算后 :经过 torch.mean(x_strength, dim=2) 操作后,channels 维度被压缩,输出形状为 [feature_num, batch_size],即每个特征块只有一个通道,这个通道的值是原来多个通道的平均值。

因此,通道数从多个变为1个,这个1个通道的值是对所有通道的强度的平均表示。

SPP后

代码

python 复制代码
import torch
import torch.nn as nn

# 定义SPP类
class SPP(nn.Module):
    def __init__(self, M=None):
        super(SPP, self).__init__()
        self.pooling_4x4 = nn.AdaptiveAvgPool2d((4, 4))
        self.pooling_2x2 = nn.AdaptiveAvgPool2d((2, 2))
        self.pooling_1x1 = nn.AdaptiveAvgPool2d((1, 1))

        self.M = M
        print(f"初始化 M:{self.M}")

    def forward(self, x):
        print(f"输入 x 的形状: {x.shape}")
        
        x_4x4 = self.pooling_4x4(x)
        x_2x2 = self.pooling_2x2(x_4x4)
        x_1x1 = self.pooling_1x1(x_4x4)

        x_4x4_flatten = torch.flatten(x_4x4, start_dim=2, end_dim=3)
        x_2x2_flatten = torch.flatten(x_2x2, start_dim=2, end_dim=3)
        x_1x1_flatten = torch.flatten(x_1x1, start_dim=2, end_dim=3)

        if self.M == '[1,2,4]':
            x_feature = torch.cat((x_1x1_flatten, x_2x2_flatten, x_4x4_flatten), dim=2)
        elif self.M == '[1,2]':
            x_feature = torch.cat((x_1x1_flatten, x_2x2_flatten), dim=2)
        else:
            raise NotImplementedError('ERROR M')

        x_strength = x_feature.permute((2, 0, 1))
        x_strength = torch.mean(x_strength, dim=2)

        return x_feature, x_strength


# 定义主网络,包含SPP和全连接层
class NetWithSPP(nn.Module):
    def __init__(self, M=None, num_classes=1000):
        super(NetWithSPP, self).__init__()
        self.spp = SPP(M)
        self.fc = nn.Linear(3, num_classes)  # 将输入维度修改为3,匹配通道数

    def forward(self, feat4):
        # 从 SPP 获取多尺度特征
        x_spp, x_strength = self.spp(feat4)
        print(f"x_spp 的形状: {x_spp.shape}")
        
        # 调整 x_spp 的维度
        x_spp = x_spp.permute((2, 0, 1))
        print(f"维度转换后 x_spp 的形状: {x_spp.shape}")
        
        # 获取维度大小
        m, b, c = x_spp.shape[0], x_spp.shape[1], x_spp.shape[2]
        print(f"m (feature_num): {m}, b (batch_size): {b}, c (channels): {c}")
        
        # 展平 x_spp 以便输入全连接层
        x_spp = torch.reshape(x_spp, (m * b, c))
        print(f"展平后的 x_spp 的形状: {x_spp.shape}")
        
        # 通过全连接层生成 patch_score
        patch_score = self.fc(x_spp)
        print(f"通过全连接层后的 patch_score 的形状: {patch_score.shape}")
        
        # 将 patch_score 重新调整形状
        patch_score = torch.reshape(patch_score, (m, b, 1000))
        print(f"重新调整形状后的 patch_score 的形状: {patch_score.shape}")
        
        # 最后 permute,恢复到 [batch_size, 1000, feature_num]
        patch_score = patch_score.permute((1, 2, 0))
        print(f"最后 permute 后 patch_score 的形状: {patch_score.shape}")

        return patch_score


# 创建网络实例
M = '[1,2,4]'  # 设置 M 为 '[1,2,4]'
net = NetWithSPP(M=M, num_classes=1000)

# 输入一个示例张量,形状为 [batch_size, channels, height, width]
input_tensor = torch.randn(2, 3, 16, 16)  # 假设 batch_size=2, channels=3, height=16, width=16

# 前向传播并打印每一步结果
patch_score = net(input_tensor)

运行结果

python 复制代码
输入 x 的形状: torch.Size([2, 3, 16, 16])
x_spp 的形状: torch.Size([2, 3, 21])
维度转换后 x_spp 的形状: torch.Size([21, 2, 3])
m (feature_num): 21, b (batch_size): 2, c (channels): 3
展平后的 x_spp 的形状: torch.Size([42, 3])
通过全连接层后的 patch_score 的形状: torch.Size([42, 1000])
重新调整形状后的 patch_score 的形状: torch.Size([21, 2, 1000])
最后 permute 后 patch_score 的形状: torch.Size([2, 1000, 21])
相关推荐
靴子学长2 小时前
基于字节大模型的论文翻译(含免费源码)
人工智能·深度学习·nlp
梧桐树04293 小时前
python常用内建模块:collections
python
Dream_Snowar3 小时前
速通Python 第三节
开发语言·python
海棠AI实验室3 小时前
AI的进阶之路:从机器学习到深度学习的演变(一)
人工智能·深度学习·机器学习
四口鲸鱼爱吃盐4 小时前
Pytorch | 从零构建GoogleNet对CIFAR10进行分类
人工智能·pytorch·分类
蓝天星空4 小时前
Python调用open ai接口
人工智能·python
jasmine s4 小时前
Pandas
开发语言·python
郭wes代码4 小时前
Cmd命令大全(万字详细版)
python·算法·小程序
leaf_leaves_leaf5 小时前
win11用一条命令给anaconda环境安装GPU版本pytorch,并检查是否为GPU版本
人工智能·pytorch·python
夜雨飘零15 小时前
基于Pytorch实现的说话人日志(说话人分离)
人工智能·pytorch·python·声纹识别·说话人分离·说话人日志