目录
[3.1 配置管理模块 (config.py)](#3.1 配置管理模块 (config.py))
[3.2 数据加载模块 (dataload)](#3.2 数据加载模块 (dataload))
[3.3 模型模块 (classic_models)](#3.3 模型模块 (classic_models))
[3.4 工具模块 (utils)](#3.4 工具模块 (utils))
[4.1 train.py - 训练脚本](#4.1 train.py - 训练脚本)
[4.2 predict.py - 预测脚本](#4.2 predict.py - 预测脚本)
[4.3 gui.py - PyQt5 图形界面](#4.3 gui.py - PyQt5 图形界面)
[4.4 run.py - 快速启动](#4.4 run.py - 快速启动)
[5.1 环境准备](#5.1 环境准备)
[5.2 数据集准备](#5.2 数据集准备)
[5.3 启动方式](#5.3 启动方式)
一、项目概述
FiveFlower 是一个基于深度学习的五类花卉图像分类系统。该项目从原始代码重构优化而来,采用模块化设计,新增 PyQt5 图形用户界面,支持模型训练、图像预测等功能。
支持的花卉类别:
• daisy(雏菊)
• dandelion(蒲公英)
• roses(玫瑰)
• sunflowers(向日葵)
• tulips(郁金香)
二、项目结构
|---------------------|-----|-------------------|
| 文件/目录 | 类型 | 说明 |
| _FiveFlower/ | 包目录 | 核心代码包 |
| ├── config.py | 模块 | 统一配置管理 |
| ├── dataload/ | 包 | 数据加载模块 |
| ├── classic_models/ | 包 | 模型定义模块 |
| └── utils/ | 包 | 工具函数模块 |
| train.py | 脚本 | 命令行训练脚本 |
| predict.py | 脚本 | 命令行预测脚本 |
| gui.py | 脚本 | PyQt5 图形界面 |
| run.py | 脚本 | 快速启动脚本 |
| Datasets/ | 目录 | 数据集(train/val子目录) |
| results/ | 目录 | 训练结果保存 |
三、核心模块详解
3.1 配置管理模块 (config.py)
config.py 采用集中式配置管理,将所有超参数、路径、数据集信息统一存放,便于维护和修改。
核心配置项示例:
# 路径配置
PROJECT_ROOT = Path(__file__).parent.parent
DATA_ROOT = PROJECT_ROOT / "Datasets"
RESULTS_ROOT = PROJECT_ROOT / "results"
# 数据集配置
DATASET_CONFIG = {
"num_classes": 5,
"image_size": 224,
"class_names": ["daisy", "dandelion", "roses",
"sunflowers", "tulips"],
}
# 训练默认参数
TRAIN_CONFIG = {
"epochs": 50,
"batch_size": 32,
"lr": 0.001,
"optimizer": "adamw",
"scheduler": "cosine",
"early_stop_patience": 10,
"use_amp": True, # 混合精度训练
}
- 设计优点:
• 修改参数只需编辑一处,无需遍历所有脚本
• 新增功能时可直接引用已有配置
• get_transforms() 函数返回训练/验证的数据增强变换
3.2 数据加载模块 (dataload)
(1)FiveFlowersDataset 数据集类
继承自 torch.utils.data.Dataset,负责加载花卉图像:
class FiveFlowersDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = Path(root_dir)
self.transform = transform
self.samples = []
# 遍历文件夹,按目录名自动识别类别
for class_idx, class_name in enumerate(self.class_names):
class_dir = self.root_dir / class_name
for img_path in class_dir.glob("*"):
if img_path.suffix.lower() in ['.jpg', '.jpeg', '.png']:
self.samples.append((str(img_path), class_idx))
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
img_path, label = self.samples[idx]
image = Image.open(img_path).convert('RGB')
if self.transform:
image = self.transform(image)
return image, label
关键特性:
- • 自动扫描目录结构,无需手动标注文件
- • 支持 jpg/jpeg/png/bmp/webp 多种格式
- • transform 参数支持自定义数据增强
(2)FiveFlowersDataLoader 数据加载器
封装数据加载逻辑,返回可直接使用的 DataLoader:
class FiveFlowersDataLoader:
def get_loaders(self):
# 训练集使用数据增强
train_dataset = FiveFlowersDataset(
self.train_path,
transform=get_transforms("train")
)
# 验证集仅做基本预处理
val_dataset = FiveFlowersDataset(
self.val_path,
transform=get_transforms("val")
)
train_loader = DataLoader(train_dataset,
batch_size=self.batch_size,
shuffle=True, num_workers=self.num_workers)
val_loader = DataLoader(val_dataset,
batch_size=self.batch_size,
shuffle=False, num_workers=self.num_workers)
return train_loader, val_loader, self.class_names
3.3 模型模块 (classic_models)
ModelFactory 工厂类统一管理模型创建和加载:
class ModelFactory:
@staticmethod
def create_model(model_name, num_classes=5, pretrained=True,
freeze_layers=0, dropout=0.5):
"""创建预训练模型并修改分类头"""
if model_name == 'resnet18':
model = models.resnet18(pretrained=pretrained)
# 冻结指定层数
if freeze_layers > 0:
for param in list(model.parameters())[:-freeze_layers]:
param.requires_grad = False
# 修改最后一层
model.fc = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(model.fc.in_features, num_classes)
)
elif model_name == 'efficientnet_b0':
model = models.efficientnet_b0(pretrained=pretrained)
model.classifier[1] = nn.Linear(
model.classifier[1].in_features, num_classes
)
# ... 其他模型
return model
支持的模型:ResNet18/34、VGG16、EfficientNet-B0、MobileNet-V3
- 关键方法:
• create_model(): 创建模型,支持预训练权重、层冻结、dropout
• load_checkpoint(): 加载已训练的权重文件
• get_model_summary(): 获取模型参数量统计
3.4 工具模块 (utils)
(1)Trainer 训练器类
class Trainer:
def train_epoch(self, model, train_loader, criterion,
optimizer, device, use_amp=False):
"""训练一个 epoch,支持混合精度"""
model.train()
running_loss = 0.0
scaler = torch.cuda.amp.GradScaler() if use_amp else None
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
if use_amp:
with torch.cuda.amp.autocast():
outputs = model(images)
loss = criterion(outputs, labels)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
return running_loss / len(train_loader)
混合精度训练 (AMP):利用 Tensor Core 加速,减少显存占用,训练速度提升约 2 倍。
(2)EarlyStopping 早停机制
class EarlyStopping:
def __init__(self, patience=10, min_delta=0.001):
self.patience = patience
self.min_delta = min_delta
self.counter = 0
self.best_loss = None
def __call__(self, val_loss):
if self.best_loss is None or val_loss < self.best_loss - self.min_delta:
self.best_loss = val_loss
self.counter = 0
return False # 继续训练
else:
self.counter += 1
if self.counter >= self.patience:
return True # 停止训练
return False
当验证损失连续 patience 次未改善时,自动停止训练,防止过拟合。
四、主要脚本使用说明
4.1 train.py - 训练脚本
# 使用示例
python train.py --model resnet18 --epochs 50 --batch_size 32 --lr 0.001
# 完整参数
python train.py \
--model resnet18 \ # 模型选择
--data_path Datasets \ # 数据集路径
--epochs 50 \ # 训练轮数
--batch_size 32 \ # 批大小
--lr 0.001 \ # 学习率
--optimizer adamw \ # 优化器
--scheduler cosine \ # 学习率调度
--early_stop 10 \ # 早停轮数
--use_amp \ # 启用混合精度
--label_smoothing 0.1 # 标签平滑
训练流程:
-
加载数据集,划分训练集和验证集
-
创建模型,加载预训练权重
-
配置优化器和学习率调度器
-
循环训练,每个 epoch 后验证
-
保存最优模型到 results/weights/
4.2 predict.py - 预测脚本
单张图像预测
python predict.py --image flower.jpg --model resnet18 --weight best_model.pth
输出示例
预测结果: roses (玫瑰)
置信度: 95.3%
各类别概率:
daisy: 2.1%
dandelion: 0.8%
roses: 95.3%
sunflowers: 1.2%
tulips: 0.6%
4.3 gui.py - PyQt5 图形界面
图形界面提供三个功能标签页:
(1)图像预测页
# 核心预测逻辑
def predict_image(self, image_path):
# 加载图像
image = Image.open(image_path).convert('RGB')
input_tensor = self.transform(image).unsqueeze(0).to(self.device)
# 模型推理
self.model.eval()
with torch.no_grad():
outputs = self.model(input_tensor)
probs = torch.nn.functional.softmax(outputs, dim=1)
# 显示结果
pred_class = self.class_names[probs.argmax()]
confidence = probs.max().item() * 100
(2)模型训练页
• 可视化参数配置(模型、轮数、学习率等)
• 训练进度条实时显示
• 日志输出窗口
(3)模型管理页
• 查看已训练模型列表
• 加载/删除模型权重
4.4 run.py - 快速启动
一键启动 GUI
python run.py
脚本会自动:
1. 检查依赖是否安装
2. 检查数据集目录
3. 启动 PyQt5 界面
五、快速开始
5.1 环境准备
# 安装依赖
pip install -r requirements.txt
# requirements.txt 内容
torch>=1.8.0
torchvision>=0.9.0
PyQt5>=5.10
matplotlib
numpy
Pillow
5.2 数据集准备
Datasets/
├── train/
│ ├── daisy/
│ │ ├── img001.jpg
│ │ └── ...
│ ├── dandelion/
│ ├── roses/
│ ├── sunflowers/
│ └── tulips/
└── val/
├── daisy/
├── dandelion/
├── roses/
├── sunflowers/
└── tulips/
5.3 启动方式
命令行训练:
python train.py --model resnet18 --epochs 50
命令行预测:
python predict.py --image test.jpg --model resnet18
图形界面:
python run.py
六、核心技术特性
|--------|-----------|---------------------------------------|
| 技术 | 作用 | 实现方式 |
| 混合精度训练 | 加速训练、减少显存 | torch.cuda.amp.autocast |
| 学习率调度 | 自动调整学习率 | CosineAnnealingLR + Warmup |
| 早停机制 | 防止过拟合 | EarlyStopping 类 |
| 数据增强 | 提升泛化能力 | RandomCrop/Flip/Rotation/ColorJitter |
| 标签平滑 | 提高模型鲁棒性 | CrossEntropyLoss(label_smoothing=0.1) |
| 层冻结 | 加速收敛 | param.requires_grad = False |
| 后台训练 | 界面不卡顿 | QThread + pyqtSignal |
七、注意事项
- 首次运行会自动下载预训练权重(约 100MB),需要网络连接
- 建议使用 GPU 训练,CPU 训练速度较慢
- PyQt5 版本问题:如遇 setNotation 报错,请升级 PyQt5:pip install --upgrade PyQt5
- 训练结果默认保存在 results/weights/ 目录
- 支持断点续训:加载已有权重继续训练