YOLO 11 图像分类推理 Web 服务

YOLO 11 图像分类推理 Web 服务

flyfish

python 复制代码
import os
import io
import uuid
import base64
import time
import torch
import torchvision.transforms as T
from PIL import Image, ImageDraw, ImageFont
from ultralytics import YOLO
from flask import Flask, request, jsonify
import datetime

# ==================== 配置参数 - 在此处修改配置 ====================
MODEL_PATH = "/home/user/yolo/runs/classify/train/weights/best.pt"  # 模型路径
OUTPUT_BASE = "inference_results"  # 结果保存的基础目录
IMGSZ = 320  # 图像尺寸
MAX_PER_FOLDER = 5000  # 每个文件夹最多存放的图片数量
PORT = 5000  # Web服务端口
HOST = "0.0.0.0"  # 监听地址,0.0.0.0表示允许所有网络访问
# ==================================================================

# 初始化Flask应用
app = Flask(__name__)

# 全局变量,用于存储模型(只加载一次)
model = None

# 确保中文正常显示
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['AR PL UMing CN']
plt.rcParams['axes.unicode_minus'] = False

def load_model():
    """加载YOLO模型,只在服务启动时调用一次"""
    global model
    try:
        model = YOLO(MODEL_PATH)
        print(f"成功加载模型: {MODEL_PATH}")
        return True
    except Exception as e:
        print(f"模型加载失败: {e}")
        return False

def preprocess_image(image, imgsz=IMGSZ):
    """图像预处理,与训练时保持一致"""
    # 转换为RGB格式
    img = image.convert('RGB')
    
    # 定义与训练时相同的预处理步骤
    transform = T.Compose([
        T.Resize((imgsz, imgsz)),
        T.ToTensor(),
        T.Normalize(mean=torch.tensor(0), std=torch.tensor(1)),
    ])
    
    # 预处理图像
    img_tensor = transform(img)
    # 添加批次维度
    img_tensor = img_tensor.unsqueeze(0)
    
    return img_tensor, img

def create_output_directory(base_dir, class_id, max_per_folder=MAX_PER_FOLDER):
    """创建输出目录,当图像数量超过max_per_folder时创建新的子目录"""
    # 基础类别目录
    class_dir = os.path.join(base_dir, f"class_{class_id}")
    
    # 检查是否需要创建子目录,从batch_0开始
    subdir_index = 0
    while True:
        current_dir = os.path.join(class_dir, f"batch_{subdir_index}")
        
        # 确保目录存在
        os.makedirs(current_dir, exist_ok=True)
        
        # 计算当前目录中的文件数量
        image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.bmp', '*.gif']
        file_count = 0
        for ext in image_extensions:
            file_count += len(os.listdir(current_dir)) if os.path.exists(current_dir) else 0
            
        if file_count < max_per_folder:
            return current_dir
        
        subdir_index += 1

def save_result_image(original_img, result, image_id):
    """保存带有推理结果的图像,文件名包含当前时间、图像id和不重复字符串"""
    # 创建可绘制的图像副本
    draw_img = original_img.copy()
    draw = ImageDraw.Draw(draw_img)
    
    # 获取结果信息
    class_id = result.probs.top1
    confidence = result.probs.top1conf.item()
    
    # 准备要显示的文本(不含class_name)
    text = f"(ID: {class_id}): {confidence:.4f}"
    
    # 设置字体(尝试使用系统字体)
    try:
        font = ImageFont.truetype(
            "/usr/share/fonts/truetype/wqy/wqy-zenhei.ttc", 
            16,
            index=0
        )
    except Exception as e:
        print(f"文泉驿字体加载失败: {e},将使用默认字体")
        font = ImageFont.load_default()
    
    # 在图像上绘制文本背景和文本
    text_bbox = draw.textbbox((10, 10), text, font=font)
    draw.rectangle([text_bbox[0]-2, text_bbox[1]-2, text_bbox[2]+2, text_bbox[3]+2], fill="white")
    draw.text((10, 10), text, font=font, fill=(255, 0, 0))  # 红色文本
    
    # 生成文件名:当前时间_图像id_不重复字符串
    current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    unique_str = str(uuid.uuid4())[:8]  # 取UUID的前8位作为不重复字符串
    file_name = f"{current_time}_{image_id}_{unique_str}.jpg"
    
    # 确定输出目录
    output_dir = create_output_directory(OUTPUT_BASE, class_id)
    
    # 保存图像
    output_path = os.path.join(output_dir, file_name)
    draw_img.save(output_path)
    
    return output_path

