手写mybatis

手写mybatis

https://gitee.com/laomaodu/handwritten-mybatis

内容回顾

sqlsession线程不安全
复制代码
  public static SqlSession getSqlSession(SqlSessionFactory sessionFactory, ExecutorType executorType,
      PersistenceExceptionTranslator exceptionTranslator) {
​
    notNull(sessionFactory, NO_SQL_SESSION_FACTORY_SPECIFIED);
    notNull(executorType, NO_EXECUTOR_TYPE_SPECIFIED);
​
    /// public abstract class TransactionSynchronizationManager {
    ///     private static final ThreadLocal<Map<Object, Object>> resources = new NamedThreadLocal("Transactional resources");
    ///     private static final ThreadLocal<Set<TransactionSynchronization>> synchronizations = new NamedThreadLocal("Transaction synchronizations");
    ///     private static final ThreadLocal<String> currentTransactionName = new NamedThreadLocal("Current transaction name");
    ///     private static final ThreadLocal<Boolean> currentTransactionReadOnly = new NamedThreadLocal("Current transaction read-only status");
    ///     private static final ThreadLocal<Integer> currentTransactionIsolationLevel = new NamedThreadLocal("Current transaction isolation level");
    ///     private static final ThreadLocal<Boolean> actualTransactionActive = new NamedThreadLocal("Actual transaction active");
    ///
    ///     public TransactionSynchronizationManager() {
    ///     }
    //从 Spring 的事务上下文里,取出当前线程绑定的 SqlSession
    SqlSessionHolder holder = (SqlSessionHolder) TransactionSynchronizationManager.getResource(sessionFactory);
​
    SqlSession session = sessionHolder(executorType, holder);
    if (session != null) {
      return session;
    }
​
    LOGGER.debug(() -> "Creating a new SqlSession");
    session = sessionFactory.openSession(executorType);
​
    registerSessionHolder(sessionFactory, executorType, exceptionTranslator, session);
​
    return session;
  }
​

线程私有

复制代码
 session = sessionFactory.openSession(executorType);
    //把 SqlSession 放进 SqlSessionHolder
    //绑定到 Spring 的事务资源管理器
    //注册事务同步回调(事务结束时自动清理)
    registerSessionHolder(sessionFactory, executorType, exceptionTranslator, session);
复制代码
​
  /**
   * 将新创建的 SqlSession 注册到 Spring 的事务同步管理器中
   *
   * 作用:
   * 1)让 SqlSession 绑定到当前线程事务上下文(ThreadLocal)
   * 2)保证事务内复用同一个 SqlSession
   * 3)事务提交/回滚时自动关闭 SqlSession
   */
  private static void registerSessionHolder(SqlSessionFactory sessionFactory, ExecutorType executorType,
      PersistenceExceptionTranslator exceptionTranslator, SqlSession session) {
    SqlSessionHolder holder;
    if (TransactionSynchronizationManager.isSynchronizationActive()) {
      // 获取 MyBatis 的运行环境对象 Environment
      // Environment 内部包含:
      //         * - DataSource(数据源)
      //         * - TransactionFactory(事务工厂)
      Environment environment = sessionFactory.getConfiguration().getEnvironment();
      //事务由spring来实现
      if (environment.getTransactionFactory() instanceof SpringManagedTransactionFactory) {
        LOGGER.debug(() -> "Registering transaction synchronization for SqlSession [" + session + "]");
      /// 创建 SqlSessionHolder,封装 SqlSession
        holder = new SqlSessionHolder(session, executorType, exceptionTranslator);
        // 核心操作:绑定资源到 Spring 的事务管理器中
        //当前线程的事务上下文中存入 SqlSessionHolder
        // 后续 Mapper 调用都会复用这个 SqlSession
        TransactionSynchronizationManager.bindResource(sessionFactory, holder);
        //册事务同步回调 SqlSessionSynchronization
        TransactionSynchronizationManager
            .registerSynchronization(new SqlSessionSynchronization(holder, sessionFactory));
        holder.setSynchronizedWithTransaction(true);
        holder.requested();
      } else {
        if (TransactionSynchronizationManager.getResource(environment.getDataSource()) == null) {
          LOGGER.debug(() -> "SqlSession [" + session
              + "] was not registered for synchronization because DataSource is not transactional");
        } else {
          throw new TransientDataAccessResourceException(
              "SqlSessionFactory must be using a SpringManagedTransactionFactory in order to use Spring transaction synchronization");
        }
      }
    } else {
      LOGGER.debug(() -> "SqlSession [" + session
          + "] was not registered for synchronization because synchronization is not active");
    }
​
  }

