[netty5: WebSocketServerHandshaker & WebSocketServerHandshakerFactory]-源码分析

在阅读这篇文章前,推荐先阅读以下内容:

  1. [netty5: WebSocketFrame]-源码分析
  2. [netty5: WebSocketFrameEncoder & WebSocketFrameDecoder]-源码解析

WebSocketServerHandshakerFactory

WebSocketServerHandshakerFactory 用于根据客户端请求中的 WebSocket 版本构造对应的 WebSocketServerHandshaker 实例,完成握手协议版本的协商与支持判断。

java 复制代码
public class WebSocketServerHandshakerFactory {

    private final String webSocketURL;

    private final String subprotocols;

    private final WebSocketDecoderConfig decoderConfig;

	// ...
	
    public WebSocketServerHandshakerFactory(String webSocketURL, String subprotocols, WebSocketDecoderConfig decoderConfig) {
        this.webSocketURL = webSocketURL;
        this.subprotocols = subprotocols;
        this.decoderConfig = Objects.requireNonNull(decoderConfig, "decoderConfig");
    }

    public WebSocketServerHandshaker newHandshaker(HttpRequest req) {
        return resolveHandshaker0(req, webSocketURL, subprotocols, decoderConfig);
    }

    public static WebSocketServerHandshaker resolveHandshaker(HttpRequest req, String webSocketURL, String subprotocols, WebSocketDecoderConfig decoderConfig) {
        Objects.requireNonNull(decoderConfig, "decoderConfig");
        return resolveHandshaker0(req, webSocketURL, subprotocols, decoderConfig);
    }

    private static WebSocketServerHandshaker resolveHandshaker0(HttpRequest req, String webSocketURL, String subprotocols, WebSocketDecoderConfig decoderConfig) {
        CharSequence version = req.headers().get(HttpHeaderNames.SEC_WEBSOCKET_VERSION);
        if (version != null && AsciiString.contentEqualsIgnoreCase(version, WebSocketVersion.V13.toAsciiString())) {
            // Version 13 of the wire protocol - RFC 6455 (version 17 of the draft hybi specification).
            return new WebSocketServerHandshaker13(webSocketURL, subprotocols, decoderConfig);
        }
        return null;
    }

    public static Future<Void> sendUnsupportedVersionResponse(Channel channel) {
        HttpResponse res = new DefaultFullHttpResponse(
            HttpVersion.HTTP_1_1,
            HttpResponseStatus.UPGRADE_REQUIRED, 
            channel.bufferAllocator().allocate(0)
        );
        res.headers().set(HttpHeaderNames.SEC_WEBSOCKET_VERSION, WebSocketVersion.V13.toHttpHeaderValue());
        HttpUtil.setContentLength(res, 0);
        return channel.writeAndFlush(res);
    }
}

WebSocketServerHandshaker13

WebSocketServerHandshaker13 负责基于 RFC 6455 实现 WebSocket 版本 13 的服务端握手处理流程,包括请求校验、响应生成、子协议协商和帧编解码器的安装。

java 复制代码
public class WebSocketServerHandshaker13 extends WebSocketServerHandshaker {

    public WebSocketServerHandshaker13(String webSocketURL, String subprotocols, WebSocketDecoderConfig decoderConfig) {
        super(WebSocketVersion.V13, webSocketURL, subprotocols, decoderConfig);
    }

