SpringBoot+MybatisPlus+自定义注解+切面实现水平数据隔离功能(附代码下载)

场景

业务场景中,需要对某些表中的数据做水平的数据隔离,比如某些表中如果含有某个字段,比如store_id(门店id)这个字段,

则对某些有对应门店权限的用户角色开放数据,如果请求的用户没有对该门店的权限,则自动对sql进行拦截添加where条件。

当然如果同一张表,又必须要查询全量数据,又可以通过添加自定义注解的方式,跳过数据隔离,返回全量数据。

并且如果用户没有任何门店的权限,或其他类似权限限制,则直接不执行查询,返回数据为空。

注:

博客:
https://blog.csdn.net/badao_liumang_qizhi

实现

新建SpringBoot项目,并引入相关依赖

如下依赖特别关注:

复制代码
        <!--MybatisPlus依赖-->
        <dependency>
            <groupId>com.baomidou</groupId>
            <artifactId>mybatis-plus-boot-starter</artifactId>
            <version>3.5.1</version>
        </dependency>
        <dependency>
            <groupId>org.aspectj</groupId>
            <artifactId>aspectjrt</artifactId>
            <version>1.9.7</version>
        </dependency>
        <dependency>
            <groupId>org.aspectj</groupId>
            <artifactId>aspectjweaver</artifactId>
            <version>1.9.7</version>
        </dependency>
        <dependency>
            <groupId>org.springframework</groupId>
            <artifactId>spring-aspects</artifactId>
        </dependency>

注意:

mybatis-plus-boot-starter 3.5.1 已包含 JSqlParser 依赖

所以此处不需要额外引入如下依赖:

复制代码
<dependency>
    <groupId>com.github.jsqlparser</groupId>
    <artifactId>jsqlparser</artifactId>
    <version>4.3</version> <!-- MyBatis-Plus 3.5.1 使用的版本 -->
</dependency>

另外还需引入其它非关键依赖,按需选择:

复制代码
       <!-- spring-boot -->
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-web</artifactId>
        </dependency>
        <!-- spring-boot-test -->
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-test</artifactId>
            <scope>test</scope>
            <exclusions>
                <exclusion>
                    <groupId>org.junit.vintage</groupId>
                    <artifactId>junit-vintage-engine</artifactId>
                </exclusion>
            </exclusions>
        </dependency>
        <!-- lombok -->
        <dependency>
            <groupId>org.projectlombok</groupId>
            <artifactId>lombok</artifactId>
            <version>1.18.26</version>
            <scope>provided</scope>
        </dependency>
        <!-- 数据库连接 -->
        <dependency>
            <groupId>mysql</groupId>
            <artifactId>mysql-connector-java</artifactId>
        </dependency>

添加mybatisplus的配置类,在配置类中实现初始化表缓存、定期刷新表缓存、注册数据隔离拦截器操作

代码实现如下:

复制代码
@Configuration
@MapperScan("com.badao.demo.mapper")
public class MybatisPlusConfig {

    // 关键功能:
    // 1. 初始化时扫描数据库表结构(initTableCache)
    // 2. 定时刷新表结构缓存(scheduleCacheRefresh)
    // 3. 注册MyBatis-Plus拦截器链
    public MybatisPlusConfig(DataSource dataSource) {
        this.dataSource = dataSource;
        initTableCache();
        //每5分钟刷新缓存(应对表结构变更)
        scheduleCacheRefresh();
    }

    // 数据源
    private final DataSource dataSource;
    // 模式名称
    public static final String DATABASE_B_GAS_STATION = "test";
    // 含有门店id字段的数据表
    // 使用ConcurrentHashMap保证线程安全
    private final Set<String> tablesWithStoreId = Collections.newSetFromMap(new ConcurrentHashMap<>());

    @Bean
    public MybatisPlusInterceptor mybatisPlusInterceptor() {
        MybatisPlusInterceptor interceptor = new MybatisPlusInterceptor();
        // 注册数据水平隔离拦截器
        interceptor.addInnerInterceptor(new StoreDataInterceptor(tablesWithStoreId));
        // 注册分页拦截器
        interceptor.addInnerInterceptor(new PaginationInnerInterceptor(DbType.MYSQL));
        return interceptor;
    }


