文章目录
- 一、前言
- 二、项目背景
- 三、解决方案
-
- [1. CountDownLatch](#1. CountDownLatch)
- [2. CompletableFuture](#2. CompletableFuture)
- 三、思路延展
- 四、参考内容
一、前言
本系列用来记录一些在实际项目中的小东西,并记录在过程中想到一些小东西,因为是随笔记录,所以内容不会过于详细。
二、项目背景
在 A 项目中,有一个场景需要从 B 项目中获取数据,正常来说 B 项目提供一个接口供 A 项目调用就好,但是 B 项目因为某些原因只能通过 websocket 方式调用。
基于上述,整个逻辑简单来说就是 :

java
@startuml
title WebSocket 转发流程
actor 用户
participant "服务 A" as A
participant "服务 B" as B
用户 -> A: 请求操作
activate A
A -> B: 发送 WebSocket 消息
note right of A: 挂起 用户请求,等待 websocket 消息返回或超时
activate B
B -> B: 处理请求
B --> A: 将结果通过 WebSocket 消息返回
deactivate B
A --> 用户
deactivate A
@enduml
该场景并不复杂,简单来说就是一个异步请求转同步的场景,实现方案有很多种,包括但不限于 CountDownLatch、Lock+Condition、CompletableFuture 等等,下面来看具体实现。
三、解决方案
在引入具体实现前,我们需要先搭建一个 websocket 服务端,如下:
-
引入 web-socket 依赖
xml<dependency> <groupId>org.springframework.boot</groupId> <artifactId>spring-boot-starter-websocket</artifactId> </dependency> -
构建websocket服务端
java/** * Spring整合WebSocket的核心配置类(必须加!) * 作用:让Spring扫描并注册@ServerEndpoint注解的WebSocket端点 */ @Configuration public class WebSocketConfig { // 关键Bean:自动注册所有@ServerEndpoint注解的实例 @Bean public ServerEndpointExporter serverEndpointExporter() { return new ServerEndpointExporter(); } } @Data @NoArgsConstructor @AllArgsConstructor public class WsMessageDTO { /** * 全局唯一请求ID,用于HTTP请求与WebSocket响应配对 */ private String requestId; /** * 命令字,用于区分不同的业务场景 */ private String cmd; /** * 实际业务请求/响应数据(JSON字符串/普通文本) */ private String data; } /** * WebSocket服务端(模拟第三方服务) * 功能:接收客户端消息 → 模拟业务处理 → 透传requestId返回响应 * 访问地址:ws://{ip}:{port}/{context-path}/ws/server */ @Slf4j @Component @ServerEndpoint("/ws/server") public class WsServer { private final ObjectMapper objectMapper = new ObjectMapper(); /** * 接收客户端消息并响应 */ @OnMessage public void onMessage(Session session, String requestJson) throws Exception { log.info("[ws server][收到客户端消息 | 数据={}]", requestJson); // 1. 解析客户端消息 WsMessageDTO requestMsg = objectMapper.readValue(requestJson, WsMessageDTO.class); String requestId = requestMsg.getRequestId(); String cmd = requestMsg.getCmd(); String bizData = requestMsg.getData(); // TODO : 根据 cmd 不同,路由到不同的处理逻辑 // 2. 模拟第三方服务业务处理(延时5秒,模拟真实处理耗时) Thread.sleep(5000); // 3. 组装响应消息 // 响应数据(模拟业务处理结果) String responseBizData = "response_" + bizData; WsMessageDTO responseMsg = new WsMessageDTO(requestId, cmd, responseBizData); String responseJson = objectMapper.writeValueAsString(responseMsg); // 4. 返回响应给客户端 session.getBasicRemote().sendText(responseJson); log.info("[ws server][响应客户端消息 | 数据={}]", responseJson); } }
上面的代码会启动一个 websocket 服务器,用于接受客户端的请求,并耗时5s(模拟业务处理耗时)后将结果返回给客户端。
基于上面的服务端,这里列举出下面 CountDownLatch 和 CompletableFuture 的解决方案示例:
本篇解决方案未考虑各种异常情况,如断线重连、并发效率问题,仅为了论证问题的解决方案。
1. CountDownLatch
基于 CountDownLatch 的逻辑比较简单,如下:
-
初始化时设置计数器 = 1 ,HTTP 请求线程执行
latch.await()后进入阻塞等待状态; -
当 WebSocket 回调线程收到第三方响应后,执行
latch.countDown(),计数器变为 0,阻塞线程被立即唤醒; -
支持超时配置,避免 HTTP 请求因第三方服务异常而无限阻塞。
具体代码实现如下, 我们通过调用 WsClient#sendAndSyncWait 方法可以实现基于 websocket 的同步请求调用:
java
/**
* 请求上下文(存储单次请求的等待器+响应数据)
*/
@Data
public class RequestContext {
private CountDownLatch latch;
private int code;
private String message;
private String response;
public void success(String response) {
this.code = 0;
this.response = response;
this.latch.countDown();
}
public void error(String response, String message) {
this.code = 1;
this.response = response;
this.message = message;
this.latch.countDown();
}
}
/**
* WebSocket客户端核心类:负责与第三方服务提供方建立Socket连接、收发消息、管理请求上下文
*/
@Slf4j
@Component
public class WsClient implements ApplicationRunner {
/**
* 第三方服务提供方的WebSocket连接地址
*/
private static final String WS_SERVER_URL = "ws://127.0.0.1:10000/temp-demo/ws/server";
/**
* 同步等待超时时间(单位:毫秒),避免HTTP请求无限阻塞
*/
private static final long WAIT_TIMEOUT = 30000L;
/**
* 存储请求上下文:key=唯一请求ID,value=请求对应的「等待器+响应数据」
*/
private final Map<String, RequestContext> requestContextMap = new ConcurrentHashMap<>();
/**
* WebSocket会话对象(全局单例,维护与服务提供方的长连接)
*/
private Session wsSession;
/**
* 对外提供:发送Socket请求 + 同步等待响应结果
*
* @param requestMsg 向第三方服务发送的请求消息
* @return 第三方服务返回的响应消息
*/
public String sendAndSyncWait(String cmd, String requestMsg) throws Exception {
// 1. 生成唯一请求ID(用于请求-响应精准配对)
String requestId = UUID.randomUUID().toString().replace("-", "");
// 2. 创建CountDownLatch(计数器=1,响应回来后countDown(),线程唤醒)
CountDownLatch latch = new CountDownLatch(1);
// 3. 初始化请求上下文,存入容器
RequestContext context = new RequestContext();
context.setLatch(latch);
requestContextMap.put(requestId, context);
try {
// 4. 组装消息体(必须包含requestId,第三方服务需透传该ID返回)
String sendMsg = buildWsMsg(requestId, cmd, requestMsg);
// 5. 向第三方服务发送WebSocket消息
if (wsSession != null && wsSession.isOpen()) {
wsSession.getBasicRemote().sendText(sendMsg);
log.info("[ws client][向第三方服务发送Socket消息,requestId={},消息体={}]", requestId, sendMsg);
} else {
throw new RuntimeException("WebSocket连接已断开,无法发送消息");
}
// 6. 阻塞等待响应(超时时间WAIT_TIMEOUT,超时抛出异常)
boolean awaitResult = latch.await(WAIT_TIMEOUT, TimeUnit.MILLISECONDS);
if (!awaitResult) {
throw new RuntimeException("等待第三方服务Socket响应超时,超时时间:" + WAIT_TIMEOUT + "ms");
}
// 7. 获取响应结果并返回
String responseMsg = context.getResponse();
if (responseMsg == null || responseMsg.isEmpty()) {
throw new RuntimeException("[ws client][第三方服务返回空响应]");
}
log.info("[ws client][收到第三方服务Socket响应,requestId={},响应体={}]", requestId, responseMsg);
return responseMsg;
} finally {
// 8. 无论成败,移除上下文(避免内存泄漏)
requestContextMap.remove(requestId);
}
}
/**
* 组装WebSocket消息体(自定义格式,第三方服务需按此格式解析、透传requestId)
* 格式示例:{"requestId":"xxx","data":"实际请求内容"}
*/
private String buildWsMsg(String requestId, String cmd, String data) {
return JSON.toJSONString(new WsMessageDTO(requestId, cmd, data));
}
@Override
public void run(ApplicationArguments args) throws Exception {
// 创建WebSocket客户端端点,绑定消息处理器
WebSocketContainer container = ContainerProvider.getWebSocketContainer();
container.connectToServer(new WsMessageHandler(), URI.create(WS_SERVER_URL));
log.info("[ws client][WebSocket客户端连接第三方服务成功,地址:{}]", WS_SERVER_URL);
}
/**
* WebSocket消息处理器(内部类):处理服务提供方的响应消息、连接状态变更
*/
@ClientEndpoint
public class WsMessageHandler {
/**
* 建立Socket连接成功回调:初始化全局会话对象
*/
@OnOpen
public void onOpen(Session session) {
WsClient.this.wsSession = session;
log.info("[ws client][WebSocket连接已建立,sessionId={}]", session.getId());
}
/**
* 接收第三方服务的Socket响应消息(核心回调方法)
*/
@OnMessage
public void onMessage(String responseMsg) {
log.info("[ws client][接收到Socket原始响应消息:{}]", responseMsg);
try {
WsMessageDTO wsMessageDTO = JSON.parseObject(responseMsg, WsMessageDTO.class);
String requestId = wsMessageDTO.getRequestId();
if (requestId == null || !requestContextMap.containsKey(requestId)) {
log.warn("[ws client][收到无效响应:requestId不存在,响应体={}]", responseMsg);
RequestContext context = requestContextMap.get(requestId);
context.error(responseMsg, "requestId不存在");
return;
}
RequestContext context = requestContextMap.get(requestId);
context.success(responseMsg);
} catch (Exception e) {
log.error("[ws client][解析Socket响应消息失败] {}", responseMsg, e);
}
}
/**
* Socket连接关闭回调
*/
@OnClose
public void onClose(Session session, CloseReason reason) {
log.info("[ws client][WebSocket连接关闭] sessionId={},原因={}", session.getId(), reason.getReasonPhrase());
WsClient.this.wsSession = null;
}
/**
* Socket连接异常回调
*/
@OnError
public void onError(Session session, Throwable throwable) {
log.error("[ws client][WebSocket连接发生异常] sessionId={}", session.getId(), throwable);
WsClient.this.wsSession = null;
}
}
}
2. CompletableFuture
CompletableFuture 的实现思路与 CountDownLatch 类似,这里就不再赘述。
关于 CompletableFuture 的介绍可参考 Java基础 : CompletableFuture① 基础使用
代码具体实现如下:
java
/**
* WebSocket客户端核心类:负责与第三方服务提供方建立Socket连接、收发消息、管理请求上下文
*/
@Slf4j
@Component
public class WsClientCf implements ApplicationRunner {
/**
* 第三方服务提供方的WebSocket连接地址
*/
private static final String WS_SERVER_URL = "ws://127.0.0.1:10000/temp-demo/ws/server";
/**
* 同步等待超时时间(单位:毫秒),避免HTTP请求无限阻塞
*/
private static final long WAIT_TIMEOUT = 30000L;
/**
* 首次重连延迟(毫秒),指数退避重连
*/
private static final long RECONNECT_INIT_DELAY = 3000L;
/**
* 最大重连延迟(毫秒)
*/
private static final long RECONNECT_MAX_DELAY = 60000L;
/**
* WebSocket会话对象(全局单例,维护与服务提供方的长连接)
*/
private Session wsSession;
/**
* 连接状态标记(原子类,保证线程安全)
*/
private final AtomicBoolean isConnected = new AtomicBoolean(false);
/**
* 重连线程池(单线程,避免多线程重连冲突)
*/
private final ScheduledExecutorService reconnectExecutor = Executors.newSingleThreadScheduledExecutor();
/**
* 存储请求上下文:key=唯一请求ID,value=请求对应的「等待器+响应数据」
*/
private final Map<String, CompletableFuture<WsMessageDTO>> requestContextMap = new ConcurrentHashMap<>();
/**
* 对外提供:发送Socket请求 + 同步等待响应结果
*
* @param requestMsg 向第三方服务发送的请求消息
* @return 第三方服务返回的响应消息
*/
public String sendAndSyncWait(String cmd, String requestMsg) throws Exception {
// 1. 生成唯一请求ID(用于请求-响应精准配对)
String requestId = UUID.randomUUID().toString().replace("-", "");
// 2. 创建CountDownLatch(计数器=1,响应回来后countDown(),线程唤醒)
// 3. 初始化请求上下文,存入容器
CompletableFuture<WsMessageDTO> completableFuture = new CompletableFuture<>();
requestContextMap.put(requestId, completableFuture);
try {
// 4. 组装消息体(必须包含requestId,第三方服务需透传该ID返回)
String sendMsg = buildWsMsg(requestId, cmd, requestMsg);
// 5. 向第三方服务发送WebSocket消息
if (wsSession != null && wsSession.isOpen()) {
wsSession.getBasicRemote().sendText(sendMsg);
log.info("[ws client][向第三方服务发送Socket消息,requestId={},消息体={}]", requestId, sendMsg);
} else {
throw new RuntimeException("WebSocket连接已断开,无法发送消息");
}
// 6. 阻塞等待响应(超时时间WAIT_TIMEOUT,超时抛出异常)
WsMessageDTO context = completableFuture.get(WAIT_TIMEOUT, TimeUnit.MILLISECONDS);
log.info("[ws client][收到第三方服务Socket响应,requestId={},响应体={}]", requestId, context);
return context.getData();
} finally {
// 8. 无论成败,移除上下文(避免内存泄漏)
requestContextMap.remove(requestId);
}
}
/**
* 组装WebSocket消息体(自定义格式,第三方服务需按此格式解析、透传requestId)
* 格式示例:{"requestId":"xxx","data":"实际请求内容"}
*/
private String buildWsMsg(String requestId, String cmd, String data) {
return JSON.toJSONString(new WsMessageDTO(requestId, cmd, data));
}
@Override
public void run(ApplicationArguments args) {
doConnect(RECONNECT_INIT_DELAY);
}
/**
* 建立WebSocket连接(初始化+重连通用)
*/
private void doConnect(long delay) {
reconnectExecutor.schedule(() -> {
try {
WebSocketContainer container = ContainerProvider.getWebSocketContainer();
container.connectToServer(new WsMessageHandler(), URI.create(WS_SERVER_URL));
isConnected.set(true);
} catch (Exception e) {
isConnected.set(false);
// 指数退避重连:延迟翻倍,最大不超过60s
long nextDelay = Math.min(delay * 2, RECONNECT_MAX_DELAY);
doConnect(nextDelay);
}
}, delay, TimeUnit.MILLISECONDS);
}
/**
* WebSocket消息处理器(内部类):处理服务提供方的响应消息、连接状态变更
*/
@ClientEndpoint
public class WsMessageHandler {
/**
* 建立Socket连接成功回调:初始化全局会话对象
*/
@OnOpen
public void onOpen(Session session) {
WsClientCf.this.wsSession = session;
log.info("[ws client][WebSocket连接已建立,sessionId={}]", session.getId());
}
/**
* 接收第三方服务的Socket响应消息(核心回调方法)
*/
@OnMessage
public void onMessage(String responseMsg) {
log.info("[ws client][接收到Socket原始响应消息:{}]", responseMsg);
WsMessageDTO wsMessageDTO = JSON.parseObject(responseMsg, WsMessageDTO.class);
String requestId = wsMessageDTO.getRequestId();
try {
if (requestId == null || !requestContextMap.containsKey(requestId)) {
log.warn("[ws client][收到无效响应:requestId不存在,响应体={}]", responseMsg);
CompletableFuture<WsMessageDTO> completableFuture = requestContextMap.get(requestId);
completableFuture.completeExceptionally(new RuntimeException("requestId不存在"));
return;
}
CompletableFuture<WsMessageDTO> completableFuture = requestContextMap.get(requestId);
completableFuture.complete(wsMessageDTO);
} catch (Exception e) {
log.error("[ws client][解析Socket响应消息失败] {}", responseMsg, e);
CompletableFuture<WsMessageDTO> completableFuture = requestContextMap.get(requestId);
completableFuture.completeExceptionally(e);
}
}
/**
* Socket连接关闭回调
*/
@OnClose
public void onClose(Session session, CloseReason reason) {
log.info("[ws client][WebSocket连接关闭] sessionId={},原因={}", session.getId(), reason.getReasonPhrase());
isConnected.set(false);
doConnect(RECONNECT_INIT_DELAY);
}
/**
* Socket连接异常回调
*/
@OnError
public void onError(Session session, Throwable throwable) {
log.error("[ws client][WebSocket连接发生异常] sessionId={}", session.getId(), throwable);
isConnected.set(false);
doConnect(RECONNECT_INIT_DELAY);
}
}
}
三、思路延展
SynchronousQueue 的相关内容在 Java基础 : BlockingQueue浅析 一文中有过简单介绍,当时写这篇文章时还没有具体业务使用场景。而本场景下也适合使用 SynchronousQueue ,因此做个记录。
其实本篇就是为了 SynchronousQueue 这盘醋包的饺子,因为之前使用 SynchronousQueue 的时候较少,该场景突然想到了该方案,因此写下此篇加深一下理解。
SynchronousQueue 是 JUC 提供的特殊无界阻塞队列 ,也是实现「一对一生产者 - 消费者」模型的最优工具,核心特性与普通队列(ArrayBlockingQueue/LinkedBlockingQueue)有本质区别。核心特性包括:
- 无容量特性 :队列不存储任何元素,是「传值通道」而非「存储容器」;
- 生产消费强耦合 :调用
put(T)生产元素的线程,会阻塞等待 直到有线程调用take()取走元素;反之,调用take()消费的线程,会阻塞等待 直到有线程调用put(T)放入元素; - 天然阻塞 + 唤醒 :生产 / 消费动作完成后,阻塞线程会自动唤醒 ,无需额外调用
notify()/countDown()/complete(); - 支持超时操作 :提供
poll(long timeout, TimeUnit unit)超时获取方法,完美解决「HTTP 线程无限阻塞」问题; - 线程安全:JUC 原生线程安全,高并发下无锁竞争风险,性能优异。
简单来说 :当我们调用 入队方法(put、add、 offer) 时并不会立刻返回,而是阻塞等待,直到有其他操作(一般是其他线程)调用了该队列的出队方法 (remove、poll、take) 后,入队方法才会返回结果。同理当我们调用出队方法时如果之前没有其他操作调用了入队方法则会挂起等待,直至其他操作调用入队方法。
SynchronousQueue 的实现本质与 CountDownLatch、CompletableFuture 没什么区别,如下:
java
/**
* WebSocket客户端核心类:负责与第三方服务提供方建立Socket连接、收发消息、管理请求上下文
*/
@Slf4j
@Component
public class WsClientSq implements ApplicationRunner {
/**
* 第三方服务提供方的WebSocket连接地址
*/
private static final String WS_SERVER_URL = "ws://127.0.0.1:10000/temp-demo/ws/server";
/**
* 同步等待超时时间(单位:毫秒),避免HTTP请求无限阻塞
*/
private static final long WAIT_TIMEOUT = 30000L;
/**
* 首次重连延迟(毫秒),指数退避重连
*/
private static final long RECONNECT_INIT_DELAY = 3000L;
/**
* 最大重连延迟(毫秒)
*/
private static final long RECONNECT_MAX_DELAY = 60000L;
/**
* WebSocket会话对象(全局单例,维护与服务提供方的长连接)
*/
private Session wsSession;
/**
* 连接状态标记(原子类,保证线程安全)
*/
private final AtomicBoolean isConnected = new AtomicBoolean(false);
/**
* 重连线程池(单线程,避免多线程重连冲突)
*/
private final ScheduledExecutorService reconnectExecutor = Executors.newSingleThreadScheduledExecutor();
/**
* 存储请求上下文:key=唯一请求ID,value=请求对应的「等待器+响应数据」
*/
private final Map<String, SynchronousQueue<WsMessageDTO>> requestContextMap = new ConcurrentHashMap<>();
/**
* 对外提供:发送Socket请求 + 同步等待响应结果
*
* @param requestMsg 向第三方服务发送的请求消息
* @return 第三方服务返回的响应消息
*/
public String sendAndSyncWait(String cmd, String requestMsg) throws Exception {
// 1. 生成唯一请求ID(用于请求-响应精准配对)
String requestId = UUID.randomUUID().toString().replace("-", "");
// 2. 创建CountDownLatch(计数器=1,响应回来后countDown(),线程唤醒)
// 3. 初始化请求上下文,存入容器
SynchronousQueue<WsMessageDTO> resultQueue = new SynchronousQueue<>();
requestContextMap.put(requestId, resultQueue);
try {
// 4. 组装消息体(必须包含requestId,第三方服务需透传该ID返回)
String sendMsg = buildWsMsg(requestId, cmd, requestMsg);
// 5. 向第三方服务发送WebSocket消息
if (wsSession != null && wsSession.isOpen()) {
wsSession.getBasicRemote().sendText(sendMsg);
log.info("[ws client][向第三方服务发送Socket消息,requestId={},消息体={}]", requestId, sendMsg);
} else {
throw new RuntimeException("WebSocket连接已断开,无法发送消息");
}
// 6. 阻塞等待响应(超时时间WAIT_TIMEOUT,超时抛出异常)
WsMessageDTO context = resultQueue.poll(WAIT_TIMEOUT, TimeUnit.MILLISECONDS);
log.info("[ws client][收到第三方服务Socket响应,requestId={},响应体={}]", requestId, context);
return context.getData();
} finally {
// 8. 无论成败,移除上下文(避免内存泄漏)
requestContextMap.remove(requestId);
}
}
/**
* 组装WebSocket消息体(自定义格式,第三方服务需按此格式解析、透传requestId)
* 格式示例:{"requestId":"xxx","data":"实际请求内容"}
*/
private String buildWsMsg(String requestId, String cmd, String data) {
return JSON.toJSONString(new WsMessageDTO(requestId, cmd, data));
}
@Override
public void run(ApplicationArguments args) {
doConnect(RECONNECT_INIT_DELAY);
}
/**
* 建立WebSocket连接(初始化+重连通用)
*/
private void doConnect(long delay) {
reconnectExecutor.schedule(() -> {
try {
WebSocketContainer container = ContainerProvider.getWebSocketContainer();
container.connectToServer(new WsMessageHandler(), URI.create(WS_SERVER_URL));
isConnected.set(true);
} catch (Exception e) {
isConnected.set(false);
// 指数退避重连:延迟翻倍,最大不超过60s
long nextDelay = Math.min(delay * 2, RECONNECT_MAX_DELAY);
doConnect(nextDelay);
}
}, delay, TimeUnit.MILLISECONDS);
}
/**
* WebSocket消息处理器(内部类):处理服务提供方的响应消息、连接状态变更
*/
@ClientEndpoint
public class WsMessageHandler {
/**
* 建立Socket连接成功回调:初始化全局会话对象
*/
@OnOpen
public void onOpen(Session session) {
WsClientSq.this.wsSession = session;
log.info("[ws client][WebSocket连接已建立,sessionId={}]", session.getId());
}
/**
* 接收第三方服务的Socket响应消息(核心回调方法)
*/
@OnMessage
public void onMessage(String responseMsg) {
log.info("[ws client][接收到Socket原始响应消息:{}]", responseMsg);
WsMessageDTO wsMessageDTO = JSON.parseObject(responseMsg, WsMessageDTO.class);
String requestId = wsMessageDTO.getRequestId();
try {
if (requestId == null || !requestContextMap.containsKey(requestId)) {
log.warn("[ws client][收到无效响应:requestId不存在,响应体={}]", responseMsg);
SynchronousQueue<WsMessageDTO> resultQueue = requestContextMap.get(requestId);
resultQueue.put(WsMessageDTO.error("requestId不存在"));
return;
}
SynchronousQueue<WsMessageDTO> resultQueue = requestContextMap.get(requestId);
resultQueue.put(wsMessageDTO);
} catch (Exception e) {
log.error("[ws client][解析Socket响应消息失败] {}", responseMsg, e);
}
}
/**
* Socket连接关闭回调
*/
@OnClose
public void onClose(Session session, CloseReason reason) {
log.info("[ws client][WebSocket连接关闭] sessionId={},原因={}", session.getId(), reason.getReasonPhrase());
isConnected.set(false);
doConnect(RECONNECT_INIT_DELAY);
}
/**
* Socket连接异常回调
*/
@OnError
public void onError(Session session, Throwable throwable) {
log.error("[ws client][WebSocket连接发生异常] sessionId={}", session.getId(), throwable);
isConnected.set(false);
doConnect(RECONNECT_INIT_DELAY);
}
}
}
四、参考内容
- 豆包