平台相关常用接口、函数

1、接口主程序

1)接口参数数据校验。jsonschema

2)异常处理。

复制代码
# 方法一:
try:
    pass
except Exception as err:
    pass
复制代码
# 方法二,此方法,可将其编写在其它文件子函数里,并实现进行一步异常处理统一返回给前端
raise ValueError('....')

3)单/多进程 + 协程

python 复制代码
# -*- coding:utf-8 -*-
import json
import os
import copy
import time
import torch

from gevent import pywsgi, monkey
# 多线程,非阻塞
monkey.patch_all()

from flask import Flask, request, jsonify
from flask_cors import CORS
from multiprocessing import cpu_count, Process
from jsonschema import validate, ValidationError

from detect import inference_main
from train import train_main
from utils.job_manager import kill_process_by_port, kill_process_by_name
from utils.logger import get_logger

# 日志
log_file = './logs/yolov5.log'
logger = get_logger(name='yolov5', log_file=log_file)

# flask服务
app = Flask(__name__)
CORS(app, resources=r'/*')
app.config['JSON_AS_ASCII'] = False


# http接口参数校验
# 接口http://ip:port/train的用户校验schema字典定义
schema_train = {
    "type": "object",
    "required": ["event_id", "event_type", "payload"],
    "properties": {
        "event_id": {
            "type": "integer",
        },
        "event_type": {
            "type": "string",
        },
        "payload": {
            "type": "object",
            "required": ["data_config", 'basic_hyp'],
            "properties": {
                "data_config": {"type": "object",
                                "required": ["train", "val", "nc", "names", "result_path"],
                                "properties": {"train": {"type": "string"},
                                               "val": {"type": "string"},
                                               "nc": {"type": "integer", "minimum": 1},
                                               "names": {"type": "array"},
                                               "result_path": {"type": "string"}
                                               }},
                "basic_hyp": {"type": "object",
                              "required": ["epochs", "batch-size", "workers", "img-size", "device"],
                              "properties": {"epochs": {"type": "integer", "minimum": 1},
                                             "batch-size": {"type": "integer", "minimum": 1},
                                             "workers": {"type": "integer", "minimum": 0},
                                             "img-size": {"type": "array"},
                                             "device": {"type": "string"}
                                             }},
            }
        }
    }
}

# 接口http://ip:port/inference的用户校验schema字典定义
schema_inference = {
    "type": "object",
    "required": ["event_id", "event_type", "payload"],
    "properties": {
        "event_id": {
            "type": "integer",
        },
        "event_type": {
            "type": "string",
        },
        "payload": {
            "type": "object",
            "required": ["data_config", "basic_hyp"],
            "properties": {
                "data_config": {"type": "object",
                                "required": ["test", "result_path"],
                                "properties": {"test": {"type": "string"},
                                               "result_path": {"type": "string"}
                                               }
                                },
                "basic_hyp": {"type": "object",
                              "required": ["weights", "device"],
                              "properties": {"weights": {"type": "string"},
                                             "device": {"type": "string"}
                                             }
                              }
            }
        }
    }
}


# data参数校验装饰器,可指定不同的校验schema
def json_validate(schema):
    def wrapper(func):
        def inner(data, *args, **kwargs):
            try:
                validate(data, schema)
            except ValidationError as e:
                logger.error("接口参数校验失败:{}!".format(e.message))
                return {'error': True, 'msg': e.message}
            else:
                logger.info("接口参数校验通过!")
                return func(data, *args, **kwargs)
        return inner
    return wrapper


def api_result(event_id, state_code, msg_type, msg, result):
    """
    构建接口返回结果
    """
    api_res = {
        "event_id": event_id,
        "state_code": state_code,
        "feed_type": msg_type,
        "feed_msg": msg,
        "feed_data": result,
    }
    logger.info("feed_msg: {}".format(api_res))
    logger.info("=====================================================\n")
    return jsonify(api_res)


@app.route('/train', methods=['POST'])
def train_post():
    """
    模型训练接口
    Returns:

    """
    if request.method == "POST":
        try:
            # 解析请求参数
            @json_validate(schema=schema_train)
            def api_parameters(msg_dict_copy):
                logger.info("启动模型训练.......")
                return msg_dict_copy
            request_data = request.get_data().decode()
            msg_dict = json.loads(request_data)
            # msg_dict_copy = copy.deepcopy(msg_dict)
            # msg_dict_copy = msg_dict
            logger.info("request msg: {}".format(msg_dict))
            # 校验参数
            msg_dict['error'] = False
            validate_msg = api_parameters(msg_dict)
            # 训练
            if not validate_msg['error']:
                result_path = train_main(msg_dict)
                result = {"result_path": result_path}
                return api_result(msg_dict['event_id'], 200, 'train', 'success', result)
            else:
                return api_result(msg_dict['event_id'], 501, 'train', '参数设置有误,请核查,错误信息:{}'.format(validate_msg['msg']), None)

        except Exception as e:
            logger.error(e)
            return api_result(msg_dict['event_id'], 500, "train", str(e), None)

    else:
        feed_msg = "error, request.method != POST"
        return api_result("101010", 400, "train", feed_msg, None)


