YOLO 11 图像分类推理 Web 服务

YOLO 11 图像分类推理 Web 服务

flyfish

python 复制代码
import os
import io
import uuid
import base64
import time
import torch
import torchvision.transforms as T
from PIL import Image, ImageDraw, ImageFont
from ultralytics import YOLO
from flask import Flask, request, jsonify
import datetime

# ==================== 配置参数 - 在此处修改配置 ====================
MODEL_PATH = "/home/user/yolo/runs/classify/train/weights/best.pt"  # 模型路径
OUTPUT_BASE = "inference_results"  # 结果保存的基础目录
IMGSZ = 320  # 图像尺寸
MAX_PER_FOLDER = 5000  # 每个文件夹最多存放的图片数量
PORT = 5000  # Web服务端口
HOST = "0.0.0.0"  # 监听地址,0.0.0.0表示允许所有网络访问
# ==================================================================

# 初始化Flask应用
app = Flask(__name__)

# 全局变量,用于存储模型(只加载一次)
model = None

# 确保中文正常显示
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['AR PL UMing CN']
plt.rcParams['axes.unicode_minus'] = False

def load_model():
    """加载YOLO模型,只在服务启动时调用一次"""
    global model
    try:
        model = YOLO(MODEL_PATH)
        print(f"成功加载模型: {MODEL_PATH}")
        return True
    except Exception as e:
        print(f"模型加载失败: {e}")
        return False

def preprocess_image(image, imgsz=IMGSZ):
    """图像预处理,与训练时保持一致"""
    # 转换为RGB格式
    img = image.convert('RGB')
    
    # 定义与训练时相同的预处理步骤
    transform = T.Compose([
        T.Resize((imgsz, imgsz)),
        T.ToTensor(),
        T.Normalize(mean=torch.tensor(0), std=torch.tensor(1)),
    ])
    
    # 预处理图像
    img_tensor = transform(img)
    # 添加批次维度
    img_tensor = img_tensor.unsqueeze(0)
    
    return img_tensor, img

def create_output_directory(base_dir, class_id, max_per_folder=MAX_PER_FOLDER):
    """创建输出目录,当图像数量超过max_per_folder时创建新的子目录"""
    # 基础类别目录
    class_dir = os.path.join(base_dir, f"class_{class_id}")
    
    # 检查是否需要创建子目录,从batch_0开始
    subdir_index = 0
    while True:
        current_dir = os.path.join(class_dir, f"batch_{subdir_index}")
        
        # 确保目录存在
        os.makedirs(current_dir, exist_ok=True)
        
        # 计算当前目录中的文件数量
        image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.bmp', '*.gif']
        file_count = 0
        for ext in image_extensions:
            file_count += len(os.listdir(current_dir)) if os.path.exists(current_dir) else 0
            
        if file_count < max_per_folder:
            return current_dir
        
        subdir_index += 1

def save_result_image(original_img, result, image_id):
    """保存带有推理结果的图像,文件名包含当前时间、图像id和不重复字符串"""
    # 创建可绘制的图像副本
    draw_img = original_img.copy()
    draw = ImageDraw.Draw(draw_img)
    
    # 获取结果信息
    class_id = result.probs.top1
    confidence = result.probs.top1conf.item()
    
    # 准备要显示的文本(不含class_name)
    text = f"(ID: {class_id}): {confidence:.4f}"
    
    # 设置字体(尝试使用系统字体)
    try:
        font = ImageFont.truetype(
            "/usr/share/fonts/truetype/wqy/wqy-zenhei.ttc", 
            16,
            index=0
        )
    except Exception as e:
        print(f"文泉驿字体加载失败: {e},将使用默认字体")
        font = ImageFont.load_default()
    
    # 在图像上绘制文本背景和文本
    text_bbox = draw.textbbox((10, 10), text, font=font)
    draw.rectangle([text_bbox[0]-2, text_bbox[1]-2, text_bbox[2]+2, text_bbox[3]+2], fill="white")
    draw.text((10, 10), text, font=font, fill=(255, 0, 0))  # 红色文本
    
    # 生成文件名:当前时间_图像id_不重复字符串
    current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    unique_str = str(uuid.uuid4())[:8]  # 取UUID的前8位作为不重复字符串
    file_name = f"{current_time}_{image_id}_{unique_str}.jpg"
    
    # 确定输出目录
    output_dir = create_output_directory(OUTPUT_BASE, class_id)
    
    # 保存图像
    output_path = os.path.join(output_dir, file_name)
    draw_img.save(output_path)
    
    return output_path

