问题背景
问题来源一个二开项目的bug,测试反馈如下:
【所有查询模块】关键字里面输入'%',或者'_'查询出来的是所有数据
问题定位
当bug被提出来后,立刻去查看了相应的sql,在xml代码中实现模糊查询的代码都是如下路数
bash
<if test="param.abc!=null and param.abc!=''">
and t.abc like concat('%',#{param.abc},'%')
</if>
所以当前端传来的参数为%或者_时,最终的like部分sql如下: and t.abc like concat('%','%','%')
或者 and t.abc like concat('%','_','%')
,而%
和_
恰好是mysql内定的通配符,所以才会出现bug描述上的问题。
解决思路
既然问题是因为参数值撞脸通配符导致的,只要把参数值中的通配符%
和_
转义一下就可以了,由于java中转义字符是\
,所以\
本身也要纳入转义的范围
项目技术栈用的是springboot+mybatis,那么实现转义有多种实现方式
- 接口层处理
大致思路是用一个统一的aop,识别出所有查询接口的参数,如果参数为字符串类型,且包含通配符的,就将其转义后再往业务层传值。
- 业务层处理
业务层各自处理各自的,凡有查询之处,凡是字符串类型的参数,统统进行校验并转义。
- mapper层处理
借助mybatis的拦截器,在sql执行前,对sql进行拦截,判断sql是否包含模糊匹配的需求,如果包含,则获取相应的参数和参数值,对参数值进行过滤,将通配符进行转义
综合考虑后,决定采用最后一种,sql拦截器方式进行实现。
拦截器实现
ini
package com.abc.efg.interceptors;
import cn.hutool.core.util.StrUtil;
import org.apache.ibatis.builder.SqlSourceBuilder;
import org.apache.ibatis.builder.StaticSqlSource;
import org.apache.ibatis.cache.CacheKey;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlSource;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.reflection.DefaultReflectorFactory;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.ReflectorFactory;
import org.apache.ibatis.reflection.factory.DefaultObjectFactory;
import org.apache.ibatis.reflection.factory.ObjectFactory;
import org.apache.ibatis.reflection.wrapper.DefaultObjectWrapperFactory;
import org.apache.ibatis.reflection.wrapper.ObjectWrapperFactory;
import org.apache.ibatis.scripting.defaults.RawSqlSource;
import org.apache.ibatis.scripting.xmltags.DynamicContext;
import org.apache.ibatis.scripting.xmltags.SqlNode;
import org.apache.ibatis.session.Configuration;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import java.lang.reflect.Field;
import java.util.*;
@Intercepts({
@Signature(
type = Executor.class,
method = "query",
args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}),
@Signature(
type = Executor.class,
method = "query",
args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class, CacheKey.class, BoundSql.class})
})
public class SpecialCharConvertSqlInterceptor implements Interceptor {
private static final ObjectFactory DEFAULT_OBJECT_FACTORY = new DefaultObjectFactory();
private static final ObjectWrapperFactory DEFAULT_OBJECT_WRAPPER_FACTORY = new DefaultObjectWrapperFactory();
private static final ReflectorFactory DEFAULT_REFLECTOR_FACTORY = new DefaultReflectorFactory();
private static final String SQL_SOURCE = "sqlSource";
private static final String ROOT_SQL_NODE = "sqlSource.rootSqlNode";
private static final String LIKE_KEYWORD = "like";
@Override
public Object intercept(Invocation invocation) throws Throwable {
Object parameter = invocation.getArgs()[1];
MappedStatement statement = (MappedStatement) invocation.getArgs()[0];
MetaObject metaMappedStatement = MetaObject.forObject(statement, DEFAULT_OBJECT_FACTORY, DEFAULT_OBJECT_WRAPPER_FACTORY, DEFAULT_REFLECTOR_FACTORY);
BoundSql boundSql = statement.getBoundSql(parameter);
if (metaMappedStatement.hasGetter(SQL_SOURCE)) {
SqlSource sqlSourceObj = (SqlSource) metaMappedStatement.getValue(SQL_SOURCE);
Configuration configuration = statement.getConfiguration();
Object parameterObject = boundSql.getParameterObject();
Class<?> parameterType = parameterObject == null ? Object.class : parameterObject.getClass();
DynamicContext context = new DynamicContext(statement.getConfiguration(), boundSql.getParameterObject());
String sql;
SqlSource newSqlSource = sqlSourceObj;
// 没有占位符 单参数类型
if (sqlSourceObj instanceof RawSqlSource) {
RawSqlSource sqlSource = (RawSqlSource) sqlSourceObj;
Class<? extends RawSqlSource> aClass = sqlSource.getClass();
Field sqlField = aClass.getDeclaredField(SQL_SOURCE);
sqlField.setAccessible(true);
Object staticSqlSource = sqlField.get(sqlSource);
if (staticSqlSource instanceof StaticSqlSource) {
Class<? extends StaticSqlSource> rawSqlSource = ((StaticSqlSource) staticSqlSource).getClass();
Field sqlInStatic = rawSqlSource.getDeclaredField("sql");
sqlInStatic.setAccessible(true);
String sqlStr = (String) sqlInStatic.get(staticSqlSource);
if (sqlStr.toLowerCase().contains(LIKE_KEYWORD)) {
sql = modifyLikeSqlForRawSqlSource(sqlStr, parameterObject);
//构建新的sqlSource;
newSqlSource = new StaticSqlSource(configuration, sql, ((StaticSqlSource) staticSqlSource).getBoundSql(parameterObject).getParameterMappings());
}
}
} else if (metaMappedStatement.hasGetter(ROOT_SQL_NODE)) {
// 有占位符的类型
SqlNode sqlNode = (SqlNode) metaMappedStatement.getValue(ROOT_SQL_NODE);
sqlNode.apply(context);
String contextSql = context.getSql();
// like sql 特殊处理
sql = modifyLikeSql(contextSql, parameterObject);
//构建新的sqlSource;
SqlSourceBuilder sqlSourceBuilder = new SqlSourceBuilder(configuration);
newSqlSource = sqlSourceBuilder.parse(sql, parameterType, context.getBindings());
}
MappedStatement newMs = newMappedStatement(statement, buildNewBoundSqlSource(newSqlSource, parameterObject, context.getBindings()));
invocation.getArgs()[0] = newMs;
}
return invocation.proceed();
}
private SqlSource buildNewBoundSqlSource(SqlSource newSqlSource, Object paramObject, Map<String, Object> objectMap) {
BoundSql newBoundSql = newSqlSource.getBoundSql(paramObject);
for (Map.Entry<String, Object> entry : objectMap.entrySet()) {
newBoundSql.setAdditionalParameter(entry.getKey(), entry.getValue());
}
return new SqlSourceWrapper(newBoundSql);
}
/**
* 对 RawSqlSource 类型的 SQL 语句进行 Like 查询的修改
*
* @param sqlStr 原始 SQL 语句
* @param parameterObject 查询参数对象
* @return 修改后的 SQL 语句
*/
private String modifyLikeSqlForRawSqlSource(String sqlStr, Object parameterObject) {
MetaObject metaObject = MetaObject.forObject(parameterObject, DEFAULT_OBJECT_FACTORY, DEFAULT_OBJECT_WRAPPER_FACTORY, DEFAULT_REFLECTOR_FACTORY);
String[] values = metaObject.getGetterNames();
return setValueForMetaObject(sqlStr, Arrays.asList(values), metaObject);
}
private String setValueForMetaObject(String sql, List<String> values, MetaObject metaObject) {
for (String param : values) {
Object val = metaObject.getValue(param);
if (val != null && val instanceof String && (val.toString().contains("%") || val.toString().contains("_") || val.toString().contains("\"))) {
val = specialCharacterReplace(val.toString());
metaObject.setValue(param, val);
}
}
return sql;
}
private MappedStatement newMappedStatement(MappedStatement ms, SqlSource newSqlSource) {
MappedStatement.Builder builder = new
MappedStatement.Builder(ms.getConfiguration(), ms.getId(), newSqlSource, ms.getSqlCommandType());
builder.parameterMap(ms.getParameterMap());
builder.resource(ms.getResource());
builder.fetchSize(ms.getFetchSize());
builder.statementType(ms.getStatementType());
builder.keyGenerator(ms.getKeyGenerator());
builder.timeout(ms.getTimeout());
builder.parameterMap(ms.getParameterMap());
builder.resultMaps(ms.getResultMaps());
builder.resultSetType(ms.getResultSetType());
builder.cache(ms.getCache());
builder.flushCacheRequired(ms.isFlushCacheRequired());
builder.useCache(ms.isUseCache());
builder.resultOrdered(ms.isResultOrdered());
return builder.build();
}
private String modifyLikeSql(String sql, Object parameterObject) {
if (!sql.toLowerCase().contains(LIKE_KEYWORD)) {
return sql;
}
List<String> replaceFiled = new ArrayList<>();
String[] likes = sql.split(LIKE_KEYWORD);
for (String str : likes) {
String val = getParameterKey(str);
if (StrUtil.isNotBlank(val)) {
replaceFiled.add(val);
}
}
// 修改参数
MetaObject metaObject = MetaObject.forObject(parameterObject, DEFAULT_OBJECT_FACTORY, DEFAULT_OBJECT_WRAPPER_FACTORY, DEFAULT_REFLECTOR_FACTORY);
return setValueForMetaObject(sql, replaceFiled, metaObject);
}
/**
* 将 % 替换成 %
* 将 _ 替换成 _
* 将 \ 替换成 \
*
* @param str 编译后的sql
* @return str
*/
private String specialCharacterReplace(String str) {
str = str.replace("\", "\\");
str = str.replace("%", "\%");
str = str.replace("_", "\_");
return str;
}
/**
* 将 like 后的参数名取出来
*
* @param input 编译后sql
* @return 参数名称
*/
private String getParameterKey(String input) {
String key = "";
// 只取包含concat的那部分
if (input.contains("concat")) {
String[] temp = input.split("#");
if (temp.length > 1) {
key = temp[1];
key = key.replace("{", "").replace("}", "").split(",")[0];
}
}
return key;
}
@Override
public void setProperties(Properties properties) {
}
static class SqlSourceWrapper implements SqlSource {
private final BoundSql boundSql;
@SuppressWarnings("checkstyle:RedundantModifier")
public SqlSourceWrapper(BoundSql boundSql) {
this.boundSql = boundSql;
}
@Override
public BoundSql getBoundSql(Object parameterObject) {
return boundSql;
}
}
}
遇到的问题
在第一版中,直接将SpecialCharConvertSqlInterceptor
注入spring容器并运行,就出现了一个问题:
在列表查询中包含like条件的可以正常查询,而在分页中结果就与预期不一致,切确的说,是在统计分页总数的时候正确的,但到了具体查询数据时,结果就不正确了。
查看sql后发现是关键字符被替换了两次
这是为什么呢?
项目中分页用的是com.github.pagehelper
它在执行分页查询前会先执行COUNT_SQL
而在执行COUNT_SQL
时,参数值已经被替换了一次,再次执行数据查询时,并不会去再次解析参数和参数值,而是直接复用执行COUNT_SQL
时已解析的值,这就会导致包含特殊字符的参数值被替换两次,从而导致结果不正确。
经过一番研究后认为是拦截器执行顺序问题,因为pagehelper
也是基于拦截器实现的,应当先执行我的转义拦截器,再执行pagehelper
的拦截器,所以不能采用直接注入的方式,而是要手工注入,并将拦截器的顺序进行调整。
java
package com.abc.ef.conf;
import com.google.common.collect.ImmutableList;
import com.abc.efg.interceptors.BaseEntityInterceptor;
import com.abc.efg.interceptors.SpecialCharConvertSqlInterceptor;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.session.SqlSessionFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationListener;
import org.springframework.context.event.ContextRefreshedEvent;
import org.springframework.stereotype.Component;
import java.util.List;
/**
* org.apache.ibatis.session.Configuration#newExecutor(org.apache.ibatis.transaction.Transaction, org.apache.ibatis.session.ExecutorType)
* interceptorChain.pluginAll(executor); 该方法会逆序创建代理对象,自定义拦截器需要添加到最后一个中
*/
@Component
public class CustomerInterceptorRegister implements ApplicationListener<ContextRefreshedEvent> {
@Autowired
private List<SqlSessionFactory> sqlSessionFactoryList;
@Override
public void onApplicationEvent(ContextRefreshedEvent event) {
for (SqlSessionFactory sqlSessionFactory : sqlSessionFactoryList) {
org.apache.ibatis.session.Configuration configuration = sqlSessionFactory.getConfiguration();
for (Interceptor interceptor : getCustomerInterceptors()) {
if (!containsInterceptor(configuration, interceptor)) {
configuration.addInterceptor(interceptor);
}
}
}
}
private List<Interceptor> getCustomerInterceptors() {
return ImmutableList.of(new SpecialCharConvertSqlInterceptor(), new BaseEntityInterceptor());
}
/**
* 是否已经存在相同的拦截器
*
* @param configuration 配置类
* @param interceptor 拦截器
* @return 是否存在
*/
private boolean containsInterceptor(org.apache.ibatis.session.Configuration configuration, Interceptor interceptor) {
try {
// getInterceptors since 3.2.2
return configuration.getInterceptors().stream().anyMatch(config -> interceptor.getClass().isAssignableFrom(config.getClass()));
} catch (Exception e) {
return false;
}
}
}
经过以上改造,问题得以解决,顺手记录一下。
附录
mybatis相关版本
xml
<dependency>
<groupId>tk.mybatis</groupId>
<artifactId>mapper</artifactId>
<version>4.2.3</version>
</dependency>
<dependency>
<groupId>org.mybatis</groupId>
<artifactId>mybatis</artifactId>
<version>3.5.13</version>
</dependency>
<dependency>
<groupId>org.mybatis</groupId>
<artifactId>mybatis-spring</artifactId>
<version>2.1.1</version>
</dependency>
pagehelper版本
xml
<dependency>
<groupId>com.github.pagehelper</groupId>
<artifactId>pagehelper-spring-boot-starter</artifactId>
<version>1.4.7</version>
</dependency>