FastGPT源码深度剖析:如何使用LLM和向量模型进行训练

引言

在构建高效的知识库系统时,数据集的训练流程是至关重要的一环。FastGPT 提供了两种主要的训练模式------问答拆分(QA)和直接分段(chunk),旨在优化文本处理和知识提取的效率。本文将详细解析FastGPT的训练逻辑、任务处理机制以及如何通过数据集质量提升和上下文连贯性增强来优化训练效果。通过深入探讨FastGPT的内部工作机制,我们能够更好地理解如何利用这一工具来构建更加智能和响应迅速的知识库系统。

训练模式

FastGPT 知识库数据集训练模式包含问答拆分qa,直接分段chunk 两种模式,训练的基本逻辑如下 :

  • 根据训练模式判断是否需要与 agentModel 进行 QA 问答;
  • 使用 vertorModel 对数据集的每个 chunk 进行 embedding;

在训练时需要用到vertorModelagentModel两种模型,配置在 data/config.json 文件中。

在实际训练时,chunk 有长度限制,chunkSize 影响因素包含:

  1. agentModel 的 maxContext;
  2. 训练模式中设定的分块长度;

数据训练

创建对应任务之后,系统采用定时任务去扫描QA任务队列,Embedding任务队列进行任务处理,定时任务扫描频率1分钟。代码路径service/common/system/cron.ts,核心代码如下:

ts 复制代码
export const startCron = () => {
  setUpdateSystemConfigCron();
  setTrainingQueueCron();
};

export const setUpdateSystemConfigCron = () => {
  setCron('*/5 * * * *', () => {
    initSystemConfig();
    console.log('refresh system config');
  });
};

export const setTrainingQueueCron = () => {
  setCron('*/1 * * * *', () => {
    generateVector();
    generateQA();
  });
};

startCron 是定时任务的启动入口,其具体调用时机在 mongodb 连接的 afterHook 里面。

ts 复制代码
export function connectToDatabase(): Promise<void> {
  return connectMongo({
    beforeHook: () => {},
    afterHook: async () => {
      initVectorStore();
      // start queue
      startQueue();
      // init system config
      getInitConfig();

      // cron
      startCron();

      initRootUser();
    }
  });
}

QA任务训练流程

QA 任务训练的核心代码在 generateQA里面,主要流程如下:

  1. 判断QA任务是否达到并发上限;
  2. 从QA队列里找任务,如未找到则说明无训练任务结束执行;
  3. 检测团队的余额,不足则发送消息并锁定团队的训练任务;
  4. 与LLM交互,将文本转为QA形式,并将QA对推入 Embedding 任务队列;
  5. 记录QA训练日志、团队账单;
  6. 任务完成,继续下一轮任务训练;

以下是 generateQA的源码,除了以上说的操作外,里面还有一些细节,具体可以查看代码。

