【推荐系统】深度学习训练框架(十五):特征工程——PySpark DataFrame数据处理核心指南

📚 一、DataFrame核心概念

1.1 什么是DataFrame?

本质 :PySpark DataFrame是一个分布式、不可变、基于命名列的数据集合,类似于关系型数据库表或Pandas DataFrame,但底层是RDD的封装。

关键特性

  • 分布式处理:数据自动分区在集群节点上并行处理
  • 惰性求值:转换操作不立即执行,而是构建执行计划
  • Schema感知:每个DataFrame都有明确的数据类型定义
  • 优化执行:通过Catalyst优化器和Tungsten引擎高效执行

1.2 与RDD、Pandas DataFrame对比

维度 PySpark DataFrame RDD Pandas DataFrame
数据处理 结构化,列式操作 非结构化,面向记录 单机结构化
API风格 声明式(SQL风格) 命令式(函数式) 命令式/声明式混合
性能 Catalyst优化,最高效 手动优化,中等 单机最快
使用场景 大规模ETL、分析 复杂非结构化处理 中小规模数据分析

🔧 二、DataFrame完整操作体系

2.1 创建DataFrame的7种方法

python 复制代码
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DoubleType
import pyspark.sql.functions as F

spark = SparkSession.builder.appName("CompleteGuide").getOrCreate()

# 方法1:从列表创建(最常用)
data = [("Alice", 34, 55000.0), ("Bob", 45, 72000.0), ("Cathy", 28, 48000.0)]
df1 = spark.createDataFrame(data, ["name", "age", "salary"])

# 方法2:从RDD转换
rdd = spark.sparkContext.parallelize(data)
df2 = rdd.toDF(["name", "age", "salary"])

# 方法3:指定Schema创建(精确控制数据类型)
schema = StructType([
    StructField("name", StringType(), nullable=False),
    StructField("age", IntegerType(), nullable=True),
    StructField("salary", DoubleType(), nullable=True)
])
df3 = spark.createDataFrame(data, schema=schema)

# 方法4:从Pandas DataFrame创建(交互调试)
import pandas as pd
pdf = pd.DataFrame(data, columns=["name", "age", "salary"])
df4 = spark.createDataFrame(pdf)

# 方法5:从字典列表创建
dict_list = [{"name": "Alice", "age": 34, "salary": 55000.0},
             {"name": "Bob", "age": 45, "salary": 72000.0}]
df5 = spark.createDataFrame(dict_list)

# 方法6:通过SQL查询创建
df1.createOrReplaceTempView("people")
df6 = spark.sql("SELECT name, age FROM people WHERE age > 30")

# 方法7:从空Schema创建(用于流处理等场景)
empty_df = spark.createDataFrame([], schema)

2.2 数据读取与写入(支持20+格式)

python 复制代码
# ============ 读取数据 ============
# 1. CSV文件(最常用)
df_csv = spark.read.csv("path/to/file.csv", 
                       header=True,          # 第一行作为列名
                       inferSchema=True,     # 自动推断类型
                       sep=",",             # 分隔符
                       quote='"',           # 引号字符
                       escape='"',          # 转义字符
                       nullValue="NA")      # 空值表示

# 2. JSON文件(半结构化数据)
df_json = spark.read.json("path/to/file.json",
                          multiLine=True)   # 多行JSON

# 3. Parquet文件(生产环境首选)
df_parquet = spark.read.parquet("path/to/file.parquet")

# 4. ORC文件
df_orc = spark.read.orc("path/to/file.orc")

# 5. 文本文件(每行作为字符串)
df_text = spark.read.text("path/to/file.txt")

# 6. 从JDBC数据库读取
df_jdbc = spark.read \
    .format("jdbc") \
    .option("url", "jdbc:mysql://localhost:3306/db") \
    .option("dbtable", "table_name") \
    .option("user", "username") \
    .option("password", "password") \
    .load()

# 7. 读取多个文件/目录
df_multi = spark.read.csv(["file1.csv", "file2.csv", "file3.csv"], header=True)

# 8. 读取Hive表
df_hive = spark.sql("SELECT * FROM hive_table")

# ============ 写入数据 ============
# 1. 基本写入
df.write.csv("output_path.csv", header=True, mode="overwrite")

# 2. 写入模式详解
# mode="overwrite"   # 覆盖(最常用)
# mode="append"      # 追加
# mode="ignore"      # 存在则跳过
# mode="error"       # 存在则报错(默认)

# 3. 分区写入(性能优化关键!)
df.write \
  .partitionBy("year", "month", "day") \  # 按多级分区
  .parquet("output_path/")

# 4. 分桶写入(优化JOIN和查询)
df.write \
  .bucketBy(100, "user_id") \  # 分成100个桶
  .sortBy("timestamp") \       # 每个桶内排序
  .mode("overwrite") \
  .saveAsTable("bucketed_table")

