diff --git a/src/mapleall/mpl2mpl/include/ext_constantfold.h b/src/mapleall/mpl2mpl/include/ext_constantfold.h index ac1d8e55b2af3c73106d0bf8846f2d20182523bc..bdaf517cb421ae35b565c7fa557b15c9dae8eb3b 100644 --- a/src/mapleall/mpl2mpl/include/ext_constantfold.h +++ b/src/mapleall/mpl2mpl/include/ext_constantfold.h @@ -21,9 +21,13 @@ namespace maple { class ExtConstantFold { public: explicit ExtConstantFold(MIRModule *mod) : mirModule(mod) {} + BaseNode *ExtFoldUnary(UnaryNode *node); + BaseNode *ExtFoldBinary(BinaryNode *node); + BaseNode* ExtFoldTernary(TernaryNode *node); StmtNode *ExtSimplify(StmtNode *node); - BaseNode *Fold(BaseNode *node); - BaseNode *FoldIor(BinaryNode *node); + BaseNode *ExtFold(BaseNode *node); + BaseNode *ExtFoldIor(BinaryNode *node); + BaseNode *ExtFoldXand(BinaryNode *node); StmtNode *ExtSimplifyBlock(BlockNode *node); StmtNode *ExtSimplifyIf(IfStmtNode *node); StmtNode *ExtSimplifyDassign(DassignNode *node); diff --git a/src/mapleall/mpl2mpl/src/ext_constantfold.cpp b/src/mapleall/mpl2mpl/src/ext_constantfold.cpp index c08391fee7efb23d38a608a4910e6dda37f36fb9..12f492d0bfd27c712f2675f0647affa5c6e1bc58 100644 --- a/src/mapleall/mpl2mpl/src/ext_constantfold.cpp +++ b/src/mapleall/mpl2mpl/src/ext_constantfold.cpp @@ -39,25 +39,105 @@ StmtNode *ExtConstantFold::ExtSimplify(StmtNode *node) { } BaseNode* ExtConstantFold::DispatchFold(BaseNode *node) { + // Not trying all possiblities. + // For simplicity, stop looking further down the expression once OP_OP_cior/OP_cand (etc) are seen CHECK_NULL_FATAL(node); - switch (node->GetOpCode()) { case OP_cior: case OP_lior: - return FoldIor(static_cast(node)); + return ExtFoldIor(static_cast(node)); + case OP_cand: + case OP_land: + return ExtFoldXand(static_cast(node)); + case OP_abs: + case OP_bnot: + case OP_lnot: + case OP_neg: + case OP_recip: + case OP_sqrt: + return ExtFoldUnary(static_cast(node)); + case OP_add: + case OP_ashr: + case OP_band: + case OP_bior: + case OP_bxor: +// case OP_cand: +// case OP_cior: + case OP_div: +// case OP_land: +// case OP_lior: + case OP_lshr: + case OP_max: + case OP_min: + case OP_mul: + case OP_rem: + case OP_shl: + case OP_sub: + case OP_eq: + case OP_ne: + case OP_ge: + case OP_gt: + case OP_le: + case OP_lt: + case OP_cmp: + return ExtFoldBinary(static_cast(node)); + case OP_select: + return ExtFoldTernary(static_cast(node)); default: return node; } } -BaseNode *ExtConstantFold::Fold(BaseNode *node) { +BaseNode *ExtConstantFold::ExtFoldUnary(UnaryNode *node) { + CHECK_NULL_FATAL(node); + BaseNode *result = nullptr; + result = DispatchFold(node->Opnd(0)); + if (result != node->Opnd(0)) { + node->SetOpnd(result, 0); + } + return node; +} + +BaseNode *ExtConstantFold::ExtFoldBinary(BinaryNode *node) { + CHECK_NULL_FATAL(node); + BaseNode *result = nullptr; + result = DispatchFold(node->Opnd(0)); + if (result != node->Opnd(0)) { + node->SetOpnd(result, 0); + } + result = DispatchFold(node->Opnd(1)); + if (result != node->Opnd(1)) { + node->SetOpnd(result, 1); + } + return node; +} + +BaseNode* ExtConstantFold::ExtFoldTernary(TernaryNode *node) { + CHECK_NULL_FATAL(node); + BaseNode *result = nullptr; + result = DispatchFold(node->Opnd(0)); + if (result != node->Opnd(0)) { + node->SetOpnd(result, 0); + } + result = DispatchFold(node->Opnd(1)); + if (result != node->Opnd(1)) { + node->SetOpnd(result, 1); + } + result = DispatchFold(node->Opnd(2)); + if (result != node->Opnd(2)) { + node->SetOpnd(result, 2); + } + return node; +} + +BaseNode *ExtConstantFold::ExtFold(BaseNode *node) { if (node == nullptr || kOpcodeInfo.IsStmt(node->GetOpCode())) { return nullptr; } return DispatchFold(node); } -BaseNode *ExtConstantFold::FoldIor(BinaryNode *node) { +BaseNode *ExtConstantFold::ExtFoldIor(BinaryNode *node) { CHECK_NULL_FATAL(node); // The target pattern (Cior, Lior): // x == c || x == c+1 || ... || x == c+k @@ -133,6 +213,73 @@ BaseNode *ExtConstantFold::FoldIor(BinaryNode *node) { } } +BaseNode *ExtConstantFold::ExtFoldXand(BinaryNode *node) { + // The target pattern (Cand, Land): + // (x & m1) == c1 && (x & m2) == c2 && ... && (x & Mk) == ck + // where mi and ci shall be all int constants + // ==> (x & M) == C + + CHECK_NULL_FATAL(node); + CHECK_FATAL(node->GetOpCode() == OP_cand || node->GetOpCode() == OP_land, "Operator is neither OP_cand nor OP_land"); + + BaseNode * lnode = DispatchFold(node->Opnd(0)); + if (lnode != node->Opnd(0)) { + node->SetOpnd(lnode, 0); + } + + BaseNode * rnode = DispatchFold(node->Opnd(1)); + if (rnode != node->Opnd(1)) { + node->SetOpnd(rnode, 1); + } + + // Check if it is of the form of (x & m) == c cand (x & m') == c' + if ((lnode->GetOpCode() == OP_eq) && (rnode->GetOpCode() == OP_eq) && + (lnode->Opnd(0)->GetOpCode() == OP_band) && + (lnode->Opnd(0)->Opnd(1)->GetOpCode() == OP_constval) && + (IsPrimitiveInteger(lnode->Opnd(0)->Opnd(1)->GetPrimType())) && + (lnode->Opnd(1)->GetOpCode() == OP_constval) && + (IsPrimitiveInteger(lnode->Opnd(1)->GetPrimType())) && + (rnode->Opnd(0)->GetOpCode() == OP_band) && + (rnode->Opnd(0)->Opnd(1)->GetOpCode() == OP_constval) && + (IsPrimitiveInteger(rnode->Opnd(0)->Opnd(1)->GetPrimType())) && + (rnode->Opnd(1)->GetOpCode() == OP_constval) && + (IsPrimitiveInteger(rnode->Opnd(1)->GetPrimType())) && + (lnode->Opnd(0)->Opnd(0)->IsSameContent(rnode->Opnd(0)->Opnd(0)))) { + MIRConst *lmConstVal = safe_cast(lnode->Opnd(0)->Opnd(1))->GetConstVal(); + uint64 lmVal = static_cast(static_cast(lmConstVal)->GetValue()); + MIRConst *rmConstVal = safe_cast(rnode->Opnd(0)->Opnd(1))->GetConstVal(); + uint64 rmVal = static_cast(static_cast(rmConstVal)->GetValue()); + MIRConst *lcConstVal = safe_cast(lnode->Opnd(1))->GetConstVal(); + uint64 lcVal = static_cast(static_cast(lcConstVal)->GetValue()); + MIRConst *rcConstVal = safe_cast(rnode->Opnd(1))->GetConstVal(); + uint64 rcVal = static_cast(static_cast(rcConstVal)->GetValue()); + + bool isWorkable = true; + for (uint32 i = 0; i < 64; i++) { + if ((lmVal & (1 << i)) == (rmVal & (1 << i)) && + (lcVal & (1 << i)) != (rcVal & (1 << i))) { + isWorkable = false; + break; + } + } + + if (isWorkable) { + uint64 mVal = lmVal | rmVal; + uint64 cVal = lcVal | rcVal; + PrimType mPrimType = lnode->Opnd(0)->Opnd(1)->GetPrimType(); + ConstvalNode *mIntConst = mirModule->GetMIRBuilder()->CreateIntConst(mVal, mPrimType); + PrimType cPrimType = lnode->Opnd(1)->GetPrimType(); + ConstvalNode *cIntConst = mirModule->GetMIRBuilder()->CreateIntConst(cVal, cPrimType); + BinaryNode *eqNode = static_cast(lnode); + BinaryNode *bandNode = static_cast(eqNode->Opnd(0)); + bandNode->SetOpnd(mIntConst, 1); + eqNode->SetOpnd(cIntConst, 1); + return eqNode; + } + } + return node; +} + StmtNode *ExtConstantFold::ExtSimplifyBlock(BlockNode *node) { CHECK_NULL_FATAL(node); if (node->GetFirst() == nullptr) { @@ -153,7 +300,7 @@ StmtNode *ExtConstantFold::ExtSimplifyIf(IfStmtNode *node) { (void)ExtSimplify(node->GetElsePart()); } BaseNode *origTest = node->Opnd(); - BaseNode *returnValue = Fold(node->Opnd()); + BaseNode *returnValue = ExtFold(node->Opnd()); if (returnValue != origTest) { node->SetOpnd(returnValue, 0); } @@ -163,7 +310,7 @@ StmtNode *ExtConstantFold::ExtSimplifyIf(IfStmtNode *node) { StmtNode *ExtConstantFold::ExtSimplifyDassign(DassignNode *node) { CHECK_NULL_FATAL(node); BaseNode *returnValue; - returnValue = Fold(node->GetRHS()); + returnValue = ExtFold(node->GetRHS()); if (returnValue != node->GetRHS()) { node->SetRHS(returnValue); } @@ -173,7 +320,7 @@ StmtNode *ExtConstantFold::ExtSimplifyDassign(DassignNode *node) { StmtNode *ExtConstantFold::ExtSimplifyIassign(IassignNode *node) { CHECK_NULL_FATAL(node); BaseNode *returnValue; - returnValue = Fold(node->GetRHS()); + returnValue = ExtFold(node->GetRHS()); if (returnValue != node->GetRHS()) { node->SetRHS(returnValue); } @@ -186,7 +333,7 @@ StmtNode *ExtConstantFold::ExtSimplifyWhile(WhileStmtNode *node) { if (node->Opnd(0) == nullptr) { return node; } - returnValue = Fold(node->Opnd(0)); + returnValue = ExtFold(node->Opnd(0)); if (returnValue != node->Opnd(0)) { node->SetOpnd(returnValue, 0); }