PyTorch 模型部署实战:用 Flask 搭图像分类 API

在 AI 项目开发中,训练好的模型只有部署成可调用的服务,才能真正落地产生价值。本文将手把手教你用 ResNet18 预训练模型,结合 Flask 框架搭建图像分类 API,并编写客户端程序实现图片上传与结果接收,全程代码可直接复用。

一、项目整体架构

整个项目分为 "服务端" 和 "客户端" 两部分,核心是通过 HTTP 请求实现数据交互。

  • 服务端:基于 Flask 搭建 API 接口,加载 ResNet18 图像分类模型,接收客户端上传的图片,完成预测后返回结果。
  • 客户端:读取本地图片,通过 POST 请求将图片发送到服务端,接收并解析预测结果,最终打印展示。

两者的交互流程非常简洁:客户端传图→服务端预测→服务端返结果→客户端显结果。

二、服务端开发:搭建图像分类 API

服务端是整个系统的核心,需要完成模型加载、图片预处理、预测逻辑和 API 接口定义四个关键步骤。

1. 依赖库安装

首先确保安装所需的 Python 库,直接用 pip 安装即可:

python 复制代码
pip install flask torch torchvision pillow requests

2. 核心代码解析

服务端代码(命名为image_classification_server.py)分为 5 个模块,每个模块功能清晰,可直接复制使用。

(1)导入依赖库

先引入所有需要的工具包,涵盖 Web 服务、模型框架、图像处理等领域:

python 复制代码
import io
import flask
import torch
import torch.nn.functional as F
from torch import nn
from PIL import Image
from torchvision import transforms, models
(2)初始化 Flask 应用与模型变量

创建 Flask 实例,并定义模型和 GPU 使用标志(默认用 CPU,避免环境依赖):

python 复制代码
app = flask.Flask(__name__)
model = None  # 全局模型变量,加载后赋值
use_gpu = False  # 可根据硬件情况改为True
(3)模型加载函数

加载预训练的 ResNet18 模型,替换全连接层适配 102 类分类任务(如花卉分类),并加载训练好的权重文件(best.pth):

python 复制代码
def load_model():
    global model
    # 1. 加载ResNet18基础模型
    model = models.resnet18()
    # 2. 替换全连接层:ResNet18默认输出1000类,改为102类
    num_ftrs = model.fc.in_features
    model.fc = nn.Sequential(nn.Linear(num_ftrs, 102))
    # 3. 加载训练好的权重(需确保best.pth在当前目录)
    checkpoint = torch.load('best.pth', map_location=torch.device('cpu'))
    model.load_state_dict(checkpoint['state_dict'])
    # 4. 设置为评估模式(禁用 dropout 等训练特有的层)
    model.eval()
    # 5. 若使用GPU且设备可用,将模型移到GPU
    if use_gpu and torch.cuda.is_available():
        model.cuda()
(4)图像预处理函数

将客户端上传的图片转换成模型要求的输入格式(ResNet 系列默认输入为 224×224,且需用 ImageNet 数据集的均值和标准差归一化):

python 复制代码
def prepare_image(image, target_size=(224, 224)):
    # 定义图像转换流程
    transform = transforms.Compose([
        transforms.Resize(target_size),  # 缩放图片
        transforms.ToTensor(),  # 转为Tensor(维度:C×H×W)
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],  # ImageNet均值
            std=[0.229, 0.224, 0.225]    # ImageNet标准差
        )
    ])
    # 应用转换并添加批次维度(模型要求输入为[batch_size, C, H, W])
    image = transform(image).unsqueeze(0)
    # 若使用GPU,将图片移到GPU
    if use_gpu and torch.cuda.is_available():
        image = image.cuda()
    return image
(5)API 接口定义

定义两个核心接口:/predict(处理图片预测)和/health(服务健康检查)。

  • /predict 接口:接收 POST 请求,处理图片上传、预测和结果返回:
python 复制代码
@app.route('/predict', methods=['POST'])
def predict():
    # 初始化响应字典(默认失败)
    data = {"success": False}
    # 检查请求方法和是否包含图片
    if flask.request.method == 'POST' and flask.request.files.get('image'):
        # 1. 读取图片:从请求中获取二进制图片数据,转为PIL图像
        image_bytes = flask.request.files['image'].read()
        image = Image.open(io.BytesIO(image_bytes)).convert('RGB')  # 确保为RGB格式
        # 2. 预处理图片
        image = prepare_image(image)
        # 3. 模型预测(禁用梯度计算,提升速度)
        with torch.no_grad():
            outputs = model(image)  # 模型输出(logits)
            probabilities = F.softmax(outputs, dim=1)  # 转为概率
            top_k_prob, top_k_indices = torch.topk(probabilities, 5)  # 获取前5个预测结果
        # 4. 处理结果:转为Python原生类型,构造返回格式
        results = []
        for i in range(top_k_prob.size(1)):
            results.append({
                "label": int(top_k_indices[0][i]),  # 类别编号
                "probability": float(top_k_prob[0][i])  # 对应概率
            })
        # 5. 更新响应数据(标记成功,添加预测结果)
        data["predictions"] = results
        data["success"] = True
    # 返回JSON格式响应
    return flask.jsonify(data)
  • /health 接口:用于检查服务是否正常运行、模型是否加载成功:
