平台相关常用接口、函数

1、接口主程序

1)接口参数数据校验。jsonschema

2)异常处理。

复制代码
# 方法一:
try:
    pass
except Exception as err:
    pass
复制代码
# 方法二,此方法,可将其编写在其它文件子函数里,并实现进行一步异常处理统一返回给前端
raise ValueError('....')

3)单/多进程 + 协程

python 复制代码
# -*- coding:utf-8 -*-
import json
import os
import copy
import time
import torch

from gevent import pywsgi, monkey
# 多线程,非阻塞
monkey.patch_all()

from flask import Flask, request, jsonify
from flask_cors import CORS
from multiprocessing import cpu_count, Process
from jsonschema import validate, ValidationError

from detect import inference_main
from train import train_main
from utils.job_manager import kill_process_by_port, kill_process_by_name
from utils.logger import get_logger

# 日志
log_file = './logs/yolov5.log'
logger = get_logger(name='yolov5', log_file=log_file)

# flask服务
app = Flask(__name__)
CORS(app, resources=r'/*')
app.config['JSON_AS_ASCII'] = False


# http接口参数校验
# 接口http://ip:port/train的用户校验schema字典定义
schema_train = {
    "type": "object",
    "required": ["event_id", "event_type", "payload"],
    "properties": {
        "event_id": {
            "type": "integer",
        },
        "event_type": {
            "type": "string",
        },
        "payload": {
            "type": "object",
            "required": ["data_config", 'basic_hyp'],
            "properties": {
                "data_config": {"type": "object",
                                "required": ["train", "val", "nc", "names", "result_path"],
                                "properties": {"train": {"type": "string"},
                                               "val": {"type": "string"},
                                               "nc": {"type": "integer", "minimum": 1},
                                               "names": {"type": "array"},
                                               "result_path": {"type": "string"}
                                               }},
                "basic_hyp": {"type": "object",
                              "required": ["epochs", "batch-size", "workers", "img-size", "device"],
                              "properties": {"epochs": {"type": "integer", "minimum": 1},
                                             "batch-size": {"type": "integer", "minimum": 1},
                                             "workers": {"type": "integer", "minimum": 0},
                                             "img-size": {"type": "array"},
                                             "device": {"type": "string"}
                                             }},
            }
        }
    }
}

# 接口http://ip:port/inference的用户校验schema字典定义
schema_inference = {
    "type": "object",
    "required": ["event_id", "event_type", "payload"],
    "properties": {
        "event_id": {
            "type": "integer",
        },
        "event_type": {
            "type": "string",
        },
        "payload": {
            "type": "object",
            "required": ["data_config", "basic_hyp"],
            "properties": {
                "data_config": {"type": "object",
                                "required": ["test", "result_path"],
                                "properties": {"test": {"type": "string"},
                                               "result_path": {"type": "string"}
                                               }
                                },
                "basic_hyp": {"type": "object",
                              "required": ["weights", "device"],
                              "properties": {"weights": {"type": "string"},
                                             "device": {"type": "string"}
                                             }
                              }
            }
        }
    }
}


# data参数校验装饰器,可指定不同的校验schema
def json_validate(schema):
    def wrapper(func):
        def inner(data, *args, **kwargs):
            try:
                validate(data, schema)
            except ValidationError as e:
                logger.error("接口参数校验失败:{}!".format(e.message))
                return {'error': True, 'msg': e.message}
            else:
                logger.info("接口参数校验通过!")
                return func(data, *args, **kwargs)
        return inner
    return wrapper


def api_result(event_id, state_code, msg_type, msg, result):
    """
    构建接口返回结果
    """
    api_res = {
        "event_id": event_id,
        "state_code": state_code,
        "feed_type": msg_type,
        "feed_msg": msg,
        "feed_data": result,
    }
    logger.info("feed_msg: {}".format(api_res))
    logger.info("=====================================================\n")
    return jsonify(api_res)


