读go语言自制解释器(一)生成ast

简介

本节内容主要介绍使用go语言,解析自定义语言monkey,生成ast(抽象语法树)的过程。主要分为两部分:词法分析和语法分析。

词法分析代码如下

package lexer

import "monkey/token"

type Lexer struct {
    input        string
    position     int  // current position in input (points to current char)
    readPosition int  // current reading position in input (after current char)
    ch           byte // current char under examination
}

func New(input string) *Lexer {
    l := &Lexer{input: input}
    l.readChar()
    return l
}

func (l *Lexer) NextToken() token.Token {
    var tok token.Token

    l.skipWhitespace()

    switch l.ch {
    case '=':
       if l.peekChar() == '=' {
          ch := l.ch
          l.readChar()
          literal := string(ch) + string(l.ch)
          tok = token.Token{Type: token.EQ, Literal: literal}
       } else {
          tok = newToken(token.ASSIGN, l.ch)
       }
    case '+':
       tok = newToken(token.PLUS, l.ch)
    case '-':
       tok = newToken(token.MINUS, l.ch)
    case '!':
       if l.peekChar() == '=' {
          ch := l.ch
          l.readChar()
          literal := string(ch) + string(l.ch)
          tok = token.Token{Type: token.NOT_EQ, Literal: literal}
       } else {
          tok = newToken(token.BANG, l.ch)
       }
    case '/':
       tok = newToken(token.SLASH, l.ch)
    case '*':
       tok = newToken(token.ASTERISK, l.ch)
    case '<':
       tok = newToken(token.LT, l.ch)
    case '>':
       tok = newToken(token.GT, l.ch)
    case ';':
       tok = newToken(token.SEMICOLON, l.ch)
    case ',':
       tok = newToken(token.COMMA, l.ch)
    case '{':
       tok = newToken(token.LBRACE, l.ch)
    case '}':
       tok = newToken(token.RBRACE, l.ch)
    case '(':
       tok = newToken(token.LPAREN, l.ch)
    case ')':
       tok = newToken(token.RPAREN, l.ch)
    case 0:
       tok.Literal = ""
       tok.Type = token.EOF
    default:
       if isLetter(l.ch) {
          tok.Literal = l.readIdentifier()
          tok.Type = token.LookupIdent(tok.Literal)
          return tok
       } else if isDigit(l.ch) {
          tok.Type = token.INT
          tok.Literal = l.readNumber()
          return tok
       } else {
          tok = newToken(token.ILLEGAL, l.ch)
       }
    }

    l.readChar()
    return tok
}

func (l *Lexer) skipWhitespace() {
    for l.ch == ' ' || l.ch == '\t' || l.ch == '\n' || l.ch == '\r' {
       l.readChar()
    }
}

func (l *Lexer) readChar() {
    if l.readPosition >= len(l.input) {
       l.ch = 0
    } else {
       l.ch = l.input[l.readPosition]
    }
    l.position = l.readPosition
    l.readPosition += 1
}

func (l *Lexer) peekChar() byte {
    if l.readPosition >= len(l.input) {
       return 0
    } else {
       return l.input[l.readPosition]
    }
}

func (l *Lexer) readIdentifier() string {
    position := l.position
    for isLetter(l.ch) {
       l.readChar()
    }
    return l.input[position:l.position]
}

func (l *Lexer) readNumber() string {
    position := l.position
    for isDigit(l.ch) {
       l.readChar()
    }
    return l.input[position:l.position]
}

func isLetter(ch byte) bool {
    return 'a' <= ch && ch <= 'z' || 'A' <= ch && ch <= 'Z' || ch == '_'
}

func isDigit(ch byte) bool {
    return '0' <= ch && ch <= '9'
}

func newToken(tokenType token.TokenType, ch byte) token.Token {
    return token.Token{Type: tokenType, Literal: string(ch)}
}

token定义如下

package token

type TokenType string

const (
    ILLEGAL = "ILLEGAL"
    EOF     = "EOF"

    // Identifiers + literals
    IDENT = "IDENT" // add, foobar, x, y, ...
    INT   = "INT"   // 1343456

    // Operators
    ASSIGN   = "="
    PLUS     = "+"
    MINUS    = "-"
    BANG     = "!"
    ASTERISK = "*"
    SLASH    = "/"

    LT = "<"
    GT = ">"

    EQ     = "=="
    NOT_EQ = "!="

    // Delimiters
    COMMA     = ","
    SEMICOLON = ";"

    LPAREN = "("
    RPAREN = ")"
    LBRACE = "{"
    RBRACE = "}"

    // Keywords
    FUNCTION = "FUNCTION"
    LET      = "LET"
    TRUE     = "TRUE"
    FALSE    = "FALSE"
    IF       = "IF"
    ELSE     = "ELSE"
    RETURN   = "RETURN"
)

type Token struct {
    Type    TokenType
    Literal string
}

var keywords = map[string]TokenType{
    "fn":     FUNCTION,
    "let":    LET,
    "true":   TRUE,
    "false":  FALSE,
    "if":     IF,
    "else":   ELSE,
    "return": RETURN,
}

func LookupIdent(ident string) TokenType {
    if tok, ok := keywords[ident]; ok {
       return tok
    }
    return IDENT
}