# 5. 写入到JDBC
df.write \
  .format("jdbc") \
  .option("url", "jdbc:mysql://localhost:3306/db") \
  .option("dbtable", "new_table") \
  .option("user", "username") \
  .option("password", "password") \
  .mode("append") \
  .save()

2.3 数据查看与探查

python 复制代码
# ============ 基础查看 ============
df.show(10, truncate=False)      # 显示10行,不截断长文本
df.show(vertical=True)           # 垂直格式显示,适合宽表
df.printSchema()                 # 打印Schema树状图
df.columns                       # 获取列名列表
df.dtypes                        # 获取(列名, 类型)列表
df.schema                        # 获取完整Schema对象
df.count()                       # 行数统计

# ============ 数据抽样 ============
df.sample(withReplacement=False, fraction=0.1, seed=42).show()  # 无放回抽样10%
df.sampleBy("category", fractions={"A": 0.1, "B": 0.2}, seed=42)  # 分层抽样

# ============ 统计信息 ============
df.describe().show()                         # 数值列基本统计
df.describe("age", "salary").show()          # 指定列统计
df.summary().show()                          # 更详细统计(包括百分位数)

# 自定义统计
df.select(
    F.mean("salary").alias("avg_salary"),
    F.stddev("salary").alias("std_salary"),
    F.skewness("salary").alias("skewness"),
    F.kurtosis("salary").alias("kurtosis")
).show()

# ============ 空值检查 ============
from pyspark.sql.functions import isnan, isnull

# 统计每列空值数量
null_counts = df.select([F.sum(F.col(c).isNull().cast("int")).alias(c) for c in df.columns])
null_counts.show()

# 检查NaN值(浮点数特有)
nan_counts = df.select([F.sum(isnan(F.col(c)).cast("int")).alias(c) for c in df.columns])
nan_counts.show()

# ============ 唯一值与频次 ============
df.select("category").distinct().count()     # 唯一值数量
df.select("category").distinct().show()      # 显示所有唯一值
df.groupBy("category").count().orderBy("count", ascending=False).show()  # 频次统计

2.4 列操作完整指南

python 复制代码
# ============ 选择列 ============
# 多种选择方式
df.select("name", "age").show()
df.select(df.name, df["age"]).show()
df.select(F.col("name"), F.col("age") + 10).show()
df.select("*").show()  # 所有列
df.select(df.columns[:3])  # 前3列

# 使用列表推导选择列
numeric_cols = [c for c, t in df.dtypes if t in ('int', 'double', 'float', 'bigint', 'smallint', 'tinyint')]
df.select(numeric_cols).show()

# ============ 新增/修改列 ============
# 1. 基本新增列
df = df.withColumn("salary_k", df.salary / 1000)

# 2. 条件赋值(SQL CASE WHEN)
df = df.withColumn("age_group",
    F.when(df.age < 20, "少年")
     .when(df.age < 40, "青年")
     .when(df.age < 60, "中年")
     .otherwise("老年"))

# 3. 多列操作
df = df.withColumn("full_name", F.concat(df.first_name, F.lit(" "), df.last_name))

# 4. 使用UDF(用户定义函数)
from pyspark.sql.types import StringType

# 注册UDF
@F.udf(returnType=StringType())
def categorize_salary(salary):
    if salary < 50000: return "低"
    elif salary < 100000: return "中"
    else: return "高"

df = df.withColumn("salary_level", categorize_salary(df.salary))

# 5. 批量添加列
new_columns = {
    "age_squared": df.age * df.age,
    "log_salary": F.log(df.salary + 1),
    "normalized_age": (df.age - F.mean(df.age)) / F.stddev(df.age)
}

for col_name, expr in new_columns.items():
    df = df.withColumn(col_name, expr)

# ============ 重命名列 ============
df = df.withColumnRenamed("old_name", "new_name")

# 批量重命名(使用字典)
rename_dict = {"old1": "new1", "old2": "new2"}
for old, new in rename_dict.items():
    df = df.withColumnRenamed(old, new)

# ============ 删除列 ============
df = df.drop("column_to_remove")
df = df.drop("col1", "col2", "col3")  # 删除多列

# 删除多个(使用列表)
columns_to_drop = ["col1", "col2", "col3"]
df = df.drop(*columns_to_drop)

# 删除空值列
df = df.dropna(how="all", subset=["column_name"])  # 该列全为空则删除行

2.5 行操作与过滤

python 复制代码
# ============ 条件过滤 ============
# 基本过滤
df.filter(df.age > 30).show()
df.where(df.salary >= 50000).show()

# 多条件组合
df.filter((df.age > 25) & (df.salary < 70000)).show()  # AND
df.filter((df.age < 25) | (df.age > 60)).show()        # OR
df.filter(~(df.age > 30)).show()                       # NOT

