平台相关常用接口、函数

1、接口主程序

1)接口参数数据校验。jsonschema

2)异常处理。

复制代码
# 方法一:
try:
    pass
except Exception as err:
    pass
复制代码
# 方法二,此方法,可将其编写在其它文件子函数里,并实现进行一步异常处理统一返回给前端
raise ValueError('....')

3)单/多进程 + 协程

python 复制代码
# -*- coding:utf-8 -*-
import json
import os
import copy
import time
import torch

from gevent import pywsgi, monkey
# 多线程,非阻塞
monkey.patch_all()

from flask import Flask, request, jsonify
from flask_cors import CORS
from multiprocessing import cpu_count, Process
from jsonschema import validate, ValidationError

from detect import inference_main
from train import train_main
from utils.job_manager import kill_process_by_port, kill_process_by_name
from utils.logger import get_logger

# 日志
log_file = './logs/yolov5.log'
logger = get_logger(name='yolov5', log_file=log_file)

# flask服务
app = Flask(__name__)
CORS(app, resources=r'/*')
app.config['JSON_AS_ASCII'] = False


# http接口参数校验
# 接口http://ip:port/train的用户校验schema字典定义
schema_train = {
    "type": "object",
    "required": ["event_id", "event_type", "payload"],
    "properties": {
        "event_id": {
            "type": "integer",
        },
        "event_type": {
            "type": "string",
        },
        "payload": {
            "type": "object",
            "required": ["data_config", 'basic_hyp'],
            "properties": {
                "data_config": {"type": "object",
                                "required": ["train", "val", "nc", "names", "result_path"],
                                "properties": {"train": {"type": "string"},
                                               "val": {"type": "string"},
                                               "nc": {"type": "integer", "minimum": 1},
                                               "names": {"type": "array"},
                                               "result_path": {"type": "string"}
                                               }},
                "basic_hyp": {"type": "object",
                              "required": ["epochs", "batch-size", "workers", "img-size", "device"],
                              "properties": {"epochs": {"type": "integer", "minimum": 1},
                                             "batch-size": {"type": "integer", "minimum": 1},
                                             "workers": {"type": "integer", "minimum": 0},
                                             "img-size": {"type": "array"},
                                             "device": {"type": "string"}
                                             }},
            }
        }
    }
}

# 接口http://ip:port/inference的用户校验schema字典定义
schema_inference = {
    "type": "object",
    "required": ["event_id", "event_type", "payload"],
    "properties": {
        "event_id": {
            "type": "integer",
        },
        "event_type": {
            "type": "string",
        },
        "payload": {
            "type": "object",
            "required": ["data_config", "basic_hyp"],
            "properties": {
                "data_config": {"type": "object",
                                "required": ["test", "result_path"],
                                "properties": {"test": {"type": "string"},
                                               "result_path": {"type": "string"}
                                               }
                                },
                "basic_hyp": {"type": "object",
                              "required": ["weights", "device"],
                              "properties": {"weights": {"type": "string"},
                                             "device": {"type": "string"}
                                             }
                              }
            }
        }
    }
}


# data参数校验装饰器,可指定不同的校验schema
def json_validate(schema):
    def wrapper(func):
        def inner(data, *args, **kwargs):
            try:
                validate(data, schema)
            except ValidationError as e:
                logger.error("接口参数校验失败:{}!".format(e.message))
                return {'error': True, 'msg': e.message}
            else:
                logger.info("接口参数校验通过!")
                return func(data, *args, **kwargs)
        return inner
    return wrapper


def api_result(event_id, state_code, msg_type, msg, result):
    """
    构建接口返回结果
    """
    api_res = {
        "event_id": event_id,
        "state_code": state_code,
        "feed_type": msg_type,
        "feed_msg": msg,
        "feed_data": result,
    }
    logger.info("feed_msg: {}".format(api_res))
    logger.info("=====================================================\n")
    return jsonify(api_res)


