【源码】Sharding-JDBC源码分析之Sql解析的原理

Sharding-JDBC系列

1、Sharding-JDBC分库分表的基本使用

2、Sharding-JDBC分库分表之SpringBoot分片策略

3、Sharding-JDBC分库分表之SpringBoot主从配置

4、SpringBoot集成Sharding-JDBC-5.3.0分库分表

5、SpringBoot集成Sharding-JDBC-5.3.0实现按月动态建表分表

6、【源码】Sharding-JDBC源码分析之JDBC

7、【源码】Sharding-JDBC源码分析之SPI机制

8、【源码】Sharding-JDBC源码分析之Yaml分片配置文件解析原理

9、【源码】Sharding-JDBC源码分析之Yaml分片配置原理(一)

10、【源码】Sharding-JDBC源码分析之Yaml分片配置原理(二)

11、【源码】Sharding-JDBC源码分析之Yaml分片配置转换原理

12、【源码】Sharding-JDBC源码分析之ShardingSphereDataSource的创建原理

13、【源码】Sharding-JDBC源码分析之ContextManager创建中mode分片配置信息的持久化存储的原理

14、【源码】Sharding-JDBC源码分析之ContextManager创建中ShardingSphereDatabase的创建原理

15、【源码】Sharding-JDBC源码分析之分片规则生成器DatabaseRuleBuilder实现规则配置到规则对象的生成原理

16、【源码】Sharding-JDBC源码分析之配置数据库定义的表的元数据解析原理

17、【源码】Sharding-JDBC源码分析之ShardingSphereConnection的创建原理

18、【源码】Sharding-JDBC源码分析之Sql解析的原理

前言

Sharding-JDBC是一个开源的分库分表中间件,它通过自动解析SQL,结合配置的分库分表规则,重写SQL,自动的将SQL路由到正确的数据库和表。Sharding-JDBC是通过Antlr来解析SQL的。本篇以MySQL数据库为例,从源码的角度分析SQL解析的实现原理。

Antlr详见:Antlr的使用-CSDN博客

MySQLStatement.g4

Antlr 使用时,需先创建解析文本的语法规则文件,文件以 .g4 后缀结尾。文件在源码的 sql-parser 文件夹下。支持的数据库如下图所示:

对于MySQL数据库,其语法规则文件为MySQLStatement.g4,在

MySQLStatement.g4的部分源码如下:

Groovy 复制代码
grammar MySQLStatement;

import Comments, DDLStatement, TCLStatement, DCLStatement, RLStatement;

execute
    : (select    // 查询
    | insert     // 插入
    | update
    | delete
    | replace

// 省略其他

    | delimiter
    ) (SEMI_ EOF? | EOF)
    | EOF
    ;

在Sharding-JDBC中,支持了几乎所有的MySQL数据库操作,对于最常用的DML,其规则文件为 DMLStatement.g4。以下以select语句为例进行分析。

DMLStatement.g4

DMLStatement.g4的部分源码如下:

Groovy 复制代码
grammar DMLStatement;

import BaseRule;

// 查询语句
select
    : queryExpression lockClauseList?  // 查询表达式
    | queryExpressionParens
    | selectWithInto
    ;

// 查询表达式
queryExpression
    : withClause? (queryExpressionBody | queryExpressionParens) orderByClause? limitClause?
    ;

// 查询表达式体
queryExpressionBody
    : queryPrimary     // 主查询
    | queryExpressionParens combineClause
    | queryExpressionBody combineClause
    ;

combineClause
    : UNION combineOption? (queryPrimary | queryExpressionParens)
    ;

queryExpressionParens
    : LP_ (queryExpressionParens | queryExpression lockClauseList?) RP_
    ;

// 主查询
queryPrimary
    : querySpecification
    | tableValueConstructor
    | tableStatement
    ;

// 查询规则
querySpecification
    : SELECT selectSpecification* projections selectIntoExpression? fromClause? whereClause? groupByClause? havingClause? windowClause?
    ;

// where部分
whereClause
    : WHERE expr
    ;

// 投影部分
projections
    : (unqualifiedShorthand | projection) (COMMA_ projection)*
    ;

projection
    : expr (AS? alias)? | qualifiedShorthand
    ;


// 省略其他

DMLStatement.g4 中引入了BaseRule.g4的内容。通过 querySpecification 规则,对于select语句,需要以select开头。

BaseRule.g4

BaseRule.g4的部分源码如下:

Groovy 复制代码
// 表达式
expr
    : booleanPrimary   // 表达式主
    | expr andOperator expr  // and 表达式
    | expr orOperator expr
    | expr XOR expr
    | notOperator expr
    ;
    
andOperator
    : AND | AND_
    ;
    
orOperator
    : OR | OR_
    ;
    
notOperator
    : NOT | NOT_
    ;
    
booleanPrimary
    : booleanPrimary IS NOT? (TRUE | FALSE | UNKNOWN | NULL)
    | booleanPrimary SAFE_EQ_ predicate
    | booleanPrimary comparisonOperator predicate    // =、> 等表达式
    | booleanPrimary comparisonOperator (ALL | ANY) subquery
    | booleanPrimary assignmentOperator predicate
    | predicate
    ;
    
assignmentOperator
    : EQ_ | ASSIGNMENT_
    ;
    
comparisonOperator
    : EQ_ | GTE_ | GT_ | LTE_ | LT_ | NEQ_
    ;
    