词法分析的主要功能就是将从文件中读出来的一个个字符拼接成一个个符合特定意义的token,供后续的语法分析阶段处理(方便生成ast)。

语法分析如下

package parser

import (
    "fmt"
    "monkey/ast"
    "monkey/lexer"
    "monkey/token"
    "strconv"
)

const (
    _ int = iota
    LOWEST
    EQUALS      // ==
    LESSGREATER // > or <
    SUM         // +
    PRODUCT     // *
    PREFIX      // -X or !X
    CALL        // myFunction(X)
)

var precedences = map[token.TokenType]int{
    token.EQ:       EQUALS,
    token.NOT_EQ:   EQUALS,
    token.LT:       LESSGREATER,
    token.GT:       LESSGREATER,
    token.PLUS:     SUM,
    token.MINUS:    SUM,
    token.SLASH:    PRODUCT,
    token.ASTERISK: PRODUCT,
    token.LPAREN:   CALL,
}

type (
    prefixParseFn func() ast.Expression
    infixParseFn  func(ast.Expression) ast.Expression
)

type Parser struct {
    l      *lexer.Lexer
    errors []string

    curToken  token.Token
    peekToken token.Token

    prefixParseFns map[token.TokenType]prefixParseFn
    infixParseFns  map[token.TokenType]infixParseFn
}

func New(l *lexer.Lexer) *Parser {
    p := &Parser{
       l:      l,
       errors: []string{},
    }

    p.prefixParseFns = make(map[token.TokenType]prefixParseFn)
    p.registerPrefix(token.IDENT, p.parseIdentifier)
    p.registerPrefix(token.INT, p.parseIntegerLiteral)
    p.registerPrefix(token.BANG, p.parsePrefixExpression)
    p.registerPrefix(token.MINUS, p.parsePrefixExpression)
    p.registerPrefix(token.TRUE, p.parseBoolean)
    p.registerPrefix(token.FALSE, p.parseBoolean)
    p.registerPrefix(token.LPAREN, p.parseGroupedExpression)
    p.registerPrefix(token.IF, p.parseIfExpression)
    p.registerPrefix(token.FUNCTION, p.parseFunctionLiteral)

    p.infixParseFns = make(map[token.TokenType]infixParseFn)
    p.registerInfix(token.PLUS, p.parseInfixExpression)
    p.registerInfix(token.MINUS, p.parseInfixExpression)
    p.registerInfix(token.SLASH, p.parseInfixExpression)
    p.registerInfix(token.ASTERISK, p.parseInfixExpression)
    p.registerInfix(token.EQ, p.parseInfixExpression)
    p.registerInfix(token.NOT_EQ, p.parseInfixExpression)
    p.registerInfix(token.LT, p.parseInfixExpression)
    p.registerInfix(token.GT, p.parseInfixExpression)

    p.registerInfix(token.LPAREN, p.parseCallExpression)

    // Read two tokens, so curToken and peekToken are both set
    p.nextToken()
    p.nextToken()

    return p
}

func (p *Parser) nextToken() {
    p.curToken = p.peekToken
    p.peekToken = p.l.NextToken()
}

func (p *Parser) curTokenIs(t token.TokenType) bool {
    return p.curToken.Type == t
}

func (p *Parser) peekTokenIs(t token.TokenType) bool {
    return p.peekToken.Type == t
}

func (p *Parser) expectPeek(t token.TokenType) bool {
    if p.peekTokenIs(t) {
       p.nextToken()
       return true
    } else {
       p.peekError(t)
       return false
    }
}

func (p *Parser) Errors() []string {
    return p.errors
}

func (p *Parser) peekError(t token.TokenType) {
    msg := fmt.Sprintf("expected next token to be %s, got %s instead",
       t, p.peekToken.Type)
    p.errors = append(p.errors, msg)
}

func (p *Parser) noPrefixParseFnError(t token.TokenType) {
    msg := fmt.Sprintf("no prefix parse function for %s found", t)
    p.errors = append(p.errors, msg)
}

func (p *Parser) ParseProgram() *ast.Program {
    program := &ast.Program{}
    program.Statements = []ast.Statement{}

    for !p.curTokenIs(token.EOF) {
       stmt := p.parseStatement()
       if stmt != nil {
          program.Statements = append(program.Statements, stmt)
       }
       p.nextToken()
    }

    return program
}

func (p *Parser) parseStatement() ast.Statement {
    switch p.curToken.Type {
    case token.LET:
       return p.parseLetStatement()
    case token.RETURN:
       return p.parseReturnStatement()
    default:
       return p.parseExpressionStatement()
    }
}

func (p *Parser) parseLetStatement() *ast.LetStatement {
    stmt := &ast.LetStatement{Token: p.curToken}

    if !p.expectPeek(token.IDENT) {
       return nil
    }

    stmt.Name = &ast.Identifier{Token: p.curToken, Value: p.curToken.Literal}

    if !p.expectPeek(token.ASSIGN) {
       return nil
    }

    p.nextToken()

    stmt.Value = p.parseExpression(LOWEST)

    if p.peekTokenIs(token.SEMICOLON) {
       p.nextToken()
    }

    return stmt
}

