在大数据处理领域,Apache Spark 凭借内存计算、多场景兼容、高性能优化三大核心优势,已成为离线批处理、实时流计算、机器学习等场景的首选框架。相比 Hadoop MapReduce,Spark 计算速度提升 10-100 倍,且支持 Python、Java、Scala 等多语言开发,降低了大数据技术的使用门槛。
一、Spark 环境安装与配置(Linux/Windows 双平台)
Spark 运行依赖 JDK,支持本地模式、Standalone 集群模式、YARN 集群模式。本节提供 Linux 生产环境与 Windows 开发环境的完整配置方案,新手建议先从本地模式入手。
1.1 前置依赖准备
1.1.1 JDK 安装(必选)
Spark 3.x 推荐使用 JDK 8(兼容性最佳),避免 JDK 11+ 导致的组件兼容问题:
bash
# Linux 安装 JDK 8
yum install java-1.8.0-openjdk-devel -y
# 配置环境变量(/etc/profile)
echo "export JAVA_HOME=/usr/lib/jvm/java-1.8.0-openjdk" >> /etc/profile
echo "export PATH=$JAVA_HOME/bin:$PATH" >> /etc/profile
source /etc/profile
# 验证安装
java -version # 预期输出:java version "1.8.0_391"
1.1.2 Scala 安装(可选,Scala 开发者必备)
Spark 底层由 Scala 开发,若需编写 Scala 代码,需安装对应版本(Spark 3.5.1 对应 Scala 2.12):
bash
# 下载 Scala 2.12.18
wget https://downloads.lightbend.com/scala/2.12.18/scala-2.12.18.tgz
tar -zxvf scala-2.12.18.tgz -C /usr/local/
mv /usr/local/scala-2.12.18 /usr/local/scala
# 配置环境变量
echo "export SCALA_HOME=/usr/local/scala" >> /etc/profile
echo "export PATH=$SCALA_HOME/bin:$PATH" >> /etc/profile
source /etc/profile
# 验证安装
scala -version # 预期输出:Scala code runner version 2.12.18
1.1.3 Windows 开发环境配置(额外步骤)
- 下载 Spark 3.5.1 安装包:Spark 官网,选择
spark-3.5.1-bin-hadoop3.tgz - 解压到本地目录(如
D:\spark-3.5.1) - 配置环境变量:
SPARK_HOME=D:\spark-3.5.1PATH=%SPARK_HOME%\bin;%SPARK_HOME%\sbin
- 下载 winutils.exe(Hadoop Windows 兼容工具),放入
%SPARK_HOME%\bin,解决本地模式文件读写权限问题
1.2 本地模式安装(快速验证)
bash
# Linux 下载 Spark 3.5.1
wget https://archive.apache.org/dist/spark/spark-3.5.1/spark-3.5.1-bin-hadoop3.tgz
tar -zxvf spark-3.5.1-bin-hadoop3.tgz -C /usr/local/
mv /usr/local/spark-3.5.1-bin-hadoop3 /usr/local/spark
# 配置环境变量
echo "export SPARK_HOME=/usr/local/spark" >> /etc/profile
echo "export PATH=$SPARK_HOME/bin:$PATH" >> /etc/profile
source /etc/profile
# 验证安装(运行 Pi 计算示例)
spark-submit --class org.apache.spark.examples.SparkPi \
--master local[*] \
$SPARK_HOME/examples/jars/spark-examples_2.12-3.5.1.jar 10
# 预期输出:Pi is roughly 3.141592653589793
1.3 Standalone 集群模式配置(生产测试)
适合小规模集群(3-10 节点),无需依赖 Hadoop YARN,部署简单:
1.3.1 主节点(Master)配置
bash
# 复制配置模板
cp $SPARK_HOME/conf/spark-env.sh.template $SPARK_HOME/conf/spark-env.sh
cp $SPARK_HOME/conf/workers.template $SPARK_HOME/conf/workers
# 修改 spark-env.sh(添加 JDK 路径与 Master 地址)
echo "export JAVA_HOME=/usr/lib/jvm/java-1.8.0-openjdk" >> $SPARK_HOME/conf/spark-env.sh
echo "export SPARK_MASTER_HOST=192.168.1.100" >> $SPARK_HOME/conf/spark-env.sh
echo "export SPARK_MASTER_PORT=7077" >> $SPARK_HOME/conf/spark-env.sh # 默认端口
# 修改 workers 文件(添加从节点 IP/主机名)
echo "192.168.1.101" >> $SPARK_HOME/conf/workers
echo "192.168.1.102" >> $SPARK_HOME/conf/workers
1.3.2 从节点(Worker)配置
-
复制主节点的 Spark 目录到所有从节点:
bashscp -r /usr/local/spark root@192.168.1.101:/usr/local/ scp -r /usr/local/spark root@192.168.1.102:/usr/local/ -
配置从节点环境变量(同主节点)
1.3.3 启动集群
bash
# 主节点启动集群
$SPARK_HOME/sbin/start-all.sh
# 验证集群状态
jps # 主节点显示 Master,从节点显示 Worker
curl http://192.168.1.100:8080 # 访问 Web UI,查看集群节点状态
二、核心数据结构实战(RDD/DataFrame/Dataset 超详细案例)
Spark 提供三大核心数据结构,各有适用场景:
- RDD:底层抽象,灵活度高,适合非结构化数据处理
- DataFrame:带 Schema 的分布式表,支持 SQL/DSL 操作,性能最优
- Dataset:DataFrame + 类型安全,兼顾易用性与性能(Scala/Java 推荐)
2.1 初始化 Spark 上下文(统一入口)
2.1.1 Python(PySpark)初始化
python
from pyspark.sql import SparkSession
from pyspark import SparkContext
# 初始化 SparkSession(Spark 2.x+ 推荐)
spark = SparkSession.builder \
.appName("SparkCoreDemo") \
.master("local[*]") # 本地模式,集群模式删除此参数
.config("spark.driver.memory", "2g") # Driver 内存
.config("spark.executor.memory", "4g") # Executor 内存
.getOrCreate()
# 获取 SparkContext(操作 RDD 需用到)
sc = spark.sparkContext
# 设置日志级别(减少冗余输出)
sc.setLogLevel("WARN")
2.1.2 Scala 初始化
scala
import org.apache.spark.sql.SparkSession
object SparkInitDemo {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder()
.appName("SparkCoreDemo")
.master("local[*]")
.config("spark.driver.memory", "2g")
.config("spark.executor.memory", "4g")
.getOrCreate()
val sc = spark.sparkContext
sc.setLogLevel("WARN")
// 关闭资源(避免内存泄漏)
spark.stop()
}
}
2.2 RDD 实战(底层操作,灵活度拉满)
RDD 支持转化操作(Transformation) 和行动操作(Action),转化操作惰性执行,行动操作触发计算。
2.2.1 RDD 创建方式(5 种常用)
python
# Python 示例
# 1. 从本地集合创建
rdd1 = sc.parallelize([(1, "a"), (2, "b"), (3, "c")], numSlices=2) # 指定分区数
# 2. 从文件创建(本地文件/DFS/HDFS)
rdd2 = sc.textFile("file:///root/data.txt") # 本地文件
rdd3 = sc.textFile("hdfs://192.168.1.100:9000/data/users.txt") # HDFS 文件
# 3. 从目录创建(读取目录下所有文件)
rdd4 = sc.wholeTextFiles("file:///root/data_dir") # (文件名, 文件内容) 键值对
# 4. 从其他 RDD 衍生
rdd5 = rdd1.map(lambda x: (x[0]*2, x[1].upper()))
# 5. 从数据库读取(通过 JDBC 间接创建)
jdbc_rdd = sc.parallelize(
spark.read.jdbc(
url="jdbc:mysql://localhost:3306/test",
table="user",
properties={"user": "root", "password": "123456"}
).collect()
)
2.2.2 RDD 核心转化操作(15+ 案例)
python
# 1. 映射操作:map(一对一)
rdd = sc.parallelize([1,2,3,4,5])
map_rdd = rdd.map(lambda x: x * 3) # 输出:[3,6,9,12,15]
# 2. 扁平化映射:flatMap(一对多)
text_rdd = sc.parallelize(["hello spark", "hello python"])
flatmap_rdd = text_rdd.flatMap(lambda x: x.split(" ")) # 输出:["hello", "spark", "hello", "python"]
# 3. 过滤操作:filter
filter_rdd = rdd.filter(lambda x: x % 2 == 0) # 输出:[2,4]
# 4. 去重操作:distinct
distinct_rdd = sc.parallelize([1,2,2,3,3,3]).distinct() # 输出:[1,2,3]
# 5. 排序操作:sortBy
sort_rdd = sc.parallelize([(2, "b"), (1, "a"), (3, "c")]).sortBy(lambda x: x[0]) # 按第一个元素升序
# 6. 聚合操作:groupByKey + reduceByKey
pair_rdd = sc.parallelize([("a", 1), ("a", 2), ("b", 3), ("b", 4)])
group_rdd = pair_rdd.groupByKey() # 分组:("a", [1,2]), ("b", [3,4])
reduce_rdd = pair_rdd.reduceByKey(lambda x, y: x + y) # 聚合:("a", 3), ("b", 7)
# 7. 连接操作:join(内连接)、leftOuterJoin(左外连接)
rdd_a = sc.parallelize([(1, "Alice"), (2, "Bob")])
rdd_b = sc.parallelize([(1, 23), (2, 25), (3, 22)])
join_rdd = rdd_a.join(rdd_b) # 输出:(1, ("Alice",23)), (2, ("Bob",25))
left_join_rdd = rdd_a.leftOuterJoin(rdd_b) # 输出:(1, ("Alice",23)), (2, ("Bob",25))
# 8. 分区操作:repartition(重分区,会 shuffle)、coalesce(合并分区,不 shuffle)
repartition_rdd = rdd.repartition(4) # 重分区为 4 个
coalesce_rdd = repartition_rdd.coalesce(2) # 合并为 2 个分区
2.2.3 RDD 核心行动操作(10+ 案例)
python
# 1. 收集结果:collect(小数据量使用)
result = map_rdd.collect() # 输出:[3,6,9,12,15]
# 2. 统计数量:count
count = rdd.count() # 输出:5
# 3. 统计求和:sum
sum_val = rdd.sum() # 输出:15
# 4. 取前 N 条:take
top3 = rdd.take(3) # 输出:[1,2,3]
# 5. 查找最值:max、min
max_val = rdd.max() # 输出:5
min_val = rdd.min() # 输出:1
# 6. 保存结果:saveAsTextFile、saveAsSequenceFile
map_rdd.saveAsTextFile("hdfs://192.168.1.100:9000/output/map_result")
pair_rdd.saveAsSequenceFile("hdfs://192.168.1.100:9000/output/sequence_result")
# 7. 遍历元素:foreach(分布式遍历,不回传 Driver)
map_rdd.foreach(lambda x: print(x)) # 每个元素在 Executor 端打印
# 8. 聚合计算:reduce
total = rdd.reduce(lambda x, y: x + y) # 输出:15
# 9. 统计信息:stats(返回计数、均值、标准差、最大值、最小值)
stats = rdd.stats()
print(f"计数:{stats.count}, 均值:{stats.mean}, 最大值:{stats.max}")
2.3 DataFrame 实战(结构化数据首选)
DataFrame 是带 Schema 的分布式表,支持 SQL 和 DSL 两种操作风格,Spark 优化器(Catalyst)会自动优化执行计划,性能优于 RDD。
2.3.1 DataFrame 创建方式(6 种常用)
python
# Python 示例
from pyspark.sql.types import StructType, StructField, IntegerType, StringType, DoubleType
# 1. 从 RDD 转换(指定 Schema)
rdd = sc.parallelize([(1, "Alice", 23, 5000.0), (2, "Bob", 25, 6000.0)])
schema = StructType([
StructField("id", IntegerType(), nullable=False),
StructField("name", StringType(), nullable=True),
StructField("age", IntegerType(), nullable=True),
StructField("salary", DoubleType(), nullable=True)
])
df1 = spark.createDataFrame(rdd, schema=schema)
# 2. 从本地集合创建(自动推断 Schema)
data = [("Charlie", 22, "Engineer"), ("David", 24, "Designer")]
df2 = spark.createDataFrame(data, schema=["name", "age", "job"])
# 3. 从文件创建(CSV/JSON/Parquet/Excel)
# CSV 文件(带表头)
df_csv = spark.read \
.option("header", "true") \
.option("inferSchema", "true") \
.csv("file:///root/data/users.csv")
# JSON 文件
df_json = spark.read.json("hdfs://192.168.1.100:9000/data/employees.json")
# Parquet 文件(Spark 原生格式,压缩比高)
df_parquet = spark.read.parquet("file:///root/data/orders.parquet")
# 4. 从数据库读取(MySQL)
df_mysql = spark.read \
.format("jdbc") \
.option("url", "jdbc:mysql://localhost:3306/test?useSSL=false&serverTimezone=UTC") \
.option("dbtable", "employee") \
.option("user", "root") \
.option("password", "123456") \
.load()
# 5. 从 Hive 表读取
df_hive = spark.sql("SELECT * FROM test_db.user_info")
# 6. 手动创建(使用 toDF 方法)
df6 = rdd.toDF(["id", "name", "age", "salary"])
2.3.2 DataFrame DSL 风格操作(20+ 案例)
python
# 1. 查看数据:show(默认前 20 行)
df1.show() # 显示所有列
df1.show(10, truncate=False) # 显示前 10 行,不截断长文本
# 2. 查看 Schema:printSchema
df1.printSchema() # 输出字段名、类型、是否可为空
# 3. 选择列:select
df_select = df1.select("name", "age", "salary") # 选择指定列
df_select_expr = df1.selectExpr("name", "age + 1 as age_plus_1") # 表达式选择
# 4. 过滤数据:filter/where(两者等价)
df_filter1 = df1.filter(df1.age > 23) # 年龄大于 23
df_filter2 = df1.where("salary > 5500") # SQL 语法字符串
# 5. 排序数据:orderBy/sort
df_sort1 = df1.orderBy("age", ascending=False) # 年龄降序
df_sort2 = df1.sort(["salary", "age"], ascending=[True, False]) # 薪资升序,年龄降序
# 6. 分组聚合:groupBy + 聚合函数
from pyspark.sql.functions import avg, sum, count, max
df_group = df1.groupBy("job") \
.agg(
avg("salary").alias("avg_salary"),
sum("age").alias("total_age"),
count("id").alias("user_count")
)
# 7. 去重:distinct/dropDuplicates
df_distinct = df1.distinct() # 全列去重
df_drop_dup = df1.dropDuplicates(["name", "age"]) # 按指定列去重
# 8. 缺失值处理:na.fill/na.drop/na.replace
df_fill = df1.na.fill({"salary": 4000.0, "job": "Unknown"}) # 填充缺失值
df_drop = df1.na.drop(how="any", subset=["id", "name"]) # 删除指定列有缺失值的行
df_replace = df1.na.replace({"Engineer": "Eng"}, subset=["job"]) # 替换值
# 9. 新增列:withColumn
df_with_col = df1.withColumn("bonus", df1.salary * 0.1) # 新增奖金列(薪资 10%)
df_with_col2 = df1.withColumn("adult", df1.age >= 18) # 新增成年标识列
# 10. 删除列:drop
df_drop_col = df1.drop("salary") # 删除薪资列
# 11. 连接操作:join(内连接)、leftJoin(左连接)、rightJoin(右连接)
df_a = spark.createDataFrame([(1, "Alice"), (2, "Bob")], ["id", "name"])
df_b = spark.createDataFrame([(1, 23), (2, 25), (3, 22)], ["id", "age"])
df_join = df_a.join(df_b, on="id", how="inner") # 内连接
df_left_join = df_a.join(df_b, on="id", how="left") # 左连接
# 12. 窗口函数:row_number/rank/dense_rank(Top N 场景常用)
from pyspark.sql.window import Window
from pyspark.sql.functions import row_number
window_spec = Window.partitionBy("job").orderBy(df1.salary.desc())
df_window = df1.withColumn("rank", row_number().over(window_spec)) # 按职位分组,薪资降序排名
2.3.3 DataFrame SQL 风格操作(10+ 案例)
python
# 1. 创建临时视图(createOrReplaceTempView:会话级,仅当前 SparkSession 可用)
df1.createOrReplaceTempView("employee")
# 2. 创建全局临时视图(createGlobalTempView:应用级,所有 SparkSession 可用)
df1.createGlobalTempView("global_employee")
# 3. 基本查询
df_sql1 = spark.sql("SELECT name, age, salary FROM employee WHERE age > 23")
# 4. 聚合查询
df_sql2 = spark.sql("""
SELECT job, AVG(salary) as avg_salary, COUNT(id) as user_count
FROM employee
GROUP BY job
HAVING COUNT(id) > 1
""")
# 5. 排序查询
df_sql3 = spark.sql("SELECT * FROM employee ORDER BY salary DESC LIMIT 10")
# 6. 连接查询
df_sql4 = spark.sql("""
SELECT e.name, e.age, d.department
FROM employee e
LEFT JOIN department d ON e.dept_id = d.id
""")
# 7. 窗口函数查询
df_sql5 = spark.sql("""
SELECT name, job, salary,
ROW_NUMBER() OVER (PARTITION BY job ORDER BY salary DESC) as rank
FROM employee
""")
# 8. 全局临时视图查询(需指定 global_temp 库)
df_sql6 = spark.sql("SELECT * FROM global_temp.global_employee WHERE salary > 5000")
# 9. 插入数据(覆盖/追加)
spark.sql("INSERT OVERWRITE TABLE employee_backup SELECT * FROM employee") # 覆盖
spark.sql("INSERT INTO TABLE employee_backup SELECT * FROM employee") # 追加
# 10. 创建表并插入数据
spark.sql("""
CREATE TABLE IF NOT EXISTS high_salary_employee (
id INT,
name STRING,
age INT,
salary DOUBLE
)
USING PARQUET
""")
spark.sql("INSERT INTO high_salary_employee SELECT * FROM employee WHERE salary > 6000")
2.4 Dataset 实战(Scala 类型安全示例)
Dataset 是 Scala/Java 特有的数据结构,结合了 DataFrame 的优化和 RDD 的类型安全,适合生产环境开发:
scala
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._
// 定义样例类(对应 Schema)
case class Employee(id: Int, name: String, age: Int, salary: Double, job: String)
object DatasetDemo {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder()
.appName("DatasetDemo")
.master("local[*]")
.getOrCreate()
import spark.implicits._ // 隐式转换,用于 RDD 转 Dataset
// 1. 创建 Dataset(从 RDD 转换)
val rdd = spark.sparkContext.parallelize(Seq(
Employee(1, "Alice", 23, 5000.0, "Engineer"),
Employee(2, "Bob", 25, 6000.0, "Designer"),
Employee(3, "Charlie", 22, 5500.0, "Engineer")
))
val ds: org.apache.spark.sql.Dataset[Employee] = rdd.toDS()
// 2. 类型安全操作(编译时检查字段名)
val filteredDs = ds.filter(_.age > 22) // 年龄大于 22
val selectedDs = ds.select($"name", $"salary" * 0.1 as "bonus") // 新增奖金列
val groupedDs = ds.groupByKey(_.job).agg(avg($"salary").as[Double]) // 按职位分组求平均薪资
// 3. 转换为 DataFrame(如需 SQL 操作)
val df = ds.toDF()
df.createOrReplaceTempView("employee")
val sqlDs = spark.sql("SELECT * FROM employee WHERE salary > 5500").as[Employee]
// 4. 保存 Dataset
ds.write.mode("overwrite").parquet("file:///root/data/employee_ds.parquet")
spark.stop()
}
}
三、多数据源联动实战(MySQL/Hive/HDFS/Redis)
Spark 支持与多种数据源交互,实现数据导入导出,本节提供生产环境常用的 4 种数据源联动案例。
3.1 与 MySQL 交互(读写完整案例)
需提前下载 MySQL JDBC 驱动(mysql-connector-java-8.0.33.jar),放入 $SPARK_HOME/jars 目录。
3.1.1 Python 读写 MySQL
python
# 1. 从 MySQL 读取数据
df_read = spark.read \
.format("jdbc") \
.option("url", "jdbc:mysql://localhost:3306/test?useSSL=false&serverTimezone=UTC&allowPublicKeyRetrieval=true") \
.option("dbtable", "employee") # 表名,可写 SQL 如 "(SELECT * FROM employee WHERE age > 20) t"
.option("user", "root") \
.option("password", "123456") \
.option("driver", "com.mysql.cj.jdbc.Driver") \
.option("numPartitions", 4) # 读取分区数(并行读取,提升效率)
.option("partitionColumn", "id") # 分区字段(整数类型)
.option("lowerBound", 1) # 分区下界
.option("upperBound", 1000) # 分区上界
.load()
# 2. 向 MySQL 写入数据(4 种模式)
# 模式 1:overwrite(覆盖原有数据)
df_read.write \
.format("jdbc") \
.option("url", "jdbc:mysql://localhost:3306/test?useSSL=false&serverTimezone=UTC") \
.option("dbtable", "employee_backup") \
.option("user", "root") \
.option("password", "123456") \
.mode("overwrite") \
.save()
# 模式 2:append(追加数据)
df_read.write \
.format("jdbc") \
.option("url", "jdbc:mysql://localhost:3306/test?useSSL=false&serverTimezone=UTC") \
.option("dbtable", "employee_backup") \
.option("user", "root") \
.option("password", "123456") \
.mode("append") \
.save()
# 模式 3:ignore(数据存在则忽略)
df_read.write \
.format("jdbc") \
.option("url", "jdbc:mysql://localhost:3306/test?useSSL=false&serverTimezone=UTC") \
.option("dbtable", "employee_backup") \
.option("user", "root") \
.option("password", "123456") \
.mode("ignore") \
.save()
# 模式 4:errorifexists(数据存在则报错)
df_read.write \
.format("jdbc") \
.option("url", "jdbc:mysql://localhost:3306/test?useSSL=false&serverTimezone=UTC") \
.option("dbtable", "employee_backup") \
.option("user", "root") \
.option("password", "123456") \
.mode("errorifexists") \
.save()
3.1.2 Scala 读写 MySQL
scala
import org.apache.spark.sql.SparkSession
object MysqlDemo {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder()
.appName("MysqlDemo")
.master("local[*]")
.getOrCreate()
// 1. 从 MySQL 读取
val dfRead = spark.read
.format("jdbc")
.option("url", "jdbc:mysql://localhost:3306/test?useSSL=false&serverTimezone=UTC")
.option("dbtable", "employee")
.option("user", "root")
.option("password", "123456")
.option("driver", "com.mysql.cj.jdbc.Driver")
.load()
// 2. 向 MySQL 写入(覆盖模式)
dfRead.write
.format("jdbc")
.option("url", "jdbc:mysql://localhost:3306/test?useSSL=false&serverTimezone=UTC")
.option("dbtable", "employee_backup")
.option("user", "root")
.option("password", "123456")
.mode("overwrite")
.save()
spark.stop()
}
}
3.2 与 Hive 集成(生产环境核心)
Spark 集成 Hive 需将 Hive 的 hive-site.xml 配置文件复制到 $SPARK_HOME/conf 目录,确保 Spark 能读取 Hive 元数据。
3.2.1 Python 集成 Hive
python
# 初始化 SparkSession 时启用 Hive 支持
spark = SparkSession.builder()
.appName("HiveDemo")
.master("local[*]")
.enableHiveSupport() # 启用 Hive 支持
.config("hive.metastore.uris", "thrift://192.168.1.100:9083") # Hive Metastore 地址
.getOrCreate()
# 1. 读取 Hive 表(3 种方式)
# 方式 1:SQL 语句
df_hive1 = spark.sql("SELECT * FROM test_db.employee WHERE age > 20")
# 方式 2:指定数据库和表名
df_hive2 = spark.table("test_db.employee")
# 方式 3:读取分区表(过滤分区字段,避免全表扫描)
df_hive3 = spark.sql("SELECT * FROM test_db.order_info WHERE dt = '2026-01-01'")
# 2. 向 Hive 表写入数据(3 种模式)
# 模式 1:保存为新表(管理表)
df_hive1.write \
.mode("overwrite") \
.saveAsTable("test_db.employee_backup") # 自动创建表
# 模式 2:追加到现有表
df_hive1.write \
.mode("append") \
.insertInto("test_db.employee_backup") # 需表结构一致
# 模式 3:保存为外部表(指定路径)
df_hive1.write \
.mode("overwrite") \
.option("path", "hdfs://192.168.1.100:9000/hive/employee_external") \
.saveAsTable("test_db.employee_external")
# 3. 创建 Hive 分区表并写入
spark.sql("""
CREATE TABLE IF NOT EXISTS test_db.user_partition (
id INT,
name STRING,
age INT
)
PARTITIONED BY (dt STRING) # 按日期分区
STORED AS PARQUET
""")
# 写入分区数据
df_hive1.write \
.mode("append") \
.partitionBy("dt") \
.saveAsTable("test_db.user_partition")
3.3 与 HDFS 交互(文件读写)
python
# 1. 读取 HDFS 文件(CSV/JSON/Parquet)
# 读取 CSV 文件
df_hdfs_csv = spark.read \
.option("header", "true") \
.csv("hdfs://192.168.1.100:9000/data/users.csv")
# 读取 Parquet 文件(推荐,压缩比高、速度快)
df_hdfs_parquet = spark.read.parquet("hdfs://192.168.1.100:9000/data/orders.parquet")
# 2. 写入 HDFS 文件(多种格式)
# 写入 CSV 文件
df_hdfs_csv.write \
.mode("overwrite") \
.option("header", "true") \
.csv("hdfs://192.168.1.100:9000/output/users_output")
# 写入 Parquet 文件
df_hdfs_parquet.write \
.mode("overwrite") \
.parquet("hdfs://192.168.1.100:9000/output/orders_output")
# 写入 JSON 文件
df_hdfs_parquet.write \
.mode("overwrite") \
.json("hdfs://192.168.1.100:9000/output/orders_json")
# 3. 读取 HDFS 目录下所有文件(包含子目录)
df_hdfs_dir = spark.read \
.option("recursiveFileLookup", "true") \
.parquet("hdfs://192.168.1.100:9000/data/")
3.4 与 Redis 交互(缓存读写)
需引入 Redis 依赖(spark-redis_2.12-3.1.0.jar),放入 $SPARK_HOME/jars 目录。
python
# 1. 从 Redis 读取数据(Key-Value 形式)
df_redis = spark.read \
.format("redis") \
.option("redis.host", "192.168.1.101") \
.option("redis.port", "6379") \
.option("redis.password", "123456") \
.option("redis.key.pattern", "user:*") # 匹配 Key 模式
.load()
# 2. 向 Redis 写入数据
df_redis.write \
.format("redis") \
.option("redis.host", "192.168.1.101") \
.option("redis.port", "6379") \
.option("redis.password", "123456") \
.option("redis.key.column", "id") # 以 id 列为 Key
.option("redis.value.columns", "name,age") # 存储的 Value 列
.mode("overwrite") \
.save()
四、生产环境性能优化(10 大核心策略 + 代码案例)
Spark 作业性能瓶颈多源于资源配置不合理、数据倾斜、序列化低效、Shuffle 过多等问题,本节提供可落地的优化方案与代码示例。
4.1 资源配置优化(最基础也最重要)
通过 spark-submit 参数或配置文件调整资源分配,避免资源浪费或不足:
bash
# 生产环境 spark-submit 最优配置示例(YARN 集群模式)
spark-submit \
--class com.example.SparkProductionJob \
--master yarn \
--deploy-mode cluster \
--executor-memory 8g \ # 每个 Executor 内存(建议 4-8g,避免 GC 耗时过长)
--executor-cores 4 \ # 每个 Executor 核心数(建议 2-4 核)
--num-executors 20 \ # Executor 数量(根据集群规模调整)
--driver-memory 4g \ # Driver 内存(生产环境建议 2-4g)
--driver-cores 2 \ # Driver 核心数
--conf spark.default.parallelism=160 \ # 默认并行度(总核心数的 2-3 倍:20*4*2=160)
--conf spark.sql.shuffle.partitions=160 \ # SQL Shuffle 分区数(与并行度一致)
--conf spark.executor.memoryOverhead=2g \ # 额外内存(堆外内存,避免 OOM)
--conf spark.driver.memoryOverhead=1g \
--jars /path/to/mysql-connector-java-8.0.33.jar \ # 依赖包
production-job.jar
4.2 数据倾斜优化(生产环境高频问题)
数据倾斜表现为部分任务执行时间极长(如几小时),其余任务快速完成,核心原因是 Key 分布不均。
4.2.1 高频 Key 拆分优化(Python 案例)
python
from pyspark.sql.functions import col, concat, lit, rand
# 场景:user_id=0 是高频 Key,导致数据倾斜
df = spark.read.parquet("hdfs:///data/user_behavior.parquet")
# 步骤 1:拆分高频 Key(给高频 Key 添加随机前缀 0-9)
def split_high_freq_key(df, key_col, high_freq_values):
# 标记高频 Key
df_marked = df.withColumn(
"is_high_freq",
col(key_col).isin(high_freq_values)
)
# 高频 Key 添加随机前缀
df_high_freq = df_marked.filter(col("is_high_freq")) \
.withColumn(key_col, concat(col(key_col), lit("_"), rand().cast("int").mod(10)))
# 普通 Key 保持不变
df_normal = df_marked.filter(~col("is_high_freq"))
# 合并
return df_high_freq.union(df_normal)
# 拆分高频 Key(假设 user_id=0 是高频)
df_split = split_high_freq_key(df, "user_id", [0])
# 步骤 2:进行聚合操作(此时数据已均匀分布)
df_agg = df_split.groupBy("user_id").count()
# 步骤 3:合并高频 Key 结果(去除随机前缀)
def merge_high_freq_key(df, key_col, high_freq_values):
# 提取高频 Key 并去除前缀
df_high_freq = df.filter(col(key_col).startswith(tuple([f"{v}_" for v in high_freq_values]))) \
.withColumn(key_col, split(col(key_col), "_").getItem(0)) \
.groupBy(key_col).sum("count").withColumnRenamed("sum(count)", "count")
# 普通 Key 结果
df_normal = df.filter(~col(key_col).startswith(tuple([f"{v}_" for v in high_freq_values])))
# 合并
return df_high_freq.union(df_normal)
# 最终结果
df_final = merge_high_freq_key(df_agg, "user_id", [0])
4.2.2 Broadcast Join 优化(小表广播)
当 Join 的两张表大小差异较大(小表 < 100MB),使用 Broadcast Join 避免大表 Shuffle:
python
from pyspark.sql.functions import broadcast
# 大表(10GB)
big_df = spark.read.parquet("hdfs:///data/big_table.parquet")
# 小表(50MB)
small_df = spark.read.parquet("hdfs:///data/small_table.parquet")
# Broadcast Join(将小表广播到所有 Executor)
join_df = big_df.join(broadcast(small_df), on="id", how="inner")
4.3 序列化优化(提升网络传输效率)
默认使用 Java 序列化,效率较低,生产环境建议改用 Kryo 序列化:
bash
# 1. 在 spark-env.sh 中配置 Kryo 序列化
echo "export SPARK_SERIALIZER=org.apache.spark.serializer.KryoSerializer" >> $SPARK_HOME/conf/spark-env.sh
echo "export SPARK_KRYO_CLASSES_TO_REGISTER=com.example.Employee,com.example.Order" >> $SPARK_HOME/conf/spark-env.sh
# 2. 或在 spark-submit 中指定
spark-submit \
--conf spark.serializer=org.apache.spark.serializer.KryoSerializer \
--conf spark.kryo.registrationRequired=true \
--conf spark.kryo.classesToRegister=com.example.Employee,com.example.Order \
job.jar
4.4 缓存优化(重复使用数据)
对重复使用的 RDD/DataFrame 进行缓存,避免重复计算:
python
# 1. RDD 缓存(优先使用 persist,灵活配置存储级别)
rdd = sc.parallelize([1,2,3,4,5])
rdd.persist(StorageLevel.MEMORY_AND_DISK) # 内存不足时写入磁盘(推荐)
# rdd.cache() # 等价于 MEMORY_ONLY,不推荐(易 OOM)
# 2. DataFrame 缓存
df = spark.read.parquet("hdfs:///data/orders.parquet")
df.cache() # DataFrame 缓存默认 MEMORY_AND_DISK
# 3. 使用后释放资源(避免内存泄漏)
rdd.unpersist()
df.unpersist()
4.5 SQL 优化(Catalyst 优化器最大化利用)
python
# 1. 开启谓词下推(默认开启,确保生效)
spark.conf.set("spark.sql.pushDownPredicate", "true")
# 2. 开启列剪枝(只读取需要的列)
df_sql = spark.sql("SELECT id, name FROM employee WHERE age > 20") # 避免 SELECT *
# 3. 分区表查询(过滤分区字段)
df_partition = spark.sql("SELECT * FROM order_partition WHERE dt = '2026-01-01'")
# 4. 避免 Cartesian Product(笛卡尔积)
# 错误示例:未指定 Join 条件
df_cartesian = df1.join(df2) # 笛卡尔积,数据量暴增
# 正确示例:指定 Join 条件
df_join = df1.join(df2, on="id")
# 5. 使用 CTE 或临时表简化复杂 SQL(提升可读性与优化效率)
df_cte = spark.sql("""
WITH temp AS (
SELECT id, name, age FROM employee WHERE age > 20
)
SELECT * FROM temp JOIN department ON temp.id = department.emp_id
""")
五、实战场景:用户行为分析(完整案例)
结合前文知识点,实现一个完整的用户行为分析场景,包含数据读取、清洗、分析、存储全流程。
5.1 需求说明
分析用户行为数据,统计:
- 每日活跃用户数(DAU)
- 各页面访问 Top 10
- 新用户注册转化率(注册后 24 小时内有访问行为)
5.2 Python 完整代码
python
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, countDistinct, count, rank, from_unixtime, to_date, datediff
from pyspark.sql.window import Window
# 1. 初始化 SparkSession
spark = SparkSession.builder()
.appName("UserBehaviorAnalysis")
.master("local[*]")
.enableHiveSupport()
.config("spark.default.parallelism", 100)
.config("spark.sql.shuffle.partitions", 100)
.getOrCreate()
# 2. 读取数据(HDFS 上的用户行为数据和注册数据)
# 用户行为数据(click_log:user_id, page_id, action_time, action_type)
click_df = spark.read \
.option("header", "true") \
.option("inferSchema", "true") \
.csv("hdfs://192.168.1.100:9000/data/click_log.csv")
# 用户注册数据(user_register:user_id, register_time)
register_df = spark.read \
.option("header", "true") \
.option("inferSchema", "true") \
.csv("hdfs://192.168.1.100:9000/data/user_register.csv")
# 3. 数据清洗(处理缺失值、格式转换)
# 转换时间格式(Unix 时间戳 → 日期)
click_df_clean = click_df \
.dropna(subset=["user_id", "page_id", "action_time"]) # 删除关键字段缺失值
.withColumn("action_date", to_date(from_unixtime(col("action_time")))) # 转换为日期
register_df_clean = register_df \
.dropna(subset=["user_id", "register_time"])
.withColumn("register_date", to_date(from_unixtime(col("register_time"))))
# 4. 业务分析
# 4.1 每日活跃用户数(DAU)
dau_df = click_df_clean \
.groupBy("action_date") \
.agg(countDistinct("user_id").alias("dau")) \
.orderBy("action_date")
# 4.2 各页面访问 Top 10
page_top10_df = click_df_clean \
.groupBy("page_id") \
.agg(count("user_id").alias("visit_count")) \
.orderBy(col("visit_count").desc()) \
.limit(10)
# 4.3 新用户注册转化率(注册后 24 小时内有访问)
# 步骤 1:关联注册数据和行为数据
user_behavior_register_df = click_df_clean.join(
register_df_clean,
on="user_id",
how="inner"
)
# 步骤 2:计算注册后访问时间差
conversion_df = user_behavior_register_df \
.withColumn("days_diff", datediff(col("action_date"), col("register_date"))) \
.filter(col("days_diff") == 0) # 注册当天有访问
# 步骤 3:统计转化率
conversion_rate_df = conversion_df \
.agg(
countDistinct("user_id").alias("converted_user"),
countDistinct(col("user_id").alias("total_register_user")) # 总注册用户
) \
.withColumn("conversion_rate", col("converted_user") / col("total_register_user"))
# 5. 结果存储(写入 Hive 表和 HDFS)
# 写入 Hive 表
dau_df.write.mode("overwrite").saveAsTable("test_db.dau_daily")
page_top10_df.write.mode("overwrite").saveAsTable("test_db.page_visit_top10")
conversion_rate_df.write.mode("overwrite").saveAsTable("test_db.user_conversion_rate")
# 写入 HDFS(Parquet 格式)
dau_df.write.mode("overwrite").parquet("hdfs://192.168.1.100:9000/output/dau_daily")
# 6. 显示结果
print("每日活跃用户数(DAU):")
dau_df.show()
print("各页面访问 Top 10:")
page_top10_df.show()
print("新用户注册转化率:")
conversion_rate_df.show()
# 7. 关闭资源
spark.stop()
六、避坑指南(10 大常见问题与解决方案)
-
本地模式文件读写权限问题(Windows):
- 解决方案:下载 winutils.exe 放入
%SPARK_HOME%\bin,并以管理员身份运行命令行。
- 解决方案:下载 winutils.exe 放入
-
MySQL 连接报错"Public Key Retrieval is not allowed":
- 解决方案:URL 中添加
allowPublicKeyRetrieval=true,如jdbc:mysql://localhost:3306/test?allowPublicKeyRetrieval=true。
- 解决方案:URL 中添加
-
数据倾斜导致任务超时:
- 解决方案:过滤无效 Key、拆分高频 Key、使用 Broadcast Join、调整分区数。
-
OOM(内存溢出):
- 解决方案:增加 Executor/Driver 内存、使用
persist(MEMORY_AND_DISK)替代cache()、避免collect()收集大量数据。
- 解决方案:增加 Executor/Driver 内存、使用
-
版本兼容性问题(Spark 与 Hadoop/Hive):
- 解决方案:Spark 3.x 对应 Hadoop 3.x、Hive 3.x,MySQL JDBC 驱动 8.x 对应 MySQL 8.x。
-
Shuffle 过程缓慢:
- 解决方案:调整
spark.sql.shuffle.partitions为合理值(100-200)、开启 Kryo 序列化、优化 Join 策略。
- 解决方案:调整
-
Hive 表无法读取:
- 解决方案:确保
hive-site.xml已复制到$SPARK_HOME/conf,Hive Metastore 服务已启动。
- 解决方案:确保
-
数据类型不匹配:
- 解决方案:读取数据时显式指定 Schema,避免
inferSchema=true导致的类型推断错误。
- 解决方案:读取数据时显式指定 Schema,避免
-
任务重试次数耗尽:
- 解决方案:调整
spark.task.maxFailures=8(默认 4 次),检查数据是否有脏数据,优化资源配置。
- 解决方案:调整
-
Driver 端压力过大:
- 解决方案:避免在 Driver 端执行大量计算,使用
take(n)或sample()替代collect(),增加 Driver 内存。
- 解决方案:避免在 Driver 端执行大量计算,使用
七、总结
本文基于 Spark 3.5.1 版本,提供了从环境搭建到生产优化的全流程实战指南,涵盖 RDD/DataFrame/Dataset 三大核心数据结构、多数据源联动、性能优化、实战场景等模块,新增了大量可直接运行的代码案例。
Spark 学习的核心是"理解抽象、熟练 API、优化实战":
- 入门阶段:掌握环境搭建和基础 API 操作,通过本地模式快速验证代码;
- 进阶阶段:深入理解数据倾斜、Shuffle 等核心概念,掌握性能优化技巧;
- 生产阶段:结合业务场景灵活选择数据结构和优化策略,确保作业稳定高效运行。