手把手带你0到1摸透RPC框架轮子项目-day4路由层、协议层

原视频参考地址:[全网最全的手写RPC教程] 手摸手教你写一个RPC-架构-设计-落地实现_哔哩哔哩_bilibili

项目地址:kkoneone11/kkoneoneRPC-master (github.com)

觉得对你有帮助的帮忙文章给个like和项目给个stars呀!!!

因为我的设计是从顶到底去写代码,因此这一部分的代码量会比往后的多,大家可以分几天去编写

路由层

负载均衡算法 根据算法选择合适的服务。(本项目选择的是轮询和一致性哈希)

静态负载均衡算法 轮询(Round Robin):服务器按照顺序循环接受请求。服务器性能通常都是一致的。 随机(Random):随机选择一台服务器接受请求。 权重(Weight):给每个服务器分配一个权重值,处理能力强的通常权重值大,权重值大的通常请求分配的概率也会大,然后根据权重来分发请求到不同的机器中。 IP哈希(IP Hash):根据客户端IP计算Hash值取模访问对应服务器。 URL哈希(URL Hash):根据请求的URL地址计算Hash值取模访问对应服务器。 一致性哈希(Consistent Hash):采用一致性Hash算法,相同IP或URL请求总是发送到同一服务器。

取余:

当请求取余完了之后来到这个哈希环的时候,如果是在AB服务器中间则给B,如果是在BC服务器中间则给C,就是只给下一个服务器。

当B服务挂掉的时候,AB之间的请求就会给到C,这样就不会影响到A服务的处理

而当出现数据倾斜的情况,即请求落在AB之间的范围更大,可以通过设置多个虚拟节点(ABC都是真实的物理节点)到范围小的地方来均匀其中的范围

动态负载均衡算法 最少连接数(Least Connection):将请求分配给最少连接处理的服务器。最快响应(Fastest Response):将请求分配给响应时间最快的服务器。 观察(Observed):以连接数和响应时间的平衡为依据请求服务器。 预测(Predictive):收集分析当前服务器性能指标,预测下个时间段内性能最佳服务器。 动态性能分配(Dynamic Ratio-APM):收集服务器各项性能参数,动态调整流量分配。 服务质量(QoS):根据服务质量选择服务器。服务类型(ToS):根据服务类型选择服务器

代码部分

1.接口LoadBalancer。要实现其他负载均衡算法都要实现这个接口且实现里面的select方法。

2.ServiceMetaRes类。服务节点类,里面包含当前服务节点和其他服务节点。

3.LoadBalancerFactory类。负载均衡工厂,里面的init()初始化方法用作初始化加载对应算法并生成一个注册中心实例以便获取服务节点。

核心类

RoundRobinLoadBalancer轮询算法

java 复制代码
package org.kkoneone.rpc.router;

import org.kkoneone.rpc.common.ServiceMeta;
import org.kkoneone.rpc.config.RpcProperties;
import org.kkoneone.rpc.registry.RegistryService;
import org.kkoneone.rpc.spi.ExtensionLoader;

import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;

/**
 * 轮询算法
 * @Author:kkoneone11
 * @name:RoundRobinLoadBalancer
 * @Date:2023/12/6 21:40
 */
public class RoundRobinLoadBalancer implements LoadBalancer{
    //原子类 基于内存即可,因为是服务调用方所调用的,而服务提供方可能会有多个
    private static AtomicInteger roundRobinId = new AtomicInteger(0);
    @Override
    public ServiceMetaRes select(Object[] params, String serviceName) {
        //传入注册中心实现类型来获取注册中心
        RegistryService registryService = ExtensionLoader.getInstance().get(RpcProperties.getInstance().getRegisterType());

        // 1.获取所有服务
        List<ServiceMeta> discoveries = registryService.discoveries(serviceName);
        int size = discoveries.size();
        // 2.根据当前轮询ID取余服务长度得到具体服务 每次都加一
        roundRobinId.addAndGet(1);
        //处理Integer最大值的问题
        if(roundRobinId.get() == Integer.MAX_VALUE){
            roundRobinId.set(0);
        }
        //取余然后获取当前的服务
        return ServiceMetaRes.build(discoveries.get(roundRobinId.get() % size),discoveries);
    }
}

ConsistentHashLoadBalancer一致性哈希

java 复制代码
package org.kkoneone.rpc.router;

import org.kkoneone.rpc.common.ServiceMeta;
import org.kkoneone.rpc.config.RpcProperties;
import org.kkoneone.rpc.registry.RegistryService;
import org.kkoneone.rpc.spi.ExtensionLoader;

