项目整合星火认知大模型(后端)

准备工作

  1. 讯飞星火大模型上根据官方的提示申请tokens 申请成功后可以获得对应的secret,key还有之前创建的应用的appId,这些就是我们要用到的信息

  2. 搭建项目

整体思路

考虑到敏感信息等安全性问题,这里和大模型的交互都放到后端去做。

客户端,服务端,星火大模型均通过Websocket的方式建立连接,用户询问问题时向SpringBoot服务端发送消息,服务端接收到后,创建与星火大模型的连接,并访问大模型,获取到请求结果后发送给客户端

==如果想实现根据上下文问答,就要把历史问题和历史回答结果全部传回大模型服务端==

请求参数的构建和响应参数解析参照官方文档Web API文档

接入大模型

服务端和大模型连接

java 复制代码
/**
 * 与大模型建立Socket连接
 *
 * @author gwj
 */
@Slf4j
public class BigModelNew extends WebSocketListener {

    public static final String appid = "appid";

    // 对话历史存储集合
    public static Map<Long,List<RoleContent>> hisMap = new ConcurrentHashMap<>();

    public static String totalAnswer = ""; // 大模型的答案汇总

    private static String newAsk = "";

    public static synchronized void ask(String question) {
        newAsk = question;
    }

    public static final Gson gson = new Gson();

    // 项目中需要用到的参数
    private Long userId;
    private Boolean wsCloseFlag;


    // 构造函数
    public BigModelNew(Long userId, Boolean wsCloseFlag) {
        this.userId = userId;
        this.wsCloseFlag = wsCloseFlag;
    }

    // 由于历史记录最大上线1.2W左右,需要判断是能能加入历史
    public boolean canAddHistory() {
        int len = 0;
        List<RoleContent> list = hisMap.get(userId);
        for (RoleContent temp : list) {
            len = len + temp.getContent().length();
        }
        if (len > 12000) {
            list.remove(0);
            list.remove(1);
            list.remove(2);
            list.remove(3);
            list.remove(4);
            return false;
        } else {
            return true;
        }
    }

    // 线程来发送参数
    class ModelThread extends Thread {
        private WebSocket webSocket;
        private Long userId;

        public ModelThread(WebSocket webSocket, Long userId) {
            this.webSocket = webSocket;
            this.userId = userId;
        }

        public void run() {
            try {
                JSONObject requestJson = new JSONObject();

                JSONObject header = new JSONObject();  // header参数
                header.put("app_id", appid);
                header.put("uid", userId+UUID.randomUUID().toString().substring(0,16));

                JSONObject parameter = new JSONObject(); // parameter参数
                JSONObject chat = new JSONObject();
                chat.put("domain", "4.0Ultra");
                chat.put("temperature", 0.5);
                chat.put("max_tokens", 4096);
                parameter.put("chat", chat);

                JSONObject payload = new JSONObject(); // payload参数
                JSONObject message = new JSONObject();
                JSONArray text = new JSONArray();

                // 历史问题获取
                List<RoleContent> list = hisMap.get(userId);
                if (list != null && !list.isEmpty()) {
                    //log.info("his:{}",list);
                    for (RoleContent tempRoleContent : list) {
                        text.add(JSON.toJSON(tempRoleContent));
                    }
                }

                // 最新问题
                RoleContent roleContent = new RoleContent();
                roleContent.setRole("user");
                roleContent.setContent(newAsk);
                text.add(JSON.toJSON(roleContent));
                hisMap.computeIfAbsent(userId, k -> new ArrayList<>());
                hisMap.get(userId).add(roleContent);

                message.put("text", text);
                payload.put("message", message);

                requestJson.put("header", header);
                requestJson.put("parameter", parameter);
                requestJson.put("payload", payload);
                // System.out.println(requestJson);

                webSocket.send(requestJson.toString());
                // 等待服务端返回完毕后关闭
                while (true) {
                    // System.err.println(wsCloseFlag + "---");
                    Thread.sleep(200);
                    if (wsCloseFlag) {
                        break;
                    }
                }
                webSocket.close(1000, "");
            } catch (Exception e) {
                log.error("【大模型】发送消息错误,{}",e.getMessage());
            }
        }
    }

