抽象语法树AST与python的Demo实现

1. 基本概念与定义

抽象语法树是源代码抽象语法结构的树状表示,它捕获了程序的逻辑结构,同时省略了具体的语法细节(如分号、括号等)。

核心特征:

  • 抽象性:不包含具体语法符号

  • 层次结构:反映程序的嵌套结构

  • 语义焦点:专注于程序的含义而非形式

2. AST vs 具体语法树(CST)

cpp 复制代码
# 源代码: x = 10 + 5 * 2

# 具体语法树(包含所有语法细节)
# Assignment
# ├── Identifier "x"
# ├── Equals "="  
# └── Expression
#     ├── Integer "10"
#     ├── Plus "+"
#     └── Expression
#         ├── Integer "5"
#         ├── Multiply "*"
#         └── Integer "2"

# 抽象语法树(只保留逻辑结构)
# Assignment("x")
# └── BinaryExpression("+")
#     ├── IntegerLiteral(10)
#     └── BinaryExpression("*")
#         ├── IntegerLiteral(5)
#         └── IntegerLiteral(2)

3. AST在编译流程中的位置

cpp 复制代码
源代码 → 词法分析 → 语法分析 → 语义分析 → 中间代码生成 → 优化 → 代码生成
                    ↓
                   AST生成

4. 专业设计模式

访问者模式(Visitor Pattern)

bash 复制代码
class ASTVisitor:
    def visit(node):
        # 双重分派:节点类型 + 访问者类型
        node.accept(self)

class TypeChecker(ASTVisitor):
    def visit_binary_expression(self, node):
        left_type = self.visit(node.left)
        right_type = self.visit(node.right)
        return self.check_compatibility(left_type, right_type, node.operator)

组合模式(Composite Pattern)

python 复制代码
# AST节点形成树形结构,统一接口
class ASTNode:
    def accept(visitor): pass

class Expression(ASTNode): pass
class Statement(ASTNode): pass

5. AST的核心属性

结构属性

  • 深度:从根节点到最远叶节点的距离

  • 广度:每层的节点数量

  • 子树:每个节点及其后代构成的子结构

语义属性

  • 作用域:变量可见性范围

  • 类型信息:表达式类型注解

  • 控制流:程序执行路径

6. AST在编译器中的应用

语义分析

python 复制代码
class SemanticAnalyzer(ASTVisitor):
    def __init__(self):
        self.symbol_table = SymbolTable()
    
    def visit_assignment(self, node):
        # 检查变量是否已声明
        if not self.symbol_table.is_declared(node.variable):
            raise SemanticError(f"Undeclared variable: {node.variable}")
        
        # 检查类型兼容性
        var_type = self.symbol_table.get_type(node.variable)
        expr_type = self.visit(node.value)
        
        if not self.is_assignable(var_type, expr_type):
            raise TypeError(f"Incompatible types in assignment")

代码生成

python 复制代码
class CodeGenerator(ASTVisitor):
    def visit_binary_expression(self, node):
        # 生成三地址代码
        left_temp = self.visit(node.left)
        right_temp = self.visit(node.right)
        result_temp = self.new_temp()
        
        self.emit(f"{result_temp} = {left_temp} {node.operator} {right_temp}")
        return result_temp

优化转换

python 复制代码
class ConstantFolding(ASTVisitor):
    def visit_binary_expression(self, node):
        left = self.visit(node.left)
        right = self.visit(node.right)
        
        # 常量折叠:5 * 2 → 10
        if isinstance(left, Constant) and isinstance(right, Constant):
            return Constant(eval(f"{left.value} {node.operator} {right.value}"))
        
        return BinaryExpression(node.operator, left, right)

7. 现代AST设计趋势

可持久化AST

python 复制代码
# 不可变AST节点,支持版本管理
@dataclass(frozen=True)
class PersistentASTNode:
    node_id: str
    children: Tuple['ASTNode', ...]
    
    def with_updated_child(self, index, new_child):
        new_children = list(self.children)
        new_children[index] = new_child
        return PersistentASTNode(self.node_id, tuple(new_children))

属性文法集成

python 复制代码
class AttributedAST:
    def __init__(self, node):
        self.node = node
        self.attributes = {}  # 合成属性、继承属性
    
    def get_attribute(self, name):
        return self.compute_attribute(name)
    
    def compute_attribute(self, name):
        # 基于依赖关系计算属性值
        pass

8. AST序列化与持久化

JSON序列化

python 复制代码
{
  "type": "BinaryExpression",
  "operator": "+",
  "left": {
    "type": "IntegerLiteral",
    "value": 10
  },
  "right": {
    "type": "BinaryExpression", 
    "operator": "*",
    "left": {"type": "IntegerLiteral", "value": 5},
    "right": {"type": "IntegerLiteral", "value": 2}
  }
}

9. 专业工具与库

工业级AST库

  • ANTLR: 强大的解析器生成器

  • Tree-sitter: 增量解析,支持多种语言

  • Eclipse JDT: Java AST工具包

  • Clang AST: C/C++前端

10. AST的扩展应用

静态分析