import java.util.List;
import java.util.Map;
import java.util.TreeMap;

/**
 * 一致性哈希
 * @Author:kkoneone11
 * @name:ConsistentHashLoadBalancer
 * @Date:2023/12/6 23:10
 */
public class ConsistentHashLoadBalancer implements LoadBalancer{

    //物理节点映射的虚拟节点,为了解决哈希倾斜
    private final static int VIRTUAL_NODE_SIZE = 10;
    private final static String VIRTUAL_NODE_SPLIT = "$";

    @Override
    public ServiceMetaRes select(Object[] params, String serviceName) {
        //获取注册中心
        RegistryService registryService = ExtensionLoader.getInstance().get(RpcProperties.getInstance().getRegisterType());
        //根据注册中心获取服务
        List<ServiceMeta> discoveries = registryService.discoveries(serviceName);
        final ServiceMeta curServiceMeta = allocateNode(makeConsistentHashRing(discoveries), params[0].hashCode());
        return ServiceMetaRes.build(curServiceMeta,discoveries);
    }


    /**
     * 选择节点
     * @param ring
     * @param hashCode
     * @return
     */
    private ServiceMeta allocateNode(TreeMap<Integer, ServiceMeta> ring, int hashCode){
        //获取最近哈希环上节点位置
        Map.Entry<Integer, ServiceMeta> entry = ring.ceilingEntry(hashCode);
        if(entry == null){
            //如果没有则找最小节点
            entry = ring.firstEntry();
        }
        return entry.getValue();
    }

    /**
     * 将所有服务实例添加到一致性哈希环上,并生成虚拟节点
     * 这里每次调用都需要构建哈希环是为了扩展(服务提供方)
     * @param servers 服务实例列表
     * @return 一致性哈希环
     */
    private TreeMap<Integer, ServiceMeta> makeConsistentHashRing(List<ServiceMeta> servers){
        TreeMap<Integer, ServiceMeta> ring = new TreeMap<>();
        for(ServiceMeta instance : servers){
            for(int i = 0; i < VIRTUAL_NODE_SIZE; i++){
                ring.put((buildServiceInstanceKey(instance) + VIRTUAL_NODE_SPLIT + i).hashCode(), instance);
            }
        }
        return ring;
    }

    /**
     * 根据服务实例信息构建缓存键
     * @param serviceMeta
     * @return
     */
    private String buildServiceInstanceKey(ServiceMeta serviceMeta){
        return String.join(":",serviceMeta.getServiceAddr(),String.valueOf(serviceMeta.getServicePort()));
    }
}

协议层

服务与服务之间进行通信需要制定好双方协议,是为了安全校验状态显示解决粘包半包 等问题。 为了跨平台存储以及网络传输因此需要序列化。在传输时也会出现粘包半包问题,也需要解决。 半包: 因窗口大小原因只发送了一半的数据,因此需要等待发送方继续发送再进行先接受

解决:拿到请求体数据后根据协议中请求体数据长度去判断拿到的请求体数据长度还不够长则表示还没发完,那就不要这次数据,等他再发一次即可

粘包: 假设发送放发了2条数据,并且这2条数据在一个窗口下,因此会出现粘包的问题

解决:同样也是根据协议中请求体数据长度且只需要拿这个长度的数据即可,之后的数据不需要管

因序列化方式都是以工具类的形式存在,并且只需要对比每种序列化的性能即可选择,因此这一小节主要讲数据格式,粘包半包,编码解码,代码层面

  • 魔数:用于标识数据格式或协议的开始或结束。它可以帮助接收方正确解析数据,确保数据的完整性和正确性
  • 版本:协议有不同版本方便更新迭代
  • 消息类型:请求或者响应
  • 状态:成功与否
  • 请求id:唯一识别一次请求。提供方也可以从请求中获取响应数据
  • 序列化方式的长度:因为序列化方式不是int long这种基本类型所以需要一个长度标识
  • 序列化方式:根据一个key从SPI这个IOC去选择对应的序列化方式
  • 请求体数据长度:解决粘包半包
  • 请求体:

和之前的代理层invoke()方法知识结合来看服务调用方和服务提供方交流的流程的话就是

RpcInvokerProxy

RpcInvokerProxy的invoke()方法其中服务调用者调用sendRequest()方法发送请求

RpcConsumer

sendRequest()里用客户端bootstrap来连接服务调用方地址和端口,然后通过这段代码future.channel().writeAndFlush(protocol);把请求写给ProviderPostProcessor服务提供方

ProviderPostProcessor

服务提供方通过建立客户端来接收请求。通过channel管道连接(责任链模式),把请求按照顺序挨个进行一系列处理(如解码或者编码),上一个过滤链处理完之后再交给下一个

