spark基于HNSW向量检索

参考文档:https://talks.anghami.com/blazing-fast-approximate-nearest-neighbour-search-on-apache-spark-using-hnsw/

HNSW参数调优文档:https://github.com/nmslib/hnswlib/blob/master/ALGO_PARAMS.md

spark 运行HNSW向量检索分为以下三步

1 创建HNSW索引,并存储到磁盘

2 将存储的索引分发到每个executor

3 进行向量检索

使用HHSW构建索引,并使用spark进行分布式向量检索,1200万向量构建索引40分钟,向量检索10分钟完成(时间取决于m和ef的大小,本人m=30,ef=1000,不然总是报错m或者ef太小)如m=30,ef=1000 1200万构建索引20分钟,向量检索还是10分钟。

1 创建HNSW索引

输入为spark dataset格式数据,有id和features组成,features为Array[Float]形式向量

import com.stepstone.search.hnswlib.jna.{Index, SpaceName}
import org.apache.spark.SparkFiles
import org.apache.spark.sql.{Dataset, Encoder, SparkSession}
import java.nio.file.Paths
import scala.reflect.runtime.universe.TypeTag
class annUtilsHnsw {


  /**
   * Builds an hnsw index.
   *
   * Default HNSW parameters are found to be good enough.
   *
   * HNSW index requires integer based object ids, so the builder re-indexes the original objects keys into integer
   * keys.
   *
   * For information on HNSW parameter tuning, [[https://github.com/nmslib/hnswlib/blob/master/ALGO_PARAMS.md]]
   *
   * @param vectorSize features vector size
   * @param features objects features to build an index for
   * @param m a parameter for construction HNSW index
   * @param efConstruction a parameter for construction HNSW index
   * @tparam Key type of the object id in features objects
   * @return
   */
  def buildHnswIndex[Key : TypeTag : Encoder](spark:SparkSession,vectorSize: Int,
                                              features: Dataset[(Key, Array[Float])],
                                              m: Int = 100,
                                              efConstruction: Int = 200): HnswIndex[Key] = {

    // map objects keys to integer based index to be used in the HNSW index as it only accepts integer key
    import spark.implicits._
    val featuresReindexed = features
      .rdd.zipWithIndex().map(x=>{
      (x._1._1,x._1._2,x._2.toInt)
    })     .toDF("id", "features","index_id")
      .select("index_id", "id", "features")
      .cache()
    // collect feature vectors
    val featuresList = featuresReindexed
      .select($"index_id", $"features".cast("array<float>"))
      .as[(Int, Array[Float])]
      .collect()

    val objectIDsMap = featuresReindexed
      .select("index_id", "id")
      .as[(Int, Key)]
      .repartition(100)

    // build index
    val index = new Index(SpaceName.COSINE, vectorSize)
    index.initialize(featuresList.length, m, efConstruction, (System.currentTimeMillis() / 1000).toInt)
    //    index.initialize(indexLength, 16, 200, (System.currentTimeMillis() / 1000).toInt)
    println("featuresList length",featuresList.length)
    // add vectors in parallel using .par
    featuresList.par.foreach {
      case (id: Int, vector: Array[Float]) =>
        index.addItem(vector, id)
    }

    // return wrapped index
    new HnswIndex(vectorSize, index, objectIDsMap)
  }


}

2 索引存储及查找

存储索引,加载索引并分发到每个executor.然后进行ANN查找

import com.stepstone.search.hnswlib.jna.{Index, SpaceName}
import org.apache.spark.SparkFiles
import org.apache.spark.sql.{Dataset, Encoder, SparkSession}
import java.nio.file.Paths
import scala.reflect.runtime.universe.TypeTag

