Flink SQL自定义函数开发:标量、聚合、表值函数实现

1. Flink函数体系架构

1.1 函数类型全景图

Flink自定义函数分为三大类,满足不同场景的计算需求:

复制代码
自定义函数体系
├── 标量函数 (Scalar Function)
│   ├── 一对一转换
│   └── 无状态计算
├── 表值函数 (Table Function)  
│   ├── 一对多展开
│   └── 返回多行结果
└── 聚合函数 (Aggregate Function)
    ├── 多对一聚合
    └── 有状态计算

1.2 函数开发基础环境

java 复制代码
// Maven依赖配置
<dependencies>
    <dependency>
        <groupId>org.apache.flink</groupId>
        <artifactId>flink-table-api-java-bridge</artifactId>
        <version>1.18.0</version>
        <scope>provided</scope>
    </dependency>
</dependencies>

// 函数基类引入
import org.apache.flink.table.functions.*;
import org.apache.flink.table.api.DataTypes;
import org.apache.flink.types.Row;

2. 标量函数开发实战

2.1 基础标量函数实现

一对一数据转换,输入一行输出一行。

java 复制代码
// 手机号脱敏函数
public class PhoneMaskFunction extends ScalarFunction {
    
    public String eval(String phone) {
        if (phone == null || phone.length() != 11) {
            return phone;
        }
        return phone.substring(0, 3) + "****" + phone.substring(7);
    }
    
    // 自动类型推导(可选)
    @Override
    public TypeInference getTypeInference(DataTypeFactory typeFactory) {
        return TypeInference.newBuilder()
            .outputTypeStrategy(TypeStrategies.explicit(DataTypes.STRING()))
            .build();
    }
}

// SQL注册使用
// CREATE FUNCTION phone_mask AS 'com.example.PhoneMaskFunction';
// SELECT phone_mask(user_phone) AS masked_phone FROM users;

2.2 多参数标量函数

支持多个输入参数的复杂计算。

java 复制代码
// 地理距离计算函数
public class GeoDistanceFunction extends ScalarFunction {
    
    private static final double EARTH_RADIUS = 6371.0; // 地球半径(km)
    
    public Double eval(Double lat1, Double lon1, Double lat2, Double lon2) {
        if (lat1 == null || lon1 == null || lat2 == null || lon2 == null) {
            return null;
        }
        
        double dLat = Math.toRadians(lat2 - lat1);
        double dLon = Math.toRadians(lon2 - lon1);
        
        double a = Math.sin(dLat / 2) * Math.sin(dLat / 2) +
                  Math.cos(Math.toRadians(lat1)) * Math.cos(Math.toRadians(lat2)) *
                  Math.sin(dLon / 2) * Math.sin(dLon / 2);
        
        double c = 2 * Math.atan2(Math.sqrt(a), Math.sqrt(1 - a));
        return EARTH_RADIUS * c;
    }
}

// SQL使用示例
// SELECT geo_distance(lat1, lon1, lat2, lon2) AS distance_km FROM locations;

3. 聚合函数深度开发

3.1 基础聚合函数实现

多行数据聚合成单个结果。

java 复制代码
// 加权平均值聚合函数
public class WeightedAvgAccum {
    public double sum = 0.0;
    public long count = 0;
}

public class WeightedAvgFunction extends AggregateFunction<Double, WeightedAvgAccum> {
    
    @Override
    public WeightedAvgAccum createAccumulator() {
        return new WeightedAvgAccum();
    }
    
    public void accumulate(WeightedAvgAccum acc, Double value, Integer weight) {
        if (value != null && weight != null) {
            acc.sum += value * weight;
            acc.count += weight;
        }
    }
    
    public void retract(WeightedAvgAccum acc, Double value, Integer weight) {
        if (value != null && weight != null) {
            acc.sum -= value * weight;
            acc.count -= weight;
        }
    }
    
    public void merge(WeightedAvgAccum acc, Iterable<WeightedAvgAccum> iterable) {
        for (WeightedAvgAccum otherAcc : iterable) {
            acc.sum += otherAcc.sum;
            acc.count += otherAcc.count;
        }
    }
    
    @Override
    public Double getValue(WeightedAvgAccum acc) {
        if (acc.count == 0) {
            return null;
        }
        return acc.sum / acc.count;
    }
    
    @Override
    public TypeInference getTypeInference(DataTypeFactory typeFactory) {
        return TypeInference.newBuilder()
            .outputTypeStrategy(TypeStrategies.explicit(DataTypes.DOUBLE()))
            .accumulatorTypeStrategy(TypeStrategies.explicit(
                DataTypes.STRUCTURED(
                    WeightedAvgAccum.class,
                    DataTypes.FIELD("sum", DataTypes.DOUBLE()),
                    DataTypes.FIELD("count", DataTypes.BIGINT())
                )
            ))
            .build();
    }
}

