基于深度学习的药用草本植物识别系统

🌿 基于深度学习的药用草本植物识别系统

使用 ResNet50 + 迁移学习实现 98 类 药用植物的高精度自动识别,Top-1 准确率 95.07%,Top-5 准确率 99.26%

如果需要全部的代码以及相关过程图与结果图,可以私信与我联系


📑 目录

  1. 项目背景与意义
  2. 整体工作流程
  3. 数据集选择与预处理
  4. 模型架构设计
  5. 训练策略与优化方法
  6. 实验结果与可视化分析
  7. 可视化图表详解
  8. 快速开始与复现指南
  9. 项目结构
  10. 总结与展望

一、项目背景与意义

1.1 问题背景

药用植物的精确识别在传统医学(如阿育吠陀、中医药)、药物研发、植物学研究和农业管理中具有重要价值。传统识别方法依赖植物学专家的经验和形态学比对,存在以下痛点:

  • 效率低:人工逐一比对耗时耗力
  • 主观性强:不同专家的判断可能存在差异
  • 覆盖范围有限:一名专家通常只能精通有限的植物种类
  • 教育成本高:培养合格的植物分类学家需要数年时间

深度学习技术的快速发展使得构建自动化的植物识别系统成为可能。通过计算机视觉模型,我们可以在秒级时间内完成一张植物图像的识别和分类。

1.2 项目目标

构建一个基于深度学习的药用植物自动识别系统:

  • 给定一张植物图像,系统自动判断该植物属于 98 种药用植物中的哪一种
  • 以置信度分数呈现 Top-5 预测结果,辅助专家决策
  • 通过 Grad-CAM 热力图展示模型的决策依据,提供可解释性

1.3 技术栈

组件 技术选型 说明
深度学习框架 PyTorch 2.11 灵活性最高,学术和工业界广泛使用
骨干网络 ResNet50 ImageNet 预训练,2460万参数
数据增强 RandAugment + MixUp + CutMix 三层级增强体系
优化器 AdamW 解耦权重衰减的自适应优化器
学习率调度 Cosine Warmup 先预热后余弦退火
损失函数 Label Smoothing CrossEntropy 平滑标签正则化 ε=0.1
混合精度 torch.amp (FP16) 显存节省40%,训练加速26%

二、整体工作流程

整个项目按照以下流程执行,每一步都有对应的脚本和配置文件:

复制代码
┌──────────────────────────────────────────────────────────────────────┐
│  Step 1: 环境配置                                                     │
│  ├── 安装 PyTorch 2.11 (CUDA 12.8) + 依赖包                           │
│  └── 验证 CUDA 可用性                          │
├──────────────────────────────────────────────────────────────────────┤
│  Step 2: 数据集获取                                                    │
│  ├── 使用 kagglehub 自动下载 Indian Medicinal Leaves Dataset           │
│  ├── 合并两个子集 → 10,813 张图像                                     │
│  └── 自动按 70%/15%/15% 分层划分训练/验证/测试集                       │
├──────────────────────────────────────────────────────────────────────┤
│  Step 3: 模型构建                                                      │
│  ├── 加载 ResNet50 (ImageNet 预训练权重)                               │
│  ├── 移除原始 1000 类分类头 → 构建自定义 98 类分类头                    │
│  └── 统计: 总参数 24,608,418,模型大小 93.9MB                         │
├──────────────────────────────────────────────────────────────────────┤
│  Step 4: 训练 (20 epochs, 全量微调)                                   │
│  ├── 数据增强: RandAugment + MixUp(α=0.2) + CutMix(α=1.0)            │
│  ├── 损失函数: Label Smoothing CrossEntropy (ε=0.1)                   │
│  ├── 优化器: AdamW (lr=1e-4, weight_decay=1e-4)                      │
│  ├── 学习率: Cosine Warmup (5 epoch 预热 → 余弦退火至 1e-6)           │
│  ├── 混合精度: torch.amp (FP16 训练,显存节省约 40%)                  │
├──────────────────────────────────────────────────────────────────────┤
│  Step 5: 模型评估                                                      │
│  ├── 测试集 1,623 张图像评估                                           │
│  ├── 计算 Top-1/Top-5 准确率、精确率、召回率、F1-Score                 │
│  ├── 生成 98 类分类报告 + 混淆矩阵 + 错误案例分析                       │
│  └── 常见混淆对统计                                                    │
├──────────────────────────────────────────────────────────────────────┤
│  Step 6: 可视化图表生成                                                 │
│  ├── Fig 1: 训练曲线 (Loss & Accuracy & LR Schedule)                   │
│  ├── Fig 2: 混淆矩阵 (8 类随机子集)                                   │
│  ├── Fig 3: Grad-CAM 热力图 (模型注意力可视化)                         │
│  ├── Fig 4: t-SNE 特征嵌入 (聚类可视化 + Silhouette Score)            │
│  ├── Fig 5: 数据增强效果展示                                           │
│  ├── Fig 6: 预测样例对比 (正确 vs 错误)                                │
│  └── Fig 7: 类别分布统计                                              │
└──────────────────────────────────────────────────────────────────────┘

