SpringBoot+MybatisPlus+自定义注解+切面实现水平数据隔离功能(附代码下载)

场景

业务场景中,需要对某些表中的数据做水平的数据隔离,比如某些表中如果含有某个字段,比如store_id(门店id)这个字段,

则对某些有对应门店权限的用户角色开放数据,如果请求的用户没有对该门店的权限,则自动对sql进行拦截添加where条件。

当然如果同一张表,又必须要查询全量数据,又可以通过添加自定义注解的方式,跳过数据隔离,返回全量数据。

并且如果用户没有任何门店的权限,或其他类似权限限制,则直接不执行查询,返回数据为空。

注:

博客:
https://blog.csdn.net/badao_liumang_qizhi

实现

新建SpringBoot项目,并引入相关依赖

如下依赖特别关注:

复制代码
        <!--MybatisPlus依赖-->
        <dependency>
            <groupId>com.baomidou</groupId>
            <artifactId>mybatis-plus-boot-starter</artifactId>
            <version>3.5.1</version>
        </dependency>
        <dependency>
            <groupId>org.aspectj</groupId>
            <artifactId>aspectjrt</artifactId>
            <version>1.9.7</version>
        </dependency>
        <dependency>
            <groupId>org.aspectj</groupId>
            <artifactId>aspectjweaver</artifactId>
            <version>1.9.7</version>
        </dependency>
        <dependency>
            <groupId>org.springframework</groupId>
            <artifactId>spring-aspects</artifactId>
        </dependency>

注意:

mybatis-plus-boot-starter 3.5.1 已包含 JSqlParser 依赖

所以此处不需要额外引入如下依赖:

复制代码
<dependency>
    <groupId>com.github.jsqlparser</groupId>
    <artifactId>jsqlparser</artifactId>
    <version>4.3</version> <!-- MyBatis-Plus 3.5.1 使用的版本 -->
</dependency>

另外还需引入其它非关键依赖,按需选择:

复制代码
       <!-- spring-boot -->
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-web</artifactId>
        </dependency>
        <!-- spring-boot-test -->
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-test</artifactId>
            <scope>test</scope>
            <exclusions>
                <exclusion>
                    <groupId>org.junit.vintage</groupId>
                    <artifactId>junit-vintage-engine</artifactId>
                </exclusion>
            </exclusions>
        </dependency>
        <!-- lombok -->
        <dependency>
            <groupId>org.projectlombok</groupId>
            <artifactId>lombok</artifactId>
            <version>1.18.26</version>
            <scope>provided</scope>
        </dependency>
        <!-- 数据库连接 -->
        <dependency>
            <groupId>mysql</groupId>
            <artifactId>mysql-connector-java</artifactId>
        </dependency>

添加mybatisplus的配置类,在配置类中实现初始化表缓存、定期刷新表缓存、注册数据隔离拦截器操作

代码实现如下:

复制代码
@Configuration
@MapperScan("com.badao.demo.mapper")
public class MybatisPlusConfig {

    // 关键功能:
    // 1. 初始化时扫描数据库表结构(initTableCache)
    // 2. 定时刷新表结构缓存(scheduleCacheRefresh)
    // 3. 注册MyBatis-Plus拦截器链
    public MybatisPlusConfig(DataSource dataSource) {
        this.dataSource = dataSource;
        initTableCache();
        //每5分钟刷新缓存(应对表结构变更)
        scheduleCacheRefresh();
    }

    // 数据源
    private final DataSource dataSource;
    // 模式名称
    public static final String DATABASE_B_GAS_STATION = "test";
    // 含有门店id字段的数据表
    // 使用ConcurrentHashMap保证线程安全
    private final Set<String> tablesWithStoreId = Collections.newSetFromMap(new ConcurrentHashMap<>());

    @Bean
    public MybatisPlusInterceptor mybatisPlusInterceptor() {
        MybatisPlusInterceptor interceptor = new MybatisPlusInterceptor();
        // 注册数据水平隔离拦截器
        interceptor.addInnerInterceptor(new StoreDataInterceptor(tablesWithStoreId));
        // 注册分页拦截器
        interceptor.addInnerInterceptor(new PaginationInnerInterceptor(DbType.MYSQL));
        return interceptor;
    }


