Java在SpringCloud中自定义Gateway负载均衡策略

Java在SpringCloud中自定义Gateway负载均衡策略

一、前言

spring-cloud-starter-netflix-ribbon已经不再更新了,最新版本是2.2.10.RELEASE,最后更新时间是2021年11月18日,详细信息可以看maven官方仓库:org.springframework.cloud/spring-cloud-starter-netflix-ribbon,SpringCloud官方推荐使用spring-cloud-starter-loadbalancer进行负载均衡。

背景:大文件上传做切片文件上传;

流程:将切片文件上传到服务器,然后进行合并任务,合并完成之后上传到对象存储;现在服务搞成多节点以后,网关默认走轮循,但是相同的服务在不同的机器上,这样就会导致切片文件散落在不同的服务器上,会导致文件合并失败;所以根据一个标识去自定义gateway对应服务的负载均衡策略,可以解决这个问题;

我的版本如下:

<spring-boot.version>2.7.3</spring-boot.version>

<spring-cloud.version>2021.0.4</spring-cloud.version>

<spring-cloud-alibaba.version>2021.0.4.0</spring-cloud-alibaba.version>

二、参考默认实现

springCloud原生默认的负载均衡策略是这个类:

org.springframework.cloud.loadbalancer.core.RoundRobinLoadBalancer

我们参考这个类实现自己的负载均衡策略即可,RoundRobinLoadBalancer实现了ReactorServiceInstanceLoadBalancer这个接口,实现了choose这个方法,如下图:

在choose方法中调用了processInstanceResponse方法,processInstanceResponse方法中调用了getInstanceResponse方法,所以我们我们可以复制RoundRobinLoadBalancer整个类,只修改getInstanceResponse这个方法里的内容就可以实现自定义负载均衡策略。

三、实现代码

原理:根据请求头当中设备的唯一标识传递到下游,唯一标识做哈希取余,可以指定对应的服务器节点,需要的服务设置自定义负载策略,不需要的服务设置默认的轮循机制即可.我这里是根据单独的接口请求地址去自定义,也可以根据服务名称自定义

复制代码
package com.wondertek.gateway.loadBalancer;

import cn.hutool.core.util.ObjectUtil;
import com.wondertek.web.exception.enums.HttpRequestHeaderEnum;
import lombok.extern.slf4j.Slf4j;
import org.springframework.cloud.gateway.filter.GatewayFilterChain;
import org.springframework.cloud.gateway.filter.GlobalFilter;
import org.springframework.core.Ordered;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.stereotype.Component;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Mono;
@Slf4j
@Component
public class RequestFilter implements GlobalFilter, Ordered {
    @Override
    public int getOrder() {
        // 应该小于LoadBalancerClientFilter的顺序值
        return Ordered.HIGHEST_PRECEDENCE;
    }
    @Override
    public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
        ServerHttpRequest request = exchange.getRequest();
        String clientDeviceUniqueCode = request.getHeaders().getFirst(HttpRequestHeaderEnum.CLIENT_DEVICE_UNIQUE_CODE.getCode());
        // 存入Reactor上下文
        String resultCode = clientDeviceUniqueCode;

        //路径
        String pathUrl = request.getURI().getPath();
        /**
         * ^ 锚点匹配输入字符串的开始位置。
         *  /(oms-api|unity-api|cloud-api) 匹配以 /oms-api 或 /unity-api 或 /cloud-api 开始的任何字符串。
         *  replaceFirst() 方法用空字符串替换第一次匹配的内容,也就是我们想要去掉的服务名称。
         */
        String resultPathUrl = pathUrl.replaceFirst("^/(oms-api|unity-api|cloud-api)", "");
        return chain.filter(exchange)
                .contextWrite(context -> {
                    if (ObjectUtil.isNotEmpty(resultCode) && ObjectUtil.isNotEmpty(resultPathUrl)) {
                        log.info("开始将request中的唯一标识封装到上下游中:{},请求path是:{}", resultCode, resultPathUrl);
                        return context.put("identification", resultCode).put("pathUrl", resultPathUrl);
                    } else {
                        //根据需求进行其他处理
                        return context;
                    }
                });
    }
}

package com.wondertek.gateway.loadBalancer;

import cn.hutool.core.util.ObjectUtil;
import com.wondertek.center.constants.BusinessCenterApi;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.cloud.client.ServiceInstance;
import org.springframework.cloud.client.loadbalancer.DefaultResponse;
import org.springframework.cloud.client.loadbalancer.EmptyResponse;
import org.springframework.cloud.client.loadbalancer.Request;
import org.springframework.cloud.client.loadbalancer.Response;
import org.springframework.cloud.loadbalancer.core.NoopServiceInstanceListSupplier;
import org.springframework.cloud.loadbalancer.core.ReactorServiceInstanceLoadBalancer;
import org.springframework.cloud.loadbalancer.core.SelectedInstanceCallback;
import org.springframework.cloud.loadbalancer.core.ServiceInstanceListSupplier;
import reactor.core.publisher.Mono;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;