class HnswIndex[DstKey : TypeTag : Encoder](vectorSize: Int,
                                            index: Index,
                                            objectIDsMap: Dataset[(Int, DstKey)]) {


  /**
   * Executres KNN query using an HNSW index.
   *
   * @param queryFeatures      features to generates recs for
   * @param minScoreThreshold  Minimum similarity/distance.
   * @param topK               number of top recommendations to generate per instance
   * @param ef                 HNSW search time parameter
   * @param queryNumPartitions number of partitions for query vectors
   * @return
   */
  def knnQuery[SrcKey: TypeTag : Encoder](spark: SparkSession, queryFeatures: Dataset[(SrcKey, Array[Float])],
                                          minScoreThreshold: Double,
                                          topK: Int,
                                          ef: Int,
                                          queryNumPartitions: Int = 200, indexSavePath: String, m: Int, efConstruction: Int): Dataset[(SrcKey, DstKey, Double)] = {


    import spark.implicits._
    // init tmp directory
    val indexLength = index.getLength

    val saveLocalPath = "index"
    val indexLocalLocation = Paths.get(saveLocalPath)
    val indexFileName = indexLocalLocation.getFileName.toString
    println("indexFileName", indexFileName)
    // saving index locally
    index.save(indexLocalLocation)

    println(index.getData(0).get().mkString(","))
    val saveAbsoluteLocalPath = saveLocalPath
    println("local path", indexLocalLocation.toAbsolutePath.toString)
    println("absolute path: ", saveAbsoluteLocalPath)
    // add file to spark context to be sent to running nodes
    spark.sparkContext.addFile(indexFileName, true)
    //    spark.sparkContext.addFile(indexSavePath,true)

    println("context path: ", SparkFiles.getRootDirectory + "/" + indexFileName)

    // The current interface to HNSW misses the functionality of setting the ef query time
    // parameter, but it's lower bounded by topK as per https://github.com/nmslib/hnswlib/blob/master/ALGO_PARAMS.md#search-parameters,
    // so set a large value of k as max(ef, topK) to get high recall, then cut off after getting the nearest neighbor.
    val k = math.max(topK, ef)

    // local scope vectorSize
    val vectorSizeLocal = vectorSize

    // execute querying
    queryFeatures
      .repartition(queryNumPartitions)
      .toDF("id", "features")
      .withColumn("features", $"features".cast("array<float>"))
      .as[(SrcKey, Array[Float])]
      .mapPartitions((it: Iterator[(SrcKey, Array[Float])]) => {
        // load index
        val index = new Index(SpaceName.COSINE, vectorSizeLocal)

        index.initialize(indexLength, m, efConstruction, (System.currentTimeMillis() / 1000).toInt)
        index.load(Paths.get(SparkFiles.getRootDirectory + "/" + indexFileName), indexLength)


        it.flatMap(x => {
          val idx = x._1
          val vector = x._2
          val queryTuple = index.knnQuery(vector, k)
          val result = queryTuple.getIds
            //            queryTuple.getLabels
            .zip(queryTuple.getCoefficients)
            .map(qt => (idx, qt._1, 1.0 - qt._2.toDouble))
            .filter(_._3 >= minScoreThreshold)
            .sortBy(_._3)
            .reverse
            .slice(0, topK)
          result
        })
      })

      .as[(Int, Int, Double)]
      .toDF("src_id", "index_id", "score")
      .join(objectIDsMap.toDF("index_id", "dst_id"), Seq("index_id"))
      .select("src_id", "dst_id", "score")
      .repartition(400)
      .as[(SrcKey, DstKey, Double)]

  }
}

3 word2vec向量检索实例

  • 训练word2vec模型

  • 将模型的向量取出,调用上面buildHnswIndex 构建索引

  • 分布式进行knnQuery 向量检索

    import org.apache.spark.ml.feature.Word2VecModel
    import org.apache.spark.ml.linalg.DenseVector

    object exampleWord2Vec {
    def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder().getOrCreate()

      val GraphInputModel =  "graph/model/word2vecmodel"
      val indexPath =  "graph/model/index"
      spark.udf.register("denseVec2Array",(vec:DenseVector ) => vec.toArray.map(_.toFloat))
      spark.udf.register("vectorSplit",(a:String)=>(a.split(',').map(_.toFloat)))
      import spark.implicits._
      val word2vec = Word2VecModel.load(GraphInputModel )
      println(word2vec .getVectors.schema)
      word2vec .getVectors.show(10)
      println(word2vec .getVectors.count())
      val itemEmbeddings = word2vec .getVectors.selectExpr("cast(word as Int) as word", "denseVec2Array(vector) features")
        .as[(Int,Array[Float])]
      itemEmbeddings.show()
      println(itemEmbeddings.schema)
      val vectorsize=itemEmbeddings.take(1)(0)._2.length
    
      val hnswIndex = new annUtilsHnsw().buildHnswIndex(spark, vectorsize, itemEmbeddings, 20)
      val queryDF=hnswIndex.knnQuery[Int](spark,itemEmbeddings.limit(20),0.3,20,200,160,indexPath,20,200)
      queryDF .write.mode("overwrite").save(savePathMl + "graph/muiscEmbedding")
    }
    

    }

4 HNSW pom依赖文件

hnswlib-jna

        <dependency>
            <groupId>com.stepstone.search.hnswlib.jna</groupId>
            <artifactId>hnswlib-jna</artifactId>
            <version>1.4.2</version>
        </dependency>
相关推荐
云云32135 分钟前
怎么通过亚矩阵云手机实现营销?
大数据·服务器·安全·智能手机·矩阵
新加坡内哥谈技术1 小时前
苏黎世联邦理工学院与加州大学伯克利分校推出MaxInfoRL:平衡内在与外在探索的全新强化学习框架
大数据·人工智能·语言模型
Data-Miner1 小时前
经典案例PPT | 大型水果连锁集团新零售数字化建设方案
大数据·big data
lovelin+v175030409661 小时前
安全性升级:API接口在零信任架构下的安全防护策略
大数据·数据库·人工智能·爬虫·数据分析
道一云黑板报2 小时前
Flink集群批作业实践:七析BI批作业执行
大数据·分布式·数据分析·flink·kubernetes
节点。csn2 小时前
flink集群搭建 详细教程
大数据·服务器·flink
数据爬坡ing3 小时前
小白考研历程:跌跌撞撞,起起伏伏,五个月备战历程!!!
大数据·笔记·考研·数据分析
云云3213 小时前
云手机方案全解析
大数据·服务器·安全·智能手机·矩阵
武子康4 小时前
大数据-257 离线数仓 - 数据质量监控 监控方法 Griffin架构
java·大数据·数据仓库·hive·hadoop·后端
碳学长4 小时前
2025系统架构师(一考就过):案例题之一:嵌入式架构、大数据架构、ISA
大数据·架构·系统架构