    /**
     * 初始化表缓存
     */
    private void initTableCache() {
        try (Connection conn = dataSource.getConnection()) {
            DatabaseMetaData metaData = conn.getMetaData();

            Map<String, Set<String>> tableColumnsMap = new HashMap<>();
            try (ResultSet columns = metaData.getColumns(DATABASE_B_GAS_STATION, null, "%", "%")) {
                while (columns.next()) {
                    String tableName = columns.getString("TABLE_NAME").toLowerCase();
                    String columnName = columns.getString("COLUMN_NAME").toLowerCase();
                    tableColumnsMap.computeIfAbsent(tableName, k -> new HashSet<>()).add(columnName);
                }
            }

            // 获取所有表
            try (ResultSet tables = metaData.getTables(DATABASE_B_GAS_STATION, null, "%", new String[]{"TABLE"})) {
                while (tables.next()) {
                    String tableName = tables.getString("TABLE_NAME").toLowerCase();
                    Set<String> columns = tableColumnsMap.getOrDefault(tableName, Collections.emptySet());

                    if (columns.contains(StoreDataInterceptor.STORE_ID)) {
                        tablesWithStoreId.add(tableName);
                    }
                }
            }
        } catch (Exception e) {
            throw new RuntimeException("SellerIso: Failed to init table cache", e);
        }
    }

    /**
     * 定时刷新表结构缓存
     */
    private void scheduleCacheRefresh() {
        //刷新表结构缓存
        ScheduledExecutorService scheduler = Executors.newSingleThreadScheduledExecutor();
        scheduler.scheduleAtFixedRate(this::initTableCache, 5, 5, TimeUnit.MINUTES);
    }
}

数据隔离拦截器实现代码

复制代码
public class StoreDataInterceptor implements InnerInterceptor {

    private final Set<String> tablesWithStoreId;

    //隔离字段column
    public static final String STORE_ID = "store_id";

    public StoreDataInterceptor(Set<String> tablesWithStoreId) {
        this.tablesWithStoreId = tablesWithStoreId;
    }

