diff options
-rw-r--r-- | include/llvm/Analysis/TargetTransformInfo.h | 10 | ||||
-rw-r--r-- | include/llvm/Target/TargetLowering.h | 12 | ||||
-rw-r--r-- | lib/Analysis/TargetTransformInfo.cpp | 17 | ||||
-rw-r--r-- | lib/CodeGen/BasicTargetTransformInfo.cpp | 14 | ||||
-rw-r--r-- | lib/Transforms/Scalar/LoopStrengthReduce.cpp | 46 |
5 files changed, 96 insertions, 3 deletions
diff --git a/include/llvm/Analysis/TargetTransformInfo.h b/include/llvm/Analysis/TargetTransformInfo.h index a9d6725d86..eb29e3483d 100644 --- a/include/llvm/Analysis/TargetTransformInfo.h +++ b/include/llvm/Analysis/TargetTransformInfo.h @@ -225,6 +225,16 @@ public: int64_t BaseOffset, bool HasBaseReg, int64_t Scale) const; + /// \brief Return the cost of the scaling factor used in the addressing + /// mode represented by AM for this target, for a load/store + /// of the specified type. + /// If the AM is supported, the return value must be >= 0. + /// If the AM is not supported, it returns a negative value. + /// TODO: Handle pre/postinc as well. + virtual int getScalingFactorCost(Type *Ty, GlobalValue *BaseGV, + int64_t BaseOffset, bool HasBaseReg, + int64_t Scale) const; + /// isTruncateFree - Return true if it's free to truncate a value of /// type Ty1 to type Ty2. e.g. On x86 it's free to truncate a i32 value in /// register EAX to i16 by referencing its sub-register AX. diff --git a/include/llvm/Target/TargetLowering.h b/include/llvm/Target/TargetLowering.h index 41a4a2b838..d67e55dc66 100644 --- a/include/llvm/Target/TargetLowering.h +++ b/include/llvm/Target/TargetLowering.h @@ -1139,6 +1139,18 @@ public: /// TODO: Handle pre/postinc as well. virtual bool isLegalAddressingMode(const AddrMode &AM, Type *Ty) const; + /// \brief Return the cost of the scaling factor used in the addressing + /// mode represented by AM for this target, for a load/store + /// of the specified type. + /// If the AM is supported, the return value must be >= 0. + /// If the AM is not supported, it returns a negative value. + /// TODO: Handle pre/postinc as well. + virtual int getScalingFactorCost(const AddrMode &AM, Type *Ty) const { + // Default: assume that any scaling factor used in a legal AM is free. + if (isLegalAddressingMode(AM, Ty)) return 0; + return -1; + } + /// isLegalICmpImmediate - Return true if the specified immediate is legal /// icmp immediate, that is the target has icmp instructions which can compare /// a register against the immediate without having to materialize the diff --git a/lib/Analysis/TargetTransformInfo.cpp b/lib/Analysis/TargetTransformInfo.cpp index 64f8e96884..35ce794c7f 100644 --- a/lib/Analysis/TargetTransformInfo.cpp +++ b/lib/Analysis/TargetTransformInfo.cpp @@ -108,6 +108,14 @@ bool TargetTransformInfo::isLegalAddressingMode(Type *Ty, GlobalValue *BaseGV, Scale); } +int TargetTransformInfo::getScalingFactorCost(Type *Ty, GlobalValue *BaseGV, + int64_t BaseOffset, + bool HasBaseReg, + int64_t Scale) const { + return PrevTTI->getScalingFactorCost(Ty, BaseGV, BaseOffset, HasBaseReg, + Scale); +} + bool TargetTransformInfo::isTruncateFree(Type *Ty1, Type *Ty2) const { return PrevTTI->isTruncateFree(Ty1, Ty2); } @@ -457,6 +465,15 @@ struct NoTTI : ImmutablePass, TargetTransformInfo { return !BaseGV && BaseOffset == 0 && Scale <= 1; } + int getScalingFactorCost(Type *Ty, GlobalValue *BaseGV, int64_t BaseOffset, + bool HasBaseReg, int64_t Scale) const { + // Guess that all legal addressing mode are free. + if(isLegalAddressingMode(Ty, BaseGV, BaseOffset, HasBaseReg, Scale)) + return 0; + return -1; + } + + bool isTruncateFree(Type *Ty1, Type *Ty2) const { return false; } diff --git a/lib/CodeGen/BasicTargetTransformInfo.cpp b/lib/CodeGen/BasicTargetTransformInfo.cpp index 4a99184f5e..92a5bb70f4 100644 --- a/lib/CodeGen/BasicTargetTransformInfo.cpp +++ b/lib/CodeGen/BasicTargetTransformInfo.cpp @@ -71,6 +71,9 @@ public: virtual bool isLegalAddressingMode(Type *Ty, GlobalValue *BaseGV, int64_t BaseOffset, bool HasBaseReg, int64_t Scale) const; + virtual int getScalingFactorCost(Type *Ty, GlobalValue *BaseGV, + int64_t BaseOffset, bool HasBaseReg, + int64_t Scale) const; virtual bool isTruncateFree(Type *Ty1, Type *Ty2) const; virtual bool isTypeLegal(Type *Ty) const; virtual unsigned getJumpBufAlignment() const; @@ -139,6 +142,17 @@ bool BasicTTI::isLegalAddressingMode(Type *Ty, GlobalValue *BaseGV, return TLI->isLegalAddressingMode(AM, Ty); } +int BasicTTI::getScalingFactorCost(Type *Ty, GlobalValue *BaseGV, + int64_t BaseOffset, bool HasBaseReg, + int64_t Scale) const { + TargetLoweringBase::AddrMode AM; + AM.BaseGV = BaseGV; + AM.BaseOffs = BaseOffset; + AM.HasBaseReg = HasBaseReg; + AM.Scale = Scale; + return TLI->getScalingFactorCost(AM, Ty); +} + bool BasicTTI::isTruncateFree(Type *Ty1, Type *Ty2) const { return TLI->isTruncateFree(Ty1, Ty2); } diff --git a/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/lib/Transforms/Scalar/LoopStrengthReduce.cpp index ecc96ae0b2..b107fef35a 100644 --- a/lib/Transforms/Scalar/LoopStrengthReduce.cpp +++ b/lib/Transforms/Scalar/LoopStrengthReduce.cpp @@ -779,6 +779,9 @@ class LSRUse; // Check if it is legal to fold 2 base registers. static bool isLegal2RegAMUse(const TargetTransformInfo &TTI, const LSRUse &LU, const Formula &F); +// Get the cost of the scaling factor used in F for LU. +static unsigned getScalingFactorCost(const TargetTransformInfo &TTI, + const LSRUse &LU, const Formula &F); namespace { @@ -792,11 +795,12 @@ class Cost { unsigned NumBaseAdds; unsigned ImmCost; unsigned SetupCost; + unsigned ScaleCost; public: Cost() : NumRegs(0), AddRecCost(0), NumIVMuls(0), NumBaseAdds(0), ImmCost(0), - SetupCost(0) {} + SetupCost(0), ScaleCost(0) {} bool operator<(const Cost &Other) const; @@ -806,9 +810,9 @@ public: // Once any of the metrics loses, they must all remain losers. bool isValid() { return ((NumRegs | AddRecCost | NumIVMuls | NumBaseAdds - | ImmCost | SetupCost) != ~0u) + | ImmCost | SetupCost | ScaleCost) != ~0u) || ((NumRegs & AddRecCost & NumIVMuls & NumBaseAdds - & ImmCost & SetupCost) == ~0u); + & ImmCost & SetupCost & ScaleCost) == ~0u); } #endif @@ -947,6 +951,9 @@ void Cost::RateFormula(const TargetTransformInfo &TTI, // allows to fold 2 registers. NumBaseAdds += NumBaseParts - (1 + isLegal2RegAMUse(TTI, LU, F)); + // Accumulate non-free scaling amounts. + ScaleCost += getScalingFactorCost(TTI, LU, F); + // Tally up the non-zero immediates. for (SmallVectorImpl<int64_t>::const_iterator I = Offsets.begin(), E = Offsets.end(); I != E; ++I) { @@ -968,6 +975,7 @@ void Cost::Loose() { NumBaseAdds = ~0u; ImmCost = ~0u; SetupCost = ~0u; + ScaleCost = ~0u; } /// operator< - Choose the lower cost. @@ -980,6 +988,8 @@ bool Cost::operator<(const Cost &Other) const { return NumIVMuls < Other.NumIVMuls; if (NumBaseAdds != Other.NumBaseAdds) return NumBaseAdds < Other.NumBaseAdds; + if (ScaleCost != Other.ScaleCost) + return ScaleCost < Other.ScaleCost; if (ImmCost != Other.ImmCost) return ImmCost < Other.ImmCost; if (SetupCost != Other.SetupCost) @@ -996,6 +1006,8 @@ void Cost::print(raw_ostream &OS) const { if (NumBaseAdds != 0) OS << ", plus " << NumBaseAdds << " base add" << (NumBaseAdds == 1 ? "" : "s"); + if (ScaleCost != 0) + OS << ", plus " << ScaleCost << " scale cost"; if (ImmCost != 0) OS << ", plus " << ImmCost << " imm cost"; if (SetupCost != 0) @@ -1396,6 +1408,34 @@ static bool isLegal2RegAMUse(const TargetTransformInfo &TTI, const LSRUse &LU, F.BaseGV, F.BaseOffset, F.HasBaseReg, 1); } +static unsigned getScalingFactorCost(const TargetTransformInfo &TTI, + const LSRUse &LU, const Formula &F) { + if (!F.Scale) + return 0; + assert(isLegalUse(TTI, LU.MinOffset, LU.MaxOffset, LU.Kind, + LU.AccessTy, F) && "Illegal formula in use."); + + switch (LU.Kind) { + case LSRUse::Address: { + int CurScaleCost = TTI.getScalingFactorCost(LU.AccessTy, F.BaseGV, + F.BaseOffset, F.HasBaseReg, + F.Scale); + assert(CurScaleCost >= 0 && "Legal addressing mode has an illegal cost!"); + return CurScaleCost; + } + case LSRUse::ICmpZero: + // ICmpZero BaseReg + -1*ScaleReg => ICmp BaseReg, ScaleReg. + // Therefore, return 0 in case F.Scale == -1. + return F.Scale != -1; + + case LSRUse::Basic: + case LSRUse::Special: + return 0; + } + + llvm_unreachable("Invalid LSRUse Kind!"); +} + static bool isAlwaysFoldable(const TargetTransformInfo &TTI, LSRUse::KindType Kind, Type *AccessTy, GlobalValue *BaseGV, int64_t BaseOffset, |