FiveFlower 五类花卉图像分类系统

目录

一、项目概述

二、项目结构

三、核心模块详解

[3.1 配置管理模块 (config.py)](#3.1 配置管理模块 (config.py))

[3.2 数据加载模块 (dataload)](#3.2 数据加载模块 (dataload))

[3.3 模型模块 (classic_models)](#3.3 模型模块 (classic_models))

[3.4 工具模块 (utils)](#3.4 工具模块 (utils))

四、主要脚本使用说明

[4.1 train.py - 训练脚本](#4.1 train.py - 训练脚本)

[4.2 predict.py - 预测脚本](#4.2 predict.py - 预测脚本)

[4.3 gui.py - PyQt5 图形界面](#4.3 gui.py - PyQt5 图形界面)

[4.4 run.py - 快速启动](#4.4 run.py - 快速启动)

五、快速开始

[5.1 环境准备](#5.1 环境准备)

[5.2 数据集准备](#5.2 数据集准备)

[5.3 启动方式](#5.3 启动方式)

六、核心技术特性

七、注意事项


一、项目概述

FiveFlower 是一个基于深度学习的五类花卉图像分类系统。该项目从原始代码重构优化而来,采用模块化设计,新增 PyQt5 图形用户界面,支持模型训练、图像预测等功能。

支持的花卉类别:

• daisy(雏菊)

• dandelion(蒲公英)

• roses(玫瑰)

• sunflowers(向日葵)

• tulips(郁金香)

二、项目结构

|---------------------|-----|-------------------|
| 文件/目录 | 类型 | 说明 |
| _FiveFlower/ | 包目录 | 核心代码包 |
| ├── config.py | 模块 | 统一配置管理 |
| ├── dataload/ | 包 | 数据加载模块 |
| ├── classic_models/ | 包 | 模型定义模块 |
| └── utils/ | 包 | 工具函数模块 |
| train.py | 脚本 | 命令行训练脚本 |
| predict.py | 脚本 | 命令行预测脚本 |
| gui.py | 脚本 | PyQt5 图形界面 |
| run.py | 脚本 | 快速启动脚本 |
| Datasets/ | 目录 | 数据集(train/val子目录) |
| results/ | 目录 | 训练结果保存 |

三、核心模块详解

3.1 配置管理模块 (config.py)

config.py 采用集中式配置管理,将所有超参数、路径、数据集信息统一存放,便于维护和修改。

核心配置项示例:

复制代码
# 路径配置
PROJECT_ROOT = Path(__file__).parent.parent
DATA_ROOT = PROJECT_ROOT / "Datasets"
RESULTS_ROOT = PROJECT_ROOT / "results"

# 数据集配置
DATASET_CONFIG = {
    "num_classes": 5,
    "image_size": 224,
    "class_names": ["daisy", "dandelion", "roses",
                    "sunflowers", "tulips"],
}

# 训练默认参数
TRAIN_CONFIG = {
    "epochs": 50,
    "batch_size": 32,
    "lr": 0.001,
    "optimizer": "adamw",
    "scheduler": "cosine",
    "early_stop_patience": 10,
    "use_amp": True,  # 混合精度训练
}
  • 设计优点:

• 修改参数只需编辑一处,无需遍历所有脚本

• 新增功能时可直接引用已有配置

• get_transforms() 函数返回训练/验证的数据增强变换

3.2 数据加载模块 (dataload)

(1)FiveFlowersDataset 数据集类

继承自 torch.utils.data.Dataset,负责加载花卉图像:

复制代码
class FiveFlowersDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = Path(root_dir)
        self.transform = transform
        self.samples = []
        
        # 遍历文件夹,按目录名自动识别类别
        for class_idx, class_name in enumerate(self.class_names):
            class_dir = self.root_dir / class_name
            for img_path in class_dir.glob("*"):
                if img_path.suffix.lower() in ['.jpg', '.jpeg', '.png']:
                    self.samples.append((str(img_path), class_idx))
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, label

关键特性:

  • • 自动扫描目录结构,无需手动标注文件
  • • 支持 jpg/jpeg/png/bmp/webp 多种格式
  • • transform 参数支持自定义数据增强

(2)FiveFlowersDataLoader 数据加载器

封装数据加载逻辑,返回可直接使用的 DataLoader:

复制代码
class FiveFlowersDataLoader:
    def get_loaders(self):
        # 训练集使用数据增强
        train_dataset = FiveFlowersDataset(
            self.train_path,
            transform=get_transforms("train")
        )
        # 验证集仅做基本预处理
        val_dataset = FiveFlowersDataset(
            self.val_path,
            transform=get_transforms("val")
        )
        
        train_loader = DataLoader(train_dataset,
                                  batch_size=self.batch_size,
                                  shuffle=True, num_workers=self.num_workers)
        val_loader = DataLoader(val_dataset,
                                batch_size=self.batch_size,
                                shuffle=False, num_workers=self.num_workers)
        return train_loader, val_loader, self.class_names

3.3 模型模块 (classic_models)

ModelFactory 工厂类统一管理模型创建和加载:

复制代码
class ModelFactory:
    @staticmethod
    def create_model(model_name, num_classes=5, pretrained=True,
                     freeze_layers=0, dropout=0.5):
        """创建预训练模型并修改分类头"""
        if model_name == 'resnet18':
            model = models.resnet18(pretrained=pretrained)
            # 冻结指定层数
            if freeze_layers > 0:
                for param in list(model.parameters())[:-freeze_layers]:
                    param.requires_grad = False
            # 修改最后一层
            model.fc = nn.Sequential(
                nn.Dropout(dropout),
                nn.Linear(model.fc.in_features, num_classes)
            )
        elif model_name == 'efficientnet_b0':
            model = models.efficientnet_b0(pretrained=pretrained)
            model.classifier[1] = nn.Linear(
                model.classifier[1].in_features, num_classes
            )
        # ... 其他模型
        return model

支持的模型:ResNet18/34、VGG16、EfficientNet-B0、MobileNet-V3

  • 关键方法:

• create_model(): 创建模型,支持预训练权重、层冻结、dropout

• load_checkpoint(): 加载已训练的权重文件

• get_model_summary(): 获取模型参数量统计

3.4 工具模块 (utils)

(1)Trainer 训练器类

复制代码
class Trainer:
    def train_epoch(self, model, train_loader, criterion,
                    optimizer, device, use_amp=False):
        """训练一个 epoch,支持混合精度"""
        model.train()
        running_loss = 0.0
        scaler = torch.cuda.amp.GradScaler() if use_amp else None
        
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            
            if use_amp:
                with torch.cuda.amp.autocast():
                    outputs = model(images)
                    loss = criterion(outputs, labels)
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                outputs = model(images)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
            
            running_loss += loss.item()
        return running_loss / len(train_loader)

混合精度训练 (AMP):利用 Tensor Core 加速,减少显存占用,训练速度提升约 2 倍。

(2)EarlyStopping 早停机制

复制代码
class EarlyStopping:
    def __init__(self, patience=10, min_delta=0.001):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        
    def __call__(self, val_loss):
        if self.best_loss is None or val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
            return False  # 继续训练
        else:
            self.counter += 1
            if self.counter >= self.patience:
                return True  # 停止训练
            return False

当验证损失连续 patience 次未改善时,自动停止训练,防止过拟合。

四、主要脚本使用说明

4.1 train.py - 训练脚本

复制代码
# 使用示例
python train.py --model resnet18 --epochs 50 --batch_size 32 --lr 0.001

# 完整参数
python train.py \
    --model resnet18 \          # 模型选择
    --data_path Datasets \       # 数据集路径
    --epochs 50 \                # 训练轮数
    --batch_size 32 \            # 批大小
    --lr 0.001 \                 # 学习率
    --optimizer adamw \          # 优化器
    --scheduler cosine \         # 学习率调度
    --early_stop 10 \            # 早停轮数
    --use_amp \                  # 启用混合精度
    --label_smoothing 0.1         # 标签平滑

训练流程:

  1. 加载数据集,划分训练集和验证集

  2. 创建模型,加载预训练权重

  3. 配置优化器和学习率调度器

  4. 循环训练,每个 epoch 后验证

  5. 保存最优模型到 results/weights/

4.2 predict.py - 预测脚本

单张图像预测

python predict.py --image flower.jpg --model resnet18 --weight best_model.pth

输出示例

预测结果: roses (玫瑰)

置信度: 95.3%

各类别概率:

daisy: 2.1%

dandelion: 0.8%

roses: 95.3%

sunflowers: 1.2%

tulips: 0.6%

4.3 gui.py - PyQt5 图形界面

图形界面提供三个功能标签页:

(1)图像预测页

复制代码
# 核心预测逻辑
def predict_image(self, image_path):
    # 加载图像
    image = Image.open(image_path).convert('RGB')
    input_tensor = self.transform(image).unsqueeze(0).to(self.device)
    
    # 模型推理
    self.model.eval()
    with torch.no_grad():
        outputs = self.model(input_tensor)
        probs = torch.nn.functional.softmax(outputs, dim=1)
    
    # 显示结果
    pred_class = self.class_names[probs.argmax()]
    confidence = probs.max().item() * 100

(2)模型训练页

• 可视化参数配置(模型、轮数、学习率等)

• 训练进度条实时显示

• 日志输出窗口

(3)模型管理页

• 查看已训练模型列表

• 加载/删除模型权重

4.4 run.py - 快速启动

一键启动 GUI

python run.py

脚本会自动:

1. 检查依赖是否安装

2. 检查数据集目录

3. 启动 PyQt5 界面

五、快速开始

5.1 环境准备

复制代码
# 安装依赖
pip install -r requirements.txt

# requirements.txt 内容
torch>=1.8.0
torchvision>=0.9.0
PyQt5>=5.10
matplotlib
numpy
Pillow

5.2 数据集准备

Datasets/

├── train/

│ ├── daisy/

│ │ ├── img001.jpg

│ │ └── ...

│ ├── dandelion/

│ ├── roses/

│ ├── sunflowers/

│ └── tulips/

└── val/

├── daisy/

├── dandelion/

├── roses/

├── sunflowers/

└── tulips/

5.3 启动方式

命令行训练:

python train.py --model resnet18 --epochs 50

命令行预测:

python predict.py --image test.jpg --model resnet18

图形界面:

python run.py

六、核心技术特性

|--------|-----------|---------------------------------------|
| 技术 | 作用 | 实现方式 |
| 混合精度训练 | 加速训练、减少显存 | torch.cuda.amp.autocast |
| 学习率调度 | 自动调整学习率 | CosineAnnealingLR + Warmup |
| 早停机制 | 防止过拟合 | EarlyStopping 类 |
| 数据增强 | 提升泛化能力 | RandomCrop/Flip/Rotation/ColorJitter |
| 标签平滑 | 提高模型鲁棒性 | CrossEntropyLoss(label_smoothing=0.1) |
| 层冻结 | 加速收敛 | param.requires_grad = False |
| 后台训练 | 界面不卡顿 | QThread + pyqtSignal |

七、注意事项

  • 首次运行会自动下载预训练权重(约 100MB),需要网络连接
  • 建议使用 GPU 训练,CPU 训练速度较慢
  • PyQt5 版本问题:如遇 setNotation 报错,请升级 PyQt5:pip install --upgrade PyQt5
  • 训练结果默认保存在 results/weights/ 目录
  • 支持断点续训:加载已有权重继续训练
相关推荐
vx_biyesheji00015 小时前
计算机毕业设计:Python股价预测与可视化系统 Flask框架 数据分析 可视化 机器学习 随机森林 大数据(建议收藏)✅
python·机器学习·信息可视化·数据分析·flask·课程设计
大龄程序员狗哥10 小时前
第25篇:Q-Learning算法解析——强化学习中的经典“价值”学习(原理解析)
人工智能·学习·算法
陶陶然Yay10 小时前
神经网络常见层Numpy封装参考(5):其他层
人工智能·神经网络·numpy
极客老王说Agent10 小时前
2026实战指南:如何用智能体实现药品不良反应报告的自动录入?
人工智能·ai·chatgpt
imbackneverdie10 小时前
本科毕业论文怎么写?需要用到什么工具?
人工智能·考研·aigc·ai写作·学术·毕业论文·ai工具
lulu121654407810 小时前
Claude Code项目大了响应慢怎么办?Subagents、Agent Teams、Git Worktree、工作流编排四种方案深度解析
java·人工智能·python·ai编程
大橙子打游戏10 小时前
talkcozy像聊微信一样多项目同时开发
人工智能·vibecoding
deephub10 小时前
LangChain 还是 LangGraph?一个是编排一个是工具包
人工智能·langchain·大语言模型·langgraph
OidEncoder11 小时前
编码器分辨率与机械精度的关系
人工智能·算法·机器人·自动化
Championship.23.2411 小时前
Harness工程深度解析:从理论到实践的完整指南
人工智能·harness