FastGPT源码深度剖析:如何使用LLM和向量模型进行训练

引言

在构建高效的知识库系统时,数据集的训练流程是至关重要的一环。FastGPT 提供了两种主要的训练模式------问答拆分(QA)和直接分段(chunk),旨在优化文本处理和知识提取的效率。本文将详细解析FastGPT的训练逻辑、任务处理机制以及如何通过数据集质量提升和上下文连贯性增强来优化训练效果。通过深入探讨FastGPT的内部工作机制,我们能够更好地理解如何利用这一工具来构建更加智能和响应迅速的知识库系统。

训练模式

FastGPT 知识库数据集训练模式包含问答拆分qa,直接分段chunk 两种模式,训练的基本逻辑如下 :

  • 根据训练模式判断是否需要与 agentModel 进行 QA 问答;
  • 使用 vertorModel 对数据集的每个 chunk 进行 embedding;

在训练时需要用到vertorModelagentModel两种模型,配置在 data/config.json 文件中。

在实际训练时,chunk 有长度限制,chunkSize 影响因素包含:

  1. agentModel 的 maxContext;
  2. 训练模式中设定的分块长度;

数据训练

创建对应任务之后,系统采用定时任务去扫描QA任务队列,Embedding任务队列进行任务处理,定时任务扫描频率1分钟。代码路径service/common/system/cron.ts,核心代码如下:

ts 复制代码
export const startCron = () => {
  setUpdateSystemConfigCron();
  setTrainingQueueCron();
};

export const setUpdateSystemConfigCron = () => {
  setCron('*/5 * * * *', () => {
    initSystemConfig();
    console.log('refresh system config');
  });
};

export const setTrainingQueueCron = () => {
  setCron('*/1 * * * *', () => {
    generateVector();
    generateQA();
  });
};

startCron 是定时任务的启动入口,其具体调用时机在 mongodb 连接的 afterHook 里面。

ts 复制代码
export function connectToDatabase(): Promise<void> {
  return connectMongo({
    beforeHook: () => {},
    afterHook: async () => {
      initVectorStore();
      // start queue
      startQueue();
      // init system config
      getInitConfig();

      // cron
      startCron();

      initRootUser();
    }
  });
}

QA任务训练流程

QA 任务训练的核心代码在 generateQA里面,主要流程如下:

  1. 判断QA任务是否达到并发上限;
  2. 从QA队列里找任务,如未找到则说明无训练任务结束执行;
  3. 检测团队的余额,不足则发送消息并锁定团队的训练任务;
  4. 与LLM交互,将文本转为QA形式,并将QA对推入 Embedding 任务队列;
  5. 记录QA训练日志、团队账单;
  6. 任务完成,继续下一轮任务训练;

以下是 generateQA的源码,除了以上说的操作外,里面还有一些细节,具体可以查看代码。

