由浅入深逐步理解spring boot中如何实现websocket

实现websocket的方式

1.springboot中有两种方式实现websocket,一种是基于原生的基于注解的websocket,另一种是基于spring封装后的WebSocketHandler

基于原生注解实现websocket

1)先引入websocket的starter坐标

xml 复制代码
 	   <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-websocket</artifactId>
        </dependency>

2)编写websocket的Endpoint端点类

java 复制代码
@ServerEndpoint(value = "/ws/{token}")
@Component
public class WebsocketHandler2 {
    private final static Logger log = LoggerFactory.getLogger(WebsocketHandler2.class);

    private static final Set<Session> SESSIONS = new ConcurrentSkipListSet<>(Comparator.comparing(Session::getId));

    private static final ScheduledExecutorService scheduledExecutor = Executors.newScheduledThreadPool(10);
    private static final Map<String, ScheduledFuture<?>> futures = new ConcurrentHashMap<>();

    @OnOpen
    public void onOpen(Session session, @PathParam("token") String token, EndpointConfig config) throws IOException {
//        session.addMessageHandler(new PongHandler());
        ScheduledFuture<?> future = scheduledExecutor.scheduleWithFixedDelay(() -> sendPing(session), 5, 5, TimeUnit.SECONDS);
        String queryString = session.getQueryString();

        futures.put(session.getId(), future);
        session.setMaxIdleTimeout(6 * 1000);

        SESSIONS.add(session);

        log.info("open connect sessionId={}, token={}, queryParam={}", session.getId(), token, queryString);
        String s = String.format("ws client(id=%s) has connected", session.getId());
        session.getBasicRemote().sendText(s);

    }

    static class PongHandler implements MessageHandler.Whole<PongMessage> {

        @Override
        public void onMessage(PongMessage message) {
            ByteBuffer data = message.getApplicationData();
            String s = new String(data.array(), StandardCharsets.UTF_8);
            log.info("receive pong msg=> {}", s);

        }
    }

    @OnClose
    public void onClose(Session session, CloseReason reason) {
        log.info("session(id={}) close ,closeCode={},closeParse={}", session.getId(), reason.getCloseCode(), reason.getReasonPhrase());
        SESSIONS.remove(session);
        ScheduledFuture<?> future = futures.get(session.getId());
        if (future != null) {
            future.cancel(true);
        }
    }

    @OnMessage
    public void onMessage(String message, Session session) throws IOException {
        log.info("receive client(id={}) msg=>{}", session.getId(), message);
        String s = String.format("reply your(id=%s) msg=>【%s】", session.getId(), message);
        session.getBasicRemote().sendText(s);
    }


    @OnMessage
    public void onPong(PongMessage message, Session session) throws IOException {
        ByteBuffer data = message.getApplicationData();
        String s = new String(data.array(), StandardCharsets.UTF_8);
        log.info("receive client(id={}) pong msg=> {}", session.getId(), s);
    }

    @OnError
    public void onError(Session session, Throwable error) {
        log.error("Session(id={}) error occur ", session.getId(), error);
    }

    private void sendPing(Session session) {
        if (session.isOpen()) {
            String replyContent = String.format("Hello,client(id=%s)", session.getId());
            try {
                session.getBasicRemote().sendPing(ByteBuffer.wrap(replyContent.getBytes(StandardCharsets.UTF_8)));
            } catch (IOException e) {
                log.error("ping client(id={}) error", session.getId(), e);
            }
            return;
        }

        SESSIONS.remove(session);
        ScheduledFuture<?> future = futures.remove(session.getId());
        if (future != null) {
            future.cancel(true);
        }
    }
}

注解说明

@ServerEndpoint标记这个是一个服务端的端点类
@OnOpen 标记此方法是建立websocket连接时的回调方法
@OnMessage 标记此方法是接收到客户端消息时的回调方法
@OnClose标记此方法是断开websocke连接时的回调方法
@OnError标记此方法是websocke发生异常时的回调方法
@PathParam可以获取@ServerEndpoint注解中绑定的路径模板参数

方法参数说明

1) onOpen方法参数

onOpen的可用参数在tomcat源码 org.apache.tomcat.websocket.pojo.PojoMethodMapping#getOnOpenArgs可以看到

java 复制代码
    public Object[] getOnOpenArgs(Map<String,String> pathParameters,
            Session session, EndpointConfig config) throws DecodeException {
        return buildArgs(onOpenParams, pathParameters, session, config, null,
                null);
    }

