基于拦截器处理mybatis模糊查询中的‘%’、‘_’、‘\’特殊字符问题

问题背景

问题来源一个二开项目的bug,测试反馈如下:

【所有查询模块】关键字里面输入'%',或者'_'查询出来的是所有数据

问题定位

当bug被提出来后,立刻去查看了相应的sql,在xml代码中实现模糊查询的代码都是如下路数

bash 复制代码
<if test="param.abc!=null and param.abc!=''">
    and t.abc like concat('%',#{param.abc},'%')
</if>

所以当前端传来的参数为%或者_时,最终的like部分sql如下: and t.abc like concat('%','%','%') 或者 and t.abc like concat('%','_','%'),而%_恰好是mysql内定的通配符,所以才会出现bug描述上的问题。

解决思路

既然问题是因为参数值撞脸通配符导致的,只要把参数值中的通配符%_转义一下就可以了,由于java中转义字符是\,所以\本身也要纳入转义的范围

项目技术栈用的是springboot+mybatis,那么实现转义有多种实现方式

  • 接口层处理

大致思路是用一个统一的aop,识别出所有查询接口的参数,如果参数为字符串类型,且包含通配符的,就将其转义后再往业务层传值。

  • 业务层处理

业务层各自处理各自的,凡有查询之处,凡是字符串类型的参数,统统进行校验并转义。

  • mapper层处理

借助mybatis的拦截器,在sql执行前,对sql进行拦截,判断sql是否包含模糊匹配的需求,如果包含,则获取相应的参数和参数值,对参数值进行过滤,将通配符进行转义

综合考虑后,决定采用最后一种,sql拦截器方式进行实现。

拦截器实现

ini 复制代码
package com.abc.efg.interceptors;

import cn.hutool.core.util.StrUtil;
import org.apache.ibatis.builder.SqlSourceBuilder;
import org.apache.ibatis.builder.StaticSqlSource;
import org.apache.ibatis.cache.CacheKey;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlSource;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.reflection.DefaultReflectorFactory;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.ReflectorFactory;
import org.apache.ibatis.reflection.factory.DefaultObjectFactory;
import org.apache.ibatis.reflection.factory.ObjectFactory;
import org.apache.ibatis.reflection.wrapper.DefaultObjectWrapperFactory;
import org.apache.ibatis.reflection.wrapper.ObjectWrapperFactory;
import org.apache.ibatis.scripting.defaults.RawSqlSource;
import org.apache.ibatis.scripting.xmltags.DynamicContext;
import org.apache.ibatis.scripting.xmltags.SqlNode;
import org.apache.ibatis.session.Configuration;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;

import java.lang.reflect.Field;
import java.util.*;


@Intercepts({
        @Signature(
                type = Executor.class,
                method = "query",
                args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}),
        @Signature(
                type = Executor.class,
                method = "query",
                args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class, CacheKey.class, BoundSql.class})

})
public class SpecialCharConvertSqlInterceptor implements Interceptor {