predicate
    : bitExpr NOT? IN subquery
    | bitExpr NOT? IN LP_ expr (COMMA_ expr)* RP_
    | bitExpr NOT? BETWEEN bitExpr AND predicate
    | bitExpr SOUNDS LIKE bitExpr
    | bitExpr NOT? LIKE simpleExpr (ESCAPE simpleExpr)?
    | bitExpr NOT? REGEXP bitExpr
    | bitExpr
    ;
    
bitExpr
    : bitExpr VERTICAL_BAR_ bitExpr
    | bitExpr AMPERSAND_ bitExpr
    | bitExpr SIGNED_LEFT_SHIFT_ bitExpr
    | bitExpr SIGNED_RIGHT_SHIFT_ bitExpr
    | bitExpr PLUS_ bitExpr
    | bitExpr MINUS_ bitExpr
    | bitExpr ASTERISK_ bitExpr
    | bitExpr SLASH_ bitExpr
    | bitExpr DIV bitExpr
    | bitExpr MOD bitExpr
    | bitExpr MOD_ bitExpr
    | bitExpr CARET_ bitExpr
    | bitExpr PLUS_ intervalExpression
    | bitExpr MINUS_ intervalExpression
    | simpleExpr
    ;
    

简单示例

以下以一个select语句为例:

select id, name from tb_user where id = 1

匹配 DMLStatement.g4 中的 querySpecification 规则,匹配如下:

MySQLStatementSQLVisitor

基于 MySQLStatement.g4 文件,通过Anltr的Generate ANTLR Recognizer,自动生成MySQLStatementBaseVisitor.java文件。MySQLStatementSQLVisitor继承MySQLStatementBaseVisitor,通过重写对应的 visit 方法,创建SQL解析后的对象。

MySQLStatementSQLVisitor源码

java 复制代码
package org.apache.shardingsphere.sql.parser.mysql.visitor.statement.impl;

/**
 * MySQL的sql语句的访问器
 */
@NoArgsConstructor
@Getter(AccessLevel.PROTECTED)
public abstract class MySQLStatementSQLVisitor extends MySQLStatementBaseVisitor<ASTNode> {
    
    // 当前参数下标。当解析参数的时候,会自动累加
    private int currentParameterIndex;
    
    // 参数标记部分。当解析参数的时候,会自动创建一个对象,加入到集合中
    private final Collection<ParameterMarkerSegment> parameterMarkerSegments = new LinkedList<>();
    
    public MySQLStatementSQLVisitor(final Properties props) {
    }

    /**
     * 访问参数标记。返回一个参数标识值对象,记录下标和?
     * @param ctx
     * @return
     */
    @Override
    public final ASTNode visitParameterMarker(final ParameterMarkerContext ctx) {
        return new ParameterMarkerValue(currentParameterIndex++, ParameterMarkerType.QUESTION);
    }
 
    
    @Override
    public final ASTNode visitIdentifier(final IdentifierContext ctx) {
        return new IdentifierValue(ctx.getText());
    }


    /**
     * 访问表名
     * @param ctx
     * @return
     */
    @Override
    public final ASTNode visitTableName(final TableNameContext ctx) {
        // 根据ctx,创建简单表部分
        SimpleTableSegment result = new SimpleTableSegment(new TableNameSegment(ctx.name().getStart().getStartIndex(),
                ctx.name().getStop().getStopIndex(), new IdentifierValue(ctx.name().identifier().getText())));
        // 获取owner
        OwnerContext owner = ctx.owner();
        if (null != owner) {
            // 访问owner部分
            result.setOwner((OwnerSegment) visit(owner));
        }
        return result;
    }

    
    @Override
    public final ASTNode visitOwner(final OwnerContext ctx) {
        return new OwnerSegment(ctx.getStart().getStartIndex(), ctx.getStop().getStopIndex(), (IdentifierValue) visit(ctx.identifier()));
    }
 

    /**
     * 访问表达式
     * @param ctx
     * @return
     */
    @Override
    public final ASTNode visitExpr(final ExprContext ctx) {
        // 普通的表达式都为booleanPrimary
        if (null != ctx.booleanPrimary()) {
            return visit(ctx.booleanPrimary());
        }
        // xor 连接的表达式
        if (null != ctx.XOR()) {
            return createBinaryOperationExpression(ctx, "XOR");
        }
        // and 连接的表达式
        if (null != ctx.andOperator()) {
            return createBinaryOperationExpression(ctx, ctx.andOperator().getText());
        }
        // or 连接的表达式
        if (null != ctx.orOperator()) {
            return createBinaryOperationExpression(ctx, ctx.orOperator().getText());
        }
        return new NotExpression(ctx.start.getStartIndex(), ctx.stop.getStopIndex(), (ExpressionSegment) visit(ctx.expr(0)));
    }
    
    private BinaryOperationExpression createBinaryOperationExpression(final ExprContext ctx, final String operator) {
        ExpressionSegment left = (ExpressionSegment) visit(ctx.expr(0));
        ExpressionSegment right = (ExpressionSegment) visit(ctx.expr(1));
        String text = ctx.start.getInputStream().getText(new Interval(ctx.start.getStartIndex(), ctx.stop.getStopIndex()));
        return new BinaryOperationExpression(ctx.start.getStartIndex(), ctx.stop.getStopIndex(), left, right, operator, text);
    }