因此可以看出@OnOpen所标记方法的合法参数有

(1)@PathParam标记的路径参数

(2)当前会话Session参数

(3)当前endpoint的配置详情EndpointConfig参数

2) onClose方法参数

onClose的可用参数在tomcat源码 org.apache.tomcat.websocket.pojo.PojoMethodMapping#getOnCloseArgs可以看到

java 复制代码
    public Object[] getOnCloseArgs(Map<String,String> pathParameters,
            Session session, CloseReason closeReason) throws DecodeException {
        return buildArgs(onCloseParams, pathParameters, session, null, null,
                closeReason);
    }

因此可以看出@OnClose所标记方法的合法参数有

(1)@PathParam标记的路径参数

(2)当前会话Session参数

(3)当前连接关闭的原因CloseReason参数

3) onError方法参数

onError的可用参数在tomcat源码 org.apache.tomcat.websocket.pojo.PojoMethodMapping#getOnErrorArgs可以看到

java 复制代码
    public Object[] getOnErrorArgs(Map<String,String> pathParameters,
            Session session, Throwable throwable) throws DecodeException {
        return buildArgs(onErrorParams, pathParameters, session, null,
                throwable, null);
    }

因此可以看出@OnError所标记方法的合法参数有

(1)@PathParam标记的路径参数

(2)当前会话Session参数

(3)发生异常的异常对象Throwable参数

4) onMessage方法参数

onMessage的可用参数在tomcat源码 org.apache.tomcat.websocket.pojo.PojoMethodMapping.MessageHandlerInfo#getMessageHandlers可以看到

java 复制代码
        public Set<MessageHandler> getMessageHandlers(Object pojo,
                Map<String,String> pathParameters, Session session,
                EndpointConfig config) {
            Object[] params = new Object[m.getParameterTypes().length];

            for (Map.Entry<Integer,PojoPathParam> entry :
                    indexPathParams.entrySet()) {
                PojoPathParam pathParam = entry.getValue();
                String valueString = pathParameters.get(pathParam.getName());
                Object value = null;
                try {
                    value = Util.coerceToType(pathParam.getType(), valueString);
                } catch (Exception e) {
                    DecodeException de =  new DecodeException(valueString,
                            sm.getString(
                                    "pojoMethodMapping.decodePathParamFail",
                                    valueString, pathParam.getType()), e);
                    params = new Object[] { de };
                    break;
                }
                params[entry.getKey().intValue()] = value;
            }

            Set<MessageHandler> results = new HashSet<>(2);
            if (indexBoolean == -1) {
                // Basic
                if (indexString != -1 || indexPrimitive != -1) {
                    MessageHandler mh = new PojoMessageHandlerWholeText(pojo, m,
                            session, config, null, params, indexPayload, false,
                            indexSession, maxMessageSize);
                    results.add(mh);
                } else if (indexReader != -1) {
                    MessageHandler mh = new PojoMessageHandlerWholeText(pojo, m,
                            session, config, null, params, indexReader, true,
                            indexSession, maxMessageSize);
                    results.add(mh);
                } else if (indexByteArray != -1) {
                    MessageHandler mh = new PojoMessageHandlerWholeBinary(pojo,
                            m, session, config, null, params, indexByteArray,
                            true, indexSession, false, maxMessageSize);
                    results.add(mh);
                } else if (indexByteBuffer != -1) {
                    MessageHandler mh = new PojoMessageHandlerWholeBinary(pojo,
                            m, session, config, null, params, indexByteBuffer,
                            false, indexSession, false, maxMessageSize);
                    results.add(mh);
                } else if (indexInputStream != -1) {
                    MessageHandler mh = new PojoMessageHandlerWholeBinary(pojo,
                            m, session, config, null, params, indexInputStream,
                            true, indexSession, true, maxMessageSize);
                    results.add(mh);
                } else if (decoderMatch != null && decoderMatch.hasMatches()) {
                    if (decoderMatch.getBinaryDecoders().size() > 0) {
                        MessageHandler mh = new PojoMessageHandlerWholeBinary(
                                pojo, m, session, config,
                                decoderMatch.getBinaryDecoders(), params,
                                indexPayload, true, indexSession, true,
                                maxMessageSize);
                        results.add(mh);
                    }
                    if (decoderMatch.getTextDecoders().size() > 0) {
                        MessageHandler mh = new PojoMessageHandlerWholeText(
                                pojo, m, session, config,
                                decoderMatch.getTextDecoders(), params,
                                indexPayload, true, indexSession, maxMessageSize);
                        results.add(mh);
                    }
                } else {
                    MessageHandler mh = new PojoMessageHandlerWholePong(pojo, m,
                            session, params, indexPong, false, indexSession);
                    results.add(mh);
                }
            } else {
                // ASync
                if (indexString != -1) {
                    MessageHandler mh = new PojoMessageHandlerPartialText(pojo,
                            m, session, params, indexString, false,
                            indexBoolean, indexSession, maxMessageSize);
                    results.add(mh);
                } else if (indexByteArray != -1) {
                    MessageHandler mh = new PojoMessageHandlerPartialBinary(
                            pojo, m, session, params, indexByteArray, true,
                            indexBoolean, indexSession, maxMessageSize);
                    results.add(mh);
                } else {
                    MessageHandler mh = new PojoMessageHandlerPartialBinary(
                            pojo, m, session, params, indexByteBuffer, false,
                            indexBoolean, indexSession, maxMessageSize);
                    results.add(mh);
                }
            }
            return results;
        }
    }

