SpringBoot 集成 Redis 实现分布式 WebSocket:跨实例消息推送实战

在分布式系统中,单机 WebSocket 无法实现跨服务实例的消息推送(如用户连接在实例 A,而消息触发在实例 B)。本文基于前文 SpringBoot+Redis 缓存实践,结合 Redis 发布订阅(Pub/Sub)机制,实现分布式 WebSocket,解决跨实例消息推送问题,核心是通过 Redis 作为消息中转站,让所有服务实例共享消息通道。

一、核心设计思路

分布式 WebSocket 的核心是 "Redis 发布订阅 + 本地 Session 管理":

  1. 客户端连接:用户 WebSocket 连接到任意服务实例,实例将连接 Session 存入本地缓存;
  2. 消息发布:任意实例产生推送消息时,通过 Redis 发布(Publish)到指定频道;
  3. 消息订阅:所有服务实例订阅 Redis 频道,接收到消息后,检查本地是否有目标用户的 Session;
  4. 消息推送:有目标用户 Session 的实例,将消息推送给对应的客户端。

整个流程复用了 Redis 缓存实践中的RedisTemplate配置,无需额外引入依赖,保证技术栈统一。

二、环境准备

复用前文SpringBoot 集成 Redis 缓存实践

  • 依赖spring-boot-starter-websocket

三、核心代码实现

1. 枚举定义:WebSocket 消息类型

统一管理 WebSocket 消息类型,便于后续扩展不同业务的消息推送:

java 复制代码
package com.demo.redis.enums;

import lombok.AllArgsConstructor;
import lombok.Getter;

/**
 * WebSocket消息类型枚举
 */
@Getter
@AllArgsConstructor
public enum DWebSocketType {
    DEMO("示例消息"),
    NOTICE("系统通知"),
    ORDER("订单消息"),
    ;
    private final String desc;
}

2. DTO 定义:WebSocket 消息体

封装推送消息的通用结构,包含用户 ID、消息类型、消息内容:

java 复制代码
package com.demo.redis.dto;

import com.demo.redis.enums.DWebSocketType;
import lombok.Data;
import lombok.experimental.Accessors;

import java.io.Serializable;

/**
 * WebSocket推送请求DTO
 * 实现Serializable保证Redis传输时可序列化
 */
@Accessors(chain = true)
@Data
public class DWebSocketReqDTO<T extends Serializable> implements Serializable {
    private static final long serialVersionUID = 1292707161671865097L;

    /**
     * 消息类型 @see DWebSocketType
     */
    private String type;
    /**
     * 目标用户ID(推送的核心标识)
     */
    private Long userId;
    /**
     * 消息体(支持任意序列化类型)
     */
    private T body;
}

3. 消息生产者:发布 Redis 消息

封装 Redis 发布逻辑,统一消息推送入口,所有需要推送的场景都通过该类发布消息:

java 复制代码
package com.demo.redis.producer;  
  
import com.demo.redis.dto.DWebSocketReqDTO;  
import lombok.AllArgsConstructor;  
import lombok.extern.slf4j.Slf4j;  
import org.springframework.data.redis.core.RedisTemplate;  
import org.springframework.stereotype.Component;  
  
import javax.annotation.Resource;  
import java.io.Serializable;  
  
  
/**  
* WebSocket消息生产者  
* 负责将消息发布到Redis指定频道  
*/  
@Slf4j  
public class DWebSocketProducer {  
  
private RedisTemplate<String, Object> redisTemplate;  
private String channel;  
  
public DWebSocketProducer(RedisTemplate<String, Object> redisTemplate, String channel) {  
this.redisTemplate = redisTemplate;  
this.channel = channel;  
}  
  
/**  
* 发布WebSocket消息  
*  
* @param reqDTO 消息体  
* @param <T> 消息内容类型  
*/  
public <T extends Serializable> void send(DWebSocketReqDTO<T> reqDTO) {  
log.info("#websocket::send# 准备推送消息,userId={},type={}", reqDTO.getUserId(), reqDTO.getType());  
// 参数校验:用户ID和消息类型不能为空  
if (reqDTO.getUserId() == null || reqDTO.getType() == null) {  
log.warn("#websocket::send::return# 参数缺失,userId={},type={}", reqDTO.getUserId(), reqDTO.getType());  
return;  
}  
// 核心:Redis发布消息到指定频道  
redisTemplate.convertAndSend(channel, reqDTO);  
log.info("#websocket::send::success# 消息发布成功,userId={}", reqDTO.getUserId());  
}  
  
}

