springboot+redis+lua脚本实现滑动窗口限流

1 限流

为了维护系统稳定性和防止DDoS攻击,需要对系统请求量进行限制。

2 滑动窗口

限流方式有:固定窗口,滑动窗口,令牌桶和漏斗。滑动窗口的意思是:维护一个长度固定的窗口,动态统计窗口内请求次数,如果窗口内请求次数超过阈值则不允许访问。

3 实现

参考https://www.jianshu.com/p/cb11e552505b。采用Redis的zset数据结构,将当前请求的时间戳作为score字段,统计窗口时间内请求次数是否超过限制。

完整代码在https://gitcode.com/zsss1/ratelimit/overview

java 复制代码
// 限流类型
public enum LimitType {
    DEFAULT,
    IP
}
java 复制代码
// 限流注解
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface RateLimiter {
    String key() default "rate:limiter:";
    long limit() default 1;
    long expire() default 1;
    String message() default "访问频繁";
    LimitType limitType() default LimitType.IP;
}
java 复制代码
// 限流切面
@Component
@Aspect
public class RateLimiterHandler {
    private static final Logger LOGGER = LoggerFactory.getLogger(RateLimiterHandler.class);

    @Autowired
    private RedisTemplate<String, Object> redisTemplate;

    @Autowired
    @Qualifier("sliding_window")
    private RedisScript<Long> redisScript;

	// AOP动态代理com.example包下所有@annotation注解的方法
    @Around("execution(* com.example..*.*(..)) && @annotation(rateLimiter)")
    public Object around(ProceedingJoinPoint proceedingJoinPoint, RateLimiter rateLimiter) throws Throwable {
        Object[] args = proceedingJoinPoint.getArgs();
        long currentTime = Long.parseLong((String) args[0]);
        MethodSignature signature = (MethodSignature) proceedingJoinPoint.getSignature();
        Method method = signature.getMethod();
        StringBuilder limitKey = new StringBuilder(rateLimiter.key());
        if (rateLimiter.limitType() == LimitType.IP) {
            limitKey.append("127.0.0.1");
        }
        String className = method.getDeclaringClass().getName();
        String methodName = method.getName();
        limitKey.append("_").append(className).append("_").append(methodName);
        long limitCount = rateLimiter.limit();
        long windowTime = rateLimiter.expire();

        List<String> keyList = new ArrayList<>();
        keyList.add(limitKey.toString());
        Long result = redisTemplate.execute(redisScript, keyList, windowTime, currentTime, limitCount);

        if (result != null && result != 1) {
            throw new RuntimeException(rateLimiter.message());
        }
        return proceedingJoinPoint.proceed();
    }
}
lua 复制代码
-- 如果允许本次请求,返回1;如果不允许本次请求,返回0
--获取KEY

local key = KEYS[1]

--获取ARGV内的参数

-- 缓存时间

local expire = tonumber(ARGV[1])

-- 当前时间

local currentMs = tonumber(ARGV[2])

-- 最大次数

local limit_count = tonumber(ARGV[3])

--窗口开始时间

local windowStartMs = currentMs - tonumber(expire * 1000)

--获取key的次数

local current = redis.call('zcount', key, windowStartMs, currentMs)

--如果key的次数存在且大于预设值直接返回当前key的次数

if current and tonumber(current) >= limit_count then
    return 0;
end

-- 清除所有过期成员

redis.call("ZREMRANGEBYSCORE", key, 0, windowStartMs);

-- 添加当前成员

redis.call("zadd", key, currentMs, currentMs);

redis.call("expire", key, expire);

--返回key的次数

return 1
java 复制代码
// 测试类
// 为了方便统计当前时间,将时间作为请求参数传入接口
@SpringBootTest(classes = DemoApplication.class)
@AutoConfigureMockMvc
public class RateLimitControllerTest {
    @Autowired
    private WebApplicationContext webApplicationContext;

    private MockMvc mockMvc;
    @BeforeEach
    public void setUp() throws Exception {
        mockMvc = MockMvcBuilders.webAppContextSetup(webApplicationContext).build();
    }

    @Test
    public void test_rate_limit() throws Exception {
        String url = "/rate/test";
        Map<Long, Integer> timeStatusMap = new LinkedHashMap<>();
        
        for (int i = 0; i < 20; i++) {
            Thread.sleep(800);
            long currentTime = System.currentTimeMillis();
            MockHttpServletRequestBuilder builder = MockMvcRequestBuilders.get(url)
                    .param("currentTime", String.valueOf(currentTime)).accept(MediaType.APPLICATION_JSON);
            int status = mockMvc.perform(builder).andReturn().getResponse().getStatus();
            timeStatusMap.put(currentTime, status);
        }

        for (Map.Entry<Long, Integer> entry : timeStatusMap.entrySet()) {
            Long currentTime = entry.getKey();
            int status = entry.getValue();
            int spectedStatus = getStatusOfCurrentTime(currentTime, timeStatusMap.entrySet());
            System.out.println(status + ", " + spectedStatus + ", " + currentTime);
            // assertEquals(status, spectedStatus);
        }
    }

    private int getStatusOfCurrentTime(Long currentTime, Set<Map.Entry<Long, Integer>> set) {
        long startTime = currentTime - 5000;
        int count = 0;
        for (Map.Entry<Long, Integer> entry : set) {
            if (entry.getKey() >= startTime && entry.getKey() < currentTime && entry.getValue() == 200) {
                count++;
            }
        }
        if (count < 5) {
            return 200;
        }
        return 400;
    }
}
java 复制代码
// 接口
@RestController
@RequestMapping("/rate")
public class RateLimitController {
    @GetMapping("/test")
    @RateLimiter(limit = 5, expire = 5, limitType = LimitType.IP)
    public String test(String currentTime) {
        return "h";
    }
}
相关推荐
qq_12498707532 小时前
基于SSM的动物保护系统的设计与实现(源码+论文+部署+安装)
java·数据库·spring boot·毕业设计·ssm·计算机毕业设计
Coder_Boy_2 小时前
基于SpringAI的在线考试系统-考试系统开发流程案例
java·数据库·人工智能·spring boot·后端
2301_818732063 小时前
前端调用控制层接口,进不去,报错415,类型不匹配
java·spring boot·spring·tomcat·intellij-idea
此生只爱蛋3 小时前
【Redis】主从复制
数据库·redis
汤姆yu6 小时前
基于springboot的尿毒症健康管理系统
java·spring boot·后端
暮色妖娆丶6 小时前
Spring 源码分析 单例 Bean 的创建过程
spring boot·后端·spring
biyezuopinvip7 小时前
基于Spring Boot的企业网盘的设计与实现(任务书)
java·spring boot·后端·vue·ssm·任务书·企业网盘的设计与实现
惊讶的猫8 小时前
redis分片集群
数据库·redis·缓存·分片集群·海量数据存储·高并发写
JavaGuide8 小时前
一款悄然崛起的国产规则引擎,让业务编排效率提升 10 倍!
java·spring boot
期待のcode8 小时前
Redis的主从复制与集群
运维·服务器·redis