由浅入深逐步理解spring boot中如何实现websocket

实现websocket的方式

1.springboot中有两种方式实现websocket,一种是基于原生的基于注解的websocket,另一种是基于spring封装后的WebSocketHandler

基于原生注解实现websocket

1)先引入websocket的starter坐标

xml 复制代码
 	   <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-websocket</artifactId>
        </dependency>

2)编写websocket的Endpoint端点类

java 复制代码
@ServerEndpoint(value = "/ws/{token}")
@Component
public class WebsocketHandler2 {
    private final static Logger log = LoggerFactory.getLogger(WebsocketHandler2.class);

    private static final Set<Session> SESSIONS = new ConcurrentSkipListSet<>(Comparator.comparing(Session::getId));

    private static final ScheduledExecutorService scheduledExecutor = Executors.newScheduledThreadPool(10);
    private static final Map<String, ScheduledFuture<?>> futures = new ConcurrentHashMap<>();

    @OnOpen
    public void onOpen(Session session, @PathParam("token") String token, EndpointConfig config) throws IOException {
//        session.addMessageHandler(new PongHandler());
        ScheduledFuture<?> future = scheduledExecutor.scheduleWithFixedDelay(() -> sendPing(session), 5, 5, TimeUnit.SECONDS);
        String queryString = session.getQueryString();

        futures.put(session.getId(), future);
        session.setMaxIdleTimeout(6 * 1000);

        SESSIONS.add(session);

        log.info("open connect sessionId={}, token={}, queryParam={}", session.getId(), token, queryString);
        String s = String.format("ws client(id=%s) has connected", session.getId());
        session.getBasicRemote().sendText(s);

    }

    static class PongHandler implements MessageHandler.Whole<PongMessage> {

        @Override
        public void onMessage(PongMessage message) {
            ByteBuffer data = message.getApplicationData();
            String s = new String(data.array(), StandardCharsets.UTF_8);
            log.info("receive pong msg=> {}", s);

        }
    }

    @OnClose
    public void onClose(Session session, CloseReason reason) {
        log.info("session(id={}) close ,closeCode={},closeParse={}", session.getId(), reason.getCloseCode(), reason.getReasonPhrase());
        SESSIONS.remove(session);
        ScheduledFuture<?> future = futures.get(session.getId());
        if (future != null) {
            future.cancel(true);
        }
    }

    @OnMessage
    public void onMessage(String message, Session session) throws IOException {
        log.info("receive client(id={}) msg=>{}", session.getId(), message);
        String s = String.format("reply your(id=%s) msg=>【%s】", session.getId(), message);
        session.getBasicRemote().sendText(s);
    }


    @OnMessage
    public void onPong(PongMessage message, Session session) throws IOException {
        ByteBuffer data = message.getApplicationData();
        String s = new String(data.array(), StandardCharsets.UTF_8);
        log.info("receive client(id={}) pong msg=> {}", session.getId(), s);
    }

    @OnError
    public void onError(Session session, Throwable error) {
        log.error("Session(id={}) error occur ", session.getId(), error);
    }

    private void sendPing(Session session) {
        if (session.isOpen()) {
            String replyContent = String.format("Hello,client(id=%s)", session.getId());
            try {
                session.getBasicRemote().sendPing(ByteBuffer.wrap(replyContent.getBytes(StandardCharsets.UTF_8)));
            } catch (IOException e) {
                log.error("ping client(id={}) error", session.getId(), e);
            }
            return;
        }

        SESSIONS.remove(session);
        ScheduledFuture<?> future = futures.remove(session.getId());
        if (future != null) {
            future.cancel(true);
        }
    }
}

注解说明

@ServerEndpoint标记这个是一个服务端的端点类
@OnOpen 标记此方法是建立websocket连接时的回调方法
@OnMessage 标记此方法是接收到客户端消息时的回调方法
@OnClose标记此方法是断开websocke连接时的回调方法
@OnError标记此方法是websocke发生异常时的回调方法
@PathParam可以获取@ServerEndpoint注解中绑定的路径模板参数

方法参数说明

1) onOpen方法参数

onOpen的可用参数在tomcat源码 org.apache.tomcat.websocket.pojo.PojoMethodMapping#getOnOpenArgs可以看到

java 复制代码
    public Object[] getOnOpenArgs(Map<String,String> pathParameters,
            Session session, EndpointConfig config) throws DecodeException {
        return buildArgs(onOpenParams, pathParameters, session, config, null,
                null);
    }

