Apache Flink Keyed State 详解之一

1. 基本概念

Keyed State(键控状态)是 Flink 中最常用的状态类型之一。它与特定的 key 相关联,只能在 KeyedStream 上使用。Keyed State 的特点是:

  • 作用域限定在当前元素的 key 上
  • 每个 key 都有其独立的状态副本
  • 状态会根据 key 自动分区和分布

2. 适用场景

Keyed State 适用于以下场景:

  1. 需要按 key 聚合数据的场景
  2. 需要维护每个 key 的状态信息
  3. 窗口操作中的状态管理
  4. 用户自定义函数中需要维护 key 相关状态的场景

3. Keyed State 类型

Flink 提供了多种 Keyed State 类型:

3.1 ValueState

ValueState 用于存储单个值,每个 key 对应一个值。

java 复制代码
import org.apache.flink.api.common.state.ValueState;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.streaming.api.functions.KeyedProcessFunction;
import org.apache.flink.util.Collector;

/**
 * ValueState 使用示例
 * 统计每个用户的访问次数
 */
public class UserVisitCountFunction extends KeyedProcessFunction<String, String, String> {
    private ValueState<Integer> visitCountState;

    @Override
    public void open(Configuration parameters) {
        // 创建 ValueStateDescriptor
        ValueStateDescriptor<Integer> descriptor = new ValueStateDescriptor<>(
            "visit-count",  // 状态名称
            Integer.class,  // 状态类型
            0               // 默认值
        );
        visitCountState = getRuntimeContext().getState(descriptor);
    }

    @Override
    public void processElement(String userId, Context ctx, Collector<String> out) throws Exception {
        // 获取当前状态值
        Integer currentCount = visitCountState.value();
        
        // 更新状态值
        currentCount++;
        visitCountState.update(currentCount);
        
        // 输出结果
        out.collect("User " + userId + " has visited " + currentCount + " times");
    }
}

3.2 ListState

ListState 用于存储元素列表,每个 key 对应一个元素列表。

java 复制代码
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.streaming.api.functions.KeyedProcessFunction;
import org.apache.flink.util.Collector;
import java.util.ArrayList;
import java.util.List;

/**
 * ListState 使用示例
 * 维护每个用户的最近访问记录
 */
public class RecentVisitsFunction extends KeyedProcessFunction<String, String, List<String>> {
    private ListState<String> recentVisitsState;

    @Override
    public void open(Configuration parameters) {
        // 创建 ListStateDescriptor
        ListStateDescriptor<String> descriptor = new ListStateDescriptor<>(
            "recent-visits",  // 状态名称
            String.class      // 状态类型
        );
        recentVisitsState = getRuntimeContext().getState(descriptor);
    }

    @Override
    public void processElement(String visitRecord, Context ctx, Collector<List<String>> out) throws Exception {
        // 添加新的访问记录
        recentVisitsState.add(visitRecord);
        
        // 获取当前所有访问记录
        Iterable<String> visits = recentVisitsState.get();
        List<String> visitList = new ArrayList<>();
        for (String visit : visits) {
            visitList.add(visit);
        }
        
        // 只保留最近5条记录
        if (visitList.size() > 5) {
            // 移除最早的记录
            recentVisitsState.clear();
            for (int i = visitList.size() - 5; i < visitList.size(); i++) {
                recentVisitsState.add(visitList.get(i));
            }
            visitList = visitList.subList(visitList.size() - 5, visitList.size());
        }
        
        // 输出结果
        out.collect(visitList);
    }
}

3.3 MapState<UK, UV>

MapState 用于存储键值对映射,每个 key 对应一个 Map。

java 复制代码
import org.apache.flink.api.common.state.MapState;
import org.apache.flink.api.common.state.MapStateDescriptor;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.streaming.api.functions.KeyedProcessFunction;
import org.apache.flink.util.Collector;

/**
 * MapState 使用示例
 * 维护每个用户的各页面访问次数
 */
