大量数据相似度加速计算

背景

在实际工作中,有100万的数据,需要将100万条数据中,语义相似的聚合一起作为list,由于数据量过大,计算相似性耗时较久

例如:
合并后的数据

复制代码
[[你好,你好啊,您好,hello],
[深圳明天的天气],[明天深圳天气是怎么样的?]]

思路

① 先计算100万条数据的 embedding,假设每条长度768维,100万*768,计算量还是相当大的,耗时久

② 通过余弦相似度计算两两之间的相似度的概率,每一条数据要对其他所有的数据计算相似性

③拿取两两相似的概率,判断概率大于阈值的,合并到一起

优化思路

① 将100万*768转为numpy的数组计算

② 计算过的不需要计算了,比如计算了1,2。之后就不用计算2,1了,例如:

下标 待求相似的下标范围
0 [1-100万]
1 [2-100万]
2 [3-100万]
- -

就是求矩阵的下三角

③ 计算使用并行的方式处理

④ 使用GPU的方式计算,如果使用cpu,容易oom,50w的数据耗时大概2h左右,如果使用GPU的方式时,需要5min左右

代码

python 复制代码
def judgeSimilar(df, chunk_size: int = 5000, device='cuda'):
    print("----------计算相似性的数据----------")
    embeddings = np.array(df['embedding'].tolist())
    
    # 将embeddings转换为PyTorch tensor并迁移到GPU
    embeddings_tensor = torch.tensor(embeddings, device=device, dtype=torch.float32)
    
    # 归一化向量
    norm_embeddings = embeddings_tensor / torch.norm(embeddings_tensor, dim=1, keepdim=True)
    
    # 初始化相似度字典
    similarity_dict = {}

    def compute_similarity_chunk(start_idx, end_idx):
        local_similarity_dict = {}
        for i in range(start_idx, end_idx):
            # 计算当前向量与之后所有向量的余弦相似性
            similarities = torch.matmul(norm_embeddings[i], norm_embeddings[i+1:].T)
            similar_indices = torch.where(similarities > 0.93)[0]
            if similar_indices.size(0) > 0:
                local_similarity_dict[i] = (similar_indices + i + 1).tolist()
        return local_similarity_dict

    # 创建任务列表
    tasks = [(i, min(i + chunk_size, len(norm_embeddings))) for i in range(0, len(norm_embeddings), chunk_size)]

    # 使用 tqdm 进度条
    with tqdm(total=len(tasks)) as pbar:
        def update(*a):
            pbar.update()

        # 使用 Parallel 和 delayed 进行并行计算
        results = Parallel(n_jobs=-1, backend='loky', verbose=5)(
            delayed(compute_similarity_chunk)(start, end) for start, end in tasks)

        # 在每次计算完成后更新进度条
        for _ in results:
            update()

    # 汇总结果
    for local_dict in results:
        similarity_dict.update(local_dict)

    # 对结果进行去重
    keys_to_remove = set()
    all_keys = similarity_dict.keys()
    for key, values in similarity_dict.items():
        for v in values:
            if v in all_keys:
                keys_to_remove.add(v)
    for key in keys_to_remove:
        del similarity_dict[key]
    similarity_dict = {int(key): value if not isinstance(value, list) else [int(x) for x in value] for key, value in
                       similarity_dict.items()}
    # 保存相似性的结果
    with open('similarity_dict.json', 'w') as file:
        json.dump(similarity_dict, file)
    print("数据写入文件:similarity_dict.json")
    return similarity_dict
相关推荐
算法鑫探2 小时前
闰年判断:C语言实战解析
c语言·数据结构·算法·新人首发
yaoxin5211232 小时前
384. Java IO API - Java 文件复制工具:Copy 示例完整解析
java·开发语言·python
WBluuue2 小时前
数据结构与算法:康托展开、约瑟夫环、完美洗牌
c++·算法
NotFound4862 小时前
实战指南如何实现Java Web 拦截机制:Filter 与 Interceptor 深度分享
java·开发语言·前端
Elastic 中国社区官方博客2 小时前
Elasticsearch:快速近似 ES|QL - 第一部分
大数据·运维·数据库·elasticsearch·搜索引擎·全文检索
木子墨5162 小时前
LeetCode 热题 100 精讲 | 并查集篇:最长连续序列 · 岛屿数量 · 省份数量 · 冗余连接 · 等式方程的可满足性
数据结构·c++·算法·leetcode
王老师青少年编程3 小时前
csp信奥赛C++高频考点专项训练之贪心算法 --【线性扫描贪心】:均分纸牌
c++·算法·编程·贪心·csp·信奥赛·均分纸牌
EQUINOX13 小时前
2026年码蹄杯 本科院校赛道&青少年挑战赛道提高组初赛(省赛)第一场,个人题解
算法
萝卜小白4 小时前
算法实习Day04-MinerU2.5-pro
人工智能·算法·机器学习
Liangwei Lin4 小时前
洛谷 P3133 [USACO16JAN] Radio Contact G
数据结构·算法