《LLVM IR 学习手记(三):赋值表达式与错误处理的实现与解析》

在上一篇博客 《LLVM IR 学习手记(二):变量表达式编译器的实现与深入解析》中,我们实现了一个支持变量表达式的基础编译器,但其中存在一个 BUG------无法对变量重新赋值,也不支持连续赋值操作。本篇博客将首先修复这个问题。

1、赋值表达式

(1)、测试文件

expr.txt 内容如下:

ini 复制代码
int a = 4;
a = 5;
int b = a = 3;
a * b - 4 / 2;

(2)、词法分析器(Lexer)

为实现赋值表达式,我们需要修改词法分析器(Lexer)语法分析器(Parser)代码生成(CodeGen)

文法定义

ini 复制代码
prog : stmt*     
stmt : decl-stmt | expr-stmt | null-stmt      
null-stmt : ";"      
decl-stmt : "int" identifier ("," identifier (= expr)?)* ";"
expr-stmt : expr ";"
expr : assign-expr | add-expr
assign-expr: identifier "=" expr
add-expr : mult-expr (("+" | "-") mult-expr)* 
mult-expr : primary-expr (("*" | "/") primary-expr)* 
primary-expr : identifier | number | "(" expr ")" 
number: ([0-9])+ 
identifier : (a-zA-Z_)(a-zA-Z0-9_)*

实现代码

lexer.h

cpp 复制代码
#pragma once

#include "llvm/ADT/StringRef.h"
#include "llvm/Support/raw_ostream.h"
#include "type.h"

// char stream -> token

enum class TokenType
{
    unknown,
    number,      // [0-9]+
    indentifier, // 变量
    kw_int,      // int
    plus,        // +
    minus,       // -
    star,        // *
    slash,       // /
    equal,       // =
    l_parent,    // (
    r_parent,    // )
    semi,        // ;
    comma,       // ,
    eof          // end of file
};

class Token
{
public:
    TokenType tokenType; // token 的种类
    int row, col;

    int value; // for number

    llvm::StringRef content; // for debug

    CType *type; // for built-in type

public:
    void Dump();
};

class Lexer
{
public:
    Lexer(llvm::StringRef buf)
    {
        BufPtr = buf.begin();
        LineHeadPtr = buf.begin();
        BufEnd = buf.end();
        row = 1;
    }
    void NextToken(Token &token);

    void SaveState();    // 用于保存当前状态
    void RestoreState(); // 恢复存储的状态

private:
    struct State // 赋值表达式的关键要素
    {
        const char *BufPtr;
        const char *LineHeadPtr;
        const char *BufEnd;
        int row;
    };

private:
    const char *BufPtr;
    const char *LineHeadPtr;
    const char *BufEnd;
    int row;

    State state; // 可以记录此时 lexer 的状态
};

lexer.cc

cpp 复制代码
#include "lexer.h"

void Token::Dump()
{
    llvm::outs() << "{" << content << ", row = " << row << ", col = " << col << "}\n";
}

// number,     // [0-9]+
// indentifier,// 变量
// kw_int,     // int
// plus,       // +
// minus,      // -
// star,       // *
// slash,      // /
// equal,      // =
// l_parent,   // (
// r_parent,   // )
// semi,       // ;
// comma,      // ,
llvm::StringRef Token::GetSpellingText(TokenType tokenType)
{
    switch (tokenType)
    {
    case TokenType::kw_int:
        return "int";
    case TokenType::plus:
        return "+";
    case TokenType::minus:
        return "-";
    case TokenType::star:
        return "*";
    case TokenType::slash:
        return "/";
    case TokenType::equal:
        return "=";
    case TokenType::l_parent:
        return "(";
    case TokenType::r_parent:
        return ")";
    case TokenType::semi:
        return ";";
    case TokenType::comma:
        return ",";
    case TokenType::number:
        return "number";
    case TokenType::indentifier:
        return "indentifier";
    default:
        llvm::llvm_unreachable_internal();
    }
}

bool IsWhiteSpace(char ch)
{
    return ch == ' ' || ch == '\r' || ch == '\n';
}

bool IsDigit(char ch)
{
    return ch >= '0' && ch <= '9';
}

bool IsLetter(char ch)
{
    // a-z, A-Z, _
    return (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch == '_');
}

void Lexer::NextToken(Token &token)
{
    // 过滤空格
    while (IsWhiteSpace(*BufPtr))
    {
        if (*BufPtr == '\n')
        {
            row += 1;
            LineHeadPtr = BufPtr + 1;
        }
        BufPtr++;
    }

    token.row = row;
    token.col = BufPtr - LineHeadPtr + 1;

    // 判断是否到结尾了
    if (BufPtr >= BufEnd)
    {
        token.tokenType = TokenType::eof;
        return;
    }

    token.ptr = BufPtr;
    token.length = 0;
    // 判断是否为数字
    if (IsDigit(*BufPtr))
    {
        int len = 0;
        int val = 0;
        while (IsDigit(*BufPtr))
        {
            val = val * 10 + *BufPtr++ - '0';
            token.length++;
        }
        token.value = val;
        token.tokenType = TokenType::number;
        token.type = CType::getIntTy();
    }
    else if (IsLetter(*BufPtr)) // 为变量
    {
        while (IsLetter(*BufPtr) || IsDigit(*BufPtr))
        {
            BufPtr++;
        }
        token.tokenType = TokenType::indentifier;
        token.length = BufPtr - token.ptr;
        llvm::StringRef text(token.ptr, BufPtr - token.ptr);
        if (text == "int")
        {
            token.tokenType = TokenType::kw_int;
        }
    }
    else // 为特殊字符
    {
        switch (*BufPtr)
        {
        case '+':
        {
            token.tokenType = TokenType::plus;
            token.length = 1;
            BufPtr++;
            break;
        }
        case '-':
        {
            token.tokenType = TokenType::minus;
            token.length = 1;
            BufPtr++;
            break;
        }
        case '*':
        {
            token.tokenType = TokenType::star;
            token.length = 1;
            BufPtr++;
            break;
        }
        case '/':
        {
            token.tokenType = TokenType::slash;
            token.length = 1;
            BufPtr++;
            break;
        }
        case '=':
        {
            token.tokenType = TokenType::equal;
            token.length = 1;
            BufPtr++;
            break;
        }
        case '(':
        {
            token.tokenType = TokenType::l_parent;
            token.length = 1;
            BufPtr++;
            break;
        }
        case ')':
        {
            token.tokenType = TokenType::r_parent;
            token.length = 1;
            BufPtr++;
            break;
        }
        case ';':
        {
            token.tokenType = TokenType::semi;
            token.length = 1;
            BufPtr++;
            break;
        }
        case ',':
        {
            token.tokenType = TokenType::comma;
            token.length = 1;
            BufPtr++;
            break;
        }
        default:
        {
            llvm::llvm_unreachable_internal();
            BufPtr++;
            break;
        }
        }
    }
}

void Lexer::SaveState()
{
    state.LineHeadPtr = LineHeadPtr;
    state.BufPtr = BufPtr;
    state.BufEnd = BufEnd;
    state.row = row;
}

void Lexer::RestoreState()
{
    LineHeadPtr = state.LineHeadPtr;
    BufPtr = state.BufPtr;
    BufEnd = state.BufEnd;
    row = state.row;
}

(3)、语法分析器(Parser)

