WebSocketFrameMaskGenerator
WebSocketFrameMaskGenerator 是用于客户端生成 WebSocket 帧掩码的接口,通过 nextMask() 返回一个 4 字节整数掩码,用于加密帧的 payload。
java
public interface WebSocketFrameMaskGenerator {
int nextMask();
}
RandomWebSocketFrameMaskGenerator
RandomWebSocketFrameMaskGenerator 实现了掩码生成接口,使用 ThreadLocalRandom 生成随机的 4 字节整数作为 WebSocket 帧的掩码。
java
public final class RandomWebSocketFrameMaskGenerator implements WebSocketFrameMaskGenerator {
public static final RandomWebSocketFrameMaskGenerator INSTANCE = new RandomWebSocketFrameMaskGenerator();
private RandomWebSocketFrameMaskGenerator() {}
@Override
public int nextMask() {
return ThreadLocalRandom.current().nextInt();
}
}
WebSocketFrameEncoder
WebSocketFrameEncoder
负责将 WebSocketFrame
编码为符合 WebSocket 协议格式的二进制数据帧,处理帧头构造、负载长度扩展、掩码生成与数据异或。
java
public interface WebSocketFrameEncoder extends ChannelHandler {}
WebSocket13FrameEncoder
WebSocket13FrameEncoder
将 WebSocket 帧编码成符合 RFC 6455 的二进制格式,支持负载长度扩展、可选掩码处理和分片发送,确保客户端数据按规范加密掩码。
java
public class WebSocket13FrameEncoder extends MessageToMessageEncoder<WebSocketFrame> implements WebSocketFrameEncoder {
private static final Logger logger = LoggerFactory.getLogger(WebSocket13FrameEncoder.class);
private static final byte OPCODE_CONT = 0x0;
private static final byte OPCODE_TEXT = 0x1;
private static final byte OPCODE_BINARY = 0x2;
private static final byte OPCODE_CLOSE = 0x8;
private static final byte OPCODE_PING = 0x9;
private static final byte OPCODE_PONG = 0xA;
private static final int GATHERING_WRITE_THRESHOLD = 1024;
private final WebSocketFrameMaskGenerator maskGenerator;
public WebSocket13FrameEncoder(boolean maskPayload) {
this(maskPayload ? RandomWebSocketFrameMaskGenerator.INSTANCE : null);
}
public WebSocket13FrameEncoder(WebSocketFrameMaskGenerator maskGenerator) {
this.maskGenerator = maskGenerator;
}
// 0 1 2 3
// 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7
// +-+-+-+-+-------+-+-------------+-------------------------------+
// |F|R|R|R| opcode|M| Payload len | Extended payload length |
// |I|S|S|S| (4) |A| (7) | (16/64 bits if needed) |
// +-+-+-+-+-------+-+-------------+-------------------------------+
// | Masking key (32 bits, only if MASK set to 1) |
// +---------------------------------------------------------------+
// | Masked/unmasked payload data |
// +---------------------------------------------------------------+
@Override
protected void encode(ChannelHandlerContext ctx, WebSocketFrame msg, List<Object> out) throws Exception {
final Buffer data = msg.binaryData();
byte opcode = getOpCode(msg);
int length = data.readableBytes();
if (logger.isTraceEnabled()) {
logger.trace("Encoding WebSocket Frame opCode={} length={}", opcode, length);
}
// 构造第一个字节(b0): FIN RSV1 RSV2 RSV3 opcode
int b0 = 0;
if (msg.isFinalFragment()) {
b0 |= 1 << 7; // FIN 位
}
b0 |= (msg.rsv() & 0x07) << 4; // RSV1、RSV2、RSV3
b0 |= opcode & 0x7F; // 低 4 位 opcode
// RFC 要求 PING 帧 payload 最大长度为 125 字节
if (opcode == OPCODE_PING && length > 125) {
throw new TooLongFrameException("invalid payload for PING (payload length must be <= 125, was " + length);
}
// 初始化输出 Buffer(header + payload + 掩码)
Buffer buf = null;
try {
int maskLength = maskGenerator != null ? 4 : 0;
if (length <= 125) {
// 1 b0 + 1 (0x80-10000000 + length-0xxxxxxxx) + 4 mask + payload data
int size = 2 + maskLength + length;
buf = ctx.bufferAllocator().allocate(size);
buf.writeByte((byte) b0);
byte b = (byte) (maskGenerator != null ? 0x80 | length : length);
buf.writeByte(b);
} else if (length <= 0xFFFF) {
// 1 b0 + 1 0xFE-1111 1110 + 2 16bit Extended payload length
int size = 4 + maskLength;
if (maskGenerator != null || length <= GATHERING_WRITE_THRESHOLD) {
size += length;
}
buf = ctx.bufferAllocator().allocate(size);
buf.writeByte((byte) b0);
buf.writeByte((byte) (maskGenerator != null ? 0xFE : 126));
buf.writeByte((byte) (length >>> 8 & 0xFF));
buf.writeByte((byte) (length & 0xFF));
} else {
// 1 b0 + 1 0xFE-1111 1110 + 8 64bit Extended payload length
int size = 10 + maskLength;
if (maskGenerator != null || length <= GATHERING_WRITE_THRESHOLD) {
size += length;
}
buf = ctx.bufferAllocator().allocate(size);
buf.writeByte((byte) b0);
buf.writeByte((byte) (maskGenerator != null ? 0xFF : 127));
buf.writeLong(length);
}
// Write payload
// 掩码是 4 字节,RFC 规定客户端 必须 对每个字节 payload[i] ^= mask[i % 4]。
if (maskGenerator != null) {
int mask = maskGenerator.nextMask();
buf.writeInt(mask);
if (mask != 0) {
int i = data.readerOffset();
int end = data.writerOffset();
int maskOffset = 0;
for (; i < end; i++) {
byte byteData = data.getByte(i);
buf.writeByte((byte) (byteData ^ WebSocketUtil.byteAtIndex(mask, maskOffset++ & 3)));
}
out.add(buf);
} else {
addBuffers(buf, data, out);
}
} else {
addBuffers(buf, data, out);
}
} catch (Throwable t) {
if (buf != null) {
buf.close();
}
throw t;
}
}
private static byte getOpCode(WebSocketFrame msg) {
if (msg instanceof TextWebSocketFrame) {
return OPCODE_TEXT;
}
if (msg instanceof BinaryWebSocketFrame) {
return OPCODE_BINARY;
}
if (msg instanceof PingWebSocketFrame) {
return OPCODE_PING;
}
if (msg instanceof PongWebSocketFrame) {
return OPCODE_PONG;
}
if (msg instanceof CloseWebSocketFrame) {
return OPCODE_CLOSE;
}
if (msg instanceof ContinuationWebSocketFrame) {
return OPCODE_CONT;
}
throw new UnsupportedOperationException("Cannot encode frame of type: " + msg.getClass().getName());
}
private static void addBuffers(Buffer buf, Buffer data, List<Object> out) {
int readableBytes = data.readableBytes();
if (buf.writableBytes() >= readableBytes) {
// merge buffers as this is cheaper then a gathering write if the payload is small enough
buf.writeBytes(data);
out.add(buf);
} else {
out.add(buf);
if (readableBytes > 0) {
out.add(data.split());
}
}
}
}
WebSocketFrameDecoder
WebSocketFrameDecoder
负责将符合 WebSocket 协议格式的二进制数据帧解码成 WebSocketFrame
,处理帧头解析、负载长度读取、掩码应用及分片重组。
java
public interface WebSocketFrameDecoder extends ChannelHandler {}
WebSocket13FrameDecoder
WebSocket13FrameDecoder 负责将接收到的二进制数据解析成符合 RFC 6455 规范的 WebSocket 帧对象,处理掩码解码、负载长度扩展、多帧分片合并及控制帧校验。
java
public class WebSocket13FrameDecoder extends ByteToMessageDecoder implements WebSocketFrameDecoder {
private static final Logger logger = LoggerFactory.getLogger(WebSocket13FrameDecoder.class);
private static final byte OPCODE_CONT = 0x0;
private static final byte OPCODE_TEXT = 0x1;
private static final byte OPCODE_BINARY = 0x2;
private static final byte OPCODE_CLOSE = 0x8;
private static final byte OPCODE_PING = 0x9;
private static final byte OPCODE_PONG = 0xA;
private final WebSocketDecoderConfig config;
private int fragmentedFramesCount;
private boolean frameFinalFlag;
private boolean frameMasked;
private int frameRsv;
private int frameOpcode;
private long framePayloadLength;
private int mask;
private int framePayloadLen1;
private boolean receivedClosingHandshake;
private State state = State.READING_FIRST;
public WebSocket13FrameDecoder(boolean expectMaskedFrames, boolean allowExtensions, int maxFramePayloadLength) {
this(expectMaskedFrames, allowExtensions, maxFramePayloadLength, false);
}
public WebSocket13FrameDecoder(boolean expectMaskedFrames, boolean allowExtensions, int maxFramePayloadLength, boolean allowMaskMismatch) {
this(WebSocketDecoderConfig.newBuilder()
.expectMaskedFrames(expectMaskedFrames)
.allowExtensions(allowExtensions)
.maxFramePayloadLength(maxFramePayloadLength)
.allowMaskMismatch(allowMaskMismatch)
.build());
}
public WebSocket13FrameDecoder(WebSocketDecoderConfig decoderConfig) {
config = Objects.requireNonNull(decoderConfig, "decoderConfig");
}
private static int toFrameLength(long length) {
if (length > Integer.MAX_VALUE) {
throw new TooLongFrameException("frame length exceeds " + Integer.MAX_VALUE + ": " + length);
} else {
return (int) length;
}
}
@Override
protected void decode(ChannelHandlerContext ctx, Buffer in) throws Exception {
// Discard all data received if closing handshake was received before.
if (receivedClosingHandshake) {
in.skipReadableBytes(actualReadableBytes());
return;
}
switch (state) {
case READING_FIRST: {
if (in.readableBytes() == 0) {
return;
}
framePayloadLength = 0;
// FIN, RSV, OPCODE
byte b = in.readByte();
frameFinalFlag = (b & 0x80) != 0;
frameRsv = (b & 0x70) >> 4;
frameOpcode = b & 0x0F;
if (logger.isTraceEnabled()) {
logger.trace("Decoding WebSocket Frame opCode={}", frameOpcode);
}
state = State.READING_SECOND;
}
case READING_SECOND: {
if (in.readableBytes() == 0) {
return;
}
// MASK, PAYLOAD LEN 1
byte b = in.readByte();
frameMasked = (b & 0x80) != 0;
framePayloadLen1 = b & 0x7F;
if (frameRsv != 0 && !config.allowExtensions()) {
protocolViolation(ctx, in, "RSV != 0 and no extension negotiated, RSV:" + frameRsv);
return;
}
if (!config.allowMaskMismatch() && config.expectMaskedFrames() != frameMasked) {
protocolViolation(ctx, in, "received a frame that is not masked as expected");
return;
}
if (frameOpcode > 7) { // control frame (have MSB in opcode set)
// control frames MUST NOT be fragmented
if (!frameFinalFlag) {
protocolViolation(ctx, in, "fragmented control frame");
return;
}
// control frames MUST have payload 125 octets or less
if (framePayloadLen1 > 125) {
protocolViolation(ctx, in, "control frame with payload length > 125 octets");
return;
}
// check for reserved control frame opcodes
if (!(frameOpcode == OPCODE_CLOSE || frameOpcode == OPCODE_PING
|| frameOpcode == OPCODE_PONG)) {
protocolViolation(ctx, in, "control frame using reserved opcode " + frameOpcode);
return;
}
// close frame : if there is a body, the first two bytes of the
// body MUST be a 2-byte unsigned integer (in network byte
// order) representing a getStatus code
if (frameOpcode == 8 && framePayloadLen1 == 1) {
protocolViolation(ctx, in, "received close control frame with payload len 1");
return;
}
} else { // data frame
// check for reserved data frame opcodes
if (!(frameOpcode == OPCODE_CONT || frameOpcode == OPCODE_TEXT
|| frameOpcode == OPCODE_BINARY)) {
protocolViolation(ctx, in, "data frame using reserved opcode " + frameOpcode);
return;
}
// check opcode vs message fragmentation state 1/2
if (fragmentedFramesCount == 0 && frameOpcode == OPCODE_CONT) {
protocolViolation(ctx, in, "received continuation data frame outside fragmented message");
return;
}
// check opcode vs message fragmentation state 2/2
if (fragmentedFramesCount != 0 && frameOpcode != OPCODE_CONT) {
protocolViolation(ctx, in,
"received non-continuation data frame while inside fragmented message");
return;
}
}
state = State.READING_SIZE;
}
case READING_SIZE: {
// Read frame payload length
if (framePayloadLen1 == 126) {
if (in.readableBytes() < 2) {
return;
}
framePayloadLength = in.readUnsignedShort();
if (framePayloadLength < 126) {
protocolViolation(ctx, in, "invalid data frame length (not using minimal length encoding)");
return;
}
} else if (framePayloadLen1 == 127) {
if (in.readableBytes() < 8) {
return;
}
framePayloadLength = in.readLong();
// TODO: check if it's bigger than 0x7FFFFFFFFFFFFFFF, Maybe
// just check if it's negative?
if (framePayloadLength < 65536) {
protocolViolation(ctx, in, "invalid data frame length (not using minimal length encoding)");
return;
}
} else {
framePayloadLength = framePayloadLen1;
}
if (framePayloadLength > config.maxFramePayloadLength()) {
protocolViolation(ctx, in, WebSocketCloseStatus.MESSAGE_TOO_BIG,
"Max frame length of " + config.maxFramePayloadLength() + " has been exceeded.");
return;
}
if (logger.isTraceEnabled()) {
logger.trace("Decoding WebSocket Frame length={}", framePayloadLength);
}
state = State.MASKING_KEY;
}
case MASKING_KEY: {
if (frameMasked) {
if (in.readableBytes() < 4) {
return;
}
mask = in.readInt();
}
state = State.PAYLOAD;
}
case PAYLOAD: {
if (in.readableBytes() < framePayloadLength) {
return;
}
Buffer payloadBuffer = null;
try {
payloadBuffer = in.readSplit(toFrameLength(framePayloadLength));
// Now we have all the data, the next checkpoint must be the next
// frame
state = State.READING_FIRST;
// Unmask data if needed
if (frameMasked) {
unmask(payloadBuffer);
}
// Processing ping/pong/close frames because they cannot be
// fragmented
if (frameOpcode == OPCODE_PING) {
WebSocketFrame frame = new PingWebSocketFrame(frameFinalFlag, frameRsv, payloadBuffer);
payloadBuffer = null;
ctx.fireChannelRead(frame);
return;
}
if (frameOpcode == OPCODE_PONG) {
WebSocketFrame frame = new PongWebSocketFrame(frameFinalFlag, frameRsv, payloadBuffer);
payloadBuffer = null;
ctx.fireChannelRead(frame);
return;
}
if (frameOpcode == OPCODE_CLOSE) {
receivedClosingHandshake = true;
checkCloseFrameBody(ctx, payloadBuffer);
WebSocketFrame frame = new CloseWebSocketFrame(frameFinalFlag, frameRsv, payloadBuffer);
payloadBuffer = null;
ctx.fireChannelRead(frame);
return;
}
// Processing for possible fragmented messages for text and binary
// frames
if (frameFinalFlag) {
// Final frame of the sequence. Apparently ping frames are
// allowed in the middle of a fragmented message
fragmentedFramesCount = 0;
} else {
// Increment counter
fragmentedFramesCount++;
}
// Return the frame
if (frameOpcode == OPCODE_TEXT) {
WebSocketFrame frame = new TextWebSocketFrame(frameFinalFlag, frameRsv, payloadBuffer);
payloadBuffer = null;
ctx.fireChannelRead(frame);
return;
} else if (frameOpcode == OPCODE_BINARY) {
WebSocketFrame frame = new BinaryWebSocketFrame(frameFinalFlag, frameRsv, payloadBuffer);
payloadBuffer = null;
ctx.fireChannelRead(frame);
return;
} else if (frameOpcode == OPCODE_CONT) {
WebSocketFrame frame = new ContinuationWebSocketFrame(frameFinalFlag, frameRsv, payloadBuffer);
payloadBuffer = null;
ctx.fireChannelRead(frame);
return;
} else {
throw new UnsupportedOperationException("Cannot decode web socket frame with opcode: "
+ frameOpcode);
}
} finally {
if (payloadBuffer != null) {
payloadBuffer.close();
}
}
}
case CORRUPT: {
if (in.readableBytes() > 0) {
// If we don't keep reading Netty will throw an exception saying
// we can't return null if no bytes read and state not changed.
in.readByte();
}
return;
}
default:
throw new Error("Shouldn't reach here.");
}
}
private void unmask(Buffer frame) {
int base = frame.readerOffset();
int len = frame.readableBytes();
int index = 0;
int intMask = mask;
if (intMask == 0) {
// If the mask is 0 we can just return directly as the XOR operations will just produce the same value.
return;
}
for (; index + 3 < len; index += Integer.BYTES) {
int off = base + index;
frame.setInt(off, frame.getInt(off) ^ intMask);
}
int maskOffset = 0;
for (; index < len; index++) {
int off = base + index;
frame.setByte(off, (byte) (frame.getByte(off) ^ WebSocketUtil.byteAtIndex(intMask, maskOffset++ & 3)));
}
}
private void protocolViolation(ChannelHandlerContext ctx, Buffer in, String reason) {
protocolViolation(ctx, in, WebSocketCloseStatus.PROTOCOL_ERROR, reason);
}
private void protocolViolation(ChannelHandlerContext ctx, Buffer in, WebSocketCloseStatus status, String reason) {
protocolViolation(ctx, in, new CorruptedWebSocketFrameException(status, reason));
}
private void protocolViolation(ChannelHandlerContext ctx, Buffer in, CorruptedWebSocketFrameException ex) {
state = State.CORRUPT;
int readableBytes = in.readableBytes();
if (readableBytes > 0) {
// Fix for memory leak, caused by ByteToMessageDecoder#channelRead:
// buffer 'cumulation' is released ONLY when no more readable bytes available.
in.skipReadableBytes(readableBytes);
}
if (ctx.channel().isActive() && config.closeOnProtocolViolation()) {
Object closeMessage;
if (receivedClosingHandshake) {
closeMessage = ctx.bufferAllocator().allocate(0);
} else {
WebSocketCloseStatus closeStatus = ex.closeStatus();
String reasonText = ex.getMessage();
if (reasonText == null) {
reasonText = closeStatus.reasonText();
}
closeMessage = new CloseWebSocketFrame(ctx.bufferAllocator(), closeStatus, reasonText);
}
ctx.writeAndFlush(closeMessage).addListener(ctx, ChannelFutureListeners.CLOSE);
}
throw ex;
}
/** */
protected void checkCloseFrameBody(ChannelHandlerContext ctx, Buffer buffer) {
if (buffer == null || buffer.readableBytes() <= 0) {
return;
}
if (buffer.readableBytes() == 1) {
protocolViolation(ctx, buffer, WebSocketCloseStatus.INVALID_PAYLOAD_DATA, "Invalid close frame body");
}
// Save reader offset.
int offset = buffer.readerOffset();
try {
// Must have 2 byte integer within the valid range.
int statusCode = buffer.readShort();
if (!WebSocketCloseStatus.isValidStatusCode(statusCode)) {
protocolViolation(ctx, buffer, "Invalid close frame getStatus code: " + statusCode);
}
// May have UTF-8 message.
if (buffer.readableBytes() > 0) {
try {
new Utf8Validator().check(buffer);
} catch (CorruptedWebSocketFrameException ex) {
protocolViolation(ctx, buffer, ex);
}
}
} finally {
// Restore reader offset.
buffer.readerOffset(offset);
}
}
enum State {
READING_FIRST,
READING_SECOND,
READING_SIZE,
MASKING_KEY,
PAYLOAD,
CORRUPT
}
}