Spark-累加器源码分析

一、累加器使用

源码中给的例子是:org.apache.spark.examples.AccumulatorMetricsTest

此示例显示了如何针对累加器源注册累加器,创建一个简单的RDD,在Task中对累加器递增,结果与累加器的值一起输出到Driver中的stdout。为了可以看到效果,我们对累加过程做了些调整

其中我们关心的代码如下,即创建、累加和使用

Scala 复制代码
val accLong = sc.longAccumulator("my-long-metric")
val accDouble = sc.doubleAccumulator("my-double-metric")
val accCollection = sc.collectionAccumulator[String]("my-collection-metric")

val num = if (args.length > 0) args(0).toInt else 1000
val accumulatorTest = sc.parallelize(1 to num).foreach(thisNum=> {
  accLong.add(3)
  accDouble.add(1.1)
  accDouble.add(2.1)
  accCollection.add("num:"+thisNum)
})

println("*** Long accumulator (my-long-metric): " + accLong.value)
println("*** Long accumulator (my-long-metric): count:" + accLong.count)
println("*** Long accumulator (my-long-metric): sum:" + accLong.sum)
println("*** Long accumulator (my-long-metric): avg:" + accLong.avg)
println("*** Double accumulator (my-double-metric): " + accDouble.value)
println("*** Double accumulator (my-double-metric): count:" + accDouble.count)
println("*** Double accumulator (my-double-metric): sum:" + accDouble.sum)
println("*** Double accumulator (my-double-metric): avg:" + accDouble.avg)
println("*** Collection accumulator (my-collection-metric): " + accCollection.value)

输出结果为:

*** Long accumulator (my-long-metric): 3000

*** Long accumulator (my-long-metric): count:1000

*** Long accumulator (my-long-metric): sum:3000

*** Long accumulator (my-long-metric): avg:3.0

*** Double accumulator (my-double-metric): 3199.9999999998868

*** Double accumulator (my-double-metric): count:2000

*** Double accumulator (my-double-metric): sum:3199.9999999998868

*** Double accumulator (my-double-metric): avg:1.5999999999999435

*** Collection accumulator (my-collection-metric): [num:1, num:2, num:3, num:4, num:5, num:6, ......,num:999, num:1000]

二、创建累加器

我们拿最常用的longAccumulator来看下:

1、SparkContext

创建并注册一个Long累加器,它从0开始,通过"add"累加输入

Scala 复制代码
  def longAccumulator(name: String): LongAccumulator = {
    val acc = new LongAccumulator
    register(acc, name)
    acc
  }

  def register(acc: AccumulatorV2[_, _], name: String): Unit = {
    //调用AccumulatorV2的register
    acc.register(this, name = Option(name))
  }

2、AccumulatorV2

累加器的基类,可以累加"IN"类型的输入,并产生"OUT"类型的输出

它是LongAccumulator、DoubleAccumulator、CollectionAccumulator的父类

Scala 复制代码
abstract class AccumulatorV2[IN, OUT] extends Serializable {

  private[spark] def register(
      sc: SparkContext,
      name: Option[String] = None,
      countFailedValues: Boolean = false): Unit = {
    if (this.metadata != null) {
      throw new IllegalStateException("Cannot register an Accumulator twice.")
    }
    this.metadata = AccumulatorMetadata(AccumulatorContext.newId(), name, countFailedValues)
    AccumulatorContext.register(this)
    sc.cleaner.foreach(_.registerAccumulatorForCleanup(this))
  }

    //...............

//内部类
private[spark] object AccumulatorContext extends Logging {

  //此全局映射保存在Driver上创建的原始累加器对象。它保留了对这些对象的弱引用,这样一旦RDD和引用它们的用户代码被清理干净,累加器就可以被垃圾回收。
  private val originals = new ConcurrentHashMap[Long, jl.ref.WeakReference[AccumulatorV2[_, _]]]

  //注册在Driver上创建的[[AcumulatorV2]],以便在Executor上使用。
  //此处注册的所有累加器稍后都可以用作跨多个Task累加部分值的容器。这就是org.apache.spark.scheduler。DAGScheduler上来做的。
  //注意:如果在此处注册了累加器,则还应将其注册到活动上下文清理器中进行清理,以避免内存泄漏。
  //如果已经注册了具有相同ID的[[AcumulatorV2]],这只会覆盖它,而不会做任何事情。我们永远不会重复注册同一个累加器
  def register(a: AccumulatorV2[_, _]): Unit = {
    originals.putIfAbsent(a.id, new jl.ref.WeakReference[AccumulatorV2[_, _]](a))
  }

  //...............
}

//用于计算64位整数的求和、计数和平均值的[[AcumulatorV2累加器]]
class LongAccumulator extends AccumulatorV2[jl.Long, jl.Long] {

  private var _sum = 0L
  private var _count = 0L

  override def add(v: jl.Long): Unit = {
    _sum += v
    _count += 1
  }

  def count: Long = _count

  def sum: Long = _sum

  def avg: Double = _sum.toDouble / _count

  override def value: jl.Long = _sum

  override def merge(other: AccumulatorV2[jl.Long, jl.Long]): Unit = other match {
    case o: LongAccumulator =>
      _sum += o.sum
      _count += o.count
    case _ =>
     //.....抛异常....
  }

  //...............
}

//用于计算双精度浮点数的求和、计数和平均值的累加器
class DoubleAccumulator extends AccumulatorV2[jl.Double, jl.Double] {

