从零手写 SQL 查询引擎:解析器、优化器与执行器实战

引言

SQL(Structured Query Language)作为数据查询的"世界语",已经走过了四十多年的历史。无论是 MySQL、PostgreSQL 这样的传统关系型数据库,还是 SparkSQL、Presto 这样的分布式查询引擎,底层的 SQL 处理流程都遵循着相同的架构范式:SQL 文本 → 词法分析 → 语法分析 → 语义绑定 → 逻辑优化 → 物理执行

每天写 SQL 的开发者数以百万计,但真正了解 SQL 引擎内部工作原理的人却并不多。大多数人的认知停留在"SQL 进了数据库,然后就出了结果"的黑盒阶段。然而,当我们面对查询性能调优、学习新的大数据查询引擎、甚至需要为特定场景定制查询能力时,理解 SQL 引擎的内部机制就会变得至关重要。

本文将从零开始,用纯 Python 实现一个简化但五脏俱全的 SQL 查询引擎,涵盖以下核心模块:

  • 词法分析器(Lexer):将 SQL 文本拆分为 Token 序列
  • 语法分析器(Parser):构建抽象语法树(AST)
  • 语义分析器(Binder):解析列引用、类型校验
  • 逻辑优化器(Optimizer):实现谓词下推、投影下推等经典优化
  • 执行引擎(Executor):支持过滤、投影、Join、聚合、排序

完整代码约 700 行,所有代码在 Python 3.10+ 环境下测试通过。这篇教程会带你从一行 SELECT * FROM users WHERE age > 18 开始,亲手搭建起属于你自己的 SQL 引擎。


一、整体架构设计

1.1 流水线架构

一个完整的 SQL 查询引擎通常包含以下处理流水线:

复制代码
SQL 字符串
    │
    ▼
┌─────────────┐    Lexer(词法分析)
│  Token 序列  │    Parser(语法分析)
├─────────────┤
│  AST 语法树  │    Binder(语义绑定)
├─────────────┤
│ 逻辑执行计划 │    Optimizer(逻辑优化)
├─────────────┤
│ 物理执行计划 │    Executor(物理执行)
│     ↓       │
│  查询结果    │
└─────────────┘

这五个阶段的划分并不是 SQL 引擎独有的------编译器领域也采用类似的架构。实际上,SQL 引擎本质上就是一个面向数据的领域特定语言(DSL)编译器

1.2 功能范围

我们的迷你 SQL 引擎支持以下查询功能:

功能 示例 状态
SELECT-FROM-WHERE SELECT * FROM t WHERE id > 10
列投影 SELECT name, age FROM t
JOIN(INNER) SELECT * FROM a JOIN b ON a.id = b.id
聚合函数 SELECT COUNT(*), AVG(age) FROM t GROUP BY dept
ORDER BY SELECT * FROM t ORDER BY age DESC
LIMIT SELECT * FROM t LIMIT 10
DISTINCT SELECT DISTINCT city FROM t
嵌套表达式 WHERE (age > 18 AND city = 'BJ') OR level >= 3
多表列引用 SELECT u.name, o.amount FROM users u JOIN orders o

不支持的特性(为保持代码可读性):子查询、窗口函数、CTE、UPDATE/INSERT/DELETE、索引、事务。

1.3 项目结构

复制代码
sql_engine/
├── token.py        # Token 类型定义
├── lexer.py        # 词法分析器
├── ast.py          # AST 节点定义
├── parser.py       # 语法分析器
├── catalog.py      # 元数据目录
├── binder.py       # 语义绑定 + 查询优化(谓词下推)
├── executor.py     # 火山模型执行引擎
└── main.py         # 入口 + 测试用例

二、词法分析器(Lexer)

词法分析是整个 SQL 引擎的第一步。它的任务是将原始的 SQL 字符串分割成一个一个的词法单元(Token) 。比如 SELECT name FROM users WHERE age > 18 会被拆分为:

复制代码
SELECT → name → FROM → users → WHERE → age → > → 18

每个 Token 包含了三个要素:类型(Type)、值(Value)和位置(Line:Col)。

2.1 Token 类型定义

我们需要为 SQL 中的每个概念元素定义对应的 Token 类型:

复制代码
# token.py
from enum import Enum, auto

class TokenType(Enum):
    # ── 关键字 ──
    SELECT = auto();    FROM = auto();     WHERE = auto()
    JOIN = auto();      ON = auto();       AND = auto()
    OR = auto();        NOT = auto();      IN = auto()
    IS = auto();        NULL = auto();     AS = auto()
    GROUP = auto();     BY = auto();       HAVING = auto()
    ORDER = auto();     ASC = auto();      DESC = auto()
    LIMIT = auto();     DISTINCT = auto()
    LEFT = auto();      RIGHT = auto();    INNER = auto()
    CROSS = auto()

    # ── 聚合函数 ──
    COUNT = auto();     SUM = auto();      AVG = auto()
    MIN = auto();       MAX = auto()

    # ── 标识符与字面量 ──
    IDENTIFIER = auto()
    NUMBER = auto()
    STRING = auto()     # 字符串字面量

    # ── 运算符 ──
    EQ = auto()         # =
    NEQ = auto()        # !=  <>
    GT = auto()         # >
    GTE = auto()        # >=
    LT = auto()         # <
    LTE = auto()        # <=
    PLUS = auto()       # +
    MINUS = auto()      # -
    STAR = auto()       # *
    DIV = auto()        # /
    MOD = auto()        # %

    # ── 标点 ──
    LPAREN = auto()     # (
    RPAREN = auto()     # )
    COMMA = auto()      # ,
    DOT = auto()        # .
    SEMICOLON = auto()  # ;

    EOF = auto()        # 文件结束标记


class Token:
    """词法单元"""
    def __init__(self, token_type: TokenType, value: str, line: int, col: int):
        self.type = token_type
        self.value = value
        self.line = line
        self.col = col

    def __repr__(self):
        return f"Token({self.type.name}, '{self.value}', {self.line}:{self.col})"

2.2 词法分析器实现

