Netty 整合 WebSocket 通过拆包来发送超大数据包的方案

引言

Netty 是一个基于 NIO 的网络通信框架,提供了高性能的异步事件驱动机制。WebSocket 是一种在单个 TCP 连接上进行全双工通信的协议,常用于实现实时通信功能。通过 Netty 实现 WebSocket,开发者能够轻松构建出强大的实时应用程序。

在整合过程中,通常通过 WebSocketServerProtocolHandler 来实现 WebSocket 协议,其中该类中设置了最大帧长度为 65536,当然,可以通过构造函数中的 maxFramePayloadLength 参数来指定最大帧长度。关于这个参数,netty 源码中如下介绍:

maxFramePayloadLength -- Maximum length of a frame's payload. Setting this to an appropriate value for you application helps check for denial of services attacks.

可以看到,该值的主要作用是为了防止异常情况或者恶意攻击引发问题,而不是想要限制实际应用中合理的消息大小。

但在实际需求中,可能需要发送比较长的数据包,又不能确定具体需要多大的限制,而设置太大,又失去了这个属性的初衷。

这里,可以用拆分数据包的方式,将大数据包拆分为多个小数据包,分批发送,并在接收端重新组装这些小数据包,从而解决这个问题。

实现拆分数据包发送

拆分数据包(发送端)

拆包发送

在发送端,通过 Netty 的 ChannelHandlerContextwrite 方法,将大数据包拆分为多个小数据包,并逐个发送。简单实现如下:

java 复制代码
// byteBuf 为原有数据包
while(content.isReadable()) {
    // CHUNK_SIZE : 每个数据包的最大长度
    int chunkSize = Math.min(content.readableBytes(), CHUNK_SIZE);
    // 拆分数据包
    ByteBuf chunk = content.readSlice(chunkSize).retain();
    // 发送
    ctx.writeAndFlush(new BinaryWebSocketFrame(chunk));
}

解决半包

这里,并没有告知接收端,接收到的数据包,该什么时候聚合。这里可以通过 Netty 解决半包常用的两种方式:

  • 发送长度
  • 包尾增加分隔符

这里先以发送长度的方式为例来实现,在消息体最开始的位置,添加长度。

java 复制代码
// 这里声明一个新的 ByteBuf
ByteBuf byteBuf = ctx.alloc().buffer();
// header
byteBuf.writeInt(content.readableBytes());
// body
byteBuf.writeBytes(content);

// 原来的消息体用不到了,手动释放掉
ReferenceCountUtil.release(content);

while (byteBuf.isReadable()) {
    int chunkSize = Math.min(byteBuf.readableBytes(), CHUNK_SIZE);
    System.out.println("分片数据长度 >>>>>>>>>> " + chunkSize);
    ByteBuf chunk = byteBuf.readSlice(chunkSize).retain();
    ctx.writeAndFlush(new BinaryWebSocketFrame(chunk));
}

组装数据包(接收端)

在接收端,通过 Netty 的 ChannelHandlerContextchannelRead 方法,累积接收到的小数据包,并判断是否已经接收到一个完整的数据包。如果是,则进行相应的处理。

组装数据

java 复制代码
private ByteBuf cumulativeBuffer = Unpooled.buffer();

public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
     // 这里因为是 WebSocket 协议,接收到的数据包为 WebSocketFrame
     if (msg instanceof WebSocketFrame) {
         WebSocketFrame frame = (WebSocketFrame) msg;
         ByteBuf content = frame.content();
         
         // 累积数据包
         cumulativeBuffer.writeBytes(content);
         
         // 根据判断,判断是否已经接收到完整数据包
         if (isCompletePacket(cumulativeBuffer)) {
             try {
                 // 处理完整数据包 cumulativeBuffer
             } finally {
                 // 最终释放 cumulativeBuffer 数据包的内存
                ReferenceCountUtil.release(cumulativeBuffer);
                cumulativeBuffer = Unpooled.buffer();
             }
         }
     } else {
         super.channelRead(ctx, msg);
     }
}

组装数据简单可以分为两步:

  1. 累积数据包
  2. 判断数据包是否完整,如果完整则处理该包,最终释放该包