三、数据集选择与预处理

3.1 数据来源

本项目使用 Kaggle 上的 Indian Medicinal Leaves Dataset ,通过 kagglehub 库自动下载。

该数据集由两个子集合并而成:

子集 说明
Medicinal Leaf dataset 原始药用植物叶片数据集
Medicinal plant dataset 补充的药用植物数据集

3.2 数据集统计

指标 数值
类别总数 98 类
图像总数 10,813 张
训练集 7,569 张 (70%)
验证集 1,621 张 (15%)
测试集 1,623 张 (15%)
每类最少 8 张
每类最多 187 张
每类平均 110.3 张

⚠️ 数据集特点:存在明显的类别不平衡问题(8~187 张/类),这对模型训练提出了挑战,也是后续使用 MixUp/CutMix 增强和小类过采样的原因之一。

3.3 数据集划分策略

我们使用 分层抽样(Stratified Sampling) 按 70%/15%/15% 的比例将每个类别的样本分配到训练集、验证集和测试集:

python 复制代码
from sklearn.model_selection import train_test_split

# Step 1: 划分训练集 (70%) 和临时集 (30%)
train_paths, temp_paths, train_labels, temp_labels = train_test_split(
    image_paths, labels, train_size=0.7, stratify=labels, random_state=42
)

# Step 2: 划分验证集 (15%) 和测试集 (15%)
val_paths, test_paths, val_labels, test_labels = train_test_split(
    temp_paths, temp_labels, train_size=0.5, stratify=temp_labels, random_state=42
)

分层抽样的优势:确保每个类别在三个子集中都有等比例的代表,避免某一类别在测试集中完全没有样本的情况。

3.4 数据组织格式

数据集按照 PyTorch 的 ImageFolder 格式组织,每个类别一个文件夹:

复制代码
data/
├── Aloevera/          # 芦荟 (282 张)
│   ├── Aloevera_0000.jpg
│   ├── Aloevera_0001.jpg
│   └── ...
├── Neem/              # 印度楝 (278 张)
│   ├── Neem_0000.jpg
│   └── ...
├── Tulsi/             # 圣罗勒 (323 张)
├── Mint/              # 薄荷 (288 张)
├── Hibiscus/          # 木槿 (283 张)
└── ... (共 98 个类别文件夹)

3.5 部分类别展示

包含的药用植物涵盖印度传统医学(阿育吠陀)常用草药:

类别 (英文名) 中文参考 图像数量 药用价值
Aloevera 芦荟 282 皮肤修复、抗炎
Tulsi / Tulasi 圣罗勒 323 免疫增强、抗菌
Neem 印度楝/苦楝 278 抗菌、抗寄生虫
Mint 薄荷 288 消化促进、清凉止痒
Hibiscus 木槿 283 抗氧化、护发
Ashwagandha 南非醉茄/印度人参 191 抗压力、增强精力
Turmeric 姜黄 45 抗炎、抗氧化
Brahmi 婆罗米/假马齿苋 121 改善记忆、镇静

四、模型架构设计

4.1 整体架构

我们采用迁移学习的范式:以 ImageNet 预训练的 ResNet50 为骨干网络,替换其 1000 类分类头为自定义的 98 类分类头。

复制代码
输入图像 (224×224×3)
    │
    ▼