@Slf4j
public class ClientDeviceUniqueCodeInstanceLoadBalancer implements ReactorServiceInstanceLoadBalancer {

    private final String serviceId;
    final AtomicInteger position;
    private ObjectProvider<ServiceInstanceListSupplier> serviceInstanceListSupplierProvider;


    public ClientDeviceUniqueCodeInstanceLoadBalancer(ObjectProvider<ServiceInstanceListSupplier> serviceInstanceListSupplierProvider, String serviceId, AtomicInteger position) {
        this.serviceId = serviceId;
        this.serviceInstanceListSupplierProvider = serviceInstanceListSupplierProvider;
        this.position = position;
    }

    @Override
    public Mono<Response<ServiceInstance>> choose(Request request) {
        //在 choose 方法中,使用 deferContextual 方法来访问上下文并提取客户端标识。这里的 getOrDefault 方法尝试从上下文中获取一个键为 "identification" 的值,如果不存在则返回 "default-identification"
        return Mono.deferContextual(contextView -> {
            String identification = contextView.getOrDefault("identification", "");
            log.info("上下游获取到的identification的值为:{}", identification);

            String pathUrl = contextView.getOrDefault("pathUrl", "");
            log.info("上下游获取到的pathUrl的值为:{}", pathUrl);

            ServiceInstanceListSupplier supplier = serviceInstanceListSupplierProvider
                    .getIfAvailable(NoopServiceInstanceListSupplier::new);
            return supplier.get(request).next()
                    .map(serviceInstances -> processInstanceResponse(supplier, serviceInstances, identification, pathUrl));
        });
    }

    private Response<ServiceInstance> processInstanceResponse(ServiceInstanceListSupplier supplier, List<ServiceInstance> serviceInstances, String identification, String pathUrl) {
        Response<ServiceInstance> serviceInstanceResponse;
        //特定接口走自定义负载策略
        Boolean status = ObjectUtil.isNotEmpty(identification) && ObjectUtil.isNotEmpty(pathUrl) &&
                (pathUrl.contains(BusinessCenterApi.WEB_UPLOAD_SLICE_FILE) ||
                        pathUrl.contains(BusinessCenterApi.WEB_MERGE_SLICE_FILE) ||
                        pathUrl.contains(BusinessCenterApi.UNITY_UPLOAD_SLICE_FILE) ||
                        pathUrl.contains(BusinessCenterApi.UNITY_MERGE_SLICE_FILE) ||
                        pathUrl.contains(BusinessCenterApi.CLOUD_UPLOAD_SLICE_FILE) ||
                        pathUrl.contains(BusinessCenterApi.CLOUD_MERGE_SLICE_FILE));

        if (status) {
            serviceInstanceResponse = this.getIpInstanceResponse(serviceInstances, identification);
        } else {
            serviceInstanceResponse = this.getInstanceResponse(serviceInstances);
        }
        if (supplier instanceof SelectedInstanceCallback && serviceInstanceResponse.hasServer()) {
            ((SelectedInstanceCallback) supplier).selectedServiceInstance((ServiceInstance) serviceInstanceResponse.getServer());
        }
        return serviceInstanceResponse;
    }

    private Response<ServiceInstance> getInstanceResponse(List<ServiceInstance> instances) {
        if (instances.isEmpty()) {
            if (log.isWarnEnabled()) {
                log.warn("No servers available for service: " + this.serviceId);
            }
            return new EmptyResponse();
        } else {
            //创建一个新的列表以避免在原始列表上排序,避免了修改共享状态可能带来的线程安全问题
            List<ServiceInstance> sortedInstances = new ArrayList<>(instances);
            // 现在对新列表进行排序,保持原始列表的顺序不变
            Collections.sort(sortedInstances, Comparator.comparing(ServiceInstance::getHost));
            //log.info("获取到的实例个数的值为:{}", sortedInstances.size());
            sortedInstances.forEach(instance -> log.info("排序后的实例: {},{}", instance.getHost(), instance.getPort()));
            int pos = Math.abs(this.position.incrementAndGet());
            //log.info("默认轮循机制,pos递加后的值为:{}", pos);
            int positionIndex = pos % instances.size();
            //log.info("取余后的positionIndex的值为:{}", positionIndex);
            ServiceInstance instance = instances.get(positionIndex);
            //log.info("instance.getUri()的值为:{}", instance.getUri());
            log.info("特殊服务,默认轮循机制,routed to instance: {}:{}", instance.getHost(), instance.getPort());
            return new DefaultResponse(instance);
        }
    }

