从秒级到毫秒级:一次Redis限流脚本的深度优化实战

最近接到一个任务,要给我们的OPEN API接口加个限流防护。现有的限流方案只是基于Redis的简单计数器,面对突发流量和时间窗口边界问题基本是摆设。

其实我一直知道Spring Cloud Gateway有个基于Redis LUA脚本实现的令牌桶滑动窗口限流组件,看着挺不错的,就直接拿来用了。但用着用着就发现问题了------它传入的请求时间戳是以秒为单位的,在高并发场景下,同一秒内的多次请求根本拿不到新令牌,导致限流过早触发。

从下面限流器中的java代码可以看出,传入的当前时间是此刻时间戳的秒数

java 复制代码
            // org.springframework.cloud.gateway.filter.ratelimit.RedisRateLimiter#isAllowed方法的代码
			List<String> keys = getKeys(id);

			// The arguments to the LUA script. time() returns unixtime in seconds.
                    // Instant.now().getEpochSecond()是此刻的时间戳的秒数
			List<String> scriptArgs = Arrays.asList(replenishRate + "",
					burstCapacity + "", Instant.now().getEpochSecond() + "", "1");
			// allowed, tokens_left = redis.eval(SCRIPT, keys, args)
			Flux<List<Long>> flux = this.redisTemplate.execute(this.script, keys,
					scriptArgs);

而限流器中的redis lua脚本,计算的是时间窗口内已过去的时间,

所以当同一秒内的请求过来,计算出来的delta是0,那么就拿不到本应该存在桶中的新增令牌。

lua 复制代码
local delta = math.max(0, now-last_refreshed)
local filled_tokens = math.min(capacity, last_tokens+(delta*rate))
local allowed = filled_tokens >= requested
local new_tokens = filled_tokens

我修改后的java代码传入的是时间戳的毫秒数

lua 复制代码
-- 这里除以`1000`,转为时间单位'秒',因为`rate`的单位是'秒'
local delta = math.max(0, (now-last_refreshed)/1000)
-- 这里使用`math.floor`,向下取整
local filled_tokens = math.min(capacity, last_tokens+math.floor(delta*rate))

这样改了后,同一秒内的请求拿到的令牌理论上应该是正常的。

后面我闲来没事,我加了日志redis.log(redis.LOG_WARNING, "delta " .. delta),再使用JMeter压测发现

打印出来的delta始终是0,这一夜回到解放前,和修改之前的效果一样啊。

我查了资料我才知道:"redis lua处理浮点数时会将其小数部分截断,只取整数部分"

所以我不应该在计算时间差时除以1000,而应该在计算新增令牌数后再除以1000。

这是我再次修改后的LUA脚本

lua 复制代码
local delta_ms = math.max(0, now-last_refreshed)
local filled_tokens = math.min(capacity, last_tokens+(delta_ms*rate)/1000)

完整的java代码

java 复制代码
@Configuration(proxyBeanMethods = false)
public class FilterConfig {

    @Bean
    public FilterRegistrationBean<RequestLimitFilter> requestLimitFilter(StringRedisTemplate stringRedisTemplate) {
        RequestLimitFilter filter = new RequestLimitFilter(stringRedisTemplate);
        FilterRegistrationBean<RequestLimitFilter> bean = new FilterRegistrationBean<>(filter);
        bean.setOrder(1);
        bean.addUrlPatterns("/*");
        return bean;
    }

    private  static class RequestLimitFilter extends OncePerRequestFilter {

        final RedisScript<List> limitScript;
        final StringRedisTemplate stringRedisTemplate;

        public RequestLimitFilter(StringRedisTemplate stringRedisTemplate) {
            DefaultRedisScript<List> script = new DefaultRedisScript<>();
            script.setLocation(new ClassPathResource("lua/request_rate_limiter.lua"));
            script.setResultType(List.class);
            this.limitScript = script;
            this.stringRedisTemplate = stringRedisTemplate;
        }

        @Override
        protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException {
            //todo 请求id,这里只是简单用请求路径,实际上需要根据业务来确定,比如ip,用户id等,这里只是简单用请求路径做演示
            String reqId = request.getRequestURI();
            List<String> keys = getKeys(reqId);
            int burstCapacity = 820;
            int replenishRate = 400;
            //spring cloud gateway中传入的时间是 'Instant.now().getEpochSecond()'
            List<String> scriptArgs = Arrays.asList(String.valueOf(replenishRate), String.valueOf(burstCapacity),
                    String.valueOf(System.currentTimeMillis()), String.valueOf(1));
            @SuppressWarnings("unchecked")
            List<Long> results = stringRedisTemplate.execute(limitScript, keys, scriptArgs.toArray());
            assert results != null;

            response.setHeader("X-RateLimit-Remaining", String.valueOf(results.get(1)));
            response.setHeader("X-RateLimit-Replenish-Rate", String.valueOf(replenishRate));
            response.setHeader("X-RateLimit-Burst-Capacity", String.valueOf(burstCapacity));

            boolean allowed = Objects.equals(results.get(0), 1L);
            if (allowed) {
                filterChain.doFilter(request, response);
                return;
            }

            response.setContentType(MediaType.APPLICATION_JSON_VALUE);
            response.setStatus(HttpStatus.TOO_MANY_REQUESTS.value());
            response.getWriter().write("{\"code\":1429,\"msg\":\"too many request\"}");
        }

