在 PyFlink (Apache Flink 的 Python API)中,自定义函数分为三种主要类型:ScalarFunction
(标量函数)、TableFunction
(表函数)和 AggregateFunction
(聚合函数)。这些自定义函数可以在 Flink 的 SQL 和 Table API 中使用,用于扩展 PyFlink 的内置功能,处理自定义的计算逻辑。
1. 安装 PyFlink
在开始之前,确保你的环境中已安装了 PyFlink:
bash
pip install apache-flink
2. 自定义标量函数 (ScalarFunction)
ScalarFunction 是最常见的自定义函数。它接受多个输入并返回一个标量值,类似于 SQL 中的普通函数。
2.1 创建一个标量函数
我们可以通过继承 ScalarFunction
类来定义一个自定义标量函数。例如,定义一个计算两个数之和的标量函数:
python
from pyflink.table.udf import ScalarFunction, udf
from pyflink.table import EnvironmentSettings, TableEnvironment
class AddScalarFunction(ScalarFunction):
def eval(self, a, b):
return a + b
# 创建 Table Environment
env_settings = EnvironmentSettings.in_streaming_mode()
table_env = TableEnvironment.create(env_settings)
# 注册并使用标量函数
add_func = udf(AddScalarFunction(), result_type='BIGINT')
table_env.create_temporary_system_function("add_func", add_func)
# 创建示例数据
table_env.execute_sql("""
CREATE TEMPORARY VIEW input_table (a BIGINT, b BIGINT) AS
VALUES (1, 2), (3, 4), (5, 6)
""")
# 查询并使用自定义标量函数
result = table_env.sql_query("SELECT add_func(a, b) FROM input_table")
result.execute().print()
2.2 代码解析
- 创建标量函数 : 继承
ScalarFunction
类,并实现eval
方法。eval
方法接受多个参数并返回计算结果。 - 注册函数 : 使用
udf()
包装自定义函数并注册到TableEnvironment
中,以便在 SQL 查询中使用。 - 在 SQL 中使用自定义函数 : 使用注册的函数名
add_func
在 SQL 中调用该自定义函数。
3. 自定义表函数 (TableFunction)
TableFunction 返回一个表,而不是单一值。它类似于 SQL 中的 LATERAL VIEW
或 UNNEST
操作,允许将一行数据转换成多行输出。
3.1 创建一个表函数
我们可以通过继承 TableFunction
类来定义自定义表函数。例如,定义一个表函数,将字符串按逗号分割并返回多个值:
python
from pyflink.table.udf import TableFunction, udtf
class SplitTableFunction(TableFunction):
def eval(self, text):
for word in text.split(","):
self.collect(word)
# 注册并使用表函数
split_func = udtf(SplitTableFunction(), result_types=['STRING'])
table_env.create_temporary_system_function("split_func", split_func)
# 创建示例数据
table_env.execute_sql("""
CREATE TEMPORARY VIEW input_table (text STRING) AS
VALUES ('hello,world'), ('foo,bar,baz')
""")
# 查询并使用自定义表函数
result = table_env.sql_query("""
SELECT text, word
FROM input_table, LATERAL TABLE(split_func(text)) AS T(word)
""")
result.execute().print()
3.2 代码解析
- 创建表函数 : 继承
TableFunction
类,并在eval
方法中使用self.collect()
收集每行数据的输出。 - 注册函数 : 使用
udtf()
包装自定义表函数并注册到TableEnvironment
中。 - 在 SQL 中使用表函数 : 使用
LATERAL TABLE
在 SQL 查询中调用自定义表函数,从每行数据生成多个输出。
4. 自定义聚合函数 (AggregateFunction)
AggregateFunction 用于定义自定义的聚合逻辑,类似于 SQL 中的聚合函数(如 SUM
、COUNT
等)。它接收多行输入并返回聚合结果。
4.1 创建一个聚合函数
我们可以通过继承 AggregateFunction
类来定义自定义聚合函数。例如,定义一个求平均值的聚合函数:
python
from pyflink.table.udf import AggregateFunction, udaf
class AvgAggregateFunction(AggregateFunction):
class Accumulator:
def __init__(self):
self.sum = 0
self.count = 0
def get_value(self, accumulator):
return accumulator.sum / accumulator.count if accumulator.count != 0 else 0
def create_accumulator(self):
return AvgAggregateFunction.Accumulator()
def accumulate(self, accumulator, value):
if value is not None:
accumulator.sum += value
accumulator.count += 1
# 注册并使用聚合函数
avg_func = udaf(AvgAggregateFunction(), result_type='DOUBLE', accumulator_type='ROW<sum DOUBLE, count BIGINT>')
table_env.create_temporary_system_function("avg_func", avg_func)
# 创建示例数据
table_env.execute_sql("""
CREATE TEMPORARY VIEW input_table (a BIGINT) AS
VALUES (1), (2), (3), (4), (5)
""")
# 查询并使用自定义聚合函数
result = table_env.sql_query("SELECT avg_func(a) FROM input_table")
result.execute().print()
4.2 代码解析
- 创建聚合函数 : 继承
AggregateFunction
类。create_accumulator
用于创建累加器,accumulate
用于聚合数据,get_value
用于返回聚合结果。 - 定义累加器 : 定义了一个
Accumulator
类来保存聚合的中间状态(例如总和和计数)。 - 注册函数 : 使用
udaf()
包装自定义聚合函数并注册到TableEnvironment
中。 - 在 SQL 中使用聚合函数: 在 SQL 查询中调用自定义聚合函数。
5. 完整示例
以下是一个完整的示例,展示了如何在一个 PyFlink 程序中定义并使用 ScalarFunction
、TableFunction
和 AggregateFunction
:
python
from pyflink.table.udf import ScalarFunction, TableFunction, AggregateFunction, udf, udtf, udaf
from pyflink.table import EnvironmentSettings, TableEnvironment
# Scalar Function: 加法
class AddScalarFunction(ScalarFunction):
def eval(self, a, b):
return a + b
# Table Function: 按逗号分割字符串
class SplitTableFunction(TableFunction):
def eval(self, text):
for word in text.split(","):
self.collect(word)
# Aggregate Function: 计算平均值
class AvgAggregateFunction(AggregateFunction):
class Accumulator:
def __init__(self):
self.sum = 0
self.count = 0
def get_value(self, accumulator):
return accumulator.sum / accumulator.count if accumulator.count != 0 else 0
def create_accumulator(self):
return AvgAggregateFunction.Accumulator()
def accumulate(self, accumulator, value):
if value is not None:
accumulator.sum += value
accumulator.count += 1
# 创建 Table Environment
env_settings = EnvironmentSettings.in_streaming_mode()
table_env = TableEnvironment.create(env_settings)
# 注册自定义函数
add_func = udf(AddScalarFunction(), result_type='BIGINT')
table_env.create_temporary_system_function("add_func", add_func)
split_func = udtf(SplitTableFunction(), result_types=['STRING'])
table_env.create_temporary_system_function("split_func", split_func)
avg_func = udaf(AvgAggregateFunction(), result_type='DOUBLE', accumulator_type='ROW<sum DOUBLE, count BIGINT>')
table_env.create_temporary_system_function("avg_func", avg_func)
# 示例数据
table_env.execute_sql("""
CREATE TEMPORARY VIEW input_table (a BIGINT, text STRING) AS
VALUES (1, 'foo,bar'), (2, 'hello,world'), (3, 'foo,baz')
""")
# 使用标量函数
result = table_env.sql_query("SELECT add_func(a, a) FROM input_table")
result.execute().print()
# 使用表函数
result = table_env.sql_query("""
SELECT text, word
FROM input_table, LATERAL TABLE(split_func(text)) AS T(word)
""")
result.execute().print()
# 使用聚合函数
result = table_env.sql_query("SELECT avg_func(a) FROM input_table")
result.execute().print()
6. 总结
在 PyFlink 中,自定义函数(ScalarFunction
、TableFunction
和 AggregateFunction
)是扩展 Flink SQL 和 Table API 功能的重要工具。通过编写自定义函数,你可以将复杂的业务逻辑集成到 Flink 的数据处理管道中,从而实现更灵活、更强大的数据处理应用。