diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py index c7b175cd4913e2072acca36d53c2952bb0eadce0..1b796359045f24e5ce1ce620c1002b5e7dcff782 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py @@ -1,13 +1,10 @@ # 进行比对及结果展示 import os -import csv from collections import namedtuple import torch import numpy as np -from rich.table import Table -from rich.console import Console -from api_accuracy_checker.common.utils import get_json_contents, write_csv, print_warn_log, Const +from api_accuracy_checker.common.utils import get_json_contents, write_csv, print_info_log, Const from api_accuracy_checker.compare.compare_utils import CompareConst, check_dtype_comparable, DETAIL_TEST_ROWS, \ precision_configs, BENCHMARK_COMPARE_SUPPORT_LIST, AbsoluteStandardApi, BinaryStandardApi, ULPStandardApi, \ ThousandthStandardApi, apis_threshold @@ -17,7 +14,6 @@ from api_accuracy_checker.compare.algorithm import get_rmse, get_error_balance, get_small_value_err_ratio, get_finite_and_infinite_mask, get_small_value_mask, check_inf_nan_value, \ check_small_value, check_norm_value, get_abs_bench_with_eps, get_ulp_err from api_accuracy_checker.common.config import msCheckerConfig -from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileOpen ResultInfo = namedtuple('ResultInfo', ['full_api_name', 'fwd_success_status', 'bwd_success_status', @@ -49,83 +45,13 @@ class Comparator: else: self.stack_info = None - self.test_result_cnt = { - "success_num": 0, "warning_num": 0, "error_num": 0, - "forward_fail_num": 0, "backward_fail_num": 0, "forward_and_backward_fail_num": 0, - "total_num": 0, "total_skip_num": 0 - } - @staticmethod def get_path_from_rank(rank, path_list, path_pattern): return path_list[-1] if len(path_list) == 1 else path_pattern.format(rank) - def print_pretest_result(self): - for save_path in self.save_path_list: - self.get_statistics_from_result_csv(save_path) - total_tests = self.test_result_cnt.get("total_num", 0) - if total_tests != 0: - passing_rate = "{:.2%}".format(self.test_result_cnt.get("success_num", 0) / total_tests) - else: - passing_rate = "0%" - - print_warn_log("The follwing tables will be deprecated in the future." - "The following results are for reference only.") - console = Console() - table_total = Table( - show_header=True, title="Overall Statistics", show_lines=True, width=75 - ) - table_total.add_column("Result") - table_total.add_column("Statistics") - table_total.add_row("[green]Pass[/green]", str(self.test_result_cnt.get("success_num", 0))) - table_total.add_row("[yellow]Warning[/yellow]", str(self.test_result_cnt.get("warning_num", 0))) - table_total.add_row("[red]Error[/red]", str(self.test_result_cnt.get("error_num", 0))) - table_total.add_row("Passing Rate", passing_rate) - table_total.add_row("Skip Tests", str(self.test_result_cnt.get("total_skip_num", 0))) - - table_detail = Table( - show_header=True, title="Detail Statistics", show_lines=True, width=75 - ) - table_detail.add_column("Result") - table_detail.add_column("Statistics") - table_detail.add_row("Forward Error", str(self.test_result_cnt.get("forward_fail_num", 0))) - table_detail.add_row("Backward Error", str(self.test_result_cnt.get("backward_fail_num", 0))) - table_detail.add_row("Both Forward & Backward Error", str(self.test_result_cnt.get("forward_and_backward_fail_num", 0))) - - console.print(table_total) - console.print(table_detail) - - def get_statistics_from_result_csv(self, save_path): - checklist = [CompareConst.PASS, CompareConst.ERROR, CompareConst.WARNING, CompareConst.SPACE, CompareConst.SKIP, "skip"] - with FileOpen(save_path, 'r') as file: - reader = csv.reader(file) - result_csv_rows = [row for row in reader] - result_csv_name = os.path.basename(save_path) - for item in result_csv_rows[1:]: - if not isinstance(item, list) or len(item) < 3: - raise ValueError("The number of columns in %s is incorrect" % result_csv_name) - if not all(item[i] and item[i] in checklist for i in (1, 2)): - raise ValueError( - "The value in the 2nd or 3rd column of %s is wrong, it must be pass, error, warning, skip, or SPACE" - % result_csv_name) - column1 = item[1] - column2 = item[2] - if column1.upper() == CompareConst.SKIP: - self.test_result_cnt["total_skip_num"] += 1 - continue - self.test_result_cnt["total_num"] += 1 - if column1 == CompareConst.PASS and column2 in [CompareConst.PASS, CompareConst.SPACE, CompareConst.SKIP]: - self.test_result_cnt['success_num'] += 1 - elif column1 == CompareConst.ERROR and column2 == CompareConst.ERROR: - self.test_result_cnt['forward_and_backward_fail_num'] += 1 - self.test_result_cnt['error_num'] += 1 - elif column1 == CompareConst.ERROR: - self.test_result_cnt['forward_fail_num'] += 1 - self.test_result_cnt['error_num'] += 1 - elif column2 == CompareConst.ERROR: - self.test_result_cnt['backward_fail_num'] += 1 - self.test_result_cnt['error_num'] += 1 - elif column1 == CompareConst.WARNING or column2 == CompareConst.WARNING: - self.test_result_cnt['warning_num'] += 1 + @staticmethod + def print_pretest_result(): + print_info_log("Successfully completed run_ut/multi_run_ut.") def write_csv_title(self): summary_test_rows = [[self.COLUMN_API_NAME, self.COLUMN_FORWARD_SUCCESS, diff --git a/debug/accuracy_tools/atat/core/common/const.py b/debug/accuracy_tools/atat/core/common/const.py new file mode 100644 index 0000000000000000000000000000000000000000..89de3a4e5a3519d823a61d561cd53e5293272ee2 --- /dev/null +++ b/debug/accuracy_tools/atat/core/common/const.py @@ -0,0 +1,237 @@ +import os +import stat +import numpy as np + +class Const: + """ + Class for const + """ + SEP = "." + REGEX_PREFIX_MAX_LENGTH = 20 + REGEX_PREFIX_PATTERN = r"^[a-zA-Z0-9_-]+$" + FILE_PATTERN = r'^[a-zA-Z0-9_./-]+$' + COMMA = "," + FLOAT_EPSILON = np.finfo(float).eps + OFF = 'OFF' + BACKWARD = 'backward' + FORWARD = 'forward' + + # dump mode + ALL = "all" + LIST = "list" + RANGE = "range" + STACK = "stack" + ACL = "acl" + API_LIST = "api_list" + API_STACK = "api_stack" + DUMP_MODE = [ALL, LIST, RANGE, STACK, ACL, API_LIST, API_STACK] + SUMMARY = "summary" + MD5 = "md5" + SUMMARY_MODE = [ALL, SUMMARY, MD5] + + WRITE_FLAGS = os.O_WRONLY | os.O_CREAT + WRITE_MODES = stat.S_IWUSR | stat.S_IRUSR + OVERWRITE_FLAGS = os.O_WRONLY | os.O_CREAT | os.O_TRUNC + + PKL_SUFFIX = ".pkl" + NUMPY_SUFFIX = ".npy" + ONE_GB = 1073741824 # 1 * 1024 * 1024 * 1024 + TEN_GB = 10737418240 # 10 * 1024 * 1024 * 1024 + FILE_PATTERN = r'^[a-zA-Z0-9_./-]+$' + DISTRIBUTED_PREFIX_LENGTH = 60 + # env dump path + KWARGS = 'kwargs' + INPUT = 'input' + OUTPUT = 'output' + INPUT_ARGS = 'input_args' + INPUT_KWARGS = 'input_kwargs' + GRAD_INPUT = 'grad_input' + GRAD_OUTPUT = 'grad_output' + START = "start" + STOP = "stop" + ENV_ENABLE = "1" + ENV_DISABLE = "0" + MAX_SEED_VALUE = 4294967295 # 2**32 - 1 + TASK_LIST = ["tensor", "statistics", "overflow_check", "free_benchmark"] + LEVEL_LIST = ["L0", "L1", "L2", "mix"] + STATISTICS = "statistics" + TENSOR = "tensor" + OVERFLOW_CHECK = "overflow_check" + FREE_BENCHMARK = "free_benchmark" + ATTR_NAME_PREFIX = "wrap_" + KERNEL_DUMP = "kernel_dump" + DATA = "data" + PT_FRAMEWORK = "pytorch" + MS_FRAMEWORK = "mindspore" + DIRECTORY_LENGTH = 4096 + FILE_NAME_LENGTH = 255 + FLOAT_TYPE = [np.half, np.single, float, np.double, np.float64, np.longdouble, np.float32, np.float16] + BOOL_TYPE = [bool, np.uint8] + INT_TYPE = [np.int32, np.int64] + NPU = 'NPU' + DISTRIBUTED = 'Distributed' + + INPLACE_LIST = [ + "broadcast", "all_reduce", "reduce", "all_gather", "gather", "scatter", "reduce_scatter", + "_reduce_scatter_base", "_all_gather_base", "send", "recv", "irecv", "isend", "all_to_all_single" + ] + + CONVERT = { + "int32_to_int64": ["torch.int32", "torch.int64"], + } + + CONVERT_API = { + "int32_to_int64": ["cross_entropy"] + } + +class CompareConst: + """ + Class for compare module const + """ + SPACE = " " + # compare result column name + NPU_NAME = "NPU Name" + BENCH_NAME = "Bench Name" + NPU_DTYPE = "NPU Dtype" + BENCH_DTYPE = "Bench Dtype" + NPU_SHAPE = "NPU Tensor Shape" + BENCH_SHAPE = "Bench Tensor Shape" + NPU_MAX = "NPU max" + NPU_MIN = "NPU min" + NPU_MEAN = "NPU mean" + NPU_NORM = "NPU l2norm" + BENCH_MAX = "Bench max" + BENCH_MIN = "Bench min" + BENCH_MEAN = "Bench mean" + BENCH_NORM = "Bench l2norm" + MAX_DIFF = "Max diff" + MIN_DIFF = "Min diff" + MEAN_DIFF = "Mean diff" + NORM_DIFF = "L2norm diff" + COSINE = "Cosine" + MAX_ABS_ERR = "MaxAbsErr" + MAX_RELATIVE_ERR = "MaxRelativeErr" + MIN_RELATIVE_ERR = "MinRelativeErr" + MEAN_RELATIVE_ERR = "MeanRelativeErr" + NORM_RELATIVE_ERR = "NormRelativeErr" + ACCURACY = "Accuracy Reached or Not" + STACK = "NPU_Stack_Info" + DATA_NAME = "Data_name" + ERROR_MESSAGE = "Err_message" + ONE_THOUSANDTH_ERR_RATIO = "One Thousandth Err Ratio" + FIVE_THOUSANDTHS_ERR_RATIO = "Five Thousandths Err Ratio" + NPU_MD5 = "NPU MD5" + BENCH_MD5 = "BENCH MD5" + RESULT = "Result" + + COMPARE_RESULT_HEADER = [ + NPU_NAME, BENCH_NAME, NPU_DTYPE, BENCH_DTYPE, NPU_SHAPE, BENCH_SHAPE, COSINE, MAX_ABS_ERR, MAX_RELATIVE_ERR, + ONE_THOUSANDTH_ERR_RATIO, FIVE_THOUSANDTHS_ERR_RATIO, + NPU_MAX, NPU_MIN, NPU_MEAN, NPU_NORM, BENCH_MAX, BENCH_MIN, BENCH_MEAN, BENCH_NORM, ACCURACY, ERROR_MESSAGE + ] + + SUMMARY_COMPARE_RESULT_HEADER = [ + NPU_NAME, BENCH_NAME, NPU_DTYPE, BENCH_DTYPE, NPU_SHAPE, BENCH_SHAPE, MAX_DIFF, MIN_DIFF, MEAN_DIFF, NORM_DIFF, + MAX_RELATIVE_ERR, MIN_RELATIVE_ERR, MEAN_RELATIVE_ERR, NORM_RELATIVE_ERR, + NPU_MAX, NPU_MIN, NPU_MEAN, NPU_NORM, BENCH_MAX, BENCH_MIN, BENCH_MEAN, BENCH_NORM, RESULT, ERROR_MESSAGE + ] + + MD5_COMPARE_RESULT_HEADER = [ + NPU_NAME, BENCH_NAME, NPU_DTYPE, BENCH_DTYPE, NPU_SHAPE, BENCH_SHAPE, NPU_MD5, BENCH_MD5, RESULT + ] + + # compare standard + THOUSAND_RATIO_THRESHOLD = 0.001 + FIVE_THOUSAND_RATIO_THRESHOLD = 0.005 + COSINE_THRESHOLD = 0.9999 + + # compare result data + READ_NONE = 'No data' + NONE = 'None' + SHAPE_UNMATCH = 'shape unmatched' + DIFF = 'Different' + UNSUPPORTED = 'unsupported' + NAN = 'Nan' + PASS = 'pass' + WARNING = 'Warning' + ERROR = 'error' + SKIP = 'SKIP' + BFLOAT16_MIN = -3.3895313892515355e+38 + BFLOAT16_MAX = 3.3895313892515355e+38 + BFLOAT16_EPS = 3.90625e-3 # 2 ** -8 + + # accuracy standards + COS_THRESHOLD = 0.99 + MAX_ABS_ERR_THRESHOLD = 0.001 + COS_MAX_THRESHOLD = 0.9 + MAX_ABS_ERR_MAX_THRESHOLD = 1 + ACCURACY_CHECK_YES = "Yes" + ACCURACY_CHECK_NO = "No" + ACCURACY_CHECK_UNMATCH = "Unmatched" + + # error message + NO_BENCH = "No bench data matched." + + # compare const + FLOAT_TYPE = [np.half, np.single, float, np.double, np.float64, np.longdouble] + + # highlight xlsx color const + RED = "FFFF0000" + YELLOW = "FFFF00" + BLUE = "0000FF" + + # highlight rules const + OVERFLOW_LIST = ['nan\t', 'inf\t', '-inf\t', 'nan', 'inf', '-inf'] + MAX_DIFF_RED = 1e+10 + ORDER_MAGNITUDE_DIFF_YELLOW = 1 + ONE_THOUSAND_ERROR_IN_RED = 0.9 + ONE_THOUSAND_ERROR_OUT_RED = 0.6 + ONE_THOUSAND_ERROR_DIFF_YELLOW = 0.1 + COSINE_DIFF_YELLOW = 0.1 + MAX_RELATIVE_OUT_RED = 0.5 + MAX_RELATIVE_OUT_YELLOW = 0.1 + MAX_RELATIVE_IN_YELLOW = 0.01 + +class FileCheckConst: + """ + Class for file check const + """ + READ_ABLE = "read" + WRITE_ABLE = "write" + READ_WRITE_ABLE = "read and write" + DIRECTORY_LENGTH = 4096 + FILE_NAME_LENGTH = 255 + FILE_VALID_PATTERN = r"^[a-zA-Z0-9_.:/-]+$" + FILE_PATTERN = r'^[a-zA-Z0-9_./-]+$' + PKL_SUFFIX = ".pkl" + NUMPY_SUFFIX = ".npy" + JSON_SUFFIX = ".json" + PT_SUFFIX = ".pt" + CSV_SUFFIX = ".csv" + YAML_SUFFIX = ".yaml" + MAX_PKL_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024 + MAX_NUMPY_SIZE = 10737418240 # 10 * 1024 * 1024 * 1024 + MAX_JSON_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024 + MAX_PT_SIZE = 10737418240 # 10 * 1024 * 1024 * 1024 + MAX_CSV_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024 + MAX_YAML_SIZE = 1048576 # 10 * 1024 * 1024 + DIR = "dir" + FILE = "file" + DATA_DIR_AUTHORITY = 0o750 + DATA_FILE_AUTHORITY = 0o640 + FILE_SIZE_DICT = { + PKL_SUFFIX: MAX_PKL_SIZE, + NUMPY_SUFFIX: MAX_NUMPY_SIZE, + JSON_SUFFIX: MAX_JSON_SIZE, + PT_SUFFIX: MAX_PT_SIZE, + CSV_SUFFIX: MAX_CSV_SIZE, + YAML_SUFFIX: MAX_YAML_SIZE + } + +class OverflowConst: + """ + Class for Overflow + """ + OVERFLOW_DEBUG_MODE_ENABLE = "OVERFLOW_DEBUG_MODE_ENABLE" + OVERFLOW_ORIGINAL_MODE = 0 + OVERFLOW_DEBUG_MODE = 1 diff --git a/debug/accuracy_tools/atat/core/common/file_check.py b/debug/accuracy_tools/atat/core/common/file_check.py index 43207e85e71f6413ed46da77b25b11a23c7b1c0a..2df825aa35108fc08b9d886bf68f0ef3e2bc1533 100644 --- a/debug/accuracy_tools/atat/core/common/file_check.py +++ b/debug/accuracy_tools/atat/core/common/file_check.py @@ -19,43 +19,7 @@ import re from atat.core.common.log import logger from atat.core.common.exceptions import FileCheckException - - -class FileCheckConst: - """ - Class for file check const - """ - READ_ABLE = "read" - WRITE_ABLE = "write" - READ_WRITE_ABLE = "read and write" - DIRECTORY_LENGTH = 4096 - FILE_NAME_LENGTH = 255 - FILE_VALID_PATTERN = r"^[a-zA-Z0-9_.:/-]+$" - FILE_PATTERN = r'^[a-zA-Z0-9_./-]+$' - PKL_SUFFIX = ".pkl" - NUMPY_SUFFIX = ".npy" - JSON_SUFFIX = ".json" - PT_SUFFIX = ".pt" - CSV_SUFFIX = ".csv" - YAML_SUFFIX = ".yaml" - MAX_PKL_SIZE = 1 * 1024 * 1024 * 1024 - MAX_NUMPY_SIZE = 10 * 1024 * 1024 * 1024 - MAX_JSON_SIZE = 1 * 1024 * 1024 * 1024 - MAX_PT_SIZE = 10 * 1024 * 1024 * 1024 - MAX_CSV_SIZE = 1 * 1024 * 1024 * 1024 - MAX_YAML_SIZE = 10 * 1024 * 1024 - DIR = "dir" - FILE = "file" - DATA_DIR_AUTHORITY = 0o750 - DATA_FILE_AUTHORITY = 0o640 - FILE_SIZE_DICT = { - PKL_SUFFIX: MAX_PKL_SIZE, - NUMPY_SUFFIX: MAX_NUMPY_SIZE, - JSON_SUFFIX: MAX_JSON_SIZE, - PT_SUFFIX: MAX_PT_SIZE, - CSV_SUFFIX: MAX_CSV_SIZE, - YAML_SUFFIX: MAX_YAML_SIZE - } +from atat.core.common.const import FileCheckConst class FileChecker: diff --git a/debug/accuracy_tools/atat/core/common/utils.py b/debug/accuracy_tools/atat/core/common/utils.py index 0c74bf038d29b19d68d20b6ffc398b78aee30abe..088530f3c5c88e4e97cd1eda470b02a6a2176fdf 100644 --- a/debug/accuracy_tools/atat/core/common/utils.py +++ b/debug/accuracy_tools/atat/core/common/utils.py @@ -26,7 +26,8 @@ from datetime import datetime, timezone from pathlib import Path import numpy as np -from atat.core.common.file_check import FileOpen, FileChecker, FileCheckConst +from atat.core.common.file_check import FileOpen, FileChecker +from atat.core.common.const import Const, FileCheckConst, CompareConst, OverflowConst from atat.core.common.log import logger @@ -34,206 +35,6 @@ device = collections.namedtuple('device', ['type', 'index']) prefixes = ['api_stack', 'list', 'range', 'acl'] -class Const: - """ - Class for const - """ - SEP = "." - MODEL_TYPE = ['.onnx', '.pb', '.om'] - DIM_PATTERN = r"^(-?[0-9]+)(,-?[0-9]+)*" - REGEX_PREFIX_MAX_LENGTH = 20 - REGEX_PREFIX_PATTERN = r"^[a-zA-Z0-9_-]+$" - SEMICOLON = ";" - COLON = ":" - EQUAL = "=" - COMMA = "," - DOT = "." - DUMP_RATIO_MAX = 100 - SUMMERY_DATA_NUMS = 256 - FLOAT_EPSILON = np.finfo(float).eps - SUPPORT_DUMP_MODE = ['api', 'acl'] - ON = 'ON' - OFF = 'OFF' - BACKWARD = 'backward' - FORWARD = 'forward' - PRE_FORWARD = "pre_forward" - - # dump mode - ALL = "all" - LIST = "list" - RANGE = "range" - STACK = "stack" - ACL = "acl" - API_LIST = "api_list" - API_STACK = "api_stack" - DUMP_MODE = [ALL, LIST, RANGE, STACK, ACL, API_LIST, API_STACK] - AUTO = "auto" - ONLINE_DUMP_MODE = [ALL, LIST, AUTO, OFF] - SUMMARY = "summary" - MD5 = "md5" - SUMMARY_MODE = [ALL, SUMMARY, MD5] - - WRITE_FLAGS = os.O_WRONLY | os.O_CREAT - WRITE_MODES = stat.S_IWUSR | stat.S_IRUSR - OVERWRITE_FLAGS = os.O_WRONLY | os.O_CREAT | os.O_TRUNC - - PKL_SUFFIX = ".pkl" - NUMPY_SUFFIX = ".npy" - ONE_GB = 1 * 1024 * 1024 * 1024 - TEN_GB = 10 * 1024 * 1024 * 1024 - FILE_PATTERN = r'^[a-zA-Z0-9_./-]+$' - FILE_NAME_LENGTH = 255 - DIRECTORY_LENGTH = 4096 - DISTRIBUTED_PREFIX_LENGTH = 60 - SUMMARY_COLUMN_NUM = 6 - STACK_COLUMN_NUM = 2 - # env dump path - ASCEND_WORK_PATH = "ASCEND_WORK_PATH" - DUMP_DIR = "dump_data" - - KWARGS = 'kwargs' - INPUT = 'input' - OUTPUT = 'output' - INPUT_ARGS = 'input_args' - INPUT_KWARGS = 'input_kwargs' - GRAD_INPUT = 'grad_input' - GRAD_OUTPUT = 'grad_output' - START = "start" - STOP = "stop" - ENV_ENABLE = "1" - ENV_DISABLE = "0" - - MAX_SEED_VALUE = 2**32 - 1 - - INPLACE_LIST = ["broadcast", "all_reduce", "reduce", "all_gather", "gather", "scatter", "reduce_scatter", - "_reduce_scatter_base", "_all_gather_base"] - - TASK_LIST = ["tensor", "statistics", "overflow_check", "free_benchmark"] - LEVEL_LIST = ["L0", "L1", "L2", "mix"] - STATISTICS = "statistics" - TENSOR = "tensor" - OVERFLOW_CHECK = "overflow_check" - FREE_BENCHMARK = "free_benchmark" - KERNEL_DUMP = "kernel_dump" - DATA = "data" - PT_FRAMEWORK = "pytorch" - MS_FRAMEWORK = "mindspore" - DIRECTORY_LENGTH = 4096 - FILE_NAME_LENGTH = 255 - FILE_PATTERN = r'^[a-zA-Z0-9_./-]+$' - FLOAT_TYPE = [np.half, np.single, float, np.double, np.float64, np.longdouble, np.float32, np.float16] - BOOL_TYPE = [bool, np.uint8] - INT_TYPE = [np.int32, np.int64] - NPU = 'NPU' - DISTRIBUTED = 'Distributed' - INPLACE_LIST = ["broadcast", "all_reduce", "reduce", "all_gather", "gather", "scatter", "reduce_scatter", - "_reduce_scatter_base", "_all_gather_base", "all_to_all_single"] - - -class CompareConst: - """ - Class for compare module const - """ - # compare result column name - NPU_NAME = "NPU Name" - BENCH_NAME = "Bench Name" - NPU_DTYPE = "NPU Dtype" - BENCH_DTYPE = "Bench Dtype" - NPU_SHAPE = "NPU Tensor Shape" - BENCH_SHAPE = "Bench Tensor Shape" - NPU_MAX = "NPU max" - NPU_MIN = "NPU min" - NPU_MEAN = "NPU mean" - NPU_NORM = "NPU l2norm" - BENCH_MAX = "Bench max" - BENCH_MIN = "Bench min" - BENCH_MEAN = "Bench mean" - BENCH_NORM = "Bench l2norm" - MAX_DIFF = "Max diff" - MIN_DIFF = "Min diff" - MEAN_DIFF = "Mean diff" - NORM_DIFF = "L2norm diff" - COSINE = "Cosine" - MAX_ABS_ERR = "MaxAbsErr" - MAX_RELATIVE_ERR = "MaxRelativeErr" - MIN_RELATIVE_ERR = "MinRelativeErr" - MEAN_RELATIVE_ERR = "MeanRelativeErr" - NORM_RELATIVE_ERR = "NormRelativeErr" - ACCURACY = "Accuracy Reached or Not" - STACK = "NPU_Stack_Info" - DATA_NAME = "Data_name" - ERROR_MESSAGE = "Err_message" - ONE_THOUSANDTH_ERR_RATIO = "One Thousandth Err Ratio" - FIVE_THOUSANDTHS_ERR_RATIO = "Five Thousandths Err Ratio" - NPU_MD5 = "NPU MD5" - BENCH_MD5 = "BENCH MD5" - RESULT = "Result" - - COMPARE_RESULT_HEADER = [ - NPU_NAME, BENCH_NAME, NPU_DTYPE, BENCH_DTYPE, NPU_SHAPE, BENCH_SHAPE, COSINE, MAX_ABS_ERR, MAX_RELATIVE_ERR, - ONE_THOUSANDTH_ERR_RATIO, FIVE_THOUSANDTHS_ERR_RATIO, - NPU_MAX, NPU_MIN, NPU_MEAN, NPU_NORM, BENCH_MAX, BENCH_MIN, BENCH_MEAN, BENCH_NORM, ACCURACY, ERROR_MESSAGE - ] - - SUMMARY_COMPARE_RESULT_HEADER = [ - NPU_NAME, BENCH_NAME, NPU_DTYPE, BENCH_DTYPE, NPU_SHAPE, BENCH_SHAPE, MAX_DIFF, MIN_DIFF, MEAN_DIFF, NORM_DIFF, - MAX_RELATIVE_ERR, MIN_RELATIVE_ERR, MEAN_RELATIVE_ERR, NORM_RELATIVE_ERR, - NPU_MAX, NPU_MIN, NPU_MEAN, NPU_NORM, BENCH_MAX, BENCH_MIN, BENCH_MEAN, BENCH_NORM, RESULT, ERROR_MESSAGE - ] - - MD5_COMPARE_RESULT_HEADER = [ - NPU_NAME, BENCH_NAME, NPU_DTYPE, BENCH_DTYPE, NPU_SHAPE, BENCH_SHAPE, NPU_MD5, BENCH_MD5, RESULT - ] - - # compare standard - THOUSAND_RATIO_THRESHOLD = 0.001 - FIVE_THOUSAND_RATIO_THRESHOLD = 0.005 - COSINE_THRESHOLD = 0.9999 - - # compare result data - READ_NONE = 'No data' - NAN = 'Nan' - NONE = 'None' - SHAPE_UNMATCH = 'shape unmatched' - DTYPE_UNMATCH = 'dtype unmatched' - PASS = 'Pass' - WARNING = 'Warning' - DIFF = 'Different' - UNSUPPORTED = 'unsupported' - - # accuracy standards - COS_THRESHOLD = 0.99 - MAX_ABS_ERR_THRESHOLD = 0.001 - COS_MAX_THRESHOLD = 0.9 - MAX_ABS_ERR_MAX_THRESHOLD = 1 - ACCURACY_CHECK_YES = "Yes" - ACCURACY_CHECK_NO = "No" - ACCURACY_CHECK_UNMATCH = "Unmatched" - - # error message - NO_BENCH = "No bench data matched." - - # compare const - FLOAT_TYPE = [np.half, np.single, float, np.double, np.float64, np.longdouble] - - # highlight xlsx color const - RED = "FFFF0000" - YELLOW = "FFFF00" - BLUE = "0000FF" - - # highlight rules const - OVERFLOW_LIST = ['nan\t', 'inf\t', '-inf\t', 'nan', 'inf', '-inf'] - MAX_DIFF_RED = 1e+10 - ORDER_MAGNITUDE_DIFF_YELLOW = 1 - ONE_THOUSAND_ERROR_IN_RED = 0.9 - ONE_THOUSAND_ERROR_OUT_RED = 0.6 - ONE_THOUSAND_ERROR_DIFF_YELLOW = 0.1 - COSINE_DIFF_YELLOW = 0.1 - MAX_RELATIVE_OUT_RED = 0.5 - MAX_RELATIVE_OUT_YELLOW = 0.1 - MAX_RELATIVE_IN_YELLOW = 0.01 - - class CompareException(Exception): """ Class for Accuracy Compare Exception @@ -273,15 +74,6 @@ class DumpException(CompareException): pass -class OverflowConst: - """ - Class for Overflow - """ - OVERFLOW_DEBUG_MODE_ENABLE = "OVERFLOW_DEBUG_MODE_ENABLE" - OVERFLOW_ORIGINAL_MODE = 0 - OVERFLOW_DEBUG_MODE = 1 - - def make_dump_path_if_not_exists(dump_path): if not os.path.exists(dump_path): try: diff --git a/debug/accuracy_tools/atat/core/common_config.py b/debug/accuracy_tools/atat/core/common_config.py index bc4ffd8090b0c91bd04091a6f76e837e7d9249bd..e256372ca877be1ca5474dd87c00decdf9d3c1c1 100644 --- a/debug/accuracy_tools/atat/core/common_config.py +++ b/debug/accuracy_tools/atat/core/common_config.py @@ -1,4 +1,4 @@ -from atat.core.common.utils import Const +from atat.core.common.const import Const from atat.core.common.log import logger from atat.core.common.exceptions import MsaccException diff --git a/debug/accuracy_tools/atat/core/data_dump/data_collector.py b/debug/accuracy_tools/atat/core/data_dump/data_collector.py index 2a0bc34ba8d8f4849d8056437f2640c11e91b340..f6a9a70b138f1f58c777c5eacf1168b628de3f38 100644 --- a/debug/accuracy_tools/atat/core/data_dump/data_collector.py +++ b/debug/accuracy_tools/atat/core/data_dump/data_collector.py @@ -4,7 +4,7 @@ import os from atat.core.data_dump.scope import build_scope, ListScope from atat.core.data_dump.json_writer import DataWriter from atat.core.common.log import logger -from atat.core.common.utils import Const +from atat.core.common.const import Const from atat.core.data_dump.data_processor.factory import DataProcessorFactory diff --git a/debug/accuracy_tools/atat/core/data_dump/data_processor/base.py b/debug/accuracy_tools/atat/core/data_dump/data_processor/base.py index 1ee3314b368f3c3382bcb1a221ca846b1beb90f1..208c053192c6da674ec9c7f522a0affdf1d091e2 100644 --- a/debug/accuracy_tools/atat/core/data_dump/data_processor/base.py +++ b/debug/accuracy_tools/atat/core/data_dump/data_processor/base.py @@ -4,7 +4,8 @@ from dataclasses import dataclass from typing import Tuple, Dict, Optional, Any import numpy as np from atat.core.common.log import logger -from atat.core.common.utils import Const, convert_tuple +from atat.core.common.utils import convert_tuple +from atat.core.common.const import Const @dataclass diff --git a/debug/accuracy_tools/atat/core/data_dump/data_processor/factory.py b/debug/accuracy_tools/atat/core/data_dump/data_processor/factory.py index 00f2f72e7a8a04fa0fad12cd2935958d50171b4a..bcc771f3684aa55a422691bd9dbfff31f07773dd 100644 --- a/debug/accuracy_tools/atat/core/data_dump/data_processor/factory.py +++ b/debug/accuracy_tools/atat/core/data_dump/data_processor/factory.py @@ -1,4 +1,4 @@ -from atat.core.common.utils import Const +from atat.core.common.const import Const class DataProcessorFactory: diff --git a/debug/accuracy_tools/atat/core/data_dump/data_processor/pytorch_processor.py b/debug/accuracy_tools/atat/core/data_dump/data_processor/pytorch_processor.py index 9f96635e9a3978ff9ff535b2001a15ec89f4f4f5..cf3c5ebe5864ff48d9f2442c020cb3ae99473b50 100644 --- a/debug/accuracy_tools/atat/core/data_dump/data_processor/pytorch_processor.py +++ b/debug/accuracy_tools/atat/core/data_dump/data_processor/pytorch_processor.py @@ -6,9 +6,9 @@ from typing import List import numpy as np import torch from atat.core.common.exceptions import MsaccException -from atat.core.common.file_check import path_len_exceeds_limit, change_mode, FileCheckConst +from atat.core.common.file_check import path_len_exceeds_limit, change_mode from atat.core.common.log import logger -from atat.core.common.utils import Const, OverflowConst +from atat.core.common.const import Const, OverflowConst, FileCheckConst from atat.core.data_dump.data_processor.base import BaseDataProcessor, ModuleBackwardInputsOutputs, \ ModuleForwardInputsOutputs, TensorStatInfo from atat.pytorch.free_benchmark import FreeBenchmarkCheck, UnequalRow diff --git a/debug/accuracy_tools/atat/core/data_dump/json_writer.py b/debug/accuracy_tools/atat/core/data_dump/json_writer.py index dd0d2f9c7b539704ecd02a7f8aa2cc006c50b81b..23f37b2342e9bde9d69bb65022a40911d5a54dc4 100644 --- a/debug/accuracy_tools/atat/core/data_dump/json_writer.py +++ b/debug/accuracy_tools/atat/core/data_dump/json_writer.py @@ -4,9 +4,9 @@ import fcntl import json from pathlib import Path -from atat.core.common.file_check import FileCheckConst, change_mode +from atat.core.common.file_check import change_mode from atat.core.common.log import logger -from atat.core.common.utils import Const +from atat.core.common.const import Const, FileCheckConst class DataWriter: diff --git a/debug/accuracy_tools/atat/core/data_dump/scope.py b/debug/accuracy_tools/atat/core/data_dump/scope.py index dc473d7e1460e977d8f3e08d690ad554415239d5..e7114f343fe724ffdd40d4837a09b8417d03a1b0 100644 --- a/debug/accuracy_tools/atat/core/data_dump/scope.py +++ b/debug/accuracy_tools/atat/core/data_dump/scope.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from atat.core.common.exceptions import ScopeException -from atat.core.common.utils import Const +from atat.core.common.const import Const def build_scope(scope_class, scope=None, api_list=None): diff --git a/debug/accuracy_tools/atat/pytorch/advisor/advisor.py b/debug/accuracy_tools/atat/pytorch/advisor/advisor.py index f4cb441f5e6f74a52282db1789bb2a29cf97ea79..43b3f40f97948808a987bd4211530cfca2cb025a 100644 --- a/debug/accuracy_tools/atat/pytorch/advisor/advisor.py +++ b/debug/accuracy_tools/atat/pytorch/advisor/advisor.py @@ -20,9 +20,9 @@ import os from atat.pytorch.advisor.advisor_result import AdvisorResult from atat.pytorch.advisor.advisor_const import AdvisorConst from atat.pytorch.common.log import logger -from atat.core.common.utils import CompareException, CompareConst, Const -from atat.core.common.file_check import FileChecker, FileCheckConst - +from atat.core.common.utils import CompareException +from atat.core.common.file_check import FileChecker +from atat.core.common.const import Const, CompareConst, FileCheckConst class Advisor: """ diff --git a/debug/accuracy_tools/atat/pytorch/advisor/advisor_result.py b/debug/accuracy_tools/atat/pytorch/advisor/advisor_result.py index 59845a75415823246a1fadeba588d5289c3eb272..a24fa2a1155501d91eb2462528f71824f091f318 100644 --- a/debug/accuracy_tools/atat/pytorch/advisor/advisor_result.py +++ b/debug/accuracy_tools/atat/pytorch/advisor/advisor_result.py @@ -19,8 +19,8 @@ import time from atat.pytorch.advisor.advisor_const import AdvisorConst from atat.pytorch.common.log import logger -from atat.core.common.utils import Const -from atat.core.common.file_check import FileCheckConst, change_mode +from atat.core.common.const import Const, FileCheckConst +from atat.core.common.file_check import change_mode class AdvisorResult: diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/common/utils.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/common/utils.py index 022edbfcf308a252cf9ef477e62ee49b4d7873ed..9e1b02c0154760f468c8d603e4bf385777258434 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/common/utils.py +++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/common/utils.py @@ -29,8 +29,8 @@ else: IS_GPU = False from atat.pytorch.common.log import logger -from atat.core.common.file_check import FileCheckConst, FileChecker, FileOpen, change_mode, create_directory -from atat.pytorch.common.utils import Const +from atat.core.common.file_check import FileChecker, FileOpen, change_mode, create_directory +from atat.core.common.const import Const, FileCheckConst from atat.core.common.utils import CompareException diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/algorithm.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/algorithm.py index a450edb929161dd14f2d1476509ff5ce5b7a9d1c..3f13534a5ad986aea8c8c1ed320e15386674445f 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/algorithm.py +++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/algorithm.py @@ -1,7 +1,7 @@ # 定义比对算法及比对标准 import torch import numpy as np -from atat.pytorch.api_accuracy_checker.compare.compare_utils import CompareConst +from atat.core.common.const import CompareConst #cos diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/api_precision_compare.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/api_precision_compare.py index 7e0617eb3ae1a91062fa25c5cb478e5e674ffb0a..89ed3a1008f08532b3e2ec93446493c40a05e29c 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/api_precision_compare.py +++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/api_precision_compare.py @@ -8,15 +8,16 @@ import pandas as pd from atat.pytorch.api_accuracy_checker.common.utils import write_csv from atat.pytorch.api_accuracy_checker.common.config import msCheckerConfig -from atat.pytorch.api_accuracy_checker.compare.compare_utils import CompareConst, API_PRECISION_COMPARE_RESULT_FILE_NAME, \ +from atat.pytorch.api_accuracy_checker.compare.compare_utils import API_PRECISION_COMPARE_RESULT_FILE_NAME, \ API_PRECISION_COMPARE_DETAILS_FILE_NAME, BENCHMARK_COMPARE_SUPPORT_LIST, API_PRECISION_COMPARE_UNSUPPORT_LIST, \ ApiPrecisionCompareColumn, AbsoluteStandardApi, BinaryStandardApi, BINARY_COMPARE_UNSUPPORT_LIST, \ convert_str_to_float, CompareMessage from atat.pytorch.api_accuracy_checker.compare.compare_column import ApiPrecisionOutputColumn from atat.pytorch.api_accuracy_checker.run_ut.run_ut import get_validated_result_csv_path -from atat.core.common.file_check import FileCheckConst, FileChecker, change_mode, check_path_before_create, create_directory +from atat.core.common.file_check import FileChecker, change_mode, check_path_before_create, create_directory from atat.pytorch.common.log import logger from atat.core.common.utils import CompareException +from atat.core.common.const import CompareConst, FileCheckConst CompareConfig = namedtuple('CompareConfig', ['npu_csv_path', 'gpu_csv_path', 'result_csv_path', 'details_csv_path']) unsupported_message = 'This data type does not support benchmark compare.' diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare.py index fbba1dca002663df9e3c6df983fe4ee3546be13f..cfc783bd75ac0ac89c681864a576c702d5722428 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare.py +++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare.py @@ -1,13 +1,10 @@ # 进行比对及结果展示 import os -import csv import torch import numpy as np -from rich.table import Table -from rich.console import Console from atat.pytorch.common.log import logger -from atat.pytorch.api_accuracy_checker.common.utils import get_json_contents, write_csv, Const -from atat.pytorch.api_accuracy_checker.compare.compare_utils import CompareConst, check_dtype_comparable, \ +from atat.pytorch.api_accuracy_checker.common.utils import get_json_contents, write_csv +from atat.pytorch.api_accuracy_checker.compare.compare_utils import check_dtype_comparable, \ DETAIL_TEST_ROWS, precision_configs, BENCHMARK_COMPARE_SUPPORT_LIST, AbsoluteStandardApi, BinaryStandardApi, \ apis_threshold from atat.pytorch.api_accuracy_checker.compare.compare_column import CompareColumn @@ -16,7 +13,7 @@ from atat.pytorch.api_accuracy_checker.compare.algorithm import get_rmse, get_er get_small_value_err_ratio, get_finite_and_infinite_mask, get_small_value_mask, check_inf_nan_value, \ check_small_value, check_norm_value, get_abs_bench_with_eps from atat.pytorch.api_accuracy_checker.common.config import msCheckerConfig -from atat.core.common.file_check import FileOpen +from atat.core.common.const import Const, CompareConst class Comparator: @@ -36,11 +33,10 @@ class Comparator: else: self.stack_info = None - self.test_result_cnt = { - "forward_fail_num": 0, "backward_fail_num": 0, "forward_and_backward_fail_num": 0, "success_num": 0, - "total_num": 0, "forward_or_backward_fail_num": 0 - } - + @staticmethod + def print_pretest_result(): + logger.info("Successfully completed run_ut/multi_run_ut.") + @staticmethod def _compare_dropout(bench_output, device_output): tensor_num = bench_output.numel() @@ -77,80 +73,6 @@ class Comparator: rtol = apis_threshold.get(api_name).get(dtype).get('rtol') return small_value_threshold, small_value_atol, rtol - def print_pretest_result(self): - self.get_statistics_from_result_csv() - total_tests = self.test_result_cnt.get("total_num", 0) - if total_tests != 0: - passing_rate = "{:.2%}".format(self.test_result_cnt.get("success_num", 0) / total_tests) - else: - passing_rate = "0%" - - logger.warning("The follwing tables will be deprecated in the future." - "The following results are for reference only.") - console = Console() - table_total = Table( - show_header=True, title="Overall Statistics", show_lines=True, width=75 - ) - table_total.add_column("Result") - table_total.add_column("Statistics") - table_total.add_row("[green]Pass[/green]", str(self.test_result_cnt.get("success_num", 0))) - table_total.add_row("[yellow]Warning[/yellow]", str(self.test_result_cnt.get("warning_num", 0))) - table_total.add_row("[red]Error[/red]", str(self.test_result_cnt.get("error_num", 0))) - table_total.add_row("Passing Rate", passing_rate) - table_total.add_row("Skip Tests", str(self.test_result_cnt.get("total_skip_num", 0))) - - table_detail = Table( - show_header=True, title="Detail Statistics", show_lines=True, width=75 - ) - table_detail.add_column("Result") - table_detail.add_column("Statistics") - table_detail.add_row("Forward Error", str(self.test_result_cnt.get("forward_fail_num", 0))) - table_detail.add_row("Backward Error", str(self.test_result_cnt.get("backward_fail_num", 0))) - table_detail.add_row("Both Forward & Backward Error", - str(self.test_result_cnt.get("forward_and_backward_fail_num", 0))) - - console.print(table_total) - console.print(table_detail) - - def get_statistics_from_result_csv(self): - checklist = [CompareConst.PASS, CompareConst.ERROR, CompareConst.WARNING, CompareConst.SPACE, CompareConst.SKIP, - "skip"] - self.test_result_cnt = { - "success_num": 0, "warning_num": 0, "error_num": 0, - "forward_fail_num": 0, "backward_fail_num": 0, "forward_and_backward_fail_num": 0, - "total_num": 0, "total_skip_num": 0 - } - with FileOpen(self.save_path, 'r') as file: - reader = csv.reader(file) - result_csv_rows = [row for row in reader] - result_csv_name = os.path.basename(self.save_path) - for item in result_csv_rows[1:]: - if not isinstance(item, list) or len(item) < 3: - raise ValueError("The number of columns in %s is incorrect" % result_csv_name) - if not all(item[i] and item[i] in checklist for i in (1, 2)): - raise ValueError( - "The value in the 2nd or 3rd column of %s is wrong, it must be pass, error, warning, skip, or SPACE" - % result_csv_name) - column1 = item[1] - column2 = item[2] - if column1.upper() == CompareConst.SKIP: - self.test_result_cnt["total_skip_num"] += 1 - continue - self.test_result_cnt["total_num"] += 1 - if column1 == CompareConst.PASS and column2 in [CompareConst.PASS, CompareConst.SPACE]: - self.test_result_cnt['success_num'] += 1 - elif column1 == CompareConst.ERROR and column2 == CompareConst.ERROR: - self.test_result_cnt['forward_and_backward_fail_num'] += 1 - self.test_result_cnt['error_num'] += 1 - elif column1 == CompareConst.ERROR: - self.test_result_cnt['forward_fail_num'] += 1 - self.test_result_cnt['error_num'] += 1 - elif column2 == CompareConst.ERROR: - self.test_result_cnt['backward_fail_num'] += 1 - self.test_result_cnt['error_num'] += 1 - elif column1 == CompareConst.WARNING or column2 == CompareConst.WARNING: - self.test_result_cnt['warning_num'] += 1 - def write_csv_title(self): summary_test_rows = [[self.COLUMN_API_NAME, self.COLUMN_FORWARD_SUCCESS, self.COLUMN_BACKWARD_SUCCESS, "Message"]] diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare_column.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare_column.py index bd88d6742f3a26f082f735d514023d74a49c7541..a018b192753d060680344fc4330e8a3427d4e6a5 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare_column.py +++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare_column.py @@ -1,4 +1,4 @@ -from atat.pytorch.api_accuracy_checker.compare.compare_utils import CompareConst +from atat.core.common.const import CompareConst class CompareColumn: diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare_utils.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare_utils.py index fe841eb06397443f7ef7d89edc6852d05f7579f5..bcf6b8ea196f5e14085d0200ba484e10017a6059 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare_utils.py +++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare_utils.py @@ -3,7 +3,8 @@ import os import numpy as np import torch import yaml -from atat.core.common.utils import Const, CompareException +from atat.core.common.utils import CompareException +from atat.core.common.const import Const from atat.pytorch.common.log import logger from atat.core.common.file_check import FileOpen @@ -77,21 +78,6 @@ precision_configs = { } } - -class CompareConst: - NAN = np.nan - NA = "N/A" - PASS = 'pass' - WARNING = 'warning' - ERROR = 'error' - SKIP = 'SKIP' - TRUE = 'TRUE' - FALSE = 'FALSE' - BFLOAT16_MIN = -3.3895313892515355e+38 - BFLOAT16_MAX = 3.3895313892515355e+38 - BFLOAT16_EPS = 2 ** -8 - SPACE = " " - class ApiPrecisionCompareColumn: API_NAME = 'API Name' diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/data_generate.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/data_generate.py index e983413bf00ea32aaea4d6ee7e215097e9f12cff..c6b721eee239fa943861d58ce31c1d8e9145bda9 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/data_generate.py +++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/data_generate.py @@ -20,8 +20,9 @@ import math import torch import numpy -from atat.pytorch.api_accuracy_checker.common.utils import Const, check_file_or_directory_path, check_object_type, get_full_data_path, CompareException +from atat.pytorch.api_accuracy_checker.common.utils import check_file_or_directory_path, check_object_type, get_full_data_path, CompareException from atat.pytorch.common.log import logger +from atat.core.common.const import Const TORCH_TYPE = ["torch.device", "torch.dtype"] TENSOR_DATA_LIST = ["torch.Tensor", "torch.nn.parameter.Parameter"] diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py index b9d1a4fd1f3da5c4f115bb8542c9f9150c4d5f35..d2ab9c1e952001f4551044f4395662dd25931e08 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py @@ -13,9 +13,10 @@ from atat.pytorch.api_accuracy_checker.run_ut.run_ut import _run_ut_parser, get_ get_validated_details_csv_path, preprocess_forward_content from atat.pytorch.api_accuracy_checker.compare.compare import Comparator from atat.pytorch.common import parse_json_info_forward_backward -from atat.core.common.file_check import FileCheckConst, FileChecker, check_file_suffix, check_link, FileOpen, \ +from atat.core.common.file_check import FileChecker, check_file_suffix, check_link, FileOpen, \ check_path_before_create, create_directory from atat.pytorch.common.log import logger +from atat.core.common.const import FileCheckConst def split_json_file(input_file, num_splits, filter_api): diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/run_ut.py index 77f3bf714a37eede95de04fdf1240ff655451e3a..47cbd9944702530680f31029bafb2ca1705b287a 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/run_ut.py @@ -27,10 +27,10 @@ from atat.pytorch.hook_module.wrap_functional import FunctionalOPTemplate from atat.pytorch.hook_module.wrap_torch import TorchOPTemplate from atat.pytorch.api_accuracy_checker.common.config import msCheckerConfig from atat.pytorch.common.parse_json import parse_json_info_forward_backward -from atat.core.common.file_check import FileOpen, FileCheckConst, FileChecker, \ +from atat.core.common.file_check import FileOpen, FileChecker, \ change_mode, check_file_suffix, check_link, check_path_before_create, create_directory from atat.pytorch.common.log import logger -from atat.pytorch.common.utils import Const +from atat.core.common.const import Const, FileCheckConst current_time = time.strftime("%Y%m%d%H%M%S") UT_ERROR_DATA_DIR = 'ut_error_data' + current_time @@ -40,7 +40,11 @@ RunUTConfig = namedtuple('RunUTConfig', ['forward_content', 'backward_content', 'save_error_data', 'is_continue_run_ut', 'real_data_path']) not_backward_list = ['repeat_interleave'] not_detach_set = {'resize_', 'resize_as_', 'set_', 'transpose_', 't_', 'squeeze_', 'unsqueeze_'} - +RAISE_PRECISION = { + torch.float16: torch.float32, + torch.bfloat16: torch.float32, + torch.float32: torch.float64 +} tqdm_params = { 'smoothing': 0, # 平滑进度条的预计剩余时间,取值范围0到1 'desc': 'Processing', # 进度条前的描述文字 @@ -75,7 +79,7 @@ def deal_detach(arg, to_detach=True): def deal_dtype(arg, raise_dtype=None): - if raise_dtype is None or arg.dtype not in Const.RAISE_PRECISION or raise_dtype == arg.dtype: + if raise_dtype is None or arg.dtype not in RAISE_PRECISION or raise_dtype == arg.dtype: return arg return arg.type(raise_dtype) @@ -120,7 +124,7 @@ def generate_cpu_params(input_args, input_kwargs, need_backward, api_name): return arg_in def is_tensor_with_raise_precision(arg_in, check_kwargs=False): - if arg_in.dtype in Const.RAISE_PRECISION: + if arg_in.dtype in RAISE_PRECISION: return True if check_kwargs and arg_in.dtype in [torch.half, torch.bfloat16]: return True @@ -139,7 +143,7 @@ def generate_cpu_params(input_args, input_kwargs, need_backward, api_name): need_raise_dtypes = recursive_find_dtypes(input_args) need_raise_dtypes.update(recursive_find_dtypes(input_kwargs, check_kwargs=True)) if len(need_raise_dtypes) == 1: - raise_dtype = Const.RAISE_PRECISION.get(need_raise_dtypes.pop(), torch.float32) + raise_dtype = RAISE_PRECISION.get(need_raise_dtypes.pop(), torch.float32) elif len(need_raise_dtypes) >= 2: raise_dtype = torch.float32 diff --git a/debug/accuracy_tools/atat/pytorch/compare/acc_compare.py b/debug/accuracy_tools/atat/pytorch/compare/acc_compare.py index 2d7bdcfff34087c38c783b20dfaafa82977a178d..061c9cdfca872bb76ad8ffa0e292fb24ce1a4ada 100644 --- a/debug/accuracy_tools/atat/pytorch/compare/acc_compare.py +++ b/debug/accuracy_tools/atat/pytorch/compare/acc_compare.py @@ -32,9 +32,10 @@ from atat.pytorch.compare.highlight import HighlightRules, get_header_index from atat.pytorch.compare.npy_compare import compare_ops_apply, get_error_type, reshape_value, get_relative_err, get_error_message from atat.pytorch.advisor.advisor import Advisor from atat.pytorch.common.log import logger -from atat.core.common.utils import check_compare_param, add_time_with_xlsx, CompareException, CompareConst, \ - format_value, check_file_not_exists, check_configuration_param, task_dumppath_get, Const -from atat.core.common.file_check import FileChecker, FileCheckConst, change_mode, FileOpen, create_directory +from atat.core.common.utils import check_compare_param, add_time_with_xlsx, CompareException, \ + format_value, check_file_not_exists, check_configuration_param, task_dumppath_get +from atat.core.common.file_check import FileChecker, change_mode, FileOpen, create_directory +from atat.core.common.const import Const, CompareConst, FileCheckConst def check_graph_mode(a_op_name, b_op_name): diff --git a/debug/accuracy_tools/atat/pytorch/compare/highlight.py b/debug/accuracy_tools/atat/pytorch/compare/highlight.py index d94e86b013b9a89c1fe0c412db612ccef3e51991..3a6898dedbb6910d0c1c9e55f80b55eb4fa0ed3c 100644 --- a/debug/accuracy_tools/atat/pytorch/compare/highlight.py +++ b/debug/accuracy_tools/atat/pytorch/compare/highlight.py @@ -1,7 +1,8 @@ import math import abc import numpy as np -from atat.core.common.utils import CompareConst, get_header_index +from atat.core.common.utils import get_header_index +from atat.core.common.const import CompareConst class HighlightCheck(abc.ABC): diff --git a/debug/accuracy_tools/atat/pytorch/compare/npy_compare.py b/debug/accuracy_tools/atat/pytorch/compare/npy_compare.py index 2e1f22ab3f5243fa4bb6b2e9bb6e7094680081b2..0cf4c6c00a0a671a6ee46dc6dacce48af6e67adf 100644 --- a/debug/accuracy_tools/atat/pytorch/compare/npy_compare.py +++ b/debug/accuracy_tools/atat/pytorch/compare/npy_compare.py @@ -1,6 +1,7 @@ import abc import numpy as np -from atat.core.common.utils import CompareConst, Const, format_value +from atat.core.common.utils import format_value +from atat.core.common.const import Const, CompareConst from atat.pytorch.common.log import logger diff --git a/debug/accuracy_tools/atat/pytorch/debugger/debugger_config.py b/debug/accuracy_tools/atat/pytorch/debugger/debugger_config.py index 6f2bfe8551062e82d552661f186130195b41f4dd..1ad69701e4167db3e2e9f61f7f725f024ec21d16 100644 --- a/debug/accuracy_tools/atat/pytorch/debugger/debugger_config.py +++ b/debug/accuracy_tools/atat/pytorch/debugger/debugger_config.py @@ -1,6 +1,6 @@ from atat.pytorch.common import seed_all from atat.pytorch.common.log import logger -from atat.core.common.utils import Const +from atat.core.common.const import Const class DebuggerConfig: diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/__init__.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/__init__.py index f86fc41d557b1801318303ce18a934b2306c223e..b9d41330a87605af37f7b538e8d7bcf56f9725f1 100644 --- a/debug/accuracy_tools/atat/pytorch/free_benchmark/__init__.py +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/__init__.py @@ -1,6 +1,6 @@ from atat.core.common.log import logger from atat.core.common.exceptions import FreeBenchmarkException -from atat.pytorch.common.utils import Const +from atat.core.common.const import Const from .main import FreeBenchmarkCheck from .common.params import UnequalRow diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/main.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/main.py index d7c2eba377fb228b1ed02f49d9d540108a160ea6..2ebc0a6db917334f492019fa694330d86a0f37e1 100644 --- a/debug/accuracy_tools/atat/pytorch/free_benchmark/main.py +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/main.py @@ -1,7 +1,8 @@ from abc import ABC import torch -from atat.pytorch.free_benchmark import Const, logger +from atat.core.common.const import Const +from atat.pytorch.free_benchmark import logger from atat.pytorch.free_benchmark.common.constant import CommonField from atat.pytorch.free_benchmark.common.enums import ( DeviceType, diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py index 2df26afc1beecc21a8a77bddbe76de6142d68862..03718e3c4d6c4c4ea28ae6eec5daddc02bcedb7d 100644 --- a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py @@ -1,5 +1,6 @@ import torch -from atat.pytorch.free_benchmark import Const, logger +from atat.core.common.const import Const +from atat.pytorch.free_benchmark import logger from atat.pytorch.free_benchmark.common.constant import CommonField from atat.pytorch.free_benchmark.common.enums import PerturbationMode from atat.pytorch.free_benchmark.common.params import DataParams diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/base_handler.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/base_handler.py index c1bfbae24a5116dc07fdbc97cc5b7691a7003e07..c57d7e390a0fad73a5e87d72ec72e434835618cf 100644 --- a/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/base_handler.py +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/base_handler.py @@ -3,7 +3,8 @@ from abc import ABC, abstractmethod from typing import Any, Optional, Tuple import torch -from atat.pytorch.free_benchmark import Const, logger +from atat.core.common.const import Const +from atat.pytorch.free_benchmark import logger from atat.pytorch.free_benchmark.common.constant import ThresholdConfig from atat.pytorch.free_benchmark.common.enums import ( FuzzThreshold, diff --git a/debug/accuracy_tools/atat/pytorch/functional/dump_module.py b/debug/accuracy_tools/atat/pytorch/functional/dump_module.py index 8652f13f9bcaa46d13fc19bfe74c796ac45cdadf..675fa2a1bfdfdef9b12bf99f2428e3236c86a906 100644 --- a/debug/accuracy_tools/atat/pytorch/functional/dump_module.py +++ b/debug/accuracy_tools/atat/pytorch/functional/dump_module.py @@ -1,6 +1,6 @@ import torch.nn as nn from atat.pytorch.common.log import logger -from atat.core.common.utils import Const +from atat.core.common.const import Const from atat.pytorch.hook_module.api_registry import api_register from atat.pytorch.debugger.precision_debugger import PrecisionDebugger from atat.core.common.exceptions import MsaccException diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/api_registry.py b/debug/accuracy_tools/atat/pytorch/hook_module/api_registry.py index 6910276f94462018e09cbd3ae865cecff0d0cc1f..3b971cc71ecaf1d229337f4acf5afab6b5b0f9db 100644 --- a/debug/accuracy_tools/atat/pytorch/hook_module/api_registry.py +++ b/debug/accuracy_tools/atat/pytorch/hook_module/api_registry.py @@ -25,7 +25,8 @@ from atat.pytorch.hook_module.wrap_functional import get_functional_ops from atat.pytorch.hook_module.wrap_tensor import get_tensor_ops from atat.pytorch.hook_module.wrap_torch import get_torch_ops from atat.pytorch.hook_module.wrap_vf import get_vf_ops -from atat.pytorch.common.utils import torch_without_guard_version, npu_distributed_api, is_gpu, Const +from atat.pytorch.common.utils import torch_without_guard_version, npu_distributed_api, is_gpu +from atat.core.common.const import Const torch_version_above_2 = torch.__version__.split('+')[0] > '2.0' diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/hook_module.py b/debug/accuracy_tools/atat/pytorch/hook_module/hook_module.py index d45a951d479cd235aa6c29e45443d6b22e56dbc9..57212b6e45c572db3d3a235035647bbb614a3cb3 100644 --- a/debug/accuracy_tools/atat/pytorch/hook_module/hook_module.py +++ b/debug/accuracy_tools/atat/pytorch/hook_module/hook_module.py @@ -20,7 +20,7 @@ import threading import torch import torch.nn as nn import torch.utils.hooks as full_hooks -from atat.core.common.utils import Const +from atat.core.common.const import Const class HOOKModule(nn.Module): diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_aten.py b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_aten.py index c247a27082edf9af738e48706c6308d4916fc586..c5a3c6365d1841927c3acd2993ac2828092ae80b 100644 --- a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_aten.py +++ b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_aten.py @@ -21,7 +21,8 @@ import torch import yaml from atat.pytorch.hook_module.hook_module import HOOKModule -from atat.pytorch.common.utils import torch_device_guard, Const +from atat.pytorch.common.utils import torch_device_guard +from atat.core.common.const import Const from atat.core.common.file_check import FileOpen diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_distributed.py b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_distributed.py index 1059bf748843ae381e48a1c7103811ad71af83c2..e02189ac1bf2a5d2e1607f2625200eb4ee2ea6e8 100644 --- a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_distributed.py +++ b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_distributed.py @@ -21,7 +21,8 @@ import torch.distributed as dist import yaml from atat.pytorch.hook_module.hook_module import HOOKModule -from atat.pytorch.common.utils import torch_device_guard, Const +from atat.pytorch.common.utils import torch_device_guard +from atat.core.common.const import Const from atat.core.common.file_check import FileOpen diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_functional.py b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_functional.py index 8c829904cbe848b7e8d57abb1f5f3a2c0bc6d494..fa97f5ee3106c62ab63d80e2b2ebe349494ce695 100644 --- a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_functional.py +++ b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_functional.py @@ -21,7 +21,8 @@ import torch import yaml from atat.pytorch.hook_module.hook_module import HOOKModule -from atat.pytorch.common.utils import torch_device_guard, Const +from atat.pytorch.common.utils import torch_device_guard +from atat.core.common.const import Const from atat.pytorch.common.log import logger from atat.core.common.file_check import FileOpen diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_npu_custom.py b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_npu_custom.py index 90ad9cb9c4f3865a7dd110dcd5701acdcaf5ce64..7d0882804f478d8657faeb5f90b0b718914d72e5 100644 --- a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_npu_custom.py +++ b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_npu_custom.py @@ -21,7 +21,8 @@ import torch_npu import yaml from atat.pytorch.hook_module.hook_module import HOOKModule -from atat.pytorch.common.utils import torch_device_guard, torch_without_guard_version, Const +from atat.pytorch.common.utils import torch_device_guard, torch_without_guard_version +from atat.core.common.const import Const from atat.core.common.file_check import FileOpen cur_path = os.path.dirname(os.path.realpath(__file__)) diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_tensor.py b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_tensor.py index d53291b78faa594b69ac686100a3f48eccce4dc0..6fac18140238e016004c2329f0d8ff647045c0c0 100644 --- a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_tensor.py +++ b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_tensor.py @@ -21,7 +21,8 @@ import torch import yaml from atat.pytorch.hook_module.hook_module import HOOKModule -from atat.pytorch.common.utils import torch_device_guard, parameter_adapter, Const +from atat.pytorch.common.utils import torch_device_guard, parameter_adapter +from atat.core.common.const import Const from atat.core.common.file_check import FileOpen cur_path = os.path.dirname(os.path.realpath(__file__)) diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_torch.py b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_torch.py index 3cdece23065a7ab830c6125929c7cd4e86bab711..f0bd01fe4624ccc1a50d1f3e1fb5d614b30a7a8d 100644 --- a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_torch.py +++ b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_torch.py @@ -21,7 +21,8 @@ import torch import yaml from atat.pytorch.hook_module.hook_module import HOOKModule -from atat.pytorch.common.utils import torch_device_guard, Const +from atat.pytorch.common.utils import torch_device_guard +from atat.core.common.const import Const from atat.core.common.file_check import FileOpen cur_path = os.path.dirname(os.path.realpath(__file__)) diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_vf.py b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_vf.py index c5f3cb7ee0624757b617b049935b6aabc593ec8c..d4c570221d4b219be168b25dafd76264769da197 100644 --- a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_vf.py +++ b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_vf.py @@ -22,7 +22,8 @@ import yaml from atat.pytorch.hook_module.hook_module import HOOKModule from atat.core.common.file_check import FileOpen -from atat.pytorch.common.utils import torch_device_guard, Const +from atat.pytorch.common.utils import torch_device_guard +from atat.core.common.const import Const cur_path = os.path.dirname(os.path.realpath(__file__)) yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") diff --git a/debug/accuracy_tools/atat/pytorch/module_processer.py b/debug/accuracy_tools/atat/pytorch/module_processer.py index f56513907c262fbc43d1ce76aace271e04caf944..8ce9140e32cbc9f7d2d62eede164e24127a40f34 100644 --- a/debug/accuracy_tools/atat/pytorch/module_processer.py +++ b/debug/accuracy_tools/atat/pytorch/module_processer.py @@ -1,7 +1,7 @@ from functools import wraps import torch from torch.utils.hooks import BackwardHook -from atat.core.common.utils import Const +from atat.core.common.const import Const from atat.core.data_dump.scope import ModuleRangeScope diff --git a/debug/accuracy_tools/atat/pytorch/online_dispatch/compare.py b/debug/accuracy_tools/atat/pytorch/online_dispatch/compare.py index d7f9e4e339d40bd9863b6040430948cf2379d91b..e6d55ca0614767338023673b6d0cfa5c7d990c0f 100644 --- a/debug/accuracy_tools/atat/pytorch/online_dispatch/compare.py +++ b/debug/accuracy_tools/atat/pytorch/online_dispatch/compare.py @@ -7,7 +7,8 @@ from collections import namedtuple from rich.table import Table from rich.console import Console from .single_compare import single_benchmark_compare_wrap -from .utils import DispatchException, CompareConst +from .utils import DispatchException +from atat.core.common.const import CompareConst from atat.core.common.file_check import FileOpen from atat.pytorch.common.log import logger from atat.core.common.utils import CompareException diff --git a/debug/accuracy_tools/atat/pytorch/online_dispatch/dispatch.py b/debug/accuracy_tools/atat/pytorch/online_dispatch/dispatch.py index 386c3eac1797e177237e9a62389ce3567b982e31..7502d746acf38a52b427659b8d97f5f55a4ce96c 100644 --- a/debug/accuracy_tools/atat/pytorch/online_dispatch/dispatch.py +++ b/debug/accuracy_tools/atat/pytorch/online_dispatch/dispatch.py @@ -22,8 +22,8 @@ from .utils import get_callstack, data_to_cpu, logger_debug, logger_error, logge DispatchException from .compare import Comparator from atat.core.common.file_check import FileOpen -from atat.pytorch.common.utils import Const -from atat.core.common.utils import CompareConst, check_file_or_directory_path, check_path_before_create +from atat.core.common.utils import check_file_or_directory_path, check_path_before_create +from atat.core.common.const import Const, CompareConst current_time = time.strftime("%Y%m%d%H%M%S") RESULT_FILE_NAME = "accuracy_checking_result_" + current_time + ".csv" diff --git a/debug/accuracy_tools/atat/pytorch/online_dispatch/dump_compare.py b/debug/accuracy_tools/atat/pytorch/online_dispatch/dump_compare.py index b8d824fd1eafbce9603b18535500d21127417fff..cd7c5a3f282d17b4e064a4b00928861b2e32f737 100644 --- a/debug/accuracy_tools/atat/pytorch/online_dispatch/dump_compare.py +++ b/debug/accuracy_tools/atat/pytorch/online_dispatch/dump_compare.py @@ -5,11 +5,10 @@ from datetime import datetime, timezone import pandas as pd import torch -from atat.pytorch.common.utils import Const from .utils import np_save_data, logger_debug, logger_error, logger_warn, logger_user, COLOR_RED, COLOR_GREEN, \ COLOR_RESET, CSV_COLUMN_NAME -from atat.core.common.file_check import FileOpen, change_mode, FileCheckConst -from atat.core.common.utils import CompareConst +from atat.core.common.file_check import FileOpen, change_mode +from atat.core.common.const import CompareConst, FileCheckConst, Const from atat.pytorch.common.log import logger class DispatchRunParam: diff --git a/debug/accuracy_tools/atat/pytorch/online_dispatch/utils.py b/debug/accuracy_tools/atat/pytorch/online_dispatch/utils.py index 1f9c2e916c187615160bfb1be64a262b2cd6bd95..f3fcffb6f26adbd2b2ff6b77da720f7ecdfcb0ca 100644 --- a/debug/accuracy_tools/atat/pytorch/online_dispatch/utils.py +++ b/debug/accuracy_tools/atat/pytorch/online_dispatch/utils.py @@ -12,8 +12,8 @@ except ImportError: else: pta_cpu_device = torch.device("cpu") -from atat.core.common.utils import CompareConst -from atat.core.common.file_check import change_mode, FileCheckConst +from atat.core.common.const import CompareConst, FileCheckConst +from atat.core.common.file_check import change_mode cpu_device = torch._C.device("cpu") COLOR_RED = '\033[31m' @@ -58,17 +58,6 @@ BOOL_TYPE = [bool, np.uint8] INT_TYPE = [np.int32, np.int64] -class CompareConst: - NAN = np.nan - NA = "N/A" - PASS = 'pass' - WARNING = 'warning' - ERROR = 'error' - SKIP = 'SKIP' - TRUE = 'TRUE' - FALSE = 'FALSE' - - def get_callstack(): callstack = [] for (_, path, line, func, code, _) in inspect.stack()[2:]: diff --git a/debug/accuracy_tools/atat/pytorch/parse_tool/lib/utils.py b/debug/accuracy_tools/atat/pytorch/parse_tool/lib/utils.py index aeb1d7f6d2f7d3a0488822bfd7859633dfd70366..ce42d242ba29ef4dcf7f2af042242b1b9987de42 100644 --- a/debug/accuracy_tools/atat/pytorch/parse_tool/lib/utils.py +++ b/debug/accuracy_tools/atat/pytorch/parse_tool/lib/utils.py @@ -30,7 +30,7 @@ from atat.pytorch.parse_tool.lib.file_desc import DumpDecodeFileDesc, FileDesc from atat.pytorch.parse_tool.lib.parse_exception import ParseException from atat.core.common.file_check import change_mode, check_other_user_writable,\ check_path_executable, check_path_owner_consistent -from atat.core.common.file_check import FileCheckConst +from atat.core.common.const import FileCheckConst from atat.core.common.file_check import FileOpen from atat.core.common.utils import check_file_or_directory_path from atat.pytorch.common.log import logger diff --git a/debug/accuracy_tools/atat/pytorch/pt_config.py b/debug/accuracy_tools/atat/pytorch/pt_config.py index e04b88bb96633e54bf6c0e085247907a29752503..0674b91b3410765ca7dcb9a6f38c6f03fa6e94f1 100644 --- a/debug/accuracy_tools/atat/pytorch/pt_config.py +++ b/debug/accuracy_tools/atat/pytorch/pt_config.py @@ -3,7 +3,7 @@ import os from atat.core.common_config import CommonConfig, BaseConfig from atat.core.common.file_check import FileOpen -from atat.core.common.utils import Const +from atat.core.common.const import Const class TensorConfig(BaseConfig): diff --git a/debug/accuracy_tools/atat/pytorch/service.py b/debug/accuracy_tools/atat/pytorch/service.py index cd80d0852ac7526615b0a14f059803ef6222ed7f..d0b9c4d4b27bd13672e0529de6dd3349860e533f 100644 --- a/debug/accuracy_tools/atat/pytorch/service.py +++ b/debug/accuracy_tools/atat/pytorch/service.py @@ -3,8 +3,8 @@ import os from pathlib import Path from atat.pytorch.common.log import logger -from atat.core.common.file_check import FileChecker, FileCheckConst, check_path_before_create -from atat.core.common.utils import Const +from atat.core.common.file_check import FileChecker, check_path_before_create +from atat.core.common.const import Const, FileCheckConst from atat.core.common.exceptions import DistributedNotInitializedError, MsaccException from atat.core.data_dump.data_collector import build_data_collector from atat.core.data_dump.scope import BaseScope diff --git a/debug/accuracy_tools/atat/test/core_ut/test_file_check.py b/debug/accuracy_tools/atat/test/core_ut/test_file_check.py new file mode 100644 index 0000000000000000000000000000000000000000..aa7882aa5906b8e472495ede42653f9f5584f573 --- /dev/null +++ b/debug/accuracy_tools/atat/test/core_ut/test_file_check.py @@ -0,0 +1,218 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +import os + +from unittest import TestCase +from unittest.mock import patch, MagicMock + +from atat.core.common.log import logger +from atat.core.common.const import FileCheckConst +from atat.core.common.exceptions import FileCheckException +from atat.core.common.file_check import (check_link, + check_path_length, + check_path_exists, + check_path_readability, + check_path_writability, + check_path_executable, + check_other_user_writable, + check_path_owner_consistent, + check_path_pattern_vaild, + check_file_size, + check_common_file_size, + check_file_suffix, + check_path_type) + + +class TestFileCheckUtil(TestCase): + @patch.object(logger, "error") + def test_check_link(self, mock_logger_error): + with patch("atat.core.common.file_check.os.path.islink", return_value=True): + with self.assertRaises(FileCheckException) as context: + check_link("link_path") + self.assertEqual(str(context.exception), + FileCheckException.err_strs.get(FileCheckException.SOFT_LINK_ERROR)) + mock_logger_error.assert_called_with("The file path link_path is a soft link.") + + @patch.object(logger, "error") + def test_check_path_length(self, mock_logger_error): + path = "P" * (FileCheckConst.DIRECTORY_LENGTH + 1) + with self.assertRaises(FileCheckException) as context: + check_path_length(path) + self.assertEqual(str(context.exception), + FileCheckException.err_strs.get(FileCheckException.ILLEGAL_PATH_ERROR)) + mock_logger_error.assert_called_with("The file path length exceeds limit.") + + path = "P" * (FileCheckConst.FILE_NAME_LENGTH + 1) + with self.assertRaises(FileCheckException) as context: + check_path_length(path) + self.assertEqual(str(context.exception), + FileCheckException.err_strs.get(FileCheckException.ILLEGAL_PATH_ERROR)) + mock_logger_error.assert_called_with("The file path length exceeds limit.") + + path = "P" * (FileCheckConst.FILE_NAME_LENGTH - 5) + with self.assertRaises(FileCheckException) as context: + check_path_length(path, name_length=FileCheckConst.FILE_NAME_LENGTH - 6) + self.assertEqual(str(context.exception), + FileCheckException.err_strs.get(FileCheckException.ILLEGAL_PATH_ERROR)) + mock_logger_error.assert_called_with("The file path length exceeds limit.") + + @patch.object(logger, "error") + def test_check_path_exists(self, mock_logger_error): + with patch("atat.core.common.file_check.os.path.exists", return_value=False): + with self.assertRaises(FileCheckException) as context: + check_path_exists("file_path") + self.assertEqual(str(context.exception), + FileCheckException.err_strs.get(FileCheckException.ILLEGAL_PATH_ERROR)) + mock_logger_error.assert_called_with("The file path file_path does not exist.") + + @patch.object(logger, "error") + def test_check_path_readability(self, mock_logger_error): + path = "file_path" + with patch("atat.core.common.file_check.os.access", return_value=False): + with self.assertRaises(FileCheckException) as context: + check_path_readability(path) + self.assertEqual(str(context.exception), + FileCheckException.err_strs.get(FileCheckException.FILE_PERMISSION_ERROR)) + mock_logger_error.assert_called_with(f"The file path {path} is not readable.") + + mock_access = MagicMock() + mock_access.return_value = True + with patch("atat.core.common.file_check.os.access", new=mock_access): + check_path_readability(path) + self.assertEqual(mock_access.call_args[0], (path, os.R_OK)) + + @patch.object(logger, "error") + def test_check_path_writability(self, mock_logger_error): + path = "file_path" + with patch("atat.core.common.file_check.os.access", return_value=False): + with self.assertRaises(FileCheckException) as context: + check_path_writability(path) + self.assertEqual(str(context.exception), + FileCheckException.err_strs.get(FileCheckException.FILE_PERMISSION_ERROR)) + mock_logger_error.assert_called_with(f"The file path {path} is not writable.") + + mock_access = MagicMock() + mock_access.return_value = True + with patch("atat.core.common.file_check.os.access", new=mock_access): + check_path_writability(path) + self.assertEqual(mock_access.call_args[0], (path, os.W_OK)) + + @patch.object(logger, "error") + def test_check_path_executable(self, mock_logger_error): + path = "file_path" + with patch("atat.core.common.file_check.os.access", return_value=False): + with self.assertRaises(FileCheckException) as context: + check_path_executable(path) + self.assertEqual(str(context.exception), + FileCheckException.err_strs.get(FileCheckException.FILE_PERMISSION_ERROR)) + mock_logger_error.assert_called_with(f"The file path {path} is not executable.") + + mock_access = MagicMock() + mock_access.return_value = True + with patch("atat.core.common.file_check.os.access", new=mock_access): + check_path_executable(path) + self.assertEqual(mock_access.call_args[0], (path, os.X_OK)) + + @patch.object(logger, "error") + def test_check_other_user_writable(self, mock_logger_error): + class TestStat: + def __init__(self, mode): + self.st_mode = mode + + path = "file_path" + mock_stat = TestStat(0o002) + with patch("atat.core.common.file_check.os.stat", return_value=mock_stat): + with self.assertRaises(FileCheckException) as context: + check_other_user_writable(path) + self.assertEqual(str(context.exception), + FileCheckException.err_strs.get(FileCheckException.FILE_PERMISSION_ERROR)) + mock_logger_error.assert_called_with(f"The file path {path} may be insecure " + "because other users have write permissions. ") + + @patch.object(logger, "error") + def test_check_path_owner_consistent(self, mock_logger_error): + file_path = os.path.realpath(__file__) + file_owner = os.stat(file_path).st_uid + with patch("atat.core.common.file_check.os.getuid", return_value=file_owner+1): + with self.assertRaises(FileCheckException) as context: + check_path_owner_consistent(file_path) + self.assertEqual(str(context.exception), + FileCheckException.err_strs.get(FileCheckException.FILE_PERMISSION_ERROR)) + mock_logger_error.assert_called_with(f"The file path {file_path} may be insecure " + "because is does not belong to you.") + + @patch.object(logger, "error") + def test_check_path_pattern_vaild(self, mock_logger_error): + path = "path" + mock_re_match = MagicMock() + mock_re_match.return_value = False + with patch("atat.core.common.file_check.re.match", new=mock_re_match): + with self.assertRaises(FileCheckException) as context: + check_path_pattern_vaild(path) + self.assertEqual(str(context.exception), + FileCheckException.err_strs.get(FileCheckException.ILLEGAL_PATH_ERROR)) + mock_logger_error.assert_called_with(f"The file path {path} contains special characters.") + mock_re_match.assert_called_with(FileCheckConst.FILE_VALID_PATTERN, path) + + @patch.object(logger, "error") + def test_check_file_size(self, mock_logger_error): + file_path = os.path.realpath(__file__) + file_size = os.path.getsize(file_path) + max_size = file_size + with self.assertRaises(FileCheckException) as context: + check_file_size(file_path, max_size) + self.assertEqual(str(context.exception), + FileCheckException.err_strs.get(FileCheckException.FILE_TOO_LARGE_ERROR)) + mock_logger_error.assert_called_with(f"The size of file path {file_path} exceeds {max_size} bytes.") + + def test_check_common_file_size(self): + mock_check_file_size = MagicMock() + with patch("atat.core.common.file_check.os.path.isfile", return_value=True), \ + patch("atat.core.common.file_check.check_file_size", new=mock_check_file_size): + for suffix, max_size in FileCheckConst.FILE_SIZE_DICT.items(): + check_common_file_size(suffix) + mock_check_file_size.assert_called_with(suffix, max_size) + + @patch.object(logger, "error") + def test_check_file_suffix(self, mock_logger_error): + file_path = "file_path" + suffix = "suffix" + with self.assertRaises(FileCheckException) as context: + check_file_suffix(file_path, suffix) + self.assertEqual(str(context.exception), + FileCheckException.err_strs.get(FileCheckException.INVALID_FILE_ERROR)) + mock_logger_error.assert_called_with(f"The {file_path} should be a {suffix} file!") + + @patch.object(logger, "error") + def test_check_path_type(self, mock_logger_error): + file_path = "file_path" + + with patch("atat.core.common.file_check.os.path.isfile", return_value=False), \ + patch("atat.core.common.file_check.os.path.isdir", return_value=True): + with self.assertRaises(FileCheckException) as context: + check_path_type(file_path, FileCheckConst.FILE) + self.assertEqual(str(context.exception), + FileCheckException.err_strs.get(FileCheckException.INVALID_FILE_ERROR)) + mock_logger_error.assert_called_with(f"The {file_path} should be a file!") + + with patch("atat.core.common.file_check.os.path.isfile", return_value=True), \ + patch("atat.core.common.file_check.os.path.isdir", return_value=False): + with self.assertRaises(FileCheckException) as context: + check_path_type(file_path, FileCheckConst.DIR) + self.assertEqual(str(context.exception), + FileCheckException.err_strs.get(FileCheckException.INVALID_FILE_ERROR)) + mock_logger_error.assert_called_with(f"The {file_path} should be a dictionary!") diff --git a/debug/accuracy_tools/atat/test/core_ut/test_utils.py b/debug/accuracy_tools/atat/test/core_ut/test_utils.py index 89734f2c572bff2aa864db16f23dfe8665042f74..b3273358e43593e14f931a5675e9039c684d1c59 100644 --- a/debug/accuracy_tools/atat/test/core_ut/test_utils.py +++ b/debug/accuracy_tools/atat/test/core_ut/test_utils.py @@ -1,13 +1,52 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2022-2023. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +import os +import uuid + from unittest import TestCase -from unittest.mock import patch +from unittest.mock import patch, MagicMock, mock_open -from atat.core.common.utils import check_seed_all, Const, CompareException, check_inplace_op from atat.core.common.log import logger +from atat.core.common.const import Const +from atat.core.common.utils import (CompareException, + check_seed_all, + check_inplace_op, + make_dump_path_if_not_exists, + check_mode_valid, + check_switch_valid, + check_dump_mode_valid, + check_summary_mode_valid, + check_summary_only_valid, + check_file_or_directory_path, + check_compare_param, + check_configuration_param, + is_starts_with, + _check_json, + check_json_file, + check_file_size, + check_regex_prefix_format_valid, + get_dump_data_path, + task_dumppath_get) +from atat.core.common.file_check import FileCheckConst class TestUtils(TestCase): @patch.object(logger, "error") - def test_check_seed_all(self, mock_print_error_log): + def test_check_seed_all(self, mock_error): self.assertIsNone(check_seed_all(1234, True)) self.assertIsNone(check_seed_all(0, True)) self.assertIsNone(check_seed_all(Const.MAX_SEED_VALUE, True)) @@ -15,23 +54,23 @@ class TestUtils(TestCase): with self.assertRaises(CompareException) as context: check_seed_all(-1, True) self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR) - mock_print_error_log.assert_called_with(f"Seed must be between 0 and {Const.MAX_SEED_VALUE}.") + mock_error.assert_called_with(f"Seed must be between 0 and {Const.MAX_SEED_VALUE}.") with self.assertRaises(CompareException) as context: check_seed_all(Const.MAX_SEED_VALUE + 1, True) self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR) - mock_print_error_log.assert_called_with(f"Seed must be between 0 and {Const.MAX_SEED_VALUE}.") + mock_error.assert_called_with(f"Seed must be between 0 and {Const.MAX_SEED_VALUE}.") with self.assertRaises(CompareException) as context: check_seed_all("1234", True) self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR) - mock_print_error_log.assert_called_with("Seed must be integer.") + mock_error.assert_called_with("Seed must be integer.") with self.assertRaises(CompareException) as context: check_seed_all(1234, 1) self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR) - mock_print_error_log.assert_called_with("seed_all mode must be bool.") - + mock_error.assert_called_with("seed_all mode must be bool.") + def test_check_inplace_op(self): test_prefix_1 = "Distributed.broadcast.0.forward.input.0" self.assertTrue(check_inplace_op(test_prefix_1)) @@ -39,3 +78,268 @@ class TestUtils(TestCase): self.assertFalse(check_inplace_op(test_prefix_2)) test_prefix_3 = "Torch.sum.0.backward.output.0" self.assertFalse(check_inplace_op(test_prefix_3)) + + @patch.object(logger, "error") + def test_make_dump_path_if_not_exists(self, mock_error): + file_path = os.path.realpath(__file__) + dirname = os.path.dirname(file_path) + str(uuid.uuid4()) + + def test_mkdir(self, **kwargs): + raise OSError + + if not os.path.exists(dirname): + with patch("atat.core.common.utils.Path.mkdir", new=test_mkdir): + with self.assertRaises(CompareException) as context: + make_dump_path_if_not_exists(dirname) + self.assertEqual(context.exception.code, CompareException.INVALID_PATH_ERROR) + + make_dump_path_if_not_exists(file_path) + mock_error.assert_called_with(f"{file_path} already exists and is not a directory.") + + def test_check_mode_valid(self): + with self.assertRaises(ValueError) as context: + check_mode_valid("all", scope="scope") + self.assertEqual(str(context.exception), "scope param set invalid, it's must be a list.") + + with self.assertRaises(ValueError) as context: + check_mode_valid("all", api_list="api_list") + self.assertEqual(str(context.exception), "api_list param set invalid, it's must be a list.") + + mode = "all_list" + with self.assertRaises(CompareException) as context: + check_mode_valid(mode) + self.assertEqual(context.exception.code, CompareException.INVALID_DUMP_MODE) + self.assertEqual(str(context.exception), + f"Current mode '{mode}' is not supported. Please use the field in {Const.DUMP_MODE}") + + mode = "list" + with self.assertRaises(ValueError) as context: + check_mode_valid(mode) + self.assertEqual(str(context.exception), + "set_dump_switch, scope param set invalid, it's should not be an empty list.") + + @patch.object(logger, "error") + def test_check_switch_valid(self, mock_error): + with self.assertRaises(CompareException) as context: + check_switch_valid("Close") + self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR) + mock_error.assert_called_with("Please set switch with 'ON' or 'OFF'.") + + @patch.object(logger, "warning") + def test_check_dump_mode_valid(self, mock_warning): + dump_mode = check_dump_mode_valid("all") + mock_warning.assert_called_with("Please set dump_mode as a list.") + self.assertEqual(dump_mode, ["forward", "backward", "input", "output"]) + + with self.assertRaises(ValueError) as context: + check_dump_mode_valid("all_forward") + self.assertEqual(str(context.exception), + "Please set dump_mode as a list containing one or more of the following: " + + "'all', 'forward', 'backward', 'input', 'output'.") + + def test_check_summary_mode_valid(self): + with self.assertRaises(CompareException) as context: + check_summary_mode_valid("MD5") + self.assertEqual(context.exception.code, CompareException.INVALID_SUMMARY_MODE) + self.assertEqual(str(context.exception), "The summary_mode is not valid") + + @patch.object(logger, "error") + def test_check_summary_only_valid(self, mock_error): + summary_only = check_summary_only_valid(True) + self.assertTrue(summary_only) + + with self.assertRaises(CompareException) as context: + check_summary_only_valid("True") + self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR) + mock_error.assert_called_with("Params summary_only only support True or False.") + + def test_check_file_or_directory_path(self): + class TestFileChecker: + file_path = "" + path_type = "" + ability = "" + checked = False + + def __init__(self, file_path, path_type, ability=None): + TestFileChecker.file_path = file_path + TestFileChecker.path_type = path_type + TestFileChecker.ability = ability + + def common_check(self): + TestFileChecker.checked = True + + file_path = os.path.realpath(__file__) + dirname = os.path.dirname(file_path) + + with patch("atat.core.common.utils.FileChecker", new=TestFileChecker): + check_file_or_directory_path(file_path, isdir=False) + self.assertTrue(TestFileChecker.checked) + self.assertEqual(TestFileChecker.file_path, file_path) + self.assertEqual(TestFileChecker.path_type, FileCheckConst.FILE) + self.assertEqual(TestFileChecker.ability, FileCheckConst.READ_ABLE) + + TestFileChecker.checked = False + with patch("atat.core.common.utils.FileChecker", new=TestFileChecker): + check_file_or_directory_path(dirname, isdir=True) + self.assertTrue(TestFileChecker.checked) + self.assertEqual(TestFileChecker.file_path, dirname) + self.assertEqual(TestFileChecker.path_type, FileCheckConst.DIR) + self.assertEqual(TestFileChecker.ability, FileCheckConst.WRITE_ABLE) + + @patch.object(logger, "error") + def test_check_compare_param(self, mock_error): + params = { + "npu_json_path": "npu_json_path", + "bench_json_path": "bench_json_path", + "stack_json_path": "stack_json_path", + "npu_dump_data_dir": "npu_dump_data_dir", + "bench_dump_data_dir": "bench_dump_data_dir" + } + + call_args = [ + ("npu_json_path", False), + ("bench_json_path", False), + ("stack_json_path", False), + ("npu_dump_data_dir", True), + ("bench_dump_data_dir", True), + ("output_path", True), + ("npu_json_path", False), + ("bench_json_path", False), + ("stack_json_path", False), + ("output_path", True) + ] + + with self.assertRaises(CompareException) as context: + check_compare_param("npu_json_path", "output_path") + self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR) + mock_error.assert_called_with("Invalid input parameters") + + mock_check_file_or_directory_path = MagicMock() + mock_check_json_file = MagicMock() + with patch("atat.core.common.utils.FileOpen", mock_open(read_data="")), \ + patch("atat.core.common.utils.check_json_file", new=mock_check_json_file), \ + patch("atat.core.common.utils.check_file_or_directory_path", new=mock_check_file_or_directory_path): + check_compare_param(params, "output_path") + check_compare_param(params, "output_path", summary_compare=False, md5_compare=True) + for i in range(len(call_args)): + self.assertEqual(mock_check_file_or_directory_path.call_args_list[i][0], call_args[i]) + self.assertEqual(len(mock_check_json_file.call_args[0]), 4) + self.assertEqual(mock_check_json_file.call_args[0][0], params) + + @patch.object(logger, "error") + def test_check_configuration_param(self, mock_error): + with self.assertRaises(CompareException) as context: + check_configuration_param(stack_mode="False", auto_analyze=True, fuzzy_match=False) + self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR) + mock_error.assert_called_with("Invalid input parameters which should be only bool type.") + + def test_is_starts_with(self): + string = "input_slot0" + self.assertFalse(is_starts_with(string, [])) + self.assertFalse(is_starts_with("", ["input"])) + self.assertFalse(is_starts_with(string, ["output"])) + self.assertTrue(is_starts_with(string, ["input", "output"])) + + @patch.object(logger, "error") + def test__check_json(self, mock_error): + class TestOpen: + def __init__(self, string): + self.string = string + + def readline(self): + return self.string + + def seek(self, begin, end): + self.string = str(begin) + "_" + str(end) + + with self.assertRaises(CompareException) as context: + _check_json(TestOpen(""), "test.json") + self.assertEqual(context.exception.code, CompareException.INVALID_DUMP_FILE) + mock_error.assert_called_with("dump file test.json have empty line!") + + handler = TestOpen("jons file\n") + _check_json(handler, "test.json") + self.assertEqual(handler.string, "0_0") + + @patch("atat.core.common.utils._check_json") + def test_check_json_file(self, _mock_check_json): + input_param = { + "npu_json_path": "npu_json_path", + "bench_json_path": "bench_json_path", + "stack_json_path": "stack_json_path" + } + check_json_file(input_param, "npu_json", "bench_json", "stack_json") + self.assertEqual(_mock_check_json.call_args_list[0][0], ("npu_json", "npu_json_path")) + self.assertEqual(_mock_check_json.call_args_list[1][0], ("bench_json", "bench_json_path")) + self.assertEqual(_mock_check_json.call_args_list[2][0], ("stack_json", "stack_json_path")) + + @patch.object(logger, "error") + def test_check_file_size(self, mock_error): + with patch("atat.core.common.utils.os.path.getsize", return_value=120): + with self.assertRaises(CompareException) as context: + check_file_size("input_file", 100) + self.assertEqual(context.exception.code, CompareException.INVALID_FILE_ERROR) + mock_error.assert_called_with("The size (120) of input_file exceeds (100) bytes, tools not support.") + + def test_check_regex_prefix_format_valid(self): + prefix = "A" * 21 + with self.assertRaises(ValueError) as context: + check_regex_prefix_format_valid(prefix) + self.assertEqual(str(context.exception), f"Maximum length of prefix is {Const.REGEX_PREFIX_MAX_LENGTH}, " + f"while current length is {len(prefix)}") + + prefix = "(prefix)" + with self.assertRaises(ValueError) as context: + check_regex_prefix_format_valid(prefix) + self.assertEqual(str(context.exception), f"prefix contains invalid characters, " + f"prefix pattern {Const.REGEX_PREFIX_PATTERN}") + + @patch("atat.core.common.utils.check_file_or_directory_path") + def test_get_dump_data_path(self, mock_check_file_or_directory_path): + file_path = os.path.realpath(__file__) + dirname = os.path.dirname(file_path) + + dump_data_path, file_is_exist = get_dump_data_path(dirname) + self.assertEqual(mock_check_file_or_directory_path.call_args[0], (dirname, True)) + self.assertEqual(dump_data_path, dirname) + self.assertTrue(file_is_exist) + + @patch.object(logger, "error") + def test_task_dumppath_get(self, mock_error): + input_param = { + "npu_json_path": None, + "bench_json_path": "bench_json_path" + } + npu_json = { + "task": Const.TENSOR, + "dump_data_dir": "dump_data_dir", + "data": "data" + } + + with self.assertRaises(CompareException) as context: + task_dumppath_get(input_param) + self.assertEqual(context.exception.code, CompareException.INVALID_PATH_ERROR) + mock_error.assert_called_with("Please check the json path is valid.") + + input_param["npu_json_path"] = "npu_json_path" + with patch("atat.core.common.utils.FileOpen", mock_open(read_data="")), \ + patch("atat.core.common.utils.json.load", return_value=npu_json): + summary_compare, md5_compare = task_dumppath_get(input_param) + self.assertFalse(summary_compare) + self.assertFalse(md5_compare) + + npu_json["task"] = Const.STATISTICS + with patch("atat.core.common.utils.FileOpen", mock_open(read_data="")), \ + patch("atat.core.common.utils.json.load", return_value=npu_json), \ + patch("atat.core.common.utils.md5_find", return_value=True): + summary_compare, md5_compare = task_dumppath_get(input_param) + self.assertFalse(summary_compare) + self.assertTrue(md5_compare) + + npu_json["task"] = Const.OVERFLOW_CHECK + with patch("atat.core.common.utils.FileOpen", mock_open(read_data="")), \ + patch("atat.core.common.utils.json.load", return_value=npu_json): + with self.assertRaises(CompareException) as context: + task_dumppath_get(input_param) + self.assertEqual(context.exception.code, CompareException.INVALID_TASK_ERROR) + mock_error.assert_called_with("Compare is not required for overflow_check or free_benchmark.") diff --git a/debug/accuracy_tools/atat/test/mindspore_ut/test_ms_config.py b/debug/accuracy_tools/atat/test/mindspore_ut/test_ms_config.py index 6be8949684c89012f0dc2165ba24eab4e7a77f1c..fe92a90aa1ce85e6efb38a8b089afcb34a199b59 100644 --- a/debug/accuracy_tools/atat/test/mindspore_ut/test_ms_config.py +++ b/debug/accuracy_tools/atat/test/mindspore_ut/test_ms_config.py @@ -1,7 +1,7 @@ from unittest import TestCase from unittest.mock import patch, mock_open -from atat.core.common.utils import Const +from atat.core.common.const import Const from atat.mindspore.ms_config import parse_json_config diff --git a/debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py b/debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py index aab90f122f64a148955a7c824e8975c7d19cb679..701c67b07483652850831d1f720f417cd5554481 100644 --- a/debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py +++ b/debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py @@ -9,7 +9,7 @@ from atat.pytorch.api_accuracy_checker.compare.api_precision_compare import ( check_error_rate, get_api_checker_result, ) -from atat.pytorch.api_accuracy_checker.compare.compare_utils import CompareConst +from atat.core.common.const import CompareConst class TestApiPrecisionCompare(unittest.TestCase): diff --git a/debug/accuracy_tools/atat/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py b/debug/accuracy_tools/atat/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py index 448b518c3344b2898ce0487687c374b0071dacb6..828d646c52f363e758a10939dbfcd98d942eb9f2 100644 --- a/debug/accuracy_tools/atat/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py +++ b/debug/accuracy_tools/atat/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py @@ -1,7 +1,7 @@ from unittest import TestCase import torch -from atat.pytorch.common.utils import Const +from atat.core.common.const import Const from atat.pytorch.free_benchmark.common.enums import DeviceType, PerturbationMode from atat.pytorch.free_benchmark.common.params import data_pre_deal from atat.pytorch.free_benchmark.perturbed_layers.layer_factory import LayerFactory diff --git a/debug/accuracy_tools/atat/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py b/debug/accuracy_tools/atat/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py index 948fdaecea5c39be5602f66fa565985655054050..d46e26e09488d481ac9e96a8b4002dd3b62446bd 100644 --- a/debug/accuracy_tools/atat/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +++ b/debug/accuracy_tools/atat/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py @@ -2,7 +2,7 @@ from abc import ABC from unittest import TestCase import torch -from atat.pytorch.common.utils import Const +from atat.core.common.const import Const from atat.pytorch.free_benchmark.common.constant import PreheatConfig, ThresholdConfig from atat.pytorch.free_benchmark.common.counter import preheat_counter from atat.pytorch.free_benchmark.common.enums import ( diff --git a/debug/accuracy_tools/atat/test/pytorch_ut/free_benchmark/test_main.py b/debug/accuracy_tools/atat/test/pytorch_ut/free_benchmark/test_main.py index 4c1aa1deff4e805c9576892ef6d1775c127b9d53..d326e993c07d66a1baf5ae785ed4b519624bb982 100644 --- a/debug/accuracy_tools/atat/test/pytorch_ut/free_benchmark/test_main.py +++ b/debug/accuracy_tools/atat/test/pytorch_ut/free_benchmark/test_main.py @@ -4,7 +4,7 @@ from unittest import TestCase import torch import torch.nn as nn -from atat.pytorch.common.utils import Const +from atat.core.common.const import Const from atat.pytorch.free_benchmark import FreeBenchmarkCheck from atat.pytorch.free_benchmark.common.constant import CommonField, PreheatConfig from atat.pytorch.free_benchmark.common.enums import ( diff --git a/debug/accuracy_tools/atat/test/pytorch_ut/test_pt_config.py b/debug/accuracy_tools/atat/test/pytorch_ut/test_pt_config.py index c931c8550716bc65b58655db6140684408f596cc..fa52fe0e1b05701cec9c8cdf41fe1586029c826e 100644 --- a/debug/accuracy_tools/atat/test/pytorch_ut/test_pt_config.py +++ b/debug/accuracy_tools/atat/test/pytorch_ut/test_pt_config.py @@ -1,7 +1,7 @@ from unittest import TestCase from unittest.mock import patch, mock_open -from atat.core.common.utils import Const +from atat.core.common.const import Const from atat.pytorch.pt_config import parse_json_config diff --git a/debug/accuracy_tools/grad_tool/common/base_comparator.py b/debug/accuracy_tools/grad_tool/common/base_comparator.py index f940ef5135503e68ffa4559b3c55d5cc280f8b25..d3254ae71f9a8fccb8608088462c4733c166814d 100644 --- a/debug/accuracy_tools/grad_tool/common/base_comparator.py +++ b/debug/accuracy_tools/grad_tool/common/base_comparator.py @@ -40,9 +40,9 @@ class BaseComparator(ABC): create_directory(output_dir) for rank in tqdm(ranks, desc="rank"): print_info_log(f"now comparing rank {rank}:") - cls.compare(os.path.join(path1, f"rank_{rank}"), - os.path.join(path2, f"rank_{rank}"), - os.path.join(output_dir, f"rank_{rank}")) + cls.compare(os.path.join(path1, f"rank{rank}"), + os.path.join(path2, f"rank{rank}"), + os.path.join(output_dir, f"rank{rank}")) @classmethod def compare(cls, path1: str, path2: str, output_dir: str): @@ -59,15 +59,15 @@ class BaseComparator(ABC): check_file_or_directory_path(path1, file_type=GradConst.DIR) check_file_or_directory_path(path2, file_type=GradConst.DIR) dirs = [] - for dirname in os.listdir(path1): - splits = dirname.split('_') - if not splits or splits[0] != dir_prefix or not splits[1].isdigit(): + for dir_name in os.listdir(path1): + index = dir_name.replace(dir_prefix, "", 1) + if not dir_name.startswith(dir_prefix) or not index.isdigit(): continue - folder2 = os.path.join(path2, dirname) + folder2 = os.path.join(path2, dir_name) if not os.path.isdir(folder2): continue - dirs.append(int(splits[1])) + dirs.append(int(index)) dirs = sorted(dirs) return dirs @@ -101,8 +101,8 @@ class BaseComparator(ABC): total_count_summary = 0 for grad_name in grad_weight_order: grad_file = cls._get_name_matched_grad_file(grad_name, grad_files) - grad1 = os.path.join(path1, f"step_{step}", grad_file) - grad2 = os.path.join(path2, f"step_{step}", grad_file) + grad1 = os.path.join(path1, f"step{step}", grad_file) + grad2 = os.path.join(path2, f"step{step}", grad_file) same_count, total_count = cls._calculate_similarity(grad1, grad2) same_count_summary += same_count total_count_summary += total_count @@ -124,8 +124,8 @@ class BaseComparator(ABC): @classmethod def _get_matched_grad_files(cls, path1: str, path2: str, step: int): - path1 = os.path.join(path1, f"step_{step}") - path2 = os.path.join(path2, f"step_{step}") + path1 = os.path.join(path1, f"step{step}") + path2 = os.path.join(path2, f"step{step}") check_file_or_directory_path(path1, file_type=GradConst.DIR) check_file_or_directory_path(path2, file_type=GradConst.DIR) grad_files = [] diff --git a/debug/accuracy_tools/grad_tool/grad_pt/grad_monitor.py b/debug/accuracy_tools/grad_tool/grad_pt/grad_monitor.py index 6733d566d6aaadd2fbf97076bfaca84a3e05160d..f3079e622c29eefe20a6e1fdc9d372002b596610 100644 --- a/debug/accuracy_tools/grad_tool/grad_pt/grad_monitor.py +++ b/debug/accuracy_tools/grad_tool/grad_pt/grad_monitor.py @@ -96,8 +96,8 @@ class PtGradientMonitor(BaseMonitor): output_lines.append(grad_info) if self._level_adp["have_grad_direction"]: PtGradientMonitor.save_grad_direction(param_name, grad, - f'{self._output_path}/rank_{self._rank}/step_{self._step}') - output_path = os.path.join(self._output_path, f"rank_{getattr(self, '_rank')}", + f'{self._output_path}/rank{self._rank}/step{self._step}') + output_path = os.path.join(self._output_path, f"rank{getattr(self, '_rank')}", f"grad_summary_{self._step}.csv") write_csv(output_path, output_lines, GradStatCsv.generate_csv_header(self._level_adp, self._bounds)) diff --git a/plugins/tensorboard-plugins/ OWNERS b/plugins/tensorboard-plugins/OWNERS similarity index 93% rename from plugins/tensorboard-plugins/ OWNERS rename to plugins/tensorboard-plugins/OWNERS index 34c383beaf138da92df0991b472135496450a827..507672c7399438d6c350856386f1818652ea21f8 100644 --- a/plugins/tensorboard-plugins/ OWNERS +++ b/plugins/tensorboard-plugins/OWNERS @@ -1,9 +1,9 @@ -options: - no_parent_owners: true -approvers: -- wo-wenjie -- ly-qianxiao -reviewers: -- wo-wenjie -- ly-qianxiao -- leo920320 +options: + no_parent_owners: true +approvers: +- wo-wenjie +- ly-qianxiao +reviewers: +- wo-wenjie +- ly-qianxiao +- leo920320 diff --git a/profiler/advisor/README.md b/profiler/advisor/README.md index ccaccdda017ad61674ea245eb5f9c9465747f63e..c650f40b3ea8ef48b3c7644e279b00a1cb99f29a 100644 --- a/profiler/advisor/README.md +++ b/profiler/advisor/README.md @@ -92,31 +92,31 @@ msprof-analyze的advisor功能是将Ascend PyTorch Profiler或者msprof采集的 - 总体性能瓶颈 ```bash - msprof-analyze advisor all -d {profiling_path} [-bp benchmark_profiling_path] [-cv cann_version] [-tv torch_version] [-pt profiling_type] [-D] [-h] + msprof-analyze advisor all -d {profiling_path} [-bp benchmark_profiling_path] [-cv cann_version] [-tv torch_version] [-pt profiling_type] [--debug] [-h] ``` - 计算瓶颈 ```bash - msprof-analyze advisor computation -d {profiling_path} [-bp benchmark_profiling_path] [-cv cann_version] [-tv torch_version] [-pt profiling_type] [-D] [-h] + msprof-analyze advisor computation -d {profiling_path} [-cv cann_version] [-tv torch_version] [-pt profiling_type] [--debug] [-h] ``` - 调度瓶颈 ```bash - msprof-analyze advisor schedule -d {profiling_path} [-bp benchmark_profiling_path] [-cv cann_version] [-tv torch_version] [-D] [-h] + msprof-analyze advisor schedule -d {profiling_path} [-cv cann_version] [-tv torch_version] [--debug] [-h] ``` #### 参数介绍 | 参数 | 说明 | 是否必选 | | ---------------------------------- | ------------------------------------------------------------ | -------- | -| -d
--profiling_path | 性能数据所在目录。性能数据通过Profiling工具采集获取。请确保性能数据采集时配置“aic-metrics”参数为“PipeUtilization”,“aicpu”参数为“on”。advisor依赖Profiling工具解析后的timeline数据、summary数据以及info.json*文件,请确保指定的“profiling_dir”目录下存在以上文件。 | 是 | +| -d
--profiling_path | 性能数据文件或目录所在路径,Ascend PyTorch Profiler采集场景指定为`*_ascend_pt`性能数据结果目录,其他场景指定为`PROF_XXX`性能数据结果目录。建议通过Ascend PyTorch Profiler获取性能数据。
advisor依赖Profiling工具解析后的timeline数据(.json)、summary(.csv)数据以及info.json*文件,请确保指定的“profiling_path”目录下存在以上文件。 | 是 | | -bp
--benchmark_profiling_path | 基准性能数据所在目录,用于性能比对。性能数据通过Profiling工具采集获取。
**computation和schedule不支持该参数。** | 否 | | -cv
--cann_version | 使用Profiling工具采集时对应的CANN软件版本,可通过在环境中执行如下命令获取其version字段,目前配套的兼容版本为“6.3.RC2”,“7.0.RC1”、“7.0.0”、“8.0.RC1”,此字段不填默认按“8.0.RC1”版本数据进行处理,其余版本采集的Profiling数据在分析时可能会导致不可知问题:`cat /usr/local/Ascend/ascend-toolkit/latest/aarch64-linux/ascend_toolkit_install.info` | 否 | | -tv
--torch_version | 运行环境的torch版本,默认为1.11.0,支持torch1.11.0和torch2.1.0,当运行环境torch版本为其他版本如torch1.11.3时,可以忽略小版本号差异选择相近的torch版本如1.11.0。 | 否 | -| -pt
--profiling_type | 配置性能数据采集使用的Profiling工具类型。可取值:
ascend_pytorch_profiler:使用Ascend PyThon Profiler接口方式采集的性能数据时配置,默认值。
msprof:使用msprof命令行方式采集的性能数据时配置。
mslite:使用[Benchmark](https://gitee.com/ascend/tools/tree/master/ais-bench_workload/tool/ais_bench)工具采集的性能数据时配置。
**schedule不支持该参数。** | 否 | -| -D
--debug | 工具执行报错时可打开此开关,将会展示详细保存堆栈信息。 | 否 | +| -pt
--profiling_type | 配置性能数据采集使用的Profiling工具类型。可取值:
ascend_pytorch_profiler:使用Ascend PyThon Profiler接口方式采集的性能数据时配置,默认值。
msprof:使用msprof命令行方式采集的性能数据时配置。功能完善中,暂不建议使用。
mslite:使用[Benchmark](https://gitee.com/ascend/tools/tree/master/ais-bench_workload/tool/ais_bench)工具采集的性能数据时配置。不建议使用。
**schedule不支持该参数。** | 否 | +| --debug | 工具执行报错时可打开此开关,将会展示详细保存堆栈信息。 | 否 | | -h,-H
--help | 在需要查询当前命令附属子命令或相关参数时,给出帮助建议。 | 否 | ### 报告解析 diff --git a/profiler/compare_tools/compare_backend/utils/constant.py b/profiler/compare_tools/compare_backend/utils/constant.py index 1b77b214c85f6733e36298e119e43a778fd7969f..e2854692ae3218c873171b75878e3e69203effa2 100644 --- a/profiler/compare_tools/compare_backend/utils/constant.py +++ b/profiler/compare_tools/compare_backend/utils/constant.py @@ -74,7 +74,7 @@ class Constant(object): MEMORY_LIST = "memory_list" COMMUNICATION_DICT = "comm_dict" - #compare type + # compare type OVERALL_COMPARE = "overall" BWD_LIST = ["bwd", "backward", "back"] diff --git a/profiler/module_visualization/__init__.py b/profiler/module_visualization/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/profiler/module_visualization/graph/__init__.py b/profiler/module_visualization/graph/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/profiler/module_visualization/graph/prof_node.py b/profiler/module_visualization/graph/prof_node.py new file mode 100644 index 0000000000000000000000000000000000000000..cfcdabbb991d2abb86f31e5a5866e788cf9a3c6e --- /dev/null +++ b/profiler/module_visualization/graph/prof_node.py @@ -0,0 +1,90 @@ +# Copyright (c) 2024 Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from profiler.prof_common.constant import Constant +from profiler.prof_common.base_node import BaseNode +from profiler.prof_common.trace_event_bean import TraceEventBean + + +class ProfNode(BaseNode): + MODULE_TYPE = 1 + + def __init__(self, event: TraceEventBean, parent_node=None): + super().__init__(event, parent_node) + self._kernel_total_list = [] + + @property + def node_id(self): + return self._event.unique_id + + @property + def total_kernels(self): + return self._kernel_total_list + + @property + def host_total_dur(self): + if self.is_root_node: + return sum((node.host_total_dur for node in self.child_nodes)) + return self._event.dur + + @property + def host_self_dur(self): + return self.host_total_dur - sum((node.host_total_dur for node in self.child_nodes)) + + @property + def device_total_dur(self): + if self.is_root_node: + return sum((node.device_total_dur for node in self.child_nodes)) + return sum((kernel.dur for kernel in self._kernel_total_list)) + + @property + def device_self_dur(self): + return self.device_total_dur - sum((node.device_total_dur for node in self.child_nodes)) + + @property + def input_data(self) -> dict: + data = {} + input_dim = self._event.args.get("Input Dims") + if input_dim: + data["Input Dims"] = input_dim + input_type = self._event.args.get("Input type") + if input_type: + data["Input type"] = input_type + return data + + @property + def data(self): + return {"Input Data": self.input_data, + "Host Self Duration(us)": round(self.host_self_dur, 2), + "Host Total Duration(us)": round(self.host_total_dur, 2), + "Device Self Duration(us)": round(self.device_self_dur, 2), + "Device Total Duration(us)": round(self.device_total_dur, 2)} + + @property + def info(self): + return {"id": self.node_id, + "node_type": self.MODULE_TYPE, + "data": self.data, + "upnode": self.parent_node.node_id if self.parent_node else "None", + "subnodes": [node.node_id for node in iter(self.child_nodes)]} + + @property + def is_root_node(self): + return self.node_id == Constant.NPU_ROOT_ID + + def update_child_nodes(self, node): + self._child_nodes.append(node) + + def update_kernel_total_list(self, kernel_list: list): + self._kernel_total_list.extend(kernel_list) diff --git a/profiler/module_visualization/graph_build/__init__.py b/profiler/module_visualization/graph_build/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/profiler/module_visualization/graph_build/fwd_module_node.py b/profiler/module_visualization/graph_build/fwd_module_node.py new file mode 100644 index 0000000000000000000000000000000000000000..34d7ab829649f482c97fb489ac0399d3a876c100 --- /dev/null +++ b/profiler/module_visualization/graph_build/fwd_module_node.py @@ -0,0 +1,29 @@ +# Copyright (c) 2024 Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from profiler.prof_common.base_node import BaseNode +from profiler.prof_common.trace_event_bean import TraceEventBean + + +class FwdModuleNode(BaseNode): + def __init__(self, event: TraceEventBean, parent_node=None): + super().__init__(event, parent_node) + self._bwd_op_list = [] + + @property + def bwd_op_list(self): + return self._bwd_op_list + + def update_bwd_op(self, bwd_op_list: list): + self._bwd_op_list.extend(bwd_op_list) diff --git a/profiler/module_visualization/graph_build/prof_graph_builder.py b/profiler/module_visualization/graph_build/prof_graph_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..83331b6250211e32399b05cabf19a293759a3741 --- /dev/null +++ b/profiler/module_visualization/graph_build/prof_graph_builder.py @@ -0,0 +1,115 @@ +# Copyright (c) 2024 Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from profiler.module_visualization.graph.prof_node import ProfNode +from profiler.module_visualization.graph_build.fwd_module_node import FwdModuleNode +from profiler.prof_common.tree_builder import TreeBuilder +from profiler.prof_common.trace_event_bean import TraceEventBean +from profiler.prof_common.constant import Constant +from profiler.module_visualization.prof_parse.prof_data_pre_process import ProfDataPreProcess + + +class ProfGraphBuilder: + def __init__(self, prof_data_path: str): + self._prof_data_path = prof_data_path + self._prof_data = {} + + @classmethod + def _create_event_bean_from_ops(cls, op_list: list, name: str) -> TraceEventBean: + min_start = min((op.start_time for op in iter(op_list))) + max_end = max((op.end_time for op in iter(op_list))) + # 以反向算子的区间作为反向module的区间范围,为了module包含算子,做了+1 +2处理 + return TraceEventBean({"ts": min_start - 1, "dur": float(max_end - min_start) + 2, "name": name}) + + @classmethod + def _trans_flow_to_dict(cls, flow_events: dict, end_events: list) -> dict: + end_event_dict = {} + for event in end_events: + end_event_dict[event.start_time] = event + result_data = {} + for flow in flow_events.values(): + start_point = flow.get("start") + end_point = flow.get("end") + if not start_point or not end_point: + continue + end_event = end_event_dict.get(end_point.start_time) + if end_event: + result_data.setdefault(start_point.start_time, []).append(end_event) + return result_data + + def build_graph(self): + self._prof_data = ProfDataPreProcess(self._prof_data_path).run() + all_data = [*self._prof_data.get(Constant.MODULE_EVENT, []), + *self.find_bwd_module(), + *self._prof_data.get(Constant.CPU_OP_EVENT, [])] + all_data.sort(key=lambda x: x.start_time) + name_dict = {} + for event in all_data: + order_id = name_dict.get(event.name, 0) + event.set_id(f"{event.name}_{order_id}") + name_dict[event.name] = order_id + 1 + root_node = TreeBuilder.build_tree(all_data, ProfNode, TraceEventBean({}, Constant.NPU_ROOT_ID)) + kernel_flow_dict = self._trans_flow_to_dict(self._prof_data.get(Constant.TORCH_TO_NPU_FLOW, {}), + self._prof_data.get(Constant.KERNEL_EVENT, [])) + for start_time, kernels in kernel_flow_dict.items(): + matched_node = root_node.binary_search(start_time) + while matched_node != Constant.INVALID_RETURN: + matched_node.update_kernel_total_list(kernels) + matched_node = matched_node.binary_search(start_time) + all_data = root_node.find_all_child_nodes() + all_data.append(root_node) + return all_data + + def find_bwd_module(self) -> list: + bwd_module_list = [] + fwdbwd_flow = self._prof_data.get(Constant.FWD_BWD_FLOW, {}) + module_list = self._prof_data.get(Constant.MODULE_EVENT, []) + cpu_op_list = self._prof_data.get(Constant.CPU_OP_EVENT, []) + if not fwdbwd_flow or not module_list or not cpu_op_list: + return bwd_module_list + fwd_tid = module_list[0].tid + bwd_tid = fwd_tid + for end_point in (flow.get("end") for flow in fwdbwd_flow.values()): + if end_point: + bwd_tid = end_point.tid + break + if fwd_tid == bwd_tid: + return bwd_module_list + # 将每一个反向包成一个module,名字叫“nn.Module: BACKWARD_0” + cpu_op_list.sort(key=lambda x: x.start_time) + pre_status = Constant.FWD_OR_OPT + bwd_op_list = [] + for op in cpu_op_list: + if op.tid == bwd_tid: + bwd_op_list.append(op) + pre_status = Constant.BACKWARD + elif pre_status == Constant.BACKWARD: + bwd_module_list.append(self._create_event_bean_from_ops(bwd_op_list, "nn.Module: BACKWARD")) + bwd_op_list.clear() + pre_status = Constant.FWD_OR_OPT + + # 通过连线匹配正向module,构建出反向的整体module关系 + root_node = TreeBuilder.build_tree(module_list, FwdModuleNode, TraceEventBean({})) + fwdbwd_flow_dict = self._trans_flow_to_dict(fwdbwd_flow, cpu_op_list) + for start_time, end_events in fwdbwd_flow_dict.items(): + matched_node = root_node.binary_search(start_time) + while matched_node != Constant.INVALID_RETURN: + matched_node.update_bwd_op(end_events) + matched_node = matched_node.binary_search(start_time) + all_nodes = root_node.find_all_child_nodes() + for module_node in all_nodes: + if module_node.bwd_op_list: + bwd_module_list.append( + self._create_event_bean_from_ops(module_node.bwd_op_list, f"{module_node.name} [BACKWARD]")) + return bwd_module_list diff --git a/profiler/module_visualization/prof_graph_export.py b/profiler/module_visualization/prof_graph_export.py new file mode 100644 index 0000000000000000000000000000000000000000..d336e97f7419b53d011fa4c043948c60afa5174d --- /dev/null +++ b/profiler/module_visualization/prof_graph_export.py @@ -0,0 +1,39 @@ +# Copyright (c) 2024 Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from datetime import datetime + +from profiler.prof_common.constant import Constant +from profiler.prof_common.file_reader import FileReader +from profiler.prof_common.path_manager import PathManager +from profiler.module_visualization.graph_build.prof_graph_builder import ProfGraphBuilder + + +class ProfGraphExport: + @staticmethod + def export_to_json(prof_data_path: str, output_path: str): + logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s") + try: + PathManager.input_path_common_check(prof_data_path) + PathManager.check_input_directory_path(output_path) + PathManager.make_dir_safety(output_path) + all_nodes = ProfGraphBuilder(prof_data_path).build_graph() + result_data = {"root": Constant.NPU_ROOT_ID, "node": {}} + for node in all_nodes: + result_data["node"][node.node_id] = node.info + file_name = "prof_graph_json_{}.vis".format(datetime.utcnow().strftime("%Y%m%d%H%M%S%f")[:-3]) + FileReader.write_json_file(output_path, result_data, file_name) + except RuntimeError as err: + logging.error(err) diff --git a/profiler/module_visualization/prof_parse/__init__.py b/profiler/module_visualization/prof_parse/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/profiler/module_visualization/prof_parse/prof_data_pre_process.py b/profiler/module_visualization/prof_parse/prof_data_pre_process.py new file mode 100644 index 0000000000000000000000000000000000000000..9dc820e4ca560f816b7738243197b90f1adb8c25 --- /dev/null +++ b/profiler/module_visualization/prof_parse/prof_data_pre_process.py @@ -0,0 +1,102 @@ +# Copyright (c) 2024 Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +from profiler.prof_common.file_reader import FileReader +from profiler.prof_common.constant import Constant +from profiler.prof_common.trace_event_bean import TraceEventBean + + +class ProfDataPreProcess: + def __init__(self, prof_data_path: str): + self._prof_data_path = prof_data_path + self._trace_path = "" + self._kernel_pid = None + self._result_data = {Constant.CPU_OP_EVENT: [], Constant.MODULE_EVENT: [], Constant.KERNEL_EVENT: [], + Constant.TORCH_TO_NPU_FLOW: {}, Constant.FWD_BWD_FLOW: {}} + + def run(self) -> dict: + self._check_trace_path() + self._parse_trace_events() + self._check_result_data() + return self._result_data + + def _check_trace_path(self): + if os.path.isfile(self._prof_data_path): + (split_file_path, split_file_name) = os.path.split(self._prof_data_path) + (shot_name, extension) = os.path.splitext(split_file_name) + if extension != ".json": + msg = f"Invalid profiling path suffix: {self._prof_data_path}. " \ + f"You should input in a json file path, such as trace_view.json." + raise RuntimeError(msg) + self._trace_path = self._prof_data_path + return + ascend_output = os.path.join(self._prof_data_path, "ASCEND_PROFILER_OUTPUT") + profiler_output = ascend_output if os.path.isdir(ascend_output) else self._prof_data_path + json_path = os.path.join(profiler_output, "trace_view.json") + if not os.path.isfile(json_path): + msg = f"Invalid profiling path: {self._prof_data_path}. The data path should be the " \ + f"folder that ends with the ascend_pt collected by the Ascend PyTorch Profiler." + raise RuntimeError(msg) + self._trace_path = json_path + + def _parse_trace_events(self): + trace_data = FileReader.read_json_file(self._trace_path) + self._check_trace_data(trace_data) + iter_trace_data = iter(trace_data) + for event in iter_trace_data: + bean = TraceEventBean(event) + if bean.is_optimizer(): + self._result_data[Constant.MODULE_EVENT].append(bean) + elif bean.is_cpu_op(): + if not bean.is_step(): + self._result_data[Constant.CPU_OP_EVENT].append(bean) + elif bean.is_nn_module(): + self._result_data[Constant.MODULE_EVENT].append(bean) + elif bean.is_torch_to_npu(): + if bean.is_flow_start(): + self._result_data[Constant.TORCH_TO_NPU_FLOW].setdefault(bean.id, {})["start"] = bean + else: + self._result_data[Constant.TORCH_TO_NPU_FLOW].setdefault(bean.id, {})["end"] = bean + elif bean.is_fwd_bwd_flow(): + if bean.is_flow_start(): + self._result_data[Constant.FWD_BWD_FLOW].setdefault(bean.id, {})["start"] = bean + else: + self._result_data[Constant.FWD_BWD_FLOW].setdefault(bean.id, {})["end"] = bean + elif bean.is_kernel_event(self._kernel_pid): + self._result_data[Constant.KERNEL_EVENT].append(bean) + + def _check_trace_data(self, trace_data): + if not isinstance(trace_data, list): + msg = f"Invalid profiling data path, this feature only supports performance data " \ + f"collected by Ascend PyTorch Profiler." + raise RuntimeError(msg) + iter_trace_data = iter(trace_data) + for event in iter_trace_data: + bean = TraceEventBean(event) + if bean.is_npu_process(): + self._kernel_pid = bean.pid + break + if self._kernel_pid is None: + msg = f"There is no operator on the NPU side for this data, please check whether the NPU switch is enabled." + raise RuntimeError(msg) + + def _check_result_data(self): + if not self._result_data.get(Constant.CPU_OP_EVENT): + msg = f"This data does not have any aten operator, please make sure to enable the CPU switch." + raise RuntimeError(msg) + if not self._result_data.get(Constant.MODULE_EVENT): + msg = f"This data does not collect any modules, please make sure to turn on the with_stack switch." + raise RuntimeError(msg) diff --git a/profiler/prof_common/base_node.py b/profiler/prof_common/base_node.py new file mode 100644 index 0000000000000000000000000000000000000000..b7cd6780003f9e0e5c58495ac43a893214e68beb --- /dev/null +++ b/profiler/prof_common/base_node.py @@ -0,0 +1,78 @@ +# Copyright (c) 2024 Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from math import ceil +from queue import Queue + +from decimal import Decimal + +from profiler.prof_common.constant import Constant +from profiler.prof_common.trace_event_bean import TraceEventBean + + +class BaseNode: + def __init__(self, event: TraceEventBean, parent_node=None): + self._event = event + self._parent_node = parent_node + self._child_nodes = [] + + @property + def parent_node(self): + return self._parent_node + + @property + def child_nodes(self): + return self._child_nodes + + @property + def name(self): + return self._event.name + + @property + def start_time(self) -> Decimal: + return self._event.start_time + + @property + def end_time(self) -> Decimal: + return self._event.end_time + + def update_child_nodes(self, node): + self._child_nodes.append(node) + + def binary_search(self, ts_time): + if not self.child_nodes: + return Constant.INVALID_RETURN + right = len(self.child_nodes) - 1 + left = 0 + while right > left: + mid = left + ceil((right - left) / 2) + if ts_time >= self.child_nodes[mid].start_time: + left = mid + else: + right = mid - 1 + if self.child_nodes[left].start_time < ts_time < self.child_nodes[left].end_time: + return self.child_nodes[left] + return Constant.INVALID_RETURN + + def find_all_child_nodes(self) -> list: + result_data = [] + node_queue = Queue() + for child_node in self.child_nodes: + node_queue.put(child_node) + while not node_queue.empty(): + tree_node = node_queue.get() + result_data.append(tree_node) + for child_node in tree_node.child_nodes: + node_queue.put(child_node) + return result_data diff --git a/profiler/prof_common/constant.py b/profiler/prof_common/constant.py index 5789b89cb1a248977b64839339395acc5288b2ab..87bc51b56bc71c2a70e35a6b08aa4de7bd521f1d 100644 --- a/profiler/prof_common/constant.py +++ b/profiler/prof_common/constant.py @@ -15,4 +15,17 @@ class Constant(object): COLLECTION_PATH = "collection_path" ANALYSIS_MODE = "analysis_mode" - CONTEXT_SETTINGS = dict(help_option_names=['-H', '-h', '--help']) \ No newline at end of file + CONTEXT_SETTINGS = dict(help_option_names=['-H', '-h', '--help']) + + MAX_FILE_SIZE_5_GB = 1024 * 1024 * 1024 * 5 + + MODULE_EVENT = "module_event" + CPU_OP_EVENT = "op_event" + TORCH_TO_NPU_FLOW = "torch_to_device" + KERNEL_EVENT = "kernel_event" + FWD_BWD_FLOW = "fwd_to_bwd" + NPU_ROOT_ID = "NPU" + + FWD_OR_OPT = 0 + BACKWARD = 1 + INVALID_RETURN = -1 diff --git a/profiler/prof_common/file_reader.py b/profiler/prof_common/file_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..d8a9c8fb4d6599edf46973f8e93aa708903ff007 --- /dev/null +++ b/profiler/prof_common/file_reader.py @@ -0,0 +1,59 @@ +# Copyright (c) 2024 Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import logging +import os + +from profiler.prof_common.path_manager import PathManager +from profiler.prof_common.constant import Constant + + +class FileReader: + DATA_FILE_AUTHORITY = 0o640 + DATA_DIR_AUTHORITY = 0o750 + + @classmethod + def read_json_file(cls, file_path: str) -> any: + PathManager.check_path_readable(file_path) + if not os.path.isfile(file_path): + raise FileNotFoundError("File not exists.") + file_size = os.path.getsize(file_path) + if file_size <= 0: + return [] + if file_size > Constant.MAX_FILE_SIZE_5_GB: + msg = f"The file({file_path}) size exceeds the preset max value, failed to read the file." + raise RuntimeError(msg) + try: + with open(file_path, "rt") as file: + json_data = json.loads(file.read()) + except Exception as e: + msg = f"Can't read file: {file_path}" + raise RuntimeError(msg) from e + return json_data + + @classmethod + def write_json_file(cls, output_path: str, data: dict, file_name: str, format_json: bool = False) -> None: + if not data: + return + output_file = os.path.join(output_path, file_name) + PathManager.check_path_writeable(output_path) + try: + with os.fdopen( + os.open(output_file, os.O_WRONLY | os.O_CREAT, cls.DATA_FILE_AUTHORITY), 'w' + ) as file: + indent = 4 if format_json else None + file.write(json.dumps(data, indent=indent)) + except Exception as e: + raise RuntimeError(f"Can't create the file: {output_path}") from e diff --git a/profiler/prof_common/path_manager.py b/profiler/prof_common/path_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..3e41b8b50aca42ba33071b2661966d221102e106 --- /dev/null +++ b/profiler/prof_common/path_manager.py @@ -0,0 +1,191 @@ +# Copyright (c) 2024 Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import re +import shutil +import platform + + +class PathManager: + MAX_PATH_LENGTH = 4096 + MAX_FILE_NAME_LENGTH = 255 + DATA_FILE_AUTHORITY = 0o640 + DATA_DIR_AUTHORITY = 0o750 + WINDOWS = "windows" + + @classmethod + def check_input_directory_path(cls, path: str): + """ + Function Description: + check whether the path is valid, some businesses can accept a path that does not exist, + so the function do not verify whether the path exists + Parameter: + path: the path to check, whether the incoming path is absolute or relative depends on the business + Exception Description: + when invalid data throw exception + """ + cls.input_path_common_check(path) + base_name = os.path.basename(path) + if os.path.isfile(path): + msg = f"Invalid input path which is a file path: {base_name}" + raise RuntimeError(msg) + + @classmethod + def check_input_file_path(cls, path: str): + """ + Function Description: + check whether the file path is valid, some businesses can accept a path that does not exist, + so the function do not verify whether the path exists + Parameter: + path: the file path to check, whether the incoming path is absolute or relative depends on the business + Exception Description: + when invalid data throw exception + """ + cls.input_path_common_check(path) + base_name = os.path.basename(path) + if os.path.isdir(path): + msg = f"Invalid input path which is a directory path: {base_name}" + raise RuntimeError(msg) + + @classmethod + def check_path_length(cls, path: str): + if len(path) > cls.MAX_PATH_LENGTH: + raise RuntimeError("Length of input path exceeds the limit.") + path_split_list = path.split("/") + for path in path_split_list: + path_list = path.split("\\") + for name in path_list: + if len(name) > cls.MAX_FILE_NAME_LENGTH: + raise RuntimeError("Length of input path exceeds the limit.") + + @classmethod + def input_path_common_check(cls, path: str): + cls.check_path_length(path) + + if os.path.islink(path): + msg = f"Invalid input path which is a soft link." + raise RuntimeError(msg) + + if platform.system().lower() == cls.WINDOWS: + pattern = r'(\.|:|\\|/|_|-|\s|[~0-9a-zA-Z\u4e00-\u9fa5])+' + else: + pattern = r'(\.|/|_|-|\s|[~0-9a-zA-Z])+' + if not re.fullmatch(pattern, path): + msg = f"Invalid input path." + raise RuntimeError(msg) + + @classmethod + def check_path_owner_consistent(cls, path: str): + """ + Function Description: + check whether the path belong to process owner + Parameter: + path: the path to check + Exception Description: + when invalid path, prompt the user + """ + base_name = os.path.basename(path) + if not os.path.exists(path): + msg = f"Invalid path: {base_name}" + raise RuntimeError(msg) + if platform.system().lower() == cls.WINDOWS: + return + if os.stat(path).st_uid != os.getuid(): + check_msg = input("The path does not belong to you, do you want to continue? [y/n]") + if check_msg.lower() != "y": + raise RuntimeError("The user choose not to continue.") + + @classmethod + def check_path_writeable(cls, path): + """ + Function Description: + check whether the path is writable + Parameter: + path: the path to check + Exception Description: + when invalid data throw exception + """ + cls.check_path_owner_consistent(path) + if os.path.islink(path): + msg = f"Invalid path which is a soft link." + raise RuntimeError(msg) + base_name = os.path.basename(path) + if not os.access(path, os.W_OK): + msg = f"The path permission check failed: {base_name}" + raise RuntimeError(msg) + + @classmethod + def check_path_readable(cls, path): + """ + Function Description: + check whether the path is writable + Parameter: + path: the path to check + Exception Description: + when invalid data throw exception + """ + cls.check_path_owner_consistent(path) + if os.path.islink(path): + msg = f"Invalid path which is a soft link." + raise RuntimeError(msg) + base_name = os.path.basename(path) + if not os.access(path, os.R_OK): + msg = f"The path permission check failed: {base_name}" + raise RuntimeError(msg) + + @classmethod + def remove_path_safety(cls, path: str): + base_name = os.path.basename(path) + msg = f"Failed to remove path: {base_name}" + if os.path.islink(path): + raise RuntimeError(msg) + if os.path.exists(path): + try: + shutil.rmtree(path) + except Exception as err: + raise RuntimeError(msg) from err + + @classmethod + def make_dir_safety(cls, path: str): + base_name = os.path.basename(path) + msg = f"Failed to make directory: {base_name}" + if os.path.islink(path): + raise RuntimeError(msg) + if os.path.exists(path): + return + try: + os.makedirs(path, mode=cls.DATA_DIR_AUTHORITY) + except Exception as err: + raise RuntimeError(msg) from err + + @classmethod + def create_file_safety(cls, path: str): + base_name = os.path.basename(path) + msg = f"Failed to create file: {base_name}" + if os.path.islink(path): + raise RuntimeError(msg) + if os.path.exists(path): + return + try: + os.close(os.open(path, os.O_WRONLY | os.O_CREAT, cls.DATA_FILE_AUTHORITY)) + except Exception as err: + raise RuntimeError(msg) from err + + @classmethod + def get_realpath(cls, path: str) -> str: + if os.path.islink(path): + msg = f"Invalid input path which is a soft link." + raise RuntimeError(msg) + return os.path.realpath(path) diff --git a/profiler/prof_common/trace_event_bean.py b/profiler/prof_common/trace_event_bean.py new file mode 100644 index 0000000000000000000000000000000000000000..2d4b96e4f6aa84ce225531da89085ba4a07335a5 --- /dev/null +++ b/profiler/prof_common/trace_event_bean.py @@ -0,0 +1,69 @@ +# Copyright (c) 2024 Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from decimal import Decimal + +from profiler.prof_common.utils import convert_to_decimal +from profiler.prof_common.analyze_dict import AnalyzeDict + + +class TraceEventBean(AnalyzeDict): + def __init__(self, data: dict, unique_id: int = None): + super().__init__(data) + self._id = unique_id + + @property + def unique_id(self): + return self._id + + @property + def start_time(self) -> Decimal: + return convert_to_decimal(self.ts) + + @property + def end_time(self) -> Decimal: + return self.start_time + convert_to_decimal(self.dur) + + def set_id(self, name_id): + self._id = name_id + + def is_cpu_op(self): + return self.cat == "cpu_op" + + def is_optimizer(self): + return self.cat == "cpu_op" and self.name.lower().startswith("optimizer") + + def is_nn_module(self): + return self.cat == "python_function" and self.name.lower().startswith("nn.module") + + def is_step(self): + return self.name.lower().startswith("profilerstep#") + + def is_torch_to_npu(self): + return self.cat == "async_npu" + + def is_fwd_bwd_flow(self): + return self.cat == "fwdbwd" + + def is_flow_start(self): + return self.ph == "s" + + def is_flow_end(self): + return self.ph == "f" + + def is_kernel_event(self, kernel_pid): + return self.ph == "X" and self.pid == kernel_pid + + def is_npu_process(self): + return self.ph == "M" and self.name == "process_name" and self.args.get("name", "") == "Ascend Hardware" diff --git a/profiler/prof_common/tree_builder.py b/profiler/prof_common/tree_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..b7d3e1baf6aa48c480124056ced422178f8fe7a2 --- /dev/null +++ b/profiler/prof_common/tree_builder.py @@ -0,0 +1,33 @@ +# Copyright (c) 2024 Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from profiler.prof_common.trace_event_bean import TraceEventBean + + +class TreeBuilder: + @staticmethod + def build_tree(event_list: list, node_class: any, root_bean: any): + root_node = node_class(root_bean) + event_list.sort(key=lambda x: x.start_time) + last_node = root_node + for event in event_list: + while last_node: + if last_node != root_node and event.start_time > last_node.end_time: + last_node = last_node.parent_node + continue + tree_node = node_class(event, last_node) + last_node.update_child_nodes(tree_node) + last_node = tree_node + break + return root_node diff --git a/profiler/prof_common/utils.py b/profiler/prof_common/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a9db41ad0b8d9dd91132959fd5b583f5711d88db --- /dev/null +++ b/profiler/prof_common/utils.py @@ -0,0 +1,25 @@ +# Copyright (c) 2024 Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from decimal import Decimal + + +def convert_to_decimal(data: any) -> Decimal: + try: + decimal_value = Decimal(data) + except Exception: + logging.error('Invalid profiling data which failed to convert data to decimal.') + return 0.0 + return decimal_value