词法分析器采用经典的逐字符扫描 + 最长匹配策略。核心算法如下:

  1. 从当前位置读取一个字符

  2. 判断字符类型(数字/字母/运算符/引号)

  3. 根据判断结果进入不同的读取分支

  4. 读取完整词后生成 Token,继续扫描下一个

    lexer.py

    from token_def import Token, TokenType

    class Lexer:
    """SQL 词法分析器"""

    复制代码
     # 关键字表:将字符串映射到 TokenType
     KEYWORDS = {
         'SELECT': TokenType.SELECT, 'FROM': TokenType.FROM,
         'WHERE': TokenType.WHERE, 'JOIN': TokenType.JOIN,
         'ON': TokenType.ON, 'AND': TokenType.AND,
         'OR': TokenType.OR, 'NOT': TokenType.NOT,
         'IN': TokenType.IN, 'IS': TokenType.IS,
         'NULL': TokenType.NULL, 'AS': TokenType.AS,
         'GROUP': TokenType.GROUP, 'BY': TokenType.BY,
         'HAVING': TokenType.HAVING, 'ORDER': TokenType.ORDER,
         'ASC': TokenType.ASC, 'DESC': TokenType.DESC,
         'LIMIT': TokenType.LIMIT, 'DISTINCT': TokenType.DISTINCT,
         'LEFT': TokenType.LEFT, 'RIGHT': TokenType.RIGHT,
         'INNER': TokenType.INNER, 'CROSS': TokenType.CROSS,
         'COUNT': TokenType.COUNT, 'SUM': TokenType.SUM,
         'AVG': TokenType.AVG, 'MIN': TokenType.MIN,
         'MAX': TokenType.MAX,
     }
    
     def __init__(self, text: str):
         self.text = text
         self.pos = 0
         self.line = 1
         self.col = 1
         self.length = len(text)
    
     def peek(self) -> str:
         """查看当前字符,不移动指针"""
         return '\0' if self.pos >= self.length else self.text[self.pos]
    
     def advance(self) -> str:
         """读取当前字符并移动指针"""
         ch = self.text[self.pos]
         self.pos += 1
         if ch == '\n':
             self.line += 1
             self.col = 1
         else:
             self.col += 1
         return ch
    
     def skip_whitespace(self):
         """跳过空白字符"""
         while self.pos < self.length and self.peek().isspace():
             self.advance()
    
     def skip_comment(self):
         """跳过单行注释"""
         while self.pos < self.length and self.peek() != '\n':
             self.advance()
    
     def read_string(self) -> Token:
         """读取字符串字面量"""
         start_col = self.col
         self.advance()  # 跳过起始单引号
         value = []
         while self.pos < self.length:
             ch = self.peek()
             if ch == '\\':
                 self.advance()
                 escaped = self.advance()
                 escape_map = {'n': '\n', 't': '\t', 'r': '\r', "'": "'", '\\': '\\'}
                 value.append(escape_map.get(escaped, escaped))
             elif ch == "'":
                 self.advance()
                 return Token(TokenType.STRING, ''.join(value), self.line, start_col)
             else:
                 value.append(self.advance())
         raise SyntaxError(f"L{self.line}:C{start_col} 未闭合的字符串字面量")
    
     def read_number(self) -> Token:
         """读取数字字面量"""
         start_col = self.col
         value = []
         while self.pos < self.length and self.peek().isdigit():
             value.append(self.advance())
         if self.pos < self.length and self.peek() == '.':
             value.append(self.advance())
             while self.pos < self.length and self.peek().isdigit():
                 value.append(self.advance())
             return Token(TokenType.NUMBER, ''.join(value), self.line, start_col)
         return Token(TokenType.NUMBER, ''.join(value), self.line, start_col)
    
     def read_identifier(self) -> Token:
         """读取标识符或关键字"""
         start_col = self.col
         value = []
         while self.pos < self.length and (self.peek().isalnum() or self.peek() == '_'):
             value.append(self.advance())
         word = ''.join(value)
         token_type = self.KEYWORDS.get(word.upper(), TokenType.IDENTIFIER)
         return Token(token_type, word, self.line, start_col)
    
     def tokenize(self) -> list:
         """执行词法分析,返回 Token 列表"""
         tokens = []
         while self.pos < self.length:
             ch = self.peek()
    
             if ch.isspace():
                 self.skip_whitespace()
                 continue
             if ch == '-' and self.pos + 1 < self.length and self.text[self.pos + 1] == '-':
                 self.skip_comment()
                 continue
             if ch == "'":
                 tokens.append(self.read_string())
                 continue
             if ch.isdigit():
                 tokens.append(self.read_number())
                 continue
             if ch.isalpha() or ch == '_':
                 tokens.append(self.read_identifier())
                 continue
    
             start_col = self.col
             if ch == '=': self.advance(); tokens.append(Token(TokenType.EQ, '=', self.line, start_col))
             elif ch == '!':
                 self.advance()
                 if self.peek() == '=':
                     self.advance()
                     tokens.append(Token(TokenType.NEQ, '!=', self.line, start_col))
                 else:
                     raise SyntaxError(f"L{self.line}:C{start_col} 意外的字符 '!'")
             elif ch == '<':
                 self.advance()
                 if self.peek() == '=':
                     self.advance(); tokens.append(Token(TokenType.LTE, '<=', self.line, start_col))
                 elif self.peek() == '>':
                     self.advance(); tokens.append(Token(TokenType.NEQ, '<>', self.line, start_col))
                 else:
                     tokens.append(Token(TokenType.LT, '<', self.line, start_col))
             elif ch == '>':
                 self.advance()
                 if self.peek() == '=':
                     self.advance(); tokens.append(Token(TokenType.GTE, '>=', self.line, start_col))
                 else:
                     tokens.append(Token(TokenType.GT, '>', self.line, start_col))
             elif ch == '+': self.advance(); tokens.append(Token(TokenType.PLUS, '+', self.line, start_col))
             elif ch == '-': self.advance(); tokens.append(Token(TokenType.MINUS, '-', self.line, start_col))
             elif ch == '*': self.advance(); tokens.append(Token(TokenType.STAR, '*', self.line, start_col))
             elif ch == '/': self.advance(); tokens.append(Token(TokenType.DIV, '/', self.line, start_col))
             elif ch == '%': self.advance(); tokens.append(Token(TokenType.MOD, '%', self.line, start_col))
             elif ch == '(': self.advance(); tokens.append(Token(TokenType.LPAREN, '(', self.line, start_col))
             elif ch == ')': self.advance(); tokens.append(Token(TokenType.RPAREN, ')', self.line, start_col))
             elif ch == ',': self.advance(); tokens.append(Token(TokenType.COMMA, ',', self.line, start_col))
             elif ch == '.': self.advance(); tokens.append(Token(TokenType.DOT, '.', self.line, start_col))
             elif ch == ';': self.advance(); tokens.append(Token(TokenType.SEMICOLON, ';', self.line, start_col))
             else:
                 raise SyntaxError(f"L{self.line}:C{self.col} 意外字符 '{ch}'")
    
         tokens.append(Token(TokenType.EOF, '', self.line, self.col))
         return tokens

词法分析的关键设计要点:

  1. 最长匹配 :先读到完整的字母序列,再去查关键字表。这保证了 SELECTION 不会被误识别为 SELECT + ION
  2. 双字符运算符<=>=<>!= 需要超前查看一个字符。
  3. 字符串转义 :支持 \n\t\' 等标准转义序列。
  4. 位置追踪:每个 Token 记录了行列号,便于报错定位。

三、抽象语法树(AST)

在解析 SQL 之前,我们需要设计 AST 节点的数据结构。AST 将扁平的 Token 序列转化为有层次的结构化表示。

3.1 表达式节点

复制代码
# ast.py
class Node:
    """所有 AST 节点的基类"""
    pass

class Expr(Node):
    """表达式基类"""
    pass