  private var _sum = 0.0
  private var _count = 0L

  override def add(v: jl.Double): Unit = {
    _sum += v
    _count += 1
  }

  def count: Long = _count

  def sum: Double = _sum

  def avg: Double = _sum / _count

  override def value: jl.Double = _sum

  override def merge(other: AccumulatorV2[jl.Double, jl.Double]): Unit = other match {
    case o: DoubleAccumulator =>
      _sum += o.sum
      _count += o.count
    case _ =>
     //.....抛异常....
  }

  //...............
}

//用于收集元素列表的[[AcumulatorV2累加器]]
class CollectionAccumulator[T] extends AccumulatorV2[T, java.util.List[T]] {

  private var _list: java.util.List[T] = _

  override def merge(other: AccumulatorV2[T, java.util.List[T]]): Unit = other match {
    case o: CollectionAccumulator[T] => this.synchronized(getOrCreate.addAll(o.value))
    case _ => //.....抛异常....
  }

  private def getOrCreate = {
    _list = Option(_list).getOrElse(new java.util.ArrayList[T]())
    _list
  }

  override def value: java.util.List[T] = this.synchronized {
    java.util.Collections.unmodifiableList(new ArrayList[T](getOrCreate))
  }

  //...............
}

}

三、实现累加

从AccumulatorV2中的方法我们可以知道最终是在DAGScheduler调用它的merge方法来实现的累加,下面我们详细看下在什么位置:

Scala 复制代码
  private def doOnReceive(event: DAGSchedulerEvent): Unit = event match {
    //当任务完成后会调用handleTaskCompletion
    case completion: CompletionEvent =>
      dagScheduler.handleTaskCompletion(completion)

  }


  private[scheduler] def handleTaskCompletion(event: CompletionEvent): Unit = {
    val task = event.task
    val stageId = task.stageId

    //........省略........

    //确保在任何其他处理发生之前更新任务的累加器,以便我们可以在更新任何作业或阶段之前发布任务结束事件。
    event.reason match {
      case Success =>
        task match {
          case rt: ResultTask[_, _] =>
            val resultStage = stage.asInstanceOf[ResultStage]
            resultStage.activeJob match {
              case Some(job) =>
                // 对于每个结果任务,只更新一次累加器
                if (!job.finished(rt.outputId)) {
                  updateAccumulators(event)
                }
              case None => // 如果任务的作业已完成,则忽略更新
            }
          case _ =>
            updateAccumulators(event)
        }
      case _: ExceptionFailure | _: TaskKilled => updateAccumulators(event)
      case _ =>
    }

    //........省略........

  }

  private def updateAccumulators(event: CompletionEvent): Unit = {
    val task = event.task
    val stage = stageIdToStage(task.stageId)

    event.accumUpdates.foreach { updates =>
      val id = updates.id
      try {
        // 在Driver上找到相应的累加器并更新
        val acc: AccumulatorV2[Any, Any] = AccumulatorContext.get(id) match {
          case Some(accum) => accum.asInstanceOf[AccumulatorV2[Any, Any]]
          case None =>
            throw new SparkException(s"attempted to access non-existent accumulator $id")
        }
        acc.merge(updates.asInstanceOf[AccumulatorV2[Any, Any]])

      } catch {
        //......异常处理........
      }
    }
  }

四、总结

1、通过SparkContext创建累加器(LongAccumulator、DoubleAccumulator、CollectionAccumulator)

2、在Driver端注册累加器(累加器必须先注册再使用)(其实就是向全局Map中放入了该累加器)

3、累加器从0开始计数,在每层Stage对应的Task结束时通过merge方法更新Driver端的累计器

4、当一个Job跑完后我们就可以使用累加器变量了,如果是数值型可以拿到总和、累加次数、平均值,如果时集合型可以拿到一个数据序列

相关推荐
Aloudata26 分钟前
从Apache Atlas到Aloudata BIG,数据血缘解析有何改变?
大数据·apache·数据血缘·主动元数据·数据链路
不能再留遗憾了26 分钟前
RabbitMQ 高级特性——消息分发
分布式·rabbitmq·ruby
水豚AI课代表32 分钟前
分析报告、调研报告、工作方案等的提示词
大数据·人工智能·学习·chatgpt·aigc
茶馆大橘35 分钟前
微服务系列六:分布式事务与seata
分布式·docker·微服务·nacos·seata·springcloud
材料苦逼不会梦到计算机白富美3 小时前
golang分布式缓存项目 Day 1
分布式·缓存·golang
拓端研究室TRL3 小时前
【梯度提升专题】XGBoost、Adaboost、CatBoost预测合集:抗乳腺癌药物优化、信贷风控、比特币应用|附数据代码...
大数据
黄焖鸡能干四碗4 小时前
信息化运维方案,实施方案,开发方案,信息中心安全运维资料(软件资料word)
大数据·人工智能·软件需求·设计规范·规格说明书
想进大厂的小王4 小时前
项目架构介绍以及Spring cloud、redis、mq 等组件的基本认识
redis·分布式·后端·spring cloud·微服务·架构
Java 第一深情4 小时前
高性能分布式缓存Redis-数据管理与性能提升之道
redis·分布式·缓存
编码小袁4 小时前
探索数据科学与大数据技术专业本科生的广阔就业前景
大数据