Spring Boot利用dag加速Spring beans初始化

1.什么是Dag?

有向无环图(Directed Acyclic Graph),简称DAG,是一种有向图,其中没有从节点出发经过若干条边后再回到该节点的路径。换句话说,DAG中不存在环路。这种数据结构常用于表示并解决具有依赖关系的问题。

DAG的特性

  • 首先,DAG中的节点可以有入度和出度。节点的入度是指指向该节点的边的数量,而节点的出度是指由该节点指向其他节点的边的数量。在DAG中,节点的入度可以是0或正整数,而出度可以是0或正整数,但不能同时为负数。
  • DAG的另一个重要性质是存在一个或多个拓扑排序。拓扑排序是DAG中节点的线性排列,满足任意一条有向边的起点在排序中都位于终点之前。可以使用深度优先搜索(DFS)或宽度优先搜索(BFS)算法来生成拓扑排序。

DAG的应用

  1. 任务调度
  2. 编译器优化
  3. 数据流分析
  4. 电路设计

2.如何加速Spring Bean初始化?

在Spring框架中进行DAG(有向无环图)分析以实现并行初始化,可以有效提升应用程序启动的性能。通常情况下,Spring应用程序的Bean是按依赖顺序初始化的,而通过DAG分析可以识别出哪些Bean之间没有依赖关系,并行初始化这些Bean可以减少启动时间。以下是实现思路:

1. 识别依赖关系,构建DAG

首先需要识别Spring Bean之间的依赖关系。可以通过@DependsOn注解、构造器注入或@Autowired等方式获取Bean依赖。具体步骤:

  • 遍历Spring上下文中的所有Bean定义。
  • 根据Bean的依赖关系构建DAG,节点代表Bean,边表示依赖关系。

Spring的ApplicationContext提供了getBeanDefinitionNames()方法可以列出所有的Bean,通过BeanDefinition可以分析出依赖。

2. 拓扑排序

通过拓扑排序(Topological Sorting)对DAG进行排序,以确保Bean按依赖顺序初始化。拓扑排序可以确定哪些Bean可以并行初始化,哪些Bean必须在某些Bean之后初始化。 使用算法如Kahn's Algorithm或DFS找到所有没有依赖的Bean(入度为0的节点),这些节点可以并行初始化。

3. 并行初始化Bean

在完成拓扑排序后,使用多线程来并行初始化可以同时启动的Bean。可以使用Java的ExecutorService或类似的线程池机制来管理并发的Bean初始化过程。 步骤:

  • 针对所有入度为0的节点,启动一个线程来初始化它们。
  • 当某个Bean初始化完成后,减少它所依赖的其他Bean的入度值。
  • 当某个Bean的入度为0时,可以在另一个线程中启动它的初始化。

4. Spring Integration

可以通过BeanFactoryPostProcessorApplicationContextInitializer来挂钩到Spring的初始化流程中,分析Bean之间的依赖关系,并将并行化初始化逻辑集成到Spring容器的启动过程中。 具体方法:

  • 实现BeanFactoryPostProcessor,在Bean初始化之前分析Bean的依赖并构建DAG。
  • 在Bean初始化阶段,使用多线程并行处理独立的Bean。

3.代码工程

整体的依赖关系如下:

实验目标

实现自定义Bean并行化加载

pom.xml

