从零开始学习深度学习—水果分类之PyQt5App

一、项目背景⭐:

本项目是"从零开始学习深度学习"系列中的第二个实战项目,旨在实现第一个简易App(图像分类任务------水果分类),进一步地落地AI模型应用,帮助初学者初步了解模型落地。

基于PyQt5图形界面的水果图像分类系统,用户可以通过加载模型、选择图像并一键完成图像识别。

二、项目目标🚀:

基于PyQt5图形界面实现以下功能:

  • 加载本地 .pth 训练好的模型;

  • 加载本地图像进行展示;

  • 自动完成图像预处理(Resize、ToTensor、Normalize);

  • 使用模型完成预测并展示结果;

  • 界面美观,交互友好。

三、适合人群🫵:

  • 深度学习零基础或刚入门的学习者
  • 希望通过项目实战学习BP神经网络、卷积神经网络模型搭建的开发者
  • 对图像识别、分类应用感兴趣的童鞋
  • 适用于想学习通过界面实现AI模型推理,

四、项目实战✊:

1.主界面构建

python 复制代码
    def initUI(self):
        # 主窗口设置
        self.setWindowTitle("水果分类应用")
        self.setGeometry(100, 100, 800, 600)

        # 创建主窗口部件
        central_widget = QWidget()
        self.setCentralWidget(central_widget)

        # 创建主布局
        main_layout = QVBoxLayout()

        # 模型选择部分
        model_layout = QHBoxLayout()
        model_label = QLabel("模型路径:")
        self.model_path_edit = QtWidgets.QLineEdit()
        model_button = QPushButton("选择模型")
        model_button.clicked.connect(self.select_model_path)
        self.load_model_button = QPushButton("加载模型")
        self.load_model_button.clicked.connect(self.load_model)
        self.load_model_button.setEnabled(False)

        model_layout.addWidget(model_label)
        model_layout.addWidget(self.model_path_edit)
        model_layout.addWidget(model_button)
        model_layout.addWidget(self.load_model_button)
        main_layout.addLayout(model_layout)

        # 图像显示部分
        self.image_label = QLabel()
        self.image_label.setAlignment(QtCore.Qt.AlignCenter)
        self.image_label.setMinimumSize(600, 400)
        main_layout.addWidget(self.image_label)

        # 图像选择部分
        image_layout = QHBoxLayout()
        image_path_label = QLabel("图像路径:")
        self.image_path_edit = QtWidgets.QLineEdit()
        image_select_button = QPushButton("选择图像")
        image_select_button.clicked.connect(self.select_image_path)
        self.predict_button = QPushButton("分类预测")
        self.predict_button.clicked.connect(self.predict_image)
        self.predict_button.setEnabled(False)

        image_layout.addWidget(image_path_label)
        image_layout.addWidget(self.image_path_edit)
        image_layout.addWidget(image_select_button)
        image_layout.addWidget(self.predict_button)
        main_layout.addLayout(image_layout)

        # 结果显示部分
        self.result_label = QLabel("请先加载模型并选择图像")
        self.result_label.setAlignment(QtCore.Qt.AlignCenter)
        self.result_label.setStyleSheet("font-size: 20px")
        main_layout.addWidget(self.result_label)

        central_widget.setLayout(main_layout)

2.功能辅助函数

