hello,我是你们的ys指风不买醉。AI 是2025年的大趋势,谁又不想拥抱AI,了解AI 底层实现尼。 下面带大家 实现RAG里面 embedding 模型处理文本向量化过程
首先,embedding 是什么?
在 RAG(Retrieval-Augmented Generation)架构中,embedding 是实现文本向量化的重要组成部分。其核心思想是将自然语言文本转换为高维向量,借助这些向量可以实现基于语义的逻辑搜索。
也就是说,我们会先将资料库中的文本(比如文章标题、分类等)利用 embedding 模型转换成向量,再将用户问题同样转换为向量,通过计算两者之间的相似度,来找到最符合用户意图的文本内容。
想了解更详细RAG:# AI全栈必问的RAG 是什么!
简化版 embedding 实现流程
下面的示例展示了如何快速实现一个简化版的 embedding 应用,包括后端环境搭建、模型封装、文件读写及跨域处理等。
1. 环境初始化与模型封装
首先通过 npm init -y
初始化后端 Node.js 环境。与之前封装 openai 类似,现在我们封装的是 embedding 模型。在此过程中建议使用 dotenv
模块保护你的 API key,防止泄露。
mjs
// openai 实例化
import OpenAI from 'openai';
import dotenv from 'dotenv';
dotenv.config({
path: '.env'
});
export const client = new OpenAI({
apiKey: process.env.OPENAI_API_KEY,
baseURL: process.env.OPENAI_API_BASE_URL,
});
2. 读写文件及调用 embedding 模型
利用 fs/promises
模块进行文件的读写操作,避免回调地狱,同时使用 async/await 让代码更加清晰。示例中,我们从 posts.json
中读取待向量化的文章数据,然后调用 embedding 模型生成对应的向量,并将结果存储到新的文件中。
数据,可以自己模拟
json
[
{
"title": "如何使用 Nuxt.js 进行服务器端渲染",
"category": "前端开发"
},
... // 怕篇幅有点多,其他可以自行模拟这种格式
]
文件放的地方,参照文件目录:
mjs
import fs from 'fs/promises';
import { client } from './app.service.mjs';
// 定义输入输出文件路径
const inputFilePath = './data/posts.json';
const outputFilePath = './data/posts_with_embeddings.json';
// 异步读取数据文件并解析 JSON 格式
const data = await fs.readFile(inputFilePath, 'utf8');
const posts = JSON.parse(data);
const postsWithEmbedding = [];
// 遍历每篇文章,生成 embedding 向量
for (const { title, category } of posts) {
const response = await client.embeddings.create({
model: 'text-embedding-ada-002',
input: `标题:${title};分类:${category}`
});
postsWithEmbedding.push({
title,
category,
embedding: response.data[0].embedding
});
}
// 将生成 embedding 的结果写入到新文件中
await fs.writeFile(outputFilePath, JSON.stringify(postsWithEmbedding));
3. 构建后端服务并实现搜索接口
使用 Koa 框架搭建服务,并通过 @koa/cors
处理跨域问题。由于前端传值通常采用 JSON 格式,因此引入 koa-bodyparser
来自动解析请求体。以下代码展示了如何监听 3000 端口,并实现一个 /search
接口,用于接收查询关键字、生成向量并计算余弦相似度,最后返回最匹配的结果。
mjs
import Koa from 'koa';
import cors from '@koa/cors';
import Router from 'koa-router';
import bodyParser from 'koa-bodyparser';
import { client } from './app.service.mjs';
import fs from 'fs/promises';
const inputFilePath = './data/posts_with_embeddings.json';
const data = await fs.readFile(inputFilePath, 'utf8');
const posts = JSON.parse(data);
const app = new Koa();
const router = new Router();
const port = 3000;
app.use(cors());
app.use(bodyParser());
// 使用路由处理请求
app.use(router.routes());
app.use(router.allowedMethods());
// 监听服务启动
app.listen(port, () => {
console.log(`Server is running on port ${port}`);
});
// 计算余弦相似度的函数
function cosineSimilarity(a, b) {
if (a.length !== b.length) {
throw new Error('向量长度不匹配');
}
let dotProduct = 0;
let normA = 0;
let normB = 0;
for (let i = 0; i < a.length; i++) {
dotProduct += a[i] * b[i];
normA += a[i] * a[i];
normB += b[i] * b[i];
}
return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB));
}
// 定义搜索路由
router.post('/search', async (ctx) => {
const { keword } = ctx.request.body; // 从请求体中获取关键字
console.log(keword);
// 生成查询关键字的 embedding 向量
const response = await client.embeddings.create({
model: 'text-embedding-ada-002',
input: keword,
});
const { embedding } = response.data[0]; // 获取生成的向量
// 计算每篇文章与查询向量的相似度
const results = posts.map(item => ({
...item,
similarity: cosineSimilarity(embedding, item.embedding)
}));
// 按相似度降序排序,并提取最相似的前三条记录
const topResults = results.sort((a, b) => b.similarity - a.similarity)
.slice(0, 3)
.map((item, index) => ({
id: index,
title: `${index + 1}.${item.title}, ${item.category}`
}));
ctx.body = {
status: 200,
data: topResults
};
});
注意sort 返回新数组,不能直接使用data:results
。可以采用在原results 链式调用sort,也可以使用topResults
接收新值传给data。
余弦相似度函数解析
下面这段代码的作用是计算两个向量之间的余弦相似度。余弦相似度是一种衡量两个向量在方向上相似程度的指标,数值范围通常在 -1 到 1 之间(对于正向量,一般在 0 到 1 之间)。值越接近 1,表示两个向量在空间中的方向越接近;值越低,则说明两个向量在语义上越不相关。
js
function cosineSimilarity(a, b) {
if (a.length !== b.length) {
throw new Error('向量长度不匹配');
}
let dotProduct = 0;
let normA = 0;
let normB = 0;
for (let i = 0; i < a.length; i++) {
dotProduct += a[i] * b[i];
normA += a[i] * a[i];
normB += b[i] * b[i];
}
return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB));
}
在这个函数中:
- 判断长度是否一致:若两个向量长度不同,则抛出错误。
- 计算点积:遍历向量,逐项乘积相加。
- 计算向量模:分别求出向量 a 和向量 b 的平方和,然后开平方。
- 返回余弦相似度:将点积除以两个向量模的乘积,得到两向量之间的相似度。
网上借了两张图,方便友友理解:
CORS 配置扩展
默认情况下,我们允许所有跨域请求,但如果需要更细粒度的控制,可以配置允许跨域的源。例如,下面的代码展示了如何设置允许跨域请求的来源、方法、请求头以及是否允许携带凭据。
js
// 配置 CORS
app.use(cors({
origin: (ctx) => {
const allowedOrigins = ['http://localhost:3000', 'http://example.com'];
const requestOrigin = ctx.request.header.origin;
if (allowedOrigins.includes(requestOrigin)) {
return requestOrigin; // 允许该来源
}
return ''; // 拒绝跨域请求
},
allowMethods: ['GET', 'POST'], // 允许的 HTTP 方法
allowHeaders: ['Content-Type', 'Authorization'], // 允许的请求头
credentials: true // 允许携带凭据
}));
小结
通过以上代码示例,我们展示了如何利用 embedding 模型实现文本向量化,再结合余弦相似度计算实现基于自然语义的搜索。这种方式不仅提升了搜索的准确性,还能应对复杂的文本匹配场景。