diff --git a/test/onnx/test_wrapper_onnx_ops.py b/test/onnx/test_wrapper_onnx_ops.py index 9d8a050d60dceed57d3a7a58781ba598bdabd02d..2378d8ebc86019439788f48bb9b91e007b394ef2 100644 --- a/test/onnx/test_wrapper_onnx_ops.py +++ b/test/onnx/test_wrapper_onnx_ops.py @@ -1435,5 +1435,31 @@ class TestOnnxOps(TestCase): assert (os.path.isfile(os.path.join(TestOnnxOps.test_onnx_path, onnx_model_name))) + + @SupportedDevices(['Ascend910B']) + def test_wrapper_npu_moe_gating_top_k_softmax(self): + class Model(torch.nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, finished=None, k=1): + return torch_npu.npu_moe_gating_top_k_softmax(x, finished, k=k) + + def export_onnx(onnx_model_name): + x = torch.tensor([[0.1, 0.1, 0.1, 0.1], + [0.2, 0.2, 0.2, 0.2], + [0.3, 0.3, 0.3, 0.3]], dtype=torch.float32).to("npu") + model = Model().to("npu") + model(x, None, 2) + self.onnx_export(model, (x, None, 2), onnx_model_name, + input_names=["x", "finished", "k"], + output_names=["y", "expert_idx", "row_idx"]) + + onnx_model_name = "model_npu_moe_gating_top_k_softmax.onnx" + export_onnx(onnx_model_name) + assert (os.path.isfile(os.path.join(TestOnnxOps.test_onnx_path, + onnx_model_name))) + + if __name__ == '__main__': run_tests() diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py index 4df9b818334fc5cb745c54d344ee876dec53aaf7..6b43743a539fa535ecb9b655bce370bebfed794e 100644 --- a/test/test_fake_tensor.py +++ b/test/test_fake_tensor.py @@ -1377,6 +1377,24 @@ class TestMmAllReduce(TestCase): self.assertEqual(output.dtype, dst_dtype) +class TestNpuMoeGatingTopKSoftmax(TestCase): + # meta shape推导 + def testNpuMoeGatingTopKSoftmax(self): + with FakeTensorMode(): + x = torch.randn(3, 4, dtype=torch.float).npu() + y_golden = torch.randn(3, 2, dtype=torch.float).npu() + expert_idx_golden = torch.randint(-1, 1, (3, 2), dtype=torch.int32).npu() + row_idx_golden = torch.randint(-1, 1, (3, 2), dtype=torch.int32).npu() + y, expert_idx, row_idx = torch.ops.npu.npu_moe_gating_top_k_softmax(x, None, k=2) + + self.assertTrue(y.dtype == y_golden.dtype) + self.assertTrue(expert_idx.dtype == expert_idx_golden.dtype) + self.assertTrue(row_idx.dtype == row_idx_golden.dtype) + self.assertTrue(y.shape == y_golden.shape) + self.assertTrue(expert_idx.shape == expert_idx_golden.shape) + self.assertTrue(row_idx.shape == row_idx_golden.shape) + + class TestScatterUpdateMeta(TestCase): def test_scatter_update_meta(self): diff --git a/torch_npu/meta/_meta_registrations.py b/torch_npu/meta/_meta_registrations.py index 1f152dcd94c969123796f104b8e84cdf9281136c..547640c7fb3392b3e0ba537a3c1ba7918f5eb81f 100644 --- a/torch_npu/meta/_meta_registrations.py +++ b/torch_npu/meta/_meta_registrations.py @@ -563,6 +563,26 @@ def npu_quantize_meta(self, scales, zero_points, dtype, axis=1): return torch.empty_like(self, dtype=torch.int8) +@impl(m, "npu_moe_gating_top_k_softmax") +def npu_moe_gating_top_k_softmax_meta(x, finished=None, k=1): + x_dim = x.dim() + torch._check( + x_dim == 2 or x_dim == 3, + lambda: "the x shape support only 2d and 3d)", + ) + if x_dim == 3: + y_dim_list = [x.size(0), x.size(1), k] + expert_idx_dim_list = [x.size(0), x.size(1), k] + row_idx_dim_list = [x.size(0), x.size(1), k] + else: + y_dim_list = [x.size(0), k] + expert_idx_dim_list = [x.size(0), k] + row_idx_dim_list = [x.size(0), k] + return (x.new_empty(tuple(y_dim_list), dtype=x.dtype), + x.new_empty(tuple(expert_idx_dim_list), dtype=torch.int32), + x.new_empty(tuple(row_idx_dim_list), dtype=torch.int32)) + + @impl(m, "npu_dynamic_quant") def npu_dynamic_quant(input_dummy, *, smooth_scales=None): return (torch.empty_like(input_dummy, dtype=torch.int8), diff --git a/torch_npu/onnx/wrapper_onnx_ops.py b/torch_npu/onnx/wrapper_onnx_ops.py index b839235ccf27936fd98d8b67b440e92b8bb358c4..ec1e4ad2b1a9c1c3c0e12d0e4035a557a64c8d60 100644 --- a/torch_npu/onnx/wrapper_onnx_ops.py +++ b/torch_npu/onnx/wrapper_onnx_ops.py @@ -844,6 +844,20 @@ class NPUMoeFinalizeRoutingOP(torch.autograd.Function): scales, expanded_src_to_dst_row, expert_for_source_row) +class NPUMoeGatingTopKSoftmaxOP(torch.autograd.Function): + + @staticmethod + def forward(ctx, *args, **kwargs): + return torch.ops.npu.npu_moe_gating_top_k_softmax(*args, **kwargs) + + @staticmethod + def symbolic(g, + x: torch.Tensor, + finished: Optional[Tensor], + k: int = 1): + return g.op("npu::NPUMoeGatingTopKSoftmax", x, finished, k_i=k, outputs=3) + + class NPUMoeComputeExpertTokensOP(torch.autograd.Function): @staticmethod @@ -1126,6 +1140,9 @@ def wrapper_npu_mm_all_reduce_base(x1, x2, hcom, reduce_op, bias, antiquant_scal dequant_scale, antiquant_group_size, comm_turn) +def wrapper_npu_moe_gating_top_k_softmax(x, finished, k): + return NPUMoeGatingTopKSoftmaxOP.apply(x, finished, k) + def wrapper_npu_weight_quant_batchmatmul(x, weight, antiquant_scale, antiquant_offset, quant_scale, quant_offset, bias, antiquant_group_size): @@ -1211,4 +1228,5 @@ def add_onnx_ops(): torch_npu.npu_anti_quant = wrapper_npu_anti_quant torch_npu.npu_quantize = wrapper_npu_quantize torch_npu.npu_moe_compute_expert_tokens = wrapper_npu_moe_compute_expert_tokens - torch_npu.npu_moe_finalize_routing = wrapper_npu_moe_finalize_routing \ No newline at end of file + torch_npu.npu_moe_finalize_routing = wrapper_npu_moe_finalize_routing + torch_npu.npu_moe_gating_top_k_softmax = wrapper_npu_moe_gating_top_k_softmax \ No newline at end of file