1. 基本概念与定义
抽象语法树是源代码抽象语法结构的树状表示,它捕获了程序的逻辑结构,同时省略了具体的语法细节(如分号、括号等)。
核心特征:
-
抽象性:不包含具体语法符号
-
层次结构:反映程序的嵌套结构
-
语义焦点:专注于程序的含义而非形式
2. AST vs 具体语法树(CST)
cpp
# 源代码: x = 10 + 5 * 2
# 具体语法树(包含所有语法细节)
# Assignment
# ├── Identifier "x"
# ├── Equals "="
# └── Expression
# ├── Integer "10"
# ├── Plus "+"
# └── Expression
# ├── Integer "5"
# ├── Multiply "*"
# └── Integer "2"
# 抽象语法树(只保留逻辑结构)
# Assignment("x")
# └── BinaryExpression("+")
# ├── IntegerLiteral(10)
# └── BinaryExpression("*")
# ├── IntegerLiteral(5)
# └── IntegerLiteral(2)
3. AST在编译流程中的位置
cpp
源代码 → 词法分析 → 语法分析 → 语义分析 → 中间代码生成 → 优化 → 代码生成
↓
AST生成
4. 专业设计模式
访问者模式(Visitor Pattern)
bash
class ASTVisitor:
def visit(node):
# 双重分派:节点类型 + 访问者类型
node.accept(self)
class TypeChecker(ASTVisitor):
def visit_binary_expression(self, node):
left_type = self.visit(node.left)
right_type = self.visit(node.right)
return self.check_compatibility(left_type, right_type, node.operator)
组合模式(Composite Pattern)
python
# AST节点形成树形结构,统一接口
class ASTNode:
def accept(visitor): pass
class Expression(ASTNode): pass
class Statement(ASTNode): pass
5. AST的核心属性
结构属性
-
深度:从根节点到最远叶节点的距离
-
广度:每层的节点数量
-
子树:每个节点及其后代构成的子结构
语义属性
-
作用域:变量可见性范围
-
类型信息:表达式类型注解
-
控制流:程序执行路径
6. AST在编译器中的应用
语义分析
python
class SemanticAnalyzer(ASTVisitor):
def __init__(self):
self.symbol_table = SymbolTable()
def visit_assignment(self, node):
# 检查变量是否已声明
if not self.symbol_table.is_declared(node.variable):
raise SemanticError(f"Undeclared variable: {node.variable}")
# 检查类型兼容性
var_type = self.symbol_table.get_type(node.variable)
expr_type = self.visit(node.value)
if not self.is_assignable(var_type, expr_type):
raise TypeError(f"Incompatible types in assignment")
代码生成
python
class CodeGenerator(ASTVisitor):
def visit_binary_expression(self, node):
# 生成三地址代码
left_temp = self.visit(node.left)
right_temp = self.visit(node.right)
result_temp = self.new_temp()
self.emit(f"{result_temp} = {left_temp} {node.operator} {right_temp}")
return result_temp
优化转换
python
class ConstantFolding(ASTVisitor):
def visit_binary_expression(self, node):
left = self.visit(node.left)
right = self.visit(node.right)
# 常量折叠:5 * 2 → 10
if isinstance(left, Constant) and isinstance(right, Constant):
return Constant(eval(f"{left.value} {node.operator} {right.value}"))
return BinaryExpression(node.operator, left, right)
7. 现代AST设计趋势
可持久化AST
python
# 不可变AST节点,支持版本管理
@dataclass(frozen=True)
class PersistentASTNode:
node_id: str
children: Tuple['ASTNode', ...]
def with_updated_child(self, index, new_child):
new_children = list(self.children)
new_children[index] = new_child
return PersistentASTNode(self.node_id, tuple(new_children))
属性文法集成
python
class AttributedAST:
def __init__(self, node):
self.node = node
self.attributes = {} # 合成属性、继承属性
def get_attribute(self, name):
return self.compute_attribute(name)
def compute_attribute(self, name):
# 基于依赖关系计算属性值
pass
8. AST序列化与持久化
JSON序列化
python
{
"type": "BinaryExpression",
"operator": "+",
"left": {
"type": "IntegerLiteral",
"value": 10
},
"right": {
"type": "BinaryExpression",
"operator": "*",
"left": {"type": "IntegerLiteral", "value": 5},
"right": {"type": "IntegerLiteral", "value": 2}
}
}
9. 专业工具与库
工业级AST库
-
ANTLR: 强大的解析器生成器
-
Tree-sitter: 增量解析,支持多种语言
-
Eclipse JDT: Java AST工具包
-
Clang AST: C/C++前端
10. AST的扩展应用
静态分析
python
class DataFlowAnalyzer:
def analyze_ast(self, ast):
# 定义-使用链分析
# 活性分析
# 常量传播
pass
重构工具
python
class RefactoringEngine:
def extract_method(self, ast, start_line, end_line):
# 识别代码片段
# 创建新方法
# 替换原代码为方法调用
pass
语言服务器协议(LSP)
python
class LanguageServer:
def provide_completions(self, ast, position):
# 基于AST的智能代码补全
# 类型推断
# 作用域分析
pass
总结
AST是现代编译器、静态分析工具和IDE的核心数据结构。它的专业价值体现在:
-
语义准确性:精确表达程序含义
-
可扩展性:支持多种分析和转换
-
工具友好:便于构建开发工具链
-
跨平台性:独立于具体目标机器
下面给出一个实现例子Demo
python
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional
class ASTNode(ABC):
@abstractmethod
def accept(self, visitor: 'ASTVisitor') -> Any:
pass
@abstractmethod
def __str__(self) -> str:
pass
class Expression(ASTNode):
pass
class Statement(ASTNode):
pass
class Literal(Expression):
def __init__(self, value: Any):
self.value = value
def get_value(self) -> Any:
return self.value
class IntegerLiteral(Literal):
def __init__(self, value: int):
super().__init__(value)
def accept(self, visitor: 'ASTVisitor') -> Any:
return visitor.visit_integer_literal(self)
def __str__(self) -> str:
return f"Integer({self.value})"
class StringLiteral(Literal):
def __init__(self, value: str):
super().__init__(value)
def accept(self, visitor: 'ASTVisitor') -> Any:
return visitor.visit_string_literal(self)
def __str__(self) -> str:
return f'String("{self.value}")'
class BinaryExpression(Expression):
def __init__(self, operator: str, left: Expression, right: Expression):
self.operator = operator
self.left = left
self.right = right
def accept(self, visitor: 'ASTVisitor') -> Any:
return visitor.visit_binary_expression(self)
def __str__(self) -> str:
return f"Binary({self.operator})"
def get_operator(self) -> str:
return self.operator
def get_left(self) -> Expression:
return self.left
def get_right(self) -> Expression:
return self.right
class Variable(Expression):
def __init__(self, name: str):
self.name = name
def accept(self, visitor: 'ASTVisitor') -> Any:
return visitor.visit_variable(self)
def __str__(self) -> str:
return f"Variable({self.name})"
def get_name(self) -> str:
return self.name
class Assignment(Statement):
def __init__(self, variable: str, value: Expression):
self.variable = variable
self.value = value
def accept(self, visitor: 'ASTVisitor') -> Any:
return visitor.visit_assignment(self)
def __str__(self) -> str:
return f"Assignment({self.variable})"
def get_variable(self) -> str:
return self.variable
def get_value(self) -> Expression:
return self.value
class PrintStatement(Statement):
def __init__(self, expression: Expression):
self.expression = expression
def accept(self, visitor: 'ASTVisitor') -> Any:
return visitor.visit_print_statement(self)
def __str__(self) -> str:
return "PrintStatement"
def get_expression(self) -> Expression:
return self.expression
class Block(Statement):
def __init__(self):
self.statements: List[Statement] = []
def add_statement(self, statement: Statement):
self.statements.append(statement)
def accept(self, visitor: 'ASTVisitor') -> Any:
return visitor.visit_block(self)
def __str__(self) -> str:
return f"Block({len(self.statements)} statements)"
def get_statements(self) -> List[Statement]:
return self.statements
class ASTVisitor(ABC):
@abstractmethod
def visit_integer_literal(self, node: IntegerLiteral) -> Any:
pass
@abstractmethod
def visit_string_literal(self, node: StringLiteral) -> Any:
pass
@abstractmethod
def visit_binary_expression(self, node: BinaryExpression) -> Any:
pass
@abstractmethod
def visit_variable(self, node: Variable) -> Any:
pass
@abstractmethod
def visit_assignment(self, node: Assignment) -> Any:
pass
@abstractmethod
def visit_print_statement(self, node: PrintStatement) -> Any:
pass
@abstractmethod
def visit_block(self, node: Block) -> Any:
pass
class ASTPrinter(ASTVisitor):
def __init__(self):
self.indent_level = 0
def _print_indent(self):
print(" " * self.indent_level, end="")
def visit_integer_literal(self, node):
self._print_indent()
print(str(node))
def visit_string_literal(self, node):
self._print_indent()
print(str(node))
def visit_binary_expression(self, node):
self._print_indent()
print(str(node))
self.indent_level += 1
node.get_left().accept(self)
node.get_right().accept(self)
self.indent_level -= 1
def visit_variable(self, node):
self._print_indent()
print(str(node))
def visit_assignment(self, node):
self._print_indent()
print(str(node))
self.indent_level += 1
node.get_value().accept(self)
self.indent_level -= 1
def visit_print_statement(self, node):
self._print_indent()
print(str(node))
self.indent_level += 1
node.get_expression().accept(self)
self.indent_level -= 1
def visit_block(self, node):
self._print_indent()
print(str(node))
self.indent_level += 1
for stmt in node.get_statements():
stmt.accept(self)
self.indent_level -= 1
class Evaluator(ASTVisitor):
def __init__(self):
self.variables: Dict[str, Any] = {}
self.result: Any = None
def get_result(self) -> Any:
return self.result
def visit_integer_literal(self, node):
self.result = node.get_value()
def visit_string_literal(self, node):
self.result = node.get_value()
def visit_binary_expression(self, node):
node.get_left().accept(self)
left_val = self.result
node.get_right().accept(self)
right_val = self.result
operator = node.get_operator()
if operator == "+":
self.result = left_val + right_val
elif operator == "-":
self.result = left_val - right_val
elif operator == "*":
self.result = left_val * right_val
elif operator == "/":
self.result = left_val / right_val
elif operator == "==":
self.result = left_val == right_val
elif operator == "!=":
self.result = left_val != right_val
def visit_variable(self, node):
self.result = self.variables.get(node.get_name(), 0)
def visit_assignment(self, node):
node.get_value().accept(self)
self.variables[node.get_variable()] = self.result
def visit_print_statement(self, node):
node.get_expression().accept(self)
print(f"Print: {self.result}")
def visit_block(self, node):
for stmt in node.get_statements():
stmt.accept(self)
def main():
# 构建 AST: x = 10 + 5 * 2
ast = Block()
# 创建表达式: 5 * 2
five = IntegerLiteral(5)
two = IntegerLiteral(2)
multiplication = BinaryExpression("*", five, two)
# 创建表达式: 10 + (5 * 2)
ten = IntegerLiteral(10)
addition = BinaryExpression("+", ten, multiplication)
# 创建赋值语句: x = 10 + 5 * 2
assignment = Assignment("x", addition)
ast.add_statement(assignment)
# 创建打印语句: print x
var_x = Variable("x")
print_stmt = PrintStatement(var_x)
ast.add_statement(print_stmt)
# 添加更多示例
# y = x + 1
one = IntegerLiteral(1)
x_plus_one = BinaryExpression("+", Variable("x"), one)
assignment_y = Assignment("y", x_plus_one)
ast.add_statement(assignment_y)
# print y
print_y = PrintStatement(Variable("y"))
ast.add_statement(print_y)
# 打印 AST
print("=== AST Structure ===")
printer = ASTPrinter()
ast.accept(printer)
# 执行求值
print("\n=== Evaluation ===")
evaluator = Evaluator()
ast.accept(evaluator)
if __name__ == "__main__":
main()
运行结果:
bash
=== AST Structure ===
Block(4 statements)
Assignment(x)
Binary(+)
Integer(10)
Binary(*)
Integer(5)
Integer(2)
PrintStatement
Variable(x)
Assignment(y)
Binary(+)
Variable(x)
Integer(1)
PrintStatement
Variable(y)
=== Evaluation ===
Print: 20
Print: 21