[netty5: WebSocketFrameEncoder & WebSocketFrameDecoder]-源码解析

WebSocketFrameMaskGenerator

WebSocketFrameMaskGenerator 是用于客户端生成 WebSocket 帧掩码的接口,通过 nextMask() 返回一个 4 字节整数掩码,用于加密帧的 payload。

java 复制代码
public interface WebSocketFrameMaskGenerator {
    int nextMask();
}

RandomWebSocketFrameMaskGenerator

RandomWebSocketFrameMaskGenerator 实现了掩码生成接口,使用 ThreadLocalRandom 生成随机的 4 字节整数作为 WebSocket 帧的掩码。

java 复制代码
public final class RandomWebSocketFrameMaskGenerator implements WebSocketFrameMaskGenerator {

    public static final RandomWebSocketFrameMaskGenerator INSTANCE = new RandomWebSocketFrameMaskGenerator();

    private RandomWebSocketFrameMaskGenerator() {}

    @Override
    public int nextMask() {
        return ThreadLocalRandom.current().nextInt();
    }
}

WebSocketFrameEncoder

WebSocketFrameEncoder 负责将 WebSocketFrame 编码为符合 WebSocket 协议格式的二进制数据帧,处理帧头构造、负载长度扩展、掩码生成与数据异或。

java 复制代码
public interface WebSocketFrameEncoder extends ChannelHandler {}

WebSocket13FrameEncoder

WebSocket13FrameEncoder 将 WebSocket 帧编码成符合 RFC 6455 的二进制格式,支持负载长度扩展、可选掩码处理和分片发送,确保客户端数据按规范加密掩码。

java 复制代码
public class WebSocket13FrameEncoder extends MessageToMessageEncoder<WebSocketFrame> implements WebSocketFrameEncoder {

    private static final Logger logger = LoggerFactory.getLogger(WebSocket13FrameEncoder.class);
    
    private static final byte OPCODE_CONT = 0x0;
    private static final byte OPCODE_TEXT = 0x1;
    private static final byte OPCODE_BINARY = 0x2;
    private static final byte OPCODE_CLOSE = 0x8;
    private static final byte OPCODE_PING = 0x9;
    private static final byte OPCODE_PONG = 0xA;

    private static final int GATHERING_WRITE_THRESHOLD = 1024;
    
    private final WebSocketFrameMaskGenerator maskGenerator;

    public WebSocket13FrameEncoder(boolean maskPayload) {
        this(maskPayload ? RandomWebSocketFrameMaskGenerator.INSTANCE : null);
    }

    public WebSocket13FrameEncoder(WebSocketFrameMaskGenerator maskGenerator) {
        this.maskGenerator = maskGenerator;
    }