因此可以看出@OnOpen所标记方法的合法参数有

(1)@PathParam标记的路径参数

(2)当前会话Session参数

(3)当前endpoint的配置详情EndpointConfig参数

2) onClose方法参数

onClose的可用参数在tomcat源码 org.apache.tomcat.websocket.pojo.PojoMethodMapping#getOnCloseArgs可以看到

java 复制代码
    public Object[] getOnCloseArgs(Map<String,String> pathParameters,
            Session session, CloseReason closeReason) throws DecodeException {
        return buildArgs(onCloseParams, pathParameters, session, null, null,
                closeReason);
    }

因此可以看出@OnClose所标记方法的合法参数有

(1)@PathParam标记的路径参数

(2)当前会话Session参数

(3)当前连接关闭的原因CloseReason参数

3) onError方法参数

onError的可用参数在tomcat源码 org.apache.tomcat.websocket.pojo.PojoMethodMapping#getOnErrorArgs可以看到

java 复制代码
    public Object[] getOnErrorArgs(Map<String,String> pathParameters,
            Session session, Throwable throwable) throws DecodeException {
        return buildArgs(onErrorParams, pathParameters, session, null,
                throwable, null);
    }

因此可以看出@OnError所标记方法的合法参数有

(1)@PathParam标记的路径参数

(2)当前会话Session参数

(3)发生异常的异常对象Throwable参数

4) onMessage方法参数

onMessage的可用参数在tomcat源码 org.apache.tomcat.websocket.pojo.PojoMethodMapping.MessageHandlerInfo#getMessageHandlers可以看到

java 复制代码
        public Set<MessageHandler> getMessageHandlers(Object pojo,
                Map<String,String> pathParameters, Session session,
                EndpointConfig config) {
            Object[] params = new Object[m.getParameterTypes().length];

            for (Map.Entry<Integer,PojoPathParam> entry :
                    indexPathParams.entrySet()) {
                PojoPathParam pathParam = entry.getValue();
                String valueString = pathParameters.get(pathParam.getName());
                Object value = null;
                try {
                    value = Util.coerceToType(pathParam.getType(), valueString);
                } catch (Exception e) {
                    DecodeException de =  new DecodeException(valueString,
                            sm.getString(
                                    "pojoMethodMapping.decodePathParamFail",
                                    valueString, pathParam.getType()), e);
                    params = new Object[] { de };
                    break;
                }
                params[entry.getKey().intValue()] = value;
            }

            Set<MessageHandler> results = new HashSet<>(2);
            if (indexBoolean == -1) {
                // Basic
                if (indexString != -1 || indexPrimitive != -1) {
                    MessageHandler mh = new PojoMessageHandlerWholeText(pojo, m,
                            session, config, null, params, indexPayload, false,
                            indexSession, maxMessageSize);
                    results.add(mh);
                } else if (indexReader != -1) {
                    MessageHandler mh = new PojoMessageHandlerWholeText(pojo, m,
                            session, config, null, params, indexReader, true,
                            indexSession, maxMessageSize);
                    results.add(mh);
                } else if (indexByteArray != -1) {
                    MessageHandler mh = new PojoMessageHandlerWholeBinary(pojo,
                            m, session, config, null, params, indexByteArray,
                            true, indexSession, false, maxMessageSize);
                    results.add(mh);
                } else if (indexByteBuffer != -1) {
                    MessageHandler mh = new PojoMessageHandlerWholeBinary(pojo,
                            m, session, config, null, params, indexByteBuffer,
                            false, indexSession, false, maxMessageSize);
                    results.add(mh);
                } else if (indexInputStream != -1) {
                    MessageHandler mh = new PojoMessageHandlerWholeBinary(pojo,
                            m, session, config, null, params, indexInputStream,
                            true, indexSession, true, maxMessageSize);
                    results.add(mh);
                } else if (decoderMatch != null && decoderMatch.hasMatches()) {
                    if (decoderMatch.getBinaryDecoders().size() > 0) {
                        MessageHandler mh = new PojoMessageHandlerWholeBinary(
                                pojo, m, session, config,
                                decoderMatch.getBinaryDecoders(), params,
                                indexPayload, true, indexSession, true,
                                maxMessageSize);
                        results.add(mh);
                    }
                    if (decoderMatch.getTextDecoders().size() > 0) {
                        MessageHandler mh = new PojoMessageHandlerWholeText(
                                pojo, m, session, config,
                                decoderMatch.getTextDecoders(), params,
                                indexPayload, true, indexSession, maxMessageSize);
                        results.add(mh);
                    }
                } else {
                    MessageHandler mh = new PojoMessageHandlerWholePong(pojo, m,
                            session, params, indexPong, false, indexSession);
                    results.add(mh);
                }
            } else {
                // ASync
                if (indexString != -1) {
                    MessageHandler mh = new PojoMessageHandlerPartialText(pojo,
                            m, session, params, indexString, false,
                            indexBoolean, indexSession, maxMessageSize);
                    results.add(mh);
                } else if (indexByteArray != -1) {
                    MessageHandler mh = new PojoMessageHandlerPartialBinary(
                            pojo, m, session, params, indexByteArray, true,
                            indexBoolean, indexSession, maxMessageSize);
                    results.add(mh);
                } else {
                    MessageHandler mh = new PojoMessageHandlerPartialBinary(
                            pojo, m, session, params, indexByteBuffer, false,
                            indexBoolean, indexSession, maxMessageSize);
                    results.add(mh);
                }
            }
            return results;
        }
    }

