summaryrefslogtreecommitdiff
path: root/lib/Target/AArch64/AArch64AddressTypePromotion.cpp
blob: 04906f6078f8b13c3175d457b52fe4e1c9b05817 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
//===-- AArch64AddressTypePromotion.cpp --- Promote type for addr accesses -==//
//
//                     The LLVM Compiler Infrastructure
//
// This file is distributed under the University of Illinois Open Source
// License. See LICENSE.TXT for details.
//
//===----------------------------------------------------------------------===//
//
// This pass tries to promote the computations use to obtained a sign extended
// value used into memory accesses.
// E.g.
// a = add nsw i32 b, 3
// d = sext i32 a to i64
// e = getelementptr ..., i64 d
//
// =>
// f = sext i32 b to i64
// a = add nsw i64 f, 3
// e = getelementptr ..., i64 a
//
// This is legal to do so if the computations are markers with either nsw or nuw
// markers.
// Moreover, the current heuristic is simple: it does not create new sext
// operations, i.e., it gives up when a sext would have forked (e.g., if
// a = add i32 b, c, two sexts are required to promote the computation).
//
// FIXME: This pass may be useful for other targets too.
// ===---------------------------------------------------------------------===//

#include "AArch64.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Operator.h"
#include "llvm/Pass.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"

using namespace llvm;

#define DEBUG_TYPE "aarch64-type-promotion"

static cl::opt<bool>
EnableAddressTypePromotion("aarch64-type-promotion", cl::Hidden,
                           cl::desc("Enable the type promotion pass"),
                           cl::init(true));
static cl::opt<bool>
EnableMerge("aarch64-type-promotion-merge", cl::Hidden,
            cl::desc("Enable merging of redundant sexts when one is dominating"
                     " the other."),
            cl::init(true));

//===----------------------------------------------------------------------===//
//                       AArch64AddressTypePromotion
//===----------------------------------------------------------------------===//

namespace llvm {
void initializeAArch64AddressTypePromotionPass(PassRegistry &);
}

namespace {
class AArch64AddressTypePromotion : public FunctionPass {

public:
  static char ID;
  AArch64AddressTypePromotion()
      : FunctionPass(ID), Func(nullptr), ConsideredSExtType(nullptr) {
    initializeAArch64AddressTypePromotionPass(*PassRegistry::getPassRegistry());
  }

  const char *getPassName() const override {
    return "AArch64 Address Type Promotion";
  }

  /// Iterate over the functions and promote the computation of interesting
  // sext instructions.
  bool runOnFunction(Function &F) override;

private:
  /// The current function.
  Function *Func;
  /// Filter out all sexts that does not have this type.
  /// Currently initialized with Int64Ty.
  Type *ConsideredSExtType;

  // This transformation requires dominator info.
  void getAnalysisUsage(AnalysisUsage &AU) const override {
    AU.setPreservesCFG();
    AU.addRequired<DominatorTreeWrapperPass>();
    AU.addPreserved<DominatorTreeWrapperPass>();
    FunctionPass::getAnalysisUsage(AU);
  }

  typedef SmallPtrSet<Instruction *, 32> SetOfInstructions;
  typedef SmallVector<Instruction *, 16> Instructions;
  typedef DenseMap<Value *, Instructions> ValueToInsts;

  /// Check if it is profitable to move a sext through this instruction.
  /// Currently, we consider it is profitable if:
  /// - Inst is used only once (no need to insert truncate).
  /// - Inst has only one operand that will require a sext operation (we do
  ///   do not create new sext operation).
  bool shouldGetThrough(const Instruction *Inst);

  /// Check if it is possible and legal to move a sext through this
  /// instruction.
  /// Current heuristic considers that we can get through:
  /// - Arithmetic operation marked with the nsw or nuw flag.
  /// - Other sext operation.
  /// - Truncate operation if it was just dropping sign extended bits.
  bool canGetThrough(const Instruction *Inst);

  /// Move sext operations through safe to sext instructions.
  bool propagateSignExtension(Instructions &SExtInsts);

  /// Is this sext should be considered for code motion.
  /// We look for sext with ConsideredSExtType and uses in at least one
  // GetElementPtrInst.
  bool shouldConsiderSExt(const Instruction *SExt) const;

  /// Collect all interesting sext operations, i.e., the ones with the right
  /// type and used in memory accesses.
  /// More precisely, a sext instruction is considered as interesting if it
  /// is used in a "complex" getelementptr or it exits at least another
  /// sext instruction that sign extended the same initial value.
  /// A getelementptr is considered as "complex" if it has more than 2
  // operands.
  void analyzeSExtension(Instructions &SExtInsts);