    /**
     * <p>
     * Handle the web socket handshake for the web socket specification <a href=
     * "https://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-17">HyBi versions 13-17</a>. Versions 13-17
     * share the same wire protocol.
     * </p>
     *
     * <p>
     * Browser request to the server:
     * </p>
     *
     * <pre>
     * GET /chat HTTP/1.1
     * Host: server.example.com
     * Upgrade: websocket
     * Connection: Upgrade
     * Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==
     * Origin: http://example.com
     * Sec-WebSocket-Protocol: chat, superchat
     * Sec-WebSocket-Version: 13
     * </pre>
     *
     * <p>
     * Server response:
     * </p>
     *
     * <pre>
     * HTTP/1.1 101 Switching Protocols
     * Upgrade: websocket
     * Connection: Upgrade
     * Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=
     * Sec-WebSocket-Protocol: chat
     * </pre>
     */
    @Override
    protected FullHttpResponse newHandshakeResponse(BufferAllocator allocator, FullHttpRequest req, HttpHeaders headers) {
        HttpMethod method = req.method();
        if (!HttpMethod.GET.equals(method)) {
            throw new WebSocketServerHandshakeException("Invalid WebSocket handshake method: " + method, req);
        }

        HttpHeaders reqHeaders = req.headers();
        if (!reqHeaders.contains(HttpHeaderNames.CONNECTION) || !reqHeaders.containsIgnoreCase(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE)) {
            throw new WebSocketServerHandshakeException("not a WebSocket request: a |Connection| header must includes a token 'Upgrade'", req);
        }

        if (!reqHeaders.containsIgnoreCase(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET)) {
            throw new WebSocketServerHandshakeException("not a WebSocket request: a |Upgrade| header must containing the value 'websocket'", req);
        }

        CharSequence key = reqHeaders.get(HttpHeaderNames.SEC_WEBSOCKET_KEY);
        if (key == null) {
            throw new WebSocketServerHandshakeException("not a WebSocket request: missing key", req);
        }

        FullHttpResponse res = new DefaultFullHttpResponse(
    		HTTP_1_1, 
    		HttpResponseStatus.SWITCHING_PROTOCOLS,
            allocator.allocate(0)
        );
        
        if (headers != null) {
            res.headers().add(headers);
        }

        String accept = WebSocketUtil.calculateV13Accept(key.toString());
        res.headers().set(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET)
                     .set(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE)
                     .set(HttpHeaderNames.SEC_WEBSOCKET_ACCEPT, accept);

        CharSequence subprotocols = reqHeaders.get(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL);
        if (subprotocols != null) {
            String selectedSubprotocol = selectSubprotocol(subprotocols.toString());
            if (selectedSubprotocol == null) {
                if (logger.isDebugEnabled()) {
                    logger.debug("Requested subprotocol(s) not supported: {}", subprotocols);
                }
            } else {
                res.headers().set(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL, selectedSubprotocol);
            }
        }
        return res;
    }

    @Override
    protected WebSocketFrameDecoder newWebsocketDecoder() {
        return new WebSocket13FrameDecoder(decoderConfig());
    }

    @Override
    protected WebSocketFrameEncoder newWebSocketEncoder() {
        return new WebSocket13FrameEncoder(false);
    }
}

WebSocketServerHandshaker

WebSocketServerHandshaker 是 WebSocket 握手处理的抽象基类,定义了服务端握手响应、子协议选择和编解码器安装等通用逻辑,供具体版本(如 V13)实现。

java 复制代码
public abstract class WebSocketServerHandshaker {
    protected static final Logger logger = LoggerFactory.getLogger(WebSocketServerHandshaker.class);

    private final String uri;

    private final String[] subprotocols;

    private final WebSocketVersion version;

    private final WebSocketDecoderConfig decoderConfig;

    private String selectedSubprotocol;

    public static final String SUB_PROTOCOL_WILDCARD = "*";

    protected WebSocketServerHandshaker(WebSocketVersion version, String uri, String subprotocols, WebSocketDecoderConfig decoderConfig) {
        this.version = version;
        this.uri = uri;
        if (subprotocols != null) {
            String[] subprotocolArray = subprotocols.split(",");
            for (int i = 0; i < subprotocolArray.length; i++) {
                subprotocolArray[i] = subprotocolArray[i].trim();
            }
            this.subprotocols = subprotocolArray;
        } else {
            this.subprotocols = EmptyArrays.EMPTY_STRINGS;
        }
        this.decoderConfig = requireNonNull(decoderConfig, "decoderConfig");
    }

	// 将当前 Handshaker 支持的子协议数组转换为有序去重的 Set 返回,用于后续子协议协商。
    public Set<String> subprotocols() {
        Set<String> ret = new LinkedHashSet<>();
        Collections.addAll(ret, subprotocols);
        return ret;
    }
    
    // WebSocketServerProtocolHandshakeHandler.channelRead
    // 执行 WebSocket 握手响应、替换或插入编解码器并清理不兼容的 HTTP 处理器,最终完成协议切换。
    public Future<Void> handshake(Channel channel, FullHttpRequest req) {
        return handshake(channel, req, null);
    }