┌──────────────────────────────────────────┐
│              ResNet50 Backbone            │
│  ┌─────────────────────────────────────┐  │
│  │ Conv7×7 + BN + ReLU + MaxPool       │  │  ← 浅层: 通用边缘/纹理特征
│  ├─────────────────────────────────────┤  │
│  │ Layer1: 3× Bottleneck (256-d)       │  │  ← 中浅层: 局部形状
│  ├─────────────────────────────────────┤  │
│  │ Layer2: 4× Bottleneck (512-d)       │  │  ← 中层: 部件级语义
│  ├─────────────────────────────────────┤  │
│  │ Layer3: 6× Bottleneck (1024-d)      │  │  ← 中深层: 目标语义
│  ├─────────────────────────────────────┤  │
│  │ Layer4: 3× Bottleneck (2048-d)      │  │  ← 深层: 高级语义特征
│  ├─────────────────────────────────────┤  │
│  │ AdaptiveAvgPool2d → Flatten         │  │
│  └─────────────────────────────────────┘  │
│          输出: 2048-d 特征向量              │
└──────────────────────────────────────────┘
    │
    ▼
┌──────────────────────────────────────────┐
│         Custom Classification Head        │
│  ┌─────────────────────────────────────┐  │
│  │ Dropout (p=0.3)                     │  │  ← 正则化
│  │ Linear(2048 → 512)                  │  │  ← 降维
│  │ BatchNorm1d(512)                    │  │  ← 稳定训练
│  │ ReLU                                │  │  ← 非线性
│  │ Dropout (p=0.3)                     │  │  ← 正则化
│  │ Linear(512 → 98)                    │  │  ← 分类输出
│  └─────────────────────────────────────┘  │
└──────────────────────────────────────────┘
    │
    ▼
Softmax → 98 类概率分布

4.2 参数统计

模块 参数量 占比
ResNet50 Backbone 23,508,032 95.5%
Classification Head 1,100,386 4.5%
总计 24,608,418 100%
模型大小 (float32) 93.9 MB -

4.3 为什么选择 ResNet50?

优势 说明
成熟的预训练权重 在 ImageNet-1K(120万张、1000类)上训练充分,低级特征(边缘、纹理、形状)泛化能力极强
适中的参数量 24.6M 参数在 8GB 显存下可从容训练,batch_size 可达 64
残差连接机制 跳跃连接(Skip Connection)有效解决了深层网络的梯度消失/爆炸问题
优秀的迁移学习基础 大量研究表明 ResNet50 在细粒度分类任务上表现优异
广泛的框架支持 PyTorch、TensorFlow 等主流框架均提供官方实现和预训练权重

4.4 关键代码实现

模型构建的核心代码非常简洁:

python 复制代码
import torch.nn as nn
import torchvision.models as tv_models

class PlantRecognitionModel(nn.Module):
    def __init__(self, backbone_name="resnet50", num_classes=98, pretrained=True):
        super().__init__()
        # 加载 ImageNet 预训练的 ResNet50
        self.backbone = tv_models.resnet50(
            weights="IMAGENET1K_V1" if pretrained else None
        )
        # 移除原始分类层
        self.backbone.fc = nn.Identity()

        # 自定义分类头
        self.classifier = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(2048, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes),
        )

    def forward(self, x):
        features = self.backbone(x)          # (B, 2048)
        output = self.classifier(features)   # (B, num_classes)
        return output

五、训练策略与优化方法

本项目的训练策略经过精心设计,包含 6 大核心优化技术,这些技术的组合是实现 95%+ 准确率的关键。

5.1 三层级数据增强体系

数据增强是提升模型泛化能力的最有效手段。我们构建了三个层级的增强体系:

第一层:基础几何增强
python 复制代码
transforms.RandomResizedCrop(224, scale=(0.08, 1.0))   # 随机裁剪缩放
transforms.RandomHorizontalFlip(p=0.5)                   # 水平翻转
transforms.RandomRotation(30)                            # 随机旋转±30°
第二层:颜色增强
python 复制代码
transforms.ColorJitter(
    brightness=0.2,   # 亮度 ±20%
    contrast=0.2,     # 对比度 ±20%
    saturation=0.2,   # 饱和度 ±20%
    hue=0.1            # 色调 ±10%
)
transforms.RandAugment(num_ops=2, magnitude=9)  # 自动搜索最优增强策略
第三层:混合增强
方法 原理 优势
MixUp (α=0.2) 将两张图像按比例λ混合: x_mix = λ·x_i + (1-λ)·x_j,标签同时混合 软标签正则化,平滑决策边界
CutMix (α=1.0) 将图像A的一块区域替换为图像B的对应区域 鼓励模型关注局部判别性区域
切换策略 MixUp 40% + CutMix 40% + 标准训练 20% 平衡增强多样性和训练稳定性

5.2 优化器:AdamW