    @Override
    public void onOpen(WebSocket webSocket, Response response) {
        super.onOpen(webSocket, response);
        log.info("上线");
        ModelThread modelThread = new ModelThread(webSocket,userId);
        modelThread.start();
    }

    @Override
    public void onMessage(WebSocket webSocket, String text) {
        JsonParse json = gson.fromJson(text, JsonParse.class);
        if (json.getHeader().getCode() != 0) {
            log.error("发生错误,错误码为:{} sid为:{}", json.getHeader().getCode(),json.getHeader().getSid());
            //System.out.println(json);
            webSocket.close(1000, "");
        }
        List<Text> textList = json.getPayload().getChoices().getText();
        for (Text temp : textList) {
            // 向客户端发送回答信息,如有存储问答需求,在此处存储
            ModelChatEndpoint.sendMsgByUserId(userId,temp.getContent());

            totalAnswer = totalAnswer + temp.getContent();
        }
        if (json.getHeader().getStatus() == 2) {
            // 可以关闭连接,释放资源
            if (canAddHistory()) {
                RoleContent roleContent = new RoleContent();
                roleContent.setRole("assistant");
                roleContent.setContent(totalAnswer);
                hisMap.get(userId).add(roleContent);
            } else {
                hisMap.get(userId).remove(0);
                RoleContent roleContent = new RoleContent();
                roleContent.setRole("assistant");
                roleContent.setContent(totalAnswer);
                hisMap.get(userId).add(roleContent);
            }
            //收到响应后让等待的线程停止等待
            wsCloseFlag = true;
        }
    }

    @Override
    public void onFailure(WebSocket webSocket, Throwable t, Response response) {
        super.onFailure(webSocket, t, response);
        try {
            if (null != response) {
                int code = response.code();
                System.out.println("onFailure code:" + code);
                System.out.println("onFailure body:" + response.body().string());
                if (101 != code) {
                    System.out.println("connection failed");
                    System.exit(0);
                }
            }
        } catch (IOException e) {
            e.printStackTrace();
        }
    }


    // 鉴权方法
    public static String getAuthUrl(String hostUrl, String apiKey, String apiSecret) throws Exception {
        URL url = new URL(hostUrl);
        // 时间
        SimpleDateFormat format = new SimpleDateFormat("EEE, dd MMM yyyy HH:mm:ss z", Locale.US);
        format.setTimeZone(TimeZone.getTimeZone("GMT"));
        String date = format.format(new Date());
        // 拼接
        String preStr = "host: " + url.getHost() + "\n" + "date: " + date + "\n" + "GET " + url.getPath() + " HTTP/1.1";
        // System.err.println(preStr);
        // SHA256加密
        Mac mac = Mac.getInstance("hmacsha256");
        SecretKeySpec spec = new SecretKeySpec(apiSecret.getBytes(StandardCharsets.UTF_8), "hmacsha256");
        mac.init(spec);

        byte[] hexDigits = mac.doFinal(preStr.getBytes(StandardCharsets.UTF_8));
        // Base64加密
        String sha = Base64.getEncoder().encodeToString(hexDigits);
        // System.err.println(sha);
        // 拼接
        String authorization = String.format("api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey, "hmac-sha256", "host date request-line", sha);
        // 拼接地址
        HttpUrl httpUrl = Objects.requireNonNull(HttpUrl.parse("https://" + url.getHost() + url.getPath())).newBuilder().//
                addQueryParameter("authorization", Base64.getEncoder().encodeToString(authorization.getBytes(StandardCharsets.UTF_8))).//
                addQueryParameter("date", date).//
                addQueryParameter("host", url.getHost()).//
                build();

        return httpUrl.toString();
    }
}

其中用来接收响应参数相关实体类

java 复制代码
@Data
public class JsonParse {
    Header header;
    Payload payload;
}
@Data
public class Header {
    int code;
    int status;
    String sid;
}
@Data
public class Payload {
    Choices choices;
}
@Data
public class Choices {
    List<Text> text;
}
@Data
public class Text {
    String role;
    String content;
}
@Data
public class RoleContent {
    String role;
    String content;
}

客户端和服务端的连接

java 复制代码
/**
 * 接收客户端请求
 *
 * @author gwj
 * @date 2024/10/29 16:51
 */
@ServerEndpoint(value = "/ws/model", configurator = GetUserConfigurator.class)
@Component
@Slf4j
public class ModelChatEndpoint {
    private static AtomicInteger online = new AtomicInteger(0);
    private static final ConcurrentHashMap<Long,ModelChatEndpoint> wsMap = new ConcurrentHashMap<>();

