Flutter+SpringBoot实现ChatGPT流式输出、上下文了连续对话
最终实现Flutter的流式输出+上下文连续对话。
这里就是提供一个简单版的工具类和使用案例,此处页面仅参考。
服务端
这里直接封装提供工具类,修改自己的apiKey即可使用,支持连续对话
工具类及使用
http依赖这里使用okHttp
xml
<dependency>
<groupId>com.squareup.okhttp3</groupId>
<artifactId>okhttp</artifactId>
<version>4.9.3</version>
</dependency>
java
import com.alibaba.fastjson2.JSON;
import com.squareup.okhttp.Call;
import com.squareup.okhttp.MediaType;
import com.squareup.okhttp.OkHttpClient;
import com.squareup.okhttp.Request;
import com.squareup.okhttp.RequestBody;
import com.squareup.okhttp.Response;
import com.squareup.okhttp.ResponseBody;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import vip.ailtw.common.utils.StringUtil;
import javax.annotation.PostConstruct;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.Serializable;
import java.util.List;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
@Slf4j
@Component
public class ChatGptStreamUtil {
/**
* 修改为自己的密钥
*/
private final String apiKey = "xxxxxxxxxxxxxx";
public final String gptCompletionsUrl = "https://api.openai.com/v1/chat/completions";
private static final OkHttpClient client = new OkHttpClient();
private static MediaType mediaType;
private static Request.Builder requestBuilder;
public final static Pattern contentPattern = Pattern.compile("\"content\":\"(.*?)\"}");
/**
* 对话符号
*/
public final static String EVENT_DATA = "d";
/**
* 错误结束符号
*/
public final static String EVENT_ERROR = "e";
/**
* 响应结束符号
*/
public final static String END = "<<END>>";
@PostConstruct
public void init() {
client.setConnectTimeout(60, TimeUnit.SECONDS);
client.setReadTimeout(60, TimeUnit.SECONDS);
mediaType = MediaType.parse("application/json; charset=utf-8");
requestBuilder = new Request.Builder()
.url(gptCompletionsUrl)
.header("Content-Type", "application/json")
.header("Authorization", "Bearer " + apiKey);
}
/**
* 流式对话
*
* @param talkList 上下文对话,最早的对话放在首位
* @param callable 消费者,流式对话每次响应的内容
*/
public GptChatResultDTO chatStream(List<ChatGptDTO> talkList, Consumer<String> callable) throws Exception {
long start = System.currentTimeMillis();
StringBuilder resp = new StringBuilder();
Response response = chatStream(talkList);
//解析对话内容
try (ResponseBody responseBody = response.body();
InputStream inputStream = responseBody.byteStream();
BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(inputStream))) {
String line;
while ((line = bufferedReader.readLine()) != null) {
if (!StringUtils.hasLength(line)) {
continue;
}
Matcher matcher = contentPattern.matcher(line);
if (matcher.find()) {
String content = matcher.group(1);
resp.append(content);
callable.accept(content);
}
}
}
int wordSize = 0;
for (ChatGptDTO dto : talkList) {
String content = dto.getContent();
wordSize += content.toCharArray().length;
}
wordSize += resp.toString().toCharArray().length;
long end = System.currentTimeMillis();
return GptChatResultDTO.builder().resContent(resp.toString()).time(end - start).wordSize(wordSize).build();
}
/**
* 流式对话
*
* @param talkList 上下文对话
* @return 接口请求响应
*/
private Response chatStream(List<ChatGptDTO> talkList) throws Exception {
ChatStreamDTO chatStreamDTO = new ChatStreamDTO(talkList);
RequestBody bodyOk = RequestBody.create(mediaType, chatStreamDTO.toString());
Request requestOk = requestBuilder.post(bodyOk).build();
Call call = client.newCall(requestOk);
Response response;
try {
response = call.execute();
} catch (IOException e) {
throw new IOException("请求时IO异常: " + e.getMessage());
}
if (response.isSuccessful()) {
return response;
}
try (ResponseBody body = response.body()) {
if (429 == response.code()) {
String msg = "Open Api key 已过期,msg: " + body.string();
log.error(msg);
}
throw new RuntimeException("chat api 请求异常, code: " + response.code() + "body: " + body.string());
}
}
private boolean sendToClient(String event, String data, SseEmitter emitter) {
try {
emitter.send(SseEmitter.event().name(event).data("{" + data + "}"));
return true;
} catch (IOException e) {
log.error("向客户端发送消息时出现异常", e);
}
return false;
}
/**
* 发送事件给客户端
*/
public boolean sendData(String data, SseEmitter emitter) {
if (StringUtil.isBlank(data)) {
return true;
}
return sendToClient(EVENT_DATA, data, emitter);
}
/**
* 发送结束事件,会关闭emitter
*/
public void sendEnd(SseEmitter emitter) {
try {
sendToClient(EVENT_DATA, END, emitter);
} finally {
emitter.complete();
}
}
/**
* 发送异常事件,会关闭emitter
*/
public void sendError(SseEmitter emitter) {
try {
sendToClient(EVENT_ERROR, "我累垮了", emitter);
} finally {
emitter.complete();
}
}
/**
* gpt请求结果
*/
@Data
@NoArgsConstructor
@AllArgsConstructor
@Builder
public static class GptChatResultDTO implements Serializable {
/**
* gpt请求返回的全部内容
*/
private String resContent;
/**
* 上下文消耗的字数
*/
private int wordSize;
/**
* 耗时
*/
private long time;
}
/**
* 连续对话DTO
*/
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
public static class ChatGptDTO implements Serializable {
/**
* 对话内容
*/
private String content;
/**
* 角色 {@link GptRoleEnum}
*/
private String role;
}
/**
* gpt连续对话角色
*/
@Getter
public static enum GptRoleEnum {
USER_ROLE("user", "用户"),
GPT_ROLE("assistant", "ChatGPT本身"),
/**
* message里role为system,是为了让ChatGPT在对话过程中设定自己的行为
* 可以理解为对话的设定,如你是谁,要什么语气、等级
*/
SYSTEM_ROLE("system", "对话设定"),
;
private final String value;
private final String desc;
GptRoleEnum(String value, String desc) {
this.value = value;
this.desc = desc;
}
}
/**
* gpt请求body
*/
@Data
public static class ChatStreamDTO {
private static final String model = "gpt-3.5-turbo";
private static final boolean stream = true;
private List<ChatGptDTO> messages;
public ChatStreamDTO(List<ChatGptDTO> messages) {
this.messages = messages;
}
@Override
public String toString() {
return "{\"model\":\"" + model + "\"," +
"\"messages\":" + JSON.toJSONString(messages) + "," +
"\"stream\":" + stream + "}";
}
}
}
使用案例:
java
public static void main(String[] args) throws Exception {
ChatGptStreamUtil chatGptStreamUtil = new ChatGptStreamUtil();
chatGptStreamUtil.init();
//构建一个上下文对话情景
List<ChatGptDTO> talkList = new ArrayList<>();
//设定gpt
talkList.add(ChatGptDTO.builder().content("你是chatgpt助手,能过帮助我查阅资料,编写教学报告。").role(GptRoleEnum.GPT_ROLE.getValue()).build());
//开始提问
talkList.add(ChatGptDTO.builder().content("请帮我写一篇小学数学加法运算教案").role(GptRoleEnum.USER_ROLE.getValue()).build());
chatGptStreamUtil.chatStream(talkList, (respContent) -> {
//这里是gpt每次流式返回的内容
System.out.println("gpt返回:" + respContent);
});
}
SpringBoot接口
基于SpringBoot工程,提供接口,供Flutter端使用。
通过上面的工具类的使用,可以知道gpt返回给我们的内容是一段一段的,因此如果我们服务端也要提供类似的效果,提供两个思路和实现:
- WebSocket,服务端接收gpt返回的内容时推送内容给flutter
- 使用Http长链接,也就是 SseEmitter,这里也是采用这种方式。
代码:
java
@RestController
@RequestMapping("/chat")
@Slf4j
public class ChatController {
@Autowired
private ChatGptStreamUtil chatGptStreamUtil;
@PostMapping(value = "/chatStream")
@ApiOperation("流式对话")
public SseEmitter chatStream() {
SseEmitter emitter = new SseEmitter(80000L);
//构建一个上下文对话情景
List<ChatGptDTO> talkList = new ArrayList<>();
//设定gpt
talkList.add(ChatGptDTO.builder().content("你是chatgpt助手,能过帮助我查阅资料,编写教学报告。").role(GptRoleEnum.GPT_ROLE.getValue()).build());
//开始提问
talkList.add(ChatGptDTO.builder().content("请帮我写一篇小学数学加法运算教案").role(GptRoleEnum.USER_ROLE.getValue()).build());
GptChatResultDTO gptChatResultDTO = chatGptStreamUtil.chatStream(talkList, (content) -> {
//这里服务端接收到消息就发送给Flutter
chatGptStreamUtil.sendData(content, emitter);
});
return emitter;
}
}
Flutter端
这里使用dio作为网络请求的工具
依赖
yml
dio: ^5.2.1+1
工具类
dart
import 'dart:async';
import 'dart:convert';
import 'package:dio/dio.dart';
import 'package:flutter/cupertino.dart';
import 'package:flutter/foundation.dart';
import 'package:get/get.dart' hide Response;
///http工具类
class HttpUtil {
Dio? client;
static HttpUtil of() {
return HttpUtil.init();
}
//初始化http工具
HttpUtil.init() {
if (client == null) {
var options = BaseOptions(
baseUrl: Config.baseUrl,
connectTimeout: const Duration(seconds: 100),
receiveTimeout: const Duration(seconds: 100));
client = Dio(options);
// 请求与响应拦截器/异常拦截器
client?.interceptors.add(OnReqResInterceptors());
}
}
Future<Stream<String>?> postStream(String path,
[Map<String, dynamic>? params]) async {
Response<ResponseBody> rs =
await Dio().post<ResponseBody>(Config.baseUrl + path,
options: Options(headers: {
"Accept": "text/event-stream",
"Cache-Control": "no-cache"
}, responseType: ResponseType.stream),
data: params
);
StreamTransformer<Uint8List, List<int>> unit8Transformer =
StreamTransformer.fromHandlers(
handleData: (data, sink) {
sink.add(List<int>.from(data));
},
);
var resp = rs.data?.stream
.transform(unit8Transformer)
.transform(const Utf8Decoder())
.transform(const LineSplitter());
return resp;
}
/// Dio 请求与响应拦截器
class OnReqResInterceptors extends InterceptorsWrapper {
@override
Future<void> onRequest(
RequestOptions options, RequestInterceptorHandler handler) async {
//统一添加token
var headers = options.headers;
headers['Authorization'] = '请求头token';
return super.onRequest(options, handler);
}
@override
void onError(DioError err, ErrorInterceptorHandler handler) {
if (err.type == DioErrorType.unknown) {
// 网络不可用,请稍后再试
}
return super.onError(err, handler);
}
@override
void onResponse(
Response<dynamic> response, ResponseInterceptorHandler handler) {
Response res = response;
return super.onResponse(res, handler);
}
}
使用
dart
//构建文章、流式对话
chatStream() async {
final stream = await HttpUtil.of().postStream("/api/chat/chatStream");
String respContent = "";
stream?.listen((content) {
debugPrint(content);
if (content != '' && content.contains("data:")) {
//解析数据
var start = content.indexOf("{") + 1;
var end = content.indexOf("}");
var substring = content.substring(start, end);
content = substring;
respContent += content;
print("返回的内容:$content");
}
});
}