与标准 Adam 的关键区别在于解耦了权重衰减和梯度更新

  • 标准 Adam:L2 正则化与自适应学习率耦合,导致实际权重衰减效果弱于预期
  • AdamW:将权重衰减独立于梯度更新,使正则化效果更纯粹

参数配置:lr=1e-4, betas=(0.9, 0.999), weight_decay=1e-4

5.3 学习率调度:Cosine Warmup

复制代码
lr
│
│  Warmup 阶段 (epoch 0-5)        Cosine 退火阶段 (epoch 5-20)
│  ┌────────┐                    ┌────────────────────────────────┐
│  │ 线性增长│                    │              ╲                  │
│  │        │                    │                ╲                │
│  │       ╱ │                    │                  ╲              │
│  │     ╱   │                    │                    ╲            │
│  │   ╱     │                    │                      ╲          │
│  │ ╱       │                    │                        ╲────────│
│  └─────────┘                    └────────────────────────────────┘
  1e-6 → 1e-4                      1e-4 → 1e-6 (余弦曲线)

为什么需要 Warmup?

训练初期,分类头参数是随机初始化的,而骨干网络是预训练的。两者梯度的方向和大小差异很大。Warmup 阶段使用从小到大的学习率逐步"预热"训练,避免初期梯度的大幅震荡。

5.4 损失函数:Label Smoothing CrossEntropy

标准交叉熵使用 one-hot 标签 (100% 概率给正确类别),可能导致模型过度自信。Label Smoothing 将标签软化:

复制代码
one-hot:              [0,     0,     1,     0,     ..., 0]
                          ↓ 平滑 ε=0.1
Label Smooth:         [0.001, 0.001, 0.9,   0.001, ..., 0.001]
                      ↑______↑______↑______↑________↑
                    错误类别各分 0.1/97    正确类别保留 0.9

效果:防止模型对训练标签过度自信,在 98 类细粒度分类任务中通常提升 0.5-1% 准确率。

5.5 混合精度训练 (AMP)

使用 PyTorch 的 torch.amp 自动将前向和反向传播中的部分计算转为 FP16:

python 复制代码
from torch.amp import GradScaler, autocast

scaler = GradScaler("cuda")
with autocast("cuda"):
    outputs = model(images)
    loss = criterion(outputs, targets)

scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
指标 FP32 FP16 (AMP) 提升
显存占用 ~6.5 GB ~4.0 GB ↓ 38%
单 epoch 耗时 ~120s ~95s ↑ 26%
最终精度 95.07% 95.07% 无损失

5.6 梯度裁剪 + 早停 + 梯度累积

  • 梯度裁剪:梯度范数超过 1.0 时自动缩放,防止梯度爆炸
  • 早停机制 (Early Stopping):验证损失连续 15 个 epoch 不下降时停止训练
  • 梯度累积:支持在显存有限时模拟更大的 batch_size

六、实验结果与可视化分析

6.1 训练动态

Epoch Train Loss Val Loss Val Acc@1 Val Acc@5 LR
1 3.9278 2.4282 56.01% 82.23% 1.00e-4
3 2.8591 1.6907 78.41% 94.63% 4.06e-5
5 2.5460 1.4408 84.33% 96.42% 8.02e-5
8 2.2905 1.2915 86.43% 98.15% 9.57e-5
10 2.2294 1.1972 89.57% 98.46% 8.36e-5
12 2.1728 1.0981 92.54% 99.38% 6.58e-5
16 2.0799 1.0291 94.88% 99.38% 2.58e-5
20 1.9401 1.0063 95.50% 99.32% 2.08e-6

训练损失偏高是因为 MixUp/CutMix 混合了不同样本的标签。验证集不使用增强,从 Epoch 1 的 56.01% 稳定提升到 Epoch 20 的 95.50%。

6.2 最终测试集结果

指标 数值 说明
Top-1 准确率 95.07% 1,543/1,623 张正确分类
Top-5 准确率 99.26% 99%+ 的图像真实类别在 Top-5 内
Precision (Macro) 0.9527 各类别精确率的算术平均
Recall (Macro) 0.9435 各类别召回率的算术平均
F1-Score (Macro) 0.9421 精确率与召回率的调和平均
错误预测数 80 仅占总测试集的 4.93%

6.3 最佳与最差类别分析

准确率 100% 的类别(部分)

Aloevera, Arali, Ashwagandha, Bamboo, Betel, Coriender, Jasmine, Lemon, Mint, Neem, Pepper, Pomegranate, Tamarind...

准确率较低的类别及原因分析

