Chapter3.cpp 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592
  1. #include "llvm/ADT/STLExtras.h"
  2. #include "llvm/IR/IRBuilder.h"
  3. #include "llvm/IR/LLVMContext.h"
  4. #include "llvm/IR/Module.h"
  5. #include "llvm/IR/Verifier.h"
  6. #include <cctype>
  7. #include <cstdio>
  8. #include <map>
  9. #include <string>
  10. #include <vector>
  11. using namespace llvm;
  12. //===----------------------------------------------------------------------===//
  13. // Lexer
  14. //===----------------------------------------------------------------------===//
  15. // The lexer returns tokens [0-255] if it is an unknown character, otherwise one
  16. // of these for known things.
  17. enum Token {
  18. tok_eof = -1,
  19. // commands
  20. tok_def = -2,
  21. tok_extern = -3,
  22. // primary
  23. tok_identifier = -4,
  24. tok_number = -5
  25. };
  26. static std::string IdentifierStr; // Filled in if tok_identifier
  27. static double NumVal; // Filled in if tok_number
  28. /// gettok - Return the next token from standard input.
  29. static int gettok() {
  30. static int LastChar = ' ';
  31. // Skip any whitespace.
  32. while (isspace(LastChar))
  33. LastChar = getchar();
  34. if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9]*
  35. IdentifierStr = LastChar;
  36. while (isalnum((LastChar = getchar())))
  37. IdentifierStr += LastChar;
  38. if (IdentifierStr == "def")
  39. return tok_def;
  40. if (IdentifierStr == "extern")
  41. return tok_extern;
  42. return tok_identifier;
  43. }
  44. if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+
  45. std::string NumStr;
  46. do {
  47. NumStr += LastChar;
  48. LastChar = getchar();
  49. } while (isdigit(LastChar) || LastChar == '.');
  50. NumVal = strtod(NumStr.c_str(), nullptr);
  51. return tok_number;
  52. }
  53. if (LastChar == '#') {
  54. // Comment until end of line.
  55. do
  56. LastChar = getchar();
  57. while (LastChar != EOF && LastChar != '\n' && LastChar != '\r');
  58. if (LastChar != EOF)
  59. return gettok();
  60. }
  61. // Check for end of file. Don't eat the EOF.
  62. if (LastChar == EOF)
  63. return tok_eof;
  64. // Otherwise, just return the character as its ascii value.
  65. int ThisChar = LastChar;
  66. LastChar = getchar();
  67. return ThisChar;
  68. }
  69. //===----------------------------------------------------------------------===//
  70. // Abstract Syntax Tree (aka Parse Tree)
  71. //===----------------------------------------------------------------------===//
  72. namespace {
  73. /// ExprAST - Base class for all expression nodes.
  74. class ExprAST {
  75. public:
  76. virtual ~ExprAST() {}
  77. virtual Value *codegen() = 0;
  78. };
  79. /// NumberExprAST - Expression class for numeric literals like "1.0".
  80. class NumberExprAST : public ExprAST {
  81. double Val;
  82. public:
  83. NumberExprAST(double Val) : Val(Val) {}
  84. Value *codegen() override;
  85. };
  86. /// VariableExprAST - Expression class for referencing a variable, like "a".
  87. class VariableExprAST : public ExprAST {
  88. std::string Name;
  89. public:
  90. VariableExprAST(const std::string &Name) : Name(Name) {}
  91. Value *codegen() override;
  92. };
  93. /// BinaryExprAST - Expression class for a binary operator.
  94. class BinaryExprAST : public ExprAST {
  95. char Op;
  96. std::unique_ptr<ExprAST> LHS, RHS;
  97. public:
  98. BinaryExprAST(char Op, std::unique_ptr<ExprAST> LHS,
  99. std::unique_ptr<ExprAST> RHS)
  100. : Op(Op), LHS(std::move(LHS)), RHS(std::move(RHS)) {}
  101. Value *codegen() override;
  102. };
  103. /// CallExprAST - Expression class for function calls.
  104. class CallExprAST : public ExprAST {
  105. std::string Callee;
  106. std::vector<std::unique_ptr<ExprAST>> Args;
  107. public:
  108. CallExprAST(const std::string &Callee,
  109. std::vector<std::unique_ptr<ExprAST>> Args)
  110. : Callee(Callee), Args(std::move(Args)) {}
  111. Value *codegen() override;
  112. };
  113. /// PrototypeAST - This class represents the "prototype" for a function,
  114. /// which captures its name, and its argument names (thus implicitly the number
  115. /// of arguments the function takes).
  116. class PrototypeAST {
  117. std::string Name;
  118. std::vector<std::string> Args;
  119. public:
  120. PrototypeAST(const std::string &Name, std::vector<std::string> Args)
  121. : Name(Name), Args(std::move(Args)) {}
  122. Function *codegen();
  123. const std::string &getName() const { return Name; }
  124. };
  125. /// FunctionAST - This class represents a function definition itself.
  126. class FunctionAST {
  127. std::unique_ptr<PrototypeAST> Proto;
  128. std::unique_ptr<ExprAST> Body;
  129. public:
  130. FunctionAST(std::unique_ptr<PrototypeAST> Proto,
  131. std::unique_ptr<ExprAST> Body)
  132. : Proto(std::move(Proto)), Body(std::move(Body)) {}
  133. Function *codegen();
  134. };
  135. } // end anonymous namespace
  136. //===----------------------------------------------------------------------===//
  137. // Parser
  138. //===----------------------------------------------------------------------===//
  139. /// CurTok/getNextToken - Provide a simple token buffer. CurTok is the current
  140. /// token the parser is looking at. getNextToken reads another token from the
  141. /// lexer and updates CurTok with its results.
  142. static int CurTok;
  143. static int getNextToken() { return CurTok = gettok(); }
  144. /// BinopPrecedence - This holds the precedence for each binary operator that is
  145. /// defined.
  146. static std::map<char, int> BinopPrecedence;
  147. /// GetTokPrecedence - Get the precedence of the pending binary operator token.
  148. static int GetTokPrecedence() {
  149. if (!isascii(CurTok))
  150. return -1;
  151. // Make sure it's a declared binop.
  152. int TokPrec = BinopPrecedence[CurTok];
  153. if (TokPrec <= 0)
  154. return -1;
  155. return TokPrec;
  156. }
  157. /// Error* - These are little helper functions for error handling.
  158. std::unique_ptr<ExprAST> Error(const char *Str) {
  159. fprintf(stderr, "Error: %s\n", Str);
  160. return nullptr;
  161. }
  162. std::unique_ptr<PrototypeAST> ErrorP(const char *Str) {
  163. Error(Str);
  164. return nullptr;
  165. }
  166. static std::unique_ptr<ExprAST> ParseExpression();
  167. /// numberexpr ::= number
  168. static std::unique_ptr<ExprAST> ParseNumberExpr() {
  169. auto Result = llvm::make_unique<NumberExprAST>(NumVal);
  170. getNextToken(); // consume the number
  171. return std::move(Result);
  172. }
  173. /// parenexpr ::= '(' expression ')'
  174. static std::unique_ptr<ExprAST> ParseParenExpr() {
  175. getNextToken(); // eat (.
  176. auto V = ParseExpression();
  177. if (!V)
  178. return nullptr;
  179. if (CurTok != ')')
  180. return Error("expected ')'");
  181. getNextToken(); // eat ).
  182. return V;
  183. }
  184. /// identifierexpr
  185. /// ::= identifier
  186. /// ::= identifier '(' expression* ')'
  187. static std::unique_ptr<ExprAST> ParseIdentifierExpr() {
  188. std::string IdName = IdentifierStr;
  189. getNextToken(); // eat identifier.
  190. if (CurTok != '(') // Simple variable ref.
  191. return llvm::make_unique<VariableExprAST>(IdName);
  192. // Call.
  193. getNextToken(); // eat (
  194. std::vector<std::unique_ptr<ExprAST>> Args;
  195. if (CurTok != ')') {
  196. while (1) {
  197. if (auto Arg = ParseExpression())
  198. Args.push_back(std::move(Arg));
  199. else
  200. return nullptr;
  201. if (CurTok == ')')
  202. break;
  203. if (CurTok != ',')
  204. return Error("Expected ')' or ',' in argument list");
  205. getNextToken();
  206. }
  207. }
  208. // Eat the ')'.
  209. getNextToken();
  210. return llvm::make_unique<CallExprAST>(IdName, std::move(Args));
  211. }
  212. /// primary
  213. /// ::= identifierexpr
  214. /// ::= numberexpr
  215. /// ::= parenexpr
  216. static std::unique_ptr<ExprAST> ParsePrimary() {
  217. switch (CurTok) {
  218. default:
  219. return Error("unknown token when expecting an expression");
  220. case tok_identifier:
  221. return ParseIdentifierExpr();
  222. case tok_number:
  223. return ParseNumberExpr();
  224. case '(':
  225. return ParseParenExpr();
  226. }
  227. }
  228. /// binoprhs
  229. /// ::= ('+' primary)*
  230. static std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec,
  231. std::unique_ptr<ExprAST> LHS) {
  232. // If this is a binop, find its precedence.
  233. while (1) {
  234. int TokPrec = GetTokPrecedence();
  235. // If this is a binop that binds at least as tightly as the current binop,
  236. // consume it, otherwise we are done.
  237. if (TokPrec < ExprPrec)
  238. return LHS;
  239. // Okay, we know this is a binop.
  240. int BinOp = CurTok;
  241. getNextToken(); // eat binop
  242. // Parse the primary expression after the binary operator.
  243. auto RHS = ParsePrimary();
  244. if (!RHS)
  245. return nullptr;
  246. // If BinOp binds less tightly with RHS than the operator after RHS, let
  247. // the pending operator take RHS as its LHS.
  248. int NextPrec = GetTokPrecedence();
  249. if (TokPrec < NextPrec) {
  250. RHS = ParseBinOpRHS(TokPrec + 1, std::move(RHS));
  251. if (!RHS)
  252. return nullptr;
  253. }
  254. // Merge LHS/RHS.
  255. LHS =
  256. llvm::make_unique<BinaryExprAST>(BinOp, std::move(LHS), std::move(RHS));
  257. }
  258. }
  259. /// expression
  260. /// ::= primary binoprhs
  261. ///
  262. static std::unique_ptr<ExprAST> ParseExpression() {
  263. auto LHS = ParsePrimary();
  264. if (!LHS)
  265. return nullptr;
  266. return ParseBinOpRHS(0, std::move(LHS));
  267. }
  268. /// prototype
  269. /// ::= id '(' id* ')'
  270. static std::unique_ptr<PrototypeAST> ParsePrototype() {
  271. if (CurTok != tok_identifier)
  272. return ErrorP("Expected function name in prototype");
  273. std::string FnName = IdentifierStr;
  274. getNextToken();
  275. if (CurTok != '(')
  276. return ErrorP("Expected '(' in prototype");
  277. std::vector<std::string> ArgNames;
  278. while (getNextToken() == tok_identifier)
  279. ArgNames.push_back(IdentifierStr);
  280. if (CurTok != ')')
  281. return ErrorP("Expected ')' in prototype");
  282. // success.
  283. getNextToken(); // eat ')'.
  284. return llvm::make_unique<PrototypeAST>(FnName, std::move(ArgNames));
  285. }
  286. /// definition ::= 'def' prototype expression
  287. static std::unique_ptr<FunctionAST> ParseDefinition() {
  288. getNextToken(); // eat def.
  289. auto Proto = ParsePrototype();
  290. if (!Proto)
  291. return nullptr;
  292. if (auto E = ParseExpression())
  293. return llvm::make_unique<FunctionAST>(std::move(Proto), std::move(E));
  294. return nullptr;
  295. }
  296. /// toplevelexpr ::= expression
  297. static std::unique_ptr<FunctionAST> ParseTopLevelExpr() {
  298. if (auto E = ParseExpression()) {
  299. // Make an anonymous proto.
  300. auto Proto = llvm::make_unique<PrototypeAST>("__anon_expr",
  301. std::vector<std::string>());
  302. return llvm::make_unique<FunctionAST>(std::move(Proto), std::move(E));
  303. }
  304. return nullptr;
  305. }
  306. /// external ::= 'extern' prototype
  307. static std::unique_ptr<PrototypeAST> ParseExtern() {
  308. getNextToken(); // eat extern.
  309. return ParsePrototype();
  310. }
  311. //===----------------------------------------------------------------------===//
  312. // Code Generation
  313. //===----------------------------------------------------------------------===//
  314. static std::unique_ptr<Module> TheModule;
  315. static IRBuilder<> Builder(getGlobalContext());
  316. static std::map<std::string, Value *> NamedValues;
  317. Value *ErrorV(const char *Str) {
  318. Error(Str);
  319. return nullptr;
  320. }
  321. Value *NumberExprAST::codegen() {
  322. return ConstantFP::get(getGlobalContext(), APFloat(Val));
  323. }
  324. Value *VariableExprAST::codegen() {
  325. // Look this variable up in the function.
  326. Value *V = NamedValues[Name];
  327. if (!V)
  328. return ErrorV("Unknown variable name");
  329. return V;
  330. }
  331. Value *BinaryExprAST::codegen() {
  332. Value *L = LHS->codegen();
  333. Value *R = RHS->codegen();
  334. if (!L || !R)
  335. return nullptr;
  336. switch (Op) {
  337. case '+':
  338. return Builder.CreateFAdd(L, R, "addtmp");
  339. case '-':
  340. return Builder.CreateFSub(L, R, "subtmp");
  341. case '*':
  342. return Builder.CreateFMul(L, R, "multmp");
  343. case '<':
  344. L = Builder.CreateFCmpULT(L, R, "cmptmp");
  345. // Convert bool 0/1 to double 0.0 or 1.0
  346. return Builder.CreateUIToFP(L, Type::getDoubleTy(getGlobalContext()),
  347. "booltmp");
  348. default:
  349. return ErrorV("invalid binary operator");
  350. }
  351. }
  352. Value *CallExprAST::codegen() {
  353. // Look up the name in the global module table.
  354. Function *CalleeF = TheModule->getFunction(Callee);
  355. if (!CalleeF)
  356. return ErrorV("Unknown function referenced");
  357. // If argument mismatch error.
  358. if (CalleeF->arg_size() != Args.size())
  359. return ErrorV("Incorrect # arguments passed");
  360. std::vector<Value *> ArgsV;
  361. for (unsigned i = 0, e = Args.size(); i != e; ++i) {
  362. ArgsV.push_back(Args[i]->codegen());
  363. if (!ArgsV.back())
  364. return nullptr;
  365. }
  366. return Builder.CreateCall(CalleeF, ArgsV, "calltmp");
  367. }
  368. Function *PrototypeAST::codegen() {
  369. // Make the function type: double(double,double) etc.
  370. std::vector<Type *> Doubles(Args.size(),
  371. Type::getDoubleTy(getGlobalContext()));
  372. FunctionType *FT =
  373. FunctionType::get(Type::getDoubleTy(getGlobalContext()), Doubles, false);
  374. Function *F =
  375. Function::Create(FT, Function::ExternalLinkage, Name, TheModule.get());
  376. // Set names for all arguments.
  377. unsigned Idx = 0;
  378. for (auto &Arg : F->args())
  379. Arg.setName(Args[Idx++]);
  380. return F;
  381. }
  382. Function *FunctionAST::codegen() {
  383. // First, check for an existing function from a previous 'extern' declaration.
  384. Function *TheFunction = TheModule->getFunction(Proto->getName());
  385. if (!TheFunction)
  386. TheFunction = Proto->codegen();
  387. if (!TheFunction)
  388. return nullptr;
  389. if (!TheFunction->empty())
  390. return (Function*)ErrorV("Function cannot be redefined.");
  391. // Create a new basic block to start insertion into.
  392. BasicBlock *BB = BasicBlock::Create(getGlobalContext(), "entry", TheFunction);
  393. Builder.SetInsertPoint(BB);
  394. // Record the function arguments in the NamedValues map.
  395. NamedValues.clear();
  396. for (auto &Arg : TheFunction->args())
  397. NamedValues[Arg.getName()] = &Arg;
  398. if (Value *RetVal = Body->codegen()) {
  399. // Finish off the function.
  400. Builder.CreateRet(RetVal);
  401. // Validate the generated code, checking for consistency.
  402. verifyFunction(*TheFunction);
  403. return TheFunction;
  404. }
  405. // Error reading body, remove function.
  406. TheFunction->eraseFromParent();
  407. return nullptr;
  408. }
  409. //===----------------------------------------------------------------------===//
  410. // Top-Level parsing and JIT Driver
  411. //===----------------------------------------------------------------------===//
  412. static void HandleDefinition() {
  413. if (auto FnAST = ParseDefinition()) {
  414. if (auto *FnIR = FnAST->codegen()) {
  415. fprintf(stderr, "Read function definition:");
  416. FnIR->dump();
  417. }
  418. } else {
  419. // Skip token for error recovery.
  420. getNextToken();
  421. }
  422. }
  423. static void HandleExtern() {
  424. if (auto ProtoAST = ParseExtern()) {
  425. if (auto *FnIR = ProtoAST->codegen()) {
  426. fprintf(stderr, "Read extern: ");
  427. FnIR->dump();
  428. }
  429. } else {
  430. // Skip token for error recovery.
  431. getNextToken();
  432. }
  433. }
  434. static void HandleTopLevelExpression() {
  435. // Evaluate a top-level expression into an anonymous function.
  436. if (auto FnAST = ParseTopLevelExpr()) {
  437. if (auto *FnIR = FnAST->codegen()) {
  438. fprintf(stderr, "Read top-level expression:");
  439. FnIR->dump();
  440. }
  441. } else {
  442. // Skip token for error recovery.
  443. getNextToken();
  444. }
  445. }
  446. /// top ::= definition | external | expression | ';'
  447. static void MainLoop() {
  448. while (1) {
  449. fprintf(stderr, "ready> ");
  450. switch (CurTok) {
  451. case tok_eof:
  452. return;
  453. case ';': // ignore top-level semicolons.
  454. getNextToken();
  455. break;
  456. case tok_def:
  457. HandleDefinition();
  458. break;
  459. case tok_extern:
  460. HandleExtern();
  461. break;
  462. default:
  463. HandleTopLevelExpression();
  464. break;
  465. }
  466. }
  467. }
  468. //===----------------------------------------------------------------------===//
  469. // Main driver code.
  470. //===----------------------------------------------------------------------===//
  471. int main() {
  472. // Install standard binary operators.
  473. // 1 is lowest precedence.
  474. BinopPrecedence['<'] = 10;
  475. BinopPrecedence['+'] = 20;
  476. BinopPrecedence['-'] = 20;
  477. BinopPrecedence['*'] = 40; // highest.
  478. // Prime the first token.
  479. fprintf(stderr, "ready> ");
  480. getNextToken();
  481. // Make the module, which holds all the code.
  482. TheModule = llvm::make_unique<Module>("my cool jit", getGlobalContext());
  483. // Run the main "interpreter loop" now.
  484. MainLoop();
  485. // Print out all of the generated code.
  486. TheModule->dump();
  487. return 0;
  488. }