diff --git a/debug/accuracy_tools/msprobe/core/common/const.py b/debug/accuracy_tools/msprobe/core/common/const.py index 10ebcfe20bc4a5c9f7951146fe3aff937c895dc2..2cbe1c293394ca5b1d80429e4e337d80c7df9613 100644 --- a/debug/accuracy_tools/msprobe/core/common/const.py +++ b/debug/accuracy_tools/msprobe/core/common/const.py @@ -56,9 +56,9 @@ class Const: SIX_SEGMENT = 6 SEVEN_SEGMENT = 7 - MAX_DEPTH = 10 + MAX_DEPTH = 999 CPU_QUARTER = 4 - DUMP_MAX_DEPTH = 50 + DUMP_MAX_DEPTH = 999 # dump mode ALL = "all" diff --git a/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py b/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py index 1d0e356d5c4b2889cd4ff2066fb8f322130d1ff1..32141f7c2885dc3d0cca52942dac24db54b371ac 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py @@ -62,7 +62,27 @@ class DataWriter: if is_new_file: change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY) - @recursion_depth_decorator("JsonWriter: DataWriter._replace_stat_placeholders", max_depth=Const.DUMP_MAX_DEPTH) + @recursion_depth_decorator("JsonWriter: DataWriter._replace_crc32_placeholders") + def _replace_crc32_placeholders(self, data, crc32_results): + """ + 遍历 JSON 结构,将所有 md5_index 占位符替换成真实的 CRC32 + """ + if isinstance(data, dict): + for k, v in list(data.items()): + if k == Const.MD5_INDEX and isinstance(v, int): + idx = v + # 防越界 + crc = crc32_results[idx] if idx < len(crc32_results) else None + # 删除占位符,改成真实字段 + del data[k] + data[Const.MD5] = crc + else: + self._replace_crc32_placeholders(v, crc32_results) + elif isinstance(data, (list, tuple)): + for item in data: + self._replace_crc32_placeholders(item, crc32_results) + + @recursion_depth_decorator("JsonWriter: DataWriter._replace_stat_placeholders") def _replace_stat_placeholders(self, data, stat_result): if isinstance(data, dict): keys = list(data.keys()) # 获取当前所有键 @@ -288,21 +308,3 @@ class DataWriter: if self.cache_debug: self.write_debug_info_json(self.debug_file_path) - def _replace_crc32_placeholders(self, data, crc32_results): - """ - 遍历 JSON 结构,将所有 md5_index 占位符替换成真实的 CRC32 - """ - if isinstance(data, dict): - for k, v in list(data.items()): - if k == Const.MD5_INDEX and isinstance(v, int): - idx = v - # 防越界 - crc = crc32_results[idx] if idx < len(crc32_results) else None - # 删除占位符,改成真实字段 - del data[k] - data[Const.MD5] = crc - else: - self._replace_crc32_placeholders(v, crc32_results) - elif isinstance(data, (list, tuple)): - for item in data: - self._replace_crc32_placeholders(item, crc32_results) diff --git a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare_check.py b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare_check.py index fdfd124222f03599f914c77eb16c42c8d3578a7b..7c9a7bdf3e3a925843c00ae6415bb8d5a8dd3ff1 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare_check.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare_check.py @@ -76,7 +76,7 @@ class TestUtilsMethods(unittest.TestCase): self.assertEqual(context.exception.code, CompareException.INVALID_CHAR_ERROR) def test_check_json_key_value_max_depth(self): - result = check_json_key_value(input_output, op_name, depth=11) + result = check_json_key_value(input_output, op_name, depth=51) self.assertEqual(result, None) def test_valid_key_value_type_shape(self): diff --git a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare_utils.py b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare_utils.py index 9d418ae5e4e7de74ba81216325a69ee057441236..96b650a2b9d0992f97a355aa69e5f709fe61aecd 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare_utils.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/compare/test_acc_compare_utils.py @@ -396,7 +396,7 @@ class TestUtilsMethods(unittest.TestCase): def test_op_item_parse_max_depth(self): with self.assertRaises(CompareException) as context: - op_item_parse(parse_item, parse_op_name, 'input', depth=11) + op_item_parse(parse_item, parse_op_name, 'input', depth=1000) self.assertEqual(context.exception.code, CompareException.RECURSION_LIMIT_ERROR) def test_get_rela_diff_summary_mode_float_or_int(self): diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py index a5d83ba3830c558d68884862b9870342c61701fa..212a558b605707a2d2c12473381b8bde3420c76b 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/common/test_common_utils.py @@ -29,14 +29,6 @@ class TestUtils(unittest.TestCase): self.processor.save_tensors_in_element(api_name, tensor) file_path = os.path.join(self.save_path, f'{api_name}.0.pt') self.assertTrue(os.path.exists(file_path)) - - def test_recursion_limit_error(self): - tensor = torch.randn(10, 10) - with self.assertRaises(DumpException) as context: - self.processor._save_recursive("test_api", [tensor, [tensor, [tensor, [tensor, [tensor, [tensor, [tensor, - [tensor, [tensor, [tensor, [tensor]]]]]]]]]]], 0) - self.assertTrue(isinstance(context.exception, DumpException)) - self.assertEqual(context.exception.code, DumpException.RECURSION_LIMIT_ERROR) def test_save_recursive_non_tensor_types(self): api_name = "test_api"