深度学习——基于 PyTorch 的蔬菜图像分类

基于 PyTorch 的蔬菜分类系统设计与实现


一、项目简介

本项目旨在利用深度学习技术对蔬菜图片进行自动分类。

系统基于 PyTorch 框架构建,采用 ResNet18 网络作为主干模型,并结合数据增强与迁移学习,实现六类蔬菜的识别任务。

下文展示完整代码与详细讲解。


二、完整代码与分步解析


第 1 部分:模块导入与基础配置

复制代码
# vegetable_classifier.py
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
import pandas as pd
from datetime import datetime

解析:

这部分导入了 PyTorch、TorchVision 及图像处理所需模块:

  • torch:深度学习核心库

  • torchvision:常用图像模型与变换

  • PIL:图像读取与处理

  • pandas:读取数据列表文件

  • datetime:计算训练耗时


第 2 部分:超参数设置与类别定义

复制代码
# ========================
# 1. 配置参数
# ========================
ROOT_DIR = 'vegetables_cls'  # 修改为你的实际路径
BATCH_SIZE = 32
NUM_EPOCHS = 10
NUM_CLASSES = 6
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# DEVICE = torch.device("cpu")

CLASSES = [
    'bocai',
    'changqiezi',
    'hongxiancai',
    'huluobo',
    'xihongshi',
    'xilanhua'
]

# print(f"Using device: {DEVICE}")
# print(f"Classes: {CLASSES}")

解析:

项目根目录

vegetables_cls/

├── train_list.txt
├── val_list.txt
├── test_list.txt
├── bocai/...
├── changqiezi/...
├── hongxiancai/...
├── huluobo/...
├── xihongshi/...
└── xilanhua/...

每个 .txt 文件存储图像路径及其标签,例如:

bocai/img_001.jpg 0

bocai/img_002.jpg 0

xihongshi/img_003.jpg 4

  • ROOT_DIR:数据根目录,包含训练、验证、测试图像。

  • NUM_CLASSES:共 6 种蔬菜类别。

  • DEVICE:自动检测是否可用 GPU 加速。

  • 通过 CLASSES 数组定义类别名。


第 3 部分:自定义 Dataset 类

复制代码
# ========================
# 2. 自定义 Dataset
# ========================
class VegetableDataset(Dataset):
    def __init__(self, data_list_file, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.samples = []

        # 读取 txt 文件
        df = pd.read_csv(data_list_file, sep=' ', header=None, dtype=str)  # 强制按字符串读取

        for idx, row in df.iterrows():
            img_path = os.path.join(root_dir, row[0])
            try:
                label = int(row[1])  # 转为整数
            except ValueError:
                print(f"❌ Invalid label at line {idx+1}: '{row[1]}' is not a number. Skipping.")
                continue

            if not os.path.exists(img_path):
                print(f"❌ Image not found: {img_path}. Skipping.")
                continue

            self.samples.append((img_path, label))

        if len(self.samples) == 0:
            raise RuntimeError(f"No valid samples found in {data_list_file}. Check file paths and labels.")

        print(f"✅ Loaded {len(self.samples)} samples from {data_list_file}")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, label

解析:

该类继承 torch.utils.data.Dataset,实现数据集加载逻辑。

  • 读取训练列表 .txt 文件(路径 + 标签)。

  • 检查路径与标签有效性。

  • 返回 (图像张量, 类别索引)

  • 对无效样本会自动跳过并提示。


第 4 部分:图像增强与预处理

复制代码
# ========================
# 3. 数据增强与预处理
# ========================
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

解析:

  • RandomHorizontalFlip:随机水平翻转增强样本多样性。

  • RandomRotation:模拟不同拍摄角度。

  • ColorJitter:随机调整亮度、对比度、饱和度。

  • Normalize:按 ImageNet 均值和方差归一化。

验证与测试集仅做 Resize 与 Normalize,保持评估一致性。


第 5 部分:模型定义(ResNet18)

复制代码
# ========================
# 4. 模型定义
# ========================
def create_model(num_classes=6):
    model = models.resnet18(pretrained=True)
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    return model.to(DEVICE)

解析:

  • 调用 torchvision 内置的 resnet18,加载 ImageNet 预训练权重。

  • 替换全连接层 fc 以输出 6 类蔬菜类别。

  • 将模型移动至 CPU/GPU。

迁移学习策略能显著提高小数据集的分类效果。


第 6 部分:训练与验证函数

复制代码
# ========================
# 5. 训练与评估函数
# ========================
def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    acc = 100. * correct / total
    avg_loss = running_loss / len(loader)
    return avg_loss, acc


def evaluate(model, loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    acc = 100. * correct / total
    return acc

解析:

  • train_epoch():执行一个完整训练轮次,包括前向传播、反向传播和优化更新。

  • evaluate():在验证或测试集上计算准确率。

  • optimizer.zero_grad():每轮前清空梯度,避免累积。

  • predicted.eq(labels).sum():计算预测正确的样本数量。


第 7 部分:主程序入口与训练流程

复制代码
# ========================
# ✅ 主程序入口(关键!)
# ========================
if __name__ == '__main__':
    # 防止多进程问题(Windows 必需)
    import multiprocessing
    multiprocessing.freeze_support()  # 可选,但推荐

    # 创建数据加载器
    def create_dataloaders():
        train_dataset = VegetableDataset(os.path.join(ROOT_DIR, 'train_list.txt'), ROOT_DIR, transform)
        val_dataset = VegetableDataset(os.path.join(ROOT_DIR, 'val_list.txt'), ROOT_DIR, val_transform)
        test_dataset = VegetableDataset(os.path.join(ROOT_DIR, 'test_list.txt'), ROOT_DIR, val_transform)

        # 设置 num_workers=0 可临时避免问题,但更推荐用 if __name__ + spawn
        train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                                  num_workers=2, pin_memory=True)
        val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False,
                                num_workers=2, pin_memory=True)
        test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False,
                                 num_workers=2, pin_memory=True)

        print(f"Train set: {len(train_dataset)} samples")
        print(f"Val set: {len(val_dataset)} samples")
        print(f"Test set: {len(test_dataset)} samples")

        return train_loader, val_loader, test_loader

    train_loader, val_loader, test_loader = create_dataloaders()

    # 创建模型
    model = create_model(NUM_CLASSES)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=0.0001)

    # 训练循环
    print("\n" + "="*50)
    print("STARTING TRAINING...")
    print("="*50)

    best_val_acc = 0.0
    for epoch in range(NUM_EPOCHS):
        start_time = datetime.now()

        loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, DEVICE)
        val_acc = evaluate(model, val_loader, DEVICE)

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), 'best_vegetable_model.pth')
            print(f"Saved best model with val accuracy: {val_acc:.2f}%")

        elapsed = datetime.now() - start_time
        print(f"Epoch [{epoch+1}/{NUM_EPOCHS}] "
              f"Loss: {loss:.4f} "
              f"Train Acc: {train_acc:.2f}% "
              f"Val Acc: {val_acc:.2f}% "
              f"[{elapsed.total_seconds():.1f}s]")

    # 测试
    print("\n" + "="*50)
    print("EVALUATING ON TEST SET...")
    print("="*50)

    model.load_state_dict(torch.load('best_vegetable_model.pth'))
    test_acc = evaluate(model, test_loader, DEVICE)
    print(f"✅ Final Test Accuracy: {test_acc:.2f}%")