func (p *Parser) parseReturnStatement() *ast.ReturnStatement {
    stmt := &ast.ReturnStatement{Token: p.curToken}

    p.nextToken()

    stmt.ReturnValue = p.parseExpression(LOWEST)

    if p.peekTokenIs(token.SEMICOLON) {
       p.nextToken()
    }

    return stmt
}

func (p *Parser) parseExpressionStatement() *ast.ExpressionStatement {
    stmt := &ast.ExpressionStatement{Token: p.curToken}

    stmt.Expression = p.parseExpression(LOWEST)

    if p.peekTokenIs(token.SEMICOLON) {
       p.nextToken()
    }

    return stmt
}

func (p *Parser) parseExpression(precedence int) ast.Expression {
    prefix := p.prefixParseFns[p.curToken.Type]
    if prefix == nil {
       p.noPrefixParseFnError(p.curToken.Type)
       return nil
    }
    leftExp := prefix()

    for !p.peekTokenIs(token.SEMICOLON) && precedence < p.peekPrecedence() {
       infix := p.infixParseFns[p.peekToken.Type]
       if infix == nil {
          return leftExp
       }

       p.nextToken()

       leftExp = infix(leftExp)
    }

    return leftExp
}

func (p *Parser) peekPrecedence() int {
    if p, ok := precedences[p.peekToken.Type]; ok {
       return p
    }

    return LOWEST
}

func (p *Parser) curPrecedence() int {
    if p, ok := precedences[p.curToken.Type]; ok {
       return p
    }

    return LOWEST
}

func (p *Parser) parseIdentifier() ast.Expression {
    return &ast.Identifier{Token: p.curToken, Value: p.curToken.Literal}
}

func (p *Parser) parseIntegerLiteral() ast.Expression {
    lit := &ast.IntegerLiteral{Token: p.curToken}

    value, err := strconv.ParseInt(p.curToken.Literal, 0, 64)
    if err != nil {
       msg := fmt.Sprintf("could not parse %q as integer", p.curToken.Literal)
       p.errors = append(p.errors, msg)
       return nil
    }

    lit.Value = value

    return lit
}

func (p *Parser) parsePrefixExpression() ast.Expression {
    expression := &ast.PrefixExpression{
       Token:    p.curToken,
       Operator: p.curToken.Literal,
    }

    p.nextToken()

    expression.Right = p.parseExpression(PREFIX)

    return expression
}

func (p *Parser) parseInfixExpression(left ast.Expression) ast.Expression {
    expression := &ast.InfixExpression{
       Token:    p.curToken,
       Operator: p.curToken.Literal,
       Left:     left,
    }

    precedence := p.curPrecedence()
    p.nextToken()
    expression.Right = p.parseExpression(precedence)

    return expression
}

func (p *Parser) parseBoolean() ast.Expression {
    return &ast.Boolean{Token: p.curToken, Value: p.curTokenIs(token.TRUE)}
}

func (p *Parser) parseGroupedExpression() ast.Expression {
    p.nextToken()

    exp := p.parseExpression(LOWEST)

    if !p.expectPeek(token.RPAREN) {
       return nil
    }

    return exp
}

func (p *Parser) parseIfExpression() ast.Expression {
    expression := &ast.IfExpression{Token: p.curToken}

    if !p.expectPeek(token.LPAREN) {
       return nil
    }

    p.nextToken()
    expression.Condition = p.parseExpression(LOWEST)

    if !p.expectPeek(token.RPAREN) {
       return nil
    }

    if !p.expectPeek(token.LBRACE) {
       return nil
    }

    expression.Consequence = p.parseBlockStatement()

    if p.peekTokenIs(token.ELSE) {
       p.nextToken()

       if !p.expectPeek(token.LBRACE) {
          return nil
       }

       expression.Alternative = p.parseBlockStatement()
    }

    return expression
}

func (p *Parser) parseBlockStatement() *ast.BlockStatement {
    block := &ast.BlockStatement{Token: p.curToken}
    block.Statements = []ast.Statement{}

    p.nextToken()

    for !p.curTokenIs(token.RBRACE) && !p.curTokenIs(token.EOF) {
       stmt := p.parseStatement()
       if stmt != nil {
          block.Statements = append(block.Statements, stmt)
       }
       p.nextToken()
    }

    return block
}

func (p *Parser) parseFunctionLiteral() ast.Expression {
    lit := &ast.FunctionLiteral{Token: p.curToken}

    if !p.expectPeek(token.LPAREN) {
       return nil
    }

    lit.Parameters = p.parseFunctionParameters()

    if !p.expectPeek(token.LBRACE) {
       return nil
    }

    lit.Body = p.parseBlockStatement()

    return lit
}

func (p *Parser) parseFunctionParameters() []*ast.Identifier {
    identifiers := []*ast.Identifier{}

    if p.peekTokenIs(token.RPAREN) {
       p.nextToken()
       return identifiers
    }

    p.nextToken()

    ident := &ast.Identifier{Token: p.curToken, Value: p.curToken.Literal}
    identifiers = append(identifiers, ident)

    for p.peekTokenIs(token.COMMA) {
       p.nextToken()
       p.nextToken()
       ident := &ast.Identifier{Token: p.curToken, Value: p.curToken.Literal}
       identifiers = append(identifiers, ident)
    }

    if !p.expectPeek(token.RPAREN) {
       return nil
    }

    return identifiers
}

