在分布式系统中,单机 WebSocket 无法实现跨服务实例的消息推送(如用户连接在实例 A,而消息触发在实例 B)。本文基于前文 SpringBoot+Redis 缓存实践,结合 Redis 发布订阅(Pub/Sub)机制,实现分布式 WebSocket,解决跨实例消息推送问题,核心是通过 Redis 作为消息中转站,让所有服务实例共享消息通道。
一、核心设计思路
分布式 WebSocket 的核心是 "Redis 发布订阅 + 本地 Session 管理":
- 客户端连接:用户 WebSocket 连接到任意服务实例,实例将连接 Session 存入本地缓存;
- 消息发布:任意实例产生推送消息时,通过 Redis 发布(Publish)到指定频道;
- 消息订阅:所有服务实例订阅 Redis 频道,接收到消息后,检查本地是否有目标用户的 Session;
- 消息推送:有目标用户 Session 的实例,将消息推送给对应的客户端。
整个流程复用了 Redis 缓存实践中的RedisTemplate配置,无需额外引入依赖,保证技术栈统一。
二、环境准备
- 依赖 :
spring-boot-starter-websocket;
三、核心代码实现
1. 枚举定义:WebSocket 消息类型
统一管理 WebSocket 消息类型,便于后续扩展不同业务的消息推送:
java
package com.demo.redis.enums;
import lombok.AllArgsConstructor;
import lombok.Getter;
/**
* WebSocket消息类型枚举
*/
@Getter
@AllArgsConstructor
public enum DWebSocketType {
DEMO("示例消息"),
NOTICE("系统通知"),
ORDER("订单消息"),
;
private final String desc;
}
2. DTO 定义:WebSocket 消息体
封装推送消息的通用结构,包含用户 ID、消息类型、消息内容:
java
package com.demo.redis.dto;
import com.demo.redis.enums.DWebSocketType;
import lombok.Data;
import lombok.experimental.Accessors;
import java.io.Serializable;
/**
* WebSocket推送请求DTO
* 实现Serializable保证Redis传输时可序列化
*/
@Accessors(chain = true)
@Data
public class DWebSocketReqDTO<T extends Serializable> implements Serializable {
private static final long serialVersionUID = 1292707161671865097L;
/**
* 消息类型 @see DWebSocketType
*/
private String type;
/**
* 目标用户ID(推送的核心标识)
*/
private Long userId;
/**
* 消息体(支持任意序列化类型)
*/
private T body;
}
3. 消息生产者:发布 Redis 消息
封装 Redis 发布逻辑,统一消息推送入口,所有需要推送的场景都通过该类发布消息:
java
package com.demo.redis.producer;
import com.demo.redis.dto.DWebSocketReqDTO;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.stereotype.Component;
import javax.annotation.Resource;
import java.io.Serializable;
/**
* WebSocket消息生产者
* 负责将消息发布到Redis指定频道
*/
@Slf4j
public class DWebSocketProducer {
private RedisTemplate<String, Object> redisTemplate;
private String channel;
public DWebSocketProducer(RedisTemplate<String, Object> redisTemplate, String channel) {
this.redisTemplate = redisTemplate;
this.channel = channel;
}
/**
* 发布WebSocket消息
*
* @param reqDTO 消息体
* @param <T> 消息内容类型
*/
public <T extends Serializable> void send(DWebSocketReqDTO<T> reqDTO) {
log.info("#websocket::send# 准备推送消息,userId={},type={}", reqDTO.getUserId(), reqDTO.getType());
// 参数校验:用户ID和消息类型不能为空
if (reqDTO.getUserId() == null || reqDTO.getType() == null) {
log.warn("#websocket::send::return# 参数缺失,userId={},type={}", reqDTO.getUserId(), reqDTO.getType());
return;
}
// 核心:Redis发布消息到指定频道
redisTemplate.convertAndSend(channel, reqDTO);
log.info("#websocket::send::success# 消息发布成功,userId={}", reqDTO.getUserId());
}
}
4. Session 管理器:本地连接管理
管理当前服务实例的 WebSocket 连接 Session,提供增删查、自动清理等能力,是分布式推送的本地核心:
java
package com.demo.redis.manage;
import cn.hutool.core.collection.CollectionUtil;
import cn.hutool.core.util.StrUtil;
import com.demo.redis.dto.DWebSocketReqDTO;
import lombok.extern.slf4j.Slf4j;
import org.apache.tomcat.websocket.WsSession;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.adapter.standard.StandardWebSocketSession;
import javax.websocket.Session;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
/**
* WebSocket Session管理器
* 管理当前实例的所有WebSocket连接,线程安全
*/
@Slf4j
public class DWebSocketSessionManage {
// 核心参数:Session池容量控制
public static final int CORE_SIZE = 30000; // 核心容量
public static final int MAX_SIZE = 50000; // 最大容量
// 存储Session的最后访问时间(用于清理过期连接)
public static Map<String, Long> KEY_TIME_POOL = new HashMap<>(200);
// 核心:存储用户ID与WebSocketSession的映射(线程安全)
public static Map<String, WebSocketSession> SESSION_POOL = new ConcurrentHashMap<>(200);
/**
* 添加Session到本地缓存
* @param key 唯一标识(格式:sk::userId)
* @param session WebSocket连接
*/
public static void add(String key, WebSocketSession session) {
if (StrUtil.isBlank(key)) {
return;
}
// 先清理过期/超量连接,防止内存溢出
clear();
// 更新最后访问时间
KEY_TIME_POOL.put(key, System.currentTimeMillis());
// 存储Session
SESSION_POOL.put(key, session);
log.info("#websocket::add# Session添加成功,key={},当前连接数={}", key, SESSION_POOL.size());
}
/**
* 清理超量连接(保留核心容量)
*/
private static void clear() {
if (KEY_TIME_POOL.size() > MAX_SIZE) {
log.warn("#websocket::clear# 连接数超过最大值{},开始清理", MAX_SIZE);
// 将连接按最后访问时间升序排序,优先清理最久未使用的
List<Map.Entry<String, Long>> list = new ArrayList<>(KEY_TIME_POOL.entrySet());
list.sort(Map.Entry.comparingByValue());
int clearCount = MAX_SIZE - CORE_SIZE; // 需要清理的数量
int count = 0;
for (Map.Entry<String, Long> entry : list) {
if (count < clearCount) {
remove(entry.getKey()); // 移除连接并关闭Session
count++;
} else {
break;
}
}
log.info("#websocket::clear::success# 清理完成,共清理{}个连接", count);
}
}
/**
* 获取用户的Session
* @param key 唯一标识
* @return WebSocketSession
*/
public static WebSocketSession get(String key) {
if (StrUtil.isBlank(key)) {
return null;
}
// 更新最后访问时间(防止被清理)
KEY_TIME_POOL.put(key, System.currentTimeMillis());
return SESSION_POOL.get(key);
}
/**
* 移除并关闭Session
* @param key 唯一标识
*/
public static void remove(String key) {
if (StrUtil.isBlank(key)) {
return;
}
// 清理时间记录
KEY_TIME_POOL.remove(key);
// 移除并关闭Session
WebSocketSession session = SESSION_POOL.remove(key);
if (session != null) {
try {
session.close();
log.info("#websocket::remove# Session关闭成功,key={}", key);
} catch (IOException e) {
log.error("#websocket::remove::error# Session关闭失败,key={}", key, e);
}
}
}
/**
* 根据消息体生成SessionKey
*/
public static String getSessionKey(DWebSocketReqDTO reqDTO) {
if (reqDTO.getUserId() == null) {
return null;
}
return getSessionKey(reqDTO.getUserId());
}
/**
* 从WebSocketSession中解析用户ID,生成SessionKey
*/
public static String getSessionKey(WebSocketSession session) {
try {
// 解析请求参数中的userId
StandardWebSocketSession standardSession = (StandardWebSocketSession) session;
WsSession wsSession = (WsSession) standardSession.getNativeSession();
List<String> userIdList = wsSession.getRequestParameterMap().get("userId");
if (CollectionUtil.isEmpty(userIdList)) {
log.warn("#websocket::getSessionKey# Session中未获取到userId");
return null;
}
return getSessionKey(userIdList.get(0));
} catch (Exception e) {
log.error("#websocket::getSessionKey::error# 解析SessionKey失败", e);
return null;
}
}
/**
* 生成统一的SessionKey格式
*/
private static String getSessionKey(Object userId) {
return "sk::" + userId; // 格式:sk::1001
}
}
5. Redis 消息监听器:订阅并推送消息
订阅 Redis 频道的消息,接收到消息后,检查本地是否有目标用户的 Session,有则推送:
java
package com.demo.redis.listener;
import cn.hutool.core.util.ObjectUtil;
import cn.hutool.core.util.StrUtil;
import cn.hutool.json.JSONUtil;
import com.demo.redis.dto.DWebSocketReqDTO;
import com.demo.redis.manage.DWebSocketSessionManage;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.data.redis.connection.Message;
import org.springframework.data.redis.connection.MessageListener;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;
/**
* Redis消息监听器
* 订阅Redis频道,接收消息并推送给本地连接的客户端
*/
@AllArgsConstructor
@Slf4j
public class DWebSocketRedisMessageListener implements MessageListener {
private RedisTemplate<String, Object> redisTemplate;
/**
* 核心方法:接收到Redis消息后的处理逻辑
*/
@Override
public void onMessage(Message message, byte[] pattern) {
// 1. 反序列化Redis消息
Object value = redisTemplate.getValueSerializer().deserialize(message.getBody());
if (ObjectUtil.isNull(value)) {
log.warn("#websocket::onMessage::return# 消息体为空");
return;
}
DWebSocketReqDTO reqDTO = (DWebSocketReqDTO) value;
log.info("#websocket::onMessage# 接收到推送消息,userId={},type={}", reqDTO.getUserId(), reqDTO.getType());
// 2. 获取目标用户的SessionKey
String sessionKey = DWebSocketSessionManage.getSessionKey(reqDTO);
if (StrUtil.isBlank(sessionKey)) {
log.warn("#websocket::onMessage::return# 未获取到SessionKey,userId={}", reqDTO.getUserId());
return;
}
// 3. 获取本地Session(当前实例是否连接该用户)
WebSocketSession userSession = DWebSocketSessionManage.get(sessionKey);
if (userSession == null) {
log.info("#websocket::onMessage::return# 当前实例无该用户连接,userId={}", reqDTO.getUserId());
return;
}
// 4. 推送消息给客户端
try {
if (userSession.isOpen()) {
String jsonStr = JSONUtil.toJsonStr(reqDTO);
userSession.sendMessage(new TextMessage(jsonStr));
log.info("#websocket::onMessage::success# 消息推送成功,userId={}", reqDTO.getUserId());
} else {
// Session已关闭,清理缓存
DWebSocketSessionManage.remove(sessionKey);
log.warn("#websocket::onMessage::close# Session已关闭,清理缓存,userId={}", reqDTO.getUserId());
}
} catch (Exception e) {
log.error("#websocket::onMessage::error# 消息推送失败,userId={}", reqDTO.getUserId(), e);
// 推送失败,清理Session
DWebSocketSessionManage.remove(sessionKey);
}
}
}
6. WebSocket 处理器:管理连接生命周期
处理 WebSocket 连接的建立、消息接收、异常、关闭等生命周期事件:
java
package com.demo.redis.handler;
import com.demo.redis.manage.DWebSocketSessionManage;
import lombok.extern.slf4j.Slf4j;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.TextWebSocketHandler;
/**
* WebSocket处理器
* 处理连接的建立、消息、异常、关闭等事件
*/
@Slf4j
public class DWebSocketHandler extends TextWebSocketHandler {
/**
* 连接建立成功后触发
*/
@Override
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
// 生成SessionKey并添加到本地缓存
String sessionKey = DWebSocketSessionManage.getSessionKey(session);
log.info("#websocket::afterConnectionEstablished# 连接建立成功,sessionKey={},sessionId={}", sessionKey, session.getId());
DWebSocketSessionManage.add(sessionKey, session);
}
/**
* 接收客户端发送的消息
*/
@Override
protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {
String sessionKey = DWebSocketSessionManage.getSessionKey(session);
log.info("#websocket::handleTextMessage# 接收客户端消息,sessionKey={},sessionId={},payload={}",
sessionKey, session.getId(), message.getPayload());
// 示例:回复心跳包
session.sendMessage(new TextMessage("pong"));
}
/**
* 传输异常时触发
*/
@Override
public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {
String sessionKey = DWebSocketSessionManage.getSessionKey(session);
log.error("#websocket::handleTransportError# 连接传输异常,sessionKey={},sessionId={}",
sessionKey, session.getId(), exception);
// 清理异常连接
DWebSocketSessionManage.remove(sessionKey);
}
/**
* 连接关闭后触发
*/
@Override
public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception {
String sessionKey = DWebSocketSessionManage.getSessionKey(session);
log.info("#websocket::afterConnectionClosed# 连接关闭,sessionKey={},sessionId={},status={}",
sessionKey, session.getId(), status);
// 清理关闭的连接
DWebSocketSessionManage.remove(sessionKey);
}
}
7. 核心配置类:整合所有组件
配置 WebSocket 处理器、Redis 监听器、消息生产者等核心组件,完成分布式 WebSocket 的整体装配:
java
package com.demo.redis.config;
import com.demo.redis.handler.DWebSocketHandler;
import com.demo.redis.listener.DWebSocketRedisMessageListener;
import com.demo.redis.producer.DWebSocketProducer;
import lombok.extern.slf4j.Slf4j;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.data.redis.connection.RedisConnectionFactory;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.listener.PatternTopic;
import org.springframework.data.redis.listener.RedisMessageListenerContainer;
import org.springframework.web.socket.config.annotation.EnableWebSocket;
import org.springframework.web.socket.config.annotation.WebSocketConfigurer;
import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry;
/**
* 分布式WebSocket核心配置
*/
@Slf4j
@EnableWebSocket // 开启WebSocket支持
@Configuration
public class DWebSocketConfig implements WebSocketConfigurer {
// Redis发布订阅频道
private static final String D_WEB_SOCKET_TOPIC = "d_web_socket_topic";
// WebSocket连接路径
private static final String WS_CONNECT_PATH = "/ws/connect";
/**
* 注册WebSocket处理器
*/
@Override
public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
// 配置WebSocket连接路径,允许跨域
registry.addHandler(dWebSocketHandler(), WS_CONNECT_PATH)
.setAllowedOrigins("*");
log.info("#websocket::registerWebSocketHandlers# WebSocket注册成功,路径={}", WS_CONNECT_PATH);
}
/**
* 初始化WebSocket处理器
*/
@Bean
public DWebSocketHandler dWebSocketHandler() {
return new DWebSocketHandler();
}
/**
* 初始化Redis消息监听器
*/
@Bean
public DWebSocketRedisMessageListener dWebSocketRedisMessageListener(RedisTemplate<String, Object> redisTemplate) {
return new DWebSocketRedisMessageListener(redisTemplate);
}
/**
* 配置Redis消息监听容器
*/
@Bean(name = "dWebSocketRedisMessageListenerContainer")
public RedisMessageListenerContainer redisMessageListenerContainer(
RedisConnectionFactory redisConnectionFactory,
DWebSocketRedisMessageListener dWebSocketRedisMessageListener) {
RedisMessageListenerContainer container = new RedisMessageListenerContainer();
container.setConnectionFactory(redisConnectionFactory);
// 订阅指定频道
container.addMessageListener(dWebSocketRedisMessageListener, new PatternTopic(D_WEB_SOCKET_TOPIC));
log.info("#websocket::redisMessageListenerContainer# Redis监听器注册成功,频道={}", D_WEB_SOCKET_TOPIC);
return container;
}
/**
* 初始化WebSocket消息生产者
*/
@Bean
public DWebSocketProducer dWebSocketProducer(RedisTemplate<String, Object> redisTemplate) {
return new DWebSocketProducer(redisTemplate, D_WEB_SOCKET_TOPIC);
}
}
四、实战场景:推送用户消息
1. 测试接口:触发消息推送
java
package com.demo.redis;
import com.demo.redis.dto.DWebSocketReqDTO;
import com.demo.redis.dto.User;
import com.demo.redis.enums.DWebSocketType;
import com.demo.redis.producer.DWebSocketProducer;
import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.test.context.SpringBootTest;
@Slf4j
@SpringBootTest(classes = DemoRedisApplication.class)
public class DWebSocketProducerTest {
@Autowired
private DWebSocketProducer dWebSocketProducer;
@Value("${server.port}")
private int port;
@Test
public void testSend() {
log.info("port:{}",port);
User user = new User();
user.setId(1L);
user.setName("dd");
DWebSocketReqDTO<User> req = new DWebSocketReqDTO<User>()
.setType(DWebSocketType.DEMO.name())
.setUserId(1L)
.setBody(user);
dWebSocketProducer.send(req);
}
}
2. 客户端测试代码
使用 HTML+JS 编写简单的 WebSocket 客户端,连接服务端并接收消息:
html
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<title>分布式WebSocket测试</title>
</head>
<body>
<h3>WebSocket测试(userId=1)</h3>
<div id="message"></div>
<script>
// 连接WebSocket(替换为实际服务端地址)
const ws = new WebSocket('ws://localhost:10088/demo/ws/connect?userId=1');
// 连接成功
ws.onopen = function() {
console.log('WebSocket连接成功');
appendMessage('连接成功');
};
// 接收消息
ws.onmessage = function(event) {
console.log('接收消息:', event.data);
appendMessage('接收消息:' + event.data);
};
// 连接关闭
ws.onclose = function() {
console.log('WebSocket连接关闭');
appendMessage('连接关闭');
};
// 连接异常
ws.onerror = function(error) {
console.error('WebSocket异常:', error);
appendMessage('连接异常:' + error);
};
// 向页面追加消息
function appendMessage(msg) {
const div = document.createElement('div');
div.innerText = new Date().toLocaleString() + ' - ' + msg;
document.getElementById('message').appendChild(div);
}
</script>
</body>
</html>
3. 分布式测试验证
- 启动两个服务实例:分别启动端口 10088、10089 的服务;
- 客户端连接:用 HTML 客户端连接 10088 端口(userId=1);
- 触发推送:调用 10089 端口 单元测试;
- 验证结果:连接 10088 的客户端能接收到消息,证明跨实例推送成功。


五、总结
本文基于 SpringBoot+Redis 缓存实践,实现了分布式 WebSocket,核心要点:
- 核心原理:Redis Pub/Sub 作为消息中转站,跨实例同步推送指令,本地 Session 管理器负责实际推送;
- 核心组件:生产者发布消息、监听器订阅消息、Session 管理器管理本地连接、处理器管理连接生命周期;
- 实战验证:跨实例推送测试验证了方案的有效性,解决了单机 WebSocket 的局限性;