diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp index 21fe448218bc7728c30e525d6fef411860ca26d0..a98346156066d62e271bccdaa88d362fe58b77ee 100644 --- a/llvm/lib/Analysis/InstructionSimplify.cpp +++ b/llvm/lib/Analysis/InstructionSimplify.cpp @@ -33,12 +33,14 @@ #include "llvm/Analysis/ValueTracking.h" #include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/ConstantRange.h" +#include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Operator.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/Support/CommandLine.h" #include "llvm/Support/KnownBits.h" #include using namespace llvm; @@ -51,6 +53,13 @@ enum { RecursionLimit = 3 }; STATISTIC(NumExpand, "Number of expansions"); STATISTIC(NumReassoc, "Number of reassociations"); +namespace llvm { +cl::opt EnableTernaryAbsOptimization( + "enable-ternary-abs-optimization", + cl::desc("Enable optimization of abs() call in ternary expression"), + cl::init(false)); +} // namespace llvm + static Value *simplifyAndInst(Value *, Value *, const SimplifyQuery &, unsigned); static Value *simplifyUnOp(unsigned, Value *, const SimplifyQuery &, unsigned); @@ -4257,6 +4266,39 @@ static Value *simplifySelectWithFakeICmpEq(Value *CmpLHS, Value *CmpRHS, Pred == ICmpInst::ICMP_EQ); } +static Value *simplifySelectFromAbs(Value *CmpLHS, Value *CmpRHS, + Value *TrueVal, Value *FalseVal, + const SimplifyQuery &Q) { + Value *X, *Y; + ConstantInt *CI; + if (match(CmpLHS, + m_NSWAdd(m_NSWSub(m_Value(X), m_Value(Y)), m_ConstantInt(CI))) && + match(CmpRHS, m_Zero())) + if (Optional Flag = + isImpliedByDomCondition(ICmpInst::ICMP_SGE, X, Y, Q.CxtI, Q.DL)) { + // x-y+1 is positive when x>=y or non-positive when xisOne()) + return *Flag ? FalseVal : TrueVal; + // x-y+n and x+n-y is positive when x>=y and n>=0 + if (!CI->isNegative()) + return *Flag ? FalseVal : nullptr; + } + // x-y-1 is negative when x<=y or non-negative when x>y + if (match(CmpLHS, m_Add(m_Xor(m_Value(Y), m_AllOnes()), m_Value(X))) && + match(CmpRHS, m_Zero())) + if (Optional Flag = + isImpliedByDomCondition(ICmpInst::ICMP_SLE, X, Y, Q.CxtI, Q.DL)) + return *Flag ? TrueVal : FalseVal; + // x-y-n is negative when x-y<=0 and -n<0 + if (match(CmpLHS, m_Add(m_Sub(m_Value(X), m_Value(Y)), m_Negative())) && + match(CmpRHS, m_Zero())) { + if (Optional Flag = + isImpliedByDomCondition(ICmpInst::ICMP_SLE, X, Y, Q.CxtI, Q.DL)) + return *Flag ? TrueVal : nullptr; + } + return nullptr; +} + /// Try to simplify a select instruction when its condition operand is an /// integer comparison. static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal, @@ -4360,6 +4402,13 @@ static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal, return FalseVal; } + // when select expression is converted from abs() call, it's right TrueVal and + // FalseVal are complement, and we try to optimize its value to one of its arm + // value based on the signess + if (Pred == ICmpInst::ICMP_SLT && EnableTernaryAbsOptimization) { + return simplifySelectFromAbs(CmpLHS, CmpRHS, TrueVal, FalseVal, Q); + } + return nullptr; } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp index 52596b30494fa3083634d46e82c2185bc7424dbf..31ca9ffa12770517bf875d8d8e6aaa743b93bfe2 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -92,6 +92,8 @@ namespace llvm { /// enable preservation of attributes in assume like: /// call void @llvm.assume(i1 true) [ "nonnull"(i32* %PTR) ] extern cl::opt EnableKnowledgeRetention; +/// enable optimization of abs call in ternary expression +extern cl::opt EnableTernaryAbsOptimization; } // namespace llvm /// Return the specified type promoted as it would be to pass though a va_arg @@ -822,9 +824,18 @@ static Optional getKnownSign(Value *Op, Instruction *CxtI, return true; Value *X, *Y; - if (match(Op, m_NSWSub(m_Value(X), m_Value(Y)))) + ConstantInt *CI = nullptr; + // abs(n*(x-y)) -> n*(x-y) or n*(y-x) when n>0 + if (match(Op, m_NSWSub(m_Value(X), m_Value(Y))) || + ((EnableTernaryAbsOptimization && + (match(Op, m_NSWShl(m_NSWSub(m_Value(X), m_Value(Y)), + m_StrictlyPositive())) || + match(Op, m_NSWMul(m_NSWSub(m_Value(X), m_Value(Y)), + m_ConstantInt(CI))))))) { + if (CI && CI->isNegative()) + return isImpliedByDomCondition(ICmpInst::ICMP_SGT, X, Y, CxtI, DL); return isImpliedByDomCondition(ICmpInst::ICMP_SLT, X, Y, CxtI, DL); - + } return isImpliedByDomCondition( ICmpInst::ICMP_SLT, Op, Constant::getNullValue(Op->getType()), CxtI, DL); } diff --git a/llvm/test/Transforms/InstCombine/abs-ternary.ll b/llvm/test/Transforms/InstCombine/abs-ternary.ll new file mode 100644 index 0000000000000000000000000000000000000000..17310efd2f6b1d193483686c26743a1b29bfdda2 --- /dev/null +++ b/llvm/test/Transforms/InstCombine/abs-ternary.ll @@ -0,0 +1,337 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt < %s -passes=instcombine -enable-ternary-abs-optimization -S | FileCheck %s + +declare i32 @abs(i32) +declare i32 @llvm.abs.i32(i32, i1) + +; x>y ? abs(x-y+1) : 0 -> x>y ? x-y+1 : 0 +; https://alive2.llvm.org/ce/z/XZsbMD +define i32 @abs_sub_with_pos_constant_sgt(i32 %x, i32 %y) { +; CHECK-LABEL: @abs_sub_with_pos_constant_sgt( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i32 [[X:%.*]], [[Y:%.*]] +; CHECK-NEXT: br i1 [[CMP]], label [[COND_TRUE:%.*]], label [[COND_END:%.*]] +; CHECK: cond.true: +; CHECK-NEXT: [[SUB:%.*]] = sub nsw i32 [[X]], [[Y]] +; CHECK-NEXT: [[ADD:%.*]] = add nsw i32 [[SUB]], 1 +; CHECK-NEXT: br label [[COND_END]] +; CHECK: cond.end: +; CHECK-NEXT: [[COND:%.*]] = phi i32 [ [[ADD]], [[COND_TRUE]] ], [ 0, [[ENTRY:%.*]] ] +; CHECK-NEXT: ret i32 [[COND]] +; +entry: + %cmp = icmp sgt i32 %x, %y + br i1 %cmp, label %cond.true, label %cond.end + +cond.true: + %sub = sub nsw i32 %x, %y + %add = add nsw i32 %sub, 1 + %call = call i32 @abs(i32 %add) + br label %cond.end + +cond.end: + %cond = phi i32 [ %call, %cond.true ], [ 0, %entry ] + ret i32 %cond +} + +; only optimize abs() call to select expression +; https://alive2.llvm.org/ce/z/fMVLDa +define i32 @abs_sub_with_pos_constant_sgt_nonsw(i32 %x, i32 %y) { +; CHECK-LABEL: @abs_sub_with_pos_constant_sgt_nonsw( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i32 [[X:%.*]], [[Y:%.*]] +; CHECK-NEXT: br i1 [[CMP]], label [[COND_TRUE:%.*]], label [[COND_END:%.*]] +; CHECK: cond.true: +; CHECK-NEXT: [[SUB:%.*]] = sub i32 [[X]], [[Y]] +; CHECK-NEXT: [[ADD:%.*]] = add i32 [[SUB]], 1 +; CHECK-NEXT: [[TMP0:%.*]] = icmp slt i32 [[ADD]], 0 +; CHECK-NEXT: [[NEG:%.*]] = xor i32 [[SUB]], -1 +; CHECK-NEXT: [[TMP1:%.*]] = select i1 [[TMP0]], i32 [[NEG]], i32 [[ADD]] +; CHECK-NEXT: br label [[COND_END]] +; CHECK: cond.end: +; CHECK-NEXT: [[COND:%.*]] = phi i32 [ [[TMP1]], [[COND_TRUE]] ], [ 0, [[ENTRY:%.*]] ] +; CHECK-NEXT: ret i32 [[COND]] +; +entry: + %cmp = icmp sgt i32 %x, %y + br i1 %cmp, label %cond.true, label %cond.end + +cond.true: + %sub = sub i32 %x, %y + %add = add i32 %sub, 1 + %call = call i32 @abs(i32 %add) + br label %cond.end + +cond.end: + %cond = phi i32 [ %call, %cond.true ], [ 0, %entry ] + ret i32 %cond +} + +; x>=y ? abs(x-y+2) : 0 -> x>=y ? x-y+2 : 0 +; https://alive2.llvm.org/ce/z/APiJXd +define i32 @abs_sub_with_pos_constant_sge(i32 %x, i32 %y) { +; CHECK-LABEL: @abs_sub_with_pos_constant_sge( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[CMP_NOT:%.*]] = icmp slt i32 [[X:%.*]], [[Y:%.*]] +; CHECK-NEXT: br i1 [[CMP_NOT]], label [[COND_END:%.*]], label [[COND_TRUE:%.*]] +; CHECK: cond.true: +; CHECK-NEXT: [[SUB:%.*]] = sub nsw i32 [[X]], [[Y]] +; CHECK-NEXT: [[ADD:%.*]] = add nsw i32 [[SUB]], 2 +; CHECK-NEXT: br label [[COND_END]] +; CHECK: cond.end: +; CHECK-NEXT: [[COND:%.*]] = phi i32 [ [[ADD]], [[COND_TRUE]] ], [ 0, [[ENTRY:%.*]] ] +; CHECK-NEXT: ret i32 [[COND]] +; +entry: + %cmp = icmp sge i32 %x, %y + br i1 %cmp, label %cond.true, label %cond.end + +cond.true: + %sub = sub nsw i32 %x, %y + %add = add nsw i32 %sub, 2 + %call = call i32 @abs(i32 %add) + br label %cond.end + +cond.end: + %cond = phi i32 [ %call, %cond.true ], [ 0, %entry ] + ret i32 %cond +} + +; x x x<=y ? y-x+3: 0 +; https://alive2.llvm.org/ce/z/z73eBY +define i32 @abs_sub_with_neg_constant_sle(i32 %x, i32 %y) { +; CHECK-LABEL: @abs_sub_with_neg_constant_sle( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[CMP_NOT:%.*]] = icmp slt i32 [[Y:%.*]], [[X:%.*]] +; CHECK-NEXT: br i1 [[CMP_NOT]], label [[COND_END:%.*]], label [[COND_TRUE:%.*]] +; CHECK: cond.true: +; CHECK-NEXT: [[SUB_NEG:%.*]] = sub i32 [[Y]], [[X]] +; CHECK-NEXT: [[NEG:%.*]] = add i32 [[SUB_NEG]], 3 +; CHECK-NEXT: br label [[COND_END]] +; CHECK: cond.end: +; CHECK-NEXT: [[COND:%.*]] = phi i32 [ [[NEG]], [[COND_TRUE]] ], [ 0, [[ENTRY:%.*]] ] +; CHECK-NEXT: ret i32 [[COND]] +; +entry: + %cmp = icmp sle i32 %x, %y + br i1 %cmp, label %cond.true, label %cond.end + +cond.true: + %sub = sub nsw i32 %x, %y + %sub1 = sub nsw i32 %sub, 3 + %call = call i32 @abs(i32 %sub1) + br label %cond.end + +cond.end: + %cond = phi i32 [ %call, %cond.true ], [ 0, %entry ] + ret i32 %cond +} + +; only optimize abs() call to select expression +define i32 @abs_sub_with_neg_constant_sgt(i32 %x, i32 %y) { +; CHECK-LABEL: @abs_sub_with_neg_constant_sgt( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i32 [[X:%.*]], [[Y:%.*]] +; CHECK-NEXT: br i1 [[CMP]], label [[COND_TRUE:%.*]], label [[COND_END:%.*]] +; CHECK: cond.true: +; CHECK-NEXT: [[TMP0:%.*]] = xor i32 [[Y]], -1 +; CHECK-NEXT: [[SUB1:%.*]] = add i32 [[TMP0]], [[X]] +; CHECK-NEXT: br label [[COND_END]] +; CHECK: cond.end: +; CHECK-NEXT: [[COND:%.*]] = phi i32 [ [[SUB1]], [[COND_TRUE]] ], [ 0, [[ENTRY:%.*]] ] +; CHECK-NEXT: ret i32 [[COND]] +; +entry: + %cmp = icmp sgt i32 %x, %y + br i1 %cmp, label %cond.true, label %cond.end + +cond.true: + %sub = sub nsw i32 %x, %y + %sub1 = sub nsw i32 %sub, 1 + %call = call i32 @abs(i32 %sub1) + br label %cond.end + +cond.end: + %cond = phi i32 [ %call, %cond.true ], [ 0, %entry ] + ret i32 %cond +} + +; only optimize abs() call to select expression +define i32 @abs_sub_with_neg_constant_sge(i32 %x, i32 %y) { +; CHECK-LABEL: @abs_sub_with_neg_constant_sge( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[CMP_NOT:%.*]] = icmp slt i32 [[X:%.*]], [[Y:%.*]] +; CHECK-NEXT: br i1 [[CMP_NOT]], label [[COND_END:%.*]], label [[COND_TRUE:%.*]] +; CHECK: cond.true: +; CHECK-NEXT: [[TMP0:%.*]] = xor i32 [[Y]], -1 +; CHECK-NEXT: [[SUB1:%.*]] = add i32 [[TMP0]], [[X]] +; CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.abs.i32(i32 [[SUB1]], i1 true) +; CHECK-NEXT: br label [[COND_END]] +; CHECK: cond.end: +; CHECK-NEXT: [[COND:%.*]] = phi i32 [ [[TMP1]], [[COND_TRUE]] ], [ 0, [[ENTRY:%.*]] ] +; CHECK-NEXT: ret i32 [[COND]] +; +entry: + %cmp = icmp sge i32 %x, %y + br i1 %cmp, label %cond.true, label %cond.end + +cond.true: + %sub = sub nsw i32 %x, %y + %sub1 = sub nsw i32 %sub, 1 + %call = call i32 @abs(i32 %sub1) + br label %cond.end + +cond.end: + %cond = phi i32 [ %call, %cond.true ], [ 0, %entry ] + ret i32 %cond +} + +; only optimize abs() call to select expression +define i32 @abs_sub_with_pos_multiply(i32 %x, i32 %y) { +; CHECK-LABEL: @abs_sub_with_pos_multiply( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i32 [[X:%.*]], [[Y:%.*]] +; CHECK-NEXT: br i1 [[CMP]], label [[COND_TRUE:%.*]], label [[COND_END:%.*]] +; CHECK: cond.true: +; CHECK-NEXT: [[SUB:%.*]] = sub nsw i32 [[X]], [[Y]] +; CHECK-NEXT: [[MUL:%.*]] = shl nsw i32 [[SUB]], 1 +; CHECK-NEXT: br label [[COND_END]] +; CHECK: cond.end: +; CHECK-NEXT: [[R:%.*]] = phi i32 [ [[MUL]], [[COND_TRUE]] ], [ 0, [[ENTRY:%.*]] ] +; CHECK-NEXT: ret i32 [[R]] +; +entry: + %cmp = icmp sgt i32 %x, %y + br i1 %cmp, label %cond.true, label %cond.end + +cond.true: + %sub = sub nsw i32 %x, %y + %mul = shl nsw i32 %sub, 1 + %0 = call i32 @llvm.abs.i32(i32 %mul, i1 true) + br label %cond.end + +cond.end: + %r = phi i32 [ %0, %cond.true ], [ 0, %entry ] + ret i32 %r +} + +; only optimize abs() call to select expression +define i32 @abs_sub_with_neg_multiply(i32 %x, i32 %y) { +; CHECK-LABEL: @abs_sub_with_neg_multiply( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i32 [[X:%.*]], [[Y:%.*]] +; CHECK-NEXT: br i1 [[CMP]], label [[COND_TRUE:%.*]], label [[COND_END:%.*]] +; CHECK: cond.true: +; CHECK-NEXT: [[SUB:%.*]] = sub nsw i32 [[X]], [[Y]] +; CHECK-NEXT: [[MUL_NEG:%.*]] = mul i32 [[SUB]], 3 +; CHECK-NEXT: br label [[COND_END]] +; CHECK: cond.end: +; CHECK-NEXT: [[R:%.*]] = phi i32 [ [[MUL_NEG]], [[COND_TRUE]] ], [ 0, [[ENTRY:%.*]] ] +; CHECK-NEXT: ret i32 [[R]] +; +entry: + %cmp = icmp sgt i32 %x, %y + br i1 %cmp, label %cond.true, label %cond.end + +cond.true: + %sub = sub nsw i32 %x, %y + %mul = mul nsw i32 %sub, -3 + %0 = call i32 @llvm.abs.i32(i32 %mul, i1 true) + br label %cond.end + +cond.end: + %r = phi i32 [ %0, %cond.true ], [ 0, %entry ] + ret i32 %r +}