From ee3f7de62e5616242441a76a8e92260d7b0f10e5 Mon Sep 17 00:00:00 2001 From: Arnold Schwaighofer Date: Fri, 10 Jan 2014 18:20:32 +0000 Subject: LoopVectorizer: Handle strided memory accesses by versioning for (i = 0; i < N; ++i) A[i * Stride1] += B[i * Stride2]; We take loops like this and check that the symbolic strides 'Strided1/2' are one and drop to the scalar loop if they are not. This is currently disabled by default and hidden behind the flag 'enable-mem-access-versioning'. radar://13075509 git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@198950 91177308-0d34-0410-b5e6-96231b3b80d8 --- lib/Transforms/Vectorize/LoopVectorize.cpp | 491 +++++++++++++++++---- .../LoopVectorize/runtime-check-readonly.ll | 12 +- .../Transforms/LoopVectorize/version-mem-access.ll | 50 +++ 3 files changed, 465 insertions(+), 88 deletions(-) create mode 100644 test/Transforms/LoopVectorize/version-mem-access.ll diff --git a/lib/Transforms/Vectorize/LoopVectorize.cpp b/lib/Transforms/Vectorize/LoopVectorize.cpp index 70c18edf55..74285ec245 100644 --- a/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -114,6 +114,21 @@ TinyTripCountVectorThreshold("vectorizer-min-trip-count", cl::init(16), "trip count that is smaller than this " "value.")); +/// This enables versioning on the strides of symbolically striding memory +/// accesses in code like the following. +/// for (i = 0; i < N; ++i) +/// A[i * Stride1] += B[i * Stride2] ... +/// +/// Will be roughly translated to +/// if (Stride1 == 1 && Stride2 == 1) { +/// for (i = 0; i < N; i+=4) +/// A[i:i+3] += ... +/// } else +/// ... +static cl::opt EnableMemAccessVersioning( + "enable-mem-access-versioning", cl::init(false), cl::Hidden, + cl::desc("Enable symblic stride memory access versioning")); + /// We don't unroll loops with a known constant trip count below this number. static const unsigned TinyTripCountUnrollThreshold = 128; @@ -158,15 +173,16 @@ public: unsigned UnrollFactor) : OrigLoop(OrigLoop), SE(SE), LI(LI), DT(DT), DL(DL), TLI(TLI), VF(VecWidth), UF(UnrollFactor), Builder(SE->getContext()), Induction(0), - OldInduction(0), WidenMap(UnrollFactor) {} + OldInduction(0), WidenMap(UnrollFactor), Legal(0) {} // Perform the actual loop widening (vectorization). - void vectorize(LoopVectorizationLegality *Legal) { + void vectorize(LoopVectorizationLegality *L) { + Legal = L; // Create a new empty loop. Unlink the old loop and connect the new one. - createEmptyLoop(Legal); + createEmptyLoop(); // Widen each instruction in the old loop to a new one in the new loop. // Use the Legality module to find the induction and reduction variables. - vectorizeLoop(Legal); + vectorizeLoop(); // Register the new loop and update the analysis passes. updateAnalysis(); } @@ -186,14 +202,23 @@ protected: typedef DenseMap, VectorParts> EdgeMaskCache; - /// Add code that checks at runtime if the accessed arrays overlap. - /// Returns the comparator value or NULL if no check is needed. - Instruction *addRuntimeCheck(LoopVectorizationLegality *Legal, - Instruction *Loc); + /// \brief Add code that checks at runtime if the accessed arrays overlap. + /// + /// Returns a pair of instructions where the first element is the first + /// instruction generated in possibly a sequence of instructions and the + /// second value is the final comparator value or NULL if no check is needed. + std::pair addRuntimeCheck(Instruction *Loc); + + /// \brief Add checks for strides that where assumed to be 1. + /// + /// Returns the last check instruction and the first check instruction in the + /// pair as (first, last). + std::pair addStrideCheck(Instruction *Loc); + /// Create an empty loop, based on the loop ranges of the old loop. - void createEmptyLoop(LoopVectorizationLegality *Legal); + void createEmptyLoop(); /// Copy and widen the instructions from the old loop. - virtual void vectorizeLoop(LoopVectorizationLegality *Legal); + virtual void vectorizeLoop(); /// \brief The Loop exit block may have single value PHI nodes where the /// incoming value is 'Undef'. While vectorizing we only handled real values @@ -210,14 +235,12 @@ protected: VectorParts createEdgeMask(BasicBlock *Src, BasicBlock *Dst); /// A helper function to vectorize a single BB within the innermost loop. - void vectorizeBlockInLoop(LoopVectorizationLegality *Legal, BasicBlock *BB, - PhiVector *PV); + void vectorizeBlockInLoop(BasicBlock *BB, PhiVector *PV); /// Vectorize a single PHINode in a block. This method handles the induction /// variable canonicalization. It supports both VF = 1 for unrolled loops and /// arbitrary length vectors. void widenPHIInstruction(Instruction *PN, VectorParts &Entry, - LoopVectorizationLegality *Legal, unsigned UF, unsigned VF, PhiVector *PV); /// Insert the new loop to the loop hierarchy and pass manager @@ -229,8 +252,7 @@ protected: virtual void scalarizeInstruction(Instruction *Instr); /// Vectorize Load and Store instructions, - virtual void vectorizeMemoryInstruction(Instruction *Instr, - LoopVectorizationLegality *Legal); + virtual void vectorizeMemoryInstruction(Instruction *Instr); /// Create a broadcast instruction. This method generates a broadcast /// instruction (shuffle) for loop invariant values and for the induction @@ -345,6 +367,8 @@ protected: /// Maps scalars to widened vectors. ValueMap WidenMap; EdgeMaskCache MaskCache; + + LoopVectorizationLegality *Legal; }; class InnerLoopUnroller : public InnerLoopVectorizer { @@ -356,8 +380,7 @@ public: private: virtual void scalarizeInstruction(Instruction *Instr); - virtual void vectorizeMemoryInstruction(Instruction *Instr, - LoopVectorizationLegality *Legal); + virtual void vectorizeMemoryInstruction(Instruction *Instr); virtual Value *getBroadcastInstrs(Value *V); virtual Value *getConsecutiveVector(Value* Val, int StartIdx, bool Negate); virtual Value *reverseVector(Value *Vec); @@ -500,7 +523,7 @@ public: /// Insert a pointer and calculate the start and end SCEVs. void insert(ScalarEvolution *SE, Loop *Lp, Value *Ptr, bool WritePtr, - unsigned DepSetId); + unsigned DepSetId, ValueToValueMap &Strides); /// This flag indicates if we need to add the runtime check. bool Need; @@ -584,6 +607,13 @@ public: unsigned getMaxSafeDepDistBytes() { return MaxSafeDepDistBytes; } + bool hasStride(Value *V) { return StrideSet.count(V); } + bool mustCheckStrides() { return !StrideSet.empty(); } + SmallPtrSet::iterator strides_begin() { + return StrideSet.begin(); + } + SmallPtrSet::iterator strides_end() { return StrideSet.end(); } + private: /// Check if a single basic block loop is vectorizable. /// At this point we know that this is a loop with a constant trip count @@ -626,6 +656,12 @@ private: /// if the PHI is not an induction variable. InductionKind isInductionVariable(PHINode *Phi); + /// \brief Collect memory access with loop invariant strides. + /// + /// Looks for accesses like "a[i * StrideA]" where "StrideA" is loop + /// invariant. + void collectStridedAcccess(Value *LoadOrStoreInst); + /// The loop that we evaluate. Loop *TheLoop; /// Scev analysis. @@ -664,6 +700,9 @@ private: bool HasFunNoNaNAttr; unsigned MaxSafeDepDistBytes; + + ValueToValueMap Strides; + SmallPtrSet StrideSet; }; /// LoopVectorizationCostModel - estimates the expected speedups due to @@ -1033,12 +1072,52 @@ struct LoopVectorize : public LoopPass { // LoopVectorizationCostModel. //===----------------------------------------------------------------------===// -void -LoopVectorizationLegality::RuntimePointerCheck::insert(ScalarEvolution *SE, - Loop *Lp, Value *Ptr, - bool WritePtr, - unsigned DepSetId) { - const SCEV *Sc = SE->getSCEV(Ptr); +static Value *stripCast(Value *V) { + if (CastInst *CI = dyn_cast(V)) + return CI->getOperand(0); + return V; +} + +///\brief Replaces the symbolic stride in a pointer SCEV expression by one. +/// +/// If \p OrigPtr is not null, use it to look up the stride value instead of +/// \p Ptr. +static const SCEV *replaceSymbolicStrideSCEV(ScalarEvolution *SE, + ValueToValueMap &PtrToStride, + Value *Ptr, Value *OrigPtr = 0) { + + const SCEV *OrigSCEV = SE->getSCEV(Ptr); + + // If there is an entry in the map return the SCEV of the pointer with the + // symbolic stride replaced by one. + ValueToValueMap::iterator SI = PtrToStride.find(OrigPtr ? OrigPtr : Ptr); + if (SI != PtrToStride.end()) { + Value *StrideVal = SI->second; + + // Strip casts. + StrideVal = stripCast(StrideVal); + + // Replace symbolic stride by one. + Value *One = ConstantInt::get(StrideVal->getType(), 1); + ValueToValueMap RewriteMap; + RewriteMap[StrideVal] = One; + + const SCEV *ByOne = + SCEVParameterRewriter::rewrite(OrigSCEV, *SE, RewriteMap, true); + DEBUG(dbgs() << "LV: Replacing SCEV: " << *OrigSCEV << " by: " << *ByOne + << "\n"); + return ByOne; + } + + // Otherwise, just return the SCEV of the original pointer. + return SE->getSCEV(Ptr); +} + +void LoopVectorizationLegality::RuntimePointerCheck::insert( + ScalarEvolution *SE, Loop *Lp, Value *Ptr, bool WritePtr, unsigned DepSetId, + ValueToValueMap &Strides) { + // Get the stride replaced scev. + const SCEV *Sc = replaceSymbolicStrideSCEV(SE, Strides, Ptr); const SCEVAddRecExpr *AR = dyn_cast(Sc); assert(AR && "Invalid addrec expression"); const SCEV *Ex = SE->getBackedgeTakenCount(Lp); @@ -1170,7 +1249,27 @@ int LoopVectorizationLegality::isConsecutivePtr(Value *Ptr) { // We can emit wide load/stores only if the last non-zero index is the // induction variable. - const SCEV *Last = SE->getSCEV(Gep->getOperand(InductionOperand)); + const SCEV *Last = 0; + if (!Strides.count(Gep)) + Last = SE->getSCEV(Gep->getOperand(InductionOperand)); + else { + // Because of the multiplication by a stride we can have a s/zext cast. + // We are going to replace this stride by 1 so the cast is safe to ignore. + // + // %indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next, %for.body ] + // %0 = trunc i64 %indvars.iv to i32 + // %mul = mul i32 %0, %Stride1 + // %idxprom = zext i32 %mul to i64 << Safe cast. + // %arrayidx = getelementptr inbounds i32* %B, i64 %idxprom + // + Last = replaceSymbolicStrideSCEV(SE, Strides, + Gep->getOperand(InductionOperand), Gep); + if (const SCEVCastExpr *C = dyn_cast(Last)) + Last = + (C->getSCEVType() == scSignExtend || C->getSCEVType() == scZeroExtend) + ? C->getOperand() + : Last; + } if (const SCEVAddRecExpr *AR = dyn_cast(Last)) { const SCEV *Step = AR->getStepRecurrence(*SE); @@ -1194,6 +1293,10 @@ InnerLoopVectorizer::getVectorValue(Value *V) { assert(V != Induction && "The new induction variable should not be used."); assert(!V->getType()->isVectorTy() && "Can't widen a vector"); + // If we have a stride that is replaced by one, do it here. + if (Legal->hasStride(V)) + V = ConstantInt::get(V->getType(), 1); + // If we have this scalar in the map, return it. if (WidenMap.has(V)) return WidenMap.get(V); @@ -1215,9 +1318,7 @@ Value *InnerLoopVectorizer::reverseVector(Value *Vec) { "reverse"); } - -void InnerLoopVectorizer::vectorizeMemoryInstruction(Instruction *Instr, - LoopVectorizationLegality *Legal) { +void InnerLoopVectorizer::vectorizeMemoryInstruction(Instruction *Instr) { // Attempt to issue a wide load. LoadInst *LI = dyn_cast(Instr); StoreInst *SI = dyn_cast(Instr); @@ -1427,14 +1528,58 @@ void InnerLoopVectorizer::scalarizeInstruction(Instruction *Instr) { } } -Instruction * -InnerLoopVectorizer::addRuntimeCheck(LoopVectorizationLegality *Legal, - Instruction *Loc) { +static Instruction *getFirstInst(Instruction *FirstInst, Value *V, + Instruction *Loc) { + if (FirstInst) + return FirstInst; + if (Instruction *I = dyn_cast(V)) + return I->getParent() == Loc->getParent() ? I : 0; + return 0; +} + +std::pair +InnerLoopVectorizer::addStrideCheck(Instruction *Loc) { + if (!Legal->mustCheckStrides()) + return std::pair(0, 0); + + IRBuilder<> ChkBuilder(Loc); + + // Emit checks. + Value *Check = 0; + Instruction *FirstInst = 0; + for (SmallPtrSet::iterator SI = Legal->strides_begin(), + SE = Legal->strides_end(); + SI != SE; ++SI) { + Value *Ptr = stripCast(*SI); + Value *C = ChkBuilder.CreateICmpNE(Ptr, ConstantInt::get(Ptr->getType(), 1), + "stride.chk"); + // Store the first instruction we create. + FirstInst = getFirstInst(FirstInst, C, Loc); + if (Check) + Check = ChkBuilder.CreateOr(Check, C); + else + Check = C; + } + + // We have to do this trickery because the IRBuilder might fold the check to a + // constant expression in which case there is no Instruction anchored in a + // the block. + LLVMContext &Ctx = Loc->getContext(); + Instruction *TheCheck = + BinaryOperator::CreateAnd(Check, ConstantInt::getTrue(Ctx)); + ChkBuilder.Insert(TheCheck, "stride.not.one"); + FirstInst = getFirstInst(FirstInst, TheCheck, Loc); + + return std::make_pair(FirstInst, TheCheck); +} + +std::pair +InnerLoopVectorizer::addRuntimeCheck(Instruction *Loc) { LoopVectorizationLegality::RuntimePointerCheck *PtrRtCheck = Legal->getRuntimePointerCheck(); if (!PtrRtCheck->Need) - return NULL; + return std::pair(0, 0); unsigned NumPointers = PtrRtCheck->Pointers.size(); SmallVector , 2> Starts; @@ -1442,6 +1587,7 @@ InnerLoopVectorizer::addRuntimeCheck(LoopVectorizationLegality *Legal, LLVMContext &Ctx = Loc->getContext(); SCEVExpander Exp(*SE, "induction"); + Instruction *FirstInst = 0; for (unsigned i = 0; i < NumPointers; ++i) { Value *Ptr = PtrRtCheck->Pointers[i]; @@ -1495,11 +1641,16 @@ InnerLoopVectorizer::addRuntimeCheck(LoopVectorizationLegality *Legal, Value *End1 = ChkBuilder.CreateBitCast(Ends[j], PtrArithTy0, "bc"); Value *Cmp0 = ChkBuilder.CreateICmpULE(Start0, End1, "bound0"); + FirstInst = getFirstInst(FirstInst, Cmp0, Loc); Value *Cmp1 = ChkBuilder.CreateICmpULE(Start1, End0, "bound1"); + FirstInst = getFirstInst(FirstInst, Cmp1, Loc); Value *IsConflict = ChkBuilder.CreateAnd(Cmp0, Cmp1, "found.conflict"); - if (MemoryRuntimeCheck) + FirstInst = getFirstInst(FirstInst, IsConflict, Loc); + if (MemoryRuntimeCheck) { IsConflict = ChkBuilder.CreateOr(MemoryRuntimeCheck, IsConflict, "conflict.rdx"); + FirstInst = getFirstInst(FirstInst, IsConflict, Loc); + } MemoryRuntimeCheck = IsConflict; } } @@ -1510,11 +1661,11 @@ InnerLoopVectorizer::addRuntimeCheck(LoopVectorizationLegality *Legal, Instruction *Check = BinaryOperator::CreateAnd(MemoryRuntimeCheck, ConstantInt::getTrue(Ctx)); ChkBuilder.Insert(Check, "memcheck.conflict"); - return Check; + FirstInst = getFirstInst(FirstInst, Check, Loc); + return std::make_pair(FirstInst, Check); } -void -InnerLoopVectorizer::createEmptyLoop(LoopVectorizationLegality *Legal) { +void InnerLoopVectorizer::createEmptyLoop() { /* In this function we generate a new loop. The new loop will contain the vectorized instructions while the old loop will continue to run the @@ -1665,22 +1816,48 @@ InnerLoopVectorizer::createEmptyLoop(LoopVectorizationLegality *Legal) { BasicBlock *LastBypassBlock = BypassBlock; + // Generate the code to check that the strides we assumed to be one are really + // one. We want the new basic block to start at the first instruction in a + // sequence of instructions that form a check. + Instruction *StrideCheck; + Instruction *FirstCheckInst; + tie(FirstCheckInst, StrideCheck) = + addStrideCheck(BypassBlock->getTerminator()); + if (StrideCheck) { + // Create a new block containing the stride check. + BasicBlock *CheckBlock = + BypassBlock->splitBasicBlock(FirstCheckInst, "vector.stridecheck"); + if (ParentLoop) + ParentLoop->addBasicBlockToLoop(CheckBlock, LI->getBase()); + LoopBypassBlocks.push_back(CheckBlock); + + // Replace the branch into the memory check block with a conditional branch + // for the "few elements case". + Instruction *OldTerm = BypassBlock->getTerminator(); + BranchInst::Create(MiddleBlock, CheckBlock, Cmp, OldTerm); + OldTerm->eraseFromParent(); + + Cmp = StrideCheck; + LastBypassBlock = CheckBlock; + } + // Generate the code that checks in runtime if arrays overlap. We put the // checks into a separate block to make the more common case of few elements // faster. - Instruction *MemRuntimeCheck = addRuntimeCheck(Legal, - BypassBlock->getTerminator()); + Instruction *MemRuntimeCheck; + tie(FirstCheckInst, MemRuntimeCheck) = + addRuntimeCheck(LastBypassBlock->getTerminator()); if (MemRuntimeCheck) { // Create a new block containing the memory check. - BasicBlock *CheckBlock = BypassBlock->splitBasicBlock(MemRuntimeCheck, - "vector.memcheck"); + BasicBlock *CheckBlock = + LastBypassBlock->splitBasicBlock(MemRuntimeCheck, "vector.memcheck"); if (ParentLoop) ParentLoop->addBasicBlockToLoop(CheckBlock, LI->getBase()); LoopBypassBlocks.push_back(CheckBlock); // Replace the branch into the memory check block with a conditional branch // for the "few elements case". - Instruction *OldTerm = BypassBlock->getTerminator(); + Instruction *OldTerm = LastBypassBlock->getTerminator(); BranchInst::Create(MiddleBlock, CheckBlock, Cmp, OldTerm); OldTerm->eraseFromParent(); @@ -2138,8 +2315,7 @@ static void cse(BasicBlock *BB) { } } -void -InnerLoopVectorizer::vectorizeLoop(LoopVectorizationLegality *Legal) { +void InnerLoopVectorizer::vectorizeLoop() { //===------------------------------------------------===// // // Notice: any optimization or new instruction that go @@ -2167,7 +2343,7 @@ InnerLoopVectorizer::vectorizeLoop(LoopVectorizationLegality *Legal) { // Vectorize all of the blocks in the original loop. for (LoopBlocksDFS::RPOIterator bb = DFS.beginRPO(), be = DFS.endRPO(); bb != be; ++bb) - vectorizeBlockInLoop(Legal, *bb, &RdxPHIsToFix); + vectorizeBlockInLoop(*bb, &RdxPHIsToFix); // At this point every instruction in the original loop is widened to // a vector form. We are almost done. Now, we need to fix the PHI nodes @@ -2434,7 +2610,6 @@ InnerLoopVectorizer::createBlockInMask(BasicBlock *BB) { void InnerLoopVectorizer::widenPHIInstruction(Instruction *PN, InnerLoopVectorizer::VectorParts &Entry, - LoopVectorizationLegality *Legal, unsigned UF, unsigned VF, PhiVector *PV) { PHINode* P = cast(PN); // Handle reduction variables: @@ -2596,9 +2771,7 @@ void InnerLoopVectorizer::widenPHIInstruction(Instruction *PN, } } -void -InnerLoopVectorizer::vectorizeBlockInLoop(LoopVectorizationLegality *Legal, - BasicBlock *BB, PhiVector *PV) { +void InnerLoopVectorizer::vectorizeBlockInLoop(BasicBlock *BB, PhiVector *PV) { // For each instruction in the old loop. for (BasicBlock::iterator it = BB->begin(), e = BB->end(); it != e; ++it) { VectorParts &Entry = WidenMap.get(it); @@ -2609,7 +2782,7 @@ InnerLoopVectorizer::vectorizeBlockInLoop(LoopVectorizationLegality *Legal, continue; case Instruction::PHI:{ // Vectorize PHINodes. - widenPHIInstruction(it, Entry, Legal, UF, VF, PV); + widenPHIInstruction(it, Entry, UF, VF, PV); continue; }// End of PHI. @@ -2703,7 +2876,7 @@ InnerLoopVectorizer::vectorizeBlockInLoop(LoopVectorizationLegality *Legal, case Instruction::Store: case Instruction::Load: - vectorizeMemoryInstruction(it, Legal); + vectorizeMemoryInstruction(it); break; case Instruction::ZExt: case Instruction::SExt: @@ -3120,8 +3293,14 @@ bool LoopVectorizationLegality::canVectorizeInstrs() { Type *T = ST->getValueOperand()->getType(); if (!VectorType::isValidElementType(T)) return false; + if (EnableMemAccessVersioning) + collectStridedAcccess(ST); } + if (EnableMemAccessVersioning) + if (LoadInst *LI = dyn_cast(it)) + collectStridedAcccess(LI); + // Reduction instructions are allowed to have exit users. // All other instructions must not have external users. if (hasOutsideLoopUser(TheLoop, it, AllowedExit)) @@ -3140,6 +3319,139 @@ bool LoopVectorizationLegality::canVectorizeInstrs() { return true; } +///\brief Remove GEPs whose indices but the last one are loop invariant and +/// return the induction operand of the gep pointer. +static Value *stripGetElementPtr(Value *Ptr, ScalarEvolution *SE, + DataLayout *DL, Loop *Lp) { + GetElementPtrInst *GEP = dyn_cast(Ptr); + if (!GEP) + return Ptr; + + unsigned InductionOperand = getGEPInductionOperand(DL, GEP); + + // Check that all of the gep indices are uniform except for our induction + // operand. + for (unsigned i = 0, e = GEP->getNumOperands(); i != e; ++i) + if (i != InductionOperand && + !SE->isLoopInvariant(SE->getSCEV(GEP->getOperand(i)), Lp)) + return Ptr; + return GEP->getOperand(InductionOperand); +} + +///\brief Look for a cast use of the passed value. +static Value *getUniqueCastUse(Value *Ptr, Loop *Lp, Type *Ty) { + Value *UniqueCast = 0; + for (Value::use_iterator UI = Ptr->use_begin(), UE = Ptr->use_end(); UI != UE; + ++UI) { + CastInst *CI = dyn_cast(*UI); + if (CI && CI->getType() == Ty) { + if (!UniqueCast) + UniqueCast = CI; + else + return 0; + } + } + return UniqueCast; +} + +///\brief Get the stride of a pointer access in a loop. +/// Looks for symbolic strides "a[i*stride]". Returns the symbolic stride as a +/// pointer to the Value, or null otherwise. +static Value *getStrideFromPointer(Value *Ptr, ScalarEvolution *SE, + DataLayout *DL, Loop *Lp) { + const PointerType *PtrTy = dyn_cast(Ptr->getType()); + if (!PtrTy || PtrTy->isAggregateType()) + return 0; + + // Try to remove a gep instruction to make the pointer (actually index at this + // point) easier analyzable. If OrigPtr is equal to Ptr we are analzying the + // pointer, otherwise, we are analyzing the index. + Value *OrigPtr = Ptr; + + // The size of the pointer access. + int64_t PtrAccessSize = 1; + + Ptr = stripGetElementPtr(Ptr, SE, DL, Lp); + const SCEV *V = SE->getSCEV(Ptr); + + if (Ptr != OrigPtr) + // Strip off casts. + while (const SCEVCastExpr *C = dyn_cast(V)) + V = C->getOperand(); + + const SCEVAddRecExpr *S = dyn_cast(V); + if (!S) + return 0; + + V = S->getStepRecurrence(*SE); + if (!V) + return 0; + + // Strip off the size of access multiplication if we are still analyzing the + // pointer. + if (OrigPtr == Ptr) { + DL->getTypeAllocSize(PtrTy->getElementType()); + if (const SCEVMulExpr *M = dyn_cast(V)) { + if (M->getOperand(0)->getSCEVType() != scConstant) + return 0; + + const APInt &APStepVal = + cast(M->getOperand(0))->getValue()->getValue(); + + // Huge step value - give up. + if (APStepVal.getBitWidth() > 64) + return 0; + + int64_t StepVal = APStepVal.getSExtValue(); + if (PtrAccessSize != StepVal) + return 0; + V = M->getOperand(1); + } + } + + // Strip off casts. + Type *StripedOffRecurrenceCast = 0; + if (const SCEVCastExpr *C = dyn_cast(V)) { + StripedOffRecurrenceCast = C->getType(); + V = C->getOperand(); + } + + // Look for the loop invariant symbolic value. + const SCEVUnknown *U = dyn_cast(V); + if (!U) + return 0; + + Value *Stride = U->getValue(); + if (!Lp->isLoopInvariant(Stride)) + return 0; + + // If we have stripped off the recurrence cast we have to make sure that we + // return the value that is used in this loop so that we can replace it later. + if (StripedOffRecurrenceCast) + Stride = getUniqueCastUse(Stride, Lp, StripedOffRecurrenceCast); + + return Stride; +} + +void LoopVectorizationLegality::collectStridedAcccess(Value *MemAccess) { + Value *Ptr = 0; + if (LoadInst *LI = dyn_cast(MemAccess)) + Ptr = LI->getPointerOperand(); + else if (StoreInst *SI = dyn_cast(MemAccess)) + Ptr = SI->getPointerOperand(); + else + return; + + Value *Stride = getStrideFromPointer(Ptr, SE, DL, TheLoop); + if (!Stride) + return; + + DEBUG(dbgs() << "LV: Found a strided access that we can version"); + DEBUG(dbgs() << " Ptr: " << *Ptr << " Stride: " << *Stride << "\n"); + Strides[Ptr] = Stride; + StrideSet.insert(Stride); +} + void LoopVectorizationLegality::collectLoopUniforms() { // We now know that the loop is vectorizable! // Collect variables that will remain uniform after vectorization. @@ -3201,7 +3513,8 @@ public: /// non-intersection. bool canCheckPtrAtRT(LoopVectorizationLegality::RuntimePointerCheck &RtCheck, unsigned &NumComparisons, ScalarEvolution *SE, - Loop *TheLoop, bool ShouldCheckStride = false); + Loop *TheLoop, ValueToValueMap &Strides, + bool ShouldCheckStride = false); /// \brief Goes over all memory accesses, checks whether a RT check is needed /// and builds sets of dependent accesses. @@ -3261,8 +3574,9 @@ private: } // end anonymous namespace /// \brief Check whether a pointer can participate in a runtime bounds check. -static bool hasComputableBounds(ScalarEvolution *SE, Value *Ptr) { - const SCEV *PtrScev = SE->getSCEV(Ptr); +static bool hasComputableBounds(ScalarEvolution *SE, ValueToValueMap &Strides, + Value *Ptr) { + const SCEV *PtrScev = replaceSymbolicStrideSCEV(SE, Strides, Ptr); const SCEVAddRecExpr *AR = dyn_cast(PtrScev); if (!AR) return false; @@ -3273,12 +3587,12 @@ static bool hasComputableBounds(ScalarEvolution *SE, Value *Ptr) { /// \brief Check the stride of the pointer and ensure that it does not wrap in /// the address space. static int isStridedPtr(ScalarEvolution *SE, DataLayout *DL, Value *Ptr, - const Loop *Lp); + const Loop *Lp, ValueToValueMap &StridesMap); bool AccessAnalysis::canCheckPtrAtRT( - LoopVectorizationLegality::RuntimePointerCheck &RtCheck, - unsigned &NumComparisons, ScalarEvolution *SE, - Loop *TheLoop, bool ShouldCheckStride) { + LoopVectorizationLegality::RuntimePointerCheck &RtCheck, + unsigned &NumComparisons, ScalarEvolution *SE, Loop *TheLoop, + ValueToValueMap &StridesMap, bool ShouldCheckStride) { // Find pointers with computable bounds. We are going to use this information // to place a runtime bound check. unsigned NumReadPtrChecks = 0; @@ -3306,10 +3620,11 @@ bool AccessAnalysis::canCheckPtrAtRT( else ++NumReadPtrChecks; - if (hasComputableBounds(SE, Ptr) && + if (hasComputableBounds(SE, StridesMap, Ptr) && // When we run after a failing dependency check we have to make sure we // don't have wrapping pointers. - (!ShouldCheckStride || isStridedPtr(SE, DL, Ptr, TheLoop) == 1)) { + (!ShouldCheckStride || + isStridedPtr(SE, DL, Ptr, TheLoop, StridesMap) == 1)) { // The id of the dependence set. unsigned DepId; @@ -3323,7 +3638,7 @@ bool AccessAnalysis::canCheckPtrAtRT( // Each access has its own dependence set. DepId = RunningDepId++; - RtCheck.insert(SE, TheLoop, Ptr, IsWrite, DepId); + RtCheck.insert(SE, TheLoop, Ptr, IsWrite, DepId, StridesMap); DEBUG(dbgs() << "LV: Found a runtime check ptr:" << *Ptr << '\n'); } else { @@ -3517,7 +3832,7 @@ public: /// /// Only checks sets with elements in \p CheckDeps. bool areDepsSafe(AccessAnalysis::DepCandidates &AccessSets, - MemAccessInfoSet &CheckDeps); + MemAccessInfoSet &CheckDeps, ValueToValueMap &Strides); /// \brief The maximum number of bytes of a vector register we can vectorize /// the accesses safely with. @@ -3561,7 +3876,8 @@ private: /// distance is smaller than any other distance encountered so far). /// Otherwise, this function returns true signaling a possible dependence. bool isDependent(const MemAccessInfo &A, unsigned AIdx, - const MemAccessInfo &B, unsigned BIdx); + const MemAccessInfo &B, unsigned BIdx, + ValueToValueMap &Strides); /// \brief Check whether the data dependence could prevent store-load /// forwarding. @@ -3578,7 +3894,7 @@ static bool isInBoundsGep(Value *Ptr) { /// \brief Check whether the access through \p Ptr has a constant stride. static int isStridedPtr(ScalarEvolution *SE, DataLayout *DL, Value *Ptr, - const Loop *Lp) { + const Loop *Lp, ValueToValueMap &StridesMap) { const Type *Ty = Ptr->getType(); assert(Ty->isPointerTy() && "Unexpected non-ptr"); @@ -3590,7 +3906,8 @@ static int isStridedPtr(ScalarEvolution *SE, DataLayout *DL, Value *Ptr, return 0; } - const SCEV *PtrScev = SE->getSCEV(Ptr); + const SCEV *PtrScev = replaceSymbolicStrideSCEV(SE, StridesMap, Ptr); + const SCEVAddRecExpr *AR = dyn_cast(PtrScev); if (!AR) { DEBUG(dbgs() << "LV: Bad stride - Not an AddRecExpr pointer " @@ -3694,7 +4011,8 @@ bool MemoryDepChecker::couldPreventStoreLoadForward(unsigned Distance, } bool MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx, - const MemAccessInfo &B, unsigned BIdx) { + const MemAccessInfo &B, unsigned BIdx, + ValueToValueMap &Strides) { assert (AIdx < BIdx && "Must pass arguments in program order"); Value *APtr = A.getPointer(); @@ -3706,11 +4024,11 @@ bool MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx, if (!AIsWrite && !BIsWrite) return false; - const SCEV *AScev = SE->getSCEV(APtr); - const SCEV *BScev = SE->getSCEV(BPtr); + const SCEV *AScev = replaceSymbolicStrideSCEV(SE, Strides, APtr); + const SCEV *BScev = replaceSymbolicStrideSCEV(SE, Strides, BPtr); - int StrideAPtr = isStridedPtr(SE, DL, APtr, InnermostLoop); - int StrideBPtr = isStridedPtr(SE, DL, BPtr, InnermostLoop); + int StrideAPtr = isStridedPtr(SE, DL, APtr, InnermostLoop, Strides); + int StrideBPtr = isStridedPtr(SE, DL, BPtr, InnermostLoop, Strides); const SCEV *Src = AScev; const SCEV *Sink = BScev; @@ -3815,9 +4133,9 @@ bool MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx, return false; } -bool -MemoryDepChecker::areDepsSafe(AccessAnalysis::DepCandidates &AccessSets, - MemAccessInfoSet &CheckDeps) { +bool MemoryDepChecker::areDepsSafe(AccessAnalysis::DepCandidates &AccessSets, + MemAccessInfoSet &CheckDeps, + ValueToValueMap &Strides) { MaxSafeDepDistBytes = -1U; while (!CheckDeps.empty()) { @@ -3841,9 +4159,9 @@ MemoryDepChecker::areDepsSafe(AccessAnalysis::DepCandidates &AccessSets, I1E = Accesses[*AI].end(); I1 != I1E; ++I1) for (std::vector::iterator I2 = Accesses[*OI].begin(), I2E = Accesses[*OI].end(); I2 != I2E; ++I2) { - if (*I1 < *I2 && isDependent(*AI, *I1, *OI, *I2)) + if (*I1 < *I2 && isDependent(*AI, *I1, *OI, *I2, Strides)) return false; - if (*I2 < *I1 && isDependent(*OI, *I2, *AI, *I1)) + if (*I2 < *I1 && isDependent(*OI, *I2, *AI, *I1, Strides)) return false; } ++OI; @@ -3974,7 +4292,7 @@ bool LoopVectorizationLegality::canVectorizeMemory() { // read a few words, modify, and write a few words, and some of the // words may be written to the same address. bool IsReadOnlyPtr = false; - if (Seen.insert(Ptr) || !isStridedPtr(SE, DL, Ptr, TheLoop)) { + if (Seen.insert(Ptr) || !isStridedPtr(SE, DL, Ptr, TheLoop, Strides)) { ++NumReads; IsReadOnlyPtr = true; } @@ -3998,8 +4316,8 @@ bool LoopVectorizationLegality::canVectorizeMemory() { unsigned NumComparisons = 0; bool CanDoRT = false; if (NeedRTCheck) - CanDoRT = Accesses.canCheckPtrAtRT(PtrRtCheck, NumComparisons, SE, TheLoop); - + CanDoRT = Accesses.canCheckPtrAtRT(PtrRtCheck, NumComparisons, SE, TheLoop, + Strides); DEBUG(dbgs() << "LV: We need to do " << NumComparisons << " pointer comparisons.\n"); @@ -4032,8 +4350,8 @@ bool LoopVectorizationLegality::canVectorizeMemory() { bool CanVecMem = true; if (Accesses.isDependencyCheckNeeded()) { DEBUG(dbgs() << "LV: Checking memory dependencies\n"); - CanVecMem = DepChecker.areDepsSafe(DependentAccesses, - Accesses.getDependenciesToCheck()); + CanVecMem = DepChecker.areDepsSafe( + DependentAccesses, Accesses.getDependenciesToCheck(), Strides); MaxSafeDepDistBytes = DepChecker.getMaxSafeDepDistBytes(); if (!CanVecMem && DepChecker.shouldRetryWithRuntimeCheck()) { @@ -4047,7 +4365,7 @@ bool LoopVectorizationLegality::canVectorizeMemory() { PtrRtCheck.Need = true; CanDoRT = Accesses.canCheckPtrAtRT(PtrRtCheck, NumComparisons, SE, - TheLoop, true); + TheLoop, Strides, true); // Check that we did not collect too many pointers or found an unsizeable // pointer. if (!CanDoRT || NumComparisons > RuntimeMemoryCheckThreshold) { @@ -4867,6 +5185,12 @@ static bool isLikelyComplexAddressComputation(Value *Ptr, return StepVal > MaxMergeDistance; } +static bool isStrideMul(Instruction *I, LoopVectorizationLegality *Legal) { + if (Legal->hasStride(I->getOperand(0)) || Legal->hasStride(I->getOperand(1))) + return true; + return false; +} + unsigned LoopVectorizationCostModel::getInstructionCost(Instruction *I, unsigned VF) { // If we know that this instruction will remain uniform, check the cost of @@ -4909,6 +5233,9 @@ LoopVectorizationCostModel::getInstructionCost(Instruction *I, unsigned VF) { case Instruction::And: case Instruction::Or: case Instruction::Xor: { + // Since we will replace the stride by 1 the multiplication should go away. + if (I->getOpcode() == Instruction::Mul && isStrideMul(I, Legal)) + return 0; // Certain instructions can be cheaper to vectorize if they have a constant // second vector operand. One example of this are shifts on x86. TargetTransformInfo::OperandValueKind Op1VK = @@ -5155,9 +5482,7 @@ void InnerLoopUnroller::scalarizeInstruction(Instruction *Instr) { } } -void -InnerLoopUnroller::vectorizeMemoryInstruction(Instruction *Instr, - LoopVectorizationLegality*) { +void InnerLoopUnroller::vectorizeMemoryInstruction(Instruction *Instr) { return scalarizeInstruction(Instr); } diff --git a/test/Transforms/LoopVectorize/runtime-check-readonly.ll b/test/Transforms/LoopVectorize/runtime-check-readonly.ll index a2b9ad94c8..e7b1e2a6b7 100644 --- a/test/Transforms/LoopVectorize/runtime-check-readonly.ll +++ b/test/Transforms/LoopVectorize/runtime-check-readonly.ll @@ -7,11 +7,13 @@ target triple = "x86_64-apple-macosx10.8.0" ;CHECK: br ;CHECK: getelementptr ;CHECK-NEXT: getelementptr -;CHECK-NEXT: icmp uge -;CHECK-NEXT: icmp uge -;CHECK-NEXT: icmp uge -;CHECK-NEXT: icmp uge -;CHECK-NEXT: and +;CHECK-DAG: icmp uge +;CHECK-DAG: icmp uge +;CHECK-DAG: icmp uge +;CHECK-DAG: icmp uge +;CHECK-DAG: and +;CHECK-DAG: and +;CHECK: br ;CHECK: ret define void @add_ints(i32* nocapture %A, i32* nocapture %B, i32* nocapture %C) { entry: diff --git a/test/Transforms/LoopVectorize/version-mem-access.ll b/test/Transforms/LoopVectorize/version-mem-access.ll new file mode 100644 index 0000000000..e712728111 --- /dev/null +++ b/test/Transforms/LoopVectorize/version-mem-access.ll @@ -0,0 +1,50 @@ +; RUN: opt -basicaa -loop-vectorize -enable-mem-access-versioning -force-vector-width=2 -force-vector-unroll=1 < %s -S | FileCheck %s + +target datalayout = "e-m:o-i64:64-f80:128-n8:16:32:64-S128" + +; CHECK-LABEL: test +define void @test(i32* noalias %A, i64 %AStride, + i32* noalias %B, i32 %BStride, + i32* noalias %C, i64 %CStride, i32 %N) { +entry: + %cmp13 = icmp eq i32 %N, 0 + br i1 %cmp13, label %for.end, label %for.body.preheader + +; CHECK-DAG: icmp ne i64 %AStride, 1 +; CHECK-DAG: icmp ne i32 %BStride, 1 +; CHECK-DAG: icmp ne i64 %CStride, 1 +; CHECK: or +; CHECK: or +; CHECK: br + +; CHECK: vector.body +; CHECK: load <2 x i32> + +for.body.preheader: + br label %for.body + +for.body: + %indvars.iv = phi i64 [ %indvars.iv.next, %for.body ], [ 0, %for.body.preheader ] + %iv.trunc = trunc i64 %indvars.iv to i32 + %mul = mul i32 %iv.trunc, %BStride + %mul64 = zext i32 %mul to i64 + %arrayidx = getelementptr inbounds i32* %B, i64 %mul64 + %0 = load i32* %arrayidx, align 4 + %mul2 = mul nsw i64 %indvars.iv, %CStride + %arrayidx3 = getelementptr inbounds i32* %C, i64 %mul2 + %1 = load i32* %arrayidx3, align 4 + %mul4 = mul nsw i32 %1, %0 + %mul3 = mul nsw i64 %indvars.iv, %AStride + %arrayidx7 = getelementptr inbounds i32* %A, i64 %mul3 + store i32 %mul4, i32* %arrayidx7, align 4 + %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1 + %lftr.wideiv = trunc i64 %indvars.iv.next to i32 + %exitcond = icmp eq i32 %lftr.wideiv, %N + br i1 %exitcond, label %for.end.loopexit, label %for.body + +for.end.loopexit: + br label %for.end + +for.end: + ret void +} -- cgit v1.2.3