From 3e5a9434702e14c71277bcdb107151ddddbc5673 Mon Sep 17 00:00:00 2001 From: yang-minghai22 Date: Thu, 7 Dec 2023 17:17:16 +0800 Subject: [PATCH] [Feature]support pkl compare without increase parameter --- .../src/python/ptdbg_ascend/common/utils.py | 21 ++++++++++++++++--- .../ptdbg_ascend/compare/acc_compare.py | 10 +++++---- .../compare/distributed_compare.py | 7 ++++--- 3 files changed, 28 insertions(+), 10 deletions(-) diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/common/utils.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/common/utils.py index a78eb595f6..213aa01589 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/common/utils.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/common/utils.py @@ -358,9 +358,24 @@ def check_compare_param(input_parma, output_path, stack_mode=False, summary_comp check_pkl_file(input_parma, npu_pkl, bench_pkl, stack_mode) -def check_configuration_param(stack_mode=False, auto_analyze=True, fuzzy_match=False, summary_compare=False): - if not (isinstance(stack_mode, bool) and isinstance(auto_analyze, bool) and isinstance(fuzzy_match, bool) and - isinstance(summary_compare, bool)): +def is_summary_compare(input_param): + npu_pkl_path = input_param.get("npu_pkl_path", None) + bench_pkl_path = input_param.get("bench_pkl_path", None) + npu_dump_data_dir = input_param.get("npu_dump_data_dir", None) + bench_dump_data_dir = input_param.get("bench_dump_data_dir", None) + if not npu_pkl_path or not bench_pkl_path: + print_error_log(f"Please check the pkl path is valid.") + raise CompareException(CompareException.INVALID_PATH_ERROR) + if not (npu_dump_data_dir and bench_dump_data_dir): + return True + if npu_dump_data_dir and bench_dump_data_dir: + return False + print_error_log(f"Please check the dump data dir is valid.") + raise CompareException(CompareException.INVALID_PATH_ERROR) + + +def check_configuration_param(stack_mode=False, auto_analyze=True, fuzzy_match=False): + if not (isinstance(stack_mode, bool) and isinstance(auto_analyze, bool) and isinstance(fuzzy_match, bool)): print_error_log("Invalid input parameters which should be only bool type.") raise CompareException(CompareException.INVALID_PARAM_ERROR) diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/compare/acc_compare.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/compare/acc_compare.py index bd5cd01573..3a72c196ec 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/compare/acc_compare.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/compare/acc_compare.py @@ -28,7 +28,8 @@ from .match import graph_mapping from ..advisor.advisor import Advisor from ..common.utils import check_compare_param, add_time_as_suffix, \ print_warn_log, print_error_log, CompareException, Const,\ - CompareConst, format_value, check_file_not_exists, check_configuration_param + CompareConst, format_value, check_file_not_exists, check_configuration_param, \ + is_summary_compare from ..common.file_check_util import FileChecker, FileCheckConst, change_mode, FileOpen @@ -559,10 +560,11 @@ def handle_inf_nan(n_value, b_value): def compare(input_parma, output_path, stack_mode=False, auto_analyze=True, - fuzzy_match=False, summary_compare=False): + fuzzy_match=False): try: + summary_compare = is_summary_compare(input_parma) check_compare_param(input_parma, output_path, stack_mode, summary_compare) - check_configuration_param(stack_mode, auto_analyze, fuzzy_match, summary_compare) + check_configuration_param(stack_mode, auto_analyze, fuzzy_match) except CompareException as error: print_error_log('Compare failed. Please check the arguments and do it again!') sys.exit(error.code) @@ -647,7 +649,7 @@ def compare_process(file_handles, stack_mode, fuzzy_match, summary_compare=False un_match_data = npu_ops_queue[0: n_match_point] for npu_data in un_match_data: get_un_match_accuracy(result, npu_data) - get_accuracy(result, n_match_data, b_match_data) + get_accuracy(result, n_match_data, b_match_data, summary_compare) del npu_ops_queue[0: n_match_point + 1] del bench_ops_queue[0: b_match_point + 1] if npu_ops_queue: diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/compare/distributed_compare.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/compare/distributed_compare.py index b1e13c85c0..7133ffe66b 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/compare/distributed_compare.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/compare/distributed_compare.py @@ -18,7 +18,7 @@ import os import sys import re from ..common.utils import print_error_log, CompareException, check_compare_param, check_file_or_directory_path, \ - check_configuration_param + check_configuration_param, is_summary_compare from .acc_compare import compare_core @@ -87,9 +87,10 @@ def compare_distributed(npu_dump_dir, bench_dump_dir, output_path, **kwargs): 'is_print_compare_log': True } try: - check_compare_param(dump_result_param, output_path, **kwargs) + summary_compare = is_summary_compare(dump_result_param) + check_compare_param(dump_result_param, output_path, summary_compare=summary_compare, **kwargs) check_configuration_param(**kwargs) except CompareException as error: print_error_log('Compare failed. Please check the arguments and do it again!') sys.exit(error.code) - compare_core(dump_result_param, output_path, suffix=f'_{nr}-{br}', **kwargs) + compare_core(dump_result_param, output_path, suffix=f'_{nr}-{br}', summary_compare=summary_compare, **kwargs) -- Gitee