目录
-
- 问题引入:很多bean初始化很慢
- 统计bean初始化方法耗时:自定义BeanPostProcessor
- 自定义beanFactory
-
- [附 bean的生命周期](#附 bean的生命周期)
- [附 `DefaultListableBeanFactory`类图](#附
DefaultListableBeanFactory
类图) - [附 AbstractAutowireCapableBeanFactory的invokeInitMethods](#附 AbstractAutowireCapableBeanFactory的invokeInitMethods)
- 自定义BeanFactoryPostProcessor给bean打标
- [要异步初始化的bean 例子如下](#要异步初始化的bean 例子如下)
- [AnnotationConfigApplicationContext 测试](#AnnotationConfigApplicationContext 测试)
- 测试
问题引入:很多bean初始化很慢
考虑如下简单的程序
java
package org.example;
import org.springframework.context.annotation.AnnotationConfigApplicationContext;
public class Main {
public static void main(String[] args) {
AnnotationConfigApplicationContext applicationContext
= new AnnotationConfigApplicationContext();
applicationContext.register(Config.class);
long startTime = System.currentTimeMillis();
applicationContext.refresh();
long cost = System.currentTimeMillis() - startTime;
System.out.println(String.format("applicationContext refresh cost:%d s", cost / 1000));
A a = (A)applicationContext.getBean("a");
a.sayHello();
applicationContext.stop();
}
}
- A,B两个bean 基本定义如下,其初始化方法可能很耗时
java
package org.example;
import org.springframework.stereotype.Component;
import org.springframework.stereotype.Service;
import javax.annotation.PostConstruct;
import java.util.concurrent.TimeUnit;
@Service
public class A {
@PostConstruct
public void init(){
try{
TimeUnit.SECONDS.sleep(2);
}catch (Exception e){
}
System.out.println("A.init success");
}
public void sayHello(){
System.out.println("A.sayHello");
}
}
如下 applicationContext refresh要5秒多
统计bean初始化方法耗时:自定义BeanPostProcessor
即postProcessBeforeInitialization记录bean的开始时间,
postProcessAfterInitialization记录bean初始化完成时间,然后就能得到bean初始化方法耗时。
java
package org.example;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.config.BeanPostProcessor;
import org.springframework.context.ApplicationEvent;
import org.springframework.context.ApplicationListener;
import org.springframework.context.event.ContextRefreshedEvent;
import org.springframework.stereotype.Component;
import java.time.Instant;
import java.time.LocalDateTime;
import java.time.ZoneId;
import java.time.format.DateTimeFormatter;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import static java.util.Collections.reverseOrder;
@Component
public class BeanInitMethodCostTimeBeanPostProcessor implements BeanPostProcessor, ApplicationListener<ApplicationEvent> {
private static final AtomicInteger ATOMIC_INTEGER = new AtomicInteger(0);
private Map<String, Long> startTime = new HashMap<>(1024);
private List<Initialization> costTime = new ArrayList<>(1024);
@Override
public Object postProcessBeforeInitialization(Object bean, String beanName) throws BeansException {
long start = System.currentTimeMillis();
startTime.put(beanName, start);
return bean;
}
@Override
public Object postProcessAfterInitialization(Object bean, String beanName) throws BeansException {
if (costTime.stream().anyMatch(it -> it.beanName.equals(beanName))) {
return bean;
}
long end = System.currentTimeMillis();
Long start = startTime.get(beanName);
if (start != null) {
costTime.add(Initialization.parseInitialization(beanName, start, end));
}
return bean;
}
@Override
public void onApplicationEvent(ApplicationEvent event) {
if (event instanceof ContextRefreshedEvent) {
costTime.sort(reverseOrder());
for (Initialization initialization : costTime) {
System.out.println(initialization.toString());
}
startTime.clear();
costTime.clear();
}
}
private static class Initialization implements Comparable<Initialization> {
private int serialNumber;
private String beanName;
private long costTime;
private long start;
private long end;
public static Initialization parseInitialization(String beanName, long start, long end) {
Initialization initialization = new Initialization();
initialization.serialNumber = ATOMIC_INTEGER.incrementAndGet();
initialization.costTime = end - start;
initialization.start = start;
initialization.end = end;
initialization.beanName = beanName;
return initialization;
}
@Override
public String toString() {
return "serialNumber: " + serialNumber + ",beanName: " + beanName + ",cost " + costTime + " ms,"
+ " start: " + convertTimeToString(start) + ", end:" + convertTimeToString(end);
}
public static String convertTimeToString(Long time) {
DateTimeFormatter ftf = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss");
return ftf.format(LocalDateTime.ofInstant(Instant.ofEpochMilli(time), ZoneId.systemDefault()));
}
@Override
public int compareTo(Initialization o) {
long res = costTime - o.costTime;
return res == 0 ? 0 : (res > 0 ? 1 : -1);
}
}
}
可以看到如下:a,b 两个bean 初始化耗时很久, applicationContext refresh耗时也主要是由于a,b两个bean初始化导致
自定义beanFactory
继承:DefaultListableBeanFactory, 重写invokeInitMethods方法
java
package org.example;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.BeanCreationException;
import org.springframework.beans.factory.FactoryBean;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.beans.factory.support.DefaultListableBeanFactory;
import org.springframework.beans.factory.support.RootBeanDefinition;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import static java.lang.Boolean.TRUE;
public class CustomBeanFactory extends DefaultListableBeanFactory {
private static List<Future<Pair<String, Throwable>>> taskList = Collections.synchronizedList(new ArrayList<>());
private static ExecutorService asyncInitPoll;
private boolean contextFinished = false;
/**
* 此方法是 接口 InitializingBean 的方法,用于在 依赖注入完成后 执行自定义的初始化逻辑
*/
private final String afterPropertiesSetMethodName = "afterPropertiesSet";
public CustomBeanFactory() {
super();
int poolSize = 4;
asyncInitPoll = new ThreadPoolExecutor(poolSize, poolSize, 0L, TimeUnit.MILLISECONDS, new LinkedBlockingQueue<>());
}
/**
* 确保线程池任务都完成
*
* @return
*/
public boolean confirmAllAsyncTaskHadSuccessfulInvoked() {
if (taskList.size() > 0) {
long start = System.currentTimeMillis();
try {
for (Future<Pair<String, Throwable>> task : taskList) {
long s0 = System.currentTimeMillis();
Pair<String, Throwable> result = task.get();
if (result.getRight() != null) {
throw result.getRight();
}
}
} catch (Throwable e) {
if (e instanceof BeanCreationException) {
throw (BeanCreationException) e;
} else {
throw new BeanCreationException(e.getMessage(), e);
}
} finally {
}
}
contextFinished = true;
asyncInitPoll.shutdown();
return contextFinished;
}
/**
* 重写初始化方法
*/
@Override
protected void invokeInitMethods(final String beanName, final Object bean, final RootBeanDefinition mbd)
throws Throwable {
if (!canAsyncInit(bean, mbd)) {
super.invokeInitMethods(beanName, bean, mbd);
return;
}
// 判断是否实现了 InitializingBean
boolean isInitializingBean = (bean instanceof InitializingBean);
final boolean needInvokeAfterPropertiesSetMethod = isInitializingBean && (mbd == null || !mbd
.isExternallyManagedInitMethod(afterPropertiesSetMethodName));
final String initMethodName = (mbd != null ? mbd.getInitMethodName() : null);
/**
* initMethod {@link afterPropertiesSetMethodName}
*/
final boolean needInvokeInitMethod = initMethodName != null && !(isInitializingBean
&& afterPropertiesSetMethodName.equals(initMethodName)) &&
!mbd.isExternallyManagedInitMethod(initMethodName);
if (needInvokeAfterPropertiesSetMethod || needInvokeInitMethod) {
asyncInvoke(new BeanInitMethodsInvoker() {
@Override
public void invoke() throws Throwable {
if (needInvokeAfterPropertiesSetMethod) {
invokeInitMethod(beanName, bean, afterPropertiesSetMethodName, false);
}
if (needInvokeInitMethod) {
invokeInitMethod(beanName, bean, initMethodName, mbd.isEnforceInitMethod());
}
}
@Override
public String getBeanName() {
return beanName;
}
});
}
}
// 反射执行初始化方法
private void invokeInitMethod(String beanName, Object bean, String method, boolean enforceInitMethod)
throws Throwable {
Method initMethod = BeanUtils.findMethod(bean.getClass(), method, null);
if (initMethod == null) {
if (enforceInitMethod) {
throw new NoSuchMethodException("Couldn't find an init method named '" + method +
"' on bean with name '" + beanName + "'");
}
} else {
initMethod.setAccessible(true);
initMethod.invoke(bean);
}
}
private void asyncInvoke(final BeanInitMethodsInvoker beanInitMethodsInvoker) {
taskList.add(asyncInitPoll.submit(() -> {
long start = System.currentTimeMillis();
try {
beanInitMethodsInvoker.invoke();
return Pair.of(beanInitMethodsInvoker.getBeanName(), null);
} catch (Throwable throwable) {
return Pair.of(beanInitMethodsInvoker.getBeanName(), new BeanCreationException(
beanInitMethodsInvoker.getBeanName() + ": Async Invocation of init method failed", throwable));
} finally {
System.out.println("asyncInvokeInitMethod " + beanInitMethodsInvoker.getBeanName() + " cost:"
+ (System.currentTimeMillis() - start) + "ms.");
}
}));
}
// 有特殊属性则需要异步初始化
private boolean canAsyncInit(Object bean, RootBeanDefinition mbd) {
if (contextFinished || mbd == null || mbd.isLazyInit() || bean instanceof FactoryBean) {
return false;
}
Object value = mbd.getAttribute(Constant.ASYNC_INIT);
return TRUE.equals(value) || "true".equals(value);
}
private interface BeanInitMethodsInvoker {
void invoke() throws Throwable;
String getBeanName();
}
}
使用线程池,将初始化方法加入任务队列,并通过反射的方式执行;
另外加上一个确保所有任务都正确执行的方法
附 bean的生命周期
参考:https://doctording.blog.csdn.net/article/details/145044487
附 DefaultListableBeanFactory
类图

附 AbstractAutowireCapableBeanFactory的invokeInitMethods
java
protected void invokeInitMethods(String beanName, Object bean, @Nullable RootBeanDefinition mbd) throws Throwable {
boolean isInitializingBean = bean instanceof InitializingBean;
if (isInitializingBean && (mbd == null || !mbd.isExternallyManagedInitMethod("afterPropertiesSet"))) {
if (this.logger.isDebugEnabled()) {
this.logger.debug("Invoking afterPropertiesSet() on bean with name '" + beanName + "'");
}
if (System.getSecurityManager() != null) {
try {
AccessController.doPrivileged(() -> {
((InitializingBean)bean).afterPropertiesSet();
return null;
}, this.getAccessControlContext());
} catch (PrivilegedActionException var6) {
throw var6.getException();
}
} else {
((InitializingBean)bean).afterPropertiesSet();
}
}
if (mbd != null && bean.getClass() != NullBean.class) {
String initMethodName = mbd.getInitMethodName();
if (StringUtils.hasLength(initMethodName) && (!isInitializingBean || !"afterPropertiesSet".equals(initMethodName)) && !mbd.isExternallyManagedInitMethod(initMethodName)) {
this.invokeCustomInitMethod(beanName, bean, mbd);
}
}
}
自定义BeanFactoryPostProcessor给bean打标
即个beanDefinition打标属性
java
package org.example;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.NoSuchBeanDefinitionException;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.BeanFactoryPostProcessor;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.stereotype.Component;
import java.util.HashSet;
import java.util.Set;
@Component
public class AsyncInitBeanFactoryPostProcessor implements BeanFactoryPostProcessor {
private Set<String> asyncInitBeanNames = new HashSet<>();
public AsyncInitBeanFactoryPostProcessor() {
// 这里可以基于配置或者其它方式
asyncInitBeanNames.add("a");
asyncInitBeanNames.add("b");
System.out.println("asyncInitBeanNames:" + asyncInitBeanNames);
}
@Override
public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) throws BeansException {
// 给bean加上特殊的属性
for (String beanName : asyncInitBeanNames) {
BeanDefinition beanDefinition = null;
try {
beanDefinition = beanFactory.getBeanDefinition(beanName);
} catch (NoSuchBeanDefinitionException e) {
}
if (beanDefinition != null) {
beanDefinition.setAttribute(Constant.ASYNC_INIT, true);
}
}
}
public Set<String> getAsyncInitBeanNames() {
return asyncInitBeanNames;
}
public void setAsyncInitBeanNames(Set<String> asyncInitBeanNames) {
this.asyncInitBeanNames = asyncInitBeanNames;
}
}
要异步初始化的bean 例子如下
java
package org.example;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.stereotype.Service;
import java.util.concurrent.TimeUnit;
@Service
public class A implements InitializingBean {
public int a;
@Override
public void afterPropertiesSet() {
try {
TimeUnit.SECONDS.sleep(2);
} catch (Exception e) {
}
a = 100;
System.out.println("A.init success");
}
public void sayHello() {
System.out.println("A.sayHello:" + a);
}
}
AnnotationConfigApplicationContext 测试
java
package org.example;
import org.springframework.context.annotation.AnnotationConfigApplicationContext;
public class Main {
public static void main(String[] args) {
// AnnotationConfigApplicationContext applicationContext = new AnnotationConfigApplicationContext();
CustomBeanFactory customBeanFactory = new CustomBeanFactory();
AnnotationConfigApplicationContext applicationContext = new AnnotationConfigApplicationContext(customBeanFactory);
applicationContext.register(Config.class);
long startTime = System.currentTimeMillis();
applicationContext.refresh();
customBeanFactory.confirmAllAsyncTaskHadSuccessfulInvoked();
long cost = System.currentTimeMillis() - startTime;
System.out.println(String.format("======= applicationContext refresh cost:%d s", cost / 1000));
A a = (A)applicationContext.getBean("a");
a.sayHello();
B b = (B)applicationContext.getBean("b");
b.sayHello();
applicationContext.stop();
}
}
测试

对比之前,现在启动只需要消耗最大的那个bean的初始化时间了,且初始化也是正确的。