[netty5: WebSocketFrameEncoder & WebSocketFrameDecoder]-源码解析

WebSocketFrameMaskGenerator

WebSocketFrameMaskGenerator 是用于客户端生成 WebSocket 帧掩码的接口,通过 nextMask() 返回一个 4 字节整数掩码,用于加密帧的 payload。

java 复制代码
public interface WebSocketFrameMaskGenerator {
    int nextMask();
}

RandomWebSocketFrameMaskGenerator

RandomWebSocketFrameMaskGenerator 实现了掩码生成接口,使用 ThreadLocalRandom 生成随机的 4 字节整数作为 WebSocket 帧的掩码。

java 复制代码
public final class RandomWebSocketFrameMaskGenerator implements WebSocketFrameMaskGenerator {

    public static final RandomWebSocketFrameMaskGenerator INSTANCE = new RandomWebSocketFrameMaskGenerator();

    private RandomWebSocketFrameMaskGenerator() {}

    @Override
    public int nextMask() {
        return ThreadLocalRandom.current().nextInt();
    }
}

WebSocketFrameEncoder

WebSocketFrameEncoder 负责将 WebSocketFrame 编码为符合 WebSocket 协议格式的二进制数据帧,处理帧头构造、负载长度扩展、掩码生成与数据异或。

java 复制代码
public interface WebSocketFrameEncoder extends ChannelHandler {}

WebSocket13FrameEncoder

WebSocket13FrameEncoder 将 WebSocket 帧编码成符合 RFC 6455 的二进制格式,支持负载长度扩展、可选掩码处理和分片发送,确保客户端数据按规范加密掩码。

java 复制代码
public class WebSocket13FrameEncoder extends MessageToMessageEncoder<WebSocketFrame> implements WebSocketFrameEncoder {

    private static final Logger logger = LoggerFactory.getLogger(WebSocket13FrameEncoder.class);
    
    private static final byte OPCODE_CONT = 0x0;
    private static final byte OPCODE_TEXT = 0x1;
    private static final byte OPCODE_BINARY = 0x2;
    private static final byte OPCODE_CLOSE = 0x8;
    private static final byte OPCODE_PING = 0x9;
    private static final byte OPCODE_PONG = 0xA;

    private static final int GATHERING_WRITE_THRESHOLD = 1024;
    
    private final WebSocketFrameMaskGenerator maskGenerator;

    public WebSocket13FrameEncoder(boolean maskPayload) {
        this(maskPayload ? RandomWebSocketFrameMaskGenerator.INSTANCE : null);
    }

    public WebSocket13FrameEncoder(WebSocketFrameMaskGenerator maskGenerator) {
        this.maskGenerator = maskGenerator;
    }