类别 准确率 测试样本数 错误原因分析
Seethaashoka 28.57% 7 样本极少,且与 Nerale/Catharanthus 叶片形态高度相似
Seethapala 52.94% 17 与 Camphor/Jackfruit 叶片纹理和颜色相似
Nagadali 60.87% 23 与 Ganike 高度混淆(9次错误),两类视觉差异极小
Malabar_Nut 71.43% 7 样本不足,叶片形态变异大

6.4 最常见混淆对 (Top-5)

从错误分析报告中提取的最常见混淆对:

排名 真实类别 误判为 频次 分析
1 Nagadali Ganike 9 两类植物叶片形状和叶脉排布极为相似
2 Seethapala Camphor 4 叶片颜色和光泽度相近
3 Tulsi Bhrami 2 圣罗勒与婆罗米幼叶形态接近
4 Tomato Chilly 2 番茄与辣椒幼叶边缘锯齿相似
5 Rose Wood_sorel 2 两者复叶的排列方式在图像中难以区分

6.5 硬件与训练时间

项目 详情
GPU NVIDIA GeForce RTX 5060 Laptop (8GB GDDR7)
GPU 架构 Blackwell (sm_120)
CUDA 版本 12.8
PyTorch 版本 2.11.0+cu128
总训练时间 31 分 44 秒
每 epoch 平均耗时 ~95 秒
单张图像推理延迟 ~8 ms (batch_size=1)
推理吞吐量 ~125 img/s (batch_size=1)

七、可视化图表详解

运行 python fig.py 后,所有可视化图表生成至 outputs/可视化/ 目录。以下逐一详细说明每张图表的含义和解读方法。

📈 Fig 1: 训练曲线

展示了模型在 20 个 epoch 训练过程中的四个关键指标变化:

  • (a) 训练损失 (Training Loss):蓝色细线为每个 batch 的实时损失值,深蓝色粗线为平滑后的趋势线。从初始约 4.8 下降至约 1.9,模型持续收敛。损失偏高是因为 MixUp/CutMix 混合了不同样本的标签,导致交叉熵无法降至接近零。
  • (b) 验证损失 (Validation Loss) :从 2.43 单调下降至 1.01,始终保持下降趋势,未出现过拟合迹象
  • © 学习率调度 (LR Schedule):完整展示了 Cosine Warmup 策略。前 5 个 epoch 学习率从 1e-6 线性预热至 1e-4,随后沿余弦曲线平滑衰减至接近零。虚线标注了 Warmup→Cosine 的过渡点。
  • (d) 验证准确率 (Validation Accuracy):Acc@1(蓝色方块线)和 Acc@5(红色菱形线)。Acc@5 在 Epoch 12 即达到 99.38%,说明 Top-5 几乎完美;Top-1 最终达 95.50%。

💡 关键洞察:验证损失始终下降,验证准确率始终上升,表明模型的泛化能力持续改善,训练策略有效。

🎯 Fig 2: 混淆矩阵

从 98 类中随机选取 8 类展示归一化混淆矩阵,便于在论文和报告中清晰呈现:

  • 对角线(红色边框):正确分类的样本比例。对角线数值越高越好
  • 颜色深浅:白色 = 0%,深蓝色 = 100%
  • 数字标注规则:高置信度(>50%)用白色大字,低置信度用灰色小字
  • 平均准确率:标题中显示 8 类的平均对角线准确率

💡 关键洞察:对角线值普遍在 85%-100%,说明这 8 类区分度良好。非对角线值很低,表明类别间混淆较少。红色边框突出对角线元素,让读者一眼就能判断模型的分类质量。

🔥 Fig 3: Grad-CAM 可视化

Grad-CAM (Gradient-weighted Class Activation Mapping) 是目前最主流的深度学习可解释性方法。其核心原理为:

  1. 计算梯度:目标类别对最后一层卷积特征的梯度 → 反映每个通道的重要性
  2. 全局平均池化:对梯度进行空间平均得到通道权重
  3. 加权求和:用权重对特征图加权求和 → 得到热力图

图中:

  • 上排:原始图像 + 类别名称
  • 下排:Grad-CAM 热力图叠加显示,红色 = 高关注度(判别性区域),蓝色 = 低关注度(背景/无关区域)
  • 置信度标注:蓝色置信度条显示模型对该预测的确信程度