def process_image(image, image_id):
    """处理单张图像并返回推理结果"""
    try:
        # 预处理图像
        img_tensor, original_img = preprocess_image(image)
        
        # 推理
        results = model(img_tensor)
        
        # 解析结果
        result = results[0]
        class_id = result.probs.top1  # 最可能的类别ID
        confidence = result.probs.top1conf.item()  # 对应的置信度
        
        # 保存结果图像
        output_path = save_result_image(original_img, result, image_id)
        
        # 记录结果到文本文件
        current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        result_file = os.path.join(OUTPUT_BASE, "inference_results.txt")
        with open(result_file, 'a', encoding='utf-8') as f:
            result_str = f"[{current_time}] 图像ID: {image_id}\n"
            result_str += f"  预测ID: {class_id}\n"
            result_str += f"  置信度: {confidence:.4f}\n"
            result_str += f"  保存路径: {output_path}\n"
            result_str += "-"*50 + "\n"
            f.write(result_str)
        
        return {
            "success": True,
            "image_id": image_id,
            "current_time": current_time,
            "class_id": int(class_id),
            "confidence": float(confidence)
        }
        
    except Exception as e:
        error_msg = f"处理图像ID {image_id} 时出错: {str(e)}\n"
        print(error_msg)
        
        # 记录错误信息
        result_file = os.path.join(OUTPUT_BASE, "inference_results.txt")
        with open(result_file, 'a', encoding='utf-8') as f:
            f.write(f"[{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] {error_msg}")
        
        return {
            "success": False,
            "image_id": image_id,
            "error": str(e)
        }

@app.route('/infer', methods=['POST'])
def infer():
    """API接口:接收图像ID和base64编码的图像,返回推理结果"""
    # 检查模型是否已加载
    if model is None:
        return jsonify({
            "success": False,
            "error": "模型未加载,请检查服务状态"
        }), 500
    
    # 获取请求数据
    data = request.json
    
    # 验证请求数据
    if not data or 'image_id' not in data or 'image_base64' not in data:
        return jsonify({
            "success": False,
            "error": "请求数据缺少image_id或image_base64字段"
        }), 400
    
    try:
        # 解码base64图像
        image_data = base64.b64decode(data['image_base64'])
        image = Image.open(io.BytesIO(image_data))
        
        # 处理图像
        result = process_image(image, data['image_id'])
        
        return jsonify(result)
        
    except Exception as e:
        return jsonify({
            "success": False,
            "image_id": data.get('image_id'),
            "error": f"处理图像时出错: {str(e)}"
        }), 500

@app.route('/health', methods=['GET'])
def health_check():
    """健康检查接口"""
    if model is not None:
        return jsonify({
            "status": "healthy",
            "model_loaded": True,
            "timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        })
    else:
        return jsonify({
            "status": "unhealthy",
            "model_loaded": False,
            "timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        }), 500

if __name__ == "__main__":
    # 创建输出目录
    os.makedirs(OUTPUT_BASE, exist_ok=True)
    
    # 加载模型
    model_loaded = load_model()
    
    # 启动Web服务
    if model_loaded:
        print(f"服务启动,监听 {HOST}:{PORT}")
        app.run(host=HOST, port=PORT, threaded=True)  # threaded=True支持多线程处理请求
    else:
        print("模型加载失败,无法启动服务")
    

使用方式

  1. 启动服务:
bash 复制代码
python yolo11_web_service.py
  1. 客户端发送请求示例(使用Python):
python 复制代码
import requests
import base64

# 读取图片并转换为base64
with open("test.jpg", "rb") as f:
    image_base64 = base64.b64encode(f.read()).decode('utf-8')

# 准备请求数据
data = {
    "image_id": "test_001",
    "image_base64": image_base64
}

# 发送请求
response = requests.post("http://localhost:5000/infer", json=data)
print(response.json())
相关推荐
newxtc1 小时前
【山西政务服务网-注册_登录安全分析报告】
selenium·安全·yolo·政务·安全爆破
Coovally AI模型快速验证1 小时前
IDEA研究院发布Rex-Omni:3B参数MLLM重塑目标检测,零样本性能超越DINO
人工智能·深度学习·yolo·目标检测·计算机视觉·目标跟踪
来酱何人11 小时前
实时NLP数据处理:流数据的清洗、特征提取与模型推理适配
人工智能·深度学习·分类·nlp·bert
XIAO·宝14 小时前
深度学习------YOLOV3
人工智能·深度学习·yolo
newxtc16 小时前
【广州公共资源交易-注册安全分析报告-无验证方式导致安全隐患】
开发语言·selenium·安全·yolo
总有刁民想爱朕ha17 小时前
YOLO目标检测:一种用于无人机的新型轻量级目标检测网络
yolo·目标检测·无人机
baole96318 小时前
YOLOv4简单基础学习
学习·yolo·目标跟踪
我叫侯小科20 小时前
YOLOv4:目标检测界的 “集大成者”
人工智能·yolo·目标检测
麒羽76021 小时前
YOLOv4:目标检测领域的 “速度与精度平衡大师”
yolo·目标检测·目标跟踪
前网易架构师-高司机1 天前
鸡蛋质量识别数据集,可识别染血的鸡蛋,棕色鸡蛋,钙沉积鸡蛋,污垢染色的鸡蛋,白鸡蛋,平均正确识别率可达89%,支持yolo, json, xml格式的标注
yolo·分类·数据集·缺陷·鸡蛋