From ac52482e038dce667e5edbf7d6c053b0a9c2953f Mon Sep 17 00:00:00 2001 From: XinfangZhang Date: Fri, 19 Apr 2024 10:53:12 +0800 Subject: [PATCH 1/2] Add meta infershape for gmm_tensor --- test/test_fake_tensor.py | 53 ++++++++++++++++++++++++++++ torch_npu/meta/meta_registrations.py | 13 +++++++ 2 files changed, 66 insertions(+) diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py index 3d8ad3d2db..6de21c790f 100644 --- a/test/test_fake_tensor.py +++ b/test/test_fake_tensor.py @@ -1749,6 +1749,59 @@ class TestGroupedMatmul(TestCase): self.assertTrue(w[0].shape[1] == res[0].shape[1]) +class TestGMMTensor(TestCase): + def test_npu_gmm_tensor_meta(self): + with FakeTensorMode(): + torch.manual_seed(0) + x = torch.randn(6, 256, dtype=torch.float16).npu() + w = torch.randn(3, 256, 128, dtype=torch.float16).npu() + b = torch.randn(3, 128, dtype=torch.float16).npu() + group_list = torch.zeros(3) + group_list[0] = 1 + group_list[1] = 3 + group_list[2] = 6 + group_type = 0 + + res = torch_npu.npu_gmm_tensor(x, w, group_list.npu(), bias=b, group_type=group_type) + self.assertTrue(x.shape[0] == res.shape[0]) + self.assertTrue(w.shape[2] == res.shape[1]) + self.assertTrue(b.shape[1] == res.shape[1]) + + def test_npu_gmm_tensor_meta_x_transposed(self): + with FakeTensorMode(): + torch.manual_seed(0) + x = torch.randn(256, 6, dtype=torch.float16).npu() + w = torch.randn(3, 256, 128, dtype=torch.float16).npu() + b = torch.randn(3, 128, dtype=torch.float16).npu() + group_list = torch.zeros(3) + group_list[0] = 1 + group_list[1] = 3 + group_list[2] = 6 + group_type = 0 + + res = torch_npu.npu_gmm_tensor(x, w, group_list.npu(), bias=b, group_type=group_type) + self.assertTrue(x.shape[1] == res.shape[0]) + self.assertTrue(w.shape[2] == res.shape[1]) + self.assertTrue(b.shape[1] == res.shape[1]) + + def test_npu_gmm_tensor_meta_weight_transposed(self): + with FakeTensorMode(): + torch.manual_seed(0) + x = torch.randn(6, 256, dtype=torch.float16).npu() + w = torch.randn(3, 128, 256, dtype=torch.float16).npu() + b = torch.randn(3, 128, dtype=torch.float16).npu() + group_list = torch.zeros(3) + group_list[0] = 1 + group_list[1] = 3 + group_list[2] = 6 + group_type = 0 + + res = torch_npu.npu_gmm_tensor(x, w, group_list.npu(), bias=b, group_type=group_type) + self.assertTrue(x.shape[0] == res.shape[0]) + self.assertTrue(w.shape[1] == res.shape[1]) + self.assertTrue(b.shape[1] == res.shape[1]) + + class TestQuantMatmul(TestCase): def test_npu_quant_matmul_meta(self): with FakeTensorMode(): diff --git a/torch_npu/meta/meta_registrations.py b/torch_npu/meta/meta_registrations.py index e43e65035c..62e0643826 100644 --- a/torch_npu/meta/meta_registrations.py +++ b/torch_npu/meta/meta_registrations.py @@ -298,6 +298,19 @@ def npu_grouped_matmul_meta(x, weight, *, bias, scale=None, offset=None, antiqua return y +@impl(m, "npu_gmm_tensor") +def npu_gmm_tensor_meta(x, weight, group_list, *, bias=None, group_type=-1): + BM = x.shape[0] + N = weight.shape[2] + if x.shape[0] == weight.shape[1]: + BM = x.shape[1] + elif x.shape[1] == weight.shape[2]: + N = weight.shape[1] + + y = x.new_empty((BM, N), dtype=x.dtype) + return y + + @impl(m, "npu_scatter_nd_update") def scatter_nd_update_meta(self, indices, updates): return torch.empty_like(self, dtype=self.dtype) -- Gitee From a9aa8f390226cefb936a47bde8bb2151ac4eefbc Mon Sep 17 00:00:00 2001 From: XinfangZhang Date: Sat, 20 Apr 2024 19:04:24 +0800 Subject: [PATCH 2/2] Correct meta infershape of npu_gmm_tensor --- torch_npu/meta/meta_registrations.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/torch_npu/meta/meta_registrations.py b/torch_npu/meta/meta_registrations.py index 62e0643826..ff4bb0cbae 100644 --- a/torch_npu/meta/meta_registrations.py +++ b/torch_npu/meta/meta_registrations.py @@ -301,11 +301,7 @@ def npu_grouped_matmul_meta(x, weight, *, bias, scale=None, offset=None, antiqua @impl(m, "npu_gmm_tensor") def npu_gmm_tensor_meta(x, weight, group_list, *, bias=None, group_type=-1): BM = x.shape[0] - N = weight.shape[2] - if x.shape[0] == weight.shape[1]: - BM = x.shape[1] - elif x.shape[1] == weight.shape[2]: - N = weight.shape[1] + N = weight.shape[-1] y = x.new_empty((BM, N), dtype=x.dtype) return y -- Gitee