diff --git a/src/mapleall/maple_be/include/be/lower.h b/src/mapleall/maple_be/include/be/lower.h index cc855b64fda77e0f90b9729f2ce1bfdc15cf5299..70e9e00d4db3966936a3a41205bf1b391e907599 100644 --- a/src/mapleall/maple_be/include/be/lower.h +++ b/src/mapleall/maple_be/include/be/lower.h @@ -130,6 +130,10 @@ class CGLowerer { void AddElemToPrintf(MapleVector &argsPrintf, int num, ...) const; + bool CheckSwitchTableContinuous(SwitchNode &stmt) const; + + bool IsSwitchToRangeGoto(const BlockNode &blk) const; + std::string AssertBoundaryGetFileName(StmtNode &stmt) { size_t pos = mirModule.GetFileNameFromFileNum(stmt.GetSrcPos().FileNum()).rfind('/'); return mirModule.GetFileNameFromFileNum(stmt.GetSrcPos().FileNum()).substr(pos + 1); diff --git a/src/mapleall/maple_be/include/be/switch_lowerer.h b/src/mapleall/maple_be/include/be/switch_lowerer.h index 56c63aa97056374d46de2f31f64be5584405950c..38467837836606acc5e9aa7112ad38634a3c25a7 100644 --- a/src/mapleall/maple_be/include/be/switch_lowerer.h +++ b/src/mapleall/maple_be/include/be/switch_lowerer.h @@ -32,7 +32,7 @@ class SwitchLowerer { ~SwitchLowerer() = default; - maple::BlockNode *LowerSwitch(); + maple::BlockNode *LowerSwitch(LabelIdx newLabelIdx = 0); private: using Cluster = std::pair; @@ -57,12 +57,12 @@ class SwitchLowerer { void FindClusters(MapleVector &clusters) const; void InitSwitchItems(MapleVector &clusters); - maple::RangeGotoNode *BuildRangeGotoNode(int32 startIdx, int32 endIdx); + maple::RangeGotoNode *BuildRangeGotoNode(int32 startIdx, int32 endIdx, LabelIdx newLabelIdx); maple::CompareNode *BuildCmpNode(Opcode opCode, uint32 idx); maple::GotoNode *BuildGotoNode(int32 idx); maple::CondGotoNode *BuildCondGotoNode(int32 idx, Opcode opCode, BaseNode &cond); maple::BlockNode *BuildCodeForSwitchItems(int32 start, int32 end, bool lowBlockNodeChecked, - bool highBlockNodeChecked); + bool highBlockNodeChecked, LabelIdx newLabelIdx = 0); }; } /* namespace maplebe */ diff --git a/src/mapleall/maple_be/src/be/lower.cpp b/src/mapleall/maple_be/src/be/lower.cpp index bb39897bab9b9824045dcf33384d070df513c9e9..1ec81028166195d7568dc9022962255e18478198 100644 --- a/src/mapleall/maple_be/src/be/lower.cpp +++ b/src/mapleall/maple_be/src/be/lower.cpp @@ -89,6 +89,10 @@ const std::string kFileSymbolNamePrefix = "symname"; const std::string CGLowerer::kIntrnRetValPrefix = "__iret"; const std::string CGLowerer::kUserRetValPrefix = "__uret"; +static bool CasePairKeyLessThan(const CasePair &left, const CasePair &right) { + return left.first < right.first; +} + std::string CGLowerer::GetFileNameSymbolName(const std::string &fileName) const { return kFileSymbolNamePrefix + std::regex_replace(fileName, std::regex("-"), "_"); } @@ -1725,6 +1729,26 @@ void CGLowerer::AddElemToPrintf(MapleVector &argsPrintf, int num, ... va_end(argPtr); } +bool CGLowerer::CheckSwitchTableContinuous(SwitchNode &stmt) const { + if (!stmt.GetSwitchTable().empty()) { + stmt.SortCasePair(CasePairKeyLessThan); + if (static_cast(((stmt.GetSwitchTable().end() - 1)->first - stmt.GetSwitchTable().begin()->first) + 1) > + stmt.GetSwitchTable().size()) { + return false; + } + } + return true; +} + +bool CGLowerer::IsSwitchToRangeGoto(const BlockNode &blk) const { + for (const StmtNode *stmt = blk.GetFirst(); stmt != nullptr; stmt = stmt->GetNext()) { + if (stmt->GetOpCode() == OP_rangegoto) { + return true; + } + } + return false; +} + void CGLowerer::SwitchAssertBoundary(StmtNode &stmt, MapleVector &argsPrintf) { MIRSymbol *errMsg; MIRSymbol *fileNameSym; @@ -1845,8 +1869,18 @@ BlockNode *CGLowerer::LowerBlock(BlockNode &block) { LowerSwitchOpnd(*stmt, *newBlk); auto switchMp = std::make_unique(memPoolCtrler, "switchlowere"); MapleAllocator switchAllocator(switchMp.get()); + LabelNode *defaultLabel = nullptr; + LabelIdx newLabelIdx = 0; + if (!CheckSwitchTableContinuous(static_cast(*stmt)) && + static_cast(stmt)->GetDefaultLabel() == 0) { + newLabelIdx = GetLabelIdx(*mirModule.CurFunction()); + defaultLabel = mirBuilder->CreateStmtLabel(newLabelIdx); + } SwitchLowerer switchLowerer(mirModule, static_cast(*stmt), switchAllocator); - BlockNode *blk = switchLowerer.LowerSwitch(); + BlockNode *blk = switchLowerer.LowerSwitch(newLabelIdx); + if (blk->GetFirst() != nullptr && defaultLabel != nullptr && IsSwitchToRangeGoto(*blk)) { + blk->AddStatement(defaultLabel); + } if (blk->GetFirst() != nullptr) { newBlk->AppendStatementsFromBlock(*blk); } diff --git a/src/mapleall/maple_be/src/be/switch_lowerer.cpp b/src/mapleall/maple_be/src/be/switch_lowerer.cpp index f98360292a6a6382f6786aa5c8240002cb45c582..f22950bf629819e7e8adf71ac5d31909e1903ec8 100644 --- a/src/mapleall/maple_be/src/be/switch_lowerer.cpp +++ b/src/mapleall/maple_be/src/be/switch_lowerer.cpp @@ -95,7 +95,7 @@ void SwitchLowerer::InitSwitchItems(MapleVector &clusters) { } } -RangeGotoNode *SwitchLowerer::BuildRangeGotoNode(int32 startIdx, int32 endIdx) { +RangeGotoNode *SwitchLowerer::BuildRangeGotoNode(int32 startIdx, int32 endIdx, LabelIdx newLabelIdx) { RangeGotoNode *node = mirModule.CurFuncCodeMemPool()->New(mirModule); node->SetOpnd(stmt->GetSwitchOpnd(), 0); @@ -114,6 +114,8 @@ RangeGotoNode *SwitchLowerer::BuildRangeGotoNode(int32 startIdx, int32 endIdx) { curTag = static_cast(static_cast(++lastCaseTag) - node->GetTagOffset()); if (stmt->GetDefaultLabel() != 0) { node->AddRangeGoto(curTag, stmt->GetDefaultLabel()); + } else if (newLabelIdx != 0) { + node->AddRangeGoto(curTag, newLabelIdx); } } curTag = static_cast(stmt->GetCasePair(static_cast(i)).first - node->GetTagOffset()); @@ -173,7 +175,7 @@ CondGotoNode *SwitchLowerer::BuildCondGotoNode(int32 idx, Opcode opCode, BaseNod /* start and end is with respect to switchItems */ BlockNode *SwitchLowerer::BuildCodeForSwitchItems(int32 start, int32 end, bool lowBlockNodeChecked, - bool highBlockNodeChecked) { + bool highBlockNodeChecked, LabelIdx newLabelIdx) { ASSERT(start >= 0, "invalid args start"); ASSERT(end >= 0, "invalid args end"); BlockNode *localBlk = mirModule.CurFuncCodeMemPool()->New(); @@ -198,7 +200,7 @@ BlockNode *SwitchLowerer::BuildCodeForSwitchItems(int32 start, int32 end, bool l } } } - rangeGoto = BuildRangeGotoNode(switchItems[start].first, switchItems[start].second); + rangeGoto = BuildRangeGotoNode(switchItems[start].first, switchItems[start].second, newLabelIdx); if (stmt->GetDefaultLabel() == 0) { localBlk->AddStatement(rangeGoto); } else { @@ -222,7 +224,7 @@ BlockNode *SwitchLowerer::BuildCodeForSwitchItems(int32 start, int32 end, bool l } highBlockNodeChecked = true; } - rangeGoto = BuildRangeGotoNode(switchItems[end].first, switchItems[end].second); + rangeGoto = BuildRangeGotoNode(switchItems[end].first, switchItems[end].second, newLabelIdx); if (stmt->GetDefaultLabel() == 0) { localBlk->AddStatement(rangeGoto); } else { @@ -328,7 +330,7 @@ BlockNode *SwitchLowerer::BuildCodeForSwitchItems(int32 start, int32 end, bool l return localBlk; } -BlockNode *SwitchLowerer::LowerSwitch() { +BlockNode *SwitchLowerer::LowerSwitch(LabelIdx newLabelIdx) { if (stmt->GetSwitchTable().empty()) { /* change to goto */ BlockNode *localBlk = mirModule.CurFuncCodeMemPool()->New(); GotoNode *gotoDft = BuildGotoNode(-1); @@ -348,7 +350,7 @@ BlockNode *SwitchLowerer::LowerSwitch() { stmt->SortCasePair(CasePairKeyLessThan); FindClusters(clusters); InitSwitchItems(clusters); - BlockNode *blkNode = BuildCodeForSwitchItems(0, static_cast(switchItems.size()) - 1, false, false); + BlockNode *blkNode = BuildCodeForSwitchItems(0, static_cast(switchItems.size()) - 1, false, false, newLabelIdx); if (!jumpToDefaultBlockGenerated) { GotoNode *gotoDft = BuildGotoNode(-1); if (gotoDft != nullptr) {