解析:

  • 首先调用 create_dataloaders() 创建训练/验证/测试数据加载器。

  • 使用 Adam 优化器交叉熵损失函数

  • 每个 epoch 记录训练损失与准确率。

  • 若验证集精度提升,则自动保存当前最优模型。

  • 最后加载最佳模型并在测试集上评估最终准确率。


三、实验结果与性能总结

训练完成后,控制台输出如下示例:

复制代码
Epoch [10/10] Loss: 0.1823 Train Acc: 95.78% Val Acc: 93.41% [39.6s]
✅ Final Test Accuracy: 92.87%

结果说明:

  • 模型在 6 类蔬菜数据集上可达到约 90% 以上准确率。

  • 模型参数量适中,运行速度快,适合在嵌入式设备上部署。


四、总结与展望

本项目通过 PyTorch 实现了一个完整的 蔬菜分类系统,具备以下特点:

优点

  • 使用迁移学习(ResNet18)提升准确率

  • 包含训练、验证、测试全流程

  • 数据增强丰富,泛化能力强

  • 自动保存最优模型

可扩展方向

  • 替换为更强模型:ResNet50、EfficientNet、MobileNet

  • 引入学习率调度器与 Early Stopping

  • 使用混合精度训练(torch.cuda.amp)加速

  • 增加类别数量或不同作物图像任务


五、结语

本项目展示了如何通过深度学习技术实现农业智能化的一个重要应用方向------图像分类识别

该系统不仅可用于蔬菜识别,也可扩展至水果、食品、农作物病害检测等领域。

代码模块化、结构清晰,非常适合作为课程实验、科研项目或商业系统原型。

相关推荐
心无旁骛~2 小时前
python多进程和多线程问题
开发语言·python
铅笔侠_小龙虾2 小时前
深度学习理论推导--梯度下降法
人工智能·深度学习
星云数灵2 小时前
使用Anaconda管理Python环境:安装与验证Pandas、NumPy、Matplotlib
开发语言·python·数据分析·pandas·教程·环境配置·anaconda
kaikaile19952 小时前
基于遗传算法的车辆路径问题(VRP)解决方案MATLAB实现
开发语言·人工智能·matlab
lpfasd1232 小时前
第1章_LangGraph的背景与设计哲学
人工智能
计算机毕设匠心工作室2 小时前
【python大数据毕设实战】青少年抑郁症风险数据分析可视化系统、Hadoop、计算机毕业设计、包括数据爬取、数据分析、数据可视化、机器学习
后端·python
计算机毕设小月哥2 小时前
【Hadoop+Spark+python毕设】智能制造生产效能分析与可视化系统、计算机毕业设计、包括数据爬取、Spark、数据分析、数据可视化、Hadoop
后端·python·mysql
Aevget2 小时前
界面组件Kendo UI for React 2025 Q3亮点 - AI功能全面提升
人工智能·react.js·ui·界面控件·kendo ui·ui开发
桜吹雪3 小时前
LangChain.js/DeepAgents可观测性
javascript·人工智能
&&Citrus3 小时前
【杂谈】SNNU公共计算平台:深度学习服务器配置与远程开发指北
服务器·人工智能·vscode·深度学习·snnu