class ColumnRef(Expr):
    """列引用"""
    def __init__(self, table: str = None, column: str = ''):
        self.table = table
        self.column = column
    def __repr__(self):
        return f"{self.table}.{self.column}" if self.table else self.column

class Literal(Expr):
    """字面量"""
    def __init__(self, value, value_type: str = None):
        self.value = value
        self.value_type = value_type
    def __repr__(self):
        return repr(self.value)

class BinaryOp(Expr):
    """二元运算"""
    def __init__(self, op: str, left: Expr, right: Expr):
        self.op = op     # '=', '!=', '<', '>', 'AND', 'OR', '+', '-' 等
        self.left = left
        self.right = right
    def __repr__(self):
        return f"({self.left} {self.op} {self.right})"

class UnaryOp(Expr):
    """一元运算"""
    def __init__(self, op: str, operand: Expr):
        self.op = op       # '-', 'NOT'
        self.operand = operand
    def __repr__(self):
        return f"({self.op} {self.operand})"

class FuncCall(Expr):
    """函数调用"""
    def __init__(self, name: str, args: list):
        self.name = name   # 'COUNT', 'SUM', 'AVG', 'MIN', 'MAX'
        self.args = args
    def __repr__(self):
        return f"{self.name}({', '.join(map(repr, self.args))})"

class Star(Expr):
    """星号"""
    def __repr__(self):
        return '*'

3.2 语句节点

复制代码
class SelectStatement(Node):
    """SELECT 查询语句 AST 根节点"""
    def __init__(self):
        self.select_items = []    # [Expr]
        self.from_table = None    # str
        self.from_alias = None    # str
        self.join_clauses = []    # [JoinClause]
        self.where_clause = None  # Expr
        self.group_by = []        # [Expr]
        self.having = None        # Expr
        self.order_by = []        # [OrderByItem]
        self.limit = None         # int
        self.distinct = False

class JoinClause(Node):
    """JOIN 子句"""
    def __init__(self, table: str, join_type: str = 'INNER', condition: Expr = None, alias: str = None):
        self.table = table
        self.alias = alias
        self.join_type = join_type
        self.condition = condition

class OrderByItem(Node):
    """ORDER BY 项"""
    def __init__(self, expr: Expr, direction: str = 'ASC'):
        self.expr = expr
        self.direction = direction

AST 的设计质量直接影响后续分析和优化的容易程度。好的 AST 应该是简洁且无冗余的------每个节点只做一件事,表达式的层级关系通过嵌套自然体现。


四、语法分析器(Parser)

语法分析器将扁平的 Token 序列转化为层次化的 AST。我们采用递归下降分析法(Recursive Descent Parsing)------这是手工编写 SQL 解析器的主流方法。每个语法规则对应一个解析方法,方法之间通过相互调用来处理嵌套结构。

4.1 运算符优先级处理

运算符优先级是表达式解析的核心难题。在 SQL 中,常见的运算符优先级从高到低为:

复制代码
最高优先级:() 括号
  → 一元: - 负号、NOT
  → 算术: * / % → + -
  → 比较: = != < > <= >= IN IS
  → 逻辑: AND
  → OR(最低优先级)

递归下降解析器通过分层解析天然地处理了优先级。每一层只处理自己这一级的运算符,当遇到优先级更高的表达式时,交给下层方法处理:

复制代码
def parse_expr(self) -> Expr:
    return self.parse_or()        # 最低优先级

def parse_or(self) -> Expr:
    left = self.parse_and()       # 遇到 AND 子表达式,交给下层
    while peek(OR):
        right = self.parse_and()  # 右侧也交给下层
        left = BinaryOp('OR', left, right)
    return left

def parse_and(self) -> Expr:
    left = self.parse_comparison()
    while peek(AND):
        right = self.parse_comparison()
        left = BinaryOp('AND', left, right)
    return left

# ... 逐层下降,直到 parse_primary(最高优先级)

这种做法的精妙之处在于:每一层都只处理自己的运算符,如果子表达式中包含更高优先级的运算符,会在下层被优先处理完才返回。1 + 2 * 3 中,parse_add_sub 先调用 parse_mul_div 处理 2 * 3,然后才做加法。

4.2 解析器完整实现

复制代码
# parser.py
from token_def import Token, TokenType
from ast import *