ts 复制代码
export async function generateQA(): Promise<any> {
  if (global.qaQueueLen >= global.systemEnv.qaMaxProcess) return;
  global.qaQueueLen++;

  // get training data
  const {data,text,done = false,error = false} = await (async () => {
    try {
      const data = await MongoDatasetTraining.findOneAndUpdate(
        {
          lockTime: { $lte: new Date(Date.now() - 6 * 60 * 1000) },
          mode: TrainingModeEnum.qa
        },
        {
          lockTime: new Date()
        }
      )
        .select({
          _id: 1,
          userId: 1,
          datasetId: 1,
          collectionId: 1,
          q: 1,
          model: 1,
          chunkIndex: 1,
          prompt: 1
          //... 省略部分字段
        })
        .lean();

      // task preemption
      if (!data) {
        return {
          done: true
        };
      }
      return {
        data,
        text: data.q
      };
    } catch (error) {
      console.log(`Get Training Data error`, error);
      return {
        error: true
      };
    }
  })();

  if (done || !data) {
    if (reduceQueue()) {
      console.log(`【QA】Task Done`);
    }
    return;
  }
  if (error) {
    reduceQueue();
    return generateQA();
  }

  // auth balance
  try {
    await authTeamBalance(data.teamId);
  } catch (error: any) {
    if (error?.statusText === UserErrEnum.balanceNotEnough) {
      // ... 省略代码,通知"团队账号余额不足",锁定团队的训练任务。
  }

  try {
    const startTime = Date.now();
    const model = getLLMModel(data.model)?.model;
    const prompt = `${data.prompt || Prompt_AgentQA.description}
${replaceVariable(Prompt_AgentQA.fixedText, { text })}`;

    // request LLM to get QA
    const messages: ChatMessageItemType[] = [
      {
        role: 'user',
        content: prompt
      }
    ];

    const ai = getAIApi({
      timeout: 600000
    });
    const chatResponse = await ai.chat.completions.create({
      model,
      temperature: 0.3,
      messages,
      stream: false
    });
    const answer = chatResponse.choices?.[0].message?.content || '';

    const qaArr = formatSplitText(answer, text); // 格式化后的QA对

    // get vector and insert
    const { insertLen } = await pushDataToTrainingQueue({
      teamId: data.teamId,
      tmbId: data.tmbId,
      collectionId: data.collectionId,
      trainingMode: TrainingModeEnum.chunk,
      data: qaArr.map((item) => ({
        ...item,
        chunkIndex: data.chunkIndex
      })),
      billId: data.billId
    });

    // delete data from training
    await MongoDatasetTraining.findByIdAndDelete(data._id);

    addLog.info(`QA Training Finish`, {
      time: `${(Date.now() - startTime) / 1000}s`,
      splitLength: qaArr.length,
      usage: chatResponse.usage
    });

    // add bill
    if (insertLen > 0) {
      pushQABill({teamId: data.teamId, tmbId: data.tmbId, charsLength: `${prompt}${answer}`.length,  billId: data.billId, model});
    } else {
      addLog.info(`QA result 0:`, { answer });
    }

    reduceQueue();
    generateQA();
  } catch (err: any) {
    // ... 省略代码,容错处理
  }
}

Embedding 任务训练流程

Embedding 任务训练的核心代码在 generateVector里面,其核心逻辑与 QA 训练基本一致,主要流程如下:

  1. 判断Embedding任务是否达到并发上限;
  2. 从Embedding队列里找任务,如未找到则说明无训练任务结束执行;
  3. 检测团队的余额,不足则发送消息并锁定团队的训练任务;
  4. 检测训练任务的文本是否存在,不存在说明 chunk 无作用,删除该训练任务;存在则与 vertorModel 交互获取向量并存入向量数据库。
  5. 记录QA训练日志、团队账单;
  6. 任务完成,继续下一轮任务训练;

generateVector 源码如下:

ts 复制代码
export async function generateVector(): Promise<any> {
  if (global.vectorQueueLen >= global.systemEnv.vectorMaxProcess) return;
  global.vectorQueueLen++;

  const start = Date.now();

  // get training data
  const { data, dataItem, done = false, error = false } = await (async () => {
    try {
      const data = await MongoDatasetTraining.findOneAndUpdate(
        {
          lockTime: { $lte: new Date(Date.now() - 1 * 60 * 1000) },
          mode: TrainingModeEnum.chunk
        },
        {
          lockTime: new Date()
        }
      )
        .sort({weight: -1})
        .select({
          _id: 1,
          userId: 1,
          teamId: 1,
          tmbId: 1,
          datasetId: 1,
          collectionId: 1,
          q: 1,
          a: 1,
          chunkIndex: 1,
          indexes: 1,
          model: 1,
          billId: 1
        })
        .lean();

      // task preemption
      if (!data) {
        return {
          done: true
        };
      }
      return {
        data,
        dataItem: {
          q: data.q,
          a: data.a || '',
          indexes: data.indexes
        }
      };
    } catch (error) {
      console.log(`Get Training Data error`, error);
      return {
        error: true
      };
    }
  })();

  if (done || !data) {
    if (reduceQueue()) {
      console.log(`【index】Task done`);
    }
    return;
  }
  if (error) {
    reduceQueue();
    return generateVector();
  }

  // auth balance
  try {
    await authTeamBalance(data.teamId);
  } catch (error: any) {
    if (error?.statusText === UserErrEnum.balanceNotEnough) {
      // ... 省略代码,通知"团队账号余额不足",锁定团队的训练任务。
  }

  // create vector and insert
  try {
    // invalid data
    if (!data.q.trim()) {
      await MongoDatasetTraining.findByIdAndDelete(data._id);
      reduceQueue();
      generateVector();
      return;
    }

    // insert data to pg
    const { charsLength } = await insertData2Dataset({
      teamId: data.teamId,
      tmbId: data.tmbId,
      datasetId: data.datasetId,
      collectionId: data.collectionId,
      q: dataItem.q,
      a: dataItem.a,
      chunkIndex: data.chunkIndex,
      indexes: dataItem.indexes,
      model: data.model
    });

    // push bill
    pushGenerateVectorBill({
      teamId: data.teamId,
      tmbId: data.tmbId,
      charsLength,
      model: data.model,
      billId: data.billId
    });

    // delete data from training
    await MongoDatasetTraining.findByIdAndDelete(data._id);
    reduceQueue();
    generateVector();

    console.log(`embedding finished, time: ${Date.now() - start}ms`);
  } catch (err: any) {
    reduceQueue();
    // log
    if (err?.response) {
      addLog.info('openai error: 生成向量错误', {
        status: err.response?.status,
        stateusText: err.response?.statusText,
        data: err.response?.data
      });
    } else {
      console.log(err);
      addLog.error(getErrText(err, '生成向量错误'));
    }

    // message error or openai account error
    if (
      err?.message === 'invalid message format' ||
      err.response?.data?.error?.type === 'invalid_request_error' ||
      err?.code === 500
    ) {
      addLog.info('Lock training data');
      console.log(err?.code);
      console.log(err.response?.data?.error?.type);
      console.log(err?.message);

      try {
        await MongoDatasetTraining.findByIdAndUpdate(data._id, {
          lockTime: new Date('2998/5/5')
        });
      } catch (error) {}
      return generateVector();
    }

    setTimeout(() => {
      generateVector();
    }, 1000);
  }
}

总结

FastGPT 提供的文本分割、知识库问答的文本训练功能能够基本满足RAG的需求,如果在实际落地时觉得效果不好可以从以下几个方面考虑优化:

  1. 数据集质量提升:数据集的准确性和完整性直接影响模型的表现。对数据集进行彻底的预处理,包括清洗、格式化和去噪,以确保输入数据的质量。

  2. 上下文连贯性增强:FastGPT在进行文本分割时,应考虑将相关内容的上下文信息一并纳入处理范围。这样可以提高数据召回的准确性,确保信息的完整性。

  3. 尝试QA问答: 如果使用的文本切割模式,可以尝试使用QA问答模式可能在召回时更加准确。

相关推荐
岁月月宝贝2 小时前
Datawhale冬令营第二期!Task2🌼
llm·ai辅助编程
冰红茶-Tea7 小时前
typescript数据类型(二)
前端·typescript
zaim18 小时前
计算机的错误计算(一百九十二)
人工智能·ai·大模型·llm·错误·误差/error·余割/csc
Hoper.J12 小时前
微调 BERT:实现抽取式问答
人工智能·深度学习·自然语言处理·llm·bert
知来者逆1 天前
Binoculars——分析证实大语言模型生成文本的检测和引用量按学科和国家明确显示了使用偏差的多样性和对内容类型的影响
人工智能·深度学习·语言模型·自然语言处理·llm·大语言模型
几米哥1 天前
如何构建高效的AI代理系统:LLM应用实践与最佳方案的深度解析
llm·aigc
测试者家园1 天前
ChatGPT生成接口文档实践案例(二)
软件测试·chatgpt·llm·测试用例·测试图书·质量效能·用chatgpt做测试
Web阿成2 天前
3.学习webpack配置 尝试打包ts文件
前端·学习·webpack·typescript
j喬乔2 天前
Node导入不了命名函数?记一次Bug的探索
typescript·node.js
yg_小小程序员3 天前
vue3中使用vuedraggable实现拖拽
typescript·vue