// SQL使用
// SELECT product_id, weighted_avg(rating, user_weight) AS weighted_rating 
// FROM product_reviews GROUP BY product_id;

3.2 复杂状态聚合函数

维护复杂状态结构的聚合计算。

java 复制代码
// 中位数计算聚合函数
public class MedianAccum {
    public List<Double> values = new ArrayList<>();
}

public class MedianFunction extends AggregateFunction<Double, MedianAccum> {
    
    @Override
    public MedianAccum createAccumulator() {
        return new MedianAccum();
    }
    
    public void accumulate(MedianAccum acc, Double value) {
        if (value != null) {
            acc.values.add(value);
        }
    }
    
    public void retract(MedianAccum acc, Double value) {
        if (value != null) {
            acc.values.remove(value);
        }
    }
    
    @Override
    public Double getValue(MedianAccum acc) {
        if (acc.values.isEmpty()) {
            return null;
        }
        
        List<Double> sorted = new ArrayList<>(acc.values);
        Collections.sort(sorted);
        
        int size = sorted.size();
        if (size % 2 == 1) {
            return sorted.get(size / 2);
        } else {
            return (sorted.get(size / 2 - 1) + sorted.get(size / 2)) / 2.0;
        }
    }
    
    // 优化:使用TreeMap维护有序状态(生产环境推荐)
    public static class OptimizedMedianAccum {
        public TreeMap<Double, Long> valueCounts = new TreeMap<>();
        public long totalCount = 0;
    }
}

4. 表值函数开发实战

4.1 基础表值函数

一行输入,多行输出的数据展开。

java 复制代码
// JSON数组展开函数
public class JsonArrayExplodeFunction extends TableFunction<Row> {
    
    private static final ObjectMapper mapper = new ObjectMapper();
    
    public void eval(String jsonArray) {
        if (jsonArray == null || jsonArray.trim().isEmpty()) {
            return;
        }
        
        try {
            JsonNode arrayNode = mapper.readTree(jsonArray);
            if (arrayNode.isArray()) {
                for (JsonNode element : arrayNode) {
                    if (element.isObject()) {
                        collect(Row.of(
                            element.get("id").asText(),
                            element.get("name").asText(),
                            element.get("value").asDouble()
                        ));
                    }
                }
            }
        } catch (Exception e) {
            // 解析失败时跳过
        }
    }
    
    @Override
    public TypeInference getTypeInference(DataTypeFactory typeFactory) {
        return TypeInference.newBuilder()
            .outputTypeStrategy(callContext -> {
                DataType rowType = DataTypes.ROW(
                    DataTypes.FIELD("id", DataTypes.STRING()),
                    DataTypes.FIELD("name", DataTypes.STRING()),
                    DataTypes.FIELD("value", DataTypes.DOUBLE())
                );
                return Optional.of(DataTypes.STRUCTURED(
                    Row.class,
                    rowType.getChildren()
                ));
            })
            .build();
    }
}

// SQL使用
// SELECT original_id, exploded.* 
// FROM source_table, 
// LATERAL TABLE(json_array_explode(json_data)) AS exploded(id, name, value);

5. 高级函数特性实现

5.1 异步表值函数

支持异步IO的高性能表函数。

java 复制代码
// 异步维表关联函数
public class AsyncDimensionLookupFunction extends AsyncTableFunction<Row> {
    
    private final String dimensionTable;
    private transient ExecutorService executor;
    
    public AsyncDimensionLookupFunction(String dimensionTable) {
        this.dimensionTable = dimensionTable;
    }
    
    public void eval(CompletableFuture<Collection<Row>> result, String key) {
        executor.submit(() -> {
            try {
                // 模拟异步查询
                List<Row> dimensionData = queryDimensionTable(key);
                result.complete(dimensionData);
            } catch (Exception e) {
                result.completeExceptionally(e);
            }
        });
    }
    
    @Override
    public void open(FunctionContext context) {
        this.executor = Executors.newFixedThreadPool(10);
    }
    
    @Override
    public void close() {
        if (executor != null) {
            executor.shutdown();
        }
    }
    
    private List<Row> queryDimensionTable(String key) {
        // 实际查询逻辑
        return Collections.singletonList(Row.of(key, "dimension_value"));
    }
}

5.2 函数参数校验与错误处理

健壮的函数实现最佳实践。

java 复制代码
public class SafeStringFunction extends ScalarFunction {
    
    public String eval(String input, String defaultValue) {
        try {
            if (input == null || input.trim().isEmpty()) {
                return defaultValue;
            }
            return processString(input);
        } catch (Exception e) {
            // 记录日志并返回默认值
            System.err.println("String processing failed: " + e.getMessage());
            return defaultValue;
        }
    }
    
