diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py
index c7b175cd4913e2072acca36d53c2952bb0eadce0..1b796359045f24e5ce1ce620c1002b5e7dcff782 100644
--- a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py
+++ b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py
@@ -1,13 +1,10 @@
# 进行比对及结果展示
import os
-import csv
from collections import namedtuple
import torch
import numpy as np
-from rich.table import Table
-from rich.console import Console
-from api_accuracy_checker.common.utils import get_json_contents, write_csv, print_warn_log, Const
+from api_accuracy_checker.common.utils import get_json_contents, write_csv, print_info_log, Const
from api_accuracy_checker.compare.compare_utils import CompareConst, check_dtype_comparable, DETAIL_TEST_ROWS, \
precision_configs, BENCHMARK_COMPARE_SUPPORT_LIST, AbsoluteStandardApi, BinaryStandardApi, ULPStandardApi, \
ThousandthStandardApi, apis_threshold
@@ -17,7 +14,6 @@ from api_accuracy_checker.compare.algorithm import get_rmse, get_error_balance,
get_small_value_err_ratio, get_finite_and_infinite_mask, get_small_value_mask, check_inf_nan_value, \
check_small_value, check_norm_value, get_abs_bench_with_eps, get_ulp_err
from api_accuracy_checker.common.config import msCheckerConfig
-from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileOpen
ResultInfo = namedtuple('ResultInfo', ['full_api_name', 'fwd_success_status', 'bwd_success_status',
@@ -49,83 +45,13 @@ class Comparator:
else:
self.stack_info = None
- self.test_result_cnt = {
- "success_num": 0, "warning_num": 0, "error_num": 0,
- "forward_fail_num": 0, "backward_fail_num": 0, "forward_and_backward_fail_num": 0,
- "total_num": 0, "total_skip_num": 0
- }
-
@staticmethod
def get_path_from_rank(rank, path_list, path_pattern):
return path_list[-1] if len(path_list) == 1 else path_pattern.format(rank)
- def print_pretest_result(self):
- for save_path in self.save_path_list:
- self.get_statistics_from_result_csv(save_path)
- total_tests = self.test_result_cnt.get("total_num", 0)
- if total_tests != 0:
- passing_rate = "{:.2%}".format(self.test_result_cnt.get("success_num", 0) / total_tests)
- else:
- passing_rate = "0%"
-
- print_warn_log("The follwing tables will be deprecated in the future."
- "The following results are for reference only.")
- console = Console()
- table_total = Table(
- show_header=True, title="Overall Statistics", show_lines=True, width=75
- )
- table_total.add_column("Result")
- table_total.add_column("Statistics")
- table_total.add_row("[green]Pass[/green]", str(self.test_result_cnt.get("success_num", 0)))
- table_total.add_row("[yellow]Warning[/yellow]", str(self.test_result_cnt.get("warning_num", 0)))
- table_total.add_row("[red]Error[/red]", str(self.test_result_cnt.get("error_num", 0)))
- table_total.add_row("Passing Rate", passing_rate)
- table_total.add_row("Skip Tests", str(self.test_result_cnt.get("total_skip_num", 0)))
-
- table_detail = Table(
- show_header=True, title="Detail Statistics", show_lines=True, width=75
- )
- table_detail.add_column("Result")
- table_detail.add_column("Statistics")
- table_detail.add_row("Forward Error", str(self.test_result_cnt.get("forward_fail_num", 0)))
- table_detail.add_row("Backward Error", str(self.test_result_cnt.get("backward_fail_num", 0)))
- table_detail.add_row("Both Forward & Backward Error", str(self.test_result_cnt.get("forward_and_backward_fail_num", 0)))
-
- console.print(table_total)
- console.print(table_detail)
-
- def get_statistics_from_result_csv(self, save_path):
- checklist = [CompareConst.PASS, CompareConst.ERROR, CompareConst.WARNING, CompareConst.SPACE, CompareConst.SKIP, "skip"]
- with FileOpen(save_path, 'r') as file:
- reader = csv.reader(file)
- result_csv_rows = [row for row in reader]
- result_csv_name = os.path.basename(save_path)
- for item in result_csv_rows[1:]:
- if not isinstance(item, list) or len(item) < 3:
- raise ValueError("The number of columns in %s is incorrect" % result_csv_name)
- if not all(item[i] and item[i] in checklist for i in (1, 2)):
- raise ValueError(
- "The value in the 2nd or 3rd column of %s is wrong, it must be pass, error, warning, skip, or SPACE"
- % result_csv_name)
- column1 = item[1]
- column2 = item[2]
- if column1.upper() == CompareConst.SKIP:
- self.test_result_cnt["total_skip_num"] += 1
- continue
- self.test_result_cnt["total_num"] += 1
- if column1 == CompareConst.PASS and column2 in [CompareConst.PASS, CompareConst.SPACE, CompareConst.SKIP]:
- self.test_result_cnt['success_num'] += 1
- elif column1 == CompareConst.ERROR and column2 == CompareConst.ERROR:
- self.test_result_cnt['forward_and_backward_fail_num'] += 1
- self.test_result_cnt['error_num'] += 1
- elif column1 == CompareConst.ERROR:
- self.test_result_cnt['forward_fail_num'] += 1
- self.test_result_cnt['error_num'] += 1
- elif column2 == CompareConst.ERROR:
- self.test_result_cnt['backward_fail_num'] += 1
- self.test_result_cnt['error_num'] += 1
- elif column1 == CompareConst.WARNING or column2 == CompareConst.WARNING:
- self.test_result_cnt['warning_num'] += 1
+ @staticmethod
+ def print_pretest_result():
+ print_info_log("Successfully completed run_ut/multi_run_ut.")
def write_csv_title(self):
summary_test_rows = [[self.COLUMN_API_NAME, self.COLUMN_FORWARD_SUCCESS,
diff --git a/debug/accuracy_tools/atat/core/common/const.py b/debug/accuracy_tools/atat/core/common/const.py
new file mode 100644
index 0000000000000000000000000000000000000000..89de3a4e5a3519d823a61d561cd53e5293272ee2
--- /dev/null
+++ b/debug/accuracy_tools/atat/core/common/const.py
@@ -0,0 +1,237 @@
+import os
+import stat
+import numpy as np
+
+class Const:
+ """
+ Class for const
+ """
+ SEP = "."
+ REGEX_PREFIX_MAX_LENGTH = 20
+ REGEX_PREFIX_PATTERN = r"^[a-zA-Z0-9_-]+$"
+ FILE_PATTERN = r'^[a-zA-Z0-9_./-]+$'
+ COMMA = ","
+ FLOAT_EPSILON = np.finfo(float).eps
+ OFF = 'OFF'
+ BACKWARD = 'backward'
+ FORWARD = 'forward'
+
+ # dump mode
+ ALL = "all"
+ LIST = "list"
+ RANGE = "range"
+ STACK = "stack"
+ ACL = "acl"
+ API_LIST = "api_list"
+ API_STACK = "api_stack"
+ DUMP_MODE = [ALL, LIST, RANGE, STACK, ACL, API_LIST, API_STACK]
+ SUMMARY = "summary"
+ MD5 = "md5"
+ SUMMARY_MODE = [ALL, SUMMARY, MD5]
+
+ WRITE_FLAGS = os.O_WRONLY | os.O_CREAT
+ WRITE_MODES = stat.S_IWUSR | stat.S_IRUSR
+ OVERWRITE_FLAGS = os.O_WRONLY | os.O_CREAT | os.O_TRUNC
+
+ PKL_SUFFIX = ".pkl"
+ NUMPY_SUFFIX = ".npy"
+ ONE_GB = 1073741824 # 1 * 1024 * 1024 * 1024
+ TEN_GB = 10737418240 # 10 * 1024 * 1024 * 1024
+ FILE_PATTERN = r'^[a-zA-Z0-9_./-]+$'
+ DISTRIBUTED_PREFIX_LENGTH = 60
+ # env dump path
+ KWARGS = 'kwargs'
+ INPUT = 'input'
+ OUTPUT = 'output'
+ INPUT_ARGS = 'input_args'
+ INPUT_KWARGS = 'input_kwargs'
+ GRAD_INPUT = 'grad_input'
+ GRAD_OUTPUT = 'grad_output'
+ START = "start"
+ STOP = "stop"
+ ENV_ENABLE = "1"
+ ENV_DISABLE = "0"
+ MAX_SEED_VALUE = 4294967295 # 2**32 - 1
+ TASK_LIST = ["tensor", "statistics", "overflow_check", "free_benchmark"]
+ LEVEL_LIST = ["L0", "L1", "L2", "mix"]
+ STATISTICS = "statistics"
+ TENSOR = "tensor"
+ OVERFLOW_CHECK = "overflow_check"
+ FREE_BENCHMARK = "free_benchmark"
+ ATTR_NAME_PREFIX = "wrap_"
+ KERNEL_DUMP = "kernel_dump"
+ DATA = "data"
+ PT_FRAMEWORK = "pytorch"
+ MS_FRAMEWORK = "mindspore"
+ DIRECTORY_LENGTH = 4096
+ FILE_NAME_LENGTH = 255
+ FLOAT_TYPE = [np.half, np.single, float, np.double, np.float64, np.longdouble, np.float32, np.float16]
+ BOOL_TYPE = [bool, np.uint8]
+ INT_TYPE = [np.int32, np.int64]
+ NPU = 'NPU'
+ DISTRIBUTED = 'Distributed'
+
+ INPLACE_LIST = [
+ "broadcast", "all_reduce", "reduce", "all_gather", "gather", "scatter", "reduce_scatter",
+ "_reduce_scatter_base", "_all_gather_base", "send", "recv", "irecv", "isend", "all_to_all_single"
+ ]
+
+ CONVERT = {
+ "int32_to_int64": ["torch.int32", "torch.int64"],
+ }
+
+ CONVERT_API = {
+ "int32_to_int64": ["cross_entropy"]
+ }
+
+class CompareConst:
+ """
+ Class for compare module const
+ """
+ SPACE = " "
+ # compare result column name
+ NPU_NAME = "NPU Name"
+ BENCH_NAME = "Bench Name"
+ NPU_DTYPE = "NPU Dtype"
+ BENCH_DTYPE = "Bench Dtype"
+ NPU_SHAPE = "NPU Tensor Shape"
+ BENCH_SHAPE = "Bench Tensor Shape"
+ NPU_MAX = "NPU max"
+ NPU_MIN = "NPU min"
+ NPU_MEAN = "NPU mean"
+ NPU_NORM = "NPU l2norm"
+ BENCH_MAX = "Bench max"
+ BENCH_MIN = "Bench min"
+ BENCH_MEAN = "Bench mean"
+ BENCH_NORM = "Bench l2norm"
+ MAX_DIFF = "Max diff"
+ MIN_DIFF = "Min diff"
+ MEAN_DIFF = "Mean diff"
+ NORM_DIFF = "L2norm diff"
+ COSINE = "Cosine"
+ MAX_ABS_ERR = "MaxAbsErr"
+ MAX_RELATIVE_ERR = "MaxRelativeErr"
+ MIN_RELATIVE_ERR = "MinRelativeErr"
+ MEAN_RELATIVE_ERR = "MeanRelativeErr"
+ NORM_RELATIVE_ERR = "NormRelativeErr"
+ ACCURACY = "Accuracy Reached or Not"
+ STACK = "NPU_Stack_Info"
+ DATA_NAME = "Data_name"
+ ERROR_MESSAGE = "Err_message"
+ ONE_THOUSANDTH_ERR_RATIO = "One Thousandth Err Ratio"
+ FIVE_THOUSANDTHS_ERR_RATIO = "Five Thousandths Err Ratio"
+ NPU_MD5 = "NPU MD5"
+ BENCH_MD5 = "BENCH MD5"
+ RESULT = "Result"
+
+ COMPARE_RESULT_HEADER = [
+ NPU_NAME, BENCH_NAME, NPU_DTYPE, BENCH_DTYPE, NPU_SHAPE, BENCH_SHAPE, COSINE, MAX_ABS_ERR, MAX_RELATIVE_ERR,
+ ONE_THOUSANDTH_ERR_RATIO, FIVE_THOUSANDTHS_ERR_RATIO,
+ NPU_MAX, NPU_MIN, NPU_MEAN, NPU_NORM, BENCH_MAX, BENCH_MIN, BENCH_MEAN, BENCH_NORM, ACCURACY, ERROR_MESSAGE
+ ]
+
+ SUMMARY_COMPARE_RESULT_HEADER = [
+ NPU_NAME, BENCH_NAME, NPU_DTYPE, BENCH_DTYPE, NPU_SHAPE, BENCH_SHAPE, MAX_DIFF, MIN_DIFF, MEAN_DIFF, NORM_DIFF,
+ MAX_RELATIVE_ERR, MIN_RELATIVE_ERR, MEAN_RELATIVE_ERR, NORM_RELATIVE_ERR,
+ NPU_MAX, NPU_MIN, NPU_MEAN, NPU_NORM, BENCH_MAX, BENCH_MIN, BENCH_MEAN, BENCH_NORM, RESULT, ERROR_MESSAGE
+ ]
+
+ MD5_COMPARE_RESULT_HEADER = [
+ NPU_NAME, BENCH_NAME, NPU_DTYPE, BENCH_DTYPE, NPU_SHAPE, BENCH_SHAPE, NPU_MD5, BENCH_MD5, RESULT
+ ]
+
+ # compare standard
+ THOUSAND_RATIO_THRESHOLD = 0.001
+ FIVE_THOUSAND_RATIO_THRESHOLD = 0.005
+ COSINE_THRESHOLD = 0.9999
+
+ # compare result data
+ READ_NONE = 'No data'
+ NONE = 'None'
+ SHAPE_UNMATCH = 'shape unmatched'
+ DIFF = 'Different'
+ UNSUPPORTED = 'unsupported'
+ NAN = 'Nan'
+ PASS = 'pass'
+ WARNING = 'Warning'
+ ERROR = 'error'
+ SKIP = 'SKIP'
+ BFLOAT16_MIN = -3.3895313892515355e+38
+ BFLOAT16_MAX = 3.3895313892515355e+38
+ BFLOAT16_EPS = 3.90625e-3 # 2 ** -8
+
+ # accuracy standards
+ COS_THRESHOLD = 0.99
+ MAX_ABS_ERR_THRESHOLD = 0.001
+ COS_MAX_THRESHOLD = 0.9
+ MAX_ABS_ERR_MAX_THRESHOLD = 1
+ ACCURACY_CHECK_YES = "Yes"
+ ACCURACY_CHECK_NO = "No"
+ ACCURACY_CHECK_UNMATCH = "Unmatched"
+
+ # error message
+ NO_BENCH = "No bench data matched."
+
+ # compare const
+ FLOAT_TYPE = [np.half, np.single, float, np.double, np.float64, np.longdouble]
+
+ # highlight xlsx color const
+ RED = "FFFF0000"
+ YELLOW = "FFFF00"
+ BLUE = "0000FF"
+
+ # highlight rules const
+ OVERFLOW_LIST = ['nan\t', 'inf\t', '-inf\t', 'nan', 'inf', '-inf']
+ MAX_DIFF_RED = 1e+10
+ ORDER_MAGNITUDE_DIFF_YELLOW = 1
+ ONE_THOUSAND_ERROR_IN_RED = 0.9
+ ONE_THOUSAND_ERROR_OUT_RED = 0.6
+ ONE_THOUSAND_ERROR_DIFF_YELLOW = 0.1
+ COSINE_DIFF_YELLOW = 0.1
+ MAX_RELATIVE_OUT_RED = 0.5
+ MAX_RELATIVE_OUT_YELLOW = 0.1
+ MAX_RELATIVE_IN_YELLOW = 0.01
+
+class FileCheckConst:
+ """
+ Class for file check const
+ """
+ READ_ABLE = "read"
+ WRITE_ABLE = "write"
+ READ_WRITE_ABLE = "read and write"
+ DIRECTORY_LENGTH = 4096
+ FILE_NAME_LENGTH = 255
+ FILE_VALID_PATTERN = r"^[a-zA-Z0-9_.:/-]+$"
+ FILE_PATTERN = r'^[a-zA-Z0-9_./-]+$'
+ PKL_SUFFIX = ".pkl"
+ NUMPY_SUFFIX = ".npy"
+ JSON_SUFFIX = ".json"
+ PT_SUFFIX = ".pt"
+ CSV_SUFFIX = ".csv"
+ YAML_SUFFIX = ".yaml"
+ MAX_PKL_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024
+ MAX_NUMPY_SIZE = 10737418240 # 10 * 1024 * 1024 * 1024
+ MAX_JSON_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024
+ MAX_PT_SIZE = 10737418240 # 10 * 1024 * 1024 * 1024
+ MAX_CSV_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024
+ MAX_YAML_SIZE = 1048576 # 10 * 1024 * 1024
+ DIR = "dir"
+ FILE = "file"
+ DATA_DIR_AUTHORITY = 0o750
+ DATA_FILE_AUTHORITY = 0o640
+ FILE_SIZE_DICT = {
+ PKL_SUFFIX: MAX_PKL_SIZE,
+ NUMPY_SUFFIX: MAX_NUMPY_SIZE,
+ JSON_SUFFIX: MAX_JSON_SIZE,
+ PT_SUFFIX: MAX_PT_SIZE,
+ CSV_SUFFIX: MAX_CSV_SIZE,
+ YAML_SUFFIX: MAX_YAML_SIZE
+ }
+
+class OverflowConst:
+ """
+ Class for Overflow
+ """
+ OVERFLOW_DEBUG_MODE_ENABLE = "OVERFLOW_DEBUG_MODE_ENABLE"
+ OVERFLOW_ORIGINAL_MODE = 0
+ OVERFLOW_DEBUG_MODE = 1
diff --git a/debug/accuracy_tools/atat/core/common/file_check.py b/debug/accuracy_tools/atat/core/common/file_check.py
index 43207e85e71f6413ed46da77b25b11a23c7b1c0a..2df825aa35108fc08b9d886bf68f0ef3e2bc1533 100644
--- a/debug/accuracy_tools/atat/core/common/file_check.py
+++ b/debug/accuracy_tools/atat/core/common/file_check.py
@@ -19,43 +19,7 @@ import re
from atat.core.common.log import logger
from atat.core.common.exceptions import FileCheckException
-
-
-class FileCheckConst:
- """
- Class for file check const
- """
- READ_ABLE = "read"
- WRITE_ABLE = "write"
- READ_WRITE_ABLE = "read and write"
- DIRECTORY_LENGTH = 4096
- FILE_NAME_LENGTH = 255
- FILE_VALID_PATTERN = r"^[a-zA-Z0-9_.:/-]+$"
- FILE_PATTERN = r'^[a-zA-Z0-9_./-]+$'
- PKL_SUFFIX = ".pkl"
- NUMPY_SUFFIX = ".npy"
- JSON_SUFFIX = ".json"
- PT_SUFFIX = ".pt"
- CSV_SUFFIX = ".csv"
- YAML_SUFFIX = ".yaml"
- MAX_PKL_SIZE = 1 * 1024 * 1024 * 1024
- MAX_NUMPY_SIZE = 10 * 1024 * 1024 * 1024
- MAX_JSON_SIZE = 1 * 1024 * 1024 * 1024
- MAX_PT_SIZE = 10 * 1024 * 1024 * 1024
- MAX_CSV_SIZE = 1 * 1024 * 1024 * 1024
- MAX_YAML_SIZE = 10 * 1024 * 1024
- DIR = "dir"
- FILE = "file"
- DATA_DIR_AUTHORITY = 0o750
- DATA_FILE_AUTHORITY = 0o640
- FILE_SIZE_DICT = {
- PKL_SUFFIX: MAX_PKL_SIZE,
- NUMPY_SUFFIX: MAX_NUMPY_SIZE,
- JSON_SUFFIX: MAX_JSON_SIZE,
- PT_SUFFIX: MAX_PT_SIZE,
- CSV_SUFFIX: MAX_CSV_SIZE,
- YAML_SUFFIX: MAX_YAML_SIZE
- }
+from atat.core.common.const import FileCheckConst
class FileChecker:
diff --git a/debug/accuracy_tools/atat/core/common/utils.py b/debug/accuracy_tools/atat/core/common/utils.py
index 0c74bf038d29b19d68d20b6ffc398b78aee30abe..088530f3c5c88e4e97cd1eda470b02a6a2176fdf 100644
--- a/debug/accuracy_tools/atat/core/common/utils.py
+++ b/debug/accuracy_tools/atat/core/common/utils.py
@@ -26,7 +26,8 @@ from datetime import datetime, timezone
from pathlib import Path
import numpy as np
-from atat.core.common.file_check import FileOpen, FileChecker, FileCheckConst
+from atat.core.common.file_check import FileOpen, FileChecker
+from atat.core.common.const import Const, FileCheckConst, CompareConst, OverflowConst
from atat.core.common.log import logger
@@ -34,206 +35,6 @@ device = collections.namedtuple('device', ['type', 'index'])
prefixes = ['api_stack', 'list', 'range', 'acl']
-class Const:
- """
- Class for const
- """
- SEP = "."
- MODEL_TYPE = ['.onnx', '.pb', '.om']
- DIM_PATTERN = r"^(-?[0-9]+)(,-?[0-9]+)*"
- REGEX_PREFIX_MAX_LENGTH = 20
- REGEX_PREFIX_PATTERN = r"^[a-zA-Z0-9_-]+$"
- SEMICOLON = ";"
- COLON = ":"
- EQUAL = "="
- COMMA = ","
- DOT = "."
- DUMP_RATIO_MAX = 100
- SUMMERY_DATA_NUMS = 256
- FLOAT_EPSILON = np.finfo(float).eps
- SUPPORT_DUMP_MODE = ['api', 'acl']
- ON = 'ON'
- OFF = 'OFF'
- BACKWARD = 'backward'
- FORWARD = 'forward'
- PRE_FORWARD = "pre_forward"
-
- # dump mode
- ALL = "all"
- LIST = "list"
- RANGE = "range"
- STACK = "stack"
- ACL = "acl"
- API_LIST = "api_list"
- API_STACK = "api_stack"
- DUMP_MODE = [ALL, LIST, RANGE, STACK, ACL, API_LIST, API_STACK]
- AUTO = "auto"
- ONLINE_DUMP_MODE = [ALL, LIST, AUTO, OFF]
- SUMMARY = "summary"
- MD5 = "md5"
- SUMMARY_MODE = [ALL, SUMMARY, MD5]
-
- WRITE_FLAGS = os.O_WRONLY | os.O_CREAT
- WRITE_MODES = stat.S_IWUSR | stat.S_IRUSR
- OVERWRITE_FLAGS = os.O_WRONLY | os.O_CREAT | os.O_TRUNC
-
- PKL_SUFFIX = ".pkl"
- NUMPY_SUFFIX = ".npy"
- ONE_GB = 1 * 1024 * 1024 * 1024
- TEN_GB = 10 * 1024 * 1024 * 1024
- FILE_PATTERN = r'^[a-zA-Z0-9_./-]+$'
- FILE_NAME_LENGTH = 255
- DIRECTORY_LENGTH = 4096
- DISTRIBUTED_PREFIX_LENGTH = 60
- SUMMARY_COLUMN_NUM = 6
- STACK_COLUMN_NUM = 2
- # env dump path
- ASCEND_WORK_PATH = "ASCEND_WORK_PATH"
- DUMP_DIR = "dump_data"
-
- KWARGS = 'kwargs'
- INPUT = 'input'
- OUTPUT = 'output'
- INPUT_ARGS = 'input_args'
- INPUT_KWARGS = 'input_kwargs'
- GRAD_INPUT = 'grad_input'
- GRAD_OUTPUT = 'grad_output'
- START = "start"
- STOP = "stop"
- ENV_ENABLE = "1"
- ENV_DISABLE = "0"
-
- MAX_SEED_VALUE = 2**32 - 1
-
- INPLACE_LIST = ["broadcast", "all_reduce", "reduce", "all_gather", "gather", "scatter", "reduce_scatter",
- "_reduce_scatter_base", "_all_gather_base"]
-
- TASK_LIST = ["tensor", "statistics", "overflow_check", "free_benchmark"]
- LEVEL_LIST = ["L0", "L1", "L2", "mix"]
- STATISTICS = "statistics"
- TENSOR = "tensor"
- OVERFLOW_CHECK = "overflow_check"
- FREE_BENCHMARK = "free_benchmark"
- KERNEL_DUMP = "kernel_dump"
- DATA = "data"
- PT_FRAMEWORK = "pytorch"
- MS_FRAMEWORK = "mindspore"
- DIRECTORY_LENGTH = 4096
- FILE_NAME_LENGTH = 255
- FILE_PATTERN = r'^[a-zA-Z0-9_./-]+$'
- FLOAT_TYPE = [np.half, np.single, float, np.double, np.float64, np.longdouble, np.float32, np.float16]
- BOOL_TYPE = [bool, np.uint8]
- INT_TYPE = [np.int32, np.int64]
- NPU = 'NPU'
- DISTRIBUTED = 'Distributed'
- INPLACE_LIST = ["broadcast", "all_reduce", "reduce", "all_gather", "gather", "scatter", "reduce_scatter",
- "_reduce_scatter_base", "_all_gather_base", "all_to_all_single"]
-
-
-class CompareConst:
- """
- Class for compare module const
- """
- # compare result column name
- NPU_NAME = "NPU Name"
- BENCH_NAME = "Bench Name"
- NPU_DTYPE = "NPU Dtype"
- BENCH_DTYPE = "Bench Dtype"
- NPU_SHAPE = "NPU Tensor Shape"
- BENCH_SHAPE = "Bench Tensor Shape"
- NPU_MAX = "NPU max"
- NPU_MIN = "NPU min"
- NPU_MEAN = "NPU mean"
- NPU_NORM = "NPU l2norm"
- BENCH_MAX = "Bench max"
- BENCH_MIN = "Bench min"
- BENCH_MEAN = "Bench mean"
- BENCH_NORM = "Bench l2norm"
- MAX_DIFF = "Max diff"
- MIN_DIFF = "Min diff"
- MEAN_DIFF = "Mean diff"
- NORM_DIFF = "L2norm diff"
- COSINE = "Cosine"
- MAX_ABS_ERR = "MaxAbsErr"
- MAX_RELATIVE_ERR = "MaxRelativeErr"
- MIN_RELATIVE_ERR = "MinRelativeErr"
- MEAN_RELATIVE_ERR = "MeanRelativeErr"
- NORM_RELATIVE_ERR = "NormRelativeErr"
- ACCURACY = "Accuracy Reached or Not"
- STACK = "NPU_Stack_Info"
- DATA_NAME = "Data_name"
- ERROR_MESSAGE = "Err_message"
- ONE_THOUSANDTH_ERR_RATIO = "One Thousandth Err Ratio"
- FIVE_THOUSANDTHS_ERR_RATIO = "Five Thousandths Err Ratio"
- NPU_MD5 = "NPU MD5"
- BENCH_MD5 = "BENCH MD5"
- RESULT = "Result"
-
- COMPARE_RESULT_HEADER = [
- NPU_NAME, BENCH_NAME, NPU_DTYPE, BENCH_DTYPE, NPU_SHAPE, BENCH_SHAPE, COSINE, MAX_ABS_ERR, MAX_RELATIVE_ERR,
- ONE_THOUSANDTH_ERR_RATIO, FIVE_THOUSANDTHS_ERR_RATIO,
- NPU_MAX, NPU_MIN, NPU_MEAN, NPU_NORM, BENCH_MAX, BENCH_MIN, BENCH_MEAN, BENCH_NORM, ACCURACY, ERROR_MESSAGE
- ]
-
- SUMMARY_COMPARE_RESULT_HEADER = [
- NPU_NAME, BENCH_NAME, NPU_DTYPE, BENCH_DTYPE, NPU_SHAPE, BENCH_SHAPE, MAX_DIFF, MIN_DIFF, MEAN_DIFF, NORM_DIFF,
- MAX_RELATIVE_ERR, MIN_RELATIVE_ERR, MEAN_RELATIVE_ERR, NORM_RELATIVE_ERR,
- NPU_MAX, NPU_MIN, NPU_MEAN, NPU_NORM, BENCH_MAX, BENCH_MIN, BENCH_MEAN, BENCH_NORM, RESULT, ERROR_MESSAGE
- ]
-
- MD5_COMPARE_RESULT_HEADER = [
- NPU_NAME, BENCH_NAME, NPU_DTYPE, BENCH_DTYPE, NPU_SHAPE, BENCH_SHAPE, NPU_MD5, BENCH_MD5, RESULT
- ]
-
- # compare standard
- THOUSAND_RATIO_THRESHOLD = 0.001
- FIVE_THOUSAND_RATIO_THRESHOLD = 0.005
- COSINE_THRESHOLD = 0.9999
-
- # compare result data
- READ_NONE = 'No data'
- NAN = 'Nan'
- NONE = 'None'
- SHAPE_UNMATCH = 'shape unmatched'
- DTYPE_UNMATCH = 'dtype unmatched'
- PASS = 'Pass'
- WARNING = 'Warning'
- DIFF = 'Different'
- UNSUPPORTED = 'unsupported'
-
- # accuracy standards
- COS_THRESHOLD = 0.99
- MAX_ABS_ERR_THRESHOLD = 0.001
- COS_MAX_THRESHOLD = 0.9
- MAX_ABS_ERR_MAX_THRESHOLD = 1
- ACCURACY_CHECK_YES = "Yes"
- ACCURACY_CHECK_NO = "No"
- ACCURACY_CHECK_UNMATCH = "Unmatched"
-
- # error message
- NO_BENCH = "No bench data matched."
-
- # compare const
- FLOAT_TYPE = [np.half, np.single, float, np.double, np.float64, np.longdouble]
-
- # highlight xlsx color const
- RED = "FFFF0000"
- YELLOW = "FFFF00"
- BLUE = "0000FF"
-
- # highlight rules const
- OVERFLOW_LIST = ['nan\t', 'inf\t', '-inf\t', 'nan', 'inf', '-inf']
- MAX_DIFF_RED = 1e+10
- ORDER_MAGNITUDE_DIFF_YELLOW = 1
- ONE_THOUSAND_ERROR_IN_RED = 0.9
- ONE_THOUSAND_ERROR_OUT_RED = 0.6
- ONE_THOUSAND_ERROR_DIFF_YELLOW = 0.1
- COSINE_DIFF_YELLOW = 0.1
- MAX_RELATIVE_OUT_RED = 0.5
- MAX_RELATIVE_OUT_YELLOW = 0.1
- MAX_RELATIVE_IN_YELLOW = 0.01
-
-
class CompareException(Exception):
"""
Class for Accuracy Compare Exception
@@ -273,15 +74,6 @@ class DumpException(CompareException):
pass
-class OverflowConst:
- """
- Class for Overflow
- """
- OVERFLOW_DEBUG_MODE_ENABLE = "OVERFLOW_DEBUG_MODE_ENABLE"
- OVERFLOW_ORIGINAL_MODE = 0
- OVERFLOW_DEBUG_MODE = 1
-
-
def make_dump_path_if_not_exists(dump_path):
if not os.path.exists(dump_path):
try:
diff --git a/debug/accuracy_tools/atat/core/common_config.py b/debug/accuracy_tools/atat/core/common_config.py
index bc4ffd8090b0c91bd04091a6f76e837e7d9249bd..e256372ca877be1ca5474dd87c00decdf9d3c1c1 100644
--- a/debug/accuracy_tools/atat/core/common_config.py
+++ b/debug/accuracy_tools/atat/core/common_config.py
@@ -1,4 +1,4 @@
-from atat.core.common.utils import Const
+from atat.core.common.const import Const
from atat.core.common.log import logger
from atat.core.common.exceptions import MsaccException
diff --git a/debug/accuracy_tools/atat/core/data_dump/data_collector.py b/debug/accuracy_tools/atat/core/data_dump/data_collector.py
index 2a0bc34ba8d8f4849d8056437f2640c11e91b340..f6a9a70b138f1f58c777c5eacf1168b628de3f38 100644
--- a/debug/accuracy_tools/atat/core/data_dump/data_collector.py
+++ b/debug/accuracy_tools/atat/core/data_dump/data_collector.py
@@ -4,7 +4,7 @@ import os
from atat.core.data_dump.scope import build_scope, ListScope
from atat.core.data_dump.json_writer import DataWriter
from atat.core.common.log import logger
-from atat.core.common.utils import Const
+from atat.core.common.const import Const
from atat.core.data_dump.data_processor.factory import DataProcessorFactory
diff --git a/debug/accuracy_tools/atat/core/data_dump/data_processor/base.py b/debug/accuracy_tools/atat/core/data_dump/data_processor/base.py
index 1ee3314b368f3c3382bcb1a221ca846b1beb90f1..208c053192c6da674ec9c7f522a0affdf1d091e2 100644
--- a/debug/accuracy_tools/atat/core/data_dump/data_processor/base.py
+++ b/debug/accuracy_tools/atat/core/data_dump/data_processor/base.py
@@ -4,7 +4,8 @@ from dataclasses import dataclass
from typing import Tuple, Dict, Optional, Any
import numpy as np
from atat.core.common.log import logger
-from atat.core.common.utils import Const, convert_tuple
+from atat.core.common.utils import convert_tuple
+from atat.core.common.const import Const
@dataclass
diff --git a/debug/accuracy_tools/atat/core/data_dump/data_processor/factory.py b/debug/accuracy_tools/atat/core/data_dump/data_processor/factory.py
index 00f2f72e7a8a04fa0fad12cd2935958d50171b4a..bcc771f3684aa55a422691bd9dbfff31f07773dd 100644
--- a/debug/accuracy_tools/atat/core/data_dump/data_processor/factory.py
+++ b/debug/accuracy_tools/atat/core/data_dump/data_processor/factory.py
@@ -1,4 +1,4 @@
-from atat.core.common.utils import Const
+from atat.core.common.const import Const
class DataProcessorFactory:
diff --git a/debug/accuracy_tools/atat/core/data_dump/data_processor/pytorch_processor.py b/debug/accuracy_tools/atat/core/data_dump/data_processor/pytorch_processor.py
index 9f96635e9a3978ff9ff535b2001a15ec89f4f4f5..cf3c5ebe5864ff48d9f2442c020cb3ae99473b50 100644
--- a/debug/accuracy_tools/atat/core/data_dump/data_processor/pytorch_processor.py
+++ b/debug/accuracy_tools/atat/core/data_dump/data_processor/pytorch_processor.py
@@ -6,9 +6,9 @@ from typing import List
import numpy as np
import torch
from atat.core.common.exceptions import MsaccException
-from atat.core.common.file_check import path_len_exceeds_limit, change_mode, FileCheckConst
+from atat.core.common.file_check import path_len_exceeds_limit, change_mode
from atat.core.common.log import logger
-from atat.core.common.utils import Const, OverflowConst
+from atat.core.common.const import Const, OverflowConst, FileCheckConst
from atat.core.data_dump.data_processor.base import BaseDataProcessor, ModuleBackwardInputsOutputs, \
ModuleForwardInputsOutputs, TensorStatInfo
from atat.pytorch.free_benchmark import FreeBenchmarkCheck, UnequalRow
diff --git a/debug/accuracy_tools/atat/core/data_dump/json_writer.py b/debug/accuracy_tools/atat/core/data_dump/json_writer.py
index dd0d2f9c7b539704ecd02a7f8aa2cc006c50b81b..23f37b2342e9bde9d69bb65022a40911d5a54dc4 100644
--- a/debug/accuracy_tools/atat/core/data_dump/json_writer.py
+++ b/debug/accuracy_tools/atat/core/data_dump/json_writer.py
@@ -4,9 +4,9 @@ import fcntl
import json
from pathlib import Path
-from atat.core.common.file_check import FileCheckConst, change_mode
+from atat.core.common.file_check import change_mode
from atat.core.common.log import logger
-from atat.core.common.utils import Const
+from atat.core.common.const import Const, FileCheckConst
class DataWriter:
diff --git a/debug/accuracy_tools/atat/core/data_dump/scope.py b/debug/accuracy_tools/atat/core/data_dump/scope.py
index dc473d7e1460e977d8f3e08d690ad554415239d5..e7114f343fe724ffdd40d4837a09b8417d03a1b0 100644
--- a/debug/accuracy_tools/atat/core/data_dump/scope.py
+++ b/debug/accuracy_tools/atat/core/data_dump/scope.py
@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from atat.core.common.exceptions import ScopeException
-from atat.core.common.utils import Const
+from atat.core.common.const import Const
def build_scope(scope_class, scope=None, api_list=None):
diff --git a/debug/accuracy_tools/atat/pytorch/advisor/advisor.py b/debug/accuracy_tools/atat/pytorch/advisor/advisor.py
index f4cb441f5e6f74a52282db1789bb2a29cf97ea79..43b3f40f97948808a987bd4211530cfca2cb025a 100644
--- a/debug/accuracy_tools/atat/pytorch/advisor/advisor.py
+++ b/debug/accuracy_tools/atat/pytorch/advisor/advisor.py
@@ -20,9 +20,9 @@ import os
from atat.pytorch.advisor.advisor_result import AdvisorResult
from atat.pytorch.advisor.advisor_const import AdvisorConst
from atat.pytorch.common.log import logger
-from atat.core.common.utils import CompareException, CompareConst, Const
-from atat.core.common.file_check import FileChecker, FileCheckConst
-
+from atat.core.common.utils import CompareException
+from atat.core.common.file_check import FileChecker
+from atat.core.common.const import Const, CompareConst, FileCheckConst
class Advisor:
"""
diff --git a/debug/accuracy_tools/atat/pytorch/advisor/advisor_result.py b/debug/accuracy_tools/atat/pytorch/advisor/advisor_result.py
index 59845a75415823246a1fadeba588d5289c3eb272..a24fa2a1155501d91eb2462528f71824f091f318 100644
--- a/debug/accuracy_tools/atat/pytorch/advisor/advisor_result.py
+++ b/debug/accuracy_tools/atat/pytorch/advisor/advisor_result.py
@@ -19,8 +19,8 @@ import time
from atat.pytorch.advisor.advisor_const import AdvisorConst
from atat.pytorch.common.log import logger
-from atat.core.common.utils import Const
-from atat.core.common.file_check import FileCheckConst, change_mode
+from atat.core.common.const import Const, FileCheckConst
+from atat.core.common.file_check import change_mode
class AdvisorResult:
diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/common/utils.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/common/utils.py
index 022edbfcf308a252cf9ef477e62ee49b4d7873ed..9e1b02c0154760f468c8d603e4bf385777258434 100644
--- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/common/utils.py
+++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/common/utils.py
@@ -29,8 +29,8 @@ else:
IS_GPU = False
from atat.pytorch.common.log import logger
-from atat.core.common.file_check import FileCheckConst, FileChecker, FileOpen, change_mode, create_directory
-from atat.pytorch.common.utils import Const
+from atat.core.common.file_check import FileChecker, FileOpen, change_mode, create_directory
+from atat.core.common.const import Const, FileCheckConst
from atat.core.common.utils import CompareException
diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/algorithm.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/algorithm.py
index a450edb929161dd14f2d1476509ff5ce5b7a9d1c..3f13534a5ad986aea8c8c1ed320e15386674445f 100644
--- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/algorithm.py
+++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/algorithm.py
@@ -1,7 +1,7 @@
# 定义比对算法及比对标准
import torch
import numpy as np
-from atat.pytorch.api_accuracy_checker.compare.compare_utils import CompareConst
+from atat.core.common.const import CompareConst
#cos
diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/api_precision_compare.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/api_precision_compare.py
index 7e0617eb3ae1a91062fa25c5cb478e5e674ffb0a..89ed3a1008f08532b3e2ec93446493c40a05e29c 100644
--- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/api_precision_compare.py
+++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/api_precision_compare.py
@@ -8,15 +8,16 @@ import pandas as pd
from atat.pytorch.api_accuracy_checker.common.utils import write_csv
from atat.pytorch.api_accuracy_checker.common.config import msCheckerConfig
-from atat.pytorch.api_accuracy_checker.compare.compare_utils import CompareConst, API_PRECISION_COMPARE_RESULT_FILE_NAME, \
+from atat.pytorch.api_accuracy_checker.compare.compare_utils import API_PRECISION_COMPARE_RESULT_FILE_NAME, \
API_PRECISION_COMPARE_DETAILS_FILE_NAME, BENCHMARK_COMPARE_SUPPORT_LIST, API_PRECISION_COMPARE_UNSUPPORT_LIST, \
ApiPrecisionCompareColumn, AbsoluteStandardApi, BinaryStandardApi, BINARY_COMPARE_UNSUPPORT_LIST, \
convert_str_to_float, CompareMessage
from atat.pytorch.api_accuracy_checker.compare.compare_column import ApiPrecisionOutputColumn
from atat.pytorch.api_accuracy_checker.run_ut.run_ut import get_validated_result_csv_path
-from atat.core.common.file_check import FileCheckConst, FileChecker, change_mode, check_path_before_create, create_directory
+from atat.core.common.file_check import FileChecker, change_mode, check_path_before_create, create_directory
from atat.pytorch.common.log import logger
from atat.core.common.utils import CompareException
+from atat.core.common.const import CompareConst, FileCheckConst
CompareConfig = namedtuple('CompareConfig', ['npu_csv_path', 'gpu_csv_path', 'result_csv_path', 'details_csv_path'])
unsupported_message = 'This data type does not support benchmark compare.'
diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare.py
index fbba1dca002663df9e3c6df983fe4ee3546be13f..cfc783bd75ac0ac89c681864a576c702d5722428 100644
--- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare.py
+++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare.py
@@ -1,13 +1,10 @@
# 进行比对及结果展示
import os
-import csv
import torch
import numpy as np
-from rich.table import Table
-from rich.console import Console
from atat.pytorch.common.log import logger
-from atat.pytorch.api_accuracy_checker.common.utils import get_json_contents, write_csv, Const
-from atat.pytorch.api_accuracy_checker.compare.compare_utils import CompareConst, check_dtype_comparable, \
+from atat.pytorch.api_accuracy_checker.common.utils import get_json_contents, write_csv
+from atat.pytorch.api_accuracy_checker.compare.compare_utils import check_dtype_comparable, \
DETAIL_TEST_ROWS, precision_configs, BENCHMARK_COMPARE_SUPPORT_LIST, AbsoluteStandardApi, BinaryStandardApi, \
apis_threshold
from atat.pytorch.api_accuracy_checker.compare.compare_column import CompareColumn
@@ -16,7 +13,7 @@ from atat.pytorch.api_accuracy_checker.compare.algorithm import get_rmse, get_er
get_small_value_err_ratio, get_finite_and_infinite_mask, get_small_value_mask, check_inf_nan_value, \
check_small_value, check_norm_value, get_abs_bench_with_eps
from atat.pytorch.api_accuracy_checker.common.config import msCheckerConfig
-from atat.core.common.file_check import FileOpen
+from atat.core.common.const import Const, CompareConst
class Comparator:
@@ -36,11 +33,10 @@ class Comparator:
else:
self.stack_info = None
- self.test_result_cnt = {
- "forward_fail_num": 0, "backward_fail_num": 0, "forward_and_backward_fail_num": 0, "success_num": 0,
- "total_num": 0, "forward_or_backward_fail_num": 0
- }
-
+ @staticmethod
+ def print_pretest_result():
+ logger.info("Successfully completed run_ut/multi_run_ut.")
+
@staticmethod
def _compare_dropout(bench_output, device_output):
tensor_num = bench_output.numel()
@@ -77,80 +73,6 @@ class Comparator:
rtol = apis_threshold.get(api_name).get(dtype).get('rtol')
return small_value_threshold, small_value_atol, rtol
- def print_pretest_result(self):
- self.get_statistics_from_result_csv()
- total_tests = self.test_result_cnt.get("total_num", 0)
- if total_tests != 0:
- passing_rate = "{:.2%}".format(self.test_result_cnt.get("success_num", 0) / total_tests)
- else:
- passing_rate = "0%"
-
- logger.warning("The follwing tables will be deprecated in the future."
- "The following results are for reference only.")
- console = Console()
- table_total = Table(
- show_header=True, title="Overall Statistics", show_lines=True, width=75
- )
- table_total.add_column("Result")
- table_total.add_column("Statistics")
- table_total.add_row("[green]Pass[/green]", str(self.test_result_cnt.get("success_num", 0)))
- table_total.add_row("[yellow]Warning[/yellow]", str(self.test_result_cnt.get("warning_num", 0)))
- table_total.add_row("[red]Error[/red]", str(self.test_result_cnt.get("error_num", 0)))
- table_total.add_row("Passing Rate", passing_rate)
- table_total.add_row("Skip Tests", str(self.test_result_cnt.get("total_skip_num", 0)))
-
- table_detail = Table(
- show_header=True, title="Detail Statistics", show_lines=True, width=75
- )
- table_detail.add_column("Result")
- table_detail.add_column("Statistics")
- table_detail.add_row("Forward Error", str(self.test_result_cnt.get("forward_fail_num", 0)))
- table_detail.add_row("Backward Error", str(self.test_result_cnt.get("backward_fail_num", 0)))
- table_detail.add_row("Both Forward & Backward Error",
- str(self.test_result_cnt.get("forward_and_backward_fail_num", 0)))
-
- console.print(table_total)
- console.print(table_detail)
-
- def get_statistics_from_result_csv(self):
- checklist = [CompareConst.PASS, CompareConst.ERROR, CompareConst.WARNING, CompareConst.SPACE, CompareConst.SKIP,
- "skip"]
- self.test_result_cnt = {
- "success_num": 0, "warning_num": 0, "error_num": 0,
- "forward_fail_num": 0, "backward_fail_num": 0, "forward_and_backward_fail_num": 0,
- "total_num": 0, "total_skip_num": 0
- }
- with FileOpen(self.save_path, 'r') as file:
- reader = csv.reader(file)
- result_csv_rows = [row for row in reader]
- result_csv_name = os.path.basename(self.save_path)
- for item in result_csv_rows[1:]:
- if not isinstance(item, list) or len(item) < 3:
- raise ValueError("The number of columns in %s is incorrect" % result_csv_name)
- if not all(item[i] and item[i] in checklist for i in (1, 2)):
- raise ValueError(
- "The value in the 2nd or 3rd column of %s is wrong, it must be pass, error, warning, skip, or SPACE"
- % result_csv_name)
- column1 = item[1]
- column2 = item[2]
- if column1.upper() == CompareConst.SKIP:
- self.test_result_cnt["total_skip_num"] += 1
- continue
- self.test_result_cnt["total_num"] += 1
- if column1 == CompareConst.PASS and column2 in [CompareConst.PASS, CompareConst.SPACE]:
- self.test_result_cnt['success_num'] += 1
- elif column1 == CompareConst.ERROR and column2 == CompareConst.ERROR:
- self.test_result_cnt['forward_and_backward_fail_num'] += 1
- self.test_result_cnt['error_num'] += 1
- elif column1 == CompareConst.ERROR:
- self.test_result_cnt['forward_fail_num'] += 1
- self.test_result_cnt['error_num'] += 1
- elif column2 == CompareConst.ERROR:
- self.test_result_cnt['backward_fail_num'] += 1
- self.test_result_cnt['error_num'] += 1
- elif column1 == CompareConst.WARNING or column2 == CompareConst.WARNING:
- self.test_result_cnt['warning_num'] += 1
-
def write_csv_title(self):
summary_test_rows = [[self.COLUMN_API_NAME, self.COLUMN_FORWARD_SUCCESS,
self.COLUMN_BACKWARD_SUCCESS, "Message"]]
diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare_column.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare_column.py
index bd88d6742f3a26f082f735d514023d74a49c7541..a018b192753d060680344fc4330e8a3427d4e6a5 100644
--- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare_column.py
+++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare_column.py
@@ -1,4 +1,4 @@
-from atat.pytorch.api_accuracy_checker.compare.compare_utils import CompareConst
+from atat.core.common.const import CompareConst
class CompareColumn:
diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare_utils.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare_utils.py
index fe841eb06397443f7ef7d89edc6852d05f7579f5..bcf6b8ea196f5e14085d0200ba484e10017a6059 100644
--- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare_utils.py
+++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare_utils.py
@@ -3,7 +3,8 @@ import os
import numpy as np
import torch
import yaml
-from atat.core.common.utils import Const, CompareException
+from atat.core.common.utils import CompareException
+from atat.core.common.const import Const
from atat.pytorch.common.log import logger
from atat.core.common.file_check import FileOpen
@@ -77,21 +78,6 @@ precision_configs = {
}
}
-
-class CompareConst:
- NAN = np.nan
- NA = "N/A"
- PASS = 'pass'
- WARNING = 'warning'
- ERROR = 'error'
- SKIP = 'SKIP'
- TRUE = 'TRUE'
- FALSE = 'FALSE'
- BFLOAT16_MIN = -3.3895313892515355e+38
- BFLOAT16_MAX = 3.3895313892515355e+38
- BFLOAT16_EPS = 2 ** -8
- SPACE = " "
-
class ApiPrecisionCompareColumn:
API_NAME = 'API Name'
diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/data_generate.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/data_generate.py
index e983413bf00ea32aaea4d6ee7e215097e9f12cff..c6b721eee239fa943861d58ce31c1d8e9145bda9 100644
--- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/data_generate.py
+++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/data_generate.py
@@ -20,8 +20,9 @@ import math
import torch
import numpy
-from atat.pytorch.api_accuracy_checker.common.utils import Const, check_file_or_directory_path, check_object_type, get_full_data_path, CompareException
+from atat.pytorch.api_accuracy_checker.common.utils import check_file_or_directory_path, check_object_type, get_full_data_path, CompareException
from atat.pytorch.common.log import logger
+from atat.core.common.const import Const
TORCH_TYPE = ["torch.device", "torch.dtype"]
TENSOR_DATA_LIST = ["torch.Tensor", "torch.nn.parameter.Parameter"]
diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py
index b9d1a4fd1f3da5c4f115bb8542c9f9150c4d5f35..d2ab9c1e952001f4551044f4395662dd25931e08 100644
--- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py
+++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py
@@ -13,9 +13,10 @@ from atat.pytorch.api_accuracy_checker.run_ut.run_ut import _run_ut_parser, get_
get_validated_details_csv_path, preprocess_forward_content
from atat.pytorch.api_accuracy_checker.compare.compare import Comparator
from atat.pytorch.common import parse_json_info_forward_backward
-from atat.core.common.file_check import FileCheckConst, FileChecker, check_file_suffix, check_link, FileOpen, \
+from atat.core.common.file_check import FileChecker, check_file_suffix, check_link, FileOpen, \
check_path_before_create, create_directory
from atat.pytorch.common.log import logger
+from atat.core.common.const import FileCheckConst
def split_json_file(input_file, num_splits, filter_api):
diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/run_ut.py
index 77f3bf714a37eede95de04fdf1240ff655451e3a..47cbd9944702530680f31029bafb2ca1705b287a 100644
--- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/run_ut.py
+++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/run_ut.py
@@ -27,10 +27,10 @@ from atat.pytorch.hook_module.wrap_functional import FunctionalOPTemplate
from atat.pytorch.hook_module.wrap_torch import TorchOPTemplate
from atat.pytorch.api_accuracy_checker.common.config import msCheckerConfig
from atat.pytorch.common.parse_json import parse_json_info_forward_backward
-from atat.core.common.file_check import FileOpen, FileCheckConst, FileChecker, \
+from atat.core.common.file_check import FileOpen, FileChecker, \
change_mode, check_file_suffix, check_link, check_path_before_create, create_directory
from atat.pytorch.common.log import logger
-from atat.pytorch.common.utils import Const
+from atat.core.common.const import Const, FileCheckConst
current_time = time.strftime("%Y%m%d%H%M%S")
UT_ERROR_DATA_DIR = 'ut_error_data' + current_time
@@ -40,7 +40,11 @@ RunUTConfig = namedtuple('RunUTConfig', ['forward_content', 'backward_content',
'save_error_data', 'is_continue_run_ut', 'real_data_path'])
not_backward_list = ['repeat_interleave']
not_detach_set = {'resize_', 'resize_as_', 'set_', 'transpose_', 't_', 'squeeze_', 'unsqueeze_'}
-
+RAISE_PRECISION = {
+ torch.float16: torch.float32,
+ torch.bfloat16: torch.float32,
+ torch.float32: torch.float64
+}
tqdm_params = {
'smoothing': 0, # 平滑进度条的预计剩余时间,取值范围0到1
'desc': 'Processing', # 进度条前的描述文字
@@ -75,7 +79,7 @@ def deal_detach(arg, to_detach=True):
def deal_dtype(arg, raise_dtype=None):
- if raise_dtype is None or arg.dtype not in Const.RAISE_PRECISION or raise_dtype == arg.dtype:
+ if raise_dtype is None or arg.dtype not in RAISE_PRECISION or raise_dtype == arg.dtype:
return arg
return arg.type(raise_dtype)
@@ -120,7 +124,7 @@ def generate_cpu_params(input_args, input_kwargs, need_backward, api_name):
return arg_in
def is_tensor_with_raise_precision(arg_in, check_kwargs=False):
- if arg_in.dtype in Const.RAISE_PRECISION:
+ if arg_in.dtype in RAISE_PRECISION:
return True
if check_kwargs and arg_in.dtype in [torch.half, torch.bfloat16]:
return True
@@ -139,7 +143,7 @@ def generate_cpu_params(input_args, input_kwargs, need_backward, api_name):
need_raise_dtypes = recursive_find_dtypes(input_args)
need_raise_dtypes.update(recursive_find_dtypes(input_kwargs, check_kwargs=True))
if len(need_raise_dtypes) == 1:
- raise_dtype = Const.RAISE_PRECISION.get(need_raise_dtypes.pop(), torch.float32)
+ raise_dtype = RAISE_PRECISION.get(need_raise_dtypes.pop(), torch.float32)
elif len(need_raise_dtypes) >= 2:
raise_dtype = torch.float32
diff --git a/debug/accuracy_tools/atat/pytorch/compare/acc_compare.py b/debug/accuracy_tools/atat/pytorch/compare/acc_compare.py
index 2d7bdcfff34087c38c783b20dfaafa82977a178d..061c9cdfca872bb76ad8ffa0e292fb24ce1a4ada 100644
--- a/debug/accuracy_tools/atat/pytorch/compare/acc_compare.py
+++ b/debug/accuracy_tools/atat/pytorch/compare/acc_compare.py
@@ -32,9 +32,10 @@ from atat.pytorch.compare.highlight import HighlightRules, get_header_index
from atat.pytorch.compare.npy_compare import compare_ops_apply, get_error_type, reshape_value, get_relative_err, get_error_message
from atat.pytorch.advisor.advisor import Advisor
from atat.pytorch.common.log import logger
-from atat.core.common.utils import check_compare_param, add_time_with_xlsx, CompareException, CompareConst, \
- format_value, check_file_not_exists, check_configuration_param, task_dumppath_get, Const
-from atat.core.common.file_check import FileChecker, FileCheckConst, change_mode, FileOpen, create_directory
+from atat.core.common.utils import check_compare_param, add_time_with_xlsx, CompareException, \
+ format_value, check_file_not_exists, check_configuration_param, task_dumppath_get
+from atat.core.common.file_check import FileChecker, change_mode, FileOpen, create_directory
+from atat.core.common.const import Const, CompareConst, FileCheckConst
def check_graph_mode(a_op_name, b_op_name):
diff --git a/debug/accuracy_tools/atat/pytorch/compare/highlight.py b/debug/accuracy_tools/atat/pytorch/compare/highlight.py
index d94e86b013b9a89c1fe0c412db612ccef3e51991..3a6898dedbb6910d0c1c9e55f80b55eb4fa0ed3c 100644
--- a/debug/accuracy_tools/atat/pytorch/compare/highlight.py
+++ b/debug/accuracy_tools/atat/pytorch/compare/highlight.py
@@ -1,7 +1,8 @@
import math
import abc
import numpy as np
-from atat.core.common.utils import CompareConst, get_header_index
+from atat.core.common.utils import get_header_index
+from atat.core.common.const import CompareConst
class HighlightCheck(abc.ABC):
diff --git a/debug/accuracy_tools/atat/pytorch/compare/npy_compare.py b/debug/accuracy_tools/atat/pytorch/compare/npy_compare.py
index 2e1f22ab3f5243fa4bb6b2e9bb6e7094680081b2..0cf4c6c00a0a671a6ee46dc6dacce48af6e67adf 100644
--- a/debug/accuracy_tools/atat/pytorch/compare/npy_compare.py
+++ b/debug/accuracy_tools/atat/pytorch/compare/npy_compare.py
@@ -1,6 +1,7 @@
import abc
import numpy as np
-from atat.core.common.utils import CompareConst, Const, format_value
+from atat.core.common.utils import format_value
+from atat.core.common.const import Const, CompareConst
from atat.pytorch.common.log import logger
diff --git a/debug/accuracy_tools/atat/pytorch/debugger/debugger_config.py b/debug/accuracy_tools/atat/pytorch/debugger/debugger_config.py
index 6f2bfe8551062e82d552661f186130195b41f4dd..1ad69701e4167db3e2e9f61f7f725f024ec21d16 100644
--- a/debug/accuracy_tools/atat/pytorch/debugger/debugger_config.py
+++ b/debug/accuracy_tools/atat/pytorch/debugger/debugger_config.py
@@ -1,6 +1,6 @@
from atat.pytorch.common import seed_all
from atat.pytorch.common.log import logger
-from atat.core.common.utils import Const
+from atat.core.common.const import Const
class DebuggerConfig:
diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/__init__.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/__init__.py
index f86fc41d557b1801318303ce18a934b2306c223e..b9d41330a87605af37f7b538e8d7bcf56f9725f1 100644
--- a/debug/accuracy_tools/atat/pytorch/free_benchmark/__init__.py
+++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/__init__.py
@@ -1,6 +1,6 @@
from atat.core.common.log import logger
from atat.core.common.exceptions import FreeBenchmarkException
-from atat.pytorch.common.utils import Const
+from atat.core.common.const import Const
from .main import FreeBenchmarkCheck
from .common.params import UnequalRow
diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/main.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/main.py
index d7c2eba377fb228b1ed02f49d9d540108a160ea6..2ebc0a6db917334f492019fa694330d86a0f37e1 100644
--- a/debug/accuracy_tools/atat/pytorch/free_benchmark/main.py
+++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/main.py
@@ -1,7 +1,8 @@
from abc import ABC
import torch
-from atat.pytorch.free_benchmark import Const, logger
+from atat.core.common.const import Const
+from atat.pytorch.free_benchmark import logger
from atat.pytorch.free_benchmark.common.constant import CommonField
from atat.pytorch.free_benchmark.common.enums import (
DeviceType,
diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py
index 2df26afc1beecc21a8a77bddbe76de6142d68862..03718e3c4d6c4c4ea28ae6eec5daddc02bcedb7d 100644
--- a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py
+++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py
@@ -1,5 +1,6 @@
import torch
-from atat.pytorch.free_benchmark import Const, logger
+from atat.core.common.const import Const
+from atat.pytorch.free_benchmark import logger
from atat.pytorch.free_benchmark.common.constant import CommonField
from atat.pytorch.free_benchmark.common.enums import PerturbationMode
from atat.pytorch.free_benchmark.common.params import DataParams
diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/base_handler.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/base_handler.py
index c1bfbae24a5116dc07fdbc97cc5b7691a7003e07..c57d7e390a0fad73a5e87d72ec72e434835618cf 100644
--- a/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/base_handler.py
+++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/base_handler.py
@@ -3,7 +3,8 @@ from abc import ABC, abstractmethod
from typing import Any, Optional, Tuple
import torch
-from atat.pytorch.free_benchmark import Const, logger
+from atat.core.common.const import Const
+from atat.pytorch.free_benchmark import logger
from atat.pytorch.free_benchmark.common.constant import ThresholdConfig
from atat.pytorch.free_benchmark.common.enums import (
FuzzThreshold,
diff --git a/debug/accuracy_tools/atat/pytorch/functional/dump_module.py b/debug/accuracy_tools/atat/pytorch/functional/dump_module.py
index 8652f13f9bcaa46d13fc19bfe74c796ac45cdadf..675fa2a1bfdfdef9b12bf99f2428e3236c86a906 100644
--- a/debug/accuracy_tools/atat/pytorch/functional/dump_module.py
+++ b/debug/accuracy_tools/atat/pytorch/functional/dump_module.py
@@ -1,6 +1,6 @@
import torch.nn as nn
from atat.pytorch.common.log import logger
-from atat.core.common.utils import Const
+from atat.core.common.const import Const
from atat.pytorch.hook_module.api_registry import api_register
from atat.pytorch.debugger.precision_debugger import PrecisionDebugger
from atat.core.common.exceptions import MsaccException
diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/api_registry.py b/debug/accuracy_tools/atat/pytorch/hook_module/api_registry.py
index 6910276f94462018e09cbd3ae865cecff0d0cc1f..3b971cc71ecaf1d229337f4acf5afab6b5b0f9db 100644
--- a/debug/accuracy_tools/atat/pytorch/hook_module/api_registry.py
+++ b/debug/accuracy_tools/atat/pytorch/hook_module/api_registry.py
@@ -25,7 +25,8 @@ from atat.pytorch.hook_module.wrap_functional import get_functional_ops
from atat.pytorch.hook_module.wrap_tensor import get_tensor_ops
from atat.pytorch.hook_module.wrap_torch import get_torch_ops
from atat.pytorch.hook_module.wrap_vf import get_vf_ops
-from atat.pytorch.common.utils import torch_without_guard_version, npu_distributed_api, is_gpu, Const
+from atat.pytorch.common.utils import torch_without_guard_version, npu_distributed_api, is_gpu
+from atat.core.common.const import Const
torch_version_above_2 = torch.__version__.split('+')[0] > '2.0'
diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/hook_module.py b/debug/accuracy_tools/atat/pytorch/hook_module/hook_module.py
index d45a951d479cd235aa6c29e45443d6b22e56dbc9..57212b6e45c572db3d3a235035647bbb614a3cb3 100644
--- a/debug/accuracy_tools/atat/pytorch/hook_module/hook_module.py
+++ b/debug/accuracy_tools/atat/pytorch/hook_module/hook_module.py
@@ -20,7 +20,7 @@ import threading
import torch
import torch.nn as nn
import torch.utils.hooks as full_hooks
-from atat.core.common.utils import Const
+from atat.core.common.const import Const
class HOOKModule(nn.Module):
diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_aten.py b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_aten.py
index c247a27082edf9af738e48706c6308d4916fc586..c5a3c6365d1841927c3acd2993ac2828092ae80b 100644
--- a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_aten.py
+++ b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_aten.py
@@ -21,7 +21,8 @@ import torch
import yaml
from atat.pytorch.hook_module.hook_module import HOOKModule
-from atat.pytorch.common.utils import torch_device_guard, Const
+from atat.pytorch.common.utils import torch_device_guard
+from atat.core.common.const import Const
from atat.core.common.file_check import FileOpen
diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_distributed.py b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_distributed.py
index 1059bf748843ae381e48a1c7103811ad71af83c2..e02189ac1bf2a5d2e1607f2625200eb4ee2ea6e8 100644
--- a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_distributed.py
+++ b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_distributed.py
@@ -21,7 +21,8 @@ import torch.distributed as dist
import yaml
from atat.pytorch.hook_module.hook_module import HOOKModule
-from atat.pytorch.common.utils import torch_device_guard, Const
+from atat.pytorch.common.utils import torch_device_guard
+from atat.core.common.const import Const
from atat.core.common.file_check import FileOpen
diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_functional.py b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_functional.py
index 8c829904cbe848b7e8d57abb1f5f3a2c0bc6d494..fa97f5ee3106c62ab63d80e2b2ebe349494ce695 100644
--- a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_functional.py
+++ b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_functional.py
@@ -21,7 +21,8 @@ import torch
import yaml
from atat.pytorch.hook_module.hook_module import HOOKModule
-from atat.pytorch.common.utils import torch_device_guard, Const
+from atat.pytorch.common.utils import torch_device_guard
+from atat.core.common.const import Const
from atat.pytorch.common.log import logger
from atat.core.common.file_check import FileOpen
diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_npu_custom.py b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_npu_custom.py
index 90ad9cb9c4f3865a7dd110dcd5701acdcaf5ce64..7d0882804f478d8657faeb5f90b0b718914d72e5 100644
--- a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_npu_custom.py
+++ b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_npu_custom.py
@@ -21,7 +21,8 @@ import torch_npu
import yaml
from atat.pytorch.hook_module.hook_module import HOOKModule
-from atat.pytorch.common.utils import torch_device_guard, torch_without_guard_version, Const
+from atat.pytorch.common.utils import torch_device_guard, torch_without_guard_version
+from atat.core.common.const import Const
from atat.core.common.file_check import FileOpen
cur_path = os.path.dirname(os.path.realpath(__file__))
diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_tensor.py b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_tensor.py
index d53291b78faa594b69ac686100a3f48eccce4dc0..6fac18140238e016004c2329f0d8ff647045c0c0 100644
--- a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_tensor.py
+++ b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_tensor.py
@@ -21,7 +21,8 @@ import torch
import yaml
from atat.pytorch.hook_module.hook_module import HOOKModule
-from atat.pytorch.common.utils import torch_device_guard, parameter_adapter, Const
+from atat.pytorch.common.utils import torch_device_guard, parameter_adapter
+from atat.core.common.const import Const
from atat.core.common.file_check import FileOpen
cur_path = os.path.dirname(os.path.realpath(__file__))
diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_torch.py b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_torch.py
index 3cdece23065a7ab830c6125929c7cd4e86bab711..f0bd01fe4624ccc1a50d1f3e1fb5d614b30a7a8d 100644
--- a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_torch.py
+++ b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_torch.py
@@ -21,7 +21,8 @@ import torch
import yaml
from atat.pytorch.hook_module.hook_module import HOOKModule
-from atat.pytorch.common.utils import torch_device_guard, Const
+from atat.pytorch.common.utils import torch_device_guard
+from atat.core.common.const import Const
from atat.core.common.file_check import FileOpen
cur_path = os.path.dirname(os.path.realpath(__file__))
diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_vf.py b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_vf.py
index c5f3cb7ee0624757b617b049935b6aabc593ec8c..d4c570221d4b219be168b25dafd76264769da197 100644
--- a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_vf.py
+++ b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_vf.py
@@ -22,7 +22,8 @@ import yaml
from atat.pytorch.hook_module.hook_module import HOOKModule
from atat.core.common.file_check import FileOpen
-from atat.pytorch.common.utils import torch_device_guard, Const
+from atat.pytorch.common.utils import torch_device_guard
+from atat.core.common.const import Const
cur_path = os.path.dirname(os.path.realpath(__file__))
yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml")
diff --git a/debug/accuracy_tools/atat/pytorch/module_processer.py b/debug/accuracy_tools/atat/pytorch/module_processer.py
index f56513907c262fbc43d1ce76aace271e04caf944..8ce9140e32cbc9f7d2d62eede164e24127a40f34 100644
--- a/debug/accuracy_tools/atat/pytorch/module_processer.py
+++ b/debug/accuracy_tools/atat/pytorch/module_processer.py
@@ -1,7 +1,7 @@
from functools import wraps
import torch
from torch.utils.hooks import BackwardHook
-from atat.core.common.utils import Const
+from atat.core.common.const import Const
from atat.core.data_dump.scope import ModuleRangeScope
diff --git a/debug/accuracy_tools/atat/pytorch/online_dispatch/compare.py b/debug/accuracy_tools/atat/pytorch/online_dispatch/compare.py
index d7f9e4e339d40bd9863b6040430948cf2379d91b..e6d55ca0614767338023673b6d0cfa5c7d990c0f 100644
--- a/debug/accuracy_tools/atat/pytorch/online_dispatch/compare.py
+++ b/debug/accuracy_tools/atat/pytorch/online_dispatch/compare.py
@@ -7,7 +7,8 @@ from collections import namedtuple
from rich.table import Table
from rich.console import Console
from .single_compare import single_benchmark_compare_wrap
-from .utils import DispatchException, CompareConst
+from .utils import DispatchException
+from atat.core.common.const import CompareConst
from atat.core.common.file_check import FileOpen
from atat.pytorch.common.log import logger
from atat.core.common.utils import CompareException
diff --git a/debug/accuracy_tools/atat/pytorch/online_dispatch/dispatch.py b/debug/accuracy_tools/atat/pytorch/online_dispatch/dispatch.py
index 386c3eac1797e177237e9a62389ce3567b982e31..7502d746acf38a52b427659b8d97f5f55a4ce96c 100644
--- a/debug/accuracy_tools/atat/pytorch/online_dispatch/dispatch.py
+++ b/debug/accuracy_tools/atat/pytorch/online_dispatch/dispatch.py
@@ -22,8 +22,8 @@ from .utils import get_callstack, data_to_cpu, logger_debug, logger_error, logge
DispatchException
from .compare import Comparator
from atat.core.common.file_check import FileOpen
-from atat.pytorch.common.utils import Const
-from atat.core.common.utils import CompareConst, check_file_or_directory_path, check_path_before_create
+from atat.core.common.utils import check_file_or_directory_path, check_path_before_create
+from atat.core.common.const import Const, CompareConst
current_time = time.strftime("%Y%m%d%H%M%S")
RESULT_FILE_NAME = "accuracy_checking_result_" + current_time + ".csv"
diff --git a/debug/accuracy_tools/atat/pytorch/online_dispatch/dump_compare.py b/debug/accuracy_tools/atat/pytorch/online_dispatch/dump_compare.py
index b8d824fd1eafbce9603b18535500d21127417fff..cd7c5a3f282d17b4e064a4b00928861b2e32f737 100644
--- a/debug/accuracy_tools/atat/pytorch/online_dispatch/dump_compare.py
+++ b/debug/accuracy_tools/atat/pytorch/online_dispatch/dump_compare.py
@@ -5,11 +5,10 @@ from datetime import datetime, timezone
import pandas as pd
import torch
-from atat.pytorch.common.utils import Const
from .utils import np_save_data, logger_debug, logger_error, logger_warn, logger_user, COLOR_RED, COLOR_GREEN, \
COLOR_RESET, CSV_COLUMN_NAME
-from atat.core.common.file_check import FileOpen, change_mode, FileCheckConst
-from atat.core.common.utils import CompareConst
+from atat.core.common.file_check import FileOpen, change_mode
+from atat.core.common.const import CompareConst, FileCheckConst, Const
from atat.pytorch.common.log import logger
class DispatchRunParam:
diff --git a/debug/accuracy_tools/atat/pytorch/online_dispatch/utils.py b/debug/accuracy_tools/atat/pytorch/online_dispatch/utils.py
index 1f9c2e916c187615160bfb1be64a262b2cd6bd95..f3fcffb6f26adbd2b2ff6b77da720f7ecdfcb0ca 100644
--- a/debug/accuracy_tools/atat/pytorch/online_dispatch/utils.py
+++ b/debug/accuracy_tools/atat/pytorch/online_dispatch/utils.py
@@ -12,8 +12,8 @@ except ImportError:
else:
pta_cpu_device = torch.device("cpu")
-from atat.core.common.utils import CompareConst
-from atat.core.common.file_check import change_mode, FileCheckConst
+from atat.core.common.const import CompareConst, FileCheckConst
+from atat.core.common.file_check import change_mode
cpu_device = torch._C.device("cpu")
COLOR_RED = '\033[31m'
@@ -58,17 +58,6 @@ BOOL_TYPE = [bool, np.uint8]
INT_TYPE = [np.int32, np.int64]
-class CompareConst:
- NAN = np.nan
- NA = "N/A"
- PASS = 'pass'
- WARNING = 'warning'
- ERROR = 'error'
- SKIP = 'SKIP'
- TRUE = 'TRUE'
- FALSE = 'FALSE'
-
-
def get_callstack():
callstack = []
for (_, path, line, func, code, _) in inspect.stack()[2:]:
diff --git a/debug/accuracy_tools/atat/pytorch/parse_tool/lib/utils.py b/debug/accuracy_tools/atat/pytorch/parse_tool/lib/utils.py
index aeb1d7f6d2f7d3a0488822bfd7859633dfd70366..ce42d242ba29ef4dcf7f2af042242b1b9987de42 100644
--- a/debug/accuracy_tools/atat/pytorch/parse_tool/lib/utils.py
+++ b/debug/accuracy_tools/atat/pytorch/parse_tool/lib/utils.py
@@ -30,7 +30,7 @@ from atat.pytorch.parse_tool.lib.file_desc import DumpDecodeFileDesc, FileDesc
from atat.pytorch.parse_tool.lib.parse_exception import ParseException
from atat.core.common.file_check import change_mode, check_other_user_writable,\
check_path_executable, check_path_owner_consistent
-from atat.core.common.file_check import FileCheckConst
+from atat.core.common.const import FileCheckConst
from atat.core.common.file_check import FileOpen
from atat.core.common.utils import check_file_or_directory_path
from atat.pytorch.common.log import logger
diff --git a/debug/accuracy_tools/atat/pytorch/pt_config.py b/debug/accuracy_tools/atat/pytorch/pt_config.py
index e04b88bb96633e54bf6c0e085247907a29752503..0674b91b3410765ca7dcb9a6f38c6f03fa6e94f1 100644
--- a/debug/accuracy_tools/atat/pytorch/pt_config.py
+++ b/debug/accuracy_tools/atat/pytorch/pt_config.py
@@ -3,7 +3,7 @@ import os
from atat.core.common_config import CommonConfig, BaseConfig
from atat.core.common.file_check import FileOpen
-from atat.core.common.utils import Const
+from atat.core.common.const import Const
class TensorConfig(BaseConfig):
diff --git a/debug/accuracy_tools/atat/pytorch/service.py b/debug/accuracy_tools/atat/pytorch/service.py
index cd80d0852ac7526615b0a14f059803ef6222ed7f..d0b9c4d4b27bd13672e0529de6dd3349860e533f 100644
--- a/debug/accuracy_tools/atat/pytorch/service.py
+++ b/debug/accuracy_tools/atat/pytorch/service.py
@@ -3,8 +3,8 @@ import os
from pathlib import Path
from atat.pytorch.common.log import logger
-from atat.core.common.file_check import FileChecker, FileCheckConst, check_path_before_create
-from atat.core.common.utils import Const
+from atat.core.common.file_check import FileChecker, check_path_before_create
+from atat.core.common.const import Const, FileCheckConst
from atat.core.common.exceptions import DistributedNotInitializedError, MsaccException
from atat.core.data_dump.data_collector import build_data_collector
from atat.core.data_dump.scope import BaseScope
diff --git a/debug/accuracy_tools/atat/test/core_ut/test_file_check.py b/debug/accuracy_tools/atat/test/core_ut/test_file_check.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa7882aa5906b8e472495ede42653f9f5584f573
--- /dev/null
+++ b/debug/accuracy_tools/atat/test/core_ut/test_file_check.py
@@ -0,0 +1,218 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+# Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+import os
+
+from unittest import TestCase
+from unittest.mock import patch, MagicMock
+
+from atat.core.common.log import logger
+from atat.core.common.const import FileCheckConst
+from atat.core.common.exceptions import FileCheckException
+from atat.core.common.file_check import (check_link,
+ check_path_length,
+ check_path_exists,
+ check_path_readability,
+ check_path_writability,
+ check_path_executable,
+ check_other_user_writable,
+ check_path_owner_consistent,
+ check_path_pattern_vaild,
+ check_file_size,
+ check_common_file_size,
+ check_file_suffix,
+ check_path_type)
+
+
+class TestFileCheckUtil(TestCase):
+ @patch.object(logger, "error")
+ def test_check_link(self, mock_logger_error):
+ with patch("atat.core.common.file_check.os.path.islink", return_value=True):
+ with self.assertRaises(FileCheckException) as context:
+ check_link("link_path")
+ self.assertEqual(str(context.exception),
+ FileCheckException.err_strs.get(FileCheckException.SOFT_LINK_ERROR))
+ mock_logger_error.assert_called_with("The file path link_path is a soft link.")
+
+ @patch.object(logger, "error")
+ def test_check_path_length(self, mock_logger_error):
+ path = "P" * (FileCheckConst.DIRECTORY_LENGTH + 1)
+ with self.assertRaises(FileCheckException) as context:
+ check_path_length(path)
+ self.assertEqual(str(context.exception),
+ FileCheckException.err_strs.get(FileCheckException.ILLEGAL_PATH_ERROR))
+ mock_logger_error.assert_called_with("The file path length exceeds limit.")
+
+ path = "P" * (FileCheckConst.FILE_NAME_LENGTH + 1)
+ with self.assertRaises(FileCheckException) as context:
+ check_path_length(path)
+ self.assertEqual(str(context.exception),
+ FileCheckException.err_strs.get(FileCheckException.ILLEGAL_PATH_ERROR))
+ mock_logger_error.assert_called_with("The file path length exceeds limit.")
+
+ path = "P" * (FileCheckConst.FILE_NAME_LENGTH - 5)
+ with self.assertRaises(FileCheckException) as context:
+ check_path_length(path, name_length=FileCheckConst.FILE_NAME_LENGTH - 6)
+ self.assertEqual(str(context.exception),
+ FileCheckException.err_strs.get(FileCheckException.ILLEGAL_PATH_ERROR))
+ mock_logger_error.assert_called_with("The file path length exceeds limit.")
+
+ @patch.object(logger, "error")
+ def test_check_path_exists(self, mock_logger_error):
+ with patch("atat.core.common.file_check.os.path.exists", return_value=False):
+ with self.assertRaises(FileCheckException) as context:
+ check_path_exists("file_path")
+ self.assertEqual(str(context.exception),
+ FileCheckException.err_strs.get(FileCheckException.ILLEGAL_PATH_ERROR))
+ mock_logger_error.assert_called_with("The file path file_path does not exist.")
+
+ @patch.object(logger, "error")
+ def test_check_path_readability(self, mock_logger_error):
+ path = "file_path"
+ with patch("atat.core.common.file_check.os.access", return_value=False):
+ with self.assertRaises(FileCheckException) as context:
+ check_path_readability(path)
+ self.assertEqual(str(context.exception),
+ FileCheckException.err_strs.get(FileCheckException.FILE_PERMISSION_ERROR))
+ mock_logger_error.assert_called_with(f"The file path {path} is not readable.")
+
+ mock_access = MagicMock()
+ mock_access.return_value = True
+ with patch("atat.core.common.file_check.os.access", new=mock_access):
+ check_path_readability(path)
+ self.assertEqual(mock_access.call_args[0], (path, os.R_OK))
+
+ @patch.object(logger, "error")
+ def test_check_path_writability(self, mock_logger_error):
+ path = "file_path"
+ with patch("atat.core.common.file_check.os.access", return_value=False):
+ with self.assertRaises(FileCheckException) as context:
+ check_path_writability(path)
+ self.assertEqual(str(context.exception),
+ FileCheckException.err_strs.get(FileCheckException.FILE_PERMISSION_ERROR))
+ mock_logger_error.assert_called_with(f"The file path {path} is not writable.")
+
+ mock_access = MagicMock()
+ mock_access.return_value = True
+ with patch("atat.core.common.file_check.os.access", new=mock_access):
+ check_path_writability(path)
+ self.assertEqual(mock_access.call_args[0], (path, os.W_OK))
+
+ @patch.object(logger, "error")
+ def test_check_path_executable(self, mock_logger_error):
+ path = "file_path"
+ with patch("atat.core.common.file_check.os.access", return_value=False):
+ with self.assertRaises(FileCheckException) as context:
+ check_path_executable(path)
+ self.assertEqual(str(context.exception),
+ FileCheckException.err_strs.get(FileCheckException.FILE_PERMISSION_ERROR))
+ mock_logger_error.assert_called_with(f"The file path {path} is not executable.")
+
+ mock_access = MagicMock()
+ mock_access.return_value = True
+ with patch("atat.core.common.file_check.os.access", new=mock_access):
+ check_path_executable(path)
+ self.assertEqual(mock_access.call_args[0], (path, os.X_OK))
+
+ @patch.object(logger, "error")
+ def test_check_other_user_writable(self, mock_logger_error):
+ class TestStat:
+ def __init__(self, mode):
+ self.st_mode = mode
+
+ path = "file_path"
+ mock_stat = TestStat(0o002)
+ with patch("atat.core.common.file_check.os.stat", return_value=mock_stat):
+ with self.assertRaises(FileCheckException) as context:
+ check_other_user_writable(path)
+ self.assertEqual(str(context.exception),
+ FileCheckException.err_strs.get(FileCheckException.FILE_PERMISSION_ERROR))
+ mock_logger_error.assert_called_with(f"The file path {path} may be insecure "
+ "because other users have write permissions. ")
+
+ @patch.object(logger, "error")
+ def test_check_path_owner_consistent(self, mock_logger_error):
+ file_path = os.path.realpath(__file__)
+ file_owner = os.stat(file_path).st_uid
+ with patch("atat.core.common.file_check.os.getuid", return_value=file_owner+1):
+ with self.assertRaises(FileCheckException) as context:
+ check_path_owner_consistent(file_path)
+ self.assertEqual(str(context.exception),
+ FileCheckException.err_strs.get(FileCheckException.FILE_PERMISSION_ERROR))
+ mock_logger_error.assert_called_with(f"The file path {file_path} may be insecure "
+ "because is does not belong to you.")
+
+ @patch.object(logger, "error")
+ def test_check_path_pattern_vaild(self, mock_logger_error):
+ path = "path"
+ mock_re_match = MagicMock()
+ mock_re_match.return_value = False
+ with patch("atat.core.common.file_check.re.match", new=mock_re_match):
+ with self.assertRaises(FileCheckException) as context:
+ check_path_pattern_vaild(path)
+ self.assertEqual(str(context.exception),
+ FileCheckException.err_strs.get(FileCheckException.ILLEGAL_PATH_ERROR))
+ mock_logger_error.assert_called_with(f"The file path {path} contains special characters.")
+ mock_re_match.assert_called_with(FileCheckConst.FILE_VALID_PATTERN, path)
+
+ @patch.object(logger, "error")
+ def test_check_file_size(self, mock_logger_error):
+ file_path = os.path.realpath(__file__)
+ file_size = os.path.getsize(file_path)
+ max_size = file_size
+ with self.assertRaises(FileCheckException) as context:
+ check_file_size(file_path, max_size)
+ self.assertEqual(str(context.exception),
+ FileCheckException.err_strs.get(FileCheckException.FILE_TOO_LARGE_ERROR))
+ mock_logger_error.assert_called_with(f"The size of file path {file_path} exceeds {max_size} bytes.")
+
+ def test_check_common_file_size(self):
+ mock_check_file_size = MagicMock()
+ with patch("atat.core.common.file_check.os.path.isfile", return_value=True), \
+ patch("atat.core.common.file_check.check_file_size", new=mock_check_file_size):
+ for suffix, max_size in FileCheckConst.FILE_SIZE_DICT.items():
+ check_common_file_size(suffix)
+ mock_check_file_size.assert_called_with(suffix, max_size)
+
+ @patch.object(logger, "error")
+ def test_check_file_suffix(self, mock_logger_error):
+ file_path = "file_path"
+ suffix = "suffix"
+ with self.assertRaises(FileCheckException) as context:
+ check_file_suffix(file_path, suffix)
+ self.assertEqual(str(context.exception),
+ FileCheckException.err_strs.get(FileCheckException.INVALID_FILE_ERROR))
+ mock_logger_error.assert_called_with(f"The {file_path} should be a {suffix} file!")
+
+ @patch.object(logger, "error")
+ def test_check_path_type(self, mock_logger_error):
+ file_path = "file_path"
+
+ with patch("atat.core.common.file_check.os.path.isfile", return_value=False), \
+ patch("atat.core.common.file_check.os.path.isdir", return_value=True):
+ with self.assertRaises(FileCheckException) as context:
+ check_path_type(file_path, FileCheckConst.FILE)
+ self.assertEqual(str(context.exception),
+ FileCheckException.err_strs.get(FileCheckException.INVALID_FILE_ERROR))
+ mock_logger_error.assert_called_with(f"The {file_path} should be a file!")
+
+ with patch("atat.core.common.file_check.os.path.isfile", return_value=True), \
+ patch("atat.core.common.file_check.os.path.isdir", return_value=False):
+ with self.assertRaises(FileCheckException) as context:
+ check_path_type(file_path, FileCheckConst.DIR)
+ self.assertEqual(str(context.exception),
+ FileCheckException.err_strs.get(FileCheckException.INVALID_FILE_ERROR))
+ mock_logger_error.assert_called_with(f"The {file_path} should be a dictionary!")
diff --git a/debug/accuracy_tools/atat/test/core_ut/test_utils.py b/debug/accuracy_tools/atat/test/core_ut/test_utils.py
index 89734f2c572bff2aa864db16f23dfe8665042f74..b3273358e43593e14f931a5675e9039c684d1c59 100644
--- a/debug/accuracy_tools/atat/test/core_ut/test_utils.py
+++ b/debug/accuracy_tools/atat/test/core_ut/test_utils.py
@@ -1,13 +1,52 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+# Copyright (C) 2022-2023. Huawei Technologies Co., Ltd. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+import os
+import uuid
+
from unittest import TestCase
-from unittest.mock import patch
+from unittest.mock import patch, MagicMock, mock_open
-from atat.core.common.utils import check_seed_all, Const, CompareException, check_inplace_op
from atat.core.common.log import logger
+from atat.core.common.const import Const
+from atat.core.common.utils import (CompareException,
+ check_seed_all,
+ check_inplace_op,
+ make_dump_path_if_not_exists,
+ check_mode_valid,
+ check_switch_valid,
+ check_dump_mode_valid,
+ check_summary_mode_valid,
+ check_summary_only_valid,
+ check_file_or_directory_path,
+ check_compare_param,
+ check_configuration_param,
+ is_starts_with,
+ _check_json,
+ check_json_file,
+ check_file_size,
+ check_regex_prefix_format_valid,
+ get_dump_data_path,
+ task_dumppath_get)
+from atat.core.common.file_check import FileCheckConst
class TestUtils(TestCase):
@patch.object(logger, "error")
- def test_check_seed_all(self, mock_print_error_log):
+ def test_check_seed_all(self, mock_error):
self.assertIsNone(check_seed_all(1234, True))
self.assertIsNone(check_seed_all(0, True))
self.assertIsNone(check_seed_all(Const.MAX_SEED_VALUE, True))
@@ -15,23 +54,23 @@ class TestUtils(TestCase):
with self.assertRaises(CompareException) as context:
check_seed_all(-1, True)
self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR)
- mock_print_error_log.assert_called_with(f"Seed must be between 0 and {Const.MAX_SEED_VALUE}.")
+ mock_error.assert_called_with(f"Seed must be between 0 and {Const.MAX_SEED_VALUE}.")
with self.assertRaises(CompareException) as context:
check_seed_all(Const.MAX_SEED_VALUE + 1, True)
self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR)
- mock_print_error_log.assert_called_with(f"Seed must be between 0 and {Const.MAX_SEED_VALUE}.")
+ mock_error.assert_called_with(f"Seed must be between 0 and {Const.MAX_SEED_VALUE}.")
with self.assertRaises(CompareException) as context:
check_seed_all("1234", True)
self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR)
- mock_print_error_log.assert_called_with("Seed must be integer.")
+ mock_error.assert_called_with("Seed must be integer.")
with self.assertRaises(CompareException) as context:
check_seed_all(1234, 1)
self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR)
- mock_print_error_log.assert_called_with("seed_all mode must be bool.")
-
+ mock_error.assert_called_with("seed_all mode must be bool.")
+
def test_check_inplace_op(self):
test_prefix_1 = "Distributed.broadcast.0.forward.input.0"
self.assertTrue(check_inplace_op(test_prefix_1))
@@ -39,3 +78,268 @@ class TestUtils(TestCase):
self.assertFalse(check_inplace_op(test_prefix_2))
test_prefix_3 = "Torch.sum.0.backward.output.0"
self.assertFalse(check_inplace_op(test_prefix_3))
+
+ @patch.object(logger, "error")
+ def test_make_dump_path_if_not_exists(self, mock_error):
+ file_path = os.path.realpath(__file__)
+ dirname = os.path.dirname(file_path) + str(uuid.uuid4())
+
+ def test_mkdir(self, **kwargs):
+ raise OSError
+
+ if not os.path.exists(dirname):
+ with patch("atat.core.common.utils.Path.mkdir", new=test_mkdir):
+ with self.assertRaises(CompareException) as context:
+ make_dump_path_if_not_exists(dirname)
+ self.assertEqual(context.exception.code, CompareException.INVALID_PATH_ERROR)
+
+ make_dump_path_if_not_exists(file_path)
+ mock_error.assert_called_with(f"{file_path} already exists and is not a directory.")
+
+ def test_check_mode_valid(self):
+ with self.assertRaises(ValueError) as context:
+ check_mode_valid("all", scope="scope")
+ self.assertEqual(str(context.exception), "scope param set invalid, it's must be a list.")
+
+ with self.assertRaises(ValueError) as context:
+ check_mode_valid("all", api_list="api_list")
+ self.assertEqual(str(context.exception), "api_list param set invalid, it's must be a list.")
+
+ mode = "all_list"
+ with self.assertRaises(CompareException) as context:
+ check_mode_valid(mode)
+ self.assertEqual(context.exception.code, CompareException.INVALID_DUMP_MODE)
+ self.assertEqual(str(context.exception),
+ f"Current mode '{mode}' is not supported. Please use the field in {Const.DUMP_MODE}")
+
+ mode = "list"
+ with self.assertRaises(ValueError) as context:
+ check_mode_valid(mode)
+ self.assertEqual(str(context.exception),
+ "set_dump_switch, scope param set invalid, it's should not be an empty list.")
+
+ @patch.object(logger, "error")
+ def test_check_switch_valid(self, mock_error):
+ with self.assertRaises(CompareException) as context:
+ check_switch_valid("Close")
+ self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR)
+ mock_error.assert_called_with("Please set switch with 'ON' or 'OFF'.")
+
+ @patch.object(logger, "warning")
+ def test_check_dump_mode_valid(self, mock_warning):
+ dump_mode = check_dump_mode_valid("all")
+ mock_warning.assert_called_with("Please set dump_mode as a list.")
+ self.assertEqual(dump_mode, ["forward", "backward", "input", "output"])
+
+ with self.assertRaises(ValueError) as context:
+ check_dump_mode_valid("all_forward")
+ self.assertEqual(str(context.exception),
+ "Please set dump_mode as a list containing one or more of the following: " +
+ "'all', 'forward', 'backward', 'input', 'output'.")
+
+ def test_check_summary_mode_valid(self):
+ with self.assertRaises(CompareException) as context:
+ check_summary_mode_valid("MD5")
+ self.assertEqual(context.exception.code, CompareException.INVALID_SUMMARY_MODE)
+ self.assertEqual(str(context.exception), "The summary_mode is not valid")
+
+ @patch.object(logger, "error")
+ def test_check_summary_only_valid(self, mock_error):
+ summary_only = check_summary_only_valid(True)
+ self.assertTrue(summary_only)
+
+ with self.assertRaises(CompareException) as context:
+ check_summary_only_valid("True")
+ self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR)
+ mock_error.assert_called_with("Params summary_only only support True or False.")
+
+ def test_check_file_or_directory_path(self):
+ class TestFileChecker:
+ file_path = ""
+ path_type = ""
+ ability = ""
+ checked = False
+
+ def __init__(self, file_path, path_type, ability=None):
+ TestFileChecker.file_path = file_path
+ TestFileChecker.path_type = path_type
+ TestFileChecker.ability = ability
+
+ def common_check(self):
+ TestFileChecker.checked = True
+
+ file_path = os.path.realpath(__file__)
+ dirname = os.path.dirname(file_path)
+
+ with patch("atat.core.common.utils.FileChecker", new=TestFileChecker):
+ check_file_or_directory_path(file_path, isdir=False)
+ self.assertTrue(TestFileChecker.checked)
+ self.assertEqual(TestFileChecker.file_path, file_path)
+ self.assertEqual(TestFileChecker.path_type, FileCheckConst.FILE)
+ self.assertEqual(TestFileChecker.ability, FileCheckConst.READ_ABLE)
+
+ TestFileChecker.checked = False
+ with patch("atat.core.common.utils.FileChecker", new=TestFileChecker):
+ check_file_or_directory_path(dirname, isdir=True)
+ self.assertTrue(TestFileChecker.checked)
+ self.assertEqual(TestFileChecker.file_path, dirname)
+ self.assertEqual(TestFileChecker.path_type, FileCheckConst.DIR)
+ self.assertEqual(TestFileChecker.ability, FileCheckConst.WRITE_ABLE)
+
+ @patch.object(logger, "error")
+ def test_check_compare_param(self, mock_error):
+ params = {
+ "npu_json_path": "npu_json_path",
+ "bench_json_path": "bench_json_path",
+ "stack_json_path": "stack_json_path",
+ "npu_dump_data_dir": "npu_dump_data_dir",
+ "bench_dump_data_dir": "bench_dump_data_dir"
+ }
+
+ call_args = [
+ ("npu_json_path", False),
+ ("bench_json_path", False),
+ ("stack_json_path", False),
+ ("npu_dump_data_dir", True),
+ ("bench_dump_data_dir", True),
+ ("output_path", True),
+ ("npu_json_path", False),
+ ("bench_json_path", False),
+ ("stack_json_path", False),
+ ("output_path", True)
+ ]
+
+ with self.assertRaises(CompareException) as context:
+ check_compare_param("npu_json_path", "output_path")
+ self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR)
+ mock_error.assert_called_with("Invalid input parameters")
+
+ mock_check_file_or_directory_path = MagicMock()
+ mock_check_json_file = MagicMock()
+ with patch("atat.core.common.utils.FileOpen", mock_open(read_data="")), \
+ patch("atat.core.common.utils.check_json_file", new=mock_check_json_file), \
+ patch("atat.core.common.utils.check_file_or_directory_path", new=mock_check_file_or_directory_path):
+ check_compare_param(params, "output_path")
+ check_compare_param(params, "output_path", summary_compare=False, md5_compare=True)
+ for i in range(len(call_args)):
+ self.assertEqual(mock_check_file_or_directory_path.call_args_list[i][0], call_args[i])
+ self.assertEqual(len(mock_check_json_file.call_args[0]), 4)
+ self.assertEqual(mock_check_json_file.call_args[0][0], params)
+
+ @patch.object(logger, "error")
+ def test_check_configuration_param(self, mock_error):
+ with self.assertRaises(CompareException) as context:
+ check_configuration_param(stack_mode="False", auto_analyze=True, fuzzy_match=False)
+ self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR)
+ mock_error.assert_called_with("Invalid input parameters which should be only bool type.")
+
+ def test_is_starts_with(self):
+ string = "input_slot0"
+ self.assertFalse(is_starts_with(string, []))
+ self.assertFalse(is_starts_with("", ["input"]))
+ self.assertFalse(is_starts_with(string, ["output"]))
+ self.assertTrue(is_starts_with(string, ["input", "output"]))
+
+ @patch.object(logger, "error")
+ def test__check_json(self, mock_error):
+ class TestOpen:
+ def __init__(self, string):
+ self.string = string
+
+ def readline(self):
+ return self.string
+
+ def seek(self, begin, end):
+ self.string = str(begin) + "_" + str(end)
+
+ with self.assertRaises(CompareException) as context:
+ _check_json(TestOpen(""), "test.json")
+ self.assertEqual(context.exception.code, CompareException.INVALID_DUMP_FILE)
+ mock_error.assert_called_with("dump file test.json have empty line!")
+
+ handler = TestOpen("jons file\n")
+ _check_json(handler, "test.json")
+ self.assertEqual(handler.string, "0_0")
+
+ @patch("atat.core.common.utils._check_json")
+ def test_check_json_file(self, _mock_check_json):
+ input_param = {
+ "npu_json_path": "npu_json_path",
+ "bench_json_path": "bench_json_path",
+ "stack_json_path": "stack_json_path"
+ }
+ check_json_file(input_param, "npu_json", "bench_json", "stack_json")
+ self.assertEqual(_mock_check_json.call_args_list[0][0], ("npu_json", "npu_json_path"))
+ self.assertEqual(_mock_check_json.call_args_list[1][0], ("bench_json", "bench_json_path"))
+ self.assertEqual(_mock_check_json.call_args_list[2][0], ("stack_json", "stack_json_path"))
+
+ @patch.object(logger, "error")
+ def test_check_file_size(self, mock_error):
+ with patch("atat.core.common.utils.os.path.getsize", return_value=120):
+ with self.assertRaises(CompareException) as context:
+ check_file_size("input_file", 100)
+ self.assertEqual(context.exception.code, CompareException.INVALID_FILE_ERROR)
+ mock_error.assert_called_with("The size (120) of input_file exceeds (100) bytes, tools not support.")
+
+ def test_check_regex_prefix_format_valid(self):
+ prefix = "A" * 21
+ with self.assertRaises(ValueError) as context:
+ check_regex_prefix_format_valid(prefix)
+ self.assertEqual(str(context.exception), f"Maximum length of prefix is {Const.REGEX_PREFIX_MAX_LENGTH}, "
+ f"while current length is {len(prefix)}")
+
+ prefix = "(prefix)"
+ with self.assertRaises(ValueError) as context:
+ check_regex_prefix_format_valid(prefix)
+ self.assertEqual(str(context.exception), f"prefix contains invalid characters, "
+ f"prefix pattern {Const.REGEX_PREFIX_PATTERN}")
+
+ @patch("atat.core.common.utils.check_file_or_directory_path")
+ def test_get_dump_data_path(self, mock_check_file_or_directory_path):
+ file_path = os.path.realpath(__file__)
+ dirname = os.path.dirname(file_path)
+
+ dump_data_path, file_is_exist = get_dump_data_path(dirname)
+ self.assertEqual(mock_check_file_or_directory_path.call_args[0], (dirname, True))
+ self.assertEqual(dump_data_path, dirname)
+ self.assertTrue(file_is_exist)
+
+ @patch.object(logger, "error")
+ def test_task_dumppath_get(self, mock_error):
+ input_param = {
+ "npu_json_path": None,
+ "bench_json_path": "bench_json_path"
+ }
+ npu_json = {
+ "task": Const.TENSOR,
+ "dump_data_dir": "dump_data_dir",
+ "data": "data"
+ }
+
+ with self.assertRaises(CompareException) as context:
+ task_dumppath_get(input_param)
+ self.assertEqual(context.exception.code, CompareException.INVALID_PATH_ERROR)
+ mock_error.assert_called_with("Please check the json path is valid.")
+
+ input_param["npu_json_path"] = "npu_json_path"
+ with patch("atat.core.common.utils.FileOpen", mock_open(read_data="")), \
+ patch("atat.core.common.utils.json.load", return_value=npu_json):
+ summary_compare, md5_compare = task_dumppath_get(input_param)
+ self.assertFalse(summary_compare)
+ self.assertFalse(md5_compare)
+
+ npu_json["task"] = Const.STATISTICS
+ with patch("atat.core.common.utils.FileOpen", mock_open(read_data="")), \
+ patch("atat.core.common.utils.json.load", return_value=npu_json), \
+ patch("atat.core.common.utils.md5_find", return_value=True):
+ summary_compare, md5_compare = task_dumppath_get(input_param)
+ self.assertFalse(summary_compare)
+ self.assertTrue(md5_compare)
+
+ npu_json["task"] = Const.OVERFLOW_CHECK
+ with patch("atat.core.common.utils.FileOpen", mock_open(read_data="")), \
+ patch("atat.core.common.utils.json.load", return_value=npu_json):
+ with self.assertRaises(CompareException) as context:
+ task_dumppath_get(input_param)
+ self.assertEqual(context.exception.code, CompareException.INVALID_TASK_ERROR)
+ mock_error.assert_called_with("Compare is not required for overflow_check or free_benchmark.")
diff --git a/debug/accuracy_tools/atat/test/mindspore_ut/test_ms_config.py b/debug/accuracy_tools/atat/test/mindspore_ut/test_ms_config.py
index 6be8949684c89012f0dc2165ba24eab4e7a77f1c..fe92a90aa1ce85e6efb38a8b089afcb34a199b59 100644
--- a/debug/accuracy_tools/atat/test/mindspore_ut/test_ms_config.py
+++ b/debug/accuracy_tools/atat/test/mindspore_ut/test_ms_config.py
@@ -1,7 +1,7 @@
from unittest import TestCase
from unittest.mock import patch, mock_open
-from atat.core.common.utils import Const
+from atat.core.common.const import Const
from atat.mindspore.ms_config import parse_json_config
diff --git a/debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py b/debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py
index aab90f122f64a148955a7c824e8975c7d19cb679..701c67b07483652850831d1f720f417cd5554481 100644
--- a/debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py
+++ b/debug/accuracy_tools/atat/test/pytorch_ut/api_accuracy_checker/compare/test_api_precision_compare.py
@@ -9,7 +9,7 @@ from atat.pytorch.api_accuracy_checker.compare.api_precision_compare import (
check_error_rate,
get_api_checker_result,
)
-from atat.pytorch.api_accuracy_checker.compare.compare_utils import CompareConst
+from atat.core.common.const import CompareConst
class TestApiPrecisionCompare(unittest.TestCase):
diff --git a/debug/accuracy_tools/atat/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py b/debug/accuracy_tools/atat/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py
index 448b518c3344b2898ce0487687c374b0071dacb6..828d646c52f363e758a10939dbfcd98d942eb9f2 100644
--- a/debug/accuracy_tools/atat/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py
+++ b/debug/accuracy_tools/atat/test/pytorch_ut/free_benchmark/perturbed_layers/test_perturbed_layser.py
@@ -1,7 +1,7 @@
from unittest import TestCase
import torch
-from atat.pytorch.common.utils import Const
+from atat.core.common.const import Const
from atat.pytorch.free_benchmark.common.enums import DeviceType, PerturbationMode
from atat.pytorch.free_benchmark.common.params import data_pre_deal
from atat.pytorch.free_benchmark.perturbed_layers.layer_factory import LayerFactory
diff --git a/debug/accuracy_tools/atat/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py b/debug/accuracy_tools/atat/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py
index 948fdaecea5c39be5602f66fa565985655054050..d46e26e09488d481ac9e96a8b4002dd3b62446bd 100644
--- a/debug/accuracy_tools/atat/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py
+++ b/debug/accuracy_tools/atat/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py
@@ -2,7 +2,7 @@ from abc import ABC
from unittest import TestCase
import torch
-from atat.pytorch.common.utils import Const
+from atat.core.common.const import Const
from atat.pytorch.free_benchmark.common.constant import PreheatConfig, ThresholdConfig
from atat.pytorch.free_benchmark.common.counter import preheat_counter
from atat.pytorch.free_benchmark.common.enums import (
diff --git a/debug/accuracy_tools/atat/test/pytorch_ut/free_benchmark/test_main.py b/debug/accuracy_tools/atat/test/pytorch_ut/free_benchmark/test_main.py
index 4c1aa1deff4e805c9576892ef6d1775c127b9d53..d326e993c07d66a1baf5ae785ed4b519624bb982 100644
--- a/debug/accuracy_tools/atat/test/pytorch_ut/free_benchmark/test_main.py
+++ b/debug/accuracy_tools/atat/test/pytorch_ut/free_benchmark/test_main.py
@@ -4,7 +4,7 @@ from unittest import TestCase
import torch
import torch.nn as nn
-from atat.pytorch.common.utils import Const
+from atat.core.common.const import Const
from atat.pytorch.free_benchmark import FreeBenchmarkCheck
from atat.pytorch.free_benchmark.common.constant import CommonField, PreheatConfig
from atat.pytorch.free_benchmark.common.enums import (
diff --git a/debug/accuracy_tools/atat/test/pytorch_ut/test_pt_config.py b/debug/accuracy_tools/atat/test/pytorch_ut/test_pt_config.py
index c931c8550716bc65b58655db6140684408f596cc..fa52fe0e1b05701cec9c8cdf41fe1586029c826e 100644
--- a/debug/accuracy_tools/atat/test/pytorch_ut/test_pt_config.py
+++ b/debug/accuracy_tools/atat/test/pytorch_ut/test_pt_config.py
@@ -1,7 +1,7 @@
from unittest import TestCase
from unittest.mock import patch, mock_open
-from atat.core.common.utils import Const
+from atat.core.common.const import Const
from atat.pytorch.pt_config import parse_json_config
diff --git a/debug/accuracy_tools/grad_tool/common/base_comparator.py b/debug/accuracy_tools/grad_tool/common/base_comparator.py
index f940ef5135503e68ffa4559b3c55d5cc280f8b25..d3254ae71f9a8fccb8608088462c4733c166814d 100644
--- a/debug/accuracy_tools/grad_tool/common/base_comparator.py
+++ b/debug/accuracy_tools/grad_tool/common/base_comparator.py
@@ -40,9 +40,9 @@ class BaseComparator(ABC):
create_directory(output_dir)
for rank in tqdm(ranks, desc="rank"):
print_info_log(f"now comparing rank {rank}:")
- cls.compare(os.path.join(path1, f"rank_{rank}"),
- os.path.join(path2, f"rank_{rank}"),
- os.path.join(output_dir, f"rank_{rank}"))
+ cls.compare(os.path.join(path1, f"rank{rank}"),
+ os.path.join(path2, f"rank{rank}"),
+ os.path.join(output_dir, f"rank{rank}"))
@classmethod
def compare(cls, path1: str, path2: str, output_dir: str):
@@ -59,15 +59,15 @@ class BaseComparator(ABC):
check_file_or_directory_path(path1, file_type=GradConst.DIR)
check_file_or_directory_path(path2, file_type=GradConst.DIR)
dirs = []
- for dirname in os.listdir(path1):
- splits = dirname.split('_')
- if not splits or splits[0] != dir_prefix or not splits[1].isdigit():
+ for dir_name in os.listdir(path1):
+ index = dir_name.replace(dir_prefix, "", 1)
+ if not dir_name.startswith(dir_prefix) or not index.isdigit():
continue
- folder2 = os.path.join(path2, dirname)
+ folder2 = os.path.join(path2, dir_name)
if not os.path.isdir(folder2):
continue
- dirs.append(int(splits[1]))
+ dirs.append(int(index))
dirs = sorted(dirs)
return dirs
@@ -101,8 +101,8 @@ class BaseComparator(ABC):
total_count_summary = 0
for grad_name in grad_weight_order:
grad_file = cls._get_name_matched_grad_file(grad_name, grad_files)
- grad1 = os.path.join(path1, f"step_{step}", grad_file)
- grad2 = os.path.join(path2, f"step_{step}", grad_file)
+ grad1 = os.path.join(path1, f"step{step}", grad_file)
+ grad2 = os.path.join(path2, f"step{step}", grad_file)
same_count, total_count = cls._calculate_similarity(grad1, grad2)
same_count_summary += same_count
total_count_summary += total_count
@@ -124,8 +124,8 @@ class BaseComparator(ABC):
@classmethod
def _get_matched_grad_files(cls, path1: str, path2: str, step: int):
- path1 = os.path.join(path1, f"step_{step}")
- path2 = os.path.join(path2, f"step_{step}")
+ path1 = os.path.join(path1, f"step{step}")
+ path2 = os.path.join(path2, f"step{step}")
check_file_or_directory_path(path1, file_type=GradConst.DIR)
check_file_or_directory_path(path2, file_type=GradConst.DIR)
grad_files = []
diff --git a/debug/accuracy_tools/grad_tool/grad_pt/grad_monitor.py b/debug/accuracy_tools/grad_tool/grad_pt/grad_monitor.py
index 6733d566d6aaadd2fbf97076bfaca84a3e05160d..f3079e622c29eefe20a6e1fdc9d372002b596610 100644
--- a/debug/accuracy_tools/grad_tool/grad_pt/grad_monitor.py
+++ b/debug/accuracy_tools/grad_tool/grad_pt/grad_monitor.py
@@ -96,8 +96,8 @@ class PtGradientMonitor(BaseMonitor):
output_lines.append(grad_info)
if self._level_adp["have_grad_direction"]:
PtGradientMonitor.save_grad_direction(param_name, grad,
- f'{self._output_path}/rank_{self._rank}/step_{self._step}')
- output_path = os.path.join(self._output_path, f"rank_{getattr(self, '_rank')}",
+ f'{self._output_path}/rank{self._rank}/step{self._step}')
+ output_path = os.path.join(self._output_path, f"rank{getattr(self, '_rank')}",
f"grad_summary_{self._step}.csv")
write_csv(output_path, output_lines,
GradStatCsv.generate_csv_header(self._level_adp, self._bounds))
diff --git a/plugins/tensorboard-plugins/ OWNERS b/plugins/tensorboard-plugins/OWNERS
similarity index 93%
rename from plugins/tensorboard-plugins/ OWNERS
rename to plugins/tensorboard-plugins/OWNERS
index 34c383beaf138da92df0991b472135496450a827..507672c7399438d6c350856386f1818652ea21f8 100644
--- a/plugins/tensorboard-plugins/ OWNERS
+++ b/plugins/tensorboard-plugins/OWNERS
@@ -1,9 +1,9 @@
-options:
- no_parent_owners: true
-approvers:
-- wo-wenjie
-- ly-qianxiao
-reviewers:
-- wo-wenjie
-- ly-qianxiao
-- leo920320
+options:
+ no_parent_owners: true
+approvers:
+- wo-wenjie
+- ly-qianxiao
+reviewers:
+- wo-wenjie
+- ly-qianxiao
+- leo920320
diff --git a/profiler/advisor/README.md b/profiler/advisor/README.md
index ccaccdda017ad61674ea245eb5f9c9465747f63e..c650f40b3ea8ef48b3c7644e279b00a1cb99f29a 100644
--- a/profiler/advisor/README.md
+++ b/profiler/advisor/README.md
@@ -92,31 +92,31 @@ msprof-analyze的advisor功能是将Ascend PyTorch Profiler或者msprof采集的
- 总体性能瓶颈
```bash
- msprof-analyze advisor all -d {profiling_path} [-bp benchmark_profiling_path] [-cv cann_version] [-tv torch_version] [-pt profiling_type] [-D] [-h]
+ msprof-analyze advisor all -d {profiling_path} [-bp benchmark_profiling_path] [-cv cann_version] [-tv torch_version] [-pt profiling_type] [--debug] [-h]
```
- 计算瓶颈
```bash
- msprof-analyze advisor computation -d {profiling_path} [-bp benchmark_profiling_path] [-cv cann_version] [-tv torch_version] [-pt profiling_type] [-D] [-h]
+ msprof-analyze advisor computation -d {profiling_path} [-cv cann_version] [-tv torch_version] [-pt profiling_type] [--debug] [-h]
```
- 调度瓶颈
```bash
- msprof-analyze advisor schedule -d {profiling_path} [-bp benchmark_profiling_path] [-cv cann_version] [-tv torch_version] [-D] [-h]
+ msprof-analyze advisor schedule -d {profiling_path} [-cv cann_version] [-tv torch_version] [--debug] [-h]
```
#### 参数介绍
| 参数 | 说明 | 是否必选 |
| ---------------------------------- | ------------------------------------------------------------ | -------- |
-| -d
--profiling_path | 性能数据所在目录。性能数据通过Profiling工具采集获取。请确保性能数据采集时配置“aic-metrics”参数为“PipeUtilization”,“aicpu”参数为“on”。advisor依赖Profiling工具解析后的timeline数据、summary数据以及info.json*文件,请确保指定的“profiling_dir”目录下存在以上文件。 | 是 |
+| -d
--profiling_path | 性能数据文件或目录所在路径,Ascend PyTorch Profiler采集场景指定为`*_ascend_pt`性能数据结果目录,其他场景指定为`PROF_XXX`性能数据结果目录。建议通过Ascend PyTorch Profiler获取性能数据。
advisor依赖Profiling工具解析后的timeline数据(.json)、summary(.csv)数据以及info.json*文件,请确保指定的“profiling_path”目录下存在以上文件。 | 是 |
| -bp
--benchmark_profiling_path | 基准性能数据所在目录,用于性能比对。性能数据通过Profiling工具采集获取。
**computation和schedule不支持该参数。** | 否 |
| -cv
--cann_version | 使用Profiling工具采集时对应的CANN软件版本,可通过在环境中执行如下命令获取其version字段,目前配套的兼容版本为“6.3.RC2”,“7.0.RC1”、“7.0.0”、“8.0.RC1”,此字段不填默认按“8.0.RC1”版本数据进行处理,其余版本采集的Profiling数据在分析时可能会导致不可知问题:`cat /usr/local/Ascend/ascend-toolkit/latest/aarch64-linux/ascend_toolkit_install.info` | 否 |
| -tv
--torch_version | 运行环境的torch版本,默认为1.11.0,支持torch1.11.0和torch2.1.0,当运行环境torch版本为其他版本如torch1.11.3时,可以忽略小版本号差异选择相近的torch版本如1.11.0。 | 否 |
-| -pt
--profiling_type | 配置性能数据采集使用的Profiling工具类型。可取值:
ascend_pytorch_profiler:使用Ascend PyThon Profiler接口方式采集的性能数据时配置,默认值。
msprof:使用msprof命令行方式采集的性能数据时配置。
mslite:使用[Benchmark](https://gitee.com/ascend/tools/tree/master/ais-bench_workload/tool/ais_bench)工具采集的性能数据时配置。
**schedule不支持该参数。** | 否 |
-| -D
--debug | 工具执行报错时可打开此开关,将会展示详细保存堆栈信息。 | 否 |
+| -pt
--profiling_type | 配置性能数据采集使用的Profiling工具类型。可取值:
ascend_pytorch_profiler:使用Ascend PyThon Profiler接口方式采集的性能数据时配置,默认值。
msprof:使用msprof命令行方式采集的性能数据时配置。功能完善中,暂不建议使用。
mslite:使用[Benchmark](https://gitee.com/ascend/tools/tree/master/ais-bench_workload/tool/ais_bench)工具采集的性能数据时配置。不建议使用。
**schedule不支持该参数。** | 否 |
+| --debug | 工具执行报错时可打开此开关,将会展示详细保存堆栈信息。 | 否 |
| -h,-H
--help | 在需要查询当前命令附属子命令或相关参数时,给出帮助建议。 | 否 |
### 报告解析
diff --git a/profiler/compare_tools/compare_backend/utils/constant.py b/profiler/compare_tools/compare_backend/utils/constant.py
index 1b77b214c85f6733e36298e119e43a778fd7969f..e2854692ae3218c873171b75878e3e69203effa2 100644
--- a/profiler/compare_tools/compare_backend/utils/constant.py
+++ b/profiler/compare_tools/compare_backend/utils/constant.py
@@ -74,7 +74,7 @@ class Constant(object):
MEMORY_LIST = "memory_list"
COMMUNICATION_DICT = "comm_dict"
- #compare type
+ # compare type
OVERALL_COMPARE = "overall"
BWD_LIST = ["bwd", "backward", "back"]
diff --git a/profiler/module_visualization/__init__.py b/profiler/module_visualization/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/profiler/module_visualization/graph/__init__.py b/profiler/module_visualization/graph/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/profiler/module_visualization/graph/prof_node.py b/profiler/module_visualization/graph/prof_node.py
new file mode 100644
index 0000000000000000000000000000000000000000..cfcdabbb991d2abb86f31e5a5866e788cf9a3c6e
--- /dev/null
+++ b/profiler/module_visualization/graph/prof_node.py
@@ -0,0 +1,90 @@
+# Copyright (c) 2024 Huawei Technologies Co., Ltd
+# All rights reserved.
+#
+# Licensed under the BSD 3-Clause License (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://opensource.org/licenses/BSD-3-Clause
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from profiler.prof_common.constant import Constant
+from profiler.prof_common.base_node import BaseNode
+from profiler.prof_common.trace_event_bean import TraceEventBean
+
+
+class ProfNode(BaseNode):
+ MODULE_TYPE = 1
+
+ def __init__(self, event: TraceEventBean, parent_node=None):
+ super().__init__(event, parent_node)
+ self._kernel_total_list = []
+
+ @property
+ def node_id(self):
+ return self._event.unique_id
+
+ @property
+ def total_kernels(self):
+ return self._kernel_total_list
+
+ @property
+ def host_total_dur(self):
+ if self.is_root_node:
+ return sum((node.host_total_dur for node in self.child_nodes))
+ return self._event.dur
+
+ @property
+ def host_self_dur(self):
+ return self.host_total_dur - sum((node.host_total_dur for node in self.child_nodes))
+
+ @property
+ def device_total_dur(self):
+ if self.is_root_node:
+ return sum((node.device_total_dur for node in self.child_nodes))
+ return sum((kernel.dur for kernel in self._kernel_total_list))
+
+ @property
+ def device_self_dur(self):
+ return self.device_total_dur - sum((node.device_total_dur for node in self.child_nodes))
+
+ @property
+ def input_data(self) -> dict:
+ data = {}
+ input_dim = self._event.args.get("Input Dims")
+ if input_dim:
+ data["Input Dims"] = input_dim
+ input_type = self._event.args.get("Input type")
+ if input_type:
+ data["Input type"] = input_type
+ return data
+
+ @property
+ def data(self):
+ return {"Input Data": self.input_data,
+ "Host Self Duration(us)": round(self.host_self_dur, 2),
+ "Host Total Duration(us)": round(self.host_total_dur, 2),
+ "Device Self Duration(us)": round(self.device_self_dur, 2),
+ "Device Total Duration(us)": round(self.device_total_dur, 2)}
+
+ @property
+ def info(self):
+ return {"id": self.node_id,
+ "node_type": self.MODULE_TYPE,
+ "data": self.data,
+ "upnode": self.parent_node.node_id if self.parent_node else "None",
+ "subnodes": [node.node_id for node in iter(self.child_nodes)]}
+
+ @property
+ def is_root_node(self):
+ return self.node_id == Constant.NPU_ROOT_ID
+
+ def update_child_nodes(self, node):
+ self._child_nodes.append(node)
+
+ def update_kernel_total_list(self, kernel_list: list):
+ self._kernel_total_list.extend(kernel_list)
diff --git a/profiler/module_visualization/graph_build/__init__.py b/profiler/module_visualization/graph_build/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/profiler/module_visualization/graph_build/fwd_module_node.py b/profiler/module_visualization/graph_build/fwd_module_node.py
new file mode 100644
index 0000000000000000000000000000000000000000..34d7ab829649f482c97fb489ac0399d3a876c100
--- /dev/null
+++ b/profiler/module_visualization/graph_build/fwd_module_node.py
@@ -0,0 +1,29 @@
+# Copyright (c) 2024 Huawei Technologies Co., Ltd
+# All rights reserved.
+#
+# Licensed under the BSD 3-Clause License (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://opensource.org/licenses/BSD-3-Clause
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from profiler.prof_common.base_node import BaseNode
+from profiler.prof_common.trace_event_bean import TraceEventBean
+
+
+class FwdModuleNode(BaseNode):
+ def __init__(self, event: TraceEventBean, parent_node=None):
+ super().__init__(event, parent_node)
+ self._bwd_op_list = []
+
+ @property
+ def bwd_op_list(self):
+ return self._bwd_op_list
+
+ def update_bwd_op(self, bwd_op_list: list):
+ self._bwd_op_list.extend(bwd_op_list)
diff --git a/profiler/module_visualization/graph_build/prof_graph_builder.py b/profiler/module_visualization/graph_build/prof_graph_builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..83331b6250211e32399b05cabf19a293759a3741
--- /dev/null
+++ b/profiler/module_visualization/graph_build/prof_graph_builder.py
@@ -0,0 +1,115 @@
+# Copyright (c) 2024 Huawei Technologies Co., Ltd
+# All rights reserved.
+#
+# Licensed under the BSD 3-Clause License (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://opensource.org/licenses/BSD-3-Clause
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from profiler.module_visualization.graph.prof_node import ProfNode
+from profiler.module_visualization.graph_build.fwd_module_node import FwdModuleNode
+from profiler.prof_common.tree_builder import TreeBuilder
+from profiler.prof_common.trace_event_bean import TraceEventBean
+from profiler.prof_common.constant import Constant
+from profiler.module_visualization.prof_parse.prof_data_pre_process import ProfDataPreProcess
+
+
+class ProfGraphBuilder:
+ def __init__(self, prof_data_path: str):
+ self._prof_data_path = prof_data_path
+ self._prof_data = {}
+
+ @classmethod
+ def _create_event_bean_from_ops(cls, op_list: list, name: str) -> TraceEventBean:
+ min_start = min((op.start_time for op in iter(op_list)))
+ max_end = max((op.end_time for op in iter(op_list)))
+ # 以反向算子的区间作为反向module的区间范围,为了module包含算子,做了+1 +2处理
+ return TraceEventBean({"ts": min_start - 1, "dur": float(max_end - min_start) + 2, "name": name})
+
+ @classmethod
+ def _trans_flow_to_dict(cls, flow_events: dict, end_events: list) -> dict:
+ end_event_dict = {}
+ for event in end_events:
+ end_event_dict[event.start_time] = event
+ result_data = {}
+ for flow in flow_events.values():
+ start_point = flow.get("start")
+ end_point = flow.get("end")
+ if not start_point or not end_point:
+ continue
+ end_event = end_event_dict.get(end_point.start_time)
+ if end_event:
+ result_data.setdefault(start_point.start_time, []).append(end_event)
+ return result_data
+
+ def build_graph(self):
+ self._prof_data = ProfDataPreProcess(self._prof_data_path).run()
+ all_data = [*self._prof_data.get(Constant.MODULE_EVENT, []),
+ *self.find_bwd_module(),
+ *self._prof_data.get(Constant.CPU_OP_EVENT, [])]
+ all_data.sort(key=lambda x: x.start_time)
+ name_dict = {}
+ for event in all_data:
+ order_id = name_dict.get(event.name, 0)
+ event.set_id(f"{event.name}_{order_id}")
+ name_dict[event.name] = order_id + 1
+ root_node = TreeBuilder.build_tree(all_data, ProfNode, TraceEventBean({}, Constant.NPU_ROOT_ID))
+ kernel_flow_dict = self._trans_flow_to_dict(self._prof_data.get(Constant.TORCH_TO_NPU_FLOW, {}),
+ self._prof_data.get(Constant.KERNEL_EVENT, []))
+ for start_time, kernels in kernel_flow_dict.items():
+ matched_node = root_node.binary_search(start_time)
+ while matched_node != Constant.INVALID_RETURN:
+ matched_node.update_kernel_total_list(kernels)
+ matched_node = matched_node.binary_search(start_time)
+ all_data = root_node.find_all_child_nodes()
+ all_data.append(root_node)
+ return all_data
+
+ def find_bwd_module(self) -> list:
+ bwd_module_list = []
+ fwdbwd_flow = self._prof_data.get(Constant.FWD_BWD_FLOW, {})
+ module_list = self._prof_data.get(Constant.MODULE_EVENT, [])
+ cpu_op_list = self._prof_data.get(Constant.CPU_OP_EVENT, [])
+ if not fwdbwd_flow or not module_list or not cpu_op_list:
+ return bwd_module_list
+ fwd_tid = module_list[0].tid
+ bwd_tid = fwd_tid
+ for end_point in (flow.get("end") for flow in fwdbwd_flow.values()):
+ if end_point:
+ bwd_tid = end_point.tid
+ break
+ if fwd_tid == bwd_tid:
+ return bwd_module_list
+ # 将每一个反向包成一个module,名字叫“nn.Module: BACKWARD_0”
+ cpu_op_list.sort(key=lambda x: x.start_time)
+ pre_status = Constant.FWD_OR_OPT
+ bwd_op_list = []
+ for op in cpu_op_list:
+ if op.tid == bwd_tid:
+ bwd_op_list.append(op)
+ pre_status = Constant.BACKWARD
+ elif pre_status == Constant.BACKWARD:
+ bwd_module_list.append(self._create_event_bean_from_ops(bwd_op_list, "nn.Module: BACKWARD"))
+ bwd_op_list.clear()
+ pre_status = Constant.FWD_OR_OPT
+
+ # 通过连线匹配正向module,构建出反向的整体module关系
+ root_node = TreeBuilder.build_tree(module_list, FwdModuleNode, TraceEventBean({}))
+ fwdbwd_flow_dict = self._trans_flow_to_dict(fwdbwd_flow, cpu_op_list)
+ for start_time, end_events in fwdbwd_flow_dict.items():
+ matched_node = root_node.binary_search(start_time)
+ while matched_node != Constant.INVALID_RETURN:
+ matched_node.update_bwd_op(end_events)
+ matched_node = matched_node.binary_search(start_time)
+ all_nodes = root_node.find_all_child_nodes()
+ for module_node in all_nodes:
+ if module_node.bwd_op_list:
+ bwd_module_list.append(
+ self._create_event_bean_from_ops(module_node.bwd_op_list, f"{module_node.name} [BACKWARD]"))
+ return bwd_module_list
diff --git a/profiler/module_visualization/prof_graph_export.py b/profiler/module_visualization/prof_graph_export.py
new file mode 100644
index 0000000000000000000000000000000000000000..d336e97f7419b53d011fa4c043948c60afa5174d
--- /dev/null
+++ b/profiler/module_visualization/prof_graph_export.py
@@ -0,0 +1,39 @@
+# Copyright (c) 2024 Huawei Technologies Co., Ltd
+# All rights reserved.
+#
+# Licensed under the BSD 3-Clause License (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://opensource.org/licenses/BSD-3-Clause
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+from datetime import datetime
+
+from profiler.prof_common.constant import Constant
+from profiler.prof_common.file_reader import FileReader
+from profiler.prof_common.path_manager import PathManager
+from profiler.module_visualization.graph_build.prof_graph_builder import ProfGraphBuilder
+
+
+class ProfGraphExport:
+ @staticmethod
+ def export_to_json(prof_data_path: str, output_path: str):
+ logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s")
+ try:
+ PathManager.input_path_common_check(prof_data_path)
+ PathManager.check_input_directory_path(output_path)
+ PathManager.make_dir_safety(output_path)
+ all_nodes = ProfGraphBuilder(prof_data_path).build_graph()
+ result_data = {"root": Constant.NPU_ROOT_ID, "node": {}}
+ for node in all_nodes:
+ result_data["node"][node.node_id] = node.info
+ file_name = "prof_graph_json_{}.vis".format(datetime.utcnow().strftime("%Y%m%d%H%M%S%f")[:-3])
+ FileReader.write_json_file(output_path, result_data, file_name)
+ except RuntimeError as err:
+ logging.error(err)
diff --git a/profiler/module_visualization/prof_parse/__init__.py b/profiler/module_visualization/prof_parse/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/profiler/module_visualization/prof_parse/prof_data_pre_process.py b/profiler/module_visualization/prof_parse/prof_data_pre_process.py
new file mode 100644
index 0000000000000000000000000000000000000000..9dc820e4ca560f816b7738243197b90f1adb8c25
--- /dev/null
+++ b/profiler/module_visualization/prof_parse/prof_data_pre_process.py
@@ -0,0 +1,102 @@
+# Copyright (c) 2024 Huawei Technologies Co., Ltd
+# All rights reserved.
+#
+# Licensed under the BSD 3-Clause License (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://opensource.org/licenses/BSD-3-Clause
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+
+from profiler.prof_common.file_reader import FileReader
+from profiler.prof_common.constant import Constant
+from profiler.prof_common.trace_event_bean import TraceEventBean
+
+
+class ProfDataPreProcess:
+ def __init__(self, prof_data_path: str):
+ self._prof_data_path = prof_data_path
+ self._trace_path = ""
+ self._kernel_pid = None
+ self._result_data = {Constant.CPU_OP_EVENT: [], Constant.MODULE_EVENT: [], Constant.KERNEL_EVENT: [],
+ Constant.TORCH_TO_NPU_FLOW: {}, Constant.FWD_BWD_FLOW: {}}
+
+ def run(self) -> dict:
+ self._check_trace_path()
+ self._parse_trace_events()
+ self._check_result_data()
+ return self._result_data
+
+ def _check_trace_path(self):
+ if os.path.isfile(self._prof_data_path):
+ (split_file_path, split_file_name) = os.path.split(self._prof_data_path)
+ (shot_name, extension) = os.path.splitext(split_file_name)
+ if extension != ".json":
+ msg = f"Invalid profiling path suffix: {self._prof_data_path}. " \
+ f"You should input in a json file path, such as trace_view.json."
+ raise RuntimeError(msg)
+ self._trace_path = self._prof_data_path
+ return
+ ascend_output = os.path.join(self._prof_data_path, "ASCEND_PROFILER_OUTPUT")
+ profiler_output = ascend_output if os.path.isdir(ascend_output) else self._prof_data_path
+ json_path = os.path.join(profiler_output, "trace_view.json")
+ if not os.path.isfile(json_path):
+ msg = f"Invalid profiling path: {self._prof_data_path}. The data path should be the " \
+ f"folder that ends with the ascend_pt collected by the Ascend PyTorch Profiler."
+ raise RuntimeError(msg)
+ self._trace_path = json_path
+
+ def _parse_trace_events(self):
+ trace_data = FileReader.read_json_file(self._trace_path)
+ self._check_trace_data(trace_data)
+ iter_trace_data = iter(trace_data)
+ for event in iter_trace_data:
+ bean = TraceEventBean(event)
+ if bean.is_optimizer():
+ self._result_data[Constant.MODULE_EVENT].append(bean)
+ elif bean.is_cpu_op():
+ if not bean.is_step():
+ self._result_data[Constant.CPU_OP_EVENT].append(bean)
+ elif bean.is_nn_module():
+ self._result_data[Constant.MODULE_EVENT].append(bean)
+ elif bean.is_torch_to_npu():
+ if bean.is_flow_start():
+ self._result_data[Constant.TORCH_TO_NPU_FLOW].setdefault(bean.id, {})["start"] = bean
+ else:
+ self._result_data[Constant.TORCH_TO_NPU_FLOW].setdefault(bean.id, {})["end"] = bean
+ elif bean.is_fwd_bwd_flow():
+ if bean.is_flow_start():
+ self._result_data[Constant.FWD_BWD_FLOW].setdefault(bean.id, {})["start"] = bean
+ else:
+ self._result_data[Constant.FWD_BWD_FLOW].setdefault(bean.id, {})["end"] = bean
+ elif bean.is_kernel_event(self._kernel_pid):
+ self._result_data[Constant.KERNEL_EVENT].append(bean)
+
+ def _check_trace_data(self, trace_data):
+ if not isinstance(trace_data, list):
+ msg = f"Invalid profiling data path, this feature only supports performance data " \
+ f"collected by Ascend PyTorch Profiler."
+ raise RuntimeError(msg)
+ iter_trace_data = iter(trace_data)
+ for event in iter_trace_data:
+ bean = TraceEventBean(event)
+ if bean.is_npu_process():
+ self._kernel_pid = bean.pid
+ break
+ if self._kernel_pid is None:
+ msg = f"There is no operator on the NPU side for this data, please check whether the NPU switch is enabled."
+ raise RuntimeError(msg)
+
+ def _check_result_data(self):
+ if not self._result_data.get(Constant.CPU_OP_EVENT):
+ msg = f"This data does not have any aten operator, please make sure to enable the CPU switch."
+ raise RuntimeError(msg)
+ if not self._result_data.get(Constant.MODULE_EVENT):
+ msg = f"This data does not collect any modules, please make sure to turn on the with_stack switch."
+ raise RuntimeError(msg)
diff --git a/profiler/prof_common/base_node.py b/profiler/prof_common/base_node.py
new file mode 100644
index 0000000000000000000000000000000000000000..b7cd6780003f9e0e5c58495ac43a893214e68beb
--- /dev/null
+++ b/profiler/prof_common/base_node.py
@@ -0,0 +1,78 @@
+# Copyright (c) 2024 Huawei Technologies Co., Ltd
+# All rights reserved.
+#
+# Licensed under the BSD 3-Clause License (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://opensource.org/licenses/BSD-3-Clause
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from math import ceil
+from queue import Queue
+
+from decimal import Decimal
+
+from profiler.prof_common.constant import Constant
+from profiler.prof_common.trace_event_bean import TraceEventBean
+
+
+class BaseNode:
+ def __init__(self, event: TraceEventBean, parent_node=None):
+ self._event = event
+ self._parent_node = parent_node
+ self._child_nodes = []
+
+ @property
+ def parent_node(self):
+ return self._parent_node
+
+ @property
+ def child_nodes(self):
+ return self._child_nodes
+
+ @property
+ def name(self):
+ return self._event.name
+
+ @property
+ def start_time(self) -> Decimal:
+ return self._event.start_time
+
+ @property
+ def end_time(self) -> Decimal:
+ return self._event.end_time
+
+ def update_child_nodes(self, node):
+ self._child_nodes.append(node)
+
+ def binary_search(self, ts_time):
+ if not self.child_nodes:
+ return Constant.INVALID_RETURN
+ right = len(self.child_nodes) - 1
+ left = 0
+ while right > left:
+ mid = left + ceil((right - left) / 2)
+ if ts_time >= self.child_nodes[mid].start_time:
+ left = mid
+ else:
+ right = mid - 1
+ if self.child_nodes[left].start_time < ts_time < self.child_nodes[left].end_time:
+ return self.child_nodes[left]
+ return Constant.INVALID_RETURN
+
+ def find_all_child_nodes(self) -> list:
+ result_data = []
+ node_queue = Queue()
+ for child_node in self.child_nodes:
+ node_queue.put(child_node)
+ while not node_queue.empty():
+ tree_node = node_queue.get()
+ result_data.append(tree_node)
+ for child_node in tree_node.child_nodes:
+ node_queue.put(child_node)
+ return result_data
diff --git a/profiler/prof_common/constant.py b/profiler/prof_common/constant.py
index 5789b89cb1a248977b64839339395acc5288b2ab..87bc51b56bc71c2a70e35a6b08aa4de7bd521f1d 100644
--- a/profiler/prof_common/constant.py
+++ b/profiler/prof_common/constant.py
@@ -15,4 +15,17 @@
class Constant(object):
COLLECTION_PATH = "collection_path"
ANALYSIS_MODE = "analysis_mode"
- CONTEXT_SETTINGS = dict(help_option_names=['-H', '-h', '--help'])
\ No newline at end of file
+ CONTEXT_SETTINGS = dict(help_option_names=['-H', '-h', '--help'])
+
+ MAX_FILE_SIZE_5_GB = 1024 * 1024 * 1024 * 5
+
+ MODULE_EVENT = "module_event"
+ CPU_OP_EVENT = "op_event"
+ TORCH_TO_NPU_FLOW = "torch_to_device"
+ KERNEL_EVENT = "kernel_event"
+ FWD_BWD_FLOW = "fwd_to_bwd"
+ NPU_ROOT_ID = "NPU"
+
+ FWD_OR_OPT = 0
+ BACKWARD = 1
+ INVALID_RETURN = -1
diff --git a/profiler/prof_common/file_reader.py b/profiler/prof_common/file_reader.py
new file mode 100644
index 0000000000000000000000000000000000000000..d8a9c8fb4d6599edf46973f8e93aa708903ff007
--- /dev/null
+++ b/profiler/prof_common/file_reader.py
@@ -0,0 +1,59 @@
+# Copyright (c) 2024 Huawei Technologies Co., Ltd
+# All rights reserved.
+#
+# Licensed under the BSD 3-Clause License (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://opensource.org/licenses/BSD-3-Clause
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import json
+import logging
+import os
+
+from profiler.prof_common.path_manager import PathManager
+from profiler.prof_common.constant import Constant
+
+
+class FileReader:
+ DATA_FILE_AUTHORITY = 0o640
+ DATA_DIR_AUTHORITY = 0o750
+
+ @classmethod
+ def read_json_file(cls, file_path: str) -> any:
+ PathManager.check_path_readable(file_path)
+ if not os.path.isfile(file_path):
+ raise FileNotFoundError("File not exists.")
+ file_size = os.path.getsize(file_path)
+ if file_size <= 0:
+ return []
+ if file_size > Constant.MAX_FILE_SIZE_5_GB:
+ msg = f"The file({file_path}) size exceeds the preset max value, failed to read the file."
+ raise RuntimeError(msg)
+ try:
+ with open(file_path, "rt") as file:
+ json_data = json.loads(file.read())
+ except Exception as e:
+ msg = f"Can't read file: {file_path}"
+ raise RuntimeError(msg) from e
+ return json_data
+
+ @classmethod
+ def write_json_file(cls, output_path: str, data: dict, file_name: str, format_json: bool = False) -> None:
+ if not data:
+ return
+ output_file = os.path.join(output_path, file_name)
+ PathManager.check_path_writeable(output_path)
+ try:
+ with os.fdopen(
+ os.open(output_file, os.O_WRONLY | os.O_CREAT, cls.DATA_FILE_AUTHORITY), 'w'
+ ) as file:
+ indent = 4 if format_json else None
+ file.write(json.dumps(data, indent=indent))
+ except Exception as e:
+ raise RuntimeError(f"Can't create the file: {output_path}") from e
diff --git a/profiler/prof_common/path_manager.py b/profiler/prof_common/path_manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e41b8b50aca42ba33071b2661966d221102e106
--- /dev/null
+++ b/profiler/prof_common/path_manager.py
@@ -0,0 +1,191 @@
+# Copyright (c) 2024 Huawei Technologies Co., Ltd
+# All rights reserved.
+#
+# Licensed under the BSD 3-Clause License (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://opensource.org/licenses/BSD-3-Clause
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import re
+import shutil
+import platform
+
+
+class PathManager:
+ MAX_PATH_LENGTH = 4096
+ MAX_FILE_NAME_LENGTH = 255
+ DATA_FILE_AUTHORITY = 0o640
+ DATA_DIR_AUTHORITY = 0o750
+ WINDOWS = "windows"
+
+ @classmethod
+ def check_input_directory_path(cls, path: str):
+ """
+ Function Description:
+ check whether the path is valid, some businesses can accept a path that does not exist,
+ so the function do not verify whether the path exists
+ Parameter:
+ path: the path to check, whether the incoming path is absolute or relative depends on the business
+ Exception Description:
+ when invalid data throw exception
+ """
+ cls.input_path_common_check(path)
+ base_name = os.path.basename(path)
+ if os.path.isfile(path):
+ msg = f"Invalid input path which is a file path: {base_name}"
+ raise RuntimeError(msg)
+
+ @classmethod
+ def check_input_file_path(cls, path: str):
+ """
+ Function Description:
+ check whether the file path is valid, some businesses can accept a path that does not exist,
+ so the function do not verify whether the path exists
+ Parameter:
+ path: the file path to check, whether the incoming path is absolute or relative depends on the business
+ Exception Description:
+ when invalid data throw exception
+ """
+ cls.input_path_common_check(path)
+ base_name = os.path.basename(path)
+ if os.path.isdir(path):
+ msg = f"Invalid input path which is a directory path: {base_name}"
+ raise RuntimeError(msg)
+
+ @classmethod
+ def check_path_length(cls, path: str):
+ if len(path) > cls.MAX_PATH_LENGTH:
+ raise RuntimeError("Length of input path exceeds the limit.")
+ path_split_list = path.split("/")
+ for path in path_split_list:
+ path_list = path.split("\\")
+ for name in path_list:
+ if len(name) > cls.MAX_FILE_NAME_LENGTH:
+ raise RuntimeError("Length of input path exceeds the limit.")
+
+ @classmethod
+ def input_path_common_check(cls, path: str):
+ cls.check_path_length(path)
+
+ if os.path.islink(path):
+ msg = f"Invalid input path which is a soft link."
+ raise RuntimeError(msg)
+
+ if platform.system().lower() == cls.WINDOWS:
+ pattern = r'(\.|:|\\|/|_|-|\s|[~0-9a-zA-Z\u4e00-\u9fa5])+'
+ else:
+ pattern = r'(\.|/|_|-|\s|[~0-9a-zA-Z])+'
+ if not re.fullmatch(pattern, path):
+ msg = f"Invalid input path."
+ raise RuntimeError(msg)
+
+ @classmethod
+ def check_path_owner_consistent(cls, path: str):
+ """
+ Function Description:
+ check whether the path belong to process owner
+ Parameter:
+ path: the path to check
+ Exception Description:
+ when invalid path, prompt the user
+ """
+ base_name = os.path.basename(path)
+ if not os.path.exists(path):
+ msg = f"Invalid path: {base_name}"
+ raise RuntimeError(msg)
+ if platform.system().lower() == cls.WINDOWS:
+ return
+ if os.stat(path).st_uid != os.getuid():
+ check_msg = input("The path does not belong to you, do you want to continue? [y/n]")
+ if check_msg.lower() != "y":
+ raise RuntimeError("The user choose not to continue.")
+
+ @classmethod
+ def check_path_writeable(cls, path):
+ """
+ Function Description:
+ check whether the path is writable
+ Parameter:
+ path: the path to check
+ Exception Description:
+ when invalid data throw exception
+ """
+ cls.check_path_owner_consistent(path)
+ if os.path.islink(path):
+ msg = f"Invalid path which is a soft link."
+ raise RuntimeError(msg)
+ base_name = os.path.basename(path)
+ if not os.access(path, os.W_OK):
+ msg = f"The path permission check failed: {base_name}"
+ raise RuntimeError(msg)
+
+ @classmethod
+ def check_path_readable(cls, path):
+ """
+ Function Description:
+ check whether the path is writable
+ Parameter:
+ path: the path to check
+ Exception Description:
+ when invalid data throw exception
+ """
+ cls.check_path_owner_consistent(path)
+ if os.path.islink(path):
+ msg = f"Invalid path which is a soft link."
+ raise RuntimeError(msg)
+ base_name = os.path.basename(path)
+ if not os.access(path, os.R_OK):
+ msg = f"The path permission check failed: {base_name}"
+ raise RuntimeError(msg)
+
+ @classmethod
+ def remove_path_safety(cls, path: str):
+ base_name = os.path.basename(path)
+ msg = f"Failed to remove path: {base_name}"
+ if os.path.islink(path):
+ raise RuntimeError(msg)
+ if os.path.exists(path):
+ try:
+ shutil.rmtree(path)
+ except Exception as err:
+ raise RuntimeError(msg) from err
+
+ @classmethod
+ def make_dir_safety(cls, path: str):
+ base_name = os.path.basename(path)
+ msg = f"Failed to make directory: {base_name}"
+ if os.path.islink(path):
+ raise RuntimeError(msg)
+ if os.path.exists(path):
+ return
+ try:
+ os.makedirs(path, mode=cls.DATA_DIR_AUTHORITY)
+ except Exception as err:
+ raise RuntimeError(msg) from err
+
+ @classmethod
+ def create_file_safety(cls, path: str):
+ base_name = os.path.basename(path)
+ msg = f"Failed to create file: {base_name}"
+ if os.path.islink(path):
+ raise RuntimeError(msg)
+ if os.path.exists(path):
+ return
+ try:
+ os.close(os.open(path, os.O_WRONLY | os.O_CREAT, cls.DATA_FILE_AUTHORITY))
+ except Exception as err:
+ raise RuntimeError(msg) from err
+
+ @classmethod
+ def get_realpath(cls, path: str) -> str:
+ if os.path.islink(path):
+ msg = f"Invalid input path which is a soft link."
+ raise RuntimeError(msg)
+ return os.path.realpath(path)
diff --git a/profiler/prof_common/trace_event_bean.py b/profiler/prof_common/trace_event_bean.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d4b96e4f6aa84ce225531da89085ba4a07335a5
--- /dev/null
+++ b/profiler/prof_common/trace_event_bean.py
@@ -0,0 +1,69 @@
+# Copyright (c) 2024 Huawei Technologies Co., Ltd
+# All rights reserved.
+#
+# Licensed under the BSD 3-Clause License (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://opensource.org/licenses/BSD-3-Clause
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from decimal import Decimal
+
+from profiler.prof_common.utils import convert_to_decimal
+from profiler.prof_common.analyze_dict import AnalyzeDict
+
+
+class TraceEventBean(AnalyzeDict):
+ def __init__(self, data: dict, unique_id: int = None):
+ super().__init__(data)
+ self._id = unique_id
+
+ @property
+ def unique_id(self):
+ return self._id
+
+ @property
+ def start_time(self) -> Decimal:
+ return convert_to_decimal(self.ts)
+
+ @property
+ def end_time(self) -> Decimal:
+ return self.start_time + convert_to_decimal(self.dur)
+
+ def set_id(self, name_id):
+ self._id = name_id
+
+ def is_cpu_op(self):
+ return self.cat == "cpu_op"
+
+ def is_optimizer(self):
+ return self.cat == "cpu_op" and self.name.lower().startswith("optimizer")
+
+ def is_nn_module(self):
+ return self.cat == "python_function" and self.name.lower().startswith("nn.module")
+
+ def is_step(self):
+ return self.name.lower().startswith("profilerstep#")
+
+ def is_torch_to_npu(self):
+ return self.cat == "async_npu"
+
+ def is_fwd_bwd_flow(self):
+ return self.cat == "fwdbwd"
+
+ def is_flow_start(self):
+ return self.ph == "s"
+
+ def is_flow_end(self):
+ return self.ph == "f"
+
+ def is_kernel_event(self, kernel_pid):
+ return self.ph == "X" and self.pid == kernel_pid
+
+ def is_npu_process(self):
+ return self.ph == "M" and self.name == "process_name" and self.args.get("name", "") == "Ascend Hardware"
diff --git a/profiler/prof_common/tree_builder.py b/profiler/prof_common/tree_builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..b7d3e1baf6aa48c480124056ced422178f8fe7a2
--- /dev/null
+++ b/profiler/prof_common/tree_builder.py
@@ -0,0 +1,33 @@
+# Copyright (c) 2024 Huawei Technologies Co., Ltd
+# All rights reserved.
+#
+# Licensed under the BSD 3-Clause License (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://opensource.org/licenses/BSD-3-Clause
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from profiler.prof_common.trace_event_bean import TraceEventBean
+
+
+class TreeBuilder:
+ @staticmethod
+ def build_tree(event_list: list, node_class: any, root_bean: any):
+ root_node = node_class(root_bean)
+ event_list.sort(key=lambda x: x.start_time)
+ last_node = root_node
+ for event in event_list:
+ while last_node:
+ if last_node != root_node and event.start_time > last_node.end_time:
+ last_node = last_node.parent_node
+ continue
+ tree_node = node_class(event, last_node)
+ last_node.update_child_nodes(tree_node)
+ last_node = tree_node
+ break
+ return root_node
diff --git a/profiler/prof_common/utils.py b/profiler/prof_common/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9db41ad0b8d9dd91132959fd5b583f5711d88db
--- /dev/null
+++ b/profiler/prof_common/utils.py
@@ -0,0 +1,25 @@
+# Copyright (c) 2024 Huawei Technologies Co., Ltd
+# All rights reserved.
+#
+# Licensed under the BSD 3-Clause License (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://opensource.org/licenses/BSD-3-Clause
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+from decimal import Decimal
+
+
+def convert_to_decimal(data: any) -> Decimal:
+ try:
+ decimal_value = Decimal(data)
+ except Exception:
+ logging.error('Invalid profiling data which failed to convert data to decimal.')
+ return 0.0
+ return decimal_value