《LLVM IR 学习手记(五):关系运算与循环语句的实现与解析》

1、实现关系运算

1.1 测试文件

expr.txt

ini 复制代码
int a = 1;
int b = 4;
if( a > b)
{
  b = a + b;
  if(a <= b)
  {
    b = 2 * a - b;
  }
}
a * (b + 3) - 2;

1.2 词法分析器 (Lexer)

对于词法分析器,需要增加 equal_equal, not_equal, less, less_equal, greater, greater_equal 这四种类,以及增加对应的判断。

文法定义

ebnf.txt

ini 复制代码
prog : stmt*
stmt : decl-stmt | expr-stmt | null-stmt | if-stmt | block-stmt
null-stmt : ";"
decl-stmt : "int" identifier ("," identifier (= expr)?)* ";"
if-stmt : "if" "(" expr ")" stmt ( "else" stmt )?
block-stmt: "{" stmt* "}"
expr-stmt : expr ";"
expr : assign-expr | equal-expr
assign-expr: identifier "=" expr
equal-expr : relational-expr (("==" | "!=") relational-expr)*
relational-expr: add-expr (("<"|">"|"<="|">=") add-expr)*
add-expr : mult-expr (("+" | "-") mult-expr)* 
mult-expr : primary-expr (("*" | "/") primary-expr)* 
primary-expr : identifier | number | "(" expr ")" 
number: ([0-9])+ 
identifier : (a-zA-Z_)(a-zA-Z0-9_)*

实现代码

lexer.h

cpp 复制代码
#pragma once

#include "llvm/ADT/StringRef.h"
#include "llvm/Support/raw_ostream.h"
#include "type.h"
#include "diag_engine.h"

// char stream -> token

enum class TokenType
{
    number,        // [0-9]+
    indentifier,   // 变量
    kw_int,        // int
    kw_if,         // if
    kw_else,       // else
    plus,          // +
    minus,         // -
    star,          // *
    slash,         // /
    equal,         // =
    equal_equal,   // == 
    not_equal,     // !=
    less,          // <
    less_equal,    // <=
    greater,       // >
    greater_equal, // >=
    l_parent,      // (
    r_parent,      // )
    l_brace,       // {
    r_brace,       // }
    semi,          // ;
    comma,         // ,
    eof            // end of file
};

class Token
{
public:
    TokenType tokenType; // token 的种类
    int row, col;

    int value; // for number

    const char *ptr; // for debug
    int length;

    CType *type; // for built-in type

public:
    void Dump();
    static llvm::StringRef GetSpellingText(TokenType tokenType);
};

class Lexer
{
private:
    DiagEngine &diagEngine;
    llvm::SourceMgr &mgr;

public:
    Lexer(DiagEngine &diagEngine, llvm::SourceMgr &mgr) : diagEngine(diagEngine), mgr(mgr)
    {
        unsigned id = mgr.getMainFileID();
        llvm::StringRef buf = mgr.getMemoryBuffer(id)->getBuffer();
        BufPtr = buf.begin();
        LineHeadPtr = buf.begin();
        BufEnd = buf.end();
        row = 1;
    }
    void NextToken(Token &token);

    void SaveState();
    void RestoreState();

    DiagEngine &GetDiagEngine() const
    {
        return diagEngine;
    }

private:
    struct State
    {
        const char *BufPtr;
        const char *LineHeadPtr;
        const char *BufEnd;
        int row;
    };

private:
    const char *BufPtr;
    const char *LineHeadPtr;
    const char *BufEnd;
    int row;

    State state;
};

lexer.cc

cpp 复制代码
#include "lexer.h"

void Token::Dump()
{
    llvm::StringRef text(ptr, length);
    llvm::outs() << "{" << text << ", row = " << row << ", col = " << col << "}\n";
}

// number,     // [0-9]+
// indentifier,// 变量
// kw_int,     // int
// plus,       // +
// minus,      // -
// star,       // *
// slash,      // /
// equal,      // =
// l_parent,   // (
// r_parent,   // )
// semi,       // ;
// comma,      // ,
llvm::StringRef Token::GetSpellingText(TokenType tokenType)
{
    switch (tokenType)
    {
    case TokenType::kw_int:
        return "int";
    case TokenType::plus:
        return "+";
    case TokenType::minus:
        return "-";
    case TokenType::star:
        return "*";
    case TokenType::slash:
        return "/";
    case TokenType::equal:
        return "=";
    case TokenType::l_parent:
        return "(";
    case TokenType::r_parent:
        return ")";
    case TokenType::semi:
        return ";";
    case TokenType::comma:
        return ",";
    case TokenType::number:
        return "number";
    case TokenType::indentifier:
        return "indentifier";
    case TokenType::kw_if:
        return "if";
    case TokenType::kw_else:
        return "else";
    case TokenType::l_brace:
        return "{";
    case TokenType::r_brace:
        return "}";
    case TokenType::equal_equal:
        return "==";
    case TokenType::not_equal:
        return "!=";
    case TokenType::less:
        return "<";
    case TokenType::less_equal:
        return "<=";
    case TokenType::greater:
        return ">";
    case TokenType::greater_equal:
        return ">=";
    default:
        llvm::llvm_unreachable_internal(); // 不可能到达这个位置
    }
}

bool IsWhiteSpace(char ch)
{
    return ch == ' ' || ch == '\t' || ch == '\r' || ch == '\n';
}

bool IsDigit(char ch)
{
    return ch >= '0' && ch <= '9';
}

bool IsLetter(char ch)
{
    // a-z, A-Z, _
    return (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch == '_');
}

void Lexer::NextToken(Token &token)
{
    // 过滤空格
    while (IsWhiteSpace(*BufPtr))
    {
        if (*BufPtr == '\n')
        {
            row += 1;
            LineHeadPtr = BufPtr + 1;
        }
        BufPtr++;
    }

    token.row = row;
    token.col = BufPtr - LineHeadPtr + 1;

    // 判断是否到结尾了
    if (BufPtr >= BufEnd)
    {
        token.tokenType = TokenType::eof;
        return;
    }

    token.ptr = BufPtr;
    token.length = 0;
    // 判断是否为数字
    if (IsDigit(*BufPtr))
    {
        int len = 0;
        int val = 0;
        while (IsDigit(*BufPtr))
        {
            val = val * 10 + *BufPtr++ - '0';
            token.length++;
        }
        token.value = val;
        token.tokenType = TokenType::number;
        token.type = CType::getIntTy();
    }
    else if (IsLetter(*BufPtr)) // 为变量
    {
        while (IsLetter(*BufPtr) || IsDigit(*BufPtr))
        {
            BufPtr++;
        }
        token.tokenType = TokenType::indentifier;
        token.length = BufPtr - token.ptr;
        llvm::StringRef text(token.ptr, BufPtr - token.ptr);
        if (text == "int")
        {
            token.tokenType = TokenType::kw_int;
        }
        else if (text == "if")
        {
            token.tokenType = TokenType::kw_if;
        }
        else if (text == "else")
        {
            token.tokenType = TokenType::kw_else;
        }
    }
    else // 为特殊字符
    {
        switch (*BufPtr)
        {
        case '+':
        {
            token.tokenType = TokenType::plus;
            token.length = 1;
            BufPtr++;
            break;
        }
        case '-':
        {
            token.tokenType = TokenType::minus;
            token.length = 1;
            BufPtr++;
            break;
        }
        case '*':
        {
            token.tokenType = TokenType::star;
            token.length = 1;
            BufPtr++;
            break;
        }
        case '/':
        {
            token.tokenType = TokenType::slash;
            token.length = 1;
            BufPtr++;
            break;
        }
        case '=':
        {
            if (*(BufPtr + 1) == '=')
            {
                token.tokenType = TokenType::equal_equal;
                token.length = 2;
                BufPtr += 2;
            }
            else
            {
                token.tokenType = TokenType::equal;
                token.length = 1;
                BufPtr++;
            }
            break;
        }
        case '!':
        {
            if (*(BufPtr + 1) == '=')
            {
                token.tokenType = TokenType::not_equal;
                token.length = 2;
                BufPtr += 2;
                break;
            }
            // pass through
        }
        case '<':
        {
            if (*(BufPtr + 1) == '=')
            {
                token.tokenType = TokenType::less_equal;
                token.length = 2;
                BufPtr += 2;
            }
            else
            {
                token.tokenType = TokenType::less;
                token.length = 1;
                BufPtr++;
            }
            break;
        }
        case '>':
        {
            if (*(BufPtr + 1) == '=')
            {
                token.tokenType = TokenType::greater_equal;
                token.length = 2;
                BufPtr += 2;
            }
            else
            {
                token.tokenType = TokenType::greater;
                token.length = 1;
                BufPtr++;
            }
            break;
        }
        case '(':
        {
            token.tokenType = TokenType::l_parent;
            token.length = 1;
            BufPtr++;
            break;
        }
        case ')':
        {
            token.tokenType = TokenType::r_parent;
            token.length = 1;
            BufPtr++;
            break;
        }
        case ';':
        {
            token.tokenType = TokenType::semi;
            token.length = 1;
            BufPtr++;
            break;
        }
        case ',':
        {
            token.tokenType = TokenType::comma;
            token.length = 1;
            BufPtr++;
            break;
        }
        case '{':
        {
            token.tokenType = TokenType::l_brace;
            token.length = 1;
            BufPtr++;
            break;
        }
        case '}':
        {
            token.tokenType = TokenType::r_brace;
            token.length = 1;
            BufPtr++;
            break;
        }
        default:
        {
            diagEngine.Report(llvm::SMLoc::getFromPointer(BufPtr), diag::err_unknown_char, *BufPtr);
        }
        }
    }
}

void Lexer::SaveState()
{
    state.LineHeadPtr = LineHeadPtr;
    state.BufPtr = BufPtr;
    state.BufEnd = BufEnd;
    state.row = row;
}

void Lexer::RestoreState()
{
    LineHeadPtr = state.LineHeadPtr;
    BufPtr = state.BufPtr;
    BufEnd = state.BufEnd;
    row = state.row;
}

1.3 语法分析器 (Parser)

由于文法中新增了等号表达式和关系表达式,以及对下面的文法也进行了更详细的说明,对于这些变化需要对语法分析器进行修改。

实现代码

parser.h

cpp 复制代码
#pragma once

#include "ast.h"
#include "lexer.h"
#include "sema.h"
#include <vector>

class Parser
{
public:
    Parser(Lexer &lexer, Sema &sema) : lexer(lexer), sema(sema)
    {
        Advance();
    }
    std::shared_ptr<Program> ParseProgram();

private:
    std::shared_ptr<ASTNode> ParseStmt();
    std::shared_ptr<ASTNode> ParseDeclStmt();
    std::shared_ptr<ASTNode> ParseBlockStmt();
    std::shared_ptr<ASTNode> ParseExprStmt();
    std::shared_ptr<ASTNode> ParseIfStmt();
    std::shared_ptr<ASTNode> ParseExpr();
    std::shared_ptr<ASTNode> ParseAssignExpr();
    std::shared_ptr<ASTNode> ParseEqualExpr();
    std::shared_ptr<ASTNode> ParseRelationalExpr();
    std::shared_ptr<ASTNode> ParseAddExpr();
    std::shared_ptr<ASTNode> ParseMultiExpr();
    std::shared_ptr<ASTNode> ParsePrimary();

