netty实现rpc

我来帮你实现一个完整的基于Netty的RPC框架。这是一个生产级别的RPC框架实现:

完整的Netty RPC框架实现

1. 项目结构

复制代码
rpc-framework/
├── rpc-core/           # 核心模块
├── rpc-client/         # 客户端模块
├── rpc-server/         # 服务端模块
├── rpc-common/         # 公共模块
└── example/            # 使用示例

2. 公共模块 (rpc-common)

2.1 消息模型
java 复制代码
// RPC请求
@Data
@Builder
@AllArgsConstructor
@NoArgsConstructor
public class RpcRequest implements Serializable {
    private String requestId;
    private String className;
    private String methodName;
    private Class<?>[] parameterTypes;
    private Object[] parameters;
    private String version;
    private boolean oneway; // 是否单向调用
}

// RPC响应
@Data
@Builder
@AllArgsConstructor
@NoArgsConstructor
public class RpcResponse implements Serializable {
    private String requestId;
    private Object result;
    private Throwable error;
    private int code;
    private String message;
    
    public boolean isSuccess() {
        return code == 0;
    }
}
2.2 RPC注解定义
java 复制代码
// 服务提供者注解
@Target(ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME)
@Component
public @interface RpcService {
    String version() default "1.0";
    String group() default "";
}

