从模型到 API:Flask+PyTorch 快速搭建图像分类

在深度学习项目落地时,将训练好的模型封装成可远程调用的 API 接口是核心环节。本文将完整讲解如何基于 Flask(轻量级 Web 框架)和 PyTorch(深度学习框架),实现图像分类服务端 API 开发 + 客户端调用 的全流程,让训练好的模型通过 HTTP 接口对外提供预测服务。

一、服务端开发:搭建图像分类 API 接口

服务端的核心职责是:加载预训练模型、接收客户端上传的图片、预处理图片、执行预测、返回 JSON 格式的预测结果。

1.完整服务端代码

python 复制代码
import io
import flask
import torch
import torch.nn.functional as F
from PIL import Image
from torch import nn
from torchvision import transforms, models

# 初始化Flask应用
app = flask.Flask(__name__)

# 全局变量:模型实例、是否使用GPU
model = None
use_gpu = False  # 测试阶段建议关闭GPU,避免环境问题

def load_model():
    """加载预训练的ResNet18模型(可替换为自定义模型)"""
    global model
    # 1. 加载ResNet18主干网络
    model = models.resnet18(pretrained=False)  # 不加载默认预训练权重
    # 2. 修改全连接层,适配自定义分类任务(示例为102分类)
    num_ftrs = model.fc.in_features
    model.fc = nn.Sequential(nn.Linear(num_ftrs, 102))
    # 3. 加载训练好的权重文件
    checkpoint = torch.load('best.pth', map_location='cpu')  # 强制使用CPU,避免GPU/CPU不匹配
    model.load_state_dict(checkpoint['state_dict'])
    # 4. 设置模型为评估模式(禁用Dropout/BatchNorm更新)
    model.eval()
    # 5. 可选:使用GPU(需确保环境有CUDA)
    if use_gpu and torch.cuda.is_available():
        model.cuda()

def prepare_image(image, target_size=(224, 224)):
    """预处理图片:转为RGB、resize、归一化、增加batch维度"""
    # 1. 统一转为RGB格式(避免灰度图等格式异常)
    if image.mode != 'RGB':
        image = image.convert('RGB')
    # 2. 图像变换:resize -> 转Tensor -> 归一化
    transform = transforms.Compose([
        transforms.Resize(target_size),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # ImageNet均值/方差
    ])
    image = transform(image)
    # 3. 增加batch维度(模型要求输入为[batch, channel, h, w])
    image = image.unsqueeze(0)
    # 4. 可选:转到GPU
    if use_gpu and torch.cuda.is_available():
        image = image.cuda()
    return image

@app.route("/predict", methods=["POST"])
def predict():
    """API核心接口:接收图片,返回预测结果"""
    # 初始化返回数据
    data = {"success": False}
    
    # 仅处理POST请求
    if flask.request.method == "POST":
        # 检查是否有图片上传
        if flask.request.files.get("image"):
            try:
                # 1. 读取并解析图片
                image_bytes = flask.request.files["image"].read()
                image = Image.open(io.BytesIO(image_bytes))
                # 2. 预处理图片
                image = prepare_image(image)
                # 3. 模型预测(禁用梯度计算,提升速度)
                with torch.no_grad():
                    preds = F.softmax(model(image), dim=1)  # 计算类别概率
                # 4. 获取概率最高的前3个结果
                top3_probs, top3_labels = torch.topk(preds.cpu(), k=3, dim=1)
                # 5. 格式化结果
                data['predictions'] = []
                for prob, label in zip(top3_probs.numpy()[0], top3_labels.numpy()[0]):
                    data['predictions'].append({
                        "label": str(label),
                        "probability": float(prob)
                    })
                # 6. 标记请求成功
                data["success"] = True
            except Exception as e:
                print(f"预测出错:{str(e)}")
    
    # 返回JSON格式结果(HTTP响应)
    return flask.jsonify(data)

if __name__ == '__main__':
    print("加载PyTorch模型中...")
    load_model()  # 启动前先加载模型
    print("模型加载完成,启动Flask服务...")
    # 启动服务:host=0.0.0.0允许局域网访问,port为自定义端口
    app.run(host='0.0.0.0', port=5012, debug=False)

2.服务端核心模块解析

(1)模型加载模块 load_model()

  • 加载 ResNet18 主干网络,替换全连接层适配自定义分类数(示例为 102 类)

  • 加载训练好的权重文件best.pth,设置map_location='cpu'避免 GPU/CPU 环境不匹配

  • 调用model.eval()将模型设为评估模式(禁用 Dropout、BatchNorm 等训练层)。

(2)图片预处理模块 prepare_image()

  • 统一转为 RGB 格式,避免灰度图、RGBA 图等格式异常

  • 使用 ImageNet 均值 / 方差归一化(与 ResNet 预训练时的预处理一致)

  • 增加 batch 维度(模型要求输入为 4 维张量:[batch, c, h, w])。

