FiveFlower 五类花卉图像分类系统

目录

一、项目概述

二、项目结构

三、核心模块详解

[3.1 配置管理模块 (config.py)](#3.1 配置管理模块 (config.py))

[3.2 数据加载模块 (dataload)](#3.2 数据加载模块 (dataload))

[3.3 模型模块 (classic_models)](#3.3 模型模块 (classic_models))

[3.4 工具模块 (utils)](#3.4 工具模块 (utils))

四、主要脚本使用说明

[4.1 train.py - 训练脚本](#4.1 train.py - 训练脚本)

[4.2 predict.py - 预测脚本](#4.2 predict.py - 预测脚本)

[4.3 gui.py - PyQt5 图形界面](#4.3 gui.py - PyQt5 图形界面)

[4.4 run.py - 快速启动](#4.4 run.py - 快速启动)

五、快速开始

[5.1 环境准备](#5.1 环境准备)

[5.2 数据集准备](#5.2 数据集准备)

[5.3 启动方式](#5.3 启动方式)

六、核心技术特性

七、注意事项


一、项目概述

FiveFlower 是一个基于深度学习的五类花卉图像分类系统。该项目从原始代码重构优化而来,采用模块化设计,新增 PyQt5 图形用户界面,支持模型训练、图像预测等功能。

支持的花卉类别:

• daisy(雏菊)

• dandelion(蒲公英)

• roses(玫瑰)

• sunflowers(向日葵)

• tulips(郁金香)

二、项目结构

|---------------------|-----|-------------------|
| 文件/目录 | 类型 | 说明 |
| _FiveFlower/ | 包目录 | 核心代码包 |
| ├── config.py | 模块 | 统一配置管理 |
| ├── dataload/ | 包 | 数据加载模块 |
| ├── classic_models/ | 包 | 模型定义模块 |
| └── utils/ | 包 | 工具函数模块 |
| train.py | 脚本 | 命令行训练脚本 |
| predict.py | 脚本 | 命令行预测脚本 |
| gui.py | 脚本 | PyQt5 图形界面 |
| run.py | 脚本 | 快速启动脚本 |
| Datasets/ | 目录 | 数据集(train/val子目录) |
| results/ | 目录 | 训练结果保存 |

三、核心模块详解

3.1 配置管理模块 (config.py)

config.py 采用集中式配置管理,将所有超参数、路径、数据集信息统一存放,便于维护和修改。

核心配置项示例:

复制代码
# 路径配置
PROJECT_ROOT = Path(__file__).parent.parent
DATA_ROOT = PROJECT_ROOT / "Datasets"
RESULTS_ROOT = PROJECT_ROOT / "results"

# 数据集配置
DATASET_CONFIG = {
    "num_classes": 5,
    "image_size": 224,
    "class_names": ["daisy", "dandelion", "roses",
                    "sunflowers", "tulips"],
}

# 训练默认参数
TRAIN_CONFIG = {
    "epochs": 50,
    "batch_size": 32,
    "lr": 0.001,
    "optimizer": "adamw",
    "scheduler": "cosine",
    "early_stop_patience": 10,
    "use_amp": True,  # 混合精度训练
}
  • 设计优点:

• 修改参数只需编辑一处,无需遍历所有脚本

• 新增功能时可直接引用已有配置

• get_transforms() 函数返回训练/验证的数据增强变换

3.2 数据加载模块 (dataload)

(1)FiveFlowersDataset 数据集类

继承自 torch.utils.data.Dataset,负责加载花卉图像:

复制代码
class FiveFlowersDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = Path(root_dir)
        self.transform = transform
        self.samples = []
        
        # 遍历文件夹,按目录名自动识别类别
        for class_idx, class_name in enumerate(self.class_names):
            class_dir = self.root_dir / class_name
            for img_path in class_dir.glob("*"):
                if img_path.suffix.lower() in ['.jpg', '.jpeg', '.png']:
                    self.samples.append((str(img_path), class_idx))
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, label

关键特性:

  • • 自动扫描目录结构,无需手动标注文件
  • • 支持 jpg/jpeg/png/bmp/webp 多种格式
  • • transform 参数支持自定义数据增强

(2)FiveFlowersDataLoader 数据加载器

封装数据加载逻辑,返回可直接使用的 DataLoader:

复制代码
class FiveFlowersDataLoader:
    def get_loaders(self):
        # 训练集使用数据增强
        train_dataset = FiveFlowersDataset(
            self.train_path,
            transform=get_transforms("train")
        )
        # 验证集仅做基本预处理
        val_dataset = FiveFlowersDataset(
            self.val_path,
            transform=get_transforms("val")
        )
        
        train_loader = DataLoader(train_dataset,
                                  batch_size=self.batch_size,
                                  shuffle=True, num_workers=self.num_workers)
        val_loader = DataLoader(val_dataset,
                                batch_size=self.batch_size,
                                shuffle=False, num_workers=self.num_workers)
        return train_loader, val_loader, self.class_names

3.3 模型模块 (classic_models)

ModelFactory 工厂类统一管理模型创建和加载:

复制代码
class ModelFactory:
    @staticmethod
    def create_model(model_name, num_classes=5, pretrained=True,
                     freeze_layers=0, dropout=0.5):
        """创建预训练模型并修改分类头"""
        if model_name == 'resnet18':
            model = models.resnet18(pretrained=pretrained)
            # 冻结指定层数
            if freeze_layers > 0:
                for param in list(model.parameters())[:-freeze_layers]:
                    param.requires_grad = False
            # 修改最后一层
            model.fc = nn.Sequential(
                nn.Dropout(dropout),
                nn.Linear(model.fc.in_features, num_classes)
            )
        elif model_name == 'efficientnet_b0':
            model = models.efficientnet_b0(pretrained=pretrained)
            model.classifier[1] = nn.Linear(
                model.classifier[1].in_features, num_classes
            )
        # ... 其他模型
        return model

支持的模型:ResNet18/34、VGG16、EfficientNet-B0、MobileNet-V3

  • 关键方法:

• create_model(): 创建模型,支持预训练权重、层冻结、dropout

• load_checkpoint(): 加载已训练的权重文件

• get_model_summary(): 获取模型参数量统计

3.4 工具模块 (utils)

(1)Trainer 训练器类

复制代码
class Trainer:
    def train_epoch(self, model, train_loader, criterion,
                    optimizer, device, use_amp=False):
        """训练一个 epoch,支持混合精度"""
        model.train()
        running_loss = 0.0
        scaler = torch.cuda.amp.GradScaler() if use_amp else None
        
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            
            if use_amp:
                with torch.cuda.amp.autocast():
                    outputs = model(images)
                    loss = criterion(outputs, labels)
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                outputs = model(images)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
            
            running_loss += loss.item()
        return running_loss / len(train_loader)

混合精度训练 (AMP):利用 Tensor Core 加速,减少显存占用,训练速度提升约 2 倍。

(2)EarlyStopping 早停机制

复制代码
class EarlyStopping:
    def __init__(self, patience=10, min_delta=0.001):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        
    def __call__(self, val_loss):
        if self.best_loss is None or val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
            return False  # 继续训练
        else:
            self.counter += 1
            if self.counter >= self.patience:
                return True  # 停止训练
            return False

当验证损失连续 patience 次未改善时,自动停止训练,防止过拟合。

四、主要脚本使用说明

4.1 train.py - 训练脚本

复制代码
# 使用示例
python train.py --model resnet18 --epochs 50 --batch_size 32 --lr 0.001

# 完整参数
python train.py \
    --model resnet18 \          # 模型选择
    --data_path Datasets \       # 数据集路径
    --epochs 50 \                # 训练轮数
    --batch_size 32 \            # 批大小
    --lr 0.001 \                 # 学习率
    --optimizer adamw \          # 优化器
    --scheduler cosine \         # 学习率调度
    --early_stop 10 \            # 早停轮数
    --use_amp \                  # 启用混合精度
    --label_smoothing 0.1         # 标签平滑

训练流程:

  1. 加载数据集,划分训练集和验证集

  2. 创建模型,加载预训练权重

  3. 配置优化器和学习率调度器

  4. 循环训练,每个 epoch 后验证

  5. 保存最优模型到 results/weights/

4.2 predict.py - 预测脚本

单张图像预测

python predict.py --image flower.jpg --model resnet18 --weight best_model.pth

输出示例

预测结果: roses (玫瑰)

置信度: 95.3%

各类别概率:

daisy: 2.1%

dandelion: 0.8%

roses: 95.3%

sunflowers: 1.2%

tulips: 0.6%

4.3 gui.py - PyQt5 图形界面

图形界面提供三个功能标签页:

(1)图像预测页

复制代码
# 核心预测逻辑
def predict_image(self, image_path):
    # 加载图像
    image = Image.open(image_path).convert('RGB')
    input_tensor = self.transform(image).unsqueeze(0).to(self.device)
    
    # 模型推理
    self.model.eval()
    with torch.no_grad():
        outputs = self.model(input_tensor)
        probs = torch.nn.functional.softmax(outputs, dim=1)
    
    # 显示结果
    pred_class = self.class_names[probs.argmax()]
    confidence = probs.max().item() * 100

(2)模型训练页

• 可视化参数配置(模型、轮数、学习率等)

• 训练进度条实时显示

• 日志输出窗口

(3)模型管理页

• 查看已训练模型列表

• 加载/删除模型权重

4.4 run.py - 快速启动

一键启动 GUI

python run.py

脚本会自动:

1. 检查依赖是否安装

2. 检查数据集目录

3. 启动 PyQt5 界面

五、快速开始

5.1 环境准备

复制代码
# 安装依赖
pip install -r requirements.txt

# requirements.txt 内容
torch>=1.8.0
torchvision>=0.9.0
PyQt5>=5.10
matplotlib
numpy
Pillow

5.2 数据集准备

Datasets/

├── train/

│ ├── daisy/

│ │ ├── img001.jpg

│ │ └── ...

│ ├── dandelion/

│ ├── roses/

│ ├── sunflowers/

│ └── tulips/

└── val/

├── daisy/

├── dandelion/

├── roses/

├── sunflowers/

└── tulips/

5.3 启动方式

命令行训练:

python train.py --model resnet18 --epochs 50

命令行预测:

python predict.py --image test.jpg --model resnet18

图形界面:

python run.py

六、核心技术特性

|--------|-----------|---------------------------------------|
| 技术 | 作用 | 实现方式 |
| 混合精度训练 | 加速训练、减少显存 | torch.cuda.amp.autocast |
| 学习率调度 | 自动调整学习率 | CosineAnnealingLR + Warmup |
| 早停机制 | 防止过拟合 | EarlyStopping 类 |
| 数据增强 | 提升泛化能力 | RandomCrop/Flip/Rotation/ColorJitter |
| 标签平滑 | 提高模型鲁棒性 | CrossEntropyLoss(label_smoothing=0.1) |
| 层冻结 | 加速收敛 | param.requires_grad = False |
| 后台训练 | 界面不卡顿 | QThread + pyqtSignal |

七、注意事项

  • 首次运行会自动下载预训练权重(约 100MB),需要网络连接
  • 建议使用 GPU 训练,CPU 训练速度较慢
  • PyQt5 版本问题:如遇 setNotation 报错,请升级 PyQt5:pip install --upgrade PyQt5
  • 训练结果默认保存在 results/weights/ 目录
  • 支持断点续训:加载已有权重继续训练
相关推荐
XM_jhxx11 分钟前
±0.03mm的精度怎么保证?翌东塑胶用AI赋能质量管控升级
人工智能
阿正的梦工坊43 分钟前
深入理解 PyTorch 中的 unsqueeze 操作
人工智能·pytorch·python
秦歌6662 小时前
DeepAgents框架详解和文件后端
人工智能·langchain
测试员周周3 小时前
【Appium 系列】第06节-页面对象实现 — LoginPage 实战
开发语言·前端·人工智能·python·功能测试·appium·测试用例
霸道流氓气质3 小时前
基于 Milvus Lite 的 Spring AI RAG 向量库实践方案与示例
人工智能·spring·milvus
ar01233 小时前
AR巡检平台:构筑智能巡检新模式的数字化引擎
人工智能·ar
语音之家3 小时前
【预讲会征集】ACL 2026 论文预讲会
人工智能·论文·acl
碳基硅坊3 小时前
电商场景下的商品自动识别与辅助上架
人工智能
熊猫钓鱼>_>4 小时前
强化学习与决策优化:从理论到工程落地的完整指南
人工智能·llm·强化学习·rl·马尔可夫·mdp·决策过程