这里的判断数据包是否完整,并没有实现,这里需要与发送端保持一致,也可以通过判断长度或者判断包尾分隔符来判断。这里以判断长度为例。

首先,需要解析长度:

java 复制代码
private int expectedLength = -1;

public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
     // 这里因为是 WebSocket 协议,接收到的数据包为 WebSocketFrame
     if (msg instanceof WebSocketFrame) {
         // ... ...
         
         if (expectedLength == -1) {
             expectedLength = content.readInt();
         }
         
        // ... ...
     } else {
         super.channelRead(ctx, msg);
     }
}

接下来,判断数据包是否完整方法如下:

java 复制代码
private boolean isCompletePacket(ByteBuf cumulativeBuffer) {
    return cumulativeBuffer.readableBytes() == expectedLength;
}

数据包完成,并处理完后, 同时需要将 expectedLength 改为 -1

最后,防止内存泄露,还需要在关闭时,将内存释放。

完整代码

完整代码如下,完整代码中增加了部分异常场景的判断。

java 复制代码
public class PacketSplitHandler extends ChannelDuplexHandler {

    private static final Log LOG = LogFactory.get(PacketSplitHandler.class);

    private static final int CHUNK_SIZE = 65536;

    private static final int HEADER_SIZE = 4;

    private ByteBuf cumulativeBuffer = Unpooled.buffer();

    private int expectedLength = -1;

    @Override
    public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
        if (msg instanceof ByteBuf) {
            ByteBuf content = (ByteBuf) msg;

            System.out.println("发送总数据长度 >>>>>>> " + content.readableBytes());
            ByteBuf byteBuf = ctx.alloc().buffer();
            // header
            byteBuf.writeInt(content.readableBytes());
            // body
            byteBuf.writeBytes(content);

            ReferenceCountUtil.release(content);

            while (byteBuf.isReadable()) {
                int chunkSize = Math.min(byteBuf.readableBytes(), CHUNK_SIZE);
                System.out.println("分片数据长度 >>>>>>>>>> " + chunkSize);
                ByteBuf chunk = byteBuf.readSlice(chunkSize).retain();
                ctx.writeAndFlush(new BinaryWebSocketFrame(chunk));
            }
        } else {
            super.write(ctx, msg, promise);
        }
    }

    @Override
    public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
        if (msg instanceof WebSocketFrame) {
            WebSocketFrame frame = (WebSocketFrame) msg;
            ByteBuf content = frame.content();

            System.out.println("接收到数据包长度 >>>>>>>>>> " + content.readableBytes());

            if (expectedLength == -1 && content.readableBytes() > HEADER_SIZE) {
                expectedLength = content.readInt();
                System.out.println("期望长度 >>>>>>>>>> " + expectedLength);
                if (expectedLength <= -1) {
                    releaseCumulativeBuffer();
                    super.channelRead(ctx, msg);
                    return;
                }
            }

            cumulativeBuffer.writeBytes(content);
            System.out.println("当前汇总数据包长度 >>>>>>>>>>" + cumulativeBuffer.readableBytes());

            if (isCompletePacket(cumulativeBuffer)) {
                ByteBuf buffer = ctx.alloc().buffer();
                buffer.writeBytes(cumulativeBuffer);
                releaseCumulativeBuffer();
                super.channelRead(ctx, buffer);
            } else if (cumulativeBuffer.readableBytes() > expectedLength) {
                LOG.error("data package length error, exceeding expected length");
                releaseCumulativeBuffer();
                super.channelRead(ctx, msg);
            } else {
                super.channelRead(ctx, msg);
            }
        } else {
            super.channelRead(ctx, msg);
        }
    }

    @Override
    public void channelInactive(ChannelHandlerContext ctx) throws Exception {
        try {
            ReferenceCountUtil.release(cumulativeBuffer);
        } catch (Throwable e) {
            // ignore
        }
        super.channelInactive(ctx);
    }

    private void releaseCumulativeBuffer() {
        try {
            expectedLength = -1;
            ReferenceCountUtil.release(cumulativeBuffer);
            cumulativeBuffer = Unpooled.buffer();
        } catch (Throwable e) {
            // ignore
        }
    }

    private boolean isCompletePacket(ByteBuf cumulativeBuffer) {
        return cumulativeBuffer.readableBytes() == expectedLength;
    }

}

