summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDan Gohman <gohman@apple.com>2009-04-25 17:12:48 +0000
committerDan Gohman <gohman@apple.com>2009-04-25 17:12:48 +0000
commit1c8491ecc768c410a552bc3441e456fedc2736ff (patch)
tree774f979e8a63d1f8a09f6af959514b073082f208
parent4128700ab11d0db62e5ba7ed8a8fc301c7aaa8b1 (diff)
downloadllvm-1c8491ecc768c410a552bc3441e456fedc2736ff.tar.gz
llvm-1c8491ecc768c410a552bc3441e456fedc2736ff.tar.bz2
llvm-1c8491ecc768c410a552bc3441e456fedc2736ff.tar.xz
Add several more icmp simplifications. Transform signed comparisons
into unsigned ones when the operands are known to have the same sign bit value. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@70053 91177308-0d34-0410-b5e6-96231b3b80d8
-rw-r--r--lib/Transforms/Scalar/InstructionCombining.cpp241
-rw-r--r--test/Transforms/InstCombine/signed-comparison.ll28
2 files changed, 187 insertions, 82 deletions
diff --git a/lib/Transforms/Scalar/InstructionCombining.cpp b/lib/Transforms/Scalar/InstructionCombining.cpp
index a2658b3e3f..c8cdc4c9fc 100644
--- a/lib/Transforms/Scalar/InstructionCombining.cpp
+++ b/lib/Transforms/Scalar/InstructionCombining.cpp
@@ -708,15 +708,13 @@ static bool ShrinkDemandedConstant(Instruction *I, unsigned OpNo,
// set of known zero and one bits, compute the maximum and minimum values that
// could have the specified known zero and known one bits, returning them in
// min/max.
-static void ComputeSignedMinMaxValuesFromKnownBits(const Type *Ty,
- const APInt& KnownZero,
+static void ComputeSignedMinMaxValuesFromKnownBits(const APInt& KnownZero,
const APInt& KnownOne,
APInt& Min, APInt& Max) {
- uint32_t BitWidth = cast<IntegerType>(Ty)->getBitWidth();
- assert(KnownZero.getBitWidth() == BitWidth &&
- KnownOne.getBitWidth() == BitWidth &&
- Min.getBitWidth() == BitWidth && Max.getBitWidth() == BitWidth &&
- "Ty, KnownZero, KnownOne and Min, Max must have equal bitwidth.");
+ assert(KnownZero.getBitWidth() == KnownOne.getBitWidth() &&
+ KnownZero.getBitWidth() == Min.getBitWidth() &&
+ KnownZero.getBitWidth() == Max.getBitWidth() &&
+ "KnownZero, KnownOne and Min, Max must have equal bitwidth.");
APInt UnknownBits = ~(KnownZero|KnownOne);
// The minimum value is when all unknown bits are zeros, EXCEPT for the sign
@@ -724,9 +722,9 @@ static void ComputeSignedMinMaxValuesFromKnownBits(const Type *Ty,
Min = KnownOne;
Max = KnownOne|UnknownBits;
- if (UnknownBits[BitWidth-1]) { // Sign bit is unknown
- Min.set(BitWidth-1);
- Max.clear(BitWidth-1);
+ if (UnknownBits.isNegative()) { // Sign bit is unknown
+ Min.set(Min.getBitWidth()-1);
+ Max.clear(Max.getBitWidth()-1);
}
}
@@ -734,14 +732,12 @@ static void ComputeSignedMinMaxValuesFromKnownBits(const Type *Ty,
// a set of known zero and one bits, compute the maximum and minimum values that
// could have the specified known zero and known one bits, returning them in
// min/max.
-static void ComputeUnsignedMinMaxValuesFromKnownBits(const Type *Ty,
- const APInt &KnownZero,
+static void ComputeUnsignedMinMaxValuesFromKnownBits(const APInt &KnownZero,
const APInt &KnownOne,
APInt &Min, APInt &Max) {
- uint32_t BitWidth = cast<IntegerType>(Ty)->getBitWidth(); BitWidth = BitWidth;
- assert(KnownZero.getBitWidth() == BitWidth &&
- KnownOne.getBitWidth() == BitWidth &&
- Min.getBitWidth() == BitWidth && Max.getBitWidth() &&
+ assert(KnownZero.getBitWidth() == KnownOne.getBitWidth() &&
+ KnownZero.getBitWidth() == Min.getBitWidth() &&
+ KnownZero.getBitWidth() == Max.getBitWidth() &&
"Ty, KnownZero, KnownOne and Min, Max must have equal bitwidth.");
APInt UnknownBits = ~(KnownZero|KnownOne);
@@ -808,9 +804,13 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
assert(V != 0 && "Null pointer of Value???");
assert(Depth <= 6 && "Limit Search Depth");
uint32_t BitWidth = DemandedMask.getBitWidth();
- const IntegerType *VTy = cast<IntegerType>(V->getType());
- assert(VTy->getBitWidth() == BitWidth &&
- KnownZero.getBitWidth() == BitWidth &&
+ const Type *VTy = V->getType();
+ assert((TD || !isa<PointerType>(VTy)) &&
+ "SimplifyDemandedBits needs to know bit widths!");
+ assert((!TD || TD->getTypeSizeInBits(VTy) == BitWidth) &&
+ (!isa<IntegerType>(VTy) ||
+ VTy->getPrimitiveSizeInBits() == BitWidth) &&
+ KnownZero.getBitWidth() == BitWidth &&
KnownOne.getBitWidth() == BitWidth &&
"Value *V, DemandedMask, KnownZero and KnownOne \
must have same BitWidth");
@@ -820,7 +820,13 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
KnownZero = ~KnownOne & DemandedMask;
return 0;
}
-
+ if (isa<ConstantPointerNull>(V)) {
+ // We know all of the bits for a constant!
+ KnownOne.clear();
+ KnownZero = DemandedMask;
+ return 0;
+ }
+
KnownZero.clear();
KnownOne.clear();
if (DemandedMask == 0) { // Not demanding any bits from V.
@@ -832,12 +838,15 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
if (Depth == 6) // Limit search depth.
return 0;
- Instruction *I = dyn_cast<Instruction>(V);
- if (!I) return 0; // Only analyze instructions.
-
APInt LHSKnownZero(BitWidth, 0), LHSKnownOne(BitWidth, 0);
APInt &RHSKnownZero = KnownZero, &RHSKnownOne = KnownOne;
+ Instruction *I = dyn_cast<Instruction>(V);
+ if (!I) {
+ ComputeMaskedBits(V, DemandedMask, RHSKnownZero, RHSKnownOne, Depth);
+ return 0; // Only analyze instructions.
+ }
+
// If there are multiple uses of this value and we aren't at the root, then
// we can't do any simplifications of the operands, because DemandedMask
// only reflects the bits demanded by *one* of the users.
@@ -1399,8 +1408,12 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
// If the client is only demanding bits that we know, return the known
// constant.
- if ((DemandedMask & (RHSKnownZero|RHSKnownOne)) == DemandedMask)
- return ConstantInt::get(RHSKnownOne);
+ if ((DemandedMask & (RHSKnownZero|RHSKnownOne)) == DemandedMask) {
+ Constant *C = ConstantInt::get(RHSKnownOne);
+ if (isa<PointerType>(V->getType()))
+ C = ConstantExpr::getIntToPtr(C, V->getType());
+ return C;
+ }
return false;
}
@@ -5831,6 +5844,14 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
}
}
+ unsigned BitWidth = 0;
+ if (TD)
+ BitWidth = TD->getTypeSizeInBits(Ty);
+ else if (isa<IntegerType>(Ty))
+ BitWidth = Ty->getPrimitiveSizeInBits();
+
+ bool isSignBit = false;
+
// See if we are doing a comparison with a constant.
if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) {
Value *A = 0, *B = 0;
@@ -5865,105 +5886,161 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
return new ICmpInst(ICmpInst::ICMP_SGT, Op0, SubOne(CI));
}
- // See if we can fold the comparison based on range information we can get
- // by checking whether bits are known to be zero or one in the input.
- uint32_t BitWidth = cast<IntegerType>(Ty)->getBitWidth();
- APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0);
-
// If this comparison is a normal comparison, it demands all
// bits, if it is a sign bit comparison, it only demands the sign bit.
bool UnusedBit;
- bool isSignBit = isSignBitCheck(I.getPredicate(), CI, UnusedBit);
-
- if (SimplifyDemandedBits(I.getOperandUse(0),
+ isSignBit = isSignBitCheck(I.getPredicate(), CI, UnusedBit);
+ }
+
+ // See if we can fold the comparison based on range information we can get
+ // by checking whether bits are known to be zero or one in the input.
+ if (BitWidth != 0) {
+ APInt Op0KnownZero(BitWidth, 0), Op0KnownOne(BitWidth, 0);
+ APInt Op1KnownZero(BitWidth, 0), Op1KnownOne(BitWidth, 0);
+
+ if (SimplifyDemandedBits(I.getOperandUse(0),
isSignBit ? APInt::getSignBit(BitWidth)
: APInt::getAllOnesValue(BitWidth),
- KnownZero, KnownOne, 0))
+ Op0KnownZero, Op0KnownOne, 0))
return &I;
-
+ if (SimplifyDemandedBits(I.getOperandUse(1),
+ APInt::getAllOnesValue(BitWidth),
+ Op1KnownZero, Op1KnownOne, 0))
+ return &I;
+
// Given the known and unknown bits, compute a range that the LHS could be
// in. Compute the Min, Max and RHS values based on the known bits. For the
// EQ and NE we use unsigned values.
- APInt Min(BitWidth, 0), Max(BitWidth, 0);
- if (ICmpInst::isSignedPredicate(I.getPredicate()))
- ComputeSignedMinMaxValuesFromKnownBits(Ty, KnownZero, KnownOne, Min, Max);
- else
- ComputeUnsignedMinMaxValuesFromKnownBits(Ty, KnownZero, KnownOne,Min,Max);
-
+ APInt Op0Min(BitWidth, 0), Op0Max(BitWidth, 0);
+ APInt Op1Min(BitWidth, 0), Op1Max(BitWidth, 0);
+ if (ICmpInst::isSignedPredicate(I.getPredicate())) {
+ ComputeSignedMinMaxValuesFromKnownBits(Op0KnownZero, Op0KnownOne,
+ Op0Min, Op0Max);
+ ComputeSignedMinMaxValuesFromKnownBits(Op1KnownZero, Op1KnownOne,
+ Op1Min, Op1Max);
+ } else {
+ ComputeUnsignedMinMaxValuesFromKnownBits(Op0KnownZero, Op0KnownOne,
+ Op0Min, Op0Max);
+ ComputeUnsignedMinMaxValuesFromKnownBits(Op1KnownZero, Op1KnownOne,
+ Op1Min, Op1Max);
+ }
+
// If Min and Max are known to be the same, then SimplifyDemandedBits
// figured out that the LHS is a constant. Just constant fold this now so
// that code below can assume that Min != Max.
- if (Min == Max)
- return ReplaceInstUsesWith(I, ConstantExpr::getICmp(I.getPredicate(),
- ConstantInt::get(Min),
- CI));
-
+ if (!isa<Constant>(Op0) && Op0Min == Op0Max)
+ return new ICmpInst(I.getPredicate(), ConstantInt::get(Op0Min), Op1);
+ if (!isa<Constant>(Op1) && Op1Min == Op1Max)
+ return new ICmpInst(I.getPredicate(), Op0, ConstantInt::get(Op1Min));
+
// Based on the range information we know about the LHS, see if we can
// simplify this comparison. For example, (x&4) < 8 is always true.
- const APInt &RHSVal = CI->getValue();
- switch (I.getPredicate()) { // LE/GE have been folded already.
+ switch (I.getPredicate()) {
default: assert(0 && "Unknown icmp opcode!");
case ICmpInst::ICMP_EQ:
- if (Max.ult(RHSVal) || Min.ugt(RHSVal))
+ if (Op0Max.ult(Op1Min) || Op0Min.ugt(Op1Max))
return ReplaceInstUsesWith(I, ConstantInt::getFalse());
break;
case ICmpInst::ICMP_NE:
- if (Max.ult(RHSVal) || Min.ugt(RHSVal))
+ if (Op0Max.ult(Op1Min) || Op0Min.ugt(Op1Max))
return ReplaceInstUsesWith(I, ConstantInt::getTrue());
break;
case ICmpInst::ICMP_ULT:
- if (Max.ult(RHSVal)) // A <u C -> true iff max(A) < C
+ if (Op0Max.ult(Op1Min)) // A <u B -> true if max(A) < min(B)
return ReplaceInstUsesWith(I, ConstantInt::getTrue());
- if (Min.uge(RHSVal)) // A <u C -> false iff min(A) >= C
+ if (Op0Min.uge(Op1Max)) // A <u B -> false if min(A) >= max(B)
return ReplaceInstUsesWith(I, ConstantInt::getFalse());
- if (RHSVal == Max) // A <u MAX -> A != MAX
+ if (Op1Min == Op0Max) // A <u B -> A != B if max(A) == min(B)
return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
- if (RHSVal == Min+1) // A <u MIN+1 -> A == MIN
- return new ICmpInst(ICmpInst::ICMP_EQ, Op0, SubOne(CI));
-
- // (x <u 2147483648) -> (x >s -1) -> true if sign bit clear
- if (CI->isMinValue(true))
- return new ICmpInst(ICmpInst::ICMP_SGT, Op0,
+ if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) {
+ if (Op1Max == Op0Min+1) // A <u C -> A == C-1 if min(A)+1 == C
+ return new ICmpInst(ICmpInst::ICMP_EQ, Op0, SubOne(CI));
+
+ // (x <u 2147483648) -> (x >s -1) -> true if sign bit clear
+ if (CI->isMinValue(true))
+ return new ICmpInst(ICmpInst::ICMP_SGT, Op0,
ConstantInt::getAllOnesValue(Op0->getType()));
+ }
break;
case ICmpInst::ICMP_UGT:
- if (Min.ugt(RHSVal)) // A >u C -> true iff min(A) > C
+ if (Op0Min.ugt(Op1Max)) // A >u B -> true if min(A) > max(B)
return ReplaceInstUsesWith(I, ConstantInt::getTrue());
- if (Max.ule(RHSVal)) // A >u C -> false iff max(A) <= C
+ if (Op0Max.ule(Op1Min)) // A >u B -> false if max(A) <= max(B)
return ReplaceInstUsesWith(I, ConstantInt::getFalse());
-
- if (RHSVal == Min) // A >u MIN -> A != MIN
+
+ if (Op1Max == Op0Min) // A >u B -> A != B if min(A) == max(B)
return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
- if (RHSVal == Max-1) // A >u MAX-1 -> A == MAX
- return new ICmpInst(ICmpInst::ICMP_EQ, Op0, AddOne(CI));
-
- // (x >u 2147483647) -> (x <s 0) -> true if sign bit set
- if (CI->isMaxValue(true))
- return new ICmpInst(ICmpInst::ICMP_SLT, Op0,
- ConstantInt::getNullValue(Op0->getType()));
+ if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) {
+ if (Op1Min == Op0Max-1) // A >u C -> A == C+1 if max(a)-1 == C
+ return new ICmpInst(ICmpInst::ICMP_EQ, Op0, AddOne(CI));
+
+ // (x >u 2147483647) -> (x <s 0) -> true if sign bit set
+ if (CI->isMaxValue(true))
+ return new ICmpInst(ICmpInst::ICMP_SLT, Op0,
+ ConstantInt::getNullValue(Op0->getType()));
+ }
break;
case ICmpInst::ICMP_SLT:
- if (Max.slt(RHSVal)) // A <s C -> true iff max(A) < C
+ if (Op0Max.slt(Op1Min)) // A <s B -> true if max(A) < min(C)
return ReplaceInstUsesWith(I, ConstantInt::getTrue());
- if (Min.sge(RHSVal)) // A <s C -> false iff min(A) >= C
+ if (Op0Min.sge(Op1Max)) // A <s B -> false if min(A) >= max(C)
return ReplaceInstUsesWith(I, ConstantInt::getFalse());
- if (RHSVal == Max) // A <s MAX -> A != MAX
+ if (Op1Min == Op0Max) // A <s B -> A != B if max(A) == min(B)
return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
- if (RHSVal == Min+1) // A <s MIN+1 -> A == MIN
- return new ICmpInst(ICmpInst::ICMP_EQ, Op0, SubOne(CI));
+ if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) {
+ if (Op1Max == Op0Min+1) // A <s C -> A == C-1 if min(A)+1 == C
+ return new ICmpInst(ICmpInst::ICMP_EQ, Op0, SubOne(CI));
+ }
break;
- case ICmpInst::ICMP_SGT:
- if (Min.sgt(RHSVal)) // A >s C -> true iff min(A) > C
+ case ICmpInst::ICMP_SGT:
+ if (Op0Min.sgt(Op1Max)) // A >s B -> true if min(A) > max(B)
return ReplaceInstUsesWith(I, ConstantInt::getTrue());
- if (Max.sle(RHSVal)) // A >s C -> false iff max(A) <= C
+ if (Op0Max.sle(Op1Min)) // A >s B -> false if max(A) <= min(B)
return ReplaceInstUsesWith(I, ConstantInt::getFalse());
-
- if (RHSVal == Min) // A >s MIN -> A != MIN
+
+ if (Op1Max == Op0Min) // A >s B -> A != B if min(A) == max(B)
return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
- if (RHSVal == Max-1) // A >s MAX-1 -> A == MAX
- return new ICmpInst(ICmpInst::ICMP_EQ, Op0, AddOne(CI));
+ if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) {
+ if (Op1Min == Op0Max-1) // A >s C -> A == C+1 if max(A)-1 == C
+ return new ICmpInst(ICmpInst::ICMP_EQ, Op0, AddOne(CI));
+ }
+ break;
+ case ICmpInst::ICMP_SGE:
+ assert(!isa<ConstantInt>(Op1) && "ICMP_SGE with ConstantInt not folded!");
+ if (Op0Min.sge(Op1Max)) // A >=s B -> true if min(A) >= max(B)
+ return ReplaceInstUsesWith(I, ConstantInt::getTrue());
+ if (Op0Max.slt(Op1Min)) // A >=s B -> false if max(A) < min(B)
+ return ReplaceInstUsesWith(I, ConstantInt::getFalse());
+ break;
+ case ICmpInst::ICMP_SLE:
+ assert(!isa<ConstantInt>(Op1) && "ICMP_SLE with ConstantInt not folded!");
+ if (Op0Max.sle(Op1Min)) // A <=s B -> true if max(A) <= min(B)
+ return ReplaceInstUsesWith(I, ConstantInt::getTrue());
+ if (Op0Min.sgt(Op1Max)) // A <=s B -> false if min(A) > max(B)
+ return ReplaceInstUsesWith(I, ConstantInt::getFalse());
+ break;
+ case ICmpInst::ICMP_UGE:
+ assert(!isa<ConstantInt>(Op1) && "ICMP_UGE with ConstantInt not folded!");
+ if (Op0Min.uge(Op1Max)) // A >=u B -> true if min(A) >= max(B)
+ return ReplaceInstUsesWith(I, ConstantInt::getTrue());
+ if (Op0Max.ult(Op1Min)) // A >=u B -> false if max(A) < min(B)
+ return ReplaceInstUsesWith(I, ConstantInt::getFalse());
+ break;
+ case ICmpInst::ICMP_ULE:
+ assert(!isa<ConstantInt>(Op1) && "ICMP_ULE with ConstantInt not folded!");
+ if (Op0Max.ule(Op1Min)) // A <=u B -> true if max(A) <= min(B)
+ return ReplaceInstUsesWith(I, ConstantInt::getTrue());
+ if (Op0Min.ugt(Op1Max)) // A <=u B -> false if min(A) > max(B)
+ return ReplaceInstUsesWith(I, ConstantInt::getFalse());
break;
}
+
+ // Turn a signed comparison into an unsigned one if both operands
+ // are known to have the same sign.
+ if (I.isSignedPredicate() &&
+ ((Op0KnownZero.isNegative() && Op1KnownZero.isNegative()) ||
+ (Op0KnownOne.isNegative() && Op1KnownOne.isNegative())))
+ return new ICmpInst(I.getUnsignedPredicate(), Op0, Op1);
}
// Test if the ICmpInst instruction is used exclusively by a select as
diff --git a/test/Transforms/InstCombine/signed-comparison.ll b/test/Transforms/InstCombine/signed-comparison.ll
new file mode 100644
index 0000000000..fdf150f9c6
--- /dev/null
+++ b/test/Transforms/InstCombine/signed-comparison.ll
@@ -0,0 +1,28 @@
+; RUN: llvm-as < %s | opt -instcombine | llvm-dis > %t
+; RUN: not grep zext %t
+; RUN: not grep slt %t
+; RUN: grep {icmp ult} %t
+
+; Instcombine should convert the zext+slt into a simple ult.
+
+define void @foo(double* %p) nounwind {
+entry:
+ br label %bb
+
+bb:
+ %indvar = phi i64 [ 0, %entry ], [ %indvar.next, %bb ]
+ %t0 = and i64 %indvar, 65535
+ %t1 = getelementptr double* %p, i64 %t0
+ %t2 = load double* %t1, align 8
+ %t3 = mul double %t2, 2.2
+ store double %t3, double* %t1, align 8
+ %i.04 = trunc i64 %indvar to i16
+ %t4 = add i16 %i.04, 1
+ %t5 = zext i16 %t4 to i32
+ %t6 = icmp slt i32 %t5, 500
+ %indvar.next = add i64 %indvar, 1
+ br i1 %t6, label %bb, label %return
+
+return:
+ ret void
+}