    /**
     * 普通表达式访问
     * @param ctx
     * @return
     */
    @Override
    public final ASTNode visitBooleanPrimary(final BooleanPrimaryContext ctx) {
        // 带 is 的条件表达式
        if (null != ctx.IS()) {
            // TODO optimize operatorToken
            String rightText = "";
            if (null != ctx.NOT()) {
                rightText = rightText.concat(ctx.start.getInputStream().getText(new Interval(ctx.NOT().getSymbol().getStartIndex(),
                        ctx.NOT().getSymbol().getStopIndex()))).concat(" ");
            }
            Token operatorToken = null;
            if (null != ctx.NULL()) {
                operatorToken = ctx.NULL().getSymbol();
            }
            if (null != ctx.TRUE()) {
                operatorToken = ctx.TRUE().getSymbol();
            }
            if (null != ctx.FALSE()) {
                operatorToken = ctx.FALSE().getSymbol();
            }
            int startIndex = null == operatorToken ? ctx.IS().getSymbol().getStopIndex() + 2 : operatorToken.getStartIndex();
            rightText = rightText.concat(ctx.start.getInputStream().getText(new Interval(startIndex, ctx.stop.getStopIndex())));
            ExpressionSegment right = new LiteralExpressionSegment(ctx.IS().getSymbol().getStopIndex() + 2, ctx.stop.getStopIndex(), rightText);
            String text = ctx.start.getInputStream().getText(new Interval(ctx.start.getStartIndex(), ctx.stop.getStopIndex()));
            ExpressionSegment left = (ExpressionSegment) visit(ctx.booleanPrimary());
            String operator = "IS";
            return new BinaryOperationExpression(ctx.start.getStartIndex(), ctx.stop.getStopIndex(), left, right, operator, text);
        }
        // 比较运算符的表达式
        if (null != ctx.comparisonOperator() || null != ctx.SAFE_EQ_()) {
            return createCompareSegment(ctx);
        }
        // 作业操作符的表达式
        if (null != ctx.assignmentOperator()) {
            return createAssignmentSegment(ctx);
        }
        // 谓词的访问
        return visit(ctx.predicate());
    }
    
    private ASTNode createAssignmentSegment(final BooleanPrimaryContext ctx) {
        ExpressionSegment left = (ExpressionSegment) visit(ctx.booleanPrimary());
        ExpressionSegment right = (ExpressionSegment) visit(ctx.predicate());
        String operator = ctx.assignmentOperator().getText();
        String text = ctx.start.getInputStream().getText(new Interval(ctx.start.getStartIndex(), ctx.stop.getStopIndex()));
        return new BinaryOperationExpression(ctx.start.getStartIndex(), ctx.stop.getStopIndex(), left, right, operator, text);
    }

    /**
     * 创建比较表达式部分
     * @param ctx
     * @return
     */
    private ASTNode createCompareSegment(final BooleanPrimaryContext ctx) {
        // 获取左边的表达式
        ExpressionSegment left = (ExpressionSegment) visit(ctx.booleanPrimary());
        ExpressionSegment right;
        // 获取右边的表达式
        if (null != ctx.predicate()) {
            right = (ExpressionSegment) visit(ctx.predicate());
        } else {
            // 子查询表达式部分
            right = new SubqueryExpressionSegment(new SubquerySegment(ctx.subquery().start.getStartIndex(), ctx.subquery().stop.getStopIndex(), (MySQLSelectStatement) visit(ctx.subquery())));
        }
        // 获取表达式运算符
        String operator = null != ctx.SAFE_EQ_() ? ctx.SAFE_EQ_().getText() : ctx.comparisonOperator().getText();
        // 获取表达式对应的文本
        String text = ctx.start.getInputStream().getText(new Interval(ctx.start.getStartIndex(), ctx.stop.getStopIndex()));
        // 创建二元运算表达式对象
        return new BinaryOperationExpression(ctx.start.getStartIndex(), ctx.stop.getStopIndex(), left, right, operator, text);
    }
    
    @Override
    public final ASTNode visitPredicate(final PredicateContext ctx) {
        // 创建 in 部分
        if (null != ctx.IN()) {
            return createInSegment(ctx);
        }
        // 创建between
        if (null != ctx.BETWEEN()) {
            return createBetweenSegment(ctx);
        }
        if (null != ctx.LIKE()) {
            return createBinaryOperationExpressionFromLike(ctx);
        }
        if (null != ctx.REGEXP()) {
            return createBinaryOperationExpressionFromRegexp(ctx);
        }
        // 创建bit表达式,即文本等
        return visit(ctx.bitExpr(0));
    }

    /**
     * 访问字节表达式
     * @param ctx
     * @return
     */
    @Override
    public final ASTNode visitBitExpr(final BitExprContext ctx) {
        // 如果有简单字节表达式,即文本,则访问
        if (null != ctx.simpleExpr()) {
            return visit(ctx.simpleExpr());
        }
        // 二元运算表达式访问
        ExpressionSegment left = (ExpressionSegment) visit(ctx.getChild(0));
        ExpressionSegment right = (ExpressionSegment) visit(ctx.getChild(2));
        String operator = ctx.getChild(1).getText();
        String text = ctx.start.getInputStream().getText(new Interval(ctx.start.getStartIndex(), ctx.stop.getStopIndex()));
        return new BinaryOperationExpression(ctx.start.getStartIndex(), ctx.stop.getStopIndex(), left, right, operator, text);
    }

