【编码】自定义通信协议——实现零拷贝文件传输

前言

上一篇随笔,介绍了如何扩展自定义协议的请求类型。本篇随笔我将介绍如何基于这个自定义协议来实现文件传输,其中将涉及数据分片零拷贝

在设计自定义协议之前,我们首先了解一下HTTP协议是如何处理文件传输的。

HTTP协议的实现方式

在这里,我们主要讨论应用最广泛的HTTP/1.1协议

关于数据分片

HTTP 协议本身是一个纯文本协议,其中的 Content-Length 头部字段用于指定响应体(body)的内容长度。Content-Length 是纯文本格式,理论上没有长度限制,因此在大多数情况下,HTTP 协议可以一次性传输整个文件。

对于较大的文件,通常情况下,可以通过一个请求下载整个文件,这也是许多网站和服务的常见做法。但如果文件特别大,或者为了提高下载效率(例如支持断点续传、并行下载等),就需要在应用层处理文件的分片。例如,服务端可以先返回文件的分段信息,然后客户端逐个请求文件的不同部分。

关于零拷贝

HTTP 协议的客户端库通常不暴露底层的 socket 连接,导致上层应用无法直接操作 socket 进行零拷贝传输。

大多数情况下,数据需要先被拷贝到进程的内存中,再传输给 HTTP 客户端。

由于 HTTP 客户端库的限制,零拷贝技术在 HTTP 协议的应用中并不直接适用。

自定义协议

关于数据分片

在自定义协议中,我们可以更灵活地控制传输过程。例如,我们只使用 3 个字节来表示消息体的长度,因此协议的最大传输内容为 16MB(2^24 - 1 字节)。

对于超出该限制的内容,我们必须进行分块处理,确保每个数据块都符合协议的长度限制。

关于零拷贝

自定义协议可以引用到Socket,所以可以使用零拷贝,避免数据在内存和磁盘之间的多次拷贝,从而提高传输效率,减少 CPU 的负载。

初步设计

如何构建数据包?

消息体是一个完整的ProtoBuf BaseResponse消息

  • msgId:请求ID
  • headers:文件名+文件大小+分块数量+分块号
  • bytes:文件分块数据
复制代码
message BaseResponse {
    required int32 msgId = 1;
    repeated Header headers = 2;
    optional bytes data = 3;
}

消息体分两部分发送

1.先发送元数据(BaseResponse的msgId+headers)

2.后发送文件数据

服务端:

1.截取文件范围得到chunkSize

2.构建BaseResponse(仅包含msgId和headers)

3.计算得到消息体Length = BaseResponse的大小+chunkSize

4.发出消息头

5.发出BaseResponse

6.零拷贝发出文件chunk

客户端:

1.将消息体作为一个完整的BaseResponse进行解析。

冲突?ProtoBuf与零拷贝

在处理过程中,我们会遇到一个问题:ProtoBuf 的解析过程需要特定的编码格式,拼接进去的文件内容无法直接作为 ProtoBuf 消息的一部分。

如果需要ProtoBuf能识别这个文件内容,则文件数据必须参与编码,要参与编码就得载入到进程内存中。这跟零拷贝是相悖的。

如何处理这个问题?

再加一个length!消息体分为三部分:

  • 2字节,作为proto消息的长度信息。(元数据字节数有限,2字节足够表示)
  • n字节,proto消息(msgId+headers)
  • n字节,文件chunk数据

处理逻辑

1)服务端代码

Java的零拷贝API是FileChannel.transferTo(long position, long count, WritableByteChannel)。

不过Netty的Channel不是WritableByteChannel的子类。要使用零拷贝,得用Netty提供的FileRegion。底层也是调用FileChannel的transferTo。

