spring bean初始化异步执行

目录

问题引入:很多bean初始化很慢

考虑如下简单的程序

java 复制代码
package org.example;

import org.springframework.context.annotation.AnnotationConfigApplicationContext;

public class Main {
    public static void main(String[] args) {
        AnnotationConfigApplicationContext applicationContext
                = new AnnotationConfigApplicationContext();
        applicationContext.register(Config.class);
        long startTime = System.currentTimeMillis();
        applicationContext.refresh();
        long cost  = System.currentTimeMillis() - startTime;
        System.out.println(String.format("applicationContext refresh cost:%d s", cost / 1000));
        A a = (A)applicationContext.getBean("a");
        a.sayHello();
        applicationContext.stop();
    }
}
  • A,B两个bean 基本定义如下,其初始化方法可能很耗时
java 复制代码
package org.example;

import org.springframework.stereotype.Component;
import org.springframework.stereotype.Service;

import javax.annotation.PostConstruct;
import java.util.concurrent.TimeUnit;

@Service
public class A {
    @PostConstruct
    public void init(){
        try{
            TimeUnit.SECONDS.sleep(2);
        }catch (Exception e){

        }
        System.out.println("A.init success");
    }

    public void sayHello(){
        System.out.println("A.sayHello");
    }
}

如下 applicationContext refresh要5秒多

统计bean初始化方法耗时:自定义BeanPostProcessor

即postProcessBeforeInitialization记录bean的开始时间,

postProcessAfterInitialization记录bean初始化完成时间,然后就能得到bean初始化方法耗时。

java 复制代码
package org.example;

import org.springframework.beans.BeansException;
import org.springframework.beans.factory.config.BeanPostProcessor;
import org.springframework.context.ApplicationEvent;
import org.springframework.context.ApplicationListener;
import org.springframework.context.event.ContextRefreshedEvent;
import org.springframework.stereotype.Component;

import java.time.Instant;
import java.time.LocalDateTime;
import java.time.ZoneId;
import java.time.format.DateTimeFormatter;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;

import static java.util.Collections.reverseOrder;

@Component
public class BeanInitMethodCostTimeBeanPostProcessor implements BeanPostProcessor, ApplicationListener<ApplicationEvent> {

    private static final AtomicInteger ATOMIC_INTEGER = new AtomicInteger(0);

    private Map<String, Long> startTime = new HashMap<>(1024);
    private List<Initialization> costTime = new ArrayList<>(1024);

    @Override
    public Object postProcessBeforeInitialization(Object bean, String beanName) throws BeansException {
        long start = System.currentTimeMillis();
        startTime.put(beanName, start);
        return bean;
    }

    @Override
    public Object postProcessAfterInitialization(Object bean, String beanName) throws BeansException {
        if (costTime.stream().anyMatch(it -> it.beanName.equals(beanName))) {
            return bean;
        }

        long end = System.currentTimeMillis();
        Long start = startTime.get(beanName);
        if (start != null) {
            costTime.add(Initialization.parseInitialization(beanName, start, end));
        }
        return bean;
    }

    @Override
    public void onApplicationEvent(ApplicationEvent event) {
        if (event instanceof ContextRefreshedEvent) {
            costTime.sort(reverseOrder());
            for (Initialization initialization : costTime) {
                System.out.println(initialization.toString());
            }
            startTime.clear();
            costTime.clear();
        }
    }

    private static class Initialization implements Comparable<Initialization> {
        private int serialNumber;
        private String beanName;
        private long costTime;
        private long start;
        private long end;

        public static Initialization parseInitialization(String beanName, long start, long end) {
            Initialization initialization = new Initialization();
            initialization.serialNumber = ATOMIC_INTEGER.incrementAndGet();
            initialization.costTime = end - start;
            initialization.start = start;
            initialization.end = end;
            initialization.beanName = beanName;
            return initialization;
        }

        @Override
        public String toString() {
            return "serialNumber:  " + serialNumber + ",beanName:   " + beanName + ",cost " + costTime + "  ms,"
                    + " start: " + convertTimeToString(start) + ", end:" + convertTimeToString(end);
        }

        public static String convertTimeToString(Long time) {
            DateTimeFormatter ftf = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss");
            return ftf.format(LocalDateTime.ofInstant(Instant.ofEpochMilli(time), ZoneId.systemDefault()));
        }

        @Override
        public int compareTo(Initialization o) {
            long res = costTime - o.costTime;
            return res == 0 ? 0 : (res > 0 ? 1 : -1);
        }
    }
}

可以看到如下:a,b 两个bean 初始化耗时很久, applicationContext refresh耗时也主要是由于a,b两个bean初始化导致

自定义beanFactory

继承:DefaultListableBeanFactory, 重写invokeInitMethods方法

java 复制代码
package org.example;

import org.apache.commons.lang3.tuple.Pair;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.BeanCreationException;
import org.springframework.beans.factory.FactoryBean;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.beans.factory.support.DefaultListableBeanFactory;
import org.springframework.beans.factory.support.RootBeanDefinition;