python 复制代码
    def select_model_path(self):
        file_path, _ = QFileDialog.getOpenFileName(
            self,
            "选择模型文件",
            "",
            "Pytorch模型 (*.pth);;所有文件(*)")
        if file_path:
            self.model_path_edit.setText(file_path)
            self.load_model_button.setEnabled(True)

    def load_model(self):
        model_path = self.model_path_edit.text()
        if not model_path:
            return
        try:
            # 模型类型(根据你的模型的时间需求进行修改)
            self.model = FruitClassificationModelResnet18(4)
            self.model.load_state_dict(torch.load(model_path, map_location=self.device, weights_only=False))
            self.model = self.model.to(self.device)
            self.model.eval()

            self.result_label.setText("模型加载成功!请选择图像进行预测.")
            self.predict_button.setEnabled(True)
        except Exception as e:
            self.result_label.setText(f"模型加载失败: {str(e)}")
            self.model = None
            self.predict_button.setEnabled(False)

    def select_image_path(self):
        file_path, _ = QFileDialog.getOpenFileName(
            self,
            "选择图像文件",
            "",
            "图像文件 (*bmp *.png *.jpg *.jpeg);;所有文件(*)"
        )
        if file_path:
            self.image_path_edit.setText(file_path)
            self.display_image(file_path)

    def display_image(self, file_path):
        pixmap = QtGui.QPixmap(file_path)
        if not pixmap.isNull():
            scaled_pixmap = pixmap.scaled(
                self.image_label.size(),
                QtCore.Qt.KeepAspectRatio,
                QtCore.Qt.SmoothTransformation
            )
            self.image_label.setPixmap(scaled_pixmap)
        else:
            self.image_label.setText("无法加载图像")

    def preprocess_image(self, image_path):
        try:
            # 定义图像预处理流程
            transform = transforms.Compose([
                transforms.Resize((224, 224)),  # 调整图像大小为224x224
                transforms.ToTensor(),  # 转换为Tensor格式
                transforms.Normalize([0.485, 0.456, 0.406],  # 标准化均值(ImageNet数据集)
                                     [0.229, 0.224, 0.225])  # 标准化标准差
            ])

            # 打开图像文件
            image = Image.open(image_path)
            # 如果图像不是RGB模式,转换为RGB
            if image.mode != "RGB":
                image = image.convert("RGB")
            # 应用预处理变换并添加batch维度(unsqueeze(0)),然后移动到指定设备
            image = transform(image).unsqueeze(0).to(self.device)
            return image
        except Exception as e:
            self.result_label.setText(f"图像预处理失败: {str(e)}")
            return None

3.加载模型

python 复制代码
    def load_model(self):
        model_path = self.model_path_edit.text()
        if not model_path:
            return
        try:
            # 模型类型(根据你的模型的时间需求进行修改)
            self.model = FruitClassificationModelResnet18(4)
            self.model.load_state_dict(torch.load(model_path, map_location=self.device, weights_only=False))
            self.model = self.model.to(self.device)
            self.model.eval()

            self.result_label.setText("模型加载成功!请选择图像进行预测.")
            self.predict_button.setEnabled(True)
        except Exception as e:
            self.result_label.setText(f"模型加载失败: {str(e)}")
            self.model = None
            self.predict_button.setEnabled(False)

4.预测函数

python 复制代码
    def predict_image(self):
        if not self.model:
            self.result_label.setText("请先加载模型")
            return

        image_path = self.image_path_edit.text()
        if not image_path:
            self.result_label.setText("请选择图像")
            return

        input_tensor = self.preprocess_image(image_path)
        if input_tensor is None:
            return

        # 预测
        with torch.no_grad():
            input_tensor = input_tensor.to(self.device)
            outputs = self.model(input_tensor)
            _, predicted = torch.max(outputs.data, 1)
            class_id = predicted.item()

        # 显示结果
        class_names = ['Apple', 'Banana', 'Orange', 'Pinenapple']  # 示例类别  根据你的模型进行修改
        if class_id < len(class_names):
            self.result_label.setText(f"预测结果: {class_names[class_id]}")
        else:
            self.result_label.setText(f"预测结果: 未知类别 ({class_id})")

        QtWidgets.QApplication.processEvents()

6.完整实现代码

python 复制代码
import cv2
import sys
import numpy as np
import torch
import torchvision.transforms as transforms
from PIL import Image
from PyQt5 import QtWidgets, QtCore, QtGui
from PyQt5.QtWidgets import QFileDialog, QLabel, QPushButton, QVBoxLayout, QWidget, QHBoxLayout
from model import FruitClassificationModelResnet18


