基于Tensorflow2.15的图像分类系统

下图所示的是一个图像分类系统,理论上也支持其他场景的图像分类需求,以花卉分类为例,可在界面上选择数据集,自动化划分数据集,配置训练时的迭代次数 学习率等超参数,即可进行训练,训练完成后可对模型进行测试,输出混淆矩阵,支持单站图片预测和批量预测,核心代码如下所示:

目前界面还在完善中,请各位看官敬请谅解



python 复制代码
class TrainingWidget(QWidget):
    """模型训练界面"""

    def __init__(self):
        super().__init__()
        self.init_ui()
        self.model = None
        self.history = None

    def init_ui(self):
        layout = QVBoxLayout()

        # 模型选择
        model_group = QGroupBox("模型配置")
        model_layout = QVBoxLayout()

        # 模型类型选择
        type_layout = QHBoxLayout()
        self.model_type = QComboBox()
        self.model_type.addItems(["CNN模型", "MobileNetV2迁移学习"])
        type_layout.addWidget(QLabel("模型类型:"))
        type_layout.addWidget(self.model_type)
        type_layout.addStretch()
        model_layout.addLayout(type_layout)

        # 训练参数设置
        param_layout = QGridLayout()

        self.epochs_spin = QSpinBox()
        self.epochs_spin.setRange(1, 100)
        self.epochs_spin.setValue(10)

        self.batch_size_spin = QSpinBox()
        self.batch_size_spin.setRange(1, 128)
        self.batch_size_spin.setValue(4)

        self.learning_rate = QDoubleSpinBox()
        self.learning_rate.setRange(0.0001, 0.1)
        self.learning_rate.setValue(0.001)
        self.learning_rate.setSingleStep(0.0001)
        self.learning_rate.setDecimals(4)

        param_layout.addWidget(QLabel("训练轮数:"), 0, 0)
        param_layout.addWidget(self.epochs_spin, 0, 1)
        param_layout.addWidget(QLabel("批次大小:"), 0, 2)
        param_layout.addWidget(self.batch_size_spin, 0, 3)
        param_layout.addWidget(QLabel("学习率:"), 0, 4)
        param_layout.addWidget(self.learning_rate, 0, 5)

        model_layout.addLayout(param_layout)

        # 数据集路径
        data_layout = QHBoxLayout()
        self.train_data_path = QLineEdit("./data/flower_photos")
        data_layout.addWidget(QLabel("训练数据:"))
        data_layout.addWidget(self.train_data_path)
        model_layout.addLayout(data_layout)

        model_group.setLayout(model_layout)
        layout.addWidget(model_group)

        # 训练控制
        control_group = QGroupBox("训练控制")
        control_layout = QHBoxLayout()

        self.train_btn = QPushButton("开始训练")
        self.train_btn.clicked.connect(self.start_training)
        self.stop_btn = QPushButton("停止训练")
        self.stop_btn.setEnabled(False)
        self.save_btn = QPushButton("保存模型")
        self.save_btn.clicked.connect(self.save_model)

        control_layout.addWidget(self.train_btn)
        control_layout.addWidget(self.stop_btn)
        control_layout.addWidget(self.save_btn)

        control_group.setLayout(control_layout)
        layout.addWidget(control_group)

        # 训练进度
        progress_group = QGroupBox("训练进度")
        progress_layout = QVBoxLayout()

        self.progress_bar = QProgressBar()
        self.status_label = QLabel("准备就绪")

        progress_layout.addWidget(self.progress_bar)
        progress_layout.addWidget(self.status_label)

        progress_group.setLayout(progress_layout)
        layout.addWidget(progress_group)

        # 训练曲线
        curve_group = QGroupBox("训练曲线")
        curve_layout = QVBoxLayout()

        self.figure = Figure(figsize=(10, 6))
        self.canvas = FigureCanvas(self.figure)
        curve_layout.addWidget(self.canvas)

        curve_group.setLayout(curve_layout)
        layout.addWidget(curve_group)

        # 训练日志
        log_group = QGroupBox("训练日志")
        log_layout = QVBoxLayout()

        self.log_text = QTextEdit()
        self.log_text.setReadOnly(True)
        self.log_text.setMaximumHeight(150)
        log_layout.addWidget(self.log_text)

        log_group.setLayout(log_layout)
        layout.addWidget(log_group)

        self.setLayout(layout)

    def start_training(self):
        self.log_text.append(f"[{datetime.now().strftime('%H:%M:%S')}] 开始训练...")
        self.train_btn.setEnabled(False)
        self.stop_btn.setEnabled(True)

        # 加载数据
        data_dir = self.train_data_path.text()
        if not os.path.exists(data_dir):
            QMessageBox.warning(self, "错误", "训练数据路径不存在!")
            self.train_btn.setEnabled(True)
            self.stop_btn.setEnabled(False)
            return

        batch_size = self.batch_size_spin.value()
        epochs = self.epochs_spin.value()
        is_transfer = self.model_type.currentIndex() == 1

        try:
            # 加载数据
            train_ds, val_ds, class_names = self.data_load(data_dir, 224, 224, batch_size)

            # 加载模型
            self.model = self.model_load(is_transfer=is_transfer)

            # 训练模型
            self.log_text.append(
                f"[{datetime.now().strftime('%H:%M:%S')}] 模型类型: {'MobileNetV2迁移学习' if is_transfer else 'CNN模型'}")
            self.log_text.append(f"[{datetime.now().strftime('%H:%M:%S')}] 训练轮数: {epochs}, 批次大小: {batch_size}")

            # 创建回调来更新进度
            class ProgressCallback(tf.keras.callbacks.Callback):
                def __init__(self, widget, epochs):
                    super().__init__()
                    self.widget = widget
                    self.total_epochs = epochs

                def on_epoch_end(self, epoch, logs=None):
                    progress = int((epoch + 1) / self.total_epochs * 100)
                    self.widget.progress_bar.setValue(progress)
                    self.widget.status_label.setText(
                        f"Epoch {epoch + 1}/{self.total_epochs} - Loss: {logs['loss']:.4f} - Accuracy: {logs['accuracy']:.4f}")
                    self.widget.log_text.append(
                        f"[{datetime.now().strftime('%H:%M:%S')}] Epoch {epoch + 1}/{self.total_epochs} - Loss: {logs['loss']:.4f} - Acc: {logs['accuracy']:.4f}")
                    QApplication.processEvents()

            callback = ProgressCallback(self, epochs)

            # 训练
            self.history = self.model.fit(
                train_ds,
                validation_data=val_ds,
                epochs=epochs,
                callbacks=[callback]
            )

            # 显示训练曲线
            self.show_training_curves()

            self.log_text.append(f"[{datetime.now().strftime('%H:%M:%S')}] 训练完成!")
            self.status_label.setText("训练完成")

        except Exception as e:
            self.log_text.append(f"[{datetime.now().strftime('%H:%M:%S')}] 训练出错: {str(e)}")
            QMessageBox.critical(self, "错误", f"训练失败: {str(e)}")

        finally:
            self.train_btn.setEnabled(True)
            self.stop_btn.setEnabled(False)

    def data_load(self, data_dir, img_height, img_width, batch_size):
        """加载数据"""
        train_ds = tf.keras.preprocessing.image_dataset_from_directory(
            data_dir,
            label_mode='categorical',
            validation_split=0.2,
            subset="training",
            seed=123,
            image_size=(img_height, img_width),
            batch_size=batch_size)

        val_ds = tf.keras.preprocessing.image_dataset_from_directory(
            data_dir,
            label_mode='categorical',
            validation_split=0.2,
            subset="validation",
            seed=123,
            image_size=(img_height, img_width),
            batch_size=batch_size)

        class_names = train_ds.class_names
        return train_ds, val_ds, class_names

    def model_load(self, IMG_SHAPE=(224, 224, 3), is_transfer=False):
        """加载模型"""
        if is_transfer:
            base_model = tf.keras.applications.MobileNetV2(
                input_shape=IMG_SHAPE,
                include_top=False,
                weights='imagenet'
            )
            base_model.trainable = False
            model = tf.keras.models.Sequential([
                tf.keras.layers.experimental.preprocessing.Rescaling(1. / 127.5, offset=-1, input_shape=IMG_SHAPE),
                base_model,
                tf.keras.layers.GlobalAveragePooling2D(),
                tf.keras.layers.Dense(5, activation='softmax')
            ])
        else:
            model = tf.keras.models.Sequential([
                tf.keras.layers.experimental.preprocessing.Rescaling(1. / 255, input_shape=IMG_SHAPE),
                tf.keras.layers.Conv2D(32, (3, 3), activation='relu'),
                tf.keras.layers.MaxPooling2D(2, 2),
                tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
                tf.keras.layers.MaxPooling2D(2, 2),
                tf.keras.layers.Flatten(),
                tf.keras.layers.Dense(128, activation='relu'),
                tf.keras.layers.Dense(5, activation='softmax')
            ])

        model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
        return model

    def show_training_curves(self):
        """显示训练曲线"""
        if self.history is None:
            return

        self.figure.clear()

        # 准确率曲线
        ax1 = self.figure.add_subplot(1, 2, 1)
        ax1.plot(self.history.history['accuracy'], label='训练准确率')
        if 'val_accuracy' in self.history.history:
            ax1.plot(self.history.history['val_accuracy'], label='验证准确率')
        ax1.set_title('模型准确率')
        ax1.set_xlabel('Epoch')
        ax1.set_ylabel('准确率')
        ax1.legend()
        ax1.grid(True)

        # 损失曲线
        ax2 = self.figure.add_subplot(1, 2, 2)
        ax2.plot(self.history.history['loss'], label='训练损失')
        if 'val_loss' in self.history.history:
            ax2.plot(self.history.history['val_loss'], label='验证损失')
        ax2.set_title('模型损失')
        ax2.set_xlabel('Epoch')
        ax2.set_ylabel('损失')
        ax2.legend()
        ax2.grid(True)

        self.figure.tight_layout()
        self.canvas.draw()

    def save_model(self):
        if self.model is None:
            QMessageBox.warning(self, "警告", "没有可保存的模型!")
            return

        file_path, _ = QFileDialog.getSaveFileName(
            self, "保存模型", "./models", "H5 Files (*.h5)"
        )

        if file_path:
            try:
                self.model.save(file_path)
                self.log_text.append(f"[{datetime.now().strftime('%H:%M:%S')}] 模型已保存到: {file_path}")
                QMessageBox.information(self, "成功", "模型保存成功!")
            except Exception as e:
                QMessageBox.critical(self, "错误", f"模型保存失败: {str(e)}")