import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;

import static java.lang.Boolean.TRUE;

public class CustomBeanFactory extends DefaultListableBeanFactory {
    private static List<Future<Pair<String, Throwable>>> taskList = Collections.synchronizedList(new ArrayList<>());

    private static ExecutorService asyncInitPoll;

    private boolean contextFinished = false;
    /**
     * 此方法是 接口 InitializingBean 的方法,用于在 依赖注入完成后 执行自定义的初始化逻辑
     */
    private final String afterPropertiesSetMethodName = "afterPropertiesSet";

    public CustomBeanFactory() {
        super();
        int poolSize = 4;
        asyncInitPoll = new ThreadPoolExecutor(poolSize, poolSize, 0L, TimeUnit.MILLISECONDS, new LinkedBlockingQueue<>());
    }

    /**
     * 确保线程池任务都完成
     *
     * @return
     */
    public boolean confirmAllAsyncTaskHadSuccessfulInvoked() {
        if (taskList.size() > 0) {
            long start = System.currentTimeMillis();
            try {
                for (Future<Pair<String, Throwable>> task : taskList) {
                    long s0 = System.currentTimeMillis();
                    Pair<String, Throwable> result = task.get();
                    if (result.getRight() != null) {
                        throw result.getRight();
                    }
                }
            } catch (Throwable e) {
                if (e instanceof BeanCreationException) {
                    throw (BeanCreationException) e;
                } else {
                    throw new BeanCreationException(e.getMessage(), e);
                }
            } finally {
            }
        }
        contextFinished = true;
        asyncInitPoll.shutdown();
        return contextFinished;
    }


    /**
     * 重写初始化方法
     */
    @Override
    protected void invokeInitMethods(final String beanName, final Object bean, final RootBeanDefinition mbd)
            throws Throwable {

        if (!canAsyncInit(bean, mbd)) {
            super.invokeInitMethods(beanName, bean, mbd);
            return;
        }

        // 判断是否实现了 InitializingBean
        boolean isInitializingBean = (bean instanceof InitializingBean);

        final boolean needInvokeAfterPropertiesSetMethod = isInitializingBean && (mbd == null || !mbd
                .isExternallyManagedInitMethod(afterPropertiesSetMethodName));

        final String initMethodName = (mbd != null ? mbd.getInitMethodName() : null);

        /**
         * initMethod {@link afterPropertiesSetMethodName}
         */
        final boolean needInvokeInitMethod = initMethodName != null && !(isInitializingBean
                && afterPropertiesSetMethodName.equals(initMethodName)) &&
                !mbd.isExternallyManagedInitMethod(initMethodName);

        if (needInvokeAfterPropertiesSetMethod || needInvokeInitMethod) {
            asyncInvoke(new BeanInitMethodsInvoker() {
                @Override
                public void invoke() throws Throwable {

                    if (needInvokeAfterPropertiesSetMethod) {
                        invokeInitMethod(beanName, bean, afterPropertiesSetMethodName, false);
                    }

                    if (needInvokeInitMethod) {
                        invokeInitMethod(beanName, bean, initMethodName, mbd.isEnforceInitMethod());
                    }
                }

                @Override
                public String getBeanName() {
                    return beanName;
                }
            });
        }

    }

    // 反射执行初始化方法
    private void invokeInitMethod(String beanName, Object bean, String method, boolean enforceInitMethod)
            throws Throwable {
        Method initMethod = BeanUtils.findMethod(bean.getClass(), method, null);
        if (initMethod == null) {
            if (enforceInitMethod) {
                throw new NoSuchMethodException("Couldn't find an init method named '" + method +
                        "' on bean with name '" + beanName + "'");
            }
        } else {
            initMethod.setAccessible(true);
            initMethod.invoke(bean);
        }
    }

    private void asyncInvoke(final BeanInitMethodsInvoker beanInitMethodsInvoker) {
        taskList.add(asyncInitPoll.submit(() -> {
            long start = System.currentTimeMillis();
            try {
                beanInitMethodsInvoker.invoke();
                return Pair.of(beanInitMethodsInvoker.getBeanName(), null);
            } catch (Throwable throwable) {
                return Pair.of(beanInitMethodsInvoker.getBeanName(), new BeanCreationException(
                        beanInitMethodsInvoker.getBeanName() + ": Async Invocation of init method failed", throwable));
            } finally {
                System.out.println("asyncInvokeInitMethod " + beanInitMethodsInvoker.getBeanName() + " cost:"
                        + (System.currentTimeMillis() - start) + "ms.");
            }
        }));
    }

    // 有特殊属性则需要异步初始化
    private boolean canAsyncInit(Object bean, RootBeanDefinition mbd) {
        if (contextFinished || mbd == null || mbd.isLazyInit() || bean instanceof FactoryBean) {
            return false;
        }
        Object value = mbd.getAttribute(Constant.ASYNC_INIT);
        return TRUE.equals(value) || "true".equals(value);
    }

