diff --git a/debug/accuracy_tools/msprobe/core/common/const.py b/debug/accuracy_tools/msprobe/core/common/const.py index 5a5c779ff8232f7d4a5eb5e08a6b1a0044f01305..ed70da506aaba668337e1b7c3c5c8eaf82fe0c72 100644 --- a/debug/accuracy_tools/msprobe/core/common/const.py +++ b/debug/accuracy_tools/msprobe/core/common/const.py @@ -54,9 +54,9 @@ class Const: SIX_SEGMENT = 6 SEVEN_SEGMENT = 7 - MAX_DEPTH = 10 + MAX_DEPTH = 400 CPU_QUARTER = 4 - DUMP_MAX_DEPTH = 50 + DUMP_MAX_DEPTH = 400 EXTERN_INPUT_LIST_MAX_LEN = 100 MAX_PROCESS_NUM = 128 diff --git a/debug/accuracy_tools/msprobe/core/common/log.py b/debug/accuracy_tools/msprobe/core/common/log.py index f20d25d991ef2d3da1307336e4aa05ec3bc87d86..4ce19e4961c62cad736ad90072de12cb44b4fb95 100644 --- a/debug/accuracy_tools/msprobe/core/common/log.py +++ b/debug/accuracy_tools/msprobe/core/common/log.py @@ -89,6 +89,13 @@ class BaseLogger: self.error(msg) raise exception + def warning_log_with_exp(self, msg, exception): + """ + 打印警告日志并抛出指定异常 + """ + self.warning(msg) + raise exception + def _print_log(self, level, msg, end='\n'): current_rank = self.get_rank() current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) diff --git a/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py b/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py index 740b7452f247637aab528762f0b870b021d12b93..79c6202bca61abe204f847aa4c05b0cbf2be1f16 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py @@ -65,6 +65,26 @@ class DataWriter: if is_new_file: change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY) + @recursion_depth_decorator("JsonWriter: DataWriter._replace_crc32_placeholders") + def _replace_crc32_placeholders(self, data, crc32_results): + """ + 遍历 JSON 结构,将所有 md5_index 占位符替换成真实的 CRC32 + """ + if isinstance(data, dict): + for k, v in list(data.items()): + if k == Const.MD5_INDEX and isinstance(v, int): + idx = v + # 防越界 + crc = crc32_results[idx] if idx < len(crc32_results) else None + # 删除占位符,改成真实字段 + del data[k] + data[Const.MD5] = crc + else: + self._replace_crc32_placeholders(v, crc32_results) + elif isinstance(data, (list, tuple)): + for item in data: + self._replace_crc32_placeholders(item, crc32_results) + @recursion_depth_decorator("JsonWriter: DataWriter._replace_stat_placeholders") def _replace_stat_placeholders(self, data, stat_result): if isinstance(data, dict): @@ -324,21 +344,3 @@ class DataWriter: if self.cache_debug: self.write_debug_info_json(self.debug_file_path) - def _replace_crc32_placeholders(self, data, crc32_results): - """ - 遍历 JSON 结构,将所有 md5_index 占位符替换成真实的 CRC32 - """ - if isinstance(data, dict): - for k, v in list(data.items()): - if k == Const.MD5_INDEX and isinstance(v, int): - idx = v - # 防越界 - crc = crc32_results[idx] if idx < len(crc32_results) else None - # 删除占位符,改成真实字段 - del data[k] - data[Const.MD5] = crc - else: - self._replace_crc32_placeholders(v, crc32_results) - elif isinstance(data, (list, tuple)): - for item in data: - self._replace_crc32_placeholders(item, crc32_results) diff --git a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/api_runner.py b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/api_runner.py index 8ec8125f75b3669f626ec97c751c69653c5f638d..8ebe4e47b841f7fb115dd12ecd093874f14f9d53 100644 --- a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/api_runner.py +++ b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/api_runner.py @@ -153,17 +153,17 @@ class ApiRunner: api_name_list = api_name_str.split(Const.SEP) if len(api_name_list) != 3: err_msg = f"ApiRunner.get_info_from_name failed: api_name_str: {api_name_str} is not in defined format" - logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue)) + logger.warning_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue)) api_type_str, api_sub_name = api_name_list[0], api_name_list[1] if api_type_str not in [MsCompareConst.MINT, MsCompareConst.MINT_FUNCTIONAL, MsCompareConst.TENSOR_API, MsCompareConst.FUNCTIONAL_API] \ and api_platform == Const.MS_FRAMEWORK: err_msg = f"ApiRunner.get_info_from_name failed: not mint, mint.nn.functional or Tensor api" - logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue)) + logger.warning_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue)) if api_type_str not in MsCompareConst.MT_VALID_API_TYPES and api_platform == Const.MT_FRAMEWORK: err_msg = f"ApiRunner.get_info_from_name failed: not torch, functional or Tensor api" - logger.error_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue)) + logger.warning_log_with_exp(err_msg, ApiAccuracyCheckerException(ApiAccuracyCheckerException.WrongValue)) return api_type_str, api_sub_name @staticmethod diff --git a/debug/accuracy_tools/msprobe/test/core_ut/common/test_utils.py b/debug/accuracy_tools/msprobe/test/core_ut/common/test_utils.py index 7a62d9f02571b5130681f181ab8f02c065f79ddc..c9214d762fc8341843eb646b92b34402b1faad0f 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/common/test_utils.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/common/test_utils.py @@ -508,12 +508,13 @@ class TestIsSaveVariableValid(unittest.TestCase): def setUp(self): self.valid_special_types = (int, float, str, bool) + @patch.object(Const, "DUMP_MAX_DEPTH", 5) def test_is_save_variable_valid_DepthExceeded_ReturnsFalse(self): - # 创建一个深度超过 Const.DUMP_MAX_DEPTH 的嵌套结构 - nested_structure = [0] * Const.DUMP_MAX_DEPTH - for _ in range(Const.DUMP_MAX_DEPTH): - nested_structure = [nested_structure] - self.assertFalse(is_save_variable_valid(nested_structure, self.valid_special_types)) + # 构造深度 = 阈值 + 1 + nested = [0] * 3 + for _ in range(Const.DUMP_MAX_DEPTH + 1): # 注意 +1,确保“超过”阈值 + nested = [nested] + self.assertFalse(is_save_variable_valid(nested, self.valid_special_types)) def test_is_save_variable_valid_ValidSpecialTypes_ReturnsTrue(self): for valid_type in self.valid_special_types: diff --git a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare_check.py b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare_check.py index 1a0a33f799724ffefe73bf8f024e0146b2925464..60c72bcb68aca1d7fd4583bed83acb52aa55f0d6 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare_check.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare_check.py @@ -75,7 +75,7 @@ class TestUtilsMethods(unittest.TestCase): self.assertEqual(context.exception.code, CompareException.INVALID_CHAR_ERROR) def test_check_json_key_value_max_depth(self): - result = check_json_key_value(input_output, op_name, depth=11) + result = check_json_key_value(input_output, op_name, depth=401) self.assertEqual(result, None) def test_valid_key_value_type_shape(self): diff --git a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare_utils.py b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare_utils.py index 9d418ae5e4e7de74ba81216325a69ee057441236..95002983adebb009624bf8bd68c0facee2a5ac71 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare_utils.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare_utils.py @@ -396,7 +396,7 @@ class TestUtilsMethods(unittest.TestCase): def test_op_item_parse_max_depth(self): with self.assertRaises(CompareException) as context: - op_item_parse(parse_item, parse_op_name, 'input', depth=11) + op_item_parse(parse_item, parse_op_name, 'input', depth=401) self.assertEqual(context.exception.code, CompareException.RECURSION_LIMIT_ERROR) def test_get_rela_diff_summary_mode_float_or_int(self): diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py index a5d83ba3830c558d68884862b9870342c61701fa..8d3a922421e30a09a07057b94fc81e4d8cfe2e9c 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py @@ -29,12 +29,17 @@ class TestUtils(unittest.TestCase): self.processor.save_tensors_in_element(api_name, tensor) file_path = os.path.join(self.save_path, f'{api_name}.0.pt') self.assertTrue(os.path.exists(file_path)) - + + @patch.object(Const, "MAX_DEPTH", 50) def test_recursion_limit_error(self): tensor = torch.randn(10, 10) with self.assertRaises(DumpException) as context: - self.processor._save_recursive("test_api", [tensor, [tensor, [tensor, [tensor, [tensor, [tensor, [tensor, - [tensor, [tensor, [tensor, [tensor]]]]]]]]]]], 0) + self.processor._save_recursive("test_api", [tensor, [tensor, [tensor, [tensor, [tensor, [tensor, [tensor, [tensor, [tensor, [tensor, [tensor, [tensor, [tensor, [tensor, [tensor, [tensor, [tensor, + [tensor, [tensor, [tensor, [tensor, [tensor, [tensor, [tensor, [tensor, [tensor, [tensor, + [tensor, [tensor, [tensor, [tensor, [tensor, [tensor, [tensor, [tensor, [tensor, [tensor, + [tensor, [tensor, [tensor, [tensor, [tensor, [tensor, [tensor, [tensor, [tensor, [tensor, + [tensor, [tensor, [tensor, [tensor + ]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]], 0) self.assertTrue(isinstance(context.exception, DumpException)) self.assertEqual(context.exception.code, DumpException.RECURSION_LIMIT_ERROR)