FlinkSQL聚合函数(Aggregate Function)详解

使用场景: 聚合函数即 UDAF,常⽤于进多条数据,出⼀条数据的场景。

上图展示了⼀个 聚合函数的例⼦ 以及 聚合函数包含的重要⽅法

案例场景:

关于饮料的表,有三个字段,分别是 id、name、price,表⾥有 5 ⾏数据,找到所有饮料⾥最贵的饮料的价格,即执⾏⼀个 max() 聚合拿到结果,遍历所有 5 ⾏数据,最终结果就只有⼀个数值。

开发流程:

实现 AggregateFunction 接⼝,其中所有的⽅法必须是 public 的、⾮ static 的;

必须实现以下⽅法:

  • Acc聚合中间结果 createAccumulator() : 为当前 Key 初始化⼀个空的 accumulator,其存储了聚合的中间结果,⽐如在执⾏ max() 时会存储当前的 max 值;
  • accumulate(Acc accumulator, Input输⼊参数) : 每⼀⾏数据,调⽤ accumulate() ⽅法更新 accumulator,处理每⼀条输⼊数据,方法必须声明为 public 和⾮ static 的,accumulate ⽅法可以重载,⽅法的参数类型可以不同,并且⽀持变⻓参数。
  • Output输出参数 getValue(Acc accumulator) : 通过调⽤ getValue ⽅法来计算和返回最终的结果。

某些场景下必须实现:

  • retract(Acc accumulator, Input输⼊参数) : 在回撤流的场景下必须实现,在计算回撤数据时调⽤,如果没有实现会直接报错。
  • merge(Acc accumulator, Iterable it) : 在批式聚合以及流式聚合中的 Session、Hop 窗⼝聚合场景下必须要实现,此外,这个⽅法对于优化也有帮助,例如,打开了两阶段聚合优化,需要 AggregateFunction 实现 merge ⽅法,在数据 shuffle 前先进⾏⼀次聚合计算。
  • resetAccumulator() : 在批式聚合中是必须实现的。

关于⼊参、出参数据类型信息的⽅法:

默认情况下,⽤户的 Input 输⼊参数( accumulate(Acc accumulator, Input输⼊参数) 的⼊参 Input输⼊参数 )、accumulator( Acc聚合中间结果 createAccumulator() 的返回结果)、 Output输出参数数据类型( Output输出参数 getValue(Acc accumulator) 的 Output输出参数 )会被 Flink 使⽤反射获取到。

对于 accumulator 和 Output 输出参数类型,Flink SQL 的类型推导在遇到复杂类型时会推导出错误的结果(注意:Input输⼊参数 因为是上游算⼦传⼊的,类型信息是确认的,不会出现推导错误),⽐如⾮基本类型 POJO 的复杂类型。

同 ScalarFunction 和 TableFunction, AggregateFunction 提供了 AggregateFunction#getResultType() 和AggregateFunction#getAccumulatorType() 指定最终返回值类型和 accumulator 的类型,两个函数的返回值类型是TypeInformation。

  • getResultType() : 即 Output 输出参数 getValue(Acc accumulator) 的输出结果数据类型;
  • getAccumulatorType() : 即 Acc聚合中间结果 createAccumulator() 的返回结果数据类型。

案例: 加权平均值

  • 定义⼀个聚合函数来计算某⼀列的加权平均
  • 在 TableEnvironment 中注册函数
  • 在查询中使⽤函数

实现思路:

为了计算加权平均值,accumulator 需要存储加权总和以及数据的条数,定义了类 WeightedAvgAccumulator 作为 accumulator,Flink 的 checkpoint 机制会⾃动保存 accumulator,在失败时进⾏恢复,保证精确⼀次的语义。

WeightedAvg(聚合函数)的 accumulate ⽅法有三个输⼊参数,第⼀个是 WeightedAvgAccum accumulator,另外两个是⽤户⾃定义的输⼊:输⼊的值 ivalue 和 输⼊的权重 iweight,尽管 retract()、merge()、resetAccumulator() ⽅法在⼤多数聚合类型中都不是必须实现的,但在样例中提供了他们的实现,并且定义了 getResultType() 和 getAccumulatorType()。

代码案例:

import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.streaming.api.datastream.DataStreamSource;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.EnvironmentSettings;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.functions.AggregateFunction;

import java.io.Serializable;

import static org.apache.flink.table.api.Expressions.$;
import static org.apache.flink.table.api.Expressions.call;

/**
 * 输入数据:
 * a,1,1
 * a,10,2
 *
 * 输出结果:
 * res1=>:1> +I[a, 1.0]
 * res2=>:1> +I[a, 1.0]
 * res3=>:1> +I[a, 1.0]
 *
 * res1=>:1> -U[a, 1.0]
 * res1=>:1> +U[a, 7.0]
 * res3=>:1> -U[a, 1.0]
 * res3=>:1> +U[a, 7.0]
 * res2=>:1> -U[a, 1.0]
 * res2=>:1> +U[a, 7.0]
 */
