1. UDF函数(用户自定义函数)
一般指的是用户自己定义的单行函数。一进一出,函数接受的是一行中的一个或者多个字段值,返回一个值。比如MySQL中的,日期相关的dateDiff函数,字符串相关的substring函数。
先准备数据:
1.1 导入必要的包
首先,确保导入必要的Spark包:
Scala
import org.apache.spark.sql.SparkSession
1.2 创建SparkSession
创建一个SparkSession对象,这是与Spark交互的入口。
1.3 定义UDF并注册到SparkSQL
定义一个Scala函数,并将其注册为UDF。示例
1.4 使用UDF在SQL查询中:
调用udf的register方法,第一个参数是udf函数的函数名,第二个参数是要注册为UDF的函数。
Scala
session.udf.register("all_income",(sal:Int,bonus:Int)=>{
sal*12 + bonus
})
1.5 代码:
尽量使用SparkSQL的sql形式的写法,api写法太麻烦了。
Scala
object TestUDF{
def main(args: Array[String]): Unit = {
val session = SparkSession.builder().master("local[*]").appName("testUDF").getOrCreate()
import session.implicits._
val df = session.sparkContext.textFile("D:\\software\\Spark\\SparkProgram1\\atguigu-classes\\data\\a.txt")
.map(t => {
val strs = t.split(" ")
(strs(0), strs(1), strs(2).toInt, strs(3).toInt)
}).toDF("id", "name", "salary", "bonus")
session.udf.register("all_income",(sal:Int,bonus:Int)=>{
sal*12 + bonus
})
import org.apache.spark.sql.functions
// df.withColumn("all",functions.callUDF("all_income",$"salary",$"bonus"))
// .select("id","name","all")
// .show()
df.createTempView("salary")
session.sql(
"""
|select id,name,all_income(salary,bonus) all from salary
|""".stripMargin)
.show()
}
}
输出:
2. UDAF(用户自定义的聚合函数)
指的是用户自定义的聚合函数,多进一出,比如MySQL中的,count函数,avg函数。
以学生信息为主进行统计,所有人员的年龄的总和
或者每个性别的年龄的平均值
计算所有人的年龄之和:
Scala
package com.atguigu.bigdata.test
import org.apache.spark.sql.{Encoder, Encoders, SparkSession, functions}
import org.apache.spark.sql.expressions.Aggregator
/**
* ClassName : TestUDAF
* Package : com.atguigu.bigdata.test
* Description
*
* @Author HeXua
* @Create 2024/11/29 19:09
* Version 1.0
*/
object TestUDAF {
def main(args: Array[String]): Unit = {
val session = SparkSession.builder().appName("test udaf").master("local[*]").getOrCreate()
import session.implicits._
val df = session.sparkContext.textFile("D:\\software\\Spark\\SparkProgram1\\atguigu-classes\\data\\a.txt")
.map(t => {
val strs = t.split(" ")
(strs(0), strs(1), strs(2).toInt, strs(3))
}).toDF("id", "name", "age", "gender")
import org.apache.spark.sql.functions._
// 注册udaf函数
session.udf.register("mysum",udaf(new MySum))
df.createTempView("student")
session.sql(
"""
|select mysum(age) from student
|""".stripMargin)
.show()
}
}
// udaf的类继承Aggregator抽象类
class MySum extends Aggregator[Int,Int,Int]{
//初始化
def zero: Int = 0
//聚合逻辑
def reduce(b: Int, a: Int): Int = a+b
//整体聚合
def merge(b1: Int, b2: Int): Int = b1+b2
//最终返回值
def finish(reduction: Int): Int = reduction
//累加值的类型
def bufferEncoder: Encoder[Int] = Encoders.scalaInt
//输出结果的类型
def outputEncoder: Encoder[Int] = Encoders.scalaInt
}
定义用户自定义聚合函数时,继承Aggregator类需要指定三个泛型参数。这三个泛型参数分别代表不同的概念。
泛型参数解释:
- 输入类型(IN)
这是聚合函数的输入类型,即每次调用reduce方法时传入的单个元素的类型。例如你要计算一组整数的平均值,输入类型就是int。
- 缓冲区类型(BUFFER)
这是聚合函数的中间状态类型,也称为缓冲区类型。
例如你要计算一组整数的平均值,缓冲区可能包含两个字段:总和和计数,因为iBUF可能是一个元组。
- 输出类型(OUT)
这是聚合函数的最终输出类型,即finish方法返回的类型。例如你要计算平均值,最终输出类型是Double。
方法解释:
zero:初始化缓冲区的值,对于平均值计算,初始化和计数都是0。
reduce:更新缓冲区,每次传入一个新的输入值时,更新总和和计数。
finish:计算最终结果,根据缓冲区中的总和和计数,计算平均值。
bufferEncoder:定义缓冲区类型的编码器,用于序列化和反序列化缓冲区。
outputEncoder:定义最终输出类型的编码器,用于序列化和反序列化输出结果。
计算每个性别的年龄的平均值:
Scala
case class AggragateVo(var cnt:Int,var sum:Int)
object MyAvg extends Aggregator[Int,AggragateVo,Double]{
override def zero: AggragateVo = AggragateVo(0,0)
override def reduce(b: AggragateVo, a: Int): AggragateVo = {
b.cnt += 1
b.sum += a
b
}
override def merge(b1: AggragateVo, b2: AggragateVo): AggragateVo = {
b1.cnt += b2.cnt
b1.sum += b2.sum
b1
}
override def finish(reduction: AggragateVo): Double = {
reduction.sum.toDouble /reduction.cnt
}
override def bufferEncoder: Encoder[AggragateVo] = Encoders.product
override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}
3. UDTF(用户自定义炸裂函数)
拆分函数,进入的是一行内容出现的结果是多行内容。
spark中并不直接支持UDTF函数。但可以使用hive中的炸裂函数达到效果。
Scala
import org.apache.spark.sql.SparkSession
object TestUDTF {
def main(args: Array[String]): Unit = {
val session = SparkSession.builder().appName("test udtf").master("local[*]").getOrCreate()
import session.implicits._
val df = session.sparkContext.textFile("file:///headless/workspace/spark/data/m.txt")
.map(t => {
val strs = t.split(",")
(strs(0), strs(1), strs(2))
}).toDF("id", "name", "actors")
//explode map array
df.createTempView("movies")
session.sql(
"""
|select id,name,actor from movies lateral view explode(split(actors,'\\|')) t as actor
|""".stripMargin)
.createTempView("movies1")
session.sql(
"""
|select count(1),actor from movies1 group by actor
|""".stripMargin)
.show()
}
}