From 3f222c96edb798453d3c06789aa34445b89d3f2a Mon Sep 17 00:00:00 2001 From: William Chen Date: Fri, 19 Feb 2021 15:50:48 -0800 Subject: [PATCH 1/2] Support for call return of large struct. --- .../maple_be/include/cg/aarch64/aarch64_abi.h | 3 +- .../include/cg/aarch64/aarch64_cgfunc.h | 2 +- .../include/cg/aarch64/aarch64_memlayout.h | 1 + src/mapleall/maple_be/src/be/lower.cpp | 49 ++++++++++++------- .../maple_be/src/cg/aarch64/aarch64_abi.cpp | 33 +++++++++++-- .../maple_be/src/cg/aarch64/aarch64_args.cpp | 13 ++++- .../src/cg/aarch64/aarch64_cgfunc.cpp | 45 +++++++++++++++-- .../src/cg/aarch64/aarch64_memlayout.cpp | 38 ++++++++++---- 8 files changed, 146 insertions(+), 38 deletions(-) diff --git a/src/mapleall/maple_be/include/cg/aarch64/aarch64_abi.h b/src/mapleall/maple_be/include/cg/aarch64/aarch64_abi.h index 0ad43dfb86..727d850471 100644 --- a/src/mapleall/maple_be/include/cg/aarch64/aarch64_abi.h +++ b/src/mapleall/maple_be/include/cg/aarch64/aarch64_abi.h @@ -83,7 +83,8 @@ class ParmLocator { ~ParmLocator() = default; // Return size of aggregate structure copy on stack. - int32 LocateNextParm(MIRType &mirType, PLocInfo &pLoc); + int32 LocateNextParm(MIRType &mirType, PLocInfo &pLoc, bool isFirst = false); + void InitPLocInfo(PLocInfo &pLoc); private: BECommon &beCommon; diff --git a/src/mapleall/maple_be/include/cg/aarch64/aarch64_cgfunc.h b/src/mapleall/maple_be/include/cg/aarch64/aarch64_cgfunc.h index 7a1b03be6d..d348c97278 100644 --- a/src/mapleall/maple_be/include/cg/aarch64/aarch64_cgfunc.h +++ b/src/mapleall/maple_be/include/cg/aarch64/aarch64_cgfunc.h @@ -578,7 +578,7 @@ class AArch64CGFunc : public CGFunc { AArch64ListOperand &srcOpnds); void SelectParmListForAggregate(BaseNode &argExpr, AArch64ListOperand &srcOpnds, ParmLocator &parmLocator, int32 &structCopyOffset); - + uint32 SelectParmListGetStructReturnSize(StmtNode &naryNode); void SelectParmList(StmtNode &naryNode, AArch64ListOperand &srcOpnds, bool isCallNative = false); Operand *SelectClearStackCallParam(const AddrofNode &expr, int64 &offsetValue); void SelectClearStackCallParmList(const StmtNode &naryNode, AArch64ListOperand &srcOpnds, diff --git a/src/mapleall/maple_be/include/cg/aarch64/aarch64_memlayout.h b/src/mapleall/maple_be/include/cg/aarch64/aarch64_memlayout.h index 7058782c48..c84f89a7d7 100644 --- a/src/mapleall/maple_be/include/cg/aarch64/aarch64_memlayout.h +++ b/src/mapleall/maple_be/include/cg/aarch64/aarch64_memlayout.h @@ -195,6 +195,7 @@ class AArch64MemLayout : public MemLayout { MemSegment segGrSaveArea = MemSegment(kMsGrSaveArea); MemSegment segVrSaveArea = MemSegment(kMsVrSaveArea); int32 fixStackSize = 0; + void SetSizeAlignForTypeIdx(uint32 typeIdx, uint32 &size, uint32 &align); void SetSegmentSize(AArch64SymbolAlloc &symbolAlloc, MemSegment &segment, uint32 typeIdx); void LayoutVarargParams(); void LayoutFormalParams(); diff --git a/src/mapleall/maple_be/src/be/lower.cpp b/src/mapleall/maple_be/src/be/lower.cpp index df2fc77381..a29b3e5f3a 100644 --- a/src/mapleall/maple_be/src/be/lower.cpp +++ b/src/mapleall/maple_be/src/be/lower.cpp @@ -512,7 +512,12 @@ BlockNode *CGLowerer::LowerReturnStruct(NaryStmtNode &retNode) { MIRSymbol *retSt = curFunc->GetFormal(0); MIRPtrType *retTy = static_cast(retSt->GetType()); IassignNode *iassign = mirModule.CurFuncCodeMemPool()->New(); - iassign->SetTyIdx(retTy->GetTypeIndex()); + if (beCommon.GetTypeSize(retTy->GetPointedTyIdx().GetIdx()) > k16ByteSize || !opnd0 || opnd0->GetPrimType() != PTY_agg) { + iassign->SetTyIdx(retTy->GetTypeIndex()); + } else { + /* struct goes into register. */ + iassign->SetTyIdx(retTy->GetPointedTyIdx()); + } iassign->SetFieldID(0); iassign->SetRHS(opnd0); if (retSt->IsPreg()) { @@ -825,26 +830,33 @@ BlockNode *CGLowerer::GenBlockNode(StmtNode &newCall, const CallReturnVector &p2 } else { sym = GetCurrentFunc()->GetSymbolTabItem(stIdx.Idx()); } + bool sizeIs0 = false; if (sym) { retType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(sym->GetTyIdx()); + if (beCommon.GetTypeSize(retType->GetTypeIndex().GetIdx()) == 0) { + sizeIs0 = true; + } } - RegFieldPair regFieldPair = p2nRets[0].second; - if (!regFieldPair.IsReg()) { - uint16 fieldID = static_cast(regFieldPair.GetFieldID()); - DassignNode *dn = SaveReturnValueInLocal(stIdx, fieldID); - CHECK_FATAL(dn->GetFieldID() == 0, "make sure dn's fieldID return 0"); - LowerDassign(*dn, *blk); - CHECK_FATAL(&newCall == blk->GetLast() || newCall.GetNext() == blk->GetLast(), ""); - dStmt = (&newCall == blk->GetLast()) ? nullptr : blk->GetLast(); - CHECK_FATAL(newCall.GetNext() == dStmt, "make sure newCall's next equal dStmt"); - } else { - PregIdx pregIdx = static_cast(regFieldPair.GetPregIdx()); - MIRPreg *mirPreg = GetCurrentFunc()->GetPregTab()->PregFromPregIdx(pregIdx); - RegreadNode *regNode = mirModule.GetMIRBuilder()->CreateExprRegread(mirPreg->GetPrimType(), -kSregRetval0); - RegassignNode *regAssign = - mirModule.GetMIRBuilder()->CreateStmtRegassign(mirPreg->GetPrimType(), regFieldPair.GetPregIdx(), regNode); - blk->AddStatement(regAssign); - dStmt = regAssign; + if (sizeIs0 == false) { + RegFieldPair regFieldPair = p2nRets[0].second; + if (!regFieldPair.IsReg()) { + uint16 fieldID = static_cast(regFieldPair.GetFieldID()); + DassignNode *dn = SaveReturnValueInLocal(stIdx, fieldID); + CHECK_FATAL(dn->GetFieldID() == 0, "make sure dn's fieldID return 0"); + LowerDassign(*dn, *blk); + CHECK_FATAL(&newCall == blk->GetLast() || newCall.GetNext() == blk->GetLast(), ""); + dStmt = (&newCall == blk->GetLast()) ? nullptr : blk->GetLast(); + CHECK_FATAL(newCall.GetNext() == dStmt, "make sure newCall's next equal dStmt"); + } else { + PregIdx pregIdx = static_cast(regFieldPair.GetPregIdx()); + MIRPreg *mirPreg = GetCurrentFunc()->GetPregTab()->PregFromPregIdx(pregIdx); + RegreadNode *regNode = mirModule.GetMIRBuilder()->CreateExprRegread(mirPreg->GetPrimType(), -kSregRetval0); + RegassignNode *regAssign = + mirModule.GetMIRBuilder()->CreateStmtRegassign(mirPreg->GetPrimType(), regFieldPair.GetPregIdx(), + regNode); + blk->AddStatement(regAssign); + dStmt = regAssign; + } } } blk->ResetBlock(); @@ -912,6 +924,7 @@ BlockNode *CGLowerer::LowerCallAssignedStmt(StmtNode &stmt) { auto &origCall = static_cast(stmt); newCall = GenIcallNode(funcCalled, origCall); p2nRets = &origCall.GetReturnVec(); + static_cast(newCall)->SetReturnVec(*p2nRets); break; } default: diff --git a/src/mapleall/maple_be/src/cg/aarch64/aarch64_abi.cpp b/src/mapleall/maple_be/src/cg/aarch64/aarch64_abi.cpp index 20b3fc9709..1f1ab63a96 100644 --- a/src/mapleall/maple_be/src/cg/aarch64/aarch64_abi.cpp +++ b/src/mapleall/maple_be/src/cg/aarch64/aarch64_abi.cpp @@ -361,6 +361,12 @@ bool IsSpillRegInRA(AArch64reg regNO, bool has3RegOpnd) { } } /* namespace AArch64Abi */ +void ParmLocator::InitPLocInfo(PLocInfo &pLoc) { + pLoc.reg0 = kRinvalid; + pLoc.reg1 = kRinvalid; + pLoc.memOffset = nextStackArgAdress; +} + /* * Refer to ARM IHI 0055C_beta: Procedure Call Standard for * the ARM 64-bit Architecture. $5.4.2 @@ -375,12 +381,31 @@ bool IsSpillRegInRA(AArch64reg regNO, bool has3RegOpnd) { * starting from the beginning, one call per parameter in sequence; it returns * the information on how each parameter is passed in pLoc */ -int32 ParmLocator::LocateNextParm(MIRType &mirType, PLocInfo &pLoc) { +int32 ParmLocator::LocateNextParm(MIRType &mirType, PLocInfo &pLoc, bool isFirst) { + InitPLocInfo(pLoc); + + if (isFirst) { + MIRFunction *func = const_cast(beCommon.GetMIRModule().CurFunction()); + if (beCommon.HasFuncReturnType(*func)) { + uint32 size = beCommon.GetTypeSize(beCommon.GetFuncReturnType(*func)); + if (size == 0) { + /* For return struct size 0 there is no return value. */ + return 0; + } else if (size > k16ByteSize) { + /* For return struct size > 16 bytes the pointer returns in x8. */ + pLoc.reg0 = R8; + return kSizeOfPtr; + } + /* For return struct size less or equal to 16 bytes, the values + * are returned in register pairs. Do nothing here. + */ + } + } uint64 typeSize = beCommon.GetTypeSize(mirType.GetTypeIndex()); + if (typeSize == 0) { + return 0; + } int32 typeAlign = beCommon.GetTypeAlign(mirType.GetTypeIndex()); - pLoc.reg0 = kRinvalid; - pLoc.reg1 = kRinvalid; - pLoc.memOffset = nextStackArgAdress; /* * Rule C.12 states that we do round nextStackArgAdress up before we use its value * according to the alignment requirement of the argument being processed. diff --git a/src/mapleall/maple_be/src/cg/aarch64/aarch64_args.cpp b/src/mapleall/maple_be/src/cg/aarch64/aarch64_args.cpp index 36a94eb648..5089f94b0b 100644 --- a/src/mapleall/maple_be/src/cg/aarch64/aarch64_args.cpp +++ b/src/mapleall/maple_be/src/cg/aarch64/aarch64_args.cpp @@ -31,7 +31,7 @@ void AArch64MoveRegArgs::CollectRegisterArgs(std::map &argsL PLocInfo ploc; for (uint32 i = 0; i < aarchCGFunc->GetFunction().GetFormalCount(); ++i) { MIRType *ty = aarchCGFunc->GetFunction().GetNthParamType(i); - parmlocator.LocateNextParm(*ty, ploc); + parmlocator.LocateNextParm(*ty, ploc, i == 0); if (ploc.reg0 == kRinvalid) { continue; } @@ -268,8 +268,17 @@ void AArch64MoveRegArgs::MoveVRegisterArgs() { PLocInfo ploc; for (uint32 i = 0; i < aarchCGFunc->GetFunction().GetFormalCount(); ++i) { + if (i == 0) { + MIRFunction *func = const_cast(aarchCGFunc->GetBecommon().GetMIRModule().CurFunction()); + if (aarchCGFunc->GetBecommon().HasFuncReturnType(*func)) { + TyIdx idx = aarchCGFunc->GetBecommon().GetFuncReturnType(*func); + if (aarchCGFunc->GetBecommon().GetTypeSize(idx) <= 16) { + continue; + } + } + } MIRType *ty = aarchCGFunc->GetFunction().GetNthParamType(i); - parmlocator.LocateNextParm(*ty, ploc); + parmlocator.LocateNextParm(*ty, ploc, i == 0); MIRSymbol *sym = aarchCGFunc->GetFunction().GetFormal(i); /* load locarefvar formals to store in the reflocals. */ diff --git a/src/mapleall/maple_be/src/cg/aarch64/aarch64_cgfunc.cpp b/src/mapleall/maple_be/src/cg/aarch64/aarch64_cgfunc.cpp index 906b97872a..c541a958fd 100644 --- a/src/mapleall/maple_be/src/cg/aarch64/aarch64_cgfunc.cpp +++ b/src/mapleall/maple_be/src/cg/aarch64/aarch64_cgfunc.cpp @@ -4800,6 +4800,9 @@ void AArch64CGFunc::SelectParmListIreadLargeAggregate(const IreadNode &iread, MI void AArch64CGFunc::CreateCallStructParamPassByStack(int32 symSize, MIRSymbol *sym, RegOperand *addrOpnd, int32 baseOffset) { + if (symSize == 0) { + return; + } MemOperand *ldMopnd, *stMopnd; int numRegNeeded = (symSize <= k8ByteSize) ? kOneRegister : kTwoRegister; for (int j = 0; j < numRegNeeded; j++) { @@ -4887,7 +4890,16 @@ AArch64RegOperand *AArch64CGFunc::CreateCallStructParamCopyToStack(uint32 numMem MemOperand *ldMopnd, *stMopnd; for (int j = 0; j < numMemOp; j++) { if (sym) { - ldMopnd = &GetOrCreateMemOpnd(*sym, (j * static_cast(kSizeOfPtr)), k64BitSize); + if (sym->GetStorageClass() == kScFormal) { + MemOperand &baseLoadOpnd = GetOrCreateMemOpnd(*sym, 0, k64BitSize); + RegOperand *vreg1 = &CreateVirtualRegisterOperand(NewVReg(kRegTyInt, k8ByteSize)); + Insn &ldInsn = GetCG()->BuildInstruction(PickLdInsn(k64BitSize, PTY_i64), *vreg1, baseLoadOpnd); + GetCurBB()->AppendInsn(ldInsn); + ldMopnd = &GetOrCreateMemOpnd(AArch64MemOperand::kAddrModeBOi, k64BitSize, vreg1, nullptr, + &GetOrCreateOfstOpnd(j * kSizeOfPtr, k32BitSize), nullptr); + } else { + ldMopnd = &GetOrCreateMemOpnd(*sym, (j * static_cast(kSizeOfPtr)), k64BitSize); + } } else { ldMopnd = &GetOrCreateMemOpnd(AArch64MemOperand::kAddrModeBOi, k64BitSize, addrOpd, nullptr, &GetOrCreateOfstOpnd(static_cast(j) * kSizeOfPtr, k32BitSize), nullptr); @@ -4980,6 +4992,28 @@ void AArch64CGFunc::SelectParmListForAggregate(BaseNode &argExpr, AArch64ListOpe } } +uint32 AArch64CGFunc::SelectParmListGetStructReturnSize(StmtNode &naryNode) { + if (naryNode.GetOpCode() == OP_call) { + CallNode &callNode = static_cast(naryNode); + MIRFunction *callFunc = GlobalTables::GetFunctionTable().GetFunctionFromPuidx(callNode.GetPUIdx()); + TyIdx retIdx = callFunc->GetReturnTyIdx(); + if ((GetBecommon().GetTypeSize(retIdx.GetIdx()) == 0) && GetBecommon().HasFuncReturnType(*callFunc)) { + return GetBecommon().GetTypeSize(GetBecommon().GetFuncReturnType(*callFunc)); + } + } else if (naryNode.GetOpCode() == OP_icall) { + IcallNode &icallNode = static_cast(naryNode); + CallReturnVector *p2nrets = &icallNode.GetReturnVec(); + if (p2nrets->size() == 1) { + StIdx stIdx = (*p2nrets)[0].first; + MIRSymbol *sym = GetBecommon().GetMIRModule().CurFunction()->GetSymTab()->GetSymbolFromStIdx(stIdx.Idx()); + if (sym) { + return GetBecommon().GetTypeSize(sym->GetTyIdx().GetIdx()); + } + } + } + return 0; +} + /* SelectParmList generates an instrunction for each of the parameters to load the parameter value into the corresponding register. @@ -4995,7 +5029,7 @@ void AArch64CGFunc::SelectParmList(StmtNode &naryNode, AArch64ListOperand &srcOp } int32 structCopyOffset = GetMaxParamStackSize() - GetStructCopySize(); - for (; i < naryNode.NumOpnds(); ++i) { + for (uint32 pnum = 0; i < naryNode.NumOpnds(); ++i, ++pnum) { MIRType *ty = nullptr; BaseNode *argExpr = naryNode.Opnd(i); PrimType primType = argExpr->GetPrimType(); @@ -5013,7 +5047,12 @@ void AArch64CGFunc::SelectParmList(StmtNode &naryNode, AArch64ListOperand &srcOp } expRegOpnd = static_cast(opnd); - parmLocator.LocateNextParm(*ty, ploc); + if ((pnum == 0) && (SelectParmListGetStructReturnSize(naryNode) > k16ByteSize)) { + parmLocator.InitPLocInfo(ploc); + ploc.reg0 = R8; + } else { + parmLocator.LocateNextParm(*ty, ploc); + } if (ploc.reg0 != kRinvalid) { /* load to the register. */ CHECK_FATAL(expRegOpnd != nullptr, "null ptr check"); AArch64RegOperand &parmRegOpnd = GetOrCreatePhysicalRegisterOperand(ploc.reg0, expRegOpnd->GetSize(), diff --git a/src/mapleall/maple_be/src/cg/aarch64/aarch64_memlayout.cpp b/src/mapleall/maple_be/src/cg/aarch64/aarch64_memlayout.cpp index ed1a646e11..a34f0aea2c 100644 --- a/src/mapleall/maple_be/src/cg/aarch64/aarch64_memlayout.cpp +++ b/src/mapleall/maple_be/src/cg/aarch64/aarch64_memlayout.cpp @@ -51,7 +51,7 @@ uint32 AArch64MemLayout::ComputeStackSpaceRequirementForCall(StmtNode &stmt, in } aggCopySize = 0; - for (; i < stmt.NumOpnds(); ++i) { + for (uint32 anum = 0; i < stmt.NumOpnds(); ++i, ++ anum) { BaseNode *opnd = stmt.Opnd(i); MIRType *ty = nullptr; if (opnd->GetPrimType() != PTY_agg) { @@ -88,7 +88,7 @@ uint32 AArch64MemLayout::ComputeStackSpaceRequirementForCall(StmtNode &stmt, in } } PLocInfo ploc; - aggCopySize += parmLocator.LocateNextParm(*ty, ploc); + aggCopySize += parmLocator.LocateNextParm(*ty, ploc, anum == 0); if (ploc.reg0 != 0) { continue; /* passed in register, so no effect on actual area */ } @@ -97,10 +97,24 @@ uint32 AArch64MemLayout::ComputeStackSpaceRequirementForCall(StmtNode &stmt, in return sizeOfArgsToStkPass; } +void AArch64MemLayout::SetSizeAlignForTypeIdx(uint32 typeIdx, uint32 &size, uint32 &align) { + if (be.GetTypeSize(typeIdx) > k16ByteSize) { + // size > 16 is passed on stack, the formal is just a pointer to the copy on stack. + align = kSizeOfPtr; + size = kSizeOfPtr; + } else { + align = be.GetTypeAlign(typeIdx); + size = be.GetTypeSize(typeIdx); + } +} + void AArch64MemLayout::SetSegmentSize(AArch64SymbolAlloc &symbolAlloc, MemSegment &segment, uint32 typeIdx) { - segment.SetSize(static_cast(RoundUp(static_cast(segment.GetSize()), be.GetTypeAlign(typeIdx)))); + uint32 size; + uint32 align; + SetSizeAlignForTypeIdx(typeIdx, size, align); + segment.SetSize(static_cast(RoundUp(static_cast(segment.GetSize()), align))); symbolAlloc.SetOffset(segment.GetSize()); - segment.SetSize(segment.GetSize() + static_cast(be.GetTypeSize(typeIdx))); + segment.SetSize(segment.GetSize() + static_cast(size)); segment.SetSize(static_cast(RoundUp(static_cast(segment.GetSize()), kSizeOfPtr))); } @@ -150,7 +164,7 @@ void AArch64MemLayout::LayoutFormalParams() { bool noStackPara = false; MIRType *ty = mirFunction->GetNthParamType(i); uint32 ptyIdx = ty->GetTypeIndex(); - parmLocator.LocateNextParm(*ty, ploc); + parmLocator.LocateNextParm(*ty, ploc, i == 0); uint32 stIndex = sym->GetStIndex(); AArch64SymbolAlloc *symLoc = memAllocator->GetMemPool()->New(); SetSymAllocInfo(stIndex, *symLoc); @@ -160,18 +174,24 @@ void AArch64MemLayout::LayoutFormalParams() { symLoc->SetMemSegment(segRefLocals); SetSegmentSize(*symLoc, segRefLocals, ptyIdx); } else if (!sym->IsPreg()) { + uint32 size; + uint32 align; + SetSizeAlignForTypeIdx(ptyIdx, size, align); symLoc->SetMemSegment(GetSegArgsRegPassed()); /* the type's alignment requirement may be smaller than a registser's byte size */ - segArgsRegPassed.SetSize(RoundUp(segArgsRegPassed.GetSize(), be.GetTypeAlign(ptyIdx))); + segArgsRegPassed.SetSize(RoundUp(segArgsRegPassed.GetSize(), align)); symLoc->SetOffset(segArgsRegPassed.GetSize()); - segArgsRegPassed.SetSize(segArgsRegPassed.GetSize() + be.GetTypeSize(ptyIdx)); + segArgsRegPassed.SetSize(segArgsRegPassed.GetSize() + size); } noStackPara = true; } else { /* stack */ + uint32 size; + uint32 align; + SetSizeAlignForTypeIdx(ptyIdx, size, align); symLoc->SetMemSegment(GetSegArgsStkPassed()); - segArgsStkPassed.SetSize(RoundUp(segArgsStkPassed.GetSize(), be.GetTypeAlign(ptyIdx))); + segArgsStkPassed.SetSize(RoundUp(segArgsStkPassed.GetSize(), align)); symLoc->SetOffset(segArgsStkPassed.GetSize()); - segArgsStkPassed.SetSize(segArgsStkPassed.GetSize() + be.GetTypeSize(ptyIdx)); + segArgsStkPassed.SetSize(segArgsStkPassed.GetSize() + size); /* We need it as dictated by the AArch64 ABI $5.4.2 C12 */ segArgsStkPassed.SetSize(RoundUp(segArgsStkPassed.GetSize(), kSizeOfPtr)); if (mirFunction->GetNthParamAttr(i).GetAttr(ATTR_localrefvar)) { -- Gitee From 4a5d42244f03db0365bf7b32fc6c2905bab90356 Mon Sep 17 00:00:00 2001 From: William Chen Date: Wed, 24 Feb 2021 20:23:42 -0800 Subject: [PATCH 2/2] fix merge conflict. --- src/mapleall/maple_be/src/be/lower.cpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/mapleall/maple_be/src/be/lower.cpp b/src/mapleall/maple_be/src/be/lower.cpp index 12aeacd223..01963172e4 100644 --- a/src/mapleall/maple_be/src/be/lower.cpp +++ b/src/mapleall/maple_be/src/be/lower.cpp @@ -830,12 +830,8 @@ BlockNode *CGLowerer::GenBlockNode(StmtNode &newCall, const CallReturnVector &p2 } else { sym = GetCurrentFunc()->GetSymbolTabItem(stIdx.Idx()); } -<<<<<< bool sizeIs0 = false; - if (sym) { -======= if (sym != nullptr) { ->>>>>>> master retType = GlobalTables::GetTypeTable().GetTypeFromTyIdx(sym->GetTyIdx()); if (beCommon.GetTypeSize(retType->GetTypeIndex().GetIdx()) == 0) { sizeIs0 = true; -- Gitee