    // 消耗 token 的函数
    // 检测 token 的类型
    bool Expect(TokenType tokenType);
    // 检测 token 的类型并消费
    bool Consume(TokenType tokenType);
    // 直接消耗当前的 token
    bool Advance();

    // 检查是否为类型名
    bool IsTypeName();

    DiagEngine &GetDiagEngine() const
    {
        return lexer.GetDiagEngine();
    }

private:
    Lexer &lexer;
    Sema &sema;
    Token token;
    std::vector<std::shared_ptr<ASTNode>> breakNodes;
    std::vector<std::shared_ptr<ASTNode>> continueNodes;
};

parser.cc

cpp 复制代码
#include "parser.h"
#include <cassert>

/*
prog : stmt*
stmt : decl-stmt | expr-stmt | null-stmt | if-stmt
null-stmt : ";"
decl-stmt : "int" identifier ("," identifier (= expr)?)* ";"
if-stmt : "if" "(" expr ")" stmt ( "else" stmt )?
expr-stmt : expr ";"
expr : assign-expr | add-expr
assign-expr: identifier "=" expr
add-expr : mult-expr (("+" | "-") mult-expr)*
mult-expr : primary-expr (("*" | "/") primary-expr)*
primary-expr : identifier | number | "(" expr ")"
number: ([0-9])+
identifier : (a-zA-Z_)(a-zA-Z0-9_)*
*/

// 解析目标程序
// stmt : decl-stmt | expr-stmt | null-stmt
std::shared_ptr<Program> Parser::ParseProgram()
{
    std::vector<std::shared_ptr<ASTNode>> nodeVec;
    while (token.tokenType != TokenType::eof)
    {
        auto stmt = ParseStmt();
        if (stmt)
            nodeVec.push_back(stmt);
    }
    auto program = std::make_shared<Program>();
    program->nodeVec = std::move(nodeVec);
    return program;
}

// 解析语句
std::shared_ptr<ASTNode> Parser::ParseStmt()
{
    // 遇到 ; 需要进行消费 token
    // null-stmt
    if (token.tokenType == TokenType::semi)
    {
        Consume(TokenType::semi);
        return nullptr;
    }
    // decl-stmt
    else if (IsTypeName())
    {
        return ParseDeclStmt();
    }
    // block-stmt
    else if (token.tokenType == TokenType::l_brace)
    {
        return ParseBlockStmt();
    }
    // if-stmt
    else if (token.tokenType == TokenType::kw_if)
    {
        return ParseIfStmt();
    }
    // expr-stmt
    else
    {
        return ParseExprStmt();
    }
}

// 解析声明语句
std::shared_ptr<ASTNode> Parser::ParseDeclStmt()
{
    /// int a, b = 3;
    /// int a = 3;
    Consume(TokenType::kw_int);
    CType *baseTy = CType::getIntTy();
    /// a , b = 3;
    /// a = 3;

    auto declStmt = std::make_shared<DeclStmt>();

    /// a, b = 3;
    /// a = 3;
    int i = 0;
    while (token.tokenType != TokenType::semi)
    {
        if (i++ > 0) // if (i++)
        {
            assert(Consume(TokenType::comma));
        }

        /// 变量声明的节点: int a = 3; -> int a; a = 3;
        // a = 3;
        auto variableDecl = sema.SemaVariableDeclNode(token, baseTy); // get a type
        declStmt->nodeVec.push_back(variableDecl);

        Token tmp = token;
        Consume(TokenType::indentifier);

        // = 3;
        if (token.tokenType == TokenType::equal)
        {
            Token opToken = token;
            Advance();

            // 3;
            auto right = ParseExpr();
            auto left = sema.SemaVariableAccessNode(tmp);
            auto assign = sema.SemaAssignExprNode(left, right, opToken);

            declStmt->nodeVec.push_back(assign);
        }
    }

    Consume(TokenType::semi);

    return declStmt;
}

std::shared_ptr<ASTNode> Parser::ParseBlockStmt()
{
    sema.EnterScope(); // 进入作用域

    auto blockStmt = std::make_shared<BlockStmt>();
    Consume(TokenType::l_brace);
    while (token.tokenType != TokenType::r_brace)
    {
        auto stmt = ParseStmt();
        if (stmt)
            blockStmt->nodeVec.push_back(stmt);
    }
    Consume(TokenType::r_brace);

    sema.ExitScope(); // 离开作用域

    return blockStmt;
}

// 解析表达式语句
std::shared_ptr<ASTNode> Parser::ParseExprStmt()
{
    auto expr = ParseExpr();
    Consume(TokenType::semi);
    return expr;
}

// if-stmt : "if" "(" expr ")" stmt ( "else" stmt )?
/*
if (a)
  b = 3;
else
    b = 4;
*/
std::shared_ptr<ASTNode> Parser::ParseIfStmt()
{
    Consume(TokenType::kw_if);
    Consume(TokenType::l_parent);
    auto condExpr = ParseExpr();
    Consume(TokenType::r_parent);
    auto thenStmt = ParseStmt();
    std::shared_ptr<ASTNode> elseStmt = nullptr;
    if (token.tokenType == TokenType::kw_else)
    {
        Consume(TokenType::kw_else);
        elseStmt = ParseStmt();
    }
    return sema.SemaIfStmtNode(condExpr, thenStmt, elseStmt);
}

// 解析表达式
// expr : assign-expr | add-expr
// assign-expr: identifier "=" expr
// add-expr : mult-expr (("+" | "-") mult-expr)*
std::shared_ptr<ASTNode> Parser::ParseExpr()
{
    lexer.SaveState();
    bool isAssign = false;
    Token tmp = token;
    // a = b;
    if (tmp.tokenType == TokenType::indentifier)
    {
        lexer.NextToken(tmp);
        if (tmp.tokenType == TokenType::equal)
        {
            isAssign = true;
        }
    }
    lexer.RestoreState();
    if (isAssign)
    {
        return ParseAssignExpr();
    }
    else // equal-expr
        return ParseEqualExpr();
}

// 解析赋值表达式
std::shared_ptr<ASTNode> Parser::ParseAssignExpr()
{
    // a = b;
    Token tmp = token;
    Consume(TokenType::indentifier);
    auto expr = sema.SemaVariableAccessNode(tmp);
    Token opToken = token;
    Consume(TokenType::equal);
    return sema.SemaAssignExprNode(expr, ParseExpr(), opToken);
}

// 解析等号表达式
std::shared_ptr<ASTNode> Parser::ParseEqualExpr()
{
    std::shared_ptr<ASTNode> left = ParseRelationalExpr();
    while (token.tokenType == TokenType::equal_equal || token.tokenType == TokenType::not_equal)
    {
        OpCode op;
        if (token.tokenType == TokenType::equal_equal)
        {
            op = OpCode::equal_equal;
        }
        else
        {
            op = OpCode::not_equal;
        }
        Advance();
        auto right = ParseRelationalExpr();
        auto binaryExpr = sema.SemaBinaryExprNode(left, right, op);

        left = binaryExpr;
    }
    return left;
}

// 解析关系表达式
std::shared_ptr<ASTNode> Parser::ParseRelationalExpr()
{
    std::shared_ptr<ASTNode> left = ParseAddExpr();
    while (token.tokenType == TokenType::less || token.tokenType == TokenType::less_equal || token.tokenType == TokenType::greater || token.tokenType == TokenType::greater_equal)
    {
        OpCode op;
        if (token.tokenType == TokenType::less)
        {
            op = OpCode::less;
        }
        else if (token.tokenType == TokenType::less_equal)
        {
            op = OpCode::less_equal;
        }
        else if (token.tokenType == TokenType::greater)
        {
            op = OpCode::greater;
        }
        else if (token.tokenType == TokenType::greater_equal)
        {
            op = OpCode::greater_equal;
        }
        Advance();
        auto right = ParseAddExpr();
        auto binaryExpr = sema.SemaBinaryExprNode(left, right, op);

        left = binaryExpr;
    }
    return left;
}

// 解析加法表达式
std::shared_ptr<ASTNode> Parser::ParseAddExpr()
{
    std::shared_ptr<ASTNode> left = ParseMultiExpr();
    while (token.tokenType == TokenType::plus || token.tokenType == TokenType::minus)
    {
        OpCode op;
        if (token.tokenType == TokenType::plus)
        {
            op = OpCode::add;
        }
        else
        {
            op = OpCode::sub;
        }
        Advance();
        auto right = ParseMultiExpr();
        auto binaryExpr = sema.SemaBinaryExprNode(left, right, op);

        left = binaryExpr;
    }
    return left;
}

// 解析项
std::shared_ptr<ASTNode> Parser::ParseMultiExpr()
{
    std::shared_ptr<ASTNode> left = ParsePrimary();

    while (token.tokenType == TokenType::star || token.tokenType == TokenType::slash)
    {
        OpCode op;
        if (token.tokenType == TokenType::star)
        {
            op = OpCode::mul;
        }
        else
        {
            op = OpCode::div;
        }
        Advance();
        auto right = ParsePrimary();
        auto binaryExpr = sema.SemaBinaryExprNode(left, right, op);

        left = binaryExpr;
    }
    return left;
}

// 解析因子
std::shared_ptr<ASTNode> Parser::ParsePrimary()
{
    if (token.tokenType == TokenType::l_parent)
    {
        Consume(TokenType::l_parent);
        auto expr = ParseExpr();
        Consume(TokenType::r_parent);
        return expr;
    }
    else if (token.tokenType == TokenType::indentifier)
    {
        auto variableAccessExpr = sema.SemaVariableAccessNode(token);
        Consume(TokenType::indentifier);
        return variableAccessExpr;
    }
    else
    {
        Expect(TokenType::number);
        auto factorExpr = sema.SemaNumberExprNode(token, token.type);
        Consume(TokenType::number);
        return factorExpr;
    }
}

/// 消耗 token 函数
bool Parser::Expect(TokenType tokenType)
{
    if (token.tokenType == tokenType)
        return true;
    GetDiagEngine().Report(llvm::SMLoc::getFromPointer(token.ptr),
                           diag::err_expected,
                           Token::GetSpellingText(tokenType),
                           llvm::StringRef(token.ptr, token.length));
    return false;
}

bool Parser::Consume(TokenType tokenType)
{
    if (Expect(tokenType))
    {
        Advance();
        return true;
    }
    return false;
}

bool Parser::Advance()
{
    lexer.NextToken(token);
    return true;
}

printVisitor.cc

新增了关系运算,所以需要对 VisitBinaryExpr() 这个函数进行修改。

cpp 复制代码
#include "printVisitor.h"

PrintVisitor::PrintVisitor(std::shared_ptr<Program> program)
{
    VisitProgram(program.get());
}

llvm::Value *PrintVisitor::VisitProgram(Program *p)
{
    for (auto &expr : p->nodeVec)
    {
        expr->Accept(this);
        llvm::outs() << "\n";
    }
    return nullptr;
}

llvm::Value *PrintVisitor::VisitDeclStmt(DeclStmt *decl)
{
    for (auto node : decl->nodeVec)
    {
        node->Accept(this);
    }
    return nullptr;
}

