讲解如何获取springBoot classLoader和如何往spring boot classLoader 添加包
获取spring boot classLoader
java
/**
* SpringBoot打包模式下是SpringBoot的类加载器
* @return
*/
public static ClassLoader genClassLoader(){
if (!SystemInfo.INSTANCE.isStandalone()) {
return null;
}
AgentLogger.info("[SF Agent] Use springboot classloader.");
try {
// 检测是否SpringBoot打包方式
Class launcherClass = Class.forName(STANDALONE_CLASS);
if (launcherClass != null) {
String jarfilePath = ClassUtils.getClassFilePath(launcherClass);
if (null == jarfilePath) {
return null;
}
// 获取Manifest,并且获取Main-Class,可能是JarLauncher或者是WarLauncher
Class jarFileArchiveClass = Class.forName("org.springframework.boot.loader.archive.JarFileArchive");
Object jarFileArchiveObj = jarFileArchiveClass.getDeclaredConstructor(File.class).newInstance(new File(jarfilePath));
Method manifestMethod = jarFileArchiveClass.getDeclaredMethod("getManifest");
Manifest manifest = (Manifest)manifestMethod.invoke(jarFileArchiveObj);
String mainClass = manifest.getMainAttributes().getValue("Main-Class");
// 反射调用Launcer,获取可以加载SpringBoot依赖的类加载器
Class launcher = Class.forName(mainClass);
Map<String, Method> methods = getClassLoadderMethod(launcher);
Method classLoaderMethod = methods.get("classLoader");
Method archivesMethod = methods.get("archive");
if (null == classLoaderMethod || null == archivesMethod) {
AgentLogger.error("No class loader found.");
return null;
}
archivesMethod.setAccessible(true);
classLoaderMethod.setAccessible(true);
Object launcherObj = launcher.getDeclaredConstructor().newInstance();
Object classPathArchives = archivesMethod.invoke(launcherObj);
return (ClassLoader)classLoaderMethod.invoke(launcherObj, classPathArchives);
}
} catch (Exception e) {
}
return null;
}
通过字节码增强往spring boot classLoader 添加包
java
public abstract class AbstractEnhanceModel {
public abstract byte[] enhance(String className, byte[] source);
}
java
import enhance.AbstractEnhanceModel;
import org.objectweb.asm.ClassReader;
import org.objectweb.asm.ClassWriter;
public class SpringBootEnhance extends AbstractEnhanceModel {
public static final SpringBootEnhance INSTANCE = new SpringBootEnhance();
private SpringBootEnhance(){}
@Override
public byte[] enhance(String className, byte[] source) {
ClassWriter classWriter = new ClassWriter(ClassWriter.COMPUTE_MAXS);
ClassReader classReader = new ClassReader(source);
AgentJarEnhance agentInitEnhance = new AgentJarEnhance(classWriter);
classReader.accept(agentInitEnhance, ClassReader.SKIP_DEBUG);
return classWriter.toByteArray();
}
}
java
import org.objectweb.asm.ClassVisitor;
import org.objectweb.asm.Label;
import org.objectweb.asm.MethodVisitor;
import org.objectweb.asm.Opcodes;
import static org.objectweb.asm.Opcodes.*;
public class AgentJarEnhance extends ClassVisitor {
private ClassVisitor cv;
public AgentJarEnhance(ClassVisitor classVisitor) {
super(ASM4, classVisitor);
this.cv = classVisitor;
}
@Override
public MethodVisitor visitMethod(int access, String name, String descriptor, String signature, String[] exceptions) {
// 新版接口转换为旧版接口
if ("isPostProcessingClassPathArchives".equals(name)) {
return null;
}
if ("postProcessClassPathArchives".equals(name)) {
return null;
}
if ("genArchive".equals(name)) {
return null;
}
return cv.visitMethod(access, name, descriptor, signature, exceptions);
}
@Override
public void visitEnd(){
isPostProcessingClassPathArchives();
postProcessClassPathArchives();
cv.visitEnd();
}
private void isPostProcessingClassPathArchives(){
MethodVisitor methodVisitor = cv.visitMethod(ACC_PUBLIC, "isPostProcessingClassPathArchives", "()Z", null, null);
methodVisitor.visitCode();
Label label0 = new Label();
methodVisitor.visitLabel(label0);
methodVisitor.visitLineNumber(14, label0);
methodVisitor.visitInsn(ICONST_1);
methodVisitor.visitInsn(IRETURN);
Label label1 = new Label();
methodVisitor.visitLabel(label1);
methodVisitor.visitLocalVariable("this", "Lorg/springframework/boot/loader/JarLauncher;", null, label0, label1, 0);
methodVisitor.visitMaxs(1, 1);
methodVisitor.visitEnd();
}
private void postProcessClassPathArchives(){
MethodVisitor methodVisitor = cv.visitMethod(ACC_PROTECTED, "postProcessClassPathArchives", "(Ljava/util/List;)V", "(Ljava/util/List<Lorg/springframework/boot/loader/archive/Archive;>;)V", new String[]{"java/lang/Exception"});
methodVisitor.visitCode();
Label label0 = new Label();
methodVisitor.visitLabel(label0);
methodVisitor.visitLineNumber(42, label0);
methodVisitor.visitVarInsn(ALOAD, 1);
methodVisitor.visitInsn(ICONST_0);
methodVisitor.visitVarInsn(ALOAD, 0);
methodVisitor.visitMethodInsn(INVOKEVIRTUAL, "org/springframework/boot/loader/JarLauncher", "getArchive", "()Lorg/springframework/boot/loader/archive/Archive;", false);
methodVisitor.visitMethodInsn(INVOKEINTERFACE, "java/util/List", "add", "(ILjava/lang/Object;)V", true);
Label label1 = new Label();
methodVisitor.visitLabel(label1);
methodVisitor.visitLineNumber(43, label1);
// 往classLoader添加jar包
methodVisitor.visitMethodInsn(INVOKESTATIC, "agent/config/SupportConfig", "getAll", "()Ljava/util/List;", false);
methodVisitor.visitVarInsn(ASTORE, 2);
Label label2 = new Label();
methodVisitor.visitLabel(label2);
methodVisitor.visitLineNumber(44, label2);
methodVisitor.visitVarInsn(ALOAD, 2);
methodVisitor.visitMethodInsn(INVOKEINTERFACE, "java/util/List", "iterator", "()Ljava/util/Iterator;", true);
methodVisitor.visitVarInsn(ASTORE, 3);
Label label3 = new Label();
methodVisitor.visitLabel(label3);
methodVisitor.visitFrame(Opcodes.F_APPEND, 2, new Object[]{"java/util/List", "java/util/Iterator"}, 0, null);
methodVisitor.visitVarInsn(ALOAD, 3);
methodVisitor.visitMethodInsn(INVOKEINTERFACE, "java/util/Iterator", "hasNext", "()Z", true);
Label label4 = new Label();
methodVisitor.visitJumpInsn(IFEQ, label4);
methodVisitor.visitVarInsn(ALOAD, 3);
methodVisitor.visitMethodInsn(INVOKEINTERFACE, "java/util/Iterator", "next", "()Ljava/lang/Object;", true);
methodVisitor.visitTypeInsn(CHECKCAST, "java/lang/String");
methodVisitor.visitVarInsn(ASTORE, 4);
Label label5 = new Label();
methodVisitor.visitLabel(label5);
methodVisitor.visitLineNumber(45, label5);
methodVisitor.visitTypeInsn(NEW, "org/springframework/boot/loader/jar/JarFile");
methodVisitor.visitInsn(DUP);
methodVisitor.visitTypeInsn(NEW, "java/io/File");
methodVisitor.visitInsn(DUP);
methodVisitor.visitVarInsn(ALOAD, 4);
methodVisitor.visitMethodInsn(INVOKESPECIAL, "java/io/File", "<init>", "(Ljava/lang/String;)V", false);
methodVisitor.visitMethodInsn(INVOKESPECIAL, "org/springframework/boot/loader/jar/JarFile", "<init>", "(Ljava/io/File;)V", false);
methodVisitor.visitVarInsn(ASTORE, 5);
Label label6 = new Label();
methodVisitor.visitLabel(label6);
methodVisitor.visitLineNumber(46, label6);
methodVisitor.visitVarInsn(ALOAD, 1);
methodVisitor.visitTypeInsn(NEW, "org/springframework/boot/loader/archive/JarFileArchive");
methodVisitor.visitInsn(DUP);
methodVisitor.visitVarInsn(ALOAD, 5);
methodVisitor.visitMethodInsn(INVOKESPECIAL, "org/springframework/boot/loader/archive/JarFileArchive", "<init>", "(Lorg/springframework/boot/loader/jar/JarFile;)V", false);
methodVisitor.visitMethodInsn(INVOKEINTERFACE, "java/util/List", "add", "(Ljava/lang/Object;)Z", true);
methodVisitor.visitInsn(POP);
Label label7 = new Label();
methodVisitor.visitLabel(label7);
methodVisitor.visitLineNumber(47, label7);
methodVisitor.visitJumpInsn(GOTO, label3);
methodVisitor.visitLabel(label4);
methodVisitor.visitLineNumber(48, label4);
methodVisitor.visitFrame(Opcodes.F_CHOP, 1, null, 0, null);
methodVisitor.visitInsn(RETURN);
Label label8 = new Label();
methodVisitor.visitLabel(label8);
methodVisitor.visitLocalVariable("jarFile", "Lorg/springframework/boot/loader/jar/JarFile;", null, label6, label7, 5);
methodVisitor.visitLocalVariable("path", "Ljava/lang/String;", null, label5, label7, 4);
methodVisitor.visitLocalVariable("this", "Lorg/springframework/boot/loader/JarLauncher;", null, label0, label8, 0);
methodVisitor.visitLocalVariable("archives", "Ljava/util/List;", "Ljava/util/List<Lorg/springframework/boot/loader/archive/Archive;>;", label0, label8, 1);
methodVisitor.visitLocalVariable("paths", "Ljava/util/List;", "Ljava/util/List<Ljava/lang/String;>;", label2, label8, 2);
methodVisitor.visitMaxs(5, 6);
methodVisitor.visitEnd();
}
}
java
import java.util.ArrayList;
import java.util.List;
public class SupportConfig {
private static final String KEY = "Governance.SupportConfig";
private static final String SPLIT = ";";
public static void add(String support) {
String value = System.getProperty(KEY);
if (null == value || value.length() < 1) {
value = support;
}else {
value = value + SPLIT + support;
}
System.setProperty(KEY, value);
}
public static List<String> getAll(){
String value = System.getProperty(KEY);
if (null == value || value.length() < 1) {
return new ArrayList<>();
}
List<String> list = new ArrayList<>();
String[] values = value.split(SPLIT);
for(String row : values) {
list.add(row);
}
return list;
}
}