
📚 一、DataFrame核心概念
1.1 什么是DataFrame?
本质 :PySpark DataFrame是一个分布式、不可变、基于命名列的数据集合,类似于关系型数据库表或Pandas DataFrame,但底层是RDD的封装。
关键特性:
- 分布式处理:数据自动分区在集群节点上并行处理
- 惰性求值:转换操作不立即执行,而是构建执行计划
- Schema感知:每个DataFrame都有明确的数据类型定义
- 优化执行:通过Catalyst优化器和Tungsten引擎高效执行
1.2 与RDD、Pandas DataFrame对比
| 维度 | PySpark DataFrame | RDD | Pandas DataFrame |
|---|---|---|---|
| 数据处理 | 结构化,列式操作 | 非结构化,面向记录 | 单机结构化 |
| API风格 | 声明式(SQL风格) | 命令式(函数式) | 命令式/声明式混合 |
| 性能 | Catalyst优化,最高效 | 手动优化,中等 | 单机最快 |
| 使用场景 | 大规模ETL、分析 | 复杂非结构化处理 | 中小规模数据分析 |
🔧 二、DataFrame完整操作体系
2.1 创建DataFrame的7种方法
python
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DoubleType
import pyspark.sql.functions as F
spark = SparkSession.builder.appName("CompleteGuide").getOrCreate()
# 方法1:从列表创建(最常用)
data = [("Alice", 34, 55000.0), ("Bob", 45, 72000.0), ("Cathy", 28, 48000.0)]
df1 = spark.createDataFrame(data, ["name", "age", "salary"])
# 方法2:从RDD转换
rdd = spark.sparkContext.parallelize(data)
df2 = rdd.toDF(["name", "age", "salary"])
# 方法3:指定Schema创建(精确控制数据类型)
schema = StructType([
StructField("name", StringType(), nullable=False),
StructField("age", IntegerType(), nullable=True),
StructField("salary", DoubleType(), nullable=True)
])
df3 = spark.createDataFrame(data, schema=schema)
# 方法4:从Pandas DataFrame创建(交互调试)
import pandas as pd
pdf = pd.DataFrame(data, columns=["name", "age", "salary"])
df4 = spark.createDataFrame(pdf)
# 方法5:从字典列表创建
dict_list = [{"name": "Alice", "age": 34, "salary": 55000.0},
{"name": "Bob", "age": 45, "salary": 72000.0}]
df5 = spark.createDataFrame(dict_list)
# 方法6:通过SQL查询创建
df1.createOrReplaceTempView("people")
df6 = spark.sql("SELECT name, age FROM people WHERE age > 30")
# 方法7:从空Schema创建(用于流处理等场景)
empty_df = spark.createDataFrame([], schema)
2.2 数据读取与写入(支持20+格式)
python
# ============ 读取数据 ============
# 1. CSV文件(最常用)
df_csv = spark.read.csv("path/to/file.csv",
header=True, # 第一行作为列名
inferSchema=True, # 自动推断类型
sep=",", # 分隔符
quote='"', # 引号字符
escape='"', # 转义字符
nullValue="NA") # 空值表示
# 2. JSON文件(半结构化数据)
df_json = spark.read.json("path/to/file.json",
multiLine=True) # 多行JSON
# 3. Parquet文件(生产环境首选)
df_parquet = spark.read.parquet("path/to/file.parquet")
# 4. ORC文件
df_orc = spark.read.orc("path/to/file.orc")
# 5. 文本文件(每行作为字符串)
df_text = spark.read.text("path/to/file.txt")
# 6. 从JDBC数据库读取
df_jdbc = spark.read \
.format("jdbc") \
.option("url", "jdbc:mysql://localhost:3306/db") \
.option("dbtable", "table_name") \
.option("user", "username") \
.option("password", "password") \
.load()
# 7. 读取多个文件/目录
df_multi = spark.read.csv(["file1.csv", "file2.csv", "file3.csv"], header=True)
# 8. 读取Hive表
df_hive = spark.sql("SELECT * FROM hive_table")
# ============ 写入数据 ============
# 1. 基本写入
df.write.csv("output_path.csv", header=True, mode="overwrite")
# 2. 写入模式详解
# mode="overwrite" # 覆盖(最常用)
# mode="append" # 追加
# mode="ignore" # 存在则跳过
# mode="error" # 存在则报错(默认)
# 3. 分区写入(性能优化关键!)
df.write \
.partitionBy("year", "month", "day") \ # 按多级分区
.parquet("output_path/")
# 4. 分桶写入(优化JOIN和查询)
df.write \
.bucketBy(100, "user_id") \ # 分成100个桶
.sortBy("timestamp") \ # 每个桶内排序
.mode("overwrite") \
.saveAsTable("bucketed_table")
# 5. 写入到JDBC
df.write \
.format("jdbc") \
.option("url", "jdbc:mysql://localhost:3306/db") \
.option("dbtable", "new_table") \
.option("user", "username") \
.option("password", "password") \
.mode("append") \
.save()
2.3 数据查看与探查
python
# ============ 基础查看 ============
df.show(10, truncate=False) # 显示10行,不截断长文本
df.show(vertical=True) # 垂直格式显示,适合宽表
df.printSchema() # 打印Schema树状图
df.columns # 获取列名列表
df.dtypes # 获取(列名, 类型)列表
df.schema # 获取完整Schema对象
df.count() # 行数统计
# ============ 数据抽样 ============
df.sample(withReplacement=False, fraction=0.1, seed=42).show() # 无放回抽样10%
df.sampleBy("category", fractions={"A": 0.1, "B": 0.2}, seed=42) # 分层抽样
# ============ 统计信息 ============
df.describe().show() # 数值列基本统计
df.describe("age", "salary").show() # 指定列统计
df.summary().show() # 更详细统计(包括百分位数)
# 自定义统计
df.select(
F.mean("salary").alias("avg_salary"),
F.stddev("salary").alias("std_salary"),
F.skewness("salary").alias("skewness"),
F.kurtosis("salary").alias("kurtosis")
).show()
# ============ 空值检查 ============
from pyspark.sql.functions import isnan, isnull
# 统计每列空值数量
null_counts = df.select([F.sum(F.col(c).isNull().cast("int")).alias(c) for c in df.columns])
null_counts.show()
# 检查NaN值(浮点数特有)
nan_counts = df.select([F.sum(isnan(F.col(c)).cast("int")).alias(c) for c in df.columns])
nan_counts.show()
# ============ 唯一值与频次 ============
df.select("category").distinct().count() # 唯一值数量
df.select("category").distinct().show() # 显示所有唯一值
df.groupBy("category").count().orderBy("count", ascending=False).show() # 频次统计
2.4 列操作完整指南
python
# ============ 选择列 ============
# 多种选择方式
df.select("name", "age").show()
df.select(df.name, df["age"]).show()
df.select(F.col("name"), F.col("age") + 10).show()
df.select("*").show() # 所有列
df.select(df.columns[:3]) # 前3列
# 使用列表推导选择列
numeric_cols = [c for c, t in df.dtypes if t in ('int', 'double', 'float', 'bigint', 'smallint', 'tinyint')]
df.select(numeric_cols).show()
# ============ 新增/修改列 ============
# 1. 基本新增列
df = df.withColumn("salary_k", df.salary / 1000)
# 2. 条件赋值(SQL CASE WHEN)
df = df.withColumn("age_group",
F.when(df.age < 20, "少年")
.when(df.age < 40, "青年")
.when(df.age < 60, "中年")
.otherwise("老年"))
# 3. 多列操作
df = df.withColumn("full_name", F.concat(df.first_name, F.lit(" "), df.last_name))
# 4. 使用UDF(用户定义函数)
from pyspark.sql.types import StringType
# 注册UDF
@F.udf(returnType=StringType())
def categorize_salary(salary):
if salary < 50000: return "低"
elif salary < 100000: return "中"
else: return "高"
df = df.withColumn("salary_level", categorize_salary(df.salary))
# 5. 批量添加列
new_columns = {
"age_squared": df.age * df.age,
"log_salary": F.log(df.salary + 1),
"normalized_age": (df.age - F.mean(df.age)) / F.stddev(df.age)
}
for col_name, expr in new_columns.items():
df = df.withColumn(col_name, expr)
# ============ 重命名列 ============
df = df.withColumnRenamed("old_name", "new_name")
# 批量重命名(使用字典)
rename_dict = {"old1": "new1", "old2": "new2"}
for old, new in rename_dict.items():
df = df.withColumnRenamed(old, new)
# ============ 删除列 ============
df = df.drop("column_to_remove")
df = df.drop("col1", "col2", "col3") # 删除多列
# 删除多个(使用列表)
columns_to_drop = ["col1", "col2", "col3"]
df = df.drop(*columns_to_drop)
# 删除空值列
df = df.dropna(how="all", subset=["column_name"]) # 该列全为空则删除行
2.5 行操作与过滤
python
# ============ 条件过滤 ============
# 基本过滤
df.filter(df.age > 30).show()
df.where(df.salary >= 50000).show()
# 多条件组合
df.filter((df.age > 25) & (df.salary < 70000)).show() # AND
df.filter((df.age < 25) | (df.age > 60)).show() # OR
df.filter(~(df.age > 30)).show() # NOT
# 字符串匹配
df.filter(df.name.startswith("A")).show()
df.filter(df.name.endswith("son")).show()
df.filter(df.name.contains("li")).show()
df.filter(df.name.like("A%")).show() # SQL LIKE模式
# 空值过滤
df.filter(df.column.isNotNull()).show()
df.filter(df.column.isNull()).show()
# 列表匹配
df.filter(df.category.isin(["A", "B", "C"])).show()
# 正则表达式匹配
df.filter(df.email.rlike(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$")).show()
# ============ 去重操作 ============
df.distinct().show() # 全行去重
df.dropDuplicates().show() # 同distinct()
# 基于子集去重
df.dropDuplicates(["name", "age"]).show()
# 保留第一条或最后一条(使用窗口函数)
from pyspark.sql.window import Window
window_spec = Window.partitionBy("user_id").orderBy(F.desc("timestamp"))
df.withColumn("row_num", F.row_number().over(window_spec)) \
.filter(F.col("row_num") == 1) \
.drop("row_num")
# ============ 行限制与抽样 ============
df.limit(100).show() # 前100行
# 随机抽样(带权重)
df.sample(withReplacement=False, fraction=0.01, seed=42) # 1%无放回抽样
df.sample(withReplacement=True, fraction=0.1) # 10%有放回抽样
# ============ 删除空行 ============
df.dropna().show() # 删除任何列包含空值的行
df.dropna(how="all").show() # 删除所有列都为空的行
df.dropna(subset=["name", "age"]).show() # 删除指定列为空的行
# ============ 填充空值 ============
# 统一填充
df.fillna({"age": 0, "salary": 0}).show()
# 使用统计值填充
mean_age = df.select(F.mean(df.age)).collect()[0][0]
df.fillna({"age": mean_age}).show()
# 前向/后向填充(需要窗口函数)
window_spec = Window.orderBy("timestamp").rowsBetween(-1, -1)
df.withColumn("filled_value", F.last(df.value, ignorenulls=True).over(window_spec))
2.6 排序与分页
python
# ============ 基本排序 ============
df.orderBy("age").show() # 升序(默认)
df.orderBy(df.age.desc()).show() # 降序
df.sort("salary", ascending=False).show() # sort是orderBy的别名
# 多列排序
df.orderBy(F.col("dept").asc(), F.col("salary").desc()).show()
# 控制NULL值位置
df.orderBy(F.col("age").asc_nulls_first()).show() # NULL在最前
df.orderBy(F.col("age").desc_nulls_last()).show() # NULL在最后
# ============ 高级排序技巧 ============
# 1. 自定义排序规则
from pyspark.sql.types import IntegerType
# 定义优先级映射
priority_map = {"High": 1, "Medium": 2, "Low": 3}
priority_udf = F.udf(lambda x: priority_map.get(x, 999), IntegerType())
df.withColumn("priority_num", priority_udf(df.priority)) \
.orderBy("priority_num") \
.drop("priority_num")
# 2. 随机排序(洗牌)
df.orderBy(F.rand(seed=42)).show()
# 3. 按字符串长度排序
df.orderBy(F.length(df.name).desc()).show()
# ============ 分页查询 ============
# 使用limit和offset模拟
page_size = 10
page_number = 3 # 第4页(0-based)
# 方法1:使用row_number窗口函数
window_spec = Window.orderBy("user_id")
df.withColumn("row_num", F.row_number().over(window_spec)) \
.filter((F.col("row_num") > page_number * page_size) &
(F.col("row_num") <= (page_number + 1) * page_size)) \
.drop("row_num")
# 方法2:使用排序+limit+offset(Spark 3.4+)
# df.orderBy("id").limit(page_size).offset(page_number * page_size)
2.7 聚合与分组操作
python
# ============ 基础聚合 ============
# 全局聚合(无分组)
df.agg(F.sum("salary").alias("total_salary"),
F.avg("age").alias("average_age"),
F.count("*").alias("total_rows")).show()
# ============ 分组聚合 ============
# 1. 基本分组
df.groupBy("department").agg(
F.count("*").alias("emp_count"),
F.avg("salary").alias("avg_salary"),
F.max("salary").alias("max_salary"),
F.min("salary").alias("min_salary"),
F.stddev("salary").alias("salary_std")
).show()
# 2. 多列分组
df.groupBy("department", "gender").agg(
F.count("*").alias("count"),
F.avg("salary").alias("avg_salary")
).orderBy("department", "gender").show()
# 3. 多个聚合函数
agg_exprs = {
"salary": ["mean", "stddev", "min", "max"],
"age": ["mean", "count"]
}
df.groupBy("department").agg(agg_exprs).show()
# ============ 高级聚合函数 ============
# 1. 收集列表
df.groupBy("department").agg(
F.collect_list("name").alias("all_names"),
F.collect_set("name").alias("unique_names")
).show()
# 2. 拼接字符串
df.groupBy("department").agg(
F.concat_ws(", ", F.collect_list("name")).alias("name_list")
).show()
# 3. 百分位数
df.groupBy("department").agg(
F.expr("percentile(salary, array(0.25, 0.5, 0.75))").alias("salary_percentiles")
).show()
# 4. 协方差与相关性
df.groupBy("department").agg(
F.corr("age", "salary").alias("age_salary_corr"),
F.covar_pop("age", "salary").alias("age_salary_covar")
).show()
# ============ 条件聚合 ============
# 1. 统计满足条件的行数
df.groupBy("department").agg(
F.sum((df.salary > 50000).cast("int")).alias("high_salary_count"),
F.avg(F.when(df.age > 30, df.salary).otherwise(None)).alias("avg_salary_over_30")
).show()
# 2. 多条件统计
df.agg(
F.count(F.when(df.salary > 50000, 1)).alias("count_high_salary"),
F.count(F.when((df.salary > 50000) & (df.age < 40), 1)).alias("count_young_high_salary")
).show()
# ============ 滚动窗口聚合 ============
# 按时间窗口聚合(需要时间戳列)
window_spec = Window.partitionBy("user_id").orderBy("timestamp").rowsBetween(-6, 0)
df.withColumn("last_7_actions",
F.collect_list("action").over(window_spec)) \
.withColumn("action_count",
F.count("*").over(window_spec))
2.8 数据连接(JOIN)操作
python
# ============ JOIN类型详解 ============
df1 = spark.createDataFrame([(1, "A"), (2, "B"), (3, "C")], ["id", "value1"])
df2 = spark.createDataFrame([(1, "X"), (2, "Y"), (4, "Z")], ["id", "value2"])
# 1. INNER JOIN(默认)
df1.join(df2, "id").show() # 使用同名列
df1.join(df2, df1.id == df2.id, "inner").show() # 指定条件
df1.join(df2, ["id"], "inner").show() # 多列连接
# 2. LEFT JOIN
df1.join(df2, "id", "left").show()
df1.join(df2, "id", "leftouter").show() # leftouter是left的别名
# 3. RIGHT JOIN
df1.join(df2, "id", "right").show()
# 4. FULL OUTER JOIN
df1.join(df2, "id", "full").show()
df1.join(df2, "id", "fullouter").show() # 别名
# 5. LEFT SEMI JOIN(只保留左表数据,相当于WHERE EXISTS)
df1.join(df2, "id", "leftsemi").show()
# 6. LEFT ANTI JOIN(只保留左表数据,相当于WHERE NOT EXISTS)
df1.join(df2, "id", "leftanti").show()
# 7. CROSS JOIN
df1.crossJoin(df2).show()
# ============ JOIN条件高级用法 ============
# 1. 多条件JOIN
df1.join(df2, (df1.id == df2.id) & (df1.category == df2.category))
# 2. 不等值JOIN
df1.join(df2, df1.start_time <= df2.end_time)
# 3. 广播小表(优化JOIN性能)
from pyspark.sql.functions import broadcast
df1.join(broadcast(df2), "id") # 将df2广播到所有节点
# 4. 处理重复列名
result = df1.join(df2, "id") # 同名列只会保留一个
# 不同列名时的重命名
df1.join(df2, df1.id == df2.user_id) \
.select(df1["*"], df2.value.alias("user_value"))
# ============ 复杂JOIN模式 ============
# 1. 自连接
df.alias("a").join(df.alias("b"), F.col("a.manager_id") == F.col("b.employee_id"))
# 2. 多个表连接
df1.join(df2, "id") \
.join(df3, "id") \
.join(df4, "id")
# 3. 使用SQL进行复杂连接
df1.createOrReplaceTempView("t1")
df2.createOrReplaceTempView("t2")
df3.createOrReplaceTempView("t3")
complex_join = spark.sql("""
SELECT t1.*, t2.value2, t3.value3
FROM t1
LEFT JOIN t2 ON t1.id = t2.id
INNER JOIN t3 ON t1.id = t3.id AND t2.category = t3.category
""")
🚀 三、高级功能与性能优化
3.1 窗口函数(Window Functions)
python
# ============ 窗口函数基础 ============
from pyspark.sql.window import Window
# 1. 定义窗口规范
window_spec = Window.partitionBy("department").orderBy("salary")
# 2. 排名函数
df.withColumn("row_number", F.row_number().over(window_spec)) \
.withColumn("rank", F.rank().over(window_spec)) \
.withColumn("dense_rank", F.dense_rank().over(window_spec)) \
.withColumn("percent_rank", F.percent_rank().over(window_spec)) \
.withColumn("ntile", F.ntile(4).over(window_spec)) # 分成4组
# 3. 分析函数
window_spec_agg = Window.partitionBy("department")
df.withColumn("dept_avg_salary", F.avg("salary").over(window_spec_agg)) \
.withColumn("dept_max_salary", F.max("salary").over(window_spec_agg)) \
.withColumn("dept_total", F.sum("salary").over(window_spec_agg)) \
.withColumn("salary_rank_in_dept",
F.row_number().over(window_spec.orderBy(F.desc("salary"))))
# ============ 滚动窗口 ============
# 1. 基于行数的窗口
rows_window = Window.partitionBy("user_id") \
.orderBy("timestamp") \
.rowsBetween(-2, 0) # 当前行和前2行
df.withColumn("last_3_avg", F.avg("value").over(rows_window))
# 2. 基于范围的窗口(时间间隔)
range_window = Window.partitionBy("user_id") \
.orderBy("timestamp") \
.rangeBetween(-3600, 0) # 过去1小时
# ============ 高级窗口技巧 ============
# 1. 累计计算
window_cumulative = Window.partitionBy("user_id") \
.orderBy("date") \
.rowsBetween(Window.unboundedPreceding, 0)
df.withColumn("cumulative_sum", F.sum("amount").over(window_cumulative)) \
.withColumn("cumulative_count", F.count("*").over(window_cumulative))
# 2. 向前/向后引用
window_lag_lead = Window.partitionBy("user_id").orderBy("timestamp")
df.withColumn("prev_value", F.lag("value", 1).over(window_lag_lead)) \
.withColumn("next_value", F.lead("value", 1).over(window_lag_lead))
# 3. 第一个/最后一个值
df.withColumn("first_in_group",
F.first("value", ignorenulls=True).over(window_spec)) \
.withColumn("last_in_group",
F.last("value", ignorenulls=True).over(window_spec))
3.2 性能优化策略
python
# ============ 配置优化参数 ============
spark.conf.set("spark.sql.shuffle.partitions", "200") # 调整shuffle分区数
spark.conf.set("spark.sql.adaptive.enabled", "true") # 启用自适应查询
spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", "true")
# ============ 数据缓存策略 ============
# 1. 缓存级别
df.persist() # 默认 MEMORY_AND_DISK
df.cache() # 等同于 persist()
# 明确指定存储级别
from pyspark import StorageLevel
df.persist(StorageLevel.MEMORY_ONLY) # 只存内存
df.persist(StorageLevel.MEMORY_AND_DISK) # 内存存不下时存磁盘
df.persist(StorageLevel.DISK_ONLY) # 只存磁盘
df.persist(StorageLevel.MEMORY_ONLY_SER) # 序列化后存内存
df.persist(StorageLevel.OFF_HEAP) # 堆外内存
# 2. 检查缓存状态
spark.catalog.isCached("table_name")
spark.catalog.cacheTable("table_name")
spark.catalog.clearCache()
# 3. 何时缓存/释放
# 缓存:会被多次使用的中间结果
# 释放:不再需要时调用 df.unpersist()
# ============ 分区与分桶优化 ============
# 1. 重新分区(触发shuffle)
df_repartitioned = df.repartition(100) # 指定分区数
df_repartitioned = df.repartition("date") # 按列分区
df_repartitioned = df.repartition(100, "date", "category") # 组合
# 2. 合并分区(无shuffle,减少分区数)
df_coalesced = df.coalesce(10) # 合并到10个分区
# 3. 最佳分区大小:每个分区128MB左右最佳
optimal_partitions = df_size_in_mb / 128
df_repartitioned = df.repartition(int(optimal_partitions))
# ============ 广播JOIN ============
# 自动广播(小表<10MB默认广播)
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "10485760") # 10MB
# 手动广播
small_df = df2.filter(df2.size < 1000)
df1.join(broadcast(small_df), "id")
# ============ 数据倾斜处理 ============
# 1. 识别倾斜键
df.groupBy("join_key").count().orderBy(F.desc("count")).show(10)
# 2. 解决方案:加盐(salting)
salt_df = df.withColumn("salt", (F.rand() * 100).cast("int"))
# 然后进行JOIN操作
# 3. 拆分大key
large_keys = df.groupBy("key").count().filter("count > 10000").select("key")
# 对大key单独处理
3.3 调试与问题排查
python
# ============ 执行计划分析 ============
# 1. 查看逻辑计划
df.explain() # 简单计划
df.explain(extended=True) # 详细计划
df.explain(mode="cost") # 成本信息(如果可用)
df.explain(mode="formatted") # 格式化输出
# 2. 查看物理计划
df._jdf.queryExecution().executedPlan()
# ============ 监控与日志 ============
# 1. 获取任务ID
query = df.write.saveAsTable("output_table")
query.id # 查询ID
# 2. 查看Spark UI
# 本地访问:http://localhost:4040
# 查看Stage详情、任务时间、数据倾斜等
# ============ 性能瓶颈识别 ============
# 1. 检查数据倾斜
df.rdd.glom().map(len).collect() # 查看每个分区数据量
# 2. 检查序列化/反序列化成本
# 使用Kryo序列化
spark.conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
# 3. 内存使用分析
df.storageLevel # 查看存储级别
📊 四、实战应用模式
4.1 数据清洗管道
python
def data_cleaning_pipeline(df):
"""完整的数据清洗管道"""
# 阶段1:Schema验证与修复
expected_schema = StructType([
StructField("user_id", IntegerType(), nullable=False),
StructField("name", StringType(), nullable=True),
StructField("age", IntegerType(), nullable=True),
StructField("email", StringType(), nullable=True),
StructField("signup_date", DateType(), nullable=True)
])
# 验证并修复Schema
for field in expected_schema:
if field.name not in df.columns:
df = df.withColumn(field.name, F.lit(None).cast(field.dataType))
# 阶段2:处理缺失值
# 数值列:用中位数填充
numeric_cols = [c for c, t in df.dtypes if t in ('int', 'double', 'float')]
for col in numeric_cols:
median_val = df.approxQuantile(col, [0.5], 0.01)[0]
df = df.fillna({col: median_val})
# 字符串列:用众数填充
string_cols = [c for c, t in df.dtypes if t == 'string']
for col in string_cols:
mode_row = df.groupBy(col).count().orderBy(F.desc("count")).first()
if mode_row:
df = df.fillna({col: mode_row[0]})
# 阶段3:异常值处理(IQR方法)
for col in numeric_cols:
q1, q3 = df.approxQuantile(col, [0.25, 0.75], 0.01)
iqr = q3 - q1
lower_bound = q1 - 1.5 * iqr
upper_bound = q3 + 1.5 * iqr
df = df.withColumn(
col,
F.when(F.col(col) < lower_bound, lower_bound)
.when(F.col(col) > upper_bound, upper_bound)
.otherwise(F.col(col))
)
# 阶段4:数据标准化
from pyspark.ml.feature import StandardScaler, VectorAssembler
assembler = VectorAssembler(inputCols=numeric_cols, outputCol="features")
df_vector = assembler.transform(df)
scaler = StandardScaler(inputCol="features", outputCol="scaled_features")
scaler_model = scaler.fit(df_vector)
df_scaled = scaler_model.transform(df_vector)
# 阶段5:重复数据删除
df_final = df_scaled.dropDuplicates(["user_id"])
# 阶段6:数据质量报告
total_rows = df_final.count()
report = {
"original_rows": df.count(),
"final_rows": total_rows,
"duplicates_removed": df.count() - total_rows,
"columns_processed": len(df.columns),
"missing_values_filled": sum([df.filter(F.col(c).isNull()).count()
for c in df.columns])
}
return df_final, report
4.2 特征工程模板
python
def create_features(df, user_col="user_id", timestamp_col="timestamp"):
"""创建推荐系统特征"""
from pyspark.ml.feature import Bucketizer
features_df = df
# 1. 时间特征
features_df = features_df.withColumn("hour", F.hour(timestamp_col))
features_df = features_df.withColumn("day_of_week", F.dayofweek(timestamp_col))
features_df = features_df.withColumn("is_weekend",
F.when(F.dayofweek(timestamp_col).isin([1, 7]), 1).otherwise(0))
# 2. 历史行为统计特征
window_7d = Window.partitionBy(user_col) \
.orderBy(timestamp_col) \
.rangeBetween(-7*86400, 0) # 7天窗口
features_df = features_df.withColumn(
"cnt_7d",
F.count("*").over(window_7d)
).withColumn(
"avg_amount_7d",
F.avg("amount").over(window_7d)
)
# 3. 序列特征(最近N次行为)
window_seq = Window.partitionBy(user_col) \
.orderBy(F.desc(timestamp_col))
features_df = features_df.withColumn(
"recent_actions",
F.collect_list("action").over(window_seq.rowsBetween(0, 9)) # 最近10次
).withColumn(
"recent_categories",
F.collect_set("category").over(window_seq.rowsBetween(0, 49)) # 最近50次
)
# 4. 交叉特征
features_df = features_df.withColumn(
"hour_category",
F.concat(F.col("hour").cast("string"), F.lit("_"), F.col("category"))
)
# 5. 分桶特征
bucketizer = Bucketizer(splits=[0, 18, 25, 35, 50, 65, 100],
inputCol="age",
outputCol="age_bucket")
features_df = bucketizer.transform(features_df)
# 6. 目标编码(以用户为例)
user_stats = df.groupBy(user_col).agg(
F.mean("target").alias("user_target_mean"),
F.stddev("target").alias("user_target_std"),
F.count("*").alias("user_total_actions")
)
features_df = features_df.join(user_stats, user_col, "left")
# 7. 特征选择:计算特征重要性(与目标的相关性)
numeric_features = [c for c, t in features_df.dtypes
if t in ('int', 'double', 'float') and c != "target"]
correlations = []
for feature in numeric_features:
corr = features_df.stat.corr(feature, "target")
correlations.append((feature, corr))
# 选择相关性最高的特征
significant_features = [f for f, c in sorted(correlations, key=lambda x: abs(x[1]), reverse=True)[:50]]
return features_df.select([user_col, timestamp_col] + significant_features)
🎯 五、最佳实践总结
5.1 Do's and Don'ts
| 最佳实践 | 反模式 |
|---|---|
✅ 使用.select()明确指定所需列 |
❌ 使用SELECT *查询所有列 |
✅ 尽早使用.filter()减少数据量 |
❌ 在转换链的最后进行过滤 |
✅ 对常用中间结果使用.cache() |
❌ 缓存所有中间结果(浪费内存) |
✅ 使用.repartition()优化大表 |
❌ 过度分区(小文件问题) |
| ✅ 使用Parquet/ORC格式存储 | ❌ 使用CSV/JSON存储大规模数据 |
| ✅ 利用窗口函数替代自连接 | ❌ 使用复杂子查询和自连接 |
| ✅ 广播小表优化JOIN性能 | ❌ 让Spark自动处理所有JOIN |
5.2 性能检查清单
-
分区检查:
- 分区大小是否接近128MB?
- 是否有数据倾斜(某些分区特别大)?
- 分区键是否合理(高基数,均匀分布)?
-
缓存策略:
- 是否缓存了会被多次使用的DataFrame?
- 缓存级别是否合适(MEMORY_AND_DISK vs MEMORY_ONLY)?
- 是否及时清理了不再需要的缓存?
-
JOIN优化:
- 是否使用了广播JOIN处理小表?
- JOIN键是否有数据倾斜?
- 是否可以考虑使用广播变量?
-
配置调优:
spark.sql.shuffle.partitions设置是否合理?- 是否启用了自适应查询执行?
- 序列化器是否配置为Kryo?
5.3 故障排除指南
python
# 常见问题与解决方案
problems_solutions = {
"内存不足": [
"增加executor内存: --executor-memory 8g",
"使用序列化缓存: df.persist(StorageLevel.MEMORY_ONLY_SER)",
"减少分区数: df.coalesce(100)",
"增加堆外内存: spark.memory.offHeap.enabled=true"
],
"数据倾斜": [
"识别倾斜键: df.groupBy(key).count().orderBy(desc('count'))",
"加盐处理: df.withColumn('salt', rand()*N)",
"拆分倾斜键: 单独处理大key",
"使用广播JOIN避免shuffle"
],
"小文件过多": [
"写入前合并: df.coalesce(N).write.parquet(...)",
"使用repartition: df.repartition('date').write...",
"设置合并参数: spark.conf.set('spark.sql.adaptive.enabled', true)"
],
"JOIN性能慢": [
"检查是否有小表可以广播",
"确保JOIN键类型一致",
"考虑预分区: 按JOIN键分区",
"使用分桶表: .bucketBy(100, 'join_key')"
]
}
📈 六、进阶资源与扩展
6.1 性能监控指标
| 指标 | 说明 | 理想范围 |
|---|---|---|
| 任务时间 | Stage执行时间 | 无统一标准,越短越好 |
| Shuffle读写 | Shuffle数据量 | 尽量减少 |
| GC时间 | 垃圾回收时间占比 | < 10% |
| 序列化时间 | 序列化/反序列化时间 | < 5% |
| 数据倾斜度 | 最大/最小分区数据量比 | < 2:1 |
6.2 推荐学习路径
-
初学者阶段(1-2周):
- 掌握DataFrame创建和基本操作
- 学习常用转换和行动操作
- 理解惰性求值原理
-
中级阶段(1个月):
- 精通窗口函数和复杂JOIN
- 掌握性能调优技巧
- 学习结构化流处理
-
高级阶段(2-3个月):
- 深入理解Catalyst优化器
- 掌握自定义优化规则
- 学习Delta Lake等高级特性
6.3 生产环境代码模板
python
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.window import Window
import logging
class SparkDataFrameProcessor:
"""生产环境DataFrame处理器模板"""
def __init__(self, app_name="DataProcessor", master="yarn"):
self.spark = self._create_spark_session(app_name, master)
self.logger = self._setup_logging()
def _create_spark_session(self, app_name, master):
"""创建配置优化的SparkSession"""
return SparkSession.builder \
.appName(app_name) \
.master(master) \
.config("spark.sql.shuffle.partitions", "200") \
.config("spark.sql.adaptive.enabled", "true") \
.config("spark.sql.autoBroadcastJoinThreshold", "10485760") \
.config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") \
.enableHiveSupport() \
.getOrCreate()
def _setup_logging(self):
"""配置日志"""
logging.basicConfig(level=logging.INFO)
return logging.getLogger(__name__)
def process_data(self, input_path, output_path):
"""数据处理主流程"""
try:
# 1. 读取数据
self.logger.info(f"读取数据: {input_path}")
df = self.spark.read.parquet(input_path)
# 2. 数据质量检查
self._data_quality_check(df)
# 3. 数据处理
processed_df = self._transform_data(df)
# 4. 结果验证
self._validate_results(processed_df)
# 5. 写入输出
self.logger.info(f"写入结果: {output_path}")
processed_df.write \
.mode("overwrite") \
.partitionBy("date") \
.parquet(output_path)
self.logger.info("处理完成!")
return True
except Exception as e:
self.logger.error(f"处理失败: {str(e)}")
raise
def _data_quality_check(self, df):
"""数据质量检查"""
total_rows = df.count()
if total_rows == 0:
raise ValueError("输入数据为空")
# 检查关键字段
required_columns = ["user_id", "timestamp", "event_type"]
for col in required_columns:
if col not in df.columns:
raise ValueError(f"缺少必要字段: {col}")
null_counts = {col: df.filter(col(col).isNull()).count()
for col in df.columns}
self.logger.info(f"数据质量报告: 总行数={total_rows}, 空值统计={null_counts}")
def _transform_data(self, df):
"""数据转换逻辑"""
# 这里添加具体的业务逻辑
return df
def _validate_results(self, df):
"""结果验证"""
if df.count() == 0:
raise ValueError("处理后数据为空")
# 检查Schema
expected_columns = ["user_id", "timestamp", "processed_value"]
for col in expected_columns:
if col not in df.columns:
raise ValueError(f"结果缺少字段: {col}")
def close(self):
"""清理资源"""
self.spark.stop()
🎓 总结
这份指南涵盖了PySpark DataFrame的核心概念、完整操作、性能优化和实战应用。关键要点:
- 理解本质:DataFrame是分布式的、惰性求值的、有Schema的数据集合
- 掌握操作:从基础的select/filter到高级的窗口函数和复杂JOIN
- 性能优先:合理分区、缓存、广播JOIN、避免数据倾斜
- 生产就绪:使用监控、日志、异常处理和最佳实践