    /**
     * 优先级高于SQL改写
     * 若返回false,则不会触发后续的beforeQuery(SQL重写逻辑)
     */
    @Override
    public boolean willDoQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) {
        // 指定跳过数据隔离
        if (SkipDataIsolation.getMethodSkipDataIsolation()) {
            return true;
        }
        // 其它业务逻辑则不予查询,比如获取请求头中的数据做权限校验,完全禁止无权限的查询(如未登录用户)
//        if(!CollectionUtils.isEmpty(UserContextHolder.getStoreIds())
//        {
//            return false;
//        }
        return true;
    }

    @Override
    public void beforeQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds,
                            ResultHandler resultHandler, BoundSql boundSql) {
        // 如果用户是超管用户则跳过拦截器-自己添加逻辑判断
        if (false) {
            return;
        }
        // 指定跳过数据隔离
        if (SkipDataIsolation.getMethodSkipDataIsolation()) {
            return;
        }

        String sql = boundSql.getSql();
        try {
            //解析SQL并重写
            Select select = (Select) CCJSqlParserUtil.parse(sql);
            SelectBody selectBody = select.getSelectBody();

            // 递归处理所有SELECT部分
            processSelectBody(selectBody);

            PluginUtils.mpBoundSql(boundSql).sql(select.toString());
        } catch (Exception e) {
            System.out.println(e.getMessage());
        }
    }

    /**
     * 处理PlainSelect
     */
    private void processSelectBody(SelectBody selectBody) {
        if (selectBody instanceof PlainSelect) {
            processPlainSelect((PlainSelect) selectBody);
        } else if (selectBody instanceof SetOperationList) {
            // 处理UNION/INTERSECT等
            for (SelectBody body : ((SetOperationList) selectBody).getSelects()) {
                processSelectBody(body);
            }
        }
        // 其他类型如WithItem暂不处理
    }

    /**
     * 处理FROM项
     */
    private void processPlainSelect(PlainSelect plainSelect) {
        // 1. 处理FROM项
        Map<String, String> aliasTableMap = new HashMap<>();
        processFromItem(plainSelect.getFromItem(), aliasTableMap);

        // 2. 处理JOIN表
        if (plainSelect.getJoins() != null) {
            for (Join join : plainSelect.getJoins()) {
                processFromItem(join.getRightItem(), aliasTableMap);
            }
        }

        // 3. 添加条件到当前SELECT
        addConditionsToSelect(plainSelect, aliasTableMap);

        // 4. 递归处理子查询
        processSubQueries(plainSelect);
    }

    /**
     * 处理查询
     */
    private void processFromItem(FromItem fromItem, Map<String, String> aliasTableMap) {
        if (fromItem instanceof Table) {
            Table table = (Table) fromItem;
            String tableName = table.getName().toLowerCase();
            String alias = table.getAlias() != null ?
                    table.getAlias().getName().toLowerCase() : tableName;

            // 缓存别名映射
            aliasTableMap.put(alias, tableName);
        } else if (fromItem instanceof SubSelect) {
            // 处理子查询
            processSelectBody(((SubSelect) fromItem).getSelectBody());
        }
    }

    /**
     * 处理子查询
     */
    private void processSubQueries(PlainSelect plainSelect) {
        // 1. 处理WHERE子句中的子查询
        if (plainSelect.getWhere() != null) {
            plainSelect.getWhere().accept(new SafeExpressionVisitor());
        }

        // 2. 处理SELECT列表中的子查询
        for (SelectItem item : plainSelect.getSelectItems()) {
            item.accept(new SafeSelectItemVisitor());
        }
    }

    /**
     * 添加查询条件到SELECT
     */
    private void addConditionsToSelect(PlainSelect plainSelect, Map<String, String> aliasTableMap) {
        // 检查哪些表需要添加条件
        for (Map.Entry<String, String> entry : aliasTableMap.entrySet()) {
            String alias = entry.getKey();
            String tableName = entry.getValue();
            //List<String> storeIds = UserContextHolder.getStoreIds();
            //此处用模拟数据示例
            List<String> storeIds = new ArrayList(){{
                this.add("1");
                this.add("2");
            }};
            //对含store_id的表自动添加条件:
            if (tablesWithStoreId.contains(tableName)) {
                handleSelectSql(alias, plainSelect, storeIds);
            }
        }
    }

    /**
     * 创建查询表达式
     */
    private static void handleSelectSql(String alias, PlainSelect plainSelect,
                                        List<String> companyChannelIds) {
        // 创建条件表达式
        Column channelColumn = new Column(alias + "." + StoreDataInterceptor.STORE_ID);

        // 创建表达式列表
        ExpressionList expressionList = new ExpressionList();
        // 手动初始化expressions列表
        expressionList.setExpressions(new ArrayList<>());

        for (String id : companyChannelIds) {
            expressionList.getExpressions().add(new StringValue(id));
        }

        // 构建条件表达式:WHERE (store_id IN (1,2) OR store_id IS NULL)
        // 创建IN表达式
        InExpression inExpression = new InExpression(channelColumn, expressionList);
        // 添加or 数据隔离字段is null 条件避免联表查询时未能关联数据导致全部数据被过滤
        IsNullExpression isNullExpression = new IsNullExpression();
        isNullExpression.setLeftExpression(channelColumn);
        OrExpression orExpression = new OrExpression(isNullExpression, inExpression);
        // 调整or条件优先级 加()
        Parenthesis parenthesis = new Parenthesis(orExpression);

        // 获取现有WHERE条件
        Expression where = plainSelect.getWhere();
        plainSelect.setWhere(where == null ? parenthesis : new AndExpression(where, parenthesis));
    }

    /**
     * 避免查询无限递归
     */
    private class SafeExpressionVisitor extends ExpressionVisitorAdapter {
        private final Set<Object> visitedObjects = Collections.newSetFromMap(new IdentityHashMap<>());

        @Override
        public void visit(SubSelect subSelect) {
            // 防止SubSelect无限递归
            if (visitedObjects.add(subSelect)) {
                try {
                    // 限制递归深度
                    if (visitedObjects.size() < 50) {
                        processSelectBody(subSelect.getSelectBody());
                    }
                } catch (Exception e) {
                    System.out.println(e.getMessage());
                } finally {
                    visitedObjects.remove(subSelect);
                }
            }
        }

        @Override
        public void visit(AllColumns allColumns) {
            // 关键:避免处理AllColumns时的无限递归
            // 在JSqlParser 4.3中,这里不能调用super.visit(allColumns)
        }
    }

    /**
     * 避免查询无限递归
     */
    private class SafeSelectItemVisitor extends SelectItemVisitorAdapter {
        @Override
        public void visit(SelectExpressionItem item) {
            try {
                item.getExpression().accept(new SafeExpressionVisitor());
            } catch (Exception e) {
                System.out.println(e.getMessage());
            }
        }
    }
}

代码如下:

注意:

1、willDoQuery中

核心作用

拦截器开关控制

决定是否允许当前SQL查询继续执行(true放行,false拦截)

与跳过机制集成

通过检查SkipDataIsolation的线程状态,实现动态拦截控制

方法调用时机

sequenceDiagram

MyBatis->>StoreDataInterceptor: 执行查询前

StoreDataInterceptor->>willDoQuery: 检查拦截条件

alt 返回true

MyBatis->>DB: 正常执行查询

else 返回false

MyBatis->>调用方: 直接返回空结果

end

优先级高于SQL改写

若返回false,则不会触发后续的beforeQuery(SQL重写逻辑)

