🌿 基于深度学习的药用草本植物识别系统
使用 ResNet50 + 迁移学习实现 98 类 药用植物的高精度自动识别,Top-1 准确率 95.07%,Top-5 准确率 99.26%
如果需要全部的代码以及相关过程图与结果图,可以私信与我联系
📑 目录
一、项目背景与意义
1.1 问题背景
药用植物的精确识别在传统医学(如阿育吠陀、中医药)、药物研发、植物学研究和农业管理中具有重要价值。传统识别方法依赖植物学专家的经验和形态学比对,存在以下痛点:
- 效率低:人工逐一比对耗时耗力
- 主观性强:不同专家的判断可能存在差异
- 覆盖范围有限:一名专家通常只能精通有限的植物种类
- 教育成本高:培养合格的植物分类学家需要数年时间
深度学习技术的快速发展使得构建自动化的植物识别系统成为可能。通过计算机视觉模型,我们可以在秒级时间内完成一张植物图像的识别和分类。
1.2 项目目标
构建一个基于深度学习的药用植物自动识别系统:
- 给定一张植物图像,系统自动判断该植物属于 98 种药用植物中的哪一种
- 以置信度分数呈现 Top-5 预测结果,辅助专家决策
- 通过 Grad-CAM 热力图展示模型的决策依据,提供可解释性
1.3 技术栈
| 组件 | 技术选型 | 说明 |
|---|---|---|
| 深度学习框架 | PyTorch 2.11 | 灵活性最高,学术和工业界广泛使用 |
| 骨干网络 | ResNet50 | ImageNet 预训练,2460万参数 |
| 数据增强 | RandAugment + MixUp + CutMix | 三层级增强体系 |
| 优化器 | AdamW | 解耦权重衰减的自适应优化器 |
| 学习率调度 | Cosine Warmup | 先预热后余弦退火 |
| 损失函数 | Label Smoothing CrossEntropy | 平滑标签正则化 ε=0.1 |
| 混合精度 | torch.amp (FP16) | 显存节省40%,训练加速26% |
二、整体工作流程
整个项目按照以下流程执行,每一步都有对应的脚本和配置文件:
┌──────────────────────────────────────────────────────────────────────┐
│ Step 1: 环境配置 │
│ ├── 安装 PyTorch 2.11 (CUDA 12.8) + 依赖包 │
│ └── 验证 CUDA 可用性 │
├──────────────────────────────────────────────────────────────────────┤
│ Step 2: 数据集获取 │
│ ├── 使用 kagglehub 自动下载 Indian Medicinal Leaves Dataset │
│ ├── 合并两个子集 → 10,813 张图像 │
│ └── 自动按 70%/15%/15% 分层划分训练/验证/测试集 │
├──────────────────────────────────────────────────────────────────────┤
│ Step 3: 模型构建 │
│ ├── 加载 ResNet50 (ImageNet 预训练权重) │
│ ├── 移除原始 1000 类分类头 → 构建自定义 98 类分类头 │
│ └── 统计: 总参数 24,608,418,模型大小 93.9MB │
├──────────────────────────────────────────────────────────────────────┤
│ Step 4: 训练 (20 epochs, 全量微调) │
│ ├── 数据增强: RandAugment + MixUp(α=0.2) + CutMix(α=1.0) │
│ ├── 损失函数: Label Smoothing CrossEntropy (ε=0.1) │
│ ├── 优化器: AdamW (lr=1e-4, weight_decay=1e-4) │
│ ├── 学习率: Cosine Warmup (5 epoch 预热 → 余弦退火至 1e-6) │
│ ├── 混合精度: torch.amp (FP16 训练,显存节省约 40%) │
├──────────────────────────────────────────────────────────────────────┤
│ Step 5: 模型评估 │
│ ├── 测试集 1,623 张图像评估 │
│ ├── 计算 Top-1/Top-5 准确率、精确率、召回率、F1-Score │
│ ├── 生成 98 类分类报告 + 混淆矩阵 + 错误案例分析 │
│ └── 常见混淆对统计 │
├──────────────────────────────────────────────────────────────────────┤
│ Step 6: 可视化图表生成 │
│ ├── Fig 1: 训练曲线 (Loss & Accuracy & LR Schedule) │
│ ├── Fig 2: 混淆矩阵 (8 类随机子集) │
│ ├── Fig 3: Grad-CAM 热力图 (模型注意力可视化) │
│ ├── Fig 4: t-SNE 特征嵌入 (聚类可视化 + Silhouette Score) │
│ ├── Fig 5: 数据增强效果展示 │
│ ├── Fig 6: 预测样例对比 (正确 vs 错误) │
│ └── Fig 7: 类别分布统计 │
└──────────────────────────────────────────────────────────────────────┘
三、数据集选择与预处理
3.1 数据来源
本项目使用 Kaggle 上的 Indian Medicinal Leaves Dataset ,通过 kagglehub 库自动下载。
该数据集由两个子集合并而成:
| 子集 | 说明 |
|---|---|
| Medicinal Leaf dataset | 原始药用植物叶片数据集 |
| Medicinal plant dataset | 补充的药用植物数据集 |
3.2 数据集统计
| 指标 | 数值 |
|---|---|
| 类别总数 | 98 类 |
| 图像总数 | 10,813 张 |
| 训练集 | 7,569 张 (70%) |
| 验证集 | 1,621 张 (15%) |
| 测试集 | 1,623 张 (15%) |
| 每类最少 | 8 张 |
| 每类最多 | 187 张 |
| 每类平均 | 110.3 张 |
⚠️ 数据集特点:存在明显的类别不平衡问题(8~187 张/类),这对模型训练提出了挑战,也是后续使用 MixUp/CutMix 增强和小类过采样的原因之一。
3.3 数据集划分策略
我们使用 分层抽样(Stratified Sampling) 按 70%/15%/15% 的比例将每个类别的样本分配到训练集、验证集和测试集:
python
from sklearn.model_selection import train_test_split
# Step 1: 划分训练集 (70%) 和临时集 (30%)
train_paths, temp_paths, train_labels, temp_labels = train_test_split(
image_paths, labels, train_size=0.7, stratify=labels, random_state=42
)
# Step 2: 划分验证集 (15%) 和测试集 (15%)
val_paths, test_paths, val_labels, test_labels = train_test_split(
temp_paths, temp_labels, train_size=0.5, stratify=temp_labels, random_state=42
)
分层抽样的优势:确保每个类别在三个子集中都有等比例的代表,避免某一类别在测试集中完全没有样本的情况。
3.4 数据组织格式
数据集按照 PyTorch 的 ImageFolder 格式组织,每个类别一个文件夹:
data/
├── Aloevera/ # 芦荟 (282 张)
│ ├── Aloevera_0000.jpg
│ ├── Aloevera_0001.jpg
│ └── ...
├── Neem/ # 印度楝 (278 张)
│ ├── Neem_0000.jpg
│ └── ...
├── Tulsi/ # 圣罗勒 (323 张)
├── Mint/ # 薄荷 (288 张)
├── Hibiscus/ # 木槿 (283 张)
└── ... (共 98 个类别文件夹)
3.5 部分类别展示
包含的药用植物涵盖印度传统医学(阿育吠陀)常用草药:
| 类别 (英文名) | 中文参考 | 图像数量 | 药用价值 |
|---|---|---|---|
| Aloevera | 芦荟 | 282 | 皮肤修复、抗炎 |
| Tulsi / Tulasi | 圣罗勒 | 323 | 免疫增强、抗菌 |
| Neem | 印度楝/苦楝 | 278 | 抗菌、抗寄生虫 |
| Mint | 薄荷 | 288 | 消化促进、清凉止痒 |
| Hibiscus | 木槿 | 283 | 抗氧化、护发 |
| Ashwagandha | 南非醉茄/印度人参 | 191 | 抗压力、增强精力 |
| Turmeric | 姜黄 | 45 | 抗炎、抗氧化 |
| Brahmi | 婆罗米/假马齿苋 | 121 | 改善记忆、镇静 |
四、模型架构设计
4.1 整体架构
我们采用迁移学习的范式:以 ImageNet 预训练的 ResNet50 为骨干网络,替换其 1000 类分类头为自定义的 98 类分类头。
输入图像 (224×224×3)
│
▼
┌──────────────────────────────────────────┐
│ ResNet50 Backbone │
│ ┌─────────────────────────────────────┐ │
│ │ Conv7×7 + BN + ReLU + MaxPool │ │ ← 浅层: 通用边缘/纹理特征
│ ├─────────────────────────────────────┤ │
│ │ Layer1: 3× Bottleneck (256-d) │ │ ← 中浅层: 局部形状
│ ├─────────────────────────────────────┤ │
│ │ Layer2: 4× Bottleneck (512-d) │ │ ← 中层: 部件级语义
│ ├─────────────────────────────────────┤ │
│ │ Layer3: 6× Bottleneck (1024-d) │ │ ← 中深层: 目标语义
│ ├─────────────────────────────────────┤ │
│ │ Layer4: 3× Bottleneck (2048-d) │ │ ← 深层: 高级语义特征
│ ├─────────────────────────────────────┤ │
│ │ AdaptiveAvgPool2d → Flatten │ │
│ └─────────────────────────────────────┘ │
│ 输出: 2048-d 特征向量 │
└──────────────────────────────────────────┘
│
▼
┌──────────────────────────────────────────┐
│ Custom Classification Head │
│ ┌─────────────────────────────────────┐ │
│ │ Dropout (p=0.3) │ │ ← 正则化
│ │ Linear(2048 → 512) │ │ ← 降维
│ │ BatchNorm1d(512) │ │ ← 稳定训练
│ │ ReLU │ │ ← 非线性
│ │ Dropout (p=0.3) │ │ ← 正则化
│ │ Linear(512 → 98) │ │ ← 分类输出
│ └─────────────────────────────────────┘ │
└──────────────────────────────────────────┘
│
▼
Softmax → 98 类概率分布
4.2 参数统计
| 模块 | 参数量 | 占比 |
|---|---|---|
| ResNet50 Backbone | 23,508,032 | 95.5% |
| Classification Head | 1,100,386 | 4.5% |
| 总计 | 24,608,418 | 100% |
| 模型大小 (float32) | 93.9 MB | - |
4.3 为什么选择 ResNet50?
| 优势 | 说明 |
|---|---|
| 成熟的预训练权重 | 在 ImageNet-1K(120万张、1000类)上训练充分,低级特征(边缘、纹理、形状)泛化能力极强 |
| 适中的参数量 | 24.6M 参数在 8GB 显存下可从容训练,batch_size 可达 64 |
| 残差连接机制 | 跳跃连接(Skip Connection)有效解决了深层网络的梯度消失/爆炸问题 |
| 优秀的迁移学习基础 | 大量研究表明 ResNet50 在细粒度分类任务上表现优异 |
| 广泛的框架支持 | PyTorch、TensorFlow 等主流框架均提供官方实现和预训练权重 |
4.4 关键代码实现
模型构建的核心代码非常简洁:
python
import torch.nn as nn
import torchvision.models as tv_models
class PlantRecognitionModel(nn.Module):
def __init__(self, backbone_name="resnet50", num_classes=98, pretrained=True):
super().__init__()
# 加载 ImageNet 预训练的 ResNet50
self.backbone = tv_models.resnet50(
weights="IMAGENET1K_V1" if pretrained else None
)
# 移除原始分类层
self.backbone.fc = nn.Identity()
# 自定义分类头
self.classifier = nn.Sequential(
nn.Dropout(0.3),
nn.Linear(2048, 512),
nn.BatchNorm1d(512),
nn.ReLU(inplace=True),
nn.Dropout(0.3),
nn.Linear(512, num_classes),
)
def forward(self, x):
features = self.backbone(x) # (B, 2048)
output = self.classifier(features) # (B, num_classes)
return output
五、训练策略与优化方法
本项目的训练策略经过精心设计,包含 6 大核心优化技术,这些技术的组合是实现 95%+ 准确率的关键。
5.1 三层级数据增强体系
数据增强是提升模型泛化能力的最有效手段。我们构建了三个层级的增强体系:
第一层:基础几何增强
python
transforms.RandomResizedCrop(224, scale=(0.08, 1.0)) # 随机裁剪缩放
transforms.RandomHorizontalFlip(p=0.5) # 水平翻转
transforms.RandomRotation(30) # 随机旋转±30°
第二层:颜色增强
python
transforms.ColorJitter(
brightness=0.2, # 亮度 ±20%
contrast=0.2, # 对比度 ±20%
saturation=0.2, # 饱和度 ±20%
hue=0.1 # 色调 ±10%
)
transforms.RandAugment(num_ops=2, magnitude=9) # 自动搜索最优增强策略
第三层:混合增强
| 方法 | 原理 | 优势 |
|---|---|---|
| MixUp (α=0.2) | 将两张图像按比例λ混合: x_mix = λ·x_i + (1-λ)·x_j,标签同时混合 |
软标签正则化,平滑决策边界 |
| CutMix (α=1.0) | 将图像A的一块区域替换为图像B的对应区域 | 鼓励模型关注局部判别性区域 |
| 切换策略 | MixUp 40% + CutMix 40% + 标准训练 20% | 平衡增强多样性和训练稳定性 |
5.2 优化器:AdamW
与标准 Adam 的关键区别在于解耦了权重衰减和梯度更新:
- 标准 Adam:L2 正则化与自适应学习率耦合,导致实际权重衰减效果弱于预期
- AdamW:将权重衰减独立于梯度更新,使正则化效果更纯粹
参数配置:lr=1e-4, betas=(0.9, 0.999), weight_decay=1e-4
5.3 学习率调度:Cosine Warmup
lr
│
│ Warmup 阶段 (epoch 0-5) Cosine 退火阶段 (epoch 5-20)
│ ┌────────┐ ┌────────────────────────────────┐
│ │ 线性增长│ │ ╲ │
│ │ │ │ ╲ │
│ │ ╱ │ │ ╲ │
│ │ ╱ │ │ ╲ │
│ │ ╱ │ │ ╲ │
│ │ ╱ │ │ ╲────────│
│ └─────────┘ └────────────────────────────────┘
1e-6 → 1e-4 1e-4 → 1e-6 (余弦曲线)
为什么需要 Warmup?
训练初期,分类头参数是随机初始化的,而骨干网络是预训练的。两者梯度的方向和大小差异很大。Warmup 阶段使用从小到大的学习率逐步"预热"训练,避免初期梯度的大幅震荡。
5.4 损失函数:Label Smoothing CrossEntropy
标准交叉熵使用 one-hot 标签 (100% 概率给正确类别),可能导致模型过度自信。Label Smoothing 将标签软化:
one-hot: [0, 0, 1, 0, ..., 0]
↓ 平滑 ε=0.1
Label Smooth: [0.001, 0.001, 0.9, 0.001, ..., 0.001]
↑______↑______↑______↑________↑
错误类别各分 0.1/97 正确类别保留 0.9
效果:防止模型对训练标签过度自信,在 98 类细粒度分类任务中通常提升 0.5-1% 准确率。
5.5 混合精度训练 (AMP)
使用 PyTorch 的 torch.amp 自动将前向和反向传播中的部分计算转为 FP16:
python
from torch.amp import GradScaler, autocast
scaler = GradScaler("cuda")
with autocast("cuda"):
outputs = model(images)
loss = criterion(outputs, targets)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
| 指标 | FP32 | FP16 (AMP) | 提升 |
|---|---|---|---|
| 显存占用 | ~6.5 GB | ~4.0 GB | ↓ 38% |
| 单 epoch 耗时 | ~120s | ~95s | ↑ 26% |
| 最终精度 | 95.07% | 95.07% | 无损失 |
5.6 梯度裁剪 + 早停 + 梯度累积
- 梯度裁剪:梯度范数超过 1.0 时自动缩放,防止梯度爆炸
- 早停机制 (Early Stopping):验证损失连续 15 个 epoch 不下降时停止训练
- 梯度累积:支持在显存有限时模拟更大的 batch_size
六、实验结果与可视化分析
6.1 训练动态
| Epoch | Train Loss | Val Loss | Val Acc@1 | Val Acc@5 | LR |
|---|---|---|---|---|---|
| 1 | 3.9278 | 2.4282 | 56.01% | 82.23% | 1.00e-4 |
| 3 | 2.8591 | 1.6907 | 78.41% | 94.63% | 4.06e-5 |
| 5 | 2.5460 | 1.4408 | 84.33% | 96.42% | 8.02e-5 |
| 8 | 2.2905 | 1.2915 | 86.43% | 98.15% | 9.57e-5 |
| 10 | 2.2294 | 1.1972 | 89.57% | 98.46% | 8.36e-5 |
| 12 | 2.1728 | 1.0981 | 92.54% | 99.38% | 6.58e-5 |
| 16 | 2.0799 | 1.0291 | 94.88% | 99.38% | 2.58e-5 |
| 20 | 1.9401 | 1.0063 | 95.50% | 99.32% | 2.08e-6 |
训练损失偏高是因为 MixUp/CutMix 混合了不同样本的标签。验证集不使用增强,从 Epoch 1 的 56.01% 稳定提升到 Epoch 20 的 95.50%。
6.2 最终测试集结果
| 指标 | 数值 | 说明 |
|---|---|---|
| Top-1 准确率 | 95.07% | 1,543/1,623 张正确分类 |
| Top-5 准确率 | 99.26% | 99%+ 的图像真实类别在 Top-5 内 |
| Precision (Macro) | 0.9527 | 各类别精确率的算术平均 |
| Recall (Macro) | 0.9435 | 各类别召回率的算术平均 |
| F1-Score (Macro) | 0.9421 | 精确率与召回率的调和平均 |
| 错误预测数 | 80 | 仅占总测试集的 4.93% |
6.3 最佳与最差类别分析
准确率 100% 的类别(部分) :
Aloevera, Arali, Ashwagandha, Bamboo, Betel, Coriender, Jasmine, Lemon, Mint, Neem, Pepper, Pomegranate, Tamarind...
准确率较低的类别及原因分析:
| 类别 | 准确率 | 测试样本数 | 错误原因分析 |
|---|---|---|---|
| Seethaashoka | 28.57% | 7 | 样本极少,且与 Nerale/Catharanthus 叶片形态高度相似 |
| Seethapala | 52.94% | 17 | 与 Camphor/Jackfruit 叶片纹理和颜色相似 |
| Nagadali | 60.87% | 23 | 与 Ganike 高度混淆(9次错误),两类视觉差异极小 |
| Malabar_Nut | 71.43% | 7 | 样本不足,叶片形态变异大 |
6.4 最常见混淆对 (Top-5)
从错误分析报告中提取的最常见混淆对:
| 排名 | 真实类别 | 误判为 | 频次 | 分析 |
|---|---|---|---|---|
| 1 | Nagadali | Ganike | 9 | 两类植物叶片形状和叶脉排布极为相似 |
| 2 | Seethapala | Camphor | 4 | 叶片颜色和光泽度相近 |
| 3 | Tulsi | Bhrami | 2 | 圣罗勒与婆罗米幼叶形态接近 |
| 4 | Tomato | Chilly | 2 | 番茄与辣椒幼叶边缘锯齿相似 |
| 5 | Rose | Wood_sorel | 2 | 两者复叶的排列方式在图像中难以区分 |
6.5 硬件与训练时间
| 项目 | 详情 |
|---|---|
| GPU | NVIDIA GeForce RTX 5060 Laptop (8GB GDDR7) |
| GPU 架构 | Blackwell (sm_120) |
| CUDA 版本 | 12.8 |
| PyTorch 版本 | 2.11.0+cu128 |
| 总训练时间 | 31 分 44 秒 |
| 每 epoch 平均耗时 | ~95 秒 |
| 单张图像推理延迟 | ~8 ms (batch_size=1) |
| 推理吞吐量 | ~125 img/s (batch_size=1) |
七、可视化图表详解
运行 python fig.py 后,所有可视化图表生成至 outputs/可视化/ 目录。以下逐一详细说明每张图表的含义和解读方法。
📈 Fig 1: 训练曲线

