From 1944c5cf21879acf1d45675931058ae3c4c1d097 Mon Sep 17 00:00:00 2001 From: s30048155 Date: Mon, 26 Feb 2024 16:16:23 +0800 Subject: [PATCH 1/9] filte same api --- .../api_accuracy_checker/run_ut/multi_run_ut.py | 3 ++- .../api_accuracy_checker/run_ut/run_ut.py | 13 +++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/multi_run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/multi_run_ut.py index 47e391f3f..02d4b1ef3 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/multi_run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/multi_run_ut.py @@ -12,13 +12,14 @@ from tqdm import tqdm from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileCheckConst, FileChecker, \ check_file_suffix, check_link, FileOpen from api_accuracy_checker.compare.compare import Comparator -from api_accuracy_checker.run_ut.run_ut import _run_ut_parser, get_validated_result_csv_path, get_validated_details_csv_path +from api_accuracy_checker.run_ut.run_ut import _run_ut_parser, get_validated_result_csv_path, get_validated_details_csv_path, preprocess_forward_content from api_accuracy_checker.common.utils import print_error_log, print_warn_log, print_info_log def split_json_file(input_file, num_splits): with FileOpen(input_file, 'r') as file: data = json.load(file) + data = preprocess_forward_content(data) items = list(data.items()) total_items = len(items) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index 510170543..0cef494fb 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -334,6 +334,18 @@ def _run_ut_parser(parser): required=False) +def preprocess_forward_content(forward_content): + unique_apis = {} + for api_full_name, api_info in forward_content.items(): + api_type, _, _ = api_full_name.rsplit("*", 2) + args = api_info.get("args", []) + dtype_shape_key = tuple((arg.get("dtype"), tuple(arg.get("shape", []))) for arg in args if "type" in arg and arg["type"] == "torch.Tensor") + unique_key = (api_type, dtype_shape_key) + if unique_key not in unique_apis: + unique_apis[unique_key] = api_full_name + filtered_forward_content = {unique_apis[key]: forward_content[api_name] for key, api_name in unique_apis.items()} + return filtered_forward_content + def _run_ut(): parser = argparse.ArgumentParser() _run_ut_parser(parser) @@ -357,6 +369,7 @@ def _run_ut(): out_path = out_path_checker.common_check() save_error_data = args.save_error_data forward_content = get_json_contents(forward_file) + forward_content = preprocess_forward_content(forward_content) backward_content = {} if args.backward_input_file: check_link(args.backward_input_file) -- Gitee From 91acf1162bf5fe09e2f30c723ce8c03edfdfe343 Mon Sep 17 00:00:00 2001 From: s30048155 Date: Mon, 26 Feb 2024 16:28:28 +0800 Subject: [PATCH 2/9] clean code --- .../api_accuracy_checker/run_ut/run_ut.py | 23 +++++++++++++------ 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index 0cef494fb..79fb3cdd1 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -337,15 +337,24 @@ def _run_ut_parser(parser): def preprocess_forward_content(forward_content): unique_apis = {} for api_full_name, api_info in forward_content.items(): - api_type, _, _ = api_full_name.rsplit("*", 2) - args = api_info.get("args", []) - dtype_shape_key = tuple((arg.get("dtype"), tuple(arg.get("shape", []))) for arg in args if "type" in arg and arg["type"] == "torch.Tensor") - unique_key = (api_type, dtype_shape_key) - if unique_key not in unique_apis: - unique_apis[unique_key] = api_full_name - filtered_forward_content = {unique_apis[key]: forward_content[api_name] for key, api_name in unique_apis.items()} + try: + api_type, _, _ = api_full_name.rsplit("*", 2) + args = api_info.get("args", []) + dtype_shape_key = tuple((arg.get("dtype"), tuple(arg.get("shape", []))) for arg in args if "type" in arg and arg["type"] == "torch.Tensor") + unique_key = (api_type, dtype_shape_key) + if unique_key not in unique_apis: + unique_apis[unique_key] = api_full_name + except KeyError as e: + raise KeyError(f"The api {api_full_name} has no args or dtype_shape_key, please check the forward_content.") + filtered_forward_content = {} + for key, api_name in unique_apis.items(): + try: + filtered_forward_content[unique_apis[key]] = forward_content.get(api_name) + except KeyError as e: + raise KeyError(f"The api {api_name} is not in forward_content, please check the forward_content.") return filtered_forward_content + def _run_ut(): parser = argparse.ArgumentParser() _run_ut_parser(parser) -- Gitee From e4271999d1916a9d02c4365fd442b9f7dd3a3122 Mon Sep 17 00:00:00 2001 From: s30048155 Date: Mon, 26 Feb 2024 16:47:29 +0800 Subject: [PATCH 3/9] clean code --- debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index 79fb3cdd1..841839e88 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -345,13 +345,13 @@ def preprocess_forward_content(forward_content): if unique_key not in unique_apis: unique_apis[unique_key] = api_full_name except KeyError as e: - raise KeyError(f"The api {api_full_name} has no args or dtype_shape_key, please check the forward_content.") + raise KeyError(f"The api {api_full_name} has no args or dtype_shape_key, please check the forward_content.") from e filtered_forward_content = {} for key, api_name in unique_apis.items(): try: filtered_forward_content[unique_apis[key]] = forward_content.get(api_name) except KeyError as e: - raise KeyError(f"The api {api_name} is not in forward_content, please check the forward_content.") + raise KeyError(f"The api {api_name} is not in forward_content, please check the forward_content.") from e return filtered_forward_content -- Gitee From 93be0135930507dacb3276718befd7b7c874fd25 Mon Sep 17 00:00:00 2001 From: s30048155 Date: Tue, 27 Feb 2024 16:59:28 +0800 Subject: [PATCH 4/9] update --- .../api_accuracy_checker/run_ut/run_ut.py | 33 +++++++++---------- 1 file changed, 15 insertions(+), 18 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index 841839e88..e030b8d6d 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -335,24 +335,21 @@ def _run_ut_parser(parser): def preprocess_forward_content(forward_content): - unique_apis = {} - for api_full_name, api_info in forward_content.items(): - try: - api_type, _, _ = api_full_name.rsplit("*", 2) - args = api_info.get("args", []) - dtype_shape_key = tuple((arg.get("dtype"), tuple(arg.get("shape", []))) for arg in args if "type" in arg and arg["type"] == "torch.Tensor") - unique_key = (api_type, dtype_shape_key) - if unique_key not in unique_apis: - unique_apis[unique_key] = api_full_name - except KeyError as e: - raise KeyError(f"The api {api_full_name} has no args or dtype_shape_key, please check the forward_content.") from e - filtered_forward_content = {} - for key, api_name in unique_apis.items(): - try: - filtered_forward_content[unique_apis[key]] = forward_content.get(api_name) - except KeyError as e: - raise KeyError(f"The api {api_name} is not in forward_content, please check the forward_content.") from e - return filtered_forward_content + processed_content = {} + for key, value in forward_content.items(): + base_key = key.rsplit('*', 1)[0] + if base_key in processed_content: + existing_args = processed_content[base_key]['args'] + new_args = value['args'] + filtered_existing_args = [{k: v for k, v in arg.items() if k not in ['Max', 'Min']} for arg in existing_args if isinstance(arg, dict)] + filtered_new_args = [{k: v for k, v in arg.items() if k not in ['Max', 'Min']} for arg in new_args if isinstance(arg, dict)] + if filtered_existing_args == filtered_new_args and processed_content[base_key]['kwargs'] == value['kwargs']: + continue + else: + processed_content[key] = value + else: + processed_content[base_key] = value + return processed_content def _run_ut(): -- Gitee From 5954c2002576d958920767b374feb7b682022435 Mon Sep 17 00:00:00 2001 From: s30048155 Date: Tue, 27 Feb 2024 18:12:10 +0800 Subject: [PATCH 5/9] update --- debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index e030b8d6d..a8a5fd293 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -337,13 +337,13 @@ def _run_ut_parser(parser): def preprocess_forward_content(forward_content): processed_content = {} for key, value in forward_content.items(): - base_key = key.rsplit('*', 1)[0] + base_key = key.rsplit('*', 1)[0] + "*0" if base_key in processed_content: existing_args = processed_content[base_key]['args'] new_args = value['args'] filtered_existing_args = [{k: v for k, v in arg.items() if k not in ['Max', 'Min']} for arg in existing_args if isinstance(arg, dict)] filtered_new_args = [{k: v for k, v in arg.items() if k not in ['Max', 'Min']} for arg in new_args if isinstance(arg, dict)] - if filtered_existing_args == filtered_new_args and processed_content[base_key]['kwargs'] == value['kwargs']: + if filtered_existing_args == filtered_new_args: continue else: processed_content[key] = value -- Gitee From 3da9b768dd8c840656f4e17e4e9f7d5010eac3eb Mon Sep 17 00:00:00 2001 From: s30048155 Date: Wed, 28 Feb 2024 10:25:57 +0800 Subject: [PATCH 6/9] add args --- .../api_accuracy_checker/run_ut/multi_run_ut.py | 7 ++++--- debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py | 5 ++++- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/multi_run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/multi_run_ut.py index 02d4b1ef3..1609ff5de 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/multi_run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/multi_run_ut.py @@ -16,10 +16,11 @@ from api_accuracy_checker.run_ut.run_ut import _run_ut_parser, get_validated_res from api_accuracy_checker.common.utils import print_error_log, print_warn_log, print_info_log -def split_json_file(input_file, num_splits): +def split_json_file(input_file, num_splits, filter_api): with FileOpen(input_file, 'r') as file: data = json.load(file) - data = preprocess_forward_content(data) + if filter_api: + data = preprocess_forward_content(data) items = list(data.items()) total_items = len(items) @@ -139,7 +140,7 @@ def prepare_config(args): out_path = os.path.realpath(args.out_path) if args.out_path else "./" out_path_checker = FileChecker(out_path, FileCheckConst.DIR, ability=FileCheckConst.WRITE_ABLE) out_path = out_path_checker.common_check() - forward_splits, total_items = split_json_file(args.forward_input_file, args.num_splits) + forward_splits, total_items = split_json_file(args.forward_input_file, args.num_splits, args.filter_api) backward_splits = [backward_file] * args.num_splits if backward_file else [None] * args.num_splits result_csv_path = args.result_csv_path or os.path.join(out_path, f"accuracy_checking_result_{time.strftime('%Y%m%d%H%M%S')}.csv") if not args.result_csv_path: diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index a8a5fd293..bfb55003f 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -332,6 +332,8 @@ def _run_ut_parser(parser): help=" In real data mode, the root directory for storing real data " "must be configured.", required=False) + parser.add_argument("-f", "--filter_api", dest="filter_api", action="store_true", + help=" Whether to filter the api in the forward_input_file.", required=False) def preprocess_forward_content(forward_content): @@ -375,7 +377,8 @@ def _run_ut(): out_path = out_path_checker.common_check() save_error_data = args.save_error_data forward_content = get_json_contents(forward_file) - forward_content = preprocess_forward_content(forward_content) + if args.filter_api: + forward_content = preprocess_forward_content(forward_content) backward_content = {} if args.backward_input_file: check_link(args.backward_input_file) -- Gitee From c43d46d88fe56d9b211008f50ffbad70add75d03 Mon Sep 17 00:00:00 2001 From: s30048155 Date: Thu, 29 Feb 2024 09:38:19 +0800 Subject: [PATCH 7/9] add kwargs --- debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index bfb55003f..674fc1efa 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -345,7 +345,7 @@ def preprocess_forward_content(forward_content): new_args = value['args'] filtered_existing_args = [{k: v for k, v in arg.items() if k not in ['Max', 'Min']} for arg in existing_args if isinstance(arg, dict)] filtered_new_args = [{k: v for k, v in arg.items() if k not in ['Max', 'Min']} for arg in new_args if isinstance(arg, dict)] - if filtered_existing_args == filtered_new_args: + if filtered_existing_args == filtered_new_args and processed_content[base_key]['kwargs'] == value['kwargs']: continue else: processed_content[key] = value -- Gitee From 6ca4fae733c4a39f516e47230f1eebf14e334f58 Mon Sep 17 00:00:00 2001 From: sunyiming Date: Fri, 1 Mar 2024 01:35:03 +0000 Subject: [PATCH 8/9] update debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py. Signed-off-by: sunyiming --- .../api_accuracy_checker/run_ut/run_ut.py | 33 +++++++++++++------ 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index 674fc1efa..c71f65400 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -338,19 +338,32 @@ def _run_ut_parser(parser): def preprocess_forward_content(forward_content): processed_content = {} + base_keys_variants = {} + for key, value in forward_content.items(): - base_key = key.rsplit('*', 1)[0] + "*0" - if base_key in processed_content: - existing_args = processed_content[base_key]['args'] - new_args = value['args'] - filtered_existing_args = [{k: v for k, v in arg.items() if k not in ['Max', 'Min']} for arg in existing_args if isinstance(arg, dict)] - filtered_new_args = [{k: v for k, v in arg.items() if k not in ['Max', 'Min']} for arg in new_args if isinstance(arg, dict)] - if filtered_existing_args == filtered_new_args and processed_content[base_key]['kwargs'] == value['kwargs']: - continue - else: + base_key = key.rsplit('*', 1)[0] + new_args = value['args'] + new_kwargs = value['kwargs'] + filtered_new_args = [{k: v for k, v in arg.items() if k not in ['Max', 'Min']} for arg in new_args if isinstance(arg, dict)] + + if base_key in base_keys_variants: + is_duplicate = False + for variant in base_keys_variants[base_key]: + existing_args = processed_content[variant]['args'] + existing_kwargs = processed_content[variant]['kwargs'] + filtered_existing_args = [{k: v for k, v in arg.items() if k not in ['Max', 'Min']} for arg in existing_args if isinstance(arg, dict)] + + if filtered_existing_args == filtered_new_args and existing_kwargs == new_kwargs: + is_duplicate = True + break + + if not is_duplicate: processed_content[key] = value + base_keys_variants[base_key].append(key) else: - processed_content[base_key] = value + processed_content[key] = value + base_keys_variants[base_key] = [key] + return processed_content -- Gitee From da7147da398f8afcac06426fa1b709396624ebba Mon Sep 17 00:00:00 2001 From: sunyiming Date: Fri, 1 Mar 2024 02:24:06 +0000 Subject: [PATCH 9/9] update debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py. Signed-off-by: sunyiming --- .../api_accuracy_checker/run_ut/run_ut.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index c71f65400..f15387099 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -339,31 +339,29 @@ def _run_ut_parser(parser): def preprocess_forward_content(forward_content): processed_content = {} base_keys_variants = {} - for key, value in forward_content.items(): base_key = key.rsplit('*', 1)[0] new_args = value['args'] new_kwargs = value['kwargs'] filtered_new_args = [{k: v for k, v in arg.items() if k not in ['Max', 'Min']} for arg in new_args if isinstance(arg, dict)] - if base_key in base_keys_variants: is_duplicate = False - for variant in base_keys_variants[base_key]: - existing_args = processed_content[variant]['args'] - existing_kwargs = processed_content[variant]['kwargs'] - filtered_existing_args = [{k: v for k, v in arg.items() if k not in ['Max', 'Min']} for arg in existing_args if isinstance(arg, dict)] - + for variant in base_keys_variants.get(base_key, []): + try: + existing_args = processed_content[variant].get('args', []) + existing_kwargs = processed_content[variant].get('kwargs', {}) + filtered_existing_args = [{k: v for k, v in arg.items() if k not in ['Max', 'Min']} for arg in existing_args if isinstance(arg, dict)] + except KeyError as e: + print_error_log(f"KeyError: {e} when processing {key}") if filtered_existing_args == filtered_new_args and existing_kwargs == new_kwargs: is_duplicate = True break - if not is_duplicate: processed_content[key] = value base_keys_variants[base_key].append(key) else: processed_content[key] = value base_keys_variants[base_key] = [key] - return processed_content -- Gitee