llvm::Value *PrintVisitor::VisitBlockStmt(BlockStmt *block)
{
    llvm::outs() << "{\n";
    for (const auto &node : block->nodeVec)
    {
        node->Accept(this);
        llvm::outs() << "\n";
    }
    llvm::outs() << "}\n";
    return nullptr;
}

llvm::Value *PrintVisitor::VisitIfStmt(IfStmt *ifStmt)
{
    llvm::outs() << "if (";
    ifStmt->condNode->Accept(this);
    llvm::outs() << ")\n";
    ifStmt->thenNode->Accept(this);
    llvm::outs() << "\n";
    if (ifStmt->elseNode)
    {
        llvm::outs() << "else\n";
        ifStmt->elseNode->Accept(this);
    }
    return nullptr;
}

llvm::Value *PrintVisitor::VisitForStmt(ForStmt *forStmt)
{
    llvm::outs() << "for (";
    if (forStmt->initNode)
    {
        forStmt->initNode->Accept(this);
    }
    llvm::outs() << "; ";
    if (forStmt->condNode)
    {
        forStmt->condNode->Accept(this);
    }
    llvm::outs() << "; ";
    if (forStmt->incNode)
    {
        forStmt->incNode->Accept(this);
    }
    llvm::outs() << ")";
    if (forStmt->bodyNode)
    {
        forStmt->bodyNode->Accept(this);
    }
    return nullptr;
}

llvm::Value *PrintVisitor::VisitBreakStmt(BreakStmt *breakStmt) 
{
    llvm::outs() << "break;";
    return nullptr;
}

llvm::Value *PrintVisitor::VisitContinueStmt(ContinueStmt *continueStmt) 
{
    llvm::outs() << "continue;";
    return nullptr;
}

llvm::Value *PrintVisitor::VisitVariableDecl(VariableDecl *decl)
{
    if (decl->type == CType::getIntTy())
    {
        llvm::outs() << "int " << llvm::StringRef(decl->token.ptr, decl->token.length) << ";";
    }
    return nullptr;
}

llvm::Value *PrintVisitor::VisitVariableAccessExpr(VariableAccessExpr *varaccExpr)
{
    llvm::outs() << llvm::StringRef(varaccExpr->token.ptr, varaccExpr->token.length);
    return nullptr;
}

llvm::Value *PrintVisitor::VisitAssignExpr(AssignExpr *assignExpr)
{
    assignExpr->left->Accept(this);

    llvm::outs() << " = ";

    assignExpr->right->Accept(this);
    return nullptr;
}

llvm::Value *PrintVisitor::VisitBinaryExpr(BinaryExpr *binaryExpr)
{
    // 后序遍历
    binaryExpr->left->Accept(this);

    switch (binaryExpr->op)
    {
    case OpCode::add:
    {
        llvm::outs() << " + ";
        break;
    }
    case OpCode::sub:
    {
        llvm::outs() << " - ";
        break;
    }
    case OpCode::mul:
    {
        llvm::outs() << " * ";
        break;
    }
    case OpCode::div:
    {
        llvm::outs() << " / ";
        break;
    }
    case OpCode::equal_equal:
    {
        llvm::outs() << " == ";
        break;
    }
    case OpCode::not_equal:
    {
        llvm::outs() << " != ";
        break;
    }
    case OpCode::less:
    {
        llvm::outs() << " < ";
        break;
    }
    case OpCode::less_equal:
    {
        llvm::outs() << " <= ";
        break;
    }
    case OpCode::greater:
    {
        llvm::outs() << " > ";
        break;
    }
    case OpCode::greater_equal:
    {
        llvm::outs() << " >= ";
        break;
    }
    default:
    {
        break;
    }
    }

    binaryExpr->right->Accept(this);

    return nullptr;
}

llvm::Value *PrintVisitor::VisitNumberExpr(NumberExpr *factorExpr)
{
    llvm::outs() << llvm::StringRef(factorExpr->token.ptr, factorExpr->token.length);
    return nullptr;
}

1.4 代码生成 (CodeGen)

新增了关系运算,所以需要对 VisitBinaryExpr() 这个函数进行修改。

实现代码

codegen.cc

cpp 复制代码
#include "codegen.h"
#include "llvm/IR/Verifier.h"

llvm::Value *CodeGen::VisitProgram(Program *p)
{
    // 创建 printf 函数
    auto printFunctionType = llvm::FunctionType::get(irBuilder.getInt32Ty(), {irBuilder.getInt8PtrTy()}, true);
    auto printFunction = llvm::Function::Create(printFunctionType, llvm::GlobalValue::ExternalLinkage, "printf", module.get());
    // 创建 main 函数
    auto mainFunctionType = llvm::FunctionType::get(irBuilder.getInt32Ty(), false);
    auto mainFunction = llvm::Function::Create(mainFunctionType, llvm::GlobalValue::ExternalLinkage, "main", module.get());
    // 创建 main 函数的基本块
    llvm::BasicBlock *entryBB = llvm::BasicBlock::Create(context, "entry", mainFunction);
    // 设置该基本块作为指令的入口
    irBuilder.SetInsertPoint(entryBB);
    // 记录当前函数
    curFunc = mainFunction;

    llvm::Value *lastVal = nullptr;
    for (auto node : p->nodeVec)
    {
        lastVal = node->Accept(this);
    }
    if (lastVal)
        irBuilder.CreateCall(printFunction, {irBuilder.CreateGlobalStringPtr("expr value: %d\n"), lastVal});
    else
        irBuilder.CreateCall(printFunction, {irBuilder.CreateGlobalStringPtr("last instruction is not expr!\n")});

    // 创建返回值
    llvm::Value *ret = irBuilder.CreateRet(irBuilder.getInt32(0));

    llvm::verifyFunction(*mainFunction);

    module->print(llvm::outs(), nullptr);
    return ret;
}

llvm::Value *CodeGen::VisitDeclStmt(DeclStmt *declStmt)
{
    llvm::Value *lastVal = nullptr;
    for (auto node : declStmt->nodeVec)
    {
        lastVal = node->Accept(this);
    }
    return lastVal;
}

llvm::Value *CodeGen::VisitBlockStmt(BlockStmt *block)
{
    llvm::Value *lastVal = nullptr;
    for (auto node : block->nodeVec)
    {
        lastVal = node->Accept(this);
    }
    return lastVal;
}

llvm::Value *CodeGen::VisitIfStmt(IfStmt *ifStmt)
{
    llvm::BasicBlock *condBB = llvm::BasicBlock::Create(context, "if_cond", curFunc);
    llvm::BasicBlock *thenBB = llvm::BasicBlock::Create(context, "if_then", curFunc);
    llvm::BasicBlock *elseBB = nullptr;
    if (ifStmt->elseNode)
        elseBB = llvm::BasicBlock::Create(context, "if_else", curFunc);
    llvm::BasicBlock *lastBB = llvm::BasicBlock::Create(context, "if_last", curFunc);

    // 需要手动添加一个无条件跳转指令,llvm 不会自动完成这个工作
    irBuilder.CreateBr(condBB);
    irBuilder.SetInsertPoint(condBB);
    llvm::Value *ret = ifStmt->condNode->Accept(this);
    // 整形比较指令
    llvm::Value *condVal = irBuilder.CreateICmpNE(ret, irBuilder.getInt32(0)); // 这里需要判断条件是否为真

    if (ifStmt->elseNode)
    {
        irBuilder.CreateCondBr(condVal, thenBB, elseBB);

        // handle then basic block
        irBuilder.SetInsertPoint(thenBB);
        ifStmt->thenNode->Accept(this);
        irBuilder.CreateBr(lastBB);

        // handle else basic block
        irBuilder.SetInsertPoint(elseBB);
        ifStmt->elseNode->Accept(this);
        irBuilder.CreateBr(lastBB);
    }
    else
    {
        irBuilder.CreateCondBr(condVal, thenBB, lastBB);

        // handle then basic block
        irBuilder.SetInsertPoint(thenBB);
        ifStmt->thenNode->Accept(this);
        irBuilder.CreateBr(lastBB);
    }
    irBuilder.SetInsertPoint(lastBB);
    return nullptr;
}

llvm::Value *CodeGen::VisitVariableDecl(VariableDecl *decl)
{
    llvm::Type *ty = nullptr;
    if (decl->type == CType::getIntTy())
    {
        ty = irBuilder.getInt32Ty();
    }
    llvm::StringRef text(decl->token.ptr, decl->token.length);
    llvm::Value *varAddr = irBuilder.CreateAlloca(ty, nullptr, text);
    varAddrTyMap.insert({text, {varAddr, ty}});
    return varAddr;
}

llvm::Value *CodeGen::VisitVariableAccessExpr(VariableAccessExpr *varaccExpr)
{
    llvm::StringRef text(varaccExpr->token.ptr, varaccExpr->token.length);
    std::pair pair = varAddrTyMap[text];
    llvm::Value *varAddr = pair.first;
    llvm::Type *ty = pair.second;
    if (varaccExpr->type == CType::getIntTy())
    {
        ty = irBuilder.getInt32Ty();
    }
    // 返回一个右值
    return irBuilder.CreateLoad(ty, varAddr, text);
}

// a = 3; // right value
llvm::Value *CodeGen::VisitAssignExpr(AssignExpr *assignExpr)
{
    VariableAccessExpr *varAccExpr = (VariableAccessExpr *)assignExpr->left.get();
    llvm::StringRef text(varAccExpr->token.ptr, varAccExpr->token.length);
    std::pair pair = varAddrTyMap[text];
    llvm::Value *addr = pair.first;
    llvm::Type *ty = pair.second;
    llvm::Value *rValue = assignExpr->right->Accept(this);
    // 这个得到的是一个左值
    irBuilder.CreateStore(rValue, addr);
    // 返回一个右值
    return irBuilder.CreateLoad(ty, addr, text);
}

llvm::Value *CodeGen::VisitBinaryExpr(BinaryExpr *binaryExpr)
{
    auto left = binaryExpr->left->Accept(this);
    auto right = binaryExpr->right->Accept(this);

    switch (binaryExpr->op)
    {
    case OpCode::add:
    {
        return irBuilder.CreateNSWAdd(left, right, "add"); // CreateNSW... 是防止溢出行为的
    }
    case OpCode::sub:
    {
        return irBuilder.CreateNSWSub(left, right, "sub");
    }
    case OpCode::mul:
    {
        return irBuilder.CreateNSWMul(left, right, "mul");
    }
    case OpCode::div:
    {
        return irBuilder.CreateSDiv(left, right, "div");
    }
    case OpCode::equal_equal:
    {
        llvm::Value *val = irBuilder.CreateICmpEQ(left, right, "equal_equal");
        return irBuilder.CreateIntCast(val, irBuilder.getInt32Ty(), true);
    }
    case OpCode::not_equal:
    {
        llvm::Value *val = irBuilder.CreateICmpNE(left, right, "not_equal");
        return irBuilder.CreateIntCast(val, irBuilder.getInt32Ty(), true);
    }
    case OpCode::less:
    {
        llvm::Value *val = irBuilder.CreateICmpSLT(left, right, "less");
        return irBuilder.CreateIntCast(val, irBuilder.getInt32Ty(), true);
    }
    case OpCode::less_equal:
    {
        llvm::Value *val = irBuilder.CreateICmpSLE(left, right, "less_equal");
        return irBuilder.CreateIntCast(val, irBuilder.getInt32Ty(), true);
    }
    case OpCode::greater:
    {
        llvm::Value *val = irBuilder.CreateICmpSGT(left, right, "greater");
        return irBuilder.CreateIntCast(val, irBuilder.getInt32Ty(), true);
    }
    case OpCode::greater_equal:
    {
        llvm::Value *val = irBuilder.CreateICmpSGE(left, right, "greater_equal");
        return irBuilder.CreateIntCast(val, irBuilder.getInt32Ty(), true);
    }
    default:
    {
        break;
    }
    }
    return nullptr;
}