展示了模型在 20 个 epoch 训练过程中的四个关键指标变化:
- (a) 训练损失 (Training Loss):蓝色细线为每个 batch 的实时损失值,深蓝色粗线为平滑后的趋势线。从初始约 4.8 下降至约 1.9,模型持续收敛。损失偏高是因为 MixUp/CutMix 混合了不同样本的标签,导致交叉熵无法降至接近零。
- (b) 验证损失 (Validation Loss) :从 2.43 单调下降至 1.01,始终保持下降趋势,未出现过拟合迹象。
- © 学习率调度 (LR Schedule):完整展示了 Cosine Warmup 策略。前 5 个 epoch 学习率从 1e-6 线性预热至 1e-4,随后沿余弦曲线平滑衰减至接近零。虚线标注了 Warmup→Cosine 的过渡点。
- (d) 验证准确率 (Validation Accuracy):Acc@1(蓝色方块线)和 Acc@5(红色菱形线)。Acc@5 在 Epoch 12 即达到 99.38%,说明 Top-5 几乎完美;Top-1 最终达 95.50%。
💡 关键洞察:验证损失始终下降,验证准确率始终上升,表明模型的泛化能力持续改善,训练策略有效。
🎯 Fig 2: 混淆矩阵

从 98 类中随机选取 8 类展示归一化混淆矩阵,便于在论文和报告中清晰呈现:
- 对角线(红色边框):正确分类的样本比例。对角线数值越高越好
- 颜色深浅:白色 = 0%,深蓝色 = 100%
- 数字标注规则:高置信度(>50%)用白色大字,低置信度用灰色小字
- 平均准确率:标题中显示 8 类的平均对角线准确率
💡 关键洞察:对角线值普遍在 85%-100%,说明这 8 类区分度良好。非对角线值很低,表明类别间混淆较少。红色边框突出对角线元素,让读者一眼就能判断模型的分类质量。
🔥 Fig 3: Grad-CAM 可视化