    public final Future<Void> handshake(Channel channel, FullHttpRequest req, HttpHeaders responseHeaders) {

        if (logger.isDebugEnabled()) {
            logger.debug("{} WebSocket version {} server handshake", channel, version());
        }
	
		//  WebSocketServerHandshaker13.newHandshakeResponse
        FullHttpResponse response = newHandshakeResponse(channel.bufferAllocator(), req, responseHeaders);
        
        // 移除 HttpObjectAggregator 和 HttpContentCompressor
        ChannelPipeline p = channel.pipeline();
        if (p.get(HttpObjectAggregator.class) != null) {
            p.remove(HttpObjectAggregator.class);
        }
        if (p.get(HttpContentCompressor.class) != null) {
            p.remove(HttpContentCompressor.class);
        }
        
        ChannelHandlerContext ctx = p.context(HttpRequestDecoder.class);
        final String encoderName;
        if (ctx == null) {
            // this means the user use an HttpServerCodec
            ctx = p.context(HttpServerCodec.class);
            if (ctx == null) {
                response.close();
                return channel.newFailedFuture(
                        new IllegalStateException("No HttpDecoder and no HttpServerCodec in the pipeline"));
            }
            p.addBefore(ctx.name(), "wsencoder", newWebSocketEncoder());
            p.addBefore(ctx.name(), "wsdecoder", newWebsocketDecoder());
            encoderName = ctx.name();
        } else {
            p.replace(ctx.name(), "wsdecoder", newWebsocketDecoder());

            encoderName = p.context(HttpResponseEncoder.class).name();
            p.addBefore(encoderName, "wsencoder", newWebSocketEncoder());
        }
        return channel.writeAndFlush(response).addListener(channel, (ch, future) -> {
            if (future.isSuccess()) {
                ChannelPipeline p1 = ch.pipeline();
                p1.remove(encoderName);
            }
        });
    }

	// 处理非 FullHttpRequest 的 WebSocket 握手场景,通过临时注入 ChannelHandler 聚合请求数据并完成协议切换
    public Future<Void> handshake(Channel channel, HttpRequest req) {
        return handshake(channel, req, null);
    }

	// 在没有使用 HttpObjectAggregator 的情况下,
	// 动态地通过临时注入一个 ChannelHandler 来手动聚合 HTTP 请求的各个部分
	// 最终组装成一个 FullHttpRequest,完成 WebSocket 握手的流程
    public final Future<Void> handshake(final Channel channel, HttpRequest req, final HttpHeaders responseHeaders) {
    	// 如果传进来的 req 已经是 FullHttpRequest,直接调用已有的 handshake(Channel, FullHttpRequest, HttpHeaders) 方法处理。
    	// 否则,说明请求是分段的(HttpRequest + HttpContent),需要手动聚合。
        if (req instanceof FullHttpRequest) {
            return handshake(channel, (FullHttpRequest) req, responseHeaders);
        }

        ChannelPipeline pipeline = channel.pipeline();

		//  先在 ChannelPipeline 里找 HttpRequestDecoder 的 ChannelHandlerContext。
		// 如果没找到,再找 HttpServerCodec。
		// 如果都没找到,直接失败,返回异常。
        ChannelHandlerContext ctx = pipeline.context(HttpRequestDecoder.class);
        if (ctx == null) {
            // This means the user use a HttpServerCodec
            ctx = pipeline.context(HttpServerCodec.class);
            if (ctx == null) {
                return channel.newFailedFuture(new IllegalStateException("No HttpDecoder and no HttpServerCodec in the pipeline"));
            }
        }

		// 动态注入一个临时的 ChannelHandlerAdapter,名字叫 "handshaker"
		// 它的职责是监听接下来流入的 HttpObject 消息,把 HttpRequest、HttpContent、LastHttpContent 等部分组装成一个完整的 FullHttpRequest
		// 当完整请求组装完成后:
		// 	1. 立刻移除自己(ctx.pipeline().remove(this)),避免继续拦截后续消息。
		// 	2. 调用真正的 handshake(Channel, FullHttpRequest, HttpHeaders) 继续 WebSocket 握手。
		// 	3. 把握手的 Future 结果关联到当前的 promise 上。
        final Promise<Void> promise = channel.newPromise();
        pipeline.addAfter(ctx.name(), "handshaker", new ChannelHandlerAdapter() {

            private FullHttpRequest fullHttpRequest;

            @Override
            public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
                if (msg instanceof HttpObject) {
                    try {
                        handleHandshakeRequest(ctx, (HttpObject) msg);
                    } finally {
                        Resource.dispose(msg);
                    }
                } else {
                    super.channelRead(ctx, msg);
                }
            }

            @Override
            public void channelExceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
                ctx.pipeline().remove(this);
                promise.tryFailure(cause);
                super.channelExceptionCaught(ctx, cause);
            }

            @Override
            public void channelInactive(ChannelHandlerContext ctx) throws Exception {
                try {
                    // Fail promise if Channel was closed
                    if (!promise.isDone()) {
                        promise.tryFailure(new ClosedChannelException());
                    }
                    ctx.fireChannelInactive();
                } finally {
                    releaseFullHttpRequest();
                }
            }

            @Override
            public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
                releaseFullHttpRequest();
            }