  /// Merge redundant sign extension operations in common dominator.
  void mergeSExts(ValueToInsts &ValToSExtendedUses,
                  SetOfInstructions &ToRemove);
};
} // end anonymous namespace.

char AArch64AddressTypePromotion::ID = 0;

INITIALIZE_PASS_BEGIN(AArch64AddressTypePromotion, "aarch64-type-promotion",
                      "AArch64 Type Promotion Pass", false, false)
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
INITIALIZE_PASS_END(AArch64AddressTypePromotion, "aarch64-type-promotion",
                    "AArch64 Type Promotion Pass", false, false)

FunctionPass *llvm::createAArch64AddressTypePromotionPass() {
  return new AArch64AddressTypePromotion();
}

bool AArch64AddressTypePromotion::canGetThrough(const Instruction *Inst) {
  if (isa<SExtInst>(Inst))
    return true;

  const BinaryOperator *BinOp = dyn_cast<BinaryOperator>(Inst);
  if (BinOp && isa<OverflowingBinaryOperator>(BinOp) &&
      (BinOp->hasNoUnsignedWrap() || BinOp->hasNoSignedWrap()))
    return true;

  // sext(trunc(sext)) --> sext
  if (isa<TruncInst>(Inst) && isa<SExtInst>(Inst->getOperand(0))) {
    const Instruction *Opnd = cast<Instruction>(Inst->getOperand(0));
    // Check that the truncate just drop sign extended bits.
    if (Inst->getType()->getIntegerBitWidth() >=
            Opnd->getOperand(0)->getType()->getIntegerBitWidth() &&
        Inst->getOperand(0)->getType()->getIntegerBitWidth() <=
            ConsideredSExtType->getIntegerBitWidth())
      return true;
  }

  return false;
}

bool AArch64AddressTypePromotion::shouldGetThrough(const Instruction *Inst) {
  // If the type of the sext is the same as the considered one, this sext
  // will become useless.
  // Otherwise, we will have to do something to preserve the original value,
  // unless it is used once.
  if (isa<SExtInst>(Inst) &&
      (Inst->getType() == ConsideredSExtType || Inst->hasOneUse()))
    return true;

  // If the Inst is used more that once, we may need to insert truncate
  // operations and we don't do that at the moment.
  if (!Inst->hasOneUse())
    return false;

  // This truncate is used only once, thus if we can get thourgh, it will become
  // useless.
  if (isa<TruncInst>(Inst))
    return true;

  // If both operands are not constant, a new sext will be created here.
  // Current heuristic is: each step should be profitable.
  // Therefore we don't allow to increase the number of sext even if it may
  // be profitable later on.
  if (isa<BinaryOperator>(Inst) && isa<ConstantInt>(Inst->getOperand(1)))
    return true;

  return false;
}

static bool shouldSExtOperand(const Instruction *Inst, int OpIdx) {
  if (isa<SelectInst>(Inst) && OpIdx == 0)
    return false;
  return true;
}

bool
AArch64AddressTypePromotion::shouldConsiderSExt(const Instruction *SExt) const {
  if (SExt->getType() != ConsideredSExtType)
    return false;

  for (const Use &U : SExt->uses()) {
    if (isa<GetElementPtrInst>(*U))
      return true;
  }

  return false;
}

