基于 Redis + Lua,实现“多维度原子限流”(令牌桶 + 滑动窗口)

基于 Redis + Lua,实现"多维度原子限流"(令牌桶 + 滑动窗口)

为啥要用redis+lua脚本来实现

实在是囊中羞涩,服务器内存不够,sentinel部署不了,redis+lua脚本不需要额外的配置.

实现原理

固定窗口算法

将时间划分为多个窗口,窗口时间跨度固定,假设为1秒。

每个窗口都有一个计数器,来记录请求数量,有一个请求,计数器值+1,当计数器值超过设置的阈值时,丢弃请求。保证每一个时间窗口内的请求数量都不超过阈值。

但是有一个缺点:

假设我们设置每一个时间窗口内请求阈值为3,

0.5秒到1秒请求数量为3,1秒到1.5秒请求数量也为3,这样QPS就达到了6。

因此就使用滑动窗口来解决。

滑动窗口算法

与固定窗口算法不同,滑动窗口算法只有一个固定跨度 窗口,和时间区间跨度相关。通过移动窗口,来关联前后时间的数据。

  • 窗口时间跨度(Interval)固定
  • 时间区间跨度为 Interval / n , n 越小,划分的区间越小,越精细
    • 窗口会随着当前请求所在时间currentTime 移动,窗口范围从currentTime-Interval 时刻之后的第一个时区开始,到currentTime所在时区结束。

具体流程:

假设阈值为3,跨度为1s, n 为2 ,时区就是500ms.

  1. 在1300ms内进入一个请求,其所在时区是1000 ~ 1500ms。
  2. 当前窗口就是 1300-1000 也就是300ms之后的第一个时区,即500-1000ms, 和1000-1500ms两个时区组成,如果这两个时区内已经有过3次请求了,则丢弃这次请求,如果没有,1000-1500ms时区内请求数+1。

什么是令牌桶算法

  • 以固定速率生成令牌,存入令牌桶中,如果令牌桶满了以后,多余令牌丢弃
  • 请求进入以后,必须尝试从桶中获取令牌,获取到令牌之后才可以被处理
  • 如果令牌桶中没有令牌,则请求等待或丢弃

所以每秒产生的令牌数量基本就是QPS上限。

多维度优势

维度 作用
用户 防刷
IP 防攻击
API 防热点

为什么选用ZSET

因为需要按照时间排序 +过期清理

执行流程

interval = 10秒
max_tokens = 5
permits = 1

  1. t = 0 秒,请求 A

    zset = []

    value = 5

    扣 1 个 token

    写入 zset

    zset:

    0 → reqA:1

    value = 4

  2. t = 3 秒,请求 B

    当前时间 = 3

    窗口范围 = [-7, 3]

bash 复制代码
 zset:
0 → reqA:1
3 → reqB:1 
value = 3
  1. t = 8 秒,请求 C
    窗口范围 = [-2, 8]
    仍然没有过期

    zset:
    0 → reqA:1
    3 → reqB:1
    8 → reqC:1
    value = 2

  2. t = 12 秒,请求 D(关键点来了)
    窗口范围 = [2, 12]

过期数据:

0 → reqA:1 (过期!)

第一步:回收 token

expired_values = zrangebyscore(0, 2)

得到:

reqA:1

👉 回收:

value = 2 + 1 = 3

🔥 第二步:删除过期记录

zremrangebyscore(0, 2)

👉 剩下:

3 → reqB:1

8 → reqC:1

🔥 第三步:处理当前请求

zset:

3 → reqB:1

8 → reqC:1

12 → reqD:1

value = 2

执行流程

  1. 第一阶段:预检查(不扣令牌)
lua 复制代码
-- 检查所有维度是否满足
if current_val < permits then
    return 0
end

只要有一个纬度不满足 ,直接失败。

  1. 第二阶段 : 同一扣减
lua 复制代码
redis.call("zadd", ...)
redis.call("set", ...)

所有维度都 OK 才一起扣

具体实现

lua脚本

