redis+lua实现分布式限流
文章目录
为什么使用redis+lua实现分布式限流
- 原子性:通过Lua脚本执行限流逻辑,所有操作在一个原子上下文中完成,避免了多步操作导致的并发问题。
- 灵活性:Lua脚本可以编写复杂的逻辑,比如滑动窗口限流,易于扩展和定制化。
- 性能:由于所有逻辑在Redis服务器端执行,减少了网络往返,提高了执行效率。
使用ZSET也可以实现限流,为什么选择lua的方式
使用zset需要额度解决这些问题
- 并发控制:需要额外的逻辑来保证操作的原子性和准确性,可能需要配合Lua脚本或Lua脚本+WATCH/MULTI/EXEC模式来实现。
- 资源消耗:长期存储请求记录可能导致Redis占用更多的内存资源。
为什么redis+zset不能保证原子性和准确性
- 多步骤操作:滑动窗口限流通常需要执行多个步骤,比如检查当前窗口的请求次数、添加新的请求记录、可能还需要删除过期的请求记录等。这些操作如果分开执行,就有可能在多线程或多进程环境下出现不一致的情况。
- 非原子性复合操作:虽然单个Redis命令是原子的,但当你需要执行一系列操作来维持限流状态时(例如,先检查计数、再增加计数、最后可能还要删除旧记录),没有一个单一的Redis命令能完成这些复合操作。如果在这系列操作之间有其他客户端修改了数据,就会导致限流不准确。
- 竞争条件:在高并发环境下,多个客户端可能几乎同时执行限流检查和增加请求的操作,如果没有适当的同步机制,可能会导致请求计数错误。
实现
依赖
xml
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-parent</artifactId>
<version>2.2.6.RELEASE</version>
<relativePath/> <!-- lookup parent from repository -->
</parent>
<groupId>com.kang</groupId>
<artifactId>rate-limiter-project</artifactId>
<version>0.0.1-SNAPSHOT</version>
<name>rate-limiter-project</name>
<description>rate-limiter-project</description>
<properties>
<java.version>8</java.version>
</properties>
<dependencies>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-data-redis</artifactId>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-pool2</artifactId>
<version>2.6.2</version>
</dependency>
<dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
<version>31.0.1-jre</version> <!-- 请检查最新版本 -->
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
<version>3.12.0</version>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<optional>true</optional>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-maven-plugin</artifactId>
<configuration>
<excludes>
<exclude>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
</exclude>
</excludes>
</configuration>
</plugin>
</plugins>
</build>
</project>
lua脚本
lua
-- KEYS[1] 是Redis中存储计数的key,,,
local key = KEYS[1]
-- ARGV[1]是当前时间戳-[当前时间戳]
local now = tonumber(ARGV[1])
-- ARGV[2]是最大请求次数-[最大请求次数]
local maxRequests = tonumber(ARGV[2])
-- ARGV[3]是时间窗口长度-[时间窗口长度]
local windowSize = tonumber(ARGV[3])
-- 获取当前时间窗口的起始时间
local windowStart = math.floor(now / windowSize) * windowSize
-- 构建时间窗口内的key,用于区分不同窗口的计数
local windowKey = key .. ':' .. tostring(windowStart)
-- 获取当前窗口的计数
local currentCount = tonumber(redis.call('get', windowKey) or '0')
-- 如果当前时间不在窗口内,重置计数
if now > windowStart + windowSize then
redis.call('del', windowKey)
currentCount = 0
end
-- 检查是否超过限制
if currentCount + 1 <= maxRequests then
-- 未超过,增加计数并返回成功,并设置键的过期时间为窗口剩余时间,以自动清理过期数据。如果超过最大请求次数,则拒绝请求
redis.call('set', windowKey, currentCount + 1, 'EX', windowSize - (now - windowStart))
return 1 -- 成功
else
return 0 -- 失败
end
yaml
yaml
server:
port: 10086
spring:
redis:
host: 127.0.0.1
port: 6379
database: 0
lettuce:
pool:
max-active: 20
max-idle: 10
min-idle: 5
代码实现
启动类
java
package com.kang.limter;
import lombok.extern.slf4j.Slf4j;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
@Slf4j
@SpringBootApplication
public class RateLimiterProjectApplication {
public static void main(String[] args) {
SpringApplication.run(RateLimiterProjectApplication.class, args);
log.info("RateLimiterProjectApplication start success");
}
}
CacheConfig
java
package com.kang.limter.cache;
import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.kang.limter.utils.LuaScriptUtils;
import lombok.extern.slf4j.Slf4j;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.TimeUnit;
import static com.kang.limter.constant.SystemConstant.REDIS_RATE_LIMITER_LUA_SCRIPT_PATH;
/**
* @Author Emperor Kang
* @ClassName CacheConfig
* @Description 缓存配置
* @Date 2024/6/13 10:07
* @Version 1.0
* @Motto 让营地比你来时更干净
*/
@Slf4j
@Configuration
public class CacheConfig {
/**
* 缓存配置,加载lua脚本
* @return
*/
@Bean(name = "rateLimiterLuaCache")
public LoadingCache<String, String> rateLimiterLuaCache() {
LoadingCache<String, String> cache = CacheBuilder.newBuilder()
// 设置缓存的最大容量,最多100个键值对
.maximumSize(100)
// 设置缓存项过期策略:写入后2小时过期
.expireAfterWrite(2, TimeUnit.HOURS)
// 缓存统计信息记录
.recordStats()
// 构建缓存加载器,用于加载缓存项的值
.build(new CacheLoader<String, String>() {
@Override
public String load(String scriptPath) throws Exception {
try {
return LuaScriptUtils.loadLuaScript(scriptPath);
} catch (Exception e) {
log.error("加载lua脚本失败:{}", e.getMessage());
return null;
}
}
});
// 预热缓存
warmUpCache(cache);
return cache;
}
/**
* 预热缓存
*/
private void warmUpCache(LoadingCache<String, String> cache) {
try {
// 假设我们有一个已知的脚本列表需要预热
List<String> knownScripts = Collections.singletonList(REDIS_RATE_LIMITER_LUA_SCRIPT_PATH);
for (String script : knownScripts) {
String luaScript = LuaScriptUtils.loadLuaScript(script);
// 手动初始化缓存
cache.put(script, luaScript);
log.info("预加载Lua脚本成功: {}, length: {}", script, luaScript.length());
}
} catch (Exception e) {
log.error("预加载Lua脚本失败: {}", e.getMessage(), e);
}
}
}
- 这里使用缓存预热加快lua脚本的加载速度,基于JVM内存操作,所以很快
SystemConstant
java
package com.kang.limter.constant;
/**
* @Author Emperor Kang
* @ClassName SystemConstant
* @Description 系统常量
* @Date 2024/6/12 19:25
* @Version 1.0
* @Motto 让营地比你来时更干净
*/
public class SystemConstant {
/**
* 限流配置缓存key前缀
*/
public static final String REDIS_RATE_LIMITER_KEY_PREFIX = "outreach:config:limiter:%s";
/**
* 限流lua脚本路径
*/
public static final String REDIS_RATE_LIMITER_LUA_SCRIPT_PATH = "classpath:lua/rate_limiter.lua";
}
RateLimiterController
java
package com.kang.limter.controller;
import com.kang.limter.dto.RateLimiterRequestDto;
import com.kang.limter.utils.RateLimiterUtil;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import static java.lang.Thread.sleep;
/**
* @Author Emperor Kang
* @ClassName RateLimiterController
* @Description TODO
* @Date 2024/6/12 19:33
* @Version 1.0
* @Motto 让营地比你来时更干净
*/
@Slf4j
@RestController
@RequestMapping("/rate/limiter")
public class RateLimiterController {
@Autowired
private RateLimiterUtil rateLimiterUtil;
@PostMapping("/test")
public String test(@RequestBody RateLimiterRequestDto rateLimiterRequestDto) {
// 是否限流
if (!rateLimiterUtil.tryAcquire(rateLimiterRequestDto.getInterfaceCode(), 5, 1000)) {
log.info("触发限流策略,InterfaceCode:{}", rateLimiterRequestDto.getInterfaceCode());
return "我被限流了InterfaceCode:" + rateLimiterRequestDto.getInterfaceCode();
}
log.info("请求参数:{}", rateLimiterRequestDto);
try {
log.info("开始加工逻辑");
sleep(1000);
} catch (InterruptedException e) {
log.error("休眠异常");
Thread.currentThread().interrupt();
return "加工异常";
}
return "加工成功,成功返回";
}
}
RateLimiterRequestDto
java
package com.kang.limter.dto;
import lombok.Data;
/**
* @Author Emperor Kang
* @ClassName RateLimiterRequestDto
* @Description TODO
* @Date 2024/6/12 19:39
* @Version 1.0
* @Motto 让营地比你来时更干净
*/
@Data
public class RateLimiterRequestDto {
/**
* 接口编码
*/
private String interfaceCode;
}
ResourceLoaderException
java
package com.kang.limter.exception;
/**
* @Author Emperor Kang
* @ClassName ResourceLoaderException
* @Description 自定义资源加载异常
* @Date 2024/6/12 18:10
* @Version 1.0
* @Motto 让营地比你来时更干净
*/
public class ResourceLoaderException extends Exception{
public ResourceLoaderException() {
super();
}
public ResourceLoaderException(String message) {
super(message);
}
public ResourceLoaderException(String message, Throwable cause) {
super(message, cause);
}
public ResourceLoaderException(Throwable cause) {
super(cause);
}
protected ResourceLoaderException(String message, Throwable cause, boolean enableSuppression, boolean writableStackTrace) {
super(message, cause, enableSuppression, writableStackTrace);
}
}
LuaScriptUtils
java
package com.kang.limter.utils;
import com.kang.limter.exception.ResourceLoaderException;
import lombok.extern.slf4j.Slf4j;
import org.springframework.core.io.DefaultResourceLoader;
import org.springframework.core.io.Resource;
import org.springframework.core.io.ResourceLoader;
import org.springframework.util.Assert;
import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
@Slf4j
public class LuaScriptUtils {
/**
* 从类路径下读取Lua脚本内容。
* @param scriptPath 类路径下的Lua脚本文件路径
* @return Lua脚本的文本内容
*/
public static String loadLuaScript(String scriptPath) throws ResourceLoaderException {
Assert.notNull(scriptPath, "script path must not be null");
try {
// 读取lua脚本
ResourceLoader resourceLoader = new DefaultResourceLoader();
Resource resource = resourceLoader.getResource(scriptPath);
try (BufferedReader reader = new BufferedReader(new InputStreamReader(resource.getInputStream(), StandardCharsets.UTF_8))) {
StringBuilder scriptBuilder = new StringBuilder();
String line;
while ((line = reader.readLine()) != null) {
scriptBuilder.append(line).append("\n");
}
String lua = scriptBuilder.toString();
log.debug("读取的lua脚本为: {}", lua);
return lua;
}
} catch (Exception e) {
log.error("Failed to load Lua script from path: {}", scriptPath, e);
throw new ResourceLoaderException("Failed to load Lua script from path: " + scriptPath, e);
}
}
}
RateLimiterUtil
java
package com.kang.limter.utils;
import com.google.common.cache.LoadingCache;
import com.kang.limter.exception.ResourceLoaderException;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.data.redis.connection.ReturnType;
import org.springframework.data.redis.core.RedisCallback;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.stereotype.Component;
import org.springframework.util.Assert;
import java.nio.charset.StandardCharsets;
import static com.kang.limter.constant.SystemConstant.REDIS_RATE_LIMITER_KEY_PREFIX;
import static com.kang.limter.constant.SystemConstant.REDIS_RATE_LIMITER_LUA_SCRIPT_PATH;
/**
* @Author Emperor Kang
* @ClassName RateLimiterUtil
* @Description 限流工具类
* @Date 2024/6/12 17:56
* @Version 1.0
* @Motto 让营地比你来时更干净
*/
@Slf4j
@Component
public class RateLimiterUtil {
@Autowired
private StringRedisTemplate redisTemplate;
@Autowired
@Qualifier("rateLimiterLuaCache")
private LoadingCache<String, String> rateLimiterLuaCache;
/**
* @param interfaceCode 接口标识
* @param maxRequests 最大请求数
* @param windowSizeMs 窗口大小
* @return boolean
* @Description 尝试获取令牌
* @Author Emperor Kang
* @Date 2024/6/12 17:57
* @Version 1.0
*/
public boolean tryAcquire(String interfaceCode, int maxRequests, long windowSizeMs) {
try {
long currentTimeMillis = System.currentTimeMillis();
String luaScript = rateLimiterLuaCache.get(REDIS_RATE_LIMITER_LUA_SCRIPT_PATH);
log.info("缓存查询lua,length={}", luaScript.length());
if(StringUtils.isBlank(luaScript)){
log.info("从缓存中未获取到lua脚本,尝试手动读取");
luaScript = LuaScriptUtils.loadLuaScript(REDIS_RATE_LIMITER_LUA_SCRIPT_PATH);
}
// 二次确认
if(StringUtils.isBlank(luaScript)){
log.info("lua脚本加载失败,暂时放弃获取许可,不再限流");
return true;
}
// 限流核心逻辑
String finalLuaScript = luaScript;
Long result = redisTemplate.execute((RedisCallback<Long>) connection -> {
// 用于存储的key
byte[] key = String.format(REDIS_RATE_LIMITER_KEY_PREFIX, interfaceCode).getBytes(StandardCharsets.UTF_8);
// 当前时间(毫秒)
byte[] now = String.valueOf(currentTimeMillis).getBytes(StandardCharsets.UTF_8);
// 最大请求数
byte[] maxRequestsBytes = String.valueOf(maxRequests).getBytes(StandardCharsets.UTF_8);
// 窗口大小
byte[] windowSizeBytes = String.valueOf(windowSizeMs).getBytes(StandardCharsets.UTF_8);
// 执行lua脚本
return connection.eval(finalLuaScript.getBytes(StandardCharsets.UTF_8), ReturnType.INTEGER, 1, key, now, maxRequestsBytes, windowSizeBytes);
});
Assert.notNull(result, "执行lua脚本响应结果为null");
// 获取结果
return result == 1L;
} catch (ResourceLoaderException e) {
log.error("加载lua脚本失败", e);
} catch (Exception e){
log.error("执行限流逻辑异常", e);
}
return true;
}
}
lua脚本
lua
-- KEYS[1] 是Redis中存储计数的key,,,
local key = KEYS[1]
-- ARGV[1]是当前时间戳-[当前时间戳]
local now = tonumber(ARGV[1])
-- ARGV[2]是最大请求次数-[最大请求次数]
local maxRequests = tonumber(ARGV[2])
-- ARGV[3]是时间窗口长度-[时间窗口长度]
local windowSize = tonumber(ARGV[3])
-- 获取当前时间窗口的起始时间
local windowStart = math.floor(now / windowSize) * windowSize
-- 构建时间窗口内的key,用于区分不同窗口的计数
local windowKey = key .. ':' .. tostring(windowStart)
-- 获取当前窗口的计数
local currentCount = tonumber(redis.call('get', windowKey) or '0')
-- 如果当前时间不在窗口内,重置计数
if now > windowStart + windowSize then
redis.call('del', windowKey)
currentCount = 0
end
-- 检查是否超过限制
if currentCount + 1 <= maxRequests then
-- 未超过,增加计数并返回成功,并设置键的过期时间为窗口剩余时间,以自动清理过期数据。如果超过最大请求次数,则拒绝请求
redis.call('set', windowKey, currentCount + 1, 'EX', windowSize - (now - windowStart))
return 1 -- 成功
else
return 0 -- 失败
end
Jmeter压测
- 200次请求/s,限流了195,而我们设置的最大令牌数就是5