1、实现 break 语句和 continue 语句功能
1.1 测试文件
expr.txt
ini
int a = 3;
int b = 5;
for(int i = 0; i < 4; i = i+1)
{
if ( a >= 6 )
continue;
a = a + i;
if ( b > 10 )
break;
b = b + (2 * i);
}
a * (b + a -3) - 2;
1.2 词法分析器 (Lexer)
对词法分析器新增 break
, continue
类型以及对应的检测方法。
文法定义
ebnf.txt
ini
prog : stmt*
stmt : decl-stmt | expr-stmt | null-stmt | if-stmt | block-stmt | for-stmt | break-stmt | continue-stmt
null-stmt : ";"
decl-stmt : "int" identifier ("," identifier (= expr)?)* ";"
if-stmt : "if" "(" expr ")" stmt ( "else" stmt )?
for-stmt : "for" "(" expr? ";" expr? ";" expr? ")" stmt
"for" "(" decl-stmt expr? ";" expr? ")" stmt
block-stmt: "{" stmt* "}"
break-stmt: "break" ";"
continue-stmt: "continue" ";"
expr-stmt : expr ";"
expr : assign-expr | equal-expr
assign-expr: identifier "=" expr
equal-expr : relational-expr (("==" | "!=") relational-expr)*
relational-expr: add-expr (("<"|">"|"<="|">=") add-expr)*
add-expr : mult-expr (("+" | "-") mult-expr)*
mult-expr : primary-expr (("*" | "/") primary-expr)*
primary-expr : identifier | number | "(" expr ")"
number: ([0-9])+
identifier : (a-zA-Z_)(a-zA-Z0-9_)*
实现代码
lexer.h
cpp
#pragma once
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/raw_ostream.h"
#include "type.h"
#include "diag_engine.h"
// char stream -> token
enum class TokenType
{
number, // [0-9]+
indentifier, // 变量
kw_int, // int
kw_if, // if
kw_else, // else
kw_for, // for
kw_break, // break
kw_continue, // continue
plus, // +
minus, // -
star, // *
slash, // /
equal, // =
equal_equal, // ==
not_equal, // !=
less, // <
less_equal, // <=
greater, // >
greater_equal, // >=
l_parent, // (
r_parent, // )
l_brace, // {
r_brace, // }
semi, // ;
comma, // ,
eof // end of file
};
class Token
{
public:
TokenType tokenType; // token 的种类
int row, col;
int value; // for number
const char *ptr; // for debug
int length;
CType *type; // for built-in type
public:
void Dump();
static llvm::StringRef GetSpellingText(TokenType tokenType);
};
class Lexer
{
private:
DiagEngine &diagEngine;
llvm::SourceMgr &mgr;
public:
Lexer(DiagEngine &diagEngine, llvm::SourceMgr &mgr) : diagEngine(diagEngine), mgr(mgr)
{
unsigned id = mgr.getMainFileID();
llvm::StringRef buf = mgr.getMemoryBuffer(id)->getBuffer();
BufPtr = buf.begin();
LineHeadPtr = buf.begin();
BufEnd = buf.end();
row = 1;
}
void NextToken(Token &token);
void SaveState();
void RestoreState();
DiagEngine &GetDiagEngine() const
{
return diagEngine;
}
private:
struct State
{
const char *BufPtr;
const char *LineHeadPtr;
const char *BufEnd;
int row;
};
private:
const char *BufPtr;
const char *LineHeadPtr;
const char *BufEnd;
int row;
State state;
};
cpp
#include "lexer.h"
void Token::Dump()
{
llvm::StringRef text(ptr, length);
llvm::outs() << "{" << text << ", row = " << row << ", col = " << col << "}\n";
}
// number, // [0-9]+
// indentifier,// 变量
// kw_int, // int
// plus, // +
// minus, // -
// star, // *
// slash, // /
// equal, // =
// l_parent, // (
// r_parent, // )
// semi, // ;
// comma, // ,
llvm::StringRef Token::GetSpellingText(TokenType tokenType)
{
switch (tokenType)
{
case TokenType::kw_int:
return "int";
case TokenType::plus:
return "+";
case TokenType::minus:
return "-";
case TokenType::star:
return "*";
case TokenType::slash:
return "/";
case TokenType::equal:
return "=";
case TokenType::l_parent:
return "(";
case TokenType::r_parent:
return ")";
case TokenType::semi:
return ";";
case TokenType::comma:
return ",";
case TokenType::number:
return "number";
case TokenType::indentifier:
return "indentifier";
case TokenType::kw_if:
return "if";
case TokenType::kw_else:
return "else";
case TokenType::kw_for:
return "for";
case TokenType::kw_break:
return "break";
case TokenType::kw_continue:
return "continue";
case TokenType::l_brace:
return "{";
case TokenType::r_brace:
return "}";
case TokenType::equal_equal:
return "==";
case TokenType::not_equal:
return "!=";
case TokenType::less:
return "<";
case TokenType::less_equal:
return "<=";
case TokenType::greater:
return ">";
case TokenType::greater_equal:
return ">=";
default:
llvm::llvm_unreachable_internal(); // 不可能到达这个位置
}
}
bool IsWhiteSpace(char ch)
{
return ch == ' ' || ch == '\t' || ch == '\r' || ch == '\n';
}
bool IsDigit(char ch)
{
return ch >= '0' && ch <= '9';
}
bool IsLetter(char ch)
{
// a-z, A-Z, _
return (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch == '_');
}
void Lexer::NextToken(Token &token)
{
// 过滤空格
while (IsWhiteSpace(*BufPtr))
{
if (*BufPtr == '\n')
{
row += 1;
LineHeadPtr = BufPtr + 1;
}
BufPtr++;
}
token.row = row;
token.col = BufPtr - LineHeadPtr + 1;
// 判断是否到结尾了
if (BufPtr >= BufEnd)
{
token.tokenType = TokenType::eof;
return;
}
token.ptr = BufPtr;
token.length = 0;
// 判断是否为数字
if (IsDigit(*BufPtr))
{
int len = 0;
int val = 0;
while (IsDigit(*BufPtr))
{
val = val * 10 + *BufPtr++ - '0';
token.length++;
}
token.value = val;
token.tokenType = TokenType::number;
token.type = CType::getIntTy();
}
else if (IsLetter(*BufPtr)) // 为变量
{
while (IsLetter(*BufPtr) || IsDigit(*BufPtr))
{
BufPtr++;
}
token.tokenType = TokenType::indentifier;
token.length = BufPtr - token.ptr;
llvm::StringRef text(token.ptr, BufPtr - token.ptr);
if (text == "int")
{
token.tokenType = TokenType::kw_int;
}
else if (text == "if")
{
token.tokenType = TokenType::kw_if;
}
else if (text == "else")
{
token.tokenType = TokenType::kw_else;
}
else if (text == "for")
{
token.tokenType = TokenType::kw_for;
}
else if (text == "break")
{
token.tokenType = TokenType::kw_break;
}
else if (text == "continue")
{
token.tokenType = TokenType::kw_continue;
}
}
else // 为特殊字符
{
switch (*BufPtr)
{
case '+':
{
token.tokenType = TokenType::plus;
token.length = 1;
BufPtr++;
break;
}
case '-':
{
token.tokenType = TokenType::minus;
token.length = 1;
BufPtr++;
break;
}
case '*':
{
token.tokenType = TokenType::star;
token.length = 1;
BufPtr++;
break;
}
case '/':
{
token.tokenType = TokenType::slash;
token.length = 1;
BufPtr++;
break;
}
case '=':
{
if (*(BufPtr + 1) == '=')
{
token.tokenType = TokenType::equal_equal;
token.length = 2;
BufPtr += 2;
}
else
{
token.tokenType = TokenType::equal;
token.length = 1;
BufPtr++;
}
break;
}
case '!':
{
if (*(BufPtr + 1) == '=')
{
token.tokenType = TokenType::not_equal;
token.length = 2;
BufPtr += 2;
break;
}
// pass through
}
case '<':
{
if (*(BufPtr + 1) == '=')
{
token.tokenType = TokenType::less_equal;
token.length = 2;
BufPtr += 2;
}
else
{
token.tokenType = TokenType::less;
token.length = 1;
BufPtr++;
}
break;
}
case '>':
{
if (*(BufPtr + 1) == '=')
{
token.tokenType = TokenType::greater_equal;
token.length = 2;
BufPtr += 2;
}
else
{
token.tokenType = TokenType::greater;
token.length = 1;
BufPtr++;
}
break;
}
case '(':
{
token.tokenType = TokenType::l_parent;
token.length = 1;
BufPtr++;
break;
}
case ')':
{
token.tokenType = TokenType::r_parent;
token.length = 1;
BufPtr++;
break;
}
case ';':
{
token.tokenType = TokenType::semi;
token.length = 1;
BufPtr++;
break;
}
case ',':
{
token.tokenType = TokenType::comma;
token.length = 1;
BufPtr++;
break;
}
case '{':
{
token.tokenType = TokenType::l_brace;
token.length = 1;
BufPtr++;
break;
}
case '}':
{
token.tokenType = TokenType::r_brace;
token.length = 1;
BufPtr++;
break;
}
default:
{
diagEngine.Report(llvm::SMLoc::getFromPointer(BufPtr), diag::err_unknown_char, *BufPtr);
}
}
}
}
void Lexer::SaveState()
{
state.LineHeadPtr = LineHeadPtr;
state.BufPtr = BufPtr;
state.BufEnd = BufEnd;
state.row = row;
}
void Lexer::RestoreState()
{
LineHeadPtr = state.LineHeadPtr;
BufPtr = state.BufPtr;
BufEnd = state.BufEnd;
row = state.row;
}
1.3 语法分析器 (Parser)
根据文法新增的 break 和 continue 语句新增 break、continue 语句的类,解析 break、continue 语句的函数以及对 ParseStmt() 函数进行对 break、continue 语句的判断,并且也对类型的判断封装成函数,便于解析变量声明节点。
实现代码
ast.h
cpp
#pragma once
#include <vector>
#include <memory>
#include "llvm/IR/Value.h"
#include "type.h"
#include "lexer.h"
enum class OpCode
{
add,
sub,
mul,
div,
equal_equal,
not_equal,
less,
less_equal,
greater,
greater_equal
};
// 进行声明
class Program;
class Expr;
class DeclStmt;
class BlockStmt;
class ForStmt;
class BreakStmt;
class ContinueStmt;
class VariableDecl;
class IfStmt;
class VariableAccessExpr;
class BinaryExpr;
class AssignExpr;
class NumberExpr;
// 访问者模式
class Visitor
{
public:
virtual ~Visitor() {}
virtual llvm::Value *VisitProgram(Program *p) = 0;
virtual llvm::Value *VisitDeclStmt(DeclStmt *decl) = 0;
virtual llvm::Value *VisitBlockStmt(BlockStmt *block) = 0;
virtual llvm::Value *VisitIfStmt(IfStmt *ifStmt) = 0;
virtual llvm::Value *VisitForStmt(ForStmt *forStmt) = 0;
virtual llvm::Value *VisitBreakStmt(BreakStmt *breakStmt) = 0;
virtual llvm::Value *VisitContinueStmt(ContinueStmt *continueStmt) = 0;
virtual llvm::Value *VisitVariableDecl(VariableDecl *decl) = 0;
virtual llvm::Value *VisitVariableAccessExpr(VariableAccessExpr *varaccExpr) = 0;
virtual llvm::Value *VisitAssignExpr(AssignExpr *assignExpr) = 0;
virtual llvm::Value *VisitBinaryExpr(BinaryExpr *binaryExpr) = 0;
virtual llvm::Value *VisitNumberExpr(NumberExpr *factorExpr) = 0;
};
// 语法树的公共节点
class ASTNode
{
public:
enum Kind
{
ND_BlockStmt,
ND_ForStmt,
ND_IfStmt,
ND_BreakStmt,
ND_ContinueStmt,
ND_DeclStmt,
ND_VariableDecl,
ND_BinaryExpr,
ND_NumberExpr,
ND_VariableAccessExpr,
ND_AssignExpr
};
private:
const Kind kind;
public:
ASTNode(Kind kind) : kind(kind) {}
virtual ~ASTNode() {}
virtual llvm::Value *Accept(Visitor *v) { return nullptr; } // 通过虚函数的特性完成分发的功能
const Kind getKind() const { return kind; }
public:
CType *type;
Token token;
};
// 声明语句节点
class DeclStmt : public ASTNode
{
public:
DeclStmt() : ASTNode(ND_DeclStmt) {}
llvm::Value *Accept(Visitor *v) override
{
return v->VisitDeclStmt(this);
}
static bool classof(const ASTNode *node) // 判断是否能进行指针的强转
{
return node->getKind() == ND_DeclStmt;
}
public:
std::vector<std::shared_ptr<ASTNode>> nodeVec;
};
// 块语句节点
class BlockStmt : public ASTNode
{
public:
BlockStmt() : ASTNode(ND_BlockStmt) {}
llvm::Value *Accept(Visitor *v) override
{
return v->VisitBlockStmt(this);
}
static bool classof(const ASTNode *node) // 判断是否能进行指针的强转
{
return node->getKind() == ND_BlockStmt;
}
public:
std::vector<std::shared_ptr<ASTNode>> nodeVec;
};
// 条件判断节点
class IfStmt : public ASTNode
{
public:
IfStmt() : ASTNode(ND_IfStmt) {}
llvm::Value *Accept(Visitor *v) override
{
return v->VisitIfStmt(this);
}
static bool classof(const ASTNode *node) // 判断是否能进行指针的强转
{
return node->getKind() == ND_IfStmt;
}
public:
std::shared_ptr<ASTNode> condNode;
std::shared_ptr<ASTNode> thenNode;
std::shared_ptr<ASTNode> elseNode;
};
// 循环节点
class ForStmt : public ASTNode
{
public:
ForStmt() : ASTNode(ND_ForStmt) {}
llvm::Value *Accept(Visitor *v) override
{
return v->VisitForStmt(this);
}
static bool classof(const ASTNode *node) // 判断是否能进行指针的强转
{
return node->getKind() == ND_ForStmt;
}
public:
std::shared_ptr<ASTNode> initNode;
std::shared_ptr<ASTNode> condNode;
std::shared_ptr<ASTNode> incNode;
std::shared_ptr<ASTNode> bodyNode;
};
class BreakStmt : public ASTNode
{
public:
BreakStmt() : ASTNode(ND_BreakStmt) {}
llvm::Value *Accept(Visitor *v) override
{
return v->VisitBreakStmt(this);
}
static bool classof(const ASTNode *node) // 判断是否能进行指针的强转
{
return node->getKind() == ND_BreakStmt;
}
public:
std::shared_ptr<ASTNode> target;
};
class ContinueStmt : public ASTNode
{
public:
ContinueStmt() : ASTNode(ND_ContinueStmt) {}
llvm::Value *Accept(Visitor *v) override
{
return v->VisitContinueStmt(this);
}
static bool classof(const ASTNode *node) // 判断是否能进行指针的强转
{
return node->getKind() == ND_ContinueStmt;
}
public:
std::shared_ptr<ASTNode> target;
};
// 变量声明节点
class VariableDecl : public ASTNode
{
public:
VariableDecl() : ASTNode(ND_VariableDecl) {}
llvm::Value *Accept(Visitor *v) override
{
return v->VisitVariableDecl(this);
}
static bool classof(const ASTNode *node) // 判断是否能进行指针的强转
{
return node->getKind() == ND_VariableDecl;
}
};
// 二元表达式节点
class BinaryExpr : public ASTNode
{
public:
BinaryExpr() : ASTNode(ND_BinaryExpr) {}
llvm::Value *Accept(Visitor *v) override
{
return v->VisitBinaryExpr(this);
}
static bool classof(const ASTNode *node) // 判断是否能进行指针的强转
{
return node->getKind() == ND_BinaryExpr;
}
public:
OpCode op;
std::shared_ptr<ASTNode> left;
std::shared_ptr<ASTNode> right;
};
// 赋值表达式节点
class AssignExpr : public ASTNode
{
public:
AssignExpr() : ASTNode(ND_AssignExpr) {}
llvm::Value *Accept(Visitor *v) override
{
return v->VisitAssignExpr(this);
}
static bool classof(const ASTNode *node) // 判断是否能进行指针的强转
{
return node->getKind() == ND_AssignExpr;
}
public:
std::shared_ptr<ASTNode> left;
std::shared_ptr<ASTNode> right;
};
// 数字表达式节点
class NumberExpr : public ASTNode
{
public:
NumberExpr() : ASTNode(ND_NumberExpr) {}
llvm::Value *Accept(Visitor *v) override
{
return v->VisitNumberExpr(this);
}
static bool classof(const ASTNode *node) // 判断是否能进行指针的强转
{
return node->getKind() == ND_NumberExpr;
}
};
// 变量访问节点
class VariableAccessExpr : public ASTNode // 变量表达式
{
public:
VariableAccessExpr() : ASTNode(ND_VariableAccessExpr) {}
llvm::Value *Accept(Visitor *v) override
{
return v->VisitVariableAccessExpr(this);
}
static bool classof(const ASTNode *node) // 判断是否能进行指针的强转
{
return node->getKind() == ND_VariableAccessExpr;
}
};
// 目标程序
class Program
{
public:
std::vector<std::shared_ptr<ASTNode>> nodeVec;
};
parser.h
cpp
#pragma once
#include "ast.h"
#include "lexer.h"
#include "sema.h"
#include <vector>
class Parser
{
public:
Parser(Lexer &lexer, Sema &sema) : lexer(lexer), sema(sema)
{
Advance();
}
std::shared_ptr<Program> ParseProgram();
private:
std::shared_ptr<ASTNode> ParseStmt();
std::shared_ptr<ASTNode> ParseDeclStmt();
std::shared_ptr<ASTNode> ParseBlockStmt();
std::shared_ptr<ASTNode> ParseExprStmt();
std::shared_ptr<ASTNode> ParseIfStmt();
std::shared_ptr<ASTNode> ParseForStmt();
std::shared_ptr<ASTNode> ParseBreakStmt();
std::shared_ptr<ASTNode> ParseContinueStmt();
std::shared_ptr<ASTNode> ParseExpr();
std::shared_ptr<ASTNode> ParseAssignExpr();
std::shared_ptr<ASTNode> ParseEqualExpr();
std::shared_ptr<ASTNode> ParseRelationalExpr();
std::shared_ptr<ASTNode> ParseAddExpr();
std::shared_ptr<ASTNode> ParseMultiExpr();
std::shared_ptr<ASTNode> ParsePrimary();
// 消耗 token 的函数
// 检测 token 的类型
bool Expect(TokenType tokenType);
// 检测 token 的类型并消费
bool Consume(TokenType tokenType);
// 直接消耗当前的 token
bool Advance();
// 检查是否为类型名
bool IsTypeName();
DiagEngine &GetDiagEngine() const
{
return lexer.GetDiagEngine();
}
private:
Lexer &lexer;
Sema &sema;
Token token;
// break 和 continue 使用有明确的限制,而这两个的作用就是记录在哪个作用域下可以使用这两种语句
std::vector<std::shared_ptr<ASTNode>> breakNodes;
std::vector<std::shared_ptr<ASTNode>> continueNodes;
};
cpp
#include "parser.h"
#include <cassert>
/*
prog : stmt*
stmt : decl-stmt | expr-stmt | null-stmt | if-stmt
null-stmt : ";"
decl-stmt : "int" identifier ("," identifier (= expr)?)* ";"
if-stmt : "if" "(" expr ")" stmt ( "else" stmt )?
expr-stmt : expr ";"
expr : assign-expr | add-expr
assign-expr: identifier "=" expr
add-expr : mult-expr (("+" | "-") mult-expr)*
mult-expr : primary-expr (("*" | "/") primary-expr)*
primary-expr : identifier | number | "(" expr ")"
number: ([0-9])+
identifier : (a-zA-Z_)(a-zA-Z0-9_)*
*/
// 解析目标程序
// stmt : decl-stmt | expr-stmt | null-stmt
std::shared_ptr<Program> Parser::ParseProgram()
{
std::vector<std::shared_ptr<ASTNode>> nodeVec;
while (token.tokenType != TokenType::eof)
{
auto stmt = ParseStmt();
if (stmt)
nodeVec.push_back(stmt);
}
auto program = std::make_shared<Program>();
program->nodeVec = std::move(nodeVec);
return program;
}
// 解析语句
std::shared_ptr<ASTNode> Parser::ParseStmt()
{
// 遇到 ; 需要进行消费 token
// null-stmt
if (token.tokenType == TokenType::semi)
{
Consume(TokenType::semi);
return nullptr;
}
// decl-stmt
else if (IsTypeName())
{
return ParseDeclStmt();
}
// block-stmt
else if (token.tokenType == TokenType::l_brace)
{
return ParseBlockStmt();
}
// if-stmt
else if (token.tokenType == TokenType::kw_if)
{
return ParseIfStmt();
}
// for-stmt
else if (token.tokenType == TokenType::kw_for)
{
return ParseForStmt();
}
// break-stmt
else if (token.tokenType == TokenType::kw_break)
{
return ParseBreakStmt();
}
// continue-stmt
else if (token.tokenType == TokenType::kw_continue)
{
return ParseContinueStmt();
}
// expr-stmt
else
{
return ParseExprStmt();
}
}
// 解析声明语句
std::shared_ptr<ASTNode> Parser::ParseDeclStmt()
{
/// int a, b = 3;
/// int a = 3;
Consume(TokenType::kw_int);
CType *baseTy = CType::getIntTy();
/// a , b = 3;
/// a = 3;
auto declStmt = std::make_shared<DeclStmt>();
/// a, b = 3;
/// a = 3;
int i = 0;
while (token.tokenType != TokenType::semi)
{
if (i++ > 0) // if (i++)
{
assert(Consume(TokenType::comma));
}
/// 变量声明的节点: int a = 3; -> int a; a = 3;
// a = 3;
auto variableDecl = sema.SemaVariableDeclNode(token, baseTy); // get a type
declStmt->nodeVec.push_back(variableDecl);
Token tmp = token;
Consume(TokenType::indentifier);
// = 3;
if (token.tokenType == TokenType::equal)
{
Token opToken = token;
Advance();
// 3;
auto right = ParseExpr();
auto left = sema.SemaVariableAccessNode(tmp);
auto assign = sema.SemaAssignExprNode(left, right, opToken);
declStmt->nodeVec.push_back(assign);
}
}
Consume(TokenType::semi);
return declStmt;
}
std::shared_ptr<ASTNode> Parser::ParseBlockStmt()
{
sema.EnterScope(); // 进入作用域
auto blockStmt = std::make_shared<BlockStmt>();
Consume(TokenType::l_brace);
while (token.tokenType != TokenType::r_brace)
{
auto stmt = ParseStmt();
if (stmt)
blockStmt->nodeVec.push_back(stmt);
}
Consume(TokenType::r_brace);
sema.ExitScope(); // 离开作用域
return blockStmt;
}
// 解析表达式语句
std::shared_ptr<ASTNode> Parser::ParseExprStmt()
{
auto expr = ParseExpr();
Consume(TokenType::semi);
return expr;
}
// if-stmt : "if" "(" expr ")" stmt ( "else" stmt )?
/*
if (a)
b = 3;
else
b = 4;
*/
std::shared_ptr<ASTNode> Parser::ParseIfStmt()
{
Consume(TokenType::kw_if);
Consume(TokenType::l_parent);
auto condExpr = ParseExpr();
Consume(TokenType::r_parent);
auto thenStmt = ParseStmt();
std::shared_ptr<ASTNode> elseStmt = nullptr;
if (token.tokenType == TokenType::kw_else)
{
Consume(TokenType::kw_else);
elseStmt = ParseStmt();
}
return sema.SemaIfStmtNode(condExpr, thenStmt, elseStmt);
}
/*
for-stmt :
"for" "(" expr? ";" expr? ";" expr? ")" stmt
"for" "(" decl-stmt expr? ";" expr? ")" stmt
*/
// 解析 for-stmt
std::shared_ptr<ASTNode> Parser::ParseForStmt()
{
sema.EnterScope(); // 进入作用域
Consume(TokenType::kw_for);
Consume(TokenType::l_parent);
auto forStmt = std::make_shared<ForStmt>();
breakNodes.push_back(forStmt);
continueNodes.push_back(forStmt);
std::shared_ptr<ASTNode> initNode = nullptr;
std::shared_ptr<ASTNode> condNode = nullptr;
std::shared_ptr<ASTNode> incNode = nullptr;
std::shared_ptr<ASTNode> bodyNode = nullptr;
if (IsTypeName())
{
initNode = ParseDeclStmt();
}
else
{
initNode = ParseExpr();
Consume(TokenType::semi);
}
condNode = ParseExpr();
Consume(TokenType::semi);
incNode = ParseExpr();
Consume(TokenType::r_parent);
bodyNode = ParseStmt();
forStmt->initNode = initNode;
forStmt->condNode = condNode;
forStmt->incNode = incNode;
forStmt->bodyNode = bodyNode;
breakNodes.pop_back();
continueNodes.pop_back();
sema.ExitScope(); // 离开作用域
return forStmt;
}
std::shared_ptr<ASTNode> Parser::ParseBreakStmt()
{
if (breakNodes.size() == 0)
{
GetDiagEngine().Report(llvm::SMLoc::getFromPointer(token.ptr), diag::err_break_stmt);
}
Consume(TokenType::kw_break);
auto node = std::make_shared<BreakStmt>();
node->target = breakNodes.back();
return node;
}
std::shared_ptr<ASTNode> Parser::ParseContinueStmt()
{
if (continueNodes.size() == 0)
{
GetDiagEngine().Report(llvm::SMLoc::getFromPointer(token.ptr), diag::err_continue_stmt);
}
Consume(TokenType::kw_continue);
auto node = std::make_shared<ContinueStmt>();
node->target = continueNodes.back();
return node;
}
// 解析表达式
// expr : assign-expr | add-expr
// assign-expr: identifier "=" expr
// add-expr : mult-expr (("+" | "-") mult-expr)*
std::shared_ptr<ASTNode> Parser::ParseExpr()
{
lexer.SaveState();
bool isAssign = false;
Token tmp = token;
// a = b;
if (tmp.tokenType == TokenType::indentifier)
{
lexer.NextToken(tmp);
if (tmp.tokenType == TokenType::equal)
{
isAssign = true;
}
}
lexer.RestoreState();
if (isAssign)
{
return ParseAssignExpr();
}
else // equal-expr
return ParseEqualExpr();
}
// 解析赋值表达式
std::shared_ptr<ASTNode> Parser::ParseAssignExpr()
{
// a = b;
Token tmp = token;
Consume(TokenType::indentifier);
auto expr = sema.SemaVariableAccessNode(tmp);
Token opToken = token;
Consume(TokenType::equal);
return sema.SemaAssignExprNode(expr, ParseExpr(), opToken);
}
// 解析等号表达式
std::shared_ptr<ASTNode> Parser::ParseEqualExpr()
{
std::shared_ptr<ASTNode> left = ParseRelationalExpr();
while (token.tokenType == TokenType::equal_equal || token.tokenType == TokenType::not_equal)
{
OpCode op;
if (token.tokenType == TokenType::equal_equal)
{
op = OpCode::equal_equal;
}
else
{
op = OpCode::not_equal;
}
Advance();
auto right = ParseRelationalExpr();
auto binaryExpr = sema.SemaBinaryExprNode(left, right, op);
left = binaryExpr;
}
return left;
}
// 解析关系表达式
std::shared_ptr<ASTNode> Parser::ParseRelationalExpr()
{
std::shared_ptr<ASTNode> left = ParseAddExpr();
while (token.tokenType == TokenType::less || token.tokenType == TokenType::less_equal || token.tokenType == TokenType::greater || token.tokenType == TokenType::greater_equal)
{
OpCode op;
if (token.tokenType == TokenType::less)
{
op = OpCode::less;
}
else if (token.tokenType == TokenType::less_equal)
{
op = OpCode::less_equal;
}
else if (token.tokenType == TokenType::greater)
{
op = OpCode::greater;
}
else if (token.tokenType == TokenType::greater_equal)
{
op = OpCode::greater_equal;
}
Advance();
auto right = ParseAddExpr();
auto binaryExpr = sema.SemaBinaryExprNode(left, right, op);
left = binaryExpr;
}
return left;
}
// 解析加法表达式
std::shared_ptr<ASTNode> Parser::ParseAddExpr()
{
std::shared_ptr<ASTNode> left = ParseMultiExpr();
while (token.tokenType == TokenType::plus || token.tokenType == TokenType::minus)
{
OpCode op;
if (token.tokenType == TokenType::plus)
{
op = OpCode::add;
}
else
{
op = OpCode::sub;
}
Advance();
auto right = ParseMultiExpr();
auto binaryExpr = sema.SemaBinaryExprNode(left, right, op);
left = binaryExpr;
}
return left;
}
// 解析项
std::shared_ptr<ASTNode> Parser::ParseMultiExpr()
{
std::shared_ptr<ASTNode> left = ParsePrimary();
while (token.tokenType == TokenType::star || token.tokenType == TokenType::slash)
{
OpCode op;
if (token.tokenType == TokenType::star)
{
op = OpCode::mul;
}
else
{
op = OpCode::div;
}
Advance();
auto right = ParsePrimary();
auto binaryExpr = sema.SemaBinaryExprNode(left, right, op);
left = binaryExpr;
}
return left;
}
// 解析因子
std::shared_ptr<ASTNode> Parser::ParsePrimary()
{
if (token.tokenType == TokenType::l_parent)
{
Consume(TokenType::l_parent);
auto expr = ParseExpr();
Consume(TokenType::r_parent);
return expr;
}
else if (token.tokenType == TokenType::indentifier)
{
auto variableAccessExpr = sema.SemaVariableAccessNode(token);
Consume(TokenType::indentifier);
return variableAccessExpr;
}
else
{
Expect(TokenType::number);
auto factorExpr = sema.SemaNumberExprNode(token, token.type);
Consume(TokenType::number);
return factorExpr;
}
}
/// 消耗 token 函数
bool Parser::Expect(TokenType tokenType)
{
if (token.tokenType == tokenType)
return true;
GetDiagEngine().Report(llvm::SMLoc::getFromPointer(token.ptr),
diag::err_expected,
Token::GetSpellingText(tokenType),
llvm::StringRef(token.ptr, token.length));
return false;
}
bool Parser::Consume(TokenType tokenType)
{
if (Expect(tokenType))
{
Advance();
return true;
}
return false;
}
bool Parser::Advance()
{
lexer.NextToken(token);
return true;
}
bool Parser::IsTypeName()
{
if (token.tokenType == TokenType::kw_int)
{
return true;
}
return false;
}
printVisitor.h
因为基类为了访问 break、continue 语句,新增了对 break、continue 语句的抽象访问函数,所以子类需要对这个函数进行复写。
cpp
#pragma once
#include "ast.h"
#include "parser.h"
class PrintVisitor : public Visitor
{
public:
PrintVisitor(std::shared_ptr<Program> program);
public:
llvm::Value *VisitProgram(Program *p) override;
llvm::Value *VisitDeclStmt(DeclStmt *decl) override;
llvm::Value *VisitBlockStmt(BlockStmt *block) override;
llvm::Value *VisitIfStmt(IfStmt *ifStmt) override;
llvm::Value *VisitForStmt(ForStmt *forStmt) override;
llvm::Value *VisitBreakStmt(BreakStmt *breakStmt) override;
llvm::Value *VisitContinueStmt(ContinueStmt *continueStmt) override;
llvm::Value *VisitVariableDecl(VariableDecl *decl) override;
llvm::Value *VisitVariableAccessExpr(VariableAccessExpr *varaccExpr) override;
llvm::Value *VisitAssignExpr(AssignExpr *assignExpr) override;
llvm::Value *VisitBinaryExpr(BinaryExpr *binaryExpr) override;
llvm::Value *VisitNumberExpr(NumberExpr *factorExpr) override;
};
cpp
#include "printVisitor.h"
PrintVisitor::PrintVisitor(std::shared_ptr<Program> program)
{
VisitProgram(program.get());
}
llvm::Value *PrintVisitor::VisitProgram(Program *p)
{
for (auto &expr : p->nodeVec)
{
expr->Accept(this);
llvm::outs() << "\n";
}
return nullptr;
}
llvm::Value *PrintVisitor::VisitDeclStmt(DeclStmt *decl)
{
for (auto node : decl->nodeVec)
{
node->Accept(this);
}
return nullptr;
}
llvm::Value *PrintVisitor::VisitBlockStmt(BlockStmt *block)
{
llvm::outs() << "{\n";
for (const auto &node : block->nodeVec)
{
node->Accept(this);
llvm::outs() << "\n";
}
llvm::outs() << "}\n";
return nullptr;
}
llvm::Value *PrintVisitor::VisitIfStmt(IfStmt *ifStmt)
{
llvm::outs() << "if (";
ifStmt->condNode->Accept(this);
llvm::outs() << ")\n";
ifStmt->thenNode->Accept(this);
llvm::outs() << "\n";
if (ifStmt->elseNode)
{
llvm::outs() << "else\n";
ifStmt->elseNode->Accept(this);
}
return nullptr;
}
llvm::Value *PrintVisitor::VisitForStmt(ForStmt *forStmt)
{
llvm::outs() << "for (";
if (forStmt->initNode)
{
forStmt->initNode->Accept(this);
}
llvm::outs() << "; ";
if (forStmt->condNode)
{
forStmt->condNode->Accept(this);
}
llvm::outs() << "; ";
if (forStmt->incNode)
{
forStmt->incNode->Accept(this);
}
llvm::outs() << ")";
if (forStmt->bodyNode)
{
forStmt->bodyNode->Accept(this);
}
return nullptr;
}
llvm::Value *PrintVisitor::VisitBreakStmt(BreakStmt *breakStmt)
{
llvm::outs() << "break;";
return nullptr;
}
llvm::Value *PrintVisitor::VisitContinueStmt(ContinueStmt *continueStmt)
{
llvm::outs() << "continue;";
return nullptr;
}
llvm::Value *PrintVisitor::VisitVariableDecl(VariableDecl *decl)
{
if (decl->type == CType::getIntTy())
{
llvm::outs() << "int " << llvm::StringRef(decl->token.ptr, decl->token.length) << ";";
}
return nullptr;
}
llvm::Value *PrintVisitor::VisitVariableAccessExpr(VariableAccessExpr *varaccExpr)
{
llvm::outs() << llvm::StringRef(varaccExpr->token.ptr, varaccExpr->token.length);
return nullptr;
}
llvm::Value *PrintVisitor::VisitAssignExpr(AssignExpr *assignExpr)
{
assignExpr->left->Accept(this);
llvm::outs() << " = ";
assignExpr->right->Accept(this);
return nullptr;
}
llvm::Value *PrintVisitor::VisitBinaryExpr(BinaryExpr *binaryExpr)
{
// 后序遍历
binaryExpr->left->Accept(this);
switch (binaryExpr->op)
{
case OpCode::add:
{
llvm::outs() << " + ";
break;
}
case OpCode::sub:
{
llvm::outs() << " - ";
break;
}
case OpCode::mul:
{
llvm::outs() << " * ";
break;
}
case OpCode::div:
{
llvm::outs() << " / ";
break;
}
case OpCode::equal_equal:
{
llvm::outs() << " == ";
break;
}
case OpCode::not_equal:
{
llvm::outs() << " != ";
break;
}
case OpCode::less:
{
llvm::outs() << " < ";
break;
}
case OpCode::less_equal:
{
llvm::outs() << " <= ";
break;
}
case OpCode::greater:
{
llvm::outs() << " > ";
break;
}
case OpCode::greater_equal:
{
llvm::outs() << " >= ";
break;
}
default:
{
break;
}
}
binaryExpr->right->Accept(this);
return nullptr;
}
llvm::Value *PrintVisitor::VisitNumberExpr(NumberExpr *factorExpr)
{
llvm::outs() << llvm::StringRef(factorExpr->token.ptr, factorExpr->token.length);
return nullptr;
}
1.4 诊断引擎 (DiagEngine)
由于 break 语句和 continue 语句只能出现在 for 语句或 switch 语句中,所以需要新增对应的错误类型
实现代码
diag.inc
cpp
#ifndef DIAG
#define DIAG(ID, KIND, MSG)
#endif
/// lexer
DIAG(err_unknown_char, Error, "unknown char '{0}'")
/// parser
DIAG(err_expected, Error, "expect '{0}' , but get '{1}'")
DIAG(err_redefine, Error, "redefine symbol '{0}'")
DIAG(err_undefine, Error, "undefine symbol '{0}'")
DIAG(err_lvalue, Error, "require left value on the assign operation left side")
DIAG(err_break_stmt, Error, "'break' statement not in loop or switch statement")
DIAG(err_continue_stmt, Error, "'continue' statement not in loop or switch statement")
#undef DIAG
1.5 代码生成 (CodeGen)
因为基类为了访问 break、continue 语句,新增了对 break、continue 语句的抽象访问函数,所以子类需要对这个函数进行复写。
实现代码
codegen.h
cpp
#pragma once
#include "ast.h"
#include "parser.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/DenseMap.h"
// 通过访问者模式生成代码
class CodeGen : public Visitor
{
public:
CodeGen(std::shared_ptr<Program> program)
{
module = std::make_shared<llvm::Module>("expr", context);
VisitProgram(program.get());
}
public:
llvm::Value *VisitProgram(Program *p) override;
llvm::Value *VisitDeclStmt(DeclStmt *decl) override;
llvm::Value *VisitBlockStmt(BlockStmt *block) override;
llvm::Value *VisitIfStmt(IfStmt *ifStmt) override;
llvm::Value *VisitForStmt(ForStmt *forStmt) override;
llvm::Value *VisitBreakStmt(BreakStmt *breakStmt) override;
llvm::Value *VisitContinueStmt(ContinueStmt *continueStmt) override;
llvm::Value *VisitVariableDecl(VariableDecl *decl) override;
llvm::Value *VisitVariableAccessExpr(VariableAccessExpr *varaccExpr) override;
llvm::Value *VisitAssignExpr(AssignExpr *assignExpr) override;
llvm::Value *VisitBinaryExpr(BinaryExpr *binaryExpr) override;
llvm::Value *VisitNumberExpr(NumberExpr *factorExpr) override;
private:
llvm::LLVMContext context;
llvm::IRBuilder<> irBuilder{context};
std::shared_ptr<llvm::Module> module;
llvm::Function *curFunc{nullptr};
llvm::StringMap<std::pair<llvm::Value *, llvm::Type *>> varAddrTyMap;
llvm::DenseMap<ASTNode *, llvm::BasicBlock *> breakBBs;
llvm::DenseMap<ASTNode *, llvm::BasicBlock *> continueBBs;
};
cpp
#include "codegen.h"
#include "llvm/IR/Verifier.h"
/*
; ModuleID = 'expr'
source_filename = "expr"
@0 = private unnamed_addr constant [16 x i8] c"expr value: %d\0A\00", align 1
declare i32 @printf(ptr, ...)
define i32 @main() {
entry:
%a = alloca i32, align 4
store i32 0, ptr %a, align 4
%a1 = load i32, ptr %a, align 4
%b = alloca i32, align 4
store i32 5, ptr %b, align 4
%b2 = load i32, ptr %b, align 4
br label %cond
cond: ; preds = %entry
%a3 = load i32, ptr %a, align 4
%0 = icmp ne i32 %a3, 0
br i1 %0, label %then, label %else
then: ; preds = %cond
store i32 3, ptr %b, align 4
%b4 = load i32, ptr %b, align 4
br label %last
else: ; preds = %cond
store i32 4, ptr %a, align 4
%a5 = load i32, ptr %a, align 4
br label %last
last: ; preds = %else, %then
%a6 = load i32, ptr %a, align 4
%b7 = load i32, ptr %b, align 4
%mul = mul nsw i32 %a6, %b7
%sub = sub nsw i32 %mul, 4
%1 = call i32 (ptr, ...) @printf(ptr @0, i32 %sub)
ret i32 0
}
*/
llvm::Value *CodeGen::VisitProgram(Program *p)
{
// 创建 printf 函数
auto printFunctionType = llvm::FunctionType::get(irBuilder.getInt32Ty(), {irBuilder.getInt8PtrTy()}, true);
auto printFunction = llvm::Function::Create(printFunctionType, llvm::GlobalValue::ExternalLinkage, "printf", module.get());
// 创建 main 函数
auto mainFunctionType = llvm::FunctionType::get(irBuilder.getInt32Ty(), false);
auto mainFunction = llvm::Function::Create(mainFunctionType, llvm::GlobalValue::ExternalLinkage, "main", module.get());
// 创建 main 函数的基本块
llvm::BasicBlock *entryBB = llvm::BasicBlock::Create(context, "entry", mainFunction);
// 设置该基本块作为指令的入口
irBuilder.SetInsertPoint(entryBB);
// 记录当前函数
curFunc = mainFunction;
llvm::Value *lastVal = nullptr;
for (auto node : p->nodeVec)
{
lastVal = node->Accept(this);
}
if (lastVal)
irBuilder.CreateCall(printFunction, {irBuilder.CreateGlobalStringPtr("expr value: %d\n"), lastVal});
else
irBuilder.CreateCall(printFunction, {irBuilder.CreateGlobalStringPtr("last instruction is not expr!\n")});
// 创建返回值
llvm::Value *ret = irBuilder.CreateRet(irBuilder.getInt32(0));
llvm::verifyFunction(*mainFunction);
module->print(llvm::outs(), nullptr);
return ret;
}
llvm::Value *CodeGen::VisitDeclStmt(DeclStmt *declStmt)
{
llvm::Value *lastVal = nullptr;
for (auto node : declStmt->nodeVec)
{
lastVal = node->Accept(this);
}
return lastVal;
}
llvm::Value *CodeGen::VisitBlockStmt(BlockStmt *block)
{
llvm::Value *lastVal = nullptr;
for (auto node : block->nodeVec)
{
lastVal = node->Accept(this);
}
return lastVal;
}
llvm::Value *CodeGen::VisitIfStmt(IfStmt *ifStmt)
{
llvm::BasicBlock *condBB = llvm::BasicBlock::Create(context, "if_cond", curFunc);
llvm::BasicBlock *thenBB = llvm::BasicBlock::Create(context, "if_then", curFunc);
llvm::BasicBlock *elseBB = nullptr;
if (ifStmt->elseNode)
elseBB = llvm::BasicBlock::Create(context, "if_else", curFunc);
llvm::BasicBlock *lastBB = llvm::BasicBlock::Create(context, "if_last", curFunc);
// 需要手动添加一个无条件跳转指令,llvm 不会自动完成这个工作
irBuilder.CreateBr(condBB);
irBuilder.SetInsertPoint(condBB);
llvm::Value *ret = ifStmt->condNode->Accept(this);
// 整形比较指令
llvm::Value *condVal = irBuilder.CreateICmpNE(ret, irBuilder.getInt32(0)); // 这里需要判断条件是否为真
if (ifStmt->elseNode)
{
irBuilder.CreateCondBr(condVal, thenBB, elseBB);
// handle then basic block
irBuilder.SetInsertPoint(thenBB);
ifStmt->thenNode->Accept(this);
irBuilder.CreateBr(lastBB);
// handle else basic block
irBuilder.SetInsertPoint(elseBB);
ifStmt->elseNode->Accept(this);
irBuilder.CreateBr(lastBB);
}
else
{
irBuilder.CreateCondBr(condVal, thenBB, lastBB);
// handle then basic block
irBuilder.SetInsertPoint(thenBB);
ifStmt->thenNode->Accept(this);
irBuilder.CreateBr(lastBB);
}
irBuilder.SetInsertPoint(lastBB);
return nullptr;
}
llvm::Value *CodeGen::VisitForStmt(ForStmt *forStmt)
{
llvm::BasicBlock *initBB = llvm::BasicBlock::Create(context, "for_init", curFunc);
llvm::BasicBlock *condBB = llvm::BasicBlock::Create(context, "for_cond", curFunc);
llvm::BasicBlock *incBB = llvm::BasicBlock::Create(context, "for_inc", curFunc);
llvm::BasicBlock *bodyBB = llvm::BasicBlock::Create(context, "for_body", curFunc);
llvm::BasicBlock *lastBB = llvm::BasicBlock::Create(context, "for_last", curFunc);
breakBBs.insert({forStmt, lastBB});
continueBBs.insert({forStmt, incBB});
irBuilder.CreateBr(initBB);
irBuilder.SetInsertPoint(initBB);
if (forStmt->initNode)
{
forStmt->initNode->Accept(this);
}
irBuilder.CreateBr(condBB);
irBuilder.SetInsertPoint(condBB);
if (forStmt->condNode)
{
llvm::Value *val = forStmt->condNode->Accept(this);
llvm::Value *condVal = irBuilder.CreateICmpNE(val, irBuilder.getInt32(0));
irBuilder.CreateCondBr(condVal, bodyBB, lastBB);
}
else
{
irBuilder.CreateBr(bodyBB);
}
irBuilder.SetInsertPoint(bodyBB);
if (forStmt->bodyNode)
{
forStmt->bodyNode->Accept(this);
}
irBuilder.CreateBr(incBB);
irBuilder.SetInsertPoint(incBB);
if (forStmt->incNode)
{
forStmt->incNode->Accept(this);
}
irBuilder.CreateBr(condBB);
irBuilder.SetInsertPoint(condBB);
irBuilder.SetInsertPoint(lastBB);
return nullptr;
}
llvm::Value *CodeGen::VisitBreakStmt(BreakStmt *breakStmt)
{
llvm::BasicBlock *bb = breakBBs[breakStmt->target.get()];
irBuilder.CreateBr(bb);
return nullptr;
}
llvm::Value *CodeGen::VisitContinueStmt(ContinueStmt *continueStmt)
{
llvm::BasicBlock *bb = continueBBs[continueStmt->target.get()];
irBuilder.CreateBr(bb);
llvm::BasicBlock *out = llvm::BasicBlock::Create(context, "for_continue_death");
irBuilder.SetInsertPoint(out);
return nullptr;
}
llvm::Value *CodeGen::VisitVariableDecl(VariableDecl *decl)
{
llvm::Type *ty = nullptr;
if (decl->type == CType::getIntTy())
{
ty = irBuilder.getInt32Ty();
}
llvm::StringRef text(decl->token.ptr, decl->token.length);
llvm::Value *varAddr = irBuilder.CreateAlloca(ty, nullptr, text);
varAddrTyMap.insert({text, {varAddr, ty}});
return varAddr;
}
llvm::Value *CodeGen::VisitVariableAccessExpr(VariableAccessExpr *varaccExpr)
{
llvm::StringRef text(varaccExpr->token.ptr, varaccExpr->token.length);
std::pair pair = varAddrTyMap[text];
llvm::Value *varAddr = pair.first;
llvm::Type *ty = pair.second;
if (varaccExpr->type == CType::getIntTy())
{
ty = irBuilder.getInt32Ty();
}
// 返回一个右值
return irBuilder.CreateLoad(ty, varAddr, text);
}
// a = 3; // right value
llvm::Value *CodeGen::VisitAssignExpr(AssignExpr *assignExpr)
{
VariableAccessExpr *varAccExpr = (VariableAccessExpr *)assignExpr->left.get();
llvm::StringRef text(varAccExpr->token.ptr, varAccExpr->token.length);
std::pair pair = varAddrTyMap[text];
llvm::Value *addr = pair.first;
llvm::Type *ty = pair.second;
llvm::Value *rValue = assignExpr->right->Accept(this);
// 这个得到的是一个左值
irBuilder.CreateStore(rValue, addr);
// 返回一个右值
return irBuilder.CreateLoad(ty, addr, text);
}
llvm::Value *CodeGen::VisitBinaryExpr(BinaryExpr *binaryExpr)
{
auto left = binaryExpr->left->Accept(this);
auto right = binaryExpr->right->Accept(this);
switch (binaryExpr->op)
{
case OpCode::add:
{
return irBuilder.CreateNSWAdd(left, right, "add"); // CreateNSW... 是防止溢出行为的
}
case OpCode::sub:
{
return irBuilder.CreateNSWSub(left, right, "sub");
}
case OpCode::mul:
{
return irBuilder.CreateNSWMul(left, right, "mul");
}
case OpCode::div:
{
return irBuilder.CreateSDiv(left, right, "div");
}
case OpCode::equal_equal:
{
llvm::Value *val = irBuilder.CreateICmpEQ(left, right, "equal_equal");
return irBuilder.CreateIntCast(val, irBuilder.getInt32Ty(), true);
}
case OpCode::not_equal:
{
llvm::Value *val = irBuilder.CreateICmpNE(left, right, "not_equal");
return irBuilder.CreateIntCast(val, irBuilder.getInt32Ty(), true);
}
case OpCode::less:
{
llvm::Value *val = irBuilder.CreateICmpSLT(left, right, "less");
return irBuilder.CreateIntCast(val, irBuilder.getInt32Ty(), true);
}
case OpCode::less_equal:
{
llvm::Value *val = irBuilder.CreateICmpSLE(left, right, "less_equal");
return irBuilder.CreateIntCast(val, irBuilder.getInt32Ty(), true);
}
case OpCode::greater:
{
llvm::Value *val = irBuilder.CreateICmpSGT(left, right, "greater");
return irBuilder.CreateIntCast(val, irBuilder.getInt32Ty(), true);
}
case OpCode::greater_equal:
{
llvm::Value *val = irBuilder.CreateICmpSGE(left, right, "greater_equal");
return irBuilder.CreateIntCast(val, irBuilder.getInt32Ty(), true);
}
default:
{
break;
}
}
return nullptr;
}
llvm::Value *CodeGen::VisitNumberExpr(NumberExpr *factorExpr)
{
return irBuilder.getInt32(factorExpr->token.value);
}
1.6 测试 break 语句和 continue 语句功能
生成 IR 文件
bash
bin/for test/expr.txt > test/expr.ll
生成 IR 的内容
ini
; ModuleID = 'expr'
source_filename = "expr"
@0 = private unnamed_addr constant [16 x i8] c"expr value: %d\0A\00", align 1
declare i32 @printf(ptr, ...)
define i32 @main() {
entry:
%a = alloca i32, align 4
store i32 3, ptr %a, align 4
%a1 = load i32, ptr %a, align 4
%b = alloca i32, align 4
store i32 5, ptr %b, align 4
%b2 = load i32, ptr %b, align 4
br label %for_init
for_init: ; preds = %entry
%i = alloca i32, align 4
store i32 0, ptr %i, align 4
%i3 = load i32, ptr %i, align 4
br label %for_cond
for_cond: ; preds = %for_inc, %for_init
%i4 = load i32, ptr %i, align 4
%less = icmp slt i32 %i4, 4
%0 = sext i1 %less to i32
%1 = icmp ne i32 %0, 0
br i1 %1, label %for_body, label %for_last
for_inc: ; preds = %if_last11, %if_then
%i17 = load i32, ptr %i, align 4
%add18 = add nsw i32 %i17, 1
store i32 %add18, ptr %i, align 4
%i19 = load i32, ptr %i, align 4
br label %for_cond
for_body: ; preds = %for_cond
br label %if_cond
for_last: ; preds = %if_then10, %for_cond
%a20 = load i32, ptr %a, align 4
%b21 = load i32, ptr %b, align 4
%a22 = load i32, ptr %a, align 4
%add23 = add nsw i32 %b21, %a22
%sub = sub nsw i32 %add23, 3
%mul24 = mul nsw i32 %a20, %sub
%sub25 = sub nsw i32 %mul24, 2
%2 = call i32 (ptr, ...) @printf(ptr @0, i32 %sub25)
ret i32 0
if_cond: ; preds = %for_body
%a5 = load i32, ptr %a, align 4
%greater_equal = icmp sge i32 %a5, 6
%3 = sext i1 %greater_equal to i32
%4 = icmp ne i32 %3, 0
br i1 %4, label %if_then, label %if_last
if_then: ; preds = %if_cond
br label %for_inc
if_last: ; preds = %for_continue_death, %if_cond
%a6 = load i32, ptr %a, align 4
%i7 = load i32, ptr %i, align 4
%add = add nsw i32 %a6, %i7
store i32 %add, ptr %a, align 4
%a8 = load i32, ptr %a, align 4
br label %if_cond9
if_cond9: ; preds = %if_last
%b12 = load i32, ptr %b, align 4
%greater = icmp sgt i32 %b12, 10
%5 = sext i1 %greater to i32
%6 = icmp ne i32 %5, 0
br i1 %6, label %if_then10, label %if_last11
if_then10: ; preds = %if_cond9
br label %for_last
br label %if_last11
if_last11: ; preds = %if_then10, %if_cond9
%b13 = load i32, ptr %b, align 4
%i14 = load i32, ptr %i, align 4
%mul = mul nsw i32 2, %i14
%add15 = add nsw i32 %b13, %mul
store i32 %add15, ptr %b, align 4
%b16 = load i32, ptr %b, align 4
br label %for_inc
}
运行 IR
bash
lli test/expr.ll
运行结果
bash
expr value: 82
结果正确,验证了编译器目前基础的 break 语句和 continue 语句功能的正确性。