lua 复制代码
---@diagnostic disable: undefined-global
-- 原子化多维度限流脚本
-- 基于令牌桶算法实现,支持多维度组合限流
-- 只有所有维度都满足条件时才扣减令牌,确保原子性

-- 参数说明:
-- KEYS[1..N]: 限流维度键列表
-- ARGV[1]: 当前时间戳(毫秒)
-- ARGV[2]: 申请令牌数
-- ARGV[3]: 时间窗口(毫秒)
-- ARGV[4]: 最大令牌数(窗口内允许的总数)
-- ARGV[5]: 请求唯一标识

local now_ms = tonumber(ARGV[1])
local permits = tonumber(ARGV[2])
local interval = tonumber(ARGV[3])
local max_tokens = tonumber(ARGV[4])
local request_id = ARGV[5]

-- 第一阶段:预检查阶段 - 检查所有维度是否有足够令牌
for i, key in ipairs(KEYS) do
    local value_key = key .. ":value"
    local permits_key = key .. ":permits"

    -- 初始化 value_key(如果不存在)
    if redis.call("exists", value_key) == 0 then
        redis.call("set", value_key, max_tokens)
    end

    -- 回收过期令牌
    -- 清理过期的 permit 记录,并回收配额到 value_key
    local expired_values = redis.call("zrangebyscore", permits_key, 0, now_ms - interval)
    if #expired_values > 0 then
        local expired_count = 0
        for _, v in ipairs(expired_values) do
            -- 优化解析逻辑:使用更高效的模式匹配
            local p = tonumber(string.match(v, ":(%d+)$"))
            if p then
                expired_count = expired_count + p
            end
        end

        -- 删除过期记录
        redis.call("zremrangebyscore", permits_key, 0, now_ms - interval)

        -- 回收配额
        if expired_count > 0 then
            local curr_v = tonumber(redis.call("get", value_key) or max_tokens)
            local next_v = math.min(max_tokens, curr_v + expired_count)
            redis.call("set", value_key, next_v)
        end
    end

    -- 核心检查:当前可用令牌是否足够
    local current_val = tonumber(redis.call("get", value_key) or max_tokens)
    if current_val < permits then
        -- 任何一个维度配额不足,直接返回失败
        return 0
    end
end

-- 第二阶段:扣减阶段 - 只有所有维度都通过后才执行
for i, key in ipairs(KEYS) do
    local value_key = key .. ":value"
    local permits_key = key .. ":permits"

    -- 记录本次令牌分配(格式:request_id:permits)
    local permit_record = request_id .. ":" .. permits
    redis.call("zadd", permits_key, now_ms, permit_record)

    -- 扣减令牌
    local current_v = tonumber(redis.call("get", value_key) or max_tokens)
    redis.call("set", value_key, current_v - permits)

    -- 设置过期时间,确保过期令牌能被正常回收 (窗口的2倍,至少1秒)
    local expire_time = math.ceil(interval * 2 / 1000)
    if expire_time < 1 then expire_time = 1 end
    redis.call("expire", value_key, expire_time)
    redis.call("expire", permits_key, expire_time)
end

-- 成功获取所有维度的令牌
return 1

。。待续

相关推荐
上海合宙LuatOS2 小时前
LuatOS扩展库API——【exgnss】GNSS定位
物联网·lua·luatos
一定要AK2 小时前
Spring 核心容器从入门到精通
java·后端·spring
RInk7oBjo2 小时前
spring boot3--自动配置与手动配置
java·spring boot·后端
最初的↘那颗心2 小时前
LangChain4j核心能力:AiService、Prompt注解与结构化输出实战
java·大模型·结构化输出·langchain4j·aiservice
lixia0417mul22 小时前
简单的RAG知识库问答
java
MacroZheng2 小时前
又一款企业级文件管理系统诞生了!支持万能文件在线预览,太香了!
java·spring boot·后端
云烟成雨TD2 小时前
Spring AI 1.x 系列【25】结构化输出案例演示
java·人工智能·spring
鱼鳞_2 小时前
Java学习笔记_Day23(HashMap)
java·笔记·学习