测试时输出的log:

测试结果:

损失值: 0.3617

准确率: 0.8965 (89.65%)

测试样本数: 734

各类别准确率:

daisy: 93.80% (正确: 121/129)

dandelion: 96.02% (正确: 169/176)

roses: 74.17% (正确: 89/120)

sunflowers: 88.82% (正确: 135/152)

tulips: 91.72% (正确: 144/157)

模型架构摘要:

总参数量: 2264389

层数: 4


相关推荐
2401_823868225 分钟前
织构表面MATLAB仿真
人工智能·机器学习·matlab·信号处理
霖008 分钟前
高级项目——基于FPGA的串行FIR滤波器
人工智能·经验分享·matlab·fpga开发·信息与通信·信号处理
Debroon24 分钟前
CV 医学影像分类、分割、目标检测,之【腹腔多器官语义分割】项目拆解
目标检测·分类·数据挖掘
掘金一周31 分钟前
我开源了一款 Canvas “瑞士军刀”,十几种“特效与工具”开箱即用 | 掘金一周 8.14
前端·人工智能·后端
程序员半支烟40 分钟前
选择gpt-5还是claude-4-sonnect
人工智能·chatgpt·个人开发
拉一次撑死狗40 分钟前
机器学习实战·第三章 分类(2)
人工智能·机器学习·分类
没事学AI1 小时前
美团搜索推荐统一Agent之交互协议与多Agent协同
人工智能·agent·美团·多agent
霖002 小时前
FPGA的PS基础1
数据结构·人工智能·windows·git·算法·fpga开发
在钱塘江3 小时前
LangGraph构建Ai智能体-12-高级RAG之自适应RAG
人工智能·python