Spring Boot WebSocket 实现深度解析

本文将深度解析一个基于 Spring Boot 的 WebSocket 功能实现,该实现具备用户认证、会话管理、事件驱动等特性,结构清晰且易于扩展。

核心组件概览

该 WebSocket 功能主要由以下几个核心组件构成:

    1. **WebSocketConfig**: WebSocket 的主配置类,负责注册处理器和拦截器。
    1. **WebSocketAuthInterceptor**: 握手阶段的认证拦截器,用于验证用户身份。
    1. **WebSocketEventHandler**: 核心事件处理器,处理连接建立、消息接收和连接关闭等生命周期事件。
    1. **WebSocketSessionManager**: 会话管理中心,用于跟踪和管理所有活跃的 WebSocket 连接。
    1. **WebSocketEvent**: 自定义事件模型,用于在 WebSocket 的不同生命周期阶段发布事件,实现业务逻辑的解耦。
    1. 业务层监听器 ( **UserServiceImpl** ) : 监听并处理 WebSocketEvent,执行具体的业务逻辑。

1. 配置入口 (WebSocketConfig)

这是 WebSocket 功能的起点。通过 @EnableWebSocket 注解开启支持,并实现 WebSocketConfigurer 接口来配置处理器和拦截器。

  • **registerWebSocketHandlers**:

    • • 注册了 WebSocketEventHandler 作为核心处理器,并映射到路径 /websocket
    • 【核心】 添加了 WebSocketAuthInterceptor 拦截器,确保所有到 /websocket 的连接请求都先经过认证。
    • • 设置了 setAllowedOriginPatterns("*") 来允许跨域连接。
kotlin 复制代码
package com.sf.springtemplate.common.config;

import com.sf.springtemplate.common.interceptor.WebSocketAuthInterceptor;
import com.sf.springtemplate.common.handler.WebSocketEventHandler;
import com.sf.springtemplate.service.WebSocketSessionManager;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.socket.config.annotation.EnableWebSocket;
import org.springframework.web.socket.config.annotation.WebSocketConfigurer;
import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry;

@Configuration
@EnableWebSocket // 开启Spring对WebSocket的支持
public class WebSocketConfig implements WebSocketConfigurer {

    private final WebSocketSessionManager sessionManager;
    private final ApplicationEventPublisher eventPublisher;
    private final WebSocketAuthInterceptor authInterceptor; // 注入我们自定义的认证拦截器

    public WebSocketConfig(WebSocketSessionManager sessionManager, ApplicationEventPublisher eventPublisher, WebSocketAuthInterceptor authInterceptor) {
        this.sessionManager = sessionManager;
        this.eventPublisher = eventPublisher;
        this.authInterceptor = authInterceptor;
    }

    @Override
    public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
        registry
        // 1. 注册我们的核心事件处理器,并指定处理的路径为"/websocket"
        .addHandler(new WebSocketEventHandler(sessionManager, eventPublisher), "/websocket")
        // 2. 【核心配置】为这个路径添加认证拦截器,所有/websocket的连接请求都会先经过它处理
        .addInterceptors(authInterceptor)
        // 3. 设置允许的跨域来源,"*"表示允许所有来源,在生产环境中应配置为具体的前端域名
        .setAllowedOriginPatterns("*");
    }
}

2. 连接认证 (WebSocketAuthInterceptor)

在 WebSocket 握手阶段进行拦截,实现用户身份认证,只有认证通过的连接才会被建立。

  • **beforeHandshake**:

      1. 从请求 URL 的参数中获取 token
      1. 使用 JwtUtilstoken 进行解析和验证。
      1. token 中获取 userId,并查询数据库以确认用户存在且状态正常。
      1. 【核心】 认证成功后,将 userId 存入 attributes 中。这个 attributes 会被传递给后续的 WebSocketEventHandler,使其能在连接建立时获取到已认证的用户信息。
      1. 如果认证失败,返回 false,中断连接。
