summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorChris Lattner <sabre@nondot.org>2011-02-10 05:23:05 +0000
committerChris Lattner <sabre@nondot.org>2011-02-10 05:23:05 +0000
commitb20c0b5092f11ff349855ec1e732590160aeba23 (patch)
treea5c04b770a04d837eec981577a95a7963c82b504 /lib
parent44cc997d42f896c42a0d37fd8b98d9ec0cb28501 (diff)
downloadllvm-b20c0b5092f11ff349855ec1e732590160aeba23.tar.gz
llvm-b20c0b5092f11ff349855ec1e732590160aeba23.tar.bz2
llvm-b20c0b5092f11ff349855ec1e732590160aeba23.tar.xz
Enhance the "compare with shift" and "compare with div"
optimizations to be much more aggressive in the face of exact/nsw/nuw div and shifts. For example, these (which are the same except the first is 'exact' sdiv: define i1 @sdiv_icmp4_exact(i64 %X) nounwind { %A = sdiv exact i64 %X, -5 ; X/-5 == 0 --> x == 0 %B = icmp eq i64 %A, 0 ret i1 %B } define i1 @sdiv_icmp4(i64 %X) nounwind { %A = sdiv i64 %X, -5 ; X/-5 == 0 --> x == 0 %B = icmp eq i64 %A, 0 ret i1 %B } compile down to: define i1 @sdiv_icmp4_exact(i64 %X) nounwind { %1 = icmp eq i64 %X, 0 ret i1 %1 } define i1 @sdiv_icmp4(i64 %X) nounwind { %X.off = add i64 %X, 4 %1 = icmp ult i64 %X.off, 9 ret i1 %1 } This happens when you do something like: (ptr1-ptr2) == 42 where the pointers are pointers to non-unit types. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@125266 91177308-0d34-0410-b5e6-96231b3b80d8
Diffstat (limited to 'lib')
-rw-r--r--lib/Transforms/InstCombine/InstCombineCompares.cpp96
1 files changed, 52 insertions, 44 deletions
diff --git a/lib/Transforms/InstCombine/InstCombineCompares.cpp b/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 8c5e7e48c4..a24d4ca7eb 100644
--- a/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -22,13 +22,17 @@
using namespace llvm;
using namespace PatternMatch;
+static ConstantInt *getOne(Constant *C) {
+ return ConstantInt::get(cast<IntegerType>(C->getType()), 1);
+}
+
/// AddOne - Add one to a ConstantInt
static Constant *AddOne(Constant *C) {
return ConstantExpr::getAdd(C, ConstantInt::get(C->getType(), 1));
}
/// SubOne - Subtract one from a ConstantInt
-static Constant *SubOne(ConstantInt *C) {
- return ConstantExpr::getSub(C, ConstantInt::get(C->getType(), 1));
+static Constant *SubOne(Constant *C) {
+ return ConstantExpr::getSub(C, ConstantInt::get(C->getType(), 1));
}
static ConstantInt *ExtractElement(Constant *V, Constant *Idx) {
@@ -782,7 +786,7 @@ Instruction *InstCombiner::FoldICmpDivCst(ICmpInst &ICI, BinaryOperator *DivI,
// results than (x /s C1) <u C2 or (x /u C1) <s C2 or even
// (x /u C1) <u C2. Simply casting the operands and result won't
// work. :( The if statement below tests that condition and bails
- // if it finds it.
+ // if it finds it.
bool DivIsSigned = DivI->getOpcode() == Instruction::SDiv;
if (!ICI.isEquality() && DivIsSigned != ICI.isSigned())
return 0;
@@ -809,6 +813,10 @@ Instruction *InstCombiner::FoldICmpDivCst(ICmpInst &ICI, BinaryOperator *DivI,
// Get the ICmp opcode
ICmpInst::Predicate Pred = ICI.getPredicate();
+ /// If the division is known to be exact, then there is no remainder from the
+ /// divide, so the covered range size is unit, otherwise it is the divisor.
+ ConstantInt *RangeSize = DivI->isExact() ? getOne(Prod) : DivRHS;
+
// Figure out the interval that is being checked. For example, a comparison
// like "X /u 5 == 0" is really checking that X is in the interval [0, 5).
// Compute this interval based on the constants involved and the signedness of
@@ -818,38 +826,43 @@ Instruction *InstCombiner::FoldICmpDivCst(ICmpInst &ICI, BinaryOperator *DivI,
// -1 if overflowed off the bottom end, or +1 if overflowed off the top end.
int LoOverflow = 0, HiOverflow = 0;
Constant *LoBound = 0, *HiBound = 0;
-
+
if (!DivIsSigned) { // udiv
// e.g. X/5 op 3 --> [15, 20)
LoBound = Prod;
HiOverflow = LoOverflow = ProdOV;
- if (!HiOverflow)
- HiOverflow = AddWithOverflow(HiBound, LoBound, DivRHS, false);
+ if (!HiOverflow) {
+ // If this is not an exact divide, then many values in the range collapse
+ // to the same result value.
+ HiOverflow = AddWithOverflow(HiBound, LoBound, RangeSize, false);
+ }
+
} else if (DivRHS->getValue().isStrictlyPositive()) { // Divisor is > 0.
if (CmpRHSV == 0) { // (X / pos) op 0
// Can't overflow. e.g. X/2 op 0 --> [-1, 2)
- LoBound = cast<ConstantInt>(ConstantExpr::getNeg(SubOne(DivRHS)));
- HiBound = DivRHS;
+ LoBound = ConstantExpr::getNeg(SubOne(RangeSize));
+ HiBound = RangeSize;
} else if (CmpRHSV.isStrictlyPositive()) { // (X / pos) op pos
LoBound = Prod; // e.g. X/5 op 3 --> [15, 20)
HiOverflow = LoOverflow = ProdOV;
if (!HiOverflow)
- HiOverflow = AddWithOverflow(HiBound, Prod, DivRHS, true);
+ HiOverflow = AddWithOverflow(HiBound, Prod, RangeSize, true);
} else { // (X / pos) op neg
// e.g. X/5 op -3 --> [-15-4, -15+1) --> [-19, -14)
HiBound = AddOne(Prod);
LoOverflow = HiOverflow = ProdOV ? -1 : 0;
if (!LoOverflow) {
- ConstantInt* DivNeg =
- cast<ConstantInt>(ConstantExpr::getNeg(DivRHS));
+ ConstantInt *DivNeg =cast<ConstantInt>(ConstantExpr::getNeg(RangeSize));
LoOverflow = AddWithOverflow(LoBound, HiBound, DivNeg, true) ? -1 : 0;
- }
+ }
}
} else if (DivRHS->getValue().isNegative()) { // Divisor is < 0.
+ if (DivI->isExact())
+ RangeSize = cast<ConstantInt>(ConstantExpr::getNeg(RangeSize));
if (CmpRHSV == 0) { // (X / neg) op 0
// e.g. X/-5 op 0 --> [-4, 5)
- LoBound = AddOne(DivRHS);
- HiBound = cast<ConstantInt>(ConstantExpr::getNeg(DivRHS));
+ LoBound = AddOne(RangeSize);
+ HiBound = cast<ConstantInt>(ConstantExpr::getNeg(RangeSize));
if (HiBound == DivRHS) { // -INTMIN = INTMIN
HiOverflow = 1; // [INTMIN+1, overflow)
HiBound = 0; // e.g. X/INTMIN = 0 --> X > INTMIN
@@ -859,12 +872,12 @@ Instruction *InstCombiner::FoldICmpDivCst(ICmpInst &ICI, BinaryOperator *DivI,
HiBound = AddOne(Prod);
HiOverflow = LoOverflow = ProdOV ? -1 : 0;
if (!LoOverflow)
- LoOverflow = AddWithOverflow(LoBound, HiBound, DivRHS, true) ? -1 : 0;
+ LoOverflow = AddWithOverflow(LoBound, HiBound, RangeSize, true) ? -1:0;
} else { // (X / neg) op neg
LoBound = Prod; // e.g. X/-5 op -3 --> [15, 20)
LoOverflow = HiOverflow = ProdOV;
if (!HiOverflow)
- HiOverflow = SubWithOverflow(HiBound, Prod, DivRHS, true);
+ HiOverflow = SubWithOverflow(HiBound, Prod, RangeSize, true);
}
// Dividing by a negative swaps the condition. LT <-> GT
@@ -883,9 +896,8 @@ Instruction *InstCombiner::FoldICmpDivCst(ICmpInst &ICI, BinaryOperator *DivI,
if (LoOverflow)
return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SLT :
ICmpInst::ICMP_ULT, X, HiBound);
- return ReplaceInstUsesWith(ICI,
- InsertRangeTest(X, LoBound, HiBound, DivIsSigned,
- true));
+ return ReplaceInstUsesWith(ICI, InsertRangeTest(X, LoBound, HiBound,
+ DivIsSigned, true));
case ICmpInst::ICMP_NE:
if (LoOverflow && HiOverflow)
return ReplaceInstUsesWith(ICI, ConstantInt::getTrue(ICI.getContext()));
@@ -908,12 +920,11 @@ Instruction *InstCombiner::FoldICmpDivCst(ICmpInst &ICI, BinaryOperator *DivI,
case ICmpInst::ICMP_SGT:
if (HiOverflow == +1) // High bound greater than input range.
return ReplaceInstUsesWith(ICI, ConstantInt::getFalse(ICI.getContext()));
- else if (HiOverflow == -1) // High bound less than input range.
+ if (HiOverflow == -1) // High bound less than input range.
return ReplaceInstUsesWith(ICI, ConstantInt::getTrue(ICI.getContext()));
if (Pred == ICmpInst::ICMP_UGT)
return new ICmpInst(ICmpInst::ICMP_UGE, X, HiBound);
- else
- return new ICmpInst(ICmpInst::ICMP_SGE, X, HiBound);
+ return new ICmpInst(ICmpInst::ICMP_SGE, X, HiBound);
}
}
@@ -1182,6 +1193,12 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI,
return ReplaceInstUsesWith(ICI, Cst);
}
+ // If the shift is NUW, then it is just shifting out zeros, no need for an
+ // AND.
+ if (cast<BinaryOperator>(LHSI)->hasNoUnsignedWrap())
+ return new ICmpInst(ICI.getPredicate(), LHSI->getOperand(0),
+ ConstantExpr::getLShr(RHS, ShAmt));
+
if (LHSI->hasOneUse()) {
// Otherwise strength reduce the shift into an and.
uint32_t ShAmtVal = (uint32_t)ShAmt->getLimitedValue(TypeBits);
@@ -1192,8 +1209,7 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI,
Value *And =
Builder->CreateAnd(LHSI->getOperand(0),Mask, LHSI->getName()+".mask");
return new ICmpInst(ICI.getPredicate(), And,
- ConstantInt::get(ICI.getContext(),
- RHSV.lshr(ShAmtVal)));
+ ConstantExpr::getLShr(RHS, ShAmt));
}
}
@@ -1222,10 +1238,9 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI,
// undefined shifts. When the shift is visited it will be
// simplified.
uint32_t TypeBits = RHSV.getBitWidth();
- if (ShAmt->uge(TypeBits))
- break;
-
uint32_t ShAmtVal = (uint32_t)ShAmt->getLimitedValue(TypeBits);
+ if (ShAmtVal >= TypeBits)
+ break;
// If we are comparing against bits always shifted out, the
// comparison cannot succeed.
@@ -1245,13 +1260,10 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI,
// Otherwise, check to see if the bits shifted out are known to be zero.
// If so, we can compare against the unshifted value:
// (X & 4) >> 1 == 2 --> (X & 4) == 4.
- if (LHSI->hasOneUse() &&
- MaskedValueIsZero(LHSI->getOperand(0),
- APInt::getLowBitsSet(Comp.getBitWidth(), ShAmtVal))) {
+ if (LHSI->hasOneUse() && cast<BinaryOperator>(LHSI)->isExact())
return new ICmpInst(ICI.getPredicate(), LHSI->getOperand(0),
ConstantExpr::getShl(RHS, ShAmt));
- }
-
+
if (LHSI->hasOneUse()) {
// Otherwise strength reduce the shift into an and.
APInt Val(APInt::getHighBitsSet(TypeBits, TypeBits - ShAmtVal));
@@ -1911,14 +1923,12 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
// If the LHS is 8 >>u x, and we know the result is a power of 2 like 1,
// then turn "((8 >>u x)&1) == 0" into "x != 3".
- ConstantInt *CI = 0;
+ const APInt *CI;
if (Op0KnownZeroInverted == 1 &&
- match(LHS, m_LShr(m_ConstantInt(CI), m_Value(X))) &&
- CI->getValue().isPowerOf2()) {
- unsigned CmpVal = CI->getValue().countTrailingZeros();
+ match(LHS, m_LShr(m_Power2(CI), m_Value(X))))
return new ICmpInst(ICmpInst::ICMP_NE, X,
- ConstantInt::get(X->getType(), CmpVal));
- }
+ ConstantInt::get(X->getType(),
+ CI->countTrailingZeros()));
}
break;
@@ -1950,14 +1960,12 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
// If the LHS is 8 >>u x, and we know the result is a power of 2 like 1,
// then turn "((8 >>u x)&1) != 0" into "x == 3".
- ConstantInt *CI = 0;
+ const APInt *CI;
if (Op0KnownZeroInverted == 1 &&
- match(LHS, m_LShr(m_ConstantInt(CI), m_Value(X))) &&
- CI->getValue().isPowerOf2()) {
- unsigned CmpVal = CI->getValue().countTrailingZeros();
+ match(LHS, m_LShr(m_Power2(CI), m_Value(X))))
return new ICmpInst(ICmpInst::ICMP_EQ, X,
- ConstantInt::get(X->getType(), CmpVal));
- }
+ ConstantInt::get(X->getType(),
+ CI->countTrailingZeros()));
}
break;