Spark全栈指南:从入门到精通

Apache Spark 完整知识总结与使用教程

目录

  1. [Spark 概述与生态系统](#Spark 概述与生态系统)
  2. [Spark 架构深解](#Spark 架构深解)
  3. 核心数据抽象:RDD、DataFrame、Dataset
  4. [SparkSession 与 SparkContext](#SparkSession 与 SparkContext)
  5. [RDD 编程详解](#RDD 编程详解)
  6. [Spark SQL 与 DataFrame 编程](#Spark SQL 与 DataFrame 编程)
  7. [Catalyst 优化器与 Tungsten 执行引擎](#Catalyst 优化器与 Tungsten 执行引擎)
  8. [Shuffle 机制与 Join 策略](#Shuffle 机制与 Join 策略)
  9. [Structured Streaming 流处理](#Structured Streaming 流处理)
  10. [Spark MLlib 机器学习](#Spark MLlib 机器学习)
  11. [Spark GraphX 图计算](#Spark GraphX 图计算)
  12. 内存管理与持久化策略
  13. 分区策略与数据倾斜处理
  14. 性能调优全面指南
  15. 集群部署与资源管理
  16. 监控与调试
  17. [Spark 配置参数速查表](#Spark 配置参数速查表)
  18. 最佳实践与常见坑
  19. 参考资源

1. Spark 概述与生态系统

1.1 什么是 Apache Spark?

Apache Spark 是一个统一的大规模数据处理分析引擎,由加州大学伯克利分校 AMPLab 在 2009 年开发,并于 2010 年开源,2014 年成为 Apache 顶级项目。

核心特性:

特性 说明
速度 基于内存计算,比 MapReduce 快 10~100 倍
易用性 提供 Java、Scala、Python、R 高层 API
通用性 批处理、流处理、机器学习、图计算统一平台
兼容性 支持 HDFS、S3、Hive、HBase 等多种数据源
容错性 基于 RDD 血统(Lineage)自动重建丢失数据

1.2 Spark vs Hadoop MapReduce

复制代码
                MapReduce           Apache Spark
计算模式:       磁盘迭代             内存迭代
迭代速度:       慢(每次写磁盘)     快(常驻内存)
编程模型:       Map + Reduce        丰富的 API(map/filter/join/groupBy...)
实时流处理:     不支持              Structured Streaming
机器学习:       需要额外工具         内置 MLlib
延迟:           分钟级               秒级(甚至毫秒级)
适合场景:       超大规模批处理        批处理 + 流处理 + ML

1.3 Spark 生态系统全景

复制代码
┌─────────────────────────────────────────────────────────────────────┐
│                        Apache Spark 生态                             │
├──────────────┬────────────────┬───────────────┬───────────────────-─┤
│   Spark SQL  │ Spark Streaming│    MLlib       │      GraphX         │
│  结构化查询  │  流式处理       │ 机器学习       │     图计算           │
├──────────────┴────────────────┴───────────────┴────────────────────-┤
│                         Spark Core(核心引擎)                        │
│              RDD、DAG、任务调度、内存管理、容错                        │
├──────────────────────────────────────────────────────────────────────┤
│                         集群资源管理层                                │
│          Standalone  |  YARN  |  Mesos  |  Kubernetes               │
├──────────────────────────────────────────────────────────────────────┤
│                            存储层                                     │
│        HDFS  |  S3  |  HBase  |  Kafka  |  Cassandra  |  JDBC       │
└──────────────────────────────────────────────────────────────────────┘

六大核心组件:

组件 功能
Spark Core 基础引擎,提供 RDD、任务调度、内存管理、故障恢复
Spark SQL 结构化数据处理,支持 SQL/DataFrame/Dataset API
Structured Streaming 基于 Spark SQL 的流处理引擎(推荐使用)
MLlib 分布式机器学习库,覆盖常用算法
GraphX 图计算框架,支持图并行计算
SparkR R 语言接口,提供 DataFrame 操作

2. Spark 架构深解

2.1 整体架构

Spark 采用主从(Master-Worker)架构,由以下核心组件构成:

复制代码
用户程序(Driver Program)
         │
         │ 创建 SparkContext/SparkSession
         ▼
┌─────────────────────┐
│    Driver(驱动节点) │  ← 负责将用户代码转化为任务并协调执行
│  · SparkContext      │
│  · DAGScheduler      │
│  · TaskScheduler     │
└──────────┬──────────┘
           │ 请求资源 & 分配任务
           ▼
┌─────────────────────┐
│   Cluster Manager   │  ← 资源管理器(Standalone/YARN/K8s/Mesos)
│  (资源分配与调度)    │
└──────────┬──────────┘
           │ 启动 Executors
     ┌─────┴─────┐
     ▼           ▼
┌─────────┐ ┌─────────┐
│Executor │ │Executor │  ← 工作节点,真正执行计算任务
│ · Task  │ │ · Task  │
│ · Cache │ │ · Cache │
└─────────┘ └─────────┘

2.2 Driver(驱动节点)

Driver 是 Spark 应用程序的"大脑",运行用户的 main 函数,主要职责:

  • 创建 SparkContext,初始化 Spark 运行环境
  • 将用户的 RDD/DataFrame 操作转化为 DAG(有向无环图)
  • 通过 DAGScheduler 将 DAG 切分为 Stage
  • 通过 TaskScheduler 将 Task 分配给 Executor 执行
  • 跟踪各 Task 的执行状态,处理失败重试

2.3 Executor(执行节点)

Executor 是运行在工作节点(Worker Node)上的 JVM 进程:

  • 每个 Spark 应用有一组独立的 Executor
  • 负责执行 Driver 分配的 Task
  • 在内存中缓存 RDD 的分区数据(用于 cache()/persist()
  • 执行结束后,Executor 进程退出并释放资源

2.4 DAG、Stage、Task 的关系

复制代码
用户代码(Transformations + Action)
         │
         ▼
    DAG(有向无环图)
    代表整个作业的计算逻辑
         │
         │ DAGScheduler 按 Shuffle 边界切分
         ▼
    Stage(阶段)         ← 以 Shuffle 为边界划分
    ├── Stage 1(窄依赖操作的 Pipeline)
    │   ├── Task 1-1(处理 Partition 1)
    │   ├── Task 1-2(处理 Partition 2)
    │   └── Task 1-N(处理 Partition N)
    └── Stage 2(Shuffle 之后)
        ├── Task 2-1
        └── Task 2-N
         │
         ▼
    Action 触发执行,结果返回 Driver 或写入存储

关键概念:

概念 说明
Job 由一个 Action 触发的完整计算,包含一或多个 Stage
Stage 由窄依赖操作组成的任务集合,Stage 之间以 Shuffle 分隔
Task Stage 中处理一个数据分区(Partition)的最小执行单元
Shuffle 数据在 Executor 之间重新分配,是最昂贵的操作

2.5 宽依赖 vs 窄依赖

复制代码
窄依赖(Narrow Dependency):
  父 RDD 的每个分区最多被子 RDD 的一个分区使用
  示例:map、filter、union、mapPartitions

  Partition 1 ──→ Partition 1'
  Partition 2 ──→ Partition 2'     ← 一对一,可在单 Stage 内 Pipeline 执行
  Partition 3 ──→ Partition 3'

宽依赖(Wide Dependency / Shuffle Dependency):
  父 RDD 的每个分区可能被子 RDD 的多个分区使用
  示例:groupByKey、reduceByKey、join、sortBy

  Partition 1 ─┬──→ Partition 1'
               └──→ Partition 2'   ← 需要 Shuffle,触发新 Stage
  Partition 2 ─┬──→ Partition 1'
               └──→ Partition 3'

3. 核心数据抽象:RDD、DataFrame、Dataset

3.1 三种抽象对比

维度 RDD DataFrame Dataset
版本 Spark 1.x Spark 1.3+ Spark 1.6+
数据类型 任意 JVM 对象 Row(无类型) 强类型对象
类型安全 编译时安全 运行时检查 编译时安全
Schema 有(列名+类型)
Catalyst 优化 ❌ 无 ✅ 有 ✅ 有
Tungsten ❌ 无 ✅ 有 ✅ 有
Python/R 支持 ❌(仅 Scala/Java)
适用场景 非结构化数据、精细控制 结构化/SQL 查询 类型安全的结构化处理

3.2 如何选择?

复制代码
数据是否有结构/Schema?
  ├── 否(如媒体流、文本流、自定义对象)→ 使用 RDD
  └── 是(有列名和类型)
       ├── 使用 Scala/Java 且需要编译时类型安全 → 使用 Dataset
       └── 使用 Python/R,或侧重 SQL/聚合 → 使用 DataFrame

黄金法则 :现代 Spark 开发优先使用 DataFrame/Dataset,它们通过 Catalyst + Tungsten 获得自动优化;只有在需要底层控制或处理非结构化数据时才用 RDD。


4. SparkSession 与 SparkContext

4.1 SparkSession(Spark 2.0+ 统一入口)

python 复制代码
from pyspark.sql import SparkSession

# 创建 SparkSession(推荐方式)
spark = SparkSession.builder \
    .appName("MySparkApp") \
    .master("local[*]") \           # local[*] 使用所有本地核心
    .config("spark.executor.memory", "4g") \
    .config("spark.executor.cores", "2") \
    .config("spark.sql.shuffle.partitions", "200") \
    .enableHiveSupport() \          # 可选:启用 Hive 元数据支持
    .getOrCreate()                  # 已存在则复用

# 获取 SparkContext(低层 API)
sc = spark.sparkContext

# 关闭 Session(程序结束时调用)
spark.stop()
scala 复制代码
// Scala 版本
import org.apache.spark.sql.SparkSession

val spark = SparkSession.builder()
  .appName("MySparkApp")
  .master("local[*]")
  .config("spark.sql.shuffle.partitions", "200")
  .getOrCreate()

import spark.implicits._  // 引入隐式转换,支持 DataFrame 操作符

4.2 SparkContext 的核心功能

python 复制代码
sc = spark.sparkContext

# 读取文本文件为 RDD
rdd = sc.textFile("hdfs://path/to/file.txt")

# 并行化本地集合
data = sc.parallelize([1, 2, 3, 4, 5], numSlices=4)

# 设置日志级别
sc.setLogLevel("WARN")  # ERROR / WARN / INFO / DEBUG

# 广播变量
lookup_table = {"a": 1, "b": 2}
broadcast_var = sc.broadcast(lookup_table)

# 累加器
counter = sc.accumulator(0)

5. RDD 编程详解

5.1 创建 RDD

python 复制代码
# 方法 1:并行化本地集合
rdd1 = sc.parallelize([1, 2, 3, 4, 5])
rdd1 = sc.parallelize(range(100), numSlices=10)  # 指定 10 个分区

# 方法 2:从文件系统读取
rdd2 = sc.textFile("file:///local/path/file.txt")        # 本地文件
rdd2 = sc.textFile("hdfs://namenode:8020/path/file.txt") # HDFS
rdd2 = sc.textFile("s3a://bucket/prefix/*.csv")          # S3

# 方法 3:从其他 RDD 转换
rdd3 = rdd1.map(lambda x: x * 2)

# 方法 4:DataFrame 转 RDD
df = spark.read.parquet("/path/to/parquet")
rdd4 = df.rdd

5.2 Transformations(转换操作)--- 懒执行

操作 说明 示例
map(f) 对每个元素应用函数 rdd.map(lambda x: x*2)
flatMap(f) 展平结果(返回 0~N 个元素) rdd.flatMap(lambda x: x.split())
filter(f) 过滤满足条件的元素 rdd.filter(lambda x: x > 0)
mapPartitions(f) 对每个分区应用函数(高效) rdd.mapPartitions(process_partition)
mapPartitionsWithIndex(f) 带分区索引的 mapPartitions -
union(other) 合并两个 RDD rdd1.union(rdd2)
intersection(other) 取交集 rdd1.intersection(rdd2)
distinct() 去重 rdd.distinct()
sample(withReplacement, fraction) 随机采样 rdd.sample(False, 0.1)
groupByKey() 按 Key 分组(慢,慎用) pairs.groupByKey()
reduceByKey(f) 按 Key 聚合(比 groupByKey 快) pairs.reduceByKey(lambda a,b: a+b)
aggregateByKey(zeroValue, seqOp, combOp) 分区内和分区间分别聚合 -
sortByKey(ascending=True) 按 Key 排序 pairs.sortByKey()
join(other) 内连接 rdd1.join(rdd2)
leftOuterJoin(other) 左外连接 rdd1.leftOuterJoin(rdd2)
cogroup(other) 按 Key 合并多个 RDD rdd1.cogroup(rdd2)
repartition(n) 重新分区(完整 Shuffle) rdd.repartition(100)
coalesce(n) 减少分区(无 Shuffle,更高效) rdd.coalesce(10)
persist(level) 持久化到内存/磁盘 rdd.persist(StorageLevel.MEMORY_AND_DISK)
cache() 等价于 persist(MEMORY_ONLY) rdd.cache()

5.3 Actions(动作操作)--- 触发执行

操作 说明 示例
collect() 收集所有元素到 Driver(小数据量) rdd.collect()
count() 统计元素个数 rdd.count()
first() 返回第一个元素 rdd.first()
take(n) 返回前 n 个元素 rdd.take(10)
top(n) 返回最大的 n 个元素 rdd.top(5)
reduce(f) 对元素做聚合操作 rdd.reduce(lambda a,b: a+b)
fold(zeroValue, f) 带初始值的聚合 rdd.fold(0, lambda a,b: a+b)
aggregate(zeroValue, seqOp, combOp) 分阶段聚合 -
foreach(f) 对每个元素执行副作用 rdd.foreach(print)
saveAsTextFile(path) 保存为文本文件 rdd.saveAsTextFile("hdfs://...")
saveAsPickleFile(path) 保存为 Pickle 格式 -
countByKey() 统计每个 Key 的数量 pairs.countByKey()
collectAsMap() 收集为字典 pairs.collectAsMap()
takeSample(withRep, n) 随机采样 n 个元素 rdd.takeSample(False, 100)

5.4 完整 RDD 示例:WordCount

python 复制代码
from pyspark.sql import SparkSession

spark = SparkSession.builder.appName("WordCount").master("local[*]").getOrCreate()
sc = spark.sparkContext

# 读取文本
text_rdd = sc.textFile("hdfs://path/to/text.txt")

# 处理流程
word_counts = (
    text_rdd
    .flatMap(lambda line: line.lower().strip().split())   # 分词
    .filter(lambda word: len(word) > 0)                   # 过滤空词
    .map(lambda word: (word, 1))                          # 映射为 (word, 1)
    .reduceByKey(lambda a, b: a + b)                      # 按词累加
    .sortBy(lambda kv: kv[1], ascending=False)            # 按频次降序
)

# 输出 Top 10
for word, count in word_counts.take(10):
    print(f"{word}: {count}")

# 保存到 HDFS
word_counts.saveAsTextFile("hdfs://path/to/output/wordcount")

5.5 共享变量

广播变量(Broadcast Variables)

广播变量将大型只读数据高效分发到各 Executor,避免在每个 Task 中都传输一份

python 复制代码
# 比如:大型词典、模型参数等
country_code = {"CN": "China", "US": "United States", "JP": "Japan"}
bc_country = sc.broadcast(country_code)

def lookup_country(code):
    return bc_country.value.get(code, "Unknown")  # 通过 .value 访问

result = rdd.map(lambda x: lookup_country(x))

# 手动释放广播变量(节省内存)
bc_country.unpersist()
bc_country.destroy()
累加器(Accumulators)

累加器是一种只能从 Executor 端累加、只能在 Driver 端读取的共享变量:

python 复制代码
error_count = sc.accumulator(0)
valid_count = sc.accumulator(0)

def parse_record(line):
    try:
        data = json.loads(line)
        valid_count.add(1)
        return data
    except:
        error_count.add(1)
        return None

result = rdd.map(parse_record).filter(lambda x: x is not None)
result.count()  # 触发执行

print(f"有效记录:{valid_count.value}")
print(f"错误记录:{error_count.value}")

⚠️ 注意:在 Transformation 中使用 Accumulator 时,因为 Task 重试机制,累加器值可能被多次累加。只有在 Action 之后读取累加器值才是准确的。


6. Spark SQL 与 DataFrame 编程

6.1 读取数据

python 复制代码
# ── CSV ──
df = spark.read \
    .option("header", "true") \
    .option("inferSchema", "true") \
    .option("sep", ",") \
    .option("nullValue", "NA") \
    .csv("hdfs://path/to/data.csv")

# ── JSON ──
df = spark.read \
    .option("multiLine", "true") \
    .json("hdfs://path/to/data.json")

# ── Parquet(推荐格式,列式存储+压缩+Schema) ──
df = spark.read.parquet("hdfs://path/to/data.parquet")

# ── ORC ──
df = spark.read.orc("hdfs://path/to/data.orc")

# ── JDBC/数据库 ──
df = spark.read \
    .format("jdbc") \
    .option("url", "jdbc:mysql://host:3306/mydb") \
    .option("dbtable", "mytable") \
    .option("user", "user") \
    .option("password", "password") \
    .option("numPartitions", 10) \
    .option("partitionColumn", "id") \
    .option("lowerBound", 1) \
    .option("upperBound", 1000000) \
    .load()

# ── Kafka(流式读取,详见 Structured Streaming 章节) ──
df = spark.read \
    .format("kafka") \
    .option("kafka.bootstrap.servers", "host:9092") \
    .option("subscribe", "topic_name") \
    .load()

# ── Hive 表 ──
df = spark.sql("SELECT * FROM hive_db.my_table")

6.2 Schema 定义

python 复制代码
from pyspark.sql.types import (
    StructType, StructField, StringType, IntegerType,
    LongType, DoubleType, BooleanType, TimestampType, ArrayType, MapType
)

# 手动定义 Schema(推荐用于生产环境,避免 inferSchema 的性能开销)
schema = StructType([
    StructField("user_id",    LongType(),      nullable=False),
    StructField("name",       StringType(),     nullable=True),
    StructField("age",        IntegerType(),    nullable=True),
    StructField("salary",     DoubleType(),     nullable=True),
    StructField("is_active",  BooleanType(),    nullable=True),
    StructField("created_at", TimestampType(),  nullable=True),
    StructField("tags",       ArrayType(StringType()), nullable=True),
    StructField("metadata",   MapType(StringType(), StringType()), nullable=True),
])

df = spark.read.schema(schema).csv("path/to/data.csv")

# 也可使用 DDL 字符串定义 Schema(Spark 2.3+)
schema_ddl = "user_id BIGINT NOT NULL, name STRING, age INT, salary DOUBLE"
df = spark.read.schema(schema_ddl).csv("path/to/data.csv")

6.3 DataFrame 常用操作

python 复制代码
# ── 查看数据 ──
df.show(20, truncate=False)   # 显示 20 行,不截断
df.printSchema()               # 打印 Schema 树形结构
df.dtypes                      # 返回列名和数据类型列表
df.columns                     # 返回列名列表
df.count()                     # 统计行数
df.describe().show()           # 数值列的统计摘要(min/max/mean/stddev)

# ── 列操作 ──
from pyspark.sql import functions as F

df.select("name", "age")                                  # 选择列
df.select(F.col("name"), F.col("age") + 1)               # 列运算
df.selectExpr("name", "age * 2 as double_age")           # SQL 表达式
df.withColumn("adult", F.col("age") >= 18)               # 添加新列
df.withColumn("age", F.col("age").cast("integer"))       # 类型转换
df.withColumnRenamed("name", "full_name")                 # 重命名列
df.drop("unwanted_col1", "unwanted_col2")                 # 删除列

# ── 过滤 ──
df.filter(F.col("age") > 25)
df.filter("age > 25 AND salary > 50000")          # SQL 字符串条件
df.where(F.col("name").isNotNull())               # where 是 filter 的别名
df.filter(F.col("status").isin("active", "pending"))
df.filter(F.col("name").like("%John%"))
df.filter(F.col("created_at").between("2024-01-01", "2024-12-31"))

# ── 聚合 ──
df.groupBy("department") \
  .agg(
      F.count("*").alias("total"),
      F.avg("salary").alias("avg_salary"),
      F.max("salary").alias("max_salary"),
      F.min("salary").alias("min_salary"),
      F.sum("salary").alias("total_salary"),
      F.countDistinct("user_id").alias("unique_users"),
      F.collect_list("name").alias("names"),
      F.stddev("salary").alias("salary_stddev")
  )

# ── 排序 ──
df.orderBy("age")
df.orderBy(F.col("salary").desc(), F.col("name").asc())
df.sort(F.col("age").desc_nulls_last())

# ── 去重 ──
df.distinct()                              # 全行去重
df.dropDuplicates(["user_id", "email"])    # 按指定列去重

# ── 缺失值处理 ──
df.dropna()                                 # 删除含任何 null 的行
df.dropna(subset=["name", "age"])           # 删除指定列含 null 的行
df.dropna(how="all")                        # 仅删除所有列都是 null 的行
df.fillna(0)                                # 所有 null 填充为 0
df.fillna({"age": 0, "name": "unknown"})    # 按列填充
df.na.replace([None, ""], "N/A", subset=["name"])

# ── Join 操作 ──
# 支持类型:inner, left, right, full, left_semi, left_anti, cross
result = df1.join(df2, on="user_id", how="inner")
result = df1.join(df2, df1["id"] == df2["user_id"], "left")
result = df1.join(df2, ["id", "name"], "inner")          # 多列 Join

# Broadcast Join(小表广播)
from pyspark.sql.functions import broadcast
result = large_df.join(broadcast(small_df), "key")

# ── 窗口函数 ──
from pyspark.sql.window import Window

window_spec = Window.partitionBy("department").orderBy(F.col("salary").desc())
df.withColumn("rank", F.rank().over(window_spec))
df.withColumn("dense_rank", F.dense_rank().over(window_spec))
df.withColumn("row_number", F.row_number().over(window_spec))
df.withColumn("lag_salary", F.lag("salary", 1).over(window_spec))
df.withColumn("lead_salary", F.lead("salary", 1).over(window_spec))

# 窗口帧(rows between / range between)
window_rolling = Window.partitionBy("user_id") \
    .orderBy("date") \
    .rowsBetween(-6, 0)  # 过去 7 行滚动窗口
df.withColumn("7day_avg", F.avg("sales").over(window_rolling))

6.4 Spark SQL 使用

python 复制代码
# 注册为临时视图
df.createOrReplaceTempView("users")

# 注册为全局临时视图(跨 SparkSession)
df.createOrReplaceGlobalTempView("global_users")

# 执行 SQL 查询
result = spark.sql("""
    SELECT
        department,
        COUNT(*) as employee_count,
        AVG(salary) as avg_salary,
        PERCENTILE_APPROX(salary, 0.5) as median_salary
    FROM users
    WHERE age BETWEEN 25 AND 55
    GROUP BY department
    HAVING COUNT(*) >= 10
    ORDER BY avg_salary DESC
""")

# 全局视图访问需要加前缀
result = spark.sql("SELECT * FROM global_temp.global_users")

# 在 Hive 表上执行 SQL
spark.sql("CREATE TABLE IF NOT EXISTS output_db.result AS SELECT * FROM users WHERE age > 30")
spark.sql("INSERT INTO output_db.result SELECT * FROM users WHERE age <= 30")

6.5 保存数据

python 复制代码
# ── 保存为 Parquet(推荐) ──
df.write.mode("overwrite").parquet("hdfs://path/to/output.parquet")

# ── 分区写入(Hive 风格分区) ──
df.write \
  .partitionBy("year", "month") \
  .mode("overwrite") \
  .parquet("hdfs://path/to/partitioned_output")

# ── 控制输出文件数 ──
df.repartition(10).write.mode("overwrite").parquet("path/to/output")
df.coalesce(1).write.mode("overwrite").csv("path/single_file.csv")   # 输出单文件

# ── 保存为 ORC ──
df.write.mode("append").orc("hdfs://path/to/output.orc")

# ── 写入 Hive 表 ──
df.write.insertInto("hive_db.my_table", overwrite=True)
df.write.saveAsTable("hive_db.new_table")  # 创建新 Hive 表

# ── 写入 JDBC ──
df.write \
  .format("jdbc") \
  .option("url", "jdbc:postgresql://host:5432/mydb") \
  .option("dbtable", "output_table") \
  .option("user", "user") \
  .option("password", "password") \
  .mode("append") \
  .save()

# ── 写入模式 ──
# "overwrite"  - 覆盖已有数据
# "append"     - 追加到已有数据
# "ignore"     - 目标存在则跳过(不报错)
# "error"      - 目标存在则抛出错误(默认)

7. Catalyst 优化器与 Tungsten 执行引擎

7.1 Catalyst 优化器

Catalyst 是 Spark SQL 的核心查询优化器,基于 Scala 函数式编程实现,对所有 DataFrame/Dataset/SQL 操作自动生效。

优化流程(四阶段):

复制代码
用户代码(DataFrame API / SQL)
         │
         ▼
1. 解析(Parsing)
   └─ 生成 Unresolved Logical Plan(未解析逻辑计划)
         │
         ▼
2. 分析(Analysis)
   └─ 结合 Catalog 验证列名、数据类型
   └─ 生成 Resolved Logical Plan(已解析逻辑计划)
         │
         ▼
3. 逻辑优化(Logical Optimization)[Catalyst 的核心]
   ├─ 谓词下推(Predicate Pushdown):Filter 尽量靠近数据源
   ├─ 列裁剪(Column Pruning):只读取需要的列
   ├─ 常量折叠(Constant Folding):提前计算常量表达式
   ├─ Join 重排序(Join Reordering):调整 Join 顺序减少 Shuffle
   ├─ 子查询消除(Subquery Elimination)
   └─ 生成 Optimized Logical Plan(优化后逻辑计划)
         │
         ▼
4. 物理规划(Physical Planning)
   ├─ 生成多个物理执行计划
   ├─ 基于代价模型(CBO)选择最优执行计划
   └─ 生成 Physical Plan(物理执行计划)
         │
         ▼
5. 代码生成(Code Generation - Tungsten)
   └─ 通过 Whole-Stage CodeGen 生成优化的 JVM 字节码执行

查看执行计划:

python 复制代码
# 查看完整执行计划(推荐)
df.explain(mode="extended")   # 显示 Parsed / Analyzed / Optimized / Physical 四个阶段

# 其他模式
df.explain()                   # 仅物理计划(默认)
df.explain(True)               # 等价于 extended
df.explain(mode="formatted")   # 格式化输出(Spark 3.0+)
df.explain(mode="cost")        # 显示代价信息(需 CBO 统计)
df.explain(mode="codegen")     # 显示生成的 Java 代码

# SQL 查看执行计划
spark.sql("EXPLAIN EXTENDED SELECT * FROM users WHERE age > 30").show(truncate=False)

7.2 Catalyst 关键优化规则

谓词下推(Predicate Pushdown)

将 Filter 操作推到尽可能早的阶段,减少数据量:

python 复制代码
# 写法等价,但 Catalyst 会自动将 filter 推入扫描阶段
df.join(df2, "id").filter(F.col("age") > 30)
# 等价于(Catalyst 会优化为此):
df.filter(F.col("age") > 30).join(df2, "id")

# 对于 Parquet/ORC 文件,谓词甚至会 pushdown 到文件层面
# 只读取满足条件的 row group,大幅减少 IO
列裁剪(Column Pruning)
python 复制代码
# 只选择需要的列,Catalyst 自动避免读取不用的列
df.select("name", "age").filter(F.col("age") > 30)
# Catalyst 会告诉 Parquet 读取器:只读 name 和 age 两列
AQE(自适应查询执行,Spark 3.0+)

AQE 在运行时根据实际数据统计动态调整查询计划:

python 复制代码
# 启用 AQE(Spark 3.2+ 默认开启)
spark.conf.set("spark.sql.adaptive.enabled", "true")

# AQE 的三大核心功能:
# 1. 动态合并小 Shuffle 分区
spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", "true")
spark.conf.set("spark.sql.adaptive.advisoryPartitionSizeInBytes", "128MB")

# 2. 动态切换 Join 策略(如将 Sort-Merge Join 切换为 Broadcast Join)
spark.conf.set("spark.sql.adaptive.autoBroadcastJoinThreshold", "30MB")

# 3. 动态处理数据倾斜
spark.conf.set("spark.sql.adaptive.skewJoin.enabled", "true")
spark.conf.set("spark.sql.adaptive.skewJoin.skewedPartitionFactor", "5")         # 超过中位数 5 倍认为倾斜
spark.conf.set("spark.sql.adaptive.skewJoin.skewedPartitionThresholdInBytes", "256MB")

7.3 Tungsten 执行引擎

Tungsten 专注于 CPU 和内存效率的极限优化:

技术 说明
堆外内存管理(Off-Heap Memory) 使用 sun.misc.Unsafe 直接管理内存,绕过 GC
缓存友好数据结构 UnsafeRow 格式,紧凑存储,提高 CPU 缓存命中率
全阶段代码生成(Whole-Stage CodeGen) 将多个算子融合为单个 JVM 方法,消除虚函数调用开销
向量化执行(Vectorized Execution) 批量处理列式数据,充分利用 SIMD 指令

8. Shuffle 机制与 Join 策略

8.1 Shuffle 工作原理

Shuffle 是 Spark 中最昂贵的操作,发生在宽依赖(如 groupBy、join、sortBy)时:

复制代码
Shuffle 执行流程:

Stage 1(Map 端):
  Executor A:Task 1 → 将数据按目标分区 Hash 写入 shuffle 文件
  Executor B:Task 2 → 将数据按目标分区 Hash 写入 shuffle 文件

             ↓ 网络传输(最耗时)

Stage 2(Reduce 端):
  Executor C:Task 1 → 从各 Executor 拉取属于自己的分区数据
  Executor D:Task 2 → 从各 Executor 拉取属于自己的分区数据

Shuffle 的代价:
  - 磁盘 I/O(写入 shuffle 文件)
  - 序列化/反序列化
  - 网络传输(最主要的瓶颈)

8.2 Join 策略详解

Spark 支持 5 种 Join 策略,自动或手动选择:

策略 1:Broadcast Hash Join(BHJ)--- 最快
复制代码
原理:将小表广播到每个 Executor,每个 Executor 本地执行 Hash Join
适用:小表 ≤ spark.sql.autoBroadcastJoinThreshold(默认 10MB)
优点:无 Shuffle,速度最快
缺点:小表必须能装入 Executor 内存;不支持某些 join 类型(如 full outer)

      Driver 收集小表
           │
     广播到所有 Executor
     ┌─────┼─────┐
  Exec-A Exec-B Exec-C   ← 每个 Executor 持有完整小表
     │     │     │
     └──本地 Hash Join──┘   ← 大表分区与本地小表直接 Join
python 复制代码
from pyspark.sql.functions import broadcast

# 方法 1:Python API
result = large_df.join(broadcast(small_df), "key")

# 方法 2:SQL hint
result = spark.sql("""
    SELECT /*+ BROADCAST(small_table) */ *
    FROM large_table
    JOIN small_table ON large_table.key = small_table.key
""")

# 调整广播阈值(字节)
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "50MB")  # 调大到 50MB
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "-1")    # 禁用自动广播
策略 2:Sort-Merge Join(SMJ)--- 大表默认选择
复制代码
原理:对两张表按 Join Key shuffle + sort,再 merge
适用:两张大表无法广播的等值 Join
优点:可处理超大表,支持所有 Join 类型,可 spill 到磁盘
缺点:需要全量 Shuffle + Sort,开销最大

步骤:
  1. 两张表各自按 Join Key 进行 Shuffle
  2. 每个分区内按 Join Key 排序
  3. 像归并排序一样合并两个有序分区
python 复制代码
# SQL hint 强制使用 Sort-Merge Join
result = spark.sql("""
    SELECT /*+ MERGE(t1) */ *
    FROM t1 JOIN t2 ON t1.key = t2.key
""")

# 配置
spark.conf.set("spark.sql.join.preferSortMergeJoin", "true")  # 默认 true
策略 3:Shuffle Hash Join(SHJ)
复制代码
原理:Shuffle 后,在每个分区的较小侧建立 Hash Table,与大侧 probe
适用:两表均较大但有一侧明显更小,且 Key 分布均匀
优点:比 Sort-Merge 快(无需排序)
缺点:内存不足可能 OOM(不能 spill 排序阶段)

AQE 可在运行时将 SMJ 动态切换为 SHJ:
  当 shuffle 后分区数据量小于 spark.sql.adaptive.maxShuffledHashJoinLocalMapThreshold 时自动切换
python 复制代码
# 强制 Shuffle Hash Join
result = df1.hint("shuffle_hash").join(df2, "key")
Join 策略选择决策树
复制代码
两表的 Join 操作
│
├── 有一侧 ≤ autoBroadcastJoinThreshold (默认 10MB)?
│   └── YES → Broadcast Hash Join(最快,无 Shuffle)
│
├── 都 > 广播阈值,但有一侧明显更小?
│   └── YES(小侧能放入单个分区内存)→ Shuffle Hash Join
│
└── 两侧都很大?
    └── Sort-Merge Join(稳健,支持 spill)
        └── AQE 开启时:若 shuffle 后小侧显著缩小 → 动态切换为 SHJ 或 BHJ

8.3 数据倾斜(Skew)处理

数据倾斜是 Spark 最常见的性能杀手,某些 Key 的数据量远大于其他 Key:

方法 1:AQE 自动处理(推荐,Spark 3.0+)
python 复制代码
spark.conf.set("spark.sql.adaptive.enabled", "true")
spark.conf.set("spark.sql.adaptive.skewJoin.enabled", "true")
# AQE 会自动将大分区切分并复制小侧数据来处理
方法 2:手动加盐(Salting)
python 复制代码
import random
from pyspark.sql import functions as F

# 加盐系数
NUM_SALT = 20

# 对倾斜的大表加盐(随机前缀)
skewed_df = large_df.withColumn(
    "salted_key",
    F.concat(F.col("join_key"), F.lit("_"), (F.rand() * NUM_SALT).cast("int").cast("string"))
)

# 对小表扩展(每个 Key 复制 NUM_SALT 份)
small_expanded = small_df.withColumn(
    "salt",
    F.explode(F.array([F.lit(i) for i in range(NUM_SALT)]))
).withColumn(
    "salted_key",
    F.concat(F.col("join_key"), F.lit("_"), F.col("salt").cast("string"))
)

# Join
result = skewed_df.join(small_expanded, "salted_key").drop("salted_key", "salt")
方法 3:分离倾斜 Key 单独处理
python 复制代码
# 找出倾斜 Key
skew_keys = ["hotkey1", "hotkey2"]

# 拆分 DataFrame
skewed_df = df.filter(F.col("key").isin(skew_keys))
normal_df = df.filter(~F.col("key").isin(skew_keys))

# 分别 Join(倾斜侧用 broadcast)
skewed_result = skewed_df.join(broadcast(small_df), "key")
normal_result = normal_df.join(small_df, "key")

# 合并结果
final_result = normal_result.union(skewed_result)

9. Structured Streaming 流处理

9.1 流处理核心概念

Structured Streaming 将流数据视为一张持续增长的无界表(Unbounded Table),用与批处理相同的 DataFrame/SQL API 处理:

复制代码
输入数据流(Kafka/文件/Socket)
    │
    ▼ 每个 trigger 时间点追加新行
┌─────────────────────────────┐
│       Input Table           │   ← 概念上的无界输入表
│  (只追加,从不修改旧行)      │
└─────────────────────────────┘
    │
    ▼ 用户定义的查询(与批处理 API 完全一致)
┌─────────────────────────────┐
│       Result Table          │   ← 查询结果表(持续更新)
└─────────────────────────────┘
    │
    ▼ Output Sink(按 Output Mode 写出)
  文件系统 / Kafka / Console / Memory / 自定义 Sink

9.2 输入源(Source)

python 复制代码
# ── Kafka Source(最常用)──
kafka_df = spark \
    .readStream \
    .format("kafka") \
    .option("kafka.bootstrap.servers", "host1:9092,host2:9092") \
    .option("subscribe", "topic1,topic2") \         # 订阅多个 topic
    .option("startingOffsets", "latest") \           # earliest / latest / JSON 指定
    .option("maxOffsetsPerTrigger", 100000) \        # 每次触发最多消费数量
    .option("failOnDataLoss", "false") \
    .load()

# Kafka value 是 binary,需要 cast
from pyspark.sql.functions import col, from_json

schema = StructType([...])
parsed_df = kafka_df \
    .select(
        col("key").cast("string"),
        from_json(col("value").cast("string"), schema).alias("data"),
        col("timestamp"),
        col("partition"),
        col("offset")
    ) \
    .select("key", "data.*", "timestamp")

# ── 文件 Source(监控目录)──
file_df = spark \
    .readStream \
    .format("csv") \
    .schema(schema) \
    .option("path", "hdfs://path/to/incoming/") \
    .option("maxFilesPerTrigger", 10) \   # 每次最多处理 10 个文件
    .load()

# ── Socket Source(调试用)──
socket_df = spark \
    .readStream \
    .format("socket") \
    .option("host", "localhost") \
    .option("port", 9999) \
    .load()

9.3 流处理操作

python 复制代码
# 与批处理完全一致的 API
result_df = parsed_df \
    .filter(F.col("amount") > 0) \
    .groupBy(
        F.window("event_time", "10 minutes", "5 minutes"),  # 10分钟窗口,5分钟滑动
        F.col("user_id")
    ) \
    .agg(
        F.sum("amount").alias("total_amount"),
        F.count("*").alias("tx_count")
    )

9.4 水位线(Watermark)处理延迟数据

python 复制代码
# 设置水位线:允许最多 10 分钟的延迟数据
windowed_df = parsed_df \
    .withWatermark("event_time", "10 minutes") \   # event_time 是事件时间列
    .groupBy(
        F.window("event_time", "5 minutes"),
        "category"
    ) \
    .agg(F.sum("amount").alias("total"))

9.5 输出 Sink 与 Output Mode

python 复制代码
# ── 三种 Output Mode ──
# complete: 输出所有结果(适合 aggregation,但内存压力大)
# append:   只输出新增行(无聚合或有水位线的聚合)
# update:   只输出变化的行(聚合时常用)

# ── Console Sink(调试)──
query = result_df.writeStream \
    .outputMode("update") \
    .format("console") \
    .option("truncate", False) \
    .trigger(processingTime="10 seconds") \
    .start()

# ── Kafka Sink ──
query = result_df \
    .selectExpr("CAST(user_id AS STRING) AS key", "to_json(struct(*)) AS value") \
    .writeStream \
    .format("kafka") \
    .option("kafka.bootstrap.servers", "host:9092") \
    .option("topic", "output_topic") \
    .option("checkpointLocation", "hdfs://path/to/checkpoint/") \
    .outputMode("append") \
    .start()

# ── 文件 Sink(Parquet)──
query = result_df.writeStream \
    .format("parquet") \
    .option("path", "hdfs://path/to/output/") \
    .option("checkpointLocation", "hdfs://path/to/checkpoint/") \
    .partitionBy("date") \
    .outputMode("append") \
    .trigger(processingTime="1 minute") \
    .start()

# ── 自定义 ForeachBatch Sink(最灵活,Spark 2.4+)──
def write_to_db(batch_df, batch_id):
    """每个 micro-batch 会调用此函数"""
    batch_df.persist()
    # 写到数据库
    batch_df.write.format("jdbc") \
        .option("url", "jdbc:...") \
        .option("dbtable", "result_table") \
        .mode("append") \
        .save()
    # 同时发送到 Kafka
    batch_df.selectExpr(...).write.format("kafka")...
    batch_df.unpersist()

query = result_df.writeStream \
    .foreachBatch(write_to_db) \
    .option("checkpointLocation", "hdfs://path/to/checkpoint/") \
    .outputMode("update") \
    .start()

# ── 等待查询结束 ──
query.awaitTermination()    # 阻塞
query.stop()                # 手动停止

9.6 Trigger(触发器)类型

python 复制代码
from pyspark.sql.streaming import Trigger

# 固定时间间隔(默认 0,尽快处理)
.trigger(processingTime="30 seconds")

# 只处理一次 micro-batch,然后停止(Spark 2.2+)
.trigger(once=True)

# 持续处理(Spark 2.3+,低延迟实验性功能)
.trigger(continuous="1 second")

# 按可用数据处理(每次处理所有可用数据)
.trigger(availableNow=True)    # Spark 3.3+

9.7 容错语义

Structured Streaming 提供**端到端精确一次(Exactly-Once)**保证:

  • Checkpointing:Spark 将 offset 和状态存入 checkpoint 目录(HDFS/S3)
  • Idempotent Sink:幂等写入(如 Parquet 文件、有事务支持的数据库)
  • 可重放的 Source:如 Kafka 支持从指定 offset 重新消费

10. Spark MLlib 机器学习

10.1 MLlib 概述

MLlib 是 Spark 内置的分布式机器学习库,基于 DataFrame API(spark.ml 包)。

复制代码
MLlib 主要功能:
├── 特征工程
│   ├── 特征提取(TF-IDF、Word2Vec、HashingTF)
│   ├── 特征转换(StandardScaler、MinMaxScaler、PCA)
│   └── 特征选择(ChiSqSelector、VarianceThresholdSelector)
├── 分类算法
│   ├── 逻辑回归(LogisticRegression)
│   ├── 随机森林(RandomForestClassifier)
│   ├── 梯度提升树(GBTClassifier)
│   ├── 支持向量机(LinearSVC)
│   └── 多层感知机(MultilayerPerceptronClassifier)
├── 回归算法
│   ├── 线性回归(LinearRegression)
│   ├── 随机森林回归(RandomForestRegressor)
│   └── 梯度提升树回归(GBTRegressor)
├── 聚类算法
│   ├── K-Means(KMeans)
│   ├── 高斯混合(GaussianMixture)
│   └── 层次聚类(BisectingKMeans)
├── 协同过滤(ALS 推荐系统)
└── 模型评估与选择(CrossValidator、TrainValidationSplit)

10.2 完整 ML Pipeline 示例

python 复制代码
from pyspark.ml import Pipeline
from pyspark.ml.feature import (
    VectorAssembler, StringIndexer, OneHotEncoder,
    StandardScaler, Imputer
)
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder

# 加载数据
data = spark.read.parquet("hdfs://path/to/ml_data.parquet")

# 划分训练集和测试集
train_data, test_data = data.randomSplit([0.8, 0.2], seed=42)

# ── 特征工程 ──
# 1. 填充缺失值
imputer = Imputer(
    inputCols=["age", "salary", "years_exp"],
    outputCols=["age_imp", "salary_imp", "years_exp_imp"],
    strategy="median"
)

# 2. 类别型变量编码
indexer = StringIndexer(inputCol="department", outputCol="dept_idx")
encoder = OneHotEncoder(inputCol="dept_idx", outputCol="dept_ohe")

# 3. 组合特征向量
assembler = VectorAssembler(
    inputCols=["age_imp", "salary_imp", "years_exp_imp", "dept_ohe"],
    outputCol="features_raw"
)

# 4. 标准化
scaler = StandardScaler(
    inputCol="features_raw",
    outputCol="features",
    withMean=True,
    withStd=True
)

# 5. 目标变量编码
label_indexer = StringIndexer(inputCol="label", outputCol="label_idx")

# ── 模型 ──
rf = RandomForestClassifier(
    featuresCol="features",
    labelCol="label_idx",
    numTrees=100,
    maxDepth=10,
    seed=42
)

# ── Pipeline ──
pipeline = Pipeline(stages=[
    imputer, indexer, encoder, assembler, scaler, label_indexer, rf
])

# ── 超参数搜索 ──
param_grid = ParamGridBuilder() \
    .addGrid(rf.numTrees, [50, 100, 200]) \
    .addGrid(rf.maxDepth, [5, 10, 15]) \
    .build()

evaluator = BinaryClassificationEvaluator(
    labelCol="label_idx",
    rawPredictionCol="rawPrediction",
    metricName="areaUnderROC"
)

cv = CrossValidator(
    estimator=pipeline,
    estimatorParamMaps=param_grid,
    evaluator=evaluator,
    numFolds=5,
    seed=42
)

# ── 训练 ──
cv_model = cv.fit(train_data)

# ── 评估 ──
predictions = cv_model.transform(test_data)
auc = evaluator.evaluate(predictions)
print(f"AUC-ROC: {auc:.4f}")

# ── 保存模型 ──
cv_model.bestModel.write().overwrite().save("hdfs://path/to/model/")

# ── 加载模型 ──
from pyspark.ml import PipelineModel
loaded_model = PipelineModel.load("hdfs://path/to/model/")

10.3 ALS 推荐系统

python 复制代码
from pyspark.ml.recommendation import ALS
from pyspark.ml.evaluation import RegressionEvaluator

# 数据格式:userId, itemId, rating
ratings = spark.read.csv("ratings.csv", header=True, inferSchema=True)
train, test = ratings.randomSplit([0.8, 0.2])

# 训练 ALS 模型
als = ALS(
    maxIter=10,
    regParam=0.01,
    rank=10,             # 隐因子维度
    userCol="userId",
    itemCol="movieId",
    ratingCol="rating",
    coldStartStrategy="drop"   # 处理冷启动问题
)
model = als.fit(train)

# 评估 RMSE
predictions = model.transform(test)
evaluator = RegressionEvaluator(
    metricName="rmse", labelCol="rating", predictionCol="prediction"
)
rmse = evaluator.evaluate(predictions)
print(f"Root-mean-square error = {rmse:.4f}")

# 为每个用户生成 Top 10 推荐
user_recs = model.recommendForAllUsers(10)
# 为每个物品找到 Top 10 潜在用户
item_recs = model.recommendForAllItems(10)

11. Spark GraphX 图计算

11.1 GraphX 基础

GraphX 在 Spark RDD 基础上构建图计算框架,提供 Graph 抽象:

scala 复制代码
// Scala 示例(GraphX 主要以 Scala 使用)
import org.apache.spark.graphx._
import org.apache.spark.rdd.RDD

// 创建顶点 RDD:(VertexId, 属性)
val vertices: RDD[(VertexId, String)] = sc.parallelize(Array(
    (1L, "Alice"), (2L, "Bob"), (3L, "Charlie"), (4L, "David")
))

// 创建边 RDD:Edge(srcId, dstId, 属性)
val edges: RDD[Edge[String]] = sc.parallelize(Array(
    Edge(1L, 2L, "friend"),
    Edge(2L, 3L, "colleague"),
    Edge(3L, 4L, "friend"),
    Edge(4L, 1L, "family")
))

// 构建图
val graph = Graph(vertices, edges, defaultVertexAttr = "unknown")

// 图的基本属性
println(s"顶点数: ${graph.numVertices}")
println(s"边数: ${graph.numEdges}")

// 查看顶点
graph.vertices.collect().foreach(println)

// 查看边
graph.edges.collect().foreach(println)

// 图算法:PageRank
val pageRankGraph = graph.pageRank(0.001)  // 容差值

// 三角形计数
val triangleCount = graph.triangleCount()

// 连通分量
val cc = graph.connectedComponents()

// 最短路径(Pregel API)
import org.apache.spark.graphx.lib.ShortestPaths
val shortestPaths = ShortestPaths.run(graph, Seq(1L))

12. 内存管理与持久化策略

12.1 Spark 内存模型

Spark 的 Executor JVM 内存划分(Spark 1.6+ 统一内存管理):

复制代码
Executor JVM 总内存
├── Reserved Memory(系统保留):300 MB(固定)
├── User Memory(1 - spark.memory.fraction = 40%)
│   └── 用户数据结构、UDF、Spark 内部对象
└── Spark Memory(spark.memory.fraction = 60%)
    ├── Execution Memory(动态共享)
    │   └── Shuffle、Sort、Join 的中间数据
    └── Storage Memory(动态共享)
        └── RDD/DataFrame Cache、广播变量

注:Execution 和 Storage 内存是动态共享的(spark.memory.storageFraction 控制初始比例)
    当执行内存不足时,可以驱逐 Storage 内存

关键配置:

python 复制代码
# 调整 Spark 内存占比(默认 0.6,即 60% JVM 堆用于 Spark)
spark.conf.set("spark.memory.fraction", "0.7")

# 在 Spark 内存中,Storage 的初始比例(默认 0.5)
spark.conf.set("spark.memory.storageFraction", "0.4")

# Executor 堆内存(在 spark-submit 或 config 中设置)
# --executor-memory 4g
# spark.executor.memory = 4g

# 堆外内存(Tungsten 使用,默认 0.1 * executor memory)
spark.conf.set("spark.memory.offHeap.enabled", "true")
spark.conf.set("spark.memory.offHeap.size", "2g")

12.2 持久化级别

python 复制代码
from pyspark import StorageLevel

# 常用持久化级别
df.cache()                                          # 等价于 MEMORY_AND_DISK
df.persist()                                        # 默认:MEMORY_AND_DISK
df.persist(StorageLevel.MEMORY_ONLY)                # 仅内存,不够则重算
df.persist(StorageLevel.MEMORY_AND_DISK)            # 优先内存,溢出到磁盘
df.persist(StorageLevel.MEMORY_ONLY_SER)            # 序列化存内存(节省空间)
df.persist(StorageLevel.MEMORY_AND_DISK_SER)        # 序列化,内存+磁盘
df.persist(StorageLevel.DISK_ONLY)                  # 仅磁盘
df.persist(StorageLevel.OFF_HEAP)                   # 堆外内存(需开启)

# 释放缓存
df.unpersist()

# 清除所有缓存
spark.catalog.clearCache()

选择持久化级别的原则:

复制代码
内存充足?
├── YES → MEMORY_ONLY(速度最快)
│   如果仍然 OOM → MEMORY_ONLY_SER(序列化节省约 50% 空间)
└── NO  → MEMORY_AND_DISK(内存放不下就 spill 到磁盘)
    需要极致压缩 → MEMORY_AND_DISK_SER
    内存极其紧张 → DISK_ONLY

何时使用 Cache?

python 复制代码
# ✅ 应该缓存:
# 1. DataFrame 被多次复用
base_df = spark.read.parquet("...").cache()
result1 = base_df.filter(...)    # 第一次用
result2 = base_df.groupBy(...)   # 第二次用(直接从 cache 读取)

# 2. 迭代计算(如 ML 训练中的特征 DataFrame)
train_features.cache()
for epoch in range(100):
    model.train(train_features)

# ❌ 不应该缓存:
# 1. 只用一次的 DataFrame
# 2. 非常大无法放入内存的 DataFrame(溢出反而更慢)
# 3. 来源读取很快(如本地 SSD Parquet)

13. 分区策略与数据倾斜处理

13.1 分区基础

python 复制代码
# 查看分区数
df.rdd.getNumPartitions()

# 增加分区(触发 Shuffle)
df_repartitioned = df.repartition(200)

# 增加并按列分区(相同 key 会到同一分区,有助于 Join 优化)
df_repartitioned = df.repartition(200, "user_id")

# 减少分区(不触发 Shuffle,比 repartition 更高效)
df_coalesced = df.coalesce(50)

# 推荐分区数
# 一般原则:分区数 ≈ 2~4 × CPU 核心总数
# shuffle 分区数(default 200)
spark.conf.set("spark.sql.shuffle.partitions", "400")

# AQE 开启时,让 Spark 自动调整 shuffle 分区数
spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", "true")
spark.conf.set("spark.sql.adaptive.advisoryPartitionSizeInBytes", "128MB")

13.2 分区大小建议

复制代码
分区大小:理想在 100MB ~ 200MB 之间
  太大 → 内存压力大,Task 运行时间长,容错重试代价高
  太小 → 任务调度开销大,Shuffle 文件过多

检查当前分区情况:
df.rdd.glom().map(len).collect()  # 每个分区的元素数量
df.groupBy(spark_partition_id()).count()  # 每个分区的行数

14. 性能调优全面指南

14.1 优化层次总览

复制代码
性能优化金字塔(重要性从高到低):

① 代码层(影响最大)
   └─ 使用 DataFrame/Dataset 而非 RDD
   └─ 避免不必要的 Shuffle(尽早 filter 和 select)
   └─ 使用内置函数而非 UDF
   └─ 合理使用 Cache

② 数据层
   └─ 使用 Parquet/ORC 列式存储格式
   └─ 合理分区和 Bucket(Bucketing)
   └─ 数据倾斜处理

③ 配置层
   └─ 合理的 Executor 配置
   └─ AQE 配置
   └─ Shuffle 参数调整

④ 资源层(影响最小但最直接)
   └─ 增加 Executor 数量和内存
   └─ 使用更快的存储(SSD)
   └─ 升级网络带宽

14.2 代码优化技巧

1. 尽早过滤和列裁剪
python 复制代码
# ❌ 不好:先 join 再过滤(join 了大量不需要的数据)
result = df1.join(df2, "id").filter(F.col("date") >= "2024-01-01")

# ✅ 好:先过滤再 join
result = df1.filter(F.col("date") >= "2024-01-01") \
            .select("id", "date", "amount") \          # 只选需要的列
            .join(df2.select("id", "name"), "id")      # 同样只选需要的列
2. 避免使用 Python UDF(改用内置函数)
python 复制代码
from pyspark.sql import functions as F

# ❌ 慢:Python UDF(序列化开销,Catalyst 无法优化)
from pyspark.sql.functions import udf
from pyspark.sql.types import StringType

@udf(returnType=StringType())
def get_domain(email):
    return email.split("@")[1] if "@" in email else None

df.withColumn("domain", get_domain(F.col("email")))  # 慢!

# ✅ 快:使用内置函数
df.withColumn("domain", F.split(F.col("email"), "@").getItem(1))

# 如果必须用 UDF,考虑 Pandas UDF(向量化执行,快 10-100 倍)
from pyspark.sql.functions import pandas_udf
import pandas as pd

@pandas_udf(StringType())
def get_domain_pandas(emails: pd.Series) -> pd.Series:
    return emails.str.split("@").str.get(1)

df.withColumn("domain", get_domain_pandas(F.col("email")))  # 比 UDF 快得多
3. reduceByKey vs groupByKey
python 复制代码
# ❌ 慢:groupByKey 在 shuffle 前不做局部聚合,传输大量数据
pairs.groupByKey().mapValues(sum)

# ✅ 快:reduceByKey 在 shuffle 前先做局部聚合(combiner 优化)
pairs.reduceByKey(lambda a, b: a + b)

# 对于复杂聚合,使用 aggregateByKey
pairs.aggregateByKey(
    (0, 0),                           # zeroValue: (sum, count)
    lambda acc, v: (acc[0]+v, acc[1]+1),  # seqOp: 分区内聚合
    lambda acc1, acc2: (acc1[0]+acc2[0], acc1[1]+acc2[1])  # combOp: 分区间合并
).mapValues(lambda x: x[0]/x[1])     # 计算平均值
4. 合理使用 repartition 和 coalesce
python 复制代码
# 在大型操作之前适当增加分区
df.repartition(400, "join_key").join(other_df, "join_key")

# 写入文件前合并小分区(避免生成大量小文件)
result.coalesce(10).write.parquet("path/to/output/")

# 注意:repartition 触发 Shuffle,coalesce 不触发
# 如果需要减少分区数,优先用 coalesce

14.3 存储格式优化

python 复制代码
# Parquet 是最推荐的存储格式
# ✅ 优点:列式存储、压缩率高、支持 Schema、谓词下推到文件层
df.write \
    .option("compression", "snappy") \   # gzip(高压缩率)/ snappy(快速)/ zstd(平衡)
    .partitionBy("year", "month") \       # 按时间分区,方便过滤
    .parquet("hdfs://path/to/output/")

# 读取时推断 Schema 代价大,生产环境显式指定 Schema
df = spark.read.schema(predefined_schema).parquet("path/to/parquet/")

14.4 Executor 资源配置

bash 复制代码
# spark-submit 资源配置示例
spark-submit \
  --master yarn \
  --deploy-mode cluster \
  --driver-memory 4g \
  --executor-memory 8g \
  --executor-cores 4 \
  --num-executors 20 \
  --conf "spark.sql.shuffle.partitions=400" \
  --conf "spark.sql.adaptive.enabled=true" \
  --conf "spark.serializer=org.apache.spark.serializer.KryoSerializer" \
  --conf "spark.default.parallelism=400" \
  --conf "spark.dynamicAllocation.enabled=true" \
  --conf "spark.dynamicAllocation.minExecutors=5" \
  --conf "spark.dynamicAllocation.maxExecutors=50" \
  my_app.py

Executor 配置黄金公式(YARN 环境):

复制代码
假设:节点 RAM = 128 GB,节点 CPU = 32 核

预留给 OS 和 YARN:RAM 留 10%,CPU 留 1 核
可用:RAM ≈ 115 GB,CPU ≈ 31 核

每个 Executor 建议配置 5 个 cores(减少 GC 压力,提高 HDFS 吞吐量)
每个节点 Executor 数 = 31 / 5 ≈ 6 个
每个 Executor 内存 = 115 / 6 ≈ 19 GB

减去 Executor 开销(overhead = max(384MB, 0.1 * executor-memory))
executor-memory = 19 - 2 ≈ 17 GB

最终配置:
  --executor-cores 5
  --executor-memory 17g
  --num-executors (节点数 × 6 - 1)  # -1 留给 driver

14.5 常见性能问题速查

问题现象 可能原因 解决方案
Stage 中个别 Task 特别慢 数据倾斜 AQE skewJoin、加盐、分离倾斜 Key
OOM(OutOfMemoryError) 内存不足、分区太少 增加内存、增加分区数、使用 coalesce
Shuffle 时间占比高 宽依赖太多 减少 groupBy/join、使用 broadcast join
读取速度慢 文件格式不佳、分区扫描 改用 Parquet/ORC、合理 partition pruning
GC 时间过长 JVM 对象过多 使用 Kryo 序列化、开启堆外内存
任务调度延迟高 分区太多(小分区) 使用 AQE 自动合并、调大分区大小
CPU 利用率低 IO 密集 / 数据加载慢 压缩存储、预取、增加并行度

15. 集群部署与资源管理

15.1 部署模式对比

模式 特点 适用场景
Local 单机模拟,local[N] 指定线程数 开发测试
Standalone Spark 自带集群管理器 纯 Spark 集群
YARN 与 Hadoop 生态集成 Hadoop 环境
Kubernetes 容器化部署,弹性伸缩 云原生/容器环境
Mesos 通用资源管理(逐渐被淘汰) -

15.2 spark-submit 提交应用

bash 复制代码
# 提交 PySpark 应用到 YARN
spark-submit \
  --master yarn \
  --deploy-mode cluster \           # client(driver 在本机)/ cluster(driver 在集群)
  --name "My Spark Job" \
  --driver-memory 4g \
  --executor-memory 8g \
  --executor-cores 4 \
  --num-executors 20 \
  --files /path/to/config.json \    # 附带文件
  --py-files /path/to/utils.zip \   # Python 模块
  --jars /path/to/extra.jar \       # 额外 JAR 包
  --conf "spark.sql.shuffle.partitions=400" \
  --conf "spark.sql.adaptive.enabled=true" \
  /path/to/your_script.py \
  arg1 arg2                         # 脚本参数

15.3 动态资源分配

bash 复制代码
# 开启动态分配(自动根据负载增减 Executor)
--conf "spark.dynamicAllocation.enabled=true"
--conf "spark.dynamicAllocation.minExecutors=2"
--conf "spark.dynamicAllocation.maxExecutors=100"
--conf "spark.dynamicAllocation.initialExecutors=10"
--conf "spark.dynamicAllocation.executorIdleTimeout=60s"
--conf "spark.shuffle.service.enabled=true"   # 动态分配需要外部 Shuffle Service

15.4 Kubernetes 部署

bash 复制代码
# 提交到 K8s 集群
spark-submit \
  --master k8s://https://k8s-api-server:6443 \
  --deploy-mode cluster \
  --name spark-example \
  --conf "spark.executor.instances=5" \
  --conf "spark.kubernetes.container.image=apache/spark:3.5.0" \
  --conf "spark.kubernetes.namespace=spark" \
  --conf "spark.kubernetes.authenticate.driver.serviceAccountName=spark" \
  local:///path/to/your_script.py

16. 监控与调试

16.1 Spark UI 使用指南

Spark UI 默认运行在 Driver 节点的 4040 端口(历史 Server 在 18080):

复制代码
Spark UI 核心页面:

Jobs 页面
└─ 查看所有 Job 的执行状态、时间、进度条
   点击 Job → 查看组成的 Stages

Stages 页面
└─ 最重要的调优页面
   查看每个 Stage 的:
   ├── 输入/输出/Shuffle 数据量
   ├── Task 分布(看是否有长尾 Task = 数据倾斜)
   ├── GC 时间(过高说明内存压力大)
   └── Spill 量(Disk 和 Memory Spill 说明内存不足)

Storage 页面
└─ 查看被 cache 的 RDD/DataFrame 及其内存占用

Environment 页面
└─ 查看 Spark 配置参数

Executors 页面
└─ 查看各 Executor 的内存使用、GC 时间、Task 数量

SQL 页面(DataFrame/SQL)
└─ 可视化执行计划 DAG
   查看每个节点处理的行数和时间
   识别性能瓶颈节点

16.2 常用调试命令

python 复制代码
# ── 查看执行计划 ──
df.explain(mode="formatted")

# ── 查看分区信息 ──
print(f"分区数: {df.rdd.getNumPartitions()}")

# 查看每个分区的数据量
df.groupBy(F.spark_partition_id().alias("partition_id")).count().orderBy("partition_id").show()

# ── 查看数据分布 ──
df.groupBy("key_column").count() \
  .orderBy(F.col("count").desc()) \
  .show(20)

# ── 统计基本信息 ──
df.summary().show()  # 比 describe 更详细

# ── 查看 Cache 状态 ──
spark.catalog.listTables()  # 列出所有注册的表

# ── 查看 Job 信息(程序内)──
sc.statusTracker().getJobIDs()
sc.statusTracker().getStageInfo(stageId)

16.3 日志配置

properties 复制代码
# log4j2.properties(Spark 3.3+ 使用 log4j2)
rootLogger.level = WARN
appender.console.type = Console
appender.console.name = console
appender.console.layout.type = PatternLayout
appender.console.layout.pattern = %d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n%ex

# 调整特定包的日志级别
logger.spark.name = org.apache.spark
logger.spark.level = WARN
logger.hadoop.name = org.apache.hadoop
logger.hadoop.level = WARN

# 自己的应用调试
logger.myapp.name = com.mycompany.myapp
logger.myapp.level = DEBUG
python 复制代码
# 在代码中动态设置日志级别
spark.sparkContext.setLogLevel("WARN")  # ERROR/WARN/INFO/DEBUG

17. Spark 配置参数速查表

17.1 核心配置

参数 默认值 说明
spark.executor.memory 1g Executor JVM 堆内存
spark.executor.cores 1(YARN)/ 所有(Standalone) 每个 Executor 的 CPU 核心数
spark.driver.memory 1g Driver 内存
spark.default.parallelism 集群 CPU 核心总数 × 2 RDD 默认并行度
spark.sql.shuffle.partitions 200 Shuffle 后分区数
spark.memory.fraction 0.6 JVM 堆分配给 Spark 的比例
spark.memory.storageFraction 0.5 Spark 内存中 Storage 的初始比例
spark.serializer JavaSerializer 推荐改为 KryoSerializer

17.2 AQE 配置

参数 默认值 说明
spark.sql.adaptive.enabled true(3.2+) 启用 AQE
spark.sql.adaptive.coalescePartitions.enabled true 动态合并小分区
spark.sql.adaptive.advisoryPartitionSizeInBytes 64MB 目标分区大小
spark.sql.adaptive.skewJoin.enabled true 自动处理倾斜 Join
spark.sql.adaptive.skewJoin.skewedPartitionFactor 5 倾斜判断因子
spark.sql.adaptive.skewJoin.skewedPartitionThresholdInBytes 256MB 倾斜分区大小阈值

17.3 Join 与广播配置

参数 默认值 说明
spark.sql.autoBroadcastJoinThreshold 10MB 广播 Join 的表大小阈值
spark.sql.join.preferSortMergeJoin true 优先 Sort-Merge Join
spark.broadcast.compress true 广播时压缩数据

17.4 Shuffle 配置

参数 默认值 说明
spark.shuffle.compress true 压缩 Shuffle 数据
spark.shuffle.spill.compress true 压缩 Spill 到磁盘的数据
spark.reducer.maxSizeInFlight 48MB Reduce 端每次拉取的最大数据量
spark.shuffle.file.buffer 32KB Shuffle 写文件时的缓冲区大小

17.5 Streaming 配置

参数 默认值 说明
spark.streaming.kafka.maxRatePerPartition 不限 Kafka 每分区每秒最大消费速率
spark.sql.streaming.checkpointLocation - 默认 checkpoint 路径
spark.sql.streaming.minBatchesToRetain 100 保留的最小 batch 元数据数量

18. 最佳实践与常见坑

18.1 黄金法则

复制代码
✅ 优先使用 DataFrame/Dataset API,而非 RDD
✅ 优先使用内置函数(spark.sql.functions),而非 Python/Scala UDF
✅ 尽早执行 filter 和 select(减少数据量)
✅ 使用 Parquet/ORC 格式,并开启分区
✅ 对小表使用 broadcast join
✅ 开启 AQE(Spark 3.0+,自动优化大量场景)
✅ 使用 explain() 检查执行计划
✅ 合理设置 shuffle 分区数(非默认 200)
✅ 写入前使用 coalesce 合并小分区
✅ 敏感配置不要硬编码,用外部配置文件
✅ 生产环境明确指定 Schema,不要用 inferSchema

18.2 常见陷阱

复制代码
❌ 在 Driver 上执行 collect() 获取大数据集 → OOM
   解决:改用 write() 或 take(n) 获取小样本

❌ 在 foreach/map 中修改 Driver 变量 → 变量不会同步回 Driver
   解决:使用 Accumulator

❌ 对大 DataFrame 频繁调用 count() → 每次都触发全量计算
   解决:cache() 后再多次 count(),或只统计一次

❌ 在循环中创建大量 RDD/DataFrame 而不释放 → 内存泄漏
   解决:及时 unpersist(),或重构避免循环创建

❌ groupByKey 后做 sum → 全量 shuffle 再计算
   解决:改用 reduceByKey(有 map-side combine)

❌ SQL 写了 SELECT * → 读取所有列,破坏列裁剪优化
   解决:显式列出需要的列

❌ 使用 Python lambda 函数并序列化整个模块 → 序列化失败
   解决:确保 lambda 捕获的变量可序列化,避免捕获 SparkContext

❌ Shuffle 分区默认 200,不随数据量调整 → 分区过大或过小
   解决:根据数据量调整,或开启 AQE 自动调整

❌ 重复读取相同数据源 → 多次 IO
   解决:读取一次,cache,然后多次使用

❌ 在 Streaming 中使用 collect() → 阻塞流处理
   解决:改用 foreachBatch

18.3 代码质量规范

python 复制代码
# ✅ 推荐的代码结构
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.types import StructType, StructField, StringType, LongType

def create_spark_session(app_name: str) -> SparkSession:
    """创建 SparkSession,生产环境使用 getOrCreate 保证幂等"""
    return SparkSession.builder \
        .appName(app_name) \
        .config("spark.sql.adaptive.enabled", "true") \
        .config("spark.sql.shuffle.partitions", "400") \
        .getOrCreate()

def read_data(spark: SparkSession, path: str, schema: StructType):
    """显式指定 Schema 读取数据"""
    return spark.read.schema(schema).parquet(path)

def process_data(df):
    """业务处理逻辑,链式调用保持清晰"""
    return (
        df
        .filter(F.col("date") >= "2024-01-01")                  # 尽早过滤
        .select("user_id", "amount", "category", "date")         # 尽早列裁剪
        .withColumn("year", F.year(F.col("date")))
        .groupBy("category", "year")
        .agg(
            F.sum("amount").alias("total_amount"),
            F.count("user_id").alias("user_count")
        )
        .orderBy(F.col("total_amount").desc())
    )

def write_data(df, output_path: str):
    """分区写入,控制输出文件数量"""
    df.repartition(20) \
      .write \
      .mode("overwrite") \
      .partitionBy("year") \
      .parquet(output_path)

def main():
    spark = create_spark_session("DataProcessingJob")

    # 定义 Schema
    schema = StructType([
        StructField("user_id", LongType(), False),
        StructField("amount", DoubleType(), True),
        StructField("category", StringType(), True),
        StructField("date", StringType(), True),
    ])

    # 读取
    raw_df = read_data(spark, "hdfs://input/path/", schema)

    # 处理(缓存复用的 DataFrame)
    processed_df = process_data(raw_df)
    processed_df.cache()

    # 输出
    write_data(processed_df, "hdfs://output/path/")

    # 验证
    print(f"输出行数: {processed_df.count()}")

    spark.stop()

if __name__ == "__main__":
    main()

19. 参考资源

官方文档

资源 链接
Apache Spark 官方文档 https://spark.apache.org/docs/latest/
Spark SQL 编程指南 https://spark.apache.org/docs/latest/sql-programming-guide.html
RDD 编程指南 https://spark.apache.org/docs/latest/rdd-programming-guide.html
Structured Streaming 指南 https://spark.apache.org/docs/latest/structured-streaming-programming-guide.html
MLlib 指南 https://spark.apache.org/docs/latest/ml-guide.html
GraphX 指南 https://spark.apache.org/docs/latest/graphx-programming-guide.html
性能调优官方指南 https://spark.apache.org/docs/latest/sql-performance-tuning.html
配置参数说明 https://spark.apache.org/docs/latest/configuration.html
Databricks 官方博客 https://www.databricks.com/blog

学习路径建议

复制代码
初学者(0~3 个月):
  1. 理解 Spark 架构(Driver、Executor、DAG)
  2. 学习 DataFrame API(本文第 6 章)
  3. 理解 RDD 基础和 Transformation/Action
  4. 实践:本地 WordCount → 数据清洗 ETL
  推荐:Spark 官方 Quick Start + 本文 1~6 章

进阶(3~9 个月):
  1. 深入 Catalyst 和 Tungsten(第 7 章)
  2. 掌握 Join 策略和数据倾斜处理(第 8 章)
  3. 学习 Structured Streaming(第 9 章)
  4. 性能调优实践(第 14 章)
  5. 实践:构建完整 ETL Pipeline + 实时流处理
  推荐:Spark: The Definitive Guide(O'Reilly)

专家(9 个月+):
  1. MLlib 分布式机器学习
  2. Spark 源码阅读(Catalyst、AQE)
  3. 自定义 Data Source / Catalyst Rule
  4. 生产集群运维与调优
  推荐:Learning Spark(O'Reilly)第二版

国内生态与工具

工具/平台 说明
阿里云 EMR 托管 Spark 集群,支持 HDFS/OSS
腾讯云 EMR Spark 与 Hadoop 集成服务
华为 MRS MapReduce Service,支持 Spark
Databricks Spark 商业化平台,性能最优
DolphinScheduler 国产开源分布式工作流调度器
Apache Kyuubi 统一 SQL 网关,支持 Spark Thrift Server
Delta Lake 数据湖存储层,支持 ACID 事务
Apache Iceberg 开放表格式,支持 Schema 演化
相关推荐
T06205142 小时前
【数据集】地市城市等级城市类型划分城市经纬度数据
大数据
大大大大晴天2 小时前
Flink技术实践-90%都会踩的状态坑
大数据·flink
Abcdzzr2 小时前
2026/4/6 Windows安装Kafka
分布式·kafka
lifallen2 小时前
Flink Agent 与 Checkpoint:主循环闭环与 Mailbox 事件驱动模型
java·大数据·人工智能·python·语言模型·flink
zxfBdd2 小时前
Spark Map算子异常处理方法
大数据·分布式·spark
白眼黑刺猬3 小时前
如何构建 Flink SQL 任务的血缘分析
大数据·面试·职场和发展·flink
lifallen3 小时前
Flink Agent:ActionTask 与可续跑状态机 (Coroutine/Continuation)
java·大数据·人工智能·语言模型·flink
一个有温度的技术博主3 小时前
告别单点瓶颈:Redis主从架构与读写分离实战
redis·分布式·缓存·架构
白眼黑刺猬3 小时前
字节二面:订单状态回撤: 支付回调延迟导致的“先退单后下单”乱序,Flink如何利用Watermark和状态处理?
大数据·面试·职场和发展·flink