From b1297fbd8899265651ee3bf8ce3513bd1f8bcdd7 Mon Sep 17 00:00:00 2001 From: shawn_zhu1 Date: Fri, 5 Jul 2024 17:54:20 +0800 Subject: [PATCH 1/3] =?UTF-8?q?=E3=80=90=E4=BC=98=E5=8C=96=E3=80=91?= =?UTF-8?q?=E4=BF=AE=E5=A4=8Dcode=20check?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- debug/accuracy_tools/atat/core/utils.py | 2 +- .../atat/pytorch/compare/acc_compare.py | 146 +++++++++++++----- .../atat/pytorch/compare/npy_compare.py | 4 + 3 files changed, 110 insertions(+), 42 deletions(-) diff --git a/debug/accuracy_tools/atat/core/utils.py b/debug/accuracy_tools/atat/core/utils.py index e3a30579a80..224e30aef35 100644 --- a/debug/accuracy_tools/atat/core/utils.py +++ b/debug/accuracy_tools/atat/core/utils.py @@ -491,7 +491,7 @@ def get_dump_data_path(dump_dir): file_is_exist = False check_file_or_directory_path(dump_dir, True) - for dir_path, sub_paths, files in os.walk(dump_dir): + for dir_path, _, files in os.walk(dump_dir): if len(files) != 0: dump_data_path = dir_path file_is_exist = True diff --git a/debug/accuracy_tools/atat/pytorch/compare/acc_compare.py b/debug/accuracy_tools/atat/pytorch/compare/acc_compare.py index bd903ef2d3a..d3a072f5c78 100644 --- a/debug/accuracy_tools/atat/pytorch/compare/acc_compare.py +++ b/debug/accuracy_tools/atat/pytorch/compare/acc_compare.py @@ -28,6 +28,7 @@ import pandas as pd import openpyxl from openpyxl.styles import PatternFill from collections import namedtuple +from dataclasses import dataclass from .match import graph_mapping from .highlight import HighlightRules, get_header_index @@ -139,7 +140,7 @@ def merge_tensor(tensor_list, summary_compare, md5_compare): op_dict["summary"] = [] op_dict["stack_info"] = [] - all_mode_bool = summary_compare == False and md5_compare == False + all_mode_bool = not (summary_compare or md5_compare) if all_mode_bool: op_dict["data_name"] = [] @@ -192,7 +193,7 @@ def get_accuracy(result, n_dict, b_dict, summary_compare=False, md5_compare=Fals bench_stack_info = b_dict.get("stack_info", None) has_stack = npu_stack_info and bench_stack_info - all_mode_bool = summary_compare == False and md5_compare == False + all_mode_bool = not (summary_compare or md5_compare) if all_mode_bool: npu_data_name = n_dict.get("data_name", None) bench_data_name = b_dict.get("data_name", None) @@ -206,7 +207,8 @@ def get_accuracy(result, n_dict, b_dict, summary_compare=False, md5_compare=Fals err_msg = "" if md5_compare: result_item = [n_name, b_name, n_struct[0], b_struct[0], n_struct[1], b_struct[1], - n_struct[2], b_struct[2], CompareConst.PASS if n_struct[2] == b_struct[2] else CompareConst.DIFF] + n_struct[2], b_struct[2], + CompareConst.PASS if n_struct[2] == b_struct[2] else CompareConst.DIFF] if has_stack and index == 0 and key == "input_struct": result_item.extend(npu_stack_info) else: @@ -233,7 +235,7 @@ def get_accuracy(result, n_dict, b_dict, summary_compare=False, md5_compare=Fals if isinstance(npu_val, (float, int)) and isinstance(bench_val, (float, int)): diff = npu_val - bench_val if bench_val != 0: - relative = str(abs((diff/bench_val) * 100)) + '%' + relative = str(abs((diff / bench_val) * 100)) + '%' else: relative = "N/A" result_item[start_idx + i] = diff @@ -245,7 +247,8 @@ def get_accuracy(result, n_dict, b_dict, summary_compare=False, md5_compare=Fals result_item[start_idx + i] = CompareConst.NONE accuracy_check = CompareConst.WARNING if warning_flag else "" err_msg += "Need double check api accuracy." if warning_flag else "" - result_item[start_idx:] = [f'{str(x)}\t' if str(x) in ('inf', '-inf', 'nan') else x for x in result_item[start_idx:]] + result_item[start_idx:] = [f'{str(x)}\t' if str(x) in ('inf', '-inf', 'nan') else x for x in + result_item[start_idx:]] result_item.append(accuracy_check if summary_compare else CompareConst.ACCURACY_CHECK_YES) result_item.append(err_msg) @@ -313,12 +316,10 @@ def read_dump_data(result_df): try: npu_dump_name_list = result_df.iloc[0:, 0].tolist() npu_dump_tensor_list = result_df.iloc[0:, -1].tolist() - # bench_dump_name_list = csv_pd.iloc[0:, 1].tolist() op_name_mapping_dict = {} for index, _ in enumerate(npu_dump_name_list): npu_dump_name = npu_dump_name_list[index] npu_dump_tensor = npu_dump_tensor_list[index] - # bench_dump_name = bench_dump_name_list[index] op_name_mapping_dict[npu_dump_name] = [npu_dump_tensor, npu_dump_tensor] return op_name_mapping_dict except ValueError as e: @@ -352,8 +353,8 @@ def _handle_multi_process(func, input_parma, result_df, lock): for process_idx, df_chunk in enumerate(df_chunks): idx = df_chunk_size * process_idx result = pool.apply_async(func, - args=(idx, op_name_mapping_dict, df_chunk, lock, input_parma), - error_callback=err_call) + args=(idx, op_name_mapping_dict, df_chunk, lock, input_parma), + error_callback=err_call) results.append(result) final_results = [r.get() for r in results] pool.close() @@ -372,33 +373,67 @@ def compare_ops(idx, dump_path_dict, result_df, lock, input_parma): for i in range(len(result_df)): op_name = result_df.iloc[i, 0] if is_print_compare_log: - print("start compare: {}".format(op_name)) - cos_sim, max_abs_err, max_relative_err, one_thousand_err_ratio, five_thousand_err_ratio, err_msg = compare_by_op(op_name, dump_path_dict, input_parma) + print_info_log("start compare: {}".format(op_name)) + cos_sim, max_abs_err, max_relative_err, one_thousand_err_ratio, five_thousand_err_ratio, err_msg = compare_by_op( + op_name, dump_path_dict, input_parma) if is_print_compare_log: - print("[{}] Compare result: cosine {}, max_abs_err {}, max_relative_err {}, {}, one_thousand_err_ratio {}, five_thousand_err_ratio {}".format(op_name, cos_sim, max_abs_err, max_relative_err, err_msg, one_thousand_err_ratio, five_thousand_err_ratio)) + print_info_log( + "[{}] Compare result: cosine {}, max_abs_err {}, max_relative_err {}, {}, one_thousand_err_ratio {}, " + "five_thousand_err_ratio {}".format(op_name, cos_sim, max_abs_err, max_relative_err, err_msg, + one_thousand_err_ratio, five_thousand_err_ratio)) cos_result.append(cos_sim) max_err_result.append(max_abs_err) max_relative_err_result.append(max_relative_err) err_mess.append(err_msg) one_thousand_err_ratio_result.append(one_thousand_err_ratio) five_thousand_err_ratio_result.append(five_thousand_err_ratio) - result_df = _save_cmp_result(idx, cos_result, max_err_result, max_relative_err_result, err_mess, one_thousand_err_ratio_result, - five_thousand_err_ratio_result, result_df, lock) - return result_df + + cr = ComparisonResult( + cos_result=cos_result, + max_err_result=max_err_result, + max_relative_err_result=max_relative_err_result, + err_msgs=err_mess, + one_thousand_err_ratio_result=one_thousand_err_ratio_result, + five_thousand_err_ratio_result=five_thousand_err_ratio_result + ) + + return _save_cmp_result(idx, cr, result_df, lock) -def _save_cmp_result(idx, cos_result, max_err_result, max_relative_err_result, err_msg, one_thousand_err_ratio_result, five_thousand_err_ratio_result, result_df, lock): +@dataclass +class ComparisonResult: + cos_result: list + max_err_result: list + max_relative_err_result: list + err_msgs: list + one_thousand_err_ratio_result: list + five_thousand_err_ratio_result: list + + +def _save_cmp_result(offset, result: ComparisonResult, result_df, lock): + """ + Save comparison results into the result DataFrame with thread safety. + Args: + offset: offset for index + result: data struct of ComparisonResult + result_df: result of DataFrame + lock: thread lock + + Returns: + comparison results in DataFrame + """ + lock.acquire() try: - for i, _ in enumerate(cos_result): - process_index = i + idx - result_df.loc[process_index, CompareConst.COSINE] = cos_result[i] - result_df.loc[process_index, CompareConst.MAX_ABS_ERR] = max_err_result[i] - result_df.loc[process_index, CompareConst.MAX_RELATIVE_ERR] = max_relative_err_result[i] - result_df.loc[process_index, CompareConst.ERROR_MESSAGE] = err_msg[i] - result_df.loc[process_index, CompareConst.ACCURACY] = check_accuracy(cos_result[i], max_err_result[i]) - result_df.loc[process_index, CompareConst.ONE_THOUSANDTH_ERR_RATIO] = one_thousand_err_ratio_result[i] - result_df.loc[process_index, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO] = five_thousand_err_ratio_result[i] + for i, _ in enumerate(result.cos_result): + process_index = i + offset + result_df.loc[process_index, CompareConst.COSINE] = result.cos_result[i] + result_df.loc[process_index, CompareConst.MAX_ABS_ERR] = result.max_err_result[i] + result_df.loc[process_index, CompareConst.MAX_RELATIVE_ERR] = result.max_relative_err_result[i] + result_df.loc[process_index, CompareConst.ERROR_MESSAGE] = result.err_msgs[i] + result_df.loc[process_index, CompareConst.ACCURACY] = check_accuracy(result.cos_result[i], result.max_err_result[i]) + result_df.loc[process_index, CompareConst.ONE_THOUSANDTH_ERR_RATIO] = result.one_thousand_err_ratio_result[i] + result_df.loc[process_index, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO] = result.five_thousand_err_ratio_result[i] return result_df except ValueError as e: print_error_log('result dataframe is not found.') @@ -510,9 +545,9 @@ def find_error_rows(result, last_len, n_num_input, highlight_dict, summary_compa continue if not isinstance(api_out[npu_max_index], (float, int)) \ or not isinstance(api_out[bench_max_index], (float, int)) \ - or not isinstance(api_out[max_diff_index],(float, int)): + or not isinstance(api_out[max_diff_index], (float, int)): continue - for m, api_in in enumerate(result[0:n_num_input]): + for _, api_in in enumerate(result[0:n_num_input]): if not isinstance(api_in[npu_max_index], (float, int)) \ or not isinstance(api_in[bench_max_index], (float, int)) \ or not isinstance(api_in[max_diff_index], (float, int)): @@ -558,7 +593,8 @@ def find_compare_result_error_rows(result_df, highlight_dict, summary_compare): num, last_state = 1, state else: output_num = num - find_error_rows(result[start:start + input_num + output_num], start, input_num, highlight_dict, summary_compare) + find_error_rows(result[start:start + input_num + output_num], start, input_num, highlight_dict, + summary_compare) num, last_api_name, last_state = 1, api_name, state start += input_num + output_num input_num, output_num = 1, 0 @@ -614,8 +650,32 @@ def compare(input_parma, output_path, stack_mode=False, auto_analyze=True, md5_compare=md5_compare) -def compare_core(input_parma, output_path, stack_mode=False, auto_analyze=True, - suffix='', fuzzy_match=False, summary_compare=False, md5_compare=False): +def compare_core(input_parma, output_path, **kwargs): + """ + Compares data from multiple JSON files and generates a comparison report. + + Args: + input_parma (dict): A dictionary containing paths to JSON files ("npu_json_path", "bench_json_path", + "stack_json_path"). + output_path (str): The path where the output Excel report will be saved. + **kwargs: Additional keyword arguments including: + - stack_mode (bool, optional): Enables stack mode comparison. Defaults to False. + - auto_analyze (bool, optional): If True, triggers automatic analysis after comparison. Defaults to True. + - suffix (str, optional): Suffix to append to the output file name. Defaults to ''. + - fuzzy_match (bool, optional): Enables fuzzy matching during comparison. Defaults to False. + - summary_compare (bool, optional): Enables summary comparison mode. Defaults to False. + - md5_compare (bool, optional): Enables MD5 comparison. Defaults to False. + + Returns: + """ + # get kwargs or set default value + stack_mode = kwargs.get('stack_mode', False) + auto_analyze = kwargs.get('auto_analyze', True) + suffix = kwargs.get('suffix', '') + fuzzy_match = kwargs.get('fuzzy_match', False) + summary_compare = kwargs.get('summary_compare', False) + md5_compare = kwargs.get('md5_compare', False) + print_info_log("Please check whether the input data belongs to you. If not, there may be security risks.") file_name = add_time_with_xlsx("compare_result" + suffix) file_path = os.path.join(os.path.realpath(output_path), file_name) @@ -658,29 +718,33 @@ def parse(pkl_file, module_name_prefix): continue if info_prefix.find("stack_info") != -1: - print("\nTrace back({}):".format(msg[0])) + print_info_log("\nTrace back({}):".format(msg[0])) for item in reversed(msg[1]): - print(" File \"{}\", line {}, in {}".format(item[0], item[1], item[2])) - print(" {}".format(item[3])) + print_info_log(" File \"{}\", line {}, in {}".format(item[0], item[1], item[2])) + print_info_log(" {}".format(item[3])) continue if len(msg) > 5: summary_info = " [{}][dtype: {}][shape: {}][max: {}][min: {}][mean: {}]" \ .format(msg[0], msg[3], msg[4], msg[5][0], msg[5][1], msg[5][2]) if not title_printed: - print("\nStatistic Info:") + print_info_log("\nStatistic Info:") title_printed = True - print(summary_info) + print_info_log(summary_info) -def op_item_parse(item, op_name, index, item_list=[], top_bool=True): - if item == None or (isinstance(item, dict) and len(item) == 0): +def op_item_parse(item, op_name, index, item_list=None, top_bool=True): + if item_list is None: + item_list = [] + if item is None or (isinstance(item, dict) and not item): if not top_bool: - tmp = {'full_op_name': op_name + '.' + str(index), 'Max': None, 'Min': None, 'Mean': None, 'Norm': None, 'dtype': None, 'shape': None, 'md5': None, 'data_name': '-1'} + tmp = {'full_op_name': op_name + '.' + str(index), 'Max': None, 'Min': None, 'Mean': None, 'Norm': None, + 'dtype': None, 'shape': None, 'md5': None, 'data_name': '-1'} else: - tmp = {'full_op_name': op_name + '.0', 'Max': None, 'Min': None, 'Mean': None, 'Norm': None, 'dtype': None, 'shape': None, 'md5': None, 'data_name': '-1'} + tmp = {'full_op_name': op_name + '.0', 'Max': None, 'Min': None, 'Mean': None, 'Norm': None, 'dtype': None, + 'shape': None, 'md5': None, 'data_name': '-1'} item_list.append(tmp) return item_list - if index == None: + if index is None: if isinstance(item, dict): full_op_name = op_name + '.0' else: @@ -889,7 +953,7 @@ def compare_process(file_handles, stack_mode, fuzzy_match, summary_compare=False else: header = CompareConst.COMPARE_RESULT_HEADER[:] - all_mode_bool = summary_compare == False and md5_compare == False + all_mode_bool = not (summary_compare or md5_compare) if stack_mode: if all_mode_bool: header.append(CompareConst.STACK) diff --git a/debug/accuracy_tools/atat/pytorch/compare/npy_compare.py b/debug/accuracy_tools/atat/pytorch/compare/npy_compare.py index f16a807fefd..b94a83f1349 100644 --- a/debug/accuracy_tools/atat/pytorch/compare/npy_compare.py +++ b/debug/accuracy_tools/atat/pytorch/compare/npy_compare.py @@ -196,6 +196,8 @@ class GetThousandErrRatio(TensorComparisonBasic): return CompareConst.NAN, "" if relative_err is None: relative_err = get_relative_err(n_value, b_value) + if not np.size(relative_err): + return CompareConst.NAN, "" return format_value(np.sum(relative_err < CompareConst.THOUSAND_RATIO_THRESHOLD) / np.size(relative_err)), "" @@ -216,6 +218,8 @@ class GetFiveThousandErrRatio(TensorComparisonBasic): return CompareConst.NAN, "" if relative_err is None: relative_err = get_relative_err(n_value, b_value) + if not np.size(relative_err): + return CompareConst.NAN, "" return format_value(np.sum(relative_err < CompareConst.FIVE_THOUSAND_RATIO_THRESHOLD) / np.size(relative_err)), "" -- Gitee From 91bb15ab17e91d4c83ddff9a5f9f06d90cddffcd Mon Sep 17 00:00:00 2001 From: shawn_zhu1 Date: Sat, 6 Jul 2024 16:30:37 +0800 Subject: [PATCH 2/3] =?UTF-8?q?=E3=80=90=E4=BC=98=E5=8C=96=E3=80=91code=20?= =?UTF-8?q?check=E6=95=B4=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- debug/accuracy_tools/atat/atat.py | 12 +-- .../atat/core/file_check_util.py | 2 +- debug/accuracy_tools/atat/core/utils.py | 40 +--------- .../atat/pytorch/advisor/advisor.py | 69 +++++++++--------- .../compare/api_precision_compare.py | 46 ++++++------ .../api_accuracy_checker/compare/compare.py | 73 +++++++++---------- .../api_accuracy_checker/dump/api_info.py | 31 ++++---- .../run_ut/run_overflow_check.py | 2 +- .../api_accuracy_checker/run_ut/run_ut.py | 2 +- .../api_accuracy_checker/test/run_ut.py | 30 +++++--- .../test/ut/run_ut/test_data_generate.py | 17 +++-- .../atat/pytorch/common/exceptions.py | 15 +++- .../accuracy_tools/atat/pytorch/common/log.py | 16 +++- .../atat/pytorch/common/recursive.py | 11 ++- .../atat/pytorch/common/utils.py | 9 ++- .../atat/pytorch/compare/acc_compare.py | 12 ++- .../atat/pytorch/debugger/debugger_config.py | 11 +-- .../pytorch/debugger/precision_debugger.py | 6 +- debug/accuracy_tools/atat/pytorch/service.py | 22 +++--- 19 files changed, 224 insertions(+), 202 deletions(-) diff --git a/debug/accuracy_tools/atat/atat.py b/debug/accuracy_tools/atat/atat.py index 799200ae41c..12c4042bee9 100644 --- a/debug/accuracy_tools/atat/atat.py +++ b/debug/accuracy_tools/atat/atat.py @@ -18,16 +18,18 @@ import sys from atat.pytorch.api_accuracy_checker.run_ut.run_ut import _run_ut_parser, run_ut_command from ptdbg_ascend.src.python.ptdbg_ascend.parse_tool.cli import parse as cli_parse from atat.pytorch.api_accuracy_checker.run_ut.multi_run_ut import prepare_config, run_parallel_ut -from atat.pytorch.api_accuracy_checker.compare.api_precision_compare import _api_precision_compare_parser, _api_precision_compare_command -from atat.pytorch.api_accuracy_checker.run_ut.run_overflow_check import _run_overflow_check_parser, _run_overflow_check_command +from atat.pytorch.api_accuracy_checker.compare.api_precision_compare import _api_precision_compare_parser, \ + _api_precision_compare_command +from atat.pytorch.api_accuracy_checker.run_ut.run_overflow_check import _run_overflow_check_parser, \ + _run_overflow_check_command def main(): parser = argparse.ArgumentParser( formatter_class=argparse.RawDescriptionHelpFormatter, description="atat(ascend training accuracy tools), [Powered by MindStudio].\n" - "Providing one-site accuracy difference debugging toolkit for training on Ascend Devices.\n" - f"For any issue, refer README.md first", + "Providing one-site accuracy difference debugging toolkit for training on Ascend Devices.\n" + f"For any issue, refer README.md first", ) parser.set_defaults(print_help=parser.print_help) parser.add_argument('-f', '--framework', required=True, choices=['pytorch'], @@ -62,4 +64,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/debug/accuracy_tools/atat/core/file_check_util.py b/debug/accuracy_tools/atat/core/file_check_util.py index b10cdd61049..7cb071bd652 100644 --- a/debug/accuracy_tools/atat/core/file_check_util.py +++ b/debug/accuracy_tools/atat/core/file_check_util.py @@ -241,7 +241,7 @@ def _user_interactive_confirm(message): print_warn_log("User canceled.") raise FileCheckException(FileCheckException.INVALID_PATH_ERROR) else: - print("Input is error, please enter 'c' or 'e'.") + print_error_log("Input is error, please enter 'c' or 'e'.") def check_path_owner_consistent(path): diff --git a/debug/accuracy_tools/atat/core/utils.py b/debug/accuracy_tools/atat/core/utils.py index 224e30aef35..25ddf51b2c9 100644 --- a/debug/accuracy_tools/atat/core/utils.py +++ b/debug/accuracy_tools/atat/core/utils.py @@ -20,15 +20,14 @@ import re import shutil import stat import subprocess -import sys import time import json -from json.decoder import JSONDecodeError from datetime import datetime, timezone from pathlib import Path import numpy as np from .file_check_util import FileOpen, FileChecker, FileCheckConst +from .log import print_info_log, print_warn_log, print_error_log device = collections.namedtuple('device', ['type', 'index']) @@ -271,43 +270,6 @@ def make_dump_path_if_not_exists(dump_path): print_error_log('{} already exists and is not a directory.'.format(dump_path)) -def _print_log(level, msg, end='\n'): - current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(int(time.time()))) - pid = os.getgid() - print(current_time + "(" + str(pid) + ")-[" + level + "]" + msg, end=end) - sys.stdout.flush() - - -def print_info_log(info_msg, end='\n'): - """ - Function Description: - print info log. - Parameter: - info_msg: the info message. - """ - _print_log("INFO", info_msg, end=end) - - -def print_error_log(error_msg): - """ - Function Description: - print error log. - Parameter: - error_msg: the error message. - """ - _print_log("ERROR", error_msg) - - -def print_warn_log(warn_msg): - """ - Function Description: - print warn log. - Parameter: - warn_msg: the warning message. - """ - _print_log("WARNING", warn_msg) - - def check_mode_valid(mode, scope=None, api_list=None): if scope is None: scope = [] diff --git a/debug/accuracy_tools/atat/pytorch/advisor/advisor.py b/debug/accuracy_tools/atat/pytorch/advisor/advisor.py index 5ae692a998d..db193dcd833 100644 --- a/debug/accuracy_tools/atat/pytorch/advisor/advisor.py +++ b/debug/accuracy_tools/atat/pytorch/advisor/advisor.py @@ -32,40 +32,7 @@ class Advisor: def __init__(self, input_data, out_path=""): self.input_data = input_data self.out_path = os.path.realpath(out_path) - - def _parse_input_data(self): - data_columns = self.input_data.columns.values - if {CompareConst.ACCURACY, CompareConst.NPU_NAME}.issubset(data_columns): - self.file_type = Const.ALL - elif {CompareConst.RESULT, CompareConst.NPU_MD5}.issubset(data_columns): - self.file_type = Const.MD5 - elif {CompareConst.MAX_DIFF, CompareConst.RESULT}.issubset(data_columns): - self.file_type = Const.SUMMARY - else: - print_error_log('Compare result does not meet the required conditions.') - raise CompareException(CompareException.INVALID_DATA_ERROR) - df = self.input_data.reset_index() - return df - - def _check_path_vaild(self): - out_path_checker = FileChecker(self.out_path, FileCheckConst.DIR, FileCheckConst.WRITE_ABLE) - out_path_checker.common_check() - - def gen_advisor_message(self, node_name): - if AdvisorConst.FORWARD in node_name: - if AdvisorConst.INPUT in node_name: - message = AdvisorConst.FORWARD_INPUT_SUGGEST - else: - message = AdvisorConst.FORWARD_OUTPUT_SUGGEST - message = self.deterministic_advisor(message, node_name) - else: - if AdvisorConst.INPUT in node_name: - message = AdvisorConst.BACKWARD_INPUT_SUGGEST - else: - message = AdvisorConst.BACKWARD_OUTPUT_SUGGEST - message = self.deterministic_advisor(message, node_name) - message = self.batch_norm_advisor(message, node_name) - return message + self.file_type = None @staticmethod def deterministic_advisor(message, node_name): @@ -102,6 +69,22 @@ class Advisor: result = AdvisorResult(node_name, index, message) return result + def gen_advisor_message(self, node_name): + if AdvisorConst.FORWARD in node_name: + if AdvisorConst.INPUT in node_name: + message = AdvisorConst.FORWARD_INPUT_SUGGEST + else: + message = AdvisorConst.FORWARD_OUTPUT_SUGGEST + message = self.deterministic_advisor(message, node_name) + else: + if AdvisorConst.INPUT in node_name: + message = AdvisorConst.BACKWARD_INPUT_SUGGEST + else: + message = AdvisorConst.BACKWARD_OUTPUT_SUGGEST + message = self.deterministic_advisor(message, node_name) + message = self.batch_norm_advisor(message, node_name) + return message + def analysis(self): self._check_path_vaild() analyze_data = self._parse_input_data() @@ -120,3 +103,21 @@ class Advisor: result = self.gen_advisor_result(failing_data) message_list = result.print_advisor_log() result.gen_summary_file(self.out_path, message_list) + + def _parse_input_data(self): + data_columns = self.input_data.columns.values + if {CompareConst.ACCURACY, CompareConst.NPU_NAME}.issubset(data_columns): + self.file_type = Const.ALL + elif {CompareConst.RESULT, CompareConst.NPU_MD5}.issubset(data_columns): + self.file_type = Const.MD5 + elif {CompareConst.MAX_DIFF, CompareConst.RESULT}.issubset(data_columns): + self.file_type = Const.SUMMARY + else: + print_error_log('Compare result does not meet the required conditions.') + raise CompareException(CompareException.INVALID_DATA_ERROR) + df = self.input_data.reset_index() + return df + + def _check_path_vaild(self): + out_path_checker = FileChecker(self.out_path, FileCheckConst.DIR, FileCheckConst.WRITE_ABLE) + out_path_checker.common_check() diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/api_precision_compare.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/api_precision_compare.py index 6a544de21a0..9484833e52c 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/api_precision_compare.py +++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/api_precision_compare.py @@ -83,6 +83,24 @@ class BenchmarkStandard: def __str__(self): return "%s" % (self.api_name) + @staticmethod + def _get_status(ratio, algorithm): + error_threshold = benchmark_algorithms_thresholds.get(algorithm).get('error_threshold') + warning_threshold = benchmark_algorithms_thresholds.get(algorithm).get('warning_threshold') + if ratio > error_threshold: + return CompareConst.ERROR + elif ratio > warning_threshold: + return CompareConst.WARNING + return CompareConst.PASS + + @staticmethod + def _calc_ratio(x, y, default_value=1.0): + x, y = convert_str_to_float(x), convert_str_to_float(y) + if math.isclose(y, 0.0): + return 1.0 if math.isclose(x, 0.0) else default_value + else: + return abs(x / y) + def get_result(self): self._compare_ratio() self.small_value_err_status = self._get_status(self.small_value_err_ratio, 'small_value') @@ -99,6 +117,11 @@ class BenchmarkStandard: elif CompareConst.WARNING in self.check_result_list: self.final_result = CompareConst.WARNING + def to_column_value(self): + return [self.small_value_err_ratio, self.small_value_err_status, self.rmse_ratio, + self.rmse_status, self.max_rel_err_ratio, self.max_rel_err_status, self.mean_rel_err_ratio, + self.mean_rel_err_status, self.eb_ratio, self.eb_status] + def _compare_ratio(self): self.small_value_err_ratio = self._calc_ratio( self.npu_precision.get(ApiPrecisionCompareColumn.SMALL_VALUE_ERROR_RATE), @@ -114,29 +137,6 @@ class BenchmarkStandard: self.eb_ratio = self._calc_ratio(self.npu_precision.get(ApiPrecisionCompareColumn.EB), self.gpu_precision.get(ApiPrecisionCompareColumn.EB), 10000.0) - def to_column_value(self): - return [self.small_value_err_ratio, self.small_value_err_status, self.rmse_ratio, - self.rmse_status, self.max_rel_err_ratio, self.max_rel_err_status, self.mean_rel_err_ratio, - self.mean_rel_err_status, self.eb_ratio, self.eb_status] - - @staticmethod - def _get_status(ratio, algorithm): - error_threshold = benchmark_algorithms_thresholds.get(algorithm).get('error_threshold') - warning_threshold = benchmark_algorithms_thresholds.get(algorithm).get('warning_threshold') - if ratio > error_threshold: - return CompareConst.ERROR - elif ratio > warning_threshold: - return CompareConst.WARNING - return CompareConst.PASS - - @staticmethod - def _calc_ratio(x, y, default_value=1.0): - x, y = convert_str_to_float(x), convert_str_to_float(y) - if math.isclose(y, 0.0): - return 1.0 if math.isclose(x, 0.0) else default_value - else: - return abs(x / y) - def write_detail_csv(content, save_path): rows = [] diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare.py index 67aa69e209b..350a4077475 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare.py +++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare.py @@ -39,6 +39,42 @@ class Comparator: "total_num": 0, "forward_or_backward_fail_num": 0 } + @staticmethod + def _compare_dropout(api_name, bench_output, device_output): + tensor_num = bench_output.numel() + if tensor_num >= 100: + if abs((bench_output == 0).sum() - (device_output == 0).cpu().sum()) / tensor_num < 0.1: + return CompareConst.PASS, 1 + else: + return CompareConst.ERROR, 0 + else: + return CompareConst.PASS, 1 + + @staticmethod + def _compare_builtin_type(bench_output, device_output, compare_column): + if not isinstance(bench_output, (bool, int, float, str)): + return CompareConst.PASS, compare_column, "" + if bench_output != device_output: + return CompareConst.ERROR, compare_column, "" + compare_column.error_rate = 0 + return CompareConst.PASS, compare_column, "" + + @staticmethod + def _compare_bool_tensor(bench_output, device_output): + error_nums = (bench_output != device_output).sum() + if bench_output.size == 0: + return CompareConst.NAN, CompareConst.ERROR, "There is not bench calculation result." + error_rate = float(error_nums / bench_output.size) + result = CompareConst.PASS if error_rate == 0 else CompareConst.ERROR + return error_rate, result, "" + + @staticmethod + def _get_absolute_threshold_attribute(api_name, dtype): + small_value_threshold = apis_threshold.get(api_name).get(dtype).get('small_value') + small_value_atol = apis_threshold.get(api_name).get(dtype).get('small_value_atol') + rtol = apis_threshold.get(api_name).get(dtype).get('rtol') + return small_value_threshold, small_value_atol, rtol + def print_pretest_result(self): self.get_statistics_from_result_csv() total_tests = self.test_result_cnt.get("total_num", 0) @@ -333,40 +369,3 @@ class Comparator: return CompareConst.WARNING, compare_column, message message += "Relative error is less than 0.0001, consider as pass.\n" return CompareConst.PASS, compare_column, message - - @staticmethod - def _compare_dropout(api_name, bench_output, device_output): - tensor_num = bench_output.numel() - if tensor_num >= 100: - if abs((bench_output == 0).sum() - (device_output == 0).cpu().sum()) / tensor_num < 0.1: - return CompareConst.PASS, 1 - else: - return CompareConst.ERROR, 0 - else: - return CompareConst.PASS, 1 - - @staticmethod - def _compare_builtin_type(bench_output, device_output, compare_column): - if not isinstance(bench_output, (bool, int, float, str)): - return CompareConst.PASS, compare_column, "" - if bench_output != device_output: - return CompareConst.ERROR, compare_column, "" - compare_column.error_rate = 0 - return CompareConst.PASS, compare_column, "" - - - @staticmethod - def _compare_bool_tensor(bench_output, device_output): - error_nums = (bench_output != device_output).sum() - if bench_output.size == 0: - return CompareConst.NAN, CompareConst.ERROR, "There is not bench calculation result." - error_rate = float(error_nums / bench_output.size) - result = CompareConst.PASS if error_rate == 0 else CompareConst.ERROR - return error_rate, result, "" - - @staticmethod - def _get_absolute_threshold_attribute(api_name, dtype): - small_value_threshold = apis_threshold.get(api_name).get(dtype).get('small_value') - small_value_atol = apis_threshold.get(api_name).get(dtype).get('small_value_atol') - rtol = apis_threshold.get(api_name).get(dtype).get('rtol') - return small_value_threshold, small_value_atol, rtol diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/dump/api_info.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/dump/api_info.py index 7452cec74e8..a8c59a229fe 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/dump/api_info.py +++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/dump/api_info.py @@ -104,6 +104,22 @@ class APIInfo: else: return os.path.join(save_path, dir_name) + @staticmethod + def _convert_numpy_to_builtin(arg): + type_mapping = { + np.integer: int, + np.floating: float, + np.bool_: bool, + np.complexfloating: complex, + np.str_: str, + np.bytes_: bytes, + np.unicode_: str + } + for numpy_type, builtin_type in type_mapping.items(): + if isinstance(arg, numpy_type): + return builtin_type(arg), get_type_name(str(type(arg))) + return arg, '' + def analyze_element(self, element): if isinstance(element, (list, tuple)): out = [] @@ -180,21 +196,6 @@ class APIInfo: single_arg.update({'type': numpy_type}) single_arg.update({'value': value}) return single_arg - - def _convert_numpy_to_builtin(self, arg): - type_mapping = { - np.integer: int, - np.floating: float, - np.bool_: bool, - np.complexfloating: complex, - np.str_: str, - np.bytes_: bytes, - np.unicode_: str - } - for numpy_type, builtin_type in type_mapping.items(): - if isinstance(arg, numpy_type): - return builtin_type(arg), get_type_name(str(type(arg))) - return arg, '' class ForwardAPIInfo(APIInfo): diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py index bea882f7507..86c52b23bf8 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py @@ -36,7 +36,7 @@ def check_tensor_overflow(x): def check_data_overflow(x): if isinstance(x, (tuple, list)) and x: - for i, item in enumerate(x): + for _, item in enumerate(x): if check_data_overflow(item): return True return False diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/run_ut.py index 3186913e948..b0eac9da4bd 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/run_ut.py @@ -160,7 +160,7 @@ def run_ut(config): csv_reader = csv.reader(file) next(csv_reader) api_name_set = {row[0] for row in csv_reader} - for i, (api_full_name, api_info_dict) in enumerate(tqdm(config.forward_content.items(), **tqdm_params)): + for _, (api_full_name, api_info_dict) in enumerate(tqdm(config.forward_content.items(), **tqdm_params)): if api_full_name in api_name_set: continue if is_unsupported_api(api_full_name): # TODO run_ut does not support to the npu fusion api and distributed api diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/test/run_ut.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/test/run_ut.py index c7394969794..16e8f091014 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/test/run_ut.py +++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/test/run_ut.py @@ -3,6 +3,7 @@ import shutil import subprocess import sys + def run_ut(): cur_dir = os.path.realpath(os.path.dirname(__file__)) top_dir = os.path.realpath(os.path.dirname(cur_dir)) @@ -10,14 +11,25 @@ def run_ut(): src_dir = top_dir report_dir = os.path.join(cur_dir, "report") + # cleanup and recreate report dir if os.path.exists(report_dir): shutil.rmtree(report_dir) - os.makedirs(report_dir) - cmd = ["python3", "-m", "pytest", ut_path, "--junitxml=" + report_dir + "/final.xml", - "--cov=" + src_dir, "--cov-branch", "--cov-report=xml:" + report_dir + "/coverage.xml"] - + # set paths for multi-platform compatibility + junit_report = os.path.join(report_dir, "final.xml") + coverage_report = os.path.join(report_dir, "coverage.xml") + + cmd = [ + "python3", + "-m", "pytest", + ut_path, + f"--junitxml={junit_report}", + f"--cov={src_dir}", + "--cov-branch", + f"--cov-report=xml:{coverage_report}" + ] + result_ut = subprocess.Popen(cmd, shell=False, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) while result_ut.poll() is None: @@ -25,16 +37,16 @@ def run_ut(): if line: print(line) - ut_flag = False - if result_ut.returncode == 0: - ut_flag = True + tests_passed = result_ut.returncode == 0 + if tests_passed: print("run ut successfully.") else: print("run ut failed.") - return ut_flag + return tests_passed + -if __name__=="__main__": +if __name__ == "__main__": if run_ut(): sys.exit(0) else: diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/test/ut/run_ut/test_data_generate.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/test/ut/run_ut/test_data_generate.py index b98f84d5164..2594b58dc26 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/test/ut/run_ut/test_data_generate.py +++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/test/ut/run_ut/test_data_generate.py @@ -16,6 +16,7 @@ for api_full_name, api_info_dict in forward_content.items(): max_value = 5.7421875 min_value = -5.125 + class TestDataGenerateMethods(unittest.TestCase): def test_gen_api_params(self): api_info = copy.deepcopy(api_info_dict) @@ -44,7 +45,7 @@ class TestDataGenerateMethods(unittest.TestCase): max_diff = abs(data.max() - max_value) min_diff = abs(data.min() - min_value) self.assertEqual(data.dtype, torch.float32) - self.assertEqual(data.requires_grad, True) + self.assertTrue(data.requires_grad) self.assertLessEqual(max_diff, 0.001) self.assertLessEqual(min_diff, 0.001) self.assertEqual(data.shape, torch.Size([2, 2560, 24, 24])) @@ -53,23 +54,23 @@ class TestDataGenerateMethods(unittest.TestCase): api_info = copy.deepcopy(api_info_dict) kwargs_params = gen_kwargs(api_info, None) self.assertEqual(kwargs_params, {'inplace': False}) - + def test_gen_kwargs_2(self): k_dict = {"inplace": {"type": "bool", "value": "False"}} for key, value in k_dict.items(): gen_torch_kwargs(k_dict, key, value) self.assertEqual(k_dict, {'inplace': False}) - + def test_gen_random_tensor(self): data = gen_random_tensor(api_info_dict.get('args')[0], None) max_diff = abs(data.max() - max_value) min_diff = abs(data.min() - min_value) self.assertEqual(data.dtype, torch.float32) - self.assertEqual(data.requires_grad, False) + self.assertFalse(data.requires_grad) self.assertLessEqual(max_diff, 0.001) self.assertLessEqual(min_diff, 0.001) self.assertEqual(data.shape, torch.Size([2, 2560, 24, 24])) - + def test_gen_common_tensor(self): info = api_info_dict.get('args')[0] low, high = info.get('Min'), info.get('Max') @@ -82,14 +83,14 @@ class TestDataGenerateMethods(unittest.TestCase): max_diff = abs(data.max() - max_value) min_diff = abs(data.min() - min_value) self.assertEqual(data.dtype, torch.float32) - self.assertEqual(data.requires_grad, False) + self.assertFalse(data.requires_grad) self.assertLessEqual(max_diff, 0.001) self.assertLessEqual(min_diff, 0.001) self.assertEqual(data.shape, torch.Size([2, 2560, 24, 24])) - + def test_gen_bool_tensor(self): info = {"type": "torch.Tensor", "dtype": "torch.bool", "shape": [1, 1, 160, 256], \ - "Max": 1, "Min": 0, "requires_grad": False} + "Max": 1, "Min": 0, "requires_grad": False} low, high = info.get("Min"), info.get("Max") shape = tuple(info.get("shape")) data = gen_bool_tensor(low, high, shape) diff --git a/debug/accuracy_tools/atat/pytorch/common/exceptions.py b/debug/accuracy_tools/atat/pytorch/common/exceptions.py index 17733b5bfd5..9e6d78cb4f0 100644 --- a/debug/accuracy_tools/atat/pytorch/common/exceptions.py +++ b/debug/accuracy_tools/atat/pytorch/common/exceptions.py @@ -1,6 +1,6 @@ - class CodedException(Exception): def __init__(self, code, error_info=''): + super().__init__() self.error_info = self.err_strs.get(code) + error_info def __str__(self): @@ -10,7 +10,7 @@ class CodedException(Exception): class MsaccException(CodedException): INVALID_PARAM_ERROR = 0 OVERFLOW_NUMS_ERROR = 1 - + err_strs = { INVALID_PARAM_ERROR: "[msacc] 无效参数: ", OVERFLOW_NUMS_ERROR: "[msacc] 超过预设溢出次数 当前溢出次数:" @@ -68,8 +68,17 @@ class StepException(CodedException): InvalidPostProcess: "[msacc] 错误的step后处理配置: ", } + class FreeBenchmarkException(CodedException): UnsupportedType = 0 err_strs = { UnsupportedType: "[msacc] Free benchmark get unsupported type: " - } \ No newline at end of file + } + + +class DistributedNotInitializedError(Exception): + def __init__(self, msg): + self.msg = msg + + def __str__(self): + return self.msg diff --git a/debug/accuracy_tools/atat/pytorch/common/log.py b/debug/accuracy_tools/atat/pytorch/common/log.py index fab5aca45c0..c37887e2b61 100644 --- a/debug/accuracy_tools/atat/pytorch/common/log.py +++ b/debug/accuracy_tools/atat/pytorch/common/log.py @@ -1,14 +1,23 @@ import os import time import sys + from .utils import get_rank_if_initialized +from .exceptions import DistributedNotInitializedError def on_rank_0(func): def func_rank_0(*args, **kwargs): - current_rank = get_rank_if_initialized() + try: + current_rank = get_rank_if_initialized() + except DistributedNotInitializedError: + current_rank = None + if current_rank is None or current_rank == 0: return func(*args, **kwargs) + else: + raise RuntimeError("Func can not be called on rank 0 or if the distributed environment is not initialized. " + f"current rank: {current_rank}") return func_rank_0 @@ -17,7 +26,10 @@ def _print_log(level, msg, end='\n'): current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(int(time.time()))) pid = os.getpid() full_msg = current_time + "(" + str(pid) + ")-[" + level + "]" + msg - current_rank = get_rank_if_initialized() + try: + current_rank = get_rank_if_initialized() + except DistributedNotInitializedError: + current_rank = None if current_rank is not None: full_msg = f"[rank {current_rank}]-" + full_msg print(full_msg, end=end) diff --git a/debug/accuracy_tools/atat/pytorch/common/recursive.py b/debug/accuracy_tools/atat/pytorch/common/recursive.py index c8a19a63117..9b222f5f521 100644 --- a/debug/accuracy_tools/atat/pytorch/common/recursive.py +++ b/debug/accuracy_tools/atat/pytorch/common/recursive.py @@ -1,10 +1,13 @@ -import torch import numpy as np +import torch + from .log import print_warn_log _recursive_key_stack = [] special_type = (torch.device, torch.dtype, torch.Size, torch.Tensor, np.integer, np.floating, np.bool_, np.complexfloating, \ np.str_, np.byte, np.unicode_, bool, int, float, str, slice) + + def recursive_apply_transform(args, transform): global _recursive_key_stack if isinstance(args, special_type): @@ -18,11 +21,11 @@ def recursive_apply_transform(args, transform): _recursive_key_stack.pop() return type(args)(transform_result) elif isinstance(args, dict): - transform_result = {} + transform_dict = {} for k, arg in args.items(): _recursive_key_stack.append(str(k)) - transform_result[k] = recursive_apply_transform(arg, transform) + transform_dict[k] = recursive_apply_transform(arg, transform) _recursive_key_stack.pop() - return transform_result + return transform_dict elif args is not None: print_warn_log(f"Data type {type(args)} is not supported.") diff --git a/debug/accuracy_tools/atat/pytorch/common/utils.py b/debug/accuracy_tools/atat/pytorch/common/utils.py index e88d506b2c3..bbb154d7bd6 100644 --- a/debug/accuracy_tools/atat/pytorch/common/utils.py +++ b/debug/accuracy_tools/atat/pytorch/common/utils.py @@ -21,6 +21,9 @@ import stat import torch import numpy as np from functools import wraps + +from .exceptions import DistributedNotInitializedError + try: import torch_npu except ImportError: @@ -93,9 +96,13 @@ def torch_device_guard(func): def get_rank_if_initialized(): + """ + return rank id if it is initialized or raise Exception: DistributedNotInitializedError + """ if torch.distributed.is_initialized(): return torch.distributed.get_rank() - return None + else: + raise DistributedNotInitializedError("torch distributed environment is not initialized") def seed_all(seed=1234, mode=False): diff --git a/debug/accuracy_tools/atat/pytorch/compare/acc_compare.py b/debug/accuracy_tools/atat/pytorch/compare/acc_compare.py index d3a072f5c78..6ef6795f8b2 100644 --- a/debug/accuracy_tools/atat/pytorch/compare/acc_compare.py +++ b/debug/accuracy_tools/atat/pytorch/compare/acc_compare.py @@ -508,7 +508,11 @@ def handle_inf_nan(n_value, b_value): b_inf = np.isinf(b_value) n_nan = np.isnan(n_value) b_nan = np.isnan(b_value) - if np.any(n_inf) or np.any(b_inf) or np.any(n_nan) or np.any(b_nan): + + # merge boolean expressions + any_inf = np.any(n_inf) or np.any(b_inf) + any_nan = np.any(n_nan) or np.any(b_nan) + if any_inf or any_nan: if np.array_equal(n_inf, b_inf) and np.array_equal(n_nan, b_nan): n_value[n_inf] = 0 b_value[b_inf] = 0 @@ -926,8 +930,10 @@ def compare_process(file_handles, stack_mode, fuzzy_match, summary_compare=False except StopIteration: read_err_bench = False - if len(npu_ops_queue) == 0 or len(bench_ops_queue) == 0 or ( - len(npu_ops_queue) == last_npu_ops_len and len(bench_ops_queue) == last_bench_ops_len): + # merge all boolean expressions + both_empty = not npu_ops_queue and not bench_ops_queue + no_change = (len(npu_ops_queue) == last_npu_ops_len) and (len(bench_ops_queue) == last_bench_ops_len) + if both_empty or no_change: continue n_match_point, b_match_point = match_op(npu_ops_queue, bench_ops_queue, fuzzy_match) diff --git a/debug/accuracy_tools/atat/pytorch/debugger/debugger_config.py b/debug/accuracy_tools/atat/pytorch/debugger/debugger_config.py index 9fc97332f64..2d4147cd67c 100644 --- a/debug/accuracy_tools/atat/pytorch/debugger/debugger_config.py +++ b/debug/accuracy_tools/atat/pytorch/debugger/debugger_config.py @@ -67,6 +67,12 @@ class DebuggerConfig: self._check_step() return True + def check_model(self, model): + if self.level in ["L0", "mix"] and not model: + raise Exception( + f"For level {self.level}, PrecisionDebugger must receive a model argument.", + ) + def _check_rank(self): if self.rank: for rank_id in self.rank: @@ -81,8 +87,3 @@ class DebuggerConfig: if not isinstance(s, int) or s < 0: raise ValueError(f"step element {s} must be an integer and greater than or equal to 0.") - def check_model(self, model): - if self.level in ["L0", "mix"] and not model: - raise Exception( - f"For level {self.level}, PrecisionDebugger must receive a model argument.", - ) \ No newline at end of file diff --git a/debug/accuracy_tools/atat/pytorch/debugger/precision_debugger.py b/debug/accuracy_tools/atat/pytorch/debugger/precision_debugger.py index e0ffa4e4d6e..8d67ae9ba6e 100644 --- a/debug/accuracy_tools/atat/pytorch/debugger/precision_debugger.py +++ b/debug/accuracy_tools/atat/pytorch/debugger/precision_debugger.py @@ -42,6 +42,10 @@ class PrecisionDebugger: print_warn_log_rank_0("The enable_dataloader feature will be deprecated in the future.") dataloader._BaseDataLoaderIter.__next__ = iter_tracer(dataloader._BaseDataLoaderIter.__next__) + @property + def instance(self): + return self._instance + @classmethod def start(cls): instance = cls._instance @@ -79,7 +83,7 @@ class PrecisionDebugger: def iter_tracer(func): def func_wrapper(*args, **kwargs): - debugger_instance = PrecisionDebugger._instance + debugger_instance = PrecisionDebugger.instance debugger_instance.enable_dataloader = False if not debugger_instance.service.first_start: debugger_instance.stop() diff --git a/debug/accuracy_tools/atat/pytorch/service.py b/debug/accuracy_tools/atat/pytorch/service.py index 9c079aedebe..093510b4ed9 100644 --- a/debug/accuracy_tools/atat/pytorch/service.py +++ b/debug/accuracy_tools/atat/pytorch/service.py @@ -1,15 +1,15 @@ +import functools import os from pathlib import Path -import functools -import torch + +from .common import print_info_log_rank_0 +from .common.file_check import FileChecker, FileCheckConst, check_path_before_create +from .common.utils import get_rank_if_initialized, is_gpu, Const from .functional import build_repair, build_data_collector, build_step_post_process +from .functional.data_processor import ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs from .functional.scope import BaseScope -from .common.utils import get_rank_if_initialized, is_gpu, Const -from .common.file_check import FileChecker, FileCheckConst, check_path_before_create -from .common import print_info_log_rank_0 -from .hook_module.api_registry import api_register from .hook_module import remove_dropout -from .functional.data_processor import ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs +from .hook_module.api_registry import api_register from .module_processer import ModuleProcesser @@ -29,6 +29,7 @@ class Service: self.first_start = True self.current_rank = None self.first_touch_dir = True + self.dump_iter_dir = None def build_hook(self, module_type, name): def pre_hook(repair, api_or_module_name, module, args, kwargs): @@ -53,7 +54,7 @@ class Service: self.data_collector.visit_and_clear_overflow_status(api_or_module_name) if not self.switch: - return + return None if self.data_collector: module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=output) self.data_collector.forward_data_collect(api_or_module_name, module, pid, module_input_output) @@ -151,12 +152,13 @@ class Service: print_info_log_rank_0("The {} hook function is successfully mounted to the model.".format(hook_name)) if self.config.level in ["L0", "mix"]: - assert self.model is not None + if self.model is None: + raise Exception("Model is None") print_info_log_rank_0("The init dump mode is enabled, and the module dump function will not be available") for name, module in self.model.named_modules(): if module == self.model: continue - prefix = BaseScope.Module_Type_Module + Const.SEP + name + Const.SEP +\ + prefix = BaseScope.Module_Type_Module + Const.SEP + name + Const.SEP + \ module.__class__.__name__ + Const.SEP pre_forward_hook, forward_hook, backward_hook = self.build_hook(BaseScope.Module_Type_Module, prefix) -- Gitee From fbf67637253ea5fe46f3b4d5a9d6f74e47bc921f Mon Sep 17 00:00:00 2001 From: shawn_zhu1 Date: Sat, 6 Jul 2024 16:30:37 +0800 Subject: [PATCH 3/3] =?UTF-8?q?=E3=80=90=E4=BC=98=E5=8C=96=E3=80=91code=20?= =?UTF-8?q?check=E6=95=B4=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- debug/accuracy_tools/atat/atat.py | 12 +-- .../atat/core/file_check_util.py | 2 +- debug/accuracy_tools/atat/core/utils.py | 40 +--------- .../atat/pytorch/advisor/advisor.py | 69 +++++++++--------- .../compare/api_precision_compare.py | 46 ++++++------ .../api_accuracy_checker/compare/compare.py | 73 +++++++++---------- .../api_accuracy_checker/dump/api_info.py | 31 ++++---- .../run_ut/run_overflow_check.py | 2 +- .../api_accuracy_checker/run_ut/run_ut.py | 2 +- .../api_accuracy_checker/test/run_ut.py | 30 +++++--- .../test/ut/run_ut/test_data_generate.py | 17 +++-- .../atat/pytorch/common/exceptions.py | 16 +++- .../accuracy_tools/atat/pytorch/common/log.py | 16 +++- .../atat/pytorch/common/recursive.py | 11 ++- .../atat/pytorch/common/utils.py | 9 ++- .../atat/pytorch/compare/acc_compare.py | 12 ++- .../atat/pytorch/debugger/debugger_config.py | 11 +-- .../pytorch/debugger/precision_debugger.py | 6 +- debug/accuracy_tools/atat/pytorch/service.py | 22 +++--- 19 files changed, 225 insertions(+), 202 deletions(-) diff --git a/debug/accuracy_tools/atat/atat.py b/debug/accuracy_tools/atat/atat.py index 799200ae41c..12c4042bee9 100644 --- a/debug/accuracy_tools/atat/atat.py +++ b/debug/accuracy_tools/atat/atat.py @@ -18,16 +18,18 @@ import sys from atat.pytorch.api_accuracy_checker.run_ut.run_ut import _run_ut_parser, run_ut_command from ptdbg_ascend.src.python.ptdbg_ascend.parse_tool.cli import parse as cli_parse from atat.pytorch.api_accuracy_checker.run_ut.multi_run_ut import prepare_config, run_parallel_ut -from atat.pytorch.api_accuracy_checker.compare.api_precision_compare import _api_precision_compare_parser, _api_precision_compare_command -from atat.pytorch.api_accuracy_checker.run_ut.run_overflow_check import _run_overflow_check_parser, _run_overflow_check_command +from atat.pytorch.api_accuracy_checker.compare.api_precision_compare import _api_precision_compare_parser, \ + _api_precision_compare_command +from atat.pytorch.api_accuracy_checker.run_ut.run_overflow_check import _run_overflow_check_parser, \ + _run_overflow_check_command def main(): parser = argparse.ArgumentParser( formatter_class=argparse.RawDescriptionHelpFormatter, description="atat(ascend training accuracy tools), [Powered by MindStudio].\n" - "Providing one-site accuracy difference debugging toolkit for training on Ascend Devices.\n" - f"For any issue, refer README.md first", + "Providing one-site accuracy difference debugging toolkit for training on Ascend Devices.\n" + f"For any issue, refer README.md first", ) parser.set_defaults(print_help=parser.print_help) parser.add_argument('-f', '--framework', required=True, choices=['pytorch'], @@ -62,4 +64,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/debug/accuracy_tools/atat/core/file_check_util.py b/debug/accuracy_tools/atat/core/file_check_util.py index b10cdd61049..7cb071bd652 100644 --- a/debug/accuracy_tools/atat/core/file_check_util.py +++ b/debug/accuracy_tools/atat/core/file_check_util.py @@ -241,7 +241,7 @@ def _user_interactive_confirm(message): print_warn_log("User canceled.") raise FileCheckException(FileCheckException.INVALID_PATH_ERROR) else: - print("Input is error, please enter 'c' or 'e'.") + print_error_log("Input is error, please enter 'c' or 'e'.") def check_path_owner_consistent(path): diff --git a/debug/accuracy_tools/atat/core/utils.py b/debug/accuracy_tools/atat/core/utils.py index 224e30aef35..25ddf51b2c9 100644 --- a/debug/accuracy_tools/atat/core/utils.py +++ b/debug/accuracy_tools/atat/core/utils.py @@ -20,15 +20,14 @@ import re import shutil import stat import subprocess -import sys import time import json -from json.decoder import JSONDecodeError from datetime import datetime, timezone from pathlib import Path import numpy as np from .file_check_util import FileOpen, FileChecker, FileCheckConst +from .log import print_info_log, print_warn_log, print_error_log device = collections.namedtuple('device', ['type', 'index']) @@ -271,43 +270,6 @@ def make_dump_path_if_not_exists(dump_path): print_error_log('{} already exists and is not a directory.'.format(dump_path)) -def _print_log(level, msg, end='\n'): - current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(int(time.time()))) - pid = os.getgid() - print(current_time + "(" + str(pid) + ")-[" + level + "]" + msg, end=end) - sys.stdout.flush() - - -def print_info_log(info_msg, end='\n'): - """ - Function Description: - print info log. - Parameter: - info_msg: the info message. - """ - _print_log("INFO", info_msg, end=end) - - -def print_error_log(error_msg): - """ - Function Description: - print error log. - Parameter: - error_msg: the error message. - """ - _print_log("ERROR", error_msg) - - -def print_warn_log(warn_msg): - """ - Function Description: - print warn log. - Parameter: - warn_msg: the warning message. - """ - _print_log("WARNING", warn_msg) - - def check_mode_valid(mode, scope=None, api_list=None): if scope is None: scope = [] diff --git a/debug/accuracy_tools/atat/pytorch/advisor/advisor.py b/debug/accuracy_tools/atat/pytorch/advisor/advisor.py index 5ae692a998d..db193dcd833 100644 --- a/debug/accuracy_tools/atat/pytorch/advisor/advisor.py +++ b/debug/accuracy_tools/atat/pytorch/advisor/advisor.py @@ -32,40 +32,7 @@ class Advisor: def __init__(self, input_data, out_path=""): self.input_data = input_data self.out_path = os.path.realpath(out_path) - - def _parse_input_data(self): - data_columns = self.input_data.columns.values - if {CompareConst.ACCURACY, CompareConst.NPU_NAME}.issubset(data_columns): - self.file_type = Const.ALL - elif {CompareConst.RESULT, CompareConst.NPU_MD5}.issubset(data_columns): - self.file_type = Const.MD5 - elif {CompareConst.MAX_DIFF, CompareConst.RESULT}.issubset(data_columns): - self.file_type = Const.SUMMARY - else: - print_error_log('Compare result does not meet the required conditions.') - raise CompareException(CompareException.INVALID_DATA_ERROR) - df = self.input_data.reset_index() - return df - - def _check_path_vaild(self): - out_path_checker = FileChecker(self.out_path, FileCheckConst.DIR, FileCheckConst.WRITE_ABLE) - out_path_checker.common_check() - - def gen_advisor_message(self, node_name): - if AdvisorConst.FORWARD in node_name: - if AdvisorConst.INPUT in node_name: - message = AdvisorConst.FORWARD_INPUT_SUGGEST - else: - message = AdvisorConst.FORWARD_OUTPUT_SUGGEST - message = self.deterministic_advisor(message, node_name) - else: - if AdvisorConst.INPUT in node_name: - message = AdvisorConst.BACKWARD_INPUT_SUGGEST - else: - message = AdvisorConst.BACKWARD_OUTPUT_SUGGEST - message = self.deterministic_advisor(message, node_name) - message = self.batch_norm_advisor(message, node_name) - return message + self.file_type = None @staticmethod def deterministic_advisor(message, node_name): @@ -102,6 +69,22 @@ class Advisor: result = AdvisorResult(node_name, index, message) return result + def gen_advisor_message(self, node_name): + if AdvisorConst.FORWARD in node_name: + if AdvisorConst.INPUT in node_name: + message = AdvisorConst.FORWARD_INPUT_SUGGEST + else: + message = AdvisorConst.FORWARD_OUTPUT_SUGGEST + message = self.deterministic_advisor(message, node_name) + else: + if AdvisorConst.INPUT in node_name: + message = AdvisorConst.BACKWARD_INPUT_SUGGEST + else: + message = AdvisorConst.BACKWARD_OUTPUT_SUGGEST + message = self.deterministic_advisor(message, node_name) + message = self.batch_norm_advisor(message, node_name) + return message + def analysis(self): self._check_path_vaild() analyze_data = self._parse_input_data() @@ -120,3 +103,21 @@ class Advisor: result = self.gen_advisor_result(failing_data) message_list = result.print_advisor_log() result.gen_summary_file(self.out_path, message_list) + + def _parse_input_data(self): + data_columns = self.input_data.columns.values + if {CompareConst.ACCURACY, CompareConst.NPU_NAME}.issubset(data_columns): + self.file_type = Const.ALL + elif {CompareConst.RESULT, CompareConst.NPU_MD5}.issubset(data_columns): + self.file_type = Const.MD5 + elif {CompareConst.MAX_DIFF, CompareConst.RESULT}.issubset(data_columns): + self.file_type = Const.SUMMARY + else: + print_error_log('Compare result does not meet the required conditions.') + raise CompareException(CompareException.INVALID_DATA_ERROR) + df = self.input_data.reset_index() + return df + + def _check_path_vaild(self): + out_path_checker = FileChecker(self.out_path, FileCheckConst.DIR, FileCheckConst.WRITE_ABLE) + out_path_checker.common_check() diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/api_precision_compare.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/api_precision_compare.py index 6a544de21a0..9484833e52c 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/api_precision_compare.py +++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/api_precision_compare.py @@ -83,6 +83,24 @@ class BenchmarkStandard: def __str__(self): return "%s" % (self.api_name) + @staticmethod + def _get_status(ratio, algorithm): + error_threshold = benchmark_algorithms_thresholds.get(algorithm).get('error_threshold') + warning_threshold = benchmark_algorithms_thresholds.get(algorithm).get('warning_threshold') + if ratio > error_threshold: + return CompareConst.ERROR + elif ratio > warning_threshold: + return CompareConst.WARNING + return CompareConst.PASS + + @staticmethod + def _calc_ratio(x, y, default_value=1.0): + x, y = convert_str_to_float(x), convert_str_to_float(y) + if math.isclose(y, 0.0): + return 1.0 if math.isclose(x, 0.0) else default_value + else: + return abs(x / y) + def get_result(self): self._compare_ratio() self.small_value_err_status = self._get_status(self.small_value_err_ratio, 'small_value') @@ -99,6 +117,11 @@ class BenchmarkStandard: elif CompareConst.WARNING in self.check_result_list: self.final_result = CompareConst.WARNING + def to_column_value(self): + return [self.small_value_err_ratio, self.small_value_err_status, self.rmse_ratio, + self.rmse_status, self.max_rel_err_ratio, self.max_rel_err_status, self.mean_rel_err_ratio, + self.mean_rel_err_status, self.eb_ratio, self.eb_status] + def _compare_ratio(self): self.small_value_err_ratio = self._calc_ratio( self.npu_precision.get(ApiPrecisionCompareColumn.SMALL_VALUE_ERROR_RATE), @@ -114,29 +137,6 @@ class BenchmarkStandard: self.eb_ratio = self._calc_ratio(self.npu_precision.get(ApiPrecisionCompareColumn.EB), self.gpu_precision.get(ApiPrecisionCompareColumn.EB), 10000.0) - def to_column_value(self): - return [self.small_value_err_ratio, self.small_value_err_status, self.rmse_ratio, - self.rmse_status, self.max_rel_err_ratio, self.max_rel_err_status, self.mean_rel_err_ratio, - self.mean_rel_err_status, self.eb_ratio, self.eb_status] - - @staticmethod - def _get_status(ratio, algorithm): - error_threshold = benchmark_algorithms_thresholds.get(algorithm).get('error_threshold') - warning_threshold = benchmark_algorithms_thresholds.get(algorithm).get('warning_threshold') - if ratio > error_threshold: - return CompareConst.ERROR - elif ratio > warning_threshold: - return CompareConst.WARNING - return CompareConst.PASS - - @staticmethod - def _calc_ratio(x, y, default_value=1.0): - x, y = convert_str_to_float(x), convert_str_to_float(y) - if math.isclose(y, 0.0): - return 1.0 if math.isclose(x, 0.0) else default_value - else: - return abs(x / y) - def write_detail_csv(content, save_path): rows = [] diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare.py index 67aa69e209b..350a4077475 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare.py +++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare.py @@ -39,6 +39,42 @@ class Comparator: "total_num": 0, "forward_or_backward_fail_num": 0 } + @staticmethod + def _compare_dropout(api_name, bench_output, device_output): + tensor_num = bench_output.numel() + if tensor_num >= 100: + if abs((bench_output == 0).sum() - (device_output == 0).cpu().sum()) / tensor_num < 0.1: + return CompareConst.PASS, 1 + else: + return CompareConst.ERROR, 0 + else: + return CompareConst.PASS, 1 + + @staticmethod + def _compare_builtin_type(bench_output, device_output, compare_column): + if not isinstance(bench_output, (bool, int, float, str)): + return CompareConst.PASS, compare_column, "" + if bench_output != device_output: + return CompareConst.ERROR, compare_column, "" + compare_column.error_rate = 0 + return CompareConst.PASS, compare_column, "" + + @staticmethod + def _compare_bool_tensor(bench_output, device_output): + error_nums = (bench_output != device_output).sum() + if bench_output.size == 0: + return CompareConst.NAN, CompareConst.ERROR, "There is not bench calculation result." + error_rate = float(error_nums / bench_output.size) + result = CompareConst.PASS if error_rate == 0 else CompareConst.ERROR + return error_rate, result, "" + + @staticmethod + def _get_absolute_threshold_attribute(api_name, dtype): + small_value_threshold = apis_threshold.get(api_name).get(dtype).get('small_value') + small_value_atol = apis_threshold.get(api_name).get(dtype).get('small_value_atol') + rtol = apis_threshold.get(api_name).get(dtype).get('rtol') + return small_value_threshold, small_value_atol, rtol + def print_pretest_result(self): self.get_statistics_from_result_csv() total_tests = self.test_result_cnt.get("total_num", 0) @@ -333,40 +369,3 @@ class Comparator: return CompareConst.WARNING, compare_column, message message += "Relative error is less than 0.0001, consider as pass.\n" return CompareConst.PASS, compare_column, message - - @staticmethod - def _compare_dropout(api_name, bench_output, device_output): - tensor_num = bench_output.numel() - if tensor_num >= 100: - if abs((bench_output == 0).sum() - (device_output == 0).cpu().sum()) / tensor_num < 0.1: - return CompareConst.PASS, 1 - else: - return CompareConst.ERROR, 0 - else: - return CompareConst.PASS, 1 - - @staticmethod - def _compare_builtin_type(bench_output, device_output, compare_column): - if not isinstance(bench_output, (bool, int, float, str)): - return CompareConst.PASS, compare_column, "" - if bench_output != device_output: - return CompareConst.ERROR, compare_column, "" - compare_column.error_rate = 0 - return CompareConst.PASS, compare_column, "" - - - @staticmethod - def _compare_bool_tensor(bench_output, device_output): - error_nums = (bench_output != device_output).sum() - if bench_output.size == 0: - return CompareConst.NAN, CompareConst.ERROR, "There is not bench calculation result." - error_rate = float(error_nums / bench_output.size) - result = CompareConst.PASS if error_rate == 0 else CompareConst.ERROR - return error_rate, result, "" - - @staticmethod - def _get_absolute_threshold_attribute(api_name, dtype): - small_value_threshold = apis_threshold.get(api_name).get(dtype).get('small_value') - small_value_atol = apis_threshold.get(api_name).get(dtype).get('small_value_atol') - rtol = apis_threshold.get(api_name).get(dtype).get('rtol') - return small_value_threshold, small_value_atol, rtol diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/dump/api_info.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/dump/api_info.py index 7452cec74e8..a8c59a229fe 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/dump/api_info.py +++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/dump/api_info.py @@ -104,6 +104,22 @@ class APIInfo: else: return os.path.join(save_path, dir_name) + @staticmethod + def _convert_numpy_to_builtin(arg): + type_mapping = { + np.integer: int, + np.floating: float, + np.bool_: bool, + np.complexfloating: complex, + np.str_: str, + np.bytes_: bytes, + np.unicode_: str + } + for numpy_type, builtin_type in type_mapping.items(): + if isinstance(arg, numpy_type): + return builtin_type(arg), get_type_name(str(type(arg))) + return arg, '' + def analyze_element(self, element): if isinstance(element, (list, tuple)): out = [] @@ -180,21 +196,6 @@ class APIInfo: single_arg.update({'type': numpy_type}) single_arg.update({'value': value}) return single_arg - - def _convert_numpy_to_builtin(self, arg): - type_mapping = { - np.integer: int, - np.floating: float, - np.bool_: bool, - np.complexfloating: complex, - np.str_: str, - np.bytes_: bytes, - np.unicode_: str - } - for numpy_type, builtin_type in type_mapping.items(): - if isinstance(arg, numpy_type): - return builtin_type(arg), get_type_name(str(type(arg))) - return arg, '' class ForwardAPIInfo(APIInfo): diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py index bea882f7507..86c52b23bf8 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py @@ -36,7 +36,7 @@ def check_tensor_overflow(x): def check_data_overflow(x): if isinstance(x, (tuple, list)) and x: - for i, item in enumerate(x): + for _, item in enumerate(x): if check_data_overflow(item): return True return False diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/run_ut.py index 3186913e948..b0eac9da4bd 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/run_ut.py @@ -160,7 +160,7 @@ def run_ut(config): csv_reader = csv.reader(file) next(csv_reader) api_name_set = {row[0] for row in csv_reader} - for i, (api_full_name, api_info_dict) in enumerate(tqdm(config.forward_content.items(), **tqdm_params)): + for _, (api_full_name, api_info_dict) in enumerate(tqdm(config.forward_content.items(), **tqdm_params)): if api_full_name in api_name_set: continue if is_unsupported_api(api_full_name): # TODO run_ut does not support to the npu fusion api and distributed api diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/test/run_ut.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/test/run_ut.py index c7394969794..16e8f091014 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/test/run_ut.py +++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/test/run_ut.py @@ -3,6 +3,7 @@ import shutil import subprocess import sys + def run_ut(): cur_dir = os.path.realpath(os.path.dirname(__file__)) top_dir = os.path.realpath(os.path.dirname(cur_dir)) @@ -10,14 +11,25 @@ def run_ut(): src_dir = top_dir report_dir = os.path.join(cur_dir, "report") + # cleanup and recreate report dir if os.path.exists(report_dir): shutil.rmtree(report_dir) - os.makedirs(report_dir) - cmd = ["python3", "-m", "pytest", ut_path, "--junitxml=" + report_dir + "/final.xml", - "--cov=" + src_dir, "--cov-branch", "--cov-report=xml:" + report_dir + "/coverage.xml"] - + # set paths for multi-platform compatibility + junit_report = os.path.join(report_dir, "final.xml") + coverage_report = os.path.join(report_dir, "coverage.xml") + + cmd = [ + "python3", + "-m", "pytest", + ut_path, + f"--junitxml={junit_report}", + f"--cov={src_dir}", + "--cov-branch", + f"--cov-report=xml:{coverage_report}" + ] + result_ut = subprocess.Popen(cmd, shell=False, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) while result_ut.poll() is None: @@ -25,16 +37,16 @@ def run_ut(): if line: print(line) - ut_flag = False - if result_ut.returncode == 0: - ut_flag = True + tests_passed = result_ut.returncode == 0 + if tests_passed: print("run ut successfully.") else: print("run ut failed.") - return ut_flag + return tests_passed + -if __name__=="__main__": +if __name__ == "__main__": if run_ut(): sys.exit(0) else: diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/test/ut/run_ut/test_data_generate.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/test/ut/run_ut/test_data_generate.py index b98f84d5164..2594b58dc26 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/test/ut/run_ut/test_data_generate.py +++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/test/ut/run_ut/test_data_generate.py @@ -16,6 +16,7 @@ for api_full_name, api_info_dict in forward_content.items(): max_value = 5.7421875 min_value = -5.125 + class TestDataGenerateMethods(unittest.TestCase): def test_gen_api_params(self): api_info = copy.deepcopy(api_info_dict) @@ -44,7 +45,7 @@ class TestDataGenerateMethods(unittest.TestCase): max_diff = abs(data.max() - max_value) min_diff = abs(data.min() - min_value) self.assertEqual(data.dtype, torch.float32) - self.assertEqual(data.requires_grad, True) + self.assertTrue(data.requires_grad) self.assertLessEqual(max_diff, 0.001) self.assertLessEqual(min_diff, 0.001) self.assertEqual(data.shape, torch.Size([2, 2560, 24, 24])) @@ -53,23 +54,23 @@ class TestDataGenerateMethods(unittest.TestCase): api_info = copy.deepcopy(api_info_dict) kwargs_params = gen_kwargs(api_info, None) self.assertEqual(kwargs_params, {'inplace': False}) - + def test_gen_kwargs_2(self): k_dict = {"inplace": {"type": "bool", "value": "False"}} for key, value in k_dict.items(): gen_torch_kwargs(k_dict, key, value) self.assertEqual(k_dict, {'inplace': False}) - + def test_gen_random_tensor(self): data = gen_random_tensor(api_info_dict.get('args')[0], None) max_diff = abs(data.max() - max_value) min_diff = abs(data.min() - min_value) self.assertEqual(data.dtype, torch.float32) - self.assertEqual(data.requires_grad, False) + self.assertFalse(data.requires_grad) self.assertLessEqual(max_diff, 0.001) self.assertLessEqual(min_diff, 0.001) self.assertEqual(data.shape, torch.Size([2, 2560, 24, 24])) - + def test_gen_common_tensor(self): info = api_info_dict.get('args')[0] low, high = info.get('Min'), info.get('Max') @@ -82,14 +83,14 @@ class TestDataGenerateMethods(unittest.TestCase): max_diff = abs(data.max() - max_value) min_diff = abs(data.min() - min_value) self.assertEqual(data.dtype, torch.float32) - self.assertEqual(data.requires_grad, False) + self.assertFalse(data.requires_grad) self.assertLessEqual(max_diff, 0.001) self.assertLessEqual(min_diff, 0.001) self.assertEqual(data.shape, torch.Size([2, 2560, 24, 24])) - + def test_gen_bool_tensor(self): info = {"type": "torch.Tensor", "dtype": "torch.bool", "shape": [1, 1, 160, 256], \ - "Max": 1, "Min": 0, "requires_grad": False} + "Max": 1, "Min": 0, "requires_grad": False} low, high = info.get("Min"), info.get("Max") shape = tuple(info.get("shape")) data = gen_bool_tensor(low, high, shape) diff --git a/debug/accuracy_tools/atat/pytorch/common/exceptions.py b/debug/accuracy_tools/atat/pytorch/common/exceptions.py index 17733b5bfd5..c4858baaa08 100644 --- a/debug/accuracy_tools/atat/pytorch/common/exceptions.py +++ b/debug/accuracy_tools/atat/pytorch/common/exceptions.py @@ -1,6 +1,6 @@ - class CodedException(Exception): def __init__(self, code, error_info=''): + super().__init__() self.error_info = self.err_strs.get(code) + error_info def __str__(self): @@ -10,7 +10,7 @@ class CodedException(Exception): class MsaccException(CodedException): INVALID_PARAM_ERROR = 0 OVERFLOW_NUMS_ERROR = 1 - + err_strs = { INVALID_PARAM_ERROR: "[msacc] 无效参数: ", OVERFLOW_NUMS_ERROR: "[msacc] 超过预设溢出次数 当前溢出次数:" @@ -68,8 +68,18 @@ class StepException(CodedException): InvalidPostProcess: "[msacc] 错误的step后处理配置: ", } + class FreeBenchmarkException(CodedException): UnsupportedType = 0 err_strs = { UnsupportedType: "[msacc] Free benchmark get unsupported type: " - } \ No newline at end of file + } + + +class DistributedNotInitializedError(Exception): + def __init__(self, msg): + super().__init__() + self.msg = msg + + def __str__(self): + return self.msg diff --git a/debug/accuracy_tools/atat/pytorch/common/log.py b/debug/accuracy_tools/atat/pytorch/common/log.py index fab5aca45c0..c37887e2b61 100644 --- a/debug/accuracy_tools/atat/pytorch/common/log.py +++ b/debug/accuracy_tools/atat/pytorch/common/log.py @@ -1,14 +1,23 @@ import os import time import sys + from .utils import get_rank_if_initialized +from .exceptions import DistributedNotInitializedError def on_rank_0(func): def func_rank_0(*args, **kwargs): - current_rank = get_rank_if_initialized() + try: + current_rank = get_rank_if_initialized() + except DistributedNotInitializedError: + current_rank = None + if current_rank is None or current_rank == 0: return func(*args, **kwargs) + else: + raise RuntimeError("Func can not be called on rank 0 or if the distributed environment is not initialized. " + f"current rank: {current_rank}") return func_rank_0 @@ -17,7 +26,10 @@ def _print_log(level, msg, end='\n'): current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(int(time.time()))) pid = os.getpid() full_msg = current_time + "(" + str(pid) + ")-[" + level + "]" + msg - current_rank = get_rank_if_initialized() + try: + current_rank = get_rank_if_initialized() + except DistributedNotInitializedError: + current_rank = None if current_rank is not None: full_msg = f"[rank {current_rank}]-" + full_msg print(full_msg, end=end) diff --git a/debug/accuracy_tools/atat/pytorch/common/recursive.py b/debug/accuracy_tools/atat/pytorch/common/recursive.py index c8a19a63117..9b222f5f521 100644 --- a/debug/accuracy_tools/atat/pytorch/common/recursive.py +++ b/debug/accuracy_tools/atat/pytorch/common/recursive.py @@ -1,10 +1,13 @@ -import torch import numpy as np +import torch + from .log import print_warn_log _recursive_key_stack = [] special_type = (torch.device, torch.dtype, torch.Size, torch.Tensor, np.integer, np.floating, np.bool_, np.complexfloating, \ np.str_, np.byte, np.unicode_, bool, int, float, str, slice) + + def recursive_apply_transform(args, transform): global _recursive_key_stack if isinstance(args, special_type): @@ -18,11 +21,11 @@ def recursive_apply_transform(args, transform): _recursive_key_stack.pop() return type(args)(transform_result) elif isinstance(args, dict): - transform_result = {} + transform_dict = {} for k, arg in args.items(): _recursive_key_stack.append(str(k)) - transform_result[k] = recursive_apply_transform(arg, transform) + transform_dict[k] = recursive_apply_transform(arg, transform) _recursive_key_stack.pop() - return transform_result + return transform_dict elif args is not None: print_warn_log(f"Data type {type(args)} is not supported.") diff --git a/debug/accuracy_tools/atat/pytorch/common/utils.py b/debug/accuracy_tools/atat/pytorch/common/utils.py index e88d506b2c3..bbb154d7bd6 100644 --- a/debug/accuracy_tools/atat/pytorch/common/utils.py +++ b/debug/accuracy_tools/atat/pytorch/common/utils.py @@ -21,6 +21,9 @@ import stat import torch import numpy as np from functools import wraps + +from .exceptions import DistributedNotInitializedError + try: import torch_npu except ImportError: @@ -93,9 +96,13 @@ def torch_device_guard(func): def get_rank_if_initialized(): + """ + return rank id if it is initialized or raise Exception: DistributedNotInitializedError + """ if torch.distributed.is_initialized(): return torch.distributed.get_rank() - return None + else: + raise DistributedNotInitializedError("torch distributed environment is not initialized") def seed_all(seed=1234, mode=False): diff --git a/debug/accuracy_tools/atat/pytorch/compare/acc_compare.py b/debug/accuracy_tools/atat/pytorch/compare/acc_compare.py index d3a072f5c78..6ef6795f8b2 100644 --- a/debug/accuracy_tools/atat/pytorch/compare/acc_compare.py +++ b/debug/accuracy_tools/atat/pytorch/compare/acc_compare.py @@ -508,7 +508,11 @@ def handle_inf_nan(n_value, b_value): b_inf = np.isinf(b_value) n_nan = np.isnan(n_value) b_nan = np.isnan(b_value) - if np.any(n_inf) or np.any(b_inf) or np.any(n_nan) or np.any(b_nan): + + # merge boolean expressions + any_inf = np.any(n_inf) or np.any(b_inf) + any_nan = np.any(n_nan) or np.any(b_nan) + if any_inf or any_nan: if np.array_equal(n_inf, b_inf) and np.array_equal(n_nan, b_nan): n_value[n_inf] = 0 b_value[b_inf] = 0 @@ -926,8 +930,10 @@ def compare_process(file_handles, stack_mode, fuzzy_match, summary_compare=False except StopIteration: read_err_bench = False - if len(npu_ops_queue) == 0 or len(bench_ops_queue) == 0 or ( - len(npu_ops_queue) == last_npu_ops_len and len(bench_ops_queue) == last_bench_ops_len): + # merge all boolean expressions + both_empty = not npu_ops_queue and not bench_ops_queue + no_change = (len(npu_ops_queue) == last_npu_ops_len) and (len(bench_ops_queue) == last_bench_ops_len) + if both_empty or no_change: continue n_match_point, b_match_point = match_op(npu_ops_queue, bench_ops_queue, fuzzy_match) diff --git a/debug/accuracy_tools/atat/pytorch/debugger/debugger_config.py b/debug/accuracy_tools/atat/pytorch/debugger/debugger_config.py index 9fc97332f64..2d4147cd67c 100644 --- a/debug/accuracy_tools/atat/pytorch/debugger/debugger_config.py +++ b/debug/accuracy_tools/atat/pytorch/debugger/debugger_config.py @@ -67,6 +67,12 @@ class DebuggerConfig: self._check_step() return True + def check_model(self, model): + if self.level in ["L0", "mix"] and not model: + raise Exception( + f"For level {self.level}, PrecisionDebugger must receive a model argument.", + ) + def _check_rank(self): if self.rank: for rank_id in self.rank: @@ -81,8 +87,3 @@ class DebuggerConfig: if not isinstance(s, int) or s < 0: raise ValueError(f"step element {s} must be an integer and greater than or equal to 0.") - def check_model(self, model): - if self.level in ["L0", "mix"] and not model: - raise Exception( - f"For level {self.level}, PrecisionDebugger must receive a model argument.", - ) \ No newline at end of file diff --git a/debug/accuracy_tools/atat/pytorch/debugger/precision_debugger.py b/debug/accuracy_tools/atat/pytorch/debugger/precision_debugger.py index e0ffa4e4d6e..8d67ae9ba6e 100644 --- a/debug/accuracy_tools/atat/pytorch/debugger/precision_debugger.py +++ b/debug/accuracy_tools/atat/pytorch/debugger/precision_debugger.py @@ -42,6 +42,10 @@ class PrecisionDebugger: print_warn_log_rank_0("The enable_dataloader feature will be deprecated in the future.") dataloader._BaseDataLoaderIter.__next__ = iter_tracer(dataloader._BaseDataLoaderIter.__next__) + @property + def instance(self): + return self._instance + @classmethod def start(cls): instance = cls._instance @@ -79,7 +83,7 @@ class PrecisionDebugger: def iter_tracer(func): def func_wrapper(*args, **kwargs): - debugger_instance = PrecisionDebugger._instance + debugger_instance = PrecisionDebugger.instance debugger_instance.enable_dataloader = False if not debugger_instance.service.first_start: debugger_instance.stop() diff --git a/debug/accuracy_tools/atat/pytorch/service.py b/debug/accuracy_tools/atat/pytorch/service.py index 9c079aedebe..093510b4ed9 100644 --- a/debug/accuracy_tools/atat/pytorch/service.py +++ b/debug/accuracy_tools/atat/pytorch/service.py @@ -1,15 +1,15 @@ +import functools import os from pathlib import Path -import functools -import torch + +from .common import print_info_log_rank_0 +from .common.file_check import FileChecker, FileCheckConst, check_path_before_create +from .common.utils import get_rank_if_initialized, is_gpu, Const from .functional import build_repair, build_data_collector, build_step_post_process +from .functional.data_processor import ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs from .functional.scope import BaseScope -from .common.utils import get_rank_if_initialized, is_gpu, Const -from .common.file_check import FileChecker, FileCheckConst, check_path_before_create -from .common import print_info_log_rank_0 -from .hook_module.api_registry import api_register from .hook_module import remove_dropout -from .functional.data_processor import ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs +from .hook_module.api_registry import api_register from .module_processer import ModuleProcesser @@ -29,6 +29,7 @@ class Service: self.first_start = True self.current_rank = None self.first_touch_dir = True + self.dump_iter_dir = None def build_hook(self, module_type, name): def pre_hook(repair, api_or_module_name, module, args, kwargs): @@ -53,7 +54,7 @@ class Service: self.data_collector.visit_and_clear_overflow_status(api_or_module_name) if not self.switch: - return + return None if self.data_collector: module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=output) self.data_collector.forward_data_collect(api_or_module_name, module, pid, module_input_output) @@ -151,12 +152,13 @@ class Service: print_info_log_rank_0("The {} hook function is successfully mounted to the model.".format(hook_name)) if self.config.level in ["L0", "mix"]: - assert self.model is not None + if self.model is None: + raise Exception("Model is None") print_info_log_rank_0("The init dump mode is enabled, and the module dump function will not be available") for name, module in self.model.named_modules(): if module == self.model: continue - prefix = BaseScope.Module_Type_Module + Const.SEP + name + Const.SEP +\ + prefix = BaseScope.Module_Type_Module + Const.SEP + name + Const.SEP + \ module.__class__.__name__ + Const.SEP pre_forward_hook, forward_hook, backward_hook = self.build_hook(BaseScope.Module_Type_Module, prefix) -- Gitee