From 10da1651ed7428b35ee89e0dc72ec177bf7041aa Mon Sep 17 00:00:00 2001 From: Justin Holewinski Date: Fri, 27 Jun 2014 18:35:37 +0000 Subject: [NVPTX] Implement fma and imad contraction as target DAGCombiner patterns This also introduces DAGCombiner patterns for mul.wide to multiply two smaller integers and produce a larger integer git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@211935 91177308-0d34-0410-b5e6-96231b3b80d8 --- lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp | 13 +- lib/Target/NVPTX/NVPTXISelLowering.cpp | 411 +++++++++++++++++++++++++++++++++ lib/Target/NVPTX/NVPTXISelLowering.h | 4 + lib/Target/NVPTX/NVPTXInstrInfo.td | 247 ++++++++++---------- 4 files changed, 549 insertions(+), 126 deletions(-) (limited to 'lib') diff --git a/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp index 1ea47fc8d5..21e4ba5c02 100644 --- a/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp +++ b/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp @@ -24,11 +24,14 @@ using namespace llvm; #define DEBUG_TYPE "nvptx-isel" -static cl::opt -FMAContractLevel("nvptx-fma-level", cl::ZeroOrMore, cl::Hidden, - cl::desc("NVPTX Specific: FMA contraction (0: don't do it" - " 1: do it 2: do it aggressively"), - cl::init(2)); +unsigned FMAContractLevel = 0; + +static cl::opt +FMAContractLevelOpt("nvptx-fma-level", cl::ZeroOrMore, cl::Hidden, + cl::desc("NVPTX Specific: FMA contraction (0: don't do it" + " 1: do it 2: do it aggressively"), + cl::location(FMAContractLevel), + cl::init(2)); static cl::opt UsePrecDivF32( "nvptx-prec-divf32", cl::ZeroOrMore, cl::Hidden, diff --git a/lib/Target/NVPTX/NVPTXISelLowering.cpp b/lib/Target/NVPTX/NVPTXISelLowering.cpp index 60da5f1f77..09e5a613f2 100644 --- a/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -243,6 +243,13 @@ NVPTXTargetLowering::NVPTXTargetLowering(NVPTXTargetMachine &TM) setOperationAction(ISD::CTPOP, MVT::i32, Legal); setOperationAction(ISD::CTPOP, MVT::i64, Legal); + // We have some custom DAG combine patterns for these nodes + setTargetDAGCombine(ISD::ADD); + setTargetDAGCombine(ISD::AND); + setTargetDAGCombine(ISD::FADD); + setTargetDAGCombine(ISD::MUL); + setTargetDAGCombine(ISD::SHL); + // Now deduce the information based on the above mentioned // actions computeRegisterProperties(); @@ -334,6 +341,10 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const { return "NVPTXISD::StoreV2"; case NVPTXISD::StoreV4: return "NVPTXISD::StoreV4"; + case NVPTXISD::FUN_SHFL_CLAMP: + return "NVPTXISD::FUN_SHFL_CLAMP"; + case NVPTXISD::FUN_SHFR_CLAMP: + return "NVPTXISD::FUN_SHFR_CLAMP"; case NVPTXISD::Tex1DFloatI32: return "NVPTXISD::Tex1DFloatI32"; case NVPTXISD::Tex1DFloatFloat: return "NVPTXISD::Tex1DFloatFloat"; case NVPTXISD::Tex1DFloatFloatLevel: @@ -2475,6 +2486,406 @@ unsigned NVPTXTargetLowering::getFunctionAlignment(const Function *) const { return 4; } +//===----------------------------------------------------------------------===// +// NVPTX DAG Combining +//===----------------------------------------------------------------------===// + +extern unsigned FMAContractLevel; + +/// PerformADDCombineWithOperands - Try DAG combinations for an ADD with +/// operands N0 and N1. This is a helper for PerformADDCombine that is +/// called with the default operands, and if that fails, with commuted +/// operands. +static SDValue PerformADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1, + TargetLowering::DAGCombinerInfo &DCI, + const NVPTXSubtarget &Subtarget, + CodeGenOpt::Level OptLevel) { + SelectionDAG &DAG = DCI.DAG; + // Skip non-integer, non-scalar case + EVT VT=N0.getValueType(); + if (VT.isVector()) + return SDValue(); + + // fold (add (mul a, b), c) -> (mad a, b, c) + // + if (N0.getOpcode() == ISD::MUL) { + assert (VT.isInteger()); + // For integer: + // Since integer multiply-add costs the same as integer multiply + // but is more costly than integer add, do the fusion only when + // the mul is only used in the add. + if (OptLevel==CodeGenOpt::None || VT != MVT::i32 || + !N0.getNode()->hasOneUse()) + return SDValue(); + + // Do the folding + return DAG.getNode(NVPTXISD::IMAD, SDLoc(N), VT, + N0.getOperand(0), N0.getOperand(1), N1); + } + else if (N0.getOpcode() == ISD::FMUL) { + if (VT == MVT::f32 || VT == MVT::f64) { + if (FMAContractLevel == 0) + return SDValue(); + + // For floating point: + // Do the fusion only when the mul has less than 5 uses and all + // are add. + // The heuristic is that if a use is not an add, then that use + // cannot be fused into fma, therefore mul is still needed anyway. + // If there are more than 4 uses, even if they are all add, fusing + // them will increase register pressue. + // + int numUses = 0; + int nonAddCount = 0; + for (SDNode::use_iterator UI = N0.getNode()->use_begin(), + UE = N0.getNode()->use_end(); + UI != UE; ++UI) { + numUses++; + SDNode *User = *UI; + if (User->getOpcode() != ISD::FADD) + ++nonAddCount; + } + if (numUses >= 5) + return SDValue(); + if (nonAddCount) { + int orderNo = N->getIROrder(); + int orderNo2 = N0.getNode()->getIROrder(); + // simple heuristics here for considering potential register + // pressure, the logics here is that the differnce are used + // to measure the distance between def and use, the longer distance + // more likely cause register pressure. + if (orderNo - orderNo2 < 500) + return SDValue(); + + // Now, check if at least one of the FMUL's operands is live beyond the node N, + // which guarantees that the FMA will not increase register pressure at node N. + bool opIsLive = false; + const SDNode *left = N0.getOperand(0).getNode(); + const SDNode *right = N0.getOperand(1).getNode(); + + if (dyn_cast(left) || dyn_cast(right)) + opIsLive = true; + + if (!opIsLive) + for (SDNode::use_iterator UI = left->use_begin(), UE = left->use_end(); UI != UE; ++UI) { + SDNode *User = *UI; + int orderNo3 = User->getIROrder(); + if (orderNo3 > orderNo) { + opIsLive = true; + break; + } + } + + if (!opIsLive) + for (SDNode::use_iterator UI = right->use_begin(), UE = right->use_end(); UI != UE; ++UI) { + SDNode *User = *UI; + int orderNo3 = User->getIROrder(); + if (orderNo3 > orderNo) { + opIsLive = true; + break; + } + } + + if (!opIsLive) + return SDValue(); + } + + return DAG.getNode(ISD::FMA, SDLoc(N), VT, + N0.getOperand(0), N0.getOperand(1), N1); + } + } + + return SDValue(); +} + +/// PerformADDCombine - Target-specific dag combine xforms for ISD::ADD. +/// +static SDValue PerformADDCombine(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI, + const NVPTXSubtarget &Subtarget, + CodeGenOpt::Level OptLevel) { + SDValue N0 = N->getOperand(0); + SDValue N1 = N->getOperand(1); + + // First try with the default operand order. + SDValue Result = PerformADDCombineWithOperands(N, N0, N1, DCI, Subtarget, + OptLevel); + if (Result.getNode()) + return Result; + + // If that didn't work, try again with the operands commuted. + return PerformADDCombineWithOperands(N, N1, N0, DCI, Subtarget, OptLevel); +} + +static SDValue PerformANDCombine(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI) { + // The type legalizer turns a vector load of i8 values into a zextload to i16 + // registers, optionally ANY_EXTENDs it (if target type is integer), + // and ANDs off the high 8 bits. Since we turn this load into a + // target-specific DAG node, the DAG combiner fails to eliminate these AND + // nodes. Do that here. + SDValue Val = N->getOperand(0); + SDValue Mask = N->getOperand(1); + + if (isa(Val)) { + std::swap(Val, Mask); + } + + SDValue AExt; + // Generally, we will see zextload -> IMOV16rr -> ANY_EXTEND -> and + if (Val.getOpcode() == ISD::ANY_EXTEND) { + AExt = Val; + Val = Val->getOperand(0); + } + + if (Val->isMachineOpcode() && Val->getMachineOpcode() == NVPTX::IMOV16rr) { + Val = Val->getOperand(0); + } + + if (Val->getOpcode() == NVPTXISD::LoadV2 || + Val->getOpcode() == NVPTXISD::LoadV4) { + ConstantSDNode *MaskCnst = dyn_cast(Mask); + if (!MaskCnst) { + // Not an AND with a constant + return SDValue(); + } + + uint64_t MaskVal = MaskCnst->getZExtValue(); + if (MaskVal != 0xff) { + // Not an AND that chops off top 8 bits + return SDValue(); + } + + MemSDNode *Mem = dyn_cast(Val); + if (!Mem) { + // Not a MemSDNode?!? + return SDValue(); + } + + EVT MemVT = Mem->getMemoryVT(); + if (MemVT != MVT::v2i8 && MemVT != MVT::v4i8) { + // We only handle the i8 case + return SDValue(); + } + + unsigned ExtType = + cast(Val->getOperand(Val->getNumOperands()-1))-> + getZExtValue(); + if (ExtType == ISD::SEXTLOAD) { + // If for some reason the load is a sextload, the and is needed to zero + // out the high 8 bits + return SDValue(); + } + + bool AddTo = false; + if (AExt.getNode() != 0) { + // Re-insert the ext as a zext. + Val = DCI.DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), + AExt.getValueType(), Val); + AddTo = true; + } + + // If we get here, the AND is unnecessary. Just replace it with the load + DCI.CombineTo(N, Val, AddTo); + } + + return SDValue(); +} + +enum OperandSignedness { + Signed = 0, + Unsigned, + Unknown +}; + +/// IsMulWideOperandDemotable - Checks if the provided DAG node is an operand +/// that can be demoted to \p OptSize bits without loss of information. The +/// signedness of the operand, if determinable, is placed in \p S. +static bool IsMulWideOperandDemotable(SDValue Op, + unsigned OptSize, + OperandSignedness &S) { + S = Unknown; + + if (Op.getOpcode() == ISD::SIGN_EXTEND || + Op.getOpcode() == ISD::SIGN_EXTEND_INREG) { + EVT OrigVT = Op.getOperand(0).getValueType(); + if (OrigVT.getSizeInBits() == OptSize) { + S = Signed; + return true; + } + } else if (Op.getOpcode() == ISD::ZERO_EXTEND) { + EVT OrigVT = Op.getOperand(0).getValueType(); + if (OrigVT.getSizeInBits() == OptSize) { + S = Unsigned; + return true; + } + } + + return false; +} + +/// AreMulWideOperandsDemotable - Checks if the given LHS and RHS operands can +/// be demoted to \p OptSize bits without loss of information. If the operands +/// contain a constant, it should appear as the RHS operand. The signedness of +/// the operands is placed in \p IsSigned. +static bool AreMulWideOperandsDemotable(SDValue LHS, SDValue RHS, + unsigned OptSize, + bool &IsSigned) { + + OperandSignedness LHSSign; + + // The LHS operand must be a demotable op + if (!IsMulWideOperandDemotable(LHS, OptSize, LHSSign)) + return false; + + // We should have been able to determine the signedness from the LHS + if (LHSSign == Unknown) + return false; + + IsSigned = (LHSSign == Signed); + + // The RHS can be a demotable op or a constant + if (ConstantSDNode *CI = dyn_cast(RHS)) { + APInt Val = CI->getAPIntValue(); + if (LHSSign == Unsigned) { + if (Val.isIntN(OptSize)) { + return true; + } + return false; + } else { + if (Val.isSignedIntN(OptSize)) { + return true; + } + return false; + } + } else { + OperandSignedness RHSSign; + if (!IsMulWideOperandDemotable(RHS, OptSize, RHSSign)) + return false; + + if (LHSSign != RHSSign) + return false; + + return true; + } +} + +/// TryMULWIDECombine - Attempt to replace a multiply of M bits with a multiply +/// of M/2 bits that produces an M-bit result (i.e. mul.wide). This transform +/// works on both multiply DAG nodes and SHL DAG nodes with a constant shift +/// amount. +static SDValue TryMULWIDECombine(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI) { + EVT MulType = N->getValueType(0); + if (MulType != MVT::i32 && MulType != MVT::i64) { + return SDValue(); + } + + unsigned OptSize = MulType.getSizeInBits() >> 1; + SDValue LHS = N->getOperand(0); + SDValue RHS = N->getOperand(1); + + // Canonicalize the multiply so the constant (if any) is on the right + if (N->getOpcode() == ISD::MUL) { + if (isa(LHS)) { + std::swap(LHS, RHS); + } + } + + // If we have a SHL, determine the actual multiply amount + if (N->getOpcode() == ISD::SHL) { + ConstantSDNode *ShlRHS = dyn_cast(RHS); + if (!ShlRHS) { + return SDValue(); + } + + APInt ShiftAmt = ShlRHS->getAPIntValue(); + unsigned BitWidth = MulType.getSizeInBits(); + if (ShiftAmt.sge(0) && ShiftAmt.slt(BitWidth)) { + APInt MulVal = APInt(BitWidth, 1) << ShiftAmt; + RHS = DCI.DAG.getConstant(MulVal, MulType); + } else { + return SDValue(); + } + } + + bool Signed; + // Verify that our operands are demotable + if (!AreMulWideOperandsDemotable(LHS, RHS, OptSize, Signed)) { + return SDValue(); + } + + EVT DemotedVT; + if (MulType == MVT::i32) { + DemotedVT = MVT::i16; + } else { + DemotedVT = MVT::i32; + } + + // Truncate the operands to the correct size. Note that these are just for + // type consistency and will (likely) be eliminated in later phases. + SDValue TruncLHS = + DCI.DAG.getNode(ISD::TRUNCATE, SDLoc(N), DemotedVT, LHS); + SDValue TruncRHS = + DCI.DAG.getNode(ISD::TRUNCATE, SDLoc(N), DemotedVT, RHS); + + unsigned Opc; + if (Signed) { + Opc = NVPTXISD::MUL_WIDE_SIGNED; + } else { + Opc = NVPTXISD::MUL_WIDE_UNSIGNED; + } + + return DCI.DAG.getNode(Opc, SDLoc(N), MulType, TruncLHS, TruncRHS); +} + +/// PerformMULCombine - Runs PTX-specific DAG combine patterns on MUL nodes. +static SDValue PerformMULCombine(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI, + CodeGenOpt::Level OptLevel) { + if (OptLevel > 0) { + // Try mul.wide combining at OptLevel > 0 + SDValue Ret = TryMULWIDECombine(N, DCI); + if (Ret.getNode()) + return Ret; + } + + return SDValue(); +} + +/// PerformSHLCombine - Runs PTX-specific DAG combine patterns on SHL nodes. +static SDValue PerformSHLCombine(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI, + CodeGenOpt::Level OptLevel) { + if (OptLevel > 0) { + // Try mul.wide combining at OptLevel > 0 + SDValue Ret = TryMULWIDECombine(N, DCI); + if (Ret.getNode()) + return Ret; + } + + return SDValue(); +} + +SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N, + DAGCombinerInfo &DCI) const { + // FIXME: Get this from the DAG somehow + CodeGenOpt::Level OptLevel = CodeGenOpt::Aggressive; + switch (N->getOpcode()) { + default: break; + case ISD::ADD: + case ISD::FADD: + return PerformADDCombine(N, DCI, nvptxSubtarget, OptLevel); + case ISD::MUL: + return PerformMULCombine(N, DCI, OptLevel); + case ISD::SHL: + return PerformSHLCombine(N, DCI, OptLevel); + case ISD::AND: + return PerformANDCombine(N, DCI); + } + return SDValue(); +} + /// ReplaceVectorLoad - Convert vector loads into multi-output scalar loads. static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG, SmallVectorImpl &Results) { diff --git a/lib/Target/NVPTX/NVPTXISelLowering.h b/lib/Target/NVPTX/NVPTXISelLowering.h index 4f5e02577b..e909e0a10f 100644 --- a/lib/Target/NVPTX/NVPTXISelLowering.h +++ b/lib/Target/NVPTX/NVPTXISelLowering.h @@ -49,6 +49,9 @@ enum NodeType { CallSeqBegin, CallSeqEnd, CallPrototype, + MUL_WIDE_SIGNED, + MUL_WIDE_UNSIGNED, + IMAD, Dummy, LoadV2 = ISD::FIRST_TARGET_MEMORY_OPCODE, @@ -258,6 +261,7 @@ private: void ReplaceNodeResults(SDNode *N, SmallVectorImpl &Results, SelectionDAG &DAG) const override; + SDValue PerformDAGCombine(SDNode *N, DAGCombinerInfo &DCI) const override; unsigned getArgumentAlignment(SDValue Callee, const ImmutableCallSite *CS, Type *Ty, unsigned Idx) const; diff --git a/lib/Target/NVPTX/NVPTXInstrInfo.td b/lib/Target/NVPTX/NVPTXInstrInfo.td index 725d6fc91c..99309f9f61 100644 --- a/lib/Target/NVPTX/NVPTXInstrInfo.td +++ b/lib/Target/NVPTX/NVPTXInstrInfo.td @@ -464,33 +464,45 @@ def SHL2MUL16 : SDNodeXFormgetTargetConstant(temp.shl(v), MVT::i16); }]>; -def MULWIDES64 : NVPTXInst<(outs Int64Regs:$dst), - (ins Int32Regs:$a, Int32Regs:$b), +def MULWIDES64 + : NVPTXInst<(outs Int64Regs:$dst), (ins Int32Regs:$a, Int32Regs:$b), + "mul.wide.s32 \t$dst, $a, $b;", []>; +def MULWIDES64Imm + : NVPTXInst<(outs Int64Regs:$dst), (ins Int32Regs:$a, i32imm:$b), "mul.wide.s32 \t$dst, $a, $b;", []>; -def MULWIDES64Imm : NVPTXInst<(outs Int64Regs:$dst), - (ins Int32Regs:$a, i64imm:$b), +def MULWIDES64Imm64 + : NVPTXInst<(outs Int64Regs:$dst), (ins Int32Regs:$a, i64imm:$b), "mul.wide.s32 \t$dst, $a, $b;", []>; -def MULWIDEU64 : NVPTXInst<(outs Int64Regs:$dst), - (ins Int32Regs:$a, Int32Regs:$b), +def MULWIDEU64 + : NVPTXInst<(outs Int64Regs:$dst), (ins Int32Regs:$a, Int32Regs:$b), + "mul.wide.u32 \t$dst, $a, $b;", []>; +def MULWIDEU64Imm + : NVPTXInst<(outs Int64Regs:$dst), (ins Int32Regs:$a, i32imm:$b), "mul.wide.u32 \t$dst, $a, $b;", []>; -def MULWIDEU64Imm : NVPTXInst<(outs Int64Regs:$dst), - (ins Int32Regs:$a, i64imm:$b), +def MULWIDEU64Imm64 + : NVPTXInst<(outs Int64Regs:$dst), (ins Int32Regs:$a, i64imm:$b), "mul.wide.u32 \t$dst, $a, $b;", []>; -def MULWIDES32 : NVPTXInst<(outs Int32Regs:$dst), - (ins Int16Regs:$a, Int16Regs:$b), +def MULWIDES32 + : NVPTXInst<(outs Int32Regs:$dst), (ins Int16Regs:$a, Int16Regs:$b), "mul.wide.s16 \t$dst, $a, $b;", []>; -def MULWIDES32Imm : NVPTXInst<(outs Int32Regs:$dst), - (ins Int16Regs:$a, i32imm:$b), +def MULWIDES32Imm + : NVPTXInst<(outs Int32Regs:$dst), (ins Int16Regs:$a, i16imm:$b), + "mul.wide.s16 \t$dst, $a, $b;", []>; +def MULWIDES32Imm32 + : NVPTXInst<(outs Int32Regs:$dst), (ins Int16Regs:$a, i32imm:$b), "mul.wide.s16 \t$dst, $a, $b;", []>; -def MULWIDEU32 : NVPTXInst<(outs Int32Regs:$dst), - (ins Int16Regs:$a, Int16Regs:$b), - "mul.wide.u16 \t$dst, $a, $b;", []>; -def MULWIDEU32Imm : NVPTXInst<(outs Int32Regs:$dst), - (ins Int16Regs:$a, i32imm:$b), +def MULWIDEU32 + : NVPTXInst<(outs Int32Regs:$dst), (ins Int16Regs:$a, Int16Regs:$b), + "mul.wide.u16 \t$dst, $a, $b;", []>; +def MULWIDEU32Imm + : NVPTXInst<(outs Int32Regs:$dst), (ins Int16Regs:$a, i16imm:$b), "mul.wide.u16 \t$dst, $a, $b;", []>; +def MULWIDEU32Imm32 + : NVPTXInst<(outs Int32Regs:$dst), (ins Int16Regs:$a, i32imm:$b), + "mul.wide.u16 \t$dst, $a, $b;", []>; def : Pat<(shl (sext Int32Regs:$a), (i32 Int5Const:$b)), (MULWIDES64Imm Int32Regs:$a, (SHL2MUL32 node:$b))>, @@ -510,25 +522,63 @@ def : Pat<(mul (sext Int32Regs:$a), (sext Int32Regs:$b)), (MULWIDES64 Int32Regs:$a, Int32Regs:$b)>, Requires<[doMulWide]>; def : Pat<(mul (sext Int32Regs:$a), (i64 SInt32Const:$b)), - (MULWIDES64Imm Int32Regs:$a, (i64 SInt32Const:$b))>, + (MULWIDES64Imm64 Int32Regs:$a, (i64 SInt32Const:$b))>, Requires<[doMulWide]>; def : Pat<(mul (zext Int32Regs:$a), (zext Int32Regs:$b)), - (MULWIDEU64 Int32Regs:$a, Int32Regs:$b)>, Requires<[doMulWide]>; + (MULWIDEU64 Int32Regs:$a, Int32Regs:$b)>, + Requires<[doMulWide]>; def : Pat<(mul (zext Int32Regs:$a), (i64 UInt32Const:$b)), - (MULWIDEU64Imm Int32Regs:$a, (i64 UInt32Const:$b))>, + (MULWIDEU64Imm64 Int32Regs:$a, (i64 UInt32Const:$b))>, Requires<[doMulWide]>; def : Pat<(mul (sext Int16Regs:$a), (sext Int16Regs:$b)), - (MULWIDES32 Int16Regs:$a, Int16Regs:$b)>, Requires<[doMulWide]>; + (MULWIDES32 Int16Regs:$a, Int16Regs:$b)>, + Requires<[doMulWide]>; def : Pat<(mul (sext Int16Regs:$a), (i32 SInt16Const:$b)), - (MULWIDES32Imm Int16Regs:$a, (i32 SInt16Const:$b))>, + (MULWIDES32Imm32 Int16Regs:$a, (i32 SInt16Const:$b))>, Requires<[doMulWide]>; def : Pat<(mul (zext Int16Regs:$a), (zext Int16Regs:$b)), - (MULWIDEU32 Int16Regs:$a, Int16Regs:$b)>, Requires<[doMulWide]>; + (MULWIDEU32 Int16Regs:$a, Int16Regs:$b)>, + Requires<[doMulWide]>; def : Pat<(mul (zext Int16Regs:$a), (i32 UInt16Const:$b)), - (MULWIDEU32Imm Int16Regs:$a, (i32 UInt16Const:$b))>, + (MULWIDEU32Imm32 Int16Regs:$a, (i32 UInt16Const:$b))>, + Requires<[doMulWide]>; + + +def SDTMulWide + : SDTypeProfile<1, 2, [SDTCisSameAs<1, 2>]>; +def mul_wide_signed + : SDNode<"NVPTXISD::MUL_WIDE_SIGNED", SDTMulWide>; +def mul_wide_unsigned + : SDNode<"NVPTXISD::MUL_WIDE_UNSIGNED", SDTMulWide>; + +def : Pat<(i32 (mul_wide_signed Int16Regs:$a, Int16Regs:$b)), + (MULWIDES32 Int16Regs:$a, Int16Regs:$b)>, + Requires<[doMulWide]>; +def : Pat<(i32 (mul_wide_signed Int16Regs:$a, imm:$b)), + (MULWIDES32Imm Int16Regs:$a, imm:$b)>, + Requires<[doMulWide]>; +def : Pat<(i32 (mul_wide_unsigned Int16Regs:$a, Int16Regs:$b)), + (MULWIDEU32 Int16Regs:$a, Int16Regs:$b)>, + Requires<[doMulWide]>; +def : Pat<(i32 (mul_wide_unsigned Int16Regs:$a, imm:$b)), + (MULWIDEU32Imm Int16Regs:$a, imm:$b)>, + Requires<[doMulWide]>; + + +def : Pat<(i64 (mul_wide_signed Int32Regs:$a, Int32Regs:$b)), + (MULWIDES64 Int32Regs:$a, Int32Regs:$b)>, + Requires<[doMulWide]>; +def : Pat<(i64 (mul_wide_signed Int32Regs:$a, imm:$b)), + (MULWIDES64Imm Int32Regs:$a, imm:$b)>, + Requires<[doMulWide]>; +def : Pat<(i64 (mul_wide_unsigned Int32Regs:$a, Int32Regs:$b)), + (MULWIDEU64 Int32Regs:$a, Int32Regs:$b)>, + Requires<[doMulWide]>; +def : Pat<(i64 (mul_wide_unsigned Int32Regs:$a, imm:$b)), + (MULWIDEU64Imm Int32Regs:$a, imm:$b)>, Requires<[doMulWide]>; defm MULT : I3<"mul.lo.s", mul>; @@ -544,69 +594,75 @@ defm SREM : I3<"rem.s", srem>; defm UREM : I3<"rem.u", urem>; // The ri version will not be selected as DAGCombiner::visitUREM will lower it. +def SDTIMAD + : SDTypeProfile<1, 3, [SDTCisSameAs<0, 1>, SDTCisInt<0>, + SDTCisInt<2>, SDTCisSameAs<0, 2>, + SDTCisSameAs<0, 3>]>; +def imad + : SDNode<"NVPTXISD::IMAD", SDTIMAD>; + def MAD16rrr : NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$a, Int16Regs:$b, Int16Regs:$c), "mad.lo.s16 \t$dst, $a, $b, $c;", - [(set Int16Regs:$dst, (add - (mul Int16Regs:$a, Int16Regs:$b), Int16Regs:$c))]>; + [(set Int16Regs:$dst, + (imad Int16Regs:$a, Int16Regs:$b, Int16Regs:$c))]>; def MAD16rri : NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$a, Int16Regs:$b, i16imm:$c), "mad.lo.s16 \t$dst, $a, $b, $c;", - [(set Int16Regs:$dst, (add - (mul Int16Regs:$a, Int16Regs:$b), imm:$c))]>; + [(set Int16Regs:$dst, + (imad Int16Regs:$a, Int16Regs:$b, imm:$c))]>; def MAD16rir : NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$a, i16imm:$b, Int16Regs:$c), "mad.lo.s16 \t$dst, $a, $b, $c;", - [(set Int16Regs:$dst, (add - (mul Int16Regs:$a, imm:$b), Int16Regs:$c))]>; + [(set Int16Regs:$dst, + (imad Int16Regs:$a, imm:$b, Int16Regs:$c))]>; def MAD16rii : NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$a, i16imm:$b, i16imm:$c), "mad.lo.s16 \t$dst, $a, $b, $c;", - [(set Int16Regs:$dst, (add (mul Int16Regs:$a, imm:$b), - imm:$c))]>; + [(set Int16Regs:$dst, + (imad Int16Regs:$a, imm:$b, imm:$c))]>; def MAD32rrr : NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$a, Int32Regs:$b, Int32Regs:$c), "mad.lo.s32 \t$dst, $a, $b, $c;", - [(set Int32Regs:$dst, (add - (mul Int32Regs:$a, Int32Regs:$b), Int32Regs:$c))]>; + [(set Int32Regs:$dst, + (imad Int32Regs:$a, Int32Regs:$b, Int32Regs:$c))]>; def MAD32rri : NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$a, Int32Regs:$b, i32imm:$c), "mad.lo.s32 \t$dst, $a, $b, $c;", - [(set Int32Regs:$dst, (add - (mul Int32Regs:$a, Int32Regs:$b), imm:$c))]>; + [(set Int32Regs:$dst, + (imad Int32Regs:$a, Int32Regs:$b, imm:$c))]>; def MAD32rir : NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$a, i32imm:$b, Int32Regs:$c), "mad.lo.s32 \t$dst, $a, $b, $c;", - [(set Int32Regs:$dst, (add - (mul Int32Regs:$a, imm:$b), Int32Regs:$c))]>; + [(set Int32Regs:$dst, + (imad Int32Regs:$a, imm:$b, Int32Regs:$c))]>; def MAD32rii : NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$a, i32imm:$b, i32imm:$c), "mad.lo.s32 \t$dst, $a, $b, $c;", - [(set Int32Regs:$dst, (add - (mul Int32Regs:$a, imm:$b), imm:$c))]>; + [(set Int32Regs:$dst, + (imad Int32Regs:$a, imm:$b, imm:$c))]>; def MAD64rrr : NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$a, Int64Regs:$b, Int64Regs:$c), "mad.lo.s64 \t$dst, $a, $b, $c;", - [(set Int64Regs:$dst, (add - (mul Int64Regs:$a, Int64Regs:$b), Int64Regs:$c))]>; + [(set Int64Regs:$dst, + (imad Int64Regs:$a, Int64Regs:$b, Int64Regs:$c))]>; def MAD64rri : NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$a, Int64Regs:$b, i64imm:$c), "mad.lo.s64 \t$dst, $a, $b, $c;", - [(set Int64Regs:$dst, (add - (mul Int64Regs:$a, Int64Regs:$b), imm:$c))]>; + [(set Int64Regs:$dst, + (imad Int64Regs:$a, Int64Regs:$b, imm:$c))]>; def MAD64rir : NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$a, i64imm:$b, Int64Regs:$c), "mad.lo.s64 \t$dst, $a, $b, $c;", - [(set Int64Regs:$dst, (add - (mul Int64Regs:$a, imm:$b), Int64Regs:$c))]>; + [(set Int64Regs:$dst, + (imad Int64Regs:$a, imm:$b, Int64Regs:$c))]>; def MAD64rii : NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$a, i64imm:$b, i64imm:$c), "mad.lo.s64 \t$dst, $a, $b, $c;", - [(set Int64Regs:$dst, (add - (mul Int64Regs:$a, imm:$b), imm:$c))]>; - + [(set Int64Regs:$dst, + (imad Int64Regs:$a, imm:$b, imm:$c))]>; def INEG16 : NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$src), "neg.s16 \t$dst, $src;", @@ -812,36 +868,26 @@ multiclass FPCONTRACT32 { def rrr : NVPTXInst<(outs Float32Regs:$dst), (ins Float32Regs:$a, Float32Regs:$b, Float32Regs:$c), !strconcat(OpcStr, " \t$dst, $a, $b, $c;"), - [(set Float32Regs:$dst, (fadd - (fmul Float32Regs:$a, Float32Regs:$b), - Float32Regs:$c))]>, Requires<[Pred]>; - // This is to WAR a weird bug in Tablegen that does not automatically - // generate the following permutated rule rrr2 from the above rrr. - // So we explicitly add it here. This happens to FMA32 only. - // See the comments at FMAD32 and FMA32 for more information. - def rrr2 : NVPTXInst<(outs Float32Regs:$dst), - (ins Float32Regs:$a, Float32Regs:$b, Float32Regs:$c), - !strconcat(OpcStr, " \t$dst, $a, $b, $c;"), - [(set Float32Regs:$dst, (fadd Float32Regs:$c, - (fmul Float32Regs:$a, Float32Regs:$b)))]>, + [(set Float32Regs:$dst, + (fma Float32Regs:$a, Float32Regs:$b, Float32Regs:$c))]>, Requires<[Pred]>; def rri : NVPTXInst<(outs Float32Regs:$dst), (ins Float32Regs:$a, Float32Regs:$b, f32imm:$c), !strconcat(OpcStr, " \t$dst, $a, $b, $c;"), - [(set Float32Regs:$dst, (fadd - (fmul Float32Regs:$a, Float32Regs:$b), fpimm:$c))]>, + [(set Float32Regs:$dst, + (fma Float32Regs:$a, Float32Regs:$b, fpimm:$c))]>, Requires<[Pred]>; def rir : NVPTXInst<(outs Float32Regs:$dst), (ins Float32Regs:$a, f32imm:$b, Float32Regs:$c), !strconcat(OpcStr, " \t$dst, $a, $b, $c;"), - [(set Float32Regs:$dst, (fadd - (fmul Float32Regs:$a, fpimm:$b), Float32Regs:$c))]>, + [(set Float32Regs:$dst, + (fma Float32Regs:$a, fpimm:$b, Float32Regs:$c))]>, Requires<[Pred]>; def rii : NVPTXInst<(outs Float32Regs:$dst), (ins Float32Regs:$a, f32imm:$b, f32imm:$c), !strconcat(OpcStr, " \t$dst, $a, $b, $c;"), - [(set Float32Regs:$dst, (fadd - (fmul Float32Regs:$a, fpimm:$b), fpimm:$c))]>, + [(set Float32Regs:$dst, + (fma Float32Regs:$a, fpimm:$b, fpimm:$c))]>, Requires<[Pred]>; } @@ -849,73 +895,32 @@ multiclass FPCONTRACT64 { def rrr : NVPTXInst<(outs Float64Regs:$dst), (ins Float64Regs:$a, Float64Regs:$b, Float64Regs:$c), !strconcat(OpcStr, " \t$dst, $a, $b, $c;"), - [(set Float64Regs:$dst, (fadd - (fmul Float64Regs:$a, Float64Regs:$b), - Float64Regs:$c))]>, Requires<[Pred]>; + [(set Float64Regs:$dst, + (fma Float64Regs:$a, Float64Regs:$b, Float64Regs:$c))]>, + Requires<[Pred]>; def rri : NVPTXInst<(outs Float64Regs:$dst), (ins Float64Regs:$a, Float64Regs:$b, f64imm:$c), !strconcat(OpcStr, " \t$dst, $a, $b, $c;"), - [(set Float64Regs:$dst, (fadd (fmul Float64Regs:$a, - Float64Regs:$b), fpimm:$c))]>, Requires<[Pred]>; + [(set Float64Regs:$dst, + (fma Float64Regs:$a, Float64Regs:$b, fpimm:$c))]>, + Requires<[Pred]>; def rir : NVPTXInst<(outs Float64Regs:$dst), (ins Float64Regs:$a, f64imm:$b, Float64Regs:$c), !strconcat(OpcStr, " \t$dst, $a, $b, $c;"), - [(set Float64Regs:$dst, (fadd - (fmul Float64Regs:$a, fpimm:$b), Float64Regs:$c))]>, + [(set Float64Regs:$dst, + (fma Float64Regs:$a, fpimm:$b, Float64Regs:$c))]>, Requires<[Pred]>; def rii : NVPTXInst<(outs Float64Regs:$dst), (ins Float64Regs:$a, f64imm:$b, f64imm:$c), !strconcat(OpcStr, " \t$dst, $a, $b, $c;"), - [(set Float64Regs:$dst, (fadd - (fmul Float64Regs:$a, fpimm:$b), fpimm:$c))]>, + [(set Float64Regs:$dst, + (fma Float64Regs:$a, fpimm:$b, fpimm:$c))]>, Requires<[Pred]>; } -// Due to a unknown reason (most likely a bug in tablegen), tablegen does not -// automatically generate the rrr2 rule from -// the rrr rule (see FPCONTRACT32) for FMA32, though it does for FMAD32. -// If we reverse the order of the following two lines, then rrr2 rule will be -// generated for FMA32, but not for rrr. -// Therefore, we manually write the rrr2 rule in FPCONTRACT32. -defm FMA32_ftz : FPCONTRACT32<"fma.rn.ftz.f32", doFMAF32_ftz>; -defm FMA32 : FPCONTRACT32<"fma.rn.f32", doFMAF32>; -defm FMA64 : FPCONTRACT64<"fma.rn.f64", doFMAF64>; - -// b*c-a => fmad(b, c, -a) -multiclass FPCONTRACT32_SUB_PAT_MAD { - def : Pat<(fsub (fmul Float32Regs:$b, Float32Regs:$c), Float32Regs:$a), - (Inst Float32Regs:$b, Float32Regs:$c, (FNEGf32 Float32Regs:$a))>, - Requires<[Pred]>; -} - -// a-b*c => fmad(-b,c, a) -// - legal because a-b*c <=> a+(-b*c) <=> a+(-b)*c -// b*c-a => fmad(b, c, -a) -// - legal because b*c-a <=> b*c+(-a) -multiclass FPCONTRACT32_SUB_PAT { - def : Pat<(fsub Float32Regs:$a, (fmul Float32Regs:$b, Float32Regs:$c)), - (Inst (FNEGf32 Float32Regs:$b), Float32Regs:$c, Float32Regs:$a)>, - Requires<[Pred]>; - def : Pat<(fsub (fmul Float32Regs:$b, Float32Regs:$c), Float32Regs:$a), - (Inst Float32Regs:$b, Float32Regs:$c, (FNEGf32 Float32Regs:$a))>, - Requires<[Pred]>; -} - -// a-b*c => fmad(-b,c, a) -// b*c-a => fmad(b, c, -a) -multiclass FPCONTRACT64_SUB_PAT { - def : Pat<(fsub Float64Regs:$a, (fmul Float64Regs:$b, Float64Regs:$c)), - (Inst (FNEGf64 Float64Regs:$b), Float64Regs:$c, Float64Regs:$a)>, - Requires<[Pred]>; - - def : Pat<(fsub (fmul Float64Regs:$b, Float64Regs:$c), Float64Regs:$a), - (Inst Float64Regs:$b, Float64Regs:$c, (FNEGf64 Float64Regs:$a))>, - Requires<[Pred]>; -} - -defm FMAF32ext_ftz : FPCONTRACT32_SUB_PAT; -defm FMAF32ext : FPCONTRACT32_SUB_PAT; -defm FMAF64ext : FPCONTRACT64_SUB_PAT; +defm FMA32_ftz : FPCONTRACT32<"fma.rn.ftz.f32", doF32FTZ>; +defm FMA32 : FPCONTRACT32<"fma.rn.f32", doNoF32FTZ>; +defm FMA64 : FPCONTRACT64<"fma.rn.f64", doNoF32FTZ>; def SINF: NVPTXInst<(outs Float32Regs:$dst), (ins Float32Regs:$src), "sin.approx.f32 \t$dst, $src;", -- cgit v1.2.3