diff --git a/debug/accuracy_tools/msprobe/core/common/utils.py b/debug/accuracy_tools/msprobe/core/common/utils.py index 204377a56513f6b31310747e53471e10c5b30451..4aca49b926d7be26e56547f125297c334123af17 100644 --- a/debug/accuracy_tools/msprobe/core/common/utils.py +++ b/debug/accuracy_tools/msprobe/core/common/utils.py @@ -407,3 +407,10 @@ def safe_get_value(container, index, container_name, key=None): f"{container_name} is {container}" logger.error(err_msg) raise MsprobeBaseException(MsprobeBaseException.INVALID_KEY_ERROR) from e + except TypeError as e: + err_msg = "wrong type, please check!\n" \ + f"{container_name} is {container}\n" \ + f"index is {index}\n" \ + f"key is {key}" + logger.error(err_msg) + raise MsprobeBaseException(MsprobeBaseException.INVALID_OBJECT_TYPE_ERROR) from e diff --git a/debug/accuracy_tools/msprobe/core/compare/utils.py b/debug/accuracy_tools/msprobe/core/compare/utils.py index b23840c3c8472a005f413e113f6eaaab155e310d..4a2b221ce51cbd7ab13fbebfb75126fa27079777 100644 --- a/debug/accuracy_tools/msprobe/core/compare/utils.py +++ b/debug/accuracy_tools/msprobe/core/compare/utils.py @@ -16,6 +16,7 @@ import os import re import math +from dataclasses import dataclass import numpy as np from msprobe.core.common.const import Const, CompareConst from msprobe.core.common.utils import CompareException, check_regex_prefix_format_valid, logger @@ -477,3 +478,42 @@ def _compare_parser(parser): help=" The data mapping file path.", required=False) parser.add_argument("-lm", "--layer_mapping", dest="layer_mapping", type=str, nargs='?', const=True, help=" The layer mapping file path.", required=False) + + +@dataclass +class ApiItemInfo: + name: str + struct: tuple + stack_info: list + + +def stack_column_process(result_item, has_stack, index, key, npu_stack_info): + if has_stack and index == 0 and key == CompareConst.INPUT_STRUCT: + result_item.extend(npu_stack_info) + else: + result_item.append(CompareConst.NONE) + return result_item + + +def result_item_init(n_info, b_info, dump_mode): + n_len = len(n_info.struct) + b_len = len(b_info.struct) + struct_long_enough = (n_len > 2 and b_len > 2) if dump_mode == Const.MD5 else (n_len > 1 and b_len > 1) + if struct_long_enough: + result_item = [ + n_info.name, b_info.name, n_info.struct[0], b_info.struct[0], n_info.struct[1], b_info.struct[1] + ] + if dump_mode == Const.MD5: + md5_compare_result = CompareConst.PASS if n_info.struct[2] == b_info.struct[2] else CompareConst.DIFF + result_item.extend([n_info.struct[2], b_info.struct[2], md5_compare_result]) + elif dump_mode == Const.SUMMARY: + result_item.extend([" "] * 8) + else: + result_item.extend([" "] * 5) + else: + err_msg = "index out of bounds error will occur in result_item_init, please check!\n" \ + f"npu_info_struct is {n_info.struct}\n" \ + f"bench_info_struct is {b_info.struct}" + logger.error(err_msg) + raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) + return result_item diff --git a/debug/accuracy_tools/msprobe/test/core_ut/common/test_utils.py b/debug/accuracy_tools/msprobe/test/core_ut/common/test_utils.py index c0235840db4dbddcfacb4ee79b31c573c9bce823..e16f7138779240f88fd76553fc2a2476ea146cae 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/common/test_utils.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/common/test_utils.py @@ -19,6 +19,7 @@ import os import tempfile from unittest import TestCase from unittest.mock import patch, MagicMock, mock_open +import numpy as np from msprobe.core.common.const import Const from msprobe.core.common.file_utils import (FileCheckConst, @@ -41,7 +42,10 @@ from msprobe.core.common.utils import (CompareException, get_dump_mode, get_real_step_or_rank, get_step_or_rank_from_string, - get_stack_construct_by_dump_json_path) + get_stack_construct_by_dump_json_path, + safe_get_value, + MsprobeBaseException) + class TestUtils(TestCase): @@ -312,3 +316,50 @@ class TestUtils(TestCase): self.assertEqual(stack, {'stack_key': 'stack_value'}) self.assertEqual(construct, {'construct_key': 'construct_value'}) + + def test_safe_get_value_dict_valid_key_index(self): + # Test valid key and index in a dictionary + dict_container = {'a': [1, 2, 3], 'b': [4, 5, 6]} + self.assertEqual(safe_get_value(dict_container, 1, 'dict_container', key='a'), 2) + + def test_safe_get_value_invalid_key(self): + # Test invalid key in dictionary + dict_container = {'a': [1, 2, 3], 'b': [4, 5, 6]} + with self.assertRaises(MsprobeBaseException) as context: + safe_get_value(dict_container, 1, 'dict_container', key='invalid_key') + self.assertEqual(context.exception.code, MsprobeBaseException.INVALID_OBJECT_TYPE_ERROR) + + def test_safe_get_value_valid_key_invalid_index(self): + # Test invalid index in dictionary[key] + dict_container = {'a': [1, 2, 3], 'b': [4, 5, 6]} + with self.assertRaises(MsprobeBaseException) as context: + safe_get_value(dict_container, 5, 'dict_container', key='a') + self.assertEqual(context.exception.code, MsprobeBaseException.INDEX_OUT_OF_BOUNDS_ERROR) + + def test_safe_get_value_list_valid_index(self): + # Test valid index in a list + list_container = [10, 20, 30] + self.assertEqual(safe_get_value(list_container, 1, 'list_container'), 20) + + def test_safe_get_value_list_index_out_of_bounds(self): + # Test index out of bounds in a list + list_container = [10, 20, 30] + with self.assertRaises(MsprobeBaseException) as context: + safe_get_value(list_container, 10, 'list_container') + self.assertEqual(context.exception.code, MsprobeBaseException.INDEX_OUT_OF_BOUNDS_ERROR) + + def test_safe_get_value_tuple_valid_index(self): + # Test valid index in a tuple + tuple_container = (100, 200, 300) + self.assertEqual(safe_get_value(tuple_container, 2, 'tuple_container'), 300) + + def test_safe_get_value_array_valid_index(self): + # Test valid index in a numpy array + array_container = np.array([1000, 2000, 3000]) + self.assertEqual(safe_get_value(array_container, 0, 'array_container'), 1000) + + def test_safe_get_value_unsupported_container_type(self): + # Test unsupported container type (e.g., a string) + with self.assertRaises(MsprobeBaseException) as context: + safe_get_value("unsupported_type", 0, 'string_container') + self.assertEqual(context.exception.code, MsprobeBaseException.INVALID_OBJECT_TYPE_ERROR) diff --git a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare_utils.py b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare_utils.py index 3150ee14f5dedb45456a9ca1b38cdcab88862fe4..145fec43f2cb24b88c418b8a918ff6f784dacb30 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare_utils.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare_utils.py @@ -6,9 +6,10 @@ import unittest import argparse from msprobe.core.compare.utils import extract_json, rename_api, read_op, op_item_parse, \ check_and_return_dir_contents, resolve_api_special_parameters, get_rela_diff_summary_mode, \ - get_accuracy, get_un_match_accuracy, merge_tensor, _compare_parser + get_accuracy, get_un_match_accuracy, merge_tensor, _compare_parser, stack_column_process, result_item_init, \ + ApiItemInfo from msprobe.core.common.utils import CompareException -from msprobe.core.common.const import Const +from msprobe.core.common.const import Const, CompareConst # test_read_op_1 @@ -367,3 +368,71 @@ class TestUtilsMethods(unittest.TestCase): self.assertIsNone(args.api_mapping) # 默认值应为 None self.assertEqual(args.data_mapping, "data_mapping.txt") self.assertEqual(args.layer_mapping, "layer_mapping.txt") + + def test_stack_column_process_stack_info(self): + result_item = [] + has_stack = True + index = 0 + key = CompareConst.INPUT_STRUCT + npu_stack_info = ['abc'] + result_item = stack_column_process(result_item, has_stack, index, key, npu_stack_info) + self.assertEqual(result_item, ['abc']) + + def test_stack_column_process_None(self): + result_item = [] + has_stack = True + index = 1 + key = CompareConst.INPUT_STRUCT + npu_stack_info = ['abc'] + result_item = stack_column_process(result_item, has_stack, index, key, npu_stack_info) + self.assertEqual(result_item, ['None']) + + def test_result_item_init_all_and_summary(self): + n_name = 'Tensor.add.0.forward.input.0' + n_struct = ('torch.float32', [96]) + npu_stack_info = ['abc'] + b_name = 'Tensor.add.0.forward.input.0' + b_struct = ('torch.float32', [96]) + bench_stack_info = ['abc'] + n_info = ApiItemInfo(n_name, n_struct, npu_stack_info) + b_info = ApiItemInfo(b_name, b_struct, bench_stack_info) + + dump_mode = Const.ALL + result_item = result_item_init(n_info, b_info, dump_mode) + self.assertEqual(result_item, ['Tensor.add.0.forward.input.0', 'Tensor.add.0.forward.input.0', + 'torch.float32', 'torch.float32', [96], [96], ' ', ' ', ' ', ' ', ' ']) + + dump_mode = Const.SUMMARY + result_item = result_item_init(n_info, b_info, dump_mode) + self.assertEqual(result_item, ['Tensor.add.0.forward.input.0', 'Tensor.add.0.forward.input.0', + 'torch.float32', 'torch.float32', [96], [96], ' ', ' ', ' ', ' ', ' ', ' ', ' ', ' ']) + + def test_result_item_init_md5(self): + n_name = 'Tensor.add.0.forward.input.0' + n_struct = ('torch.float32', [96], 'e87000dc') + npu_stack_info = ['abc'] + b_name = 'Tensor.add.0.forward.input.0' + b_struct = ('torch.float32', [96], 'e87000dc') + bench_stack_info = ['abc'] + n_info = ApiItemInfo(n_name, n_struct, npu_stack_info) + b_info = ApiItemInfo(b_name, b_struct, bench_stack_info) + + dump_mode = Const.MD5 + result_item = result_item_init(n_info, b_info, dump_mode) + self.assertEqual(result_item, ['Tensor.add.0.forward.input.0', 'Tensor.add.0.forward.input.0', + 'torch.float32', 'torch.float32', [96], [96], 'e87000dc', 'e87000dc', 'pass']) + + def test_result_item_init_md5_index_error(self): + n_name = 'Tensor.add.0.forward.input.0' + n_struct = ('torch.float32', [96]) + npu_stack_info = ['abc'] + b_name = 'Tensor.add.0.forward.input.0' + b_struct = ('torch.float32', [96]) + bench_stack_info = ['abc'] + n_info = ApiItemInfo(n_name, n_struct, npu_stack_info) + b_info = ApiItemInfo(b_name, b_struct, bench_stack_info) + + dump_mode = Const.MD5 + with self.assertRaises(CompareException) as context: + result_item = result_item_init(n_info, b_info, dump_mode) + self.assertEqual(context.exception.code, CompareException.INDEX_OUT_OF_BOUNDS_ERROR)