python 复制代码
class DataFlowAnalyzer:
    def analyze_ast(self, ast):
        # 定义-使用链分析
        # 活性分析
        # 常量传播
        pass

重构工具

python 复制代码
class RefactoringEngine:
    def extract_method(self, ast, start_line, end_line):
        # 识别代码片段
        # 创建新方法
        # 替换原代码为方法调用
        pass

语言服务器协议(LSP)

python 复制代码
class LanguageServer:
    def provide_completions(self, ast, position):
        # 基于AST的智能代码补全
        # 类型推断
        # 作用域分析
        pass

总结

AST是现代编译器、静态分析工具和IDE的核心数据结构。它的专业价值体现在:

  1. 语义准确性:精确表达程序含义

  2. 可扩展性:支持多种分析和转换

  3. 工具友好:便于构建开发工具链

  4. 跨平台性:独立于具体目标机器

下面给出一个实现例子Demo

python 复制代码
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional

class ASTNode(ABC):
    @abstractmethod
    def accept(self, visitor: 'ASTVisitor') -> Any:
        pass
    
    @abstractmethod
    def __str__(self) -> str:
        pass

class Expression(ASTNode):
    pass

class Statement(ASTNode):
    pass

class Literal(Expression):
    def __init__(self, value: Any):
        self.value = value
    
    def get_value(self) -> Any:
        return self.value

class IntegerLiteral(Literal):
    def __init__(self, value: int):
        super().__init__(value)
    
    def accept(self, visitor: 'ASTVisitor') -> Any:
        return visitor.visit_integer_literal(self)
    
    def __str__(self) -> str:
        return f"Integer({self.value})"

class StringLiteral(Literal):
    def __init__(self, value: str):
        super().__init__(value)
    
    def accept(self, visitor: 'ASTVisitor') -> Any:
        return visitor.visit_string_literal(self)
    
    def __str__(self) -> str:
        return f'String("{self.value}")'

class BinaryExpression(Expression):
    def __init__(self, operator: str, left: Expression, right: Expression):
        self.operator = operator
        self.left = left
        self.right = right
    
    def accept(self, visitor: 'ASTVisitor') -> Any:
        return visitor.visit_binary_expression(self)
    
    def __str__(self) -> str:
        return f"Binary({self.operator})"
    
    def get_operator(self) -> str:
        return self.operator
    
    def get_left(self) -> Expression:
        return self.left
    
    def get_right(self) -> Expression:
        return self.right

class Variable(Expression):
    def __init__(self, name: str):
        self.name = name
    
    def accept(self, visitor: 'ASTVisitor') -> Any:
        return visitor.visit_variable(self)
    
    def __str__(self) -> str:
        return f"Variable({self.name})"
    
    def get_name(self) -> str:
        return self.name

class Assignment(Statement):
    def __init__(self, variable: str, value: Expression):
        self.variable = variable
        self.value = value
    
    def accept(self, visitor: 'ASTVisitor') -> Any:
        return visitor.visit_assignment(self)
    
    def __str__(self) -> str:
        return f"Assignment({self.variable})"
    
    def get_variable(self) -> str:
        return self.variable
    
    def get_value(self) -> Expression:
        return self.value

class PrintStatement(Statement):
    def __init__(self, expression: Expression):
        self.expression = expression
    
    def accept(self, visitor: 'ASTVisitor') -> Any:
        return visitor.visit_print_statement(self)
    
    def __str__(self) -> str:
        return "PrintStatement"
    
    def get_expression(self) -> Expression:
        return self.expression

class Block(Statement):
    def __init__(self):
        self.statements: List[Statement] = []
    
    def add_statement(self, statement: Statement):
        self.statements.append(statement)
    
    def accept(self, visitor: 'ASTVisitor') -> Any:
        return visitor.visit_block(self)
    
    def __str__(self) -> str:
        return f"Block({len(self.statements)} statements)"
    
    def get_statements(self) -> List[Statement]:
        return self.statements

class ASTVisitor(ABC):
    @abstractmethod
    def visit_integer_literal(self, node: IntegerLiteral) -> Any:
        pass
    
    @abstractmethod
    def visit_string_literal(self, node: StringLiteral) -> Any:
        pass
    
    @abstractmethod
    def visit_binary_expression(self, node: BinaryExpression) -> Any:
        pass
    
    @abstractmethod
    def visit_variable(self, node: Variable) -> Any:
        pass
    
    @abstractmethod
    def visit_assignment(self, node: Assignment) -> Any:
        pass
    
    @abstractmethod
    def visit_print_statement(self, node: PrintStatement) -> Any:
        pass
    
    @abstractmethod
    def visit_block(self, node: Block) -> Any:
        pass



