改进系列(3):基于ResNet网络与CBAM模块融合实现的生活垃圾分类

目录

[1. ResNet介绍](#1. ResNet介绍)

[2. CBAM 模块](#2. CBAM 模块)

[3. resnet + cbam](#3. resnet + cbam)

[3.1 添加在每个layer层后](#3.1 添加在每个layer层后)

[3.2 关于训练的建议](#3.2 关于训练的建议)

[4. 垃圾分类实战](#4. 垃圾分类实战)

[4.1 数据集](#4.1 数据集)

[4.2 训练](#4.2 训练)

[4.3 最好的权重](#4.3 最好的权重)

[4.4 推理](#4.4 推理)

[5. 其它](#5. 其它)


1. ResNet介绍

ResNet(残差网络)是一种深度卷积神经网络模型,由Kaiming He等人于2015年提出。它的提出解决了深度神经网络的梯度消失和梯度爆炸问题,使得深层网络的训练变得更加容易和有效。

在深度神经网络中,随着网络层数的增加,梯度在反向传播过程中逐渐变小,导致网络的训练变得困难。这是因为在传统的网络结构中,每个网络层都是通过直接逐层堆叠来进行信息的传递。当网络层数增加时,信息的传递路径变得更长,导致梯度逐渐消失。为了解决这个问题,ResNet提出了"残差学习"的概念。

ResNet引入了"残差块"(residual block)的概念,其中每个残差块包含一个跳跃连接(skip connection),将输入直接添加到输出中。这个跳跃连接允许梯度直接通过残差块传递,避免了梯度的消失问题。通过残差块的堆叠,ResNet可以构建非常深的网络,如ResNet-50、ResNet-101等。

ResNet的提出极大地促进了深度神经网络的发展。它在多个视觉任务上取得了非常好的性能,成为了目标检测、图像分类、图像分割等领域的重要基准模型。同时,ResNet的思想也影响了后续的深度神经网络架构设计,被广泛应用于各种深度学习任务中。

Resnet 太经典了,以至于后面基本上都采用了'残差块'的结构,关于resnet网络的详细介绍,可以百度,或者参考本人之前的博文

2. CBAM 模块

CBAM代表卷积块注意力模块。它是一种神经网络架构,增强了卷积神经网络(CNN)中的注意力机制。CBAM模块由两个注意力机制组成,即通道注意力模块(CAM)和空间注意力模块(SAM)。

通道注意力模块侧重于通过计算通道统计数据并将这些统计数据的加权和应用于每个通道来捕获通道之间的全局依赖关系。这使得网络能够强调重要渠道,抑制无关渠道。

另一方面,空间注意力模块捕捉图像中空间位置之间的依赖关系。它通过考虑每个空间位置相对于其相邻位置的重要性来计算空间注意力图。这有助于网络关注重要的空间区域,忽略信息量较小的区域。

通过结合信道注意力模块和空间注意力模块,CBAM模块使网络能够动态调整CNN内的注意力。这提高了图像分类、对象检测和语义分割等任务的性能。

python的代码实现如下:

python 复制代码
class ChannelAttention(nn.Module):
    def __init__(self, in_planes, reduction=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.fc1 = nn.Conv2d(in_planes, in_planes // reduction, 1, bias=False)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Conv2d(in_planes // reduction, in_planes, 1, bias=False)

        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
        max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
        out = avg_out + max_out
        return self.sigmoid(out)


class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()

        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
        padding = 3 if kernel_size == 7 else 1

        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv1(x)
        return self.sigmoid(x)


class CBAM(nn.Module):
    def __init__(self, in_planes, reduction=16, kernel_size=7):
        super(CBAM, self).__init__()
        self.ca = ChannelAttention(in_planes, reduction)
        self.sa = SpatialAttention(kernel_size)

    def forward(self, x):
        out = x * self.ca(x)
        out = out * self.sa(out)
        return out

3. resnet + cbam

resnet不同版本的结构如下

其实都是5层的结构,不过每层的相同模块重复多少次的区别

对于浅层(18,34)的残差块,可以看到没有1*1的卷积,因为网络并不大。而后面的50、101、152因为通道数很多,导致网络参数量很大,所以使用了1*1进行降维

添加模块很简单,因为CBAM模块的输入维度和输出维度是一样的,这样就可以在resnet网络的任意位置进行添加。

就比如vgg都是3*3连续两次的卷积,就是因为通道一样,那我想要连续卷积3次,甚至4次五次也只是重复叠加就行了,就类似于这个CMAB模块

事实上,通道不一样也可以添加,不过要更改维度稍微麻烦一点罢了

3.1 添加在每个layer层后

关键代码如下

python 复制代码
    net.layer1.add_module('cbam',CBAM(net.layer1[-1].conv2.out_channels))
    net.layer2.add_module('cbam',CBAM(net.layer2[-1].conv2.out_channels))
    net.layer3.add_module('cbam',CBAM(net.layer3[-1].conv2.out_channels))
    net.layer4.add_module('cbam',CBAM(net.layer4[-1].conv2.out_channels))

这里的net就是resnet18 ~ resnet152网络,如果想要在单独某个层后添加,将对应的layerx注释即可,网络仍然可以正常运行

因为CBAM模块需要知道上一层传入的通道数是多少,这里经过测试发现,不同版本resnet的不同layer的最后模块都是下面这样,那么将它取出,传给resnet就行了

(1): BasicBlock(

(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)

(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

(relu): ReLU(inplace=True)

(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)

(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

添加后的网络结构如下:

3.2 关于训练的建议

添加完模块后,网络的结构就被破坏了,这样官方的预训练模型和改进后的模型无法匹配,就没法使用,这里有两种解决办法

  1. 载入模型的时候有个strick参数,设置为False采用不完全覆盖即可

  2. 先载入网络的预训练权重,然后在更改网络,这里本人使用的是第二种

不过这不是重点,迁移学习的好处谁都知道。针对于不同任务,虽然特征不同,但我们可以肯定在底层的语义特征其实是类似的(边缘、亮度啥的),不同的仅仅是高级的语义特征罢了,所以没必要从头进行学习

而我们添加的模块是没有权重的,所以CBAM模块必须要重头训练,这也是为啥有人认为加了模块为什么精度还降低的原因。要么模块不适用某个数据集,要么没有训练到位

python 复制代码
    # 是否冻结权重
    if args.freeze_layers:
        for name, para in net.named_parameters():
            para.requires_grad_(False)
            if "fc" in name:
                para.requires_grad_(True)
            if "cbam" in name:
                para.requires_grad_(True)

这里的cbam就是之前添加CBAM模块的名字

可以发现,冻结后的训练参数明显少了很多

4. 垃圾分类实战

下载:Resnet改进(resnet18、resnet50等)在每个layer后加入CBAM模块实战:生活垃圾识别资源-CSDN文库

Tips:项目已经封装好,参考readme的数据集摆放好就可以训练了

4.1 数据集

生活常见垃圾分类,类别如下(代码自动生成

复制代码
{
    "0": "battery",
    "1": "biological",
    "2": "brown_glass",
    "3": "cardboard",
    "4": "clothes",
    "5": "green_glass",
    "6": "metal",
    "7": "paper",
    "8": "plastic",
    "9": "shoes",
    "10": "trash",
    "11": "white_glass"
}

可视化结果:

样本量:

4.2 训练

这里把CBAM加在了每一个layer后面,想要作对比实验可以自己注释掉特定层,或者全部去掉就是最初的resnet网络,这里利用resnet18+CBAM训练了100轮。

生成的结果如下:

python 复制代码
Namespace(model='resnet18', pretrained=True, freeze_layers=True, batch_size=8, epochs=100, optim='SGD', lr=0.001, lrf=0.0001)
Using device is:  cuda
Using dataloader workers is : 8
trainSet number is : 12415 valSet number is : 3100
model output is : 12
resnet version :  resnet18
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (cbam): CBAM(
      (ca): ChannelAttention(
        (avg_pool): AdaptiveAvgPool2d(output_size=1)
        (max_pool): AdaptiveMaxPool2d(output_size=1)
        (fc1): Conv2d(64, 4, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (relu1): ReLU()
        (fc2): Conv2d(4, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (sigmoid): Sigmoid()
      )
      (sa): SpatialAttention(
        (conv1): Conv2d(2, 1, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)
        (sigmoid): Sigmoid()
      )
    )
  )
  (layer2): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (cbam): CBAM(
      (ca): ChannelAttention(
        (avg_pool): AdaptiveAvgPool2d(output_size=1)
        (max_pool): AdaptiveMaxPool2d(output_size=1)
        (fc1): Conv2d(128, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (relu1): ReLU()
        (fc2): Conv2d(8, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (sigmoid): Sigmoid()
      )
      (sa): SpatialAttention(
        (conv1): Conv2d(2, 1, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)
        (sigmoid): Sigmoid()
      )
    )
  )
  (layer3): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (cbam): CBAM(
      (ca): ChannelAttention(
        (avg_pool): AdaptiveAvgPool2d(output_size=1)
        (max_pool): AdaptiveMaxPool2d(output_size=1)
        (fc1): Conv2d(256, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (relu1): ReLU()
        (fc2): Conv2d(16, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (sigmoid): Sigmoid()
      )
      (sa): SpatialAttention(
        (conv1): Conv2d(2, 1, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)
        (sigmoid): Sigmoid()
      )
    )
  )
  (layer4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (cbam): CBAM(
      (ca): ChannelAttention(
        (avg_pool): AdaptiveAvgPool2d(output_size=1)
        (max_pool): AdaptiveMaxPool2d(output_size=1)
        (fc1): Conv2d(512, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (relu1): ReLU()
        (fc2): Conv2d(32, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (sigmoid): Sigmoid()
      )
      (sa): SpatialAttention(
        (conv1): Conv2d(2, 1, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)
        (sigmoid): Sigmoid()
      )
    )
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=512, out_features=12, bias=True)
)
Total parameters is:11.23 M
Train parameters is:50068 
Flops:1824.40 M 
use optim is :  SGD

训练日志:

python 复制代码
    "train parameters": {
        "model": "resnet18",
        "pretrained": true,
        "freeze_layers": true,
        "batch_size": 16,
        "epochs": 100,
        "optim": "SGD",
        "lr": 0.01,
        "lrf": 0.001
    },
    "Datasets": {
        "trainSets number": 12415,
        "validSets number": 3100
    },
    "model": {
        "total parameters": 11226580.0,
        "train parameters": 49676,
        "flops": 1824400426.0
    },
python 复制代码
    "epoch:99": {
        "train info": {
            "accuracy": 0.9484494563020954,
            "battery": {
                "Precision": 0.9493,
                "Recall": 0.9656,
                "Specificity": 0.9967,
                "F1 score": 0.9574
            },
            "biological": {
                "Precision": 0.9662,
                "Recall": 0.9797,
                "Specificity": 0.9977,
                "F1 score": 0.9729
            },
            "brown_glass": {
                "Precision": 0.9312,
                "Recall": 0.9465,
                "Specificity": 0.9971,
                "F1 score": 0.9388
            },
            "cardboard": {
                "Precision": 0.9516,
                "Recall": 0.9369,
                "Specificity": 0.9971,
                "F1 score": 0.9442
            },
            "clothes": {
                "Precision": 0.9852,
                "Recall": 0.9843,
                "Specificity": 0.9923,
                "F1 score": 0.9847
            },
            "green_glass": {
                "Precision": 0.9704,
                "Recall": 0.9762,
                "Specificity": 0.9987,
                "F1 score": 0.9733
            },
            "metal": {
                "Precision": 0.8776,
                "Recall": 0.8961,
                "Specificity": 0.9935,
                "F1 score": 0.8868
            },
            "paper": {
                "Precision": 0.9167,
                "Recall": 0.9298,
                "Specificity": 0.9939,
                "F1 score": 0.9232
            },
            "plastic": {
                "Precision": 0.8656,
                "Recall": 0.8194,
                "Specificity": 0.9925,
                "F1 score": 0.8419
            },
            "shoes": {
                "Precision": 0.9525,
                "Recall": 0.9507,
                "Specificity": 0.9931,
                "F1 score": 0.9516
            },
            "trash": {
                "Precision": 0.9177,
                "Recall": 0.9391,
                "Specificity": 0.996,
                "F1 score": 0.9283
            },
            "white_glass": {
                "Precision": 0.8837,
                "Recall": 0.8581,
                "Specificity": 0.9941,
                "F1 score": 0.8707
            },
            "mean precision": 0.9306416666666667,
            "mean recall": 0.9318666666666666,
            "mean specificity": 0.995225,
            "mean f1 score": 0.93115
        },
        "valid info": {
            "accuracy": 0.6038709677399875,
            "battery": {
                "Precision": 0.5551,
                "Recall": 0.6667,
                "Specificity": 0.9653,
                "F1 score": 0.6058
            },
            "biological": {
                "Precision": 0.7278,
                "Recall": 0.665,
                "Specificity": 0.9831,
                "F1 score": 0.695
            },
            "brown_glass": {
                "Precision": 0.5816,
                "Recall": 0.6777,
                "Specificity": 0.9802,
                "F1 score": 0.626
            },
            "cardboard": {
                "Precision": 0.5762,
                "Recall": 0.6798,
                "Specificity": 0.9695,
                "F1 score": 0.6237
            },
            "clothes": {
                "Precision": 0.7115,
                "Recall": 0.5859,
                "Specificity": 0.8757,
                "F1 score": 0.6426
            },
            "green_glass": {
                "Precision": 0.725,
                "Recall": 0.696,
                "Specificity": 0.9889,
                "F1 score": 0.7102
            },
            "metal": {
                "Precision": 0.4346,
                "Recall": 0.5425,
                "Specificity": 0.9634,
                "F1 score": 0.4826
            },
            "paper": {
                "Precision": 0.5596,
                "Recall": 0.581,
                "Specificity": 0.9668,
                "F1 score": 0.5701
            },
            "plastic": {
                "Precision": 0.4469,
                "Recall": 0.4624,
                "Specificity": 0.9662,
                "F1 score": 0.4545
            },
            "shoes": {
                "Precision": 0.5217,
                "Recall": 0.5772,
                "Specificity": 0.9227,
                "F1 score": 0.548
            },
            "trash": {
                "Precision": 0.6185,
                "Recall": 0.7698,
                "Specificity": 0.9777,
                "F1 score": 0.6859
            },
            "white_glass": {
                "Precision": 0.551,
                "Recall": 0.5226,
                "Specificity": 0.9776,
                "F1 score": 0.5364
            },
            "mean precision": 0.584125,
            "mean recall": 0.6188833333333333,
            "mean specificity": 0.9614250000000002,
            "mean f1 score": 0.5983999999999999
        }
    }

如果想要绘制pr曲线、auc啥的,根据上述的json文件绘制即可

各类曲线:

混淆矩阵:

4.3 最好的权重

训练结束后,会打印在验证集上最好的一轮结果

python 复制代码
best epoch: 42
train performance: {
    "accuracy": 0.9386226339098359,
    "battery": {
        "Precision": 0.9358,
        "Recall": 0.9444,
        "Specificity": 0.9958,
        "F1 score": 0.9401
    },
    "biological": {
        "Precision": 0.9722,
        "Recall": 0.9746,
        "Specificity": 0.9981,
        "F1 score": 0.9734
    },
    "brown_glass": {
        "Precision": 0.9242,
        "Recall": 0.928,
        "Specificity": 0.9969,
        "F1 score": 0.9261
    },
    "cardboard": {
        "Precision": 0.9363,
        "Recall": 0.9271,
        "Specificity": 0.9962,
        "F1 score": 0.9317
    },
    "clothes": {
        "Precision": 0.9822,
        "Recall": 0.9847,
        "Specificity": 0.9907,
        "F1 score": 0.9834
    },
    "green_glass": {
        "Precision": 0.9473,
        "Recall": 0.9623,
        "Specificity": 0.9977,
        "F1 score": 0.9547
    },
    "metal": {
        "Precision": 0.8468,
        "Recall": 0.8701,
        "Specificity": 0.9918,
        "F1 score": 0.8583
    },
    "paper": {
        "Precision": 0.9035,
        "Recall": 0.9143,
        "Specificity": 0.9929,
        "F1 score": 0.9089
    },
    "plastic": {
        "Precision": 0.8464,
        "Recall": 0.7962,
        "Specificity": 0.9915,
        "F1 score": 0.8205
    },
    "shoes": {
        "Precision": 0.9454,
        "Recall": 0.9418,
        "Specificity": 0.9921,
        "F1 score": 0.9436
    },
    "trash": {
        "Precision": 0.9279,
        "Recall": 0.9229,
        "Specificity": 0.9966,
        "F1 score": 0.9254
    },
    "white_glass": {
        "Precision": 0.8371,
        "Recall": 0.8371,
        "Specificity": 0.9914,
        "F1 score": 0.8371
    },
    "mean precision": 0.9170916666666665,
    "mean recall": 0.9169583333333334,
    "mean specificity": 0.9943083333333335,
    "mean f1 score": 0.9169333333333333
}
valid performance: {
    "accuracy": 0.6090322580625516,
    "battery": {
        "Precision": 0.5484,
        "Recall": 0.7196,
        "Specificity": 0.9615,
        "F1 score": 0.6224
    },
    "biological": {
        "Precision": 0.7471,
        "Recall": 0.6599,
        "Specificity": 0.9848,
        "F1 score": 0.7008
    },
    "brown_glass": {
        "Precision": 0.6,
        "Recall": 0.6942,
        "Specificity": 0.9812,
        "F1 score": 0.6437
    },
    "cardboard": {
        "Precision": 0.581,
        "Recall": 0.5843,
        "Specificity": 0.9743,
        "F1 score": 0.5826
    },
    "clothes": {
        "Precision": 0.6558,
        "Recall": 0.6638,
        "Specificity": 0.8177,
        "F1 score": 0.6598
    },
    "green_glass": {
        "Precision": 0.7222,
        "Recall": 0.728,
        "Specificity": 0.9882,
        "F1 score": 0.7251
    },
    "metal": {
        "Precision": 0.4602,
        "Recall": 0.5294,
        "Specificity": 0.9678,
        "F1 score": 0.4924
    },
    "paper": {
        "Precision": 0.5678,
        "Recall": 0.5381,
        "Specificity": 0.9702,
        "F1 score": 0.5526
    },
    "plastic": {
        "Precision": 0.4768,
        "Recall": 0.4162,
        "Specificity": 0.973,
        "F1 score": 0.4444
    },
    "shoes": {
        "Precision": 0.5871,
        "Recall": 0.4608,
        "Specificity": 0.9527,
        "F1 score": 0.5163
    },
    "trash": {
        "Precision": 0.6243,
        "Recall": 0.777,
        "Specificity": 0.978,
        "F1 score": 0.6923
    },
    "white_glass": {
        "Precision": 0.5479,
        "Recall": 0.5161,
        "Specificity": 0.9776,
        "F1 score": 0.5315
    },
    "mean precision": 0.5932166666666667,
    "mean recall": 0.6072833333333333,
    "mean specificity": 0.9605833333333335,
    "mean f1 score": 0.5969916666666667
}

4.4 推理

推理的时候,放在指定目录即可对图片进行批推理

5. 其它

更多的CNN图像分类、语义分割关注本专栏,将持续更新

融合其他模块应该和CBAM模块差不多,思路一样的。对于不同网络,利用vgg、densenet、transformer之类的也差不多。同时语义分割的unet也是连续的3*3卷积,通道数也不变,加入CBAM模块也是类似的。

关于本项目,CBAM也可以在layer内部进行添加,不过我搜了资料,加在layer后是效果最棒的。如果想要发文章,最好多试试其他模块,然后多炼丹才好

相关推荐
红色的山茶花19 分钟前
YOLOv9-0.1部分代码阅读笔记-loss_tal.py
笔记·深度学习·yolo
小蜗牛慢慢爬行33 分钟前
有关异步场景的 10 大 Spring Boot 面试问题
java·开发语言·网络·spring boot·后端·spring·面试
MARIN_shen39 分钟前
Marin说PCB之POC电路layout设计仿真案例---06
网络·单片机·嵌入式硬件·硬件工程·pcb工艺
一位小说男主1 小时前
编码器与解码器:从‘乱码’到‘通话’
人工智能·深度学习
m0_748240021 小时前
Chromium 中chrome.webRequest扩展接口定义c++
网络·c++·chrome
終不似少年遊*1 小时前
华为云计算HCIE笔记05
网络·华为云·云计算·学习笔记·hcie·认证·hcs
蜜獾云2 小时前
docker 安装雷池WAF防火墙 守护Web服务器
linux·运维·服务器·网络·网络安全·docker·容器
qq_529025292 小时前
Torch.gather
python·深度学习·机器学习