Spring 扩展:动态使某个 @Import 方式导入的 @Configuration 类失效

背景

在中大型 Spring/Spring Boot 项目里,配置常常被拆分成多个模块(数据源、缓存、消息、业务功能包等)。我们既想按模块组合配置 ,又希望不依赖组件扫描就能显式引入 某些配置类。 @Import 正是为这种"把别处定义的 Bean 定义/配置引入到当前上下文 "而设计的机制,它在配置类解析阶段 被处理,比 Bean 实例化更早,属于注册 BeanDefinition 的元编程工具 。 然而我们可能会遇到多个jar包中定义同名的类,导致出现BeanDefinitionOverrideException 一般情况下我们可以通过几个方式解决

  • 在项目中覆写类,然后起个别名 但classloader如果先加载到框架里的类这种方式将无效
  • 动态排除Configuration类,让其不生效。 思考如何通过扩展spring的方式解决此问题?

解决方案

我们知道,@Import注解具有更高的优先级被spring处理。在spring中自动配置类可以使用 spring.autoconfigure.exclude 方式进行排除,但是@Import类的不能。所以我们需要有一种方式让@Import方式导入的@Configuration类也有方式被排除,回顾@Import处理的源码和时机

  • org.springframework.context.annotation.AnnotationConfigUtils#registerAnnotationConfigProcessors

    • 注册ConfigurationClassPostProcessor 类,进行@Configuration类的扫描处理
  • org.springframework.context.support.AbstractApplicationContext#invokeBeanFactoryPostProcessors

    • PostProcessorRegistrationDelegate.invokeBeanFactoryPostProcessors(beanFactory, getBeanFactoryPostProcessors());
    • getBeanFactoryPostProcessors() 获取注册到AbstractApplicationContext 的BeanDefinitionRegistryPostProcessor 调用postProcessBeanDefinitionRegistry
  • org.springframework.context.support.PostProcessorRegistrationDelegate#invokeBeanFactoryPostProcessors

    • 方法内的逻辑
    java 复制代码
    for (BeanFactoryPostProcessor postProcessor : beanFactoryPostProcessors) {  
        if (postProcessor instanceof BeanDefinitionRegistryPostProcessor) {  
           BeanDefinitionRegistryPostProcessor registryProcessor =  
                 (BeanDefinitionRegistryPostProcessor) postProcessor;  
           registryProcessor.postProcessBeanDefinitionRegistry(registry);  
           registryProcessors.add(registryProcessor);  
        }  
        else {  
           regularPostProcessors.add(postProcessor);  
        }  
    }
    • 这个registryProcessor.postProcessBeanDefinitionRegistry(registry);就是我们扩展点之一
  • org.springframework.context.annotation.ConfigurationClassPostProcessor#postProcessBeanDefinitionRegistry

    • 后续进行Configuration的扫描
  • org.springframework.context.annotation.ConfigurationClassPostProcessor#processConfigBeanDefinitions

    • 后续处理Configuration的类
  • org.springframework.context.annotation.ConfigurationClassBeanDefinitionReader#loadBeanDefinitions

    • 后续处理每个Configuration类的@Bean方法
  • org.springframework.context.annotation.ConfigurationClassBeanDefinitionReader#loadBeanDefinitionsForBeanMethod

    • 后续继续处理每个方法
  • org.springframework.beans.factory.support.BeanDefinitionRegistry#registerBeanDefinition

    • 后续然后调用注册bean的方法 以上是处理Configuration类的主要源码路径。我们发现可以registryProcessor.postProcessBeanDefinitionRegistry(registry); 这个地方进行扩展。 扩展方式为:
java 复制代码
public static void main(String[] args) {  
    ConfigurableApplicationContext context = new SpringApplicationBuilder(BeanOverrideApplication.class)  
            .initializers(new ApplicationContextInitializer<ConfigurableApplicationContext>() {  
                @Override  
                public void initialize(ConfigurableApplicationContext applicationContext) {  
                    applicationContext.addBeanFactoryPostProcessor(new MyBeanFactoryPostProcessor());  
                }  
            })  
            .run(args);  
}
  • 在ApplicationContextInitializer的时候进行注册MyBeanFactoryPostProcessor
    • 当然也可用spring.factoris中设置org.springframework.context.ApplicationContextInitializer=xxxx类进行注册
    • 注意ApplicationContextInitializer 方式初始化比较早,可以保证MyBeanFactoryPostProcessor生效
    • 一般的使用@Component或者@Bean方式注册的 BeanFactoryPostProcessor 是不会生效的。
  • MyBeanFactoryPostProcessor 进行扩展 ConfigurationClassPostProcessor 注册自定义的ConfigurationClassPostProcessor
