From 710c1a449dd7bee747ecf9c344a6f6d5461a158d Mon Sep 17 00:00:00 2001 From: Reid Kleckner Date: Thu, 24 Apr 2014 20:14:34 +0000 Subject: Add 'musttail' marker to call instructions This is similar to the 'tail' marker, except that it guarantees that tail call optimization will occur. It also comes with convervative IR verification rules that ensure that tail call optimization is possible. Reviewers: nicholas Differential Revision: http://llvm-reviews.chandlerc.com/D3240 git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@207143 91177308-0d34-0410-b5e6-96231b3b80d8 --- lib/IR/Verifier.cpp | 89 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 89 insertions(+) (limited to 'lib/IR/Verifier.cpp') diff --git a/lib/IR/Verifier.cpp b/lib/IR/Verifier.cpp index e7c67c7e4f..43534385f3 100644 --- a/lib/IR/Verifier.cpp +++ b/lib/IR/Verifier.cpp @@ -301,6 +301,7 @@ private: void visitLandingPadInst(LandingPadInst &LPI); void VerifyCallSite(CallSite CS); + void verifyMustTailCall(CallInst &CI); bool PerformTypeCheck(Intrinsic::ID ID, Function *F, Type *Ty, int VT, unsigned ArgNo, std::string &Suffix); bool VerifyIntrinsicType(Type *Ty, ArrayRef &Infos, @@ -1545,9 +1546,97 @@ void Verifier::VerifyCallSite(CallSite CS) { visitInstruction(*I); } +/// Two types are "congruent" if they are identical, or if they are both pointer +/// types with different pointee types and the same address space. +static bool isTypeCongruent(Type *L, Type *R) { + if (L == R) + return true; + PointerType *PL = dyn_cast(L); + PointerType *PR = dyn_cast(R); + if (!PL || !PR) + return false; + return PL->getAddressSpace() == PR->getAddressSpace(); +} + +void Verifier::verifyMustTailCall(CallInst &CI) { + Assert1(!CI.isInlineAsm(), "cannot use musttail call with inline asm", &CI); + + // - The caller and callee prototypes must match. Pointer types of + // parameters or return types may differ in pointee type, but not + // address space. + Function *F = CI.getParent()->getParent(); + auto GetFnTy = [](Value *V) { + return cast( + cast(V->getType())->getElementType()); + }; + FunctionType *CallerTy = GetFnTy(F); + FunctionType *CalleeTy = GetFnTy(CI.getCalledValue()); + Assert1(CallerTy->getNumParams() == CalleeTy->getNumParams(), + "cannot guarantee tail call due to mismatched parameter counts", &CI); + Assert1(CallerTy->isVarArg() == CalleeTy->isVarArg(), + "cannot guarantee tail call due to mismatched varargs", &CI); + Assert1(isTypeCongruent(CallerTy->getReturnType(), CalleeTy->getReturnType()), + "cannot guarantee tail call due to mismatched return types", &CI); + for (int I = 0, E = CallerTy->getNumParams(); I != E; ++I) { + Assert1( + isTypeCongruent(CallerTy->getParamType(I), CalleeTy->getParamType(I)), + "cannot guarantee tail call due to mismatched parameter types", &CI); + } + + // - The calling conventions of the caller and callee must match. + Assert1(F->getCallingConv() == CI.getCallingConv(), + "cannot guarantee tail call due to mismatched calling conv", &CI); + + // - All ABI-impacting function attributes, such as sret, byval, inreg, + // returned, and inalloca, must match. + static const Attribute::AttrKind ABIAttrs[] = { + Attribute::Alignment, Attribute::StructRet, Attribute::ByVal, + Attribute::InAlloca, Attribute::InReg, Attribute::Returned}; + AttributeSet CallerAttrs = F->getAttributes(); + AttributeSet CalleeAttrs = CI.getAttributes(); + for (int I = 0, E = CallerTy->getNumParams(); I != E; ++I) { + AttrBuilder CallerABIAttrs; + AttrBuilder CalleeABIAttrs; + for (auto AK : ABIAttrs) { + if (CallerAttrs.hasAttribute(I + 1, AK)) + CallerABIAttrs.addAttribute(AK); + if (CalleeAttrs.hasAttribute(I + 1, AK)) + CalleeABIAttrs.addAttribute(AK); + } + Assert2(CallerABIAttrs == CalleeABIAttrs, + "cannot guarantee tail call due to mismatched ABI impacting " + "function attributes", &CI, CI.getOperand(I)); + } + + // - The call must immediately precede a :ref:`ret ` instruction, + // or a pointer bitcast followed by a ret instruction. + // - The ret instruction must return the (possibly bitcasted) value + // produced by the call or void. + Value *RetVal = &CI; + Instruction *Next = CI.getNextNode(); + + // Handle the optional bitcast. + if (BitCastInst *BI = dyn_cast_or_null(Next)) { + Assert1(BI->getOperand(0) == RetVal, + "bitcast following musttail call must use the call", BI); + RetVal = BI; + Next = BI->getNextNode(); + } + + // Check the return. + ReturnInst *Ret = dyn_cast_or_null(Next); + Assert1(Ret, "musttail call must be precede a ret with an optional bitcast", + &CI); + Assert1(!Ret->getReturnValue() || Ret->getReturnValue() == RetVal, + "musttail call result must be returned", Ret); +} + void Verifier::visitCallInst(CallInst &CI) { VerifyCallSite(&CI); + if (CI.isMustTailCall()) + verifyMustTailCall(CI); + if (Function *F = CI.getCalledFunction()) if (Intrinsic::ID ID = (Intrinsic::ID)F->getIntrinsicID()) visitIntrinsicFunctionCall(ID, CI); -- cgit v1.2.3