在阅读这篇文章前,推荐先阅读以下内容:
WebSocketServerHandshakerFactory
WebSocketServerHandshakerFactory
用于根据客户端请求中的 WebSocket 版本构造对应的 WebSocketServerHandshaker
实例,完成握手协议版本的协商与支持判断。
java
public class WebSocketServerHandshakerFactory {
private final String webSocketURL;
private final String subprotocols;
private final WebSocketDecoderConfig decoderConfig;
// ...
public WebSocketServerHandshakerFactory(String webSocketURL, String subprotocols, WebSocketDecoderConfig decoderConfig) {
this.webSocketURL = webSocketURL;
this.subprotocols = subprotocols;
this.decoderConfig = Objects.requireNonNull(decoderConfig, "decoderConfig");
}
public WebSocketServerHandshaker newHandshaker(HttpRequest req) {
return resolveHandshaker0(req, webSocketURL, subprotocols, decoderConfig);
}
public static WebSocketServerHandshaker resolveHandshaker(HttpRequest req, String webSocketURL, String subprotocols, WebSocketDecoderConfig decoderConfig) {
Objects.requireNonNull(decoderConfig, "decoderConfig");
return resolveHandshaker0(req, webSocketURL, subprotocols, decoderConfig);
}
private static WebSocketServerHandshaker resolveHandshaker0(HttpRequest req, String webSocketURL, String subprotocols, WebSocketDecoderConfig decoderConfig) {
CharSequence version = req.headers().get(HttpHeaderNames.SEC_WEBSOCKET_VERSION);
if (version != null && AsciiString.contentEqualsIgnoreCase(version, WebSocketVersion.V13.toAsciiString())) {
// Version 13 of the wire protocol - RFC 6455 (version 17 of the draft hybi specification).
return new WebSocketServerHandshaker13(webSocketURL, subprotocols, decoderConfig);
}
return null;
}
public static Future<Void> sendUnsupportedVersionResponse(Channel channel) {
HttpResponse res = new DefaultFullHttpResponse(
HttpVersion.HTTP_1_1,
HttpResponseStatus.UPGRADE_REQUIRED,
channel.bufferAllocator().allocate(0)
);
res.headers().set(HttpHeaderNames.SEC_WEBSOCKET_VERSION, WebSocketVersion.V13.toHttpHeaderValue());
HttpUtil.setContentLength(res, 0);
return channel.writeAndFlush(res);
}
}
WebSocketServerHandshaker13
WebSocketServerHandshaker13
负责基于 RFC 6455 实现 WebSocket 版本 13 的服务端握手处理流程,包括请求校验、响应生成、子协议协商和帧编解码器的安装。
java
public class WebSocketServerHandshaker13 extends WebSocketServerHandshaker {
public WebSocketServerHandshaker13(String webSocketURL, String subprotocols, WebSocketDecoderConfig decoderConfig) {
super(WebSocketVersion.V13, webSocketURL, subprotocols, decoderConfig);
}
/**
* <p>
* Handle the web socket handshake for the web socket specification <a href=
* "https://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-17">HyBi versions 13-17</a>. Versions 13-17
* share the same wire protocol.
* </p>
*
* <p>
* Browser request to the server:
* </p>
*
* <pre>
* GET /chat HTTP/1.1
* Host: server.example.com
* Upgrade: websocket
* Connection: Upgrade
* Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==
* Origin: http://example.com
* Sec-WebSocket-Protocol: chat, superchat
* Sec-WebSocket-Version: 13
* </pre>
*
* <p>
* Server response:
* </p>
*
* <pre>
* HTTP/1.1 101 Switching Protocols
* Upgrade: websocket
* Connection: Upgrade
* Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=
* Sec-WebSocket-Protocol: chat
* </pre>
*/
@Override
protected FullHttpResponse newHandshakeResponse(BufferAllocator allocator, FullHttpRequest req, HttpHeaders headers) {
HttpMethod method = req.method();
if (!HttpMethod.GET.equals(method)) {
throw new WebSocketServerHandshakeException("Invalid WebSocket handshake method: " + method, req);
}
HttpHeaders reqHeaders = req.headers();
if (!reqHeaders.contains(HttpHeaderNames.CONNECTION) || !reqHeaders.containsIgnoreCase(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE)) {
throw new WebSocketServerHandshakeException("not a WebSocket request: a |Connection| header must includes a token 'Upgrade'", req);
}
if (!reqHeaders.containsIgnoreCase(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET)) {
throw new WebSocketServerHandshakeException("not a WebSocket request: a |Upgrade| header must containing the value 'websocket'", req);
}
CharSequence key = reqHeaders.get(HttpHeaderNames.SEC_WEBSOCKET_KEY);
if (key == null) {
throw new WebSocketServerHandshakeException("not a WebSocket request: missing key", req);
}
FullHttpResponse res = new DefaultFullHttpResponse(
HTTP_1_1,
HttpResponseStatus.SWITCHING_PROTOCOLS,
allocator.allocate(0)
);
if (headers != null) {
res.headers().add(headers);
}
String accept = WebSocketUtil.calculateV13Accept(key.toString());
res.headers().set(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET)
.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE)
.set(HttpHeaderNames.SEC_WEBSOCKET_ACCEPT, accept);
CharSequence subprotocols = reqHeaders.get(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL);
if (subprotocols != null) {
String selectedSubprotocol = selectSubprotocol(subprotocols.toString());
if (selectedSubprotocol == null) {
if (logger.isDebugEnabled()) {
logger.debug("Requested subprotocol(s) not supported: {}", subprotocols);
}
} else {
res.headers().set(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL, selectedSubprotocol);
}
}
return res;
}
@Override
protected WebSocketFrameDecoder newWebsocketDecoder() {
return new WebSocket13FrameDecoder(decoderConfig());
}
@Override
protected WebSocketFrameEncoder newWebSocketEncoder() {
return new WebSocket13FrameEncoder(false);
}
}
WebSocketServerHandshaker
WebSocketServerHandshaker
是 WebSocket 握手处理的抽象基类,定义了服务端握手响应、子协议选择和编解码器安装等通用逻辑,供具体版本(如 V13)实现。
java
public abstract class WebSocketServerHandshaker {
protected static final Logger logger = LoggerFactory.getLogger(WebSocketServerHandshaker.class);
private final String uri;
private final String[] subprotocols;
private final WebSocketVersion version;
private final WebSocketDecoderConfig decoderConfig;
private String selectedSubprotocol;
public static final String SUB_PROTOCOL_WILDCARD = "*";
protected WebSocketServerHandshaker(WebSocketVersion version, String uri, String subprotocols, WebSocketDecoderConfig decoderConfig) {
this.version = version;
this.uri = uri;
if (subprotocols != null) {
String[] subprotocolArray = subprotocols.split(",");
for (int i = 0; i < subprotocolArray.length; i++) {
subprotocolArray[i] = subprotocolArray[i].trim();
}
this.subprotocols = subprotocolArray;
} else {
this.subprotocols = EmptyArrays.EMPTY_STRINGS;
}
this.decoderConfig = requireNonNull(decoderConfig, "decoderConfig");
}
// 将当前 Handshaker 支持的子协议数组转换为有序去重的 Set 返回,用于后续子协议协商。
public Set<String> subprotocols() {
Set<String> ret = new LinkedHashSet<>();
Collections.addAll(ret, subprotocols);
return ret;
}
// WebSocketServerProtocolHandshakeHandler.channelRead
// 执行 WebSocket 握手响应、替换或插入编解码器并清理不兼容的 HTTP 处理器,最终完成协议切换。
public Future<Void> handshake(Channel channel, FullHttpRequest req) {
return handshake(channel, req, null);
}
public final Future<Void> handshake(Channel channel, FullHttpRequest req, HttpHeaders responseHeaders) {
if (logger.isDebugEnabled()) {
logger.debug("{} WebSocket version {} server handshake", channel, version());
}
// WebSocketServerHandshaker13.newHandshakeResponse
FullHttpResponse response = newHandshakeResponse(channel.bufferAllocator(), req, responseHeaders);
// 移除 HttpObjectAggregator 和 HttpContentCompressor
ChannelPipeline p = channel.pipeline();
if (p.get(HttpObjectAggregator.class) != null) {
p.remove(HttpObjectAggregator.class);
}
if (p.get(HttpContentCompressor.class) != null) {
p.remove(HttpContentCompressor.class);
}
ChannelHandlerContext ctx = p.context(HttpRequestDecoder.class);
final String encoderName;
if (ctx == null) {
// this means the user use an HttpServerCodec
ctx = p.context(HttpServerCodec.class);
if (ctx == null) {
response.close();
return channel.newFailedFuture(
new IllegalStateException("No HttpDecoder and no HttpServerCodec in the pipeline"));
}
p.addBefore(ctx.name(), "wsencoder", newWebSocketEncoder());
p.addBefore(ctx.name(), "wsdecoder", newWebsocketDecoder());
encoderName = ctx.name();
} else {
p.replace(ctx.name(), "wsdecoder", newWebsocketDecoder());
encoderName = p.context(HttpResponseEncoder.class).name();
p.addBefore(encoderName, "wsencoder", newWebSocketEncoder());
}
return channel.writeAndFlush(response).addListener(channel, (ch, future) -> {
if (future.isSuccess()) {
ChannelPipeline p1 = ch.pipeline();
p1.remove(encoderName);
}
});
}
// 处理非 FullHttpRequest 的 WebSocket 握手场景,通过临时注入 ChannelHandler 聚合请求数据并完成协议切换
public Future<Void> handshake(Channel channel, HttpRequest req) {
return handshake(channel, req, null);
}
// 在没有使用 HttpObjectAggregator 的情况下,
// 动态地通过临时注入一个 ChannelHandler 来手动聚合 HTTP 请求的各个部分
// 最终组装成一个 FullHttpRequest,完成 WebSocket 握手的流程
public final Future<Void> handshake(final Channel channel, HttpRequest req, final HttpHeaders responseHeaders) {
// 如果传进来的 req 已经是 FullHttpRequest,直接调用已有的 handshake(Channel, FullHttpRequest, HttpHeaders) 方法处理。
// 否则,说明请求是分段的(HttpRequest + HttpContent),需要手动聚合。
if (req instanceof FullHttpRequest) {
return handshake(channel, (FullHttpRequest) req, responseHeaders);
}
ChannelPipeline pipeline = channel.pipeline();
// 先在 ChannelPipeline 里找 HttpRequestDecoder 的 ChannelHandlerContext。
// 如果没找到,再找 HttpServerCodec。
// 如果都没找到,直接失败,返回异常。
ChannelHandlerContext ctx = pipeline.context(HttpRequestDecoder.class);
if (ctx == null) {
// This means the user use a HttpServerCodec
ctx = pipeline.context(HttpServerCodec.class);
if (ctx == null) {
return channel.newFailedFuture(new IllegalStateException("No HttpDecoder and no HttpServerCodec in the pipeline"));
}
}
// 动态注入一个临时的 ChannelHandlerAdapter,名字叫 "handshaker"
// 它的职责是监听接下来流入的 HttpObject 消息,把 HttpRequest、HttpContent、LastHttpContent 等部分组装成一个完整的 FullHttpRequest
// 当完整请求组装完成后:
// 1. 立刻移除自己(ctx.pipeline().remove(this)),避免继续拦截后续消息。
// 2. 调用真正的 handshake(Channel, FullHttpRequest, HttpHeaders) 继续 WebSocket 握手。
// 3. 把握手的 Future 结果关联到当前的 promise 上。
final Promise<Void> promise = channel.newPromise();
pipeline.addAfter(ctx.name(), "handshaker", new ChannelHandlerAdapter() {
private FullHttpRequest fullHttpRequest;
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (msg instanceof HttpObject) {
try {
handleHandshakeRequest(ctx, (HttpObject) msg);
} finally {
Resource.dispose(msg);
}
} else {
super.channelRead(ctx, msg);
}
}
@Override
public void channelExceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
ctx.pipeline().remove(this);
promise.tryFailure(cause);
super.channelExceptionCaught(ctx, cause);
}
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
try {
// Fail promise if Channel was closed
if (!promise.isDone()) {
promise.tryFailure(new ClosedChannelException());
}
ctx.fireChannelInactive();
} finally {
releaseFullHttpRequest();
}
}
@Override
public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
releaseFullHttpRequest();
}
private void handleHandshakeRequest(ChannelHandlerContext ctx, HttpObject httpObject) {
if (httpObject instanceof FullHttpRequest) {
ctx.pipeline().remove(this);
handshake(channel, (FullHttpRequest) httpObject, responseHeaders).cascadeTo(promise);
return;
}
if (httpObject instanceof LastHttpContent) {
assert fullHttpRequest != null;
try (FullHttpRequest handshakeRequest = fullHttpRequest) {
fullHttpRequest = null;
ctx.pipeline().remove(this);
handshake(channel, handshakeRequest, responseHeaders).cascadeTo(promise);
}
return;
}
if (httpObject instanceof HttpRequest) {
HttpRequest httpRequest = (HttpRequest) httpObject;
fullHttpRequest = new DefaultFullHttpRequest(httpRequest.protocolVersion(), httpRequest.method(),
httpRequest.uri(), ctx.bufferAllocator().allocate(0),
httpRequest.headers(), HttpHeaders.emptyHeaders());
if (httpRequest.decoderResult().isFailure()) {
fullHttpRequest.setDecoderResult(httpRequest.decoderResult());
}
}
}
private void releaseFullHttpRequest() {
if (fullHttpRequest != null) {
fullHttpRequest.close();
fullHttpRequest = null;
}
}
});
try {
ctx.fireChannelRead(ReferenceCountUtil.retain(req));
} catch (Throwable cause) {
promise.setFailure(cause);
}
return promise.asFuture();
}
public Future<Void> close(Channel channel, CloseWebSocketFrame frame) {
requireNonNull(channel, "channel");
return close0(channel, frame);
}
public Future<Void> close(ChannelHandlerContext ctx, CloseWebSocketFrame frame) {
requireNonNull(ctx, "ctx");
return close0(ctx, frame);
}
private static Future<Void> close0(ChannelOutboundInvoker invoker, CloseWebSocketFrame frame) {
return invoker.writeAndFlush(frame).addListener(invoker, ChannelFutureListeners.CLOSE);
}
// WebSocketServerHandshaker13.newHandshakeResponse
// 服务端从客户端请求的子协议中选出一个自己支持的返回给客户端的过程
protected String selectSubprotocol(String requestedSubprotocols) {
if (requestedSubprotocols == null || subprotocols.length == 0) {
return null;
}
String[] requestedSubprotocolArray = requestedSubprotocols.split(",");
for (String p : requestedSubprotocolArray) {
String requestedSubprotocol = p.trim();
for (String supportedSubprotocol : subprotocols) {
if (SUB_PROTOCOL_WILDCARD.equals(supportedSubprotocol) || requestedSubprotocol.equals(supportedSubprotocol)) {
selectedSubprotocol = requestedSubprotocol;
return requestedSubprotocol;
}
}
}
// No match found
return null;
}
protected abstract FullHttpResponse newHandshakeResponse(BufferAllocator allocator, FullHttpRequest req,HttpHeaders responseHeaders);
protected abstract WebSocketFrameDecoder newWebsocketDecoder();
protected abstract WebSocketFrameEncoder newWebSocketEncoder();
}