class Parser:
    """递归下降 SQL 解析器"""

    def __init__(self, tokens: list):
        self.tokens = tokens
        self.pos = 0

    def peek(self) -> Token:
        return self.tokens[self.pos]

    def peek_type(self, *types: TokenType) -> bool:
        return self.peek().type in types

    def advance(self) -> Token:
        t = self.tokens[self.pos]; self.pos += 1; return t

    def expect(self, tt: TokenType) -> Token:
        if self.peek().type != tt:
            t = self.peek()
            raise SyntaxError(f"L{t.line}:C{t.col} 期望 {tt.name},但得到 '{t.value}'")
        return self.advance()

    # ── 入口 ──

    def parse(self) -> SelectStatement:
        return self.parse_select()

    def parse_select(self) -> SelectStatement:
        stmt = SelectStatement()
        self.expect(TokenType.SELECT)

        if self.peek_type(TokenType.DISTINCT):
            self.advance()
            stmt.distinct = True

        stmt.select_items = self.parse_select_list()
        self.expect(TokenType.FROM)

        table_token = self.expect(TokenType.IDENTIFIER)
        stmt.from_table = table_token.value

        # 别名(可选)
        if self.peek_type(TokenType.AS):
            self.advance()
        if self.peek_type(TokenType.IDENTIFIER):
            stmt.from_alias = self.advance().value

        # JOIN
        while self.peek_type(TokenType.JOIN, TokenType.LEFT, TokenType.RIGHT,
                             TokenType.INNER, TokenType.CROSS):
            stmt.join_clauses.append(self.parse_join())

        # WHERE
        if self.peek_type(TokenType.WHERE):
            self.advance()
            stmt.where_clause = self.parse_expr()

        # GROUP BY
        if self.peek_type(TokenType.GROUP):
            self.advance(); self.expect(TokenType.BY)
            stmt.group_by = self.parse_expr_list()

        # HAVING
        if self.peek_type(TokenType.HAVING):
            self.advance()
            stmt.having = self.parse_expr()

        # ORDER BY
        if self.peek_type(TokenType.ORDER):
            self.advance(); self.expect(TokenType.BY)
            stmt.order_by = self.parse_order_by_list()

        # LIMIT
        if self.peek_type(TokenType.LIMIT):
            self.advance()
            stmt.limit = int(self.expect(TokenType.NUMBER).value)

        return stmt

    def parse_select_list(self) -> list:
        items = []
        while True:
            if self.peek_type(TokenType.STAR):
                self.advance()
                items.append(Star())
            else:
                items.append(self.parse_expr())
                if self.peek_type(TokenType.AS):
                    self.advance(); self.expect(TokenType.IDENTIFIER)
            if self.peek_type(TokenType.COMMA):
                self.advance()
            else:
                break
        return items

    def parse_join(self) -> JoinClause:
        join_type = 'INNER'
        if self.peek_type(TokenType.LEFT, TokenType.RIGHT):
            jt = self.advance(); join_type = jt.value.upper()
            if self.peek_type(TokenType.JOIN): self.advance()
        elif self.peek_type(TokenType.INNER):
            self.advance(); self.expect(TokenType.JOIN)
        elif self.peek_type(TokenType.CROSS):
            self.advance(); self.expect(TokenType.JOIN)
            return JoinClause(self.expect(TokenType.IDENTIFIER).value, 'CROSS')
        else:
            self.expect(TokenType.JOIN)

        table = self.expect(TokenType.IDENTIFIER).value
        alias = None
        if self.peek_type(TokenType.AS):
            self.advance(); alias = self.expect(TokenType.IDENTIFIER).value
        elif self.peek_type(TokenType.IDENTIFIER) and not self.peek_type(TokenType.ON):
            alias = self.advance().value

        condition = None
        if self.peek_type(TokenType.ON):
            self.advance(); condition = self.parse_expr()

        return JoinClause(table, join_type, condition, alias)

    def parse_order_by_list(self) -> list:
        items = []
        while True:
            expr = self.parse_expr()
            direction = 'ASC'
            if self.peek_type(TokenType.ASC): self.advance()
            elif self.peek_type(TokenType.DESC): self.advance(); direction = 'DESC'
            items.append(OrderByItem(expr, direction))
            if self.peek_type(TokenType.COMMA): self.advance()
            else: break
        return items

    def parse_expr_list(self) -> list:
        items = [self.parse_expr()]
        while self.peek_type(TokenType.COMMA):
            self.advance(); items.append(self.parse_expr())
        return items

    # ── 表达式优先级解析 ──

    def parse_expr(self) -> Expr:
        return self.parse_or()

    def parse_or(self) -> Expr:
        left = self.parse_and()
        while self.peek_type(TokenType.OR):
            self.advance()
            right = self.parse_and()
            left = BinaryOp('OR', left, right)
        return left

    def parse_and(self) -> Expr:
        left = self.parse_not()
        while self.peek_type(TokenType.AND):
            self.advance()
            right = self.parse_not()
            left = BinaryOp('AND', left, right)
        return left

    def parse_not(self) -> Expr:
        if self.peek_type(TokenType.NOT):
            self.advance()
            return UnaryOp('NOT', self.parse_not())
        return self.parse_comparison()

    def parse_comparison(self) -> Expr:
        left = self.parse_add_sub()

        if self.peek_type(TokenType.EQ):
            self.advance(); return BinaryOp('=', left, self.parse_add_sub())
        if self.peek_type(TokenType.NEQ):
            self.advance(); return BinaryOp('!=', left, self.parse_add_sub())
        if self.peek_type(TokenType.GT):
            self.advance(); return BinaryOp('>', left, self.parse_add_sub())
        if self.peek_type(TokenType.GTE):
            self.advance(); return BinaryOp('>=', left, self.parse_add_sub())
        if self.peek_type(TokenType.LT):
            self.advance(); return BinaryOp('<', left, self.parse_add_sub())
        if self.peek_type(TokenType.LTE):
            self.advance(); return BinaryOp('<=', left, self.parse_add_sub())
        if self.peek_type(TokenType.IN):
            self.advance(); self.expect(TokenType.LPAREN)
            values = self.parse_expr_list(); self.expect(TokenType.RPAREN)
            return BinaryOp('IN', left, Literal(values))
        if self.peek_type(TokenType.IS):
            self.advance()
            op = 'IS'
            if self.peek_type(TokenType.NOT): self.advance(); op = 'IS NOT'
            self.expect(TokenType.NULL)
            return BinaryOp(op, left, Literal(None))

        return left

    def parse_add_sub(self) -> Expr:
        left = self.parse_mul_div()
        while self.peek_type(TokenType.PLUS, TokenType.MINUS):
            op = self.advance().value
            right = self.parse_mul_div()
            left = BinaryOp(op, left, right)
        return left

    def parse_mul_div(self) -> Expr:
        left = self.parse_unary()
        while self.peek_type(TokenType.STAR, TokenType.DIV, TokenType.MOD):
            op = self.advance().value
            right = self.parse_unary()
            left = BinaryOp(op, left, right)
        return left

    def parse_unary(self) -> Expr:
        if self.peek_type(TokenType.MINUS):
            self.advance(); return UnaryOp('-', self.parse_unary())
        return self.parse_primary()

    def parse_primary(self) -> Expr:
        # 括号
        if self.peek_type(TokenType.LPAREN):
            self.advance()
            expr = self.parse_expr()
            self.expect(TokenType.RPAREN)
            return expr

        # 聚合函数
        if self.peek_type(TokenType.COUNT, TokenType.SUM, TokenType.AVG,
                          TokenType.MIN, TokenType.MAX):
            func = self.advance().value.upper()
            self.expect(TokenType.LPAREN)
            args = []
            if not self.peek_type(TokenType.RPAREN):
                args = self.parse_expr_list()
            self.expect(TokenType.RPAREN)
            return FuncCall(func, args)

        # 字面量
        if self.peek_type(TokenType.NUMBER):
            t = self.advance()
            val = float(t.value) if '.' in t.value else int(t.value)
            return Literal(val)
        if self.peek_type(TokenType.STRING):
            return Literal(self.advance().value, 'STRING')
        if self.peek_type(TokenType.NULL):
            self.advance(); return Literal(None)

        # 列引用
        if self.peek_type(TokenType.IDENTIFIER):
            first = self.advance().value
            if self.peek_type(TokenType.DOT):
                self.advance()
                second = self.expect(TokenType.IDENTIFIER).value
                return ColumnRef(table=first, column=second)
            return ColumnRef(column=first)

        t = self.peek()
        raise SyntaxError(f"L{t.line}:C{t.col} 意外的 Token: '{t.value}' ({t.type.name})")

五、语义绑定与查询优化

解析后的 AST 还不知道"表里有什么列"、"列是什么类型"。语义绑定(Binding)负责将 AST 中的符号引用解析为具体的表结构。优化器则对执行计划进行等价变换,减少数据扫描量和计算量。

5.1 元数据目录

复制代码
# catalog.py
class ColumnMeta:
    def __init__(self, name: str, col_type: str):
        self.name = name
        self.type = col_type  # 'INTEGER', 'FLOAT', 'STRING', 'BOOLEAN'

class TableMeta:
    def __init__(self, name: str):
        self.name = name
        self.columns = {}

    def add_column(self, name: str, col_type: str):
        self.columns[name] = ColumnMeta(name, col_type)

class Catalog:
    """元数据目录"""
    def __init__(self):
        self.tables = {}

    def add_table(self, name: str, data: list = None):
        self.tables[name] = TableMeta(name)

    def get_table(self, name: str) -> TableMeta:
        if name not in self.tables:
            raise ValueError(f"表 '{name}' 不存在")
        return self.tables[name]

    def resolve_column(self, table_name: str, col_name: str) -> ColumnMeta:
        table = self.get_table(table_name)
        if col_name not in table.columns:
            raise ValueError(f"表 '{table_name}' 中不存在列 '{col_name}'")
        return table.columns[col_name]