    /**
     * 访问简单表达式
     * @param ctx
     * @return
     */
    @Override
    public final ASTNode visitSimpleExpr(final SimpleExprContext ctx) {
        // 获取起始、截止位置
        int startIndex = ctx.start.getStartIndex();
        int stopIndex = ctx.stop.getStopIndex();
        // 访问子查询
        if (null != ctx.subquery()) {
            SubquerySegment subquerySegment = new SubquerySegment(ctx.subquery().getStart().getStartIndex(), ctx.subquery().getStop().getStopIndex(), (MySQLSelectStatement) visit(ctx.subquery()));
            if (null != ctx.EXISTS()) {
                return new ExistsSubqueryExpression(startIndex, stopIndex, subquerySegment);
            }
            return new SubqueryExpressionSegment(subquerySegment);
        }
        // 访问参数标记
        if (null != ctx.parameterMarker()) {
            // 访问参数标记,创建一个ParameterMarkerValue,记录当前参数的下标及?
            ParameterMarkerValue parameterMarker = (ParameterMarkerValue) visit(ctx.parameterMarker());
            // 创建参数标记表达式部分
            ParameterMarkerExpressionSegment segment = new ParameterMarkerExpressionSegment(startIndex, stopIndex, parameterMarker.getValue(), parameterMarker.getType());
            // 记录,并返回
            parameterMarkerSegments.add(segment);
            return segment;
        }
        // 访问字面量
        if (null != ctx.literals()) {
            return SQLUtil.createLiteralExpression(visit(ctx.literals()), startIndex, stopIndex, ctx.literals().start.getInputStream().getText(new Interval(startIndex, stopIndex)));
        }
        // 访问区间表达式
        if (null != ctx.intervalExpression()) {
            return visit(ctx.intervalExpression());
        }
        // 访问方法调用
        if (null != ctx.functionCall()) {
            return visit(ctx.functionCall());
        }
        if (null != ctx.collateClause()) {
            SimpleExpressionSegment collateValueSegment = (SimpleExpressionSegment) visit(ctx.collateClause());
            return new CollateExpression(startIndex, stopIndex, collateValueSegment);
        }
        // 访问列
        if (null != ctx.columnRef()) {
            return visit(ctx.columnRef());
        }
        // 访问匹配表达式
        if (null != ctx.matchExpression()) {
            return visit(ctx.matchExpression());
        }
        // not 运算符
        if (null != ctx.notOperator()) {
            ASTNode expression = visit(ctx.simpleExpr(0));
            if (expression instanceof ExistsSubqueryExpression) {
                ((ExistsSubqueryExpression) expression).setNot(true);
                return expression;
            }
            return new NotExpression(startIndex, stopIndex, (ExpressionSegment) expression);
        }
        if (null != ctx.LP_() && 1 == ctx.expr().size()) {
            return visit(ctx.expr(0));
        }
        return visitRemainSimpleExpr(ctx);
    }

    /**
     * 访问列
     * @param ctx
     * @return
     */
    @Override
    public ASTNode visitColumnRef(final ColumnRefContext ctx) {
        int identifierCount = ctx.identifier().size();
        ColumnSegment result;
        // 只有一个标识符
        if (1 == identifierCount) {
            // 创建一个列部分
            result = new ColumnSegment(ctx.getStart().getStartIndex(), ctx.getStop().getStopIndex(), (IdentifierValue) visit(ctx.identifier(0)));
        } else if (2 == identifierCount) {
            // 有两个标识符,后一个为列部分,前一个为列的owner部分
            result = new ColumnSegment(ctx.getStart().getStartIndex(), ctx.getStop().getStopIndex(), (IdentifierValue) visit(ctx.identifier(1)));
            result.setOwner(new OwnerSegment(ctx.identifier(0).start.getStartIndex(), ctx.identifier(0).stop.getStopIndex(), (IdentifierValue) visit(ctx.identifier(0))));
        } else {
            // 如果有三个,则为owner.owner.col
            result = new ColumnSegment(ctx.getStart().getStartIndex(), ctx.getStop().getStopIndex(), (IdentifierValue) visit(ctx.identifier(2)));
            OwnerSegment owner = new OwnerSegment(ctx.identifier(1).start.getStartIndex(), ctx.identifier(1).stop.getStopIndex(), (IdentifierValue) visit(ctx.identifier(1)));
            owner.setOwner(new OwnerSegment(ctx.identifier(0).start.getStartIndex(), ctx.identifier(0).stop.getStopIndex(), (IdentifierValue) visit(ctx.identifier(0))));
            result.setOwner(owner);
        }
        return result;
    }

    @Override
    public ASTNode visitQueryExpression(final QueryExpressionContext ctx) {
        MySQLSelectStatement result;
        if (null != ctx.queryExpressionBody()) {
            result = (MySQLSelectStatement) visit(ctx.queryExpressionBody());
        } else {
            result = (MySQLSelectStatement) visit(ctx.queryExpressionParens());
        }
        if (null != ctx.orderByClause()) {
            result.setOrderBy((OrderBySegment) visit(ctx.orderByClause()));
        }
        if (null != ctx.limitClause()) {
            result.setLimit((LimitSegment) visit(ctx.limitClause()));
        }
        return result;
    }
    
