【stomp 实战】spring websocket源码分析之握手请求的处理

上一节【搭建一套websocket推送平台】我们通过一个项目,实现了一套推送平台。由于spring框架对于websocket的支持和stomp协议的良好封装,我们很容易地就实现了websocket的消息推送功能。虽然搭建这么一套推送系统不难,但是如果不了解其底层原理,当出现问题时,我们就比较痛苦了。这次我们就来分析一下这块的源码。

一、WebSocket 握手过程

1.1 客户端握手请求

客户端发起 WebSocket 握手流程。客户端发送带有如下请求头的标准 HTTP 请求(HTTP 版本必须是 1.1 或更高,并且请求方法必须是 GET):

javascript 复制代码
GET /chat HTTP/1.1
Host: example.com:8000
Upgrade: websocket
Connection: Upgrade
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==
Sec-WebSocket-Version: 13

1.2 服务端握手响应

当服务端收到握手请求时,将发送一个特殊响应,该响应表明协议将从 HTTP 变更为 WebSocket。

该响应头大致如下(记住,每个响应头行以 \r\n 结尾,在最后一行的后面添加额外的 \r\n,以说明响应头结束):

javascript 复制代码
HTTP/1.1 101 Switching Protocols
Upgrade: websocket
Connection: Upgrade
Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=

二、源码分析

上面就是握手的大概请求和响应报文。通过握手,客户端和服务端就可以建立连接了。

我们来看一下源码中是如何实现的。

整个过程我总结成了一个流程图,对照着这个流程图,我们再来一步步分析代码,避免在源码中迷路

2.1 流程图

2.2 把请求交给对应的处理器

如果你看过Spring-MVC的代码,你一定对DispatcherServlet有一定印象。这个Servlet是所有http请求的入口,所有的http请求都会经过它。

