summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLang Hames <lhames@gmail.com>2013-10-21 17:51:24 +0000
committerLang Hames <lhames@gmail.com>2013-10-21 17:51:24 +0000
commit1d82537762a0f4019bde301d498d190140585f57 (patch)
tree85024ec02b3891da0b68ca933c8b5a251534c394
parentef2d919f7f767b9d70bc44c4ac918b0a6b540cac (diff)
downloadllvm-1d82537762a0f4019bde301d498d190140585f57.tar.gz
llvm-1d82537762a0f4019bde301d498d190140585f57.tar.bz2
llvm-1d82537762a0f4019bde301d498d190140585f57.tar.xz
X86 vector element shift-by-immediate instructions take i8 immediates. Make
the instruction defenitions and ISEL reflect this. Prior to this patch these instructions took an i32i8imm, and the high bits were dropped during encoding. This led to incorrect behavior for shifts by immediates higher than 255. This patch fixes that issue by detecting large immediate shifts and returning constant zero (for logical shifts) or capping the shift amount at an encodable value (for arithmetic shifts). Fixes <rdar://problem/14968098> git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@193096 91177308-0d34-0410-b5e6-96231b3b80d8
-rw-r--r--lib/Target/X86/X86ISelLowering.cpp109
-rw-r--r--lib/Target/X86/X86InstrAVX512.td12
-rw-r--r--lib/Target/X86/X86InstrSSE.td12
-rw-r--r--test/CodeGen/X86/avx2-vector-shifts.ll4
-rw-r--r--test/CodeGen/X86/sse2-vector-shifts.ll4
5 files changed, 77 insertions, 64 deletions
diff --git a/lib/Target/X86/X86ISelLowering.cpp b/lib/Target/X86/X86ISelLowering.cpp
index dea4a4616f..e8918f4c34 100644
--- a/lib/Target/X86/X86ISelLowering.cpp
+++ b/lib/Target/X86/X86ISelLowering.cpp
@@ -10952,6 +10952,26 @@ static SDValue LowerVACOPY(SDValue Op, const X86Subtarget *Subtarget,
MachinePointerInfo(DstSV), MachinePointerInfo(SrcSV));
}
+// getTargetVShiftByConstNode - Handle vector element shifts where the shift
+// amount is a constant. Takes immediate version of shift as input.
+static SDValue getTargetVShiftByConstNode(unsigned Opc, SDLoc dl, EVT VT,
+ SDValue SrcOp, uint64_t ShiftAmt,
+ SelectionDAG &DAG) {
+
+ // Check for ShiftAmt >= element width
+ if (ShiftAmt >= VT.getVectorElementType().getSizeInBits()) {
+ if (Opc == X86ISD::VSRAI)
+ ShiftAmt = VT.getVectorElementType().getSizeInBits() - 1;
+ else
+ return DAG.getConstant(0, VT);
+ }
+
+ assert((Opc == X86ISD::VSHLI || Opc == X86ISD::VSRLI || Opc == X86ISD::VSRAI)
+ && "Unknown target vector shift-by-constant node");
+
+ return DAG.getNode(Opc, dl, VT, SrcOp, DAG.getConstant(ShiftAmt, MVT::i8));
+}
+
// getTargetVShiftNode - Handle vector element shifts where the shift amount
// may or may not be a constant. Takes immediate version of shift as input.
static SDValue getTargetVShiftNode(unsigned Opc, SDLoc dl, EVT VT,
@@ -10959,18 +10979,10 @@ static SDValue getTargetVShiftNode(unsigned Opc, SDLoc dl, EVT VT,
SelectionDAG &DAG) {
assert(ShAmt.getValueType() == MVT::i32 && "ShAmt is not i32");
- if (isa<ConstantSDNode>(ShAmt)) {
- // Constant may be a TargetConstant. Use a regular constant.
- uint32_t ShiftAmt = cast<ConstantSDNode>(ShAmt)->getZExtValue();
- switch (Opc) {
- default: llvm_unreachable("Unknown target vector shift node");
- case X86ISD::VSHLI:
- case X86ISD::VSRLI:
- case X86ISD::VSRAI:
- return DAG.getNode(Opc, dl, VT, SrcOp,
- DAG.getConstant(ShiftAmt, MVT::i32));
- }
- }
+ // Catch shift-by-constant.
+ if (ConstantSDNode *CShAmt = dyn_cast<ConstantSDNode>(ShAmt))
+ return getTargetVShiftByConstNode(Opc, dl, VT, SrcOp,
+ CShAmt->getZExtValue(), DAG);
// Change opcode to non-immediate version
switch (Opc) {
@@ -12416,10 +12428,8 @@ static SDValue LowerMUL(SDValue Op, const X86Subtarget *Subtarget,
// AhiBlo = psllqi(AhiBlo, 32);
// return AloBlo + AloBhi + AhiBlo;
- SDValue ShAmt = DAG.getConstant(32, MVT::i32);
-
- SDValue Ahi = DAG.getNode(X86ISD::VSRLI, dl, VT, A, ShAmt);
- SDValue Bhi = DAG.getNode(X86ISD::VSRLI, dl, VT, B, ShAmt);
+ SDValue Ahi = getTargetVShiftByConstNode(X86ISD::VSRLI, dl, VT, A, 32, DAG);
+ SDValue Bhi = getTargetVShiftByConstNode(X86ISD::VSRLI, dl, VT, B, 32, DAG);
// Bit cast to 32-bit vectors for MULUDQ
EVT MulVT = (VT == MVT::v2i64) ? MVT::v4i32 :
@@ -12433,8 +12443,8 @@ static SDValue LowerMUL(SDValue Op, const X86Subtarget *Subtarget,
SDValue AloBhi = DAG.getNode(X86ISD::PMULUDQ, dl, VT, A, Bhi);
SDValue AhiBlo = DAG.getNode(X86ISD::PMULUDQ, dl, VT, Ahi, B);
- AloBhi = DAG.getNode(X86ISD::VSHLI, dl, VT, AloBhi, ShAmt);
- AhiBlo = DAG.getNode(X86ISD::VSHLI, dl, VT, AhiBlo, ShAmt);
+ AloBhi = getTargetVShiftByConstNode(X86ISD::VSHLI, dl, VT, AloBhi, 32, DAG);
+ AhiBlo = getTargetVShiftByConstNode(X86ISD::VSHLI, dl, VT, AhiBlo, 32, DAG);
SDValue Res = DAG.getNode(ISD::ADD, dl, VT, AloBlo, AloBhi);
return DAG.getNode(ISD::ADD, dl, VT, Res, AhiBlo);
@@ -12462,7 +12472,7 @@ static SDValue LowerSDIV(SDValue Op, SelectionDAG &DAG) {
if ((SplatValue != 0) &&
(SplatValue.isPowerOf2() || (-SplatValue).isPowerOf2())) {
- unsigned lg2 = SplatValue.countTrailingZeros();
+ unsigned Lg2 = SplatValue.countTrailingZeros();
// Splat the sign bit.
SmallVector<SDValue, 16> Sz(NumElts,
DAG.getConstant(EltTy.getSizeInBits() - 1,
@@ -12472,13 +12482,13 @@ static SDValue LowerSDIV(SDValue Op, SelectionDAG &DAG) {
NumElts));
// Add (N0 < 0) ? abs2 - 1 : 0;
SmallVector<SDValue, 16> Amt(NumElts,
- DAG.getConstant(EltTy.getSizeInBits() - lg2,
+ DAG.getConstant(EltTy.getSizeInBits() - Lg2,
EltTy));
SDValue SRL = DAG.getNode(ISD::SRL, dl, VT, SGN,
DAG.getNode(ISD::BUILD_VECTOR, dl, VT, &Amt[0],
NumElts));
SDValue ADD = DAG.getNode(ISD::ADD, dl, VT, N0, SRL);
- SmallVector<SDValue, 16> Lg2Amt(NumElts, DAG.getConstant(lg2, EltTy));
+ SmallVector<SDValue, 16> Lg2Amt(NumElts, DAG.getConstant(Lg2, EltTy));
SDValue SRA = DAG.getNode(ISD::SRA, dl, VT, ADD,
DAG.getNode(ISD::BUILD_VECTOR, dl, VT, &Lg2Amt[0],
NumElts));
@@ -12514,21 +12524,22 @@ static SDValue LowerScalarImmediateShift(SDValue Op, SelectionDAG &DAG,
(Subtarget->hasAVX512() &&
(VT == MVT::v8i64 || VT == MVT::v16i32))) {
if (Op.getOpcode() == ISD::SHL)
- return DAG.getNode(X86ISD::VSHLI, dl, VT, R,
- DAG.getConstant(ShiftAmt, MVT::i32));
+ return getTargetVShiftByConstNode(X86ISD::VSHLI, dl, VT, R, ShiftAmt,
+ DAG);
if (Op.getOpcode() == ISD::SRL)
- return DAG.getNode(X86ISD::VSRLI, dl, VT, R,
- DAG.getConstant(ShiftAmt, MVT::i32));
+ return getTargetVShiftByConstNode(X86ISD::VSRLI, dl, VT, R, ShiftAmt,
+ DAG);
if (Op.getOpcode() == ISD::SRA && VT != MVT::v2i64 && VT != MVT::v4i64)
- return DAG.getNode(X86ISD::VSRAI, dl, VT, R,
- DAG.getConstant(ShiftAmt, MVT::i32));
+ return getTargetVShiftByConstNode(X86ISD::VSRAI, dl, VT, R, ShiftAmt,
+ DAG);
}
if (VT == MVT::v16i8) {
if (Op.getOpcode() == ISD::SHL) {
// Make a large shift.
- SDValue SHL = DAG.getNode(X86ISD::VSHLI, dl, MVT::v8i16, R,
- DAG.getConstant(ShiftAmt, MVT::i32));
+ SDValue SHL = getTargetVShiftByConstNode(X86ISD::VSHLI, dl,
+ MVT::v8i16, R, ShiftAmt,
+ DAG);
SHL = DAG.getNode(ISD::BITCAST, dl, VT, SHL);
// Zero out the rightmost bits.
SmallVector<SDValue, 16> V(16,
@@ -12539,8 +12550,9 @@ static SDValue LowerScalarImmediateShift(SDValue Op, SelectionDAG &DAG,
}
if (Op.getOpcode() == ISD::SRL) {
// Make a large shift.
- SDValue SRL = DAG.getNode(X86ISD::VSRLI, dl, MVT::v8i16, R,
- DAG.getConstant(ShiftAmt, MVT::i32));
+ SDValue SRL = getTargetVShiftByConstNode(X86ISD::VSRLI, dl,
+ MVT::v8i16, R, ShiftAmt,
+ DAG);
SRL = DAG.getNode(ISD::BITCAST, dl, VT, SRL);
// Zero out the leftmost bits.
SmallVector<SDValue, 16> V(16,
@@ -12571,8 +12583,9 @@ static SDValue LowerScalarImmediateShift(SDValue Op, SelectionDAG &DAG,
if (Subtarget->hasInt256() && VT == MVT::v32i8) {
if (Op.getOpcode() == ISD::SHL) {
// Make a large shift.
- SDValue SHL = DAG.getNode(X86ISD::VSHLI, dl, MVT::v16i16, R,
- DAG.getConstant(ShiftAmt, MVT::i32));
+ SDValue SHL = getTargetVShiftByConstNode(X86ISD::VSHLI, dl,
+ MVT::v16i16, R, ShiftAmt,
+ DAG);
SHL = DAG.getNode(ISD::BITCAST, dl, VT, SHL);
// Zero out the rightmost bits.
SmallVector<SDValue, 32> V(32,
@@ -12583,8 +12596,9 @@ static SDValue LowerScalarImmediateShift(SDValue Op, SelectionDAG &DAG,
}
if (Op.getOpcode() == ISD::SRL) {
// Make a large shift.
- SDValue SRL = DAG.getNode(X86ISD::VSRLI, dl, MVT::v16i16, R,
- DAG.getConstant(ShiftAmt, MVT::i32));
+ SDValue SRL = getTargetVShiftByConstNode(X86ISD::VSRLI, dl,
+ MVT::v16i16, R, ShiftAmt,
+ DAG);
SRL = DAG.getNode(ISD::BITCAST, dl, VT, SRL);
// Zero out the leftmost bits.
SmallVector<SDValue, 32> V(32,
@@ -12649,14 +12663,14 @@ static SDValue LowerScalarImmediateShift(SDValue Op, SelectionDAG &DAG,
default:
llvm_unreachable("Unknown shift opcode!");
case ISD::SHL:
- return DAG.getNode(X86ISD::VSHLI, dl, VT, R,
- DAG.getConstant(ShiftAmt, MVT::i32));
+ return getTargetVShiftByConstNode(X86ISD::VSHLI, dl, VT, R, ShiftAmt,
+ DAG);
case ISD::SRL:
- return DAG.getNode(X86ISD::VSRLI, dl, VT, R,
- DAG.getConstant(ShiftAmt, MVT::i32));
+ return getTargetVShiftByConstNode(X86ISD::VSRLI, dl, VT, R, ShiftAmt,
+ DAG);
case ISD::SRA:
- return DAG.getNode(X86ISD::VSRAI, dl, VT, R,
- DAG.getConstant(ShiftAmt, MVT::i32));
+ return getTargetVShiftByConstNode(X86ISD::VSRAI, dl, VT, R, ShiftAmt,
+ DAG);
}
}
@@ -12869,8 +12883,7 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget* Subtarget,
// r = VSELECT(r, psllw(r & (char16)15, 4), a);
SDValue M = DAG.getNode(ISD::AND, dl, VT, R, CM1);
- M = getTargetVShiftNode(X86ISD::VSHLI, dl, MVT::v8i16, M,
- DAG.getConstant(4, MVT::i32), DAG);
+ M = getTargetVShiftByConstNode(X86ISD::VSHLI, dl, MVT::v8i16, M, 4, DAG);
M = DAG.getNode(ISD::BITCAST, dl, VT, M);
R = DAG.getNode(ISD::VSELECT, dl, VT, OpVSel, M, R);
@@ -12881,8 +12894,7 @@ static SDValue LowerShift(SDValue Op, const X86Subtarget* Subtarget,
// r = VSELECT(r, psllw(r & (char16)63, 2), a);
M = DAG.getNode(ISD::AND, dl, VT, R, CM2);
- M = getTargetVShiftNode(X86ISD::VSHLI, dl, MVT::v8i16, M,
- DAG.getConstant(2, MVT::i32), DAG);
+ M = getTargetVShiftByConstNode(X86ISD::VSHLI, dl, MVT::v8i16, M, 2, DAG);
M = DAG.getNode(ISD::BITCAST, dl, VT, M);
R = DAG.getNode(ISD::VSELECT, dl, VT, OpVSel, M, R);
@@ -13025,7 +13037,6 @@ SDValue X86TargetLowering::LowerSIGN_EXTEND_INREG(SDValue Op,
unsigned BitsDiff = VT.getScalarType().getSizeInBits() -
ExtraVT.getScalarType().getSizeInBits();
- SDValue ShAmt = DAG.getConstant(BitsDiff, MVT::i32);
switch (VT.getSimpleVT().SimpleTy) {
default: return SDValue();
@@ -13075,8 +13086,10 @@ SDValue X86TargetLowering::LowerSIGN_EXTEND_INREG(SDValue Op,
}
// If the above didn't work, then just use Shift-Left + Shift-Right.
- Tmp1 = getTargetVShiftNode(X86ISD::VSHLI, dl, VT, Op0, ShAmt, DAG);
- return getTargetVShiftNode(X86ISD::VSRAI, dl, VT, Tmp1, ShAmt, DAG);
+ Tmp1 = getTargetVShiftByConstNode(X86ISD::VSHLI, dl, VT, Op0, BitsDiff,
+ DAG);
+ return getTargetVShiftByConstNode(X86ISD::VSRAI, dl, VT, Tmp1, BitsDiff,
+ DAG);
}
}
}
diff --git a/lib/Target/X86/X86InstrAVX512.td b/lib/Target/X86/X86InstrAVX512.td
index 3fd725cd95..05e346dec5 100644
--- a/lib/Target/X86/X86InstrAVX512.td
+++ b/lib/Target/X86/X86InstrAVX512.td
@@ -1845,22 +1845,22 @@ multiclass avx512_shift_rmi<bits<8> opc, Format ImmFormR, Format ImmFormM,
ValueType vt, X86MemOperand x86memop, PatFrag mem_frag,
RegisterClass KRC> {
def ri : AVX512BIi8<opc, ImmFormR, (outs RC:$dst),
- (ins RC:$src1, i32i8imm:$src2),
+ (ins RC:$src1, i8imm:$src2),
!strconcat(OpcodeStr, "\t{$src2, $src1, $dst|$dst, $src1, $src2}"),
- [(set RC:$dst, (vt (OpNode RC:$src1, (i32 imm:$src2))))],
+ [(set RC:$dst, (vt (OpNode RC:$src1, (i8 imm:$src2))))],
SSE_INTSHIFT_ITINS_P.rr>, EVEX_4V;
def rik : AVX512BIi8<opc, ImmFormR, (outs RC:$dst),
- (ins KRC:$mask, RC:$src1, i32i8imm:$src2),
+ (ins KRC:$mask, RC:$src1, i8imm:$src2),
!strconcat(OpcodeStr,
"\t{$src2, $src1, $dst {${mask}}|$dst {${mask}}, $src1, $src2}"),
[], SSE_INTSHIFT_ITINS_P.rr>, EVEX_4V, EVEX_K;
def mi: AVX512BIi8<opc, ImmFormM, (outs RC:$dst),
- (ins x86memop:$src1, i32i8imm:$src2),
+ (ins x86memop:$src1, i8imm:$src2),
!strconcat(OpcodeStr, "\t{$src2, $src1, $dst|$dst, $src1, $src2}"),
[(set RC:$dst, (OpNode (mem_frag addr:$src1),
- (i32 imm:$src2)))], SSE_INTSHIFT_ITINS_P.rm>, EVEX_4V;
+ (i8 imm:$src2)))], SSE_INTSHIFT_ITINS_P.rm>, EVEX_4V;
def mik: AVX512BIi8<opc, ImmFormM, (outs RC:$dst),
- (ins KRC:$mask, x86memop:$src1, i32i8imm:$src2),
+ (ins KRC:$mask, x86memop:$src1, i8imm:$src2),
!strconcat(OpcodeStr,
"\t{$src2, $src1, $dst {${mask}}|$dst {${mask}}, $src1, $src2}"),
[], SSE_INTSHIFT_ITINS_P.rm>, EVEX_4V, EVEX_K;
diff --git a/lib/Target/X86/X86InstrSSE.td b/lib/Target/X86/X86InstrSSE.td
index c2f319704d..f1bb9f84a7 100644
--- a/lib/Target/X86/X86InstrSSE.td
+++ b/lib/Target/X86/X86InstrSSE.td
@@ -3744,11 +3744,11 @@ multiclass PDI_binop_rmi<bits<8> opc, bits<8> opc2, Format ImmForm,
(bc_frag (memopv2i64 addr:$src2)))))], itins.rm>,
Sched<[WriteVecShiftLd, ReadAfterLd]>;
def ri : PDIi8<opc2, ImmForm, (outs RC:$dst),
- (ins RC:$src1, i32i8imm:$src2),
+ (ins RC:$src1, i8imm:$src2),
!if(Is2Addr,
!strconcat(OpcodeStr, "\t{$src2, $dst|$dst, $src2}"),
!strconcat(OpcodeStr, "\t{$src2, $src1, $dst|$dst, $src1, $src2}")),
- [(set RC:$dst, (DstVT (OpNode2 RC:$src1, (i32 imm:$src2))))], itins.ri>,
+ [(set RC:$dst, (DstVT (OpNode2 RC:$src1, (i8 imm:$src2))))], itins.ri>,
Sched<[WriteVecShift]>;
}
@@ -5064,12 +5064,12 @@ multiclass SS3I_unop_rm_int_y<bits<8> opc, string OpcodeStr,
// Helper fragments to match sext vXi1 to vXiY.
def v16i1sextv16i8 : PatLeaf<(v16i8 (X86pcmpgt (bc_v16i8 (v4i32 immAllZerosV)),
VR128:$src))>;
-def v8i1sextv8i16 : PatLeaf<(v8i16 (X86vsrai VR128:$src, (i32 15)))>;
-def v4i1sextv4i32 : PatLeaf<(v4i32 (X86vsrai VR128:$src, (i32 31)))>;
+def v8i1sextv8i16 : PatLeaf<(v8i16 (X86vsrai VR128:$src, (i8 15)))>;
+def v4i1sextv4i32 : PatLeaf<(v4i32 (X86vsrai VR128:$src, (i8 31)))>;
def v32i1sextv32i8 : PatLeaf<(v32i8 (X86pcmpgt (bc_v32i8 (v8i32 immAllZerosV)),
VR256:$src))>;
-def v16i1sextv16i16: PatLeaf<(v16i16 (X86vsrai VR256:$src, (i32 15)))>;
-def v8i1sextv8i32 : PatLeaf<(v8i32 (X86vsrai VR256:$src, (i32 31)))>;
+def v16i1sextv16i16: PatLeaf<(v16i16 (X86vsrai VR256:$src, (i8 15)))>;
+def v8i1sextv8i32 : PatLeaf<(v8i32 (X86vsrai VR256:$src, (i8 31)))>;
let Predicates = [HasAVX] in {
defm VPABSB : SS3I_unop_rm_int<0x1C, "vpabsb",
diff --git a/test/CodeGen/X86/avx2-vector-shifts.ll b/test/CodeGen/X86/avx2-vector-shifts.ll
index a978d93fc5..5592e6c8a5 100644
--- a/test/CodeGen/X86/avx2-vector-shifts.ll
+++ b/test/CodeGen/X86/avx2-vector-shifts.ll
@@ -121,7 +121,7 @@ entry:
}
; CHECK-LABEL: test_sraw_3:
-; CHECK: vpsraw $16, %ymm0, %ymm0
+; CHECK: vpsraw $15, %ymm0, %ymm0
; CHECK: ret
define <8 x i32> @test_srad_1(<8 x i32> %InVec) {
@@ -151,7 +151,7 @@ entry:
}
; CHECK-LABEL: test_srad_3:
-; CHECK: vpsrad $32, %ymm0, %ymm0
+; CHECK: vpsrad $31, %ymm0, %ymm0
; CHECK: ret
; SSE Logical Shift Right
diff --git a/test/CodeGen/X86/sse2-vector-shifts.ll b/test/CodeGen/X86/sse2-vector-shifts.ll
index e2d612567a..462def980a 100644
--- a/test/CodeGen/X86/sse2-vector-shifts.ll
+++ b/test/CodeGen/X86/sse2-vector-shifts.ll
@@ -121,7 +121,7 @@ entry:
}
; CHECK-LABEL: test_sraw_3:
-; CHECK: psraw $16, %xmm0
+; CHECK: psraw $15, %xmm0
; CHECK-NEXT: ret
define <4 x i32> @test_srad_1(<4 x i32> %InVec) {
@@ -151,7 +151,7 @@ entry:
}
; CHECK-LABEL: test_srad_3:
-; CHECK: psrad $32, %xmm0
+; CHECK: psrad $31, %xmm0
; CHECK-NEXT: ret
; SSE Logical Shift Right