场景
业务场景中,需要对某些表中的数据做水平的数据隔离,比如某些表中如果含有某个字段,比如store_id(门店id)这个字段,
则对某些有对应门店权限的用户角色开放数据,如果请求的用户没有对该门店的权限,则自动对sql进行拦截添加where条件。
当然如果同一张表,又必须要查询全量数据,又可以通过添加自定义注解的方式,跳过数据隔离,返回全量数据。
并且如果用户没有任何门店的权限,或其他类似权限限制,则直接不执行查询,返回数据为空。
注:
博客:
https://blog.csdn.net/badao_liumang_qizhi
实现
新建SpringBoot项目,并引入相关依赖
如下依赖特别关注:
<!--MybatisPlus依赖-->
<dependency>
<groupId>com.baomidou</groupId>
<artifactId>mybatis-plus-boot-starter</artifactId>
<version>3.5.1</version>
</dependency>
<dependency>
<groupId>org.aspectj</groupId>
<artifactId>aspectjrt</artifactId>
<version>1.9.7</version>
</dependency>
<dependency>
<groupId>org.aspectj</groupId>
<artifactId>aspectjweaver</artifactId>
<version>1.9.7</version>
</dependency>
<dependency>
<groupId>org.springframework</groupId>
<artifactId>spring-aspects</artifactId>
</dependency>
注意:
mybatis-plus-boot-starter 3.5.1 已包含 JSqlParser 依赖
所以此处不需要额外引入如下依赖:
<dependency>
<groupId>com.github.jsqlparser</groupId>
<artifactId>jsqlparser</artifactId>
<version>4.3</version> <!-- MyBatis-Plus 3.5.1 使用的版本 -->
</dependency>
另外还需引入其它非关键依赖,按需选择:
<!-- spring-boot -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<!-- spring-boot-test -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
<scope>test</scope>
<exclusions>
<exclusion>
<groupId>org.junit.vintage</groupId>
<artifactId>junit-vintage-engine</artifactId>
</exclusion>
</exclusions>
</dependency>
<!-- lombok -->
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<version>1.18.26</version>
<scope>provided</scope>
</dependency>
<!-- 数据库连接 -->
<dependency>
<groupId>mysql</groupId>
<artifactId>mysql-connector-java</artifactId>
</dependency>
添加mybatisplus的配置类,在配置类中实现初始化表缓存、定期刷新表缓存、注册数据隔离拦截器操作
代码实现如下:
@Configuration
@MapperScan("com.badao.demo.mapper")
public class MybatisPlusConfig {
// 关键功能:
// 1. 初始化时扫描数据库表结构(initTableCache)
// 2. 定时刷新表结构缓存(scheduleCacheRefresh)
// 3. 注册MyBatis-Plus拦截器链
public MybatisPlusConfig(DataSource dataSource) {
this.dataSource = dataSource;
initTableCache();
//每5分钟刷新缓存(应对表结构变更)
scheduleCacheRefresh();
}
// 数据源
private final DataSource dataSource;
// 模式名称
public static final String DATABASE_B_GAS_STATION = "test";
// 含有门店id字段的数据表
// 使用ConcurrentHashMap保证线程安全
private final Set<String> tablesWithStoreId = Collections.newSetFromMap(new ConcurrentHashMap<>());
@Bean
public MybatisPlusInterceptor mybatisPlusInterceptor() {
MybatisPlusInterceptor interceptor = new MybatisPlusInterceptor();
// 注册数据水平隔离拦截器
interceptor.addInnerInterceptor(new StoreDataInterceptor(tablesWithStoreId));
// 注册分页拦截器
interceptor.addInnerInterceptor(new PaginationInnerInterceptor(DbType.MYSQL));
return interceptor;
}
/**
* 初始化表缓存
*/
private void initTableCache() {
try (Connection conn = dataSource.getConnection()) {
DatabaseMetaData metaData = conn.getMetaData();
Map<String, Set<String>> tableColumnsMap = new HashMap<>();
try (ResultSet columns = metaData.getColumns(DATABASE_B_GAS_STATION, null, "%", "%")) {
while (columns.next()) {
String tableName = columns.getString("TABLE_NAME").toLowerCase();
String columnName = columns.getString("COLUMN_NAME").toLowerCase();
tableColumnsMap.computeIfAbsent(tableName, k -> new HashSet<>()).add(columnName);
}
}
// 获取所有表
try (ResultSet tables = metaData.getTables(DATABASE_B_GAS_STATION, null, "%", new String[]{"TABLE"})) {
while (tables.next()) {
String tableName = tables.getString("TABLE_NAME").toLowerCase();
Set<String> columns = tableColumnsMap.getOrDefault(tableName, Collections.emptySet());
if (columns.contains(StoreDataInterceptor.STORE_ID)) {
tablesWithStoreId.add(tableName);
}
}
}
} catch (Exception e) {
throw new RuntimeException("SellerIso: Failed to init table cache", e);
}
}
/**
* 定时刷新表结构缓存
*/
private void scheduleCacheRefresh() {
//刷新表结构缓存
ScheduledExecutorService scheduler = Executors.newSingleThreadScheduledExecutor();
scheduler.scheduleAtFixedRate(this::initTableCache, 5, 5, TimeUnit.MINUTES);
}
}
数据隔离拦截器实现代码
public class StoreDataInterceptor implements InnerInterceptor {
private final Set<String> tablesWithStoreId;
//隔离字段column
public static final String STORE_ID = "store_id";
public StoreDataInterceptor(Set<String> tablesWithStoreId) {
this.tablesWithStoreId = tablesWithStoreId;
}
/**
* 优先级高于SQL改写
* 若返回false,则不会触发后续的beforeQuery(SQL重写逻辑)
*/
@Override
public boolean willDoQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) {
// 指定跳过数据隔离
if (SkipDataIsolation.getMethodSkipDataIsolation()) {
return true;
}
// 其它业务逻辑则不予查询,比如获取请求头中的数据做权限校验,完全禁止无权限的查询(如未登录用户)
// if(!CollectionUtils.isEmpty(UserContextHolder.getStoreIds())
// {
// return false;
// }
return true;
}
@Override
public void beforeQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds,
ResultHandler resultHandler, BoundSql boundSql) {
// 如果用户是超管用户则跳过拦截器-自己添加逻辑判断
if (false) {
return;
}
// 指定跳过数据隔离
if (SkipDataIsolation.getMethodSkipDataIsolation()) {
return;
}
String sql = boundSql.getSql();
try {
//解析SQL并重写
Select select = (Select) CCJSqlParserUtil.parse(sql);
SelectBody selectBody = select.getSelectBody();
// 递归处理所有SELECT部分
processSelectBody(selectBody);
PluginUtils.mpBoundSql(boundSql).sql(select.toString());
} catch (Exception e) {
System.out.println(e.getMessage());
}
}
/**
* 处理PlainSelect
*/
private void processSelectBody(SelectBody selectBody) {
if (selectBody instanceof PlainSelect) {
processPlainSelect((PlainSelect) selectBody);
} else if (selectBody instanceof SetOperationList) {
// 处理UNION/INTERSECT等
for (SelectBody body : ((SetOperationList) selectBody).getSelects()) {
processSelectBody(body);
}
}
// 其他类型如WithItem暂不处理
}
/**
* 处理FROM项
*/
private void processPlainSelect(PlainSelect plainSelect) {
// 1. 处理FROM项
Map<String, String> aliasTableMap = new HashMap<>();
processFromItem(plainSelect.getFromItem(), aliasTableMap);
// 2. 处理JOIN表
if (plainSelect.getJoins() != null) {
for (Join join : plainSelect.getJoins()) {
processFromItem(join.getRightItem(), aliasTableMap);
}
}
// 3. 添加条件到当前SELECT
addConditionsToSelect(plainSelect, aliasTableMap);
// 4. 递归处理子查询
processSubQueries(plainSelect);
}
/**
* 处理查询
*/
private void processFromItem(FromItem fromItem, Map<String, String> aliasTableMap) {
if (fromItem instanceof Table) {
Table table = (Table) fromItem;
String tableName = table.getName().toLowerCase();
String alias = table.getAlias() != null ?
table.getAlias().getName().toLowerCase() : tableName;
// 缓存别名映射
aliasTableMap.put(alias, tableName);
} else if (fromItem instanceof SubSelect) {
// 处理子查询
processSelectBody(((SubSelect) fromItem).getSelectBody());
}
}
/**
* 处理子查询
*/
private void processSubQueries(PlainSelect plainSelect) {
// 1. 处理WHERE子句中的子查询
if (plainSelect.getWhere() != null) {
plainSelect.getWhere().accept(new SafeExpressionVisitor());
}
// 2. 处理SELECT列表中的子查询
for (SelectItem item : plainSelect.getSelectItems()) {
item.accept(new SafeSelectItemVisitor());
}
}
/**
* 添加查询条件到SELECT
*/
private void addConditionsToSelect(PlainSelect plainSelect, Map<String, String> aliasTableMap) {
// 检查哪些表需要添加条件
for (Map.Entry<String, String> entry : aliasTableMap.entrySet()) {
String alias = entry.getKey();
String tableName = entry.getValue();
//List<String> storeIds = UserContextHolder.getStoreIds();
//此处用模拟数据示例
List<String> storeIds = new ArrayList(){{
this.add("1");
this.add("2");
}};
//对含store_id的表自动添加条件:
if (tablesWithStoreId.contains(tableName)) {
handleSelectSql(alias, plainSelect, storeIds);
}
}
}
/**
* 创建查询表达式
*/
private static void handleSelectSql(String alias, PlainSelect plainSelect,
List<String> companyChannelIds) {
// 创建条件表达式
Column channelColumn = new Column(alias + "." + StoreDataInterceptor.STORE_ID);
// 创建表达式列表
ExpressionList expressionList = new ExpressionList();
// 手动初始化expressions列表
expressionList.setExpressions(new ArrayList<>());
for (String id : companyChannelIds) {
expressionList.getExpressions().add(new StringValue(id));
}
// 构建条件表达式:WHERE (store_id IN (1,2) OR store_id IS NULL)
// 创建IN表达式
InExpression inExpression = new InExpression(channelColumn, expressionList);
// 添加or 数据隔离字段is null 条件避免联表查询时未能关联数据导致全部数据被过滤
IsNullExpression isNullExpression = new IsNullExpression();
isNullExpression.setLeftExpression(channelColumn);
OrExpression orExpression = new OrExpression(isNullExpression, inExpression);
// 调整or条件优先级 加()
Parenthesis parenthesis = new Parenthesis(orExpression);
// 获取现有WHERE条件
Expression where = plainSelect.getWhere();
plainSelect.setWhere(where == null ? parenthesis : new AndExpression(where, parenthesis));
}
/**
* 避免查询无限递归
*/
private class SafeExpressionVisitor extends ExpressionVisitorAdapter {
private final Set<Object> visitedObjects = Collections.newSetFromMap(new IdentityHashMap<>());
@Override
public void visit(SubSelect subSelect) {
// 防止SubSelect无限递归
if (visitedObjects.add(subSelect)) {
try {
// 限制递归深度
if (visitedObjects.size() < 50) {
processSelectBody(subSelect.getSelectBody());
}
} catch (Exception e) {
System.out.println(e.getMessage());
} finally {
visitedObjects.remove(subSelect);
}
}
}
@Override
public void visit(AllColumns allColumns) {
// 关键:避免处理AllColumns时的无限递归
// 在JSqlParser 4.3中,这里不能调用super.visit(allColumns)
}
}
/**
* 避免查询无限递归
*/
private class SafeSelectItemVisitor extends SelectItemVisitorAdapter {
@Override
public void visit(SelectExpressionItem item) {
try {
item.getExpression().accept(new SafeExpressionVisitor());
} catch (Exception e) {
System.out.println(e.getMessage());
}
}
}
}
代码如下:
注意:
1、willDoQuery中
核心作用
拦截器开关控制
决定是否允许当前SQL查询继续执行(true放行,false拦截)
与跳过机制集成
通过检查SkipDataIsolation的线程状态,实现动态拦截控制
方法调用时机
sequenceDiagram
MyBatis->>StoreDataInterceptor: 执行查询前
StoreDataInterceptor->>willDoQuery: 检查拦截条件
alt 返回true
MyBatis->>DB: 正常执行查询
else 返回false
MyBatis->>调用方: 直接返回空结果
end
优先级高于SQL改写
若返回false,则不会触发后续的beforeQuery(SQL重写逻辑)
典型使用场景
完全禁止无权限的查询(如未登录用户)
快速跳过无需处理的查询类型(如特定Mapper方法)
2、beforeQuery中
如果用户是超管用户则跳过拦截器-自己添加逻辑判断
addConditionsToSelect添加查询条件中SELECT中,获取当前用户的门店id权限使用模拟数据演示效果。
正常应该是从权限控制相关业务中获取,此处注意使用时修改。
自定义跳过数据隔离注解实现
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface SkipDataIsolationAnnotation {
}
跳过数据隔离切面实现
/**
* 跳过数据隔离切面
*/
@Aspect
@Component
public class SkipDataIsolationAspect {
//Around增强:在方法执行前后插入逻辑
@Around("@annotation(skipDataIsolationAnnotation)")
public Object handleSkipDataIsolation(ProceedingJoinPoint joinPoint,
SkipDataIsolationAnnotation skipDataIsolationAnnotation) throws Throwable {
try {
//进入方法时设置ThreadLocal标志为true
SkipDataIsolation.setMethodSkipDataIsolation(true);// 设置线程标志
return joinPoint.proceed();
} finally {
//通过try-finally确保异常时也能清理状态
SkipDataIsolation.methodClear(); // 清理线程状态
}
}
}
上下文控制器实现
/**
* 上下文控制器
*/
public class SkipDataIsolation {
// 单次sql语句级别跳过数据隔离: 使用ThreadLocal存储跳过数据隔离的标志,默认不跳过value=false
private static final ThreadLocal<Boolean> SKIP_DATA_ISOLATION = ThreadLocal.withInitial(() -> false);
// 方法级别跳过数据隔离: 使用ThreadLocal存储跳过数据隔离的标志, 默认不跳过value=false
private static final ThreadLocal<Boolean> SKIP_DATA_ISOLATION_METHOD = ThreadLocal.withInitial(() -> false);
/**
* 单次sql级别:设置跳过数据隔离标志
*/
public static void setSkipDataIsolation(Boolean skip) {
SKIP_DATA_ISOLATION.set(skip);
}
/**
* 单次sql级别:获取跳过数据隔离标志
*/
public static Boolean getSkipDataIsolation() {
return SKIP_DATA_ISOLATION.get();
}
/**
* 单次sql级别:清理ThreadLocal,防止内存泄漏
*/
public static void clear() {
SKIP_DATA_ISOLATION.remove();
}
/**
* 方法级别:设置跳过数据隔离标志
*/
public static void setMethodSkipDataIsolation(Boolean skip) {
SKIP_DATA_ISOLATION_METHOD.set(skip);
}
/**
* 方法级别:获取跳过数据隔离标志
*/
public static Boolean getMethodSkipDataIsolation() {
return SKIP_DATA_ISOLATION_METHOD.get();
}
/**
* 方法级别:清理ThreadLocal,防止内存泄漏
*/
public static void methodClear() {
SKIP_DATA_ISOLATION_METHOD.remove();
}
}
测试效果
新建一个包含store_id字段的表,并生成5条数据,其中有两条数据store_id为1和2。
新建两个controller并且一个添加跳过数据隔离注解,一个不添加,执行同样的mp的条件查询。
不进行数据隔离的查询效果

带数据隔离的效果
完整示例代码以及SQL文件资源下载
https://download.csdn.net/download/BADAO_LIUMANG_QIZHI/92218402