4. Session 管理器:本地连接管理

管理当前服务实例的 WebSocket 连接 Session,提供增删查、自动清理等能力,是分布式推送的本地核心:

java 复制代码
package com.demo.redis.manage;

import cn.hutool.core.collection.CollectionUtil;
import cn.hutool.core.util.StrUtil;
import com.demo.redis.dto.DWebSocketReqDTO;
import lombok.extern.slf4j.Slf4j;
import org.apache.tomcat.websocket.WsSession;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.adapter.standard.StandardWebSocketSession;

import javax.websocket.Session;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

/**
 * WebSocket Session管理器
 * 管理当前实例的所有WebSocket连接,线程安全
 */
@Slf4j
public class DWebSocketSessionManage {
    // 核心参数:Session池容量控制
    public static final int CORE_SIZE = 30000; // 核心容量
    public static final int MAX_SIZE = 50000; // 最大容量

    // 存储Session的最后访问时间(用于清理过期连接)
    public static Map<String, Long> KEY_TIME_POOL = new HashMap<>(200);
    // 核心:存储用户ID与WebSocketSession的映射(线程安全)
    public static Map<String, WebSocketSession> SESSION_POOL = new ConcurrentHashMap<>(200);

    /**
     * 添加Session到本地缓存
     * @param key 唯一标识(格式:sk::userId)
     * @param session WebSocket连接
     */
    public static void add(String key, WebSocketSession session) {
        if (StrUtil.isBlank(key)) {
            return;
        }
        // 先清理过期/超量连接,防止内存溢出
        clear();
        // 更新最后访问时间
        KEY_TIME_POOL.put(key, System.currentTimeMillis());
        // 存储Session
        SESSION_POOL.put(key, session);
        log.info("#websocket::add# Session添加成功,key={},当前连接数={}", key, SESSION_POOL.size());
    }

    /**
     * 清理超量连接(保留核心容量)
     */
    private static void clear() {
        if (KEY_TIME_POOL.size() > MAX_SIZE) {
            log.warn("#websocket::clear# 连接数超过最大值{},开始清理", MAX_SIZE);
            // 将连接按最后访问时间升序排序,优先清理最久未使用的
            List<Map.Entry<String, Long>> list = new ArrayList<>(KEY_TIME_POOL.entrySet());
            list.sort(Map.Entry.comparingByValue());

            int clearCount = MAX_SIZE - CORE_SIZE; // 需要清理的数量
            int count = 0;
            for (Map.Entry<String, Long> entry : list) {
                if (count < clearCount) {
                    remove(entry.getKey()); // 移除连接并关闭Session
                    count++;
                } else {
                    break;
                }
            }
            log.info("#websocket::clear::success# 清理完成,共清理{}个连接", count);
        }
    }

    /**
     * 获取用户的Session
     * @param key 唯一标识
     * @return WebSocketSession
     */
    public static WebSocketSession get(String key) {
        if (StrUtil.isBlank(key)) {
            return null;
        }
        // 更新最后访问时间(防止被清理)
        KEY_TIME_POOL.put(key, System.currentTimeMillis());
        return SESSION_POOL.get(key);
    }

    /**
     * 移除并关闭Session
     * @param key 唯一标识
     */
    public static void remove(String key) {
        if (StrUtil.isBlank(key)) {
            return;
        }
        // 清理时间记录
        KEY_TIME_POOL.remove(key);
        // 移除并关闭Session
        WebSocketSession session = SESSION_POOL.remove(key);
        if (session != null) {
            try {
                session.close();
                log.info("#websocket::remove# Session关闭成功,key={}", key);
            } catch (IOException e) {
                log.error("#websocket::remove::error# Session关闭失败,key={}", key, e);
            }
        }
    }

    /**
     * 根据消息体生成SessionKey
     */
    public static String getSessionKey(DWebSocketReqDTO reqDTO) {
        if (reqDTO.getUserId() == null) {
            return null;
        }
        return getSessionKey(reqDTO.getUserId());
    }

    /**
     * 从WebSocketSession中解析用户ID,生成SessionKey
     */
    public static String getSessionKey(WebSocketSession session) {
        try {
            // 解析请求参数中的userId
            StandardWebSocketSession standardSession = (StandardWebSocketSession) session;
            WsSession wsSession = (WsSession) standardSession.getNativeSession();
            List<String> userIdList = wsSession.getRequestParameterMap().get("userId");
            if (CollectionUtil.isEmpty(userIdList)) {
                log.warn("#websocket::getSessionKey# Session中未获取到userId");
                return null;
            }
            return getSessionKey(userIdList.get(0));
        } catch (Exception e) {
            log.error("#websocket::getSessionKey::error# 解析SessionKey失败", e);
            return null;
        }
    }