	//  0               1               2               3 
	//  0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 
	// +-+-+-+-+-------+-+-------------+-------------------------------+
	// |F|R|R|R| opcode|M| Payload len |    Extended payload length    |
	// |I|S|S|S|  (4)  |A|     (7)     |     (16/64 bits if needed)    |
	// +-+-+-+-+-------+-+-------------+-------------------------------+
	// |     Masking key (32 bits, only if MASK set to 1)             |
	// +---------------------------------------------------------------+
	// |     Masked/unmasked payload data                             |
	// +---------------------------------------------------------------+    
	@Override
    protected void encode(ChannelHandlerContext ctx, WebSocketFrame msg, List<Object> out) throws Exception {
        final Buffer data = msg.binaryData();

        byte opcode = getOpCode(msg);

        int length = data.readableBytes();

        if (logger.isTraceEnabled()) {
            logger.trace("Encoding WebSocket Frame opCode={} length={}", opcode, length);
        }

		// 构造第一个字节(b0): FIN RSV1 RSV2 RSV3 opcode
        int b0 = 0;
        if (msg.isFinalFragment()) {
            b0 |= 1 << 7; // FIN 位
        }
        b0 |= (msg.rsv() & 0x07) << 4; // RSV1、RSV2、RSV3
        b0 |= opcode & 0x7F; // 低 4 位 opcode
        
		// RFC 要求 PING 帧 payload 最大长度为 125 字节
        if (opcode == OPCODE_PING && length > 125) {
            throw new TooLongFrameException("invalid payload for PING (payload length must be <= 125, was " + length);
        }

		// 初始化输出 Buffer(header + payload + 掩码)
        Buffer buf = null;
        try {
            int maskLength = maskGenerator != null ? 4 : 0;
            if (length <= 125) {
            	// 1 b0 +  1 (0x80-10000000 + length-0xxxxxxxx) + 4 mask + payload data
                int size = 2 + maskLength + length;
                buf = ctx.bufferAllocator().allocate(size);
                buf.writeByte((byte) b0);
                byte b = (byte) (maskGenerator != null ? 0x80 | length : length);
                buf.writeByte(b);
            } else if (length <= 0xFFFF) {
            	// 1 b0 +  1 0xFE-1111 1110 + 2 16bit Extended payload length
                int size = 4 + maskLength;
                if (maskGenerator != null || length <= GATHERING_WRITE_THRESHOLD) {
                    size += length;
                }
                buf = ctx.bufferAllocator().allocate(size);
                buf.writeByte((byte) b0);
                buf.writeByte((byte) (maskGenerator != null ? 0xFE : 126));
                buf.writeByte((byte) (length >>> 8 & 0xFF));
                buf.writeByte((byte) (length & 0xFF));
            } else {
            	// 1 b0 +  1 0xFE-1111 1110 + 8 64bit Extended payload length
                int size = 10 + maskLength;
                if (maskGenerator != null || length <= GATHERING_WRITE_THRESHOLD) {
                    size += length;
                }
                buf = ctx.bufferAllocator().allocate(size);
                buf.writeByte((byte) b0);
                buf.writeByte((byte) (maskGenerator != null ? 0xFF : 127));
                buf.writeLong(length);
            }

            // Write payload
            // 掩码是 4 字节,RFC 规定客户端 必须 对每个字节 payload[i] ^= mask[i % 4]。
            if (maskGenerator != null) {
                int mask = maskGenerator.nextMask();
                buf.writeInt(mask);

                if (mask != 0) {
                    int i = data.readerOffset();
                    int end = data.writerOffset();

                    int maskOffset = 0;
                    for (; i < end; i++) {
                        byte byteData = data.getByte(i);
                        buf.writeByte((byte) (byteData ^ WebSocketUtil.byteAtIndex(mask, maskOffset++ & 3)));
                    }
                    out.add(buf);
                } else {
                    addBuffers(buf, data, out);
                }
            } else {
                addBuffers(buf, data, out);
            }
        } catch (Throwable t) {
            if (buf != null) {
                buf.close();
            }
            throw t;
        }
    }

    private static byte getOpCode(WebSocketFrame msg) {
        if (msg instanceof TextWebSocketFrame) {
            return OPCODE_TEXT;
        }
        if (msg instanceof BinaryWebSocketFrame) {
            return OPCODE_BINARY;
        }
        if (msg instanceof PingWebSocketFrame) {
            return OPCODE_PING;
        }
        if (msg instanceof PongWebSocketFrame) {
            return OPCODE_PONG;
        }
        if (msg instanceof CloseWebSocketFrame) {
            return OPCODE_CLOSE;
        }
        if (msg instanceof ContinuationWebSocketFrame) {
            return OPCODE_CONT;
        }
        throw new UnsupportedOperationException("Cannot encode frame of type: " + msg.getClass().getName());
    }

    private static void addBuffers(Buffer buf, Buffer data, List<Object> out) {
        int readableBytes = data.readableBytes();

        if (buf.writableBytes() >= readableBytes) {
            // merge buffers as this is cheaper then a gathering write if the payload is small enough
            buf.writeBytes(data);
            out.add(buf);
        } else {
            out.add(buf);
            if (readableBytes > 0) {
                out.add(data.split());
            }
        }
    }
}