    private static BigModelConfig config;
    @Resource
    private BigModelConfig modelConfig;

    @PostConstruct
    public void init() {
        config = modelConfig;
    }

    private Session session;
    private Long userId;

    @OnOpen
    public void onOpen(EndpointConfig config, Session session) {
        String s = config.getUserProperties().get("id").toString();
        userId = Long.parseLong(s);
        this.session = session;
        wsMap.put(userId,this);
        online.incrementAndGet();
        log.info("用户{},连接成功,在线人数:{}",userId,online);
    }

    @OnClose
    public void onClose() {
        wsMap.remove(userId);
        online.incrementAndGet();
        log.info("{},退出,在线人数:{}",userId,online);
    }

    @OnError
    public void onError(Session session, Throwable error) {
        log.error("连接出错,{}", error.getMessage());
    }

    @OnMessage
    public void onMessage(String message,Session session) throws Exception {
        BigModelNew.ask(message);
        //构建鉴权url
        String authUrl = BigModelNew.getAuthUrl(config.getHostUrl(), config.getApiKey(), config.getApiSecret());
        OkHttpClient client = new OkHttpClient.Builder().build();
        String url = authUrl.replace("http://", "ws://").replace("https://", "wss://");
        Request request = new Request.Builder().url(url).build();
        WebSocket webSocket = client.newWebSocket(request,
                new BigModelNew(this.userId, false));
        log.info("收到客户端{}的消息:{}", userId, message);
    }


    private void sendMsg(String message) {
        try {
            this.session.getBasicRemote().sendText(message);
        } catch (IOException e) {
            log.error("客户端{}发送{}失败",userId,message);
        }
    }


    /**
     * 根据userId向用户发送消息
     *
     * @param userId 用户id
     * @param message 消息
     */
    public static void sendMsgByUserId(Long userId,String message) {
        if (userId != null && wsMap.containsKey(userId)) {
            wsMap.get(userId).sendMsg(message);
        }
    }

    
}

测试

这样就简单实现了一个ai问答功能

相关推荐
hummhumm17 分钟前
第 25 章 - Golang 项目结构
java·开发语言·前端·后端·python·elasticsearch·golang
deephub18 分钟前
优化注意力层提升 Transformer 模型效率:通过改进注意力机制降低机器学习成本
人工智能·深度学习·transformer·大语言模型·注意力机制
J老熊27 分钟前
JavaFX:简介、使用场景、常见问题及对比其他框架分析
java·开发语言·后端·面试·系统架构·软件工程
搏博30 分钟前
神经网络问题之二:梯度爆炸(Gradient Explosion)
人工智能·深度学习·神经网络
AuroraI'ncoding33 分钟前
时间请求参数、响应
java·后端·spring
KGback35 分钟前
【论文解析】HAQ: Hardware-Aware Automated Quantization With Mixed Precision
人工智能
电子手信43 分钟前
知识中台在多语言客户中的应用
大数据·人工智能·自然语言处理·数据挖掘·知识图谱
不高明的骗子44 分钟前
【深度学习之一】2024最新pytorch+cuda+cudnn下载安装搭建开发环境
人工智能·pytorch·深度学习·cuda
好奇的菜鸟1 小时前
Go语言中的引用类型:指针与传递机制
开发语言·后端·golang
Alive~o.01 小时前
Go语言进阶&依赖管理
开发语言·后端·golang