// Input:
// - SExtInsts contains all the sext instructions that are use direclty in
//   GetElementPtrInst, i.e., access to memory.
// Algorithm:
// - For each sext operation in SExtInsts:
//   Let var be the operand of sext.
//   while it is profitable (see shouldGetThrough), legal, and safe
//   (see canGetThrough) to move sext through var's definition:
//   * promote the type of var's definition.
//   * fold var into sext uses.
//   * move sext above var's definition.
//   * update sext operand to use the operand of var that should be sign
//     extended (by construction there is only one).
//
//   E.g.,
//   a = ... i32 c, 3
//   b = sext i32 a to i64 <- is it legal/safe/profitable to get through 'a'
//   ...
//   = b
// => Yes, update the code
//   b = sext i32 c to i64
//   a = ... i64 b, 3
//   ...
//   = a
// Iterate on 'c'.
bool
AArch64AddressTypePromotion::propagateSignExtension(Instructions &SExtInsts) {
  DEBUG(dbgs() << "*** Propagate Sign Extension ***\n");

  bool LocalChange = false;
  SetOfInstructions ToRemove;
  ValueToInsts ValToSExtendedUses;
  while (!SExtInsts.empty()) {
    // Get through simple chain.
    Instruction *SExt = SExtInsts.pop_back_val();

    DEBUG(dbgs() << "Consider:\n" << *SExt << '\n');

    // If this SExt has already been merged continue.
    if (SExt->use_empty() && ToRemove.count(SExt)) {
      DEBUG(dbgs() << "No uses => marked as delete\n");
      continue;
    }

    // Now try to get through the chain of definitions.
    while (isa<Instruction>(SExt->getOperand(0))) {
      Instruction *Inst = dyn_cast<Instruction>(SExt->getOperand(0));
      DEBUG(dbgs() << "Try to get through:\n" << *Inst << '\n');
      if (!canGetThrough(Inst) || !shouldGetThrough(Inst)) {
        // We cannot get through something that is not an Instruction
        // or not safe to SExt.
        DEBUG(dbgs() << "Cannot get through\n");
        break;
      }

      LocalChange = true;
      // If this is a sign extend, it becomes useless.
      if (isa<SExtInst>(Inst) || isa<TruncInst>(Inst)) {
        DEBUG(dbgs() << "SExt or trunc, mark it as to remove\n");
        // We cannot use replaceAllUsesWith here because we may trigger some
        // assertion on the type as all involved sext operation may have not
        // been moved yet.
        while (!Inst->use_empty()) {
          Value::use_iterator UseIt = Inst->use_begin();
          Instruction *UseInst = dyn_cast<Instruction>(*UseIt);
          assert(UseInst && "Use of sext is not an Instruction!");
          UseInst->setOperand(UseIt->getOperandNo(), SExt);
        }
        ToRemove.insert(Inst);
        SExt->setOperand(0, Inst->getOperand(0));
        SExt->moveBefore(Inst);
        continue;
      }

      // Get through the Instruction:
      // 1. Update its type.
      // 2. Replace the uses of SExt by Inst.
      // 3. Sign extend each operand that needs to be sign extended.

      // Step #1.
      Inst->mutateType(SExt->getType());
      // Step #2.
      SExt->replaceAllUsesWith(Inst);
      // Step #3.
      Instruction *SExtForOpnd = SExt;

      DEBUG(dbgs() << "Propagate SExt to operands\n");
      for (int OpIdx = 0, EndOpIdx = Inst->getNumOperands(); OpIdx != EndOpIdx;
           ++OpIdx) {
        DEBUG(dbgs() << "Operand:\n" << *(Inst->getOperand(OpIdx)) << '\n');
        if (Inst->getOperand(OpIdx)->getType() == SExt->getType() ||
            !shouldSExtOperand(Inst, OpIdx)) {
          DEBUG(dbgs() << "No need to propagate\n");
          continue;
        }
        // Check if we can statically sign extend the operand.
        Value *Opnd = Inst->getOperand(OpIdx);
        if (const ConstantInt *Cst = dyn_cast<ConstantInt>(Opnd)) {
          DEBUG(dbgs() << "Statically sign extend\n");
          Inst->setOperand(OpIdx, ConstantInt::getSigned(SExt->getType(),
                                                         Cst->getSExtValue()));
          continue;
        }
        // UndefValue are typed, so we have to statically sign extend them.
        if (isa<UndefValue>(Opnd)) {
          DEBUG(dbgs() << "Statically sign extend\n");
          Inst->setOperand(OpIdx, UndefValue::get(SExt->getType()));
          continue;
        }

        // Otherwise we have to explicity sign extend it.
        assert(SExtForOpnd &&
               "Only one operand should have been sign extended");

        SExtForOpnd->setOperand(0, Opnd);

        DEBUG(dbgs() << "Move before:\n" << *Inst << "\nSign extend\n");
        // Move the sign extension before the insertion point.
        SExtForOpnd->moveBefore(Inst);
        Inst->setOperand(OpIdx, SExtForOpnd);
        // If more sext are required, new instructions will have to be created.
        SExtForOpnd = nullptr;
      }
      if (SExtForOpnd == SExt) {
        DEBUG(dbgs() << "Sign extension is useless now\n");
        ToRemove.insert(SExt);
        break;
      }
    }

    // If the use is already of the right type, connect its uses to its argument
    // and delete it.
    // This can happen for an Instruction which all uses are sign extended.
    if (!ToRemove.count(SExt) &&
        SExt->getType() == SExt->getOperand(0)->getType()) {
      DEBUG(dbgs() << "Sign extension is useless, attach its use to "
                      "its argument\n");
      SExt->replaceAllUsesWith(SExt->getOperand(0));
      ToRemove.insert(SExt);
    } else
      ValToSExtendedUses[SExt->getOperand(0)].push_back(SExt);
  }

  if (EnableMerge)
    mergeSExts(ValToSExtendedUses, ToRemove);

  // Remove all instructions marked as ToRemove.
  for (Instruction *I: ToRemove)
    I->eraseFromParent();
  return LocalChange;
}

