diff --git a/torch_npu/_inductor/decomposition.py b/torch_npu/_inductor/decomposition.py index b9c725f3ff2d762e060349b7f2b25afe04c5ec51..8770f2b42c4bd480c0ea731370bdd147d8e968d0 100644 --- a/torch_npu/_inductor/decomposition.py +++ b/torch_npu/_inductor/decomposition.py @@ -1,32 +1,10 @@ import torch._ops -from torch._inductor.decomposition import decompositions, pw_cast_for_opmath -from torch._inductor.decomposition import register_decomposition - -from .lowering import _init_set +from torch._inductor.decomposition import pw_cast_for_opmath, register_decomposition aten = torch.ops.aten -DECOMPOSITION_OVERLOAD_OP = [ - aten._log_softmax, - aten.nll_loss_forward, - # aten.gelu_backward, - # aten.gelu, - aten.nll_loss_backward, - aten._log_softmax_backward_data, - aten.embedding_dense_backward, - aten.addmm, - aten.gelu -] - def _register_npu_inductor_decompositons(): - overload_op_set = set() - _init_set(DECOMPOSITION_OVERLOAD_OP, overload_op_set) - - for op in overload_op_set: - if (op in decompositions): - del decompositions[op] - @register_decomposition([aten.scatter.src]) @pw_cast_for_opmath def scatter_src(self, input_tensor, dim, index_tensor, source_tensor): @@ -45,5 +23,5 @@ def _register_npu_inductor_decompositons(): @register_decomposition([aten.erfc]) def erfc(x): - tensor = torch.ones_like(x) - torch.exp(x) + tensor = torch.ones_like(x) - torch.erf(x) return tensor diff --git a/torch_npu/_inductor/lowering.py b/torch_npu/_inductor/lowering.py index 2b47e091af8f49d3662f4c613a97b505f3e9266b..4448a0b059b6c16ac705676da17d73ad645bb97a 100644 --- a/torch_npu/_inductor/lowering.py +++ b/torch_npu/_inductor/lowering.py @@ -126,8 +126,10 @@ def _register_npu_inductor_fallbacks(): if flag: continue else: - make_fallback(op) FALLBACK_LIST.append(op) + for op in FALLBACK_LIST: + make_fallback(op) + # 把需要overload的op在lowering里删除 for op in overload_op_set: if op in lowerings: @@ -262,8 +264,4 @@ def _register_npu_inductor_fallbacks(): @register_lowering(aten.cat) def cat(inputs, dim=0): - return fallback_handler(aten.cat.default)(inputs, dim) - - make_fallback(aten._log_softmax) - make_fallback(aten.gather) - make_fallback(aten.nll_loss_forward) + return fallback_handler(aten.cat.default)(inputs, dim) \ No newline at end of file diff --git a/torch_npu/_inductor/lowering_op_list.py b/torch_npu/_inductor/lowering_op_list.py index db9c427e60d69e95e39e9a7b83396198831d6070..c101457f5e5efc0d2258bf19baad7514fdf106f0 100644 --- a/torch_npu/_inductor/lowering_op_list.py +++ b/torch_npu/_inductor/lowering_op_list.py @@ -81,7 +81,9 @@ GENERATE_LIST2 = [ "foreach" ] -FALLBACK_LIST = [] +FALLBACK_LIST = [ + aten.gather, +] # Delete these op in lowering list and then update lowering list with new lowering, # otherwise, it will not use npu overload lowering. @@ -99,9 +101,5 @@ LOWERING_OVERLOAD_OP = [ aten.var, aten.embedding, - aten.split, - aten.split_with_sizes, - aten.nll_loss_forward, - aten.gather, aten.cat, ]