From 456d5dc6d7c175c82c381d9e681ab3f82fb6e3f1 Mon Sep 17 00:00:00 2001 From: qinlcdy Date: Tue, 7 May 2024 20:36:04 +0800 Subject: [PATCH] pta add npu_moe_gating_top_k_softmax onnx --- test/onnx/test_wrapper_onnx_ops.py | 24 ++++++++++++++++++++++++ test/test_fake_tensor.py | 18 ++++++++++++++++++ torch_npu/meta/_meta_registrations.py | 20 ++++++++++++++++++++ torch_npu/onnx/wrapper_onnx_ops.py | 21 ++++++++++++++++++++- 4 files changed, 82 insertions(+), 1 deletion(-) diff --git a/test/onnx/test_wrapper_onnx_ops.py b/test/onnx/test_wrapper_onnx_ops.py index ced26ab6aa..d33ae4753b 100644 --- a/test/onnx/test_wrapper_onnx_ops.py +++ b/test/onnx/test_wrapper_onnx_ops.py @@ -1435,5 +1435,29 @@ class TestOnnxOps(TestCase): assert (os.path.isfile(os.path.join(TestOnnxOps.test_onnx_path, onnx_model_name))) + @SupportedDevices(['Ascend910B']) + def test_wrapper_npu_moe_gating_top_k_softmax(self): + class Model(torch.nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, finished=None, k=1): + return torch_npu.npu_moe_gating_top_k_softmax(x, finished, k=k) + + def export_onnx(onnx_model_name): + x = torch.tensor([[0.1, 0.1, 0.1, 0.1], + [0.2, 0.2, 0.2, 0.2], + [0.3, 0.3, 0.3, 0.3]], dtype=torch.float32).to("npu") + model = Model().to("npu") + model(x, None, 2) + self.onnx_export(model, (x, None, 2), onnx_model_name, + input_names=["x", "finished", "k"], + output_names=["y", "expert_idx", "row_idx"]) + + onnx_model_name = "model_npu_moe_gating_top_k_softmax.onnx" + export_onnx(onnx_model_name) + assert (os.path.isfile(os.path.join(TestOnnxOps.test_onnx_path, + onnx_model_name))) + if __name__ == '__main__': run_tests() diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py index 4feb9c502b..bd1e753071 100644 --- a/test/test_fake_tensor.py +++ b/test/test_fake_tensor.py @@ -1399,6 +1399,24 @@ class TestNpuQuantScatterMeta(TestCase): self.assertIsNot(fake_result, in_var) +class TestNpuMoeGatingTopKSoftmax(TestCase): + # meta shape推导 + def testNpuMoeGatingTopKSoftmax(self): + with FakeTensorMode(): + x = torch.randn(3, 4, dtype=torch.float).npu() + y_golden = torch.randn(3, 2, dtype=torch.float).npu() + expert_idx_golden = torch.randint(-1, 1, (3, 2), dtype=torch.int32).npu() + row_idx_golden = torch.randint(-1, 1, (3, 2), dtype=torch.int32).npu() + y, expert_idx, row_idx = torch.ops.npu.npu_moe_gating_top_k_softmax(x, None, k=2) + + self.assertTrue(y.dtype == y_golden.dtype) + self.assertTrue(expert_idx.dtype == expert_idx_golden.dtype) + self.assertTrue(row_idx.dtype == row_idx_golden.dtype) + self.assertTrue(y.shape == y_golden.shape) + self.assertTrue(expert_idx.shape == expert_idx_golden.shape) + self.assertTrue(row_idx.shape == row_idx_golden.shape) + + class TestNpuApplyRotoryPosEmbMeta(TestCase): def test_npu_apply_rotary_pos_emb_meta(self): diff --git a/torch_npu/meta/_meta_registrations.py b/torch_npu/meta/_meta_registrations.py index cd3c7f5ba9..14d2db7078 100644 --- a/torch_npu/meta/_meta_registrations.py +++ b/torch_npu/meta/_meta_registrations.py @@ -552,6 +552,26 @@ def npu_apply_rotary_pos_emb_meta(query, key, cos, sin, layout=1): return (torch.empty_like(query, dtype=query.dtype), torch.empty_like(key, dtype=key.dtype)) +@impl(m, "npu_moe_gating_top_k_softmax") +def npu_moe_gating_top_k_softmax_meta(x, finished=None, k=1): + x_dim = x.dim() + torch._check( + x_dim == 2 or x_dim == 3, + lambda: "the x shape support only 2d and 3d)", + ) + if x_dim == 3: + y_dim_list = [x.size(0), x.size(1), k] + expert_idx_dim_list = [x.size(0), x.size(1), k] + row_idx_dim_list = [x.size(0), x.size(1), k] + else: + y_dim_list = [x.size(0), k] + expert_idx_dim_list = [x.size(0), k] + row_idx_dim_list = [x.size(0), k] + return (x.new_empty(tuple(y_dim_list), dtype=x.dtype), + x.new_empty(tuple(expert_idx_dim_list), dtype=torch.int32), + x.new_empty(tuple(row_idx_dim_list), dtype=torch.int32)) + + @impl(m, "npu_quant_conv2d") def npu_quant_conv2d(input_, weight, scale, strides, pads, dilations, groups=1, offset_x=0, round_mode='rint', output_dtype=None, bias=None, offset=None): diff --git a/torch_npu/onnx/wrapper_onnx_ops.py b/torch_npu/onnx/wrapper_onnx_ops.py index bdd340b98a..0c036028dc 100644 --- a/torch_npu/onnx/wrapper_onnx_ops.py +++ b/torch_npu/onnx/wrapper_onnx_ops.py @@ -846,6 +846,20 @@ class NPUMoeFinalizeRoutingOP(torch.autograd.Function): scales, expanded_src_to_dst_row, expert_for_source_row) +class NPUMoeGatingTopKSoftmaxOP(torch.autograd.Function): + + @staticmethod + def forward(ctx, *args, **kwargs): + return torch.ops.npu.npu_moe_gating_top_k_softmax(*args, **kwargs) + + @staticmethod + def symbolic(g, + x: torch.Tensor, + finished: Optional[Tensor], + k: int = 1): + return g.op("npu::NPUMoeGatingTopKSoftmax", x, finished, k_i=k, outputs=3) + + def wrapper_npu_masked_softmax_with_rel_pos_bias(x, atten_mask, relative_pos_bias, scale_value=1.0, inner_precision_mode=0): return NPUMaskedSoftmaxWithRelPosBiasOP.apply(x, atten_mask, relative_pos_bias, scale_value, inner_precision_mode) @@ -1129,6 +1143,10 @@ def wrapper_npu_quantize(inputs, scales, zero_points, dtype, axis): return NPUQuantizeOP.apply(inputs, scales, zero_points, dtype, axis) +def wrapper_npu_moe_gating_top_k_softmax(x, finished, k): + return NPUMoeGatingTopKSoftmaxOP.apply(x, finished, k) + + def wrapper_npu_moe_finalize_routing(expanded_permuted_rows, skip1, skip2_optional, bias, scales, expanded_src_to_dst_row, expert_for_source_row): return NPUMoeFinalizeRoutingOP.apply(expanded_permuted_rows, skip1, skip2_optional, bias, @@ -1194,4 +1212,5 @@ def add_onnx_ops(): torch_npu.npu_anti_quant = wrapper_npu_anti_quant torch_npu.npu_quantize = wrapper_npu_quantize torch_npu.npu_moe_compute_expert_tokens = wrapper_npu_moe_compute_expert_tokens - torch_npu.npu_moe_finalize_routing = wrapper_npu_moe_finalize_routing \ No newline at end of file + torch_npu.npu_moe_finalize_routing = wrapper_npu_moe_finalize_routing + torch_npu.npu_moe_gating_top_k_softmax = wrapper_npu_moe_gating_top_k_softmax \ No newline at end of file -- Gitee