WebSocketFrameDecoder

WebSocketFrameDecoder 负责将符合 WebSocket 协议格式的二进制数据帧解码成 WebSocketFrame,处理帧头解析、负载长度读取、掩码应用及分片重组。

java 复制代码
public interface WebSocketFrameDecoder extends ChannelHandler {}

WebSocket13FrameDecoder

WebSocket13FrameDecoder 负责将接收到的二进制数据解析成符合 RFC 6455 规范的 WebSocket 帧对象,处理掩码解码、负载长度扩展、多帧分片合并及控制帧校验。

java 复制代码
public class WebSocket13FrameDecoder extends ByteToMessageDecoder implements WebSocketFrameDecoder {

    private static final Logger logger = LoggerFactory.getLogger(WebSocket13FrameDecoder.class);
    
    private static final byte OPCODE_CONT = 0x0;
    private static final byte OPCODE_TEXT = 0x1;
    private static final byte OPCODE_BINARY = 0x2;
    private static final byte OPCODE_CLOSE = 0x8;
    private static final byte OPCODE_PING = 0x9;
    private static final byte OPCODE_PONG = 0xA;
    
    private final WebSocketDecoderConfig config;
    private int fragmentedFramesCount;
    private boolean frameFinalFlag;
    private boolean frameMasked;
    private int frameRsv;
    private int frameOpcode;
    private long framePayloadLength;
    private int mask;
    private int framePayloadLen1;
    private boolean receivedClosingHandshake;
    
    private State state = State.READING_FIRST;

    public WebSocket13FrameDecoder(boolean expectMaskedFrames, boolean allowExtensions, int maxFramePayloadLength) {
        this(expectMaskedFrames, allowExtensions, maxFramePayloadLength, false);
    }

    public WebSocket13FrameDecoder(boolean expectMaskedFrames, boolean allowExtensions, int maxFramePayloadLength, boolean allowMaskMismatch) {
        this(WebSocketDecoderConfig.newBuilder()
            .expectMaskedFrames(expectMaskedFrames)
            .allowExtensions(allowExtensions)
            .maxFramePayloadLength(maxFramePayloadLength)
            .allowMaskMismatch(allowMaskMismatch)
            .build());
    }

    public WebSocket13FrameDecoder(WebSocketDecoderConfig decoderConfig) {
        config = Objects.requireNonNull(decoderConfig, "decoderConfig");
    }

    private static int toFrameLength(long length) {
        if (length > Integer.MAX_VALUE) {
            throw new TooLongFrameException("frame length exceeds " + Integer.MAX_VALUE + ": " + length);
        } else {
            return (int) length;
        }
    }

