Spark SQL Shuffle 分区数生成机制详解

前言

在 Spark SQL 物理计划已经生成之后、真正执行之前,还会通过一系列 rule(如 AQE、EnsureRequirements、WSCG 等),对整棵物理计划进行全局整合或优化。

普通 Join、Group By 等算子所需的隐式 Shuffle,通常是在 EnsureRequirements 中补充的。

例如,一个聚合 SQL 生成的初始物理计划是:

text 复制代码
HashAggregateExec(final)
  +- HashAggregateExec(partial)
      +- FileSourceScanExec

经过 EnsureRequirements 后:

text 复制代码
HashAggregateExec(final)
  +- ShuffleExchangeExec(HashPartitioning(user_id, 200))
      +- HashAggregateExec(partial)
          +- FileSourceScanExec

Spark SQL 的 Shuffle 分区数,通常是在执行 EnsureRequirements 并插入 ShuffleExchangeExec 时确定的。

接下来,本文基于 Spark 3.1.2 源码视角,详细剖析 Spark SQL 发生 Shuffle 时的分区数决定机制。

EnsureRequirements 介绍

org.apache.spark.sql.execution.exchange.EnsureRequirements

scala 复制代码
object EnsureRequirements extends Rule[SparkPlan] {

  private def ensureDistributionAndOrdering(operator: SparkPlan): SparkPlan = {
    
    // 1、获取当前算子要求 child 需要满足的数据分布和排序要求
    val requiredChildDistributions: Seq[Distribution] = operator.requiredChildDistribution
    val requiredChildOrderings: Seq[Seq[SortOrder]] = operator.requiredChildOrdering
    var children: Seq[SparkPlan] = operator.children
    assert(requiredChildDistributions.length == children.length)
    assert(requiredChildOrderings.length == children.length)
	
    // 2、判断当前算子 child 的数据分布是否满足当前算子的输入要求,并根据不同结果作处理
    // Ensure that the operator's children satisfy their output distribution requirements.
    children = children.zip(requiredChildDistributions).map {
      case (child, distribution) if child.outputPartitioning.satisfies(distribution) =>
        child
      case (child, BroadcastDistribution(mode)) =>
        BroadcastExchangeExec(mode, child)
      case (child, distribution) =>
        val numPartitions = distribution.requiredNumPartitions
          .getOrElse(conf.numShufflePartitions)
        ShuffleExchangeExec(distribution.createPartitioning(numPartitions), child)
    }
      
    //....
  }

这段代码的作用是:对当前物理算子 operator 的每个 child 进行检查,判断 child 的输出分布是否满足父算子的输入要求。如果不满足,就在已有 child 上方插入 BroadcastExchangeExecShuffleExchangeExec

operator.children

scala 复制代码
var children: Seq[SparkPlan] = operator.children

operator.children 表示当前物理算子的输入子计划。

单 child 场景:

text 复制代码
HashAggregateExec
  +- FileSourceScanExec

HashAggregateExec 来说:

text 复制代码
children = Seq(FileSourceScanExec)

多 child 场景:

text 复制代码
SortMergeJoinExec
  :- leftPlan
  +- rightPlan

SortMergeJoinExec 来说:

text 复制代码
children = Seq(leftPlan, rightPlan)

叶子节点没有 child,例如 FileSourceScanExec,其自身的 children = Seq.empty。而 ProjectExecFilterExec 这类一元算子通常会有一个 child。

operator.requiredChildDistribution

scala 复制代码
val requiredChildDistributions: Seq[Distribution] = operator.requiredChildDistribution

requiredChildDistribution 定义了当前物理算子要求所有 child 需要满足的数据分布 Distribution

requiredChildDistributionSparkPlan 中定义,所有物理算子都会包含该方法。

org.apache.spark.sql.execution.SparkPlan

scala 复制代码
  def requiredChildDistribution: Seq[Distribution] =
    Seq.fill(children.size)(UnspecifiedDistribution)

默认含义是:有几个 child,就返回几个 UnspecifiedDistribution(对 child 的数据分区不关心)。

如果算子没有重写 requiredChildDistribution,通常表示它对 child 数据分区没有特殊要求,比如 ProjectExecFilterExec 这类一元算子。而常见的会发生 Shuffle 的算子,如 HashAggregateExecSortMergeJoinExec,会要求 child 满足一定的数据分布,因此会重写 requiredChildDistribution

常见 Shuffle 算子要求 child 满足的数据分布 Distribution

不同算子对 child 的数据分布要求是不一样的,下面列举几种常见 Shuffle 算子所要求的数据分布。

① HashAggregateExec

HashAggregateExecrequiredChildDistribution 方法在 BaseAggregateExec 中实现。

org/apache/spark/sql/execution/aggregate/BaseAggregateExec

scala 复制代码
override def requiredChildDistribution: List[Distribution] = {
    requiredChildDistributionExpressions match {
        case Some(exprs) if exprs.isEmpty => AllTuples :: Nil
        case Some(exprs) => ClusteredDistribution(exprs) :: Nil
        case None => UnspecifiedDistribution :: Nil
    }
}
  • 有指定表达式且为空

