summaryrefslogtreecommitdiff
path: root/lib/Transforms/Scalar/DecomposeMultiDimRefs.cpp
blob: 5d873cda2e1f0c90bd401fd38225069cc6dc2edc (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
//===- llvm/Transforms/DecomposeMultiDimRefs.cpp - Lower array refs to 1D -===//
//
// DecomposeMultiDimRefs - Convert multi-dimensional references consisting of
// any combination of 2 or more array and structure indices into a sequence of
// instructions (using getelementpr and cast) so that each instruction has at
// most one index (except structure references, which need an extra leading
// index of [0]).
//
//===----------------------------------------------------------------------===//

#include "llvm/Transforms/Scalar.h"
#include "llvm/DerivedTypes.h"
#include "llvm/Constant.h"
#include "llvm/iMemory.h"
#include "llvm/iOther.h"
#include "llvm/BasicBlock.h"
#include "llvm/Pass.h"
#include "Support/StatisticReporter.h"

static Statistic<> NumAdded("lowerrefs\t\t- New instructions added");

namespace {
  struct DecomposePass : public BasicBlockPass {
    virtual bool runOnBasicBlock(BasicBlock &BB);

  private:
    static void decomposeArrayRef(BasicBlock::iterator &BBI);
  };

RegisterPass<DecomposePass> X("lowerrefs", "Decompose multi-dimensional "
                              "structure/array references");
}

Pass *createDecomposeMultiDimRefsPass() {
  return new DecomposePass();
}


// runOnBasicBlock - Entry point for array or structure references with multiple
// indices.
//
bool DecomposePass::runOnBasicBlock(BasicBlock &BB) {
  bool Changed = false;
  for (BasicBlock::iterator II = BB.begin(); II != BB.end(); ) {
    if (MemAccessInst *MAI = dyn_cast<MemAccessInst>(&*II)) {
      if (MAI->getNumOperands() > MAI->getFirstIndexOperandNumber()+1) {
        decomposeArrayRef(II);
        Changed = true;
      } else {
        ++II;
      }
    } else {
      ++II;
    }
  }
  
  return Changed;
}

// 
// For any combination of 2 or more array and structure indices,
// this function repeats the foll. until we have a one-dim. reference: {
//      ptr1 = getElementPtr [CompositeType-N] * lastPtr, uint firstIndex
//      ptr2 = cast [CompositeType-N] * ptr1 to [CompositeType-N] *
// }
// Then it replaces the original instruction with an equivalent one that
// uses the last ptr2 generated in the loop and a single index.
// If any index is (uint) 0, we omit the getElementPtr instruction.
// 

void DecomposePass::decomposeArrayRef(BasicBlock::iterator &BBI) {
  MemAccessInst &MAI = cast<MemAccessInst>(*BBI);
  BasicBlock *BB = MAI.getParent();
  Value *LastPtr = MAI.getPointerOperand();

  // Remove the instruction from the stream
  BB->getInstList().remove(BBI);

  std::vector<Instruction*> NewInsts;
  
  // Process each index except the last one.
  // 

  User::const_op_iterator OI = MAI.idx_begin(), OE = MAI.idx_end();
  for (; OI+1 != OE; ++OI) {
    assert(isa<PointerType>(LastPtr->getType()));
      
    // Check for a zero index.  This will need a cast instead of
    // a getElementPtr, or it may need neither.
    bool indexIsZero = isa<Constant>(*OI) && 
                       cast<Constant>(OI->get())->isNullValue() &&
                       OI->get()->getType() == Type::UIntTy;
      
    // Extract the first index.  If the ptr is a pointer to a structure
    // and the next index is a structure offset (i.e., not an array offset), 
    // we need to include an initial [0] to index into the pointer.
    //

    std::vector<Value*> Indices;
    const PointerType *PtrTy = cast<PointerType>(LastPtr->getType());

    if (isa<StructType>(PtrTy->getElementType())
        && !PtrTy->indexValid(*OI))
      Indices.push_back(Constant::getNullValue(Type::UIntTy));
    Indices.push_back(*OI);

    // Get the type obtained by applying the first index.
    // It must be a structure or array.
    const Type *NextTy = MemAccessInst::getIndexedType(LastPtr->getType(),
                                                       Indices, true);
    assert(isa<CompositeType>(NextTy));
    
    // Get a pointer to the structure or to the elements of the array.
    const Type *NextPtrTy =
      PointerType::get(isa<StructType>(NextTy) ? NextTy
                       : cast<ArrayType>(NextTy)->getElementType());
      
    // Instruction 1: nextPtr1 = GetElementPtr LastPtr, Indices
    // This is not needed if the index is zero.
    if (!indexIsZero) {
      LastPtr = new GetElementPtrInst(LastPtr, Indices, "ptr1");
      NewInsts.push_back(cast<Instruction>(LastPtr));
      ++NumAdded;
    }

      
    // Instruction 2: nextPtr2 = cast nextPtr1 to NextPtrTy
    // This is not needed if the two types are identical.
    //
    if (LastPtr->getType() != NextPtrTy) {
      LastPtr = new CastInst(LastPtr, NextPtrTy, "ptr2");
      NewInsts.push_back(cast<Instruction>(LastPtr));
      ++NumAdded;
    }
  }
  
  // 
  // Now create a new instruction to replace the original one
  //
  const PointerType *PtrTy = cast<PointerType>(LastPtr->getType());

  // First, get the final index vector.  As above, we may need an initial [0].

  std::vector<Value*> Indices;
  if (isa<StructType>(PtrTy->getElementType())
      && !PtrTy->indexValid(*OI))
    Indices.push_back(Constant::getNullValue(Type::UIntTy));

  Indices.push_back(*OI);

  Instruction *NewI = 0;
  switch(MAI.getOpcode()) {
  case Instruction::Load:
    NewI = new LoadInst(LastPtr, Indices, MAI.getName());
    break;
  case Instruction::Store:
    NewI = new StoreInst(MAI.getOperand(0), LastPtr, Indices);
    break;
  case Instruction::GetElementPtr:
    NewI = new GetElementPtrInst(LastPtr, Indices, MAI.getName());
    break;
  default:
    assert(0 && "Unrecognized memory access instruction");
  }
  NewInsts.push_back(NewI);

  
  // Replace all uses of the old instruction with the new
  MAI.replaceAllUsesWith(NewI);

  // Now delete the old instruction...
  delete &MAI;

  // Insert all of the new instructions...
  BB->getInstList().insert(BBI, NewInsts.begin(), NewInsts.end());
  
  // Advance the iterator to the instruction following the one just inserted...
  BBI = NewInsts.back();
  ++BBI;
}