java
复制代码
<dependency>
<groupId>com.squareup.okhttp3</groupId>
<artifactId>okhttp</artifactId>
<version>4.2.0</version>
</dependency>
<dependency>
<groupId>com.squareup.okhttp3</groupId>
<artifactId>okhttp-sse</artifactId>
<version>4.2.0</version>
</dependency>
controller
java
复制代码
package com.demo.controller;
import com.alibaba.fastjson.JSON;
import com.demo.listener.SSEListener;
import com.demo.params.req.ChatGlmDto;
import com.demo.utils.ExecuteSSEUtil;
import lombok.extern.slf4j.Slf4j;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RestController;
import javax.servlet.http.HttpServletResponse;
@RestController
@Slf4j
public class APITestController {
private static final String API_KEY = "xxx";
private static final String URL = "xxx";
@PostMapping(value = "/sse-invoke", produces = "text/event-stream;charset=UTF-8")
public void sse(@RequestBody ChatGlmDto chatGlmDto, HttpServletResponse rp) {
try {
SSEListener sseListener = new SSEListener(chatGlmDto, rp);
ExecuteSSEUtil.executeSSE(URL, API_KEY , sseListener, JSON.toJSONString(chatGlmDto));
} catch (Exception e) {
log.error("请求SSE错误处理", e);
}
}
}
ChatGlmDto
java
复制代码
package com.demo.params.req;
import lombok.Data;
/**
* Created by WeiRan on 2023.03.20 19:19
*/
@Data
public class ChatGlmDto {
private String messageId;
private Object prompt;
private String requestTaskNo;
private boolean incremental = true;
private boolean notSensitive = true;
}
SSEListener
java
复制代码
package com.demo.listener;
import com.alibaba.fastjson.JSON;
import com.demo.params.req.ChatGlmDto;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import okhttp3.Response;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSourceListener;
import javax.servlet.http.HttpServletResponse;
import java.util.concurrent.CountDownLatch;
@Slf4j
@Data
public class SSEListener extends EventSourceListener {
private CountDownLatch countDownLatch = new CountDownLatch(1);
private ChatGlmDto chatGlmDto;
private HttpServletResponse rp;
private StringBuffer output = new StringBuffer();
public SSEListener(ChatGlmDto chatGlmDto, HttpServletResponse response) {
this.chatGlmDto = chatGlmDto;
this.rp = response;
}
/**
* {@inheritDoc}
* 建立sse连接
*/
@Override
public void onOpen(final EventSource eventSource, final Response
response) {
if (rp != null) {
rp.setContentType("text/event-stream");
rp.setCharacterEncoding("UTF-8");
rp.setStatus(200);
log.info("建立sse连接..." + JSON.toJSONString(chatGlmDto));
} else {
log.info("客户端非sse推送" + JSON.toJSONString(chatGlmDto));
}
}
/**
* 事件
*
* @param eventSource
* @param id
* @param type
* @param data
*/
@Override
public void onEvent(EventSource eventSource, String id, String type, String data) {
try {
output.append(data);
if ("finish".equals(type)) {
log.info("请求结束{} {}", chatGlmDto.getMessageId(), output.toString());
}
if ("error".equals(type)) {
log.info("{}: {}source {}", chatGlmDto.getMessageId(), data, JSON.toJSONString(chatGlmDto));
}
if (rp != null) {
if ("\n".equals(data)) {
rp.getWriter().write("event:" + type + "\n");
rp.getWriter().write("id:" + chatGlmDto.getMessageId() + "\n");
rp.getWriter().write("data:\n\n");
rp.getWriter().flush();
} else {
String[] dataArr = data.split("\\n");
for (int i = 0; i < dataArr.length; i++) {
if (i == 0) {
rp.getWriter().write("event:" + type + "\n");
rp.getWriter().write("id:" + chatGlmDto.getMessageId() + "\n");
}
if (i == dataArr.length - 1) {
rp.getWriter().write("data:" + dataArr[i] + "\n\n");
rp.getWriter().flush();
} else {
rp.getWriter().write("data:" + dataArr[i] + "\n");
rp.getWriter().flush();
}
}
}
}
} catch (Exception e) {
log.error("消息错误[" + JSON.toJSONString(chatGlmDto) + "]", e);
countDownLatch.countDown();
throw new RuntimeException(e);
}
}
/**
* {@inheritDoc}
*/
@Override
public void onClosed(final EventSource eventSource) {
log.info("sse连接关闭:{}", chatGlmDto.getMessageId());
log.info("结果输出:{}" + output.toString());
countDownLatch.countDown();
}
/**
* {@inheritDoc}
*/
@Override
public void onFailure(final EventSource eventSource, final Throwable t, final Response response) {
log.error("使用事件源时出现异常... [响应:{}]...", chatGlmDto.getMessageId());
countDownLatch.countDown();
}
public CountDownLatch getCountDownLatch() {
return this.countDownLatch;
}
}
ExecuteSSEUtil
java
复制代码
package com.demo.utils;
import com.demo.listener.SSEListener;
import lombok.extern.slf4j.Slf4j;
import okhttp3.MediaType;
import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSources;
@Slf4j
public class ExecuteSSEUtil {
public static void executeSSE(String url, String authToken, SSEListener eventSourceListener, String chatGlm) throws Exception {
RequestBody formBody = RequestBody.create(chatGlm, MediaType.parse("application/json; charset=utf-8"));
Request.Builder requestBuilder = new Request.Builder();
requestBuilder.addHeader("Authorization", authToken);
Request request = requestBuilder.url(url).post(formBody).build();
EventSource.Factory factory = EventSources.createFactory(OkHttpUtil.getInstance());
//创建事件
factory.newEventSource(request, eventSourceListener);
eventSourceListener.getCountDownLatch().await();
}
}
OkHttpUtil
java
复制代码
package com.demo.utils;
import okhttp3.ConnectionPool;
import okhttp3.OkHttpClient;
import java.net.Proxy;
import java.util.concurrent.TimeUnit;
public class OkHttpUtil {
private static OkHttpClient okHttpClient;
public static ConnectionPool connectionPool = new ConnectionPool(10, 5, TimeUnit.MINUTES);
public static OkHttpClient getInstance() {
if (okHttpClient == null) { //加同步安全
synchronized (OkHttpClient.class) {
if (okHttpClient == null) { //okhttp可以缓存数据....指定缓存路径
okHttpClient = new OkHttpClient.Builder()//构建器
.proxy(Proxy.NO_PROXY) //来屏蔽系统代理
.connectionPool(connectionPool)
.connectTimeout(600, TimeUnit.SECONDS)//连接超时
.writeTimeout(600, TimeUnit.SECONDS)//写入超时
.readTimeout(600, TimeUnit.SECONDS)//读取超时
.build();
okHttpClient.dispatcher().setMaxRequestsPerHost(200);
okHttpClient.dispatcher().setMaxRequests(200);
}
}
}
return okHttpClient;
}
}