java 复制代码
public class MyBeanFactoryPostProcessor implements BeanDefinitionRegistryPostProcessor {  
    private static final Logger log = LoggerFactory.getLogger(MyBeanFactoryPostProcessor.class);  
  
    @Override  
    public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) throws BeansException {  
    }  
  
    @Override  
    public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry) throws BeansException {  
        if (registry instanceof DefaultListableBeanFactory) {  
            DefaultListableBeanFactory beanFactory = (DefaultListableBeanFactory) registry;  
            String[] beanDefinitionNames = beanFactory.getBeanDefinitionNames();  
            for (String beanDefinitionName : beanDefinitionNames) {  
                BeanDefinition beanDefinition = beanFactory.getBeanDefinition(beanDefinitionName);  
                if (beanDefinition instanceof RootBeanDefinition) {  
                    RootBeanDefinition rootBeanDefinition = (RootBeanDefinition) beanDefinition;  
                    String beanClassName = rootBeanDefinition.getBeanClassName();  
                    if ("org.springframework.context.annotation.internalConfigurationAnnotationProcessor".equals(beanDefinitionName)  
                            && "org.springframework.context.annotation.ConfigurationClassPostProcessor".equals(beanClassName)) {  
                        log.info("beanDefinitionName: {} beanDefinition:{}", beanDefinitionName, beanDefinition);  
                        registry.removeBeanDefinition(beanDefinitionName);  
                        rootBeanDefinition.setBeanClass(MyConfigurationClassPostProcessor.class);  
                        registry.registerBeanDefinition(beanDefinitionName, rootBeanDefinition);  
                    }  
                }  
            }  
        }  
    }  
}
  • 使用自定义的ConfigurationClassPostProcessor 处理Configuration类
java 复制代码
package org.springframework.context.annotation;  
  
import org.apache.commons.logging.Log;  
import org.apache.commons.logging.LogFactory;  
import org.springframework.beans.factory.config.BeanDefinition;  
import org.springframework.beans.factory.config.BeanDefinitionHolder;  
import org.springframework.beans.factory.config.SingletonBeanRegistry;  
import org.springframework.beans.factory.parsing.ProblemReporter;  
import org.springframework.beans.factory.parsing.SourceExtractor;  
import org.springframework.beans.factory.support.BeanDefinitionRegistry;  
import org.springframework.beans.factory.support.BeanNameGenerator;  
import org.springframework.core.env.Environment;  
import org.springframework.core.env.StandardEnvironment;  
import org.springframework.core.io.ResourceLoader;  
import org.springframework.core.metrics.ApplicationStartup;  
import org.springframework.core.metrics.StartupStep;  
import org.springframework.core.type.classreading.CachingMetadataReaderFactory;  
import org.springframework.core.type.classreading.MetadataReaderFactory;  
import org.springframework.lang.Nullable;  
import org.springframework.util.ReflectionUtils;  
  
import java.lang.reflect.Field;  
import java.util.*;  
  
