275. Java Stream API - flatMap 操作:展开一对多的关系,拉平你的流!
🧠 背景:我们为什么需要 flatMap?
假设我们有以下结构:
- 每个
Country拥有多个City - 每个
City有一个人口数population
我们的目标是:统计所有城市的总人口数。
最直接的写法当然是嵌套 for 循环:
java
int totalPopulation = 0;
for (Country country : countries) {
for (City city : country.cities()) {
totalPopulation += city.population();
}
}
System.out.println("Total population = " + totalPopulation);
📌 输出:
java
Total population = 24493
虽然有效,但 Java 8 之后我们有了更优雅的方式:使用流 + flatMap 来处理一对多的关系。
🔁 用 flatMap 优雅替代嵌套循环
✅ 定义模型结构
java
record City(String name, int population) {}
record Country(String name, List<City> cities) {}
✅ 初始化数据
java
City newYork = new City("New York", 8_258);
City losAngeles = new City("Los Angeles", 3_821);
Country usa = new Country("USA", List.of(newYork, losAngeles));
City london = new City("London", 8_866);
City manchester = new City("Manchester", 568);
Country uk = new Country("United Kingdom", List.of(london, manchester));
City paris = new City("Paris", 2_103);
City marseille = new City("Marseille", 877);
Country france = new Country("France", List.of(paris, marseille));
List<Country> countries = List.of(usa, uk, france);
🚀 使用 flatMap 重写统计逻辑
java
int totalPopulation = countries.stream()
.flatMap(country -> country.cities().stream()) // 展开所有城市
.mapToInt(City::population) // 提取人口
.sum(); // 累加总人口
System.out.println("Total population = " + totalPopulation);
📌 输出:
java
Total population = 24493
🔍 flatMap 是如何工作的?
flatMap 是两个操作的组合:
步骤 1️⃣:映射(map)
java
country -> country.cities().stream()
这一步将每个 Country 映射为它的城市流,得到的是一个 Stream<Stream<City>>(流的流)。
步骤 2️⃣:展平(flat)
flatMap 会自动帮你把多个子流合并为一个连续的扁平流 (Stream<City>),这样你就可以对所有城市统一处理!
🎯 类比图示:
java
Stream<Country> ---映射---> Stream<Stream<City>>
|
+---> 展平(flatten)---> Stream<City>
📚 延伸案例:Map 结构的 flatMap
假设我们有一个 Continent 类型,它包含一个 Map:
java
record Continent(Map<String, Country> countries) {}
此时,如果你想从 Continent 中提取所有国家,可以这样写:
java
Function<Continent, Stream<Country>> continentToCountry =
continent -> continent.countries().values().stream();
再进一步,还可以这样嵌套 flatMap:
java
int total = continents.stream()
.flatMap(continent -> continent.countries().values().stream())
.flatMap(country -> country.cities().stream())
.mapToInt(City::population)
.sum();
🧠 小结:flatMap 用法口诀
| 用法场景 | 对应方法 |
|---|---|
| 一对一映射(每个元素 → 单个新值) | .map() |
| 一对多映射(每个元素 → 多个新值) | .flatMap() |
| 提取嵌套集合中的内容并扁平化 | .flatMap() |
转换成基础类型流(int/long/double) |
.mapToInt() 等 |
🧪 练习建议(课堂可选)
❓ 问题:下面代码的输出是什么?
java
List<String> words = List.of("java", "stream", "api");
List<Character> chars = words.stream()
.flatMap(word -> word.chars().mapToObj(c -> (char) c))
.toList();
System.out.println(chars);
🎯 答案:
java
[j, a, v, a, s, t, r, e, a, m, a, p, i]