线程缓存使用

一个关掉一个自减

手写mybatis v1.0

加深理解

jdbc单纯操作

1.重复代码

2.耦合度高

3.返回值无处理

借鉴

定义sqlSession

复制代码
​
public class GPSqlSession {
    private GPConfiguration configuration;
​
    private GPExecutor executor;
​
    public GPSqlSession(GPConfiguration configuration, GPExecutor executor){
        this.configuration = configuration;
        this.executor = executor;
    }
    /// 获取数据的地方
    public <T> T selectOne(String statementId, Object paramater){
        // 根据statementId拿到SQL
        String sql = GPConfiguration.sqlMappings.getString(statementId);
        if(null != sql && !"".equals(sql)){
            return executor.query(sql, paramater );
        }
        return null;
    }
      //返回代理的proxy->invoke计算出sql com.gupaoedu.mybatis.v1.mapper.BlogMapper.selectBlogById=select * from blog where bid = %d
    public <T> T getMapper(Class clazz){
        return configuration.getMapper(clazz, this);
    }
}

定义Executor

复制代码
package com.gupaoedu.mybatis.v1;
​
import com.gupaoedu.mybatis.v1.mapper.Blog;
​
import java.sql.DriverManager;
import java.sql.ResultSet;
import java.sql.SQLException;
​
public class GPExecutor {
    public <T> T query(String sql, Object paramater) {
        Connection conn = null;
        Statement stmt = null;
        Blog blog = new Blog();
​
        try {
            // 注册 JDBC 驱动
            Class.forName("com.mysql.jdbc.Driver");
​
            // 打开连接
            conn = DriverManager.getConnection("jdbc:mysql://localhost:3306/mybatis", "root", "123456");
​
            // 执行查询
            stmt = conn.createStatement();
            ResultSet rs = stmt.executeQuery(String.format(sql, paramater));
​
            // 获取结果集
            while (rs.next()) {
                Integer bid = rs.getInt("bid");
                String name = rs.getString("name");
                Integer authorId = rs.getInt("author_id");
                blog.setAuthorId(authorId);
                blog.setBid(bid);
                blog.setName(name);
            }
            System.out.println(blog);
​
            rs.close();
            stmt.close();
            conn.close();
        } catch (SQLException se) {
            se.printStackTrace();
        } catch (Exception e) {
            e.printStackTrace();
        } finally {
            try {
                if (stmt != null) stmt.close();
            } catch (SQLException se2) {
            }
            try {
                if (conn != null) conn.close();
            } catch (SQLException se) {
                se.printStackTrace();
            }
        }
        return (T)blog;
    }
}
​

mapper

属性文件解析

复制代码
com.gupaoedu.mybatis.v1.mapper.BlogMapper.selectBlogById=select * from blog where bid = %d

config保存所有信息

解析sql资源文件

sqlsession

1.提供单条查询

executor

需要sql语句再config里面

复制代码
​
public class TestMain {
    public static void main(String[] args) {
        GPConfiguration configuration = new GPConfiguration();
        GPExecutor executor = new GPExecutor();
        GPSqlSession sqlSession = new GPSqlSession(configuration, executor);
        Blog blog = sqlSession.selectOne("com.gupaoedu.mybatis.v1.mapper.BlogMapper.selectBlogById", 1);
        System.out.println(blog);
    }
}

executor执行

先跑通

最基础

代理方式调用

复制代码
        V1Mapper mapper = sqlSession.getMapper(V1Mapper.class);
        String sql = mapper.selectBlogById(1);

接下来

复制代码
    public <T> T getMapper(Class clazz){
        return configuration.getMapper(clazz, this);
    }
复制代码
public class GPConfiguration {
    public static final ResourceBundle sqlMappings;
​
    static{
        sqlMappings = ResourceBundle.getBundle("v1sql");
    }
​
    public <T> T getMapper(Class clazz, GPSqlSession sqlSession) {
        return (T) Proxy.newProxyInstance(this.getClass().getClassLoader(),
                new Class[]{clazz},
                new GPMapperProxy(sqlSession));
    }
}
​
复制代码
public class GPMapperProxy implements InvocationHandler {
    private GPSqlSession sqlSession;
​
    public GPMapperProxy(GPSqlSession sqlSession){
        this.sqlSession = sqlSession;
    }
    @Override
    public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
        String mapperInterface = method.getDeclaringClass().getName();
        String methodName = method.getName();
        String statementId = mapperInterface + "." + methodName;
        return sqlSession.selectOne(statementId, args[0]);
    }
}
​