    /**
     * 生成统一的SessionKey格式
     */
    private static String getSessionKey(Object userId) {
        return "sk::" + userId; // 格式:sk::1001
    }
}

5. Redis 消息监听器:订阅并推送消息

订阅 Redis 频道的消息,接收到消息后,检查本地是否有目标用户的 Session,有则推送:

java 复制代码
package com.demo.redis.listener;

import cn.hutool.core.util.ObjectUtil;
import cn.hutool.core.util.StrUtil;
import cn.hutool.json.JSONUtil;
import com.demo.redis.dto.DWebSocketReqDTO;
import com.demo.redis.manage.DWebSocketSessionManage;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.data.redis.connection.Message;
import org.springframework.data.redis.connection.MessageListener;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;

/**
 * Redis消息监听器
 * 订阅Redis频道,接收消息并推送给本地连接的客户端
 */
@AllArgsConstructor
@Slf4j
public class DWebSocketRedisMessageListener implements MessageListener {

    private RedisTemplate<String, Object> redisTemplate;

    /**
     * 核心方法:接收到Redis消息后的处理逻辑
     */
    @Override
    public void onMessage(Message message, byte[] pattern) {
        // 1. 反序列化Redis消息
        Object value = redisTemplate.getValueSerializer().deserialize(message.getBody());
        if (ObjectUtil.isNull(value)) {
            log.warn("#websocket::onMessage::return# 消息体为空");
            return;
        }
        DWebSocketReqDTO reqDTO = (DWebSocketReqDTO) value;
        log.info("#websocket::onMessage# 接收到推送消息,userId={},type={}", reqDTO.getUserId(), reqDTO.getType());

        // 2. 获取目标用户的SessionKey
        String sessionKey = DWebSocketSessionManage.getSessionKey(reqDTO);
        if (StrUtil.isBlank(sessionKey)) {
            log.warn("#websocket::onMessage::return# 未获取到SessionKey,userId={}", reqDTO.getUserId());
            return;
        }

        // 3. 获取本地Session(当前实例是否连接该用户)
        WebSocketSession userSession = DWebSocketSessionManage.get(sessionKey);
        if (userSession == null) {
            log.info("#websocket::onMessage::return# 当前实例无该用户连接,userId={}", reqDTO.getUserId());
            return;
        }

        // 4. 推送消息给客户端
        try {
            if (userSession.isOpen()) {
                String jsonStr = JSONUtil.toJsonStr(reqDTO);
                userSession.sendMessage(new TextMessage(jsonStr));
                log.info("#websocket::onMessage::success# 消息推送成功,userId={}", reqDTO.getUserId());
            } else {
                // Session已关闭,清理缓存
                DWebSocketSessionManage.remove(sessionKey);
                log.warn("#websocket::onMessage::close# Session已关闭,清理缓存,userId={}", reqDTO.getUserId());
            }
        } catch (Exception e) {
            log.error("#websocket::onMessage::error# 消息推送失败,userId={}", reqDTO.getUserId(), e);
            // 推送失败,清理Session
            DWebSocketSessionManage.remove(sessionKey);
        }
    }
}

6. WebSocket 处理器:管理连接生命周期

处理 WebSocket 连接的建立、消息接收、异常、关闭等生命周期事件:

java 复制代码
package com.demo.redis.handler;

import com.demo.redis.manage.DWebSocketSessionManage;
import lombok.extern.slf4j.Slf4j;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.TextWebSocketHandler;

/**
 * WebSocket处理器
 * 处理连接的建立、消息、异常、关闭等事件
 */
@Slf4j
public class DWebSocketHandler extends TextWebSocketHandler {

    /**
     * 连接建立成功后触发
     */
    @Override
    public void afterConnectionEstablished(WebSocketSession session) throws Exception {
        // 生成SessionKey并添加到本地缓存
        String sessionKey = DWebSocketSessionManage.getSessionKey(session);
        log.info("#websocket::afterConnectionEstablished# 连接建立成功,sessionKey={},sessionId={}", sessionKey, session.getId());
        DWebSocketSessionManage.add(sessionKey, session);
    }

    /**
     * 接收客户端发送的消息
     */
    @Override
    protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {
        String sessionKey = DWebSocketSessionManage.getSessionKey(session);
        log.info("#websocket::handleTextMessage# 接收客户端消息,sessionKey={},sessionId={},payload={}", 
                 sessionKey, session.getId(), message.getPayload());
        // 示例:回复心跳包
        session.sendMessage(new TextMessage("pong"));
    }