(3)API 接口 predict()

  • 仅接收 POST 请求,确保数据传输安全

  • 通过flask.request.files读取客户端上传的二进制图片

  • 使用torch.no_grad()禁用梯度计算,大幅提升预测速度

  • torch.topk获取概率最高的前 3 个类别,格式化后以 JSON 返回。

二、客户端开发:调用 API 获取预测结果

客户端的核心职责是:读取本地图片、以二进制形式上传到服务端 API、解析返回的 JSON 结果并展示。

1.客户端完整代码

python 复制代码
import requests#第三方库,爬虫中的一个库。
'''客户端的程序代码:功能;负责按照指定格式上传图片和接受结果'''
flask_url='http://127.0.0.1:5012/predict'#_url和端口写成自己的本地ip
def predict_result(image_path):
    image = open(image_path, 'rb').read()
    payload= {'image':image}
    r=requests.post(flask_url,files =payload).json()#是因为网络传辑
    if r['success']:
    # 成功的话再返回。
    # 输出结果
        for(i,result) in enumerate(r['predictions']):
            print('{}.预测类别为{}:的概率:{}'.format(i + 1,result['label'],result['probability']))
    # 失败了就打印
    else:
        print('Request failed')
if __name__ == "__main__":
    predict_result('./flower_data/2.jpg')

2.代码解析

2.1 配置信息
复制代码
import requests#第三方库,爬虫中的一个库。
  • 引入requests库:这是 Python 中处理 HTTP 请求的主流第三方库,不仅用于爬虫,更是接口调用的核心工具,能轻松实现 POST/GET 等请求方式。

    flask_url='http://127.0.0.1:5012/predict'#_url和端口写成自己的本地ip

  • 配置服务端接口地址:127.0.0.1是本地回环地址,对应运行 Flask 服务端的本机;5012是服务端设置的端口号;/predict是服务端开放的预测接口路径,必须与服务端的路由地址完全一致。

2.2 核心预测函数predict_result
复制代码
def predict_result(image_path):
    image = open(image_path, 'rb').read()
  • 读取图片文件:open(image_path, 'rb')以二进制只读模式打开图片,read()将文件内容读取为二进制字节流 ------ 网络传输图片必须用二进制格式,不能直接传文件路径或文本格式。

    payload= {'image':image}

  • 构造请求参数:创建字典payload,键名image是服务端约定的接收图片的参数名(需与服务端flask.request.files.get("image")中的image一致),值为读取的二进制图片数据。

    r=requests.post(flask_url,files =payload).json()

  • 发送 POST 请求并解析结果:

(1)requests.post():向指定的flask_url发送 POST 请求,files参数专门用于上传文件 / 二进制数据

(2).json():将服务端返回的 JSON 格式字符串自动解析为 Python 字典,方便后续取值。

复制代码
if r['success']:
    for(i,result) in enumerate(r['predictions']):
        print('{}.预测类别为{}:的概率:{}'.format(i + 1,result['label'],result['probability']))
else:
    print('Request failed')
  • 解析并输出结果:

(1)服务端返回的字典中,success为True表示预测成功,predictions是包含前 3 个高概率类别结果的列表

(2)enumerate为结果添加序号,通过result['label']和result['probability']分别获取类别标签和对应概率,格式化输出

(3)若success为False,则打印请求失败提示。

三、运行效果

服务端正常响应时,客户端控制台会输出:

复制代码
1.预测类别为5:的概率:0.9875643849372864
2.预测类别为8:的概率:0.0089123456789012
3.预测类别为12:的概率:0.0021098765432109
相关推荐
AI浩1 小时前
自适应图像变焦与边界框变换用于无人机目标检测
人工智能·目标检测·无人机
IT_陈寒2 小时前
SpringBoot开发效率提升50%的5个隐藏技巧,官方文档都没告诉你!
前端·人工智能·后端
大报言看2 小时前
2026年主流大模型API中转平台选型指南:稳定性与工程化能力的深度评估
人工智能·api
balmtv2 小时前
国内AI镜像站技术解析:如何实现GPT-4、Claude 3、Gemini的聚合与加速?
人工智能
坚持学习前端日记2 小时前
Agent AI 前端技术架构设计文档
前端·javascript·人工智能·python
智算菩萨2 小时前
GPT-5.4的“慢思考“艺术:详解推理时计算(Inference-Time Compute)如何重塑复杂任务解决能力
人工智能·gpt·ai·chatgpt
工业甲酰苯胺2 小时前
Docker 容器化 OpenClaw
人工智能·docker·openclaw
zadyd2 小时前
为什么GRPO更适合强逻辑内容的强化学习
人工智能
明月醉窗台2 小时前
Torch-TensorRT 相关
人工智能·目标检测·计算机视觉·目标跟踪