    /**
     * 初始化表缓存
     */
    private void initTableCache() {
        try (Connection conn = dataSource.getConnection()) {
            DatabaseMetaData metaData = conn.getMetaData();

            Map<String, Set<String>> tableColumnsMap = new HashMap<>();
            try (ResultSet columns = metaData.getColumns(DATABASE_B_GAS_STATION, null, "%", "%")) {
                while (columns.next()) {
                    String tableName = columns.getString("TABLE_NAME").toLowerCase();
                    String columnName = columns.getString("COLUMN_NAME").toLowerCase();
                    tableColumnsMap.computeIfAbsent(tableName, k -> new HashSet<>()).add(columnName);
                }
            }

            // 获取所有表
            try (ResultSet tables = metaData.getTables(DATABASE_B_GAS_STATION, null, "%", new String[]{"TABLE"})) {
                while (tables.next()) {
                    String tableName = tables.getString("TABLE_NAME").toLowerCase();
                    Set<String> columns = tableColumnsMap.getOrDefault(tableName, Collections.emptySet());

                    if (columns.contains(StoreDataInterceptor.STORE_ID)) {
                        tablesWithStoreId.add(tableName);
                    }
                }
            }
        } catch (Exception e) {
            throw new RuntimeException("SellerIso: Failed to init table cache", e);
        }
    }

    /**
     * 定时刷新表结构缓存
     */
    private void scheduleCacheRefresh() {
        //刷新表结构缓存
        ScheduledExecutorService scheduler = Executors.newSingleThreadScheduledExecutor();
        scheduler.scheduleAtFixedRate(this::initTableCache, 5, 5, TimeUnit.MINUTES);
    }
}

数据隔离拦截器实现代码

复制代码
public class StoreDataInterceptor implements InnerInterceptor {

    private final Set<String> tablesWithStoreId;

    //隔离字段column
    public static final String STORE_ID = "store_id";

    public StoreDataInterceptor(Set<String> tablesWithStoreId) {
        this.tablesWithStoreId = tablesWithStoreId;
    }