    private Response<ServiceInstance> getIpInstanceResponse(List<ServiceInstance> instances, String identification) {
        if (instances.isEmpty()) {
            log.warn("No servers available for service: " + this.serviceId);
            return new EmptyResponse();
        } else {
            //创建一个新的列表以避免在原始列表上排序,避免了修改共享状态可能带来的线程安全问题
            List<ServiceInstance> sortedInstances = new ArrayList<>(instances);
            // 现在对新列表进行排序,保持原始列表的顺序不变
            Collections.sort(sortedInstances, Comparator.comparing(ServiceInstance::getHost));
            //log.info("获取到的实例个数的值为:{}", sortedInstances.size());
            sortedInstances.forEach(instance -> log.info("排序后的实例: {},{}", instance.getHost(), instance.getPort()));
            //log.info("多个服务实例,使用客户端 identification 地址的哈希值来选择服务实例");
            // 使用排序后的列表来找到实例
            int ipHashCode = Math.abs(identification.hashCode());
            //log.info("identificationHashCode的值为:{}", ipHashCode);
            int instanceIndex = ipHashCode % sortedInstances.size();
            //log.info("instanceIndex的值为:{}", instanceIndex);
            ServiceInstance instanceToReturn = sortedInstances.get(instanceIndex);
            //log.info("instanceToReturn.getUri()的值为:{}", instanceToReturn.getUri());
            log.info("特殊服务,自定义identification负载机制,Client identification: {} is routed to instance: {}:{}", identification, instanceToReturn.getHost(), instanceToReturn.getPort());
            return new DefaultResponse(instanceToReturn);
        }
    }

}

package com.wondertek.gateway.loadBalancer;

import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.cloud.client.ServiceInstance;
import org.springframework.cloud.client.loadbalancer.DefaultResponse;
import org.springframework.cloud.client.loadbalancer.EmptyResponse;
import org.springframework.cloud.client.loadbalancer.Request;
import org.springframework.cloud.client.loadbalancer.Response;
import org.springframework.cloud.loadbalancer.core.NoopServiceInstanceListSupplier;
import org.springframework.cloud.loadbalancer.core.ReactorServiceInstanceLoadBalancer;
import org.springframework.cloud.loadbalancer.core.SelectedInstanceCallback;
import org.springframework.cloud.loadbalancer.core.ServiceInstanceListSupplier;
import reactor.core.publisher.Mono;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;

@Slf4j
public class DefaultInstanceLoadBalancer implements ReactorServiceInstanceLoadBalancer {

    private final String serviceId;
    private ObjectProvider<ServiceInstanceListSupplier> serviceInstanceListSupplierProvider;
    final AtomicInteger position;

    public DefaultInstanceLoadBalancer(ObjectProvider<ServiceInstanceListSupplier> serviceInstanceListSupplierProvider, String serviceId, AtomicInteger position) {
        this.serviceId = serviceId;
        this.serviceInstanceListSupplierProvider = serviceInstanceListSupplierProvider;
        this.position = position;
    }
    @Override
    public Mono<Response<ServiceInstance>> choose(Request request) {
        ServiceInstanceListSupplier supplier = serviceInstanceListSupplierProvider
                .getIfAvailable(NoopServiceInstanceListSupplier::new);
        return supplier.get(request).next()
                .map(serviceInstances -> processInstanceResponse(supplier, serviceInstances));
    }

    private Response<ServiceInstance> processInstanceResponse(ServiceInstanceListSupplier supplier,
                                                              List<ServiceInstance> serviceInstances) {
        Response<ServiceInstance> serviceInstanceResponse = getInstanceResponse(serviceInstances);
        if (supplier instanceof SelectedInstanceCallback && serviceInstanceResponse.hasServer()) {
            ((SelectedInstanceCallback) supplier).selectedServiceInstance(serviceInstanceResponse.getServer());
        }
        return serviceInstanceResponse;
    }