public class PageVisitCountFunction extends KeyedProcessFunction<String, PageVisit, String> {
    private MapState<String, Integer> pageVisitCountState;

    @Override
    public void open(Configuration parameters) {
        // 创建 MapStateDescriptor
        MapStateDescriptor<String, Integer> descriptor = new MapStateDescriptor<>(
            "page-visit-count",  // 状态名称
            String.class,        // key 类型
            Integer.class        // value 类型
        );
        pageVisitCountState = getRuntimeContext().getState(descriptor);
    }

    @Override
    public void processElement(PageVisit pageVisit, Context ctx, Collector<String> out) throws Exception {
        String page = pageVisit.getPage();
        String user = pageVisit.getUser();
        
        // 获取当前页面的访问次数
        Integer currentCount = pageVisitCountState.get(page);
        if (currentCount == null) {
            currentCount = 0;
        }
        
        // 更新访问次数
        currentCount++;
        pageVisitCountState.put(page, currentCount);
        
        // 输出结果
        out.collect("User " + user + " visited page " + page + " " + currentCount + " times");
    }
    
    // 页面访问记录类
    public static class PageVisit {
        private String user;
        private String page;
        
        public PageVisit() {}
        
        public PageVisit(String user, String page) {
            this.user = user;
            this.page = page;
        }
        
        public String getUser() { return user; }
        public String getPage() { return page; }
    }
}

3.4 ReducingState

ReducingState 用于存储聚合值,通过 ReduceFunction 对添加的元素进行聚合。

java 复制代码
import org.apache.flink.api.common.state.ReducingState;
import org.apache.flink.api.common.state.ReducingStateDescriptor;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.streaming.api.functions.KeyedProcessFunction;
import org.apache.flink.util.Collector;

/**
 * ReducingState 使用示例
 * 计算每个用户的总消费金额
 */
public class TotalAmountFunction extends KeyedProcessFunction<String, Transaction, Double> {
    private ReducingState<Double> totalAmountState;

    @Override
    public void open(Configuration parameters) {
        // 创建 ReducingStateDescriptor
        ReducingStateDescriptor<Double> descriptor = new ReducingStateDescriptor<>(
            "total-amount",     // 状态名称
            new SumReduceFunction(),  // 聚合函数
            Double.class        // 状态类型
        );
        totalAmountState = getRuntimeContext().getState(descriptor);
    }

    @Override
    public void processElement(Transaction transaction, Context ctx, Collector<Double> out) throws Exception {
        // 添加交易金额到状态
        totalAmountState.add(transaction.getAmount());
        
        // 获取聚合后的总金额
        Double totalAmount = totalAmountState.get();
        
        // 输出结果
        out.collect(totalAmount);
    }
    
    // 求和聚合函数
    public static class SumReduceFunction implements ReduceFunction<Double> {
        @Override
        public Double reduce(Double value1, Double value2) throws Exception {
            return value1 + value2;
        }
    }
    
    // 交易记录类
    public static class Transaction {
        private String user;
        private Double amount;
        
        public Transaction() {}
        
        public Transaction(String user, Double amount) {
            this.user = user;
            this.amount = amount;
        }
        
        public String getUser() { return user; }
        public Double getAmount() { return amount; }
    }
}

3.5 AggregatingState<IN, OUT>

AggregatingState 与 ReducingState 类似,但可以处理不同类型的输入和输出。

java 复制代码
import org.apache.flink.api.common.state.AggregatingState;
import org.apache.flink.api.common.state.AggregatingStateDescriptor;
import org.apache.flink.api.common.functions.AggregateFunction;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.streaming.api.functions.KeyedProcessFunction;
import org.apache.flink.util.Collector;

/**
 * AggregatingState 使用示例
 * 计算每个用户的平均消费金额
 */
public class AverageAmountFunction extends KeyedProcessFunction<String, Transaction, Double> {
    private AggregatingState<Transaction, Double> averageAmountState;