    这种场景常见于全局聚合 ,如 SELECT COUNT(*) FROM t,要求所有数据进入同一个分区,Distribution 数据分布要求为 AllTuples

  • 指定了表达式且非空

    这种场景常见于普通分组聚合 ,如 SELECT key, COUNT(*) FROM t GROUP BY key,要求 child 节点输出的数据按照这些表达式聚集或分区,数据分布要求为 ClusteredDistribution

  • 没有指定任何分布要求

对 child 节点的数据分布没有要求。

② SortMergeJoinExec 或 ShuffledHashJoinExec

对于 Shuffle Join,例如 SortMergeJoinExecShuffledHashJoinExec,它们有两个 child,并且对左右 child 都有 ClusteredDistribution 数据分布要求(按照 Join key 做相同的分布)。

org/apache/spark/sql/execution/joins/ShuffledJoin

scala 复制代码
override def requiredChildDistribution: Seq[Distribution] = {
  ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
}

上面频繁出现 Distribution 这个概念,它到底是什么?下面一起看看。

Distribution

org/apache/spark/sql/catalyst/plans/physical/Distribution

scala 复制代码
sealed trait Distribution {
  // 该 Distribution 要求的分区数量。如果是 None,表示任意分区数量都可以满足这个 Distribution。
  def requiredNumPartitions: Option[Int]

  // 为该 Distribution 创建一个默认的 Partitioning。这个 Partitioning 既能满足当前 Distribution 的要求,又能匹配给定的分区数量
  def createPartitioning(numPartitions: Int): Partitioning
}

Distribution 表示"数据应该怎么分布":当一个查询在多台机器上并行执行时,具有相同表达式值的元组应该如何分布。

前面涉及了不同算子所要求的 Distribution,下面将介绍常见的 Distribution 类型:

  • UnspecifiedDistribution

不对 child 输入数据的分区方式提出要求,如普通投影、过滤。

scala 复制代码
case object UnspecifiedDistribution extends Distribution {
  override def requiredNumPartitions: Option[Int] = None

  override def createPartitioning(numPartitions: Int): Partitioning = {
    throw new IllegalStateException("UnspecifiedDistribution does not have default partitioning.")
  }
}
  • AllTuples

只有一个分区的分布,表示整个数据集必须放在一个分区里,通常用于需要全局处理的操作。

scala 复制代码
case object AllTuples extends Distribution {
  override def requiredNumPartitions: Option[Int] = Some(1)

  override def createPartitioning(numPartitions: Int): Partitioning = {
    assert(numPartitions == 1, "The default partitioning of AllTuples can only have 1 partition.")
    SinglePartition
  }
}
  • ClusteredDistribution

要求相同 key 的数据落在同一分区,通常用于 Group By、窗口函数。

scala 复制代码
case class ClusteredDistribution(
    clustering: Seq[Expression],
    requiredNumPartitions: Option[Int] = None) extends Distribution {
  require(
    clustering != Nil, // 参与数据分布的表达式不能为空
    "The clustering expressions of a ClusteredDistribution should not be Nil.")

  override def createPartitioning(numPartitions: Int): Partitioning = {
    assert(requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions,
      s"This ClusteredDistribution requires ${requiredNumPartitions.get} partitions, but " +
        s"the actual number of partitions is $numPartitions.")
    HashPartitioning(clustering, numPartitions)
  }
}
  • HashClusteredDistribution

