java
复制代码
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang.mutable.MutableLong;
import org.jetbrains.annotations.NotNull;
import org.springframework.lang.NonNull;
import org.springframework.lang.Nullable;
import org.springframework.util.StringUtils;
import java.util.*;
import java.util.function.BiConsumer;
import java.util.function.Function;
import java.util.stream.Collectors;
/**
* 树结构工具类
*
* @author Mr丶s
* @date 2024/7/16 下午4:58
* @description
*/
@SuppressWarnings("unused")
public class TreeUtil {
/**
* ParentId 形式树构建,当一个节点的 ParentId 不为任何其他节点的 ID,该节点为根节点
*
* @param list 所有参与构建的节点
* @param getId 获取 ID
* @param getParentId 获取父级 ID
* @param comparator 同级排序(非必传)
* @param setSub 设置子级
* @param <T> 树节点类型
* @param <I> 树节点 ID 类型
* @return 树
*/
public static <T, I> List<T> buildByParentId(@NonNull List<T> list,
@NonNull Function<T, I> getId,
@NonNull Function<T, I> getParentId,
@Nullable Comparator<T> comparator,
@NonNull BiConsumer<T, List<T>> setSub) {
List<T> tree = rootListByParentId(list, getId, getParentId);
sortList(tree, comparator);
tree.forEach(n -> buildByParentId(n, list, getId, getParentId, comparator, setSub));
return tree;
}
/**
* 编码形式树构建,当一个节点的编码不以任何其他节点编码为前缀时,该节点为根节点
* <p>
* 所有节点的子节点树形必须不为 null
*
* @param list 所有参与构建的节点
* @param getCode 获取编码
* @param comparator 同级排序(非必传)
* @param getSub 获取子级 List
* @param <T> 树节点类型
* @param <C> 树节点编码类型
* @return 树
*/
public static <T, C extends String> List<T> buildByCode(@NonNull List<T> list,
@NonNull Function<T, C> getCode,
@Nullable Comparator<T> comparator,
@NonNull Function<T, List<T>> getSub,
@NonNull BiConsumer<T, List<T>> setSub) {
Map<C, List<T>> code2List = groupByCode(list, getCode);
List<T> tree = new ArrayList<>();
code2List.forEach((k, v) -> tree.add(buildNodeByCode(v, getCode, getSub, setSub)));
sortTree(tree, comparator, getSub);
return tree;
}
/**
* 获取父级
*
* @param list list
* @param ids ids
* @param idExtractor id策略
* @param parentIdExtractor parentId策略
* @param containSelf 是否包含当前ids的对象
* @return 父级list
*/
public static <T, R> List<T> getParent(List<T> list, List<R> ids,
Function<? super T, ? extends R> idExtractor,
Function<? super T, ? extends R> parentIdExtractor,
boolean containSelf) {
if (CollectionUtils.isEmpty(list) || CollectionUtils.isEmpty(ids)) {
return new ArrayList<>();
}
List<T> result = new ArrayList<>();
Map<? extends R, T> map = list.stream()
.collect(Collectors.toMap(idExtractor, Function.identity()));
for (R id : ids) {
for (T item : list) {
if (!Objects.equals(id, idExtractor.apply(item))) {
continue;
}
int i = 1;
T parent = item;
while (parent != null) {
if (i == 1) {
if (containSelf) {
result.add(parent);
}
} else {
result.add(parent);
}
parent = map.get(parentIdExtractor.apply(parent));
i++;
}
}
}
result = result.stream()
.filter(ObjectUtils.distinctByKey(idExtractor))
.collect(Collectors.toList());
return result;
}
/**
* 获取子级
*
* @param list list
* @param ids ids
* @param idExtractor id策略
* @param parentIdExtractor parentId策略
* @param containSelf 是否包含当前id的对象
* @return 子级list
*/
public static <T, R> List<T> getChildren(List<T> list, List<R> ids,
Function<? super T, ? extends R> idExtractor,
Function<? super T, ? extends R> parentIdExtractor,
boolean containSelf) {
if (CollectionUtils.isEmpty(list) || CollectionUtils.isEmpty(ids)) {
return new ArrayList<>();
}
List<T> result = new ArrayList<>();
if (containSelf) {
list.stream()
.filter(c -> ids.contains(idExtractor.apply(c)))
.forEach(result::add);
}
Map<? extends R, List<T>> map = list.stream()
.filter(c -> Objects.nonNull(parentIdExtractor.apply(c)))
.collect(Collectors.groupingBy(parentIdExtractor));
List<R> parentIds = new ArrayList<>(ids);
for (int i = 0; i < parentIds.size(); i++) {
R parentId = parentIds.get(i);
List<T> childList = map.get(parentId);
if (childList == null) {
continue;
}
result.addAll(childList);
childList.forEach(c -> parentIds.add(idExtractor.apply(c)));
}
result = result.stream()
.filter(ObjectUtils.distinctByKey(idExtractor))
.collect(Collectors.toList());
return result;
}
/**
* 顶层节点开始搜索所有具有指定属性的节点
*
* @param tree 需要搜索的树
* @param getKey 获取属性
* @param key 属性值
* @param <T> 树节点类型
* @param <I> 属性值类型
* @return 所有满足指定属性的节点
*/
public static <T, I> List<T> searchTree4All(@NonNull List<T> tree, @NonNull Function<T, I> getKey,
@NonNull Function<T, List<T>> getSub, @NonNull I key) {
List<T> matched = new ArrayList<>();
tree.forEach(n -> {
I currentKey = getKey.apply(n);
if (currentKey != null && currentKey.equals(key)) {
matched.add(n);
}
List<T> sub = getSub.apply(n);
if (sub != null && sub.size() != 0) {
matched.addAll(searchTree4All(sub, getKey, getSub, key));
}
});
return matched;
}
/**
* 顶层节点开始搜索第一个具有指定属性的节点
*
* @param tree 需要搜索的树
* @param getKey 获取属性
* @param key 属性值
* @param <T> 树节点类型
* @param <I> 属性值类型
* @return 第一个具有指定属性的节点
*/
public static <T, I> Optional<T> searchTree4One(@NonNull List<T> tree, @NotNull Function<T, I> getKey,
@NonNull Function<T, List<T>> getSub, @NotNull I key) {
for (T n : tree) {
I currentKey = getKey.apply(n);
if (currentKey != null && currentKey.equals(key)) {
return Optional.of(n);
}
List<T> sub = getSub.apply(n);
if (sub != null && sub.size() != 0) {
Optional<T> subSearchResult = searchTree4One(sub, getKey, getSub, key);
if (subSearchResult.isPresent()) {
return subSearchResult;
}
}
}
return Optional.empty();
}
/**
* 将树转换为列表(非对象拷贝,下级属性中仍存在相关引用)
*
* @param tree 需要展开的树
* @param getSub 获取树的下级
* @param <T> 树节点类型
* @return 树列表
*/
public static <T> List<T> tree2List(@NonNull List<T> tree, @NonNull Function<T, List<T>> getSub) {
List<T> list = new ArrayList<>();
tree.forEach(n -> {
list.add(n);
List<T> sub = getSub.apply(n);
if (sub != null && sub.size() != 0) {
list.addAll(tree2List(sub, getSub));
}
});
return list;
}
/**
* 为树节点添加随机 ID
*
* @param tree 需要添加节点的树
* @param getSub 获取树节点的下级
* @param setId 设置树节点的 ID
* @param setParentId 设置树节点的父 ID
* @param parentId 初始 parentId,即根节点的 parentId 值(为 null 时为 0)
* @param idCounter ID 计数器(为 null 时 id 从 1 开始)
* @param <T> 树节点类型
*/
public static <T> void addRandomId(@NonNull List<T> tree, @NonNull Function<T, List<T>> getSub,
@NonNull BiConsumer<T, Long> setId, @NonNull BiConsumer<T, Long> setParentId,
@Nullable Long parentId, @Nullable MutableLong idCounter) {
parentId = parentId == null ? 0L : parentId;
idCounter = idCounter == null ? new MutableLong(1L) : idCounter;
for (T n : tree) {
long id = idCounter.longValue();
idCounter.increment();
setId.accept(n, id);
setParentId.accept(n, parentId);
List<T> sub = getSub.apply(n);
if (sub != null && sub.size() != 0) {
addRandomId(sub, getSub, setId, setParentId, id, idCounter);
}
}
}
/**
* 在树中按名称进行搜索
*
* @param tree 原始树
* @param getSub 获取子节点
* @param getName 获取名称
* @param searchName 搜索名称
* @param reserveChild 父节点名称匹配时是否保留所有子节点
* @param <T> 树类型泛型
*/
public static <T> void filterTreeByName(@NonNull List<T> tree, @NonNull Function<T, List<T>> getSub,
@NonNull Function<T, String> getName, @NonNull String searchName,
@NonNull Boolean reserveChild) {
if (!StringUtils.hasLength(searchName)) {
return;
}
for (Iterator<T> iterator = tree.iterator(); iterator.hasNext(); ) {
T n = iterator.next();
String name = getName.apply(n);
// 保留子节点
if (reserveChild && StringUtils.hasLength(name) && name.contains(searchName)) {
continue;
}
List<T> sub = getSub.apply(n);
if (sub != null && sub.size() != 0) {
// 向下递归
filterTreeByName(sub, getSub, getName, searchName, reserveChild);
}
// 子集处理完了,处理自身
if ((sub == null || sub.size() == 0) && (!StringUtils.hasLength(name) || !name.contains(searchName))) {
iterator.remove();
}
}
}
/**
* 在树中按id进行搜索
*
* @param tree 原始树
* @param getSub 获取子节点
* @param getId 获取id
* @param searchId 搜索id
* @param reserveChild 父节点id匹配时是否保留所有子节点
* @param <T> 树类型泛型
*/
public static <T> void filterTreeById(@NonNull List<T> tree, @NonNull Function<T, List<T>> getSub,
@NonNull Function<T, Long> getId, @NonNull Long searchId,
@NonNull Boolean reserveChild) {
for (Iterator<T> iterator = tree.iterator(); iterator.hasNext(); ) {
T n = iterator.next();
Long id = getId.apply(n);
// 保留子节点
if (reserveChild && id != null && id.equals(searchId)) {
continue;
}
List<T> sub = getSub.apply(n);
if (sub != null && sub.size() != 0) {
// 向下递归
filterTreeById(sub, getSub, getId, searchId, reserveChild);
}
// 子集处理完了,处理自身
if ((sub == null || sub.size() == 0) && (getId != null || !id.equals(searchId))) {
iterator.remove();
}
}
}
private static <T> void sortList(List<T> list, Comparator<T> comparator) {
if (comparator == null) {
return;
}
list.sort(comparator);
}
private static <T> void sortTree(List<T> tree, Comparator<T> comparator, Function<T, List<T>> getSub) {
sortList(tree, comparator);
tree.forEach(n -> {
List<T> sub = getSub.apply(n);
if (sub != null && sub.size() != 0) {
sortList(sub, comparator);
sortTree(sub, comparator, getSub);
}
});
}
private static <T, I> List<T> rootListByParentId(List<T> list, Function<T, I> getId, Function<T, I> getParentId) {
Set<I> idSet = list.stream().map(getId).collect(Collectors.toSet());
Set<I> parentIdSet = list.stream().map(getParentId).collect(Collectors.toSet());
parentIdSet.removeAll(idSet);
return list.stream().filter(i -> parentIdSet.contains(getParentId.apply(i))).collect(Collectors.toList());
}
private static <T, I> void buildByParentId(T node, List<T> list, Function<T, I> getId, Function<T, I> getParentId
, Comparator<T> comparator, BiConsumer<T, List<T>> setSub) {
I id = getId.apply(node);
List<T> sub = list.stream().filter(i -> id.equals(getParentId.apply(i))).collect(Collectors.toList());
if (sub.size() != 0) {
sortList(sub, comparator);
setSub.accept(node, sub);
sub.forEach(s -> buildByParentId(s, list, getId, getParentId, comparator, setSub));
}
}
private static <T, C extends String> Map<C, List<T>> groupByCode(List<T> list, Function<T, C> getCode) {
// 按照 code 排序,一个 List 为一个根节点开始的所有树节点
List<T> sortedCodeList = list.stream().sorted(Comparator.comparing(getCode)).collect(Collectors.toList());
Map<C, List<T>> code2List = new HashMap<>();
C flagCode = null;
for (T item : sortedCodeList) {
C currentCode = getCode.apply(item);
if (flagCode == null) {
flagCode = currentCode;
}
if (!currentCode.startsWith(flagCode)) {
flagCode = currentCode;
}
List<T> subList = code2List.computeIfAbsent(flagCode, k -> new ArrayList<>());
subList.add(item);
}
return code2List;
}
private static <T, C extends String> T buildNodeByCode(List<T> subList, Function<T, C> getCode, Function<T,
List<T>> getSub, BiConsumer<T, List<T>> setSub) {
if (subList.size() == 0) {
throw new IllegalStateException("树构建异常(子节点不存在)");
}
// 子找父具有唯一性
Collections.reverse(subList);
for (int i = 0; i < subList.size() - 1; i++) {
T parent = findParentByCode(subList.get(i), subList.subList(i + 1, subList.size()), getCode);
List<T> sub = getSub.apply(parent);
if (sub == null) {
sub = new ArrayList<>();
setSub.accept(parent, sub);
}
sub.add(subList.get(i));
}
return subList.get(subList.size() - 1);
}
private static <T, C extends String> T findParentByCode(T currentNode, List<T> subList, Function<T, C> getCode) {
C currentCode = getCode.apply(currentNode);
for (T node : subList) {
C searchCode = getCode.apply(node);
// 避免异常数据 stack over flow
if (currentCode.startsWith(searchCode) && searchCode.length() != currentCode.length()) {
return node;
}
}
throw new IllegalStateException("构建异常(父节点查找失败)");
}
/**
* 按对象属性过滤去重,返回Predicate
*
* @param keyExtractor 属性策略
* @return Predicate
*/
public static <T> Predicate<T> distinctByKey(Function<? super T, ?> keyExtractor) {
Objects.requireNonNull(keyExtractor);
Map<Object, Boolean> seen = new ConcurrentHashMap<>();
return t -> Objects.isNull(seen.putIfAbsent(keyExtractor.apply(t), Boolean.TRUE));
}
}