flask_restx 创建restful api

文章目录

介绍

  • flask_restx 快速创建restful api的flask扩展包;
  • pip install flask_restx;
  • 自动创建Swagger UI的接口文档;

代码案例

bash 复制代码
"""
Flask-RESTX 主应用配置
"""
import os
import hashlib
from functools import wraps
from flask import Flask, jsonify
from flask_restx import Api, Resource, fields
from flask_sqlalchemy import SQLAlchemy
from flask_migrate import Migrate
from flask_jwt_extended import JWTManager, jwt_required, get_jwt_identity, verify_jwt_in_request

# 1.在.env文件中存储敏感信息,如API_KEY,数据库密码等
# 2.在.gitignore中设置.env不上传仓库,防止敏感信息泄露
from dotenv import load_dotenv  # pip install dotenv

# websocket 双通道通信
from flask_sock import Sock

import pymysql
pymysql.install_as_MySQLdb()  # ModuleNotFoundError: No module named 'MySQLdb'

# 只在开发环境加载 .env
if os.getenv("ENVIRONMENT") != "production":
    # 从.env加载环境变量,并写入os.environ中
    load_dotenv()

# 初始化 Flask 应用
app = Flask(__name__)

# 基础配置
app.config.update({
    'DEBUG': os.getenv('FLASK_ENV') == 'development',
    'TESTING': os.getenv('FLASK_ENV') == 'testing',
    'SECRET_KEY': os.getenv('SECRET_KEY', '123'),

    # 数据库连接
    'SQLALCHEMY_DATABASE_URI': os.getenv('DATABASE_URL', 'sqlite:///app.db'),
    'SQLALCHEMY_TRACE_MODIFICATIONS': False,

    # JWT 配置
    'JWT_SECRET_KEY': os.getenv('JWT_SECRET_KEY', '456'),
    'JWT_ACCESS_TOKEN_EXPIRES': 3600,  # 1小时
    'JWT_REFRESH_TOKEN_EXPIRES': 86400,  # 24小时
})

# 初始化扩展
db = SQLAlchemy(app)
migrate = Migrate(app, db)
jwt_manager = JWTManager(app)
sock = Sock(app)

# ============================================
# Flask-RESTX 配置, 构建restful api
# ============================================
# 自定义 API 配置
authorizations = {
    'Bearer Auth': {
        'type': 'apiKey',
        'in': 'header',
        'name': 'Authorization',
        'description': '输入: Bearer <JWT token>'
    }
}

api = Api(
    app,
    version='1.0',
    title='用户管理系统 API',
    description='用户管理系统 REST API',
    doc='/api/docs',  # Swagger UI 访问路径
    authorizations=authorizations,
    security='Bearer Auth',  # 默认安全方案
    contact='dev@qq.com',
    contact_url='https://example.com',
    license='MIT',
    license_url='https://opensource.org/licenses/MIT'
)

# ============================================
# 数据schema 定义, 用于数据验证、数据过滤、序列化、反序列化
# 这些模型用于 API 文档和请求/响应验证
# ============================================
# 用户模型schema
user_schema = api.model('UserSchema', {
    'id': fields.Integer(readOnly=True, description='用户ID'),
    'username': fields.String(required=True, description='用户名'),
    'email': fields.String(required=True, description='邮箱地址'),
    'created_at': fields.DateTime(readOnly=True, description='创建时间'),
    'updated_at': fields.DateTime(readOnly=True, description='更新时间')
})