ts 复制代码
export async function generateQA(): Promise<any> {
  if (global.qaQueueLen >= global.systemEnv.qaMaxProcess) return;
  global.qaQueueLen++;

  // get training data
  const {data,text,done = false,error = false} = await (async () => {
    try {
      const data = await MongoDatasetTraining.findOneAndUpdate(
        {
          lockTime: { $lte: new Date(Date.now() - 6 * 60 * 1000) },
          mode: TrainingModeEnum.qa
        },
        {
          lockTime: new Date()
        }
      )
        .select({
          _id: 1,
          userId: 1,
          datasetId: 1,
          collectionId: 1,
          q: 1,
          model: 1,
          chunkIndex: 1,
          prompt: 1
          //... 省略部分字段
        })
        .lean();

      // task preemption
      if (!data) {
        return {
          done: true
        };
      }
      return {
        data,
        text: data.q
      };
    } catch (error) {
      console.log(`Get Training Data error`, error);
      return {
        error: true
      };
    }
  })();

  if (done || !data) {
    if (reduceQueue()) {
      console.log(`【QA】Task Done`);
    }
    return;
  }
  if (error) {
    reduceQueue();
    return generateQA();
  }

  // auth balance
  try {
    await authTeamBalance(data.teamId);
  } catch (error: any) {
    if (error?.statusText === UserErrEnum.balanceNotEnough) {
      // ... 省略代码,通知"团队账号余额不足",锁定团队的训练任务。
  }

  try {
    const startTime = Date.now();
    const model = getLLMModel(data.model)?.model;
    const prompt = `${data.prompt || Prompt_AgentQA.description}
${replaceVariable(Prompt_AgentQA.fixedText, { text })}`;

    // request LLM to get QA
    const messages: ChatMessageItemType[] = [
      {
        role: 'user',
        content: prompt
      }
    ];

    const ai = getAIApi({
      timeout: 600000
    });
    const chatResponse = await ai.chat.completions.create({
      model,
      temperature: 0.3,
      messages,
      stream: false
    });
    const answer = chatResponse.choices?.[0].message?.content || '';

    const qaArr = formatSplitText(answer, text); // 格式化后的QA对

    // get vector and insert
    const { insertLen } = await pushDataToTrainingQueue({
      teamId: data.teamId,
      tmbId: data.tmbId,
      collectionId: data.collectionId,
      trainingMode: TrainingModeEnum.chunk,
      data: qaArr.map((item) => ({
        ...item,
        chunkIndex: data.chunkIndex
      })),
      billId: data.billId
    });

    // delete data from training
    await MongoDatasetTraining.findByIdAndDelete(data._id);

    addLog.info(`QA Training Finish`, {
      time: `${(Date.now() - startTime) / 1000}s`,
      splitLength: qaArr.length,
      usage: chatResponse.usage
    });

    // add bill
    if (insertLen > 0) {
      pushQABill({teamId: data.teamId, tmbId: data.tmbId, charsLength: `${prompt}${answer}`.length,  billId: data.billId, model});
    } else {
      addLog.info(`QA result 0:`, { answer });
    }

    reduceQueue();
    generateQA();
  } catch (err: any) {
    // ... 省略代码,容错处理
  }
}

Embedding 任务训练流程

Embedding 任务训练的核心代码在 generateVector里面,其核心逻辑与 QA 训练基本一致,主要流程如下:

  1. 判断Embedding任务是否达到并发上限;
  2. 从Embedding队列里找任务,如未找到则说明无训练任务结束执行;
  3. 检测团队的余额,不足则发送消息并锁定团队的训练任务;
  4. 检测训练任务的文本是否存在,不存在说明 chunk 无作用,删除该训练任务;存在则与 vertorModel 交互获取向量并存入向量数据库。
  5. 记录QA训练日志、团队账单;
  6. 任务完成,继续下一轮任务训练;

generateVector 源码如下:

ts 复制代码
export async function generateVector(): Promise<any> {
  if (global.vectorQueueLen >= global.systemEnv.vectorMaxProcess) return;
  global.vectorQueueLen++;

  const start = Date.now();

  // get training data
  const { data, dataItem, done = false, error = false } = await (async () => {
    try {
      const data = await MongoDatasetTraining.findOneAndUpdate(
        {
          lockTime: { $lte: new Date(Date.now() - 1 * 60 * 1000) },
          mode: TrainingModeEnum.chunk
        },
        {
          lockTime: new Date()
        }
      )
        .sort({weight: -1})
        .select({
          _id: 1,
          userId: 1,
          teamId: 1,
          tmbId: 1,
          datasetId: 1,
          collectionId: 1,
          q: 1,
          a: 1,
          chunkIndex: 1,
          indexes: 1,
          model: 1,
          billId: 1
        })
        .lean();

      // task preemption
      if (!data) {
        return {
          done: true
        };
      }
      return {
        data,
        dataItem: {
          q: data.q,
          a: data.a || '',
          indexes: data.indexes
        }
      };
    } catch (error) {
      console.log(`Get Training Data error`, error);
      return {
        error: true
      };
    }
  })();

  if (done || !data) {
    if (reduceQueue()) {
      console.log(`【index】Task done`);
    }
    return;
  }
  if (error) {
    reduceQueue();
    return generateVector();
  }

  // auth balance
  try {
    await authTeamBalance(data.teamId);
  } catch (error: any) {
    if (error?.statusText === UserErrEnum.balanceNotEnough) {
      // ... 省略代码,通知"团队账号余额不足",锁定团队的训练任务。
  }

  // create vector and insert
  try {
    // invalid data
    if (!data.q.trim()) {
      await MongoDatasetTraining.findByIdAndDelete(data._id);
      reduceQueue();
      generateVector();
      return;
    }

    // insert data to pg
    const { charsLength } = await insertData2Dataset({
      teamId: data.teamId,
      tmbId: data.tmbId,
      datasetId: data.datasetId,
      collectionId: data.collectionId,
      q: dataItem.q,
      a: dataItem.a,
      chunkIndex: data.chunkIndex,
      indexes: dataItem.indexes,
      model: data.model
    });

    // push bill
    pushGenerateVectorBill({
      teamId: data.teamId,
      tmbId: data.tmbId,
      charsLength,
      model: data.model,
      billId: data.billId
    });

    // delete data from training
    await MongoDatasetTraining.findByIdAndDelete(data._id);
    reduceQueue();
    generateVector();

    console.log(`embedding finished, time: ${Date.now() - start}ms`);
  } catch (err: any) {
    reduceQueue();
    // log
    if (err?.response) {
      addLog.info('openai error: 生成向量错误', {
        status: err.response?.status,
        stateusText: err.response?.statusText,
        data: err.response?.data
      });
    } else {
      console.log(err);
      addLog.error(getErrText(err, '生成向量错误'));
    }

    // message error or openai account error
    if (
      err?.message === 'invalid message format' ||
      err.response?.data?.error?.type === 'invalid_request_error' ||
      err?.code === 500
    ) {
      addLog.info('Lock training data');
      console.log(err?.code);
      console.log(err.response?.data?.error?.type);
      console.log(err?.message);

      try {
        await MongoDatasetTraining.findByIdAndUpdate(data._id, {
          lockTime: new Date('2998/5/5')
        });
      } catch (error) {}
      return generateVector();
    }

    setTimeout(() => {
      generateVector();
    }, 1000);
  }
}

总结

FastGPT 提供的文本分割、知识库问答的文本训练功能能够基本满足RAG的需求,如果在实际落地时觉得效果不好可以从以下几个方面考虑优化:

  1. 数据集质量提升:数据集的准确性和完整性直接影响模型的表现。对数据集进行彻底的预处理,包括清洗、格式化和去噪,以确保输入数据的质量。

  2. 上下文连贯性增强:FastGPT在进行文本分割时,应考虑将相关内容的上下文信息一并纳入处理范围。这样可以提高数据召回的准确性,确保信息的完整性。

  3. 尝试QA问答: 如果使用的文本切割模式,可以尝试使用QA问答模式可能在召回时更加准确。

相关推荐
bobostudio19954 小时前
TypeScript 设计模式之【策略模式】
前端·javascript·设计模式·typescript·策略模式
Tiffany_Ho12 小时前
【TypeScript】知识点梳理(三)
前端·typescript
爱喝白开水a14 小时前
关于大模型在企业生产环境中的独立部署问题
人工智能·深度学习·llm·大语言模型·ai大模型·计算机技术·本地部署大模型
Langchain15 小时前
不可错过!CMU最新《生成式人工智能大模型》课程:从文本、图像到多模态大模型
人工智能·自然语言处理·langchain·大模型·llm·大语言模型·多模态大模型
幽影相随16 小时前
构建llama.cpp并在linux上使用gpu
llm·llama.cpp
看到请催我学习16 小时前
如何实现两个标签页之间的通信
javascript·css·typescript·node.js·html5
AAI机器之心17 小时前
LLM大模型:开源RAG框架汇总
人工智能·chatgpt·开源·大模型·llm·大语言模型·rag
网安-搬运工1 天前
RAG再总结之如何使大模型更好使用外部数据:四个不同层级及查询-文档对齐策略
人工智能·自然语言处理·大模型·llm·大语言模型·ai大模型·rag
大模型八哥1 天前
大模型扫盲系列——大模型实用技术介绍(上)
人工智能·程序人生·ai·大模型·llm·llama·ai大模型
天涯学馆1 天前
Deno与Secure TypeScript:安全的后端开发
前端·typescript·deno