Java在SpringCloud中自定义Gateway负载均衡策略

Java在SpringCloud中自定义Gateway负载均衡策略

一、前言

spring-cloud-starter-netflix-ribbon已经不再更新了,最新版本是2.2.10.RELEASE,最后更新时间是2021年11月18日,详细信息可以看maven官方仓库:org.springframework.cloud/spring-cloud-starter-netflix-ribbon,SpringCloud官方推荐使用spring-cloud-starter-loadbalancer进行负载均衡。

背景:大文件上传做切片文件上传;

流程:将切片文件上传到服务器,然后进行合并任务,合并完成之后上传到对象存储;现在服务搞成多节点以后,网关默认走轮循,但是相同的服务在不同的机器上,这样就会导致切片文件散落在不同的服务器上,会导致文件合并失败;所以根据一个标识去自定义gateway对应服务的负载均衡策略,可以解决这个问题;

我的版本如下:

<spring-boot.version>2.7.3</spring-boot.version>

<spring-cloud.version>2021.0.4</spring-cloud.version>

<spring-cloud-alibaba.version>2021.0.4.0</spring-cloud-alibaba.version>

二、参考默认实现

springCloud原生默认的负载均衡策略是这个类:

org.springframework.cloud.loadbalancer.core.RoundRobinLoadBalancer

我们参考这个类实现自己的负载均衡策略即可,RoundRobinLoadBalancer实现了ReactorServiceInstanceLoadBalancer这个接口,实现了choose这个方法,如下图:

在choose方法中调用了processInstanceResponse方法,processInstanceResponse方法中调用了getInstanceResponse方法,所以我们我们可以复制RoundRobinLoadBalancer整个类,只修改getInstanceResponse这个方法里的内容就可以实现自定义负载均衡策略。

三、实现代码

原理:根据请求头当中设备的唯一标识传递到下游,唯一标识做哈希取余,可以指定对应的服务器节点,需要的服务设置自定义负载策略,不需要的服务设置默认的轮循机制即可.我这里是根据单独的接口请求地址去自定义,也可以根据服务名称自定义

package com.wondertek.gateway.loadBalancer;

import cn.hutool.core.util.ObjectUtil;
import com.wondertek.web.exception.enums.HttpRequestHeaderEnum;
import lombok.extern.slf4j.Slf4j;
import org.springframework.cloud.gateway.filter.GatewayFilterChain;
import org.springframework.cloud.gateway.filter.GlobalFilter;
import org.springframework.core.Ordered;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.stereotype.Component;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Mono;
@Slf4j
@Component
public class RequestFilter implements GlobalFilter, Ordered {
    @Override
    public int getOrder() {
        // 应该小于LoadBalancerClientFilter的顺序值
        return Ordered.HIGHEST_PRECEDENCE;
    }
    @Override
    public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
        ServerHttpRequest request = exchange.getRequest();
        String clientDeviceUniqueCode = request.getHeaders().getFirst(HttpRequestHeaderEnum.CLIENT_DEVICE_UNIQUE_CODE.getCode());
        // 存入Reactor上下文
        String resultCode = clientDeviceUniqueCode;

        //路径
        String pathUrl = request.getURI().getPath();
        /**
         * ^ 锚点匹配输入字符串的开始位置。
         *  /(oms-api|unity-api|cloud-api) 匹配以 /oms-api 或 /unity-api 或 /cloud-api 开始的任何字符串。
         *  replaceFirst() 方法用空字符串替换第一次匹配的内容,也就是我们想要去掉的服务名称。
         */
        String resultPathUrl = pathUrl.replaceFirst("^/(oms-api|unity-api|cloud-api)", "");
        return chain.filter(exchange)
                .contextWrite(context -> {
                    if (ObjectUtil.isNotEmpty(resultCode) && ObjectUtil.isNotEmpty(resultPathUrl)) {
                        log.info("开始将request中的唯一标识封装到上下游中:{},请求path是:{}", resultCode, resultPathUrl);
                        return context.put("identification", resultCode).put("pathUrl", resultPathUrl);
                    } else {
                        //根据需求进行其他处理
                        return context;
                    }
                });
    }
}

package com.wondertek.gateway.loadBalancer;

import cn.hutool.core.util.ObjectUtil;
import com.wondertek.center.constants.BusinessCenterApi;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.cloud.client.ServiceInstance;
import org.springframework.cloud.client.loadbalancer.DefaultResponse;
import org.springframework.cloud.client.loadbalancer.EmptyResponse;
import org.springframework.cloud.client.loadbalancer.Request;
import org.springframework.cloud.client.loadbalancer.Response;
import org.springframework.cloud.loadbalancer.core.NoopServiceInstanceListSupplier;
import org.springframework.cloud.loadbalancer.core.ReactorServiceInstanceLoadBalancer;
import org.springframework.cloud.loadbalancer.core.SelectedInstanceCallback;
import org.springframework.cloud.loadbalancer.core.ServiceInstanceListSupplier;
import reactor.core.publisher.Mono;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;