java 复制代码
	protected void doDispatch(HttpServletRequest request, HttpServletResponse response) throws Exception {
				// 通过请求找到handlerAdapter
				HandlerAdapter ha = getHandlerAdapter(mappedHandler.getHandler());
		。。。略
				//用这个handlerAdapter来执行请求处理
				mv = ha.handle(processedRequest, response, mappedHandler.getHandler());
	...略

省略大量代码后,实际上主要做了两件事

  • 通过请求找到handlerAdapter。握手报文进来后,找到的是HttpRequestHandlerAdapter
  • 这个handlerAdapter来执行请求处理

HttpRequestHandlerAdapter又将处理SockJsHttpRequestHandler处理。最终是DefaultSockJsService来处理请求。

java 复制代码
//SockJsHttpRequestHandler代码
	@Override
	public void handleRequest(HttpServletRequest servletRequest, HttpServletResponse servletResponse)
			throws ServletException, IOException {

		ServerHttpRequest request = new ServletServerHttpRequest(servletRequest);
		ServerHttpResponse response = new ServletServerHttpResponse(servletResponse);

		try {
		//DefaultSockJsService.handleRequest
			this.sockJsService.handleRequest(request, response, getSockJsPath(servletRequest), this.webSocketHandler);
		}
		catch (Exception ex) {
			throw new SockJsException("Uncaught failure in SockJS request, uri=" + request.getURI(), ex);
		}
	}

进入this.sockJsService.handleRequest,会由(DefaultSockJsService的父类)AbstractSockJsService.handleRequest来处理请求

下面的代码有点长,完全理解有难度。还是抓重点

java 复制代码
	public final void handleRequest(ServerHttpRequest request, ServerHttpResponse response,
			@Nullable String sockJsPath, WebSocketHandler wsHandler) throws SockJsException {

		if (sockJsPath == null) {
			if (logger.isWarnEnabled()) {
				logger.warn(LogFormatUtils.formatValue(
						"Expected SockJS path. Failing request: " + request.getURI(), -1, true));
			}
			response.setStatusCode(HttpStatus.NOT_FOUND);
			return;
		}

		try {
			request.getHeaders();
		}
		catch (InvalidMediaTypeException ex) {
			// As per SockJS protocol content-type can be ignored (it's always json)
		}

		String requestInfo = (logger.isDebugEnabled() ? request.getMethod() + " " + request.getURI() : null);

		try {
			if (sockJsPath.isEmpty() || sockJsPath.equals("/")) {
				if (requestInfo != null) {
					logger.debug("Processing transport request: " + requestInfo);
				}
				if ("websocket".equalsIgnoreCase(request.getHeaders().getUpgrade())) {
					response.setStatusCode(HttpStatus.BAD_REQUEST);
					return;
				}
				response.getHeaders().setContentType(new MediaType("text", "plain", StandardCharsets.UTF_8));
				response.getBody().write("Welcome to SockJS!\n".getBytes(StandardCharsets.UTF_8));
			}

			else if (sockJsPath.equals("/info")) {
				if (requestInfo != null) {
					logger.debug("Processing transport request: " + requestInfo);
				}
				this.infoHandler.handle(request, response);
			}

			else if (sockJsPath.matches("/iframe[0-9-.a-z_]*.html")) {
				if (!getAllowedOrigins().isEmpty() && !getAllowedOrigins().contains("*") ||
						!getAllowedOriginPatterns().isEmpty()) {
					if (requestInfo != null) {
						logger.debug("Iframe support is disabled when an origin check is required. " +
								"Ignoring transport request: " + requestInfo);
					}
					response.setStatusCode(HttpStatus.NOT_FOUND);
					return;
				}
				if (getAllowedOrigins().isEmpty()) {
					response.getHeaders().add(XFRAME_OPTIONS_HEADER, "SAMEORIGIN");
				}
				if (requestInfo != null) {
					logger.debug("Processing transport request: " + requestInfo);
				}
				this.iframeHandler.handle(request, response);
			}

			else if (sockJsPath.equals("/websocket")) {
				if (isWebSocketEnabled()) {
					if (requestInfo != null) {
						logger.debug("Processing transport request: " + requestInfo);
					}
					handleRawWebSocketRequest(request, response, wsHandler);
				}
				else if (requestInfo != null) {
					logger.debug("WebSocket disabled. Ignoring transport request: " + requestInfo);
				}
			}

			else {
				String[] pathSegments = StringUtils.tokenizeToStringArray(sockJsPath.substring(1), "/");
				if (pathSegments.length != 3) {
					if (logger.isWarnEnabled()) {
						logger.warn(LogFormatUtils.formatValue("Invalid SockJS path '" + sockJsPath + "' - " +
								"required to have 3 path segments", -1, true));
					}
					if (requestInfo != null) {
						logger.debug("Ignoring transport request: " + requestInfo);
					}
					response.setStatusCode(HttpStatus.NOT_FOUND);
					return;
				}

				String serverId = pathSegments[0];
				String sessionId = pathSegments[1];
				String transport = pathSegments[2];

				if (!isWebSocketEnabled() && transport.equals("websocket")) {
					if (requestInfo != null) {
						logger.debug("WebSocket disabled. Ignoring transport request: " + requestInfo);
					}
					response.setStatusCode(HttpStatus.NOT_FOUND);
					return;
				}
				else if (!validateRequest(serverId, sessionId, transport) || !validatePath(request)) {
					if (requestInfo != null) {
						logger.debug("Ignoring transport request: " + requestInfo);
					}
					response.setStatusCode(HttpStatus.NOT_FOUND);
					return;
				}

				if (requestInfo != null) {
					logger.debug("Processing transport request: " + requestInfo);
				}
				handleTransportRequest(request, response, wsHandler, sessionId, transport);
			}
			response.close();
		}
		catch (IOException ex) {
			throw new SockJsException("Failed to write to the response", null, ex);
		}
	}

上面的代码很长,阅读起来不容易,可以通过debug的方式看下整个流程。这里就不细讲了,重点的一个方法

TransportHandlingSockJsService.handleTransportRequest(request, response, wsHandler, sessionId, transport);

代码如下,主要流程见代码注释

java 复制代码
	@Override
	protected void handleTransportRequest(ServerHttpRequest request, ServerHttpResponse response,
			WebSocketHandler handler, String sessionId, String transport) throws SockJsException {
		//这个值是Websocket
		TransportType transportType = TransportType.fromValue(transport);
		if (transportType == null) {
			if (logger.isWarnEnabled()) {
				logger.warn(LogFormatUtils.formatValue("Unknown transport type for " + request.getURI(), -1, true));
			}
			response.setStatusCode(HttpStatus.NOT_FOUND);
			return;
		}
		//这里取到的是WebSocketTransportHandler
		TransportHandler transportHandler = this.handlers.get(transportType);
		if (transportHandler == null) {
			if (logger.isWarnEnabled()) {
				logger.warn(LogFormatUtils.formatValue("No TransportHandler for " + request.getURI(), -1, true));
			}
			response.setStatusCode(HttpStatus.NOT_FOUND);
			return;
		}

		SockJsException failure = null;
		//构造一个拦截链,我们可以注册自己的拦截器,这样就可以在握手阶段来注入我们自己的业务逻辑,比如报文校验等
		HandshakeInterceptorChain chain = new HandshakeInterceptorChain(this.interceptors, handler);

		try {
			HttpMethod supportedMethod = transportType.getHttpMethod();
			if (supportedMethod != request.getMethod()) {
				if (request.getMethod() == HttpMethod.OPTIONS && transportType.supportsCors()) {
					if (checkOrigin(request, response, HttpMethod.OPTIONS, supportedMethod)) {
						response.setStatusCode(HttpStatus.NO_CONTENT);
						addCacheHeaders(response);
					}
				}
				else if (transportType.supportsCors()) {
					sendMethodNotAllowed(response, supportedMethod, HttpMethod.OPTIONS);
				}
				else {
					sendMethodNotAllowed(response, supportedMethod);
				}
				return;
			}
		//会话的创建
			SockJsSession session = this.sessions.get(sessionId);
			boolean isNewSession = false;
			if (session == null) {
				if (transportHandler instanceof SockJsSessionFactory) {
					Map<String, Object> attributes = new HashMap<>();
					//拦截链的前置处理
					if (!chain.applyBeforeHandshake(request, response, attributes)) {
						return;
					}
					SockJsSessionFactory sessionFactory = (SockJsSessionFactory) transportHandler;
					session = createSockJsSession(sessionId, sessionFactory, handler, attributes);
					isNewSession = true;
				}
				else {
					response.setStatusCode(HttpStatus.NOT_FOUND);
					if (logger.isDebugEnabled()) {
						logger.debug("Session not found, sessionId=" + sessionId +
								". The session may have been closed " +
								"(e.g. missed heart-beat) while a message was coming in.");
					}
					return;
				}
			}
			else {
				Principal principal = session.getPrincipal();
				if (principal != null && !principal.equals(request.getPrincipal())) {
					logger.debug("The user for the session does not match the user for the request.");
					response.setStatusCode(HttpStatus.NOT_FOUND);
					return;
				}
				if (!transportHandler.checkSessionType(session)) {
					logger.debug("Session type does not match the transport type for the request.");
					response.setStatusCode(HttpStatus.NOT_FOUND);
					return;
				}
			}

			if (transportType.sendsNoCacheInstruction()) {
				addNoCacheHeaders(response);
			}
			if (transportType.supportsCors() && !checkOrigin(request, response)) {
				return;
			}
			//这里是核心的处理逻辑
			transportHandler.handleRequest(request, response, handler, session);

			if (isNewSession && (response instanceof ServletServerHttpResponse)) {
				int status = ((ServletServerHttpResponse) response).getServletResponse().getStatus();
				if (HttpStatus.valueOf(status).is4xxClientError()) {
					this.sessions.remove(sessionId);
				}
			}
			//拦截链的后置处理
			chain.applyAfterHandshake(request, response, null);
		}
		catch (SockJsException ex) {
			failure = ex;
		}
		catch (Exception ex) {
			failure = new SockJsException("Uncaught failure for request " + request.getURI(), sessionId, ex);
		}
		finally {
			if (failure != null) {
				chain.applyAfterHandshake(request, response, failure);
				throw failure;
			}
		}
	}

总结起来有几个过程

  • 先取到一个hander,这里取到的是WebSocketTransportHandler
  • 构造一个拦截链,我们可以注册自己的拦截器,这样就可以在握手阶段来注入我们自己的业务逻辑,比如报文校验等
  • 拦截链的前置处理
  • 创建一个用户会话,握手时,肯定会话是空的,得建一个会话session
  • 处理器hander处理核心逻辑:transportHandler.handleRequest(request, response, handler, session);。这里后面详细写
  • 拦截链的后置处理

transportHandler.handleRequest(request, response, handler, session); 到底做了啥

java 复制代码
	public void handleRequest(ServerHttpRequest request, ServerHttpResponse response,
			WebSocketHandler wsHandler, SockJsSession wsSession) throws SockJsException {

		WebSocketServerSockJsSession sockJsSession = (WebSocketServerSockJsSession) wsSession;
		try {
			wsHandler = new SockJsWebSocketHandler(getServiceConfig(), wsHandler, sockJsSession);
			//握手处理器握手
			this.handshakeHandler.doHandshake(request, response, wsHandler, sockJsSession.getAttributes());
		}
		catch (Exception ex) {
			sockJsSession.tryCloseWithSockJsTransportError(ex, CloseStatus.SERVER_ERROR);
			throw new SockJsTransportFailureException("WebSocket handshake failure", wsSession.getId(), ex);
		}
	}

终于看到了一个握手处理器,handshakeHandler。先不看代码,猜测一下,这里的作用应该是,构造一个握手返回报文,然后通过response写回给客户端。然后告知web容器tomcat,当前请求升级为websocket了。这样,浏览器后面就可以发送websockdet消息了。

握手的逻辑代码,这里不是我们主要研究的点,就不再细讲了。

升级成功后,tomcat有个回调方法,然后再进行一系列的初始化动作

上面的代码是tomcat的代码,可以看到红框中的StandardWebSocketHandlerAdapter。这里进行一第列的初始化动作

2.3 websocket初始化

StandardWebSocketHandlerAdapter.onOpen是入口,由Tomcat回调。

java 复制代码
	public void onOpen(final javax.websocket.Session session, EndpointConfig config) {
		this.wsSession.initializeNativeSession(session);

		// The following inner classes need to remain since lambdas would not retain their
		// declared generic types (which need to be seen by the underlying WebSocket engine)

		if (this.handler.supportsPartialMessages()) {
			session.addMessageHandler(new MessageHandler.Partial<String>() {
				@Override
				public void onMessage(String message, boolean isLast) {
					handleTextMessage(session, message, isLast);
				}
			});
			session.addMessageHandler(new MessageHandler.Partial<ByteBuffer>() {
				@Override
				public void onMessage(ByteBuffer message, boolean isLast) {
					handleBinaryMessage(session, message, isLast);
				}
			});
		}
		else {
			session.addMessageHandler(new MessageHandler.Whole<String>() {
				@Override
				public void onMessage(String message) {
					handleTextMessage(session, message, true);
				}
			});
			session.addMessageHandler(new MessageHandler.Whole<ByteBuffer>() {
				@Override
				public void onMessage(ByteBuffer message) {
					handleBinaryMessage(session, message, true);
				}
			});
		}

		session.addMessageHandler(new MessageHandler.Whole<javax.websocket.PongMessage>() {
			@Override
			public void onMessage(javax.websocket.PongMessage message) {
				handlePongMessage(session, message.getApplicationData());
			}
		});

		try {
			this.handler.afterConnectionEstablished(this.wsSession);
		}
		catch (Exception ex) {
			ExceptionWebSocketHandlerDecorator.tryCloseWithError(this.wsSession, ex, logger);
		}
	}

代码总结:

  • 这里入参传了一个javax.websocket.Session。这个可以理解为当前Websocket连接。
  • 原来这个Session可以给自己添加messageHandler,那当有消息来的时候,就会经过这些handler来进行处理。
  • 那这个hander就是处理业务消息的重点了
    看一下这个hander是怎么处理消息的
java 复制代码
private void handleTextMessage(javax.websocket.Session session, String payload, boolean isLast) {
		TextMessage textMessage = new TextMessage(payload, isLast);
		try {
			this.handler.handleMessage(this.wsSession, textMessage);
		}
		catch (Exception ex) {
			ExceptionWebSocketHandlerDecorator.tryCloseWithError(this.wsSession, ex, logger);
		}
	}

这个handler,对应的实现是:SockJsWebSocketHandler

进入handleMessage看一下处理逻辑,原来是将消息分为三类

  • 文本消息
  • 二进制消息
  • 心跳消息
    这三种消息,分别进行处理
java 复制代码
@Override
	public void handleMessage(WebSocketSession session, WebSocketMessage<?> message) throws Exception {
		if (message instanceof TextMessage) {
			handleTextMessage(session, (TextMessage) message);
		}
		else if (message instanceof BinaryMessage) {
			handleBinaryMessage(session, (BinaryMessage) message);
		}
		else if (message instanceof PongMessage) {
			handlePongMessage(session, (PongMessage) message);
		}
		else {
			throw new IllegalStateException("Unexpected WebSocket message type: " + message);
		}
	}

我们一般处理的是文本消息

java 复制代码
	@Override
	public void handleTextMessage(WebSocketSession wsSession, TextMessage message) throws Exception {
		this.sockJsSession.handleMessage(message, wsSession);
	}

又交给sockJsSession来处理消息。

消息的处理过程,我们暂且不表。下节再来分析。

this.handler.afterConnectionEstablished(this.wsSession);

java 复制代码
	public void initializeDelegateSession(WebSocketSession session) {
		synchronized (this.initSessionLock) {
			this.webSocketSession = session;
			try {
				// Let "our" handler know before sending the open frame to the remote handler
				delegateConnectionEstablished();
				this.webSocketSession.sendMessage(new TextMessage(SockJsFrame.openFrame().getContent()));

				// Flush any messages cached in the meantime
				while (!this.initSessionCache.isEmpty()) {
					writeFrame(SockJsFrame.messageFrame(getMessageCodec(), this.initSessionCache.poll()));
				}
				scheduleHeartbeat();
				this.openFrameSent = true;
			}
			catch (Exception ex) {
				tryCloseWithSockJsTransportError(ex, CloseStatus.SERVER_ERROR);
			}
		}
	}

这里的逻辑如下

  • delegateConnectionEstablished。让我们的hander知晓,当前Websocket连接已经建立了,这是个回调方法
  • 发送一个websocket open报文给客户端
  • 开启websocket心跳线程

代码就分析完毕了,结合最开始的流程图,可以自己再debug一下加深印象。

三、总结

整个握手过程包含以下关键步骤

  • 通过http请求,找到对应的握手的处理器
  • 握手处理器将websocket握手成功的返回报文发送给客户端
  • web容器回调自身,告知协议升级
  • 注册消息处理器,当有websocket消息来时,就会回调处理器进行消息的逻辑处理
  • 初始化事件,包括发送一个open报文给客户端,开启Websocket心跳线程等
相关推荐
小马爱打代码8 小时前
Spring源码 第九篇:Spring 5 源码深度拆解 - Spring 事件驱动模型
java·后端·spring
ForgeAI码匠8 小时前
ForgeAdmin|Spring Boot 3 后台框架的自动配置设计:少写配置,多做组合
java·spring boot·后端
tongluowan0079 小时前
Redisson的参数及工作原理
java·redis·lua·分布式锁
仙俊红9 小时前
Integer\int对比,equals()\hashcode面试
java·面试·职场和发展
WiChP9 小时前
【V0.1B10】从零开始的2D游戏引擎开发之路
java·数据库·游戏引擎
云烟成雨TD10 小时前
Spring AI Alibaba 1.x 系列【60】检查点机制原理与全流程剖析
java·人工智能·spring
ForgeAI码匠10 小时前
Maven 多模块项目如何避免越写越乱?Forge Admin 的模块边界实践
java·人工智能·开源·maven
z落落10 小时前
C# 数组 最终完整版全套笔记(一维+多维+交错+引用类型+对象数组)
java·笔记·c#
Access开发易登软件10 小时前
Access 和 SQLite,根本不在一个赛道上
java·jvm·数据库·sqlite·excel·vba·access开发
小马爱打代码10 小时前
Spring源码 第十篇:Spring 5 源码深度拆解 - Spring 类型转换与校验体系
java·spring