迁移

继续改造

手写2.0

statem细化

参数处理

结果集处理

数据库硬编码

复制代码
    public static void main(String[] args) {
        SqlSessionFactory factory = new SqlSessionFactory();
        DefaultSqlSession sqlSession = factory.build().openSqlSession();
        // 获取MapperProxy代理
        BlogMapper mapper = sqlSession.getMapper(BlogMapper.class);
        Blog blog = mapper.selectBlogById(1);
​
        System.out.println("第一次查询: " + blog);
        System.out.println();
        blog = mapper.selectBlogById(1);
        System.out.println("第二次查询: " + blog);
    }
复制代码
StatementHandler

解决硬编码过多

复制代码
 /**
     * 获取连接
     * @return
     * @throws SQLException
     */
    private Connection getConnection() {
        String driver = Configuration.properties.getString("jdbc.driver");
        String url =  Configuration.properties.getString("jdbc.url");
        String username = Configuration.properties.getString("jdbc.username");
        String password = Configuration.properties.getString("jdbc.password");
        Connection conn = null;
        try {
            Class.forName(driver);
            conn = DriverManager.getConnection(url, username, password);
        } catch (ClassNotFoundException e) {
            e.printStackTrace();
        } catch (SQLException e) {
            e.printStackTrace();
        }
        return conn;
    }

无连接器

职责细分

缓存和默认

复制代码
    @Override
    public <T> T query(String statement, Object[] parameter, Class pojo)  {
        // 计算CacheKey
        CacheKey cacheKey = new CacheKey();
        cacheKey.update(statement);
        cacheKey.update(joinStr(parameter));
        // 是否拿到缓存
        if (cache.containsKey(cacheKey.getCode())) {
            // 命中缓存
            System.out.println("【命中缓存】");
            return (T)cache.get(cacheKey.getCode());
        }else{
            // 没有的话调用被装饰的SimpleExecutor从数据库查询
            Object obj = delegate.query(statement, parameter, pojo);
            cache.put(cacheKey.getCode(), obj);
            return (T)obj;
        }
    }
复制代码
    @Override
    public <T> T query(String statement, Object[] parameter, Class pojo) {
        StatementHandler statementHandler = new StatementHandler();
        return statementHandler.query(statement, parameter, pojo);
    }

Executor 负责"调度流程",Handler 负责"具体干活",这是典型的职责拆分设计。

参数处理paramteHadnle

结果集处理

复制代码
private ResultSetHandler resultSetHandler = new ResultSetHandler();
​
    public <T> T query(String statement, Object[] parameter, Class pojo){
        // JDBC数据库连接对象
        Connection conn = null;
        // 预编译SQL执行对象
        PreparedStatement preparedStatement = null;
        // 查询结果临时存储
        Object result = null;
​
        try {
            conn = getConnection();
​
            preparedStatement = conn.prepareStatement(statement);
            // 参数处理器封装
​
            ParameterHandler parameterHandler = new ParameterHandler(preparedStatement);
            //将参数填充到SQL占位符 ? 中
            //            //Java 默认行为:
​
            //            //对象作为参数传递时,传的是引用值(指针)
            parameterHandler.setParameters(parameter);
            preparedStatement.execute();
            try {
                //结果处理
                result = resultSetHandler.handle(preparedStatement.getResultSet(), pojo);
            } catch (Exception e) {
                e.printStackTrace();
            }
            return (T)result;
        } catch (Exception e){
            e.printStackTrace();
        } finally {
            if (conn != null) {
                try {
                    conn.close();
                } catch (SQLException e) {
                    e.printStackTrace();
                }
                conn = null;
            }
        }
        // 只在try里面return会报错
        return null;
    }

config增强