    /**
     * 访问查询表达式体
     * @param ctx
     * @return
     */
    @Override
    public ASTNode visitQueryExpressionBody(final QueryExpressionBodyContext ctx) {
        // 如果只有一个子节点 && 节点为QueryPrimaryContext,即queryPrimary部分
        if (1 == ctx.getChildCount() && ctx.getChild(0) instanceof QueryPrimaryContext) {
            return visit(ctx.queryPrimary());
        }
        // 如果是queryExpressionBody,递归访问该方法
        if (null != ctx.queryExpressionBody()) {
            MySQLSelectStatement result = new MySQLSelectStatement();
            MySQLSelectStatement left = (MySQLSelectStatement) visit(ctx.queryExpressionBody());
            result.setProjections(left.getProjections());
            result.setFrom(left.getFrom());
            left.getTable().ifPresent(result::setTable);
            result.setCombine(createCombineSegment(ctx.combineClause(), left));
            return result;
        }
        return visit(ctx.queryExpressionParens());
    }
    
    private CombineSegment createCombineSegment(final CombineClauseContext ctx, final MySQLSelectStatement left) {
        CombineType combineType = (null != ctx.combineOption() && null != ctx.combineOption().ALL()) ? CombineType.UNION_ALL : CombineType.UNION;
        MySQLSelectStatement right = null != ctx.queryPrimary() ? (MySQLSelectStatement) visit(ctx.queryPrimary()) : (MySQLSelectStatement) visit(ctx.queryExpressionParens());
        return new CombineSegment(ctx.getStart().getStartIndex(), ctx.getStop().getStopIndex(), left, combineType, right);
    }

    /**
     * 访问查询规则
     * @param ctx
     * @return
     */
    @Override
    public ASTNode visitQuerySpecification(final QuerySpecificationContext ctx) {
        MySQLSelectStatement result = new MySQLSelectStatement();
        // 访问投影
        result.setProjections((ProjectionsSegment) visit(ctx.projections()));
        // 访问查询规则部分
        if (null != ctx.selectSpecification()) {
            result.getProjections().setDistinctRow(isDistinct(ctx));
        }
        // from的访问,获取表部分
        if (null != ctx.fromClause() && null != ctx.fromClause().tableReferences()) {
            TableSegment tableSource = (TableSegment) visit(ctx.fromClause().tableReferences());
            result.setFrom(tableSource);
        }
        // where的访问,获取where条件部分
        if (null != ctx.whereClause()) {
            result.setWhere((WhereSegment) visit(ctx.whereClause()));
        }
        if (null != ctx.groupByClause()) {
            result.setGroupBy((GroupBySegment) visit(ctx.groupByClause()));
        }
        if (null != ctx.havingClause()) {
            result.setHaving((HavingSegment) visit(ctx.havingClause()));
        }
        if (null != ctx.windowClause()) {
            result.setWindow((WindowSegment) visit(ctx.windowClause()));
        }
        return result;
    }
    
    @Override
    public ASTNode visitTableStatement(final TableStatementContext ctx) {
        MySQLSelectStatement result = new MySQLSelectStatement();
        result.setTable((SimpleTableSegment) visit(ctx.tableName()));
        return result;
    }

    /**
     * 访问排序部分
     * @param ctx
     * @return
     */
    @Override
    public final ASTNode visitOrderByClause(final OrderByClauseContext ctx) {
        Collection<OrderByItemSegment> items = new LinkedList<>();
        // 访问排序项
        for (OrderByItemContext each : ctx.orderByItem()) {
            items.add((OrderByItemSegment) visit(each));
        }
        return new OrderBySegment(ctx.getStart().getStartIndex(), ctx.getStop().getStopIndex(), items);
    }

    /**
     * 访问排序项
     * @param ctx
     * @return
     */
    @Override
    public final ASTNode visitOrderByItem(final OrderByItemContext ctx) {
        OrderDirection orderDirection;
        // 排序方向部分
        if (null != ctx.direction()) {
            orderDirection = null != ctx.direction().DESC() ? OrderDirection.DESC : OrderDirection.ASC;
        } else {
            orderDirection = OrderDirection.ASC;
        }
        // 数字部分
        if (null != ctx.numberLiterals()) {
            return new IndexOrderByItemSegment(ctx.numberLiterals().getStart().getStartIndex(), ctx.numberLiterals().getStop().getStopIndex(),
                    SQLUtil.getExactlyNumber(ctx.numberLiterals().getText(), 10).intValue(), orderDirection, null);
        } else {
            // 访问排序表达式,如排序的列等的访问
            ASTNode expr = visitExpr(ctx.expr());
            // 如果为列
            if (expr instanceof ColumnSegment) {
                // 创建排序列部分
                return new ColumnOrderByItemSegment((ColumnSegment) expr, orderDirection, null);
            } else {
                // 如果是表达式,创建排序表达式部分
                return new ExpressionOrderByItemSegment(ctx.expr().getStart().getStartIndex(),
                        ctx.expr().getStop().getStopIndex(), getOriginalText(ctx.expr()), orderDirection, null, (ExpressionSegment) expr);
            }
        }
    }

    @Override
    public ASTNode visitSingleTableClause(final SingleTableClauseContext ctx) {
        SimpleTableSegment result = (SimpleTableSegment) visit(ctx.tableName());
        if (null != ctx.alias()) {
            result.setAlias((AliasSegment) visit(ctx.alias()));
        }
        return result;
    }
    
