建立模型
-
python
from sqlalchemy.orm import DeclarativeBase class Base(DeclarativeBase): pass
-
以用户为例,建立用户模型继承基类
pythonfrom sqlalchemy import Integer, String, ForeignKey, DateTime, Boolean from sqlalchemy.orm import mapped_column, Mapped from src.db.model.base import Base # 用户表 class User(Base): __tablename__ = 'user' id: Mapped[int] = mapped_column(Integer, primary_key=True, comment='用户id') username: Mapped[str] = mapped_column(String(64), unique=True, nullable=True, comment='用户名称') password: Mapped[str] = mapped_column(String(64), comment='用户密码') name: Mapped[str] = mapped_column(String(32), comment='姓名') mobile_phone: Mapped[str] = mapped_column(String(32), comment='手机号') cloud_role_id: Mapped[int] = mapped_column(Integer, ForeignKey('cloud_role.id'), comment='平台角色id') user_group_id: Mapped[int] = mapped_column(Integer, ForeignKey('user_group.id'), comment='所在用户组id') status: Mapped[Boolean] = mapped_column(Boolean, comment='状态') register_time: Mapped[str] = mapped_column(String(32), comment='注册时间') last_login_time: Mapped[str] = mapped_column(DateTime, comment='最后登录时间') __table_args__ = ({'comment': '用户表'})
comment表示注释,生产mysql的数据表里面会带上注释
-
建立异步引擎和session,后面接口的async_session()都从这里引入
pythonfrom sqlalchemy.ext.asyncio import create_async_engine, AsyncSession engine = create_async_engine(DATABASE_URL, future=True, pool_pre_ping=True, pool_recycle=3600) async_session = sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)
编写异步接口
-
添加用户
pythonclass UserDao: def __init__(self): pass @classmethod async def add_user(cls, username: str, password: str, name: str, mobile_phone: Optional[str], cloud_role_id: int, user_group_id: int, status: int) -> bool: if mobile_phone is None: mobile_phone = '' # 对密码进行hash md_password = hashlib.md5(password.encode('utf-8')).hexdigest() user = User(username=username, password=md_password, name=name, mobile_phone=mobile_phone, cloud_role_id=cloud_role_id, user_group_id=user_group_id, status=status) user.register_time = datetime.datetime.now() user.last_login_time = user.register_time # 使用with包裹可以自动处理session的commit()和rollback() async with async_session() as session: async with session.begin(): try: session.add(user) await session.flush() return True except Exception as e: log.error(e) return False
-
删除用户
python@classmethod async def delete_user(cls, user_id: int) -> bool: async with async_session() as session: async with session.begin(): try: query_sql = select(User).filter(User.id == user_id) user = (await session.execute(query_sql)).scalar() if user is not None: # 删除站点关联的用户记录 delete_relation_sql = delete(StationSiteUsers).filter(StationSiteUsers.user_id == user_id) await session.execute(delete_relation_sql) # 删除用户 delete_user_sql = delete(User).filter(User.id == user_id) await session.execute(delete_user_sql) return True else: return False except Exception as e: log.error(e) return False
注意,使用select查询时获取单个要用scalar(),因为sqlalchemy的异步方法不支持query。
一般是select搭配scalar()使用,query()搭配fisrt()使用
-
编辑用户
python@classmethod async def edit_user(cls, user_id: int, username: str, name: str, mobile_phone: Optional[str], cloud_role_id: int, user_group_id: int, status: int) -> bool: if mobile_phone is None: mobile_phone = '' async with async_session() as session: async with session.begin(): try: # 查询user query_sql = select(User).filter(User.id == user_id) user = (await session.execute(query_sql)).scalar() if user is not None: user.username = username user.name = name user.mobile_phone = mobile_phone user.cloud_role_id = cloud_role_id user.user_group_id = user_group_id user.status = status return True else: return False except Exception as e: log.error(e) return False
-
查询用户信息
python# 获取所有用户信息 @classmethod async def get_all_user_info(cls) -> List[dict]: async with async_session() as session: async with session.begin(): try: query_sql = select(User.id, User.username, User.name, User.mobile_phone, User.cloud_role_id, User.status, User.user_group_id, UserGroup.group_name, User.last_login_time).join( CloudRole, User.cloud_role_id == CloudRole.id).join( UserGroup, User.user_group_id == UserGroup.id) results = (await session.execute(query_sql)).fetchall() user_info_list = [] for result in results: user_dict = {'user_id': result[0], 'username': result[1], 'name': result[2], 'mobile_phone': result[3], 'cloud_role_id': result[4], 'status': result[5], 'user_group_id': result[6], 'user_group_name': result[7], 'last_login_time': datetime.datetime.strftime(result[8], '%Y-%m-%d %H:%M:%S')} user_info_list.append(user_dict) return user_info_list except Exception as e: log.error(e) return []
使用pytest编写单元测试
-
测试user方法
pythonimport pytest from src.db.async_controller.config import engine from src.db.async_dao.user_dao import UserDao user_dao = UserDao() class TestUserDao: def setup_method(self): pass def teardown_method(self): engine.dispose() @pytest.mark.asyncio async def test_add_user(self): await user_dao.add_user(username='superadmin', password="superadmin", name="superadmin", mobile_phone="123456", cloud_role_id=1, user_group_id=1, status=1) await user_dao.add_user(username='cdy', password="888888", name="cdy", mobile_phone="123456", cloud_role_id=2, user_group_id=2, status=1) # 管理员账号 await user_dao.add_user(username='admin', password="888888", name="admin", mobile_phone="123456", cloud_role_id=1, user_group_id=3, status=1) await user_dao.add_user(username='admin1', password="888888", name="admin", mobile_phone="123456", cloud_role_id=1, user_group_id=3, status=1) await user_dao.add_user(username='admin2', password="888888", name="admin", mobile_phone="123456", cloud_role_id=1, user_group_id=3, status=1) # 普通账号 await user_dao.add_user(username='hongdou1', password="888888", name="test", mobile_phone="123456", cloud_role_id=2, user_group_id=4, status=1) await user_dao.add_user(username='hongdou2', password="123456", name="test", mobile_phone="123456", cloud_role_id=2, user_group_id=4, status=1) await user_dao.add_user(username='hongdou3', password="123456", name="test", mobile_phone="123456", cloud_role_id=2, user_group_id=4, status=1) await user_dao.add_user(username='hongdou4', password="123456", name="test", mobile_phone="123456", cloud_role_id=2, user_group_id=4, status=1) @pytest.mark.asyncio async def test_delete_user(self): result = await user_dao.delete_user(user_id=7) print(result) @pytest.mark.asyncio async def test_get_accessible_user_list(self): print(await user_dao.get_accessible_user_list(user_id=1))