diff options
author | Dan Gohman <gohman@apple.com> | 2008-10-30 20:40:10 +0000 |
---|---|---|
committer | Dan Gohman <gohman@apple.com> | 2008-10-30 20:40:10 +0000 |
commit | 1975d03183966698650042e7a2bbd7198d276cfb (patch) | |
tree | f33349c5e9efe7c269db43d41ca63e460ea703ef | |
parent | d383ff313b67b08ab36e2c0fa0ceac59c167333d (diff) | |
download | llvm-1975d03183966698650042e7a2bbd7198d276cfb.tar.gz llvm-1975d03183966698650042e7a2bbd7198d276cfb.tar.bz2 llvm-1975d03183966698650042e7a2bbd7198d276cfb.tar.xz |
Canonicalize sext(i1) to i1?-1:0, and update various instcombine
optimizations accordingly.
git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@58457 91177308-0d34-0410-b5e6-96231b3b80d8
-rw-r--r-- | include/llvm/Support/PatternMatch.h | 46 | ||||
-rw-r--r-- | lib/Transforms/Scalar/InstructionCombining.cpp | 131 | ||||
-rw-r--r-- | test/Transforms/InstCombine/logical-select.ll | 26 |
3 files changed, 162 insertions, 41 deletions
diff --git a/include/llvm/Support/PatternMatch.h b/include/llvm/Support/PatternMatch.h index a3951e2dd3..2408103cb9 100644 --- a/include/llvm/Support/PatternMatch.h +++ b/include/llvm/Support/PatternMatch.h @@ -51,6 +51,22 @@ inline leaf_ty<Value> m_Value() { return leaf_ty<Value>(); } /// m_ConstantInt() - Match an arbitrary ConstantInt and ignore it. inline leaf_ty<ConstantInt> m_ConstantInt() { return leaf_ty<ConstantInt>(); } +struct constantint_ty { + int64_t Val; + explicit constantint_ty(int64_t val) : Val(val) {} + + template<typename ITy> + bool match(ITy *V) { + return isa<ConstantInt>(V) && cast<ConstantInt>(V)->getSExtValue() == Val; + } +}; + +/// m_ConstantInt(int64_t) - Match a ConstantInt with a specific value +/// and ignore it. +inline constantint_ty m_ConstantInt(int64_t Val) { + return constantint_ty(Val); +} + struct zero_ty { template<typename ITy> bool match(ITy *V) { @@ -322,6 +338,36 @@ m_FCmp(FCmpInst::Predicate &Pred, const LHS &L, const RHS &R) { } //===----------------------------------------------------------------------===// +// Matchers for SelectInst classes +// + +template<typename Cond_t, typename LHS_t, typename RHS_t> +struct SelectClass_match { + Cond_t C; + LHS_t L; + RHS_t R; + + SelectClass_match(const Cond_t &Cond, const LHS_t &LHS, + const RHS_t &RHS) + : C(Cond), L(LHS), R(RHS) {} + + template<typename OpTy> + bool match(OpTy *V) { + if (SelectInst *I = dyn_cast<SelectInst>(V)) + return C.match(I->getOperand(0)) && + L.match(I->getOperand(1)) && + R.match(I->getOperand(2)); + return false; + } +}; + +template<typename Cond, typename LHS, typename RHS> +inline SelectClass_match<Cond, RHS, LHS> +m_Select(const Cond &C, const LHS &L, const RHS &R) { + return SelectClass_match<Cond, LHS, RHS>(C, L, R); +} + +//===----------------------------------------------------------------------===// // Matchers for CastInst classes // diff --git a/lib/Transforms/Scalar/InstructionCombining.cpp b/lib/Transforms/Scalar/InstructionCombining.cpp index 4ec36ad151..70b5aefa23 100644 --- a/lib/Transforms/Scalar/InstructionCombining.cpp +++ b/lib/Transforms/Scalar/InstructionCombining.cpp @@ -2012,6 +2012,14 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { KnownZero, KnownOne)) return &I; } + + // zext(i1) - 1 -> select i1, 0, -1 + if (ZExtInst *ZI = dyn_cast<ZExtInst>(LHS)) + if (CI->isAllOnesValue() && + ZI->getOperand(0)->getType() == Type::Int1Ty) + return SelectInst::Create(ZI->getOperand(0), + Constant::getNullValue(I.getType()), + ConstantInt::getAllOnesValue(I.getType())); } if (isa<PHINode>(LHS)) @@ -4338,24 +4346,55 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { } } - // (A & sext(C0)) | (B & ~sext(C0) -> C0 ? A : B - if (isa<SExtInst>(C) && - cast<User>(C)->getOperand(0)->getType() == Type::Int1Ty) { + // (A & (C0?-1:0)) | (B & ~(C0?-1:0)) -> C0 ? A : B, and commuted variants + if (match(A, m_Select(m_Value(), m_ConstantInt(-1), m_ConstantInt(0)))) { + if (match(D, m_Not(m_Value(A)))) + return SelectInst::Create(cast<User>(A)->getOperand(0), C, B); + if (match(B, m_Not(m_Value(A)))) + return SelectInst::Create(cast<User>(A)->getOperand(0), C, D); + } + if (match(B, m_Select(m_Value(), m_ConstantInt(-1), m_ConstantInt(0)))) { + if (match(C, m_Not(m_Value(B)))) + return SelectInst::Create(cast<User>(B)->getOperand(0), A, D); + if (match(A, m_Not(m_Value(B)))) + return SelectInst::Create(cast<User>(B)->getOperand(0), C, D); + } + if (match(C, m_Select(m_Value(), m_ConstantInt(-1), m_ConstantInt(0)))) { if (match(D, m_Not(m_Value(C)))) return SelectInst::Create(cast<User>(C)->getOperand(0), A, B); - // And commutes, try both ways. if (match(B, m_Not(m_Value(C)))) return SelectInst::Create(cast<User>(C)->getOperand(0), A, D); } - // Or commutes, try both ways. - if (isa<SExtInst>(D) && - cast<User>(D)->getOperand(0)->getType() == Type::Int1Ty) { + if (match(D, m_Select(m_Value(), m_ConstantInt(-1), m_ConstantInt(0)))) { if (match(C, m_Not(m_Value(D)))) return SelectInst::Create(cast<User>(D)->getOperand(0), A, B); - // And commutes, try both ways. if (match(A, m_Not(m_Value(D)))) return SelectInst::Create(cast<User>(D)->getOperand(0), C, B); } + if (match(A, m_Select(m_Value(), m_ConstantInt(0), m_ConstantInt(-1)))) { + if (match(D, m_Not(m_Value(A)))) + return SelectInst::Create(cast<User>(A)->getOperand(0), B, C); + if (match(B, m_Not(m_Value(A)))) + return SelectInst::Create(cast<User>(A)->getOperand(0), D, C); + } + if (match(B, m_Select(m_Value(), m_ConstantInt(0), m_ConstantInt(-1)))) { + if (match(C, m_Not(m_Value(B)))) + return SelectInst::Create(cast<User>(B)->getOperand(0), D, A); + if (match(A, m_Not(m_Value(B)))) + return SelectInst::Create(cast<User>(B)->getOperand(0), D, C); + } + if (match(C, m_Select(m_Value(), m_ConstantInt(0), m_ConstantInt(-1)))) { + if (match(D, m_Not(m_Value(C)))) + return SelectInst::Create(cast<User>(C)->getOperand(0), B, A); + if (match(B, m_Not(m_Value(C)))) + return SelectInst::Create(cast<User>(C)->getOperand(0), D, A); + } + if (match(D, m_Select(m_Value(), m_ConstantInt(0), m_ConstantInt(-1)))) { + if (match(C, m_Not(m_Value(D)))) + return SelectInst::Create(cast<User>(D)->getOperand(0), B, A); + if (match(A, m_Not(m_Value(D)))) + return SelectInst::Create(cast<User>(D)->getOperand(0), B, C); + } } // (X >> Z) | (Y >> Z) -> (X|Y) >> Z for all shifts. @@ -7965,37 +8004,11 @@ Instruction *InstCombiner::visitSExt(SExtInst &CI) { Value *Src = CI.getOperand(0); - // sext (x <s 0) -> ashr x, 31 -> all ones if signed - // sext (x >s -1) -> ashr x, 31 -> all ones if not signed - if (ICmpInst *ICI = dyn_cast<ICmpInst>(Src)) { - // If we are just checking for a icmp eq of a single bit and zext'ing it - // to an integer, then shift the bit to the appropriate place and then - // cast to integer to avoid the comparison. - if (ConstantInt *Op1C = dyn_cast<ConstantInt>(ICI->getOperand(1))) { - const APInt &Op1CV = Op1C->getValue(); - - // sext (x <s 0) to i32 --> x>>s31 true if signbit set. - // sext (x >s -1) to i32 --> (x>>s31)^-1 true if signbit clear. - if ((ICI->getPredicate() == ICmpInst::ICMP_SLT && Op1CV == 0) || - (ICI->getPredicate() == ICmpInst::ICMP_SGT &&Op1CV.isAllOnesValue())){ - Value *In = ICI->getOperand(0); - Value *Sh = ConstantInt::get(In->getType(), - In->getType()->getPrimitiveSizeInBits()-1); - In = InsertNewInstBefore(BinaryOperator::CreateAShr(In, Sh, - In->getName()+".lobit"), - CI); - if (In->getType() != CI.getType()) - In = CastInst::CreateIntegerCast(In, CI.getType(), - true/*SExt*/, "tmp", &CI); - - if (ICI->getPredicate() == ICmpInst::ICMP_SGT) - In = InsertNewInstBefore(BinaryOperator::CreateNot(In, - In->getName()+".not"), CI); - - return ReplaceInstUsesWith(CI, In); - } - } - } + // Canonicalize sign-extend from i1 to a select. + if (Src->getType() == Type::Int1Ty) + return SelectInst::Create(Src, + ConstantInt::getAllOnesValue(CI.getType()), + Constant::getNullValue(CI.getType())); // See if the value being truncated is already sign extended. If so, just // eliminate the trunc/sext pair. @@ -8468,7 +8481,7 @@ Instruction *InstCombiner::visitSelectInstWithICmp(SelectInst &SI, // can be adjusted to fit the min/max idiom. We may edit ICI in // place here, so make sure the select is the only user. if (ICI->hasOneUse()) - if (ConstantInt *CI = dyn_cast<ConstantInt>(CmpRHS)) + if (ConstantInt *CI = dyn_cast<ConstantInt>(CmpRHS)) { switch (Pred) { default: break; case ICmpInst::ICMP_ULT: @@ -8513,6 +8526,44 @@ Instruction *InstCombiner::visitSelectInstWithICmp(SelectInst &SI, } } + // (x <s 0) ? -1 : 0 -> ashr x, 31 -> all ones if signed + // (x >s -1) ? -1 : 0 -> ashr x, 31 -> all ones if not signed + CmpInst::Predicate Pred = ICI->getPredicate(); + if (match(TrueVal, m_ConstantInt(0)) && + match(FalseVal, m_ConstantInt(-1))) + Pred = CmpInst::getInversePredicate(Pred); + else if (!match(TrueVal, m_ConstantInt(-1)) || + !match(FalseVal, m_ConstantInt(0))) + Pred = CmpInst::BAD_ICMP_PREDICATE; + if (Pred != CmpInst::BAD_ICMP_PREDICATE) { + // If we are just checking for a icmp eq of a single bit and zext'ing it + // to an integer, then shift the bit to the appropriate place and then + // cast to integer to avoid the comparison. + const APInt &Op1CV = CI->getValue(); + + // sext (x <s 0) to i32 --> x>>s31 true if signbit set. + // sext (x >s -1) to i32 --> (x>>s31)^-1 true if signbit clear. + if ((Pred == ICmpInst::ICMP_SLT && Op1CV == 0) || + (Pred == ICmpInst::ICMP_SGT &&Op1CV.isAllOnesValue())) { + Value *In = ICI->getOperand(0); + Value *Sh = ConstantInt::get(In->getType(), + In->getType()->getPrimitiveSizeInBits()-1); + In = InsertNewInstBefore(BinaryOperator::CreateAShr(In, Sh, + In->getName()+".lobit"), + *ICI); + if (In->getType() != CI->getType()) + In = CastInst::CreateIntegerCast(In, CI->getType(), + true/*SExt*/, "tmp", ICI); + + if (Pred == ICmpInst::ICMP_SGT) + In = InsertNewInstBefore(BinaryOperator::CreateNot(In, + In->getName()+".not"), *ICI); + + return ReplaceInstUsesWith(SI, In); + } + } + } + if (CmpLHS == TrueVal && CmpRHS == FalseVal) { // Transform (X == Y) ? X : Y -> Y if (Pred == ICmpInst::ICMP_EQ) diff --git a/test/Transforms/InstCombine/logical-select.ll b/test/Transforms/InstCombine/logical-select.ll index 6369badee6..39702d390a 100644 --- a/test/Transforms/InstCombine/logical-select.ll +++ b/test/Transforms/InstCombine/logical-select.ll @@ -1,4 +1,7 @@ -; RUN: llvm-as < %s | opt -instcombine | llvm-dis | grep select | count 2 +; RUN: llvm-as < %s | opt -instcombine | llvm-dis > %t +; RUN grep select %t | count 4 +; RUN not grep and %t +; RUN not grep or %t define i32 @foo(i32 %a, i32 %b, i32 %c, i32 %d) nounwind { %e = icmp slt i32 %a, %b @@ -18,3 +21,24 @@ define i32 @bar(i32 %a, i32 %b, i32 %c, i32 %d) nounwind { %j = or i32 %i, %g ret i32 %j } +define i32 @goo(i32 %a, i32 %b, i32 %c, i32 %d) nounwind { +entry: + %0 = icmp slt i32 %a, %b + %iftmp.0.0 = select i1 %0, i32 -1, i32 0 + %1 = and i32 %iftmp.0.0, %c + %not = xor i32 %iftmp.0.0, -1 + %2 = and i32 %not, %d + %3 = or i32 %1, %2 + ret i32 %3 +} + +define i32 @par(i32 %a, i32 %b, i32 %c, i32 %d) nounwind { +entry: + %0 = icmp slt i32 %a, %b + %iftmp.1.0 = select i1 %0, i32 -1, i32 0 + %1 = and i32 %iftmp.1.0, %c + %not = xor i32 %iftmp.1.0, -1 + %2 = and i32 %not, %d + %3 = or i32 %1, %2 + ret i32 %3 +} |