【java】实现sse调用websocket接口,忽略wss证书并控制sse吐字速度

maven

        <dependency>
            <groupId>org.java-websocket</groupId>
            <artifactId>Java-WebSocket</artifactId>
            <version>1.5.3</version>
        </dependency>

AsyncConfig

package com.test.demo.sse;

import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.scheduling.annotation.EnableAsync;
import org.springframework.scheduling.concurrent.CustomizableThreadFactory;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;

import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;

/**
 * <p>
 * <code>AsyncConfig</code>
 * </p>
 * Description: 异步配置
 */
@EnableAsync
@Configuration
public class AsyncConfig {

    /**
     * 核心线程数(默认线程数)
     */
    @Value("${sync.corePoolSize:50}")
    private int corePoolSize;
    /**
     * 最大线程数
     */
    @Value("${sync.maxPoolSize:200}")
    private int maxPoolSize;
    /**
     * 缓冲队列数数量
     */
    @Value("${sync.queueCapacity:10000000}")
    private int queueCapacity;

    @Bean
    public Executor executor() {
        ThreadPoolTaskExecutor taskExecutor = new ThreadPoolTaskExecutor();
        // 核心线程数(默认线程数)
        taskExecutor.setCorePoolSize(corePoolSize);
        // 最大线程数
        taskExecutor.setMaxPoolSize(maxPoolSize);
        // 缓冲队列数,默认Integer.MAX_VALUE.
        taskExecutor.setQueueCapacity(queueCapacity);
        // 线程池名前缀
        taskExecutor.setThreadNamePrefix("async-executor-");
        // 允许线程空闲时间(单位:秒),默认:60
        // taskExecutor.setKeepAliveSeconds(60);
        // 线程池对拒绝任务的处理策略,默认值AbortPolicy
        // taskExecutor.setRejectedExecutionHandler(new ThreadPoolExecutor.AbortPolicy());
        // 初始化
        taskExecutor.initialize();

        return taskExecutor;
    }

    @Bean
    public ScheduledExecutorService scheduledExecutorService() {
        return Executors.newScheduledThreadPool(corePoolSize,
                new CustomizableThreadFactory("schedule-executor-"));
    }
}

SpringContextUtils

package com.test.demo.sse;

import org.springframework.beans.BeansException;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.context.ApplicationEvent;
import org.springframework.stereotype.Component;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;

import javax.servlet.http.HttpServletRequest;

/**
 * Spring ApplicationContext 工具类
 */
@Component
public class SpringContextUtils implements ApplicationContextAware {

    /**
     * 上下文对象实例
     */
    private static ApplicationContext applicationContext;

    @Override
    public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
        SpringContextUtils.applicationContext = applicationContext;
    }

    /**
     * 获取applicationContext
     *
     * @return
     */
    public static ApplicationContext getApplicationContext() {
        return applicationContext;
    }

    /**
     * 获取HttpServletRequest
     */
    public static HttpServletRequest getHttpServletRequest() {
        return ((ServletRequestAttributes) RequestContextHolder.getRequestAttributes()).getRequest();
    }

    public static String getDomain() {
        HttpServletRequest request = getHttpServletRequest();
        StringBuffer url = request.getRequestURL();
        return url.delete(url.length() - request.getRequestURI().length(), url.length()).toString();
    }

    public static String getOrigin() {
        HttpServletRequest request = getHttpServletRequest();
        return request.getHeader("Origin");
    }

    /**
     * 通过name获取 Bean.
     *
     * @param name
     * @return
     */
    public static Object getBean(String name) {
        return getApplicationContext().getBean(name);
    }

    /**
     * 通过class获取Bean.
     *
     * @param clazz
     * @param <T>
     * @return
     */
    public static <T> T getBean(Class<T> clazz) {
        return getApplicationContext().getBean(clazz);
    }

    /**
     * 通过name,以及Clazz返回指定的Bean
     *
     * @param name
     * @param clazz
     * @param <T>
     * @return
     */
    public static <T> T getBean(String name, Class<T> clazz) {
        return getApplicationContext().getBean(name, clazz);
    }

    /**
     * 发布事件
     *
     * @param event
     */
    public static void publishEvent(ApplicationEvent event) {
        if (applicationContext == null) {
            return;
        }
        applicationContext.publishEvent(event);
    }
}

MySseEmitter

package com.test.demo.sse;

import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;

import java.nio.charset.StandardCharsets;
import java.util.UUID;
import java.util.concurrent.*;

/**
 * <p>
 * <code>MySseEmitter</code>
 * </p>
 * Description:  解决SseEmitter浏览器下中文乱码问题
 */