RpcRequestHandler

当对请求处理到这的时候这个方法就会将对应请求的body取出然后通过Map找到对应的服务,然后由服务中再交给线程去进行处理

再创建一个RpcResponse将对应的结果处理情况写入,并封装成protocol

最后将protocol写回给RpcConsumer

RpcConsumer

RpcConsumer将返回的结果接收然后过滤到最后一个的时候

就读取到返回的msg消息并把对应的请求给清理然后返回结果

最后又回到invoke()方法此处,把结果复制给rpcResponse

代码部分

序列化方式

JsonSerialization

易读,但传输空间占用比较多

java 复制代码
package org.kkoneone.rpc.protocol.serialization;

import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.kkoneone.rpc.common.RpcRequest;
import org.kkoneone.rpc.common.RpcResponse;

import java.io.IOException;
import java.lang.reflect.Field;
import java.nio.charset.StandardCharsets;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.LinkedHashMap;

/**
 * @Author:kkoneone11
 * @name:JsonSerialization
 * @Date:2023/12/7 16:35
 */
public class JsonSerialization implements RpcSerialization{

    private final static ObjectMapper MAPPER;

    static{
        //JsonInclude.Include 是 Jackson 库中的一个枚举类型,用于指定在序列化 Java 对象为 JSON 时,哪些属性应该包含在输出的 JSON 中
        //- ALWAYS: 总是包含该属性
        //- NON_NULL: 只有当属性值不为 null 时才包含
        //- NON_ABSENT: 只有当属性值不为 Optional.empty() 时才包含
        //- NON_EMPTY: 只有当属性值不为 null 且不为空时才包含
        //- NON_DEFAULT: 只有当属性值不等于默认值时才包含
        //- USE_DEFAULTS: 使用默认的包含规则
        MAPPER = generateMapper(JsonInclude.Include.ALWAYS);
    }

    private static ObjectMapper generateMapper(JsonInclude.Include include){
        com.fasterxml.jackson.databind.ObjectMapper customMapper = new com.fasterxml.jackson.databind.ObjectMapper();
        customMapper.setSerializationInclusion(include);
        customMapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
        customMapper.configure(DeserializationFeature.FAIL_ON_NUMBERS_FOR_ENUMS, true);
        customMapper.setDateFormat(new SimpleDateFormat("yyyy-MM-dd HH:mm:ss"));
        return customMapper;
    }


    /**
     * 序列化
     * @param obj
     * @param <T>
     * @return
     * @throws IOException
     */
    @Override
    public <T> byte[] serialize(T obj) throws IOException {
        //如果是string类型则直接转化为byte流
        //如果不是则按照utf-8的格式转化成byte流
        return obj instanceof String ? ((String) obj).getBytes() : MAPPER.writeValueAsString(obj).getBytes(StandardCharsets.UTF_8);
    }

    /**
     * 反序列化 解决jackson在反序列化对象时为LinkedHashMap
     * @param data
     * @param clz
     * @param <T>
     * @return
     * @throws IOException
     */
    @Override
    public <T> T deserialize(byte[] data, Class<T> clz) throws IOException {
        //将byte数组反序列化为指定类型的对象
        final T t = MAPPER.readValue(data, clz);
        //先判断类型
        //为RpcRequest
        if(clz.equals(RpcRequest.class)){
            //将t转化
            RpcRequest rpcRequest = (RpcRequest) t;
            rpcRequest.setData(convertRes(rpcRequest.getData(),rpcRequest.getDataClass()));
            return (T) rpcRequest;
         //否则为RpcResponse
        }else{
            RpcResponse rpcResponse = (RpcResponse) t;
            rpcResponse.setData(convertRes(rpcResponse.getData(),rpcResponse.getDataClass()));
            return (T) rpcResponse;
        }

    }

    public Object convertReq(Object data,Class clazz){
        final LinkedHashMap map = (LinkedHashMap)((ArrayList) data).get(0);
        return convert(clazz,map);
    }

    public Object convertRes(Object data,Class clazz){
        final  LinkedHashMap map = (LinkedHashMap) ((ArrayList)data).get(0);
        return convert(clazz,map);
    }