💡 关键洞察 :这是验证模型是否学到了正确特征的关键证据。理想的 Grad-CAM 热力图应集中覆盖植物的叶片主体区域。如果热力图聚焦在背景(如土壤、花盆)或边缘无关区域,说明模型存在**捷径学习(shortcut learning)**问题。

🌌 Fig 4: t-SNE 特征嵌入

t-SNE 是目前最常用的高维特征可视化方法。其处理流程为:

  1. 从 ResNet50 最终池化层提取 2048 维特征向量
  2. 使用 PCA 预降维至 100 维(保留主要信息)
  3. 使用 t-SNE 进一步降至 2 维空间进行可视化

图中:

  • 每个点:一张测试集图像
  • 不同颜色:不同类别
  • Silhouette Score:聚类质量指标,范围 -1, 1,越大表示类内紧凑、类间分离越好

💡 关键洞察:相同颜色的点聚集成紧密的簇,不同颜色的簇之间有清晰的边界分离,表明 ResNet50 学到的特征具有良好的类别可分性。高 Silhouette Score(通常 >0.5)反映了模型的优秀表现。

🎨 Fig 5: 数据增强演示

展示同一张植物图像经过数据增强 Pipeline 后的多样性效果:

  • 原始图像:中心区域标注 "原始图像"
  • 训练增强(×6):每次经过完整的训练增强管道(RandAugment + ColorJitter + RandomResizedCrop),展示随机增强的多样性
  • 水平翻转:确定性水平翻转
  • 旋转:确定性旋转 +30°

💡 关键洞察:增强后的图像在颜色、构图、角度、尺度上呈现丰富多样性。这种**"同一植物、多种视角"**的训练数据是模型获得强泛化能力的基础------模型学到的不是某张特定图像,而是该植物的本质视觉特征。

✅❌ Fig 6: 预测样例对比

对比展示正确分类和错误分类的预测样例:

  • 上排 :原始图像
    • 🟢 绿色标题 = 正确预测
    • 🔴 红色标题 = 错误预测
  • 下排 :Top-5 置信度条形图
    • 🟢 绿色条 = 真实类别
    • 🔵 蓝色条 = 其他预测类别
  • 右下角标注:真实类别名称

💡 关键洞察:正确的预测通常有高置信度(80-99%)且真实类别高居第一位。错误预测的分析可以揭示数据标注问题、类别歧义性或模型的系统性弱点------例如难区分的相似物种对。

📊 Fig 7: 类别分布统计

全面展示 98 类数据集的分布情况:

  • (a) 各类别图像分布:按样本数降序排列的柱状图,红色虚线标注均值(约 110 张/类)
  • (b) 分布直方图:展示每类样本数的分布形态,红色线标注均值,绿色线标注中位数
  • ©Top-20 样本最多类别:数据量最充足的 20 个类别(如 Aloevera 282 张、Tulsi 323 张)
  • (d) Bottom-20 样本最少类别:数据量最少的 20 个类别(最少仅 8 张)

💡 关键洞察:长尾分布------头部类别有 300+ 张图像,尾部类别不到 50 张。这直接解释了 Fig 6 中某些小类(如 Seethaashoka 仅 7 张测试样本)表现较差的原因。解决方向包括:过采样、加权损失函数、或数据增强策略。


八、快速开始与复现指南

8.1 环境要求

组件 最低要求 推荐
Python 3.10+ 3.13
PyTorch 2.0+ 2.11
CUDA 11.8+ 12.8
GPU 显存 4 GB+ 8 GB+

8.2 安装依赖

bash 复制代码
# 安装所有依赖
pip install -r requirements.txt

# 验证 CUDA 环境
python -c "import torch; print(f'PyTorch {torch.__version__}'); print(f'CUDA: {torch.cuda.is_available()}')"

8.3 下载数据集并训练

bash 复制代码
# 方式一:自动从 Kaggle 下载(推荐)
python -c "
import kagglehub, os, shutil

# 下载 Indian Medicinal Leaves Dataset
path = kagglehub.dataset_download('aryashah2k/indian-medicinal-leaves-dataset')
src = os.path.join(path, 'Indian Medicinal Leaves Image Datasets')

