Spring Security 7 OAuth2 授权码分布式存储之Redis存储方案
背景分析
在分布式系统中,OAuth2 授权码模式需要多个服务节点协同工作。当用户通过授权码换取访问令牌时,任何一个服务节点都能正确处理该请求,这就要求授权码及其关联的授权信息必须存储在共享存储中,而非本地内存。
Spring Security OAuth2 Authorization Server 默认提供 InMemoryOAuth2AuthorizationService,仅适用于单机开发环境。本文将基于源码分析,探讨如何实现 Redis 存储方案以满足分布式认证需求。
授权码流程回顾
┌─────────┐ ┌──────────────┐ ┌─────────┐ ┌─────────────────┐
│ 浏览器 │────▶│ 授权服务器 │────▶│ 用户 │────▶│ 授权服务器 │
│ │ │ /authorize │ │ 登录 │ │ (生成授权码) │
└─────────┘ └──────────────┘ └─────────┘ └─────────────────┘
│ │
│ 授权码 + redirect_uri │
└──────────────────────────────────┘
│
▼
┌─────────────────┐
│ 客户端应用 │
│ (用授权码换Token)│
└─────────────────┘
关键点:授权码生成后存储在服务端,客户端携带授权码换取令牌时,服务端需要根据授权码查找完整的授权信息。
核心接口分析
OAuth2AuthorizationService 接口
java
public interface OAuth2AuthorizationService {
// 保存授权信息(包含授权码)
void save(OAuth2Authorization authorization);
// 根据ID移除授权信息
void remove(OAuth2Authorization authorization);
// 根据授权ID查找
OAuth2Authorization findById(String id);
// 根据令牌值查找(支持授权码、访问令牌等)
OAuth2Authorization findByToken(String token, OAuth2TokenType tokenType);
}
OAuth2Authorization 数据结构
OAuth2Authorization 是整个授权流程的核心实体,包含:
| 字段 | 说明 |
|---|---|
id |
授权唯一标识 |
registeredClientId |
注册客户端ID |
principalName |
资源所有者用户名 |
authorizationGrantType |
授权类型(authorization_code) |
tokens |
令牌集合(OAuth2AuthorizationCode、OAuth2AccessToken等) |
attributes |
附加属性(state、redirect_uri等) |
java
public class OAuth2Authorization implements Serializable {
private String id;
private String registeredClientId;
private String principalName;
private AuthorizationGrantType authorizationGrantType;
private Set<String> authorizedScopes;
private Map<Class<? extends OAuth2Token>, Token<?>> tokens;
private Map<String, Object> attributes;
// 获取指定类型的令牌
public <T extends OAuth2Token> Token<T> getToken(Class<T> tokenType);
// 根据令牌值查找
public <T extends OAuth2Token> Token<T> getToken(String tokenValue);
}
OAuth2AuthorizationCode 结构
java
public class OAuth2AuthorizationCode extends AbstractOAuth2Token {
// 继承自 AbstractOAuth2Token
// - tokenValue: 授权码字符串
// - issuedAt: 签发时间
// - expiresAt: 过期时间(默认5分钟)
}
授权码存储时机
在 OAuth2AuthorizationCodeRequestAuthenticationProvider.authenticate() 方法中(第314-318行):
java
// 生成授权码
OAuth2AuthorizationCode authorizationCode = this.authorizationCodeGenerator.generate(tokenContext);
// 构建授权对象并保存
OAuth2Authorization authorization = authorizationBuilder(registeredClient, principal, authorizationRequest)
.authorizedScopes(authorizationRequest.getScopes())
.token(authorizationCode) // 授权码绑定到authorization
.build();
this.authorizationService.save(authorization); // 关键:保存到存储服务
Redis 存储方案设计
存储结构设计
┌─────────────────────────────────────────────────────────────────────┐
│ Redis Key 设计 │
├─────────────────────────────────────────────────────────────────────┤
│ Key: oauth2:authorization:{uuid} │
│ Value: JSON序列化的 OAuth2Authorization │
│ TTL: 5分钟(未完成授权) / 24小时(已完成授权,持有访问令牌) │
├─────────────────────────────────────────────────────────────────────┤
│ Key: oauth2:token:{authorization_code_value} │
│ Value: authorization_id │
│ TTL: 与授权码过期时间一致 │
├─────────────────────────────────────────────────────────────────────┤
│ Key: oauth2:state:{state_value} │
│ Value: authorization_id │
│ TTL: 10分钟 │
└─────────────────────────────────────────────────────────────────────┘
核心实现代码
java
@Component
public class RedisOAuth2AuthorizationService implements OAuth2AuthorizationService {
private static final String AUTHORIZATION_KEY_PREFIX = "oauth2:authorization:";
private static final String TOKEN_KEY_PREFIX = "oauth2:token:";
private static final String STATE_KEY_PREFIX = "oauth2:state:";
private static final Duration DEFAULT_CODE_TTL = Duration.ofMinutes(5);
private static final Duration COMPLETED_TTL = Duration.ofHours(24);
private final RedisTemplate<String, OAuth2Authorization> redisTemplate;
private final ObjectMapper objectMapper;
public RedisOAuth2AuthorizationService(
RedisTemplate<String, OAuth2Authorization> redisTemplate) {
this.redisTemplate = redisTemplate;
// 配置 Jackson ObjectMapper
this.objectMapper = new ObjectMapper();
this.objectMapper.registerModule(new OAuth2AuthorizationServerJackson2Module());
this.objectMapper.registerModule(new Jdk8Module());
this.objectMapper.disable(SerializationFeature.WRITE_DATES_AS_TIMESTAMPS);
this.objectMapper.activateDefaultTyping(
objectMapper.getPolymorphicTypeValidator(),
ObjectMapper.DefaultTyping.NON_FINAL
);
// 配置 RedisTemplate 的序列化器
Jackson2JsonRedisSerializer<OAuth2Authorization> serializer =
new Jackson2JsonRedisSerializer<>(objectMapper, OAuth2Authorization.class);
this.redisTemplate.setValueSerializer(serializer);
this.redisTemplate.setHashValueSerializer(serializer);
}
@Override
public void save(OAuth2Authorization authorization) {
Assert.notNull(authorization, "authorization cannot be null");
String key = AUTHORIZATION_KEY_PREFIX + authorization.getId();
// 计算过期时间
Duration ttl = calculateTtl(authorization);
redisTemplate.opsForValue().set(key, authorization, ttl);
// 建立授权码索引
OAuth2Authorization.Token<OAuth2AuthorizationCode> codeToken =
authorization.getToken(OAuth2AuthorizationCode.class);
if (codeToken != null) {
saveTokenIndex(codeToken.getToken(), authorization.getId());
}
// 建立 state 索引
String state = authorization.getAttribute(OAuth2ParameterNames.STATE);
if (StringUtils.hasText(state)) {
String stateKey = STATE_KEY_PREFIX + state;
redisTemplate.opsForValue().set(stateKey, authorization.getId(),
Duration.ofMinutes(10));
}
}
@Override
public void remove(OAuth2Authorization authorization) {
Assert.notNull(authorization, "authorization cannot be null");
String key = AUTHORIZATION_KEY_PREFIX + authorization.getId();
// 清理授权码索引
OAuth2Authorization.Token<OAuth2AuthorizationCode> codeToken =
authorization.getToken(OAuth2AuthorizationCode.class);
if (codeToken != null) {
redisTemplate.delete(TOKEN_KEY_PREFIX + codeToken.getToken().getTokenValue());
}
// 清理 state 索引
String state = authorization.getAttribute(OAuth2ParameterNames.STATE);
if (StringUtils.hasText(state)) {
redisTemplate.delete(STATE_KEY_PREFIX + state);
}
redisTemplate.delete(key);
}
@Override
public OAuth2Authorization findById(String id) {
String key = AUTHORIZATION_KEY_PREFIX + id;
return redisTemplate.opsForValue().get(key);
}
@Override
public OAuth2Authorization findByToken(String token, OAuth2TokenType tokenType) {
Assert.hasText(token, "token cannot be empty");
// 根据令牌类型精确查找
if (tokenType != null) {
if (OAuth2ParameterNames.CODE.equals(tokenType.getValue())) {
return findByAuthorizationCode(token);
} else if (OAuth2ParameterNames.STATE.equals(tokenType.getValue())) {
return findByState(token);
} else if (OAuth2TokenType.ACCESS_TOKEN.equals(tokenType)) {
return findByAccessToken(token);
}
}
// 通用查找(遍历所有token类型)
return findByAnyToken(token);
}
private OAuth2Authorization findByAuthorizationCode(String code) {
String codeKey = TOKEN_KEY_PREFIX + code;
String authId = redisTemplate.opsForValue().get(codeKey);
if (authId != null) {
return findById(authId);
}
return null;
}
private OAuth2Authorization findByState(String state) {
String stateKey = STATE_KEY_PREFIX + state;
String authId = redisTemplate.opsForValue().get(stateKey);
if (authId != null) {
return findById(authId);
}
return null;
}
private OAuth2Authorization findByAccessToken(String accessToken) {
Set<String> keys = redisTemplate.keys(AUTHORIZATION_KEY_PREFIX + "*");
if (keys == null || keys.isEmpty()) {
return null;
}
List<String> authIds = redisTemplate.opsForValue().multiGet(keys);
for (int i = 0; i < keys.size(); i++) {
if (authIds != null && authIds.get(i) != null) {
OAuth2Authorization auth = redisTemplate.opsForValue()
.get((String) Array.get(keys.toArray(), i));
if (hasAccessToken(auth, accessToken)) {
return auth;
}
}
}
return null;
}
private void saveTokenIndex(OAuth2AuthorizationCode code, String authId) {
String codeKey = TOKEN_KEY_PREFIX + code.getTokenValue();
Duration ttl = Duration.between(Instant.now(), code.getExpiresAt());
redisTemplate.opsForValue().set(codeKey, authId, ttl.plusMinutes(1));
}
private Duration calculateTtl(OAuth2Authorization authorization) {
// 已完成授权(有访问令牌)
if (authorization.getAccessToken() != null) {
return COMPLETED_TTL;
}
// 未完成授权,检查授权码过期时间
OAuth2Authorization.Token<OAuth2AuthorizationCode> codeToken =
authorization.getToken(OAuth2AuthorizationCode.class);
if (codeToken != null && codeToken.getToken().getExpiresAt() != null) {
return Duration.between(Instant.now(), codeToken.getToken().getExpiresAt())
.plusMinutes(1); // 缓冲时间
}
return DEFAULT_CODE_TTL;
}
private boolean hasAccessToken(OAuth2Authorization auth, String tokenValue) {
OAuth2Authorization.Token<?> accessToken = auth.getToken(OAuth2AccessToken.class);
return accessToken != null &&
accessToken.getToken().getTokenValue().equals(tokenValue);
}
private boolean hasToken(OAuth2Authorization authorization, String token,
OAuth2TokenType tokenType) {
// 简化实现,完整实现参考 InMemoryOAuth2AuthorizationService
return false;
}
}
配置类
java
@Configuration
@EnableConfigurationProperties(Oauth2Properties.class)
public class OAuth2RedisAuthorizationConfig {
@Value("${spring.data.redis.host:localhost}")
private String redisHost;
@Value("${spring.data.redis.port:6379}")
private int redisPort;
@Bean
public RedisConnectionFactory redisConnectionFactory() {
RedisStandaloneConfiguration config = new RedisStandaloneConfiguration();
config.setHostName(redisHost);
config.setPort(redisPort);
return new LettuceConnectionFactory(config);
}
@Bean
public RedisTemplate<String, OAuth2Authorization> authorizationRedisTemplate(
RedisConnectionFactory connectionFactory) {
RedisTemplate<String, OAuth2Authorization> template = new RedisTemplate<>();
template.setConnectionFactory(connectionFactory);
// 使用 StringRedisSerializer 作为 Key 的序列化器
StringRedisSerializer stringSerializer = new StringRedisSerializer();
template.setKeySerializer(stringSerializer);
template.setHashKeySerializer(stringSerializer);
// 使用 Jackson2JsonRedisSerializer 作为 Value 的序列化器
ObjectMapper objectMapper = new ObjectMapper();
objectMapper.registerModule(new OAuth2AuthorizationServerJackson2Module());
objectMapper.registerModule(new Jdk8Module());
objectMapper.activateDefaultTyping(
objectMapper.getPolymorphicTypeValidator(),
ObjectMapper.DefaultTyping.NON_FINAL
);
Jackson2JsonRedisSerializer<OAuth2Authorization> jsonSerializer =
new Jackson2JsonRedisSerializer<>(objectMapper, OAuth2Authorization.class);
template.setValueSerializer(jsonSerializer);
template.setHashValueSerializer(jsonSerializer);
template.afterPropertiesSet();
return template;
}
@Bean
public OAuth2AuthorizationService authorizationService(
RedisTemplate<String, OAuth2Authorization> authorizationRedisTemplate) {
return new RedisOAuth2AuthorizationService(authorizationRedisTemplate);
}
}
集成到授权服务器配置
java
@Configuration
@EnableConfigurationProperties(Oauth2Properties.class)
public class Oauth2AuthorizationServerHttpSecurityConfig
implements ICustomHttpSecurityConfig {
private final OAuth2AuthorizationService authorizationService;
public Oauth2AuthorizationServerHttpSecurityConfig(
OAuth2AuthorizationService authorizationService) {
this.authorizationService = authorizationService;
}
@Override
public void config(HttpSecurity http) {
http.oauth2AuthorizationServer(oauthServer -> {
oauthServer
.authorizationEndpoint(authorizationEndpoint ->
authorizationEndpoint
.authorizationService(authorizationService))
.oidc(Customizer.withDefaults());
});
}
}
序列化配置说明
Spring Security OAuth2 Authorization Server 提供了专门的 Jackson 模块用于序列化 OAuth2 相关对象:
java
// Maven 依赖
// spring-security-oauth2-authorization-server
// spring-data-redis
// jackson-datatype-jsr310
@Bean
public Module oauth2AuthorizationServerJackson2Module() {
return new OAuth2AuthorizationServerJackson2Module();
}
该模块支持的序列化类型:
OAuth2AuthorizationOAuth2AuthorizationCodeOAuth2AuthorizationConsentOAuth2AuthorizationRequestRegisteredClient
性能优化建议
1. 使用 Lua 脚本保证原子性
lua
-- Redis Lua 脚本:原子性保存授权信息
local authKey = KEYS[1]
local codeKey = KEYS[2]
local authData = ARGV[1]
local codeData = ARGV[2]
local codeTtl = tonumber(ARGV[3])
redis.call('SET', authKey, authData, 'EX', 86400)
redis.call('SET', codeKey, codeData, 'EX', codeTtl)
return 'OK'
2. 连接池配置
yaml
spring:
data:
redis:
host: localhost
port: 6379
lettuce:
pool:
max-active: 50
max-idle: 20
min-idle: 5
max-wait: 3000ms
3. 监控指标
java
@Bean
public MeterRegistry meterRegistry() {
return new SimpleMeterRegistry();
}
@Bean
public RedisOAuth2AuthorizationServiceMetrics metrics() {
return new RedisOAuth2AuthorizationServiceMetrics(meterRegistry);
}
完整流程图
┌──────────────────────────────────────────────────────────────────────────┐
│ 授权码请求流程 │
├──────────────────────────────────────────────────────────────────────────┤
│ │
│ 1. 用户访问 /oauth2/authorize?client_id=xxx&response_type=code │
│ │ │
│ ▼ │
│ 2. OAuth2AuthorizationCodeRequestAuthenticationProvider │
│ │ │
│ ┌──────────────────────────┼──────────────────────────┐ │
│ │ │ │ │
│ ▼ ▼ ▼ │
│ 3. 需要用户授权? ──YES──▶ 保存授权请求 ──▶ 返回consent页面 │
│ │ │
│ NO │
│ │ │
│ ▼ │
│ 4. 生成授权码 ──▶ 构建OAuth2Authorization ──▶ 保存到Redis │
│ │ │ │ │ │
│ │ │ │ │ │
│ │ │ ▼ ▼ │
│ │ │ redisKey: oauth2:authorization:{uuid} │
│ │ │ redisKey: oauth2:token:{code} -> {uuid} │
│ │ │ │
│ │ ▼ │
│ └──▶ 返回授权码 + redirect_uri │
│ │
└──────────────────────────────────────────────────────────────────────────┘
┌──────────────────────────────────────────────────────────────────────────┐
│ 令牌交换流程 │
├──────────────────────────────────────────────────────────────────────────┤
│ │
│ 1. 客户端 POST /oauth2/token │
│ grant_type=authorization_code&code=xxx&client_id=xxx │
│ │ │
│ ▼ │
│ 2. OAuth2AuthorizationCodeAuthenticationProvider │
│ │ │
│ ▼ │
│ 3. findByToken(code, OAuth2TokenType("code")) │
│ │ │
│ │ 3.1 查找 oauth2:token:{code} 获取 authorization_id │
│ │ │
│ ▼ │
│ 4. 查找 oauth2:authorization:{id} 获取 OAuth2Authorization │
│ │ │
│ ▼ │
│ 5. 验证授权码状态(未过期、未使用) │
│ │ │
│ ▼ │
│ 6. 生成访问令牌 ──▶ 更新OAuth2Authorization ──▶ 保存到Redis │
│ │ │ │ │ │
│ │ │ │ │ │
│ │ │ ▼ ▼ │
│ │ │ TTL 变为 24小时 │
│ │ │ │
│ │ ▼ │
│ └──▶ 返回访问令牌 │
│ │
└──────────────────────────────────────────────────────────────────────────┘
总结
通过实现 OAuth2AuthorizationService 接口并将授权信息存储到 Redis,我们可以:
- 支持分布式部署:多个授权服务器实例共享授权状态
- 保持会话一致:用户登录后,任何服务节点都能完成令牌交换
- 可扩展性:可以通过 Redis 集群支持更高的可用性和性能
- 状态可追踪:通过 Redis 的过期机制自动清理过期的授权信息
关键实现要点:
- 正确设置不同阶段的 TTL(未完成 5 分钟,已完成 24 小时)
- 建立授权码到授权 ID 的快速索引
- 使用官方提供的 Jackson 模块确保序列化正确性
- 考虑使用 Lua 脚本保证操作的原子性