func (p *Parser) parseCallExpression(function ast.Expression) ast.Expression {
    exp := &ast.CallExpression{Token: p.curToken, Function: function}
    exp.Arguments = p.parseCallArguments()
    return exp
}

func (p *Parser) parseCallArguments() []ast.Expression {
    args := []ast.Expression{}

    if p.peekTokenIs(token.RPAREN) {
       p.nextToken()
       return args
    }

    p.nextToken()
    args = append(args, p.parseExpression(LOWEST))

    for p.peekTokenIs(token.COMMA) {
       p.nextToken()
       p.nextToken()
       args = append(args, p.parseExpression(LOWEST))
    }

    if !p.expectPeek(token.RPAREN) {
       return nil
    }

    return args
}

func (p *Parser) registerPrefix(tokenType token.TokenType, fn prefixParseFn) {
    p.prefixParseFns[tokenType] = fn
}

func (p *Parser) registerInfix(tokenType token.TokenType, fn infixParseFn) {
    p.infixParseFns[tokenType] = fn
}

语法分析,采用递归下降的解析方式,解析生成ast树,供后续的语义分析。

测试用例如下

package parser

import (
    "fmt"
    "monkey/ast"
    "monkey/lexer"
    "testing"
)

func TestLetStatements(t *testing.T) {
    tests := []struct {
       input              string
       expectedIdentifier string
       expectedValue      interface{}
    }{
       {"let x = 5;", "x", 5},
       {"let y = true;", "y", true},
       {"let foobar = y;", "foobar", "y"},
    }

    for _, tt := range tests {
       l := lexer.New(tt.input)
       p := New(l)
       program := p.ParseProgram()
       checkParserErrors(t, p)

       if len(program.Statements) != 1 {
          t.Fatalf("program.Statements does not contain 1 statements. got=%d",
             len(program.Statements))
       }

       stmt := program.Statements[0]
       if !testLetStatement(t, stmt, tt.expectedIdentifier) {
          return
       }

       val := stmt.(*ast.LetStatement).Value
       if !testLiteralExpression(t, val, tt.expectedValue) {
          return
       }
    }
}

func TestReturnStatements(t *testing.T) {
    tests := []struct {
       input         string
       expectedValue interface{}
    }{
       {"return 5;", 5},
       {"return true;", true},
       {"return foobar;", "foobar"},
    }

    for _, tt := range tests {
       l := lexer.New(tt.input)
       p := New(l)
       program := p.ParseProgram()
       checkParserErrors(t, p)

       if len(program.Statements) != 1 {
          t.Fatalf("program.Statements does not contain 1 statements. got=%d",
             len(program.Statements))
       }

       stmt := program.Statements[0]
       returnStmt, ok := stmt.(*ast.ReturnStatement)
       if !ok {
          t.Fatalf("stmt not *ast.ReturnStatement. got=%T", stmt)
       }
       if returnStmt.TokenLiteral() != "return" {
          t.Fatalf("returnStmt.TokenLiteral not 'return', got %q",
             returnStmt.TokenLiteral())
       }
       if testLiteralExpression(t, returnStmt.ReturnValue, tt.expectedValue) {
          return
       }
    }
}

func TestIdentifierExpression(t *testing.T) {
    input := "foobar;"

    l := lexer.New(input)
    p := New(l)
    program := p.ParseProgram()
    checkParserErrors(t, p)

    if len(program.Statements) != 1 {
       t.Fatalf("program has not enough statements. got=%d",
          len(program.Statements))
    }
    stmt, ok := program.Statements[0].(*ast.ExpressionStatement)
    if !ok {
       t.Fatalf("program.Statements[0] is not ast.ExpressionStatement. got=%T",
          program.Statements[0])
    }

    ident, ok := stmt.Expression.(*ast.Identifier)
    if !ok {
       t.Fatalf("exp not *ast.Identifier. got=%T", stmt.Expression)
    }
    if ident.Value != "foobar" {
       t.Errorf("ident.Value not %s. got=%s", "foobar", ident.Value)
    }
    if ident.TokenLiteral() != "foobar" {
       t.Errorf("ident.TokenLiteral not %s. got=%s", "foobar",
          ident.TokenLiteral())
    }
}

func TestIntegerLiteralExpression(t *testing.T) {
    input := "5;"

    l := lexer.New(input)
    p := New(l)
    program := p.ParseProgram()
    checkParserErrors(t, p)

    if len(program.Statements) != 1 {
       t.Fatalf("program has not enough statements. got=%d",
          len(program.Statements))
    }
    stmt, ok := program.Statements[0].(*ast.ExpressionStatement)
    if !ok {
       t.Fatalf("program.Statements[0] is not ast.ExpressionStatement. got=%T",
          program.Statements[0])
    }

    literal, ok := stmt.Expression.(*ast.IntegerLiteral)
    if !ok {
       t.Fatalf("exp not *ast.IntegerLiteral. got=%T", stmt.Expression)
    }
    if literal.Value != 5 {
       t.Errorf("literal.Value not %d. got=%d", 5, literal.Value)
    }
    if literal.TokenLiteral() != "5" {
       t.Errorf("literal.TokenLiteral not %s. got=%s", "5",
          literal.TokenLiteral())
    }
}