@app.route('/train', methods=['POST'])
def train_post():
    """
    模型训练接口
    Returns:

    """
    if request.method == "POST":
        try:
            # 解析请求参数
            @json_validate(schema=schema_train)
            def api_parameters(msg_dict_copy):
                logger.info("启动模型训练.......")
                return msg_dict_copy
            request_data = request.get_data().decode()
            msg_dict = json.loads(request_data)
            # msg_dict_copy = copy.deepcopy(msg_dict)
            # msg_dict_copy = msg_dict
            logger.info("request msg: {}".format(msg_dict))
            # 校验参数
            msg_dict['error'] = False
            validate_msg = api_parameters(msg_dict)
            # 训练
            if not validate_msg['error']:
                result_path = train_main(msg_dict)
                result = {"result_path": result_path}
                return api_result(msg_dict['event_id'], 200, 'train', 'success', result)
            else:
                return api_result(msg_dict['event_id'], 501, 'train', '参数设置有误,请核查,错误信息:{}'.format(validate_msg['msg']), None)

        except Exception as e:
            logger.error(e)
            return api_result(msg_dict['event_id'], 500, "train", str(e), None)

    else:
        feed_msg = "error, request.method != POST"
        return api_result("101010", 400, "train", feed_msg, None)


@app.route('/inference', methods=['POST'])
def inference_post():
    """
    模型推理接口
    Returns:

    """
    if request.method == "POST":
        try:
            # 解析请求参数
            @json_validate(schema=schema_inference)
            def api_parameters(msg_dict_copy):
                logger.info("启动模型推理.......")
                return msg_dict_copy

            request_data = request.get_data().decode()
            msg_dict = json.loads(request_data)
            # msg_dict_copy = copy.deepcopy(msg_dict)
            # msg_dict_copy = msg_dict
            logger.info("request msg: {}".format(msg_dict))
            # 校验参数
            msg_dict['error'] = False
            validate_msg = api_parameters(msg_dict)
            # 推理
            if not validate_msg['error']:
                result_path = inference_main(msg_dict)
                result = {"result_path": result_path}
                return api_result(msg_dict['event_id'], 200, 'inference', 'success', result)
            else:
                return api_result(msg_dict['event_id'], 501, 'inference', '参数设置有误,请核查,错误信息:{}'.format(validate_msg['msg']), None)

        except Exception as e:
            logger.error(e)
            return api_result(msg_dict['event_id'], 500, "train", str(e), None)
    else:
        feed_msg = "error, request.method != POST"
        return api_result("101010", 400, "inference", feed_msg, None)


def start_app(MULTI_PROCESS=False, USE_CORES=1):
    """
    启动服务
    Returns:
    """
    # 先清空显存占用
    torch.cuda.empty_cache()
    try:
        logger.info("\n===============================================================================")
        logger.info("deeplearn server starting...")
        # 持久化服务
        if MULTI_PROCESS == False:
            server = pywsgi.WSGIServer(("0.0.0.0", 8080), app)
            server.serve_forever()
            logger.info("deeplearn server start success.")
            print('单进程 + 协程')
            return
        else:
            mulserver = pywsgi.WSGIServer(('0.0.0.0', 8080), app)
            mulserver.start()

            def server_forever():
                mulserver.start_accepting()
                mulserver._stop_event.wait()

            all_cpu_cores = cpu_count()
            if USE_CORES > all_cpu_cores:
                use_cores = all_cpu_cores
            else:
                use_cores = USE_CORES
            for i in range(use_cores):
                p = Process(target=server_forever)
                p.start()
            print('多进程 + 协程,进程数:{}+1'.format(use_cores))
            return

    except Exception as err:
        logger.error("exception in server: {}".format(err))
        logger.error("a same service port has been started. please shut down before operation.")
        try:
            logger.error("{}".format(kill_process_by_port(8080)))
        except Exception as err:
            logger.error("exception in server: {}".format(err))



def stop_app():
    """
    结束服务
    Returns:
    """
    logger.info("\n===============================================================================")
    logger.warning("deeplearn server stopping...")
    try:
        logger.info("stop info: {}".format(kill_process_by_port(8080)))
    except Exception as err:
        logger.error("stop err: {}".format(err, kill_process_by_name("python.exe")))


if __name__ == "__main__":
    # app.run(port=8080, host="0.0.0.0", )
    MULTI_PROCESS = True
    # 默认启动2+1进程
    USE_CORES = int(os.getenv('USE_CORES')) if os.getenv('USE_CORES') else 2
    start_app(MULTI_PROCESS=MULTI_PROCESS, USE_CORES=USE_CORES)
    # stop_app()

2、进程相关常用函数

python 复制代码
# job_manager.py
# -*- coding:utf-8 -*-

import os
import psutil


def get_all_process():
    pid_dict = {}
    pids = psutil.pids()
    try:
        for pid in pids:
            p = psutil.Process(pid)
            pid_dict[pid] = p.name()
    except Exception as err:
        pass
    return pid_dict