// 服务消费者注解
@Target({ElementType.FIELD, ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
@Autowired
public @interface RpcReference {
    String version() default "1.0";
    String group() default "";
    long timeout() default 3000;
    boolean async() default false;
    String loadBalance() default "random";
}
2.3 序列化接口
java 复制代码
public interface Serializer {
    byte[] serialize(Object obj);
    <T> T deserialize(byte[] bytes, Class<T> clazz);
    
    enum Algorithm {
        JSON,
        HESSIAN,
        PROTOSTUFF,
        KRYO
    }
}

// 使用Protostuff实现
public class ProtostuffSerializer implements Serializer {
    private static final LinkedBuffer BUFFER = 
        LinkedBuffer.allocate(LinkedBuffer.DEFAULT_BUFFER_SIZE);
    
    @Override
    public byte[] serialize(Object obj) {
        Class<?> clazz = obj.getClass();
        Schema schema = RuntimeSchema.getSchema(clazz);
        try {
            return ProtostuffIOUtil.toByteArray(obj, schema, BUFFER);
        } finally {
            BUFFER.clear();
        }
    }
    
    @Override
    public <T> T deserialize(byte[] bytes, Class<T> clazz) {
        Schema<T> schema = RuntimeSchema.getSchema(clazz);
        T obj = schema.newMessage();
        ProtostuffIOUtil.mergeFrom(bytes, obj, schema);
        return obj;
    }
}

3. 核心模块 (rpc-core)

3.1 协议编解码器
java 复制代码
public class RpcCodec extends MessageToMessageCodec<ByteBuf, Object> {
    private static final int MAGIC_NUMBER = 0xCAFEBABE;
    
    @Override
    protected void encode(ChannelHandlerContext ctx, Object msg, List<Object> out) {
        ByteBuf buf = ctx.alloc().buffer();
        
        // 序列化
        byte[] body = SerializerHolder.getSerializer().serialize(msg);
        
        // 写协议头
        buf.writeInt(MAGIC_NUMBER);
        buf.writeByte(1); // 版本
        buf.writeByte(0); // 消息类型:0-请求,1-响应
        buf.writeByte(0); // 序列化方式
        buf.writeInt(body.length);
        
        // 写消息体
        buf.writeBytes(body);
        
        out.add(buf);
    }
    
    @Override
    protected void decode(ChannelHandlerContext ctx, ByteBuf msg, List<Object> out) {
        // 校验魔数
        int magic = msg.readInt();
        if (magic != MAGIC_NUMBER) {
            throw new RuntimeException("Invalid magic number: " + magic);
        }
        
        // 读取协议头
        byte version = msg.readByte();
        byte messageType = msg.readByte();
        byte serializerType = msg.readByte();
        int length = msg.readInt();
        
        // 读取消息体
        byte[] body = new byte[length];
        msg.readBytes(body);
        
        // 反序列化
        Class<?> clazz = messageType == 0 ? RpcRequest.class : RpcResponse.class;
        Object obj = SerializerHolder.getSerializer().deserialize(body, clazz);
        
        out.add(obj);
    }
}
3.2 连接池实现
java 复制代码
public class ConnectionPool {
    private static final Logger logger = LoggerFactory.getLogger(ConnectionPool.class);
    
    private final String host;
    private final int port;
    private final int maxConnections;
    private final int minConnections;
    private final long maxWaitTime;
    
    private final Queue<Channel> idleConnections = new ConcurrentLinkedQueue<>();
    private final Set<Channel> activeConnections = ConcurrentHashMap.newKeySet();
    private final AtomicInteger connectionCount = new AtomicInteger(0);
    
    private final Bootstrap bootstrap;
    private final EventLoopGroup group;
    
    public ConnectionPool(String host, int port, int maxConnections, int minConnections) {
        this.host = host;
        this.port = port;
        this.maxConnections = maxConnections;
        this.minConnections = minConnections;
        this.maxWaitTime = 5000;
        
        this.group = new NioEventLoopGroup();
        this.bootstrap = new Bootstrap()
            .group(group)
            .channel(NioSocketChannel.class)
            .option(ChannelOption.SO_KEEPALIVE, true)
            .option(ChannelOption.TCP_NODELAY, true)
            .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, 3000)
            .handler(new ChannelInitializer<SocketChannel>() {
                @Override
                protected void initChannel(SocketChannel ch) {
                    ChannelPipeline p = ch.pipeline();
                    p.addLast(new IdleStateHandler(0, 0, 60, TimeUnit.SECONDS));
                    p.addLast(new RpcCodec());
                    p.addLast(new RpcClientHandler());
                }
            });
        
        // 初始化最小连接数
        initMinConnections();
    }
    
    private void initMinConnections() {
        for (int i = 0; i < minConnections; i++) {
            try {
                Channel channel = createConnection();
                if (channel != null) {
                    idleConnections.offer(channel);
                }
            } catch (Exception e) {
                logger.warn("Failed to create connection", e);
            }
        }
    }
    
    private Channel createConnection() throws InterruptedException {
        if (connectionCount.get() >= maxConnections) {
            throw new IllegalStateException("Connection pool exhausted");
        }
        
        ChannelFuture future = bootstrap.connect(host, port).sync();
        if (future.isSuccess()) {
            Channel channel = future.channel();
            channel.closeFuture().addListener(f -> {
                connectionCount.decrementAndGet();
                activeConnections.remove(channel);
                idleConnections.remove(channel);
            });
            
            connectionCount.incrementAndGet();
            return channel;
        }
        return null;
    }
    
    public Channel getConnection() throws Exception {
        return getConnection(maxWaitTime, TimeUnit.MILLISECONDS);
    }
    
    public Channel getConnection(long timeout, TimeUnit unit) throws Exception {
        long startTime = System.currentTimeMillis();
        long timeoutMillis = unit.toMillis(timeout);
        
        while (System.currentTimeMillis() - startTime < timeoutMillis) {
            // 1. 从空闲连接获取
            Channel channel = idleConnections.poll();
            if (channel != null && channel.isActive()) {
                activeConnections.add(channel);
                return channel;
            }
            
            // 2. 创建新连接
            if (connectionCount.get() < maxConnections) {
                channel = createConnection();
                if (channel != null) {
                    activeConnections.add(channel);
                    return channel;
                }
            }
            
            // 3. 等待连接释放
            Thread.sleep(10);
        }
        
        throw new TimeoutException("Get connection timeout");
    }
    
    public void releaseConnection(Channel channel) {
        if (channel == null || !channel.isActive()) {
            return;
        }
        
        activeConnections.remove(channel);
        
        if (idleConnections.size() < maxConnections) {
            idleConnections.offer(channel);
        } else {
            channel.close();
        }
    }
    
    public void close() {
        idleConnections.forEach(Channel::close);
        activeConnections.forEach(Channel::close);
        group.shutdownGracefully();
    }
}

4. 客户端模块 (rpc-client)

4.1 客户端处理器(使用Promise)
java 复制代码
@ChannelHandler.Sharable
public class RpcClientHandler extends SimpleChannelInboundHandler<RpcResponse> {
    private static final Logger logger = LoggerFactory.getLogger(RpcClientHandler.class);
    
    private final Map<String, RpcFuture<?>> pendingRequests = new ConcurrentHashMap<>();
    private final AtomicLong requestIdGenerator = new AtomicLong(0);
    
    // 发送请求,返回Promise
    public <T> Promise<T> sendRequest(RpcRequest request, Channel channel) {
        String requestId = generateRequestId();
        request.setRequestId(requestId);
        
        // 创建Promise
        DefaultPromise<T> promise = new DefaultPromise<>(channel.eventLoop());
        RpcFuture<T> future = new RpcFuture<>(promise, request, channel);
        
        pendingRequests.put(requestId, future);
        
        // 发送请求
        channel.writeAndFlush(request).addListener(f -> {
            if (!f.isSuccess()) {
                promise.setFailure(f.cause());
                pendingRequests.remove(requestId);
            }
        });
        
        return promise;
    }
    
    // 处理响应
    @Override
    protected void channelRead0(ChannelHandlerContext ctx, RpcResponse response) {
        RpcFuture<?> future = pendingRequests.remove(response.getRequestId());
        if (future != null) {
            future.complete(response);
        } else {
            logger.warn("Received response for unknown request: {}", response.getRequestId());
        }
    }
    
    // 异常处理
    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
        logger.error("Client exception", cause);
        // 失败所有pending请求
        pendingRequests.values().forEach(future -> {
            future.getPromise().tryFailure(cause);
        });
        pendingRequests.clear();
        ctx.close();
    }
    
    // 心跳检测
    @Override
    public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
        if (evt instanceof IdleStateEvent) {
            IdleStateEvent event = (IdleStateEvent) evt;
            if (event.state() == IdleState.WRITER_IDLE) {
                // 发送心跳
                RpcRequest heartbeat = RpcRequest.builder()
                    .requestId("heartbeat-" + System.currentTimeMillis())
                    .className("heartbeat")
                    .oneway(true)
                    .build();
                ctx.writeAndFlush(heartbeat);
            }
        }
    }
    
    private String generateRequestId() {
        return System.currentTimeMillis() + "-" + requestIdGenerator.incrementAndGet();
    }
    
    // 包装Future类
    private static class RpcFuture<T> {
        private final Promise<T> promise;
        private final RpcRequest request;
        private final Channel channel;
        private final long timestamp;
        
        RpcFuture(Promise<T> promise, RpcRequest request, Channel channel) {
            this.promise = promise;
            this.request = request;
            this.channel = channel;
            this.timestamp = System.currentTimeMillis();
        }
        
        @SuppressWarnings("unchecked")
        void complete(RpcResponse response) {
            if (response.isSuccess()) {
                promise.setSuccess((T) response.getResult());
            } else {
                promise.setFailure(response.getError());
            }
        }
        
        Promise<T> getPromise() {
            return promise;
        }
        
        long getTimestamp() {
            return timestamp;
        }
    }
}
4.2 客户端代理工厂
java 复制代码
public class RpcClientProxy {
    private final ConnectionPool connectionPool;
    private final RpcClientHandler clientHandler;
    private final LoadBalancer loadBalancer;
    
