diff --git a/debug/accuracy_tools/api_accuracy_checker/common/base_api.py b/debug/accuracy_tools/api_accuracy_checker/common/base_api.py index 2c3086184c3609c0b9042cc54e7bd33eec6db6ed..49fd004c06094a5011caf0bcedec03aa4b971915 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/base_api.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/base_api.py @@ -54,12 +54,18 @@ class BaseAPIInfo: else: api_args = self.api_name + '.' + str(self.args_num) + rank = self.get_tensor_rank(arg) + if rank is not None: + rank = "rank" + str(rank) if self.is_forward: - forward_real_data_path = os.path.join(self.save_path, self.forward_path) - + forward_real_data_path = os.path.join(self.save_path, self.forward_path, rank) if rank else os.path.join(self.save_path, self.forward_path) + if not os.path.exists(forward_real_data_path): + os.makedirs(forward_real_data_path) file_path = os.path.join(forward_real_data_path, f'{api_args}.pt') else: - backward_real_data_path = os.path.join(self.save_path, self.backward_path) + backward_real_data_path = os.path.join(self.save_path, self.backward_path, rank) if rank else os.path.join(self.save_path, self.backward_path) + if not os.path.exists(backward_real_data_path): + os.makedirs(backward_real_data_path) file_path = os.path.join(backward_real_data_path, f'{api_args}.pt') self.args_num += 1 pt_path = write_pt(file_path, arg.contiguous().cpu().detach()) @@ -68,6 +74,24 @@ class BaseAPIInfo: single_arg.update({'requires_grad': arg.requires_grad}) return single_arg + def get_tensor_rank(self, arg): + def get_tensor_rank_single(x): + if isinstance(x, (list, tuple)): + if len(x) > 0: + return get_tensor_rank_single(x[0]) + return None + elif isinstance(x, torch.Tensor): + device = x.device + if device.type == 'cpu': + return None + else: + return device.index + return None + rank = get_tensor_rank_single(arg) + if rank is None: + return None + return rank + def analyze_builtin(self, arg): single_arg = {} if self.is_save_data: diff --git a/debug/accuracy_tools/api_accuracy_checker/test/ut/common/test_config.py b/debug/accuracy_tools/api_accuracy_checker/test/ut/common/test_config.py index ed764987d5a9287293f183c0bde1d86afd90ccae..a68057dfb41ca38ba79e1daa992a8f51ce4d64e4 100644 --- a/debug/accuracy_tools/api_accuracy_checker/test/ut/common/test_config.py +++ b/debug/accuracy_tools/api_accuracy_checker/test/ut/common/test_config.py @@ -17,5 +17,5 @@ class TestConfig(unittest.TestCase): def test_update_config(self): - self.config.update_config(dump_path='/new/path/to/dump', enable_dataloader=False) + self.config.update_config(dump_path='/new/path/to/dump') self.assertEqual(self.config.dump_path, '/new/path/to/dump') diff --git a/debug/accuracy_tools/api_accuracy_checker/test/ut/dump/test_dump_scopr.py b/debug/accuracy_tools/api_accuracy_checker/test/ut/dump/test_dump_scopr.py index addba38e38446b177942a104b4194efe910b1f7c..b892a6077a3c26ae27343734aca8012e21d3fc2c 100644 --- a/debug/accuracy_tools/api_accuracy_checker/test/ut/dump/test_dump_scopr.py +++ b/debug/accuracy_tools/api_accuracy_checker/test/ut/dump/test_dump_scopr.py @@ -10,12 +10,12 @@ class TestDumpScope(unittest.TestCase): wrapped_func = iter_tracer(dummy_func) result = wrapped_func() - self.assertEqual(DumpUtil.dump_switch, "ON") + self.assertEqual(DumpUtil.dump_switch, "OFF") self.assertEqual(result, "Hello, World!") def another_dummy_func(): return 123 wrapped_func = iter_tracer(another_dummy_func) result = wrapped_func() - self.assertEqual(DumpUtil.dump_switch, "ON") + self.assertEqual(DumpUtil.dump_switch, "OFF") self.assertEqual(result, 123) diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/utils.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/utils.py index b4148f65ff6dd2e6cfcde803ba78a32569e3801d..35f5ad4a905f9d1ce7b055d1723fd4babcc3d9fd 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/utils.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/utils.py @@ -196,7 +196,7 @@ def generate_dump_path_str(): return dump_path -def set_dump_switch(switch, mode=Const.ALL, scope=[], api_list=[], filter_switch=Const.OFF, dump_mode=[Const.ALL], summary_only=False): +def set_dump_switch(switch, mode=Const.ALL, scope=[], api_list=[], filter_switch=Const.ON, dump_mode=[Const.ALL], summary_only=False): check_switch_valid(switch) if not DumpUtil.dump_path: set_dump_path() @@ -210,7 +210,7 @@ def set_dump_switch(switch, mode=Const.ALL, scope=[], api_list=[], filter_switch set_dump_switch_config(mode=mode, scope=scope, api_list=api_list, filter_switch=filter_switch, dump_mode=dump_mode,summary_only=summary_only) -def set_dump_switch_config(mode=Const.ALL, scope=[], api_list=[], filter_switch=Const.OFF, dump_mode=[Const.ALL], summary_only=False): +def set_dump_switch_config(mode=Const.ALL, scope=[], api_list=[], filter_switch=Const.ON, dump_mode=[Const.ALL], summary_only=False): try: check_mode_valid(mode, scope, api_list) check_switch_valid(filter_switch) diff --git a/debug/accuracy_tools/ptdbg_ascend/test/ut/test_hooks.py b/debug/accuracy_tools/ptdbg_ascend/test/ut/test_hooks.py index 82f3d8dfedd7ffbd57145d57ee272c8b31024800..7874d3c2fa947dedcdd28c5f463ab119e3fb712c 100644 --- a/debug/accuracy_tools/ptdbg_ascend/test/ut/test_hooks.py +++ b/debug/accuracy_tools/ptdbg_ascend/test/ut/test_hooks.py @@ -17,7 +17,7 @@ class TestUtilsMethods(unittest.TestCase): self.assertTrue(dump_util.dump_init_enable) self.assertEqual(dump_util.dump_switch_scope, []) self.assertEqual(dump_util.dump_api_list, []) - self.assertEqual(dump_util.dump_filter_switch, "OFF") + self.assertEqual(dump_util.dump_filter_switch, switch_on) self.assertEqual(dump_count, 0) def test_set_dump_switch_mode_is_list(self):