summaryrefslogtreecommitdiff
path: root/lib/Transforms/IPO/RaiseAllocations.cpp
blob: dcfdb34b91e53fc433ed67884fcdca6fffcfd631 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
//===- RaiseAllocations.cpp - Convert %malloc & %free calls to insts ------===//
//
// This file defines the RaiseAllocations pass which convert malloc and free
// calls to malloc and free instructions.
//
//===----------------------------------------------------------------------===//

#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Module.h"
#include "llvm/DerivedTypes.h"
#include "llvm/iMemory.h"
#include "llvm/iOther.h"
#include "llvm/Pass.h"
#include "Support/Statistic.h"

namespace {
  Statistic<> NumRaised("raiseallocs", "Number of allocations raised");

  // RaiseAllocations - Turn %malloc and %free calls into the appropriate
  // instruction.
  //
  class RaiseAllocations : public BasicBlockPass {
    Function *MallocFunc;   // Functions in the module we are processing
    Function *FreeFunc;     // Initialized by doPassInitializationVirt
  public:
    RaiseAllocations() : MallocFunc(0), FreeFunc(0) {}
    
    // doPassInitialization - For the raise allocations pass, this finds a
    // declaration for malloc and free if they exist.
    //
    bool doInitialization(Module &M);
    
    // runOnBasicBlock - This method does the actual work of converting
    // instructions over, assuming that the pass has already been initialized.
    //
    bool runOnBasicBlock(BasicBlock &BB);
  };
  
  RegisterOpt<RaiseAllocations>
  X("raiseallocs", "Raise allocations from calls to instructions");
}  // end anonymous namespace


// createRaiseAllocationsPass - The interface to this file...
Pass *createRaiseAllocationsPass() {
  return new RaiseAllocations();
}


bool RaiseAllocations::doInitialization(Module &M) {
  // If the module has a symbol table, they might be referring to the malloc
  // and free functions.  If this is the case, grab the method pointers that 
  // the module is using.
  //
  // Lookup %malloc and %free in the symbol table, for later use.  If they
  // don't exist, or are not external, we do not worry about converting calls
  // to that function into the appropriate instruction.
  //
  const FunctionType *MallocType =   // Get the type for malloc
    FunctionType::get(PointerType::get(Type::SByteTy),
                    std::vector<const Type*>(1, Type::ULongTy), false);

  const FunctionType *FreeType =     // Get the type for free
    FunctionType::get(Type::VoidTy,
                   std::vector<const Type*>(1, PointerType::get(Type::SByteTy)),
                      false);

  // Get Malloc and free prototypes if they exist!
  MallocFunc = M.getFunction("malloc", MallocType);
  FreeFunc   = M.getFunction("free"  , FreeType);

  // Check to see if the prototype is wrong, giving us sbyte*(uint) * malloc
  // This handles the common declaration of: 'void *malloc(unsigned);'
  if (MallocFunc == 0) {
    MallocType = FunctionType::get(PointerType::get(Type::SByteTy),
                            std::vector<const Type*>(1, Type::UIntTy), false);
    MallocFunc = M.getFunction("malloc", MallocType);
  }

  // Check to see if the prototype is missing, giving us sbyte*(...) * malloc
  // This handles the common declaration of: 'void *malloc();'
  if (MallocFunc == 0) {
    MallocType = FunctionType::get(PointerType::get(Type::SByteTy),
                                   std::vector<const Type*>(), true);
    MallocFunc = M.getFunction("malloc", MallocType);
  }

  // Check to see if the prototype was forgotten, giving us void (...) * free
  // This handles the common forward declaration of: 'void free();'
  if (FreeFunc == 0) {
    FreeType = FunctionType::get(Type::VoidTy, std::vector<const Type*>(),true);
    FreeFunc = M.getFunction("free", FreeType);
  }

  // One last try, check to see if we can find free as 'int (...)* free'.  This
  // handles the case where NOTHING was declared.
  if (FreeFunc == 0) {
    FreeType = FunctionType::get(Type::IntTy, std::vector<const Type*>(),true);
    FreeFunc = M.getFunction("free", FreeType);
  }


  // Don't mess with locally defined versions of these functions...
  if (MallocFunc && !MallocFunc->isExternal()) MallocFunc = 0;
  if (FreeFunc && !FreeFunc->isExternal())     FreeFunc = 0;
  return false;
}

// runOnBasicBlock - Process a basic block, fixing it up...
//
bool RaiseAllocations::runOnBasicBlock(BasicBlock &BB) {
  bool Changed = false;
  BasicBlock::InstListType &BIL = BB.getInstList();

  for (BasicBlock::iterator BI = BB.begin(); BI != BB.end(); ++BI) {
    Instruction *I = BI;

    if (CallInst *CI = dyn_cast<CallInst>(I)) {
      if (CI->getCalledValue() == MallocFunc) {      // Replace call to malloc?
        Value *Source = CI->getOperand(1);
        
        // If no prototype was provided for malloc, we may need to cast the
        // source size.
        if (Source->getType() != Type::UIntTy)
          Source = new CastInst(Source, Type::UIntTy, "MallocAmtCast", BI);

        std::string Name(CI->getName()); CI->setName("");
        BI = new MallocInst(Type::SByteTy, Source, Name, BI);
        CI->replaceAllUsesWith(BI);
        BIL.erase(I);
        Changed = true;
        ++NumRaised;
      } else if (CI->getCalledValue() == FreeFunc) { // Replace call to free?
        // If no prototype was provided for free, we may need to cast the
        // source pointer.  This should be really uncommon, but it's neccesary
        // just in case we are dealing with wierd code like this:
        //   free((long)ptr);
        //
        Value *Source = CI->getOperand(1);
        if (!isa<PointerType>(Source->getType()))
          Source = new CastInst(Source, PointerType::get(Type::SByteTy),
                                "FreePtrCast", BI);
        BI = new FreeInst(Source, BI);
        BIL.erase(I);
        Changed = true;
        ++NumRaised;
      }
    }
  }

  return Changed;
}