java 复制代码
package com.sf.springtemplate.common.interceptor;

import com.sf.springtemplate.common.util.JwtUtils;
import com.sf.springtemplate.entity.User;
import com.sf.springtemplate.mapper.UserMapper;
import io.jsonwebtoken.Claims;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.stereotype.Component;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.server.HandshakeInterceptor;
import org.springframework.web.util.UriComponentsBuilder;

import java.util.Map;
import java.util.Objects;

@Component
@Slf4j
public class WebSocketAuthInterceptor implements HandshakeInterceptor {

    private final JwtUtils jwtUtils;
    private final UserMapper userMapper;

    public WebSocketAuthInterceptor(JwtUtils jwtUtils, UserMapper userMapper) {
        this.jwtUtils = jwtUtils;
        this.userMapper = userMapper;
    }

    @Override
    public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception {
        String token = UriComponentsBuilder.fromUri(request.getURI()).build().getQueryParams().getFirst("token");

        if (token == null || token.trim().isEmpty()) {
            log.warn("WebSocket握手失败: URL中缺少token参数。");
            return false;
        }

        try {
            Claims claims = jwtUtils.parseToken(token);
            if (Objects.isNull(claims)) {
                log.warn("WebSocket握手失败: Token无效。");
                return false;
            }

            Integer userId = claims.get("userId", Integer.class);
            if (userId == null) {
                log.warn("WebSocket握手失败: Token中缺少userId。");
                return false;
            }

            User user = userMapper.selectById(userId);
            if (user == null || !user.getStatus()) {
                log.warn("WebSocket握手失败: 用户不存在或已被禁用, userId: {}", userId);
                return false;
            }

            attributes.put("userId", String.valueOf(user.getId()));
            log.info("WebSocket认证成功,用户ID: {}", userId);
            return true;

        } catch (Exception e) {
            log.error("WebSocket握手认证时发生异常: {}", e.getMessage());
            return false;
        }
    }

    @Override
    public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Exception exception) {
        // 握手后不做任何处理
    }
}

3. 会话管理 (WebSocketSessionManager)

这是一个单例组件 (@Component),负责在内存中统一管理所有活跃的 WebSocket 连接。

  • • 使用了三个 ConcurrentHashMap 来分别存储 sessionId -> sessionuserId -> sessionIdsessionId -> userId 的映射关系,确保线程安全和高效查找。
  • **addSession**: 添加新连接。
  • **removeSession**: 移除连接。
  • **sendMessageToUser**: 向指定用户发送消息。
  • **broadcastMessage**: 向所有在线用户广播消息。
java 复制代码
package com.sf.springtemplate.service;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Component;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;

import java.io.IOException;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;

@Component
public class WebSocketSessionManager {

    private static final Logger log = LoggerFactory.getLogger(WebSocketSessionManager.class);

    private static final ConcurrentHashMap<String, WebSocketSession> SESSIONS = new ConcurrentHashMap<>();
    private static final ConcurrentHashMap<String, String> USER_SESSIONS = new ConcurrentHashMap<>();
    private static final ConcurrentHashMap<String, String> SESSION_USERS = new ConcurrentHashMap<>();

    public void addSession(String userId, WebSocketSession session) {
        SESSIONS.put(session.getId(), session);
        USER_SESSIONS.put(userId, session.getId());
        SESSION_USERS.put(session.getId(), userId);
    }

    public String removeSession(WebSocketSession session) {
        String userId = SESSION_USERS.remove(session.getId());
        if (userId != null) {
            USER_SESSIONS.remove(userId);
        }
        SESSIONS.remove(session.getId());
        return userId;
    }
    
    public void sendMessageToUser(String userId, String message) {
        String sessionId = USER_SESSIONS.get(userId);
        if (sessionId != null) {
            WebSocketSession session = SESSIONS.get(sessionId);
            if (session != null && session.isOpen()) {
                try {
                    session.sendMessage(new TextMessage(message));
                } catch (IOException e) {
                    log.error("向用户 {} 发送消息失败: {}", userId, e.getMessage());
                }
            }
        }
    }