void AArch64AddressTypePromotion::mergeSExts(ValueToInsts &ValToSExtendedUses,
                                             SetOfInstructions &ToRemove) {
  DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();

  for (auto &Entry : ValToSExtendedUses) {
    Instructions &Insts = Entry.second;
    Instructions CurPts;
    for (Instruction *Inst : Insts) {
      if (ToRemove.count(Inst))
        continue;
      bool inserted = false;
      for (auto Pt : CurPts) {
        if (DT.dominates(Inst, Pt)) {
          DEBUG(dbgs() << "Replace all uses of:\n" << *Pt << "\nwith:\n"
                       << *Inst << '\n');
          (Pt)->replaceAllUsesWith(Inst);
          ToRemove.insert(Pt);
          Pt = Inst;
          inserted = true;
          break;
        }
        if (!DT.dominates(Pt, Inst))
          // Give up if we need to merge in a common dominator as the
          // expermients show it is not profitable.
          continue;

        DEBUG(dbgs() << "Replace all uses of:\n" << *Inst << "\nwith:\n"
                     << *Pt << '\n');
        Inst->replaceAllUsesWith(Pt);
        ToRemove.insert(Inst);
        inserted = true;
        break;
      }
      if (!inserted)
        CurPts.push_back(Inst);
    }
  }
}

void AArch64AddressTypePromotion::analyzeSExtension(Instructions &SExtInsts) {
  DEBUG(dbgs() << "*** Analyze Sign Extensions ***\n");

  DenseMap<Value *, Instruction *> SeenChains;

  for (auto &BB : *Func) {
    for (auto &II : BB) {
      Instruction *SExt = &II;

      // Collect all sext operation per type.
      if (!isa<SExtInst>(SExt) || !shouldConsiderSExt(SExt))
        continue;

      DEBUG(dbgs() << "Found:\n" << (*SExt) << '\n');

      // Cases where we actually perform the optimization:
      // 1. SExt is used in a getelementptr with more than 2 operand =>
      //    likely we can merge some computation if they are done on 64 bits.
      // 2. The beginning of the SExt chain is SExt several time. =>
      //    code sharing is possible.

      bool insert = false;
      // #1.
      for (const Use &U : SExt->uses()) {
        const Instruction *Inst = dyn_cast<GetElementPtrInst>(U);
        if (Inst && Inst->getNumOperands() > 2) {
          DEBUG(dbgs() << "Interesting use in GetElementPtrInst\n" << *Inst
                       << '\n');
          insert = true;
          break;
        }
      }

      // #2.
      // Check the head of the chain.
      Instruction *Inst = SExt;
      Value *Last;
      do {
        int OpdIdx = 0;
        const BinaryOperator *BinOp = dyn_cast<BinaryOperator>(Inst);
        if (BinOp && isa<ConstantInt>(BinOp->getOperand(0)))
          OpdIdx = 1;
        Last = Inst->getOperand(OpdIdx);
        Inst = dyn_cast<Instruction>(Last);
      } while (Inst && canGetThrough(Inst) && shouldGetThrough(Inst));

      DEBUG(dbgs() << "Head of the chain:\n" << *Last << '\n');
      DenseMap<Value *, Instruction *>::iterator AlreadySeen =
          SeenChains.find(Last);
      if (insert || AlreadySeen != SeenChains.end()) {
        DEBUG(dbgs() << "Insert\n");
        SExtInsts.push_back(SExt);
        if (AlreadySeen != SeenChains.end() && AlreadySeen->second != nullptr) {
          DEBUG(dbgs() << "Insert chain member\n");
          SExtInsts.push_back(AlreadySeen->second);
          SeenChains[Last] = nullptr;
        }
      } else {
        DEBUG(dbgs() << "Record its chain membership\n");
        SeenChains[Last] = SExt;
      }
    }
  }
}

bool AArch64AddressTypePromotion::runOnFunction(Function &F) {
  if (!EnableAddressTypePromotion || F.isDeclaration())
    return false;
  Func = &F;
  ConsideredSExtType = Type::getInt64Ty(Func->getContext());

  DEBUG(dbgs() << "*** " << getPassName() << ": " << Func->getName() << '\n');

  Instructions SExtInsts;
  analyzeSExtension(SExtInsts);
  return propagateSignExtension(SExtInsts);
}