/**  
 * @author wxl  
 */public class MyConfigurationClassPostProcessor extends ConfigurationClassPostProcessor {  
    private final Log logger = LogFactory.getLog(getClass());  
    private static final String IMPORT_REGISTRY_BEAN_NAME =  
            ConfigurationClassPostProcessor.class.getName() + ".importRegistry";  
  
    @Override  
    public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry) {  
        super.postProcessBeanDefinitionRegistry(registry);  
    }  
  
    /**  
     * Build and validate a configuration model based on the registry of     * {@link Configuration} classes.  
     */    @Override  
    public void processConfigBeanDefinitions(BeanDefinitionRegistry registry) {  
        List<BeanDefinitionHolder> configCandidates = new ArrayList<>();  
        String[] candidateNames = registry.getBeanDefinitionNames();  
  
        for (String beanName : candidateNames) {  
            BeanDefinition beanDef = registry.getBeanDefinition(beanName);  
            if (beanDef.getAttribute(ConfigurationClassUtils.CONFIGURATION_CLASS_ATTRIBUTE) != null) {  
                if (logger.isDebugEnabled()) {  
                    logger.debug("Bean definition has already been processed as a configuration class: " + beanDef);  
                }  
            } else if (ConfigurationClassUtils.checkConfigurationClassCandidate(beanDef, getMetadataReaderFactory())) {  
                configCandidates.add(new BeanDefinitionHolder(beanDef, beanName));  
            }  
        }  
  
        // Return immediately if no @Configuration classes were found  
        if (configCandidates.isEmpty()) {  
            return;  
        }  
  
        // Sort by previously determined @Order value, if applicable  
        configCandidates.sort((bd1, bd2) -> {  
            int i1 = ConfigurationClassUtils.getOrder(bd1.getBeanDefinition());  
            int i2 = ConfigurationClassUtils.getOrder(bd2.getBeanDefinition());  
            return Integer.compare(i1, i2);  
        });  
  
        // Detect any custom bean name generation strategy supplied through the enclosing application context  
        SingletonBeanRegistry sbr = null;  
        if (registry instanceof SingletonBeanRegistry) {  
            sbr = (SingletonBeanRegistry) registry;  
            if (!isLocalBeanNameGeneratorSet()) {  
                BeanNameGenerator generator = (BeanNameGenerator) sbr.getSingleton(  
                        AnnotationConfigUtils.CONFIGURATION_BEAN_NAME_GENERATOR);  
                if (generator != null) {  
                    setComponentScanBeanNameGenerator(generator);  
                    setImportBeanNameGenerator(generator);  
                }  
            }  
        }  
  
        if (getEnvironment() == null) {  
            setEnvironment(new StandardEnvironment());  
        }  
  
        // Parse each @Configuration class  
        ConfigurationClassParser parser = new ConfigurationClassParser(  
                getMetadataReaderFactory(), getProblemReporter(), getEnvironment(),  
                getResourceLoader(), getComponentScanBeanNameGenerator(), registry);  
  
        Set<BeanDefinitionHolder> candidates = new LinkedHashSet<>(configCandidates);  
        Set<ConfigurationClass> alreadyParsed = new HashSet<>(configCandidates.size());  
        do {  
            StartupStep processConfig = getApplicationStartup().start("spring.context.config-classes.parse");  
            parser.parse(candidates);  
            parser.validate();  
  
            Set<ConfigurationClass> configClasses = new LinkedHashSet<>(parser.getConfigurationClasses());  
            configClasses.removeAll(alreadyParsed);  
  
            // Read the model and create bean definitions based on its content  
            if (getReader() == null) {  
                ConfigurationClassBeanDefinitionReader configurationClassBeanDefinitionReader = new ConfigurationClassBeanDefinitionReader(  
                        registry, getSourceExtractor(), getResourceLoader(), getEnvironment(),  
                        getImportBeanNameGenerator(), parser.getImportRegistry());  
                setReader(configurationClassBeanDefinitionReader);  
            }  
  
            //【todo 进行configClass过滤,排除】 
            configClasses.removeIf(this::isExcludedClass);  

			// 读取所有@Configuration类进行bean定义注册
            this.getReader().loadBeanDefinitions(configClasses);  
            alreadyParsed.addAll(configClasses);  
            processConfig.tag("classCount", () -> String.valueOf(configClasses.size())).end();  
  
            candidates.clear();  
            if (registry.getBeanDefinitionCount() > candidateNames.length) {  
                String[] newCandidateNames = registry.getBeanDefinitionNames();  
                Set<String> oldCandidateNames = new HashSet<>(Arrays.asList(candidateNames));  
                Set<String> alreadyParsedClasses = new HashSet<>();  
                for (ConfigurationClass configurationClass : alreadyParsed) {  
                    alreadyParsedClasses.add(configurationClass.getMetadata().getClassName());  
                }  
                for (String candidateName : newCandidateNames) {  
                    if (!oldCandidateNames.contains(candidateName)) {  
                        BeanDefinition bd = registry.getBeanDefinition(candidateName);  
                        if (ConfigurationClassUtils.checkConfigurationClassCandidate(bd, getMetadataReaderFactory()) &&  
                                !alreadyParsedClasses.contains(bd.getBeanClassName())) {  
                            candidates.add(new BeanDefinitionHolder(bd, candidateName));  
                        }  
                    }  
                }  
                candidateNames = newCandidateNames;  
            }  
        }  
        while (!candidates.isEmpty());  
  
        // Register the ImportRegistry as a bean in order to support ImportAware @Configuration classes  
        if (sbr != null && !sbr.containsSingleton(IMPORT_REGISTRY_BEAN_NAME)) {  
            sbr.registerSingleton(IMPORT_REGISTRY_BEAN_NAME, parser.getImportRegistry());  
        }  
  
  
        if (getMetadataReaderFactory() instanceof CachingMetadataReaderFactory) {  
            // Clear cache in externally provided MetadataReaderFactory; this is a no-op  
            // for a shared cache since it'll be cleared by the ApplicationContext.            ((CachingMetadataReaderFactory) getMetadataReaderFactory()).clearCache();  
        }  
    }  
  
    private boolean isExcludedClass(ConfigurationClass configurationClass) {  
    //排除逻辑
//        if ("com.ccwxl.practice.cc.TaskBeanConfiguration".equals(configurationClass.getMetadata().getClassName())) {  
//            logger.info("exclude class:" + configurationClass.getMetadata().getClassName());  
//            return true;  
//        }  
        return false;  
    }  
  
    @Nullable  
    private void setReader(ConfigurationClassBeanDefinitionReader configurationClassBeanDefinitionReader) {  
        Field reader = ReflectionUtils.findField(ConfigurationClassPostProcessor.class, "reader");  
        if (reader != null) {  
            reader.setAccessible(true);  
            ReflectionUtils.setField(reader, this, configurationClassBeanDefinitionReader);  
        }  
    }  
  
    @Nullable  
    public MetadataReaderFactory getMetadataReaderFactory() {  
        Field metadataReaderFactory = ReflectionUtils.findField(ConfigurationClassPostProcessor.class, "metadataReaderFactory");  
        if (metadataReaderFactory != null) {  
            metadataReaderFactory.setAccessible(true);  
            Object field = ReflectionUtils.getField(metadataReaderFactory, this);  
            if (field instanceof MetadataReaderFactory) {  
                return (MetadataReaderFactory) field;  
            }  
        }  
        return null;  
    }  
  
    @Nullable  
    public ConfigurationClassBeanDefinitionReader getReader() {  
        Field reader = ReflectionUtils.findField(ConfigurationClassPostProcessor.class, "reader");  
        if (reader != null) {  
            reader.setAccessible(true);  
            Object field = ReflectionUtils.getField(reader, this);  
            if (field instanceof ConfigurationClassBeanDefinitionReader) {  
                return (ConfigurationClassBeanDefinitionReader) field;  
            }  
        }  
        return null;  
    }  
  
    // 新增的反射方法  
    @Nullable  
    private boolean isLocalBeanNameGeneratorSet() {  
        Field localBeanNameGeneratorSet = ReflectionUtils.findField(ConfigurationClassPostProcessor.class, "localBeanNameGeneratorSet");  
        if (localBeanNameGeneratorSet != null) {  
            localBeanNameGeneratorSet.setAccessible(true);  
            Object field = ReflectionUtils.getField(localBeanNameGeneratorSet, this);  
            if (field instanceof Boolean) {  
                return (Boolean) field;  
            }  
        }  
        return false;  
    }  
  
    @Nullable  
    private void setComponentScanBeanNameGenerator(BeanNameGenerator generator) {  
        Field componentScanBeanNameGenerator = ReflectionUtils.findField(ConfigurationClassPostProcessor.class, "componentScanBeanNameGenerator");  
        if (componentScanBeanNameGenerator != null) {  
            componentScanBeanNameGenerator.setAccessible(true);  
            ReflectionUtils.setField(componentScanBeanNameGenerator, this, generator);  
        }  
    }  
  
    @Nullable  
    private void setImportBeanNameGenerator(BeanNameGenerator generator) {  
        Field importBeanNameGenerator = ReflectionUtils.findField(ConfigurationClassPostProcessor.class, "importBeanNameGenerator");  
        if (importBeanNameGenerator != null) {  
            importBeanNameGenerator.setAccessible(true);  
            ReflectionUtils.setField(importBeanNameGenerator, this, generator);  
        }  
    }  
  
    @Nullable  
    private Environment getEnvironment() {  
        Field environment = ReflectionUtils.findField(ConfigurationClassPostProcessor.class, "environment");  
        if (environment != null) {  
            environment.setAccessible(true);  
            Object field = ReflectionUtils.getField(environment, this);  
            if (field instanceof Environment) {  
                return (Environment) field;  
            }  
        }  
        return null;  
    }  
  
    @Override  
    public void setEnvironment(Environment environment) {  
        Field envField = ReflectionUtils.findField(ConfigurationClassPostProcessor.class, "environment");  
        if (envField != null) {  
            envField.setAccessible(true);  
            ReflectionUtils.setField(envField, this, environment);  
        }  
    }  
  
    @Nullable  
    private ProblemReporter getProblemReporter() {  
        Field problemReporter = ReflectionUtils.findField(ConfigurationClassPostProcessor.class, "problemReporter");  
        if (problemReporter != null) {  
            problemReporter.setAccessible(true);  
            Object field = ReflectionUtils.getField(problemReporter, this);  
            if (field instanceof ProblemReporter) {  
                return (ProblemReporter) field;  
            }  
        }  
        return null;  
    }  
  
    @Nullable  
    private ResourceLoader getResourceLoader() {  
        Field resourceLoader = ReflectionUtils.findField(ConfigurationClassPostProcessor.class, "resourceLoader");  
        if (resourceLoader != null) {  
            resourceLoader.setAccessible(true);  
            Object field = ReflectionUtils.getField(resourceLoader, this);  
            if (field instanceof ResourceLoader) {  
                return (ResourceLoader) field;  
            }  
        }  
        return null;  
    }  
  
    @Nullable  
    private BeanNameGenerator getComponentScanBeanNameGenerator() {  
        Field componentScanBeanNameGenerator = ReflectionUtils.findField(ConfigurationClassPostProcessor.class, "componentScanBeanNameGenerator");  
        if (componentScanBeanNameGenerator != null) {  
            componentScanBeanNameGenerator.setAccessible(true);  
            Object field = ReflectionUtils.getField(componentScanBeanNameGenerator, this);  
            if (field instanceof BeanNameGenerator) {  
                return (BeanNameGenerator) field;  
            }  
        }  
        return null;  
    }  
  
    @Nullable  
    private SourceExtractor getSourceExtractor() {  
        Field sourceExtractor = ReflectionUtils.findField(ConfigurationClassPostProcessor.class, "sourceExtractor");  
        if (sourceExtractor != null) {  
            sourceExtractor.setAccessible(true);  
            Object field = ReflectionUtils.getField(sourceExtractor, this);  
            if (field instanceof SourceExtractor) {  
                return (SourceExtractor) field;  
            }  
        }  
        return null;  
    }  
  
    @Nullable  
    private BeanNameGenerator getImportBeanNameGenerator() {  
        Field importBeanNameGenerator = ReflectionUtils.findField(ConfigurationClassPostProcessor.class, "importBeanNameGenerator");  
        if (importBeanNameGenerator != null) {  
            importBeanNameGenerator.setAccessible(true);  
            Object field = ReflectionUtils.getField(importBeanNameGenerator, this);  
            if (field instanceof BeanNameGenerator) {  
                return (BeanNameGenerator) field;  
            }  
        }  
        return null;  
    }  
  
    @Nullable  
    private ApplicationStartup getApplicationStartup() {  
        Field applicationStartup = ReflectionUtils.findField(ConfigurationClassPostProcessor.class, "applicationStartup");  
        if (applicationStartup != null) {  
            applicationStartup.setAccessible(true);  
            Object field = ReflectionUtils.getField(applicationStartup, this);  
            if (field instanceof ApplicationStartup) {  
                return (ApplicationStartup) field;  
            }  
        }  
        return null;  
    }  
}
  • 注意要将此类放到org.springframework.context.annotation包下面。
  • 其中isExcludedClass 的方法可以定制逻辑,进行相关的Configuration类排除

总结

  • 可以动态让某个@Import导入的@Configuration类进行失效可以实现有效控制
相关推荐
Moonbit17 分钟前
# 量子位 AI 沙龙回顾丨用 MoonBit Pilot 解答 AI Coding 的未来
后端
码事漫谈27 分钟前
C++ vector越界问题完全解决方案:从基础防护到现代C++新特性
后端
啾啾大学习1 小时前
让我们快速入门DDD
后端·领域驱动设计
老张聊数据集成1 小时前
数据分析师如何构建自己的底层逻辑?
后端·数据分析
咕噜分发企业签名APP加固彭于晏2 小时前
市面上有多少智能体平台
前端·后端
掘金一周2 小时前
我开源了一款 Canvas “瑞士军刀”,十几种“特效与工具”开箱即用 | 掘金一周 8.14
前端·人工智能·后端
开心就好20253 小时前
前端性能优化移动端网页滚动卡顿与掉帧问题实战
后端
语落心生3 小时前
如何利用Paimon做流量定时检查? --- 试试标签表
后端