python 复制代码
@app.route('/health', methods=['GET'])
def health_check():
    return flask.jsonify({"status": "healthy", "model_loaded": model is not None})
(6)启动服务

在主函数中加载模型,并启动 Flask 服务(默认端口 5012,允许外部访问):

python 复制代码
if __name__ == '__main__':
    print('Loading model... Please wait...')
    load_model()  # 启动前加载模型
    print('Model loaded successfully!')
    # 启动服务:host设为0.0.0.0可允许同一局域网内其他设备访问
    app.run(port=5012, host='0.0.0.0')

三、客户端开发:实现图片上传与结果接收

客户端代码(命名为image_classification_client.py)功能简单:读取本地图片,发送到服务端,解析并打印预测结果。

1. 核心代码解析

python 复制代码
import requests

# 服务端API地址(本地测试用127.0.0.1,局域网访问需替换为服务端IP)
flask_url = 'http://127.0.0.1:5012/predict'

def predict_result(image_path):
    # 1. 读取本地图片(二进制模式)
    with open(image_path, 'rb') as f:
        image_data = f.read()
    # 2. 构造请求参数(key为'image',与服务端接收的字段一致)
    payload = {'image': image_data}
    # 3. 发送POST请求,解析JSON响应
    response = requests.post(flask_url, files=payload).json()
    # 4. 处理响应结果
    if response['success']:
        # 打印前5个预测结果(类别编号+概率)
        print("预测结果:")
        for i, result in enumerate(response['predictions']):
            print(f"{i+1}. 类别编号:{result['label']},概率:{result['probability']:.4f}")
    else:
        print("请求失败,请检查服务端或图片路径!")

# 主函数:调用预测函数(替换为你的图片路径)
if __name__ == '__main__':
    predict_result('./flower_data/val_filelist/image_00059.jpg')

四、项目运行步骤

按照以下步骤操作,即可快速跑通整个流程:

  1. 准备模型权重 :将训练好的best.pth文件放到服务端代码同一目录(若没有,可先训练一个 102 类分类模型,或修改代码适配你的类别数)。

  2. 启动服务端 :运行服务端代码,看到 "Model loaded successfully!" 表示启动成功:

    bash

    python 复制代码
    python image_classification_server.py
  3. 运行客户端 :在另一个终端运行客户端代码,确保图片路径正确,即可看到预测结果:

    bash

    python 复制代码
    python image_classification_client.py

五、项目扩展方向

本项目是一个基础的图像分类 API 框架,可根据需求进行扩展:

  • 添加类别名称映射 :目前返回的是类别编号,可添加一个字典(如label2name = {0: '玫瑰', 1: '百合'...}),将编号转为具体名称。
  • 增加图片格式校验:在服务端添加对图片格式(如 JPG、PNG)和大小的校验,提升鲁棒性。
  • 部署到云服务器:将服务端部署到阿里云、腾讯云等平台,配置域名和 HTTPS,实现公网访问。
  • 添加请求限流:使用 Flask-Limiter 等库限制接口调用频率,防止恶意请求。

总结

本文通过 ResNet18+Flask 实现了图像分类 API 的快速搭建,从代码解析到运行步骤都做了详细说明,即使是初学者也能轻松上手。这个框架不仅适用于图像分类,稍作修改后还可用于目标检测、图像分割等其他计算机视觉任务,具有很强的通用性

相关推荐
西柚小萌新6 小时前
【深入浅出PyTorch】--6.2.PyTorch进阶训练技巧2
人工智能·pytorch·python
Kaydeon7 小时前
【AIGC】50倍加速!NVIDIA蒸馏算法rCM:分数正则化连续时间一致性模型的大规模扩散蒸馏
人工智能·pytorch·python·深度学习·计算机视觉·aigc
B站_计算机毕业设计之家8 小时前
大数据实战:Python+Flask 汽车数据分析可视化系统(爬虫+线性回归预测+推荐 源码+文档)✅
大数据·python·数据分析·flask·汽车·线性回归·预测
mortimer9 小时前
从 Python+venv+pip 迁移到 uv 全过程 及 处理 torch + cuda 的跨平台指南
pytorch·python·macos
Access开发易登软件12 小时前
Access调用Azure翻译:轻松实现系统多语言切换
后端·python·低代码·flask·vba·access·access开发
羊羊小栈13 小时前
基于「多模态大模型 + BGE向量检索增强RAG」的新能源汽车故障诊断智能问答系统(vue+flask+AI算法)
vue.js·人工智能·算法·flask·汽车·毕业设计·大作业
my烂笔头13 小时前
计算机视觉 图像分类 → 目标检测 → 实例分割
目标检测·计算机视觉·分类
麦麦大数据15 小时前
F024 vue+flask电影知识图谱推荐系统vue+neo4j +python实现
vue.js·python·flask·知识图谱·推荐算法·电影推荐
zzZ656516 小时前
PyTorch 实现 MNIST 手写数字识别全流程
pytorch·深度学习