# 字符串匹配
df.filter(df.name.startswith("A")).show()
df.filter(df.name.endswith("son")).show()
df.filter(df.name.contains("li")).show()
df.filter(df.name.like("A%")).show()  # SQL LIKE模式

# 空值过滤
df.filter(df.column.isNotNull()).show()
df.filter(df.column.isNull()).show()

# 列表匹配
df.filter(df.category.isin(["A", "B", "C"])).show()

# 正则表达式匹配
df.filter(df.email.rlike(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$")).show()

# ============ 去重操作 ============
df.distinct().show()  # 全行去重
df.dropDuplicates().show()  # 同distinct()

# 基于子集去重
df.dropDuplicates(["name", "age"]).show()

# 保留第一条或最后一条(使用窗口函数)
from pyspark.sql.window import Window
window_spec = Window.partitionBy("user_id").orderBy(F.desc("timestamp"))
df.withColumn("row_num", F.row_number().over(window_spec)) \
  .filter(F.col("row_num") == 1) \
  .drop("row_num")

# ============ 行限制与抽样 ============
df.limit(100).show()  # 前100行

# 随机抽样(带权重)
df.sample(withReplacement=False, fraction=0.01, seed=42)  # 1%无放回抽样
df.sample(withReplacement=True, fraction=0.1)  # 10%有放回抽样

# ============ 删除空行 ============
df.dropna().show()  # 删除任何列包含空值的行
df.dropna(how="all").show()  # 删除所有列都为空的行
df.dropna(subset=["name", "age"]).show()  # 删除指定列为空的行

# ============ 填充空值 ============
# 统一填充
df.fillna({"age": 0, "salary": 0}).show()

# 使用统计值填充
mean_age = df.select(F.mean(df.age)).collect()[0][0]
df.fillna({"age": mean_age}).show()

# 前向/后向填充(需要窗口函数)
window_spec = Window.orderBy("timestamp").rowsBetween(-1, -1)
df.withColumn("filled_value", F.last(df.value, ignorenulls=True).over(window_spec))

2.6 排序与分页

python 复制代码
# ============ 基本排序 ============
df.orderBy("age").show()  # 升序(默认)
df.orderBy(df.age.desc()).show()  # 降序
df.sort("salary", ascending=False).show()  # sort是orderBy的别名

# 多列排序
df.orderBy(F.col("dept").asc(), F.col("salary").desc()).show()

# 控制NULL值位置
df.orderBy(F.col("age").asc_nulls_first()).show()  # NULL在最前
df.orderBy(F.col("age").desc_nulls_last()).show()   # NULL在最后

# ============ 高级排序技巧 ============
# 1. 自定义排序规则
from pyspark.sql.types import IntegerType

# 定义优先级映射
priority_map = {"High": 1, "Medium": 2, "Low": 3}
priority_udf = F.udf(lambda x: priority_map.get(x, 999), IntegerType())

df.withColumn("priority_num", priority_udf(df.priority)) \
  .orderBy("priority_num") \
  .drop("priority_num")

# 2. 随机排序(洗牌)
df.orderBy(F.rand(seed=42)).show()

# 3. 按字符串长度排序
df.orderBy(F.length(df.name).desc()).show()

# ============ 分页查询 ============
# 使用limit和offset模拟
page_size = 10
page_number = 3  # 第4页(0-based)

# 方法1:使用row_number窗口函数
window_spec = Window.orderBy("user_id")
df.withColumn("row_num", F.row_number().over(window_spec)) \
  .filter((F.col("row_num") > page_number * page_size) & 
          (F.col("row_num") <= (page_number + 1) * page_size)) \
  .drop("row_num")

# 方法2:使用排序+limit+offset(Spark 3.4+)
# df.orderBy("id").limit(page_size).offset(page_number * page_size)

2.7 聚合与分组操作

python 复制代码
# ============ 基础聚合 ============
# 全局聚合(无分组)
df.agg(F.sum("salary").alias("total_salary"),
       F.avg("age").alias("average_age"),
       F.count("*").alias("total_rows")).show()

# ============ 分组聚合 ============
# 1. 基本分组
df.groupBy("department").agg(
    F.count("*").alias("emp_count"),
    F.avg("salary").alias("avg_salary"),
    F.max("salary").alias("max_salary"),
    F.min("salary").alias("min_salary"),
    F.stddev("salary").alias("salary_std")
).show()

# 2. 多列分组
df.groupBy("department", "gender").agg(
    F.count("*").alias("count"),
    F.avg("salary").alias("avg_salary")
).orderBy("department", "gender").show()

# 3. 多个聚合函数
agg_exprs = {
    "salary": ["mean", "stddev", "min", "max"],
    "age": ["mean", "count"]
}
df.groupBy("department").agg(agg_exprs).show()

# ============ 高级聚合函数 ============
# 1. 收集列表
df.groupBy("department").agg(
    F.collect_list("name").alias("all_names"),
    F.collect_set("name").alias("unique_names")
).show()

# 2. 拼接字符串
df.groupBy("department").agg(
    F.concat_ws(", ", F.collect_list("name")).alias("name_list")
).show()

# 3. 百分位数
df.groupBy("department").agg(
    F.expr("percentile(salary, array(0.25, 0.5, 0.75))").alias("salary_percentiles")
).show()

# 4. 协方差与相关性
df.groupBy("department").agg(
    F.corr("age", "salary").alias("age_salary_corr"),
    F.covar_pop("age", "salary").alias("age_salary_covar")
).show()

# ============ 条件聚合 ============
# 1. 统计满足条件的行数
df.groupBy("department").agg(
    F.sum((df.salary > 50000).cast("int")).alias("high_salary_count"),
    F.avg(F.when(df.age > 30, df.salary).otherwise(None)).alias("avg_salary_over_30")
).show()

# 2. 多条件统计
df.agg(
    F.count(F.when(df.salary > 50000, 1)).alias("count_high_salary"),
    F.count(F.when((df.salary > 50000) & (df.age < 40), 1)).alias("count_young_high_salary")
).show()

# ============ 滚动窗口聚合 ============
# 按时间窗口聚合(需要时间戳列)
window_spec = Window.partitionBy("user_id").orderBy("timestamp").rowsBetween(-6, 0)
df.withColumn("last_7_actions", 
              F.collect_list("action").over(window_spec)) \
  .withColumn("action_count", 
              F.count("*").over(window_spec))

2.8 数据连接(JOIN)操作

python 复制代码
# ============ JOIN类型详解 ============
df1 = spark.createDataFrame([(1, "A"), (2, "B"), (3, "C")], ["id", "value1"])
df2 = spark.createDataFrame([(1, "X"), (2, "Y"), (4, "Z")], ["id", "value2"])

# 1. INNER JOIN(默认)
df1.join(df2, "id").show()  # 使用同名列
df1.join(df2, df1.id == df2.id, "inner").show()  # 指定条件
df1.join(df2, ["id"], "inner").show()  # 多列连接

# 2. LEFT JOIN
df1.join(df2, "id", "left").show()
df1.join(df2, "id", "leftouter").show()  # leftouter是left的别名

# 3. RIGHT JOIN
df1.join(df2, "id", "right").show()

# 4. FULL OUTER JOIN
df1.join(df2, "id", "full").show()
df1.join(df2, "id", "fullouter").show()  # 别名

# 5. LEFT SEMI JOIN(只保留左表数据,相当于WHERE EXISTS)
df1.join(df2, "id", "leftsemi").show()

# 6. LEFT ANTI JOIN(只保留左表数据,相当于WHERE NOT EXISTS)
df1.join(df2, "id", "leftanti").show()

# 7. CROSS JOIN
df1.crossJoin(df2).show()

# ============ JOIN条件高级用法 ============
# 1. 多条件JOIN
df1.join(df2, (df1.id == df2.id) & (df1.category == df2.category))

# 2. 不等值JOIN
df1.join(df2, df1.start_time <= df2.end_time)

# 3. 广播小表(优化JOIN性能)
from pyspark.sql.functions import broadcast
df1.join(broadcast(df2), "id")  # 将df2广播到所有节点

# 4. 处理重复列名
result = df1.join(df2, "id")  # 同名列只会保留一个

# 不同列名时的重命名
df1.join(df2, df1.id == df2.user_id) \
   .select(df1["*"], df2.value.alias("user_value"))

# ============ 复杂JOIN模式 ============
# 1. 自连接
df.alias("a").join(df.alias("b"), F.col("a.manager_id") == F.col("b.employee_id"))

# 2. 多个表连接
df1.join(df2, "id") \
   .join(df3, "id") \
   .join(df4, "id")

# 3. 使用SQL进行复杂连接
df1.createOrReplaceTempView("t1")
df2.createOrReplaceTempView("t2")
df3.createOrReplaceTempView("t3")

complex_join = spark.sql("""
    SELECT t1.*, t2.value2, t3.value3
    FROM t1
    LEFT JOIN t2 ON t1.id = t2.id
    INNER JOIN t3 ON t1.id = t3.id AND t2.category = t3.category
""")

🚀 三、高级功能与性能优化

3.1 窗口函数(Window Functions)

python 复制代码
# ============ 窗口函数基础 ============
from pyspark.sql.window import Window

# 1. 定义窗口规范
window_spec = Window.partitionBy("department").orderBy("salary")

# 2. 排名函数
df.withColumn("row_number", F.row_number().over(window_spec)) \
  .withColumn("rank", F.rank().over(window_spec)) \
  .withColumn("dense_rank", F.dense_rank().over(window_spec)) \
  .withColumn("percent_rank", F.percent_rank().over(window_spec)) \
  .withColumn("ntile", F.ntile(4).over(window_spec))  # 分成4组

# 3. 分析函数
window_spec_agg = Window.partitionBy("department")

df.withColumn("dept_avg_salary", F.avg("salary").over(window_spec_agg)) \
  .withColumn("dept_max_salary", F.max("salary").over(window_spec_agg)) \
  .withColumn("dept_total", F.sum("salary").over(window_spec_agg)) \
  .withColumn("salary_rank_in_dept", 
              F.row_number().over(window_spec.orderBy(F.desc("salary"))))

# ============ 滚动窗口 ============
# 1. 基于行数的窗口
rows_window = Window.partitionBy("user_id") \
                   .orderBy("timestamp") \
                   .rowsBetween(-2, 0)  # 当前行和前2行

df.withColumn("last_3_avg", F.avg("value").over(rows_window))

# 2. 基于范围的窗口(时间间隔)
range_window = Window.partitionBy("user_id") \
                    .orderBy("timestamp") \
                    .rangeBetween(-3600, 0)  # 过去1小时

# ============ 高级窗口技巧 ============
# 1. 累计计算
window_cumulative = Window.partitionBy("user_id") \
                         .orderBy("date") \
                         .rowsBetween(Window.unboundedPreceding, 0)

df.withColumn("cumulative_sum", F.sum("amount").over(window_cumulative)) \
  .withColumn("cumulative_count", F.count("*").over(window_cumulative))

# 2. 向前/向后引用
window_lag_lead = Window.partitionBy("user_id").orderBy("timestamp")
df.withColumn("prev_value", F.lag("value", 1).over(window_lag_lead)) \
  .withColumn("next_value", F.lead("value", 1).over(window_lag_lead))

# 3. 第一个/最后一个值
df.withColumn("first_in_group", 
              F.first("value", ignorenulls=True).over(window_spec)) \
  .withColumn("last_in_group", 
              F.last("value", ignorenulls=True).over(window_spec))

3.2 性能优化策略

python 复制代码
# ============ 配置优化参数 ============
spark.conf.set("spark.sql.shuffle.partitions", "200")  # 调整shuffle分区数
spark.conf.set("spark.sql.adaptive.enabled", "true")   # 启用自适应查询
spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", "true")

# ============ 数据缓存策略 ============
# 1. 缓存级别
df.persist()  # 默认 MEMORY_AND_DISK
df.cache()    # 等同于 persist()

# 明确指定存储级别
from pyspark import StorageLevel
df.persist(StorageLevel.MEMORY_ONLY)           # 只存内存
df.persist(StorageLevel.MEMORY_AND_DISK)       # 内存存不下时存磁盘
df.persist(StorageLevel.DISK_ONLY)             # 只存磁盘
df.persist(StorageLevel.MEMORY_ONLY_SER)       # 序列化后存内存
df.persist(StorageLevel.OFF_HEAP)              # 堆外内存

# 2. 检查缓存状态
spark.catalog.isCached("table_name")
spark.catalog.cacheTable("table_name")
spark.catalog.clearCache()

# 3. 何时缓存/释放
# 缓存:会被多次使用的中间结果
# 释放:不再需要时调用 df.unpersist()

# ============ 分区与分桶优化 ============
# 1. 重新分区(触发shuffle)
df_repartitioned = df.repartition(100)  # 指定分区数
df_repartitioned = df.repartition("date")  # 按列分区
df_repartitioned = df.repartition(100, "date", "category")  # 组合

# 2. 合并分区(无shuffle,减少分区数)
df_coalesced = df.coalesce(10)  # 合并到10个分区

# 3. 最佳分区大小:每个分区128MB左右最佳
optimal_partitions = df_size_in_mb / 128
df_repartitioned = df.repartition(int(optimal_partitions))

# ============ 广播JOIN ============
# 自动广播(小表<10MB默认广播)
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "10485760")  # 10MB

# 手动广播
small_df = df2.filter(df2.size < 1000)
df1.join(broadcast(small_df), "id")

# ============ 数据倾斜处理 ============
# 1. 识别倾斜键
df.groupBy("join_key").count().orderBy(F.desc("count")).show(10)

# 2. 解决方案:加盐(salting)
salt_df = df.withColumn("salt", (F.rand() * 100).cast("int"))
# 然后进行JOIN操作

# 3. 拆分大key
large_keys = df.groupBy("key").count().filter("count > 10000").select("key")
# 对大key单独处理

3.3 调试与问题排查

python 复制代码
# ============ 执行计划分析 ============
# 1. 查看逻辑计划
df.explain()  # 简单计划
df.explain(extended=True)  # 详细计划
df.explain(mode="cost")  # 成本信息(如果可用)
df.explain(mode="formatted")  # 格式化输出

# 2. 查看物理计划
df._jdf.queryExecution().executedPlan()

# ============ 监控与日志 ============
# 1. 获取任务ID
query = df.write.saveAsTable("output_table")
query.id  # 查询ID

# 2. 查看Spark UI
# 本地访问:http://localhost:4040
# 查看Stage详情、任务时间、数据倾斜等

# ============ 性能瓶颈识别 ============
# 1. 检查数据倾斜
df.rdd.glom().map(len).collect()  # 查看每个分区数据量

# 2. 检查序列化/反序列化成本
# 使用Kryo序列化
spark.conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")

# 3. 内存使用分析
df.storageLevel  # 查看存储级别

📊 四、实战应用模式

4.1 数据清洗管道

python 复制代码
def data_cleaning_pipeline(df):
    """完整的数据清洗管道"""
    
    # 阶段1:Schema验证与修复
    expected_schema = StructType([
        StructField("user_id", IntegerType(), nullable=False),
        StructField("name", StringType(), nullable=True),
        StructField("age", IntegerType(), nullable=True),
        StructField("email", StringType(), nullable=True),
        StructField("signup_date", DateType(), nullable=True)
    ])
    
    # 验证并修复Schema
    for field in expected_schema:
        if field.name not in df.columns:
            df = df.withColumn(field.name, F.lit(None).cast(field.dataType))
    
    # 阶段2:处理缺失值
    # 数值列:用中位数填充
    numeric_cols = [c for c, t in df.dtypes if t in ('int', 'double', 'float')]
    for col in numeric_cols:
        median_val = df.approxQuantile(col, [0.5], 0.01)[0]
        df = df.fillna({col: median_val})
    
    # 字符串列:用众数填充
    string_cols = [c for c, t in df.dtypes if t == 'string']
    for col in string_cols:
        mode_row = df.groupBy(col).count().orderBy(F.desc("count")).first()
        if mode_row:
            df = df.fillna({col: mode_row[0]})
    
    # 阶段3:异常值处理(IQR方法)
    for col in numeric_cols:
        q1, q3 = df.approxQuantile(col, [0.25, 0.75], 0.01)
        iqr = q3 - q1
        lower_bound = q1 - 1.5 * iqr
        upper_bound = q3 + 1.5 * iqr
        
        df = df.withColumn(
            col,
            F.when(F.col(col) < lower_bound, lower_bound)
             .when(F.col(col) > upper_bound, upper_bound)
             .otherwise(F.col(col))
        )
    
    # 阶段4:数据标准化
    from pyspark.ml.feature import StandardScaler, VectorAssembler
    
    assembler = VectorAssembler(inputCols=numeric_cols, outputCol="features")
    df_vector = assembler.transform(df)
    
    scaler = StandardScaler(inputCol="features", outputCol="scaled_features")
    scaler_model = scaler.fit(df_vector)
    df_scaled = scaler_model.transform(df_vector)
    
    # 阶段5:重复数据删除
    df_final = df_scaled.dropDuplicates(["user_id"])
    
    # 阶段6:数据质量报告
    total_rows = df_final.count()
    report = {
        "original_rows": df.count(),
        "final_rows": total_rows,
        "duplicates_removed": df.count() - total_rows,
        "columns_processed": len(df.columns),
        "missing_values_filled": sum([df.filter(F.col(c).isNull()).count() 
                                     for c in df.columns])
    }
    
    return df_final, report

4.2 特征工程模板

python 复制代码
def create_features(df, user_col="user_id", timestamp_col="timestamp"):
    """创建推荐系统特征"""
    
    from pyspark.ml.feature import Bucketizer
    
    features_df = df
    
    # 1. 时间特征
    features_df = features_df.withColumn("hour", F.hour(timestamp_col))
    features_df = features_df.withColumn("day_of_week", F.dayofweek(timestamp_col))
    features_df = features_df.withColumn("is_weekend", 
                                         F.when(F.dayofweek(timestamp_col).isin([1, 7]), 1).otherwise(0))
    
    # 2. 历史行为统计特征
    window_7d = Window.partitionBy(user_col) \
                     .orderBy(timestamp_col) \
                     .rangeBetween(-7*86400, 0)  # 7天窗口
    
    features_df = features_df.withColumn(
        "cnt_7d",
        F.count("*").over(window_7d)
    ).withColumn(
        "avg_amount_7d",
        F.avg("amount").over(window_7d)
    )
    
    # 3. 序列特征(最近N次行为)
    window_seq = Window.partitionBy(user_col) \
                      .orderBy(F.desc(timestamp_col))
    
    features_df = features_df.withColumn(
        "recent_actions",
        F.collect_list("action").over(window_seq.rowsBetween(0, 9))  # 最近10次
    ).withColumn(
        "recent_categories",
        F.collect_set("category").over(window_seq.rowsBetween(0, 49))  # 最近50次
    )
    
    # 4. 交叉特征
    features_df = features_df.withColumn(
        "hour_category",
        F.concat(F.col("hour").cast("string"), F.lit("_"), F.col("category"))
    )
    
    # 5. 分桶特征
    bucketizer = Bucketizer(splits=[0, 18, 25, 35, 50, 65, 100],
                           inputCol="age",
                           outputCol="age_bucket")
    features_df = bucketizer.transform(features_df)
    
    # 6. 目标编码(以用户为例)
    user_stats = df.groupBy(user_col).agg(
        F.mean("target").alias("user_target_mean"),
        F.stddev("target").alias("user_target_std"),
        F.count("*").alias("user_total_actions")
    )
    
    features_df = features_df.join(user_stats, user_col, "left")
    
    # 7. 特征选择:计算特征重要性(与目标的相关性)
    numeric_features = [c for c, t in features_df.dtypes 
                       if t in ('int', 'double', 'float') and c != "target"]
    
    correlations = []
    for feature in numeric_features:
        corr = features_df.stat.corr(feature, "target")
        correlations.append((feature, corr))
    
    # 选择相关性最高的特征
    significant_features = [f for f, c in sorted(correlations, key=lambda x: abs(x[1]), reverse=True)[:50]]
    
    return features_df.select([user_col, timestamp_col] + significant_features)

🎯 五、最佳实践总结

5.1 Do's and Don'ts

最佳实践 反模式
✅ 使用.select()明确指定所需列 ❌ 使用SELECT *查询所有列
✅ 尽早使用.filter()减少数据量 ❌ 在转换链的最后进行过滤
✅ 对常用中间结果使用.cache() ❌ 缓存所有中间结果(浪费内存)
✅ 使用.repartition()优化大表 ❌ 过度分区(小文件问题)
✅ 使用Parquet/ORC格式存储 ❌ 使用CSV/JSON存储大规模数据
✅ 利用窗口函数替代自连接 ❌ 使用复杂子查询和自连接
✅ 广播小表优化JOIN性能 ❌ 让Spark自动处理所有JOIN

5.2 性能检查清单

  1. 分区检查

    • 分区大小是否接近128MB?
    • 是否有数据倾斜(某些分区特别大)?
    • 分区键是否合理(高基数,均匀分布)?
  2. 缓存策略

    • 是否缓存了会被多次使用的DataFrame?
    • 缓存级别是否合适(MEMORY_AND_DISK vs MEMORY_ONLY)?
    • 是否及时清理了不再需要的缓存?
  3. JOIN优化

    • 是否使用了广播JOIN处理小表?
    • JOIN键是否有数据倾斜?
    • 是否可以考虑使用广播变量?
  4. 配置调优

    • spark.sql.shuffle.partitions设置是否合理?
    • 是否启用了自适应查询执行?
    • 序列化器是否配置为Kryo?

5.3 故障排除指南

python 复制代码
# 常见问题与解决方案
problems_solutions = {
    "内存不足": [
        "增加executor内存: --executor-memory 8g",
        "使用序列化缓存: df.persist(StorageLevel.MEMORY_ONLY_SER)",
        "减少分区数: df.coalesce(100)",
        "增加堆外内存: spark.memory.offHeap.enabled=true"
    ],
    
    "数据倾斜": [
        "识别倾斜键: df.groupBy(key).count().orderBy(desc('count'))",
        "加盐处理: df.withColumn('salt', rand()*N)",
        "拆分倾斜键: 单独处理大key",
        "使用广播JOIN避免shuffle"
    ],
    
    "小文件过多": [
        "写入前合并: df.coalesce(N).write.parquet(...)",
        "使用repartition: df.repartition('date').write...",
        "设置合并参数: spark.conf.set('spark.sql.adaptive.enabled', true)"
    ],
    
    "JOIN性能慢": [
        "检查是否有小表可以广播",
        "确保JOIN键类型一致",
        "考虑预分区: 按JOIN键分区",
        "使用分桶表: .bucketBy(100, 'join_key')"
    ]
}

📈 六、进阶资源与扩展

6.1 性能监控指标

指标 说明 理想范围
任务时间 Stage执行时间 无统一标准,越短越好
Shuffle读写 Shuffle数据量 尽量减少
GC时间 垃圾回收时间占比 < 10%
序列化时间 序列化/反序列化时间 < 5%
数据倾斜度 最大/最小分区数据量比 < 2:1

6.2 推荐学习路径

  1. 初学者阶段(1-2周):

    • 掌握DataFrame创建和基本操作
    • 学习常用转换和行动操作
    • 理解惰性求值原理
  2. 中级阶段(1个月):

    • 精通窗口函数和复杂JOIN
    • 掌握性能调优技巧
    • 学习结构化流处理
  3. 高级阶段(2-3个月):

    • 深入理解Catalyst优化器
    • 掌握自定义优化规则
    • 学习Delta Lake等高级特性

6.3 生产环境代码模板

python 复制代码
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.window import Window
import logging

class SparkDataFrameProcessor:
    """生产环境DataFrame处理器模板"""
    
    def __init__(self, app_name="DataProcessor", master="yarn"):
        self.spark = self._create_spark_session(app_name, master)
        self.logger = self._setup_logging()
        
    def _create_spark_session(self, app_name, master):
        """创建配置优化的SparkSession"""
        return SparkSession.builder \
            .appName(app_name) \
            .master(master) \
            .config("spark.sql.shuffle.partitions", "200") \
            .config("spark.sql.adaptive.enabled", "true") \
            .config("spark.sql.autoBroadcastJoinThreshold", "10485760") \
            .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") \
            .enableHiveSupport() \
            .getOrCreate()
    
    def _setup_logging(self):
        """配置日志"""
        logging.basicConfig(level=logging.INFO)
        return logging.getLogger(__name__)
    
    def process_data(self, input_path, output_path):
        """数据处理主流程"""
        try:
            # 1. 读取数据
            self.logger.info(f"读取数据: {input_path}")
            df = self.spark.read.parquet(input_path)
            
            # 2. 数据质量检查
            self._data_quality_check(df)
            
            # 3. 数据处理
            processed_df = self._transform_data(df)
            
            # 4. 结果验证
            self._validate_results(processed_df)
            
            # 5. 写入输出
            self.logger.info(f"写入结果: {output_path}")
            processed_df.write \
                .mode("overwrite") \
                .partitionBy("date") \
                .parquet(output_path)
                
            self.logger.info("处理完成!")
            return True
            
        except Exception as e:
            self.logger.error(f"处理失败: {str(e)}")
            raise
    
    def _data_quality_check(self, df):
        """数据质量检查"""
        total_rows = df.count()
        if total_rows == 0:
            raise ValueError("输入数据为空")
        
        # 检查关键字段
        required_columns = ["user_id", "timestamp", "event_type"]
        for col in required_columns:
            if col not in df.columns:
                raise ValueError(f"缺少必要字段: {col}")
        
        null_counts = {col: df.filter(col(col).isNull()).count() 
                      for col in df.columns}
        
        self.logger.info(f"数据质量报告: 总行数={total_rows}, 空值统计={null_counts}")
    
    def _transform_data(self, df):
        """数据转换逻辑"""
        # 这里添加具体的业务逻辑
        return df
    
    def _validate_results(self, df):
        """结果验证"""
        if df.count() == 0:
            raise ValueError("处理后数据为空")
        
        # 检查Schema
        expected_columns = ["user_id", "timestamp", "processed_value"]
        for col in expected_columns:
            if col not in df.columns:
                raise ValueError(f"结果缺少字段: {col}")
    
    def close(self):
        """清理资源"""
        self.spark.stop()

🎓 总结

这份指南涵盖了PySpark DataFrame的核心概念、完整操作、性能优化和实战应用。关键要点:

  1. 理解本质:DataFrame是分布式的、惰性求值的、有Schema的数据集合
  2. 掌握操作:从基础的select/filter到高级的窗口函数和复杂JOIN
  3. 性能优先:合理分区、缓存、广播JOIN、避免数据倾斜
  4. 生产就绪:使用监控、日志、异常处理和最佳实践
相关推荐
TOWE technology1 小时前
PDU、工业连接器与数据中心机柜电力系统
大数据·人工智能·数据中心·idc·pdu·智能pdu·定制电源管理
小魔女千千鱼1 小时前
openEuler 常用开发工具性能实测:Python、Node.js、Git 运行效率对比
人工智能
用户377833043491 小时前
( 教学 )Agent 构建 Prompt(提示词)4. JsonOutputParser
人工智能·后端
YuSun_WK1 小时前
检索增强VS知识蒸馏VS伪标签扩展
人工智能·python
五度易链-区域产业数字化管理平台1 小时前
行业研究+大数据+AI:“五度易链”如何构建高质量产业数据库?
大数据·人工智能
通义灵码1 小时前
如何调教一名合格的“编程搭子”
人工智能·智能体·qoder
aitoolhub1 小时前
AI 生图技术解析:从训练到输出的全流程机制
大数据·人工智能·深度学习
smilejingwei1 小时前
Text2SQL 破局技术解析之三:NLQ 词典与准确性
人工智能·text2sql·bi·spl
图欧学习资源库1 小时前
人工智能领域、图欧科技、IMYAI智能助手2025年11月更新月报
人工智能·科技