From bc48ce87ef608730616c3250b18c013b1b4a39fc Mon Sep 17 00:00:00 2001 From: Justin Holewinski Date: Fri, 28 Jun 2013 17:57:55 +0000 Subject: [NVPTX] Add support for vectorized function return values git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@185173 91177308-0d34-0410-b5e6-96231b3b80d8 --- lib/Target/NVPTX/NVPTXISelLowering.cpp | 164 +++++++++++++++++++++++++++------ 1 file changed, 137 insertions(+), 27 deletions(-) (limited to 'lib/Target/NVPTX/NVPTXISelLowering.cpp') diff --git a/lib/Target/NVPTX/NVPTXISelLowering.cpp b/lib/Target/NVPTX/NVPTXISelLowering.cpp index 42bfab148c..9679b05ab7 100644 --- a/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -1338,37 +1338,147 @@ SDValue NVPTXTargetLowering::LowerFormalArguments( } -SDValue NVPTXTargetLowering::LowerReturn( - SDValue Chain, CallingConv::ID CallConv, bool isVarArg, - const SmallVectorImpl &Outs, - const SmallVectorImpl &OutVals, SDLoc dl, - SelectionDAG &DAG) const { +SDValue +NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv, + bool isVarArg, + const SmallVectorImpl &Outs, + const SmallVectorImpl &OutVals, + SDLoc dl, SelectionDAG &DAG) const { + MachineFunction &MF = DAG.getMachineFunction(); + const Function *F = MF.getFunction(); + const Type *RetTy = F->getReturnType(); + const DataLayout *TD = getDataLayout(); bool isABI = (nvptxSubtarget.getSmVersion() >= 20); + assert(isABI && "Non-ABI compilation is not supported"); + if (!isABI) + return Chain; - unsigned sizesofar = 0; - unsigned idx = 0; - for (unsigned i = 0, e = Outs.size(); i != e; ++i) { - SDValue theVal = OutVals[i]; - EVT theValType = theVal.getValueType(); - unsigned numElems = 1; - if (theValType.isVector()) - numElems = theValType.getVectorNumElements(); - for (unsigned j = 0, je = numElems; j != je; ++j) { - SDValue tmpval = theVal; - if (theValType.isVector()) - tmpval = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, - theValType.getVectorElementType(), tmpval, - DAG.getIntPtrConstant(j)); - Chain = DAG.getNode( - isABI ? NVPTXISD::StoreRetval : NVPTXISD::MoveToRetval, dl, - MVT::Other, Chain, DAG.getConstant(isABI ? sizesofar : idx, MVT::i32), - tmpval); + if (const VectorType *VTy = dyn_cast(RetTy)) { + // If we have a vector type, the OutVals array will be the scalarized + // components and we have combine them into 1 or more vector stores. + unsigned NumElts = VTy->getNumElements(); + assert(NumElts == Outs.size() && "Bad scalarization of return value"); + + // V1 store + if (NumElts == 1) { + SDValue StoreVal = OutVals[0]; + // We only have one element, so just directly store it + if (StoreVal.getValueType().getSizeInBits() < 8) + StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i8, StoreVal); + Chain = DAG.getNode(NVPTXISD::StoreRetval, dl, MVT::Other, Chain, + DAG.getConstant(0, MVT::i32), StoreVal); + } else if (NumElts == 2) { + // V2 store + SDValue StoreVal0 = OutVals[0]; + SDValue StoreVal1 = OutVals[1]; + + if (StoreVal0.getValueType().getSizeInBits() < 8) { + StoreVal0 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i8, StoreVal0); + StoreVal1 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i8, StoreVal1); + } + + Chain = DAG.getNode(NVPTXISD::StoreRetvalV2, dl, MVT::Other, Chain, + DAG.getConstant(0, MVT::i32), StoreVal0, StoreVal1); + } else { + // V4 stores + // We have at least 4 elements (<3 x Ty> expands to 4 elements) and the + // vector will be expanded to a power of 2 elements, so we know we can + // always round up to the next multiple of 4 when creating the vector + // stores. + // e.g. 4 elem => 1 st.v4 + // 6 elem => 2 st.v4 + // 8 elem => 2 st.v4 + // 11 elem => 3 st.v4 + + unsigned VecSize = 4; + if (OutVals[0].getValueType().getSizeInBits() == 64) + VecSize = 2; + + unsigned Offset = 0; + + EVT VecVT = + EVT::getVectorVT(F->getContext(), OutVals[0].getValueType(), VecSize); + unsigned PerStoreOffset = + TD->getTypeAllocSize(VecVT.getTypeForEVT(F->getContext())); + + bool Extend = false; + if (OutVals[0].getValueType().getSizeInBits() < 8) + Extend = true; + + for (unsigned i = 0; i < NumElts; i += VecSize) { + // Get values + SDValue StoreVal; + SmallVector Ops; + Ops.push_back(Chain); + Ops.push_back(DAG.getConstant(Offset, MVT::i32)); + unsigned Opc = NVPTXISD::StoreRetvalV2; + EVT ExtendedVT = (Extend) ? MVT::i8 : OutVals[0].getValueType(); + + StoreVal = OutVals[i]; + if (Extend) + StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i8, StoreVal); + Ops.push_back(StoreVal); + + if (i + 1 < NumElts) { + StoreVal = OutVals[i + 1]; + if (Extend) + StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i8, StoreVal); + } else { + StoreVal = DAG.getUNDEF(ExtendedVT); + } + Ops.push_back(StoreVal); + + if (VecSize == 4) { + Opc = NVPTXISD::StoreRetvalV4; + if (i + 2 < NumElts) { + StoreVal = OutVals[i + 2]; + if (Extend) + StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i8, StoreVal); + } else { + StoreVal = DAG.getUNDEF(ExtendedVT); + } + Ops.push_back(StoreVal); + + if (i + 3 < NumElts) { + StoreVal = OutVals[i + 3]; + if (Extend) + StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i8, StoreVal); + } else { + StoreVal = DAG.getUNDEF(ExtendedVT); + } + Ops.push_back(StoreVal); + } + + Chain = DAG.getNode(Opc, dl, MVT::Other, &Ops[0], Ops.size()); + Offset += PerStoreOffset; + } + } + } else { + unsigned sizesofar = 0; + for (unsigned i = 0, e = Outs.size(); i != e; ++i) { + SDValue theVal = OutVals[i]; + EVT theValType = theVal.getValueType(); + unsigned numElems = 1; if (theValType.isVector()) - sizesofar += theValType.getVectorElementType().getStoreSizeInBits() / 8; - else - sizesofar += theValType.getStoreSizeInBits() / 8; - ++idx; + numElems = theValType.getVectorNumElements(); + for (unsigned j = 0, je = numElems; j != je; ++j) { + SDValue tmpval = theVal; + if (theValType.isVector()) + tmpval = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, + theValType.getVectorElementType(), tmpval, + DAG.getIntPtrConstant(j)); + EVT theStoreType = tmpval.getValueType(); + if (theStoreType.getSizeInBits() < 8) + tmpval = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i8, tmpval); + Chain = DAG.getNode(NVPTXISD::StoreRetval, dl, MVT::Other, Chain, + DAG.getConstant(sizesofar, MVT::i32), tmpval); + if (theValType.isVector()) + sizesofar += + theValType.getVectorElementType().getStoreSizeInBits() / 8; + else + sizesofar += theValType.getStoreSizeInBits() / 8; + } } } -- cgit v1.2.3