summaryrefslogtreecommitdiff
path: root/lib/Target/PTX/PTXISelLowering.cpp
blob: 6e68c3760187776162ef5a47ea4862a48a314d6f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
//===-- PTXISelLowering.cpp - PTX DAG Lowering Implementation -------------===//
//
//                     The LLVM Compiler Infrastructure
//
// This file is distributed under the University of Illinois Open Source
// License. See LICENSE.TXT for details.
//
//===----------------------------------------------------------------------===//
//
// This file implements the PTXTargetLowering class.
//
//===----------------------------------------------------------------------===//

#include "PTX.h"
#include "PTXISelLowering.h"
#include "PTXRegisterInfo.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/CodeGen/MachineFunction.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
#include "llvm/CodeGen/SelectionDAG.h"
#include "llvm/CodeGen/TargetLoweringObjectFileImpl.h"

using namespace llvm;

PTXTargetLowering::PTXTargetLowering(TargetMachine &TM)
  : TargetLowering(TM, new TargetLoweringObjectFileELF()) {
  // Set up the register classes.
  addRegisterClass(MVT::i1,  PTX::PredsRegisterClass);
  addRegisterClass(MVT::i32, PTX::RRegs32RegisterClass);

  // Compute derived properties from the register classes
  computeRegisterProperties();
}

const char *PTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
  switch (Opcode) {
    default:           llvm_unreachable("Unknown opcode");
    case PTXISD::EXIT: return "PTXISD::EXIT";
    case PTXISD::RET:  return "PTXISD::RET";
  }
}

//===----------------------------------------------------------------------===//
//                      Calling Convention Implementation
//===----------------------------------------------------------------------===//

static struct argmap_entry {
  MVT::SimpleValueType VT;
  TargetRegisterClass *RC;
  TargetRegisterClass::iterator loc;

  argmap_entry(MVT::SimpleValueType _VT, TargetRegisterClass *_RC)
    : VT(_VT), RC(_RC), loc(_RC->begin()) {}

  void reset(void) { loc = RC->begin(); }
  bool operator==(MVT::SimpleValueType _VT) { return VT == _VT; }
} argmap[] = {
  argmap_entry(MVT::i1,  PTX::PredsRegisterClass),
  argmap_entry(MVT::i32, PTX::RRegs32RegisterClass)
};

static SDValue lower_kernel_argument(int i,
                                     SDValue Chain,
                                     DebugLoc dl,
                                     MVT::SimpleValueType VT,
                                     argmap_entry *entry,
                                     SelectionDAG &DAG,
                                     unsigned *argreg) {
  // TODO
  llvm_unreachable("Not implemented yet");
}

static SDValue lower_device_argument(int i,
                                     SDValue Chain,
                                     DebugLoc dl,
                                     MVT::SimpleValueType VT,
                                     argmap_entry *entry,
                                     SelectionDAG &DAG,
                                     unsigned *argreg) {
  MachineRegisterInfo &RegInfo = DAG.getMachineFunction().getRegInfo();

  unsigned preg = *++(entry->loc); // allocate start from register 1
  unsigned vreg = RegInfo.createVirtualRegister(entry->RC);
  RegInfo.addLiveIn(preg, vreg);

  *argreg = preg;
  return DAG.getCopyFromReg(Chain, dl, vreg, VT);
}

typedef SDValue (*lower_argument_func)(int i,
                                       SDValue Chain,
                                       DebugLoc dl,
                                       MVT::SimpleValueType VT,
                                       argmap_entry *entry,
                                       SelectionDAG &DAG,
                                       unsigned *argreg);

SDValue PTXTargetLowering::
  LowerFormalArguments(SDValue Chain,
                       CallingConv::ID CallConv,
                       bool isVarArg,
                       const SmallVectorImpl<ISD::InputArg> &Ins,
                       DebugLoc dl,
                       SelectionDAG &DAG,
                       SmallVectorImpl<SDValue> &InVals) const {
  if (isVarArg) llvm_unreachable("PTX does not support varargs");

  lower_argument_func lower_argument;

  switch (CallConv) {
    default:
      llvm_unreachable("Unsupported calling convention");
      break;
    case CallingConv::PTX_Kernel:
      lower_argument = lower_kernel_argument;
      break;
    case CallingConv::PTX_Device:
      lower_argument = lower_device_argument;
      break;
  }

  // Reset argmap before allocation
  for (struct argmap_entry *i = argmap, *e = argmap + array_lengthof(argmap);
       i != e; ++ i)
    i->reset();

  for (int i = 0, e = Ins.size(); i != e; ++ i) {
    MVT::SimpleValueType VT = Ins[i].VT.getSimpleVT().SimpleTy;

    struct argmap_entry *entry = std::find(argmap,
                                           argmap + array_lengthof(argmap), VT);
    if (entry == argmap + array_lengthof(argmap))
      llvm_unreachable("Type of argument is not supported");

    unsigned reg;
    SDValue arg = lower_argument(i, Chain, dl, VT, entry, DAG, &reg);
    InVals.push_back(arg);
  }

  return Chain;
}

SDValue PTXTargetLowering::
  LowerReturn(SDValue Chain,
              CallingConv::ID CallConv,
              bool isVarArg,
              const SmallVectorImpl<ISD::OutputArg> &Outs,
              const SmallVectorImpl<SDValue> &OutVals,
              DebugLoc dl,
              SelectionDAG &DAG) const {
  if (isVarArg) llvm_unreachable("PTX does not support varargs");

  switch (CallConv) {
    default:
      llvm_unreachable("Unsupported calling convention.");
    case CallingConv::PTX_Kernel:
      assert(Outs.size() == 0 && "Kernel must return void.");
      return DAG.getNode(PTXISD::EXIT, dl, MVT::Other, Chain);
    case CallingConv::PTX_Device:
      assert(Outs.size() <= 1 && "Can at most return one value.");
      break;
  }

  // PTX_Device

  // return void
  if (Outs.size() == 0)
    return DAG.getNode(PTXISD::RET, dl, MVT::Other, Chain);

  assert(Outs[0].VT == MVT::i32 && "Can return only basic types");

  SDValue Flag;
  unsigned reg = PTX::R0;

  // If this is the first return lowered for this function, add the regs to the
  // liveout set for the function
  if (DAG.getMachineFunction().getRegInfo().liveout_empty())
    DAG.getMachineFunction().getRegInfo().addLiveOut(reg);

  // Copy the result values into the output registers
  Chain = DAG.getCopyToReg(Chain, dl, reg, OutVals[0], Flag);

  // Guarantee that all emitted copies are stuck together,
  // avoiding something bad
  Flag = Chain.getValue(1);

  return DAG.getNode(PTXISD::RET, dl, MVT::Other, Chain, Flag);
}