基于spring-cloud-gateway实现自己的网关过滤器
spring cloud gateway custom starter
自定义非阻塞式反应网关服务,集成鉴权、限流、响应的增强处理等等
-
环境要求
<properties> <java.version>17</java.version> <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding> <maven.compiler.target>17</maven.compiler.target> <maven.compiler.source>17</maven.compiler.source> <java.source.version>17</java.source.version> <java.target.version>17</java.target.version> <spring-boot.version>3.1.12</spring-boot.version> <spring-cloud.version>2022.0.5</spring-cloud.version> <commons.pool2.version>2.12.0</commons.pool2.version> <redisson.version>3.34.1</redisson.version> <com.fastjson.jackson.version>2.17.2</com.fastjson.jackson.version> <commons.lang3.version>3.16.0</commons.lang3.version> <lombok.version>1.18.32</lombok.version> </properties>
-
GatewayFilter 路由过滤器
-
TokenFilterGatewayFilterFactory,鉴权处理
-
配置示例:
spring:
cloud:
gateway:
routes:
- id: testRoute
uri: http://127.0.0.1:8080
predicates:
- Path=/test/**
filters:
- name: TokenFilter
args:
requestHeaderKey: auth -
RlimterGatewayFilterFactory,自定义key,比如:我们想根据不同用户去做自定义限流,那么我们可以在自己的网关过滤工厂里面将limterKey设置为根据请求头自定义的用户标识,来进行自定义的配置。相比重写spring-cloud-gateway里面自定义的keyResolver和RedisLimiter相对容易一些。
-
配置示例:
spring:
cloud:
gateway:
routes:
- id: route_test
uri: http://192.168.1.1:8091
predicates:
- Path=/from/requestIds/to/appNames
filters:
- name: RLimter
args:
limterKey: v1_10 #限流所需要的key
rate: 5 #每秒允许的请求数
crust: 0 #每秒令牌桶的填充数 -
更多路由过滤器扩展中。。。
-
自定义网关过滤工厂实现,TokenFilterGatewayFilterFactory和RlimterGatewayFilterFactory,以自定义限流器RlimterGatewayFilterFactory为例
-
- 首先定义网关过滤工厂功能接口,限流的key,速率和桶的大小,我们是按照spring-cloud-gateway内部限流的实现改编而来的,通过调用lua脚本,采用redis的令牌桶算法做限流
java
public interface RLimter {
Mono<Response> isAllowed(String limitKey, String rate, String crust);
@Setter
@Getter
@NoArgsConstructor
@AllArgsConstructor
class Response {
private boolean allowed;
}
}
- 限流功能组件的注入和声明,注入我们需要的Bean,注入RedisScript,调用自定义的lua脚本,以及StringRedisTemplate,因为下面这段代码我是将整个网关作为的启动器,所以限流的实现类也一并交给spring管理了,RRedisRateLimiter
java
@Configuration
public class RLimterAutoConfiguration {
@Bean(name = "rredisRequestRateLimiterScript")
public RedisScript<?> redisRequestRateLimiterScript() {
DefaultRedisScript redisScript = new DefaultRedisScript();
redisScript.setScriptSource(new ResourceScriptSource(new ClassPathResource("script/request_rate_limiter.lua")));
redisScript.setResultType(List.class);
return redisScript;
}
@Bean(name = "rredisRateLimiter")
public RRedisRateLimiter redisRateLimiter(@Qualifier("rredisRequestRateLimiterScript") RedisScript<List<Long>> redisRequestRateLimiterScript, StringRedisTemplate redisTemplate) {
return new RRedisRateLimiter(redisTemplate, redisRequestRateLimiterScript);
}
@Bean
public RLimterGatewayFilterFactory rLimterGatewayFilterFactory(@Qualifier("rredisRateLimiter") RRedisRateLimiter redisRateLimiter) {
return new RLimterGatewayFilterFactory(redisRateLimiter);
}
}
- 上面代码中提到的new ClassPathResource("script/request_rate_limiter.lua"),工程resources目录下即可,也是摘抄自spring-cloud-gateway内部的限流脚本,原封不动
lua
redis.replicate_commands()
local tokens_key = KEYS[1]
local timestamp_key = KEYS[2]
-- redis.log(redis.LOG_WARNING, "tokens_key " .. tokens_key)
local rate = tonumber(ARGV[1])
local capacity = tonumber(ARGV[2])
local now = tonumber(ARGV[3])
local requested = tonumber(ARGV[4])
local fill_time = capacity / rate
local ttl = math.floor(fill_time * 2)
-- for testing, it should use redis system time in production
if now == nil then
now = redis.call('TIME')[1]
end
--redis.log(redis.LOG_WARNING, "rate " .. ARGV[1])
--redis.log(redis.LOG_WARNING, "capacity " .. ARGV[2])
--redis.log(redis.LOG_WARNING, "now " .. now)
--redis.log(redis.LOG_WARNING, "requested " .. ARGV[4])
--redis.log(redis.LOG_WARNING, "filltime " .. fill_time)
--redis.log(redis.LOG_WARNING, "ttl " .. ttl)
local last_tokens = tonumber(redis.call("get", tokens_key))
if last_tokens == nil then
last_tokens = capacity
end
--redis.log(redis.LOG_WARNING, "last_tokens " .. last_tokens)
local last_refreshed = tonumber(redis.call("get", timestamp_key))
if last_refreshed == nil then
last_refreshed = 0
end
--redis.log(redis.LOG_WARNING, "last_refreshed " .. last_refreshed)
local delta = math.max(0, now-last_refreshed)
local filled_tokens = math.min(capacity, last_tokens+(delta*rate))
local allowed = filled_tokens >= requested
local new_tokens = filled_tokens
local allowed_num = 0
if allowed then
new_tokens = filled_tokens - requested
allowed_num = 1
end
--redis.log(redis.LOG_WARNING, "delta " .. delta)
--redis.log(redis.LOG_WARNING, "filled_tokens " .. filled_tokens)
--redis.log(redis.LOG_WARNING, "allowed_num " .. allowed_num)
--redis.log(redis.LOG_WARNING, "new_tokens " .. new_tokens)
if ttl > 0 then
redis.call("setex", tokens_key, ttl, new_tokens)
redis.call("setex", timestamp_key, ttl, now)
end
-- return { allowed_num, new_tokens, capacity, filled_tokens, requested, new_tokens }
return { allowed_num, new_tokens }
- 然后是限流实现类,通过传入上面lua脚本需要的四个参数即可,lua脚本中requested为每次从桶里面取出的令牌数,这个默认为1,此处不关注这个,默认值即可。
java
public class RRedisRateLimiter implements RLimter {
private static final Logger logger = LoggerFactory.getLogger(RedisRateLimiter.class);
private final StringRedisTemplate redisTemplate;
private final RedisScript<List<Long>> script;
public RRedisRateLimiter(StringRedisTemplate redisTemplate, RedisScript<List<Long>> script) {
this.redisTemplate = redisTemplate;
this.script = script;
}
static List<String> getKeys(String id) {
String prefix = "request_rate_limiter.{" + id;
String tokenKey = prefix + "}.tokens";
String timestampKey = prefix + "}.timestamp";
return Arrays.asList(tokenKey, timestampKey);
}
@Override
public Mono<Response> isAllowed(String key, String rate, String crust) {
List<String> keys = getKeys(key);
List<Long> exec;
try {
exec = this.redisTemplate.execute(this.script, keys, rate, crust, "", "1");
} catch (Throwable throwable) {
logger.error("Error calling rate limiter lua", throwable);
exec = Arrays.asList(1L, -1L);
}
assert exec != null;
boolean allowed = exec.get(0) == 1L;
return Mono.just(new Response(allowed));
}
}
- 最后就是定义我们自己的自定义限流网关工厂了,通过继承spring-cloud-gateway的一个父类,帮助我们加载自定义的网关过滤工厂,AbstractGatewayFilterFactory,父类支持传入我们的自定义配置参数,Config,通过泛型参数定义自己的配置类,并在构造中传入。
java
public class RLimterGatewayFilterFactory extends AbstractGatewayFilterFactory<RLimterGatewayFilterFactory.Config2> {
private final RRedisRateLimiter redisRateLimiter;
public RLimterGatewayFilterFactory(RRedisRateLimiter redisRateLimiter) {
super(Config2.class);
this.redisRateLimiter = redisRateLimiter;
}
@Override
public GatewayFilter apply(Config2 config) {
return (exchange, chain) -> redisRateLimiter.isAllowed(config.limterKey, config.rate, config.crust).flatMap(response -> {
if (response.isAllowed()) {
return chain.filter(exchange);
} else {
ServerWebExchangeUtils.setResponseStatus(exchange, config.getStatusCode());
return exchange.getResponse().setComplete();
}
});
}
@Setter
@Getter
public static class Config2 implements HasRouteId {
private String routeId;
private String limterKey = "default";
private String rate = "1";
private String crust = "1";
private HttpStatus statusCode = HttpStatus.TOO_MANY_REQUESTS;
@Override
public String getRouteId() {
return routeId;
}
@Override
public void setRouteId(String routeId) {
this.routeId = routeId;
}
}
}
- 这里我只是写了个简单的实例,如果实现自定义的限流,还需要在网关工厂里面,通过exchange拿到请求体来进行相应的key的解析,结合业务实现自定义的限流。配置可以参考一开始的yaml配置示例。可以测试当把令牌桶设置为0时,会给出TOO MANY REQUEST的429状态码。
全局过滤器案例
- 首先定义自己的上下文对象
java
@Setter
@Getter
public class RequestContext {
private String requestId;
private long requestStartTime;
private String requestIp;
}
- 全局过滤器实现
java
@Component
public class RequestContextFilter implements WebFilter, Ordered {
private static final Logger logger = LoggerFactory.getLogger(RequestContextFilter.class);
@Override
public int getOrder() {
return OrderConstant.REQUEST_CONTEXT_ORDER;
}
@Override
public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
long requestStartTime = SystemClock.now();
String requestId = generateRequestId();
ServerHttpRequest request = exchange.getRequest();
String uri = request.getURI().getRawPath();
String requestIp = IpUtils.tryGetRealIp(request);
exchange.getResponse().getHeaders().add("requestId", requestId);
logger.info("request start,requestId:{}, requestUri:{},ip:{},requestStartTime:{}", requestId, uri, requestIp, requestStartTime);
RequestContext context = new RequestContext();
context.setRequestId(requestId);
context.setRequestStartTime(requestStartTime);
context.setRequestIp(requestIp);
return chain.filter(exchange).contextWrite(Context.of(RequestContext.class, context))
.doOnEach(signal -> {
long requestEndTime = SystemClock.now();
if (signal.isOnComplete()) {
logger.info("request end,requestId:{},response:{},requestEndTime:{},耗时ms:{}", requestId, exchange.getResponse().getStatusCode(), requestEndTime, (requestEndTime - requestStartTime));
}
if (signal.isOnError()) {
logger.info("request end,requestId:{},error:{},requestEndTime:{},耗时ms:{}", requestId, signal.getThrowable(), requestEndTime, (requestEndTime - requestStartTime));
}
});
}
private String generateRequestId() {
return UUID.randomUUID().toString().replaceAll("-", "");
}
}