diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py index 4ac42ac81e1a60158d8fb0beb1d2f951851c614c..763a4505b2ddfa466fb1b3b1cd40b1c3bd799805 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py @@ -87,10 +87,6 @@ def signal_handler(signum, frame): raise KeyboardInterrupt() -signal.signal(signal.SIGINT, signal_handler) -signal.signal(signal.SIGTERM, signal_handler) - - ParallelUTConfig = namedtuple('ParallelUTConfig', ['api_files', 'out_path', 'num_splits', 'save_error_data_flag', 'jit_compile_flag', 'device_id', 'result_csv_path', 'total_items', 'config_path']) @@ -217,6 +213,8 @@ def prepare_config(args): def main(): + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) parser = argparse.ArgumentParser(description='Run UT in parallel') _run_ut_parser(parser) parser.add_argument('-n', '--num_splits', type=int, choices=range(1, 65), default=8, diff --git a/debug/accuracy_tools/msprobe/test/visualization_ut/test_graph_service.py b/debug/accuracy_tools/msprobe/test/visualization_ut/test_graph_service.py index 7dfd9564ebc21327f3e7e29be90da7f78c3b0393..f9ca5592aaa153bc0446443548c3e18329784a18 100644 --- a/debug/accuracy_tools/msprobe/test/visualization_ut/test_graph_service.py +++ b/debug/accuracy_tools/msprobe/test/visualization_ut/test_graph_service.py @@ -7,7 +7,7 @@ import argparse from dataclasses import dataclass from unittest.mock import patch -from msprobe.visualization.graph_service import _compare_graph, _build_graph, _compare_graph_ranks, \ +from msprobe.visualization.graph_service import _compare_graph_result, _build_graph_result, _compare_graph_ranks, \ _compare_graph_steps, _build_graph_ranks, _build_graph_steps, _graph_service_command, _graph_service_parser from msprobe.core.common.utils import CompareException @@ -45,30 +45,31 @@ class TestGraphService(unittest.TestCase): last_call_args = mock_log_info.call_args[0][0] self.assertIn(log_info, last_call_args) matches = re.findall(self.pattern, last_call_args) - self.assertTrue(os.path.exists(os.path.join(self.output, matches[0]))) + if matches: + self.assertTrue(os.path.exists(os.path.join(self.output, matches[0]))) @patch('msprobe.core.common.log.logger.info') - def test_compare_graph(self, mock_log_info): + def test_compare_graph_result(self, mock_log_info): args = Args(output_path=self.output, framework='pytorch') - result = _compare_graph(self.input_param, args) + result = _compare_graph_result(self.input_param, args) self.assertEqual(mock_log_info.call_count, 2) self.assertIsNotNone(result) args = Args(output_path=self.output, framework='mindspore') - result = _compare_graph(self.input_param, args) + result = _compare_graph_result(self.input_param, args) self.assertIsNotNone(result) args = Args(output_path=self.output, framework='pytorch', layer_mapping=self.layer_mapping) - result = _compare_graph(self.input_param, args) + result = _compare_graph_result(self.input_param, args) self.assertIsNotNone(result) args = Args(output_path=self.output, framework='pytorch', overflow_check=True) - result = _compare_graph(self.input_param, args) + result = _compare_graph_result(self.input_param, args) self.assertIsNotNone(result) @patch('msprobe.core.common.log.logger.info') - def test_build_graph(self, mock_log_info): - result = _build_graph(os.path.join(self.input, 'step0', 'rank0'), Args(overflow_check=True)) + def test_build_graph_result(self, mock_log_info): + result = _build_graph_result(os.path.join(self.input, 'step0', 'rank0'), Args(overflow_check=True)) self.assertEqual(mock_log_info.call_count, 1) self.assertIsNotNone(result) @@ -81,7 +82,7 @@ class TestGraphService(unittest.TestCase): } args = Args(output_path=self.output, framework='pytorch') _compare_graph_ranks(input_param, args) - self.assert_log_info(mock_log_info) + self.assert_log_info(mock_log_info, 'Successfully exported compare graph results.') input_param1 = { 'npu_path': os.path.join(self.input, 'step0'), @@ -101,7 +102,7 @@ class TestGraphService(unittest.TestCase): } args = Args(output_path=self.output, framework='pytorch') _compare_graph_steps(input_param, args) - self.assert_log_info(mock_log_info) + self.assert_log_info(mock_log_info, 'Successfully exported compare graph results.') input_param1 = { 'npu_path': self.input, @@ -115,12 +116,12 @@ class TestGraphService(unittest.TestCase): @patch('msprobe.core.common.log.logger.info') def test_build_graph_ranks(self, mock_log_info): _build_graph_ranks(os.path.join(self.input, 'step0'), Args(output_path=self.output)) - self.assert_log_info(mock_log_info, "Model graph built successfully, the result file is saved in") + self.assert_log_info(mock_log_info, "Successfully exported build graph results.") @patch('msprobe.core.common.log.logger.info') def test_build_graph_steps(self, mock_log_info): _build_graph_steps(self.input, Args(output_path=self.output)) - self.assert_log_info(mock_log_info, "Model graph built successfully, the result file is saved in") + self.assert_log_info(mock_log_info, "Successfully exported build graph results.") @patch('msprobe.core.common.log.logger.info') def test_graph_service_command(self, mock_log_info): @@ -129,7 +130,7 @@ class TestGraphService(unittest.TestCase): args = Args(input_path=self.output_json[0], output_path=self.output, framework='pytorch') _graph_service_command(args) - self.assert_log_info(mock_log_info) + self.assert_log_info(mock_log_info, 'Exporting compare graph result successfully, the result file is saved in') input_param1 = { 'npu_path': os.path.join(self.input, 'step0', 'rank0'), @@ -139,7 +140,7 @@ class TestGraphService(unittest.TestCase): json.dump(input_param1, f, indent=4) args = Args(input_path=self.output_json[1], output_path=self.output, framework='pytorch') _graph_service_command(args) - self.assert_log_info(mock_log_info, "Model graph built successfully, the result file is saved in") + self.assert_log_info(mock_log_info, "Model graph exported successfully, the result file is saved in") input_param2 = { 'npu_path': os.path.join(self.input, 'step0'), @@ -150,7 +151,7 @@ class TestGraphService(unittest.TestCase): json.dump(input_param2, f, indent=4) args = Args(input_path=self.output_json[2], output_path=self.output, framework='pytorch') _graph_service_command(args) - self.assert_log_info(mock_log_info) + self.assert_log_info(mock_log_info, 'Successfully exported compare graph results.') input_param3 = { 'npu_path': self.input, @@ -161,7 +162,7 @@ class TestGraphService(unittest.TestCase): json.dump(input_param3, f, indent=4) args = Args(input_path=self.output_json[3], output_path=self.output, framework='pytorch') _graph_service_command(args) - self.assert_log_info(mock_log_info) + self.assert_log_info(mock_log_info, 'Successfully exported compare graph results.') input_param4 = { 'npu_path': os.path.join(self.input, 'step0'), @@ -171,7 +172,7 @@ class TestGraphService(unittest.TestCase): json.dump(input_param4, f, indent=4) args = Args(input_path=self.output_json[4], output_path=self.output, framework='pytorch') _graph_service_command(args) - self.assert_log_info(mock_log_info, "Model graph built successfully, the result file is saved in") + self.assert_log_info(mock_log_info, "Successfully exported build graph results.") input_param5 = { 'npu_path': self.input, @@ -181,7 +182,7 @@ class TestGraphService(unittest.TestCase): json.dump(input_param5, f, indent=4) args = Args(input_path=self.output_json[5], output_path=self.output, framework='pytorch') _graph_service_command(args) - self.assert_log_info(mock_log_info, "Model graph built successfully, the result file is saved in") + self.assert_log_info(mock_log_info, "Successfully exported build graph results.") input_param6 = { 'npu_path': self.input, diff --git a/debug/accuracy_tools/msprobe/visualization/builder/graph_builder.py b/debug/accuracy_tools/msprobe/visualization/builder/graph_builder.py index bec99d675f4b1238fde3905037ec5f7fb5a0c8fe..07e7400e8dcf2f9a0baadc48d0a2bdfc195be647 100644 --- a/debug/accuracy_tools/msprobe/visualization/builder/graph_builder.py +++ b/debug/accuracy_tools/msprobe/visualization/builder/graph_builder.py @@ -285,3 +285,20 @@ class GraphExportConfig: self.micro_steps = micro_steps self.task = task self.overflow_check = overflow_check + + +class GraphInfo: + def __init__(self, graph: Graph, construct_path: str, data_path: str, stack_path: str): + self.graph = graph + self.construct_path = construct_path + self.data_path = data_path + self.stack_path = stack_path + + +class BuildGraphTaskInfo: + def __init__(self, graph_info_n: GraphInfo, graph_info_b: GraphInfo, npu_rank, bench_rank, time_str): + self.graph_info_n = graph_info_n + self.graph_info_b = graph_info_b + self.npu_rank = npu_rank + self.bench_rank = bench_rank + self.time_str = time_str diff --git a/debug/accuracy_tools/msprobe/visualization/graph_service.py b/debug/accuracy_tools/msprobe/visualization/graph_service.py index 887c79860554077d93646ec4a92643e7d5299ff7..67da755feb1b394220c34948d6c4753e77d6fbc2 100644 --- a/debug/accuracy_tools/msprobe/visualization/graph_service.py +++ b/debug/accuracy_tools/msprobe/visualization/graph_service.py @@ -15,13 +15,14 @@ import os import time +from multiprocessing import cpu_count, Pool from msprobe.core.common.file_utils import (check_file_type, create_directory, FileChecker, check_file_or_directory_path, load_json) from msprobe.core.common.const import FileCheckConst, Const from msprobe.core.common.utils import CompareException from msprobe.visualization.compare.graph_comparator import GraphComparator -from msprobe.visualization.utils import GraphConst, check_directory_content -from msprobe.visualization.builder.graph_builder import GraphBuilder, GraphExportConfig +from msprobe.visualization.utils import GraphConst, check_directory_content, SerializableArgs +from msprobe.visualization.builder.graph_builder import GraphBuilder, GraphExportConfig, GraphInfo, BuildGraphTaskInfo from msprobe.core.common.log import logger from msprobe.visualization.graph.node_colors import NodeColors from msprobe.core.compare.layer_mapping import generate_api_mapping_by_layer_mapping @@ -32,72 +33,74 @@ from msprobe.visualization.graph.distributed_analyzer import DistributedAnalyzer current_time = time.strftime("%Y%m%d%H%M%S") -def _compare_graph(input_param, args): - logger.info('Start building model graphs...') - # 对两个数据进行构图 - dump_path_n = input_param.get('npu_path') - dump_path_b = input_param.get('bench_path') - construct_path_n = FileChecker(os.path.join(dump_path_n, GraphConst.CONSTRUCT_FILE), - FileCheckConst.FILE, FileCheckConst.READ_ABLE).common_check() - construct_path_b = FileChecker(os.path.join(dump_path_b, GraphConst.CONSTRUCT_FILE), - FileCheckConst.FILE, FileCheckConst.READ_ABLE).common_check() - data_path_n = FileChecker(os.path.join(dump_path_n, GraphConst.DUMP_FILE), FileCheckConst.FILE, - FileCheckConst.READ_ABLE).common_check() - data_path_b = FileChecker(os.path.join(dump_path_b, GraphConst.DUMP_FILE), FileCheckConst.FILE, - FileCheckConst.READ_ABLE).common_check() - stack_path_n = FileChecker(os.path.join(dump_path_n, GraphConst.STACK_FILE), FileCheckConst.FILE, - FileCheckConst.READ_ABLE).common_check() - stack_path_b = FileChecker(os.path.join(dump_path_b, GraphConst.STACK_FILE), FileCheckConst.FILE, - FileCheckConst.READ_ABLE).common_check() - graph_n = GraphBuilder.build(construct_path_n, data_path_n, stack_path_n, complete_stack=args.complete_stack) - graph_b = GraphBuilder.build(construct_path_b, data_path_b, stack_path_b, complete_stack=args.complete_stack) - logger.info('Model graphs built successfully, start Comparing graphs...') - # 基于graph、stack和data进行比较 +def _compare_graph(graph_n: GraphInfo, graph_b: GraphInfo, input_param, args): dump_path_param = { - 'npu_json_path': data_path_n, - 'bench_json_path': data_path_b, - 'stack_json_path': stack_path_n, + 'npu_json_path': graph_n.data_path, + 'bench_json_path': graph_b.data_path, + 'stack_json_path': graph_n.stack_path, 'is_print_compare_log': input_param.get("is_print_compare_log", True) } mapping_dict = {} if args.layer_mapping: try: - mapping_dict = generate_api_mapping_by_layer_mapping(data_path_n, data_path_b, args.layer_mapping) + mapping_dict = generate_api_mapping_by_layer_mapping(graph_n.data_path, graph_b.data_path, + args.layer_mapping) except Exception: logger.warning('The layer mapping file parsing failed, please check file format, mapping is not effective.') - - is_cross_framework = detect_framework_by_dump_json(data_path_n) != detect_framework_by_dump_json(data_path_b) + is_cross_framework = detect_framework_by_dump_json(graph_n.data_path) != \ + detect_framework_by_dump_json(graph_b.data_path) if is_cross_framework and not args.layer_mapping: logger.error('The cross_frame graph comparison failed. ' 'Please specify -lm or --layer_mapping when performing cross_frame graph comparison.') raise CompareException(CompareException.CROSS_FRAME_ERROR) - graph_comparator = GraphComparator([graph_n, graph_b], dump_path_param, args, is_cross_framework, + graph_comparator = GraphComparator([graph_n.graph, graph_b.graph], dump_path_param, args, is_cross_framework, mapping_dict=mapping_dict) graph_comparator.compare() - micro_steps = graph_n.paging_by_micro_step(graph_b) + return graph_comparator + + +def _compare_graph_result(input_param, args): + logger.info('Start building model graphs...') + # 对两个数据进行构图 + graph_n = _build_graph_info(input_param.get('npu_path'), args) + graph_b = _build_graph_info(input_param.get('bench_path'), args) + logger.info('Model graphs built successfully, start Comparing graphs...') + # 基于graph、stack和data进行比较 + graph_comparator = _compare_graph(graph_n, graph_b, input_param, args) + # 增加micro step标记 + micro_steps = graph_n.graph.paging_by_micro_step(graph_b.graph) # 开启溢出检测 if args.overflow_check: - graph_n.overflow_check() - graph_b.overflow_check() + graph_n.graph.overflow_check() + graph_b.graph.overflow_check() - return CompareGraphResult(graph_n, graph_b, graph_comparator, micro_steps) + return CompareGraphResult(graph_n.graph, graph_b.graph, graph_comparator, micro_steps) -def _export_compare_graph_result(args, graphs, graph_comparator, micro_steps, - output_file_name=f'compare_{current_time}.vis'): - create_directory(args.output_path) +def _export_compare_graph_result(args, result): + graphs = [result.graph_n, result.graph_b] + graph_comparator = result.graph_comparator + micro_steps = result.micro_steps + output_file_name = result.output_file_name + if not output_file_name: + output_file_name = f'compare_{current_time}.vis' + logger.info(f'Start exporting compare graph result, file name: {output_file_name}...') output_path = os.path.join(args.output_path, output_file_name) task = GraphConst.GRAPHCOMPARE_MODE_TO_DUMP_MODE_TO_MAPPING.get(graph_comparator.ma.compare_mode) export_config = GraphExportConfig(graphs[0], graphs[1], graph_comparator.ma.get_tool_tip(), NodeColors.get_node_colors(graph_comparator.ma.compare_mode), micro_steps, task, args.overflow_check) - GraphBuilder.to_json(output_path, export_config) - logger.info(f'Model graphs compared successfully, the result file is saved in {output_path}') + try: + GraphBuilder.to_json(output_path, export_config) + logger.info(f'Exporting compare graph result successfully, the result file is saved in {output_path}') + return '' + except RuntimeError as e: + logger.error(f'Failed to export compare graph result, file: {output_file_name}, error: {e}') + return output_file_name -def _build_graph(dump_path, args): - logger.info('Start building model graph...') +def _build_graph_info(dump_path, args): construct_path = FileChecker(os.path.join(dump_path, GraphConst.CONSTRUCT_FILE), FileCheckConst.FILE, FileCheckConst.READ_ABLE).common_check() data_path = FileChecker(os.path.join(dump_path, GraphConst.DUMP_FILE), FileCheckConst.FILE, @@ -105,6 +108,13 @@ def _build_graph(dump_path, args): stack_path = FileChecker(os.path.join(dump_path, GraphConst.STACK_FILE), FileCheckConst.FILE, FileCheckConst.READ_ABLE).common_check() graph = GraphBuilder.build(construct_path, data_path, stack_path, complete_stack=args.complete_stack) + return GraphInfo(graph, construct_path, data_path, stack_path) + + +def _build_graph_result(dump_path, args): + logger.info('Start building model graphs...') + graph = _build_graph_info(dump_path, args).graph + # 增加micro step标记 micro_steps = graph.paging_by_micro_step() # 开启溢出检测 if args.overflow_check: @@ -112,12 +122,71 @@ def _build_graph(dump_path, args): return BuildGraphResult(graph, micro_steps) -def _export_build_graph_result(out_path, graph, micro_steps, overflow_check, - output_file_name=f'build_{current_time}.vis'): - create_directory(out_path) +def _run_build_graph_compare(input_param, args, nr, br): + logger.info(f'Start building graph for {nr}...') + graph_n = _build_graph_info(input_param.get('npu_path'), args) + graph_b = _build_graph_info(input_param.get('bench_path'), args) + logger.info(f'Building graph for {nr} finished.') + return BuildGraphTaskInfo(graph_n, graph_b, nr, br, current_time) + + +def _run_build_graph_single(dump_ranks_path, rank, step, args): + logger.info(f'Start building graph for {rank}...') + dump_path = os.path.join(dump_ranks_path, rank) + output_file_name = f'build_{step}_{rank}_{current_time}.vis' if step else f'build_{rank}_{current_time}.vis' + result = _build_graph_result(dump_path, args) + result.output_file_name = output_file_name + if rank != Const.RANK: + try: + result.rank = int(rank.replace(Const.RANK, "")) + except Exception as e: + logger.error('The folder name format is incorrect, expected rank+number.') + raise CompareException(CompareException.INVALID_PATH_ERROR) from e + logger.info(f'Building graph for {rank} finished.') + return result + + +def _run_graph_compare(graph_task_info, input_param, args, output_file_name): + logger.info(f'Start comparing data for {graph_task_info.npu_rank}...') + graph_n = graph_task_info.graph_info_n + graph_b = graph_task_info.graph_info_b + nr = graph_task_info.npu_rank + graph_comparator = _compare_graph(graph_n, graph_b, input_param, args) + micro_steps = graph_n.graph.paging_by_micro_step(graph_b.graph) + # 开启溢出检测 + if args.overflow_check: + graph_n.graph.overflow_check() + graph_b.graph.overflow_check() + graph_result = CompareGraphResult(graph_n.graph, graph_b.graph, graph_comparator, micro_steps) + graph_result.output_file_name = output_file_name + if nr != Const.RANK: + try: + graph_result.rank = int(nr.replace(Const.RANK, "")) + except Exception as e: + logger.error('The folder name format is incorrect, expected rank+number.') + raise CompareException(CompareException.INVALID_PATH_ERROR) from e + logger.info(f'Comparing data for {graph_task_info.npu_rank} finished.') + return graph_result + + +def _export_build_graph_result(args, result): + out_path = args.output_path + graph = result.graph + micro_steps = result.micro_steps + overflow_check = args.overflow_check + output_file_name = result.output_file_name + if not output_file_name: + output_file_name = f'build_{current_time}.vis' + logger.info(f'Start exporting graph for {output_file_name}...') output_path = os.path.join(out_path, output_file_name) - GraphBuilder.to_json(output_path, GraphExportConfig(graph, micro_steps=micro_steps, overflow_check=overflow_check)) - logger.info(f'Model graph built successfully, the result file is saved in {output_path}') + try: + GraphBuilder.to_json(output_path, GraphExportConfig(graph, micro_steps=micro_steps, + overflow_check=overflow_check)) + logger.info(f'Model graph exported successfully, the result file is saved in {output_path}') + return None + except RuntimeError as e: + logger.error(f'Failed to export model graph, file: {output_file_name}, error: {e}') + return output_file_name def _compare_graph_ranks(input_param, args, step=None): @@ -128,33 +197,49 @@ def _compare_graph_ranks(input_param, args, step=None): if npu_ranks != bench_ranks: logger.error('The number of ranks in the two runs are different. Unable to match the ranks.') raise CompareException(CompareException.INVALID_PATH_ERROR) + mp_res_dict = {} compare_graph_results = [] - for nr, br in zip(npu_ranks, bench_ranks): - logger.info(f'Start processing data for {nr}...') - input_param['npu_path'] = os.path.join(dump_rank_n, nr) - input_param['bench_path'] = os.path.join(dump_rank_b, br) - output_file_name = f'compare_{step}_{nr}_{current_time}.vis' if step else f'compare_{nr}_{current_time}.vis' - result = _compare_graph(input_param, args) - result.output_file_name = output_file_name - if nr != Const.RANK: + with Pool(processes=max(int((cpu_count() + 1) // 4), 1)) as pool: + def err_call(err): + logger.error(f'Error occurred while comparing graph ranks: {err}') try: - result.rank = int(nr.replace(Const.RANK, "")) - except Exception as e: - logger.error('The folder name format is incorrect, expected rank+number.') - raise CompareException(CompareException.INVALID_PATH_ERROR) from e - # 暂存所有rank的graph,用于匹配rank间的分布式节点 - compare_graph_results.append(result) - - # 匹配rank间的分布式节点 - if len(compare_graph_results) > 1: - DistributedAnalyzer({obj.rank: obj.graph_n for obj in compare_graph_results}, - args.overflow_check).distributed_match() - DistributedAnalyzer({obj.rank: obj.graph_b for obj in compare_graph_results}, - args.overflow_check).distributed_match() - - for result in compare_graph_results: - _export_compare_graph_result(args, [result.graph_n, result.graph_b], result.graph_comparator, - result.micro_steps, output_file_name=result.output_file_name) + pool.terminate() + except OSError as e: + logger.error(f'Error occurred while terminating the pool: {e}') + + serializable_args = SerializableArgs(args) + for nr, br in zip(npu_ranks, bench_ranks): + input_param['npu_path'] = os.path.join(dump_rank_n, nr) + input_param['bench_path'] = os.path.join(dump_rank_b, br) + output_file_name = f'compare_{step}_{nr}_{current_time}.vis' if step else f'compare_{nr}_{current_time}.vis' + mp_res_dict[output_file_name] = pool.apply_async(_run_build_graph_compare, + args=(input_param, serializable_args, nr, br), + error_callback=err_call) + + for output_file_name, mp_res in mp_res_dict.items(): + # 暂存所有rank的graph,用于匹配rank间的分布式节点 + compare_graph_results.append(_run_graph_compare(mp_res.get(), input_param, serializable_args, + output_file_name)) + + # 匹配rank间的分布式节点 + if len(compare_graph_results) > 1: + DistributedAnalyzer({obj.rank: obj.graph_n for obj in compare_graph_results}, + args.overflow_check).distributed_match() + DistributedAnalyzer({obj.rank: obj.graph_b for obj in compare_graph_results}, + args.overflow_check).distributed_match() + + export_res_task_list = [] + create_directory(args.output_path) + for result in compare_graph_results: + export_res_task_list.append(pool.apply_async(_export_compare_graph_result, + args=(serializable_args, result), + error_callback=err_call)) + export_res_list = [res.get() for res in export_res_task_list] + if any(export_res_list): + failed_names = list(filter(lambda x: x, export_res_list)) + logger.error(f'Unable to export compare graph results: {", ".join(failed_names)}.') + else: + logger.info('Successfully exported compare graph results.') def _compare_graph_steps(input_param, args): @@ -178,28 +263,39 @@ def _compare_graph_steps(input_param, args): def _build_graph_ranks(dump_ranks_path, args, step=None): ranks = sorted(check_and_return_dir_contents(dump_ranks_path, Const.RANK)) - build_graph_results = [] - for rank in ranks: - logger.info(f'Start processing data for {rank}...') - dump_path = os.path.join(dump_ranks_path, rank) - output_file_name = f'build_{step}_{rank}_{current_time}.vis' if step else f'build_{rank}_{current_time}.vis' - result = _build_graph(dump_path, args) - result.output_file_name = output_file_name - if rank != Const.RANK: + serializable_args = SerializableArgs(args) + with Pool(processes=max(int((cpu_count() + 1) // 4), 1)) as pool: + def err_call(err): + logger.error(f'Error occurred while comparing graph ranks: {err}') try: - result.rank = int(rank.replace(Const.RANK, "")) - except Exception as e: - logger.error('The folder name format is incorrect, expected rank+number.') - raise CompareException(CompareException.INVALID_PATH_ERROR) from e - build_graph_results.append(result) - - if len(build_graph_results) > 1: - DistributedAnalyzer({obj.rank: obj.graph for obj in build_graph_results}, - args.overflow_check).distributed_match() + pool.terminate() + except OSError as e: + logger.error(f'Error occurred while terminating the pool: {e}') + + build_graph_tasks = [] + for rank in ranks: + build_graph_tasks.append(pool.apply_async(_run_build_graph_single, + args=(dump_ranks_path, rank, step, serializable_args), + error_callback=err_call)) + build_graph_results = [task.get() for task in build_graph_tasks] + + if len(build_graph_results) > 1: + DistributedAnalyzer({obj.rank: obj.graph for obj in build_graph_results}, + args.overflow_check).distributed_match() + + create_directory(args.output_path) + export_build_graph_tasks = [] + for result in build_graph_results: + export_build_graph_tasks.append(pool.apply_async(_export_build_graph_result, + args=(serializable_args, result), + error_callback=err_call)) + export_build_graph_result = [task.get() for task in export_build_graph_tasks] + if any(export_build_graph_result): + failed_names = list(filter(lambda x: x, export_build_graph_result)) + logger.error(f'Unable to export build graph results: {", ".join(failed_names)}.') + else: + logger.info(f'Successfully exported build graph results.') - for result in build_graph_results: - _export_build_graph_result(args.output_path, result.graph, result.micro_steps, args.overflow_check, - result.output_file_name) def _build_graph_steps(dump_steps_path, args): @@ -215,7 +311,7 @@ def _graph_service_parser(parser): help=" The compare input path, a dict json.", required=True) parser.add_argument("-o", "--output_path", dest="output_path", type=str, help=" The compare task result out path.", required=True) - parser.add_argument("-lm", "--layer_mapping", dest="layer_mapping", type=str, nargs='?', const=True, + parser.add_argument("-lm", "--layer_mapping", dest="layer_mapping", type=str, help=" The layer mapping file path.", required=False) parser.add_argument("-oc", "--overflow_check", dest="overflow_check", action="store_true", help=" whether open overflow_check for graph.", required=False) @@ -239,8 +335,11 @@ def _graph_service_command(args): elif content == GraphConst.STEPS: _build_graph_steps(npu_path, args) else: - result = _build_graph(npu_path, args) - _export_build_graph_result(args.output_path, result.graph, result.micro_steps, args.overflow_check) + result = _build_graph_result(npu_path, args) + create_directory(args.output_path) + file_name = _export_build_graph_result(args, result) + if file_name: + logger.error('Failed to export model build graph.') elif check_file_type(npu_path) == FileCheckConst.DIR and check_file_type(bench_path) == FileCheckConst.DIR: content_n = check_directory_content(npu_path) content_b = check_directory_content(bench_path) @@ -251,9 +350,11 @@ def _graph_service_command(args): elif content_n == GraphConst.STEPS: _compare_graph_steps(input_param, args) else: - result = _compare_graph(input_param, args) - _export_compare_graph_result(args, [result.graph_n, result.graph_b], - result.graph_comparator, result.micro_steps) + result = _compare_graph_result(input_param, args) + create_directory(args.output_path) + file_name = _export_compare_graph_result(args, result) + if file_name: + logger.error('Failed to export model compare graph.') else: logger.error("The npu_path or bench_path should be a folder.") raise CompareException(CompareException.INVALID_COMPARE_MODE) diff --git a/debug/accuracy_tools/msprobe/visualization/utils.py b/debug/accuracy_tools/msprobe/visualization/utils.py index 5f428697bdef1bc1de72a3eb6da1c4c5761eae2a..4193207ee3382ec83adeaa5937604ce4660fceef 100644 --- a/debug/accuracy_tools/msprobe/visualization/utils.py +++ b/debug/accuracy_tools/msprobe/visualization/utils.py @@ -16,6 +16,7 @@ import os import re import json +import pickle from msprobe.core.common.file_utils import FileOpen from msprobe.core.common.const import CompareConst, Const from msprobe.core.compare.acc_compare import Comparator, ModeConfig @@ -192,3 +193,21 @@ class GraphConst: OP = 'op' PEER = 'peer' GROUP_ID = 'group_id' + + +def is_serializable(obj): + """ + Check if an object is serializable + """ + try: + pickle.dumps(obj) + return True + except Exception: + return False + + +class SerializableArgs: + def __init__(self, args): + for k, v in vars(args).items(): + if is_serializable(v): + setattr(self, k, v)