    public RpcClientProxy(String host, int port) {
        this.connectionPool = new ConnectionPool(host, port, 10, 2);
        this.clientHandler = new RpcClientHandler();
        this.loadBalancer = new RandomLoadBalancer();
    }
    
    @SuppressWarnings("unchecked")
    public <T> T createProxy(Class<T> interfaceClass, RpcReference reference) {
        return (T) Proxy.newProxyInstance(
            interfaceClass.getClassLoader(),
            new Class<?>[]{interfaceClass},
            (proxy, method, args) -> {
                // 构建请求
                RpcRequest request = RpcRequest.builder()
                    .className(interfaceClass.getName())
                    .methodName(method.getName())
                    .parameterTypes(method.getParameterTypes())
                    .parameters(args)
                    .version(reference.version())
                    .build();
                
                // 获取连接
                Channel channel = connectionPool.getConnection(
                    reference.timeout(), TimeUnit.MILLISECONDS);
                
                try {
                    // 发送请求
                    Promise<Object> promise = clientHandler.sendRequest(request, channel);
                    
                    // 同步调用
                    if (!reference.async()) {
                        if (promise.await(reference.timeout())) {
                            if (promise.isSuccess()) {
                                return promise.getNow();
                            } else {
                                throw promise.cause();
                            }
                        } else {
                            throw new TimeoutException("RPC timeout");
                        }
                    }
                    // 异步调用
                    else {
                        return CompletableFuture.supplyAsync(() -> {
                            try {
                                if (promise.await(reference.timeout())) {
                                    if (promise.isSuccess()) {
                                        return promise.getNow();
                                    } else {
                                        throw new ExecutionException(promise.cause());
                                    }
                                } else {
                                    throw new TimeoutException("RPC timeout");
                                }
                            } catch (Exception e) {
                                throw new CompletionException(e);
                            }
                        });
                    }
                } finally {
                    // 归还连接
                    connectionPool.releaseConnection(channel);
                }
            }
        );
    }
}