class FruitClassificationApp(QtWidgets.QMainWindow):
    def __init__(self):
        super().__init__()
        self.model = None
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.initUI()

    def initUI(self):
        # 主窗口设置
        self.setWindowTitle("水果分类应用")
        self.setGeometry(100, 100, 800, 600)

        # 创建主窗口部件
        central_widget = QWidget()
        self.setCentralWidget(central_widget)

        # 创建主布局
        main_layout = QVBoxLayout()

        # 模型选择部分
        model_layout = QHBoxLayout()
        model_label = QLabel("模型路径:")
        self.model_path_edit = QtWidgets.QLineEdit()
        model_button = QPushButton("选择模型")
        model_button.clicked.connect(self.select_model_path)
        self.load_model_button = QPushButton("加载模型")
        self.load_model_button.clicked.connect(self.load_model)
        self.load_model_button.setEnabled(False)

        model_layout.addWidget(model_label)
        model_layout.addWidget(self.model_path_edit)
        model_layout.addWidget(model_button)
        model_layout.addWidget(self.load_model_button)
        main_layout.addLayout(model_layout)

        # 图像显示部分
        self.image_label = QLabel()
        self.image_label.setAlignment(QtCore.Qt.AlignCenter)
        self.image_label.setMinimumSize(600, 400)
        main_layout.addWidget(self.image_label)

        # 图像选择部分
        image_layout = QHBoxLayout()
        image_path_label = QLabel("图像路径:")
        self.image_path_edit = QtWidgets.QLineEdit()
        image_select_button = QPushButton("选择图像")
        image_select_button.clicked.connect(self.select_image_path)
        self.predict_button = QPushButton("分类预测")
        self.predict_button.clicked.connect(self.predict_image)
        self.predict_button.setEnabled(False)

        image_layout.addWidget(image_path_label)
        image_layout.addWidget(self.image_path_edit)
        image_layout.addWidget(image_select_button)
        image_layout.addWidget(self.predict_button)
        main_layout.addLayout(image_layout)

        # 结果显示部分
        self.result_label = QLabel("请先加载模型并选择图像")
        self.result_label.setAlignment(QtCore.Qt.AlignCenter)
        self.result_label.setStyleSheet("font-size: 20px")
        main_layout.addWidget(self.result_label)

        central_widget.setLayout(main_layout)

    def select_model_path(self):
        file_path, _ = QFileDialog.getOpenFileName(
            self,
            "选择模型文件",
            "",
            "Pytorch模型 (*.pth);;所有文件(*)")
        if file_path:
            self.model_path_edit.setText(file_path)
            self.load_model_button.setEnabled(True)

    def load_model(self):
        model_path = self.model_path_edit.text()
        if not model_path:
            return
        try:
            # 模型类型(根据你的模型的时间需求进行修改)
            self.model = FruitClassificationModelResnet18(4)
            self.model.load_state_dict(torch.load(model_path, map_location=self.device, weights_only=False))
            self.model = self.model.to(self.device)
            self.model.eval()

            self.result_label.setText("模型加载成功!请选择图像进行预测.")
            self.predict_button.setEnabled(True)
        except Exception as e:
            self.result_label.setText(f"模型加载失败: {str(e)}")
            self.model = None
            self.predict_button.setEnabled(False)

    def select_image_path(self):
        file_path, _ = QFileDialog.getOpenFileName(
            self,
            "选择图像文件",
            "",
            "图像文件 (*bmp *.png *.jpg *.jpeg);;所有文件(*)"
        )
        if file_path:
            self.image_path_edit.setText(file_path)
            self.display_image(file_path)

    def display_image(self, file_path):
        pixmap = QtGui.QPixmap(file_path)
        if not pixmap.isNull():
            scaled_pixmap = pixmap.scaled(
                self.image_label.size(),
                QtCore.Qt.KeepAspectRatio,
                QtCore.Qt.SmoothTransformation
            )
            self.image_label.setPixmap(scaled_pixmap)
        else:
            self.image_label.setText("无法加载图像")

    def preprocess_image(self, image_path):
        try:
            # 定义图像预处理流程
            transform = transforms.Compose([
                transforms.Resize((224, 224)),  # 调整图像大小为224x224
                transforms.ToTensor(),  # 转换为Tensor格式
                transforms.Normalize([0.485, 0.456, 0.406],  # 标准化均值(ImageNet数据集)
                                     [0.229, 0.224, 0.225])  # 标准化标准差
            ])

            # 打开图像文件
            image = Image.open(image_path)
            # 如果图像不是RGB模式,转换为RGB
            if image.mode != "RGB":
                image = image.convert("RGB")
            # 应用预处理变换并添加batch维度(unsqueeze(0)),然后移动到指定设备
            image = transform(image).unsqueeze(0).to(self.device)
            return image
        except Exception as e:
            self.result_label.setText(f"图像预处理失败: {str(e)}")
            return None

    def predict_image(self):
        if not self.model:
            self.result_label.setText("请先加载模型")
            return

        image_path = self.image_path_edit.text()
        if not image_path:
            self.result_label.setText("请选择图像")
            return

        input_tensor = self.preprocess_image(image_path)
        if input_tensor is None:
            return

        # 预测
        with torch.no_grad():
            input_tensor = input_tensor.to(self.device)
            outputs = self.model(input_tensor)
            _, predicted = torch.max(outputs.data, 1)
            class_id = predicted.item()

        # 显示结果
        class_names = ['Apple', 'Banana', 'Orange', 'Pinenapple']  # 示例类别  根据你的模型进行修改
        if class_id < len(class_names):
            self.result_label.setText(f"预测结果: {class_names[class_id]}")
        else:
            self.result_label.setText(f"预测结果: 未知类别 ({class_id})")

        QtWidgets.QApplication.processEvents()