复制代码
public class Configuration {
    // SQL映射关系配置,使用注解时不用重复配置
    public static final ResourceBundle sqlMappings;
    // 全局配置
    public static final ResourceBundle properties;
    // 维护接口与工厂类关系
    public static final MapperRegistry MAPPER_REGISTRY = new MapperRegistry();
    // 维护接口方法与SQL关系
    public static final Map<String, String> mappedStatements = new HashMap<>();
​
    // 插件
    private InterceptorChain interceptorChain = new InterceptorChain();
    // 所有Mapper接口
    private List<Class<?>> mapperList = new ArrayList<>();
    // 类所有文件
    private List<String> classPaths = new ArrayList<>();
​
    static{
        sqlMappings = ResourceBundle.getBundle("v2sql");
        properties = ResourceBundle.getBundle("mybatis");
    }

注解实现

缓存实现

复制代码
   public DefaultSqlSession(Configuration configuration) {
        this.configuration = configuration;
        // 根据全局配置决定是否使用缓存装饰
        this.executor = configuration.newExecutor();
    }
复制代码
   /**
     * 创建执行器,当开启缓存时使用缓存装饰
     * 当配置插件时,使用插件代理
     * @return
     */
    public Executor newExecutor() {
        Executor executor = null;
        if (properties.getString("cache.enabled").equals("true")) {
            executor = new CachingExecutor(new SimpleExecutor());
        }else{
            executor = new SimpleExecutor();
        }
​
        // 目前只拦截了Executor,所有的插件都对Executor进行代理,没有对拦截类和方法签名进行判断
        if (interceptorChain.hasPlugin()) {
            return (Executor)interceptorChain.pluginAll(executor);
        }
        return executor;
    }
​

插件

MyBatis V2 的插件系统采用 *责任链模式 + JDK 动态代理* 实现,允许用户在 SQL 执行的关键节点插入自定义逻辑。

复制代码
┌─────────────────────────────────────────────────────────────┐
│                      插件核心组件                              │
├─────────────────────────────────────────────────────────────┤
│  @Intercepts      →  注解标记,指定拦截的方法名                  │
│  Interceptor      →  拦截器接口,定义拦截逻辑                   │
│  InterceptorChain →  拦截器链,管理多个插件的层层代理            │
│  Plugin           →  JDK动态代理类,执行实际的拦截判断            │
│  Invocation       →  方法调用封装,包装目标方法信息               │
└─────────────────────────────────────────────────────────────┘
复制代码
/**
 * 拦截器接口,所有自定义拦截器必须实现此接口
 * @Author: qingshan
 */
public interface Interceptor {
    /**
     * 插件的核心逻辑实现
     * @param invocation
     * @return
     * @throws Throwable
     */
    Object intercept(Invocation invocation) throws Throwable;
​
    /**
     * 对被拦截对象进行代理
     * @param target
     * @return
     */
    Object plugin(Object target);
}
​
  • intercept(): 执行拦截逻辑,通过 Invocation.proceed() 调用原方法

  • plugin(): 生成代理对象,通常调用 Plugin.wrap(target, this)

复制代码
@Documented
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.TYPE)
public @interface Intercepts {
    String value();  // 指定要拦截的方法名
}
复制代码
@Intercepts("query")  // 拦截名为 "query" 的方法
public class MyPlugin implements Interceptor {
    // ...
}

Plugin 代理类

复制代码
public class Plugin implements InvocationHandler {
    private Object target;      // 被代理对象
    private Interceptor interceptor;  // 拦截器
    
    // 创建代理对象
    public static Object wrap(Object obj, Interceptor interceptor) {
        Class clazz = obj.getClass();
        return Proxy.newProxyInstance(
            clazz.getClassLoader(), 
            clazz.getInterfaces(), 
            new Plugin(obj, interceptor)
        );
    }
    
    // 方法调用拦截
    @Override
    public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
        // 检查是否有@Intercepts注解
        if (interceptor.getClass().isAnnotationPresent(Intercepts.class)) {
            // 判断方法名是否匹配
            if (method.getName().equals(
                interceptor.getClass().getAnnotation(Intercepts.class).value())) {
                // 进入拦截逻辑
                return interceptor.intercept(new Invocation(target, method, args));
            }
        }
        // 非拦截方法,执行原逻辑
        return method.invoke(target, args);
    }
}

Invocation 封装类

复制代码
public class Invocation {
    private Object target;    // 目标对象
    private Method method;    // 目标方法
    private Object[] args;    // 方法参数
    
    public Object proceed() throws InvocationTargetException, IllegalAccessException {
        return method.invoke(target, args);  // 调用原方法
    }
    
    // Getter 方法...
}

作用:封装被拦截方法的完整信息,让拦截器可以获取参数并控制是否执行原方法。

InterceptorChain 拦截器链

复制代码
public class InterceptorChain {
    private final List<Interceptor> interceptors = new ArrayList<>();
    
    // 添加拦截器
    public void addInterceptor(Interceptor interceptor) {
        interceptors.add(interceptor);
    }
    
    // 层层代理:对每个拦截器都进行a一次代理包装
    public Object pluginAll(Object target) {
        for (Interceptor interceptor : interceptors) {
            target = interceptor.plugin(target);  // 代理包裹代理
        }
        return target;
    }
}
复制代码
原始 Executor
    ↓ proxy
Plugin3(拦截器3)  ← 最外层代理
    ↓ proxy  