@app.route('/train', methods=['POST'])
def train_post():
    """
    模型训练接口
    Returns:

    """
    if request.method == "POST":
        try:
            # 解析请求参数
            @json_validate(schema=schema_train)
            def api_parameters(msg_dict_copy):
                logger.info("启动模型训练.......")
                return msg_dict_copy
            request_data = request.get_data().decode()
            msg_dict = json.loads(request_data)
            # msg_dict_copy = copy.deepcopy(msg_dict)
            # msg_dict_copy = msg_dict
            logger.info("request msg: {}".format(msg_dict))
            # 校验参数
            msg_dict['error'] = False
            validate_msg = api_parameters(msg_dict)
            # 训练
            if not validate_msg['error']:
                result_path = train_main(msg_dict)
                result = {"result_path": result_path}
                return api_result(msg_dict['event_id'], 200, 'train', 'success', result)
            else:
                return api_result(msg_dict['event_id'], 501, 'train', '参数设置有误,请核查,错误信息:{}'.format(validate_msg['msg']), None)

        except Exception as e:
            logger.error(e)
            return api_result(msg_dict['event_id'], 500, "train", str(e), None)

    else:
        feed_msg = "error, request.method != POST"
        return api_result("101010", 400, "train", feed_msg, None)


@app.route('/inference', methods=['POST'])
def inference_post():
    """
    模型推理接口
    Returns:

    """
    if request.method == "POST":
        try:
            # 解析请求参数
            @json_validate(schema=schema_inference)
            def api_parameters(msg_dict_copy):
                logger.info("启动模型推理.......")
                return msg_dict_copy

            request_data = request.get_data().decode()
            msg_dict = json.loads(request_data)
            # msg_dict_copy = copy.deepcopy(msg_dict)
            # msg_dict_copy = msg_dict
            logger.info("request msg: {}".format(msg_dict))
            # 校验参数
            msg_dict['error'] = False
            validate_msg = api_parameters(msg_dict)
            # 推理
            if not validate_msg['error']:
                result_path = inference_main(msg_dict)
                result = {"result_path": result_path}
                return api_result(msg_dict['event_id'], 200, 'inference', 'success', result)
            else:
                return api_result(msg_dict['event_id'], 501, 'inference', '参数设置有误,请核查,错误信息:{}'.format(validate_msg['msg']), None)

        except Exception as e:
            logger.error(e)
            return api_result(msg_dict['event_id'], 500, "train", str(e), None)
    else:
        feed_msg = "error, request.method != POST"
        return api_result("101010", 400, "inference", feed_msg, None)


def start_app(MULTI_PROCESS=False, USE_CORES=1):
    """
    启动服务
    Returns:
    """
    # 先清空显存占用
    torch.cuda.empty_cache()
    try:
        logger.info("\n===============================================================================")
        logger.info("deeplearn server starting...")
        # 持久化服务
        if MULTI_PROCESS == False:
            server = pywsgi.WSGIServer(("0.0.0.0", 8080), app)
            server.serve_forever()
            logger.info("deeplearn server start success.")
            print('单进程 + 协程')
            return
        else:
            mulserver = pywsgi.WSGIServer(('0.0.0.0', 8080), app)
            mulserver.start()

            def server_forever():
                mulserver.start_accepting()
                mulserver._stop_event.wait()

            all_cpu_cores = cpu_count()
            if USE_CORES > all_cpu_cores:
                use_cores = all_cpu_cores
            else:
                use_cores = USE_CORES
            for i in range(use_cores):
                p = Process(target=server_forever)
                p.start()
            print('多进程 + 协程,进程数:{}+1'.format(use_cores))
            return

    except Exception as err:
        logger.error("exception in server: {}".format(err))
        logger.error("a same service port has been started. please shut down before operation.")
        try:
            logger.error("{}".format(kill_process_by_port(8080)))
        except Exception as err:
            logger.error("exception in server: {}".format(err))



def stop_app():
    """
    结束服务
    Returns:
    """
    logger.info("\n===============================================================================")
    logger.warning("deeplearn server stopping...")
    try:
        logger.info("stop info: {}".format(kill_process_by_port(8080)))
    except Exception as err:
        logger.error("stop err: {}".format(err, kill_process_by_name("python.exe")))


if __name__ == "__main__":
    # app.run(port=8080, host="0.0.0.0", )
    MULTI_PROCESS = True
    # 默认启动2+1进程
    USE_CORES = int(os.getenv('USE_CORES')) if os.getenv('USE_CORES') else 2
    start_app(MULTI_PROCESS=MULTI_PROCESS, USE_CORES=USE_CORES)
    # stop_app()

2、进程相关常用函数

python 复制代码
# job_manager.py
# -*- coding:utf-8 -*-

import os
import psutil


def get_all_process():
    pid_dict = {}
    pids = psutil.pids()
    try:
        for pid in pids:
            p = psutil.Process(pid)
            pid_dict[pid] = p.name()
    except Exception as err:
        pass
    return pid_dict

