@ServerEndpoint 会被websocket容器管理,不会被spring管理
每个实例连接由spring容器管理
实例的连接请求,由websocket容器本身管理,包括session
端点方式,不能正常注入bean,通过 SpringContext 去查找bean
基于端点的实现方式
java
@Configuration
@EnableWebSocket
public class WebSocketConfig {
@Bean
public ServerEndpointExporter serverEndpointExporter(){
return new ServerEndpointExporter();
}
}
java
@Slf4j
@Component
@ServerEndpoint(value = "/websocket/{ai-podcast}", encoders = StringEncoder.class)
public class WebSocketServerEndpoint {
private final Map<String, FileOutputStream> receive = new ConcurrentHashMap<>();
@OnOpen
public void onOpen(Session session, @PathParam("ai-podcast") String flag) {
log.info("websocket上线:{}", flag);
WebSocketUtil.addSession(flag, session);
}
@OnMessage
public void onMessageText(Session session, String message) throws IOException {
log.info("websocket收到文本消息:{}", message);
JSONObject jsonObject = JSONUtil.parseObj(message);
if (jsonObject.containsKey("type") && jsonObject.getStr("type").equals("qa")) {
if (jsonObject.get("status").equals("qa_done")) {
this.getAiPodcastService().setAnswer(jsonObject.getStr("message_id"), jsonObject.getStr("answer"));
return;
}
if (jsonObject.get("status").equals("stream_start")) {
FileOutputStream fileOutputStream = new FileOutputStream("E:/pcm/" + jsonObject.get("message_id") + ".pcm");
receive.put(session.getId(), fileOutputStream);
return;
}
if (jsonObject.get("status").equals("stream_end")) {
FileOutputStream fileOutputStream = receive.get(session.getId());
fileOutputStream.flush();
fileOutputStream.close();
receive.remove(session.getId());
}
}
}
@OnMessage
public void onMessageByte(Session session, byte[] message) throws IOException {
log.info("websocket收到数组消息:{}", message.length);
FileOutputStream fileOutputStream = receive.get(session.getId());
fileOutputStream.write(message);
}
@OnClose
public void onClose(Session session, @PathParam("ai-podcast") String flag) {
log.info("websocket下线:{}", flag);
WebSocketUtil.removeSession(flag);
receive.remove(session.getId());
}
@OnError
public void onError(Session session, Throwable throwable, @PathParam("ai-podcast") String flag) {
log.error("websocket发生错误:{}", throwable.getMessage(), throwable);
throwable.printStackTrace();
WebSocketUtil.removeSession(flag);
receive.remove(session.getId());
}
private AiPodcastService getAiPodcastService() {
return SpringUtil.getBean(AiPodcastService.class);
}
}
java
@Slf4j
public class WebSocketUtil {
private final static Map<String, Session> sessions = new ConcurrentHashMap<>();
public static void addSession(String flag, Session session) {
sessions.put(flag, session);
}
public static void removeSession(String flag) {
sessions.remove(flag);
}
public static void sendMessage(String flag, String message) {
Session session = sessions.get(flag);
try {
log.error("websocket发送消息:{},{}", flag, message);
session.getBasicRemote().sendObject(message);
} catch (IOException | EncodeException e) {
log.error("发送消息失败:{}", e.getMessage(), e);
}
}
}
java
public class StringEncoder implements Encoder.Text<String> {
@Override
public String encode(String message) throws EncodeException {
// 直接返回字符串,不需要特殊编码
return message;
}
}
基于WebSocketHandler
java
@Configuration
@EnableWebSocket
public class WebSocketConfig implements WebSocketConfigurer {
@Autowired
private BinaryWebSocketHandler binaryWebSocketHandler;
@Override
public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
registry.addHandler(binaryWebSocketHandler, "/websocket/upload/{userId}")
.setAllowedOrigins("*");
}
}
java
@Component
public class BinaryWebSocketHandler implements WebSocketHandler {
@Resource
private MyService myService;
@Override
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
// 连接建立
// myService.logConnection(session.getId());
// 获取 URI 模板变量
Map<String, String> pathVariables = (Map<String, String>)
session.getAttributes().get(URI_TEMPLATE_VARIABLES_ATTRIBUTE);
String userId = pathVariables.get("userId");
// 存储到 session 属性中
session.getAttributes().put("userId", userId);
}
@Override
public void handleMessage(WebSocketSession session, WebSocketMessage<?> message) throws Exception {
if (message instanceof TextMessage textMessage) {
// 处理文本消息(元数据)
String metadata = textMessage.getPayload();
// myService.processMetadata(metadata, session);
} else if (message instanceof BinaryMessage binaryMessage) {
// 处理二进制消息(文件数据)
ByteBuffer buffer = binaryMessage.getPayload();
byte[] data = new byte[buffer.remaining()];
buffer.get(data);
// myService.processBinaryData(data, session);
}
}
@Override
public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {
// 处理传输错误
}
@Override
public void afterConnectionClosed(WebSocketSession session, CloseStatus closeStatus) throws Exception {
// 连接关闭
// myService.cleanup(session.getId());
}
@Override
public boolean supportsPartialMessages() {
return false;
}
}