
最近接到一个任务,要给我们的OPEN API接口加个限流防护。现有的限流方案只是基于Redis的简单计数器,面对突发流量和时间窗口边界问题基本是摆设。
其实我一直知道Spring Cloud Gateway有个基于Redis LUA脚本实现的令牌桶滑动窗口限流组件,看着挺不错的,就直接拿来用了。但用着用着就发现问题了------它传入的请求时间戳是以秒为单位的,在高并发场景下,同一秒内的多次请求根本拿不到新令牌,导致限流过早触发。
从下面限流器中的java代码可以看出,传入的当前时间是此刻时间戳的秒数
java
// org.springframework.cloud.gateway.filter.ratelimit.RedisRateLimiter#isAllowed方法的代码
List<String> keys = getKeys(id);
// The arguments to the LUA script. time() returns unixtime in seconds.
// Instant.now().getEpochSecond()是此刻的时间戳的秒数
List<String> scriptArgs = Arrays.asList(replenishRate + "",
burstCapacity + "", Instant.now().getEpochSecond() + "", "1");
// allowed, tokens_left = redis.eval(SCRIPT, keys, args)
Flux<List<Long>> flux = this.redisTemplate.execute(this.script, keys,
scriptArgs);
而限流器中的redis lua脚本,计算的是时间窗口内已过去的时间,
所以当同一秒内的请求过来,计算出来的delta是0,那么就拿不到本应该存在桶中的新增令牌。
lua
local delta = math.max(0, now-last_refreshed)
local filled_tokens = math.min(capacity, last_tokens+(delta*rate))
local allowed = filled_tokens >= requested
local new_tokens = filled_tokens
我修改后的java代码传入的是时间戳的毫秒数
lua
-- 这里除以`1000`,转为时间单位'秒',因为`rate`的单位是'秒'
local delta = math.max(0, (now-last_refreshed)/1000)
-- 这里使用`math.floor`,向下取整
local filled_tokens = math.min(capacity, last_tokens+math.floor(delta*rate))
这样改了后,同一秒内的请求拿到的令牌理论上应该是正常的。
后面我闲来没事,我加了日志redis.log(redis.LOG_WARNING, "delta " .. delta),再使用JMeter压测发现
打印出来的delta始终是0,这一夜回到解放前,和修改之前的效果一样啊。
我查了资料我才知道:"redis lua处理浮点数时会将其小数部分截断,只取整数部分"
所以我不应该在计算时间差时除以1000,而应该在计算新增令牌数后再除以1000。
这是我再次修改后的LUA脚本
lua
local delta_ms = math.max(0, now-last_refreshed)
local filled_tokens = math.min(capacity, last_tokens+(delta_ms*rate)/1000)
完整的java代码
java
@Configuration(proxyBeanMethods = false)
public class FilterConfig {
@Bean
public FilterRegistrationBean<RequestLimitFilter> requestLimitFilter(StringRedisTemplate stringRedisTemplate) {
RequestLimitFilter filter = new RequestLimitFilter(stringRedisTemplate);
FilterRegistrationBean<RequestLimitFilter> bean = new FilterRegistrationBean<>(filter);
bean.setOrder(1);
bean.addUrlPatterns("/*");
return bean;
}
private static class RequestLimitFilter extends OncePerRequestFilter {
final RedisScript<List> limitScript;
final StringRedisTemplate stringRedisTemplate;
public RequestLimitFilter(StringRedisTemplate stringRedisTemplate) {
DefaultRedisScript<List> script = new DefaultRedisScript<>();
script.setLocation(new ClassPathResource("lua/request_rate_limiter.lua"));
script.setResultType(List.class);
this.limitScript = script;
this.stringRedisTemplate = stringRedisTemplate;
}
@Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException {
//todo 请求id,这里只是简单用请求路径,实际上需要根据业务来确定,比如ip,用户id等,这里只是简单用请求路径做演示
String reqId = request.getRequestURI();
List<String> keys = getKeys(reqId);
int burstCapacity = 820;
int replenishRate = 400;
//spring cloud gateway中传入的时间是 'Instant.now().getEpochSecond()'
List<String> scriptArgs = Arrays.asList(String.valueOf(replenishRate), String.valueOf(burstCapacity),
String.valueOf(System.currentTimeMillis()), String.valueOf(1));
@SuppressWarnings("unchecked")
List<Long> results = stringRedisTemplate.execute(limitScript, keys, scriptArgs.toArray());
assert results != null;
response.setHeader("X-RateLimit-Remaining", String.valueOf(results.get(1)));
response.setHeader("X-RateLimit-Replenish-Rate", String.valueOf(replenishRate));
response.setHeader("X-RateLimit-Burst-Capacity", String.valueOf(burstCapacity));
boolean allowed = Objects.equals(results.get(0), 1L);
if (allowed) {
filterChain.doFilter(request, response);
return;
}
response.setContentType(MediaType.APPLICATION_JSON_VALUE);
response.setStatus(HttpStatus.TOO_MANY_REQUESTS.value());
response.getWriter().write("{\"code\":1429,\"msg\":\"too many request\"}");
}
static List<String> getKeys(String id) {
String prefix = "request_rate_limiter.{" + id;
String tokenKey = prefix + "}.tokens";
String timestampKey = prefix + "}.timestamp";
return Arrays.asList(tokenKey, timestampKey);
}
}
}
改造后的lua脚本
lua
-- request_rate_limiter.lua
local tokens_key = KEYS[1]
local timestamp_key = KEYS[2]
--redis.log(redis.LOG_WARNING, "tokens_key " .. tokens_key)
local rate = tonumber(ARGV[1])
local capacity = tonumber(ARGV[2])
-- now is ms timestamp
local now = tonumber(ARGV[3])
local requested = tonumber(ARGV[4])
local fill_time = capacity/rate
local ttl = math.floor(fill_time*2)
--redis.log(redis.LOG_WARNING, "rate " .. ARGV[1])
--redis.log(redis.LOG_WARNING, "capacity " .. ARGV[2])
--redis.log(redis.LOG_WARNING, "now " .. ARGV[3])
--redis.log(redis.LOG_WARNING, "requested " .. ARGV[4])
--redis.log(redis.LOG_WARNING, "filltime " .. fill_time)
--redis.log(redis.LOG_WARNING, "ttl " .. ttl)
local last_tokens = tonumber(redis.call("get", tokens_key))
if last_tokens == nil then
last_tokens = capacity
end
--redis.log(redis.LOG_WARNING, "last_tokens " .. last_tokens)
local last_refreshed = tonumber(redis.call("get", timestamp_key))
if last_refreshed == nil then
last_refreshed = 0
end
--redis.log(redis.LOG_WARNING, "last_refreshed " .. last_refreshed)
local delta_ms = math.max(0, now-last_refreshed)
local logText=string.format("limit delta = %d, last_tokens = %d, last_refreshed = %d", delta_ms, last_tokens, last_refreshed)
redis.log(redis.LOG_NOTICE, logText)
-- 这里的时间单位需要注意:传入的是毫秒时间戳,但rate是按秒计算的,所以在计算令牌补充量时要除以1000进行单位转换
-- 为什么不直接在计算时间差时就除以1000呢?因为在Redis Lua中,浮点运算会被截断为整数,如果先转换时间单位再计算,会导致精度损失
-- 比如:假设delta_ms=1500ms,rate=10,正确计算应该是 1500*10/1000=15个令牌
-- 如果先转换时间单位:1500/1000=1(被截断),然后1*10=10个令牌,就少了5个
local filled_tokens = math.min(capacity, last_tokens+(delta_ms*rate)/1000)
local allowed = filled_tokens >= requested
local new_tokens = filled_tokens
local allowed_num = 0
if allowed then
new_tokens = filled_tokens - requested
allowed_num = 1
end
--redis.log(redis.LOG_WARNING, "delta " .. delta)
--redis.log(redis.LOG_WARNING, "filled_tokens " .. filled_tokens)
--redis.log(redis.LOG_WARNING, "allowed_num " .. allowed_num)
--redis.log(redis.LOG_WARNING, "new_tokens " .. new_tokens)
-- 原始LUA脚本,使用'setnx'指令,过期时间单位是秒数,我这里使用了`psetex`,过期时间单位是毫秒数,让过期时间更精确
local ttl_ms= ttl * 1000
redis.call("psetex", tokens_key, ttl_ms, new_tokens)
redis.call("psetex", timestamp_key, ttl_ms, now)
return { allowed_num, new_tokens }