diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py index db549ec1d54060cb8a9f96ceeec39f90566b05ca..be65add1f37ce2372016cc8a3734f8902772b994 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py @@ -8,7 +8,7 @@ from rich.console import Console from api_accuracy_checker.common.utils import get_json_contents, write_csv, print_warn_log from api_accuracy_checker.compare.compare_utils import CompareConst, check_dtype_comparable, DETAIL_TEST_ROWS, \ precision_configs, BENCHMARK_COMPARE_SUPPORT_LIST, AbsoluteStandardApi, BinaryStandardApi, ULPStandardApi, \ - apis_threshold + ThousandthStandardApi, apis_threshold from api_accuracy_checker.compare.compare_column import CompareColumn from api_accuracy_checker.compare.algorithm import get_rmse, get_error_balance, get_max_rel_err, get_mean_rel_err, \ get_rel_err, get_abs_err, get_max_abs_err, get_rel_err_ratio, cosine_sim, get_rel_err_origin, \ @@ -160,7 +160,7 @@ class Comparator: self.write_detail_csv(args) def compare_output(self, full_api_name, data_info): - _, api_name, _ = full_api_name.split("*") + _, api_name, _ = full_api_name.split(".") bench_output = data_info.bench_output device_output = data_info.device_output bench_grad = data_info.bench_grad @@ -279,6 +279,10 @@ class Comparator: message = "" abs_bench, abs_bench_with_eps = get_abs_bench_with_eps(bench_output, dtype) abs_err = get_abs_err(bench_output, device_output) + rel_err_orign = get_rel_err_origin(abs_err, abs_bench_with_eps) + if api_name in ThousandthStandardApi: + thousand_res, thousand_status = get_rel_err_ratio(rel_err_orign, 0.001) + compare_column.rel_err_thousandth = thousand_res if str(dtype) in BENCHMARK_COMPARE_SUPPORT_LIST: both_finite_mask, inf_nan_mask = get_finite_and_infinite_mask(bench_output, device_output) if api_name in BinaryStandardApi: @@ -330,7 +334,6 @@ class Comparator: message += "Max abs error is less than 0.001, consider as pass, skip other check and set to SPACE.\n" return CompareConst.PASS, compare_column, message - rel_err_orign = get_rel_err_origin(abs_err, abs_bench_with_eps) if dtype in [torch.float16, torch.bfloat16]: hundred_res, hundred_status = get_rel_err_ratio(rel_err_orign, 0.01) compare_column.rel_err_hundredth = hundred_res diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/dump.py b/debug/accuracy_tools/api_accuracy_checker/dump/dump.py index d8b317aa282bb74f7b35c6a4b6216446959fb30e..f5a4620e1ad2044a3ba0259b5ef533cb569ed102 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/dump.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/dump.py @@ -87,6 +87,7 @@ class DumpConst: def pretest_info_dump(name, out_feat, module, phase): if not DumpUtil.get_dump_switch(): return + name = name.replace('*', '.') if phase == DumpConst.forward: api_info = ForwardAPIInfo(name, module.input_args, module.input_kwargs) elif phase == DumpConst.backward: diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index 05bd4305a3b8ac1366596d2947cafd467a784b1d..1f7cf8a691c67b130d41e306ba464f3aa18a774a 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -17,7 +17,7 @@ else: import torch from tqdm import tqdm from api_accuracy_checker.run_ut.data_generate import gen_api_params, gen_args -from api_accuracy_checker.run_ut.run_ut_utils import Backward_Message +from api_accuracy_checker.run_ut.run_ut_utils import Backward_Message, hf_32_standard_api from api_accuracy_checker.common.utils import print_info_log, print_warn_log, get_json_contents, api_info_preprocess, \ print_error_log, initialize_save_path, Const, create_directory from api_accuracy_checker.compare.compare import Comparator @@ -77,7 +77,18 @@ def deal_detach(arg, to_detach=True): return arg.detach() if to_detach else arg -def deal_dtype(arg, raise_dtype=None): +def raise_bench_data_dtype(api_name, arg, raise_dtype=None): + ''' + 将标杆数据的dtype转换为raise_dtype + 输入: + api_name:api名称 + arg:标杆输入 + raise_dtype:需要转换的dtype + 输出: + arg: 转换dtype的标杆输入 + ''' + if api_name in hf_32_standard_api and arg.dtype == torch.float32: + return arg if raise_dtype is None or arg.dtype not in Const.RAISE_PRECISION or raise_dtype == arg.dtype: return arg return arg.type(raise_dtype) @@ -112,13 +123,13 @@ def generate_cpu_params(input_args, input_kwargs, need_backward, api_name): return type(arg_in)(recursive_arg_to_cpu(arg, to_detach, raise_dtype=raise_dtype) for arg in arg_in) elif isinstance(arg_in, torch.Tensor): if need_backward and arg_in.requires_grad: - arg_in = deal_detach(deal_dtype(arg_in.clone(), raise_dtype), to_detach).requires_grad_() + arg_in = deal_detach(raise_bench_data_dtype(api_name, arg_in.clone(), raise_dtype), to_detach).requires_grad_() temp_arg_in = arg_in * 1 arg_in = temp_arg_in.type_as(arg_in) arg_in.retain_grad() return arg_in else: - return deal_detach(deal_dtype(arg_in.clone(), raise_dtype=raise_dtype), to_detach) + return deal_detach(raise_bench_data_dtype(api_name, arg_in.clone(), raise_dtype=raise_dtype), to_detach) else: return arg_in @@ -170,7 +181,7 @@ def run_ut(config): continue try: if msCheckerConfig.white_list: - [_, api_name, _] = api_full_name.split("*") + [_, api_name, _] = api_full_name.split(".") if api_name not in set(msCheckerConfig.white_list): continue data_info = run_torch_api(api_full_name, config.real_data_path, config.backward_content, api_info_dict) @@ -179,7 +190,7 @@ def run_ut(config): if config.save_error_data: do_save_error_data(api_full_name, data_info, is_fwd_success, is_bwd_success) except Exception as err: - [_, api_name, _] = api_full_name.split("*") + [_, api_name, _] = api_full_name.split(".") if "expected scalar type Long" in str(err): print_warn_log(f"API {api_name} not support int32 tensor in CPU, please add {api_name} to CONVERT_API " f"'int32_to_int64' list in accuracy_tools/api_accuracy_check/common/utils.py file.") @@ -201,20 +212,20 @@ def run_ut(config): def do_save_error_data(api_full_name, data_info, is_fwd_success, is_bwd_success): if not is_fwd_success or not is_bwd_success: - api_full_name = api_full_name.replace("*", ".") + # api_full_name = api_full_name.replace("*", ".") for element in data_info.in_fwd_data_list: UtAPIInfo(api_full_name + '.forward.input', element) - UtAPIInfo(api_full_name + '.forward.output.bench', data_info.bench_out) - UtAPIInfo(api_full_name + '.forward.output.device', data_info.device_out) + UtAPIInfo(api_full_name + '.forward.output.bench', data_info.bench_output) + UtAPIInfo(api_full_name + '.forward.output.device', data_info.device_output) UtAPIInfo(api_full_name + '.backward.input', data_info.grad_in) - UtAPIInfo(api_full_name + '.backward.output.bench', data_info.bench_grad_out) - UtAPIInfo(api_full_name + '.backward.output.device', data_info.device_grad_out) + UtAPIInfo(api_full_name + '.backward.output.bench', data_info.bench_grad) + UtAPIInfo(api_full_name + '.backward.output.device', data_info.device_grad) def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict): in_fwd_data_list = [] backward_message = '' - [api_type, api_name, _] = api_full_name.split("*") + [api_type, api_name, _] = api_full_name.split(".") args, kwargs, need_grad = get_api_info(api_info_dict, api_name, real_data_path) in_fwd_data_list.append(args) in_fwd_data_list.append(kwargs) @@ -246,7 +257,7 @@ def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict if need_backward: if need_to_backward(grad_index, out): backward_args = backward_content[api_full_name] - grad = gen_args(backward_args, real_data_path=real_data_path)[0] + grad = gen_args(backward_args, api_name, real_data_path=real_data_path)[0] bench_grad, _ = generate_cpu_params(grad, {}, False, api_name) bench_grad_out = run_backward(cpu_args, bench_grad, grad_index, out) device_grad = grad.clone().detach().to(current_device) @@ -262,7 +273,7 @@ def get_api_info(api_info_dict, api_name, real_data_path): need_grad = True if api_info_dict.get("kwargs") and "out" in api_info_dict.get("kwargs"): need_grad = False - args, kwargs = gen_api_params(api_info_dict, need_grad, convert_type, real_data_path) + args, kwargs = gen_api_params(api_info_dict, api_name, need_grad, convert_type, real_data_path) return args, kwargs, need_grad @@ -366,29 +377,37 @@ def _run_ut_parser(parser): def preprocess_forward_content(forward_content): processed_content = {} base_keys_variants = {} + arg_cache = {} + for key, value in forward_content.items(): base_key = key.rsplit('*', 1)[0] - new_args = value['args'] - new_kwargs = value['kwargs'] - filtered_new_args = [{k: v for k, v in arg.items() if k not in ['Max', 'Min']} for arg in new_args if isinstance(arg, dict)] - if base_key in base_keys_variants: + + if key not in arg_cache: + new_args = value['args'] + new_kwargs = value['kwargs'] + filtered_new_args = [ + {k: v for k, v in arg.items() if k not in ['Max', 'Min']} + for arg in new_args if isinstance(arg, dict) + ] + arg_cache[key] = (filtered_new_args, new_kwargs) + + filtered_new_args, new_kwargs = arg_cache[key] + + if base_key not in base_keys_variants: + processed_content[key] = value + base_keys_variants[base_key] = {key} + else: is_duplicate = False - for variant in base_keys_variants.get(base_key, []): - try: - existing_args = processed_content[variant].get('args', []) - existing_kwargs = processed_content[variant].get('kwargs', {}) - filtered_existing_args = [{k: v for k, v in arg.items() if k not in ['Max', 'Min']} for arg in existing_args if isinstance(arg, dict)] - except KeyError as e: - print_error_log(f"KeyError: {e} when processing {key}") - if filtered_existing_args == filtered_new_args and existing_kwargs == new_kwargs: + for variant in base_keys_variants[base_key]: + existing_args, existing_kwargs = arg_cache[variant] + if existing_args == filtered_new_args and existing_kwargs == new_kwargs: is_duplicate = True break + if not is_duplicate: processed_content[key] = value - base_keys_variants[base_key].append(key) - else: - processed_content[key] = value - base_keys_variants[base_key] = [key] + base_keys_variants[base_key].add(key) + return processed_content @@ -423,7 +442,9 @@ def run_ut_command(args): save_error_data = args.save_error_data forward_content = get_json_contents(forward_file) if args.filter_api: + print_info_log("Start filtering the api in the forward_input_file.") forward_content = preprocess_forward_content(forward_content) + print_info_log("Finish filtering the api in the forward_input_file.") backward_content = {} if args.backward_input_file: check_link(args.backward_input_file) diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/dump.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/dump.py index 2e49a9743b6a317ee401b2b8f0b31fb2ea68c07a..112b2102dc21294426d3a0623114408eaf49ee21 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/dump.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/dump.py @@ -188,6 +188,39 @@ def dump_data(prefix, data_info): def thread_dump_data(prefix, data_info): DumpUtil.dump_thread_pool.submit(dump_data, prefix, data_info) +def underscore_replace(prefix): + """ + Replacing symbols to unify the format of ptdbg and pretest + """ + replaced_prefix = [] + consecutive_underscore_count = 0 + three_underscore_time = 0 + + for char in prefix: + if char == '_': + consecutive_underscore_count += 1 + if consecutive_underscore_count == 2: + # Two consecutive underscores, leave them unchanged + replaced_prefix.pop() + replaced_prefix.append('__') + elif consecutive_underscore_count == 3: + # Three consecutive underscores + three_underscore_time += 1 + replaced_prefix.pop() + if three_underscore_time % 2 == 1: + # Even index, replace the first underscore + replaced_prefix.append('.__') + else: + # Odd index, replace the third underscore + replaced_prefix.append('__.') + else: + # Single underscore, replace with a period + replaced_prefix.append('.') + else: + # Not an underscore, reset the count + consecutive_underscore_count = 0 + replaced_prefix.append(char) + return ''.join(replaced_prefix) def dump_data_by_rank_count(dump_step, prefix, data_info): print_info_log(f"ptdbg is analyzing rank{rank} api: {prefix}" + " " * 10, end='\r') @@ -213,7 +246,7 @@ def dump_stack_info(name_template): except Exception as e: print_warn_log("Dump stack info failed, error: {}".format(e)) stack_str.append('') - + prefix = name_template.format("stack_info") if DumpUtil.dump_switch_mode in Const.DUMP_MODE: complement_set = set(['forward', 'backward', 'input', 'output']) - set(DumpUtil.dump_mode) @@ -304,7 +337,7 @@ def dump_acc_cmp(name, in_feat, out_feat, dump_step, module): print_warn_log("The file does not exist, error: {}".format(e)) name_prefix = name - name_template = f"{name_prefix}" + "_{}" + name_template = f"{underscore_replace(name_prefix)}" + ".{}" if DumpUtil.is_single_rank is None: DumpUtil.is_single_rank = check_single_rank_folder(dump_dir) if DumpUtil.dump_switch_mode in [Const.ALL, Const.API_LIST]: diff --git a/debug/accuracy_tools/ptdbg_ascend/test/ut/test_dump.py b/debug/accuracy_tools/ptdbg_ascend/test/ut/test_dump.py index 9673c292ba20cef94b926986ddac270f748645e9..56440bbe555541cdb356006acef5b65c4a4fc1cb 100644 --- a/debug/accuracy_tools/ptdbg_ascend/test/ut/test_dump.py +++ b/debug/accuracy_tools/ptdbg_ascend/test/ut/test_dump.py @@ -47,3 +47,12 @@ class TestDump(unittest.TestCase): result = get_pkl_file_path() self.assertEqual(result, "") + def test_underscore_replacement(self): + prefix = "Torch_matmul_605_forward_input.0" + replaced_prefix = underscore_replacement(prefix) + self.assertEqual(replaced_prefix, "Torch.matmul.605.forward.input.0") + + prefix = "Tensor___getitem___488_forward_stack_info" + replaced_prefix = underscore_replacement(prefix) + self.assertEqual(replaced_prefix, "Tensor.__getitem__.488.forward.stack_info") +