summaryrefslogtreecommitdiff
path: root/lib/Transforms/Scalar/Reassociate.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Transforms/Scalar/Reassociate.cpp')
-rw-r--r--lib/Transforms/Scalar/Reassociate.cpp229
1 files changed, 204 insertions, 25 deletions
diff --git a/lib/Transforms/Scalar/Reassociate.cpp b/lib/Transforms/Scalar/Reassociate.cpp
index 275d19adbc..50de946a30 100644
--- a/lib/Transforms/Scalar/Reassociate.cpp
+++ b/lib/Transforms/Scalar/Reassociate.cpp
@@ -143,7 +143,6 @@ namespace {
Value *buildMinimalMultiplyDAG(IRBuilder<> &Builder,
SmallVectorImpl<Factor> &Factors);
Value *OptimizeMul(BinaryOperator *I, SmallVectorImpl<ValueEntry> &Ops);
- void LinearizeExprTree(BinaryOperator *I, SmallVectorImpl<ValueEntry> &Ops);
Value *RemoveFactorFromExpression(Value *V, Value *Factor);
void EraseInst(Instruction *I);
void OptimizeInst(Instruction *I);
@@ -251,10 +250,148 @@ static BinaryOperator *LowerNegateToMultiply(Instruction *Neg) {
return Res;
}
+/// CarmichaelShift - Returns k such that lambda(2^Bitwidth) = 2^k, where lambda
+/// is the Carmichael function. This means that x^(2^k) === 1 mod 2^Bitwidth for
+/// every odd x, i.e. x^(2^k) = 1 for every odd x in Bitwidth-bit arithmetic.
+/// Note that 0 <= k < Bitwidth, and if Bitwidth > 3 then x^(2^k) = 0 for every
+/// even x in Bitwidth-bit arithmetic.
+static unsigned CarmichaelShift(unsigned Bitwidth) {
+ if (Bitwidth < 3)
+ return Bitwidth - 1;
+ return Bitwidth - 2;
+}
+
+/// IncorporateWeight - Add the extra weight 'RHS' to the existing weight 'LHS',
+/// reducing the combined weight using any special properties of the operation.
+/// The existing weight LHS represents the computation X op X op ... op X where
+/// X occurs LHS times. The combined weight represents X op X op ... op X with
+/// X occurring LHS + RHS times. If op is "Xor" for example then the combined
+/// operation is equivalent to X if LHS + RHS is odd, or 0 if LHS + RHS is even;
+/// the routine returns 1 in LHS in the first case, and 0 in LHS in the second.
+static void IncorporateWeight(APInt &LHS, const APInt &RHS, unsigned Opcode) {
+ // If we were working with infinite precision arithmetic then the combined
+ // weight would be LHS + RHS. But we are using finite precision arithmetic,
+ // and the APInt sum LHS + RHS may not be correct if it wraps (it is correct
+ // for nilpotent operations and addition, but not for idempotent operations
+ // and multiplication), so it is important to correctly reduce the combined
+ // weight back into range if wrapping would be wrong.
+
+ // If RHS is zero then the weight didn't change.
+ if (RHS.isMinValue())
+ return;
+ // If LHS is zero then the combined weight is RHS.
+ if (LHS.isMinValue()) {
+ LHS = RHS;
+ return;
+ }
+ // From this point on we know that neither LHS nor RHS is zero.
+
+ if (Instruction::isIdempotent(Opcode)) {
+ // Idempotent means X op X === X, so any non-zero weight is equivalent to a
+ // weight of 1. Keeping weights at zero or one also means that wrapping is
+ // not a problem.
+ assert(LHS == 1 && RHS == 1 && "Weights not reduced!");
+ return; // Return a weight of 1.
+ }
+ if (Instruction::isNilpotent(Opcode)) {
+ // Nilpotent means X op X === 0, so reduce weights modulo 2.
+ assert(LHS == 1 && RHS == 1 && "Weights not reduced!");
+ LHS = 0; // 1 + 1 === 0 modulo 2.
+ return;
+ }
+ if (Opcode == Instruction::Add) {
+ // TODO: Reduce the weight by exploiting nsw/nuw?
+ LHS += RHS;
+ return;
+ }
+
+ assert(Opcode == Instruction::Mul && "Unknown associative operation!");
+ unsigned Bitwidth = LHS.getBitWidth();
+ // If CM is the Carmichael number then a weight W satisfying W >= CM+Bitwidth
+ // can be replaced with W-CM. That's because x^W=x^(W-CM) for every Bitwidth
+ // bit number x, since either x is odd in which case x^CM = 1, or x is even in
+ // which case both x^W and x^(W - CM) are zero. By subtracting off multiples
+ // of CM like this weights can always be reduced to the range [0, CM+Bitwidth)
+ // which by a happy accident means that they can always be represented using
+ // Bitwidth bits.
+ // TODO: Reduce the weight by exploiting nsw/nuw? (Could do much better than
+ // the Carmichael number).
+ if (Bitwidth > 3) {
+ /// CM - The value of Carmichael's lambda function.
+ APInt CM = APInt::getOneBitSet(Bitwidth, CarmichaelShift(Bitwidth));
+ // Any weight W >= Threshold can be replaced with W - CM.
+ APInt Threshold = CM + Bitwidth;
+ assert(LHS.ult(Threshold) && RHS.ult(Threshold) && "Weights not reduced!");
+ // For Bitwidth 4 or more the following sum does not overflow.
+ LHS += RHS;
+ while (LHS.uge(Threshold))
+ LHS -= CM;
+ } else {
+ // To avoid problems with overflow do everything the same as above but using
+ // a larger type.
+ unsigned CM = 1U << CarmichaelShift(Bitwidth);
+ unsigned Threshold = CM + Bitwidth;
+ assert(LHS.getZExtValue() < Threshold && RHS.getZExtValue() < Threshold &&
+ "Weights not reduced!");
+ unsigned Total = LHS.getZExtValue() + RHS.getZExtValue();
+ while (Total >= Threshold)
+ Total -= CM;
+ LHS = Total;
+ }
+}
+
+/// EvaluateRepeatedConstant - Compute C op C op ... op C where the constant C
+/// is repeated Weight times.
+static Constant *EvaluateRepeatedConstant(unsigned Opcode, Constant *C,
+ APInt Weight) {
+ // For addition the result can be efficiently computed as the product of the
+ // constant and the weight.
+ if (Opcode == Instruction::Add)
+ return ConstantExpr::getMul(C, ConstantInt::get(C->getContext(), Weight));
+
+ // The weight might be huge, so compute by repeated squaring to ensure that
+ // compile time is proportional to the logarithm of the weight.
+ Constant *Result = 0;
+ Constant *Power = C; // Successively C, C op C, (C op C) op (C op C) etc.
+ // Visit the bits in Weight.
+ while (Weight != 0) {
+ // If the current bit in Weight is non-zero do Result = Result op Power.
+ if (Weight[0])
+ Result = Result ? ConstantExpr::get(Opcode, Result, Power) : Power;
+ // Move on to the next bit if any more are non-zero.
+ Weight = Weight.lshr(1);
+ if (Weight.isMinValue())
+ break;
+ // Square the power.
+ Power = ConstantExpr::get(Opcode, Power, Power);
+ }
+
+ assert(Result && "Only positive weights supported!");
+ return Result;
+}
+
+typedef std::pair<Value*, APInt> RepeatedValue;
+
/// LinearizeExprTree - Given an associative binary expression, return the leaf
-/// nodes in Ops. The original expression is the same as Ops[0] op ... Ops[N].
-/// Note that a node may occur multiple times in Ops, but if so all occurrences
-/// are consecutive in the vector.
+/// nodes in Ops along with their weights (how many times the leaf occurs). The
+/// original expression is the same as
+/// (Ops[0].first op Ops[0].first op ... Ops[0].first) <- Ops[0].second times
+/// op
+/// (Ops[1].first op Ops[1].first op ... Ops[1].first) <- Ops[1].second times
+/// op
+/// ...
+/// op
+/// (Ops[N].first op Ops[N].first op ... Ops[N].first) <- Ops[N].second times
+///
+/// Note that the values Ops[0].first, ..., Ops[N].first are all distinct, and
+/// they are all non-constant except possibly for the last one, which if it is
+/// constant will have weight one (Ops[N].second === 1).
+///
+/// This routine may modify the function, in which case it returns 'true'. The
+/// changes it makes may well be destructive, changing the value computed by 'I'
+/// to something completely different. Thus if the routine returns 'true' then
+/// you MUST either replace I with a new expression computed from the Ops array,
+/// or use RewriteExprTree to put the values back in.
///
/// A leaf node is either not a binary operation of the same kind as the root
/// node 'I' (i.e. is not a binary operator at all, or is, but with a different
@@ -276,7 +413,7 @@ static BinaryOperator *LowerNegateToMultiply(Instruction *Neg) {
/// + * | F, G
///
/// The leaf nodes are C, E, F and G. The Ops array will contain (maybe not in
-/// that order) C, E, F, F, G, G.
+/// that order) (C, 1), (E, 1), (F, 2), (G, 2).
///
/// The expression is maximal: if some instruction is a binary operator of the
/// same kind as 'I', and all of its uses are non-leaf nodes of the expression,
@@ -287,7 +424,8 @@ static BinaryOperator *LowerNegateToMultiply(Instruction *Neg) {
/// order to ensure that every non-root node in the expression has *exactly one*
/// use by a non-leaf node of the expression. This destruction means that the
/// caller MUST either replace 'I' with a new expression or use something like
-/// RewriteExprTree to put the values back in.
+/// RewriteExprTree to put the values back in if the routine indicates that it
+/// made a change by returning 'true'.
///
/// In the above example either the right operand of A or the left operand of B
/// will be replaced by undef. If it is B's operand then this gives:
@@ -310,9 +448,14 @@ static BinaryOperator *LowerNegateToMultiply(Instruction *Neg) {
/// of the expression) if it can turn them into binary operators of the right
/// type and thus make the expression bigger.
-void Reassociate::LinearizeExprTree(BinaryOperator *I,
- SmallVectorImpl<ValueEntry> &Ops) {
+static bool LinearizeExprTree(BinaryOperator *I,
+ SmallVectorImpl<RepeatedValue> &Ops) {
DEBUG(dbgs() << "LINEARIZE: " << *I << '\n');
+ unsigned Bitwidth = I->getType()->getScalarType()->getPrimitiveSizeInBits();
+ unsigned Opcode = I->getOpcode();
+ assert(Instruction::isAssociative(Opcode) &&
+ Instruction::isCommutative(Opcode) &&
+ "Expected an associative and commutative operation!");
// Visit all operands of the expression, keeping track of their weight (the
// number of paths from the expression root to the operand, or if you like
@@ -324,9 +467,9 @@ void Reassociate::LinearizeExprTree(BinaryOperator *I,
// with their weights, representing a certain number of paths to the operator.
// If an operator occurs in the worklist multiple times then we found multiple
// ways to get to it.
- SmallVector<std::pair<BinaryOperator*, unsigned>, 8> Worklist; // (Op, Weight)
- Worklist.push_back(std::make_pair(I, 1));
- unsigned Opcode = I->getOpcode();
+ SmallVector<std::pair<BinaryOperator*, APInt>, 8> Worklist; // (Op, Weight)
+ Worklist.push_back(std::make_pair(I, APInt(Bitwidth, 1)));
+ bool MadeChange = false;
// Leaves of the expression are values that either aren't the right kind of
// operation (eg: a constant, or a multiply in an add tree), or are, but have
@@ -343,7 +486,7 @@ void Reassociate::LinearizeExprTree(BinaryOperator *I,
// Leaves - Keeps track of the set of putative leaves as well as the number of
// paths to each leaf seen so far.
- typedef SmallMap<Value*, unsigned, 8> LeafMap;
+ typedef SmallMap<Value*, APInt, 8> LeafMap;
LeafMap Leaves; // Leaf -> Total weight so far.
SmallVector<Value*, 8> LeafOrder; // Ensure deterministic leaf output order.
@@ -351,13 +494,12 @@ void Reassociate::LinearizeExprTree(BinaryOperator *I,
SmallPtrSet<Value*, 8> Visited; // For sanity checking the iteration scheme.
#endif
while (!Worklist.empty()) {
- std::pair<BinaryOperator*, unsigned> P = Worklist.pop_back_val();
+ std::pair<BinaryOperator*, APInt> P = Worklist.pop_back_val();
I = P.first; // We examine the operands of this binary operator.
- assert(P.second >= 1 && "No paths to here, so how did we get here?!");
for (unsigned OpIdx = 0; OpIdx < 2; ++OpIdx) { // Visit operands.
Value *Op = I->getOperand(OpIdx);
- unsigned Weight = P.second; // Number of paths to this operand.
+ APInt Weight = P.second; // Number of paths to this operand.
DEBUG(dbgs() << "OPERAND: " << *Op << " (" << Weight << ")\n");
assert(!Op->use_empty() && "No uses, so how did we get to it?!");
@@ -389,7 +531,7 @@ void Reassociate::LinearizeExprTree(BinaryOperator *I,
assert(Visited.count(Op) && "In leaf map but not visited!");
// Update the number of paths to the leaf.
- It->second += Weight;
+ IncorporateWeight(It->second, Weight, Opcode);
// The leaf already has one use from inside the expression. As we want
// exactly one such use, drop this new use of the leaf.
@@ -450,21 +592,44 @@ void Reassociate::LinearizeExprTree(BinaryOperator *I,
// The leaves, repeated according to their weights, represent the linearized
// form of the expression.
+ Constant *Cst = 0; // Accumulate constants here.
for (unsigned i = 0, e = LeafOrder.size(); i != e; ++i) {
Value *V = LeafOrder[i];
LeafMap::iterator It = Leaves.find(V);
if (It == Leaves.end())
- // Leaf already output, or node initially thought to be a leaf wasn't.
+ // Node initially thought to be a leaf wasn't.
continue;
assert(!isReassociableOp(V, Opcode) && "Shouldn't be a leaf!");
- unsigned Weight = It->second;
- assert(Weight > 0 && "No paths to this value!");
- // FIXME: Rather than repeating values Weight times, use a vector of
- // (ValueEntry, multiplicity) pairs.
- Ops.append(Weight, ValueEntry(getRank(V), V));
+ APInt Weight = It->second;
+ if (Weight.isMinValue())
+ // Leaf already output or weight reduction eliminated it.
+ continue;
// Ensure the leaf is only output once.
- Leaves.erase(It);
+ It->second = 0;
+ // Glob all constants together into Cst.
+ if (Constant *C = dyn_cast<Constant>(V)) {
+ C = EvaluateRepeatedConstant(Opcode, C, Weight);
+ Cst = Cst ? ConstantExpr::get(Opcode, Cst, C) : C;
+ continue;
+ }
+ // Add non-constant
+ Ops.push_back(std::make_pair(V, Weight));
+ }
+
+ // Add any constants back into Ops, all globbed together and reduced to having
+ // weight 1 for the convenience of users.
+ if (Cst && Cst != ConstantExpr::getBinOpIdentity(Opcode, I->getType()))
+ Ops.push_back(std::make_pair(Cst, APInt(Bitwidth, 1)));
+
+ // For nilpotent operations or addition there may be no operands, for example
+ // because the expression was "X xor X" or consisted of 2^Bitwidth additions:
+ // in both cases the weight reduces to 0 causing the value to be skipped.
+ if (Ops.empty()) {
+ Constant *Identity = ConstantExpr::getBinOpIdentity(Opcode, I->getType());
+ Ops.push_back(std::make_pair(Identity, APInt(Bitwidth, 1)));
}
+
+ return MadeChange;
}
// RewriteExprTree - Now that the operands for this expression tree are
@@ -775,8 +940,15 @@ Value *Reassociate::RemoveFactorFromExpression(Value *V, Value *Factor) {
BinaryOperator *BO = isReassociableOp(V, Instruction::Mul);
if (!BO) return 0;
+ SmallVector<RepeatedValue, 8> Tree;
+ MadeChange |= LinearizeExprTree(BO, Tree);
SmallVector<ValueEntry, 8> Factors;
- LinearizeExprTree(BO, Factors);
+ Factors.reserve(Tree.size());
+ for (unsigned i = 0, e = Tree.size(); i != e; ++i) {
+ RepeatedValue E = Tree[i];
+ Factors.append(E.second.getZExtValue(),
+ ValueEntry(getRank(E.first), E.first));
+ }
bool FoundFactor = false;
bool NeedsNegate = false;
@@ -1439,8 +1611,15 @@ Value *Reassociate::ReassociateExpression(BinaryOperator *I) {
// First, walk the expression tree, linearizing the tree, collecting the
// operand information.
+ SmallVector<RepeatedValue, 8> Tree;
+ MadeChange |= LinearizeExprTree(I, Tree);
SmallVector<ValueEntry, 8> Ops;
- LinearizeExprTree(I, Ops);
+ Ops.reserve(Tree.size());
+ for (unsigned i = 0, e = Tree.size(); i != e; ++i) {
+ RepeatedValue E = Tree[i];
+ Ops.append(E.second.getZExtValue(),
+ ValueEntry(getRank(E.first), E.first));
+ }
DEBUG(dbgs() << "RAIn:\t"; PrintOps(I, Ops); dbgs() << '\n');