在 PySpark 中,DSL(Domain Specific Language)编程主要基于DataFrame API(Python 中 DataFrame 与 Dataset 概念统一),是处理结构化 / 半结构化大数据的核心方式。它通过链式调用 API 方法实现数据加载、清洗、转换、聚合、分析等全流程操作,无需编写 SQL 语句,更贴合 Python 编程习惯。
一、核心基础:环境初始化与 DataFrame 概念
1. 环境准备(初始化 SparkSession)
所有 PySpark DSL 操作的入口是SparkSession
,需先创建会话对象:
from pyspark.sql import SparkSession
# 初始化SparkSession
spark = SparkSession.builder \
.appName("PySpark-DSL-Example") # 应用名称
.master("local[*]") # 本地模式(生产环境无需指定,由集群管理)
.config("spark.sql.shuffle.partitions", "4") # 调整shuffle分区数(默认200,小数据场景可减小)
.getOrCreate()
# 导入内置函数(常用别名F,简化调用)
from pyspark.sql import functions as F
# 导入数据类型(用于定义schema)
from pyspark.sql.types import *
2. DataFrame 核心概念
- DataFrame :分布式数据集合,以命名列(Column)组织,类似关系型数据库的 "表",但底层基于 RDD 实现,包含schema(元数据,描述列名和类型)。
- 不可变性 :DataFrame 是不可变的,所有转换操作(如
withColumn
、filter
)都会生成新的 DataFrame,原对象不变。 - 惰性执行 :转换操作(Transformation)不会立即执行,只有触发行动操作(Action,如
show
、count
、write
)时才会真正计算,优化执行效率。
二、数据加载:创建 DataFrame
PySpark 支持从多种数据源加载数据生成 DataFrame,以下是常见场景:
1. 从文件加载(CSV/JSON/Parquet 等)
(1)CSV 文件(最常用)
# 读取CSV(自动推断schema,适用于测试)
df = spark.read \
.option("header", "true") # 第一行为列名
.option("inferSchema", "true") # 自动推断列类型(生产环境不推荐,效率低)
.option("sep", ",") # 分隔符(默认逗号)
.option("nullValue", "NA") # 指定空值标识
.csv("path/to/data.csv") # 文件路径(支持本地/分布式文件系统如HDFS)
# 生产环境:手动指定schema(高效且避免类型推断错误)
schema = StructType([
StructField("id", IntegerType(), nullable=False), # 非空整数
StructField("name", StringType(), nullable=True),
StructField("birth_date", StringType(), nullable=True), # 先按字符串读,后续转为日期
StructField("salary", DoubleType(), nullable=True)
])
df = spark.read \
.option("header", "true") \
.schema(schema) # 应用手动定义的schema
.csv("path/to/data.csv")
(2)JSON 文件
# 读取JSON(支持单行JSON或多行JSON)
df = spark.read \
.option("multiline", "true") # 多行JSON(默认false,单行JSON)
.json("path/to/data.json")
(3)Parquet 文件(Spark 默认存储格式,列式存储,压缩率高)
df = spark.read.parquet("path/to/data.parquet")
2. 从数据库加载(MySQL/PostgreSQL 等)
需添加对应数据库的 JDBC 驱动(如 MySQL 的mysql-connector-java
):
df = spark.read \
.format("jdbc") \
.option("url", "jdbc:mysql://host:port/db_name") \
.option("dbtable", "table_name") # 表名或SQL查询(如"(select * from t where id>100) tmp")
.option("user", "username") \
.option("password", "password") \
.option("driver", "com.mysql.cj.jdbc.Driver") # 驱动类名
.load()
3. 从 RDD 或集合创建
# 从Python列表创建
data = [("Alice", 25, "F"), ("Bob", 30, "M")]
df = spark.createDataFrame(data, schema=["name", "age", "gender"]) # 手动指定列名
# 从RDD创建
rdd = spark.sparkContext.parallelize(data)
df = rdd.toDF(schema=["name", "age", "gender"])
三、基础操作:查看与验证数据
加载数据后,需先验证数据格式和内容,常用 API:
操作 | 功能说明 | 示例代码 |
---|---|---|
printSchema() |
打印 DataFrame 的 schema(列名 + 类型) | df.printSchema() |
show(n, truncate) |
显示前 n 行数据(truncate=False 不截断长字符串) | df.show(5, truncate=False) |
columns |
返回所有列名列表 | print(df.columns) |
dtypes |
返回(列名,类型)列表 | print(df.dtypes) |
count() |
计算总行数(Action 操作) | print(f"总条数: {df.count()}") |
describe(*cols) |
计算列的统计信息(计数、均值、标准差等) | df.describe("age", "salary").show() |
head(n)/take(n) |
获取前 n 行数据(返回 Row 对象列表) | print(df.head(2)) |
四、核心操作:数据转换与清洗
PySpark DSL 提供了丰富的 API 用于数据转换,覆盖列操作、行过滤、类型转换等场景。
1. 列操作
(1)选择列(select
)
# 选择单列
df.select("name").show()
# 选择多列
df.select("name", "age").show()
# 结合列计算(如年龄+1)
df.select(
F.col("name"), # F.col()用于引用列(推荐,支持链式操作)
F.col("age"),
(F.col("age") + 1).alias("age_plus_1") # alias()重命名列
).show()
# 通配符选择(如所有以"col_"开头的列)
df.select(F.col("`col_*`")).show() # 反引号处理特殊列名
(2)重命名列(withColumnRenamed
)
# 重命名单个列
df_renamed = df.withColumnRenamed("old_col", "new_col")
# 重命名多个列(链式调用)
df_renamed = df \
.withColumnRenamed("a", "col_a") \
.withColumnRenamed("b", "col_b")
(3)删除列(drop
)
# 删除单个列
df_dropped = df.drop("age")
# 删除多个列
df_dropped = df.drop("age", "gender")
# 删除不存在的列(不报错)
df_dropped = df.drop("nonexistent_col")
2. 行操作
(1)过滤行(filter
/where
,二者等价)
# 条件:年龄>25
df.filter(F.col("age") > 25).show()
# 多条件(且:&,或:|,非:~;注意括号)
df.filter(
(F.col("age") > 25) & (F.col("gender") == "F")
).show()
# 字符串条件(不推荐,无类型检查)
df.where("age > 25 and gender = 'F'").show()
# 空值处理(过滤空值)
df.filter(F.col("birth_date").isNotNull()).show()
# 过滤非空且非空字符串
df.filter(
F.col("name").isNotNull() & (F.col("name") != "")
).show()
(2)去重(distinct
/dropDuplicates
)
# 所有列完全重复的行去重
df_distinct = df.distinct()
# 指定列重复的行去重(保留第一条)
df_drop_dup = df.dropDuplicates(subset=["name", "gender"]) # 按name和gender去重
(3)排序(orderBy
/sort
,二者等价)
# 按年龄升序(默认)
df.orderBy("age").show()
# 按年龄降序(F.desc())
df.orderBy(F.desc("age")).show()
# 多列排序(年龄降序,姓名升序)
df.sort(
F.col("age").desc(),
F.col("name").asc()
).show()
3. 类型转换(cast
)
解决数据类型不匹配问题(如字符串转日期、字符串转数字):
# 字符串转整数(若转换失败,结果为null)
df = df.withColumn("age", F.col("age").cast(IntegerType()))
# 字符串转日期(指定格式,如"yyyy-MM-dd")
df = df.withColumn(
"birth_date",
F.to_date(F.col("birth_date"), "yyyy-MM-dd") # 比cast更灵活,支持指定格式
)
# 日期转字符串
df = df.withColumn(
"birth_str",
F.date_format(F.col("birth_date"), "yyyyMMdd") # 转为"20000101"格式
)
# 字符串转浮点数
df = df.withColumn("salary", F.col("salary").cast(DoubleType()))
4. 新增列(withColumn
)
通过现有列计算或常量值新增列:
# 基于现有列计算(年龄是否成年)
df = df.withColumn("is_adult", F.col("age") >= 18)
# 常量列(所有行值相同)
df = df.withColumn("source", F.lit("csv_file")) # F.lit()表示常量
# 条件列(when/otherwise,类似if-else)
df = df.withColumn(
"age_group",
F.when(F.col("age") < 18, "少年")
.when((F.col("age") >= 18) & (F.col("age") < 30), "青年")
.when((F.col("age") >= 30) & (F.col("age") < 50), "中年")
.otherwise("老年") # 其他情况
)
5. 字符串处理
PySpark 提供丰富的字符串函数(pyspark.sql.functions
):
函数 | 功能 | 示例 |
---|---|---|
regexp_replace |
正则替换 | F.regexp_replace("name", "A", "a") |
substring |
截取子串(索引从 1 开始) | F.substring("birth_str", 1, 4) # 取年份 |
upper /lower |
大小写转换 | F.upper("name") |
trim /ltrim /rtrim |
去除空格 | F.trim("name") |
split |
分割字符串为数组 | F.split("address", ",") |
concat /concat_ws |
拼接字符串 | F.concat_ws("-", "year", "month") |
示例:
# 清洗日期字符串(移除所有非数字字符,如"2000/01/01"→"20000101")
df = df.withColumn(
"cleaned_birth",
F.regexp_replace(F.col("birth_date_str"), "[^0-9]", "")
)
# 拆分姓名为姓和名(假设格式为" lastName, firstName")
df = df.withColumn(
"split_name",
F.split(F.trim("name"), ", ") # 先去空格,再按", "分割
).withColumn(
"first_name",
F.col("split_name")[1] # 取数组第二个元素
).withColumn(
"last_name",
F.col("split_name")[0] # 取数组第一个元素
).drop("split_name") # 删除临时列
6. 日期时间处理
针对DateType
或TimestampType
列的操作:
函数 | 功能 | 示例 |
---|---|---|
to_date |
字符串转日期 | F.to_date("str", "yyyy-MM-dd") |
date_add /date_sub |
日期加减天数 | F.date_add("birth_date", 1) |
months_between |
两个日期相差月数 | F.months_between("end", "start") |
year /month /day |
提取年 / 月 / 日 | F.year("birth_date") |
current_date /current_timestamp |
当前日期 / 时间 | F.current_date() |
示例:
# 计算年龄(当前年份 - 出生年份)
df = df.withColumn(
"calc_age",
F.year(F.current_date()) - F.year(F.col("birth_date"))
)
# 计算距今天数
df = df.withColumn(
"days_since_birth",
F.datediff(F.current_date(), F.col("birth_date"))
)
五、高级操作:聚合、连接与窗口函数
1. 聚合操作(groupBy
)
按列分组后进行统计(如计数、求和、均值等):
# 按性别分组,统计人数和平均年龄
gender_stats = df.groupBy("gender") \
.agg(
F.count("id").alias("total_people"), # 计数(非空id的数量)
F.avg("age").alias("avg_age"), # 平均年龄
F.max("salary").alias("max_salary"), # 最高薪资
F.min("salary").alias("min_salary") # 最低薪资
)
gender_stats.show()
# 全局聚合(不分组,统计整体)
total_stats = df.agg(
F.count("*").alias("total_rows"), # 总条数(包括null)
F.sum("salary").alias("total_salary")
)
2. 连接操作(join
,多表关联)
合并多个 DataFrame(类似 SQL 的 JOIN),支持内连接、左连接等:
# 示例:员工表(df_emp)与部门表(df_dept)关联
df_emp = spark.createDataFrame(
[("1", "Alice", "10"), ("2", "Bob", "20")],
["emp_id", "name", "dept_id"]
)
df_dept = spark.createDataFrame(
[("10", "HR"), ("20", "Tech"), ("30", "Finance")],
["dept_id", "dept_name"]
)
# 内连接(只保留两表都匹配的行)
inner_join = df_emp.join(
df_dept,
on="dept_id", # 连接键(若列名不同,用on=[df_emp.a == df_dept.b])
how="inner"
)
# 左连接(保留左表所有行,右表无匹配则为null)
left_join = df_emp.join(
df_dept,
on="dept_id",
how="left" # 或"left_outer"
)
# 右连接(保留右表所有行)
right_join = df_emp.join(df_dept, on="dept_id", how="right")
# 全连接(保留两表所有行)
full_join = df_emp.join(df_dept, on="dept_id", how="full")
注意 :连接后若有重名列(非连接键),需用alias
区分:
df_emp.alias("e").join(
df_dept.alias("d"),
F.col("e.dept_id") == F.col("d.dept_id"),
how="inner"
).select("e.emp_id", "e.name", "d.dept_name").show() # 明确指定列来源
3. 窗口函数(Window
,分组内的精细计算)
用于实现 "分组内排序""Top N""累计求和" 等场景,需先定义窗口规则:
from pyspark.sql.window import Window
# 示例:按部门分组,计算每个员工的薪资排名
# 1. 定义窗口:按部门分区(partitionBy),按薪资降序排序(orderBy)
window_spec = Window \
.partitionBy("dept_id") \
.orderBy(F.col("salary").desc())
# 2. 应用窗口函数
df = df.withColumn(
"rank_in_dept", # 排名(相同值会占用相同名次,后续名次跳过)
F.rank().over(window_spec)
).withColumn(
"dense_rank_in_dept", # 密集排名(相同值占用相同名次,后续名次不跳过)
F.dense_rank().over(window_spec)
).withColumn(
"row_num_in_dept", # 行号(即使值相同,名次也唯一)
F.row_number().over(window_spec)
)
# 3. 取每个部门薪资前2的员工
top2_in_dept = df.filter(F.col("row_num_in_dept") <= 2)
常用窗口函数:
- 排名类:
rank()
、dense_rank()
、row_number()
- 聚合类:
sum()over()
、avg()over()
(如 "累计销售额") - 偏移类:
lag()
(取前 n 行值)、lead()
(取后 n 行值)
六、用户自定义函数(UDF)
当内置函数无法满足需求时,可自定义函数扩展功能:
1. 普通 UDF(基于 Python 函数)
# 定义Python函数:计算姓名长度
def name_length(name):
return len(name) if name is not None else 0
# 注册为UDF(指定返回类型)
name_length_udf = F.udf(name_length, IntegerType())
# 使用UDF
df = df.withColumn("name_len", name_length_udf(F.col("name")))
2. Pandas UDF(向量化 UDF,性能优于普通 UDF)
适用于大数据量场景,基于 Pandas Series 处理:
import pandas as pd
from pyspark.sql.functions import pandas_udf
# 定义Pandas UDF(输入输出为Pandas Series)
@pandas_udf(IntegerType())
def pandas_name_length(name_series: pd.Series) -> pd.Series:
return name_series.str.len().fillna(0) # 利用Pandas字符串方法
# 使用Pandas UDF
df = df.withColumn("name_len", pandas_name_length(F.col("name")))
注意:UDF 会打破 Spark 的优化逻辑,尽量优先使用内置函数;必须指定返回类型,否则可能报错。
七、数据写出(持久化结果)
处理完成后,将 DataFrame 写出到文件或数据库:
1. 写出到文件
# 写出为Parquet(推荐,压缩率高,保留schema)
df.write \
.mode("overwrite") # 写出模式:overwrite(覆盖)/append(追加)/ignore(忽略)/errorifexists(报错)
.parquet("path/to/output.parquet")
# 写出为CSV(需指定header,否则无列名)
df.write \
.mode("append") \
.option("header", "true") \
.option("sep", ",") \
.csv("path/to/output.csv")
# 写出为JSON
df.write.json("path/to/output.json")
2. 写出到数据库
df.write \
.format("jdbc") \
.option("url", "jdbc:mysql://host:port/db_name") \
.option("dbtable", "target_table") \
.option("user", "username") \
.option("password", "password") \
.mode("overwrite") \
.save()
八、性能优化技巧
-
指定 schema :读取数据时手动定义 schema,避免
inferSchema
(减少 IO 和计算开销)。 -
合理使用缓存 :对重复使用的 DataFrame 进行缓存(
cache()
或persist()
),减少重复计算:df_cached = df.cache() # 缓存到内存(默认MEMORY_AND_DISK级别) df_cached.count() # 触发缓存
-
减少数据量 :尽早过滤(
filter
)和选择必要列(select
),避免大表全量处理。 -
调整分区 :
- 读取后分区数不合理:
df.repartition(8)
(增加分区,适合大表)或df.coalesce(2)
(减少分区,不 shuffle)。 - shuffle 操作(如
groupBy
、join
)前设置spark.sql.shuffle.partitions
(根据集群资源调整,通常为核心数的 2-3 倍)。
- 读取后分区数不合理:
-
广播小表 :小表与大表连接时,用
broadcast
广播小表,避免大表 shuffle:from pyspark.sql.functions import broadcast df_large.join(broadcast(df_small), on="id", how="inner") # 广播df_small
-
避免
collect()
:collect()
会将分布式数据拉取到 Driver 端,可能导致 OOM,小数据才用;大数据用take(n)
或写出到文件。
九、综合案例:用户行为数据分析
假设需分析用户行为数据(user_behavior.csv
),包含user_id
、action
(点击 / 购买)、action_time
、product_id
,目标是:
-
清洗数据(转换时间格式,过滤无效值);
-
统计每个用户的点击和购买次数;
-
计算每个用户的首购时间。
1. 加载数据并定义schema
schema = StructType([
StructField("user_id", StringType(), False),
StructField("action", StringType(), False),
StructField("action_time", StringType(), False),
StructField("product_id", StringType(), True)
])df = spark.read
.option("header", "true")
.schema(schema)
.csv("user_behavior.csv")2. 数据清洗
df_clean = df
# 过滤无效动作(只保留点击和购买)
.filter(F.col("action").isin(["click", "purchase"]))
# 转换时间格式(字符串→时间戳)
.withColumn("action_ts", F.to_timestamp("action_time", "yyyy-MM-dd HH:mm:ss"))
# 删除无效时间行
.filter(F.col("action_ts").isNotNull())
.drop("action_time") # 丢弃原字符串时间列3. 统计每个用户的点击和购买次数
user_action_count = df_clean.groupBy("user_id")
.pivot("action", ["click", "purchase"]) # 透视action列,转为click和purchase列
.count()
.fillna(0) # 空值填充为0
.withColumnRenamed("click", "click_count")
.withColumnRenamed("purchase", "purchase_count")4. 计算每个用户的首购时间
定义窗口:按用户分区,按时间升序排序
window_first_purchase = Window
.partitionBy("user_id")
.orderBy("action_ts")first_purchase = df_clean
# 只保留购买行为
.filter(F.col("action") == "purchase")
# 标记每个用户的第一条购买记录
.withColumn("row_num", F.row_number().over(window_first_purchase))
.filter(F.col("row_num") == 1)
# 提取首购时间和商品
.select(
"user_id",
F.col("action_ts").alias("first_purchase_time"),
"product_id"
)5. 合并结果并写出
result = user_action_count.join(
first_purchase,
on="user_id",
how="left" # 左连接,保留所有用户(包括无购买的)
)result.write
.mode("overwrite")
.parquet("user_behavior_analysis_result")关闭SparkSession
spark.stop()
总结
PySpark DSL 编程以 DataFrame API 为核心,覆盖了从数据加载、清洗、转换、聚合到写出的全流程,通过链式调用实现高效的大数据处理。关键在于:
- 熟练掌握基础 API(
select
、filter
、withColumn
等); - 理解惰性执行机制,合理使用缓存;
- 灵活运用聚合、连接、窗口函数解决复杂业务问题;
- 关注性能优化,避免常见陷阱(如无意义的全表扫描、滥用 UDF)。
实际开发中,需结合具体业务场景选择合适的 API,并通过 Spark UI(默认 4040 端口)监控作业执行情况,持续优化。