HNSW参数调优文档:https://github.com/nmslib/hnswlib/blob/master/ALGO_PARAMS.md
spark 运行HNSW向量检索分为以下三步
1 创建HNSW索引,并存储到磁盘
2 将存储的索引分发到每个executor
3 进行向量检索
使用HHSW构建索引,并使用spark进行分布式向量检索,1200万向量构建索引40分钟,向量检索10分钟完成(时间取决于m和ef的大小,本人m=30,ef=1000,不然总是报错m或者ef太小)如m=30,ef=1000 1200万构建索引20分钟,向量检索还是10分钟。
1 创建HNSW索引
输入为spark dataset格式数据,有id和features组成,features为Array[Float]形式向量
import com.stepstone.search.hnswlib.jna.{Index, SpaceName}
import org.apache.spark.SparkFiles
import org.apache.spark.sql.{Dataset, Encoder, SparkSession}
import java.nio.file.Paths
import scala.reflect.runtime.universe.TypeTag
class annUtilsHnsw {
/**
* Builds an hnsw index.
*
* Default HNSW parameters are found to be good enough.
*
* HNSW index requires integer based object ids, so the builder re-indexes the original objects keys into integer
* keys.
*
* For information on HNSW parameter tuning, [[https://github.com/nmslib/hnswlib/blob/master/ALGO_PARAMS.md]]
*
* @param vectorSize features vector size
* @param features objects features to build an index for
* @param m a parameter for construction HNSW index
* @param efConstruction a parameter for construction HNSW index
* @tparam Key type of the object id in features objects
* @return
*/
def buildHnswIndex[Key : TypeTag : Encoder](spark:SparkSession,vectorSize: Int,
features: Dataset[(Key, Array[Float])],
m: Int = 100,
efConstruction: Int = 200): HnswIndex[Key] = {
// map objects keys to integer based index to be used in the HNSW index as it only accepts integer key
import spark.implicits._
val featuresReindexed = features
.rdd.zipWithIndex().map(x=>{
(x._1._1,x._1._2,x._2.toInt)
}) .toDF("id", "features","index_id")
.select("index_id", "id", "features")
.cache()
// collect feature vectors
val featuresList = featuresReindexed
.select($"index_id", $"features".cast("array<float>"))
.as[(Int, Array[Float])]
.collect()
val objectIDsMap = featuresReindexed
.select("index_id", "id")
.as[(Int, Key)]
.repartition(100)
// build index
val index = new Index(SpaceName.COSINE, vectorSize)
index.initialize(featuresList.length, m, efConstruction, (System.currentTimeMillis() / 1000).toInt)
// index.initialize(indexLength, 16, 200, (System.currentTimeMillis() / 1000).toInt)
println("featuresList length",featuresList.length)
// add vectors in parallel using .par
featuresList.par.foreach {
case (id: Int, vector: Array[Float]) =>
index.addItem(vector, id)
}
// return wrapped index
new HnswIndex(vectorSize, index, objectIDsMap)
}
}
2 索引存储及查找
存储索引,加载索引并分发到每个executor.然后进行ANN查找
import com.stepstone.search.hnswlib.jna.{Index, SpaceName}
import org.apache.spark.SparkFiles
import org.apache.spark.sql.{Dataset, Encoder, SparkSession}
import java.nio.file.Paths
import scala.reflect.runtime.universe.TypeTag
class HnswIndex[DstKey : TypeTag : Encoder](vectorSize: Int,
index: Index,
objectIDsMap: Dataset[(Int, DstKey)]) {
/**
* Executres KNN query using an HNSW index.
*
* @param queryFeatures features to generates recs for
* @param minScoreThreshold Minimum similarity/distance.
* @param topK number of top recommendations to generate per instance
* @param ef HNSW search time parameter
* @param queryNumPartitions number of partitions for query vectors
* @return
*/
def knnQuery[SrcKey: TypeTag : Encoder](spark: SparkSession, queryFeatures: Dataset[(SrcKey, Array[Float])],
minScoreThreshold: Double,
topK: Int,
ef: Int,
queryNumPartitions: Int = 200, indexSavePath: String, m: Int, efConstruction: Int): Dataset[(SrcKey, DstKey, Double)] = {
import spark.implicits._
// init tmp directory
val indexLength = index.getLength
val saveLocalPath = "index"
val indexLocalLocation = Paths.get(saveLocalPath)
val indexFileName = indexLocalLocation.getFileName.toString
println("indexFileName", indexFileName)
// saving index locally
index.save(indexLocalLocation)
println(index.getData(0).get().mkString(","))
val saveAbsoluteLocalPath = saveLocalPath
println("local path", indexLocalLocation.toAbsolutePath.toString)
println("absolute path: ", saveAbsoluteLocalPath)
// add file to spark context to be sent to running nodes
spark.sparkContext.addFile(indexFileName, true)
// spark.sparkContext.addFile(indexSavePath,true)
println("context path: ", SparkFiles.getRootDirectory + "/" + indexFileName)
// The current interface to HNSW misses the functionality of setting the ef query time
// parameter, but it's lower bounded by topK as per https://github.com/nmslib/hnswlib/blob/master/ALGO_PARAMS.md#search-parameters,
// so set a large value of k as max(ef, topK) to get high recall, then cut off after getting the nearest neighbor.
val k = math.max(topK, ef)
// local scope vectorSize
val vectorSizeLocal = vectorSize
// execute querying
queryFeatures
.repartition(queryNumPartitions)
.toDF("id", "features")
.withColumn("features", $"features".cast("array<float>"))
.as[(SrcKey, Array[Float])]
.mapPartitions((it: Iterator[(SrcKey, Array[Float])]) => {
// load index
val index = new Index(SpaceName.COSINE, vectorSizeLocal)
index.initialize(indexLength, m, efConstruction, (System.currentTimeMillis() / 1000).toInt)
index.load(Paths.get(SparkFiles.getRootDirectory + "/" + indexFileName), indexLength)
it.flatMap(x => {
val idx = x._1
val vector = x._2
val queryTuple = index.knnQuery(vector, k)
val result = queryTuple.getIds
// queryTuple.getLabels
.zip(queryTuple.getCoefficients)
.map(qt => (idx, qt._1, 1.0 - qt._2.toDouble))
.filter(_._3 >= minScoreThreshold)
.sortBy(_._3)
.reverse
.slice(0, topK)
result
})
})
.as[(Int, Int, Double)]
.toDF("src_id", "index_id", "score")
.join(objectIDsMap.toDF("index_id", "dst_id"), Seq("index_id"))
.select("src_id", "dst_id", "score")
.repartition(400)
.as[(SrcKey, DstKey, Double)]
}
}
3 word2vec向量检索实例
-
训练word2vec模型
-
将模型的向量取出,调用上面buildHnswIndex 构建索引
-
分布式进行knnQuery 向量检索
import org.apache.spark.ml.feature.Word2VecModel
import org.apache.spark.ml.linalg.DenseVectorobject exampleWord2Vec {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder().getOrCreate()val GraphInputModel = "graph/model/word2vecmodel" val indexPath = "graph/model/index" spark.udf.register("denseVec2Array",(vec:DenseVector ) => vec.toArray.map(_.toFloat)) spark.udf.register("vectorSplit",(a:String)=>(a.split(',').map(_.toFloat))) import spark.implicits._ val word2vec = Word2VecModel.load(GraphInputModel ) println(word2vec .getVectors.schema) word2vec .getVectors.show(10) println(word2vec .getVectors.count()) val itemEmbeddings = word2vec .getVectors.selectExpr("cast(word as Int) as word", "denseVec2Array(vector) features") .as[(Int,Array[Float])] itemEmbeddings.show() println(itemEmbeddings.schema) val vectorsize=itemEmbeddings.take(1)(0)._2.length val hnswIndex = new annUtilsHnsw().buildHnswIndex(spark, vectorsize, itemEmbeddings, 20) val queryDF=hnswIndex.knnQuery[Int](spark,itemEmbeddings.limit(20),0.3,20,200,160,indexPath,20,200) queryDF .write.mode("overwrite").save(savePathMl + "graph/muiscEmbedding") }
}
4 HNSW pom依赖文件
hnswlib-jna
<dependency>
<groupId>com.stepstone.search.hnswlib.jna</groupId>
<artifactId>hnswlib-jna</artifactId>
<version>1.4.2</version>
</dependency>