在上一篇博客 《LLVM IR 学习手记(二):变量表达式编译器的实现与深入解析》中,我们实现了一个支持变量表达式的基础编译器,但其中存在一个 BUG------无法对变量重新赋值,也不支持连续赋值操作。本篇博客将首先修复这个问题。
1、赋值表达式
(1)、测试文件
expr.txt 内容如下:
ini
int a = 4;
a = 5;
int b = a = 3;
a * b - 4 / 2;
(2)、词法分析器(Lexer)
为实现赋值表达式,我们需要修改词法分析器(Lexer) 、语法分析器(Parser) 和代码生成(CodeGen)。
文法定义
ini
prog : stmt*
stmt : decl-stmt | expr-stmt | null-stmt
null-stmt : ";"
decl-stmt : "int" identifier ("," identifier (= expr)?)* ";"
expr-stmt : expr ";"
expr : assign-expr | add-expr
assign-expr: identifier "=" expr
add-expr : mult-expr (("+" | "-") mult-expr)*
mult-expr : primary-expr (("*" | "/") primary-expr)*
primary-expr : identifier | number | "(" expr ")"
number: ([0-9])+
identifier : (a-zA-Z_)(a-zA-Z0-9_)*
实现代码
lexer.h
cpp
#pragma once
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/raw_ostream.h"
#include "type.h"
// char stream -> token
enum class TokenType
{
unknown,
number, // [0-9]+
indentifier, // 变量
kw_int, // int
plus, // +
minus, // -
star, // *
slash, // /
equal, // =
l_parent, // (
r_parent, // )
semi, // ;
comma, // ,
eof // end of file
};
class Token
{
public:
TokenType tokenType; // token 的种类
int row, col;
int value; // for number
llvm::StringRef content; // for debug
CType *type; // for built-in type
public:
void Dump();
};
class Lexer
{
public:
Lexer(llvm::StringRef buf)
{
BufPtr = buf.begin();
LineHeadPtr = buf.begin();
BufEnd = buf.end();
row = 1;
}
void NextToken(Token &token);
void SaveState(); // 用于保存当前状态
void RestoreState(); // 恢复存储的状态
private:
struct State // 赋值表达式的关键要素
{
const char *BufPtr;
const char *LineHeadPtr;
const char *BufEnd;
int row;
};
private:
const char *BufPtr;
const char *LineHeadPtr;
const char *BufEnd;
int row;
State state; // 可以记录此时 lexer 的状态
};
cpp
#include "lexer.h"
void Token::Dump()
{
llvm::outs() << "{" << content << ", row = " << row << ", col = " << col << "}\n";
}
// number, // [0-9]+
// indentifier,// 变量
// kw_int, // int
// plus, // +
// minus, // -
// star, // *
// slash, // /
// equal, // =
// l_parent, // (
// r_parent, // )
// semi, // ;
// comma, // ,
llvm::StringRef Token::GetSpellingText(TokenType tokenType)
{
switch (tokenType)
{
case TokenType::kw_int:
return "int";
case TokenType::plus:
return "+";
case TokenType::minus:
return "-";
case TokenType::star:
return "*";
case TokenType::slash:
return "/";
case TokenType::equal:
return "=";
case TokenType::l_parent:
return "(";
case TokenType::r_parent:
return ")";
case TokenType::semi:
return ";";
case TokenType::comma:
return ",";
case TokenType::number:
return "number";
case TokenType::indentifier:
return "indentifier";
default:
llvm::llvm_unreachable_internal();
}
}
bool IsWhiteSpace(char ch)
{
return ch == ' ' || ch == '\r' || ch == '\n';
}
bool IsDigit(char ch)
{
return ch >= '0' && ch <= '9';
}
bool IsLetter(char ch)
{
// a-z, A-Z, _
return (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch == '_');
}
void Lexer::NextToken(Token &token)
{
// 过滤空格
while (IsWhiteSpace(*BufPtr))
{
if (*BufPtr == '\n')
{
row += 1;
LineHeadPtr = BufPtr + 1;
}
BufPtr++;
}
token.row = row;
token.col = BufPtr - LineHeadPtr + 1;
// 判断是否到结尾了
if (BufPtr >= BufEnd)
{
token.tokenType = TokenType::eof;
return;
}
token.ptr = BufPtr;
token.length = 0;
// 判断是否为数字
if (IsDigit(*BufPtr))
{
int len = 0;
int val = 0;
while (IsDigit(*BufPtr))
{
val = val * 10 + *BufPtr++ - '0';
token.length++;
}
token.value = val;
token.tokenType = TokenType::number;
token.type = CType::getIntTy();
}
else if (IsLetter(*BufPtr)) // 为变量
{
while (IsLetter(*BufPtr) || IsDigit(*BufPtr))
{
BufPtr++;
}
token.tokenType = TokenType::indentifier;
token.length = BufPtr - token.ptr;
llvm::StringRef text(token.ptr, BufPtr - token.ptr);
if (text == "int")
{
token.tokenType = TokenType::kw_int;
}
}
else // 为特殊字符
{
switch (*BufPtr)
{
case '+':
{
token.tokenType = TokenType::plus;
token.length = 1;
BufPtr++;
break;
}
case '-':
{
token.tokenType = TokenType::minus;
token.length = 1;
BufPtr++;
break;
}
case '*':
{
token.tokenType = TokenType::star;
token.length = 1;
BufPtr++;
break;
}
case '/':
{
token.tokenType = TokenType::slash;
token.length = 1;
BufPtr++;
break;
}
case '=':
{
token.tokenType = TokenType::equal;
token.length = 1;
BufPtr++;
break;
}
case '(':
{
token.tokenType = TokenType::l_parent;
token.length = 1;
BufPtr++;
break;
}
case ')':
{
token.tokenType = TokenType::r_parent;
token.length = 1;
BufPtr++;
break;
}
case ';':
{
token.tokenType = TokenType::semi;
token.length = 1;
BufPtr++;
break;
}
case ',':
{
token.tokenType = TokenType::comma;
token.length = 1;
BufPtr++;
break;
}
default:
{
llvm::llvm_unreachable_internal();
BufPtr++;
break;
}
}
}
}
void Lexer::SaveState()
{
state.LineHeadPtr = LineHeadPtr;
state.BufPtr = BufPtr;
state.BufEnd = BufEnd;
state.row = row;
}
void Lexer::RestoreState()
{
LineHeadPtr = state.LineHeadPtr;
BufPtr = state.BufPtr;
BufEnd = state.BufEnd;
row = state.row;
}
(3)、语法分析器(Parser)
判断一个表达式是否为赋值表达式,需要向后查看一个 Token。如果该 Token 是变量(identifier),且下一个 Token 是等号(equal),则可以确定这是一个赋值表达式。赋值表达式通常采用右结合方式处理。
根据上述文法,我们需要在解析表达式语句时增加一个新的分支。
实现代码
parser.h
cpp
#pragma once
#include "ast.h"
#include "lexer.h"
#include "sema.h"
class Parser
{
public:
Parser(Lexer &lexer, Sema &sema) : lexer(lexer), sema(sema)
{
Advance();
}
std::shared_ptr<Program> ParseProgram();
private:
std::vector<std::shared_ptr<ASTNode>> ParseDeclStmt();
std::shared_ptr<ASTNode> ParseExprStmt();
std::shared_ptr<ASTNode> ParseAssignExpr(); // 需要新增一个用于解析赋值表达式的函数
std::shared_ptr<ASTNode> ParseExpr();
std::shared_ptr<ASTNode> ParseTerm();
std::shared_ptr<ASTNode> ParseFactor();
// 消耗 token 的函数
// 检测 token 的类型
bool Expect(TokenType tokenType);
// 检测 token 的类型并消费
bool Consume(TokenType tokenType);
// 直接消耗当前的 token
bool Advance();
private:
Lexer &lexer;
Sema &sema;
Token token;
};
cpp
#include "parser.h"
#include <cassert>
// prog : stmt*
// stmt : decl-stmt | expr-stmt | null-stmt
// null-stmt : ";"
// decl-stmt : "int" identifier ("," identifier (= expr)?)* ";"
// expr-stmt : expr ";"
// expr : assign-expr | add-expr
// assign-expr: identifier "=" expr
// add-expr : mult-expr (("+" | "-") mult-expr)*
// mult-expr : primary-expr (("*" | "/") primary-expr)*
// primary-expr : identifier | number | "(" expr ")"
// number: ([0-9])+
// identifier : (a-zA-Z_)(a-zA-Z0-9_)*
// 解析目标程序
// stmt : decl-stmt | expr-stmt | null-stmt
std::shared_ptr<Program> Parser::ParseProgram()
{
std::vector<std::shared_ptr<ASTNode>> ExprVec;
while (token.tokenType != TokenType::eof)
{
// 遇到 ; 需要进行消费 token
// null-stmt
if (token.tokenType == TokenType::semi)
{
Advance();
continue;
}
// decl-stmt
if (token.tokenType == TokenType::kw_int)
{
auto exprs = ParseDeclStmt();
for (auto expr : exprs)
{
ExprVec.push_back(expr);
}
}
else // expr-stmt
{
auto expr = ParseExprStmt();
ExprVec.push_back(expr);
}
}
auto program = std::make_shared<Program>();
program->ExprVec = std::move(ExprVec);
return program;
}
// 解析声明
std::vector<std::shared_ptr<ASTNode>> Parser::ParseDeclStmt()
{
/// int a, b = 3;
/// int a = 3;
Consume(TokenType::kw_int);
CType *baseTy = CType::getIntTy();
/// a , b = 3;
/// a = 3;
std::vector<std::shared_ptr<ASTNode>> astArr;
/// a, b = 3;
/// a = 3;
int i = 0;
while (token.tokenType != TokenType::semi)
{
if (i++ > 0) // if (i++)
{
assert(Consume(TokenType::comma));
}
/// 变量声明的节点: int a = 3; -> int a; a = 3;
// a = 3;
auto varName = token.content;
auto variableDecl = sema.SemaVariableDeclNode(varName, baseTy); // get a type
astArr.push_back(variableDecl);
Consume(TokenType::indentifier);
// = 3;
if (token.tokenType == TokenType::equal)
{
llvm::StringRef name = varName;
Advance();
// 3;
auto right = ParseExpr();
auto left = sema.SemaVariableAccessNode(name);
auto assign = sema.SemaAssignExprNode(left, right);
astArr.push_back(assign);
}
}
Advance();
return astArr;
}
// 解析表达式语句
std::shared_ptr<ASTNode> Parser::ParseExprStmt()
{
auto expr = ParseExpr();
Consume(TokenType::semi);
return expr;
}
// 解析表达式
// expr : assign-expr | add-expr
// assign-expr: identifier "=" expr
// add-expr : mult-expr (("+" | "-") mult-expr)*
std::shared_ptr<ASTNode> Parser::ParseExpr()
{
lexer.SaveState(); // 先保存词法分析器 lexer 当前状态
bool isAssign = false; // 用于标记这个表达式是否为赋值表达式
Token tmp = token; // 借助一个了临时 token 进行判断
// a = b;
if (tmp.tokenType == TokenType::indentifier) // 判断类型是否为变量
{
lexer.NextToken(tmp);
if (tmp.tokenType == TokenType::equal) // 判断类型是否为等号
{
isAssign = true;
}
}
lexer.RestoreState(); // 恢复词法分析器 lexer 的状态
if (isAssign) // 说明是一个赋值表达式
{
return ParseAssignExpr();
}
// add-expr
std::shared_ptr<ASTNode> left = ParseTerm();
while (token.tokenType == TokenType::plus || token.tokenType == TokenType::minus)
{
OpCode op;
if (token.tokenType == TokenType::plus)
{
op = OpCode::add;
}
else
{
op = OpCode::sub;
}
Advance();
auto right = ParseTerm();
auto binaryExpr = sema.SemaBinaryExprNode(left, right, op);
left = binaryExpr;
}
return left;
}
// 解析赋值表达式
std::shared_ptr<ASTNode> Parser::ParseAssignExpr()
{
// a = b;
llvm::StringRef text = token.content;
Consume(TokenType::indentifier);
auto expr = sema.SemaVariableAccessNode(text);
Consume(TokenType::equal);
return sema.SemaAssignExprNode(expr, ParseExpr());
}
// 解析项
std::shared_ptr<ASTNode> Parser::ParseTerm()
{
std::shared_ptr<ASTNode> left = ParseFactor();
while (token.tokenType == TokenType::star || token.tokenType == TokenType::slash)
{
OpCode op;
if (token.tokenType == TokenType::star)
{
op = OpCode::mul;
}
else
{
op = OpCode::div;
}
Advance();
auto right = ParseFactor();
auto binaryExpr = sema.SemaBinaryExprNode(left, right, op);
left = binaryExpr;
}
return left;
}
// 解析因子
std::shared_ptr<ASTNode> Parser::ParseFactor()
{
if (token.tokenType == TokenType::l_parent)
{
Advance();
auto expr = ParseExpr();
assert(Expect(TokenType::r_parent));
Advance();
return expr;
}
else if(token.tokenType == TokenType::indentifier)
{
auto variableAccessExpr = sema.SemaVariableAccessNode(token.content);
Advance();
return variableAccessExpr;
}
else
{
auto factorExpr = sema.SemaNumberExprNode(token.value, token.type);
Advance();
return factorExpr;
}
}
/// 消耗 token 函数
bool Parser::Expect(TokenType tokenType)
{
if (token.tokenType == tokenType)
return true;
return false;
}
bool Parser::Consume(TokenType tokenType)
{
if (Expect(tokenType))
{
Advance();
return true;
}
return false;
}
bool Parser::Advance()
{
lexer.NextToken(token);
return true;
}
(4)、代码生成(CodeGen)
最后是代码生成部分,主要工作是将一些返回左值的操作改为返回右值。
实现代码
codegen.h
cpp
#pragma once
#include "ast.h"
#include "parser.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/ADT/StringMap.h"
// 通过访问者模式生成代码
class CodeGen : public Visitor
{
public:
CodeGen(std::shared_ptr<Program> program)
{
module = std::make_shared<llvm::Module>("expr", context);
VisitProgram(program.get());
}
public:
llvm::Value *VisitProgram(Program *p) override;
llvm::Value *VisitVariableDecl(VariableDecl *decl) override;
llvm::Value *VisitVariableAccessExpr(VariableAccessExpr *varaccExpr) override;
llvm::Value *VisitAssignExpr(AssignExpr *assignExpr) override;
llvm::Value *VisitBinaryExpr(BinaryExpr *binaryExpr) override;
llvm::Value *VisitNumberExpr(NumberExpr *factorExpr) override;
private:
llvm::LLVMContext context;
llvm::IRBuilder<> irBuilder{context};
std::shared_ptr<llvm::Module> module;
// 要得到右值就需要 irBuilder.CreateLoad() 函数,这个函数需要变量对应的类型
// 为了得到变量对应的类型,所以这个 Map 需要使用 pair 来进行保存
llvm::StringMap<std::pair<llvm::Value *, llvm::Type *>> varAddrTyMap;
};
cpp
#include "codegen.h"
#include "llvm/IR/Verifier.h"
llvm::Value *CodeGen::VisitProgram(Program *p)
{
// 创建 printf 函数
auto printFunctionType = llvm::FunctionType::get(irBuilder.getInt32Ty(), {irBuilder.getInt8PtrTy()}, true);
auto printFunction = llvm::Function::Create(printFunctionType, llvm::GlobalValue::ExternalLinkage, "printf", module.get());
// 创建 main 函数
auto mainFunctionType = llvm::FunctionType::get(irBuilder.getInt32Ty(), false);
auto mainFunction = llvm::Function::Create(mainFunctionType, llvm::GlobalValue::ExternalLinkage, "main", module.get());
// 创建 main 函数的基本块
llvm::BasicBlock *entryBlock = llvm::BasicBlock::Create(context, "entry", mainFunction);
// 设置该基本块作为指令的入口
irBuilder.SetInsertPoint(entryBlock);
llvm::Value *lastVal = nullptr;
for (auto expr : p->ExprVec)
{
lastVal = expr->Accept(this);
}
irBuilder.CreateCall(printFunction, {irBuilder.CreateGlobalStringPtr("expr value: %d\n"), lastVal});
// 创建返回值
llvm::Value *ret = irBuilder.CreateRet(irBuilder.getInt32(0));
llvm::verifyFunction(*mainFunction);
module->print(llvm::outs(), nullptr);
return ret;
}
// 这个函数需要进行一下微调
llvm::Value *CodeGen::VisitVariableDecl(VariableDecl *decl)
{
llvm::Type *ty = nullptr;
if (decl->type == CType::getIntTy())
{
ty = irBuilder.getInt32Ty();
}
llvm::Value *varAddr = irBuilder.CreateAlloca(ty, nullptr, decl->name);
varAddrTyMap.insert({text, {varAddr, ty}}); // 第二个参数变成了 pair
return varAddr;
}
llvm::Value *CodeGen::VisitVariableAccessExpr(VariableAccessExpr *varaccExpr)
{
std::pair pair = varAddrTyMap[varaccExpr->name];
llvm::Value *varAddr = pair.first; // 得到变量的地址
llvm::Type *ty = pair.second; // 得到变量的类型
if (varaccExpr->type == CType::getIntTy())
{
ty = irBuilder.getInt32Ty();
}
// 返回一个右值
return irBuilder.CreateLoad(ty, varAddr, text);
}
// a = 3; // right value
llvm::Value *CodeGen::VisitAssignExpr(AssignExpr *assignExpr)
{
VariableAccessExpr *varAccExpr = (VariableAccessExpr *)assignExpr->left.get();
std::pair pair = varAddrTyMap[varAccExpr->name];
llvm::Value *addr = pair.first;
llvm::Type *ty = pair.second;
llvm::Value *rValue = assignExpr->right->Accept(this);
// 这个得到的是一个左值
irBuilder.CreateStore(rValue, addr);
// 返回一个右值
return irBuilder.CreateLoad(ty, addr,text);
}
llvm::Value *CodeGen::VisitBinaryExpr(BinaryExpr *binaryExpr)
{
auto left = binaryExpr->left->Accept(this);
auto right = binaryExpr->right->Accept(this);
switch (binaryExpr->op)
{
case OpCode::add:
{
return irBuilder.CreateNSWAdd(left, right, "add"); // CreateNSW... 是防止溢出行为的
}
case OpCode::sub:
{
return irBuilder.CreateNSWSub(left, right, "sub");
}
case OpCode::mul:
{
return irBuilder.CreateNSWMul(left, right, "mul");
}
case OpCode::div:
{
return irBuilder.CreateSDiv(left, right, "div");
}
default:
{
break;
}
}
return nullptr;
}
llvm::Value *CodeGen::VisitNumberExpr(NumberExpr *factorExpr)
{
return irBuilder.getInt32(factorExpr->token.value);
}
(5)、测试编译器
生成 IR 文件
bash
bin/expr test/expr.txt > test/expr.ll
生成的 IR 内容
ini
; ModuleID = 'expr'
source_filename = "expr"
@0 = private unnamed_addr constant [16 x i8] c"expr value: %d\0A\00", align 1
declare i32 @printf(ptr, ...)
define i32 @main() {
entry:
%a = alloca i32, align 4
store i32 4, ptr %a, align 4
%a1 = load i32, ptr %a, align 4
store i32 5, ptr %a, align 4
%a2 = load i32, ptr %a, align 4
%b = alloca i32, align 4
store i32 3, ptr %a, align 4
%a3 = load i32, ptr %a, align 4
store i32 %a3, ptr %b, align 4
%b4 = load i32, ptr %b, align 4
%a5 = load i32, ptr %a, align 4
%b6 = load i32, ptr %b, align 4
%mul = mul nsw i32 %a5, %b6
%sub = sub nsw i32 %mul, 2
%0 = call i32 (ptr, ...) @printf(ptr @0, i32 %sub)
ret i32 0
}
运行 IR
bash
lli test/expr.ll
运行结果
bash
expr value: 15
结果正确,验证了编译器目前赋值功能的正确性。
2、错误处理
编译器在报错时通常需要提供明确的错误信息,这能大大提高调试效率。错误处理模块正是为了实现这一功能。
(1)、测试文件
expr_01.txt 文件如下:
ini
int a = 4;
{
int b = 2;
expr_02.txt 文件如下:
ini
int a = 4
a + b * 3;
expr_03.txt 文件如下:
ini
int a = 4;
a + b * 3;
expr_04.txt 文件如下:
ini
int a = 4;
int a;
(2)、诊断引擎(DiagEngine)
我们需要一个诊断引擎来帮助报错。在编写诊断引擎(DiagEngine)代码之前,需要先编写一个配置文件,用于收集所有错误信息。
实现代码
diag.inc
cpp
#ifndef DIAG
#define DIAG(ID, KIND, MSG)
#endif
/// lexer
DIAG(err_unknown_char, Error, "unknown char '{0}'")
/// parser
DIAG(err_expected, Error, "expect '{0}' , but get '{1}'")
DIAG(err_redefine, Error, "redefine symbol '{0}'")
DIAG(err_undefine, Error, "undefine symbol '{0}'")
DIAG(err_lvalue, Error, "require left value on the assign operation left side")
#undef DIAG
diag_engine.h
cpp
#pragma once
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/SMLoc.h"
#include "llvm/Support/FormatVariadic.h"
namespace diag // 定义一个命名空间用于枚举错误编号
{
enum
{
#define DIAG(ID, KIND, MSG) ID, // 只要 DIAG 里的 ID 属性
#include "diag.inc" // 这种方式在预处理阶段会直接将 diag.inc 文件全部内容直接放到这个位置
};
};
class DiagEngine
{
public:
DiagEngine(llvm::SourceMgr &mgr) : mgr(mgr) {}
template <typename... Args> // 定义一个可变参数
void Report(llvm::SMLoc loc, unsigned diagID, Args... args) // 传入这个可变参数
{
// 通过 diagID 可以获得 kind 和 msg
auto kind = GetDiagKind(diagID);
const char *fmt = GetDiagMsg(diagID);
auto f = llvm::formatv(fmt, std::forward<Args>(args)...).str();
mgr.PrintMessage(loc, kind, f); // 报错功能主要借助的就是这个函数
if (kind == llvm::SourceMgr::DK_Error)
{
exit(1);
}
}
private:
llvm::SourceMgr::DiagKind GetDiagKind(unsigned id);
const char *GetDiagMsg(unsigned id);
private:
llvm::SourceMgr &mgr; // 报错主要用的是 llvm 提供的 SourceMgr 这个类
};
diag_engine.cc
cpp
#include "diag_engine.h"
static llvm::SourceMgr::DiagKind diag_kind[] =
{
#define DIAG(ID, KIND, MSG) llvm::SourceMgr::DK_##KIND,
#include "diag.inc"
};
static const char* diag_msg[] =
{
#define DIAG(ID, KIND, MSG) MSG,
#include "diag.inc"
};
llvm::SourceMgr::DiagKind DiagEngine::GetDiagKind(unsigned id)
{
return diag_kind[id];
}
const char* DiagEngine::GetDiagMsg(unsigned id)
{
return diag_msg[id];
}
(3)、词法分析器(Lexer)
有了诊断引擎,接下来需要将其融入到**词法分析器(Lexer)、语法分析器(Parser)、语义分析器(Sema)**中。
实现代码
lexer.h
cpp
#pragma once
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/raw_ostream.h"
#include "type.h"
#include "diag_engine.h"
// char stream -> token
enum class TokenType
{
//unknown, // 将这个进行注释掉,因为 unknown 本身就是一种错误
number, // [0-9]+
indentifier, // 变量
kw_int, // int
plus, // +
minus, // -
star, // *
slash, // /
equal, // =
l_parent, // (
r_parent, // )
semi, // ;
comma, // ,
eof // end of file
};
class Token
{
public:
TokenType tokenType; // token 的种类
int row, col;
int value; // for number
// 因为 DiagEngine 中的 mgr 进行报错的时候需要位置信息,
// 而这个位置信息依靠的是 llvm::SMLoc::getFromPointer() 这个函数,
// 这个函数需要的参数就是一个指针,所以将 content 这个拆成 ptr 和 length
// 借助 llvm::StringRef() 的构造函数也能得到对应的 content
// llvm::StringRef content;
const char *ptr; // for debug
int length;
CType *type; // for built-in type
public:
void Dump();
static llvm::StringRef GetSpellingText(TokenType tokenType);
};
class Lexer
{
private:
DiagEngine &diagEngine;
llvm::SourceMgr &mgr;
public:
Lexer(DiagEngine &diagEngine, llvm::SourceMgr &mgr) : diagEngine(diagEngine), mgr(mgr)
{
unsigned id = mgr.getMainFileID();
llvm::StringRef buf = mgr.getMemoryBuffer(id)->getBuffer();
BufPtr = buf.begin();
LineHeadPtr = buf.begin();
BufEnd = buf.end();
row = 1;
}
void NextToken(Token &token);
void SaveState();
void RestoreState();
DiagEngine &GetDiagEngine() const
{
return diagEngine;
}
private:
struct State
{
const char *BufPtr;
const char *LineHeadPtr;
const char *BufEnd;
int row;
};
private:
const char *BufPtr;
const char *LineHeadPtr;
const char *BufEnd;
int row;
State state;
};
cpp
#include "lexer.h"
void Token::Dump()
{
llvm::StringRef text(ptr, length);
llvm::outs() << "{" << text << ", row = " << row << ", col = " << col << "}\n";
}
// number, // [0-9]+
// indentifier,// 变量
// kw_int, // int
// plus, // +
// minus, // -
// star, // *
// slash, // /
// equal, // =
// l_parent, // (
// r_parent, // )
// semi, // ;
// comma, // ,
llvm::StringRef Token::GetSpellingText(TokenType tokenType)
{
switch (tokenType)
{
case TokenType::kw_int:
return "int";
case TokenType::plus:
return "+";
case TokenType::minus:
return "-";
case TokenType::star:
return "*";
case TokenType::slash:
return "/";
case TokenType::equal:
return "=";
case TokenType::l_parent:
return "(";
case TokenType::r_parent:
return ")";
case TokenType::semi:
return ";";
case TokenType::comma:
return ",";
case TokenType::number:
return "number";
case TokenType::indentifier:
return "indentifier";
default:
llvm::llvm_unreachable_internal(); // 逻辑上不可能到达这个位置吗,假如到达这位置会报错,说明代码有问题
}
}
bool IsWhiteSpace(char ch)
{
return ch == ' ' || ch == '\r' || ch == '\n';
}
bool IsDigit(char ch)
{
return ch >= '0' && ch <= '9';
}
bool IsLetter(char ch)
{
// a-z, A-Z, _
return (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch == '_');
}
void Lexer::NextToken(Token &token)
{
// 过滤空格
while (IsWhiteSpace(*BufPtr))
{
if (*BufPtr == '\n')
{
row += 1;
LineHeadPtr = BufPtr + 1;
}
BufPtr++;
}
token.row = row;
token.col = BufPtr - LineHeadPtr + 1;
// 判断是否到结尾了
if (BufPtr >= BufEnd)
{
token.tokenType = TokenType::eof;
return;
}
token.ptr = BufPtr;
token.length = 0;
// 判断是否为数字
if (IsDigit(*BufPtr))
{
int len = 0;
int val = 0;
while (IsDigit(*BufPtr))
{
val = val * 10 + *BufPtr++ - '0';
token.length++;
}
token.value = val;
token.tokenType = TokenType::number;
token.type = CType::getIntTy();
}
else if (IsLetter(*BufPtr)) // 为变量
{
while (IsLetter(*BufPtr) || IsDigit(*BufPtr))
{
BufPtr++;
}
token.tokenType = TokenType::indentifier;
token.length = BufPtr - token.ptr;
llvm::StringRef text(token.ptr, BufPtr - token.ptr);
if (text == "int")
{
token.tokenType = TokenType::kw_int;
}
}
else // 为特殊字符
{
switch (*BufPtr)
{
case '+':
{
token.tokenType = TokenType::plus;
token.length = 1;
BufPtr++;
break;
}
case '-':
{
token.tokenType = TokenType::minus;
token.length = 1;
BufPtr++;
break;
}
case '*':
{
token.tokenType = TokenType::star;
token.length = 1;
BufPtr++;
break;
}
case '/':
{
token.tokenType = TokenType::slash;
token.length = 1;
BufPtr++;
break;
}
case '=':
{
token.tokenType = TokenType::equal;
token.length = 1;
BufPtr++;
break;
}
case '(':
{
token.tokenType = TokenType::l_parent;
token.length = 1;
BufPtr++;
break;
}
case ')':
{
token.tokenType = TokenType::r_parent;
token.length = 1;
BufPtr++;
break;
}
case ';':
{
token.tokenType = TokenType::semi;
token.length = 1;
BufPtr++;
break;
}
case ',':
{
token.tokenType = TokenType::comma;
token.length = 1;
BufPtr++;
break;
}
default:
{
diagEngine.Report(llvm::SMLoc::getFromPointer(BufPtr), diag::err_unknown_char, *BufPtr);
}
}
}
}
void Lexer::SaveState()
{
state.LineHeadPtr = LineHeadPtr;
state.BufPtr = BufPtr;
state.BufEnd = BufEnd;
state.row = row;
}
void Lexer::RestoreState()
{
LineHeadPtr = state.LineHeadPtr;
BufPtr = state.BufPtr;
BufEnd = state.BufEnd;
row = state.row;
}
(4)、语法分析器(Parser)
对于语法分析器,主要修改的是抽象语法树AST 和消耗 Token 的函数的代码。
实现代码
ast.h
cpp
#pragma once
#include <vector>
#include <memory>
#include "llvm/IR/Value.h"
#include "type.h"
#include "lexer.h"
enum class OpCode
{
add,
sub,
mul,
div
};
// 进行声明
class Program;
class Expr;
class VariableDecl;
class VariableAccessExpr;
class BinaryExpr;
class AssignExpr;
class NumberExpr;
// 访问者模式
class Visitor
{
public:
virtual ~Visitor() {}
virtual llvm::Value *VisitProgram(Program *p) = 0;
virtual llvm::Value *VisitVariableDecl(VariableDecl* decl) = 0;
virtual llvm::Value *VisitVariableAccessExpr(VariableAccessExpr* varaccExpr) = 0;
virtual llvm::Value *VisitAssignExpr(AssignExpr* assignExpr) = 0;
virtual llvm::Value *VisitBinaryExpr(BinaryExpr *binaryExpr) = 0;
virtual llvm::Value *VisitNumberExpr(NumberExpr *factorExpr) = 0;
};
// 语法树的公共节点
class ASTNode
{
public:
enum Kind
{
ND_VariableDecl,
ND_BinaryExpr,
ND_NumberExpr,
ND_VariableAccessExpr,
ND_AssignExpr
};
private:
const Kind kind;
public:
ASTNode(Kind kind) : kind(kind) {}
virtual ~ASTNode() {}
virtual llvm::Value *Accept(Visitor *v) { return nullptr; } // 通过虚函数的特性完成分发的功能
const Kind getKind() const { return kind; }
public:
CType *type;
Token token;
};
// 变量声明节点
class VariableDecl : public ASTNode
{
public:
VariableDecl() : ASTNode(ND_VariableDecl) {}
llvm::Value *Accept(Visitor *v) override
{
return v->VisitVariableDecl(this);
}
static bool classof(const ASTNode *node) // 判断是否能进行指针的强转
{
return node->getKind() == ND_VariableDecl;
}
};
// 二元表达式节点
class BinaryExpr : public ASTNode
{
public:
BinaryExpr() : ASTNode(ND_BinaryExpr) {}
llvm::Value *Accept(Visitor *v) override
{
return v->VisitBinaryExpr(this);
}
static bool classof(const ASTNode *node) // 判断是否能进行指针的强转
{
return node->getKind() == ND_BinaryExpr;
}
public:
OpCode op;
std::shared_ptr<ASTNode> left;
std::shared_ptr<ASTNode> right;
};
// 赋值表达式节点
class AssignExpr : public ASTNode
{
public:
AssignExpr() : ASTNode(ND_AssignExpr) {}
llvm::Value *Accept(Visitor *v) override
{
return v->VisitAssignExpr(this);
}
static bool classof(const ASTNode *node) // 判断是否能进行指针的强转
{
return node->getKind() == ND_AssignExpr;
}
public:
std::shared_ptr<ASTNode> left;
std::shared_ptr<ASTNode> right;
};
// 数字表达式节点
class NumberExpr : public ASTNode
{
public:
NumberExpr() : ASTNode(ND_NumberExpr) {}
llvm::Value *Accept(Visitor *v) override
{
return v->VisitNumberExpr(this);
}
static bool classof(const ASTNode *node) // 判断是否能进行指针的强转
{
return node->getKind() == ND_NumberExpr;
}
};
// 变量访问节点
class VariableAccessExpr : public ASTNode // 变量表达式
{
public:
VariableAccessExpr() : ASTNode(ND_VariableAccessExpr) {}
llvm::Value *Accept(Visitor *v) override
{
return v->VisitVariableAccessExpr(this);
}
static bool classof(const ASTNode *node) // 判断是否能进行指针的强转
{
return node->getKind() == ND_VariableAccessExpr;
}
};
// 目标程序
class Program
{
public:
std::vector<std::shared_ptr<ASTNode>> ExprVec;
};
parser.h
cpp
#pragma once
#include "ast.h"
#include "lexer.h"
#include "sema.h"
class Parser
{
public:
Parser(Lexer &lexer, Sema &sema) : lexer(lexer), sema(sema)
{
Advance();
}
std::shared_ptr<Program> ParseProgram();
private:
std::vector<std::shared_ptr<ASTNode>> ParseDeclStmt();
std::shared_ptr<ASTNode> ParseExprStmt();
std::shared_ptr<ASTNode> ParseAssignExpr();
std::shared_ptr<ASTNode> ParseExpr();
std::shared_ptr<ASTNode> ParseTerm();
std::shared_ptr<ASTNode> ParseFactor();
// 消耗 token 的函数
// 检测 token 的类型
bool Expect(TokenType tokenType);
// 检测 token 的类型并消费
bool Consume(TokenType tokenType);
// 直接消耗当前的 token
bool Advance();
DiagEngine &GetDiagEngine() const
{
return lexer.GetDiagEngine();
}
private:
Lexer &lexer;
Sema &sema;
Token token;
};
cpp
#include "parser.h"
#include <cassert>
// prog : stmt*
// stmt : decl-stmt | expr-stmt | null-stmt
// null-stmt : ";"
// decl-stmt : "int" identifier ("," identifier (= expr)?)* ";"
// expr-stmt : expr ";"
// expr : assign-expr | add-expr
// assign-expr: identifier "=" expr
// add-expr : mult-expr (("+" | "-") mult-expr)*
// mult-expr : primary-expr (("*" | "/") primary-expr)*
// primary-expr : identifier | number | "(" expr ")"
// number: ([0-9])+
// identifier : (a-zA-Z_)(a-zA-Z0-9_)*
// 解析目标程序
// stmt : decl-stmt | expr-stmt | null-stmt
std::shared_ptr<Program> Parser::ParseProgram()
{
std::vector<std::shared_ptr<ASTNode>> ExprVec;
while (token.tokenType != TokenType::eof)
{
// 遇到 ; 需要进行消费 token
// null-stmt
if (token.tokenType == TokenType::semi)
{
Advance();
continue;
}
// decl-stmt
if (token.tokenType == TokenType::kw_int)
{
auto exprs = ParseDeclStmt();
for (auto expr : exprs)
{
ExprVec.push_back(expr);
}
}
else // expr-stmt
{
auto expr = ParseExprStmt();
ExprVec.push_back(expr);
}
}
auto program = std::make_shared<Program>();
program->ExprVec = std::move(ExprVec);
return program;
}
// 解析声明
std::vector<std::shared_ptr<ASTNode>> Parser::ParseDeclStmt()
{
/// int a, b = 3;
/// int a = 3;
Consume(TokenType::kw_int);
CType *baseTy = CType::getIntTy();
/// a , b = 3;
/// a = 3;
std::vector<std::shared_ptr<ASTNode>> astArr;
/// a, b = 3;
/// a = 3;
int i = 0;
while (token.tokenType != TokenType::semi)
{
if (i++ > 0) // if (i++)
{
assert(Consume(TokenType::comma));
}
/// 变量声明的节点: int a = 3; -> int a; a = 3;
// a = 3;
auto variableDecl = sema.SemaVariableDeclNode(token, baseTy); // get a type
astArr.push_back(variableDecl);
Token tmp = token;
Consume(TokenType::indentifier);
// = 3;
if (token.tokenType == TokenType::equal)
{
Token opToken = token;
Advance();
// 3;
auto right = ParseExpr();
auto left = sema.SemaVariableAccessNode(tmp);
auto assign = sema.SemaAssignExprNode(left, right, opToken);
astArr.push_back(assign);
}
}
Advance();
return astArr;
}
// 解析表达式语句
std::shared_ptr<ASTNode> Parser::ParseExprStmt()
{
auto expr = ParseExpr();
Consume(TokenType::semi);
return expr;
}
// 解析表达式
// expr : assign-expr | add-expr
// assign-expr: identifier "=" expr
// add-expr : mult-expr (("+" | "-") mult-expr)*
std::shared_ptr<ASTNode> Parser::ParseExpr()
{
lexer.SaveState();
bool isAssign = false;
Token tmp = token;
// a = b;
if (tmp.tokenType == TokenType::indentifier)
{
lexer.NextToken(tmp);
if (tmp.tokenType == TokenType::equal)
{
isAssign = true;
}
}
lexer.RestoreState();
if (isAssign)
{
return ParseAssignExpr();
}
// add-expr
std::shared_ptr<ASTNode> left = ParseTerm();
while (token.tokenType == TokenType::plus || token.tokenType == TokenType::minus)
{
OpCode op;
if (token.tokenType == TokenType::plus)
{
op = OpCode::add;
}
else
{
op = OpCode::sub;
}
Advance();
auto right = ParseTerm();
auto binaryExpr = sema.SemaBinaryExprNode(left, right, op);
left = binaryExpr;
}
return left;
}
// 解析赋值表达式
std::shared_ptr<ASTNode> Parser::ParseAssignExpr()
{
// a = b;
Token tmp = token;
Consume(TokenType::indentifier);
auto expr = sema.SemaVariableAccessNode(tmp);
Token opToken = token;
Consume(TokenType::equal);
return sema.SemaAssignExprNode(expr, ParseExpr(), opToken);
}
// 解析项
std::shared_ptr<ASTNode> Parser::ParseTerm()
{
std::shared_ptr<ASTNode> left = ParseFactor();
while (token.tokenType == TokenType::star || token.tokenType == TokenType::slash)
{
OpCode op;
if (token.tokenType == TokenType::star)
{
op = OpCode::mul;
}
else
{
op = OpCode::div;
}
Advance();
auto right = ParseFactor();
auto binaryExpr = sema.SemaBinaryExprNode(left, right, op);
left = binaryExpr;
}
return left;
}
// 解析因子
std::shared_ptr<ASTNode> Parser::ParseFactor()
{
if (token.tokenType == TokenType::l_parent)
{
Advance();
auto expr = ParseExpr();
assert(Expect(TokenType::r_parent));
Advance();
return expr;
}
else if (token.tokenType == TokenType::indentifier)
{
auto variableAccessExpr = sema.SemaVariableAccessNode(token);
Advance();
return variableAccessExpr;
}
else
{
Expect(TokenType::number);
auto factorExpr = sema.SemaNumberExprNode(token, token.type);
Advance();
return factorExpr;
}
}
/// 消耗 token 函数
bool Parser::Expect(TokenType tokenType)
{
if (token.tokenType == tokenType)
return true;
// 当不是所期望的类型时应该直接报错,说明此时源代码不符合文法
GetDiagEngine().Report(llvm::SMLoc::getFromPointer(token.ptr),
diag::err_expected,
Token::GetSpellingText(tokenType),
llvm::StringRef(token.ptr, token.length));
return false;
}
bool Parser::Consume(TokenType tokenType)
{
if (Expect(tokenType))
{
Advance();
return true;
}
return false;
}
bool Parser::Advance()
{
lexer.NextToken(token);
return true;
}
由于抽象语法树中的节点发生了修改,而抽象语法树的形成依靠的是访问者模式,所以继承访问者模式的打印功能的代码也要进行修改。
printVisitor.h
cpp
#pragma once
#include "ast.h"
#include "parser.h"
class PrintVisitor : public Visitor
{
public:
PrintVisitor(std::shared_ptr<Program> program);
public:
llvm::Value *VisitProgram(Program *p) override;
llvm::Value *VisitVariableDecl(VariableDecl *decl) override;
llvm::Value *VisitVariableAccessExpr(VariableAccessExpr *varaccExpr) override;
llvm::Value *VisitAssignExpr(AssignExpr *assignExpr) override;
llvm::Value *VisitBinaryExpr(BinaryExpr *binaryExpr) override;
llvm::Value *VisitNumberExpr(NumberExpr *factorExpr) override;
};
cpp
#include "printVisitor.h"
PrintVisitor::PrintVisitor(std::shared_ptr<Program> program)
{
VisitProgram(program.get());
}
llvm::Value *PrintVisitor::VisitProgram(Program *p)
{
for (auto &expr : p->ExprVec)
{
expr->Accept(this);
llvm::outs() << "\n";
}
return nullptr;
}
llvm::Value *PrintVisitor::VisitVariableDecl(VariableDecl *decl)
{
if(decl->type == CType::getIntTy())
{
llvm::outs() << "int " << llvm::StringRef(decl->token.ptr, decl->token.length);
}
return nullptr;
}
llvm::Value *PrintVisitor::VisitVariableAccessExpr(VariableAccessExpr *varaccExpr)
{
llvm::outs() << llvm::StringRef(varaccExpr->token.ptr, varaccExpr->token.length);
return nullptr;
}
llvm::Value *PrintVisitor::VisitAssignExpr(AssignExpr *assignExpr)
{
assignExpr->left->Accept(this);
llvm::outs() << " = ";
assignExpr->right->Accept(this);
return nullptr;
}
llvm::Value *PrintVisitor::VisitBinaryExpr(BinaryExpr *binaryExpr)
{
// 后序遍历
binaryExpr->left->Accept(this);
switch (binaryExpr->op)
{
case OpCode::add:
{
llvm::outs() << " + ";
break;
}
case OpCode::sub:
{
llvm::outs() << " - ";
break;
}
case OpCode::mul:
{
llvm::outs() << " * ";
break;
}
case OpCode::div:
{
llvm::outs() << " / ";
break;
}
default:
{
break;
}
}
binaryExpr->right->Accept(this);
return nullptr;
}
llvm::Value *PrintVisitor::VisitNumberExpr(NumberExpr *factorExpr)
{
llvm::outs() << llvm::StringRef(factorExpr->token.ptr, factorExpr->token.length);
return nullptr;
}
(5)、语义分析器(Sema)
实现代码
sema.h
cpp
#pragma once
#include "scope.h"
#include "ast.h"
class Sema // 语义分析
{
public:
Sema(DiagEngine &diagEngine) : diagEngine(diagEngine) {}
std::shared_ptr<ASTNode> SemaVariableDeclNode(Token token, CType* ty);
std::shared_ptr<ASTNode> SemaVariableAccessNode(Token token);
std::shared_ptr<ASTNode> SemaAssignExprNode(std::shared_ptr<ASTNode> left, std::shared_ptr<ASTNode> right, Token token);
std::shared_ptr<ASTNode> SemaBinaryExprNode(std::shared_ptr<ASTNode> left, std::shared_ptr<ASTNode> right, OpCode op);
std::shared_ptr<ASTNode> SemaNumberExprNode(Token token, CType* ty);
DiagEngine &GetDiaEngine() const
{
return diagEngine;
}
private:
Scope scope;
DiagEngine &diagEngine;
};
cpp
#include "sema.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Support/Casting.h"
// 符号声明节点
std::shared_ptr<ASTNode> Sema::SemaVariableDeclNode(Token token, CType* ty)
{
llvm::StringRef text(token.ptr, token.length);
auto symbol = scope.FindVarSymbolInCurEnv(text);
if (symbol) // 查找是否重定义
{
GetDiaEngine().Report(llvm::SMLoc::getFromPointer(token.ptr),
diag::err_redefine,
llvm::StringRef(token.ptr, token.length));
}
scope.AddVarSymbol(SymbolKind::LocalVariable, ty, text); // 添加到符号表
auto variableDecl = std::make_shared<VariableDecl>();
variableDecl->token = token;
variableDecl->type = ty;
return variableDecl;
}
std::shared_ptr<ASTNode> Sema::SemaVariableAccessNode(Token token)
{
llvm::StringRef text(token.ptr, token.length);
auto symbol = scope.FindVarSymbol(text);
// auto symbol = scope.FindVarSymbolInCurEnv(name); // err
if (symbol == nullptr)
{
GetDiaEngine().Report(llvm::SMLoc::getFromPointer(token.ptr),
diag::err_undefine,
llvm::StringRef(token.ptr, token.length));
}
auto varAcc = std::make_shared<VariableAccessExpr>();
varAcc->token = token;
varAcc->type = symbol->getTy();
return varAcc;
}
std::shared_ptr<ASTNode> Sema::SemaAssignExprNode(std::shared_ptr<ASTNode> left, std::shared_ptr<ASTNode> right, Token token)
{
assert(left && right);
if (!llvm::isa<VariableAccessExpr>(left.get()))
{
diagEngine.Report(llvm::SMLoc::getFromPointer(left->token.ptr), diag::err_lvalue);
}
auto assign = std::make_shared<AssignExpr>();
assign->left = left;
assign->right = right;
return assign;
}
std::shared_ptr<ASTNode> Sema::SemaBinaryExprNode(std::shared_ptr<ASTNode> left, std::shared_ptr<ASTNode> right, OpCode op)
{
auto binaryExpr = std::make_shared<BinaryExpr>();
binaryExpr->left = left;
binaryExpr->right = right;
binaryExpr->op = op;
return binaryExpr;
}
std::shared_ptr<ASTNode> Sema::SemaNumberExprNode(Token token, CType* ty)
{
auto factorExpr = std::make_shared<NumberExpr>();
factorExpr->token = token;
factorExpr->type = ty;
return factorExpr;
}
(6)、测试错误处理功能
生成 IR 文件
expr_01.txt
bash
bin/expr test/expr_01.txt > test/expr_01.ll
expr_02.txt
bash
bin/expr test/expr_02.txt > test/expr_02.ll
expr_03.txt
bash
bin/expr test/expr_03.txt > test/expr_03.ll
expr_04.txt
bash
bin/expr test/expr_04.txt > test/expr_04.ll
生成的 IR 内容
expr_01.txt
bash
test/expr.txt:2:1: error: unknown char '{'
{
^
expr_02.txt
bash
test/expr.txt:2:1: error: expect ',' , but get 'a'
a + b * 3;
expr_03.txt
bash
test/expr.txt:2:5: error: undefine symbol 'b'
a + b * 3;
expr_04.txt
bash
test/expr.txt:2:5: error: redefine symbol 'a'
int a;
^
由于最后一个左值的情况只能发生在为表达式的这一种情况,但是要测试出来要么左边为常数,要么是使用std::move() 得到的这个整体也是右值,而后一种暂不支持,前一种又不符合表达式的情况,所以这个目前无法测试!!!
3 单元测试
这里单元测试使用的是 google test。
首先是先对 CMakeLists.txt
进行一下简单的修改。
CMakeLists.txt
ini
cmake_minimum_required(VERSION 3.18)
project(errhandle_unittest)
set(CMAKE_C_STANDARD 99)
set(CMAKE_CXX_STANDARD 17)
SET(EXECUTABLE_OUTPUT_PATH ${CMAKE_SOURCE_DIR}/bin)
find_package(LLVM REQUIRED CONFIG)
list(APPEND CMAKE_MODULE_PATH ${LLVM_CMAKE_DIR})
include(AddLLVM)
include_directories("${LLVM_BINARY_DIR}/include" "${LLVM_INCLUDE_DIR}")
add_definitions(${LLVM_DEFINITIONS})
if (NOT ${LLVM_ENABLE_RTTI})
# For non-MSVC compilers
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-rtti")
endif()
set(LLVM_LINK_COMPONENTS Support Core)
add_llvm_executable(${PROJECT_NAME} main.cc lexer.cc parser.cc printVisitor.cc codegen.cc type.cc scope.cc sema.cc diag_engine.cc)
add_subdirectory(unittest) # 在最后加上这么一句话,这样 cmake 的时候可以得到两个相应的工程文件
文件夹 unittest
中也创建一个对应的 CMakeLists.txt
文件。
unittest/CMakeLists.txt
ini
include(FetchContent)
FetchContent_Declare(
googletest
URL https://github.com/google/googletest/archive/03597a01ee50ed33e9dfd640b249b4be3799d395.zip
)
# For Windows: Prevent overriding the parent project's compiler/linker settings
set(gtest_force_shared_crt ON CACHE BOOL "" FORCE)
FetchContent_MakeAvailable(googletest)
add_subdirectory(lexer) # 以你实际要测试的模块取名
当使用 cmake 构建工程的时候也顺便会安装 google test。
在 unittest
这个文件夹下再创建一个文件夹,这里我以 lexer
为例子。
testset/lexer_01.txt
ini
int a = 3, b = 5;
a + 3 * b - 6;
unittest/lexer/CMakeLists.txt
ini
enable_testing()
add_executable(
lexer_test
lexer_test.cc
../../lexer.cc
../../type.cc
../../diag_engine.cc
)
llvm_map_components_to_libnames(llvm_all Support Core)
target_link_libraries(
lexer_test
GTest::gtest_main
${llvm_all}
)
include(GoogleTest)
gtest_discover_tests(lexer_test)
unittest/lexer/lexer_test.cc
cpp
#include <gtest/gtest.h>
#include "../../lexer.h"
class LexerTest : public ::testing::Test
{
public:
void SetUp() override
{
static llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> buf = llvm::MemoryBuffer::getFile("../testset/lexer_01.txt");
if (!buf)
{
llvm::errs() << "can't open file!\n";
return;
}
llvm::SourceMgr mgr;
DiagEngine diagEngine(mgr);
mgr.AddNewSourceBuffer(std::move(*buf), llvm::SMLoc());
lexer = new Lexer(diagEngine, mgr);
}
void TearDown() override
{
delete lexer;
}
Lexer *lexer;
};
/*
int a = 3, b = 5;
a + 3 * b - 6;
*/
TEST_F(LexerTest, NextToken)
{
std::vector<Token> expectedVec; // 正确集
std::vector<Token> curVec; // 当前集
expectedVec.push_back(Token{TokenType::kw_int, 1, 1});
expectedVec.push_back(Token{TokenType::indentifier, 1, 5});
expectedVec.push_back(Token{TokenType::equal, 1, 7});
expectedVec.push_back(Token{TokenType::number, 1, 9});
expectedVec.push_back(Token{TokenType::comma, 1, 10});
expectedVec.push_back(Token{TokenType::indentifier, 1, 12});
expectedVec.push_back(Token{TokenType::equal, 1, 14});
expectedVec.push_back(Token{TokenType::number, 1, 16});
expectedVec.push_back(Token{TokenType::semi, 1, 17});
expectedVec.push_back(Token{TokenType::indentifier, 2, 1});
expectedVec.push_back(Token{TokenType::plus, 2, 3});
expectedVec.push_back(Token{TokenType::number, 2, 5});
expectedVec.push_back(Token{TokenType::star, 2, 7});
expectedVec.push_back(Token{TokenType::indentifier, 2, 9});
expectedVec.push_back(Token{TokenType::minus, 2, 11});
expectedVec.push_back(Token{TokenType::number, 2, 13});
expectedVec.push_back(Token{TokenType::semi, 2, 14});
Token token;
while(true)
{
lexer->NextToken(token);
if(token.tokenType == TokenType::eof)
{
break;
}
curVec.push_back(token);
}
ASSERT_EQ(expectedVec.size(), curVec.size());
for(int i = 0; i < expectedVec.size(); i++)
{
const auto expected_tok = expectedVec[i];
const auto cur_tok = curVec[i];
EXPECT_EQ(expected_tok.tokenType, cur_tok.tokenType);
EXPECT_EQ(expected_tok.row, cur_tok.row);
EXPECT_EQ(expected_tok.col, cur_tok.col);
}
}
测试 Lexer
生成对应的工程文件
bash
cmake -GNinja
运行得到的工程文件 lexer_test
bash
lexer_test
运行结果
bash
Running main() from build/_deps/googletest-src/googletest/src/gtest_main.cc
[==========] Running 1 test from 1 test suite.
[----------] Global test environment set-up.
[----------] 1 test from LexerTest
[ RUN ] LexerTest.NextToken
[ OK ] LexerTest.NextToken (2 ms)
[----------] 1 test from LexerTest (2 ms total)
[----------] Global test environment tear-down
[==========] 1 test from 1 test suite ran. (2 ms total)
[ PASSED ] 1 test.
🎉 至此,我们完成了对赋值表达式的支持,并增强了编译器的错误处理能力。通过引入诊断引擎,编译器现在能够提供更清晰的错误提示,提升了开发体验。后续可继续扩展语言特性,并完善测试覆盖。