func TestParsingPrefixExpressions(t *testing.T) {
    prefixTests := []struct {
       input    string
       operator string
       value    interface{}
    }{
       {"!5;", "!", 5},
       {"-15;", "-", 15},
       {"!foobar;", "!", "foobar"},
       {"-foobar;", "-", "foobar"},
       {"!true;", "!", true},
       {"!false;", "!", false},
    }

    for _, tt := range prefixTests {
       l := lexer.New(tt.input)
       p := New(l)
       program := p.ParseProgram()
       checkParserErrors(t, p)

       if len(program.Statements) != 1 {
          t.Fatalf("program.Statements does not contain %d statements. got=%d\n",
             1, len(program.Statements))
       }

       stmt, ok := program.Statements[0].(*ast.ExpressionStatement)
       if !ok {
          t.Fatalf("program.Statements[0] is not ast.ExpressionStatement. got=%T",
             program.Statements[0])
       }

       exp, ok := stmt.Expression.(*ast.PrefixExpression)
       if !ok {
          t.Fatalf("stmt is not ast.PrefixExpression. got=%T", stmt.Expression)
       }
       if exp.Operator != tt.operator {
          t.Fatalf("exp.Operator is not '%s'. got=%s",
             tt.operator, exp.Operator)
       }
       if !testLiteralExpression(t, exp.Right, tt.value) {
          return
       }
    }
}

func TestParsingInfixExpressions(t *testing.T) {
    infixTests := []struct {
       input      string
       leftValue  interface{}
       operator   string
       rightValue interface{}
    }{
       {"5 + 5;", 5, "+", 5},
       {"5 - 5;", 5, "-", 5},
       {"5 * 5;", 5, "*", 5},
       {"5 / 5;", 5, "/", 5},
       {"5 > 5;", 5, ">", 5},
       {"5 < 5;", 5, "<", 5},
       {"5 == 5;", 5, "==", 5},
       {"5 != 5;", 5, "!=", 5},
       {"foobar + barfoo;", "foobar", "+", "barfoo"},
       {"foobar - barfoo;", "foobar", "-", "barfoo"},
       {"foobar * barfoo;", "foobar", "*", "barfoo"},
       {"foobar / barfoo;", "foobar", "/", "barfoo"},
       {"foobar > barfoo;", "foobar", ">", "barfoo"},
       {"foobar < barfoo;", "foobar", "<", "barfoo"},
       {"foobar == barfoo;", "foobar", "==", "barfoo"},
       {"foobar != barfoo;", "foobar", "!=", "barfoo"},
       {"true == true", true, "==", true},
       {"true != false", true, "!=", false},
       {"false == false", false, "==", false},
    }

    for _, tt := range infixTests {
       l := lexer.New(tt.input)
       p := New(l)
       program := p.ParseProgram()
       checkParserErrors(t, p)

       if len(program.Statements) != 1 {
          t.Fatalf("program.Statements does not contain %d statements. got=%d\n",
             1, len(program.Statements))
       }

       stmt, ok := program.Statements[0].(*ast.ExpressionStatement)
       if !ok {
          t.Fatalf("program.Statements[0] is not ast.ExpressionStatement. got=%T",
             program.Statements[0])
       }

       if !testInfixExpression(t, stmt.Expression, tt.leftValue,
          tt.operator, tt.rightValue) {
          return
       }
    }
}

func TestOperatorPrecedenceParsing(t *testing.T) {
    tests := []struct {
       input    string
       expected string
    }{
       {
          "-a * b",
          "((-a) * b)",
       },
       {
          "!-a",
          "(!(-a))",
       },
       {
          "a + b + c",
          "((a + b) + c)",
       },
       {
          "a + b - c",
          "((a + b) - c)",
       },
       {
          "a * b * c",
          "((a * b) * c)",
       },
       {
          "a * b / c",
          "((a * b) / c)",
       },
       {
          "a + b / c",
          "(a + (b / c))",
       },
       {
          "a + b * c + d / e - f",
          "(((a + (b * c)) + (d / e)) - f)",
       },
       {
          "3 + 4; -5 * 5",
          "(3 + 4)((-5) * 5)",
       },
       {
          "5 > 4 == 3 < 4",
          "((5 > 4) == (3 < 4))",
       },
       {
          "5 < 4 != 3 > 4",
          "((5 < 4) != (3 > 4))",
       },
       {
          "3 + 4 * 5 == 3 * 1 + 4 * 5",
          "((3 + (4 * 5)) == ((3 * 1) + (4 * 5)))",
       },
       {
          "true",
          "true",
       },
       {
          "false",
          "false",
       },
       {
          "3 > 5 == false",
          "((3 > 5) == false)",
       },
       {
          "3 < 5 == true",
          "((3 < 5) == true)",
       },
       {
          "1 + (2 + 3) + 4",
          "((1 + (2 + 3)) + 4)",
       },
       {
          "(5 + 5) * 2",
          "((5 + 5) * 2)",
       },
       {
          "2 / (5 + 5)",
          "(2 / (5 + 5))",
       },
       {
          "(5 + 5) * 2 * (5 + 5)",
          "(((5 + 5) * 2) * (5 + 5))",
       },
       {
          "-(5 + 5)",
          "(-(5 + 5))",
       },
       {
          "!(true == true)",
          "(!(true == true))",
       },
       {
          "a + add(b * c) + d",
          "((a + add((b * c))) + d)",
       },
       {
          "add(a, b, 1, 2 * 3, 4 + 5, add(6, 7 * 8))",
          "add(a, b, 1, (2 * 3), (4 + 5), add(6, (7 * 8)))",
       },
       {
          "add(a + b + c * d / f + g)",
          "add((((a + b) + ((c * d) / f)) + g))",
       },
    }

    for _, tt := range tests {
       l := lexer.New(tt.input)
       p := New(l)
       program := p.ParseProgram()
       checkParserErrors(t, p)

       actual := program.String()
       if actual != tt.expected {
          t.Errorf("expected=%q, got=%q", tt.expected, actual)
       }
    }
}