@app.route('/inference', methods=['POST'])
def inference_post():
    """
    模型推理接口
    Returns:

    """
    if request.method == "POST":
        try:
            # 解析请求参数
            @json_validate(schema=schema_inference)
            def api_parameters(msg_dict_copy):
                logger.info("启动模型推理.......")
                return msg_dict_copy

            request_data = request.get_data().decode()
            msg_dict = json.loads(request_data)
            # msg_dict_copy = copy.deepcopy(msg_dict)
            # msg_dict_copy = msg_dict
            logger.info("request msg: {}".format(msg_dict))
            # 校验参数
            msg_dict['error'] = False
            validate_msg = api_parameters(msg_dict)
            # 推理
            if not validate_msg['error']:
                result_path = inference_main(msg_dict)
                result = {"result_path": result_path}
                return api_result(msg_dict['event_id'], 200, 'inference', 'success', result)
            else:
                return api_result(msg_dict['event_id'], 501, 'inference', '参数设置有误,请核查,错误信息:{}'.format(validate_msg['msg']), None)

        except Exception as e:
            logger.error(e)
            return api_result(msg_dict['event_id'], 500, "train", str(e), None)
    else:
        feed_msg = "error, request.method != POST"
        return api_result("101010", 400, "inference", feed_msg, None)


def start_app(MULTI_PROCESS=False, USE_CORES=1):
    """
    启动服务
    Returns:
    """
    # 先清空显存占用
    torch.cuda.empty_cache()
    try:
        logger.info("\n===============================================================================")
        logger.info("deeplearn server starting...")
        # 持久化服务
        if MULTI_PROCESS == False:
            server = pywsgi.WSGIServer(("0.0.0.0", 8080), app)
            server.serve_forever()
            logger.info("deeplearn server start success.")
            print('单进程 + 协程')
            return
        else:
            mulserver = pywsgi.WSGIServer(('0.0.0.0', 8080), app)
            mulserver.start()

            def server_forever():
                mulserver.start_accepting()
                mulserver._stop_event.wait()

            all_cpu_cores = cpu_count()
            if USE_CORES > all_cpu_cores:
                use_cores = all_cpu_cores
            else:
                use_cores = USE_CORES
            for i in range(use_cores):
                p = Process(target=server_forever)
                p.start()
            print('多进程 + 协程,进程数:{}+1'.format(use_cores))
            return

    except Exception as err:
        logger.error("exception in server: {}".format(err))
        logger.error("a same service port has been started. please shut down before operation.")
        try:
            logger.error("{}".format(kill_process_by_port(8080)))
        except Exception as err:
            logger.error("exception in server: {}".format(err))



def stop_app():
    """
    结束服务
    Returns:
    """
    logger.info("\n===============================================================================")
    logger.warning("deeplearn server stopping...")
    try:
        logger.info("stop info: {}".format(kill_process_by_port(8080)))
    except Exception as err:
        logger.error("stop err: {}".format(err, kill_process_by_name("python.exe")))


if __name__ == "__main__":
    # app.run(port=8080, host="0.0.0.0", )
    MULTI_PROCESS = True
    # 默认启动2+1进程
    USE_CORES = int(os.getenv('USE_CORES')) if os.getenv('USE_CORES') else 2
    start_app(MULTI_PROCESS=MULTI_PROCESS, USE_CORES=USE_CORES)
    # stop_app()

2、进程相关常用函数

python 复制代码
# job_manager.py
# -*- coding:utf-8 -*-

import os
import psutil


def get_all_process():
    pid_dict = {}
    pids = psutil.pids()
    try:
        for pid in pids:
            p = psutil.Process(pid)
            pid_dict[pid] = p.name()
    except Exception as err:
        pass
    return pid_dict

def find_pid_by_name(name: str):
    """
    根据进程名获取进程pid
    Args:
        name: process name

    Returns: process pid

    """
    pros = psutil.process_iter()
    print("[" + name + "]'s pid is:")
    pids = []
    for pro in pros:
        if (pro.name() == name):
            print(pro.pid)
            pids.append(pro.pid)
    return pids