    private Response<ServiceInstance> getInstanceResponse(List<ServiceInstance> instances) {
        if (instances.isEmpty()) {
            if (log.isWarnEnabled()) {
                log.warn("No servers available for service: " + serviceId);
            }
            return new EmptyResponse();
        }
        //创建一个新的列表以避免在原始列表上排序,避免了修改共享状态可能带来的线程安全问题
        List<ServiceInstance> sortedInstances = new ArrayList<>(instances);
        // 现在对新列表进行排序,保持原始列表的顺序不变
        Collections.sort(sortedInstances, Comparator.comparing(ServiceInstance::getHost));
        //log.info("获取到的实例个数的值为:{}", sortedInstances.size());
        sortedInstances.forEach(instance -> log.info("排序后的实例: {},{}", instance.getHost(), instance.getPort()));
        int pos = Math.abs(this.position.incrementAndGet());
        //log.info("默认轮循机制,pos递加后的值为:{}", pos);
        int positionIndex = pos % instances.size();
        //log.info("取余后的positionIndex的值为:{}", positionIndex);
        ServiceInstance instance = instances.get(positionIndex);
        //log.info("instance.getUri()的值为:{}", instance.getUri());
        log.info("默认轮循机制,routed to instance: {}:{}",instance.getHost(), instance.getPort());
        return new DefaultResponse(instance);
    }

}

package com.wondertek.gateway.loadBalancer;

import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.cloud.loadbalancer.annotation.LoadBalancerClient;
import org.springframework.cloud.loadbalancer.annotation.LoadBalancerClients;
import org.springframework.cloud.loadbalancer.core.ReactorServiceInstanceLoadBalancer;
import org.springframework.cloud.loadbalancer.core.ServiceInstanceListSupplier;
import org.springframework.cloud.loadbalancer.support.LoadBalancerClientFactory;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.env.Environment;

import java.util.concurrent.atomic.AtomicInteger;

@Configuration
//单台服务
//@LoadBalancerClient(name = "oms-api", configuration = CustomLoadBalancerConfig.class)
//多台服务
@LoadBalancerClients({
        @LoadBalancerClient(name = "oms-api", configuration = CustomLoadBalancerConfig.class),
        @LoadBalancerClient(name = "unity-api", configuration = CustomLoadBalancerConfig.class),
        @LoadBalancerClient(name = "cloud-api", configuration = CustomLoadBalancerConfig.class),
        @LoadBalancerClient(name = "open-api", configuration = CustomLoadBalancerConfig.class),
        @LoadBalancerClient(name = "server-api", configuration = CustomLoadBalancerConfig.class),
        @LoadBalancerClient(name = "center-service", configuration = CustomLoadBalancerConfig.class),
})
@Slf4j
public class CustomLoadBalancerConfig {
    // 定义一个Bean来提供AtomicInteger的实例
    @Bean
    public AtomicInteger positionTracker() {
        // 这将在应用上下文中只初始化一次
        return new AtomicInteger(0);
    }

    //自定义优先级负载均衡器
    @Bean
    public ReactorServiceInstanceLoadBalancer customPriorityLoadBalancer(ObjectProvider<ServiceInstanceListSupplier> serviceInstanceListSupplierProvider,
                                                                     Environment environment,AtomicInteger positionTracker) {
        String serviceId = environment.getProperty(LoadBalancerClientFactory.PROPERTY_NAME);
        //目的为解决文件上传切片文件分散上传的问题
        if ("oms-api".equals(serviceId)||"unity-api".equals(serviceId)||"cloud-api".equals(serviceId)){
            //log.info("服务名称:serviceId:{},走自定义clientDeviceUniqueCode负载模式", serviceId);
            return new ClientDeviceUniqueCodeInstanceLoadBalancer(serviceInstanceListSupplierProvider, serviceId, positionTracker);
        }
        //log.info("服务名称:serviceId:{},走默认负载模式", serviceId);
        return new DefaultInstanceLoadBalancer(serviceInstanceListSupplierProvider, serviceId,positionTracker);
    }
}

【SpringCloud系列】开发环境下重写Loadbalancer实现自定义负载均衡

相关推荐
创码小奇客4 分钟前
MongoDB 事务:数据世界的守护者联盟全解析
spring boot·mongodb·trae
AI的魔盒5 分钟前
基于Java与MAVLink协议的多无人机(Cube飞控)集群控制与调度方案问题
java·开发语言·无人机
北执南念42 分钟前
项目代码生成工具
java
中国lanwp1 小时前
springboot logback 默认加载配置文件顺序
java·spring boot·logback
cherishSpring1 小时前
在windows使用docker打包springboot项目镜像并上传到阿里云
spring boot·docker·容器
苹果酱05671 小时前
【Azure Redis 缓存】在Azure Redis中,如何限制只允许Azure App Service访问?
java·vue.js·spring boot·mysql·课程设计
Java致死2 小时前
单例设计模式
java·单例模式·设计模式
胡子发芽2 小时前
请详细解释Java中的线程池(ThreadPoolExecutor)的工作原理,并说明如何自定义线程池的拒绝策略
java
沫夕残雪2 小时前
Tomcat的安装与配置
java·tomcat
胡子发芽2 小时前
请解释Java中的NIO(New I/O)与传统I/O的区别,并说明NIO中的关键组件及其作用
java