summaryrefslogtreecommitdiff
path: root/lib/Transforms/Scalar/TailRecursionElimination.cpp
blob: 447c0ae849ed6fb19d7921eb6e6aa0d78115187a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
//===- TailRecursionElimination.cpp - Eliminate Tail Calls ----------------===//
//
// This file implements tail recursion elimination.
//
// Caveats: The algorithm implemented is trivially simple.  There are several
// improvements that could be made:
//
//  1. If the function has any alloca instructions, these instructions will not
//     remain in the entry block of the function.  Doing this requires analysis
//     to prove that the alloca is not reachable by the recursively invoked
//     function call.
//  2. Tail recursion is only performed if the call immediately preceeds the
//     return instruction.  Would it be useful to generalize this somehow?
//  3. TRE is only performed if the function returns void or if the return
//     returns the result returned by the call.  It is possible, but unlikely,
//     that the return returns something else (like constant 0), and can still
//     be TRE'd.  It can be TRE'd if ALL OTHER return instructions in the
//     function return the exact same value.
//
//===----------------------------------------------------------------------===//

#include "llvm/Transforms/Scalar.h"
#include "llvm/DerivedTypes.h"
#include "llvm/Function.h"
#include "llvm/Instructions.h"
#include "llvm/Pass.h"
#include "Support/Statistic.h"

namespace {
  Statistic<> NumEliminated("tailcallelim", "Number of tail calls removed");

  struct TailCallElim : public FunctionPass {
    virtual bool runOnFunction(Function &F);
  };
  RegisterOpt<TailCallElim> X("tailcallelim", "Tail Call Elimination");
}

FunctionPass *createTailCallEliminationPass() { return new TailCallElim(); }


bool TailCallElim::runOnFunction(Function &F) {
  // If this function is a varargs function, we won't be able to PHI the args
  // right, so don't even try to convert it...
  if (F.getFunctionType()->isVarArg()) return false;

  BasicBlock *OldEntry = 0;
  std::vector<PHINode*> ArgumentPHIs;
  bool MadeChange = false;

  // Loop over the function, looking for any returning blocks...
  for (Function::iterator BB = F.begin(), E = F.end(); BB != E; ++BB)
    if (ReturnInst *Ret = dyn_cast<ReturnInst>(BB->getTerminator()))
      if (Ret != BB->begin())  // Make sure there is something before the ret...
        if (CallInst *CI = dyn_cast<CallInst>(Ret->getPrev()))
          // Make sure the tail call is to the current function, and that the
          // return either returns void or returns the value computed by the
          // call.
          if (CI->getCalledFunction() == &F &&
              (Ret->getNumOperands() == 0 || Ret->getReturnValue() == CI)) {
            // Ohh, it looks like we found a tail call, is this the first?
            if (!OldEntry) {
              // Ok, so this is the first tail call we have found in this
              // function.  Insert a new entry block into the function, allowing
              // us to branch back to the old entry block.
              OldEntry = &F.getEntryBlock();
              BasicBlock *NewEntry = new BasicBlock("tailrecurse", OldEntry);
              NewEntry->getInstList().push_back(new BranchInst(OldEntry));
              
              // Now that we have created a new block, which jumps to the entry
              // block, insert a PHI node for each argument of the function.
              // For now, we initialize each PHI to only have the real arguments
              // which are passed in.
              Instruction *InsertPos = OldEntry->begin();
              for (Function::aiterator I = F.abegin(), E = F.aend(); I!=E; ++I){
                PHINode *PN = new PHINode(I->getType(), I->getName()+".tr",
                                          InsertPos);
                I->replaceAllUsesWith(PN); // Everyone use the PHI node now!
                PN->addIncoming(I, NewEntry);
                ArgumentPHIs.push_back(PN);
              }
            }
            
            // Ok, now that we know we have a pseudo-entry block WITH all of the
            // required PHI nodes, add entries into the PHI node for the actual
            // parameters passed into the tail-recursive call.
            for (unsigned i = 0, e = CI->getNumOperands()-1; i != e; ++i)
              ArgumentPHIs[i]->addIncoming(CI->getOperand(i+1), BB);

            // Now that all of the PHI nodes are in place, remove the call and
            // ret instructions, replacing them with an unconditional branch.
            new BranchInst(OldEntry, CI);
            BB->getInstList().pop_back();  // Remove return.
            BB->getInstList().pop_back();  // Remove call.
            MadeChange = true;
            NumEliminated++;
          }
  
  return MadeChange;
}