Apache Spark 完整知识总结与使用教程
目录
- [Spark 概述与生态系统](#Spark 概述与生态系统)
- [Spark 架构深解](#Spark 架构深解)
- 核心数据抽象:RDD、DataFrame、Dataset
- [SparkSession 与 SparkContext](#SparkSession 与 SparkContext)
- [RDD 编程详解](#RDD 编程详解)
- [Spark SQL 与 DataFrame 编程](#Spark SQL 与 DataFrame 编程)
- [Catalyst 优化器与 Tungsten 执行引擎](#Catalyst 优化器与 Tungsten 执行引擎)
- [Shuffle 机制与 Join 策略](#Shuffle 机制与 Join 策略)
- [Structured Streaming 流处理](#Structured Streaming 流处理)
- [Spark MLlib 机器学习](#Spark MLlib 机器学习)
- [Spark GraphX 图计算](#Spark GraphX 图计算)
- 内存管理与持久化策略
- 分区策略与数据倾斜处理
- 性能调优全面指南
- 集群部署与资源管理
- 监控与调试
- [Spark 配置参数速查表](#Spark 配置参数速查表)
- 最佳实践与常见坑
- 参考资源
1. Spark 概述与生态系统
1.1 什么是 Apache Spark?
Apache Spark 是一个统一的大规模数据处理分析引擎,由加州大学伯克利分校 AMPLab 在 2009 年开发,并于 2010 年开源,2014 年成为 Apache 顶级项目。
核心特性:
| 特性 | 说明 |
|---|---|
| 速度 | 基于内存计算,比 MapReduce 快 10~100 倍 |
| 易用性 | 提供 Java、Scala、Python、R 高层 API |
| 通用性 | 批处理、流处理、机器学习、图计算统一平台 |
| 兼容性 | 支持 HDFS、S3、Hive、HBase 等多种数据源 |
| 容错性 | 基于 RDD 血统(Lineage)自动重建丢失数据 |
1.2 Spark vs Hadoop MapReduce
MapReduce Apache Spark
计算模式: 磁盘迭代 内存迭代
迭代速度: 慢(每次写磁盘) 快(常驻内存)
编程模型: Map + Reduce 丰富的 API(map/filter/join/groupBy...)
实时流处理: 不支持 Structured Streaming
机器学习: 需要额外工具 内置 MLlib
延迟: 分钟级 秒级(甚至毫秒级)
适合场景: 超大规模批处理 批处理 + 流处理 + ML
1.3 Spark 生态系统全景
┌─────────────────────────────────────────────────────────────────────┐
│ Apache Spark 生态 │
├──────────────┬────────────────┬───────────────┬───────────────────-─┤
│ Spark SQL │ Spark Streaming│ MLlib │ GraphX │
│ 结构化查询 │ 流式处理 │ 机器学习 │ 图计算 │
├──────────────┴────────────────┴───────────────┴────────────────────-┤
│ Spark Core(核心引擎) │
│ RDD、DAG、任务调度、内存管理、容错 │
├──────────────────────────────────────────────────────────────────────┤
│ 集群资源管理层 │
│ Standalone | YARN | Mesos | Kubernetes │
├──────────────────────────────────────────────────────────────────────┤
│ 存储层 │
│ HDFS | S3 | HBase | Kafka | Cassandra | JDBC │
└──────────────────────────────────────────────────────────────────────┘
六大核心组件:
| 组件 | 功能 |
|---|---|
| Spark Core | 基础引擎,提供 RDD、任务调度、内存管理、故障恢复 |
| Spark SQL | 结构化数据处理,支持 SQL/DataFrame/Dataset API |
| Structured Streaming | 基于 Spark SQL 的流处理引擎(推荐使用) |
| MLlib | 分布式机器学习库,覆盖常用算法 |
| GraphX | 图计算框架,支持图并行计算 |
| SparkR | R 语言接口,提供 DataFrame 操作 |
2. Spark 架构深解
2.1 整体架构
Spark 采用主从(Master-Worker)架构,由以下核心组件构成:
用户程序(Driver Program)
│
│ 创建 SparkContext/SparkSession
▼
┌─────────────────────┐
│ Driver(驱动节点) │ ← 负责将用户代码转化为任务并协调执行
│ · SparkContext │
│ · DAGScheduler │
│ · TaskScheduler │
└──────────┬──────────┘
│ 请求资源 & 分配任务
▼
┌─────────────────────┐
│ Cluster Manager │ ← 资源管理器(Standalone/YARN/K8s/Mesos)
│ (资源分配与调度) │
└──────────┬──────────┘
│ 启动 Executors
┌─────┴─────┐
▼ ▼
┌─────────┐ ┌─────────┐
│Executor │ │Executor │ ← 工作节点,真正执行计算任务
│ · Task │ │ · Task │
│ · Cache │ │ · Cache │
└─────────┘ └─────────┘
2.2 Driver(驱动节点)
Driver 是 Spark 应用程序的"大脑",运行用户的 main 函数,主要职责:
- 创建
SparkContext,初始化 Spark 运行环境 - 将用户的 RDD/DataFrame 操作转化为 DAG(有向无环图)
- 通过 DAGScheduler 将 DAG 切分为 Stage
- 通过 TaskScheduler 将 Task 分配给 Executor 执行
- 跟踪各 Task 的执行状态,处理失败重试
2.3 Executor(执行节点)
Executor 是运行在工作节点(Worker Node)上的 JVM 进程:
- 每个 Spark 应用有一组独立的 Executor
- 负责执行 Driver 分配的 Task
- 在内存中缓存 RDD 的分区数据(用于
cache()/persist()) - 执行结束后,Executor 进程退出并释放资源
2.4 DAG、Stage、Task 的关系
用户代码(Transformations + Action)
│
▼
DAG(有向无环图)
代表整个作业的计算逻辑
│
│ DAGScheduler 按 Shuffle 边界切分
▼
Stage(阶段) ← 以 Shuffle 为边界划分
├── Stage 1(窄依赖操作的 Pipeline)
│ ├── Task 1-1(处理 Partition 1)
│ ├── Task 1-2(处理 Partition 2)
│ └── Task 1-N(处理 Partition N)
└── Stage 2(Shuffle 之后)
├── Task 2-1
└── Task 2-N
│
▼
Action 触发执行,结果返回 Driver 或写入存储
关键概念:
| 概念 | 说明 |
|---|---|
| Job | 由一个 Action 触发的完整计算,包含一或多个 Stage |
| Stage | 由窄依赖操作组成的任务集合,Stage 之间以 Shuffle 分隔 |
| Task | Stage 中处理一个数据分区(Partition)的最小执行单元 |
| Shuffle | 数据在 Executor 之间重新分配,是最昂贵的操作 |
2.5 宽依赖 vs 窄依赖
窄依赖(Narrow Dependency):
父 RDD 的每个分区最多被子 RDD 的一个分区使用
示例:map、filter、union、mapPartitions
Partition 1 ──→ Partition 1'
Partition 2 ──→ Partition 2' ← 一对一,可在单 Stage 内 Pipeline 执行
Partition 3 ──→ Partition 3'
宽依赖(Wide Dependency / Shuffle Dependency):
父 RDD 的每个分区可能被子 RDD 的多个分区使用
示例:groupByKey、reduceByKey、join、sortBy
Partition 1 ─┬──→ Partition 1'
└──→ Partition 2' ← 需要 Shuffle,触发新 Stage
Partition 2 ─┬──→ Partition 1'
└──→ Partition 3'
3. 核心数据抽象:RDD、DataFrame、Dataset
3.1 三种抽象对比
| 维度 | RDD | DataFrame | Dataset |
|---|---|---|---|
| 版本 | Spark 1.x | Spark 1.3+ | Spark 1.6+ |
| 数据类型 | 任意 JVM 对象 | Row(无类型) | 强类型对象 |
| 类型安全 | 编译时安全 | 运行时检查 | 编译时安全 |
| Schema | 无 | 有(列名+类型) | 有 |
| Catalyst 优化 | ❌ 无 | ✅ 有 | ✅ 有 |
| Tungsten | ❌ 无 | ✅ 有 | ✅ 有 |
| Python/R 支持 | ✅ | ✅ | ❌(仅 Scala/Java) |
| 适用场景 | 非结构化数据、精细控制 | 结构化/SQL 查询 | 类型安全的结构化处理 |
3.2 如何选择?
数据是否有结构/Schema?
├── 否(如媒体流、文本流、自定义对象)→ 使用 RDD
└── 是(有列名和类型)
├── 使用 Scala/Java 且需要编译时类型安全 → 使用 Dataset
└── 使用 Python/R,或侧重 SQL/聚合 → 使用 DataFrame
黄金法则 :现代 Spark 开发优先使用 DataFrame/Dataset,它们通过 Catalyst + Tungsten 获得自动优化;只有在需要底层控制或处理非结构化数据时才用 RDD。
4. SparkSession 与 SparkContext
4.1 SparkSession(Spark 2.0+ 统一入口)
python
from pyspark.sql import SparkSession
# 创建 SparkSession(推荐方式)
spark = SparkSession.builder \
.appName("MySparkApp") \
.master("local[*]") \ # local[*] 使用所有本地核心
.config("spark.executor.memory", "4g") \
.config("spark.executor.cores", "2") \
.config("spark.sql.shuffle.partitions", "200") \
.enableHiveSupport() \ # 可选:启用 Hive 元数据支持
.getOrCreate() # 已存在则复用
# 获取 SparkContext(低层 API)
sc = spark.sparkContext
# 关闭 Session(程序结束时调用)
spark.stop()
scala
// Scala 版本
import org.apache.spark.sql.SparkSession
val spark = SparkSession.builder()
.appName("MySparkApp")
.master("local[*]")
.config("spark.sql.shuffle.partitions", "200")
.getOrCreate()
import spark.implicits._ // 引入隐式转换,支持 DataFrame 操作符
4.2 SparkContext 的核心功能
python
sc = spark.sparkContext
# 读取文本文件为 RDD
rdd = sc.textFile("hdfs://path/to/file.txt")
# 并行化本地集合
data = sc.parallelize([1, 2, 3, 4, 5], numSlices=4)
# 设置日志级别
sc.setLogLevel("WARN") # ERROR / WARN / INFO / DEBUG
# 广播变量
lookup_table = {"a": 1, "b": 2}
broadcast_var = sc.broadcast(lookup_table)
# 累加器
counter = sc.accumulator(0)
5. RDD 编程详解
5.1 创建 RDD
python
# 方法 1:并行化本地集合
rdd1 = sc.parallelize([1, 2, 3, 4, 5])
rdd1 = sc.parallelize(range(100), numSlices=10) # 指定 10 个分区
# 方法 2:从文件系统读取
rdd2 = sc.textFile("file:///local/path/file.txt") # 本地文件
rdd2 = sc.textFile("hdfs://namenode:8020/path/file.txt") # HDFS
rdd2 = sc.textFile("s3a://bucket/prefix/*.csv") # S3
# 方法 3:从其他 RDD 转换
rdd3 = rdd1.map(lambda x: x * 2)
# 方法 4:DataFrame 转 RDD
df = spark.read.parquet("/path/to/parquet")
rdd4 = df.rdd
5.2 Transformations(转换操作)--- 懒执行
| 操作 | 说明 | 示例 |
|---|---|---|
map(f) |
对每个元素应用函数 | rdd.map(lambda x: x*2) |
flatMap(f) |
展平结果(返回 0~N 个元素) | rdd.flatMap(lambda x: x.split()) |
filter(f) |
过滤满足条件的元素 | rdd.filter(lambda x: x > 0) |
mapPartitions(f) |
对每个分区应用函数(高效) | rdd.mapPartitions(process_partition) |
mapPartitionsWithIndex(f) |
带分区索引的 mapPartitions | - |
union(other) |
合并两个 RDD | rdd1.union(rdd2) |
intersection(other) |
取交集 | rdd1.intersection(rdd2) |
distinct() |
去重 | rdd.distinct() |
sample(withReplacement, fraction) |
随机采样 | rdd.sample(False, 0.1) |
groupByKey() |
按 Key 分组(慢,慎用) | pairs.groupByKey() |
reduceByKey(f) |
按 Key 聚合(比 groupByKey 快) | pairs.reduceByKey(lambda a,b: a+b) |
aggregateByKey(zeroValue, seqOp, combOp) |
分区内和分区间分别聚合 | - |
sortByKey(ascending=True) |
按 Key 排序 | pairs.sortByKey() |
join(other) |
内连接 | rdd1.join(rdd2) |
leftOuterJoin(other) |
左外连接 | rdd1.leftOuterJoin(rdd2) |
cogroup(other) |
按 Key 合并多个 RDD | rdd1.cogroup(rdd2) |
repartition(n) |
重新分区(完整 Shuffle) | rdd.repartition(100) |
coalesce(n) |
减少分区(无 Shuffle,更高效) | rdd.coalesce(10) |
persist(level) |
持久化到内存/磁盘 | rdd.persist(StorageLevel.MEMORY_AND_DISK) |
cache() |
等价于 persist(MEMORY_ONLY) |
rdd.cache() |
5.3 Actions(动作操作)--- 触发执行
| 操作 | 说明 | 示例 |
|---|---|---|
collect() |
收集所有元素到 Driver(小数据量) | rdd.collect() |
count() |
统计元素个数 | rdd.count() |
first() |
返回第一个元素 | rdd.first() |
take(n) |
返回前 n 个元素 | rdd.take(10) |
top(n) |
返回最大的 n 个元素 | rdd.top(5) |
reduce(f) |
对元素做聚合操作 | rdd.reduce(lambda a,b: a+b) |
fold(zeroValue, f) |
带初始值的聚合 | rdd.fold(0, lambda a,b: a+b) |
aggregate(zeroValue, seqOp, combOp) |
分阶段聚合 | - |
foreach(f) |
对每个元素执行副作用 | rdd.foreach(print) |
saveAsTextFile(path) |
保存为文本文件 | rdd.saveAsTextFile("hdfs://...") |
saveAsPickleFile(path) |
保存为 Pickle 格式 | - |
countByKey() |
统计每个 Key 的数量 | pairs.countByKey() |
collectAsMap() |
收集为字典 | pairs.collectAsMap() |
takeSample(withRep, n) |
随机采样 n 个元素 | rdd.takeSample(False, 100) |
5.4 完整 RDD 示例:WordCount
python
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("WordCount").master("local[*]").getOrCreate()
sc = spark.sparkContext
# 读取文本
text_rdd = sc.textFile("hdfs://path/to/text.txt")
# 处理流程
word_counts = (
text_rdd
.flatMap(lambda line: line.lower().strip().split()) # 分词
.filter(lambda word: len(word) > 0) # 过滤空词
.map(lambda word: (word, 1)) # 映射为 (word, 1)
.reduceByKey(lambda a, b: a + b) # 按词累加
.sortBy(lambda kv: kv[1], ascending=False) # 按频次降序
)
# 输出 Top 10
for word, count in word_counts.take(10):
print(f"{word}: {count}")
# 保存到 HDFS
word_counts.saveAsTextFile("hdfs://path/to/output/wordcount")
5.5 共享变量
广播变量(Broadcast Variables)
广播变量将大型只读数据高效分发到各 Executor,避免在每个 Task 中都传输一份:
python
# 比如:大型词典、模型参数等
country_code = {"CN": "China", "US": "United States", "JP": "Japan"}
bc_country = sc.broadcast(country_code)
def lookup_country(code):
return bc_country.value.get(code, "Unknown") # 通过 .value 访问
result = rdd.map(lambda x: lookup_country(x))
# 手动释放广播变量(节省内存)
bc_country.unpersist()
bc_country.destroy()
累加器(Accumulators)
累加器是一种只能从 Executor 端累加、只能在 Driver 端读取的共享变量:
python
error_count = sc.accumulator(0)
valid_count = sc.accumulator(0)
def parse_record(line):
try:
data = json.loads(line)
valid_count.add(1)
return data
except:
error_count.add(1)
return None
result = rdd.map(parse_record).filter(lambda x: x is not None)
result.count() # 触发执行
print(f"有效记录:{valid_count.value}")
print(f"错误记录:{error_count.value}")
⚠️ 注意:在 Transformation 中使用 Accumulator 时,因为 Task 重试机制,累加器值可能被多次累加。只有在 Action 之后读取累加器值才是准确的。
6. Spark SQL 与 DataFrame 编程
6.1 读取数据
python
# ── CSV ──
df = spark.read \
.option("header", "true") \
.option("inferSchema", "true") \
.option("sep", ",") \
.option("nullValue", "NA") \
.csv("hdfs://path/to/data.csv")
# ── JSON ──
df = spark.read \
.option("multiLine", "true") \
.json("hdfs://path/to/data.json")
# ── Parquet(推荐格式,列式存储+压缩+Schema) ──
df = spark.read.parquet("hdfs://path/to/data.parquet")
# ── ORC ──
df = spark.read.orc("hdfs://path/to/data.orc")
# ── JDBC/数据库 ──
df = spark.read \
.format("jdbc") \
.option("url", "jdbc:mysql://host:3306/mydb") \
.option("dbtable", "mytable") \
.option("user", "user") \
.option("password", "password") \
.option("numPartitions", 10) \
.option("partitionColumn", "id") \
.option("lowerBound", 1) \
.option("upperBound", 1000000) \
.load()
# ── Kafka(流式读取,详见 Structured Streaming 章节) ──
df = spark.read \
.format("kafka") \
.option("kafka.bootstrap.servers", "host:9092") \
.option("subscribe", "topic_name") \
.load()
# ── Hive 表 ──
df = spark.sql("SELECT * FROM hive_db.my_table")
6.2 Schema 定义
python
from pyspark.sql.types import (
StructType, StructField, StringType, IntegerType,
LongType, DoubleType, BooleanType, TimestampType, ArrayType, MapType
)
# 手动定义 Schema(推荐用于生产环境,避免 inferSchema 的性能开销)
schema = StructType([
StructField("user_id", LongType(), nullable=False),
StructField("name", StringType(), nullable=True),
StructField("age", IntegerType(), nullable=True),
StructField("salary", DoubleType(), nullable=True),
StructField("is_active", BooleanType(), nullable=True),
StructField("created_at", TimestampType(), nullable=True),
StructField("tags", ArrayType(StringType()), nullable=True),
StructField("metadata", MapType(StringType(), StringType()), nullable=True),
])
df = spark.read.schema(schema).csv("path/to/data.csv")
# 也可使用 DDL 字符串定义 Schema(Spark 2.3+)
schema_ddl = "user_id BIGINT NOT NULL, name STRING, age INT, salary DOUBLE"
df = spark.read.schema(schema_ddl).csv("path/to/data.csv")
6.3 DataFrame 常用操作
python
# ── 查看数据 ──
df.show(20, truncate=False) # 显示 20 行,不截断
df.printSchema() # 打印 Schema 树形结构
df.dtypes # 返回列名和数据类型列表
df.columns # 返回列名列表
df.count() # 统计行数
df.describe().show() # 数值列的统计摘要(min/max/mean/stddev)
# ── 列操作 ──
from pyspark.sql import functions as F
df.select("name", "age") # 选择列
df.select(F.col("name"), F.col("age") + 1) # 列运算
df.selectExpr("name", "age * 2 as double_age") # SQL 表达式
df.withColumn("adult", F.col("age") >= 18) # 添加新列
df.withColumn("age", F.col("age").cast("integer")) # 类型转换
df.withColumnRenamed("name", "full_name") # 重命名列
df.drop("unwanted_col1", "unwanted_col2") # 删除列
# ── 过滤 ──
df.filter(F.col("age") > 25)
df.filter("age > 25 AND salary > 50000") # SQL 字符串条件
df.where(F.col("name").isNotNull()) # where 是 filter 的别名
df.filter(F.col("status").isin("active", "pending"))
df.filter(F.col("name").like("%John%"))
df.filter(F.col("created_at").between("2024-01-01", "2024-12-31"))
# ── 聚合 ──
df.groupBy("department") \
.agg(
F.count("*").alias("total"),
F.avg("salary").alias("avg_salary"),
F.max("salary").alias("max_salary"),
F.min("salary").alias("min_salary"),
F.sum("salary").alias("total_salary"),
F.countDistinct("user_id").alias("unique_users"),
F.collect_list("name").alias("names"),
F.stddev("salary").alias("salary_stddev")
)
# ── 排序 ──
df.orderBy("age")
df.orderBy(F.col("salary").desc(), F.col("name").asc())
df.sort(F.col("age").desc_nulls_last())
# ── 去重 ──
df.distinct() # 全行去重
df.dropDuplicates(["user_id", "email"]) # 按指定列去重
# ── 缺失值处理 ──
df.dropna() # 删除含任何 null 的行
df.dropna(subset=["name", "age"]) # 删除指定列含 null 的行
df.dropna(how="all") # 仅删除所有列都是 null 的行
df.fillna(0) # 所有 null 填充为 0
df.fillna({"age": 0, "name": "unknown"}) # 按列填充
df.na.replace([None, ""], "N/A", subset=["name"])
# ── Join 操作 ──
# 支持类型:inner, left, right, full, left_semi, left_anti, cross
result = df1.join(df2, on="user_id", how="inner")
result = df1.join(df2, df1["id"] == df2["user_id"], "left")
result = df1.join(df2, ["id", "name"], "inner") # 多列 Join
# Broadcast Join(小表广播)
from pyspark.sql.functions import broadcast
result = large_df.join(broadcast(small_df), "key")
# ── 窗口函数 ──
from pyspark.sql.window import Window
window_spec = Window.partitionBy("department").orderBy(F.col("salary").desc())
df.withColumn("rank", F.rank().over(window_spec))
df.withColumn("dense_rank", F.dense_rank().over(window_spec))
df.withColumn("row_number", F.row_number().over(window_spec))
df.withColumn("lag_salary", F.lag("salary", 1).over(window_spec))
df.withColumn("lead_salary", F.lead("salary", 1).over(window_spec))
# 窗口帧(rows between / range between)
window_rolling = Window.partitionBy("user_id") \
.orderBy("date") \
.rowsBetween(-6, 0) # 过去 7 行滚动窗口
df.withColumn("7day_avg", F.avg("sales").over(window_rolling))
6.4 Spark SQL 使用
python
# 注册为临时视图
df.createOrReplaceTempView("users")
# 注册为全局临时视图(跨 SparkSession)
df.createOrReplaceGlobalTempView("global_users")
# 执行 SQL 查询
result = spark.sql("""
SELECT
department,
COUNT(*) as employee_count,
AVG(salary) as avg_salary,
PERCENTILE_APPROX(salary, 0.5) as median_salary
FROM users
WHERE age BETWEEN 25 AND 55
GROUP BY department
HAVING COUNT(*) >= 10
ORDER BY avg_salary DESC
""")
# 全局视图访问需要加前缀
result = spark.sql("SELECT * FROM global_temp.global_users")
# 在 Hive 表上执行 SQL
spark.sql("CREATE TABLE IF NOT EXISTS output_db.result AS SELECT * FROM users WHERE age > 30")
spark.sql("INSERT INTO output_db.result SELECT * FROM users WHERE age <= 30")
6.5 保存数据
python
# ── 保存为 Parquet(推荐) ──
df.write.mode("overwrite").parquet("hdfs://path/to/output.parquet")
# ── 分区写入(Hive 风格分区) ──
df.write \
.partitionBy("year", "month") \
.mode("overwrite") \
.parquet("hdfs://path/to/partitioned_output")
# ── 控制输出文件数 ──
df.repartition(10).write.mode("overwrite").parquet("path/to/output")
df.coalesce(1).write.mode("overwrite").csv("path/single_file.csv") # 输出单文件
# ── 保存为 ORC ──
df.write.mode("append").orc("hdfs://path/to/output.orc")
# ── 写入 Hive 表 ──
df.write.insertInto("hive_db.my_table", overwrite=True)
df.write.saveAsTable("hive_db.new_table") # 创建新 Hive 表
# ── 写入 JDBC ──
df.write \
.format("jdbc") \
.option("url", "jdbc:postgresql://host:5432/mydb") \
.option("dbtable", "output_table") \
.option("user", "user") \
.option("password", "password") \
.mode("append") \
.save()
# ── 写入模式 ──
# "overwrite" - 覆盖已有数据
# "append" - 追加到已有数据
# "ignore" - 目标存在则跳过(不报错)
# "error" - 目标存在则抛出错误(默认)
7. Catalyst 优化器与 Tungsten 执行引擎
7.1 Catalyst 优化器
Catalyst 是 Spark SQL 的核心查询优化器,基于 Scala 函数式编程实现,对所有 DataFrame/Dataset/SQL 操作自动生效。
优化流程(四阶段):
用户代码(DataFrame API / SQL)
│
▼
1. 解析(Parsing)
└─ 生成 Unresolved Logical Plan(未解析逻辑计划)
│
▼
2. 分析(Analysis)
└─ 结合 Catalog 验证列名、数据类型
└─ 生成 Resolved Logical Plan(已解析逻辑计划)
│
▼
3. 逻辑优化(Logical Optimization)[Catalyst 的核心]
├─ 谓词下推(Predicate Pushdown):Filter 尽量靠近数据源
├─ 列裁剪(Column Pruning):只读取需要的列
├─ 常量折叠(Constant Folding):提前计算常量表达式
├─ Join 重排序(Join Reordering):调整 Join 顺序减少 Shuffle
├─ 子查询消除(Subquery Elimination)
└─ 生成 Optimized Logical Plan(优化后逻辑计划)
│
▼
4. 物理规划(Physical Planning)
├─ 生成多个物理执行计划
├─ 基于代价模型(CBO)选择最优执行计划
└─ 生成 Physical Plan(物理执行计划)
│
▼
5. 代码生成(Code Generation - Tungsten)
└─ 通过 Whole-Stage CodeGen 生成优化的 JVM 字节码执行
查看执行计划:
python
# 查看完整执行计划(推荐)
df.explain(mode="extended") # 显示 Parsed / Analyzed / Optimized / Physical 四个阶段
# 其他模式
df.explain() # 仅物理计划(默认)
df.explain(True) # 等价于 extended
df.explain(mode="formatted") # 格式化输出(Spark 3.0+)
df.explain(mode="cost") # 显示代价信息(需 CBO 统计)
df.explain(mode="codegen") # 显示生成的 Java 代码
# SQL 查看执行计划
spark.sql("EXPLAIN EXTENDED SELECT * FROM users WHERE age > 30").show(truncate=False)
7.2 Catalyst 关键优化规则
谓词下推(Predicate Pushdown)
将 Filter 操作推到尽可能早的阶段,减少数据量:
python
# 写法等价,但 Catalyst 会自动将 filter 推入扫描阶段
df.join(df2, "id").filter(F.col("age") > 30)
# 等价于(Catalyst 会优化为此):
df.filter(F.col("age") > 30).join(df2, "id")
# 对于 Parquet/ORC 文件,谓词甚至会 pushdown 到文件层面
# 只读取满足条件的 row group,大幅减少 IO
列裁剪(Column Pruning)
python
# 只选择需要的列,Catalyst 自动避免读取不用的列
df.select("name", "age").filter(F.col("age") > 30)
# Catalyst 会告诉 Parquet 读取器:只读 name 和 age 两列
AQE(自适应查询执行,Spark 3.0+)
AQE 在运行时根据实际数据统计动态调整查询计划:
python
# 启用 AQE(Spark 3.2+ 默认开启)
spark.conf.set("spark.sql.adaptive.enabled", "true")
# AQE 的三大核心功能:
# 1. 动态合并小 Shuffle 分区
spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", "true")
spark.conf.set("spark.sql.adaptive.advisoryPartitionSizeInBytes", "128MB")
# 2. 动态切换 Join 策略(如将 Sort-Merge Join 切换为 Broadcast Join)
spark.conf.set("spark.sql.adaptive.autoBroadcastJoinThreshold", "30MB")
# 3. 动态处理数据倾斜
spark.conf.set("spark.sql.adaptive.skewJoin.enabled", "true")
spark.conf.set("spark.sql.adaptive.skewJoin.skewedPartitionFactor", "5") # 超过中位数 5 倍认为倾斜
spark.conf.set("spark.sql.adaptive.skewJoin.skewedPartitionThresholdInBytes", "256MB")
7.3 Tungsten 执行引擎
Tungsten 专注于 CPU 和内存效率的极限优化:
| 技术 | 说明 |
|---|---|
| 堆外内存管理(Off-Heap Memory) | 使用 sun.misc.Unsafe 直接管理内存,绕过 GC |
| 缓存友好数据结构 | UnsafeRow 格式,紧凑存储,提高 CPU 缓存命中率 |
| 全阶段代码生成(Whole-Stage CodeGen) | 将多个算子融合为单个 JVM 方法,消除虚函数调用开销 |
| 向量化执行(Vectorized Execution) | 批量处理列式数据,充分利用 SIMD 指令 |
8. Shuffle 机制与 Join 策略
8.1 Shuffle 工作原理
Shuffle 是 Spark 中最昂贵的操作,发生在宽依赖(如 groupBy、join、sortBy)时:
Shuffle 执行流程:
Stage 1(Map 端):
Executor A:Task 1 → 将数据按目标分区 Hash 写入 shuffle 文件
Executor B:Task 2 → 将数据按目标分区 Hash 写入 shuffle 文件
↓ 网络传输(最耗时)
Stage 2(Reduce 端):
Executor C:Task 1 → 从各 Executor 拉取属于自己的分区数据
Executor D:Task 2 → 从各 Executor 拉取属于自己的分区数据
Shuffle 的代价:
- 磁盘 I/O(写入 shuffle 文件)
- 序列化/反序列化
- 网络传输(最主要的瓶颈)
8.2 Join 策略详解
Spark 支持 5 种 Join 策略,自动或手动选择:
策略 1:Broadcast Hash Join(BHJ)--- 最快
原理:将小表广播到每个 Executor,每个 Executor 本地执行 Hash Join
适用:小表 ≤ spark.sql.autoBroadcastJoinThreshold(默认 10MB)
优点:无 Shuffle,速度最快
缺点:小表必须能装入 Executor 内存;不支持某些 join 类型(如 full outer)
Driver 收集小表
│
广播到所有 Executor
┌─────┼─────┐
Exec-A Exec-B Exec-C ← 每个 Executor 持有完整小表
│ │ │
└──本地 Hash Join──┘ ← 大表分区与本地小表直接 Join
python
from pyspark.sql.functions import broadcast
# 方法 1:Python API
result = large_df.join(broadcast(small_df), "key")
# 方法 2:SQL hint
result = spark.sql("""
SELECT /*+ BROADCAST(small_table) */ *
FROM large_table
JOIN small_table ON large_table.key = small_table.key
""")
# 调整广播阈值(字节)
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "50MB") # 调大到 50MB
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "-1") # 禁用自动广播
策略 2:Sort-Merge Join(SMJ)--- 大表默认选择
原理:对两张表按 Join Key shuffle + sort,再 merge
适用:两张大表无法广播的等值 Join
优点:可处理超大表,支持所有 Join 类型,可 spill 到磁盘
缺点:需要全量 Shuffle + Sort,开销最大
步骤:
1. 两张表各自按 Join Key 进行 Shuffle
2. 每个分区内按 Join Key 排序
3. 像归并排序一样合并两个有序分区
python
# SQL hint 强制使用 Sort-Merge Join
result = spark.sql("""
SELECT /*+ MERGE(t1) */ *
FROM t1 JOIN t2 ON t1.key = t2.key
""")
# 配置
spark.conf.set("spark.sql.join.preferSortMergeJoin", "true") # 默认 true
策略 3:Shuffle Hash Join(SHJ)
原理:Shuffle 后,在每个分区的较小侧建立 Hash Table,与大侧 probe
适用:两表均较大但有一侧明显更小,且 Key 分布均匀
优点:比 Sort-Merge 快(无需排序)
缺点:内存不足可能 OOM(不能 spill 排序阶段)
AQE 可在运行时将 SMJ 动态切换为 SHJ:
当 shuffle 后分区数据量小于 spark.sql.adaptive.maxShuffledHashJoinLocalMapThreshold 时自动切换
python
# 强制 Shuffle Hash Join
result = df1.hint("shuffle_hash").join(df2, "key")
Join 策略选择决策树
两表的 Join 操作
│
├── 有一侧 ≤ autoBroadcastJoinThreshold (默认 10MB)?
│ └── YES → Broadcast Hash Join(最快,无 Shuffle)
│
├── 都 > 广播阈值,但有一侧明显更小?
│ └── YES(小侧能放入单个分区内存)→ Shuffle Hash Join
│
└── 两侧都很大?
└── Sort-Merge Join(稳健,支持 spill)
└── AQE 开启时:若 shuffle 后小侧显著缩小 → 动态切换为 SHJ 或 BHJ
8.3 数据倾斜(Skew)处理
数据倾斜是 Spark 最常见的性能杀手,某些 Key 的数据量远大于其他 Key:
方法 1:AQE 自动处理(推荐,Spark 3.0+)
python
spark.conf.set("spark.sql.adaptive.enabled", "true")
spark.conf.set("spark.sql.adaptive.skewJoin.enabled", "true")
# AQE 会自动将大分区切分并复制小侧数据来处理
方法 2:手动加盐(Salting)
python
import random
from pyspark.sql import functions as F
# 加盐系数
NUM_SALT = 20
# 对倾斜的大表加盐(随机前缀)
skewed_df = large_df.withColumn(
"salted_key",
F.concat(F.col("join_key"), F.lit("_"), (F.rand() * NUM_SALT).cast("int").cast("string"))
)
# 对小表扩展(每个 Key 复制 NUM_SALT 份)
small_expanded = small_df.withColumn(
"salt",
F.explode(F.array([F.lit(i) for i in range(NUM_SALT)]))
).withColumn(
"salted_key",
F.concat(F.col("join_key"), F.lit("_"), F.col("salt").cast("string"))
)
# Join
result = skewed_df.join(small_expanded, "salted_key").drop("salted_key", "salt")
方法 3:分离倾斜 Key 单独处理
python
# 找出倾斜 Key
skew_keys = ["hotkey1", "hotkey2"]
# 拆分 DataFrame
skewed_df = df.filter(F.col("key").isin(skew_keys))
normal_df = df.filter(~F.col("key").isin(skew_keys))
# 分别 Join(倾斜侧用 broadcast)
skewed_result = skewed_df.join(broadcast(small_df), "key")
normal_result = normal_df.join(small_df, "key")
# 合并结果
final_result = normal_result.union(skewed_result)
9. Structured Streaming 流处理
9.1 流处理核心概念
Structured Streaming 将流数据视为一张持续增长的无界表(Unbounded Table),用与批处理相同的 DataFrame/SQL API 处理:
输入数据流(Kafka/文件/Socket)
│
▼ 每个 trigger 时间点追加新行
┌─────────────────────────────┐
│ Input Table │ ← 概念上的无界输入表
│ (只追加,从不修改旧行) │
└─────────────────────────────┘
│
▼ 用户定义的查询(与批处理 API 完全一致)
┌─────────────────────────────┐
│ Result Table │ ← 查询结果表(持续更新)
└─────────────────────────────┘
│
▼ Output Sink(按 Output Mode 写出)
文件系统 / Kafka / Console / Memory / 自定义 Sink
9.2 输入源(Source)
python
# ── Kafka Source(最常用)──
kafka_df = spark \
.readStream \
.format("kafka") \
.option("kafka.bootstrap.servers", "host1:9092,host2:9092") \
.option("subscribe", "topic1,topic2") \ # 订阅多个 topic
.option("startingOffsets", "latest") \ # earliest / latest / JSON 指定
.option("maxOffsetsPerTrigger", 100000) \ # 每次触发最多消费数量
.option("failOnDataLoss", "false") \
.load()
# Kafka value 是 binary,需要 cast
from pyspark.sql.functions import col, from_json
schema = StructType([...])
parsed_df = kafka_df \
.select(
col("key").cast("string"),
from_json(col("value").cast("string"), schema).alias("data"),
col("timestamp"),
col("partition"),
col("offset")
) \
.select("key", "data.*", "timestamp")
# ── 文件 Source(监控目录)──
file_df = spark \
.readStream \
.format("csv") \
.schema(schema) \
.option("path", "hdfs://path/to/incoming/") \
.option("maxFilesPerTrigger", 10) \ # 每次最多处理 10 个文件
.load()
# ── Socket Source(调试用)──
socket_df = spark \
.readStream \
.format("socket") \
.option("host", "localhost") \
.option("port", 9999) \
.load()
9.3 流处理操作
python
# 与批处理完全一致的 API
result_df = parsed_df \
.filter(F.col("amount") > 0) \
.groupBy(
F.window("event_time", "10 minutes", "5 minutes"), # 10分钟窗口,5分钟滑动
F.col("user_id")
) \
.agg(
F.sum("amount").alias("total_amount"),
F.count("*").alias("tx_count")
)
9.4 水位线(Watermark)处理延迟数据
python
# 设置水位线:允许最多 10 分钟的延迟数据
windowed_df = parsed_df \
.withWatermark("event_time", "10 minutes") \ # event_time 是事件时间列
.groupBy(
F.window("event_time", "5 minutes"),
"category"
) \
.agg(F.sum("amount").alias("total"))
9.5 输出 Sink 与 Output Mode
python
# ── 三种 Output Mode ──
# complete: 输出所有结果(适合 aggregation,但内存压力大)
# append: 只输出新增行(无聚合或有水位线的聚合)
# update: 只输出变化的行(聚合时常用)
# ── Console Sink(调试)──
query = result_df.writeStream \
.outputMode("update") \
.format("console") \
.option("truncate", False) \
.trigger(processingTime="10 seconds") \
.start()
# ── Kafka Sink ──
query = result_df \
.selectExpr("CAST(user_id AS STRING) AS key", "to_json(struct(*)) AS value") \
.writeStream \
.format("kafka") \
.option("kafka.bootstrap.servers", "host:9092") \
.option("topic", "output_topic") \
.option("checkpointLocation", "hdfs://path/to/checkpoint/") \
.outputMode("append") \
.start()
# ── 文件 Sink(Parquet)──
query = result_df.writeStream \
.format("parquet") \
.option("path", "hdfs://path/to/output/") \
.option("checkpointLocation", "hdfs://path/to/checkpoint/") \
.partitionBy("date") \
.outputMode("append") \
.trigger(processingTime="1 minute") \
.start()
# ── 自定义 ForeachBatch Sink(最灵活,Spark 2.4+)──
def write_to_db(batch_df, batch_id):
"""每个 micro-batch 会调用此函数"""
batch_df.persist()
# 写到数据库
batch_df.write.format("jdbc") \
.option("url", "jdbc:...") \
.option("dbtable", "result_table") \
.mode("append") \
.save()
# 同时发送到 Kafka
batch_df.selectExpr(...).write.format("kafka")...
batch_df.unpersist()
query = result_df.writeStream \
.foreachBatch(write_to_db) \
.option("checkpointLocation", "hdfs://path/to/checkpoint/") \
.outputMode("update") \
.start()
# ── 等待查询结束 ──
query.awaitTermination() # 阻塞
query.stop() # 手动停止
9.6 Trigger(触发器)类型
python
from pyspark.sql.streaming import Trigger
# 固定时间间隔(默认 0,尽快处理)
.trigger(processingTime="30 seconds")
# 只处理一次 micro-batch,然后停止(Spark 2.2+)
.trigger(once=True)
# 持续处理(Spark 2.3+,低延迟实验性功能)
.trigger(continuous="1 second")
# 按可用数据处理(每次处理所有可用数据)
.trigger(availableNow=True) # Spark 3.3+
9.7 容错语义
Structured Streaming 提供**端到端精确一次(Exactly-Once)**保证:
- Checkpointing:Spark 将 offset 和状态存入 checkpoint 目录(HDFS/S3)
- Idempotent Sink:幂等写入(如 Parquet 文件、有事务支持的数据库)
- 可重放的 Source:如 Kafka 支持从指定 offset 重新消费
10. Spark MLlib 机器学习
10.1 MLlib 概述
MLlib 是 Spark 内置的分布式机器学习库,基于 DataFrame API(spark.ml 包)。
MLlib 主要功能:
├── 特征工程
│ ├── 特征提取(TF-IDF、Word2Vec、HashingTF)
│ ├── 特征转换(StandardScaler、MinMaxScaler、PCA)
│ └── 特征选择(ChiSqSelector、VarianceThresholdSelector)
├── 分类算法
│ ├── 逻辑回归(LogisticRegression)
│ ├── 随机森林(RandomForestClassifier)
│ ├── 梯度提升树(GBTClassifier)
│ ├── 支持向量机(LinearSVC)
│ └── 多层感知机(MultilayerPerceptronClassifier)
├── 回归算法
│ ├── 线性回归(LinearRegression)
│ ├── 随机森林回归(RandomForestRegressor)
│ └── 梯度提升树回归(GBTRegressor)
├── 聚类算法
│ ├── K-Means(KMeans)
│ ├── 高斯混合(GaussianMixture)
│ └── 层次聚类(BisectingKMeans)
├── 协同过滤(ALS 推荐系统)
└── 模型评估与选择(CrossValidator、TrainValidationSplit)
10.2 完整 ML Pipeline 示例
python
from pyspark.ml import Pipeline
from pyspark.ml.feature import (
VectorAssembler, StringIndexer, OneHotEncoder,
StandardScaler, Imputer
)
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
# 加载数据
data = spark.read.parquet("hdfs://path/to/ml_data.parquet")
# 划分训练集和测试集
train_data, test_data = data.randomSplit([0.8, 0.2], seed=42)
# ── 特征工程 ──
# 1. 填充缺失值
imputer = Imputer(
inputCols=["age", "salary", "years_exp"],
outputCols=["age_imp", "salary_imp", "years_exp_imp"],
strategy="median"
)
# 2. 类别型变量编码
indexer = StringIndexer(inputCol="department", outputCol="dept_idx")
encoder = OneHotEncoder(inputCol="dept_idx", outputCol="dept_ohe")
# 3. 组合特征向量
assembler = VectorAssembler(
inputCols=["age_imp", "salary_imp", "years_exp_imp", "dept_ohe"],
outputCol="features_raw"
)
# 4. 标准化
scaler = StandardScaler(
inputCol="features_raw",
outputCol="features",
withMean=True,
withStd=True
)
# 5. 目标变量编码
label_indexer = StringIndexer(inputCol="label", outputCol="label_idx")
# ── 模型 ──
rf = RandomForestClassifier(
featuresCol="features",
labelCol="label_idx",
numTrees=100,
maxDepth=10,
seed=42
)
# ── Pipeline ──
pipeline = Pipeline(stages=[
imputer, indexer, encoder, assembler, scaler, label_indexer, rf
])
# ── 超参数搜索 ──
param_grid = ParamGridBuilder() \
.addGrid(rf.numTrees, [50, 100, 200]) \
.addGrid(rf.maxDepth, [5, 10, 15]) \
.build()
evaluator = BinaryClassificationEvaluator(
labelCol="label_idx",
rawPredictionCol="rawPrediction",
metricName="areaUnderROC"
)
cv = CrossValidator(
estimator=pipeline,
estimatorParamMaps=param_grid,
evaluator=evaluator,
numFolds=5,
seed=42
)
# ── 训练 ──
cv_model = cv.fit(train_data)
# ── 评估 ──
predictions = cv_model.transform(test_data)
auc = evaluator.evaluate(predictions)
print(f"AUC-ROC: {auc:.4f}")
# ── 保存模型 ──
cv_model.bestModel.write().overwrite().save("hdfs://path/to/model/")
# ── 加载模型 ──
from pyspark.ml import PipelineModel
loaded_model = PipelineModel.load("hdfs://path/to/model/")
10.3 ALS 推荐系统
python
from pyspark.ml.recommendation import ALS
from pyspark.ml.evaluation import RegressionEvaluator
# 数据格式:userId, itemId, rating
ratings = spark.read.csv("ratings.csv", header=True, inferSchema=True)
train, test = ratings.randomSplit([0.8, 0.2])
# 训练 ALS 模型
als = ALS(
maxIter=10,
regParam=0.01,
rank=10, # 隐因子维度
userCol="userId",
itemCol="movieId",
ratingCol="rating",
coldStartStrategy="drop" # 处理冷启动问题
)
model = als.fit(train)
# 评估 RMSE
predictions = model.transform(test)
evaluator = RegressionEvaluator(
metricName="rmse", labelCol="rating", predictionCol="prediction"
)
rmse = evaluator.evaluate(predictions)
print(f"Root-mean-square error = {rmse:.4f}")
# 为每个用户生成 Top 10 推荐
user_recs = model.recommendForAllUsers(10)
# 为每个物品找到 Top 10 潜在用户
item_recs = model.recommendForAllItems(10)
11. Spark GraphX 图计算
11.1 GraphX 基础
GraphX 在 Spark RDD 基础上构建图计算框架,提供 Graph 抽象:
scala
// Scala 示例(GraphX 主要以 Scala 使用)
import org.apache.spark.graphx._
import org.apache.spark.rdd.RDD
// 创建顶点 RDD:(VertexId, 属性)
val vertices: RDD[(VertexId, String)] = sc.parallelize(Array(
(1L, "Alice"), (2L, "Bob"), (3L, "Charlie"), (4L, "David")
))
// 创建边 RDD:Edge(srcId, dstId, 属性)
val edges: RDD[Edge[String]] = sc.parallelize(Array(
Edge(1L, 2L, "friend"),
Edge(2L, 3L, "colleague"),
Edge(3L, 4L, "friend"),
Edge(4L, 1L, "family")
))
// 构建图
val graph = Graph(vertices, edges, defaultVertexAttr = "unknown")
// 图的基本属性
println(s"顶点数: ${graph.numVertices}")
println(s"边数: ${graph.numEdges}")
// 查看顶点
graph.vertices.collect().foreach(println)
// 查看边
graph.edges.collect().foreach(println)
// 图算法:PageRank
val pageRankGraph = graph.pageRank(0.001) // 容差值
// 三角形计数
val triangleCount = graph.triangleCount()
// 连通分量
val cc = graph.connectedComponents()
// 最短路径(Pregel API)
import org.apache.spark.graphx.lib.ShortestPaths
val shortestPaths = ShortestPaths.run(graph, Seq(1L))
12. 内存管理与持久化策略
12.1 Spark 内存模型
Spark 的 Executor JVM 内存划分(Spark 1.6+ 统一内存管理):
Executor JVM 总内存
├── Reserved Memory(系统保留):300 MB(固定)
├── User Memory(1 - spark.memory.fraction = 40%)
│ └── 用户数据结构、UDF、Spark 内部对象
└── Spark Memory(spark.memory.fraction = 60%)
├── Execution Memory(动态共享)
│ └── Shuffle、Sort、Join 的中间数据
└── Storage Memory(动态共享)
└── RDD/DataFrame Cache、广播变量
注:Execution 和 Storage 内存是动态共享的(spark.memory.storageFraction 控制初始比例)
当执行内存不足时,可以驱逐 Storage 内存
关键配置:
python
# 调整 Spark 内存占比(默认 0.6,即 60% JVM 堆用于 Spark)
spark.conf.set("spark.memory.fraction", "0.7")
# 在 Spark 内存中,Storage 的初始比例(默认 0.5)
spark.conf.set("spark.memory.storageFraction", "0.4")
# Executor 堆内存(在 spark-submit 或 config 中设置)
# --executor-memory 4g
# spark.executor.memory = 4g
# 堆外内存(Tungsten 使用,默认 0.1 * executor memory)
spark.conf.set("spark.memory.offHeap.enabled", "true")
spark.conf.set("spark.memory.offHeap.size", "2g")
12.2 持久化级别
python
from pyspark import StorageLevel
# 常用持久化级别
df.cache() # 等价于 MEMORY_AND_DISK
df.persist() # 默认:MEMORY_AND_DISK
df.persist(StorageLevel.MEMORY_ONLY) # 仅内存,不够则重算
df.persist(StorageLevel.MEMORY_AND_DISK) # 优先内存,溢出到磁盘
df.persist(StorageLevel.MEMORY_ONLY_SER) # 序列化存内存(节省空间)
df.persist(StorageLevel.MEMORY_AND_DISK_SER) # 序列化,内存+磁盘
df.persist(StorageLevel.DISK_ONLY) # 仅磁盘
df.persist(StorageLevel.OFF_HEAP) # 堆外内存(需开启)
# 释放缓存
df.unpersist()
# 清除所有缓存
spark.catalog.clearCache()
选择持久化级别的原则:
内存充足?
├── YES → MEMORY_ONLY(速度最快)
│ 如果仍然 OOM → MEMORY_ONLY_SER(序列化节省约 50% 空间)
└── NO → MEMORY_AND_DISK(内存放不下就 spill 到磁盘)
需要极致压缩 → MEMORY_AND_DISK_SER
内存极其紧张 → DISK_ONLY
何时使用 Cache?
python
# ✅ 应该缓存:
# 1. DataFrame 被多次复用
base_df = spark.read.parquet("...").cache()
result1 = base_df.filter(...) # 第一次用
result2 = base_df.groupBy(...) # 第二次用(直接从 cache 读取)
# 2. 迭代计算(如 ML 训练中的特征 DataFrame)
train_features.cache()
for epoch in range(100):
model.train(train_features)
# ❌ 不应该缓存:
# 1. 只用一次的 DataFrame
# 2. 非常大无法放入内存的 DataFrame(溢出反而更慢)
# 3. 来源读取很快(如本地 SSD Parquet)
13. 分区策略与数据倾斜处理
13.1 分区基础
python
# 查看分区数
df.rdd.getNumPartitions()
# 增加分区(触发 Shuffle)
df_repartitioned = df.repartition(200)
# 增加并按列分区(相同 key 会到同一分区,有助于 Join 优化)
df_repartitioned = df.repartition(200, "user_id")
# 减少分区(不触发 Shuffle,比 repartition 更高效)
df_coalesced = df.coalesce(50)
# 推荐分区数
# 一般原则:分区数 ≈ 2~4 × CPU 核心总数
# shuffle 分区数(default 200)
spark.conf.set("spark.sql.shuffle.partitions", "400")
# AQE 开启时,让 Spark 自动调整 shuffle 分区数
spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", "true")
spark.conf.set("spark.sql.adaptive.advisoryPartitionSizeInBytes", "128MB")
13.2 分区大小建议
分区大小:理想在 100MB ~ 200MB 之间
太大 → 内存压力大,Task 运行时间长,容错重试代价高
太小 → 任务调度开销大,Shuffle 文件过多
检查当前分区情况:
df.rdd.glom().map(len).collect() # 每个分区的元素数量
df.groupBy(spark_partition_id()).count() # 每个分区的行数
14. 性能调优全面指南
14.1 优化层次总览
性能优化金字塔(重要性从高到低):
① 代码层(影响最大)
└─ 使用 DataFrame/Dataset 而非 RDD
└─ 避免不必要的 Shuffle(尽早 filter 和 select)
└─ 使用内置函数而非 UDF
└─ 合理使用 Cache
② 数据层
└─ 使用 Parquet/ORC 列式存储格式
└─ 合理分区和 Bucket(Bucketing)
└─ 数据倾斜处理
③ 配置层
└─ 合理的 Executor 配置
└─ AQE 配置
└─ Shuffle 参数调整
④ 资源层(影响最小但最直接)
└─ 增加 Executor 数量和内存
└─ 使用更快的存储(SSD)
└─ 升级网络带宽
14.2 代码优化技巧
1. 尽早过滤和列裁剪
python
# ❌ 不好:先 join 再过滤(join 了大量不需要的数据)
result = df1.join(df2, "id").filter(F.col("date") >= "2024-01-01")
# ✅ 好:先过滤再 join
result = df1.filter(F.col("date") >= "2024-01-01") \
.select("id", "date", "amount") \ # 只选需要的列
.join(df2.select("id", "name"), "id") # 同样只选需要的列
2. 避免使用 Python UDF(改用内置函数)
python
from pyspark.sql import functions as F
# ❌ 慢:Python UDF(序列化开销,Catalyst 无法优化)
from pyspark.sql.functions import udf
from pyspark.sql.types import StringType
@udf(returnType=StringType())
def get_domain(email):
return email.split("@")[1] if "@" in email else None
df.withColumn("domain", get_domain(F.col("email"))) # 慢!
# ✅ 快:使用内置函数
df.withColumn("domain", F.split(F.col("email"), "@").getItem(1))
# 如果必须用 UDF,考虑 Pandas UDF(向量化执行,快 10-100 倍)
from pyspark.sql.functions import pandas_udf
import pandas as pd
@pandas_udf(StringType())
def get_domain_pandas(emails: pd.Series) -> pd.Series:
return emails.str.split("@").str.get(1)
df.withColumn("domain", get_domain_pandas(F.col("email"))) # 比 UDF 快得多
3. reduceByKey vs groupByKey
python
# ❌ 慢:groupByKey 在 shuffle 前不做局部聚合,传输大量数据
pairs.groupByKey().mapValues(sum)
# ✅ 快:reduceByKey 在 shuffle 前先做局部聚合(combiner 优化)
pairs.reduceByKey(lambda a, b: a + b)
# 对于复杂聚合,使用 aggregateByKey
pairs.aggregateByKey(
(0, 0), # zeroValue: (sum, count)
lambda acc, v: (acc[0]+v, acc[1]+1), # seqOp: 分区内聚合
lambda acc1, acc2: (acc1[0]+acc2[0], acc1[1]+acc2[1]) # combOp: 分区间合并
).mapValues(lambda x: x[0]/x[1]) # 计算平均值
4. 合理使用 repartition 和 coalesce
python
# 在大型操作之前适当增加分区
df.repartition(400, "join_key").join(other_df, "join_key")
# 写入文件前合并小分区(避免生成大量小文件)
result.coalesce(10).write.parquet("path/to/output/")
# 注意:repartition 触发 Shuffle,coalesce 不触发
# 如果需要减少分区数,优先用 coalesce
14.3 存储格式优化
python
# Parquet 是最推荐的存储格式
# ✅ 优点:列式存储、压缩率高、支持 Schema、谓词下推到文件层
df.write \
.option("compression", "snappy") \ # gzip(高压缩率)/ snappy(快速)/ zstd(平衡)
.partitionBy("year", "month") \ # 按时间分区,方便过滤
.parquet("hdfs://path/to/output/")
# 读取时推断 Schema 代价大,生产环境显式指定 Schema
df = spark.read.schema(predefined_schema).parquet("path/to/parquet/")
14.4 Executor 资源配置
bash
# spark-submit 资源配置示例
spark-submit \
--master yarn \
--deploy-mode cluster \
--driver-memory 4g \
--executor-memory 8g \
--executor-cores 4 \
--num-executors 20 \
--conf "spark.sql.shuffle.partitions=400" \
--conf "spark.sql.adaptive.enabled=true" \
--conf "spark.serializer=org.apache.spark.serializer.KryoSerializer" \
--conf "spark.default.parallelism=400" \
--conf "spark.dynamicAllocation.enabled=true" \
--conf "spark.dynamicAllocation.minExecutors=5" \
--conf "spark.dynamicAllocation.maxExecutors=50" \
my_app.py
Executor 配置黄金公式(YARN 环境):
假设:节点 RAM = 128 GB,节点 CPU = 32 核
预留给 OS 和 YARN:RAM 留 10%,CPU 留 1 核
可用:RAM ≈ 115 GB,CPU ≈ 31 核
每个 Executor 建议配置 5 个 cores(减少 GC 压力,提高 HDFS 吞吐量)
每个节点 Executor 数 = 31 / 5 ≈ 6 个
每个 Executor 内存 = 115 / 6 ≈ 19 GB
减去 Executor 开销(overhead = max(384MB, 0.1 * executor-memory))
executor-memory = 19 - 2 ≈ 17 GB
最终配置:
--executor-cores 5
--executor-memory 17g
--num-executors (节点数 × 6 - 1) # -1 留给 driver
14.5 常见性能问题速查
| 问题现象 | 可能原因 | 解决方案 |
|---|---|---|
| Stage 中个别 Task 特别慢 | 数据倾斜 | AQE skewJoin、加盐、分离倾斜 Key |
| OOM(OutOfMemoryError) | 内存不足、分区太少 | 增加内存、增加分区数、使用 coalesce |
| Shuffle 时间占比高 | 宽依赖太多 | 减少 groupBy/join、使用 broadcast join |
| 读取速度慢 | 文件格式不佳、分区扫描 | 改用 Parquet/ORC、合理 partition pruning |
| GC 时间过长 | JVM 对象过多 | 使用 Kryo 序列化、开启堆外内存 |
| 任务调度延迟高 | 分区太多(小分区) | 使用 AQE 自动合并、调大分区大小 |
| CPU 利用率低 | IO 密集 / 数据加载慢 | 压缩存储、预取、增加并行度 |
15. 集群部署与资源管理
15.1 部署模式对比
| 模式 | 特点 | 适用场景 |
|---|---|---|
| Local | 单机模拟,local[N] 指定线程数 |
开发测试 |
| Standalone | Spark 自带集群管理器 | 纯 Spark 集群 |
| YARN | 与 Hadoop 生态集成 | Hadoop 环境 |
| Kubernetes | 容器化部署,弹性伸缩 | 云原生/容器环境 |
| Mesos | 通用资源管理(逐渐被淘汰) | - |
15.2 spark-submit 提交应用
bash
# 提交 PySpark 应用到 YARN
spark-submit \
--master yarn \
--deploy-mode cluster \ # client(driver 在本机)/ cluster(driver 在集群)
--name "My Spark Job" \
--driver-memory 4g \
--executor-memory 8g \
--executor-cores 4 \
--num-executors 20 \
--files /path/to/config.json \ # 附带文件
--py-files /path/to/utils.zip \ # Python 模块
--jars /path/to/extra.jar \ # 额外 JAR 包
--conf "spark.sql.shuffle.partitions=400" \
--conf "spark.sql.adaptive.enabled=true" \
/path/to/your_script.py \
arg1 arg2 # 脚本参数
15.3 动态资源分配
bash
# 开启动态分配(自动根据负载增减 Executor)
--conf "spark.dynamicAllocation.enabled=true"
--conf "spark.dynamicAllocation.minExecutors=2"
--conf "spark.dynamicAllocation.maxExecutors=100"
--conf "spark.dynamicAllocation.initialExecutors=10"
--conf "spark.dynamicAllocation.executorIdleTimeout=60s"
--conf "spark.shuffle.service.enabled=true" # 动态分配需要外部 Shuffle Service
15.4 Kubernetes 部署
bash
# 提交到 K8s 集群
spark-submit \
--master k8s://https://k8s-api-server:6443 \
--deploy-mode cluster \
--name spark-example \
--conf "spark.executor.instances=5" \
--conf "spark.kubernetes.container.image=apache/spark:3.5.0" \
--conf "spark.kubernetes.namespace=spark" \
--conf "spark.kubernetes.authenticate.driver.serviceAccountName=spark" \
local:///path/to/your_script.py
16. 监控与调试
16.1 Spark UI 使用指南
Spark UI 默认运行在 Driver 节点的 4040 端口(历史 Server 在 18080):
Spark UI 核心页面:
Jobs 页面
└─ 查看所有 Job 的执行状态、时间、进度条
点击 Job → 查看组成的 Stages
Stages 页面
└─ 最重要的调优页面
查看每个 Stage 的:
├── 输入/输出/Shuffle 数据量
├── Task 分布(看是否有长尾 Task = 数据倾斜)
├── GC 时间(过高说明内存压力大)
└── Spill 量(Disk 和 Memory Spill 说明内存不足)
Storage 页面
└─ 查看被 cache 的 RDD/DataFrame 及其内存占用
Environment 页面
└─ 查看 Spark 配置参数
Executors 页面
└─ 查看各 Executor 的内存使用、GC 时间、Task 数量
SQL 页面(DataFrame/SQL)
└─ 可视化执行计划 DAG
查看每个节点处理的行数和时间
识别性能瓶颈节点
16.2 常用调试命令
python
# ── 查看执行计划 ──
df.explain(mode="formatted")
# ── 查看分区信息 ──
print(f"分区数: {df.rdd.getNumPartitions()}")
# 查看每个分区的数据量
df.groupBy(F.spark_partition_id().alias("partition_id")).count().orderBy("partition_id").show()
# ── 查看数据分布 ──
df.groupBy("key_column").count() \
.orderBy(F.col("count").desc()) \
.show(20)
# ── 统计基本信息 ──
df.summary().show() # 比 describe 更详细
# ── 查看 Cache 状态 ──
spark.catalog.listTables() # 列出所有注册的表
# ── 查看 Job 信息(程序内)──
sc.statusTracker().getJobIDs()
sc.statusTracker().getStageInfo(stageId)
16.3 日志配置
properties
# log4j2.properties(Spark 3.3+ 使用 log4j2)
rootLogger.level = WARN
appender.console.type = Console
appender.console.name = console
appender.console.layout.type = PatternLayout
appender.console.layout.pattern = %d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n%ex
# 调整特定包的日志级别
logger.spark.name = org.apache.spark
logger.spark.level = WARN
logger.hadoop.name = org.apache.hadoop
logger.hadoop.level = WARN
# 自己的应用调试
logger.myapp.name = com.mycompany.myapp
logger.myapp.level = DEBUG
python
# 在代码中动态设置日志级别
spark.sparkContext.setLogLevel("WARN") # ERROR/WARN/INFO/DEBUG
17. Spark 配置参数速查表
17.1 核心配置
| 参数 | 默认值 | 说明 |
|---|---|---|
spark.executor.memory |
1g | Executor JVM 堆内存 |
spark.executor.cores |
1(YARN)/ 所有(Standalone) | 每个 Executor 的 CPU 核心数 |
spark.driver.memory |
1g | Driver 内存 |
spark.default.parallelism |
集群 CPU 核心总数 × 2 | RDD 默认并行度 |
spark.sql.shuffle.partitions |
200 | Shuffle 后分区数 |
spark.memory.fraction |
0.6 | JVM 堆分配给 Spark 的比例 |
spark.memory.storageFraction |
0.5 | Spark 内存中 Storage 的初始比例 |
spark.serializer |
JavaSerializer | 推荐改为 KryoSerializer |
17.2 AQE 配置
| 参数 | 默认值 | 说明 |
|---|---|---|
spark.sql.adaptive.enabled |
true(3.2+) | 启用 AQE |
spark.sql.adaptive.coalescePartitions.enabled |
true | 动态合并小分区 |
spark.sql.adaptive.advisoryPartitionSizeInBytes |
64MB | 目标分区大小 |
spark.sql.adaptive.skewJoin.enabled |
true | 自动处理倾斜 Join |
spark.sql.adaptive.skewJoin.skewedPartitionFactor |
5 | 倾斜判断因子 |
spark.sql.adaptive.skewJoin.skewedPartitionThresholdInBytes |
256MB | 倾斜分区大小阈值 |
17.3 Join 与广播配置
| 参数 | 默认值 | 说明 |
|---|---|---|
spark.sql.autoBroadcastJoinThreshold |
10MB | 广播 Join 的表大小阈值 |
spark.sql.join.preferSortMergeJoin |
true | 优先 Sort-Merge Join |
spark.broadcast.compress |
true | 广播时压缩数据 |
17.4 Shuffle 配置
| 参数 | 默认值 | 说明 |
|---|---|---|
spark.shuffle.compress |
true | 压缩 Shuffle 数据 |
spark.shuffle.spill.compress |
true | 压缩 Spill 到磁盘的数据 |
spark.reducer.maxSizeInFlight |
48MB | Reduce 端每次拉取的最大数据量 |
spark.shuffle.file.buffer |
32KB | Shuffle 写文件时的缓冲区大小 |
17.5 Streaming 配置
| 参数 | 默认值 | 说明 |
|---|---|---|
spark.streaming.kafka.maxRatePerPartition |
不限 | Kafka 每分区每秒最大消费速率 |
spark.sql.streaming.checkpointLocation |
- | 默认 checkpoint 路径 |
spark.sql.streaming.minBatchesToRetain |
100 | 保留的最小 batch 元数据数量 |
18. 最佳实践与常见坑
18.1 黄金法则
✅ 优先使用 DataFrame/Dataset API,而非 RDD
✅ 优先使用内置函数(spark.sql.functions),而非 Python/Scala UDF
✅ 尽早执行 filter 和 select(减少数据量)
✅ 使用 Parquet/ORC 格式,并开启分区
✅ 对小表使用 broadcast join
✅ 开启 AQE(Spark 3.0+,自动优化大量场景)
✅ 使用 explain() 检查执行计划
✅ 合理设置 shuffle 分区数(非默认 200)
✅ 写入前使用 coalesce 合并小分区
✅ 敏感配置不要硬编码,用外部配置文件
✅ 生产环境明确指定 Schema,不要用 inferSchema
18.2 常见陷阱
❌ 在 Driver 上执行 collect() 获取大数据集 → OOM
解决:改用 write() 或 take(n) 获取小样本
❌ 在 foreach/map 中修改 Driver 变量 → 变量不会同步回 Driver
解决:使用 Accumulator
❌ 对大 DataFrame 频繁调用 count() → 每次都触发全量计算
解决:cache() 后再多次 count(),或只统计一次
❌ 在循环中创建大量 RDD/DataFrame 而不释放 → 内存泄漏
解决:及时 unpersist(),或重构避免循环创建
❌ groupByKey 后做 sum → 全量 shuffle 再计算
解决:改用 reduceByKey(有 map-side combine)
❌ SQL 写了 SELECT * → 读取所有列,破坏列裁剪优化
解决:显式列出需要的列
❌ 使用 Python lambda 函数并序列化整个模块 → 序列化失败
解决:确保 lambda 捕获的变量可序列化,避免捕获 SparkContext
❌ Shuffle 分区默认 200,不随数据量调整 → 分区过大或过小
解决:根据数据量调整,或开启 AQE 自动调整
❌ 重复读取相同数据源 → 多次 IO
解决:读取一次,cache,然后多次使用
❌ 在 Streaming 中使用 collect() → 阻塞流处理
解决:改用 foreachBatch
18.3 代码质量规范
python
# ✅ 推荐的代码结构
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.types import StructType, StructField, StringType, LongType
def create_spark_session(app_name: str) -> SparkSession:
"""创建 SparkSession,生产环境使用 getOrCreate 保证幂等"""
return SparkSession.builder \
.appName(app_name) \
.config("spark.sql.adaptive.enabled", "true") \
.config("spark.sql.shuffle.partitions", "400") \
.getOrCreate()
def read_data(spark: SparkSession, path: str, schema: StructType):
"""显式指定 Schema 读取数据"""
return spark.read.schema(schema).parquet(path)
def process_data(df):
"""业务处理逻辑,链式调用保持清晰"""
return (
df
.filter(F.col("date") >= "2024-01-01") # 尽早过滤
.select("user_id", "amount", "category", "date") # 尽早列裁剪
.withColumn("year", F.year(F.col("date")))
.groupBy("category", "year")
.agg(
F.sum("amount").alias("total_amount"),
F.count("user_id").alias("user_count")
)
.orderBy(F.col("total_amount").desc())
)
def write_data(df, output_path: str):
"""分区写入,控制输出文件数量"""
df.repartition(20) \
.write \
.mode("overwrite") \
.partitionBy("year") \
.parquet(output_path)
def main():
spark = create_spark_session("DataProcessingJob")
# 定义 Schema
schema = StructType([
StructField("user_id", LongType(), False),
StructField("amount", DoubleType(), True),
StructField("category", StringType(), True),
StructField("date", StringType(), True),
])
# 读取
raw_df = read_data(spark, "hdfs://input/path/", schema)
# 处理(缓存复用的 DataFrame)
processed_df = process_data(raw_df)
processed_df.cache()
# 输出
write_data(processed_df, "hdfs://output/path/")
# 验证
print(f"输出行数: {processed_df.count()}")
spark.stop()
if __name__ == "__main__":
main()
19. 参考资源
官方文档
| 资源 | 链接 |
|---|---|
| Apache Spark 官方文档 | https://spark.apache.org/docs/latest/ |
| Spark SQL 编程指南 | https://spark.apache.org/docs/latest/sql-programming-guide.html |
| RDD 编程指南 | https://spark.apache.org/docs/latest/rdd-programming-guide.html |
| Structured Streaming 指南 | https://spark.apache.org/docs/latest/structured-streaming-programming-guide.html |
| MLlib 指南 | https://spark.apache.org/docs/latest/ml-guide.html |
| GraphX 指南 | https://spark.apache.org/docs/latest/graphx-programming-guide.html |
| 性能调优官方指南 | https://spark.apache.org/docs/latest/sql-performance-tuning.html |
| 配置参数说明 | https://spark.apache.org/docs/latest/configuration.html |
| Databricks 官方博客 | https://www.databricks.com/blog |
学习路径建议
初学者(0~3 个月):
1. 理解 Spark 架构(Driver、Executor、DAG)
2. 学习 DataFrame API(本文第 6 章)
3. 理解 RDD 基础和 Transformation/Action
4. 实践:本地 WordCount → 数据清洗 ETL
推荐:Spark 官方 Quick Start + 本文 1~6 章
进阶(3~9 个月):
1. 深入 Catalyst 和 Tungsten(第 7 章)
2. 掌握 Join 策略和数据倾斜处理(第 8 章)
3. 学习 Structured Streaming(第 9 章)
4. 性能调优实践(第 14 章)
5. 实践:构建完整 ETL Pipeline + 实时流处理
推荐:Spark: The Definitive Guide(O'Reilly)
专家(9 个月+):
1. MLlib 分布式机器学习
2. Spark 源码阅读(Catalyst、AQE)
3. 自定义 Data Source / Catalyst Rule
4. 生产集群运维与调优
推荐:Learning Spark(O'Reilly)第二版
国内生态与工具
| 工具/平台 | 说明 |
|---|---|
| 阿里云 EMR | 托管 Spark 集群,支持 HDFS/OSS |
| 腾讯云 EMR | Spark 与 Hadoop 集成服务 |
| 华为 MRS | MapReduce Service,支持 Spark |
| Databricks | Spark 商业化平台,性能最优 |
| DolphinScheduler | 国产开源分布式工作流调度器 |
| Apache Kyuubi | 统一 SQL 网关,支持 Spark Thrift Server |
| Delta Lake | 数据湖存储层,支持 ACID 事务 |
| Apache Iceberg | 开放表格式,支持 Schema 演化 |