diff --git a/Makefile b/Makefile index 0a535438f5d6a48d4527faa530325b7946a91d54..5eda71036844fb3a79f158b17eeec85ecf660774 100644 --- a/Makefile +++ b/Makefile @@ -1,14 +1,14 @@ export PROJECT_NAME = exprAuto CC = gcc -CPP = g++ -INCLUDE = -Iinclude -I/usr/include/python3.8 +CPP = clang++ +INCLUDE = -Iinclude -I/usr/include/python3.8 `llvm-config --cxxflags | sed -r 's@(-I[a-zA-Z0-9/]+).*@\1@g'` # LIBS= -L/usr/lib/python3.8/config-3.8-x86_64-linux-gnu -lpython3.8 # may need -L to assign the Python lobrary path -LIBS = -lpython3.8 -lfmt +LIBS = `llvm-config --ldflags --system-libs --libs core orcjit native` -lpython3.8 -lfmt -rdynamic ECHO = printf # $(info $(CFLAGS) ) -override CFLAGS += -g -Wall -Wextra -Wpedantic -Wno-unused-function -fdiagnostics-color=always +override CFLAGS += -g -Wall -Wextra -Wpedantic -Wno-unused-function -Wno-unused-parameter -fdiagnostics-color=always `llvm-config --cxxflags | sed -r 's@(-I[a-zA-Z0-9/]+) *@@g' | sed "s@-std=c++[0-9]* *@@g" | sed "s@-fno-rtti *@@g"` # $(info $(CFLAGS) ) EXPRAUTO_ALL_SRCS_CPP = $(wildcard src/*.cpp) diff --git a/include/KaleidoscopeJIT.h b/include/KaleidoscopeJIT.h new file mode 100644 index 0000000000000000000000000000000000000000..a9592cb3f7b309420a73069a46dbc8a612c6e0f2 --- /dev/null +++ b/include/KaleidoscopeJIT.h @@ -0,0 +1,108 @@ +//===- KaleidoscopeJIT.h - A simple JIT for Kaleidoscope --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Contains a simple JIT definition for use in the kaleidoscope tutorials. +// +//===----------------------------------------------------------------------===// + +// git directory https://github.com/llvm/llvm-project/blob/main/llvm/examples/Kaleidoscope/include/KaleidoscopeJIT.h +// git commit id da83b70a6fe602f047c8007009cfd646a2166cce +// git commit time Aug 19, 2021 + +#ifndef LLVM_EXECUTIONENGINE_ORC_KALEIDOSCOPEJIT_H +#define LLVM_EXECUTIONENGINE_ORC_KALEIDOSCOPEJIT_H + +#include "llvm/ADT/StringRef.h" +#include "llvm/ExecutionEngine/JITSymbol.h" +#include "llvm/ExecutionEngine/Orc/CompileUtils.h" +#include "llvm/ExecutionEngine/Orc/Core.h" +#include "llvm/ExecutionEngine/Orc/ExecutionUtils.h" +#include "llvm/ExecutionEngine/Orc/ExecutorProcessControl.h" +#include "llvm/ExecutionEngine/Orc/IRCompileLayer.h" +#include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" +#include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h" +#include "llvm/ExecutionEngine/SectionMemoryManager.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/LLVMContext.h" +#include + +namespace llvm { +namespace orc { + +class KaleidoscopeJIT { +private: + std::unique_ptr ES; + + DataLayout DL; + MangleAndInterner Mangle; + + RTDyldObjectLinkingLayer ObjectLayer; + IRCompileLayer CompileLayer; + + JITDylib &MainJD; + +public: + KaleidoscopeJIT(std::unique_ptr ES, + JITTargetMachineBuilder JTMB, DataLayout DL) + : ES(std::move(ES)), DL(std::move(DL)), Mangle(*this->ES, this->DL), + ObjectLayer(*this->ES, + []() { return std::make_unique(); }), + CompileLayer(*this->ES, ObjectLayer, + std::make_unique(std::move(JTMB))), + MainJD(this->ES->createBareJITDylib("
")) { + MainJD.addGenerator( + cantFail(DynamicLibrarySearchGenerator::GetForCurrentProcess( + DL.getGlobalPrefix()))); + if (JTMB.getTargetTriple().isOSBinFormatCOFF()) { + ObjectLayer.setOverrideObjectFlagsWithResponsibilityFlags(true); + ObjectLayer.setAutoClaimResponsibilityForObjectSymbols(true); + } + } + + ~KaleidoscopeJIT() { + if (auto Err = ES->endSession()) + ES->reportError(std::move(Err)); + } + + static Expected> Create() { + auto EPC = SelfExecutorProcessControl::Create(); + if (!EPC) + return EPC.takeError(); + + auto ES = std::make_unique(std::move(*EPC)); + + JITTargetMachineBuilder JTMB( + ES->getExecutorProcessControl().getTargetTriple()); + + auto DL = JTMB.getDefaultDataLayoutForTarget(); + if (!DL) + return DL.takeError(); + + return std::make_unique(std::move(ES), std::move(JTMB), + std::move(*DL)); + } + + const DataLayout &getDataLayout() const { return DL; } + + JITDylib &getMainJITDylib() { return MainJD; } + + Error addModule(ThreadSafeModule TSM, ResourceTrackerSP RT = nullptr) { + if (!RT) + RT = MainJD.getDefaultResourceTracker(); + return CompileLayer.add(RT, std::move(TSM)); + } + + Expected lookup(StringRef Name) { + return ES->lookup({&MainJD}, Mangle(Name.str())); + } +}; + +} // end namespace orc +} // end namespace llvm + +#endif // LLVM_EXECUTIONENGINE_ORC_KALEIDOSCOPEJIT_H \ No newline at end of file diff --git a/include/basic.hpp b/include/basic.hpp index fbe60ea0812f2409ce4ba416fa7fdf101833b19c..3187b2ae66b0c9956cb1f738e88631a1e5b5afe3 100644 --- a/include/basic.hpp +++ b/include/basic.hpp @@ -13,6 +13,10 @@ #include #include +// for llvm codegen +#include "llvm/IR/Value.h" +#include "llvm/IR/Function.h" + using std::string; using std::vector; using std::ofstream; @@ -44,6 +48,8 @@ public: virtual string type() { return "Base"; } virtual std::unique_ptr Clone() { return makePtr(); } + + virtual llvm::Value *codegen(); }; /// NumberExprAST - Expression class for numeric literals like "1.0". @@ -54,13 +60,15 @@ class NumberExprAST : public ExprAST public: NumberExprAST(double Val) : Val(Val) {} - void printExpr() { fprintf(stderr, "num = %f\n", Val); } + void printExpr() override { fprintf(stderr, "num = %f\n", Val); } - string type() { return "Number"; } + string type() override { return "Number"; } double getNumber() { return Val; } - std::unique_ptr Clone() { return makePtr(Val); } + std::unique_ptr Clone() override { return makePtr(Val); } + + llvm::Value *codegen() override; }; /// VariableExprAST - Expression class for referencing a variable, like "a". @@ -71,13 +79,15 @@ class VariableExprAST : public ExprAST public: VariableExprAST(const string &Name) : Name(Name) {} - void printExpr() { fprintf(stderr, "variable = %s\n", Name.c_str()); } + void printExpr() override { fprintf(stderr, "variable = %s\n", Name.c_str()); } - string type() { return "Variable"; } + string type() override { return "Variable"; } string getVariable() { return Name; } - std::unique_ptr Clone() { return makePtr(Name); } + std::unique_ptr Clone() override { return makePtr(Name); } + + llvm::Value *codegen() override; }; /// BinaryExprAST - Expression class for a binary operator. @@ -89,14 +99,14 @@ class BinaryExprAST : public ExprAST public: BinaryExprAST(char Op, std::unique_ptr LHS, std::unique_ptr RHS) : Op(Op), LHS(std::move(LHS)), RHS(std::move(RHS)) {} - void printExpr() + void printExpr() override { fprintf(stderr, "op = %c\n", Op); LHS->printExpr(); RHS->printExpr(); } - string type() { return "Binary"; } + string type() override { return "Binary"; } char getOp() { return Op; } void setOp(char opNew) { Op = opNew; } @@ -106,12 +116,14 @@ public: void setLHS(std::unique_ptr &newLHS) { LHS = newLHS->Clone(); } void setRHS(std::unique_ptr &newRHS) { RHS = newRHS->Clone(); } - std::unique_ptr Clone() + std::unique_ptr Clone() override { auto LHSNew = LHS->Clone(); auto RHSNew = RHS->Clone(); return makePtr(Op, std::move(LHSNew), std::move(RHSNew)); } + + llvm::Value *codegen() override; }; /// CallExprAST - Expression class for function calls. @@ -125,7 +137,7 @@ public: CallExprAST(std::unique_ptr &func) : Callee(func->Callee), Args(std::move(func->Args)) {} - void printExpr() + void printExpr() override { fprintf(stderr, "call function name = %s\n", Callee.c_str()); fprintf(stderr, "call function args =\n"); @@ -135,12 +147,12 @@ public: } } - string type() { return "Call"; } + string type() override { return "Call"; } string getCallee() { return Callee; } vector> &getArgs() { return Args; } - std::unique_ptr Clone() + std::unique_ptr Clone() override { vector> ArgsNew; for (long unsigned int i = 0; i < Args.size(); ++i) @@ -151,6 +163,8 @@ public: return makePtr(Callee, std::move(ArgsNew)); } + + llvm::Value *codegen() override; }; /// PrototypeAST - This class represents the "prototype" for a function, @@ -163,7 +177,8 @@ class PrototypeAST public: PrototypeAST(const string &Name, vector Args) : Name(Name), Args(std::move(Args)) {} - + + llvm::Function *codegen(); const string &getName() const { return Name; } const vector &getArgs() const { return Args; } }; @@ -177,6 +192,8 @@ class FunctionAST public: FunctionAST(std::unique_ptr Proto, std::unique_ptr Body) : Proto(std::move(Proto)), Body(std::move(Body)) {} + llvm::Function *codegen(); + const string &getFuncName() const { return Proto->getName(); } const vector &getFuncArgs() const { return Proto->getArgs(); } std::unique_ptr &getFuncBody() { return Body; } @@ -234,6 +251,14 @@ vector> combination(const int num, const vector& indexs); size_t combination(size_t k, size_t n); +/// LogError* - These are little helper functions for error handling. +ast_ptr LogError(const char *Str); + +std::unique_ptr LogErrorP(const char *Str); + +// LogErrorV for codegen +llvm::Value *LogErrorV(const char *Str); + // } // end anonymous namespace #endif \ No newline at end of file diff --git a/include/codegenLLVM.hpp b/include/codegenLLVM.hpp new file mode 100644 index 0000000000000000000000000000000000000000..0e7188333b75e91910c8c5aba8161481ac65bf7e --- /dev/null +++ b/include/codegenLLVM.hpp @@ -0,0 +1,37 @@ +#ifndef _CODEGENLLVM +#define _CODEGENLLVM + +class PrototypeAST; + +#include "KaleidoscopeJIT.h" +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/LegacyPassManager.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Verifier.h" +#include "llvm/Support/TargetSelect.h" +#include "llvm/Target/TargetMachine.h" +#include "llvm/Transforms/InstCombine/InstCombine.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Scalar/GVN.h" + +// codegen +extern std::unique_ptr TheContext; +extern std::unique_ptr TheModule; +extern std::unique_ptr> Builder; +extern std::map NamedValues; +extern std::unique_ptr TheFPM; +extern std::unique_ptr TheJIT; +extern std::map> FunctionProtos; +extern llvm::ExitOnError ExitOnErr; +void InitializeModule(); +void InitializeModuleAndPassManager(void); + +#endif \ No newline at end of file diff --git a/include/laxerAST.hpp b/include/laxerAST.hpp index bb3a2c0fe5c127b1feb02b7c515ea4dd15fb2424..161ce7bb63db8bfebcf2275edc35ad7a34ce44a2 100644 --- a/include/laxerAST.hpp +++ b/include/laxerAST.hpp @@ -49,9 +49,4 @@ void installOperators(); /// GetTokPrecedence - Get the precedence of the pending binary operator token. int GetTokPrecedence(); -/// LogError* - These are little helper functions for error handling. -ast_ptr LogError(const char *Str); - -std::unique_ptr LogErrorP(const char *Str); - #endif diff --git a/include/parserASTLY.hpp b/include/parserASTLY.hpp index 530f5783f46061ec10876318b17131f0022937cb..b840038d907c4a771587686a03482412c1c3fba7 100644 --- a/include/parserASTLY.hpp +++ b/include/parserASTLY.hpp @@ -46,4 +46,10 @@ string readFileIntoString(const char * filename); ast_ptr ParseExpressionFromString(); ast_ptr ParseExpressionFromString(string str); +std::unique_ptr ParseTopLevelExprForStr(string str); + +void HandleTopLevelExpression(string str); + +void MainLoopForStr(string str); + #endif \ No newline at end of file diff --git a/src/basic.cpp b/src/basic.cpp index cd8ade9258cb8279639b3d0a4d6cd30c5b280d3e..d46229b4f094712f7ebf161d6492f95957db9c6a 100644 --- a/src/basic.cpp +++ b/src/basic.cpp @@ -7,6 +7,8 @@ #include #include +#include "codegenLLVM.hpp" + using std::cerr; using std::cout; using std::endl; @@ -736,4 +738,196 @@ size_t combination(size_t k, size_t n) return 0; } return factorial(n) / (factorial(k) * factorial(n - k)); +} + +/// LogError* - These are little helper functions for error handling. +ast_ptr LogError(const char *Str) +{ + fprintf(stderr, "Error: %s\n", Str); + return nullptr; +} +std::unique_ptr LogErrorP(const char *Str) +{ + LogError(Str); + return nullptr; +} + +// codegen +llvm::Value *LogErrorV(const char *Str) +{ + LogError(Str); + return nullptr; +} + +// codegen +std::unique_ptr TheContext; +std::unique_ptr TheModule; +std::unique_ptr> Builder; +std::map NamedValues; +std::unique_ptr TheFPM; +std::map> FunctionProtos; + +void InitializeModule() +{ + // Open a new context and module. + TheContext = std::make_unique(); + TheModule = std::make_unique("my cool jit", *TheContext); + + // Create a new builder for the module. + Builder = std::make_unique>(*TheContext); +} + +void InitializeModuleAndPassManager(void) { + // Open a new module. + TheContext = std::make_unique(); + TheModule = std::make_unique("my cool jit", *TheContext); + TheModule->setDataLayout(TheJIT->getDataLayout()); + + // Create a new builder for the module. + Builder = std::make_unique>(*TheContext); + + // Create a new pass manager attached to it. + TheFPM = std::make_unique(TheModule.get()); + + // Do simple "peephole" optimizations and bit-twiddling optzns. + TheFPM->add(llvm::createInstructionCombiningPass()); + // Reassociate expressions. + TheFPM->add(llvm::createReassociatePass()); + // Eliminate Common SubExpressions. + TheFPM->add(llvm::createGVNPass()); + // Simplify the control flow graph (deleting unreachable blocks, etc). + TheFPM->add(llvm::createCFGSimplificationPass()); + + TheFPM->doInitialization(); +} + +llvm::Function *getFunction(std::string Name) +{ + // First, see if the function has already been added to the current module. + if(auto *F = TheModule->getFunction(Name)) + return F; + + // If not, check whether we can codegen the declaration from some existing + // prototype. + auto FI = FunctionProtos.find(Name); + if(FI != FunctionProtos.end()) + return FI->second->codegen(); + + // If no existing prototype exists, return null. + return nullptr; +} + +llvm::Value *ExprAST::codegen() { return 0; } + +llvm::Value *NumberExprAST::codegen() { return llvm::ConstantFP::get(*TheContext, llvm::APFloat(Val)); } + +llvm::Value *VariableExprAST::codegen() +{ // Look this variable up in the function. + llvm::Value *V = NamedValues[Name]; + if(!V) + return LogErrorV("Unknown variable name"); + return V; +} + +llvm::Value *BinaryExprAST::codegen() +{ + llvm::Value *L = LHS->codegen(); + llvm::Value *R = RHS->codegen(); + if(!L || !R) + return nullptr; + + switch(Op) + { + case '+': + return Builder->CreateFAdd(L, R, "addtmp"); + case '-': + return Builder->CreateFSub(L, R, "subtmp"); + case '*': + return Builder->CreateFMul(L, R, "multmp"); + case '<': + L = Builder->CreateFCmpULT(L, R, "cmptmp"); + // Convert bool 0/1 to double 0.0 or 1.0 + return Builder->CreateUIToFP(L, llvm::Type::getDoubleTy(*TheContext), "booltmp"); + default: + return LogErrorV("invalid binary operator"); + } +} + +llvm::Value *CallExprAST::codegen() +{ + // Look up the name in the global module table. + llvm::Function *CalleeF = getFunction(Callee); + if(!CalleeF) + return LogErrorV("Unknown function referenced"); + + // If argument mismatch error. + if(CalleeF->arg_size() != Args.size()) + return LogErrorV("Incorrect # arguments passed"); + + std::vector ArgsV; + for(unsigned i = 0, e = Args.size(); i != e; ++i) + { + ArgsV.push_back(Args[i]->codegen()); + if(!ArgsV.back()) + return nullptr; + } + + return Builder->CreateCall(CalleeF, ArgsV, "calltmp"); +} + +llvm::Function *PrototypeAST::codegen() +{ + // Make the function type: double(double,double) etc. + std::vector Doubles(Args.size(), llvm::Type::getDoubleTy(*TheContext)); + llvm::FunctionType *FT = llvm::FunctionType::get(llvm::Type::getDoubleTy(*TheContext), Doubles, false); + + llvm::Function *F = llvm::Function::Create(FT, llvm::Function::ExternalLinkage, Name, TheModule.get()); + + // Set names for all arguments. + unsigned Idx = 0; + for(auto &Arg : F->args()) + Arg.setName(Args[Idx++]); + + return F; +} + +llvm::Function *FunctionAST::codegen() +{ + // Transfer ownership of the prototype to the FunctionProtos map, but keep a + // reference to it for use below. + auto &P = *Proto; + FunctionProtos[Proto->getName()] = std::move(Proto); + llvm::Function *TheFunction = getFunction(P.getName()); + if(!TheFunction) + return nullptr; + // if(!TheFunction) + // TheFunction = Proto->codegen(); + + + // Create a new basic block to start insertion into. + llvm::BasicBlock *BB = llvm::BasicBlock::Create(*TheContext, "entry", TheFunction); + Builder->SetInsertPoint(BB); + + // Record the function arguments in the NamedValues map. + NamedValues.clear(); + for(auto &Arg : TheFunction->args()) + NamedValues[std::string(Arg.getName())] = &Arg; + + if(llvm::Value *RetVal = Body->codegen()) + { + // Finish off the function. + Builder->CreateRet(RetVal); + + // Validate the generated code, checking for consistency. + verifyFunction(*TheFunction); + + // Optimize the function. + TheFPM->run(*TheFunction); + + return TheFunction; + } + + // Error reading body, remove function. + TheFunction->eraseFromParent(); + return nullptr; } \ No newline at end of file diff --git a/src/laxerAST.cpp b/src/laxerAST.cpp index 0af671852bc526cb8c2e74eadc3e5a0cb4c6a23d..18e452f77739f237f59520e724930cddbd613452 100644 --- a/src/laxerAST.cpp +++ b/src/laxerAST.cpp @@ -138,15 +138,3 @@ int GetTokPrecedence() return -1; return TokPrec; } - -/// LogError* - These are little helper functions for error handling. -ast_ptr LogError(const char *Str) -{ - fprintf(stderr, "Error: %s\n", Str); - return nullptr; -} -std::unique_ptr LogErrorP(const char *Str) -{ - LogError(Str); - return nullptr; -} diff --git a/src/main.cpp b/src/main.cpp index 1008febc1d9f66c355ee3028699ca1e5f716d49d..07de0e585faae2bc7d779685ad80d9fead29b2eb 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -7,6 +7,7 @@ #include "geneCode.hpp" #include "tools.hpp" #include "benchMark.hpp" +#include "codegenLLVM.hpp" #include #include @@ -26,10 +27,24 @@ using std::vector; // Main driver code. //===----------------------------------------------------------------------===// +// codegen +std::unique_ptr TheJIT; +llvm::ExitOnError ExitOnErr; int main() { installOperatorsForStr(); initPython(); + // codegen + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + llvm::InitializeNativeTargetAsmParser(); + + // TheJIT = std::make_unique(); + TheJIT = ExitOnErr(llvm::orc::KaleidoscopeJIT::Create()); + + // Make the module, which holds all the code. + // InitializeModule(); + InitializeModuleAndPassManager(); // tmp files for input and output. ifstream infile; @@ -69,11 +84,11 @@ int main() fprintf(stderr, GREEN "ready> " RESET); continue; } - else if (inputStr.back() == ';') - { - fprintf(stderr, "you do not need to add a ';' after the expression\n"); - inputStr.pop_back(); // remove the last char ';' - } + // else if (inputStr.back() == ';') + // { + // fprintf(stderr, "you do not need to add a ';' after the expression\n"); + // inputStr.pop_back(); // remove the last char ';' + // } auto benchMarkData = initalBenchMark(); auto pos = benchMarkData.find(inputStr); @@ -90,7 +105,15 @@ int main() const char *split = " "; // Get the information about the input expr + // codegen test + cout << "input:" << inputStr << endl; + MainLoopForStr(inputStr); + fprintf(stderr, GREEN "ready> " RESET); + continue; + // codegen test end + auto originExpr = ParseExpressionFromString(inputStr); + printExpr(originExpr, "codegen test: "); vector vars; getVariablesFromExpr(originExpr, vars); // cout << vars.size() << endl; diff --git a/src/mathfuncRewrite.cpp b/src/mathfuncRewrite.cpp index b0b84f427f10a532d7426ec3b52b5c79ba4a7bec..e7c75549459dbed1209d7ed0a111ee885e1f68ce 100644 --- a/src/mathfuncRewrite.cpp +++ b/src/mathfuncRewrite.cpp @@ -98,7 +98,7 @@ ast_ptr logTolog1p(const ast_ptr &expr) double number = numberExpr->getNumber(); if(number == 1) { - argsNew.push_back(std::move(rhs->Clone())); + argsNew.push_back(rhs->Clone()); return makePtr("log1p", std::move(argsNew)); } } @@ -108,7 +108,7 @@ ast_ptr logTolog1p(const ast_ptr &expr) double number = numberExpr->getNumber(); if(number == 1) { - argsNew.push_back(std::move(lhs->Clone())); + argsNew.push_back(lhs->Clone()); return makePtr("log1p", std::move(argsNew)); } } @@ -161,8 +161,8 @@ ast_ptr sqrtTohypot(const ast_ptr &expr) if(isEqual(lhsL, rhsL) && isEqual(lhsR, rhsR)) { vector argsNew; // store the parameters for hypot - argsNew.push_back(std::move(lhsL->Clone())); - argsNew.push_back(std::move(lhsR->Clone())); + argsNew.push_back(lhsL->Clone()); + argsNew.push_back(lhsR->Clone()); return makePtr("hypot", std::move(argsNew)); } @@ -666,9 +666,9 @@ vector toFma(const ast_ptr &expr) if(!isSpecialNumber(lhsL) && !isSpecialNumber(rhsL)) // if(!isEqual(lhsL, tmpNegOne) && !isEqual(rhsL, tmpNegOne)) { - argsNew.push_back(std::move(lhsL->Clone())); - argsNew.push_back(std::move(rhsL->Clone())); - argsNew.push_back(std::move(rhs->Clone())); + argsNew.push_back(lhsL->Clone()); + argsNew.push_back(rhsL->Clone()); + argsNew.push_back(rhs->Clone()); string calleeNew = "fma"; ast_ptr exprFinal = makePtr(calleeNew, std::move(argsNew)); results.push_back(std::move(exprFinal)); @@ -692,7 +692,7 @@ vector toFma(const ast_ptr &expr) // the key to control the rewrite results if(results.size() == 0) { - results.push_back(std::move(expr->Clone())); + results.push_back(expr->Clone()); } // printExprs(results, prompt + "results: "); // cout << prompt << "end--------" << endl; @@ -742,7 +742,7 @@ vector fmaRewrite(const ast_ptr &expr) if (expr->type() != "Binary") // May be call, variable or number { vector exprsFinal; - exprsFinal.push_back(std::move(expr->Clone())); + exprsFinal.push_back(expr->Clone()); if (callCount == 1) printExprs(exprsFinal, prompt + "exprsFinal: "); if (callCount == 1) cout << prompt << "end--------" << endl; callCount--; diff --git a/src/parserASTLY.cpp b/src/parserASTLY.cpp index b72572dbfe1016760646252aaba5579c9c1159bb..719683d6e9e42fb9b49e69f97712b5499fe2dd3b 100644 --- a/src/parserASTLY.cpp +++ b/src/parserASTLY.cpp @@ -2,6 +2,7 @@ #include "laxerASTLY.hpp" #include "parserASTLY.hpp" #include "string.h" +#include "codegenLLVM.hpp" #include using std::string; @@ -201,7 +202,187 @@ ast_ptr ParseExpressionFromString(string str) if (auto E = ParseExpressionForStr()){ ast_ptr es = E->Clone(); + // if(auto* EIR = E->codegen()) + // { + // fprintf(stderr, "Read top-level expression:"); + // EIR->print(llvm::errs()); + // fprintf(stderr, "\n"); + + // // Remove the anonymous expression. + // // EIR->eraseFromParent(); + // } return es; } return nullptr; +} + +/// prototype +/// ::= id '(' id* ')' +std::unique_ptr ParsePrototypeForStr() +{ + if(CurTokForStr != tok_identifier_forstr) + return LogErrorP("Expected function name in prototype"); + + string FnName = IdentifierStr1; + getNextTokenForStr(); + + if(CurTokForStr != '(') + return LogErrorP("Expected '(' in prototype"); + + vector ArgNames; + while(getNextTokenForStr() == tok_identifier_forstr) + ArgNames.push_back(IdentifierStr1); + if(CurTokForStr != ')') + return LogErrorP("Expected ')' in prototype"); + + // success. + getNextTokenForStr(); // eat ')'. + + return makePtr(FnName, std::move(ArgNames)); +} + +/// definition ::= 'def' prototype expression +std::unique_ptr ParseDefinitionForStr() +{ + getNextTokenForStr(); // eat def. + auto Proto = ParsePrototypeForStr(); + if(!Proto) + return nullptr; + + if(auto E = ParseExpressionForStr()) + return makePtr(std::move(Proto), std::move(E)); + return nullptr; +} + +/// toplevelexpr ::= expression +std::unique_ptr ParseTopLevelExprForStr() +{ + if(auto E = ParseExpressionForStr()) + { + // for test purposes + // printExpr(E, "codegen test: "); + // Make an anonymous proto. + auto Proto = makePtr("__anon_expr", vector()); + return makePtr(std::move(Proto), std::move(E)); + } + return nullptr; +} + +/// external ::= 'extern' prototype +std::unique_ptr ParseExternForStr() +{ + getNextTokenForStr(); // eat extern. + return ParsePrototypeForStr(); +} + +void HandleDefinitionForStr() +{ + if(auto FnAST = ParseDefinitionForStr()) + { + if(auto* FnIR = FnAST->codegen()) + { + fprintf(stderr, "Read function definition:"); + FnIR->print(llvm::errs()); + fprintf(stderr, "\n"); + ExitOnErr(TheJIT->addModule(llvm::orc::ThreadSafeModule(std::move(TheModule), std::move(TheContext)))); + InitializeModuleAndPassManager(); + } + } + else + { + // Skip token for error recovery. + getNextTokenForStr(); + } +} + +void HandleExternForStr() +{ + if(auto ProtoAST = ParseExternForStr()) + { + if(auto* FnIR = ProtoAST->codegen()) + { + fprintf(stderr, "Read extern: "); + FnIR->print(llvm::errs()); + fprintf(stderr, "\n"); + FunctionProtos[ProtoAST->getName()] = std::move(ProtoAST); + } + } + else + { + // Skip token for error recovery. + getNextTokenForStr(); + } +} + +void HandleTopLevelExpressionForStr() +{ + // Evaluate a top-level expression into an anonymous function. + if(auto FnAST = ParseTopLevelExprForStr()) + { + if(auto* FnIR = FnAST->codegen()) + { + // JIT the module containing the anonymous expression, keeping a handle so + // we can free it later. + auto RT = TheJIT->getMainJITDylib().createResourceTracker(); + + auto TSM = llvm::orc::ThreadSafeModule(std::move(TheModule), std::move(TheContext)); + ExitOnErr(TheJIT->addModule(std::move(TSM), RT)); + InitializeModuleAndPassManager(); + + // Search the JIT for the __anon_expr symbol. + auto ExprSymbol = ExitOnErr(TheJIT->lookup("__anon_expr")); + + // Get the symbol's address and cast it to the right type (takes no + // arguments, returns a double) so we can call it as a native function. + double (*FP)() = (double (*)())(intptr_t)ExprSymbol.getAddress(); + fprintf(stderr, "Evaluated to %f\n", FP()); + + // Delete the anonymous expression module from the JIT. + ExitOnErr(RT->remove()); + + // fprintf(stderr, "Read top-level expression:"); + // FnIR->print(llvm::errs()); + // fprintf(stderr, "\n"); + + // // Remove the anonymous expression. + // FnIR->eraseFromParent(); + + } + } +} + +/// top ::= definition | external | expression | ';' +void MainLoopForStr(string str) +{ + if(CurTokForStr == ';') + { + filestring.clear(); + filestring.shrink_to_fit(); + flag = 0; + } + + filestring = str; + getNextTokenForStr(); + std::cout << "str: " << str << std::endl; + std::cout << "filestring: " << filestring << std::endl; + while(true) // TODO: check + { + switch(CurTokForStr) + { + case tok_eof_forstr: + return; + case ';': // ignore top-level semicolons. + getNextTokenForStr(); + return; + case tok_def_forstr: + HandleDefinitionForStr(); + break; + case tok_extern_forstr: + HandleExternForStr(); + break; + default: + HandleTopLevelExpressionForStr(); + break; + } + } } \ No newline at end of file