springboot+redis+lua脚本实现滑动窗口限流

1 限流

为了维护系统稳定性和防止DDoS攻击,需要对系统请求量进行限制。

2 滑动窗口

限流方式有:固定窗口,滑动窗口,令牌桶和漏斗。滑动窗口的意思是:维护一个长度固定的窗口,动态统计窗口内请求次数,如果窗口内请求次数超过阈值则不允许访问。

3 实现

参考https://www.jianshu.com/p/cb11e552505b。采用Redis的zset数据结构,将当前请求的时间戳作为score字段,统计窗口时间内请求次数是否超过限制。

完整代码在https://gitcode.com/zsss1/ratelimit/overview

java 复制代码
// 限流类型
public enum LimitType {
    DEFAULT,
    IP
}
java 复制代码
// 限流注解
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface RateLimiter {
    String key() default "rate:limiter:";
    long limit() default 1;
    long expire() default 1;
    String message() default "访问频繁";
    LimitType limitType() default LimitType.IP;
}
java 复制代码
// 限流切面
@Component
@Aspect
public class RateLimiterHandler {
    private static final Logger LOGGER = LoggerFactory.getLogger(RateLimiterHandler.class);

    @Autowired
    private RedisTemplate<String, Object> redisTemplate;

    @Autowired
    @Qualifier("sliding_window")
    private RedisScript<Long> redisScript;

	// AOP动态代理com.example包下所有@annotation注解的方法
    @Around("execution(* com.example..*.*(..)) && @annotation(rateLimiter)")
    public Object around(ProceedingJoinPoint proceedingJoinPoint, RateLimiter rateLimiter) throws Throwable {
        Object[] args = proceedingJoinPoint.getArgs();
        long currentTime = Long.parseLong((String) args[0]);
        MethodSignature signature = (MethodSignature) proceedingJoinPoint.getSignature();
        Method method = signature.getMethod();
        StringBuilder limitKey = new StringBuilder(rateLimiter.key());
        if (rateLimiter.limitType() == LimitType.IP) {
            limitKey.append("127.0.0.1");
        }
        String className = method.getDeclaringClass().getName();
        String methodName = method.getName();
        limitKey.append("_").append(className).append("_").append(methodName);
        long limitCount = rateLimiter.limit();
        long windowTime = rateLimiter.expire();

        List<String> keyList = new ArrayList<>();
        keyList.add(limitKey.toString());
        Long result = redisTemplate.execute(redisScript, keyList, windowTime, currentTime, limitCount);

        if (result != null && result != 1) {
            throw new RuntimeException(rateLimiter.message());
        }
        return proceedingJoinPoint.proceed();
    }
}
lua 复制代码
-- 如果允许本次请求,返回1;如果不允许本次请求,返回0
--获取KEY

local key = KEYS[1]

--获取ARGV内的参数

-- 缓存时间

local expire = tonumber(ARGV[1])

-- 当前时间

local currentMs = tonumber(ARGV[2])

-- 最大次数

local limit_count = tonumber(ARGV[3])

--窗口开始时间

local windowStartMs = currentMs - tonumber(expire * 1000)

--获取key的次数

local current = redis.call('zcount', key, windowStartMs, currentMs)

--如果key的次数存在且大于预设值直接返回当前key的次数

if current and tonumber(current) >= limit_count then
    return 0;
end

-- 清除所有过期成员

redis.call("ZREMRANGEBYSCORE", key, 0, windowStartMs);

-- 添加当前成员

redis.call("zadd", key, currentMs, currentMs);

redis.call("expire", key, expire);

--返回key的次数

return 1
java 复制代码
// 测试类
// 为了方便统计当前时间,将时间作为请求参数传入接口
@SpringBootTest(classes = DemoApplication.class)
@AutoConfigureMockMvc
public class RateLimitControllerTest {
    @Autowired
    private WebApplicationContext webApplicationContext;

    private MockMvc mockMvc;
    @BeforeEach
    public void setUp() throws Exception {
        mockMvc = MockMvcBuilders.webAppContextSetup(webApplicationContext).build();
    }

    @Test
    public void test_rate_limit() throws Exception {
        String url = "/rate/test";
        Map<Long, Integer> timeStatusMap = new LinkedHashMap<>();
        
        for (int i = 0; i < 20; i++) {
            Thread.sleep(800);
            long currentTime = System.currentTimeMillis();
            MockHttpServletRequestBuilder builder = MockMvcRequestBuilders.get(url)
                    .param("currentTime", String.valueOf(currentTime)).accept(MediaType.APPLICATION_JSON);
            int status = mockMvc.perform(builder).andReturn().getResponse().getStatus();
            timeStatusMap.put(currentTime, status);
        }

        for (Map.Entry<Long, Integer> entry : timeStatusMap.entrySet()) {
            Long currentTime = entry.getKey();
            int status = entry.getValue();
            int spectedStatus = getStatusOfCurrentTime(currentTime, timeStatusMap.entrySet());
            System.out.println(status + ", " + spectedStatus + ", " + currentTime);
            // assertEquals(status, spectedStatus);
        }
    }

    private int getStatusOfCurrentTime(Long currentTime, Set<Map.Entry<Long, Integer>> set) {
        long startTime = currentTime - 5000;
        int count = 0;
        for (Map.Entry<Long, Integer> entry : set) {
            if (entry.getKey() >= startTime && entry.getKey() < currentTime && entry.getValue() == 200) {
                count++;
            }
        }
        if (count < 5) {
            return 200;
        }
        return 400;
    }
}
java 复制代码
// 接口
@RestController
@RequestMapping("/rate")
public class RateLimitController {
    @GetMapping("/test")
    @RateLimiter(limit = 5, expire = 5, limitType = LimitType.IP)
    public String test(String currentTime) {
        return "h";
    }
}
相关推荐
杨DaB6 小时前
【SpringBoot】Swagger 接口工具
java·spring boot·后端·restful·swagger
昵称为空C7 小时前
SpringBoot接口限流的常用方案
服务器·spring boot
hrrrrb8 小时前
【Java Web 快速入门】十一、Spring Boot 原理
java·前端·spring boot
创码小奇客10 小时前
架构师私藏:SpringBoot 集成 Hera,让日志查看从 “找罪证” 变 “查答案”
spring boot·spring cloud·trae
Olrookie10 小时前
XXL-JOB GLUE模式动态数据源实践:Spring AOP + MyBatis 解耦多库查询
java·数据库·spring boot
waynaqua11 小时前
SpringBoot:听说你还不知道时区设置
spring boot
又是努力搬砖的一年11 小时前
SpringBoot中,接口加解密
java·spring boot·后端
风象南12 小时前
SpringBoot 自研运行时 SQL 调用树,3 分钟定位慢 SQL!
spring boot·后端
海梨花13 小时前
【从零开始学习Redis】项目实战-黑马点评D2
java·数据库·redis·后端·缓存
Q_Q51100828515 小时前
python的软件工程与项目管理课程组学习系统
spring boot·python·django·flask·node.js·php·软件工程