Python操作Spark的常用命令指南,涵盖从环境配置到数据分析的核心操作。
一、环境配置与Spark初始化
PySpark是Apache Spark的Python API,它结合了Python的易用性和Spark的分布式计算能力。
1. 安装与基础导入
python
# 安装PySpark
!pip install pyspark
# 导入必要的库
from pyspark.sql import SparkSession
import pyspark.sql.functions as F # 常用函数
import pyspark.sql.types as T # 数据类型
2. 创建SparkSession
SparkSession是与Spark集群交互的统一入口点,绝大多数操作都从这里开始。
python
spark = SparkSession.builder \
.appName("MySparkApp") \ # 设置应用名称
.config("spark.driver.memory", "2g") \ # 配置参数(可选)
.getOrCreate()
3. 基础环境检查
python
# 检查Spark版本
print(spark.version)
# 查看当前配置
print(spark.conf.getAll())
# 列出数据库和表
spark.sql("SHOW DATABASES").show()
spark.sql("USE your_database") # 切换数据库
spark.sql("SHOW TABLES").show()
二、核心操作:数据读写
1. 创建DataFrame
有多种方式创建DataFrame,这是Spark中的核心数据结构。
python
# 方式1:从Python列表创建
data = [("Alice", 29), ("Bob", 35)]
columns = ["name", "age"]
df = spark.createDataFrame(data, schema=columns)
# 方式2:指定详细模式(Schema)
schema = T.StructType([
T.StructField("name", T.StringType(), True),
T.StructField("age", T.IntegerType(), True)
])
df = spark.createDataFrame(data, schema=schema)
2. 从文件读取数据
PySpark支持CSV、Parquet、JSON等多种格式。
python
# 读取CSV文件(常用)
df = spark.read.csv(
"path/to/file.csv",
header=True, # 第一行作为列名
inferSchema=True # 自动推断列类型
)
# 读取Parquet文件(列式存储,高效)
df = spark.read.parquet("path/to/file.parquet")
# 读取JSON文件
df = spark.read.json("path/to/file.json")
3. 数据写入与保存
python
# 保存为Parquet格式(推荐,压缩率高)
df.write.mode("overwrite").parquet("output_path.parquet")
# 保存为CSV格式
df.write.mode("overwrite") \
.option("header", True) \
.csv("output_path.csv")
# 保存为Spark表(可在集群中持久化)
df.write.saveAsTable("table_name")
三、数据处理与转换
数据处理的核心是对DataFrame进行列和行的操作。
1. 列操作
python
# 选择特定列
df.select("name", "age").show()
# 创建新列(示例:年龄加1)
df = df.withColumn("age_plus_one", F.col("age") + 1)
# 重命名列
df = df.withColumnRenamed("old_name", "new_name")
# 更改列类型(示例:整型转字符串)
df = df.withColumn("age_str", F.col("age").cast(T.StringType()))
# 删除列
df = df.drop("column_to_remove")
2. 行操作(过滤与排序)
python
# 过滤行(示例:年龄大于30)
df_filtered = df.filter(F.col("age") > 30)
# 等价写法
df_filtered = df.where(df["age"] > 30)
# 排序
df_sorted = df.orderBy(F.col("age").desc()) # 按年龄降序
3. 处理缺失值
python
# 删除包含任何空值的行
df_clean = df.dropna()
# 填充缺失值(示例:用0填充特定列)
df_filled = df.fillna({"age": 0, "name": "Unknown"})
四、数据聚合与高级分析
1. 分组与聚合
这是数据分析中最常用的操作之一。
python
# 基础分组聚合(示例:按部门计算平均工资)
df.groupBy("department").agg(
F.avg("salary").alias("avg_salary"),
F.count("*").alias("employee_count")
).show()
2. 连接(Join)操作
用于合并两个DataFrame。
python
# 假设有另一个部门信息表df_dept
df_joined = df.join(
df_dept,
on="department_id", # 连接键
how="inner" # 连接方式:inner, left, right, outer等
)
3. 窗口函数
用于计算排名、移动平均等高级分析。
python
from pyspark.sql.window import Window
# 定义窗口:按部门分区,按工资降序
window_spec = Window.partitionBy("department").orderBy(F.col("salary").desc())
# 计算部门内工资排名
df.withColumn("salary_rank", F.rank().over(window_spec)).show()
4. 用户自定义函数(UDF)
当内置函数无法满足需求时使用。
python
from pyspark.sql.functions import udf
# 定义Python函数
def categorize_age(age):
return "Young" if age < 30 else "Senior"
# 注册为UDF(需指定返回类型)
categorize_udf = udf(categorize_age, T.StringType())
# 应用UDF
df.withColumn("age_group", categorize_udf(F.col("age"))).show()
五、在PySpark中运行SQL查询
你可以直接在PySpark中执行SQL语句,这为熟悉SQL的用户提供了便利。
python
# 将DataFrame注册为临时视图
df.createOrReplaceTempView("people")
# 执行SQL查询
result = spark.sql("""
SELECT department, AVG(salary) as avg_sal
FROM people
WHERE age > 25
GROUP BY department
ORDER BY avg_sal DESC
""")
result.show()
六、性能优化与最佳实践
-
避免数据混洗 :
groupBy、join等操作可能导致数据在节点间大量移动,应尽量减少这类操作或提前过滤数据。 -
选择合适的数据格式 :生产环境中,Parquet通常是比CSV更好的选择,因为它支持列式存储和谓词下推,能显著提高查询性能。
-
利用缓存 :对需要多次使用的中间结果进行缓存。
pythondf.cache() # 将DataFrame缓存到内存 df.unpersist() # 使用后释放缓存 -
及时关闭会话 :处理完成后关闭SparkSession以释放资源。
pythonspark.stop()
七、一个完整的示例
python
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
# 1. 初始化
spark = SparkSession.builder.appName("Example").getOrCreate()
# 2. 读取数据
df = spark.read.csv("sales.csv", header=True, inferSchema=True)
# 3. 数据处理
result = (df
.filter(F.col("amount") > 100) # 筛选大额交易
.groupBy("region", "product") # 按地区和产品分组
.agg(F.sum("amount").alias("total_sales")) # 计算总销售额
.orderBy(F.col("total_sales").desc()) # 按销售额降序排序
)
# 4. 输出
result.show()
result.write.mode("overwrite").parquet("sales_summary.parquet")
# 5. 清理
spark.stop()