From d3e66ff921264a42ee334370c7d1b3d87ca00836 Mon Sep 17 00:00:00 2001 From: Fred Chow Date: Sun, 14 Nov 2021 14:39:41 -0800 Subject: [PATCH] Moved the optimization to replace switch with one branch target by conditional branch from me_cfg.cpp to mir_lower.cpp This is for fixing bugs in the old code and to avoid having to update the CFG. --- src/mapleall/maple_ir/include/mir_lower.h | 1 + src/mapleall/maple_ir/src/mir_lower.cpp | 75 ++++++++++++++++++ src/mapleall/maple_me/include/me_cfg.h | 1 - src/mapleall/maple_me/src/me_cfg.cpp | 96 ----------------------- 4 files changed, 76 insertions(+), 97 deletions(-) diff --git a/src/mapleall/maple_ir/include/mir_lower.h b/src/mapleall/maple_ir/include/mir_lower.h index 32a08de9b6..e87659da23 100644 --- a/src/mapleall/maple_ir/include/mir_lower.h +++ b/src/mapleall/maple_ir/include/mir_lower.h @@ -64,6 +64,7 @@ class MIRLower { } virtual BlockNode *LowerIfStmt(IfStmtNode &ifStmt, bool recursive); + BlockNode *LowerSwitchStmt(SwitchNode *switchNode); virtual BlockNode *LowerWhileStmt(WhileStmtNode&); BlockNode *LowerDowhileStmt(WhileStmtNode&); BlockNode *LowerDoloopStmt(DoloopNode&); diff --git a/src/mapleall/maple_ir/src/mir_lower.cpp b/src/mapleall/maple_ir/src/mir_lower.cpp index 5d1b7bcf01..b5dd9b6649 100644 --- a/src/mapleall/maple_ir/src/mir_lower.cpp +++ b/src/mapleall/maple_ir/src/mir_lower.cpp @@ -225,6 +225,73 @@ BlockNode *MIRLower::LowerIfStmt(IfStmtNode &ifStmt, bool recursive) { return blk; } +static bool ConsecutiveCaseValsAndSameTarget(const CaseVector *switchTable) { + size_t caseNum = switchTable->size(); + int lastVal = (*switchTable)[0].first; + LabelIdx lblIdx = (*switchTable)[0].second; + for (size_t id = 1; id < caseNum; id++) { + lastVal++; + if (lastVal != (*switchTable)[id].first) { + return false; + } + if (lblIdx != (*switchTable)[id].second) { + return false; + } + } + return true; +} + +// if there is only 1 case branch, replace with conditional branch(es) and +// return the optimized multiple statements; otherwise, return nullptr +BlockNode *MIRLower::LowerSwitchStmt(SwitchNode *switchNode) { + CaseVector *switchTable = &switchNode->GetSwitchTable(); + if (switchTable->empty()) { // goto @defaultLabel + BlockNode *blk = mirModule.CurFuncCodeMemPool()->New(); + LabelIdx defaultLabel = switchNode->GetDefaultLabel(); + MIRBuilder *builder = mirModule.GetMIRBuilder(); + GotoNode *gotoStmt = builder->CreateStmtGoto(OP_goto, defaultLabel); + blk->AddStatement(gotoStmt); + return blk; + } + if (!ConsecutiveCaseValsAndSameTarget(switchTable)) { + return nullptr; + } + BlockNode *blk = mirModule.CurFuncCodeMemPool()->New(); + LabelIdx caseGotoLabel = switchTable->front().second; + LabelIdx defaultLabel = switchNode->GetDefaultLabel(); + int64 minCaseVal = switchTable->front().first; + int64 maxCaseVal = switchTable->back().first; + BaseNode *switchOpnd = switchNode->Opnd(0); + MIRBuilder *builder = mirModule.GetMIRBuilder(); + ConstvalNode *minCaseNode = builder->CreateIntConst(minCaseVal, switchOpnd->GetPrimType()); + ConstvalNode *maxCaseNode = builder->CreateIntConst(maxCaseVal, switchOpnd->GetPrimType()); + if (minCaseVal == maxCaseVal) { + // brtrue (x == minCaseVal) @case_goto_label + // goto @default_label + CompareNode *eqNode = builder->CreateExprCompare(OP_eq, *GlobalTables::GetTypeTable().GetInt32(), + *GlobalTables::GetTypeTable().GetTypeFromTyIdx(TyIdx(switchOpnd->GetPrimType())), switchOpnd, minCaseNode); + CondGotoNode *condGoto = builder->CreateStmtCondGoto(eqNode, OP_brtrue, caseGotoLabel); + blk->AddStatement(condGoto); + GotoNode *gotoStmt = builder->CreateStmtGoto(OP_goto, defaultLabel); + blk->AddStatement(gotoStmt); + } else { + // brtrue (x < minCaseVal) @default_label + // brtrue (x > maxCaseVal) @default_label + // goto @case_goto_label + CompareNode *ltNode = builder->CreateExprCompare(OP_lt, *GlobalTables::GetTypeTable().GetInt32(), + *GlobalTables::GetTypeTable().GetTypeFromTyIdx(TyIdx(switchOpnd->GetPrimType())), switchOpnd, minCaseNode); + CondGotoNode *condGoto = builder->CreateStmtCondGoto(ltNode, OP_brtrue, defaultLabel); + blk->AddStatement(condGoto); + CompareNode *gtNode = builder->CreateExprCompare(OP_gt, *GlobalTables::GetTypeTable().GetInt32(), + *GlobalTables::GetTypeTable().GetTypeFromTyIdx(TyIdx(switchOpnd->GetPrimType())), switchOpnd, maxCaseNode); + condGoto = builder->CreateStmtCondGoto(gtNode, OP_brtrue, defaultLabel); + blk->AddStatement(condGoto); + GotoNode *gotoStmt = builder->CreateStmtGoto(OP_goto, caseGotoLabel); + blk->AddStatement(gotoStmt); + } + return blk; +} + // while // is lowered to: // brfalse @@ -377,6 +444,14 @@ BlockNode *MIRLower::LowerBlock(BlockNode &block) { tmp = LowerIfStmt(static_cast(*stmt), true); newBlock->AppendStatementsFromBlock(*tmp); break; + case OP_switch: + tmp = LowerSwitchStmt(static_cast(stmt)); + if (tmp != nullptr) { + newBlock->AppendStatementsFromBlock(*tmp); + } else { + newBlock->AddStatement(stmt); + } + break; case OP_while: newBlock->AppendStatementsFromBlock(*LowerWhileStmt(static_cast(*stmt))); break; diff --git a/src/mapleall/maple_me/include/me_cfg.h b/src/mapleall/maple_me/include/me_cfg.h index 69f306f7bc..206c6d29b0 100644 --- a/src/mapleall/maple_me/include/me_cfg.h +++ b/src/mapleall/maple_me/include/me_cfg.h @@ -297,7 +297,6 @@ class MeCFG : public AnalysisResult { void UpdateBranchTarget(BB &currBB, BB &oldTarget, BB &newTarget, MeFunction &func); private: - void ReplaceSwitchContainsOneCaseBranchWithBrtrue(BB &bb, MapleVector &exitBlocks); void AddCatchHandlerForTryBB(BB &bb, MapleVector &exitBlocks); std::string ConstructFileNameToDump(const std::string &prefix) const; void DumpToFileInStrs(std::ofstream &cfgFile) const; diff --git a/src/mapleall/maple_me/src/me_cfg.cpp b/src/mapleall/maple_me/src/me_cfg.cpp index 611a4a8e0d..c9b845ec00 100644 --- a/src/mapleall/maple_me/src/me_cfg.cpp +++ b/src/mapleall/maple_me/src/me_cfg.cpp @@ -25,18 +25,6 @@ namespace { constexpr int kFuncNameLenLimit = 80; -static bool CaseValOfSwitchIsSuccInt(const maple::CaseVector &switchTable) { - ASSERT(!switchTable.empty(), "switch table is empty"); - size_t caseNum = switchTable.size(); - int val = switchTable[0].first; - for (size_t id = 1; id < caseNum; id++) { - val++; - if (val != switchTable[id].first) { - return false; - } - } - return true; -} } namespace maple { @@ -88,82 +76,6 @@ bool MeCFG::IfReplaceWithAssertNonNull(const BB &bb) const { return true; } -void MeCFG::ReplaceSwitchContainsOneCaseBranchWithBrtrue(maple::BB &bb, MapleVector &exitBlocks) { - StmtNode &lastStmt = bb.GetStmtNodes().back(); - ASSERT(lastStmt.GetOpCode() == OP_switch, "runtime check error"); - auto &switchStmt = static_cast(lastStmt); - auto &swithcTable = switchStmt.GetSwitchTable(); - if (!CaseValOfSwitchIsSuccInt(swithcTable)) { - return; - } - LabelIdx defaultLabelIdx = switchStmt.GetDefaultLabel(); - int32 minCaseVal = swithcTable.front().first; - int32 maxCaseVal = swithcTable.back().first; - - // lfopreemit can't handle the optimized cfg for swith with one case with range value branch - if ((minCaseVal != maxCaseVal) && func.GetLfoFunc()) { - return; - } - auto &mirBuilder = func.GetMIRModule().GetMIRBuilder(); - auto *baseNode = switchStmt.Opnd(0); - auto *minCaseNode = mirBuilder->CreateIntConst(minCaseVal, PTY_i32); - auto *maxCaseNode = mirBuilder->CreateIntConst(maxCaseVal, PTY_i32); - if (minCaseVal == maxCaseVal) { - // brtrue != minCaseVal, @default_label - // caseBB - // @default_label - // defaultBB; - auto *neNode = mirBuilder->CreateExprCompare(OP_ne, GetTypeFromTyIdx(TyIdx(PTY_u1)), - GetTypeFromTyIdx(TyIdx(PTY_i32)), baseNode, minCaseNode); - auto *condGoto = mirBuilder->CreateStmtCondGoto(neNode, OP_brtrue, defaultLabelIdx); - bb.ReplaceStmt(&switchStmt, condGoto); - bb.SetKind(kBBCondGoto); - // reset bb succ - bb.RemoveAllSucc(); - BB *defaultBB = GetLabelBBAt(defaultLabelIdx); - ASSERT(defaultBB != nullptr, "null ptr check"); - BB *caseBB = GetLabelBBAt(switchStmt.GetSwitchTable().front().second); - ASSERT(caseBB != nullptr, "null ptr check"); - bb.AddSucc(*caseBB); // add fallthru - bb.AddSucc(*defaultBB); // add target - return; - } else { - // lfopreemit can't handle the optimized cfg for swith with one case branch - auto *ltNode = mirBuilder->CreateExprCompare(OP_lt, GetTypeFromTyIdx(TyIdx(PTY_u1)), - GetTypeFromTyIdx(TyIdx(PTY_i32)), baseNode, minCaseNode); - auto *condGoto = mirBuilder->CreateStmtCondGoto(ltNode, OP_brtrue, defaultLabelIdx); - bb.ReplaceStmt(&switchStmt, condGoto); - bb.SetKind(kBBCondGoto); - - auto *newBB = NewBasicBlock(); - auto *gtNode = mirBuilder->CreateExprCompare(OP_gt, GetTypeFromTyIdx(TyIdx(PTY_u1)), - GetTypeFromTyIdx(TyIdx(PTY_i32)), baseNode, maxCaseNode); - condGoto = mirBuilder->CreateStmtCondGoto(gtNode, OP_brtrue, defaultLabelIdx); - newBB->GetStmtNodes().push_back(condGoto); - newBB->SetKind(kBBCondGoto); - - BB *defaultBB = GetLabelBBAt(defaultLabelIdx); - ASSERT(defaultBB != nullptr, "null ptr check"); - while (!bb.GetSucc().empty()) { - bb.RemoveSucc(*bb.GetSucc(0)); - } - bb.AddSucc(*newBB); - bb.AddSucc(*defaultBB); - - BB *caseBB = GetLabelBBAt(switchStmt.GetSwitchTable().front().second); - ASSERT(caseBB != nullptr, "null ptr check"); - newBB->AddSucc(*caseBB); - newBB->AddSucc(*defaultBB); - - if (bb.GetAttributes(kBBAttrIsTry)) { - newBB->SetAttributes(kBBAttrIsTry); - SetBBTryNodeMap(*newBB, *GetBBTryNodeMap().at(&bb)); - AddCatchHandlerForTryBB(bb, exitBlocks); - AddCatchHandlerForTryBB(*newBB, exitBlocks); - } - } -} - void MeCFG::AddCatchHandlerForTryBB(BB &bb, MapleVector &exitBlocks) { if (!bb.GetAttributes(kBBAttrIsTry)) { return; @@ -208,7 +120,6 @@ void MeCFG::AddCatchHandlerForTryBB(BB &bb, MapleVector &exitBlocks) { void MeCFG::BuildMirCFG() { MapleVector entryBlocks(GetAlloc().Adapter()); MapleVector exitBlocks(GetAlloc().Adapter()); - std::vector switchBBsWithOneCaseBranch; auto eIt = valid_end(); for (auto bIt = valid_begin(); bIt != eIt; ++bIt) { if (bIt == common_entry() || bIt == common_exit()) { @@ -271,11 +182,9 @@ void MeCFG::BuildMirCFG() { LabelIdx lblIdx = switchStmt.GetDefaultLabel(); BB *mirBB = GetLabelBBAt(lblIdx); bb->AddSucc(*mirBB); - std::set caseLabels; for (size_t j = 0; j < switchStmt.GetSwitchTable().size(); ++j) { lblIdx = switchStmt.GetCasePair(j).second; BB *meBB = GetLabelBBAt(lblIdx); - (void)caseLabels.insert(lblIdx); // Avoid duplicate succs. if (!meBB->IsSuccBB(*bb)) { bb->AddSucc(*meBB); @@ -284,8 +193,6 @@ void MeCFG::BuildMirCFG() { if (bb->GetSucc().size() == 1) { bb->RemoveLastStmt(); bb->SetKind(kBBFallthru); - } else if (caseLabels.size() == 1) { - switchBBsWithOneCaseBranch.push_back(bb); } break; } @@ -314,9 +221,6 @@ void MeCFG::BuildMirCFG() { } } - for (BB *switchBB : switchBBsWithOneCaseBranch) { - ReplaceSwitchContainsOneCaseBranchWithBrtrue(*switchBB, exitBlocks); - } // merge all blocks in entryBlocks for (BB *bb : entryBlocks) { GetCommonEntryBB()->AddEntry(*bb); -- Gitee