下图所示的是一个图像分类系统,理论上也支持其他场景的图像分类需求,以花卉分类为例,可在界面上选择数据集,自动化划分数据集,配置训练时的迭代次数 学习率等超参数,即可进行训练,训练完成后可对模型进行测试,输出混淆矩阵,支持单站图片预测和批量预测,核心代码如下所示:
目前界面还在完善中,请各位看官敬请谅解
python
class TrainingWidget(QWidget):
"""模型训练界面"""
def __init__(self):
super().__init__()
self.init_ui()
self.model = None
self.history = None
def init_ui(self):
layout = QVBoxLayout()
# 模型选择
model_group = QGroupBox("模型配置")
model_layout = QVBoxLayout()
# 模型类型选择
type_layout = QHBoxLayout()
self.model_type = QComboBox()
self.model_type.addItems(["CNN模型", "MobileNetV2迁移学习"])
type_layout.addWidget(QLabel("模型类型:"))
type_layout.addWidget(self.model_type)
type_layout.addStretch()
model_layout.addLayout(type_layout)
# 训练参数设置
param_layout = QGridLayout()
self.epochs_spin = QSpinBox()
self.epochs_spin.setRange(1, 100)
self.epochs_spin.setValue(10)
self.batch_size_spin = QSpinBox()
self.batch_size_spin.setRange(1, 128)
self.batch_size_spin.setValue(4)
self.learning_rate = QDoubleSpinBox()
self.learning_rate.setRange(0.0001, 0.1)
self.learning_rate.setValue(0.001)
self.learning_rate.setSingleStep(0.0001)
self.learning_rate.setDecimals(4)
param_layout.addWidget(QLabel("训练轮数:"), 0, 0)
param_layout.addWidget(self.epochs_spin, 0, 1)
param_layout.addWidget(QLabel("批次大小:"), 0, 2)
param_layout.addWidget(self.batch_size_spin, 0, 3)
param_layout.addWidget(QLabel("学习率:"), 0, 4)
param_layout.addWidget(self.learning_rate, 0, 5)
model_layout.addLayout(param_layout)
# 数据集路径
data_layout = QHBoxLayout()
self.train_data_path = QLineEdit("./data/flower_photos")
data_layout.addWidget(QLabel("训练数据:"))
data_layout.addWidget(self.train_data_path)
model_layout.addLayout(data_layout)
model_group.setLayout(model_layout)
layout.addWidget(model_group)
# 训练控制
control_group = QGroupBox("训练控制")
control_layout = QHBoxLayout()
self.train_btn = QPushButton("开始训练")
self.train_btn.clicked.connect(self.start_training)
self.stop_btn = QPushButton("停止训练")
self.stop_btn.setEnabled(False)
self.save_btn = QPushButton("保存模型")
self.save_btn.clicked.connect(self.save_model)
control_layout.addWidget(self.train_btn)
control_layout.addWidget(self.stop_btn)
control_layout.addWidget(self.save_btn)
control_group.setLayout(control_layout)
layout.addWidget(control_group)
# 训练进度
progress_group = QGroupBox("训练进度")
progress_layout = QVBoxLayout()
self.progress_bar = QProgressBar()
self.status_label = QLabel("准备就绪")
progress_layout.addWidget(self.progress_bar)
progress_layout.addWidget(self.status_label)
progress_group.setLayout(progress_layout)
layout.addWidget(progress_group)
# 训练曲线
curve_group = QGroupBox("训练曲线")
curve_layout = QVBoxLayout()
self.figure = Figure(figsize=(10, 6))
self.canvas = FigureCanvas(self.figure)
curve_layout.addWidget(self.canvas)
curve_group.setLayout(curve_layout)
layout.addWidget(curve_group)
# 训练日志
log_group = QGroupBox("训练日志")
log_layout = QVBoxLayout()
self.log_text = QTextEdit()
self.log_text.setReadOnly(True)
self.log_text.setMaximumHeight(150)
log_layout.addWidget(self.log_text)
log_group.setLayout(log_layout)
layout.addWidget(log_group)
self.setLayout(layout)
def start_training(self):
self.log_text.append(f"[{datetime.now().strftime('%H:%M:%S')}] 开始训练...")
self.train_btn.setEnabled(False)
self.stop_btn.setEnabled(True)
# 加载数据
data_dir = self.train_data_path.text()
if not os.path.exists(data_dir):
QMessageBox.warning(self, "错误", "训练数据路径不存在!")
self.train_btn.setEnabled(True)
self.stop_btn.setEnabled(False)
return
batch_size = self.batch_size_spin.value()
epochs = self.epochs_spin.value()
is_transfer = self.model_type.currentIndex() == 1
try:
# 加载数据
train_ds, val_ds, class_names = self.data_load(data_dir, 224, 224, batch_size)
# 加载模型
self.model = self.model_load(is_transfer=is_transfer)
# 训练模型
self.log_text.append(
f"[{datetime.now().strftime('%H:%M:%S')}] 模型类型: {'MobileNetV2迁移学习' if is_transfer else 'CNN模型'}")
self.log_text.append(f"[{datetime.now().strftime('%H:%M:%S')}] 训练轮数: {epochs}, 批次大小: {batch_size}")
# 创建回调来更新进度
class ProgressCallback(tf.keras.callbacks.Callback):
def __init__(self, widget, epochs):
super().__init__()
self.widget = widget
self.total_epochs = epochs
def on_epoch_end(self, epoch, logs=None):
progress = int((epoch + 1) / self.total_epochs * 100)
self.widget.progress_bar.setValue(progress)
self.widget.status_label.setText(
f"Epoch {epoch + 1}/{self.total_epochs} - Loss: {logs['loss']:.4f} - Accuracy: {logs['accuracy']:.4f}")
self.widget.log_text.append(
f"[{datetime.now().strftime('%H:%M:%S')}] Epoch {epoch + 1}/{self.total_epochs} - Loss: {logs['loss']:.4f} - Acc: {logs['accuracy']:.4f}")
QApplication.processEvents()
callback = ProgressCallback(self, epochs)
# 训练
self.history = self.model.fit(
train_ds,
validation_data=val_ds,
epochs=epochs,
callbacks=[callback]
)
# 显示训练曲线
self.show_training_curves()
self.log_text.append(f"[{datetime.now().strftime('%H:%M:%S')}] 训练完成!")
self.status_label.setText("训练完成")
except Exception as e:
self.log_text.append(f"[{datetime.now().strftime('%H:%M:%S')}] 训练出错: {str(e)}")
QMessageBox.critical(self, "错误", f"训练失败: {str(e)}")
finally:
self.train_btn.setEnabled(True)
self.stop_btn.setEnabled(False)
def data_load(self, data_dir, img_height, img_width, batch_size):
"""加载数据"""
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
data_dir,
label_mode='categorical',
validation_split=0.2,
subset="training",
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size)
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
data_dir,
label_mode='categorical',
validation_split=0.2,
subset="validation",
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size)
class_names = train_ds.class_names
return train_ds, val_ds, class_names
def model_load(self, IMG_SHAPE=(224, 224, 3), is_transfer=False):
"""加载模型"""
if is_transfer:
base_model = tf.keras.applications.MobileNetV2(
input_shape=IMG_SHAPE,
include_top=False,
weights='imagenet'
)
base_model.trainable = False
model = tf.keras.models.Sequential([
tf.keras.layers.experimental.preprocessing.Rescaling(1. / 127.5, offset=-1, input_shape=IMG_SHAPE),
base_model,
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Dense(5, activation='softmax')
])
else:
model = tf.keras.models.Sequential([
tf.keras.layers.experimental.preprocessing.Rescaling(1. / 255, input_shape=IMG_SHAPE),
tf.keras.layers.Conv2D(32, (3, 3), activation='relu'),
tf.keras.layers.MaxPooling2D(2, 2),
tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
tf.keras.layers.MaxPooling2D(2, 2),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(5, activation='softmax')
])
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
return model
def show_training_curves(self):
"""显示训练曲线"""
if self.history is None:
return
self.figure.clear()
# 准确率曲线
ax1 = self.figure.add_subplot(1, 2, 1)
ax1.plot(self.history.history['accuracy'], label='训练准确率')
if 'val_accuracy' in self.history.history:
ax1.plot(self.history.history['val_accuracy'], label='验证准确率')
ax1.set_title('模型准确率')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('准确率')
ax1.legend()
ax1.grid(True)
# 损失曲线
ax2 = self.figure.add_subplot(1, 2, 2)
ax2.plot(self.history.history['loss'], label='训练损失')
if 'val_loss' in self.history.history:
ax2.plot(self.history.history['val_loss'], label='验证损失')
ax2.set_title('模型损失')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('损失')
ax2.legend()
ax2.grid(True)
self.figure.tight_layout()
self.canvas.draw()
def save_model(self):
if self.model is None:
QMessageBox.warning(self, "警告", "没有可保存的模型!")
return
file_path, _ = QFileDialog.getSaveFileName(
self, "保存模型", "./models", "H5 Files (*.h5)"
)
if file_path:
try:
self.model.save(file_path)
self.log_text.append(f"[{datetime.now().strftime('%H:%M:%S')}] 模型已保存到: {file_path}")
QMessageBox.information(self, "成功", "模型保存成功!")
except Exception as e:
QMessageBox.critical(self, "错误", f"模型保存失败: {str(e)}")
测试时输出的log:
测试结果:
损失值: 0.3617
准确率: 0.8965 (89.65%)
测试样本数: 734
各类别准确率:
daisy: 93.80% (正确: 121/129)
dandelion: 96.02% (正确: 169/176)
roses: 74.17% (正确: 89/120)
sunflowers: 88.82% (正确: 135/152)
tulips: 91.72% (正确: 144/157)
模型架构摘要:
总参数量: 2264389
层数: 4