From 89fe0551ddaefad3ffc51762145060aee6ba10e2 Mon Sep 17 00:00:00 2001 From: softwarezhen Date: Sat, 9 Sep 2023 03:50:34 +0000 Subject: [PATCH 1/3] =?UTF-8?q?=E5=AE=9E=E7=8E=B0=E4=BA=86=E4=BB=A5?= =?UTF-8?q?=E4=B8=8B=E5=87=A0=E7=A7=8D=E6=95=B0=E5=AD=A6=E5=BA=93=E5=92=8C?= =?UTF-8?q?=E5=B9=B6=E4=BC=98=E5=8C=96,=E5=B9=B6=E8=BF=9B=E8=A1=8C?= =?UTF-8?q?=E4=BA=86=E5=8A=9F=E8=83=BD=E9=AA=8C=E8=AF=81=EF=BC=9A=20pow=20?= =?UTF-8?q?(sqrt(x),=20y)=20->=20pow=20(x,=20y0.5)=20pow=20(pow=20(x,=20y)?= =?UTF-8?q?,=20z)=20->=20pow=20(x,=20yz)=20sqrt=20(Nroot(x))=20->=20pow(x,?= =?UTF-8?q?1/(2N))=20sqrt=20(pow=20(x,=20y))=20->=20pow=20(x,=20y0.5)=20cb?= =?UTF-8?q?rt(exp(X))=20->=20exp(x/3)=20cbrt(exp2(X))=20->=20exp2(x/3)=20c?= =?UTF-8?q?brt(sqrt(x))=20->=20pow(x,1/6)=20cbrt(cbrt(x))=20->=20x>=3D0=3F?= =?UTF-8?q?pow(x,1/9):-pow(-x,1/9)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: softwarezhen --- .../llvm/Transforms/Utils/SimplifyLibCalls.h | 4 + .../lib/Transforms/Utils/SimplifyLibCalls.cpp | 347 +++++++++++++++++- .../InstCombine/pow-sqrt-exp-cbrt.ll | 216 +++++++++++ 3 files changed, 566 insertions(+), 1 deletion(-) create mode 100644 llvm/test/Transforms/InstCombine/pow-sqrt-exp-cbrt.ll diff --git a/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h b/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h index 1b2482a2363d..be7733ded2a1 100644 --- a/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h +++ b/llvm/include/llvm/Transforms/Utils/SimplifyLibCalls.h @@ -191,6 +191,10 @@ private: Value *replacePowWithExp(CallInst *Pow, IRBuilderBase &B); Value *replacePowWithSqrt(CallInst *Pow, IRBuilderBase &B); Value *optimizeExp2(CallInst *CI, IRBuilderBase &B); + Value *replaceNestedPowAndSqrtWithPow(CallInst *Pow, IRBuilderBase &B); + Value *replaceNestedPowAndPowWithPow(CallInst *Pow, IRBuilderBase &B); + Value *replaceNestedSqrtAndPowWithPow(CallInst *Sqrt,IRBuilderBase &B); + Value *optimizeCbrt(CallInst *CI, IRBuilderBase &B); Value *optimizeFMinFMax(CallInst *CI, IRBuilderBase &B); Value *optimizeLog(CallInst *CI, IRBuilderBase &B); Value *optimizeSqrt(CallInst *CI, IRBuilderBase &B); diff --git a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp index 245f2d4e442a..d154a2dc5d38 100644 --- a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp @@ -1971,6 +1971,331 @@ Value *LibCallSimplifier::replacePowWithSqrt(CallInst *Pow, IRBuilderBase &B) { return Sqrt; } +// pow(sqrt(x),y) -> pow(x,y*0.5) +Value *LibCallSimplifier::replaceNestedPowAndSqrtWithPow(CallInst *Pow, + IRBuilderBase &B) { + Value *Base = nullptr, *y = nullptr, *NewPow = nullptr; + Base = Pow->getArgOperand(0); + y = Pow->getArgOperand(1); + Module *Mod = Pow->getModule(); + Type *Ty = Pow->getType(); + CallInst *BaseFn = dyn_cast(Base); + if (BaseFn && BaseFn->hasOneUse() && BaseFn->isFast() && Pow->isFast()) { + Function *CalleeFn = BaseFn->getCalledFunction(); + LibFunc LibFn; + + // If Pow is an intrinsic call, and + // its first argument is an intrinsic call to Sqrt + if (IntrinsicInst *II = dyn_cast(Pow)) { + if (II->getIntrinsicID() == Intrinsic::pow && CalleeFn && + CalleeFn->getIntrinsicID() == Intrinsic::sqrt) { + Value *x = BaseFn->getOperand(0); + // Create a new node y * 0.5. + Value *y_05 = B.CreateFMul(y, ConstantFP::get(Ty, 0.5)); + NewPow = B.CreateCall( + Intrinsic::getDeclaration(Mod, Pow->getIntrinsicID(), Ty), + {x, y_05}); + } + } + // If it is a library function call + else if (CalleeFn && TLI->getLibFunc(CalleeFn->getName(), LibFn) && + isLibFuncEmittable(Mod, TLI, LibFn)) { + LibFunc floatFn, doubleFn, longDFn; + switch (LibFn) { + case LibFunc_sqrtf: + case LibFunc_sqrt: + case LibFunc_sqrtl: + floatFn = LibFunc_powf; + doubleFn = LibFunc_pow; + longDFn = LibFunc_powl; + break; + case LibFunc_sqrtf_finite: + case LibFunc_sqrt_finite: + case LibFunc_sqrtl_finite: + floatFn = LibFunc_powf_finite; + doubleFn = LibFunc_pow_finite; + longDFn = LibFunc_powl_finite; + break; + default: + return nullptr; + } + Value *x = BaseFn->getOperand(0); + NewPow = emitBinaryFloatFnCall( + x, B.CreateFMul(y, ConstantFP::get(Ty, 0.5)), TLI, doubleFn, floatFn, + longDFn, B, CalleeFn->getAttributes()); + } + if (NewPow) { + Pow->replaceAllUsesWith(NewPow); + return NewPow; + } + } + return nullptr; +} + +//pow(pow(x,y),z)-> pow(x,y*z) +Value *LibCallSimplifier::replaceNestedPowAndPowWithPow(CallInst *Pow, IRBuilderBase &B){ + Value *Base=nullptr,*z=nullptr,*NewPow=nullptr; + Base = Pow->getArgOperand(0); + z = Pow->getArgOperand(1); + Module *Mod = Pow->getModule(); + Type *Ty = Pow->getType(); + CallInst *BaseFn = dyn_cast(Base); + + if (BaseFn && BaseFn->hasOneUse() && BaseFn->isFast() && Pow->isFast()){ + Function *CalleeFn = BaseFn->getCalledFunction(); + LibFunc LibFn; + // If Pow is an intrinsic call and + // its first argument is also an intrinsic call to pow + if (IntrinsicInst *II = dyn_cast(Pow)) { + if (II->getIntrinsicID() == Intrinsic::pow && CalleeFn && + CalleeFn->getIntrinsicID() == Intrinsic::pow) { + Value *x = BaseFn->getOperand(0); + Value *y = BaseFn->getOperand(1); + Value *yz = B.CreateFMul(y, z); + NewPow = B.CreateCall( + Intrinsic::getDeclaration(Mod, Pow->getIntrinsicID(), Ty), {x, yz}); + } + } + // If it is a library function call + else if (CalleeFn && TLI->getLibFunc(CalleeFn->getName(), LibFn) && + isLibFuncEmittable(Mod, TLI, LibFn)) { + LibFunc floatFn,doubleFn,longDFn; + switch (LibFn) + { + case LibFunc_powf: + case LibFunc_pow: + case LibFunc_powl: + floatFn = LibFunc_powf; + doubleFn = LibFunc_pow; + longDFn = LibFunc_powl; + break; + case LibFunc_powf_finite: + case LibFunc_pow_finite: + case LibFunc_powl_finite: + floatFn = LibFunc_powf_finite; + doubleFn = LibFunc_pow_finite; + longDFn = LibFunc_powl_finite; + break; + default: + return nullptr; + } + Value *x=BaseFn->getOperand(0); + Value *y=BaseFn->getOperand(1); + Value *yz=B.CreateFMul(y,z); + NewPow = emitBinaryFloatFnCall(x, yz, TLI, doubleFn, floatFn, longDFn, + B, CalleeFn->getAttributes()); + } + if (NewPow) { + Pow->replaceAllUsesWith(NewPow); + return NewPow; + } + } + return nullptr; +} + +// sqrt(pow(x,y)) -> pow(|x|,y*0.5) (incorrect transformation) +// If x is transformed into abs(x), +// the result may be inconsistent. +// Therefore, the actual transformation approach is: +// sqrt(pow(x,y)) -> pow(x,y*0.5) +Value *LibCallSimplifier::replaceNestedSqrtAndPowWithPow(CallInst *Sqrt, + IRBuilderBase &B) { + Value *OldPow = nullptr, *NewPow = nullptr; + OldPow = Sqrt->getArgOperand(0); + Module *Mod = Sqrt->getModule(); + Type *Ty = Sqrt->getType(); + CallInst *Pow = dyn_cast(OldPow); + if (Pow && Pow->hasOneUse() && Pow->isFast() && Sqrt->isFast()) { + Function *CalleeFn = Pow->getCalledFunction(); + IRBuilderBase::FastMathFlagGuard Guard(B); + B.setFastMathFlags(Sqrt->getFastMathFlags()); + LibFunc LibFn; + // If Sqrt is an intrinsic call and + // its first argument is also an intrinsic call to pow + if (IntrinsicInst *II = dyn_cast(Sqrt)) { + if (II->getIntrinsicID() == Intrinsic::sqrt && CalleeFn && + CalleeFn->getIntrinsicID() == Intrinsic::pow) { + Value *x = Pow->getOperand(0); + Value *y = Pow->getOperand(1); + Value *y_05 = B.CreateFMul(y, ConstantFP::get(Ty, 0.5)); + NewPow = B.CreateCall( + Intrinsic::getDeclaration(Mod, Pow->getIntrinsicID(), Ty), + {x, y_05}); + } + } else if (CalleeFn && TLI->getLibFunc(CalleeFn->getName(), LibFn) && + isLibFuncEmittable(Mod, TLI, LibFn)) { + LibFunc floatFn, doubleFn, longDFn; + switch (LibFn) { + case LibFunc_powf: + case LibFunc_pow: + case LibFunc_powl: + floatFn = LibFunc_powf; + doubleFn = LibFunc_pow; + longDFn = LibFunc_powl; + break; + case LibFunc_powf_finite: + case LibFunc_pow_finite: + case LibFunc_powl_finite: + floatFn = LibFunc_powf_finite; + doubleFn = LibFunc_pow_finite; + longDFn = LibFunc_powl_finite; + break; + default: + return nullptr; + } + Value *x = Pow->getOperand(0); + Value *y = Pow->getOperand(1); + Value *y_05 = B.CreateFMul(y, ConstantFP::get(Ty, 0.5)); + NewPow = emitBinaryFloatFnCall(x, y_05, TLI, doubleFn, floatFn, longDFn, + B, CalleeFn->getAttributes()); + } + if (NewPow) { + Sqrt->replaceAllUsesWith(NewPow); + return NewPow; + } + } + return nullptr; +} + +/* cbrt(expN(X)) -> expN(x/3) + * cbrt(sqrt(x)) -> pow(x,1/6) + * cbrt(cbrt(x)) -> pow(x,1/9) (incorrect transformation) + * When x < 0, the third transformation would yield incorrect results. + * Therefore, it is necessary to handle the transformation of x differently + * based on different cases. + * cbrt(cbrt(x)) -> x>=0?pow(x,1/9):-pow(-x,1/9) + */ +Value *LibCallSimplifier::optimizeCbrt(CallInst *CI, IRBuilderBase &B) { + Module *M = CI->getModule(); + Value *Base = CI->getArgOperand(0); + CallInst *BaseFn = dyn_cast(Base); + Type *Ty = CI->getType(); + Value *Ret = nullptr, *tempRet1 = nullptr, *tempRet2 = nullptr; + if (!TargetLibraryInfoImpl::isCallingConvCCompatible(CI)) + return nullptr; + IRBuilderBase::FastMathFlagGuard Guard(B); + B.setFastMathFlags(CI->getFastMathFlags()); + // Confirming that the internal representation of the cbrt function + // also involves a function call, and the fast-math flag is enabled + if (BaseFn && BaseFn->hasOneUse() && BaseFn->isFast() && CI->isFast()) { + LibFunc LibFn; + Function *CalleeFn = BaseFn->getCalledFunction(); + Value *x; + // If the internal representation is an intrinsic call + if (IntrinsicInst *II = dyn_cast(BaseFn)) { + Intrinsic::ID IntrinsicID = II->getIntrinsicID(); + switch (IntrinsicID) { + // cbrt(exp(X)) -> exp(x/3) + // cbrt(exp2(X)) -> exp2(x/3) + case Intrinsic::exp: + case Intrinsic::exp2: + x = BaseFn->getOperand(0); + Ret = B.CreateCall(Intrinsic::getDeclaration(M, IntrinsicID, Ty), + B.CreateFDiv(x, ConstantFP::get(Ty, 3.0))); + break; + // cbrt(sqrt(x)) -> pow(x,1/6) + case Intrinsic::sqrt: + x = BaseFn->getOperand(0); + Ret = B.CreateCall(Intrinsic::getDeclaration(M, Intrinsic::pow, Ty), + {x, B.CreateFDiv(ConstantFP::get(Ty, 1.0), + ConstantFP::get(Ty, 6.0))}); + break; + default: + return nullptr; + } + } else if (CalleeFn && TLI->getLibFunc(CalleeFn->getName(), LibFn) && + isLibFuncEmittable(M, TLI, LibFn)) { + switch (LibFn) { + // cbrt(exp(X)) -> exp(x/3) + case LibFunc_exp: + case LibFunc_expf: + case LibFunc_expl: + x = BaseFn->getOperand(0); + Ret = emitUnaryFloatFnCall(B.CreateFDiv(x, ConstantFP::get(Ty, 3.0)), + TLI, LibFunc_exp, LibFunc_expf, LibFunc_expl, + B, CalleeFn->getAttributes()); + break; + case LibFunc_exp_finite: + case LibFunc_expf_finite: + case LibFunc_expl_finite: + x = BaseFn->getOperand(0); + Ret = emitUnaryFloatFnCall(B.CreateFDiv(x, ConstantFP::get(Ty, 3.0)), + TLI, LibFunc_exp_finite, LibFunc_expf_finite, + LibFunc_expl_finite, B, + CalleeFn->getAttributes()); + break; + // cbrt(exp2(X)) -> exp2(x/3) + case LibFunc_exp2: + case LibFunc_exp2f: + case LibFunc_exp2l: + x = BaseFn->getOperand(0); + Ret = emitUnaryFloatFnCall(B.CreateFDiv(x, ConstantFP::get(Ty, 3.0)), + TLI, LibFunc_exp2, LibFunc_exp2f, + LibFunc_exp2l, B, CalleeFn->getAttributes()); + break; + case LibFunc_exp2_finite: + case LibFunc_exp2f_finite: + case LibFunc_exp2l_finite: + x = BaseFn->getOperand(0); + Ret = emitUnaryFloatFnCall(B.CreateFDiv(x, ConstantFP::get(Ty, 3.0)), + TLI, LibFunc_exp2_finite, + LibFunc_exp2f_finite, LibFunc_exp2l_finite, + B, CalleeFn->getAttributes()); + break; + // cbrt(sqrt(x)) -> pow(x,1/6) + case LibFunc_sqrt: + case LibFunc_sqrtf: + case LibFunc_sqrtl: + x = BaseFn->getOperand(0); + Ret = emitBinaryFloatFnCall( + x, B.CreateFDiv(ConstantFP::get(Ty, 1.0), ConstantFP::get(Ty, 6.0)), + TLI, LibFunc_pow, LibFunc_powf, LibFunc_powl, B, + BaseFn->getAttributes()); + break; + // cbrt(sqrt(x)) -> pow(x,1/6) + case LibFunc_sqrt_finite: + case LibFunc_sqrtf_finite: + case LibFunc_sqrtl_finite: + x = BaseFn->getOperand(0); + Ret = emitBinaryFloatFnCall( + x, B.CreateFDiv(ConstantFP::get(Ty, 1.0), ConstantFP::get(Ty, 6.0)), + TLI, LibFunc_pow_finite, LibFunc_powf_finite, LibFunc_powl_finite, + B, BaseFn->getAttributes()); + break; + // cbrt(cbrt(x)) -> pow(x,1/9) + case LibFunc_cbrt: + case LibFunc_cbrtf: + case LibFunc_cbrtl: + x = BaseFn->getOperand(0); + // When x >= 0, it can be transformed into pow(x, 1/9) + tempRet1 = emitBinaryFloatFnCall( + x, B.CreateFDiv(ConstantFP::get(Ty, 1.0), ConstantFP::get(Ty, 9.0)), + TLI, LibFunc_pow, LibFunc_powf, LibFunc_powl, B, + BaseFn->getAttributes()); + // When x < 0, it can be transformed into -pow(-x, 1/9) + tempRet2 = B.CreateFNeg(emitBinaryFloatFnCall( + B.CreateFNeg(x), + B.CreateFDiv(ConstantFP::get(Ty, 1.0), ConstantFP::get(Ty, 9.0)), + TLI, LibFunc_pow, LibFunc_powf, LibFunc_powl, B, + BaseFn->getAttributes())); + Ret = B.CreateSelect(B.CreateFCmpOGE(x, ConstantFP::get(Ty, 0.0)), + tempRet1, tempRet2); + break; + default: + return nullptr; + } + } + if (Ret) { + CI->replaceAllUsesWith(Ret); + return Ret; + } + } + // Reverting to the original handling of the cbrt function + if (UnsafeFPShrink && hasFloatVersion(M, CI->getCalledFunction()->getName())) + return optimizeUnaryDoubleFP(CI, B, TLI, true); + return nullptr; +} + static Value *createPowWithIntegerExponent(Value *Base, Value *Expo, Module *M, IRBuilderBase &B) { Value *Args[] = {Base, Expo}; @@ -2021,6 +2346,13 @@ Value *LibCallSimplifier::optimizePow(CallInst *Pow, IRBuilderBase &B) { if (Value *Sqrt = replacePowWithSqrt(Pow, B)) return Sqrt; + + //pow(sqrt(x),y) -> pow(x,y*0.5) + if (Value *V = replaceNestedPowAndSqrtWithPow(Pow, B)) + return V; + //pow(pow(x,y),z)-> pow(x,y*z) + if (Value *V = replaceNestedPowAndPowWithPow(Pow, B)) + return V; // If we can approximate pow: // pow(x, n) -> powi(x, n) * sqrt(x) if n has exactly a 0.5 fraction @@ -2314,6 +2646,10 @@ Value *LibCallSimplifier::optimizeSqrt(CallInst *CI, IRBuilderBase &B) { if (!CI->isFast()) return Ret; + // sqrt(pow(x,y)) -> pow(|x|,y*0.5) + if (Value *V = replaceNestedSqrtAndPowWithPow(CI, B)) + return V; + Instruction *I = dyn_cast(CI->getArgOperand(0)); if (!I || I->getOpcode() != Instruction::FMul || !I->isFast()) return Ret; @@ -3274,6 +3610,9 @@ Value *LibCallSimplifier::optimizeFloatingPointLibCall(CallInst *CI, case LibFunc_powf: case LibFunc_pow: case LibFunc_powl: + case LibFunc_pow_finite: + case LibFunc_powf_finite: + case LibFunc_powl_finite: return optimizePow(CI, Builder); case LibFunc_exp2l: case LibFunc_exp2: @@ -3286,6 +3625,9 @@ Value *LibCallSimplifier::optimizeFloatingPointLibCall(CallInst *CI, case LibFunc_sqrtf: case LibFunc_sqrt: case LibFunc_sqrtl: + case LibFunc_sqrtf_finite: + case LibFunc_sqrt_finite: + case LibFunc_sqrtl_finite: return optimizeSqrt(CI, Builder); case LibFunc_logf: case LibFunc_log: @@ -3327,7 +3669,6 @@ Value *LibCallSimplifier::optimizeFloatingPointLibCall(CallInst *CI, case LibFunc_asinh: case LibFunc_atan: case LibFunc_atanh: - case LibFunc_cbrt: case LibFunc_cosh: case LibFunc_exp: case LibFunc_exp10: @@ -3354,6 +3695,10 @@ Value *LibCallSimplifier::optimizeFloatingPointLibCall(CallInst *CI, case LibFunc_cabsf: case LibFunc_cabsl: return optimizeCAbs(CI, Builder); + case LibFunc_cbrtf: + case LibFunc_cbrt: + case LibFunc_cbrtl: + return optimizeCbrt(CI, Builder); default: return nullptr; } diff --git a/llvm/test/Transforms/InstCombine/pow-sqrt-exp-cbrt.ll b/llvm/test/Transforms/InstCombine/pow-sqrt-exp-cbrt.ll new file mode 100644 index 000000000000..24a78094012e --- /dev/null +++ b/llvm/test/Transforms/InstCombine/pow-sqrt-exp-cbrt.ll @@ -0,0 +1,216 @@ +; RUN: opt < %s -passes=instcombine -S | FileCheck %s +; In each test case, an extra instruction is introduced during the transformation, +; which will be eliminated in the subsequent dead code elimination optimization. + +define double @pow_sqrt(double %x, double %y) { +; CHECK-LABEL: @pow_sqrt( +; CHECK-NEXT: [[UNUSED:%.*]] = call fast double @sqrt(double [[X:%.*]]) +; CHECK-NEXT: [[MUL:%.*]] = fmul fast double [[Y:%.*]], 5.000000e-01 +; CHECK-NEXT: [[POW:%.*]] = call fast double @pow(double [[X:%.*]], double [[MUL]]) +; CHECK-NEXT: ret double [[POW]] +; + %call = call fast double @sqrt(double %x) + %pow = call fast double @pow(double %call, double %y) + ret double %pow +} + +define float @powf_sqrtf(float %x, float %y) { +; CHECK-LABEL: @powf_sqrtf( +; CHECK-NEXT: [[UNUSED:%.*]] = call fast float @sqrtf(float [[X:%.*]]) +; CHECK-NEXT: [[MUL:%.*]] = fmul fast float [[Y:%.*]], 5.000000e-01 +; CHECK-NEXT: [[POW:%.*]] = call fast float @powf(float [[X:%.*]], float [[MUL]]) +; CHECK-NEXT: ret float [[POW]] +; + %call = call fast float @sqrtf(float %x) + %pow = call fast float @powf(float %call, float %y) + ret float %pow +} + +define double @pow_pow(double %x, double %y, double %z) { +; CHECK-LABEL: @pow_pow( +; CHECK-NEXT: [[UNUSED:%.*]] = call fast double @pow(double [[X:%.*]], double [[Y:%.*]]) +; CHECK-NEXT: [[MUL:%.*]] = fmul fast double [[Y:%.*]], [[Z:%.*]] +; CHECK-NEXT: [[POW:%.*]] = call fast double @pow(double [[X:%.*]], double [[MUL]]) +; CHECK-NEXT: ret double [[POW]] +; + %call = call fast double @pow(double %x, double %y) + %pow = call fast double @pow(double %call, double %z) + ret double %pow +} + +define float @powf_powf(float %x, float %y, float %z) { +; CHECK-LABEL: @powf_powf( +; CHECK-NEXT: [[UNUSED:%.*]] = call fast float @powf(float [[X:%.*]], float [[Y:%.*]]) +; CHECK-NEXT: [[MUL:%.*]] = fmul fast float [[Y:%.*]], [[Z:%.*]] +; CHECK-NEXT: [[POW:%.*]] = call fast float @powf(float [[X:%.*]], float [[MUL]]) +; CHECK-NEXT: ret float [[POW]] +; + %call = call fast float @powf(float %x, float %y) + %pow = call fast float @powf(float %call, float %z) + ret float %pow +} + +define double @sqrt_nroot(double %x, double %n){ +; CHECK-LABEL: @sqrt_nroot( +; CHECK-NEXT: [[DIV:%.*]] = fdiv double 1.000000e+00, [[N:%.*]] +; CHECK-NEXT: [[UNUSED:%.*]] = call fast double @pow(double [[X:%.*]], double [[DIV]]) +; CHECK-NEXT: [[MUL:%.*]] = fmul fast double [[DIV]], 5.000000e-01 +; CHECK-NEXT: [[POW:%.*]] = call fast double @pow(double [[X:%.*]], double [[MUL]]) +; CHECK-NEXT: ret double [[POW]] +; + %div = fdiv double 1.000000e+00, %n + %call = call fast double @pow(double %x, double %div) + %call6 = call fast double @sqrt(double %call) + ret double %call6 +} + +define float @sqrtf_nroot(float %x, float %n){ +; CHECK-LABEL: @sqrtf_nroot( +; CHECK-NEXT: [[DIV:%.*]] = fdiv float 1.000000e+00, [[N:%.*]] +; CHECK-NEXT: [[UNUSED:%.*]] = call fast float @powf(float [[X:%.*]], float [[DIV]]) +; CHECK-NEXT: [[MUL:%.*]] = fmul fast float [[DIV]], 5.000000e-01 +; CHECK-NEXT: [[POW:%.*]] = call fast float @powf(float [[X:%.*]], float [[MUL]]) +; CHECK-NEXT: ret float [[POW]] +; + %div = fdiv float 1.000000e+00, %n + %call = call fast float @powf(float %x, float %div) + %call6 = call fast float @sqrtf(float %call) + ret float %call6 +} + +define double @sqrt_pow(double %x, double %y) { +; CHECK-LABEL: @sqrt_pow( +; CHECK-NEXT: [[UNUSED:%.*]] = call fast double @pow(double [[X:%.*]], double [[Y:%.*]]) +; CHECK-NEXT: [[MUL:%.*]] = fmul fast double [[Y:%.*]], 5.000000e-01 +; CHECK-NEXT: [[POW:%.*]] = call fast double @pow(double [[X:%.*]], double [[MUL]]) +; CHECK-NEXT: ret double [[POW]] +; + %call = call fast double @pow(double %x, double %y) + %pow = call fast double @sqrt(double %call) + ret double %pow +} + +define float @sqrtf_powf(float %x, float %y) { +; CHECK-LABEL: @sqrtf_powf( +; CHECK-NEXT: [[UNUSED:%.*]] = call fast float @powf(float [[X:%.*]], float [[Y:%.*]]) +; CHECK-NEXT: [[MUL:%.*]] = fmul fast float [[Y:%.*]], 5.000000e-01 +; CHECK-NEXT: [[POW:%.*]] = call fast float @powf(float [[X:%.*]], float [[MUL]]) +; CHECK-NEXT: ret float [[POW]] +; + %call = call fast float @powf(float %x, float %y) + %pow = call fast float @sqrtf(float %call) + ret float %pow +} + +define double @cbrt_exp(double %x) { +; CHECK-LABEL: @cbrt_exp( +; CHECK-NEXT: [[UNUSED:%.*]] = call fast double @exp(double [[X:%.*]]) +; CHECK-NEXT: [[MUL:%.*]] = fmul fast double [[X:%.*]], 0x3FD5555555555555 +; CHECK-NEXT: [[EXP:%.*]] = call fast double @exp(double [[MUL]]) +; CHECK-NEXT: ret double [[EXP]] +; + %call = call fast double @exp(double %x) + %pow = call fast double @cbrt(double %call) + ret double %pow +} + +define float @cbrtf_expf(float %x) { +; CHECK-LABEL: @cbrtf_expf( +; CHECK-NEXT: [[UNUSED:%.*]] = call fast float @expf(float [[X:%.*]]) +; CHECK-NEXT: [[MUL:%.*]] = fmul fast float [[X:%.*]], 0x3FD5555560000000 +; CHECK-NEXT: [[EXP:%.*]] = call fast float @expf(float [[MUL]]) +; CHECK-NEXT: ret float [[EXP]] +; + %call = call fast float @expf(float %x) + %pow = call fast float @cbrtf(float %call) + ret float %pow +} + +define double @cbrt_exp2(double %x) { +; CHECK-LABEL: @cbrt_exp2( +; CHECK-NEXT: [[UNUSED:%.*]] = call fast double @exp2(double [[X:%.*]]) +; CHECK-NEXT: [[MUL:%.*]] = fmul fast double [[X:%.*]], 0x3FD5555555555555 +; CHECK-NEXT: [[EXP:%.*]] = call fast double @exp2(double [[MUL]]) +; CHECK-NEXT: ret double [[EXP]] +; + %call = call fast double @exp2(double %x) + %pow = call fast double @cbrt(double %call) + ret double %pow +} + +define float @cbrtf_exp2f(float %x) { +; CHECK-LABEL: @cbrtf_exp2f( +; CHECK-NEXT: [[UNUSED:%.*]] = call fast float @exp2f(float [[X:%.*]]) +; CHECK-NEXT: [[MUL:%.*]] = fmul fast float [[X:%.*]], 0x3FD5555560000000 +; CHECK-NEXT: [[EXP:%.*]] = call fast float @exp2f(float [[MUL]]) +; CHECK-NEXT: ret float [[EXP]] +; + %call = call fast float @exp2f(float %x) + %pow = call fast float @cbrtf(float %call) + ret float %pow +} + +define double @cbrt_sqrt(double %x) { +; CHECK-LABEL: @cbrt_sqrt( +; CHECK-NEXT: [[UNUSED:%.*]] = call fast double @sqrt(double [[X:%.*]]) +; CHECK-NEXT: [[POW:%.*]] = call fast double @pow(double [[X:%.*]], double 0x3FC5555555555555) +; CHECK-NEXT: ret double [[POW]] +; + %call = call fast double @sqrt(double %x) + %pow = call fast double @cbrt(double %call) + ret double %pow +} + +define float @cbrtf_sqrtf(float %x) { +; CHECK-LABEL: @cbrtf_sqrtf( +; CHECK-NEXT: [[UNUSED:%.*]] = call fast float @sqrtf(float [[X:%.*]]) +; CHECK-NEXT: [[POW:%.*]] = call fast float @powf(float [[X:%.*]], float 0x3FC5555560000000) +; CHECK-NEXT: ret float [[POW]] +; + %call = call fast float @sqrtf(float %x) + %pow = call fast float @cbrtf(float %call) + ret float %pow +} + +define double @cbrt_cbrt(double %x) { +; CHECK-LABEL: @cbrt_cbrt( +; CHECK-NEXT: [[UNUSED:%.*]] = call fast double @cbrt(double [[X:%.*]]) +; CHECK-NEXT: [[POW1:%.*]] = call fast double @pow(double [[X:%.*]], double 0x3FBC71C71C71C71C) +; CHECK-NEXT: [[NEG_X:%.*]] = fneg fast double [[X:%.*]] +; CHECK-NEXT: [[POW2:%.*]] = call fast double @pow(double [[NEG_X]], double 0x3FBC71C71C71C71C) +; CHECK-NEXT: [[NEG_POW2:%.*]] = fneg fast double [[POW2]] +; CHECK-NEXT: [[CMP:%.*]] = fcmp fast oge double [[X:%.*]], 0.000000e+00 +; CHECK-NEXT: [[SELECT:%.*]] = select fast i1 [[CMP]], double [[POW1]], double [[NEG_POW2]] +; CHECK-NEXT: ret double [[SELECT]] +; + %call = call fast double @cbrt(double %x) + %pow = call fast double @cbrt(double %call) + ret double %pow +} + +define float @cbrtf_cbrtf(float %x) { +; CHECK-LABEL: @cbrtf_cbrtf( +; CHECK-NEXT: [[UNUSED:%.*]] = call fast float @cbrtf(float [[X:%.*]]) +; CHECK-NEXT: [[POW1:%.*]] = call fast float @powf(float [[X:%.*]], float 0x3FBC71C720000000) +; CHECK-NEXT: [[NEG_X:%.*]] = fneg fast float [[X:%.*]] +; CHECK-NEXT: [[POW2:%.*]] = call fast float @powf(float [[NEG_X]], float 0x3FBC71C720000000) +; CHECK-NEXT: [[NEG_POW2:%.*]] = fneg fast float [[POW2]] +; CHECK-NEXT: [[CMP:%.*]] = fcmp fast oge float [[X:%.*]], 0.000000e+00 +; CHECK-NEXT: [[SELECT:%.*]] = select fast i1 [[CMP]], float [[POW1]], float [[NEG_POW2]] +; CHECK-NEXT: ret float [[SELECT]] +; + %call = call fast float @cbrtf(float %x) + %pow = call fast float @cbrtf(float %call) + ret float %pow +} + +declare double @pow(double,double) +declare float @powf(float,float) +declare double @sqrt(double) +declare float @sqrtf(float) +declare double @cbrt(double) +declare float @cbrtf(float) +declare double @exp(double) +declare float @expf(float) +declare double @exp2(double) +declare float @exp2f(float) \ No newline at end of file -- Gitee From 3e01ac7d20e2376431bff8b63b20141d355b2430 Mon Sep 17 00:00:00 2001 From: softwarezhen Date: Tue, 12 Sep 2023 12:45:43 +0000 Subject: [PATCH 2/3] =?UTF-8?q?=E5=B0=86sqrt(pow(x,y))=20->=20pow(x,y*0.5)?= =?UTF-8?q?=E4=BC=98=E5=8C=96=E6=9B=B4=E6=AD=A3=E4=B8=BAsqrt(pow(x,y))=20-?= =?UTF-8?q?>=20pow(|x|,y*0.5)=EF=BC=9B=20=E4=BF=AE=E6=94=B9=E4=BA=86?= =?UTF-8?q?=E9=83=A8=E5=88=86=E5=8F=98=E9=87=8F=E5=90=8D=E7=A7=B0=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: softwarezhen --- .../lib/Transforms/Utils/SimplifyLibCalls.cpp | 157 +++++++++--------- .../InstCombine/pow-sqrt-exp-cbrt.ll | 16 +- 2 files changed, 86 insertions(+), 87 deletions(-) diff --git a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp index d154a2dc5d38..e45fdb03c232 100644 --- a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp @@ -1974,9 +1974,9 @@ Value *LibCallSimplifier::replacePowWithSqrt(CallInst *Pow, IRBuilderBase &B) { // pow(sqrt(x),y) -> pow(x,y*0.5) Value *LibCallSimplifier::replaceNestedPowAndSqrtWithPow(CallInst *Pow, IRBuilderBase &B) { - Value *Base = nullptr, *y = nullptr, *NewPow = nullptr; - Base = Pow->getArgOperand(0); - y = Pow->getArgOperand(1); + Value *NewPow = nullptr; + Value *Base = Pow->getArgOperand(0); + Value *Y = Pow->getArgOperand(1); Module *Mod = Pow->getModule(); Type *Ty = Pow->getType(); CallInst *BaseFn = dyn_cast(Base); @@ -1984,21 +1984,18 @@ Value *LibCallSimplifier::replaceNestedPowAndSqrtWithPow(CallInst *Pow, Function *CalleeFn = BaseFn->getCalledFunction(); LibFunc LibFn; - // If Pow is an intrinsic call, and - // its first argument is an intrinsic call to Sqrt + // Check if Pow complies with the conversion rules. if (IntrinsicInst *II = dyn_cast(Pow)) { if (II->getIntrinsicID() == Intrinsic::pow && CalleeFn && CalleeFn->getIntrinsicID() == Intrinsic::sqrt) { - Value *x = BaseFn->getOperand(0); - // Create a new node y * 0.5. - Value *y_05 = B.CreateFMul(y, ConstantFP::get(Ty, 0.5)); + Value *X = BaseFn->getOperand(0); + // Create a new node Y * 0.5. + Value *Mul = B.CreateFMul(Y, ConstantFP::get(Ty, 0.5)); NewPow = B.CreateCall( Intrinsic::getDeclaration(Mod, Pow->getIntrinsicID(), Ty), - {x, y_05}); + {X, Mul}); } - } - // If it is a library function call - else if (CalleeFn && TLI->getLibFunc(CalleeFn->getName(), LibFn) && + } else if (CalleeFn && TLI->getLibFunc(CalleeFn->getName(), LibFn) && isLibFuncEmittable(Mod, TLI, LibFn)) { LibFunc floatFn, doubleFn, longDFn; switch (LibFn) { @@ -2019,9 +2016,9 @@ Value *LibCallSimplifier::replaceNestedPowAndSqrtWithPow(CallInst *Pow, default: return nullptr; } - Value *x = BaseFn->getOperand(0); + Value *X = BaseFn->getOperand(0); NewPow = emitBinaryFloatFnCall( - x, B.CreateFMul(y, ConstantFP::get(Ty, 0.5)), TLI, doubleFn, floatFn, + X, B.CreateFMul(Y, ConstantFP::get(Ty, 0.5)), TLI, doubleFn, floatFn, longDFn, B, CalleeFn->getAttributes()); } if (NewPow) { @@ -2034,9 +2031,9 @@ Value *LibCallSimplifier::replaceNestedPowAndSqrtWithPow(CallInst *Pow, //pow(pow(x,y),z)-> pow(x,y*z) Value *LibCallSimplifier::replaceNestedPowAndPowWithPow(CallInst *Pow, IRBuilderBase &B){ - Value *Base=nullptr,*z=nullptr,*NewPow=nullptr; - Base = Pow->getArgOperand(0); - z = Pow->getArgOperand(1); + Value *NewPow=nullptr; + Value *Base = Pow->getArgOperand(0); + Value *Z = Pow->getArgOperand(1); Module *Mod = Pow->getModule(); Type *Ty = Pow->getType(); CallInst *BaseFn = dyn_cast(Base); @@ -2044,20 +2041,17 @@ Value *LibCallSimplifier::replaceNestedPowAndPowWithPow(CallInst *Pow, IRBuilder if (BaseFn && BaseFn->hasOneUse() && BaseFn->isFast() && Pow->isFast()){ Function *CalleeFn = BaseFn->getCalledFunction(); LibFunc LibFn; - // If Pow is an intrinsic call and - // its first argument is also an intrinsic call to pow + // Check if Pow complies with the conversion rules. if (IntrinsicInst *II = dyn_cast(Pow)) { if (II->getIntrinsicID() == Intrinsic::pow && CalleeFn && CalleeFn->getIntrinsicID() == Intrinsic::pow) { - Value *x = BaseFn->getOperand(0); - Value *y = BaseFn->getOperand(1); - Value *yz = B.CreateFMul(y, z); + Value *X = BaseFn->getOperand(0); + Value *Y = BaseFn->getOperand(1); + Value *Mul = B.CreateFMul(Y, Z); NewPow = B.CreateCall( - Intrinsic::getDeclaration(Mod, Pow->getIntrinsicID(), Ty), {x, yz}); + Intrinsic::getDeclaration(Mod, Pow->getIntrinsicID(), Ty), {X, Mul}); } - } - // If it is a library function call - else if (CalleeFn && TLI->getLibFunc(CalleeFn->getName(), LibFn) && + } else if (CalleeFn && TLI->getLibFunc(CalleeFn->getName(), LibFn) && isLibFuncEmittable(Mod, TLI, LibFn)) { LibFunc floatFn,doubleFn,longDFn; switch (LibFn) @@ -2079,10 +2073,10 @@ Value *LibCallSimplifier::replaceNestedPowAndPowWithPow(CallInst *Pow, IRBuilder default: return nullptr; } - Value *x=BaseFn->getOperand(0); - Value *y=BaseFn->getOperand(1); - Value *yz=B.CreateFMul(y,z); - NewPow = emitBinaryFloatFnCall(x, yz, TLI, doubleFn, floatFn, longDFn, + Value *X=BaseFn->getOperand(0); + Value *Y=BaseFn->getOperand(1); + Value *Mul=B.CreateFMul(Y,Z); + NewPow = emitBinaryFloatFnCall(X, Mul, TLI, doubleFn, floatFn, longDFn, B, CalleeFn->getAttributes()); } if (NewPow) { @@ -2093,15 +2087,11 @@ Value *LibCallSimplifier::replaceNestedPowAndPowWithPow(CallInst *Pow, IRBuilder return nullptr; } -// sqrt(pow(x,y)) -> pow(|x|,y*0.5) (incorrect transformation) -// If x is transformed into abs(x), -// the result may be inconsistent. -// Therefore, the actual transformation approach is: -// sqrt(pow(x,y)) -> pow(x,y*0.5) +// sqrt(pow(x,y)) -> pow(|x|,y*0.5) Value *LibCallSimplifier::replaceNestedSqrtAndPowWithPow(CallInst *Sqrt, IRBuilderBase &B) { - Value *OldPow = nullptr, *NewPow = nullptr; - OldPow = Sqrt->getArgOperand(0); + Value *NewPow = nullptr; + Value *OldPow = Sqrt->getArgOperand(0); Module *Mod = Sqrt->getModule(); Type *Ty = Sqrt->getType(); CallInst *Pow = dyn_cast(OldPow); @@ -2110,17 +2100,18 @@ Value *LibCallSimplifier::replaceNestedSqrtAndPowWithPow(CallInst *Sqrt, IRBuilderBase::FastMathFlagGuard Guard(B); B.setFastMathFlags(Sqrt->getFastMathFlags()); LibFunc LibFn; - // If Sqrt is an intrinsic call and - // its first argument is also an intrinsic call to pow + // Check if Sqrt complies with the conversion rules. if (IntrinsicInst *II = dyn_cast(Sqrt)) { if (II->getIntrinsicID() == Intrinsic::sqrt && CalleeFn && CalleeFn->getIntrinsicID() == Intrinsic::pow) { - Value *x = Pow->getOperand(0); - Value *y = Pow->getOperand(1); - Value *y_05 = B.CreateFMul(y, ConstantFP::get(Ty, 0.5)); + Value *X = Pow->getOperand(0); + Value *Y = Pow->getOperand(1); + Value *AbsX = B.CreateSelect( + B.CreateFCmpOGT(X, ConstantFP::get(Ty, 0.0)), X, B.CreateFNeg(X)); + Value *Mul = B.CreateFMul(Y, ConstantFP::get(Ty, 0.5)); NewPow = B.CreateCall( Intrinsic::getDeclaration(Mod, Pow->getIntrinsicID(), Ty), - {x, y_05}); + {AbsX, Mul}); } } else if (CalleeFn && TLI->getLibFunc(CalleeFn->getName(), LibFn) && isLibFuncEmittable(Mod, TLI, LibFn)) { @@ -2143,10 +2134,12 @@ Value *LibCallSimplifier::replaceNestedSqrtAndPowWithPow(CallInst *Sqrt, default: return nullptr; } - Value *x = Pow->getOperand(0); - Value *y = Pow->getOperand(1); - Value *y_05 = B.CreateFMul(y, ConstantFP::get(Ty, 0.5)); - NewPow = emitBinaryFloatFnCall(x, y_05, TLI, doubleFn, floatFn, longDFn, + Value *X = Pow->getOperand(0); + Value *Y = Pow->getOperand(1); + Value *AbsX = B.CreateSelect( + B.CreateFCmpOGT(X, ConstantFP::get(Ty, 0.0)), X, B.CreateFNeg(X)); + Value *Mul = B.CreateFMul(Y, ConstantFP::get(Ty, 0.5)); + NewPow = emitBinaryFloatFnCall(AbsX, Mul, TLI, doubleFn, floatFn, longDFn, B, CalleeFn->getAttributes()); } if (NewPow) { @@ -2170,7 +2163,7 @@ Value *LibCallSimplifier::optimizeCbrt(CallInst *CI, IRBuilderBase &B) { Value *Base = CI->getArgOperand(0); CallInst *BaseFn = dyn_cast(Base); Type *Ty = CI->getType(); - Value *Ret = nullptr, *tempRet1 = nullptr, *tempRet2 = nullptr; + Value *Ret = nullptr, *TempRet1 = nullptr, *TempRet2 = nullptr; if (!TargetLibraryInfoImpl::isCallingConvCCompatible(CI)) return nullptr; IRBuilderBase::FastMathFlagGuard Guard(B); @@ -2180,24 +2173,24 @@ Value *LibCallSimplifier::optimizeCbrt(CallInst *CI, IRBuilderBase &B) { if (BaseFn && BaseFn->hasOneUse() && BaseFn->isFast() && CI->isFast()) { LibFunc LibFn; Function *CalleeFn = BaseFn->getCalledFunction(); - Value *x; + Value *X; // If the internal representation is an intrinsic call if (IntrinsicInst *II = dyn_cast(BaseFn)) { Intrinsic::ID IntrinsicID = II->getIntrinsicID(); switch (IntrinsicID) { - // cbrt(exp(X)) -> exp(x/3) - // cbrt(exp2(X)) -> exp2(x/3) + // cbrt(exp(X)) -> exp(X/3) + // cbrt(exp2(X)) -> exp2(X/3) case Intrinsic::exp: case Intrinsic::exp2: - x = BaseFn->getOperand(0); + X = BaseFn->getOperand(0); Ret = B.CreateCall(Intrinsic::getDeclaration(M, IntrinsicID, Ty), - B.CreateFDiv(x, ConstantFP::get(Ty, 3.0))); + B.CreateFDiv(X, ConstantFP::get(Ty, 3.0))); break; - // cbrt(sqrt(x)) -> pow(x,1/6) + // cbrt(sqrt(X)) -> pow(X,1/6) case Intrinsic::sqrt: - x = BaseFn->getOperand(0); + X = BaseFn->getOperand(0); Ret = B.CreateCall(Intrinsic::getDeclaration(M, Intrinsic::pow, Ty), - {x, B.CreateFDiv(ConstantFP::get(Ty, 1.0), + {X, B.CreateFDiv(ConstantFP::get(Ty, 1.0), ConstantFP::get(Ty, 6.0))}); break; default: @@ -2206,80 +2199,80 @@ Value *LibCallSimplifier::optimizeCbrt(CallInst *CI, IRBuilderBase &B) { } else if (CalleeFn && TLI->getLibFunc(CalleeFn->getName(), LibFn) && isLibFuncEmittable(M, TLI, LibFn)) { switch (LibFn) { - // cbrt(exp(X)) -> exp(x/3) + // cbrt(exp(X)) -> exp(X/3) case LibFunc_exp: case LibFunc_expf: case LibFunc_expl: - x = BaseFn->getOperand(0); - Ret = emitUnaryFloatFnCall(B.CreateFDiv(x, ConstantFP::get(Ty, 3.0)), + X = BaseFn->getOperand(0); + Ret = emitUnaryFloatFnCall(B.CreateFDiv(X, ConstantFP::get(Ty, 3.0)), TLI, LibFunc_exp, LibFunc_expf, LibFunc_expl, B, CalleeFn->getAttributes()); break; case LibFunc_exp_finite: case LibFunc_expf_finite: case LibFunc_expl_finite: - x = BaseFn->getOperand(0); - Ret = emitUnaryFloatFnCall(B.CreateFDiv(x, ConstantFP::get(Ty, 3.0)), + X = BaseFn->getOperand(0); + Ret = emitUnaryFloatFnCall(B.CreateFDiv(X, ConstantFP::get(Ty, 3.0)), TLI, LibFunc_exp_finite, LibFunc_expf_finite, LibFunc_expl_finite, B, CalleeFn->getAttributes()); break; - // cbrt(exp2(X)) -> exp2(x/3) + // cbrt(exp2(X)) -> exp2(X/3) case LibFunc_exp2: case LibFunc_exp2f: case LibFunc_exp2l: - x = BaseFn->getOperand(0); - Ret = emitUnaryFloatFnCall(B.CreateFDiv(x, ConstantFP::get(Ty, 3.0)), + X = BaseFn->getOperand(0); + Ret = emitUnaryFloatFnCall(B.CreateFDiv(X, ConstantFP::get(Ty, 3.0)), TLI, LibFunc_exp2, LibFunc_exp2f, LibFunc_exp2l, B, CalleeFn->getAttributes()); break; case LibFunc_exp2_finite: case LibFunc_exp2f_finite: case LibFunc_exp2l_finite: - x = BaseFn->getOperand(0); - Ret = emitUnaryFloatFnCall(B.CreateFDiv(x, ConstantFP::get(Ty, 3.0)), + X = BaseFn->getOperand(0); + Ret = emitUnaryFloatFnCall(B.CreateFDiv(X, ConstantFP::get(Ty, 3.0)), TLI, LibFunc_exp2_finite, LibFunc_exp2f_finite, LibFunc_exp2l_finite, B, CalleeFn->getAttributes()); break; - // cbrt(sqrt(x)) -> pow(x,1/6) + // cbrt(sqrt(X)) -> pow(X,1/6) case LibFunc_sqrt: case LibFunc_sqrtf: case LibFunc_sqrtl: - x = BaseFn->getOperand(0); + X = BaseFn->getOperand(0); Ret = emitBinaryFloatFnCall( - x, B.CreateFDiv(ConstantFP::get(Ty, 1.0), ConstantFP::get(Ty, 6.0)), + X, B.CreateFDiv(ConstantFP::get(Ty, 1.0), ConstantFP::get(Ty, 6.0)), TLI, LibFunc_pow, LibFunc_powf, LibFunc_powl, B, BaseFn->getAttributes()); break; - // cbrt(sqrt(x)) -> pow(x,1/6) + // cbrt(sqrt(X)) -> pow(X,1/6) case LibFunc_sqrt_finite: case LibFunc_sqrtf_finite: case LibFunc_sqrtl_finite: - x = BaseFn->getOperand(0); + X = BaseFn->getOperand(0); Ret = emitBinaryFloatFnCall( - x, B.CreateFDiv(ConstantFP::get(Ty, 1.0), ConstantFP::get(Ty, 6.0)), + X, B.CreateFDiv(ConstantFP::get(Ty, 1.0), ConstantFP::get(Ty, 6.0)), TLI, LibFunc_pow_finite, LibFunc_powf_finite, LibFunc_powl_finite, B, BaseFn->getAttributes()); break; - // cbrt(cbrt(x)) -> pow(x,1/9) + // cbrt(cbrt(X)) -> pow(X,1/9) case LibFunc_cbrt: case LibFunc_cbrtf: case LibFunc_cbrtl: - x = BaseFn->getOperand(0); - // When x >= 0, it can be transformed into pow(x, 1/9) - tempRet1 = emitBinaryFloatFnCall( - x, B.CreateFDiv(ConstantFP::get(Ty, 1.0), ConstantFP::get(Ty, 9.0)), + X = BaseFn->getOperand(0); + // When X >= 0, it can be transformed into pow(X, 1/9) + TempRet1 = emitBinaryFloatFnCall( + X, B.CreateFDiv(ConstantFP::get(Ty, 1.0), ConstantFP::get(Ty, 9.0)), TLI, LibFunc_pow, LibFunc_powf, LibFunc_powl, B, BaseFn->getAttributes()); - // When x < 0, it can be transformed into -pow(-x, 1/9) - tempRet2 = B.CreateFNeg(emitBinaryFloatFnCall( - B.CreateFNeg(x), + // When X < 0, it can be transformed into -pow(-X, 1/9) + TempRet2 = B.CreateFNeg(emitBinaryFloatFnCall( + B.CreateFNeg(X), B.CreateFDiv(ConstantFP::get(Ty, 1.0), ConstantFP::get(Ty, 9.0)), TLI, LibFunc_pow, LibFunc_powf, LibFunc_powl, B, BaseFn->getAttributes())); - Ret = B.CreateSelect(B.CreateFCmpOGE(x, ConstantFP::get(Ty, 0.0)), - tempRet1, tempRet2); + Ret = B.CreateSelect(B.CreateFCmpOGE(X, ConstantFP::get(Ty, 0.0)), + TempRet1, TempRet2); break; default: return nullptr; diff --git a/llvm/test/Transforms/InstCombine/pow-sqrt-exp-cbrt.ll b/llvm/test/Transforms/InstCombine/pow-sqrt-exp-cbrt.ll index 24a78094012e..941796b4926d 100644 --- a/llvm/test/Transforms/InstCombine/pow-sqrt-exp-cbrt.ll +++ b/llvm/test/Transforms/InstCombine/pow-sqrt-exp-cbrt.ll @@ -54,8 +54,9 @@ define double @sqrt_nroot(double %x, double %n){ ; CHECK-LABEL: @sqrt_nroot( ; CHECK-NEXT: [[DIV:%.*]] = fdiv double 1.000000e+00, [[N:%.*]] ; CHECK-NEXT: [[UNUSED:%.*]] = call fast double @pow(double [[X:%.*]], double [[DIV]]) +; CHECK-NEXT: [[ABSX:%.*]] = call fast double @llvm.fabs.f64(double [[X:%.*]]) ; CHECK-NEXT: [[MUL:%.*]] = fmul fast double [[DIV]], 5.000000e-01 -; CHECK-NEXT: [[POW:%.*]] = call fast double @pow(double [[X:%.*]], double [[MUL]]) +; CHECK-NEXT: [[POW:%.*]] = call fast double @pow(double [[ABSX]], double [[MUL]]) ; CHECK-NEXT: ret double [[POW]] ; %div = fdiv double 1.000000e+00, %n @@ -68,8 +69,9 @@ define float @sqrtf_nroot(float %x, float %n){ ; CHECK-LABEL: @sqrtf_nroot( ; CHECK-NEXT: [[DIV:%.*]] = fdiv float 1.000000e+00, [[N:%.*]] ; CHECK-NEXT: [[UNUSED:%.*]] = call fast float @powf(float [[X:%.*]], float [[DIV]]) +; CHECK-NEXT: [[ABSX:%.*]] = call fast float @llvm.fabs.f32(float [[X:%.*]]) ; CHECK-NEXT: [[MUL:%.*]] = fmul fast float [[DIV]], 5.000000e-01 -; CHECK-NEXT: [[POW:%.*]] = call fast float @powf(float [[X:%.*]], float [[MUL]]) +; CHECK-NEXT: [[POW:%.*]] = call fast float @powf(float [[ABSX]], float [[MUL]]) ; CHECK-NEXT: ret float [[POW]] ; %div = fdiv float 1.000000e+00, %n @@ -81,8 +83,9 @@ define float @sqrtf_nroot(float %x, float %n){ define double @sqrt_pow(double %x, double %y) { ; CHECK-LABEL: @sqrt_pow( ; CHECK-NEXT: [[UNUSED:%.*]] = call fast double @pow(double [[X:%.*]], double [[Y:%.*]]) +; CHECK-NEXT: [[ABSX:%.*]] = call fast double @llvm.fabs.f64(double [[X:%.*]]) ; CHECK-NEXT: [[MUL:%.*]] = fmul fast double [[Y:%.*]], 5.000000e-01 -; CHECK-NEXT: [[POW:%.*]] = call fast double @pow(double [[X:%.*]], double [[MUL]]) +; CHECK-NEXT: [[POW:%.*]] = call fast double @pow(double [[ABSX]], double [[MUL]]) ; CHECK-NEXT: ret double [[POW]] ; %call = call fast double @pow(double %x, double %y) @@ -93,8 +96,9 @@ define double @sqrt_pow(double %x, double %y) { define float @sqrtf_powf(float %x, float %y) { ; CHECK-LABEL: @sqrtf_powf( ; CHECK-NEXT: [[UNUSED:%.*]] = call fast float @powf(float [[X:%.*]], float [[Y:%.*]]) +; CHECK-NEXT: [[ABSX:%.*]] = call fast float @llvm.fabs.f32(float [[X:%.*]]) ; CHECK-NEXT: [[MUL:%.*]] = fmul fast float [[Y:%.*]], 5.000000e-01 -; CHECK-NEXT: [[POW:%.*]] = call fast float @powf(float [[X:%.*]], float [[MUL]]) +; CHECK-NEXT: [[POW:%.*]] = call fast float @powf(float [[ABSX]], float [[MUL]]) ; CHECK-NEXT: ret float [[POW]] ; %call = call fast float @powf(float %x, float %y) @@ -213,4 +217,6 @@ declare float @cbrtf(float) declare double @exp(double) declare float @expf(float) declare double @exp2(double) -declare float @exp2f(float) \ No newline at end of file +declare float @exp2f(float) +declare double @llvm.fabs.f64(double) +declare float @llvm.fabs.f32(float) -- Gitee From f90bbecc0651b842970820e919c897d7318fd1b5 Mon Sep 17 00:00:00 2001 From: softwarezhen Date: Sat, 9 Sep 2023 03:50:34 +0000 Subject: [PATCH 3/3] =?UTF-8?q?=E5=AE=9E=E7=8E=B0=E4=BA=86=E4=BB=A5?= =?UTF-8?q?=E4=B8=8B=E5=87=A0=E7=A7=8D=E6=95=B0=E5=AD=A6=E5=BA=93=E5=92=8C?= =?UTF-8?q?=E5=B9=B6=E4=BC=98=E5=8C=96,=E5=B9=B6=E8=BF=9B=E8=A1=8C?= =?UTF-8?q?=E4=BA=86=E5=8A=9F=E8=83=BD=E9=AA=8C=E8=AF=81=EF=BC=9A=20pow=20?= =?UTF-8?q?(sqrt(x),=20y)=20->=20pow=20(x,=20y0.5)=20pow=20(pow=20(x,=20y)?= =?UTF-8?q?,=20z)=20->=20pow=20(x,=20yz)=20sqrt=20(Nroot(x))=20->=20pow(x,?= =?UTF-8?q?1/(2N))=20sqrt=20(pow=20(x,=20y))=20->=20pow=20(|x|,=20y0.5)=20?= =?UTF-8?q?cbrt(exp(X))=20->=20exp(x/3)=20cbrt(exp2(X))=20->=20exp2(x/3)?= =?UTF-8?q?=20cbrt(sqrt(x))=20->=20pow(x,1/6)=20cbrt(cbrt(x))=20->=20x>=3D?= =?UTF-8?q?0=3Fpow(x,1/9):-pow(-x,1/9)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: softwarezhen --- .../lib/Transforms/Utils/SimplifyLibCalls.cpp | 186 ++++++++++++++++++ .../InstCombine/pow-sqrt-exp-cbrt.ll | 24 +++ 2 files changed, 210 insertions(+) diff --git a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp index e45fdb03c232..2994dcb3f3b5 100644 --- a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp @@ -1974,9 +1974,15 @@ Value *LibCallSimplifier::replacePowWithSqrt(CallInst *Pow, IRBuilderBase &B) { // pow(sqrt(x),y) -> pow(x,y*0.5) Value *LibCallSimplifier::replaceNestedPowAndSqrtWithPow(CallInst *Pow, IRBuilderBase &B) { +<<<<<<< HEAD Value *NewPow = nullptr; Value *Base = Pow->getArgOperand(0); Value *Y = Pow->getArgOperand(1); +======= + Value *Base = nullptr, *y = nullptr, *NewPow = nullptr; + Base = Pow->getArgOperand(0); + y = Pow->getArgOperand(1); +>>>>>>> 89fe0551ddae (实现了以下几种数学库和并优化,并进行了功能验证:) Module *Mod = Pow->getModule(); Type *Ty = Pow->getType(); CallInst *BaseFn = dyn_cast(Base); @@ -1984,6 +1990,7 @@ Value *LibCallSimplifier::replaceNestedPowAndSqrtWithPow(CallInst *Pow, Function *CalleeFn = BaseFn->getCalledFunction(); LibFunc LibFn; +<<<<<<< HEAD // Check if Pow complies with the conversion rules. if (IntrinsicInst *II = dyn_cast(Pow)) { if (II->getIntrinsicID() == Intrinsic::pow && CalleeFn && @@ -1996,6 +2003,23 @@ Value *LibCallSimplifier::replaceNestedPowAndSqrtWithPow(CallInst *Pow, {X, Mul}); } } else if (CalleeFn && TLI->getLibFunc(CalleeFn->getName(), LibFn) && +======= + // If Pow is an intrinsic call, and + // its first argument is an intrinsic call to Sqrt + if (IntrinsicInst *II = dyn_cast(Pow)) { + if (II->getIntrinsicID() == Intrinsic::pow && CalleeFn && + CalleeFn->getIntrinsicID() == Intrinsic::sqrt) { + Value *x = BaseFn->getOperand(0); + // Create a new node y * 0.5. + Value *y_05 = B.CreateFMul(y, ConstantFP::get(Ty, 0.5)); + NewPow = B.CreateCall( + Intrinsic::getDeclaration(Mod, Pow->getIntrinsicID(), Ty), + {x, y_05}); + } + } + // If it is a library function call + else if (CalleeFn && TLI->getLibFunc(CalleeFn->getName(), LibFn) && +>>>>>>> 89fe0551ddae (实现了以下几种数学库和并优化,并进行了功能验证:) isLibFuncEmittable(Mod, TLI, LibFn)) { LibFunc floatFn, doubleFn, longDFn; switch (LibFn) { @@ -2016,9 +2040,15 @@ Value *LibCallSimplifier::replaceNestedPowAndSqrtWithPow(CallInst *Pow, default: return nullptr; } +<<<<<<< HEAD Value *X = BaseFn->getOperand(0); NewPow = emitBinaryFloatFnCall( X, B.CreateFMul(Y, ConstantFP::get(Ty, 0.5)), TLI, doubleFn, floatFn, +======= + Value *x = BaseFn->getOperand(0); + NewPow = emitBinaryFloatFnCall( + x, B.CreateFMul(y, ConstantFP::get(Ty, 0.5)), TLI, doubleFn, floatFn, +>>>>>>> 89fe0551ddae (实现了以下几种数学库和并优化,并进行了功能验证:) longDFn, B, CalleeFn->getAttributes()); } if (NewPow) { @@ -2031,9 +2061,15 @@ Value *LibCallSimplifier::replaceNestedPowAndSqrtWithPow(CallInst *Pow, //pow(pow(x,y),z)-> pow(x,y*z) Value *LibCallSimplifier::replaceNestedPowAndPowWithPow(CallInst *Pow, IRBuilderBase &B){ +<<<<<<< HEAD Value *NewPow=nullptr; Value *Base = Pow->getArgOperand(0); Value *Z = Pow->getArgOperand(1); +======= + Value *Base=nullptr,*z=nullptr,*NewPow=nullptr; + Base = Pow->getArgOperand(0); + z = Pow->getArgOperand(1); +>>>>>>> 89fe0551ddae (实现了以下几种数学库和并优化,并进行了功能验证:) Module *Mod = Pow->getModule(); Type *Ty = Pow->getType(); CallInst *BaseFn = dyn_cast(Base); @@ -2041,6 +2077,7 @@ Value *LibCallSimplifier::replaceNestedPowAndPowWithPow(CallInst *Pow, IRBuilder if (BaseFn && BaseFn->hasOneUse() && BaseFn->isFast() && Pow->isFast()){ Function *CalleeFn = BaseFn->getCalledFunction(); LibFunc LibFn; +<<<<<<< HEAD // Check if Pow complies with the conversion rules. if (IntrinsicInst *II = dyn_cast(Pow)) { if (II->getIntrinsicID() == Intrinsic::pow && CalleeFn && @@ -2052,6 +2089,22 @@ Value *LibCallSimplifier::replaceNestedPowAndPowWithPow(CallInst *Pow, IRBuilder Intrinsic::getDeclaration(Mod, Pow->getIntrinsicID(), Ty), {X, Mul}); } } else if (CalleeFn && TLI->getLibFunc(CalleeFn->getName(), LibFn) && +======= + // If Pow is an intrinsic call and + // its first argument is also an intrinsic call to pow + if (IntrinsicInst *II = dyn_cast(Pow)) { + if (II->getIntrinsicID() == Intrinsic::pow && CalleeFn && + CalleeFn->getIntrinsicID() == Intrinsic::pow) { + Value *x = BaseFn->getOperand(0); + Value *y = BaseFn->getOperand(1); + Value *yz = B.CreateFMul(y, z); + NewPow = B.CreateCall( + Intrinsic::getDeclaration(Mod, Pow->getIntrinsicID(), Ty), {x, yz}); + } + } + // If it is a library function call + else if (CalleeFn && TLI->getLibFunc(CalleeFn->getName(), LibFn) && +>>>>>>> 89fe0551ddae (实现了以下几种数学库和并优化,并进行了功能验证:) isLibFuncEmittable(Mod, TLI, LibFn)) { LibFunc floatFn,doubleFn,longDFn; switch (LibFn) @@ -2073,10 +2126,17 @@ Value *LibCallSimplifier::replaceNestedPowAndPowWithPow(CallInst *Pow, IRBuilder default: return nullptr; } +<<<<<<< HEAD Value *X=BaseFn->getOperand(0); Value *Y=BaseFn->getOperand(1); Value *Mul=B.CreateFMul(Y,Z); NewPow = emitBinaryFloatFnCall(X, Mul, TLI, doubleFn, floatFn, longDFn, +======= + Value *x=BaseFn->getOperand(0); + Value *y=BaseFn->getOperand(1); + Value *yz=B.CreateFMul(y,z); + NewPow = emitBinaryFloatFnCall(x, yz, TLI, doubleFn, floatFn, longDFn, +>>>>>>> 89fe0551ddae (实现了以下几种数学库和并优化,并进行了功能验证:) B, CalleeFn->getAttributes()); } if (NewPow) { @@ -2087,11 +2147,23 @@ Value *LibCallSimplifier::replaceNestedPowAndPowWithPow(CallInst *Pow, IRBuilder return nullptr; } +<<<<<<< HEAD // sqrt(pow(x,y)) -> pow(|x|,y*0.5) Value *LibCallSimplifier::replaceNestedSqrtAndPowWithPow(CallInst *Sqrt, IRBuilderBase &B) { Value *NewPow = nullptr; Value *OldPow = Sqrt->getArgOperand(0); +======= +// sqrt(pow(x,y)) -> pow(|x|,y*0.5) (incorrect transformation) +// If x is transformed into abs(x), +// the result may be inconsistent. +// Therefore, the actual transformation approach is: +// sqrt(pow(x,y)) -> pow(x,y*0.5) +Value *LibCallSimplifier::replaceNestedSqrtAndPowWithPow(CallInst *Sqrt, + IRBuilderBase &B) { + Value *OldPow = nullptr, *NewPow = nullptr; + OldPow = Sqrt->getArgOperand(0); +>>>>>>> 89fe0551ddae (实现了以下几种数学库和并优化,并进行了功能验证:) Module *Mod = Sqrt->getModule(); Type *Ty = Sqrt->getType(); CallInst *Pow = dyn_cast(OldPow); @@ -2100,6 +2172,7 @@ Value *LibCallSimplifier::replaceNestedSqrtAndPowWithPow(CallInst *Sqrt, IRBuilderBase::FastMathFlagGuard Guard(B); B.setFastMathFlags(Sqrt->getFastMathFlags()); LibFunc LibFn; +<<<<<<< HEAD // Check if Sqrt complies with the conversion rules. if (IntrinsicInst *II = dyn_cast(Sqrt)) { if (II->getIntrinsicID() == Intrinsic::sqrt && CalleeFn && @@ -2112,6 +2185,19 @@ Value *LibCallSimplifier::replaceNestedSqrtAndPowWithPow(CallInst *Sqrt, NewPow = B.CreateCall( Intrinsic::getDeclaration(Mod, Pow->getIntrinsicID(), Ty), {AbsX, Mul}); +======= + // If Sqrt is an intrinsic call and + // its first argument is also an intrinsic call to pow + if (IntrinsicInst *II = dyn_cast(Sqrt)) { + if (II->getIntrinsicID() == Intrinsic::sqrt && CalleeFn && + CalleeFn->getIntrinsicID() == Intrinsic::pow) { + Value *x = Pow->getOperand(0); + Value *y = Pow->getOperand(1); + Value *y_05 = B.CreateFMul(y, ConstantFP::get(Ty, 0.5)); + NewPow = B.CreateCall( + Intrinsic::getDeclaration(Mod, Pow->getIntrinsicID(), Ty), + {x, y_05}); +>>>>>>> 89fe0551ddae (实现了以下几种数学库和并优化,并进行了功能验证:) } } else if (CalleeFn && TLI->getLibFunc(CalleeFn->getName(), LibFn) && isLibFuncEmittable(Mod, TLI, LibFn)) { @@ -2134,12 +2220,19 @@ Value *LibCallSimplifier::replaceNestedSqrtAndPowWithPow(CallInst *Sqrt, default: return nullptr; } +<<<<<<< HEAD Value *X = Pow->getOperand(0); Value *Y = Pow->getOperand(1); Value *AbsX = B.CreateSelect( B.CreateFCmpOGT(X, ConstantFP::get(Ty, 0.0)), X, B.CreateFNeg(X)); Value *Mul = B.CreateFMul(Y, ConstantFP::get(Ty, 0.5)); NewPow = emitBinaryFloatFnCall(AbsX, Mul, TLI, doubleFn, floatFn, longDFn, +======= + Value *x = Pow->getOperand(0); + Value *y = Pow->getOperand(1); + Value *y_05 = B.CreateFMul(y, ConstantFP::get(Ty, 0.5)); + NewPow = emitBinaryFloatFnCall(x, y_05, TLI, doubleFn, floatFn, longDFn, +>>>>>>> 89fe0551ddae (实现了以下几种数学库和并优化,并进行了功能验证:) B, CalleeFn->getAttributes()); } if (NewPow) { @@ -2163,7 +2256,11 @@ Value *LibCallSimplifier::optimizeCbrt(CallInst *CI, IRBuilderBase &B) { Value *Base = CI->getArgOperand(0); CallInst *BaseFn = dyn_cast(Base); Type *Ty = CI->getType(); +<<<<<<< HEAD Value *Ret = nullptr, *TempRet1 = nullptr, *TempRet2 = nullptr; +======= + Value *Ret = nullptr, *tempRet1 = nullptr, *tempRet2 = nullptr; +>>>>>>> 89fe0551ddae (实现了以下几种数学库和并优化,并进行了功能验证:) if (!TargetLibraryInfoImpl::isCallingConvCCompatible(CI)) return nullptr; IRBuilderBase::FastMathFlagGuard Guard(B); @@ -2173,11 +2270,16 @@ Value *LibCallSimplifier::optimizeCbrt(CallInst *CI, IRBuilderBase &B) { if (BaseFn && BaseFn->hasOneUse() && BaseFn->isFast() && CI->isFast()) { LibFunc LibFn; Function *CalleeFn = BaseFn->getCalledFunction(); +<<<<<<< HEAD Value *X; +======= + Value *x; +>>>>>>> 89fe0551ddae (实现了以下几种数学库和并优化,并进行了功能验证:) // If the internal representation is an intrinsic call if (IntrinsicInst *II = dyn_cast(BaseFn)) { Intrinsic::ID IntrinsicID = II->getIntrinsicID(); switch (IntrinsicID) { +<<<<<<< HEAD // cbrt(exp(X)) -> exp(X/3) // cbrt(exp2(X)) -> exp2(X/3) case Intrinsic::exp: @@ -2191,6 +2293,21 @@ Value *LibCallSimplifier::optimizeCbrt(CallInst *CI, IRBuilderBase &B) { X = BaseFn->getOperand(0); Ret = B.CreateCall(Intrinsic::getDeclaration(M, Intrinsic::pow, Ty), {X, B.CreateFDiv(ConstantFP::get(Ty, 1.0), +======= + // cbrt(exp(X)) -> exp(x/3) + // cbrt(exp2(X)) -> exp2(x/3) + case Intrinsic::exp: + case Intrinsic::exp2: + x = BaseFn->getOperand(0); + Ret = B.CreateCall(Intrinsic::getDeclaration(M, IntrinsicID, Ty), + B.CreateFDiv(x, ConstantFP::get(Ty, 3.0))); + break; + // cbrt(sqrt(x)) -> pow(x,1/6) + case Intrinsic::sqrt: + x = BaseFn->getOperand(0); + Ret = B.CreateCall(Intrinsic::getDeclaration(M, Intrinsic::pow, Ty), + {x, B.CreateFDiv(ConstantFP::get(Ty, 1.0), +>>>>>>> 89fe0551ddae (实现了以下几种数学库和并优化,并进行了功能验证:) ConstantFP::get(Ty, 6.0))}); break; default: @@ -2199,42 +2316,71 @@ Value *LibCallSimplifier::optimizeCbrt(CallInst *CI, IRBuilderBase &B) { } else if (CalleeFn && TLI->getLibFunc(CalleeFn->getName(), LibFn) && isLibFuncEmittable(M, TLI, LibFn)) { switch (LibFn) { +<<<<<<< HEAD // cbrt(exp(X)) -> exp(X/3) case LibFunc_exp: case LibFunc_expf: case LibFunc_expl: X = BaseFn->getOperand(0); Ret = emitUnaryFloatFnCall(B.CreateFDiv(X, ConstantFP::get(Ty, 3.0)), +======= + // cbrt(exp(X)) -> exp(x/3) + case LibFunc_exp: + case LibFunc_expf: + case LibFunc_expl: + x = BaseFn->getOperand(0); + Ret = emitUnaryFloatFnCall(B.CreateFDiv(x, ConstantFP::get(Ty, 3.0)), +>>>>>>> 89fe0551ddae (实现了以下几种数学库和并优化,并进行了功能验证:) TLI, LibFunc_exp, LibFunc_expf, LibFunc_expl, B, CalleeFn->getAttributes()); break; case LibFunc_exp_finite: case LibFunc_expf_finite: case LibFunc_expl_finite: +<<<<<<< HEAD X = BaseFn->getOperand(0); Ret = emitUnaryFloatFnCall(B.CreateFDiv(X, ConstantFP::get(Ty, 3.0)), +======= + x = BaseFn->getOperand(0); + Ret = emitUnaryFloatFnCall(B.CreateFDiv(x, ConstantFP::get(Ty, 3.0)), +>>>>>>> 89fe0551ddae (实现了以下几种数学库和并优化,并进行了功能验证:) TLI, LibFunc_exp_finite, LibFunc_expf_finite, LibFunc_expl_finite, B, CalleeFn->getAttributes()); break; +<<<<<<< HEAD // cbrt(exp2(X)) -> exp2(X/3) case LibFunc_exp2: case LibFunc_exp2f: case LibFunc_exp2l: X = BaseFn->getOperand(0); Ret = emitUnaryFloatFnCall(B.CreateFDiv(X, ConstantFP::get(Ty, 3.0)), +======= + // cbrt(exp2(X)) -> exp2(x/3) + case LibFunc_exp2: + case LibFunc_exp2f: + case LibFunc_exp2l: + x = BaseFn->getOperand(0); + Ret = emitUnaryFloatFnCall(B.CreateFDiv(x, ConstantFP::get(Ty, 3.0)), +>>>>>>> 89fe0551ddae (实现了以下几种数学库和并优化,并进行了功能验证:) TLI, LibFunc_exp2, LibFunc_exp2f, LibFunc_exp2l, B, CalleeFn->getAttributes()); break; case LibFunc_exp2_finite: case LibFunc_exp2f_finite: case LibFunc_exp2l_finite: +<<<<<<< HEAD X = BaseFn->getOperand(0); Ret = emitUnaryFloatFnCall(B.CreateFDiv(X, ConstantFP::get(Ty, 3.0)), +======= + x = BaseFn->getOperand(0); + Ret = emitUnaryFloatFnCall(B.CreateFDiv(x, ConstantFP::get(Ty, 3.0)), +>>>>>>> 89fe0551ddae (实现了以下几种数学库和并优化,并进行了功能验证:) TLI, LibFunc_exp2_finite, LibFunc_exp2f_finite, LibFunc_exp2l_finite, B, CalleeFn->getAttributes()); break; +<<<<<<< HEAD // cbrt(sqrt(X)) -> pow(X,1/6) case LibFunc_sqrt: case LibFunc_sqrtf: @@ -2273,6 +2419,46 @@ Value *LibCallSimplifier::optimizeCbrt(CallInst *CI, IRBuilderBase &B) { BaseFn->getAttributes())); Ret = B.CreateSelect(B.CreateFCmpOGE(X, ConstantFP::get(Ty, 0.0)), TempRet1, TempRet2); +======= + // cbrt(sqrt(x)) -> pow(x,1/6) + case LibFunc_sqrt: + case LibFunc_sqrtf: + case LibFunc_sqrtl: + x = BaseFn->getOperand(0); + Ret = emitBinaryFloatFnCall( + x, B.CreateFDiv(ConstantFP::get(Ty, 1.0), ConstantFP::get(Ty, 6.0)), + TLI, LibFunc_pow, LibFunc_powf, LibFunc_powl, B, + BaseFn->getAttributes()); + break; + // cbrt(sqrt(x)) -> pow(x,1/6) + case LibFunc_sqrt_finite: + case LibFunc_sqrtf_finite: + case LibFunc_sqrtl_finite: + x = BaseFn->getOperand(0); + Ret = emitBinaryFloatFnCall( + x, B.CreateFDiv(ConstantFP::get(Ty, 1.0), ConstantFP::get(Ty, 6.0)), + TLI, LibFunc_pow_finite, LibFunc_powf_finite, LibFunc_powl_finite, + B, BaseFn->getAttributes()); + break; + // cbrt(cbrt(x)) -> pow(x,1/9) + case LibFunc_cbrt: + case LibFunc_cbrtf: + case LibFunc_cbrtl: + x = BaseFn->getOperand(0); + // When x >= 0, it can be transformed into pow(x, 1/9) + tempRet1 = emitBinaryFloatFnCall( + x, B.CreateFDiv(ConstantFP::get(Ty, 1.0), ConstantFP::get(Ty, 9.0)), + TLI, LibFunc_pow, LibFunc_powf, LibFunc_powl, B, + BaseFn->getAttributes()); + // When x < 0, it can be transformed into -pow(-x, 1/9) + tempRet2 = B.CreateFNeg(emitBinaryFloatFnCall( + B.CreateFNeg(x), + B.CreateFDiv(ConstantFP::get(Ty, 1.0), ConstantFP::get(Ty, 9.0)), + TLI, LibFunc_pow, LibFunc_powf, LibFunc_powl, B, + BaseFn->getAttributes())); + Ret = B.CreateSelect(B.CreateFCmpOGE(x, ConstantFP::get(Ty, 0.0)), + tempRet1, tempRet2); +>>>>>>> 89fe0551ddae (实现了以下几种数学库和并优化,并进行了功能验证:) break; default: return nullptr; diff --git a/llvm/test/Transforms/InstCombine/pow-sqrt-exp-cbrt.ll b/llvm/test/Transforms/InstCombine/pow-sqrt-exp-cbrt.ll index 941796b4926d..0f3dc48cc328 100644 --- a/llvm/test/Transforms/InstCombine/pow-sqrt-exp-cbrt.ll +++ b/llvm/test/Transforms/InstCombine/pow-sqrt-exp-cbrt.ll @@ -54,9 +54,14 @@ define double @sqrt_nroot(double %x, double %n){ ; CHECK-LABEL: @sqrt_nroot( ; CHECK-NEXT: [[DIV:%.*]] = fdiv double 1.000000e+00, [[N:%.*]] ; CHECK-NEXT: [[UNUSED:%.*]] = call fast double @pow(double [[X:%.*]], double [[DIV]]) +<<<<<<< HEAD ; CHECK-NEXT: [[ABSX:%.*]] = call fast double @llvm.fabs.f64(double [[X:%.*]]) ; CHECK-NEXT: [[MUL:%.*]] = fmul fast double [[DIV]], 5.000000e-01 ; CHECK-NEXT: [[POW:%.*]] = call fast double @pow(double [[ABSX]], double [[MUL]]) +======= +; CHECK-NEXT: [[MUL:%.*]] = fmul fast double [[DIV]], 5.000000e-01 +; CHECK-NEXT: [[POW:%.*]] = call fast double @pow(double [[X:%.*]], double [[MUL]]) +>>>>>>> 89fe0551ddae (实现了以下几种数学库和并优化,并进行了功能验证:) ; CHECK-NEXT: ret double [[POW]] ; %div = fdiv double 1.000000e+00, %n @@ -69,9 +74,14 @@ define float @sqrtf_nroot(float %x, float %n){ ; CHECK-LABEL: @sqrtf_nroot( ; CHECK-NEXT: [[DIV:%.*]] = fdiv float 1.000000e+00, [[N:%.*]] ; CHECK-NEXT: [[UNUSED:%.*]] = call fast float @powf(float [[X:%.*]], float [[DIV]]) +<<<<<<< HEAD ; CHECK-NEXT: [[ABSX:%.*]] = call fast float @llvm.fabs.f32(float [[X:%.*]]) ; CHECK-NEXT: [[MUL:%.*]] = fmul fast float [[DIV]], 5.000000e-01 ; CHECK-NEXT: [[POW:%.*]] = call fast float @powf(float [[ABSX]], float [[MUL]]) +======= +; CHECK-NEXT: [[MUL:%.*]] = fmul fast float [[DIV]], 5.000000e-01 +; CHECK-NEXT: [[POW:%.*]] = call fast float @powf(float [[X:%.*]], float [[MUL]]) +>>>>>>> 89fe0551ddae (实现了以下几种数学库和并优化,并进行了功能验证:) ; CHECK-NEXT: ret float [[POW]] ; %div = fdiv float 1.000000e+00, %n @@ -83,9 +93,14 @@ define float @sqrtf_nroot(float %x, float %n){ define double @sqrt_pow(double %x, double %y) { ; CHECK-LABEL: @sqrt_pow( ; CHECK-NEXT: [[UNUSED:%.*]] = call fast double @pow(double [[X:%.*]], double [[Y:%.*]]) +<<<<<<< HEAD ; CHECK-NEXT: [[ABSX:%.*]] = call fast double @llvm.fabs.f64(double [[X:%.*]]) ; CHECK-NEXT: [[MUL:%.*]] = fmul fast double [[Y:%.*]], 5.000000e-01 ; CHECK-NEXT: [[POW:%.*]] = call fast double @pow(double [[ABSX]], double [[MUL]]) +======= +; CHECK-NEXT: [[MUL:%.*]] = fmul fast double [[Y:%.*]], 5.000000e-01 +; CHECK-NEXT: [[POW:%.*]] = call fast double @pow(double [[X:%.*]], double [[MUL]]) +>>>>>>> 89fe0551ddae (实现了以下几种数学库和并优化,并进行了功能验证:) ; CHECK-NEXT: ret double [[POW]] ; %call = call fast double @pow(double %x, double %y) @@ -96,9 +111,14 @@ define double @sqrt_pow(double %x, double %y) { define float @sqrtf_powf(float %x, float %y) { ; CHECK-LABEL: @sqrtf_powf( ; CHECK-NEXT: [[UNUSED:%.*]] = call fast float @powf(float [[X:%.*]], float [[Y:%.*]]) +<<<<<<< HEAD ; CHECK-NEXT: [[ABSX:%.*]] = call fast float @llvm.fabs.f32(float [[X:%.*]]) ; CHECK-NEXT: [[MUL:%.*]] = fmul fast float [[Y:%.*]], 5.000000e-01 ; CHECK-NEXT: [[POW:%.*]] = call fast float @powf(float [[ABSX]], float [[MUL]]) +======= +; CHECK-NEXT: [[MUL:%.*]] = fmul fast float [[Y:%.*]], 5.000000e-01 +; CHECK-NEXT: [[POW:%.*]] = call fast float @powf(float [[X:%.*]], float [[MUL]]) +>>>>>>> 89fe0551ddae (实现了以下几种数学库和并优化,并进行了功能验证:) ; CHECK-NEXT: ret float [[POW]] ; %call = call fast float @powf(float %x, float %y) @@ -217,6 +237,10 @@ declare float @cbrtf(float) declare double @exp(double) declare float @expf(float) declare double @exp2(double) +<<<<<<< HEAD declare float @exp2f(float) declare double @llvm.fabs.f64(double) declare float @llvm.fabs.f32(float) +======= +declare float @exp2f(float) +>>>>>>> 89fe0551ddae (实现了以下几种数学库和并优化,并进行了功能验证:) -- Gitee