    @Override
    protected void decode(ChannelHandlerContext ctx, Buffer in) throws Exception {
        // Discard all data received if closing handshake was received before.
        if (receivedClosingHandshake) {
            in.skipReadableBytes(actualReadableBytes());
            return;
        }

        switch (state) {
        case READING_FIRST: {
            if (in.readableBytes() == 0) {
                return;
            }

            framePayloadLength = 0;

            // FIN, RSV, OPCODE
            byte b = in.readByte();
            frameFinalFlag = (b & 0x80) != 0;
            frameRsv = (b & 0x70) >> 4;
            frameOpcode = b & 0x0F;

            if (logger.isTraceEnabled()) {
                logger.trace("Decoding WebSocket Frame opCode={}", frameOpcode);
            }

            state = State.READING_SECOND;
        }
        case READING_SECOND: {
            if (in.readableBytes() == 0) {
                return;
            }
            // MASK, PAYLOAD LEN 1
            byte b = in.readByte();
            frameMasked = (b & 0x80) != 0;
            framePayloadLen1 = b & 0x7F;

            if (frameRsv != 0 && !config.allowExtensions()) {
                protocolViolation(ctx, in, "RSV != 0 and no extension negotiated, RSV:" + frameRsv);
                return;
            }

            if (!config.allowMaskMismatch() && config.expectMaskedFrames() != frameMasked) {
                protocolViolation(ctx, in, "received a frame that is not masked as expected");
                return;
            }

            if (frameOpcode > 7) { // control frame (have MSB in opcode set)

                // control frames MUST NOT be fragmented
                if (!frameFinalFlag) {
                    protocolViolation(ctx, in, "fragmented control frame");
                    return;
                }

                // control frames MUST have payload 125 octets or less
                if (framePayloadLen1 > 125) {
                    protocolViolation(ctx, in, "control frame with payload length > 125 octets");
                    return;
                }

                // check for reserved control frame opcodes
                if (!(frameOpcode == OPCODE_CLOSE || frameOpcode == OPCODE_PING
                      || frameOpcode == OPCODE_PONG)) {
                    protocolViolation(ctx, in, "control frame using reserved opcode " + frameOpcode);
                    return;
                }

                // close frame : if there is a body, the first two bytes of the
                // body MUST be a 2-byte unsigned integer (in network byte
                // order) representing a getStatus code
                if (frameOpcode == 8 && framePayloadLen1 == 1) {
                    protocolViolation(ctx, in, "received close control frame with payload len 1");
                    return;
                }
            } else { // data frame
                // check for reserved data frame opcodes
                if (!(frameOpcode == OPCODE_CONT || frameOpcode == OPCODE_TEXT
                      || frameOpcode == OPCODE_BINARY)) {
                    protocolViolation(ctx, in, "data frame using reserved opcode " + frameOpcode);
                    return;
                }

                // check opcode vs message fragmentation state 1/2
                if (fragmentedFramesCount == 0 && frameOpcode == OPCODE_CONT) {
                    protocolViolation(ctx, in, "received continuation data frame outside fragmented message");
                    return;
                }

                // check opcode vs message fragmentation state 2/2
                if (fragmentedFramesCount != 0 && frameOpcode != OPCODE_CONT) {
                    protocolViolation(ctx, in,
                                      "received non-continuation data frame while inside fragmented message");
                    return;
                }
            }

            state = State.READING_SIZE;
        }
        case READING_SIZE: {
            // Read frame payload length
            if (framePayloadLen1 == 126) {
                if (in.readableBytes() < 2) {
                    return;
                }
                framePayloadLength = in.readUnsignedShort();
                if (framePayloadLength < 126) {
                    protocolViolation(ctx, in, "invalid data frame length (not using minimal length encoding)");
                    return;
                }
            } else if (framePayloadLen1 == 127) {
                if (in.readableBytes() < 8) {
                    return;
                }
                framePayloadLength = in.readLong();
                // TODO: check if it's bigger than 0x7FFFFFFFFFFFFFFF, Maybe
                // just check if it's negative?

                if (framePayloadLength < 65536) {
                    protocolViolation(ctx, in, "invalid data frame length (not using minimal length encoding)");
                    return;
                }
            } else {
                framePayloadLength = framePayloadLen1;
            }

            if (framePayloadLength > config.maxFramePayloadLength()) {
                protocolViolation(ctx, in, WebSocketCloseStatus.MESSAGE_TOO_BIG,
                                  "Max frame length of " + config.maxFramePayloadLength() + " has been exceeded.");
                return;
            }

            if (logger.isTraceEnabled()) {
                logger.trace("Decoding WebSocket Frame length={}", framePayloadLength);
            }

            state = State.MASKING_KEY;
        }
        case MASKING_KEY: {
            if (frameMasked) {
                if (in.readableBytes() < 4) {
                    return;
                }
                mask = in.readInt();
            }
            state = State.PAYLOAD;
        }
        case PAYLOAD: {
            if (in.readableBytes() < framePayloadLength) {
                return;
            }

            Buffer payloadBuffer = null;
            try {
                payloadBuffer = in.readSplit(toFrameLength(framePayloadLength));

                // Now we have all the data, the next checkpoint must be the next
                // frame
                state = State.READING_FIRST;

                // Unmask data if needed
                if (frameMasked) {
                    unmask(payloadBuffer);
                }

                // Processing ping/pong/close frames because they cannot be
                // fragmented
                if (frameOpcode == OPCODE_PING) {
                    WebSocketFrame frame = new PingWebSocketFrame(frameFinalFlag, frameRsv, payloadBuffer);
                    payloadBuffer = null;
                    ctx.fireChannelRead(frame);
                    return;
                }
                if (frameOpcode == OPCODE_PONG) {
                    WebSocketFrame frame = new PongWebSocketFrame(frameFinalFlag, frameRsv, payloadBuffer);
                    payloadBuffer = null;
                    ctx.fireChannelRead(frame);
                    return;
                }
                if (frameOpcode == OPCODE_CLOSE) {
                    receivedClosingHandshake = true;
                    checkCloseFrameBody(ctx, payloadBuffer);
                    WebSocketFrame frame = new CloseWebSocketFrame(frameFinalFlag, frameRsv, payloadBuffer);
                    payloadBuffer = null;
                    ctx.fireChannelRead(frame);
                    return;
                }

                // Processing for possible fragmented messages for text and binary
                // frames
                if (frameFinalFlag) {
                    // Final frame of the sequence. Apparently ping frames are
                    // allowed in the middle of a fragmented message
                    fragmentedFramesCount = 0;
                } else {
                    // Increment counter
                    fragmentedFramesCount++;
                }

                // Return the frame
                if (frameOpcode == OPCODE_TEXT) {
                    WebSocketFrame frame = new TextWebSocketFrame(frameFinalFlag, frameRsv, payloadBuffer);
                    payloadBuffer = null;
                    ctx.fireChannelRead(frame);
                    return;
                } else if (frameOpcode == OPCODE_BINARY) {
                    WebSocketFrame frame = new BinaryWebSocketFrame(frameFinalFlag, frameRsv, payloadBuffer);
                    payloadBuffer = null;
                    ctx.fireChannelRead(frame);
                    return;
                } else if (frameOpcode == OPCODE_CONT) {
                    WebSocketFrame frame = new ContinuationWebSocketFrame(frameFinalFlag, frameRsv, payloadBuffer);
                    payloadBuffer = null;
                    ctx.fireChannelRead(frame);
                    return;
                } else {
                    throw new UnsupportedOperationException("Cannot decode web socket frame with opcode: "
                                                            + frameOpcode);
                }
            } finally {
                if (payloadBuffer != null) {
                    payloadBuffer.close();
                }
            }
        }
        case CORRUPT: {
            if (in.readableBytes() > 0) {
                // If we don't keep reading Netty will throw an exception saying
                // we can't return null if no bytes read and state not changed.
                in.readByte();
            }
            return;
        }
        default:
            throw new Error("Shouldn't reach here.");
        }
    }

