From 7ba787a64c298c341990af17a72c64d1dc2b5bc1 Mon Sep 17 00:00:00 2001 From: gitee Date: Wed, 30 Oct 2024 16:31:57 +0800 Subject: [PATCH 1/4] delete process num --- .../msprobe/docs/18.online_dispatch.md | 1 - .../pytorch/online_dispatch/dispatch.py | 67 +++---------------- .../pytorch/online_dispatch/dump_compare.py | 36 ++-------- 3 files changed, 12 insertions(+), 92 deletions(-) diff --git a/debug/accuracy_tools/msprobe/docs/18.online_dispatch.md b/debug/accuracy_tools/msprobe/docs/18.online_dispatch.md index e686c61b68..a1cd711132 100644 --- a/debug/accuracy_tools/msprobe/docs/18.online_dispatch.md +++ b/debug/accuracy_tools/msprobe/docs/18.online_dispatch.md @@ -70,7 +70,6 @@ PyTorch NPU在线精度比对是msprobe工具实现在PyTorch训练过程中直 | api_list | dump范围,dump_mode="list"时设置,需要Dump Aten Ir API名称,默认为None,Aten Ir API名称可以通过dir(torch.ops.aten)查看。 | 否 | | dump_path| dump文件生成的路径。 | 是 | | tag | 传入tag字符串,成为dump文件夹名一部分,默认为None。 | 否 | -| process_num | 多进程并发数,默认为0。 | 否 | | debug | debug信息打印,默认为False。 | 否 | ### dump数据存盘说明 dump数据存盘目录名格式:`atat_tag_rankid_{timestamp}`。 diff --git a/debug/accuracy_tools/msprobe/pytorch/online_dispatch/dispatch.py b/debug/accuracy_tools/msprobe/pytorch/online_dispatch/dispatch.py index df8899954a..3f8e5774be 100644 --- a/debug/accuracy_tools/msprobe/pytorch/online_dispatch/dispatch.py +++ b/debug/accuracy_tools/msprobe/pytorch/online_dispatch/dispatch.py @@ -31,8 +31,8 @@ else: from msprobe.core.common.file_utils import check_file_or_directory_path, load_yaml from msprobe.core.common.const import Const, CompareConst from msprobe.pytorch.common.log import logger -from msprobe.pytorch.online_dispatch.dump_compare import dispatch_workflow, dispatch_multiprocess, error_call, \ - TimeStatistics, DispatchRunParam, DisPatchDataInfo +from msprobe.pytorch.online_dispatch.dump_compare import dispatch_workflow, TimeStatistics, DispatchRunParam, \ + DisPatchDataInfo from msprobe.pytorch.online_dispatch.utils import get_callstack, data_to_cpu, get_sys_info, DispatchException, \ COMPARE_LOGO from msprobe.pytorch.online_dispatch.compare import Comparator @@ -44,7 +44,7 @@ DETAILS_FILE_NAME = "accuracy_checking_details_" + current_time + ".csv" class PtdbgDispatch(TorchDispatchMode): - def __init__(self, dump_mode=Const.OFF, api_list=None, debug=False, dump_path=None, tag=None, process_num=0): + def __init__(self, dump_mode=Const.OFF, api_list=None, debug=False, dump_path=None, tag=None): super(PtdbgDispatch, self).__init__() logger.info(COMPARE_LOGO) if not is_npu: @@ -64,7 +64,6 @@ class PtdbgDispatch(TorchDispatchMode): self.device_dump_path_npu = None self.all_summary = [] self.call_stack_list = [] - self.process_num = process_num self.filter_dump_api() self.check_param() dir_name = self.get_dir_name(tag) @@ -83,13 +82,9 @@ class PtdbgDispatch(TorchDispatchMode): yaml_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "torch_ops_config.yaml") self.get_ops(yaml_path) - self.lock = None - if process_num > 0: - self.pool = Pool(process_num) if debug: logger.info(f'Main pid:{os.getpid()} device:{self.device_id} dump_list:{self.dump_api_list} ' - f'dump_mode:{self.dump_mode} cpu_path[{self.root_cpu_path}], npu_path[{self.root_npu_path}], ' - f'process[{process_num}]') + f'dump_mode:{self.dump_mode} cpu_path[{self.root_cpu_path}], npu_path[{self.root_npu_path}].') def __exit__(self, exc_type, exc_val, exc_tb): super().__exit__(exc_type, exc_val, exc_tb) @@ -98,25 +93,6 @@ class PtdbgDispatch(TorchDispatchMode): return logger.info(f'start write compare csv: Rank[{self.device_id}], Pid[{os.getpid()}') - if self.process_num > 0: - self.pool.close() - self.pool.join() - summary_path = os.path.join(self.root_cpu_path, f'summary.json') - if not os.path.exists(summary_path): - logger.error("Please check train log, An exception may have occurred!") - return - check_file_or_directory_path(summary_path, False) - fp_handle = FileOpen(summary_path, "r") - while True: - json_line_data = fp_handle.readline() - if json_line_data == '\n': - continue - if len(json_line_data) == 0: - break - msg = json.loads(json_line_data) - self.all_summary[msg[0]] = msg[1] - fp_handle.close() - if self.debug_flag: input_num = 0 output_num = 0 @@ -185,35 +161,11 @@ class PtdbgDispatch(TorchDispatchMode): if isinstance(cpu_out, torch.Tensor) and cpu_out.dtype in [torch.bfloat16, torch.float16, torch.half]: cpu_out = cpu_out.float() - if self.process_num == 0: - self.all_summary.append([]) - data_info = DisPatchDataInfo(cpu_args, cpu_kwargs, self.all_summary, func, npu_out_cpu, cpu_out, self.lock) - dispatch_workflow(run_param, data_info) - else: - self.lock.acquire() - self.all_summary.append([]) - self.lock.release() - run_param.process_flag = True - if self.check_fun(func, run_param): - data_info = DisPatchDataInfo(cpu_args, cpu_kwargs, self.all_summary, None, npu_out_cpu, cpu_out, - self.lock) - self.pool.apply_async(func=dispatch_multiprocess, args=(run_param, data_info), - error_callback=error_call) - else: - logger.error("can not get correct function please set process_num=0") + self.all_summary.append([]) + data_info = DisPatchDataInfo(cpu_args, cpu_kwargs, self.all_summary, func, npu_out_cpu, cpu_out) + dispatch_workflow(run_param, data_info) return npu_out - @staticmethod - def check_fun(func, run_param): - if hasattr(torch.ops.aten, run_param.aten_api): - aten_func = getattr(torch.ops.aten, run_param.aten_api) - if hasattr(aten_func, run_param.aten_api_overload_name): - aten_overload_func = getattr(aten_func, run_param.aten_api_overload_name) - if id(aten_overload_func) == id(func): - run_param.func_namespace = "aten" - return True - return False - def get_dir_name(self, tag): # guarantee file uniqueness time.sleep(1) @@ -245,7 +197,7 @@ class PtdbgDispatch(TorchDispatchMode): def get_run_param(self, aten_api, func_name, aten_api_overload_name): run_param = DispatchRunParam(self.debug_flag, self.device_id, self.root_npu_path, self.root_cpu_path, - self.process_num, self.comparator) + self.comparator) run_param.dump_flag, run_param.auto_dump_flag = self.get_dump_flag(aten_api) run_param.func_name = func_name run_param.aten_api = aten_api @@ -275,9 +227,6 @@ class PtdbgDispatch(TorchDispatchMode): if not isinstance(self.debug_flag, bool): logger.error('The type of parameter "debug" can only be bool.') raise DispatchException(DispatchException.INVALID_PARAMETER) - if not isinstance(self.process_num, int) or self.process_num < 0: - logger.error('The type of parameter "process_num" can only be int and it should not be less than 0.') - raise DispatchException(DispatchException.INVALID_PARAMETER) def enable_autograd(self, aten_api): if aten_api in self.npu_adjust_autograd: diff --git a/debug/accuracy_tools/msprobe/pytorch/online_dispatch/dump_compare.py b/debug/accuracy_tools/msprobe/pytorch/online_dispatch/dump_compare.py index edb9c40d38..865e9610bd 100644 --- a/debug/accuracy_tools/msprobe/pytorch/online_dispatch/dump_compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/online_dispatch/dump_compare.py @@ -24,13 +24,12 @@ from msprobe.pytorch.common.log import logger class DispatchRunParam: - def __init__(self, debug_flag, device_id, root_npu_path, root_cpu_path, process_num, comparator): + def __init__(self, debug_flag, device_id, root_npu_path, root_cpu_path, comparator): # static parameters are initialized by constructors, and dynamic parameters are constructed at run time self.debug_flag = debug_flag self.device_id = device_id self.root_npu_path = root_npu_path self.root_cpu_path = root_cpu_path - self.process_num = process_num self.process_flag = False self.func_name = None self.func_namespace = None @@ -44,14 +43,13 @@ class DispatchRunParam: class DisPatchDataInfo: - def __init__(self, cpu_args, cpu_kwargs, all_summary, func, npu_out_cpu, cpu_out, lock): + def __init__(self, cpu_args, cpu_kwargs, all_summary, func, npu_out_cpu, cpu_out): self.cpu_args = cpu_args self.cpu_kwargs = cpu_kwargs self.all_summary = all_summary self.func = func self.npu_out_cpu = npu_out_cpu self.cpu_out = cpu_out - self.lock = lock class TimeStatistics: @@ -60,7 +58,6 @@ class TimeStatistics: if self.debug: self.fun = run_param.func_name self.device = run_param.device_id - self.process = run_param.process_num self.index = run_param.single_api_index self.tag = name_tag self.timeout = timeout @@ -104,19 +101,10 @@ def dump_data(data, prefix, dump_path): save_npy(data, path) -def save_temp_summary(api_index, single_api_summary, path, lock): - summary_path = os.path.join(path, f'summary.json') - lock.acquire() - with FileOpen(summary_path, "a") as f: - json.dump([api_index, single_api_summary], f) - f.write('\n') - lock.release() - - def dispatch_workflow(run_param: DispatchRunParam, data_info: DisPatchDataInfo): cpu_args, cpu_kwargs = data_info.cpu_args, data_info.cpu_kwargs all_summary, func = data_info.all_summary, data_info.func - npu_out_cpu, cpu_out, lock = data_info.npu_out_cpu, data_info.cpu_out, data_info.lock + npu_out_cpu, cpu_out = data_info.npu_out_cpu, data_info.cpu_out single_api_summary = [] prefix_input = f'{run_param.aten_api}_{run_param.single_api_index}_input' @@ -139,10 +127,7 @@ def dispatch_workflow(run_param: DispatchRunParam, data_info: DisPatchDataInfo): dump_data(cpu_out, prefix_output, run_param.root_cpu_path) dump_data(npu_out_cpu, prefix_output, run_param.root_npu_path) - if run_param.process_num == 0: - all_summary[run_param.api_index - 1] = copy.deepcopy(single_api_summary) - else: - save_temp_summary(run_param.api_index - 1, single_api_summary, run_param.root_cpu_path, lock) + all_summary[run_param.api_index - 1] = copy.deepcopy(single_api_summary) def get_torch_func(run_param): @@ -154,16 +139,3 @@ def get_torch_func(run_param): ops_aten_overlaod_func = getattr(ops_aten_func, run_param.aten_api_overload_name) return ops_aten_overlaod_func return None - - -def dispatch_multiprocess(run_param, dispatch_data_info): - torch_func = get_torch_func(run_param) - if torch_func is None: - logger.error(f'can not find suitable call api:{run_param.aten_api}') - else: - dispatch_data_info.func = torch_func - dispatch_workflow(run_param, dispatch_data_info) - - -def error_call(err): - logger.error(f'multiprocess {err}') -- Gitee From e58eedacd65befc7f31cb4755fda86fc81d87c29 Mon Sep 17 00:00:00 2001 From: gitee Date: Wed, 30 Oct 2024 17:34:33 +0800 Subject: [PATCH 2/4] fix --- .../online_dispatch/test_dump_compare.py | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/online_dispatch/test_dump_compare.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/online_dispatch/test_dump_compare.py index dfd45d7765..5223b4d131 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/online_dispatch/test_dump_compare.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/online_dispatch/test_dump_compare.py @@ -24,7 +24,7 @@ from unittest.mock import Mock, patch from unittest.mock import MagicMock from msprobe.core.common.const import CompareConst -from msprobe.pytorch.online_dispatch.dump_compare import support_basic_type, dump_data, save_temp_summary, \ +from msprobe.pytorch.online_dispatch.dump_compare import support_basic_type, dump_data, \ dispatch_workflow, get_torch_func, dispatch_multiprocess, error_call, DispatchRunParam, DisPatchDataInfo @@ -173,21 +173,6 @@ class TestDumpCompare(unittest.TestCase): # Meta tensor 不应该生成文件 self.assertFalse(os.path.exists(expected_path)) - def test_save_temp_summary(self): - api_index='1' - single_api_summary="conv2d" - path = '' - data = [] - lock=threading.Lock() - - save_temp_summary(api_index=api_index,single_api_summary=single_api_summary,path=path,lock=lock) - - with open(self.summary_path, 'r') as f: - content = f.readlines() - for line in content: - data.append(json.loads(line)) - self.assertEqual([['1','conv2d']],data) - @patch('msprobe.pytorch.online_dispatch.dump_compare.dump_data') @patch('msprobe.pytorch.online_dispatch.dump_compare.save_temp_summary') def test_dispatch_workflow_should_dump_when_flag_is_True(self,mock_save_temp_summary,mock_dump_data): -- Gitee From 23d1524895722f3632cdbfa429ba95003dea2c1a Mon Sep 17 00:00:00 2001 From: gitee Date: Thu, 31 Oct 2024 10:07:44 +0800 Subject: [PATCH 3/4] delete multiprocess --- .../online_dispatch/test_dump_compare.py | 31 ++----------------- 1 file changed, 3 insertions(+), 28 deletions(-) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/online_dispatch/test_dump_compare.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/online_dispatch/test_dump_compare.py index 5223b4d131..0eb3a084ab 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/online_dispatch/test_dump_compare.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/online_dispatch/test_dump_compare.py @@ -25,7 +25,7 @@ from unittest.mock import MagicMock from msprobe.core.common.const import CompareConst from msprobe.pytorch.online_dispatch.dump_compare import support_basic_type, dump_data, \ - dispatch_workflow, get_torch_func, dispatch_multiprocess, error_call, DispatchRunParam, DisPatchDataInfo + dispatch_workflow, get_torch_func, DispatchRunParam, DisPatchDataInfo class TestDumpCompare(unittest.TestCase): @@ -174,8 +174,7 @@ class TestDumpCompare(unittest.TestCase): self.assertFalse(os.path.exists(expected_path)) @patch('msprobe.pytorch.online_dispatch.dump_compare.dump_data') - @patch('msprobe.pytorch.online_dispatch.dump_compare.save_temp_summary') - def test_dispatch_workflow_should_dump_when_flag_is_True(self,mock_save_temp_summary,mock_dump_data): + def test_dispatch_workflow_should_dump_when_flag_is_True(self,mock_dump_data): mock_run_param = Mock() mock_run_param.aten_api="aten_api" mock_run_param.single_api_index="single_api_index" @@ -191,11 +190,9 @@ class TestDumpCompare(unittest.TestCase): dispatch_workflow(mock_run_param, mock_data_info) mock_dump_data.assert_called() - mock_save_temp_summary.assert_not_called() @patch('msprobe.pytorch.online_dispatch.dump_compare.dump_data') - @patch('msprobe.pytorch.online_dispatch.dump_compare.save_temp_summary') - def test_dispatch_workflow_should_not_dump_when_flag_is_false(self,mock_save_temp_summary,mock_dump_data): + def test_dispatch_workflow_should_not_dump_when_flag_is_false(self,mock_dump_data): mock_run_param = Mock() mock_run_param.aten_api="aten_api" mock_run_param.single_api_index="single_api_index" @@ -212,7 +209,6 @@ class TestDumpCompare(unittest.TestCase): dispatch_workflow(mock_run_param, mock_data_info) mock_dump_data.assert_not_called() - mock_save_temp_summary.assert_called() def test_get_torch_func_should_return_None_when_outside_input(self): mock_run_param = Mock() @@ -221,27 +217,6 @@ class TestDumpCompare(unittest.TestCase): mock_run_param.aten_api_overload_name="new_attr3" self.assertIsNone(get_torch_func(mock_run_param)) - @patch('msprobe.core.common.log.BaseLogger.error') - def test_dispatch_multiprocess_should_logger_error_when_wrong_api_input(self,mock_error): - mock_run_param = Mock() - mock_run_param.func_namespace="new_attr1" - mock_run_param.aten_api="new_attr2" - mock_run_param.aten_api_overload_name="new_attr3" - mock_dispatch_data_info=Mock() - dispatch_multiprocess(mock_run_param,mock_dispatch_data_info) - mock_error.assert_called_once_with(f'can not find suitable call api:{mock_run_param.aten_api}') - - @patch('msprobe.pytorch.online_dispatch.dump_compare.dispatch_workflow') - def test_dispatch_multiprocess_should_workflow_when_right_api_input(self,mock_workflow): - mock_run_param = Mock() - mock_run_param.func_namespace="aten" - mock_run_param.aten_api="add" - mock_run_param.aten_api_overload_name="Scalar" - mock_dispatch_data_info=Mock() - mock_workflow.return_value=1 - dispatch_multiprocess(mock_run_param,mock_dispatch_data_info) - mock_workflow.assert_called_once_with(mock_run_param,mock_dispatch_data_info) - @patch('msprobe.core.common.log.BaseLogger.error') def test_error_call(self,mock_error): error_call("messages") -- Gitee From cca8946ec6b56941ae09eb3b7b0d8c6b1ddd3f7b Mon Sep 17 00:00:00 2001 From: gitee Date: Thu, 31 Oct 2024 11:00:12 +0800 Subject: [PATCH 4/4] delete useless param --- .../online_dispatch/test_dispatch.py | 25 ------------------- .../online_dispatch/test_dump_compare.py | 16 ++---------- 2 files changed, 2 insertions(+), 39 deletions(-) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/online_dispatch/test_dispatch.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/online_dispatch/test_dispatch.py index cace480e7c..3655d51021 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/online_dispatch/test_dispatch.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/online_dispatch/test_dispatch.py @@ -42,26 +42,6 @@ class TestPtdbgDispatch(unittest.TestCase): args = mock_info.call_args[0] self.assertTrue(args[0].startswith('Dispatch exit')) - @patch('torch.ops.aten') - def test_check_fun_success(self, mock_aten): - run_param = RunParam('my_api', 'my_overload') - mock_func = Mock() - mock_aten.my_api = Mock() - mock_aten.my_api.my_overload = mock_func - res = self.PtdbgDispatch.check_fun(mock_func, run_param) - - self.assertTrue(res) - self.assertEqual(run_param.func_namespace, 'aten') - - @patch('torch.ops.aten') - def test_check_fun_failure(self, mock_aten): - run_param = RunParam('invalid_api', 'invalid_overload') - - res = self.PtdbgDispatch.check_fun(None, run_param) - - self.assertFalse(res) - self.assertIsNone(run_param.func_namespace) - def test_get_dir_name(self): res = self.PtdbgDispatch.get_dir_name('my_tag') @@ -99,11 +79,6 @@ class TestPtdbgDispatch(unittest.TestCase): self.PtdbgDispatch.debug_flag = 'awfef' self.PtdbgDispatch.check_param() - def test_check_param_process_num(self): - with self.assertRaises(DispatchException): - self.PtdbgDispatch.process_num = 'awfef' - self.PtdbgDispatch.check_param() - @patch('torch._C._dispatch_tls_set_dispatch_key_excluded') def test_enable_autograd(self, mock__dispatch_tls_set_dispatch_key_excluded): self.PtdbgDispatch.npu_adjust_autograd.append('to') diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/online_dispatch/test_dump_compare.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/online_dispatch/test_dump_compare.py index 0eb3a084ab..fa6e19e340 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/online_dispatch/test_dump_compare.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/online_dispatch/test_dump_compare.py @@ -82,18 +82,15 @@ class TestDumpCompare(unittest.TestCase): device_id = 0 root_npu_path = '/path/to/npu' root_cpu_path = '/path/to/cpu' - process_num = 4 comparator = 'comparator_function' - dispatch_run_param = DispatchRunParam(debug_flag, device_id, root_npu_path, root_cpu_path, process_num, - comparator) + dispatch_run_param = DispatchRunParam(debug_flag, device_id, root_npu_path, root_cpu_path, comparator) # 验证静态参数是否正确初始化 self.assertEqual(dispatch_run_param.debug_flag, debug_flag) self.assertEqual(dispatch_run_param.device_id, device_id) self.assertEqual(dispatch_run_param.root_npu_path, root_npu_path) self.assertEqual(dispatch_run_param.root_cpu_path, root_cpu_path) - self.assertEqual(dispatch_run_param.process_num, process_num) self.assertEqual(dispatch_run_param.comparator, comparator) # 验证动态参数是否有默认值 @@ -109,14 +106,13 @@ class TestDumpCompare(unittest.TestCase): def test_DisPatchDataInfo(self): mock_func = MagicMock() - mock_lock = MagicMock() cpu_args = (1, 2, 3) cpu_kwargs = {'arg1': 1, 'arg2': 2} all_summary = ['summary1', 'summary2'] npu_out_cpu = [10, 20] cpu_out = [30, 40] - dispatch_data = DisPatchDataInfo(cpu_args, cpu_kwargs, all_summary, mock_func, npu_out_cpu, cpu_out, mock_lock) + dispatch_data = DisPatchDataInfo(cpu_args, cpu_kwargs, all_summary, mock_func, npu_out_cpu, cpu_out) self.assertEqual(dispatch_data.cpu_args, cpu_args) self.assertEqual(dispatch_data.cpu_kwargs, cpu_kwargs) @@ -124,7 +120,6 @@ class TestDumpCompare(unittest.TestCase): self.assertEqual(dispatch_data.func, mock_func) self.assertEqual(dispatch_data.npu_out_cpu, npu_out_cpu) self.assertEqual(dispatch_data.cpu_out, cpu_out) - self.assertEqual(dispatch_data.lock, mock_lock) def test_support_basic_type_should_return_true_when_is_instance(self): self.assertTrue(support_basic_type(2.3)) @@ -184,7 +179,6 @@ class TestDumpCompare(unittest.TestCase): mock_data_info.cpu_kwargs=[] mock_run_param.dump_flag=True - mock_run_param.process_num = 0 mock_run_param.api_index = 1 mock_data_info.all_summary=[1] @@ -203,7 +197,6 @@ class TestDumpCompare(unittest.TestCase): mock_run_param.dump_flag=False mock_run_param.auto_dump_flag=False - mock_run_param.process_num = 1 mock_run_param.api_index = 1 mock_data_info.all_summary=[1] @@ -216,8 +209,3 @@ class TestDumpCompare(unittest.TestCase): mock_run_param.aten_api="new_attr2" mock_run_param.aten_api_overload_name="new_attr3" self.assertIsNone(get_torch_func(mock_run_param)) - - @patch('msprobe.core.common.log.BaseLogger.error') - def test_error_call(self,mock_error): - error_call("messages") - mock_error.assert_called_once_with("multiprocess messages") -- Gitee