func TestBooleanExpression(t *testing.T) {
    tests := []struct {
       input           string
       expectedBoolean bool
    }{
       {"true;", true},
       {"false;", false},
    }

    for _, tt := range tests {
       l := lexer.New(tt.input)
       p := New(l)
       program := p.ParseProgram()
       checkParserErrors(t, p)

       if len(program.Statements) != 1 {
          t.Fatalf("program has not enough statements. got=%d",
             len(program.Statements))
       }

       stmt, ok := program.Statements[0].(*ast.ExpressionStatement)
       if !ok {
          t.Fatalf("program.Statements[0] is not ast.ExpressionStatement. got=%T",
             program.Statements[0])
       }

       boolean, ok := stmt.Expression.(*ast.Boolean)
       if !ok {
          t.Fatalf("exp not *ast.Boolean. got=%T", stmt.Expression)
       }
       if boolean.Value != tt.expectedBoolean {
          t.Errorf("boolean.Value not %t. got=%t", tt.expectedBoolean,
             boolean.Value)
       }
    }
}

func TestIfExpression(t *testing.T) {
    input := `if (x < y) { x }`

    l := lexer.New(input)
    p := New(l)
    program := p.ParseProgram()
    checkParserErrors(t, p)

    if len(program.Statements) != 1 {
       t.Fatalf("program.Statements does not contain %d statements. got=%d\n",
          1, len(program.Statements))
    }

    stmt, ok := program.Statements[0].(*ast.ExpressionStatement)
    if !ok {
       t.Fatalf("program.Statements[0] is not ast.ExpressionStatement. got=%T",
          program.Statements[0])
    }

    exp, ok := stmt.Expression.(*ast.IfExpression)
    if !ok {
       t.Fatalf("stmt.Expression is not ast.IfExpression. got=%T",
          stmt.Expression)
    }

    if !testInfixExpression(t, exp.Condition, "x", "<", "y") {
       return
    }

    if len(exp.Consequence.Statements) != 1 {
       t.Errorf("consequence is not 1 statements. got=%d\n",
          len(exp.Consequence.Statements))
    }

    consequence, ok := exp.Consequence.Statements[0].(*ast.ExpressionStatement)
    if !ok {
       t.Fatalf("Statements[0] is not ast.ExpressionStatement. got=%T",
          exp.Consequence.Statements[0])
    }

    if !testIdentifier(t, consequence.Expression, "x") {
       return
    }

    if exp.Alternative != nil {
       t.Errorf("exp.Alternative.Statements was not nil. got=%+v", exp.Alternative)
    }
}

func TestIfElseExpression(t *testing.T) {
    input := `if (x < y) { x } else { y }`

    l := lexer.New(input)
    p := New(l)
    program := p.ParseProgram()
    checkParserErrors(t, p)

    if len(program.Statements) != 1 {
       t.Fatalf("program.Statements does not contain %d statements. got=%d\n",
          1, len(program.Statements))
    }

    stmt, ok := program.Statements[0].(*ast.ExpressionStatement)
    if !ok {
       t.Fatalf("program.Statements[0] is not ast.ExpressionStatement. got=%T",
          program.Statements[0])
    }

    exp, ok := stmt.Expression.(*ast.IfExpression)
    if !ok {
       t.Fatalf("stmt.Expression is not ast.IfExpression. got=%T", stmt.Expression)
    }

    if !testInfixExpression(t, exp.Condition, "x", "<", "y") {
       return
    }

    if len(exp.Consequence.Statements) != 1 {
       t.Errorf("consequence is not 1 statements. got=%d\n",
          len(exp.Consequence.Statements))
    }

    consequence, ok := exp.Consequence.Statements[0].(*ast.ExpressionStatement)
    if !ok {
       t.Fatalf("Statements[0] is not ast.ExpressionStatement. got=%T",
          exp.Consequence.Statements[0])
    }

    if !testIdentifier(t, consequence.Expression, "x") {
       return
    }

    if len(exp.Alternative.Statements) != 1 {
       t.Errorf("exp.Alternative.Statements does not contain 1 statements. got=%d\n",
          len(exp.Alternative.Statements))
    }

    alternative, ok := exp.Alternative.Statements[0].(*ast.ExpressionStatement)
    if !ok {
       t.Fatalf("Statements[0] is not ast.ExpressionStatement. got=%T",
          exp.Alternative.Statements[0])
    }

    if !testIdentifier(t, alternative.Expression, "y") {
       return
    }
}

