From e887c3c1d8cbc6773f31012b8dbe60cbb53bf878 Mon Sep 17 00:00:00 2001 From: Fred Chow Date: Thu, 9 Dec 2021 15:10:40 -0800 Subject: [PATCH] Implemented LFO's loop unrolling, only handling loops with constant trip count --- .../maple_driver/src/driver_runner.cpp | 1 + src/mapleall/maple_ir/include/mir_nodes.h | 7 +- src/mapleall/maple_ir/src/mir_parser.cpp | 3 - src/mapleall/maple_me/include/lfo_unroll.h | 5 +- src/mapleall/maple_me/src/lfo_unroll.cpp | 77 ++++++++++++++++--- 5 files changed, 74 insertions(+), 19 deletions(-) diff --git a/src/mapleall/maple_driver/src/driver_runner.cpp b/src/mapleall/maple_driver/src/driver_runner.cpp index 308fca5a3a..2683bf9c41 100644 --- a/src/mapleall/maple_driver/src/driver_runner.cpp +++ b/src/mapleall/maple_driver/src/driver_runner.cpp @@ -219,6 +219,7 @@ void DriverRunner::RunNewPM(const std::string &output, const std::string &vtable { LogInfo::MapleLogger() << "\n" << LoopVectorization::vectorizedLoop << " loop vectorized\n"; LogInfo::MapleLogger() << "\n" << SeqVectorize::seqVecStores << " sequencestores vectorized\n"; + LogInfo::MapleLogger() << "\n" << LfoUnrollOneLoop::countOfLoopsUnrolled << " loops unrolled\n"; } } diff --git a/src/mapleall/maple_ir/include/mir_nodes.h b/src/mapleall/maple_ir/include/mir_nodes.h index 08d7f47b79..0a0be4f646 100755 --- a/src/mapleall/maple_ir/include/mir_nodes.h +++ b/src/mapleall/maple_ir/include/mir_nodes.h @@ -2450,7 +2450,7 @@ class BlockNode : public StmtNode { class IfStmtNode : public UnaryStmtNode { public: IfStmtNode() : UnaryStmtNode(OP_if) { - numOpnds = kOperandNumBinary; + numOpnds = kOperandNumTernary; } virtual ~IfStmtNode() = default; @@ -2500,7 +2500,10 @@ class IfStmtNode : public UnaryStmtNode { } size_t NumOpnds() const override { - return numOpnds; + if (elsePart == nullptr) { + return kOperandNumBinary; + } + return kOperandNumTernary; } private: diff --git a/src/mapleall/maple_ir/src/mir_parser.cpp b/src/mapleall/maple_ir/src/mir_parser.cpp index 7f74d8c1d1..386a12083d 100755 --- a/src/mapleall/maple_ir/src/mir_parser.cpp +++ b/src/mapleall/maple_ir/src/mir_parser.cpp @@ -400,9 +400,6 @@ bool MIRParser::ParseStmtIf(StmtNodePtr &stmt) { return false; } ifStmt->SetElsePart(elseBlock); - if (elseBlock != nullptr) { - ifStmt->SetNumOpnds(ifStmt->GetNumOpnds() + 1); - } } stmt = ifStmt; return true; diff --git a/src/mapleall/maple_me/include/lfo_unroll.h b/src/mapleall/maple_me/include/lfo_unroll.h index fcf753a69e..a7f43d23c7 100644 --- a/src/mapleall/maple_me/include/lfo_unroll.h +++ b/src/mapleall/maple_me/include/lfo_unroll.h @@ -34,8 +34,8 @@ class LfoUnrollOneLoop { BaseNode *CloneIVNode(); bool IsIVNode(BaseNode *x); void ReplaceIV(BaseNode *x, BaseNode *repNode); - void DoFullUnroll(size_t tripCount); - void DoUnroll(size_t times); + BlockNode *DoFullUnroll(size_t tripCount); + BlockNode *DoUnroll(size_t times, size_t tripCount); void Process(); LfoFunction *lfoFunc; @@ -47,6 +47,7 @@ class LfoUnrollOneLoop { MIRBuilder *mirBuilder; int64 stepAmount = 0; PrimType ivPrimType = PTY_unknown; + static uint32 countOfLoopsUnrolled; }; MAPLE_FUNC_PHASE_DECLARE(MELfoUnroll, MeFunction) diff --git a/src/mapleall/maple_me/src/lfo_unroll.cpp b/src/mapleall/maple_me/src/lfo_unroll.cpp index 2489ea210b..085cf42bb0 100644 --- a/src/mapleall/maple_me/src/lfo_unroll.cpp +++ b/src/mapleall/maple_me/src/lfo_unroll.cpp @@ -16,8 +16,10 @@ #include "me_loop_analysis.h" namespace maple { +uint32 LfoUnrollOneLoop::countOfLoopsUnrolled = 0; constexpr size_t unrolledSizeLimit = 12; // unrolled loop body size to be < this value +constexpr size_t unrollMax = 8; // times to unroll never more than this BaseNode *LfoUnrollOneLoop::CloneIVNode() { if (doloop->IsPreg()) { @@ -61,7 +63,7 @@ void LfoUnrollOneLoop::ReplaceIV(BaseNode *x, BaseNode *repNode) { } } -void LfoUnrollOneLoop::DoFullUnroll(size_t tripCount) { +BlockNode *LfoUnrollOneLoop::DoFullUnroll(size_t tripCount) { BlockNode *unrolledBlk = doloop->GetDoBody()->CloneTreeWithSrcPosition(*mirModule); ReplaceIV(unrolledBlk, doloop->GetStartExpr()); BlockNode *nextIterBlk = nullptr; @@ -76,16 +78,52 @@ void LfoUnrollOneLoop::DoFullUnroll(size_t tripCount) { unrolledBlk->InsertBlockAfter(*nextIterBlk, unrolledBlk->GetLast()); tripCount--; } - - // replace doloop by the statements in unrolledBlk - LfoPart *lfopart = (*preEmit->GetLfoStmtMap())[doloop->GetStmtID()]; - BaseNode *parent = lfopart->GetParent(); - ASSERT(parent && (parent->GetOpCode() == OP_block), "LfoUnroll: parent of doloop is not OP_block"); - BlockNode *pblock = static_cast(parent); - pblock->ReplaceStmtWithBlock(*doloop, *unrolledBlk); + return unrolledBlk; } -void LfoUnrollOneLoop::DoUnroll(size_t times) { +// only handling constant trip count for now +BlockNode *LfoUnrollOneLoop::DoUnroll(size_t times, size_t tripCount) { + BlockNode *unrolledBlk = nullptr; + // form the remainder loop before the unrolled loop + size_t remainderTripCount = tripCount % times; + if (remainderTripCount == 0) { + unrolledBlk = codeMP->New(); + } else if (remainderTripCount == 1) { + unrolledBlk = doloop->GetDoBody()->CloneTreeWithSrcPosition(*mirModule); + ReplaceIV(unrolledBlk, doloop->GetStartExpr()); + } else { + DoloopNode *remDoloop = doloop->CloneTree(*preEmit->GetCodeMPAlloc()); + // generate remDoloop's termination + BaseNode *terminationRHS = codeMP->New(OP_add, ivPrimType, + doloop->GetStartExpr()->CloneTree(*preEmit->GetCodeMPAlloc()), + mirBuilder->CreateIntConst(remainderTripCount, ivPrimType)); + remDoloop->SetContExpr(codeMP->New(OP_lt, PTY_i32, ivPrimType, CloneIVNode(), terminationRHS)); + unrolledBlk = codeMP->New(); + unrolledBlk->AddStatement(remDoloop); + } + // form the unrolled loop + DoloopNode *unrolledDoloop = doloop->CloneTree(*preEmit->GetCodeMPAlloc()); + uint32 i = 1; + BlockNode *nextIterBlk = nullptr; + do { + nextIterBlk = doloop->GetDoBody()->CloneTreeWithSrcPosition(*mirModule); + BaseNode *adjExpr = mirBuilder->CreateIntConst(stepAmount * i, ivPrimType); + BaseNode *repExpr = codeMP->New(OP_add, ivPrimType, CloneIVNode(), adjExpr); + ReplaceIV(nextIterBlk, repExpr); + unrolledDoloop->GetDoBody()->InsertBlockAfter(*nextIterBlk, unrolledDoloop->GetDoBody()->GetLast()); + i++; + } while (i != times); + if (remainderTripCount != 0) { // update startExpr + BaseNode *newStartExpr = codeMP->New(OP_add, ivPrimType, unrolledDoloop->GetStartExpr(), + mirBuilder->CreateIntConst(remainderTripCount, ivPrimType)); + unrolledDoloop->SetStartExpr(newStartExpr); + } + // update incrExpr + ConstvalNode *stepNode = static_cast(unrolledDoloop->GetIncrExpr()); + int64 origIncr = static_cast(stepNode->GetConstVal())->GetValue(); + unrolledDoloop->SetIncrExpr(mirBuilder->CreateIntConst(origIncr*times, ivPrimType)); + unrolledBlk->AddStatement(unrolledDoloop); + return unrolledBlk; } static size_t CountBlockStmts(BlockNode *blk) { @@ -166,19 +204,30 @@ void LfoUnrollOneLoop::Process() { } size_t unrollTimes = 1; size_t unrolledStmtCount = stmtCount; - while (unrolledStmtCount < unrolledSizeLimit) { + while (unrolledStmtCount < unrolledSizeLimit && unrollTimes < unrollMax) { unrollTimes++; unrolledStmtCount += stmtCount; } bool fullUnroll = tripCount < (unrollTimes * 2); + BlockNode *unrolledBlk = nullptr; if (fullUnroll) { - DoFullUnroll(tripCount); + unrolledBlk = DoFullUnroll(tripCount); } else { if (unrollTimes == 1) { return; } - DoUnroll(unrollTimes); + unrolledBlk = DoUnroll(unrollTimes, tripCount); } + + // replace doloop by the statements in unrolledBlk + LfoPart *lfopart = (*preEmit->GetLfoStmtMap())[doloop->GetStmtID()]; + BaseNode *parent = lfopart->GetParent(); + ASSERT(parent && (parent->GetOpCode() == OP_block), "LfoUnroll: parent of doloop is not OP_block"); + BlockNode *pblock = static_cast(parent); + pblock->ReplaceStmtWithBlock(*doloop, *unrolledBlk); + + // update counter + countOfLoopsUnrolled++; }; bool MELfoUnroll::PhaseRun(MeFunction &f) { @@ -187,6 +236,7 @@ bool MELfoUnroll::PhaseRun(MeFunction &f) { LfoDepInfo *lfoDepInfo = GET_ANALYSIS(MELfoDepTest, f); ASSERT(lfoDepInfo != nullptr, "lfo dep test phase has problem"); LfoFunction *lfoFunc = f.GetLfoFunc(); +//uint32 savedCountOfLoopsUnrolled = LfoUnrollOneLoop::countOfLoopsUnrolled; MapleMap::iterator mapit = lfoDepInfo->doloopInfoMap.begin(); for (; mapit != lfoDepInfo->doloopInfoMap.end(); mapit++) { @@ -196,6 +246,9 @@ bool MELfoUnroll::PhaseRun(MeFunction &f) { LfoUnrollOneLoop unroll(lfoFunc, preEmit, mapit->second); unroll.Process(); } +//if (!MeOption::quiet && savedCountOfLoopsUnrolled != LfoUnrollOneLoop::countOfLoopsUnrolled) { +// f.GetMirFunc()->Dump(); +//} return false; } -- Gitee