diff options
author | Andrea Di Biagio <Andrea_DiBiagio@sn.scee.net> | 2014-06-19 10:29:41 +0000 |
---|---|---|
committer | Andrea Di Biagio <Andrea_DiBiagio@sn.scee.net> | 2014-06-19 10:29:41 +0000 |
commit | cfdf8052865b01e8b8d321640c3f51ff938cc3c4 (patch) | |
tree | 65d2737c67d5032deb9efb9b6380d916fae819a7 /lib/Target/X86 | |
parent | 83175090522ebd6513e45033c342200cd645f89c (diff) | |
download | llvm-cfdf8052865b01e8b8d321640c3f51ff938cc3c4.tar.gz llvm-cfdf8052865b01e8b8d321640c3f51ff938cc3c4.tar.bz2 llvm-cfdf8052865b01e8b8d321640c3f51ff938cc3c4.tar.xz |
[X86] Teach how to combine horizontal binop even in the presence of undefs.
Before this change, the backend was unable to fold a build_vector dag
node with UNDEF operands into a single horizontal add/sub.
This patch teaches how to combine a build_vector with UNDEF operands into a
horizontal add/sub when possible. The algorithm conservatively avoids to combine
a build_vector with only a single non-UNDEF operand.
Added test haddsub-undef.ll to verify that we correctly fold horizontal binop
even in the presence of UNDEFs.
git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@211265 91177308-0d34-0410-b5e6-96231b3b80d8
Diffstat (limited to 'lib/Target/X86')
-rw-r--r-- | lib/Target/X86/X86ISelLowering.cpp | 155 |
1 files changed, 115 insertions, 40 deletions
diff --git a/lib/Target/X86/X86ISelLowering.cpp b/lib/Target/X86/X86ISelLowering.cpp index 851607eac9..a7b6e70781 100644 --- a/lib/Target/X86/X86ISelLowering.cpp +++ b/lib/Target/X86/X86ISelLowering.cpp @@ -6077,21 +6077,35 @@ X86TargetLowering::LowerBUILD_VECTORvXi1(SDValue Op, SelectionDAG &DAG) const { /// This function only analyzes elements of \p N whose indices are /// in range [BaseIdx, LastIdx). static bool isHorizontalBinOp(const BuildVectorSDNode *N, unsigned Opcode, + SelectionDAG &DAG, unsigned BaseIdx, unsigned LastIdx, SDValue &V0, SDValue &V1) { + EVT VT = N->getValueType(0); + assert(BaseIdx * 2 <= LastIdx && "Invalid Indices in input!"); - assert(N->getValueType(0).isVector() && - N->getValueType(0).getVectorNumElements() >= LastIdx && + assert(VT.isVector() && VT.getVectorNumElements() >= LastIdx && "Invalid Vector in input!"); bool IsCommutable = (Opcode == ISD::ADD || Opcode == ISD::FADD); bool CanFold = true; unsigned ExpectedVExtractIdx = BaseIdx; unsigned NumElts = LastIdx - BaseIdx; + V0 = DAG.getUNDEF(VT); + V1 = DAG.getUNDEF(VT); // Check if N implements a horizontal binop. for (unsigned i = 0, e = NumElts; i != e && CanFold; ++i) { SDValue Op = N->getOperand(i + BaseIdx); + + // Skip UNDEFs. + if (Op->getOpcode() == ISD::UNDEF) { + // Update the expected vector extract index. + if (i * 2 == NumElts) + ExpectedVExtractIdx = BaseIdx; + ExpectedVExtractIdx += 2; + continue; + } + CanFold = Op->getOpcode() == Opcode && Op->hasOneUse(); if (!CanFold) @@ -6112,12 +6126,15 @@ static bool isHorizontalBinOp(const BuildVectorSDNode *N, unsigned Opcode, unsigned I0 = cast<ConstantSDNode>(Op0.getOperand(1))->getZExtValue(); unsigned I1 = cast<ConstantSDNode>(Op1.getOperand(1))->getZExtValue(); - - if (i == 0) - V0 = Op0.getOperand(0); - else if (i * 2 == NumElts) { - V1 = Op0.getOperand(0); - ExpectedVExtractIdx = BaseIdx; + + if (i * 2 < NumElts) { + if (V0.getOpcode() == ISD::UNDEF) + V0 = Op0.getOperand(0); + } else { + if (V1.getOpcode() == ISD::UNDEF) + V1 = Op0.getOperand(0); + if (i * 2 == NumElts) + ExpectedVExtractIdx = BaseIdx; } SDValue Expected = (i * 2 < NumElts) ? V0 : V1; @@ -6163,9 +6180,14 @@ static bool isHorizontalBinOp(const BuildVectorSDNode *N, unsigned Opcode, /// Example: /// HADD V0_LO, V1_LO /// HADD V0_HI, V1_HI +/// +/// If \p isUndefLO is set, then the algorithm propagates UNDEF to the lower +/// 128-bits of the result. If \p isUndefHI is set, then UNDEF is propagated to +/// the upper 128-bits of the result. static SDValue ExpandHorizontalBinOp(const SDValue &V0, const SDValue &V1, SDLoc DL, SelectionDAG &DAG, - unsigned X86Opcode, bool Mode) { + unsigned X86Opcode, bool Mode, + bool isUndefLO, bool isUndefHI) { EVT VT = V0.getValueType(); assert(VT.is256BitVector() && VT == V1.getValueType() && "Invalid nodes in input!"); @@ -6177,13 +6199,24 @@ static SDValue ExpandHorizontalBinOp(const SDValue &V0, const SDValue &V1, SDValue V1_HI = Extract128BitVector(V1, NumElts/2, DAG, DL); EVT NewVT = V0_LO.getValueType(); - SDValue LO, HI; + SDValue LO = DAG.getUNDEF(NewVT); + SDValue HI = DAG.getUNDEF(NewVT); + if (Mode) { - LO = DAG.getNode(X86Opcode, DL, NewVT, V0_LO, V0_HI); - HI = DAG.getNode(X86Opcode, DL, NewVT, V1_LO, V1_HI); + // Don't emit a horizontal binop if the result is expected to be UNDEF. + if (!isUndefLO && V0->getOpcode() != ISD::UNDEF) + LO = DAG.getNode(X86Opcode, DL, NewVT, V0_LO, V0_HI); + if (!isUndefHI && V1->getOpcode() != ISD::UNDEF) + HI = DAG.getNode(X86Opcode, DL, NewVT, V1_LO, V1_HI); } else { - LO = DAG.getNode(X86Opcode, DL, NewVT, V0_LO, V1_LO); - HI = DAG.getNode(X86Opcode, DL, NewVT, V1_HI, V1_HI); + // Don't emit a horizontal binop if the result is expected to be UNDEF. + if (!isUndefLO && (V0_LO->getOpcode() != ISD::UNDEF || + V1_LO->getOpcode() != ISD::UNDEF)) + LO = DAG.getNode(X86Opcode, DL, NewVT, V0_LO, V1_LO); + + if (!isUndefHI && (V0_HI->getOpcode() != ISD::UNDEF || + V1_HI->getOpcode() != ISD::UNDEF)) + HI = DAG.getNode(X86Opcode, DL, NewVT, V0_HI, V1_HI); } return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, LO, HI); @@ -6198,19 +6231,37 @@ static SDValue PerformBUILD_VECTORCombine(SDNode *N, SelectionDAG &DAG, SDValue InVec0, InVec1; // Try to match horizontal ADD/SUB. + unsigned NumUndefsLO = 0; + unsigned NumUndefsHI = 0; + unsigned Half = NumElts/2; + + // Count the number of UNDEF operands in the build_vector in input. + for (unsigned i = 0, e = Half; i != e; ++i) + if (BV->getOperand(i)->getOpcode() == ISD::UNDEF) + NumUndefsLO++; + + for (unsigned i = Half, e = NumElts; i != e; ++i) + if (BV->getOperand(i)->getOpcode() == ISD::UNDEF) + NumUndefsHI++; + + // Early exit if this is either a build_vector of all UNDEFs or all the + // operands but one are UNDEF. + if (NumUndefsLO + NumUndefsHI + 1 >= NumElts) + return SDValue(); + if ((VT == MVT::v4f32 || VT == MVT::v2f64) && Subtarget->hasSSE3()) { // Try to match an SSE3 float HADD/HSUB. - if (isHorizontalBinOp(BV, ISD::FADD, 0, NumElts, InVec0, InVec1)) + if (isHorizontalBinOp(BV, ISD::FADD, DAG, 0, NumElts, InVec0, InVec1)) return DAG.getNode(X86ISD::FHADD, DL, VT, InVec0, InVec1); - if (isHorizontalBinOp(BV, ISD::FSUB, 0, NumElts, InVec0, InVec1)) + if (isHorizontalBinOp(BV, ISD::FSUB, DAG, 0, NumElts, InVec0, InVec1)) return DAG.getNode(X86ISD::FHSUB, DL, VT, InVec0, InVec1); } else if ((VT == MVT::v4i32 || VT == MVT::v8i16) && Subtarget->hasSSSE3()) { // Try to match an SSSE3 integer HADD/HSUB. - if (isHorizontalBinOp(BV, ISD::ADD, 0, NumElts, InVec0, InVec1)) + if (isHorizontalBinOp(BV, ISD::ADD, DAG, 0, NumElts, InVec0, InVec1)) return DAG.getNode(X86ISD::HADD, DL, VT, InVec0, InVec1); - if (isHorizontalBinOp(BV, ISD::SUB, 0, NumElts, InVec0, InVec1)) + if (isHorizontalBinOp(BV, ISD::SUB, DAG, 0, NumElts, InVec0, InVec1)) return DAG.getNode(X86ISD::HSUB, DL, VT, InVec0, InVec1); } @@ -6221,16 +6272,20 @@ static SDValue PerformBUILD_VECTORCombine(SDNode *N, SelectionDAG &DAG, // Try to match an AVX horizontal add/sub of packed single/double // precision floating point values from 256-bit vectors. SDValue InVec2, InVec3; - if (isHorizontalBinOp(BV, ISD::FADD, 0, NumElts/2, InVec0, InVec1) && - isHorizontalBinOp(BV, ISD::FADD, NumElts/2, NumElts, InVec2, InVec3) && - InVec0.getNode() == InVec2.getNode() && - InVec1.getNode() == InVec3.getNode()) + if (isHorizontalBinOp(BV, ISD::FADD, DAG, 0, Half, InVec0, InVec1) && + isHorizontalBinOp(BV, ISD::FADD, DAG, Half, NumElts, InVec2, InVec3) && + ((InVec0.getOpcode() == ISD::UNDEF || + InVec2.getOpcode() == ISD::UNDEF) || InVec0 == InVec2) && + ((InVec1.getOpcode() == ISD::UNDEF || + InVec3.getOpcode() == ISD::UNDEF) || InVec1 == InVec3)) return DAG.getNode(X86ISD::FHADD, DL, VT, InVec0, InVec1); - if (isHorizontalBinOp(BV, ISD::FSUB, 0, NumElts/2, InVec0, InVec1) && - isHorizontalBinOp(BV, ISD::FSUB, NumElts/2, NumElts, InVec2, InVec3) && - InVec0.getNode() == InVec2.getNode() && - InVec1.getNode() == InVec3.getNode()) + if (isHorizontalBinOp(BV, ISD::FSUB, DAG, 0, Half, InVec0, InVec1) && + isHorizontalBinOp(BV, ISD::FSUB, DAG, Half, NumElts, InVec2, InVec3) && + ((InVec0.getOpcode() == ISD::UNDEF || + InVec2.getOpcode() == ISD::UNDEF) || InVec0 == InVec2) && + ((InVec1.getOpcode() == ISD::UNDEF || + InVec3.getOpcode() == ISD::UNDEF) || InVec1 == InVec3)) return DAG.getNode(X86ISD::FHSUB, DL, VT, InVec0, InVec1); } else if (VT == MVT::v8i32 || VT == MVT::v16i16) { // Try to match an AVX2 horizontal add/sub of signed integers. @@ -6238,15 +6293,19 @@ static SDValue PerformBUILD_VECTORCombine(SDNode *N, SelectionDAG &DAG, unsigned X86Opcode; bool CanFold = true; - if (isHorizontalBinOp(BV, ISD::ADD, 0, NumElts/2, InVec0, InVec1) && - isHorizontalBinOp(BV, ISD::ADD, NumElts/2, NumElts, InVec2, InVec3) && - InVec0.getNode() == InVec2.getNode() && - InVec1.getNode() == InVec3.getNode()) + if (isHorizontalBinOp(BV, ISD::ADD, DAG, 0, Half, InVec0, InVec1) && + isHorizontalBinOp(BV, ISD::ADD, DAG, Half, NumElts, InVec2, InVec3) && + ((InVec0.getOpcode() == ISD::UNDEF || + InVec2.getOpcode() == ISD::UNDEF) || InVec0 == InVec2) && + ((InVec1.getOpcode() == ISD::UNDEF || + InVec3.getOpcode() == ISD::UNDEF) || InVec1 == InVec3)) X86Opcode = X86ISD::HADD; - else if (isHorizontalBinOp(BV, ISD::SUB, 0, NumElts/2, InVec0, InVec1) && - isHorizontalBinOp(BV, ISD::SUB, NumElts/2, NumElts, InVec2, InVec3) && - InVec0.getNode() == InVec2.getNode() && - InVec1.getNode() == InVec3.getNode()) + else if (isHorizontalBinOp(BV, ISD::SUB, DAG, 0, Half, InVec0, InVec1) && + isHorizontalBinOp(BV, ISD::SUB, DAG, Half, NumElts, InVec2, InVec3) && + ((InVec0.getOpcode() == ISD::UNDEF || + InVec2.getOpcode() == ISD::UNDEF) || InVec0 == InVec2) && + ((InVec1.getOpcode() == ISD::UNDEF || + InVec3.getOpcode() == ISD::UNDEF) || InVec1 == InVec3)) X86Opcode = X86ISD::HSUB; else CanFold = false; @@ -6257,29 +6316,45 @@ static SDValue PerformBUILD_VECTORCombine(SDNode *N, SelectionDAG &DAG, if (Subtarget->hasAVX2()) return DAG.getNode(X86Opcode, DL, VT, InVec0, InVec1); + // Do not try to expand this build_vector into a pair of horizontal + // add/sub if we can emit a pair of scalar add/sub. + if (NumUndefsLO + 1 == Half || NumUndefsHI + 1 == Half) + return SDValue(); + // Convert this build_vector into a pair of horizontal binop followed by // a concat vector. - return ExpandHorizontalBinOp(InVec0, InVec1, DL, DAG, X86Opcode, false); + bool isUndefLO = NumUndefsLO == Half; + bool isUndefHI = NumUndefsHI == Half; + return ExpandHorizontalBinOp(InVec0, InVec1, DL, DAG, X86Opcode, false, + isUndefLO, isUndefHI); } } if ((VT == MVT::v8f32 || VT == MVT::v4f64 || VT == MVT::v8i32 || VT == MVT::v16i16) && Subtarget->hasAVX()) { unsigned X86Opcode; - if (isHorizontalBinOp(BV, ISD::ADD, 0, NumElts, InVec0, InVec1)) + if (isHorizontalBinOp(BV, ISD::ADD, DAG, 0, NumElts, InVec0, InVec1)) X86Opcode = X86ISD::HADD; - else if (isHorizontalBinOp(BV, ISD::SUB, 0, NumElts, InVec0, InVec1)) + else if (isHorizontalBinOp(BV, ISD::SUB, DAG, 0, NumElts, InVec0, InVec1)) X86Opcode = X86ISD::HSUB; - else if (isHorizontalBinOp(BV, ISD::FADD, 0, NumElts, InVec0, InVec1)) + else if (isHorizontalBinOp(BV, ISD::FADD, DAG, 0, NumElts, InVec0, InVec1)) X86Opcode = X86ISD::FHADD; - else if (isHorizontalBinOp(BV, ISD::FSUB, 0, NumElts, InVec0, InVec1)) + else if (isHorizontalBinOp(BV, ISD::FSUB, DAG, 0, NumElts, InVec0, InVec1)) X86Opcode = X86ISD::FHSUB; else return SDValue(); + // Don't try to expand this build_vector into a pair of horizontal add/sub + // if we can simply emit a pair of scalar add/sub. + if (NumUndefsLO + 1 == Half || NumUndefsHI + 1 == Half) + return SDValue(); + // Convert this build_vector into two horizontal add/sub followed by // a concat vector. - return ExpandHorizontalBinOp(InVec0, InVec1, DL, DAG, X86Opcode, true); + bool isUndefLO = NumUndefsLO == Half; + bool isUndefHI = NumUndefsHI == Half; + return ExpandHorizontalBinOp(InVec0, InVec1, DL, DAG, X86Opcode, true, + isUndefLO, isUndefHI); } return SDValue(); |