summaryrefslogtreecommitdiff
path: root/lib/Transforms/IPO/SimplifyLibCalls.cpp
diff options
context:
space:
mode:
authorChris Lattner <sabre@nondot.org>2007-04-07 21:17:51 +0000
committerChris Lattner <sabre@nondot.org>2007-04-07 21:17:51 +0000
commit73f5d42a9722eef5127cdf3e5967c8518dbad812 (patch)
treef525c56cfbb5dedfa961f4ddd8368229946c2542 /lib/Transforms/IPO/SimplifyLibCalls.cpp
parent3492cda48f85ca9048824baff18e24a19edaf7ed (diff)
downloadllvm-73f5d42a9722eef5127cdf3e5967c8518dbad812.tar.gz
llvm-73f5d42a9722eef5127cdf3e5967c8518dbad812.tar.bz2
llvm-73f5d42a9722eef5127cdf3e5967c8518dbad812.tar.xz
Fix problems in the sprintf optimizer
git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@35754 91177308-0d34-0410-b5e6-96231b3b80d8
Diffstat (limited to 'lib/Transforms/IPO/SimplifyLibCalls.cpp')
-rw-r--r--lib/Transforms/IPO/SimplifyLibCalls.cpp141
1 files changed, 60 insertions, 81 deletions
diff --git a/lib/Transforms/IPO/SimplifyLibCalls.cpp b/lib/Transforms/IPO/SimplifyLibCalls.cpp
index 6b55eb6759..662dbb7e63 100644
--- a/lib/Transforms/IPO/SimplifyLibCalls.cpp
+++ b/lib/Transforms/IPO/SimplifyLibCalls.cpp
@@ -1276,8 +1276,7 @@ public:
if (CI->getNumOperands() != 3 && CI->getNumOperands() != 4)
return false;
- // All the optimizations depend on the length of the second argument and the
- // fact that it is a constant string array. Check that now
+ // All the optimizations depend on the format string.
uint64_t FormatLen, FormatStartIdx;
ConstantArray *CA = 0;
if (!GetConstantStringInfo(CI->getOperand(2), CA, FormatLen,FormatStartIdx))
@@ -1368,108 +1367,88 @@ public:
SPrintFOptimization() : LibCallOptimization("sprintf",
"Number of 'sprintf' calls simplified") {}
- /// @brief Make sure that the "fprintf" function has the right prototype
- virtual bool ValidateCalledFunction(const Function *f, SimplifyLibCalls &SLC){
- // Just make sure this has at least 2 arguments
- return (f->getReturnType() == Type::Int32Ty && f->arg_size() >= 2);
+ /// @brief Make sure that the "sprintf" function has the right prototype
+ virtual bool ValidateCalledFunction(const Function *F, SimplifyLibCalls &SLC){
+ const FunctionType *FT = F->getFunctionType();
+ return FT->getNumParams() == 2 && // two fixed arguments.
+ FT->getParamType(1) == PointerType::get(Type::Int8Ty) &&
+ FT->getParamType(0) == FT->getParamType(1) &&
+ isa<IntegerType>(FT->getReturnType());
}
/// @brief Perform the sprintf optimization.
- virtual bool OptimizeCall(CallInst *ci, SimplifyLibCalls &SLC) {
+ virtual bool OptimizeCall(CallInst *CI, SimplifyLibCalls &SLC) {
// If the call has more than 3 operands, we can't optimize it
- if (ci->getNumOperands() > 4 || ci->getNumOperands() < 3)
+ if (CI->getNumOperands() != 3 && CI->getNumOperands() != 4)
return false;
- // All the optimizations depend on the length of the second argument and the
- // fact that it is a constant string array. Check that now
- uint64_t len, StartIdx;
- ConstantArray* CA = 0;
- if (!GetConstantStringInfo(ci->getOperand(2), CA, len, StartIdx))
+ uint64_t FormatLen, FormatStartIdx;
+ ConstantArray *CA = 0;
+ if (!GetConstantStringInfo(CI->getOperand(2), CA, FormatLen,FormatStartIdx))
return false;
-
- if (ci->getNumOperands() == 3) {
- if (len == 0) {
- // If the length is 0, we just need to store a null byte
- new StoreInst(ConstantInt::get(Type::Int8Ty,0),ci->getOperand(1),ci);
- return ReplaceCallWith(ci, ConstantInt::get(Type::Int32Ty,0));
- }
-
+
+ if (CI->getNumOperands() == 3) {
+ if (!CA->isCString()) return false;
+
// Make sure there's no % in the constant array
- for (unsigned i = 0; i < len; ++i) {
- if (ConstantInt* CI = dyn_cast<ConstantInt>(CA->getOperand(i))) {
- // Check for the null terminator
- if (CI->getZExtValue() == '%')
- return false; // we found a %, can't optimize
- } else {
- return false; // initializer is not constant int, can't optimize
- }
- }
-
- // Increment length because we want to copy the null byte too
- len++;
-
+ std::string S = CA->getAsString();
+ for (unsigned i = FormatStartIdx, e = S.size(); i != e; ++i)
+ if (S[i] == '%')
+ return false; // we found a format specifier
+
// sprintf(str,fmt) -> llvm.memcpy(str,fmt,strlen(fmt),1)
- Value *args[4] = {
- ci->getOperand(1),
- ci->getOperand(2),
- ConstantInt::get(SLC.getIntPtrType(),len),
+ Value *MemCpyArgs[] = {
+ CI->getOperand(1), CI->getOperand(2),
+ ConstantInt::get(SLC.getIntPtrType(), FormatLen+1), // Copy the nul byte
ConstantInt::get(Type::Int32Ty, 1)
};
- new CallInst(SLC.get_memcpy(), args, 4, "", ci);
- return ReplaceCallWith(ci, ConstantInt::get(Type::Int32Ty,len));
+ new CallInst(SLC.get_memcpy(), MemCpyArgs, 4, "", CI);
+ return ReplaceCallWith(CI, ConstantInt::get(CI->getType(), FormatLen));
}
- // The remaining optimizations require the format string to be length 2
- // "%s" or "%c".
- if (len != 2)
+ // The remaining optimizations require the format string to be "%s" or "%c".
+ if (FormatLen != 2 ||
+ cast<ConstantInt>(CA->getOperand(FormatStartIdx))->getZExtValue() !='%')
return false;
- // The first character has to be a %
- if (ConstantInt* CI = dyn_cast<ConstantInt>(CA->getOperand(0)))
- if (CI->getZExtValue() != '%')
- return false;
-
// Get the second character and switch on its value
- ConstantInt* CI = dyn_cast<ConstantInt>(CA->getOperand(1));
- switch (CI->getZExtValue()) {
+ switch (cast<ConstantInt>(CA->getOperand(1))->getZExtValue()) {
+ case 'c': {
+ // sprintf(dest,"%c",chr) -> store chr, dest
+ Value *V = CastInst::createTruncOrBitCast(CI->getOperand(3),
+ Type::Int8Ty, "char", CI);
+ new StoreInst(V, CI->getOperand(1), CI);
+ Value *Ptr = new GetElementPtrInst(CI->getOperand(1),
+ ConstantInt::get(Type::Int32Ty, 1),
+ CI->getOperand(1)->getName()+".end",
+ CI);
+ new StoreInst(ConstantInt::get(Type::Int8Ty,0), Ptr, CI);
+ return ReplaceCallWith(CI, ConstantInt::get(Type::Int32Ty, 1));
+ }
case 's': {
// sprintf(dest,"%s",str) -> llvm.memcpy(dest, str, strlen(str)+1, 1)
Value *Len = new CallInst(SLC.get_strlen(),
- CastToCStr(ci->getOperand(3), ci),
- ci->getOperand(3)->getName()+".len", ci);
- Value *Len1 = BinaryOperator::createAdd(Len,
- ConstantInt::get(Len->getType(), 1),
- Len->getName()+"1", ci);
- if (Len1->getType() != SLC.getIntPtrType())
- Len1 = CastInst::createIntegerCast(Len1, SLC.getIntPtrType(), false,
- Len1->getName(), ci);
- Value *args[4] = {
- CastToCStr(ci->getOperand(1), ci),
- CastToCStr(ci->getOperand(3), ci),
- Len1,
- ConstantInt::get(Type::Int32Ty,1)
+ CastToCStr(CI->getOperand(3), CI),
+ CI->getOperand(3)->getName()+".len", CI);
+ Value *UnincLen = Len;
+ Len = BinaryOperator::createAdd(Len, ConstantInt::get(Len->getType(), 1),
+ Len->getName()+"1", CI);
+ Value *MemcpyArgs[4] = {
+ CI->getOperand(1),
+ CastToCStr(CI->getOperand(3), CI),
+ Len,
+ ConstantInt::get(Type::Int32Ty, 1)
};
- new CallInst(SLC.get_memcpy(), args, 4, "", ci);
+ new CallInst(SLC.get_memcpy(), MemcpyArgs, 4, "", CI);
// The strlen result is the unincremented number of bytes in the string.
- if (!ci->use_empty()) {
- if (Len->getType() != ci->getType())
- Len = CastInst::createIntegerCast(Len, ci->getType(), false,
- Len->getName(), ci);
- ci->replaceAllUsesWith(Len);
+ if (!CI->use_empty()) {
+ if (UnincLen->getType() != CI->getType())
+ UnincLen = CastInst::createIntegerCast(UnincLen, CI->getType(), false,
+ Len->getName(), CI);
+ CI->replaceAllUsesWith(UnincLen);
}
- return ReplaceCallWith(ci, 0);
- }
- case 'c': {
- // sprintf(dest,"%c",chr) -> store chr, dest
- CastInst* cast = CastInst::createTruncOrBitCast(
- ci->getOperand(3), Type::Int8Ty, "char", ci);
- new StoreInst(cast, ci->getOperand(1), ci);
- GetElementPtrInst* gep = new GetElementPtrInst(ci->getOperand(1),
- ConstantInt::get(Type::Int32Ty,1),ci->getOperand(1)->getName()+".end",
- ci);
- new StoreInst(ConstantInt::get(Type::Int8Ty,0),gep,ci);
- return ReplaceCallWith(ci, ConstantInt::get(Type::Int32Ty, 1));
+ return ReplaceCallWith(CI, 0);
}
}
return false;