Grad-CAM (Gradient-weighted Class Activation Mapping) 是目前最主流的深度学习可解释性方法。其核心原理为:
- 计算梯度:目标类别对最后一层卷积特征的梯度 → 反映每个通道的重要性
- 全局平均池化:对梯度进行空间平均得到通道权重
- 加权求和:用权重对特征图加权求和 → 得到热力图
图中:
- 上排:原始图像 + 类别名称
- 下排:Grad-CAM 热力图叠加显示,红色 = 高关注度(判别性区域),蓝色 = 低关注度(背景/无关区域)
- 置信度标注:蓝色置信度条显示模型对该预测的确信程度
💡 关键洞察 :这是验证模型是否学到了正确特征的关键证据。理想的 Grad-CAM 热力图应集中覆盖植物的叶片主体区域。如果热力图聚焦在背景(如土壤、花盆)或边缘无关区域,说明模型存在**捷径学习(shortcut learning)**问题。
🌌 Fig 4: t-SNE 特征嵌入

t-SNE 是目前最常用的高维特征可视化方法。其处理流程为:
- 从 ResNet50 最终池化层提取 2048 维特征向量
- 使用 PCA 预降维至 100 维(保留主要信息)
- 使用 t-SNE 进一步降至 2 维空间进行可视化
图中:
- 每个点:一张测试集图像
- 不同颜色:不同类别
- Silhouette Score:聚类质量指标,范围 -1, 1,越大表示类内紧凑、类间分离越好
💡 关键洞察:相同颜色的点聚集成紧密的簇,不同颜色的簇之间有清晰的边界分离,表明 ResNet50 学到的特征具有良好的类别可分性。高 Silhouette Score(通常 >0.5)反映了模型的优秀表现。
🎨 Fig 5: 数据增强演示