示例

服务端

java 复制代码
public class PacketSplitServer {

    public static void main(String[] args) throws InterruptedException {
        EventLoopGroup bossGroup = new NioEventLoopGroup();
        EventLoopGroup workerGroup = new NioEventLoopGroup();

        try {
            ServerBootstrap bootstrap = new ServerBootstrap();
            bootstrap.group(bossGroup, workerGroup)
                .channel(NioServerSocketChannel.class)
                .childHandler(new ChannelInitializer<SocketChannel>() {
                    @Override
                    protected void initChannel(SocketChannel ch) {
                        ChannelPipeline pipeline = ch.pipeline();
                        pipeline.addLast(new HttpServerCodec());
                        pipeline.addLast(new HttpObjectAggregator(1024 * 1024 * 10));
                        pipeline.addLast(new WebSocketServerProtocolHandler("/ws", null, true));
                        pipeline.addLast(new NettyServerHandler());
                    }
                });

            ChannelFuture channelFuture = bootstrap.bind(8999).sync();
            channelFuture.channel().closeFuture().sync();
        } finally {
            bossGroup.shutdownGracefully();
            workerGroup.shutdownGracefully();
        }
    }

    static class NettyServerHandler extends ChannelInboundHandlerAdapter {
        @Override
        public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
            if (msg instanceof WebSocketFrame) {
                WebSocketFrame webSocketFrame = (WebSocketFrame) msg;
                System.out.println("接收到的数据包大小 : " + ((WebSocketFrame) msg).content().readableBytes());
            }
            super.channelRead(ctx, msg);
        }
    }

}

客户端

java 复制代码
public class PacketSplitClient {


    public static void main(String[] args) throws URISyntaxException, InterruptedException {
        URI uri = new URI("ws://localhost:8999/ws");
        EventLoopGroup group = new NioEventLoopGroup();

        WebSocketClientProtocolConfig webSocketClientProtocolConfig =
            WebSocketClientProtocolConfig.newBuilder().webSocketUri(uri).subprotocol(null).allowExtensions(true)
                .version(WebSocketVersion.V13).customHeaders(new DefaultHttpHeaders()).build();

        try {
            Bootstrap bootstrap = new Bootstrap();
            bootstrap.group(group)
                .channel(NioSocketChannel.class)
                .handler(new ChannelInitializer<SocketChannel>() {
                    @Override
                    protected void initChannel(SocketChannel ch) {
                        ChannelPipeline pipeline = ch.pipeline();
                        pipeline.addLast(new HttpClientCodec());
                        pipeline.addLast(new HttpObjectAggregator(1024 * 1024 * 10));
                        pipeline.addLast(new WebSocketClientProtocolHandler(webSocketClientProtocolConfig));
                        pipeline.addLast(new NettyClientHandler());
                    }
                });

            Channel channel = bootstrap.connect(uri.getHost(), uri.getPort()).sync().channel();
            channel.closeFuture().sync();
        } finally {
            group.shutdownGracefully();
        }
    }

    static class NettyClientHandler extends ChannelInboundHandlerAdapter {
        @Override
        public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
            if (WebSocketClientProtocolHandler.ClientHandshakeStateEvent.HANDSHAKE_COMPLETE.equals(evt)) {
                new ScheduledThreadPoolExecutor(1)
                    .scheduleWithFixedDelay(() -> {
                        System.out.println("ws连接成功");
                        int length = 65537;
                        byte[] bytes = new byte[length];
                        Arrays.fill(bytes, (byte) 127);

                        ByteBuf byteBuf = ctx.alloc().buffer();
                        byteBuf.writeBytes(bytes);

                        ctx.writeAndFlush(new BinaryWebSocketFrame(byteBuf));
                    }, 2, 30, TimeUnit.SECONDS);
            }
            super.userEventTriggered(ctx, evt);
        }
    }

}

这里服务端在握手成功后,发送一个 65537 长度的数据包,由于服务端默认配置数据帧最大为 65536,故服务端控制台抛出异常:

