diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/multi_run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/multi_run_ut.py index 47e391f3febd9a1bfba7461cda4a28eb8b2c1251..1609ff5de2cf577f5b51ff9aa10d534f18c67663 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/multi_run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/multi_run_ut.py @@ -12,13 +12,15 @@ from tqdm import tqdm from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileCheckConst, FileChecker, \ check_file_suffix, check_link, FileOpen from api_accuracy_checker.compare.compare import Comparator -from api_accuracy_checker.run_ut.run_ut import _run_ut_parser, get_validated_result_csv_path, get_validated_details_csv_path +from api_accuracy_checker.run_ut.run_ut import _run_ut_parser, get_validated_result_csv_path, get_validated_details_csv_path, preprocess_forward_content from api_accuracy_checker.common.utils import print_error_log, print_warn_log, print_info_log -def split_json_file(input_file, num_splits): +def split_json_file(input_file, num_splits, filter_api): with FileOpen(input_file, 'r') as file: data = json.load(file) + if filter_api: + data = preprocess_forward_content(data) items = list(data.items()) total_items = len(items) @@ -138,7 +140,7 @@ def prepare_config(args): out_path = os.path.realpath(args.out_path) if args.out_path else "./" out_path_checker = FileChecker(out_path, FileCheckConst.DIR, ability=FileCheckConst.WRITE_ABLE) out_path = out_path_checker.common_check() - forward_splits, total_items = split_json_file(args.forward_input_file, args.num_splits) + forward_splits, total_items = split_json_file(args.forward_input_file, args.num_splits, args.filter_api) backward_splits = [backward_file] * args.num_splits if backward_file else [None] * args.num_splits result_csv_path = args.result_csv_path or os.path.join(out_path, f"accuracy_checking_result_{time.strftime('%Y%m%d%H%M%S')}.csv") if not args.result_csv_path: diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index 5101705434de0b0325817b05758922197f8bf381..f153870999553b2aa5b6e8a74682ede5512c9a24 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -332,6 +332,37 @@ def _run_ut_parser(parser): help=" In real data mode, the root directory for storing real data " "must be configured.", required=False) + parser.add_argument("-f", "--filter_api", dest="filter_api", action="store_true", + help=" Whether to filter the api in the forward_input_file.", required=False) + + +def preprocess_forward_content(forward_content): + processed_content = {} + base_keys_variants = {} + for key, value in forward_content.items(): + base_key = key.rsplit('*', 1)[0] + new_args = value['args'] + new_kwargs = value['kwargs'] + filtered_new_args = [{k: v for k, v in arg.items() if k not in ['Max', 'Min']} for arg in new_args if isinstance(arg, dict)] + if base_key in base_keys_variants: + is_duplicate = False + for variant in base_keys_variants.get(base_key, []): + try: + existing_args = processed_content[variant].get('args', []) + existing_kwargs = processed_content[variant].get('kwargs', {}) + filtered_existing_args = [{k: v for k, v in arg.items() if k not in ['Max', 'Min']} for arg in existing_args if isinstance(arg, dict)] + except KeyError as e: + print_error_log(f"KeyError: {e} when processing {key}") + if filtered_existing_args == filtered_new_args and existing_kwargs == new_kwargs: + is_duplicate = True + break + if not is_duplicate: + processed_content[key] = value + base_keys_variants[base_key].append(key) + else: + processed_content[key] = value + base_keys_variants[base_key] = [key] + return processed_content def _run_ut(): @@ -357,6 +388,8 @@ def _run_ut(): out_path = out_path_checker.common_check() save_error_data = args.save_error_data forward_content = get_json_contents(forward_file) + if args.filter_api: + forward_content = preprocess_forward_content(forward_content) backward_content = {} if args.backward_input_file: check_link(args.backward_input_file)