判断一个表达式是否为赋值表达式,需要向后查看一个 Token。如果该 Token 是变量(identifier),且下一个 Token 是等号(equal),则可以确定这是一个赋值表达式。赋值表达式通常采用右结合方式处理。

根据上述文法,我们需要在解析表达式语句时增加一个新的分支。

实现代码

parser.h

cpp 复制代码
#pragma once

#include "ast.h"
#include "lexer.h"
#include "sema.h"

class Parser
{
public:
    Parser(Lexer &lexer, Sema &sema) : lexer(lexer), sema(sema)
    {
        Advance();
    }
    std::shared_ptr<Program> ParseProgram();

private:
    std::vector<std::shared_ptr<ASTNode>> ParseDeclStmt();
    std::shared_ptr<ASTNode> ParseExprStmt();
    std::shared_ptr<ASTNode> ParseAssignExpr(); // 需要新增一个用于解析赋值表达式的函数
    std::shared_ptr<ASTNode> ParseExpr();
    std::shared_ptr<ASTNode> ParseTerm();
    std::shared_ptr<ASTNode> ParseFactor();

    // 消耗 token 的函数
    // 检测 token 的类型
    bool Expect(TokenType tokenType);
    // 检测 token 的类型并消费
    bool Consume(TokenType tokenType);
    // 直接消耗当前的 token
    bool Advance();

private:
    Lexer &lexer;
    Sema &sema;
    Token token;
};

parser.cc

cpp 复制代码
#include "parser.h"
#include <cassert>

// prog : stmt*
// stmt : decl-stmt | expr-stmt | null-stmt
// null-stmt : ";"
// decl-stmt : "int" identifier ("," identifier (= expr)?)* ";"
// expr-stmt : expr ";"
// expr : assign-expr | add-expr
// assign-expr: identifier "=" expr
// add-expr : mult-expr (("+" | "-") mult-expr)*
// mult-expr : primary-expr (("*" | "/") primary-expr)*
// primary-expr : identifier | number | "(" expr ")"
// number: ([0-9])+
// identifier : (a-zA-Z_)(a-zA-Z0-9_)*

// 解析目标程序
// stmt : decl-stmt | expr-stmt | null-stmt
std::shared_ptr<Program> Parser::ParseProgram()
{
    std::vector<std::shared_ptr<ASTNode>> ExprVec;
    while (token.tokenType != TokenType::eof)
    {
        // 遇到 ; 需要进行消费 token
        // null-stmt
        if (token.tokenType == TokenType::semi)
        {
            Advance();
            continue;
        }
        // decl-stmt
        if (token.tokenType == TokenType::kw_int)
        {
            auto exprs = ParseDeclStmt();
            for (auto expr : exprs)
            {
                ExprVec.push_back(expr);
            }
        }
        else // expr-stmt
        {
            auto expr = ParseExprStmt();
            ExprVec.push_back(expr);
        }
    }
    auto program = std::make_shared<Program>();
    program->ExprVec = std::move(ExprVec);
    return program;
}

// 解析声明
std::vector<std::shared_ptr<ASTNode>> Parser::ParseDeclStmt()
{
    /// int a, b = 3;
    /// int a = 3;
    Consume(TokenType::kw_int);
    CType *baseTy = CType::getIntTy();
    /// a , b = 3;
    /// a = 3;

    std::vector<std::shared_ptr<ASTNode>> astArr;

    /// a, b = 3;
    /// a = 3;
    int i = 0;
    while (token.tokenType != TokenType::semi)
    {
        if (i++ > 0) // if (i++)
        {
            assert(Consume(TokenType::comma));
        }

        /// 变量声明的节点: int a = 3; -> int a; a = 3;
        // a = 3;
        auto varName = token.content;
        auto variableDecl = sema.SemaVariableDeclNode(varName, baseTy); // get a type
        astArr.push_back(variableDecl);
        Consume(TokenType::indentifier);

        // = 3;
        if (token.tokenType == TokenType::equal)
        {
            llvm::StringRef name = varName;
            Advance();

            // 3;
            auto right = ParseExpr();
            auto left = sema.SemaVariableAccessNode(name);
            auto assign = sema.SemaAssignExprNode(left, right);

            astArr.push_back(assign);
        }
    }
    Advance();
    return astArr;
}

// 解析表达式语句
std::shared_ptr<ASTNode> Parser::ParseExprStmt()
{
    auto expr = ParseExpr();
    Consume(TokenType::semi);
    return expr;
}

// 解析表达式
// expr : assign-expr | add-expr
// assign-expr: identifier "=" expr
// add-expr : mult-expr (("+" | "-") mult-expr)*
std::shared_ptr<ASTNode> Parser::ParseExpr()
{
    lexer.SaveState(); // 先保存词法分析器 lexer 当前状态
    bool isAssign = false; // 用于标记这个表达式是否为赋值表达式
    Token tmp = token; // 借助一个了临时 token 进行判断
    // a = b;
    if (tmp.tokenType == TokenType::indentifier) // 判断类型是否为变量
    {
        lexer.NextToken(tmp);
        if (tmp.tokenType == TokenType::equal) // 判断类型是否为等号
        {
            isAssign = true;
        }
    }
    lexer.RestoreState(); // 恢复词法分析器 lexer 的状态
    if (isAssign) // 说明是一个赋值表达式
    {
        return ParseAssignExpr();
    }
    // add-expr
    std::shared_ptr<ASTNode> left = ParseTerm();
    while (token.tokenType == TokenType::plus || token.tokenType == TokenType::minus)
    {
        OpCode op;
        if (token.tokenType == TokenType::plus)
        {
            op = OpCode::add;
        }
        else
        {
            op = OpCode::sub;
        }
        Advance();
        auto right = ParseTerm();
        auto binaryExpr = sema.SemaBinaryExprNode(left, right, op);

        left = binaryExpr;
    }
    return left;
}

// 解析赋值表达式
std::shared_ptr<ASTNode> Parser::ParseAssignExpr()
{
    // a = b;
    llvm::StringRef text = token.content;
    Consume(TokenType::indentifier);
    auto expr = sema.SemaVariableAccessNode(text);
    Consume(TokenType::equal);
    return sema.SemaAssignExprNode(expr, ParseExpr());
}

// 解析项
std::shared_ptr<ASTNode> Parser::ParseTerm()
{
    std::shared_ptr<ASTNode> left = ParseFactor();

    while (token.tokenType == TokenType::star || token.tokenType == TokenType::slash)
    {
        OpCode op;
        if (token.tokenType == TokenType::star)
        {
            op = OpCode::mul;
        }
        else
        {
            op = OpCode::div;
        }
        Advance();
        auto right = ParseFactor();
        auto binaryExpr = sema.SemaBinaryExprNode(left, right, op);

        left = binaryExpr;
    }
    return left;
}

// 解析因子
std::shared_ptr<ASTNode> Parser::ParseFactor()
{
    if (token.tokenType == TokenType::l_parent)
    {
        Advance();
        auto expr = ParseExpr();
        assert(Expect(TokenType::r_parent));
        Advance();
        return expr;
    }
    else if(token.tokenType == TokenType::indentifier)
    {
        auto variableAccessExpr = sema.SemaVariableAccessNode(token.content);
        Advance();
        return variableAccessExpr;
    }
    else
    {
        auto factorExpr = sema.SemaNumberExprNode(token.value, token.type);
        Advance();
        return factorExpr;
    }
}

/// 消耗 token 函数
bool Parser::Expect(TokenType tokenType)
{
    if (token.tokenType == tokenType)
        return true;
    return false;
}