# 合并两个子集到 data/ 目录
os.makedirs('data', exist_ok=True)
for sub in ['Medicinal Leaf dataset', 'Medicinal plant dataset']:
    sub_path = os.path.join(src, sub)
    if not os.path.exists(sub_path): continue
    for cls in os.listdir(sub_path):
        cls_path = os.path.join(sub_path, cls)
        if not os.path.isdir(cls_path): continue
        clean = cls.strip().replace(' ', '_').replace('(','').replace(')','')
        dst = os.path.join('data', clean)
        os.makedirs(dst, exist_ok=True)
        for i, f in enumerate(os.listdir(cls_path)):
            if f.lower().endswith(('.jpg','.jpeg','.png','.bmp')):
                ext = os.path.splitext(f)[1]
                shutil.copy2(os.path.join(cls_path, f), os.path.join(dst, f'{clean}_{i:04d}{ext}'))
print('✅ 数据集准备完毕!')
"

# 开始训练
python train.py --epochs 20 --batch_size 64 --strategy full

8.4 评估与预测

bash 复制代码
# 在测试集上全面评估模型
python evaluate.py --checkpoint checkpoints/best_model.pth

# 生成所有可视化图表(7张)
python fig.py

# 单张图像预测
python predict.py --image data/Aloevera/Aloevera_0000.jpg --topk 5

# 批量预测
python predict.py --image_dir data/Aloevera/ --topk 5

# 使用 TTA 提高预测准确率
python predict.py --image your_plant.jpg --tta

# 推理性能基准测试
python predict.py --benchmark --batch_sizes 1,4,8,16,32

8.5 高级训练选项

bash 复制代码
# 使用 ResNet101 骨干网络(更大容量)
python train.py --backbone resnet101 --epochs 30

# 使用渐进式微调(推荐大数据集)
python train.py --strategy progressive

# 降低 batch_size 以适应较小显存
python train.py --batch_size 16

# 使用自己的数据集
python train.py --data_root /path/to/your/dataset

# 从检查点恢复训练
python train.py --resume checkpoints/best_model.pth

# 训练完成后导出 TorchScript 模型
python train.py --export

九、项目结构

复制代码
草本植物识别/
└── plant_recognition/
    │
    ├── config.py              # 🌐 全局配置 (模型/数据/训练超参数)
    ├── dataset.py             # 📦 数据加载、增强Pipeline、数据集划分
    ├── models.py              # 🏗️ ResNet50模型定义 + 多策略微调管理器
    ├── utils.py               # 🔧 优化器/学习率调度器/损失函数/评估指标
    │
    ├── train.py               # 🚂 训练脚本 (全量微调 / 渐进式微调)
    ├── evaluate.py            # 📊 评估脚本 (混淆矩阵 / 错误分析 / Grad-CAM)
    ├── predict.py             # 🔮 推理脚本 (单张/批量/TTA/性能基准)
    ├── fig.py                 # 📈 可视化脚本 (生成7张论文级图表)
    │
    ├── requirements.txt       # 📋 Python 依赖清单
    ├── README.md              # 📖 项目英文文档
    │
    ├── data/                  # 🗂️ 数据集 (98类 / 10,813张图像)
    │   ├── Aloevera/
    │   ├── Neem/
    │   ├── Tulsi/
    │   └── ... (共98个类别文件夹)
    │
    ├── checkpoints/           # 💾 模型检查点
    │   └── best_model.pth     # 最佳模型 (Val Acc@1 = 95.50%)
    │
    ├── logs/                  # 📝 TensorBoard 训练日志
    │   └── resnet50_YYYYMMDD_HHMMSS/
    │
    └── outputs/               # 📤 输出文件
        ├── 可视化/             # 📈 论文级图表 (7张)
        │   ├── fig1_训练曲线.png      # 四合一训练动态
        │   ├── fig2_混淆矩阵.png      # 8类混淆矩阵
        │   ├── fig3_GradCAM.png       # Grad-CAM注意力可视化
        │   ├── fig4_tSNE.png          # t-SNE特征嵌入
        │   ├── fig5_数据增强.png      # 数据增强效果展示
        │   ├── fig6_预测样例.png      # 正确vs错误预测对比
        │   └── fig7_类别分布.png      # 数据集统计分析
        ├── evaluation_report.txt     # 完整评估报告
        ├── error_analysis.txt        # 错误案例分析
        └── results_*.txt             # 训练结果摘要

十、总结与展望

🔑 关键技术要点回顾

技术 作用 贡献
迁移学习 ImageNet 预训练提供通用视觉特征 大幅降低对标注数据的需求,从零训练到 95% 仅需 20 epochs
RandAugment + MixUp + CutMix 三层级数据增强 显著提升泛化能力,防止过拟合
Label Smoothing 软化标签,防止过度自信 98 类细粒度分类中提升 0.5-1%
Cosine Warmup 预热+平滑衰减 避免训练初期震荡,实现精细收敛
混合精度 (AMP) FP16 训练 显存节省 40%,加速 26%,精度无损
Grad-CAM 可解释性 热力图展示模型关注区域 验证模型学到了正确的判别性特征