展示同一张植物图像经过数据增强 Pipeline 后的多样性效果:
- 原始图像:中心区域标注 "原始图像"
- 训练增强(×6):每次经过完整的训练增强管道(RandAugment + ColorJitter + RandomResizedCrop),展示随机增强的多样性
- 水平翻转:确定性水平翻转
- 旋转:确定性旋转 +30°
💡 关键洞察:增强后的图像在颜色、构图、角度、尺度上呈现丰富多样性。这种**"同一植物、多种视角"**的训练数据是模型获得强泛化能力的基础------模型学到的不是某张特定图像,而是该植物的本质视觉特征。
✅❌ Fig 6: 预测样例对比

对比展示正确分类和错误分类的预测样例:
- 上排 :原始图像
- 🟢 绿色标题 = 正确预测
- 🔴 红色标题 = 错误预测
- 下排 :Top-5 置信度条形图
- 🟢 绿色条 = 真实类别
- 🔵 蓝色条 = 其他预测类别
- 右下角标注:真实类别名称
💡 关键洞察:正确的预测通常有高置信度(80-99%)且真实类别高居第一位。错误预测的分析可以揭示数据标注问题、类别歧义性或模型的系统性弱点------例如难区分的相似物种对。
📊 Fig 7: 类别分布统计

全面展示 98 类数据集的分布情况:
- (a) 各类别图像分布:按样本数降序排列的柱状图,红色虚线标注均值(约 110 张/类)
- (b) 分布直方图:展示每类样本数的分布形态,红色线标注均值,绿色线标注中位数
- ©Top-20 样本最多类别:数据量最充足的 20 个类别(如 Aloevera 282 张、Tulsi 323 张)
- (d) Bottom-20 样本最少类别:数据量最少的 20 个类别(最少仅 8 张)
💡 关键洞察:长尾分布------头部类别有 300+ 张图像,尾部类别不到 50 张。这直接解释了 Fig 6 中某些小类(如 Seethaashoka 仅 7 张测试样本)表现较差的原因。解决方向包括:过采样、加权损失函数、或数据增强策略。
八、快速开始与复现指南
8.1 环境要求
| 组件 | 最低要求 | 推荐 |
|---|---|---|
| Python | 3.10+ | 3.13 |
| PyTorch | 2.0+ | 2.11 |
| CUDA | 11.8+ | 12.8 |
| GPU 显存 | 4 GB+ | 8 GB+ |
8.2 安装依赖
bash
# 安装所有依赖
pip install -r requirements.txt
# 验证 CUDA 环境
python -c "import torch; print(f'PyTorch {torch.__version__}'); print(f'CUDA: {torch.cuda.is_available()}')"
8.3 下载数据集并训练
bash
# 方式一:自动从 Kaggle 下载(推荐)
python -c "
import kagglehub, os, shutil
# 下载 Indian Medicinal Leaves Dataset
path = kagglehub.dataset_download('aryashah2k/indian-medicinal-leaves-dataset')
src = os.path.join(path, 'Indian Medicinal Leaves Image Datasets')
# 合并两个子集到 data/ 目录
os.makedirs('data', exist_ok=True)
for sub in ['Medicinal Leaf dataset', 'Medicinal plant dataset']:
sub_path = os.path.join(src, sub)
if not os.path.exists(sub_path): continue
for cls in os.listdir(sub_path):
cls_path = os.path.join(sub_path, cls)
if not os.path.isdir(cls_path): continue
clean = cls.strip().replace(' ', '_').replace('(','').replace(')','')
dst = os.path.join('data', clean)
os.makedirs(dst, exist_ok=True)
for i, f in enumerate(os.listdir(cls_path)):
if f.lower().endswith(('.jpg','.jpeg','.png','.bmp')):
ext = os.path.splitext(f)[1]
shutil.copy2(os.path.join(cls_path, f), os.path.join(dst, f'{clean}_{i:04d}{ext}'))
print('✅ 数据集准备完毕!')
"
# 开始训练
python train.py --epochs 20 --batch_size 64 --strategy full
8.4 评估与预测
bash
# 在测试集上全面评估模型
python evaluate.py --checkpoint checkpoints/best_model.pth
# 生成所有可视化图表(7张)
python fig.py
# 单张图像预测
python predict.py --image data/Aloevera/Aloevera_0000.jpg --topk 5
# 批量预测
python predict.py --image_dir data/Aloevera/ --topk 5
# 使用 TTA 提高预测准确率
python predict.py --image your_plant.jpg --tta
# 推理性能基准测试
python predict.py --benchmark --batch_sizes 1,4,8,16,32
8.5 高级训练选项
bash
# 使用 ResNet101 骨干网络(更大容量)
python train.py --backbone resnet101 --epochs 30
# 使用渐进式微调(推荐大数据集)
python train.py --strategy progressive
# 降低 batch_size 以适应较小显存
python train.py --batch_size 16
# 使用自己的数据集
python train.py --data_root /path/to/your/dataset
# 从检查点恢复训练
python train.py --resume checkpoints/best_model.pth
# 训练完成后导出 TorchScript 模型
python train.py --export
九、项目结构
草本植物识别/
└── plant_recognition/
│
├── config.py # 🌐 全局配置 (模型/数据/训练超参数)
├── dataset.py # 📦 数据加载、增强Pipeline、数据集划分
├── models.py # 🏗️ ResNet50模型定义 + 多策略微调管理器
├── utils.py # 🔧 优化器/学习率调度器/损失函数/评估指标
│
├── train.py # 🚂 训练脚本 (全量微调 / 渐进式微调)
├── evaluate.py # 📊 评估脚本 (混淆矩阵 / 错误分析 / Grad-CAM)
├── predict.py # 🔮 推理脚本 (单张/批量/TTA/性能基准)
├── fig.py # 📈 可视化脚本 (生成7张论文级图表)
│
├── requirements.txt # 📋 Python 依赖清单
├── README.md # 📖 项目英文文档
│
├── data/ # 🗂️ 数据集 (98类 / 10,813张图像)
│ ├── Aloevera/
│ ├── Neem/
│ ├── Tulsi/
│ └── ... (共98个类别文件夹)
│
├── checkpoints/ # 💾 模型检查点
│ └── best_model.pth # 最佳模型 (Val Acc@1 = 95.50%)
│
├── logs/ # 📝 TensorBoard 训练日志
│ └── resnet50_YYYYMMDD_HHMMSS/
│
└── outputs/ # 📤 输出文件
├── 可视化/ # 📈 论文级图表 (7张)
│ ├── fig1_训练曲线.png # 四合一训练动态
│ ├── fig2_混淆矩阵.png # 8类混淆矩阵
│ ├── fig3_GradCAM.png # Grad-CAM注意力可视化
│ ├── fig4_tSNE.png # t-SNE特征嵌入
│ ├── fig5_数据增强.png # 数据增强效果展示
│ ├── fig6_预测样例.png # 正确vs错误预测对比
│ └── fig7_类别分布.png # 数据集统计分析
├── evaluation_report.txt # 完整评估报告
├── error_analysis.txt # 错误案例分析
└── results_*.txt # 训练结果摘要
十、总结与展望
🔑 关键技术要点回顾
| 技术 | 作用 | 贡献 |
|---|---|---|
| 迁移学习 | ImageNet 预训练提供通用视觉特征 | 大幅降低对标注数据的需求,从零训练到 95% 仅需 20 epochs |
| RandAugment + MixUp + CutMix | 三层级数据增强 | 显著提升泛化能力,防止过拟合 |
| Label Smoothing | 软化标签,防止过度自信 | 98 类细粒度分类中提升 0.5-1% |
| Cosine Warmup | 预热+平滑衰减 | 避免训练初期震荡,实现精细收敛 |
| 混合精度 (AMP) | FP16 训练 | 显存节省 40%,加速 26%,精度无损 |
| Grad-CAM 可解释性 | 热力图展示模型关注区域 | 验证模型学到了正确的判别性特征 |
📈 模型优点
- ✅ 高准确率:Top-1 95.07%,Top-5 99.26%
- ✅ 高效率:30 分钟训练完成,单张推理仅 8ms
- ✅ 可解释:Grad-CAM 提供决策依据可视化
- ✅ 可扩展:支持 ResNet/ResNeXt/EfficientNet/ViT/Swin 多种骨干网络
- ✅ 工程化:完整的训练-评估-推理-可视化流水线
🔮 改进方向
- 处理类别不平衡:使用过采样(Oversampling)或加权 Focal Loss 改善尾部类别(< 50 样本)的识别率
- 集成学习:训练多个不同骨干网络(ResNet101 + EfficientNet + ViT)并集成,预计可提升 1-2%
- 更大的数据集:迁移至 163 类或 880 类的中草药大规模数据集
- 细粒度特征增强:引入注意力机制(CBAM、SE-Net),聚焦叶脉、边缘等判别性微特征
- 移动端部署:使用知识蒸馏或模型量化(INT8),将模型压缩至 20MB 以内部署到手机端
- 多模态融合:结合叶片形状 + 花朵 + 果实等多种器官信息,提升准确率
📚 参考文献
- He, K., et al. "Deep Residual Learning for Image Recognition." CVPR 2016 . [arXiv:1512.03385](https://arxiv.org/abs/1512.03385)
- Cubuk, E. D., et al. "RandAugment: Practical automated data augmentation with a reduced search space." NeurIPS 2020 . [arXiv:1909.13719](https://arxiv.org/abs/1909.13719)
- Zhang, H., et al. "mixup: Beyond Empirical Risk Minimization." ICLR 2018 . [arXiv:1710.09412](https://arxiv.org/abs/1710.09412)
- Yun, S., et al. "CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features." ICCV 2019 . [arXiv:1905.04899](https://arxiv.org/abs/1905.04899)
- Müller, R., et al. "When Does Label Smoothing Help?" NeurIPS 2019 . [arXiv:1906.02629](https://arxiv.org/abs/1906.02629)
- Loshchilov, I. & Hutter, F. "Decoupled Weight Decay Regularization." ICLR 2019 . [arXiv:1711.05101](https://arxiv.org/abs/1711.05101)
- Selvaraju, R. R., et al. "Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization." ICCV 2017 . [arXiv:1610.02391](https://arxiv.org/abs/1610.02391)
🌿 基于深度学习的药用植物识别系统
Deep Learning-Based Medicinal Plant Recognition System
~ResNet50 · 迁移学习 · 98 类 · 95.07% Top-1 · Python 3.13 · PyTorch 2.11~
📧 如果需要全部的代码以及相关过程图与结果图,可以私信与我联系