    @Override
    public void open(Configuration parameters) {
        // 创建 AggregatingStateDescriptor
        AggregatingStateDescriptor<Transaction, AverageAccumulator, Double> descriptor = 
            new AggregatingStateDescriptor<>(
                "average-amount",           // 状态名称
                new AverageAggregateFunction(),  // 聚合函数
                new AverageAccumulator()    // 初始累加器
            );
        averageAmountState = getRuntimeContext().getState(descriptor);
    }

    @Override
    public void processElement(Transaction transaction, Context ctx, Collector<Double> out) throws Exception {
        // 添加交易记录到状态
        averageAmountState.add(transaction);
        
        // 获取聚合后的平均金额
        Double averageAmount = averageAmountState.get();
        
        // 输出结果
        out.collect(averageAmount);
    }
    
    // 平均值聚合函数
    public static class AverageAggregateFunction implements AggregateFunction<Transaction, AverageAccumulator, Double> {
        @Override
        public AverageAccumulator createAccumulator() {
            return new AverageAccumulator();
        }

        @Override
        public AverageAccumulator add(Transaction transaction, AverageAccumulator accumulator) {
            accumulator.sum += transaction.getAmount();
            accumulator.count++;
            return accumulator;
        }

        @Override
        public Double getResult(AverageAccumulator accumulator) {
            if (accumulator.count == 0) {
                return 0.0;
            }
            return accumulator.sum / accumulator.count;
        }

        @Override
        public AverageAccumulator merge(AverageAccumulator a, AverageAccumulator b) {
            a.sum += b.sum;
            a.count += b.count;
            return a;
        }
    }
    
    // 平均值累加器
    public static class AverageAccumulator {
        public double sum = 0.0;
        public long count = 0;
    }
    
    // 交易记录类
    public static class Transaction {
        private String user;
        private Double amount;
        
        public Transaction() {}
        
        public Transaction(String user, Double amount) {
            this.user = user;
            this.amount = amount;
        }
        
        public String getUser() { return user; }
        public Double getAmount() { return amount; }
    }
}

4. 配置方法

4.1 状态描述符

每种状态类型都有对应的描述符:

java 复制代码
// ValueStateDescriptor
ValueStateDescriptor<Integer> valueDescriptor = new ValueStateDescriptor<>(
    "value-state-name",
    Integer.class,
    0  // 默认值
);

// ListStateDescriptor
ListStateDescriptor<String> listDescriptor = new ListStateDescriptor<>(
    "list-state-name",
    String.class
);

// MapStateDescriptor
MapStateDescriptor<String, Integer> mapDescriptor = new MapStateDescriptor<>(
    "map-state-name",
    String.class,
    Integer.class
);

// ReducingStateDescriptor
ReducingStateDescriptor<Double> reducingDescriptor = new ReducingStateDescriptor<>(
    "reducing-state-name",
    new SumReduceFunction(),
    Double.class
);

// AggregatingStateDescriptor
AggregatingStateDescriptor<Transaction, AverageAccumulator, Double> aggregatingDescriptor = 
    new AggregatingStateDescriptor<>(
        "aggregating-state-name",
        new AverageAggregateFunction(),
        new AverageAccumulator()
    );

4.2 状态 TTL(Time-To-Live)

可以为状态设置 TTL,自动清理过期数据:

java 复制代码
import org.apache.flink.api.common.state.StateTtlConfig;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.api.common.time.Time;

// 配置状态 TTL
StateTtlConfig ttlConfig = StateTtlConfig
    .newBuilder(Time.hours(1))  // 设置 TTL 为 1 小时
    .setUpdateType(StateTtlConfig.UpdateType.OnCreateAndWrite)  // 更新类型
    .setStateVisibility(StateTtlConfig.StateVisibility.NeverReturnExpired)  // 状态可见性
    .cleanupFullSnapshot()  // 清理策略
    .build();