func TestFunctionLiteralParsing(t *testing.T) {
    input := `fn(x, y) { x + y; }`

    l := lexer.New(input)
    p := New(l)
    program := p.ParseProgram()
    checkParserErrors(t, p)

    if len(program.Statements) != 1 {
       t.Fatalf("program.Statements does not contain %d statements. got=%d\n",
          1, len(program.Statements))
    }

    stmt, ok := program.Statements[0].(*ast.ExpressionStatement)
    if !ok {
       t.Fatalf("program.Statements[0] is not ast.ExpressionStatement. got=%T",
          program.Statements[0])
    }

    function, ok := stmt.Expression.(*ast.FunctionLiteral)
    if !ok {
       t.Fatalf("stmt.Expression is not ast.FunctionLiteral. got=%T",
          stmt.Expression)
    }

    if len(function.Parameters) != 2 {
       t.Fatalf("function literal parameters wrong. want 2, got=%d\n",
          len(function.Parameters))
    }

    testLiteralExpression(t, function.Parameters[0], "x")
    testLiteralExpression(t, function.Parameters[1], "y")

    if len(function.Body.Statements) != 1 {
       t.Fatalf("function.Body.Statements has not 1 statements. got=%d\n",
          len(function.Body.Statements))
    }

    bodyStmt, ok := function.Body.Statements[0].(*ast.ExpressionStatement)
    if !ok {
       t.Fatalf("function body stmt is not ast.ExpressionStatement. got=%T",
          function.Body.Statements[0])
    }

    testInfixExpression(t, bodyStmt.Expression, "x", "+", "y")
}

func TestFunctionParameterParsing(t *testing.T) {
    tests := []struct {
       input          string
       expectedParams []string
    }{
       {input: "fn() {};", expectedParams: []string{}},
       {input: "fn(x) {};", expectedParams: []string{"x"}},
       {input: "fn(x, y, z) {};", expectedParams: []string{"x", "y", "z"}},
    }

    for _, tt := range tests {
       l := lexer.New(tt.input)
       p := New(l)
       program := p.ParseProgram()
       checkParserErrors(t, p)

       stmt := program.Statements[0].(*ast.ExpressionStatement)
       function := stmt.Expression.(*ast.FunctionLiteral)

       if len(function.Parameters) != len(tt.expectedParams) {
          t.Errorf("length parameters wrong. want %d, got=%d\n",
             len(tt.expectedParams), len(function.Parameters))
       }

       for i, ident := range tt.expectedParams {
          testLiteralExpression(t, function.Parameters[i], ident)
       }
    }
}

func TestCallExpressionParsing(t *testing.T) {
    input := "add(1, 2 * 3, 4 + 5);"

    l := lexer.New(input)
    p := New(l)
    program := p.ParseProgram()
    checkParserErrors(t, p)

    if len(program.Statements) != 1 {
       t.Fatalf("program.Statements does not contain %d statements. got=%d\n",
          1, len(program.Statements))
    }

    stmt, ok := program.Statements[0].(*ast.ExpressionStatement)
    if !ok {
       t.Fatalf("stmt is not ast.ExpressionStatement. got=%T",
          program.Statements[0])
    }

    exp, ok := stmt.Expression.(*ast.CallExpression)
    if !ok {
       t.Fatalf("stmt.Expression is not ast.CallExpression. got=%T",
          stmt.Expression)
    }

    if !testIdentifier(t, exp.Function, "add") {
       return
    }

    if len(exp.Arguments) != 3 {
       t.Fatalf("wrong length of arguments. got=%d", len(exp.Arguments))
    }

    testLiteralExpression(t, exp.Arguments[0], 1)
    testInfixExpression(t, exp.Arguments[1], 2, "*", 3)
    testInfixExpression(t, exp.Arguments[2], 4, "+", 5)
}

func TestCallExpressionParameterParsing(t *testing.T) {
    tests := []struct {
       input         string
       expectedIdent string
       expectedArgs  []string
    }{
       {
          input:         "add();",
          expectedIdent: "add",
          expectedArgs:  []string{},
       },
       {
          input:         "add(1);",
          expectedIdent: "add",
          expectedArgs:  []string{"1"},
       },
       {
          input:         "add(1, 2 * 3, 4 + 5);",
          expectedIdent: "add",
          expectedArgs:  []string{"1", "(2 * 3)", "(4 + 5)"},
       },
    }

    for _, tt := range tests {
       l := lexer.New(tt.input)
       p := New(l)
       program := p.ParseProgram()
       checkParserErrors(t, p)

       stmt := program.Statements[0].(*ast.ExpressionStatement)
       exp, ok := stmt.Expression.(*ast.CallExpression)
       if !ok {
          t.Fatalf("stmt.Expression is not ast.CallExpression. got=%T",
             stmt.Expression)
       }

       if !testIdentifier(t, exp.Function, tt.expectedIdent) {
          return
       }

       if len(exp.Arguments) != len(tt.expectedArgs) {
          t.Fatalf("wrong number of arguments. want=%d, got=%d",
             len(tt.expectedArgs), len(exp.Arguments))
       }

       for i, arg := range tt.expectedArgs {
          if exp.Arguments[i].String() != arg {
             t.Errorf("argument %d wrong. want=%q, got=%q", i,
                arg, exp.Arguments[i].String())
          }
       }
    }
}