因此可以看出@OnMessage所标记方法的合法参数有

(1)@PathParam标记的路径参数

(2)当前会话Session参数

(3)当数据是分块传输时,表示当前消息时是否是最后一块数据的boolean Boolean参数

(4)字符输入流Reader参数

(5)二进制输入流InputStream参数

(6)原始的ByteBuffer参数

(7)字节数组byte[]参数

(8)字符串string参数

(9) Pong响应PongMessage参数

注意:接收数据报文的参数(4)~(5),只能使用其中的一个,否则可能导致IO异常(IO流只能读取一次)

ping和pong

上面的代码中我额外给websocket会话增加了一个PongMessage的处理方法onPong,它的作用是接收客户端的pong回执消息。只有在服务端向客户端发送Ping请求时,服务端才能接收到Pong响应。这里的ping和pong就是类型于其他系统中的心跳机制,用来检测客户端、服务端双方是否还在线,如果超过了限定时间没有收到pingpong消息,服务端就会主动断开连接。

因此我在建立websocke连接的时候给当前回话设置了最大空闲时间(超过这个时间没有数据报文传输,此连接就会自动断开),同时绑定了一个定时任务,这个定时任务会定时发送ping消息来保活。

这里的onPong方法不是必须的,没有它能保活,onPong只是用来得到一个ping结果的通知。

3)注册暴露端点

java 复制代码
@Configuration
@EnableWebSocket
public class WebsocketConfig  {
    

    @Bean
    public ServerEndpointExporter serverEndpointExporter(){
        ServerEndpointExporter serverEndpointExporter = new ServerEndpointExporter();
        //WebsocketHandler2如果是一个spring bean(即有@Component),则不需要调用setAnnotatedEndpointClasses方法,spring会自动探测有@ServerEndpoint注解的bean
        //WebsocketHandler2如果只是一个包含@ServerEndpoint注解的普通类(不是 spring bean),则需要在此调用setAnnotatedEndpointClasses方法,手动注册Endpoint类型
//        serverEndpointExporter.setAnnotatedEndpointClasses(WebsocketHandler2.class );
        return serverEndpointExporter;
    }
}    

配置类添加@EnableWebSocket,启用spring websocket功能.

另外还需配置一个Bean ServerEndpointExporter ;如果Endpoint类是一个spring bean(即有@Component),则不需要调用setAnnotatedEndpointClasses方法,spring会自动探测含有@ServerEndpoint注解的Bean;如果Endpoint类只是一个包含@ServerEndpoint注解的普通类(不是 spring bean),则需要在此调用setAnnotatedEndpointClasses方法,手动注册Endpoint类型。

注意:即使Endpoint类是spring bean ,WebsocketContainer也会再创建并使用这个类的一个新实例,也就是说这个Endpoint中不能使用spring相关的功能,典型的就是不能使用@Autowire等注解自动注入Bean。其原因是websocket的默认端点配置org.apache.tomcat.websocket.server.DefaultServerEndpointConfigurator获取endpoint实例的逻辑是反射调用构造方法去创建一个新对象

