From 646e3e3ebcb379143edbd7826787b67439f8a9f0 Mon Sep 17 00:00:00 2001 From: hjw Date: Fri, 18 Nov 2022 09:20:55 +0800 Subject: [PATCH 1/2] Implement code generation --- Makefile | 8 +- include/basic.hpp | 67 +++++++++++++---- include/laxerAST.hpp | 5 -- include/parserASTLY.hpp | 6 ++ src/basic.cpp | 145 ++++++++++++++++++++++++++++++++++++ src/laxerAST.cpp | 12 --- src/main.cpp | 21 ++++-- src/mathfuncRewrite.cpp | 18 ++--- src/parserASTLY.cpp | 158 ++++++++++++++++++++++++++++++++++++++++ 9 files changed, 392 insertions(+), 48 deletions(-) diff --git a/Makefile b/Makefile index 0a53543..f7e1b40 100644 --- a/Makefile +++ b/Makefile @@ -1,14 +1,14 @@ export PROJECT_NAME = exprAuto CC = gcc -CPP = g++ -INCLUDE = -Iinclude -I/usr/include/python3.8 +CPP = clang++ +INCLUDE = -Iinclude -I/usr/include/python3.8 `llvm-config --cxxflags | sed -r 's@(-I[a-zA-Z0-9/]+).*@\1@g'` # LIBS= -L/usr/lib/python3.8/config-3.8-x86_64-linux-gnu -lpython3.8 # may need -L to assign the Python lobrary path -LIBS = -lpython3.8 -lfmt +LIBS = `llvm-config --ldflags --system-libs --libs core` -lpython3.8 -lfmt ECHO = printf # $(info $(CFLAGS) ) -override CFLAGS += -g -Wall -Wextra -Wpedantic -Wno-unused-function -fdiagnostics-color=always +override CFLAGS += -g -Wall -Wextra -Wpedantic -Wno-unused-function -Wno-unused-parameter -fdiagnostics-color=always `llvm-config --cxxflags | sed -r 's@(-I[a-zA-Z0-9/]+) *@@g' | sed "s@-std=c++[0-9]* *@@g" | sed "s@-fno-rtti *@@g"` # $(info $(CFLAGS) ) EXPRAUTO_ALL_SRCS_CPP = $(wildcard src/*.cpp) diff --git a/include/basic.hpp b/include/basic.hpp index fbe60ea..c2fade1 100644 --- a/include/basic.hpp +++ b/include/basic.hpp @@ -13,6 +13,19 @@ #include #include +// for codegen +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Verifier.h" + using std::string; using std::vector; using std::ofstream; @@ -44,6 +57,8 @@ public: virtual string type() { return "Base"; } virtual std::unique_ptr Clone() { return makePtr(); } + + virtual llvm::Value *codegen(); }; /// NumberExprAST - Expression class for numeric literals like "1.0". @@ -54,13 +69,15 @@ class NumberExprAST : public ExprAST public: NumberExprAST(double Val) : Val(Val) {} - void printExpr() { fprintf(stderr, "num = %f\n", Val); } + void printExpr() override { fprintf(stderr, "num = %f\n", Val); } - string type() { return "Number"; } + string type() override { return "Number"; } double getNumber() { return Val; } - std::unique_ptr Clone() { return makePtr(Val); } + std::unique_ptr Clone() override { return makePtr(Val); } + + llvm::Value *codegen() override; }; /// VariableExprAST - Expression class for referencing a variable, like "a". @@ -71,13 +88,15 @@ class VariableExprAST : public ExprAST public: VariableExprAST(const string &Name) : Name(Name) {} - void printExpr() { fprintf(stderr, "variable = %s\n", Name.c_str()); } + void printExpr() override { fprintf(stderr, "variable = %s\n", Name.c_str()); } - string type() { return "Variable"; } + string type() override { return "Variable"; } string getVariable() { return Name; } - std::unique_ptr Clone() { return makePtr(Name); } + std::unique_ptr Clone() override { return makePtr(Name); } + + llvm::Value *codegen() override; }; /// BinaryExprAST - Expression class for a binary operator. @@ -89,14 +108,14 @@ class BinaryExprAST : public ExprAST public: BinaryExprAST(char Op, std::unique_ptr LHS, std::unique_ptr RHS) : Op(Op), LHS(std::move(LHS)), RHS(std::move(RHS)) {} - void printExpr() + void printExpr() override { fprintf(stderr, "op = %c\n", Op); LHS->printExpr(); RHS->printExpr(); } - string type() { return "Binary"; } + string type() override { return "Binary"; } char getOp() { return Op; } void setOp(char opNew) { Op = opNew; } @@ -106,12 +125,14 @@ public: void setLHS(std::unique_ptr &newLHS) { LHS = newLHS->Clone(); } void setRHS(std::unique_ptr &newRHS) { RHS = newRHS->Clone(); } - std::unique_ptr Clone() + std::unique_ptr Clone() override { auto LHSNew = LHS->Clone(); auto RHSNew = RHS->Clone(); return makePtr(Op, std::move(LHSNew), std::move(RHSNew)); } + + llvm::Value *codegen() override; }; /// CallExprAST - Expression class for function calls. @@ -125,7 +146,7 @@ public: CallExprAST(std::unique_ptr &func) : Callee(func->Callee), Args(std::move(func->Args)) {} - void printExpr() + void printExpr() override { fprintf(stderr, "call function name = %s\n", Callee.c_str()); fprintf(stderr, "call function args =\n"); @@ -135,12 +156,12 @@ public: } } - string type() { return "Call"; } + string type() override { return "Call"; } string getCallee() { return Callee; } vector> &getArgs() { return Args; } - std::unique_ptr Clone() + std::unique_ptr Clone() override { vector> ArgsNew; for (long unsigned int i = 0; i < Args.size(); ++i) @@ -151,6 +172,8 @@ public: return makePtr(Callee, std::move(ArgsNew)); } + + llvm::Value *codegen() override; }; /// PrototypeAST - This class represents the "prototype" for a function, @@ -163,7 +186,8 @@ class PrototypeAST public: PrototypeAST(const string &Name, vector Args) : Name(Name), Args(std::move(Args)) {} - + + llvm::Function *codegen(); const string &getName() const { return Name; } const vector &getArgs() const { return Args; } }; @@ -177,6 +201,8 @@ class FunctionAST public: FunctionAST(std::unique_ptr Proto, std::unique_ptr Body) : Proto(std::move(Proto)), Body(std::move(Body)) {} + llvm::Function *codegen(); + const string &getFuncName() const { return Proto->getName(); } const vector &getFuncArgs() const { return Proto->getArgs(); } std::unique_ptr &getFuncBody() { return Body; } @@ -234,6 +260,21 @@ vector> combination(const int num, const vector& indexs); size_t combination(size_t k, size_t n); +/// LogError* - These are little helper functions for error handling. +ast_ptr LogError(const char *Str); + +std::unique_ptr LogErrorP(const char *Str); + +// codegen +llvm::Value *LogErrorV(const char *Str); + +// codegen +extern std::unique_ptr TheContext; +extern std::unique_ptr TheModule; +extern std::unique_ptr> Builder; +extern std::map NamedValues; +void InitializeModule(); + // } // end anonymous namespace #endif \ No newline at end of file diff --git a/include/laxerAST.hpp b/include/laxerAST.hpp index bb3a2c0..161ce7b 100644 --- a/include/laxerAST.hpp +++ b/include/laxerAST.hpp @@ -49,9 +49,4 @@ void installOperators(); /// GetTokPrecedence - Get the precedence of the pending binary operator token. int GetTokPrecedence(); -/// LogError* - These are little helper functions for error handling. -ast_ptr LogError(const char *Str); - -std::unique_ptr LogErrorP(const char *Str); - #endif diff --git a/include/parserASTLY.hpp b/include/parserASTLY.hpp index 530f578..b840038 100644 --- a/include/parserASTLY.hpp +++ b/include/parserASTLY.hpp @@ -46,4 +46,10 @@ string readFileIntoString(const char * filename); ast_ptr ParseExpressionFromString(); ast_ptr ParseExpressionFromString(string str); +std::unique_ptr ParseTopLevelExprForStr(string str); + +void HandleTopLevelExpression(string str); + +void MainLoopForStr(string str); + #endif \ No newline at end of file diff --git a/src/basic.cpp b/src/basic.cpp index cd8ade9..3626816 100644 --- a/src/basic.cpp +++ b/src/basic.cpp @@ -736,4 +736,149 @@ size_t combination(size_t k, size_t n) return 0; } return factorial(n) / (factorial(k) * factorial(n - k)); +} + +/// LogError* - These are little helper functions for error handling. +ast_ptr LogError(const char *Str) +{ + fprintf(stderr, "Error: %s\n", Str); + return nullptr; +} +std::unique_ptr LogErrorP(const char *Str) +{ + LogError(Str); + return nullptr; +} + +// codegen +llvm::Value *LogErrorV(const char *Str) +{ + LogError(Str); + return nullptr; +} + +// codegen +std::unique_ptr TheContext; +std::unique_ptr TheModule; +std::unique_ptr> Builder; +std::map NamedValues; + +void InitializeModule() +{ + // Open a new context and module. + TheContext = std::make_unique(); + TheModule = std::make_unique("my cool jit", *TheContext); + + // Create a new builder for the module. + Builder = std::make_unique>(*TheContext); +} + +llvm::Value *ExprAST::codegen() { return 0; } + +llvm::Value *NumberExprAST::codegen() { return llvm::ConstantFP::get(*TheContext, llvm::APFloat(Val)); } + +llvm::Value *VariableExprAST::codegen() +{ // Look this variable up in the function. + llvm::Value *V = NamedValues[Name]; + if(!V) + return LogErrorV("Unknown variable name"); + return V; +} + +llvm::Value *BinaryExprAST::codegen() +{ + llvm::Value *L = LHS->codegen(); + llvm::Value *R = RHS->codegen(); + if(!L || !R) + return nullptr; + + switch(Op) + { + case '+': + return Builder->CreateFAdd(L, R, "addtmp"); + case '-': + return Builder->CreateFSub(L, R, "subtmp"); + case '*': + return Builder->CreateFMul(L, R, "multmp"); + case '<': + L = Builder->CreateFCmpULT(L, R, "cmptmp"); + // Convert bool 0/1 to double 0.0 or 1.0 + return Builder->CreateUIToFP(L, llvm::Type::getDoubleTy(*TheContext), "booltmp"); + default: + return LogErrorV("invalid binary operator"); + } +} + +llvm::Value *CallExprAST::codegen() +{ + // Look up the name in the global module table. + llvm::Function *CalleeF = TheModule->getFunction(Callee); + if(!CalleeF) + return LogErrorV("Unknown function referenced"); + + // If argument mismatch error. + if(CalleeF->arg_size() != Args.size()) + return LogErrorV("Incorrect # arguments passed"); + + std::vector ArgsV; + for(unsigned i = 0, e = Args.size(); i != e; ++i) + { + ArgsV.push_back(Args[i]->codegen()); + if(!ArgsV.back()) + return nullptr; + } + + return Builder->CreateCall(CalleeF, ArgsV, "calltmp"); +} + +llvm::Function *PrototypeAST::codegen() +{ + // Make the function type: double(double,double) etc. + std::vector Doubles(Args.size(), llvm::Type::getDoubleTy(*TheContext)); + llvm::FunctionType *FT = llvm::FunctionType::get(llvm::Type::getDoubleTy(*TheContext), Doubles, false); + + llvm::Function *F = llvm::Function::Create(FT, llvm::Function::ExternalLinkage, Name, TheModule.get()); + + // Set names for all arguments. + unsigned Idx = 0; + for(auto &Arg : F->args()) + Arg.setName(Args[Idx++]); + + return F; +} + +llvm::Function *FunctionAST::codegen() +{ + // First, check for an existing function from a previous 'extern' declaration. + llvm::Function *TheFunction = TheModule->getFunction(Proto->getName()); + + if(!TheFunction) + TheFunction = Proto->codegen(); + + if(!TheFunction) + return nullptr; + + // Create a new basic block to start insertion into. + llvm::BasicBlock *BB = llvm::BasicBlock::Create(*TheContext, "entry", TheFunction); + Builder->SetInsertPoint(BB); + + // Record the function arguments in the NamedValues map. + NamedValues.clear(); + for(auto &Arg : TheFunction->args()) + NamedValues[std::string(Arg.getName())] = &Arg; + + if(llvm::Value *RetVal = Body->codegen()) + { + // Finish off the function. + Builder->CreateRet(RetVal); + + // Validate the generated code, checking for consistency. + verifyFunction(*TheFunction); + + return TheFunction; + } + + // Error reading body, remove function. + TheFunction->eraseFromParent(); + return nullptr; } \ No newline at end of file diff --git a/src/laxerAST.cpp b/src/laxerAST.cpp index 0af6718..18e452f 100644 --- a/src/laxerAST.cpp +++ b/src/laxerAST.cpp @@ -138,15 +138,3 @@ int GetTokPrecedence() return -1; return TokPrec; } - -/// LogError* - These are little helper functions for error handling. -ast_ptr LogError(const char *Str) -{ - fprintf(stderr, "Error: %s\n", Str); - return nullptr; -} -std::unique_ptr LogErrorP(const char *Str) -{ - LogError(Str); - return nullptr; -} diff --git a/src/main.cpp b/src/main.cpp index 1008feb..9e61b04 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -30,6 +30,9 @@ int main() { installOperatorsForStr(); initPython(); + // codegen + // Make the module, which holds all the code. + InitializeModule(); // tmp files for input and output. ifstream infile; @@ -69,11 +72,11 @@ int main() fprintf(stderr, GREEN "ready> " RESET); continue; } - else if (inputStr.back() == ';') - { - fprintf(stderr, "you do not need to add a ';' after the expression\n"); - inputStr.pop_back(); // remove the last char ';' - } + // else if (inputStr.back() == ';') + // { + // fprintf(stderr, "you do not need to add a ';' after the expression\n"); + // inputStr.pop_back(); // remove the last char ';' + // } auto benchMarkData = initalBenchMark(); auto pos = benchMarkData.find(inputStr); @@ -90,7 +93,15 @@ int main() const char *split = " "; // Get the information about the input expr + // codegen test + cout << "input:" << inputStr << endl; + MainLoopForStr(inputStr); + fprintf(stderr, GREEN "ready> " RESET); + continue; + // codegen test end + auto originExpr = ParseExpressionFromString(inputStr); + printExpr(originExpr, "codegen test: "); vector vars; getVariablesFromExpr(originExpr, vars); // cout << vars.size() << endl; diff --git a/src/mathfuncRewrite.cpp b/src/mathfuncRewrite.cpp index b0b84f4..e7c7554 100644 --- a/src/mathfuncRewrite.cpp +++ b/src/mathfuncRewrite.cpp @@ -98,7 +98,7 @@ ast_ptr logTolog1p(const ast_ptr &expr) double number = numberExpr->getNumber(); if(number == 1) { - argsNew.push_back(std::move(rhs->Clone())); + argsNew.push_back(rhs->Clone()); return makePtr("log1p", std::move(argsNew)); } } @@ -108,7 +108,7 @@ ast_ptr logTolog1p(const ast_ptr &expr) double number = numberExpr->getNumber(); if(number == 1) { - argsNew.push_back(std::move(lhs->Clone())); + argsNew.push_back(lhs->Clone()); return makePtr("log1p", std::move(argsNew)); } } @@ -161,8 +161,8 @@ ast_ptr sqrtTohypot(const ast_ptr &expr) if(isEqual(lhsL, rhsL) && isEqual(lhsR, rhsR)) { vector argsNew; // store the parameters for hypot - argsNew.push_back(std::move(lhsL->Clone())); - argsNew.push_back(std::move(lhsR->Clone())); + argsNew.push_back(lhsL->Clone()); + argsNew.push_back(lhsR->Clone()); return makePtr("hypot", std::move(argsNew)); } @@ -666,9 +666,9 @@ vector toFma(const ast_ptr &expr) if(!isSpecialNumber(lhsL) && !isSpecialNumber(rhsL)) // if(!isEqual(lhsL, tmpNegOne) && !isEqual(rhsL, tmpNegOne)) { - argsNew.push_back(std::move(lhsL->Clone())); - argsNew.push_back(std::move(rhsL->Clone())); - argsNew.push_back(std::move(rhs->Clone())); + argsNew.push_back(lhsL->Clone()); + argsNew.push_back(rhsL->Clone()); + argsNew.push_back(rhs->Clone()); string calleeNew = "fma"; ast_ptr exprFinal = makePtr(calleeNew, std::move(argsNew)); results.push_back(std::move(exprFinal)); @@ -692,7 +692,7 @@ vector toFma(const ast_ptr &expr) // the key to control the rewrite results if(results.size() == 0) { - results.push_back(std::move(expr->Clone())); + results.push_back(expr->Clone()); } // printExprs(results, prompt + "results: "); // cout << prompt << "end--------" << endl; @@ -742,7 +742,7 @@ vector fmaRewrite(const ast_ptr &expr) if (expr->type() != "Binary") // May be call, variable or number { vector exprsFinal; - exprsFinal.push_back(std::move(expr->Clone())); + exprsFinal.push_back(expr->Clone()); if (callCount == 1) printExprs(exprsFinal, prompt + "exprsFinal: "); if (callCount == 1) cout << prompt << "end--------" << endl; callCount--; diff --git a/src/parserASTLY.cpp b/src/parserASTLY.cpp index b72572d..e69c53c 100644 --- a/src/parserASTLY.cpp +++ b/src/parserASTLY.cpp @@ -201,7 +201,165 @@ ast_ptr ParseExpressionFromString(string str) if (auto E = ParseExpressionForStr()){ ast_ptr es = E->Clone(); + // if(auto* EIR = E->codegen()) + // { + // fprintf(stderr, "Read top-level expression:"); + // EIR->print(llvm::errs()); + // fprintf(stderr, "\n"); + + // // Remove the anonymous expression. + // // EIR->eraseFromParent(); + // } return es; } return nullptr; +} + +/// prototype +/// ::= id '(' id* ')' +std::unique_ptr ParsePrototypeForStr() +{ + if(CurTokForStr != tok_identifier_forstr) + return LogErrorP("Expected function name in prototype"); + + string FnName = IdentifierStr1; + getNextTokenForStr(); + + if(CurTokForStr != '(') + return LogErrorP("Expected '(' in prototype"); + + vector ArgNames; + while(getNextTokenForStr() == tok_identifier_forstr) + ArgNames.push_back(IdentifierStr1); + if(CurTokForStr != ')') + return LogErrorP("Expected ')' in prototype"); + + // success. + getNextTokenForStr(); // eat ')'. + + return makePtr(FnName, std::move(ArgNames)); +} + +/// definition ::= 'def' prototype expression +std::unique_ptr ParseDefinitionForStr() +{ + getNextTokenForStr(); // eat def. + auto Proto = ParsePrototypeForStr(); + if(!Proto) + return nullptr; + + if(auto E = ParseExpressionForStr()) + return makePtr(std::move(Proto), std::move(E)); + return nullptr; +} + +/// toplevelexpr ::= expression +std::unique_ptr ParseTopLevelExprForStr() +{ + if(auto E = ParseExpressionForStr()) + { + // for test purposes + printExpr(E, "codegen test: "); + // Make an anonymous proto. + auto Proto = makePtr("__anon_expr", vector()); + return makePtr(std::move(Proto), std::move(E)); + } + return nullptr; +} + +/// external ::= 'extern' prototype +std::unique_ptr ParseExternForStr() +{ + getNextTokenForStr(); // eat extern. + return ParsePrototypeForStr(); +} + +void HandleDefinitionForStr() +{ + if(auto FnAST = ParseDefinitionForStr()) + { + if(auto* FnIR = FnAST->codegen()) + { + fprintf(stderr, "Read function definition:"); + FnIR->print(llvm::errs()); + fprintf(stderr, "\n"); + } + } + else + { + // Skip token for error recovery. + getNextTokenForStr(); + } +} + +void HandleExternForStr() +{ + if(auto ProtoAST = ParseExternForStr()) + { + if(auto* FnIR = ProtoAST->codegen()) + { + fprintf(stderr, "Read extern: "); + FnIR->print(llvm::errs()); + fprintf(stderr, "\n"); + } + } + else + { + // Skip token for error recovery. + getNextTokenForStr(); + } +} + +void HandleTopLevelExpressionForStr() +{ + // Evaluate a top-level expression into an anonymous function. + if(auto FnAST = ParseTopLevelExprForStr()) + { + if(auto* FnIR = FnAST->codegen()) + { + fprintf(stderr, "Read top-level expression:"); + FnIR->print(llvm::errs()); + fprintf(stderr, "\n"); + + // Remove the anonymous expression. + FnIR->eraseFromParent(); + + } + } +} + +/// top ::= definition | external | expression | ';' +void MainLoopForStr(string str) +{ + if(CurTokForStr == ';') + { + filestring.clear(); + filestring.shrink_to_fit(); + flag = 0; + } + + filestring = str; + getNextTokenForStr(); + std::cout << "str: " << str << std::endl; + std::cout << "filestring: " << filestring << std::endl; + while(true) // TODO: check + { + switch(CurTokForStr) + { + case tok_eof_forstr: + return; + case ';': // ignore top-level semicolons. + getNextTokenForStr(); + return; + case tok_def_forstr: + HandleDefinitionForStr(); + break; + case tok_extern_forstr: + HandleExternForStr(); + break; + default: + HandleTopLevelExpressionForStr(); + break; + } + } } \ No newline at end of file -- Gitee From 3fc39120fa1c569c35a7884c5b6df0050ca14fae Mon Sep 17 00:00:00 2001 From: hjw Date: Sat, 10 Dec 2022 21:55:32 +0800 Subject: [PATCH 2/2] Adding JIT and Optimizer Support --- Makefile | 2 +- include/KaleidoscopeJIT.h | 108 ++++++++++++++++++++++++++++++++++++++ include/basic.hpp | 22 ++------ include/codegenLLVM.hpp | 37 +++++++++++++ src/basic.cpp | 63 +++++++++++++++++++--- src/main.cpp | 14 ++++- src/parserASTLY.cpp | 35 +++++++++--- 7 files changed, 247 insertions(+), 34 deletions(-) create mode 100644 include/KaleidoscopeJIT.h create mode 100644 include/codegenLLVM.hpp diff --git a/Makefile b/Makefile index f7e1b40..5eda710 100644 --- a/Makefile +++ b/Makefile @@ -4,7 +4,7 @@ CC = gcc CPP = clang++ INCLUDE = -Iinclude -I/usr/include/python3.8 `llvm-config --cxxflags | sed -r 's@(-I[a-zA-Z0-9/]+).*@\1@g'` # LIBS= -L/usr/lib/python3.8/config-3.8-x86_64-linux-gnu -lpython3.8 # may need -L to assign the Python lobrary path -LIBS = `llvm-config --ldflags --system-libs --libs core` -lpython3.8 -lfmt +LIBS = `llvm-config --ldflags --system-libs --libs core orcjit native` -lpython3.8 -lfmt -rdynamic ECHO = printf # $(info $(CFLAGS) ) diff --git a/include/KaleidoscopeJIT.h b/include/KaleidoscopeJIT.h new file mode 100644 index 0000000..a9592cb --- /dev/null +++ b/include/KaleidoscopeJIT.h @@ -0,0 +1,108 @@ +//===- KaleidoscopeJIT.h - A simple JIT for Kaleidoscope --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Contains a simple JIT definition for use in the kaleidoscope tutorials. +// +//===----------------------------------------------------------------------===// + +// git directory https://github.com/llvm/llvm-project/blob/main/llvm/examples/Kaleidoscope/include/KaleidoscopeJIT.h +// git commit id da83b70a6fe602f047c8007009cfd646a2166cce +// git commit time Aug 19, 2021 + +#ifndef LLVM_EXECUTIONENGINE_ORC_KALEIDOSCOPEJIT_H +#define LLVM_EXECUTIONENGINE_ORC_KALEIDOSCOPEJIT_H + +#include "llvm/ADT/StringRef.h" +#include "llvm/ExecutionEngine/JITSymbol.h" +#include "llvm/ExecutionEngine/Orc/CompileUtils.h" +#include "llvm/ExecutionEngine/Orc/Core.h" +#include "llvm/ExecutionEngine/Orc/ExecutionUtils.h" +#include "llvm/ExecutionEngine/Orc/ExecutorProcessControl.h" +#include "llvm/ExecutionEngine/Orc/IRCompileLayer.h" +#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" +#include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h" +#include "llvm/ExecutionEngine/SectionMemoryManager.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/LLVMContext.h" +#include + +namespace llvm { +namespace orc { + +class KaleidoscopeJIT { +private: + std::unique_ptr ES; + + DataLayout DL; + MangleAndInterner Mangle; + + RTDyldObjectLinkingLayer ObjectLayer; + IRCompileLayer CompileLayer; + + JITDylib &MainJD; + +public: + KaleidoscopeJIT(std::unique_ptr ES, + JITTargetMachineBuilder JTMB, DataLayout DL) + : ES(std::move(ES)), DL(std::move(DL)), Mangle(*this->ES, this->DL), + ObjectLayer(*this->ES, + []() { return std::make_unique(); }), + CompileLayer(*this->ES, ObjectLayer, + std::make_unique(std::move(JTMB))), + MainJD(this->ES->createBareJITDylib("
")) { + MainJD.addGenerator( + cantFail(DynamicLibrarySearchGenerator::GetForCurrentProcess( + DL.getGlobalPrefix()))); + if (JTMB.getTargetTriple().isOSBinFormatCOFF()) { + ObjectLayer.setOverrideObjectFlagsWithResponsibilityFlags(true); + ObjectLayer.setAutoClaimResponsibilityForObjectSymbols(true); + } + } + + ~KaleidoscopeJIT() { + if (auto Err = ES->endSession()) + ES->reportError(std::move(Err)); + } + + static Expected> Create() { + auto EPC = SelfExecutorProcessControl::Create(); + if (!EPC) + return EPC.takeError(); + + auto ES = std::make_unique(std::move(*EPC)); + + JITTargetMachineBuilder JTMB( + ES->getExecutorProcessControl().getTargetTriple()); + + auto DL = JTMB.getDefaultDataLayoutForTarget(); + if (!DL) + return DL.takeError(); + + return std::make_unique(std::move(ES), std::move(JTMB), + std::move(*DL)); + } + + const DataLayout &getDataLayout() const { return DL; } + + JITDylib &getMainJITDylib() { return MainJD; } + + Error addModule(ThreadSafeModule TSM, ResourceTrackerSP RT = nullptr) { + if (!RT) + RT = MainJD.getDefaultResourceTracker(); + return CompileLayer.add(RT, std::move(TSM)); + } + + Expected lookup(StringRef Name) { + return ES->lookup({&MainJD}, Mangle(Name.str())); + } +}; + +} // end namespace orc +} // end namespace llvm + +#endif // LLVM_EXECUTIONENGINE_ORC_KALEIDOSCOPEJIT_H \ No newline at end of file diff --git a/include/basic.hpp b/include/basic.hpp index c2fade1..3187b2a 100644 --- a/include/basic.hpp +++ b/include/basic.hpp @@ -13,18 +13,9 @@ #include #include -// for codegen -#include "llvm/ADT/APFloat.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/IR/BasicBlock.h" -#include "llvm/IR/Constants.h" -#include "llvm/IR/DerivedTypes.h" +// for llvm codegen +#include "llvm/IR/Value.h" #include "llvm/IR/Function.h" -#include "llvm/IR/IRBuilder.h" -#include "llvm/IR/LLVMContext.h" -#include "llvm/IR/Module.h" -#include "llvm/IR/Type.h" -#include "llvm/IR/Verifier.h" using std::string; using std::vector; @@ -265,16 +256,9 @@ ast_ptr LogError(const char *Str); std::unique_ptr LogErrorP(const char *Str); -// codegen +// LogErrorV for codegen llvm::Value *LogErrorV(const char *Str); -// codegen -extern std::unique_ptr TheContext; -extern std::unique_ptr TheModule; -extern std::unique_ptr> Builder; -extern std::map NamedValues; -void InitializeModule(); - // } // end anonymous namespace #endif \ No newline at end of file diff --git a/include/codegenLLVM.hpp b/include/codegenLLVM.hpp new file mode 100644 index 0000000..0e71883 --- /dev/null +++ b/include/codegenLLVM.hpp @@ -0,0 +1,37 @@ +#ifndef _CODEGENLLVM +#define _CODEGENLLVM + +class PrototypeAST; + +#include "KaleidoscopeJIT.h" +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/LegacyPassManager.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Verifier.h" +#include "llvm/Support/TargetSelect.h" +#include "llvm/Target/TargetMachine.h" +#include "llvm/Transforms/InstCombine/InstCombine.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Scalar/GVN.h" + +// codegen +extern std::unique_ptr TheContext; +extern std::unique_ptr TheModule; +extern std::unique_ptr> Builder; +extern std::map NamedValues; +extern std::unique_ptr TheFPM; +extern std::unique_ptr TheJIT; +extern std::map> FunctionProtos; +extern llvm::ExitOnError ExitOnErr; +void InitializeModule(); +void InitializeModuleAndPassManager(void); + +#endif \ No newline at end of file diff --git a/src/basic.cpp b/src/basic.cpp index 3626816..d46229b 100644 --- a/src/basic.cpp +++ b/src/basic.cpp @@ -7,6 +7,8 @@ #include #include +#include "codegenLLVM.hpp" + using std::cerr; using std::cout; using std::endl; @@ -762,6 +764,8 @@ std::unique_ptr TheContext; std::unique_ptr TheModule; std::unique_ptr> Builder; std::map NamedValues; +std::unique_ptr TheFPM; +std::map> FunctionProtos; void InitializeModule() { @@ -773,6 +777,46 @@ void InitializeModule() Builder = std::make_unique>(*TheContext); } +void InitializeModuleAndPassManager(void) { + // Open a new module. + TheContext = std::make_unique(); + TheModule = std::make_unique("my cool jit", *TheContext); + TheModule->setDataLayout(TheJIT->getDataLayout()); + + // Create a new builder for the module. + Builder = std::make_unique>(*TheContext); + + // Create a new pass manager attached to it. + TheFPM = std::make_unique(TheModule.get()); + + // Do simple "peephole" optimizations and bit-twiddling optzns. + TheFPM->add(llvm::createInstructionCombiningPass()); + // Reassociate expressions. + TheFPM->add(llvm::createReassociatePass()); + // Eliminate Common SubExpressions. + TheFPM->add(llvm::createGVNPass()); + // Simplify the control flow graph (deleting unreachable blocks, etc). + TheFPM->add(llvm::createCFGSimplificationPass()); + + TheFPM->doInitialization(); +} + +llvm::Function *getFunction(std::string Name) +{ + // First, see if the function has already been added to the current module. + if(auto *F = TheModule->getFunction(Name)) + return F; + + // If not, check whether we can codegen the declaration from some existing + // prototype. + auto FI = FunctionProtos.find(Name); + if(FI != FunctionProtos.end()) + return FI->second->codegen(); + + // If no existing prototype exists, return null. + return nullptr; +} + llvm::Value *ExprAST::codegen() { return 0; } llvm::Value *NumberExprAST::codegen() { return llvm::ConstantFP::get(*TheContext, llvm::APFloat(Val)); } @@ -812,7 +856,7 @@ llvm::Value *BinaryExprAST::codegen() llvm::Value *CallExprAST::codegen() { // Look up the name in the global module table. - llvm::Function *CalleeF = TheModule->getFunction(Callee); + llvm::Function *CalleeF = getFunction(Callee); if(!CalleeF) return LogErrorV("Unknown function referenced"); @@ -849,14 +893,16 @@ llvm::Function *PrototypeAST::codegen() llvm::Function *FunctionAST::codegen() { - // First, check for an existing function from a previous 'extern' declaration. - llvm::Function *TheFunction = TheModule->getFunction(Proto->getName()); - - if(!TheFunction) - TheFunction = Proto->codegen(); - + // Transfer ownership of the prototype to the FunctionProtos map, but keep a + // reference to it for use below. + auto &P = *Proto; + FunctionProtos[Proto->getName()] = std::move(Proto); + llvm::Function *TheFunction = getFunction(P.getName()); if(!TheFunction) return nullptr; + // if(!TheFunction) + // TheFunction = Proto->codegen(); + // Create a new basic block to start insertion into. llvm::BasicBlock *BB = llvm::BasicBlock::Create(*TheContext, "entry", TheFunction); @@ -875,6 +921,9 @@ llvm::Function *FunctionAST::codegen() // Validate the generated code, checking for consistency. verifyFunction(*TheFunction); + // Optimize the function. + TheFPM->run(*TheFunction); + return TheFunction; } diff --git a/src/main.cpp b/src/main.cpp index 9e61b04..07de0e5 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -7,6 +7,7 @@ #include "geneCode.hpp" #include "tools.hpp" #include "benchMark.hpp" +#include "codegenLLVM.hpp" #include #include @@ -26,13 +27,24 @@ using std::vector; // Main driver code. //===----------------------------------------------------------------------===// +// codegen +std::unique_ptr TheJIT; +llvm::ExitOnError ExitOnErr; int main() { installOperatorsForStr(); initPython(); // codegen + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + llvm::InitializeNativeTargetAsmParser(); + + // TheJIT = std::make_unique(); + TheJIT = ExitOnErr(llvm::orc::KaleidoscopeJIT::Create()); + // Make the module, which holds all the code. - InitializeModule(); + // InitializeModule(); + InitializeModuleAndPassManager(); // tmp files for input and output. ifstream infile; diff --git a/src/parserASTLY.cpp b/src/parserASTLY.cpp index e69c53c..719683d 100644 --- a/src/parserASTLY.cpp +++ b/src/parserASTLY.cpp @@ -2,6 +2,7 @@ #include "laxerASTLY.hpp" #include "parserASTLY.hpp" #include "string.h" +#include "codegenLLVM.hpp" #include using std::string; @@ -259,7 +260,7 @@ std::unique_ptr ParseTopLevelExprForStr() if(auto E = ParseExpressionForStr()) { // for test purposes - printExpr(E, "codegen test: "); + // printExpr(E, "codegen test: "); // Make an anonymous proto. auto Proto = makePtr("__anon_expr", vector()); return makePtr(std::move(Proto), std::move(E)); @@ -283,6 +284,8 @@ void HandleDefinitionForStr() fprintf(stderr, "Read function definition:"); FnIR->print(llvm::errs()); fprintf(stderr, "\n"); + ExitOnErr(TheJIT->addModule(llvm::orc::ThreadSafeModule(std::move(TheModule), std::move(TheContext)))); + InitializeModuleAndPassManager(); } } else @@ -301,6 +304,7 @@ void HandleExternForStr() fprintf(stderr, "Read extern: "); FnIR->print(llvm::errs()); fprintf(stderr, "\n"); + FunctionProtos[ProtoAST->getName()] = std::move(ProtoAST); } } else @@ -317,12 +321,31 @@ void HandleTopLevelExpressionForStr() { if(auto* FnIR = FnAST->codegen()) { - fprintf(stderr, "Read top-level expression:"); - FnIR->print(llvm::errs()); - fprintf(stderr, "\n"); + // JIT the module containing the anonymous expression, keeping a handle so + // we can free it later. + auto RT = TheJIT->getMainJITDylib().createResourceTracker(); + + auto TSM = llvm::orc::ThreadSafeModule(std::move(TheModule), std::move(TheContext)); + ExitOnErr(TheJIT->addModule(std::move(TSM), RT)); + InitializeModuleAndPassManager(); + + // Search the JIT for the __anon_expr symbol. + auto ExprSymbol = ExitOnErr(TheJIT->lookup("__anon_expr")); + + // Get the symbol's address and cast it to the right type (takes no + // arguments, returns a double) so we can call it as a native function. + double (*FP)() = (double (*)())(intptr_t)ExprSymbol.getAddress(); + fprintf(stderr, "Evaluated to %f\n", FP()); + + // Delete the anonymous expression module from the JIT. + ExitOnErr(RT->remove()); + + // fprintf(stderr, "Read top-level expression:"); + // FnIR->print(llvm::errs()); + // fprintf(stderr, "\n"); - // Remove the anonymous expression. - FnIR->eraseFromParent(); + // // Remove the anonymous expression. + // FnIR->eraseFromParent(); } } -- Gitee