改进系列(6):基于DenseNet网络添加TripletAttention注意力层实现的番茄病害图像分类

目录

[1. DenseNet 介绍](#1. DenseNet 介绍)

[2. TripletAttention](#2. TripletAttention)

[3. DenseNet + TripletAttention](#3. DenseNet + TripletAttention)

[4. 番茄场景病害病虫识别](#4. 番茄场景病害病虫识别)

[4.1 数据集情况](#4.1 数据集情况)

[4.2 训练](#4.2 训练)

[4.3 训练结果](#4.3 训练结果)

[4.4 推理](#4.4 推理)


1. DenseNet 介绍

DenseNet是一种深度学习架构,卷积神经网络(CNN)的一种变体,旨在解决梯度消失的问题并提高网络连接性。

在传统的CNN中,信息流是顺序的,每一层只连接到下一层。这可能会导致梯度在网络中传播时减小,从而难以训练深度网络。DenseNet旨在通过引入密集连接来缓解这一问题,密集连接允许从网络中的任何层直接连接到任何其他层。

DenseNet由多个密集块组成,每个密集块包含多个层。密集块内的每一层都连接到同一块内的其他每一层。这种密集的连接促进了特征重用和信息流,使梯度更容易在整个网络中传播。此外,DenseNet在每个密集块后都加入了一个过渡层,以降低特征图的维度并控制网络的增长。

DenseNet的主要优势包括:

改进的梯度流:层之间的直接连接有助于克服梯度消失问题,并实现深度网络的高效训练。

强大的特征重用:密集的连接促进了特征重用,从而实现了更紧凑的网络和更好的参数效率。

参数数量减少:与传统的CNN架构相比,DenseNet通常需要更少的参数,从而使模型更容易训练,计算效率更高。

提高精度:DenseNet已被证明在各种计算机视觉任务上达到了最先进的性能,如图像分类和物体检测。

总体而言,DenseNet是一个强大的深度学习架构,可以解决训练深度网络的挑战。其密集的连接性和高效的参数共享使其成为各种计算机视觉任务的有效选择。

其中,denseNet不同版本的架构如下

2. TripletAttention

TripletPattention是一种用于自然语言处理任务的注意力机制,特别是在问答和机器翻译中。它是传统自我注意机制的延伸,也称为变压器注意。

TripletAttention机制旨在通过结合有关输入序列的额外信息来增强注意力机制。它不仅使用查询、键和值向量,还包括三种不同类型的向量:锚向量、正向量和负向量。

在问答的上下文中,锚向量表示问题,正向量表示包含答案的段落,负向量表示不包含答案的随机段落。通过包含正向量和负向量,TripletAttention机制学会关注文章中有助于回答问题的相关信息,并忽略负面文章中的无关信息。

TripletPattention机制可用于计算文章中每个单词的注意力得分,表明其在答案预测中的重要性。然后,这些注意力得分用于对值向量进行加权,并生成一个上下文向量,即值向量的加权和。

总的来说,TripletAttention是一种强大的注意力机制,它利用额外的信息来提高神经网络在问答和其他自然语言处理任务中的性能。

python实现的代码如下:

python 复制代码
class TripletAttention(nn.Module):
    def __init__(self, in_channels):
        super(TripletAttention, self).__init__()
        self.conv_query = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.conv_key = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.conv_value = nn.Conv2d(in_channels, in_channels, kernel_size=1)

    def forward(self, x):
        query = self.conv_query(x)
        key = self.conv_key(x)
        value = self.conv_value(x)

        query = query.view(query.size(0), -1, query.size(2) * query.size(3))
        key = key.view(key.size(0), -1, key.size(2) * key.size(3))
        value = value.view(value.size(0), -1, value.size(2) * value.size(3))

        attention_map = torch.bmm(query.permute(0, 2, 1), key)
        attention_map = F.softmax(attention_map, dim=-1)

        attended_features = torch.bmm(value, attention_map.permute(0, 2, 1))
        attended_features = attended_features.view(x.size())

        return attended_features

3. DenseNet + TripletAttention

本人添加的位置在DenseNet每个transition后

如下:其中红色区域就是添加的部分

4. 番茄场景病害病虫识别

完整项目下载:DenseNet图像分类改进【添加TripletAttention三重注意力机制模块】:番茄病害识别资源-CSDN文库

4.1 数据集情况

总共有8类别,分别放在不同的目录下,训练集有496张图片,验证集有121张图片

标签类别如下:

python 复制代码
{
    "0": "BA",
    "1": "HA",
    "2": "MP",
    "3": "SE",
    "4": "SL",
    "5": "TP",
    "6": "TU",
    "7": "ZC"
}

可视化结果:

4.2 训练

这里训练了100个epoch,参数如下:

python 复制代码
{
    "train parameters": {
        "model": "densenet121",
        "pretrained": true,
        "freeze_layers": true,
        "batch_size": 16,
        "epochs": 100,
        "optim": "SGD",
        "lr": 0.001,
        "lrf": 0.001
    },
    "model": {
        "total parameters": 7393256.0,
        "train parameters": 439400,
        "flops": 2944160256.0

想要更改训练超参数的可以在train脚本更改

4.3 训练结果

这里最后一轮的指标如下:

python 复制代码
    "epoch:99": {
        "train info": {
            "accuracy": 0.8938775510021657,
            "BA": {
                "Precision": 0.8148,
                "Recall": 0.9167,
                "Specificity": 0.9774,
                "F1 score": 0.8628
            },
            "HA": {
                "Precision": 0.8947,
                "Recall": 0.9659,
                "Specificity": 0.9751,
                "F1 score": 0.9289
            },
            "MP": {
                "Precision": 0.902,
                "Recall": 0.8846,
                "Specificity": 0.9741,
                "F1 score": 0.8932
            },
            "SE": {
                "Precision": 0.8596,
                "Recall": 0.8167,
                "Specificity": 0.9814,
                "F1 score": 0.8376
            },
            "SL": {
                "Precision": 0.9718,
                "Recall": 0.8846,
                "Specificity": 0.9951,
                "F1 score": 0.9262
            },
            "TP": {
                "Precision": 0.8333,
                "Recall": 0.7895,
                "Specificity": 0.9936,
                "F1 score": 0.8108
            },
            "TU": {
                "Precision": 0.8814,
                "Recall": 0.8814,
                "Specificity": 0.9838,
                "F1 score": 0.8814
            },
            "ZC": {
                "Precision": 0.9412,
                "Recall": 0.9412,
                "Specificity": 0.9956,
                "F1 score": 0.9412
            },
            "mean precision": 0.8873500000000001,
            "mean recall": 0.8850750000000001,
            "mean specificity": 0.9845124999999999,
            "mean f1 score": 0.8852625
        },
        "valid info": {
            "accuracy": 0.529411764661394,
            "BA": {
                "Precision": 0.4211,
                "Recall": 0.6667,
                "Specificity": 0.8972,
                "F1 score": 0.5162
            },
            "HA": {
                "Precision": 0.4545,
                "Recall": 0.5,
                "Specificity": 0.8788,
                "F1 score": 0.4762
            },
            "MP": {
                "Precision": 0.5385,
                "Recall": 0.5385,
                "Specificity": 0.871,
                "F1 score": 0.5385
            },
            "SE": {
                "Precision": 0.25,
                "Recall": 0.2,
                "Specificity": 0.9135,
                "F1 score": 0.2222
            },
            "SL": {
                "Precision": 0.7059,
                "Recall": 0.6667,
                "Specificity": 0.9505,
                "F1 score": 0.6857
            },
            "TP": {
                "Precision": 0.3333,
                "Recall": 0.2,
                "Specificity": 0.9825,
                "F1 score": 0.25
            },
            "TU": {
                "Precision": 0.6923,
                "Recall": 0.6,
                "Specificity": 0.9615,
                "F1 score": 0.6429
            },
            "ZC": {
                "Precision": 0.8571,
                "Recall": 0.75,
                "Specificity": 0.991,
                "F1 score": 0.8
            },
            "mean precision": 0.5315875,
            "mean recall": 0.5152375,
            "mean specificity": 0.93075,
            "mean f1 score": 0.5164624999999999
        }
    }

曲线图:

混淆矩阵:

4.4 推理

推理结果如下:

想要更换数据集训练的话,参考readme文件即可

相关推荐
AI科研技术派18 分钟前
颠覆LSTM!贝叶斯优化+LSTM+时序预测=Nature子刊!
人工智能·rnn·lstm
陌繁23 分钟前
LeetCode1.两数之和(超简单讲解)
数据结构·算法·leetcode
熬夜学编程的小王32 分钟前
【优选算法篇】前缀和与哈希表的完美结合:掌握子数组问题的关键(下篇)
数据结构·c++·算法·前缀和·蓝桥杯
X在敲AI代码33 分钟前
力扣刷题D1
算法·leetcode·职场和发展
m0_6949380139 分钟前
Leetcode打卡:最近的房间
算法·leetcode·职场和发展
cloud___fly39 分钟前
力扣hot100——子串
算法·leetcode·哈希算法
m0_6760995842 分钟前
OpenCV图片矫正
人工智能·opencv·计算机视觉
无水先生1 小时前
掌握特征提取:机器学习中的 PCA、t-SNE 和 LDA模型
人工智能·机器学习
菜鸟起航ing1 小时前
数据结构---图(Graph)
java·数据结构·算法·深度优先
qwe3526331 小时前
open cv学习之图片添加水印
人工智能·学习·计算机视觉