summaryrefslogtreecommitdiff
path: root/lib/Transforms/Scalar/Reassociate.cpp
diff options
context:
space:
mode:
authorDuncan Sands <baldrick@free.fr>2012-06-12 14:33:56 +0000
committerDuncan Sands <baldrick@free.fr>2012-06-12 14:33:56 +0000
commitc038a7833565ecf92a699371d448135a097c9e2f (patch)
tree3df170cfa4feeaad9409335ee997d76551bcf483 /lib/Transforms/Scalar/Reassociate.cpp
parent3f696e568bae8afa5986e7af48156c2bac041ba7 (diff)
downloadllvm-c038a7833565ecf92a699371d448135a097c9e2f.tar.gz
llvm-c038a7833565ecf92a699371d448135a097c9e2f.tar.bz2
llvm-c038a7833565ecf92a699371d448135a097c9e2f.tar.xz
Now that Reassociate's LinearizeExprTree can look through arbitrary expression
topologies, it is quite possible for a leaf node to have huge multiplicity, for example: x0 = x*x, x1 = x0*x0, x2 = x1*x1, ... rapidly gives a value which is x raised to a vast power (the multiplicity, or weight, of x). This patch fixes the computation of weights by correctly computing them no matter how big they are, rather than just overflowing and getting a wrong value. It turns out that the weight for a value never needs more bits to represent than the value itself, so it is enough to represent weights as APInts of the same bitwidth and do the right overflow-avoiding dance steps when computing weights. As a side-effect it reduces the number of multiplies needed in some cases of large powers. While there, in view of external uses (eg by the vectorizer) I made LinearizeExprTree static, pushing the rank computation out into users. This is progress towards fixing PR13021. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@158358 91177308-0d34-0410-b5e6-96231b3b80d8
Diffstat (limited to 'lib/Transforms/Scalar/Reassociate.cpp')
-rw-r--r--lib/Transforms/Scalar/Reassociate.cpp229
1 files changed, 204 insertions, 25 deletions
diff --git a/lib/Transforms/Scalar/Reassociate.cpp b/lib/Transforms/Scalar/Reassociate.cpp
index 275d19adbc..50de946a30 100644
--- a/lib/Transforms/Scalar/Reassociate.cpp
+++ b/lib/Transforms/Scalar/Reassociate.cpp
@@ -143,7 +143,6 @@ namespace {
Value *buildMinimalMultiplyDAG(IRBuilder<> &Builder,
SmallVectorImpl<Factor> &Factors);
Value *OptimizeMul(BinaryOperator *I, SmallVectorImpl<ValueEntry> &Ops);
- void LinearizeExprTree(BinaryOperator *I, SmallVectorImpl<ValueEntry> &Ops);
Value *RemoveFactorFromExpression(Value *V, Value *Factor);
void EraseInst(Instruction *I);
void OptimizeInst(Instruction *I);
@@ -251,10 +250,148 @@ static BinaryOperator *LowerNegateToMultiply(Instruction *Neg) {
return Res;
}
+/// CarmichaelShift - Returns k such that lambda(2^Bitwidth) = 2^k, where lambda
+/// is the Carmichael function. This means that x^(2^k) === 1 mod 2^Bitwidth for
+/// every odd x, i.e. x^(2^k) = 1 for every odd x in Bitwidth-bit arithmetic.
+/// Note that 0 <= k < Bitwidth, and if Bitwidth > 3 then x^(2^k) = 0 for every
+/// even x in Bitwidth-bit arithmetic.
+static unsigned CarmichaelShift(unsigned Bitwidth) {
+ if (Bitwidth < 3)
+ return Bitwidth - 1;
+ return Bitwidth - 2;
+}
+
+/// IncorporateWeight - Add the extra weight 'RHS' to the existing weight 'LHS',
+/// reducing the combined weight using any special properties of the operation.
+/// The existing weight LHS represents the computation X op X op ... op X where
+/// X occurs LHS times. The combined weight represents X op X op ... op X with
+/// X occurring LHS + RHS times. If op is "Xor" for example then the combined
+/// operation is equivalent to X if LHS + RHS is odd, or 0 if LHS + RHS is even;
+/// the routine returns 1 in LHS in the first case, and 0 in LHS in the second.
+static void IncorporateWeight(APInt &LHS, const APInt &RHS, unsigned Opcode) {
+ // If we were working with infinite precision arithmetic then the combined
+ // weight would be LHS + RHS. But we are using finite precision arithmetic,
+ // and the APInt sum LHS + RHS may not be correct if it wraps (it is correct
+ // for nilpotent operations and addition, but not for idempotent operations
+ // and multiplication), so it is important to correctly reduce the combined
+ // weight back into range if wrapping would be wrong.
+
+ // If RHS is zero then the weight didn't change.
+ if (RHS.isMinValue())
+ return;
+ // If LHS is zero then the combined weight is RHS.
+ if (LHS.isMinValue()) {
+ LHS = RHS;
+ return;
+ }
+ // From this point on we know that neither LHS nor RHS is zero.
+
+ if (Instruction::isIdempotent(Opcode)) {
+ // Idempotent means X op X === X, so any non-zero weight is equivalent to a
+ // weight of 1. Keeping weights at zero or one also means that wrapping is
+ // not a problem.
+ assert(LHS == 1 && RHS == 1 && "Weights not reduced!");
+ return; // Return a weight of 1.
+ }
+ if (Instruction::isNilpotent(Opcode)) {
+ // Nilpotent means X op X === 0, so reduce weights modulo 2.
+ assert(LHS == 1 && RHS == 1 && "Weights not reduced!");
+ LHS = 0; // 1 + 1 === 0 modulo 2.
+ return;
+ }
+ if (Opcode == Instruction::Add) {
+ // TODO: Reduce the weight by exploiting nsw/nuw?
+ LHS += RHS;
+ return;
+ }
+
+ assert(Opcode == Instruction::Mul && "Unknown associative operation!");
+ unsigned Bitwidth = LHS.getBitWidth();
+ // If CM is the Carmichael number then a weight W satisfying W >= CM+Bitwidth
+ // can be replaced with W-CM. That's because x^W=x^(W-CM) for every Bitwidth
+ // bit number x, since either x is odd in which case x^CM = 1, or x is even in
+ // which case both x^W and x^(W - CM) are zero. By subtracting off multiples
+ // of CM like this weights can always be reduced to the range [0, CM+Bitwidth)
+ // which by a happy accident means that they can always be represented using
+ // Bitwidth bits.
+ // TODO: Reduce the weight by exploiting nsw/nuw? (Could do much better than
+ // the Carmichael number).
+ if (Bitwidth > 3) {
+ /// CM - The value of Carmichael's lambda function.
+ APInt CM = APInt::getOneBitSet(Bitwidth, CarmichaelShift(Bitwidth));
+ // Any weight W >= Threshold can be replaced with W - CM.
+ APInt Threshold = CM + Bitwidth;
+ assert(LHS.ult(Threshold) && RHS.ult(Threshold) && "Weights not reduced!");
+ // For Bitwidth 4 or more the following sum does not overflow.
+ LHS += RHS;
+ while (LHS.uge(Threshold))
+ LHS -= CM;
+ } else {
+ // To avoid problems with overflow do everything the same as above but using
+ // a larger type.
+ unsigned CM = 1U << CarmichaelShift(Bitwidth);
+ unsigned Threshold = CM + Bitwidth;
+ assert(LHS.getZExtValue() < Threshold && RHS.getZExtValue() < Threshold &&
+ "Weights not reduced!");
+ unsigned Total = LHS.getZExtValue() + RHS.getZExtValue();
+ while (Total >= Threshold)
+ Total -= CM;
+ LHS = Total;
+ }
+}
+
+/// EvaluateRepeatedConstant - Compute C op C op ... op C where the constant C
+/// is repeated Weight times.
+static Constant *EvaluateRepeatedConstant(unsigned Opcode, Constant *C,
+ APInt Weight) {
+ // For addition the result can be efficiently computed as the product of the
+ // constant and the weight.
+ if (Opcode == Instruction::Add)
+ return ConstantExpr::getMul(C, ConstantInt::get(C->getContext(), Weight));
+
+ // The weight might be huge, so compute by repeated squaring to ensure that
+ // compile time is proportional to the logarithm of the weight.
+ Constant *Result = 0;
+ Constant *Power = C; // Successively C, C op C, (C op C) op (C op C) etc.
+ // Visit the bits in Weight.
+ while (Weight != 0) {
+ // If the current bit in Weight is non-zero do Result = Result op Power.
+ if (Weight[0])
+ Result = Result ? ConstantExpr::get(Opcode, Result, Power) : Power;
+ // Move on to the next bit if any more are non-zero.
+ Weight = Weight.lshr(1);
+ if (Weight.isMinValue())
+ break;
+ // Square the power.
+ Power = ConstantExpr::get(Opcode, Power, Power);
+ }
+
+ assert(Result && "Only positive weights supported!");
+ return Result;
+}
+
+typedef std::pair<Value*, APInt> RepeatedValue;
+
/// LinearizeExprTree - Given an associative binary expression, return the leaf
-/// nodes in Ops. The original expression is the same as Ops[0] op ... Ops[N].
-/// Note that a node may occur multiple times in Ops, but if so all occurrences
-/// are consecutive in the vector.
+/// nodes in Ops along with their weights (how many times the leaf occurs). The
+/// original expression is the same as
+/// (Ops[0].first op Ops[0].first op ... Ops[0].first) <- Ops[0].second times
+/// op
+/// (Ops[1].first op Ops[1].first op ... Ops[1].first) <- Ops[1].second times
+/// op
+/// ...
+/// op
+/// (Ops[N].first op Ops[N].first op ... Ops[N].first) <- Ops[N].second times
+///
+/// Note that the values Ops[0].first, ..., Ops[N].first are all distinct, and
+/// they are all non-constant except possibly for the last one, which if it is
+/// constant will have weight one (Ops[N].second === 1).
+///
+/// This routine may modify the function, in which case it returns 'true'. The
+/// changes it makes may well be destructive, changing the value computed by 'I'
+/// to something completely different. Thus if the routine returns 'true' then
+/// you MUST either replace I with a new expression computed from the Ops array,
+/// or use RewriteExprTree to put the values back in.
///
/// A leaf node is either not a binary operation of the same kind as the root
/// node 'I' (i.e. is not a binary operator at all, or is, but with a different
@@ -276,7 +413,7 @@ static BinaryOperator *LowerNegateToMultiply(Instruction *Neg) {
/// + * | F, G
///
/// The leaf nodes are C, E, F and G. The Ops array will contain (maybe not in
-/// that order) C, E, F, F, G, G.
+/// that order) (C, 1), (E, 1), (F, 2), (G, 2).
///
/// The expression is maximal: if some instruction is a binary operator of the
/// same kind as 'I', and all of its uses are non-leaf nodes of the expression,
@@ -287,7 +424,8 @@ static BinaryOperator *LowerNegateToMultiply(Instruction *Neg) {
/// order to ensure that every non-root node in the expression has *exactly one*
/// use by a non-leaf node of the expression. This destruction means that the
/// caller MUST either replace 'I' with a new expression or use something like
-/// RewriteExprTree to put the values back in.
+/// RewriteExprTree to put the values back in if the routine indicates that it
+/// made a change by returning 'true'.
///
/// In the above example either the right operand of A or the left operand of B
/// will be replaced by undef. If it is B's operand then this gives:
@@ -310,9 +448,14 @@ static BinaryOperator *LowerNegateToMultiply(Instruction *Neg) {
/// of the expression) if it can turn them into binary operators of the right
/// type and thus make the expression bigger.
-void Reassociate::LinearizeExprTree(BinaryOperator *I,
- SmallVectorImpl<ValueEntry> &Ops) {
+static bool LinearizeExprTree(BinaryOperator *I,
+ SmallVectorImpl<RepeatedValue> &Ops) {
DEBUG(dbgs() << "LINEARIZE: " << *I << '\n');
+ unsigned Bitwidth = I->getType()->getScalarType()->getPrimitiveSizeInBits();
+ unsigned Opcode = I->getOpcode();
+ assert(Instruction::isAssociative(Opcode) &&
+ Instruction::isCommutative(Opcode) &&
+ "Expected an associative and commutative operation!");
// Visit all operands of the expression, keeping track of their weight (the
// number of paths from the expression root to the operand, or if you like
@@ -324,9 +467,9 @@ void Reassociate::LinearizeExprTree(BinaryOperator *I,
// with their weights, representing a certain number of paths to the operator.
// If an operator occurs in the worklist multiple times then we found multiple
// ways to get to it.
- SmallVector<std::pair<BinaryOperator*, unsigned>, 8> Worklist; // (Op, Weight)
- Worklist.push_back(std::make_pair(I, 1));
- unsigned Opcode = I->getOpcode();
+ SmallVector<std::pair<BinaryOperator*, APInt>, 8> Worklist; // (Op, Weight)
+ Worklist.push_back(std::make_pair(I, APInt(Bitwidth, 1)));
+ bool MadeChange = false;
// Leaves of the expression are values that either aren't the right kind of
// operation (eg: a constant, or a multiply in an add tree), or are, but have
@@ -343,7 +486,7 @@ void Reassociate::LinearizeExprTree(BinaryOperator *I,
// Leaves - Keeps track of the set of putative leaves as well as the number of
// paths to each leaf seen so far.
- typedef SmallMap<Value*, unsigned, 8> LeafMap;
+ typedef SmallMap<Value*, APInt, 8> LeafMap;
LeafMap Leaves; // Leaf -> Total weight so far.
SmallVector<Value*, 8> LeafOrder; // Ensure deterministic leaf output order.
@@ -351,13 +494,12 @@ void Reassociate::LinearizeExprTree(BinaryOperator *I,
SmallPtrSet<Value*, 8> Visited; // For sanity checking the iteration scheme.
#endif
while (!Worklist.empty()) {
- std::pair<BinaryOperator*, unsigned> P = Worklist.pop_back_val();
+ std::pair<BinaryOperator*, APInt> P = Worklist.pop_back_val();
I = P.first; // We examine the operands of this binary operator.
- assert(P.second >= 1 && "No paths to here, so how did we get here?!");
for (unsigned OpIdx = 0; OpIdx < 2; ++OpIdx) { // Visit operands.
Value *Op = I->getOperand(OpIdx);
- unsigned Weight = P.second; // Number of paths to this operand.
+ APInt Weight = P.second; // Number of paths to this operand.
DEBUG(dbgs() << "OPERAND: " << *Op << " (" << Weight << ")\n");
assert(!Op->use_empty() && "No uses, so how did we get to it?!");
@@ -389,7 +531,7 @@ void Reassociate::LinearizeExprTree(BinaryOperator *I,
assert(Visited.count(Op) && "In leaf map but not visited!");
// Update the number of paths to the leaf.
- It->second += Weight;
+ IncorporateWeight(It->second, Weight, Opcode);
// The leaf already has one use from inside the expression. As we want
// exactly one such use, drop this new use of the leaf.
@@ -450,21 +592,44 @@ void Reassociate::LinearizeExprTree(BinaryOperator *I,
// The leaves, repeated according to their weights, represent the linearized
// form of the expression.
+ Constant *Cst = 0; // Accumulate constants here.
for (unsigned i = 0, e = LeafOrder.size(); i != e; ++i) {
Value *V = LeafOrder[i];
LeafMap::iterator It = Leaves.find(V);
if (It == Leaves.end())
- // Leaf already output, or node initially thought to be a leaf wasn't.
+ // Node initially thought to be a leaf wasn't.
continue;
assert(!isReassociableOp(V, Opcode) && "Shouldn't be a leaf!");
- unsigned Weight = It->second;
- assert(Weight > 0 && "No paths to this value!");
- // FIXME: Rather than repeating values Weight times, use a vector of
- // (ValueEntry, multiplicity) pairs.
- Ops.append(Weight, ValueEntry(getRank(V), V));
+ APInt Weight = It->second;
+ if (Weight.isMinValue())
+ // Leaf already output or weight reduction eliminated it.
+ continue;
// Ensure the leaf is only output once.
- Leaves.erase(It);
+ It->second = 0;
+ // Glob all constants together into Cst.
+ if (Constant *C = dyn_cast<Constant>(V)) {
+ C = EvaluateRepeatedConstant(Opcode, C, Weight);
+ Cst = Cst ? ConstantExpr::get(Opcode, Cst, C) : C;
+ continue;
+ }
+ // Add non-constant
+ Ops.push_back(std::make_pair(V, Weight));
+ }
+
+ // Add any constants back into Ops, all globbed together and reduced to having
+ // weight 1 for the convenience of users.
+ if (Cst && Cst != ConstantExpr::getBinOpIdentity(Opcode, I->getType()))
+ Ops.push_back(std::make_pair(Cst, APInt(Bitwidth, 1)));
+
+ // For nilpotent operations or addition there may be no operands, for example
+ // because the expression was "X xor X" or consisted of 2^Bitwidth additions:
+ // in both cases the weight reduces to 0 causing the value to be skipped.
+ if (Ops.empty()) {
+ Constant *Identity = ConstantExpr::getBinOpIdentity(Opcode, I->getType());
+ Ops.push_back(std::make_pair(Identity, APInt(Bitwidth, 1)));
}
+
+ return MadeChange;
}
// RewriteExprTree - Now that the operands for this expression tree are
@@ -775,8 +940,15 @@ Value *Reassociate::RemoveFactorFromExpression(Value *V, Value *Factor) {
BinaryOperator *BO = isReassociableOp(V, Instruction::Mul);
if (!BO) return 0;
+ SmallVector<RepeatedValue, 8> Tree;
+ MadeChange |= LinearizeExprTree(BO, Tree);
SmallVector<ValueEntry, 8> Factors;
- LinearizeExprTree(BO, Factors);
+ Factors.reserve(Tree.size());
+ for (unsigned i = 0, e = Tree.size(); i != e; ++i) {
+ RepeatedValue E = Tree[i];
+ Factors.append(E.second.getZExtValue(),
+ ValueEntry(getRank(E.first), E.first));
+ }
bool FoundFactor = false;
bool NeedsNegate = false;
@@ -1439,8 +1611,15 @@ Value *Reassociate::ReassociateExpression(BinaryOperator *I) {
// First, walk the expression tree, linearizing the tree, collecting the
// operand information.
+ SmallVector<RepeatedValue, 8> Tree;
+ MadeChange |= LinearizeExprTree(I, Tree);
SmallVector<ValueEntry, 8> Ops;
- LinearizeExprTree(I, Ops);
+ Ops.reserve(Tree.size());
+ for (unsigned i = 0, e = Tree.size(); i != e; ++i) {
+ RepeatedValue E = Tree[i];
+ Ops.append(E.second.getZExtValue(),
+ ValueEntry(getRank(E.first), E.first));
+ }
DEBUG(dbgs() << "RAIn:\t"; PrintOps(I, Ops); dbgs() << '\n');