def find_pid_by_name(name: str):
    """
    根据进程名获取进程pid
    Args:
        name: process name

    Returns: process pid

    """
    pros = psutil.process_iter()
    print("[" + name + "]'s pid is:")
    pids = []
    for pro in pros:
        if (pro.name() == name):
            print(pro.pid)
            pids.append(pro.pid)
    return pids

def find_port_by_pid(pid: int):
    """根据pid寻找该进程对应的端口"""
    alist = []
    # 获取当前的网络连接信息
    net_con = psutil.net_connections()
    for con_info in net_con:
        if con_info.pid == pid:
            alist.append({pid: con_info.laddr.port})
    return alist


def find_pid_by_port(port: int):
    """根据端口寻找该进程对应的pid"""
    pid_list = []
    # 获取当前的网络连接信息
    net_con = psutil.net_connections()
    for con_info in net_con:
        if con_info.laddr.port == port:
            pid_list.append(con_info.pid)
    return pid_list


def kill_process_by_pid(pid):
    # windows
    # cmd = 'taskkill /pid ' + pid + ' /f'
    cmd = 'kill -9 ' + pid
    try:
        os.system(cmd)
    except Exception as e:
        print(e)


def kill_process_by_name(set_name):
    all_pid = get_all_process()
    for pid, name in all_pid.items():
        if name == set_name:
            kill_process_by_pid(str(pid))
    msg_str = "kill process in name: {}".format(set_name)
    return msg_str


def kill_process_by_port(port):
    pids = find_pid_by_port(port)
    for pid in pids:
        kill_process_by_pid(str(pid))

    msg_str = "kill process in port: {}".format(port)
    return msg_str


def clean_cmd():
    kill_process_by_name("cmd.exe")
    kill_process_by_name("bash.exe")


if __name__ == "__main__":
    # kill_process_by_port(8010)
    # kill_process_by_name("python.exe")
    # kill_process_by_name("cmd.exe")
    # kill_process_by_name("bash.exe")
    # kill_process_by_name("myProcess")
    # print(find_pid_by_port('8080'))
    print(find_pid_by_name('myProcess'))

3、日志模块

python 复制代码
# logger.py

# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import sys
import logging
from logging.handlers import RotatingFileHandler
import functools

logger_initialized = {}


@functools.lru_cache()
def get_logger(name='root', log_file=None, log_level=logging.INFO):
    """Initialize and get a logger by name.
    If the logger has not been initialized, this method will initialize the
    logger by adding one or two handlers, otherwise the initialized logger will
    be directly returned. During initialization, a StreamHandler will always be
    added. If `log_file` is specified a FileHandler will also be added.
    Args:
        name (str): Logger name.
        log_file (str | None): The log filename. If specified, a FileHandler
            will be added to the logger.
        log_level (int): The logger level. Note that only the process of
            rank 0 is affected, and other processes will set the level to
            "Error" thus be silent most of the time.
    Returns:
        logging.Logger: The expected logger.
    """
    logger = logging.getLogger(name)
    if name in logger_initialized:
        return logger
    for logger_name in logger_initialized:
        if name == logger_name:
            return logger

    formatter = logging.Formatter(
        '[%(asctime)s.%(msecs)03d] %(name)s %(levelname)s: %(message)s', datefmt="%Y/%m/%d %H:%M:%S")

    stream_handler = logging.StreamHandler(stream=sys.stdout)
    stream_handler.setFormatter(formatter)
    logger.addHandler(stream_handler)
    if log_file is not None:
        log_file_folder = os.path.split(log_file)[0]
        os.makedirs(log_file_folder, exist_ok=True)
        # file_handler = logging.FileHandler(log_file, 'a')
        file_handler = RotatingFileHandler(filename=log_file, maxBytes=10 * 1024 * 1024, backupCount=15, encoding='utf-8')
        file_handler.setFormatter(formatter)
        logger.addHandler(file_handler)
        logger.setLevel(log_level)

    logger_initialized[name] = True
    return logger

if __name__ == "__main__":
    # 日志
    log_file = './logs/yolov5.log'
    logger = get_logger(name='yolov5', log_file=log_file)

4、dockefile 构建镜像

docker build -t wood_detect:test .

python 复制代码
# Dockerfile

# Start FROM Nvidia PyTorch image https://ngc.nvidia.com/catalog/containers/nvidia:pytorch
#FROM nvcr.io/nvidia/pytorch:21.05-py3
#FROM pytorch/pytorch:1.7.0-cuda11.0-cudnn8-runtime
FROM deploy.hello.com/2020-public/yolov5_base:1.0.1