5. 服务端模块 (rpc-server)

5.1 服务注册中心
java 复制代码
public class ServiceRegistry {
    private final Map<String, Object> serviceMap = new ConcurrentHashMap<>();
    private final Map<String, Class<?>> serviceClassMap = new ConcurrentHashMap<>();
    
    public void registerService(String serviceName, Object service, Class<?> serviceClass) {
        String key = buildServiceKey(serviceName, "1.0");
        serviceMap.put(key, service);
        serviceClassMap.put(key, serviceClass);
    }
    
    public Object getService(String serviceName, String version) {
        return serviceMap.get(buildServiceKey(serviceName, version));
    }
    
    public Class<?> getServiceClass(String serviceName, String version) {
        return serviceClassMap.get(buildServiceKey(serviceName, version));
    }
    
    private String buildServiceKey(String serviceName, String version) {
        return serviceName + ":" + version;
    }
}
5.2 服务端处理器
java 复制代码
@ChannelHandler.Sharable
public class RpcServerHandler extends SimpleChannelInboundHandler<RpcRequest> {
    private static final Logger logger = LoggerFactory.getLogger(RpcServerHandler.class);
    
    private final ServiceRegistry serviceRegistry;
    private final ThreadPoolExecutor executor;
    
    public RpcServerHandler(ServiceRegistry serviceRegistry) {
        this.serviceRegistry = serviceRegistry;
        this.executor = new ThreadPoolExecutor(
            10, 100, 60, TimeUnit.SECONDS,
            new LinkedBlockingQueue<>(1000),
            new NamedThreadFactory("rpc-server"),
            new ThreadPoolExecutor.CallerRunsPolicy()
        );
    }
    