因此可以看出@OnMessage所标记方法的合法参数有

(1)@PathParam标记的路径参数

(2)当前会话Session参数

(3)当数据是分块传输时,表示当前消息时是否是最后一块数据的boolean Boolean参数

(4)字符输入流Reader参数

(5)二进制输入流InputStream参数

(6)原始的ByteBuffer参数

(7)字节数组byte[]参数

(8)字符串string参数

(9) Pong响应PongMessage参数

注意:接收数据报文的参数(4)~(5),只能使用其中的一个,否则可能导致IO异常(IO流只能读取一次)

ping和pong

上面的代码中我额外给websocket会话增加了一个PongMessage的处理方法onPong,它的作用是接收客户端的pong回执消息。只有在服务端向客户端发送Ping请求时,服务端才能接收到Pong响应。这里的ping和pong就是类型于其他系统中的心跳机制,用来检测客户端、服务端双方是否还在线,如果超过了限定时间没有收到pingpong消息,服务端就会主动断开连接。

因此我在建立websocke连接的时候给当前回话设置了最大空闲时间(超过这个时间没有数据报文传输,此连接就会自动断开),同时绑定了一个定时任务,这个定时任务会定时发送ping消息来保活。

这里的onPong方法不是必须的,没有它能保活,onPong只是用来得到一个ping结果的通知。

3)注册暴露端点

java 复制代码
@Configuration
@EnableWebSocket
public class WebsocketConfig  {
    

    @Bean
    public ServerEndpointExporter serverEndpointExporter(){
        ServerEndpointExporter serverEndpointExporter = new ServerEndpointExporter();
        //WebsocketHandler2如果是一个spring bean(即有@Component),则不需要调用setAnnotatedEndpointClasses方法,spring会自动探测有@ServerEndpoint注解的bean
        //WebsocketHandler2如果只是一个包含@ServerEndpoint注解的普通类(不是 spring bean),则需要在此调用setAnnotatedEndpointClasses方法,手动注册Endpoint类型
//        serverEndpointExporter.setAnnotatedEndpointClasses(WebsocketHandler2.class );
        return serverEndpointExporter;
    }
}    

配置类添加@EnableWebSocket,启用spring websocket功能.

另外还需配置一个Bean ServerEndpointExporter ;如果Endpoint类是一个spring bean(即有@Component),则不需要调用setAnnotatedEndpointClasses方法,spring会自动探测含有@ServerEndpoint注解的Bean;如果Endpoint类只是一个包含@ServerEndpoint注解的普通类(不是 spring bean),则需要在此调用setAnnotatedEndpointClasses方法,手动注册Endpoint类型。

注意:即使Endpoint类是spring bean ,WebsocketContainer也会再创建并使用这个类的一个新实例,也就是说这个Endpoint中不能使用spring相关的功能,典型的就是不能使用@Autowire等注解自动注入Bean。其原因是websocket的默认端点配置org.apache.tomcat.websocket.server.DefaultServerEndpointConfigurator获取endpoint实例的逻辑是反射调用构造方法去创建一个新对象