    /**
     * Select语句
     * @param ctx
     * @return
     */
    @Override
    public ASTNode visitSelect(final SelectContext ctx) {
        // TODO :Unsupported for withClause.
        MySQLSelectStatement result;
        if (null != ctx.queryExpression()) {
            // 如果有查询表达式,访问表达式,获得MySQLSelectStatement对象
            result = (MySQLSelectStatement) visit(ctx.queryExpression());
            // 如果有锁定条款列表,访问作为锁部分,添加到MySQLSelectStatement对象
            if (null != ctx.lockClauseList()) {
                result.setLock((LockSegment) visit(ctx.lockClauseList()));
            }
        } else if (null != ctx.selectWithInto()) {
            // 访问selectWithInto部分
            result = (MySQLSelectStatement) visit(ctx.selectWithInto());
        } else {
            // 如果是queryExpressionParens,访问第一个孩子节点
            result = (MySQLSelectStatement) visit(ctx.getChild(0));
        }
        // 设置参数数量,在解析参数标记时,会自动从0开始累加currentParameterIndex
        result.setParameterCount(currentParameterIndex);
        // 添加参数标记部分,在解析参数标记时,会创建参数标记部分对象,加入到集合中
        result.getParameterMarkerSegments().addAll(getParameterMarkerSegments());
        return result;
    }

    /**
     * 如 distinct 等的访问
     * @param ctx
     * @return
     */
    private boolean isDistinct(final QuerySpecificationContext ctx) {
        // 包括:duplicateSpecification | HIGH_PRIORITY | STRAIGHT_JOIN | SQL_SMALL_RESULT |
        // SQL_BIG_RESULT | SQL_BUFFER_RESULT | SQL_NO_CACHE | SQL_CALC_FOUND_ROWS
        for (SelectSpecificationContext each : ctx.selectSpecification()) {
            if (((BooleanLiteralValue) visit(each)).getValue()) {
                return true;
            }
        }
        return false;
    }

    /**
     * 选择规范访问
     * @param ctx
     * @return
     */
    @Override
    public ASTNode visitSelectSpecification(final SelectSpecificationContext ctx) {
        // 重复规范
        if (null != ctx.duplicateSpecification()) {
            return visit(ctx.duplicateSpecification());
        }
        return new BooleanLiteralValue(false);
    }

    /**
     * 访问重复规范
     * @param ctx
     * @return
     */
    @Override
    public ASTNode visitDuplicateSpecification(final DuplicateSpecificationContext ctx) {
        String text = ctx.getText();
        if ("DISTINCT".equalsIgnoreCase(text) || "DISTINCTROW".equalsIgnoreCase(text)) {
            return new BooleanLiteralValue(true);
        }
        return new BooleanLiteralValue(false);
    }

    /**
     * 访问投影信息
     * @param ctx
     * @return
     */
    @Override
    public ASTNode visitProjections(final ProjectionsContext ctx) {
        Collection<ProjectionSegment> projections = new LinkedList<>();
        // 如果没有特别的短语,创建一个速记投影段
        if (null != ctx.unqualifiedShorthand()) {
            projections.add(new ShorthandProjectionSegment(ctx.unqualifiedShorthand().getStart().getStartIndex(), ctx.unqualifiedShorthand().getStop().getStopIndex()));
        }
        // 遍历投影,访问投影部分
        for (ProjectionContext each : ctx.projection()) {
            projections.add((ProjectionSegment) visit(each));
        }
        // 创建投影部分对象
        ProjectionsSegment result = new ProjectionsSegment(ctx.getStart().getStartIndex(), ctx.getStop().getStopIndex());
        // 添加投影
        result.getProjections().addAll(projections);
        return result;
    }
    
    @Override
    public ASTNode visitProjection(final ProjectionContext ctx) {
        // FIXME :The stop index of project is the stop index of projection, instead of alias.
        if (null != ctx.qualifiedShorthand()) {
            return createShorthandProjection(ctx.qualifiedShorthand());
        }
        AliasSegment alias = null == ctx.alias() ? null : (AliasSegment) visit(ctx.alias());
        ASTNode exprProjection = visit(ctx.expr());
        if (exprProjection instanceof ColumnSegment) {
            ColumnProjectionSegment result = new ColumnProjectionSegment((ColumnSegment) exprProjection);
            result.setAlias(alias);
            return result;
        }
        if (exprProjection instanceof SubquerySegment) {
            SubquerySegment subquerySegment = (SubquerySegment) exprProjection;
            String text = ctx.start.getInputStream().getText(new Interval(subquerySegment.getStartIndex(), subquerySegment.getStopIndex()));
            SubqueryProjectionSegment result = new SubqueryProjectionSegment((SubquerySegment) exprProjection, text);
            result.setAlias(alias);
            return result;
        }
        if (exprProjection instanceof ExistsSubqueryExpression) {
            ExistsSubqueryExpression existsSubqueryExpression = (ExistsSubqueryExpression) exprProjection;
            String text = ctx.start.getInputStream().getText(new Interval(existsSubqueryExpression.getStartIndex(), existsSubqueryExpression.getStopIndex()));
            SubqueryProjectionSegment result = new SubqueryProjectionSegment(((ExistsSubqueryExpression) exprProjection).getSubquery(), text);
            result.setAlias(alias);
            return result;
        }
        return createProjection(ctx, alias, exprProjection);
    }
    
    private ShorthandProjectionSegment createShorthandProjection(final QualifiedShorthandContext shorthand) {
        ShorthandProjectionSegment result = new ShorthandProjectionSegment(shorthand.getStart().getStartIndex(), shorthand.getStop().getStopIndex());
        IdentifierContext identifier = shorthand.identifier().get(shorthand.identifier().size() - 1);
        OwnerSegment owner = new OwnerSegment(identifier.getStart().getStartIndex(), identifier.getStop().getStopIndex(), new IdentifierValue(identifier.getText()));
        result.setOwner(owner);
        if (shorthand.identifier().size() > 1) {
            IdentifierContext schemaIdentifier = shorthand.identifier().get(0);
            owner.setOwner(new OwnerSegment(schemaIdentifier.getStart().getStartIndex(), schemaIdentifier.getStop().getStopIndex(), new IdentifierValue(schemaIdentifier.getText())));
        }
        return result;
    }
    