def find_pid_by_name(name: str):
    """
    根据进程名获取进程pid
    Args:
        name: process name

    Returns: process pid

    """
    pros = psutil.process_iter()
    print("[" + name + "]'s pid is:")
    pids = []
    for pro in pros:
        if (pro.name() == name):
            print(pro.pid)
            pids.append(pro.pid)
    return pids

def find_port_by_pid(pid: int):
    """根据pid寻找该进程对应的端口"""
    alist = []
    # 获取当前的网络连接信息
    net_con = psutil.net_connections()
    for con_info in net_con:
        if con_info.pid == pid:
            alist.append({pid: con_info.laddr.port})
    return alist


def find_pid_by_port(port: int):
    """根据端口寻找该进程对应的pid"""
    pid_list = []
    # 获取当前的网络连接信息
    net_con = psutil.net_connections()
    for con_info in net_con:
        if con_info.laddr.port == port:
            pid_list.append(con_info.pid)
    return pid_list


def kill_process_by_pid(pid):
    # windows
    # cmd = 'taskkill /pid ' + pid + ' /f'
    cmd = 'kill -9 ' + pid
    try:
        os.system(cmd)
    except Exception as e:
        print(e)


def kill_process_by_name(set_name):
    all_pid = get_all_process()
    for pid, name in all_pid.items():
        if name == set_name:
            kill_process_by_pid(str(pid))
    msg_str = "kill process in name: {}".format(set_name)
    return msg_str


def kill_process_by_port(port):
    pids = find_pid_by_port(port)
    for pid in pids:
        kill_process_by_pid(str(pid))

    msg_str = "kill process in port: {}".format(port)
    return msg_str


def clean_cmd():
    kill_process_by_name("cmd.exe")
    kill_process_by_name("bash.exe")


if __name__ == "__main__":
    # kill_process_by_port(8010)
    # kill_process_by_name("python.exe")
    # kill_process_by_name("cmd.exe")
    # kill_process_by_name("bash.exe")
    # kill_process_by_name("myProcess")
    # print(find_pid_by_port('8080'))
    print(find_pid_by_name('myProcess'))

3、日志模块

python 复制代码
# logger.py

# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import sys
import logging
from logging.handlers import RotatingFileHandler
import functools

logger_initialized = {}


@functools.lru_cache()
def get_logger(name='root', log_file=None, log_level=logging.INFO):
    """Initialize and get a logger by name.
    If the logger has not been initialized, this method will initialize the
    logger by adding one or two handlers, otherwise the initialized logger will
    be directly returned. During initialization, a StreamHandler will always be
    added. If `log_file` is specified a FileHandler will also be added.
    Args:
        name (str): Logger name.
        log_file (str | None): The log filename. If specified, a FileHandler
            will be added to the logger.
        log_level (int): The logger level. Note that only the process of
            rank 0 is affected, and other processes will set the level to
            "Error" thus be silent most of the time.
    Returns:
        logging.Logger: The expected logger.
    """
    logger = logging.getLogger(name)
    if name in logger_initialized:
        return logger
    for logger_name in logger_initialized:
        if name == logger_name:
            return logger

    formatter = logging.Formatter(
        '[%(asctime)s.%(msecs)03d] %(name)s %(levelname)s: %(message)s', datefmt="%Y/%m/%d %H:%M:%S")

    stream_handler = logging.StreamHandler(stream=sys.stdout)
    stream_handler.setFormatter(formatter)
    logger.addHandler(stream_handler)
    if log_file is not None:
        log_file_folder = os.path.split(log_file)[0]
        os.makedirs(log_file_folder, exist_ok=True)
        # file_handler = logging.FileHandler(log_file, 'a')
        file_handler = RotatingFileHandler(filename=log_file, maxBytes=10 * 1024 * 1024, backupCount=15, encoding='utf-8')
        file_handler.setFormatter(formatter)
        logger.addHandler(file_handler)
        logger.setLevel(log_level)

    logger_initialized[name] = True
    return logger

if __name__ == "__main__":
    # 日志
    log_file = './logs/yolov5.log'
    logger = get_logger(name='yolov5', log_file=log_file)

4、dockefile 构建镜像

docker build -t wood_detect:test .

python 复制代码
# Dockerfile

# Start FROM Nvidia PyTorch image https://ngc.nvidia.com/catalog/containers/nvidia:pytorch
#FROM nvcr.io/nvidia/pytorch:21.05-py3
#FROM pytorch/pytorch:1.7.0-cuda11.0-cudnn8-runtime
FROM deploy.hello.com/2020-public/yolov5_base:1.0.1

# Install linux packages
#RUN apt update && apt install -y zip htop screen libgl1-mesa-glx