📈 模型优点

  • 高准确率:Top-1 95.07%,Top-5 99.26%
  • 高效率:30 分钟训练完成,单张推理仅 8ms
  • 可解释:Grad-CAM 提供决策依据可视化
  • 可扩展:支持 ResNet/ResNeXt/EfficientNet/ViT/Swin 多种骨干网络
  • 工程化:完整的训练-评估-推理-可视化流水线

🔮 改进方向

  1. 处理类别不平衡:使用过采样(Oversampling)或加权 Focal Loss 改善尾部类别(< 50 样本)的识别率
  2. 集成学习:训练多个不同骨干网络(ResNet101 + EfficientNet + ViT)并集成,预计可提升 1-2%
  3. 更大的数据集:迁移至 163 类或 880 类的中草药大规模数据集
  4. 细粒度特征增强:引入注意力机制(CBAM、SE-Net),聚焦叶脉、边缘等判别性微特征
  5. 移动端部署:使用知识蒸馏或模型量化(INT8),将模型压缩至 20MB 以内部署到手机端
  6. 多模态融合:结合叶片形状 + 花朵 + 果实等多种器官信息,提升准确率

📚 参考文献

  1. He, K., et al. "Deep Residual Learning for Image Recognition." CVPR 2016 . [arXiv:1512.03385](https://arxiv.org/abs/1512.03385)
  2. Cubuk, E. D., et al. "RandAugment: Practical automated data augmentation with a reduced search space." NeurIPS 2020 . [arXiv:1909.13719](https://arxiv.org/abs/1909.13719)
  3. Zhang, H., et al. "mixup: Beyond Empirical Risk Minimization." ICLR 2018 . [arXiv:1710.09412](https://arxiv.org/abs/1710.09412)
  4. Yun, S., et al. "CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features." ICCV 2019 . [arXiv:1905.04899](https://arxiv.org/abs/1905.04899)
  5. Müller, R., et al. "When Does Label Smoothing Help?" NeurIPS 2019 . [arXiv:1906.02629](https://arxiv.org/abs/1906.02629)
  6. Loshchilov, I. & Hutter, F. "Decoupled Weight Decay Regularization." ICLR 2019 . [arXiv:1711.05101](https://arxiv.org/abs/1711.05101)
  7. Selvaraju, R. R., et al. "Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization." ICCV 2017 . [arXiv:1610.02391](https://arxiv.org/abs/1610.02391)

🌿 基于深度学习的药用植物识别系统

Deep Learning-Based Medicinal Plant Recognition System

~ResNet50 · 迁移学习 · 98 类 · 95.07% Top-1 · Python 3.13 · PyTorch 2.11~


📧 如果需要全部的代码以及相关过程图与结果图,可以私信与我联系

相关推荐
txg6661 小时前
FuzzGPT:用大语言模型生成“极端边界程序”的深度学习框架 Fuzzing 新范式
人工智能·深度学习·安全·网络安全·语言模型
码云骑士1 小时前
05-Python字典底层原理-Hash表与有序性的真相
开发语言·python·哈希算法
Cloud_Shy6181 小时前
解读《Effective Python 3rd Edition》:从练气到老魔(第六章 Item 44 - 47)
开发语言·人工智能·经验分享·笔记·python
mxlwd1681 小时前
movielen 100k lr模型训练过程
开发语言·python·机器学习
探物 AI1 小时前
零基础入门3D点云深度学习:从PointNet开始,理解3D数据处理
人工智能·深度学习·3d
硅谷秋水2 小时前
Nautilus:从单一提示词到即插即用机器人学习
人工智能·深度学习·机器学习·机器人
小森林之主2 小时前
深入正则表达式:核心语法与实战剖析
javascript·python·正则表达式·编程技巧·字符串处理
果丁智能2 小时前
智慧校园一卡通深度融合方案:基于超级SIM卡的手机碰一碰智能开锁技术落地实践
数据结构·人工智能·python·科技·算法·智能家居·信息与通信
eric-sjq2 小时前
Xiaothink-T17-Tiny 模型深度解析:轻量级RNN架构的创新与实战评测
人工智能·深度学习·语言模型·自然语言处理·架构