Apache Flink Keyed State 详解
1. 基本概念
Keyed State(键控状态)是 Flink 中最常用的状态类型之一。它与特定的 key 相关联,只能在 KeyedStream 上使用。Keyed State 的特点是:
- 作用域限定在当前元素的 key 上
- 每个 key 都有其独立的状态副本
- 状态会根据 key 自动分区和分布
2. 适用场景
Keyed State 适用于以下场景:
- 需要按 key 聚合数据的场景
- 需要维护每个 key 的状态信息
- 窗口操作中的状态管理
- 用户自定义函数中需要维护 key 相关状态的场景
3. Keyed State 类型
Flink 提供了多种 Keyed State 类型:
3.1 ValueState
ValueState 用于存储单个值,每个 key 对应一个值。
java
import org.apache.flink.api.common.state.ValueState;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.streaming.api.functions.KeyedProcessFunction;
import org.apache.flink.util.Collector;
/**
* ValueState 使用示例
* 统计每个用户的访问次数
*/
public class UserVisitCountFunction extends KeyedProcessFunction<String, String, String> {
private ValueState<Integer> visitCountState;
@Override
public void open(Configuration parameters) {
// 创建 ValueStateDescriptor
ValueStateDescriptor<Integer> descriptor = new ValueStateDescriptor<>(
"visit-count", // 状态名称
Integer.class, // 状态类型
0 // 默认值
);
visitCountState = getRuntimeContext().getState(descriptor);
}
@Override
public void processElement(String userId, Context ctx, Collector<String> out) throws Exception {
// 获取当前状态值
Integer currentCount = visitCountState.value();
// 更新状态值
currentCount++;
visitCountState.update(currentCount);
// 输出结果
out.collect("User " + userId + " has visited " + currentCount + " times");
}
}
3.2 ListState
ListState 用于存储元素列表,每个 key 对应一个元素列表。
java
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.streaming.api.functions.KeyedProcessFunction;
import org.apache.flink.util.Collector;
import java.util.ArrayList;
import java.util.List;
/**
* ListState 使用示例
* 维护每个用户的最近访问记录
*/
public class RecentVisitsFunction extends KeyedProcessFunction<String, String, List<String>> {
private ListState<String> recentVisitsState;
@Override
public void open(Configuration parameters) {
// 创建 ListStateDescriptor
ListStateDescriptor<String> descriptor = new ListStateDescriptor<>(
"recent-visits", // 状态名称
String.class // 状态类型
);
recentVisitsState = getRuntimeContext().getState(descriptor);
}
@Override
public void processElement(String visitRecord, Context ctx, Collector<List<String>> out) throws Exception {
// 添加新的访问记录
recentVisitsState.add(visitRecord);
// 获取当前所有访问记录
Iterable<String> visits = recentVisitsState.get();
List<String> visitList = new ArrayList<>();
for (String visit : visits) {
visitList.add(visit);
}
// 只保留最近5条记录
if (visitList.size() > 5) {
// 移除最早的记录
recentVisitsState.clear();
for (int i = visitList.size() - 5; i < visitList.size(); i++) {
recentVisitsState.add(visitList.get(i));
}
visitList = visitList.subList(visitList.size() - 5, visitList.size());
}
// 输出结果
out.collect(visitList);
}
}
3.3 MapState<UK, UV>
MapState 用于存储键值对映射,每个 key 对应一个 Map。
java
import org.apache.flink.api.common.state.MapState;
import org.apache.flink.api.common.state.MapStateDescriptor;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.streaming.api.functions.KeyedProcessFunction;
import org.apache.flink.util.Collector;
/**
* MapState 使用示例
* 维护每个用户的各页面访问次数
*/
public class PageVisitCountFunction extends KeyedProcessFunction<String, PageVisit, String> {
private MapState<String, Integer> pageVisitCountState;
@Override
public void open(Configuration parameters) {
// 创建 MapStateDescriptor
MapStateDescriptor<String, Integer> descriptor = new MapStateDescriptor<>(
"page-visit-count", // 状态名称
String.class, // key 类型
Integer.class // value 类型
);
pageVisitCountState = getRuntimeContext().getState(descriptor);
}
@Override
public void processElement(PageVisit pageVisit, Context ctx, Collector<String> out) throws Exception {
String page = pageVisit.getPage();
String user = pageVisit.getUser();
// 获取当前页面的访问次数
Integer currentCount = pageVisitCountState.get(page);
if (currentCount == null) {
currentCount = 0;
}
// 更新访问次数
currentCount++;
pageVisitCountState.put(page, currentCount);
// 输出结果
out.collect("User " + user + " visited page " + page + " " + currentCount + " times");
}
// 页面访问记录类
public static class PageVisit {
private String user;
private String page;
public PageVisit() {}
public PageVisit(String user, String page) {
this.user = user;
this.page = page;
}
public String getUser() { return user; }
public String getPage() { return page; }
}
}
3.4 ReducingState
ReducingState 用于存储聚合值,通过 ReduceFunction 对添加的元素进行聚合。
java
import org.apache.flink.api.common.state.ReducingState;
import org.apache.flink.api.common.state.ReducingStateDescriptor;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.streaming.api.functions.KeyedProcessFunction;
import org.apache.flink.util.Collector;
/**
* ReducingState 使用示例
* 计算每个用户的总消费金额
*/
public class TotalAmountFunction extends KeyedProcessFunction<String, Transaction, Double> {
private ReducingState<Double> totalAmountState;
@Override
public void open(Configuration parameters) {
// 创建 ReducingStateDescriptor
ReducingStateDescriptor<Double> descriptor = new ReducingStateDescriptor<>(
"total-amount", // 状态名称
new SumReduceFunction(), // 聚合函数
Double.class // 状态类型
);
totalAmountState = getRuntimeContext().getState(descriptor);
}
@Override
public void processElement(Transaction transaction, Context ctx, Collector<Double> out) throws Exception {
// 添加交易金额到状态
totalAmountState.add(transaction.getAmount());
// 获取聚合后的总金额
Double totalAmount = totalAmountState.get();
// 输出结果
out.collect(totalAmount);
}
// 求和聚合函数
public static class SumReduceFunction implements ReduceFunction<Double> {
@Override
public Double reduce(Double value1, Double value2) throws Exception {
return value1 + value2;
}
}
// 交易记录类
public static class Transaction {
private String user;
private Double amount;
public Transaction() {}
public Transaction(String user, Double amount) {
this.user = user;
this.amount = amount;
}
public String getUser() { return user; }
public Double getAmount() { return amount; }
}
}
3.5 AggregatingState<IN, OUT>
AggregatingState 与 ReducingState 类似,但可以处理不同类型的输入和输出。
java
import org.apache.flink.api.common.state.AggregatingState;
import org.apache.flink.api.common.state.AggregatingStateDescriptor;
import org.apache.flink.api.common.functions.AggregateFunction;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.streaming.api.functions.KeyedProcessFunction;
import org.apache.flink.util.Collector;
/**
* AggregatingState 使用示例
* 计算每个用户的平均消费金额
*/
public class AverageAmountFunction extends KeyedProcessFunction<String, Transaction, Double> {
private AggregatingState<Transaction, Double> averageAmountState;
@Override
public void open(Configuration parameters) {
// 创建 AggregatingStateDescriptor
AggregatingStateDescriptor<Transaction, AverageAccumulator, Double> descriptor =
new AggregatingStateDescriptor<>(
"average-amount", // 状态名称
new AverageAggregateFunction(), // 聚合函数
new AverageAccumulator() // 初始累加器
);
averageAmountState = getRuntimeContext().getState(descriptor);
}
@Override
public void processElement(Transaction transaction, Context ctx, Collector<Double> out) throws Exception {
// 添加交易记录到状态
averageAmountState.add(transaction);
// 获取聚合后的平均金额
Double averageAmount = averageAmountState.get();
// 输出结果
out.collect(averageAmount);
}
// 平均值聚合函数
public static class AverageAggregateFunction implements AggregateFunction<Transaction, AverageAccumulator, Double> {
@Override
public AverageAccumulator createAccumulator() {
return new AverageAccumulator();
}
@Override
public AverageAccumulator add(Transaction transaction, AverageAccumulator accumulator) {
accumulator.sum += transaction.getAmount();
accumulator.count++;
return accumulator;
}
@Override
public Double getResult(AverageAccumulator accumulator) {
if (accumulator.count == 0) {
return 0.0;
}
return accumulator.sum / accumulator.count;
}
@Override
public AverageAccumulator merge(AverageAccumulator a, AverageAccumulator b) {
a.sum += b.sum;
a.count += b.count;
return a;
}
}
// 平均值累加器
public static class AverageAccumulator {
public double sum = 0.0;
public long count = 0;
}
// 交易记录类
public static class Transaction {
private String user;
private Double amount;
public Transaction() {}
public Transaction(String user, Double amount) {
this.user = user;
this.amount = amount;
}
public String getUser() { return user; }
public Double getAmount() { return amount; }
}
}
4. 配置方法
4.1 状态描述符
每种状态类型都有对应的描述符:
java
// ValueStateDescriptor
ValueStateDescriptor<Integer> valueDescriptor = new ValueStateDescriptor<>(
"value-state-name",
Integer.class,
0 // 默认值
);
// ListStateDescriptor
ListStateDescriptor<String> listDescriptor = new ListStateDescriptor<>(
"list-state-name",
String.class
);
// MapStateDescriptor
MapStateDescriptor<String, Integer> mapDescriptor = new MapStateDescriptor<>(
"map-state-name",
String.class,
Integer.class
);
// ReducingStateDescriptor
ReducingStateDescriptor<Double> reducingDescriptor = new ReducingStateDescriptor<>(
"reducing-state-name",
new SumReduceFunction(),
Double.class
);
// AggregatingStateDescriptor
AggregatingStateDescriptor<Transaction, AverageAccumulator, Double> aggregatingDescriptor =
new AggregatingStateDescriptor<>(
"aggregating-state-name",
new AverageAggregateFunction(),
new AverageAccumulator()
);
4.2 状态 TTL(Time-To-Live)
可以为状态设置 TTL,自动清理过期数据:
java
import org.apache.flink.api.common.state.StateTtlConfig;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.api.common.time.Time;
// 配置状态 TTL
StateTtlConfig ttlConfig = StateTtlConfig
.newBuilder(Time.hours(1)) // 设置 TTL 为 1 小时
.setUpdateType(StateTtlConfig.UpdateType.OnCreateAndWrite) // 更新类型
.setStateVisibility(StateTtlConfig.StateVisibility.NeverReturnExpired) // 状态可见性
.cleanupFullSnapshot() // 清理策略
.build();
// 应用 TTL 配置到状态描述符
ValueStateDescriptor<Integer> descriptor = new ValueStateDescriptor<>(
"ttl-state",
Integer.class,
0
);
descriptor.enableTimeToLive(ttlConfig);
5. 完整使用示例
java
import org.apache.flink.api.common.eventtime.WatermarkStrategy;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.state.ValueState;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.KeyedProcessFunction;
import org.apache.flink.util.Collector;
/**
* Keyed State 完整使用示例
* 实时统计用户访问次数
*/
public class KeyedStateExample {
public static void main(String[] args) throws Exception {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
// 创建用户访问数据流
DataStream<String> userVisits = env.fromElements(
"user1", "user2", "user1", "user3", "user2", "user1"
);
// 按用户分组并统计访问次数
DataStream<String> visitCounts = userVisits
.keyBy(user -> user)
.process(new VisitCountFunction());
visitCounts.print();
env.execute("Keyed State Example");
}
/**
* 访问次数统计函数
*/
public static class VisitCountFunction extends KeyedProcessFunction<String, String, String> {
private ValueState<Integer> countState;
@Override
public void open(Configuration parameters) {
ValueStateDescriptor<Integer> descriptor = new ValueStateDescriptor<>(
"visit-count",
Integer.class,
0
);
countState = getRuntimeContext().getState(descriptor);
}
@Override
public void processElement(String user, Context ctx, Collector<String> out) throws Exception {
// 获取当前计数
Integer count = countState.value();
// 增加计数
count++;
// 更新状态
countState.update(count);
// 输出结果
out.collect("User " + user + " has visited " + count + " times");
}
}
}
6. 最佳实践建议
-
合理选择状态类型:
- 单个值使用 ValueState
- 列表数据使用 ListState
- 键值对数据使用 MapState
- 需要聚合的数据使用 ReducingState 或 AggregatingState
-
状态命名规范:
- 使用有意义的名称
- 避免重复名称
- 建议使用小写字母和连字符
-
状态清理:
- 及时清理不需要的状态
- 使用 TTL 自动清理过期数据
- 在适当的时候调用 clear() 方法
-
性能优化:
- 避免在状态中存储大量数据
- 合理设置状态后端
- 考虑使用 RocksDB 状态后端处理大状态
-
容错处理:
- 确保状态操作的幂等性
- 处理状态恢复时的异常情况
- 定期检查点以保证状态一致性
通过合理使用 Keyed State,可以有效地在 Flink 应用程序中维护和处理与特定键相关联的状态信息,实现复杂的状态管理和计算逻辑。