@EqualsAndHashCode(callSuper = true)
@Data
@Slf4j
public class MySseEmitter extends SseEmitter {
    /**
     * websocket返回的所有信息,只用于将消息发送到前端
     */
    private StringBuilder totalAnswer = new StringBuilder();
    /**
     * websocket返回的所有信息,用于最终存储的消息内容
     */
    private StringBuilder totalAnswerStorage = new StringBuilder();
    /**
     * 链接是否已主动断开,,true:已主动断开,false:未断开
     */
    private boolean disconnected = false;
    /**
     * 本条消息的唯一id
     */
    private String messageUuid = UUID.randomUUID().toString();
    /**
     * 本次会话的唯一id
     */
    private String conversationUuid = UUID.randomUUID().toString();
    /**
     * 是否匀速返回,true:需要匀速,false:不需要匀速
     */
    private boolean speedControl;
    /**
     * 是否已经开始匀速返回信息,true:已经开始,false:还没有开始
     */
    private boolean startSendMsgWithSpeedControl = false;
    /**
     * 已经发送的消息的长度
     */
    private int sendLength = 0;
    /**
     * 所有消息是否已经全部匀速返回,true:已经全部返回,false:还没有全部返回
     */
    private boolean endSendMsgWithSpeedControl = false;
    /**
     * 匀速吐字间隔时间,单位:毫秒
     */
    private long sleepTime = 20L;
    /**
     * 匀速发送消息时每次返回多少个字符
     */
    private int sendMsgSpeed = 1;
    /**
     * 超时时间,单位:毫秒
     */
    private long timeout;

    /**
     * 当前登录人
     */
    private String userUid;
    /**
     * 当前登录人的问题
     */
    private String userQuestion;


    /**
     * 解决中文乱码
     *
     * @param outputMessage
     */
    @Override
    protected void extendResponse(ServerHttpResponse outputMessage) {
        super.extendResponse(outputMessage);
        HttpHeaders headers = outputMessage.getHeaders();
        headers.setContentType(new MediaType(MediaType.TEXT_EVENT_STREAM, StandardCharsets.UTF_8));
    }

    /**
     * 创建SSE对象
     *
     * @param speedControl 是否开启匀速,true:开启,false:关闭
     * @param timeout      超时时间,单位:毫秒
     * @param userUid      当前登录人
     * @param userQuestion 当前登录人的问题
     */
    public MySseEmitter(boolean speedControl, long timeout, String userUid, String userQuestion) {
        // 设置超时时间,单位:毫秒
        super(timeout);
        this.speedControl = speedControl;
        this.timeout = timeout;
        this.userUid = userUid;
        this.userQuestion = userQuestion;
    }

    /**
     * 自定义发送消息方法
     *
     * @param message    具体消息
     * @param msgStorage 本次发送的消息内容是否需要进行存储,true:需要,false:不需要
     * @return 是否需要关闭链接,true:是,false:否
     */
    public boolean mySend(String message, boolean msgStorage) {
        try {
            if (StringUtils.isNotEmpty(message)) {
                // 处理换行,PC换行\r、\n、、\r\n都行,移动只能\r\n
                message = message.replaceAll("\r", "\n")
                        .replaceAll("\n\n", "\n")
                        .replaceAll("\n\n", "\n")
                        .replaceAll("\n\n", "\n")
                        .replaceAll("\n\n", "\n")
                        .replaceAll("\n\n", "\n")
                        .replaceAll("\n\n", "\n")
                        .replaceAll("\n\n", "\n")
                        .replaceAll("\n\n", "\n")
                        .replaceAll("\n\n", "\n")
                        .replaceAll("\n\n", "\n")
                        .replaceAll("\n\n", "\n")
                        .replaceAll("\n\n", "\n")
                        .replaceAll("\n", "\r\n");

                this.totalAnswer.append(message);
                if (msgStorage) {
                    this.totalAnswerStorage.append(message);
                }
                if (!this.speedControl) {
                    super.send(message);
                } else if (!this.startSendMsgWithSpeedControl && !this.disconnected) {
                    // 异步发送
                    SpringContextUtils.getBean(AsyncService.class).sendMsgWithSpeedControl(this);
                }
            }
        } catch (Exception e) {
            log.error("==>MySseEmitter send error,conversationUuid:{}", this.conversationUuid, e);
            this.disconnected = true;
        }
        return this.disconnected;
    }