def find_port_by_pid(pid: int):
    """根据pid寻找该进程对应的端口"""
    alist = []
    # 获取当前的网络连接信息
    net_con = psutil.net_connections()
    for con_info in net_con:
        if con_info.pid == pid:
            alist.append({pid: con_info.laddr.port})
    return alist


def find_pid_by_port(port: int):
    """根据端口寻找该进程对应的pid"""
    pid_list = []
    # 获取当前的网络连接信息
    net_con = psutil.net_connections()
    for con_info in net_con:
        if con_info.laddr.port == port:
            pid_list.append(con_info.pid)
    return pid_list


def kill_process_by_pid(pid):
    # windows
    # cmd = 'taskkill /pid ' + pid + ' /f'
    cmd = 'kill -9 ' + pid
    try:
        os.system(cmd)
    except Exception as e:
        print(e)


def kill_process_by_name(set_name):
    all_pid = get_all_process()
    for pid, name in all_pid.items():
        if name == set_name:
            kill_process_by_pid(str(pid))
    msg_str = "kill process in name: {}".format(set_name)
    return msg_str


def kill_process_by_port(port):
    pids = find_pid_by_port(port)
    for pid in pids:
        kill_process_by_pid(str(pid))

    msg_str = "kill process in port: {}".format(port)
    return msg_str


def clean_cmd():
    kill_process_by_name("cmd.exe")
    kill_process_by_name("bash.exe")


if __name__ == "__main__":
    # kill_process_by_port(8010)
    # kill_process_by_name("python.exe")
    # kill_process_by_name("cmd.exe")
    # kill_process_by_name("bash.exe")
    # kill_process_by_name("myProcess")
    # print(find_pid_by_port('8080'))
    print(find_pid_by_name('myProcess'))

3、日志模块

python 复制代码
# logger.py

# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import sys
import logging
from logging.handlers import RotatingFileHandler
import functools

logger_initialized = {}


@functools.lru_cache()
def get_logger(name='root', log_file=None, log_level=logging.INFO):
    """Initialize and get a logger by name.
    If the logger has not been initialized, this method will initialize the
    logger by adding one or two handlers, otherwise the initialized logger will
    be directly returned. During initialization, a StreamHandler will always be
    added. If `log_file` is specified a FileHandler will also be added.
    Args:
        name (str): Logger name.
        log_file (str | None): The log filename. If specified, a FileHandler
            will be added to the logger.
        log_level (int): The logger level. Note that only the process of
            rank 0 is affected, and other processes will set the level to
            "Error" thus be silent most of the time.
    Returns:
        logging.Logger: The expected logger.
    """
    logger = logging.getLogger(name)
    if name in logger_initialized:
        return logger
    for logger_name in logger_initialized:
        if name == logger_name:
            return logger

    formatter = logging.Formatter(
        '[%(asctime)s.%(msecs)03d] %(name)s %(levelname)s: %(message)s', datefmt="%Y/%m/%d %H:%M:%S")

    stream_handler = logging.StreamHandler(stream=sys.stdout)
    stream_handler.setFormatter(formatter)
    logger.addHandler(stream_handler)
    if log_file is not None:
        log_file_folder = os.path.split(log_file)[0]
        os.makedirs(log_file_folder, exist_ok=True)
        # file_handler = logging.FileHandler(log_file, 'a')
        file_handler = RotatingFileHandler(filename=log_file, maxBytes=10 * 1024 * 1024, backupCount=15, encoding='utf-8')
        file_handler.setFormatter(formatter)
        logger.addHandler(file_handler)
        logger.setLevel(log_level)

    logger_initialized[name] = True
    return logger

if __name__ == "__main__":
    # 日志
    log_file = './logs/yolov5.log'
    logger = get_logger(name='yolov5', log_file=log_file)

4、dockefile 构建镜像

docker build -t wood_detect:test .

python 复制代码
# Dockerfile

# Start FROM Nvidia PyTorch image https://ngc.nvidia.com/catalog/containers/nvidia:pytorch
#FROM nvcr.io/nvidia/pytorch:21.05-py3
#FROM pytorch/pytorch:1.7.0-cuda11.0-cudnn8-runtime
FROM deploy.hello.com/2020-public/yolov5_base:1.0.1

# Install linux packages
#RUN apt update && apt install -y zip htop screen libgl1-mesa-glx