class ASTPrinter(ASTVisitor):
    def __init__(self):
        self.indent_level = 0
    
    def _print_indent(self):
        print("  " * self.indent_level, end="")
    
    def visit_integer_literal(self, node):
        self._print_indent()
        print(str(node))
    
    def visit_string_literal(self, node):
        self._print_indent()
        print(str(node))
    
    def visit_binary_expression(self, node):
        self._print_indent()
        print(str(node))
        
        self.indent_level += 1
        node.get_left().accept(self)
        node.get_right().accept(self)
        self.indent_level -= 1
    
    def visit_variable(self, node):
        self._print_indent()
        print(str(node))
    
    def visit_assignment(self, node):
        self._print_indent()
        print(str(node))
        
        self.indent_level += 1
        node.get_value().accept(self)
        self.indent_level -= 1
    
    def visit_print_statement(self, node):
        self._print_indent()
        print(str(node))
        
        self.indent_level += 1
        node.get_expression().accept(self)
        self.indent_level -= 1
    
    def visit_block(self, node):
        self._print_indent()
        print(str(node))
        
        self.indent_level += 1
        for stmt in node.get_statements():
            stmt.accept(self)
        self.indent_level -= 1




class Evaluator(ASTVisitor):
    def __init__(self):
        self.variables: Dict[str, Any] = {}
        self.result: Any = None
    
    def get_result(self) -> Any:
        return self.result
    
    def visit_integer_literal(self, node):
        self.result = node.get_value()
    
    def visit_string_literal(self, node):
        self.result = node.get_value()
    
    def visit_binary_expression(self, node):
        node.get_left().accept(self)
        left_val = self.result
        
        node.get_right().accept(self)
        right_val = self.result
        
        operator = node.get_operator()
        if operator == "+":
            self.result = left_val + right_val
        elif operator == "-":
            self.result = left_val - right_val
        elif operator == "*":
            self.result = left_val * right_val
        elif operator == "/":
            self.result = left_val / right_val
        elif operator == "==":
            self.result = left_val == right_val
        elif operator == "!=":
            self.result = left_val != right_val
    
    def visit_variable(self, node):
        self.result = self.variables.get(node.get_name(), 0)
    
    def visit_assignment(self, node):
        node.get_value().accept(self)
        self.variables[node.get_variable()] = self.result
    
    def visit_print_statement(self, node):
        node.get_expression().accept(self)
        print(f"Print: {self.result}")
    
    def visit_block(self, node):
        for stmt in node.get_statements():
            stmt.accept(self)        


def main():
    # 构建 AST: x = 10 + 5 * 2
    ast = Block()
    
    # 创建表达式: 5 * 2
    five = IntegerLiteral(5)
    two = IntegerLiteral(2)
    multiplication = BinaryExpression("*", five, two)
    
    # 创建表达式: 10 + (5 * 2)
    ten = IntegerLiteral(10)
    addition = BinaryExpression("+", ten, multiplication)
    
    # 创建赋值语句: x = 10 + 5 * 2
    assignment = Assignment("x", addition)
    ast.add_statement(assignment)
    
    # 创建打印语句: print x
    var_x = Variable("x")
    print_stmt = PrintStatement(var_x)
    ast.add_statement(print_stmt)
    
    # 添加更多示例
    # y = x + 1
    one = IntegerLiteral(1)
    x_plus_one = BinaryExpression("+", Variable("x"), one)
    assignment_y = Assignment("y", x_plus_one)
    ast.add_statement(assignment_y)
    
    # print y
    print_y = PrintStatement(Variable("y"))
    ast.add_statement(print_y)
    
    # 打印 AST
    print("=== AST Structure ===")
    printer = ASTPrinter()
    ast.accept(printer)
    
    # 执行求值
    print("\n=== Evaluation ===")
    evaluator = Evaluator()
    ast.accept(evaluator)

if __name__ == "__main__":
    main()

运行结果:

bash 复制代码
=== AST Structure ===
Block(4 statements)
  Assignment(x)
    Binary(+)
      Integer(10)
      Binary(*)
        Integer(5)
        Integer(2)
  PrintStatement
    Variable(x)
  Assignment(y)
    Binary(+)
      Variable(x)
      Integer(1)
  PrintStatement
    Variable(y)

=== Evaluation ===
Print: 20
Print: 21
相关推荐
郝开3 小时前
Spring Boot 2.7.18(最终 2.x 系列版本)3 - 枚举规范定义:定义基础枚举接口;定义枚举工具类;示例枚举
spring boot·后端·python·枚举·enum
钅日 勿 XiName4 小时前
一小时速通Pytorch之自动梯度(Autograd)和计算图(Computational Graph)(二)
人工智能·pytorch·python
故林丶4 小时前
【Django】Django笔记
python·django
IT北辰4 小时前
Python实现居民供暖中暖气能耗数据可视化分析(文中含源码)
开发语言·python·信息可视化
FreeCode4 小时前
LangChain1.0智能体开发:长期记忆
python·langchain·agent
Predestination王瀞潞5 小时前
Python __name__ 与 __main__
开发语言·python
萧曵 丶5 小时前
Python 字符串、列表、元组、字典、集合常用函数
开发语言·前端·python
梦想的初衷~5 小时前
Plaxis自动化建模与Python应用全解:从环境搭建到高级案例实战
python·自动化·工程设计·工程软件
Q_Q5110082855 小时前
python+uniapp基于微信小程序的垃圾分类信息系统
spring boot·python·微信小程序·django·flask·uni-app·node.js