    /**
     * 断开连接
     *
     * @param msgStorageType 消息处理类型,0:不存储,1:本地数据库存储
     */
    public void myComplete(String msgStorageType) {
        try {
            if (!this.speedControl || this.disconnected || this.endSendMsgWithSpeedControl) {
                super.complete();
            } else {
                Future<?> future = SpringContextUtils.getBean(ScheduledExecutorService.class).scheduleAtFixedRate(() -> {
                    if (this.endSendMsgWithSpeedControl) {
                        log.info("==>当前消息已全部返回完成,主动断开与端上链接,conversationUuid:{}", conversationUuid);
                        throw new RuntimeException("==>当前消息已全部返回完成,主动断开与端上链接,conversationUuid:" + conversationUuid);
                    }
                }, this.sleepTime, this.sleepTime, TimeUnit.MILLISECONDS);
                try {
                    // 超时时间,单位:毫秒
                    future.get(timeout, TimeUnit.MILLISECONDS);
                } catch (TimeoutException | ExecutionException e) {
                    log.info("==>等待断开链接任务执行结束,conversationUuid:{}", conversationUuid);
                    // 取消任务
                    future.cancel(true);
                } catch (Exception e) {
                    log.error("==>等待断开链接任务执行异常,conversationUuid:{}", conversationUuid, e);
                    // 取消任务
                    future.cancel(true);
                }

                super.complete();
            }
        } catch (Exception ignore) {
        }

        if ("1".equals(msgStorageType)) {
            // 本地数据库存储消息,异步保存数据
            SpringContextUtils.getBean(AsyncService.class).saveMsg(this.messageUuid, this.conversationUuid,
                    this.userUid, this.userQuestion, this.totalAnswerStorage.toString());
        }
    }
}

MyWebSocketClient

package com.test.demo.sse;

import lombok.extern.slf4j.Slf4j;
import org.java_websocket.client.WebSocketClient;
import org.java_websocket.drafts.Draft_6455;

import javax.net.ssl.*;
import java.net.Socket;
import java.net.URI;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;

/**
 * <p>
 * <code>MyWebSocketClient</code>
 * </p>
 * Description: 自定义WebSocketClient,忽略wss证书
 */
@Slf4j
public abstract class MyWebSocketClient extends WebSocketClient {

    /**
     * 创建WebSocketClient
     *
     * @param serverUri      websocket 地址
     * @param connectTimeout 连接超时时间,单位:毫秒
     */
    public MyWebSocketClient(URI serverUri, int connectTimeout) {
        // 设置连接超时时间
        super(serverUri, new Draft_6455(), null, connectTimeout);
        // 设置不验证SSL证书的SSLContext
        TrustManager[] trustAllCerts = new TrustManager[]{new X509ExtendedTrustManager() {
            @Override
            public void checkClientTrusted(X509Certificate[] x509Certificates, String s, Socket socket) throws CertificateException {

            }

            @Override
            public void checkServerTrusted(X509Certificate[] x509Certificates, String s, Socket socket) throws CertificateException {

            }

            @Override
            public void checkClientTrusted(X509Certificate[] x509Certificates, String s, SSLEngine sslEngine) throws CertificateException {

            }

            @Override
            public void checkServerTrusted(X509Certificate[] x509Certificates, String s, SSLEngine sslEngine) throws CertificateException {

            }

            @Override
            public X509Certificate[] getAcceptedIssuers() {
                return null;
            }

            @Override
            public void checkClientTrusted(X509Certificate[] arg0, String arg1) throws CertificateException {
            }

            @Override
            public void checkServerTrusted(X509Certificate[] arg0, String arg1) throws CertificateException {
            }
        }};

        try {
            SSLContext ssl = SSLContext.getInstance("SSL");
            ssl.init(null, trustAllCerts, new java.security.SecureRandom());
            SSLSocketFactory socketFactory = ssl.getSocketFactory();
            this.setSocketFactory(socketFactory);
        } catch (Exception e) {
            log.error("==>初始化SSLContext失败", e);
        }
    }
}

MyWebSocketClientHelper

package com.test.demo.sse;

import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.java_websocket.client.WebSocketClient;
import org.java_websocket.handshake.ServerHandshake;

import java.net.URI;

/**
 * <p>
 * <code>MyWebSocketClientHelper</code>
 * </p>
 * Description:
 */
@Slf4j
public class MyWebSocketClientHelper {