    public Object convert(Class clazz,LinkedHashMap map){
        //额外处理对象
        final Class dataClass = clazz;
        try{
            //用类建一个实例
            Object o = dataClass.newInstance();
            //循环遍历map
            map.forEach((k,v)->{
                //动态地设置对象 o 的字段值
                try {
                    final Field field = dataClass.getDeclaredField(String.valueOf(k));
                    if (v!=null && v.getClass().equals(LinkedHashMap.class)){
                        v = convert(field.getType(),(LinkedHashMap) v);
                    }
                    field.setAccessible(true);
                    field.set(o,v);
                    field.setAccessible(false);
                } catch (IllegalAccessException e) {
                    e.printStackTrace();
                } catch (NoSuchFieldException e) {
                    e.printStackTrace();
                }
            });
            return o;
        }catch (InstantiationException | IllegalAccessException e){
            e.printStackTrace();
        }
        return null;
    }
}

HessianSerialization

java 复制代码
package org.kkoneone.rpc.protocol.serialization;

import com.caucho.hessian.io.HessianSerializerInput;
import com.caucho.hessian.io.HessianSerializerOutput;
import org.springframework.data.redis.serializer.SerializationException;

import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;

/**
 * @Author:kkoneone11
 * @name:HessianSerialization
 * @Date:2023/12/11 21:15
 */
public class HessianSerialization implements RpcSerialization{
    @Override
    public <T> byte[] serialize(T obj) throws IOException {
        if(obj == null){
            throw new NullPointerException();
        }

        //创建一个byte数组
        byte[] results;
        HessianSerializerOutput hessianOutput;
        //创建ByteArrayOutputStream和Hessian 序列化来将对象写入字节数组中
        try(ByteArrayOutputStream os = new ByteArrayOutputStream()){
            //创建序列化输出
            hessianOutput = new HessianSerializerOutput(os);
            //写入数据
            hessianOutput.writeObject(obj);
            //将缓冲区中的数据刷新到 os 中。
            hessianOutput.flush();
            results = os.toByteArray();
        }catch (Exception e){
            throw new SerializationException(e.toString());
        }

        return results;
    }

    @Override
    public <T> T deserialize(byte[] data, Class<T> clz) throws IOException {
        if (data == null) {
            throw new NullPointerException();
        }
        T result;

        try (ByteArrayInputStream is = new ByteArrayInputStream(data)) {
            HessianSerializerInput hessianInput = new HessianSerializerInput(is);
            result = (T) hessianInput.readObject(clz);
        } catch (Exception e) {
            throw new SerializationException(e.toString());
        }

        return result;
    }
}

核心类

编码类RpcEncoder

java 复制代码
package org.kkoneone.rpc.protocol.codec;

import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.MessageToByteEncoder;
import org.kkoneone.rpc.protocol.MsgHeader;
import org.kkoneone.rpc.protocol.RpcProtocol;
import org.kkoneone.rpc.protocol.serialization.RpcSerialization;
import org.kkoneone.rpc.protocol.serialization.SerializationFactory;

/**
 * 编码器
 * @Author:kkoneone11
 * @name:RpcEncoder
 * @Date:2023/11/28 14:35
 */
public class RpcEncoder extends MessageToByteEncoder<RpcProtocol<Object>> {
    @Override
    protected void encode(ChannelHandlerContext ctx, RpcProtocol<Object> msg, ByteBuf byteBuf) throws Exception {
        // 获取消息头类型
        MsgHeader header = msg.getHeader();
        // 写入魔数(安全校验,可以参考java中的CAFEBABE)
        byteBuf.writeShort(header.getMagic());
        // 写入版本号
        byteBuf.writeByte(header.getVersion());
        // 写入消息类型(接收放根据不同的消息类型进行不同的处理方式)
        byteBuf.writeByte(header.getMsgType());
        // 写入状态
        byteBuf.writeByte(header.getStatus());
        // 写入请求id(请求id可以用于记录异步回调标识,具体需要回调给哪个请求)
        byteBuf.writeLong(header.getRequestId());
        // 写入序列化方式(接收方需要依靠具体哪个序列化进行序列化)
        byteBuf.writeInt(header.getSerializationLen());
        final byte[] ser = header.getSerializations();
        byteBuf.writeBytes(ser);
        final String serialization = new String(ser);
        //获取序列化策略 序列化消息体
        RpcSerialization rpcSerialization = SerializationFactory.get(serialization);
        byte[] data = rpcSerialization.serialize(msg.getBody());
        // 写入数据长度(接收方根据数据长度读取数据内容)
        byteBuf.writeInt(data.length);
        // 写入数据
        byteBuf.writeBytes(data);
    }
}

解码类RpcDecoder

java 复制代码
package org.kkoneone.rpc.protocol.codec;

