//===- LoopVR.cpp - Value Range analysis driven by loop information -------===// // // The LLVM Compiler Infrastructure // // This file is distributed under the University of Illinois Open Source // License. See LICENSE.TXT for details. // //===----------------------------------------------------------------------===// // // FIXME: What does this do? // //===----------------------------------------------------------------------===// #define DEBUG_TYPE "loopvr" #include "llvm/Analysis/LoopVR.h" #include "llvm/Constants.h" #include "llvm/Instructions.h" #include "llvm/LLVMContext.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Assembly/Writer.h" #include "llvm/Support/CFG.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" using namespace llvm; char LoopVR::ID = 0; static RegisterPass X("loopvr", "Loop Value Ranges", false, true); /// getRange - determine the range for a particular SCEV within a given Loop ConstantRange LoopVR::getRange(const SCEV *S, Loop *L, ScalarEvolution &SE) { const SCEV *T = SE.getBackedgeTakenCount(L); if (isa(T)) return ConstantRange(cast(S->getType())->getBitWidth(), true); T = SE.getTruncateOrZeroExtend(T, S->getType()); return getRange(S, T, SE); } /// getRange - determine the range for a particular SCEV with a given trip count ConstantRange LoopVR::getRange(const SCEV *S, const SCEV *T, ScalarEvolution &SE){ if (const SCEVConstant *C = dyn_cast(S)) return ConstantRange(C->getValue()->getValue()); ConstantRange FullSet(cast(S->getType())->getBitWidth(), true); // {x,+,y,+,...z}. We detect overflow by checking the size of the set after // summing the upper and lower. if (const SCEVAddExpr *Add = dyn_cast(S)) { ConstantRange X = getRange(Add->getOperand(0), T, SE); if (X.isFullSet()) return FullSet; for (unsigned i = 1, e = Add->getNumOperands(); i != e; ++i) { ConstantRange Y = getRange(Add->getOperand(i), T, SE); if (Y.isFullSet()) return FullSet; APInt Spread_X = X.getSetSize(), Spread_Y = Y.getSetSize(); APInt NewLower = X.getLower() + Y.getLower(); APInt NewUpper = X.getUpper() + Y.getUpper() - 1; if (NewLower == NewUpper) return FullSet; X = ConstantRange(NewLower, NewUpper); if (X.getSetSize().ult(Spread_X) || X.getSetSize().ult(Spread_Y)) return FullSet; // we've wrapped, therefore, full set. } return X; } // {x,*,y,*,...,z}. In order to detect overflow, we use k*bitwidth where // k is the number of terms being multiplied. if (const SCEVMulExpr *Mul = dyn_cast(S)) { ConstantRange X = getRange(Mul->getOperand(0), T, SE); if (X.isFullSet()) return FullSet; const IntegerType *Ty = IntegerType::get(SE.getContext(), X.getBitWidth()); const IntegerType *ExTy = IntegerType::get(SE.getContext(), X.getBitWidth() * Mul->getNumOperands()); ConstantRange XExt = X.zeroExtend(ExTy->getBitWidth()); for (unsigned i = 1, e = Mul->getNumOperands(); i != e; ++i) { ConstantRange Y = getRange(Mul->getOperand(i), T, SE); if (Y.isFullSet()) return FullSet; ConstantRange YExt = Y.zeroExtend(ExTy->getBitWidth()); XExt = ConstantRange(XExt.getLower() * YExt.getLower(), ((XExt.getUpper()-1) * (YExt.getUpper()-1)) + 1); } return XExt.truncate(Ty->getBitWidth()); } // X smax Y smax ... Z is: range(smax(X_smin, Y_smin, ..., Z_smin), // smax(X_smax, Y_smax, ..., Z_smax)) // It doesn't matter if one of the SCEVs has FullSet because we're taking // a maximum of the minimums across all of them. if (const SCEVSMaxExpr *SMax = dyn_cast(S)) { ConstantRange X = getRange(SMax->getOperand(0), T, SE); if (X.isFullSet()) return FullSet; APInt smin = X.getSignedMin(), smax = X.getSignedMax(); for (unsigned i = 1, e = SMax->getNumOperands(); i != e; ++i) { ConstantRange Y = getRange(SMax->getOperand(i), T, SE); smin = APIntOps::smax(smin, Y.getSignedMin()); smax = APIntOps::smax(smax, Y.getSignedMax()); } if (smax + 1 == smin) return FullSet; return ConstantRange(smin, smax + 1); } // X umax Y umax ... Z is: range(umax(X_umin, Y_umin, ..., Z_umin), // umax(X_umax, Y_umax, ..., Z_umax)) // It doesn't matter if one of the SCEVs has FullSet because we're taking // a maximum of the minimums across all of them. if (const SCEVUMaxExpr *UMax = dyn_cast(S)) { ConstantRange X = getRange(UMax->getOperand(0), T, SE); if (X.isFullSet()) return FullSet; APInt umin = X.getUnsignedMin(), umax = X.getUnsignedMax(); for (unsigned i = 1, e = UMax->getNumOperands(); i != e; ++i) { ConstantRange Y = getRange(UMax->getOperand(i), T, SE); umin = APIntOps::umax(umin, Y.getUnsignedMin()); umax = APIntOps::umax(umax, Y.getUnsignedMax()); } if (umax + 1 == umin) return FullSet; return ConstantRange(umin, umax + 1); } // L udiv R. Luckily, there's only ever 2 sides to a udiv. if (const SCEVUDivExpr *UDiv = dyn_cast(S)) { ConstantRange L = getRange(UDiv->getLHS(), T, SE); ConstantRange R = getRange(UDiv->getRHS(), T, SE); if (L.isFullSet() && R.isFullSet()) return FullSet; if (R.getUnsignedMax() == 0) { // RHS must be single-element zero. Return an empty set. return ConstantRange(R.getBitWidth(), false); } APInt Lower = L.getUnsignedMin().udiv(R.getUnsignedMax()); APInt Upper; if (R.getUnsignedMin() == 0) { // Just because it contains zero, doesn't mean it will also contain one. ConstantRange NotZero(APInt(L.getBitWidth(), 1), APInt::getNullValue(L.getBitWidth())); R = R.intersectWith(NotZero); } // But, the intersection might still include zero. If it does, then we know // it also included one. if (R.contains(APInt::getNullValue(L.getBitWidth()))) Upper = L.getUnsignedMax(); else Upper = L.getUnsignedMax().udiv(R.getUnsignedMin()); return ConstantRange(Lower, Upper); } // ConstantRange already implements the cast operators. if (const SCEVZeroExtendExpr *ZExt = dyn_cast(S)) { T = SE.getTruncateOrZeroExtend(T, ZExt->getOperand()->getType()); ConstantRange X = getRange(ZExt->getOperand(), T, SE); return X.zeroExtend(cast(ZExt->getType())->getBitWidth()); } if (const SCEVSignExtendExpr *SExt = dyn_cast(S)) { T = SE.getTruncateOrZeroExtend(T, SExt->getOperand()->getType()); ConstantRange X = getRange(SExt->getOperand(), T, SE); return X.signExtend(cast(SExt->getType())->getBitWidth()); } if (const SCEVTruncateExpr *Trunc = dyn_cast(S)) { T = SE.getTruncateOrZeroExtend(T, Trunc->getOperand()->getType()); ConstantRange X = getRange(Trunc->getOperand(), T, SE); if (X.isFullSet()) return FullSet; return X.truncate(cast(Trunc->getType())->getBitWidth()); } if (const SCEVAddRecExpr *AddRec = dyn_cast(S)) { const SCEVConstant *Trip = dyn_cast(T); if (!Trip) return FullSet; if (AddRec->isAffine()) { const SCEV *StartHandle = AddRec->getStart(); const SCEV *StepHandle = AddRec->getOperand(1); const SCEVConstant *Step = dyn_cast(StepHandle); if (!Step) return FullSet; uint32_t ExWidth = 2 * Trip->getValue()->getBitWidth(); APInt TripExt = Trip->getValue()->getValue(); TripExt.zext(ExWidth); APInt StepExt = Step->getValue()->getValue(); StepExt.zext(ExWidth); if ((TripExt * StepExt).ugt(APInt::getLowBitsSet(ExWidth, ExWidth >> 1))) return FullSet; const SCEV *EndHandle = SE.getAddExpr(StartHandle, SE.getMulExpr(T, StepHandle)); const SCEVConstant *Start = dyn_cast(StartHandle); const SCEVConstant *End = dyn_cast(EndHandle); if (!Start || !End) return FullSet; const APInt &StartInt = Start->getValue()->getValue(); const APInt &EndInt = End->getValue()->getValue(); const APInt &StepInt = Step->getValue()->getValue(); if (StepInt.isNegative()) { if (EndInt == StartInt + 1) return FullSet; return ConstantRange(EndInt, StartInt + 1); } else { if (StartInt == EndInt + 1) return FullSet; return ConstantRange(StartInt, EndInt + 1); } } } // TODO: non-affine addrec, udiv, SCEVUnknown (narrowed from elsewhere)? return FullSet; } void LoopVR::getAnalysisUsage(AnalysisUsage &AU) const { AU.addRequiredTransitive(); AU.addRequiredTransitive(); AU.setPreservesAll(); } bool LoopVR::runOnFunction(Function &F) { Map.clear(); return false; } void LoopVR::print(raw_ostream &OS, const Module *) const { for (std::map::const_iterator I = Map.begin(), E = Map.end(); I != E; ++I) { OS << *I->first << ": " << *I->second << '\n'; } } void LoopVR::releaseMemory() { for (std::map::iterator I = Map.begin(), E = Map.end(); I != E; ++I) { delete I->second; } Map.clear(); } ConstantRange LoopVR::compute(Value *V) { if (ConstantInt *CI = dyn_cast(V)) return ConstantRange(CI->getValue()); Instruction *I = dyn_cast(V); if (!I) return ConstantRange(cast(V->getType())->getBitWidth(), false); LoopInfo &LI = getAnalysis(); Loop *L = LI.getLoopFor(I->getParent()); if (!L || L->isLoopInvariant(I)) return ConstantRange(cast(V->getType())->getBitWidth(), false); ScalarEvolution &SE = getAnalysis(); const SCEV *S = SE.getSCEV(I); if (isa(S) || isa(S)) return ConstantRange(cast(V->getType())->getBitWidth(), false); return ConstantRange(getRange(S, L, SE)); } ConstantRange LoopVR::get(Value *V) { std::map::iterator I = Map.find(V); if (I == Map.end()) { ConstantRange *CR = new ConstantRange(compute(V)); Map[V] = CR; return *CR; } return *I->second; } void LoopVR::remove(Value *V) { std::map::iterator I = Map.find(V); if (I != Map.end()) { delete I->second; Map.erase(I); } } void LoopVR::narrow(Value *V, const ConstantRange &CR) { if (CR.isFullSet()) return; std::map::iterator I = Map.find(V); if (I == Map.end()) Map[V] = new ConstantRange(CR); else Map[V] = new ConstantRange(Map[V]->intersectWith(CR)); }