// 应用 TTL 配置到状态描述符
ValueStateDescriptor<Integer> descriptor = new ValueStateDescriptor<>(
    "ttl-state",
    Integer.class,
    0
);
descriptor.enableTimeToLive(ttlConfig);

5. 完整使用示例

java 复制代码
import org.apache.flink.api.common.eventtime.WatermarkStrategy;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.state.ValueState;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.KeyedProcessFunction;
import org.apache.flink.util.Collector;

/**
 * Keyed State 完整使用示例
 * 实时统计用户访问次数
 */
public class KeyedStateExample {
    public static void main(String[] args) throws Exception {
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        
        // 创建用户访问数据流
        DataStream<String> userVisits = env.fromElements(
            "user1", "user2", "user1", "user3", "user2", "user1"
        );
        
        // 按用户分组并统计访问次数
        DataStream<String> visitCounts = userVisits
            .keyBy(user -> user)
            .process(new VisitCountFunction());
        
        visitCounts.print();
        
        env.execute("Keyed State Example");
    }
    
    /**
     * 访问次数统计函数
     */
    public static class VisitCountFunction extends KeyedProcessFunction<String, String, String> {
        private ValueState<Integer> countState;
        
        @Override
        public void open(Configuration parameters) {
            ValueStateDescriptor<Integer> descriptor = new ValueStateDescriptor<>(
                "visit-count",
                Integer.class,
                0
            );
            countState = getRuntimeContext().getState(descriptor);
        }
        
        @Override
        public void processElement(String user, Context ctx, Collector<String> out) throws Exception {
            // 获取当前计数
            Integer count = countState.value();
            
            // 增加计数
            count++;
            
            // 更新状态
            countState.update(count);
            
            // 输出结果
            out.collect("User " + user + " has visited " + count + " times");
        }
    }
}

6. 最佳实践建议

  1. 合理选择状态类型

    • 单个值使用 ValueState
    • 列表数据使用 ListState
    • 键值对数据使用 MapState
    • 需要聚合的数据使用 ReducingState 或 AggregatingState
  2. 状态命名规范

    • 使用有意义的名称
    • 避免重复名称
    • 建议使用小写字母和连字符
  3. 状态清理

    • 及时清理不需要的状态
    • 使用 TTL 自动清理过期数据
    • 在适当的时候调用 clear() 方法
  4. 性能优化

    • 避免在状态中存储大量数据
    • 合理设置状态后端
    • 考虑使用 RocksDB 状态后端处理大状态
  5. 容错处理

    • 确保状态操作的幂等性
    • 处理状态恢复时的异常情况
    • 定期检查点以保证状态一致性

通过合理使用 Keyed State,可以有效地在 Flink 应用程序中维护和处理与特定键相关联的状态信息,实现复杂的状态管理和计算逻辑。

相关推荐
CoovallyAIHub3 小时前
Arm重磅加码边缘AI!Flexible Access开放v9平台,实现高端算力普惠
深度学习·算法·计算机视觉
louisdlee.4 小时前
树状数组维护DP——前缀最大值
数据结构·c++·算法·dp
Q741_1474 小时前
C++ 分治 归并排序 归并排序VS快速排序 力扣 912. 排序数组 题解 每日一题
c++·算法·leetcode·归并排序·分治
victory04315 小时前
K8S 安装 部署 文档
算法·贪心算法·kubernetes
月疯5 小时前
样本熵和泊松指数的计算流程!!!
算法·机器学习·概率论
机器学习之心5 小时前
MATLAB基于自适应动态特征加权的K-means算法
算法·matlab·kmeans
minji...5 小时前
算法题 逆波兰表达式/计算器
数据结构·c++·算法·1024程序员节
编码追梦人6 小时前
基于 STM32 的智能语音唤醒与关键词识别系统设计 —— 从硬件集成到算法实现
stm32·算法·struts
循着风8 小时前
二叉树的多种遍历方式
数据结构·算法