java数据权限过滤
使用切面
- 数据权限过滤注解
java
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface DataFilter {
String FILTER_TYPE_EQ = "=";
String FILTER_TYPE_IN = "IN";
/**
* 表别名
*/
String alias() default "";
/**
* 过滤字段
*/
String field() default "";
/**
* 过滤类型
*/
String type() default FILTER_TYPE_IN;
}
- 基础查询类:想要使用数据过滤的方法,其参数需要继承该类。
java
@Data
public class Query {
/**
* 查询条件子句
*/
@JsonIgnore
private String clause;
/**
* 查询条件参数
*/
@JsonIgnore
private String filterParams;
}
- 切面类
java
@Aspect
@Component
public class DataFilterAspect {
private final Logger log = LoggerFactory.getLogger(DataFilterAspect.class);
@Before("@annotation(dataFilter)")
public void doBefore(JoinPoint point, DataFilter dataFilter) {
// 1.获取调用的方法参数
Object arg = point.getArgs()[0];
if (!(arg instanceof Query)) {
log.warn("调用方法[{}], 参数不是Query对象", point.getSignature().getName());
return;
}
Query query = (Query) arg;
// 获取当前的用户
LoginUser loginUser = SecurityUtils.getLoginUser();
SysUser user = loginUser.getUser();
if (user.isAdmin()) {
// 如果是超级管理员,则不过滤数据
query.setClause(null);
return;
}
// 2.获取数据过滤注解的参数
String alias = dataFilter.alias();
String field = dataFilter.field();
String type = dataFilter.type();
if (StringUtils.isEmpty(field)) {
log.warn("调用方法[{}], 数据过滤注解的field参数为空", point.getSignature().getName());
return;
}
if (StringUtils.isEmpty(type)) {
log.warn("调用方法[{}], 数据过滤注解的type参数为空", point.getSignature().getName());
return;
}
// 3.设置查询sql子句
if (StringUtils.isNotEmpty(query.getClause())) {
return;
}
String clause;
if (StringUtils.isNotEmpty(query.getFilterParams())) {
if (StringUtils.isNotEmpty(alias)) {
clause = StringUtils.format("{}.{} {} {}", alias, field, type, query.getFilterParams());
} else {
clause = StringUtils.format("{} {} {}", field, type, query.getFilterParams());
}
query.setClause(clause);
return;
}
if (StringUtils.isNotEmpty(alias)) {
clause = StringUtils.format("{}.{} IN ( SELECT company_code FROM sys_user_company WHERE user_id = {} )", alias, field, user.getUserId());
} else {
clause = StringUtils.format("{} IN ( SELECT company_code FROM sys_user_company WHERE user_id = {} )", field, user.getUserId());
}
query.setClause(clause);
}
}
-
使用
- mapper.xml 里的sql需要加入如下内容:
xml<if test="query.clause != null and query.clause != ''"> AND ${query.clause} </if>
- mapper类方法上加上注解:
@DataFilter(alias = "t1", field = "company_code")
使用mybatis拦截器
- 工具类
java
public class DataFilter {
/**
* sql条件子句
*/
protected static final ThreadLocal<String> LOCAL_CLAUSE = new ThreadLocal<>();
protected static void setClause(String clause) {
LOCAL_CLAUSE.set(clause);
}
public static String getClause() {
return LOCAL_CLAUSE.get();
}
public static void clear() {
LOCAL_CLAUSE.remove();
}
/**
* 用指定的sql子句进行数据过滤
* <p>使用此方法需要被执行的sql中带有 WHERE 子句,例如:WHERE 1 = 1</p>
* @param clause sql子句
*/
public static void startFilter(String clause) {
setClause(clause);
}
}
- 拦截器
java
@Intercepts({
// 拦截Executor的query方法(不带CacheKey和BoundSql参数)
@Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}),
// 拦截Executor的query方法(带CacheKey和BoundSql参数)
@Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class, CacheKey.class, BoundSql.class}),
// 拦截Executor的update方法(包括update、insert、delete语句)方法
@Signature(type = Executor.class, method = "update", args = {MappedStatement.class, Object.class}),
})
public class DataFilterInterceptor implements Interceptor {
public DataFilterInterceptor() {
log.info("数据过滤拦截器初始化成功");
}
private static final Logger log = LoggerFactory.getLogger(DataFilterInterceptor.class);
/**
* 拦截方法,在SQL执行前进行处理
*/
@Override
public Object intercept(Invocation invocation) throws Throwable {
// 获取数据过滤条件
String clause = DataFilter.getClause();
// 如果没有过滤条件,直接执行原SQL
if (clause == null) {
return invocation.proceed();
}
log.info("数据过滤拦截器-附加sql子句:{}", clause);
// 1.获取原始sql
Object[] args = invocation.getArgs();
MappedStatement ms = (MappedStatement) args[0];
BoundSql boundSql;
// 根据参数数量判断是哪种query方法
if (args.length == 6) {
// 带CacheKey和BoundSql的query方法
boundSql = (BoundSql) args[5];
} else {
// 普通query或update方法
boundSql = ms.getBoundSql(args[1]);
}
String originalSql = boundSql.getSql();
// 2.改写sql,添加数据过滤条件
String newSql = appendDataFilter(originalSql, clause);
// 3.替换SQL
BoundSql newBoundSql = new BoundSql(ms.getConfiguration(), newSql,
boundSql.getParameterMappings(), boundSql.getParameterObject());
if (args.length == 6) {
// 更新带CacheKey和BoundSql参数的query方法中的BoundSql
args[5] = newBoundSql;
} else {
// 替换普通query或update方法中的MappedStatement
MappedStatement newMs = copyFromMappedStatement(ms, new BoundSqlSqlSource(newBoundSql));
args[0] = newMs;
}
// 4.执行方法
return invocation.proceed();
// 注意:清理ThreadLocal数据不能放在这,因为使用分页拦截器时会执行count和select两个语句,
// 清除数据后会导致只有count语句会被加上过滤条件
}
/**
* 在原始SQL中添加数据过滤条件
*/
private String appendDataFilter(String originalSql, String clause) {
String lowerSql = originalSql.toLowerCase();
// 查找order by 和 limit 的位置
int orderByIndex = lowerSql.lastIndexOf("order by");
int limitIndex = lowerSql.lastIndexOf("limit");
if (orderByIndex != -1) {
// 存在order by,将条件插入到 order by前
String beforeOrderBy = originalSql.substring(0, orderByIndex);
String afterOrderBy = originalSql.substring(orderByIndex);
return beforeOrderBy + " AND" + clause + " " + afterOrderBy;
} else {
// 不存在order by,存在limit,将条件插入到 limit 前
if (limitIndex != -1) {
String beforeLimit = originalSql.substring(0, limitIndex);
String afterLimit = originalSql.substring(limitIndex);
return beforeLimit + " AND" + clause + " " + afterLimit;
}
// 不存在order by,不存在limit,将条件插入到sql的末尾
return originalSql + " AND" + clause;
}
}
/**
* 复制 MappedStatement 并替换 BoundSql
*/
private MappedStatement copyFromMappedStatement(MappedStatement ms, SqlSource newSqlSource) {
MappedStatement.Builder builder = new MappedStatement.Builder(
ms.getConfiguration(),
ms.getId(),
newSqlSource,
ms.getSqlCommandType()
);
// 复制其他属性(如 resultMaps、parameterMap 等)
builder.resource(ms.getResource());
builder.fetchSize(ms.getFetchSize());
builder.statementType(ms.getStatementType());
builder.keyGenerator(ms.getKeyGenerator());
builder.timeout(ms.getTimeout());
builder.parameterMap(ms.getParameterMap());
builder.resultMaps(ms.getResultMaps());
builder.resultSetType(ms.getResultSetType());
builder.cache(ms.getCache());
builder.flushCacheRequired(ms.isFlushCacheRequired());
builder.useCache(ms.isUseCache());
return builder.build();
}
/**
* 自定义 SqlSource,用于动态 SQL
*/
public static class BoundSqlSqlSource implements SqlSource {
private final BoundSql boundSql;
public BoundSqlSqlSource(BoundSql boundSql) {
this.boundSql = boundSql;
}
@Override
public BoundSql getBoundSql(Object parameterObject) {
return boundSql;
}
}
}
- 注册拦截器
java
@Configuration
public class MyBatisConfig {
@Bean
public SqlSessionFactory sqlSessionFactory(DataSource dataSource) throws Exception {
SqlSessionFactoryBean sessionFactory = new SqlSessionFactoryBean();
sessionFactory.setDataSource(dataSource);
sessionFactory.setPlugins(new DataFilterInterceptor());
return sessionFactory.getObject();
}
}
// mp中
@Configuration
public class MyBatisPlusConfig {
@Bean
public DataFilterInterceptor dataFilterInterceptor() {
return new DataFilterInterceptor();
}
}
- 使用
示例:
java
@GetMapping("/list")
public R<List<HotelOrderVo>> list(HotelOrderReqDto dto) {
List<String> codeList = sysDeptService.selectCodeList(getDeptId()).stream()
.map(code -> "'" + code + "'")
.collect(Collectors.toList());
DataFilter.startFilter(" dept_id IN (" + StringUtils.join(codeList, ",") + ")");
//PageUtils.startPage();
List<HotelOrderVo> list = hotelOrderService.list(dto);
DataFilter.clear();
// 使用了分页拦截器后返回的list实际是继承了List的Page对象
return R.ok(list);
}
相关过滤条件也可以直接写死在拦截器里。