    @Override
    protected void channelRead0(ChannelHandlerContext ctx, RpcRequest request) {
        // 单向调用,不需要响应
        if (request.isOneway()) {
            processRequest(request);
            return;
        }
        
        // 提交到线程池处理
        executor.submit(() -> {
            try {
                Object result = processRequest(request);
                RpcResponse response = RpcResponse.builder()
                    .requestId(request.getRequestId())
                    .result(result)
                    .code(0)
                    .message("success")
                    .build();
                ctx.writeAndFlush(response);
            } catch (Throwable t) {
                logger.error("Process request error", t);
                RpcResponse response = RpcResponse.builder()
                    .requestId(request.getRequestId())
                    .error(t)
                    .code(500)
                    .message(t.getMessage())
                    .build();
                ctx.writeAndFlush(response);
            }
        });
    }
    
    private Object processRequest(RpcRequest request) throws Exception {
        Object service = serviceRegistry.getService(request.getClassName(), request.getVersion());
        if (service == null) {
            throw new RuntimeException("Service not found: " + request.getClassName());
        }
        
        Class<?> serviceClass = serviceRegistry.getServiceClass(
            request.getClassName(), request.getVersion());
        
        Method method = serviceClass.getMethod(
            request.getMethodName(), request.getParameterTypes());
        
        return method.invoke(service, request.getParameters());
    }
    
    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
        logger.error("Server exception", cause);
        ctx.close();
    }
}
5.3 RPC服务器
java 复制代码
public class RpcServer {
    private final int port;
    private final ServiceRegistry serviceRegistry;
    private EventLoopGroup bossGroup;
    private EventLoopGroup workerGroup;
    
    public RpcServer(int port) {
        this.port = port;
        this.serviceRegistry = new ServiceRegistry();
    }
    
    public void registerService(Object service) {
        Class<?>[] interfaces = service.getClass().getInterfaces();
        RpcService annotation = service.getClass().getAnnotation(RpcService.class);
        
        for (Class<?> interfaceClass : interfaces) {
            String version = annotation != null ? annotation.version() : "1.0";
            serviceRegistry.registerService(
                interfaceClass.getName(), 
                service, 
                interfaceClass
            );
            logger.info("Register service: {} version: {}", 
                interfaceClass.getName(), version);
        }
    }
    
    public void start() throws InterruptedException {
        bossGroup = new NioEventLoopGroup(1);
        workerGroup = new NioEventLoopGroup();
        
        try {
            ServerBootstrap b = new ServerBootstrap();
            b.group(bossGroup, workerGroup)
                .channel(NioServerSocketChannel.class)
                .childHandler(new ChannelInitializer<SocketChannel>() {
                    @Override
                    protected void initChannel(SocketChannel ch) {
                        ChannelPipeline p = ch.pipeline();
                        p.addLast(new IdleStateHandler(60, 0, 0, TimeUnit.SECONDS));
                        p.addLast(new RpcCodec());
                        p.addLast(new RpcServerHandler(serviceRegistry));
                    }
                })
                .option(ChannelOption.SO_BACKLOG, 128)
                .childOption(ChannelOption.SO_KEEPALIVE, true);
            
            ChannelFuture f = b.bind(port).sync();
            logger.info("RPC Server started on port {}", port);
            f.channel().closeFuture().sync();
        } finally {
            workerGroup.shutdownGracefully();
            bossGroup.shutdownGracefully();
        }
    }
    
    public void stop() {
        if (bossGroup != null) {
            bossGroup.shutdownGracefully();
        }
        if (workerGroup != null) {
            workerGroup.shutdownGracefully();
        }
    }
}

6. 使用示例

6.1 定义服务接口
java 复制代码
public interface UserService {
    User getUserById(Long id);
    List<User> findUsers(String name);
    CompletableFuture<User> getUserAsync(Long id);
}

@Data
public class User {
    private Long id;
    private String name;
    private String email;
}
6.2 服务端实现
java 复制代码
@RpcService(version = "1.0")
public class UserServiceImpl implements UserService {
    @Override
    public User getUserById(Long id) {
        User user = new User();
        user.setId(id);
        user.setName("User-" + id);
        user.setEmail(id + "@example.com");
        return user;
    }
    