    private static final ObjectFactory DEFAULT_OBJECT_FACTORY = new DefaultObjectFactory();
    private static final ObjectWrapperFactory DEFAULT_OBJECT_WRAPPER_FACTORY = new DefaultObjectWrapperFactory();
    private static final ReflectorFactory DEFAULT_REFLECTOR_FACTORY = new DefaultReflectorFactory();
    private static final String SQL_SOURCE = "sqlSource";
    private static final String ROOT_SQL_NODE = "sqlSource.rootSqlNode";
    private static final String LIKE_KEYWORD = "like";

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        Object parameter = invocation.getArgs()[1];
        MappedStatement statement = (MappedStatement) invocation.getArgs()[0];
        MetaObject metaMappedStatement = MetaObject.forObject(statement, DEFAULT_OBJECT_FACTORY, DEFAULT_OBJECT_WRAPPER_FACTORY, DEFAULT_REFLECTOR_FACTORY);
        BoundSql boundSql = statement.getBoundSql(parameter);
        if (metaMappedStatement.hasGetter(SQL_SOURCE)) {
            SqlSource sqlSourceObj = (SqlSource) metaMappedStatement.getValue(SQL_SOURCE);
            Configuration configuration = statement.getConfiguration();
            Object parameterObject = boundSql.getParameterObject();
            Class<?> parameterType = parameterObject == null ? Object.class : parameterObject.getClass();
            DynamicContext context = new DynamicContext(statement.getConfiguration(), boundSql.getParameterObject());
            String sql;
            SqlSource newSqlSource = sqlSourceObj;
            // 没有占位符 单参数类型
            if (sqlSourceObj instanceof RawSqlSource) {
                RawSqlSource sqlSource = (RawSqlSource) sqlSourceObj;
                Class<? extends RawSqlSource> aClass = sqlSource.getClass();
                Field sqlField = aClass.getDeclaredField(SQL_SOURCE);
                sqlField.setAccessible(true);
                Object staticSqlSource = sqlField.get(sqlSource);
                if (staticSqlSource instanceof StaticSqlSource) {
                    Class<? extends StaticSqlSource> rawSqlSource = ((StaticSqlSource) staticSqlSource).getClass();
                    Field sqlInStatic = rawSqlSource.getDeclaredField("sql");
                    sqlInStatic.setAccessible(true);
                    String sqlStr = (String) sqlInStatic.get(staticSqlSource);
                    if (sqlStr.toLowerCase().contains(LIKE_KEYWORD)) {
                        sql = modifyLikeSqlForRawSqlSource(sqlStr, parameterObject);
                        //构建新的sqlSource;
                        newSqlSource = new StaticSqlSource(configuration, sql, ((StaticSqlSource) staticSqlSource).getBoundSql(parameterObject).getParameterMappings());
                    }
                }
            } else if (metaMappedStatement.hasGetter(ROOT_SQL_NODE)) {
                // 有占位符的类型
                SqlNode sqlNode = (SqlNode) metaMappedStatement.getValue(ROOT_SQL_NODE);
                sqlNode.apply(context);
                String contextSql = context.getSql();
                // like sql 特殊处理
                sql = modifyLikeSql(contextSql, parameterObject);
                //构建新的sqlSource;
                SqlSourceBuilder sqlSourceBuilder = new SqlSourceBuilder(configuration);
                newSqlSource = sqlSourceBuilder.parse(sql, parameterType, context.getBindings());
            }
            MappedStatement newMs = newMappedStatement(statement, buildNewBoundSqlSource(newSqlSource, parameterObject, context.getBindings()));
            invocation.getArgs()[0] = newMs;
        }
        return invocation.proceed();
    }

    private SqlSource buildNewBoundSqlSource(SqlSource newSqlSource, Object paramObject, Map<String, Object> objectMap) {
        BoundSql newBoundSql = newSqlSource.getBoundSql(paramObject);
        for (Map.Entry<String, Object> entry : objectMap.entrySet()) {
            newBoundSql.setAdditionalParameter(entry.getKey(), entry.getValue());
        }
        return new SqlSourceWrapper(newBoundSql);
    }

    /**
     * 对 RawSqlSource 类型的 SQL 语句进行 Like 查询的修改
     *
     * @param sqlStr          原始 SQL 语句
     * @param parameterObject 查询参数对象
     * @return 修改后的 SQL 语句
     */
    private String modifyLikeSqlForRawSqlSource(String sqlStr, Object parameterObject) {
        MetaObject metaObject = MetaObject.forObject(parameterObject, DEFAULT_OBJECT_FACTORY, DEFAULT_OBJECT_WRAPPER_FACTORY, DEFAULT_REFLECTOR_FACTORY);
        String[] values = metaObject.getGetterNames();
        return setValueForMetaObject(sqlStr, Arrays.asList(values), metaObject);

    }

    private String setValueForMetaObject(String sql, List<String> values, MetaObject metaObject) {
        for (String param : values) {
            Object val = metaObject.getValue(param);
            if (val != null && val instanceof String && (val.toString().contains("%") || val.toString().contains("_") || val.toString().contains("\"))) {
                val = specialCharacterReplace(val.toString());
                metaObject.setValue(param, val);
            }
        }
        return sql;
    }

    private MappedStatement newMappedStatement(MappedStatement ms, SqlSource newSqlSource) {
        MappedStatement.Builder builder = new
                MappedStatement.Builder(ms.getConfiguration(), ms.getId(), newSqlSource, ms.getSqlCommandType());
        builder.parameterMap(ms.getParameterMap());
        builder.resource(ms.getResource());
        builder.fetchSize(ms.getFetchSize());
        builder.statementType(ms.getStatementType());
        builder.keyGenerator(ms.getKeyGenerator());
        builder.timeout(ms.getTimeout());
        builder.parameterMap(ms.getParameterMap());
        builder.resultMaps(ms.getResultMaps());
        builder.resultSetType(ms.getResultSetType());
        builder.cache(ms.getCache());
        builder.flushCacheRequired(ms.isFlushCacheRequired());
        builder.useCache(ms.isUseCache());
        builder.resultOrdered(ms.isResultOrdered());
        return builder.build();
    }

    private String modifyLikeSql(String sql, Object parameterObject) {
        if (!sql.toLowerCase().contains(LIKE_KEYWORD)) {
            return sql;
        }
        List<String> replaceFiled = new ArrayList<>();
        String[] likes = sql.split(LIKE_KEYWORD);
        for (String str : likes) {
            String val = getParameterKey(str);
            if (StrUtil.isNotBlank(val)) {
                replaceFiled.add(val);
            }
        }
        // 修改参数
        MetaObject metaObject = MetaObject.forObject(parameterObject, DEFAULT_OBJECT_FACTORY, DEFAULT_OBJECT_WRAPPER_FACTORY, DEFAULT_REFLECTOR_FACTORY);
        return setValueForMetaObject(sql, replaceFiled, metaObject);

    }

    /**
     * 将 % 替换成 %
     * 将 _ 替换成 _
     * 将 \ 替换成 \
     *
     * @param str 编译后的sql
     * @return str
     */
    private String specialCharacterReplace(String str) {
        str = str.replace("\", "\\");
        str = str.replace("%", "\%");
        str = str.replace("_", "\_");
        return str;
    }

    /**
     * 将 like 后的参数名取出来
     *
     * @param input 编译后sql
     * @return 参数名称
     */
    private String getParameterKey(String input) {
        String key = "";
        // 只取包含concat的那部分
        if (input.contains("concat")) {
            String[] temp = input.split("#");
            if (temp.length > 1) {
                key = temp[1];
                key = key.replace("{", "").replace("}", "").split(",")[0];
            }
        }
        return key;
    }

    @Override
    public void setProperties(Properties properties) {

    }

    static class SqlSourceWrapper implements SqlSource {
        private final BoundSql boundSql;

        @SuppressWarnings("checkstyle:RedundantModifier")
        public SqlSourceWrapper(BoundSql boundSql) {
            this.boundSql = boundSql;
        }

        @Override
        public BoundSql getBoundSql(Object parameterObject) {
            return boundSql;
        }
    }
}

遇到的问题

在第一版中,直接将SpecialCharConvertSqlInterceptor注入spring容器并运行,就出现了一个问题:

在列表查询中包含like条件的可以正常查询,而在分页中结果就与预期不一致,切确的说,是在统计分页总数的时候正确的,但到了具体查询数据时,结果就不正确了。

查看sql后发现是关键字符被替换了两次

这是为什么呢?

项目中分页用的是com.github.pagehelper

它在执行分页查询前会先执行COUNT_SQL

而在执行COUNT_SQL时,参数值已经被替换了一次,再次执行数据查询时,并不会去再次解析参数和参数值,而是直接复用执行COUNT_SQL时已解析的值,这就会导致包含特殊字符的参数值被替换两次,从而导致结果不正确。

经过一番研究后认为是拦截器执行顺序问题,因为pagehelper也是基于拦截器实现的,应当先执行我的转义拦截器,再执行pagehelper的拦截器,所以不能采用直接注入的方式,而是要手工注入,并将拦截器的顺序进行调整。

java 复制代码
package com.abc.ef.conf;

import com.google.common.collect.ImmutableList;
import com.abc.efg.interceptors.BaseEntityInterceptor;
import com.abc.efg.interceptors.SpecialCharConvertSqlInterceptor;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.session.SqlSessionFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationListener;
import org.springframework.context.event.ContextRefreshedEvent;
import org.springframework.stereotype.Component;

import java.util.List;

/**
 * org.apache.ibatis.session.Configuration#newExecutor(org.apache.ibatis.transaction.Transaction, org.apache.ibatis.session.ExecutorType)
 * interceptorChain.pluginAll(executor); 该方法会逆序创建代理对象,自定义拦截器需要添加到最后一个中
 */
@Component
public class CustomerInterceptorRegister implements ApplicationListener<ContextRefreshedEvent> {
    @Autowired
    private List<SqlSessionFactory> sqlSessionFactoryList;

    @Override
    public void onApplicationEvent(ContextRefreshedEvent event) {
        for (SqlSessionFactory sqlSessionFactory : sqlSessionFactoryList) {
            org.apache.ibatis.session.Configuration configuration = sqlSessionFactory.getConfiguration();
            for (Interceptor interceptor : getCustomerInterceptors()) {
                if (!containsInterceptor(configuration, interceptor)) {
                    configuration.addInterceptor(interceptor);
                }
            }
        }
    }

    private List<Interceptor> getCustomerInterceptors() {
        return ImmutableList.of(new SpecialCharConvertSqlInterceptor(), new BaseEntityInterceptor());
    }

    /**
     * 是否已经存在相同的拦截器
     *
     * @param configuration 配置类
     * @param interceptor   拦截器
     * @return 是否存在
     */
    private boolean containsInterceptor(org.apache.ibatis.session.Configuration configuration, Interceptor interceptor) {
        try {
            // getInterceptors since 3.2.2
            return configuration.getInterceptors().stream().anyMatch(config -> interceptor.getClass().isAssignableFrom(config.getClass()));
        } catch (Exception e) {
            return false;
        }
    }
}

经过以上改造,问题得以解决,顺手记录一下。

附录

mybatis相关版本

xml 复制代码
<dependency>
    <groupId>tk.mybatis</groupId>
    <artifactId>mapper</artifactId>
    <version>4.2.3</version>
</dependency>
<dependency>
    <groupId>org.mybatis</groupId>
    <artifactId>mybatis</artifactId>
    <version>3.5.13</version>
</dependency>
<dependency>
    <groupId>org.mybatis</groupId>
    <artifactId>mybatis-spring</artifactId>
    <version>2.1.1</version>
</dependency>

pagehelper版本

xml 复制代码
<dependency>
    <groupId>com.github.pagehelper</groupId>
    <artifactId>pagehelper-spring-boot-starter</artifactId>
    <version>1.4.7</version>
</dependency>
相关推荐
代码对我眨眼睛14 分钟前
springboot从分层到解耦
spring boot·后端
The Straggling Crow23 分钟前
go 战略
开发语言·后端·golang
ai安歌29 分钟前
【JavaWeb】利用IDEA2024+tomcat10配置web6.0版本搭建JavaWeb开发项目
java·开发语言·后端·tomcat·web·intellij idea
尘浮生41 分钟前
Java项目实战II基于Java+Spring Boot+MySQL的作业管理系统设计与实现(源码+数据库+文档)
java·开发语言·数据库·spring boot·后端·mysql·spring
程序员阿鹏2 小时前
ArrayList 与 LinkedList 的区别?
java·开发语言·后端·eclipse·intellij-idea
java_heartLake3 小时前
微服务中间件之Nacos
后端·中间件·nacos·架构
GoFly开发者4 小时前
GoFly快速开发框架/Go语言封装的图像相似性比较插件使用说明
开发语言·后端·golang
苹果酱05674 小时前
通过springcloud gateway优雅的进行springcloud oauth2认证和权限控制
java·开发语言·spring boot·后端·中间件
豌豆花下猫5 小时前
Python 潮流周刊#70:微软 Excel 中的 Python 正式发布!(摘要)
后端·python·ai
芯冰乐6 小时前
综合时如何计算net delay?
后端·fpga开发