【stomp 实战】spring websocket源码分析之握手请求的处理

上一节【搭建一套websocket推送平台】我们通过一个项目,实现了一套推送平台。由于spring框架对于websocket的支持和stomp协议的良好封装,我们很容易地就实现了websocket的消息推送功能。虽然搭建这么一套推送系统不难,但是如果不了解其底层原理,当出现问题时,我们就比较痛苦了。这次我们就来分析一下这块的源码。

一、WebSocket 握手过程

1.1 客户端握手请求

客户端发起 WebSocket 握手流程。客户端发送带有如下请求头的标准 HTTP 请求(HTTP 版本必须是 1.1 或更高,并且请求方法必须是 GET):

javascript 复制代码
GET /chat HTTP/1.1
Host: example.com:8000
Upgrade: websocket
Connection: Upgrade
Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==
Sec-WebSocket-Version: 13

1.2 服务端握手响应

当服务端收到握手请求时,将发送一个特殊响应,该响应表明协议将从 HTTP 变更为 WebSocket。

该响应头大致如下(记住,每个响应头行以 \r\n 结尾,在最后一行的后面添加额外的 \r\n,以说明响应头结束):

javascript 复制代码
HTTP/1.1 101 Switching Protocols
Upgrade: websocket
Connection: Upgrade
Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=

二、源码分析

上面就是握手的大概请求和响应报文。通过握手,客户端和服务端就可以建立连接了。

我们来看一下源码中是如何实现的。

整个过程我总结成了一个流程图,对照着这个流程图,我们再来一步步分析代码,避免在源码中迷路

2.1 流程图

2.2 把请求交给对应的处理器

如果你看过Spring-MVC的代码,你一定对DispatcherServlet有一定印象。这个Servlet是所有http请求的入口,所有的http请求都会经过它。