    @Override
    public ASTNode visitAlias(final AliasContext ctx) {
        return new AliasSegment(ctx.start.getStartIndex(), ctx.stop.getStopIndex(), new IdentifierValue(ctx.textOrIdentifier().getText()));
    }
    
    private ASTNode createProjection(final ProjectionContext ctx, final AliasSegment alias, final ASTNode projection) {
        if (projection instanceof AggregationProjectionSegment) {
            ((AggregationProjectionSegment) projection).setAlias(alias);
            return projection;
        }
        if (projection instanceof ExpressionProjectionSegment) {
            ((ExpressionProjectionSegment) projection).setAlias(alias);
            return projection;
        }
        if (projection instanceof FunctionSegment) {
            FunctionSegment functionSegment = (FunctionSegment) projection;
            ExpressionProjectionSegment result = new ExpressionProjectionSegment(functionSegment.getStartIndex(), functionSegment.getStopIndex(), functionSegment.getText(), functionSegment);
            result.setAlias(alias);
            return result;
        }
        if (projection instanceof CommonExpressionSegment) {
            CommonExpressionSegment segment = (CommonExpressionSegment) projection;
            ExpressionProjectionSegment result = new ExpressionProjectionSegment(segment.getStartIndex(), segment.getStopIndex(), segment.getText(), segment);
            result.setAlias(alias);
            return result;
        }
        // FIXME :For DISTINCT()
        if (projection instanceof ColumnSegment) {
            ExpressionProjectionSegment result = new ExpressionProjectionSegment(ctx.start.getStartIndex(), ctx.stop.getStopIndex(), getOriginalText(ctx), (ColumnSegment) projection);
            result.setAlias(alias);
            return result;
        }
        if (projection instanceof SubqueryExpressionSegment) {
            SubqueryExpressionSegment subqueryExpressionSegment = (SubqueryExpressionSegment) projection;
            String text = ctx.start.getInputStream().getText(new Interval(subqueryExpressionSegment.getStartIndex(), subqueryExpressionSegment.getStopIndex()));
            SubqueryProjectionSegment result = new SubqueryProjectionSegment(subqueryExpressionSegment.getSubquery(), text);
            result.setAlias(alias);
            return result;
        }
        if (projection instanceof BinaryOperationExpression) {
            int startIndex = ((BinaryOperationExpression) projection).getStartIndex();
            int stopIndex = null != alias ? alias.getStopIndex() : ((BinaryOperationExpression) projection).getStopIndex();
            ExpressionProjectionSegment result = new ExpressionProjectionSegment(startIndex, stopIndex, ((BinaryOperationExpression) projection).getText(), (BinaryOperationExpression) projection);
            result.setAlias(alias);
            return result;
        }
        if (projection instanceof ParameterMarkerExpressionSegment) {
            ParameterMarkerExpressionSegment result = (ParameterMarkerExpressionSegment) projection;
            result.setAlias(alias);
            return projection;
        }
        if (projection instanceof CaseWhenExpression) {
            ExpressionProjectionSegment result = new ExpressionProjectionSegment(ctx.start.getStartIndex(), ctx.stop.getStopIndex(), getOriginalText(ctx.expr()), (CaseWhenExpression) projection);
            result.setAlias(alias);
            return result;
        }
        LiteralExpressionSegment column = (LiteralExpressionSegment) projection;
        ExpressionProjectionSegment result = null == alias
                ? new ExpressionProjectionSegment(column.getStartIndex(), column.getStopIndex(), String.valueOf(column.getLiterals()), column)
                : new ExpressionProjectionSegment(column.getStartIndex(), ctx.alias().stop.getStopIndex(), String.valueOf(column.getLiterals()), column);
        result.setAlias(alias);
        return result;
    }
    
    @Override
    public ASTNode visitFromClause(final FromClauseContext ctx) {
        return visit(ctx.tableReferences());
    }

    /**
     * 表对象访问
     * @param ctx
     * @return
     */
    @Override
    public ASTNode visitTableReferences(final TableReferencesContext ctx) {
        // 获取第一个表部分
        TableSegment result = (TableSegment) visit(ctx.tableReference(0));
        if (ctx.tableReference().size() > 1) {
            // 连接表的访问
            for (int i = 1; i < ctx.tableReference().size(); i++) {
                result = generateJoinTableSourceFromEscapedTableReference(ctx.tableReference(i), result);
            }
        }
        return result;
    }

    /**
     * 连接表部分访问
     * @param ctx
     * @param tableSegment
     * @return
     */
    private JoinTableSegment generateJoinTableSourceFromEscapedTableReference(final TableReferenceContext ctx, final TableSegment tableSegment) {
        JoinTableSegment result = new JoinTableSegment();
        result.setStartIndex(tableSegment.getStartIndex());
        result.setStopIndex(ctx.stop.getStopIndex());
        result.setLeft(tableSegment);
        result.setJoinType(JoinType.COMMA.name());
        result.setRight((TableSegment) visit(ctx));
        return result;
    }