        static List<String> getKeys(String id) {
            String prefix = "request_rate_limiter.{" + id;
            String tokenKey = prefix + "}.tokens";
            String timestampKey = prefix + "}.timestamp";
            return Arrays.asList(tokenKey, timestampKey);
        }
    }
}

改造后的lua脚本

lua 复制代码
-- request_rate_limiter.lua
local tokens_key = KEYS[1]
local timestamp_key = KEYS[2]
--redis.log(redis.LOG_WARNING, "tokens_key " .. tokens_key)

local rate = tonumber(ARGV[1])
local capacity = tonumber(ARGV[2])
-- now is ms timestamp
local now = tonumber(ARGV[3])
local requested = tonumber(ARGV[4])

local fill_time = capacity/rate
local ttl = math.floor(fill_time*2)

--redis.log(redis.LOG_WARNING, "rate " .. ARGV[1])
--redis.log(redis.LOG_WARNING, "capacity " .. ARGV[2])
--redis.log(redis.LOG_WARNING, "now " .. ARGV[3])
--redis.log(redis.LOG_WARNING, "requested " .. ARGV[4])
--redis.log(redis.LOG_WARNING, "filltime " .. fill_time)
--redis.log(redis.LOG_WARNING, "ttl " .. ttl)

local last_tokens = tonumber(redis.call("get", tokens_key))
if last_tokens == nil then
  last_tokens = capacity
end
--redis.log(redis.LOG_WARNING, "last_tokens " .. last_tokens)

local last_refreshed = tonumber(redis.call("get", timestamp_key))
if last_refreshed == nil then
  last_refreshed = 0
end
--redis.log(redis.LOG_WARNING, "last_refreshed " .. last_refreshed)
local delta_ms = math.max(0, now-last_refreshed)
local logText=string.format("limit delta = %d, last_tokens = %d, last_refreshed = %d", delta_ms, last_tokens, last_refreshed)
redis.log(redis.LOG_NOTICE, logText)
-- 这里的时间单位需要注意:传入的是毫秒时间戳,但rate是按秒计算的,所以在计算令牌补充量时要除以1000进行单位转换
-- 为什么不直接在计算时间差时就除以1000呢?因为在Redis Lua中,浮点运算会被截断为整数,如果先转换时间单位再计算,会导致精度损失
-- 比如:假设delta_ms=1500ms,rate=10,正确计算应该是 1500*10/1000=15个令牌
-- 如果先转换时间单位:1500/1000=1(被截断),然后1*10=10个令牌,就少了5个
local filled_tokens = math.min(capacity, last_tokens+(delta_ms*rate)/1000)
local allowed = filled_tokens >= requested
local new_tokens = filled_tokens
local allowed_num = 0
if allowed then
  new_tokens = filled_tokens - requested
  allowed_num = 1
end

--redis.log(redis.LOG_WARNING, "delta " .. delta)
--redis.log(redis.LOG_WARNING, "filled_tokens " .. filled_tokens)
--redis.log(redis.LOG_WARNING, "allowed_num " .. allowed_num)
--redis.log(redis.LOG_WARNING, "new_tokens " .. new_tokens)
--  原始LUA脚本,使用'setnx'指令,过期时间单位是秒数,我这里使用了`psetex`,过期时间单位是毫秒数,让过期时间更精确
local ttl_ms= ttl * 1000
redis.call("psetex", tokens_key, ttl_ms, new_tokens)
redis.call("psetex", timestamp_key, ttl_ms, now)

return { allowed_num, new_tokens }
相关推荐
lbb 小魔仙2 小时前
【Java】Spring Cloud 核心组件详解:Eureka、Ribbon、Feign 与 Hystrix
java·spring cloud·eureka
K_Men2 小时前
【Redis】根据key模糊匹配批量删除
redis
啊吧怪不啊吧2 小时前
从单主机到多主机——分布式系统的不断推进
网络·数据库·redis·分布式·架构
予枫的编程笔记13 小时前
Redis 核心数据结构深度解密:从基础命令到源码架构
java·数据结构·数据库·redis·缓存·架构
CodeAmaz15 小时前
一致性哈希与Redis哈希槽详解
redis·算法·哈希算法
一条大祥脚17 小时前
25.12.30
数据库·redis·缓存
清晓粼溪18 小时前
RestTemplate
java·spring cloud
程可爱18 小时前
详解Redis消息队列的三种实现方案
redis
Wang's Blog18 小时前
Lua: Web应用开发之OpenResty与Lapis框架深度指南
lua·openresty