summaryrefslogtreecommitdiff
path: root/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Transforms/InstCombine/InstCombineMulDivRem.cpp')
-rw-r--r--lib/Transforms/InstCombine/InstCombineMulDivRem.cpp91
1 files changed, 60 insertions, 31 deletions
diff --git a/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index 29846c156c..8e4267f898 100644
--- a/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -377,6 +377,8 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) {
if (Value *V = SimplifyFMulInst(Op0, Op1, I.getFastMathFlags(), TD))
return ReplaceInstUsesWith(I, V);
+ bool AllowReassociate = I.hasUnsafeAlgebra();
+
// Simplify mul instructions with a constant RHS.
if (isa<Constant>(Op1)) {
// Try to fold constant mul into select arguments.
@@ -389,7 +391,7 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) {
return NV;
ConstantFP *C = dyn_cast<ConstantFP>(Op1);
- if (C && I.hasUnsafeAlgebra() && C->getValueAPF().isNormal()) {
+ if (C && AllowReassociate && C->getValueAPF().isNormal()) {
// Let MDC denote an expression in one of these forms:
// X * C, C/X, X/C, where C is a constant.
//
@@ -430,7 +432,7 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) {
BinaryOperator::CreateFAdd(M0, M1) :
BinaryOperator::CreateFSub(M0, M1);
Instruction *RI = cast<Instruction>(R);
- RI->setHasUnsafeAlgebra(true);
+ RI->copyFastMathFlags(&I);
return RI;
}
}
@@ -438,9 +440,6 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) {
}
}
- if (Value *Op0v = dyn_castFNegVal(Op0)) // -X * -Y = X*Y
- if (Value *Op1v = dyn_castFNegVal(Op1))
- return BinaryOperator::CreateFMul(Op0v, Op1v);
// Under unsafe algebra do:
// X * log2(0.5*Y) = X*log2(Y) - X
@@ -469,36 +468,66 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) {
}
}
- // X * cond ? 1.0 : 0.0 => cond ? X : 0.0
- if (I.hasNoNaNs() && I.hasNoSignedZeros()) {
- Value *V0 = I.getOperand(0);
- Value *V1 = I.getOperand(1);
- Value *Cond, *SLHS, *SRHS;
- bool Match = false;
-
- if (match(V0, m_Select(m_Value(Cond), m_Value(SLHS), m_Value(SRHS)))) {
- Match = true;
- } else if (match(V1, m_Select(m_Value(Cond), m_Value(SLHS),
- m_Value(SRHS)))) {
- Match = true;
- std::swap(V0, V1);
+ // Handle symmetric situation in a 2-iteration loop
+ Value *Opnd0 = Op0;
+ Value *Opnd1 = Op1;
+ for (int i = 0; i < 2; i++) {
+ bool IgnoreZeroSign = I.hasNoSignedZeros();
+ if (BinaryOperator::isFNeg(Opnd0, IgnoreZeroSign)) {
+ Value *N0 = dyn_castFNegVal(Opnd0, IgnoreZeroSign);
+ Value *N1 = dyn_castFNegVal(Opnd1, IgnoreZeroSign);
+
+ // -X * -Y => X*Y
+ if (N1)
+ return BinaryOperator::CreateFMul(N0, N1);
+
+ if (Opnd0->hasOneUse()) {
+ // -X * Y => -(X*Y) (Promote negation as high as possible)
+ Value *T = Builder->CreateFMul(N0, Opnd1);
+ cast<Instruction>(T)->setDebugLoc(I.getDebugLoc());
+ Instruction *Neg = BinaryOperator::CreateFNeg(T);
+ if (I.getFastMathFlags().any()) {
+ cast<Instruction>(T)->copyFastMathFlags(&I);
+ Neg->copyFastMathFlags(&I);
+ }
+ return Neg;
+ }
}
- if (Match) {
- ConstantFP *C0 = dyn_cast<ConstantFP>(SLHS);
- ConstantFP *C1 = dyn_cast<ConstantFP>(SRHS);
-
- if (C0 && C1 &&
- ((C0->isZero() && C1->isExactlyValue(1.0)) ||
- (C1->isZero() && C0->isExactlyValue(1.0)))) {
- Value *T;
- if (C0->isZero())
- T = Builder->CreateSelect(Cond, SLHS, V1);
- else
- T = Builder->CreateSelect(Cond, V1, SRHS);
- return ReplaceInstUsesWith(I, T);
+ // (X*Y) * X => (X*X) * Y where Y != X
+ // The purpose is two-fold:
+ // 1) to form a power expression (of X).
+ // 2) potentially shorten the critical path: After transformation, the
+ // latency of the instruction Y is amortized by the expression of X*X,
+ // and therefore Y is in a "less critical" position compared to what it
+ // was before the transformation.
+ //
+ if (AllowReassociate) {
+ Value *Opnd0_0, *Opnd0_1;
+ if (Opnd0->hasOneUse() &&
+ match(Opnd0, m_FMul(m_Value(Opnd0_0), m_Value(Opnd0_1)))) {
+ Value *Y = 0;
+ if (Opnd0_0 == Opnd1 && Opnd0_1 != Opnd1)
+ Y = Opnd0_1;
+ else if (Opnd0_1 == Opnd1 && Opnd0_0 != Opnd1)
+ Y = Opnd0_0;
+
+ if (Y) {
+ Instruction *T = cast<Instruction>(Builder->CreateFMul(Opnd1, Opnd1));
+ T->copyFastMathFlags(&I);
+ T->setDebugLoc(I.getDebugLoc());
+
+ Instruction *R = BinaryOperator::CreateFMul(T, Y);
+ R->copyFastMathFlags(&I);
+ return R;
+ }
}
}
+
+ if (!isa<Constant>(Op1))
+ std::swap(Opnd0, Opnd1);
+ else
+ break;
}
return Changed ? &I : 0;