    /**
     * 访问表部分
     * @param ctx
     * @return
     */
    @Override
    public ASTNode visitTableReference(final TableReferenceContext ctx) {
        TableSegment result;
        TableSegment left;
        // 访问TableFactor部分
        left = null != ctx.tableFactor() ? (TableSegment) visit(ctx.tableFactor()) : (TableSegment) visit(ctx.escapedTableReference());
        // 如果有连接表
        for (JoinedTableContext each : ctx.joinedTable()) {
            // 访问连接表部分
            left = visitJoinedTable(each, left);
        }
        result = left;
        return result;
    }

    /**
     * 访问表因素
     * @param ctx
     * @return
     */
    @Override
    public ASTNode visitTableFactor(final TableFactorContext ctx) {
        // 子查询访问
        if (null != ctx.subquery()) {
            MySQLSelectStatement subquery = (MySQLSelectStatement) visit(ctx.subquery());
            SubquerySegment subquerySegment = new SubquerySegment(ctx.subquery().start.getStartIndex(), ctx.subquery().stop.getStopIndex(), subquery);
            SubqueryTableSegment result = new SubqueryTableSegment(subquerySegment);
            if (null != ctx.alias()) {
                result.setAlias((AliasSegment) visit(ctx.alias()));
            }
            return result;
        }
        // 表名访问
        if (null != ctx.tableName()) {
            // 访问表名部分,获取简单表部分对象
            SimpleTableSegment result = (SimpleTableSegment) visit(ctx.tableName());
            // 如果有昵称
            if (null != ctx.alias()) {
                // 访问昵称部分
                result.setAlias((AliasSegment) visit(ctx.alias()));
            }
            return result;
        }
        return visit(ctx.tableReferences());
    }

    /**
     * 访问where部分
     * @param ctx
     * @return
     */
    @Override
    public ASTNode visitWhereClause(final WhereClauseContext ctx) {
        // 访问表达式
        ASTNode segment = visit(ctx.expr());
        // 创建where部分
        return new WhereSegment(ctx.getStart().getStartIndex(), ctx.getStop().getStopIndex(), (ExpressionSegment) segment);
    }

    // 省略其他
    
}

示例解析

还是下面的sql为例:

select id, name from tb_user where id = 1

通过Antlr的访问者(Visitor)接口遍历语法树时,对于以上的select语句,会先访问MySQLStatementBaseVisitor的visitSelect() -> visitQueryExpression() -> visitQueryExpressionBody() -> visitQueryPrimary() -> visitQuerySpecification()。

MySQLStatementSQLVisitor重写了visitQuerySpecification()方法,进而逐步访问匹配的投影、from部分、where部分。执行步骤如下:

解析之后,生成一个MySQLSelectStatement对象。

其他操作类型的sql语句的解析流程也大体相当。

小结

限于篇幅,本篇先分享到这里。以下做一个小结:

1)Sharding-JDBC支持的数据库包括:mysql、opengauss、oracle、postgresql、sql92、sqlserver;

2)Sharding-JDBC采用 Antlr 解析 sql 语句,不同的数据库对应不同的 .g4 文件;

文件路径在源码的 sql-parser/dialect 目录下;

文件名称为 XxxStatement.g4,其中Xxx对应数据库名;

3)对于 MySQL 数据库,规则文件为 MySQLStatement.g4,常用的增删改查的规则信息定义在 DMLStatement.g4文件;

3.1)通过 Antlr 的 Generate ANTLR Recognizer,自动生成 MySQLStatementBaseVisitor 文件。Sharding-JDBC自定义MySQLStatementSQLVisitor,该类继承 MySQLStatementBaseVisitor;

3.2)对于 select 查询,通过 Antlr 的访问者(Visitor)接口访问生成树时,入口方法为MySQLStatementBaseVisitor的visitSelect()方法。解析后生成MySQLSelectStatement对象;

关于本篇内容你有什么自己的想法或独到见解,欢迎在评论区一起交流探讨下吧。

相关推荐
清风-云烟17 小时前
使用redis-cli命令实现redis crud操作
java·linux·数据库·redis·spring·缓存·1024程序员节
Joeysoda1 天前
Java数据结构 (链表反转(LinkedList----Leetcode206))
java·linux·开发语言·数据结构·链表·1024程序员节
比特在路上1 天前
StackOrQueueOJ3:用栈实现队列
c语言·开发语言·数据结构·1024程序员节
0xCC说逆向3 天前
Windows图形界面(GUI)-QT-C/C++ - Qt键盘与鼠标事件处理详解
c语言·开发语言·c++·windows·qt·win32·1024程序员节
明明真系叻4 天前
2025.1.18机器学习笔记:PINN文献精读
人工智能·笔记·深度学习·机器学习·1024程序员节
0xCC说逆向5 天前
Windows图形界面(GUI)-QT-C/C++ - Qt List Widget详解与应用
c语言·开发语言·c++·windows·qt·win32·1024程序员节
明明真系叻7 天前
2025.1.12机器学习笔记:GAN文献阅读
人工智能·笔记·深度学习·机器学习·1024程序员节
比特在路上8 天前
OJ12:160. 相交链表
c语言·数据结构·算法·链表·1024程序员节
earthzhang20219 天前
《深入浅出HTTPS》读书笔记(28):DSA数字签名
开发语言·网络协议·算法·https·1024程序员节
比特在路上9 天前
初阶数据结构【栈及其接口的实现】
c语言·开发语言·数据结构·1024程序员节