实现websocket的方式
1.springboot中有两种方式实现websocket,一种是基于原生的基于注解的websocket,另一种是基于spring封装后的WebSocketHandler
基于原生注解实现websocket
1)先引入websocket的starter坐标
xml
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-websocket</artifactId>
</dependency>
2)编写websocket的Endpoint端点类
java
@ServerEndpoint(value = "/ws/{token}")
@Component
public class WebsocketHandler2 {
private final static Logger log = LoggerFactory.getLogger(WebsocketHandler2.class);
private static final Set<Session> SESSIONS = new ConcurrentSkipListSet<>(Comparator.comparing(Session::getId));
private static final ScheduledExecutorService scheduledExecutor = Executors.newScheduledThreadPool(10);
private static final Map<String, ScheduledFuture<?>> futures = new ConcurrentHashMap<>();
@OnOpen
public void onOpen(Session session, @PathParam("token") String token, EndpointConfig config) throws IOException {
// session.addMessageHandler(new PongHandler());
ScheduledFuture<?> future = scheduledExecutor.scheduleWithFixedDelay(() -> sendPing(session), 5, 5, TimeUnit.SECONDS);
String queryString = session.getQueryString();
futures.put(session.getId(), future);
session.setMaxIdleTimeout(6 * 1000);
SESSIONS.add(session);
log.info("open connect sessionId={}, token={}, queryParam={}", session.getId(), token, queryString);
String s = String.format("ws client(id=%s) has connected", session.getId());
session.getBasicRemote().sendText(s);
}
static class PongHandler implements MessageHandler.Whole<PongMessage> {
@Override
public void onMessage(PongMessage message) {
ByteBuffer data = message.getApplicationData();
String s = new String(data.array(), StandardCharsets.UTF_8);
log.info("receive pong msg=> {}", s);
}
}
@OnClose
public void onClose(Session session, CloseReason reason) {
log.info("session(id={}) close ,closeCode={},closeParse={}", session.getId(), reason.getCloseCode(), reason.getReasonPhrase());
SESSIONS.remove(session);
ScheduledFuture<?> future = futures.get(session.getId());
if (future != null) {
future.cancel(true);
}
}
@OnMessage
public void onMessage(String message, Session session) throws IOException {
log.info("receive client(id={}) msg=>{}", session.getId(), message);
String s = String.format("reply your(id=%s) msg=>【%s】", session.getId(), message);
session.getBasicRemote().sendText(s);
}
@OnMessage
public void onPong(PongMessage message, Session session) throws IOException {
ByteBuffer data = message.getApplicationData();
String s = new String(data.array(), StandardCharsets.UTF_8);
log.info("receive client(id={}) pong msg=> {}", session.getId(), s);
}
@OnError
public void onError(Session session, Throwable error) {
log.error("Session(id={}) error occur ", session.getId(), error);
}
private void sendPing(Session session) {
if (session.isOpen()) {
String replyContent = String.format("Hello,client(id=%s)", session.getId());
try {
session.getBasicRemote().sendPing(ByteBuffer.wrap(replyContent.getBytes(StandardCharsets.UTF_8)));
} catch (IOException e) {
log.error("ping client(id={}) error", session.getId(), e);
}
return;
}
SESSIONS.remove(session);
ScheduledFuture<?> future = futures.remove(session.getId());
if (future != null) {
future.cancel(true);
}
}
}
注解说明
@ServerEndpoint
标记这个是一个服务端的端点类
@OnOpen
标记此方法是建立websocket连接时的回调方法
@OnMessage
标记此方法是接收到客户端消息时的回调方法
@OnClose
标记此方法是断开websocke连接时的回调方法
@OnError
标记此方法是websocke发生异常时的回调方法
@PathParam
可以获取@ServerEndpoint
注解中绑定的路径模板参数
方法参数说明
1) onOpen方法参数
onOpen的可用参数在tomcat源码 org.apache.tomcat.websocket.pojo.PojoMethodMapping#getOnOpenArgs
可以看到
java
public Object[] getOnOpenArgs(Map<String,String> pathParameters,
Session session, EndpointConfig config) throws DecodeException {
return buildArgs(onOpenParams, pathParameters, session, config, null,
null);
}
因此可以看出@OnOpen
所标记方法的合法参数有
(1)@PathParam标记的路径参数
(2)当前会话Session参数
(3)当前endpoint的配置详情EndpointConfig参数
2) onClose方法参数
onClose的可用参数在tomcat源码 org.apache.tomcat.websocket.pojo.PojoMethodMapping#getOnCloseArgs
可以看到
java
public Object[] getOnCloseArgs(Map<String,String> pathParameters,
Session session, CloseReason closeReason) throws DecodeException {
return buildArgs(onCloseParams, pathParameters, session, null, null,
closeReason);
}
因此可以看出@OnClose
所标记方法的合法参数有
(1)@PathParam标记的路径参数
(2)当前会话Session参数
(3)当前连接关闭的原因CloseReason参数
3) onError方法参数
onError的可用参数在tomcat源码 org.apache.tomcat.websocket.pojo.PojoMethodMapping#getOnErrorArgs
可以看到
java
public Object[] getOnErrorArgs(Map<String,String> pathParameters,
Session session, Throwable throwable) throws DecodeException {
return buildArgs(onErrorParams, pathParameters, session, null,
throwable, null);
}
因此可以看出@OnError
所标记方法的合法参数有
(1)@PathParam标记的路径参数
(2)当前会话Session参数
(3)发生异常的异常对象Throwable参数
4) onMessage方法参数
onMessage的可用参数在tomcat源码 org.apache.tomcat.websocket.pojo.PojoMethodMapping.MessageHandlerInfo#getMessageHandlers
可以看到
java
public Set<MessageHandler> getMessageHandlers(Object pojo,
Map<String,String> pathParameters, Session session,
EndpointConfig config) {
Object[] params = new Object[m.getParameterTypes().length];
for (Map.Entry<Integer,PojoPathParam> entry :
indexPathParams.entrySet()) {
PojoPathParam pathParam = entry.getValue();
String valueString = pathParameters.get(pathParam.getName());
Object value = null;
try {
value = Util.coerceToType(pathParam.getType(), valueString);
} catch (Exception e) {
DecodeException de = new DecodeException(valueString,
sm.getString(
"pojoMethodMapping.decodePathParamFail",
valueString, pathParam.getType()), e);
params = new Object[] { de };
break;
}
params[entry.getKey().intValue()] = value;
}
Set<MessageHandler> results = new HashSet<>(2);
if (indexBoolean == -1) {
// Basic
if (indexString != -1 || indexPrimitive != -1) {
MessageHandler mh = new PojoMessageHandlerWholeText(pojo, m,
session, config, null, params, indexPayload, false,
indexSession, maxMessageSize);
results.add(mh);
} else if (indexReader != -1) {
MessageHandler mh = new PojoMessageHandlerWholeText(pojo, m,
session, config, null, params, indexReader, true,
indexSession, maxMessageSize);
results.add(mh);
} else if (indexByteArray != -1) {
MessageHandler mh = new PojoMessageHandlerWholeBinary(pojo,
m, session, config, null, params, indexByteArray,
true, indexSession, false, maxMessageSize);
results.add(mh);
} else if (indexByteBuffer != -1) {
MessageHandler mh = new PojoMessageHandlerWholeBinary(pojo,
m, session, config, null, params, indexByteBuffer,
false, indexSession, false, maxMessageSize);
results.add(mh);
} else if (indexInputStream != -1) {
MessageHandler mh = new PojoMessageHandlerWholeBinary(pojo,
m, session, config, null, params, indexInputStream,
true, indexSession, true, maxMessageSize);
results.add(mh);
} else if (decoderMatch != null && decoderMatch.hasMatches()) {
if (decoderMatch.getBinaryDecoders().size() > 0) {
MessageHandler mh = new PojoMessageHandlerWholeBinary(
pojo, m, session, config,
decoderMatch.getBinaryDecoders(), params,
indexPayload, true, indexSession, true,
maxMessageSize);
results.add(mh);
}
if (decoderMatch.getTextDecoders().size() > 0) {
MessageHandler mh = new PojoMessageHandlerWholeText(
pojo, m, session, config,
decoderMatch.getTextDecoders(), params,
indexPayload, true, indexSession, maxMessageSize);
results.add(mh);
}
} else {
MessageHandler mh = new PojoMessageHandlerWholePong(pojo, m,
session, params, indexPong, false, indexSession);
results.add(mh);
}
} else {
// ASync
if (indexString != -1) {
MessageHandler mh = new PojoMessageHandlerPartialText(pojo,
m, session, params, indexString, false,
indexBoolean, indexSession, maxMessageSize);
results.add(mh);
} else if (indexByteArray != -1) {
MessageHandler mh = new PojoMessageHandlerPartialBinary(
pojo, m, session, params, indexByteArray, true,
indexBoolean, indexSession, maxMessageSize);
results.add(mh);
} else {
MessageHandler mh = new PojoMessageHandlerPartialBinary(
pojo, m, session, params, indexByteBuffer, false,
indexBoolean, indexSession, maxMessageSize);
results.add(mh);
}
}
return results;
}
}
因此可以看出@OnMessage
所标记方法的合法参数有
(1)@PathParam标记的路径参数
(2)当前会话Session参数
(3)当数据是分块传输时,表示当前消息时是否是最后一块数据的boolean Boolean参数
(4)字符输入流Reader参数
(5)二进制输入流InputStream参数
(6)原始的ByteBuffer参数
(7)字节数组byte[]参数
(8)字符串string参数
(9) Pong响应PongMessage参数
注意:接收数据报文的参数(4)~(5),只能使用其中的一个,否则可能导致IO异常(IO流只能读取一次)
ping和pong
上面的代码中我额外给websocket会话增加了一个PongMessage
的处理方法onPong
,它的作用是接收客户端的pong回执消息。只有在服务端向客户端发送Ping
请求时,服务端才能接收到Pong响应。这里的ping和pong就是类型于其他系统中的心跳机制,用来检测客户端、服务端双方是否还在线,如果超过了限定时间没有收到ping
和 pong
消息,服务端就会主动断开连接。
因此我在建立websocke连接的时候给当前回话设置了最大空闲时间(超过这个时间没有数据报文传输,此连接就会自动断开),同时绑定了一个定时任务,这个定时任务会定时发送ping消息来保活。
这里的onPong
方法不是必须的,没有它能保活,onPong
只是用来得到一个ping结果的通知。
3)注册暴露端点
java
@Configuration
@EnableWebSocket
public class WebsocketConfig {
@Bean
public ServerEndpointExporter serverEndpointExporter(){
ServerEndpointExporter serverEndpointExporter = new ServerEndpointExporter();
//WebsocketHandler2如果是一个spring bean(即有@Component),则不需要调用setAnnotatedEndpointClasses方法,spring会自动探测有@ServerEndpoint注解的bean
//WebsocketHandler2如果只是一个包含@ServerEndpoint注解的普通类(不是 spring bean),则需要在此调用setAnnotatedEndpointClasses方法,手动注册Endpoint类型
// serverEndpointExporter.setAnnotatedEndpointClasses(WebsocketHandler2.class );
return serverEndpointExporter;
}
}
配置类添加@EnableWebSocket
,启用spring websocket功能.
另外还需配置一个Bean ServerEndpointExporter
;如果Endpoint类是一个spring bean(即有@Component
),则不需要调用setAnnotatedEndpointClasses方法,spring会自动探测含有@ServerEndpoint
注解的Bean;如果Endpoint类只是一个包含@ServerEndpoint
注解的普通类(不是 spring bean),则需要在此调用setAnnotatedEndpointClasses方法,手动注册Endpoint类型。
注意:即使Endpoint类是spring bean ,WebsocketContainer也会再创建并使用这个类的一个新实例,也就是说这个Endpoint中不能使用spring相关的功能,典型的就是不能使用@Autowire
等注解自动注入Bean。其原因是websocket的默认端点配置org.apache.tomcat.websocket.server.DefaultServerEndpointConfigurator
获取endpoint实例的逻辑是反射调用构造方法去创建一个新对象
java
public class DefaultServerEndpointConfigurator
extends ServerEndpointConfig.Configurator {
@Override
public <T> T getEndpointInstance(Class<T> clazz)
throws InstantiationException {
try {
return clazz.getConstructor().newInstance();
} catch (InstantiationException e) {
throw e;
} catch (ReflectiveOperationException e) {
InstantiationException ie = new InstantiationException();
ie.initCause(e);
throw ie;
}
}
当然你可以通过注入静态属性的方式来绕过这个限制。
理论上说也可在@ServerEndpoint
注解的configurator属性指定为spring的org.springframework.web.socket.server.standard.SpringConfigurator
也可以自动注入Bean依赖.
java
@ServerEndpoint(value = "/echo", configurator = SpringConfigurator.class)
public class EchoEndpoint {
// ...
}
SpringConfigurator它重写了获取Endpoint实例的方法逻辑getEndpointInstance
,它是直接到spring容器中去取这个bean,而不是创建一个新实例.
但实际在spring boot项目中,上面的getEndpointInstance
方法获取到的WebApplicationContext
是null
,也就没法从spring容器中获取这个Endpoint bean
基于spring WebSocketHandler实现websocket
了解WebSocketHandler
提前引入前面提到的websocket的starter 依赖
WebSocketHandler接口定义了5个方法,
afterConnectionEstablished
:建立连接后的回调方法
handleMessage
:接收到客户端消息后的回调方法
handleTransportError
: 数据传输异常时的回调方法
afterConnectionClosed
: 连接关闭后的回调方法
supportsPartialMessages
: 是否支持数据分块传输(最后一个分块传输,isLast是true)
它有两个主要的子类, 一个是处理纯文本数据的TextWebSocketHandler
,另一个是处理二进制数据的BinaryWebSocketHandler
。
我们实现websocket一般是继承这两个类,并重写相应的方法。一般都需要重写afterConnectionEstablished
handleTransportError
handleTransportError
afterConnectionClosed
这三个方法,除此之外,处理文本还要重写接收客户端消息后的回调方法handleTextMessage
,处理二进制数据需要重写接收客户端消息后的回调方法handleBinaryMessage
。如果有需要得到ping结果回调,还可以重写handlePongMessage
方法
代码
java
@Component
public class WebsocketHandler1 extends TextWebSocketHandler {
private final Logger log = LoggerFactory.getLogger(getClass());
private static final Set<WebSocketSession> sessions = new ConcurrentSkipListSet<>(Comparator.comparing(WebSocketSession::getId));
private static final ScheduledExecutorService scheduledExecutor = Executors.newScheduledThreadPool(10);
private static final Map<String, ScheduledFuture<?>> futures = new ConcurrentHashMap<>();
@Override
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
@SuppressWarnings("unchecked")
AbstractWebSocketSession<Session> standardSession = (AbstractWebSocketSession) session;
Session nativeSession = standardSession.getNativeSession();
nativeSession.setMaxIdleTimeout(1000*4);
ScheduledFuture<?> future = scheduledExecutor.scheduleWithFixedDelay(() -> sendPing(session), 5, 5, TimeUnit.SECONDS);
futures.put(session.getId(), future);
log.info("open connect sessionId={}", session.getId());
sessions.add(session);
TextMessage msg = new TextMessage(String.format("ws client(id=%s) has connected", session.getId()));
session.sendMessage(msg);
}
@Override
protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {
log.info("receive client(id={}) msg=>{}", session.getId(), message.getPayload());
TextMessage msg = new TextMessage(String.format("reply your(id=%s) msg=>%s", session.getId(), message.getPayload()));
session.sendMessage(msg);
}
@Override
protected void handlePongMessage(WebSocketSession session, PongMessage message) throws Exception {
ByteBuffer payload = message.getPayload();
String s = new String(payload.array());
log.info("receive client(id={}) pong msg=>{}", session.getId(),s);
}
@Override
public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {
log.error("client(id={}) error occur ", session.getId(), exception);
}
@Override
public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception {
log.info("close ,status={}", status);
sessions.remove(session);
ScheduledFuture<?> future = futures.get(session.getId());
if (future != null) {
future.cancel(true);
}
}
@Override
public boolean supportsPartialMessages() {
return true;
}
private void sendPing(WebSocketSession session) {
if (session.isOpen()) {
String replyContent = String.format("Hello,client(id=%s)", session.getId());
PingMessage msg = new PingMessage(ByteBuffer.wrap(replyContent.getBytes(StandardCharsets.UTF_8)));
try {
session.sendMessage(msg);
} catch (IOException e) {
log.error("ping client(id={}) error", session.getId(), e);
}
}
}
}
springboot内置的WebSocketHandler
前端html代码
html
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>WebSocket Chat</title>
<style>
body {
font-family: Arial, sans-serif;
}
#chat-box {
width: 100%;
height: 300px;
border: 1px solid #ccc;
overflow-y: auto;
padding: 10px;
margin-bottom: 10px;
background-color: #f9f9f9;
white-space: pre-wrap;
}
#input-box {
width: calc(50% - 90px);
padding: 10px;
margin-right: 10px;
display: flex;
justify-content: center
}
.btn {
padding: 10px;
}
#btn-container {
margin: 1px;
display: flex;
justify-content: center;
gap: 5px;
}
#input-container {
margin: 1px;
display: flex;
justify-content: center;
gap: 5px;
}
</style>
</head>
<body>
<div id="chat-box"></div>
<div id="input-container">
<input type="text" id="input-box" placeholder="Enter your message"/>
</div>
<div id="btn-container">
<button id="connect-button" class="btn">Connect</button>
<button id="close-button" class="btn">Close</button>
<button id="clear-button" class="btn">Clear</button>
<button id="send-button" class="btn">Send</button>
</div>
<script>
const chatBox = document.getElementById('chat-box');
const inputBox = document.getElementById('input-box');
const sendButton = document.getElementById('send-button');
const connectBtn = document.getElementById('connect-button');
const closeBtn = document.getElementById('close-button');
const clearBtn = document.getElementById('clear-button');
let ws = null;
sendButton.addEventListener('click', () => {
if (ws === null) {
alert("no connect")
return;
}
const message = inputBox.value;
if (message) {
ws.send(message);
chatBox.innerHTML += 'You: ' + message + '\n';
chatBox.scrollTop = chatBox.scrollHeight;
inputBox.value = '';
}
});
clearBtn.addEventListener('click', () => {
chatBox.innerHTML = '';
});
closeBtn.addEventListener('click', () => {
if (ws === null) {
alert("no connect")
return;
}
console.log("prepare close ws");
ws.close(1000, 'Normal closure');
});
connectBtn.addEventListener('click', () => {
if (ws !== null) {
alert("already connected!")
return;
}
let curWs = new WebSocket('ws://localhost:7001/ws/Hews2df?id=323&color=red');
curWs.onopen = event => {
ws = curWs;
console.log('Connected to WebSocket server, event=>%s', JSON.stringify(event));
};
curWs.onmessage = event => {
const message = event.data;
chatBox.innerHTML += 'Server: ' + message + '\n';
chatBox.scrollTop = chatBox.scrollHeight;
};
curWs.onclose = event => {
ws = null;
console.log('Disconnected from WebSocket server, close code=%s,close reason=%s', event.code, event.reason);
};
curWs.onerror = event => {
console.log("error occur, event=>%s", JSON.stringify(event))
};
});
</script>
</body>
</html>
演示效果