java 复制代码
public class DefaultServerEndpointConfigurator
        extends ServerEndpointConfig.Configurator {

    @Override
    public <T> T getEndpointInstance(Class<T> clazz)
            throws InstantiationException {
        try {
            return clazz.getConstructor().newInstance();
        } catch (InstantiationException e) {
            throw e;
        } catch (ReflectiveOperationException e) {
            InstantiationException ie = new InstantiationException();
            ie.initCause(e);
            throw ie;
        }
    }

当然你可以通过注入静态属性的方式来绕过这个限制。

理论上说也可在@ServerEndpoint注解的configurator属性指定为spring的org.springframework.web.socket.server.standard.SpringConfigurator也可以自动注入Bean依赖.

java 复制代码
 @ServerEndpoint(value = "/echo", configurator = SpringConfigurator.class)
 public class EchoEndpoint {
       // ...
   }

SpringConfigurator它重写了获取Endpoint实例的方法逻辑getEndpointInstance,它是直接到spring容器中去取这个bean,而不是创建一个新实例.

但实际在spring boot项目中,上面的getEndpointInstance方法获取到的WebApplicationContextnull,也就没法从spring容器中获取这个Endpoint bean

基于spring WebSocketHandler实现websocket

了解WebSocketHandler

提前引入前面提到的websocket的starter 依赖

WebSocketHandler接口定义了5个方法,
afterConnectionEstablished:建立连接后的回调方法
handleMessage:接收到客户端消息后的回调方法
handleTransportError: 数据传输异常时的回调方法
afterConnectionClosed: 连接关闭后的回调方法
supportsPartialMessages: 是否支持数据分块传输(最后一个分块传输,isLast是true)

它有两个主要的子类, 一个是处理纯文本数据的TextWebSocketHandler ,另一个是处理二进制数据的BinaryWebSocketHandler

我们实现websocket一般是继承这两个类,并重写相应的方法。一般都需要重写afterConnectionEstablished handleTransportError
handleTransportError afterConnectionClosed 这三个方法,除此之外,处理文本还要重写接收客户端消息后的回调方法handleTextMessage,处理二进制数据需要重写接收客户端消息后的回调方法handleBinaryMessage。如果有需要得到ping结果回调,还可以重写handlePongMessage方法

代码

java 复制代码
@Component
public class WebsocketHandler1 extends TextWebSocketHandler {
    private final Logger log = LoggerFactory.getLogger(getClass());
    private static final Set<WebSocketSession> sessions = new ConcurrentSkipListSet<>(Comparator.comparing(WebSocketSession::getId));
    private static final ScheduledExecutorService scheduledExecutor = Executors.newScheduledThreadPool(10);
    private static final Map<String, ScheduledFuture<?>> futures = new ConcurrentHashMap<>();


    @Override
    public void afterConnectionEstablished(WebSocketSession session) throws Exception {
        @SuppressWarnings("unchecked")
        AbstractWebSocketSession<Session> standardSession = (AbstractWebSocketSession) session;
        Session nativeSession = standardSession.getNativeSession();
        nativeSession.setMaxIdleTimeout(1000*4);

        ScheduledFuture<?> future = scheduledExecutor.scheduleWithFixedDelay(() -> sendPing(session), 5, 5, TimeUnit.SECONDS);
        futures.put(session.getId(), future);



        log.info("open connect sessionId={}", session.getId());
        sessions.add(session);
        TextMessage msg = new TextMessage(String.format("ws client(id=%s) has connected", session.getId()));
        session.sendMessage(msg);
    }


    @Override
    protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {
        log.info("receive client(id={}) msg=>{}", session.getId(), message.getPayload());
        TextMessage msg = new TextMessage(String.format("reply your(id=%s) msg=>%s", session.getId(), message.getPayload()));
        session.sendMessage(msg);
    }

    @Override
    protected void handlePongMessage(WebSocketSession session, PongMessage message) throws Exception {
        ByteBuffer payload = message.getPayload();
        String s = new String(payload.array());
        log.info("receive client(id={}) pong msg=>{}", session.getId(),s);
    }

    @Override
    public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {
        log.error("client(id={}) error occur ", session.getId(), exception);

    }

    @Override
    public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception {
        log.info("close ,status={}", status);
        sessions.remove(session);
        ScheduledFuture<?> future = futures.get(session.getId());
        if (future != null) {
            future.cancel(true);
        }
    }

    @Override
    public boolean supportsPartialMessages() {
        return true;
    }

    private void sendPing(WebSocketSession session) {
        if (session.isOpen()) {
            String replyContent = String.format("Hello,client(id=%s)", session.getId());

            PingMessage msg = new PingMessage(ByteBuffer.wrap(replyContent.getBytes(StandardCharsets.UTF_8)));
            try {
                session.sendMessage(msg);
            } catch (IOException e) {
                log.error("ping client(id={}) error", session.getId(), e);
            }

        }
    }
}

springboot内置的WebSocketHandler

前端html代码

html 复制代码
<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>WebSocket Chat</title>
    <style>
        body {
            font-family: Arial, sans-serif;
        }

        #chat-box {
            width: 100%;
            height: 300px;
            border: 1px solid #ccc;
            overflow-y: auto;
            padding: 10px;
            margin-bottom: 10px;
            background-color: #f9f9f9;
            white-space: pre-wrap;
        }

        #input-box {
            width: calc(50% - 90px);
            padding: 10px;
            margin-right: 10px;
            display: flex;
            justify-content: center
        }

        .btn {
            padding: 10px;

        }

        #btn-container {
            margin: 1px;
            display: flex;
            justify-content: center;
            gap: 5px;
        }

        #input-container {
            margin: 1px;
            display: flex;
            justify-content: center;
            gap: 5px;
        }
    </style>