java 复制代码
public class DefaultServerEndpointConfigurator
        extends ServerEndpointConfig.Configurator {

    @Override
    public <T> T getEndpointInstance(Class<T> clazz)
            throws InstantiationException {
        try {
            return clazz.getConstructor().newInstance();
        } catch (InstantiationException e) {
            throw e;
        } catch (ReflectiveOperationException e) {
            InstantiationException ie = new InstantiationException();
            ie.initCause(e);
            throw ie;
        }
    }

当然你可以通过注入静态属性的方式来绕过这个限制。

理论上说也可在@ServerEndpoint注解的configurator属性指定为spring的org.springframework.web.socket.server.standard.SpringConfigurator也可以自动注入Bean依赖.

java 复制代码
 @ServerEndpoint(value = "/echo", configurator = SpringConfigurator.class)
 public class EchoEndpoint {
       // ...
   }

SpringConfigurator它重写了获取Endpoint实例的方法逻辑getEndpointInstance,它是直接到spring容器中去取这个bean,而不是创建一个新实例.

但实际在spring boot项目中,上面的getEndpointInstance方法获取到的WebApplicationContextnull,也就没法从spring容器中获取这个Endpoint bean

基于spring WebSocketHandler实现websocket

了解WebSocketHandler

提前引入前面提到的websocket的starter 依赖

WebSocketHandler接口定义了5个方法,
afterConnectionEstablished:建立连接后的回调方法
handleMessage:接收到客户端消息后的回调方法
handleTransportError: 数据传输异常时的回调方法
afterConnectionClosed: 连接关闭后的回调方法
supportsPartialMessages: 是否支持数据分块传输(最后一个分块传输,isLast是true)

它有两个主要的子类, 一个是处理纯文本数据的TextWebSocketHandler ,另一个是处理二进制数据的BinaryWebSocketHandler

我们实现websocket一般是继承这两个类,并重写相应的方法。一般都需要重写afterConnectionEstablished handleTransportError
handleTransportError afterConnectionClosed 这三个方法,除此之外,处理文本还要重写接收客户端消息后的回调方法handleTextMessage,处理二进制数据需要重写接收客户端消息后的回调方法handleBinaryMessage。如果有需要得到ping结果回调,还可以重写handlePongMessage方法

代码

java 复制代码
@Component
public class WebsocketHandler1 extends TextWebSocketHandler {
    private final Logger log = LoggerFactory.getLogger(getClass());
    private static final Set<WebSocketSession> sessions = new ConcurrentSkipListSet<>(Comparator.comparing(WebSocketSession::getId));
    private static final ScheduledExecutorService scheduledExecutor = Executors.newScheduledThreadPool(10);
    private static final Map<String, ScheduledFuture<?>> futures = new ConcurrentHashMap<>();


    @Override
    public void afterConnectionEstablished(WebSocketSession session) throws Exception {
        @SuppressWarnings("unchecked")
        AbstractWebSocketSession<Session> standardSession = (AbstractWebSocketSession) session;
        Session nativeSession = standardSession.getNativeSession();
        nativeSession.setMaxIdleTimeout(1000*4);

        ScheduledFuture<?> future = scheduledExecutor.scheduleWithFixedDelay(() -> sendPing(session), 5, 5, TimeUnit.SECONDS);
        futures.put(session.getId(), future);



        log.info("open connect sessionId={}", session.getId());
        sessions.add(session);
        TextMessage msg = new TextMessage(String.format("ws client(id=%s) has connected", session.getId()));
        session.sendMessage(msg);
    }


    @Override
    protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {
        log.info("receive client(id={}) msg=>{}", session.getId(), message.getPayload());
        TextMessage msg = new TextMessage(String.format("reply your(id=%s) msg=>%s", session.getId(), message.getPayload()));
        session.sendMessage(msg);
    }

    @Override
    protected void handlePongMessage(WebSocketSession session, PongMessage message) throws Exception {
        ByteBuffer payload = message.getPayload();
        String s = new String(payload.array());
        log.info("receive client(id={}) pong msg=>{}", session.getId(),s);
    }

    @Override
    public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {
        log.error("client(id={}) error occur ", session.getId(), exception);

    }

    @Override
    public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception {
        log.info("close ,status={}", status);
        sessions.remove(session);
        ScheduledFuture<?> future = futures.get(session.getId());
        if (future != null) {
            future.cancel(true);
        }
    }

    @Override
    public boolean supportsPartialMessages() {
        return true;
    }

    private void sendPing(WebSocketSession session) {
        if (session.isOpen()) {
            String replyContent = String.format("Hello,client(id=%s)", session.getId());

            PingMessage msg = new PingMessage(ByteBuffer.wrap(replyContent.getBytes(StandardCharsets.UTF_8)));
            try {
                session.sendMessage(msg);
            } catch (IOException e) {
                log.error("ping client(id={}) error", session.getId(), e);
            }

        }
    }
}

springboot内置的WebSocketHandler

前端html代码

html 复制代码
<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>WebSocket Chat</title>
    <style>
        body {
            font-family: Arial, sans-serif;
        }

        #chat-box {
            width: 100%;
            height: 300px;
            border: 1px solid #ccc;
            overflow-y: auto;
            padding: 10px;
            margin-bottom: 10px;
            background-color: #f9f9f9;
            white-space: pre-wrap;
        }

        #input-box {
            width: calc(50% - 90px);
            padding: 10px;
            margin-right: 10px;
            display: flex;
            justify-content: center
        }

        .btn {
            padding: 10px;

        }

        #btn-container {
            margin: 1px;
            display: flex;
            justify-content: center;
            gap: 5px;
        }

        #input-container {
            margin: 1px;
            display: flex;
            justify-content: center;
            gap: 5px;
        }
    </style>