    /**
     * 传输异常时触发
     */
    @Override
    public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {
        String sessionKey = DWebSocketSessionManage.getSessionKey(session);
        log.error("#websocket::handleTransportError# 连接传输异常,sessionKey={},sessionId={}", 
                  sessionKey, session.getId(), exception);
        // 清理异常连接
        DWebSocketSessionManage.remove(sessionKey);
    }

    /**
     * 连接关闭后触发
     */
    @Override
    public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception {
        String sessionKey = DWebSocketSessionManage.getSessionKey(session);
        log.info("#websocket::afterConnectionClosed# 连接关闭,sessionKey={},sessionId={},status={}", 
                 sessionKey, session.getId(), status);
        // 清理关闭的连接
        DWebSocketSessionManage.remove(sessionKey);
    }
}

7. 核心配置类:整合所有组件

配置 WebSocket 处理器、Redis 监听器、消息生产者等核心组件,完成分布式 WebSocket 的整体装配:

java 复制代码
package com.demo.redis.config;  
  
  
import com.demo.redis.handler.DWebSocketHandler;  
import com.demo.redis.listener.DWebSocketRedisMessageListener;  
import com.demo.redis.producer.DWebSocketProducer;  
import lombok.extern.slf4j.Slf4j;  
import org.springframework.context.annotation.Bean;  
import org.springframework.context.annotation.Configuration;  
import org.springframework.data.redis.connection.RedisConnectionFactory;  
import org.springframework.data.redis.core.RedisTemplate;  
import org.springframework.data.redis.listener.PatternTopic;  
import org.springframework.data.redis.listener.RedisMessageListenerContainer;  
import org.springframework.web.socket.config.annotation.EnableWebSocket;  
import org.springframework.web.socket.config.annotation.WebSocketConfigurer;  
import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry;  
  
/**  
* 分布式WebSocket核心配置  
*/  
@Slf4j  
@EnableWebSocket // 开启WebSocket支持  
@Configuration  
public class DWebSocketConfig implements WebSocketConfigurer {  
  
// Redis发布订阅频道  
private static final String D_WEB_SOCKET_TOPIC = "d_web_socket_topic";  
// WebSocket连接路径  
private static final String WS_CONNECT_PATH = "/ws/connect";  
  
/**  
* 注册WebSocket处理器  
*/  
@Override  
public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {  
// 配置WebSocket连接路径,允许跨域  
registry.addHandler(dWebSocketHandler(), WS_CONNECT_PATH)  
.setAllowedOrigins("*");  
log.info("#websocket::registerWebSocketHandlers# WebSocket注册成功,路径={}", WS_CONNECT_PATH);  
}  
  
/**  
* 初始化WebSocket处理器  
*/  
@Bean  
public DWebSocketHandler dWebSocketHandler() {  
return new DWebSocketHandler();  
}  
  
/**  
* 初始化Redis消息监听器  
*/  
@Bean  
public DWebSocketRedisMessageListener dWebSocketRedisMessageListener(RedisTemplate<String, Object> redisTemplate) {  
return new DWebSocketRedisMessageListener(redisTemplate);  
}  
  
/**  
* 配置Redis消息监听容器  
*/  
@Bean(name = "dWebSocketRedisMessageListenerContainer")  
public RedisMessageListenerContainer redisMessageListenerContainer(  
RedisConnectionFactory redisConnectionFactory,  
DWebSocketRedisMessageListener dWebSocketRedisMessageListener) {  
RedisMessageListenerContainer container = new RedisMessageListenerContainer();  
container.setConnectionFactory(redisConnectionFactory);  
// 订阅指定频道  
container.addMessageListener(dWebSocketRedisMessageListener, new PatternTopic(D_WEB_SOCKET_TOPIC));  
log.info("#websocket::redisMessageListenerContainer# Redis监听器注册成功,频道={}", D_WEB_SOCKET_TOPIC);  
return container;  
}  
  
/**  
* 初始化WebSocket消息生产者  
*/  
@Bean  
public DWebSocketProducer dWebSocketProducer(RedisTemplate<String, Object> redisTemplate) {  
return new DWebSocketProducer(redisTemplate, D_WEB_SOCKET_TOPIC);  
}  
  
}

四、实战场景:推送用户消息

1. 测试接口:触发消息推送

java 复制代码
package com.demo.redis;  
  