	//  0               1               2               3 
	//  0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 
	// +-+-+-+-+-------+-+-------------+-------------------------------+
	// |F|R|R|R| opcode|M| Payload len |    Extended payload length    |
	// |I|S|S|S|  (4)  |A|     (7)     |     (16/64 bits if needed)    |
	// +-+-+-+-+-------+-+-------------+-------------------------------+
	// |     Masking key (32 bits, only if MASK set to 1)             |
	// +---------------------------------------------------------------+
	// |     Masked/unmasked payload data                             |
	// +---------------------------------------------------------------+    
	@Override
    protected void encode(ChannelHandlerContext ctx, WebSocketFrame msg, List<Object> out) throws Exception {
        final Buffer data = msg.binaryData();

        byte opcode = getOpCode(msg);

        int length = data.readableBytes();

        if (logger.isTraceEnabled()) {
            logger.trace("Encoding WebSocket Frame opCode={} length={}", opcode, length);
        }

		// 构造第一个字节(b0): FIN RSV1 RSV2 RSV3 opcode
        int b0 = 0;
        if (msg.isFinalFragment()) {
            b0 |= 1 << 7; // FIN 位
        }
        b0 |= (msg.rsv() & 0x07) << 4; // RSV1、RSV2、RSV3
        b0 |= opcode & 0x7F; // 低 4 位 opcode
        
		// RFC 要求 PING 帧 payload 最大长度为 125 字节
        if (opcode == OPCODE_PING && length > 125) {
            throw new TooLongFrameException("invalid payload for PING (payload length must be <= 125, was " + length);
        }

		// 初始化输出 Buffer(header + payload + 掩码)
        Buffer buf = null;
        try {
            int maskLength = maskGenerator != null ? 4 : 0;
            if (length <= 125) {
            	// 1 b0 +  1 (0x80-10000000 + length-0xxxxxxxx) + 4 mask + payload data
                int size = 2 + maskLength + length;
                buf = ctx.bufferAllocator().allocate(size);
                buf.writeByte((byte) b0);
                byte b = (byte) (maskGenerator != null ? 0x80 | length : length);
                buf.writeByte(b);
            } else if (length <= 0xFFFF) {
            	// 1 b0 +  1 0xFE-1111 1110 + 2 16bit Extended payload length
                int size = 4 + maskLength;
                if (maskGenerator != null || length <= GATHERING_WRITE_THRESHOLD) {
                    size += length;
                }
                buf = ctx.bufferAllocator().allocate(size);
                buf.writeByte((byte) b0);
                buf.writeByte((byte) (maskGenerator != null ? 0xFE : 126));
                buf.writeByte((byte) (length >>> 8 & 0xFF));
                buf.writeByte((byte) (length & 0xFF));
            } else {
            	// 1 b0 +  1 0xFE-1111 1110 + 8 64bit Extended payload length
                int size = 10 + maskLength;
                if (maskGenerator != null || length <= GATHERING_WRITE_THRESHOLD) {
                    size += length;
                }
                buf = ctx.bufferAllocator().allocate(size);
                buf.writeByte((byte) b0);
                buf.writeByte((byte) (maskGenerator != null ? 0xFF : 127));
                buf.writeLong(length);
            }

            // Write payload
            // 掩码是 4 字节,RFC 规定客户端 必须 对每个字节 payload[i] ^= mask[i % 4]。
            if (maskGenerator != null) {
                int mask = maskGenerator.nextMask();
                buf.writeInt(mask);

                if (mask != 0) {
                    int i = data.readerOffset();
                    int end = data.writerOffset();

                    int maskOffset = 0;
                    for (; i < end; i++) {
                        byte byteData = data.getByte(i);
                        buf.writeByte((byte) (byteData ^ WebSocketUtil.byteAtIndex(mask, maskOffset++ & 3)));
                    }
                    out.add(buf);
                } else {
                    addBuffers(buf, data, out);
                }
            } else {
                addBuffers(buf, data, out);
            }
        } catch (Throwable t) {
            if (buf != null) {
                buf.close();
            }
            throw t;
        }
    }

    private static byte getOpCode(WebSocketFrame msg) {
        if (msg instanceof TextWebSocketFrame) {
            return OPCODE_TEXT;
        }
        if (msg instanceof BinaryWebSocketFrame) {
            return OPCODE_BINARY;
        }
        if (msg instanceof PingWebSocketFrame) {
            return OPCODE_PING;
        }
        if (msg instanceof PongWebSocketFrame) {
            return OPCODE_PONG;
        }
        if (msg instanceof CloseWebSocketFrame) {
            return OPCODE_CLOSE;
        }
        if (msg instanceof ContinuationWebSocketFrame) {
            return OPCODE_CONT;
        }
        throw new UnsupportedOperationException("Cannot encode frame of type: " + msg.getClass().getName());
    }

    private static void addBuffers(Buffer buf, Buffer data, List<Object> out) {
        int readableBytes = data.readableBytes();

        if (buf.writableBytes() >= readableBytes) {
            // merge buffers as this is cheaper then a gathering write if the payload is small enough
            buf.writeBytes(data);
            out.add(buf);
        } else {
            out.add(buf);
            if (readableBytes > 0) {
                out.add(data.split());
            }
        }
    }
}

WebSocketFrameDecoder

WebSocketFrameDecoder 负责将符合 WebSocket 协议格式的二进制数据帧解码成 WebSocketFrame,处理帧头解析、负载长度读取、掩码应用及分片重组。

java 复制代码
public interface WebSocketFrameDecoder extends ChannelHandler {}

WebSocket13FrameDecoder

WebSocket13FrameDecoder 负责将接收到的二进制数据解析成符合 RFC 6455 规范的 WebSocket 帧对象,处理掩码解码、负载长度扩展、多帧分片合并及控制帧校验。

java 复制代码
public class WebSocket13FrameDecoder extends ByteToMessageDecoder implements WebSocketFrameDecoder {

    private static final Logger logger = LoggerFactory.getLogger(WebSocket13FrameDecoder.class);
    
    private static final byte OPCODE_CONT = 0x0;
    private static final byte OPCODE_TEXT = 0x1;
    private static final byte OPCODE_BINARY = 0x2;
    private static final byte OPCODE_CLOSE = 0x8;
    private static final byte OPCODE_PING = 0x9;
    private static final byte OPCODE_PONG = 0xA;
    