@Slf4j
public class ClientDeviceUniqueCodeInstanceLoadBalancer implements ReactorServiceInstanceLoadBalancer {

    private final String serviceId;
    final AtomicInteger position;
    private ObjectProvider<ServiceInstanceListSupplier> serviceInstanceListSupplierProvider;


    public ClientDeviceUniqueCodeInstanceLoadBalancer(ObjectProvider<ServiceInstanceListSupplier> serviceInstanceListSupplierProvider, String serviceId, AtomicInteger position) {
        this.serviceId = serviceId;
        this.serviceInstanceListSupplierProvider = serviceInstanceListSupplierProvider;
        this.position = position;
    }

    @Override
    public Mono<Response<ServiceInstance>> choose(Request request) {
        //在 choose 方法中,使用 deferContextual 方法来访问上下文并提取客户端标识。这里的 getOrDefault 方法尝试从上下文中获取一个键为 "identification" 的值,如果不存在则返回 "default-identification"
        return Mono.deferContextual(contextView -> {
            String identification = contextView.getOrDefault("identification", "");
            log.info("上下游获取到的identification的值为:{}", identification);

            String pathUrl = contextView.getOrDefault("pathUrl", "");
            log.info("上下游获取到的pathUrl的值为:{}", pathUrl);

            ServiceInstanceListSupplier supplier = serviceInstanceListSupplierProvider
                    .getIfAvailable(NoopServiceInstanceListSupplier::new);
            return supplier.get(request).next()
                    .map(serviceInstances -> processInstanceResponse(supplier, serviceInstances, identification, pathUrl));
        });
    }

    private Response<ServiceInstance> processInstanceResponse(ServiceInstanceListSupplier supplier, List<ServiceInstance> serviceInstances, String identification, String pathUrl) {
        Response<ServiceInstance> serviceInstanceResponse;
        //特定接口走自定义负载策略
        Boolean status = ObjectUtil.isNotEmpty(identification) && ObjectUtil.isNotEmpty(pathUrl) &&
                (pathUrl.contains(BusinessCenterApi.WEB_UPLOAD_SLICE_FILE) ||
                        pathUrl.contains(BusinessCenterApi.WEB_MERGE_SLICE_FILE) ||
                        pathUrl.contains(BusinessCenterApi.UNITY_UPLOAD_SLICE_FILE) ||
                        pathUrl.contains(BusinessCenterApi.UNITY_MERGE_SLICE_FILE) ||
                        pathUrl.contains(BusinessCenterApi.CLOUD_UPLOAD_SLICE_FILE) ||
                        pathUrl.contains(BusinessCenterApi.CLOUD_MERGE_SLICE_FILE));

        if (status) {
            serviceInstanceResponse = this.getIpInstanceResponse(serviceInstances, identification);
        } else {
            serviceInstanceResponse = this.getInstanceResponse(serviceInstances);
        }
        if (supplier instanceof SelectedInstanceCallback && serviceInstanceResponse.hasServer()) {
            ((SelectedInstanceCallback) supplier).selectedServiceInstance((ServiceInstance) serviceInstanceResponse.getServer());
        }
        return serviceInstanceResponse;
    }

    private Response<ServiceInstance> getInstanceResponse(List<ServiceInstance> instances) {
        if (instances.isEmpty()) {
            if (log.isWarnEnabled()) {
                log.warn("No servers available for service: " + this.serviceId);
            }
            return new EmptyResponse();
        } else {
            //创建一个新的列表以避免在原始列表上排序,避免了修改共享状态可能带来的线程安全问题
            List<ServiceInstance> sortedInstances = new ArrayList<>(instances);
            // 现在对新列表进行排序,保持原始列表的顺序不变
            Collections.sort(sortedInstances, Comparator.comparing(ServiceInstance::getHost));
            //log.info("获取到的实例个数的值为:{}", sortedInstances.size());
            sortedInstances.forEach(instance -> log.info("排序后的实例: {},{}", instance.getHost(), instance.getPort()));
            int pos = Math.abs(this.position.incrementAndGet());
            //log.info("默认轮循机制,pos递加后的值为:{}", pos);
            int positionIndex = pos % instances.size();
            //log.info("取余后的positionIndex的值为:{}", positionIndex);
            ServiceInstance instance = instances.get(positionIndex);
            //log.info("instance.getUri()的值为:{}", instance.getUri());
            log.info("特殊服务,默认轮循机制,routed to instance: {}:{}", instance.getHost(), instance.getPort());
            return new DefaultResponse(instance);
        }
    }