def process_image(image, image_id):
    """处理单张图像并返回推理结果"""
    try:
        # 预处理图像
        img_tensor, original_img = preprocess_image(image)
        
        # 推理
        results = model(img_tensor)
        
        # 解析结果
        result = results[0]
        class_id = result.probs.top1  # 最可能的类别ID
        confidence = result.probs.top1conf.item()  # 对应的置信度
        
        # 保存结果图像
        output_path = save_result_image(original_img, result, image_id)
        
        # 记录结果到文本文件
        current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        result_file = os.path.join(OUTPUT_BASE, "inference_results.txt")
        with open(result_file, 'a', encoding='utf-8') as f:
            result_str = f"[{current_time}] 图像ID: {image_id}\n"
            result_str += f"  预测ID: {class_id}\n"
            result_str += f"  置信度: {confidence:.4f}\n"
            result_str += f"  保存路径: {output_path}\n"
            result_str += "-"*50 + "\n"
            f.write(result_str)
        
        return {
            "success": True,
            "image_id": image_id,
            "current_time": current_time,
            "class_id": int(class_id),
            "confidence": float(confidence)
        }
        
    except Exception as e:
        error_msg = f"处理图像ID {image_id} 时出错: {str(e)}\n"
        print(error_msg)
        
        # 记录错误信息
        result_file = os.path.join(OUTPUT_BASE, "inference_results.txt")
        with open(result_file, 'a', encoding='utf-8') as f:
            f.write(f"[{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] {error_msg}")
        
        return {
            "success": False,
            "image_id": image_id,
            "error": str(e)
        }

@app.route('/infer', methods=['POST'])
def infer():
    """API接口:接收图像ID和base64编码的图像,返回推理结果"""
    # 检查模型是否已加载
    if model is None:
        return jsonify({
            "success": False,
            "error": "模型未加载,请检查服务状态"
        }), 500
    
    # 获取请求数据
    data = request.json
    
    # 验证请求数据
    if not data or 'image_id' not in data or 'image_base64' not in data:
        return jsonify({
            "success": False,
            "error": "请求数据缺少image_id或image_base64字段"
        }), 400
    
    try:
        # 解码base64图像
        image_data = base64.b64decode(data['image_base64'])
        image = Image.open(io.BytesIO(image_data))
        
        # 处理图像
        result = process_image(image, data['image_id'])
        
        return jsonify(result)
        
    except Exception as e:
        return jsonify({
            "success": False,
            "image_id": data.get('image_id'),
            "error": f"处理图像时出错: {str(e)}"
        }), 500

@app.route('/health', methods=['GET'])
def health_check():
    """健康检查接口"""
    if model is not None:
        return jsonify({
            "status": "healthy",
            "model_loaded": True,
            "timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        })
    else:
        return jsonify({
            "status": "unhealthy",
            "model_loaded": False,
            "timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        }), 500

if __name__ == "__main__":
    # 创建输出目录
    os.makedirs(OUTPUT_BASE, exist_ok=True)
    
    # 加载模型
    model_loaded = load_model()
    
    # 启动Web服务
    if model_loaded:
        print(f"服务启动,监听 {HOST}:{PORT}")
        app.run(host=HOST, port=PORT, threaded=True)  # threaded=True支持多线程处理请求
    else:
        print("模型加载失败,无法启动服务")
    

使用方式

  1. 启动服务:
bash 复制代码
python yolo11_web_service.py
  1. 客户端发送请求示例(使用Python):
python 复制代码
import requests
import base64

# 读取图片并转换为base64
with open("test.jpg", "rb") as f:
    image_base64 = base64.b64encode(f.read()).decode('utf-8')

# 准备请求数据
data = {
    "image_id": "test_001",
    "image_base64": image_base64
}

# 发送请求
response = requests.post("http://localhost:5000/infer", json=data)
print(response.json())
相关推荐
放羊郎5 小时前
SLAM算法分类对比
人工智能·算法·分类·数据挖掘·slam·视觉·激光
飞翔的佩奇1 天前
【完整源码+数据集+部署教程】鸡只与养殖场环境物品图像分割: yolov8-seg等50+全套改进创新点发刊_一键训练教程_Web前端展示
python·yolo·计算机视觉·数据集·yolov8·yolo11·鸡只与养殖场环境物品图像分割
Mitty_Li1 天前
食品分类的代码复习(无半监督部分,无迁移学习部分)
分类·数据挖掘·迁移学习
夏雨不在低喃1 天前
YOLOv8目标检测融合RFLA提高小目标准确率
人工智能·yolo·目标检测
小狗照亮每一天2 天前
【菜狗学聚类】序列嵌入表示、UMAP降维——20250930
算法·分类·聚类
尤超宇2 天前
基于卷积神经网络的 CIFAR-10 图像分类实验报告
人工智能·分类·cnn
哈基鑫2 天前
深度学习之图像分类笔记
笔记·深度学习·分类
Yolo566Q2 天前
基于PyTorch深度学习遥感影像地物分类与目标检测、分割及遥感影像问题深度学习优化实践技术应用
pytorch·深度学习·分类
AAIshangyanxiu2 天前
基于PyTorch深度学习遥感影像地物分类与目标检测、分割及遥感影像问题深度学习优化实践技术应用
pytorch·深度学习·分类·地物分类