java
复制代码
package com.sunxiansheng.sunspring.ioc;
import com.sunxiansheng.sunspring.annotation.Component;
import com.sunxiansheng.sunspring.annotation.ComponentScan;
import com.sunxiansheng.sunspring.annotation.Resource;
import com.sunxiansheng.sunspring.annotation.Scope;
import com.sunxiansheng.sunspring.annotation.myenum.MyScope;
import com.sunxiansheng.sunspring.processor.BeanPostProcessor;
import com.sunxiansheng.sunspring.processor.InitializingBean;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.net.URL;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
/**
* Description: 自定义Spring容器
* @Author sun
* @Create 2024/8/4 16:35
* @Version 1.0
*/
public class SunSpringApplicationContext {
private static final Logger log = LoggerFactory.getLogger(SunSpringApplicationContext.class);
/**
* bean定义的map
*/
private ConcurrentHashMap<String, BeanDefintion> beanDefintionMap = new ConcurrentHashMap<>();
/**
* 单例池
*/
private ConcurrentHashMap<String, Object> singletonObjects = new ConcurrentHashMap<>();
/**
* bean的后置处理器
*/
private List<BeanPostProcessor> beanPostProcessorList = new ArrayList<>();
// 构造器,接收配置类的class对象
public SunSpringApplicationContext(Class<?> configClass) throws ClassNotFoundException, InstantiationException, IllegalAccessException {
// 完成bean的扫描,将bean的信息记录到beanDefintionMap中
beanDefinitionByScan(configClass);
// 初始化单例池
initSingletonObjects();
}
/**
* 给某个bean对象完成依赖注入
*/
private void populateBeans(Object bean) {
// 扫描beanDefintionMap中的bean信息,对bean对象中的属性进行依赖注入
// 获取Class对象
Class<?> clazz = bean.getClass();
// 获取所有字段
Field[] fields = clazz.getDeclaredFields();
// 判断字段上是否有@Resource注解
for (Field field : fields) {
if (field.isAnnotationPresent(Resource.class)) {
// 获取字段名
String fieldName = field.getName();
// 根据字段名获取bean对象
Object beanObject = null;
// 从beanDefintionMap中获取bean对象
BeanDefintion beanDefintion = beanDefintionMap.get(fieldName);
try {
// 根据bean的定义信息创建bean对象
beanObject = getBean(fieldName);
} catch (Exception e) {
throw new RuntimeException(e);
}
// 设置字段可访问
field.setAccessible(true);
try {
// 依赖注入
field.set(bean, beanObject);
log.info("依赖注入成功:{} => {}.{}", beanObject.getClass(), clazz, fieldName);
} catch (IllegalAccessException e) {
e.printStackTrace();
}
}
}
}
/**
* 初始化单例池
*/
private void initSingletonObjects() {
// 将beanDefintionMap中的bean信息创建成bean对象放到单例池中
beanDefintionMap.forEach((beanName, beanDefintion) -> {
try {
// 根据bean的定义信息创建bean对象
Object bean = createBean(beanDefintion);
if (bean != null) {
// 将bean对象放到单例池中
singletonObjects.put(beanName, bean);
}
} catch (Exception e) {
e.printStackTrace();
}
});
// 打印单例池中的bean对象
log.info("根据bean定义信息初始化单例池:{}", singletonObjects);
}
// 返回容器中的对象
public Object getBean(String name) throws Exception {
BeanDefintion beanDefintion = beanDefintionMap.get(name);
if (beanDefintion == null) {
throw new NullPointerException("在bean定义中没有找到bean对象");
}
// 根据单例和多例来获取bean对象
MyScope scope = beanDefintion.getScope();
Object bean = null;
if (scope == MyScope.SINGLETON) {
log.info("getBean单例对象:{}", singletonObjects.get(name));
// 单例就直接从单例池中获取对象
bean = singletonObjects.get(name);
} else {
// 多例就创建一个新的对象
bean = createProtoTypeBean(beanDefintion);
}
// 给bean对象完成依赖注入
populateBeans(bean);
// 记录当前对象,因为后置处理器可能会返回一个新的对象
Object current = bean;
// 初始化方法之前调用后置处理器 postProcessBeforeInitialization
for (BeanPostProcessor beanPostProcessor : beanPostProcessorList) {
Object bean1 = beanPostProcessor.postProcessBeforeInitialization(bean, name);
// 如果beanPostProcessor返回的对象为空,则使用原来的对象
if (bean1 != null) {
current = bean1;
}
}
// 初始化bean
init(current);
// 初始化方法之后调用后置处理器 postProcessAfterInitialization
for (BeanPostProcessor beanPostProcessor : beanPostProcessorList) {
bean = beanPostProcessor.postProcessAfterInitialization(current, name);
// 如果beanPostProcessor返回的对象为空,则使用原来的对象
if (bean == null) {
bean = current;
}
}
log.info("getBean多例对象:{}", bean);
return bean;
}
/**
* 初始化bean
* @param bean
*/
public void init(Object bean) {
if (bean instanceof InitializingBean) {
((InitializingBean) bean).afterPropertiesSet();
}
}
/**
* 根据bean的定义信息创建bean对象(单例bean)
* @param beanDefintion
* @return
* @throws Exception
*/
private Object createBean(BeanDefintion beanDefintion) throws Exception {
// 得到bean的类型
Class<?> clazz = beanDefintion.getClazz();
// 根据bean的作用域创建bean对象,多例就不创建了,单例就创建
if (beanDefintion.getScope() == MyScope.PROTOTYPE) {
return null;
}
Object bean = clazz.getDeclaredConstructor().newInstance();
return bean;
}
/**
* 创建多例bean
* @param beanDefintion
* @return
* @throws InstantiationException
* @throws IllegalAccessException
* @throws InvocationTargetException
* @throws NoSuchMethodException
*/
private static Object createProtoTypeBean(BeanDefintion beanDefintion) throws InstantiationException, IllegalAccessException, InvocationTargetException, NoSuchMethodException {
// 多例就创建一个新的对象
Class<?> clazz = beanDefintion.getClazz();
Object bean = clazz.getDeclaredConstructor().newInstance();
return bean;
}
/**
* 完成bean的扫描,将bean的信息记录到beanDefintionMap中
* @param configClass
* @throws ClassNotFoundException
*/
private void beanDefinitionByScan(Class<?> configClass) {
// 传进来一个配置类的Class对象
// 一、获取要扫描的包
// 1.首先反射获取类的注解信息
ComponentScan componentScan = configClass.getDeclaredAnnotation(ComponentScan.class);
// 2.通过注解来获取要扫描的包的路径
String path = componentScan.packagePath();
log.info("扫描的包路径:{}", path);
// 二、得到要扫描包的.class文件对象,从而得到全路径进行反射
// 1.获取App类加载器
ClassLoader classLoader = SunSpringApplicationContext.class.getClassLoader();
// 2.获取要扫描包的真实路径,默认刚开始在根目录下
path = path.replace(".", "/");
URL resource = classLoader.getResource(path);
// 3.由该路径创建一个文件对象,可使用resource.getFile()将URL类型转化为String类型
File file = new File(resource.getFile());
// 4.遍历该文件夹下的所有.class文件对象
if (file.isDirectory()) {
File[] files = file.listFiles();
for (File f : files) {
// 反射注入容器
// 1.获取所有文件的绝对路径
String absolutePath = f.getAbsolutePath();
// 只处理class文件
if (absolutePath.endsWith(".class")) {
// 2.分割出类名
String className = extractClassName(absolutePath);
// 3.得到全路径
String fullPath = path.replace("/", ".") + "." + className;
// 4.判断是否需要注入容器,查看有没有自定义的注解Component
Class<?> aClass = null;
try {
aClass = classLoader.loadClass(fullPath);
} catch (ClassNotFoundException e) {
throw new RuntimeException(e);
}
// 如果该类使用了注解Component则说明是一个spring bean
if (aClass.isAnnotationPresent(Component.class)) {
log.info("扫描到Spring Bean:{}", aClass);
// 将Bean的后置处理器加入到beanPostProcessorList中
// 判断Class对象是否实现了BeanPostProcessor接口
if (BeanPostProcessor.class.isAssignableFrom(aClass)) {
Object o = null;
try {
o = aClass.getDeclaredConstructor().newInstance();
} catch (Exception e) {
log.info("BeanPostProcessor实例化失败:{}", e);
}
if (o instanceof BeanPostProcessor) {
beanPostProcessorList.add((BeanPostProcessor) o);
}
log.info("BeanPostProcessor实例化成功:{}", o);
// 直接跳过,不需要将BeanPostProcessor加入到beanDefintionMap中
continue;
}
// 将bean的信息记录到beanDefintionMap中
BeanDefintion beanDefintion = new BeanDefintion();
// 1.获取Scope注解的value值
if (aClass.isAnnotationPresent(Scope.class)) {
Scope scope = aClass.getDeclaredAnnotation(Scope.class);
MyScope value = scope.value();
// 放到beanDefintion中
beanDefintion.setScope(value);
} else {
// 如果没有指定作用域,则默认为单例
beanDefintion.setScope(MyScope.SINGLETON);
}
beanDefintion.setClazz(aClass);
// 2.获取Component注解的value值
Component component = aClass.getDeclaredAnnotation(Component.class);
String beanName = component.value();
if ("".equals(beanName)) {
// 如果没有指定value属性,则使用类名首字母小写作为bean的id
beanName = className.substring(0, 1).toLowerCase() + className.substring(1);
}
// 3.将bean的id和bean的信息放到beanDefintionMap中
beanDefintionMap.put(beanName, beanDefintion);
} else {
log.info("这不是一个Spring Bean={}", aClass);
}
}
}
}
// 打印beanDefintionMap中的bean信息
log.info("将bean定义信息放到beanDefintionMap:{}", beanDefintionMap);
}
/**
* 分割出类名
* 类似于 com/sunxiansheng/sunspring/compent/MonsterService.class 的类名
* @param filePath
* @return
*/
private String extractClassName(String filePath) {
// 获取最后一个 '/' 的位置
int lastSlashIndex = filePath.lastIndexOf('/');
// 获取最后一个 '.' 的位置
int lastDotIndex = filePath.lastIndexOf('.');
// 提取两者之间的字符串作为类名
return filePath.substring(lastSlashIndex + 1, lastDotIndex);
}
}