    /**
     * 优先级高于SQL改写
     * 若返回false,则不会触发后续的beforeQuery(SQL重写逻辑)
     */
    @Override
    public boolean willDoQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) {
        // 指定跳过数据隔离
        if (SkipDataIsolation.getMethodSkipDataIsolation()) {
            return true;
        }
        // 其它业务逻辑则不予查询,比如获取请求头中的数据做权限校验,完全禁止无权限的查询(如未登录用户)
//        if(!CollectionUtils.isEmpty(UserContextHolder.getStoreIds())
//        {
//            return false;
//        }
        return true;
    }

    @Override
    public void beforeQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds,
                            ResultHandler resultHandler, BoundSql boundSql) {
        // 如果用户是超管用户则跳过拦截器-自己添加逻辑判断
        if (false) {
            return;
        }
        // 指定跳过数据隔离
        if (SkipDataIsolation.getMethodSkipDataIsolation()) {
            return;
        }

        String sql = boundSql.getSql();
        try {
            //解析SQL并重写
            Select select = (Select) CCJSqlParserUtil.parse(sql);
            SelectBody selectBody = select.getSelectBody();

            // 递归处理所有SELECT部分
            processSelectBody(selectBody);

            PluginUtils.mpBoundSql(boundSql).sql(select.toString());
        } catch (Exception e) {
            System.out.println(e.getMessage());
        }
    }

    /**
     * 处理PlainSelect
     */
    private void processSelectBody(SelectBody selectBody) {
        if (selectBody instanceof PlainSelect) {
            processPlainSelect((PlainSelect) selectBody);
        } else if (selectBody instanceof SetOperationList) {
            // 处理UNION/INTERSECT等
            for (SelectBody body : ((SetOperationList) selectBody).getSelects()) {
                processSelectBody(body);
            }
        }
        // 其他类型如WithItem暂不处理
    }

    /**
     * 处理FROM项
     */
    private void processPlainSelect(PlainSelect plainSelect) {
        // 1. 处理FROM项
        Map<String, String> aliasTableMap = new HashMap<>();
        processFromItem(plainSelect.getFromItem(), aliasTableMap);

        // 2. 处理JOIN表
        if (plainSelect.getJoins() != null) {
            for (Join join : plainSelect.getJoins()) {
                processFromItem(join.getRightItem(), aliasTableMap);
            }
        }

        // 3. 添加条件到当前SELECT
        addConditionsToSelect(plainSelect, aliasTableMap);

        // 4. 递归处理子查询
        processSubQueries(plainSelect);
    }

    /**
     * 处理查询
     */
    private void processFromItem(FromItem fromItem, Map<String, String> aliasTableMap) {
        if (fromItem instanceof Table) {
            Table table = (Table) fromItem;
            String tableName = table.getName().toLowerCase();
            String alias = table.getAlias() != null ?
                    table.getAlias().getName().toLowerCase() : tableName;

            // 缓存别名映射
            aliasTableMap.put(alias, tableName);
        } else if (fromItem instanceof SubSelect) {
            // 处理子查询
            processSelectBody(((SubSelect) fromItem).getSelectBody());
        }
    }

    /**
     * 处理子查询
     */
    private void processSubQueries(PlainSelect plainSelect) {
        // 1. 处理WHERE子句中的子查询
        if (plainSelect.getWhere() != null) {
            plainSelect.getWhere().accept(new SafeExpressionVisitor());
        }

        // 2. 处理SELECT列表中的子查询
        for (SelectItem item : plainSelect.getSelectItems()) {
            item.accept(new SafeSelectItemVisitor());
        }
    }

    /**
     * 添加查询条件到SELECT
     */
    private void addConditionsToSelect(PlainSelect plainSelect, Map<String, String> aliasTableMap) {
        // 检查哪些表需要添加条件
        for (Map.Entry<String, String> entry : aliasTableMap.entrySet()) {
            String alias = entry.getKey();
            String tableName = entry.getValue();
            //List<String> storeIds = UserContextHolder.getStoreIds();
            //此处用模拟数据示例
            List<String> storeIds = new ArrayList(){{
                this.add("1");
                this.add("2");
            }};
            //对含store_id的表自动添加条件:
            if (tablesWithStoreId.contains(tableName)) {
                handleSelectSql(alias, plainSelect, storeIds);
            }
        }
    }

    /**
     * 创建查询表达式
     */
    private static void handleSelectSql(String alias, PlainSelect plainSelect,
                                        List<String> companyChannelIds) {
        // 创建条件表达式
        Column channelColumn = new Column(alias + "." + StoreDataInterceptor.STORE_ID);

        // 创建表达式列表
        ExpressionList expressionList = new ExpressionList();
        // 手动初始化expressions列表
        expressionList.setExpressions(new ArrayList<>());

        for (String id : companyChannelIds) {
            expressionList.getExpressions().add(new StringValue(id));
        }

        // 构建条件表达式:WHERE (store_id IN (1,2) OR store_id IS NULL)
        // 创建IN表达式
        InExpression inExpression = new InExpression(channelColumn, expressionList);
        // 添加or 数据隔离字段is null 条件避免联表查询时未能关联数据导致全部数据被过滤
        IsNullExpression isNullExpression = new IsNullExpression();
        isNullExpression.setLeftExpression(channelColumn);
        OrExpression orExpression = new OrExpression(isNullExpression, inExpression);
        // 调整or条件优先级 加()
        Parenthesis parenthesis = new Parenthesis(orExpression);

        // 获取现有WHERE条件
        Expression where = plainSelect.getWhere();
        plainSelect.setWhere(where == null ? parenthesis : new AndExpression(where, parenthesis));
    }

    /**
     * 避免查询无限递归
     */
    private class SafeExpressionVisitor extends ExpressionVisitorAdapter {
        private final Set<Object> visitedObjects = Collections.newSetFromMap(new IdentityHashMap<>());

        @Override
        public void visit(SubSelect subSelect) {
            // 防止SubSelect无限递归
            if (visitedObjects.add(subSelect)) {
                try {
                    // 限制递归深度
                    if (visitedObjects.size() < 50) {
                        processSelectBody(subSelect.getSelectBody());
                    }
                } catch (Exception e) {
                    System.out.println(e.getMessage());
                } finally {
                    visitedObjects.remove(subSelect);
                }
            }
        }

        @Override
        public void visit(AllColumns allColumns) {
            // 关键:避免处理AllColumns时的无限递归
            // 在JSqlParser 4.3中,这里不能调用super.visit(allColumns)
        }
    }

    /**
     * 避免查询无限递归
     */
    private class SafeSelectItemVisitor extends SelectItemVisitorAdapter {
        @Override
        public void visit(SelectExpressionItem item) {
            try {
                item.getExpression().accept(new SafeExpressionVisitor());
            } catch (Exception e) {
                System.out.println(e.getMessage());
            }
        }
    }
}

代码如下:

注意:

1、willDoQuery中

核心作用

拦截器开关控制

决定是否允许当前SQL查询继续执行(true放行,false拦截)

与跳过机制集成

通过检查SkipDataIsolation的线程状态,实现动态拦截控制

方法调用时机

sequenceDiagram

MyBatis->>StoreDataInterceptor: 执行查询前

StoreDataInterceptor->>willDoQuery: 检查拦截条件

alt 返回true

MyBatis->>DB: 正常执行查询

else 返回false

MyBatis->>调用方: 直接返回空结果

end

优先级高于SQL改写

若返回false,则不会触发后续的beforeQuery(SQL重写逻辑)

典型使用场景

完全禁止无权限的查询(如未登录用户)

快速跳过无需处理的查询类型(如特定Mapper方法)

2、beforeQuery中