# Install linux packages
#RUN apt update && apt install -y zip htop screen libgl1-mesa-glx

## Create working directory
#RUN mkdir -p /usr/src/app
WORKDIR /usr/src/app

# Copy contents
COPY . /usr/src/app
EXPOSE 8080
# Install python dependencies
#COPY requirements.txt .
# RUN python -m pip install --upgrade pip
#RUN pip uninstall -y nvidia-tensorboard nvidia-tensorboard-plugin-dlprof
#RUN pip install --no-cache -r requirements.txt coremltools onnx gsutil -i https://pypi.douban.com/simple/
RUN pip install --no-cache -r requirements.txt -i https://pypi.douban.com/simple/

# RUN pip install --no-cache -U torch torchvision

## Set environment variables
#ENV HOME=/usr/src/app
#
ENTRYPOINT ["python","main_app.py"]
# ---------------------------------------------------  Extras Below  ---------------------------------------------------

# Build and Push
# t=ultralytics/yolov5:latest && sudo docker build -t $t . && sudo docker push $t
# for v in {300..303}; do t=ultralytics/coco:v$v && sudo docker build -t $t . && sudo docker push $t; done

# Pull and Run
# t=ultralytics/yolov5:latest && sudo docker pull $t && sudo docker run -it --ipc=host --gpus all $t

# Pull and Run with local directory access
# t=ultralytics/yolov5:latest && sudo docker pull $t && sudo docker run -it --ipc=host --gpus all -v "$(pwd)"/coco:/usr/src/coco $t

# Kill all
# sudo docker kill $(sudo docker ps -q)

# Kill all image-based
# sudo docker kill $(sudo docker ps -qa --filter ancestor=ultralytics/yolov5:latest)

# Bash into running container
# sudo docker exec -it 5a9b5863d93d bash

# Bash into stopped container
# id=$(sudo docker ps -qa) && sudo docker start $id && sudo docker exec -it $id bash

# Send weights to GCP
# python -c "from utils.general import *; strip_optimizer('runs/train/exp0_*/weights/best.pt', 'tmp.pt')" && gsutil cp tmp.pt gs://*.pt

# Clean up
# docker system prune -a --volumes

5、Docker-compose.yml方法构建镜像并部署

复制代码
docker-compose up
python 复制代码
# Docker-compose.yml
# GPU配置,参考https://docs.docker.com/compose/gpu-support/

version: "3.8"

services:
  yolov5:
    build:
      context: .
    image: deploy.deepexi.com/2048-public/yolov5_server:alpha_v1.0
    restart: always
    container_name: yolov5_server
    ports:
      - 8080:8080
    deploy:
      resources:
        reservations:
          devices:
            - driver: nvidia
              device_ids: [ '0',]
              capabilities: [ gpu ]

6、重写yaml文件

python 复制代码
# --hyp,data/hyp.scratch.yaml文件
with open(opt.hyp) as f:
    hyp = yaml.safe_load(f)
    if 'lr0' in yolo_hype and isinstance(yolo_hype['lr0'], float) and yolo_hype['lr0'] >= 0.0:
        hyp['lr0'] = yolo_hype['lr0']
    yaml.safe_dump(hyp, open(opt.hyp, mode='w'))
相关推荐
你的人类朋友6 分钟前
认识一下Bcrypt哈希算法
后端·安全·程序员
tangweiguo0305198720 分钟前
基于 Django 与 Bootstrap 构建的现代化设备管理平台
后端·django·bootstrap
闲人编程28 分钟前
Flask 前后端分离架构实现支付宝电脑网站支付功能
python·架构·flask·支付宝·前后端·网站支付·apl
IT果果日记29 分钟前
详解DataX开发达梦数据库插件
大数据·数据库·后端
996终结者36 分钟前
同类软件对比(四):Jupyter vs PyCharm vs VS Code:Python开发工具终极选择指南
vscode·python·jupyter·pycharm·visual studio code
dazhong201237 分钟前
Spring Boot 项目新增 Module 完整指南
java·spring boot·后端
果壳~41 分钟前
【Python】爬虫html提取内容基础,bs4
爬虫·python·html
bobz96544 分钟前
Cilium + Kubevirt 与 Kube-OVN + Kubevirt 在公有云场景下的对比与选择
后端
David爱编程2 小时前
深度解析:synchronized 性能演进史,从 JDK1.6 到 JDK17
java·后端
尝试经历体验2 小时前
pycharm突然不能正常运行
python·深度学习·pycharm