    private final WebSocketDecoderConfig config;
    private int fragmentedFramesCount;
    private boolean frameFinalFlag;
    private boolean frameMasked;
    private int frameRsv;
    private int frameOpcode;
    private long framePayloadLength;
    private int mask;
    private int framePayloadLen1;
    private boolean receivedClosingHandshake;
    
    private State state = State.READING_FIRST;

    public WebSocket13FrameDecoder(boolean expectMaskedFrames, boolean allowExtensions, int maxFramePayloadLength) {
        this(expectMaskedFrames, allowExtensions, maxFramePayloadLength, false);
    }

    public WebSocket13FrameDecoder(boolean expectMaskedFrames, boolean allowExtensions, int maxFramePayloadLength, boolean allowMaskMismatch) {
        this(WebSocketDecoderConfig.newBuilder()
            .expectMaskedFrames(expectMaskedFrames)
            .allowExtensions(allowExtensions)
            .maxFramePayloadLength(maxFramePayloadLength)
            .allowMaskMismatch(allowMaskMismatch)
            .build());
    }

    public WebSocket13FrameDecoder(WebSocketDecoderConfig decoderConfig) {
        config = Objects.requireNonNull(decoderConfig, "decoderConfig");
    }

    private static int toFrameLength(long length) {
        if (length > Integer.MAX_VALUE) {
            throw new TooLongFrameException("frame length exceeds " + Integer.MAX_VALUE + ": " + length);
        } else {
            return (int) length;
        }
    }

