diff --git a/src/mapleall/maple_be/include/be/lower.h b/src/mapleall/maple_be/include/be/lower.h index 67ea31b73ec3237064592f741d95c3f3d0bdd3e1..6b3a54edad67726d120836ec42a1dc853c05392c 100644 --- a/src/mapleall/maple_be/include/be/lower.h +++ b/src/mapleall/maple_be/include/be/lower.h @@ -136,7 +136,7 @@ class CGLowerer { BaseNode *LowerArray(ArrayNode &array, const BaseNode &parent); DassignNode *SaveReturnValueInLocal(StIdx, uint16); - void LowerCallStmt(StmtNode&, StmtNode*&, BlockNode&); + void LowerCallStmt(StmtNode&, StmtNode*&, BlockNode&, MIRType *retTy = nullptr); BlockNode *LowerCallAssignedStmt(StmtNode &stmt); BaseNode *LowerRem(BaseNode &rem, BlockNode &block); @@ -156,7 +156,7 @@ class CGLowerer { virtual BlockNode *LowerReturn(NaryStmtNode &retNode); void LowerEntry(MIRFunction &func); - StmtNode *LowerCall(CallNode &call, StmtNode *&stmt, BlockNode &block); + StmtNode *LowerCall(CallNode &call, StmtNode *&stmt, BlockNode &block, MIRType *reTty = nullptr); void SplitCallArg(CallNode &callNode, BaseNode *newOpnd, size_t i, BlockNode &newBlk); void CleanupBranches(MIRFunction &func) const; diff --git a/src/mapleall/maple_be/src/be/lower.cpp b/src/mapleall/maple_be/src/be/lower.cpp index 36f85a18a5354372913dae64f4ecb47558fbb1e9..df2fc77381f17f6197047f28e9f61c34fbe8579c 100644 --- a/src/mapleall/maple_be/src/be/lower.cpp +++ b/src/mapleall/maple_be/src/be/lower.cpp @@ -727,7 +727,7 @@ BaseNode *CGLowerer::LowerRem(BaseNode &expr, BlockNode &blk) { } /* to lower call (including icall) and intrinsicall statements */ -void CGLowerer::LowerCallStmt(StmtNode &stmt, StmtNode *&nextStmt, BlockNode &newBlk) { +void CGLowerer::LowerCallStmt(StmtNode &stmt, StmtNode *&nextStmt, BlockNode &newBlk, MIRType *retTy) { StmtNode *newStmt = nullptr; if (stmt.GetOpCode() == OP_intrinsiccall) { auto &intrnNode = static_cast(stmt); @@ -743,7 +743,7 @@ void CGLowerer::LowerCallStmt(StmtNode &stmt, StmtNode *&nextStmt, BlockNode &ne } if ((newStmt->GetOpCode() == OP_call) || (newStmt->GetOpCode() == OP_icall)) { - newStmt = LowerCall(static_cast(*newStmt), nextStmt, newBlk); + newStmt = LowerCall(static_cast(*newStmt), nextStmt, newBlk, retTy); } newStmt->SetSrcPos(stmt.GetSrcPos()); newBlk.AddStatement(newStmt); @@ -816,8 +816,18 @@ BlockNode *CGLowerer::GenBlockNode(StmtNode &newCall, const CallReturnVector &p2 CHECK_FATAL(p2nRets.size() <= 1, "make sure p2nRets size <= 1"); /* Create DassignStmt to save kSregRetval0. */ StmtNode *dStmt = nullptr; + MIRType *retType = nullptr; if (p2nRets.size() == 1) { + MIRSymbol *sym; StIdx stIdx = p2nRets[0].first; + if (stIdx.IsGlobal()) { + sym = GlobalTables::GetGsymTable().GetSymbolFromStidx(stIdx.Idx()); + } else { + sym = GetCurrentFunc()->GetSymbolTabItem(stIdx.Idx()); + } + if (sym) { + retType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(sym->GetTyIdx()); + } RegFieldPair regFieldPair = p2nRets[0].second; if (!regFieldPair.IsReg()) { uint16 fieldID = static_cast(regFieldPair.GetFieldID()); @@ -851,7 +861,7 @@ BlockNode *CGLowerer::GenBlockNode(StmtNode &newCall, const CallReturnVector &p2 blk->AddStatement(cmnt); } CHECK_FATAL(dStmt == nullptr || dStmt->GetNext() == nullptr, "make sure dStmt or dStmt's next is nullptr"); - LowerCallStmt(newCall, dStmt, *blk); + LowerCallStmt(newCall, dStmt, *blk, retType); if (dStmt != nullptr) { dStmt->SetSrcPos(newCall.GetSrcPos()); blk->AddStatement(dStmt); @@ -1117,7 +1127,7 @@ void CGLowerer::SplitCallArg(CallNode &callNode, BaseNode *newOpnd, size_t i, Bl } } -StmtNode *CGLowerer::LowerCall(CallNode &callNode, StmtNode *&nextStmt, BlockNode &newBlk) { +StmtNode *CGLowerer::LowerCall(CallNode &callNode, StmtNode *&nextStmt, BlockNode &newBlk, MIRType *retTy) { /* * nextStmt in-out * call $foo(constval u32 128) @@ -1166,10 +1176,6 @@ StmtNode *CGLowerer::LowerCall(CallNode &callNode, StmtNode *&nextStmt, BlockNod } } - if (callNode.GetOpCode() == OP_icall) { - return &callNode; - } - DassignNode *dassignNode = nullptr; if ((nextStmt != nullptr) && (nextStmt->GetOpCode() == OP_dassign)) { dassignNode = static_cast(nextStmt); @@ -1180,14 +1186,30 @@ StmtNode *CGLowerer::LowerCall(CallNode &callNode, StmtNode *&nextStmt, BlockNod return &callNode; } - MIRFunction *calleeFunc = GlobalTables::GetFunctionTable().GetFunctionFromPuidx(callNode.GetPUIdx()); - MIRType *retType = calleeFunc->GetReturnType(); - if (calleeFunc->IsReturnStruct() && (retType->GetPrimType() == PTY_void)) { - MIRPtrType *pretType = static_cast((calleeFunc->GetNthParamType(0))); - CHECK_FATAL(pretType != nullptr, "nullptr is not expected"); - retType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(pretType->GetPointedTyIdx()); - CHECK_FATAL((retType->GetKind() == kTypeStruct) || (retType->GetKind() == kTypeUnion), - "make sure retType is a struct type"); + if (retTy && beCommon.GetTypeSize(retTy->GetTypeIndex().GetIdx()) <= k16ByteSize) { + // return structure fitting in one or two regs. + return &callNode; + } + + MIRType *retType = nullptr; + if (callNode.op == OP_icall) { + if (retTy == nullptr) { + return &callNode; + } else { + retType = retTy; + } + } + + if (retType == nullptr) { + MIRFunction *calleeFunc = GlobalTables::GetFunctionTable().GetFunctionFromPuidx(callNode.GetPUIdx()); + retType = calleeFunc->GetReturnType(); + if (calleeFunc->IsReturnStruct() && (retType->GetPrimType() == PTY_void)) { + MIRPtrType *pretType = static_cast((calleeFunc->GetNthParamType(0))); + CHECK_FATAL(pretType != nullptr, "nullptr is not expected"); + retType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(pretType->GetPointedTyIdx()); + CHECK_FATAL((retType->GetKind() == kTypeStruct) || (retType->GetKind() == kTypeUnion), + "make sure retType is a struct type"); + } } /* if return type is not of a struct, return */ @@ -1213,9 +1235,18 @@ StmtNode *CGLowerer::LowerCall(CallNode &callNode, StmtNode *&nextStmt, BlockNod addrofNode->SetPrimType(LOWERED_PTR_TYPE); addrofNode->SetStIdx(dsgnSt->GetStIdx()); addrofNode->SetFieldID(0); - newNopnd.emplace_back(addrofNode); - for (auto *opnd : callNode.GetNopnd()) { - newNopnd.emplace_back(opnd); + if (callNode.op == OP_icall) { + auto ond = callNode.GetNopnd().begin(); + newNopnd.emplace_back(*ond); + newNopnd.emplace_back(addrofNode); + for (++ond; ond != callNode.GetNopnd().end(); ++ond) { + newNopnd.emplace_back(*ond); + } + } else { + newNopnd.emplace_back(addrofNode); + for (auto *opnd : callNode.GetNopnd()) { + newNopnd.emplace_back(opnd); + } } callNode.SetNOpnd(newNopnd); callNode.SetNumOpnds(static_cast(newNopnd.size()));