    private interface BeanInitMethodsInvoker {
        void invoke() throws Throwable;

        String getBeanName();
    }

}

使用线程池,将初始化方法加入任务队列,并通过反射的方式执行;

另外加上一个确保所有任务都正确执行的方法

附 bean的生命周期

参考:https://doctording.blog.csdn.net/article/details/145044487

DefaultListableBeanFactory类图

附 AbstractAutowireCapableBeanFactory的invokeInitMethods

java 复制代码
 protected void invokeInitMethods(String beanName, Object bean, @Nullable RootBeanDefinition mbd) throws Throwable {
        boolean isInitializingBean = bean instanceof InitializingBean;
        if (isInitializingBean && (mbd == null || !mbd.isExternallyManagedInitMethod("afterPropertiesSet"))) {
            if (this.logger.isDebugEnabled()) {
                this.logger.debug("Invoking afterPropertiesSet() on bean with name '" + beanName + "'");
            }

            if (System.getSecurityManager() != null) {
                try {
                    AccessController.doPrivileged(() -> {
                        ((InitializingBean)bean).afterPropertiesSet();
                        return null;
                    }, this.getAccessControlContext());
                } catch (PrivilegedActionException var6) {
                    throw var6.getException();
                }
            } else {
                ((InitializingBean)bean).afterPropertiesSet();
            }
        }

        if (mbd != null && bean.getClass() != NullBean.class) {
            String initMethodName = mbd.getInitMethodName();
            if (StringUtils.hasLength(initMethodName) && (!isInitializingBean || !"afterPropertiesSet".equals(initMethodName)) && !mbd.isExternallyManagedInitMethod(initMethodName)) {
                this.invokeCustomInitMethod(beanName, bean, mbd);
            }
        }

    }

自定义BeanFactoryPostProcessor给bean打标

即个beanDefinition打标属性

java 复制代码
package org.example;

import org.springframework.beans.BeansException;
import org.springframework.beans.factory.NoSuchBeanDefinitionException;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.BeanFactoryPostProcessor;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.stereotype.Component;

import java.util.HashSet;
import java.util.Set;

@Component
public class AsyncInitBeanFactoryPostProcessor implements BeanFactoryPostProcessor {

    private Set<String> asyncInitBeanNames = new HashSet<>();

    public AsyncInitBeanFactoryPostProcessor() {
        // 这里可以基于配置或者其它方式
        asyncInitBeanNames.add("a");
        asyncInitBeanNames.add("b");
        System.out.println("asyncInitBeanNames:" + asyncInitBeanNames);
    }

    @Override
    public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) throws BeansException {
        // 给bean加上特殊的属性
        for (String beanName : asyncInitBeanNames) {
            BeanDefinition beanDefinition = null;
            try {
                beanDefinition = beanFactory.getBeanDefinition(beanName);
            } catch (NoSuchBeanDefinitionException e) {

            }
            if (beanDefinition != null) {
                beanDefinition.setAttribute(Constant.ASYNC_INIT, true);
            }
        }
    }

    public Set<String> getAsyncInitBeanNames() {
        return asyncInitBeanNames;
    }

    public void setAsyncInitBeanNames(Set<String> asyncInitBeanNames) {
        this.asyncInitBeanNames = asyncInitBeanNames;
    }
}

要异步初始化的bean 例子如下

java 复制代码
package org.example;

import org.springframework.beans.factory.InitializingBean;
import org.springframework.stereotype.Service;

import java.util.concurrent.TimeUnit;

@Service
public class A implements InitializingBean {

    public int a;

    @Override
    public void afterPropertiesSet() {
        try {
            TimeUnit.SECONDS.sleep(2);
        } catch (Exception e) {

        }
        a = 100;
        System.out.println("A.init success");
    }

    public void sayHello() {
        System.out.println("A.sayHello:" + a);
    }
}

AnnotationConfigApplicationContext 测试

java 复制代码
package org.example;

import org.springframework.context.annotation.AnnotationConfigApplicationContext;

public class Main {
    public static void main(String[] args) {
//        AnnotationConfigApplicationContext applicationContext = new AnnotationConfigApplicationContext();
        CustomBeanFactory customBeanFactory = new CustomBeanFactory();
        AnnotationConfigApplicationContext applicationContext = new AnnotationConfigApplicationContext(customBeanFactory);

        applicationContext.register(Config.class);
        long startTime = System.currentTimeMillis();
        applicationContext.refresh();
        customBeanFactory.confirmAllAsyncTaskHadSuccessfulInvoked();
        long cost  = System.currentTimeMillis() - startTime;
        System.out.println(String.format("======= applicationContext refresh cost:%d s", cost / 1000));

        A a = (A)applicationContext.getBean("a");
        a.sayHello();
        B b = (B)applicationContext.getBean("b");
        b.sayHello();

        applicationContext.stop();
    }
}

测试

对比之前,现在启动只需要消耗最大的那个bean的初始化时间了,且初始化也是正确的。