xml 复制代码
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <parent>
        <groupId>org.springframework.boot</groupId>
        <artifactId>spring-boot-starter-parent</artifactId>
        <version>3.2.1</version>
    </parent>
    <modelVersion>4.0.0</modelVersion>

    <artifactId>dag</artifactId>

    <properties>
        <maven.compiler.source>17</maven.compiler.source>
        <maven.compiler.target>17</maven.compiler.target>
    </properties>
    <dependencies>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-web</artifactId>
        </dependency>

        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-autoconfigure</artifactId>
        </dependency>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-test</artifactId>
            <scope>test</scope>
        </dependency>
        <dependency>
            <groupId>org.jgrapht</groupId>
            <artifactId>jgrapht-core</artifactId>
            <version>1.5.1</version>
        </dependency>

    </dependencies>
    <build>
        <pluginManagement>
            <plugins>
                <plugin>
                    <groupId>org.apache.maven.plugins</groupId>
                    <artifactId>maven-compiler-plugin</artifactId>
                    <version>3.8.1</version>
                    <configuration>
                        <fork>true</fork>
                        <failOnError>false</failOnError>
                    </configuration>
                </plugin>

                <plugin>
                    <groupId>org.apache.maven.plugins</groupId>
                    <artifactId>maven-surefire-plugin</artifactId>
                    <version>2.22.2</version>
                    <configuration>
                        <forkCount>0</forkCount>
                        <failIfNoTests>false</failIfNoTests>
                    </configuration>
                </plugin>
            </plugins>
        </pluginManagement>
    </build>
</project>

Bean初始化

scss 复制代码
package com.et.config;

import org.jgrapht.graph.DefaultEdge;
import org.jgrapht.graph.DirectedAcyclicGraph;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.BeanFactoryPostProcessor;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.stereotype.Component;

import java.util.*;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.stream.Collectors;

@Component
public class DAGBeanInitializer implements BeanFactoryPostProcessor {

    private final ExecutorService executorService = Executors.newFixedThreadPool(10);

    @Override
    public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) throws BeansException {
        Map<String, BeanDefinition> beanDefinitionMap = new HashMap<>();
        for (String beanName : beanFactory.getBeanDefinitionNames()) {
            BeanDefinition beanDefinition = beanFactory.getBeanDefinition(beanName);
            beanDefinitionMap.put(beanName, beanDefinition);
        }

        // build DAG
        DirectedAcyclicGraph<String, DefaultEdge> dag = buildDAG(beanDefinitionMap,beanFactory);

        // bean layers
        List<Set<String>> layers = getBeansByLayer(dag);
        System.out.println("layers:"+layers);
        // init bean by layers
        initializeBeansInLayers(layers, beanFactory);
    }

    // DAG Bean
    private DirectedAcyclicGraph<String, DefaultEdge> buildDAG(Map<String, BeanDefinition> beanDefinitionMap, ConfigurableListableBeanFactory beanFactory) {
        DependencyResolver resolver = new DependencyResolver(beanFactory);
        DirectedAcyclicGraph<String, DefaultEdge> dag = new DirectedAcyclicGraph<>(DefaultEdge.class);
        for (String beanName : beanDefinitionMap.keySet()) {
            if(shouldLoadBean(beanName)) {
                dag.addVertex(beanName);
                String[] dependencies = beanDefinitionMap.get(beanName).getDependsOn();
                if (dependencies != null) {
                    for (String dependency : dependencies) {
                        dag.addEdge(dependency, beanName); 
                    }
                }
                // get @Autowired dependencies
                Set<String> autowireDependencies = resolver.getAllDependencies(beanName);
                for (String autowireDependency : autowireDependencies) {
                    // convert beanName
                    String autowireBeanName = convertToBeanName(autowireDependency);
                    dag.addVertex(autowireBeanName);
                    dag.addEdge(autowireBeanName, beanName);
                }
            }
        }
        return dag;
    }
    private String convertToBeanName(String className) {
        String simpleName = className.substring(className.lastIndexOf('.') + 1);
        return Character.toLowerCase(simpleName.charAt(0)) + simpleName.substring(1);
    }
    private List<Set<String>> getBeansByLayer(DirectedAcyclicGraph<String,DefaultEdge> dag) {
        List<Set<String>> layers = new ArrayList<>();
        Map<String, Integer> inDegree = new HashMap<>();
        Queue<String> queue = new LinkedList<>();

        // init all nodes degree
        for (String vertex : dag) {
            int degree = dag.inDegreeOf(vertex);
            inDegree.put(vertex, degree);
            if (degree == 0) {
                queue.offer(vertex);  //zero degree as the first layer
            }
        }

        // BFS process everyLayers
        while (!queue.isEmpty()) {
            Set<String> currentLayer = new HashSet<>();
            int size = queue.size();
            for (int i = 0; i < size; i++) {
                String currentBean = queue.poll();
                currentLayer.add(currentBean);

                // iterator layers
                for (String successor : getSuccessors(dag,currentBean)) {
                    inDegree.put(successor, inDegree.get(successor) - 1);
                    if (inDegree.get(successor) == 0) {
                        queue.offer(successor);  // add next layer when the degress is zero
                    }
                }
            }
            layers.add(currentLayer);
        }

        return layers;
    }
    // get next node
    private Set<String> getSuccessors(DirectedAcyclicGraph<String, DefaultEdge> dag, String vertex) {
        // get outgoingEdges
        Set<DefaultEdge> outgoingEdges = dag.outgoingEdgesOf(vertex);

        // find the next node
        return outgoingEdges.stream()
                .map(edge -> dag.getEdgeTarget(edge))
                .collect(Collectors.toSet());
    }
    // init beans by layer
    private void initializeBeansInLayers(List<Set<String>> layers, ConfigurableListableBeanFactory beanFactory) {
        for (Set<String> layer : layers) {
            // Beans of the same layer can be initialized in parallel
            List<CompletableFuture<Void>> futures = new ArrayList<>();
            for (String beanName : layer) {
                // only load beans that  wrote by yourself
                if (shouldLoadBean(beanName)) {
                    CompletableFuture<Void> future = CompletableFuture.runAsync(() -> {
                        try {
                            beanFactory.getBean(beanName);  // init Bean
                        } catch (Exception e) {
                            System.err.println("Failed to initialize bean: " + beanName);
                            e.printStackTrace();
                        }
                    }, executorService);
                    futures.add(future);
                }
            }
            //Wait for all beans in the current layer to be initialized before initializing the next layer.
            CompletableFuture<Void> allOf = CompletableFuture.allOf(futures.toArray(new CompletableFuture[0]));
            allOf.join();  // make sure to be done on current layer
        }
    }

    private boolean shouldLoadBean(String beanName) {
        return beanName.startsWith("helloWorldController")
                ||beanName.startsWith("serviceOne")
                ||beanName.startsWith("serviceTwo")
                ||beanName.startsWith("serviceThree");
    }
}