llvm::Value *CodeGen::VisitNumberExpr(NumberExpr *factorExpr)
{
    return irBuilder.getInt32(factorExpr->token.value);
}

1.5 测试关系运算功能

生成 IR 文件

bash 复制代码
bin/for test/expr.txt > test/expr.ll

生成的 IR 内容

ini 复制代码
; ModuleID = 'expr'
source_filename = "expr"

@0 = private unnamed_addr constant [16 x i8] c"expr value: %d\0A\00", align 1

declare i32 @printf(ptr, ...)

define i32 @main() {
entry:
  %a = alloca i32, align 4
  store i32 1, ptr %a, align 4
  %a1 = load i32, ptr %a, align 4
  %b = alloca i32, align 4
  store i32 4, ptr %b, align 4
  %b2 = load i32, ptr %b, align 4
  br label %if_cond

if_cond:                                          ; preds = %entry
  %a3 = load i32, ptr %a, align 4
  %b4 = load i32, ptr %b, align 4
  %greater = icmp sgt i32 %a3, %b4
  %0 = sext i1 %greater to i32
  %1 = icmp ne i32 %0, 0
  br i1 %1, label %if_then, label %if_last

if_then:                                          ; preds = %if_cond
  %a5 = load i32, ptr %a, align 4
  %b6 = load i32, ptr %b, align 4
  %add = add nsw i32 %a5, %b6
  store i32 %add, ptr %b, align 4
  %b7 = load i32, ptr %b, align 4
  br label %if_cond8

if_last:                                          ; preds = %if_last10, %if_cond
  %a16 = load i32, ptr %a, align 4
  %b17 = load i32, ptr %b, align 4
  %add18 = add nsw i32 %b17, 3
  %mul19 = mul nsw i32 %a16, %add18
  %sub20 = sub nsw i32 %mul19, 2
  %2 = call i32 (ptr, ...) @printf(ptr @0, i32 %sub20)
  ret i32 0

if_cond8:                                         ; preds = %if_then
  %a11 = load i32, ptr %a, align 4
  %b12 = load i32, ptr %b, align 4
  %less_equal = icmp sle i32 %a11, %b12
  %3 = sext i1 %less_equal to i32
  %4 = icmp ne i32 %3, 0
  br i1 %4, label %if_then9, label %if_last10

if_then9:                                         ; preds = %if_cond8
  %a13 = load i32, ptr %a, align 4
  %mul = mul nsw i32 2, %a13
  %b14 = load i32, ptr %b, align 4
  %sub = sub nsw i32 %mul, %b14
  store i32 %sub, ptr %b, align 4
  %b15 = load i32, ptr %b, align 4
  br label %if_last10

if_last10:                                        ; preds = %if_then9, %if_cond8
  br label %if_last
}

运行 IR

bash 复制代码
lli test/expr.ll

运行结果

bash 复制代码
expr value: 5

结果正确,验证了编译器目前基础的关系运算功能的正确性。

2、实现 for 语句功能

2.1 测试文件

expr.txt

ini 复制代码
int a = 1;
int b = 4;
for (int i = 0; i < 4; i = i + 1)
{
	a = a + i;
	b = b * a + 2 * i;
}
a * (b + 3) - 2;

2.2 词法分析器 (Lexer)

文法定义

ini 复制代码
prog : stmt*
stmt : decl-stmt | expr-stmt | null-stmt | if-stmt | block-stmt | for-stmt
null-stmt : ";"
decl-stmt : "int" identifier ("," identifier (= expr)?)* ";"
if-stmt : "if" "(" expr ")" stmt ( "else" stmt )?
for-stmt : "for" "(" expr? ";" expr? ";" expr? ")" stmt
		"for" "(" decl-stmt expr? ";" expr? ")" stmt
block-stmt: "{" stmt* "}"
expr-stmt : expr ";"
expr : assign-expr | equal-expr
assign-expr: identifier "=" expr
equal-expr : relational-expr (("==" | "!=") relational-expr)*
relational-expr: add-expr (("<"|">"|"<="|">=") add-expr)*
add-expr : mult-expr (("+" | "-") mult-expr)* 
mult-expr : primary-expr (("*" | "/") primary-expr)* 
primary-expr : identifier | number | "(" expr ")" 
number: ([0-9])+ 
identifier : (a-zA-Z_)(a-zA-Z0-9_)*

实现代码

对词法分析器新增 for 类型以及对应的检测方法。

lexer.h

cpp 复制代码
#pragma once

#include "llvm/ADT/StringRef.h"
#include "llvm/Support/raw_ostream.h"
#include "type.h"
#include "diag_engine.h"

// char stream -> token

enum class TokenType
{
    number,        // [0-9]+
    indentifier,   // 变量
    kw_int,        // int
    kw_if,         // if
    kw_else,       // else
    kw_for,        // for
    plus,          // +
    minus,         // -
    star,          // *
    slash,         // /
    equal,         // =
    equal_equal,   // == 
    not_equal,     // !=
    less,          // <
    less_equal,    // <=
    greater,       // >
    greater_equal, // >=
    l_parent,      // (
    r_parent,      // )
    l_brace,       // {
    r_brace,       // }
    semi,          // ;
    comma,         // ,
    eof            // end of file
};

class Token
{
public:
    TokenType tokenType; // token 的种类
    int row, col;

    int value; // for number

    const char *ptr; // for debug
    int length;

    CType *type; // for built-in type

public:
    void Dump();
    static llvm::StringRef GetSpellingText(TokenType tokenType);
};

class Lexer
{
private:
    DiagEngine &diagEngine;
    llvm::SourceMgr &mgr;

public:
    Lexer(DiagEngine &diagEngine, llvm::SourceMgr &mgr) : diagEngine(diagEngine), mgr(mgr)
    {
        unsigned id = mgr.getMainFileID();
        llvm::StringRef buf = mgr.getMemoryBuffer(id)->getBuffer();
        BufPtr = buf.begin();
        LineHeadPtr = buf.begin();
        BufEnd = buf.end();
        row = 1;
    }
    void NextToken(Token &token);

    void SaveState();
    void RestoreState();

    DiagEngine &GetDiagEngine() const
    {
        return diagEngine;
    }

private:
    struct State
    {
        const char *BufPtr;
        const char *LineHeadPtr;
        const char *BufEnd;
        int row;
    };

private:
    const char *BufPtr;
    const char *LineHeadPtr;
    const char *BufEnd;
    int row;

    State state;
};

lexer.cc

cpp 复制代码
#include "lexer.h"

void Token::Dump()
{
    llvm::StringRef text(ptr, length);
    llvm::outs() << "{" << text << ", row = " << row << ", col = " << col << "}\n";
}

llvm::StringRef Token::GetSpellingText(TokenType tokenType)
{
    switch (tokenType)
    {
    case TokenType::kw_int:
        return "int";
    case TokenType::plus:
        return "+";
    case TokenType::minus:
        return "-";
    case TokenType::star:
        return "*";
    case TokenType::slash:
        return "/";
    case TokenType::equal:
        return "=";
    case TokenType::l_parent:
        return "(";
    case TokenType::r_parent:
        return ")";
    case TokenType::semi:
        return ";";
    case TokenType::comma:
        return ",";
    case TokenType::number:
        return "number";
    case TokenType::indentifier:
        return "indentifier";
    case TokenType::kw_if:
        return "if";
    case TokenType::kw_else:
        return "else";
    case TokenType::kw_for:
        return "for";
    case TokenType::l_brace:
        return "{";
    case TokenType::r_brace:
        return "}";
    case TokenType::equal_equal:
        return "==";
    case TokenType::not_equal:
        return "!=";
    case TokenType::less:
        return "<";
    case TokenType::less_equal:
        return "<=";
    case TokenType::greater:
        return ">";
    case TokenType::greater_equal:
        return ">=";
    default:
        llvm::llvm_unreachable_internal(); // 不可能到达这个位置
    }
}

bool IsWhiteSpace(char ch)
{
    return ch == ' ' || ch == '\t' || ch == '\r' || ch == '\n';
}

bool IsDigit(char ch)
{
    return ch >= '0' && ch <= '9';
}

bool IsLetter(char ch)
{
    // a-z, A-Z, _
    return (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch == '_');
}

void Lexer::NextToken(Token &token)
{
    // 过滤空格
    while (IsWhiteSpace(*BufPtr))
    {
        if (*BufPtr == '\n')
        {
            row += 1;
            LineHeadPtr = BufPtr + 1;
        }
        BufPtr++;
    }

    token.row = row;
    token.col = BufPtr - LineHeadPtr + 1;

    // 判断是否到结尾了
    if (BufPtr >= BufEnd)
    {
        token.tokenType = TokenType::eof;
        return;
    }

    token.ptr = BufPtr;
    token.length = 0;
    // 判断是否为数字
    if (IsDigit(*BufPtr))
    {
        int len = 0;
        int val = 0;
        while (IsDigit(*BufPtr))
        {
            val = val * 10 + *BufPtr++ - '0';
            token.length++;
        }
        token.value = val;
        token.tokenType = TokenType::number;
        token.type = CType::getIntTy();
    }
    else if (IsLetter(*BufPtr)) // 为变量
    {
        while (IsLetter(*BufPtr) || IsDigit(*BufPtr))
        {
            BufPtr++;
        }
        token.tokenType = TokenType::indentifier;
        token.length = BufPtr - token.ptr;
        llvm::StringRef text(token.ptr, BufPtr - token.ptr);
        if (text == "int")
        {
            token.tokenType = TokenType::kw_int;
        }
        else if (text == "if")
        {
            token.tokenType = TokenType::kw_if;
        }
        else if (text == "else")
        {
            token.tokenType = TokenType::kw_else;
        }
        else if (text == "for")
        {
            token.tokenType = TokenType::kw_for;
        }
    }
    else // 为特殊字符
    {
        switch (*BufPtr)
        {
        case '+':
        {
            token.tokenType = TokenType::plus;
            token.length = 1;
            BufPtr++;
            break;
        }
        case '-':
        {
            token.tokenType = TokenType::minus;
            token.length = 1;
            BufPtr++;
            break;
        }
        case '*':
        {
            token.tokenType = TokenType::star;
            token.length = 1;
            BufPtr++;
            break;
        }
        case '/':
        {
            token.tokenType = TokenType::slash;
            token.length = 1;
            BufPtr++;
            break;
        }
        case '=':
        {
            if (*(BufPtr + 1) == '=')
            {
                token.tokenType = TokenType::equal_equal;
                token.length = 2;
                BufPtr += 2;
            }
            else
            {
                token.tokenType = TokenType::equal;
                token.length = 1;
                BufPtr++;
            }
            break;
        }
        case '!':
        {
            if (*(BufPtr + 1) == '=')
            {
                token.tokenType = TokenType::not_equal;
                token.length = 2;
                BufPtr += 2;
                break;
            }
            // pass through
        }
        case '<':
        {
            if (*(BufPtr + 1) == '=')
            {
                token.tokenType = TokenType::less_equal;
                token.length = 2;
                BufPtr += 2;
            }
            else
            {
                token.tokenType = TokenType::less;
                token.length = 1;
                BufPtr++;
            }
            break;
        }
        case '>':
        {
            if (*(BufPtr + 1) == '=')
            {
                token.tokenType = TokenType::greater_equal;
                token.length = 2;
                BufPtr += 2;
            }
            else
            {
                token.tokenType = TokenType::greater;
                token.length = 1;
                BufPtr++;
            }
            break;
        }
        case '(':
        {
            token.tokenType = TokenType::l_parent;
            token.length = 1;
            BufPtr++;
            break;
        }
        case ')':
        {
            token.tokenType = TokenType::r_parent;
            token.length = 1;
            BufPtr++;
            break;
        }
        case ';':
        {
            token.tokenType = TokenType::semi;
            token.length = 1;
            BufPtr++;
            break;
        }
        case ',':
        {
            token.tokenType = TokenType::comma;
            token.length = 1;
            BufPtr++;
            break;
        }
        case '{':
        {
            token.tokenType = TokenType::l_brace;
            token.length = 1;
            BufPtr++;
            break;
        }
        case '}':
        {
            token.tokenType = TokenType::r_brace;
            token.length = 1;
            BufPtr++;
            break;
        }
        default:
        {
            diagEngine.Report(llvm::SMLoc::getFromPointer(BufPtr), diag::err_unknown_char, *BufPtr);
        }
        }
    }
}

