flask_restx 创建restful api

文章目录

介绍

  • flask_restx 快速创建restful api的flask扩展包;
  • pip install flask_restx;
  • 自动创建Swagger UI的接口文档;

代码案例

bash 复制代码
"""
Flask-RESTX 主应用配置
"""
import os
import hashlib
from functools import wraps
from flask import Flask, jsonify
from flask_restx import Api, Resource, fields
from flask_sqlalchemy import SQLAlchemy
from flask_migrate import Migrate
from flask_jwt_extended import JWTManager, jwt_required, get_jwt_identity, verify_jwt_in_request

# 1.在.env文件中存储敏感信息,如API_KEY,数据库密码等
# 2.在.gitignore中设置.env不上传仓库,防止敏感信息泄露
from dotenv import load_dotenv  # pip install dotenv

# websocket 双通道通信
from flask_sock import Sock

import pymysql
pymysql.install_as_MySQLdb()  # ModuleNotFoundError: No module named 'MySQLdb'

# 只在开发环境加载 .env
if os.getenv("ENVIRONMENT") != "production":
    # 从.env加载环境变量,并写入os.environ中
    load_dotenv()

# 初始化 Flask 应用
app = Flask(__name__)

# 基础配置
app.config.update({
    'DEBUG': os.getenv('FLASK_ENV') == 'development',
    'TESTING': os.getenv('FLASK_ENV') == 'testing',
    'SECRET_KEY': os.getenv('SECRET_KEY', '123'),

    # 数据库连接
    'SQLALCHEMY_DATABASE_URI': os.getenv('DATABASE_URL', 'sqlite:///app.db'),
    'SQLALCHEMY_TRACE_MODIFICATIONS': False,

    # JWT 配置
    'JWT_SECRET_KEY': os.getenv('JWT_SECRET_KEY', '456'),
    'JWT_ACCESS_TOKEN_EXPIRES': 3600,  # 1小时
    'JWT_REFRESH_TOKEN_EXPIRES': 86400,  # 24小时
})

# 初始化扩展
db = SQLAlchemy(app)
migrate = Migrate(app, db)
jwt_manager = JWTManager(app)
sock = Sock(app)

# ============================================
# Flask-RESTX 配置, 构建restful api
# ============================================
# 自定义 API 配置
authorizations = {
    'Bearer Auth': {
        'type': 'apiKey',
        'in': 'header',
        'name': 'Authorization',
        'description': '输入: Bearer <JWT token>'
    }
}

api = Api(
    app,
    version='1.0',
    title='用户管理系统 API',
    description='用户管理系统 REST API',
    doc='/api/docs',  # Swagger UI 访问路径
    authorizations=authorizations,
    security='Bearer Auth',  # 默认安全方案
    contact='dev@qq.com',
    contact_url='https://example.com',
    license='MIT',
    license_url='https://opensource.org/licenses/MIT'
)

# ============================================
# 数据schema 定义, 用于数据验证、数据过滤、序列化、反序列化
# 这些模型用于 API 文档和请求/响应验证
# ============================================
# 用户模型schema
user_schema = api.model('UserSchema', {
    'id': fields.Integer(readOnly=True, description='用户ID'),
    'username': fields.String(required=True, description='用户名'),
    'email': fields.String(required=True, description='邮箱地址'),
    'created_at': fields.DateTime(readOnly=True, description='创建时间'),
    'updated_at': fields.DateTime(readOnly=True, description='更新时间')
})

