From 7e8832dc4da388c2ea1937551f36fcdbbeb9edbd Mon Sep 17 00:00:00 2001 From: Roman Rusyaev Date: Fri, 18 Feb 2022 17:18:39 +0300 Subject: [PATCH 1/5] Make temporary fix of integer overflow in constant folding optimization --- src/mapleall/mpl2mpl/src/constantfold.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mapleall/mpl2mpl/src/constantfold.cpp b/src/mapleall/mpl2mpl/src/constantfold.cpp index 555ebfe420..a659fd2f0e 100644 --- a/src/mapleall/mpl2mpl/src/constantfold.cpp +++ b/src/mapleall/mpl2mpl/src/constantfold.cpp @@ -1743,7 +1743,7 @@ std::pair ConstantFold::FoldBinary(BinaryNode *node) { result = l; } else if (op == OP_mul && lp.second != 0 && lp.second > -kMaxOffset) { // (X + konst) * rConst -> the pair [(X*rConst), (konst*rConst)] - sum = lp.second * cst; + sum = lp.second * static_cast(cst); if (GetPrimTypeSize(primType) > GetPrimTypeSize(lp.first->GetPrimType())) { lp.first = mirModule->CurFuncCodeMemPool()->New(OP_cvt, primType, PTY_i32, lp.first); } -- Gitee From 2dd22b627929a5c9bb47e645ba6c80cc1d23fedd Mon Sep 17 00:00:00 2001 From: Roman Rusyaev Date: Fri, 18 Feb 2022 17:20:26 +0300 Subject: [PATCH 2/5] Dump function IR before constant folding at CG phase Make refactoring related to function dumping at CG phase --- .../maple_be/src/cg/cg_phasemanager.cpp | 39 +++++++++++++------ 1 file changed, 28 insertions(+), 11 deletions(-) diff --git a/src/mapleall/maple_be/src/cg/cg_phasemanager.cpp b/src/mapleall/maple_be/src/cg/cg_phasemanager.cpp index a5a7f561d1..206ee0f49d 100644 --- a/src/mapleall/maple_be/src/cg/cg_phasemanager.cpp +++ b/src/mapleall/maple_be/src/cg/cg_phasemanager.cpp @@ -69,6 +69,24 @@ namespace maplebe { } \ } while (0) +namespace { + +void DumpMIRFunc(MIRFunction &func, const char *msg, bool printAlways = false, const char* extraMsg = nullptr) { + bool dumpAll = (CGOptions::GetDumpPhases().find("*") != CGOptions::GetDumpPhases().end()); + bool dumpFunc = CGOptions::FuncFilter(func.GetName()); + + if (printAlways || (dumpAll && dumpFunc)) { + LogInfo::MapleLogger() << msg << '\n'; + func.Dump(); + + if (extraMsg) { + LogInfo::MapleLogger() << extraMsg << '\n'; + } + } +} + +} // anonymous namespace + void CgFuncPM::GenerateOutPutFile(MIRModule &m) { CHECK_FATAL(cg != nullptr, "cg is null"); CHECK_FATAL(cg->GetEmitter(), "emitter is null"); @@ -143,6 +161,7 @@ bool CgFuncPM::PhaseRun(MIRModule &m) { /* LowerIR. */ m.SetCurFunction(mirFunc); if (cg->DoConstFold()) { + DumpMIRFunc(*mirFunc, "************* before ConstantFold **************"); ConstantFold cf(m); (void)cf.Simplify(mirFunc->GetBody()); } @@ -326,20 +345,18 @@ void CgFuncPM::DoFuncCGLower(const MIRModule &m, MIRFunction &mirFunc) { if (m.GetFlavor() <= kFeProduced) { mirLower->SetLowerCG(); mirLower->SetMirFunc(&mirFunc); + + DumpMIRFunc(mirFunc, "************* before MIRLowerer **************"); mirLower->LowerFunc(mirFunc); } - bool dumpAll = (CGOptions::GetDumpPhases().find("*") != CGOptions::GetDumpPhases().end()); - bool dumpFunc = CGOptions::FuncFilter(mirFunc.GetName()); - if (!cg->IsQuiet() || (dumpAll && dumpFunc)) { - LogInfo::MapleLogger() << "************* before CGLowerer **************" << '\n'; - mirFunc.Dump(); - } + + bool isNotQuiet = !cg->IsQuiet(); + DumpMIRFunc(mirFunc, "************* before CGLowerer **************", isNotQuiet); + cgLower->LowerFunc(mirFunc); - if (!cg->IsQuiet() || (dumpAll && dumpFunc)) { - LogInfo::MapleLogger() << "************* after CGLowerer **************" << '\n'; - mirFunc.Dump(); - LogInfo::MapleLogger() << "************* end CGLowerer **************" << '\n'; - } + + DumpMIRFunc(mirFunc, "************* after CGLowerer **************", isNotQuiet, + "************* end CGLowerer **************"); } void CgFuncPM::EmitDuplicatedAsmFunc(MIRModule &m) const { -- Gitee From 4637ff2a8ebf55a05368c6894ebe2350479cf34a Mon Sep 17 00:00:00 2001 From: Roman Rusyaev Date: Fri, 18 Feb 2022 17:22:58 +0300 Subject: [PATCH 3/5] Check that the result of integer constants folding is not overflowed while constant folding optimization --- src/mapleall/maple_me/src/irmap.cpp | 2 +- src/mapleall/mpl2mpl/include/constantfold.h | 2 +- src/mapleall/mpl2mpl/src/constantfold.cpp | 141 +++++++++++--------- 3 files changed, 81 insertions(+), 64 deletions(-) diff --git a/src/mapleall/maple_me/src/irmap.cpp b/src/mapleall/maple_me/src/irmap.cpp index 9c4dbe3d2f..67e9a56ab0 100644 --- a/src/mapleall/maple_me/src/irmap.cpp +++ b/src/mapleall/maple_me/src/irmap.cpp @@ -1225,7 +1225,7 @@ MeExpr *IRMap::SimplifyAddExpr(const OpMeExpr *addExpr) { if (opnd1->GetMeOp() == kMeOpConst) { auto constA = static_cast(opndA)->GetIntValue(); auto constB = static_cast(opnd1)->GetIntValue(); - if (ConstantFold::IntegerOpIsOverflow(OP_add, opndA->GetPrimType(), constA, constB)) { + if (ConstantFold::IsIntegerOpOverflow(OP_add, opndA->GetPrimType(), constA, constB)) { return nullptr; } retOpMeExpr = static_cast(CreateCanonicalizedMeExpr(addExpr->GetPrimType(), OP_sub, OP_add, diff --git a/src/mapleall/mpl2mpl/include/constantfold.h b/src/mapleall/mpl2mpl/include/constantfold.h index e14ae4bbdf..d1fd4e4790 100644 --- a/src/mapleall/mpl2mpl/include/constantfold.h +++ b/src/mapleall/mpl2mpl/include/constantfold.h @@ -53,7 +53,7 @@ class ConstantFold : public FuncOptimizeImpl { MIRConst *FoldIntConstBinaryMIRConst(Opcode opcode, PrimType resultType, const MIRIntConst *intConst0, const MIRIntConst *intConst1) const; MIRConst *FoldConstComparisonMIRConst(Opcode, PrimType, PrimType, const MIRConst&, const MIRConst&); - static bool IntegerOpIsOverflow(Opcode op, PrimType primType, int64 cstA, int64 cstB); + static bool IsIntegerOpOverflow(Opcode op, PrimType primType, int64 cstA, int64 cstB); private: StmtNode *SimplifyBinary(BinaryStmtNode *node); StmtNode *SimplifyBlock(BlockNode *node); diff --git a/src/mapleall/mpl2mpl/src/constantfold.cpp b/src/mapleall/mpl2mpl/src/constantfold.cpp index a659fd2f0e..3547160315 100644 --- a/src/mapleall/mpl2mpl/src/constantfold.cpp +++ b/src/mapleall/mpl2mpl/src/constantfold.cpp @@ -24,19 +24,6 @@ #include "me_option.h" #include "maple_phase_manager.h" -namespace { -constexpr maple::uint64 kJsTypeNumber = 4; -constexpr maple::uint64 kJsTypeNumberInHigh32Bit = kJsTypeNumber << 32; // set high 32 bit as JSTYPE_NUMBER -constexpr maple::uint32 kByteSizeOfBit64 = 8; // byte number for 64 bit -constexpr maple::uint32 kBitSizePerByte = 8; -constexpr maple::int32 kMaxOffset = INT_MAX - 8; -enum CompareRes : maple::int64 { - kLess = -1, - kEqual = 0, - kGreater = 1 -}; -} - namespace maple { // This phase is designed to achieve compiler optimization by // simplifying constant expressions. The constant expression @@ -47,6 +34,41 @@ namespace maple { // A. Analyze expression type // B. Analysis operator type // C. Replace the expression with the result of the operation + +namespace { + +constexpr uint64 kJsTypeNumber = 4; +constexpr uint64 kJsTypeNumberInHigh32Bit = kJsTypeNumber << 32; // set high 32 bit as JSTYPE_NUMBER +constexpr uint32 kByteSizeOfBit64 = 8; // byte number for 64 bit +constexpr uint32 kBitSizePerByte = 8; +constexpr int32 kMaxOffset = INT_MAX - 8; + +enum CompareRes : int64 { + kLess = -1, + kEqual = 0, + kGreater = 1 +}; + +template +std::pair GetPrimTypeIntMinMax(PrimType type, bool isSigned) { + static_assert(std::is_integral_v && std::is_integral_v, "types must be integral"); + + switch (GetPrimTypeSize(type)) { + case 1: + return { isSigned ? INT8_MIN : 0, isSigned ? INT8_MAX : UINT8_MAX }; + case 2: + return { isSigned ? INT16_MIN : 0, isSigned ? INT16_MAX : UINT16_MAX }; + case 4: + return { isSigned ? INT32_MIN : 0, isSigned ? INT32_MAX : UINT32_MAX }; + case 8: + return { isSigned ? INT64_MIN : 0, isSigned ? INT64_MAX : UINT64_MAX }; + default: + CHECK_FATAL(false, "Unsupported integer type size"); + } +} + +} // anonymous namespace + BinaryNode *ConstantFold::NewBinaryNode(BinaryNode *old, Opcode op, PrimType primType, BaseNode *lhs, BaseNode *rhs) const { CHECK_NULL_FATAL(old); @@ -1575,32 +1597,36 @@ std::pair ConstantFold::FoldIread(IreadNode *node) { return std::make_pair(result, 0); } -bool ConstantFold::IntegerOpIsOverflow(Opcode op, PrimType primType, int64 cstA, int64 cstB) { - switch (op ){ - case OP_add: { - int64 res = static_cast(static_cast(cstA) + static_cast(cstB)); - if (IsUnsignedInteger(primType)) { - return static_cast(res) < static_cast(cstA); +bool ConstantFold::IsIntegerOpOverflow(Opcode op, PrimType primType, int64 cstA, int64 cstB) { + auto [sIntMin, sIntMax] = GetPrimTypeIntMinMax(primType, /*is_signed*/ true); + uint64 uIntMax; + std::tie(std::ignore, uIntMax) = GetPrimTypeIntMinMax(primType, /*is_signed*/ false); + + uint64 uCstA = cstA; + uint64 uCstB = cstB; + + bool isUnsigned = IsUnsignedInteger(primType); + ASSERT(isUnsigned || IsSignedInteger(primType), "Integer type must be signed or unsigned"); + + switch (op) { + case OP_add: + return isUnsigned ? (uIntMax - uCstA) < uCstB + : ((cstB > 0 && cstA > (sIntMax - cstB)) || (cstB < 0 && cstA < (sIntMin - cstB))); + case OP_sub: + return isUnsigned ? uCstA < uCstB + : (cstB > 0 && cstA < (sIntMin + cstB)) || (cstB < 0 && cstA > (sIntMax + cstB)); + case OP_mul: + if (isUnsigned) { + return cstA > (uIntMax / cstB); } - auto rightShiftNumToGetSignFlag = GetPrimTypeBitSize(primType) - 1; - return (static_cast(res) >> rightShiftNumToGetSignFlag != - static_cast(cstA) >> rightShiftNumToGetSignFlag) && - (static_cast(res) >> rightShiftNumToGetSignFlag != - static_cast(cstB) >> rightShiftNumToGetSignFlag ); - } - case OP_sub: { - if (IsUnsignedInteger(primType)) { - return cstA < cstB; + + if (cstA > 0) { + return (cstB > 0) ? cstA > (sIntMax / cstB) : cstB < (sIntMin / cstA); + } else { + return (cstB > 0) ? cstA < (sIntMin / cstB) : cstA != 0 && cstB < (sIntMax / cstA); } - int64 res = static_cast(static_cast(cstA) - static_cast(cstB)); - auto rightShiftNumToGetSignFlag = GetPrimTypeBitSize(primType) - 1; - return (static_cast(cstA) >> rightShiftNumToGetSignFlag != - static_cast(cstB) >> rightShiftNumToGetSignFlag) && - (static_cast(res) >> rightShiftNumToGetSignFlag != - static_cast(cstA) >> rightShiftNumToGetSignFlag); - } default: { - return false; + CHECK_FATAL(false, "NIY"); } } } @@ -1644,12 +1670,12 @@ std::pair ConstantFold::FoldBinary(BinaryNode *node) { } else if (lConst != nullptr && isInt) { MIRIntConst *mcst = safe_cast(lConst->GetConstVal()); ASSERT_NOT_NULL(mcst); - PrimType cstTyp = mcst->GetType().GetPrimType(); + PrimType cstType = mcst->GetType().GetPrimType(); int64 cst = mcst->GetValue(); - if (op == OP_add) { + if (op == OP_add && !IsIntegerOpOverflow(op, cstType, cst, rp.second)) { sum = cst + rp.second; result = r; - } else if (op == OP_sub && r->GetPrimType() != PTY_u1) { + } else if (op == OP_sub && r->GetPrimType() != PTY_u1 && !IsIntegerOpOverflow(op, cstType, cst, rp.second)) { // We exclude u1 type for fixing the following wrong example: // before cf: // sub i32 (constval i32 17, eq u1 i32 (dread i32 %i, constval i32 16))) @@ -1668,7 +1694,7 @@ std::pair ConstantFold::FoldBinary(BinaryNode *node) { // 0 & X -> 0 // 0 && X -> 0 sum = 0; - result = mirModule->GetMIRBuilder()->CreateIntConst(0, cstTyp); + result = mirModule->GetMIRBuilder()->CreateIntConst(0, cstType); } else if (op == OP_mul && cst == 1) { // 1 * X --> X sum = rp.second; @@ -1676,8 +1702,8 @@ std::pair ConstantFold::FoldBinary(BinaryNode *node) { } else if (op == OP_bior && cst == -1) { // (-1) | X -> -1 sum = 0; - result = mirModule->GetMIRBuilder()->CreateIntConst(-1, cstTyp); - } else if (op == OP_mul && rp.second != 0) { + result = mirModule->GetMIRBuilder()->CreateIntConst(-1, cstType); + } else if (op == OP_mul && rp.second != 0 && !IsIntegerOpOverflow(op, cstType, cst, rp.second)) { // lConst * (X + konst) -> the pair [(lConst*X), (lConst*konst)] sum = cst * rp.second; if (GetPrimTypeSize(primType) > GetPrimTypeSize(rp.first->GetPrimType())) { @@ -1688,7 +1714,7 @@ std::pair ConstantFold::FoldBinary(BinaryNode *node) { sum = 0; if (cst != 0) { // 5 || X -> 1 - result = mirModule->GetMIRBuilder()->CreateIntConst(1, cstTyp); + result = mirModule->GetMIRBuilder()->CreateIntConst(1, cstType); } else { // when cst is zero // 0 || X -> (X != 0); @@ -1715,33 +1741,24 @@ std::pair ConstantFold::FoldBinary(BinaryNode *node) { } else if (rConst != nullptr && isInt) { MIRIntConst *mcst = safe_cast(rConst->GetConstVal()); ASSERT_NOT_NULL(mcst); - PrimType cstTyp = mcst->GetType().GetPrimType(); + PrimType cstType = mcst->GetType().GetPrimType(); int64 cst = mcst->GetValue(); - if (op == OP_add) { - if (IntegerOpIsOverflow(op, cstTyp, lp.second, cst)) { - result = NewBinaryNode(node, op, primType, PairToExpr(lPrimTypes, lp), PairToExpr(rPrimTypes, rp)); - sum = 0; - } else { - result = l; - sum = lp.second + cst; - } - } else if (op == OP_sub && cst != INT_MIN) { - { - result = l; - sum = lp.second - cst; - } + if ((op == OP_add || op == OP_sub) && !IsIntegerOpOverflow(op, cstType, lp.second, cst)) { + result = l; + sum = (op == OP_add) ? lp.second + cst : lp.second - cst; } else if ((op == OP_mul || op == OP_band || op == OP_cand || op == OP_land) && cst == 0) { // X * 0 -> 0 // X & 0 -> 0 // X && 0 -> 0 sum = 0; - result = mirModule->GetMIRBuilder()->CreateIntConst(0, cstTyp); + result = mirModule->GetMIRBuilder()->CreateIntConst(0, cstType); } else if ((op == OP_mul || op == OP_div) && cst == 1) { // case [X * 1 -> X] // case [X / 1 = X] sum = lp.second; result = l; - } else if (op == OP_mul && lp.second != 0 && lp.second > -kMaxOffset) { + } else if (op == OP_mul && lp.second != 0 && lp.second > -kMaxOffset && + !IsIntegerOpOverflow(op, cstType, lp.second, cst)) { // (X + konst) * rConst -> the pair [(X*rConst), (konst*rConst)] sum = lp.second * static_cast(cst); if (GetPrimTypeSize(primType) > GetPrimTypeSize(lp.first->GetPrimType())) { @@ -1755,12 +1772,12 @@ std::pair ConstantFold::FoldBinary(BinaryNode *node) { } else if (op == OP_bior && cst == -1) { // X | (-1) -> -1 sum = 0; - result = mirModule->GetMIRBuilder()->CreateIntConst(-1, cstTyp); + result = mirModule->GetMIRBuilder()->CreateIntConst(-1, cstType); } else if ((op == OP_lior || op == OP_cior)) { sum = 0; if (cst > 0) { // X || 5 -> 1 - result = mirModule->GetMIRBuilder()->CreateIntConst(1, cstTyp); + result = mirModule->GetMIRBuilder()->CreateIntConst(1, cstType); } else if (cst == 0) { // X || 0 -> X sum = lp.second; @@ -1794,7 +1811,7 @@ std::pair ConstantFold::FoldBinary(BinaryNode *node) { } else if (op == OP_rem && cst == 1) { // X % 1 -> 0 sum = 0; - result = mirModule->GetMIRBuilder()->CreateIntConst(0, cstTyp); + result = mirModule->GetMIRBuilder()->CreateIntConst(0, cstType); } else { result = NewBinaryNode(node, op, primType, PairToExpr(lPrimTypes, lp), r); sum = 0; -- Gitee From 273a392ede00f17016088665263c5c70278224d8 Mon Sep 17 00:00:00 2001 From: Roman Rusyaev Date: Fri, 18 Feb 2022 18:10:08 +0300 Subject: [PATCH 4/5] Remove redundant check for a32 type in MIRIntConst constructor Add check that value fits into give type --- src/mapleall/maple_ir/include/mir_const.h | 12 ++++++++---- src/mapleall/maple_ir/src/mir_const.cpp | 22 ++++++++++++++++++++++ 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/src/mapleall/maple_ir/include/mir_const.h b/src/mapleall/maple_ir/include/mir_const.h index a8abffac85..c06d3d203f 100644 --- a/src/mapleall/maple_ir/include/mir_const.h +++ b/src/mapleall/maple_ir/include/mir_const.h @@ -94,13 +94,15 @@ class MIRIntConst : public MIRConst { public: using value_type = int64; MIRIntConst(int64 val, MIRType &type) : MIRConst(type, kConstInt), value(val) { - if (!IsPrimitiveDynType(type.GetPrimType())) { + auto primType = type.GetPrimType(); + + if (!IsPrimitiveDynType(primType)) { + ASSERT(DoeValueFitIntoPrimType(primType), "value doesn't fit into primitive type"); + if (type.GetPrimType() == PTY_u128 || type.GetPrimType() == PTY_i128) { Trunc(64u); - } else if (type.GetPrimType() == PTY_a32) { - CHECK_FATAL(val <= INT32_MAX && val >= INT32_MIN, "address out of range"); } else { - Trunc(GetPrimTypeBitSize(type.GetPrimType())); + Trunc(GetPrimTypeBitSize(primType)); } } } @@ -155,6 +157,8 @@ class MIRIntConst : public MIRConst { } private: + bool DoeValueFitIntoPrimType(PrimType primType) const; + int64 value; }; diff --git a/src/mapleall/maple_ir/src/mir_const.cpp b/src/mapleall/maple_ir/src/mir_const.cpp index f4783781e5..39e51cf766 100644 --- a/src/mapleall/maple_ir/src/mir_const.cpp +++ b/src/mapleall/maple_ir/src/mir_const.cpp @@ -83,6 +83,28 @@ int64 MIRIntConst::GetValueUnderType() const { return static_cast((unsignedVal << shiftBitNum) >> shiftBitNum); } +bool MIRIntConst::DoeValueFitIntoPrimType(PrimType primType) const { + bool isUnsigned = IsPrimitiveUnsigned(primType); + uint64 uValue = value; + +#define VALUE_FIT_INTO_SIZE(S) \ + (isUnsigned ? static_cast(value) == uValue : static_cast(value) == value) + + switch (GetPrimTypeSize(primType)) { + case 1: + return VALUE_FIT_INTO_SIZE(8); + case 2: + return VALUE_FIT_INTO_SIZE(16); + case 4: + return VALUE_FIT_INTO_SIZE(32); + default: + // sizes that are equal to or bigger than 64-bit or unknown sizes + return true; + } + +#undef VALUE_FIT_INTO_SIZE +} + void MIRAddrofConst::Dump(const MIRSymbolTable *localSymTab) const { LogInfo::MapleLogger() << "addrof " << GetPrimTypeName(PTY_ptr); const MIRSymbol *sym = stIdx.IsGlobal() ? GlobalTables::GetGsymTable().GetSymbolFromStidx(stIdx.Idx()) -- Gitee From f9cf63ca4045ddf179efe7a0285ba018fa5391a2 Mon Sep 17 00:00:00 2001 From: Roman Rusyaev Date: Fri, 18 Feb 2022 21:15:45 +0300 Subject: [PATCH 5/5] Add bit-width checker in constant constructor Add additional checkers on integer overflow Add overflow check for rem and div operations --- src/mapleall/maple_ir/include/mir_const.h | 12 ++-- src/mapleall/maple_ir/src/mir_const.cpp | 22 ------- src/mapleall/maple_me/src/irmap.cpp | 19 ++---- src/mapleall/mpl2mpl/include/constantfold.h | 1 + src/mapleall/mpl2mpl/src/constantfold.cpp | 67 +++++++++++++++------ 5 files changed, 61 insertions(+), 60 deletions(-) diff --git a/src/mapleall/maple_ir/include/mir_const.h b/src/mapleall/maple_ir/include/mir_const.h index c06d3d203f..7652c76c29 100644 --- a/src/mapleall/maple_ir/include/mir_const.h +++ b/src/mapleall/maple_ir/include/mir_const.h @@ -97,12 +97,18 @@ class MIRIntConst : public MIRConst { auto primType = type.GetPrimType(); if (!IsPrimitiveDynType(primType)) { - ASSERT(DoeValueFitIntoPrimType(primType), "value doesn't fit into primitive type"); + // FIXME: 'typeSize < 32' condition is a temporary workaround. Currently, FoldIntConstBinaryMIRConst has + // incorrect implementation because it performs arithmetic operations on int32 and int64 types. + // But these operations should be performed based on real constant types in IR. + // It's necessary to implement operator+, operator- etc in MIRIntConst class + // to make constant folding correctly. + uint32 typeSize = GetPrimTypeBitSize(primType); + ASSERT(GetBitWidth() <= typeSize || typeSize < 32, "value doesn't fit into primitive type"); if (type.GetPrimType() == PTY_u128 || type.GetPrimType() == PTY_i128) { Trunc(64u); } else { - Trunc(GetPrimTypeBitSize(primType)); + Trunc(typeSize); } } } @@ -157,8 +163,6 @@ class MIRIntConst : public MIRConst { } private: - bool DoeValueFitIntoPrimType(PrimType primType) const; - int64 value; }; diff --git a/src/mapleall/maple_ir/src/mir_const.cpp b/src/mapleall/maple_ir/src/mir_const.cpp index 39e51cf766..f4783781e5 100644 --- a/src/mapleall/maple_ir/src/mir_const.cpp +++ b/src/mapleall/maple_ir/src/mir_const.cpp @@ -83,28 +83,6 @@ int64 MIRIntConst::GetValueUnderType() const { return static_cast((unsignedVal << shiftBitNum) >> shiftBitNum); } -bool MIRIntConst::DoeValueFitIntoPrimType(PrimType primType) const { - bool isUnsigned = IsPrimitiveUnsigned(primType); - uint64 uValue = value; - -#define VALUE_FIT_INTO_SIZE(S) \ - (isUnsigned ? static_cast(value) == uValue : static_cast(value) == value) - - switch (GetPrimTypeSize(primType)) { - case 1: - return VALUE_FIT_INTO_SIZE(8); - case 2: - return VALUE_FIT_INTO_SIZE(16); - case 4: - return VALUE_FIT_INTO_SIZE(32); - default: - // sizes that are equal to or bigger than 64-bit or unknown sizes - return true; - } - -#undef VALUE_FIT_INTO_SIZE -} - void MIRAddrofConst::Dump(const MIRSymbolTable *localSymTab) const { LogInfo::MapleLogger() << "addrof " << GetPrimTypeName(PTY_ptr); const MIRSymbol *sym = stIdx.IsGlobal() ? GlobalTables::GetGsymTable().GetSymbolFromStidx(stIdx.Idx()) diff --git a/src/mapleall/maple_me/src/irmap.cpp b/src/mapleall/maple_me/src/irmap.cpp index 67e9a56ab0..03eb1fa40a 100644 --- a/src/mapleall/maple_me/src/irmap.cpp +++ b/src/mapleall/maple_me/src/irmap.cpp @@ -1072,12 +1072,8 @@ MeExpr *IRMap::FoldConstExpr(PrimType primType, Opcode op, ConstMeExpr *opndA, C maple::ConstantFold cf(mirModule); auto *constA = static_cast(opndA->GetConstVal()); auto *constB = static_cast(opndB->GetConstVal()); - if ((op == OP_div || op == OP_rem)) { - if (constB->GetValue() == 0 || - (constB->GetValue() == -1 && ((primType == PTY_i32 && constA->GetValue() == INT32_MIN) || - (primType == PTY_i64 && constA->GetValue() == INT64_MIN)))) { - return nullptr; - } + if (cf.IsIntegerOpOverflow(op, primType, constA->GetValue(), constB->GetValue())) { + return nullptr; } MIRConst *resconst = cf.FoldIntConstBinaryMIRConst(op, primType, constA, constB); return CreateConstMeExpr(primType, *resconst); @@ -1815,15 +1811,8 @@ MeExpr *IRMap::SimplifyOpMeExpr(OpMeExpr *opmeexpr) { maple::ConstantFold cf(mirModule); MIRIntConst *opnd0const = static_cast(static_cast(opnd0)->GetConstVal()); MIRIntConst *opnd1const = static_cast(static_cast(opnd1)->GetConstVal()); - if ((opop == OP_div || opop == OP_rem)) { - int64 opnd0constValue = opnd0const->GetValue(); - int64 opnd1constValue = opnd1const->GetValue(); - PrimType resPtyp = opmeexpr->GetPrimType(); - if (opnd1constValue == 0 || - (opnd1constValue == -1 && ((resPtyp == PTY_i32 && opnd0constValue == INT32_MIN) || - (resPtyp == PTY_i64 && opnd0constValue == INT64_MIN)))) { - return nullptr; - } + if (cf.IsIntegerOpOverflow(opop, opmeexpr->GetPrimType(), opnd0const->GetValue(), opnd1const->GetValue())) { + return nullptr; } MIRConst *resconst = cf.FoldIntConstBinaryMIRConst(opmeexpr->GetOp(), opmeexpr->GetPrimType(), opnd0const, opnd1const); diff --git a/src/mapleall/mpl2mpl/include/constantfold.h b/src/mapleall/mpl2mpl/include/constantfold.h index d1fd4e4790..89d3cdb0be 100644 --- a/src/mapleall/mpl2mpl/include/constantfold.h +++ b/src/mapleall/mpl2mpl/include/constantfold.h @@ -53,6 +53,7 @@ class ConstantFold : public FuncOptimizeImpl { MIRConst *FoldIntConstBinaryMIRConst(Opcode opcode, PrimType resultType, const MIRIntConst *intConst0, const MIRIntConst *intConst1) const; MIRConst *FoldConstComparisonMIRConst(Opcode, PrimType, PrimType, const MIRConst&, const MIRConst&); + static bool IsIntegerOpOverflow(Opcode op, PrimType primType, const ConstvalNode& cstA, const ConstvalNode& cstB); static bool IsIntegerOpOverflow(Opcode op, PrimType primType, int64 cstA, int64 cstB); private: StmtNode *SimplifyBinary(BinaryStmtNode *node); diff --git a/src/mapleall/mpl2mpl/src/constantfold.cpp b/src/mapleall/mpl2mpl/src/constantfold.cpp index 3547160315..c007caddfa 100644 --- a/src/mapleall/mpl2mpl/src/constantfold.cpp +++ b/src/mapleall/mpl2mpl/src/constantfold.cpp @@ -67,6 +67,21 @@ std::pair GetPrimTypeIntMinMax(PrimType type, bool isSigned) { } } +bool isValueMinimumForPrimType(int64 value, PrimType type) { + switch (GetPrimTypeSize(type)) { + case 1: + return value == INT8_MIN; + case 2: + return value == INT16_MIN; + case 4: + return value == INT32_MIN; + case 8: + return value == INT64_MIN; + default: + CHECK_FATAL(false, "Unsupported integer type size"); + } +} + } // anonymous namespace BinaryNode *ConstantFold::NewBinaryNode(BinaryNode *old, Opcode op, PrimType primType, @@ -104,7 +119,7 @@ BaseNode *ConstantFold::PairToExpr(PrimType resultType, const std::pair(pair.first)->Opnd(0); result = mirModule->CurFuncCodeMemPool()->New(OP_sub, resultType, val, r); } else { - if (pair.second > 0 || pair.second == LLONG_MIN) { + if (pair.second > 0 || isValueMinimumForPrimType(pair.second, resultType)) { // +-a, 5 -> a + 5 ConstvalNode *val = mirModule->GetMIRBuilder()->CreateIntConst(pair.second, resultType); result = mirModule->CurFuncCodeMemPool()->New(OP_add, resultType, pair.first, val); @@ -1384,8 +1399,13 @@ MIRConst *ConstantFold::FoldTypeCvtMIRConst(const MIRConst &cst, PrimType fromTy } if (IsPrimitiveInteger(fromType) && IsPrimitiveInteger(toType)) { MIRConst *toConst = nullptr; + + const MIRIntConst *constVal = safe_cast(cst); + ASSERT_NOT_NULL(constVal); + uint32 fromSize = GetPrimTypeBitSize(fromType); uint32 toSize = GetPrimTypeBitSize(toType); + // GetPrimTypeBitSize(PTY_u1) will return 8, which is not expected here. if (fromType == PTY_u1) { fromSize = 1; @@ -1399,9 +1419,7 @@ MIRConst *ConstantFold::FoldTypeCvtMIRConst(const MIRConst &cst, PrimType fromTy op = OP_sext; } toConst = FoldSignExtendMIRConst(op, toType, fromSize, cst); - } else { - const MIRIntConst *constVal = safe_cast(cst); - ASSERT_NOT_NULL(constVal); + } else if (toSize >= constVal->GetBitWidth()) { MIRType &type = *GlobalTables::GetTypeTable().GetPrimType(toType); toConst = GlobalTables::GetIntConstTable().GetOrCreateIntConst(constVal->GetValue(), type); @@ -1597,6 +1615,22 @@ std::pair ConstantFold::FoldIread(IreadNode *node) { return std::make_pair(result, 0); } +bool ConstantFold::IsIntegerOpOverflow(Opcode op, PrimType primType, const ConstvalNode &nodeA, + const ConstvalNode &nodeB) { + if (!IsPrimitiveInteger(primType)) { + return false; + } + + const MIRConst *valA = nodeA.GetConstVal(); + const MIRConst *valB = nodeB.GetConstVal(); + ASSERT_NOT_NULL(valA); + ASSERT_NOT_NULL(valB); + int64 cstA = static_cast(valA)->GetValue(); + int64 cstB = static_cast(valB)->GetValue(); + + return IsIntegerOpOverflow(op, primType, cstA, cstB); +} + bool ConstantFold::IsIntegerOpOverflow(Opcode op, PrimType primType, int64 cstA, int64 cstB) { auto [sIntMin, sIntMax] = GetPrimTypeIntMinMax(primType, /*is_signed*/ true); uint64 uIntMax; @@ -1625,9 +1659,12 @@ bool ConstantFold::IsIntegerOpOverflow(Opcode op, PrimType primType, int64 cstA, } else { return (cstB > 0) ? cstA < (sIntMin / cstB) : cstA != 0 && cstB < (sIntMax / cstA); } - default: { - CHECK_FATAL(false, "NIY"); - } + case OP_div: + case OP_rem: + return !isUnsigned && (cstB == 0 || (cstA == sIntMin && cstB == -1)); + default: + // TODO: implement check for shift-left and shift-right operations + return false; } } @@ -1647,15 +1684,8 @@ std::pair ConstantFold::FoldBinary(BinaryNode *node) { ConstvalNode *rConst = safe_cast(r); bool isInt = IsPrimitiveInteger(primType); if (lConst != nullptr && rConst != nullptr) { - MIRConst *lConstVal = lConst->GetConstVal(); - MIRConst *rConstVal = rConst->GetConstVal(); - ASSERT_NOT_NULL(lConstVal); - ASSERT_NOT_NULL(rConstVal); - // Don't fold div by 0, for floats div by 0 is well defined. - if ((op == OP_div || op == OP_rem) && isInt && - (static_cast(rConstVal)->GetValue() == 0 || - static_cast(lConstVal)->GetValue() == LONG_MIN || - static_cast(lConstVal)->GetValue() == INT_MIN)) { + // Don't fold if operation on integer number overflows. + if (isInt && IsIntegerOpOverflow(op, primType, *lConst, *rConst)) { result = NewBinaryNode(node, op, primType, lConst, rConst); sum = 0; } else { @@ -1705,7 +1735,7 @@ std::pair ConstantFold::FoldBinary(BinaryNode *node) { result = mirModule->GetMIRBuilder()->CreateIntConst(-1, cstType); } else if (op == OP_mul && rp.second != 0 && !IsIntegerOpOverflow(op, cstType, cst, rp.second)) { // lConst * (X + konst) -> the pair [(lConst*X), (lConst*konst)] - sum = cst * rp.second; + sum = cst * static_cast(rp.second); if (GetPrimTypeSize(primType) > GetPrimTypeSize(rp.first->GetPrimType())) { rp.first = mirModule->CurFuncCodeMemPool()->New(OP_cvt, primType, PTY_i32, rp.first); } @@ -1757,8 +1787,7 @@ std::pair ConstantFold::FoldBinary(BinaryNode *node) { // case [X / 1 = X] sum = lp.second; result = l; - } else if (op == OP_mul && lp.second != 0 && lp.second > -kMaxOffset && - !IsIntegerOpOverflow(op, cstType, lp.second, cst)) { + } else if (op == OP_mul && lp.second != 0 && !IsIntegerOpOverflow(op, cstType, lp.second, cst)) { // (X + konst) * rConst -> the pair [(X*rConst), (konst*rConst)] sum = lp.second * static_cast(cst); if (GetPrimTypeSize(primType) > GetPrimTypeSize(lp.first->GetPrimType())) { -- Gitee