获取bean@Autowired依赖

java 复制代码
package com.et.config;

import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;

import java.lang.reflect.Field;
import java.util.HashSet;
import java.util.Set;

@Component
public class DependencyResolver {

    private final ConfigurableListableBeanFactory beanFactory;

    @Autowired
    public DependencyResolver(ConfigurableListableBeanFactory beanFactory) {
        this.beanFactory = beanFactory;
    }

    public Set<String> getAllDependencies(String beanName) {
        Set<String> dependencies = new HashSet<>();

        // get Bean definite
        BeanDefinition beanDefinition = beanFactory.getBeanDefinition(beanName);

        // reflect
        try {
            Class<?> beanClass = Class.forName(beanDefinition.getBeanClassName());
            Field[] fields = beanClass.getDeclaredFields();
            for (Field field : fields) {
                if (field.isAnnotationPresent(Autowired.class)) {
                    dependencies.add(field.getType().getName()); 
                }
            }
        } catch (ClassNotFoundException e) {
            e.printStackTrace();
        }

        return dependencies;
    }
}

controller

typescript 复制代码
package com.et.controller;

import com.et.service.ServiceTwo;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import com.et.service.*;
import java.util.HashMap;
import java.util.Map;

@RestController
public class HelloWorldController {
    @Autowired
    ServiceOne ServiceOne;
    @Autowired
    ServiceTwo ServiceTwo;
    @RequestMapping("/hello")
    public Map<String, Object> showHelloWorld(){
        Map<String, Object> map = new HashMap<>();
        map.put("msg", "HelloWorld");
        return map;
    }
}

service

