Pārlūkot izejas kodu

Chapter 4 fully functioning

Andrea Gussoni 8 gadi atpakaļ
vecāks
revīzija
9f1bb90d25
7 mainītis faili ar 265 papildinājumiem un 15 dzēšanām
  1. 12 9
      source/Ast.cpp
  2. 2 2
      source/CMakeLists.txt
  3. 54 0
      source/JIT.cpp
  4. 26 0
      source/JIT.h
  5. 114 0
      source/KaleidoscopeJIT.h
  6. 28 0
      source/Main.cpp
  7. 29 4
      source/Parser.cpp

+ 12 - 9
source/Ast.cpp

@@ -4,6 +4,9 @@
 // Local includes
 #include "Ast.h"
 #include "Parser.h"
+#include "JIT.h"
+
+using namespace jit;
 
 namespace ast{
 // attributes added from chapter 3 of the tutorial
@@ -54,7 +57,7 @@ Value *BinaryExprAST::codegen() {
 
 Value *CallExprAST::codegen() {
   // Look up the name in the global module table.
-  Function *CalleeF = AstObjects::TheModule->getFunction(Callee);
+  Function *CalleeF = getFunction(Callee);
   if (!CalleeF)
   return ErrorV("Unknown function referenced");
 
@@ -90,14 +93,11 @@ Function *PrototypeAST::codegen() {
 }
 
 Function *FunctionAST::codegen() {
-
-  // First, check for an existing function from a previous 'extern' declaration.
-  Function *TheFunction =
-      AstObjects::TheModule->getFunction(Proto->getName());
-
-  if (!TheFunction)
-    TheFunction = Proto->codegen();
-
+  // Transfer ownership of the prototype to the FunctionProtos map, but keep a
+  // reference to it for use below.
+  auto &P = *Proto;
+  JITObjects::FunctionProtos[Proto->getName()] = std::move(Proto);
+  Function *TheFunction = getFunction(P.getName());
   if (!TheFunction)
     return nullptr;
 
@@ -119,6 +119,9 @@ Function *FunctionAST::codegen() {
     // Validate the generated code, checking for consistency.
     verifyFunction(*TheFunction);
 
+    // Optimize the function.
+    JITObjects::TheFPM->run(*TheFunction);
+
     return TheFunction;
   }
 

+ 2 - 2
source/CMakeLists.txt

@@ -17,11 +17,11 @@ set(CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS} ${LLVM_CXX_FLAGS})
 include_directories(${LLVM_INCLUDE_DIRS})
 add_definitions(${LLVM_DEFINITIONS})
 
-add_executable(kaleidoscope Main.cpp Lexer.cpp Parser.cpp Ast.cpp)
+add_executable(kaleidoscope Main.cpp Lexer.cpp Parser.cpp Ast.cpp JIT.cpp)
 
 # Find the libraries that correspond to the LLVM components
 # that we wish to use
-llvm_map_components_to_libnames(llvm_libs core support native mcjit)
+llvm_map_components_to_libnames(llvm_libs core support native mcjit object scalaropts instcombine RuntimeDyld ExecutionEngine)
 
 # Link against LLVM libraries
 target_link_libraries(kaleidoscope ${llvm_libs})

+ 54 - 0
source/JIT.cpp

@@ -0,0 +1,54 @@
+//LLVM includes
+#include "llvm/Analysis/Passes.h"
+
+// Local includes
+#include "JIT.h"
+
+using namespace ast;
+using namespace llvm;
+using namespace llvm::orc;
+
+namespace jit{
+
+std::unique_ptr<llvm::legacy::FunctionPassManager> JITObjects::TheFPM =
+    std::make_unique<llvm::legacy::FunctionPassManager>(AstObjects::TheModule.get());
+std::unique_ptr<llvm::orc::KaleidoscopeJIT> JITObjects::TheJIT =
+    std::unique_ptr<llvm::orc::KaleidoscopeJIT>(nullptr);
+std::map<std::string, std::unique_ptr<ast::PrototypeAST>> JITObjects::FunctionProtos{};
+
+void InitializeModuleAndPassManager(void) {
+  // Open a new module.
+  AstObjects::TheModule = llvm::make_unique<Module>("my cool jit", getGlobalContext());
+  AstObjects::TheModule->setDataLayout(JITObjects::TheJIT->getTargetMachine().createDataLayout());
+
+  // Create a new pass manager attached to it.
+  JITObjects::TheFPM = llvm::make_unique<legacy::FunctionPassManager>(AstObjects::TheModule.get());
+
+  // Do simple "peephole" optimizations and bit-twiddling optzns.
+  JITObjects::TheFPM->add(createInstructionCombiningPass());
+  // Reassociate expressions.
+  JITObjects::TheFPM->add(createReassociatePass());
+  // Eliminate Common SubExpressions.
+  JITObjects::TheFPM->add(createGVNPass());
+  // Simplify the control flow graph (deleting unreachable blocks, etc).
+  JITObjects::TheFPM->add(createCFGSimplificationPass());
+
+  JITObjects::TheFPM->doInitialization();
+}
+
+Function *getFunction(std::string Name) {
+  // First, see if the function has already been added to the current module.
+  if (auto *F = AstObjects::TheModule->getFunction(Name))
+    return F;
+
+  // If not, check whether we can codegen the declaration from some existing
+  // prototype.
+  auto FI = JITObjects::FunctionProtos.find(Name);
+  if (FI != JITObjects::FunctionProtos.end())
+    return FI->second->codegen();
+
+  // If no existing prototype exists, return null.
+  return nullptr;
+}
+
+}

+ 26 - 0
source/JIT.h

@@ -0,0 +1,26 @@
+#ifndef _JIT_H
+#define _JIT_H
+
+//LLVM includes
+#include "llvm/IR/LegacyPassManager.h"
+#include "llvm/Transforms/Scalar.h"
+
+// Local includes
+#include "KaleidoscopeJIT.h"
+#include "Ast.h"
+
+namespace jit {
+
+struct JITObjects {
+public:
+  static std::unique_ptr<llvm::legacy::FunctionPassManager> TheFPM;
+  static std::unique_ptr<llvm::orc::KaleidoscopeJIT> TheJIT;
+  static std::map<std::string, std::unique_ptr<ast::PrototypeAST>> FunctionProtos;
+};
+
+void InitializeModuleAndPassManager(void);
+
+Function *getFunction(std::string Name);
+
+}
+#endif

+ 114 - 0
source/KaleidoscopeJIT.h

@@ -0,0 +1,114 @@
+//===----- KaleidoscopeJIT.h - A simple JIT for Kaleidoscope ----*- C++ -*-===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+//
+// Contains a simple JIT definition for use in the kaleidoscope tutorials.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_EXECUTIONENGINE_ORC_KALEIDOSCOPEJIT_H
+#define LLVM_EXECUTIONENGINE_ORC_KALEIDOSCOPEJIT_H
+
+#include "llvm/ExecutionEngine/ExecutionEngine.h"
+#include "llvm/ExecutionEngine/RTDyldMemoryManager.h"
+#include "llvm/ExecutionEngine/Orc/CompileUtils.h"
+#include "llvm/ExecutionEngine/Orc/IRCompileLayer.h"
+#include "llvm/ExecutionEngine/Orc/LambdaResolver.h"
+#include "llvm/ExecutionEngine/Orc/ObjectLinkingLayer.h"
+#include "llvm/IR/Mangler.h"
+#include "llvm/Support/DynamicLibrary.h"
+
+namespace llvm {
+namespace orc {
+
+class KaleidoscopeJIT {
+public:
+  typedef ObjectLinkingLayer<> ObjLayerT;
+  typedef IRCompileLayer<ObjLayerT> CompileLayerT;
+  typedef CompileLayerT::ModuleSetHandleT ModuleHandleT;
+
+  KaleidoscopeJIT()
+      : TM(EngineBuilder().selectTarget()), DL(TM->createDataLayout()),
+        CompileLayer(ObjectLayer, SimpleCompiler(*TM)) {
+    llvm::sys::DynamicLibrary::LoadLibraryPermanently(nullptr);
+  }
+
+  TargetMachine &getTargetMachine() { return *TM; }
+
+  ModuleHandleT addModule(std::unique_ptr<Module> M) {
+    // We need a memory manager to allocate memory and resolve symbols for this
+    // new module. Create one that resolves symbols by looking back into the
+    // JIT.
+    auto Resolver = createLambdaResolver(
+        [&](const std::string &Name) {
+          if (auto Sym = findMangledSymbol(Name))
+            return RuntimeDyld::SymbolInfo(Sym.getAddress(), Sym.getFlags());
+          return RuntimeDyld::SymbolInfo(nullptr);
+        },
+        [](const std::string &S) { return nullptr; });
+    auto H = CompileLayer.addModuleSet(singletonSet(std::move(M)),
+                                       make_unique<SectionMemoryManager>(),
+                                       std::move(Resolver));
+
+    ModuleHandles.push_back(H);
+    return H;
+  }
+
+  void removeModule(ModuleHandleT H) {
+    ModuleHandles.erase(
+        std::find(ModuleHandles.begin(), ModuleHandles.end(), H));
+    CompileLayer.removeModuleSet(H);
+  }
+
+  JITSymbol findSymbol(const std::string Name) {
+    return findMangledSymbol(mangle(Name));
+  }
+
+private:
+
+  std::string mangle(const std::string &Name) {
+    std::string MangledName;
+    {
+      raw_string_ostream MangledNameStream(MangledName);
+      Mangler::getNameWithPrefix(MangledNameStream, Name, DL);
+    }
+    return MangledName;
+  }
+
+  template <typename T> static std::vector<T> singletonSet(T t) {
+    std::vector<T> Vec;
+    Vec.push_back(std::move(t));
+    return Vec;
+  }
+
+  JITSymbol findMangledSymbol(const std::string &Name) {
+    // Search modules in reverse order: from last added to first added.
+    // This is the opposite of the usual search order for dlsym, but makes more
+    // sense in a REPL where we want to bind to the newest available definition.
+    for (auto H : make_range(ModuleHandles.rbegin(), ModuleHandles.rend()))
+      if (auto Sym = CompileLayer.findSymbolIn(H, Name, true))
+        return Sym;
+
+    // If we can't find the symbol in the JIT, try looking in the host process.
+    if (auto SymAddr = RTDyldMemoryManager::getSymbolAddressInProcess(Name))
+      return JITSymbol(SymAddr, JITSymbolFlags::Exported);
+
+    return nullptr;
+  }
+
+  std::unique_ptr<TargetMachine> TM;
+  const DataLayout DL;
+  ObjLayerT ObjectLayer;
+  CompileLayerT CompileLayer;
+  std::vector<ModuleHandleT> ModuleHandles;
+};
+
+} // End namespace orc.
+} // End namespace llvm
+
+#endif // LLVM_EXECUTIONENGINE_ORC_KALEIDOSCOPEJIT_H

+ 28 - 0
source/Main.cpp

@@ -1,11 +1,17 @@
 // Standard includes
 #include <cstdio>
 
+//LLVM includes
+#include "llvm/Support/TargetSelect.h"
+#include "KaleidoscopeJIT.h"
+
 // Local includes
 #include "Ast.h"
 #include "Parser.h"
+#include "JIT.h"
 
 using namespace parser;
+using namespace jit;
 
 namespace helper {
 // Cloning make_unique here until it's standard in C++14.
@@ -19,12 +25,34 @@ static
 }
 }
 
+//===----------------------------------------------------------------------===//
+// "Library" functions that can be "extern'd" from user code.
+//===----------------------------------------------------------------------===//
+
+/// putchard - putchar that takes a double and returns 0.
+extern "C" double putchard(double X) {
+  fputc((char)X, stderr);
+  return 0;
+}
+
+/// printd - printf that takes a double prints it as "%f\n", returning 0.
+extern "C" double printd(double X) {
+  fprintf(stderr, "%f\n", X);
+  return 0;
+}
 
 //===----------------------------------------------------------------------===//
 // Main driver code.
 //===----------------------------------------------------------------------===//
 
 int main() {
+  InitializeNativeTarget();
+  InitializeNativeTargetAsmPrinter();
+  InitializeNativeTargetAsmParser();
+
+  JITObjects::TheJIT = llvm::make_unique<llvm::orc::KaleidoscopeJIT>();
+  jit::InitializeModuleAndPassManager();
+
   // Prime the first token.
   fprintf(stderr, "ready> ");
   getNextToken();

+ 29 - 4
source/Parser.cpp

@@ -1,12 +1,19 @@
 // Standard includes
+#include <cassert>
+#include <cstdint>
 #include <cstdio>
 #include <stdexcept>
 
+//LLVM includes
+
 // Local includes
 #include "Parser.h"
 #include "Lexer.h"
+#include "JIT.h"
 
 using namespace lexer;
+using namespace jit;
+
 namespace parser{
 //===----------------------------------------------------------------------===//
 // Parser
@@ -227,10 +234,12 @@ static void HandleDefinition() {
     if (auto *FnIR = FnAST->codegen()) {
       fprintf(stderr, "Read function definition:");
       FnIR->dump();
+      JITObjects::TheJIT->addModule(std::move(AstObjects::TheModule));
+      InitializeModuleAndPassManager();
     }
   } else {
     // Skip token for error recovery.
-    getNextToken();
+     getNextToken();
   }
 }
 
@@ -239,6 +248,7 @@ static void HandleExtern() {
     if (auto *FnIR = ProtoAST->codegen()) {
       fprintf(stderr, "Read extern: ");
       FnIR->dump();
+      JITObjects::FunctionProtos[ProtoAST->getName()] = std::move(ProtoAST);
     }
   } else {
     // Skip token for error recovery.
@@ -249,9 +259,24 @@ static void HandleExtern() {
 static void HandleTopLevelExpression() {
   // Evaluate a top-level expression into an anonymous function.
   if (auto FnAST = ParseTopLevelExpr()) {
-    if (auto *FnIR = FnAST->codegen()) {
-      fprintf(stderr, "Read top-level expression:");
-      FnIR->dump();
+    if (FnAST->codegen()) {
+
+      // JIT the module containing the anonymous expression, keeping a handle so
+      // we can free it later.
+      auto H = JITObjects::TheJIT->addModule(std::move(AstObjects::TheModule));
+      InitializeModuleAndPassManager();
+
+      // Search the JIT for the __anon_expr symbol.
+      auto ExprSymbol = JITObjects::TheJIT->findSymbol("__anon_expr");
+      assert(ExprSymbol && "Function not found");
+
+      // Get the symbol's address and cast it to the right type (takes no
+      // arguments, returns a double) so we can call it as a native function.
+      double (*FP)() = (double (*)())(intptr_t)ExprSymbol.getAddress();
+      fprintf(stderr, "Evaluated to %f\n", FP());
+
+      // Delete the anonymous expression module from the JIT.
+      JITObjects::TheJIT->removeModule(H);
     }
   } else {
     // Skip token for error recovery.