summaryrefslogtreecommitdiff
path: root/lib/Transforms/Scalar/Reassociate.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Transforms/Scalar/Reassociate.cpp')
-rw-r--r--lib/Transforms/Scalar/Reassociate.cpp257
1 files changed, 247 insertions, 10 deletions
diff --git a/lib/Transforms/Scalar/Reassociate.cpp b/lib/Transforms/Scalar/Reassociate.cpp
index cb408a137e..c4079e37a1 100644
--- a/lib/Transforms/Scalar/Reassociate.cpp
+++ b/lib/Transforms/Scalar/Reassociate.cpp
@@ -31,10 +31,12 @@
#include "llvm/Pass.h"
#include "llvm/Assembly/Writer.h"
#include "llvm/Support/CFG.h"
+#include "llvm/Support/IRBuilder.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ValueHandle.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/ADT/PostOrderIterator.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/ADT/DenseMap.h"
#include <algorithm>
@@ -72,6 +74,45 @@ static void PrintOps(Instruction *I, const SmallVectorImpl<ValueEntry> &Ops) {
#endif
namespace {
+ /// \brief Utility class representing a base and exponent pair which form one
+ /// factor of some product.
+ struct Factor {
+ Value *Base;
+ unsigned Power;
+
+ Factor(Value *Base, unsigned Power) : Base(Base), Power(Power) {}
+
+ /// \brief Sort factors by their Base.
+ struct BaseSorter {
+ bool operator()(const Factor &LHS, const Factor &RHS) {
+ return LHS.Base < RHS.Base;
+ }
+ };
+
+ /// \brief Compare factors for equal bases.
+ struct BaseEqual {
+ bool operator()(const Factor &LHS, const Factor &RHS) {
+ return LHS.Base == RHS.Base;
+ }
+ };
+
+ /// \brief Sort factors in descending order by their power.
+ struct PowerDescendingSorter {
+ bool operator()(const Factor &LHS, const Factor &RHS) {
+ return LHS.Power > RHS.Power;
+ }
+ };
+
+ /// \brief Compare factors for equal powers.
+ struct PowerEqual {
+ bool operator()(const Factor &LHS, const Factor &RHS) {
+ return LHS.Power == RHS.Power;
+ }
+ };
+ };
+}
+
+namespace {
class Reassociate : public FunctionPass {
DenseMap<BasicBlock*, unsigned> RankMap;
DenseMap<AssertingVH<Value>, unsigned> ValueRankMap;
@@ -98,6 +139,11 @@ namespace {
Value *OptimizeExpression(BinaryOperator *I,
SmallVectorImpl<ValueEntry> &Ops);
Value *OptimizeAdd(Instruction *I, SmallVectorImpl<ValueEntry> &Ops);
+ bool collectMultiplyFactors(SmallVectorImpl<ValueEntry> &Ops,
+ SmallVectorImpl<Factor> &Factors);
+ Value *buildMinimalMultiplyDAG(IRBuilder<> &Builder,
+ SmallVectorImpl<Factor> &Factors);
+ Value *OptimizeMul(BinaryOperator *I, SmallVectorImpl<ValueEntry> &Ops);
void LinearizeExprTree(BinaryOperator *I, SmallVectorImpl<ValueEntry> &Ops);
void LinearizeExpr(BinaryOperator *I);
Value *RemoveFactorFromExpression(Value *V, Value *Factor);
@@ -888,6 +934,199 @@ Value *Reassociate::OptimizeAdd(Instruction *I,
return 0;
}
+namespace {
+ /// \brief Predicate tests whether a ValueEntry's op is in a map.
+ struct IsValueInMap {
+ const DenseMap<Value *, unsigned> &Map;
+
+ IsValueInMap(const DenseMap<Value *, unsigned> &Map) : Map(Map) {}
+
+ bool operator()(const ValueEntry &Entry) {
+ return Map.find(Entry.Op) != Map.end();
+ }
+ };
+}
+
+/// \brief Build up a vector of value/power pairs factoring a product.
+///
+/// Given a series of multiplication operands, build a vector of factors and
+/// the powers each is raised to when forming the final product. Sort them in
+/// the order of descending power.
+///
+/// (x*x) -> [(x, 2)]
+/// ((x*x)*x) -> [(x, 3)]
+/// ((((x*y)*x)*y)*x) -> [(x, 3), (y, 2)]
+///
+/// \returns Whether any factors have a power greater than one.
+bool Reassociate::collectMultiplyFactors(SmallVectorImpl<ValueEntry> &Ops,
+ SmallVectorImpl<Factor> &Factors) {
+ unsigned FactorPowerSum = 0;
+ DenseMap<Value *, unsigned> FactorCounts;
+ for (unsigned LastIdx = 0, Idx = 0, Size = Ops.size(); Idx < Size; ++Idx) {
+ // Note that 'use_empty' uses means the only use is in the linearized tree
+ // represented by Ops -- we remove the values from the actual operations to
+ // reduce their use count.
+ if (!Ops[Idx].Op->use_empty()) {
+ if (LastIdx == Idx)
+ ++LastIdx;
+ continue;
+ }
+ if (LastIdx == Idx || Ops[LastIdx].Op != Ops[Idx].Op) {
+ LastIdx = Idx;
+ continue;
+ }
+ // Track for simplification all factors which occur 2 or more times.
+ DenseMap<Value *, unsigned>::iterator CountIt;
+ bool Inserted;
+ llvm::tie(CountIt, Inserted)
+ = FactorCounts.insert(std::make_pair(Ops[Idx].Op, 2));
+ if (Inserted) {
+ FactorPowerSum += 2;
+ Factors.push_back(Factor(Ops[Idx].Op, 2));
+ } else {
+ ++CountIt->second;
+ ++FactorPowerSum;
+ }
+ }
+ // We can only simplify factors if the sum of the powers of our simplifiable
+ // factors is 4 or higher. When that is the case, we will *always* have
+ // a simplification. This is an important invariant to prevent cyclicly
+ // trying to simplify already minimal formations.
+ if (FactorPowerSum < 4)
+ return false;
+
+ // Remove all the operands which are in the map.
+ Ops.erase(std::remove_if(Ops.begin(), Ops.end(), IsValueInMap(FactorCounts)),
+ Ops.end());
+
+ // Record the adjusted power for the simplification factors. We add back into
+ // the Ops list any values with an odd power, and make the power even. This
+ // allows the outer-most multiplication tree to remain in tact during
+ // simplification.
+ unsigned OldOpsSize = Ops.size();
+ for (unsigned Idx = 0, Size = Factors.size(); Idx != Size; ++Idx) {
+ Factors[Idx].Power = FactorCounts[Factors[Idx].Base];
+ if (Factors[Idx].Power & 1) {
+ Ops.push_back(ValueEntry(getRank(Factors[Idx].Base), Factors[Idx].Base));
+ --Factors[Idx].Power;
+ --FactorPowerSum;
+ }
+ }
+ // None of the adjustments above should have reduced the sum of factor powers
+ // below our mininum of '4'.
+ assert(FactorPowerSum >= 4);
+
+ // Patch up the sort of the ops vector by sorting the factors we added back
+ // onto the back, and merging the two sequences.
+ if (OldOpsSize != Ops.size()) {
+ SmallVectorImpl<ValueEntry>::iterator MiddleIt = Ops.begin() + OldOpsSize;
+ std::sort(MiddleIt, Ops.end());
+ std::inplace_merge(Ops.begin(), MiddleIt, Ops.end());
+ }
+
+ std::sort(Factors.begin(), Factors.end(), Factor::PowerDescendingSorter());
+ return true;
+}
+
+/// \brief Build a tree of multiplies, computing the product of Ops.
+static Value *buildMultiplyTree(IRBuilder<> &Builder,
+ SmallVectorImpl<Value*> &Ops) {
+ if (Ops.size() == 1)
+ return Ops.back();
+
+ Value *LHS = Ops.pop_back_val();
+ do {
+ LHS = Builder.CreateMul(LHS, Ops.pop_back_val());
+ } while (!Ops.empty());
+
+ return LHS;
+}
+
+/// \brief Build a minimal multiplication DAG for (a^x)*(b^y)*(c^z)*...
+///
+/// Given a vector of values raised to various powers, where no two values are
+/// equal and the powers are sorted in decreasing order, compute the minimal
+/// DAG of multiplies to compute the final product, and return that product
+/// value.
+Value *Reassociate::buildMinimalMultiplyDAG(IRBuilder<> &Builder,
+ SmallVectorImpl<Factor> &Factors) {
+ assert(Factors[0].Power);
+ SmallVector<Value *, 4> OuterProduct;
+ for (unsigned LastIdx = 0, Idx = 1, Size = Factors.size();
+ Idx < Size && Factors[Idx].Power > 0; ++Idx) {
+ if (Factors[Idx].Power != Factors[LastIdx].Power) {
+ LastIdx = Idx;
+ continue;
+ }
+
+ // We want to multiply across all the factors with the same power so that
+ // we can raise them to that power as a single entity. Build a mini tree
+ // for that.
+ SmallVector<Value *, 4> InnerProduct;
+ InnerProduct.push_back(Factors[LastIdx].Base);
+ do {
+ InnerProduct.push_back(Factors[Idx].Base);
+ ++Idx;
+ } while (Idx < Size && Factors[Idx].Power == Factors[LastIdx].Power);
+
+ // Reset the base value of the first factor to the new expression tree.
+ // We'll remove all the factors with the same power in a second pass.
+ Factors[LastIdx].Base
+ = ReassociateExpression(
+ cast<BinaryOperator>(buildMultiplyTree(Builder, InnerProduct)));
+
+ LastIdx = Idx;
+ }
+ // Unique factors with equal powers -- we've folded them into the first one's
+ // base.
+ Factors.erase(std::unique(Factors.begin(), Factors.end(),
+ Factor::PowerEqual()),
+ Factors.end());
+
+ // Iteratively collect the base of each factor with an add power into the
+ // outer product, and halve each power in preparation for squaring the
+ // expression.
+ for (unsigned Idx = 0, Size = Factors.size(); Idx != Size; ++Idx) {
+ if (Factors[Idx].Power & 1)
+ OuterProduct.push_back(Factors[Idx].Base);
+ Factors[Idx].Power >>= 1;
+ }
+ if (Factors[0].Power) {
+ Value *SquareRoot = buildMinimalMultiplyDAG(Builder, Factors);
+ OuterProduct.push_back(SquareRoot);
+ OuterProduct.push_back(SquareRoot);
+ }
+ if (OuterProduct.size() == 1)
+ return OuterProduct.front();
+
+ return ReassociateExpression(
+ cast<BinaryOperator>(buildMultiplyTree(Builder, OuterProduct)));
+}
+
+Value *Reassociate::OptimizeMul(BinaryOperator *I,
+ SmallVectorImpl<ValueEntry> &Ops) {
+ // We can only optimize the multiplies when there is a chain of more than
+ // three, such that a balanced tree might require fewer total multiplies.
+ if (Ops.size() < 4)
+ return 0;
+
+ // Try to turn linear trees of multiplies without other uses of the
+ // intermediate stages into minimal multiply DAGs with perfect sub-expression
+ // re-use.
+ SmallVector<Factor, 4> Factors;
+ if (!collectMultiplyFactors(Ops, Factors))
+ return 0; // All distinct factors, so nothing left for us to do.
+
+ IRBuilder<> Builder(I);
+ Value *V = buildMinimalMultiplyDAG(Builder, Factors);
+ if (Ops.empty())
+ return V;
+
+ ValueEntry NewEntry = ValueEntry(getRank(V), V);
+ Ops.insert(std::lower_bound(Ops.begin(), Ops.end(), NewEntry), NewEntry);
+ return 0;
+}
+
Value *Reassociate::OptimizeExpression(BinaryOperator *I,
SmallVectorImpl<ValueEntry> &Ops) {
// Now that we have the linearized expression tree, try to optimize it.
@@ -937,30 +1176,28 @@ Value *Reassociate::OptimizeExpression(BinaryOperator *I,
// Handle destructive annihilation due to identities between elements in the
// argument list here.
+ unsigned NumOps = Ops.size();
switch (Opcode) {
default: break;
case Instruction::And:
case Instruction::Or:
- case Instruction::Xor: {
- unsigned NumOps = Ops.size();
+ case Instruction::Xor:
if (Value *Result = OptimizeAndOrXor(Opcode, Ops))
return Result;
- IterateOptimization |= Ops.size() != NumOps;
break;
- }
- case Instruction::Add: {
- unsigned NumOps = Ops.size();
+ case Instruction::Add:
if (Value *Result = OptimizeAdd(I, Ops))
return Result;
- IterateOptimization |= Ops.size() != NumOps;
- }
+ break;
+ case Instruction::Mul:
+ if (Value *Result = OptimizeMul(I, Ops))
+ return Result;
break;
- //case Instruction::Mul:
}
- if (IterateOptimization)
+ if (IterateOptimization || Ops.size() != NumOps)
return OptimizeExpression(I, Ops);
return 0;
}