public class AggregateFunctionTest {
    public static void main(String[] args) throws Exception {
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();

        EnvironmentSettings settings = EnvironmentSettings.newInstance()
                .useBlinkPlanner()
                .inStreamingMode()
                .build();

        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env, settings);

        DataStreamSource<String> source = env.socketTextStream("localhost", 8888);

        SingleOutputStreamOperator<Tuple3<String, Double, Double>> tpStream = source.map(new MapFunction<String, Tuple3<String, Double, Double>>() {
            @Override
            public Tuple3<String, Double, Double> map(String input) throws Exception {
                return new Tuple3<>(input.split(",")[0],
                        Double.parseDouble(input.split(",")[1]),
                        Double.parseDouble(input.split(",")[2]));
            }
        });

        Table table = tEnv.fromDataStream(tpStream, "field,iValue,iWeight");

        tEnv.createTemporaryView("SourceTable", table);

        Table res1 = tEnv.from("SourceTable")
                .groupBy($("field"))
                .select($("field"), call(WeightedAvg.class, $("iValue"), $("iWeight")));

        // 注册函数
        tEnv.createTemporarySystemFunction("WeightedAvg", WeightedAvg.class);

        // Table API 调⽤函数
        Table res2 = tEnv.from("SourceTable")
                .groupBy($("field"))
                .select($("field"), call("WeightedAvg", $("iValue"), $("iWeight")));

        // SQL API 调⽤函数
        Table res3 = tEnv.sqlQuery("SELECT field, WeightedAvg(`iValue`, iWeight) FROM SourceTable GROUP BY field");

        tEnv.toChangelogStream(res1).print("res1=>");
        tEnv.toChangelogStream(res2).print("res2=>");
        tEnv.toChangelogStream(res3).print("res3=>");

        env.execute();

    }

    // ⾃定义⼀个计算权重 avg 的 accmulator
    public static class WeightedAvgAccumulator implements Serializable {
        public Double sum = 0.0;
        public Double count = 0.0;
    }

    // 输⼊:Long iValue, Integer iWeight
    public static class WeightedAvg extends AggregateFunction<Double, WeightedAvgAccumulator> {
        // 创建⼀个 accumulator
        @Override
        public WeightedAvgAccumulator createAccumulator() {
            return new WeightedAvgAccumulator();
        }

        public void accumulate(WeightedAvgAccumulator acc, Double iValue, Double iWeight) {
            acc.sum += iValue * iWeight;
            acc.count += iWeight;
        }

        public void retract(WeightedAvgAccumulator acc, Double iValue, Double iWeight) {
            acc.sum -= iValue * iWeight;
            acc.count -= iWeight;
        }

        // 获取返回结果
        @Override
        public Double getValue(WeightedAvgAccumulator acc) {
            if (acc.count == 0) {
                return null;
            } else {
                return acc.sum / acc.count;
            }
        }

        // Session window 使⽤这个⽅法将⼏个单独窗⼝的结果合并
        public void merge(WeightedAvgAccumulator acc, Iterable<WeightedAvgAccumulator> it) {
            for (WeightedAvgAccumulator a : it) {
                acc.count += a.count;
                acc.sum += a.sum;
            }
        }

        public void resetAccumulator(WeightedAvgAccumulator acc) {
            acc.count = 0.0;
            acc.sum = 0.0;
        }
    }
}

测试结果:

相关推荐
众拾达人4 分钟前
Android自动化测试实战 Java篇 主流工具 框架 脚本
android·java·开发语言
皓木.6 分钟前
Mybatis-Plus
java·开发语言
不良人天码星7 分钟前
lombok插件不生效
java·开发语言·intellij-idea
守护者17027 分钟前
JAVA学习-练习试用Java实现“使用Arrays.toString方法将数组转换为字符串并打印出来”
java·学习
源码哥_博纳软云29 分钟前
JAVA同城服务场馆门店预约系统支持H5小程序APP源码
java·开发语言·微信小程序·小程序·微信公众平台
禾高网络30 分钟前
租赁小程序成品|租赁系统搭建核心功能
java·人工智能·小程序
学会沉淀。36 分钟前
Docker学习
java·开发语言·学习
如若12338 分钟前
对文件内的文件名生成目录,方便查阅
java·前端·python
初晴~1 小时前
【Redis分布式锁】高并发场景下秒杀业务的实现思路(集群模式)
java·数据库·redis·分布式·后端·spring·
黑胡子大叔的小屋2 小时前
基于springboot的海洋知识服务平台的设计与实现
java·spring boot·毕业设计