    private Response<ServiceInstance> getIpInstanceResponse(List<ServiceInstance> instances, String identification) {
        if (instances.isEmpty()) {
            log.warn("No servers available for service: " + this.serviceId);
            return new EmptyResponse();
        } else {
            //创建一个新的列表以避免在原始列表上排序,避免了修改共享状态可能带来的线程安全问题
            List<ServiceInstance> sortedInstances = new ArrayList<>(instances);
            // 现在对新列表进行排序,保持原始列表的顺序不变
            Collections.sort(sortedInstances, Comparator.comparing(ServiceInstance::getHost));
            //log.info("获取到的实例个数的值为:{}", sortedInstances.size());
            sortedInstances.forEach(instance -> log.info("排序后的实例: {},{}", instance.getHost(), instance.getPort()));
            //log.info("多个服务实例,使用客户端 identification 地址的哈希值来选择服务实例");
            // 使用排序后的列表来找到实例
            int ipHashCode = Math.abs(identification.hashCode());
            //log.info("identificationHashCode的值为:{}", ipHashCode);
            int instanceIndex = ipHashCode % sortedInstances.size();
            //log.info("instanceIndex的值为:{}", instanceIndex);
            ServiceInstance instanceToReturn = sortedInstances.get(instanceIndex);
            //log.info("instanceToReturn.getUri()的值为:{}", instanceToReturn.getUri());
            log.info("特殊服务,自定义identification负载机制,Client identification: {} is routed to instance: {}:{}", identification, instanceToReturn.getHost(), instanceToReturn.getPort());
            return new DefaultResponse(instanceToReturn);
        }
    }

}

package com.wondertek.gateway.loadBalancer;

import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.cloud.client.ServiceInstance;
import org.springframework.cloud.client.loadbalancer.DefaultResponse;
import org.springframework.cloud.client.loadbalancer.EmptyResponse;
import org.springframework.cloud.client.loadbalancer.Request;
import org.springframework.cloud.client.loadbalancer.Response;
import org.springframework.cloud.loadbalancer.core.NoopServiceInstanceListSupplier;
import org.springframework.cloud.loadbalancer.core.ReactorServiceInstanceLoadBalancer;
import org.springframework.cloud.loadbalancer.core.SelectedInstanceCallback;
import org.springframework.cloud.loadbalancer.core.ServiceInstanceListSupplier;
import reactor.core.publisher.Mono;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;

@Slf4j
public class DefaultInstanceLoadBalancer implements ReactorServiceInstanceLoadBalancer {

    private final String serviceId;
    private ObjectProvider<ServiceInstanceListSupplier> serviceInstanceListSupplierProvider;
    final AtomicInteger position;

    public DefaultInstanceLoadBalancer(ObjectProvider<ServiceInstanceListSupplier> serviceInstanceListSupplierProvider, String serviceId, AtomicInteger position) {
        this.serviceId = serviceId;
        this.serviceInstanceListSupplierProvider = serviceInstanceListSupplierProvider;
        this.position = position;
    }
    @Override
    public Mono<Response<ServiceInstance>> choose(Request request) {
        ServiceInstanceListSupplier supplier = serviceInstanceListSupplierProvider
                .getIfAvailable(NoopServiceInstanceListSupplier::new);
        return supplier.get(request).next()
                .map(serviceInstances -> processInstanceResponse(supplier, serviceInstances));
    }

    private Response<ServiceInstance> processInstanceResponse(ServiceInstanceListSupplier supplier,
                                                              List<ServiceInstance> serviceInstances) {
        Response<ServiceInstance> serviceInstanceResponse = getInstanceResponse(serviceInstances);
        if (supplier instanceof SelectedInstanceCallback && serviceInstanceResponse.hasServer()) {
            ((SelectedInstanceCallback) supplier).selectedServiceInstance(serviceInstanceResponse.getServer());
        }
        return serviceInstanceResponse;
    }

    private Response<ServiceInstance> getInstanceResponse(List<ServiceInstance> instances) {
        if (instances.isEmpty()) {
            if (log.isWarnEnabled()) {
                log.warn("No servers available for service: " + serviceId);
            }
            return new EmptyResponse();
        }
        //创建一个新的列表以避免在原始列表上排序,避免了修改共享状态可能带来的线程安全问题
        List<ServiceInstance> sortedInstances = new ArrayList<>(instances);
        // 现在对新列表进行排序,保持原始列表的顺序不变
        Collections.sort(sortedInstances, Comparator.comparing(ServiceInstance::getHost));
        //log.info("获取到的实例个数的值为:{}", sortedInstances.size());
        sortedInstances.forEach(instance -> log.info("排序后的实例: {},{}", instance.getHost(), instance.getPort()));
        int pos = Math.abs(this.position.incrementAndGet());
        //log.info("默认轮循机制,pos递加后的值为:{}", pos);
        int positionIndex = pos % instances.size();
        //log.info("取余后的positionIndex的值为:{}", positionIndex);
        ServiceInstance instance = instances.get(positionIndex);
        //log.info("instance.getUri()的值为:{}", instance.getUri());
        log.info("默认轮循机制,routed to instance: {}:{}",instance.getHost(), instance.getPort());
        return new DefaultResponse(instance);
    }

}

