Chapter4.cpp 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691
  1. #include "llvm/ADT/APFloat.h"
  2. #include "llvm/ADT/STLExtras.h"
  3. #include "llvm/IR/BasicBlock.h"
  4. #include "llvm/IR/Constants.h"
  5. #include "llvm/IR/DerivedTypes.h"
  6. #include "llvm/IR/Function.h"
  7. #include "llvm/IR/IRBuilder.h"
  8. #include "llvm/IR/LLVMContext.h"
  9. #include "llvm/IR/LegacyPassManager.h"
  10. #include "llvm/IR/Module.h"
  11. #include "llvm/IR/Type.h"
  12. #include "llvm/IR/Verifier.h"
  13. #include "llvm/Support/TargetSelect.h"
  14. #include "llvm/Target/TargetMachine.h"
  15. #include "llvm/Transforms/Scalar.h"
  16. #include "llvm/Transforms/Scalar/GVN.h"
  17. #include "../include/KaleidoscopeJIT.h"
  18. #include <algorithm>
  19. #include <cassert>
  20. #include <cctype>
  21. #include <cstdint>
  22. #include <cstdio>
  23. #include <cstdlib>
  24. #include <map>
  25. #include <memory>
  26. #include <string>
  27. #include <vector>
  28. using namespace llvm;
  29. using namespace llvm::orc;
  30. //===----------------------------------------------------------------------===//
  31. // Lexer
  32. //===----------------------------------------------------------------------===//
  33. // The lexer returns tokens [0-255] if it is an unknown character, otherwise one
  34. // of these for known things.
  35. enum Token {
  36. tok_eof = -1,
  37. // commands
  38. tok_def = -2,
  39. tok_extern = -3,
  40. // primary
  41. tok_identifier = -4,
  42. tok_number = -5
  43. };
  44. static std::string IdentifierStr; // Filled in if tok_identifier
  45. static double NumVal; // Filled in if tok_number
  46. /// gettok - Return the next token from standard input.
  47. static int gettok() {
  48. static int LastChar = ' ';
  49. // Skip any whitespace.
  50. while (isspace(LastChar))
  51. LastChar = getchar();
  52. if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9]*
  53. IdentifierStr = LastChar;
  54. while (isalnum((LastChar = getchar())))
  55. IdentifierStr += LastChar;
  56. if (IdentifierStr == "def")
  57. return tok_def;
  58. if (IdentifierStr == "extern")
  59. return tok_extern;
  60. return tok_identifier;
  61. }
  62. if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+
  63. std::string NumStr;
  64. do {
  65. NumStr += LastChar;
  66. LastChar = getchar();
  67. } while (isdigit(LastChar) || LastChar == '.');
  68. NumVal = strtod(NumStr.c_str(), nullptr);
  69. return tok_number;
  70. }
  71. if (LastChar == '#') {
  72. // Comment until end of line.
  73. do
  74. LastChar = getchar();
  75. while (LastChar != EOF && LastChar != '\n' && LastChar != '\r');
  76. if (LastChar != EOF)
  77. return gettok();
  78. }
  79. // Check for end of file. Don't eat the EOF.
  80. if (LastChar == EOF)
  81. return tok_eof;
  82. // Otherwise, just return the character as its ascii value.
  83. int ThisChar = LastChar;
  84. LastChar = getchar();
  85. return ThisChar;
  86. }
  87. //===----------------------------------------------------------------------===//
  88. // Abstract Syntax Tree (aka Parse Tree)
  89. //===----------------------------------------------------------------------===//
  90. namespace {
  91. /// ExprAST - Base class for all expression nodes.
  92. class ExprAST {
  93. public:
  94. virtual ~ExprAST() = default;
  95. virtual Value *codegen() = 0;
  96. };
  97. /// NumberExprAST - Expression class for numeric literals like "1.0".
  98. class NumberExprAST : public ExprAST {
  99. double Val;
  100. public:
  101. NumberExprAST(double Val) : Val(Val) {}
  102. Value *codegen() override;
  103. };
  104. /// VariableExprAST - Expression class for referencing a variable, like "a".
  105. class VariableExprAST : public ExprAST {
  106. std::string Name;
  107. public:
  108. VariableExprAST(const std::string &Name) : Name(Name) {}
  109. Value *codegen() override;
  110. };
  111. /// BinaryExprAST - Expression class for a binary operator.
  112. class BinaryExprAST : public ExprAST {
  113. char Op;
  114. std::unique_ptr<ExprAST> LHS, RHS;
  115. public:
  116. BinaryExprAST(char Op, std::unique_ptr<ExprAST> LHS,
  117. std::unique_ptr<ExprAST> RHS)
  118. : Op(Op), LHS(std::move(LHS)), RHS(std::move(RHS)) {}
  119. Value *codegen() override;
  120. };
  121. /// CallExprAST - Expression class for function calls.
  122. class CallExprAST : public ExprAST {
  123. std::string Callee;
  124. std::vector<std::unique_ptr<ExprAST>> Args;
  125. public:
  126. CallExprAST(const std::string &Callee,
  127. std::vector<std::unique_ptr<ExprAST>> Args)
  128. : Callee(Callee), Args(std::move(Args)) {}
  129. Value *codegen() override;
  130. };
  131. /// PrototypeAST - This class represents the "prototype" for a function,
  132. /// which captures its name, and its argument names (thus implicitly the number
  133. /// of arguments the function takes).
  134. class PrototypeAST {
  135. std::string Name;
  136. std::vector<std::string> Args;
  137. public:
  138. PrototypeAST(const std::string &Name, std::vector<std::string> Args)
  139. : Name(Name), Args(std::move(Args)) {}
  140. Function *codegen();
  141. const std::string &getName() const { return Name; }
  142. };
  143. /// FunctionAST - This class represents a function definition itself.
  144. class FunctionAST {
  145. std::unique_ptr<PrototypeAST> Proto;
  146. std::unique_ptr<ExprAST> Body;
  147. public:
  148. FunctionAST(std::unique_ptr<PrototypeAST> Proto,
  149. std::unique_ptr<ExprAST> Body)
  150. : Proto(std::move(Proto)), Body(std::move(Body)) {}
  151. Function *codegen();
  152. };
  153. } // end anonymous namespace
  154. //===----------------------------------------------------------------------===//
  155. // Parser
  156. //===----------------------------------------------------------------------===//
  157. /// CurTok/getNextToken - Provide a simple token buffer. CurTok is the current
  158. /// token the parser is looking at. getNextToken reads another token from the
  159. /// lexer and updates CurTok with its results.
  160. static int CurTok;
  161. static int getNextToken() { return CurTok = gettok(); }
  162. /// BinopPrecedence - This holds the precedence for each binary operator that is
  163. /// defined.
  164. static std::map<char, int> BinopPrecedence;
  165. /// GetTokPrecedence - Get the precedence of the pending binary operator token.
  166. static int GetTokPrecedence() {
  167. if (!isascii(CurTok))
  168. return -1;
  169. // Make sure it's a declared binop.
  170. int TokPrec = BinopPrecedence[CurTok];
  171. if (TokPrec <= 0)
  172. return -1;
  173. return TokPrec;
  174. }
  175. /// LogError* - These are little helper functions for error handling.
  176. std::unique_ptr<ExprAST> LogError(const char *Str) {
  177. fprintf(stderr, "Error: %s\n", Str);
  178. return nullptr;
  179. }
  180. std::unique_ptr<PrototypeAST> LogErrorP(const char *Str) {
  181. LogError(Str);
  182. return nullptr;
  183. }
  184. static std::unique_ptr<ExprAST> ParseExpression();
  185. /// numberexpr ::= number
  186. static std::unique_ptr<ExprAST> ParseNumberExpr() {
  187. auto Result = llvm::make_unique<NumberExprAST>(NumVal);
  188. getNextToken(); // consume the number
  189. return std::move(Result);
  190. }
  191. /// parenexpr ::= '(' expression ')'
  192. static std::unique_ptr<ExprAST> ParseParenExpr() {
  193. getNextToken(); // eat (.
  194. auto V = ParseExpression();
  195. if (!V)
  196. return nullptr;
  197. if (CurTok != ')')
  198. return LogError("expected ')'");
  199. getNextToken(); // eat ).
  200. return V;
  201. }
  202. /// identifierexpr
  203. /// ::= identifier
  204. /// ::= identifier '(' expression* ')'
  205. static std::unique_ptr<ExprAST> ParseIdentifierExpr() {
  206. std::string IdName = IdentifierStr;
  207. getNextToken(); // eat identifier.
  208. if (CurTok != '(') // Simple variable ref.
  209. return llvm::make_unique<VariableExprAST>(IdName);
  210. // Call.
  211. getNextToken(); // eat (
  212. std::vector<std::unique_ptr<ExprAST>> Args;
  213. if (CurTok != ')') {
  214. while (true) {
  215. if (auto Arg = ParseExpression())
  216. Args.push_back(std::move(Arg));
  217. else
  218. return nullptr;
  219. if (CurTok == ')')
  220. break;
  221. if (CurTok != ',')
  222. return LogError("Expected ')' or ',' in argument list");
  223. getNextToken();
  224. }
  225. }
  226. // Eat the ')'.
  227. getNextToken();
  228. return llvm::make_unique<CallExprAST>(IdName, std::move(Args));
  229. }
  230. /// primary
  231. /// ::= identifierexpr
  232. /// ::= numberexpr
  233. /// ::= parenexpr
  234. static std::unique_ptr<ExprAST> ParsePrimary() {
  235. switch (CurTok) {
  236. default:
  237. return LogError("unknown token when expecting an expression");
  238. case tok_identifier:
  239. return ParseIdentifierExpr();
  240. case tok_number:
  241. return ParseNumberExpr();
  242. case '(':
  243. return ParseParenExpr();
  244. }
  245. }
  246. /// binoprhs
  247. /// ::= ('+' primary)*
  248. static std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec,
  249. std::unique_ptr<ExprAST> LHS) {
  250. // If this is a binop, find its precedence.
  251. while (true) {
  252. int TokPrec = GetTokPrecedence();
  253. // If this is a binop that binds at least as tightly as the current binop,
  254. // consume it, otherwise we are done.
  255. if (TokPrec < ExprPrec)
  256. return LHS;
  257. // Okay, we know this is a binop.
  258. int BinOp = CurTok;
  259. getNextToken(); // eat binop
  260. // Parse the primary expression after the binary operator.
  261. auto RHS = ParsePrimary();
  262. if (!RHS)
  263. return nullptr;
  264. // If BinOp binds less tightly with RHS than the operator after RHS, let
  265. // the pending operator take RHS as its LHS.
  266. int NextPrec = GetTokPrecedence();
  267. if (TokPrec < NextPrec) {
  268. RHS = ParseBinOpRHS(TokPrec + 1, std::move(RHS));
  269. if (!RHS)
  270. return nullptr;
  271. }
  272. // Merge LHS/RHS.
  273. LHS =
  274. llvm::make_unique<BinaryExprAST>(BinOp, std::move(LHS), std::move(RHS));
  275. }
  276. }
  277. /// expression
  278. /// ::= primary binoprhs
  279. ///
  280. static std::unique_ptr<ExprAST> ParseExpression() {
  281. auto LHS = ParsePrimary();
  282. if (!LHS)
  283. return nullptr;
  284. return ParseBinOpRHS(0, std::move(LHS));
  285. }
  286. /// prototype
  287. /// ::= id '(' id* ')'
  288. static std::unique_ptr<PrototypeAST> ParsePrototype() {
  289. if (CurTok != tok_identifier)
  290. return LogErrorP("Expected function name in prototype");
  291. std::string FnName = IdentifierStr;
  292. getNextToken();
  293. if (CurTok != '(')
  294. return LogErrorP("Expected '(' in prototype");
  295. std::vector<std::string> ArgNames;
  296. while (getNextToken() == tok_identifier)
  297. ArgNames.push_back(IdentifierStr);
  298. if (CurTok != ')')
  299. return LogErrorP("Expected ')' in prototype");
  300. // success.
  301. getNextToken(); // eat ')'.
  302. return llvm::make_unique<PrototypeAST>(FnName, std::move(ArgNames));
  303. }
  304. /// definition ::= 'def' prototype expression
  305. static std::unique_ptr<FunctionAST> ParseDefinition() {
  306. getNextToken(); // eat def.
  307. auto Proto = ParsePrototype();
  308. if (!Proto)
  309. return nullptr;
  310. if (auto E = ParseExpression())
  311. return llvm::make_unique<FunctionAST>(std::move(Proto), std::move(E));
  312. return nullptr;
  313. }
  314. /// toplevelexpr ::= expression
  315. static std::unique_ptr<FunctionAST> ParseTopLevelExpr() {
  316. if (auto E = ParseExpression()) {
  317. // Make an anonymous proto.
  318. auto Proto = llvm::make_unique<PrototypeAST>("__anon_expr",
  319. std::vector<std::string>());
  320. return llvm::make_unique<FunctionAST>(std::move(Proto), std::move(E));
  321. }
  322. return nullptr;
  323. }
  324. /// external ::= 'extern' prototype
  325. static std::unique_ptr<PrototypeAST> ParseExtern() {
  326. getNextToken(); // eat extern.
  327. return ParsePrototype();
  328. }
  329. //===----------------------------------------------------------------------===//
  330. // Code Generation
  331. //===----------------------------------------------------------------------===//
  332. static LLVMContext TheContext;
  333. static IRBuilder<> Builder(TheContext);
  334. static std::unique_ptr<Module> TheModule;
  335. static std::map<std::string, Value *> NamedValues;
  336. static std::unique_ptr<legacy::FunctionPassManager> TheFPM;
  337. static std::unique_ptr<KaleidoscopeJIT> TheJIT;
  338. static std::map<std::string, std::unique_ptr<PrototypeAST>> FunctionProtos;
  339. Value *LogErrorV(const char *Str) {
  340. LogError(Str);
  341. return nullptr;
  342. }
  343. Function *getFunction(std::string Name) {
  344. // First, see if the function has already been added to the current module.
  345. if (auto *F = TheModule->getFunction(Name))
  346. return F;
  347. // If not, check whether we can codegen the declaration from some existing
  348. // prototype.
  349. auto FI = FunctionProtos.find(Name);
  350. if (FI != FunctionProtos.end())
  351. return FI->second->codegen();
  352. // If no existing prototype exists, return null.
  353. return nullptr;
  354. }
  355. Value *NumberExprAST::codegen() {
  356. return ConstantFP::get(TheContext, APFloat(Val));
  357. }
  358. Value *VariableExprAST::codegen() {
  359. // Look this variable up in the function.
  360. Value *V = NamedValues[Name];
  361. if (!V)
  362. return LogErrorV("Unknown variable name");
  363. return V;
  364. }
  365. Value *BinaryExprAST::codegen() {
  366. Value *L = LHS->codegen();
  367. Value *R = RHS->codegen();
  368. if (!L || !R)
  369. return nullptr;
  370. switch (Op) {
  371. case '+':
  372. return Builder.CreateFAdd(L, R, "addtmp");
  373. case '-':
  374. return Builder.CreateFSub(L, R, "subtmp");
  375. case '*':
  376. return Builder.CreateFMul(L, R, "multmp");
  377. case '<':
  378. L = Builder.CreateFCmpULT(L, R, "cmptmp");
  379. // Convert bool 0/1 to double 0.0 or 1.0
  380. return Builder.CreateUIToFP(L, Type::getDoubleTy(TheContext), "booltmp");
  381. default:
  382. return LogErrorV("invalid binary operator");
  383. }
  384. }
  385. Value *CallExprAST::codegen() {
  386. // Look up the name in the global module table.
  387. Function *CalleeF = getFunction(Callee);
  388. if (!CalleeF)
  389. return LogErrorV("Unknown function referenced");
  390. // If argument mismatch error.
  391. if (CalleeF->arg_size() != Args.size())
  392. return LogErrorV("Incorrect # arguments passed");
  393. std::vector<Value *> ArgsV;
  394. for (unsigned i = 0, e = Args.size(); i != e; ++i) {
  395. ArgsV.push_back(Args[i]->codegen());
  396. if (!ArgsV.back())
  397. return nullptr;
  398. }
  399. return Builder.CreateCall(CalleeF, ArgsV, "calltmp");
  400. }
  401. Function *PrototypeAST::codegen() {
  402. // Make the function type: double(double,double) etc.
  403. std::vector<Type *> Doubles(Args.size(), Type::getDoubleTy(TheContext));
  404. FunctionType *FT =
  405. FunctionType::get(Type::getDoubleTy(TheContext), Doubles, false);
  406. Function *F =
  407. Function::Create(FT, Function::ExternalLinkage, Name, TheModule.get());
  408. // Set names for all arguments.
  409. unsigned Idx = 0;
  410. for (auto &Arg : F->args())
  411. Arg.setName(Args[Idx++]);
  412. return F;
  413. }
  414. Function *FunctionAST::codegen() {
  415. // Transfer ownership of the prototype to the FunctionProtos map, but keep a
  416. // reference to it for use below.
  417. auto &P = *Proto;
  418. FunctionProtos[Proto->getName()] = std::move(Proto);
  419. Function *TheFunction = getFunction(P.getName());
  420. if (!TheFunction)
  421. return nullptr;
  422. // Create a new basic block to start insertion into.
  423. BasicBlock *BB = BasicBlock::Create(TheContext, "entry", TheFunction);
  424. Builder.SetInsertPoint(BB);
  425. // Record the function arguments in the NamedValues map.
  426. NamedValues.clear();
  427. for (auto &Arg : TheFunction->args())
  428. NamedValues[Arg.getName()] = &Arg;
  429. if (Value *RetVal = Body->codegen()) {
  430. // Finish off the function.
  431. Builder.CreateRet(RetVal);
  432. // Validate the generated code, checking for consistency.
  433. verifyFunction(*TheFunction);
  434. // Run the optimizer on the function.
  435. TheFPM->run(*TheFunction);
  436. return TheFunction;
  437. }
  438. // Error reading body, remove function.
  439. TheFunction->eraseFromParent();
  440. return nullptr;
  441. }
  442. //===----------------------------------------------------------------------===//
  443. // Top-Level parsing and JIT Driver
  444. //===----------------------------------------------------------------------===//
  445. static void InitializeModuleAndPassManager() {
  446. // Open a new module.
  447. TheModule = llvm::make_unique<Module>("my cool jit", TheContext);
  448. TheModule->setDataLayout(TheJIT->getTargetMachine().createDataLayout());
  449. // Create a new pass manager attached to it.
  450. TheFPM = llvm::make_unique<legacy::FunctionPassManager>(TheModule.get());
  451. // Do simple "peephole" optimizations and bit-twiddling optzns.
  452. TheFPM->add(createInstructionCombiningPass());
  453. // Reassociate expressions.
  454. TheFPM->add(createReassociatePass());
  455. // Eliminate Common SubExpressions.
  456. TheFPM->add(createGVNPass());
  457. // Simplify the control flow graph (deleting unreachable blocks, etc).
  458. TheFPM->add(createCFGSimplificationPass());
  459. TheFPM->doInitialization();
  460. }
  461. static void HandleDefinition() {
  462. if (auto FnAST = ParseDefinition()) {
  463. if (auto *FnIR = FnAST->codegen()) {
  464. fprintf(stderr, "Read function definition:");
  465. FnIR->dump();
  466. TheJIT->addModule(std::move(TheModule));
  467. InitializeModuleAndPassManager();
  468. }
  469. } else {
  470. // Skip token for error recovery.
  471. getNextToken();
  472. }
  473. }
  474. static void HandleExtern() {
  475. if (auto ProtoAST = ParseExtern()) {
  476. if (auto *FnIR = ProtoAST->codegen()) {
  477. fprintf(stderr, "Read extern: ");
  478. FnIR->dump();
  479. FunctionProtos[ProtoAST->getName()] = std::move(ProtoAST);
  480. }
  481. } else {
  482. // Skip token for error recovery.
  483. getNextToken();
  484. }
  485. }
  486. static void HandleTopLevelExpression() {
  487. // Evaluate a top-level expression into an anonymous function.
  488. if (auto FnAST = ParseTopLevelExpr()) {
  489. if (FnAST->codegen()) {
  490. // JIT the module containing the anonymous expression, keeping a handle so
  491. // we can free it later.
  492. auto H = TheJIT->addModule(std::move(TheModule));
  493. InitializeModuleAndPassManager();
  494. // Search the JIT for the __anon_expr symbol.
  495. auto ExprSymbol = TheJIT->findSymbol("__anon_expr");
  496. assert(ExprSymbol && "Function not found");
  497. // Get the symbol's address and cast it to the right type (takes no
  498. // arguments, returns a double) so we can call it as a native function.
  499. double (*FP)() = (double (*)())(intptr_t)ExprSymbol.getAddress();
  500. fprintf(stderr, "Evaluated to %f\n", FP());
  501. // Delete the anonymous expression module from the JIT.
  502. TheJIT->removeModule(H);
  503. }
  504. } else {
  505. // Skip token for error recovery.
  506. getNextToken();
  507. }
  508. }
  509. /// top ::= definition | external | expression | ';'
  510. static void MainLoop() {
  511. while (true) {
  512. fprintf(stderr, "ready> ");
  513. switch (CurTok) {
  514. case tok_eof:
  515. return;
  516. case ';': // ignore top-level semicolons.
  517. getNextToken();
  518. break;
  519. case tok_def:
  520. HandleDefinition();
  521. break;
  522. case tok_extern:
  523. HandleExtern();
  524. break;
  525. default:
  526. HandleTopLevelExpression();
  527. break;
  528. }
  529. }
  530. }
  531. //===----------------------------------------------------------------------===//
  532. // "Library" functions that can be "extern'd" from user code.
  533. //===----------------------------------------------------------------------===//
  534. /// putchard - putchar that takes a double and returns 0.
  535. extern "C" double putchard(double X) {
  536. fputc((char)X, stderr);
  537. return 0;
  538. }
  539. /// printd - printf that takes a double prints it as "%f\n", returning 0.
  540. extern "C" double printd(double X) {
  541. fprintf(stderr, "%f\n", X);
  542. return 0;
  543. }
  544. //===----------------------------------------------------------------------===//
  545. // Main driver code.
  546. //===----------------------------------------------------------------------===//
  547. int main() {
  548. InitializeNativeTarget();
  549. InitializeNativeTargetAsmPrinter();
  550. InitializeNativeTargetAsmParser();
  551. // Install standard binary operators.
  552. // 1 is lowest precedence.
  553. BinopPrecedence['<'] = 10;
  554. BinopPrecedence['+'] = 20;
  555. BinopPrecedence['-'] = 20;
  556. BinopPrecedence['*'] = 40; // highest.
  557. // Prime the first token.
  558. fprintf(stderr, "ready> ");
  559. getNextToken();
  560. TheJIT = llvm::make_unique<KaleidoscopeJIT>();
  561. InitializeModuleAndPassManager();
  562. // Run the main "interpreter loop" now.
  563. MainLoop();
  564. return 0;
  565. }