    @Override
    protected void decode(ChannelHandlerContext ctx, Buffer in) throws Exception {
        // Discard all data received if closing handshake was received before.
        if (receivedClosingHandshake) {
            in.skipReadableBytes(actualReadableBytes());
            return;
        }

        switch (state) {
        case READING_FIRST: {
            if (in.readableBytes() == 0) {
                return;
            }

            framePayloadLength = 0;

            // FIN, RSV, OPCODE
            byte b = in.readByte();
            frameFinalFlag = (b & 0x80) != 0;
            frameRsv = (b & 0x70) >> 4;
            frameOpcode = b & 0x0F;

            if (logger.isTraceEnabled()) {
                logger.trace("Decoding WebSocket Frame opCode={}", frameOpcode);
            }

            state = State.READING_SECOND;
        }
        case READING_SECOND: {
            if (in.readableBytes() == 0) {
                return;
            }
            // MASK, PAYLOAD LEN 1
            byte b = in.readByte();
            frameMasked = (b & 0x80) != 0;
            framePayloadLen1 = b & 0x7F;

            if (frameRsv != 0 && !config.allowExtensions()) {
                protocolViolation(ctx, in, "RSV != 0 and no extension negotiated, RSV:" + frameRsv);
                return;
            }

            if (!config.allowMaskMismatch() && config.expectMaskedFrames() != frameMasked) {
                protocolViolation(ctx, in, "received a frame that is not masked as expected");
                return;
            }

            if (frameOpcode > 7) { // control frame (have MSB in opcode set)

                // control frames MUST NOT be fragmented
                if (!frameFinalFlag) {
                    protocolViolation(ctx, in, "fragmented control frame");
                    return;
                }

                // control frames MUST have payload 125 octets or less
                if (framePayloadLen1 > 125) {
                    protocolViolation(ctx, in, "control frame with payload length > 125 octets");
                    return;
                }

                // check for reserved control frame opcodes
                if (!(frameOpcode == OPCODE_CLOSE || frameOpcode == OPCODE_PING
                      || frameOpcode == OPCODE_PONG)) {
                    protocolViolation(ctx, in, "control frame using reserved opcode " + frameOpcode);
                    return;
                }

                // close frame : if there is a body, the first two bytes of the
                // body MUST be a 2-byte unsigned integer (in network byte
                // order) representing a getStatus code
                if (frameOpcode == 8 && framePayloadLen1 == 1) {
                    protocolViolation(ctx, in, "received close control frame with payload len 1");
                    return;
                }
            } else { // data frame
                // check for reserved data frame opcodes
                if (!(frameOpcode == OPCODE_CONT || frameOpcode == OPCODE_TEXT
                      || frameOpcode == OPCODE_BINARY)) {
                    protocolViolation(ctx, in, "data frame using reserved opcode " + frameOpcode);
                    return;
                }

                // check opcode vs message fragmentation state 1/2
                if (fragmentedFramesCount == 0 && frameOpcode == OPCODE_CONT) {
                    protocolViolation(ctx, in, "received continuation data frame outside fragmented message");
                    return;
                }

                // check opcode vs message fragmentation state 2/2
                if (fragmentedFramesCount != 0 && frameOpcode != OPCODE_CONT) {
                    protocolViolation(ctx, in,
                                      "received non-continuation data frame while inside fragmented message");
                    return;
                }
            }

            state = State.READING_SIZE;
        }
        case READING_SIZE: {
            // Read frame payload length
            if (framePayloadLen1 == 126) {
                if (in.readableBytes() < 2) {
                    return;
                }
                framePayloadLength = in.readUnsignedShort();
                if (framePayloadLength < 126) {
                    protocolViolation(ctx, in, "invalid data frame length (not using minimal length encoding)");
                    return;
                }
            } else if (framePayloadLen1 == 127) {
                if (in.readableBytes() < 8) {
                    return;
                }
                framePayloadLength = in.readLong();
                // TODO: check if it's bigger than 0x7FFFFFFFFFFFFFFF, Maybe
                // just check if it's negative?

                if (framePayloadLength < 65536) {
                    protocolViolation(ctx, in, "invalid data frame length (not using minimal length encoding)");
                    return;
                }
            } else {
                framePayloadLength = framePayloadLen1;
            }

            if (framePayloadLength > config.maxFramePayloadLength()) {
                protocolViolation(ctx, in, WebSocketCloseStatus.MESSAGE_TOO_BIG,
                                  "Max frame length of " + config.maxFramePayloadLength() + " has been exceeded.");
                return;
            }

            if (logger.isTraceEnabled()) {
                logger.trace("Decoding WebSocket Frame length={}", framePayloadLength);
            }

            state = State.MASKING_KEY;
        }
        case MASKING_KEY: {
            if (frameMasked) {
                if (in.readableBytes() < 4) {
                    return;
                }
                mask = in.readInt();
            }
            state = State.PAYLOAD;
        }
        case PAYLOAD: {
            if (in.readableBytes() < framePayloadLength) {
                return;
            }

            Buffer payloadBuffer = null;
            try {
                payloadBuffer = in.readSplit(toFrameLength(framePayloadLength));

                // Now we have all the data, the next checkpoint must be the next
                // frame
                state = State.READING_FIRST;

                // Unmask data if needed
                if (frameMasked) {
                    unmask(payloadBuffer);
                }

                // Processing ping/pong/close frames because they cannot be
                // fragmented
                if (frameOpcode == OPCODE_PING) {
                    WebSocketFrame frame = new PingWebSocketFrame(frameFinalFlag, frameRsv, payloadBuffer);
                    payloadBuffer = null;
                    ctx.fireChannelRead(frame);
                    return;
                }
                if (frameOpcode == OPCODE_PONG) {
                    WebSocketFrame frame = new PongWebSocketFrame(frameFinalFlag, frameRsv, payloadBuffer);
                    payloadBuffer = null;
                    ctx.fireChannelRead(frame);
                    return;
                }
                if (frameOpcode == OPCODE_CLOSE) {
                    receivedClosingHandshake = true;
                    checkCloseFrameBody(ctx, payloadBuffer);
                    WebSocketFrame frame = new CloseWebSocketFrame(frameFinalFlag, frameRsv, payloadBuffer);
                    payloadBuffer = null;
                    ctx.fireChannelRead(frame);
                    return;
                }

                // Processing for possible fragmented messages for text and binary
                // frames
                if (frameFinalFlag) {
                    // Final frame of the sequence. Apparently ping frames are
                    // allowed in the middle of a fragmented message
                    fragmentedFramesCount = 0;
                } else {
                    // Increment counter
                    fragmentedFramesCount++;
                }

                // Return the frame
                if (frameOpcode == OPCODE_TEXT) {
                    WebSocketFrame frame = new TextWebSocketFrame(frameFinalFlag, frameRsv, payloadBuffer);
                    payloadBuffer = null;
                    ctx.fireChannelRead(frame);
                    return;
                } else if (frameOpcode == OPCODE_BINARY) {
                    WebSocketFrame frame = new BinaryWebSocketFrame(frameFinalFlag, frameRsv, payloadBuffer);
                    payloadBuffer = null;
                    ctx.fireChannelRead(frame);
                    return;
                } else if (frameOpcode == OPCODE_CONT) {
                    WebSocketFrame frame = new ContinuationWebSocketFrame(frameFinalFlag, frameRsv, payloadBuffer);
                    payloadBuffer = null;
                    ctx.fireChannelRead(frame);
                    return;
                } else {
                    throw new UnsupportedOperationException("Cannot decode web socket frame with opcode: "
                                                            + frameOpcode);
                }
            } finally {
                if (payloadBuffer != null) {
                    payloadBuffer.close();
                }
            }
        }
        case CORRUPT: {
            if (in.readableBytes() > 0) {
                // If we don't keep reading Netty will throw an exception saying
                // we can't return null if no bytes read and state not changed.
                in.readByte();
            }
            return;
        }
        default:
            throw new Error("Shouldn't reach here.");
        }
    }