# 创建用户的schema
user_create_schema = api.model('UserCreateSchema', {
    'username': fields.String(required=True, description='用户名', min_length=3, max_length=50),
    'password': fields.String(required=True, description='密码', min_length=6),
    'email': fields.String(required=True, description='邮箱地址',
                           pattern=r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$')
})

# 用户更新的schema
user_update_schema = api.model('UserUpdateSchema', {
    'username': fields.String(description='用户名', min_length=3, max_length=50),
    'email': fields.String(description='邮箱地址', pattern=r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$')
})

# 登录模型schema
login_schema = api.model('LoginSchema', {
    'username': fields.String(required=True, description='用户名'),
    'password': fields.String(required=True, description='密码')
})

# 分页响应模型schema
pagination_schema = api.model('PaginationSchema', {
    'page': fields.Integer(description='页码'),
    'per_page': fields.Integer(description='每页数量'),
    'total_pages': fields.Integer(description='总页数'),
    'total_items': fields.Integer(description='总条目数'),
    'has_next': fields.Boolean(description='是否有下一页'),
    'has_prev': fields.Boolean(description='是否有上一页')
})

# 带分页的用户列表schema
user_list_schema = api.model('UserListSchema', {
    'users': fields.List(fields.Nested(user_schema)),  # 内嵌
    'pagination': fields.Nested(pagination_schema)  # 内嵌表示自定义的schema
})

# ============================================
# 创建命名空间,表示关联API的分组
# ============================================
# 用户相关 API
user_ns = api.namespace(
    'users',
    description='用户管理操作',
    path='/api/users'  # 请求路径
)

# 认证相关 API
auth_ns = api.namespace(
    'auth',
    description='用户认证操作',
    path='/api/auth'  # 请求路径
)


# ============================================
# 数据库模型定义, 连接db的ORM 模型
# ============================================
class User(db.Model):
    """用户数据模型"""
    __tablename__ = 'users'

    id = db.Column(db.Integer, primary_key=True)
    username = db.Column(db.String(80), unique=True, nullable=False, index=True)
    email = db.Column(db.String(120), unique=True, nullable=False, index=True)
    password_hash = db.Column(db.String(256), nullable=False)  # 加密存储密码
    created_at = db.Column(db.DateTime, default=db.func.current_timestamp())
    updated_at = db.Column(
        db.DateTime,
        default=db.func.current_timestamp(),
        onupdate=db.func.current_timestamp()
    )

    def to_dict(self):
        """转换为字典"""
        return {
            'id': self.id,
            'username': self.username,
            'email': self.email,
            'created_at': self.created_at.isoformat() if self.created_at else None,
            'updated_at': self.updated_at.isoformat() if self.updated_at else None
        }

    @staticmethod
    def hash_password(password):
        """密码哈希(实际项目中使用 bcrypt 等)"""
        return hashlib.sha256(password.encode()).hexdigest()

    def check_password(self, password):
        """验证密码"""
        return self.password_hash == self.hash_password(password)


# ============================================
# 自定义装饰器
# ============================================
def admin_required(fn):
    """管理员权限装饰器"""

    @wraps(fn)  # 恢复fn的函数名
    def wrapper(*args, **kwargs):
        verify_jwt_in_request()
        identity = get_jwt_identity()
        print("identity:", identity)
        # 查询数据库检查用户角色
        # 假设用户名为 'admin' 的是管理员
        if identity.get('username') != 'admin':
            return {'message': '需要管理员权限'}, 403

        return fn(*args, **kwargs)

    return wrapper


# ============================================
# 用户管理的API 资源
# ============================================
@user_ns.route('/')  # 访问路径 /api/users/
class UserListResource(Resource):
    """用户列表资源的 视图类"""

    @user_ns.doc('list_users')  # 接口的文档描述
    @user_ns.expect(api.parser().add_argument('page', type=int, default=1, help='页码'))
    @user_ns.expect(api.parser().add_argument('per_page', type=int, default=10, help='每页数量'))
    @user_ns.marshal_with(user_list_schema)  # 返回响应时序列化数据
    def get(self):
        """获取用户列表"""

        # 从多个来源解析数据
        # 默认会检查以下位置(按顺序):
        # 1. JSON 请求体 (request.json)
        # 2. 表单数据 (request.form)
        # 3. 查询字符串 (request.args)
        # 4. 文件上传 (request.files)
        parser = api.parser()
        parser.add_argument('page', type=int, default=1, help='页码')
        parser.add_argument('per_page', type=int, default=10, help='每页数量')
        args = parser.parse_args()
        # 获取数据
        page = args['page']
        per_page = args['per_page']

        # 分页查询
        pagination = User.query.paginate(page=page, per_page=per_page, error_out=False)
        # pagination.items 所有的项目

        return {
            'users': [user.to_dict() for user in pagination.items],
            'pagination': {
                'page': pagination.page,  # 当前页码
                'per_page': pagination.per_page,  # 每页有多少
                'total_pages': pagination.pages,  # 总计多少页
                'total_items': pagination.total,  # 总数据量
                'has_next': pagination.has_next,
                'has_prev': pagination.has_prev
            }
        }

    @user_ns.doc('create_user')  # 接口的文档描述
    @user_ns.expect(user_create_schema)  # 期望用户提交的数据结构
    @user_ns.marshal_with(user_schema, code=201)
    @user_ns.response(400, '输入数据无效')
    @user_ns.response(409, '用户名或邮箱已存在')  # 响应码为409时,提示信息
    def post(self):
        """创建新用户"""
        data = api.payload  # 获取表单数据

        # 检查用户名是否已存在
        if User.query.filter_by(username=data['username']).first():
            return {'message': '用户名已存在'}, 409

        # 检查邮箱是否已存在
        if User.query.filter_by(email=data['email']).first():
            return {'message': '邮箱已存在'}, 409

        # 创建用户
        user = User(
            username=data['username'],
            email=data['email'],
            password_hash=User.hash_password(data['password'])
        )

        # 会话提交
        db.session.add(user)
        db.session.commit()
        db.session.refresh(user)
        return user.to_dict(), 201


@user_ns.route('/<int:user_id>')
@user_ns.param('user_id', '用户ID')  # 从url路径参数中获取参数
@user_ns.response(404, '用户不存在')
class UserResource(Resource):
    """单个用户资源"""

    @user_ns.doc('get_user')
    @user_ns.marshal_with(user_schema)
    def get(self, user_id):
        """根据ID获取用户信息"""
        user = User.query.get_or_404(user_id)
        return user.to_dict()

    @user_ns.doc('update_user')
    @user_ns.expect(user_update_schema)
    @user_ns.marshal_with(user_schema)
    @user_ns.response(400, '输入数据无效')
    @user_ns.response(409, '用户名或邮箱已存在')
    def put(self, user_id):
        """更新用户信息"""
        user = User.query.get_or_404(user_id)
        data = api.payload

        # 如果更新用户名,检查是否重复
        if 'username' in data and data['username'] != user.username:
            if User.query.filter_by(username=data['username']).first():
                return {'message': '用户名已存在'}, 409
            user.username = data['username']

        # 如果更新邮箱,检查是否重复
        if 'email' in data and data['email'] != user.email:
            if User.query.filter_by(email=data['email']).first():
                return {'message': '邮箱已存在'}, 409
            user.email = data['email']

        db.session.commit()
        return user.to_dict()

    @user_ns.doc('delete_user')
    @user_ns.response(204, '用户已删除')
    def delete(self, user_id):
        """删除用户"""
        user = User.query.get_or_404(user_id)

        db.session.delete(user)
        db.session.commit()

        return '', 204


@user_ns.route('/search')  # /api/users/search
class UserSearchResource(Resource):
    """用户搜索资源"""

    @user_ns.doc('search_users')
    @user_ns.expect(api.parser().add_argument('q', required=True, help='搜索关键词'))
    @user_ns.marshal_list_with(user_schema)
    def get(self):
        """搜索用户"""
        parser = api.parser()
        parser.add_argument('q', required=True, help='搜索关键词')
        args = parser.parse_args()

        search_term = f"%{args['q']}%"
        # 模糊查询
        users = User.query.filter(
            (User.username.like(search_term)) |
            (User.email.like(search_term))
        ).all()

        return [user.to_dict() for user in users]


# ============================================
# API 资源 - 用户认证
# ============================================
@auth_ns.route('/register')
class RegisterResource(Resource):
    """用户注册"""

    @auth_ns.doc('register_user')
    @auth_ns.expect(user_create_schema)
    @auth_ns.marshal_with(user_schema, code=201)
    @auth_ns.response(400, '输入数据无效')
    def post(self):
        """用户注册"""
        data = api.payload

        # 验证邮箱格式
        import re
        email_pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
        if not re.match(email_pattern, data['email']):
            return {'message': '邮箱格式无效'}, 400

        # 创建用户
        user = User(
            username=data['username'],
            email=data['email'],
            password_hash=User.hash_password(data['password'])
        )

        db.session.add(user)
        db.session.commit()

        return user.to_dict(), 201


@auth_ns.route('/login')
class LoginResource(Resource):
    """用户登录"""

    @auth_ns.doc('user_login')
    @auth_ns.expect(login_schema)
    @auth_ns.response(200, '登录成功')
    @auth_ns.response(401, '用户名或密码错误')
    def post(self):
        """用户登录"""
        data = api.payload

        # 查找用户
        user = User.query.filter_by(username=data['username']).first()

        if not user or not user.check_password(data['password']):
            return {'message': '用户名或密码错误'}, 401

        # 生成 JWT token
        from flask_jwt_extended import create_access_token, create_refresh_token

        access_token = create_access_token(identity={
            'id': user.id,
            'username': user.username
        })

        refresh_token = create_refresh_token(identity={
            'id': user.id,
            'username': user.username
        })

        return {
            'access_token': access_token,
            'refresh_token': refresh_token,
            'token_type': 'Bearer',
            'expires_in': 3600,
            'user': user.to_dict()
        }, 200


@auth_ns.route('/refresh')
class RefreshTokenResource(Resource):
    """刷新访问令牌"""

    @auth_ns.doc('refresh_token')
    @auth_ns.response(200, '令牌刷新成功')
    @auth_ns.response(401, '刷新令牌无效')
    def post(self):
        """刷新访问令牌"""
        from flask_jwt_extended import jwt_required, get_jwt_identity, create_access_token

        identity = get_jwt_identity()
        access_token = create_access_token(identity=identity)

        return {
            'access_token': access_token,
            'token_type': 'Bearer',
            'expires_in': 3600
        }, 200


@auth_ns.route('/me')
class CurrentUserResource(Resource):
    """当前用户信息"""

    @auth_ns.doc('get_current_user')
    @auth_ns.response(200, '成功获取用户信息')
    @auth_ns.response(401, '未授权访问')
    def get(self):
        """获取当前登录用户信息"""

        identity = get_jwt_identity()
        user = User.query.get(identity['id'])

        if not user:
            return {'message': '用户不存在'}, 404

        return user.to_dict(), 200


# ============================================
# 错误处理
# ============================================
@app.errorhandler(404)
def not_found(error):
    """处理 404 错误"""
    return jsonify({
        'error': 'Not Found',
        'message': '请求的资源不存在',
        'status_code': 404
    }), 404


@app.errorhandler(500)
def internal_error(error):
    """处理 500 错误"""
    return jsonify({
        'error': 'Internal Server Error',
        'message': '服务器内部错误',
        'status_code': 500
    }), 500


# ============================================
# 自定义 Swagger 文档
# ============================================

# @api.documentation
# def custom_ui():   # 自定义接口文档格式   /api/docs 访问
#     """自定义 Swagger UI 配置"""
#     return {
#         'apisSorter': 'alpha',
#         'operationsSorter': 'alpha',
#         'docExpansion': 'none',
#         'defaultModelsExpandDepth': 2,
#         'defaultModelExpandDepth': 2,
#         'showCommonExtensions': True,
#         'showExtensions': True
#     }


# ============================================
# 应用启动
# ============================================

if __name__ == '__main__':
    # 创建数据库表
    with app.app_context():
        db.create_all()

    # 运行应用
    app.run(
        host=os.getenv('FLASK_HOST', '0.0.0.0'),
        port=int(os.getenv('FLASK_PORT', 5000)),
        debug=os.getenv('FLASK_ENV') == 'development'
    )

页面效果

相关推荐
毕设源码-郭学长2 小时前
【开题答辩全过程】以 基于python电商商城系统为例,包含答辩的问题和答案
开发语言·python
black0moonlight2 小时前
win11 isaacsim 5.1.0 和lab配置
python
知乎的哥廷根数学学派2 小时前
基于多尺度注意力机制融合连续小波变换与原型网络的滚动轴承小样本故障诊断方法(Pytorch)
网络·人工智能·pytorch·python·深度学习·算法·机器学习
网安CILLE2 小时前
PHP四大输出语句
linux·开发语言·python·web安全·网络安全·系统安全·php
jjjddfvv2 小时前
超级简单启动llamafactory!
windows·python·深度学习·神经网络·微调·audiolm·llamafactory
A先生的AI之旅2 小时前
2025顶会TimeDRT快速解读
人工智能·pytorch·python·深度学习·机器学习
程序员小远2 小时前
完整的项目测试方案流程
自动化测试·软件测试·python·功能测试·测试工具·职场和发展·测试用例
程序猿阿伟2 小时前
《量子算法开发实战手册:Python全栈能力的落地指南》
python·算法·量子计算