            private void handleHandshakeRequest(ChannelHandlerContext ctx, HttpObject httpObject) {
                if (httpObject instanceof FullHttpRequest) {
                    ctx.pipeline().remove(this);
                    handshake(channel, (FullHttpRequest) httpObject, responseHeaders).cascadeTo(promise);
                    return;
                }

                if (httpObject instanceof LastHttpContent) {
                    assert fullHttpRequest != null;
                    try (FullHttpRequest handshakeRequest = fullHttpRequest) {
                        fullHttpRequest = null;
                        ctx.pipeline().remove(this);
                        handshake(channel, handshakeRequest, responseHeaders).cascadeTo(promise);
                    }
                    return;
                }

                if (httpObject instanceof HttpRequest) {
                    HttpRequest httpRequest = (HttpRequest) httpObject;
                    fullHttpRequest = new DefaultFullHttpRequest(httpRequest.protocolVersion(), httpRequest.method(),
                                                                 httpRequest.uri(), ctx.bufferAllocator().allocate(0),
                                                                 httpRequest.headers(), HttpHeaders.emptyHeaders());
                    if (httpRequest.decoderResult().isFailure()) {
                        fullHttpRequest.setDecoderResult(httpRequest.decoderResult());
                    }
                }
            }

            private void releaseFullHttpRequest() {
                if (fullHttpRequest != null) {
                    fullHttpRequest.close();
                    fullHttpRequest = null;
                }
            }
        });

        try {
            ctx.fireChannelRead(ReferenceCountUtil.retain(req));
        } catch (Throwable cause) {
            promise.setFailure(cause);
        }

        return promise.asFuture();
    }

    public Future<Void> close(Channel channel, CloseWebSocketFrame frame) {
        requireNonNull(channel, "channel");
        return close0(channel, frame);
    }

    public Future<Void> close(ChannelHandlerContext ctx, CloseWebSocketFrame frame) {
        requireNonNull(ctx, "ctx");
        return close0(ctx, frame);
    }

    private static Future<Void> close0(ChannelOutboundInvoker invoker, CloseWebSocketFrame frame) {
        return invoker.writeAndFlush(frame).addListener(invoker, ChannelFutureListeners.CLOSE);
    }

	// WebSocketServerHandshaker13.newHandshakeResponse
	// 服务端从客户端请求的子协议中选出一个自己支持的返回给客户端的过程
    protected String selectSubprotocol(String requestedSubprotocols) {
        if (requestedSubprotocols == null || subprotocols.length == 0) {
            return null;
        }

        String[] requestedSubprotocolArray = requestedSubprotocols.split(",");
        for (String p : requestedSubprotocolArray) {
            String requestedSubprotocol = p.trim();

            for (String supportedSubprotocol : subprotocols) {
                if (SUB_PROTOCOL_WILDCARD.equals(supportedSubprotocol) || requestedSubprotocol.equals(supportedSubprotocol)) {
                    selectedSubprotocol = requestedSubprotocol;
                    return requestedSubprotocol;
                }
            }
        }

        // No match found
        return null;
    }

	protected abstract FullHttpResponse newHandshakeResponse(BufferAllocator allocator, FullHttpRequest req,HttpHeaders responseHeaders);
    protected abstract WebSocketFrameDecoder newWebsocketDecoder();
    protected abstract WebSocketFrameEncoder newWebSocketEncoder();
}
相关推荐
idolyXyz18 小时前
[netty5: WebSocketFrameEncoder & WebSocketFrameDecoder]-源码解析
netty
idolyXyz2 天前
[netty5: HttpObjectEncoder & HttpObjectDecoder]-源码解析
netty
Derek_Smart15 天前
基于Netty与Spring Integration的高并发工业物联网网关架构设计与实现
spring boot·物联网·netty
迢迢星万里灬16 天前
Java求职者面试指南:微服务技术与源码原理深度解析
java·spring cloud·微服务·dubbo·netty·分布式系统·面试指南
Y_3_718 天前
Netty实战:从核心组件到多协议实现(超详细注释,udp,tcp,websocket,http完整demo)
linux·运维·后端·ubuntu·netty
安徽杰杰25 天前
能源即服务:智慧移动充电桩的供给模式创新
netty
安徽杰杰1 个月前
新基建浪潮下:中国新能源汽车充电桩智慧化建设与管理实践
netty
迢迢星万里灬1 个月前
Java求职者面试:微服务技术与源码原理深度解析
java·spring cloud·微服务·dubbo·netty·分布式系统
触角云科技1 个月前
掌上充电站:基于APP/小程序的新能源汽车智慧充电管理
netty