    /**
     * WebSocketClient连接并发送消息
     *
     * @param sseEmitter     sse链接
     * @param msgStorageType 消息处理类型,0:不存储,1:本地数据库存储
     */
    public static void connectAndSend(MySseEmitter sseEmitter, String msgStorageType) {
        String commonErrorMsg = "通用错误信息,报错啦";
        String messageUuid = sseEmitter.getMessageUuid();

        WebSocketClient client = null;
        try {
            client = new MyWebSocketClient(new URI("wss://xxxxx"), Integer.parseInt(Long.toString(sseEmitter.getTimeout()))) {
                @Override
                public void onOpen(ServerHandshake serverHandshake) {
                    log.info("==>connect success,messageUuid:{}", messageUuid);
                    try {
                        String requestParam = "xxxxxxx";
                        this.send(requestParam);
                    } catch (Exception e) {
                        log.error("==>sendRequest error,messageUuid:{}", messageUuid, e);
                        throw e;
                    }
                }

                @Override
                public void onMessage(String result) {
                    log.info("==>messageUuid:{},onMessage:{}", messageUuid, result);
                    try {
                        sseEmitter.mySend(result, true);
                    } catch (Exception e) {
                        log.error("==>onMessage error,messageUuid:{}", messageUuid, e);
                    }
                }

                @Override
                public void onClose(int code, String reason, boolean remote) {
                    // 1. code(int类型):表示关闭连接的原因,通常是一个整数。例如,如果连接正常关闭,code的值可能是1000(表示正常关闭);如果连接因为服务器主动关闭而关闭,code的值可能是1006(表示服务器端强制关闭)。
                    // 2. reason(String类型):表示关闭连接的原因,通常是一段文本描述。这个参数是可选的,如果没有提供原因,可以传递一个空字符串或者null。
                    // 3. remote(boolean类型):表示连接是否被清理。如果为true,表示连接正常关闭;如果为false,表示连接异常关闭。
                    log.info("==>onClose,messageUuid:{},code:{},reason:{},remote:{}", messageUuid, code, reason, remote);
                    try {
                        if (!sseEmitter.isDisconnected() && StringUtils.isBlank(sseEmitter.getTotalAnswer().toString())) {
                            sseEmitter.mySend(commonErrorMsg, true);
                        }
                        sseEmitter.myComplete(msgStorageType);
                    } catch (Exception ignored) {
                    }
                }

                @Override
                public void onError(Exception e) {
                    log.error("==>onError,messageUuid:{}", messageUuid, e);
                    try {
                        this.close();
                    } catch (Exception ignored) {
                    }
                }
            };
            client.connect();
        } catch (Exception e) {
            log.error("==>WebSocketClientConnectAndSend error,messageUuid:{}", messageUuid, e);
            try {
                if (client != null) {
                    client.close();
                }
            } catch (Exception ignored) {
            }
        }
    }
}

AsyncService

package com.test.demo.sse;

/**
 * <p>
 * <code>AsyncService</code>
 * </p>
 * Description:
 */
public interface AsyncService {

    /**
     * 异步匀速返回消息
     *
     * @param sseEmitter
     */
    void sendMsgWithSpeedControl(MySseEmitter sseEmitter);

    /**
     * 异步保存消息
     *
     * @param messageUuid        消息uid
     * @param conversationUuid   会话uid
     * @param userUid            当前登录人uid
     * @param userQuestion       用户输入的问题
     * @param totalAnswerStorage websocket返回的具体消息内容
     */
    void saveMsg(String messageUuid, String conversationUuid, String userUid, String userQuestion, String totalAnswerStorage);
}

AsyncServiceImpl

package com.test.demo.sse;

import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Lazy;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;

import java.util.concurrent.*;

/**
 * <p>
 * <code>AsyncServiceImpl</code>
 * </p>
 * Description:
 */
@Slf4j
@Async
@Service
public class AsyncServiceImpl implements AsyncService {

    @Lazy
    @Autowired
    ScheduledExecutorService scheduledExecutorService;