如果用户是超管用户则跳过拦截器-自己添加逻辑判断

addConditionsToSelect添加查询条件中SELECT中,获取当前用户的门店id权限使用模拟数据演示效果。

正常应该是从权限控制相关业务中获取,此处注意使用时修改。

自定义跳过数据隔离注解实现

复制代码
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface SkipDataIsolationAnnotation {
}

跳过数据隔离切面实现

复制代码
/**
 * 跳过数据隔离切面
 */
@Aspect
@Component
public class SkipDataIsolationAspect {

    //Around增强:在方法执行前后插入逻辑
    @Around("@annotation(skipDataIsolationAnnotation)")
    public Object handleSkipDataIsolation(ProceedingJoinPoint joinPoint,
                                          SkipDataIsolationAnnotation skipDataIsolationAnnotation) throws Throwable {
        try {
            //进入方法时设置ThreadLocal标志为true
            SkipDataIsolation.setMethodSkipDataIsolation(true);// 设置线程标志
            return joinPoint.proceed();
        } finally {
            //通过try-finally确保异常时也能清理状态
            SkipDataIsolation.methodClear(); // 清理线程状态
        }
    }
}

上下文控制器实现

复制代码
/**
 * 上下文控制器
 */
public class SkipDataIsolation {

    // 单次sql语句级别跳过数据隔离: 使用ThreadLocal存储跳过数据隔离的标志,默认不跳过value=false
    private static final ThreadLocal<Boolean> SKIP_DATA_ISOLATION = ThreadLocal.withInitial(() -> false);
    // 方法级别跳过数据隔离: 使用ThreadLocal存储跳过数据隔离的标志, 默认不跳过value=false
    private static final ThreadLocal<Boolean> SKIP_DATA_ISOLATION_METHOD = ThreadLocal.withInitial(() -> false);

    /**
     * 单次sql级别:设置跳过数据隔离标志
     */
    public static void setSkipDataIsolation(Boolean skip) {
        SKIP_DATA_ISOLATION.set(skip);
    }

    /**
     * 单次sql级别:获取跳过数据隔离标志
     */
    public static Boolean getSkipDataIsolation() {
        return SKIP_DATA_ISOLATION.get();
    }

    /**
     * 单次sql级别:清理ThreadLocal,防止内存泄漏
     */
    public static void clear() {
        SKIP_DATA_ISOLATION.remove();
    }

    /**
     * 方法级别:设置跳过数据隔离标志
     */
    public static void setMethodSkipDataIsolation(Boolean skip) {
        SKIP_DATA_ISOLATION_METHOD.set(skip);
    }

    /**
     * 方法级别:获取跳过数据隔离标志
     */
    public static Boolean getMethodSkipDataIsolation() {
        return SKIP_DATA_ISOLATION_METHOD.get();
    }

    /**
     * 方法级别:清理ThreadLocal,防止内存泄漏
     */
    public static void methodClear() {
        SKIP_DATA_ISOLATION_METHOD.remove();
    }
}

测试效果

新建一个包含store_id字段的表,并生成5条数据,其中有两条数据store_id为1和2。

新建两个controller并且一个添加跳过数据隔离注解,一个不添加,执行同样的mp的条件查询。

不进行数据隔离的查询效果

带数据隔离的效果

完整示例代码以及SQL文件资源下载

https://download.csdn.net/download/BADAO_LIUMANG_QIZHI/92218402

相关推荐
间彧3 小时前
RocketMQ消息幂等控制:借助ConcurrentHashMap的putIfAbsent方法实现
后端
海边夕阳20063 小时前
深入解析volatile关键字:多线程环境下的内存可见性与指令重排序防护
java·开发语言·jvm·架构
ZeroKoop3 小时前
JDK版本管理工具JVMS
java·开发语言
rengang663 小时前
101-Spring AI Alibaba RAG 示例
java·人工智能·spring·rag·spring ai·ai应用编程
韩立学长3 小时前
【开题答辩实录分享】以《智慧校园勤工俭学信息管理系统的设计与实现》为例进行答辩实录分享
vue.js·spring boot·微信小程序
乾坤瞬间3 小时前
【Java后端进行ai coding实践系列二】记住规范,记住内容,如何使用iflow进行上下文管理
java·开发语言·ai编程
迦蓝叶3 小时前
JAiRouter v1.1.0 发布:把“API 调没调通”从 10 分钟压缩到 10 秒
java·人工智能·网关·openai·api·协议归一
不知道累,只知道类3 小时前
记一次诡异的“偶发 404”排查:CDN 回源到 OSS 导致 REST API 失败
java·云原生
lang201509283 小时前
Spring数据库连接控制全解析
java·数据库·spring