void Lexer::SaveState()
{
    state.LineHeadPtr = LineHeadPtr;
    state.BufPtr = BufPtr;
    state.BufEnd = BufEnd;
    state.row = row;
}

void Lexer::RestoreState()
{
    LineHeadPtr = state.LineHeadPtr;
    BufPtr = state.BufPtr;
    BufEnd = state.BufEnd;
    row = state.row;
}

2.3 语法分析器 (Parser)

根据文法新增的 for 语句新增一个 for 语句的类,解析 for 语句的函数以及对 ParseStmt() 函数进行对 for 语句的判断。

实现代码

ast.h

cpp 复制代码
#pragma once

#include <vector>
#include <memory>
#include "llvm/IR/Value.h"
#include "type.h"
#include "lexer.h"

enum class OpCode
{
    add,
    sub,
    mul,
    div,
    equal_equal,
    not_equal,
    less,
    less_equal,
    greater,
    greater_equal
};

// 进行声明
class Program;
class Expr;
class DeclStmt;
class BlockStmt;
class ForStmt;
class VariableDecl;
class IfStmt;
class VariableAccessExpr;
class BinaryExpr;
class AssignExpr;
class NumberExpr;

// 访问者模式
class Visitor
{
public:
    virtual ~Visitor() {}
    virtual llvm::Value *VisitProgram(Program *p) = 0;
    virtual llvm::Value *VisitDeclStmt(DeclStmt *decl) = 0;
    virtual llvm::Value *VisitBlockStmt(BlockStmt *block) = 0;
    virtual llvm::Value *VisitIfStmt(IfStmt *ifStmt) = 0;
    virtual llvm::Value *VisitForStmt(ForStmt *forStmt) = 0;
    virtual llvm::Value *VisitVariableDecl(VariableDecl *decl) = 0;
    virtual llvm::Value *VisitVariableAccessExpr(VariableAccessExpr *varaccExpr) = 0;
    virtual llvm::Value *VisitAssignExpr(AssignExpr *assignExpr) = 0;
    virtual llvm::Value *VisitBinaryExpr(BinaryExpr *binaryExpr) = 0;
    virtual llvm::Value *VisitNumberExpr(NumberExpr *factorExpr) = 0;
};

// 语法树的公共节点
class ASTNode
{
public:
    enum Kind
    {
        ND_BlockStmt,
        ND_ForStmt,
        ND_IfStmt,
        ND_DeclStmt,
        ND_VariableDecl,
        ND_BinaryExpr,
        ND_NumberExpr,
        ND_VariableAccessExpr,
        ND_AssignExpr
    };

private:
    const Kind kind;

public:
    ASTNode(Kind kind) : kind(kind) {}
    virtual ~ASTNode() {}
    virtual llvm::Value *Accept(Visitor *v) { return nullptr; } // 通过虚函数的特性完成分发的功能
    const Kind getKind() const { return kind; }

public:
    CType *type;
    Token token;
};

// 声明语句节点
class DeclStmt : public ASTNode
{
public:
    DeclStmt() : ASTNode(ND_DeclStmt) {}

    llvm::Value *Accept(Visitor *v) override
    {
        return v->VisitDeclStmt(this);
    }

    static bool classof(const ASTNode *node) // 判断是否能进行指针的强转
    {
        return node->getKind() == ND_DeclStmt;
    }

public:
    std::vector<std::shared_ptr<ASTNode>> nodeVec;
};

// 块语句节点
class BlockStmt : public ASTNode
{
public:
    BlockStmt() : ASTNode(ND_BlockStmt) {}

    llvm::Value *Accept(Visitor *v) override
    {
        return v->VisitBlockStmt(this);
    }

    static bool classof(const ASTNode *node) // 判断是否能进行指针的强转
    {
        return node->getKind() == ND_BlockStmt;
    }

public:
    std::vector<std::shared_ptr<ASTNode>> nodeVec;
};

// 条件判断节点
class IfStmt : public ASTNode
{
public:
    IfStmt() : ASTNode(ND_IfStmt) {}

    llvm::Value *Accept(Visitor *v) override
    {
        return v->VisitIfStmt(this);
    }

    static bool classof(const ASTNode *node) // 判断是否能进行指针的强转
    {
        return node->getKind() == ND_IfStmt;
    }

public:
    std::shared_ptr<ASTNode> condNode;
    std::shared_ptr<ASTNode> thenNode;
    std::shared_ptr<ASTNode> elseNode;
};

// 循环节点
class ForStmt : public ASTNode
{
public:
    ForStmt() : ASTNode(ND_ForStmt) {}

    llvm::Value *Accept(Visitor *v) override
    {
        return v->VisitForStmt(this);
    }

    static bool classof(const ASTNode *node) // 判断是否能进行指针的强转
    {
        return node->getKind() == ND_ForStmt;
    }

public:
    std::shared_ptr<ASTNode> initNode;
    std::shared_ptr<ASTNode> condNode;
    std::shared_ptr<ASTNode> incNode;
    std::shared_ptr<ASTNode> bodyNode;
};

// 变量声明节点
class VariableDecl : public ASTNode
{
public:
    VariableDecl() : ASTNode(ND_VariableDecl) {}

    llvm::Value *Accept(Visitor *v) override
    {
        return v->VisitVariableDecl(this);
    }

    static bool classof(const ASTNode *node) // 判断是否能进行指针的强转
    {
        return node->getKind() == ND_VariableDecl;
    }
};

// 二元表达式节点
class BinaryExpr : public ASTNode
{
public:
    BinaryExpr() : ASTNode(ND_BinaryExpr) {}

    llvm::Value *Accept(Visitor *v) override
    {
        return v->VisitBinaryExpr(this);
    }

    static bool classof(const ASTNode *node) // 判断是否能进行指针的强转
    {
        return node->getKind() == ND_BinaryExpr;
    }

public:
    OpCode op;
    std::shared_ptr<ASTNode> left;
    std::shared_ptr<ASTNode> right;
};

// 赋值表达式节点
class AssignExpr : public ASTNode
{
public:
    AssignExpr() : ASTNode(ND_AssignExpr) {}

    llvm::Value *Accept(Visitor *v) override
    {
        return v->VisitAssignExpr(this);
    }

    static bool classof(const ASTNode *node) // 判断是否能进行指针的强转
    {
        return node->getKind() == ND_AssignExpr;
    }

public:
    std::shared_ptr<ASTNode> left;
    std::shared_ptr<ASTNode> right;
};

// 数字表达式节点
class NumberExpr : public ASTNode
{
public:
    NumberExpr() : ASTNode(ND_NumberExpr) {}

    llvm::Value *Accept(Visitor *v) override
    {
        return v->VisitNumberExpr(this);
    }

    static bool classof(const ASTNode *node) // 判断是否能进行指针的强转
    {
        return node->getKind() == ND_NumberExpr;
    }
};

// 变量访问节点
class VariableAccessExpr : public ASTNode // 变量表达式
{
public:
    VariableAccessExpr() : ASTNode(ND_VariableAccessExpr) {}

    llvm::Value *Accept(Visitor *v) override
    {
        return v->VisitVariableAccessExpr(this);
    }

    static bool classof(const ASTNode *node) // 判断是否能进行指针的强转
    {
        return node->getKind() == ND_VariableAccessExpr;
    }
};

// 目标程序
class Program
{
public:
    std::vector<std::shared_ptr<ASTNode>> nodeVec;
};

parser.h

cpp 复制代码
#pragma once

#include "ast.h"
#include "lexer.h"
#include "sema.h"
#include <vector>

class Parser
{
public:
    Parser(Lexer &lexer, Sema &sema) : lexer(lexer), sema(sema)
    {
        Advance();
    }
    std::shared_ptr<Program> ParseProgram();

private:
    std::shared_ptr<ASTNode> ParseStmt();
    std::shared_ptr<ASTNode> ParseDeclStmt();
    std::shared_ptr<ASTNode> ParseBlockStmt();
    std::shared_ptr<ASTNode> ParseExprStmt();
    std::shared_ptr<ASTNode> ParseIfStmt();
    std::shared_ptr<ASTNode> ParseForStmt();
    std::shared_ptr<ASTNode> ParseExpr();
    std::shared_ptr<ASTNode> ParseAssignExpr();
    std::shared_ptr<ASTNode> ParseEqualExpr();
    std::shared_ptr<ASTNode> ParseRelationalExpr();
    std::shared_ptr<ASTNode> ParseAddExpr();
    std::shared_ptr<ASTNode> ParseMultiExpr();
    std::shared_ptr<ASTNode> ParsePrimary();

    // 消耗 token 的函数
    // 检测 token 的类型
    bool Expect(TokenType tokenType);
    // 检测 token 的类型并消费
    bool Consume(TokenType tokenType);
    // 直接消耗当前的 token
    bool Advance();

    // 检查是否为类型名
    bool IsTypeName();

    DiagEngine &GetDiagEngine() const
    {
        return lexer.GetDiagEngine();
    }

private:
    Lexer &lexer;
    Sema &sema;
    Token token;
    std::vector<std::shared_ptr<ASTNode>> breakNodes;
    std::vector<std::shared_ptr<ASTNode>> continueNodes;
};

parser.cc