    private String processString(String input) {
        // 核心处理逻辑
        return input.trim().toUpperCase();
    }
    
    // 声明函数确定性(优化器使用)
    @Override
    public boolean isDeterministic() {
        return true;
    }
    
    // 声明函数幂等性
    @Override  
    public boolean isResultConstant(Object[] args) {
        return args.length > 0 && args[0] instanceof String 
            && ((String) args[0]).length() < 100; // 小字符串可缓存
    }
}

6. 函数注册与管理

6.1 SQL环境函数注册

java 复制代码
// 编程方式注册函数
StreamTableEnvironment tableEnv = StreamTableEnvironment.create(env);

// 注册标量函数
tableEnv.createTemporarySystemFunction("phone_mask", new PhoneMaskFunction());

// 注册聚合函数
tableEnv.createTemporarySystemFunction("weighted_avg", new WeightedAvgFunction());

// 注册表值函数
tableEnv.createTemporarySystemFunction("json_explode", new JsonArrayExplodeFunction());

6.2 SQL DDL函数注册

sql 复制代码
-- 在SQL中直接注册函数
CREATE FUNCTION phone_mask AS 'com.example.PhoneMaskFunction';

CREATE FUNCTION weighted_avg AS 'com.example.WeightedAvgFunction';

CREATE FUNCTION json_explode AS 'com.example.JsonArrayExplodeFunction';

-- 查看已注册函数
SHOW FUNCTIONS;

-- 删除函数
DROP FUNCTION IF EXISTS phone_mask;

7. 测试与调试

7.1 单元测试框架

java 复制代码
// 使用JUnit测试自定义函数
public class PhoneMaskFunctionTest {
    
    private PhoneMaskFunction function = new PhoneMaskFunction();
    
    @Test
    public void testPhoneMask() {
        assertEquals("138****1234", function.eval("13812341234"));
        assertNull(function.eval(null));
        assertEquals("123", function.eval("123")); // 短于11位
    }
    
    @Test
    public void testAggregateFunction() throws Exception {
        WeightedAvgFunction avgFunc = new WeightedAvgFunction();
        WeightedAvgAccum accum = avgFunc.createAccumulator();
        
        avgFunc.accumulate(accum, 10.0, 2);
        avgFunc.accumulate(accum, 20.0, 3);
        assertEquals(16.0, avgFunc.getValue(accum), 0.001); // (10 * 2 + 20 * 3)/5
    }
}

7.2 集成测试

java 复制代码
// 在真实TableEnvironment中测试函数
public class FunctionIntegrationTest {
    
    @Test
    public void testFunctionInSQL() {
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        StreamTableEnvironment tableEnv = StreamTableEnvironment.create(env);
        
        tableEnv.createTemporarySystemFunction("phone_mask", new PhoneMaskFunction());
        
        Table result = tableEnv.sqlQuery(
            "SELECT phone_mask('13812341234') AS masked FROM (VALUES(1))"
        );
        
        // 验证执行结果
        DataStream<Row> resultStream = tableEnv.toDataStream(result);
        // ... 添加断言验证
    }
}

8. 总结

自定义函数是扩展Flink SQL能力的核心手段。标量函数实现简单数据转换,聚合函数处理复杂状态计算,表值函数完成数据展开与生成。开发时需重点关注类型安全、状态管理和性能优化,生产环境要确保异常处理和资源清理。通过合理设计函数接口和优化实现逻辑,可以大幅提升流处理应用的表达能力和执行效率。

相关推荐
qq_252614412 小时前
python爬虫爬取视频
开发语言·爬虫·python
PNP Robotics2 小时前
聚焦具身智能,PNP机器人展出力反馈遥操作,VR动作捕捉等方案,获得中国科研贡献奖
大数据·人工智能·python·学习·机器人
咸鱼加辣2 小时前
【python面试】你x的启动?
开发语言·python
八月ouc2 小时前
Python实战小游戏(二): 文字冒险游戏
数据结构·python·文字冒险
Blossom.1182 小时前
多模态大模型实战:从零实现CLIP与电商跨模态检索系统
python·web安全·yolo·目标检测·机器学习·目标跟踪·开源软件
Jackyzhe2 小时前
Flink源码阅读:Checkpoint机制(下)
大数据·flink
wasp5202 小时前
AgentScope深入分析-设计模式与架构决策分分析
开发语言·python·agent·agentscope
山土成旧客2 小时前
【Python学习打卡-Day26】函数的艺术(上):从基础定义到参数魔法
开发语言·python·学习
roman_日积跬步-终至千里2 小时前
【源码分析】StarRocks EditLog 写入与 Replay 完整流程分析
java·网络·python