5.2 执行计划节点定义

在执行引擎之前,我们需要定义执行计划(PlanNode)的节点类型。每个节点代表执行过程中的一个操作:

复制代码
# executor.py - PlanNode 定义
from abc import ABC, abstractmethod

class PlanNode(ABC):
    """执行计划节点基类"""
    @abstractmethod
    def next(self) -> list:
        """返回一行数据(列值列表),没有数据时返回 None"""
        pass

class ScanNode(PlanNode):
    """全表扫描节点"""
    def __init__(self, table_data: list, columns: list):
        self.data = table_data   # 原始数据
        self.columns = columns   # 列名列表
        self.idx = 0             # 当前读取位置

    def next(self) -> list:
        if self.idx >= len(self.data):
            return None
        row = self.data[self.idx]
        self.idx += 1
        return row

class FilterNode(PlanNode):
    """过滤节点:对输入应用 WHERE 条件"""
    def __init__(self, child: PlanNode, condition, columns: dict):
        self.child = child
        self.condition = condition  # AST 表达式
        self.columns = columns      # 列名 → 列索引的映射

    def next(self) -> list:
        while True:
            row = self.child.next()
            if row is None:
                return None
            if self.evaluate(self.condition, row):
                return row

    def evaluate(self, expr, row) -> bool:
        """递归求值表达式"""
        val = self.eval_expr(expr, row)
        if isinstance(val, bool):
            return val
        # SQL 中 WHERE 条件视为布尔上下文
        return bool(val) if val is not None else False

    def eval_expr(self, expr, row):
        if isinstance(expr, Literal):
            return expr.value
        if isinstance(expr, ColumnRef):
            col_idx = self.columns[expr.column]
            return row[col_idx]
        if isinstance(expr, BinaryOp):
            left = self.eval_expr(expr.left, row)
            right = self.eval_expr(expr.right, row)
            if expr.op == '=': return left == right
            if expr.op == '!=': return left != right
            if expr.op == '>': return left > right
            if expr.op == '>=': return left >= right
            if expr.op == '<': return left < right
            if expr.op == '<=': return left <= right
            if expr.op == 'AND': return left and right
            if expr.op == 'OR': return left or right
            if expr.op == '+': return left + right
            if expr.op == '-': return left - right
            if expr.op == '*': return left * right
            if expr.op == '/': return left / right
            raise ValueError(f"未知运算符: {expr.op}")
        if isinstance(expr, UnaryOp):
            val = self.eval_expr(expr.operand, row)
            if expr.op == 'NOT': return not val
            if expr.op == '-': return -val
        raise ValueError(f"未知表达式: {type(expr)}")