## Create working directory
#RUN mkdir -p /usr/src/app
WORKDIR /usr/src/app

# Copy contents
COPY . /usr/src/app
EXPOSE 8080
# Install python dependencies
#COPY requirements.txt .
# RUN python -m pip install --upgrade pip
#RUN pip uninstall -y nvidia-tensorboard nvidia-tensorboard-plugin-dlprof
#RUN pip install --no-cache -r requirements.txt coremltools onnx gsutil -i https://pypi.douban.com/simple/
RUN pip install --no-cache -r requirements.txt -i https://pypi.douban.com/simple/

# RUN pip install --no-cache -U torch torchvision

## Set environment variables
#ENV HOME=/usr/src/app
#
ENTRYPOINT ["python","main_app.py"]
# ---------------------------------------------------  Extras Below  ---------------------------------------------------

# Build and Push
# t=ultralytics/yolov5:latest && sudo docker build -t $t . && sudo docker push $t
# for v in {300..303}; do t=ultralytics/coco:v$v && sudo docker build -t $t . && sudo docker push $t; done

# Pull and Run
# t=ultralytics/yolov5:latest && sudo docker pull $t && sudo docker run -it --ipc=host --gpus all $t

# Pull and Run with local directory access
# t=ultralytics/yolov5:latest && sudo docker pull $t && sudo docker run -it --ipc=host --gpus all -v "$(pwd)"/coco:/usr/src/coco $t

# Kill all
# sudo docker kill $(sudo docker ps -q)

# Kill all image-based
# sudo docker kill $(sudo docker ps -qa --filter ancestor=ultralytics/yolov5:latest)

# Bash into running container
# sudo docker exec -it 5a9b5863d93d bash

# Bash into stopped container
# id=$(sudo docker ps -qa) && sudo docker start $id && sudo docker exec -it $id bash

# Send weights to GCP
# python -c "from utils.general import *; strip_optimizer('runs/train/exp0_*/weights/best.pt', 'tmp.pt')" && gsutil cp tmp.pt gs://*.pt

# Clean up
# docker system prune -a --volumes

5、Docker-compose.yml方法构建镜像并部署

复制代码
docker-compose up
python 复制代码
# Docker-compose.yml
# GPU配置,参考https://docs.docker.com/compose/gpu-support/

version: "3.8"

services:
  yolov5:
    build:
      context: .
    image: deploy.deepexi.com/2048-public/yolov5_server:alpha_v1.0
    restart: always
    container_name: yolov5_server
    ports:
      - 8080:8080
    deploy:
      resources:
        reservations:
          devices:
            - driver: nvidia
              device_ids: [ '0',]
              capabilities: [ gpu ]

6、重写yaml文件

python 复制代码
# --hyp,data/hyp.scratch.yaml文件
with open(opt.hyp) as f:
    hyp = yaml.safe_load(f)
    if 'lr0' in yolo_hype and isinstance(yolo_hype['lr0'], float) and yolo_hype['lr0'] >= 0.0:
        hyp['lr0'] = yolo_hype['lr0']
    yaml.safe_dump(hyp, open(opt.hyp, mode='w'))
相关推荐
火车叼位14 分钟前
也许你不需要创建.venv, 此规范使python脚本自备依赖
python
暮色妖娆丶19 分钟前
SpringBoot 启动流程源码分析 ~ 它其实不复杂
spring boot·后端·spring
火车叼位21 分钟前
脚本伪装:让 Python 与 Node.js 像原生 Shell 命令一样运行
运维·javascript·python
Coder_Boy_26 分钟前
Deeplearning4j+ Spring Boot 电商用户复购预测案例中相关概念
java·人工智能·spring boot·后端·spring
孤狼warrior32 分钟前
YOLO目标检测 一千字解析yolo最初的摸样 模型下载,数据集构建及模型训练代码
人工智能·python·深度学习·算法·yolo·目标检测·目标跟踪
Java后端的Ai之路37 分钟前
【Spring全家桶】-一文弄懂Spring Cloud Gateway
java·后端·spring cloud·gateway
Katecat9966342 分钟前
YOLO11分割算法实现甲状腺超声病灶自动检测与定位_DWR方法应用
python
野犬寒鸦42 分钟前
从零起步学习并发编程 || 第七章:ThreadLocal深层解析及常见问题解决方案
java·服务器·开发语言·jvm·后端·学习
玩大数据的龙威1 小时前
农经权二轮延包—各种地块示意图
python·arcgis
ZH15455891311 小时前
Flutter for OpenHarmony Python学习助手实战:数据库操作与管理的实现
python·学习·flutter