    private void unmask(Buffer frame) {
        int base = frame.readerOffset();
        int len = frame.readableBytes();
        int index = 0;

        int intMask = mask;
        if (intMask == 0) {
            // If the mask is 0 we can just return directly as the XOR operations will just produce the same value.
            return;
        }
        for (; index + 3 < len; index += Integer.BYTES) {
            int off = base + index;
            frame.setInt(off, frame.getInt(off) ^ intMask);
        }
        int maskOffset = 0;
        for (; index < len; index++) {
            int off = base + index;
            frame.setByte(off, (byte) (frame.getByte(off) ^ WebSocketUtil.byteAtIndex(intMask, maskOffset++ & 3)));
        }
    }

    private void protocolViolation(ChannelHandlerContext ctx, Buffer in, String reason) {
        protocolViolation(ctx, in, WebSocketCloseStatus.PROTOCOL_ERROR, reason);
    }

    private void protocolViolation(ChannelHandlerContext ctx, Buffer in, WebSocketCloseStatus status, String reason) {
        protocolViolation(ctx, in, new CorruptedWebSocketFrameException(status, reason));
    }

    private void protocolViolation(ChannelHandlerContext ctx, Buffer in, CorruptedWebSocketFrameException ex) {
        state = State.CORRUPT;
        int readableBytes = in.readableBytes();
        if (readableBytes > 0) {
            // Fix for memory leak, caused by ByteToMessageDecoder#channelRead:
            // buffer 'cumulation' is released ONLY when no more readable bytes available.
            in.skipReadableBytes(readableBytes);
        }
        if (ctx.channel().isActive() && config.closeOnProtocolViolation()) {
            Object closeMessage;
            if (receivedClosingHandshake) {
                closeMessage = ctx.bufferAllocator().allocate(0);
            } else {
                WebSocketCloseStatus closeStatus = ex.closeStatus();
                String reasonText = ex.getMessage();
                if (reasonText == null) {
                    reasonText = closeStatus.reasonText();
                }
                closeMessage = new CloseWebSocketFrame(ctx.bufferAllocator(), closeStatus, reasonText);
            }
            ctx.writeAndFlush(closeMessage).addListener(ctx, ChannelFutureListeners.CLOSE);
        }
        throw ex;
    }

    /** */
    protected void checkCloseFrameBody(ChannelHandlerContext ctx, Buffer buffer) {
        if (buffer == null || buffer.readableBytes() <= 0) {
            return;
        }
        if (buffer.readableBytes() == 1) {
            protocolViolation(ctx, buffer, WebSocketCloseStatus.INVALID_PAYLOAD_DATA, "Invalid close frame body");
        }

        // Save reader offset.
        int offset = buffer.readerOffset();
        try {
            // Must have 2 byte integer within the valid range.
            int statusCode = buffer.readShort();
            if (!WebSocketCloseStatus.isValidStatusCode(statusCode)) {
                protocolViolation(ctx, buffer, "Invalid close frame getStatus code: " + statusCode);
            }

            // May have UTF-8 message.
            if (buffer.readableBytes() > 0) {
                try {
                    new Utf8Validator().check(buffer);
                } catch (CorruptedWebSocketFrameException ex) {
                    protocolViolation(ctx, buffer, ex);
                }
            }
        } finally {
            // Restore reader offset.
            buffer.readerOffset(offset);
        }
    }

    enum State {
        READING_FIRST,
        READING_SECOND,
        READING_SIZE,
        MASKING_KEY,
        PAYLOAD,
        CORRUPT
    }
}
相关推荐
晓牛开发者1 天前
Netty4 TLS单向安全加密传输案例
netty
hanxiaozhang20183 天前
Netty面试重点-2
面试·netty
9527出列4 天前
Netty源码分析--客户端连接接入流程解析
网络协议·netty
马尚来5 天前
【韩顺平】尚硅谷Netty视频教程
后端·netty
马尚道7 天前
【韩顺平】尚硅谷Netty视频教程
netty
马尚道7 天前
Netty核心技术及源码剖析
源码·netty
moxiaoran57537 天前
java接收小程序发送的protobuf消息
websocket·netty·protobuf
马尚来8 天前
尚硅谷 Netty核心技术及源码剖析 Netty模型 详细版
源码·netty
马尚来8 天前
Netty核心技术及源码剖析
后端·netty
失散1314 天前
分布式专题——35 Netty的使用和常用组件辨析
java·分布式·架构·netty