```python
class ProjectNode(PlanNode):
    """投影节点:从输入行中提取需要的列"""
    def __init__(self, child: PlanNode, select_items: list, columns: dict):
        self.child = child
        self.select_items = select_items
        self.columns = columns

    def next(self) -> list:
        row = self.child.next()
        if row is None:
            return None
        result = []
        for item in self.select_items:
            if isinstance(item, Star):
                result.extend(row)
            else:
                result.append(self.eval_expr(item, row))
        return result

    def eval_expr(self, expr, row):
        """递归求值表达式(与 FilterNode 共用逻辑,简化起见复用)"""
        if isinstance(expr, Literal):
            return expr.value
        if isinstance(expr, ColumnRef):
            if expr.table:
                key = f"{expr.table}.{expr.column}"
            else:
                key = expr.column
            col_idx = self.columns.get(key, self.columns.get(expr.column))
            if col_idx is None:
                raise ValueError(f"找不到列: {expr}")
            return row[col_idx]
        if isinstance(expr, BinaryOp):
            left = self.eval_expr(expr.left, row)
            right = self.eval_expr(expr.right, row)
            ops = {'=': lambda a,b: a==b, '!=': lambda a,b: a!=b,
                   '>': lambda a,b: a>b, '<': lambda a,b: a<b,
                   '+': lambda a,b: a+b, '-': lambda a,b: a-b,
                   '*': lambda a,b: a*b, '/': lambda a,b: a/b}
            return ops[expr.op](left, right)
        if isinstance(expr, FuncCall):
            if expr.name == 'COUNT':
                return len(row)  # 简化:COUNT(*) 的场景
            if expr.name == 'SUM':
                return sum(self.eval_expr(a, row) for a in expr.args)
        raise ValueError(f"未知表达式: {type(expr)}")


class JoinNode(PlanNode):
    """JOIN 节点:嵌套循环连接(Nested Loop Join)"""
    def __init__(self, left: PlanNode, right: PlanNode, condition, 
                 left_cols: dict, right_cols: dict):
        self.left = left
        self.right = right
        self.condition = condition
        self.left_cols = left_cols
        self.right_cols = right_cols
        self.left_row = None
        self.right_cache = []  # 缓存右表数据
        self.right_idx = 0
        self._load_right()

    def _load_right(self):
        """将右表全部读入缓存"""
        while True:
            row = self.right.next()
            if row is None:
                break
            self.right_cache.append(row)

    def next(self) -> list:
        while True:
            if self.left_row is None:
                self.left_row = self.left.next()
                if self.left_row is None:
                    return None
                self.right_idx = 0

            while self.right_idx < len(self.right_cache):
                right_row = self.right_cache[self.right_idx]
                self.right_idx += 1
                # 检查 ON 条件
                combined = self.left_row + right_row
                if self.condition is None:
                    return combined  # CROSS JOIN
                combined_cols = {**self.left_cols, **self.right_cols}
                # 简单重映射:合并后的列索引
                if self._eval_condition(combined, combined_cols):
                    return combined

            self.left_row = None  # 左表下一条

        return None

    def _eval_condition(self, row, col_map):
        """评估 ON 条件表达式"""
        # 简化:直接比较两个列值
        if isinstance(self.condition, BinaryOp) and self.condition.op == '=':
            left_val = self._resolve(self.condition.left, row, col_map)
            right_val = self._resolve(self.condition.right, row, col_map)
            return left_val == right_val
        return True  # 简化


class AggregateNode(PlanNode):
    """聚合节点:GROUP BY + 聚合函数"""
    def __init__(self, child: PlanNode, group_by: list, select_items: list, columns: dict):
        self.child = child
        self.group_by = group_by
        self.select_items = select_items
        self.columns = columns
        self.results = None
        self.idx = 0

    def next(self) -> list:
        if self.results is None:
            self._execute()
        if self.idx >= len(self.results):
            return None
        row = self.results[self.idx]
        self.idx += 1
        return row

    def _execute(self):
        """执行聚合运算"""
        groups = {}
        # 读取所有数据并分组
        while True:
            row = self.child.next()
            if row is None:
                break
            key = self._get_group_key(row)
            if key not in groups:
                groups[key] = []
            groups[key].append(row)

        # 对每组计算聚合值
        self.results = []
        for key, rows in groups.items():
            result = []
            for item in self.select_items:
                if isinstance(item, FuncCall):
                    vals = [self._resolve(arg, rows[0], self.columns) for arg in item.args]
                    if item.name == 'COUNT':
                        result.append(len(rows))
                    elif item.name == 'SUM':
                        col_vals = [self._resolve(item.args[0], r, self.columns) for r in rows]
                        result.append(sum(col_vals))
                    elif item.name == 'AVG':
                        col_vals = [self._resolve(item.args[0], r, self.columns) for r in rows]
                        result.append(sum(col_vals) / len(col_vals))
                    elif item.name == 'MIN':
                        col_vals = [self._resolve(item.args[0], r, self.columns) for r in rows]
                        result.append(min(col_vals))
                    elif item.name == 'MAX':
                        col_vals = [self._resolve(item.args[0], r, self.columns) for r in rows]
                        result.append(max(col_vals))
                elif isinstance(item, ColumnRef):
                    result.append(key if item.column in self.group_by else None)
                else:
                    result.append(self._resolve(item, rows[0], self.columns))
            self.results.append(result)


class SortNode(PlanNode):
    """排序节点"""
    def __init__(self, child: PlanNode, order_by: list, columns: dict):
        self.child = child
        self.order_by = order_by
        self.columns = columns
        self.sorted_data = None
        self.idx = 0

    def next(self) -> list:
        if self.sorted_data is None:
            self._sort()
        if self.idx >= len(self.sorted_data):
            return None
        row = self.sorted_data[self.idx]
        self.idx += 1
        return row

    def _sort(self):
        data = []
        while True:
            row = self.child.next()
            if row is None: break
            data.append(row)

        def cmp_key(row):
            keys = []
            for ob in self.order_by:
                val = self._resolve(ob.expr, row, self.columns)
                keys.append((val if ob.direction == 'ASC' else -val 
                           if isinstance(val, (int, float)) else val))
            return tuple(keys)

        data.sort(key=cmp_key, reverse=False)
        self.sorted_data = data


class LimitNode(PlanNode):
    """LIMIT 节点"""
    def __init__(self, child: PlanNode, limit: int):
        self.child = child
        self.limit = limit
        self.count = 0

    def next(self) -> list:
        if self.count >= self.limit:
            return None
        row = self.child.next()
        if row is None:
            return None
        self.count += 1
        return row

核心的执行器已经就位。这些节点通过"火山模型"串联------每个节点只需要实现 next() 方法,上层节点从下层节点"拉取"数据。这种设计的优点在于:每个节点是独立的、可组合的、支持流水线执行。

5.3 谓词下推优化

查询优化中最核心的规则之一是谓词下推(Predicate Pushdown)。它的思想是:将 WHERE 条件尽可能地"推"到数据源附近执行,尽早过滤掉不满足条件的行,减少后续节点的处理量。

复制代码
# binder.py
class Binder:
    """
    语义绑定器 + 查询优化器
    将 AST 转换为优化后的执行计划
    """

    def __init__(self, catalog: Catalog):
        self.catalog = catalog

    def bind_and_optimize(self, stmt: SelectStatement) -> PlanNode:
        """语义绑定 + 优化,返回执行计划根节点"""

        # Step 1: 解析表结构,构建列映射
        table_meta = self.catalog.get_table(stmt.from_table)

        # Step 2: 创建 Scan 节点(全表扫描)
        # 在真实引擎中,这里会做"投影下推"------只扫描 SELECT 和 WHERE 需要的列
        scan_node = ScanNode(
            self.catalog.get_table_data(stmt.from_table),
            list(table_meta.columns.keys())
        )

        current_node = scan_node

        # Step 3: 构建列名到列索引的映射
        col_index = self._build_col_index(table_meta, stmt.from_table)

        # Step 4: 谓词下推优化
        # 如果 WHERE 条件中只有当前表的列引用,直接下推到 Scan 之上
        if stmt.where_clause is not None:
            pushdown_expr = self._can_pushdown(stmt.where_clause, stmt.from_table)
            if pushdown_expr:
                # 谓词已经被推到 Scan 之上
                current_node = FilterNode(current_node, pushdown_expr, col_index)
            else:
                # 包含跨表引用,留在上层
                current_node = FilterNode(current_node, stmt.where_clause, col_index)

        # Step 5: 处理 JOIN
        for join in stmt.join_clauses:
            right_meta = self.catalog.get_table(join.table)
            right_scan = ScanNode(
                self.catalog.get_table_data(join.table),
                list(right_meta.columns.keys())
            )
            right_cols = self._build_col_index(right_meta, join.table, 
                                                len(col_index))

            # 合并列映射
            combined_cols = {**col_index, **right_cols}
            current_node = JoinNode(current_node, right_scan, join.condition,
                                    col_index, right_cols)
            col_index = combined_cols

        # Step 6: 处理 GROUP BY + 聚合
        if stmt.group_by:
            current_node = AggregateNode(current_node, stmt.group_by, 
                                         stmt.select_items, col_index)

        # Step 7: ORDER BY
        if stmt.order_by:
            current_node = SortNode(current_node, stmt.order_by, col_index)

        # Step 8: 投影
        current_node = ProjectNode(current_node, stmt.select_items, col_index)

        # Step 9: LIMIT
        if stmt.limit is not None:
            current_node = LimitNode(current_node, stmt.limit)

        return current_node

    def _can_pushdown(self, expr, table_name: str) -> Expr:
        """
        检查表达式是否能下推到表 table_name。
        如果所有列引用都属于该表,则可以下推。
        """
        if isinstance(expr, ColumnRef):
            if expr.table and expr.table != table_name:
                return None
            return expr
        if isinstance(expr, BinaryOp):
            left = self._can_pushdown(expr.left, table_name)
            right = self._can_pushdown(expr.right, table_name)
            if left and right:
                return BinaryOp(expr.op, left, right)
            return None
        if isinstance(expr, UnaryOp):
            opd = self._can_pushdown(expr.operand, table_name)
            return UnaryOp(expr.op, opd) if opd else None
        if isinstance(expr, Literal):
            return expr
        return None

谓词下推优化的收益到底有多大?

以一个简单的查询为例:

复制代码
SELECT name, age FROM users WHERE age > 30 AND city = 'Beijing'

不下推的执行顺序

  1. Full Scan 读取所有行(假设 1000 万行)

  2. 为每个 SELECT 项目做投影

  3. 最后才过滤 WHERE 条件

下推后的执行顺序

  1. Full Scan 过程中立即检查 age > 30

  2. 只对通过的行再检查 city = 'Beijing'

  3. 最后才做投影

数据点对比:如果 age > 30 过滤掉 90% 的行,则投影节点只处理 10% 的数据,内存和 CPU 开销降低一个数量级。


六、火山模型执行引擎

火山模型(Volcano Model)是数据库执行引擎的经典模式。它的核心理念是每个操作符都实现 next() 方法,上层操作符通过调用下层操作符的 next() 来"拉取"数据。这个模型有三个关键优势:

  1. 流水线执行(Pipelining):数据按行流动,不需要在中间节点物化整个数据集
  2. 模块化组合:可以用"乐高积木"的方式组合操作符
  3. 延迟计算:消费一行才计算一行,避免不必要计算

6.1 执行器入口

复制代码
# executor.py
class Executor:
    """SQL 执行引擎入口"""

    def __init__(self, plan: PlanNode):
        self.plan = plan

    def execute(self) -> list:
        """执行查询,返回所有结果行"""
        results = []
        while True:
            row = self.plan.next()
            if row is None:
                break
            results.append(row)
        return results

6.2 完整执行示例

让我们用一个实际查询来走通整条链路:

复制代码
# 1. 准备数据
users_data = [
    [1, 'Alice',   25, 'Beijing'],
    [2, 'Bob',     32, 'Shanghai'],
    [3, 'Charlie', 28, 'Beijing'],
    [4, 'David',   45, 'Guangzhou'],
    [5, 'Eve',     22, 'Beijing'],
    [6, 'Frank',   35, 'Shanghai'],
]

orders_data = [
    [101, 1, 250.0, '2024-01-15'],
    [102, 2, 180.0, '2024-02-20'],
    [103, 1, 320.0, '2024-03-10'],
    [104, 3, 150.0, '2024-01-28'],
    [105, 5, 500.0, '2024-04-05'],
]

# 2. 设置元数据
catalog = Catalog()
catalog.add_table('users')
catalog.get_table('users').add_column('id', 'INTEGER')
catalog.get_table('users').add_column('name', 'STRING')
catalog.get_table('users').add_column('age', 'INTEGER')
catalog.get_table('users').add_column('city', 'STRING')
catalog.set_table_data('users', users_data)

catalog.add_table('orders')
catalog.get_table('orders').add_column('id', 'INTEGER')
catalog.get_table('orders').add_column('user_id', 'INTEGER')
catalog.get_table('orders').add_column('amount', 'FLOAT')
catalog.get_table('orders').add_column('date', 'STRING')
catalog.set_table_data('orders', orders_data)

# 3. 解析 SQL
lexer = Lexer("SELECT u.name, o.amount, o.date "
              "FROM users u "
              "JOIN orders o ON u.id = o.user_id "
              "WHERE u.city = 'Beijing' AND o.amount > 200 "
              "ORDER BY o.amount DESC")
tokens = lexer.tokenize()
parser = Parser(tokens)
ast = parser.parse()

# 4. 绑定 + 优化
binder = Binder(catalog)
plan = binder.bind_and_optimize(ast)

# 5. 执行
executor = Executor(plan)
results = executor.execute()

# 6. 输出结果
for row in results:
    print(row)

输出结果:

复制代码
['Alice', 320.0, '2024-03-10']
['Alice', 250.0, '2024-01-15']
['Eve', 500.0, '2024-04-05']

七、性能优化实战经验

手写 SQL 引擎的过程让我对数据库查询优化有了更深的理解。这里分享几个在实践中非常有价值的优化思路:

7.1 选择正确的 Join 算法

本文实现的是嵌套循环连接(Nested Loop Join),适合数据量较小的场景。实际数据库中常用的还有:

算法 适用场景 时间复杂度
Nested Loop Join 小表 Join 大表(且有索引) O(N × M)
Hash Join 等值连接,无索引 O(N + M)
Merge Join 两个表都已排序 O(N + M)
Sort-Merge Join 需要排序输出的场景 O(N log N + M log M)

7.2 投影下推

除了谓词下推,投影下推是另一个非常重要的优化。如果 SELECT 只用了 3 列,而表有 50 列,那么扫描节点应该只读取那 3 列的数据。这在行存数据库中效果显著,在列存数据库中更是天生就实现了。

7.3 向量化执行

我们的引擎是一次处理一行数据,这被称为"行迭代器"模式。现代数据库引擎(如 ClickHouse、DuckDB)采用向量化执行------一次处理一批数据(通常为 1024 行)。向量化可以减少虚函数调用开销,更好地利用 CPU 缓存。

复制代码
# 伪代码:向量化 vs 行迭代
# 行迭代(我们的实现)
for row in data:
    process(row)

# 向量化执行
for batch in batches(data, size=1024):
    process_vectorized(batch)  # 用 NumPy 等批量处理

7.4 自适应执行计划

一个有趣的方向是根据数据统计信息动态调整执行计划。例如,如果查询优化器发现某个过滤条件的选择性很低(只匹配极少数行),可以自动切换 Join 顺序或算法。


八、扩展思路

到此为止,我们已经实现了一个迷你 SQL 引擎。如果你有兴趣继续深入,以下是几个扩展方向:

8.1 支持子查询

子查询的解析需要在语法层面支持:expr IN (SELECT ...)EXISTS (SELECT ...)。在 AST 层面,需要把子查询表示为一个子查询节点。在执行层面可以用"相关子查询"(Correlated Subquery)和"非相关子查询"(Non-Correlated Subquery)两种方式处理。

8.2 统计信息驱动的代价优化

真实的查询优化器会收集表的统计信息(行数、列的直方图等),然后计算每个操作的代价,选择总代价最小的执行计划。这是一个搜索问题------在可能的执行计划空间中找出最优解。

8.3 索引支持

为表建立索引后,ScanNode 可以改为 IndexScanNode,通过 B+ 树直接定位需要的数据,将全表扫描 O(N) 降低到索引查找 O(log N)。

8.4 MPP 分布式执行

将当前的单机执行器改为分布式执行器,通过数据分片和任务并行来加速大规模数据查询。这涉及数据分区策略、网络通信、任务调度等系统设计问题。


九、完整执行流程可视化

让我们把整个执行链路串联起来,用一个具体的例子展示数据在各个阶段的形态变化。以这条 SQL 为例:

复制代码
SELECT u.name, o.amount
FROM users u JOIN orders o ON u.id = o.user_id
WHERE u.city = 'Beijing' AND o.amount > 200
ORDER BY o.amount DESC
LIMIT 2

Step 1: 词法分析(字符串 → Token)

复制代码
SELECT   → Token(SELECT, 'SELECT')
u        → Token(IDENTIFIER, 'u')
.        → Token(DOT, '.')
name     → Token(IDENTIFIER, 'name')
FROM     → Token(FROM, 'FROM')
users    → Token(IDENTIFIER, 'users')
u        → Token(IDENTIFIER, 'u')   [别名]
JOIN     → Token(JOIN, 'JOIN')
orders   → Token(IDENTIFIER, 'orders')
o        → Token(IDENTIFIER, 'o')   [别名]
ON       → Token(ON, 'ON')
... 共 28 个 Token

Step 2: 语法分析(Token → AST)

Parser 生成 SelectStatement:

复制代码
SelectStatement(
  select_items: [ColumnRef('u.name'), ColumnRef('o.amount')],
  from_table: 'users',
  from_alias: 'u',
  join_clauses: [JoinClause(
    table: 'orders', alias: 'o',
    condition: BinaryOp(=, ColumnRef('u.id'), ColumnRef('o.user_id'))
  )],
  where_clause: BinaryOp(AND,
    BinaryOp(=, ColumnRef('u.city'), Literal('Beijing')),
    BinaryOp(>, ColumnRef('o.amount'), Literal(200))
  ),
  order_by: [OrderByItem(ColumnRef('o.amount'), 'DESC')],
  limit: 2
)

Step 3: 优化(AST → 优化后的执行计划)

Binder 应用谓词下推和投影下推后生成执行计划:

复制代码
LimitNode(2)
  └── SortNode(ORDER BY o.amount DESC)
       └── ProjectNode([u.name, o.amount])
            └── JoinNode(ON u.id = o.user_id)
                 ├── FilterNode(u.city = 'Beijing')
                 │    └── ScanNode(users, [id, name, city])  ← 投影下推只读 3 列
                 └── ScanNode(orders, [user_id, amount])      ← 投影下推只读 2 列

注意:原本 WHERE 中的 u.city = 'Beijing' 被下推到了 Join 之前执行,在 ScanNode 之上就完成了过滤。

Step 4: 执行(逐层拉取数据)

执行过程从最上层(LimitNode)开始,逐层调用 next():

复制代码
1. LimitNode.next() → SortNode.next() → ProjectNode.next() → JoinNode.next()
   → FilterNode.next() → ScanNode.next()
   → 返回 Alice 的行 (age=25, city='Beijing') → 通过过滤 ✓
   → JoinNode 读取右表
   → 返回 (Alice, 250.0) → 通过 ORDER BY DES 排序...
2. 拉取下一行
   → 返回 (Alice, 320.0)
3. 拉取下一行
   → FilterNode.next() → 下一条: Charlie (city='Beijing') → 通过 ✓
   → 返回 (Charlie, 150.0) → 金额 150 < 200,被 WHERE 过滤 ✗
4. FilterNode.next() → 下一条: Eve (city='Beijing') → 通过 ✓
   → 返回 (Eve, 500.0)
5. SortNode 收集所有通过的行,按 amount DESC 排序
6. LimitNode 取前 2 行
   → ['Eve', 500.0]
   → ['Alice', 320.0]

最终结果:

复制代码
['Eve', 500.0]
['Alice', 320.0]

这个可视化过程清晰展示了"火山模型"的核心思想:数据在生产时就消费了,不需要遍历全表后再过滤。FilterNode 在 Scan 读取完第一行就立刻判断是否满足条件,不满足就直接丢弃,不会传递到上层的 JoinNode。这就是谓词下推带来的优势。


十、总结与思考

本文从零开始,用 Python 实现了一个包含词法分析、语法分析、语义绑定、逻辑优化和火山模型执行的 SQL 查询引擎。全流程走下来,我们可以看到一个 SELECT 查询从字符串到结果的完整旅程:

  1. 词法分析 将 SQL 拆成 Token → 语法分析 构建 AST → 语义绑定 解析符号引用 → 优化器 做谓词下推等变换 → 执行器按火山模型逐行计算

理解 SQL 引擎内部机制的价值不仅在于满足好奇心,更在于日常工作中能做出更优的技术决策:

  • 写 SQL 时:知道 WHERE 条件的顺序基本不影响性能(优化器会自动调整),但知道哪个过滤条件的"选择性"高,就能帮助理解执行计划
  • 调优时:看到执行计划中的 Filter、Scan,能快速定位瓶颈,提出加索引或改写 SQL 的方案
  • 评估新数据库时:了解其查询引擎的架构(MPP vs 单机、行存 vs 列存、向量化 vs 行迭代),能做出更理性的技术选型

我们的迷你引擎只有几百行代码,但"麻雀虽小,五脏俱全"。建议你自己动手跑一遍代码,尝试修改或扩展功能------这是理解数据库系统最有效的学习路径。


附录:完整代码 & 快速上手

所有代码已上传至 GitHub(示例仓库),也可以直接复制本文中的各模块组合运行。启动示例:

复制代码
# 将所有代码放入 sql_engine/ 目录
python -c "
from lexer import Lexer
from parser import Parser
from catalog import Catalog
from binder import Binder
from executor import Executor

# 一分钟上手
catalog = Catalog()
catalog.add_table('users')
catalog.get_table('users').add_column('id', 'INTEGER')
catalog.get_table('users').add_column('name', 'STRING')
catalog.get_table('users').add_column('age', 'INTEGER')
catalog.set_table_data('users', [[1, 'Tom', 25], [2, 'Jerry', 30], [3, 'Spike', 35]])

sql = 'SELECT name, age FROM users WHERE age >= 30'
tokens = Lexer(sql).tokenize()
ast = Parser(tokens).parse()
plan = Binder(catalog).bind_and_optimize(ast)
results = Executor(plan).execute()

for row in results:
    print(row)
# 输出: ['Jerry', 30]
#       ['Spike', 35]
"

本文是「从零手写系列」的数据库引擎篇。关注我,获取更多深入底层的实战教程!

📚 相关推荐DeepSeek 模型推理优化实战指南 --- 从推理引擎视角理解系统性能优化。

如果你对本文有任何疑问或建议,欢迎在评论区留言讨论!


本文为原创技术文章,转载请注明出处。

相关推荐
Yan-英杰1 小时前
亮数据 - Ticket_Hunter_Agent
人工智能·神经网络·机器学习·ai开发工具
编码如写诗1 小时前
瑞芯微RK3588+麒麟V10国防版+昇腾310异构部署k8s集群+KubeSphere
人工智能·ai·云原生·kubernetes
ai产品老杨1 小时前
基于 Docker 与 GB28181/RTSP 的边缘计算 AI 视频管理平台:高并发流媒体解耦与源码交付架构深析
人工智能·docker·边缘计算
Dymc1 小时前
【论文解析】CoPCS — 让无人机与无人车“心有灵犀“的协同规划框架
人工智能·无人机·视觉定位·低空经济·无人集群
Raink老师1 小时前
【AI面试临阵磨枪-78】本地生活 Agent:外卖、到店、打车、酒店、售后全链路设计
人工智能·面试·生活
GlobalInfo1 小时前
人工智能NFT生成工具行业调查、市场规模、排名分析报告2026-2032
人工智能·百度
Tisfy1 小时前
LeetCode 3121.统计特殊字母的数量 II:状态机
算法·leetcode·题解·状态机
zzzsde1 小时前
【Linux网络】传输层协议UDP
linux·服务器·开发语言·网络·算法·udp
CoCo的编程之路1 小时前
像素级突围:如何利用智能前端开发助手最大化提升页面构建速度?
前端·人工智能·ai编程·智能编程助手·文心快码baiducomate