文章目录
pyspark.sql.group.GroupedData 类型内置方法
agg 聚合
df = spark.createDataFrame(
[(2, "Alice"), (3, "Alice"), (5, "Bob"), (10, "Bob")], ["age", "name"])
df.show()
+---+-----+
|age| name|
+---+-----+
| 2|Alice|
| 3|Alice|
| 5| Bob|
| 10| Bob|
+---+-----+
df.groupBy(df.name)
<pyspark.sql.group.GroupedData object at 0x7f74be2e4240>
df.groupBy(df.name).agg({'*':'count'}).show()
+-----+--------+
| name|count(1)|
+-----+--------+
|Alice| 2|
| Bob| 2|
+-----+--------+
df.groupBy(df.name).agg({'age':'min'}).sort("name").show()
+-----+--------+
| name|min(age)|
+-----+--------+
|Alice| 2|
| Bob| 5|
+-----+--------+
from pyspark.sql.functions import lit,col,min
df.groupBy(df.name).agg(min(df.age)).sort("name").show()
+-----+--------+
| name|min(age)|
+-----+--------+
|Alice| 2|
| Bob| 5|
+-----+--------+
from pyspark.sql.pandas.functions import pandas_udf, PandasUDFType
@pandas_udf('int',PandasUDFType.GROUPED_AGG)
def min_udf(v):
print(type(v))
print(v)
return v.min()
df.groupBy(df.name).agg(min_udf(df.age)).sort("name").show()
<class 'pandas.core.series.Series'>
0 2
1 3
+-----+------------+
| name|min_udf(age)|
+-----+------------+
|Alice| 2|
| Bob| 5|
+-----+------------+
apply
参数:pandas_udf装饰的函数
pyspark.sql.functions.pandas_udf()
from pyspark.sql.pandas.functions import pandas_udf, PandasUDFType
df = spark.createDataFrame(
[(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
("id", "v"))
df.show()
+---+----+
| id| v|
+---+----+
| 1| 1.0|
| 1| 2.0|
| 2| 3.0|
| 2| 5.0|
| 2|10.0|
+---+----+
@pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP)
def normalize(pdf):
print(pdf)
v = pdf.v
d1 = pdf.assign(v=(v - v.mean()) / v.std())
print(d1,type(d1))
return d1
df.groupby("id").apply(normalize).show()
id v
0 1 1.0
1 1 2.0
id v
0 1 -0.707107
1 1 0.707107 <class 'pandas.core.frame.DataFrame'>
id v
0 2 3.0
1 2 5.0
2 2 10.0
id v
0 2 -0.83205
1 2 -0.27735
2 2 1.10940 <class 'pandas.core.frame.DataFrame'>
+---+-------------------+
| id| v|
+---+-------------------+
| 1|-0.7071067811865475|
| 1| 0.7071067811865475|
| 2|-0.8320502943378437|
| 2|-0.2773500981126146|
| 2| 1.1094003924504583|
+---+-------------------+
applylnPandas
参数:普通函数,dataframe的scheam结构
import pandas as pd
df = spark.createDataFrame(
[(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
("id", "v"))
def normalize(pdf):
v = pdf.v
return pdf.assign(v=(v - v.mean()) / v.std())
df.groupby("id").applyInPandas(
normalize, schema="id long, v double").show()
+---+-------------------+
| id| v|
+---+-------------------+
| 1|-0.7071067811865475|
| 1| 0.7071067811865475|
| 2|-0.8320502943378437|
| 2|-0.2773500981126146|
| 2| 1.1094003924504583|
+---+-------------------+
def normalize(pdf):
v = pdf.v
return pd.DataFrame()
df.groupby("id").applyInPandas(
normalize, schema="id long, v double").show()
+---+---+
| id| v|
+---+---+
+---+---+
avg alias mean
df = spark.createDataFrame([
... (2, "Alice", 80), (3, "Alice", 100),
... (5, "Bob", 120), (10, "Bob", 140)], ["age", "name", "height"])
>>> df.show()
+---+-----+------+
|age| name|height|
+---+-----+------+
| 2|Alice| 80|
| 3|Alice| 100|
| 5| Bob| 120|
| 10| Bob| 140|
+---+-----+------+
# 直接使用avg方法
df.groupBy("name").avg('age').sort("name").show()
+-----+--------+
| name|avg(age)|
+-----+--------+
|Alice| 2.5|
| Bob| 7.5|
+-----+--------+
# 使用agg 配合使用
df.groupBy("name").agg({"age":'avg'}).sort("name").show()
+-----+--------+
| name|avg(age)|
+-----+--------+
|Alice| 2.5|
| Bob| 7.5|
+-----+--------+
count
df = spark.createDataFrame(
[(2, "Alice"), (3, "Alice"), (5, "Bob"), (10, "Bob")], ["age", "name"])
df.show()
+---+-----+
|age| name|
+---+-----+
| 2|Alice|
| 3|Alice|
| 5| Bob|
| 10| Bob|
+---+-----+
df.groupBy(df.name).count().sort("name").show()
+-----+-----+
| name|count|
+-----+-----+
|Alice| 2|
| Bob| 2|
+-----+-----+
df.groupBy("name").agg({"name":'count'}).sort("name").show()
+-----+-----------+
| name|count(name)|
+-----+-----------+
|Alice| 2|
| Bob| 2|
+-----+-----------+
max
df = spark.createDataFrame([
(2, "Alice", 80), (3, "Alice", 100),
(5, "Bob", 120), (10, "Bob", 140)], ["age", "name", "height"])
df.show()
+---+-----+------+
|age| name|height|
+---+-----+------+
| 2|Alice| 80|
| 3|Alice| 100|
| 5| Bob| 120|
| 10| Bob| 140|
+---+-----+------+
df.groupBy("name").max("age").sort("name").show()
+-----+--------+
| name|max(age)|
+-----+--------+
|Alice| 3|
| Bob| 10|
+-----+--------+
df.groupBy("name").agg({"age":'max','height':"min"}).sort("name").show()
+-----+--------+-----------+
| name|max(age)|min(height)|
+-----+--------+-----------+
|Alice| 3| 80|
| Bob| 10| 120|
+-----+--------+-----------+
min
df = spark.createDataFrame([
(2, "Alice", 80), (3, "Alice", 100),
(5, "Bob", 120), (10, "Bob", 140)], ["age", "name", "height"])
df.show()
+---+-----+------+
|age| name|height|
+---+-----+------+
| 2|Alice| 80|
| 3|Alice| 100|
| 5| Bob| 120|
| 10| Bob| 140|
+---+-----+------+
df.groupBy("name").min("age").sort("name").show()
+-----+--------+
| name|min(age)|
+-----+--------+
|Alice| 2|
| Bob| 5|
+-----+--------+
df.groupBy("name").agg({"age":'max','height':"min"}).sort("name").show()
+-----+--------+-----------+
| name|max(age)|min(height)|
+-----+--------+-----------+
|Alice| 3| 80|
| Bob| 10| 120|
+-----+--------+-----------+
sum
df = spark.createDataFrame([
(2, "Alice", 80), (3, "Alice", 100),
(5, "Bob", 120), (10, "Bob", 140)], ["age", "name", "height"])
df.show()
+---+-----+------+
|age| name|height|
+---+-----+------+
| 2|Alice| 80|
| 3|Alice| 100|
| 5| Bob| 120|
| 10| Bob| 140|
+---+-----+------+
df.groupBy("name").sum("age").sort("name").show()
+-----+--------+
| name|sum(age)|
+-----+--------+
|Alice| 5|
| Bob| 15|
+-----+--------+
pivot
分组后,选中某列,把列中的指定值进行统计
from pyspark.sql import Row
... df1 = spark.createDataFrame([
... Row(course="dotNET", year=2012, earnings=10000),
... Row(course="Java", year=2012, earnings=20000),
... Row(course="dotNET", year=2012, earnings=5000),
... Row(course="dotNET", year=2013, earnings=48000),
... Row(course="Java", year=2013, earnings=30000),
... ])
... df1.show()
+------+----+--------+
|course|year|earnings|
+------+----+--------+
|dotNET|2012| 10000|
| Java|2012| 20000|
|dotNET|2012| 5000|
|dotNET|2013| 48000|
| Java|2013| 30000|
+------+----+--------+
df1.groupBy("year").pivot("course", ["dotNET", "Java"]).max("earnings").show()
+----+------+-----+
|year|dotNET| Java|
+----+------+-----+
|2012| 10000|20000|
|2013| 48000|30000|
+----+------+-----+
df1.groupBy("year").pivot("course", ["dotNET", "Java"]).sum("earnings").show()
+----+------+-----+
|year|dotNET| Java|
+----+------+-----+
|2012| 15000|20000|
|2013| 48000|30000|
+----+------+-----+