目录
- JWT认证与OAuth2集成:构建安全的现代API
-
- 引言
- [1. JWT认证基础](#1. JWT认证基础)
-
- [1.1 JWT结构解析](#1.1 JWT结构解析)
-
- [1.1.1 Header(头部)](#1.1.1 Header(头部))
- [1.1.2 Payload(负载)](#1.1.2 Payload(负载))
- [1.1.3 Signature(签名)](#1.1.3 Signature(签名))
- [1.2 JWT工作流程](#1.2 JWT工作流程)
- [2. 完整的JWT认证系统实现](#2. 完整的JWT认证系统实现)
-
- [2.1 项目结构](#2.1 项目结构)
- [2.2 配置管理](#2.2 配置管理)
- [2.3 密码安全处理](#2.3 密码安全处理)
- [2.4 JWT处理器](#2.4 JWT处理器)
- [2.5 数据模型](#2.5 数据模型)
- [2.6 Pydantic模型](#2.6 Pydantic模型)
- [2.7 认证依赖项](#2.7 认证依赖项)
- [3. OAuth2集成实现](#3. OAuth2集成实现)
-
- [3.1 OAuth2工作流程](#3.1 OAuth2工作流程)
- [3.2 OAuth2配置管理](#3.2 OAuth2配置管理)
- [3.3 OAuth数据库服务](#3.3 OAuth数据库服务)
- [4. API端点实现](#4. API端点实现)
-
- [4.1 认证端点](#4.1 认证端点)
- [4.2 用户管理端点](#4.2 用户管理端点)
- [5. 高级安全特性](#5. 高级安全特性)
-
- [5.1 双因素认证实现](#5.1 双因素认证实现)
- [5.2 会话管理](#5.2 会话管理)
- [5.3 安全中间件](#5.3 安全中间件)
- [6. 完整的用户服务实现](#6. 完整的用户服务实现)
- [7. 测试用例](#7. 测试用例)
- [8. 部署配置](#8. 部署配置)
-
- [8.1 环境变量配置](#8.1 环境变量配置)
- [8.2 Docker配置](#8.2 Docker配置)
- [9. 性能优化和安全建议](#9. 性能优化和安全建议)
-
- [9.1 性能优化](#9.1 性能优化)
- [9.2 安全建议](#9.2 安全建议)
- [10. 总结](#10. 总结)
『宝藏代码胶囊开张啦!』------ 我的 CodeCapsule 来咯!✨写代码不再头疼!我的新站点 CodeCapsule 主打一个 "白菜价"+"量身定制 "!无论是卡脖子的毕设/课设/文献复现 ,需要灵光一现的算法改进 ,还是想给项目加个"外挂",这里都有便宜又好用的代码方案等你发现!低成本,高适配,助你轻松通关!速来围观 👉 CodeCapsule官网
JWT认证与OAuth2集成:构建安全的现代API
引言
在现代Web应用和API开发中,身份认证和授权是确保系统安全的核心组件。JSON Web Token(JWT)和OAuth2.0已经成为实现这些安全机制的标准协议。本文将深入探讨JWT的工作原理、OAuth2.0的授权流程,以及如何在FastAPI中实现完整的认证授权系统。
1. JWT认证基础
1.1 JWT结构解析
JWT由三部分组成,格式为:header.payload.signature
1.1.1 Header(头部)
json
{
"alg": "HS256", // 签名算法
"typ": "JWT" // 令牌类型
}
1.1.2 Payload(负载)
包含声明(claims),有三种类型:
- Registered claims:预定义的声明,如iss(签发者)、exp(过期时间)
- Public claims:公开的声明,可以自定义
- Private claims:私有声明,用于在双方之间共享信息
1.1.3 Signature(签名)
使用Base64Url编码的头部、负载和密钥计算签名:
Signature = HMAC256 ( base64UrlEncode ( h e a d e r ) + " . " + base64UrlEncode ( p a y l o a d ) , secret ) \text{Signature} = \text{HMAC256}(\text{base64UrlEncode}(header) + "." + \text{base64UrlEncode}(payload), \text{secret}) Signature=HMAC256(base64UrlEncode(header)+"."+base64UrlEncode(payload),secret)
1.2 JWT工作流程
Client API Server Database 1. 发送登录请求(username/password) 2. 验证用户凭证 3. 返回用户信息 4. 生成JWT令牌 5. 返回JWT令牌 6. 后续请求携带JWT(Authorization: Bearer <token>) 7. 验证JWT签名和过期时间 8. 返回请求的资源 Client API Server Database
2. 完整的JWT认证系统实现
2.1 项目结构
auth_system/
├── src/
│ ├── __init__.py
│ ├── main.py # FastAPI应用入口
│ ├── config.py # 配置文件
│ ├── database.py # 数据库配置
│ ├── models.py # 数据模型
│ ├── schemas.py # Pydantic模型
│ ├── auth/
│ │ ├── __init__.py
│ │ ├── dependencies.py # 认证依赖项
│ │ ├── jwt_handler.py # JWT处理器
│ │ ├── password.py # 密码处理
│ │ └── security.py # 安全配置
│ ├── api/
│ │ ├── __init__.py
│ │ ├── v1/
│ │ │ ├── __init__.py
│ │ │ ├── endpoints/
│ │ │ │ ├── auth.py
│ │ │ │ ├── users.py
│ │ │ │ └── admin.py
│ │ │ └── routers.py
│ ├── core/
│ │ ├── __init__.py
│ │ ├── exceptions.py # 自定义异常
│ │ └── security.py # 安全工具
│ └── utils/
│ ├── __init__.py
│ └── email.py # 邮件工具
├── tests/
├── alembic/ # 数据库迁移
├── .env
└── requirements.txt
2.2 配置管理
python
# src/config.py
from pydantic import BaseSettings, Field
from typing import Optional, List
import secrets
class Settings(BaseSettings):
"""应用配置"""
# 应用信息
APP_NAME: str = "FastAPI Authentication System"
APP_VERSION: str = "1.0.0"
API_V1_STR: str = "/api/v1"
# 安全配置
SECRET_KEY: str = Field(
default_factory=lambda: secrets.token_urlsafe(32)
)
ALGORITHM: str = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
REFRESH_TOKEN_EXPIRE_DAYS: int = 7
# CORS配置
BACKEND_CORS_ORIGINS: List[str] = [
"http://localhost:3000",
"https://localhost:3000",
]
# 数据库配置
DATABASE_URL: str = "postgresql+asyncpg://user:password@localhost/dbname"
DATABASE_POOL_SIZE: int = 20
DATABASE_MAX_OVERFLOW: int = 10
# Redis配置(用于令牌黑名单/刷新令牌)
REDIS_URL: str = "redis://localhost:6379/0"
# 邮件配置
SMTP_HOST: Optional[str] = None
SMTP_PORT: Optional[int] = 587
SMTP_USER: Optional[str] = None
SMTP_PASSWORD: Optional[str] = None
EMAILS_FROM_EMAIL: Optional[str] = None
# 安全增强
PASSWORD_RESET_TOKEN_EXPIRE_HOURS: int = 24
ACCOUNT_VERIFICATION_TOKEN_EXPIRE_HOURS: int = 24
MAX_LOGIN_ATTEMPTS: int = 5
LOCKOUT_TIME_MINUTES: int = 15
# 速率限制
RATE_LIMIT_PER_MINUTE: int = 60
RATE_LIMIT_PER_HOUR: int = 1000
# 环境
ENVIRONMENT: str = "development"
DEBUG: bool = False
class Config:
env_file = ".env"
case_sensitive = True
settings = Settings()
2.3 密码安全处理
python
# src/auth/password.py
from passlib.context import CryptContext
from datetime import datetime, timedelta
import secrets
import hashlib
from typing import Tuple, Optional
# 密码哈希上下文
pwd_context = CryptContext(
schemes=["argon2", "bcrypt"], # 优先使用argon2
deprecated="auto",
argon2__time_cost=3, # 时间成本
argon2__memory_cost=65536, # 内存成本(64MB)
argon2__parallelism=4, # 并行度
bcrypt__rounds=12, # bcrypt轮数
)
class PasswordManager:
"""密码管理器"""
@staticmethod
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""验证密码"""
return pwd_context.verify(plain_password, hashed_password)
@staticmethod
def get_password_hash(password: str) -> str:
"""获取密码哈希值"""
return pwd_context.hash(password)
@staticmethod
def validate_password_strength(password: str) -> Tuple[bool, str]:
"""
验证密码强度
返回: (是否有效, 错误消息)
"""
if len(password) < 8:
return False, "密码必须至少8个字符"
if len(password) > 128:
return False, "密码不能超过128个字符"
# 检查字符类型
has_upper = any(c.isupper() for c in password)
has_lower = any(c.islower() for c in password)
has_digit = any(c.isdigit() for c in password)
has_special = any(not c.isalnum() for c in password)
if not (has_upper and has_lower):
return False, "密码必须包含大小写字母"
if not has_digit:
return False, "密码必须包含数字"
if not has_special:
return False, "密码必须包含特殊字符"
# 检查常见弱密码
weak_passwords = {
"password", "123456", "qwerty", "admin", "welcome",
"password123", "123456789", "letmein", "monkey"
}
if password.lower() in weak_passwords:
return False, "密码太弱,请使用更强的密码"
return True, "密码强度足够"
@staticmethod
def generate_reset_token() -> str:
"""生成密码重置令牌"""
return secrets.token_urlsafe(32)
@staticmethod
def create_reset_token_hash(token: str) -> str:
"""创建重置令牌的哈希值(用于存储)"""
return hashlib.sha256(token.encode()).hexdigest()
@staticmethod
def generate_totp_secret() -> str:
"""生成TOTP密钥(用于双因素认证)"""
return secrets.token_hex(20) # 40个字符的十六进制字符串
# 密码策略常量
PASSWORD_POLICY = {
"min_length": 8,
"max_length": 128,
"require_upper": True,
"require_lower": True,
"require_digit": True,
"require_special": True,
"prevent_common": True,
}
2.4 JWT处理器
python
# src/auth/jwt_handler.py
from datetime import datetime, timedelta
from typing import Optional, Dict, Any, Union
import jwt
from jwt.exceptions import InvalidTokenError
import redis.asyncio as redis
from src.config import settings
from src.core.exceptions import AuthException
import uuid
class JWTManager:
"""JWT令牌管理器"""
def __init__(self):
self.secret_key = settings.SECRET_KEY
self.algorithm = settings.ALGORITHM
self.access_token_expire = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
self.refresh_token_expire = timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
# Redis连接(用于令牌黑名单和刷新令牌)
self.redis_client = redis.from_url(
settings.REDIS_URL,
encoding="utf-8",
decode_responses=False,
)
def create_access_token(
self,
subject: Union[str, Dict],
token_type: str = "access",
additional_claims: Optional[Dict[str, Any]] = None
) -> str:
"""
创建访问令牌
Args:
subject: 主题(通常是用户ID)
token_type: 令牌类型(access/refresh)
additional_claims: 额外声明
Returns:
JWT令牌字符串
"""
now = datetime.utcnow()
if token_type == "access":
expire = now + self.access_token_expire
else:
expire = now + self.refresh_token_expire
# 基本声明
claims = {
"sub": str(subject),
"iat": now,
"exp": expire,
"type": token_type,
"jti": str(uuid.uuid4()), # JWT ID,用于撤销
}
# 添加额外声明
if additional_claims:
claims.update(additional_claims)
# 生成令牌
encoded_jwt = jwt.encode(
claims,
self.secret_key,
algorithm=self.algorithm
)
return encoded_jwt
def decode_token(self, token: str) -> Dict[str, Any]:
"""
解码并验证JWT令牌
Args:
token: JWT令牌字符串
Returns:
解码后的声明字典
Raises:
AuthException: 如果令牌无效
"""
try:
payload = jwt.decode(
token,
self.secret_key,
algorithms=[self.algorithm]
)
# 检查令牌类型
if payload.get("type") != "access":
raise AuthException("令牌类型不正确")
return payload
except jwt.ExpiredSignatureError:
raise AuthException("令牌已过期")
except jwt.InvalidTokenError as e:
raise AuthException(f"无效令牌: {str(e)}")
async def revoke_token(self, token: str) -> None:
"""
撤销令牌(加入黑名单)
Args:
token: 要撤销的令牌
"""
try:
payload = self.decode_token(token)
# 计算剩余过期时间
exp_timestamp = payload.get("exp")
if exp_timestamp:
exp_datetime = datetime.fromtimestamp(exp_timestamp)
ttl = (exp_datetime - datetime.utcnow()).total_seconds()
# 将令牌ID加入黑名单,有效期至令牌过期
jti = payload.get("jti")
if jti and ttl > 0:
await self.redis_client.setex(
f"blacklist:{jti}",
int(ttl),
"revoked"
)
except AuthException:
pass # 无效令牌无需加入黑名单
async def is_token_revoked(self, token: str) -> bool:
"""
检查令牌是否已被撤销
Args:
token: 要检查的令牌
Returns:
是否已被撤销
"""
try:
payload = self.decode_token(token)
jti = payload.get("jti")
if not jti:
return True
# 检查是否在黑名单中
result = await self.redis_client.exists(f"blacklist:{jti}")
return bool(result)
except AuthException:
return True
async def create_token_pair(self, user_id: str) -> Dict[str, str]:
"""
创建访问令牌和刷新令牌对
Args:
user_id: 用户ID
Returns:
包含access_token和refresh_token的字典
"""
# 创建访问令牌
access_token = self.create_access_token(
subject=user_id,
token_type="access",
additional_claims={
"scope": "read write",
}
)
# 创建刷新令牌
refresh_token = self.create_access_token(
subject=user_id,
token_type="refresh",
additional_claims={
"scope": "refresh",
}
)
# 存储刷新令牌(用于刷新时验证)
refresh_token_payload = jwt.decode(
refresh_token,
self.secret_key,
algorithms=[self.algorithm]
)
refresh_token_exp = refresh_token_payload.get("exp")
if refresh_token_exp:
ttl = refresh_token_exp - int(datetime.utcnow().timestamp())
if ttl > 0:
await self.redis_client.setex(
f"refresh_token:{user_id}",
ttl,
refresh_token_payload.get("jti", "")
)
return {
"access_token": access_token,
"refresh_token": refresh_token,
"token_type": "bearer",
"expires_in": int(self.access_token_expire.total_seconds()),
}
async def refresh_access_token(self, refresh_token: str) -> Dict[str, str]:
"""
使用刷新令牌获取新的访问令牌
Args:
refresh_token: 刷新令牌
Returns:
新的访问令牌
Raises:
AuthException: 如果刷新令牌无效
"""
try:
# 解码刷新令牌
payload = jwt.decode(
refresh_token,
self.secret_key,
algorithms=[self.algorithm]
)
# 验证令牌类型
if payload.get("type") != "refresh":
raise AuthException("无效的刷新令牌")
user_id = payload.get("sub")
jti = payload.get("jti")
if not user_id or not jti:
raise AuthException("无效的刷新令牌")
# 验证刷新令牌是否有效
stored_jti = await self.redis_client.get(f"refresh_token:{user_id}")
if not stored_jti or stored_jti.decode() != jti:
raise AuthException("刷新令牌无效或已过期")
# 创建新的访问令牌
new_access_token = self.create_access_token(
subject=user_id,
token_type="access",
additional_claims={
"scope": "read write",
}
)
return {
"access_token": new_access_token,
"token_type": "bearer",
"expires_in": int(self.access_token_expire.total_seconds()),
}
except jwt.ExpiredSignatureError:
raise AuthException("刷新令牌已过期")
except jwt.InvalidTokenError as e:
raise AuthException(f"无效的刷新令牌: {str(e)}")
# 全局JWT管理器实例
jwt_manager = JWTManager()
2.5 数据模型
python
# src/models.py
from sqlalchemy import Boolean, Column, Integer, String, DateTime, Text, Enum
from sqlalchemy.dialects.postgresql import UUID, JSONB
from sqlalchemy.sql import func
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship
import enum
import uuid
Base = declarative_base()
class UserRole(str, enum.Enum):
"""用户角色枚举"""
USER = "user"
MODERATOR = "moderator"
ADMIN = "admin"
SUPER_ADMIN = "super_admin"
class UserStatus(str, enum.Enum):
"""用户状态枚举"""
ACTIVE = "active"
INACTIVE = "inactive"
SUSPENDED = "suspended"
DELETED = "deleted"
class User(Base):
"""用户模型"""
__tablename__ = "users"
id = Column(
UUID(as_uuid=True),
primary_key=True,
default=uuid.uuid4,
index=True
)
email = Column(String(255), unique=True, index=True, nullable=False)
username = Column(String(50), unique=True, index=True, nullable=False)
full_name = Column(String(100))
# 密码相关
hashed_password = Column(String(255), nullable=False)
password_changed_at = Column(DateTime(timezone=True))
# 角色和状态
role = Column(Enum(UserRole), default=UserRole.USER, nullable=False)
status = Column(Enum(UserStatus), default=UserStatus.INACTIVE, nullable=False)
# 账户安全
is_email_verified = Column(Boolean, default=False)
is_2fa_enabled = Column(Boolean, default=False)
two_factor_secret = Column(String(100))
# 登录相关
last_login_at = Column(DateTime(timezone=True))
last_login_ip = Column(String(45)) # 支持IPv6
failed_login_attempts = Column(Integer, default=0)
locked_until = Column(DateTime(timezone=True))
# 元数据
created_at = Column(
DateTime(timezone=True),
server_default=func.now(),
nullable=False
)
updated_at = Column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
nullable=False
)
# 关系
sessions = relationship("UserSession", back_populates="user", cascade="all, delete-orphan")
oauth_accounts = relationship("OAuthAccount", back_populates="user", cascade="all, delete-orphan")
def __repr__(self):
return f"<User(id={self.id}, email={self.email}, username={self.username})>"
class UserSession(Base):
"""用户会话模型"""
__tablename__ = "user_sessions"
id = Column(
UUID(as_uuid=True),
primary_key=True,
default=uuid.uuid4,
index=True
)
user_id = Column(UUID(as_uuid=True), index=True, nullable=False)
session_token = Column(String(255), unique=True, index=True, nullable=False)
# 设备信息
user_agent = Column(Text)
ip_address = Column(String(45))
device_type = Column(String(50))
device_name = Column(String(100))
browser = Column(String(50))
platform = Column(String(50))
# 会话状态
is_active = Column(Boolean, default=True)
last_activity_at = Column(DateTime(timezone=True), server_default=func.now())
# 过期时间
expires_at = Column(DateTime(timezone=True), nullable=False)
# 时间戳
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
# 关系
user = relationship("User", back_populates="sessions")
class OAuthProvider(str, enum.Enum):
"""OAuth提供者枚举"""
GOOGLE = "google"
GITHUB = "github"
FACEBOOK = "facebook"
MICROSOFT = "microsoft"
APPLE = "apple"
class OAuthAccount(Base):
"""OAuth账户模型"""
__tablename__ = "oauth_accounts"
id = Column(
UUID(as_uuid=True),
primary_key=True,
default=uuid.uuid4,
index=True
)
user_id = Column(
UUID(as_uuid=True),
index=True,
nullable=False
)
# OAuth提供者信息
provider = Column(Enum(OAuthProvider), nullable=False)
provider_user_id = Column(String(255), nullable=False)
provider_user_email = Column(String(255), nullable=False)
# OAuth令牌信息
access_token = Column(Text, nullable=False)
refresh_token = Column(Text)
access_token_expires_at = Column(DateTime(timezone=True))
# 提供者返回的额外数据
provider_data = Column(JSONB)
# 时间戳
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now())
# 唯一约束
__table_args__ = (
# 每个用户在同一个提供者下只能有一个账户
(UniqueConstraint('user_id', 'provider', name='uq_user_provider')),
# 同一个提供者的用户ID必须唯一
(UniqueConstraint('provider', 'provider_user_id', name='uq_provider_user')),
)
# 关系
user = relationship("User", back_populates="oauth_accounts")
def __repr__(self):
return f"<OAuthAccount(provider={self.provider}, email={self.provider_user_email})>"
class PasswordResetToken(Base):
"""密码重置令牌模型"""
__tablename__ = "password_reset_tokens"
id = Column(
UUID(as_uuid=True),
primary_key=True,
default=uuid.uuid4,
index=True
)
email = Column(String(255), nullable=False, index=True)
token_hash = Column(String(255), nullable=False, unique=True, index=True)
expires_at = Column(DateTime(timezone=True), nullable=False)
is_used = Column(Boolean, default=False)
used_at = Column(DateTime(timezone=True))
created_at = Column(DateTime(timezone=True), server_default=func.now())
class EmailVerificationToken(Base):
"""邮箱验证令牌模型"""
__tablename__ = "email_verification_tokens"
id = Column(
UUID(as_uuid=True),
primary_key=True,
default=uuid.uuid4,
index=True
)
user_id = Column(UUID(as_uuid=True), nullable=False, index=True)
token_hash = Column(String(255), nullable=False, unique=True, index=True)
expires_at = Column(DateTime(timezone=True), nullable=False)
is_used = Column(Boolean, default=False)
used_at = Column(DateTime(timezone=True))
created_at = Column(DateTime(timezone=True), server_default=func.now())
2.6 Pydantic模型
python
# src/schemas.py
from pydantic import BaseModel, EmailStr, Field, validator
from typing import Optional, List
from datetime import datetime
from enum import Enum
import re
class Token(BaseModel):
"""令牌响应模型"""
access_token: str
refresh_token: str
token_type: str = "bearer"
expires_in: int
class Config:
schema_extra = {
"example": {
"access_token": "eyJhbGciOiJIUzI1NiIs...",
"refresh_token": "eyJhbGciOiJIUzI1NiIs...",
"token_type": "bearer",
"expires_in": 1800
}
}
class TokenData(BaseModel):
"""令牌数据模型"""
sub: Optional[str] = None
exp: Optional[datetime] = None
type: Optional[str] = None
scope: Optional[str] = None
jti: Optional[str] = None
class UserBase(BaseModel):
"""用户基础模型"""
email: EmailStr
username: str = Field(..., min_length=3, max_length=50, regex="^[a-zA-Z0-9_]+$")
full_name: Optional[str] = Field(None, max_length=100)
class UserCreate(UserBase):
"""用户创建模型"""
password: str
@validator('password')
def validate_password(cls, v):
from src.auth.password import PasswordManager
is_valid, message = PasswordManager.validate_password_strength(v)
if not is_valid:
raise ValueError(message)
return v
@validator('username')
def validate_username(cls, v):
# 检查用户名是否包含敏感词
sensitive_words = ['admin', 'root', 'system', 'moderator']
if v.lower() in sensitive_words:
raise ValueError('用户名包含敏感词汇')
return v
class UserLogin(BaseModel):
"""用户登录模型"""
email: Optional[EmailStr] = None
username: Optional[str] = None
password: str
@validator('email', 'username')
def check_identifier(cls, v, values, **kwargs):
# 确保至少提供邮箱或用户名之一
if not v and not values.get('username') and not values.get('email'):
raise ValueError('必须提供邮箱或用户名')
return v
class UserInDB(UserBase):
"""数据库中的用户模型"""
id: str
role: str
status: str
is_email_verified: bool
is_2fa_enabled: bool
created_at: datetime
updated_at: datetime
class Config:
orm_mode = True
class UserResponse(BaseModel):
"""用户响应模型"""
id: str
email: EmailStr
username: str
full_name: Optional[str]
role: str
status: str
is_email_verified: bool
is_2fa_enabled: bool
created_at: datetime
class Config:
orm_mode = True
class PasswordChange(BaseModel):
"""密码修改模型"""
current_password: str
new_password: str
@validator('new_password')
def validate_new_password(cls, v):
from src.auth.password import PasswordManager
is_valid, message = PasswordManager.validate_password_strength(v)
if not is_valid:
raise ValueError(message)
return v
class PasswordResetRequest(BaseModel):
"""密码重置请求模型"""
email: EmailStr
class PasswordResetConfirm(BaseModel):
"""密码重置确认模型"""
token: str
new_password: str
@validator('new_password')
def validate_new_password(cls, v):
from src.auth.password import PasswordManager
is_valid, message = PasswordManager.validate_password_strength(v)
if not is_valid:
raise ValueError(message)
return v
class TwoFactorEnable(BaseModel):
"""双因素认证启用模型"""
code: str = Field(..., min_length=6, max_length=6, regex="^[0-9]+$")
class TwoFactorVerify(BaseModel):
"""双因素认证验证模型"""
code: str = Field(..., min_length=6, max_length=6, regex="^[0-9]+$")
class OAuthLoginRequest(BaseModel):
"""OAuth登录请求模型"""
provider: str
redirect_uri: str
class OAuthCallback(BaseModel):
"""OAuth回调模型"""
code: str
state: str
class UserSessionResponse(BaseModel):
"""用户会话响应模型"""
id: str
user_agent: Optional[str]
ip_address: Optional[str]
device_type: Optional[str]
device_name: Optional[str]
browser: Optional[str]
platform: Optional[str]
last_activity_at: datetime
created_at: datetime
class Config:
orm_mode = True
2.7 认证依赖项
python
# src/auth/dependencies.py
from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from typing import Optional, List
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from src.database import get_db
from src.auth.jwt_handler import jwt_manager
from src.models import User, UserRole, UserStatus
from src.schemas import TokenData
security = HTTPBearer(
scheme_name="JWT",
description="JWT Bearer令牌",
auto_error=False
)
class PermissionChecker:
"""权限检查器"""
def __init__(self, required_roles: List[UserRole] = None):
self.required_roles = required_roles or []
async def __call__(self, current_user: User = Depends(get_current_user)):
"""检查用户是否有足够权限"""
if self.required_roles:
user_role = UserRole(current_user.role)
if user_role not in self.required_roles:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="权限不足"
)
return current_user
async def get_current_user(
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security),
db: AsyncSession = Depends(get_db)
) -> User:
"""
获取当前认证用户
Args:
credentials: HTTP授权凭证
db: 数据库会话
Returns:
当前用户
Raises:
HTTPException: 如果认证失败
"""
if not credentials:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="需要认证",
headers={"WWW-Authenticate": "Bearer"},
)
token = credentials.credentials
try:
# 验证令牌是否已被撤销
if await jwt_manager.is_token_revoked(token):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="令牌已失效",
headers={"WWW-Authenticate": "Bearer"},
)
# 解码令牌
payload = jwt_manager.decode_token(token)
token_data = TokenData(**payload)
if token_data.sub is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无效的认证凭证",
headers={"WWW-Authenticate": "Bearer"},
)
# 获取用户
user_id = token_data.sub
result = await db.execute(
select(User).where(User.id == user_id)
)
user = result.scalar_one_or_none()
if user is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="用户不存在",
headers={"WWW-Authenticate": "Bearer"},
)
# 检查用户状态
if user.status == UserStatus.SUSPENDED:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="账户已被封禁"
)
if user.status == UserStatus.DELETED:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="账户已删除"
)
return user
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"认证失败: {str(e)}",
headers={"WWW-Authenticate": "Bearer"},
)
async def get_current_active_user(
current_user: User = Depends(get_current_user)
) -> User:
"""获取当前活跃用户"""
if current_user.status != UserStatus.ACTIVE:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="账户未激活"
)
return current_user
async def get_current_superuser(
current_user: User = Depends(get_current_user)
) -> User:
"""获取当前超级管理员用户"""
if current_user.role != UserRole.SUPER_ADMIN:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="需要超级管理员权限"
)
return current_user
# 权限依赖项
require_admin = PermissionChecker(required_roles=[UserRole.ADMIN, UserRole.SUPER_ADMIN])
require_moderator = PermissionChecker(required_roles=[UserRole.MODERATOR, UserRole.ADMIN, UserRole.SUPER_ADMIN])
require_super_admin = PermissionChecker(required_roles=[UserRole.SUPER_ADMIN])
# 可选认证依赖项
async def get_current_user_optional(
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security),
db: AsyncSession = Depends(get_db)
) -> Optional[User]:
"""
可选的当前用户获取(如果没有认证,返回None)
"""
if not credentials:
return None
try:
user = await get_current_user(credentials, db)
return user
except HTTPException:
return None
3. OAuth2集成实现
3.1 OAuth2工作流程
- 发起OAuth登录 2. 重定向到授权页面 3. 用户授权 4. 返回授权码 5. 用授权码交换令牌 6. 返回访问令牌 7. 获取用户信息 8. 返回用户信息 9. 创建/获取本地用户 10. 生成JWT令牌 客户端 FastAPI应用 OAuth提供者
如Google/GitHub 用户 数据库
3.2 OAuth2配置管理
python
# src/auth/oauth.py
from typing import Dict, Optional, Any
from authlib.integrations.starlette_client import OAuth
from authlib.oauth2.rfc6749.wrappers import OAuth2Token
from starlette.config import Config
from src.config import settings
from src.models import OAuthProvider
import secrets
class OAuthConfig:
"""OAuth配置管理器"""
def __init__(self):
self.oauth = OAuth()
# 加载配置
config_dict = {
# Google配置
"GOOGLE_CLIENT_ID": getattr(settings, "GOOGLE_CLIENT_ID", ""),
"GOOGLE_CLIENT_SECRET": getattr(settings, "GOOGLE_CLIENT_SECRET", ""),
# GitHub配置
"GITHUB_CLIENT_ID": getattr(settings, "GITHUB_CLIENT_ID", ""),
"GITHUB_CLIENT_SECRET": getattr(settings, "GITHUB_CLIENT_SECRET", ""),
# Facebook配置
"FACEBOOK_CLIENT_ID": getattr(settings, "FACEBOOK_CLIENT_ID", ""),
"FACEBOOK_CLIENT_SECRET": getattr(settings, "FACEBOOK_CLIENT_SECRET", ""),
# Microsoft配置
"MICROSOFT_CLIENT_ID": getattr(settings, "MICROSOFT_CLIENT_ID", ""),
"MICROSOFT_CLIENT_SECRET": getattr(settings, "MICROSOFT_CLIENT_SECRET", ""),
# Apple配置
"APPLE_CLIENT_ID": getattr(settings, "APPLE_CLIENT_ID", ""),
"APPLE_CLIENT_SECRET": getattr(settings, "APPLE_CLIENT_SECRET", ""),
"APPLE_KEY_ID": getattr(settings, "APPLE_KEY_ID", ""),
"APPLE_TEAM_ID": getattr(settings, "APPLE_TEAM_ID", ""),
"APPLE_PRIVATE_KEY": getattr(settings, "APPLE_PRIVATE_KEY", ""),
}
config = Config(environ=config_dict)
# 注册Google OAuth
if config_dict["GOOGLE_CLIENT_ID"]:
self.oauth.register(
name='google',
client_id=config('GOOGLE_CLIENT_ID'),
client_secret=config('GOOGLE_CLIENT_SECRET'),
server_metadata_url='https://accounts.google.com/.well-known/openid-configuration',
client_kwargs={
'scope': 'openid email profile',
'prompt': 'select_account',
}
)
# 注册GitHub OAuth
if config_dict["GITHUB_CLIENT_ID"]:
self.oauth.register(
name='github',
client_id=config('GITHUB_CLIENT_ID'),
client_secret=config('GITHUB_CLIENT_SECRET'),
access_token_url='https://github.com/login/oauth/access_token',
authorize_url='https://github.com/login/oauth/authorize',
api_base_url='https://api.github.com/',
client_kwargs={'scope': 'user:email'}
)
# 注册Facebook OAuth
if config_dict["FACEBOOK_CLIENT_ID"]:
self.oauth.register(
name='facebook',
client_id=config('FACEBOOK_CLIENT_ID'),
client_secret=config('FACEBOOK_CLIENT_SECRET'),
access_token_url='https://graph.facebook.com/oauth/access_token',
authorize_url='https://www.facebook.com/dialog/oauth',
api_base_url='https://graph.facebook.com/',
client_kwargs={'scope': 'email'}
)
# 注册Microsoft OAuth
if config_dict["MICROSOFT_CLIENT_ID"]:
self.oauth.register(
name='microsoft',
client_id=config('MICROSOFT_CLIENT_ID'),
client_secret=config('MICROSOFT_CLIENT_SECRET'),
server_metadata_url='https://login.microsoftonline.com/common/v2.0/.well-known/openid-configuration',
client_kwargs={
'scope': 'openid email profile',
}
)
def get_client(self, provider: str):
"""获取OAuth客户端"""
return self.oauth.create_client(provider)
def get_provider_scopes(self, provider: str) -> Dict[str, str]:
"""获取提供者的scope配置"""
scopes = {
'google': 'openid email profile',
'github': 'user:email',
'facebook': 'email',
'microsoft': 'openid email profile',
'apple': 'email name',
}
return scopes.get(provider, '')
class OAuthService:
"""OAuth服务"""
def __init__(self, oauth_config: OAuthConfig):
self.oauth_config = oauth_config
self.state_store = {} # 临时存储state,生产环境应该使用Redis
async def create_authorization_url(
self,
provider: str,
redirect_uri: str,
state: Optional[str] = None
) -> Dict[str, str]:
"""
创建授权URL
Args:
provider: OAuth提供者
redirect_uri: 回调URL
state: 状态参数(防止CSRF攻击)
Returns:
包含授权URL和state的字典
"""
client = self.oauth_config.get_client(provider)
if not client:
raise ValueError(f"不支持的OAuth提供者: {provider}")
# 生成state(如果未提供)
if not state:
state = secrets.token_urlsafe(16)
# 存储state(生产环境应该使用Redis)
self.state_store[state] = {
'provider': provider,
'redirect_uri': redirect_uri,
'created_at': datetime.utcnow().isoformat()
}
# 创建授权URL
authorization_url = await client.create_authorization_url(
redirect_uri=redirect_uri,
state=state
)
return {
'authorization_url': authorization_url['url'],
'state': state
}
async def validate_state(self, state: str) -> bool:
"""验证state参数"""
if state not in self.state_store:
return False
state_data = self.state_store[state]
created_at = datetime.fromisoformat(state_data['created_at'])
# State应该在10分钟内使用
if datetime.utcnow() - created_at > timedelta(minutes=10):
del self.state_store[state]
return False
return True
async def get_access_token(
self,
provider: str,
code: str,
redirect_uri: str
) -> OAuth2Token:
"""
使用授权码获取访问令牌
Args:
provider: OAuth提供者
code: 授权码
redirect_uri: 回调URL
Returns:
OAuth2令牌
"""
client = self.oauth_config.get_client(provider)
if not client:
raise ValueError(f"不支持的OAuth提供者: {provider}")
# 交换令牌
token = await client.fetch_access_token(
code=code,
redirect_uri=redirect_uri
)
return token
async def get_user_info(self, provider: str, token: OAuth2Token) -> Dict[str, Any]:
"""
获取用户信息
Args:
provider: OAuth提供者
token: OAuth2令牌
Returns:
用户信息字典
"""
client = self.oauth_config.get_client(provider)
if not client:
raise ValueError(f"不支持的OAuth提供者: {provider}")
# 设置令牌
client.token = token
# 根据提供者获取用户信息
if provider == 'google':
userinfo_endpoint = 'https://openidconnect.googleapis.com/v1/userinfo'
resp = await client.get(userinfo_endpoint)
user_info = resp.json()
return {
'id': user_info.get('sub'),
'email': user_info.get('email'),
'verified_email': user_info.get('email_verified', False),
'name': user_info.get('name'),
'given_name': user_info.get('given_name'),
'family_name': user_info.get('family_name'),
'picture': user_info.get('picture'),
'locale': user_info.get('locale'),
}
elif provider == 'github':
# GitHub需要额外请求获取邮箱
resp = await client.get('user')
user_info = resp.json()
# 获取主要邮箱
emails_resp = await client.get('user/emails')
emails = emails_resp.json()
primary_email = None
verified_email = None
for email in emails:
if email.get('primary'):
primary_email = email.get('email')
if email.get('verified') and not primary_email:
verified_email = email.get('email')
email = primary_email or verified_email or user_info.get('email')
return {
'id': str(user_info.get('id')),
'email': email,
'verified_email': bool(email),
'name': user_info.get('name'),
'login': user_info.get('login'),
'avatar_url': user_info.get('avatar_url'),
'bio': user_info.get('bio'),
'location': user_info.get('location'),
}
elif provider == 'facebook':
# Facebook需要指定fields
resp = await client.get('me', params={
'fields': 'id,name,email,picture'
})
user_info = resp.json()
return {
'id': user_info.get('id'),
'email': user_info.get('email'),
'verified_email': True, # Facebook邮箱默认已验证
'name': user_info.get('name'),
'picture': user_info.get('picture', {}).get('data', {}).get('url'),
}
elif provider == 'microsoft':
resp = await client.get('https://graph.microsoft.com/v1.0/me')
user_info = resp.json()
return {
'id': user_info.get('id'),
'email': user_info.get('mail') or user_info.get('userPrincipalName'),
'verified_email': True, # Microsoft邮箱默认已验证
'name': user_info.get('displayName'),
'given_name': user_info.get('givenName'),
'family_name': user_info.get('surname'),
}
else:
raise ValueError(f"不支持的提供者: {provider}")
async def cleanup_state(self, state: str):
"""清理state"""
if state in self.state_store:
del self.state_store[state]
# 全局OAuth服务实例
oauth_config = OAuthConfig()
oauth_service = OAuthService(oauth_config)
3.3 OAuth数据库服务
python
# src/services/oauth_service.py
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, or_
from typing import Optional, Dict, Any
from datetime import datetime, timedelta
from src.models import User, OAuthAccount, OAuthProvider, UserStatus
from src.auth.password import PasswordManager
from src.auth.jwt_handler import jwt_manager
from src.core.exceptions import AuthException
class OAuthUserService:
"""OAuth用户服务"""
def __init__(self, db: AsyncSession):
self.db = db
async def get_or_create_user_from_oauth(
self,
provider: OAuthProvider,
provider_user_id: str,
provider_user_email: str,
oauth_token: Dict[str, Any],
user_info: Dict[str, Any]
) -> User:
"""
从OAuth信息获取或创建用户
Args:
provider: OAuth提供者
provider_user_id: 提供者用户ID
provider_user_email: 提供者用户邮箱
oauth_token: OAuth令牌信息
user_info: 用户信息
Returns:
用户对象
"""
# 1. 查找现有的OAuth账户
oauth_account = await self.find_oauth_account(provider, provider_user_id)
if oauth_account:
# 更新OAuth令牌信息
await self.update_oauth_account(oauth_account, oauth_token)
return oauth_account.user
# 2. 查找现有用户(通过邮箱)
user = await self.find_user_by_email(provider_user_email)
if user:
# 为用户创建新的OAuth账户
await self.create_oauth_account(
user=user,
provider=provider,
provider_user_id=provider_user_id,
provider_user_email=provider_user_email,
oauth_token=oauth_token,
user_info=user_info
)
return user
# 3. 创建新用户
user = await self.create_user_from_oauth(
provider=provider,
provider_user_id=provider_user_id,
provider_user_email=provider_user_email,
oauth_token=oauth_token,
user_info=user_info
)
return user
async def find_oauth_account(
self,
provider: OAuthProvider,
provider_user_id: str
) -> Optional[OAuthAccount]:
"""查找OAuth账户"""
result = await self.db.execute(
select(OAuthAccount).where(
OAuthAccount.provider == provider,
OAuthAccount.provider_user_id == provider_user_id
)
)
return result.scalar_one_or_none()
async def find_user_by_email(self, email: str) -> Optional[User]:
"""通过邮箱查找用户"""
result = await self.db.execute(
select(User).where(User.email == email)
)
return result.scalar_one_or_none()
async def update_oauth_account(
self,
oauth_account: OAuthAccount,
oauth_token: Dict[str, Any]
):
"""更新OAuth账户信息"""
oauth_account.access_token = oauth_token.get('access_token')
oauth_account.refresh_token = oauth_token.get('refresh_token')
# 更新访问令牌过期时间
if expires_at := oauth_token.get('expires_at'):
oauth_account.access_token_expires_at = expires_at
oauth_account.updated_at = datetime.utcnow()
await self.db.commit()
async def create_oauth_account(
self,
user: User,
provider: OAuthProvider,
provider_user_id: str,
provider_user_email: str,
oauth_token: Dict[str, Any],
user_info: Dict[str, Any]
):
"""创建OAuth账户"""
oauth_account = OAuthAccount(
user_id=user.id,
provider=provider,
provider_user_id=provider_user_id,
provider_user_email=provider_user_email,
access_token=oauth_token.get('access_token'),
refresh_token=oauth_token.get('refresh_token'),
access_token_expires_at=oauth_token.get('expires_at'),
provider_data=user_info,
)
self.db.add(oauth_account)
await self.db.commit()
async def create_user_from_oauth(
self,
provider: OAuthProvider,
provider_user_id: str,
provider_user_email: str,
oauth_token: Dict[str, Any],
user_info: Dict[str, Any]
) -> User:
"""从OAuth信息创建用户"""
# 生成用户名(如果没有提供)
username = self.generate_username(user_info)
# 生成随机密码(用户可以通过邮箱重置)
random_password = PasswordManager.get_password_hash(
secrets.token_urlsafe(32)
)
# 创建用户
user = User(
email=provider_user_email,
username=username,
full_name=user_info.get('name'),
hashed_password=random_password,
is_email_verified=user_info.get('verified_email', True),
status=UserStatus.ACTIVE,
)
self.db.add(user)
await self.db.commit()
await self.db.refresh(user)
# 创建OAuth账户
await self.create_oauth_account(
user=user,
provider=provider,
provider_user_id=provider_user_id,
provider_user_email=provider_user_email,
oauth_token=oauth_token,
user_info=user_info
)
return user
def generate_username(self, user_info: Dict[str, Any]) -> str:
"""生成用户名"""
# 尝试从用户信息获取用户名
username = None
if 'login' in user_info: # GitHub
username = user_info['login']
elif 'given_name' in user_info: # Google/Microsoft
given_name = user_info.get('given_name', '').lower()
family_name = user_info.get('family_name', '').lower()
if given_name and family_name:
username = f"{given_name}.{family_name}"
elif given_name:
username = given_name
# 如果没有合适的用户名,生成一个
if not username:
email = user_info.get('email', 'user')
base_username = email.split('@')[0]
username = base_username
# 确保用户名合法
username = re.sub(r'[^a-zA-Z0-9_]', '', username)
# 添加随机后缀以防重复
if len(username) < 3:
username += str(random.randint(100, 999))
else:
username += str(random.randint(10, 99))
return username[:50] # 限制长度
4. API端点实现
4.1 认证端点
python
# src/api/v1/endpoints/auth.py
from fastapi import APIRouter, Depends, HTTPException, status, BackgroundTasks
from fastapi.security import OAuth2PasswordRequestForm
from sqlalchemy.ext.asyncio import AsyncSession
from typing import Dict, Any
import pyotp
import qrcode
import io
import base64
from src.database import get_db
from src.auth.dependencies import get_current_user
from src.auth.jwt_handler import jwt_manager
from src.auth.password import PasswordManager
from src.schemas import (
UserCreate, UserLogin, Token,
PasswordChange, PasswordResetRequest,
PasswordResetConfirm, TwoFactorEnable,
TwoFactorVerify, OAuthLoginRequest,
OAuthCallback, UserResponse
)
from src.models import User, UserSession, UserStatus
from src.services.user_service import UserService
from src.services.oauth_service import OAuthUserService
from src.auth.oauth import oauth_service
from src.utils.email import send_password_reset_email, send_verification_email
router = APIRouter()
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
async def register(
user_in: UserCreate,
background_tasks: BackgroundTasks,
db: AsyncSession = Depends(get_db)
):
"""
用户注册
"""
user_service = UserService(db)
# 检查邮箱是否已存在
if await user_service.get_user_by_email(user_in.email):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="邮箱已被注册"
)
# 检查用户名是否已存在
if await user_service.get_user_by_username(user_in.username):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="用户名已被使用"
)
# 创建用户
user = await user_service.create_user(user_in)
# 发送验证邮件
if settings.EMAILS_ENABLED:
background_tasks.add_task(
send_verification_email,
email_to=user.email,
username=user.username,
user_id=user.id
)
return user
@router.post("/login", response_model=Token)
async def login(
form_data: OAuth2PasswordRequestForm = Depends(),
db: AsyncSession = Depends(get_db)
):
"""
用户登录(使用用户名/邮箱和密码)
"""
user_service = UserService(db)
# 获取用户
user = await user_service.authenticate_user(
identifier=form_data.username, # username字段实际可以是邮箱或用户名
password=form_data.password
)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="用户名或密码错误",
headers={"WWW-Authenticate": "Bearer"},
)
# 检查账户状态
if user.status != UserStatus.ACTIVE:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="账户未激活"
)
# 检查是否被封禁
if user.locked_until and user.locked_until > datetime.utcnow():
raise HTTPException(
status_code=status.HTTP_423_LOCKED,
detail="账户已被锁定,请稍后重试"
)
# 如果需要双因素认证
if user.is_2fa_enabled:
# 这里应该返回需要2FA验证的响应
return {
"requires_2fa": True,
"message": "需要双因素认证"
}
# 更新登录信息
await user_service.update_login_info(
user_id=user.id,
login_success=True,
ip_address=None # 实际应该从请求中获取
)
# 创建会话
session = await user_service.create_session(
user_id=user.id,
user_agent="", # 从请求头获取
ip_address=None
)
# 生成令牌
token_data = await jwt_manager.create_token_pair(str(user.id))
return token_data
@router.post("/login/2fa", response_model=Token)
async def verify_2fa(
verify_data: TwoFactorVerify,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""
验证双因素认证
"""
user_service = UserService(db)
# 验证TOTP代码
if not user_service.verify_totp_code(
secret=current_user.two_factor_secret,
code=verify_data.code
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="验证码错误"
)
# 生成令牌
token_data = await jwt_manager.create_token_pair(str(current_user.id))
return token_data
@router.post("/refresh", response_model=Token)
async def refresh_token(
refresh_token: str,
db: AsyncSession = Depends(get_db)
):
"""
刷新访问令牌
"""
try:
token_data = await jwt_manager.refresh_access_token(refresh_token)
return token_data
except Exception as e:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=str(e)
)
@router.post("/logout")
async def logout(
current_user: User = Depends(get_current_user),
token: str = Depends(get_current_user_token)
):
"""
用户登出
"""
await jwt_manager.revoke_token(token)
return {"message": "登出成功"}
@router.post("/password/change")
async def change_password(
password_data: PasswordChange,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""
修改密码
"""
user_service = UserService(db)
# 验证当前密码
if not PasswordManager.verify_password(
password_data.current_password,
current_user.hashed_password
):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="当前密码错误"
)
# 更新密码
await user_service.update_password(
user_id=current_user.id,
new_password=password_data.new_password
)
# 使所有会话失效
await user_service.invalidate_all_sessions(current_user.id)
return {"message": "密码修改成功"}
@router.post("/password/reset/request")
async def request_password_reset(
reset_request: PasswordResetRequest,
background_tasks: BackgroundTasks,
db: AsyncSession = Depends(get_db)
):
"""
请求密码重置
"""
user_service = UserService(db)
# 获取用户
user = await user_service.get_user_by_email(reset_request.email)
if not user:
# 为了防止枚举攻击,即使用户不存在也返回成功
return {"message": "如果邮箱存在,重置链接已发送"}
# 生成重置令牌
reset_token = await user_service.create_password_reset_token(user.email)
# 发送重置邮件
if settings.EMAILS_ENABLED:
background_tasks.add_task(
send_password_reset_email,
email_to=user.email,
username=user.username,
token=reset_token
)
return {"message": "如果邮箱存在,重置链接已发送"}
@router.post("/password/reset/confirm")
async def confirm_password_reset(
reset_data: PasswordResetConfirm,
db: AsyncSession = Depends(get_db)
):
"""
确认密码重置
"""
user_service = UserService(db)
# 验证重置令牌
user = await user_service.verify_password_reset_token(
token=reset_data.token
)
if not user:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="无效或过期的重置令牌"
)
# 更新密码
await user_service.update_password(
user_id=user.id,
new_password=reset_data.new_password
)
# 使所有会话失效
await user_service.invalidate_all_sessions(user.id)
return {"message": "密码重置成功"}
@router.post("/2fa/enable", response_model=Dict[str, Any])
async def enable_2fa(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""
启用双因素认证
"""
user_service = UserService(db)
# 生成TOTP密钥
totp_secret = PasswordManager.generate_totp_secret()
# 生成二维码
totp = pyotp.TOTP(totp_secret)
provisioning_uri = totp.provisioning_uri(
name=current_user.email,
issuer_name=settings.APP_NAME
)
# 生成二维码图片
qr = qrcode.QRCode(
version=1,
error_correction=qrcode.constants.ERROR_CORRECT_L,
box_size=10,
border=4,
)
qr.add_data(provisioning_uri)
qr.make(fit=True)
img = qr.make_image(fill_color="black", back_color="white")
# 转换为base64
buffered = io.BytesIO()
img.save(buffered, format="PNG")
qr_code_base64 = base64.b64encode(buffered.getvalue()).decode()
# 暂时保存密钥(等待验证)
await user_service.set_temporary_2fa_secret(
user_id=current_user.id,
secret=totp_secret
)
return {
"secret": totp_secret,
"provisioning_uri": provisioning_uri,
"qr_code": f"data:image/png;base64,{qr_code_base64}",
"message": "扫描二维码并输入验证码以完成设置"
}
@router.post("/2fa/confirm")
async def confirm_2fa(
verify_data: TwoFactorEnable,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""
确认启用双因素认证
"""
user_service = UserService(db)
# 获取临时保存的密钥
temp_secret = await user_service.get_temporary_2fa_secret(current_user.id)
if not temp_secret:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="请先请求启用2FA"
)
# 验证TOTP代码
if not user_service.verify_totp_code(
secret=temp_secret,
code=verify_data.code
):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="验证码错误"
)
# 启用2FA
await user_service.enable_2fa(
user_id=current_user.id,
secret=temp_secret
)
return {"message": "双因素认证已启用"}
@router.post("/2fa/disable")
async def disable_2fa(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""
禁用双因素认证
"""
user_service = UserService(db)
await user_service.disable_2fa(current_user.id)
return {"message": "双因素认证已禁用"}
@router.get("/oauth/{provider}/login")
async def oauth_login(
provider: str,
redirect_uri: str,
request: Request
):
"""
OAuth登录
"""
# 创建授权URL
auth_data = await oauth_service.create_authorization_url(
provider=provider,
redirect_uri=redirect_uri
)
return {
"authorization_url": auth_data["authorization_url"],
"state": auth_data["state"]
}
@router.get("/oauth/{provider}/callback")
async def oauth_callback(
provider: str,
code: str,
state: str,
db: AsyncSession = Depends(get_db)
):
"""
OAuth回调
"""
# 验证state
if not await oauth_service.validate_state(state):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="无效的state参数"
)
try:
# 获取state数据
state_data = oauth_service.state_store.get(state)
if not state_data:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="无效的state参数"
)
redirect_uri = state_data['redirect_uri']
# 获取访问令牌
oauth_token = await oauth_service.get_access_token(
provider=provider,
code=code,
redirect_uri=redirect_uri
)
# 获取用户信息
user_info = await oauth_service.get_user_info(provider, oauth_token)
# 获取或创建用户
oauth_user_service = OAuthUserService(db)
user = await oauth_user_service.get_or_create_user_from_oauth(
provider=OAuthProvider(provider),
provider_user_id=user_info['id'],
provider_user_email=user_info['email'],
oauth_token=oauth_token,
user_info=user_info
)
# 生成JWT令牌
token_data = await jwt_manager.create_token_pair(str(user.id))
# 清理state
await oauth_service.cleanup_state(state)
return token_data
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"OAuth认证失败: {str(e)}"
)
@router.get("/me", response_model=UserResponse)
async def get_current_user_info(
current_user: User = Depends(get_current_user)
):
"""
获取当前用户信息
"""
return current_user
@router.get("/sessions")
async def get_user_sessions(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""
获取用户的所有会话
"""
user_service = UserService(db)
sessions = await user_service.get_user_sessions(current_user.id)
return sessions
@router.post("/sessions/{session_id}/revoke")
async def revoke_session(
session_id: str,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""
撤销特定会话
"""
user_service = UserService(db)
success = await user_service.revoke_session(
user_id=current_user.id,
session_id=session_id
)
if not success:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="会话不存在"
)
return {"message": "会话已撤销"}
@router.post("/sessions/revoke-all")
async def revoke_all_sessions(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""
撤销所有会话(除了当前会话)
"""
user_service = UserService(db)
await user_service.invalidate_all_sessions(
user_id=current_user.id,
exclude_current=True
)
return {"message": "所有其他会话已撤销"}
4.2 用户管理端点
python
# src/api/v1/endpoints/users.py
from fastapi import APIRouter, Depends, HTTPException, status, Query
from sqlalchemy.ext.asyncio import AsyncSession
from typing import List, Optional
from src.database import get_db
from src.auth.dependencies import (
get_current_user, require_admin,
require_super_admin, PermissionChecker
)
from src.models import User, UserRole, UserStatus
from src.schemas import UserResponse
from src.services.user_service import UserService
router = APIRouter()
@router.get("/", response_model=List[UserResponse])
async def list_users(
skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=1000),
role: Optional[UserRole] = None,
status: Optional[UserStatus] = None,
search: Optional[str] = None,
current_user: User = Depends(require_admin),
db: AsyncSession = Depends(get_db)
):
"""
获取用户列表(需要管理员权限)
"""
user_service = UserService(db)
users = await user_service.get_users(
skip=skip,
limit=limit,
role=role,
status=status,
search=search
)
return users
@router.get("/{user_id}", response_model=UserResponse)
async def get_user(
user_id: str,
current_user: User = Depends(require_admin),
db: AsyncSession = Depends(get_db)
):
"""
获取用户信息(需要管理员权限)
"""
user_service = UserService(db)
user = await user_service.get_user_by_id(user_id)
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="用户不存在"
)
return user
@router.patch("/{user_id}/role")
async def update_user_role(
user_id: str,
new_role: UserRole,
current_user: User = Depends(require_super_admin),
db: AsyncSession = Depends(get_db)
):
"""
更新用户角色(需要超级管理员权限)
"""
user_service = UserService(db)
# 不能修改自己的角色
if str(current_user.id) == user_id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="不能修改自己的角色"
)
success = await user_service.update_user_role(user_id, new_role)
if not success:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="用户不存在"
)
return {"message": "用户角色更新成功"}
@router.patch("/{user_id}/status")
async def update_user_status(
user_id: str,
new_status: UserStatus,
current_user: User = Depends(require_admin),
db: AsyncSession = Depends(get_db)
):
"""
更新用户状态(需要管理员权限)
"""
user_service = UserService(db)
# 不能修改自己的状态
if str(current_user.id) == user_id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="不能修改自己的状态"
)
success = await user_service.update_user_status(user_id, new_status)
if not success:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="用户不存在"
)
return {"message": "用户状态更新成功"}
@router.delete("/{user_id}")
async def delete_user(
user_id: str,
current_user: User = Depends(require_super_admin),
db: AsyncSession = Depends(get_db)
):
"""
删除用户(需要超级管理员权限)
"""
user_service = UserService(db)
# 不能删除自己
if str(current_user.id) == user_id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="不能删除自己"
)
success = await user_service.delete_user(user_id)
if not success:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="用户不存在"
)
return {"message": "用户删除成功"}
5. 高级安全特性
5.1 双因素认证实现
python
# src/services/two_factor_service.py
import pyotp
import time
from typing import Optional
from datetime import datetime, timedelta
class TwoFactorService:
"""双因素认证服务"""
@staticmethod
def generate_secret() -> str:
"""生成TOTP密钥"""
return pyotp.random_base32()
@staticmethod
def generate_provisioning_uri(
secret: str,
email: str,
issuer: str
) -> str:
"""生成配置URI(用于二维码)"""
totp = pyotp.TOTP(secret)
return totp.provisioning_uri(
name=email,
issuer_name=issuer
)
@staticmethod
def verify_code(secret: str, code: str) -> bool:
"""验证TOTP代码"""
totp = pyotp.TOTP(secret)
# 允许时间偏差(前后30秒)
return totp.verify(code, valid_window=1)
@staticmethod
def generate_backup_codes(count: int = 10) -> list:
"""生成备份代码"""
import secrets
codes = []
for _ in range(count):
# 生成8位数字代码
code = ''.join(secrets.choice('0123456789') for _ in range(8))
codes.append(code)
return codes
@staticmethod
def hash_backup_code(code: str) -> str:
"""哈希备份代码(用于安全存储)"""
import hashlib
return hashlib.sha256(code.encode()).hexdigest()
class TwoFactorManager:
"""双因素认证管理器"""
def __init__(self, db):
self.db = db
self.two_factor_service = TwoFactorService()
async def enable_2fa(self, user_id: str, secret: str) -> bool:
"""启用双因素认证"""
from src.models import User
user = await self.db.get(User, user_id)
if not user:
return False
user.is_2fa_enabled = True
user.two_factor_secret = secret
user.two_factor_enabled_at = datetime.utcnow()
# 生成备份代码
backup_codes = self.two_factor_service.generate_backup_codes()
user.backup_codes = [
self.two_factor_service.hash_backup_code(code)
for code in backup_codes
]
await self.db.commit()
# 返回备份代码(只显示一次)
return backup_codes
async def disable_2fa(self, user_id: str) -> bool:
"""禁用双因素认证"""
from src.models import User
user = await self.db.get(User, user_id)
if not user:
return False
user.is_2fa_enabled = False
user.two_factor_secret = None
user.two_factor_enabled_at = None
user.backup_codes = []
await self.db.commit()
return True
async def verify_2fa(
self,
user_id: str,
code: str,
use_backup_code: bool = False
) -> bool:
"""验证双因素认证"""
from src.models import User
user = await self.db.get(User, user_id)
if not user or not user.is_2fa_enabled:
return False
if use_backup_code:
# 验证备份代码
code_hash = self.two_factor_service.hash_backup_code(code)
if code_hash in user.backup_codes:
# 移除已使用的备份代码
user.backup_codes.remove(code_hash)
await self.db.commit()
return True
return False
else:
# 验证TOTP代码
return self.two_factor_service.verify_code(
user.two_factor_secret,
code
)
async def regenerate_backup_codes(self, user_id: str) -> list:
"""重新生成备份代码"""
from src.models import User
user = await self.db.get(User, user_id)
if not user or not user.is_2fa_enabled:
return []
backup_codes = self.two_factor_service.generate_backup_codes()
user.backup_codes = [
self.two_factor_service.hash_backup_code(code)
for code in backup_codes
]
await self.db.commit()
return backup_codes
5.2 会话管理
python
# src/services/session_service.py
from datetime import datetime, timedelta
from typing import Optional, List
import secrets
import ua_parser.user_agent_parser as ua_parser
class SessionService:
"""会话管理服务"""
def __init__(self, db):
self.db = db
async def create_session(
self,
user_id: str,
user_agent: Optional[str] = None,
ip_address: Optional[str] = None,
expires_in_hours: int = 24 * 7 # 默认7天
) -> str:
"""创建用户会话"""
from src.models import UserSession
# 解析User-Agent
device_info = self.parse_user_agent(user_agent)
# 生成会话令牌
session_token = secrets.token_urlsafe(32)
# 创建会话
session = UserSession(
user_id=user_id,
session_token=session_token,
user_agent=user_agent,
ip_address=ip_address,
device_type=device_info.get('device_type'),
device_name=device_info.get('device_name'),
browser=device_info.get('browser'),
platform=device_info.get('platform'),
expires_at=datetime.utcnow() + timedelta(hours=expires_in_hours)
)
self.db.add(session)
await self.db.commit()
return session_token
def parse_user_agent(self, user_agent: Optional[str]) -> dict:
"""解析User-Agent字符串"""
if not user_agent:
return {}
try:
parsed = ua_parser.Parse(user_agent)
device_type = "desktop"
device_name = ""
# 确定设备类型
device_family = parsed['device']['family']
if device_family == 'Spider':
device_type = 'bot'
elif device_family != 'Other':
device_type = 'mobile'
device_name = device_family
return {
'device_type': device_type,
'device_name': device_name,
'browser': parsed['user_agent']['family'],
'platform': parsed['os']['family']
}
except:
return {}
async def validate_session(
self,
session_token: str,
update_last_activity: bool = True
) -> Optional[str]:
"""验证会话令牌,返回用户ID"""
from src.models import UserSession
result = await self.db.execute(
select(UserSession).where(
UserSession.session_token == session_token,
UserSession.is_active == True,
UserSession.expires_at > datetime.utcnow()
)
)
session = result.scalar_one_or_none()
if not session:
return None
# 更新最后活动时间
if update_last_activity:
session.last_activity_at = datetime.utcnow()
await self.db.commit()
return str(session.user_id)
async def revoke_session(
self,
user_id: str,
session_id: str
) -> bool:
"""撤销特定会话"""
from src.models import UserSession
result = await self.db.execute(
select(UserSession).where(
UserSession.id == session_id,
UserSession.user_id == user_id,
UserSession.is_active == True
)
)
session = result.scalar_one_or_none()
if not session:
return False
session.is_active = False
await self.db.commit()
return True
async def revoke_all_sessions(
self,
user_id: str,
exclude_current: Optional[str] = None
) -> int:
"""撤销用户的所有会话(可排除当前会话)"""
from src.models import UserSession
query = select(UserSession).where(
UserSession.user_id == user_id,
UserSession.is_active == True
)
if exclude_current:
query = query.where(UserSession.session_token != exclude_current)
result = await self.db.execute(query)
sessions = result.scalars().all()
count = 0
for session in sessions:
session.is_active = False
count += 1
await self.db.commit()
return count
async def get_user_sessions(
self,
user_id: str,
active_only: bool = True
) -> List[dict]:
"""获取用户的所有会话"""
from src.models import UserSession
query = select(UserSession).where(
UserSession.user_id == user_id
)
if active_only:
query = query.where(
UserSession.is_active == True,
UserSession.expires_at > datetime.utcnow()
)
result = await self.db.execute(query)
sessions = result.scalars().all()
return [
{
'id': str(session.id),
'user_agent': session.user_agent,
'ip_address': session.ip_address,
'device_type': session.device_type,
'device_name': session.device_name,
'browser': session.browser,
'platform': session.platform,
'is_active': session.is_active,
'last_activity_at': session.last_activity_at,
'created_at': session.created_at,
'expires_at': session.expires_at
}
for session in sessions
]
async def cleanup_expired_sessions(self) -> int:
"""清理过期会话"""
from src.models import UserSession
result = await self.db.execute(
select(UserSession).where(
UserSession.expires_at <= datetime.utcnow(),
UserSession.is_active == True
)
)
sessions = result.scalars().all()
count = 0
for session in sessions:
session.is_active = False
count += 1
await self.db.commit()
return count
5.3 安全中间件
python
# src/middleware/security.py
from fastapi import Request, HTTPException
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response
import time
from typing import Dict, Tuple
import redis.asyncio as redis
class SecurityMiddleware(BaseHTTPMiddleware):
"""安全中间件"""
def __init__(
self,
app,
redis_client: redis.Redis,
rate_limit_per_minute: int = 60,
rate_limit_per_hour: int = 1000
):
super().__init__(app)
self.redis_client = redis_client
self.rate_limit_per_minute = rate_limit_per_minute
self.rate_limit_per_hour = rate_limit_per_hour
async def dispatch(self, request: Request, call_next):
# 安全检查
await self.check_rate_limit(request)
await self.check_suspicious_activity(request)
# 添加安全头
response = await call_next(request)
response = self.add_security_headers(response)
return response
async def check_rate_limit(self, request: Request):
"""检查速率限制"""
client_ip = request.client.host if request.client else "unknown"
path = request.url.path
# 每分钟限制
minute_key = f"rate_limit:{client_ip}:{path}:minute"
minute_count = await self.redis_client.get(minute_key)
if minute_count and int(minute_count) >= self.rate_limit_per_minute:
raise HTTPException(
status_code=429,
detail="请求过于频繁,请稍后再试",
headers={"Retry-After": "60"}
)
# 每小时限制
hour_key = f"rate_limit:{client_ip}:{path}:hour"
hour_count = await self.redis_client.get(hour_key)
if hour_count and int(hour_count) >= self.rate_limit_per_hour:
raise HTTPException(
status_code=429,
detail="每小时请求次数超限",
headers={"Retry-After": "3600"}
)
# 更新计数器
pipe = self.redis_client.pipeline()
pipe.incr(minute_key)
pipe.expire(minute_key, 60)
pipe.incr(hour_key)
pipe.expire(hour_key, 3600)
await pipe.execute()
async def check_suspicious_activity(self, request: Request):
"""检查可疑活动"""
client_ip = request.client.host if request.client else "unknown"
# 检查是否有可疑User-Agent
user_agent = request.headers.get("user-agent", "")
suspicious_agents = ["curl", "wget", "python-requests", "scan"]
if any(agent in user_agent.lower() for agent in suspicious_agents):
# 记录可疑活动
await self.redis_client.incr(f"suspicious:{client_ip}")
def add_security_headers(self, response: Response) -> Response:
"""添加安全头"""
security_headers = {
"X-Content-Type-Options": "nosniff",
"X-Frame-Options": "DENY",
"X-XSS-Protection": "1; mode=block",
"Strict-Transport-Security": "max-age=31536000; includeSubDomains",
"Content-Security-Policy": "default-src 'self'",
"Referrer-Policy": "strict-origin-when-cross-origin",
"Permissions-Policy": "geolocation=(), microphone=(), camera=()",
}
for header, value in security_headers.items():
response.headers[header] = value
return response
class LoginAttemptTracker:
"""登录尝试跟踪器"""
def __init__(self, redis_client: redis.Redis):
self.redis_client = redis_client
async def record_failed_attempt(self, identifier: str) -> Tuple[int, bool]:
"""
记录失败登录尝试
Returns:
(失败次数, 是否应该锁定)
"""
key = f"login_failures:{identifier}"
# 获取当前失败次数
failures = await self.redis_client.get(key)
failure_count = int(failures) if failures else 0
# 增加失败次数
failure_count += 1
await self.redis_client.setex(key, 900, failure_count) # 15分钟过期
# 检查是否应该锁定
should_lock = failure_count >= 5
return failure_count, should_lock
async def reset_failed_attempts(self, identifier: str):
"""重置失败登录尝试"""
key = f"login_failures:{identifier}"
await self.redis_client.delete(key)
async def lock_account(self, user_id: str, minutes: int = 15):
"""锁定账户"""
key = f"account_lock:{user_id}"
await self.redis_client.setex(key, minutes * 60, "locked")
async def is_account_locked(self, user_id: str) -> bool:
"""检查账户是否被锁定"""
key = f"account_lock:{user_id}"
return bool(await self.redis_client.exists(key))
6. 完整的用户服务实现
python
# src/services/user_service.py
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, update, delete, or_
from typing import Optional, List
from datetime import datetime, timedelta
import secrets
from src.models import User, UserSession, PasswordResetToken, EmailVerificationToken, UserRole, UserStatus
from src.auth.password import PasswordManager
from src.schemas import UserCreate, UserLogin
class UserService:
"""用户服务"""
def __init__(self, db: AsyncSession):
self.db = db
async def create_user(self, user_in: UserCreate) -> User:
"""创建用户"""
hashed_password = PasswordManager.get_password_hash(
user_in.password
)
user = User(
email=user_in.email,
username=user_in.username,
full_name=user_in.full_name,
hashed_password=hashed_password,
status=UserStatus.INACTIVE, # 需要邮箱验证
)
self.db.add(user)
await self.db.commit()
await self.db.refresh(user)
# 创建邮箱验证令牌
await self.create_email_verification_token(user.id)
return user
async def get_user_by_id(self, user_id: str) -> Optional[User]:
"""通过ID获取用户"""
result = await self.db.execute(
select(User).where(User.id == user_id)
)
return result.scalar_one_or_none()
async def get_user_by_email(self, email: str) -> Optional[User]:
"""通过邮箱获取用户"""
result = await self.db.execute(
select(User).where(User.email == email)
)
return result.scalar_one_or_none()
async def get_user_by_username(self, username: str) -> Optional[User]:
"""通过用户名获取用户"""
result = await self.db.execute(
select(User).where(User.username == username)
)
return result.scalar_one_or_none()
async def authenticate_user(
self,
identifier: str,
password: str
) -> Optional[User]:
"""验证用户凭证"""
# 尝试通过邮箱或用户名查找用户
result = await self.db.execute(
select(User).where(
or_(
User.email == identifier,
User.username == identifier
)
)
)
user = result.scalar_one_or_none()
if not user:
return None
# 验证密码
if not PasswordManager.verify_password(password, user.hashed_password):
return None
return user
async def update_login_info(
self,
user_id: str,
login_success: bool,
ip_address: Optional[str] = None
):
"""更新登录信息"""
if login_success:
# 成功登录
await self.db.execute(
update(User).where(User.id == user_id).values(
last_login_at=datetime.utcnow(),
last_login_ip=ip_address,
failed_login_attempts=0,
locked_until=None
)
)
else:
# 失败登录
await self.db.execute(
update(User)
.where(User.id == user_id)
.values(
failed_login_attempts=User.failed_login_attempts + 1
)
)
# 检查是否应该锁定账户
result = await self.db.execute(
select(User.failed_login_attempts)
.where(User.id == user_id)
)
attempts = result.scalar_one()
if attempts >= 5:
lock_until = datetime.utcnow() + timedelta(minutes=15)
await self.db.execute(
update(User).where(User.id == user_id).values(
locked_until=lock_until
)
)
await self.db.commit()
async def update_password(self, user_id: str, new_password: str):
"""更新用户密码"""
hashed_password = PasswordManager.get_password_hash(new_password)
await self.db.execute(
update(User).where(User.id == user_id).values(
hashed_password=hashed_password,
password_changed_at=datetime.utcnow()
)
)
await self.db.commit()
async def create_password_reset_token(self, email: str) -> str:
"""创建密码重置令牌"""
# 生成令牌
token = PasswordManager.generate_reset_token()
token_hash = PasswordManager.create_reset_token_hash(token)
# 创建重置令牌记录
reset_token = PasswordResetToken(
email=email,
token_hash=token_hash,
expires_at=datetime.utcnow() + timedelta(hours=24)
)
self.db.add(reset_token)
await self.db.commit()
return token
async def verify_password_reset_token(self, token: str) -> Optional[User]:
"""验证密码重置令牌"""
token_hash = PasswordManager.create_reset_token_hash(token)
# 查找有效的重置令牌
result = await self.db.execute(
select(PasswordResetToken).where(
PasswordResetToken.token_hash == token_hash,
PasswordResetToken.expires_at > datetime.utcnow(),
PasswordResetToken.is_used == False
)
)
reset_token = result.scalar_one_or_none()
if not reset_token:
return None
# 标记为已使用
reset_token.is_used = True
reset_token.used_at = datetime.utcnow()
# 获取用户
user = await self.get_user_by_email(reset_token.email)
await self.db.commit()
return user
async def create_email_verification_token(self, user_id: str) -> str:
"""创建邮箱验证令牌"""
token = secrets.token_urlsafe(32)
token_hash = PasswordManager.create_reset_token_hash(token)
# 创建验证令牌记录
verification_token = EmailVerificationToken(
user_id=user_id,
token_hash=token_hash,
expires_at=datetime.utcnow() + timedelta(hours=24)
)
self.db.add(verification_token)
await self.db.commit()
return token
async def verify_email_token(self, token: str) -> Optional[User]:
"""验证邮箱验证令牌"""
token_hash = PasswordManager.create_reset_token_hash(token)
# 查找有效的验证令牌
result = await self.db.execute(
select(EmailVerificationToken).where(
EmailVerificationToken.token_hash == token_hash,
EmailVerificationToken.expires_at > datetime.utcnow(),
EmailVerificationToken.is_used == False
)
)
verification_token = result.scalar_one_or_none()
if not verification_token:
return None
# 标记为已使用
verification_token.is_used = True
verification_token.used_at = datetime.utcnow()
# 激活用户
await self.db.execute(
update(User).where(User.id == verification_token.user_id).values(
is_email_verified=True,
status=UserStatus.ACTIVE
)
)
# 获取用户
user = await self.get_user_by_id(verification_token.user_id)
await self.db.commit()
return user
async def enable_2fa(self, user_id: str, secret: str):
"""启用双因素认证"""
await self.db.execute(
update(User).where(User.id == user_id).values(
is_2fa_enabled=True,
two_factor_secret=secret
)
)
await self.db.commit()
async def disable_2fa(self, user_id: str):
"""禁用双因素认证"""
await self.db.execute(
update(User).where(User.id == user_id).values(
is_2fa_enabled=False,
two_factor_secret=None
)
)
await self.db.commit()
async def create_session(
self,
user_id: str,
user_agent: Optional[str] = None,
ip_address: Optional[str] = None
) -> str:
"""创建用户会话"""
from src.services.session_service import SessionService
session_service = SessionService(self.db)
return await session_service.create_session(
user_id=user_id,
user_agent=user_agent,
ip_address=ip_address
)
async def invalidate_all_sessions(
self,
user_id: str,
exclude_current: bool = False
):
"""使用户的所有会话失效"""
from src.services.session_service import SessionService
session_service = SessionService(self.db)
return await session_service.revoke_all_sessions(
user_id=user_id,
exclude_current=exclude_current
)
async def get_users(
self,
skip: int = 0,
limit: int = 100,
role: Optional[UserRole] = None,
status: Optional[UserStatus] = None,
search: Optional[str] = None
) -> List[User]:
"""获取用户列表"""
query = select(User)
# 应用过滤器
if role:
query = query.where(User.role == role)
if status:
query = query.where(User.status == status)
if search:
search_filter = or_(
User.email.ilike(f"%{search}%"),
User.username.ilike(f"%{search}%"),
User.full_name.ilike(f"%{search}%")
)
query = query.where(search_filter)
# 应用分页
query = query.offset(skip).limit(limit).order_by(User.created_at.desc())
result = await self.db.execute(query)
return result.scalars().all()
async def update_user_role(self, user_id: str, new_role: UserRole) -> bool:
"""更新用户角色"""
result = await self.db.execute(
update(User)
.where(User.id == user_id)
.values(role=new_role)
)
await self.db.commit()
return result.rowcount > 0
async def update_user_status(self, user_id: str, new_status: UserStatus) -> bool:
"""更新用户状态"""
result = await self.db.execute(
update(User)
.where(User.id == user_id)
.values(status=new_status)
)
await self.db.commit()
return result.rowcount > 0
async def delete_user(self, user_id: str) -> bool:
"""删除用户(软删除)"""
result = await self.db.execute(
update(User)
.where(User.id == user_id)
.values(
status=UserStatus.DELETED,
email=f"deleted_{user_id}@deleted.com", # 修改邮箱防止重复
username=f"deleted_{user_id}"
)
)
await self.db.commit()
return result.rowcount > 0
7. 测试用例
python
# tests/test_auth.py
import pytest
from httpx import AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession
from datetime import datetime, timedelta
import jwt
from src.main import app
from src.models import User, UserStatus
from src.auth.jwt_handler import jwt_manager
from src.auth.password import PasswordManager
@pytest.mark.asyncio
class TestAuthentication:
"""认证测试"""
async def test_register_success(self, client: AsyncClient, db: AsyncSession):
"""测试成功注册"""
user_data = {
"email": "test@example.com",
"username": "testuser",
"password": "SecurePass123!",
"full_name": "Test User"
}
response = await client.post("/api/v1/auth/register", json=user_data)
assert response.status_code == 201
data = response.json()
assert data["email"] == user_data["email"]
assert data["username"] == user_data["username"]
assert data["status"] == UserStatus.INACTIVE.value
assert "id" in data
async def test_register_duplicate_email(self, client: AsyncClient, db: AsyncSession):
"""测试重复邮箱注册"""
# 先注册一个用户
user_data = {
"email": "duplicate@example.com",
"username": "user1",
"password": "SecurePass123!",
}
await client.post("/api/v1/auth/register", json=user_data)
# 尝试用相同邮箱注册
duplicate_data = {
"email": "duplicate@example.com",
"username": "user2",
"password": "AnotherPass123!",
}
response = await client.post("/api/v1/auth/register", json=duplicate_data)
assert response.status_code == 400
assert "邮箱已被注册" in response.json()["detail"]
async def test_login_success(self, client: AsyncClient, db: AsyncSession):
"""测试成功登录"""
# 先注册用户
user_data = {
"email": "login@example.com",
"username": "loginuser",
"password": "SecurePass123!",
}
await client.post("/api/v1/auth/register", json=user_data)
# 激活用户
user_service = UserService(db)
user = await user_service.get_user_by_email("login@example.com")
user.status = UserStatus.ACTIVE
user.is_email_verified = True
await db.commit()
# 登录
login_data = {
"username": "loginuser",
"password": "SecurePass123!"
}
response = await client.post(
"/api/v1/auth/login",
data=login_data,
headers={"Content-Type": "application/x-www-form-urlencoded"}
)
assert response.status_code == 200
data = response.json()
assert "access_token" in data
assert "refresh_token" in data
assert data["token_type"] == "bearer"
async def test_login_invalid_credentials(self, client: AsyncClient, db: AsyncSession):
"""测试无效凭证登录"""
login_data = {
"username": "nonexistent",
"password": "wrongpassword"
}
response = await client.post(
"/api/v1/auth/login",
data=login_data,
headers={"Content-Type": "application/x-www-form-urlencoded"}
)
assert response.status_code == 401
assert "用户名或密码错误" in response.json()["detail"]
async def test_refresh_token(self, client: AsyncClient, db: AsyncSession):
"""测试刷新令牌"""
# 先登录获取刷新令牌
user_data = {
"email": "refresh@example.com",
"username": "refreshuser",
"password": "SecurePass123!",
}
await client.post("/api/v1/auth/register", json=user_data)
# 激活用户
user_service = UserService(db)
user = await user_service.get_user_by_email("refresh@example.com")
user.status = UserStatus.ACTIVE
user.is_email_verified = True
await db.commit()
# 登录获取令牌
login_data = {
"username": "refreshuser",
"password": "SecurePass123!"
}
login_response = await client.post(
"/api/v1/auth/login",
data=login_data,
headers={"Content-Type": "application/x-www-form-urlencoded"}
)
refresh_token = login_response.json()["refresh_token"]
# 刷新令牌
response = await client.post(
"/api/v1/auth/refresh",
json={"refresh_token": refresh_token}
)
assert response.status_code == 200
data = response.json()
assert "access_token" in data
assert data["token_type"] == "bearer"
async def test_protected_endpoint(self, client: AsyncClient, db: AsyncSession):
"""测试受保护端点"""
# 先注册并激活用户
user_data = {
"email": "protected@example.com",
"username": "protecteduser",
"password": "SecurePass123!",
}
await client.post("/api/v1/auth/register", json=user_data)
# 激活用户
user_service = UserService(db)
user = await user_service.get_user_by_email("protected@example.com")
user.status = UserStatus.ACTIVE
user.is_email_verified = True
await db.commit()
# 登录获取令牌
login_data = {
"username": "protecteduser",
"password": "SecurePass123!"
}
login_response = await client.post(
"/api/v1/auth/login",
data=login_data,
headers={"Content-Type": "application/x-www-form-urlencoded"}
)
access_token = login_response.json()["access_token"]
# 访问受保护端点
response = await client.get(
"/api/v1/auth/me",
headers={"Authorization": f"Bearer {access_token}"}
)
assert response.status_code == 200
data = response.json()
assert data["email"] == user_data["email"]
assert data["username"] == user_data["username"]
async def test_protected_endpoint_no_token(self, client: AsyncClient):
"""测试无令牌访问受保护端点"""
response = await client.get("/api/v1/auth/me")
assert response.status_code == 401
assert "需要认证" in response.json()["detail"]
async def test_logout(self, client: AsyncClient, db: AsyncSession):
"""测试登出"""
# 先注册并激活用户
user_data = {
"email": "logout@example.com",
"username": "logoutuser",
"password": "SecurePass123!",
}
await client.post("/api/v1/auth/register", json=user_data)
# 激活用户
user_service = UserService(db)
user = await user_service.get_user_by_email("logout@example.com")
user.status = UserStatus.ACTIVE
user.is_email_verified = True
await db.commit()
# 登录获取令牌
login_data = {
"username": "logoutuser",
"password": "SecurePass123!"
}
login_response = await client.post(
"/api/v1/auth/login",
data=login_data,
headers={"Content-Type": "application/x-www-form-urlencoded"}
)
access_token = login_response.json()["access_token"]
# 登出
response = await client.post(
"/api/v1/auth/logout",
headers={"Authorization": f"Bearer {access_token}"}
)
assert response.status_code == 200
assert response.json()["message"] == "登出成功"
# 尝试使用已撤销的令牌访问
response = await client.get(
"/api/v1/auth/me",
headers={"Authorization": f"Bearer {access_token}"}
)
assert response.status_code == 401
@pytest.mark.asyncio
class TestOAuth:
"""OAuth测试"""
async def test_oauth_login_url(self, client: AsyncClient):
"""测试获取OAuth登录URL"""
response = await client.get(
"/api/v1/auth/oauth/google/login",
params={"redirect_uri": "http://localhost:3000/callback"}
)
assert response.status_code == 200
data = response.json()
assert "authorization_url" in data
assert "state" in data
assert "google.com" in data["authorization_url"]
@pytest.mark.asyncio
class TestPasswordReset:
"""密码重置测试"""
async def test_password_reset_flow(self, client: AsyncClient, db: AsyncSession):
"""测试完整密码重置流程"""
# 1. 注册用户
user_data = {
"email": "reset@example.com",
"username": "resetuser",
"password": "OldPass123!",
}
await client.post("/api/v1/auth/register", json=user_data)
# 2. 请求密码重置
reset_request = {"email": "reset@example.com"}
response = await client.post(
"/api/v1/auth/password/reset/request",
json=reset_request
)
assert response.status_code == 200
# 3. 确认密码重置(这里简化,实际应该从邮件获取令牌)
# 获取测试令牌
user_service = UserService(db)
user = await user_service.get_user_by_email("reset@example.com")
# 创建重置令牌
reset_token = await user_service.create_password_reset_token(user.email)
# 4. 使用令牌重置密码
reset_confirm = {
"token": reset_token,
"new_password": "NewPass456!"
}
response = await client.post(
"/api/v1/auth/password/reset/confirm",
json=reset_confirm
)
assert response.status_code == 200
# 5. 使用新密码登录
login_data = {
"username": "resetuser",
"password": "NewPass456!"
}
response = await client.post(
"/api/v1/auth/login",
data=login_data,
headers={"Content-Type": "application/x-www-form-urlencoded"}
)
assert response.status_code == 200
8. 部署配置
8.1 环境变量配置
bash
# .env.example
# 应用配置
APP_NAME="FastAPI Authentication"
APP_VERSION="1.0.0"
ENVIRONMENT="production"
DEBUG=false
# 安全配置
SECRET_KEY="your-secret-key-here-change-in-production"
ALGORITHM="HS256"
ACCESS_TOKEN_EXPIRE_MINUTES=30
REFRESH_TOKEN_EXPIRE_DAYS=7
# 数据库
DATABASE_URL="postgresql+asyncpg://user:password@localhost/dbname"
# Redis
REDIS_URL="redis://localhost:6379/0"
# OAuth配置
GOOGLE_CLIENT_ID=""
GOOGLE_CLIENT_SECRET=""
GITHUB_CLIENT_ID=""
GITHUB_CLIENT_SECRET=""
FACEBOOK_CLIENT_ID=""
FACEBOOK_CLIENT_SECRET=""
# 邮件配置
SMTP_HOST="smtp.gmail.com"
SMTP_PORT=587
SMTP_USER=""
SMTP_PASSWORD=""
EMAILS_FROM_EMAIL="noreply@example.com"
# 前端URL(用于CORS)
FRONTEND_URL="https://yourfrontend.com"
8.2 Docker配置
dockerfile
# Dockerfile
FROM python:3.9-slim
WORKDIR /app
# 安装系统依赖
RUN apt-get update && apt-get install -y \
gcc \
libpq-dev \
&& rm -rf /var/lib/apt/lists/*
# 复制依赖文件
COPY requirements.txt .
# 安装Python依赖
RUN pip install --no-cache-dir -r requirements.txt
# 复制应用代码
COPY . .
# 创建非root用户
RUN useradd -m -u 1000 fastapi && chown -R fastapi:fastapi /app
USER fastapi
# 运行应用
CMD ["uvicorn", "src.main:app", "--host", "0.0.0.0", "--port", "8000"]
9. 性能优化和安全建议
9.1 性能优化
- 令牌验证缓存:验证JWT令牌时缓存结果,避免重复解码
- 数据库索引:确保用户表的关键字段有索引
- Redis连接池:使用连接池管理Redis连接
- 异步处理:对于邮件发送等耗时操作使用后台任务
- CDN缓存:对于静态资源使用CDN缓存
9.2 安全建议
- 使用HTTPS:生产环境必须使用HTTPS
- 定期轮换密钥:定期轮换JWT密钥
- 监控异常登录:监控异常的登录尝试
- 实现审计日志:记录重要的安全事件
- 定期安全扫描:定期进行安全漏洞扫描
10. 总结
本文详细介绍了如何在FastAPI中实现完整的JWT认证和OAuth2集成系统。通过这个系统,您可以:
- 实现安全的用户认证:使用JWT进行无状态认证
- 支持多种登录方式:支持密码登录和OAuth登录
- 提供完善的安全特性:包括双因素认证、会话管理、密码策略
- 实现细粒度的权限控制:基于角色的访问控制
- 确保系统安全性:通过多种安全措施保护用户数据
这个系统为现代Web应用提供了一个安全、可扩展的认证授权基础架构,可以根据具体需求进行扩展和定制。