diff options
author | Chris Lattner <sabre@nondot.org> | 2011-02-10 05:23:05 +0000 |
---|---|---|
committer | Chris Lattner <sabre@nondot.org> | 2011-02-10 05:23:05 +0000 |
commit | b20c0b5092f11ff349855ec1e732590160aeba23 (patch) | |
tree | a5c04b770a04d837eec981577a95a7963c82b504 /lib | |
parent | 44cc997d42f896c42a0d37fd8b98d9ec0cb28501 (diff) | |
download | llvm-b20c0b5092f11ff349855ec1e732590160aeba23.tar.gz llvm-b20c0b5092f11ff349855ec1e732590160aeba23.tar.bz2 llvm-b20c0b5092f11ff349855ec1e732590160aeba23.tar.xz |
Enhance the "compare with shift" and "compare with div"
optimizations to be much more aggressive in the face of
exact/nsw/nuw div and shifts. For example, these (which
are the same except the first is 'exact' sdiv:
define i1 @sdiv_icmp4_exact(i64 %X) nounwind {
%A = sdiv exact i64 %X, -5 ; X/-5 == 0 --> x == 0
%B = icmp eq i64 %A, 0
ret i1 %B
}
define i1 @sdiv_icmp4(i64 %X) nounwind {
%A = sdiv i64 %X, -5 ; X/-5 == 0 --> x == 0
%B = icmp eq i64 %A, 0
ret i1 %B
}
compile down to:
define i1 @sdiv_icmp4_exact(i64 %X) nounwind {
%1 = icmp eq i64 %X, 0
ret i1 %1
}
define i1 @sdiv_icmp4(i64 %X) nounwind {
%X.off = add i64 %X, 4
%1 = icmp ult i64 %X.off, 9
ret i1 %1
}
This happens when you do something like:
(ptr1-ptr2) == 42
where the pointers are pointers to non-unit types.
git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@125266 91177308-0d34-0410-b5e6-96231b3b80d8
Diffstat (limited to 'lib')
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineCompares.cpp | 96 |
1 files changed, 52 insertions, 44 deletions
diff --git a/lib/Transforms/InstCombine/InstCombineCompares.cpp b/lib/Transforms/InstCombine/InstCombineCompares.cpp index 8c5e7e48c4..a24d4ca7eb 100644 --- a/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -22,13 +22,17 @@ using namespace llvm; using namespace PatternMatch; +static ConstantInt *getOne(Constant *C) { + return ConstantInt::get(cast<IntegerType>(C->getType()), 1); +} + /// AddOne - Add one to a ConstantInt static Constant *AddOne(Constant *C) { return ConstantExpr::getAdd(C, ConstantInt::get(C->getType(), 1)); } /// SubOne - Subtract one from a ConstantInt -static Constant *SubOne(ConstantInt *C) { - return ConstantExpr::getSub(C, ConstantInt::get(C->getType(), 1)); +static Constant *SubOne(Constant *C) { + return ConstantExpr::getSub(C, ConstantInt::get(C->getType(), 1)); } static ConstantInt *ExtractElement(Constant *V, Constant *Idx) { @@ -782,7 +786,7 @@ Instruction *InstCombiner::FoldICmpDivCst(ICmpInst &ICI, BinaryOperator *DivI, // results than (x /s C1) <u C2 or (x /u C1) <s C2 or even // (x /u C1) <u C2. Simply casting the operands and result won't // work. :( The if statement below tests that condition and bails - // if it finds it. + // if it finds it. bool DivIsSigned = DivI->getOpcode() == Instruction::SDiv; if (!ICI.isEquality() && DivIsSigned != ICI.isSigned()) return 0; @@ -809,6 +813,10 @@ Instruction *InstCombiner::FoldICmpDivCst(ICmpInst &ICI, BinaryOperator *DivI, // Get the ICmp opcode ICmpInst::Predicate Pred = ICI.getPredicate(); + /// If the division is known to be exact, then there is no remainder from the + /// divide, so the covered range size is unit, otherwise it is the divisor. + ConstantInt *RangeSize = DivI->isExact() ? getOne(Prod) : DivRHS; + // Figure out the interval that is being checked. For example, a comparison // like "X /u 5 == 0" is really checking that X is in the interval [0, 5). // Compute this interval based on the constants involved and the signedness of @@ -818,38 +826,43 @@ Instruction *InstCombiner::FoldICmpDivCst(ICmpInst &ICI, BinaryOperator *DivI, // -1 if overflowed off the bottom end, or +1 if overflowed off the top end. int LoOverflow = 0, HiOverflow = 0; Constant *LoBound = 0, *HiBound = 0; - + if (!DivIsSigned) { // udiv // e.g. X/5 op 3 --> [15, 20) LoBound = Prod; HiOverflow = LoOverflow = ProdOV; - if (!HiOverflow) - HiOverflow = AddWithOverflow(HiBound, LoBound, DivRHS, false); + if (!HiOverflow) { + // If this is not an exact divide, then many values in the range collapse + // to the same result value. + HiOverflow = AddWithOverflow(HiBound, LoBound, RangeSize, false); + } + } else if (DivRHS->getValue().isStrictlyPositive()) { // Divisor is > 0. if (CmpRHSV == 0) { // (X / pos) op 0 // Can't overflow. e.g. X/2 op 0 --> [-1, 2) - LoBound = cast<ConstantInt>(ConstantExpr::getNeg(SubOne(DivRHS))); - HiBound = DivRHS; + LoBound = ConstantExpr::getNeg(SubOne(RangeSize)); + HiBound = RangeSize; } else if (CmpRHSV.isStrictlyPositive()) { // (X / pos) op pos LoBound = Prod; // e.g. X/5 op 3 --> [15, 20) HiOverflow = LoOverflow = ProdOV; if (!HiOverflow) - HiOverflow = AddWithOverflow(HiBound, Prod, DivRHS, true); + HiOverflow = AddWithOverflow(HiBound, Prod, RangeSize, true); } else { // (X / pos) op neg // e.g. X/5 op -3 --> [-15-4, -15+1) --> [-19, -14) HiBound = AddOne(Prod); LoOverflow = HiOverflow = ProdOV ? -1 : 0; if (!LoOverflow) { - ConstantInt* DivNeg = - cast<ConstantInt>(ConstantExpr::getNeg(DivRHS)); + ConstantInt *DivNeg =cast<ConstantInt>(ConstantExpr::getNeg(RangeSize)); LoOverflow = AddWithOverflow(LoBound, HiBound, DivNeg, true) ? -1 : 0; - } + } } } else if (DivRHS->getValue().isNegative()) { // Divisor is < 0. + if (DivI->isExact()) + RangeSize = cast<ConstantInt>(ConstantExpr::getNeg(RangeSize)); if (CmpRHSV == 0) { // (X / neg) op 0 // e.g. X/-5 op 0 --> [-4, 5) - LoBound = AddOne(DivRHS); - HiBound = cast<ConstantInt>(ConstantExpr::getNeg(DivRHS)); + LoBound = AddOne(RangeSize); + HiBound = cast<ConstantInt>(ConstantExpr::getNeg(RangeSize)); if (HiBound == DivRHS) { // -INTMIN = INTMIN HiOverflow = 1; // [INTMIN+1, overflow) HiBound = 0; // e.g. X/INTMIN = 0 --> X > INTMIN @@ -859,12 +872,12 @@ Instruction *InstCombiner::FoldICmpDivCst(ICmpInst &ICI, BinaryOperator *DivI, HiBound = AddOne(Prod); HiOverflow = LoOverflow = ProdOV ? -1 : 0; if (!LoOverflow) - LoOverflow = AddWithOverflow(LoBound, HiBound, DivRHS, true) ? -1 : 0; + LoOverflow = AddWithOverflow(LoBound, HiBound, RangeSize, true) ? -1:0; } else { // (X / neg) op neg LoBound = Prod; // e.g. X/-5 op -3 --> [15, 20) LoOverflow = HiOverflow = ProdOV; if (!HiOverflow) - HiOverflow = SubWithOverflow(HiBound, Prod, DivRHS, true); + HiOverflow = SubWithOverflow(HiBound, Prod, RangeSize, true); } // Dividing by a negative swaps the condition. LT <-> GT @@ -883,9 +896,8 @@ Instruction *InstCombiner::FoldICmpDivCst(ICmpInst &ICI, BinaryOperator *DivI, if (LoOverflow) return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT, X, HiBound); - return ReplaceInstUsesWith(ICI, - InsertRangeTest(X, LoBound, HiBound, DivIsSigned, - true)); + return ReplaceInstUsesWith(ICI, InsertRangeTest(X, LoBound, HiBound, + DivIsSigned, true)); case ICmpInst::ICMP_NE: if (LoOverflow && HiOverflow) return ReplaceInstUsesWith(ICI, ConstantInt::getTrue(ICI.getContext())); @@ -908,12 +920,11 @@ Instruction *InstCombiner::FoldICmpDivCst(ICmpInst &ICI, BinaryOperator *DivI, case ICmpInst::ICMP_SGT: if (HiOverflow == +1) // High bound greater than input range. return ReplaceInstUsesWith(ICI, ConstantInt::getFalse(ICI.getContext())); - else if (HiOverflow == -1) // High bound less than input range. + if (HiOverflow == -1) // High bound less than input range. return ReplaceInstUsesWith(ICI, ConstantInt::getTrue(ICI.getContext())); if (Pred == ICmpInst::ICMP_UGT) return new ICmpInst(ICmpInst::ICMP_UGE, X, HiBound); - else - return new ICmpInst(ICmpInst::ICMP_SGE, X, HiBound); + return new ICmpInst(ICmpInst::ICMP_SGE, X, HiBound); } } @@ -1182,6 +1193,12 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, return ReplaceInstUsesWith(ICI, Cst); } + // If the shift is NUW, then it is just shifting out zeros, no need for an + // AND. + if (cast<BinaryOperator>(LHSI)->hasNoUnsignedWrap()) + return new ICmpInst(ICI.getPredicate(), LHSI->getOperand(0), + ConstantExpr::getLShr(RHS, ShAmt)); + if (LHSI->hasOneUse()) { // Otherwise strength reduce the shift into an and. uint32_t ShAmtVal = (uint32_t)ShAmt->getLimitedValue(TypeBits); @@ -1192,8 +1209,7 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, Value *And = Builder->CreateAnd(LHSI->getOperand(0),Mask, LHSI->getName()+".mask"); return new ICmpInst(ICI.getPredicate(), And, - ConstantInt::get(ICI.getContext(), - RHSV.lshr(ShAmtVal))); + ConstantExpr::getLShr(RHS, ShAmt)); } } @@ -1222,10 +1238,9 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, // undefined shifts. When the shift is visited it will be // simplified. uint32_t TypeBits = RHSV.getBitWidth(); - if (ShAmt->uge(TypeBits)) - break; - uint32_t ShAmtVal = (uint32_t)ShAmt->getLimitedValue(TypeBits); + if (ShAmtVal >= TypeBits) + break; // If we are comparing against bits always shifted out, the // comparison cannot succeed. @@ -1245,13 +1260,10 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI, // Otherwise, check to see if the bits shifted out are known to be zero. // If so, we can compare against the unshifted value: // (X & 4) >> 1 == 2 --> (X & 4) == 4. - if (LHSI->hasOneUse() && - MaskedValueIsZero(LHSI->getOperand(0), - APInt::getLowBitsSet(Comp.getBitWidth(), ShAmtVal))) { + if (LHSI->hasOneUse() && cast<BinaryOperator>(LHSI)->isExact()) return new ICmpInst(ICI.getPredicate(), LHSI->getOperand(0), ConstantExpr::getShl(RHS, ShAmt)); - } - + if (LHSI->hasOneUse()) { // Otherwise strength reduce the shift into an and. APInt Val(APInt::getHighBitsSet(TypeBits, TypeBits - ShAmtVal)); @@ -1911,14 +1923,12 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { // If the LHS is 8 >>u x, and we know the result is a power of 2 like 1, // then turn "((8 >>u x)&1) == 0" into "x != 3". - ConstantInt *CI = 0; + const APInt *CI; if (Op0KnownZeroInverted == 1 && - match(LHS, m_LShr(m_ConstantInt(CI), m_Value(X))) && - CI->getValue().isPowerOf2()) { - unsigned CmpVal = CI->getValue().countTrailingZeros(); + match(LHS, m_LShr(m_Power2(CI), m_Value(X)))) return new ICmpInst(ICmpInst::ICMP_NE, X, - ConstantInt::get(X->getType(), CmpVal)); - } + ConstantInt::get(X->getType(), + CI->countTrailingZeros())); } break; @@ -1950,14 +1960,12 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { // If the LHS is 8 >>u x, and we know the result is a power of 2 like 1, // then turn "((8 >>u x)&1) != 0" into "x == 3". - ConstantInt *CI = 0; + const APInt *CI; if (Op0KnownZeroInverted == 1 && - match(LHS, m_LShr(m_ConstantInt(CI), m_Value(X))) && - CI->getValue().isPowerOf2()) { - unsigned CmpVal = CI->getValue().countTrailingZeros(); + match(LHS, m_LShr(m_Power2(CI), m_Value(X)))) return new ICmpInst(ICmpInst::ICMP_EQ, X, - ConstantInt::get(X->getType(), CmpVal)); - } + ConstantInt::get(X->getType(), + CI->countTrailingZeros())); } break; |