    private void unmask(Buffer frame) {
        int base = frame.readerOffset();
        int len = frame.readableBytes();
        int index = 0;

        int intMask = mask;
        if (intMask == 0) {
            // If the mask is 0 we can just return directly as the XOR operations will just produce the same value.
            return;
        }
        for (; index + 3 < len; index += Integer.BYTES) {
            int off = base + index;
            frame.setInt(off, frame.getInt(off) ^ intMask);
        }
        int maskOffset = 0;
        for (; index < len; index++) {
            int off = base + index;
            frame.setByte(off, (byte) (frame.getByte(off) ^ WebSocketUtil.byteAtIndex(intMask, maskOffset++ & 3)));
        }
    }

    private void protocolViolation(ChannelHandlerContext ctx, Buffer in, String reason) {
        protocolViolation(ctx, in, WebSocketCloseStatus.PROTOCOL_ERROR, reason);
    }

    private void protocolViolation(ChannelHandlerContext ctx, Buffer in, WebSocketCloseStatus status, String reason) {
        protocolViolation(ctx, in, new CorruptedWebSocketFrameException(status, reason));
    }

    private void protocolViolation(ChannelHandlerContext ctx, Buffer in, CorruptedWebSocketFrameException ex) {
        state = State.CORRUPT;
        int readableBytes = in.readableBytes();
        if (readableBytes > 0) {
            // Fix for memory leak, caused by ByteToMessageDecoder#channelRead:
            // buffer 'cumulation' is released ONLY when no more readable bytes available.
            in.skipReadableBytes(readableBytes);
        }
        if (ctx.channel().isActive() && config.closeOnProtocolViolation()) {
            Object closeMessage;
            if (receivedClosingHandshake) {
                closeMessage = ctx.bufferAllocator().allocate(0);
            } else {
                WebSocketCloseStatus closeStatus = ex.closeStatus();
                String reasonText = ex.getMessage();
                if (reasonText == null) {
                    reasonText = closeStatus.reasonText();
                }
                closeMessage = new CloseWebSocketFrame(ctx.bufferAllocator(), closeStatus, reasonText);
            }
            ctx.writeAndFlush(closeMessage).addListener(ctx, ChannelFutureListeners.CLOSE);
        }
        throw ex;
    }