    @Override
    public void sendMsgWithSpeedControl(MySseEmitter sseEmitter) {
        sseEmitter.setStartSendMsgWithSpeedControl(true);
        sseEmitter.setEndSendMsgWithSpeedControl(false);

        final String messageUuid = sseEmitter.getMessageUuid();

        // 使用scheduleAtFixedRate方法安排任务
        Future<?> future = scheduledExecutorService.scheduleAtFixedRate(() -> {
            int sendLength = sseEmitter.getSendLength();
            int sendMsgSpeed = sseEmitter.getSendMsgSpeed();
            // 当前时刻,所有的消息
            String nowAllAnswer = sseEmitter.getTotalAnswer().toString();
            int totalAnswerLength = nowAllAnswer.length();
            if (sendLength >= totalAnswerLength) {
                log.info("==>当前时刻所有消息已全部发送完成,任务执行结束,messageUuid:{}", messageUuid);
                throw new RuntimeException("==>当前时刻所有消息已全部发送完成,任务执行结束,messageUuid:" + messageUuid);
            }

            String message;
            if ((sendLength + sendMsgSpeed) > totalAnswerLength) {
                message = nowAllAnswer.substring(sendLength);
            } else {
                message = nowAllAnswer.substring(sendLength, sendLength + sendMsgSpeed);
            }

            if (message.endsWith("\r") && (sendLength + sendMsgSpeed + 1) <= totalAnswerLength) {
                message = nowAllAnswer.substring(sendLength, sendLength + sendMsgSpeed + 1);
            }

            try {
                sseEmitter.send(message);
            } catch (Exception e) {
                sseEmitter.setDisconnected(true);
                log.info("==>发送消息失败,视为端上主动断开链接,任务执行结束,messageUuid:{}", messageUuid);
                throw new RuntimeException("==>发送消息失败,视为端上主动断开链接,任务执行结束,messageUuid:" + messageUuid);
            }
            sendLength += message.length();
            sseEmitter.setSendLength(sendLength);
        }, sseEmitter.getSleepTime(), sseEmitter.getSleepTime(), TimeUnit.MILLISECONDS);

        // 尝试获取任务结果,如果超过超时时间则抛出TimeoutException异常
        try {
            // 超时时间,单位:毫秒
            future.get(sseEmitter.getTimeout(), TimeUnit.MILLISECONDS);
        } catch (TimeoutException | ExecutionException e) {
            log.info("==>任务执行结束,messageUuid:{}", messageUuid);
            // 取消任务
            future.cancel(true);
        } catch (Exception e) {
            log.error("==>任务执行异常,messageUuid:{}", messageUuid, e);
            // 取消任务
            future.cancel(true);
        }

        sseEmitter.setStartSendMsgWithSpeedControl(false);
        sseEmitter.setEndSendMsgWithSpeedControl(true);
    }


    @Override
    @Transactional(rollbackFor = Exception.class)
    public void saveMsg(String messageUuid, String conversationUuid, String userUid, String userQuestion, String totalAnswerStorage) {
        log.info("==>保存会话和信息,messageUuid:{},conversationUuid:{},userUid:{},userQuestion:{},totalAnswerStorage:{}",
                messageUuid, conversationUuid, userUid, userQuestion, totalAnswerStorage);
        // 会话不存在的,新建会话并保存


        // 保存消息
    }
}

使用方法

package com.test.demo.sse;

import lombok.extern.slf4j.Slf4j;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;

/**
 * <p>
 * <code>SseTestController</code>
 * </p>
 * Description:
 */
@Slf4j
@RestController
@RequestMapping("/sse")
public class SseTestController {

    /**
     * 用户提问提问
     *
     * @param input 输入信息
     * @return
     */
    @PostMapping(value = "/ask")
    public SseEmitter ask(@RequestBody @Valid Input input) {
        MySseEmitter sseEmitter = new MySseEmitter(true, 60000L, input.getUserUid(), input.getUserQuestion);
        try {
            MyWebSocketClientHelper.connectAndSend(sseEmitter, "1");
        } catch (Exception e) {
            log.error("ask error", e);
            sseEmitter.myComplete("1");
        }
        return sseEmitter;
    }
}
相关推荐
Theodore_10224 小时前
4 设计模式原则之接口隔离原则
java·开发语言·设计模式·java-ee·接口隔离原则·javaee
冰帝海岸5 小时前
01-spring security认证笔记
java·笔记·spring
世间万物皆对象5 小时前
Spring Boot核心概念:日志管理
java·spring boot·单元测试
没书读了6 小时前
ssm框架-spring-spring声明式事务
java·数据库·spring
小二·6 小时前
java基础面试题笔记(基础篇)
java·笔记·python
开心工作室_kaic6 小时前
ssm161基于web的资源共享平台的共享与开发+jsp(论文+源码)_kaic
java·开发语言·前端
懒洋洋大魔王6 小时前
RocketMQ的使⽤
java·rocketmq·java-rocketmq
武子康6 小时前
Java-06 深入浅出 MyBatis - 一对一模型 SqlMapConfig 与 Mapper 详细讲解测试
java·开发语言·数据仓库·sql·mybatis·springboot·springcloud
转世成为计算机大神7 小时前
易考八股文之Java中的设计模式?
java·开发语言·设计模式
qq_327342737 小时前
Java实现离线身份证号码OCR识别
java·开发语言