cpp 复制代码
#include "parser.h"
#include <cassert>

/*
prog : stmt*
stmt : decl-stmt | expr-stmt | null-stmt | if-stmt
null-stmt : ";"
decl-stmt : "int" identifier ("," identifier (= expr)?)* ";"
if-stmt : "if" "(" expr ")" stmt ( "else" stmt )?
expr-stmt : expr ";"
expr : assign-expr | add-expr
assign-expr: identifier "=" expr
add-expr : mult-expr (("+" | "-") mult-expr)*
mult-expr : primary-expr (("*" | "/") primary-expr)*
primary-expr : identifier | number | "(" expr ")"
number: ([0-9])+
identifier : (a-zA-Z_)(a-zA-Z0-9_)*
*/

// 解析目标程序
// stmt : decl-stmt | expr-stmt | null-stmt
std::shared_ptr<Program> Parser::ParseProgram()
{
    std::vector<std::shared_ptr<ASTNode>> nodeVec;
    while (token.tokenType != TokenType::eof)
    {
        auto stmt = ParseStmt();
        if (stmt)
            nodeVec.push_back(stmt);
    }
    auto program = std::make_shared<Program>();
    program->nodeVec = std::move(nodeVec);
    return program;
}

// 解析语句
std::shared_ptr<ASTNode> Parser::ParseStmt()
{
    // 遇到 ; 需要进行消费 token
    // null-stmt
    if (token.tokenType == TokenType::semi)
    {
        Consume(TokenType::semi);
        return nullptr;
    }
    // decl-stmt
    else if (IsTypeName())
    {
        return ParseDeclStmt();
    }
    // block-stmt
    else if (token.tokenType == TokenType::l_brace)
    {
        return ParseBlockStmt();
    }
    // if-stmt
    else if (token.tokenType == TokenType::kw_if)
    {
        return ParseIfStmt();
    }
    // for-stmt
    else if (token.tokenType == TokenType::kw_for)
    {
        return ParseForStmt();
    }
    // break-stmt
    else if (token.tokenType == TokenType::kw_break)
    {
        return ParseBreakStmt();
    }
    // continue-stmt
    else if (token.tokenType == TokenType::kw_continue)
    {
        return ParseContinueStmt();
    }
    // expr-stmt
    else
    {
        return ParseExprStmt();
    }
}

// 解析声明语句
std::shared_ptr<ASTNode> Parser::ParseDeclStmt()
{
    /// int a, b = 3;
    /// int a = 3;
    Consume(TokenType::kw_int);
    CType *baseTy = CType::getIntTy();
    /// a , b = 3;
    /// a = 3;

    auto declStmt = std::make_shared<DeclStmt>();

    /// a, b = 3;
    /// a = 3;
    int i = 0;
    while (token.tokenType != TokenType::semi)
    {
        if (i++ > 0) // if (i++)
        {
            assert(Consume(TokenType::comma));
        }

        /// 变量声明的节点: int a = 3; -> int a; a = 3;
        // a = 3;
        auto variableDecl = sema.SemaVariableDeclNode(token, baseTy); // get a type
        declStmt->nodeVec.push_back(variableDecl);

        Token tmp = token;
        Consume(TokenType::indentifier);

        // = 3;
        if (token.tokenType == TokenType::equal)
        {
            Token opToken = token;
            Advance();

            // 3;
            auto right = ParseExpr();
            auto left = sema.SemaVariableAccessNode(tmp);
            auto assign = sema.SemaAssignExprNode(left, right, opToken);

            declStmt->nodeVec.push_back(assign);
        }
    }

    Consume(TokenType::semi);

    return declStmt;
}

std::shared_ptr<ASTNode> Parser::ParseBlockStmt()
{
    sema.EnterScope(); // 进入作用域

    auto blockStmt = std::make_shared<BlockStmt>();
    Consume(TokenType::l_brace);
    while (token.tokenType != TokenType::r_brace)
    {
        auto stmt = ParseStmt();
        if (stmt)
            blockStmt->nodeVec.push_back(stmt);
    }
    Consume(TokenType::r_brace);

    sema.ExitScope(); // 离开作用域

    return blockStmt;
}

// 解析表达式语句
std::shared_ptr<ASTNode> Parser::ParseExprStmt()
{
    auto expr = ParseExpr();
    Consume(TokenType::semi);
    return expr;
}

// if-stmt : "if" "(" expr ")" stmt ( "else" stmt )?
/*
if (a)
  b = 3;
else
    b = 4;
*/
std::shared_ptr<ASTNode> Parser::ParseIfStmt()
{
    Consume(TokenType::kw_if);
    Consume(TokenType::l_parent);
    auto condExpr = ParseExpr();
    Consume(TokenType::r_parent);
    auto thenStmt = ParseStmt();
    std::shared_ptr<ASTNode> elseStmt = nullptr;
    if (token.tokenType == TokenType::kw_else)
    {
        Consume(TokenType::kw_else);
        elseStmt = ParseStmt();
    }
    return sema.SemaIfStmtNode(condExpr, thenStmt, elseStmt);
}

/*
for-stmt :
"for" "(" expr? ";" expr? ";" expr? ")" stmt
"for" "(" decl-stmt expr? ";" expr? ")" stmt
*/
// 解析 for-stmt
std::shared_ptr<ASTNode> Parser::ParseForStmt()
{
    sema.EnterScope();

    Consume(TokenType::kw_for);
    Consume(TokenType::l_parent);

    auto forStmt = std::make_shared<ForStmt>();

    std::shared_ptr<ASTNode> initNode = nullptr;
    std::shared_ptr<ASTNode> condNode = nullptr;
    std::shared_ptr<ASTNode> incNode  = nullptr;
    std::shared_ptr<ASTNode> bodyNode = nullptr;

    if (IsTypeName())
    {
        initNode = ParseDeclStmt();
    }
    else
    {
        initNode = ParseExpr();
        Consume(TokenType::semi);
    }

    condNode = ParseExpr();
    Consume(TokenType::semi);

    incNode = ParseExpr();
    Consume(TokenType::r_parent);

    bodyNode = ParseStmt();

    forStmt->initNode = initNode;
    forStmt->condNode = condNode;
    forStmt->incNode = incNode;
    forStmt->bodyNode = bodyNode;

    breakNodes.pop_back();
    continueNodes.pop_back();

    sema.ExitScope();

    return forStmt;
}

// 解析表达式
// expr : assign-expr | add-expr
// assign-expr: identifier "=" expr
// add-expr : mult-expr (("+" | "-") mult-expr)*
std::shared_ptr<ASTNode> Parser::ParseExpr()
{
    lexer.SaveState();
    bool isAssign = false;
    Token tmp = token;
    // a = b;
    if (tmp.tokenType == TokenType::indentifier)
    {
        lexer.NextToken(tmp);
        if (tmp.tokenType == TokenType::equal)
        {
            isAssign = true;
        }
    }
    lexer.RestoreState();
    if (isAssign)
    {
        return ParseAssignExpr();
    }
    else // equal-expr
        return ParseEqualExpr();
}

// 解析赋值表达式
std::shared_ptr<ASTNode> Parser::ParseAssignExpr()
{
    // a = b;
    Token tmp = token;
    Consume(TokenType::indentifier);
    auto expr = sema.SemaVariableAccessNode(tmp);
    Token opToken = token;
    Consume(TokenType::equal);
    return sema.SemaAssignExprNode(expr, ParseExpr(), opToken);
}

// 解析等号表达式
std::shared_ptr<ASTNode> Parser::ParseEqualExpr()
{
    std::shared_ptr<ASTNode> left = ParseRelationalExpr();
    while (token.tokenType == TokenType::equal_equal || token.tokenType == TokenType::not_equal)
    {
        OpCode op;
        if (token.tokenType == TokenType::equal_equal)
        {
            op = OpCode::equal_equal;
        }
        else
        {
            op = OpCode::not_equal;
        }
        Advance();
        auto right = ParseRelationalExpr();
        auto binaryExpr = sema.SemaBinaryExprNode(left, right, op);

        left = binaryExpr;
    }
    return left;
}

// 解析关系表达式
std::shared_ptr<ASTNode> Parser::ParseRelationalExpr()
{
    std::shared_ptr<ASTNode> left = ParseAddExpr();
    while (token.tokenType == TokenType::less || token.tokenType == TokenType::less_equal || token.tokenType == TokenType::greater || token.tokenType == TokenType::greater_equal)
    {
        OpCode op;
        if (token.tokenType == TokenType::less)
        {
            op = OpCode::less;
        }
        else if (token.tokenType == TokenType::less_equal)
        {
            op = OpCode::less_equal;
        }
        else if (token.tokenType == TokenType::greater)
        {
            op = OpCode::greater;
        }
        else if (token.tokenType == TokenType::greater_equal)
        {
            op = OpCode::greater_equal;
        }
        Advance();
        auto right = ParseAddExpr();
        auto binaryExpr = sema.SemaBinaryExprNode(left, right, op);

        left = binaryExpr;
    }
    return left;
}

// 解析加法表达式
std::shared_ptr<ASTNode> Parser::ParseAddExpr()
{
    std::shared_ptr<ASTNode> left = ParseMultiExpr();
    while (token.tokenType == TokenType::plus || token.tokenType == TokenType::minus)
    {
        OpCode op;
        if (token.tokenType == TokenType::plus)
        {
            op = OpCode::add;
        }
        else
        {
            op = OpCode::sub;
        }
        Advance();
        auto right = ParseMultiExpr();
        auto binaryExpr = sema.SemaBinaryExprNode(left, right, op);

        left = binaryExpr;
    }
    return left;
}

// 解析项
std::shared_ptr<ASTNode> Parser::ParseMultiExpr()
{
    std::shared_ptr<ASTNode> left = ParsePrimary();

    while (token.tokenType == TokenType::star || token.tokenType == TokenType::slash)
    {
        OpCode op;
        if (token.tokenType == TokenType::star)
        {
            op = OpCode::mul;
        }
        else
        {
            op = OpCode::div;
        }
        Advance();
        auto right = ParsePrimary();
        auto binaryExpr = sema.SemaBinaryExprNode(left, right, op);

        left = binaryExpr;
    }
    return left;
}

// 解析因子
std::shared_ptr<ASTNode> Parser::ParsePrimary()
{
    if (token.tokenType == TokenType::l_parent)
    {
        Consume(TokenType::l_parent);
        auto expr = ParseExpr();
        Consume(TokenType::r_parent);
        return expr;
    }
    else if (token.tokenType == TokenType::indentifier)
    {
        auto variableAccessExpr = sema.SemaVariableAccessNode(token);
        Consume(TokenType::indentifier);
        return variableAccessExpr;
    }
    else
    {
        Expect(TokenType::number);
        auto factorExpr = sema.SemaNumberExprNode(token, token.type);
        Consume(TokenType::number);
        return factorExpr;
    }
}

/// 消耗 token 函数
bool Parser::Expect(TokenType tokenType)
{
    if (token.tokenType == tokenType)
        return true;
    GetDiagEngine().Report(llvm::SMLoc::getFromPointer(token.ptr),
                           diag::err_expected,
                           Token::GetSpellingText(tokenType),
                           llvm::StringRef(token.ptr, token.length));
    return false;
}