</head>
<body>

<div id="chat-box"></div>
<div id="input-container">
    <input type="text" id="input-box" placeholder="Enter your message"/>
</div>

<div id="btn-container">
    <button id="connect-button" class="btn">Connect</button>
    <button id="close-button" class="btn">Close</button>
    <button id="clear-button" class="btn">Clear</button>
    <button id="send-button" class="btn">Send</button>
</div>


<script>
    const chatBox = document.getElementById('chat-box');
    const inputBox = document.getElementById('input-box');
    const sendButton = document.getElementById('send-button');
    const connectBtn = document.getElementById('connect-button');
    const closeBtn = document.getElementById('close-button');
    const clearBtn = document.getElementById('clear-button');


    let ws = null;

    sendButton.addEventListener('click', () => {
        if (ws === null) {
            alert("no connect")
            return;
        }
        const message = inputBox.value;
        if (message) {
            ws.send(message);

            chatBox.innerHTML += 'You: ' + message + '\n';
            chatBox.scrollTop = chatBox.scrollHeight;
            inputBox.value = '';
        }
    });

    clearBtn.addEventListener('click', () => {
        chatBox.innerHTML = '';
    });

    closeBtn.addEventListener('click', () => {
        if (ws === null) {
            alert("no connect")
            return;
        }
        console.log("prepare close ws");
        ws.close(1000, 'Normal closure');
    });

    connectBtn.addEventListener('click', () => {
        if (ws !== null) {
            alert("already connected!")
            return;
        }
        let curWs = new WebSocket('ws://localhost:7001/ws/Hews2df?id=323&color=red');

        curWs.onopen = event => {
            ws = curWs;
            console.log('Connected to WebSocket server, event=>%s', JSON.stringify(event));
        };

        curWs.onmessage = event => {
            const message = event.data;
            chatBox.innerHTML += 'Server: ' + message + '\n';
            chatBox.scrollTop = chatBox.scrollHeight;
        };

        curWs.onclose = event => {
            ws = null;
            console.log('Disconnected from WebSocket server, close code=%s,close reason=%s', event.code, event.reason);
        };
        curWs.onerror = event => {
            console.log("error occur, event=>%s", JSON.stringify(event))

        };
    });

</script>

</body>
</html>

演示效果

相关推荐
向前看-3 小时前
验证码机制
前端·后端
黄油饼卷咖喱鸡就味增汤拌孜然羊肉炒饭4 小时前
SpringBoot如何实现缓存预热?
java·spring boot·spring·缓存·程序员
超爱吃士力架4 小时前
邀请逻辑
java·linux·后端
AskHarries6 小时前
Spring Cloud OpenFeign快速入门demo
spring boot·后端
isolusion7 小时前
Springboot的创建方式
java·spring boot·后端
Yvemil78 小时前
《开启微服务之旅:Spring Boot Web开发举例》(一)
前端·spring boot·微服务
zjw_rp8 小时前
Spring-AOP
java·后端·spring·spring-aop
TodoCoder8 小时前
【编程思想】CopyOnWrite是如何解决高并发场景中的读写瓶颈?
java·后端·面试
凌虚9 小时前
Kubernetes APF(API 优先级和公平调度)简介
后端·程序员·kubernetes
星河梦瑾9 小时前
SpringBoot相关漏洞学习资料
java·经验分享·spring boot·安全