</head>
<body>

<div id="chat-box"></div>
<div id="input-container">
    <input type="text" id="input-box" placeholder="Enter your message"/>
</div>

<div id="btn-container">
    <button id="connect-button" class="btn">Connect</button>
    <button id="close-button" class="btn">Close</button>
    <button id="clear-button" class="btn">Clear</button>
    <button id="send-button" class="btn">Send</button>
</div>


<script>
    const chatBox = document.getElementById('chat-box');
    const inputBox = document.getElementById('input-box');
    const sendButton = document.getElementById('send-button');
    const connectBtn = document.getElementById('connect-button');
    const closeBtn = document.getElementById('close-button');
    const clearBtn = document.getElementById('clear-button');


    let ws = null;

    sendButton.addEventListener('click', () => {
        if (ws === null) {
            alert("no connect")
            return;
        }
        const message = inputBox.value;
        if (message) {
            ws.send(message);

            chatBox.innerHTML += 'You: ' + message + '\n';
            chatBox.scrollTop = chatBox.scrollHeight;
            inputBox.value = '';
        }
    });

    clearBtn.addEventListener('click', () => {
        chatBox.innerHTML = '';
    });

    closeBtn.addEventListener('click', () => {
        if (ws === null) {
            alert("no connect")
            return;
        }
        console.log("prepare close ws");
        ws.close(1000, 'Normal closure');
    });

    connectBtn.addEventListener('click', () => {
        if (ws !== null) {
            alert("already connected!")
            return;
        }
        let curWs = new WebSocket('ws://localhost:7001/ws/Hews2df?id=323&color=red');

        curWs.onopen = event => {
            ws = curWs;
            console.log('Connected to WebSocket server, event=>%s', JSON.stringify(event));
        };

        curWs.onmessage = event => {
            const message = event.data;
            chatBox.innerHTML += 'Server: ' + message + '\n';
            chatBox.scrollTop = chatBox.scrollHeight;
        };

        curWs.onclose = event => {
            ws = null;
            console.log('Disconnected from WebSocket server, close code=%s,close reason=%s', event.code, event.reason);
        };
        curWs.onerror = event => {
            console.log("error occur, event=>%s", JSON.stringify(event))

        };
    });

</script>

</body>
</html>

演示效果

相关推荐
没有羊的王K5 分钟前
SSM框架学习DI入门——day2
java·spring boot·学习
brzhang15 分钟前
别再梭哈 Curosr 了!这 AI 神器直接把需求、架构、任务一条龙全干了!
前端·后端·架构
星释26 分钟前
优雅的Java:01.数据更新如何更优雅
java·开发语言·spring boot
安妮的心动录29 分钟前
安妮的2025 Q2 Review
后端·程序员
程序员爱钓鱼29 分钟前
Go语言数组排序(冒泡排序法)—— 用最直观的方式掌握排序算法
后端·google·go
从int开始31 分钟前
WebApplicationType.REACTIVE 的webSocket 多实例问题处理
websocket
Victor3561 小时前
MySQL(140)如何解决外键约束冲突?
后端
Victor3562 小时前
MySQL(139)如何处理MySQL字符编码问题?
后端
007php0073 小时前
服务器上PHP环境安装与更新版本和扩展(安装PHP、Nginx、Redis、Swoole和OPcache)
运维·服务器·后端·nginx·golang·测试用例·php