YOLO 11 图像分类推理 Web 服务
flyfish
python
import os
import io
import uuid
import base64
import time
import torch
import torchvision.transforms as T
from PIL import Image, ImageDraw, ImageFont
from ultralytics import YOLO
from flask import Flask, request, jsonify
import datetime
# ==================== 配置参数 - 在此处修改配置 ====================
MODEL_PATH = "/home/user/yolo/runs/classify/train/weights/best.pt" # 模型路径
OUTPUT_BASE = "inference_results" # 结果保存的基础目录
IMGSZ = 320 # 图像尺寸
MAX_PER_FOLDER = 5000 # 每个文件夹最多存放的图片数量
PORT = 5000 # Web服务端口
HOST = "0.0.0.0" # 监听地址,0.0.0.0表示允许所有网络访问
# ==================================================================
# 初始化Flask应用
app = Flask(__name__)
# 全局变量,用于存储模型(只加载一次)
model = None
# 确保中文正常显示
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['AR PL UMing CN']
plt.rcParams['axes.unicode_minus'] = False
def load_model():
"""加载YOLO模型,只在服务启动时调用一次"""
global model
try:
model = YOLO(MODEL_PATH)
print(f"成功加载模型: {MODEL_PATH}")
return True
except Exception as e:
print(f"模型加载失败: {e}")
return False
def preprocess_image(image, imgsz=IMGSZ):
"""图像预处理,与训练时保持一致"""
# 转换为RGB格式
img = image.convert('RGB')
# 定义与训练时相同的预处理步骤
transform = T.Compose([
T.Resize((imgsz, imgsz)),
T.ToTensor(),
T.Normalize(mean=torch.tensor(0), std=torch.tensor(1)),
])
# 预处理图像
img_tensor = transform(img)
# 添加批次维度
img_tensor = img_tensor.unsqueeze(0)
return img_tensor, img
def create_output_directory(base_dir, class_id, max_per_folder=MAX_PER_FOLDER):
"""创建输出目录,当图像数量超过max_per_folder时创建新的子目录"""
# 基础类别目录
class_dir = os.path.join(base_dir, f"class_{class_id}")
# 检查是否需要创建子目录,从batch_0开始
subdir_index = 0
while True:
current_dir = os.path.join(class_dir, f"batch_{subdir_index}")
# 确保目录存在
os.makedirs(current_dir, exist_ok=True)
# 计算当前目录中的文件数量
image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.bmp', '*.gif']
file_count = 0
for ext in image_extensions:
file_count += len(os.listdir(current_dir)) if os.path.exists(current_dir) else 0
if file_count < max_per_folder:
return current_dir
subdir_index += 1
def save_result_image(original_img, result, image_id):
"""保存带有推理结果的图像,文件名包含当前时间、图像id和不重复字符串"""
# 创建可绘制的图像副本
draw_img = original_img.copy()
draw = ImageDraw.Draw(draw_img)
# 获取结果信息
class_id = result.probs.top1
confidence = result.probs.top1conf.item()
# 准备要显示的文本(不含class_name)
text = f"(ID: {class_id}): {confidence:.4f}"
# 设置字体(尝试使用系统字体)
try:
font = ImageFont.truetype(
"/usr/share/fonts/truetype/wqy/wqy-zenhei.ttc",
16,
index=0
)
except Exception as e:
print(f"文泉驿字体加载失败: {e},将使用默认字体")
font = ImageFont.load_default()
# 在图像上绘制文本背景和文本
text_bbox = draw.textbbox((10, 10), text, font=font)
draw.rectangle([text_bbox[0]-2, text_bbox[1]-2, text_bbox[2]+2, text_bbox[3]+2], fill="white")
draw.text((10, 10), text, font=font, fill=(255, 0, 0)) # 红色文本
# 生成文件名:当前时间_图像id_不重复字符串
current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
unique_str = str(uuid.uuid4())[:8] # 取UUID的前8位作为不重复字符串
file_name = f"{current_time}_{image_id}_{unique_str}.jpg"
# 确定输出目录
output_dir = create_output_directory(OUTPUT_BASE, class_id)
# 保存图像
output_path = os.path.join(output_dir, file_name)
draw_img.save(output_path)
return output_path
def process_image(image, image_id):
"""处理单张图像并返回推理结果"""
try:
# 预处理图像
img_tensor, original_img = preprocess_image(image)
# 推理
results = model(img_tensor)
# 解析结果
result = results[0]
class_id = result.probs.top1 # 最可能的类别ID
confidence = result.probs.top1conf.item() # 对应的置信度
# 保存结果图像
output_path = save_result_image(original_img, result, image_id)
# 记录结果到文本文件
current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
result_file = os.path.join(OUTPUT_BASE, "inference_results.txt")
with open(result_file, 'a', encoding='utf-8') as f:
result_str = f"[{current_time}] 图像ID: {image_id}\n"
result_str += f" 预测ID: {class_id}\n"
result_str += f" 置信度: {confidence:.4f}\n"
result_str += f" 保存路径: {output_path}\n"
result_str += "-"*50 + "\n"
f.write(result_str)
return {
"success": True,
"image_id": image_id,
"current_time": current_time,
"class_id": int(class_id),
"confidence": float(confidence)
}
except Exception as e:
error_msg = f"处理图像ID {image_id} 时出错: {str(e)}\n"
print(error_msg)
# 记录错误信息
result_file = os.path.join(OUTPUT_BASE, "inference_results.txt")
with open(result_file, 'a', encoding='utf-8') as f:
f.write(f"[{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] {error_msg}")
return {
"success": False,
"image_id": image_id,
"error": str(e)
}
@app.route('/infer', methods=['POST'])
def infer():
"""API接口:接收图像ID和base64编码的图像,返回推理结果"""
# 检查模型是否已加载
if model is None:
return jsonify({
"success": False,
"error": "模型未加载,请检查服务状态"
}), 500
# 获取请求数据
data = request.json
# 验证请求数据
if not data or 'image_id' not in data or 'image_base64' not in data:
return jsonify({
"success": False,
"error": "请求数据缺少image_id或image_base64字段"
}), 400
try:
# 解码base64图像
image_data = base64.b64decode(data['image_base64'])
image = Image.open(io.BytesIO(image_data))
# 处理图像
result = process_image(image, data['image_id'])
return jsonify(result)
except Exception as e:
return jsonify({
"success": False,
"image_id": data.get('image_id'),
"error": f"处理图像时出错: {str(e)}"
}), 500
@app.route('/health', methods=['GET'])
def health_check():
"""健康检查接口"""
if model is not None:
return jsonify({
"status": "healthy",
"model_loaded": True,
"timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
})
else:
return jsonify({
"status": "unhealthy",
"model_loaded": False,
"timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
}), 500
if __name__ == "__main__":
# 创建输出目录
os.makedirs(OUTPUT_BASE, exist_ok=True)
# 加载模型
model_loaded = load_model()
# 启动Web服务
if model_loaded:
print(f"服务启动,监听 {HOST}:{PORT}")
app.run(host=HOST, port=PORT, threaded=True) # threaded=True支持多线程处理请求
else:
print("模型加载失败,无法启动服务")
使用方式
- 启动服务:
bash
python yolo11_web_service.py
- 客户端发送请求示例(使用Python):
python
import requests
import base64
# 读取图片并转换为base64
with open("test.jpg", "rb") as f:
image_base64 = base64.b64encode(f.read()).decode('utf-8')
# 准备请求数据
data = {
"image_id": "test_001",
"image_base64": image_base64
}
# 发送请求
response = requests.post("http://localhost:5000/infer", json=data)
print(response.json())