From 60cea142c75c02cbb3ff39eacf11e1503e2ac603 Mon Sep 17 00:00:00 2001 From: hjw Date: Tue, 14 Mar 2023 08:41:54 +0800 Subject: [PATCH] Improve exprAuto Details: Feature enhancement: optimize the pickTheBest function Feature enhancement: improve exprInfo class Feature enhancement: modify the output of the geneExprCodeKernel and geneExprCode functions to an expression string Feature enhancement: add timing function to matlab function Feature enhancement: add special case processing, if the initial error is not greater than 0.5, no subsequent detection will be performed Feature enhancement: add special case handling, if there is only one equivalent expression, and it is the same as the input expression, nothing is done Feature enhancement: add timing function to main program Feature enhancement: Support for mpfr code generation for fma Feature modification: modify the error threshold of the example Feature modification: remove some rewrite related print information Bug fix: Fixed Python parsing double-multiple symbol '**' failure Bug fix: The abnormal value of error detection is not written to the file to avoid subsequent file parsing failure --- include/exprAuto.hpp | 3 +- include/tools.hpp | 1 + src/basic.cpp | 3 ++ src/exprAuto.cpp | 54 ++++++++++++++++++-------------- src/geneCode.cpp | 9 +++--- src/getUpEdge.m | 3 +- src/main.cpp | 50 ++++++++++++++++++++--------- src/mathfuncRewrite.cpp | 12 +++---- src/simplify.py | 2 +- src/tools.cpp | 10 ++++++ srcTest/test1paramFPEDParallel.c | 8 +++-- 11 files changed, 100 insertions(+), 55 deletions(-) diff --git a/include/exprAuto.hpp b/include/exprAuto.hpp index 2a9208a..1a74a7c 100644 --- a/include/exprAuto.hpp +++ b/include/exprAuto.hpp @@ -6,6 +6,7 @@ #include #include "basic.hpp" #include "preprocess.hpp" +#include "tools.hpp" using std::string; using std::vector; @@ -44,7 +45,7 @@ void sortExpr(ast_ptr &expr); vector tryRewrite(ast_ptr expr, bool addSelf = true); -string pickTheBest(string uniqueLabel, vector testSet, vector intervals, vector scales); +exprInfo pickTheBest(string uniqueLabel, vector testSet, vector intervals, vector scales); void geneSampleData(); diff --git a/include/tools.hpp b/include/tools.hpp index 6175c2f..e71a315 100644 --- a/include/tools.hpp +++ b/include/tools.hpp @@ -13,6 +13,7 @@ public: double start; double end; vector intervals; + string suffix; string exprStr; double error; double aveError; diff --git a/src/basic.cpp b/src/basic.cpp index cd8ade9..f2d4774 100644 --- a/src/basic.cpp +++ b/src/basic.cpp @@ -636,6 +636,9 @@ string mpfrCodeGenerator(const ast_ptr &expr, size_t &mpfr_variables, const std: if (call_str == "mpfr_pow") { string str1 = argsStr.at(0), str2 = argsStr.at(1); callee_str = call_str + "(mp" + to_string(mpfr_variables) + ", " + str1 + ", " + str2 + ", MPFR_RNDN);"; + } else if (call_str == "mpfr_fma") { + string str1 = argsStr.at(0), str2 = argsStr.at(1), str3 = argsStr.at(2); + callee_str = call_str + "(mp" + to_string(mpfr_variables) + ", " + str1 + ", " + str2 + ", " + str3 + ", MPFR_RNDN);"; } else { string str1 = argsStr.at(0); callee_str = call_str + "(mp" + to_string(mpfr_variables) + ", " + str1 + ", MPFR_RNDN);"; diff --git a/src/exprAuto.cpp b/src/exprAuto.cpp index 3b874d9..b38e4cf 100644 --- a/src/exprAuto.cpp +++ b/src/exprAuto.cpp @@ -370,10 +370,10 @@ vector dealWithCalls(const ast_ptr &expr) static size_t callCount = 0; callCount++; callLevel++; - string prompt(callLevel * promtTimes, callLevelChar); - prompt.append(callCount, callCountChar); - prompt += "dealWithCalls: "; - cout << prompt << "start--------" << endl; + // string prompt(callLevel * promtTimes, callLevelChar); + // prompt.append(callCount, callCountChar); + // prompt += "dealWithCalls: "; + // cout << prompt << "start--------" << endl; CallExprAST *callExpr = dynamic_cast(expr.get()); string callee = callExpr->getCallee(); vector &args = callExpr->getArgs(); @@ -399,8 +399,8 @@ vector dealWithCalls(const ast_ptr &expr) } deleteTheSame(allResults); - printExprs(allResults, prompt + "at the end, "); - cout << prompt << "end--------" << endl; + // printExprs(allResults, prompt + "at the end, "); + // cout << prompt << "end--------" << endl; callCount--; callLevel--; return allResults; @@ -953,7 +953,7 @@ vector tryRewrite(ast_ptr expr, bool addSelf) } deleteTheSame(results); - if(callCount == 1) printExprs(results, prompt + "at the last: "); + // if(callCount == 1) printExprs(results, prompt + "at the last: "); // cout << prompt << "end--------" < testSet, vector intervals, vector scales) +exprInfo pickTheBest(string uniqueLabel, vector testSet, vector intervals, vector scales) { + exprInfo result; string bestExpr = "origin"; double maxError = INFINITY; double aveError = INFINITY; @@ -1030,7 +1031,12 @@ string pickTheBest(string uniqueLabel, vector testSet, vector in aveError = tempError.aveError; } } - return bestExpr; + cout << "pick \'" << bestExpr << "\' as the init expression.\n"; + result.intervals = intervals; + result.suffix = bestExpr; + result.exprStr = ""; + result.maxError = maxError; + return result; } size_t pickTheBest(vector &items, ast_ptr &originExpr) @@ -1285,38 +1291,38 @@ vector exprAutoNew(const ast_ptr &expr, bool addSelf) string prompt(callLevel * promtTimes, callLevelChar); prompt.append(callCount, callCountChar); prompt += "exprAutoNew: "; - cout << prompt << "start-----------" << endl; + if(callCount == 1) cout << prompt << "start-----------" << endl; if (expr == nullptr) { cerr << prompt << "ERROR: the input expr is nullptr!" << endl; exit(EXIT_FAILURE); } - cout << prompt << "at the beginning: expr = " << PrintExpression(expr) << endl; - cout << prompt << "step1: preprocess" << endl; + if(callCount == 1) cout << prompt << "at the beginning: expr = " << PrintExpression(expr) << endl; + if(callCount == 1) cout << prompt << "step1: preprocess" << endl; ast_ptr exprNew = preprocess(expr); // exprNew = simplifyExpr(exprNew); exprNew = minusRewrite(exprNew); combineConstant(exprNew); - cout << prompt << "after preprocess: exprNew = " << PrintExpression(exprNew) << endl; + if(callCount == 1) cout << prompt << "after preprocess: exprNew = " << PrintExpression(exprNew) << endl; vector results; - cout << prompt << "step2: judge if exprNew is a fraction" << endl; + if(callCount == 1) cout << prompt << "step2: judge if exprNew is a fraction" << endl; if (isFraction(exprNew)) { - cout << prompt << "exprNew is a fraction, so perform step3 and step4" << endl; + if(callCount == 1) cout << prompt << "exprNew is a fraction, so perform step3 and step4" << endl; ast_ptr numerator = getNumerator(exprNew); ast_ptr denominator = getDenominator(exprNew); - cout << prompt << "step3: perform on numerator." << endl; + if(callCount == 1) cout << prompt << "step3: perform on numerator." << endl; auto numerators = tryRewrite(std::move(numerator), addSelf); - cout << prompt << "step3: end perform on numerator." << endl; + if(callCount == 1) cout << prompt << "step3: end perform on numerator." << endl; - cout << prompt << "step3: perform on denominator." << endl; + if(callCount == 1) cout << prompt << "step3: perform on denominator." << endl; auto denominators = tryRewrite(std::move(denominator), addSelf); - cout << prompt << "step3: end perform on denominator." << endl; + if(callCount == 1) cout << prompt << "step3: end perform on denominator." << endl; - cout << prompt << "step4: combine numerator and denominator." << endl; + if(callCount == 1) cout << prompt << "step4: combine numerator and denominator." << endl; if(isConstant(denominators)) // only one element, and the element is constant { if(isConstant(numerators)) @@ -1347,14 +1353,14 @@ vector exprAutoNew(const ast_ptr &expr, bool addSelf) } else { - cout << prompt << "exprNew is not a fraction, so perform step4" << endl; + if(callCount == 1) cout << prompt << "exprNew is not a fraction, so perform step4" << endl; results = tryRewrite(std::move(exprNew), addSelf); // results = tryRewriteRandom(std::move(exprNew)); } - cout << prompt << "at the last: results size = " << results.size() << endl; - printExprs(results, prompt); - cout << prompt << "end-----------" << endl; + if(callCount == 1) cout << prompt << "at the last: results size = " << results.size() << endl; + if(callCount == 1) printExprs(results, prompt); + if(callCount == 1) cout << prompt << "end-----------" << endl; callCount--; callLevel--; return results; diff --git a/src/geneCode.cpp b/src/geneCode.cpp index c792f1f..98c2606 100644 --- a/src/geneCode.cpp +++ b/src/geneCode.cpp @@ -151,7 +151,7 @@ string geneExprCodeKernel(string exprStr, vector vars, string uniqueLabe fout << "}\n"; fout.close(); - return funcName; + return exprStr; } string geneExprCode(string exprStr, string uniqueLabel, string tail) @@ -160,9 +160,9 @@ string geneExprCode(string exprStr, string uniqueLabel, string tail) vector vars; getVariablesFromExpr(originExpr, vars); - auto funcName = geneExprCodeKernel(exprStr, vars, uniqueLabel, tail); + geneExprCodeKernel(exprStr, vars, uniqueLabel, tail); - return funcName; + return exprStr; } string geneTGenCode(string exprStr, vector vars, string uniqueLabel, string tail) @@ -243,7 +243,7 @@ string geneHerbieCode(string uniqueLabel) {"exp1x_log", "expm1(x) / x"}, {"intro_example", "(x / (1.0 + pow(x, 3.0))) * fma(x, x, (1.0 - x))"}, {"logexp", "log1p(exp(x))"}, - {"NMSEexample31", "(x + (1.0 - x)) / fma(pow(x, 0.25), pow(x, 0.25), sqrt((x + 1.0)))"}, + {"NMSEexample31", "(x + (1.0 - x)) / fma(pow(x, 0.25), pow(x, 0.25), sqrt(x + 1.0))"}, {"NMSEexample310", "log1p(-x) / log1p(x)"}, {"NMSEexample34", "tan((x / 2.0))"}, {"NMSEexample35", "atan2((x + (1.0- x)), (1.0+fma(sqrt(x), sqrt(x), (x * x))))"}, @@ -321,6 +321,7 @@ string geneMpfrCode(const ast_ptr &exprAst, const string uniqueLabel, vector vars; // getVariablesFromExpr(exprAst, vars); diff --git a/src/getUpEdge.m b/src/getUpEdge.m index c45ef64..31633ee 100644 --- a/src/getUpEdge.m +++ b/src/getUpEdge.m @@ -1,4 +1,5 @@ % disp(sampleFileName); +tic; sampleData = importdata(sampleFileName); lenTmp = length(sampleData); maxTmp = max(sampleData(:,2)); @@ -186,7 +187,7 @@ if maxTmp > 2 else meanTmp = mean(sampleData(:, 2)); end - +toc % compute the threshold for deviding intervals % thresholdTmp = (maxTmp + meanTmp) / 2; % fprintf("%s: max = %g, average = %g, threshold = %d\n", sampleFileName, maxTmp, meanTmp, thresholdTmp); diff --git a/src/main.cpp b/src/main.cpp index a8a622a..dacc164 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -28,10 +28,10 @@ map> benchmarkThresholds = { {"exp1x", {1}}, {"exp1x_log", {1}}, {"intro_example", {1}}, - {"logexp", {10}}, + {"logexp", {1}}, {"NMSEexample31", {2}}, {"NMSEexample310", {2}}, - {"NMSEexample34", {2.1}}, + {"NMSEexample34", {0.5}}, {"NMSEexample35", {1}}, {"NMSEexample36", {2}}, {"NMSEexample37", {0.5}}, @@ -212,30 +212,43 @@ int main() string mkdirCommand = "mkdir -p srcTest/" + uniqueLabel + " outputs/" + uniqueLabel; system(mkdirCommand.c_str()); - auto funcNameOrigin = geneExprCodeKernel(inputStr, vars, uniqueLabel, "origin"); + auto exprOrigin = geneExprCodeKernel(inputStr, vars, uniqueLabel, "origin"); // auto funcNameOrigin = geneExprCode(inputStr, uniqueLabel, "origin"); // auto funcNameHerbie = geneHerbieCode(inputStr, uniqueLabel, "herbie"); - auto funcNameHerbie = geneHerbieCode(uniqueLabel); - // auto funcNameDaisy = geneDaisyCode(inputStr, uniqueLabel, "daisy"); + auto exprHerbie = geneHerbieCode(uniqueLabel); + // auto exprDaisy = geneDaisyCode(inputStr, uniqueLabel, "daisy"); auto funcNameMpfr = geneMpfrCode(inputStr, uniqueLabel, vars); // TODO: improve pickTheBest to support more suffix // pickTheBest(uniqueLabel, 0, 1, 100); vector suffixSet = {"origin", "herbie"}; - auto suffix = pickTheBest(uniqueLabel, suffixSet, intervals, scales); + auto initExprInfo = pickTheBest(uniqueLabel, suffixSet, intervals, scales); + auto suffix = initExprInfo.suffix; if (suffix == "origin") { - inputStr = funcNameOrigin; + inputStr = exprOrigin; } else if (suffix == "herbie") { - inputStr = funcNameHerbie; + inputStr = exprHerbie; } else { fprintf(stderr, "ERROR: main: we do not support the suffix %s now\n", suffix.c_str()); exit(EXIT_FAILURE); } + cout << "the pick expr is " << inputStr << "\n"; + auto timeTmp1 = std::chrono::high_resolution_clock::now(); // init over + std::chrono::duration init_seconds = timeTmp1 - timeStart; + cout << BLUE << "init time: " << init_seconds.count() << " s" << RESET << endl; + auto &initExprMaxError = initExprInfo.maxError; + if (initExprMaxError <= 0.5) + { + fprintf(stderr, "the error of %s is no bigger than 0.5, so do not need precision improvement.\n", inputStr.c_str()); + fprintf(stderr, GREEN "ready> " RESET); + continue; + } + // cout << BLUE << "main: start testError for origin: " << inputStr << RESET << endl; // auto timeTmp1 = std::chrono::high_resolution_clock::now(); // geneSampleData(); @@ -243,26 +256,31 @@ int main() // auto timeTmp2 = std::chrono::high_resolution_clock::now(); // cout << BLUE << "main: ending testError for origin: " << inputStr << RESET << endl; // std::chrono::duration testError_seconds = timeTmp2 - timeTmp1; - // cout << BLUE << "testError time: " << testError_seconds.count() << "s" << RESET << endl; + // cout << BLUE << "testError time: " << testError_seconds.count() << " s" << RESET << endl; // fprintf(stderr, GREEN "ready> " RESET); // continue; // auto infoTmp = testError(uniqueLabel, "origin", intervals, scales); // TODO-done: 完善origin误差测试 // WHY do it? Because no pickTheBest before. So, have to use testError to get the sample data. auto upEdgeFileName = geneBoundaryData(uniqueLabel, suffix); // sample data == matlab ==> upEdge data // TODO: support multiple dimension cout << "upEdgeFileName: " << upEdgeFileName << "\n"; + auto timeTmp2 = std::chrono::high_resolution_clock::now(); // matlab over + std::chrono::duration matlab_seconds = timeTmp2 - timeTmp1; + cout << BLUE << "regime time (matlab part): " << matlab_seconds.count() << " s" << RESET << endl; vector upEdgeFileNames; upEdgeFileNames.push_back(upEdgeFileName); auto intervalData = getIntervalData(upEdgeFileNames, thresholds); fmt::print("after regime, we have {} intervals: {}\n", intervalData.size(), intervalData); + auto timeTmp3 = std::chrono::high_resolution_clock::now(); + std::chrono::duration regime_seconds = timeTmp3 - timeTmp2; + cout << BLUE << "regime time (other part): " << regime_seconds.count() << " s" << RESET << endl; // fprintf(stderr, GREEN "ready> " RESET); // continue; cout << "=-=-=-=-=-=-=-=-=-=-=-=-= rewrite start =-=-=-=-=-=-=-=-=-=-=-=-=" << endl; - // auto timeTmp3 = std::chrono::high_resolution_clock::now(); auto exprInfoVector = rewrite(inputStr, uniqueLabel, intervalData); - // auto timeTmp4 = std::chrono::high_resolution_clock::now(); - // std::chrono::duration rewrite_seconds = timeTmp4 - timeTmp3; - // cout << BLUE << "rewrite time: " << rewrite_seconds.count() << " s" << RESET << endl; + auto timeTmp4 = std::chrono::high_resolution_clock::now(); // rewrite over + std::chrono::duration rewrite_seconds = timeTmp4 - timeTmp3; + cout << BLUE << "rewrite time: " << rewrite_seconds.count() << " s" << RESET << endl; // fprintf(stderr, GREEN "ready> " RESET); // continue; cout << "=-=-=-=-=-=-=-=-=-=-=-=-= rewrite end =-=-=-=-=-=-=-=-=-=-=-=-=" << endl; @@ -274,7 +292,9 @@ int main() infoTmp.performance = testPerformance(uniqueLabel, "final", intervals); cout << "performance: " << infoTmp.performance << "\n\n"; cout << "=-=-=-=-=-=-=-=-=-=-=-=-= test final code's error and performance end =-=-=-=-=-=-=-=-=-=-=-=-=\n"; - + auto timeTmp5 = std::chrono::high_resolution_clock::now(); + std::chrono::duration final_seconds = timeTmp5 - timeTmp4; + cout << BLUE << "final time: " << final_seconds.count() << " s" << RESET << endl; // write the results to file // auto results = exprAutoWrapper(inputStr, intervals, scales); // for (size_t i = 0; i < results.size(); i++) @@ -314,7 +334,7 @@ int main() auto timeEnd = std::chrono::high_resolution_clock::now(); std::chrono::duration elapsed_seconds = timeEnd - timeStart; - cout << BLUE << "elapsed time: " << elapsed_seconds.count() << "s" << RESET << endl; + cout << BLUE << "the whole time: " << elapsed_seconds.count() << " s" << RESET << endl; fprintf(stderr, GREEN "ready> " RESET); } diff --git a/src/mathfuncRewrite.cpp b/src/mathfuncRewrite.cpp index b0b84f4..fe4a606 100644 --- a/src/mathfuncRewrite.cpp +++ b/src/mathfuncRewrite.cpp @@ -727,14 +727,14 @@ vector fmaRewrite(const ast_ptr &expr) static size_t callCount = 0; callCount++; callLevel++; - string prompt(callLevel * promtTimes, callLevelChar); - prompt.append(callCount, callCountChar); - prompt += "fmaRewrite: "; + // string prompt(callLevel * promtTimes, callLevelChar); + // prompt.append(callCount, callCountChar); + // prompt += "fmaRewrite: "; // if (callCount == 1) cout << prompt << "start--------" << endl; if (expr == nullptr) { - cerr << prompt << "ERROR: expr is nullptr" << endl; + // cerr << prompt << "ERROR: expr is nullptr" << endl; exit(EXIT_FAILURE); } // printExpr(expr, "\tfmaRewrite: at the beginning: "); @@ -743,8 +743,8 @@ vector fmaRewrite(const ast_ptr &expr) { vector exprsFinal; exprsFinal.push_back(std::move(expr->Clone())); - if (callCount == 1) printExprs(exprsFinal, prompt + "exprsFinal: "); - if (callCount == 1) cout << prompt << "end--------" << endl; + // if (callCount == 1) printExprs(exprsFinal, prompt + "exprsFinal: "); + // if (callCount == 1) cout << prompt << "end--------" << endl; callCount--; callLevel--; return exprsFinal; diff --git a/src/simplify.py b/src/simplify.py index 28a66d1..f74b62b 100644 --- a/src/simplify.py +++ b/src/simplify.py @@ -36,7 +36,7 @@ def find_index(expr): def find_base(expr, i): - op_lst = ['+', '-', '*', '/'] + op_lst = ['+', '-', '*', '/', ',', ' '] if expr[i - 1] == ')': flag = 0 j = i - 2 diff --git a/src/tools.cpp b/src/tools.cpp index 6b2b892..4f3a4d9 100644 --- a/src/tools.cpp +++ b/src/tools.cpp @@ -678,6 +678,16 @@ vector rewrite(string exprStr, string uniqueLabel, vector maxReUlp) { // flag = 0; @@ -111,6 +112,7 @@ struct errorInfo test1FPEDparamParallel(DL x0Start, DL x0End, unsigned long int maxReUlp = reUlp; } } + } // aveReUlp = aveReUlp / (testNumX0); // if(flag == 1) { -- Gitee