复制代码
    public void handleDownloadRequest(BaseRequest baseRequest, ChannelHandlerContext ctx) throws Exception {
        File file = new File("F:\\redis.log");
        RandomAccessFile raf = new RandomAccessFile(file, "r");
        FileChannel fileChannel = null;

        long fileLength = raf.length();
        System.out.println("file length" + fileLength);
        long offset = 0;


        int chunkIndex = 0;
        int totalChunks = (int) Math.ceil((double) fileLength / MAX_CHUNK_SIZE);
        boolean firstPackage = true;

        while(offset < fileLength) {
            raf = new RandomAccessFile(file, "r");
            fileChannel = raf.getChannel();
            System.out.println("open:"+fileChannel.isOpen());
            //文件块大小
            long chunkSize = Math.min(MAX_CHUNK_SIZE, fileLength - offset);
            System.out.println("chunkSize:"+chunkSize);

            // 创建 FileRegion 来传输当前文件块
            FileRegion fileRegion = new DefaultFileRegion(fileChannel, offset, chunkSize);


            List<Header> headers = new ArrayList<>();
            if(firstPackage) {
                headers.add(Header.newBuilder().setKey("fileName").setValue(file.getName()).build());
                headers.add(Header.newBuilder().setKey("fileSize").setValue(String.valueOf(fileLength)).build());
                headers.add(Header.newBuilder().setKey("totalChunks").setValue(String.valueOf(totalChunks)).build());
            }
            headers.add(Header.newBuilder().setKey("chunkIndex").setValue(String.valueOf(chunkIndex)).build());

            //发送消息体的上半部分(msgId+headers)
            BaseResponse response = BaseResponse.newBuilder()
                    .setMsgId(baseRequest.getMsgId())
                    .addAllHeaders(headers)
                    .build();
            byte[] payloadHeadBytes = response.toByteArray();
            long bodyLength = 2 + payloadHeadBytes.length + chunkSize; //两个字节

            byte[] lengthBytes = new byte[3];
            lengthBytes[0] = (byte) (bodyLength >> 16);
            lengthBytes[1] = (byte) (bodyLength >> 8);
            lengthBytes[2] = (byte) bodyLength;

            //protobuf长度
            long length2 = payloadHeadBytes.length;
            byte[] lengthBytes2 = new byte[2];
            lengthBytes2[0] = (byte) (length2 >> 8);
            lengthBytes2[1] = (byte) (length2);

            //发送消息头+消息体的上半部分
            ByteBuf byteBuf = Unpooled.copiedBuffer(new byte[]{5}, lengthBytes, lengthBytes2, payloadHeadBytes);
            ChannelFuture f1 = ctx.channel().writeAndFlush(byteBuf);
            f1.sync();
//            System.out.println("f1:"+f1.isSuccess());
            //零拷贝写出文件数据(文件内容无需进入用户区内存,直接拷贝到socket发送缓冲区)
            ChannelFuture f2 = ctx.writeAndFlush(fileRegion);
            f2.sync();
//            System.out.println("f2:"+f2.isSuccess());

            firstPackage = false;
            // 更新偏移量
            offset += chunkSize;
            System.out.println("写出:"+bodyLength);

            raf.close();
        }

    }

2)客户端代码

复制代码
public class DownloadManager {
    private Map<Integer, FileDownContext> waitingMap = new ConcurrentHashMap<>();

    public void addToMap(Integer msgId, CompletableFuture<String> waiter) {
        waitingMap.put(msgId, new FileDownContext(null, null, 0L, 0.0d, waiter));
    }

    public void onResponse(BaseResponse response) {
//        System.out.println("收到:"+response.getMsgId());
        Integer msgId = response.getMsgId();
        FileDownContext context = waitingMap.get(msgId);
        if(Objects.isNull(context)) {
            return;
        }
        //首包带有这两个信息
        for (Header header : response.getHeadersList()) {
            if(StrUtil.equals(header.getKey(), "fileName")) {
                context.setFileName(header.getValue());
            }
            if(StrUtil.equals(header.getKey(), "totalChunks")) {
                context.setTotalChunks(Long.parseLong(header.getValue()));
            }
        }
        //更新接收情况
        context.receivedChunks++;
        context.progress = (double)context.receivedChunks/context.totalChunks;

        try {
            //文件如果不存在,则创建
            Path filePath = Paths.get("F:\\clientDownload\\" + context.fileName);
            if(!Files.exists(filePath)) {
                Files.createFile(filePath);
            }
            //追加写入文件
            Files.write(filePath, response.getData().toByteArray(), StandardOpenOption.APPEND);
        } catch (IOException e) {
            e.printStackTrace();
        }
        //完成请求,释放Context
        if(Objects.equals(context.receivedChunks, context.totalChunks)) {
            context.waiter.complete(context.fileName);
            waitingMap.remove(msgId);
        }
    }

    @Data
    @AllArgsConstructor
    class FileDownContext {
        String fileName;
        Long totalChunks;
        Long receivedChunks;
        Double progress;
        CompletableFuture<String> waiter;
    }
}