采用 Redis + Lua 脚本 实现分布式全局令牌桶限流 ,核心目的是:在多实例集群部署环境 下,实现统一、精准、原子化的接口限流控制,解决单机限流无法控制全局总流量、并发超卖、流量失控等问题,覆盖接口、用户多维度限流场景,保障系统在高并发场景下的稳定性与可用性。
一:具体解决的五大问题
1、解决多实例部署下,总流量无法精准控制的问题
-
单机限流:每台机器各自计数,总流量 = 单台限额 × 实例数,远超预期阈值 。
例如:你部署了3 台服务器实例 ,想做「接口每秒最多 10 次请求」的限流。
如果用单机限流 (比如每台实例自己记请求数):- 每台实例自己都放 10 次请求
- 3 台实例加起来,系统总 QPS 变成了
10×3=30 - 你本来想限制 10,结果变成了 30,流量不可控,这就是问题所在!
-
Redis 全局限流:所有服务实例共用同一个 Redis 计数器,全局总流量严格等于配置阈值。
2、解决并发请求下的超卖 / 超限制问题
- 传统 Java 级限流在高并发下会出现竞态条件,导致实际请求数超过限额。
- Lua 脚本在 Redis 中单线程原子执行 ,判断、扣减、回收一气呵成,绝对不会超发令牌。
3、解决分布式环境下的限流一致性问题
- 无论请求打到哪一台服务器,都统一去 Redis 中获取令牌。
- 保证全集群限流规则一致,不会出现某台机器松、某台机器严的情况。
4、解决固定窗口限流的 "边界突刺" 问题(令牌桶优势)
- 令牌桶支持平滑流量,令牌匀速生成、自动回收。
- 避免传统计数器在窗口切换瞬间出现流量突刺,压垮服务。
5、保障高并发下接口的可用性与自我保护
- 超出限流规则的请求直接拦截,防止大量无效请求压垮数据库 / 第三方接口。
- 配合降级机制,返回友好提示,提升用户体验。
二:令牌桶算法介绍
原理:系统以恒定的速度向令牌桶中放入令牌,当请求到达时,需要从桶中获取一个令牌才能被处理。如果桶中没有令牌,则拒绝请求。令牌桶的容量是固定的,当令牌放满时,多余的令牌会被丢弃。
核心优势:
- 允许突发流量:令牌桶可以积累令牌,当系统空闲时,令牌会逐渐填满桶,此时如果有突发流量到来,可以一次性获取多个令牌进行处理,充分利用系统资源。
- 平滑限流:令牌以恒定速度放入桶中,避免了固定窗口的临界问题,使请求处理速度更加平滑。
- 易于实现多维度限流:每个限流维度(如接口、IP、用户)可以维护独立的令牌桶,互不干扰。
令牌桶算法的数学模型:
- 设令牌桶容量为max_tokens(即最大突发请求数)
- 令牌生成速率为rate(即每秒生成的令牌数,等于限流 QPS)
- 当前令牌数为current_tokens
- 当请求到达时,若current_tokens >= 1,则current_tokens -= 1,请求被处理;否则拒绝请求。
- 每隔 1/rate 秒,current_tokens += 1,但不超过max_tokens。
代码实例:
java
/**
* 令牌桶限流算法(工业界标准)
* 优点:允许突发流量,同时能平滑限流,兼顾性能和灵活性
* 缺点:实现稍复杂
*/
public class TokenBucketRateLimiter {
// 令牌桶容量(最大突发请求数)
private final int capacity;
// 令牌生成速度(每秒生成的令牌数,即限流QPS)
private final double tokenRate;
// 当前令牌数
private double currentTokens;
// 上次生成令牌的时间戳
private long lastTokenTime;
public TokenBucketRateLimiter(int capacity, double tokenRate) {
this.capacity = capacity;
this.tokenRate = tokenRate;
// 初始时桶是满的
this.currentTokens = capacity;
this.lastTokenTime = System.currentTimeMillis();
}
/**
* 尝试获取令牌
* @return true-获取成功(允许请求),false-获取失败(限流)
*/
public synchronized boolean tryAcquire() {
return tryAcquire(1);
}
/**
* 尝试获取指定数量的令牌
* @param permits 需要获取的令牌数
* @return true-获取成功,false-获取失败
*/
public synchronized boolean tryAcquire(int permits) {
if (permits <= 0 || permits > capacity) {
return false;
}
long currentTime = System.currentTimeMillis();
// 1. 计算从上次生成令牌到现在应该生成的令牌数
double generatedTokens = (currentTime - lastTokenTime) / 1000.0 * tokenRate;
// 2. 更新当前令牌数(不能超过桶的容量)
currentTokens = Math.min(capacity, currentTokens + generatedTokens);
lastTokenTime = currentTime;
// 3. 判断是否有足够的令牌
if (currentTokens >= permits) {
currentTokens -= permits;
return true;
}
return false;
}
// 测试用例
public static void main(String[] args) throws InterruptedException {
// 令牌桶容量10,每秒生成5个令牌(即限流QPS=5,最大突发10个请求)
TokenBucketRateLimiter limiter = new TokenBucketRateLimiter(10, 5);
// 模拟15个突发请求
System.out.println("===== 突发15个请求 =====");
for (int i = 0; i < 15; i++) {
final int requestId = i;
new Thread(() -> {
if (limiter.tryAcquire()) {
System.out.println("请求" + requestId + ":成功");
} else {
System.out.println("请求" + requestId + ":被限流");
}
}).start();
}
// 等待1秒,令牌桶会补充5个令牌
Thread.sleep(1000);
System.out.println("===== 1秒后 =====");
// 再发10个请求
for (int i = 15; i < 25; i++) {
final int requestId = i;
new Thread(() -> {
if (limiter.tryAcquire()) {
System.out.println("请求" + requestId + ":成功");
} else {
System.out.println("请求" + requestId + ":被限流");
}
}).start();
}
}
}
三:分布式限流代码实现
核心实现:
@RateLimit注解:给方法打标签,配置限流规则RateLimitAspectAOP 切面:拦截注解方法,执行限流逻辑rate_limit.lua脚本:Redis 中原子执行的令牌桶算法,保证并发安全
整个流程:
java
用户请求 → 进入被@RateLimit标记的方法 → AOP切面拦截 → 调用Lua脚本执行限流判断
→ 通过:执行原方法
→ 不通过:执行降级方法 / 抛出异常
1、@RateLimit 注解(限流规则定义)
这是给方法打标签的注解,用来配置限流的维度、次数、时间窗口等。
java
/**
* 限流注解
* 用于方法级别的限流控制,支持多维度组合限流
*
* @see RateLimitAspect
*/
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface RateLimit {
/**
* 限流维度枚举
*/
enum Dimension {
/**
* 全局限流:对所有请求统一限流
*/
GLOBAL,
/**
* IP限流:按客户端IP地址限流
*/
IP,
/**
* 用户限流:按用户ID限流
*/
USER
}
/**
* 限流维度配置
* 支持多维度组合,只有所有维度都满足条件时才允许请求通过
* 例如:{Dimension.GLOBAL, Dimension.USER} 表示同时进行全局限流和用户级限流
*
* @return 限流维度数组
*/
Dimension[] dimensions() default {Dimension.GLOBAL};
/**
* 在指定时间窗口内允许的最大请求数
* 例如:count = 10, interval = 1, timeUnit = MINUTES 表示每分钟最多 10 次
*
* @return 令牌总数
*/
double count();
/**
* 时间窗口大小
* 默认 1
*
* @return 时间窗口
*/
long interval() default 1;
/**
* 时间单位
* 默认为秒,即默认"每秒 count 次"
*
* @return 时间单位
*/
TimeUnit timeUnit() default TimeUnit.SECONDS;
/**
* 等待令牌的超时时间
* 如果设置为0,表示不等待,直接获取令牌,失败则拒绝
* 如果大于0,会尝试等待指定时间获取令牌
*
* @return 超时时间
*/
long timeout() default 0;
/**
* 降级方法名
* 当限流触发时,调用指定方法进行降级处理
* 降级方法支持:
* 1. 无参方法
* 2. 与原方法参数列表完全一致的方法
* 降级方法必须在同一个类中,返回值类型与原方法兼容
* 如果为空字符串,则抛出 RateLimitExceededException 异常
*
* @return 降级方法名
*/
String fallback() default "";
/**
* 时间单位枚举
*/
enum TimeUnit {
MILLISECONDS, SECONDS, MINUTES, HOURS, DAYS
}
}
dimensions:多维度组合限流,比如{GLOBAL, USER}表示同时做全局限流和用户级限流,只有两个都通过才放行count + interval + timeUnit:组成限流规则,比如count=10, interval=1, timeUnit=MINUTES表示 "每分钟最多 10 次请求"timeout:请求拿不到令牌时的等待时间,默认 0 是直接拒绝fallback:限流触发时调用的降级方法名,支持无参或和原方法参数一致的方法
2、Lua 脚本(Redis 端原子限流逻辑)
这部分是限流的核心,用 Lua 脚本在 Redis 中原子执行,解决并发竞争问题。
java
---@diagnostic disable: undefined-global
-- 原子化多维度限流脚本
-- 基于令牌桶算法实现,支持多维度组合限流
-- 只有所有维度都满足条件时才扣减令牌,确保原子性
-- 参数说明:
-- KEYS[1..N]: 限流维度键列表
-- ARGV[1]: 当前时间戳(毫秒)
-- ARGV[2]: 申请令牌数
-- ARGV[3]: 时间窗口(毫秒)
-- ARGV[4]: 最大令牌数(窗口内允许的总数)
-- ARGV[5]: 请求唯一标识
local now_ms = tonumber(ARGV[1])
local permits = tonumber(ARGV[2])
local interval = tonumber(ARGV[3])
local max_tokens = tonumber(ARGV[4])
local request_id = ARGV[5]
-- 第一阶段:预检查阶段 - 检查所有维度是否有足够令牌
for i, key in ipairs(KEYS) do
local value_key = key .. ":value"
local permits_key = key .. ":permits"
-- 初始化 value_key(如果不存在)
if redis.call("exists", value_key) == 0 then
redis.call("set", value_key, max_tokens)
end
-- 回收过期令牌
-- 清理过期的 permit 记录,并回收配额到 value_key
local expired_values = redis.call("zrangebyscore", permits_key, 0, now_ms - interval)
if #expired_values > 0 then
local expired_count = 0
for _, v in ipairs(expired_values) do
-- 优化解析逻辑:使用更高效的模式匹配
local p = tonumber(string.match(v, ":(%d+)$"))
if p then
expired_count = expired_count + p
end
end
-- 删除过期记录
redis.call("zremrangebyscore", permits_key, 0, now_ms - interval)
-- 回收配额
if expired_count > 0 then
local curr_v = tonumber(redis.call("get", value_key) or max_tokens)
local next_v = math.min(max_tokens, curr_v + expired_count)
redis.call("set", value_key, next_v)
end
end
-- 核心检查:当前可用令牌是否足够
local current_val = tonumber(redis.call("get", value_key) or max_tokens)
if current_val < permits then
-- 任何一个维度配额不足,直接返回失败
return 0
end
end
-- 第二阶段:扣减阶段 - 只有所有维度都通过后才执行
for i, key in ipairs(KEYS) do
local value_key = key .. ":value"
local permits_key = key .. ":permits"
-- 记录本次令牌分配(格式:request_id:permits)
local permit_record = request_id .. ":" .. permits
redis.call("zadd", permits_key, now_ms, permit_record)
-- 扣减令牌
local current_v = tonumber(redis.call("get", value_key) or max_tokens)
redis.call("set", value_key, current_v - permits)
-- 设置过期时间,确保过期令牌能被正常回收 (窗口的2倍,至少1秒)
local expire_time = math.ceil(interval * 2 / 1000)
if expire_time < 1 then expire_time = 1 end
redis.call("expire", value_key, expire_time)
redis.call("expire", permits_key, expire_time)
end
-- 成功获取所有维度的令牌
return 1
这部分的主要工作:
- 初始化 Redis 中的令牌数(不存在则设置为最大数)
- 回收过期的令牌:清理有序集合里超过时间窗口的发放记录,把配额回收回来
- 检查所有维度的剩余令牌数,只要有一个维度不够,直接返回失败
- 把本次发放的令牌记录存入有序集合(用请求 UUID 做标识,score 为当前时间戳)
- 原子扣减剩余令牌数
- 设置 key 的过期时间,防止 Redis 中存在永久 key
3、RateLimitAspect AOP 切面(Java 端限流流程)
这是连接注解和 Lua 脚本的桥梁,拦截方法、组装参数、调用 Lua 脚本、处理结果。
java
/**
* 限流 AOP 切面
* 实现基于令牌桶算法的多维度限流
*/
@Slf4j
@Aspect
@Component
@RequiredArgsConstructor
public class RateLimitAspect {
private final RedissonClient redissonClient;
//Lua 脚本缓存
private static String LUA_SCRIPT;
private String luaScriptSha;
static {
try {
ClassPathResource resource = new ClassPathResource("rate_limit.lua");
LUA_SCRIPT = new String(resource.getContentAsByteArray(), StandardCharsets.UTF_8);
} catch (IOException e) {
throw new RuntimeException("加载限流 Lua 脚本失败", e);
}
}
//初始化:预加载脚本到 Redis 提高性能
@PostConstruct
public void init() {
this.luaScriptSha = redissonClient.getScript(StringCodec.INSTANCE).scriptLoad(LUA_SCRIPT);
log.info("限流 Lua 脚本加载完成, SHA1: {}", luaScriptSha);
}
//环绕通知:拦截带 @RateLimit 注解的方法
@Around("@annotation(rateLimit)")
public Object around(ProceedingJoinPoint joinPoint, RateLimit rateLimit) throws Throwable {
MethodSignature signature = (MethodSignature) joinPoint.getSignature();
Method method = signature.getMethod();
String className = method.getDeclaringClass().getSimpleName();
String methodName = method.getName();
// 1. 计算时间窗口(毫秒)
long intervalMs = calculateIntervalMs(rateLimit.interval(), rateLimit.timeUnit());
// 2. 根据配置维度动态生成 Redis Keys
List<String> keys = generateKeys(className, methodName, rateLimit.dimensions());
// 3. 调用 Lua 脚本执行原子限流
// 使用 StringCodec 确保参数正确传递为字符串
RScript script = redissonClient.getScript(StringCodec.INSTANCE);
// 准备参数
List<Object> keysList = new ArrayList<>(keys);
Object[] args = {
String.valueOf(System.currentTimeMillis()), // ARGV[1]: 当前时间戳
String.valueOf(1), // ARGV[2]: 申请令牌数(默认1个)
String.valueOf(intervalMs), // ARGV[3]: 时间窗口
String.valueOf(rateLimit.count()), // ARGV[4]: 最大令牌数
UUID.randomUUID().toString() // ARGV[5]: 请求唯一标识
};
Object resultObj = script.evalSha(
RScript.Mode.READ_WRITE,
luaScriptSha,
RScript.ReturnType.VALUE,
keysList,
args
);
// 将结果转换为 Long
Long result = convertToLong(resultObj);
// 4. 处理限流结果
if (result == null || result == 0) {
return handleRateLimitExceeded(joinPoint, rateLimit, keys);
}
// 5. 执行原方法
return joinPoint.proceed();
}
//计算时间窗口毫秒数
private long calculateIntervalMs(long interval, RateLimit.TimeUnit unit) {
return switch (unit) {
case MILLISECONDS -> interval;
case SECONDS -> interval * 1000;
case MINUTES -> interval * 60 * 1000;
case HOURS -> interval * 3600 * 1000;
case DAYS -> interval * 86400 * 1000;
};
}
//将结果对象安全转换为 Long
private Long convertToLong(Object obj) {
if (obj == null) {
return null;
}
if (obj instanceof Long) {
return (Long) obj;
} else if (obj instanceof Integer) {
return ((Integer) obj).longValue();
} else if (obj instanceof Short) {
return ((Short) obj).longValue();
} else if (obj instanceof Byte) {
return ((Byte) obj).longValue();
} else if (obj instanceof String) {
try {
return Long.parseLong((String) obj);
} catch (NumberFormatException e) {
log.warn("无法将字符串转换为Long: {}", obj);
return null;
}
}
log.warn("不支持的对象类型转换为Long: {}", obj.getClass().getName());
return null;
}
//生成限流键列表
private List<String> generateKeys(String className, String methodName, RateLimit.Dimension[] dimensions) {
List<String> keys = new ArrayList<>();
// 使用 {} 包含类名和方法名作为 Hash Tag,确保该方法的所有限流 Key 落在同一个 Redis Slot
// 从而适配 Redis Cluster 模式
String hashTag = "{" + className + ":" + methodName + "}";
String keyPrefix = "ratelimit:" + hashTag;
for (RateLimit.Dimension dimension : dimensions) {
switch (dimension) {
case GLOBAL -> keys.add(keyPrefix + ":global");
case IP -> keys.add(keyPrefix + ":ip:" + getClientIp());
case USER -> keys.add(keyPrefix + ":user:" + getCurrentUserId());
}
}
return keys;
}
//处理限流超出情况
private Object handleRateLimitExceeded(ProceedingJoinPoint joinPoint, RateLimit rateLimit, List<String> keys)
throws Throwable {
String methodName = joinPoint.getSignature().getName();
// 如果配置了降级方法,则调用降级方法
if (rateLimit.fallback() != null && !rateLimit.fallback().isEmpty()) {
try {
Method fallbackMethod = findFallbackMethod(joinPoint, rateLimit.fallback());
if (fallbackMethod != null) {
log.debug("限流触发,执行降级方法: {}.{} -> {}",
joinPoint.getTarget().getClass().getSimpleName(),
methodName,
rateLimit.fallback());
// 如果降级方法有参数,传入原方法的参数
if (fallbackMethod.getParameterCount() > 0) {
return fallbackMethod.invoke(joinPoint.getTarget(), joinPoint.getArgs());
} else {
return fallbackMethod.invoke(joinPoint.getTarget());
}
}
} catch (Exception e) {
log.error("降级方法执行失败: {}", rateLimit.fallback(), e);
}
}
// 没有降级方法或降级失败,抛出限流异常
log.debug("限流触发,拒绝请求: keys={}, count={} per {} {}",
keys, rateLimit.count(), rateLimit.interval(), rateLimit.timeUnit());
throw new BusinessException("请求过于频繁,请稍后再试");
}
/**
* 查找降级方法
* 优先查找与原方法参数列表完全一致的方法,找不到则查找无参方法
*/
private Method findFallbackMethod(ProceedingJoinPoint joinPoint, String fallbackName) {
Class<?> targetClass = joinPoint.getTarget().getClass();
MethodSignature signature = (MethodSignature) joinPoint.getSignature();
Class<?>[] parameterTypes = signature.getParameterTypes();
try {
// 1. 尝试查找同参数列表的方法
Method method = targetClass.getDeclaredMethod(fallbackName, parameterTypes);
method.setAccessible(true);
return method;
} catch (NoSuchMethodException e) {
// 2. 尝试查找无参方法
try {
Method method = targetClass.getDeclaredMethod(fallbackName);
method.setAccessible(true);
return method;
} catch (NoSuchMethodException ex) {
log.warn("未找到降级方法: {}.{} (需无参或参数列表一致)",
targetClass.getSimpleName(), fallbackName);
return null;
}
}
}
/**
* 获取客户端真实 IP
* 处理 X-Forwarded-For 头,支持代理服务器场景
*/
private String getClientIp() {
ServletRequestAttributes attributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
if (attributes == null) {
return "unknown";
}
HttpServletRequest request = attributes.getRequest();
String ip = request.getHeader("X-Forwarded-For");
if (ip == null || ip.isEmpty() || "unknown".equalsIgnoreCase(ip)) {
ip = request.getHeader("X-Real-IP");
}
if (ip == null || ip.isEmpty() || "unknown".equalsIgnoreCase(ip)) {
ip = request.getHeader("Proxy-Client-IP");
}
if (ip == null || ip.isEmpty() || "unknown".equalsIgnoreCase(ip)) {
ip = request.getHeader("WL-Proxy-Client-IP");
}
if (ip == null || ip.isEmpty() || "unknown".equalsIgnoreCase(ip)) {
ip = request.getRemoteAddr();
}
// 处理多个 IP 的情况(X-Forwarded-For 可能包含多个 IP)
if (ip != null && ip.contains(",")) {
ip = ip.split(",")[0].trim();
}
return ip != null ? ip : "unknown";
}
/**
* 获取当前用户 ID
* 从请求属性或 Session 中获取
*/
private String getCurrentUserId() {
UserVO user = UserHolder.getUser();
if (user != null) {
return user.getId().toString();
}
return "anonymous";
}
}
这部分的主要工作:
- 静态代码块加载 Lua 脚本内容
@PostConstruct预加载脚本到 Redis,获取 SHA1 值,后续用 SHA1 调用,避免每次都传整个脚本内容- 解析注解配置,把时间窗口转换为毫秒
- 根据限流维度生成 Redis key(全局 / IP / 用户级)
- 调用 Lua 脚本执行原子限流判断
- 根据返回结果处理:限流则执行降级,通过则执行原方法
- 把注解中的时间单位统一转换为毫秒,给 Lua 脚本使用。
- 用
{}包裹类名 + 方法名,作为 HashTag,确保同一个方法的所有限流 key 落在同一个 Redis Slot,避免集群模式下跨 slot 操作报错 - 按不同维度生成 key:全局 key、IP key、用户 key
- 限流触发时,优先调用配置的降级方法,支持无参或和原方法参数一致的方法
- 降级方法执行失败或未配置时,抛出
BusinessException异常,提示用户请求频繁
4、使用示例
java
// 每分钟最多5次请求,同时做用户级限流,触发限流时调用fallback方法
@RateLimit(
dimensions = {RateLimit.Dimension.USER},
count = 5,
interval = 1,
timeUnit = RateLimit.TimeUnit.MINUTES,
fallback = "seckillFallback"
)
public ResponseDTO<String> seckillFosterCoupon(Long couponId) {
//业务逻辑
}
// 降级方法(和原方法参数一致)
public ResponseDTO<String> seckillFallback(Long couponId) {
return ResponseDTO.fail("您的请求过于频繁,请稍后再试");
}
四:代码问题解答
1、@PostConstruct注解的作用是什么
它是 Spring 的注解作用是:在 Bean 对象创建完成、所有依赖注入(@Resource/@Autowired)完成之后,自动执行一次这个方法!
在代码中的作用:
- 等
RateLimitAspect对象创建好 - 等
redissonClient注入完成 - 自动执行一次 init () 方法
- 把 Lua 脚本加载到 Redis
- 拿到脚本的 SHA1 值
- 后面限流直接用 SHA1,性能更高
2、SHA1是什么
SHA1 是一种哈希算法 ,能把一段长文本,算出一串固定 40 位十六进制短字符串 ,这个结果就叫SHA1 摘要。
在代码中的作用:
Lua 脚本内容很长,每次调用都传整段脚本很慢。Redis 支持:
- 先把脚本上传 Redis,Redis 算出SHA1 哈希值存起来
- 后续调用只用传短小的 SHA1 值,Redis 根据值找脚本执行大幅节省网络传输、提升速度。
3、HashTag是什么
String hashTag = "{" + className + ":" + methodName + "}";
{} 中间的内容叫 **HashTag,**Redis 只计算 {} 里面的字符串作为槽位!
**{}**就是 Redis 的 HashTag 语法:
- Redis 计算 slot 时,只会用
{}里的内容来计算 - 如果两个 key 的 HashTag 部分都是
{UserService:login},那么它们计算出来的 slot 号是完全一样的,会落在同一个 Redis 节点上
为什么这么做:
Lua 脚本有一个铁律:**Lua 脚本里用到的所有 key,必须在同一个 slot(同一个抽屉)里!**否则 直接报错!