summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--include/llvm/Analysis/TargetTransformInfo.h10
-rw-r--r--include/llvm/Target/TargetLowering.h12
-rw-r--r--lib/Analysis/TargetTransformInfo.cpp17
-rw-r--r--lib/CodeGen/BasicTargetTransformInfo.cpp14
-rw-r--r--lib/Transforms/Scalar/LoopStrengthReduce.cpp46
5 files changed, 96 insertions, 3 deletions
diff --git a/include/llvm/Analysis/TargetTransformInfo.h b/include/llvm/Analysis/TargetTransformInfo.h
index a9d6725d86..eb29e3483d 100644
--- a/include/llvm/Analysis/TargetTransformInfo.h
+++ b/include/llvm/Analysis/TargetTransformInfo.h
@@ -225,6 +225,16 @@ public:
int64_t BaseOffset, bool HasBaseReg,
int64_t Scale) const;
+ /// \brief Return the cost of the scaling factor used in the addressing
+ /// mode represented by AM for this target, for a load/store
+ /// of the specified type.
+ /// If the AM is supported, the return value must be >= 0.
+ /// If the AM is not supported, it returns a negative value.
+ /// TODO: Handle pre/postinc as well.
+ virtual int getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
+ int64_t BaseOffset, bool HasBaseReg,
+ int64_t Scale) const;
+
/// isTruncateFree - Return true if it's free to truncate a value of
/// type Ty1 to type Ty2. e.g. On x86 it's free to truncate a i32 value in
/// register EAX to i16 by referencing its sub-register AX.
diff --git a/include/llvm/Target/TargetLowering.h b/include/llvm/Target/TargetLowering.h
index 41a4a2b838..d67e55dc66 100644
--- a/include/llvm/Target/TargetLowering.h
+++ b/include/llvm/Target/TargetLowering.h
@@ -1139,6 +1139,18 @@ public:
/// TODO: Handle pre/postinc as well.
virtual bool isLegalAddressingMode(const AddrMode &AM, Type *Ty) const;
+ /// \brief Return the cost of the scaling factor used in the addressing
+ /// mode represented by AM for this target, for a load/store
+ /// of the specified type.
+ /// If the AM is supported, the return value must be >= 0.
+ /// If the AM is not supported, it returns a negative value.
+ /// TODO: Handle pre/postinc as well.
+ virtual int getScalingFactorCost(const AddrMode &AM, Type *Ty) const {
+ // Default: assume that any scaling factor used in a legal AM is free.
+ if (isLegalAddressingMode(AM, Ty)) return 0;
+ return -1;
+ }
+
/// isLegalICmpImmediate - Return true if the specified immediate is legal
/// icmp immediate, that is the target has icmp instructions which can compare
/// a register against the immediate without having to materialize the
diff --git a/lib/Analysis/TargetTransformInfo.cpp b/lib/Analysis/TargetTransformInfo.cpp
index 64f8e96884..35ce794c7f 100644
--- a/lib/Analysis/TargetTransformInfo.cpp
+++ b/lib/Analysis/TargetTransformInfo.cpp
@@ -108,6 +108,14 @@ bool TargetTransformInfo::isLegalAddressingMode(Type *Ty, GlobalValue *BaseGV,
Scale);
}
+int TargetTransformInfo::getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
+ int64_t BaseOffset,
+ bool HasBaseReg,
+ int64_t Scale) const {
+ return PrevTTI->getScalingFactorCost(Ty, BaseGV, BaseOffset, HasBaseReg,
+ Scale);
+}
+
bool TargetTransformInfo::isTruncateFree(Type *Ty1, Type *Ty2) const {
return PrevTTI->isTruncateFree(Ty1, Ty2);
}
@@ -457,6 +465,15 @@ struct NoTTI : ImmutablePass, TargetTransformInfo {
return !BaseGV && BaseOffset == 0 && Scale <= 1;
}
+ int getScalingFactorCost(Type *Ty, GlobalValue *BaseGV, int64_t BaseOffset,
+ bool HasBaseReg, int64_t Scale) const {
+ // Guess that all legal addressing mode are free.
+ if(isLegalAddressingMode(Ty, BaseGV, BaseOffset, HasBaseReg, Scale))
+ return 0;
+ return -1;
+ }
+
+
bool isTruncateFree(Type *Ty1, Type *Ty2) const {
return false;
}
diff --git a/lib/CodeGen/BasicTargetTransformInfo.cpp b/lib/CodeGen/BasicTargetTransformInfo.cpp
index 4a99184f5e..92a5bb70f4 100644
--- a/lib/CodeGen/BasicTargetTransformInfo.cpp
+++ b/lib/CodeGen/BasicTargetTransformInfo.cpp
@@ -71,6 +71,9 @@ public:
virtual bool isLegalAddressingMode(Type *Ty, GlobalValue *BaseGV,
int64_t BaseOffset, bool HasBaseReg,
int64_t Scale) const;
+ virtual int getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
+ int64_t BaseOffset, bool HasBaseReg,
+ int64_t Scale) const;
virtual bool isTruncateFree(Type *Ty1, Type *Ty2) const;
virtual bool isTypeLegal(Type *Ty) const;
virtual unsigned getJumpBufAlignment() const;
@@ -139,6 +142,17 @@ bool BasicTTI::isLegalAddressingMode(Type *Ty, GlobalValue *BaseGV,
return TLI->isLegalAddressingMode(AM, Ty);
}
+int BasicTTI::getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
+ int64_t BaseOffset, bool HasBaseReg,
+ int64_t Scale) const {
+ TargetLoweringBase::AddrMode AM;
+ AM.BaseGV = BaseGV;
+ AM.BaseOffs = BaseOffset;
+ AM.HasBaseReg = HasBaseReg;
+ AM.Scale = Scale;
+ return TLI->getScalingFactorCost(AM, Ty);
+}
+
bool BasicTTI::isTruncateFree(Type *Ty1, Type *Ty2) const {
return TLI->isTruncateFree(Ty1, Ty2);
}
diff --git a/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/lib/Transforms/Scalar/LoopStrengthReduce.cpp
index ecc96ae0b2..b107fef35a 100644
--- a/lib/Transforms/Scalar/LoopStrengthReduce.cpp
+++ b/lib/Transforms/Scalar/LoopStrengthReduce.cpp
@@ -779,6 +779,9 @@ class LSRUse;
// Check if it is legal to fold 2 base registers.
static bool isLegal2RegAMUse(const TargetTransformInfo &TTI, const LSRUse &LU,
const Formula &F);
+// Get the cost of the scaling factor used in F for LU.
+static unsigned getScalingFactorCost(const TargetTransformInfo &TTI,
+ const LSRUse &LU, const Formula &F);
namespace {
@@ -792,11 +795,12 @@ class Cost {
unsigned NumBaseAdds;
unsigned ImmCost;
unsigned SetupCost;
+ unsigned ScaleCost;
public:
Cost()
: NumRegs(0), AddRecCost(0), NumIVMuls(0), NumBaseAdds(0), ImmCost(0),
- SetupCost(0) {}
+ SetupCost(0), ScaleCost(0) {}
bool operator<(const Cost &Other) const;
@@ -806,9 +810,9 @@ public:
// Once any of the metrics loses, they must all remain losers.
bool isValid() {
return ((NumRegs | AddRecCost | NumIVMuls | NumBaseAdds
- | ImmCost | SetupCost) != ~0u)
+ | ImmCost | SetupCost | ScaleCost) != ~0u)
|| ((NumRegs & AddRecCost & NumIVMuls & NumBaseAdds
- & ImmCost & SetupCost) == ~0u);
+ & ImmCost & SetupCost & ScaleCost) == ~0u);
}
#endif
@@ -947,6 +951,9 @@ void Cost::RateFormula(const TargetTransformInfo &TTI,
// allows to fold 2 registers.
NumBaseAdds += NumBaseParts - (1 + isLegal2RegAMUse(TTI, LU, F));
+ // Accumulate non-free scaling amounts.
+ ScaleCost += getScalingFactorCost(TTI, LU, F);
+
// Tally up the non-zero immediates.
for (SmallVectorImpl<int64_t>::const_iterator I = Offsets.begin(),
E = Offsets.end(); I != E; ++I) {
@@ -968,6 +975,7 @@ void Cost::Loose() {
NumBaseAdds = ~0u;
ImmCost = ~0u;
SetupCost = ~0u;
+ ScaleCost = ~0u;
}
/// operator< - Choose the lower cost.
@@ -980,6 +988,8 @@ bool Cost::operator<(const Cost &Other) const {
return NumIVMuls < Other.NumIVMuls;
if (NumBaseAdds != Other.NumBaseAdds)
return NumBaseAdds < Other.NumBaseAdds;
+ if (ScaleCost != Other.ScaleCost)
+ return ScaleCost < Other.ScaleCost;
if (ImmCost != Other.ImmCost)
return ImmCost < Other.ImmCost;
if (SetupCost != Other.SetupCost)
@@ -996,6 +1006,8 @@ void Cost::print(raw_ostream &OS) const {
if (NumBaseAdds != 0)
OS << ", plus " << NumBaseAdds << " base add"
<< (NumBaseAdds == 1 ? "" : "s");
+ if (ScaleCost != 0)
+ OS << ", plus " << ScaleCost << " scale cost";
if (ImmCost != 0)
OS << ", plus " << ImmCost << " imm cost";
if (SetupCost != 0)
@@ -1396,6 +1408,34 @@ static bool isLegal2RegAMUse(const TargetTransformInfo &TTI, const LSRUse &LU,
F.BaseGV, F.BaseOffset, F.HasBaseReg, 1);
}
+static unsigned getScalingFactorCost(const TargetTransformInfo &TTI,
+ const LSRUse &LU, const Formula &F) {
+ if (!F.Scale)
+ return 0;
+ assert(isLegalUse(TTI, LU.MinOffset, LU.MaxOffset, LU.Kind,
+ LU.AccessTy, F) && "Illegal formula in use.");
+
+ switch (LU.Kind) {
+ case LSRUse::Address: {
+ int CurScaleCost = TTI.getScalingFactorCost(LU.AccessTy, F.BaseGV,
+ F.BaseOffset, F.HasBaseReg,
+ F.Scale);
+ assert(CurScaleCost >= 0 && "Legal addressing mode has an illegal cost!");
+ return CurScaleCost;
+ }
+ case LSRUse::ICmpZero:
+ // ICmpZero BaseReg + -1*ScaleReg => ICmp BaseReg, ScaleReg.
+ // Therefore, return 0 in case F.Scale == -1.
+ return F.Scale != -1;
+
+ case LSRUse::Basic:
+ case LSRUse::Special:
+ return 0;
+ }
+
+ llvm_unreachable("Invalid LSRUse Kind!");
+}
+
static bool isAlwaysFoldable(const TargetTransformInfo &TTI,
LSRUse::KindType Kind, Type *AccessTy,
GlobalValue *BaseGV, int64_t BaseOffset,