From a181faed42bc6c91887f3043b92063fd8daf81ee Mon Sep 17 00:00:00 2001 From: 18761692867 Date: Sat, 9 Mar 2024 16:38:54 +0800 Subject: [PATCH] support npu_tome_merge --- test/onnx/test_wrapper_onnx_ops.py | 23 +++++++++++++++++++++++ test/test_fake_tensor.py | 25 +++++++++++++++++++++++++ torch_npu/meta/meta_registrations.py | 15 +++++++++++++++ torch_npu/onnx/wrapper_onnx_ops.py | 14 ++++++++++++++ 4 files changed, 77 insertions(+) diff --git a/test/onnx/test_wrapper_onnx_ops.py b/test/onnx/test_wrapper_onnx_ops.py index fa03e71d75f..f36ff186f7f 100644 --- a/test/onnx/test_wrapper_onnx_ops.py +++ b/test/onnx/test_wrapper_onnx_ops.py @@ -1295,6 +1295,29 @@ class TestOnnxOps(TestCase): export_onnx(onnx_model_name) assert (os.path.isfile(os.path.join(TestOnnxOps.test_onnx_path, onnx_model_name))) + @SupportedDevices(['Ascend910B']) + def test_wrapper_npu_tome_merge(self): + class Model(torch.nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, tokenA, tokenB, indices, argmax): + topRate = 0.5 + x = torch_npu.npu_tome_merge(tokenA, tokenB, indices, argmax, topRate) + return x + + def export_onnx(onnx_model_name): + tokenA = torch.rand(4, 3072, 320).uniform_(-3, 3).npu().to(torch.float16) + tokenB = torch.rand(4, 1024, 320).uniform_(-3, 3).npu().to(torch.float16) + indices = torch.rand(4, 3072).uniform_(-3, 3).npu().to(torch.int64) + argmax = torch.rand(4, 3072).uniform_(-3, 3).npu().to(torch.int64) + model = Model().to("npu") + model(tokenA, tokenB, indices, argmax) + self.onnx_export(model, (tokenA, tokenB, indices, argmax), onnx_model_name) + onnx_model_name = "model_npu_tome_merge.onnx" + export_onnx(onnx_model_name) + assert (os.path.isfile(os.path.join(TestOnnxOps.test_onnx_path, onnx_model_name))) + @SupportedDevices(['Ascend910B']) def test_wrapper_npu_weight_quant_batchmatmul(self): class Model(torch.nn.Module): diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py index 97e91a7c115..756688146a1 100644 --- a/test/test_fake_tensor.py +++ b/test/test_fake_tensor.py @@ -1571,6 +1571,31 @@ class TestNpuAddRmsNorm(TestCase): self.assertEqual(result_x.device, npu_x1.device) +class TestNpuTomeMerge(TestCase): + def test_npu_tome_merge(self): + with FakeTensorMode(): + tokenA = torch.randn((4, 3072, 320)).npu().to(torch.float16) + tokenB = torch.randn((4, 1024, 320)).npu().to(torch.float16) + indices = torch.randn((4, 3072)).npu().to(torch.int64) + argmax = torch.randn((4, 3072)).npu().to(torch.int64) + + output1 = torch.randn((4, 1024, 320)).npu().to(torch.float16) + output2 = torch.randn((4, 8, 1024, 320)).npu().to(torch.float16) + output3 = torch.randn((4, 8, 1024)).npu().to(torch.float32) + + actual1, actual2, actual3 = torch_npu.npu_tome_merge(tokenA, tokenB, indices, argmax, 0.5) + + self.assertEqual(actual1.dtype, output1.dtype) + self.assertEqual(actual1.shape, output1.shape) + self.assertEqual(actual1.device, output1.device) + self.assertEqual(actual2.dtype, output2.dtype) + self.assertEqual(actual2.shape, output2.shape) + self.assertEqual(actual2.device, output2.device) + self.assertEqual(actual3.dtype, output3.dtype) + self.assertEqual(actual3.shape, output3.shape) + self.assertEqual(actual3.device, output3.device) + + class TestFFN(TestCase): def test_npu_ffn_meta(self): with FakeTensorMode(): diff --git a/torch_npu/meta/meta_registrations.py b/torch_npu/meta/meta_registrations.py index 9765057b880..6b2e6e9974c 100644 --- a/torch_npu/meta/meta_registrations.py +++ b/torch_npu/meta/meta_registrations.py @@ -521,6 +521,21 @@ def npu_quantize_meta(self, scales, zero_points, dtype, axis=1): return torch.empty_like(self, dtype=torch.int8) +@impl(m, "npu_tome_merge") +def npu_tome_merge(token_a, token_b, topk_indice, arg_max, top_rate=0.5): + batch = token_a.size(0) + seq_len_a = token_a.size(1) + hidden_size = token_a.size(2) + seq_len_b = token_b.size(1) + topR = math.floor((seq_len_a + seq_len_b) * top_rate) + heads = 8 + unmerge_token_a_dim_list = [batch, seq_len_a - topR, hidden_size] + unmerge_token_b_dim_list = [batch, heads, seq_len_b, hidden_size] + unreduce_count_dim_list = [batch, heads, seq_len_b] + unreduce_count = torch.empty(unreduce_count_dim_list, dtype=torch.float32, device='meta') + return (token_a.new_empty(tuple(unmerge_token_a_dim_list)), token_a.new_empty(tuple(unmerge_token_b_dim_list)), torch.empty_like(unreduce_count)) + + @impl(m, "npu_anti_quant") def npu_anti_quant_meta(x, scale, *, offset=None, dst_dtype=None, src_dtype=None): if dst_dtype is None: diff --git a/torch_npu/onnx/wrapper_onnx_ops.py b/torch_npu/onnx/wrapper_onnx_ops.py index 33494acce7f..6540ab07010 100644 --- a/torch_npu/onnx/wrapper_onnx_ops.py +++ b/torch_npu/onnx/wrapper_onnx_ops.py @@ -702,6 +702,16 @@ class NPUMaskedSoftmaxWithRelPosBiasOP(torch.autograd.Function): inner_precision_mode_i=inner_precision_mode) +class NPUTomeMergeOp(torch.autograd.Function): + + @staticmethod + def forward(ctx, *args, **kwargs): + return torch.ops.npu.npu_tome_merge(*args, **kwargs) + + @staticmethod + def symbolic(g, token_a: Tensor, token_b: Tensor, token_indice: Tensor, arg_max: Tensor, top_rate: float = 0.5): + return g.op("npu::NPUTomeMerge", token_a, token_b, token_indice, arg_max, top_rate, outputs=3) + class NPUMmAllReduceBaseOP(torch.autograd.Function): @@ -1048,6 +1058,9 @@ def wrapper_npu_mm_all_reduce_base(x1, x2, hcom, reduce_op, bias, antiquant_scal dequant_scale, antiquant_group_size, comm_turn) +def wrapper_npu_tome_merge(token_a, token_b, token_indice, arg_max, top_rate=1.0): + return NPUTomeMergeOp.apply(token_a, token_b, token_indice, arg_max, top_rate) + def wrapper_npu_weight_quant_batchmatmul(x, weight, antiquant_scale, antiquant_offset, quant_scale, quant_offset, bias, antiquant_group_size): @@ -1112,6 +1125,7 @@ def add_onnx_ops(): torch_npu.npu_prompt_flash_attention = wrapper_npu_prompt_flash_attention torch_npu.npu_incre_flash_attention = wrapper_npu_incre_flash_attention torch_npu.npu_masked_softmax_with_rel_pos_bias = wrapper_npu_masked_softmax_with_rel_pos_bias + torch_npu.npu_tome_merge = wrapper_npu_tome_merge torch_npu.npu_mm_all_reduce_base = wrapper_npu_mm_all_reduce_base torch_npu.npu_weight_quant_batchmatmul = wrapper_npu_weight_quant_batchmatmul torch_npu.npu_anti_quant = wrapper_npu_anti_quant -- Gitee