summaryrefslogtreecommitdiff
path: root/lib/Target/NVPTX/NVPTXISelLowering.cpp
diff options
context:
space:
mode:
authorJustin Holewinski <jholewinski@nvidia.com>2013-06-28 17:57:55 +0000
committerJustin Holewinski <jholewinski@nvidia.com>2013-06-28 17:57:55 +0000
commitbc48ce87ef608730616c3250b18c013b1b4a39fc (patch)
tree3e8731197b9e8c01eaeda13edc0d395b5a5c49bf /lib/Target/NVPTX/NVPTXISelLowering.cpp
parentb67366514316bbb3cc3cb57f72f2d1439ec474bc (diff)
downloadllvm-bc48ce87ef608730616c3250b18c013b1b4a39fc.tar.gz
llvm-bc48ce87ef608730616c3250b18c013b1b4a39fc.tar.bz2
llvm-bc48ce87ef608730616c3250b18c013b1b4a39fc.tar.xz
[NVPTX] Add support for vectorized function return values
git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@185173 91177308-0d34-0410-b5e6-96231b3b80d8
Diffstat (limited to 'lib/Target/NVPTX/NVPTXISelLowering.cpp')
-rw-r--r--lib/Target/NVPTX/NVPTXISelLowering.cpp164
1 files changed, 137 insertions, 27 deletions
diff --git a/lib/Target/NVPTX/NVPTXISelLowering.cpp b/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 42bfab148c..9679b05ab7 100644
--- a/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -1338,37 +1338,147 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
}
-SDValue NVPTXTargetLowering::LowerReturn(
- SDValue Chain, CallingConv::ID CallConv, bool isVarArg,
- const SmallVectorImpl<ISD::OutputArg> &Outs,
- const SmallVectorImpl<SDValue> &OutVals, SDLoc dl,
- SelectionDAG &DAG) const {
+SDValue
+NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
+ bool isVarArg,
+ const SmallVectorImpl<ISD::OutputArg> &Outs,
+ const SmallVectorImpl<SDValue> &OutVals,
+ SDLoc dl, SelectionDAG &DAG) const {
+ MachineFunction &MF = DAG.getMachineFunction();
+ const Function *F = MF.getFunction();
+ const Type *RetTy = F->getReturnType();
+ const DataLayout *TD = getDataLayout();
bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
+ assert(isABI && "Non-ABI compilation is not supported");
+ if (!isABI)
+ return Chain;
- unsigned sizesofar = 0;
- unsigned idx = 0;
- for (unsigned i = 0, e = Outs.size(); i != e; ++i) {
- SDValue theVal = OutVals[i];
- EVT theValType = theVal.getValueType();
- unsigned numElems = 1;
- if (theValType.isVector())
- numElems = theValType.getVectorNumElements();
- for (unsigned j = 0, je = numElems; j != je; ++j) {
- SDValue tmpval = theVal;
- if (theValType.isVector())
- tmpval = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl,
- theValType.getVectorElementType(), tmpval,
- DAG.getIntPtrConstant(j));
- Chain = DAG.getNode(
- isABI ? NVPTXISD::StoreRetval : NVPTXISD::MoveToRetval, dl,
- MVT::Other, Chain, DAG.getConstant(isABI ? sizesofar : idx, MVT::i32),
- tmpval);
+ if (const VectorType *VTy = dyn_cast<const VectorType>(RetTy)) {
+ // If we have a vector type, the OutVals array will be the scalarized
+ // components and we have combine them into 1 or more vector stores.
+ unsigned NumElts = VTy->getNumElements();
+ assert(NumElts == Outs.size() && "Bad scalarization of return value");
+
+ // V1 store
+ if (NumElts == 1) {
+ SDValue StoreVal = OutVals[0];
+ // We only have one element, so just directly store it
+ if (StoreVal.getValueType().getSizeInBits() < 8)
+ StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i8, StoreVal);
+ Chain = DAG.getNode(NVPTXISD::StoreRetval, dl, MVT::Other, Chain,
+ DAG.getConstant(0, MVT::i32), StoreVal);
+ } else if (NumElts == 2) {
+ // V2 store
+ SDValue StoreVal0 = OutVals[0];
+ SDValue StoreVal1 = OutVals[1];
+
+ if (StoreVal0.getValueType().getSizeInBits() < 8) {
+ StoreVal0 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i8, StoreVal0);
+ StoreVal1 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i8, StoreVal1);
+ }
+
+ Chain = DAG.getNode(NVPTXISD::StoreRetvalV2, dl, MVT::Other, Chain,
+ DAG.getConstant(0, MVT::i32), StoreVal0, StoreVal1);
+ } else {
+ // V4 stores
+ // We have at least 4 elements (<3 x Ty> expands to 4 elements) and the
+ // vector will be expanded to a power of 2 elements, so we know we can
+ // always round up to the next multiple of 4 when creating the vector
+ // stores.
+ // e.g. 4 elem => 1 st.v4
+ // 6 elem => 2 st.v4
+ // 8 elem => 2 st.v4
+ // 11 elem => 3 st.v4
+
+ unsigned VecSize = 4;
+ if (OutVals[0].getValueType().getSizeInBits() == 64)
+ VecSize = 2;
+
+ unsigned Offset = 0;
+
+ EVT VecVT =
+ EVT::getVectorVT(F->getContext(), OutVals[0].getValueType(), VecSize);
+ unsigned PerStoreOffset =
+ TD->getTypeAllocSize(VecVT.getTypeForEVT(F->getContext()));
+
+ bool Extend = false;
+ if (OutVals[0].getValueType().getSizeInBits() < 8)
+ Extend = true;
+
+ for (unsigned i = 0; i < NumElts; i += VecSize) {
+ // Get values
+ SDValue StoreVal;
+ SmallVector<SDValue, 8> Ops;
+ Ops.push_back(Chain);
+ Ops.push_back(DAG.getConstant(Offset, MVT::i32));
+ unsigned Opc = NVPTXISD::StoreRetvalV2;
+ EVT ExtendedVT = (Extend) ? MVT::i8 : OutVals[0].getValueType();
+
+ StoreVal = OutVals[i];
+ if (Extend)
+ StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i8, StoreVal);
+ Ops.push_back(StoreVal);
+
+ if (i + 1 < NumElts) {
+ StoreVal = OutVals[i + 1];
+ if (Extend)
+ StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i8, StoreVal);
+ } else {
+ StoreVal = DAG.getUNDEF(ExtendedVT);
+ }
+ Ops.push_back(StoreVal);
+
+ if (VecSize == 4) {
+ Opc = NVPTXISD::StoreRetvalV4;
+ if (i + 2 < NumElts) {
+ StoreVal = OutVals[i + 2];
+ if (Extend)
+ StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i8, StoreVal);
+ } else {
+ StoreVal = DAG.getUNDEF(ExtendedVT);
+ }
+ Ops.push_back(StoreVal);
+
+ if (i + 3 < NumElts) {
+ StoreVal = OutVals[i + 3];
+ if (Extend)
+ StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i8, StoreVal);
+ } else {
+ StoreVal = DAG.getUNDEF(ExtendedVT);
+ }
+ Ops.push_back(StoreVal);
+ }
+
+ Chain = DAG.getNode(Opc, dl, MVT::Other, &Ops[0], Ops.size());
+ Offset += PerStoreOffset;
+ }
+ }
+ } else {
+ unsigned sizesofar = 0;
+ for (unsigned i = 0, e = Outs.size(); i != e; ++i) {
+ SDValue theVal = OutVals[i];
+ EVT theValType = theVal.getValueType();
+ unsigned numElems = 1;
if (theValType.isVector())
- sizesofar += theValType.getVectorElementType().getStoreSizeInBits() / 8;
- else
- sizesofar += theValType.getStoreSizeInBits() / 8;
- ++idx;
+ numElems = theValType.getVectorNumElements();
+ for (unsigned j = 0, je = numElems; j != je; ++j) {
+ SDValue tmpval = theVal;
+ if (theValType.isVector())
+ tmpval = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl,
+ theValType.getVectorElementType(), tmpval,
+ DAG.getIntPtrConstant(j));
+ EVT theStoreType = tmpval.getValueType();
+ if (theStoreType.getSizeInBits() < 8)
+ tmpval = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i8, tmpval);
+ Chain = DAG.getNode(NVPTXISD::StoreRetval, dl, MVT::Other, Chain,
+ DAG.getConstant(sizesofar, MVT::i32), tmpval);
+ if (theValType.isVector())
+ sizesofar +=
+ theValType.getVectorElementType().getStoreSizeInBits() / 8;
+ else
+ sizesofar += theValType.getStoreSizeInBits() / 8;
+ }
}
}