## Create working directory
#RUN mkdir -p /usr/src/app
WORKDIR /usr/src/app

# Copy contents
COPY . /usr/src/app
EXPOSE 8080
# Install python dependencies
#COPY requirements.txt .
# RUN python -m pip install --upgrade pip
#RUN pip uninstall -y nvidia-tensorboard nvidia-tensorboard-plugin-dlprof
#RUN pip install --no-cache -r requirements.txt coremltools onnx gsutil -i https://pypi.douban.com/simple/
RUN pip install --no-cache -r requirements.txt -i https://pypi.douban.com/simple/

# RUN pip install --no-cache -U torch torchvision

## Set environment variables
#ENV HOME=/usr/src/app
#
ENTRYPOINT ["python","main_app.py"]
# ---------------------------------------------------  Extras Below  ---------------------------------------------------

# Build and Push
# t=ultralytics/yolov5:latest && sudo docker build -t $t . && sudo docker push $t
# for v in {300..303}; do t=ultralytics/coco:v$v && sudo docker build -t $t . && sudo docker push $t; done

# Pull and Run
# t=ultralytics/yolov5:latest && sudo docker pull $t && sudo docker run -it --ipc=host --gpus all $t

# Pull and Run with local directory access
# t=ultralytics/yolov5:latest && sudo docker pull $t && sudo docker run -it --ipc=host --gpus all -v "$(pwd)"/coco:/usr/src/coco $t

# Kill all
# sudo docker kill $(sudo docker ps -q)

# Kill all image-based
# sudo docker kill $(sudo docker ps -qa --filter ancestor=ultralytics/yolov5:latest)

# Bash into running container
# sudo docker exec -it 5a9b5863d93d bash

# Bash into stopped container
# id=$(sudo docker ps -qa) && sudo docker start $id && sudo docker exec -it $id bash

# Send weights to GCP
# python -c "from utils.general import *; strip_optimizer('runs/train/exp0_*/weights/best.pt', 'tmp.pt')" && gsutil cp tmp.pt gs://*.pt

# Clean up
# docker system prune -a --volumes

5、Docker-compose.yml方法构建镜像并部署

复制代码
docker-compose up
python 复制代码
# Docker-compose.yml
# GPU配置,参考https://docs.docker.com/compose/gpu-support/

version: "3.8"

services:
  yolov5:
    build:
      context: .
    image: deploy.deepexi.com/2048-public/yolov5_server:alpha_v1.0
    restart: always
    container_name: yolov5_server
    ports:
      - 8080:8080
    deploy:
      resources:
        reservations:
          devices:
            - driver: nvidia
              device_ids: [ '0',]
              capabilities: [ gpu ]

6、重写yaml文件

python 复制代码
# --hyp,data/hyp.scratch.yaml文件
with open(opt.hyp) as f:
    hyp = yaml.safe_load(f)
    if 'lr0' in yolo_hype and isinstance(yolo_hype['lr0'], float) and yolo_hype['lr0'] >= 0.0:
        hyp['lr0'] = yolo_hype['lr0']
    yaml.safe_dump(hyp, open(opt.hyp, mode='w'))
相关推荐
称昵写填未16 分钟前
在Pycharm配置conda虚拟环境的Python解释器
开发语言·ide·python·pycharm·conda·anaconda·虚拟环境
m0_7431064617 分钟前
nerfstudio以及相关使用记录(长期更新)
python·深度学习·ubuntu·计算机视觉·3d
关山月34 分钟前
Python 列表方法可视化解释
python
3DVisionary35 分钟前
蓝光三维扫描技术:手机闪光灯模块全尺寸3D检测的精准解决方案
python·数码相机·3d·智能手机·蓝光3d扫描技术·非接触、高效率、全尺寸检测·完美适配手机微型零部件
拓端研究室TRL1 小时前
专题|Python梯度提升实例合集:GBM、XGBoost、SMOTE重采样、贝叶斯、逻辑回归、随机森林分析信贷、破产数据...
开发语言·python·算法·随机森林·逻辑回归
小马运维的一天1 小时前
03.Python基础2
开发语言·python
竣峰2 小时前
简单商品管理页开发-基于yzpass-admin-template 后台管理系统模版
前端·后端
水w2 小时前
【PyCharm2024】一些好用的小功能
开发语言·python·pycharm·技巧
databook2 小时前
『Plotly实战指南』--绘图初体验
python·数据分析·数据可视化
Victor3562 小时前
Zookeeper(107)Zookeeper的观察者(Watcher)机制是如何实现的?
后端