diff --git a/OWNERS b/OWNERS index 7b721dd643e3399d29ff649e2f76182de72421a4..71b6708f14855f8c00990f95c5117e3c38958885 100644 --- a/OWNERS +++ b/OWNERS @@ -44,4 +44,6 @@ reviewers: - binghamhuang - wjchuee - zhou-xianqi -- stby11 \ No newline at end of file +- stby11 +- TAJh +- jiandaobao \ No newline at end of file diff --git a/debug/OWNERS b/debug/OWNERS index 84a4493dd246c3f894eebdfaed65d20b03cb6308..7212741b0f1e1a34fb15ffb5549d9d391756919d 100644 --- a/debug/OWNERS +++ b/debug/OWNERS @@ -5,8 +5,10 @@ approvers: - kun_8 - binghamhuang - brightlyking +- litian_drinksnow reviewers: - lv-kaimeng -- litian_drinksnow - binghamhuang - xiangsen2 +- TAJh +- jiandaobao \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/core/common/const.py b/debug/accuracy_tools/msprobe/core/common/const.py index c1a453a21a6c2f8f30f22812214e2a6e4fc53932..84c5ac1abe4fe44402caf3a1877d99f582626111 100644 --- a/debug/accuracy_tools/msprobe/core/common/const.py +++ b/debug/accuracy_tools/msprobe/core/common/const.py @@ -1,5 +1,6 @@ import os import stat + import numpy as np @@ -16,10 +17,17 @@ class Const: OFF = 'OFF' BACKWARD = 'backward' FORWARD = 'forward' + PRIMITIVE_PREFIX = 'Primitive' DEFAULT_LIST = [] DEFAULT_PATH = './' WHITE_LIST = 'white_list' BLACK_LIST = 'black_list' + DUMP_TENSOR_DATA = 'dump_tensor_data' + NONE = None + THREE_SEGMENT = 3 + FOUR_SEGMENT = 4 + SIX_SEGMENT = 6 + SEVEN_SEGMENT = 7 # dump mode ALL = "all" @@ -61,13 +69,18 @@ class Const: ENV_ENABLE = "1" ENV_DISABLE = "0" MAX_SEED_VALUE = 4294967295 # 2**32 - 1 - TASK_LIST = ["tensor", "statistics", "overflow_check", "free_benchmark", "run_ut"] - LEVEL_LIST = ["L0", "L1", "L2", "mix"] STATISTICS = "statistics" TENSOR = "tensor" OVERFLOW_CHECK = "overflow_check" FREE_BENCHMARK = "free_benchmark" RUN_UT = "run_ut" + GRAD_PROBE = "grad_probe" + TASK_LIST = [TENSOR, STATISTICS, OVERFLOW_CHECK, FREE_BENCHMARK, RUN_UT, GRAD_PROBE] + LEVEL_L0 = "L0" + LEVEL_L1 = "L1" + LEVEL_L2 = "L2" + LEVEL_MIX = "mix" + LEVEL_LIST = [LEVEL_L0, LEVEL_L1, LEVEL_L2, LEVEL_MIX] ATTR_NAME_PREFIX = "wrap_" ATTR_NAME_PREFIX_LEN = len(ATTR_NAME_PREFIX) KERNEL_DUMP = "kernel_dump" @@ -80,11 +93,15 @@ class Const: BOOL_TYPE = [bool, np.uint8] INT_TYPE = [np.int32, np.int64] NPU = 'NPU' + NPU_LOWERCASE = 'npu' + CPU_LOWERCASE = 'cpu' + CUDA_LOWERCASE = 'cuda' DISTRIBUTED = 'Distributed' INPLACE_LIST = [ "broadcast", "all_reduce", "reduce", "all_gather", "gather", "scatter", "reduce_scatter", - "_reduce_scatter_base", "_all_gather_base", "send", "recv", "irecv", "isend", "all_to_all_single", "all_to_all" + "_reduce_scatter_base", "_all_gather_base", "send", "recv", "irecv", "isend", "all_to_all_single", "all_to_all", + "all_gather_into_tensor", "reduce_scatter_tensor" ] CONVERT = { @@ -135,7 +152,12 @@ class CompareConst: NPU_MD5 = "NPU MD5" BENCH_MD5 = "BENCH MD5" RESULT = "Result" - + MAGNITUDE = 0.5 + OP_NAME = "op_name" + INPUT_STRUCT = "input_struct" + OUTPUT_STRUCT = "output_struct" + SUMMARY = "summary" + COMPARE_RESULT_HEADER = [ NPU_NAME, BENCH_NAME, NPU_DTYPE, BENCH_DTYPE, NPU_SHAPE, BENCH_SHAPE, COSINE, MAX_ABS_ERR, MAX_RELATIVE_ERR, ONE_THOUSANDTH_ERR_RATIO, FIVE_THOUSANDTHS_ERR_RATIO, @@ -172,6 +194,7 @@ class CompareConst: WARNING = 'Warning' ERROR = 'error' SKIP = 'SKIP' + N_A = 'N/A' BFLOAT16_MIN = -3.3895313892515355e+38 BFLOAT16_MAX = 3.3895313892515355e+38 BFLOAT16_EPS = 3.90625e-3 # 2 ** -8 @@ -179,6 +202,7 @@ class CompareConst: # accuracy standards COS_THRESHOLD = 0.99 MAX_ABS_ERR_THRESHOLD = 0.001 + MAX_RELATIVE_ERR_THRESHOLD = 0.001 COS_MAX_THRESHOLD = 0.9 MAX_ABS_ERR_MAX_THRESHOLD = 1 ACCURACY_CHECK_YES = "Yes" @@ -195,6 +219,10 @@ class CompareConst: RED = "FFFF0000" YELLOW = "FFFF00" BLUE = "0000FF" + + # run_ut const + MAX_TOKENS = 65536 + SPECIAL_SPARSE_MOED = 4 # highlight rules const OVERFLOW_LIST = ['nan\t', 'inf\t', '-inf\t', 'nan', 'inf', '-inf'] @@ -207,6 +235,19 @@ class CompareConst: MAX_RELATIVE_OUT_RED = 0.5 MAX_RELATIVE_OUT_YELLOW = 0.1 MAX_RELATIVE_IN_YELLOW = 0.01 + MS_GRAPH_BASE = { + NPU_NAME: None, BENCH_NAME: None, NPU_DTYPE: None, BENCH_DTYPE: None, NPU_SHAPE: None, BENCH_SHAPE: None, + NPU_MAX: None, NPU_MIN: None, NPU_MEAN: None, NPU_NORM: None, BENCH_MAX: None, BENCH_MIN: None, + BENCH_MEAN: None, BENCH_NORM: None, ACCURACY: '', ERROR_MESSAGE: '' + } + MS_GRAPH_NPY = { + COSINE: None, MAX_ABS_ERR: None, MAX_RELATIVE_ERR: None, ONE_THOUSANDTH_ERR_RATIO: None, + FIVE_THOUSANDTHS_ERR_RATIO: None + } + MS_GRAPH_STATISTIC = { + MAX_DIFF: None, MIN_DIFF: None, MEAN_DIFF: None, NORM_DIFF: None, MAX_RELATIVE_ERR: None, + MIN_RELATIVE_ERR: None, MEAN_RELATIVE_ERR: None, NORM_RELATIVE_ERR: None + } class FileCheckConst: @@ -254,16 +295,43 @@ class OverflowConst: OVERFLOW_ORIGINAL_MODE = 0 OVERFLOW_DEBUG_MODE = 1 +class MsCompareConst: + # api_info field + MINT = "Mint" + MINT_FUNCTIONAL = "MintFunctional" -class MsConst: - CELL = "cell" - API = "api" - KERNEL = "kernel" - TOOL_LEVEL_DICT = { - "L0": CELL, - "L1": API, - "L2": KERNEL - } - PYNATIVE_MODE = "pynative" - GRAPH_GE_MODE = "graph_ge" - GRAPH_KBYK_MODE = "graph_kbyk" + TASK_FIELD = "task" + STATISTICS_TASK = "statistics" + TENSOR_TASK = "tensor" + DUMP_DATA_DIR_FIELD = "dump_data_dir" + DATA_FIELD = "data" + + #detail_csv + DETAIL_CSV_API_NAME = "API Name" + DETAIL_CSV_BENCH_DTYPE = "Bench Dtype" + DETAIL_CSV_TESTED_DTYPE = "Tested Dtype" + DETAIL_CSV_SHAPE = "Shape" + DETAIL_CSV_PASS_STATUS = "Status" + DETAIL_CSV_MESSAGE = "Message" + DETAIL_CSV_FILE_NAME = "accuracy_checking_details" + + #result_csv + RESULT_CSV_FORWARD_TEST_SUCCESS = "Forward Test Success" + RESULT_CSV_BACKWARD_TEST_SUCCESS = "Backward Test Success" + RESULT_CSV_FILE_NAME = "accuracy_checking_result" + + EPSILON = 1e-8 + +class MsgConst: + """ + Class for log messages const + """ + CLEAR_SYMBOL = "\033[K" + LEVEL = ["INFO", "WARNING", "ERROR"] + SPECIAL_CHAR = ["\n", "\r", "\u007F", "\b", "\f", "\t", "\u000B", "%08", "%0a", "%0b", "%0c", "%0d", "%7f"] + + +class GraphMode: + NPY_MODE = "NPY_MODE" + STATISTIC_MODE = "STATISTIC_MODE" + ERROR_MODE = "ERROR_MODE" diff --git a/debug/accuracy_tools/msprobe/core/common/file_check.py b/debug/accuracy_tools/msprobe/core/common/file_utils.py similarity index 56% rename from debug/accuracy_tools/msprobe/core/common/file_check.py rename to debug/accuracy_tools/msprobe/core/common/file_utils.py index 36896cfbc19b29f1fcaef04228aac37dc29c8416..0976323b5a760c1e6d250923ab7bfdbc166a0bef 100644 --- a/debug/accuracy_tools/msprobe/core/common/file_check.py +++ b/debug/accuracy_tools/msprobe/core/common/file_utils.py @@ -14,8 +14,14 @@ # See the License for the specific language governing permissions and # limitations under the License. """ +import csv +import fcntl import os +import json import re +import shutil +import yaml +import numpy as np from msprobe.core.common.log import logger from msprobe.core.common.exceptions import FileCheckException @@ -32,6 +38,7 @@ class FileChecker: ability(str): FileCheckConst.WRITE_ABLE or FileCheckConst.READ_ABLE to set file has writability or readability file_type(str): The correct file type for file """ + def __init__(self, file_path, path_type, ability=None, file_type=None, is_script=True): self.file_path = file_path self.path_type = self._check_path_type(path_type) @@ -187,14 +194,18 @@ def check_path_owner_consistent(path): def check_path_pattern_vaild(path): if not re.match(FileCheckConst.FILE_VALID_PATTERN, path): - logger.error('The file path %s contains special characters.' %(path)) + logger.error('The file path %s contains special characters.' % (path)) raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR) def check_file_size(file_path, max_size): - file_size = os.path.getsize(file_path) + try: + file_size = os.path.getsize(file_path) + except OSError as os_error: + logger.error(f'Failed to open "{file_path}". {str(os_error)}') + raise FileCheckException(FileCheckException.INVALID_FILE_ERROR) from os_error if file_size >= max_size: - logger.error(f'The size of file path {file_path} exceeds {max_size} bytes.') + logger.error(f'The size ({file_size}) of {file_path} exceeds ({max_size}) bytes, tools not support.') raise FileCheckException(FileCheckException.FILE_TOO_LARGE_ERROR) @@ -224,21 +235,36 @@ def check_path_type(file_path, file_type): raise FileCheckException(FileCheckException.INVALID_FILE_ERROR) +def make_dir(dir_path): + dir_path = os.path.realpath(dir_path) + check_path_before_create(dir_path) + if os.path.isdir(dir_path): + return + try: + os.mkdir(dir_path, mode=FileCheckConst.DATA_DIR_AUTHORITY) + except OSError as ex: + raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR, + f"Failed to create {dir_path}. " + f"Please check the path permission or disk space. {str(ex)}") from ex + file_check = FileChecker(dir_path, FileCheckConst.DIR) + file_check.common_check() + + def create_directory(dir_path): """ Function Description: - creating a directory with specified permissions + creating a safe directory with specified permissions Parameter: dir_path: directory path Exception Description: when invalid data throw exception """ dir_path = os.path.realpath(dir_path) - try: - os.makedirs(dir_path, mode=FileCheckConst.DATA_DIR_AUTHORITY, exist_ok=True) - except OSError as ex: - raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR, - 'Failed to create {}. Please check the path permission or disk space .{}'.format(dir_path, str(ex))) from ex + check_path_before_create(dir_path) + parent_dir = os.path.dirname(dir_path) + if not os.path.isdir(parent_dir): + create_directory(parent_dir) + make_dir(dir_path) def check_path_before_create(path): @@ -250,6 +276,23 @@ def check_path_before_create(path): 'The file path {} contains special characters.'.format(path)) +def check_file_or_directory_path(path, isdir=False): + """ + Function Description: + check whether the path is valid + Parameter: + path: the path to check + isdir: the path is dir or file + Exception Description: + when invalid data throw exception + """ + if isdir: + path_checker = FileChecker(path, FileCheckConst.DIR, FileCheckConst.WRITE_ABLE) + else: + path_checker = FileChecker(path, FileCheckConst.FILE, FileCheckConst.READ_ABLE) + path_checker.common_check() + + def change_mode(path, mode): if not os.path.exists(path) or os.path.islink(path): return @@ -262,4 +305,174 @@ def change_mode(path, mode): def path_len_exceeds_limit(file_path): return len(os.path.realpath(file_path)) > FileCheckConst.DIRECTORY_LENGTH or \ - len(os.path.basename(file_path)) > FileCheckConst.FILE_NAME_LENGTH \ No newline at end of file + len(os.path.basename(file_path)) > FileCheckConst.FILE_NAME_LENGTH + + +def check_file_type(path): + """ + Function Description: + determine if it is a file or a directory + Parameter: + path: path + Exception Description: + when neither a file nor a directory throw exception + """ + if os.path.isdir(path): + return FileCheckConst.DIR + elif os.path.isfile(path): + return FileCheckConst.FILE + else: + logger.error('Neither a file nor a directory.') + raise FileCheckException(FileCheckException.INVALID_FILE_ERROR) + + +def load_yaml(yaml_path): + path_checker = FileChecker(yaml_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE, FileCheckConst.YAML_SUFFIX) + checked_path = path_checker.common_check() + try: + with FileOpen(checked_path, "r") as f: + yaml_data = yaml.safe_load(f) + except Exception as e: + logger.error(f"The yaml file failed to load. Please check the path: {checked_path}.") + raise RuntimeError(f"Load yaml file {checked_path} failed.") from e + return yaml_data + + +def load_npy(filepath, enable_pickle=False): + check_file_or_directory_path(filepath) + try: + npy = np.load(filepath, allow_pickle=enable_pickle) + except Exception as e: + logger.error(f"The numpy file failed to load. Please check the path: {filepath}.") + raise RuntimeError(f"Load numpy file {filepath} failed.") from e + return npy + + +def load_json(json_path): + try: + with FileOpen(json_path, "r") as f: + fcntl.flock(f, fcntl.LOCK_EX) + data = json.load(f) + fcntl.flock(f, fcntl.LOCK_UN) + except Exception as e: + logger.error(f'load json file "{os.path.basename(json_path)}" failed.') + raise RuntimeError(f"Load json file {json_path} failed.") from e + return data + + +def save_json(json_path, data, indent=None): + json_path = os.path.realpath(json_path) + check_path_before_create(json_path) + try: + with FileOpen(json_path, 'w') as f: + fcntl.flock(f, fcntl.LOCK_EX) + json.dump(data, f, indent=indent) + fcntl.flock(f, fcntl.LOCK_UN) + except Exception as e: + logger.error(f'Save json file "{os.path.basename(json_path)}" failed.') + raise RuntimeError(f"Save json file {json_path} failed.") from e + change_mode(json_path, FileCheckConst.DATA_FILE_AUTHORITY) + + +def move_file(src_path, dst_path): + check_file_or_directory_path(src_path) + check_path_before_create(dst_path) + try: + shutil.move(src_path, dst_path) + except Exception as e: + logger.error(f"move file {src_path} to {dst_path} failed") + raise RuntimeError(f"move file {src_path} to {dst_path} failed") from e + change_mode(dst_path, FileCheckConst.DATA_FILE_AUTHORITY) + + +def save_npy(data, filepath): + filepath = os.path.realpath(filepath) + check_path_before_create(filepath) + try: + np.save(filepath, data) + except Exception as e: + logger.error(f"The numpy file failed to save. Please check the path: {filepath}.") + raise RuntimeError(f"Save numpy file {filepath} failed.") from e + change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY) + + +def save_npy_to_txt(self, data, dst_file='', align=0): + if os.path.exists(dst_file): + self.log.info("Dst file %s exists, will not save new one.", dst_file) + return + shape = data.shape + data = data.flatten() + if align == 0: + align = 1 if len(shape) == 0 else shape[-1] + elif data.size % align != 0: + pad_array = np.zeros((align - data.size % align,)) + data = np.append(data, pad_array) + check_path_before_create(dst_file) + try: + np.savetxt(dst_file, data.reshape((-1, align)), delimiter=' ', fmt='%g') + except Exception as e: + self.log.error("An unexpected error occurred: %s when savetxt to %s" % (str(e)), dst_file) + change_mode(dst_file, FileCheckConst.DATA_FILE_AUTHORITY) + + +def save_workbook(workbook, file_path): + """ + 保存工作簿到指定的文件路径 + workbook: 要保存的工作簿对象 + file_path: 文件保存路径 + """ + file_path = os.path.realpath(file_path) + check_path_before_create(file_path) + try: + workbook.save(file_path) + except Exception as e: + logger.error(f'Save result file "{os.path.basename(file_path)}" failed') + raise RuntimeError(f"Save result file {file_path} failed.") from e + change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY) + + +def write_csv(data, filepath, mode="a+"): + file_path = os.path.realpath(filepath) + check_path_before_create(filepath) + try: + with FileOpen(filepath, mode, encoding='utf-8-sig') as f: + writer = csv.writer(f) + writer.writerows(data) + except Exception as e: + logger.error(f'Save csv file "{os.path.basename(file_path)}" failed') + raise RuntimeError(f"Save csv file {file_path} failed.") from e + change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY) + + +def remove_path(path): + if not os.path.exists(path): + return + try: + if os.path.islink(path) or os.path.isfile(path): + os.remove(path) + else: + shutil.rmtree(path) + except PermissionError as err: + logger.error("Failed to delete {}. Please check the permission.".format(path)) + raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR) from err + except Exception as e: + logger.error("Failed to delete {}. Please check.".format(path)) + raise RuntimeError(f"Delete {path} failed.") from e + + +def get_json_contents(file_path): + ops = get_file_content_bytes(file_path) + try: + json_obj = json.loads(ops) + except ValueError as error: + logger.error('Failed to load json.') + raise FileCheckException(FileCheckException.INVALID_FILE_ERROR) from error + if not isinstance(json_obj, dict): + logger.error('Json file content is not a dictionary!') + raise FileCheckException(FileCheckException.INVALID_FILE_ERROR) + return json_obj + + +def get_file_content_bytes(file): + with FileOpen(file, 'rb') as file_handle: + return file_handle.read() diff --git a/debug/accuracy_tools/msprobe/core/common/utils.py b/debug/accuracy_tools/msprobe/core/common/utils.py index aac5a054819e424ebd2fcb222871139e5a1c2024..038fb779fc7f4a0a5665476883ee6121c9200640 100644 --- a/debug/accuracy_tools/msprobe/core/common/utils.py +++ b/debug/accuracy_tools/msprobe/core/common/utils.py @@ -15,19 +15,20 @@ # limitations under the License. """ import collections +import fcntl import os import re import shutil -import stat import subprocess import time import json +import csv from datetime import datetime, timezone -from pathlib import Path +import yaml import numpy as np -from msprobe.core.common.file_check import FileOpen, FileChecker -from msprobe.core.common.const import Const, FileCheckConst, CompareConst, OverflowConst +from msprobe.core.common.file_utils import FileOpen, FileChecker, change_mode +from msprobe.core.common.const import Const, FileCheckConst, CompareConst from msprobe.core.common.log import logger @@ -60,6 +61,8 @@ class CompareException(Exception): OVER_SIZE_FILE_ERROR = 18 INVALID_SUMMARY_MODE = 19 INVALID_TASK_ERROR = 20 + DETACH_ERROR = 21 + def __init__(self, code, error_info: str = ""): super(CompareException, self).__init__() @@ -91,19 +94,6 @@ class DumpException(CompareException): pass -def make_dump_path_if_not_exists(dump_path): - if not os.path.exists(dump_path): - try: - Path(dump_path).mkdir(mode=0o750, exist_ok=True, parents=True) - except OSError as ex: - logger.error( - 'Failed to create {}.Please check the path permission or disk space .{}'.format(dump_path, str(ex))) - raise CompareException(CompareException.INVALID_PATH_ERROR) from ex - else: - if not os.path.isdir(dump_path): - logger.error('{} already exists and is not a directory.'.format(dump_path)) - - def check_mode_valid(mode, scope=None, api_list=None): if scope is None: scope = [] @@ -165,21 +155,24 @@ def check_summary_only_valid(summary_only): return summary_only -def check_compare_param(input_parma, output_path, summary_compare=False, md5_compare=False): - if not (isinstance(input_parma, dict) and isinstance(output_path, str)): +def check_compare_param(input_param, output_path, summary_compare=False, md5_compare=False): + if not (isinstance(input_param, dict) and isinstance(output_path, str)): logger.error("Invalid input parameters") raise CompareException(CompareException.INVALID_PARAM_ERROR) - check_file_or_directory_path(input_parma.get("npu_json_path"), False) - check_file_or_directory_path(input_parma.get("bench_json_path"), False) - check_file_or_directory_path(input_parma.get("stack_json_path"), False) + + check_file_or_directory_path(input_param.get("npu_json_path"), False) + check_file_or_directory_path(input_param.get("bench_json_path"), False) + check_file_or_directory_path(input_param.get("stack_json_path"), False) if not summary_compare and not md5_compare: - check_file_or_directory_path(input_parma.get("npu_dump_data_dir"), True) - check_file_or_directory_path(input_parma.get("bench_dump_data_dir"), True) + check_file_or_directory_path(input_param.get("npu_dump_data_dir"), True) + check_file_or_directory_path(input_param.get("bench_dump_data_dir"), True) check_file_or_directory_path(output_path, True) - with FileOpen(input_parma.get("npu_json_path"), "r") as npu_json, \ - FileOpen(input_parma.get("bench_json_path"), "r") as bench_json, \ - FileOpen(input_parma.get("stack_json_path"), "r") as stack_json: - check_json_file(input_parma, npu_json, bench_json, stack_json) + + with FileOpen(input_param.get("npu_json_path"), "r") as npu_json, \ + FileOpen(input_param.get("bench_json_path"), "r") as bench_json, \ + FileOpen(input_param.get("stack_json_path"), "r") as stack_json: + check_json_file(input_param, npu_json, bench_json, stack_json) + def check_configuration_param(stack_mode=False, auto_analyze=True, fuzzy_match=False): @@ -274,6 +267,17 @@ def remove_path(path): raise CompareException(CompareException.INVALID_PATH_ERROR) from err +def move_file(src_path, dst_path): + check_file_or_directory_path(src_path) + check_path_before_create(dst_path) + try: + shutil.move(src_path, dst_path) + except Exception as e: + logger.error(f"move file {src_path} to {dst_path} failed") + raise RuntimeError(f"move file {src_path} to {dst_path} failed") from e + change_mode(dst_path, FileCheckConst.DATA_FILE_AUTHORITY) + + def get_dump_data_path(dump_dir): """ Function Description: @@ -296,24 +300,6 @@ def get_dump_data_path(dump_dir): return dump_data_path, file_is_exist -def create_directory(dir_path): - """ - Function Description: - creating a directory with specified permissions - Parameter: - dir_path: directory path - Exception Description: - when invalid data throw exception - """ - if not os.path.exists(dir_path): - try: - os.makedirs(dir_path, mode=0o700) - except OSError as ex: - logger.error( - 'Failed to create {}.Please check the path permission or disk space .{}'.format(dir_path, str(ex))) - raise CompareException(CompareException.INVALID_PATH_ERROR) from ex - - def execute_command(cmd): """ Function Description: @@ -480,14 +466,14 @@ def md5_find(data): def task_dumppath_get(input_param): - npu_json_path = input_param.get("npu_json_path", None) - bench_json_path = input_param.get("bench_json_path", None) - if not npu_json_path or not bench_json_path: + npu_path = input_param.get("npu_json_path", None) + bench_path = input_param.get("bench_json_path", None) + if not npu_path or not bench_path: logger.error(f"Please check the json path is valid.") raise CompareException(CompareException.INVALID_PATH_ERROR) - with FileOpen(npu_json_path, 'r') as npu_f: + with FileOpen(npu_path, 'r') as npu_f: npu_json_data = json.load(npu_f) - with FileOpen(bench_json_path, 'r') as bench_f: + with FileOpen(bench_path, 'r') as bench_f: bench_json_data = json.load(bench_f) if npu_json_data['task'] != bench_json_data['task']: logger.error(f"Please check the dump task is consistent.") @@ -504,8 +490,8 @@ def task_dumppath_get(input_param): else: logger.error(f"Compare is not required for overflow_check or free_benchmark.") raise CompareException(CompareException.INVALID_TASK_ERROR) - input_param['npu_dump_data_dir'] = npu_json_data['dump_data_dir'] - input_param['bench_dump_data_dir'] = bench_json_data['dump_data_dir'] + input_param['npu_dump_data_dir'] = os.path.join(os.path.dirname(npu_path), Const.DUMP_TENSOR_DATA) + input_param['bench_dump_data_dir'] = os.path.join(os.path.dirname(bench_path), Const.DUMP_TENSOR_DATA) return summary_compare, md5_compare @@ -522,3 +508,122 @@ def get_header_index(header_name, summary_compare=False): def convert_tuple(data): return data if isinstance(data, tuple) else (data, ) + + +def write_csv(data, filepath, mode="a+"): + exist = os.path.exists(filepath) + with FileOpen(filepath, mode, encoding='utf-8-sig') as f: + writer = csv.writer(f) + writer.writerows(data) + if not exist: + change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY) + + +def load_npy(filepath, enable_pickle=False): + check_file_or_directory_path(filepath) + try: + npy = np.load(filepath, allow_pickle=enable_pickle) + except Exception as e: + logger.error(f"The numpy file failed to load. Please check the path: {filepath}.") + raise RuntimeError(f"Load numpy file {filepath} failed.") from e + return npy + + +def save_npy(data, filepath): + filepath = os.path.realpath(filepath) + check_path_before_create(filepath) + try: + np.save(filepath, data) + except Exception as e: + logger.error(f"The numpy file failed to save. Please check the path: {filepath}.") + raise RuntimeError(f"Save numpy file {filepath} failed.") from e + change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY) + +def save_npy_to_txt(self, data, dst_file='', align=0): + if os.path.exists(dst_file): + self.log.info("Dst file %s exists, will not save new one.", dst_file) + return + shape = data.shape + data = data.flatten() + if align == 0: + align = 1 if len(shape) == 0 else shape[-1] + elif data.size % align != 0: + pad_array = np.zeros((align - data.size % align,)) + data = np.append(data, pad_array) + check_path_before_create(dst_file) + try: + np.savetxt(dst_file, data.reshape((-1, align)), delimiter=' ', fmt='%g') + except Exception as e: + self.log.error("An unexpected error occurred: %s when savetxt to %s" % (str(e)), dst_file) + change_mode(dst_file, FileCheckConst.DATA_FILE_AUTHORITY) + +def get_json_contents(file_path): + ops = get_file_content_bytes(file_path) + try: + json_obj = json.loads(ops) + except ValueError as error: + logger.error('Failed to load json.') + raise CompareException(CompareException.INVALID_FILE_ERROR) from error + if not isinstance(json_obj, dict): + logger.error('Json file content is not a dictionary!') + raise CompareException(CompareException.INVALID_FILE_ERROR) + return json_obj + + +def get_file_content_bytes(file): + with FileOpen(file, 'rb') as file_handle: + return file_handle.read() + + +def load_yaml(yaml_path): + path_checker = FileChecker(yaml_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE, FileCheckConst.YAML_SUFFIX) + checked_path = path_checker.common_check() + try: + with FileOpen(checked_path, "r") as f: + yaml_data = yaml.safe_load(f) + except Exception as e: + logger.error(f"The yaml file failed to load. Please check the path: {checked_path}.") + raise RuntimeError(f"Load yaml file {checked_path} failed.") from e + return yaml_data + + +def save_workbook(workbook, file_path): + """ + 保存工作簿到指定的文件路径 + workbook: 要保存的工作簿对象 + file_path: 文件保存路径 + """ + file_path = os.path.realpath(file_path) + check_path_before_create(file_path) + try: + workbook.save(file_path) + except Exception as e: + logger.error(f'Save result file "{os.path.basename(file_path)}" failed') + raise CompareException(CompareException.WRITE_FILE_ERROR) from e + change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY) + + +def load_json(json_path): + try: + with FileOpen(json_path, "r") as f: + fcntl.flock(f, fcntl.LOCK_EX) + data = json.load(f) + fcntl.flock(f, fcntl.LOCK_UN) + except Exception as e: + logger.error(f'load json file "{os.path.basename(json_path)}" failed.') + raise DumpException(DumpException.WRITE_FILE_ERROR) from e + return data + + +def save_json(json_path, data, indent=None): + json_path = os.path.realpath(json_path) + check_path_before_create(json_path) + try: + with FileOpen(json_path, 'w') as f: + fcntl.flock(f, fcntl.LOCK_EX) + json.dump(data, f, indent=indent) + fcntl.flock(f, fcntl.LOCK_UN) + except Exception as e: + logger.error(f'Save json file "{os.path.basename(json_path)}" failed.') + raise DumpException(DumpException.WRITE_FILE_ERROR) from e + change_mode(json_path, FileCheckConst.DATA_FILE_AUTHORITY) diff --git a/debug/accuracy_tools/msprobe/core/compare/utils.py b/debug/accuracy_tools/msprobe/core/compare/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9cf02b953f6898e753139b3e0801f4dcf3799db1 --- /dev/null +++ b/debug/accuracy_tools/msprobe/core/compare/utils.py @@ -0,0 +1,430 @@ + +import os +import re +import numpy as np +from msprobe.core.common.const import Const, CompareConst +from msprobe.core.common.utils import CompareException, check_regex_prefix_format_valid, logger +from msprobe.core.common.file_utils import check_file_or_directory_path + + +def extract_json(dirname, stack_json=False): + json_path = '' + for fname in os.listdir(dirname): + if fname == "construct.json": + continue + full_path = os.path.join(dirname, fname) + if full_path.endswith('.json'): + json_path = full_path + if not stack_json and 'stack' not in json_path: + break + if stack_json and 'stack' in json_path: + break + + # Provide robustness on invalid directory inputs + if not json_path: + logger.error(f'No file is found in dump dir {dirname}. ') + raise CompareException(CompareException.NO_DUMP_FILE_ERROR) + return json_path + + +def check_and_return_dir_contents(dump_dir, prefix): + """ + check the given dump dir and validate files in dump dir by using the given prefix patterns to build a + pattern: ^{prefix}(?:0|[0-9][1-9]*)?$ + + Args: + dump_dir (str): dump dir + prefix (str): prefix for the patterns, prefix should be less than 20 characters and alphanumeric/-/_ only + + Returns: + content [list]: dir contents + Raises: + CompareException: invalid path + ValueError: prefix not match the patterns + + """ + check_regex_prefix_format_valid(prefix) + check_file_or_directory_path(dump_dir, True) + contents = os.listdir(dump_dir) + pattern = re.compile(rf'^{prefix}(?:0|[0-9][1-9]*)?$') + for name in contents: + if not pattern.match(name): + logger.error( + f"dump_dir contains '{name}'. Expected '{prefix}'. This name is not in the format of dump " + f"output. Please check and delete irrelevant files in {dump_dir} and try again." + ) + raise CompareException(CompareException.INVALID_PATH_ERROR) + return contents + + +def rename_api(npu_name, process): + npu_split = npu_name.split(process) + torch_func_index, in_out = npu_split[0], npu_split[1] + torch_func_split = torch_func_index.rsplit(Const.SEP, 2) + torch_func = str(torch_func_split[0]) + str(in_out) + return torch_func + + +def read_op(op_data, op_name): + op_parsed_list = Const.DEFAULT_LIST + if Const.FORWARD in op_name: + if Const.INPUT_ARGS in op_data: + input_item = op_data[Const.INPUT_ARGS] + input_parsed_list = op_item_parse(input_item, op_name + '.input', None) + op_parsed_list = input_parsed_list.copy() + input_parsed_list.clear() + if Const.INPUT_KWARGS in op_data: + kwargs_item = op_data[Const.INPUT_KWARGS] + if isinstance(kwargs_item, dict) and "type" in kwargs_item or isinstance(kwargs_item, list): + kwarg_parsed_list = op_item_parse(kwargs_item, op_name + '.input', None) + op_parsed_list += kwarg_parsed_list + kwarg_parsed_list.clear() + elif kwargs_item: + for kwarg in kwargs_item: + kwarg_parsed_list = op_item_parse(kwargs_item[kwarg], op_name + '.input.' + kwarg, None) + op_parsed_list += kwarg_parsed_list + kwarg_parsed_list.clear() + if Const.OUTPUT in op_data: + output_item = op_data[Const.OUTPUT] + output_parsed_list = op_item_parse(output_item, op_name + '.output', None) + op_parsed_list += output_parsed_list + output_parsed_list.clear() + if Const.BACKWARD in op_name: + if Const.INPUT in op_data: + input_item = op_data[Const.INPUT] + input_parsed_list = op_item_parse(input_item, op_name + '.input', None) + op_parsed_list = input_parsed_list.copy() + input_parsed_list.clear() + if Const.OUTPUT in op_data: + output_item = op_data[Const.OUTPUT] + output_parsed_list = op_item_parse(output_item, op_name + '.output', None) + op_parsed_list += output_parsed_list + output_parsed_list.clear() + return op_parsed_list + + +def op_item_parse(item, op_name, index, item_list=None, top_bool=True): + if item_list is None: + item_list = [] + if item is None or (isinstance(item, dict) and not item): + if not top_bool: + tmp = {'full_op_name': op_name + '.' + str(index), 'Max': None, 'Min': None, 'Mean': None, 'Norm': None, + 'dtype': None, 'shape': None, 'md5': None, 'data_name': '-1'} + else: + tmp = {'full_op_name': op_name + '.0', 'Max': None, 'Min': None, 'Mean': None, 'Norm': None, 'dtype': None, + 'shape': None, 'md5': None, 'data_name': '-1'} + item_list.append(tmp) + return item_list + if index is None: + if isinstance(item, dict): + full_op_name = op_name + '.0' + else: + full_op_name = op_name + else: + full_op_name = op_name + Const.SEP + str(index) + if isinstance(item, dict): + if 'type' not in item: + for kwarg in item: + kwarg_parsed_list = op_item_parse(item[kwarg], op_name + Const.SEP + kwarg, None) + item_list += kwarg_parsed_list + kwarg_parsed_list.clear() + elif 'dtype' in item: + parsed_item = item + parsed_item['full_op_name'] = full_op_name + item_list.append(parsed_item) + elif 'type' in item: + parsed_item = {} + if item['type'] == 'torch.Size': + parsed_item['full_op_name'] = full_op_name + parsed_item['dtype'] = 'torch.Size' + parsed_item['shape'] = str(item['value']) + parsed_item['md5'] = None + parsed_item['Max'] = None + parsed_item['Min'] = None + parsed_item['Mean'] = None + parsed_item['Norm'] = None + parsed_item['data_name'] = '-1' + item_list.append(parsed_item) + elif item['type'] == 'slice': + parsed_item['full_op_name'] = full_op_name + parsed_item['dtype'] = 'slice' + parsed_item['shape'] = str(np.shape(np.array(item['value']))) + parsed_item['md5'] = None + parsed_item['Max'] = None + parsed_item['Min'] = None + parsed_item['Mean'] = None + parsed_item['Norm'] = None + parsed_item['data_name'] = '-1' + item_list.append(parsed_item) + else: + parsed_item['full_op_name'] = full_op_name + parsed_item['dtype'] = str(type(item['value'])) + parsed_item['shape'] = '[]' + parsed_item['md5'] = None + parsed_item['Max'] = item['value'] + parsed_item['Min'] = item['value'] + parsed_item['Mean'] = item['value'] + parsed_item['Norm'] = item['value'] + parsed_item['data_name'] = '-1' + item_list.append(parsed_item) + else: + resolve_api_special_parameters(item, full_op_name, item_list) + else: + for j, item_spec in enumerate(item): + op_item_parse(item_spec, full_op_name, j, item_list=item_list, top_bool=False) + return item_list + + +def resolve_api_special_parameters(data_dict, full_op_name, item_list): + """ + Function Description: + 解析下面格式的数据, 是api参数的一种特殊格式 + { + "last_hidden_state": { + "type": "torch.Tensor", + "dtype": "torch.bfloat16", + ... + }, + "loss": { + "type": "torch.Tensor", + "dtype": "torch.float32", + ... + } + } + Parameter: + data_dict: 字典格式的数据 + full_op_name: 参数的全名字符串 + item_list: 参数信息集合 + """ + for key, value in data_dict.items(): + if isinstance(value, dict): + parsed_item = value + parts = full_op_name.split(Const.SEP) + parts.insert(-1, key) + full_op_name_new = ".".join(parts) + parsed_item['full_op_name'] = full_op_name_new + item_list.append(parsed_item) + + +def get_accuracy(result, n_dict, b_dict, summary_compare=False, md5_compare=False): + def get_accuracy_core(n_start, n_len, b_start, b_len, key): + min_len = min(n_len, b_len) + npu_stack_info = n_dict.get("stack_info", None) + bench_stack_info = b_dict.get("stack_info", None) + has_stack = npu_stack_info and bench_stack_info + + all_mode_bool = not (summary_compare or md5_compare) + if all_mode_bool: + npu_data_name = n_dict.get("data_name", None) + bench_data_name = b_dict.get("data_name", None) + + for index in range(min_len): + + n_name = n_dict['op_name'][n_start + index] + b_name = b_dict['op_name'][b_start + index] + n_struct = n_dict[key][index] + b_struct = b_dict[key][index] + err_msg = "" + if md5_compare: + result_item = [n_name, b_name, n_struct[0], b_struct[0], n_struct[1], b_struct[1], + n_struct[2], b_struct[2], + CompareConst.PASS if n_struct[2] == b_struct[2] else CompareConst.DIFF] + if has_stack and index == 0 and key == "input_struct": + result_item.extend(npu_stack_info) + else: + result_item.append(CompareConst.NONE) + result.append(result_item) + continue + + if summary_compare: + result_item = [n_name, b_name, n_struct[0], b_struct[0], n_struct[1], b_struct[1], + " ", " ", " ", " ", " ", " ", " ", " "] + else: + result_item = [n_name, b_name, n_struct[0], b_struct[0], n_struct[1], b_struct[1], + " ", " ", " ", " ", " "] + + npu_summary_data = n_dict.get("summary")[n_start + index] + result_item.extend(npu_summary_data) + bench_summary_data = b_dict.get("summary")[b_start + index] + result_item.extend(bench_summary_data) + + if summary_compare: + start_idx = CompareConst.SUMMARY_COMPARE_RESULT_HEADER.index(CompareConst.MAX_DIFF) + warning_flag = False + for i, (npu_val, bench_val) in enumerate(zip(npu_summary_data, bench_summary_data)): + if isinstance(npu_val, (float, int)) and isinstance(bench_val, (float, int)): + diff = npu_val - bench_val + if bench_val != 0: + relative = str(abs((diff / bench_val) * 100)) + '%' + else: + relative = "N/A" + result_item[start_idx + i] = diff + result_item[start_idx + i + 4] = relative + magnitude_diff = abs(diff) / (max(abs(npu_val), abs(bench_val)) + 1e-10) + if magnitude_diff > 0.5: + warning_flag = True + else: + result_item[start_idx + i] = CompareConst.NONE + accuracy_check = CompareConst.WARNING if warning_flag else "" + err_msg += "Need double check api accuracy." if warning_flag else "" + for i in range(start_idx, len(result_item)): + if str(result_item[i]) in ('inf', '-inf', 'nan'): + result_item[i] = f'{result_item[i]}\t' + + result_item.append(accuracy_check if summary_compare else CompareConst.ACCURACY_CHECK_YES) + result_item.append(err_msg) + if has_stack and index == 0 and key == "input_struct": + result_item.extend(npu_stack_info) + else: + result_item.append(CompareConst.NONE) + if all_mode_bool: + result_item.append(npu_data_name[n_start + index]) + + result.append(result_item) + + if n_len > b_len: + for index in range(b_len, n_len): + n_name = n_dict['op_name'][n_start + index] + n_struct = n_dict[key][index] + if md5_compare: + result_item = [n_name, CompareConst.NAN, n_struct[0], CompareConst.NAN, + n_struct[1], CompareConst.NAN, n_struct[2], CompareConst.NAN, CompareConst.NAN] + result.append(result_item) + continue + result_item = [n_name, CompareConst.NAN, n_struct[0], CompareConst.NAN, + n_struct[1], CompareConst.NAN, " ", " ", " ", " ", " "] + summary_data = n_dict.get("summary")[n_start + index] + result_item.extend(summary_data) + summary_data = [CompareConst.NAN for _ in range(len(n_dict.get("summary")[0]))] + result_item.extend(summary_data) + + err_msg = "" + result_item.append(CompareConst.ACCURACY_CHECK_YES) + result_item.append(err_msg) + + if has_stack and index == 0 and key == "input_struct": + result_item.extend(npu_stack_info) + else: + result_item.append(CompareConst.NONE) + if all_mode_bool: + result_item.append(npu_data_name[n_start + index]) + + result.append(result_item) + + n_num = len(n_dict['op_name']) + b_num = len(b_dict['op_name']) + n_num_input = len([name for name in n_dict['op_name'] if Const.INPUT in name]) + b_num_input = len([name for name in b_dict['op_name'] if Const.INPUT in name]) + n_num_kwarg = len([name for name in n_dict['op_name'] if 'kwarg' in name]) + b_num_kwarg = len([name for name in b_dict['op_name'] if 'kwarg' in name]) + n_num_output = n_num - n_num_input - n_num_kwarg + b_num_output = b_num - b_num_input - b_num_kwarg + get_accuracy_core(0, n_num_input, 0, b_num_input, 'input_struct') + get_accuracy_core(n_num_input, n_num_kwarg, b_num_input, b_num_kwarg, "kwargs_struct") + get_accuracy_core(n_num_input + n_num_kwarg, n_num_output, b_num_input + b_num_kwarg, b_num_output, 'output_struct') + + +def get_un_match_accuracy(result, n_dict, md5_compare, summary_compare): + index_out = 0 + npu_stack_info = n_dict.get("stack_info", None) + bench_name, bench_type, bench_shape = CompareConst.N_A, CompareConst.N_A, CompareConst.N_A + err_msg = CompareConst.NO_BENCH + accuracy_check_res = CompareConst.N_A + for index, n_name in enumerate(n_dict["op_name"]): + if n_name.find("input") != -1: + n_struct = n_dict["input_struct"][index] + else: + n_struct = n_dict["output_struct"][index_out] + index_out += 1 + + result_item = [n_name, bench_name, n_struct[0], bench_type, n_struct[1], bench_shape] + if md5_compare: + result_item.extend([CompareConst.N_A] * 3) + if npu_stack_info and index == 0: + result_item.extend(npu_stack_info) + else: + result_item.append(CompareConst.NONE) + result.append(result_item) + continue + if summary_compare: + result_item.extend([CompareConst.N_A] * 8) + else: + result_item.extend([CompareConst.N_A] * 5) + npu_summary_data = n_dict.get("summary")[index] + result_item.extend(npu_summary_data) + bench_summary_data = [CompareConst.N_A] * 4 + result_item.extend(bench_summary_data) + result_item.append(accuracy_check_res) + result_item.append(err_msg) + if npu_stack_info and index == 0: + result_item.extend(npu_stack_info) + else: + result_item.append(CompareConst.NONE) + if not md5_compare and not summary_compare and result_item[1] == CompareConst.N_A: + result_item.extend(["-1"]) + result.append(result_item) + + +def merge_tensor(tensor_list, summary_compare, md5_compare): + op_dict = {} + op_dict["op_name"] = [] + op_dict["input_struct"] = [] + op_dict["kwargs_struct"] = [] + op_dict["output_struct"] = [] + op_dict["summary"] = [] + op_dict["stack_info"] = [] + + all_mode_bool = not (summary_compare or md5_compare) + if all_mode_bool: + op_dict["data_name"] = [] + + for tensor in tensor_list: + if len(tensor) == 2: + op_dict['stack_info'].append(tensor['full_info']) + break + op_dict["op_name"].append(tensor['full_op_name']) + if not md5_compare: + if tensor['full_op_name'].find("input") != -1: + op_dict["input_struct"].append((tensor['dtype'], tensor['shape'])) + elif tensor['full_op_name'].find("kwarg") != -1: + op_dict["kwargs_struct"].append((tensor['dtype'], tensor['shape'])) + elif tensor['full_op_name'].find("output") != -1: + op_dict["output_struct"].append((tensor['dtype'], tensor['shape'])) + else: + if tensor['full_op_name'].find("input") != -1: + op_dict["input_struct"].append((tensor['dtype'], tensor['shape'], tensor['md5'])) + elif tensor['full_op_name'].find("kwarg") != -1: + op_dict["kwargs_struct"].append((tensor['dtype'], tensor['shape'], tensor['md5'])) + elif tensor['full_op_name'].find("output") != -1: + op_dict["output_struct"].append((tensor['dtype'], tensor['shape'], tensor['md5'])) + + op_dict["summary"].append([tensor['Max'], tensor['Min'], tensor['Mean'], tensor['Norm']]) + + if all_mode_bool: + op_dict["data_name"].append(tensor['data_name']) + + if not op_dict["kwargs_struct"]: + del op_dict["kwargs_struct"] + return op_dict if op_dict["op_name"] else {} + + +def _compare_parser(parser): + parser.add_argument("-i", "--input_path", dest="input_path", type=str, + help=" The compare input path, a dict json.", required=True) + parser.add_argument("-o", "--output_path", dest="output_path", type=str, + help=" The compare task result out path.", required=True) + parser.add_argument("-s", "--stack_mode", dest="stack_mode", action="store_true", + help=" Whether to save stack info.", required=False) + parser.add_argument("-c", "--compare_only", dest="compare_only", action="store_true", + help=" Whether to give advisor.", required=False) + parser.add_argument("-f", "--fuzzy_match", dest="fuzzy_match", action="store_true", + help=" Whether to perform a fuzzy match on the api name.", required=False) + parser.add_argument("-cm", "--cell_mapping", dest="cell_mapping", type=str, nargs='?', const=True, + help=" The cell mapping file path.", required=False) + parser.add_argument("-am", "--api_mapping", dest="api_mapping", type=str, nargs='?', const=True, + help=" The api mapping file path.", required=False) + + + + + diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py index c208df7d900683197fc24081b42835716ce7605f..0cfb27cda2afa8794943dceb23b4085f0305c50b 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py @@ -21,7 +21,7 @@ import numpy as np from msprobe.core.common.const import Const from msprobe.core.data_dump.data_processor.base import (BaseDataProcessor, TensorStatInfo, ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs) -from msprobe.core.common.file_check import path_len_exceeds_limit, change_mode, FileCheckConst +from msprobe.core.common.file_utils import path_len_exceeds_limit, change_mode, FileCheckConst from msprobe.mindspore.dump.hook_cell.wrap_functional import load_ops_functions from msprobe.mindspore.common.utils import convert_bf16_to_fp32 from msprobe.mindspore.common.log import logger diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py index 007fec80964e300315c59f3d7fa4166b9d10fa70..0986ccf826330f9783a5dde2cc1d07a87c2e57f4 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py @@ -6,7 +6,7 @@ from typing import List import numpy as np import torch from msprobe.core.common.exceptions import MsprobeException -from msprobe.core.common.file_check import path_len_exceeds_limit, change_mode +from msprobe.core.common.file_utils import path_len_exceeds_limit, change_mode from msprobe.core.common.log import logger from msprobe.core.common.const import Const, OverflowConst, FileCheckConst from msprobe.core.data_dump.data_processor.base import BaseDataProcessor, ModuleBackwardInputsOutputs, \ diff --git a/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py b/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py index 112e45171efb23535a5496e776918d5ae07ca4ab..99cc5f3159ee038133e874109a78b2a7a85d9413 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py @@ -4,7 +4,7 @@ import fcntl import json from pathlib import Path -from msprobe.core.common.file_check import change_mode, FileOpen +from msprobe.core.common.file_utils import change_mode, FileOpen from msprobe.core.common.log import logger from msprobe.core.common.const import Const, FileCheckConst diff --git a/debug/accuracy_tools/msprobe/pytorch/advisor/advisor.py b/debug/accuracy_tools/msprobe/pytorch/advisor/advisor.py index b178664d9e37f7d6cafdca58218b75909ab9cfcc..fe5bf1efb1f0e520c9cf77cd1cb808650e66e548 100644 --- a/debug/accuracy_tools/msprobe/pytorch/advisor/advisor.py +++ b/debug/accuracy_tools/msprobe/pytorch/advisor/advisor.py @@ -21,7 +21,7 @@ from msprobe.pytorch.advisor.advisor_result import AdvisorResult from msprobe.pytorch.advisor.advisor_const import AdvisorConst from msprobe.pytorch.common.log import logger from msprobe.core.common.utils import CompareException -from msprobe.core.common.file_check import FileChecker +from msprobe.core.common.file_utils import FileChecker from msprobe.core.common.const import Const, CompareConst, FileCheckConst class Advisor: diff --git a/debug/accuracy_tools/msprobe/pytorch/advisor/advisor_result.py b/debug/accuracy_tools/msprobe/pytorch/advisor/advisor_result.py index 456f542e1f5bf867aa3db6a88e36dd03f8b581dc..58b76d3c8479321eb354e27c785a38d4ca3d8aaa 100644 --- a/debug/accuracy_tools/msprobe/pytorch/advisor/advisor_result.py +++ b/debug/accuracy_tools/msprobe/pytorch/advisor/advisor_result.py @@ -20,7 +20,7 @@ import time from msprobe.pytorch.advisor.advisor_const import AdvisorConst from msprobe.pytorch.common.log import logger from msprobe.core.common.const import Const, FileCheckConst -from msprobe.core.common.file_check import change_mode +from msprobe.core.common.file_utils import change_mode class AdvisorResult: diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/config.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/config.py index 760e7c862dba5440412f5ee27d0345d1a17d2c5d..ca6bb1627e736ded4050bc3fd0bab9f5d64a1d51 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/config.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/config.py @@ -1,7 +1,7 @@ import os import yaml from msprobe.pytorch.api_accuracy_checker.common.utils import check_file_or_directory_path -from msprobe.core.common.file_check import FileOpen +from msprobe.core.common.file_utils import FileOpen from msprobe.pytorch.pt_config import RunUTConfig diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/utils.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/utils.py index b6e8932960c6ce15c65c83874720bd7d24f19909..7855a51e4b472fd3d22b5ad9ee2f7e4d4a0d39e5 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/utils.py @@ -29,7 +29,7 @@ else: IS_GPU = False from msprobe.pytorch.common.log import logger -from msprobe.core.common.file_check import FileChecker, FileOpen, change_mode, create_directory +from msprobe.core.common.file_utils import FileChecker, FileOpen, change_mode, create_directory from msprobe.core.common.const import Const, FileCheckConst from msprobe.core.common.utils import CompareException diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py index 73bf7c2b8ebd59af8e31b9a7e9ad534f11717340..d85abfe9ddaf24ce19ecfef965886db93e16761c 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py @@ -16,7 +16,7 @@ from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import API_PRECI check_inf_or_nan from msprobe.pytorch.api_accuracy_checker.compare.compare_column import ApiPrecisionOutputColumn from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import get_validated_result_csv_path -from msprobe.core.common.file_check import FileChecker, change_mode, check_path_before_create, create_directory +from msprobe.core.common.file_utils import FileChecker, change_mode, check_path_before_create, create_directory from msprobe.pytorch.common.log import logger from msprobe.core.common.utils import CompareException from msprobe.core.common.const import CompareConst, FileCheckConst diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py index 5c7e86ff36cbe027efdc20e4dc6cbdbf4b98b808..f8450b64b56a24fde76f0fc5a7be46479b5059ea 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py @@ -8,7 +8,7 @@ import yaml from msprobe.core.common.utils import CompareException from msprobe.core.common.const import Const from msprobe.pytorch.common.log import logger -from msprobe.core.common.file_check import FileOpen +from msprobe.core.common.file_utils import FileOpen current_time = time.strftime("%Y%m%d%H%M%S") diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py index b2eec691af0e7161e8f53a607c92acf29ad71ceb..51477ddd46efe607b4407814c753d07a6f714ad9 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/data_generate.py @@ -23,7 +23,7 @@ import numpy from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import hf_32_standard_api from msprobe.pytorch.api_accuracy_checker.common.utils import check_object_type, get_full_data_path, \ CompareException -from msprobe.core.common.file_check import FileChecker +from msprobe.core.common.file_utils import FileChecker from msprobe.pytorch.common.log import logger from msprobe.core.common.const import Const, FileCheckConst diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py index 9acb5ee64981b400ea00579b91db0c35320cfe08..049f6e9de5f590d5eb78a9ff2614a570c8bb88cb 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py @@ -13,7 +13,7 @@ from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import _run_ut_parser, g get_validated_details_csv_path, preprocess_forward_content from msprobe.pytorch.api_accuracy_checker.compare.compare import Comparator from msprobe.pytorch.common import parse_json_info_forward_backward -from msprobe.core.common.file_check import FileChecker, check_file_suffix, check_link, FileOpen, \ +from msprobe.core.common.file_utils import FileChecker, check_file_suffix, check_link, FileOpen, \ check_path_before_create, create_directory from msprobe.pytorch.common.log import logger from msprobe.core.common.const import FileCheckConst diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py index 732745ee8ca14b0665e2beb22da72cfd856164d1..1b9b26f9c0e3b954c577b68d35d8c5e086660b73 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py @@ -12,7 +12,7 @@ import torch from tqdm import tqdm from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import exec_api, generate_device_params, get_api_info from msprobe.pytorch.api_accuracy_checker.common.utils import get_json_contents -from msprobe.core.common.file_check import check_link +from msprobe.core.common.file_utils import check_link from msprobe.pytorch.common.log import logger from msprobe.pytorch.common.parse_json import parse_json_info_forward_backward from msprobe.core.common.const import Const diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py index 559dfdc0f14f191fc7142f6b2f9d735c51d6a738..2fb709127d893320988b64db634c92b1eda46d60 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py @@ -31,7 +31,7 @@ from msprobe.pytorch.hook_module.wrap_npu_custom import NpuOPTemplate from msprobe.pytorch.hook_module.wrap_aten import AtenOPTemplate from msprobe.pytorch.api_accuracy_checker.common.config import msCheckerConfig from msprobe.pytorch.common.parse_json import parse_json_info_forward_backward -from msprobe.core.common.file_check import FileOpen, FileChecker, \ +from msprobe.core.common.file_utils import FileOpen, FileChecker, \ change_mode, check_file_suffix, check_link, check_path_before_create, create_directory from msprobe.pytorch.common.log import logger from msprobe.pytorch.pt_config import parse_json_config diff --git a/debug/accuracy_tools/msprobe/pytorch/common/parse_json.py b/debug/accuracy_tools/msprobe/pytorch/common/parse_json.py index ccad903724c42bbc48819c1f64634ec57d4244d2..89edd834cf67c8006f577f1d33fee32b3b6dc751 100644 --- a/debug/accuracy_tools/msprobe/pytorch/common/parse_json.py +++ b/debug/accuracy_tools/msprobe/pytorch/common/parse_json.py @@ -1,7 +1,7 @@ import json from msprobe.core.common.exceptions import ParseJsonException -from msprobe.core.common.file_check import FileOpen +from msprobe.core.common.file_utils import FileOpen def parse_json_info_forward_backward(json_path): diff --git a/debug/accuracy_tools/msprobe/pytorch/compare/acc_compare.py b/debug/accuracy_tools/msprobe/pytorch/compare/acc_compare.py index 27d555f5d2718eb31fdb1ce54cc6c9ef56a08cb6..16533accb26680e341a71e8f40f30e2ed899301f 100644 --- a/debug/accuracy_tools/msprobe/pytorch/compare/acc_compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/compare/acc_compare.py @@ -35,7 +35,7 @@ from msprobe.pytorch.advisor.advisor import Advisor from msprobe.pytorch.common.log import logger from msprobe.core.common.utils import check_compare_param, add_time_with_xlsx, CompareException, \ format_value, check_file_not_exists, check_configuration_param, task_dumppath_get -from msprobe.core.common.file_check import FileChecker, change_mode, FileOpen, create_directory +from msprobe.core.common.file_utils import FileChecker, change_mode, FileOpen, create_directory from msprobe.core.common.const import Const, CompareConst, FileCheckConst from msprobe.core.common.exceptions import FileCheckException diff --git a/debug/accuracy_tools/msprobe/pytorch/compare/distributed_compare.py b/debug/accuracy_tools/msprobe/pytorch/compare/distributed_compare.py index caac139580751a9b9a36d0f73fbf163263d85a51..b2b3a6672c49f38ed8c27ef42f13b52bbdb051f5 100644 --- a/debug/accuracy_tools/msprobe/pytorch/compare/distributed_compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/compare/distributed_compare.py @@ -20,7 +20,7 @@ import re from msprobe.core.common.utils import CompareException, check_compare_param, \ check_configuration_param, task_dumppath_get, check_file_or_directory_path, check_regex_prefix_format_valid from msprobe.pytorch.compare.acc_compare import compare_core -from msprobe.core.common.file_check import create_directory +from msprobe.core.common.file_utils import create_directory from msprobe.core.common.exceptions import FileCheckException from msprobe.pytorch.common.log import logger diff --git a/debug/accuracy_tools/msprobe/pytorch/compare/match.py b/debug/accuracy_tools/msprobe/pytorch/compare/match.py index 6347d8887c85427fcb556eecb5cd4a7302166969..ca335f7d8ec9bda9b594ff9cd2c9d5ed1f7a053f 100644 --- a/debug/accuracy_tools/msprobe/pytorch/compare/match.py +++ b/debug/accuracy_tools/msprobe/pytorch/compare/match.py @@ -1,6 +1,6 @@ import os import yaml -from msprobe.core.common.file_check import FileOpen +from msprobe.core.common.file_utils import FileOpen from msprobe.core.common.utils import CompareException diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/utils.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/utils.py index c1e581675fa0995549fcfcd5521cf9759180c3d9..d991445db091f74fbe9f5f7d5f68dd4bdcf4b868 100644 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/utils.py @@ -18,7 +18,7 @@ import os import yaml -from msprobe.core.common.file_check import FileOpen +from msprobe.core.common.file_utils import FileOpen cur_path = os.path.dirname(os.path.realpath(__file__)) yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_aten.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_aten.py index a02abbe5f4b7e551faf2c4ff465271ae9bebffde..a99d669beb227c0576d95991d5e89811d229c4c3 100644 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_aten.py +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_aten.py @@ -23,7 +23,7 @@ import yaml from msprobe.pytorch.hook_module.hook_module import HOOKModule from msprobe.pytorch.common.utils import torch_device_guard from msprobe.core.common.const import Const -from msprobe.core.common.file_check import FileOpen +from msprobe.core.common.file_utils import FileOpen from msprobe.pytorch.function_factory import npu_custom_grad_functions cur_path = os.path.dirname(os.path.realpath(__file__)) diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_distributed.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_distributed.py index 6cf425441cc381652ddca4b203ac7a2b4161a116..54afecb9a663db8aa2db998247fca707a8aa2ced 100644 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_distributed.py +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_distributed.py @@ -23,7 +23,7 @@ import yaml from msprobe.pytorch.hook_module.hook_module import HOOKModule from msprobe.pytorch.common.utils import torch_device_guard from msprobe.core.common.const import Const -from msprobe.core.common.file_check import FileOpen +from msprobe.core.common.file_utils import FileOpen cur_path = os.path.dirname(os.path.realpath(__file__)) diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_functional.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_functional.py index fd7610ca8fc8089f427a91bed174055882e7207f..96d6986a0784cf98e42fd04d45a52f0be6bd39ed 100644 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_functional.py +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_functional.py @@ -24,7 +24,7 @@ from msprobe.pytorch.hook_module.hook_module import HOOKModule from msprobe.pytorch.common.utils import torch_device_guard from msprobe.core.common.const import Const from msprobe.pytorch.common.log import logger -from msprobe.core.common.file_check import FileOpen +from msprobe.core.common.file_utils import FileOpen def remove_dropout(): diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_npu_custom.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_npu_custom.py index 8a67ed94290d9ba02e947f5806daece13f041e9e..607b4ed3cd54a7e745fa1eca715b1ef8ba1e04b2 100644 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_npu_custom.py +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_npu_custom.py @@ -22,7 +22,7 @@ import yaml from msprobe.pytorch.hook_module.hook_module import HOOKModule from msprobe.pytorch.common.utils import torch_device_guard, torch_without_guard_version from msprobe.core.common.const import Const -from msprobe.core.common.file_check import FileOpen +from msprobe.core.common.file_utils import FileOpen from msprobe.pytorch.function_factory import npu_custom_functions cur_path = os.path.dirname(os.path.realpath(__file__)) diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_tensor.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_tensor.py index 3e26ae3beda5341df76eb1f3fdea68e43193f983..aba6f86148967f0f5d0af4d9ac2a85150e79ff47 100644 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_tensor.py +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_tensor.py @@ -23,7 +23,7 @@ import yaml from msprobe.pytorch.hook_module.hook_module import HOOKModule from msprobe.pytorch.common.utils import torch_device_guard, parameter_adapter from msprobe.core.common.const import Const -from msprobe.core.common.file_check import FileOpen +from msprobe.core.common.file_utils import FileOpen cur_path = os.path.dirname(os.path.realpath(__file__)) yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_torch.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_torch.py index 486ddda4919b1abefa35aec8ed21659c06c4588d..3f9518b7f1a3b9532de0ad016b21246be5bd883d 100644 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_torch.py +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_torch.py @@ -23,7 +23,7 @@ import yaml from msprobe.pytorch.hook_module.hook_module import HOOKModule from msprobe.pytorch.common.utils import torch_device_guard from msprobe.core.common.const import Const -from msprobe.core.common.file_check import FileOpen +from msprobe.core.common.file_utils import FileOpen cur_path = os.path.dirname(os.path.realpath(__file__)) yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_vf.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_vf.py index d78beb2a6ad790ab3ad897bf819e74a234520e8c..351820fd6c11f71b7465e1e5e456b582161ff462 100644 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_vf.py +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_vf.py @@ -21,7 +21,7 @@ import torch import yaml from msprobe.pytorch.hook_module.hook_module import HOOKModule -from msprobe.core.common.file_check import FileOpen +from msprobe.core.common.file_utils import FileOpen from msprobe.pytorch.common.utils import torch_device_guard from msprobe.core.common.const import Const diff --git a/debug/accuracy_tools/msprobe/pytorch/online_dispatch/compare.py b/debug/accuracy_tools/msprobe/pytorch/online_dispatch/compare.py index 048ab3f901c49870c706958da5cdd5d549c475cf..4e3d574cd84118dda2e32667be09d8695312a584 100644 --- a/debug/accuracy_tools/msprobe/pytorch/online_dispatch/compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/online_dispatch/compare.py @@ -6,10 +6,9 @@ import json from collections import namedtuple from rich.table import Table from rich.console import Console -from .single_compare import single_benchmark_compare_wrap -from .utils import DispatchException +from msprobe.pytorch.online_dispatch.single_compare import single_benchmark_compare_wrap from msprobe.core.common.const import CompareConst -from msprobe.core.common.file_check import FileOpen +from msprobe.core.common.file_utils import FileOpen from msprobe.pytorch.common.log import logger from msprobe.core.common.utils import CompareException diff --git a/debug/accuracy_tools/msprobe/pytorch/online_dispatch/dispatch.py b/debug/accuracy_tools/msprobe/pytorch/online_dispatch/dispatch.py index 898df30b99d0fa5ebffb46e05ff7247d19d1f859..2251fa6cb6aabb0054598cbeb0354d3ac9552c7a 100644 --- a/debug/accuracy_tools/msprobe/pytorch/online_dispatch/dispatch.py +++ b/debug/accuracy_tools/msprobe/pytorch/online_dispatch/dispatch.py @@ -16,12 +16,12 @@ except ImportError: else: is_npu = True -from .dump_compare import dispatch_workflow, dispatch_multiprocess, error_call, TimeStatistics, \ +from msprobe.pytorch.online_dispatch.dump_compare import dispatch_workflow, dispatch_multiprocess, error_call, TimeStatistics, \ DispatchRunParam, DisPatchDataInfo -from .utils import get_callstack, data_to_cpu, logger_debug, logger_error, logger_warn, logger_logo, get_sys_info, \ +from msprobe.pytorch.online_dispatch.utils import get_callstack, data_to_cpu, logger_debug, logger_error, logger_warn, logger_logo, get_sys_info, \ DispatchException -from .compare import Comparator -from msprobe.core.common.file_check import FileOpen +from msprobe.pytorch.online_dispatch.compare import Comparator +from msprobe.core.common.file_utils import FileOpen from msprobe.core.common.utils import check_file_or_directory_path, check_path_before_create from msprobe.core.common.const import Const, CompareConst diff --git a/debug/accuracy_tools/msprobe/pytorch/online_dispatch/dump_compare.py b/debug/accuracy_tools/msprobe/pytorch/online_dispatch/dump_compare.py index f83b6fc9f00b7b1fa9ac4baa89632c9d43a04e4c..5e8bf4f1117b0a86f67902df723ca08ad19db73e 100644 --- a/debug/accuracy_tools/msprobe/pytorch/online_dispatch/dump_compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/online_dispatch/dump_compare.py @@ -5,9 +5,9 @@ from datetime import datetime, timezone import pandas as pd import torch -from .utils import np_save_data, logger_debug, logger_error, logger_warn, logger_user, COLOR_RED, COLOR_GREEN, \ +from msprobe.pytorch.online_dispatch.utils import np_save_data, logger_debug, logger_error, logger_warn, logger_user, COLOR_RED, COLOR_GREEN, \ COLOR_RESET, CSV_COLUMN_NAME -from msprobe.core.common.file_check import FileOpen, change_mode +from msprobe.core.common.file_utils import FileOpen, change_mode from msprobe.core.common.const import CompareConst, FileCheckConst, Const from msprobe.pytorch.common.log import logger diff --git a/debug/accuracy_tools/msprobe/pytorch/online_dispatch/utils.py b/debug/accuracy_tools/msprobe/pytorch/online_dispatch/utils.py index fec3e0b00746c653089d763c052d2b1c350a6886..c1d1e841a40137508c1fd09d617d57c83e9306a9 100644 --- a/debug/accuracy_tools/msprobe/pytorch/online_dispatch/utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/online_dispatch/utils.py @@ -13,7 +13,7 @@ else: pta_cpu_device = torch.device("cpu") from msprobe.core.common.const import CompareConst, FileCheckConst -from msprobe.core.common.file_check import change_mode +from msprobe.core.common.file_utils import change_mode cpu_device = torch._C.device("cpu") COLOR_RED = '\033[31m' diff --git a/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/utils.py b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/utils.py index 17a01f20fb0630a3f260e065631c7056ecf436a3..a8abec2d15634d08429e51899a83d4938c4a8869 100644 --- a/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/utils.py @@ -28,10 +28,10 @@ from collections import namedtuple from msprobe.pytorch.parse_tool.lib.config import Const from msprobe.pytorch.parse_tool.lib.file_desc import DumpDecodeFileDesc, FileDesc from msprobe.pytorch.parse_tool.lib.parse_exception import ParseException -from msprobe.core.common.file_check import change_mode, check_other_user_writable,\ +from msprobe.core.common.file_utils import change_mode, check_other_user_writable,\ check_path_executable, check_path_owner_consistent from msprobe.core.common.const import FileCheckConst -from msprobe.core.common.file_check import FileOpen +from msprobe.core.common.file_utils import FileOpen from msprobe.core.common.utils import check_file_or_directory_path, check_path_before_create from msprobe.pytorch.common.log import logger diff --git a/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/visualization.py b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/visualization.py index 5e37b58d0b9fae8ad4e69ec58a4498dae9bd33b3..a10c7a447fd1cc5b18202bfb382190ed7a663ccc 100644 --- a/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/visualization.py +++ b/debug/accuracy_tools/msprobe/pytorch/parse_tool/lib/visualization.py @@ -20,7 +20,7 @@ import numpy as np from msprobe.pytorch.parse_tool.lib.config import Const from msprobe.pytorch.parse_tool.lib.utils import Util from msprobe.pytorch.parse_tool.lib.parse_exception import ParseException -from msprobe.core.common.file_check import FileOpen +from msprobe.core.common.file_utils import FileOpen class Visualization: diff --git a/debug/accuracy_tools/msprobe/pytorch/pt_config.py b/debug/accuracy_tools/msprobe/pytorch/pt_config.py index ceec92a633af86fec8ddc200378e5ef42dfd4600..92145ee2c70fe8667726977236afecb4c14c4d3b 100644 --- a/debug/accuracy_tools/msprobe/pytorch/pt_config.py +++ b/debug/accuracy_tools/msprobe/pytorch/pt_config.py @@ -2,7 +2,7 @@ import json import os from msprobe.core.common_config import CommonConfig, BaseConfig -from msprobe.core.common.file_check import FileOpen +from msprobe.core.common.file_utils import FileOpen from msprobe.core.common.const import Const from msprobe.pytorch.hook_module.utils import WrapFunctionalOps, WrapTensorOps, WrapTorchOps diff --git a/debug/accuracy_tools/msprobe/pytorch/service.py b/debug/accuracy_tools/msprobe/pytorch/service.py index aa22cd0b3e549b4ebc7729c5fd691fa118ae98e5..3238d11b2b96689f0185d71d6c8f1c6a04e3f270 100644 --- a/debug/accuracy_tools/msprobe/pytorch/service.py +++ b/debug/accuracy_tools/msprobe/pytorch/service.py @@ -5,7 +5,7 @@ import torch from packaging import version from msprobe.core.common.const import Const, FileCheckConst from msprobe.core.common.exceptions import DistributedNotInitializedError, MsprobeException -from msprobe.core.common.file_check import FileChecker, check_path_before_create +from msprobe.core.common.file_utils import FileChecker, check_path_before_create from msprobe.core.data_dump.data_collector import build_data_collector from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs from msprobe.core.data_dump.scope import BaseScope diff --git a/debug/accuracy_tools/msprobe/pytorch/torchair_compare/torchair_acc_cmp_mstt.py b/debug/accuracy_tools/msprobe/pytorch/torchair_compare/torchair_acc_cmp_mstt.py deleted file mode 100644 index 564d13d75b8b17b3df0d6cd7a8a0907a2330a033..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/pytorch/torchair_compare/torchair_acc_cmp_mstt.py +++ /dev/null @@ -1,198 +0,0 @@ -import os -import re -import datetime -import pandas as pd - -from msit_llm.common.log import logger -from msit_llm.compare.torchair_acc_cmp import set_msaccucmp_path_from_cann, parse_pbtxt_to_dict, \ - compare_ge_with_fx, compare_ge_with_ge, GE_GRAPH_FILE_PREFIX, MAX_TOKEN_LEN, DUMP_FILE_FILTER_SUFIX -from msit_llm.common.constant import CSV_GOLDEN_HEADER, MY_DTYPE, GOLDEN_DTYPE - - -def save_compare_reault_to_csv(gathered_row_data, output_path=".", columns=CSV_GOLDEN_HEADER): - if not os.path.exists(output_path): - os.makedirs(output_path) - - local_timezone = datetime.datetime.now().astimezone().tzinfo - cur_time = datetime.datetime.now().astimezone(local_timezone).strftime('%Y%m%d%H%M%S') - csv_save_path = os.path.join(output_path, f"mstt_cmp_report_{cur_time}.csv") - - # 过滤不宜展示的数据,int8建议只与int8比较 - for row_data in gathered_row_data: - if GOLDEN_DTYPE in row_data and MY_DTYPE in row_data: - if (row_data[GOLDEN_DTYPE] == 'torch.int8') ^ (row_data[MY_DTYPE] == 'torch.int8'): - gathered_row_data.remove(row_data) - - data_frame = pd.DataFrame(gathered_row_data, columns=columns) - data_frame.fillna(value="", inplace=True) - data_frame.dropna(axis=0, how="all", inplace=True) - data_frame.to_csv(csv_save_path, index=False) - logger.info(f"Saved comparing results: {csv_save_path}") - - -def get_torchair_ge_graph_path(my_path): - if not os.path.isdir(my_path): - return [] - - ge_graph_files = [] - my_path_depth = len(my_path.split(os.sep)) - timestamp_pattern = re.compile(r'(\d+)') - for cur_path, _, file_names in os.walk(my_path): - for file_name in file_names: - if file_name.startswith(GE_GRAPH_FILE_PREFIX) and file_name.endswith(".txt"): - match = timestamp_pattern.search(file_name) - if match: - full_path = os.path.join(cur_path, file_name) - timestamp = int(match.group(1)) - ge_graph_files.append((full_path, timestamp)) - - cur_depth = len(cur_path.split(os.sep)) - my_path_depth - if cur_depth > 5: # Avoid going too deep - break - - if ge_graph_files: - sorted_ge_graph_files = [file for file, timestamp in sorted(ge_graph_files, key=lambda x: x[1])] - return sorted_ge_graph_files - else: - return [] - - -def gather_data_with_token_id(data_path, fx=False): - token_dirs, cur_token_id = [], 0 - # Detect the deepest dir level where sub dirs are all digits, and regard as tokens level. - if fx: - for cur_path, dirs, _ in os.walk(data_path): - if len(dirs) == 0: - continue - if all([len(ii) < MAX_TOKEN_LEN and str.isdigit(ii) for ii in dirs]): - dirs = sorted(dirs, key=lambda xx: int(xx)) - token_dirs = [os.path.join(cur_path, dir_name) for dir_name in dirs] # Search till deepest level - else: - token_dirs = [] - for cur_path, dirs, _ in sorted(os.walk(data_path), key=lambda x: x[0]): - if not dirs: - token_dirs.append(cur_path) - - if len(token_dirs) == 0: - token_dirs.append(data_path) # Just use data_path if found no token like dirs - - gathered_files_list = [] - for token_dir in token_dirs: - gathered_files = {} - cur_basename = os.path.basename(token_dir) - cur_token_id = int(cur_basename) if str.isdigit(cur_basename) else 0 - for cur_path, _, file_names in os.walk(token_dir): - if gathered_files: - gathered_files = {} - file_names = [os.path.join(cur_path, file_name) for file_name in file_names] - gathered_files.setdefault(cur_token_id, []).extend(file_names) - if gathered_files.get(cur_token_id, None): - gathered_files_list.append(gathered_files) - return gathered_files_list - - -def init_ge_dump_data_from_bin_path(ge_dump_path): - gathered_files_list = gather_data_with_token_id(ge_dump_path) - if not gathered_files_list: - logger.error("can not get ge dump data") - raise Exception - - dump_data_with_token_id_list = [] - for gathered_files in gathered_files_list: - dump_data_with_token_id = {} - for token_id, file_list in gathered_files.items(): - cur_dump_data = {} - for file_name in sorted(file_list): - if os.path.splitext(file_name)[-1] in DUMP_FILE_FILTER_SUFIX: - continue - split_name = os.path.basename(file_name).split(".") - if len(split_name) < 5: - logger.warning(f"invalid file name: {file_name}, should contain at least 4 '.'") - continue - - cur_op_name = ".".join(split_name[1:-3]) - if cur_op_name in cur_dump_data: - exists_file = cur_dump_data[cur_op_name] - exists_file_size = os.path.getsize(exists_file) - cur_file_size = os.path.getsize(file_name) - keep_one = file_name if cur_file_size > exists_file_size else exists_file - cur_dump_data[cur_op_name] = keep_one - logger.warning(f"duplicated op name: {cur_op_name}." - f" [{os.path.basename(file_name)}, {os.path.basename(exists_file)}]." - f" Will keep the larger one {os.path.basename(keep_one)}." - ) - else: - cur_dump_data[cur_op_name] = file_name - dump_data_with_token_id[token_id] = cur_dump_data - dump_data_with_token_id_list.append(dump_data_with_token_id) - return dump_data_with_token_id_list - - -def init_fx_dump_data_from_path(fx_dump_path): - gathered_files_list = gather_data_with_token_id(fx_dump_path, fx=True) - if not gathered_files_list: - logger.error("can not get fx dump data") - raise Exception - - dump_data_with_token_id_list = [] - for gathered_files in gathered_files_list: - dump_data_with_token_id = {} - for token_id, file_list in gathered_files.items(): - cur_dump_data = {} - for file_path in sorted(file_list): - if not file_path.endswith("npy"): - continue - file_name = os.path.basename(file_path) - split_name = file_name.split(".") - is_input = ".INPUT." in file_name - cur_op_name = file_name.split('.INPUT.' if is_input else ".OUTPUT.")[0] - cur_op_map = cur_dump_data.get(cur_op_name, {}) - cur_op_map.setdefault("input" if is_input else "output", []).append(file_path) - cur_dump_data[cur_op_name] = cur_op_map - if len(cur_dump_data) > 0: - dump_data_with_token_id[token_id - 1] = cur_dump_data # For FX data, token starts from 1, while GE is 0 - dump_data_with_token_id_list.append(dump_data_with_token_id) - return dump_data_with_token_id_list - - -def acc_compare(golden_path, my_path, output_path=".", ge_graph_path=None): - logger.info(f"[compare_torchair], golden_path: {golden_path}, my_path: {my_path}, ge_graph_path: {ge_graph_path}") - set_msaccucmp_path_from_cann() - - if ge_graph_path is None: - ge_graph_path = get_torchair_ge_graph_path(my_path) - graph_map_list = [] - for path in ge_graph_path: - graph_map_list.append(parse_pbtxt_to_dict(path)) # 解析 GE 图,并将其转化为字典 graph_map - - my_dump_data_list = init_ge_dump_data_from_bin_path(my_path) # 从ge路径中提取ge数据 - - is_golden_fx = get_torchair_ge_graph_path(golden_path) is None # 判断是fx图还是ge图,fx图返回None,ge图不是None - if is_golden_fx: - logger.info("Comparing GE with FX") - golden_dump_data_list = init_fx_dump_data_from_path(golden_path) # 从fx路径中提取ge数据 - else: - logger.info("Comparing GE with GE") - golden_dump_data_list = init_ge_dump_data_from_bin_path(golden_path) # 从ge路径中提取ge数据 - - logger.info(f"All token ids in my_dump_data: {my_dump_data_list[0].keys()}") - logger.info(f"All token ids in golden_dump_data: {my_dump_data_list[0].keys()}") - - graph_map_list_len = len(graph_map_list) - for i in range(graph_map_list_len): - graph_map = graph_map_list[i] - my_dump_data = my_dump_data_list[i] - golden_dump_data = golden_dump_data_list[i] - - gathered_row_data = [] - for token_id in my_dump_data: - if token_id not in golden_dump_data: - logger.warning(f"My token_id {token_id} not found in golden dump data") - continue - logger.info(f"Comparing token_id: {token_id}") - if is_golden_fx: - row_data = compare_ge_with_fx(graph_map, my_dump_data[token_id], golden_dump_data[token_id], token_id) - else: - row_data = compare_ge_with_ge(graph_map, my_dump_data[token_id], golden_dump_data[token_id], token_id) - gathered_row_data.extend(row_data) - save_compare_reault_to_csv(gathered_row_data, output_path) diff --git a/debug/accuracy_tools/msprobe/pytorch/torchair_compare/torchair_compare_cli.py b/debug/accuracy_tools/msprobe/pytorch/torchair_compare/torchair_compare_cli.py index ecbfc93db43849aeaaff4c3b2c2e9b9f249a66f0..c46e3f04456db0f63eb7c19ba94a5252f99f26b2 100644 --- a/debug/accuracy_tools/msprobe/pytorch/torchair_compare/torchair_compare_cli.py +++ b/debug/accuracy_tools/msprobe/pytorch/torchair_compare/torchair_compare_cli.py @@ -1,4 +1,4 @@ -from msprobe.pytorch.torchair_compare.torchair_acc_cmp_mstt import acc_compare +from msit_llm.compare.torchair_acc_cmp import acc_compare def torchair_compare_parser(parser): diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/builder/graph_builder.py b/debug/accuracy_tools/msprobe/pytorch/visualization/builder/graph_builder.py index 7d5c156030df580cb1beb10b21bb03c593a9348e..61f845a2811fd39854f72146b8587f43dda423f6 100644 --- a/debug/accuracy_tools/msprobe/pytorch/visualization/builder/graph_builder.py +++ b/debug/accuracy_tools/msprobe/pytorch/visualization/builder/graph_builder.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import re from msprobe.pytorch.visualization.graph.graph import Graph, BaseNode from msprobe.pytorch.visualization.graph.node_op import NodeOp from msprobe.pytorch.visualization.utils import load_json_file, load_data_json_file, save_json_file, GraphConst @@ -56,10 +57,27 @@ class GraphBuilder: if tool_tip: result[GraphConst.JSON_TIP_KEY] = tool_tip save_json_file(filename, result) - + + @staticmethod + def _handle_backward_upnode_missing(construct_dict, subnode_id, upnode_id): + """ + 如果backward节点的父级节点是null,则尝试从同名的forward节点寻找父级节点 + """ + # 匹配以.backward.后跟一个或多个数字结尾的模式 + backward_pattern = r"(\.backward\.)(\d+)$" + forward_pattern = r"(\.forward\.)(\d+)$" + if re.search(backward_pattern, subnode_id) and not upnode_id: + forward_upnode_id = construct_dict.get(re.sub(backward_pattern, r".forward.\2", subnode_id)) + if forward_upnode_id: + new_upnode_id = re.sub(forward_pattern, r".backward.\2", forward_upnode_id) + if new_upnode_id in construct_dict: + return new_upnode_id + return upnode_id + @staticmethod def _init_nodes(graph, construct_dict, data_dict): for subnode_id, upnode_id in construct_dict.items(): + upnode_id = GraphBuilder._handle_backward_upnode_missing(construct_dict, subnode_id, upnode_id) if upnode_id: upnode_op = NodeOp.get_node_op(upnode_id) upnode = GraphBuilder._create_or_get_node(graph, data_dict, upnode_op, upnode_id) diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/builder/msprobe_adapter.py b/debug/accuracy_tools/msprobe/pytorch/visualization/builder/msprobe_adapter.py index b82412c7dc3d3414133eafee21640268da90aeea..92af6d67325c2b2ff2b4dbd34bf799b3219445de 100644 --- a/debug/accuracy_tools/msprobe/pytorch/visualization/builder/msprobe_adapter.py +++ b/debug/accuracy_tools/msprobe/pytorch/visualization/builder/msprobe_adapter.py @@ -97,6 +97,23 @@ def compare_data(data_dict_list1, data_dict_list2): return True +def compare_mapping_data(data_dict_list1, data_dict_list2): + """ + node1映射node2,可能node1参数多于或少于node2参数,个别参数的shape的维度顺序不同,node1参数null对应node2参数其他值 + 工具要尽可能保证node的数据能够比对,进行数据的弱校验,仅校验参数的shape维度数值是否相同 + """ + for x, y in zip(data_dict_list1.values(), data_dict_list2.values()): + x_shape = x.get('shape') + y_shape = y.get('shape') + if x_shape is None or y_shape is None: + continue + x_shape = sorted(x_shape) if isinstance(x_shape, list) else x_shape + y_shape = sorted(y_shape) if isinstance(y_shape, list) else y_shape + if x_shape != y_shape: + return False + return True + + def format_node_data(data_dict): """ 批量进行节点数据的输出 diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/compare/graph_comparator.py b/debug/accuracy_tools/msprobe/pytorch/visualization/compare/graph_comparator.py index 86ba1b1f869983daaccbc7642dfa560d9d92f75c..4bbf2909ede80e0b329bf0e76cb95227adf96f39 100644 --- a/debug/accuracy_tools/msprobe/pytorch/visualization/compare/graph_comparator.py +++ b/debug/accuracy_tools/msprobe/pytorch/visualization/compare/graph_comparator.py @@ -20,10 +20,11 @@ from msprobe.pytorch.visualization.compare.mode_adapter import ModeAdapter class GraphComparator: - def __init__(self, graphs, data_paths, stack_path, output_path): + def __init__(self, graphs, data_paths, stack_path, output_path, mapping_config=None): self.graph_n = graphs[0] self.graph_b = graphs[1] self._parse_param(data_paths, stack_path, output_path) + self.mapping_config = mapping_config def compare(self): """ @@ -46,6 +47,8 @@ class GraphComparator: compare_out_dict = {} # input和output对比数据分开 for item in compare_result_list: + if not node.stack_info and node.id in item[0]: + node.stack_info = item[-1] if 'output' in item[0]: compare_out_dict[item[0]] = item else: @@ -98,12 +101,23 @@ class GraphComparator: node.data[GraphConst.JSON_INDEX_KEY] = precision_index def _compare_nodes(self, node_n): - #递归遍历NPU树中的节点,如果在Bench中找到具有相同名称的节点,检查他们的祖先和参数信息,检查一致则及逆行精度数据对比 - #这里采用先序遍历,好处在于当这个节点被比较时,他的先序已经被匹配,这可以为后续的模糊匹配提供重要信息 - node_b, ancestors = Graph.match(self.graph_n, node_n, self.graph_b) + """ + 递归遍历NPU树中的节点,如果在Bench中找到具有相同名称的节点,检查他们的祖先和参数信息,检查一致则及逆行精度数据对比 + 这里采用先序遍历,好处在于当这个节点被比较时,他的先序已经被匹配,这可以为后续的模糊匹配提供重要信息 + """ + if self.mapping_config: + node_b, ancestors_n, ancestors_b = Graph.mapping_match(node_n, self.graph_b, self.mapping_config) + if node_b: + ancestors_n.append(node_n.id) + ancestors_b.append(node_b.id) + node_n.matched_node_link = ancestors_b + node_b.matched_node_link = ancestors_n + else: + node_b, ancestors = Graph.match(self.graph_n, node_n, self.graph_b) + if node_b: + ancestors.append(node_b.id) + node_n.add_link(node_b, ancestors) if node_b: - ancestors.append(node_b.id) - node_n.add_link(node_b, ancestors) # 真实数据比对只会得到基本信息,并没有精度指标,需要调用多进程对比接口 compare_result_list = compare_node([node_n.id, node_b.id], [self.data_n_dict, self.data_b_dict], diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/graph/base_node.py b/debug/accuracy_tools/msprobe/pytorch/visualization/graph/base_node.py index e8c86e243e0ce2f34363f55e39ab344e989e205a..82b3a4bfe56b384868ec0404ec887e9bf9003421 100644 --- a/debug/accuracy_tools/msprobe/pytorch/visualization/graph/base_node.py +++ b/debug/accuracy_tools/msprobe/pytorch/visualization/graph/base_node.py @@ -15,7 +15,7 @@ from msprobe.pytorch.visualization.graph.node_op import NodeOp from msprobe.pytorch.visualization.utils import Suggestions, GraphConst -from msprobe.pytorch.visualization.builder.msprobe_adapter import format_node_data, compare_data +from msprobe.pytorch.visualization.builder.msprobe_adapter import format_node_data, compare_data, compare_mapping_data class BaseNode: @@ -30,6 +30,7 @@ class BaseNode: self.subnodes = [] self.matched_node_link = [] self.suggestions = {} + self.stack_info = [] def __str__(self): info = f'id:\t{self.id}' @@ -45,6 +46,13 @@ class BaseNode: return False return True + def compare_mapping_node(self, other): + if not compare_mapping_data(self.input_data, other.input_data): + return False + if not compare_mapping_data(self.output_data, other.output_data): + return False + return True + def get_suggestions(self): """ 精度疑似有问题时,提供一些建议 @@ -93,6 +101,7 @@ class BaseNode: result['subnodes'] = [node.id for node in self.subnodes] result['matched_node_link'] = self.matched_node_link result['suggestions'] = self.suggestions + result['stack_info'] = self.stack_info return result def get_ancestors(self): diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/graph/graph.py b/debug/accuracy_tools/msprobe/pytorch/visualization/graph/graph.py index a6ccd308873f3764cd169daf83d2989ca0befebe..17a8d24ac642f4ac72f8f8769cbf41d38c26401d 100644 --- a/debug/accuracy_tools/msprobe/pytorch/visualization/graph/graph.py +++ b/debug/accuracy_tools/msprobe/pytorch/visualization/graph/graph.py @@ -24,12 +24,12 @@ class Graph: self.node_id_map = {} self.add_node(NodeOp.module, model_name) self.root = self.get_node(model_name) - + def __str__(self): infos = [f'{str(self.node_map.get(node_id))}' for node_id in self.node_map] info = "\n".join(infos) return info - + @staticmethod def match(graph_n, node_n, graph_b): """ @@ -47,14 +47,26 @@ class Graph: if ancestors_n != ancestors_b: return None, [] return node_b, ancestors_n - + + @staticmethod + def mapping_match(node_n, graph_b, mapping_config): + """ + 根据映射配置对节点进行匹配 + """ + node_b = graph_b.node_map.get(mapping_config.get_mapping_string(node_n.id)) + if not node_b or not node_n.compare_mapping_node(node_b): + return None, [], [] + ancestors_n = node_n.get_ancestors() + ancestors_b = node_b.get_ancestors() + return node_b, ancestors_n, ancestors_b + @staticmethod def dfs(node, result): info = node.to_dict() result[node.id] = info for subnode in node.subnodes: Graph.dfs(subnode, result) - + def add_node(self, node_op, node_id, up_node=None, id_accumulation=False): """ 在graph中进行节点的添加 @@ -78,13 +90,13 @@ class Graph: node = BaseNode(node_op, node_id, up_node) self.node_map[node_id] = node return node_id - + def get_node(self, node_id): """ 返回节点,不存在返回None """ return self.node_map.get(node_id, None) - + def to_dict(self): """ 用于数据输出 diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/graph_service.py b/debug/accuracy_tools/msprobe/pytorch/visualization/graph_service.py index 0e4fdab7cf832725ae53978c0049e52a6e08ee13..c2b77ac44e2a50b32c601df29e83d048ea345b3f 100644 --- a/debug/accuracy_tools/msprobe/pytorch/visualization/graph_service.py +++ b/debug/accuracy_tools/msprobe/pytorch/visualization/graph_service.py @@ -19,11 +19,12 @@ from msprobe.pytorch.visualization.compare.graph_comparator import GraphComparat from msprobe.pytorch.visualization.utils import GraphConst from msprobe.pytorch.visualization.builder.graph_builder import GraphBuilder from msprobe.core.common.log import logger +from msprobe.pytorch.visualization.mapping_config import MappingConfig current_time = time.strftime("%Y%m%d%H%M%S") -def compare_graph(dump_path_n, dump_path_b, out_path, model_name='Model'): +def compare_graph(dump_path_n, dump_path_b, out_path, model_name='Model', mapping_file=None): logger.info('Start building model graphs...') # 对两个数据进行构图 construct_path_n = os.path.join(dump_path_n, GraphConst.CONSTRUCT_FILE) @@ -35,7 +36,8 @@ def compare_graph(dump_path_n, dump_path_b, out_path, model_name='Model'): logger.info('Model graphs built successfully, start Comparing graphs...') # 基于graph、stack和data进行比较 stack_path = os.path.join(dump_path_n, GraphConst.STACK_FILE) - graph_comparator = GraphComparator([graph_n, graph_b], [data_path_n, data_path_b], stack_path, out_path) + graph_comparator = GraphComparator([graph_n, graph_b], [data_path_n, data_path_b], stack_path, out_path, + mapping_config=MappingConfig(mapping_file) if mapping_file else None) graph_comparator.compare() output_path = os.path.join(out_path, f'compare_{current_time}.vis') GraphBuilder.to_json(output_path, graph_n, graph_b, graph_comparator.ma.get_tool_tip()) diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/mapping_config.py b/debug/accuracy_tools/msprobe/pytorch/visualization/mapping_config.py new file mode 100644 index 0000000000000000000000000000000000000000..d986493513a345483bd6c32f808cca002b3cba5b --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/visualization/mapping_config.py @@ -0,0 +1,77 @@ +import re +import yaml +from msprobe.core.common.file_utils import FileOpen +from msprobe.core.common.const import Const +from msprobe.pytorch.visualization.utils import GraphConst + + +class MappingConfig: + MAX_STRING_LEN = 10000 + + def __init__(self, yaml_file): + with FileOpen(yaml_file, 'r') as file: + config = yaml.safe_load(file) + try: + self.config = {key: self.validate(key, value) for data in config for key, value in data.items()} + except Exception as e: + raise RuntimeError("Line of yaml contains content that is not '- key: value'.") from e + self.classify_config = self._classify_and_sort_keys() + + @staticmethod + def validate(key, value): + if not isinstance(key, str): + raise ValueError(f"{key} must be a string.") + if not isinstance(value, str): + raise ValueError(f"{value} must be a string.") + return value + + @staticmethod + def convert_to_regex(s): + """ + 字符串转换为正则表达式, {}替换为d+以匹配一个或多个数字, 开始和结束添加.*以匹配任意前缀和后缀 + Args: + s: 字符串 + Returns: 正则表达式 + """ + escaped_pattern = re.escape(s) + pattern = re.sub(r'\\\{\\\}', r'\\d+', escaped_pattern) + pattern = f'.*{pattern}.*' + return pattern + + @staticmethod + def _replace_parts(origin_string, mapping_key, mapping_value): + if GraphConst.BRACE in mapping_key: + parts = mapping_key.split(GraphConst.BRACE) + m_parts = mapping_value.split(GraphConst.BRACE) + return origin_string.replace(parts[0], m_parts[0]).replace(parts[1], m_parts[1]) + else: + return origin_string.replace(mapping_key, mapping_value) + + def get_mapping_string(self, origin_string: str): + if len(origin_string) > MappingConfig.MAX_STRING_LEN: + return origin_string + for category, items in self.classify_config.items(): + if category in origin_string: + for key, value in items: + if re.match(MappingConfig.convert_to_regex(key), origin_string): + return MappingConfig._replace_parts(origin_string, key, value) + return origin_string + + def _classify_and_sort_keys(self): + categorized_dict = {} + for key, value in self.config.items(): + parts = key.split(Const.SEP) + # 获取第一个部分作为新的分类key + category_key = parts[0] + + if category_key not in categorized_dict: + categorized_dict[category_key] = [] + + # 将原始的key-value对添加到对应的分类中 + categorized_dict[category_key].append((key, value)) + + # 对每个分类中的项按key中的.数量进行排序, .数量越多排越靠前, 优先匹配 + for category in categorized_dict: + categorized_dict[category].sort(key=lambda x: -x[0].count(Const.SEP)) + + return categorized_dict diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/utils.py b/debug/accuracy_tools/msprobe/pytorch/visualization/utils.py index 4699215866784ab01d61aae0e35e8650236979ef..4357961ee8891225128454a8a8abed05aaf77e34 100644 --- a/debug/accuracy_tools/msprobe/pytorch/visualization/utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/visualization/utils.py @@ -14,7 +14,7 @@ # limitations under the License. import json -from msprobe.core.common.file_check import FileOpen +from msprobe.core.common.file_utils import FileOpen from msprobe.core.common.const import CompareConst from msprobe.pytorch.compare.acc_compare import result_to_csv @@ -131,3 +131,4 @@ class GraphConst: NULL = 'null' NONE = 'None' VALUE = 'value' + BRACE = '{}' diff --git a/debug/accuracy_tools/msprobe/test/core_ut/common/test_utils.py b/debug/accuracy_tools/msprobe/test/core_ut/common/test_utils.py index edd3eb53dccf453f1d3efde7189dfadcd6dee000..27ffe411586707e384a404edfa39b6ac5a013464 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/common/test_utils.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/common/test_utils.py @@ -41,7 +41,7 @@ from msprobe.core.common.utils import (CompareException, check_regex_prefix_format_valid, get_dump_data_path, task_dumppath_get) -from msprobe.core.common.file_check import FileCheckConst +from msprobe.core.common.file_utils import FileCheckConst class TestUtils(TestCase): diff --git a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_json_writer.py b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_json_writer.py index cfb1b3d551aa225d165b6620b2bb3de906ce70ea..4161377a6067c53fa4735a13641c5609bd0d06c9 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_json_writer.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_json_writer.py @@ -3,7 +3,7 @@ from msprobe.core.data_dump.json_writer import DataWriter import os import csv -from msprobe.core.common.file_check import FileOpen +from msprobe.core.common.file_utils import FileOpen from msprobe.core.common import utils from pathlib import Path import json diff --git a/debug/accuracy_tools/msprobe/test/core_ut/test_file_check.py b/debug/accuracy_tools/msprobe/test/core_ut/test_file_check.py index ecdf3da9fedfcc283e3a7887c74b4e46c3d0aae5..c3f7836bf177c283b246e061d2df61f682175c38 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/test_file_check.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/test_file_check.py @@ -22,7 +22,7 @@ from unittest.mock import patch, MagicMock from msprobe.core.common.log import logger from msprobe.core.common.const import FileCheckConst from msprobe.core.common.exceptions import FileCheckException -from msprobe.core.common.file_check import (check_link, +from msprobe.core.common.file_utils import (check_link, check_path_length, check_path_exists, check_path_readability, @@ -40,7 +40,7 @@ from msprobe.core.common.file_check import (check_link, class TestFileCheckUtil(TestCase): @patch.object(logger, "error") def test_check_link(self, mock_logger_error): - with patch("msprobe.core.common.file_check.os.path.islink", return_value=True): + with patch("msprobe.core.common.file_utils.os.path.islink", return_value=True): with self.assertRaises(FileCheckException) as context: check_link("link_path") self.assertEqual(str(context.exception), @@ -72,7 +72,7 @@ class TestFileCheckUtil(TestCase): @patch.object(logger, "error") def test_check_path_exists(self, mock_logger_error): - with patch("msprobe.core.common.file_check.os.path.exists", return_value=False): + with patch("msprobe.core.common.file_utils.os.path.exists", return_value=False): with self.assertRaises(FileCheckException) as context: check_path_exists("file_path") self.assertEqual(str(context.exception), @@ -82,7 +82,7 @@ class TestFileCheckUtil(TestCase): @patch.object(logger, "error") def test_check_path_readability(self, mock_logger_error): path = "file_path" - with patch("msprobe.core.common.file_check.os.access", return_value=False): + with patch("msprobe.core.common.file_utils.os.access", return_value=False): with self.assertRaises(FileCheckException) as context: check_path_readability(path) self.assertEqual(str(context.exception), @@ -91,14 +91,14 @@ class TestFileCheckUtil(TestCase): mock_access = MagicMock() mock_access.return_value = True - with patch("msprobe.core.common.file_check.os.access", new=mock_access): + with patch("msprobe.core.common.file_utils.os.access", new=mock_access): check_path_readability(path) self.assertEqual(mock_access.call_args[0], (path, os.R_OK)) @patch.object(logger, "error") def test_check_path_writability(self, mock_logger_error): path = "file_path" - with patch("msprobe.core.common.file_check.os.access", return_value=False): + with patch("msprobe.core.common.file_utils.os.access", return_value=False): with self.assertRaises(FileCheckException) as context: check_path_writability(path) self.assertEqual(str(context.exception), @@ -107,14 +107,14 @@ class TestFileCheckUtil(TestCase): mock_access = MagicMock() mock_access.return_value = True - with patch("msprobe.core.common.file_check.os.access", new=mock_access): + with patch("msprobe.core.common.file_utils.os.access", new=mock_access): check_path_writability(path) self.assertEqual(mock_access.call_args[0], (path, os.W_OK)) @patch.object(logger, "error") def test_check_path_executable(self, mock_logger_error): path = "file_path" - with patch("msprobe.core.common.file_check.os.access", return_value=False): + with patch("msprobe.core.common.file_utils.os.access", return_value=False): with self.assertRaises(FileCheckException) as context: check_path_executable(path) self.assertEqual(str(context.exception), @@ -123,7 +123,7 @@ class TestFileCheckUtil(TestCase): mock_access = MagicMock() mock_access.return_value = True - with patch("msprobe.core.common.file_check.os.access", new=mock_access): + with patch("msprobe.core.common.file_utils.os.access", new=mock_access): check_path_executable(path) self.assertEqual(mock_access.call_args[0], (path, os.X_OK)) @@ -135,7 +135,7 @@ class TestFileCheckUtil(TestCase): path = "file_path" mock_stat = TestStat(0o002) - with patch("msprobe.core.common.file_check.os.stat", return_value=mock_stat): + with patch("msprobe.core.common.file_utils.os.stat", return_value=mock_stat): with self.assertRaises(FileCheckException) as context: check_other_user_writable(path) self.assertEqual(str(context.exception), @@ -147,7 +147,7 @@ class TestFileCheckUtil(TestCase): def test_check_path_owner_consistent(self, mock_logger_error): file_path = os.path.realpath(__file__) file_owner = os.stat(file_path).st_uid - with patch("msprobe.core.common.file_check.os.getuid", return_value=file_owner+1): + with patch("msprobe.core.common.file_utils.os.getuid", return_value=file_owner+1): with self.assertRaises(FileCheckException) as context: check_path_owner_consistent(file_path) self.assertEqual(str(context.exception), @@ -160,7 +160,7 @@ class TestFileCheckUtil(TestCase): path = "path" mock_re_match = MagicMock() mock_re_match.return_value = False - with patch("msprobe.core.common.file_check.re.match", new=mock_re_match): + with patch("msprobe.core.common.file_utils.re.match", new=mock_re_match): with self.assertRaises(FileCheckException) as context: check_path_pattern_vaild(path) self.assertEqual(str(context.exception), @@ -181,8 +181,8 @@ class TestFileCheckUtil(TestCase): def test_check_common_file_size(self): mock_check_file_size = MagicMock() - with patch("msprobe.core.common.file_check.os.path.isfile", return_value=True), \ - patch("msprobe.core.common.file_check.check_file_size", new=mock_check_file_size): + with patch("msprobe.core.common.file_utils.os.path.isfile", return_value=True), \ + patch("msprobe.core.common.file_utils.check_file_size", new=mock_check_file_size): for suffix, max_size in FileCheckConst.FILE_SIZE_DICT.items(): check_common_file_size(suffix) mock_check_file_size.assert_called_with(suffix, max_size) @@ -201,16 +201,16 @@ class TestFileCheckUtil(TestCase): def test_check_path_type(self, mock_logger_error): file_path = "file_path" - with patch("msprobe.core.common.file_check.os.path.isfile", return_value=False), \ - patch("msprobe.core.common.file_check.os.path.isdir", return_value=True): + with patch("msprobe.core.common.file_utils.os.path.isfile", return_value=False), \ + patch("msprobe.core.common.file_utils.os.path.isdir", return_value=True): with self.assertRaises(FileCheckException) as context: check_path_type(file_path, FileCheckConst.FILE) self.assertEqual(str(context.exception), FileCheckException.err_strs.get(FileCheckException.INVALID_FILE_ERROR)) mock_logger_error.assert_called_with(f"The {file_path} should be a file!") - with patch("msprobe.core.common.file_check.os.path.isfile", return_value=True), \ - patch("msprobe.core.common.file_check.os.path.isdir", return_value=False): + with patch("msprobe.core.common.file_utils.os.path.isfile", return_value=True), \ + patch("msprobe.core.common.file_utils.os.path.isdir", return_value=False): with self.assertRaises(FileCheckException) as context: check_path_type(file_path, FileCheckConst.DIR) self.assertEqual(str(context.exception), diff --git a/debug/accuracy_tools/setup.py b/debug/accuracy_tools/setup.py index a794e61d0f558291fcbe8b85b2b2e38fa7c2c9b5..549c5c02ba738c030b907393d30f153ed9c6b628 100644 --- a/debug/accuracy_tools/setup.py +++ b/debug/accuracy_tools/setup.py @@ -21,8 +21,9 @@ from setuptools import setup repository_url = "https://gitee.com/ascend/msit.git" target_dir = "./msprobe/msit" +branch_name = "master" -subprocess.check_call(["git", "submodule", "add", repository_url, target_dir]) +subprocess.check_call(["git", "submodule", "add", "-b", branch_name, repository_url, target_dir]) subprocess.check_call(["git", "submodule", "init"]) subprocess.check_call(["git", "submodule", "update"]) @@ -37,7 +38,7 @@ EXCLUDE_PKGS = [ setup( name='mindstudio-probe', - version='1.0.3', + version='1.0.4', description='This is a pytorch precision comparison tools', long_description='This is a pytorch precision comparison tools, include ptdbg and api accuracy checker', packages=setuptools.find_namespace_packages(exclude=EXCLUDE_PKGS, include=["msprobe", "msprobe*"]), diff --git a/profiler/cluster_analyse/analysis/mstx_sum/mstx_sum.py b/profiler/cluster_analyse/analysis/mstx_sum/mstx_sum.py index 46a0e18abeee5cdd6b058d71e3a1bd2b97e7c29d..ccf2ac7bbdd60e1ef1325d3cd434e8e7d0caeb1a 100644 --- a/profiler/cluster_analyse/analysis/mstx_sum/mstx_sum.py +++ b/profiler/cluster_analyse/analysis/mstx_sum/mstx_sum.py @@ -19,7 +19,7 @@ from collections import namedtuple from analysis.base_analysis import BaseRecipeAnalysis from common_func.constant import Constant from common_func.utils import describe_duration -from cluster_statistics_export.mstx_mark_export import MstxMarkExport +from cluster_statistics_export.mstx_event_export import MstxMarkExport, MstxRangeExport from cluster_statistics_export.mstx_step_export import MstxStepExport @@ -40,16 +40,28 @@ def format_mark_info(df: pd.DataFrame, start_idx, stop_idx, name) -> MarkInfo: ) -def rename_mark_msg_name(mark_stats_df: pd.DataFrame): +def format_range_info(df: pd.DataFrame, idx, name) -> MarkInfo: + range_series = df.iloc[idx] + return MarkInfo( + name=name, + framework_duration=float(0), + cann_duration=float(range_series["cann_end_ts"]- range_series["cann_start_ts"]), + device_duration=float(range_series["device_end_ts"]- range_series["device_start_ts"]), + tid=range_series["tid"], + start_ns=range_series["cann_start_ts"] + ) + + +def rename_mark_msg_name(mstx_stats_df: pd.DataFrame): msg_idx_counter = {} - for idx, mark_info in enumerate(mark_stats_df.itertuples(index=False)): + for idx, mark_info in enumerate(mstx_stats_df.itertuples(index=False)): msg_idx_counter.setdefault(mark_info.step_id, {}).setdefault(mark_info.name, []).append(idx) for msg_dict in msg_idx_counter.values(): for msg, idx_list in msg_dict.items(): if len(idx_list) <= 1: continue for i, idx in enumerate(idx_list): - mark_stats_df.loc[idx, 'name'] = f"{msg}_{i}" + mstx_stats_df.loc[idx, 'name'] = f"{msg}_{i}" def compute_step_id(mark_stat, step_stats_df: pd.DataFrame): @@ -77,6 +89,45 @@ def format_columns(df: pd.DataFrame): return formatted_df[cols] +def handle_mark_data(mark_df: pd.DataFrame, data_map: list) -> list: + res = [] + mark_df["framework_ts"] = mark_df["framework_ts"].astype("int64") + mark_info = {} + mismatch_msg = [] + for idx, row in enumerate(mark_df.itertuples(index=False)): + if row.msg.endswith(MstxSum.START_SUFFIX): + msg = row.msg[:-len(MstxSum.START_SUFFIX)] + mark_info.setdefault(row.tid, {}).setdefault(msg, []).append(idx) + elif row.msg.endswith(MstxSum.STOP_SUFFIX): + msg = row.msg[:-len(MstxSum.STOP_SUFFIX)] + idx_list = mark_info.get(row.tid, {}).get(msg, []) + if not idx_list: + mismatch_msg.append((row.msg, idx)) + continue + start_idx = idx_list.pop() + res.append(format_mark_info(mark_df, start_idx, idx, msg)) + + # 统计未匹配上的mark信息 + for msg_info in mark_info.values(): + for msg, idx_list in msg_info.items(): + if not idx_list: + continue + mismatch_msg.extend((msg + MstxSum.START_SUFFIX, idx) for idx in idx_list) + if mismatch_msg: + mismatch_msg.sort(key=lambda msg: msg[1]) + print(f"[WARNING] The following mark messages do not match anyone in " + f"rank {data_map[0]}: {','.join(msg[0] for msg in mismatch_msg)}.") + + return res + + +def handle_range_data(range_df: pd.DataFrame) -> list: + res = [] + for idx, row in enumerate(range_df.itertuples(index=False)): + res.append(format_range_info(range_df, idx, row.msg)) + return res + + class MstxSum(BaseRecipeAnalysis): TABLE_FRAMEWORK_STATS = "MSTXAllFrameworkStats" @@ -105,43 +156,22 @@ class MstxSum(BaseRecipeAnalysis): if step_df is None or step_df.empty: step_df = pd.DataFrame({"start_ns": [0], "end_ns": [float("inf")], "step_id": [0]}) mark_df = MstxMarkExport(data_map[1], analysis_class).read_export_db() - if mark_df is None or mark_df.empty: - print(f"[WARNING] There is no mark data in {data_map[1]}.") + range_df = MstxRangeExport(data_map[1], analysis_class).read_export_db() + + mstx_res = [] + if not mark_df.empty: + mstx_res += handle_mark_data(mark_df, data_map) + if not range_df.empty: + mstx_res += handle_range_data(range_df) + if not mstx_res: + print(f"[WARNING] There is no mstx data in {data_map[1]}.") return None - mark_df["framework_ts"] = mark_df["framework_ts"].astype("int64") - - mark_info = {} - mark_res = [] - mismatch_msg = [] - for idx, row in enumerate(mark_df.itertuples(index=False)): - if row.msg.endswith(MstxSum.START_SUFFIX): - msg = row.msg[:-len(MstxSum.START_SUFFIX)] - mark_info.setdefault(row.tid, {}).setdefault(msg, []).append(idx) - elif row.msg.endswith(MstxSum.STOP_SUFFIX): - msg = row.msg[:-len(MstxSum.STOP_SUFFIX)] - idx_list = mark_info.get(row.tid, {}).get(msg, []) - if not idx_list: - mismatch_msg.append((row.msg, idx)) - continue - start_idx = idx_list.pop() - mark_res.append(format_mark_info(mark_df, start_idx, idx, msg)) - - # 统计未匹配上的mark信息 - for msg_info in mark_info.values(): - for msg, idx_list in msg_info.items(): - if not idx_list: - continue - mismatch_msg.extend((msg + MstxSum.START_SUFFIX, idx) for idx in idx_list) - if mismatch_msg: - mismatch_msg.sort(key=lambda msg: msg[1]) - print(f"[WARNING] The following mark messages do not match anyone in " - f"rank {data_map[0]}: {','.join(msg[0] for msg in mismatch_msg)}.") - - mark_stats_df = pd.DataFrame(mark_res).assign(Rank=data_map[0]) - mark_stats_df["step_id"] = mark_stats_df.apply(compute_step_id, axis=1, step_stats_df=step_df) - rename_mark_msg_name(mark_stats_df) - mark_stats_df = format_columns(mark_stats_df).set_index("Name", drop=True) - return mark_stats_df + + mstx_stats_df = pd.DataFrame(mstx_res).assign(Rank=data_map[0]) + mstx_stats_df["step_id"] = mstx_stats_df.apply(compute_step_id, axis=1, step_stats_df=step_df) + rename_mark_msg_name(mstx_stats_df) + mstx_stats_df = format_columns(mstx_stats_df).set_index("Name", drop=True) + return mstx_stats_df def mapper_func(self, context): return context.wait( diff --git a/profiler/cluster_analyse/cluster_statistics_export/mstx_mark_export.py b/profiler/cluster_analyse/cluster_statistics_export/mstx_event_export.py similarity index 66% rename from profiler/cluster_analyse/cluster_statistics_export/mstx_mark_export.py rename to profiler/cluster_analyse/cluster_statistics_export/mstx_event_export.py index ac5355c020042d474963296242b79eb3fd6a8c38..b70a51fa4bbf88c832d24d26e23ade5cdfad3913 100644 --- a/profiler/cluster_analyse/cluster_statistics_export/mstx_mark_export.py +++ b/profiler/cluster_analyse/cluster_statistics_export/mstx_event_export.py @@ -16,7 +16,7 @@ from cluster_statistics_export.stats_export import StatsExport -QUERY = """ +MARK_QUERY = """ WITH FRAMEWORK_API AS ( SELECT @@ -45,6 +45,8 @@ LEFT JOIN LEFT JOIN STRING_IDS AS MSG_IDS ON MSTX_EVENTS.message == MSG_IDS.id +WHERE + MSTX_EVENTS.eventType == 3 ORDER BY MSTX_EVENTS.startNs """ @@ -54,4 +56,36 @@ class MstxMarkExport(StatsExport): def __init__(self, db_path, recipe_name): super().__init__(db_path, recipe_name) - self._query = QUERY + self._query = MARK_QUERY + + +RANGE_QUERY = ''' +SELECT + MSG_IDS.value AS "msg", + MSTX_EVENTS.startNs AS "cann_start_ts", + MSTX_EVENTS.endNs AS "cann_end_ts", + TASK.startNs AS "device_start_ts", + TASK.endNs AS "device_end_ts", + MSTX_EVENTS.globalTid AS "tid" +FROM + MSTX_EVENTS +LEFT JOIN + TASK + ON MSTX_EVENTS.connectionId == TASK.connectionId +LEFT JOIN + STRING_IDS AS MSG_IDS + ON MSTX_EVENTS.message == MSG_IDS.id +WHERE + MSTX_EVENTS.eventType == 2 +AND + MSTX_EVENTS.connectionId != 4294967295 +ORDER BY + MSTX_EVENTS.startNs + ''' + + +class MstxRangeExport(StatsExport): + + def __init__(self, db_path, recipe_name): + super().__init__(db_path, recipe_name) + self._query = RANGE_QUERY \ No newline at end of file diff --git a/profiler/osrt_trace/README.md b/profiler/osrt_trace/README.md new file mode 100644 index 0000000000000000000000000000000000000000..a3f3a7d59a82fb38588e10939902197b7d369fdd --- /dev/null +++ b/profiler/osrt_trace/README.md @@ -0,0 +1,157 @@ +# MSOSRT Trace系统库函数耗时检测 + +OSRT(OS runtime libraries trace)是根据Linux操作系统运行时库采集用户层库函数API的调用信息。MSOSRT(MindStudio OSRT)则是采集Linux C库函数和POSIX线程(pthread)库中典型的高耗时接口,即可能阻塞用户进程的函数(如read、ioctl,pthread_mutex_lock等),统计其耗时信息,帮助用户分析进程阻塞的原因。 + +## 使用方法 + +1. 约束条件:仅支持Linux系统,拥有g++编译环境和glibc、pthread等标准库。 +2. 将mstt代码仓下载到本地,进入到profiler/osrt_trace目录,执行`bash build.sh`,生成`libmsosrt_trace.so`。 +3. 执行`export LD_PRELOAD=./libmsosrt_trace.so:$LD_PRELOAD`,将`libmsosrt_trace.so`加入到LD_PRELOAD环境变量中。 +4. 设置检测阈值和导出目录的环境变量: + + ```bash + # 检测阈值,正整数,只统计超过阈值的库函数,单位:ns,默认为10000000 + export MSOSRT_TRACE_THRESHOLD=10000000 + # 导出目录,字符串,设置检测结果导出的目录,默认为当前目录 + export MSOSRT_EXPORT_PATH="./osrt_trace_result" + ``` + +5. 执行用户进程,如`python main.py` + +6. 用户进程执行结束后,在MSOSRT_EXPORT_PATH路径下会生成检测结果,生成结果文件:msosrt_trace\_{进程号}\_{进程名}.csv,如`msosrt_trace_2328177_python3.csv`,文件内容包含pid、tid、函数名、开始执行时间和耗时等信息,如下所示: + + | Pid | Tid | Function | StartTime(ns) | Duration(ns) | + | ------: | ------: | ----------------: | ------------------: | -----------: | + | 2328177 | 2328280 | pthread_cond_wait | 1725398310787080000 | 3088062410 | + | 2328177 | 2328282 | pthread_cond_wait | 1725398310787170000 | 3087994240 | + | 2328177 | 2328480 | read | 1725398318916180000 | 100509970 | + | 2328177 | 2328440 | ioctl | 1725398319218640000 | 512040720 | + | 2328177 | 2328177 | free | 1725398330504550000 | 56386880 | + +## 检测接口 + +MSOSRT支持检测如下操作系统库函数: + +- 内存操作 + + ```c + malloc + realloc + free + mmap + munmap + mremap + msync + mprotect + brk + ``` + +- 文件操作 + + ```c + dup + dup2 + dup3 + tee + splice + fallocate + fdatasync + fsync + fcntl + flock + lockf + truncate + ftruncate + ioctl + open + openat + pipe + pipe2 + mkfifo + mkfifoat + read + pread + readv + preadv + preadv2 + write + pwrite + writev + pwritev + pwritev2 + copy_file_range + sync + syncfs + sync_file_range + vmsplice + process_vm_readv + process_vm_writev + fclose + fcloseall + fflush + fgetc + fgets + fputc + fputs + flockfile + ftrylockfile + funlockfile + fopen + freopen + fread + fwrite + getdelim + getline + getc + putc + getc_unlocked + putc_unlocked + fflush_unlocked + fgetc_unlocked + fputc_unlocked + fread_unlocked + fwrite_unlocked + fgets_unlocked + fputs_unlocked + ``` + +- 网络操作 + + ```c + socket + socketpair + epoll_ctl + epoll_wait + epoll_pwait + select + listen + accept + accept4 + bind + poll + ppoll + send + sendto + sendmsg + sendmmsg + sendfile + recv + recvfrom + recvmsg + recvmmsg + ``` + +- 线程操作 + + ```c + pthread_mutex_lock + pthread_mutex_timedlock + pthread_cond_signal + pthread_cond_broadcast + pthread_cond_wait + pthread_cond_timedwait + pthread_rwlock_rdlock + pthread_rwlock_timedrdlock + pthread_rwlock_wrlock + pthread_rwlock_timedwrlock + ``` \ No newline at end of file diff --git a/profiler/osrt_trace/build.sh b/profiler/osrt_trace/build.sh new file mode 100644 index 0000000000000000000000000000000000000000..bb153e6247122c922dc5cea247be43bfec3d5430 --- /dev/null +++ b/profiler/osrt_trace/build.sh @@ -0,0 +1 @@ +g++ ./src/*.cpp -std=c++11 -fPIC -fstack-protector-all -fno-strict-aliasing -fno-common -fvisibility=hidden -fvisibility-inlines-hidden -Wfloat-equal -Wextra -O2 -shared -lpthread -ldl -o libmsosrt_trace.so \ No newline at end of file diff --git a/profiler/osrt_trace/src/file_func.cpp b/profiler/osrt_trace/src/file_func.cpp new file mode 100644 index 0000000000000000000000000000000000000000..319dcb227b139adf158d55fe762f97afdfa5fdd8 --- /dev/null +++ b/profiler/osrt_trace/src/file_func.cpp @@ -0,0 +1,664 @@ +#include "file_func.h" + +#include + +#include "msosrt_trace.h" + +void FileFuncProxy::loadFunc() +{ + LOAD_FUNC(dup, DupFunc); + LOAD_FUNC(dup2, Dup2Func); + LOAD_FUNC(dup3, Dup3Func); + LOAD_FUNC(tee, TeeFunc); + LOAD_FUNC(splice, SpliceFunc); + LOAD_FUNC(fallocate, FallocateFunc); + LOAD_FUNC(fdatasync, FdatasyncFunc); + LOAD_FUNC(fsync, FsyncFunc); + LOAD_FUNC(fcntl, FcntlFunc); + LOAD_FUNC(flock, FlockFunc); + LOAD_FUNC(lockf, LockfFunc); + LOAD_FUNC(truncate, TruncateFunc); + LOAD_FUNC(ftruncate, FtruncateFunc); + LOAD_FUNC(ioctl, IoctlFunc); + LOAD_FUNC(open, OpenFunc); + LOAD_FUNC(openat, OpenatFunc); + LOAD_FUNC(pipe, PipeFunc); + LOAD_FUNC(pipe2, Pipe2Func); + LOAD_FUNC(mkfifo, MkfifoFunc); + LOAD_FUNC(mkfifoat, MkfifoatFunc); + LOAD_FUNC(read, ReadFunc); + LOAD_FUNC(pread, PreadFunc); + LOAD_FUNC(readv, ReadvFunc); + LOAD_FUNC(preadv, PreadvFunc); + LOAD_FUNC(preadv2, Preadv2Func); + LOAD_FUNC(write, WriteFunc); + LOAD_FUNC(pwrite, PwriteFunc); + LOAD_FUNC(writev, WritevFunc); + LOAD_FUNC(pwritev, PwritevFunc); + LOAD_FUNC(pwritev2, Pwritev2Func); + LOAD_FUNC(copy_file_range, CopyFileRangeFunc); + LOAD_FUNC(sync, SyncFunc); + LOAD_FUNC(syncfs, SyncfsFunc); + LOAD_FUNC(sync_file_range, SyncFileRangeFunc); + LOAD_FUNC(vmsplice, VmspliceFunc); + LOAD_FUNC(process_vm_readv, ProcessVmReadvFunc); + LOAD_FUNC(process_vm_writev, ProcessVmWritevFunc); + LOAD_FUNC(fclose, FcloseFunc); + LOAD_FUNC(fcloseall, FcloseallFunc); + LOAD_FUNC(fflush, FflushFunc); + LOAD_FUNC(fgetc, FgetcFunc); + LOAD_FUNC(fgets, FgetsFunc); + LOAD_FUNC(fputc, FputcFunc); + LOAD_FUNC(fputs, FputsFunc); + LOAD_FUNC(flockfile, FlockfileFunc); + LOAD_FUNC(ftrylockfile, FtrylockfileFunc); + LOAD_FUNC(funlockfile, FunlockfileFunc); + LOAD_FUNC(fopen, FopenFunc); + LOAD_FUNC(freopen, FreopenFunc); + LOAD_FUNC(fread, FreadFunc); + LOAD_FUNC(fwrite, FwriteFunc); + LOAD_FUNC(getdelim, GetdelimFunc); + LOAD_FUNC(getline, GetlineFunc); + LOAD_FUNC(getc, GetcFunc); + LOAD_FUNC(putc, PutcFunc); + LOAD_FUNC(getc_unlocked, GetcUnlockedFunc); + LOAD_FUNC(putc_unlocked, PutcUnlockedFunc); + LOAD_FUNC(fflush_unlocked, FflushUnlockedFunc); + LOAD_FUNC(fgetc_unlocked, FgetcUnlockedFunc); + LOAD_FUNC(fputc_unlocked, FputcUnlockedFunc); + LOAD_FUNC(fread_unlocked, FreadUnlockedFunc); + LOAD_FUNC(fwrite_unlocked, FwriteUnlockedFunc); + LOAD_FUNC(fgets_unlocked, FgetsUnlockedFunc); + LOAD_FUNC(fputs_unlocked, FputsUnlockedFunc); +} + +int dup(int oldfd) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_dup(oldfd); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int dup2(int oldfd, int newfd) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_dup2(oldfd, newfd); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int dup3(int oldfd, int newfd, int flags) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_dup3(oldfd, newfd, flags); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +ssize_t tee(int fd_in, int fd_out, size_t len, unsigned int flags) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_tee(fd_in, fd_out, len, flags); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +ssize_t splice(int fd_in, off_t* off_in, int fd_out, off_t* off_out, size_t len, unsigned int flags) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_splice(fd_in, off_in, fd_out, off_out, len, flags); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int fallocate(int fd, int mode, off_t offset, off_t len) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_fallocate(fd, mode, offset, len); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int fdatasync(int fildes) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_fdatasync(fildes); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int fsync(int fd) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_fsync(fd); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int fcntl(int fd, int op, ...) +{ + global_osrt_func.loadFunc(); + va_list args; + va_start(args, op); + void* arg = va_arg(args, void*); + va_end(args); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_fcntl(fd, op, arg); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int flock(int fd, int op) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_flock(fd, op); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int lockf(int fd, int op, off_t len) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_lockf(fd, op, len); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int truncate(const char* path, off_t length) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_truncate(path, length); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int ftruncate(int fildes, off_t length) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_ftruncate(fildes, length); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int ioctl(int fd, int op, ...) +{ + global_osrt_func.loadFunc(); + va_list args; + va_start(args, op); + void* arg = va_arg(args, void*); + va_end(args); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_ioctl(fd, op, arg); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int open(const char* pathname, int flags, ...) +{ + global_osrt_func.loadFunc(); + va_list args; + va_start(args, flags); + mode_t arg = va_arg(args, mode_t); + va_end(args); + uint64_t start_time = nsec_now(); + auto ret = arg ? global_osrt_func.file_func.real_open(pathname, flags, arg) : global_osrt_func.file_func.real_open(pathname, flags); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int openat(int dirfd, const char *pathname, int flags, ...) +{ + global_osrt_func.loadFunc(); + va_list args; + va_start(args, flags); + mode_t arg = va_arg(args, mode_t); + va_end(args); + uint64_t start_time = nsec_now(); + auto ret = arg ? global_osrt_func.file_func.real_openat(dirfd, pathname, flags, arg) : global_osrt_func.file_func.real_openat(dirfd, pathname, flags); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int pipe(int pipefd[2]) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_pipe(pipefd); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int pipe2(int pipefd[2], int flags) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_pipe2(pipefd, flags); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int mkfifo(const char* pathname, mode_t mode) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_mkfifo(pathname, mode); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int mkfifoat(int dirfd, const char* pathname, mode_t mode) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_mkfifoat(dirfd, pathname, mode); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +ssize_t read(int fd, void* buf, size_t count) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_read(fd, buf, count); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +ssize_t pread(int fd, void* buf, size_t count, off_t offset) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_pread(fd, buf, count, offset); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +ssize_t readv(int fd, const struct iovec* iov, int iovcnt) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_readv(fd, iov, iovcnt); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +ssize_t preadv(int fd, const struct iovec* iov, int iovcnt, off_t offset) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_preadv(fd, iov, iovcnt, offset); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +ssize_t preadv2(int fd, const struct iovec* iov, int iovcnt, off_t offset, int flags) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_preadv2(fd, iov, iovcnt, offset, flags); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +ssize_t write(int fd, const void* buf, size_t count) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_write(fd, buf, count); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +ssize_t pwrite(int fd, const void* buf, size_t count, off_t offset) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_pwrite(fd, buf, count, offset); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +ssize_t writev(int fd, const struct iovec* iov, int iovcnt) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_writev(fd, iov, iovcnt); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +ssize_t pwritev(int fd, const struct iovec* iov, int iovcnt, off_t offset) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_pwritev(fd, iov, iovcnt, offset); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +ssize_t pwritev2(int fd, const struct iovec* iov, int iovcnt, off_t offset, int flags) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_pwritev2(fd, iov, iovcnt, offset, flags); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +ssize_t copy_file_range(int fd_in, off_t* off_in, int fd_out, off_t* off_out, size_t len, unsigned int flags) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_copy_file_range(fd_in, off_in, fd_out, off_out, len, flags); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +void sync(void) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + global_osrt_func.file_func.real_sync(); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); +} + +int syncfs(int fd) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_syncfs(fd); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int sync_file_range(int fd, off_t offset, off_t nbytes, unsigned int flags) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_sync_file_range(fd, offset, nbytes, flags); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +ssize_t vmsplice(int fd, const struct iovec* iov, size_t nr_segs, unsigned int flags) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_vmsplice(fd, iov, nr_segs, flags); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +ssize_t process_vm_readv(pid_t pid, const struct iovec* local_iov, unsigned long liovcnt, + const struct iovec* remote_iov, unsigned long riovcnt, unsigned long flags) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_process_vm_readv(pid, local_iov, liovcnt, remote_iov, riovcnt, flags); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +ssize_t process_vm_writev(pid_t pid, const struct iovec* local_iov, unsigned long liovcnt, + const struct iovec* remote_iov, unsigned long riovcnt, unsigned long flags) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_process_vm_writev(pid, local_iov, liovcnt, remote_iov, riovcnt, flags); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int fclose(FILE* stream) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_fclose(stream); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int fcloseall(void) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_fcloseall(); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int fflush(FILE* stream) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_fflush(stream); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int fgetc(FILE* stream) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_fgetc(stream); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +char* fgets(char* s, int size, FILE* stream) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + char* ret = global_osrt_func.file_func.real_fgets(s, size, stream); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int fputc(int c, FILE* stream) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_fputc(c, stream); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int fputs(const char* s, FILE* stream) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_fputs(s, stream); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +void flockfile(FILE* filehandle) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + global_osrt_func.file_func.real_flockfile(filehandle); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); +} + +int ftrylockfile(FILE* filehandle) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_ftrylockfile(filehandle); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +void funlockfile(FILE* filehandle) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + global_osrt_func.file_func.real_funlockfile(filehandle); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); +} + +FILE* fopen(const char* pathname, const char* mode) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_fopen(pathname, mode); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +FILE* freopen(const char* pathname, const char* mode, FILE* stream) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_freopen(pathname, mode, stream); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +size_t fread(void* ptr, size_t size, size_t nmemb, FILE* stream) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_fread(ptr, size, nmemb, stream); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +size_t fwrite(const void* ptr, size_t size, size_t nitems, FILE* stream) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_fwrite(ptr, size, nitems, stream); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +ssize_t getdelim(char** lineptr, size_t* n, int delimiter, FILE* stream) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_getdelim(lineptr, n, delimiter, stream); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +ssize_t getline(char** lineptr, size_t* n, FILE* stream) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_getline(lineptr, n, stream); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int getc(FILE* stream) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_getc(stream); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int putc(int c, FILE* stream) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_putc(c, stream); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int getc_unlocked(FILE* stream) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_getc_unlocked(stream); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int putc_unlocked(int c, FILE* stream) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_putc_unlocked(c, stream); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int fflush_unlocked(FILE* stream) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_fflush_unlocked(stream); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int fgetc_unlocked(FILE* stream) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_fgetc_unlocked(stream); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int fputc_unlocked(int c, FILE* stream) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_fputc_unlocked(c, stream); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +size_t fread_unlocked(void* ptr, size_t size, size_t n, FILE* stream) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_fread_unlocked(ptr, size, n, stream); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +size_t fwrite_unlocked(const void* ptr, size_t size, size_t n, FILE* stream) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_fwrite_unlocked(ptr, size, n, stream); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +char* fgets_unlocked(char* s, int n, FILE* stream) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + char* ret = global_osrt_func.file_func.real_fgets_unlocked(s, n, stream); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int fputs_unlocked(const char* s, FILE* stream) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.file_func.real_fputs_unlocked(s, stream); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} diff --git a/profiler/osrt_trace/src/file_func.h b/profiler/osrt_trace/src/file_func.h new file mode 100644 index 0000000000000000000000000000000000000000..23c6a25eeeddd734a1ab10ecfcb7d3035d2f9a6a --- /dev/null +++ b/profiler/osrt_trace/src/file_func.h @@ -0,0 +1,144 @@ +#pragma once + +#ifndef _GNU_SOURCE +#define _GNU_SOURCE +#endif + +#include +#include +#include + +using DupFunc = int(*)(int); +using Dup2Func = int(*)(int, int); +using Dup3Func = int(*)(int, int, int); +using TeeFunc = ssize_t(*)(int, int, size_t, unsigned int); +using SpliceFunc = ssize_t(*)(int, off_t*, int, off_t*, size_t, unsigned int); +using FallocateFunc = int(*)(int, int, off_t, off_t); +using FdatasyncFunc = int(*)(int); +using FsyncFunc = int(*)(int); +using FcntlFunc = int(*)(int, int, ...); +using FlockFunc = int(*)(int, int); +using LockfFunc = int(*)(int, int, off_t); +using TruncateFunc = int(*)(const char*, off_t); +using FtruncateFunc = int(*)(int, off_t); +using IoctlFunc = int(*)(int, int, ...); +using OpenFunc = int(*)(const char*, int, ...); +using OpenatFunc = int(*)(int, const char*, int, ...); +using PipeFunc = int(*)(int*); +using Pipe2Func = int(*)(int*, int); +using MkfifoFunc = int(*)(const char*, mode_t); +using MkfifoatFunc = int(*)(int, const char*, mode_t); +using ReadFunc = ssize_t(*)(int, void*, size_t); +using PreadFunc = ssize_t(*)(int, void*, size_t, off_t); +using ReadvFunc = ssize_t(*)(int, const struct iovec*, int); +using PreadvFunc = ssize_t(*)(int, const struct iovec*, int, off_t); +using Preadv2Func = ssize_t(*)(int, const struct iovec*, int, off_t, int); +using WriteFunc = ssize_t(*)(int, const void*, size_t); +using PwriteFunc = ssize_t(*)(int, const void*, size_t, off_t); +using WritevFunc = ssize_t(*)(int, const struct iovec*, int); +using PwritevFunc = ssize_t(*)(int, const struct iovec*, int, off_t); +using Pwritev2Func = ssize_t(*)(int, const struct iovec*, int, off_t, int); +using CopyFileRangeFunc = ssize_t(*)(int, off_t*, int, off_t*, size_t, unsigned int); +using SyncFunc = void(*)(void); +using SyncfsFunc = int(*)(int); +using SyncFileRangeFunc = int(*)(int, off_t, off_t, unsigned int); +using VmspliceFunc = ssize_t(*)(int, const struct iovec*, size_t, unsigned int); +using ProcessVmReadvFunc = ssize_t(*)(pid_t, const struct iovec*, unsigned long, const struct iovec*, unsigned long, unsigned long); +using ProcessVmWritevFunc = ssize_t(*)(pid_t, const struct iovec*, unsigned long, const struct iovec*, unsigned long, unsigned long); +using FcloseFunc = int(*)(FILE*); +using FcloseallFunc = int(*)(void); +using FflushFunc = int(*)(FILE*); +using FgetcFunc = int(*)(FILE*); +using FgetsFunc = char*(*)(char*, int, FILE*); +using FputcFunc = int(*)(int, FILE*); +using FputsFunc = int(*)(const char*, FILE*); +using FlockfileFunc = void(*)(FILE*); +using FtrylockfileFunc = int(*)(FILE*); +using FunlockfileFunc = void(*)(FILE*); +using FopenFunc = FILE*(*)(const char*, const char*); +using FreopenFunc = FILE*(*)(const char*, const char*, FILE*); +using FreadFunc = size_t(*)(void*, size_t, size_t, FILE*); +using FwriteFunc = size_t(*)(const void*, size_t, size_t, FILE*); +using GetdelimFunc = ssize_t(*)(char**, size_t*, int, FILE*); +using GetlineFunc = ssize_t(*)(char**, size_t*, FILE*); +using GetcFunc = int(*)(FILE*); +using PutcFunc = int(*)(int, FILE*); +using GetcUnlockedFunc = int(*)(FILE*); +using PutcUnlockedFunc = int(*)(int, FILE*); +using FflushUnlockedFunc = int(*)(FILE*); +using FgetcUnlockedFunc = int(*)(FILE*); +using FputcUnlockedFunc = int(*)(int, FILE*); +using FreadUnlockedFunc = size_t(*)(void*, size_t, size_t, FILE*); +using FwriteUnlockedFunc = size_t(*)(const void*, size_t, size_t, FILE*); +using FgetsUnlockedFunc = char*(*)(char*, int, FILE*); +using FputsUnlockedFunc = int(*)(const char*, FILE*); + +struct FileFuncProxy +{ + DupFunc real_dup = nullptr; + Dup2Func real_dup2 = nullptr; + Dup3Func real_dup3 = nullptr; + TeeFunc real_tee = nullptr; + SpliceFunc real_splice = nullptr; + FallocateFunc real_fallocate = nullptr; + FdatasyncFunc real_fdatasync = nullptr; + FsyncFunc real_fsync = nullptr; + FcntlFunc real_fcntl = nullptr; + FlockFunc real_flock = nullptr; + LockfFunc real_lockf = nullptr; + TruncateFunc real_truncate = nullptr; + FtruncateFunc real_ftruncate = nullptr; + IoctlFunc real_ioctl = nullptr; + OpenFunc real_open = nullptr; + OpenatFunc real_openat = nullptr; + PipeFunc real_pipe = nullptr; + Pipe2Func real_pipe2 = nullptr; + MkfifoFunc real_mkfifo = nullptr; + MkfifoatFunc real_mkfifoat = nullptr; + ReadFunc real_read = nullptr; + PreadFunc real_pread = nullptr; + ReadvFunc real_readv = nullptr; + PreadvFunc real_preadv = nullptr; + Preadv2Func real_preadv2 = nullptr; + WriteFunc real_write = nullptr; + PwriteFunc real_pwrite = nullptr; + WritevFunc real_writev = nullptr; + PwritevFunc real_pwritev = nullptr; + Pwritev2Func real_pwritev2 = nullptr; + CopyFileRangeFunc real_copy_file_range = nullptr; + SyncFunc real_sync = nullptr; + SyncfsFunc real_syncfs = nullptr; + SyncFileRangeFunc real_sync_file_range = nullptr; + VmspliceFunc real_vmsplice = nullptr; + ProcessVmReadvFunc real_process_vm_readv = nullptr; + ProcessVmWritevFunc real_process_vm_writev = nullptr; + FcloseFunc real_fclose = nullptr; + FcloseallFunc real_fcloseall = nullptr; + FflushFunc real_fflush = nullptr; + FgetcFunc real_fgetc = nullptr; + FgetsFunc real_fgets = nullptr; + FputcFunc real_fputc = nullptr; + FputsFunc real_fputs = nullptr; + FlockfileFunc real_flockfile = nullptr; + FtrylockfileFunc real_ftrylockfile = nullptr; + FunlockfileFunc real_funlockfile = nullptr; + FopenFunc real_fopen = nullptr; + FreopenFunc real_freopen = nullptr; + FreadFunc real_fread = nullptr; + FwriteFunc real_fwrite = nullptr; + GetdelimFunc real_getdelim = nullptr; + GetlineFunc real_getline = nullptr; + GetcFunc real_getc = nullptr; + PutcFunc real_putc = nullptr; + GetcUnlockedFunc real_getc_unlocked = nullptr; + PutcUnlockedFunc real_putc_unlocked = nullptr; + FflushUnlockedFunc real_fflush_unlocked = nullptr; + FgetcUnlockedFunc real_fgetc_unlocked = nullptr; + FputcUnlockedFunc real_fputc_unlocked = nullptr; + FreadUnlockedFunc real_fread_unlocked = nullptr; + FwriteUnlockedFunc real_fwrite_unlocked = nullptr; + FgetsUnlockedFunc real_fgets_unlocked = nullptr; + FputsUnlockedFunc real_fputs_unlocked = nullptr; + + void loadFunc(); +}; diff --git a/profiler/osrt_trace/src/msosrt_trace.cpp b/profiler/osrt_trace/src/msosrt_trace.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a3a88b05480193ce9bee6c26480e214f69e4ddf0 --- /dev/null +++ b/profiler/osrt_trace/src/msosrt_trace.cpp @@ -0,0 +1,476 @@ +#include "msosrt_trace.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#if !defined (__linux__) || !defined(__GLIBC__) +#error "This tool only works on Linux!" +#endif + +#ifdef __cplusplus +extern "C" { +#endif +static void setup_trace() __attribute ((constructor)); +static void end_trace() __attribute ((destructor)); +#ifdef __cplusplus +} +#endif + +// Special handling exit func +static void (*real_exit)(int status) __attribute__((noreturn)) = nullptr; +static void (*real__exit)(int status) __attribute__((noreturn)) = nullptr; +static void (*real__Exit)(int status) __attribute__((noreturn)) = nullptr; + +static __thread bool RECURSIVE = false; +static volatile bool INITIALIZED = false; + +namespace { +pid_t GetPid() +{ + static thread_local pid_t pid = getpid(); + return pid; +} + +pid_t GetTid() +{ + static thread_local pid_t tid = gettid(); + return tid; +} + +const char* DUMP_FILE = "msosrt_trace_"; +char EXPORT_PATH[PATH_MAX]; + +const size_t RECORD_LENGTH = 512 * 1024; // Default number of trace data records +struct { + OSRTRecord data_[RECORD_LENGTH]; + std::atomic index_{0}; + bool is_full_ = false; + + void recordData(const char* function, uint64_t start_time, uint64_t duration) + { + size_t index = index_.load(std::memory_order_relaxed); + if (index + 1 >= RECORD_LENGTH) { + index_.store(0, std::memory_order_relaxed); + is_full_ = true; + } else { + index_.fetch_add(1, std::memory_order_relaxed); + } + auto& record = data_[index]; + record.pid = GetPid(); + record.tid = GetTid(); + record.function = function; + record.start_time = start_time; + record.duration = duration; + } + + size_t size() + { + return is_full_ ? RECORD_LENGTH : index_.load(std::memory_order_relaxed); + } + + bool hasValidData() + { + pid_t pid = getpid(); + for (size_t i = 0, len = size(); i < len; ++i) { + if (data_[i].pid == pid && data_[i].function != nullptr) { + return true; + } + } + return false; + } +} OSRT_RECORD_QUEUE; +} + +OSRTFunc global_osrt_func; + +void OSRTFunc::loadFunc() +{ + static volatile bool loaded = false; + if (LIKELY(loaded)) { + return; + } + RECURSIVE = true; + LOAD_FUNC(malloc, MallocFunc); + LOAD_FUNC(realloc, ReallocFunc); + LOAD_FUNC(free, FreeFunc); + LOAD_FUNC(mmap, MmapFunc); + LOAD_FUNC(munmap, MunmapFunc); + LOAD_FUNC(mremap, MremapFunc); + LOAD_FUNC(msync, MsyncFunc); + LOAD_FUNC(mprotect, MprotectFunc); + LOAD_FUNC(brk, BrkFunc); + + LOAD_FUNC(pthread_mutex_lock, PthreadMutexLockFunc); + LOAD_FUNC(pthread_mutex_timedlock, PthreadMutexTimedlockFunc); + LOAD_FUNC(pthread_cond_signal, PthreadCondSignalFunc); + LOAD_FUNC(pthread_cond_broadcast, PthreadCondBroadcastFunc); + LOAD_FUNC(pthread_cond_wait, PthreadCondWaitFunc); + LOAD_FUNC(pthread_cond_timedwait, PthreadCondTimedwaitFunc); + LOAD_FUNC(pthread_rwlock_rdlock, PthreadRwlockRdlockFunc); + LOAD_FUNC(pthread_rwlock_timedrdlock, PthreadRwlockTimedrdlockFunc); + LOAD_FUNC(pthread_rwlock_wrlock, PthreadRwlockWrlockFunc); + LOAD_FUNC(pthread_rwlock_timedwrlock, PthreadRwlockTimedwrlockFunc); + + real_exit = reinterpret_cast(dlsym(RTLD_NEXT, "exit")); + real__exit = reinterpret_cast(dlsym(RTLD_NEXT, "_exit")); + real__Exit = reinterpret_cast(dlsym(RTLD_NEXT, "_Exit")); + + file_func.loadFunc(); + socket_func.loadFunc(); + + loaded = true; + RECURSIVE = false; +} + +void OSRTFunc::recordFunc(uint64_t start_time, uint64_t duration, const char* name) +{ + if (UNLIKELY(!INITIALIZED || RECURSIVE)) { + return; + } + if (UNLIKELY(duration >= threshold_)) { + RECURSIVE = true; + OSRT_RECORD_QUEUE.recordData(name, start_time, duration); + RECURSIVE = false; + } +} + +void OSRTFunc::dumpFunc() +{ + if (!INITIALIZED) { + return; + } + static std::mutex dump_mutex; + static bool dumped = false; + + std::lock_guard lock(dump_mutex); + if (!dumped) { + RECURSIVE = true; + if (OSRT_RECORD_QUEUE.hasValidData()) { + std::string dump_file; + pid_t pid = getpid(); + // The glibc program_invocation_short_name contains the basename that was used to invoke the calling program + if (program_invocation_short_name != nullptr) { + dump_file = std::string(EXPORT_PATH) + "/" + DUMP_FILE + std::to_string(pid) + "_" + program_invocation_short_name + ".csv"; + } else { + dump_file = std::string(EXPORT_PATH) + "/" + DUMP_FILE + std::to_string(pid) + ".csv"; + } + if (!PathUtils::IsFileExist(dump_file) && !PathUtils::CreateFile(dump_file)) { + fprintf(stderr, "[ERROR] Create msosrt trace file failed.\n"); + RECURSIVE = false; + return; + } + auto fd = fopen(dump_file.c_str(), "ab"); + if (fd == nullptr) { + RECURSIVE = false; + return; + } + fprintf(fd, "%s\n", "Pid,Tid,Function,StartTime(ns),Duration(ns)"); + for (size_t i = 0, len = OSRT_RECORD_QUEUE.size(); i < len; ++i) { + if (OSRT_RECORD_QUEUE.data_[i].pid == pid && OSRT_RECORD_QUEUE.data_[i].function != nullptr) { + fprintf(fd, "%" PRIdMAX ",%" PRIdMAX ",%s,%" PRIu64 ",%" PRIu64 "\n", + static_cast(pid), + static_cast(OSRT_RECORD_QUEUE.data_[i].tid), + OSRT_RECORD_QUEUE.data_[i].function, + OSRT_RECORD_QUEUE.data_[i].start_time, + OSRT_RECORD_QUEUE.data_[i].duration); + } + } + fclose(fd); + } + RECURSIVE = false; + } + dumped = true; +} + +static void setup_trace() +{ + if (LIKELY(INITIALIZED)) { + return; + } + global_osrt_func.loadFunc(); + INITIALIZED = true; + + RECURSIVE = true; + const char* threshold_env_val = getenv("MSOSRT_TRACE_THRESHOLD"); + int64_t threshold = 0; + if (threshold_env_val == nullptr || str_to_i64(threshold_env_val, threshold) != 0) { + fprintf(stderr, "[WARNING] Parse MSOSRT_TRACE_THRESHOLD failed, use default value\n"); + } else { + if (threshold > 0) { + global_osrt_func.threshold_ = threshold; + } else { + fprintf(stderr, "[WARNING] MSOSRT_TRACE_THRESHOLD must be a positive integer, use default value\n"); + } + } + + const char* export_path_env_val = getenv("MSOSRT_EXPORT_PATH"); + std::string dump_path; + if (export_path_env_val != nullptr) { + dump_path = export_path_env_val; + } + if (dump_path.empty()) { + fprintf(stderr, "[WARNING] MSOSRT_EXPORT_PATH is not set, data will export to current working directory\n"); + char cwd_path[PATH_MAX] = {0}; + if (getcwd(cwd_path, PATH_MAX) != nullptr) { + dump_path = cwd_path; + } + } + std::string abs_path = PathUtils::RelativeToAbsPath(dump_path); + if (PathUtils::DirPathCheck(abs_path)) { + std::string real_path = PathUtils::RealPath(abs_path); + strncpy(EXPORT_PATH, real_path.c_str(), real_path.size() < PATH_MAX ? real_path.size() : PATH_MAX); + fprintf(stderr, "[INFO] MSOSRT result export path is: %s\n", real_path.c_str()); + } else { + fprintf(stderr, "[ERROR] Invalid export path, data will not be exported.\n"); + } + RECURSIVE = false; +} + +static void end_trace() +{ + global_osrt_func.dumpFunc(); +} + +void* malloc(size_t size) +{ + global_osrt_func.loadFunc(); + if (UNLIKELY(RECURSIVE)) { + return (void*)global_osrt_func.real_malloc(size); + } + uint64_t start_time = nsec_now(); + void* ret = global_osrt_func.real_malloc(size); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +void* realloc(void* ptr, size_t size) +{ + global_osrt_func.loadFunc(); + if (UNLIKELY(RECURSIVE)) { + return (void*)global_osrt_func.real_realloc(ptr, size); + } + uint64_t start_time = nsec_now(); + void* ret = global_osrt_func.real_realloc(ptr, size); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +void free(void* ptr) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + global_osrt_func.real_free(ptr); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); +} + +void* mmap(void* addr, size_t length, int prot, int flags, int fd, off_t offset) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + void* ret = global_osrt_func.real_mmap(addr, length, prot, flags, fd, offset); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +void* mremap(void* old_address, size_t old_size, size_t new_size, int flags, ...) +{ + global_osrt_func.loadFunc(); + va_list args; + va_start(args, flags); + void* arg = va_arg(args, void*); + va_end(args); + uint64_t start_time = nsec_now(); + void* ret = global_osrt_func.real_mremap(old_address, old_size, new_size, flags, arg); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int munmap(void* addr, size_t length) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.real_munmap(addr, length); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int msync(void* addr, size_t length, int flags) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.real_msync(addr, length, flags); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int mprotect(void* addr, size_t len, int prot) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.real_mprotect(addr, len, prot); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int brk(void* addr) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.real_brk(addr); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int pthread_mutex_lock(pthread_mutex_t* mutex) +{ + if (UNLIKELY(!INITIALIZED && RECURSIVE)) { + // During the initialization phase we might be called inside of dlsym(). + // Since we'd enter an endless loop if we tried to resolved the real + // pthread_mutex_lock() here then we simply fake the lock which should + // be safe since no thread can be running yet. + return 0; + } + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.real_pthread_mutex_lock(mutex); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int pthread_mutex_timedlock(pthread_mutex_t* mutex, const struct timespec* abstime) +{ + global_osrt_func.loadFunc(); + if (UNLIKELY(RECURSIVE)) { + return global_osrt_func.real_pthread_mutex_timedlock(mutex, abstime); + } + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.real_pthread_mutex_timedlock(mutex, abstime); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int pthread_cond_signal(pthread_cond_t* cond) +{ + global_osrt_func.loadFunc(); + if (UNLIKELY(RECURSIVE)) { + return global_osrt_func.real_pthread_cond_signal(cond); + } + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.real_pthread_cond_signal(cond); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int pthread_cond_broadcast(pthread_cond_t* cond) +{ + global_osrt_func.loadFunc(); + if (UNLIKELY(RECURSIVE)) { + return global_osrt_func.real_pthread_cond_broadcast(cond); + } + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.real_pthread_cond_broadcast(cond); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int pthread_cond_wait(pthread_cond_t* cond, pthread_mutex_t* mutex) +{ + global_osrt_func.loadFunc(); + if (UNLIKELY(RECURSIVE)) { + return global_osrt_func.real_pthread_cond_wait(cond, mutex); + } + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.real_pthread_cond_wait(cond, mutex); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int pthread_cond_timedwait(pthread_cond_t* cond, pthread_mutex_t* mutex, const struct timespec* abstime) +{ + global_osrt_func.loadFunc(); + if (UNLIKELY(RECURSIVE)) { + return global_osrt_func.real_pthread_cond_timedwait(cond, mutex, abstime); + } + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.real_pthread_cond_timedwait(cond, mutex, abstime); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int pthread_rwlock_rdlock(pthread_rwlock_t* rwlock) +{ + global_osrt_func.loadFunc(); + if (UNLIKELY(RECURSIVE)) { + return global_osrt_func.real_pthread_rwlock_rdlock(rwlock); + } + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.real_pthread_rwlock_rdlock(rwlock); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int pthread_rwlock_timedrdlock(pthread_rwlock_t* rwlock, const struct timespec* abstime) +{ + global_osrt_func.loadFunc(); + if (UNLIKELY(RECURSIVE)) { + return global_osrt_func.real_pthread_rwlock_timedrdlock(rwlock, abstime); + } + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.real_pthread_rwlock_timedrdlock(rwlock, abstime); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int pthread_rwlock_wrlock(pthread_rwlock_t* rwlock) +{ + global_osrt_func.loadFunc(); + if (UNLIKELY(RECURSIVE)) { + global_osrt_func.real_pthread_rwlock_wrlock(rwlock); + } + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.real_pthread_rwlock_wrlock(rwlock); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int pthread_rwlock_timedwrlock(pthread_rwlock_t* rwlock, const struct timespec* abstime) +{ + global_osrt_func.loadFunc(); + if (UNLIKELY(RECURSIVE)) { + return global_osrt_func.real_pthread_rwlock_timedwrlock(rwlock, abstime); + } + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.real_pthread_rwlock_timedwrlock(rwlock, abstime); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +void exit(int status) +{ + if (LIKELY(INITIALIZED)) { + global_osrt_func.dumpFunc(); + } + real_exit(status); +} + +void _exit(int status) +{ + if (LIKELY(INITIALIZED)) { + global_osrt_func.dumpFunc(); + } + real__exit(status); +} + +void _Exit(int status) +{ + if (LIKELY(INITIALIZED)) { + global_osrt_func.dumpFunc(); + } + real__Exit(status); +} diff --git a/profiler/osrt_trace/src/msosrt_trace.h b/profiler/osrt_trace/src/msosrt_trace.h new file mode 100644 index 0000000000000000000000000000000000000000..e153ef5138883cd597c0a5a524adc5ec5b555ea4 --- /dev/null +++ b/profiler/osrt_trace/src/msosrt_trace.h @@ -0,0 +1,207 @@ +#pragma once + +#ifndef _GNU_SOURCE +#define _GNU_SOURCE +#endif + +#include +#include +#include +#include +#include +#include +#include + +#include "utils.h" +#include "file_func.h" +#include "socket_func.h" + +#define TRACE_API __attribute__((visibility("default"))) +#define LOAD_FUNC(name, func_type) \ + do { \ + (real_##name) = reinterpret_cast(dlsym(RTLD_NEXT, #name)); \ + } while (false) + +#ifdef __cplusplus +extern "C" { +#endif +// memory func +TRACE_API void* malloc(size_t size); +TRACE_API void* realloc(void* ptr, size_t size); +TRACE_API void free(void* ptr); +TRACE_API void* mmap(void* addr, size_t length, int prot, int flags, int fd, off_t offset); +TRACE_API int munmap(void* addr, size_t length); +TRACE_API void* mremap(void* old_address, size_t old_size, size_t new_size, int flags, ... /* void *new_address */); +TRACE_API int msync(void* addr, size_t length, int flags); +TRACE_API int mprotect(void* addr, size_t len, int prot); +TRACE_API int brk(void* addr); +// pthread func +TRACE_API int pthread_mutex_lock(pthread_mutex_t* mutex); +TRACE_API int pthread_mutex_timedlock(pthread_mutex_t* mutex, const struct timespec* abstime); +TRACE_API int pthread_cond_signal(pthread_cond_t* cond); +TRACE_API int pthread_cond_broadcast(pthread_cond_t* cond); +TRACE_API int pthread_cond_wait(pthread_cond_t* cond, pthread_mutex_t* mutex); +TRACE_API int pthread_cond_timedwait(pthread_cond_t* cond, pthread_mutex_t* mutex, const struct timespec* abstime); +TRACE_API int pthread_rwlock_rdlock(pthread_rwlock_t* rwlock); +TRACE_API int pthread_rwlock_timedrdlock(pthread_rwlock_t* rwlock, const struct timespec* abstime); +TRACE_API int pthread_rwlock_wrlock(pthread_rwlock_t* rwlock); +TRACE_API int pthread_rwlock_timedwrlock(pthread_rwlock_t* rwlock, const struct timespec* abstime); +// exit func +TRACE_API void exit(int status) __attribute__((noreturn)); +TRACE_API void _exit(int status) __attribute__((noreturn)); +TRACE_API void _Exit(int status) __attribute__((noreturn)); +// file func +TRACE_API int dup(int oldfd); +TRACE_API int dup2(int oldfd, int newfd); +TRACE_API int dup3(int oldfd, int newfd, int flags); +TRACE_API ssize_t tee(int fd_in, int fd_out, size_t len, unsigned int flags); +TRACE_API ssize_t splice(int fd_in, off_t* off_in, int fd_out, off_t* off_out, size_t len, unsigned int flags); +TRACE_API int fallocate(int fd, int mode, off_t offset, off_t len); +TRACE_API int fdatasync(int fildes); +TRACE_API int fsync(int fd); +TRACE_API int fcntl(int fd, int op, ...); +TRACE_API int flock(int fd, int op); +TRACE_API int lockf(int fd, int op, off_t len); +TRACE_API int truncate(const char* path, off_t length); +TRACE_API int ftruncate(int fildes, off_t length); +TRACE_API int ioctl(int fd, int op, ...); +TRACE_API int open(const char* pathname, int flags, ... /* mode_t mode */ ); +TRACE_API int openat(int dirfd, const char* pathname, int flags, ... /* mode_t mode */ ); +TRACE_API int pipe(int pipefd[2]); +TRACE_API int pipe2(int pipefd[2], int flags); +TRACE_API int mkfifo(const char* pathname, mode_t mode); +TRACE_API int mkfifoat(int dirfd, const char* pathname, mode_t mode); +TRACE_API ssize_t read(int fd, void* buf, size_t count); +TRACE_API ssize_t pread(int fd, void* buf, size_t count, off_t offset); +TRACE_API ssize_t readv(int fd, const struct iovec* iov, int iovcnt); +TRACE_API ssize_t preadv(int fd, const struct iovec* iov, int iovcnt, off_t offset); +TRACE_API ssize_t preadv2(int fd, const struct iovec* iov, int iovcnt, off_t offset, int flags); +TRACE_API ssize_t write(int fd, const void* buf, size_t count); +TRACE_API ssize_t pwrite(int fd, const void* buf, size_t count, off_t offset); +TRACE_API ssize_t writev(int fd, const struct iovec* iov, int iovcnt); +TRACE_API ssize_t pwritev(int fd, const struct iovec* iov, int iovcnt, off_t offset); +TRACE_API ssize_t pwritev2(int fd, const struct iovec* iov, int iovcnt, off_t offset, int flags); +TRACE_API ssize_t copy_file_range(int fd_in, off_t* off_in, int fd_out, off_t* off_out, size_t len, unsigned int flags); +TRACE_API void sync(void); +TRACE_API int syncfs(int fd); +TRACE_API int sync_file_range(int fd, off_t offset, off_t nbytes, unsigned int flags); +TRACE_API ssize_t vmsplice(int fd, const struct iovec* iov, size_t nr_segs, unsigned int flags); +TRACE_API ssize_t process_vm_readv(pid_t pid, const struct iovec* local_iov, unsigned long liovcnt, + const struct iovec* remote_iov, unsigned long riovcnt, unsigned long flags); +TRACE_API ssize_t process_vm_writev(pid_t pid, const struct iovec* local_iov, unsigned long liovcnt, + const struct iovec* remote_iov, unsigned long riovcnt, unsigned long flags); +TRACE_API int fclose(FILE* stream); +TRACE_API int fcloseall(void); +TRACE_API int fflush(FILE* stream); +TRACE_API int fgetc(FILE* stream); +TRACE_API char* fgets(char* s, int size, FILE* stream); +TRACE_API int fputc(int c, FILE* stream); +TRACE_API int fputs(const char* s, FILE* stream); +TRACE_API void flockfile(FILE* filehandle); +TRACE_API int ftrylockfile(FILE* filehandle); +TRACE_API void funlockfile(FILE* filehandle); +TRACE_API FILE* fopen(const char* pathname, const char* mode); +TRACE_API FILE* freopen(const char* pathname, const char* mode, FILE* stream); +TRACE_API size_t fread(void* ptr, size_t size, size_t nmemb, FILE* stream); +TRACE_API size_t fwrite(const void* ptr, size_t size, size_t nitems, FILE* stream); +TRACE_API ssize_t getdelim(char** lineptr, size_t* n, int delimiter, FILE* stream); +TRACE_API ssize_t getline(char** lineptr, size_t* n, FILE* stream); +TRACE_API int getc(FILE* stream); +TRACE_API int putc(int c, FILE* stream); +TRACE_API int getc_unlocked(FILE* stream); +TRACE_API int putc_unlocked(int c, FILE* stream); +TRACE_API int fflush_unlocked(FILE* stream); +TRACE_API int fgetc_unlocked(FILE* stream); +TRACE_API int fputc_unlocked(int c, FILE* stream); +TRACE_API size_t fread_unlocked(void* ptr, size_t size, size_t n, FILE* stream); +TRACE_API size_t fwrite_unlocked(const void* ptr, size_t size, size_t n, FILE* stream); +TRACE_API char* fgets_unlocked(char* s, int n, FILE* stream); +TRACE_API int fputs_unlocked(const char* s, FILE* stream); +// socket func +TRACE_API int socket(int domain, int type, int protocol); +TRACE_API int socketpair(int domain, int type, int protocol, int sv[2]); +TRACE_API int epoll_ctl(int epfd, int op, int fd, struct epoll_event* event); +TRACE_API int epoll_wait(int epfd, struct epoll_event* events, int maxevents, int timeout); +TRACE_API int epoll_pwait(int epfd, struct epoll_event* events, int maxevents, int timeout, const sigset_t* sigmask); +TRACE_API int select(int nfds, fd_set* readfds, fd_set* writefds, fd_set* exceptfds, struct timeval* timeout); +TRACE_API int listen(int sockfd, int backlog); +TRACE_API int accept(int sockfd, struct sockaddr* addr, socklen_t* addrlen); +TRACE_API int accept4(int sockfd, struct sockaddr* addr, socklen_t* addrlen, int flags); +TRACE_API int bind(int sockfd, const struct sockaddr* addr, socklen_t addrlen); +TRACE_API int poll(struct pollfd* fds, nfds_t nfds, int timeout); +TRACE_API int ppoll(struct pollfd* fds, nfds_t nfds, const struct timespec* tmo_p, const sigset_t* sigmask); +TRACE_API ssize_t send(int sockfd, const void* buf, size_t len, int flags); +TRACE_API ssize_t sendto(int sockfd, const void* buf, size_t len, int flags, const struct sockaddr* dest_addr, socklen_t addrlen); +TRACE_API ssize_t sendmsg(int sockfd, const struct msghdr* msg, int flags); +TRACE_API int sendmmsg(int sockfd, struct mmsghdr* msgvec, unsigned int vlen, int flags); +TRACE_API ssize_t sendfile(int out_fd, int in_fd, off_t* offset, size_t count); +TRACE_API ssize_t recv(int sockfd, void* buf, size_t len, int flags); +TRACE_API ssize_t recvfrom(int sockfd, void* buf, size_t len, int flags, struct sockaddr* src_addr, socklen_t* addrlen); +TRACE_API ssize_t recvmsg(int sockfd, struct msghdr* msg, int flags); +TRACE_API int recvmmsg(int sockfd, struct mmsghdr* msgvec, unsigned int vlen, int flags, struct timespec* timeout); +#ifdef __cplusplus +} +#endif + +using MallocFunc = void*(*)(size_t); +using ReallocFunc = void*(*)(void*, size_t); +using FreeFunc = void(*)(void*); +using MmapFunc = void*(*)(void*, size_t, int, int, int, off_t); +using MunmapFunc = int(*)(void*, size_t); +using MremapFunc = void*(*)(void*, size_t, size_t, int, ...); +using MsyncFunc = int(*)(void*, size_t, int); +using MprotectFunc = int(*)(void*, size_t, int); +using BrkFunc = int(*)(void*); +using PthreadMutexLockFunc = int(*)(pthread_mutex_t*); +using PthreadMutexTimedlockFunc = int(*)(pthread_mutex_t*, const struct timespec*); +using PthreadCondSignalFunc = int(*)(pthread_cond_t*); +using PthreadCondBroadcastFunc = int(*)(pthread_cond_t*); +using PthreadCondWaitFunc = int(*)(pthread_cond_t*, pthread_mutex_t*); +using PthreadCondTimedwaitFunc = int(*)(pthread_cond_t*, pthread_mutex_t*, const struct timespec*); +using PthreadRwlockRdlockFunc = int(*)(pthread_rwlock_t*); +using PthreadRwlockTimedrdlockFunc = int(*)(pthread_rwlock_t*, const struct timespec*); +using PthreadRwlockWrlockFunc = int(*)(pthread_rwlock_t*); +using PthreadRwlockTimedwrlockFunc = int(*)(pthread_rwlock_t*, const struct timespec*); + +struct OSRTRecord { + pid_t pid = 0; + pid_t tid = 0; + const char* function = nullptr; + uint64_t start_time = 0; + uint64_t duration = 0; +}; + +const uint64_t DEFAULT_THRESHOLD = 10 * 1000 * 1000; // 10ms + +struct OSRTFunc { + uint64_t threshold_ = DEFAULT_THRESHOLD; + + MallocFunc real_malloc = nullptr; + ReallocFunc real_realloc = nullptr; + FreeFunc real_free = nullptr; + MmapFunc real_mmap = nullptr; + MunmapFunc real_munmap = nullptr; + MremapFunc real_mremap = nullptr; + MsyncFunc real_msync = nullptr; + MprotectFunc real_mprotect = nullptr; + BrkFunc real_brk = nullptr; + PthreadMutexLockFunc real_pthread_mutex_lock = nullptr; + PthreadMutexTimedlockFunc real_pthread_mutex_timedlock = nullptr; + PthreadCondSignalFunc real_pthread_cond_signal = nullptr; + PthreadCondBroadcastFunc real_pthread_cond_broadcast = nullptr; + PthreadCondWaitFunc real_pthread_cond_wait = nullptr; + PthreadCondTimedwaitFunc real_pthread_cond_timedwait = nullptr; + PthreadRwlockRdlockFunc real_pthread_rwlock_rdlock = nullptr; + PthreadRwlockTimedrdlockFunc real_pthread_rwlock_timedrdlock = nullptr; + PthreadRwlockWrlockFunc real_pthread_rwlock_wrlock = nullptr; + PthreadRwlockTimedwrlockFunc real_pthread_rwlock_timedwrlock = nullptr; + + FileFuncProxy file_func; + SocketFuncProxy socket_func; + + void loadFunc(); + void recordFunc(uint64_t start_time, uint64_t duration, const char* name); + void dumpFunc(); +}; + +extern OSRTFunc global_osrt_func; diff --git a/profiler/osrt_trace/src/socket_func.cpp b/profiler/osrt_trace/src/socket_func.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f2863c6a515f3d5159eb5e7e1212499d78301df9 --- /dev/null +++ b/profiler/osrt_trace/src/socket_func.cpp @@ -0,0 +1,217 @@ +#include "socket_func.h" + +#include "msosrt_trace.h" + +void SocketFuncProxy::loadFunc() +{ + LOAD_FUNC(socket, SocketFunc); + LOAD_FUNC(socketpair, SocketpairFunc); + LOAD_FUNC(epoll_ctl, EpollCtlFunc); + LOAD_FUNC(epoll_wait, EpollWaitFunc); + LOAD_FUNC(epoll_pwait, EpollPwaitFunc); + LOAD_FUNC(select, SelectFunc); + LOAD_FUNC(listen, ListenFunc); + LOAD_FUNC(accept, AcceptFunc); + LOAD_FUNC(accept4, Accept4Func); + LOAD_FUNC(bind, BindFunc); + LOAD_FUNC(poll, PollFunc); + LOAD_FUNC(ppoll, PpollFunc); + LOAD_FUNC(send, SendFunc); + LOAD_FUNC(sendto, SendtoFunc); + LOAD_FUNC(sendmsg, SendmsgFunc); + LOAD_FUNC(sendmmsg, SendmmsgFunc); + LOAD_FUNC(sendfile, SendfileFunc); + LOAD_FUNC(recv, RecvFunc); + LOAD_FUNC(recvfrom, RecvfromFunc); + LOAD_FUNC(recvmsg, RecvmsgFunc); + LOAD_FUNC(recvmmsg, RecvmmsgFunc); +} + +int socket(int domain, int type, int protocol) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.socket_func.real_socket(domain, type, protocol); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int socketpair(int domain, int type, int protocol, int sv[2]) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.socket_func.real_socketpair(domain, type, protocol, sv); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int epoll_ctl(int epfd, int op, int fd, struct epoll_event* event) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.socket_func.real_epoll_ctl(epfd, op, fd, event); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int epoll_wait(int epfd, struct epoll_event* events, int maxevents, int timeout) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.socket_func.real_epoll_wait(epfd, events, maxevents, timeout); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int epoll_pwait(int epfd, struct epoll_event* events, int maxevents, int timeout, const sigset_t* sigmask) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.socket_func.real_epoll_pwait(epfd, events, maxevents, timeout, sigmask); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int select(int nfds, fd_set* readfds, fd_set* writefds, fd_set* exceptfds, struct timeval* timeout) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.socket_func.real_select(nfds, readfds, writefds, exceptfds, timeout); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int listen(int sockfd, int backlog) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.socket_func.real_listen(sockfd, backlog); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int accept(int sockfd, struct sockaddr* addr, socklen_t* addrlen) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.socket_func.real_accept(sockfd, addr, addrlen); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int accept4(int sockfd, struct sockaddr* addr, socklen_t* addrlen, int flags) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.socket_func.real_accept4(sockfd, addr, addrlen, flags); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int bind(int sockfd, const struct sockaddr* addr, socklen_t addrlen) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.socket_func.real_bind(sockfd, addr, addrlen); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int poll(struct pollfd* fds, nfds_t nfds, int timeout) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.socket_func.real_poll(fds, nfds, timeout); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int ppoll(struct pollfd* fds, nfds_t nfds, const struct timespec* tmo_p, const sigset_t* sigmask) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.socket_func.real_ppoll(fds, nfds, tmo_p, sigmask); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +ssize_t send(int sockfd, const void* buf, size_t len, int flags) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.socket_func.real_send(sockfd, buf, len, flags); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +ssize_t sendto(int sockfd, const void* buf, size_t len, int flags, const struct sockaddr* dest_addr, socklen_t addrlen) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.socket_func.real_sendto(sockfd, buf, len, flags, dest_addr, addrlen); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +ssize_t sendmsg(int sockfd, const struct msghdr* msg, int flags) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.socket_func.real_sendmsg(sockfd, msg, flags); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int sendmmsg(int sockfd, struct mmsghdr* msgvec, unsigned int vlen, int flags) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.socket_func.real_sendmmsg(sockfd, msgvec, vlen, flags); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +ssize_t sendfile(int out_fd, int in_fd, off_t* offset, size_t count) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.socket_func.real_sendfile(out_fd, in_fd, offset, count); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +ssize_t recv(int sockfd, void* buf, size_t len, int flags) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.socket_func.real_recv(sockfd, buf, len, flags); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +ssize_t recvfrom(int sockfd, void* buf, size_t len, int flags, struct sockaddr* src_addr, socklen_t* addrlen) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.socket_func.real_recvfrom(sockfd, buf, len, flags, src_addr, addrlen); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +ssize_t recvmsg(int sockfd, struct msghdr* msg, int flags) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.socket_func.real_recvmsg(sockfd, msg, flags); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} + +int recvmmsg(int sockfd, struct mmsghdr* msgvec, unsigned int vlen, int flags, struct timespec* timeout) +{ + global_osrt_func.loadFunc(); + uint64_t start_time = nsec_now(); + auto ret = global_osrt_func.socket_func.real_recvmmsg(sockfd, msgvec, vlen, flags, timeout); + global_osrt_func.recordFunc(start_time, nsec_now() - start_time, __FUNCTION__); + return ret; +} diff --git a/profiler/osrt_trace/src/socket_func.h b/profiler/osrt_trace/src/socket_func.h new file mode 100644 index 0000000000000000000000000000000000000000..361ce1d6382eada6cd942d74c2f3e0e7cd8621a0 --- /dev/null +++ b/profiler/osrt_trace/src/socket_func.h @@ -0,0 +1,60 @@ +#pragma once + +#ifndef _GNU_SOURCE +#define _GNU_SOURCE +#endif + +#include +#include +#include +#include +#include + +using SocketFunc = int(*)(int, int, int); +using SocketpairFunc = int(*)(int, int, int, int* sv); +using EpollCtlFunc = int(*)(int, int, int, struct epoll_event*); +using EpollWaitFunc = int(*)(int, struct epoll_event*, int, int); +using EpollPwaitFunc = int(*)(int, struct epoll_event*, int, int, const sigset_t*); +using SelectFunc = int(*)(int, fd_set*, fd_set*, fd_set*, struct timeval*); +using ListenFunc = int(*)(int, int); +using AcceptFunc = int(*)(int, struct sockaddr*, socklen_t*); +using Accept4Func = int(*)(int, struct sockaddr*, socklen_t*, int); +using BindFunc = int(*)(int, const struct sockaddr*, socklen_t); +using PollFunc = int(*)(struct pollfd*, nfds_t, int); +using PpollFunc = int(*)(struct pollfd*, nfds_t, const struct timespec*, const sigset_t*); +using SendFunc = ssize_t(*)(int, const void*, size_t, int); +using SendtoFunc = ssize_t(*)(int, const void*, size_t, int, const struct sockaddr*, socklen_t); +using SendmsgFunc = ssize_t(*)(int, const struct msghdr*, int); +using SendmmsgFunc = int(*)(int, struct mmsghdr*, unsigned int, int); +using SendfileFunc = ssize_t(*)(int, int, off_t*, size_t); +using RecvFunc = ssize_t(*)(int, void*, size_t, int); +using RecvfromFunc = ssize_t(*)(int, void*, size_t, int, struct sockaddr*, socklen_t*); +using RecvmsgFunc = ssize_t(*)(int, struct msghdr*, int); +using RecvmmsgFunc = int(*)(int, struct mmsghdr*, unsigned int, int, struct timespec*); + +struct SocketFuncProxy +{ + SocketFunc real_socket = nullptr; + SocketpairFunc real_socketpair = nullptr; + EpollCtlFunc real_epoll_ctl = nullptr; + EpollWaitFunc real_epoll_wait = nullptr; + EpollPwaitFunc real_epoll_pwait = nullptr; + SelectFunc real_select = nullptr; + ListenFunc real_listen = nullptr; + AcceptFunc real_accept = nullptr; + Accept4Func real_accept4 = nullptr; + BindFunc real_bind = nullptr; + PollFunc real_poll = nullptr; + PpollFunc real_ppoll = nullptr; + SendFunc real_send = nullptr; + SendtoFunc real_sendto = nullptr; + SendmsgFunc real_sendmsg = nullptr; + SendmmsgFunc real_sendmmsg = nullptr; + SendfileFunc real_sendfile = nullptr; + RecvFunc real_recv = nullptr; + RecvfromFunc real_recvfrom = nullptr; + RecvmsgFunc real_recvmsg = nullptr; + RecvmmsgFunc real_recvmmsg = nullptr; + + void loadFunc(); +}; diff --git a/profiler/osrt_trace/src/utils.cpp b/profiler/osrt_trace/src/utils.cpp new file mode 100644 index 0000000000000000000000000000000000000000..82382d23039e63c7ab2d4475d0dcf7fe2aec9fad --- /dev/null +++ b/profiler/osrt_trace/src/utils.cpp @@ -0,0 +1,159 @@ +#include "utils.h" + +#include +#include +#include +#include +#include +#include + +int str_to_i64(const std::string& str, int64_t& num) +{ + if (str.empty()) { + return -1; + } + size_t pos = 0; + try { + num = std::stoll(str, &pos); + } catch (...) { + return -1; + } + if (pos != str.size()) { + return -1; + } + return 0; +} + +bool PathUtils::IsFileExist(const std::string &path) +{ + if (path.empty() || path.size() > PATH_MAX) { + return false; + } + return (access(path.c_str(), F_OK) == 0) ? true : false; +} + +bool PathUtils::IsFileWritable(const std::string &path) +{ + if (path.empty() || path.size() > PATH_MAX) { + return false; + } + return (access(path.c_str(), W_OK) == 0) ? true : false; +} + +bool PathUtils::IsDir(const std::string &path) +{ + if (path.empty() || path.size() > PATH_MAX) { + return false; + } + struct stat st{}; + int ret = lstat(path.c_str(), &st); + if (ret != 0) { + return false; + } + return S_ISDIR(st.st_mode) ? true : false; +} + +bool PathUtils::CreateDir(const std::string &path) +{ + if (path.empty() || path.size() > PATH_MAX) { + return false; + } + if (IsFileExist(path)) { + return IsDir(path) ? true : false; + } + size_t pos = 0; + while ((pos = path.find_first_of('/', pos)) != std::string::npos) { + std::string base_dir = path.substr(0, ++pos); + if (IsFileExist(base_dir)) { + if (IsDir(base_dir)) { + continue; + } else { + return false; + } + } + if (mkdir(base_dir.c_str(), DATA_DIR_AUTHORITY) != 0) { + return false; + } + } + return (mkdir(path.c_str(), DATA_DIR_AUTHORITY) == 0) ? true : false; +} + +std::string PathUtils::RealPath(const std::string &path) +{ + if (path.empty() || path.size() > PATH_MAX) { + return ""; + } + char realPath[PATH_MAX] = {0}; + if (realpath(path.c_str(), realPath) == nullptr) { + return ""; + } + return std::string(realPath); +} + +std::string PathUtils::RelativeToAbsPath(const std::string &path) +{ + if (path.empty() || path.size() > PATH_MAX) { + return ""; + } + if (path[0] != '/') { + char pwd_path[PATH_MAX] = {0}; + if (getcwd(pwd_path, PATH_MAX) != nullptr) { + return std::string(pwd_path) + "/" + path; + } + return ""; + } + return std::string(path); +} + +std::string PathUtils::DirName(const std::string &path) +{ + if (path.empty()) { + return ""; + } + char temp_path[PATH_MAX] = {0}; + strncpy(temp_path, path.c_str(), path.size() < PATH_MAX ? path.size() : PATH_MAX); + char* path_c = dirname(temp_path); + return path_c ? std::string(path_c) : ""; +} + +bool PathUtils::CreateFile(const std::string &path) +{ + if (path.empty() || path.size() > PATH_MAX || !CreateDir(DirName(path))) { + return false; + } + int fd = creat(path.c_str(), DATA_FILE_AUTHORITY); + return (fd < 0 || close(fd) != 0) ? false : true; +} + +bool PathUtils::IsSoftLink(const std::string &path) +{ + if (path.empty() || path.size() > PATH_MAX || !IsFileExist(path)) { + return false; + } + struct stat st{}; + if (lstat(path.c_str(), &st) != 0) { + return false; + } + return S_ISLNK(st.st_mode); +} + +bool PathUtils::DirPathCheck(const std::string& abs_path) +{ + if (abs_path.empty() || abs_path.size() > PATH_MAX) { + fprintf(stderr, "[ERROR] The length of Path %s is invalid.\n", abs_path.c_str()); + return false; + } + if (IsSoftLink(abs_path)) { + fprintf(stderr, "[ERROR] Path %s is soft link.\n", abs_path.c_str()); + return false; + } + if (!IsFileExist(abs_path) && !CreateDir(abs_path)) { + fprintf(stderr, "[ERROR] Path %s not exist and create failed.\n", abs_path.c_str()); + return false; + } + if (!IsDir(abs_path) || !IsFileWritable(abs_path)) { + fprintf(stderr, "[ERROR] %s is not a directory or is not writable.\n", abs_path.c_str()); + return false; + } + return true; +} diff --git a/profiler/osrt_trace/src/utils.h b/profiler/osrt_trace/src/utils.h new file mode 100644 index 0000000000000000000000000000000000000000..129c062d5f2898d0b33db33f4716ae497c6ad8d1 --- /dev/null +++ b/profiler/osrt_trace/src/utils.h @@ -0,0 +1,50 @@ +/** + * Copyright 2024 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include + +#define LIKELY(x) (__builtin_expect(!!(x), 1)) +#define UNLIKELY(x) (__builtin_expect(!!(x), 0)) + +const mode_t DATA_FILE_AUTHORITY = 0640; +const mode_t DATA_DIR_AUTHORITY = 0750; + +inline uint64_t nsec_now() +{ + static const uint64_t S_TO_NS = 1000 * 1000 * 1000; + struct timespec ts; + clock_gettime(CLOCK_REALTIME, &ts); + return static_cast(ts.tv_sec * S_TO_NS + ts.tv_nsec); +} + +int str_to_i64(const std::string& str, int64_t& num); + + +struct PathUtils { + static bool IsFileExist(const std::string &path); + static bool IsFileWritable(const std::string &path); + static bool IsDir(const std::string &path); + static bool CreateDir(const std::string &path); + static std::string RealPath(const std::string &path); + static std::string RelativeToAbsPath(const std::string &path); + static std::string DirName(const std::string &path); + static bool CreateFile(const std::string &path); + static bool IsSoftLink(const std::string &path); + static bool DirPathCheck(const std::string &path); +};