    @Override
    public List<User> findUsers(String name) {
        List<User> users = new ArrayList<>();
        for (int i = 1; i <= 3; i++) {
            User user = new User();
            user.setId((long) i);
            user.setName(name + "-" + i);
            users.add(user);
        }
        return users;
    }
    
    @Override
    public CompletableFuture<User> getUserAsync(Long id) {
        return CompletableFuture.supplyAsync(() -> getUserById(id));
    }
}

// 启动服务端
public class ServerApplication {
    public static void main(String[] args) throws InterruptedException {
        RpcServer server = new RpcServer(8080);
        server.registerService(new UserServiceImpl());
        server.start();
    }
}
6.3 客户端使用
java 复制代码
@Service
public class UserController {
    @RpcReference(version = "1.0", timeout = 3000)
    private UserService userService;
    
    // 同步调用
    public User getUserSync(Long id) {
        return userService.getUserById(id);
    }
    
    // 异步调用
    public CompletableFuture<User> getUserAsync(Long id) {
        return userService.getUserAsync(id);
    }
    
    // 手动创建代理
    public static void main(String[] args) throws Exception {
        RpcClientProxy proxy = new RpcClientProxy("127.0.0.1", 8080);
        UserService userService = proxy.createProxy(
            UserService.class, 
            new RpcReference() {
                @Override
                public String version() { return "1.0"; }
                @Override
                public String group() { return ""; }
                @Override
                public long timeout() { return 3000; }
                @Override
                public boolean async() { return false; }
                @Override
                public String loadBalance() { return "random"; }
                @Override
                public Class<? extends Annotation> annotationType() { return RpcReference.class; }
            }
        );
        
        User user = userService.getUserById(1L);
        System.out.println("User: " + user);
    }
}

关键特性

  1. 连接池管理:支持最小连接数、最大连接数配置,自动维护连接健康
  2. Promise异步处理:使用Netty的DefaultPromise处理异步结果
  3. 注解驱动:支持@RpcService和@RpcReference注解
  4. 异步调用支持:支持同步和异步两种调用方式
  5. 心跳机制:保持连接活跃,自动检测断连
  6. 序列化扩展:支持多种序列化方式
  7. 负载均衡:支持随机、轮询等负载均衡策略(示例中使用了随机)
  8. 超时控制:支持方法级别的超时配置
  9. 单向调用:支持不需要返回结果的调用

优化建议

  1. 增加服务发现:集成Zookeeper、Nacos等服务注册中心
  2. 熔断降级:集成Hystrix或Sentinel
  3. 监控指标:增加Metrics收集
  4. 配置中心:支持动态配置更新
  5. 链路追踪:集成SkyWalking或Zipkin

这个框架已经具备了生产环境使用的基本特性,可以根据具体需求进行扩展和优化。

相关推荐
四维碎片3 小时前
QSettings + INI 笔记
笔记·qt·算法
DARLING Zero two♡5 小时前
【计算机网络】简学深悟启示录:http
网络协议·计算机网络·http
Yu_Lijing6 小时前
《图解HTTP》笔记与读后感(上)
网络·笔记·网络协议·http
北京耐用通信6 小时前
耐达讯自动化 Profibus 光纤链路模块:破解变频器通信难题,助力物流自动化升级
人工智能·物联网·网络协议·自动化·信息与通信
开开心心_Every8 小时前
家长控制电脑软件:定时锁屏管理使用时长
网络协议·tcp/ip·游戏·微信·pdf·excel·语音识别
weixin_430750939 小时前
堡垒机jumpserver的部署与应用
运维·网络·网络协议·jumpserver
看今朝·9 小时前
MCP协议与HTTP接口区别
网络·网络协议·http·mcp
SilentSlot9 小时前
【QT-QML】1. 快速入门
开发语言·qt·qml
全球优质代理IP10 小时前
【动态住宅 IP 助力广告跑量】 出海广告跑不起来怎么办?
网络·网络协议·tcp/ip