    不仅要求相同 key 的数据在一起,还要求按 hash 规则进入指定分区,常用于 join 算子。相比 ClusteredDistribution,它提供了更强的保障。

scala 复制代码
case class HashClusteredDistribution(
    expressions: Seq[Expression],
    requiredNumPartitions: Option[Int] = None) extends Distribution {
  require(
    expressions != Nil, // 参与数据分布的表达式不能为空
    "The expressions of a HashClusteredDistribution should not be Nil.")

  override def createPartitioning(numPartitions: Int): Partitioning = {
    assert(requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions,
      s"This HashClusteredDistribution requires ${requiredNumPartitions.get} partitions, but " +
        s"the actual number of partitions is $numPartitions.")
    HashPartitioning(expressions, numPartitions)
  }
}
  • BroadcastDistribution

是一种广播分布要求,数据会被广播到每个执行节点上。

scala 复制代码
case class BroadcastDistribution(mode: BroadcastMode) extends Distribution {
  override def requiredNumPartitions: Option[Int] = Some(1) // 广播数据在被广播前通常会先收集成一个整体,然后作为一个广播变量发到各个 executor。所以从 Spark 物理分区语义上看,默认分区数要求是 1。

  override def createPartitioning(numPartitions: Int): Partitioning = {
    assert(numPartitions == 1,
      "The default partitioning of BroadcastDistribution can only have 1 partition.")
    BroadcastPartitioning(mode)
  }
}

Partitioning

org/apache/spark/sql/catalyst/plans/physical/partitioning

scala 复制代码
trait Partitioning {

   // 该 SparkPlan 输出 RDD 的分区数目
  val numPartitions: Int

    // 当前 child 的 Partitioning 是否能够满足下游要求的 Distribution,不满足时返回 false。
    // 要满足需要两个条件:① 分区数 numPartitions 要相等;
    // ② satisfies0 方法返回 true,包括两种情况:1. 对子节点的分布没有要求;2. 全局处理。
  final def satisfies(required: Distribution): Boolean = {
    required.requiredNumPartitions.forall(_ == numPartitions) && satisfies0(required)
  }