import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.ByteToMessageDecoder;
import org.kkoneone.rpc.common.RpcRequest;
import org.kkoneone.rpc.common.RpcResponse;
import org.kkoneone.rpc.common.constants.MsgType;
import org.kkoneone.rpc.common.constants.ProtocolConstants;
import org.kkoneone.rpc.protocol.MsgHeader;
import org.kkoneone.rpc.protocol.RpcProtocol;
import org.kkoneone.rpc.protocol.serialization.RpcSerialization;
import org.kkoneone.rpc.protocol.serialization.SerializationFactory;

import java.util.List;

/**
 * 解码器
 * @Author:kkoneone11
 * @name:RpcDecoder
 * @Date:2023/11/28 19:01
 */
public class RpcDecoder extends ByteToMessageDecoder {
    @Override
    protected void decode(ChannelHandlerContext channelHandlerContext, ByteBuf in, List<Object> out) throws Exception {
        // 如果可读字节数少于协议头长度,说明还没有接收完整个协议头,直接返回
        if(in.readableBytes() < ProtocolConstants.HEADER_TOTAL_LEN){
            return;
        }
        // 标记当前读取位置,便于后面回退
        in.markReaderIndex();
        // 1.读取魔数字段
        short magic = in.readShort();
        if (magic != ProtocolConstants.MAGIC) {
            throw new IllegalArgumentException("magic number is illegal, " + magic);
        }
        // 2.读取版本字段
        byte version = in.readByte();
        // 3.读取消息类型
        byte msgType = in.readByte();
        // 4.读取响应状态
        byte status = in.readByte();
        // 5.读取请求 ID
        long requestId = in.readLong();
        // 6.获取序列化算法长度
        final int len = in.readInt();
        if(in.readableBytes() < len){
            in.resetReaderIndex();
            return;
        }
        //7.序列化数据
        final byte[] bytes = new byte[len];
        in.readBytes(bytes);
        final String serialization = new String(bytes);
        // 8.读取消息体长度
        int dataLength = in.readInt();
        // 如果可读字节数小于消息体长度,说明还没有接收完整个消息体,回退并返回(半包问题)
        if(in.readableBytes() < dataLength){
            // 回退标记位置
            in.resetReaderIndex();
            return;
        }
        byte[] data = new byte[dataLength];
        // 读取数据
        in.readBytes(data);
        // 处理消息的类型
        MsgType msgTypeEnum = MsgType.findByType(msgType);
        if(msgTypeEnum == null){
            return;
        }
        // 构建消息头
        MsgHeader header = new MsgHeader();
        header.setMagic(magic);
        header.setVersion(version);
        header.setStatus(status);
        header.setRequestId(requestId);
        header.setMsgType(msgType);
        header.setSerializations(bytes);
        header.setSerializationLen(len);
        header.setMsgLen(dataLength);
        // 获取序列化器
        RpcSerialization rpcSerialization = SerializationFactory.get(serialization);
        // 根据消息类型进行处理(如果消息类型过多可以使用策略+工厂模式进行管理)
        switch (msgTypeEnum){
            //请求消息
            case REQUEST:
                RpcRequest request = rpcSerialization.deserialize(data, RpcRequest.class);
                if (request != null) {
                    RpcProtocol<RpcRequest> protocol = new RpcProtocol<>();
                    protocol.setHeader(header);
                    protocol.setBody(request);
                    out.add(protocol);
                }
                break;
            //响应消息
            case RESPONSE:
                RpcResponse response = rpcSerialization.deserialize(data, RpcResponse.class);
                if (response != null) {
                    RpcProtocol<RpcResponse> protocol = new RpcProtocol<>();
                    protocol.setHeader(header);
                    protocol.setBody(response);
                    out.add(protocol);
                }
                break;
        }
    }
}
相关推荐
monkey_meng几秒前
【Rust中的迭代器】
开发语言·后端·rust
余衫马3 分钟前
Rust-Trait 特征编程
开发语言·后端·rust
monkey_meng7 分钟前
【Rust中多线程同步机制】
开发语言·redis·后端·rust
paopaokaka_luck5 小时前
【360】基于springboot的志愿服务管理系统
java·spring boot·后端·spring·毕业设计
码农小旋风6 小时前
详解K8S--声明式API
后端
Peter_chq6 小时前
【操作系统】基于环形队列的生产消费模型
linux·c语言·开发语言·c++·后端
Yaml46 小时前
Spring Boot 与 Vue 共筑二手书籍交易卓越平台
java·spring boot·后端·mysql·spring·vue·二手书籍
小小小妮子~6 小时前
Spring Boot详解:从入门到精通
java·spring boot·后端
hong1616887 小时前
Spring Boot中实现多数据源连接和切换的方案
java·spring boot·后端
睡觉谁叫~~~8 小时前
一文解秘Rust如何与Java互操作
java·开发语言·后端·rust