典型使用场景

完全禁止无权限的查询(如未登录用户)

快速跳过无需处理的查询类型(如特定Mapper方法)

2、beforeQuery中

如果用户是超管用户则跳过拦截器-自己添加逻辑判断

addConditionsToSelect添加查询条件中SELECT中,获取当前用户的门店id权限使用模拟数据演示效果。

正常应该是从权限控制相关业务中获取,此处注意使用时修改。

自定义跳过数据隔离注解实现

复制代码
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface SkipDataIsolationAnnotation {
}

跳过数据隔离切面实现

复制代码
/**
 * 跳过数据隔离切面
 */
@Aspect
@Component
public class SkipDataIsolationAspect {

    //Around增强:在方法执行前后插入逻辑
    @Around("@annotation(skipDataIsolationAnnotation)")
    public Object handleSkipDataIsolation(ProceedingJoinPoint joinPoint,
                                          SkipDataIsolationAnnotation skipDataIsolationAnnotation) throws Throwable {
        try {
            //进入方法时设置ThreadLocal标志为true
            SkipDataIsolation.setMethodSkipDataIsolation(true);// 设置线程标志
            return joinPoint.proceed();
        } finally {
            //通过try-finally确保异常时也能清理状态
            SkipDataIsolation.methodClear(); // 清理线程状态
        }
    }
}

上下文控制器实现

复制代码
/**
 * 上下文控制器
 */
public class SkipDataIsolation {

    // 单次sql语句级别跳过数据隔离: 使用ThreadLocal存储跳过数据隔离的标志,默认不跳过value=false
    private static final ThreadLocal<Boolean> SKIP_DATA_ISOLATION = ThreadLocal.withInitial(() -> false);
    // 方法级别跳过数据隔离: 使用ThreadLocal存储跳过数据隔离的标志, 默认不跳过value=false
    private static final ThreadLocal<Boolean> SKIP_DATA_ISOLATION_METHOD = ThreadLocal.withInitial(() -> false);

    /**
     * 单次sql级别:设置跳过数据隔离标志
     */
    public static void setSkipDataIsolation(Boolean skip) {
        SKIP_DATA_ISOLATION.set(skip);
    }

    /**
     * 单次sql级别:获取跳过数据隔离标志
     */
    public static Boolean getSkipDataIsolation() {
        return SKIP_DATA_ISOLATION.get();
    }

    /**
     * 单次sql级别:清理ThreadLocal,防止内存泄漏
     */
    public static void clear() {
        SKIP_DATA_ISOLATION.remove();
    }

    /**
     * 方法级别:设置跳过数据隔离标志
     */
    public static void setMethodSkipDataIsolation(Boolean skip) {
        SKIP_DATA_ISOLATION_METHOD.set(skip);
    }

    /**
     * 方法级别:获取跳过数据隔离标志
     */
    public static Boolean getMethodSkipDataIsolation() {
        return SKIP_DATA_ISOLATION_METHOD.get();
    }

    /**
     * 方法级别:清理ThreadLocal,防止内存泄漏
     */
    public static void methodClear() {
        SKIP_DATA_ISOLATION_METHOD.remove();
    }
}

测试效果

新建一个包含store_id字段的表,并生成5条数据,其中有两条数据store_id为1和2。

新建两个controller并且一个添加跳过数据隔离注解,一个不添加,执行同样的mp的条件查询。

不进行数据隔离的查询效果

带数据隔离的效果

完整示例代码以及SQL文件资源下载

https://download.csdn.net/download/BADAO_LIUMANG_QIZHI/92218402

相关推荐
陌殇殇16 分钟前
001 Spring AI Alibaba框架整合百炼大模型平台 — 快速入门
人工智能·spring boot·ai
言慢行善30 分钟前
sqlserver模糊查询问题
java·数据库·sqlserver
专吃海绵宝宝菠萝屋的派大星35 分钟前
使用Dify对接自己开发的mcp
java·服务器·前端
大数据新鸟1 小时前
操作系统之虚拟内存
java·服务器·网络
Tong Z1 小时前
常见的限流算法和实现原理
java·开发语言
凭君语未可1 小时前
Java 中的实现类是什么
java·开发语言
He少年1 小时前
【基础知识、Skill、Rules和MCP案例介绍】
java·前端·python
克里斯蒂亚诺更新1 小时前
myeclipse的pojie
java·ide·myeclipse
迷藏4941 小时前
**eBPF实战进阶:从零构建网络流量监控与过滤系统**在现代云原生架构中,**网络可观测性**和**安全隔离**已成为
java·网络·python·云原生·架构