    public void broadcastMessage(String message) {
        log.info("开始广播消息: {}", message);
        int successCount = 0;
        for (WebSocketSession session : SESSIONS.values()) {
            if (session.isOpen()) {
                try {
                    session.sendMessage(new TextMessage(message));
                    successCount++;
                } catch (IOException e) {
                    log.error("向会话 {} 广播消息失败: {}", session.getId(), e.getMessage());
                }
            }
        }
        log.info("消息广播完成,成功发送给 {} 个客户端", successCount);
    }
}

4. 事件处理与发布 (WebSocketEventHandler & WebSocketEvent)

WebSocketEventHandler 继承自 TextWebSocketHandler,负责处理 WebSocket 的核心生命周期,并通过 ApplicationEventPublisher 将这些活动发布为 Spring 事件。

  • **afterConnectionEstablished**: 连接成功后,从 session.getAttributes() 中获取 userId(由 WebSocketAuthInterceptor 存入),将会话添加到 sessionManager 中,并发布 USER_ONLINE 事件。
  • **handleTextMessage**: 收到消息后,发布 MESSAGE_RECEIVED 事件。
  • **afterConnectionClosed**: 连接关闭后,从 sessionManager 中移除会话,并发布 USER_OFFLINE 事件。

WebSocketEvent 是一个自定义的 ApplicationEvent,用于封装事件信息(如事件类型、用户ID、会话等),实现了业务逻辑与 WebSocket 底层处理的解耦。

java 复制代码
package com.sf.springtemplate.common.handler;

import com.sf.springtemplate.common.model.WebSocketEvent;
import com.sf.springtemplate.service.WebSocketSessionManager;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.TextWebSocketHandler;

public class WebSocketEventHandler extends TextWebSocketHandler {

    private static final Logger log = LoggerFactory.getLogger(WebSocketEventHandler.class);
    private final WebSocketSessionManager sessionManager;
    private final ApplicationEventPublisher eventPublisher;

    public WebSocketEventHandler(WebSocketSessionManager sessionManager, ApplicationEventPublisher eventPublisher) {
        this.sessionManager = sessionManager;
        this.eventPublisher = eventPublisher;
    }

    @Override
    public void afterConnectionEstablished(WebSocketSession session) {
        String userId = (String) session.getAttributes().get("userId");
        if (userId != null) {
            sessionManager.addSession(userId, session);
            log.info("用户 {} 连接成功, 会话ID: {}, 当前总连接数: {}", userId, session.getId(), sessionManager.getActiveConnectionCount());
            eventPublisher.publishEvent(new WebSocketEvent(this, WebSocketEvent.Type.USER_ONLINE, userId, session));
        } else {
            // ...
        }
    }

    @Override
    protected void handleTextMessage(WebSocketSession session, TextMessage message) {
        String userId = sessionManager.getUserId(session.getId());
        if (userId != null) {
            log.info("收到来自用户 {} 的消息: {}", userId, message.getPayload());
            eventPublisher.publishEvent(new WebSocketEvent(this, WebSocketEvent.Type.MESSAGE_RECEIVED, userId, session, message.getPayload()));
        }
    }

    @Override
    public void afterConnectionClosed(WebSocketSession session, CloseStatus status) {
        String userId = sessionManager.removeSession(session);
        if (userId != null) {
            log.info("用户 {} 连接关闭, 原因: {}, 当前总连接数: {}", userId, status.getReason(), sessionManager.getActiveConnectionCount());
            eventPublisher.publishEvent(new WebSocketEvent(this, WebSocketEvent.Type.USER_OFFLINE, userId, session));
        }
    }
}
scala 复制代码
package com.sf.springtemplate.common.model;

import lombok.Getter;
import org.springframework.context.ApplicationEvent;
import org.springframework.web.socket.WebSocketSession;