bool Parser::Consume(TokenType tokenType)
{
    if (Expect(tokenType))
    {
        Advance();
        return true;
    }
    return false;
}

bool Parser::Advance()
{
    lexer.NextToken(token);
    return true;
}

(4)、代码生成(CodeGen)

最后是代码生成部分,主要工作是将一些返回左值的操作改为返回右值。

实现代码

codegen.h

cpp 复制代码
#pragma once

#include "ast.h"
#include "parser.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/ADT/StringMap.h"

// 通过访问者模式生成代码
class CodeGen : public Visitor
{
public:
    CodeGen(std::shared_ptr<Program> program)
    {
        module = std::make_shared<llvm::Module>("expr", context);
        VisitProgram(program.get());
    }

public:
    llvm::Value *VisitProgram(Program *p) override;
    llvm::Value *VisitVariableDecl(VariableDecl *decl) override;
    llvm::Value *VisitVariableAccessExpr(VariableAccessExpr *varaccExpr) override;
    llvm::Value *VisitAssignExpr(AssignExpr *assignExpr) override;
    llvm::Value *VisitBinaryExpr(BinaryExpr *binaryExpr) override;
    llvm::Value *VisitNumberExpr(NumberExpr *factorExpr) override;

private:
    llvm::LLVMContext context;
    llvm::IRBuilder<> irBuilder{context};
    std::shared_ptr<llvm::Module> module;
    
    // 要得到右值就需要 irBuilder.CreateLoad() 函数,这个函数需要变量对应的类型
    // 为了得到变量对应的类型,所以这个 Map 需要使用 pair 来进行保存
    llvm::StringMap<std::pair<llvm::Value *, llvm::Type *>> varAddrTyMap;
};

codegen.cc

cpp 复制代码
#include "codegen.h"
#include "llvm/IR/Verifier.h"

llvm::Value *CodeGen::VisitProgram(Program *p)
{
    // 创建 printf 函数
    auto printFunctionType = llvm::FunctionType::get(irBuilder.getInt32Ty(), {irBuilder.getInt8PtrTy()}, true);
    auto printFunction = llvm::Function::Create(printFunctionType, llvm::GlobalValue::ExternalLinkage, "printf", module.get());
    // 创建 main 函数
    auto mainFunctionType = llvm::FunctionType::get(irBuilder.getInt32Ty(), false);
    auto mainFunction = llvm::Function::Create(mainFunctionType, llvm::GlobalValue::ExternalLinkage, "main", module.get());
    // 创建 main 函数的基本块
    llvm::BasicBlock *entryBlock = llvm::BasicBlock::Create(context, "entry", mainFunction);
    // 设置该基本块作为指令的入口
    irBuilder.SetInsertPoint(entryBlock);

    llvm::Value *lastVal = nullptr;
    for (auto expr : p->ExprVec)
    {
        lastVal = expr->Accept(this);
    }
    irBuilder.CreateCall(printFunction, {irBuilder.CreateGlobalStringPtr("expr value: %d\n"), lastVal});

    // 创建返回值
    llvm::Value *ret = irBuilder.CreateRet(irBuilder.getInt32(0));

    llvm::verifyFunction(*mainFunction);

    module->print(llvm::outs(), nullptr);
    return ret;
}

// 这个函数需要进行一下微调
llvm::Value *CodeGen::VisitVariableDecl(VariableDecl *decl)
{
    llvm::Type *ty = nullptr;
    if (decl->type == CType::getIntTy())
    {
        ty = irBuilder.getInt32Ty();
    }
    llvm::Value *varAddr = irBuilder.CreateAlloca(ty, nullptr, decl->name);
    varAddrTyMap.insert({text, {varAddr, ty}}); // 第二个参数变成了 pair
    return varAddr;
}

llvm::Value *CodeGen::VisitVariableAccessExpr(VariableAccessExpr *varaccExpr)
{
    std::pair pair = varAddrTyMap[varaccExpr->name];
    llvm::Value *varAddr = pair.first; // 得到变量的地址
    llvm::Type *ty = pair.second;      // 得到变量的类型
    if (varaccExpr->type == CType::getIntTy())
    {
        ty = irBuilder.getInt32Ty();
    }
    // 返回一个右值
    return irBuilder.CreateLoad(ty, varAddr, text);
}

// a = 3; // right value
llvm::Value *CodeGen::VisitAssignExpr(AssignExpr *assignExpr)
{
    VariableAccessExpr *varAccExpr = (VariableAccessExpr *)assignExpr->left.get();
    std::pair pair = varAddrTyMap[varAccExpr->name];
    llvm::Value *addr = pair.first;
    llvm::Type *ty = pair.second;
    llvm::Value *rValue = assignExpr->right->Accept(this);
    // 这个得到的是一个左值
    irBuilder.CreateStore(rValue, addr);
    // 返回一个右值
    return irBuilder.CreateLoad(ty, addr,text);
}

llvm::Value *CodeGen::VisitBinaryExpr(BinaryExpr *binaryExpr)
{
    auto left = binaryExpr->left->Accept(this);
    auto right = binaryExpr->right->Accept(this);

    switch (binaryExpr->op)
    {
    case OpCode::add:
    {
        return irBuilder.CreateNSWAdd(left, right, "add"); // CreateNSW... 是防止溢出行为的
    }
    case OpCode::sub:
    {
        return irBuilder.CreateNSWSub(left, right, "sub");
    }
    case OpCode::mul:
    {
        return irBuilder.CreateNSWMul(left, right, "mul");
    }
    case OpCode::div:
    {
        return irBuilder.CreateSDiv(left, right, "div");
    }
    default:
    {
        break;
    }
    }
    return nullptr;
}

llvm::Value *CodeGen::VisitNumberExpr(NumberExpr *factorExpr)
{
    return irBuilder.getInt32(factorExpr->token.value);
}

(5)、测试编译器

生成 IR 文件

bash 复制代码
bin/expr test/expr.txt > test/expr.ll

生成的 IR 内容

ini 复制代码
; ModuleID = 'expr'
source_filename = "expr"

@0 = private unnamed_addr constant [16 x i8] c"expr value: %d\0A\00", align 1

declare i32 @printf(ptr, ...)

define i32 @main() {
entry:
  %a = alloca i32, align 4
  store i32 4, ptr %a, align 4
  %a1 = load i32, ptr %a, align 4
  store i32 5, ptr %a, align 4
  %a2 = load i32, ptr %a, align 4
  %b = alloca i32, align 4
  store i32 3, ptr %a, align 4
  %a3 = load i32, ptr %a, align 4
  store i32 %a3, ptr %b, align 4
  %b4 = load i32, ptr %b, align 4
  %a5 = load i32, ptr %a, align 4
  %b6 = load i32, ptr %b, align 4
  %mul = mul nsw i32 %a5, %b6
  %sub = sub nsw i32 %mul, 2
  %0 = call i32 (ptr, ...) @printf(ptr @0, i32 %sub)
  ret i32 0
}

运行 IR

bash 复制代码
lli test/expr.ll

运行结果

bash 复制代码
expr value: 15

结果正确,验证了编译器目前赋值功能的正确性。

2、错误处理

编译器在报错时通常需要提供明确的错误信息,这能大大提高调试效率。错误处理模块正是为了实现这一功能。

(1)、测试文件

expr_01.txt 文件如下:

ini 复制代码
int a = 4;
{
int b = 2;

expr_02.txt 文件如下:

ini 复制代码
int a = 4
a + b * 3;

expr_03.txt 文件如下:

ini 复制代码
int a = 4;
a + b * 3;

expr_04.txt 文件如下:

ini 复制代码
int a = 4;
int a;

(2)、诊断引擎(DiagEngine)

我们需要一个诊断引擎来帮助报错。在编写诊断引擎(DiagEngine)代码之前,需要先编写一个配置文件,用于收集所有错误信息。

实现代码

diag.inc

cpp 复制代码
#ifndef DIAG
#define DIAG(ID, KIND, MSG)
#endif 

/// lexer
DIAG(err_unknown_char, Error, "unknown char '{0}'")

/// parser
DIAG(err_expected, Error, "expect '{0}' , but get '{1}'")
DIAG(err_redefine, Error, "redefine symbol '{0}'")
DIAG(err_undefine, Error, "undefine symbol '{0}'")
DIAG(err_lvalue, Error, "require left value on the assign operation left side")

#undef DIAG

diag_engine.h

cpp 复制代码
#pragma once

#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/SMLoc.h"
#include "llvm/Support/FormatVariadic.h"

namespace diag // 定义一个命名空间用于枚举错误编号
{
    enum
    {
#define DIAG(ID, KIND, MSG) ID, // 只要 DIAG 里的 ID 属性
#include "diag.inc" // 这种方式在预处理阶段会直接将 diag.inc 文件全部内容直接放到这个位置
    };
};

class DiagEngine
{
public:
    DiagEngine(llvm::SourceMgr &mgr) : mgr(mgr) {}

    template <typename... Args>                                 // 定义一个可变参数
    void Report(llvm::SMLoc loc, unsigned diagID, Args... args) // 传入这个可变参数
    {
        // 通过 diagID 可以获得 kind 和 msg
        auto kind = GetDiagKind(diagID);
        const char *fmt = GetDiagMsg(diagID);
        auto f = llvm::formatv(fmt, std::forward<Args>(args)...).str();
        mgr.PrintMessage(loc, kind, f); // 报错功能主要借助的就是这个函数

        if (kind == llvm::SourceMgr::DK_Error)
        {
            exit(1);
        }
    }

private:
    llvm::SourceMgr::DiagKind GetDiagKind(unsigned id);
    const char *GetDiagMsg(unsigned id);

private:
    llvm::SourceMgr &mgr; // 报错主要用的是 llvm 提供的 SourceMgr 这个类
};

diag_engine.cc

cpp 复制代码
#include "diag_engine.h"

static llvm::SourceMgr::DiagKind diag_kind[] = 
{
    #define DIAG(ID, KIND, MSG) llvm::SourceMgr::DK_##KIND,
    #include "diag.inc"
};

static const char* diag_msg[] = 
{
    #define DIAG(ID, KIND, MSG) MSG,
    #include "diag.inc"
};

llvm::SourceMgr::DiagKind DiagEngine::GetDiagKind(unsigned id)
{
    return diag_kind[id];
}

const char* DiagEngine::GetDiagMsg(unsigned id)
{
    return diag_msg[id];
}

(3)、词法分析器(Lexer)

有了诊断引擎,接下来需要将其融入到**词法分析器(Lexer)、语法分析器(Parser)、语义分析器(Sema)**中。

实现代码

lexer.h

cpp 复制代码
#pragma once

#include "llvm/ADT/StringRef.h"
#include "llvm/Support/raw_ostream.h"
#include "type.h"
#include "diag_engine.h"

// char stream -> token

enum class TokenType
{
    //unknown,   // 将这个进行注释掉,因为 unknown 本身就是一种错误
    number,      // [0-9]+
    indentifier, // 变量
    kw_int,      // int
    plus,        // +
    minus,       // -
    star,        // *
    slash,       // /
    equal,       // =
    l_parent,    // (
    r_parent,    // )
    semi,        // ;
    comma,       // ,
    eof          // end of file
};

class Token
{
public:
    TokenType tokenType; // token 的种类
    int row, col;

    int value; // for number

    // 因为 DiagEngine 中的 mgr 进行报错的时候需要位置信息,
    // 而这个位置信息依靠的是 llvm::SMLoc::getFromPointer() 这个函数,
    // 这个函数需要的参数就是一个指针,所以将 content 这个拆成 ptr 和 length
    // 借助 llvm::StringRef() 的构造函数也能得到对应的 content
    // llvm::StringRef content; 
    const char *ptr; // for debug
    int length;

    CType *type; // for built-in type

public:
    void Dump();
    static llvm::StringRef GetSpellingText(TokenType tokenType);
};

class Lexer
{
private:
    DiagEngine &diagEngine;
    llvm::SourceMgr &mgr;

public:
    Lexer(DiagEngine &diagEngine, llvm::SourceMgr &mgr) : diagEngine(diagEngine), mgr(mgr)
    {
        unsigned id = mgr.getMainFileID();
        llvm::StringRef buf = mgr.getMemoryBuffer(id)->getBuffer();
        BufPtr = buf.begin();
        LineHeadPtr = buf.begin();
        BufEnd = buf.end();
        row = 1;
    }
    void NextToken(Token &token);

    void SaveState();
    void RestoreState();

    DiagEngine &GetDiagEngine() const
    {
        return diagEngine;
    }

private:
    struct State
    {
        const char *BufPtr;
        const char *LineHeadPtr;
        const char *BufEnd;
        int row;
    };

private:
    const char *BufPtr;
    const char *LineHeadPtr;
    const char *BufEnd;
    int row;

    State state;
};

lexer.cc

cpp 复制代码
#include "lexer.h"

void Token::Dump()
{
    llvm::StringRef text(ptr, length);
    llvm::outs() << "{" << text << ", row = " << row << ", col = " << col << "}\n";
}

// number,     // [0-9]+
// indentifier,// 变量
// kw_int,     // int
// plus,       // +
// minus,      // -
// star,       // *
// slash,      // /
// equal,      // =
// l_parent,   // (
// r_parent,   // )
// semi,       // ;
// comma,      // ,
llvm::StringRef Token::GetSpellingText(TokenType tokenType)
{
    switch (tokenType)
    {
    case TokenType::kw_int:
        return "int";
    case TokenType::plus:
        return "+";
    case TokenType::minus:
        return "-";
    case TokenType::star:
        return "*";
    case TokenType::slash:
        return "/";
    case TokenType::equal:
        return "=";
    case TokenType::l_parent:
        return "(";
    case TokenType::r_parent:
        return ")";
    case TokenType::semi:
        return ";";
    case TokenType::comma:
        return ",";
    case TokenType::number:
        return "number";
    case TokenType::indentifier:
        return "indentifier";
    default:
        llvm::llvm_unreachable_internal(); // 逻辑上不可能到达这个位置吗,假如到达这位置会报错,说明代码有问题
    }
}

bool IsWhiteSpace(char ch)
{
    return ch == ' ' || ch == '\r' || ch == '\n';
}

bool IsDigit(char ch)
{
    return ch >= '0' && ch <= '9';
}

bool IsLetter(char ch)
{
    // a-z, A-Z, _
    return (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch == '_');
}

void Lexer::NextToken(Token &token)
{
    // 过滤空格
    while (IsWhiteSpace(*BufPtr))
    {
        if (*BufPtr == '\n')
        {
            row += 1;
            LineHeadPtr = BufPtr + 1;
        }
        BufPtr++;
    }

    token.row = row;
    token.col = BufPtr - LineHeadPtr + 1;

    // 判断是否到结尾了
    if (BufPtr >= BufEnd)
    {
        token.tokenType = TokenType::eof;
        return;
    }

    token.ptr = BufPtr;
    token.length = 0;
    // 判断是否为数字
    if (IsDigit(*BufPtr))
    {
        int len = 0;
        int val = 0;
        while (IsDigit(*BufPtr))
    {
            val = val * 10 + *BufPtr++ - '0';
            token.length++;
        }
        token.value = val;
        token.tokenType = TokenType::number;
        token.type = CType::getIntTy();
    }
    else if (IsLetter(*BufPtr)) // 为变量
    {
        while (IsLetter(*BufPtr) || IsDigit(*BufPtr))
        {
            BufPtr++;
        }
        token.tokenType = TokenType::indentifier;
        token.length = BufPtr - token.ptr;
        llvm::StringRef text(token.ptr, BufPtr - token.ptr);
        if (text == "int")
        {
            token.tokenType = TokenType::kw_int;
        }
    }
    else // 为特殊字符
    {
        switch (*BufPtr)
        {
        case '+':
        {
            token.tokenType = TokenType::plus;
            token.length = 1;
            BufPtr++;
            break;
        }
        case '-':
        {
            token.tokenType = TokenType::minus;
            token.length = 1;
            BufPtr++;
            break;
        }
        case '*':
        {
            token.tokenType = TokenType::star;
            token.length = 1;
            BufPtr++;
            break;
        }
        case '/':
        {
            token.tokenType = TokenType::slash;
            token.length = 1;
            BufPtr++;
            break;
        }
        case '=':
        {
            token.tokenType = TokenType::equal;
            token.length = 1;
            BufPtr++;
            break;
        }
        case '(':
        {
            token.tokenType = TokenType::l_parent;
            token.length = 1;
            BufPtr++;
            break;
        }
        case ')':
        {
            token.tokenType = TokenType::r_parent;
            token.length = 1;
            BufPtr++;
            break;
        }
        case ';':
        {
            token.tokenType = TokenType::semi;
            token.length = 1;
            BufPtr++;
            break;
        }
        case ',':
        {
            token.tokenType = TokenType::comma;
            token.length = 1;
            BufPtr++;
            break;
        }
        default:
        {
            diagEngine.Report(llvm::SMLoc::getFromPointer(BufPtr), diag::err_unknown_char, *BufPtr);
        }
        }
    }
}

void Lexer::SaveState()
{
    state.LineHeadPtr = LineHeadPtr;
    state.BufPtr = BufPtr;
    state.BufEnd = BufEnd;
    state.row = row;
}

void Lexer::RestoreState()
{
    LineHeadPtr = state.LineHeadPtr;
    BufPtr = state.BufPtr;
    BufEnd = state.BufEnd;
    row = state.row;
}

(4)、语法分析器(Parser)

对于语法分析器,主要修改的是抽象语法树AST消耗 Token 的函数的代码。

实现代码

ast.h

cpp 复制代码
#pragma once

#include <vector>
#include <memory>
#include "llvm/IR/Value.h"
#include "type.h"
#include "lexer.h"

enum class OpCode
{
    add,
    sub,
    mul,
    div
};

// 进行声明
class Program;
class Expr;
class VariableDecl;
class VariableAccessExpr;
class BinaryExpr;
class AssignExpr;
class NumberExpr;

// 访问者模式
class Visitor
{
public:
    virtual ~Visitor() {}
    virtual llvm::Value *VisitProgram(Program *p) = 0;
    virtual llvm::Value *VisitVariableDecl(VariableDecl* decl) = 0;
    virtual llvm::Value *VisitVariableAccessExpr(VariableAccessExpr* varaccExpr) = 0;
    virtual llvm::Value *VisitAssignExpr(AssignExpr* assignExpr) = 0;
    virtual llvm::Value *VisitBinaryExpr(BinaryExpr *binaryExpr) = 0;
    virtual llvm::Value *VisitNumberExpr(NumberExpr *factorExpr) = 0;
};

// 语法树的公共节点
class ASTNode 
{
public:
    enum Kind
    {
        ND_VariableDecl,
        ND_BinaryExpr,
        ND_NumberExpr,
        ND_VariableAccessExpr,
        ND_AssignExpr
    };

private:
    const Kind kind;
public:
    ASTNode(Kind kind) : kind(kind) {}
    virtual ~ASTNode() {}
    virtual llvm::Value *Accept(Visitor *v) { return nullptr; } // 通过虚函数的特性完成分发的功能
    const Kind getKind() const { return kind; }

public:
    CType *type;
    Token token;
};

// 变量声明节点
class VariableDecl : public ASTNode 
{
public:
    VariableDecl() : ASTNode(ND_VariableDecl) {}

    llvm::Value *Accept(Visitor *v) override
    {
        return v->VisitVariableDecl(this);
    }

    static bool classof(const ASTNode *node) // 判断是否能进行指针的强转
    {
        return node->getKind() == ND_VariableDecl;
    }
};

// 二元表达式节点
class BinaryExpr : public ASTNode
{
public:
    BinaryExpr() : ASTNode(ND_BinaryExpr) {}

    llvm::Value *Accept(Visitor *v) override
    {
        return v->VisitBinaryExpr(this);
    }

    static bool classof(const ASTNode *node) // 判断是否能进行指针的强转
    {
        return node->getKind() == ND_BinaryExpr;
    }
public:
    OpCode op;
    std::shared_ptr<ASTNode> left;
    std::shared_ptr<ASTNode> right;
};

// 赋值表达式节点
class AssignExpr : public ASTNode
{
public:
    AssignExpr() : ASTNode(ND_AssignExpr) {}

    llvm::Value *Accept(Visitor *v) override
    {
        return v->VisitAssignExpr(this);
    }

    static bool classof(const ASTNode *node) // 判断是否能进行指针的强转
    {
        return node->getKind() == ND_AssignExpr;
    }
public:
    std::shared_ptr<ASTNode> left;
    std::shared_ptr<ASTNode> right;
};

// 数字表达式节点
class NumberExpr : public ASTNode
{
public:
    NumberExpr() : ASTNode(ND_NumberExpr) {}

    llvm::Value *Accept(Visitor *v) override
    {
        return v->VisitNumberExpr(this);
    }

    static bool classof(const ASTNode *node) // 判断是否能进行指针的强转
    {
        return node->getKind() == ND_NumberExpr;
    }
};

// 变量访问节点
class VariableAccessExpr : public ASTNode // 变量表达式
{
public:
    VariableAccessExpr() : ASTNode(ND_VariableAccessExpr) {}

    llvm::Value *Accept(Visitor *v) override
    {
        return v->VisitVariableAccessExpr(this);
    }

    static bool classof(const ASTNode *node) // 判断是否能进行指针的强转
    {
        return node->getKind() == ND_VariableAccessExpr;
    }
};

// 目标程序
class Program
{
public:
    std::vector<std::shared_ptr<ASTNode>> ExprVec;
};

parser.h

cpp 复制代码
#pragma once

#include "ast.h"
#include "lexer.h"
#include "sema.h"

class Parser
{
public:
    Parser(Lexer &lexer, Sema &sema) : lexer(lexer), sema(sema)
    {
        Advance();
    }
    std::shared_ptr<Program> ParseProgram();

private:
    std::vector<std::shared_ptr<ASTNode>> ParseDeclStmt();
    std::shared_ptr<ASTNode> ParseExprStmt();
    std::shared_ptr<ASTNode> ParseAssignExpr();
    std::shared_ptr<ASTNode> ParseExpr();
    std::shared_ptr<ASTNode> ParseTerm();
    std::shared_ptr<ASTNode> ParseFactor();

    // 消耗 token 的函数
    // 检测 token 的类型
    bool Expect(TokenType tokenType);
    // 检测 token 的类型并消费
    bool Consume(TokenType tokenType);
    // 直接消耗当前的 token
    bool Advance();

    DiagEngine &GetDiagEngine() const
    {
        return lexer.GetDiagEngine();
    }

private:
    Lexer &lexer;
    Sema &sema;
    Token token;
};

parser.cc

cpp 复制代码
#include "parser.h"
#include <cassert>

// prog : stmt*
// stmt : decl-stmt | expr-stmt | null-stmt
// null-stmt : ";"
// decl-stmt : "int" identifier ("," identifier (= expr)?)* ";"
// expr-stmt : expr ";"
// expr : assign-expr | add-expr
// assign-expr: identifier "=" expr
// add-expr : mult-expr (("+" | "-") mult-expr)*
// mult-expr : primary-expr (("*" | "/") primary-expr)*
// primary-expr : identifier | number | "(" expr ")"
// number: ([0-9])+
// identifier : (a-zA-Z_)(a-zA-Z0-9_)*

// 解析目标程序
// stmt : decl-stmt | expr-stmt | null-stmt
std::shared_ptr<Program> Parser::ParseProgram()
{
    std::vector<std::shared_ptr<ASTNode>> ExprVec;
    while (token.tokenType != TokenType::eof)
    {
        // 遇到 ; 需要进行消费 token
        // null-stmt
        if (token.tokenType == TokenType::semi)
        {
            Advance();
            continue;
        }
        // decl-stmt
        if (token.tokenType == TokenType::kw_int)
    {
            auto exprs = ParseDeclStmt();
            for (auto expr : exprs)
            {
                ExprVec.push_back(expr);
            }
        }
        else // expr-stmt
        {
            auto expr = ParseExprStmt();
            ExprVec.push_back(expr);
        }
    }
    auto program = std::make_shared<Program>();
    program->ExprVec = std::move(ExprVec);
    return program;
}

// 解析声明
std::vector<std::shared_ptr<ASTNode>> Parser::ParseDeclStmt()
{
    /// int a, b = 3;
    /// int a = 3;
    Consume(TokenType::kw_int);
    CType *baseTy = CType::getIntTy();
    /// a , b = 3;
    /// a = 3;

    std::vector<std::shared_ptr<ASTNode>> astArr;

    /// a, b = 3;
    /// a = 3;
    int i = 0;
    while (token.tokenType != TokenType::semi)
    {
        if (i++ > 0) // if (i++)
        {
            assert(Consume(TokenType::comma));
        }

        /// 变量声明的节点: int a = 3; -> int a; a = 3;
        // a = 3;
        auto variableDecl = sema.SemaVariableDeclNode(token, baseTy); // get a type
        astArr.push_back(variableDecl);
        
        Token tmp = token;
        Consume(TokenType::indentifier);

        // = 3;
        if (token.tokenType == TokenType::equal)
        {
            Token opToken = token;
            Advance();

            // 3;
            auto right = ParseExpr();
            auto left = sema.SemaVariableAccessNode(tmp);
            auto assign = sema.SemaAssignExprNode(left, right, opToken);

            astArr.push_back(assign);
        }
    }

    Advance();

    return astArr;
}

// 解析表达式语句
std::shared_ptr<ASTNode> Parser::ParseExprStmt()
{
    auto expr = ParseExpr();
    Consume(TokenType::semi);
    return expr;
}

// 解析表达式
// expr : assign-expr | add-expr
// assign-expr: identifier "=" expr
// add-expr : mult-expr (("+" | "-") mult-expr)*
std::shared_ptr<ASTNode> Parser::ParseExpr()
{
    lexer.SaveState();
    bool isAssign = false;
    Token tmp = token;
    // a = b;
    if (tmp.tokenType == TokenType::indentifier)
    {
        lexer.NextToken(tmp);
        if (tmp.tokenType == TokenType::equal)
        {
            isAssign = true;
        }
    }
    lexer.RestoreState();
    if (isAssign)
    {
        return ParseAssignExpr();
    }
    // add-expr
    std::shared_ptr<ASTNode> left = ParseTerm();
    while (token.tokenType == TokenType::plus || token.tokenType == TokenType::minus)
    {
        OpCode op;
        if (token.tokenType == TokenType::plus)
        {
            op = OpCode::add;
        }
        else
        {
            op = OpCode::sub;
        }
        Advance();
        auto right = ParseTerm();
        auto binaryExpr = sema.SemaBinaryExprNode(left, right, op);

        left = binaryExpr;
    }
    return left;
}

// 解析赋值表达式
std::shared_ptr<ASTNode> Parser::ParseAssignExpr()
{
    // a = b;
    Token tmp = token;
    Consume(TokenType::indentifier);
    auto expr = sema.SemaVariableAccessNode(tmp);
    Token opToken = token;
    Consume(TokenType::equal);
    return sema.SemaAssignExprNode(expr, ParseExpr(), opToken);
}

// 解析项
std::shared_ptr<ASTNode> Parser::ParseTerm()
{
    std::shared_ptr<ASTNode> left = ParseFactor();

    while (token.tokenType == TokenType::star || token.tokenType == TokenType::slash)
    {
        OpCode op;
        if (token.tokenType == TokenType::star)
        {
            op = OpCode::mul;
        }
        else
        {
            op = OpCode::div;
        }
        Advance();
        auto right = ParseFactor();
        auto binaryExpr = sema.SemaBinaryExprNode(left, right, op);

        left = binaryExpr;
    }
    return left;
}

// 解析因子
std::shared_ptr<ASTNode> Parser::ParseFactor()
{
    if (token.tokenType == TokenType::l_parent)
    {
        Advance();
        auto expr = ParseExpr();
        assert(Expect(TokenType::r_parent));
        Advance();
        return expr;
    }
    else if (token.tokenType == TokenType::indentifier)
    {
        auto variableAccessExpr = sema.SemaVariableAccessNode(token);
        Advance();
        return variableAccessExpr;
    }
    else
    {
        Expect(TokenType::number);
        auto factorExpr = sema.SemaNumberExprNode(token, token.type);
        Advance();
        return factorExpr;
    }
}

/// 消耗 token 函数
bool Parser::Expect(TokenType tokenType)
{
    if (token.tokenType == tokenType)
        return true;
    // 当不是所期望的类型时应该直接报错,说明此时源代码不符合文法
    GetDiagEngine().Report(llvm::SMLoc::getFromPointer(token.ptr),
                           diag::err_expected,
                           Token::GetSpellingText(tokenType),
                           llvm::StringRef(token.ptr, token.length));
    return false;
}

bool Parser::Consume(TokenType tokenType)
{
    if (Expect(tokenType))
    {
        Advance();
        return true;
    }
    return false;
}

bool Parser::Advance()
{
    lexer.NextToken(token);
    return true;
}

由于抽象语法树中的节点发生了修改,而抽象语法树的形成依靠的是访问者模式,所以继承访问者模式的打印功能的代码也要进行修改。

printVisitor.h

cpp 复制代码
#pragma once

#include "ast.h"
#include "parser.h"

class PrintVisitor : public Visitor
{
public:
    PrintVisitor(std::shared_ptr<Program> program);

public:
    llvm::Value *VisitProgram(Program *p) override;
    llvm::Value *VisitVariableDecl(VariableDecl *decl) override;
    llvm::Value *VisitVariableAccessExpr(VariableAccessExpr *varaccExpr) override;
    llvm::Value *VisitAssignExpr(AssignExpr *assignExpr) override;
    llvm::Value *VisitBinaryExpr(BinaryExpr *binaryExpr) override;
    llvm::Value *VisitNumberExpr(NumberExpr *factorExpr) override;
};

printVisitor.cc

cpp 复制代码
#include "printVisitor.h"

PrintVisitor::PrintVisitor(std::shared_ptr<Program> program)
{
    VisitProgram(program.get());
}

llvm::Value *PrintVisitor::VisitProgram(Program *p)
{
    for (auto &expr : p->ExprVec)
    {
        expr->Accept(this);
        llvm::outs() << "\n";
    }
    return nullptr;
}

llvm::Value *PrintVisitor::VisitVariableDecl(VariableDecl *decl)
{
    if(decl->type == CType::getIntTy())
    {
        llvm::outs() << "int " << llvm::StringRef(decl->token.ptr, decl->token.length);
    }
    return nullptr;
}

llvm::Value *PrintVisitor::VisitVariableAccessExpr(VariableAccessExpr *varaccExpr)
{
    llvm::outs() << llvm::StringRef(varaccExpr->token.ptr, varaccExpr->token.length);
    return nullptr;
}

llvm::Value *PrintVisitor::VisitAssignExpr(AssignExpr *assignExpr)
{
    assignExpr->left->Accept(this);

    llvm::outs() << " = ";

    assignExpr->right->Accept(this);
    return nullptr;

}

llvm::Value *PrintVisitor::VisitBinaryExpr(BinaryExpr *binaryExpr)
{
    // 后序遍历
    binaryExpr->left->Accept(this);

    switch (binaryExpr->op)
    {
    case OpCode::add:
    {
        llvm::outs() << " + ";
        break;
    }
    case OpCode::sub:
    {
        llvm::outs() << " - ";
        break;
    }
    case OpCode::mul:
    {
        llvm::outs() << " * ";
        break;
    }
    case OpCode::div:
    {
        llvm::outs() << " / ";
        break;
    }
    default:
    {
        break;
    }
    }

    binaryExpr->right->Accept(this);

    return nullptr;
}

llvm::Value *PrintVisitor::VisitNumberExpr(NumberExpr *factorExpr)
{
    llvm::outs() << llvm::StringRef(factorExpr->token.ptr, factorExpr->token.length);
    return nullptr;
}

(5)、语义分析器(Sema)

实现代码

sema.h

cpp 复制代码
#pragma once

#include "scope.h"
#include "ast.h"

class Sema // 语义分析
{
public:
    Sema(DiagEngine &diagEngine) : diagEngine(diagEngine) {}
    std::shared_ptr<ASTNode> SemaVariableDeclNode(Token token, CType* ty);
    std::shared_ptr<ASTNode> SemaVariableAccessNode(Token token);
    std::shared_ptr<ASTNode> SemaAssignExprNode(std::shared_ptr<ASTNode> left, std::shared_ptr<ASTNode> right, Token token);
    std::shared_ptr<ASTNode> SemaBinaryExprNode(std::shared_ptr<ASTNode> left, std::shared_ptr<ASTNode> right, OpCode op);
    std::shared_ptr<ASTNode> SemaNumberExprNode(Token token, CType* ty);

    DiagEngine &GetDiaEngine() const
    {
        return diagEngine;
    }
private:
    Scope scope;
    DiagEngine &diagEngine;
};

sema.cc

cpp 复制代码
#include "sema.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Support/Casting.h"

// 符号声明节点
std::shared_ptr<ASTNode> Sema::SemaVariableDeclNode(Token token, CType* ty)
{
    llvm::StringRef text(token.ptr, token.length);
    auto symbol = scope.FindVarSymbolInCurEnv(text);
    if (symbol) // 查找是否重定义
    {
        GetDiaEngine().Report(llvm::SMLoc::getFromPointer(token.ptr), 
                              diag::err_redefine, 
                              llvm::StringRef(token.ptr, token.length));
    }

    scope.AddVarSymbol(SymbolKind::LocalVariable, ty, text); // 添加到符号表

    auto variableDecl = std::make_shared<VariableDecl>();
    variableDecl->token = token;
    variableDecl->type = ty;

    return variableDecl;
}

std::shared_ptr<ASTNode> Sema::SemaVariableAccessNode(Token token)
{
    llvm::StringRef text(token.ptr, token.length);
    auto symbol = scope.FindVarSymbol(text);
    // auto symbol = scope.FindVarSymbolInCurEnv(name); // err
    if (symbol == nullptr)
    {
        GetDiaEngine().Report(llvm::SMLoc::getFromPointer(token.ptr), 
                              diag::err_undefine, 
                              llvm::StringRef(token.ptr, token.length));
    }
    auto varAcc = std::make_shared<VariableAccessExpr>();
    varAcc->token = token;
    varAcc->type = symbol->getTy();
    
    return varAcc;
}

std::shared_ptr<ASTNode> Sema::SemaAssignExprNode(std::shared_ptr<ASTNode> left, std::shared_ptr<ASTNode> right, Token token)
{
    assert(left && right);

    if (!llvm::isa<VariableAccessExpr>(left.get()))
    {
        diagEngine.Report(llvm::SMLoc::getFromPointer(left->token.ptr), diag::err_lvalue);
    }

    auto assign = std::make_shared<AssignExpr>();
    assign->left = left;
    assign->right = right;

    return assign;
}

std::shared_ptr<ASTNode> Sema::SemaBinaryExprNode(std::shared_ptr<ASTNode> left, std::shared_ptr<ASTNode> right, OpCode op)
{
    auto binaryExpr = std::make_shared<BinaryExpr>();
    binaryExpr->left = left;
    binaryExpr->right = right;
    binaryExpr->op = op;

    return binaryExpr;
}

std::shared_ptr<ASTNode> Sema::SemaNumberExprNode(Token token, CType* ty)
{
    auto factorExpr = std::make_shared<NumberExpr>();
    factorExpr->token = token;
    factorExpr->type = ty;

    return factorExpr;
}

(6)、测试错误处理功能

生成 IR 文件

expr_01.txt

bash 复制代码
bin/expr test/expr_01.txt > test/expr_01.ll

expr_02.txt

bash 复制代码
bin/expr test/expr_02.txt > test/expr_02.ll

expr_03.txt

bash 复制代码
bin/expr test/expr_03.txt > test/expr_03.ll

expr_04.txt

bash 复制代码
bin/expr test/expr_04.txt > test/expr_04.ll

生成的 IR 内容

expr_01.txt

bash 复制代码
test/expr.txt:2:1: error: unknown char '{'
{
^

expr_02.txt

bash 复制代码
test/expr.txt:2:1: error: expect ',' , but get 'a'
a + b * 3;

expr_03.txt

bash 复制代码
test/expr.txt:2:5: error: undefine symbol 'b'
a + b * 3;

expr_04.txt

bash 复制代码
test/expr.txt:2:5: error: redefine symbol 'a'
int a;
    ^

由于最后一个左值的情况只能发生在为表达式的这一种情况,但是要测试出来要么左边为常数,要么是使用std::move() 得到的这个整体也是右值,而后一种暂不支持,前一种又不符合表达式的情况,所以这个目前无法测试!!!

3 单元测试

这里单元测试使用的是 google test。

首先是先对 CMakeLists.txt进行一下简单的修改。

CMakeLists.txt

ini 复制代码
cmake_minimum_required(VERSION 3.18)

project(errhandle_unittest)

set(CMAKE_C_STANDARD 99)
set(CMAKE_CXX_STANDARD 17)

SET(EXECUTABLE_OUTPUT_PATH ${CMAKE_SOURCE_DIR}/bin)

find_package(LLVM REQUIRED CONFIG)

list(APPEND CMAKE_MODULE_PATH ${LLVM_CMAKE_DIR})

include(AddLLVM)

include_directories("${LLVM_BINARY_DIR}/include" "${LLVM_INCLUDE_DIR}")
add_definitions(${LLVM_DEFINITIONS})

if (NOT ${LLVM_ENABLE_RTTI})
    # For non-MSVC compilers
    set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-rtti")
endif()

set(LLVM_LINK_COMPONENTS Support Core)

add_llvm_executable(${PROJECT_NAME} main.cc lexer.cc parser.cc printVisitor.cc codegen.cc type.cc scope.cc sema.cc diag_engine.cc)

add_subdirectory(unittest) # 在最后加上这么一句话,这样 cmake 的时候可以得到两个相应的工程文件

文件夹 unittest中也创建一个对应的 CMakeLists.txt文件。

unittest/CMakeLists.txt

ini 复制代码
include(FetchContent)
FetchContent_Declare(
  googletest
  URL https://github.com/google/googletest/archive/03597a01ee50ed33e9dfd640b249b4be3799d395.zip
)
# For Windows: Prevent overriding the parent project's compiler/linker settings
set(gtest_force_shared_crt ON CACHE BOOL "" FORCE)
FetchContent_MakeAvailable(googletest)

add_subdirectory(lexer) # 以你实际要测试的模块取名

当使用 cmake 构建工程的时候也顺便会安装 google test。

unittest这个文件夹下再创建一个文件夹,这里我以 lexer为例子。

testset/lexer_01.txt

ini 复制代码
int a = 3, b = 5;
a + 3 * b - 6;

unittest/lexer/CMakeLists.txt

ini 复制代码
enable_testing()

add_executable(
  lexer_test
  lexer_test.cc
  
  ../../lexer.cc
  ../../type.cc
  ../../diag_engine.cc
)

llvm_map_components_to_libnames(llvm_all Support Core)

target_link_libraries(
  lexer_test
  GTest::gtest_main
  ${llvm_all}
)

include(GoogleTest)
gtest_discover_tests(lexer_test)

unittest/lexer/lexer_test.cc

cpp 复制代码
#include <gtest/gtest.h>
#include "../../lexer.h"

class LexerTest : public ::testing::Test
{
public:
    void SetUp() override
    {
        static llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> buf = llvm::MemoryBuffer::getFile("../testset/lexer_01.txt");
        if (!buf)
        {
            llvm::errs() << "can't open file!\n";
            return;
        }
        llvm::SourceMgr mgr;
        DiagEngine diagEngine(mgr);
        mgr.AddNewSourceBuffer(std::move(*buf), llvm::SMLoc());
        lexer = new Lexer(diagEngine, mgr);
    }

    void TearDown() override
    {
        delete lexer;
    }

    Lexer *lexer;
};

/*
int a = 3, b = 5;
a + 3 * b - 6;
*/
TEST_F(LexerTest, NextToken)
{
    std::vector<Token> expectedVec; // 正确集
    std::vector<Token> curVec;      // 当前集
    expectedVec.push_back(Token{TokenType::kw_int, 1, 1});
    expectedVec.push_back(Token{TokenType::indentifier, 1, 5});
    expectedVec.push_back(Token{TokenType::equal, 1, 7});
    expectedVec.push_back(Token{TokenType::number, 1, 9});
    expectedVec.push_back(Token{TokenType::comma, 1, 10});
    expectedVec.push_back(Token{TokenType::indentifier, 1, 12});
    expectedVec.push_back(Token{TokenType::equal, 1, 14});
    expectedVec.push_back(Token{TokenType::number, 1, 16});
    expectedVec.push_back(Token{TokenType::semi, 1, 17});
    expectedVec.push_back(Token{TokenType::indentifier, 2, 1});
    expectedVec.push_back(Token{TokenType::plus, 2, 3});
    expectedVec.push_back(Token{TokenType::number, 2, 5});
    expectedVec.push_back(Token{TokenType::star, 2, 7});
    expectedVec.push_back(Token{TokenType::indentifier, 2, 9});
    expectedVec.push_back(Token{TokenType::minus, 2, 11});
    expectedVec.push_back(Token{TokenType::number, 2, 13});
    expectedVec.push_back(Token{TokenType::semi, 2, 14});

    Token token;
    while(true)
    {
        lexer->NextToken(token);
        if(token.tokenType == TokenType::eof)
        {
            break;
        }
        curVec.push_back(token);
    }

    ASSERT_EQ(expectedVec.size(), curVec.size());
    for(int i = 0; i < expectedVec.size(); i++)
    {
        const auto expected_tok = expectedVec[i];
        const auto cur_tok = curVec[i];

        EXPECT_EQ(expected_tok.tokenType, cur_tok.tokenType);
        EXPECT_EQ(expected_tok.row, cur_tok.row);
        EXPECT_EQ(expected_tok.col, cur_tok.col);
    }
}

测试 Lexer

生成对应的工程文件
bash 复制代码
cmake -GNinja
运行得到的工程文件 lexer_test
bash 复制代码
lexer_test
运行结果
bash 复制代码
Running main() from build/_deps/googletest-src/googletest/src/gtest_main.cc
[==========] Running 1 test from 1 test suite.
[----------] Global test environment set-up.
[----------] 1 test from LexerTest
[ RUN      ] LexerTest.NextToken
[       OK ] LexerTest.NextToken (2 ms)
[----------] 1 test from LexerTest (2 ms total)

[----------] Global test environment tear-down
[==========] 1 test from 1 test suite ran. (2 ms total)
[  PASSED  ] 1 test.

🎉 至此,我们完成了对赋值表达式的支持,并增强了编译器的错误处理能力。通过引入诊断引擎,编译器现在能够提供更清晰的错误提示,提升了开发体验。后续可继续扩展语言特性,并完善测试覆盖。

相关推荐
sophie旭2 小时前
一道面试题,开始性能优化之旅(3)-- DNS查询+TCP(一)
前端·面试·性能优化
IT_陈寒3 小时前
JavaScript性能优化:这7个V8引擎技巧让我的应用速度提升了50%
前端·人工智能·后端
学渣y3 小时前
nvm下载node版本,npm -v查看版本报错
前端·npm·node.js
excel3 小时前
首屏加载优化总结
前端
敲代码的嘎仔3 小时前
JavaWeb零基础学习Day1——HTML&CSS
java·开发语言·前端·css·学习·html·学习方法
Tachyon.xue5 小时前
Vue 3 项目集成 Element Plus + Tailwind CSS 详细教程
前端·css·vue.js
FuckPatience6 小时前
Vue 中‘$‘符号含义
前端·javascript·vue.js
东风西巷8 小时前
K-Lite Mega/FULL Codec Pack(视频解码器)
前端·电脑·音视频·软件需求
超级大只老咪9 小时前
何为“类”?(Java基础语法)
java·开发语言·前端