From 4119ec930b0ee4c0cd5f9db5cd04e9374dc98ea3 Mon Sep 17 00:00:00 2001 From: yuhaiyan Date: Fri, 29 Aug 2025 14:37:05 +0800 Subject: [PATCH] Add testcases for driver version and fixed the test_fake_tensor.py --- test/npu/test_cann_version.py | 23 +++++++++++++++++++---- test/test_fake_tensor.py | 8 ++++---- 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/test/npu/test_cann_version.py b/test/npu/test_cann_version.py index 3663bc78683..f96a88dd88f 100644 --- a/test/npu/test_cann_version.py +++ b/test/npu/test_cann_version.py @@ -12,14 +12,29 @@ class TestCANNversion(TestCase): version_env = get_cann_version_from_env() version = get_cann_version(module="CANN") if not version_env.startswith("CANN") and version_env >= "8.1.RC1": - is_match = (re.match("([0-9]+).([0-9]+).RC([0-9]+)", version) - or re.match("([0-9]+).([0-9]+).([0-9]+)", version) - or re.match("([0-9]+).([0-9]+).T([0-9]+)", version) - or re.match("([0-9]+).([0-9]+).RC([0-9]+).alpha([0-9]+)", version)) + is_match = (re.match("([0-9]+).([0-9]+).RC([0-9]+)$", version) + or re.match("([0-9]+).([0-9]+).([0-9]+)$", version) + or re.match("([0-9]+).([0-9]+).T([0-9]+)$", version) + or re.match("([0-9]+).([0-9]+).RC([0-9]+).alpha([0-9]+)$", version)) self.assertTrue(is_match, f"The env version is {version_env}. The format of cann version {version} is invalid.") version = get_cann_version(module="CAN") self.assertTrue(version == "", "When module is invalid, the result of get_cann_version is not right.") + + def test_get_driver_version(self): + version = get_cann_version(module="CANN") + if re.match("([0-9]+).([0-9]+).RC([0-9]+).B([0-9]+)$", version, re.IGNORECASE): + version = re.sub(".B([0-9]+)", "", version, flags=re.IGNORECASE) + if version >= "25.": + is_match = (re.match("([0-9]+).([0-9]+).RC([0-9]+)$", version, re.IGNORECASE) + or re.match("([0-9]+).([0-9]+).([0-9]+)$", version) + or re.match("([0-9]+).([0-9]+).RC([0-9]+).([0-9]+)$", version, re.IGNORECASE) + or re.match("([0-9]+).([0-9]+).([0-9]+).([0-9]+)$", version) + or re.match("([0-9]+).([0-9]+).T([0-9]+)$", version, re.IGNORECASE) + or re.match("([0-9]+).([0-9]+).RC([0-9]+).beta([0-9]+)$", version, re.IGNORECASE) + or re.match("([0-9]+).([0-9]+).RC([0-9]+).alpha([0-9]+)$", version, re.IGNORECASE) + ) + self.assertTrue(is_match, f"The format of driver version {version} is invalid.") def test_compare_cann_version(self): version_env = get_cann_version_from_env() diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py index 25ffbfdb385..673ed226f6b 100644 --- a/test/test_fake_tensor.py +++ b/test/test_fake_tensor.py @@ -1728,7 +1728,7 @@ class TestGroupedMatmul(TestCase): group_list = None split_item = 0 - res = torch_npu.npu_grouped_matmul(x, w, bias=b, group_list=group_list, split_item=split_item) + res = torch_npu.npu_grouped_matmul(x, w, bias=b, group_list=group_list, split_item=split_item, group_type=-1) self.assertTrue(x[0].shape[0] == res[0].shape[0]) self.assertTrue(x[1].shape[0] == res[1].shape[0]) self.assertTrue(x[2].shape[0] == res[2].shape[0]) @@ -1752,7 +1752,7 @@ class TestGroupedMatmul(TestCase): group_list = [256, 1280, 1792] split_item = 1 - res = torch_npu.npu_grouped_matmul(x, w, bias=b, group_list=group_list, split_item=split_item) + res = torch_npu.npu_grouped_matmul(x, w, bias=b, group_list=group_list, split_item=split_item, group_type=-1) self.assertTrue(group_list[0] == res[0].shape[0]) self.assertTrue(group_list[1] - group_list[0] == res[1].shape[0]) self.assertTrue(group_list[2] - group_list[1] == res[2].shape[0]) @@ -1778,7 +1778,7 @@ class TestGroupedMatmul(TestCase): group_list = None split_item = 2 - res = torch_npu.npu_grouped_matmul(x, w, bias=b, group_list=group_list, split_item=split_item) + res = torch_npu.npu_grouped_matmul(x, w, bias=b, group_list=group_list, split_item=split_item, group_type=0) dim0 = 0 for xi in x: dim0 += xi.shape[0] @@ -1801,7 +1801,7 @@ class TestGroupedMatmul(TestCase): group_list = [256, 1280, 1792] split_item = 3 - res = torch_npu.npu_grouped_matmul(x, w, bias=b, group_list=group_list, split_item=split_item) + res = torch_npu.npu_grouped_matmul(x, w, bias=b, group_list=group_list, split_item=split_item, group_type=0) self.assertTrue(x[0].shape[0] == res[0].shape[0]) self.assertTrue(w[0].shape[1] == res[0].shape[1]) -- Gitee