func testLetStatement(t *testing.T, s ast.Statement, name string) bool {
    if s.TokenLiteral() != "let" {
       t.Errorf("s.TokenLiteral not 'let'. got=%q", s.TokenLiteral())
       return false
    }

    letStmt, ok := s.(*ast.LetStatement)
    if !ok {
       t.Errorf("s not *ast.LetStatement. got=%T", s)
       return false
    }

    if letStmt.Name.Value != name {
       t.Errorf("letStmt.Name.Value not '%s'. got=%s", name, letStmt.Name.Value)
       return false
    }

    if letStmt.Name.TokenLiteral() != name {
       t.Errorf("letStmt.Name.TokenLiteral() not '%s'. got=%s",
          name, letStmt.Name.TokenLiteral())
       return false
    }

    return true
}

func testInfixExpression(t *testing.T, exp ast.Expression, left interface{},
    operator string, right interface{}) bool {

    opExp, ok := exp.(*ast.InfixExpression)
    if !ok {
       t.Errorf("exp is not ast.InfixExpression. got=%T(%s)", exp, exp)
       return false
    }

    if !testLiteralExpression(t, opExp.Left, left) {
       return false
    }

    if opExp.Operator != operator {
       t.Errorf("exp.Operator is not '%s'. got=%q", operator, opExp.Operator)
       return false
    }

    if !testLiteralExpression(t, opExp.Right, right) {
       return false
    }

    return true
}

func testLiteralExpression(
    t *testing.T,
    exp ast.Expression,
    expected interface{},
) bool {
    switch v := expected.(type) {
    case int:
       return testIntegerLiteral(t, exp, int64(v))
    case int64:
       return testIntegerLiteral(t, exp, v)
    case string:
       return testIdentifier(t, exp, v)
    case bool:
       return testBooleanLiteral(t, exp, v)
    }
    t.Errorf("type of exp not handled. got=%T", exp)
    return false
}

func testIntegerLiteral(t *testing.T, il ast.Expression, value int64) bool {
    integ, ok := il.(*ast.IntegerLiteral)
    if !ok {
       t.Errorf("il not *ast.IntegerLiteral. got=%T", il)
       return false
    }

    if integ.Value != value {
       t.Errorf("integ.Value not %d. got=%d", value, integ.Value)
       return false
    }

    if integ.TokenLiteral() != fmt.Sprintf("%d", value) {
       t.Errorf("integ.TokenLiteral not %d. got=%s", value,
          integ.TokenLiteral())
       return false
    }

    return true
}

func testIdentifier(t *testing.T, exp ast.Expression, value string) bool {
    ident, ok := exp.(*ast.Identifier)
    if !ok {
       t.Errorf("exp not *ast.Identifier. got=%T", exp)
       return false
    }

    if ident.Value != value {
       t.Errorf("ident.Value not %s. got=%s", value, ident.Value)
       return false
    }

    if ident.TokenLiteral() != value {
       t.Errorf("ident.TokenLiteral not %s. got=%s", value,
          ident.TokenLiteral())
       return false
    }

    return true
}

func testBooleanLiteral(t *testing.T, exp ast.Expression, value bool) bool {
    bo, ok := exp.(*ast.Boolean)
    if !ok {
       t.Errorf("exp not *ast.Boolean. got=%T", exp)
       return false
    }

    if bo.Value != value {
       t.Errorf("bo.Value not %t. got=%t", value, bo.Value)
       return false
    }

    if bo.TokenLiteral() != fmt.Sprintf("%t", value) {
       t.Errorf("bo.TokenLiteral not %t. got=%s",
          value, bo.TokenLiteral())
       return false
    }

    return true
}

func checkParserErrors(t *testing.T, p *Parser) {
    errors := p.Errors()
    if len(errors) == 0 {
       return
    }

    t.Errorf("parser has %d errors", len(errors))
    for _, msg := range errors {
       t.Errorf("parser error: %q", msg)
    }
    t.FailNow()
}

总结

以上就是生成ast的介绍了,其实理解了递归下降的解析方式,参照书中,自定义一个小型的语言集,我想大家也都可以实现一个ast解析树吧。

相关推荐
Re.不晚13 分钟前
Java入门15——抽象类
java·开发语言·学习·算法·intellij-idea
老秦包你会15 分钟前
Qt第三课 ----------容器类控件
开发语言·qt
凤枭香18 分钟前
Python OpenCV 傅里叶变换
开发语言·图像处理·python·opencv
ULTRA??22 分钟前
C加加中的结构化绑定(解包,折叠展开)
开发语言·c++
码农派大星。23 分钟前
Spring Boot 配置文件
java·spring boot·后端
远望清一色38 分钟前
基于MATLAB的实现垃圾分类Matlab源码
开发语言·matlab
confiself1 小时前
大模型系列——LLAMA-O1 复刻代码解读
java·开发语言
XiaoLeisj1 小时前
【JavaEE初阶 — 多线程】Thread类的方法&线程生命周期
java·开发语言·java-ee
杜杜的man1 小时前
【go从零单排】go中的结构体struct和method
开发语言·后端·golang
幼儿园老大*1 小时前
走进 Go 语言基础语法
开发语言·后端·学习·golang·go