kotlin 复制代码
package com.et.service;

import org.springframework.stereotype.Service;

/**
 * @author liuhaihua
 * @version 1.0
 * @ClassName ServiceOne
 * @Description todo
 * @date 2024/09/20/ 14:01
 */
@Service
public class ServiceOne {
    private   void  sayhi(){
        System.out.println("this is service one sayhi");
    }
}

package com.et.service;

import org.springframework.stereotype.Service;

/**
 * @author liuhaihua
 * @version 1.0
 * @ClassName ServiceOne
 * @Description todo
 * @date 2024/09/20/ 14:01
 */
@Service
public class ServiceThree {

    private   void  sayhi(){
        System.out.println("this is service three sayhi");
    }
}

package com.et.service;

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

/**
 * @author liuhaihua
 * @version 1.0
 * @ClassName ServiceOne
 * @Description todo
 * @date 2024/09/20/ 14:01
 */
@Service
public class ServiceTwo {
    @Autowired
    ServiceThree serviceThree;
    private   void  sayhi(){
        System.out.println("this is service two sayhi");
    }
}

只是一些关键代码,所有代码请参见下面代码仓库

代码仓库

4.测试

启动Spring Boot工程,查看bean加载顺序如下

lua 复制代码
2024-09-20T15:51:27.081+08:00 INFO 33188 --- [ main] com.et.DemoApplication : Starting DemoApplication using Java 17.0.9 with PID 33188 (D:\IdeaProjects\ETFramework\dag\target\classes started by Dell in D:\IdeaProjects\ETFramework)
2024-09-20T15:51:27.085+08:00 INFO 33188 --- [ main] com.et.DemoApplication : No active profile set, falling back to 1 default profile: "default"
layers:[[serviceOne, serviceThree], [serviceTwo], [helloWorldController]]
2024-09-20T15:51:28.286+08:00 INFO 33188 --- [ main] o.s.b.w.embedded.tomcat.TomcatWebServer : Tomcat initialized with port 8088 (http)
2024-09-20T15:51:28.297+08:00 INFO 33188 --- [ main] o.apache.catalina.core.StandardService : Starting service [Tomcat]
2024-09-20T15:51:28.297+08:00 INFO 33188 --- [ main] o.apache.catalina.core.StandardEngine : Starting Servlet engine: [Apache Tomcat/10.1.17]
2024-09-20T15:51:28.373+08:00 INFO 33188 --- [ main] o.a.c.c.C.[Tomcat].[localhost].[/] : Initializing Spring embedded WebApplicationContext
2024-09-20T15:51:28.374+08:00 INFO 33188 --- [ main] w.s.c.ServletWebServerApplicationContext : Root WebApplicationContext: initialization completed in 1198 ms
2024-09-20T15:51:28.725+08:00 INFO 33188 --- [ main] o.s.b.w.embedded.tomcat.TomcatWebServer : Tomcat started on port 8088 (http) with context path ''
2024-09-20T15:51:28.732+08:00 INFO 33188 --- [ main] com.et.DemoApplication

5.引用

相关推荐
monkey_meng9 分钟前
【Rust中的迭代器】
开发语言·后端·rust
余衫马12 分钟前
Rust-Trait 特征编程
开发语言·后端·rust
monkey_meng16 分钟前
【Rust中多线程同步机制】
开发语言·redis·后端·rust
七星静香17 分钟前
laravel chunkById 分块查询 使用时的问题
java·前端·laravel
Jacob程序员18 分钟前
java导出word文件(手绘)
java·开发语言·word
ZHOUPUYU18 分钟前
IntelliJ IDEA超详细下载安装教程(附安装包)
java·ide·intellij-idea
stewie622 分钟前
在IDEA中使用Git
java·git
Elaine20239137 分钟前
06 网络编程基础
java·网络
G丶AEOM38 分钟前
分布式——BASE理论
java·分布式·八股
落落鱼201339 分钟前
tp接口 入口文件 500 错误原因
java·开发语言