bool Parser::Consume(TokenType tokenType)
{
    if (Expect(tokenType))
    {
        Advance();
        return true;
    }
    return false;
}

bool Parser::Advance()
{
    lexer.NextToken(token);
    return true;
}

bool Parser::IsTypeName()
{
    if (token.tokenType == TokenType::kw_int)
    {
        return true;
    }
    return false;
}

printVisitor.h

因为基类为了访问 for 语句,新增了对 for 语句的抽象访问函数,所以子类需要对这个函数进行复写。

cpp 复制代码
#pragma once

#include "ast.h"
#include "parser.h"

class PrintVisitor : public Visitor
{
public:
    PrintVisitor(std::shared_ptr<Program> program);

public:
    llvm::Value *VisitProgram(Program *p) override;
    llvm::Value *VisitDeclStmt(DeclStmt *decl) override;
    llvm::Value *VisitBlockStmt(BlockStmt *block) override;
    llvm::Value *VisitIfStmt(IfStmt *ifStmt) override;
    llvm::Value *VisitForStmt(ForStmt *forStmt) override;
    llvm::Value *VisitVariableDecl(VariableDecl *decl) override;
    llvm::Value *VisitVariableAccessExpr(VariableAccessExpr *varaccExpr) override;
    llvm::Value *VisitAssignExpr(AssignExpr *assignExpr) override;
    llvm::Value *VisitBinaryExpr(BinaryExpr *binaryExpr) override;
    llvm::Value *VisitNumberExpr(NumberExpr *factorExpr) override;
};

printVisitor.cc

cpp 复制代码
#include "printVisitor.h"

PrintVisitor::PrintVisitor(std::shared_ptr<Program> program)
{
    VisitProgram(program.get());
}

llvm::Value *PrintVisitor::VisitProgram(Program *p)
{
    for (auto &expr : p->nodeVec)
    {
        expr->Accept(this);
        llvm::outs() << "\n";
    }
    return nullptr;
}

llvm::Value *PrintVisitor::VisitDeclStmt(DeclStmt *decl)
{
    for (auto node : decl->nodeVec)
    {
        node->Accept(this);
    }
    return nullptr;
}

llvm::Value *PrintVisitor::VisitBlockStmt(BlockStmt *block)
{
    llvm::outs() << "{\n";
    for (const auto &node : block->nodeVec)
    {
        node->Accept(this);
        llvm::outs() << "\n";
    }
    llvm::outs() << "}\n";
    return nullptr;
}

llvm::Value *PrintVisitor::VisitIfStmt(IfStmt *ifStmt)
{
    llvm::outs() << "if (";
    ifStmt->condNode->Accept(this);
    llvm::outs() << ")\n";
    ifStmt->thenNode->Accept(this);
    llvm::outs() << "\n";
    if (ifStmt->elseNode)
    {
        llvm::outs() << "else\n";
        ifStmt->elseNode->Accept(this);
    }
    return nullptr;
}

llvm::Value *PrintVisitor::VisitForStmt(ForStmt *forStmt)
{
    llvm::outs() << "for (";
    if (forStmt->initNode)
    {
        forStmt->initNode->Accept(this);
    }
    llvm::outs() << "; ";
    if (forStmt->condNode)
    {
        forStmt->condNode->Accept(this);
    }
    llvm::outs() << "; ";
    if (forStmt->incNode)
    {
        forStmt->incNode->Accept(this);
    }
    llvm::outs() << ")";
    if (forStmt->bodyNode)
    {
        forStmt->bodyNode->Accept(this);
    }
    return nullptr;
}

llvm::Value *PrintVisitor::VisitVariableDecl(VariableDecl *decl)
{
    if (decl->type == CType::getIntTy())
    {
        llvm::outs() << "int " << llvm::StringRef(decl->token.ptr, decl->token.length) << ";";
    }
    return nullptr;
}

llvm::Value *PrintVisitor::VisitVariableAccessExpr(VariableAccessExpr *varaccExpr)
{
    llvm::outs() << llvm::StringRef(varaccExpr->token.ptr, varaccExpr->token.length);
    return nullptr;
}

llvm::Value *PrintVisitor::VisitAssignExpr(AssignExpr *assignExpr)
{
    assignExpr->left->Accept(this);

    llvm::outs() << " = ";

    assignExpr->right->Accept(this);
    return nullptr;
}

llvm::Value *PrintVisitor::VisitBinaryExpr(BinaryExpr *binaryExpr)
{
    // 后序遍历
    binaryExpr->left->Accept(this);

    switch (binaryExpr->op)
    {
    case OpCode::add:
    {
        llvm::outs() << " + ";
        break;
    }
    case OpCode::sub:
    {
        llvm::outs() << " - ";
        break;
    }
    case OpCode::mul:
    {
        llvm::outs() << " * ";
        break;
    }
    case OpCode::div:
    {
        llvm::outs() << " / ";
        break;
    }
    case OpCode::equal_equal:
    {
        llvm::outs() << " == ";
        break;
    }
    case OpCode::not_equal:
    {
        llvm::outs() << " != ";
        break;
    }
    case OpCode::less:
    {
        llvm::outs() << " < ";
        break;
    }
    case OpCode::less_equal:
    {
        llvm::outs() << " <= ";
        break;
    }
    case OpCode::greater:
    {
        llvm::outs() << " > ";
        break;
    }
    case OpCode::greater_equal:
    {
        llvm::outs() << " >= ";
        break;
    }
    default:
    {
        break;
    }
    }

    binaryExpr->right->Accept(this);

    return nullptr;
}

llvm::Value *PrintVisitor::VisitNumberExpr(NumberExpr *factorExpr)
{
    llvm::outs() << llvm::StringRef(factorExpr->token.ptr, factorExpr->token.length);
    return nullptr;
}

2.4 代码生成 (CodeGen)

根据文法新增的 for 语句新增一个 for 语句的类,解析 for 语句的函数以及对 ParseStmt() 函数进行对 for 语句的判断。

实现框架

代码实现

codegen.h

cpp 复制代码
#pragma once

#include "ast.h"
#include "parser.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/DenseMap.h"

// 通过访问者模式生成代码
class CodeGen : public Visitor
{
public:
    CodeGen(std::shared_ptr<Program> program)
    {
        module = std::make_shared<llvm::Module>("expr", context);
        VisitProgram(program.get());
    }

public:
    llvm::Value *VisitProgram(Program *p) override;
    llvm::Value *VisitDeclStmt(DeclStmt *decl) override;
    llvm::Value *VisitBlockStmt(BlockStmt *block) override;
    llvm::Value *VisitIfStmt(IfStmt *ifStmt) override;
    llvm::Value *VisitForStmt(ForStmt *forStmt) override;
    llvm::Value *VisitVariableDecl(VariableDecl *decl) override;
    llvm::Value *VisitVariableAccessExpr(VariableAccessExpr *varaccExpr) override;
    llvm::Value *VisitAssignExpr(AssignExpr *assignExpr) override;
    llvm::Value *VisitBinaryExpr(BinaryExpr *binaryExpr) override;
    llvm::Value *VisitNumberExpr(NumberExpr *factorExpr) override;

private:
    llvm::LLVMContext context;
    llvm::IRBuilder<> irBuilder{context};
    std::shared_ptr<llvm::Module> module;
    llvm::Function *curFunc{nullptr};
    llvm::StringMap<std::pair<llvm::Value *, llvm::Type *>> varAddrTyMap;
};

codegen.cc

cpp 复制代码
#include "codegen.h"
#include "llvm/IR/Verifier.h"

/*
; ModuleID = 'expr'
source_filename = "expr"

@0 = private unnamed_addr constant [16 x i8] c"expr value: %d\0A\00", align 1

declare i32 @printf(ptr, ...)

define i32 @main() {
entry:
  %a = alloca i32, align 4
  store i32 0, ptr %a, align 4
  %a1 = load i32, ptr %a, align 4
  %b = alloca i32, align 4
  store i32 5, ptr %b, align 4
  %b2 = load i32, ptr %b, align 4
  br label %cond

cond:                                             ; preds = %entry
  %a3 = load i32, ptr %a, align 4
  %0 = icmp ne i32 %a3, 0
  br i1 %0, label %then, label %else

then:                                             ; preds = %cond
  store i32 3, ptr %b, align 4
  %b4 = load i32, ptr %b, align 4
  br label %last

else:                                             ; preds = %cond
  store i32 4, ptr %a, align 4
  %a5 = load i32, ptr %a, align 4
  br label %last

last:                                             ; preds = %else, %then
  %a6 = load i32, ptr %a, align 4
  %b7 = load i32, ptr %b, align 4
  %mul = mul nsw i32 %a6, %b7
  %sub = sub nsw i32 %mul, 4
  %1 = call i32 (ptr, ...) @printf(ptr @0, i32 %sub)
  ret i32 0
}

*/

llvm::Value *CodeGen::VisitProgram(Program *p)
{
    // 创建 printf 函数
    auto printFunctionType = llvm::FunctionType::get(irBuilder.getInt32Ty(), {irBuilder.getInt8PtrTy()}, true);
    auto printFunction = llvm::Function::Create(printFunctionType, llvm::GlobalValue::ExternalLinkage, "printf", module.get());
    // 创建 main 函数
    auto mainFunctionType = llvm::FunctionType::get(irBuilder.getInt32Ty(), false);
    auto mainFunction = llvm::Function::Create(mainFunctionType, llvm::GlobalValue::ExternalLinkage, "main", module.get());
    // 创建 main 函数的基本块
    llvm::BasicBlock *entryBB = llvm::BasicBlock::Create(context, "entry", mainFunction);
    // 设置该基本块作为指令的入口
    irBuilder.SetInsertPoint(entryBB);
    // 记录当前函数
    curFunc = mainFunction;

    llvm::Value *lastVal = nullptr;
    for (auto node : p->nodeVec)
    {
        lastVal = node->Accept(this);
    }
    if (lastVal)
        irBuilder.CreateCall(printFunction, {irBuilder.CreateGlobalStringPtr("expr value: %d\n"), lastVal});
    else
        irBuilder.CreateCall(printFunction, {irBuilder.CreateGlobalStringPtr("last instruction is not expr!\n")});

    // 创建返回值
    llvm::Value *ret = irBuilder.CreateRet(irBuilder.getInt32(0));

    llvm::verifyFunction(*mainFunction);

    module->print(llvm::outs(), nullptr);
    return ret;
}

llvm::Value *CodeGen::VisitDeclStmt(DeclStmt *declStmt)
{
    llvm::Value *lastVal = nullptr;
    for (auto node : declStmt->nodeVec)
    {
        lastVal = node->Accept(this);
    }
    return lastVal;
}

llvm::Value *CodeGen::VisitBlockStmt(BlockStmt *block)
{
    llvm::Value *lastVal = nullptr;
    for (auto node : block->nodeVec)
    {
        lastVal = node->Accept(this);
    }
    return lastVal;
}

