From a825fe6ad9d8def61fe3b4262b2ec91de980610e Mon Sep 17 00:00:00 2001 From: keith Date: Wed, 8 May 2024 17:13:10 +0800 Subject: [PATCH 1/3] =?UTF-8?q?ptdbg=E3=80=81API=E9=A2=84=E6=A3=80?= =?UTF-8?q?=E5=AD=90=E5=8A=9F=E8=83=BDAPI=E5=91=BD=E5=90=8D=E6=A0=BC?= =?UTF-8?q?=E5=BC=8F=E7=BB=9F=E4=B8=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../api_accuracy_checker/dump/dump.py | 1 + .../api_accuracy_checker/run_ut/run_ut.py | 7 ++-- .../src/python/ptdbg_ascend/dump/dump.py | 35 +++++++++++++++++++ .../ptdbg_ascend/test/ut/test_dump.py | 9 +++++ 4 files changed, 48 insertions(+), 4 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/dump.py b/debug/accuracy_tools/api_accuracy_checker/dump/dump.py index d8b317aa2..f5a4620e1 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/dump.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/dump.py @@ -87,6 +87,7 @@ class DumpConst: def pretest_info_dump(name, out_feat, module, phase): if not DumpUtil.get_dump_switch(): return + name = name.replace('*', '.') if phase == DumpConst.forward: api_info = ForwardAPIInfo(name, module.input_args, module.input_kwargs) elif phase == DumpConst.backward: diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index 05bd4305a..71018a86d 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -170,7 +170,7 @@ def run_ut(config): continue try: if msCheckerConfig.white_list: - [_, api_name, _] = api_full_name.split("*") + [_, api_name, _] = api_full_name.split(".") if api_name not in set(msCheckerConfig.white_list): continue data_info = run_torch_api(api_full_name, config.real_data_path, config.backward_content, api_info_dict) @@ -179,7 +179,7 @@ def run_ut(config): if config.save_error_data: do_save_error_data(api_full_name, data_info, is_fwd_success, is_bwd_success) except Exception as err: - [_, api_name, _] = api_full_name.split("*") + [_, api_name, _] = api_full_name.split(".") if "expected scalar type Long" in str(err): print_warn_log(f"API {api_name} not support int32 tensor in CPU, please add {api_name} to CONVERT_API " f"'int32_to_int64' list in accuracy_tools/api_accuracy_check/common/utils.py file.") @@ -201,7 +201,6 @@ def run_ut(config): def do_save_error_data(api_full_name, data_info, is_fwd_success, is_bwd_success): if not is_fwd_success or not is_bwd_success: - api_full_name = api_full_name.replace("*", ".") for element in data_info.in_fwd_data_list: UtAPIInfo(api_full_name + '.forward.input', element) UtAPIInfo(api_full_name + '.forward.output.bench', data_info.bench_out) @@ -214,7 +213,7 @@ def do_save_error_data(api_full_name, data_info, is_fwd_success, is_bwd_success) def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict): in_fwd_data_list = [] backward_message = '' - [api_type, api_name, _] = api_full_name.split("*") + [api_type, api_name, _] = api_full_name.split(".") args, kwargs, need_grad = get_api_info(api_info_dict, api_name, real_data_path) in_fwd_data_list.append(args) in_fwd_data_list.append(kwargs) diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/dump.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/dump.py index 2e49a9743..8d64a5ddf 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/dump.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/dump.py @@ -188,8 +188,43 @@ def dump_data(prefix, data_info): def thread_dump_data(prefix, data_info): DumpUtil.dump_thread_pool.submit(dump_data, prefix, data_info) +def underscore_replacement(prefix): + """ + Replacing symbols to unify the format of ptdbg and pretest + """ + replaced_prefix = [] + consecutive_underscore_count = 0 + three_underscore_time = 0 + + for char in prefix: + if char == '_': + consecutive_underscore_count += 1 + if consecutive_underscore_count == 2: + # Two consecutive underscores, leave them unchanged + replaced_prefix.pop() + replaced_prefix.append('__') + elif consecutive_underscore_count == 3: + # Three consecutive underscores + three_underscore_time += 1 + replaced_prefix.pop() + if three_underscore_time % 2 == 1: + # Even index, replace the first underscore + replaced_prefix.append('.__') + else: + # Odd index, replace the third underscore + replaced_prefix.append('__.') + else: + # Single underscore, replace with a period + replaced_prefix.append('.') + else: + # Not an underscore, reset the count + consecutive_underscore_count = 0 + replaced_prefix.append(char) + replaced_prefix = ''.join(replaced_prefix).replace("stack.info", "stack_info") + return replaced_prefix def dump_data_by_rank_count(dump_step, prefix, data_info): + prefix = underscore_replacement(prefix) print_info_log(f"ptdbg is analyzing rank{rank} api: {prefix}" + " " * 10, end='\r') if DumpUtil.is_single_rank and DumpUtil.dump_thread_pool: thread_dump_data(prefix, data_info) diff --git a/debug/accuracy_tools/ptdbg_ascend/test/ut/test_dump.py b/debug/accuracy_tools/ptdbg_ascend/test/ut/test_dump.py index 9673c292b..56440bbe5 100644 --- a/debug/accuracy_tools/ptdbg_ascend/test/ut/test_dump.py +++ b/debug/accuracy_tools/ptdbg_ascend/test/ut/test_dump.py @@ -47,3 +47,12 @@ class TestDump(unittest.TestCase): result = get_pkl_file_path() self.assertEqual(result, "") + def test_underscore_replacement(self): + prefix = "Torch_matmul_605_forward_input.0" + replaced_prefix = underscore_replacement(prefix) + self.assertEqual(replaced_prefix, "Torch.matmul.605.forward.input.0") + + prefix = "Tensor___getitem___488_forward_stack_info" + replaced_prefix = underscore_replacement(prefix) + self.assertEqual(replaced_prefix, "Tensor.__getitem__.488.forward.stack_info") + -- Gitee From 949a9f43c26a65a964f9b6eca3cbbe6dcefbb205 Mon Sep 17 00:00:00 2001 From: keith Date: Wed, 8 May 2024 20:14:46 +0800 Subject: [PATCH 2/3] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E5=86=92=E7=83=9F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../ptdbg_ascend/src/python/ptdbg_ascend/dump/dump.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/dump.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/dump.py index 8d64a5ddf..112b2102d 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/dump.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/dump.py @@ -188,7 +188,7 @@ def dump_data(prefix, data_info): def thread_dump_data(prefix, data_info): DumpUtil.dump_thread_pool.submit(dump_data, prefix, data_info) -def underscore_replacement(prefix): +def underscore_replace(prefix): """ Replacing symbols to unify the format of ptdbg and pretest """ @@ -220,11 +220,9 @@ def underscore_replacement(prefix): # Not an underscore, reset the count consecutive_underscore_count = 0 replaced_prefix.append(char) - replaced_prefix = ''.join(replaced_prefix).replace("stack.info", "stack_info") - return replaced_prefix + return ''.join(replaced_prefix) def dump_data_by_rank_count(dump_step, prefix, data_info): - prefix = underscore_replacement(prefix) print_info_log(f"ptdbg is analyzing rank{rank} api: {prefix}" + " " * 10, end='\r') if DumpUtil.is_single_rank and DumpUtil.dump_thread_pool: thread_dump_data(prefix, data_info) @@ -248,7 +246,7 @@ def dump_stack_info(name_template): except Exception as e: print_warn_log("Dump stack info failed, error: {}".format(e)) stack_str.append('') - + prefix = name_template.format("stack_info") if DumpUtil.dump_switch_mode in Const.DUMP_MODE: complement_set = set(['forward', 'backward', 'input', 'output']) - set(DumpUtil.dump_mode) @@ -339,7 +337,7 @@ def dump_acc_cmp(name, in_feat, out_feat, dump_step, module): print_warn_log("The file does not exist, error: {}".format(e)) name_prefix = name - name_template = f"{name_prefix}" + "_{}" + name_template = f"{underscore_replace(name_prefix)}" + ".{}" if DumpUtil.is_single_rank is None: DumpUtil.is_single_rank = check_single_rank_folder(dump_dir) if DumpUtil.dump_switch_mode in [Const.ALL, Const.API_LIST]: -- Gitee From cb042754fd119d87c09dee9bce04bc3ea98b3288 Mon Sep 17 00:00:00 2001 From: keith Date: Wed, 8 May 2024 21:47:14 +0800 Subject: [PATCH 3/3] =?UTF-8?q?=E5=86=92=E7=83=9F=E4=BF=AE=E5=A4=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../api_accuracy_checker/compare/compare.py | 9 ++- .../api_accuracy_checker/run_ut/run_ut.py | 74 ++++++++++++------- 2 files changed, 54 insertions(+), 29 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py index db549ec1d..be65add1f 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py @@ -8,7 +8,7 @@ from rich.console import Console from api_accuracy_checker.common.utils import get_json_contents, write_csv, print_warn_log from api_accuracy_checker.compare.compare_utils import CompareConst, check_dtype_comparable, DETAIL_TEST_ROWS, \ precision_configs, BENCHMARK_COMPARE_SUPPORT_LIST, AbsoluteStandardApi, BinaryStandardApi, ULPStandardApi, \ - apis_threshold + ThousandthStandardApi, apis_threshold from api_accuracy_checker.compare.compare_column import CompareColumn from api_accuracy_checker.compare.algorithm import get_rmse, get_error_balance, get_max_rel_err, get_mean_rel_err, \ get_rel_err, get_abs_err, get_max_abs_err, get_rel_err_ratio, cosine_sim, get_rel_err_origin, \ @@ -160,7 +160,7 @@ class Comparator: self.write_detail_csv(args) def compare_output(self, full_api_name, data_info): - _, api_name, _ = full_api_name.split("*") + _, api_name, _ = full_api_name.split(".") bench_output = data_info.bench_output device_output = data_info.device_output bench_grad = data_info.bench_grad @@ -279,6 +279,10 @@ class Comparator: message = "" abs_bench, abs_bench_with_eps = get_abs_bench_with_eps(bench_output, dtype) abs_err = get_abs_err(bench_output, device_output) + rel_err_orign = get_rel_err_origin(abs_err, abs_bench_with_eps) + if api_name in ThousandthStandardApi: + thousand_res, thousand_status = get_rel_err_ratio(rel_err_orign, 0.001) + compare_column.rel_err_thousandth = thousand_res if str(dtype) in BENCHMARK_COMPARE_SUPPORT_LIST: both_finite_mask, inf_nan_mask = get_finite_and_infinite_mask(bench_output, device_output) if api_name in BinaryStandardApi: @@ -330,7 +334,6 @@ class Comparator: message += "Max abs error is less than 0.001, consider as pass, skip other check and set to SPACE.\n" return CompareConst.PASS, compare_column, message - rel_err_orign = get_rel_err_origin(abs_err, abs_bench_with_eps) if dtype in [torch.float16, torch.bfloat16]: hundred_res, hundred_status = get_rel_err_ratio(rel_err_orign, 0.01) compare_column.rel_err_hundredth = hundred_res diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index 71018a86d..1f7cf8a69 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -17,7 +17,7 @@ else: import torch from tqdm import tqdm from api_accuracy_checker.run_ut.data_generate import gen_api_params, gen_args -from api_accuracy_checker.run_ut.run_ut_utils import Backward_Message +from api_accuracy_checker.run_ut.run_ut_utils import Backward_Message, hf_32_standard_api from api_accuracy_checker.common.utils import print_info_log, print_warn_log, get_json_contents, api_info_preprocess, \ print_error_log, initialize_save_path, Const, create_directory from api_accuracy_checker.compare.compare import Comparator @@ -77,7 +77,18 @@ def deal_detach(arg, to_detach=True): return arg.detach() if to_detach else arg -def deal_dtype(arg, raise_dtype=None): +def raise_bench_data_dtype(api_name, arg, raise_dtype=None): + ''' + 将标杆数据的dtype转换为raise_dtype + 输入: + api_name:api名称 + arg:标杆输入 + raise_dtype:需要转换的dtype + 输出: + arg: 转换dtype的标杆输入 + ''' + if api_name in hf_32_standard_api and arg.dtype == torch.float32: + return arg if raise_dtype is None or arg.dtype not in Const.RAISE_PRECISION or raise_dtype == arg.dtype: return arg return arg.type(raise_dtype) @@ -112,13 +123,13 @@ def generate_cpu_params(input_args, input_kwargs, need_backward, api_name): return type(arg_in)(recursive_arg_to_cpu(arg, to_detach, raise_dtype=raise_dtype) for arg in arg_in) elif isinstance(arg_in, torch.Tensor): if need_backward and arg_in.requires_grad: - arg_in = deal_detach(deal_dtype(arg_in.clone(), raise_dtype), to_detach).requires_grad_() + arg_in = deal_detach(raise_bench_data_dtype(api_name, arg_in.clone(), raise_dtype), to_detach).requires_grad_() temp_arg_in = arg_in * 1 arg_in = temp_arg_in.type_as(arg_in) arg_in.retain_grad() return arg_in else: - return deal_detach(deal_dtype(arg_in.clone(), raise_dtype=raise_dtype), to_detach) + return deal_detach(raise_bench_data_dtype(api_name, arg_in.clone(), raise_dtype=raise_dtype), to_detach) else: return arg_in @@ -201,13 +212,14 @@ def run_ut(config): def do_save_error_data(api_full_name, data_info, is_fwd_success, is_bwd_success): if not is_fwd_success or not is_bwd_success: + # api_full_name = api_full_name.replace("*", ".") for element in data_info.in_fwd_data_list: UtAPIInfo(api_full_name + '.forward.input', element) - UtAPIInfo(api_full_name + '.forward.output.bench', data_info.bench_out) - UtAPIInfo(api_full_name + '.forward.output.device', data_info.device_out) + UtAPIInfo(api_full_name + '.forward.output.bench', data_info.bench_output) + UtAPIInfo(api_full_name + '.forward.output.device', data_info.device_output) UtAPIInfo(api_full_name + '.backward.input', data_info.grad_in) - UtAPIInfo(api_full_name + '.backward.output.bench', data_info.bench_grad_out) - UtAPIInfo(api_full_name + '.backward.output.device', data_info.device_grad_out) + UtAPIInfo(api_full_name + '.backward.output.bench', data_info.bench_grad) + UtAPIInfo(api_full_name + '.backward.output.device', data_info.device_grad) def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict): @@ -245,7 +257,7 @@ def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict if need_backward: if need_to_backward(grad_index, out): backward_args = backward_content[api_full_name] - grad = gen_args(backward_args, real_data_path=real_data_path)[0] + grad = gen_args(backward_args, api_name, real_data_path=real_data_path)[0] bench_grad, _ = generate_cpu_params(grad, {}, False, api_name) bench_grad_out = run_backward(cpu_args, bench_grad, grad_index, out) device_grad = grad.clone().detach().to(current_device) @@ -261,7 +273,7 @@ def get_api_info(api_info_dict, api_name, real_data_path): need_grad = True if api_info_dict.get("kwargs") and "out" in api_info_dict.get("kwargs"): need_grad = False - args, kwargs = gen_api_params(api_info_dict, need_grad, convert_type, real_data_path) + args, kwargs = gen_api_params(api_info_dict, api_name, need_grad, convert_type, real_data_path) return args, kwargs, need_grad @@ -365,29 +377,37 @@ def _run_ut_parser(parser): def preprocess_forward_content(forward_content): processed_content = {} base_keys_variants = {} + arg_cache = {} + for key, value in forward_content.items(): base_key = key.rsplit('*', 1)[0] - new_args = value['args'] - new_kwargs = value['kwargs'] - filtered_new_args = [{k: v for k, v in arg.items() if k not in ['Max', 'Min']} for arg in new_args if isinstance(arg, dict)] - if base_key in base_keys_variants: + + if key not in arg_cache: + new_args = value['args'] + new_kwargs = value['kwargs'] + filtered_new_args = [ + {k: v for k, v in arg.items() if k not in ['Max', 'Min']} + for arg in new_args if isinstance(arg, dict) + ] + arg_cache[key] = (filtered_new_args, new_kwargs) + + filtered_new_args, new_kwargs = arg_cache[key] + + if base_key not in base_keys_variants: + processed_content[key] = value + base_keys_variants[base_key] = {key} + else: is_duplicate = False - for variant in base_keys_variants.get(base_key, []): - try: - existing_args = processed_content[variant].get('args', []) - existing_kwargs = processed_content[variant].get('kwargs', {}) - filtered_existing_args = [{k: v for k, v in arg.items() if k not in ['Max', 'Min']} for arg in existing_args if isinstance(arg, dict)] - except KeyError as e: - print_error_log(f"KeyError: {e} when processing {key}") - if filtered_existing_args == filtered_new_args and existing_kwargs == new_kwargs: + for variant in base_keys_variants[base_key]: + existing_args, existing_kwargs = arg_cache[variant] + if existing_args == filtered_new_args and existing_kwargs == new_kwargs: is_duplicate = True break + if not is_duplicate: processed_content[key] = value - base_keys_variants[base_key].append(key) - else: - processed_content[key] = value - base_keys_variants[base_key] = [key] + base_keys_variants[base_key].add(key) + return processed_content @@ -422,7 +442,9 @@ def run_ut_command(args): save_error_data = args.save_error_data forward_content = get_json_contents(forward_file) if args.filter_api: + print_info_log("Start filtering the api in the forward_input_file.") forward_content = preprocess_forward_content(forward_content) + print_info_log("Finish filtering the api in the forward_input_file.") backward_content = {} if args.backward_input_file: check_link(args.backward_input_file) -- Gitee