if __name__ == '__main__':
    app = QtWidgets.QApplication(sys.argv)
    window = FruitClassificationApp()
    window.show()
    sys.exit(app.exec_())

五、学习收获🎁:

通过本次 PyTorch 与 PyQt5 的项目实战,不仅巩固了深度学习模型的使用方法,也系统地学习了如何将模型部署到图形界面中。以下是我的一些具体收获:

1️⃣ 深度学习模型部署实践

  • 学会了如何将 .pth 格式的模型加载到推理环境;

  • 熟悉了图像的预处理流程(如Resize、ToTensor、Normalize);

  • 掌握了 torch.no_grad() 推理模式下的使用,避免梯度计算加速推理。

2️⃣ PyQt5 图形界面开发

  • 掌握了 PyQt5 中常用的控件如 QLabelQPushButtonQLineEdit 等;

  • 学会了如何使用 QFileDialog 实现文件选择;

  • 了解了如何通过 QPixmap 加载并展示图像;

  • 熟悉了 QVBoxLayoutQHBoxLayout 进行界面布局。

3️⃣ 端到端流程整合

  • 实现了从模型加载 → 图像读取 → 图像预处理 → 推理 → 展示结果 的完整流程;

  • 初步理解了如何将 AI 模型变成一个用户可交互的软件;

  • 为后续构建更复杂的推理系统(如视频流识别、多模型切换)打下了基础。

注:完整代码,请私聊,免费获取。

相关推荐
北京迅为2 小时前
《【北京迅为】itop-3568开发板NPU使用手册》- 第 7章 使用RKNN-Toolkit-lite2
linux·人工智能·嵌入式·npu
我是一只puppy2 小时前
使用AI进行代码审查
javascript·人工智能·git·安全·源代码管理
阿杰学AI2 小时前
AI核心知识91——大语言模型之 Transformer 架构(简洁且通俗易懂版)
人工智能·深度学习·ai·语言模型·自然语言处理·aigc·transformer
esmap2 小时前
ESMAP 智慧消防解决方案:以数字孪生技术构建全域感知消防体系,赋能消防安全管理智能化升级
人工智能·物联网·3d·编辑器·智慧城市
LaughingZhu2 小时前
Product Hunt 每日热榜 | 2026-02-08
大数据·人工智能·经验分享·搜索引擎·产品运营
芷栀夏2 小时前
CANN ops-math:筑牢 AI 神经网络底层的高性能数学运算算子库核心实现
人工智能·深度学习·神经网络
用户5191495848452 小时前
CVE-2025-47812:Wing FTP Server 高危RCE漏洞分析与利用
人工智能·aigc
阿里云大数据AI技术2 小时前
【AAAI2026】阿里云人工智能平台PAI视频编辑算法论文入选
人工智能
玄同7652 小时前
我的 Trae Skill 实践|使用 UV 工具一键搭建 Python 项目开发环境
开发语言·人工智能·python·langchain·uv·trae·vibe coding
苍何2 小时前
腾讯重磅开源!混元图像 3.0 图生图真香!
人工智能