一. 安装(mac环境)
要使用PySpark,本地要有Java开发环境。
-
Java 8 :
brew install --cask homebrew/cask-versions/adoptopenjdk8
-
pyspark安装:
pip install pyspark
二. spark 的配置
pyspark使用前,需要import 出SparkSession
from pyspark.sql import SparkSession, DataFrame
根据实际情况增减配置项
spark = (SparkSession.builder
.config("hive.metastore.uris", "thrift://***.com:8106")
.config("spark.sql.catalog.iceberg", "org.apache.iceberg.spark.SparkCatalog")
.config("spark.sql.catalog.iceberg.type", "hive")
.config("spark.jars.packages", "org.apache.iceberg:iceberg-spark-runtime-3.2_2.12:1.4.1")
.config("spark.sql.catalog.spark_catalog", "org.apache.iceberg.spark.SparkSessionCatalog")
.config("spark.sql.catalog.spark_catalog.type", "hive")
.config("spark.sql.shuffle.partitions", 500)
.config("spark.sql.adaptive.enabled", "true")
.config("spark.sql.adaptive.skewJoin.enabled", "true")
.config("spark.driver.maxResultSize", "40g")
.config('spark.driver.memory','40g')
.config("spark.executor.memoryOverhead", "8g")
.config('spark.executor.memory','8g')
.enableHiveSupport().getOrCreate())
结束时,停止spark
spark.stop()
三. 常用的代码------dataframe
3.1 读取csv
spark在读取csv上优势就很明显了,
1)能直接快速读取几个G的大文件;
2)能直接读取一个目录下的所有文件;
pyspark读取csv,快速高效
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName('learn').master("local").getOrCreate()
print(spark)
df = spark.read.csv(path,header=True)
或者读取某个目录下的所有csv文件
df = spark.read.format("csv").option("delimiter", ",").option("header", "true").load(f{load_path}/*.csv")
普通的pandas读取大的csv,只能将其拆分为多个chunk进行读取,假如我们直接读取csv,可能会直接报内存不够导致进程被干掉。
import pandas as pd
df = pd.read_csv(path, index_col=False, iterator=True, chunksize=100000)
for df_i in df:
print(df_i)
3.2 写csv
pandas写入csv
df.to_csv('test.csv',index=False)
pyspark写入csv时,指定某个目录,这里推荐使用repartition(1),让所有分区文件合并成一个,不然得话存储为多个分片文件;分区数可以自行设置,可以根据文件大小或者行数提前计算分区数,有几个分区,就是几个分片文件
spark_df.repartition(1).write.csv("data/", encoding="utf-8", header=True,mode='overwrite')
或者这种写法
partition_num = max(1, ceil(df.count() / per_file_num))
print(f"partition_num={partition_num}")
df.repartition(partition_num).write.option("header", "true").format("csv").mode("overwrite").save(save_path)
df.repartition(partition_num).write \
.option("header", "true") \
.option("quote", "\"") \
.option("escape", "\"") \
.format("csv") \
.mode("overwrite") \
.save(save_path)
mode: overwrite/ append
3.3 构建Dataframe
pandas构建dataframe
df = pd.DataFrame([['Sport', 1, 1], ['Flow', 2, 9], ['Sport', 2, 2],['Hear', 1, 6]],
columns=['type', 'lenth', 'score'])
pyspark构建dataframe
spark_df = spark.createDataFrame([['Sport', 1, 1], ['Flow', 2, 9], ['Sport', 2, 2],['Hear', 1, 6]],
['type', 'lenth', 'score'])
pandas的dataframe 转 pyspark的dataframe
spark_df = spark.createDataFrame(df)
spark_df.show()
3.4 自定义函数
在处理同一批数据(130w条样本测试)时,使用pyspark(local模式)处理需要0.05s,而pandas的apply函数则需要15s,快了300倍!
pandas的自定义函数apply
def is_sport(x):
if x == 'sport':
return 1
else:
return 0
df['is_sport'] = df['type'].apply(is_sport)
pyspark的自定义函数udf
from pyspark.sql import functions as F
from pyspark.sql.types import IntegerType
type_formater = F.udf(is_sport,IntegerType())
new_type = type_formater(F.col('type')).alias('new_type')
spark_df.select(['type','lenth',new_type]).show()
3.5 查询函数
pandas查询函数query
df = df.query('score == 1')
df = df[(df['a'] > -100) & (df['b'] < 100)]
pyspark查询函数filter
写法1
spark_df.filter("score == 1").show()
写法2
spark_df.filter(F.col('score') == 1)
过滤df中name不为空
df.filter(df.name.isNotNull())
df 中 link_form 为空,is_roundabout 列为0
link_form 不为空,且link_form中含有12或13,is_roundabout 列为1;
都不含,is_roundabout 列为0
df = df.withColumn("is_roundabout", F.when(F.col("link_form").isNull(), 0).otherwise(F.when(
(F.array_contains(F.split(F.col("link_form"), "\\|"), "12")) |
(F.array_contains(F.split(F.col("link_form"), "\\|"), "13")),1).otherwise(0)))
多个条件组合过滤
df = df.filter((F.col('a') == 1) & (F.col('b') == 1))
df = df.filter(F.size(F.col('data'))>0)
3.6 分组聚合函数
pandas分组函数groupby,通过lambda 函数聚合其它列
df.groupby('type').sum()
# 原始数据
a_lst = [{'a': "a_111", 'b': "b_111", 'c': "c_111"},
{'a': "a_222", 'b': "b_222", 'c': "c_222"},
{'a': "a_333", 'b': "b_333", 'c': "c_333"},
{'a': "a_444", 'b': "b_444", 'c': "c_444"},
{'a': "a_111", 'b': "b_555", 'c': "c_555"},
{'a': "a_333", 'b': "b_666", 'c': "c_666"},
]
df = pd.DataFrame(a_lst)
# 添加filter_rsn列
df['filter_rsn'] = 0
df.loc[df['a'] == 'a_111', 'filter_rsn'] = 111
df.loc[df['a'] == 'a_222', 'filter_rsn'] = 222
# 按a列分组并合并b列和c列
result = df.groupby('a').agg({
'b': lambda x: ', '.join(x),
'c': lambda x: ', '.join(x),
'filter_rsn': 'first' # 取每组第一个filter_rsn值
}).reset_index()
print("合并后的结果:")
print(result)
pyspark分组函数groupBy
spark_df.groupBy('type').sum().show()
#按照'adcode','name'两列聚合,
df.groupby('adcode','name').agg(F.concat_ws(",",F.collect_list("link")).alias("links"))
df.groupby('adcode','name').agg(F.concat_ws(",",F.collect_set("link")).alias("links"))
填充空值为 ''
df.fillna(value='')
将link_layer 列按照; 分开,并扩展成多行,扩展列名为 link
df = df.withColumn('link', F.explode(F.split(F.col('link_layer'), ';')))
3.7 合并函数
左右join 相同
df1 和 df2 合并,通过link_id字段合并,左连接
merge_df = df1.join(df2, on='link_id', how='left')
上下合并
pyspark,使用union, 注意这种方式,要求df1 和 df2 所有列名完全相同,且列名顺序完全一致
df = df1.union(df2)
如果列名相同,但是顺序不同,可以用 unionByName
df = df1.unionByName(df2)
3.8 选择部分列
pyspark,使用select 函数,同时其它函数
df = df.select("link_id", "dt", "link_id_map", F.split("link_id_map", ",").alias("link_id_array"))
pandas,使用[[]]
df = df[["link_id", "dt", "link_id_map"]]
3.9 互转格式
# pandas ==> pyspark
pyspark_df = spark.createDataFrame(pandas_df)
# pyspark ==> pandas
pandas_df = pyspark_df.toPandas()
转化代码中有哪些坑呢?
-
pandas转pyspark的时候,如果你的pandas的版本过低,就会报错,这里你可以选择以下2个方案解决:
-
升级pandas
-
在代码中添加:
pd.DataFrame.iteritems = pd.DataFrame.items
-
-
耗时过长,这里也有以下方案能缩减耗时:
-
减少df的列和行 ==> 减少数据
-
利用pyArrow加速:
pip install pyarrow
spark = SparkSession.builder.config("spark.sql.execution.arrow.pyspark.enabled", "true") # 加速转pandasdf的速度
-
3.10 去重
pyspark
根据某一或者几列去重
df = df.dropDuplicates(subset=['a','b'])
python
df=df.drop_duplicates(subset=['a', 'b']).reset_index(drop=True)
3.11 新增常数列
pyspark
df = df.withColumn('a', F.lit('0'))
dataframe
df['a'] = 0
3.12 差值集
两个df按照某一列进行计算差集
diff_df = df1.select("a").subtract(df2.select("a")).distinct()
3.13 取某列值的子串
substr
df = df.filter(col("a_str").substr(1, 4) == "1234")
3.14 dataframe 是否为空判断
pyspark
if df.rdd.isEmpty():
print(f"df is empty")
dataframe
if df.empty:
print("df is empty")
3.15 删除某列
df = df.drop('result')
四. 常用代码-------常见操作udf 和 sql
4.1 udf 使用
4.1.1 类中udf使用
可能在一个处理的过程中往往会使用多个自定义的udf函数,但是当项目非常大的时候,最好还是把归属于这个处理类的udf集成到类中:
class A:
@staticmethod
@F.udf(returnType=IntegerType())
def is_a_equal0(a):
if a == 0:
return 1
else:
return 0
4.1.2 需要返回多列
def aaa(var_list):
@F.udf(returnType=StringType())
def bbb(value):
# 在这里可以对每个值进行自定义的处理操作
rs = ''
value_js = json.loads(value)
for v in var_list:
if rs:
rs += (';' + str(value_js[v]))
else:
rs += str(value_js[v])
return rs
return bbb
need_vars = ['a','b','c']
df = df.withColumn("need_data", aaa(need_vars)(F.col("data")))
df = df.withColumn("s", F.split(df['data'], ";"))
for i, v in enumerate(need_vars):
df = df.withColumn(v, df['s'].getItem(i))
4.1.3 udf使用完整示例
def topo(tx_link, admin_code):
try:
code, msg, result = process_topo(tx_link=tx_link)
return {"code": code, "msg": msg, "topo_result": result}
except Exception as e:
return {"code": -1, "msg": str(e), "topo_result": ''}
# 定义拓扑结果 schema
result_schema = StructType([
StructField("code", IntegerType()),
StructField("msg", StringType()),
StructField("topo_result", StringType())
])
# 注册 UDF
topo_udf = F.udf(topo, result_schema)
df = df.withColumn("topo_results", topo_udf(df['link'], df['adcode']))
# 展开拓扑结果
df = df.withColumn("code", df.topo_results.code) \
.withColumn("msg", df.topo_results.msg) \
.withColumn("topo_result", df.topo_results.topo_result)
4.2 sql相关操作
4.2.1 读取hive表
读取hive表
方法一
branch = spark.sql(f'select branch from {self.branch} where status=1 order by create_time desc limit 1').collect()[0][0]
方法二
tx_df = spark.read.format("iceberg").load(self.unimap_data_tabel).where(f"branch = '{branch}'").select('tile_data')
方法一的举例
df = spark.sql(f"SELECT * from {database_name} where dt ='{batch_id}' and err_info is null and link_sequence <> ''")
df.createOrReplaceTempView("tmp_data")
4.2.2 存储hive表
df.write.format("iceberg").option("mergeSchema", "true").mode("overwrite").save(f"{database_name}")
4.3 聚合 exit_links_array 中有相同数据的所有行
# 收集每个唯一链接对应的所有记录
all_links = df.select(F.explode("exit_links_array").alias("link")).distinct().collect()
all_links = [row.link for row in all_links]
print("create network Graph")
# 创建图
G = nx.Graph()
# 添加节点和边
for link in all_links:
G.add_node(link)
for row in df.collect():
links = row.exit_links_array
if len(links) > 1:
for i in range(len(links)):
for j in range(i+1, len(links)):
G.add_edge(links[i], links[j])
print(f"G add node edge finish")
# 找到所有连通分量
connected_components = list(nx.connected_components(G))
# 为每个连通分量创建一个DataFrame并合并
link_to_component = {}
for component_id, component in enumerate(connected_components):
# print(f"component_id={component_id}, component={component}")
for link in component:
link_to_component[link]=component_id
print("link_to_component")
print(f"link_to_component={link_to_component}")
# 为每条记录分配连通分量ID
component_df = df.withColumn("component_id", component_id_udf(F.col("exit_links_array"), link_to_component))
print("component_df")
component_df.show(5)
exploded_df = component_df.select(
"component_id",
"exit_comments",
"exit_geometry",
"dt",
F.explode("exit_links_array").alias("exit_link")
)
print("exploded_df")
exploded_df.show(5)
exploded_df = exploded_df.filter(F.col("component_id") != -1)
print("filter exploded_df != -1")
exploded_df.show(5)
final_df = exploded_df.groupBy("component_id").agg(
F.collect_set("exit_comments").alias("exit_comments_set"),
F.collect_set("exit_link").alias("exit_links_set"),
F.collect_set("exit_geometry").alias("exit_geometry_set"),
F.collect_set("dt").alias("dt_set")
) \
.withColumn("exit_comments", F.concat_ws(",", "exit_comments_set")) \
.withColumn("exit_links", F.concat_ws(",", "exit_links_set")) \
.withColumn("exit_geometry", F.concat_ws("@@", "exit_geometry_set")) \
.withColumn("dt", F.concat_ws(",", "dt_set")) \
.select("exit_links", "exit_geometry", "exit_comments", "dt")
五. 机器学习
5.1 构建特征
VectorAssembler是一个Transformer,用来将数据集中多个属性按次序组合成一个类型为向量vector的属性。
from pyspark.ml.feature import VectorAssembler
featureassembler=VectorAssembler(inputCols=["lenth","score"],outputCol="Features")
output=featureassembler.transform(spark_df)
output.show()
5.2 构建label
使用StringIndexer来构建数据集的label,默认的index是从0开始
indexer=StringIndexer(inputCol="type",outputCol="label")
output=indexer.fit(output).transform(output)
output.show()
5.3 训练模型
选择需要的特征后,将数据集拆分,进行训练,这里使用的随机森林模型
finalized_data=output.select("Features","label")
train_data,test_data=finalized_data.randomSplit([0.9,0.1])
rf=RandomForestClassificationModel(labelCol='label',featuresCol='Features',numTrees=20,maxBins=122)
rf=rf.fit(train_data)
rf.save('./model')
六. 常见问题及解决办法
6.1 OOM问题
在使用spark的时候,经常在save数据的时候,都会遇到内存溢出的问题,这通常是由于数据量过大导致的。以下是一些可能的解决方案:
6.1.1 配置方面优化
- 增加分区数:如果数据集非常大,可以尝试增加分区数。可以使用
repartition()
或coalesce()
函数来增加分区数。增加分区数可以将数据均匀地分布在更多的节点上,从而减少每个节点上的内存压力。 - 压缩数据:如果数据集包含大量重复的值,可以考虑使用压缩算法来减少内存使用。Pyspark提供了多种压缩算法,如Snappy、Gzip等。可以使用
option("compression", "snappy")
来设置压缩算法。 - 增加集群资源:可以考虑增加集群资源。可以增加集群的节点数或增加每个节点的内存。可以通过调整
spark.driver.memory
和spark.executor.memory
参数来增加内存分配,特别对于driver而言,最好把内存设置大一些。
6.1.2 代码方面的优化
- UDF过于复杂:尽可能将结果拆分不同的列,然后再用简单的udf来组合这些列进行计算。
- 多用filter算子:提前将大量数据剔除
- 多用select算子:只保留需要的列,减少内存的使用
- 尽量少用collect、count算子:像这些action算子基本都会把executor的数据全部加载回driver上,导致driver的内存吃紧。
6.2 未执行问题
Spark SQL 代码看起来执行到了最后一行,但实际上没有真正执行操作,这通常是因为 Spark 的惰性执行特性导致的。
Spark 采用惰性执行(Lazy Evaluation)机制,这意味着:
- 转换操作(Transformations)不会立即执行
- 只有遇到行动操作(Actions)时才会真正执行
- 打印日志只表示代码被解析,不代表数据被处理
解决方案
- 添加行动操作(Action)
确保你的代码中包含至少一个行动操作:
# 转换操作 (不会立即执行)
df_filtered = df.filter(df.is_roundabout == 0)
# 行动操作 (会触发真正执行)
df_filtered.show() # 显示数据
# 或
df_filtered.count() # 计算行数
# 或
df_filtered.printSchema() #
#或
df_filtered.write.saveAsTable("result_table") # 写入存储
# print 列名不是Action操作
print(tx_name_df.columns)
- 检查 Spark UI
访问 Spark UI (通常位于 http://driver-node:4040
) 查看:
-
是否有作业(Jobs)被提交
-
是否有阶段(Stages)在执行
-
任务(Tasks)是否有进度
-
强制立即执行
可以使用 cache()
或 count()
强制立即执行:
df.cache().count() # 缓存并计算行数,触发执行
- 检查数据源
确保数据源路径正确且可访问:
检查文件是否存在 dbutils.fs.ls(file_path) # 尝试读取小样本测试 spark.read.csv(file_path).limit(10).show()
- 权限问题:检查对输入/输出路径的访问权限
6.资源不足:检查 Executor 是否分配了足够资源