From 43009d3929bbafe2cf63d104ac8945db9a9bf267 Mon Sep 17 00:00:00 2001 From: zhuceHW Date: Thu, 28 Aug 2025 11:27:42 +0800 Subject: [PATCH] 1. inductor no more need decomposition op white list, remove it 2.fix erfc decomposition logic 3.remove redundant lowering overload op --- torch_npu/_inductor/decomposition.py | 26 ++----------------------- torch_npu/_inductor/lowering.py | 10 ++++------ torch_npu/_inductor/lowering_op_list.py | 8 +++----- 3 files changed, 9 insertions(+), 35 deletions(-) diff --git a/torch_npu/_inductor/decomposition.py b/torch_npu/_inductor/decomposition.py index b9c725f3ff..8770f2b42c 100644 --- a/torch_npu/_inductor/decomposition.py +++ b/torch_npu/_inductor/decomposition.py @@ -1,32 +1,10 @@ import torch._ops -from torch._inductor.decomposition import decompositions, pw_cast_for_opmath -from torch._inductor.decomposition import register_decomposition - -from .lowering import _init_set +from torch._inductor.decomposition import pw_cast_for_opmath, register_decomposition aten = torch.ops.aten -DECOMPOSITION_OVERLOAD_OP = [ - aten._log_softmax, - aten.nll_loss_forward, - # aten.gelu_backward, - # aten.gelu, - aten.nll_loss_backward, - aten._log_softmax_backward_data, - aten.embedding_dense_backward, - aten.addmm, - aten.gelu -] - def _register_npu_inductor_decompositons(): - overload_op_set = set() - _init_set(DECOMPOSITION_OVERLOAD_OP, overload_op_set) - - for op in overload_op_set: - if (op in decompositions): - del decompositions[op] - @register_decomposition([aten.scatter.src]) @pw_cast_for_opmath def scatter_src(self, input_tensor, dim, index_tensor, source_tensor): @@ -45,5 +23,5 @@ def _register_npu_inductor_decompositons(): @register_decomposition([aten.erfc]) def erfc(x): - tensor = torch.ones_like(x) - torch.exp(x) + tensor = torch.ones_like(x) - torch.erf(x) return tensor diff --git a/torch_npu/_inductor/lowering.py b/torch_npu/_inductor/lowering.py index 2b47e091af..4448a0b059 100644 --- a/torch_npu/_inductor/lowering.py +++ b/torch_npu/_inductor/lowering.py @@ -126,8 +126,10 @@ def _register_npu_inductor_fallbacks(): if flag: continue else: - make_fallback(op) FALLBACK_LIST.append(op) + for op in FALLBACK_LIST: + make_fallback(op) + # 把需要overload的op在lowering里删除 for op in overload_op_set: if op in lowerings: @@ -262,8 +264,4 @@ def _register_npu_inductor_fallbacks(): @register_lowering(aten.cat) def cat(inputs, dim=0): - return fallback_handler(aten.cat.default)(inputs, dim) - - make_fallback(aten._log_softmax) - make_fallback(aten.gather) - make_fallback(aten.nll_loss_forward) + return fallback_handler(aten.cat.default)(inputs, dim) \ No newline at end of file diff --git a/torch_npu/_inductor/lowering_op_list.py b/torch_npu/_inductor/lowering_op_list.py index db9c427e60..c101457f5e 100644 --- a/torch_npu/_inductor/lowering_op_list.py +++ b/torch_npu/_inductor/lowering_op_list.py @@ -81,7 +81,9 @@ GENERATE_LIST2 = [ "foreach" ] -FALLBACK_LIST = [] +FALLBACK_LIST = [ + aten.gather, +] # Delete these op in lowering list and then update lowering list with new lowering, # otherwise, it will not use npu overload lowering. @@ -99,9 +101,5 @@ LOWERING_OVERLOAD_OP = [ aten.var, aten.embedding, - aten.split, - aten.split_with_sizes, - aten.nll_loss_forward, - aten.gather, aten.cat, ] -- Gitee