llvm::Value *CodeGen::VisitIfStmt(IfStmt *ifStmt)
{
    llvm::BasicBlock *condBB = llvm::BasicBlock::Create(context, "if_cond", curFunc);
    llvm::BasicBlock *thenBB = llvm::BasicBlock::Create(context, "if_then", curFunc);
    llvm::BasicBlock *elseBB = nullptr;
    if (ifStmt->elseNode)
        elseBB = llvm::BasicBlock::Create(context, "if_else", curFunc);
    llvm::BasicBlock *lastBB = llvm::BasicBlock::Create(context, "if_last", curFunc);

    // 需要手动添加一个无条件跳转指令,llvm 不会自动完成这个工作
    irBuilder.CreateBr(condBB);
    irBuilder.SetInsertPoint(condBB);
    llvm::Value *ret = ifStmt->condNode->Accept(this);
    // 整形比较指令
    llvm::Value *condVal = irBuilder.CreateICmpNE(ret, irBuilder.getInt32(0)); // 这里需要判断条件是否为真

    if (ifStmt->elseNode)
    {
        irBuilder.CreateCondBr(condVal, thenBB, elseBB);

        // handle then basic block
        irBuilder.SetInsertPoint(thenBB);
        ifStmt->thenNode->Accept(this);
        irBuilder.CreateBr(lastBB);

        // handle else basic block
        irBuilder.SetInsertPoint(elseBB);
        ifStmt->elseNode->Accept(this);
        irBuilder.CreateBr(lastBB);
    }
    else
    {
        irBuilder.CreateCondBr(condVal, thenBB, lastBB);

        // handle then basic block
        irBuilder.SetInsertPoint(thenBB);
        ifStmt->thenNode->Accept(this);
        irBuilder.CreateBr(lastBB);
    }
    irBuilder.SetInsertPoint(lastBB);
    return nullptr;
}

llvm::Value *CodeGen::VisitForStmt(ForStmt *forStmt)
{
    llvm::BasicBlock *initBB = llvm::BasicBlock::Create(context, "for_init", curFunc);
    llvm::BasicBlock *condBB = llvm::BasicBlock::Create(context, "for_cond", curFunc);
    llvm::BasicBlock *incBB = llvm::BasicBlock::Create(context, "for_inc", curFunc);
    llvm::BasicBlock *bodyBB = llvm::BasicBlock::Create(context, "for_body", curFunc);
    llvm::BasicBlock *lastBB = llvm::BasicBlock::Create(context, "for_last", curFunc);

    breakBBs.insert({forStmt, lastBB});
    continueBBs.insert({forStmt, incBB});

    irBuilder.CreateBr(initBB);
    irBuilder.SetInsertPoint(initBB);

    if (forStmt->initNode)
    {
        forStmt->initNode->Accept(this);
    }

    irBuilder.CreateBr(condBB);
    irBuilder.SetInsertPoint(condBB);

    if (forStmt->condNode)
    {
        llvm::Value *val = forStmt->condNode->Accept(this);
        llvm::Value *condVal = irBuilder.CreateICmpNE(val, irBuilder.getInt32(0));

        irBuilder.CreateCondBr(condVal, bodyBB, lastBB);
    }
    else
    {
        irBuilder.CreateBr(bodyBB);
    }

    irBuilder.SetInsertPoint(bodyBB);

    if (forStmt->bodyNode)
    {
        forStmt->bodyNode->Accept(this);
    }

    irBuilder.CreateBr(incBB);
    irBuilder.SetInsertPoint(incBB);

    if (forStmt->incNode)
    {
        forStmt->incNode->Accept(this);
    }

    irBuilder.CreateBr(condBB);
    irBuilder.SetInsertPoint(condBB);

    irBuilder.SetInsertPoint(lastBB);

    return nullptr;
}

llvm::Value *CodeGen::VisitVariableDecl(VariableDecl *decl)
{
    llvm::Type *ty = nullptr;
    if (decl->type == CType::getIntTy())
    {
        ty = irBuilder.getInt32Ty();
    }
    llvm::StringRef text(decl->token.ptr, decl->token.length);
    llvm::Value *varAddr = irBuilder.CreateAlloca(ty, nullptr, text);
    varAddrTyMap.insert({text, {varAddr, ty}});
    return varAddr;
}

llvm::Value *CodeGen::VisitVariableAccessExpr(VariableAccessExpr *varaccExpr)
{
    llvm::StringRef text(varaccExpr->token.ptr, varaccExpr->token.length);
    std::pair pair = varAddrTyMap[text];
    llvm::Value *varAddr = pair.first;
    llvm::Type *ty = pair.second;
    if (varaccExpr->type == CType::getIntTy())
    {
        ty = irBuilder.getInt32Ty();
    }
    // 返回一个右值
    return irBuilder.CreateLoad(ty, varAddr, text);
}

// a = 3; // right value
llvm::Value *CodeGen::VisitAssignExpr(AssignExpr *assignExpr)
{
    VariableAccessExpr *varAccExpr = (VariableAccessExpr *)assignExpr->left.get();
    llvm::StringRef text(varAccExpr->token.ptr, varAccExpr->token.length);
    std::pair pair = varAddrTyMap[text];
    llvm::Value *addr = pair.first;
    llvm::Type *ty = pair.second;
    llvm::Value *rValue = assignExpr->right->Accept(this);
    // 这个得到的是一个左值
    irBuilder.CreateStore(rValue, addr);
    // 返回一个右值
    return irBuilder.CreateLoad(ty, addr, text);
}

llvm::Value *CodeGen::VisitBinaryExpr(BinaryExpr *binaryExpr)
{
    auto left = binaryExpr->left->Accept(this);
    auto right = binaryExpr->right->Accept(this);

    switch (binaryExpr->op)
    {
    case OpCode::add:
    {
        return irBuilder.CreateNSWAdd(left, right, "add"); // CreateNSW... 是防止溢出行为的
    }
    case OpCode::sub:
    {
        return irBuilder.CreateNSWSub(left, right, "sub");
    }
    case OpCode::mul:
    {
        return irBuilder.CreateNSWMul(left, right, "mul");
    }
    case OpCode::div:
    {
        return irBuilder.CreateSDiv(left, right, "div");
    }
    case OpCode::equal_equal:
    {
        llvm::Value *val = irBuilder.CreateICmpEQ(left, right, "equal_equal");
        return irBuilder.CreateIntCast(val, irBuilder.getInt32Ty(), true);
    }
    case OpCode::not_equal:
    {
        llvm::Value *val = irBuilder.CreateICmpNE(left, right, "not_equal");
        return irBuilder.CreateIntCast(val, irBuilder.getInt32Ty(), true);
    }
    case OpCode::less:
    {
        llvm::Value *val = irBuilder.CreateICmpSLT(left, right, "less");
        return irBuilder.CreateIntCast(val, irBuilder.getInt32Ty(), true);
    }
    case OpCode::less_equal:
    {
        llvm::Value *val = irBuilder.CreateICmpSLE(left, right, "less_equal");
        return irBuilder.CreateIntCast(val, irBuilder.getInt32Ty(), true);
    }
    case OpCode::greater:
    {
        llvm::Value *val = irBuilder.CreateICmpSGT(left, right, "greater");
        return irBuilder.CreateIntCast(val, irBuilder.getInt32Ty(), true);
    }
    case OpCode::greater_equal:
    {
        llvm::Value *val = irBuilder.CreateICmpSGE(left, right, "greater_equal");
        return irBuilder.CreateIntCast(val, irBuilder.getInt32Ty(), true);
    }
    default:
    {
        break;
    }
    }
    return nullptr;
}

llvm::Value *CodeGen::VisitNumberExpr(NumberExpr *factorExpr)
{
    return irBuilder.getInt32(factorExpr->token.value);
}

2.5 测试 for 语句功能

生成 IR 文件

bash 复制代码
bin/for test/expr.txt > test/expr.ll

生成 IR 的内容

ini 复制代码
; ModuleID = 'expr'
source_filename = "expr"

@0 = private unnamed_addr constant [16 x i8] c"expr value: %d\0A\00", align 1

declare i32 @printf(ptr, ...)

define i32 @main() {
entry:
  %a = alloca i32, align 4
  store i32 1, ptr %a, align 4
  %a1 = load i32, ptr %a, align 4
  %b = alloca i32, align 4
  store i32 4, ptr %b, align 4
  %b2 = load i32, ptr %b, align 4
  br label %for_init

for_init:                                         ; preds = %entry
  %i = alloca i32, align 4
  store i32 0, ptr %i, align 4
  %i3 = load i32, ptr %i, align 4
  br label %for_cond

for_cond:                                         ; preds = %for_inc, %for_init
  %i4 = load i32, ptr %i, align 4
  %less = icmp slt i32 %i4, 4
  %0 = sext i1 %less to i32
  %1 = icmp ne i32 %0, 0
  br i1 %1, label %for_body, label %for_last

for_inc:                                          ; preds = %for_body
  %i14 = load i32, ptr %i, align 4
  %add15 = add nsw i32 %i14, 1
  store i32 %add15, ptr %i, align 4
  %i16 = load i32, ptr %i, align 4
  br label %for_cond

for_body:                                         ; preds = %for_cond
  %a5 = load i32, ptr %a, align 4
  %i6 = load i32, ptr %i, align 4
  %add = add nsw i32 %a5, %i6
  store i32 %add, ptr %a, align 4
  %a7 = load i32, ptr %a, align 4
  %b8 = load i32, ptr %b, align 4
  %a9 = load i32, ptr %a, align 4
  %mul = mul nsw i32 %b8, %a9
  %i10 = load i32, ptr %i, align 4
  %mul11 = mul nsw i32 2, %i10
  %add12 = add nsw i32 %mul, %mul11
  store i32 %add12, ptr %b, align 4
  %b13 = load i32, ptr %b, align 4
  br label %for_inc

for_last:                                         ; preds = %for_cond
  %a17 = load i32, ptr %a, align 4
  %b18 = load i32, ptr %b, align 4
  %add19 = add nsw i32 %b18, 3
  %mul20 = mul nsw i32 %a17, %add19
  %sub = sub nsw i32 %mul20, 2
  %2 = call i32 (ptr, ...) @printf(ptr @0, i32 %sub)
  ret i32 0
}

运行 IR

bash 复制代码
lli test/expr.ll

运行结果

bash 复制代码
expr value: 2217

结果正确,验证了编译器目前基础的 for 语句功能的正确性。

相关推荐
aidingni8883 小时前
保护你的前端:打造数字堡垒的艺术
前端·javascript
快乐是一切3 小时前
PDF底层结构之字体与编码系统
前端
llq_3503 小时前
解决端口被占用问题的 Webpack 启动脚本
前端
沢田纲吉3 小时前
《LLVM IR 学习手记(六):break 语句与 continue 语句的实现与解析》
前端·c++·llvm
火锅小王子3 小时前
目标筑基:从0到1学习GoLang (入门 Go语言+GoFrame开发服务端+ langchain接入)
前端·后端·openai
温宇飞3 小时前
CSS 属性分类
前端
鹏多多3 小时前
使用React-OAuth进行Google/GitHub登录的教程和案例
前端·javascript·react.js
爱和冰阔落3 小时前
【C++进阶】继承上 概念及其定义 赋值兼容转换 子类默认成员函数的详解分析
c++
晓得迷路了3 小时前
栗子前端技术周刊第 101 期 - React 19.2、Next.js 16 Beta、pnpm 10.18...
前端·javascript·react.js