    /** */
    protected void checkCloseFrameBody(ChannelHandlerContext ctx, Buffer buffer) {
        if (buffer == null || buffer.readableBytes() <= 0) {
            return;
        }
        if (buffer.readableBytes() == 1) {
            protocolViolation(ctx, buffer, WebSocketCloseStatus.INVALID_PAYLOAD_DATA, "Invalid close frame body");
        }

        // Save reader offset.
        int offset = buffer.readerOffset();
        try {
            // Must have 2 byte integer within the valid range.
            int statusCode = buffer.readShort();
            if (!WebSocketCloseStatus.isValidStatusCode(statusCode)) {
                protocolViolation(ctx, buffer, "Invalid close frame getStatus code: " + statusCode);
            }

            // May have UTF-8 message.
            if (buffer.readableBytes() > 0) {
                try {
                    new Utf8Validator().check(buffer);
                } catch (CorruptedWebSocketFrameException ex) {
                    protocolViolation(ctx, buffer, ex);
                }
            }
        } finally {
            // Restore reader offset.
            buffer.readerOffset(offset);
        }
    }

    enum State {
        READING_FIRST,
        READING_SECOND,
        READING_SIZE,
        MASKING_KEY,
        PAYLOAD,
        CORRUPT
    }
}
相关推荐
idolyXyz1 天前
[netty5: HttpObjectEncoder & HttpObjectDecoder]-源码解析
netty
Derek_Smart15 天前
基于Netty与Spring Integration的高并发工业物联网网关架构设计与实现
spring boot·物联网·netty
迢迢星万里灬16 天前
Java求职者面试指南:微服务技术与源码原理深度解析
java·spring cloud·微服务·dubbo·netty·分布式系统·面试指南
Y_3_717 天前
Netty实战:从核心组件到多协议实现(超详细注释,udp,tcp,websocket,http完整demo)
linux·运维·后端·ubuntu·netty
安徽杰杰24 天前
能源即服务:智慧移动充电桩的供给模式创新
netty
安徽杰杰25 天前
新基建浪潮下:中国新能源汽车充电桩智慧化建设与管理实践
netty
迢迢星万里灬1 个月前
Java求职者面试:微服务技术与源码原理深度解析
java·spring cloud·微服务·dubbo·netty·分布式系统
触角云科技1 个月前
掌上充电站:基于APP/小程序的新能源汽车智慧充电管理
netty
安徽杰杰1 个月前
智慧充电:新能源汽车智慧充电桩的发展前景受哪些因素影响?
netty