Plugin2(拦截器2)
    ↓ proxy
Plugin1(拦截器1)  ← 最内层代理
    ↓ invoke
SimpleExecutor    ← 原始对象
复制代码
plugin.path=com.gupaoedu.mybatis.v2.interceptor.MyPlugin

初始化

文件位置 : com.gupaoedu.mybatis.v2.session.Configuration

复制代码
    String pluginPathValue = properties.getString("plugin.path");
        String[] pluginPaths = pluginPathValue.split(",");
        for (String plugin : pluginPaths) {
            // 反射实例化拦截器
            Interceptor interceptor = (Interceptor) Class.forName(plugin).newInstance();
            interceptorChain.addInterceptor(interceptor);
        }
复制代码
// 创建执行器时应用插件代理
    public Executor newExecutor() {
        Executor executor = new SimpleExecutor();
        
        // 如果有插件,对 Executor 进行层层代理
        if (interceptorChain.hasPlugin()) {
            return (Executor) interceptorChain.pluginAll(executor);
        }
        return executor;
    }
复制代码
┌────────────────┐
│  加载配置文件   │
└───────┬────────┘
        ↓
┌────────────────┐
│ 解析plugin.path │
└───────┬────────┘
        ↓
┌────────────────┐     ┌──────────────────┐
│ 反射创建拦截器  │────→│ InterceptorChain  │
└───────┬────────┘     │   .addInterceptor()│
        ↓              └──────────────────┘
┌────────────────┐
│  newExecutor() │
└───────┬────────┘
        ↓
┌────────────────┐     ┌──────────────────┐
│ 创建SimpleExecutor│──→│ interceptorChain │
└───────┬────────┘     │   .pluginAll()   │
        ↓              └────────┬─────────┘
┌────────────────┐              ↓
│ 返回代理后的Executor│←──层层包装Plugin代理
└────────────────┘

流程

com.gupaoedu.mybatis.v2.interceptor.MyPlugin`

复制代码
package com.gupaoedu.mybatis.v2.interceptor;
​
import com.gupaoedu.mybatis.v2.annotation.Intercepts;
import com.gupaoedu.mybatis.v2.plugin.Interceptor;
import com.gupaoedu.mybatis.v2.plugin.Invocation;
import com.gupaoedu.mybatis.v2.plugin.Plugin;
​
import java.util.Arrays;
​
/**
 * 自定义插件 - 拦截 query 方法,打印 SQL 和参数
 */
@Intercepts("query")  // 指定拦截 Executor.query() 方法
public class MyPlugin implements Interceptor {
    
    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        // 获取方法参数
        String statement = (String) invocation.getArgs()[0];
        Object[] parameter = (Object[]) invocation.getArgs()[1];
        Class pojo = (Class) invocation.getArgs()[2];
        
        // 自定义逻辑:打印 SQL 信息
        System.out.println("进入自定义插件:MyPlugin");
        System.out.println("SQL:[" + statement + "]");
        System.out.println("Parameters:" + Arrays.toString(parameter));
        
        // 继续执行原方法(可在此处做权限校验、修改参数等)
        return invocation.proceed();
    }
​
    @Override
    public Object plugin(Object target) {
        // 使用 Plugin 工具类创建代理
        return Plugin.wrap(target, this);
    }
}
相关推荐
两点王爷5 小时前
Java基础面试题——【Java语言特性】
java·开发语言
海山数据库5 小时前
移动云大云海山数据库(He3DB)postgresql_anonymizer插件原理介绍与安装
数据库·he3db·大云海山数据库·移动云数据库
choke2335 小时前
[特殊字符] Python 文件与路径操作
java·前端·javascript
云飞云共享云桌面5 小时前
高性能图形工作站的资源如何共享给10个SolidWorks研发设计用
linux·运维·服务器·前端·网络·数据库·人工智能
choke2335 小时前
Python 基础语法精讲:数据类型、运算符与输入输出
java·linux·服务器
2501_927993535 小时前
SQL Server 2022安装详细教程(图文详解,非常详细)
数据库·sqlserver
星火s漫天5 小时前
第一篇: 使用Docker部署flask项目(Flask + DB 容器化)
数据库·docker·flask
岁岁种桃花儿5 小时前
CentOS7 彻底卸载所有JDK/JRE + 重新安装JDK8(实操完整版,解决kafka/jps报错)
java·开发语言·kafka
xcLeigh5 小时前
Python 项目实战:用 Flask 实现 MySQL 数据库增删改查 API
数据库·python·mysql·flask·教程·python3