代码环境:
jdk8, spring-boot1.5.22
依赖:
XML
<parent>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-parent</artifactId>
<version>1.5.22.RELEASE</version>
<relativePath /> <!-- lookup parent from repository -->
</parent>
<dependencies>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-websocket</artifactId>
</dependency>
<!-- spring-boot的web启动的jar包 -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
<exclusions>
<exclusion>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-tomcat</artifactId>
</exclusion>
<exclusion>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
</exclusion>
<exclusion>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-core</artifactId>
</exclusion>
<exclusion>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-annotations</artifactId>
</exclusion>
<exclusion>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId>
</exclusion>
<exclusion>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-core</artifactId>
</exclusion>
</exclusions>
</dependency>
</dependencies>
代码:
XML
package cn.com.trinitygo.scm.common.util.websocket;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.config.annotation.EnableWebSocket;
import org.springframework.web.socket.config.annotation.WebSocketConfigurer;
import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry;
/**
* websocket的配置类
* @ClassName: CustomWebSocketConfig
* @Author: peilei
* @Date: 2021/9/13 14:59
* @Description
*/
@Configuration
@EnableWebSocket
public class CustomWebSocketConfig implements WebSocketConfigurer {
@Override
public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
registry.addHandler(customWebSocketHandler(), "/webSocketBySpring/customWebSocketHandler").addInterceptors(new CustomWebSocketInterceptor()).setAllowedOrigins("*");
registry.addHandler(customWebSocketHandler(), "/sockjs/webSocketBySpring/customWebSocketHandler").addInterceptors(new CustomWebSocketInterceptor()).setAllowedOrigins("*").withSockJS();
}
@Bean
public WebSocketHandler customWebSocketHandler() {
return new CustomWebSocketHandler();
}
}
java
package cn.com.trinitygo.scm.common.util.websocket;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.server.HandshakeInterceptor;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpSession;
import java.util.Map;
/**
* WebSocket握手时的拦截器
* @ClassName: CustomWebSocketInterceptor
* @Author: peilei
* @Date: 2021/9/13 15:00
* @Description
*/
public class CustomWebSocketInterceptor implements HandshakeInterceptor {
private Logger logger = LoggerFactory.getLogger(CustomWebSocketInterceptor.class);
/**
* 关联HeepSession和WebSocketSession,
* beforeHandShake方法中的Map参数 就是对应websocketSession里的属性
*/
@Override
public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler handler, Map<String, Object> map) throws Exception {
if (request instanceof ServletServerHttpRequest) {
logger.info("*****beforeHandshake******");
HttpServletRequest httpServletRequest = ((ServletServerHttpRequest) request).getServletRequest();
HttpSession session = httpServletRequest.getSession(true);
logger.info("userId:{}", httpServletRequest.getParameter("userId"));
if (session != null) {
map.put("sessionId",session.getId());
map.put("userId", httpServletRequest.getParameter("userId"));
}
}
return true;
}
@Override
public void afterHandshake(ServerHttpRequest serverHttpRequest, ServerHttpResponse serverHttpResponse, WebSocketHandler webSocketHandler, Exception e) {
logger.info("******afterHandshake******");
}
}
java
package cn.com.trinitygo.scm.common.util.websocket;
import cn.com.trinitygo.scm.service.SysMsgService;
import com.alibaba.fastjson.JSONObject;
import httl.util.StringUtils;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.TextWebSocketHandler;
import org.slf4j.Logger;
import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
/**
* 创建一个WebSocket server
*
* @ClassName: CustomWebSocketHandler
* @Author: peilei
* @Date: 2021/9/13 14:48
* @Description
*/
@Service
public class CustomWebSocketHandler extends TextWebSocketHandler implements WebSocketHandler {
@Autowired
SysMsgService sysMsgService;
private Logger logger = LoggerFactory.getLogger(CustomWebSocketHandler.class);
// 在线用户列表
private static final Map<String, WebSocketSession> users;
// 用户标识
private static final String CLIENT_ID = "userId";
static {
users = new HashMap<>();
}
@Override
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
logger.info("成功建立websocket-spring连接");
String userId = getUserId(session);
if (StringUtils.isNotEmpty(userId)) {
users.put(userId, session);
session.sendMessage(new TextMessage("成功建立websocket-spring连接"));
logger.info("用户标识:{},Session:{}", userId, session.toString());
}
}
@Override
public void handleTextMessage(WebSocketSession session, TextMessage message) {
logger.info("收到客户端消息:{}", message.getPayload());
JSONObject msgJson = JSONObject.parseObject(message.getPayload());
String to = msgJson.getString("to");
String msg = msgJson.getString("msg");
WebSocketMessage<?> webSocketMessageServer = new TextMessage("server:" + message);
try {
session.sendMessage(webSocketMessageServer);
if ("all".equals(to.toLowerCase())) {
sendMessageToAllUsers(new TextMessage(getUserId(session) + ":" + msg));
} else {
sendMessageToUser(to, new TextMessage(getUserId(session) + ":" + msg));
}
} catch (IOException e) {
logger.info("handleTextMessage method error:{}", e);
}
}
@Override
public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {
if (session.isOpen()) {
session.close();
}
logger.info("连接出错");
users.remove(getUserId(session));
}
@Override
public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception {
logger.info("连接已关闭:" + status);
users.remove(getUserId(session));
}
@Override
public boolean supportsPartialMessages() {
return false;
}
public void sendMessage(String jsonData) {
logger.info("收到客户端消息sendMessage:{}", jsonData);
JSONObject msgJson = JSONObject.parseObject(jsonData);
String userId = StringUtils.isEmpty(msgJson.getString(CLIENT_ID)) ? "陌生人" : msgJson.getString(CLIENT_ID);
String to = msgJson.getString("to");
String msg = msgJson.getString("msg");
if ("all".equals(to.toLowerCase())) {
sendMessageToAllUsers(new TextMessage(userId + ":" + msg));
} else {
sendMessageToUser(to, new TextMessage(userId + ":" + msg));
}
}
/**
* 发送信息给指定用户
*
* @param userId
* @param message
* @return
* @Title: sendMessageToUser
* @Description: TODO
* @Date 2021/9/14 17:33
* @author OnlyMate
*/
public boolean sendMessageToUser(String userId, TextMessage message) {
if (users.get(userId) == null)
return false;
WebSocketSession session = users.get(userId);
logger.info("sendMessage:{} ,msg:{}", session, message.getPayload());
if (!session.isOpen()) {
logger.info("客户端:{},已断开连接,发送消息失败", userId);
return false;
}
try {
session.sendMessage(message);
} catch (IOException e) {
logger.info("sendMessageToUser method error:{}", e);
return false;
}
return true;
}
/**
* 发送信息给指定多用户
*
* @return
* @Title: sendMessageToUser
* @Description: TODO
* @Date 2021/9/14 17:33
* @author OnlyMate
*/
public boolean sendMessageToUsers(Long businessId, Integer type, String cnContent, String enContent, List<Long> ids, Integer businessType, Integer sendType, String tokenUserId) {
if (ids != null && ids.size() > 0) {
sysMsgService.sendMessage(businessId, type, cnContent, enContent, ids, businessType, sendType, tokenUserId);
for (Long userId : ids) {
String userIdstr = userId.toString();
if (users.get(userIdstr) == null){
continue;
}
WebSocketSession session = users.get(userIdstr);
TextMessage message = new TextMessage(cnContent);
logger.info("sendMessage:{} ,msg:{}", session, message.getPayload());
if (!session.isOpen()) {
logger.info("客户端:{},已断开连接,发送消息失败", userIdstr);
continue;
}
try {
session.sendMessage(message);
} catch (IOException e) {
logger.info("sendMessageToUser method error:{}", e);
continue;
}
}
}
return true;
}
/**
* 广播信息
*
* @param message
* @return
* @Title: sendMessageToAllUsers
* @Description: TODO
* @Date 2021/9/14 17:33
* @author OnlyMate
*/
public boolean sendMessageToAllUsers(TextMessage message) {
boolean allSendSuccess = true;
Set<String> userIds = users.keySet();
WebSocketSession session = null;
for (String userId : userIds) {
try {
session = users.get(userId);
if (session.isOpen()) {
session.sendMessage(message);
} else {
logger.info("客户端:{},已断开连接,发送消息失败", userId);
}
} catch (IOException e) {
logger.info("sendMessageToAllUsers method error:{}", e);
allSendSuccess = false;
}
}
return allSendSuccess;
}
/**
* 获取用户标识
*
* @param session
* @return
* @Title: getUserId
* @Description: TODO
* @Date 2021/9/14 17:33
* @author OnlyMate
*/
private String getUserId(WebSocketSession session) {
try {
String userId = session.getAttributes().get(CLIENT_ID).toString();
return userId;
} catch (Exception e) {
return null;
}
}
}