   // 1、如果 requiredChildDistribution 为 UnspecifiedDistribution,则说明对子节点的分布没有要求,返回 true;
   // 2、如果 requiredChildDistribution 为 AllTuples,则只要 numPartitions == 1,返回 true;
   // 3、其他情况,返回 false。
   // 具体的 Partitioning 类型会对 satisfies0 进行重写。
  protected def satisfies0(required: Distribution): Boolean = required match {
    case UnspecifiedDistribution => true
    case AllTuples => numPartitions == 1
    case _ => false
  }
}

Partitioning 定义了一个物理算子输出数据的分区方式。

常见 Partitioning 类型有:

UnknownPartitioning:未知分区方式;

SinglePartition:单分区;

BroadcastPartitioning:广播分区;

HashPartitioning:基于哈希的分区方式;

RangePartitioning:基于范围的分区方式,通过确定分区键是否在某个范围内来选择分区;

Distribution 和 Partitioning 的关系

Distribution 是"要求",Partitioning 是"实现"。

比如某个物理算子要求 child 中相同 key 的数据必须在同一个分区,这是一种 Distribution 要求。Spark 为了满足它,可能会创建 HashPartitioning(keys, numPartitions),这就是具体的 Partitioning 实现,也就是通过 hash(key) 的方式把数据分到指定数量的分区中。

简单总结:

复制代码
Distribution  =  算子需要什么样的数据分布
Partitioning  =  当前数据实际是什么样的分区方式
Shuffle       =  当现有 Partitioning 不能满足 Distribution 时,Spark 插入 Exchange 来重新分区

children.zip(requiredChildDistributions)

scala 复制代码
children = children.zip(requiredChildDistributions).map {
  case (child, distribution) if child.outputPartitioning.satisfies(distribution) =>
    child
  case (child, BroadcastDistribution(mode)) =>
    BroadcastExchangeExec(mode, child)
  case (child, distribution) =>
    val numPartitions = distribution.requiredNumPartitions
      .getOrElse(conf.numShufflePartitions)
    ShuffleExchangeExec(distribution.createPartitioning(numPartitions), child)
}

这一步会把当前算子的每个 child 和它要求的 Distribution 配对。

例如 SortMergeJoinExec 算子:

text 复制代码
children:
  leftPlan
  rightPlan

requiredChildDistributions:
  ClusteredDistribution(leftJoinKeys)
  ClusteredDistribution(rightJoinKeys)

配对后变成:

text 复制代码
(leftPlan, ClusteredDistribution(leftJoinKeys))
(rightPlan, ClusteredDistribution(rightJoinKeys))

后续 map 会逐个检查每个 child 是否满足对应的 Distribution

第一分支:child 已经满足分布要求

scala 复制代码
case (child, distribution) if child.outputPartitioning.satisfies(distribution) =>
  child

如果 child 的输出分区方式已经满足父算子的分布要求,就直接返回原 child。例如,ProjectExecFilterExec 等物理算子并不要求 child 具备特定的输入分布。

第二分支:需要 BroadcastExchangeExec

scala 复制代码
case (child, BroadcastDistribution(mode)) =>
  BroadcastExchangeExec(mode, child)

如果父算子要求 child 满足 BroadcastDistribution,说明该 child 不需要普通 Shuffle,而是需要被广播。

典型场景是 Broadcast Hash Join:

sql 复制代码
SELECT /*+ BROADCAST(s) */ *
FROM big_table b
JOIN small_table s
ON b.id = s.id;

对于小表 s,父算子会要求:

scala 复制代码
BroadcastDistribution(HashedRelationBroadcastMode(keys))

第三分支:需要 ShuffleExchangeExec

scala 复制代码
case (child, distribution) =>
  val numPartitions = distribution.requiredNumPartitions
    .getOrElse(conf.numShufflePartitions)
  ShuffleExchangeExec(distribution.createPartitioning(numPartitions), child)

如果 child 输出分区不满足当前算子的 Distribution,且当前算子要求的不是 broadcast,就进入普通 Shuffle 分支。

1 确定 Shuffle 分区数

scala 复制代码
val numPartitions = distribution.requiredNumPartitions
  .getOrElse(conf.numShufflePartitions)

如果 distribution.requiredNumPartitions 有值,则优先使用它。例如 AllTuples.requiredNumPartitions = Some(1),表示强制单分区。

如果没有强制分区数,则使用 conf.numShufflePartitions

org/apache/spark/sql/internal/SQLConf

scala 复制代码
def defaultNumShufflePartitions: Int = getConf(SHUFFLE_PARTITIONS)

def numShufflePartitions: Int = {
  if (adaptiveExecutionEnabled && coalesceShufflePartitionsEnabled) {
    getConf(COALESCE_PARTITIONS_INITIAL_PARTITION_NUM).getOrElse(defaultNumShufflePartitions)
  } else {
    defaultNumShufflePartitions
  }
}

因此:

  • 未开启 AQE 时,shuffle 分区数等于 spark.sql.shuffle.partitions
  • 开启 AQE 且开启 Shuffle 分区合并时,初始分区数可能来自 spark.sql.adaptive.coalescePartitions.initialPartitionNum;如果没有设置 initialPartitionNum,仍然回退到 spark.sql.shuffle.partitions

相关参数:

spark.sql.shuffle.partitions

该参数表示在进行 join 或聚合操作且需要 shuffle 数据时,默认使用的分区数量。默认值是 200。

scala 复制代码
val SHUFFLE_PARTITIONS = buildConf("spark.sql.shuffle.partitions")
  .doc("The default number of partitions to use when shuffling data for joins or aggregations.")
  .version("1.1.0")
  .intConf
  .checkValue(_ > 0, "The value of spark.sql.shuffle.partitions must be positive")
  .createWithDefault(200)

spark.sql.adaptive.coalescePartitions.initialPartitionNum

该参数表示在进行分区合并之前,shuffle 分区的初始数量。

如果没有设置该值,则它等于 spark.sql.shuffle.partitions

这个配置只有在 spark.sql.adaptive.enabledspark.sql.adaptive.coalescePartitions.enabled 都为 true 时才会生效。

scala 复制代码
  val COALESCE_PARTITIONS_INITIAL_PARTITION_NUM =
    buildConf("spark.sql.adaptive.coalescePartitions.initialPartitionNum")
      .doc("The initial number of shuffle partitions before coalescing. If not set, it equals to " +
        s"${SHUFFLE_PARTITIONS.key}. This configuration only has an effect when " +
        s"'${ADAPTIVE_EXECUTION_ENABLED.key}' and '${COALESCE_PARTITIONS_ENABLED.key}' " +
        "are both true.")
      .version("3.0.0")
      .intConf
      .checkValue(_ > 0, "The initial number of partitions must be positive.")
      .createOptional

spark.sql.adaptive.enabled

该参数表示是否开启 Spark SQL 的 AQE,即 Adaptive Query Execution,自适应查询执行(默认关闭)。

开启后,Spark 不只依赖执行前的静态优化结果,而是会在 SQL 运行过程中,根据真实的运行时数据量、分区大小、统计信息等,动态调整执行计划。

scala 复制代码
  val ADAPTIVE_EXECUTION_ENABLED = buildConf("spark.sql.adaptive.enabled")
    .doc("When true, enable adaptive query execution, which re-optimizes the query plan in the " +
      "middle of query execution, based on accurate runtime statistics.")
    .version("1.6.0")
    .booleanConf
    .createWithDefault(false)

spark.sql.adaptive.coalescePartitions.enabled

该参数表示是否开启 AQE 下的 shuffle 分区合并功能(默认开启)。

当该配置为 true,并且 spark.sql.adaptive.enabled 也为 true 时,Spark 会根据目标分区大小,也就是 spark.sql.adaptive.advisoryPartitionSizeInBytes 指定的大小,合并连续的 shuffle 分区,以避免产生过多的小任务。

scala 复制代码
  val COALESCE_PARTITIONS_ENABLED =
    buildConf("spark.sql.adaptive.coalescePartitions.enabled")
      .doc(s"When true and '${ADAPTIVE_EXECUTION_ENABLED.key}' is true, Spark will coalesce " +
        "contiguous shuffle partitions according to the target size (specified by " +
        s"'${ADVISORY_PARTITION_SIZE_IN_BYTES.key}'), to avoid too many small tasks.")
      .version("3.0.0")
      .booleanConf
      .createWithDefault(true)

spark.sql.adaptive.advisoryPartitionSizeInBytes

该参数表示 AQE 过程中 shuffle 分区的"建议目标大小"。

该配置在 spark.sql.adaptive.enabledtrue 时使用。当 Spark 合并较小的 shuffle 分区,或者拆分发生数据倾斜的 shuffle 分区时,该配置会生效。

它主要影响两类场景:

  • 合并小 shuffle 分区:多个小分区会被合并,目标是接近该大小。
  • 拆分倾斜 shuffle 分区:过大的倾斜分区会被拆开,拆分后的大小也会参考该值。
scala 复制代码
  val ADVISORY_PARTITION_SIZE_IN_BYTES =
    buildConf("spark.sql.adaptive.advisoryPartitionSizeInBytes")
      .doc("The advisory size in bytes of the shuffle partition during adaptive optimization " +
        s"(when ${ADAPTIVE_EXECUTION_ENABLED.key} is true). It takes effect when Spark " +
        "coalesces small shuffle partitions or splits skewed shuffle partition.")
      .version("3.0.0")
      .fallbackConf(SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE)

2 创建具体的分区

scala 复制代码
distribution.createPartitioning(numPartitions)

创建满足当前算子要求的 Distribution 分布。

常见映射关系:

text 复制代码
ClusteredDistribution(keys)
  -> HashPartitioning(keys, numPartitions)

OrderedDistribution(ordering)
  -> RangePartitioning(ordering, numPartitions)

AllTuples
  -> SinglePartition

3 创建 ShuffleExchangeExec

scala 复制代码
ShuffleExchangeExec(partitioning, child)

在当前物理算子的子节点 child 外面包一层 ShuffleExchangeExec,强制对 child 的输出数据做一次 shuffle。在生成新的物理计划树时,当前物理算子的 child 会被替换为这个新的 ShuffleExchangeExec,从而在当前算子和原计划之间增加一层 shuffle。

例如,一个聚合 SQL 生成的初始物理计划是:

text 复制代码
HashAggregateExec(final)
  +- HashAggregateExec(partial)
      +- FileSourceScanExec

经过 EnsureRequirements 后:

text 复制代码
HashAggregateExec(final)
  +- ShuffleExchangeExec(HashPartitioning(user_id, 200))
      +- HashAggregateExec(partial)
          +- FileSourceScanExec

ShuffleExchangeExec 将分区数带入 Spark Core

ShuffleExchangeExec 持有最终的输出分区方式:

org/apache/spark/sql/execution/exchange/ShuffleExchangeExec

scala 复制代码
case class ShuffleExchangeExec(
    override val outputPartitioning: Partitioning,
    child: SparkPlan,
    shuffleOrigin: ShuffleOrigin = ENSURE_REQUIREMENTS)
  extends ShuffleExchangeLike

其中 outputPartitioning.numPartitions 就是 SQL 物理计划层面已经确定的 Shuffle 分区数。

执行时会构造 Shuffle 依赖:

scala 复制代码
val shuffleDependency = ShuffleExchangeExec.prepareShuffleDependency(
  inputRDD,
  child.output,
  outputPartitioning,
  serializer,
  writeMetrics)

进入 Spark Core 后,分区数会体现在 ShuffleDependencypartitioner 上。

core/src/main/scala/org/apache/spark/Dependency.scala

scala 复制代码
class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag](
    _rdd: RDD[_ <: Product2[K, V]],
    val partitioner: Partitioner,
    ...)

其中 partitioner.numPartitions 决定了 Shuffle reduce 端的分区数,也通常对应 reduce task 的数量。

核心结论

Spark SQL Shuffle 分区数决定机制主流程可以概括为:

text 复制代码
父算子声明 requiredChildDistribution
  -> EnsureRequirements 检查 child.outputPartitioning 是否满足要求
  -> 满足则复用 child
  -> 要求 BroadcastDistribution 则插入 BroadcastExchangeExec
  -> 否则插入 ShuffleExchangeExec
  -> Shuffle 分区数优先取 distribution.requiredNumPartitions
  -> 没有强制要求时取 conf.numShufflePartitions
  -> conf.numShufflePartitions 通常来自 spark.sql.shuffle.partitions,默认 200

Spark SQL 的 Shuffle 分区数不是由单一位置决定的,而是由逻辑计划、物理计划、分布需求、用户显式设置和 AQE 共同决定。

分区数的优先级可以按下面理解:

  1. 显式重分区优先,例如 df.repartition(n)repartitionByRange(n, ...),分区数直接使用用户指定的 n
  2. 算子强制分区数优先,例如 AllTuples.requiredNumPartitions = Some(1),会生成单分区 Shuffle。
  3. 普通 Join、Group By、Distinct、Window 等场景通常使用 conf.numShufflePartitions,也就是 spark.sql.shuffle.partitions
  4. 开启 AQE 后,spark.sql.shuffle.partitions 更多表示初始 Shuffle 分区数,最终 Shuffle Read task 数可能被合并或因倾斜拆分而变化。

参考

Spark SQL 源码研读系列 06:Executed Plan