Milvus源码分析:向量查询(Search)

本文针对于milvus2.5.15版本的search流程进行源码分析

tips

  1. proxy做请求转发、结果聚合;
  2. 在queryNode可能会做两次rpc请求;
  3. queryNode和segment是一对多的关系;一个queryNode仅包含部分segment信息;
  4. L0 Segment是一个特殊的segment,用于专门管理删除操作,仅包含删除操作,不包含数据;

数据流

思维导图

源码解析

  1. milvus的各个组件是通过grpc进行交互
  2. milvus.proto是定义行为和属性的统一入口;对应的grpc会生成一个执行类
  3. milvus_grpc.pb.go;milvus server的功能入口

Proxy模块

堆栈如下

Proxy#search

经过一系列的grpc的拦截器后到internal.distributed.proxy.impl.go#search ;这相当于search方法的真正入口的地方了;

在这里面,其中最核心的一步是封装searchTask,然后将这个searchTask提交的Scheduler中,交给异步线程去执行,最后等待返回结果;

go 复制代码
func (node *Proxy) search(ctx context.Context, request *milvuspb.SearchRequest, optimizedSearch bool, isRecallEvaluation bool) (*milvuspb.SearchResults, bool, bool, bool, error) {
    ...
    qt := &searchTask{
        ctx:       ctx,
        Condition: NewTaskCondition(ctx),
        SearchRequest: &internalpb.SearchRequest{
           Base: commonpbutil.NewMsgBase(
              commonpbutil.WithMsgType(commonpb.MsgType_Search),
              commonpbutil.WithSourceID(paramtable.GetNodeID()),
           ),
           ReqID:              paramtable.GetNodeID(),
           IsTopkReduce:       optimizedSearch,
           IsRecallEvaluation: isRecallEvaluation,
        },
        request:                request,
        tr:                     timerecord.NewTimeRecorder("search"),
        qc:                     node.queryCoord,
        node:                   node,
        lb:                     node.lbPolicy,
        enableMaterializedView: node.enableMaterializedView,
        mustUsePartitionKey:    Params.ProxyCfg.MustUsePartitionKey.GetAsBool(),
    }
    ...
    // 会将这个任务提交到task_scheduler.go#taskScheduler中;
    if err := node.sched.dqQueue.Enqueue(qt); err != nil {
    ...
}

这一步会将当前的searchTask入队列;然后返回一个引用,就完成了 ;那这个searchTask在哪里执行的呢?

答案:服务启动的时候,启动的taskScheduler,

taskScheduler#queryLoop

主要的功能:创建一个pool,然后循环监听dqQueue中是否有任务需要执行,如果有,则会将放到pool中执行,如果收到done信号,就会结束任务;

go 复制代码
func (sched *taskScheduler) queryLoop() {
    defer sched.wg.Done()

    poolSize := paramtable.Get().ProxyCfg.MaxTaskNum.GetAsInt()
    // 创建一个pool
    pool := conc.NewPool[struct{}](poolSize, conc.WithExpiryDuration(time.Minute))
    subTaskPool := conc.NewPool[struct{}](poolSize, conc.WithExpiryDuration(time.Minute))
    defer pool.Release()
    defer subTaskPool.Release()
    // 这里是监听执行的地方
    for {
       select {
       case <-sched.ctx.Done():
          return
       case <-sched.dqQueue.utChan():
          // 如果dqQueue不为空,则会弹出task,然后放到线程池中进行执行
          if !sched.dqQueue.utEmpty() {
             t := sched.scheduleDqTask()
             p := pool
             // if task is sub task spawned by another, use sub task pool in case of deadlock
             if t.IsSubTask() {
                p = subTaskPool
             }
             // 往pool中提交任务,开始执行searchTask
             p.Submit(func() (struct{}, error) {
             // processTask 才是执行的重点,执行完成以后,通过引用进行回传感知
                sched.processTask(t, sched.dqQueue)
                return struct{}{}, nil
             })
          } else {
             log.Ctx(context.TODO()).Debug("query queue is empty ...")
          }
          sched.dqQueue.updateMetrics()
       }
    }
}

taskScheduler#processTask

在这个方法开始执行

  1. PreExecute:参数校验、变量赋值(partitionName->partitionID)
  2. Execute:请求分发到不同的shard中
  3. PostExecute:解码、并聚合 主要是三行,pre、exec、post
go 复制代码
func (sched *taskScheduler) processTask(t task, q taskQueue) {
    .....
    // 主要是参数校验、对象变量赋值(consistencyLevel、guaranteeTs等)、是否搜索下推(优先使用标量过滤)
    err := t.PreExecute(ctx)
    
    ......
    // 真正执行的地方(后续重点介绍)
    // 经过LB,将请求分发到不同的shard上执行,
    err = t.Execute(ctx)
    ......
    // decode、收集结果并聚合
    err = t.PostExecute(ctx)
    .....
}

经过searchTask.execute-> loadbalancer.Execute进行负载均衡,最终执行到searchTask.searchShard方法

scss 复制代码
func (t *searchTask) Execute(ctx context.Context) error {
    ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Search-Execute")
    defer sp.End()
    log := log.Ctx(ctx).WithLazy(zap.Int64("nq", t.SearchRequest.GetNq()))

    tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy execute search %d", t.ID()))
    defer tr.CtxElapse(ctx, "done")

    err := t.lb.Execute(ctx, CollectionWorkLoad{
       db:             t.request.GetDbName(),
       collectionID:   t.SearchRequest.CollectionID,
       collectionName: t.collectionName,
       nq:             t.Nq,
       // 回调函数在这里定义的;
       exec:           t.searchShard,
    })
    if err != nil {
       log.Warn("search execute failed", zap.Error(err))
       return errors.Wrap(err, "failed to search")
    }

    log.Debug("Search Execute done.",
       zap.Int64("collection", t.GetCollectionID()),
       zap.Int64s("partitionIDs", t.GetPartitionIDs()))
    return nil
}

LBPolicyImpl#Execute

  1. 获取每个shard的leader
  2. 执行传入的searchShard函数(真正的重点)
  3. 将searchResult结果放到context中(decode流程在postExecute)
go 复制代码
func (lb *LBPolicyImpl) Execute(ctx context.Context, workload CollectionWorkLoad) error {
    // 获取每个shard的leader
    dml2leaders, err := lb.GetShardLeaders(ctx, workload.db, workload.collectionName, workload.collectionID, true)
    if err != nil {
       log.Ctx(ctx).Warn("failed to get shards", zap.Error(err))
       return err
    }

    // let every request could retry at least twice, which could retry after update shard leader cache
    wg, ctx := errgroup.WithContext(ctx)
    // 遍历所有分片的数据
    for k, v := range dml2leaders {
       channel := k
       nodes := v
       channelRetryTimes := lb.retryOnReplica
       if len(nodes) > 0 {
          channelRetryTimes *= len(nodes)
       }
       wg.Go(func() error {
       // 带有重试功能的执行
          return lb.ExecuteWithRetry(ctx, ChannelWorkload{
             db:             workload.db,
             collectionName: workload.collectionName,
             collectionID:   workload.collectionID,
             channel:        channel,
             shardLeaders:   nodes,
             nq:             workload.nq,
             exec:           workload.exec,
             retryTimes:     uint(channelRetryTimes),
          })
       })
    }

    return wg.Wait()
}

searchTask#searchShard

开始要执行rpc请求到queryNode;

go 复制代码
func (t *searchTask) searchShard(ctx context.Context, nodeID int64, qn types.QueryNodeClient, channel string) error {
    ......
    // 这里的qn 是qn_wrapper.go里面的qnServerWrapper
    result, err = qn.Search(ctx, req)
     ......
}

QueryNode

QueryNode模块的功能流转

经过RPC调用后,先经过querynode/service.go,流转至querynodev2/services.go进行逻辑处理

QueryNode#Search方法

internal.querynodev2.services.go 一个注意点:DmlChannels: 是从QueryCoord里面查询得到以后,逐步传过来的;是一个string字符串(比如:by-dev-rootcoord-dml_0_xxxv0)

go 复制代码
func (node *QueryNode) Search(ctx context.Context, req *querypb.SearchRequest) (*internalpb.SearchResults, error) {
   ......
    collection := node.manager.Collection.Get(req.GetReq().GetCollectionID())
   ......
    // 获取解析结果
    ret, err := node.searchChannel(ctx, channelReq, ch)
    if err != nil {
       resp.Status = merr.Status(err)
       return resp, nil
    }
......
    return ret, nil
}

QueryNode#searchChannel

  1. 在这个DmlChannel对应的shardDelegator中准备获取数据
  2. DmlChannel和shardDelegator是一一对应的
  3. 执行查询操作
  4. (一个DmlChanel对应多个segments时)将segments的数据进行pb解码,去重、排序后,选择topK,然后在pb编码
go 复制代码
func (node *QueryNode) searchChannel(ctx context.Context, req *querypb.SearchRequest, channel string) (*internalpb.SearchResults, error) {
    ......
    // get delegator
    sd, ok := node.delegators.Get(channel)
  
    ......
    // do search
    results, err := sd.Search(searchCtx, req)
   
   ......
   // reduce result
    resp, err := segments.ReduceSearchOnQueryNode(ctx, results,
       reduce.NewReduceSearchResultInfo(req.GetReq().GetNq(),
         
   ......
   
    return resp, nil
}

shardDelegator#search

  1. 智能剪枝
  2. BM25字段检索
  3. 分配给不同的QueryNode(该node包含segment信息)search子任务
  4. 真正开始调用rpc请求,并返回结果
  5. 结果聚合(去重、排序)
go 复制代码
func (sd *shardDelegator) search(ctx context.Context, req *querypb.SearchRequest, sealed []SnapshotItem, growing []SegmentEntry) ([]*internalpb.SearchResults, error) {
   ......
    searchAgainstBM25Field := sd.isBM25Field[req.GetReq().GetFieldId()]
    
   ......
   
    req, err := optimizers.OptimizeSearchParams(ctx, req, sd.queryHook, sealedNum)
    if err != nil {
       log.Warn("failed to optimize search params", zap.Error(err))
       return nil, err
    }
    
    
    tasks, err := organizeSubTask(ctx, req, sealed, growing, sd, sd.modifySearchRequest)
    if err != nil {
       log.Warn("Search organizeSubTask failed", zap.Error(err))
       return nil, err
    }
    
    // 分别执行子task, 
    results, err := executeSubTasks(ctx, tasks, func(ctx context.Context, req *querypb.SearchRequest, worker cluster.Worker) (*internalpb.SearchResults, error) {
       return worker.SearchSegments(ctx, req)
    }, "Search", log)
    if err != nil {
       log.Warn("Delegator search failed", zap.Error(err))
       return nil, err
    }

    log.Debug("Delegator search done")

    return results, nil
}

QueryNode#SearchSegments

最主要的功能就是将search任务放到了队列中,等待异步执行

scss 复制代码
func (node *QueryNode) SearchSegments(ctx context.Context, req *querypb.SearchRequest) (*internalpb.SearchResults, error) {
   channel := req.GetDmlChannels()[0]
   
   ......
   
   if err := node.lifetime.Add(merr.IsHealthy); err != nil {
      resp.Status = merr.Status(err)
      return resp, nil
   }
   defer node.lifetime.Done()

  ......
  
   log.Debug("start to search segments on worker",
      zap.Int64s("segmentIDs", req.GetSegmentIDs()),
   )
   searchCtx, cancel := context.WithCancel(ctx)
   defer cancel()

   tr := timerecord.NewTimeRecorder("searchSegments")
   log.Debug("search segments...")

   if !node.manager.Collection.Ref(req.Req.GetCollectionID(), 1) {
      err := merr.WrapErrCollectionNotLoaded(req.GetReq().GetCollectionID())
      log.Warn("failed to search segments", zap.Error(err))
      resp.Status = merr.Status(err)
      return resp, nil
   }
   collection := node.manager.Collection.Get(req.Req.GetCollectionID())
   defer func() {
      node.manager.Collection.Unref(req.GetReq().GetCollectionID(), 1)
   }()

   var task scheduler.Task
   if paramtable.Get().QueryNodeCfg.UseStreamComputing.GetAsBool() {
      task = tasks.NewStreamingSearchTask(searchCtx, collection, node.manager, req, node.serverID)
   } else {
      task = tasks.NewSearchTask(searchCtx, collection, node.manager, req, node.serverID)
   }

   if err := node.scheduler.Add(task); err != nil {
      log.Warn("failed to search channel", zap.Error(err))
      resp.Status = merr.Status(err)
      return resp, nil
   }
......
}

quernodev2/SearchTask#Search

  1. 在指定的Segment执行搜索逻辑
  2. 根据类型判断在Stream还是Historical Segment执行Search任务
  3. 最后做一次结果的聚合
scss 复制代码
func (t *SearchTask) Execute() error {
    log := log.Ctx(t.ctx).With(
       zap.Int64("collectionID", t.collection.ID()),
       zap.String("shard", t.req.GetDmlChannels()[0]),
    )

    if t.scheduleSpan != nil {
       t.scheduleSpan.End()
    }
    tr := timerecord.NewTimeRecorderWithTrace(t.ctx, "SearchTask")

    req := t.req
    err := t.combinePlaceHolderGroups()
    if err != nil {
       return err
    }
    searchReq, err := segcore.NewSearchRequest(t.collection.GetCCollection(), req, t.placeholderGroup)
    if err != nil {
       return err
    }
    defer searchReq.Delete()

    var (
       results          []*segments.SearchResult
       searchedSegments []segments.Segment
    )
    if req.GetScope() == querypb.DataScope_Historical {
       results, searchedSegments, err = segments.SearchHistorical(
          t.ctx,
          t.segmentManager,
          searchReq,
          req.GetReq().GetCollectionID(),
          req.GetReq().GetPartitionIDs(),
          req.GetSegmentIDs(),
       )
    } else if req.GetScope() == querypb.DataScope_Streaming {
       results, searchedSegments, err = segments.SearchStreaming(
          t.ctx,
          t.segmentManager,
          searchReq,
          req.GetReq().GetCollectionID(),
          req.GetReq().GetPartitionIDs(),
          req.GetSegmentIDs(),
       )
    }
    defer t.segmentManager.Segment.Unpin(searchedSegments)
    if err != nil {
       return err
    }
    defer segments.DeleteSearchResults(results)

    // plan.MetricType is accurate, though req.MetricType may be empty
    metricType := searchReq.Plan().GetMetricType()

    if len(results) == 0 {
       for i := range t.originNqs {
          var task *SearchTask
          if i == 0 {
             task = t
          } else {
             task = t.others[i-1]
          }

          task.result = &internalpb.SearchResults{
             Base: &commonpb.MsgBase{
                SourceID: t.GetNodeID(),
             },
             Status:         merr.Success(),
             MetricType:     metricType,
             NumQueries:     t.originNqs[i],
             TopK:           t.originTopks[i],
             SlicedOffset:   1,
             SlicedNumCount: 1,
             CostAggregation: &internalpb.CostAggregation{
                ServiceTime: tr.ElapseSpan().Milliseconds(),
             },
          }
       }
       return nil
    }

    relatedDataSize := lo.Reduce(searchedSegments, func(acc int64, seg segments.Segment, _ int) int64 {
       return acc + segments.GetSegmentRelatedDataSize(seg)
    }, 0)

    tr.RecordSpan()
    // 结果聚合、排序、去重、填充字段,序列化成pb格式
    blobs, err := segcore.ReduceSearchResultsAndFillData(
       t.ctx,
       searchReq.Plan(),
       results,
       int64(len(results)),
       t.originNqs,
       t.originTopks,
    )
    if err != nil {
       log.Warn("failed to reduce search results", zap.Error(err))
       return err
    }
    defer segcore.DeleteSearchResultDataBlobs(blobs)
    metrics.QueryNodeReduceLatency.WithLabelValues(
       fmt.Sprint(t.GetNodeID()),
       metrics.SearchLabel,
       metrics.ReduceSegments,
       metrics.BatchReduce).
       Observe(float64(tr.RecordSpan().Milliseconds()))
    for i := range t.originNqs {
       blob, err := segcore.GetSearchResultDataBlob(t.ctx, blobs, i)
       if err != nil {
          return err
       }

       var task *SearchTask
       if i == 0 {
          task = t
       } else {
          task = t.others[i-1]
       }

       // Note: blob is unsafe because get from C
       bs := make([]byte, len(blob))
       copy(bs, blob)

       task.result = &internalpb.SearchResults{
          Base: &commonpb.MsgBase{
             SourceID: t.GetNodeID(),
          },
          Status:         merr.Success(),
          MetricType:     metricType,
          NumQueries:     t.originNqs[i],
          TopK:           t.originTopks[i],
          SlicedBlob:     bs,
          SlicedOffset:   1,
          SlicedNumCount: 1,
          CostAggregation: &internalpb.CostAggregation{
             ServiceTime:          tr.ElapseSpan().Milliseconds(),
             TotalRelatedDataSize: relatedDataSize,
          },
       }
    }

    return nil
}
相关推荐
间彧3 小时前
Java HashMap:链表工作原理与红黑树转换
后端
亚雷4 小时前
深入浅出达梦共享存储集群数据同步
数据库·后端·程序员
作伴4 小时前
多租户架构如何设计多数据源
后端
苏三说技术4 小时前
SpringBoot开发使用Mybatis,还是Spring Data JPA?
后端
canonical_entropy4 小时前
最小信息表达:软件框架设计的第一性原理
后端·架构·编译原理
自由的疯4 小时前
Java Docker部署RuoYi框架的jar包
java·后端·架构
自由的疯5 小时前
Java Docker本地部署Java服务
java·后端·架构
绝无仅有5 小时前
面试真实经历某商银行大厂计算机网络问题和答案总结
后端·面试·github
绝无仅有5 小时前
面试真实经历某商银行大厂系统,微服务,分布式问题和答案总结
后端·面试·github