java 复制代码
	protected void doDispatch(HttpServletRequest request, HttpServletResponse response) throws Exception {
				// 通过请求找到handlerAdapter
				HandlerAdapter ha = getHandlerAdapter(mappedHandler.getHandler());
		。。。略
				//用这个handlerAdapter来执行请求处理
				mv = ha.handle(processedRequest, response, mappedHandler.getHandler());
	...略

省略大量代码后,实际上主要做了两件事

  • 通过请求找到handlerAdapter。握手报文进来后,找到的是HttpRequestHandlerAdapter
  • 这个handlerAdapter来执行请求处理

HttpRequestHandlerAdapter又将处理SockJsHttpRequestHandler处理。最终是DefaultSockJsService来处理请求。

java 复制代码
//SockJsHttpRequestHandler代码
	@Override
	public void handleRequest(HttpServletRequest servletRequest, HttpServletResponse servletResponse)
			throws ServletException, IOException {

		ServerHttpRequest request = new ServletServerHttpRequest(servletRequest);
		ServerHttpResponse response = new ServletServerHttpResponse(servletResponse);

		try {
		//DefaultSockJsService.handleRequest
			this.sockJsService.handleRequest(request, response, getSockJsPath(servletRequest), this.webSocketHandler);
		}
		catch (Exception ex) {
			throw new SockJsException("Uncaught failure in SockJS request, uri=" + request.getURI(), ex);
		}
	}

进入this.sockJsService.handleRequest,会由(DefaultSockJsService的父类)AbstractSockJsService.handleRequest来处理请求

下面的代码有点长,完全理解有难度。还是抓重点

java 复制代码
	public final void handleRequest(ServerHttpRequest request, ServerHttpResponse response,
			@Nullable String sockJsPath, WebSocketHandler wsHandler) throws SockJsException {

		if (sockJsPath == null) {
			if (logger.isWarnEnabled()) {
				logger.warn(LogFormatUtils.formatValue(
						"Expected SockJS path. Failing request: " + request.getURI(), -1, true));
			}
			response.setStatusCode(HttpStatus.NOT_FOUND);
			return;
		}

		try {
			request.getHeaders();
		}
		catch (InvalidMediaTypeException ex) {
			// As per SockJS protocol content-type can be ignored (it's always json)
		}

		String requestInfo = (logger.isDebugEnabled() ? request.getMethod() + " " + request.getURI() : null);

		try {
			if (sockJsPath.isEmpty() || sockJsPath.equals("/")) {
				if (requestInfo != null) {
					logger.debug("Processing transport request: " + requestInfo);
				}
				if ("websocket".equalsIgnoreCase(request.getHeaders().getUpgrade())) {
					response.setStatusCode(HttpStatus.BAD_REQUEST);
					return;
				}
				response.getHeaders().setContentType(new MediaType("text", "plain", StandardCharsets.UTF_8));
				response.getBody().write("Welcome to SockJS!\n".getBytes(StandardCharsets.UTF_8));
			}

			else if (sockJsPath.equals("/info")) {
				if (requestInfo != null) {
					logger.debug("Processing transport request: " + requestInfo);
				}
				this.infoHandler.handle(request, response);
			}

			else if (sockJsPath.matches("/iframe[0-9-.a-z_]*.html")) {
				if (!getAllowedOrigins().isEmpty() && !getAllowedOrigins().contains("*") ||
						!getAllowedOriginPatterns().isEmpty()) {
					if (requestInfo != null) {
						logger.debug("Iframe support is disabled when an origin check is required. " +
								"Ignoring transport request: " + requestInfo);
					}
					response.setStatusCode(HttpStatus.NOT_FOUND);
					return;
				}
				if (getAllowedOrigins().isEmpty()) {
					response.getHeaders().add(XFRAME_OPTIONS_HEADER, "SAMEORIGIN");
				}
				if (requestInfo != null) {
					logger.debug("Processing transport request: " + requestInfo);
				}
				this.iframeHandler.handle(request, response);
			}

			else if (sockJsPath.equals("/websocket")) {
				if (isWebSocketEnabled()) {
					if (requestInfo != null) {
						logger.debug("Processing transport request: " + requestInfo);
					}
					handleRawWebSocketRequest(request, response, wsHandler);
				}
				else if (requestInfo != null) {
					logger.debug("WebSocket disabled. Ignoring transport request: " + requestInfo);
				}
			}

			else {
				String[] pathSegments = StringUtils.tokenizeToStringArray(sockJsPath.substring(1), "/");
				if (pathSegments.length != 3) {
					if (logger.isWarnEnabled()) {
						logger.warn(LogFormatUtils.formatValue("Invalid SockJS path '" + sockJsPath + "' - " +
								"required to have 3 path segments", -1, true));
					}
					if (requestInfo != null) {
						logger.debug("Ignoring transport request: " + requestInfo);
					}
					response.setStatusCode(HttpStatus.NOT_FOUND);
					return;
				}

				String serverId = pathSegments[0];
				String sessionId = pathSegments[1];
				String transport = pathSegments[2];

				if (!isWebSocketEnabled() && transport.equals("websocket")) {
					if (requestInfo != null) {
						logger.debug("WebSocket disabled. Ignoring transport request: " + requestInfo);
					}
					response.setStatusCode(HttpStatus.NOT_FOUND);
					return;
				}
				else if (!validateRequest(serverId, sessionId, transport) || !validatePath(request)) {
					if (requestInfo != null) {
						logger.debug("Ignoring transport request: " + requestInfo);
					}
					response.setStatusCode(HttpStatus.NOT_FOUND);
					return;
				}

				if (requestInfo != null) {
					logger.debug("Processing transport request: " + requestInfo);
				}
				handleTransportRequest(request, response, wsHandler, sessionId, transport);
			}
			response.close();
		}
		catch (IOException ex) {
			throw new SockJsException("Failed to write to the response", null, ex);
		}
	}

上面的代码很长,阅读起来不容易,可以通过debug的方式看下整个流程。这里就不细讲了,重点的一个方法

TransportHandlingSockJsService.handleTransportRequest(request, response, wsHandler, sessionId, transport);

代码如下,主要流程见代码注释

java 复制代码
	@Override
	protected void handleTransportRequest(ServerHttpRequest request, ServerHttpResponse response,
			WebSocketHandler handler, String sessionId, String transport) throws SockJsException {
		//这个值是Websocket
		TransportType transportType = TransportType.fromValue(transport);
		if (transportType == null) {
			if (logger.isWarnEnabled()) {
				logger.warn(LogFormatUtils.formatValue("Unknown transport type for " + request.getURI(), -1, true));
			}
			response.setStatusCode(HttpStatus.NOT_FOUND);
			return;
		}
		//这里取到的是WebSocketTransportHandler
		TransportHandler transportHandler = this.handlers.get(transportType);
		if (transportHandler == null) {
			if (logger.isWarnEnabled()) {
				logger.warn(LogFormatUtils.formatValue("No TransportHandler for " + request.getURI(), -1, true));
			}
			response.setStatusCode(HttpStatus.NOT_FOUND);
			return;
		}

		SockJsException failure = null;
		//构造一个拦截链,我们可以注册自己的拦截器,这样就可以在握手阶段来注入我们自己的业务逻辑,比如报文校验等
		HandshakeInterceptorChain chain = new HandshakeInterceptorChain(this.interceptors, handler);

		try {
			HttpMethod supportedMethod = transportType.getHttpMethod();
			if (supportedMethod != request.getMethod()) {
				if (request.getMethod() == HttpMethod.OPTIONS && transportType.supportsCors()) {
					if (checkOrigin(request, response, HttpMethod.OPTIONS, supportedMethod)) {
						response.setStatusCode(HttpStatus.NO_CONTENT);
						addCacheHeaders(response);
					}
				}
				else if (transportType.supportsCors()) {
					sendMethodNotAllowed(response, supportedMethod, HttpMethod.OPTIONS);
				}
				else {
					sendMethodNotAllowed(response, supportedMethod);
				}
				return;
			}
		//会话的创建
			SockJsSession session = this.sessions.get(sessionId);
			boolean isNewSession = false;
			if (session == null) {
				if (transportHandler instanceof SockJsSessionFactory) {
					Map<String, Object> attributes = new HashMap<>();
					//拦截链的前置处理
					if (!chain.applyBeforeHandshake(request, response, attributes)) {
						return;
					}
					SockJsSessionFactory sessionFactory = (SockJsSessionFactory) transportHandler;
					session = createSockJsSession(sessionId, sessionFactory, handler, attributes);
					isNewSession = true;
				}
				else {
					response.setStatusCode(HttpStatus.NOT_FOUND);
					if (logger.isDebugEnabled()) {
						logger.debug("Session not found, sessionId=" + sessionId +
								". The session may have been closed " +
								"(e.g. missed heart-beat) while a message was coming in.");
					}
					return;
				}
			}
			else {
				Principal principal = session.getPrincipal();
				if (principal != null && !principal.equals(request.getPrincipal())) {
					logger.debug("The user for the session does not match the user for the request.");
					response.setStatusCode(HttpStatus.NOT_FOUND);
					return;
				}
				if (!transportHandler.checkSessionType(session)) {
					logger.debug("Session type does not match the transport type for the request.");
					response.setStatusCode(HttpStatus.NOT_FOUND);
					return;
				}
			}

			if (transportType.sendsNoCacheInstruction()) {
				addNoCacheHeaders(response);
			}
			if (transportType.supportsCors() && !checkOrigin(request, response)) {
				return;
			}
			//这里是核心的处理逻辑
			transportHandler.handleRequest(request, response, handler, session);

			if (isNewSession && (response instanceof ServletServerHttpResponse)) {
				int status = ((ServletServerHttpResponse) response).getServletResponse().getStatus();
				if (HttpStatus.valueOf(status).is4xxClientError()) {
					this.sessions.remove(sessionId);
				}
			}
			//拦截链的后置处理
			chain.applyAfterHandshake(request, response, null);
		}
		catch (SockJsException ex) {
			failure = ex;
		}
		catch (Exception ex) {
			failure = new SockJsException("Uncaught failure for request " + request.getURI(), sessionId, ex);
		}
		finally {
			if (failure != null) {
				chain.applyAfterHandshake(request, response, failure);
				throw failure;
			}
		}
	}

总结起来有几个过程

  • 先取到一个hander,这里取到的是WebSocketTransportHandler
  • 构造一个拦截链,我们可以注册自己的拦截器,这样就可以在握手阶段来注入我们自己的业务逻辑,比如报文校验等
  • 拦截链的前置处理
  • 创建一个用户会话,握手时,肯定会话是空的,得建一个会话session
  • 处理器hander处理核心逻辑:transportHandler.handleRequest(request, response, handler, session);。这里后面详细写
  • 拦截链的后置处理

transportHandler.handleRequest(request, response, handler, session); 到底做了啥

java 复制代码
	public void handleRequest(ServerHttpRequest request, ServerHttpResponse response,
			WebSocketHandler wsHandler, SockJsSession wsSession) throws SockJsException {

		WebSocketServerSockJsSession sockJsSession = (WebSocketServerSockJsSession) wsSession;
		try {
			wsHandler = new SockJsWebSocketHandler(getServiceConfig(), wsHandler, sockJsSession);
			//握手处理器握手
			this.handshakeHandler.doHandshake(request, response, wsHandler, sockJsSession.getAttributes());
		}
		catch (Exception ex) {
			sockJsSession.tryCloseWithSockJsTransportError(ex, CloseStatus.SERVER_ERROR);
			throw new SockJsTransportFailureException("WebSocket handshake failure", wsSession.getId(), ex);
		}
	}

终于看到了一个握手处理器,handshakeHandler。先不看代码,猜测一下,这里的作用应该是,构造一个握手返回报文,然后通过response写回给客户端。然后告知web容器tomcat,当前请求升级为websocket了。这样,浏览器后面就可以发送websockdet消息了。

握手的逻辑代码,这里不是我们主要研究的点,就不再细讲了。

升级成功后,tomcat有个回调方法,然后再进行一系列的初始化动作

上面的代码是tomcat的代码,可以看到红框中的StandardWebSocketHandlerAdapter。这里进行一第列的初始化动作

2.3 websocket初始化

StandardWebSocketHandlerAdapter.onOpen是入口,由Tomcat回调。

java 复制代码
	public void onOpen(final javax.websocket.Session session, EndpointConfig config) {
		this.wsSession.initializeNativeSession(session);

		// The following inner classes need to remain since lambdas would not retain their
		// declared generic types (which need to be seen by the underlying WebSocket engine)

		if (this.handler.supportsPartialMessages()) {
			session.addMessageHandler(new MessageHandler.Partial<String>() {
				@Override
				public void onMessage(String message, boolean isLast) {
					handleTextMessage(session, message, isLast);
				}
			});
			session.addMessageHandler(new MessageHandler.Partial<ByteBuffer>() {
				@Override
				public void onMessage(ByteBuffer message, boolean isLast) {
					handleBinaryMessage(session, message, isLast);
				}
			});
		}
		else {
			session.addMessageHandler(new MessageHandler.Whole<String>() {
				@Override
				public void onMessage(String message) {
					handleTextMessage(session, message, true);
				}
			});
			session.addMessageHandler(new MessageHandler.Whole<ByteBuffer>() {
				@Override
				public void onMessage(ByteBuffer message) {
					handleBinaryMessage(session, message, true);
				}
			});
		}

		session.addMessageHandler(new MessageHandler.Whole<javax.websocket.PongMessage>() {
			@Override
			public void onMessage(javax.websocket.PongMessage message) {
				handlePongMessage(session, message.getApplicationData());
			}
		});

		try {
			this.handler.afterConnectionEstablished(this.wsSession);
		}
		catch (Exception ex) {
			ExceptionWebSocketHandlerDecorator.tryCloseWithError(this.wsSession, ex, logger);
		}
	}

代码总结:

  • 这里入参传了一个javax.websocket.Session。这个可以理解为当前Websocket连接。
  • 原来这个Session可以给自己添加messageHandler,那当有消息来的时候,就会经过这些handler来进行处理。
  • 那这个hander就是处理业务消息的重点了
    看一下这个hander是怎么处理消息的
java 复制代码
private void handleTextMessage(javax.websocket.Session session, String payload, boolean isLast) {
		TextMessage textMessage = new TextMessage(payload, isLast);
		try {
			this.handler.handleMessage(this.wsSession, textMessage);
		}
		catch (Exception ex) {
			ExceptionWebSocketHandlerDecorator.tryCloseWithError(this.wsSession, ex, logger);
		}
	}

这个handler,对应的实现是:SockJsWebSocketHandler

进入handleMessage看一下处理逻辑,原来是将消息分为三类

  • 文本消息
  • 二进制消息
  • 心跳消息
    这三种消息,分别进行处理
java 复制代码
@Override
	public void handleMessage(WebSocketSession session, WebSocketMessage<?> message) throws Exception {
		if (message instanceof TextMessage) {
			handleTextMessage(session, (TextMessage) message);
		}
		else if (message instanceof BinaryMessage) {
			handleBinaryMessage(session, (BinaryMessage) message);
		}
		else if (message instanceof PongMessage) {
			handlePongMessage(session, (PongMessage) message);
		}
		else {
			throw new IllegalStateException("Unexpected WebSocket message type: " + message);
		}
	}

我们一般处理的是文本消息

java 复制代码
	@Override
	public void handleTextMessage(WebSocketSession wsSession, TextMessage message) throws Exception {
		this.sockJsSession.handleMessage(message, wsSession);
	}

又交给sockJsSession来处理消息。

消息的处理过程,我们暂且不表。下节再来分析。

this.handler.afterConnectionEstablished(this.wsSession);

java 复制代码
	public void initializeDelegateSession(WebSocketSession session) {
		synchronized (this.initSessionLock) {
			this.webSocketSession = session;
			try {
				// Let "our" handler know before sending the open frame to the remote handler
				delegateConnectionEstablished();
				this.webSocketSession.sendMessage(new TextMessage(SockJsFrame.openFrame().getContent()));

				// Flush any messages cached in the meantime
				while (!this.initSessionCache.isEmpty()) {
					writeFrame(SockJsFrame.messageFrame(getMessageCodec(), this.initSessionCache.poll()));
				}
				scheduleHeartbeat();
				this.openFrameSent = true;
			}
			catch (Exception ex) {
				tryCloseWithSockJsTransportError(ex, CloseStatus.SERVER_ERROR);
			}
		}
	}

这里的逻辑如下

  • delegateConnectionEstablished。让我们的hander知晓,当前Websocket连接已经建立了,这是个回调方法
  • 发送一个websocket open报文给客户端
  • 开启websocket心跳线程

代码就分析完毕了,结合最开始的流程图,可以自己再debug一下加深印象。

三、总结

整个握手过程包含以下关键步骤

  • 通过http请求,找到对应的握手的处理器
  • 握手处理器将websocket握手成功的返回报文发送给客户端
  • web容器回调自身,告知协议升级
  • 注册消息处理器,当有websocket消息来时,就会回调处理器进行消息的逻辑处理
  • 初始化事件,包括发送一个open报文给客户端,开启Websocket心跳线程等
相关推荐
潘多编程12 分钟前
Java中的状态机实现:使用Spring State Machine管理复杂状态流转
java·开发语言·spring
_阿伟_31 分钟前
SpringMVC
java·spring
代码在改了38 分钟前
springboot厨房达人美食分享平台(源码+文档+调试+答疑)
java·spring boot
wclass-zhengge1 小时前
数据结构篇(绪论)
java·数据结构·算法
何事驚慌1 小时前
2024/10/5 数据结构打卡
java·数据结构·算法
结衣结衣.1 小时前
C++ 类和对象的初步介绍
java·开发语言·数据结构·c++·笔记·学习·算法
TJKFYY1 小时前
Java.数据结构.HashSet
java·开发语言·数据结构
kylinxjd1 小时前
spring boot发送邮件
java·spring boot·后端·发送email邮件
OLDERHARD1 小时前
Java - MyBatis(上)
java·oracle·mybatis
杨荧1 小时前
【JAVA开源】基于Vue和SpringBoot的旅游管理系统
java·vue.js·spring boot·spring cloud·开源·旅游