From 73f5d42a9722eef5127cdf3e5967c8518dbad812 Mon Sep 17 00:00:00 2001 From: Chris Lattner Date: Sat, 7 Apr 2007 21:17:51 +0000 Subject: Fix problems in the sprintf optimizer git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@35754 91177308-0d34-0410-b5e6-96231b3b80d8 --- lib/Transforms/IPO/SimplifyLibCalls.cpp | 141 ++++++++++++++------------------ 1 file changed, 60 insertions(+), 81 deletions(-) (limited to 'lib/Transforms/IPO/SimplifyLibCalls.cpp') diff --git a/lib/Transforms/IPO/SimplifyLibCalls.cpp b/lib/Transforms/IPO/SimplifyLibCalls.cpp index 6b55eb6759..662dbb7e63 100644 --- a/lib/Transforms/IPO/SimplifyLibCalls.cpp +++ b/lib/Transforms/IPO/SimplifyLibCalls.cpp @@ -1276,8 +1276,7 @@ public: if (CI->getNumOperands() != 3 && CI->getNumOperands() != 4) return false; - // All the optimizations depend on the length of the second argument and the - // fact that it is a constant string array. Check that now + // All the optimizations depend on the format string. uint64_t FormatLen, FormatStartIdx; ConstantArray *CA = 0; if (!GetConstantStringInfo(CI->getOperand(2), CA, FormatLen,FormatStartIdx)) @@ -1368,108 +1367,88 @@ public: SPrintFOptimization() : LibCallOptimization("sprintf", "Number of 'sprintf' calls simplified") {} - /// @brief Make sure that the "fprintf" function has the right prototype - virtual bool ValidateCalledFunction(const Function *f, SimplifyLibCalls &SLC){ - // Just make sure this has at least 2 arguments - return (f->getReturnType() == Type::Int32Ty && f->arg_size() >= 2); + /// @brief Make sure that the "sprintf" function has the right prototype + virtual bool ValidateCalledFunction(const Function *F, SimplifyLibCalls &SLC){ + const FunctionType *FT = F->getFunctionType(); + return FT->getNumParams() == 2 && // two fixed arguments. + FT->getParamType(1) == PointerType::get(Type::Int8Ty) && + FT->getParamType(0) == FT->getParamType(1) && + isa(FT->getReturnType()); } /// @brief Perform the sprintf optimization. - virtual bool OptimizeCall(CallInst *ci, SimplifyLibCalls &SLC) { + virtual bool OptimizeCall(CallInst *CI, SimplifyLibCalls &SLC) { // If the call has more than 3 operands, we can't optimize it - if (ci->getNumOperands() > 4 || ci->getNumOperands() < 3) + if (CI->getNumOperands() != 3 && CI->getNumOperands() != 4) return false; - // All the optimizations depend on the length of the second argument and the - // fact that it is a constant string array. Check that now - uint64_t len, StartIdx; - ConstantArray* CA = 0; - if (!GetConstantStringInfo(ci->getOperand(2), CA, len, StartIdx)) + uint64_t FormatLen, FormatStartIdx; + ConstantArray *CA = 0; + if (!GetConstantStringInfo(CI->getOperand(2), CA, FormatLen,FormatStartIdx)) return false; - - if (ci->getNumOperands() == 3) { - if (len == 0) { - // If the length is 0, we just need to store a null byte - new StoreInst(ConstantInt::get(Type::Int8Ty,0),ci->getOperand(1),ci); - return ReplaceCallWith(ci, ConstantInt::get(Type::Int32Ty,0)); - } - + + if (CI->getNumOperands() == 3) { + if (!CA->isCString()) return false; + // Make sure there's no % in the constant array - for (unsigned i = 0; i < len; ++i) { - if (ConstantInt* CI = dyn_cast(CA->getOperand(i))) { - // Check for the null terminator - if (CI->getZExtValue() == '%') - return false; // we found a %, can't optimize - } else { - return false; // initializer is not constant int, can't optimize - } - } - - // Increment length because we want to copy the null byte too - len++; - + std::string S = CA->getAsString(); + for (unsigned i = FormatStartIdx, e = S.size(); i != e; ++i) + if (S[i] == '%') + return false; // we found a format specifier + // sprintf(str,fmt) -> llvm.memcpy(str,fmt,strlen(fmt),1) - Value *args[4] = { - ci->getOperand(1), - ci->getOperand(2), - ConstantInt::get(SLC.getIntPtrType(),len), + Value *MemCpyArgs[] = { + CI->getOperand(1), CI->getOperand(2), + ConstantInt::get(SLC.getIntPtrType(), FormatLen+1), // Copy the nul byte ConstantInt::get(Type::Int32Ty, 1) }; - new CallInst(SLC.get_memcpy(), args, 4, "", ci); - return ReplaceCallWith(ci, ConstantInt::get(Type::Int32Ty,len)); + new CallInst(SLC.get_memcpy(), MemCpyArgs, 4, "", CI); + return ReplaceCallWith(CI, ConstantInt::get(CI->getType(), FormatLen)); } - // The remaining optimizations require the format string to be length 2 - // "%s" or "%c". - if (len != 2) + // The remaining optimizations require the format string to be "%s" or "%c". + if (FormatLen != 2 || + cast(CA->getOperand(FormatStartIdx))->getZExtValue() !='%') return false; - // The first character has to be a % - if (ConstantInt* CI = dyn_cast(CA->getOperand(0))) - if (CI->getZExtValue() != '%') - return false; - // Get the second character and switch on its value - ConstantInt* CI = dyn_cast(CA->getOperand(1)); - switch (CI->getZExtValue()) { + switch (cast(CA->getOperand(1))->getZExtValue()) { + case 'c': { + // sprintf(dest,"%c",chr) -> store chr, dest + Value *V = CastInst::createTruncOrBitCast(CI->getOperand(3), + Type::Int8Ty, "char", CI); + new StoreInst(V, CI->getOperand(1), CI); + Value *Ptr = new GetElementPtrInst(CI->getOperand(1), + ConstantInt::get(Type::Int32Ty, 1), + CI->getOperand(1)->getName()+".end", + CI); + new StoreInst(ConstantInt::get(Type::Int8Ty,0), Ptr, CI); + return ReplaceCallWith(CI, ConstantInt::get(Type::Int32Ty, 1)); + } case 's': { // sprintf(dest,"%s",str) -> llvm.memcpy(dest, str, strlen(str)+1, 1) Value *Len = new CallInst(SLC.get_strlen(), - CastToCStr(ci->getOperand(3), ci), - ci->getOperand(3)->getName()+".len", ci); - Value *Len1 = BinaryOperator::createAdd(Len, - ConstantInt::get(Len->getType(), 1), - Len->getName()+"1", ci); - if (Len1->getType() != SLC.getIntPtrType()) - Len1 = CastInst::createIntegerCast(Len1, SLC.getIntPtrType(), false, - Len1->getName(), ci); - Value *args[4] = { - CastToCStr(ci->getOperand(1), ci), - CastToCStr(ci->getOperand(3), ci), - Len1, - ConstantInt::get(Type::Int32Ty,1) + CastToCStr(CI->getOperand(3), CI), + CI->getOperand(3)->getName()+".len", CI); + Value *UnincLen = Len; + Len = BinaryOperator::createAdd(Len, ConstantInt::get(Len->getType(), 1), + Len->getName()+"1", CI); + Value *MemcpyArgs[4] = { + CI->getOperand(1), + CastToCStr(CI->getOperand(3), CI), + Len, + ConstantInt::get(Type::Int32Ty, 1) }; - new CallInst(SLC.get_memcpy(), args, 4, "", ci); + new CallInst(SLC.get_memcpy(), MemcpyArgs, 4, "", CI); // The strlen result is the unincremented number of bytes in the string. - if (!ci->use_empty()) { - if (Len->getType() != ci->getType()) - Len = CastInst::createIntegerCast(Len, ci->getType(), false, - Len->getName(), ci); - ci->replaceAllUsesWith(Len); + if (!CI->use_empty()) { + if (UnincLen->getType() != CI->getType()) + UnincLen = CastInst::createIntegerCast(UnincLen, CI->getType(), false, + Len->getName(), CI); + CI->replaceAllUsesWith(UnincLen); } - return ReplaceCallWith(ci, 0); - } - case 'c': { - // sprintf(dest,"%c",chr) -> store chr, dest - CastInst* cast = CastInst::createTruncOrBitCast( - ci->getOperand(3), Type::Int8Ty, "char", ci); - new StoreInst(cast, ci->getOperand(1), ci); - GetElementPtrInst* gep = new GetElementPtrInst(ci->getOperand(1), - ConstantInt::get(Type::Int32Ty,1),ci->getOperand(1)->getName()+".end", - ci); - new StoreInst(ConstantInt::get(Type::Int8Ty,0),gep,ci); - return ReplaceCallWith(ci, ConstantInt::get(Type::Int32Ty, 1)); + return ReplaceCallWith(CI, 0); } } return false; -- cgit v1.2.3