diff --git a/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h b/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h index 1b2482a2363de9ef0af19107937cdb3cc0e72e4c..dedf8482aab8fa06e785524d08e3e1bd13655c1f 100644 --- a/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h +++ b/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h @@ -190,6 +190,10 @@ private: Value *optimizePow(CallInst *CI, IRBuilderBase &B); Value *replacePowWithExp(CallInst *Pow, IRBuilderBase &B); Value *replacePowWithSqrt(CallInst *Pow, IRBuilderBase &B); + Value *replaceNestedPowAndSqrtWithPow(CallInst *Pow, IRBuilderBase &B); + Value *replaceNestedPowAndPowWithPow(CallInst *Pow, IRBuilderBase &B); + Value *replaceNestedSqrtAndPowWithPow(CallInst *Sqrt,IRBuilderBase &B); + Value *optimizeCbrt(CallInst *CI, IRBuilderBase &B); Value *optimizeExp2(CallInst *CI, IRBuilderBase &B); Value *optimizeFMinFMax(CallInst *CI, IRBuilderBase &B); Value *optimizeLog(CallInst *CI, IRBuilderBase &B); diff --git a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp index 245f2d4e442a435cedaf1f596f37e3637a968203..5f349f6141588bf155c7defd02fe3ec6ac2d303c 100644 --- a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp @@ -1971,6 +1971,331 @@ Value *LibCallSimplifier::replacePowWithSqrt(CallInst *Pow, IRBuilderBase &B) { return Sqrt; } +// pow(sqrt(x),y) -> pow(x,y*0.5) +Value *LibCallSimplifier::replaceNestedPowAndSqrtWithPow(CallInst *Pow, + IRBuilderBase &B) { + Value *Base = nullptr, *y = nullptr, *NewPow = nullptr; + Base = Pow->getArgOperand(0); + y = Pow->getArgOperand(1); + Module *Mod = Pow->getModule(); + Type *Ty = Pow->getType(); + CallInst *BaseFn = dyn_cast(Base); + if (BaseFn && BaseFn->hasOneUse() && BaseFn->isFast() && Pow->isFast()) { + Function *CalleeFn = BaseFn->getCalledFunction(); + LibFunc LibFn; + + // If Pow is an intrinsic call, and + // its first argument is an intrinsic call to Sqrt + if (IntrinsicInst *II = dyn_cast(Pow)) { + if (II->getIntrinsicID() == Intrinsic::pow && CalleeFn && + CalleeFn->getIntrinsicID() == Intrinsic::sqrt) { + Value *x = BaseFn->getOperand(0); + // Create a new node y * 0.5. + Value *y_05 = B.CreateFMul(y, ConstantFP::get(Ty, 0.5)); + NewPow = B.CreateCall( + Intrinsic::getDeclaration(Mod, Pow->getIntrinsicID(), Ty), + {x, y_05}); + } + } + // If it is a library function call + else if (CalleeFn && TLI->getLibFunc(CalleeFn->getName(), LibFn) && + isLibFuncEmittable(Mod, TLI, LibFn)) { + LibFunc floatFn, doubleFn, longDFn; + switch (LibFn) { + case LibFunc_sqrtf: + case LibFunc_sqrt: + case LibFunc_sqrtl: + floatFn = LibFunc_powf; + doubleFn = LibFunc_pow; + longDFn = LibFunc_powl; + break; + case LibFunc_sqrtf_finite: + case LibFunc_sqrt_finite: + case LibFunc_sqrtl_finite: + floatFn = LibFunc_powf_finite; + doubleFn = LibFunc_pow_finite; + longDFn = LibFunc_powl_finite; + break; + default: + return nullptr; + } + Value *x = BaseFn->getOperand(0); + NewPow = emitBinaryFloatFnCall( + x, B.CreateFMul(y, ConstantFP::get(Ty, 0.5)), TLI, doubleFn, floatFn, + longDFn, B, CalleeFn->getAttributes()); + } + if (NewPow) { + Pow->replaceAllUsesWith(NewPow); + return NewPow; + } + } + return nullptr; +} + +//pow(pow(x,y),z)-> pow(x,y*z) +Value *LibCallSimplifier::replaceNestedPowAndPowWithPow(CallInst *Pow, IRBuilderBase &B){ + Value *Base=nullptr,*z=nullptr,*NewPow=nullptr; + Base = Pow->getArgOperand(0); + z = Pow->getArgOperand(1); + Module *Mod = Pow->getModule(); + Type *Ty = Pow->getType(); + CallInst *BaseFn = dyn_cast(Base); + + if (BaseFn && BaseFn->hasOneUse() && BaseFn->isFast() && Pow->isFast()){ + Function *CalleeFn = BaseFn->getCalledFunction(); + LibFunc LibFn; + // If Pow is an intrinsic call and + // its first argument is also an intrinsic call to pow + if (IntrinsicInst *II = dyn_cast(Pow)) { + if (II->getIntrinsicID() == Intrinsic::pow && CalleeFn && + CalleeFn->getIntrinsicID() == Intrinsic::pow) { + Value *x = BaseFn->getOperand(0); + Value *y = BaseFn->getOperand(1); + Value *yz = B.CreateFMul(y, z); + NewPow = B.CreateCall( + Intrinsic::getDeclaration(Mod, Pow->getIntrinsicID(), Ty), {x, yz}); + } + } + // If it is a library function call + else if (CalleeFn && TLI->getLibFunc(CalleeFn->getName(), LibFn) && + isLibFuncEmittable(Mod, TLI, LibFn)) { + LibFunc floatFn,doubleFn,longDFn; + switch (LibFn) + { + case LibFunc_powf: + case LibFunc_pow: + case LibFunc_powl: + floatFn = LibFunc_powf; + doubleFn = LibFunc_pow; + longDFn = LibFunc_powl; + break; + case LibFunc_powf_finite: + case LibFunc_pow_finite: + case LibFunc_powl_finite: + floatFn = LibFunc_powf_finite; + doubleFn = LibFunc_pow_finite; + longDFn = LibFunc_powl_finite; + break; + default: + return nullptr; + } + Value *x=BaseFn->getOperand(0); + Value *y=BaseFn->getOperand(1); + Value *yz=B.CreateFMul(y,z); + NewPow = emitBinaryFloatFnCall(x, yz, TLI, doubleFn, floatFn, longDFn, + B, CalleeFn->getAttributes()); + } + if (NewPow) { + Pow->replaceAllUsesWith(NewPow); + return NewPow; + } + } + return nullptr; +} + +// sqrt(pow(x,y)) -> pow(|x|,y*0.5) (incorrect transformation) +// If x is transformed into abs(x), +// the result may be inconsistent. +// Therefore, the actual transformation approach is: +// sqrt(pow(x,y)) -> pow(x,y*0.5) +Value *LibCallSimplifier::replaceNestedSqrtAndPowWithPow(CallInst *Sqrt, + IRBuilderBase &B) { + Value *OldPow = nullptr, *NewPow = nullptr; + OldPow = Sqrt->getArgOperand(0); + Module *Mod = Sqrt->getModule(); + Type *Ty = Sqrt->getType(); + CallInst *Pow = dyn_cast(OldPow); + if (Pow && Pow->hasOneUse() && Pow->isFast() && Sqrt->isFast()) { + Function *CalleeFn = Pow->getCalledFunction(); + IRBuilderBase::FastMathFlagGuard Guard(B); + B.setFastMathFlags(Sqrt->getFastMathFlags()); + LibFunc LibFn; + // If Sqrt is an intrinsic call and + // its first argument is also an intrinsic call to pow + if (IntrinsicInst *II = dyn_cast(Sqrt)) { + if (II->getIntrinsicID() == Intrinsic::sqrt && CalleeFn && + CalleeFn->getIntrinsicID() == Intrinsic::pow) { + Value *x = Pow->getOperand(0); + Value *y = Pow->getOperand(1); + Value *y_05 = B.CreateFMul(y, ConstantFP::get(Ty, 0.5)); + NewPow = B.CreateCall( + Intrinsic::getDeclaration(Mod, Pow->getIntrinsicID(), Ty), + {x, y_05}); + } + } else if (CalleeFn && TLI->getLibFunc(CalleeFn->getName(), LibFn) && + isLibFuncEmittable(Mod, TLI, LibFn)) { + LibFunc floatFn, doubleFn, longDFn; + switch (LibFn) { + case LibFunc_powf: + case LibFunc_pow: + case LibFunc_powl: + floatFn = LibFunc_powf; + doubleFn = LibFunc_pow; + longDFn = LibFunc_powl; + break; + case LibFunc_powf_finite: + case LibFunc_pow_finite: + case LibFunc_powl_finite: + floatFn = LibFunc_powf_finite; + doubleFn = LibFunc_pow_finite; + longDFn = LibFunc_powl_finite; + break; + default: + return nullptr; + } + Value *x = Pow->getOperand(0); + Value *y = Pow->getOperand(1); + Value *y_05 = B.CreateFMul(y, ConstantFP::get(Ty, 0.5)); + NewPow = emitBinaryFloatFnCall(x, y_05, TLI, doubleFn, floatFn, longDFn, + B, CalleeFn->getAttributes()); + } + if (NewPow) { + Sqrt->replaceAllUsesWith(NewPow); + return NewPow; + } + } + return nullptr; +} + +/* cbrt(expN(X)) -> expN(x/3) + * cbrt(sqrt(x)) -> pow(x,1/6) + * cbrt(cbrt(x)) -> pow(x,1/9) (incorrect transformation) + * When x < 0, the third transformation would yield incorrect results. + * Therefore, it is necessary to handle the transformation of x differently + * based on different cases. + * cbrt(cbrt(x)) -> x>=0?pow(x,1/9):-pow(-x,1/9) + */ +Value *LibCallSimplifier::optimizeCbrt(CallInst *CI, IRBuilderBase &B) { + Module *M = CI->getModule(); + Value *Base = CI->getArgOperand(0); + CallInst *BaseFn = dyn_cast(Base); + Type *Ty = CI->getType(); + Value *Ret = nullptr, *tempRet1 = nullptr, *tempRet2 = nullptr; + if (!TargetLibraryInfoImpl::isCallingConvCCompatible(CI)) + return nullptr; + IRBuilderBase::FastMathFlagGuard Guard(B); + B.setFastMathFlags(CI->getFastMathFlags()); + // Confirming that the internal representation of the cbrt function + // also involves a function call, and the fast-math flag is enabled + if (BaseFn && BaseFn->hasOneUse() && BaseFn->isFast() && CI->isFast()) { + LibFunc LibFn; + Function *CalleeFn = BaseFn->getCalledFunction(); + Value *x; + // If the internal representation is an intrinsic call + if (IntrinsicInst *II = dyn_cast(BaseFn)) { + Intrinsic::ID IntrinsicID = II->getIntrinsicID(); + switch (IntrinsicID) { + // cbrt(exp(X)) -> exp(x/3) + // cbrt(exp2(X)) -> exp2(x/3) + case Intrinsic::exp: + case Intrinsic::exp2: + x = BaseFn->getOperand(0); + Ret = B.CreateCall(Intrinsic::getDeclaration(M, IntrinsicID, Ty), + B.CreateFDiv(x, ConstantFP::get(Ty, 3.0))); + break; + // cbrt(sqrt(x)) -> pow(x,1/6) + case Intrinsic::sqrt: + x = BaseFn->getOperand(0); + Ret = B.CreateCall(Intrinsic::getDeclaration(M, Intrinsic::pow, Ty), + {x, B.CreateFDiv(ConstantFP::get(Ty, 1.0), + ConstantFP::get(Ty, 6.0))}); + break; + default: + return nullptr; + } + } else if (CalleeFn && TLI->getLibFunc(CalleeFn->getName(), LibFn) && + isLibFuncEmittable(M, TLI, LibFn)) { + switch (LibFn) { + // cbrt(exp(X)) -> exp(x/3) + case LibFunc_exp: + case LibFunc_expf: + case LibFunc_expl: + x = BaseFn->getOperand(0); + Ret = emitUnaryFloatFnCall(B.CreateFDiv(x, ConstantFP::get(Ty, 3.0)), + TLI, LibFunc_exp, LibFunc_expf, LibFunc_expl, + B, CalleeFn->getAttributes()); + break; + case LibFunc_exp_finite: + case LibFunc_expf_finite: + case LibFunc_expl_finite: + x = BaseFn->getOperand(0); + Ret = emitUnaryFloatFnCall(B.CreateFDiv(x, ConstantFP::get(Ty, 3.0)), + TLI, LibFunc_exp_finite, LibFunc_expf_finite, + LibFunc_expl_finite, B, + CalleeFn->getAttributes()); + break; + // cbrt(exp2(X)) -> exp2(x/3) + case LibFunc_exp2: + case LibFunc_exp2f: + case LibFunc_exp2l: + x = BaseFn->getOperand(0); + Ret = emitUnaryFloatFnCall(B.CreateFDiv(x, ConstantFP::get(Ty, 3.0)), + TLI, LibFunc_exp2, LibFunc_exp2f, + LibFunc_exp2l, B, CalleeFn->getAttributes()); + break; + case LibFunc_exp2_finite: + case LibFunc_exp2f_finite: + case LibFunc_exp2l_finite: + x = BaseFn->getOperand(0); + Ret = emitUnaryFloatFnCall(B.CreateFDiv(x, ConstantFP::get(Ty, 3.0)), + TLI, LibFunc_exp2_finite, + LibFunc_exp2f_finite, LibFunc_exp2l_finite, + B, CalleeFn->getAttributes()); + break; + // cbrt(sqrt(x)) -> pow(x,1/6) + case LibFunc_sqrt: + case LibFunc_sqrtf: + case LibFunc_sqrtl: + x = BaseFn->getOperand(0); + Ret = emitBinaryFloatFnCall( + x, B.CreateFDiv(ConstantFP::get(Ty, 1.0), ConstantFP::get(Ty, 6.0)), + TLI, LibFunc_pow, LibFunc_powf, LibFunc_powl, B, + BaseFn->getAttributes()); + break; + // cbrt(sqrt(x)) -> pow(x,1/6) + case LibFunc_sqrt_finite: + case LibFunc_sqrtf_finite: + case LibFunc_sqrtl_finite: + x = BaseFn->getOperand(0); + Ret = emitBinaryFloatFnCall( + x, B.CreateFDiv(ConstantFP::get(Ty, 1.0), ConstantFP::get(Ty, 6.0)), + TLI, LibFunc_pow_finite, LibFunc_powf_finite, LibFunc_powl_finite, + B, BaseFn->getAttributes()); + break; + // cbrt(cbrt(x)) -> pow(x,1/9) + case LibFunc_cbrt: + case LibFunc_cbrtf: + case LibFunc_cbrtl: + x = BaseFn->getOperand(0); + // When x >= 0, it can be transformed into pow(x, 1/9) + tempRet1 = emitBinaryFloatFnCall( + x, B.CreateFDiv(ConstantFP::get(Ty, 1.0), ConstantFP::get(Ty, 9.0)), + TLI, LibFunc_pow, LibFunc_powf, LibFunc_powl, B, + BaseFn->getAttributes()); + // When x < 0, it can be transformed into -pow(-x, 1/9) + tempRet2 = B.CreateFNeg(emitBinaryFloatFnCall( + B.CreateFNeg(x), + B.CreateFDiv(ConstantFP::get(Ty, 1.0), ConstantFP::get(Ty, 9.0)), + TLI, LibFunc_pow, LibFunc_powf, LibFunc_powl, B, + BaseFn->getAttributes())); + Ret = B.CreateSelect(B.CreateFCmpOGE(x, ConstantFP::get(Ty, 0.0)), + tempRet1, tempRet2); + break; + default: + return nullptr; + } + } + if (Ret) { + CI->replaceAllUsesWith(Ret); + return Ret; + } + } + // Reverting to the original handling of the cbrt function + if (UnsafeFPShrink && hasFloatVersion(M, CI->getCalledFunction()->getName())) + return optimizeUnaryDoubleFP(CI, B, TLI, true); + return nullptr; +} + static Value *createPowWithIntegerExponent(Value *Base, Value *Expo, Module *M, IRBuilderBase &B) { Value *Args[] = {Base, Expo}; @@ -2021,6 +2346,12 @@ Value *LibCallSimplifier::optimizePow(CallInst *Pow, IRBuilderBase &B) { if (Value *Sqrt = replacePowWithSqrt(Pow, B)) return Sqrt; + //pow(sqrt(x),y) -> pow(x,y*0.5) + if (Value *V = replaceNestedPowAndSqrtWithPow(Pow, B)) + return V; + //pow(pow(x,y),z)-> pow(x,y*z) + if (Value *V = replaceNestedPowAndPowWithPow(Pow, B)) + return V; // If we can approximate pow: // pow(x, n) -> powi(x, n) * sqrt(x) if n has exactly a 0.5 fraction @@ -2313,6 +2644,9 @@ Value *LibCallSimplifier::optimizeSqrt(CallInst *CI, IRBuilderBase &B) { if (!CI->isFast()) return Ret; + // sqrt(pow(x,y)) -> pow(|x|,y*0.5) + if (Value *V = replaceNestedSqrtAndPowWithPow(CI, B)) + return V; Instruction *I = dyn_cast(CI->getArgOperand(0)); if (!I || I->getOpcode() != Instruction::FMul || !I->isFast()) @@ -3274,6 +3608,9 @@ Value *LibCallSimplifier::optimizeFloatingPointLibCall(CallInst *CI, case LibFunc_powf: case LibFunc_pow: case LibFunc_powl: + case LibFunc_powf_finite: + case LibFunc_pow_finite: + case LibFunc_powl_finite: return optimizePow(CI, Builder); case LibFunc_exp2l: case LibFunc_exp2: @@ -3286,6 +3623,9 @@ Value *LibCallSimplifier::optimizeFloatingPointLibCall(CallInst *CI, case LibFunc_sqrtf: case LibFunc_sqrt: case LibFunc_sqrtl: + case LibFunc_sqrtf_finite: + case LibFunc_sqrt_finite: + case LibFunc_sqrtl_finite: return optimizeSqrt(CI, Builder); case LibFunc_logf: case LibFunc_log: @@ -3327,7 +3667,6 @@ Value *LibCallSimplifier::optimizeFloatingPointLibCall(CallInst *CI, case LibFunc_asinh: case LibFunc_atan: case LibFunc_atanh: - case LibFunc_cbrt: case LibFunc_cosh: case LibFunc_exp: case LibFunc_exp10: @@ -3354,6 +3693,10 @@ Value *LibCallSimplifier::optimizeFloatingPointLibCall(CallInst *CI, case LibFunc_cabsf: case LibFunc_cabsl: return optimizeCAbs(CI, Builder); + case LibFunc_cbrtf: + case LibFunc_cbrt: + case LibFunc_cbrtl: + return optimizeCbrt(CI, Builder); default: return nullptr; }