Django 使用swagger自定义自动生成类

复制代码
完整代码:https://gitee.com/mom925/django-system

之前写的Django配置swagger(https://www.cnblogs.com/moon3496694/p/17657283.html)其实更多还是自己手动的写代码去书写接口文档,我希望它能更加的自动化生成出接口文档,所以我需要自己重写一些函数。
安装所需的包,注册app,注册路由参考之前的即可(https://www.cnblogs.com/moon3496694/p/17657283.html),下面是在之前的基础上做的改进

自定义swagger自动生成的类需要在配置里指定自定义的类
复制代码
SWAGGER_SETTINGS = {
    'USE_SESSION_AUTH': False,
    'SECURITY_DEFINITIONS': {
        '身份验证': {
            'type': 'apiKey',
            'in': 'header',
            'name': 'Authorization'
        }
    },
    "DEFAULT_AUTO_SCHEMA_CLASS": "utils.swagger.CustomSwaggerAutoSchema",
}
复制代码
我的swagger.py文件
复制代码
from django.utils.encoding import smart_str
from drf_yasg.errors import SwaggerGenerationError
from drf_yasg.inspectors import SwaggerAutoSchema
from drf_yasg.utils import merge_params, get_object_classes
from rest_framework.parsers import FileUploadParser
from rest_framework.request import is_form_media_type
from rest_framework.schemas import AutoSchema
from rest_framework.utils import formatting

from Wchime.settings import SWAGGER_SETTINGS


def get_consumes(parser_classes):

    parser_classes = get_object_classes(parser_classes)
    parser_classes = [pc for pc in parser_classes if not issubclass(pc, FileUploadParser)]
    media_types = [parser.media_type for parser in parser_classes or []]
    return media_types


def get_summary(string):
    if string is not None:
        result = string.strip().replace(" ", "").split("\n")
        return result[0]


class CustomAutoSchema(AutoSchema):
    def get_description(self, path, method):
        view = self.view
        return self._get_description_section(view, 'tags', view.get_view_description())


class CustomSwaggerAutoSchema(SwaggerAutoSchema):

    def get_tags(self, operation_keys=None):
        tags = super().get_tags(operation_keys)
        # print(tags)
        if "api" in tags and operation_keys:
            #  `operation_keys` 内容像这样 ['v1', 'prize_join_log', 'create']
            tags[0] = operation_keys[SWAGGER_SETTINGS.get('AUTO_SCHEMA_TYPE', 2)]
        ca = CustomAutoSchema()
        ca.view = self.view
        tag = ca.get_description(self.path, 'get') or None
        if tag:
            # tags.append(tag)
            tags[0] = tag
        # print('===', tags)
        return tags

    def get_summary_and_description(self):
        description = self.overrides.get('operation_description', None)
        summary = self.overrides.get('operation_summary', None)
        # print(description, summary)
        if description is None:
            description = self._sch.get_description(self.path, self.method) or ''
            description = description.strip().replace('\r', '')

            if description and (summary is None):
                # description from docstring... do summary magic
                summary, description = self.split_summary_from_description(description)
            # print('====', summary, description)
        if summary is None:
            summary = description
        return summary, description

    def get_consumes_form(self):

        return get_consumes(self.get_parser_classes())

    def add_manual_parameters(self, parameters):
        """
        重写这个函数,让他能解析json,也可以解析表单
        """
        manual_parameters = self.overrides.get('manual_parameters', None) or []

        if manual_parameters:
            parameters = []

        if any(param.in_ == openapi.IN_BODY for param in manual_parameters):  # pragma: no cover
            raise SwaggerGenerationError("specify the body parameter as a Schema or Serializer in request_body")
        if any(param.in_ == openapi.IN_FORM for param in manual_parameters):  # pragma: no cover
            has_body_parameter = any(param.in_ == openapi.IN_BODY for param in parameters)

            if has_body_parameter or not any(is_form_media_type(encoding) for encoding in self.get_consumes_form()):
                raise SwaggerGenerationError("cannot add form parameters when the request has a request body; "
                                             "did you forget to set an appropriate parser class on the view?")
            if self.method not in self.body_methods:
                raise SwaggerGenerationError("form parameters can only be applied to "
                                             "(" + ','.join(self.body_methods) + ") HTTP methods")

        return merge_params(parameters, manual_parameters)


# --------------------------------------------------------------------------------------------------------------

from rest_framework import serializers
from drf_yasg import openapi
from rest_framework.relations import PrimaryKeyRelatedField
from rest_framework.fields import ChoiceField


def serializer_to_swagger(ser_model, get_req=False):
    '''
    序列化转成openapi的形式
    '''
    if ser_model is None and get_req is True:
        return {}, []
    elif ser_model is None and get_req is False:
        return {}
    dit = {}
    serializer_field_mapping = {
        ChoiceField: openapi.TYPE_INTEGER,
        PrimaryKeyRelatedField: openapi.TYPE_INTEGER,
        serializers.IntegerField: openapi.TYPE_INTEGER,
        serializers.BooleanField: openapi.TYPE_BOOLEAN,
        serializers.CharField: openapi.TYPE_STRING,
        serializers.DateField: openapi.TYPE_STRING,
        serializers.DateTimeField: openapi.TYPE_STRING,
        serializers.DecimalField: openapi.TYPE_NUMBER,
        serializers.DurationField: openapi.TYPE_STRING,
        serializers.EmailField: openapi.TYPE_STRING,
        serializers.ModelField: openapi.TYPE_OBJECT,
        serializers.FileField: openapi.TYPE_STRING,
        serializers.FloatField: openapi.TYPE_NUMBER,
        serializers.ImageField: openapi.TYPE_STRING,
        serializers.SlugField: openapi.TYPE_STRING,
        serializers.TimeField: openapi.TYPE_STRING,
        serializers.URLField: openapi.TYPE_STRING,
        serializers.UUIDField: openapi.TYPE_STRING,
        serializers.IPAddressField: openapi.TYPE_STRING,
        serializers.FilePathField: openapi.TYPE_STRING,
    }
    fields = ser_model().get_fields()
    if get_req:
        required = []
        for k, v in fields.items():
            description = getattr(v, 'label', '')
            if isinstance(v, serializers.SerializerMethodField) or getattr(v, 'source'):
                continue
            elif isinstance(v, ChoiceField):
                description += str(dict(getattr(v, 'choices', {})))

            if getattr(v, 'required', True) is not False:
                required.append(k)
            typ = serializer_field_mapping.get(type(v), openapi.TYPE_STRING)
            dit[k] = openapi.Schema(description=description, type=typ)
        return dit, required
    else:
        for k, v in fields.items():
            description = getattr(v, 'label', '')
            if isinstance(v, ChoiceField):
                description += str(dict(getattr(v, 'choices', {})))
            elif isinstance(v, serializers.SerializerMethodField):
                continue
            typ = serializer_field_mapping.get(type(v), openapi.TYPE_STRING)
            dit[k] = openapi.Schema(description=description, type=typ)

        return dit


def serializer_to_req_form_swagger(ser_model, filter_fields):
    li = list()
    serializer_field_mapping = {
        ChoiceField: openapi.TYPE_INTEGER,
        PrimaryKeyRelatedField: openapi.TYPE_INTEGER,
        serializers.IntegerField: openapi.TYPE_INTEGER,
        serializers.BooleanField: openapi.TYPE_BOOLEAN,
        serializers.CharField: openapi.TYPE_STRING,
        serializers.DateField: openapi.TYPE_STRING,
        serializers.DateTimeField: openapi.TYPE_STRING,
        serializers.DecimalField: openapi.TYPE_NUMBER,
        serializers.DurationField: openapi.TYPE_STRING,
        serializers.EmailField: openapi.TYPE_STRING,
        serializers.ModelField: openapi.TYPE_OBJECT,
        serializers.FileField: openapi.TYPE_FILE,
        serializers.FloatField: openapi.TYPE_NUMBER,
        serializers.ImageField: openapi.TYPE_FILE,
        serializers.SlugField: openapi.TYPE_STRING,
        serializers.TimeField: openapi.TYPE_STRING,
        serializers.URLField: openapi.TYPE_STRING,
        serializers.UUIDField: openapi.TYPE_STRING,
        serializers.IPAddressField: openapi.TYPE_STRING,
        serializers.FilePathField: openapi.TYPE_STRING,
    }
    fields = ser_model().get_fields()
    for k, v in fields.items():
        if k in filter_fields:
            continue
        description = getattr(v, 'label', '')
        if isinstance(v, serializers.SerializerMethodField) or getattr(v, 'source'):
            continue
        elif isinstance(v, ChoiceField):
            description += str(dict(getattr(v, 'choices', {})))
        req = getattr(v, 'required', True)
        typ = serializer_field_mapping.get(type(v), openapi.TYPE_STRING)
        li.append(openapi.Parameter(name=k, description=description, type=typ, required=req, in_=openapi.IN_FORM))
    return li


class ViewSwagger(object):

    get_req_params = []
    get_req_body = None
    get_res_data = None
    get_res_examples = {'json': {}}
    get_res_description = ' '
    get_res_code = 200
    get_tags = None
    get_operation_description = None

    post_req_params = []
    post_req_body = None
    post_res_data = None
    post_res_examples = {'json': {}}
    post_res_description = ' '
    post_res_code = 200
    post_tags = None
    post_operation_description = None

    put_req_params = []
    put_req_body = None
    put_res_data = None
    put_res_examples = {'json': {}}
    put_res_description = ' '
    put_res_code = 200
    put_tags = None
    put_operation_description = None

    delete_req_params = []
    delete_req_body = None
    delete_res_data = None
    delete_res_examples = {'json': {}}
    delete_res_description = ' '
    delete_res_code = 200
    delete_tags = None
    delete_operation_description = None

    @classmethod
    def req_serialize_schema(cls, serializer):
        return serializer_to_swagger(serializer, get_req=True)

    @classmethod
    def res_serializer_schema(cls, serializer):
        return serializer_to_swagger(serializer, get_req=False)
    @classmethod
    def req_serializer_form_schema(cls, serializer, filter_fields=[]):
        return serializer_to_req_form_swagger(serializer, filter_fields)
    @classmethod
    def get(cls):

        ret = {
            'manual_parameters': cls.get_req_params,
            'request_body': cls.get_req_body,
            'responses': {cls.get_res_code: openapi.Response(description=cls.get_res_description, schema=cls.get_res_data,  examples=cls.get_res_examples)} if cls.get_res_data else None
        }
        return ret

    @classmethod
    def post(cls):
        ret = {
            'manual_parameters': cls.post_req_params,
            'request_body': cls.post_req_body,
            'responses': {
                cls.post_res_code: openapi.Response(description=cls.post_res_description, schema=cls.post_res_data,
                                                   examples=cls.post_res_examples)} if cls.post_res_data else None
        }
        return ret

    @classmethod
    def put(cls):
        ret = {
            'manual_parameters': cls.put_req_params,
            'request_body': cls.put_req_body,
            'responses': {
                cls.put_res_code: openapi.Response(description=cls.put_res_description, schema=cls.put_res_data,
                                                   examples=cls.put_res_examples)} if cls.put_res_data else None
        }
        return ret

    @classmethod
    def delete(cls):
        ret = {
            'manual_parameters': cls.delete_req_params,
            'request_body': cls.delete_req_body,
            'responses': {
                cls.delete_res_code: openapi.Response(description=cls.delete_res_description, schema=cls.delete_res_data,
                                                   examples=cls.delete_res_examples)} if cls.delete_res_data else None
        }
        return ret
复制代码
首先重写了get_tags方法,我希望只要在视图类下面注释里写上tags:"xxxx"即可自动的读取到。
上面写的CustomAutoSchema类就是读取了视图类的注释,然后获取出里面的tags值
只需要这样写:
复制代码
然后即可生成:

得到了都在 测试图片标签下

复制代码
重写get_summary_and_description方法,原来的这个方法获取到summary是有可能为空的,所以改成当summary为None时summary=description
如果需要在视图类注释中写这两个描述,则像下面一样:
复制代码
也可以在方法注释中写,则像下面一样:
复制代码
得到的结果一样:
复制代码
注意如果两个地方都写则里面的注释会覆盖外层的,也就是方法中的注释会去覆盖视图类下面的注释



重写add_manual_parameters方法,原来的自动生成时只能解析一种数据类型,当传入多种解析类型时会默认的是JSON类型(因为rest_framework就是默认解析JSON)
因为在rest_framework中我们不管是表单还是json格式都可以request.data获取,像新增时是提交表单,批量删除时提交json格式,但是一般又写在同一个视图类下
所以给视图类指定解析数据类型 parser_classes = [MultiPartParser, JSONParser]
重写以后,存在两种都有的会返回表单格式先
视图类像下面一样:
复制代码
得到的post和delete:

得到了post的表达数据和delete的JSON数据

复制代码