lua 复制代码
io.netty.handler.codec.http.websocketx.CorruptedWebSocketFrameException: Max frame length of 65536 has been exceeded.

测试拆包发送

在服务端 WebSocketServerProtocolHandler 后面添加 PacketSplitHandler

java 复制代码
// ... ...
pipeline.addLast(new WebSocketServerProtocolHandler("/ws", null, true));
pipeline.addLast(new PacketSplitHandler());
pipeline.addLast(new NettyServerHandler());
// ... ... 

因为在 PacketSplitHandlerchannelRead 方法中定义聚合后的数据包类型为 ByteBuf,故 NettyServerHandler 修改如下:

java 复制代码
static class NettyServerHandler extends ChannelInboundHandlerAdapter {
    @Override
    public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
        if (msg instanceof ByteBuf) {
            SSystem.out.println(">>>>>>>> 接收到的数据包大小 : " + ((ByteBuf) msg).readableBytes());
        }
        super.channelRead(ctx, msg);
    }
}

同时,在客户端 WebSocketClientProtocolHandler 后面添加 PacketSplitHandler

java 复制代码
// ... ...
pipeline.addLast(new WebSocketClientProtocolHandler(webSocketClientProtocolConfig));
pipeline.addLast(new PacketSplitHandler());
pipeline.addLast(new NettyClientHandler());
// ... ...

由于在 PacketSplitHandlerwrite 方法中定义只处理 ByteBuf 类型的数据,故发送数据时,修改为:

java 复制代码
ctx.writeAndFlush(byteBuf);

重新测试

  • 客户端打印日志如下:

    发送总数据长度 >>>>>>> 65537
    分片数据长度 >>>>>>>>>> 65536
    分片数据长度 >>>>>>>>>> 5

  • 服务端打印日志如下:

shell 复制代码
接收到数据包长度 >>>>>>>>>> 65536
期望长度 >>>>>>>>>> 65537
当前汇总数据包长度 >>>>>>>>>>65532
接收到数据包长度 >>>>>>>>>> 5
当前汇总数据包长度 >>>>>>>>>>65537
>>>>>>>> 接收到的数据包大小 : 65537

可以看到,65537 长度的数据包,分为两个包发送, 第一个包前 4 位为该包总长度(65537),第一包长度为 65536(4 + 65532),第一个包实际发送数据长度为 65532, 第二个不需要包含长度,故第二个包的长度为 5(65537 - 65532)。

我是「代码笔耕 」,致力于打造高效简洁、稳定可靠代码的后端开发。 本文可能存在纰漏或错误,如有问题欢迎指正,感谢您阅读这篇文章,如果觉得还行的话,不要忘记点赞、评论、收藏喔!

最后欢迎大家关注我的公众号「代码笔耕」和开源项目:easii (easii) - Gitee.com

相关推荐
qq_441996052 小时前
Mybatis官方生成器使用示例
java·mybatis
巨大八爪鱼2 小时前
XP系统下用mod_jk 1.2.40整合apache2.2.16和tomcat 6.0.29,让apache可以同时访问php和jsp页面
java·tomcat·apache·mod_jk
码上一元4 小时前
SpringBoot自动装配原理解析
java·spring boot·后端
计算机-秋大田4 小时前
基于微信小程序的养老院管理系统的设计与实现,LW+源码+讲解
java·spring boot·微信小程序·小程序·vue
魔道不误砍柴功6 小时前
简单叙述 Spring Boot 启动过程
java·数据库·spring boot
失落的香蕉6 小时前
C语言串讲-2之指针和结构体
java·c语言·开发语言
枫叶_v6 小时前
【SpringBoot】22 Txt、Csv文件的读取和写入
java·spring boot·后端
wclass-zhengge6 小时前
SpringCloud篇(配置中心 - Nacos)
java·spring·spring cloud
路在脚下@6 小时前
Springboot 的Servlet Web 应用、响应式 Web 应用(Reactive)以及非 Web 应用(None)的特点和适用场景
java·spring boot·servlet
黑马师兄6 小时前
SpringBoot
java·spring