From 2c91474140cf80baf8086ca9fd35f18a7b5e5b99 Mon Sep 17 00:00:00 2001 From: linma Date: Fri, 22 Oct 2021 11:28:17 -0700 Subject: [PATCH] lfo loop vectorization: support reduction variable in loop --- src/mapleall/maple_me/include/lfo_dep_test.h | 5 +- src/mapleall/maple_me/include/lfo_loop_vec.h | 11 +- src/mapleall/maple_me/src/lfo_dep_test.cpp | 1 + src/mapleall/maple_me/src/lfo_loop_vec.cpp | 193 +++++++++++++++++- .../maple_me/src/me_value_range_prop.cpp | 28 +++ 5 files changed, 225 insertions(+), 13 deletions(-) diff --git a/src/mapleall/maple_me/include/lfo_dep_test.h b/src/mapleall/maple_me/include/lfo_dep_test.h index dbf7e3e540..3df6f4ba37 100644 --- a/src/mapleall/maple_me/include/lfo_dep_test.h +++ b/src/mapleall/maple_me/include/lfo_dep_test.h @@ -79,6 +79,7 @@ class DoloopInfo { bool hasMayDef = false; // give up dep testing if true MapleVector outputDepTestList; // output dependence only MapleVector flowDepTestList; // include both true and anti dependences + MapleSet redVars; // reduction variables DoloopInfo(MapleAllocator *allc, LfoDepInfo *depinfo, DoloopNode *doloop, DoloopInfo *prnt) : alloc(allc), @@ -89,7 +90,8 @@ class DoloopInfo { lhsArrays(alloc->Adapter()), rhsArrays(alloc->Adapter()), outputDepTestList(alloc->Adapter()), - flowDepTestList(alloc->Adapter()) {} + flowDepTestList(alloc->Adapter()), + redVars(alloc->Adapter()) {} ~DoloopInfo() = default; bool IsLoopInvariant(MeExpr *x); bool OnlyInvariantScalars(MeExpr *x); @@ -102,6 +104,7 @@ class DoloopInfo { bool Parallelizable(); bool CheckReductionLoop(); ArrayAccessDesc* GetArrayAccessDesc(ArrayNode *node, bool isRHS); + bool IsReductionVar(StIdx stidx) { return (redVars.count(stidx) > 0); } }; class LfoDepInfo : public AnalysisResult { diff --git a/src/mapleall/maple_me/include/lfo_loop_vec.h b/src/mapleall/maple_me/include/lfo_loop_vec.h index d1d3072f43..3946eaa6e1 100644 --- a/src/mapleall/maple_me/include/lfo_loop_vec.h +++ b/src/mapleall/maple_me/include/lfo_loop_vec.h @@ -37,7 +37,9 @@ class LoopVecInfo { : vecStmtIDs(alloc.Adapter()), uniformNodes(alloc.Adapter()), uniformVecNodes(alloc.Adapter()), - constvalTypes(alloc.Adapter()) { + constvalTypes(alloc.Adapter()), + reductionVars(alloc.Adapter()), + redVecNodes(alloc.Adapter()) { largestTypeSize = 8; // i8 bit size currentRHSTypeSize = 0; } @@ -53,6 +55,8 @@ class LoopVecInfo { MapleMap uniformVecNodes; // new generated vector node // constval node need to adjust with new PrimType MapleMap constvalTypes; + MapleSet reductionVars; // reduction variables used in rhs->opnd(0) + MapleMap redVecNodes; // new generate vector node }; // tranform plan for current loop @@ -98,7 +102,7 @@ class LoopVectorization { void VectorizeDoLoop(DoloopNode *, LoopTransPlan*); void VectorizeNode(BaseNode *, LoopTransPlan *); MIRType *GenVecType(PrimType, uint8_t); - RegassignNode *GenDupScalarStmt(BaseNode *scalar, PrimType vecPrimType); + IntrinsicopNode *GenDupScalarExpr(BaseNode *scalar, PrimType vecPrimType); bool ExprVectorizable(DoloopInfo *doloopInfo, LoopVecInfo*, BaseNode *x); bool Vectorizable(DoloopInfo *doloopInfo, LoopVecInfo*, BlockNode *block); void widenDoloop(DoloopNode *doloop, LoopTransPlan *); @@ -109,6 +113,9 @@ class LoopVectorization { std::string PhaseName() const { return "lfoloopvec"; } bool CanConvert(uint32_t, uint32_t); bool CanAdjustRhsType(PrimType, ConstvalNode *); + bool IsReductionOp(Opcode op); + IntrinsicopNode *GenSumVecStmt(BaseNode *vecTemp, PrimType vecPrimType); + public: static uint32_t vectorizedLoop; private: diff --git a/src/mapleall/maple_me/src/lfo_dep_test.cpp b/src/mapleall/maple_me/src/lfo_dep_test.cpp index c0963c3071..ac06708ed8 100644 --- a/src/mapleall/maple_me/src/lfo_dep_test.cpp +++ b/src/mapleall/maple_me/src/lfo_dep_test.cpp @@ -538,6 +538,7 @@ bool DoloopInfo::CheckReductionLoop() { if (!OnlyInvariantScalars(depInfo->preEmit->GetMexpr(otherOpnd))) { return false; } + redVars.insert(stIdx); stmt = stmt->GetNext(); } return true; diff --git a/src/mapleall/maple_me/src/lfo_loop_vec.cpp b/src/mapleall/maple_me/src/lfo_loop_vec.cpp index e5f9f5794d..9589676bb6 100644 --- a/src/mapleall/maple_me/src/lfo_loop_vec.cpp +++ b/src/mapleall/maple_me/src/lfo_loop_vec.cpp @@ -257,8 +257,101 @@ MIRType* LoopVectorization::GenVecType(PrimType sPrimType, uint8 lanes) { return vecType; } +// generate instrinsic node to sum all elements of a vector type +IntrinsicopNode *LoopVectorization::GenSumVecStmt(BaseNode *vecTemp, PrimType vecPrimType) { + MIRIntrinsicID intrnID = INTRN_vector_sum_v4i32; + MIRType *retType = nullptr; + switch (vecPrimType) { + case PTY_v4i32: { + intrnID = INTRN_vector_sum_v4i32; + retType = GlobalTables::GetTypeTable().GetInt32(); + break; + } + case PTY_v2i32: { + intrnID = INTRN_vector_sum_v2i32; + retType = GlobalTables::GetTypeTable().GetInt32(); + break; + } + case PTY_v4u32: { + intrnID = INTRN_vector_sum_v4u32; + retType = GlobalTables::GetTypeTable().GetUInt32(); + break; + } + case PTY_v2u32: { + intrnID = INTRN_vector_sum_v2u32; + retType = GlobalTables::GetTypeTable().GetUInt32(); + break; + } + case PTY_v2i64: { + intrnID = INTRN_vector_sum_v2i64; + retType = GlobalTables::GetTypeTable().GetInt64(); + break; + } + case PTY_v2u64: { + intrnID = INTRN_vector_sum_v2u64; + retType = GlobalTables::GetTypeTable().GetUInt64(); + break; + } + case PTY_v8i16: { + intrnID = INTRN_vector_sum_v8i16; + retType = GlobalTables::GetTypeTable().GetInt16(); + break; + } + case PTY_v8u16: { + intrnID = INTRN_vector_sum_v8u16; + retType = GlobalTables::GetTypeTable().GetUInt16(); + break; + } + case PTY_v4i16: { + intrnID = INTRN_vector_sum_v4i16; + retType = GlobalTables::GetTypeTable().GetInt16(); + break; + } + case PTY_v4u16: { + intrnID = INTRN_vector_sum_v4u16; + retType = GlobalTables::GetTypeTable().GetUInt16(); + break; + } + case PTY_v16i8: { + intrnID = INTRN_vector_sum_v16i8; + retType = GlobalTables::GetTypeTable().GetInt8(); + break; + } + case PTY_v16u8: { + intrnID = INTRN_vector_sum_v16u8; + retType = GlobalTables::GetTypeTable().GetUInt8(); + break; + } + case PTY_v8i8: { + intrnID = INTRN_vector_sum_v8i8; + retType = GlobalTables::GetTypeTable().GetInt8(); + break; + } + case PTY_v8u8: { + intrnID = INTRN_vector_sum_v8u8; + retType = GlobalTables::GetTypeTable().GetUInt8(); + break; + } + default: + ASSERT(0, "NIY"); + } + // generate instrinsic op + IntrinsicopNode *rhs = codeMP->New(*codeMPAlloc, OP_intrinsicop, retType->GetPrimType()); + rhs->SetIntrinsic(intrnID); + rhs->SetNumOpnds(1); + rhs->GetNopnd().push_back(vecTemp); + rhs->SetTyIdx(retType->GetTypeIndex()); + return rhs; +} + +// check opcode is reduction, +/-/*///min/max +bool LoopVectorization::IsReductionOp(Opcode op) { + if (op == OP_add || op == OP_sub) return true; + return false; +} + // generate instrinsic node to copy scalar to vector type -RegassignNode *LoopVectorization::GenDupScalarStmt(BaseNode *scalar, PrimType vecPrimType) { +IntrinsicopNode *LoopVectorization::GenDupScalarExpr(BaseNode *scalar, PrimType vecPrimType) { MIRIntrinsicID intrnID = INTRN_vector_from_scalar_v4i32; MIRType *vecType = nullptr; switch (vecPrimType) { @@ -342,9 +435,7 @@ RegassignNode *LoopVectorization::GenDupScalarStmt(BaseNode *scalar, PrimType ve rhs->SetNumOpnds(1); rhs->GetNopnd().push_back(scalar); rhs->SetTyIdx(vecType->GetTypeIndex()); - PregIdx regIdx = mirFunc->GetPregTab()->CreatePreg(vecPrimType); - RegassignNode *stmtNode = codeMP->New(vecPrimType, regIdx, rhs); - return stmtNode; + return rhs; } // iterate tree node to wide scalar type to vector type @@ -404,9 +495,33 @@ void LoopVectorization::VectorizeNode(BaseNode *node, LoopTransPlan *tp) { break; } // scalar related: widen type directly or unroll instructions - case OP_dassign: - ASSERT(0, "NIY"); + case OP_dassign: { + // now only support reduction scalar + // sum = sum +/- vectorizable_expr + // => + // vec t1 = dup_scalar(sum); + // doloop { + // t1 = t1 + vectorized_node; + // } + // sum = sum +/- intrinsic_op vec_sum(t1) + // sum = intrinsic_op vec_sum(vectorized_node); + DassignNode *dassign = static_cast(node); + StIdx lhsStIdx = dassign->GetStIdx(); + MIRSymbol *lhsSym = mirFunc->GetLocalOrGlobalSymbol(lhsStIdx); + MIRType &lhsType = GetTypeFromTyIdx(lhsSym->GetTyIdx()); + MIRType *vecType = GenVecType(lhsType.GetPrimType(), tp->vecFactor); + ASSERT(vecType != nullptr, "vector type should not be null"); + BaseNode *vecNode = dassign->GetRHS()->Opnd(1); + VectorizeNode(vecNode, tp); + BaseNode *redNewVar = tp->vecInfo->redVecNodes[lhsStIdx]; + ASSERT((redNewVar != nullptr && redNewVar->GetOpCode() == OP_dread), "nullptr check"); + StIdx vecStIdx = (static_cast(redNewVar))->GetStIdx(); + dassign->SetStIdx(vecStIdx); + dassign->SetPrimType(vecType->GetPrimType()); + dassign->GetRHS()->SetOpnd(redNewVar, 0); + dassign->GetRHS()->SetPrimType(vecType->GetPrimType()); break; + } // vector type support in opcode +, -, *, &, |, <<, >>, compares, ~, ! case OP_add: case OP_sub: @@ -503,14 +618,48 @@ void LoopVectorization::VectorizeDoLoop(DoloopNode *doloop, LoopTransPlan *tp) { ptype = tp->vecInfo->constvalTypes[node]; } MIRType *vecType = GenVecType(ptype, tp->vecFactor); - RegassignNode *dupScalarStmt = GenDupScalarStmt(node, vecType->GetPrimType()); + IntrinsicopNode *dupscalar = GenDupScalarExpr(node, vecType->GetPrimType()); + PregIdx regIdx = mirFunc->GetPregTab()->CreatePreg(vecType->GetPrimType()); + RegassignNode *dupScalarStmt = codeMP->New(vecType->GetPrimType(), regIdx, dupscalar); pblock->InsertBefore(doloop, dupScalarStmt); - RegreadNode *regreadNode = codeMP->New(vecType->GetPrimType(), dupScalarStmt->GetRegIdx()); + RegreadNode *regreadNode = codeMP->New(vecType->GetPrimType(), regIdx); tp->vecInfo->uniformVecNodes[node] = regreadNode; } } } - + // step 2.2 reduction variable + if (!tp->vecInfo->reductionVars.empty()) { + LfoPart* lfopart = (*lfoStmtParts)[doloop->GetStmtID()]; + BaseNode *parent = lfopart->GetParent(); + ASSERT(parent && (parent->GetOpCode() == OP_block), "nullptr check"); + BlockNode *pblock = static_cast(parent); + auto it = tp->vecInfo->reductionVars.begin(); + int count = 0; + for (; it != tp->vecInfo->reductionVars.end(); it++) { + StIdx stIdx = *it; + MIRSymbol *redSym = mirFunc->GetLocalOrGlobalSymbol(stIdx); + PrimType ptype = GetTypeFromTyIdx(redSym->GetTyIdx()).GetPrimType(); + MIRType *vecType = GenVecType(ptype, tp->vecFactor); + // before loop: vec = dup_scalar(reduction) + AddrofNode *redScalarNode = codeMP->New(OP_dread, ptype, stIdx, 0); + IntrinsicopNode *dupscalar = GenDupScalarExpr(redScalarNode, vecType->GetPrimType()); + // new stidx + std::string redName("red"); + redName.append(std::to_string(doloop->GetStmtID())); + redName.append("_"); + redName.append(std::to_string(count++)); + GStrIdx strIdx = GlobalTables::GetStrTable().GetOrCreateStrIdxFromName(redName); + MIRSymbol *st = mirFunc->GetModule()->GetMIRBuilder()->CreateSymbol(vecType->GetTypeIndex(), strIdx, kStVar, kScAuto, mirFunc, kScopeLocal); + DassignNode *redInitStmt = codeMP->New(ptype, dupscalar, st->GetStIdx(), 0); + pblock->InsertBefore(doloop, redInitStmt); + AddrofNode *dreadNode = codeMP->New(OP_dread, vecType->GetPrimType(), st->GetStIdx(), 0); + tp->vecInfo->redVecNodes[stIdx] = dreadNode; + // after loop: reduction = vec_sum(vec) + IntrinsicopNode *intrnNode = GenSumVecStmt(dreadNode, vecType->GetPrimType()); + DassignNode *redDassign = codeMP->New(ptype, intrnNode, stIdx, 0); + pblock->InsertAfter(doloop, redDassign); + } + } // step 3: widen vectorizable stmt in doloop BlockNode *loopbody = doloop->GetDoBody(); for (auto &stmt : loopbody->GetStmtNodes()) { @@ -809,6 +958,29 @@ bool LoopVectorization::Vectorizable(DoloopInfo *doloopInfo, LoopVecInfo* vecInf } break; } + case OP_dassign: { + DassignNode *dassign = static_cast(stmt); + StIdx lhsStIdx = dassign->GetStIdx(); + MIRSymbol *lhsSym = mirFunc->GetLocalOrGlobalSymbol(lhsStIdx); + MIRType &lhsType = GetTypeFromTyIdx(lhsSym->GetTyIdx()); + BaseNode *rhs = dassign->GetRHS(); + if (IsReductionOp(rhs->GetOpCode()) && doloopInfo->IsReductionVar(lhsStIdx)) { + BaseNode *opnd0 = rhs->Opnd(0); + BaseNode *opnd1 = rhs->Opnd(1); + CHECK_FATAL((opnd0->GetOpCode() == OP_dread) && ((static_cast(opnd0))->GetStIdx() == lhsStIdx), + "opnd0 is reduction variable"); + if (ExprVectorizable(doloopInfo, vecInfo, opnd1)) { + vecInfo->vecStmtIDs.insert((stmt)->GetStmtID()); + vecInfo->UpdateWidestTypeSize(GetPrimTypeSize(lhsType.GetPrimType()) *8); + vecInfo->reductionVars.insert((static_cast(opnd0))->GetStIdx()); + } else { + return false; // only handle reduction scalar + } + } else { + return false; + } + break; + } default: return false; } stmt = stmt->GetNext(); @@ -820,7 +992,8 @@ void LoopVectorization::Perform() { // step 2: collect information, legality check and generate transform plan MapleMap::iterator mapit = depInfo->doloopInfoMap.begin(); for (; mapit != depInfo->doloopInfoMap.end(); mapit++) { - if (!mapit->second->children.empty() || !mapit->second->Parallelizable()) { + if (!mapit->second->children.empty() || + ((!mapit->second->Parallelizable()) && (!mapit->second->CheckReductionLoop()))) { continue; } LoopVecInfo *vecInfo = localMP->New(localAlloc); diff --git a/src/mapleall/maple_me/src/me_value_range_prop.cpp b/src/mapleall/maple_me/src/me_value_range_prop.cpp index 47754453e5..e22253e695 100644 --- a/src/mapleall/maple_me/src/me_value_range_prop.cpp +++ b/src/mapleall/maple_me/src/me_value_range_prop.cpp @@ -845,25 +845,38 @@ bool IsNeededPrimType(PrimType prim) { int64 GetMinNumber(PrimType primType) { switch (primType) { case PTY_i8: + case PTY_v8i8: + case PTY_v16i8: return std::numeric_limits::min(); break; case PTY_i16: + case PTY_v4i16: + case PTY_v8i16: return std::numeric_limits::min(); break; case PTY_i32: + case PTY_v2i32: + case PTY_v4i32: return std::numeric_limits::min(); break; case PTY_i64: + case PTY_v2i64: return std::numeric_limits::min(); break; case PTY_u8: + case PTY_v8u8: + case PTY_v16u8: return std::numeric_limits::min(); break; case PTY_u16: + case PTY_v4u16: + case PTY_v8u16: return std::numeric_limits::min(); break; case PTY_u32: case PTY_a32: + case PTY_v4u32: + case PTY_v2u32: return std::numeric_limits::min(); break; case PTY_ref: @@ -877,6 +890,7 @@ int64 GetMinNumber(PrimType primType) { break; case PTY_u64: case PTY_a64: + case PTY_v2u64: return std::numeric_limits::min(); break; case PTY_u1: @@ -891,25 +905,38 @@ int64 GetMinNumber(PrimType primType) { int64 GetMaxNumber(PrimType primType) { switch (primType) { case PTY_i8: + case PTY_v8i8: + case PTY_v16i8: return std::numeric_limits::max(); break; case PTY_i16: + case PTY_v4i16: + case PTY_v8i16: return std::numeric_limits::max(); break; case PTY_i32: + case PTY_v2i32: + case PTY_v4i32: return std::numeric_limits::max(); break; case PTY_i64: + case PTY_v2i64: return std::numeric_limits::max(); break; case PTY_u8: + case PTY_v8u8: + case PTY_v16u8: return std::numeric_limits::max(); break; case PTY_u16: + case PTY_v4u16: + case PTY_v8u16: return std::numeric_limits::max(); break; case PTY_u32: case PTY_a32: + case PTY_v4u32: + case PTY_v2u32: return std::numeric_limits::max(); break; case PTY_ref: @@ -923,6 +950,7 @@ int64 GetMaxNumber(PrimType primType) { break; case PTY_u64: case PTY_a64: + case PTY_v2u64: return std::numeric_limits::max(); break; case PTY_u1: -- Gitee