package com.wondertek.gateway.loadBalancer;

import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.cloud.loadbalancer.annotation.LoadBalancerClient;
import org.springframework.cloud.loadbalancer.annotation.LoadBalancerClients;
import org.springframework.cloud.loadbalancer.core.ReactorServiceInstanceLoadBalancer;
import org.springframework.cloud.loadbalancer.core.ServiceInstanceListSupplier;
import org.springframework.cloud.loadbalancer.support.LoadBalancerClientFactory;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.env.Environment;

import java.util.concurrent.atomic.AtomicInteger;

@Configuration
//单台服务
//@LoadBalancerClient(name = "oms-api", configuration = CustomLoadBalancerConfig.class)
//多台服务
@LoadBalancerClients({
        @LoadBalancerClient(name = "oms-api", configuration = CustomLoadBalancerConfig.class),
        @LoadBalancerClient(name = "unity-api", configuration = CustomLoadBalancerConfig.class),
        @LoadBalancerClient(name = "cloud-api", configuration = CustomLoadBalancerConfig.class),
        @LoadBalancerClient(name = "open-api", configuration = CustomLoadBalancerConfig.class),
        @LoadBalancerClient(name = "server-api", configuration = CustomLoadBalancerConfig.class),
        @LoadBalancerClient(name = "center-service", configuration = CustomLoadBalancerConfig.class),
})
@Slf4j
public class CustomLoadBalancerConfig {
    // 定义一个Bean来提供AtomicInteger的实例
    @Bean
    public AtomicInteger positionTracker() {
        // 这将在应用上下文中只初始化一次
        return new AtomicInteger(0);
    }

    //自定义优先级负载均衡器
    @Bean
    public ReactorServiceInstanceLoadBalancer customPriorityLoadBalancer(ObjectProvider<ServiceInstanceListSupplier> serviceInstanceListSupplierProvider,
                                                                     Environment environment,AtomicInteger positionTracker) {
        String serviceId = environment.getProperty(LoadBalancerClientFactory.PROPERTY_NAME);
        //目的为解决文件上传切片文件分散上传的问题
        if ("oms-api".equals(serviceId)||"unity-api".equals(serviceId)||"cloud-api".equals(serviceId)){
            //log.info("服务名称:serviceId:{},走自定义clientDeviceUniqueCode负载模式", serviceId);
            return new ClientDeviceUniqueCodeInstanceLoadBalancer(serviceInstanceListSupplierProvider, serviceId, positionTracker);
        }
        //log.info("服务名称:serviceId:{},走默认负载模式", serviceId);
        return new DefaultInstanceLoadBalancer(serviceInstanceListSupplierProvider, serviceId,positionTracker);
    }
}

【SpringCloud系列】开发环境下重写Loadbalancer实现自定义负载均衡

相关推荐
喵叔哟15 分钟前
重构代码中引入外部方法和引入本地扩展的区别
java·开发语言·重构
尘浮生21 分钟前
Java项目实战II基于微信小程序的电影院买票选座系统(开发文档+数据库+源码)
java·开发语言·数据库·微信小程序·小程序·maven·intellij-idea
郑祎亦44 分钟前
Spring Boot 项目 myblog 整理
spring boot·后端·java-ee·maven·mybatis
不是二师兄的八戒44 分钟前
本地 PHP 和 Java 开发环境 Docker 化与配置开机自启
java·docker·php
爱编程的小生1 小时前
Easyexcel(2-文件读取)
java·excel
带多刺的玫瑰1 小时前
Leecode刷题C语言之统计不是特殊数字的数字数量
java·c语言·算法
计算机毕设指导62 小时前
基于 SpringBoot 的作业管理系统【附源码】
java·vue.js·spring boot·后端·mysql·spring·intellij-idea
Gu Gu Study2 小时前
枚举与lambda表达式,枚举实现单例模式为什么是安全的,lambda表达式与函数式接口的小九九~
java·开发语言
Chris _data2 小时前
二叉树oj题解析
java·数据结构
牙牙7052 小时前
Centos7安装Jenkins脚本一键部署
java·servlet·jenkins