# 创建用户的schema
user_create_schema = api.model('UserCreateSchema', {
    'username': fields.String(required=True, description='用户名', min_length=3, max_length=50),
    'password': fields.String(required=True, description='密码', min_length=6),
    'email': fields.String(required=True, description='邮箱地址',
                           pattern=r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$')
})

# 用户更新的schema
user_update_schema = api.model('UserUpdateSchema', {
    'username': fields.String(description='用户名', min_length=3, max_length=50),
    'email': fields.String(description='邮箱地址', pattern=r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$')
})

# 登录模型schema
login_schema = api.model('LoginSchema', {
    'username': fields.String(required=True, description='用户名'),
    'password': fields.String(required=True, description='密码')
})

# 分页响应模型schema
pagination_schema = api.model('PaginationSchema', {
    'page': fields.Integer(description='页码'),
    'per_page': fields.Integer(description='每页数量'),
    'total_pages': fields.Integer(description='总页数'),
    'total_items': fields.Integer(description='总条目数'),
    'has_next': fields.Boolean(description='是否有下一页'),
    'has_prev': fields.Boolean(description='是否有上一页')
})

# 带分页的用户列表schema
user_list_schema = api.model('UserListSchema', {
    'users': fields.List(fields.Nested(user_schema)),  # 内嵌
    'pagination': fields.Nested(pagination_schema)  # 内嵌表示自定义的schema
})

# ============================================
# 创建命名空间,表示关联API的分组
# ============================================
# 用户相关 API
user_ns = api.namespace(
    'users',
    description='用户管理操作',
    path='/api/users'  # 请求路径
)

# 认证相关 API
auth_ns = api.namespace(
    'auth',
    description='用户认证操作',
    path='/api/auth'  # 请求路径
)


# ============================================
# 数据库模型定义, 连接db的ORM 模型
# ============================================
class User(db.Model):
    """用户数据模型"""
    __tablename__ = 'users'

    id = db.Column(db.Integer, primary_key=True)
    username = db.Column(db.String(80), unique=True, nullable=False, index=True)
    email = db.Column(db.String(120), unique=True, nullable=False, index=True)
    password_hash = db.Column(db.String(256), nullable=False)  # 加密存储密码
    created_at = db.Column(db.DateTime, default=db.func.current_timestamp())
    updated_at = db.Column(
        db.DateTime,
        default=db.func.current_timestamp(),
        onupdate=db.func.current_timestamp()
    )

    def to_dict(self):
        """转换为字典"""
        return {
            'id': self.id,
            'username': self.username,
            'email': self.email,
            'created_at': self.created_at.isoformat() if self.created_at else None,
            'updated_at': self.updated_at.isoformat() if self.updated_at else None
        }

    @staticmethod
    def hash_password(password):
        """密码哈希(实际项目中使用 bcrypt 等)"""
        return hashlib.sha256(password.encode()).hexdigest()

    def check_password(self, password):
        """验证密码"""
        return self.password_hash == self.hash_password(password)


# ============================================
# 自定义装饰器
# ============================================
def admin_required(fn):
    """管理员权限装饰器"""

    @wraps(fn)  # 恢复fn的函数名
    def wrapper(*args, **kwargs):
        verify_jwt_in_request()
        identity = get_jwt_identity()
        print("identity:", identity)
        # 查询数据库检查用户角色
        # 假设用户名为 'admin' 的是管理员
        if identity.get('username') != 'admin':
            return {'message': '需要管理员权限'}, 403

        return fn(*args, **kwargs)

    return wrapper


# ============================================
# 用户管理的API 资源
# ============================================
@user_ns.route('/')  # 访问路径 /api/users/
class UserListResource(Resource):
    """用户列表资源的 视图类"""

    @user_ns.doc('list_users')  # 接口的文档描述
    @user_ns.expect(api.parser().add_argument('page', type=int, default=1, help='页码'))
    @user_ns.expect(api.parser().add_argument('per_page', type=int, default=10, help='每页数量'))
    @user_ns.marshal_with(user_list_schema)  # 返回响应时序列化数据
    def get(self):
        """获取用户列表"""

        # 从多个来源解析数据
        # 默认会检查以下位置(按顺序):
        # 1. JSON 请求体 (request.json)
        # 2. 表单数据 (request.form)
        # 3. 查询字符串 (request.args)
        # 4. 文件上传 (request.files)
        parser = api.parser()
        parser.add_argument('page', type=int, default=1, help='页码')
        parser.add_argument('per_page', type=int, default=10, help='每页数量')
        args = parser.parse_args()
        # 获取数据
        page = args['page']
        per_page = args['per_page']

        # 分页查询
        pagination = User.query.paginate(page=page, per_page=per_page, error_out=False)
        # pagination.items 所有的项目

        return {
            'users': [user.to_dict() for user in pagination.items],
            'pagination': {
                'page': pagination.page,  # 当前页码
                'per_page': pagination.per_page,  # 每页有多少
                'total_pages': pagination.pages,  # 总计多少页
                'total_items': pagination.total,  # 总数据量
                'has_next': pagination.has_next,
                'has_prev': pagination.has_prev
            }
        }

    @user_ns.doc('create_user')  # 接口的文档描述
    @user_ns.expect(user_create_schema)  # 期望用户提交的数据结构
    @user_ns.marshal_with(user_schema, code=201)
    @user_ns.response(400, '输入数据无效')
    @user_ns.response(409, '用户名或邮箱已存在')  # 响应码为409时,提示信息
    def post(self):
        """创建新用户"""
        data = api.payload  # 获取表单数据

        # 检查用户名是否已存在
        if User.query.filter_by(username=data['username']).first():
            return {'message': '用户名已存在'}, 409

        # 检查邮箱是否已存在
        if User.query.filter_by(email=data['email']).first():
            return {'message': '邮箱已存在'}, 409

        # 创建用户
        user = User(
            username=data['username'],
            email=data['email'],
            password_hash=User.hash_password(data['password'])
        )

        # 会话提交
        db.session.add(user)
        db.session.commit()
        db.session.refresh(user)
        return user.to_dict(), 201


@user_ns.route('/<int:user_id>')
@user_ns.param('user_id', '用户ID')  # 从url路径参数中获取参数
@user_ns.response(404, '用户不存在')
class UserResource(Resource):
    """单个用户资源"""

    @user_ns.doc('get_user')
    @user_ns.marshal_with(user_schema)
    def get(self, user_id):
        """根据ID获取用户信息"""
        user = User.query.get_or_404(user_id)
        return user.to_dict()

    @user_ns.doc('update_user')
    @user_ns.expect(user_update_schema)
    @user_ns.marshal_with(user_schema)
    @user_ns.response(400, '输入数据无效')
    @user_ns.response(409, '用户名或邮箱已存在')
    def put(self, user_id):
        """更新用户信息"""
        user = User.query.get_or_404(user_id)
        data = api.payload

        # 如果更新用户名,检查是否重复
        if 'username' in data and data['username'] != user.username:
            if User.query.filter_by(username=data['username']).first():
                return {'message': '用户名已存在'}, 409
            user.username = data['username']

        # 如果更新邮箱,检查是否重复
        if 'email' in data and data['email'] != user.email:
            if User.query.filter_by(email=data['email']).first():
                return {'message': '邮箱已存在'}, 409
            user.email = data['email']

        db.session.commit()
        return user.to_dict()

    @user_ns.doc('delete_user')
    @user_ns.response(204, '用户已删除')
    def delete(self, user_id):
        """删除用户"""
        user = User.query.get_or_404(user_id)

        db.session.delete(user)
        db.session.commit()

        return '', 204


@user_ns.route('/search')  # /api/users/search
class UserSearchResource(Resource):
    """用户搜索资源"""

    @user_ns.doc('search_users')
    @user_ns.expect(api.parser().add_argument('q', required=True, help='搜索关键词'))
    @user_ns.marshal_list_with(user_schema)
    def get(self):
        """搜索用户"""
        parser = api.parser()
        parser.add_argument('q', required=True, help='搜索关键词')
        args = parser.parse_args()

        search_term = f"%{args['q']}%"
        # 模糊查询
        users = User.query.filter(
            (User.username.like(search_term)) |
            (User.email.like(search_term))
        ).all()

        return [user.to_dict() for user in users]


# ============================================
# API 资源 - 用户认证
# ============================================
@auth_ns.route('/register')
class RegisterResource(Resource):
    """用户注册"""

    @auth_ns.doc('register_user')
    @auth_ns.expect(user_create_schema)
    @auth_ns.marshal_with(user_schema, code=201)
    @auth_ns.response(400, '输入数据无效')
    def post(self):
        """用户注册"""
        data = api.payload

        # 验证邮箱格式
        import re
        email_pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
        if not re.match(email_pattern, data['email']):
            return {'message': '邮箱格式无效'}, 400

        # 创建用户
        user = User(
            username=data['username'],
            email=data['email'],
            password_hash=User.hash_password(data['password'])
        )

        db.session.add(user)
        db.session.commit()

        return user.to_dict(), 201


@auth_ns.route('/login')
class LoginResource(Resource):
    """用户登录"""

    @auth_ns.doc('user_login')
    @auth_ns.expect(login_schema)
    @auth_ns.response(200, '登录成功')
    @auth_ns.response(401, '用户名或密码错误')
    def post(self):
        """用户登录"""
        data = api.payload

        # 查找用户
        user = User.query.filter_by(username=data['username']).first()

        if not user or not user.check_password(data['password']):
            return {'message': '用户名或密码错误'}, 401

        # 生成 JWT token
        from flask_jwt_extended import create_access_token, create_refresh_token

        access_token = create_access_token(identity={
            'id': user.id,
            'username': user.username
        })

        refresh_token = create_refresh_token(identity={
            'id': user.id,
            'username': user.username
        })

        return {
            'access_token': access_token,
            'refresh_token': refresh_token,
            'token_type': 'Bearer',
            'expires_in': 3600,
            'user': user.to_dict()
        }, 200


@auth_ns.route('/refresh')
class RefreshTokenResource(Resource):
    """刷新访问令牌"""

    @auth_ns.doc('refresh_token')
    @auth_ns.response(200, '令牌刷新成功')
    @auth_ns.response(401, '刷新令牌无效')
    def post(self):
        """刷新访问令牌"""
        from flask_jwt_extended import jwt_required, get_jwt_identity, create_access_token

        identity = get_jwt_identity()
        access_token = create_access_token(identity=identity)

        return {
            'access_token': access_token,
            'token_type': 'Bearer',
            'expires_in': 3600
        }, 200


@auth_ns.route('/me')
class CurrentUserResource(Resource):
    """当前用户信息"""

    @auth_ns.doc('get_current_user')
    @auth_ns.response(200, '成功获取用户信息')
    @auth_ns.response(401, '未授权访问')
    def get(self):
        """获取当前登录用户信息"""

        identity = get_jwt_identity()
        user = User.query.get(identity['id'])

        if not user:
            return {'message': '用户不存在'}, 404

        return user.to_dict(), 200


# ============================================
# 错误处理
# ============================================
@app.errorhandler(404)
def not_found(error):
    """处理 404 错误"""
    return jsonify({
        'error': 'Not Found',
        'message': '请求的资源不存在',
        'status_code': 404
    }), 404


@app.errorhandler(500)
def internal_error(error):
    """处理 500 错误"""
    return jsonify({
        'error': 'Internal Server Error',
        'message': '服务器内部错误',
        'status_code': 500
    }), 500


# ============================================
# 自定义 Swagger 文档
# ============================================

# @api.documentation
# def custom_ui():   # 自定义接口文档格式   /api/docs 访问
#     """自定义 Swagger UI 配置"""
#     return {
#         'apisSorter': 'alpha',
#         'operationsSorter': 'alpha',
#         'docExpansion': 'none',
#         'defaultModelsExpandDepth': 2,
#         'defaultModelExpandDepth': 2,
#         'showCommonExtensions': True,
#         'showExtensions': True
#     }


# ============================================
# 应用启动
# ============================================

if __name__ == '__main__':
    # 创建数据库表
    with app.app_context():
        db.create_all()

    # 运行应用
    app.run(
        host=os.getenv('FLASK_HOST', '0.0.0.0'),
        port=int(os.getenv('FLASK_PORT', 5000)),
        debug=os.getenv('FLASK_ENV') == 'development'
    )

页面效果

相关推荐
人工智能训练4 小时前
【极速部署】Ubuntu24.04+CUDA13.0 玩转 VLLM 0.15.0:预编译 Wheel 包 GPU 版安装全攻略
运维·前端·人工智能·python·ai编程·cuda·vllm
yaoming1685 小时前
python性能优化方案研究
python·性能优化
码云数智-大飞6 小时前
使用 Python 高效提取 PDF 中的表格数据并导出为 TXT 或 Excel
python
biuyyyxxx7 小时前
Python自动化办公学习笔记(一) 工具安装&教程
笔记·python·学习·自动化
极客数模7 小时前
【2026美赛赛题初步翻译F题】2026_ICM_Problem_F
大数据·c语言·python·数学建模·matlab
小鸡吃米…9 小时前
机器学习中的代价函数
人工智能·python·机器学习
Li emily10 小时前
如何通过外汇API平台快速实现实时数据接入?
开发语言·python·api·fastapi·美股
m0_5613596710 小时前
掌握Python魔法方法(Magic Methods)
jvm·数据库·python
Ulyanov10 小时前
顶层设计——单脉冲雷达仿真器的灵魂蓝图
python·算法·pyside·仿真系统·单脉冲
2401_8384725111 小时前
使用Python进行图像识别:CNN卷积神经网络实战
jvm·数据库·python