本文将深度解析一个基于 Spring Boot 的 WebSocket 功能实现,该实现具备用户认证、会话管理、事件驱动等特性,结构清晰且易于扩展。
核心组件概览
该 WebSocket 功能主要由以下几个核心组件构成:
-
**WebSocketConfig**
: WebSocket 的主配置类,负责注册处理器和拦截器。
-
**WebSocketAuthInterceptor**
: 握手阶段的认证拦截器,用于验证用户身份。
-
**WebSocketEventHandler**
: 核心事件处理器,处理连接建立、消息接收和连接关闭等生命周期事件。
-
**WebSocketSessionManager**
: 会话管理中心,用于跟踪和管理所有活跃的 WebSocket 连接。
-
**WebSocketEvent**
: 自定义事件模型,用于在 WebSocket 的不同生命周期阶段发布事件,实现业务逻辑的解耦。
-
- 业务层监听器 (
**UserServiceImpl**
) : 监听并处理WebSocketEvent
,执行具体的业务逻辑。
- 业务层监听器 (
1. 配置入口 (WebSocketConfig
)
这是 WebSocket 功能的起点。通过 @EnableWebSocket
注解开启支持,并实现 WebSocketConfigurer
接口来配置处理器和拦截器。
-
•
**registerWebSocketHandlers**
:- • 注册了
WebSocketEventHandler
作为核心处理器,并映射到路径/websocket
。 - • 【核心】 添加了
WebSocketAuthInterceptor
拦截器,确保所有到/websocket
的连接请求都先经过认证。 - • 设置了
setAllowedOriginPatterns("*")
来允许跨域连接。
- • 注册了
kotlin
package com.sf.springtemplate.common.config;
import com.sf.springtemplate.common.interceptor.WebSocketAuthInterceptor;
import com.sf.springtemplate.common.handler.WebSocketEventHandler;
import com.sf.springtemplate.service.WebSocketSessionManager;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.socket.config.annotation.EnableWebSocket;
import org.springframework.web.socket.config.annotation.WebSocketConfigurer;
import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry;
@Configuration
@EnableWebSocket // 开启Spring对WebSocket的支持
public class WebSocketConfig implements WebSocketConfigurer {
private final WebSocketSessionManager sessionManager;
private final ApplicationEventPublisher eventPublisher;
private final WebSocketAuthInterceptor authInterceptor; // 注入我们自定义的认证拦截器
public WebSocketConfig(WebSocketSessionManager sessionManager, ApplicationEventPublisher eventPublisher, WebSocketAuthInterceptor authInterceptor) {
this.sessionManager = sessionManager;
this.eventPublisher = eventPublisher;
this.authInterceptor = authInterceptor;
}
@Override
public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
registry
// 1. 注册我们的核心事件处理器,并指定处理的路径为"/websocket"
.addHandler(new WebSocketEventHandler(sessionManager, eventPublisher), "/websocket")
// 2. 【核心配置】为这个路径添加认证拦截器,所有/websocket的连接请求都会先经过它处理
.addInterceptors(authInterceptor)
// 3. 设置允许的跨域来源,"*"表示允许所有来源,在生产环境中应配置为具体的前端域名
.setAllowedOriginPatterns("*");
}
}
2. 连接认证 (WebSocketAuthInterceptor
)
在 WebSocket 握手阶段进行拦截,实现用户身份认证,只有认证通过的连接才会被建立。
-
•
**beforeHandshake**
:-
- 从请求 URL 的参数中获取
token
。
- 从请求 URL 的参数中获取
-
- 使用
JwtUtils
对token
进行解析和验证。
- 使用
-
- 从
token
中获取userId
,并查询数据库以确认用户存在且状态正常。
- 从
-
- 【核心】 认证成功后,将
userId
存入attributes
中。这个attributes
会被传递给后续的WebSocketEventHandler
,使其能在连接建立时获取到已认证的用户信息。
- 【核心】 认证成功后,将
-
- 如果认证失败,返回
false
,中断连接。
- 如果认证失败,返回
-
java
package com.sf.springtemplate.common.interceptor;
import com.sf.springtemplate.common.util.JwtUtils;
import com.sf.springtemplate.entity.User;
import com.sf.springtemplate.mapper.UserMapper;
import io.jsonwebtoken.Claims;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.stereotype.Component;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.server.HandshakeInterceptor;
import org.springframework.web.util.UriComponentsBuilder;
import java.util.Map;
import java.util.Objects;
@Component
@Slf4j
public class WebSocketAuthInterceptor implements HandshakeInterceptor {
private final JwtUtils jwtUtils;
private final UserMapper userMapper;
public WebSocketAuthInterceptor(JwtUtils jwtUtils, UserMapper userMapper) {
this.jwtUtils = jwtUtils;
this.userMapper = userMapper;
}
@Override
public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception {
String token = UriComponentsBuilder.fromUri(request.getURI()).build().getQueryParams().getFirst("token");
if (token == null || token.trim().isEmpty()) {
log.warn("WebSocket握手失败: URL中缺少token参数。");
return false;
}
try {
Claims claims = jwtUtils.parseToken(token);
if (Objects.isNull(claims)) {
log.warn("WebSocket握手失败: Token无效。");
return false;
}
Integer userId = claims.get("userId", Integer.class);
if (userId == null) {
log.warn("WebSocket握手失败: Token中缺少userId。");
return false;
}
User user = userMapper.selectById(userId);
if (user == null || !user.getStatus()) {
log.warn("WebSocket握手失败: 用户不存在或已被禁用, userId: {}", userId);
return false;
}
attributes.put("userId", String.valueOf(user.getId()));
log.info("WebSocket认证成功,用户ID: {}", userId);
return true;
} catch (Exception e) {
log.error("WebSocket握手认证时发生异常: {}", e.getMessage());
return false;
}
}
@Override
public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Exception exception) {
// 握手后不做任何处理
}
}
3. 会话管理 (WebSocketSessionManager
)
这是一个单例组件 (@Component
),负责在内存中统一管理所有活跃的 WebSocket 连接。
- • 使用了三个
ConcurrentHashMap
来分别存储sessionId -> session
、userId -> sessionId
和sessionId -> userId
的映射关系,确保线程安全和高效查找。 - •
**addSession**
: 添加新连接。 - •
**removeSession**
: 移除连接。 - •
**sendMessageToUser**
: 向指定用户发送消息。 - •
**broadcastMessage**
: 向所有在线用户广播消息。
java
package com.sf.springtemplate.service;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Component;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;
import java.io.IOException;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
@Component
public class WebSocketSessionManager {
private static final Logger log = LoggerFactory.getLogger(WebSocketSessionManager.class);
private static final ConcurrentHashMap<String, WebSocketSession> SESSIONS = new ConcurrentHashMap<>();
private static final ConcurrentHashMap<String, String> USER_SESSIONS = new ConcurrentHashMap<>();
private static final ConcurrentHashMap<String, String> SESSION_USERS = new ConcurrentHashMap<>();
public void addSession(String userId, WebSocketSession session) {
SESSIONS.put(session.getId(), session);
USER_SESSIONS.put(userId, session.getId());
SESSION_USERS.put(session.getId(), userId);
}
public String removeSession(WebSocketSession session) {
String userId = SESSION_USERS.remove(session.getId());
if (userId != null) {
USER_SESSIONS.remove(userId);
}
SESSIONS.remove(session.getId());
return userId;
}
public void sendMessageToUser(String userId, String message) {
String sessionId = USER_SESSIONS.get(userId);
if (sessionId != null) {
WebSocketSession session = SESSIONS.get(sessionId);
if (session != null && session.isOpen()) {
try {
session.sendMessage(new TextMessage(message));
} catch (IOException e) {
log.error("向用户 {} 发送消息失败: {}", userId, e.getMessage());
}
}
}
}
public void broadcastMessage(String message) {
log.info("开始广播消息: {}", message);
int successCount = 0;
for (WebSocketSession session : SESSIONS.values()) {
if (session.isOpen()) {
try {
session.sendMessage(new TextMessage(message));
successCount++;
} catch (IOException e) {
log.error("向会话 {} 广播消息失败: {}", session.getId(), e.getMessage());
}
}
}
log.info("消息广播完成,成功发送给 {} 个客户端", successCount);
}
}
4. 事件处理与发布 (WebSocketEventHandler
& WebSocketEvent
)
WebSocketEventHandler
继承自 TextWebSocketHandler
,负责处理 WebSocket 的核心生命周期,并通过 ApplicationEventPublisher
将这些活动发布为 Spring 事件。
- •
**afterConnectionEstablished**
: 连接成功后,从session.getAttributes()
中获取userId
(由WebSocketAuthInterceptor
存入),将会话添加到sessionManager
中,并发布USER_ONLINE
事件。 - •
**handleTextMessage**
: 收到消息后,发布MESSAGE_RECEIVED
事件。 - •
**afterConnectionClosed**
: 连接关闭后,从sessionManager
中移除会话,并发布USER_OFFLINE
事件。
WebSocketEvent
是一个自定义的 ApplicationEvent
,用于封装事件信息(如事件类型、用户ID、会话等),实现了业务逻辑与 WebSocket 底层处理的解耦。
java
package com.sf.springtemplate.common.handler;
import com.sf.springtemplate.common.model.WebSocketEvent;
import com.sf.springtemplate.service.WebSocketSessionManager;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.TextWebSocketHandler;
public class WebSocketEventHandler extends TextWebSocketHandler {
private static final Logger log = LoggerFactory.getLogger(WebSocketEventHandler.class);
private final WebSocketSessionManager sessionManager;
private final ApplicationEventPublisher eventPublisher;
public WebSocketEventHandler(WebSocketSessionManager sessionManager, ApplicationEventPublisher eventPublisher) {
this.sessionManager = sessionManager;
this.eventPublisher = eventPublisher;
}
@Override
public void afterConnectionEstablished(WebSocketSession session) {
String userId = (String) session.getAttributes().get("userId");
if (userId != null) {
sessionManager.addSession(userId, session);
log.info("用户 {} 连接成功, 会话ID: {}, 当前总连接数: {}", userId, session.getId(), sessionManager.getActiveConnectionCount());
eventPublisher.publishEvent(new WebSocketEvent(this, WebSocketEvent.Type.USER_ONLINE, userId, session));
} else {
// ...
}
}
@Override
protected void handleTextMessage(WebSocketSession session, TextMessage message) {
String userId = sessionManager.getUserId(session.getId());
if (userId != null) {
log.info("收到来自用户 {} 的消息: {}", userId, message.getPayload());
eventPublisher.publishEvent(new WebSocketEvent(this, WebSocketEvent.Type.MESSAGE_RECEIVED, userId, session, message.getPayload()));
}
}
@Override
public void afterConnectionClosed(WebSocketSession session, CloseStatus status) {
String userId = sessionManager.removeSession(session);
if (userId != null) {
log.info("用户 {} 连接关闭, 原因: {}, 当前总连接数: {}", userId, status.getReason(), sessionManager.getActiveConnectionCount());
eventPublisher.publishEvent(new WebSocketEvent(this, WebSocketEvent.Type.USER_OFFLINE, userId, session));
}
}
}
scala
package com.sf.springtemplate.common.model;
import lombok.Getter;
import org.springframework.context.ApplicationEvent;
import org.springframework.web.socket.WebSocketSession;
@Getter
public class WebSocketEvent extends ApplicationEvent {
private final Type type;
private final String userId;
private final WebSocketSession session;
private final String message;
public enum Type {
USER_ONLINE,
USER_OFFLINE,
MESSAGE_RECEIVED
}
public WebSocketEvent(Object source, Type type, String userId, WebSocketSession session) {
this(source, type, userId, session, null);
}
public WebSocketEvent(Object source, Type type, String userId, WebSocketSession session, String message) {
super(source);
this.type = type;
this.userId = userId;
this.session = session;
this.message = message;
}
}
5. 业务逻辑处理 (UserServiceImpl
)
在业务层(例如 UserServiceImpl
)中,可以非常方便地通过 @EventListener
注解来监听并处理前面发布的 WebSocketEvent
。
-
•
**handleWebSocketEvents**
:- • 监听
WebSocketEvent
。 - • 判断事件类型是
USER_ONLINE
还是USER_OFFLINE
。 - • 执行相应的业务逻辑,例如,当用户上线时,调用
webSocketSessionManager.sendMessageToUser
发送一条欢迎消息。
- • 监听
scala
// 在 UserServiceImpl.java 中
import com.sf.springtemplate.common.model.WebSocketEvent;
import com.sf.springtemplate.service.WebSocketSessionManager;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.event.EventListener;
// ...
@Service
@Slf4j
public class UserServiceImpl extends ServiceImpl<UserMapper, User> implements UserService {
@Autowired
private WebSocketSessionManager webSocketSessionManager;
//... 其他代码
/**
* 监听WebSocket事件的示例方法。
*/
@EventListener
public void handleWebSocketEvents(WebSocketEvent event) {
if (event.getType() == WebSocketEvent.Type.USER_ONLINE) {
String userId = event.getUserId();
log.info("【业务层】监听到用户 {} 上线了!", userId);
String welcomeMessage = "欢迎回来!您已成功连接到实时通知服务。";
webSocketSessionManager.sendMessageToUser(userId, welcomeMessage);
} else if (event.getType() == WebSocketEvent.Type.USER_OFFLINE) {
log.info("【业务层】监听到用户 {} 离线了。", event.getUserId());
// 可在此处添加用户离线后的业务处理
}
}
}
总结
这个 WebSocket 实现方案通过责任链模式(拦截器处理认证)和观察者模式(事件发布/监听机制)实现了高度的模块化和解耦。开发者可以轻松地在业务层监听 WebSocket 事件,而无需关心底层的连接管理和生命周期,从而实现干净、可维护的代码结构。