diff options
Diffstat (limited to 'lib/Target/X86/X86ISelLowering.cpp')
-rw-r--r-- | lib/Target/X86/X86ISelLowering.cpp | 43 |
1 files changed, 43 insertions, 0 deletions
diff --git a/lib/Target/X86/X86ISelLowering.cpp b/lib/Target/X86/X86ISelLowering.cpp index 6df0fd880f..a878ea82ea 100644 --- a/lib/Target/X86/X86ISelLowering.cpp +++ b/lib/Target/X86/X86ISelLowering.cpp @@ -16323,6 +16323,44 @@ static SDValue XFormVExtractWithShuffleIntoLoad(SDNode *N, SelectionDAG &DAG, EltNo); } +/// Extract one bit from mask vector, like v16i1 or v8i1. +/// AVX-512 feature. +static SDValue ExtractBitFromMaskVector(SDNode *N, SelectionDAG &DAG) { + SDValue Vec = N->getOperand(0); + SDLoc dl(Vec); + MVT VecVT = Vec.getSimpleValueType(); + SDValue Idx = N->getOperand(1); + MVT EltVT = N->getSimpleValueType(0); + + assert((VecVT.getVectorElementType() == MVT::i1 && EltVT == MVT::i8) || + "Unexpected operands in ExtractBitFromMaskVector"); + + // variable index + if (!isa<ConstantSDNode>(Idx)) { + MVT ExtVT = (VecVT == MVT::v8i1 ? MVT::v8i64 : MVT::v16i32); + SDValue Ext = DAG.getNode(ISD::ZERO_EXTEND, dl, ExtVT, Vec); + SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, + ExtVT.getVectorElementType(), Ext); + return DAG.getNode(ISD::TRUNCATE, dl, EltVT, Elt); + } + + unsigned IdxVal = cast<ConstantSDNode>(Idx)->getZExtValue(); + + MVT ScalarVT = MVT::getIntegerVT(VecVT.getSizeInBits()); + unsigned MaxShift = VecVT.getSizeInBits() - 1; + Vec = DAG.getNode(ISD::BITCAST, dl, ScalarVT, Vec); + Vec = DAG.getNode(ISD::SHL, dl, ScalarVT, Vec, + DAG.getConstant(MaxShift - IdxVal, ScalarVT)); + Vec = DAG.getNode(ISD::SRL, dl, ScalarVT, Vec, + DAG.getConstant(MaxShift, ScalarVT)); + + if (VecVT == MVT::v16i1) { + Vec = DAG.getNode(ISD::BITCAST, dl, MVT::i16, Vec); + return DAG.getNode(ISD::TRUNCATE, dl, MVT::i8, Vec); + } + return DAG.getNode(ISD::BITCAST, dl, MVT::i8, Vec); +} + /// PerformEXTRACT_VECTOR_ELTCombine - Detect vector gather/scatter index /// generation and convert it from being a bunch of shuffles and extracts /// to a simple store and scalar loads to extract the elements. @@ -16333,6 +16371,11 @@ static SDValue PerformEXTRACT_VECTOR_ELTCombine(SDNode *N, SelectionDAG &DAG, return NewOp; SDValue InputVector = N->getOperand(0); + + if (InputVector.getValueType().getVectorElementType() == MVT::i1 && + !DCI.isBeforeLegalize()) + return ExtractBitFromMaskVector(N, DAG); + // Detect whether we are trying to convert from mmx to i32 and the bitcast // from mmx to v2i32 has a single usage. if (InputVector.getNode()->getOpcode() == llvm::ISD::BITCAST && |