import com.demo.redis.dto.DWebSocketReqDTO;  
import com.demo.redis.dto.User;  
import com.demo.redis.enums.DWebSocketType;  
import com.demo.redis.producer.DWebSocketProducer;  
import lombok.extern.slf4j.Slf4j;  
import org.junit.jupiter.api.Test;  
import org.springframework.beans.factory.annotation.Autowired;  
import org.springframework.beans.factory.annotation.Value;  
import org.springframework.boot.test.context.SpringBootTest;  
  
@Slf4j  
@SpringBootTest(classes = DemoRedisApplication.class)  
public class DWebSocketProducerTest {  
  
@Autowired  
private DWebSocketProducer dWebSocketProducer;  
@Value("${server.port}")  
private int port;  
  
  
@Test  
public void testSend() {  
log.info("port:{}",port);  
User user = new User();  
user.setId(1L);  
user.setName("dd");  
DWebSocketReqDTO<User> req = new DWebSocketReqDTO<User>()  
.setType(DWebSocketType.DEMO.name())  
.setUserId(1L)  
.setBody(user);  
dWebSocketProducer.send(req);  
}  
  
  
}

2. 客户端测试代码

使用 HTML+JS 编写简单的 WebSocket 客户端,连接服务端并接收消息:

html 复制代码
<!DOCTYPE html>
<html lang="zh-CN">
<head>
    <meta charset="UTF-8">
    <title>分布式WebSocket测试</title>
</head>
<body>
<h3>WebSocket测试(userId=1)</h3>
<div id="message"></div>

<script>
    // 连接WebSocket(替换为实际服务端地址)
    const ws = new WebSocket('ws://localhost:10088/demo/ws/connect?userId=1');
    
    // 连接成功
    ws.onopen = function() {
        console.log('WebSocket连接成功');
        appendMessage('连接成功');
    };
    
    // 接收消息
    ws.onmessage = function(event) {
        console.log('接收消息:', event.data);
        appendMessage('接收消息:' + event.data);
    };
    
    // 连接关闭
    ws.onclose = function() {
        console.log('WebSocket连接关闭');
        appendMessage('连接关闭');
    };
    
    // 连接异常
    ws.onerror = function(error) {
        console.error('WebSocket异常:', error);
        appendMessage('连接异常:' + error);
    };
    
    // 向页面追加消息
    function appendMessage(msg) {
        const div = document.createElement('div');
        div.innerText = new Date().toLocaleString() + ' - ' + msg;
        document.getElementById('message').appendChild(div);
    }
</script>
</body>
</html>

3. 分布式测试验证

  1. 启动两个服务实例:分别启动端口 10088、10089 的服务;
  2. 客户端连接:用 HTML 客户端连接 10088 端口(userId=1);
  3. 触发推送:调用 10089 端口 单元测试;
  4. 验证结果:连接 10088 的客户端能接收到消息,证明跨实例推送成功。

五、总结

本文基于 SpringBoot+Redis 缓存实践,实现了分布式 WebSocket,核心要点:

  1. 核心原理:Redis Pub/Sub 作为消息中转站,跨实例同步推送指令,本地 Session 管理器负责实际推送;
  2. 核心组件:生产者发布消息、监听器订阅消息、Session 管理器管理本地连接、处理器管理连接生命周期;
  3. 实战验证:跨实例推送测试验证了方案的有效性,解决了单机 WebSocket 的局限性;
相关推荐
华如锦9 小时前
四:从零搭建一个RAG
java·开发语言·人工智能·python·机器学习·spring cloud·计算机视觉
Tony_yitao9 小时前
22.华为OD机试真题:数组拼接(Java实现,100分通关)
java·算法·华为od·algorithm
JavaGuru_LiuYu9 小时前
Spring Boot 整合 SSE(Server-Sent Events)
java·spring boot·后端·sse
爬山算法10 小时前
Hibernate(26)什么是Hibernate的透明持久化?
java·后端·hibernate
彭于晏Yan10 小时前
Springboot实现数据脱敏
java·spring boot·后端
luming-0210 小时前
java报错解决:sun.net.utils不存
java·经验分享·bug·.net·intellij-idea
北海有初拥10 小时前
Python基础语法万字详解
java·开发语言·python
alonewolf_9910 小时前
Spring IOC容器扩展点全景:深入探索与实践演练
java·后端·spring
super_lzb10 小时前
springboot打war包时将外部配置文件打入到war包内
java·spring boot·后端·maven
毛小茛10 小时前
芋道管理系统学习——项目结构
java·学习