@Getter
public class WebSocketEvent extends ApplicationEvent {

    private final Type type;
    private final String userId;
    private final WebSocketSession session;
    private final String message;

    public enum Type {
        USER_ONLINE,
        USER_OFFLINE,
        MESSAGE_RECEIVED
    }

    public WebSocketEvent(Object source, Type type, String userId, WebSocketSession session) {
        this(source, type, userId, session, null);
    }

    public WebSocketEvent(Object source, Type type, String userId, WebSocketSession session, String message) {
        super(source);
        this.type = type;
        this.userId = userId;
        this.session = session;
        this.message = message;
    }
}

5. 业务逻辑处理 (UserServiceImpl)

在业务层(例如 UserServiceImpl)中,可以非常方便地通过 @EventListener 注解来监听并处理前面发布的 WebSocketEvent

  • **handleWebSocketEvents**:

    • • 监听 WebSocketEvent
    • • 判断事件类型是 USER_ONLINE还是 USER_OFFLINE
    • • 执行相应的业务逻辑,例如,当用户上线时,调用 webSocketSessionManager.sendMessageToUser 发送一条欢迎消息。
scala 复制代码
// 在 UserServiceImpl.java 中
import com.sf.springtemplate.common.model.WebSocketEvent;
import com.sf.springtemplate.service.WebSocketSessionManager;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.event.EventListener;
// ...

@Service
@Slf4j
public class UserServiceImpl extends ServiceImpl<UserMapper, User> implements UserService {
    
    @Autowired
    private WebSocketSessionManager webSocketSessionManager;

    //... 其他代码

    /**
     * 监听WebSocket事件的示例方法。
     */
    @EventListener
    public void handleWebSocketEvents(WebSocketEvent event) {
        if (event.getType() == WebSocketEvent.Type.USER_ONLINE) {
            String userId = event.getUserId();
            log.info("【业务层】监听到用户 {} 上线了!", userId);

            String welcomeMessage = "欢迎回来!您已成功连接到实时通知服务。";
            webSocketSessionManager.sendMessageToUser(userId, welcomeMessage);

        } else if (event.getType() == WebSocketEvent.Type.USER_OFFLINE) {
            log.info("【业务层】监听到用户 {} 离线了。", event.getUserId());
            // 可在此处添加用户离线后的业务处理
        }
    }
}

总结

这个 WebSocket 实现方案通过责任链模式(拦截器处理认证)和观察者模式(事件发布/监听机制)实现了高度的模块化和解耦。开发者可以轻松地在业务层监听 WebSocket 事件,而无需关心底层的连接管理和生命周期,从而实现干净、可维护的代码结构。

相关推荐
一只爱撸猫的程序猿6 小时前
做一个「运维知识库 + 多模态检索问答」的案例
spring boot·aigc·ai编程
sheji34168 小时前
【开题答辩全过程】以 信达纸巾公司生产管理系统为例,包含答辩的问题和答案
spring boot
泉城老铁11 小时前
Spring Boot项目开发中,JPA 和mybatisplus哪个更哇塞呢
java·spring boot·后端
编啊编程啊程11 小时前
响应式编程框架Reactor【5】
java·jvm·spring boot·spring cloud·java-ee·maven
一 乐15 小时前
医院排班|医护人员排班系统|基于springboot医护人员排班系统设计与实现(源码+数据库+文档)
java·数据库·spring boot·后端·论文·毕设·医护人员排班系统
曾令胜16 小时前
Spring和mybatis整合后事务拦截器TransactionInterceptor开启提交事务流程
数据库·spring boot·mybatis
小猪咪piggy18 小时前
【JavaEE】(20) Spring Boot 统一功能处理
java·spring boot·后端
makerjack00119 小时前
Java中使用Spring Boot+Ollama实现本地AI的MCP接入
java·人工智能·spring boot
就叫飞六吧1 天前
基于Spring Boot的短信平台平滑切换设计方案
java·spring boot·后端