diff --git a/debug/accuracy_tools/grad_tool/common/constant.py b/debug/accuracy_tools/grad_tool/common/constant.py index 38d33e9886490bba65205eff6a8d080070213acc..7904c1d424ec2d935a50b2630ebe50918f72f088 100644 --- a/debug/accuracy_tools/grad_tool/common/constant.py +++ b/debug/accuracy_tools/grad_tool/common/constant.py @@ -39,7 +39,7 @@ class GradConst: DIRECTORY_LENGTH = 4096 FILE_NAME_LENGTH = 255 FILE_VALID_PATTERN = r"^[a-zA-Z0-9_.:/-]+$" - PARAM_VALID_PATTERN = r"^[a-zA-Z0-9.]+$" + PARAM_VALID_PATTERN = r"^[a-zA-Z0-9_.:-]+$" DIR = "dir" FILE = "file" diff --git a/debug/accuracy_tools/grad_tool/common/utils.py b/debug/accuracy_tools/grad_tool/common/utils.py index fceda8ce0f2683d1e390a045c4b9fc28fe931cff..f40f8688c2458fa17a5dc2db1ac999c9dc9ab878 100644 --- a/debug/accuracy_tools/grad_tool/common/utils.py +++ b/debug/accuracy_tools/grad_tool/common/utils.py @@ -7,7 +7,6 @@ import yaml import pandas as pd from grad_tool.common.constant import GradConst -from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileOpen def _print_log(level, msg, end='\n'): @@ -115,7 +114,7 @@ class ListCache(list): def get_config(filepath): - with FileOpen(filepath, 'r') as file: + with open(filepath, 'r') as file: config = yaml.safe_load(file) return config diff --git a/debug/accuracy_tools/grad_tool/grad_ms/grad_analyzer.py b/debug/accuracy_tools/grad_tool/grad_ms/grad_analyzer.py index c843df3884e34525fc725cf0eb1fc06fe68c96c5..fa794a681a86d040510d9dd92e039f292189c8b7 100644 --- a/debug/accuracy_tools/grad_tool/grad_ms/grad_analyzer.py +++ b/debug/accuracy_tools/grad_tool/grad_ms/grad_analyzer.py @@ -16,7 +16,6 @@ from grad_tool.common.utils import ListCache, print_warn_log from grad_tool.common.utils import create_directory, check_file_or_directory_path, write_csv from grad_tool.grad_ms.global_context import grad_context from grad_tool.grad_ms.global_context import GlobalContext -from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileCheckConst, FileChecker def get_rank_id(): @@ -170,8 +169,6 @@ class CSVGenerator(Process): stat_data = None max_try = 10 while max_try: - file_path_checker = FileChecker(file_path, FileCheckConst.DIR,FileCheckConst.READ_ABLE) - file_path = file_path_checker.common_check() try: stat_data = np.load(file_path) return stat_data diff --git a/debug/accuracy_tools/kj600/kj600/distributed/wrap_distributed.py b/debug/accuracy_tools/kj600/kj600/distributed/wrap_distributed.py index 80f978c94c91dc103c9cda6092742ab92cabcd38..77fd7924f937487feca6a300fa9a9023a26b3c4b 100644 --- a/debug/accuracy_tools/kj600/kj600/distributed/wrap_distributed.py +++ b/debug/accuracy_tools/kj600/kj600/distributed/wrap_distributed.py @@ -142,7 +142,7 @@ def op_aggregate(op, tensorlist): return max(tensorlist) if op == 'norm': return sum(tensorlist) - if op == 'zeros': # TODO wrong + if op == 'zeros': return sum(tensorlist) / len(tensorlist) if len(tensorlist) != 0 else 0 return torch.nan diff --git a/debug/accuracy_tools/msprobe/README.md b/debug/accuracy_tools/msprobe/README.md index 42743c50781501d17796aa0b214bdc034a8cf2ec..a89592499e9978fafeef306f9ce0949091bbfac4 100644 --- a/debug/accuracy_tools/msprobe/README.md +++ b/debug/accuracy_tools/msprobe/README.md @@ -21,7 +21,9 @@ Successfully installed mindstudio_probe-{version} ``` ### 下载whl包安装 -1. 使用pip命令安装numpy、openpyxl、pandas、PyYAML、rich、torch、tqdm依赖。 +1. 使用pip命令安装numpy、openpyxl、pandas、PyYAML、rich、tqdm、matplotlib依赖。 + + 根据自己的环境选择安装 torch、mindspore。 若环境中已安装部分依赖,不需要重复安装。 @@ -177,6 +179,14 @@ Required-by: MindSpore场景:暂不支持。 +6. 执行梯度采集和比对。 + + 用于采集梯度数据并进行梯度相似度比对。可以精准定位问题出现的step。 + + 详见[梯度状态监测工具](./doc/grad_probe/grad_probe.md)。 + + + 上述流程中的工具均为msprobe工具的子工具,使用相同的命令行,格式如下: 精度预检工具 diff --git a/debug/accuracy_tools/msprobe/config/config.json b/debug/accuracy_tools/msprobe/config/config.json index ef0283ca27a4102fd1a53a842e0dc056b21827ae..bc9789a38e62c85822ab55232de85d9b802fc29c 100644 --- a/debug/accuracy_tools/msprobe/config/config.json +++ b/debug/accuracy_tools/msprobe/config/config.json @@ -31,11 +31,20 @@ "error_data_path": "./" }, "grad_probe": { - "level": "L1", + "grad_level": "L1", "param_list": [], - "rank": [], - "step": [], - "bounds": [-1, 0, 1], - "output_path": "./grad_output" + "bounds": [-1, 0, 1] + }, + "free_benchmark": { + "scope": [], + "list": [], + "fuzz_device": "npu", + "pert_mode": "improve_precision", + "handler_type": "check", + "fuzz_level": "L1", + "fuzz_stage": "forward", + "if_preheat": false, + "preheat_step": 15, + "max_sample": 20 } } \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/pytorch/advisor/advisor.py b/debug/accuracy_tools/msprobe/core/advisor/advisor.py similarity index 96% rename from debug/accuracy_tools/msprobe/pytorch/advisor/advisor.py rename to debug/accuracy_tools/msprobe/core/advisor/advisor.py index b178664d9e37f7d6cafdca58218b75909ab9cfcc..9824ac22a036252b3aad4256af7a46f85356f770 100644 --- a/debug/accuracy_tools/msprobe/pytorch/advisor/advisor.py +++ b/debug/accuracy_tools/msprobe/core/advisor/advisor.py @@ -17,9 +17,9 @@ import os -from msprobe.pytorch.advisor.advisor_result import AdvisorResult -from msprobe.pytorch.advisor.advisor_const import AdvisorConst -from msprobe.pytorch.common.log import logger +from msprobe.core.advisor.advisor_result import AdvisorResult +from msprobe.core.advisor.advisor_const import AdvisorConst +from msprobe.core.common.log import logger from msprobe.core.common.utils import CompareException from msprobe.core.common.file_check import FileChecker from msprobe.core.common.const import Const, CompareConst, FileCheckConst diff --git a/debug/accuracy_tools/msprobe/pytorch/advisor/advisor_const.py b/debug/accuracy_tools/msprobe/core/advisor/advisor_const.py similarity index 100% rename from debug/accuracy_tools/msprobe/pytorch/advisor/advisor_const.py rename to debug/accuracy_tools/msprobe/core/advisor/advisor_const.py diff --git a/debug/accuracy_tools/msprobe/pytorch/advisor/advisor_result.py b/debug/accuracy_tools/msprobe/core/advisor/advisor_result.py similarity index 95% rename from debug/accuracy_tools/msprobe/pytorch/advisor/advisor_result.py rename to debug/accuracy_tools/msprobe/core/advisor/advisor_result.py index 456f542e1f5bf867aa3db6a88e36dd03f8b581dc..2bfea2eb9576166658ff6b3fe34c8cb54cd86fdf 100644 --- a/debug/accuracy_tools/msprobe/pytorch/advisor/advisor_result.py +++ b/debug/accuracy_tools/msprobe/core/advisor/advisor_result.py @@ -17,8 +17,8 @@ import os import time -from msprobe.pytorch.advisor.advisor_const import AdvisorConst -from msprobe.pytorch.common.log import logger +from msprobe.core.advisor.advisor_const import AdvisorConst +from msprobe.core.common.log import logger from msprobe.core.common.const import Const, FileCheckConst from msprobe.core.common.file_check import change_mode diff --git a/debug/accuracy_tools/msprobe/core/common/const.py b/debug/accuracy_tools/msprobe/core/common/const.py index 35946ca7c07a27eb7cab56d89b5868040c9b9ac3..3337570825206097d1a1f2dbe6201c9f4272d0fe 100644 --- a/debug/accuracy_tools/msprobe/core/common/const.py +++ b/debug/accuracy_tools/msprobe/core/common/const.py @@ -1,5 +1,6 @@ import os import stat + import numpy as np @@ -16,10 +17,12 @@ class Const: OFF = 'OFF' BACKWARD = 'backward' FORWARD = 'forward' + PRIMITIVE_PREFIX = 'Primitive' DEFAULT_LIST = [] DEFAULT_PATH = './' WHITE_LIST = 'white_list' BLACK_LIST = 'black_list' + DUMP_TENSOR_DATA = 'dump_tensor_data' # dump mode ALL = "all" @@ -254,17 +257,3 @@ class OverflowConst: OVERFLOW_DEBUG_MODE_ENABLE = "OVERFLOW_DEBUG_MODE_ENABLE" OVERFLOW_ORIGINAL_MODE = 0 OVERFLOW_DEBUG_MODE = 1 - - -class MsConst: - CELL = "cell" - API = "api" - KERNEL = "kernel" - TOOL_LEVEL_DICT = { - "L0": CELL, - "L1": API, - "L2": KERNEL - } - PYNATIVE_MODE = "pynative" - GRAPH_GE_MODE = "graph_ge" - GRAPH_KBYK_MODE = "graph_kbyk" diff --git a/debug/accuracy_tools/msprobe/core/common/exceptions.py b/debug/accuracy_tools/msprobe/core/common/exceptions.py index ea61f8cd58fe057ba6836dd1ed368d52adedeb18..eb314c7c645e50b300fd12c91669c24dfa914583 100644 --- a/debug/accuracy_tools/msprobe/core/common/exceptions.py +++ b/debug/accuracy_tools/msprobe/core/common/exceptions.py @@ -85,4 +85,4 @@ class DistributedNotInitializedError(Exception): self.msg = msg def __str__(self): - return self.msg + return self.msg \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/core/common/file_check.py b/debug/accuracy_tools/msprobe/core/common/file_check.py index 36896cfbc19b29f1fcaef04228aac37dc29c8416..c567f94545e2ed946e3335ce34ae1757046e2efa 100644 --- a/debug/accuracy_tools/msprobe/core/common/file_check.py +++ b/debug/accuracy_tools/msprobe/core/common/file_check.py @@ -262,4 +262,22 @@ def change_mode(path, mode): def path_len_exceeds_limit(file_path): return len(os.path.realpath(file_path)) > FileCheckConst.DIRECTORY_LENGTH or \ - len(os.path.basename(file_path)) > FileCheckConst.FILE_NAME_LENGTH \ No newline at end of file + len(os.path.basename(file_path)) > FileCheckConst.FILE_NAME_LENGTH + + +def check_file_type(path): + """ + Function Description: + determine if it is a file or a directory + Parameter: + path: path + Exception Description: + when neither a file nor a directory throw exception + """ + if os.path.isdir(path): + return FileCheckConst.DIR + elif os.path.isfile(path): + return FileCheckConst.FILE + else: + logger.error('Neither a file nor a directory.') + raise FileCheckException(FileCheckException.INVALID_FILE_ERROR) diff --git a/debug/accuracy_tools/msprobe/core/common/utils.py b/debug/accuracy_tools/msprobe/core/common/utils.py index 67ca7a8edead6d3abb0f8e7a9552e63c61c2f173..7a34a24118a9d9c4935194a96004bf2953c9c8dd 100644 --- a/debug/accuracy_tools/msprobe/core/common/utils.py +++ b/debug/accuracy_tools/msprobe/core/common/utils.py @@ -27,7 +27,7 @@ from datetime import datetime, timezone from pathlib import Path import numpy as np -from msprobe.core.common.file_check import FileOpen, FileChecker +from msprobe.core.common.file_check import FileOpen, FileChecker, change_mode from msprobe.core.common.const import Const, FileCheckConst, CompareConst, OverflowConst from msprobe.core.common.log import logger @@ -149,21 +149,21 @@ def check_summary_only_valid(summary_only): return summary_only -def check_compare_param(input_parma, output_path, summary_compare=False, md5_compare=False): - if not (isinstance(input_parma, dict) and isinstance(output_path, str)): +def check_compare_param(input_param, output_path, summary_compare=False, md5_compare=False): + if not (isinstance(input_param, dict) and isinstance(output_path, str)): logger.error("Invalid input parameters") raise CompareException(CompareException.INVALID_PARAM_ERROR) - check_file_or_directory_path(input_parma.get("npu_json_path"), False) - check_file_or_directory_path(input_parma.get("bench_json_path"), False) - check_file_or_directory_path(input_parma.get("stack_json_path"), False) + check_file_or_directory_path(input_param.get("npu_path"), False) + check_file_or_directory_path(input_param.get("bench_path"), False) + check_file_or_directory_path(input_param.get("stack_path"), False) if not summary_compare and not md5_compare: - check_file_or_directory_path(input_parma.get("npu_dump_data_dir"), True) - check_file_or_directory_path(input_parma.get("bench_dump_data_dir"), True) + check_file_or_directory_path(input_param.get("npu_dump_data_dir"), True) + check_file_or_directory_path(input_param.get("bench_dump_data_dir"), True) check_file_or_directory_path(output_path, True) - with FileOpen(input_parma.get("npu_json_path"), "r") as npu_json, \ - FileOpen(input_parma.get("bench_json_path"), "r") as bench_json, \ - FileOpen(input_parma.get("stack_json_path"), "r") as stack_json: - check_json_file(input_parma, npu_json, bench_json, stack_json) + with FileOpen(input_param.get("npu_path"), "r") as npu_json, \ + FileOpen(input_param.get("bench_path"), "r") as bench_json, \ + FileOpen(input_param.get("stack_path"), "r") as stack_json: + check_json_file(input_param, npu_json, bench_json, stack_json) def check_configuration_param(stack_mode=False, auto_analyze=True, fuzzy_match=False): @@ -202,9 +202,9 @@ def _check_json(json_file_handle, file_name): def check_json_file(input_param, npu_json, bench_json, stack_json): - _check_json(npu_json, input_param.get("npu_json_path")) - _check_json(bench_json, input_param.get("bench_json_path")) - _check_json(stack_json, input_param.get("stack_json_path")) + _check_json(npu_json, input_param.get("npu_path")) + _check_json(bench_json, input_param.get("bench_path")) + _check_json(stack_json, input_param.get("stack_path")) def check_file_size(input_file, max_size): @@ -258,6 +258,17 @@ def remove_path(path): raise CompareException(CompareException.INVALID_PATH_ERROR) from err +def move_file(src_path, dst_path): + check_file_or_directory_path(src_path) + check_path_before_create(dst_path) + try: + shutil.move(src_path, dst_path) + except Exception as e: + logger.error(f"move file {src_path} to {dst_path} failed") + raise RuntimeError(f"move file {src_path} to {dst_path} failed") from e + change_mode(dst_path, FileCheckConst.DATA_FILE_AUTHORITY) + + def get_dump_data_path(dump_dir): """ Function Description: @@ -464,14 +475,14 @@ def md5_find(data): def task_dumppath_get(input_param): - npu_json_path = input_param.get("npu_json_path", None) - bench_json_path = input_param.get("bench_json_path", None) - if not npu_json_path or not bench_json_path: + npu_path = input_param.get("npu_path", None) + bench_path = input_param.get("bench_path", None) + if not npu_path or not bench_path: logger.error(f"Please check the json path is valid.") raise CompareException(CompareException.INVALID_PATH_ERROR) - with FileOpen(npu_json_path, 'r') as npu_f: + with FileOpen(npu_path, 'r') as npu_f: npu_json_data = json.load(npu_f) - with FileOpen(bench_json_path, 'r') as bench_f: + with FileOpen(bench_path, 'r') as bench_f: bench_json_data = json.load(bench_f) if npu_json_data['task'] != bench_json_data['task']: logger.error(f"Please check the dump task is consistent.") @@ -488,8 +499,8 @@ def task_dumppath_get(input_param): else: logger.error(f"Compare is not required for overflow_check or free_benchmark.") raise CompareException(CompareException.INVALID_TASK_ERROR) - input_param['npu_dump_data_dir'] = npu_json_data['dump_data_dir'] - input_param['bench_dump_data_dir'] = bench_json_data['dump_data_dir'] + input_param['npu_dump_data_dir'] = os.path.join(os.path.dirname(npu_path), Const.DUMP_TENSOR_DATA) + input_param['bench_dump_data_dir'] = os.path.join(os.path.dirname(bench_path), Const.DUMP_TENSOR_DATA) return summary_compare, md5_compare @@ -515,10 +526,19 @@ def write_csv(data, filepath): def load_npy(filepath): - filepath = os.path.realpath(filepath) check_file_or_directory_path(filepath) try: npy = np.load(filepath) except Exception as e: raise RuntimeError(f"load npy file {filepath} failed") from e return npy + + +def save_npy(data, filepath): + filepath = os.path.realpath(filepath) + check_path_before_create(filepath) + try: + npy = np.save(filepath, data) + except Exception as e: + raise RuntimeError(f"save npy file {filepath} failed") from e + change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY) diff --git a/debug/accuracy_tools/msprobe/core/common_config.py b/debug/accuracy_tools/msprobe/core/common_config.py index d6c15e101e7f3ca7dc41025fbf73671980e6dea2..688734be8acbfae23ff453fea67787a577effa13 100644 --- a/debug/accuracy_tools/msprobe/core/common_config.py +++ b/debug/accuracy_tools/msprobe/core/common_config.py @@ -50,6 +50,14 @@ class BaseConfig: self.summary_mode = json_config.get("summary_mode") self.overflow_nums = json_config.get("overflow_nums") self.check_mode = json_config.get("check_mode") + self.fuzz_device = json_config.get("fuzz_device") + self.pert_mode = json_config.get("pert_mode") + self.handler_type = json_config.get("handler_type") + self.fuzz_level = json_config.get("fuzz_level") + self.fuzz_stage = json_config.get("fuzz_stage") + self.if_preheat = json_config.get("if_preheat") + self.preheat_step = json_config.get("preheat_step") + self.max_sample = json_config.get("max_sample") def check_config(self): if self.scope is not None and not isinstance(self.scope, list): diff --git a/debug/accuracy_tools/msprobe/core/compare/acc_compare.py b/debug/accuracy_tools/msprobe/core/compare/acc_compare.py new file mode 100644 index 0000000000000000000000000000000000000000..aa2016247ff2bc0ddcf606e39fc801fbc1f0b4e6 --- /dev/null +++ b/debug/accuracy_tools/msprobe/core/compare/acc_compare.py @@ -0,0 +1,131 @@ +import multiprocessing +import pandas as pd +from msprobe.core.common.const import CompareConst +from msprobe.core.compare.npy_compare import compare_ops_apply, get_error_type, reshape_value, get_relative_err, \ + get_error_message +from msprobe.core.common.exceptions import FileCheckException +from msprobe.core.compare.utils import read_op, merge_tensor, CompareException +from msprobe.core.compare.multiprocessing_compute import _handle_multi_process +from msprobe.core.common.log import logger +from msprobe.core.compare.check import check_graph_mode, check_struct_match, fuzzy_check_op + +class Comparator: + + def __init__(self): + pass + + @classmethod + def make_result_table(cls,result,md5_compare,summary_compare,stack_mode): + header = [] + if md5_compare: + header = CompareConst.MD5_COMPARE_RESULT_HEADER[:] + elif summary_compare: + header = CompareConst.SUMMARY_COMPARE_RESULT_HEADER[:] + else: + header = CompareConst.COMPARE_RESULT_HEADER[:] + + all_mode_bool = not (summary_compare or md5_compare) + if stack_mode: + if all_mode_bool: + header.append(CompareConst.STACK) + header.append(CompareConst.DATA_NAME) + else: + header.append(CompareConst.STACK) + else: + if all_mode_bool: + for row in result: + del row[-2] + header.append(CompareConst.DATA_NAME) + else: + for row in result: + del row[-1] + result_df = pd.DataFrame(result, columns=header) + return result_df + + @classmethod + def gen_merge_list(self,json_data,op_name,stack_json_data,summary_compare,md5_compare): + op_data = json_data['data'][op_name] + op_parsed_list = read_op(op_data, op_name) + if op_name in stack_json_data: + op_parsed_list.append({'full_op_name': op_name, 'full_info': stack_json_data[op_name]}) + else: + op_parsed_list.append({'full_op_name': op_name, 'full_info': None}) + + merge_list = merge_tensor(op_parsed_list, summary_compare, md5_compare) + return merge_list + + def check_op(self, npu_dict, bench_dict, fuzzy_match): + a_op_name = npu_dict["op_name"] + b_op_name = bench_dict["op_name"] + graph_mode = check_graph_mode(a_op_name[0], b_op_name[0]) + + frame_name = getattr(self,"frame_name") + if frame_name == "PTComparator": + from msprobe.pytorch.compare.match import graph_mapping + if graph_mode: + return graph_mapping.match(a_op_name[0], b_op_name[0]) + struct_match = check_struct_match(npu_dict, bench_dict) + if not fuzzy_match: + return a_op_name == b_op_name and struct_match + is_match = True + try: + is_match = fuzzy_check_op(a_op_name, b_op_name) + except Exception as err: + logger.warning("%s and %s can not fuzzy match." % (a_op_name, b_op_name)) + is_match = False + return is_match and struct_match + + def match_op(self, npu_queue, bench_queue, fuzzy_match): + for b_index, b_op in enumerate(bench_queue[0: -1]): + if self.check_op(npu_queue[-1], b_op, fuzzy_match): + return len(npu_queue) - 1, b_index + if self.check_op(npu_queue[-1], bench_queue[-1], fuzzy_match): + return len(npu_queue) - 1, len(bench_queue) - 1 + for n_index, n_op in enumerate(npu_queue[0: -1]): + if self.check_op(n_op, bench_queue[-1], fuzzy_match): + return n_index, len(bench_queue) - 1 + return -1, -1 + + def compare_by_op(self,op_name, op_name_mapping_dict, input_parma): + npu_bench_name_list = op_name_mapping_dict[op_name] + data_name = npu_bench_name_list[1] + error_file, relative_err, error_flag = None, None, False + if data_name == '-1' or data_name == -1: # 没有真实数据路径 + n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE + error_flag = True + else: + try: + read_npy_data = getattr(self,"read_npy_data") + n_value = read_npy_data(input_parma.get("npu_dump_data_dir"), npu_bench_name_list[0]) + b_value = read_npy_data(input_parma.get("bench_dump_data_dir"), npu_bench_name_list[1]) + except IOError as error: + error_file = error.filename + n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE + error_flag = True + except FileCheckException: + error_file = data_name + n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE + error_flag = True + + n_value, b_value, error_flag = get_error_type(n_value, b_value, error_flag) + if not error_flag: + relative_err = get_relative_err(n_value, b_value) + n_value, b_value = reshape_value(n_value, b_value) + + err_msg = get_error_message(n_value, b_value, op_name, error_flag, error_file=error_file) + result_list, err_msg = compare_ops_apply(n_value, b_value, error_flag, err_msg, relative_err=relative_err) + + if npu_bench_name_list[0] != npu_bench_name_list[1]: + err_msg += " Fuzzy matching data, the comparison accuracy may be affected." + result_list.append(err_msg) + return result_list + + def _do_multi_process(self,input_parma, result_df): + try: + compare_ops = getattr(self,"compare_ops") + result_df = _handle_multi_process(compare_ops, input_parma, result_df, multiprocessing.Manager().RLock()) + return result_df + except ValueError as e: + logger.error('result dataframe is not found.') + raise CompareException(CompareException.INVALID_DATA_ERROR) from e + \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/core/compare/check.py b/debug/accuracy_tools/msprobe/core/compare/check.py new file mode 100644 index 0000000000000000000000000000000000000000..c243c0910d5cc61b5a713503fbfdddfc6a359194 --- /dev/null +++ b/debug/accuracy_tools/msprobe/core/compare/check.py @@ -0,0 +1,72 @@ +from msprobe.core.common.log import logger +from msprobe.core.compare.utils import rename_api + + +def check_struct_match(npu_dict, bench_dict): + npu_struct_in = npu_dict.get("input_struct") + bench_struct_in = bench_dict.get("input_struct") + npu_struct_out = npu_dict.get("output_struct") + bench_struct_out = bench_dict.get("output_struct") + is_match = npu_struct_in == bench_struct_in and npu_struct_out == bench_struct_out + if not is_match: + if len(npu_struct_in) == 0 or len(bench_struct_in) == 0 or len(npu_struct_in) != len(bench_struct_in): + return False + struct_in_is_match = check_type_shape_match(npu_struct_in, bench_struct_in) + struct_out_is_match = check_type_shape_match(npu_struct_out, bench_struct_out) + is_match = struct_in_is_match and struct_out_is_match + return is_match + + +def check_type_shape_match(npu_struct, bench_struct): + shape_type_match = False + for npu_type_shape, bench_type_shape in zip(npu_struct, bench_struct): + npu_type = npu_type_shape[0] + npu_shape = npu_type_shape[1] + bench_type = bench_type_shape[0] + bench_shape = bench_type_shape[1] + shape_match = npu_shape == bench_shape + type_match = npu_type == bench_type + if not type_match: + ms_type=[["Float16", "Float32"], ["Float32", "Float16"],["Float16", "BFloat16"],["BFloat16", "Float16"]] + torch_type=[["torch.float16", "torch.float32"], ["torch.float32", "torch.float16"], + ["torch.float16", "torch.bfloat16"], ["torch.bfloat16", "torch.float16"]] + if ([npu_type, bench_type] in ms_type)or ([npu_type, bench_type] in torch_type): + type_match = True + else: + type_match = False + shape_type_match = shape_match and type_match + if not shape_type_match: + return False + return shape_type_match + + +def check_graph_mode(a_op_name, b_op_name): + if "Aten" in a_op_name and "Aten" not in b_op_name: + return True + if "Aten" not in a_op_name and "Aten" in b_op_name: + return True + return False + + +def fuzzy_check_op(npu_name_list, bench_name_list): + if len(npu_name_list) == 0 or len(bench_name_list) == 0 or len(npu_name_list) != len(bench_name_list): + return False + is_match = True + for npu_name, bench_name in zip(npu_name_list, bench_name_list): + is_match = fuzzy_check_name(npu_name, bench_name) + if not is_match: + break + return is_match + + +def fuzzy_check_name(npu_name, bench_name): + if "forward" in npu_name and "forward" in bench_name: + is_match = rename_api(npu_name, "forward") == rename_api(bench_name, "forward") + elif "backward" in npu_name and "backward" in bench_name: + is_match = rename_api(npu_name, "backward") == rename_api(bench_name, "backward") + else: + is_match = npu_name == bench_name + return is_match + + + diff --git a/debug/accuracy_tools/msprobe/pytorch/compare/highlight.py b/debug/accuracy_tools/msprobe/core/compare/highlight.py similarity index 44% rename from debug/accuracy_tools/msprobe/pytorch/compare/highlight.py rename to debug/accuracy_tools/msprobe/core/compare/highlight.py index 82f0022f8b5d4a0c6472b749d4937bfe39ef8a86..ef35fd06165a212af0bd43bd57d1877aec8983d4 100644 --- a/debug/accuracy_tools/msprobe/pytorch/compare/highlight.py +++ b/debug/accuracy_tools/msprobe/core/compare/highlight.py @@ -1,8 +1,13 @@ import math import abc +from collections import namedtuple import numpy as np -from msprobe.core.common.utils import get_header_index -from msprobe.core.common.const import CompareConst +import openpyxl +from openpyxl.styles import PatternFill +from msprobe.core.common.utils import get_header_index, CompareException +from msprobe.core.common.log import logger +from msprobe.core.common.file_check import change_mode +from msprobe.core.common.const import CompareConst, FileCheckConst class HighlightCheck(abc.ABC): @@ -98,3 +103,125 @@ class HighlightRules: "check_order_magnitude": CheckOrderMagnitude(), "check_max_relative_diff": CheckMaxRelativeDiff(), } + + +def find_error_rows(result, last_len, n_num_input, highlight_dict, summary_compare=False, md5_compare=False): + """找到单个API中需要高亮的行""" + if md5_compare: + return + npu_max_index = get_header_index('NPU max', summary_compare) + bench_max_index = get_header_index('Bench max', summary_compare) + max_diff_index = get_header_index('Max diff' if summary_compare else 'MaxAbsErr', summary_compare) + + red_lines, yellow_lines = [], [] + LineInfo = namedtuple('LineInfo', ['line_data', 'num_pointer']) + ApiInfo = namedtuple('ApiInfo', ['api_input', 'api_output', 'num_pointer']) + ColorColumns = namedtuple('ColorColumns', ['red', 'yellow']) + color_columns = ColorColumns(red=red_lines, yellow=yellow_lines) + + # 对单行API的输入或输出进行误差判断 + for i, line in enumerate(result): + num = last_len + i + line_info = LineInfo(line_data=line, num_pointer=num) + for rule in HighlightRules.basic_rules.values(): + rule.apply(line_info, color_columns, summary_compare) + + # 对API的输出与输入比较,进行误差判断 + for n, api_out in enumerate(result[n_num_input:len(result)]): + num = last_len + n_num_input + n + if num in red_lines: + continue + if not isinstance(api_out[npu_max_index], (float, int)) \ + or not isinstance(api_out[bench_max_index], (float, int)) \ + or not isinstance(api_out[max_diff_index], (float, int)): + continue + for _, api_in in enumerate(result[0:n_num_input]): + if not isinstance(api_in[npu_max_index], (float, int)) \ + or not isinstance(api_in[bench_max_index], (float, int)) \ + or not isinstance(api_in[max_diff_index], (float, int)): + continue + + api_info = ApiInfo(api_input=api_in, api_output=api_out, num_pointer=num) + if summary_compare: + for rule in HighlightRules.summary_compare_rules.values(): + rule.apply(api_info, color_columns, summary_compare) + else: + for rule in HighlightRules.compare_rules.values(): + rule.apply(api_info, color_columns, summary_compare) + + highlight_dict.get('red_rows', []).extend(list(set(red_lines))) + highlight_dict.get('yellow_rows', []).extend(list(set(yellow_lines) - set(red_lines))) + + +def get_name_and_state(name): + """Get api/module name and state""" + if "input" in name: + api_name = name.split("input")[0] + state = "input" + else: + api_name = name.split("output")[0] + state = "output" + return api_name, state + + +def find_compare_result_error_rows(result_df, highlight_dict, summary_compare, md5_compare): + """将dataframe根据API分组,并找到有误差的算子用于高亮""" + result = result_df.values + start, input_num, output_num, end = 0, 0, 0, len(result_df) + last_api_name, last_state = None, None + num, last_len = 0, 0 + for res_i in result: + api_name, state = get_name_and_state(res_i[0]) + if last_api_name: + if api_name == last_api_name: + if state == last_state: + num += 1 + else: + input_num = num + num, last_state = 1, state + else: + output_num = num + find_error_rows(result[start:start + input_num + output_num], start, input_num, highlight_dict, + summary_compare, md5_compare) + num, last_api_name, last_state = 1, api_name, state + start += input_num + output_num + input_num, output_num = 1, 0 + else: + num, last_api_name, last_state = 1, api_name, state + if state: + if state == "input": + input_num = num + else: + output_num = num + find_error_rows(result[start:start + input_num + output_num], start, input_num, highlight_dict, summary_compare, md5_compare) + + +def highlight_rows_xlsx(result_df, highlight_dict, file_path): + """Write and highlight results in Excel""" + logger.info('Compare result is %s' % file_path) + + wb = openpyxl.Workbook() + ws = wb.active + + # write header + for j, col_name in enumerate(result_df.columns, start=1): + ws.cell(row=1, column=j, value=col_name) + + for i, row in enumerate(result_df.iterrows(), start=2): + for j, value in enumerate(row[1], start=1): + if not isinstance(value, (float, int)): + value = f'{str(value)}\t' if str(value) in ('inf', '-inf', 'nan') else str(value) + ws.cell(row=i, column=j, value=f'{str(value)}\t' if str(value) in ('inf', '-inf', 'nan') else value) + + if (i - 2) in highlight_dict['red_rows']: + ws.cell(row=i, column=j).fill = PatternFill(start_color=CompareConst.RED, + end_color=CompareConst.RED, fill_type="solid") + elif (i - 2) in highlight_dict['yellow_rows']: + ws.cell(row=i, column=j).fill = PatternFill(start_color=CompareConst.YELLOW, + end_color=CompareConst.YELLOW, fill_type="solid") + try: + wb.save(file_path) + except Exception as e: + logger.error('Save result file failed') + raise CompareException(CompareException.WRITE_FILE_ERROR) from e + change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY) diff --git a/debug/accuracy_tools/msprobe/core/compare/multiprocessing_compute.py b/debug/accuracy_tools/msprobe/core/compare/multiprocessing_compute.py new file mode 100644 index 0000000000000000000000000000000000000000..da63005e5d96649c6267c8e14d258a780b21bb34 --- /dev/null +++ b/debug/accuracy_tools/msprobe/core/compare/multiprocessing_compute.py @@ -0,0 +1,120 @@ + +import multiprocessing +from dataclasses import dataclass +import pandas as pd +from msprobe.core.common.log import logger +from msprobe.core.common.utils import CompareException +from msprobe.core.common.const import CompareConst + + +def _handle_multi_process(func, input_parma, result_df, lock): + process_num = int((multiprocessing.cpu_count() + 1) / 2) + op_name_mapping_dict = read_dump_data(result_df) + + df_chunk_size = len(result_df) // process_num + if df_chunk_size > 0: + df_chunks = [result_df.iloc[i:i + df_chunk_size] for i in range(0, len(result_df), df_chunk_size)] + else: + df_chunks = [result_df] + + results = [] + pool = multiprocessing.Pool(process_num) + + def err_call(args): + logger.error('multiprocess compare failed! Reason: {}'.format(args)) + try: + pool.terminate() + except OSError as e: + logger.error("pool terminate failed") + + for process_idx, df_chunk in enumerate(df_chunks): + idx = df_chunk_size * process_idx + result = pool.apply_async(func, + args=(idx, op_name_mapping_dict, df_chunk, lock, input_parma), + error_callback=err_call) + results.append(result) + final_results = [r.get() for r in results] + pool.close() + pool.join() + return pd.concat(final_results, ignore_index=True) + + +def read_dump_data(result_df): + try: + npu_dump_name_list = result_df.iloc[0:, 0].tolist() + npu_dump_tensor_list = result_df.iloc[0:, -1].tolist() + op_name_mapping_dict = {} + for index, _ in enumerate(npu_dump_name_list): + npu_dump_name = npu_dump_name_list[index] + npu_dump_tensor = npu_dump_tensor_list[index] + op_name_mapping_dict[npu_dump_name] = [npu_dump_tensor, npu_dump_tensor] + return op_name_mapping_dict + except ValueError as e: + logger.error('result dataframe is not found.') + raise CompareException(CompareException.INVALID_DATA_ERROR) from e + except IndexError as e: + logger.error('result dataframe elements can not be access.') + raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e + +@dataclass +class ComparisonResult: + cos_result: list + max_err_result: list + max_relative_err_result: list + err_msgs: list + one_thousand_err_ratio_result: list + five_thousand_err_ratio_result: list + + +def _save_cmp_result(offset, result: ComparisonResult, result_df, lock): + """ + Save comparison results into the result DataFrame with thread safety. + Args: + offset: offset for index + result: data struct of ComparisonResult + result_df: result of DataFrame + lock: thread lock + + Returns: + comparison results in DataFrame + """ + + lock.acquire() + try: + for i, _ in enumerate(result.cos_result): + process_index = i + offset + result_df.loc[process_index, CompareConst.COSINE] = result.cos_result[i] + result_df.loc[process_index, CompareConst.MAX_ABS_ERR] = result.max_err_result[i] + result_df.loc[process_index, CompareConst.MAX_RELATIVE_ERR] = result.max_relative_err_result[i] + result_df.loc[process_index, CompareConst.ERROR_MESSAGE] = result.err_msgs[i] + result_df.loc[process_index, CompareConst.ACCURACY] = check_accuracy(result.cos_result[i], result.max_err_result[i]) + result_df.loc[process_index, CompareConst.ONE_THOUSANDTH_ERR_RATIO] = result.one_thousand_err_ratio_result[i] + result_df.loc[process_index, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO] = result.five_thousand_err_ratio_result[i] + return result_df + except ValueError as e: + logger.error('result dataframe is not found.') + raise CompareException(CompareException.INVALID_DATA_ERROR) from e + except IndexError as e: + logger.error('result dataframe elements can not be access.') + raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e + finally: + lock.release() + + +def check_accuracy(cos, max_abs_err): + if cos == CompareConst.SHAPE_UNMATCH: + return CompareConst.ACCURACY_CHECK_UNMATCH + if cos == CompareConst.NONE or max_abs_err == CompareConst.NONE: + return CompareConst.NONE + if cos == "N/A" or max_abs_err == "N/A": + return CompareConst.ACCURACY_CHECK_NO + try: + cos, max_abs_err = float(cos), float(max_abs_err) + except ValueError: + logger.warning("Cosine or MaxAbsErr can not get float value.") + return CompareConst.NONE + if cos < CompareConst.COS_THRESHOLD and max_abs_err > CompareConst.MAX_ABS_ERR_THRESHOLD: + return CompareConst.ACCURACY_CHECK_NO + if cos < CompareConst.COS_MAX_THRESHOLD or max_abs_err > CompareConst.MAX_ABS_ERR_MAX_THRESHOLD: + return CompareConst.ACCURACY_CHECK_NO + return CompareConst.ACCURACY_CHECK_YES \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/pytorch/compare/npy_compare.py b/debug/accuracy_tools/msprobe/core/compare/npy_compare.py similarity index 99% rename from debug/accuracy_tools/msprobe/pytorch/compare/npy_compare.py rename to debug/accuracy_tools/msprobe/core/compare/npy_compare.py index 5a0feb4cd4a63b6f2ab680c9e9a0f0e92b594e2e..0c75076c5b41d95a5c816dc64b56ff26ae0463b4 100644 --- a/debug/accuracy_tools/msprobe/pytorch/compare/npy_compare.py +++ b/debug/accuracy_tools/msprobe/core/compare/npy_compare.py @@ -2,7 +2,7 @@ import abc import numpy as np from msprobe.core.common.utils import format_value from msprobe.core.common.const import Const, CompareConst -from msprobe.pytorch.common.log import logger +from msprobe.core.common.log import logger def handle_inf_nan(n_value, b_value): diff --git a/debug/accuracy_tools/msprobe/pytorch/compare/acc_compare.py b/debug/accuracy_tools/msprobe/core/compare/utils.py similarity index 35% rename from debug/accuracy_tools/msprobe/pytorch/compare/acc_compare.py rename to debug/accuracy_tools/msprobe/core/compare/utils.py index 2a68c756ed3bbb1d5c919ed7c0e0b07ff1056af8..510403bf32c72c2edae9a3612f3f0ee094ec791d 100644 --- a/debug/accuracy_tools/msprobe/pytorch/compare/acc_compare.py +++ b/debug/accuracy_tools/msprobe/core/compare/utils.py @@ -1,126 +1,59 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -# Copyright (C) 2019-2024. Huawei Technologies Co., Ltd. All rights reserved. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -import json -import multiprocessing -import os.path -import sys -import torch +import os +import re import numpy as np -import pandas as pd -import openpyxl -from openpyxl.styles import PatternFill -from collections import namedtuple -from dataclasses import dataclass +from msprobe.core.common.const import Const, CompareConst +from msprobe.core.common.utils import CompareException, check_file_or_directory_path, check_regex_prefix_format_valid, logger -from msprobe.pytorch.compare.match import graph_mapping -from msprobe.pytorch.compare.highlight import HighlightRules, get_header_index -from msprobe.pytorch.compare.npy_compare import compare_ops_apply, get_error_type, reshape_value, get_relative_err, \ - get_error_message -from msprobe.pytorch.advisor.advisor import Advisor -from msprobe.pytorch.common.log import logger -from msprobe.core.common.utils import check_compare_param, add_time_with_xlsx, CompareException, \ - format_value, check_file_not_exists, check_configuration_param, task_dumppath_get -from msprobe.core.common.file_check import FileChecker, change_mode, FileOpen, create_directory -from msprobe.core.common.const import Const, CompareConst, FileCheckConst -from msprobe.core.common.exceptions import FileCheckException +def extract_json(dirname, stack_json=False): + json_path = '' + for fname in os.listdir(dirname): + if fname == "construct.json": + continue + full_path = os.path.join(dirname, fname) + if full_path.endswith('.json'): + json_path = full_path + if not stack_json and 'stack' not in json_path: + break + if stack_json and 'stack' in json_path: + break -def check_graph_mode(a_op_name, b_op_name): - if "Aten" in a_op_name and "Aten" not in b_op_name: - return True - if "Aten" not in a_op_name and "Aten" in b_op_name: - return True - return False - - -def check_op(npu_dict, bench_dict, fuzzy_match): - a_op_name = npu_dict["op_name"] - b_op_name = bench_dict["op_name"] - graph_mode = check_graph_mode(a_op_name[0], b_op_name[0]) - if graph_mode: - return graph_mapping.match(a_op_name[0], b_op_name[0]) - struct_match = check_struct_match(npu_dict, bench_dict) - if not fuzzy_match: - return a_op_name == b_op_name and struct_match - is_match = True - try: - is_match = fuzzy_check_op(a_op_name, b_op_name) - except Exception as err: - logger.warning("%s and %s can not fuzzy match." % (a_op_name, b_op_name)) - is_match = False - return is_match and struct_match - - -def check_struct_match(npu_dict, bench_dict): - npu_struct_in = npu_dict.get("input_struct") - bench_struct_in = bench_dict.get("input_struct") - npu_struct_out = npu_dict.get("output_struct") - bench_struct_out = bench_dict.get("output_struct") - is_match = npu_struct_in == bench_struct_in and npu_struct_out == bench_struct_out - if not is_match: - if len(npu_struct_in) == 0 or len(bench_struct_in) == 0 or len(npu_struct_in) != len(bench_struct_in): - return False - struct_in_is_match = check_type_shape_match(npu_struct_in, bench_struct_in) - struct_out_is_match = check_type_shape_match(npu_struct_out, bench_struct_out) - is_match = struct_in_is_match and struct_out_is_match - return is_match - + # Provide robustness on invalid directory inputs + if not json_path: + logger.error(f'No file is found in dump dir {dirname}. ') + raise CompareException(CompareException.NO_DUMP_FILE_ERROR) + return json_path -def check_type_shape_match(npu_struct, bench_struct): - shape_type_match = False - for npu_type_shape, bench_type_shape in zip(npu_struct, bench_struct): - npu_type = npu_type_shape[0] - npu_shape = npu_type_shape[1] - bench_type = bench_type_shape[0] - bench_shape = bench_type_shape[1] - shape_match = npu_shape == bench_shape - type_match = npu_type == bench_type - if not type_match: - if [npu_type, bench_type] in [["torch.float16", "torch.float32"], ["torch.float32", "torch.float16"], - ["torch.float16", "torch.bfloat16"], ["torch.bfloat16", "torch.float16"]]: - type_match = True - else: - type_match = False - shape_type_match = shape_match and type_match - if not shape_type_match: - return False - return shape_type_match +def check_and_return_dir_contents(dump_dir, prefix): + """ + check the given dump dir and validate files in dump dir by using the given prefix patterns to build a + pattern: ^{prefix}(?:0|[0-9][1-9]*)?$ -def fuzzy_check_op(npu_name_list, bench_name_list): - if len(npu_name_list) == 0 or len(bench_name_list) == 0 or len(npu_name_list) != len(bench_name_list): - return False - is_match = True - for npu_name, bench_name in zip(npu_name_list, bench_name_list): - is_match = fuzzy_check_name(npu_name, bench_name) - if not is_match: - break - return is_match + Args: + dump_dir (str): dump dir + prefix (str): prefix for the patterns, prefix should be less than 20 characters and alphanumeric/-/_ only + Returns: + content [list]: dir contents + Raises: + CompareException: invalid path + ValueError: prefix not match the patterns -def fuzzy_check_name(npu_name, bench_name): - if "forward" in npu_name and "forward" in bench_name: - is_match = rename_api(npu_name, "forward") == rename_api(bench_name, "forward") - elif "backward" in npu_name and "backward" in bench_name: - is_match = rename_api(npu_name, "backward") == rename_api(bench_name, "backward") - else: - is_match = npu_name == bench_name - return is_match + """ + check_regex_prefix_format_valid(prefix) + check_file_or_directory_path(dump_dir, True) + contents = os.listdir(dump_dir) + pattern = re.compile(rf'^{prefix}(?:0|[0-9][1-9]*)?$') + for name in contents: + if not pattern.match(name): + logger.error( + f"dump_dir contains '{name}'. Expected '{prefix}'. This name is not in the format of dump " + f"output. Please check and delete irrelevant files in {dump_dir} and try again." + ) + raise CompareException(CompareException.INVALID_PATH_ERROR) + return contents def rename_api(npu_name, process): @@ -131,59 +64,145 @@ def rename_api(npu_name, process): return torch_func -def merge_tensor(tensor_list, summary_compare, md5_compare): - op_dict = {} - op_dict["op_name"] = [] - op_dict["input_struct"] = [] - op_dict["kwargs_struct"] = [] - op_dict["output_struct"] = [] - op_dict["summary"] = [] - op_dict["stack_info"] = [] +def read_op(op_data, op_name): + op_parsed_list = Const.DEFAULT_LIST + if 'forward' in op_name: + if 'input_args' in op_data: + input_item = op_data['input_args'] + input_parsed_list = op_item_parse(input_item, op_name + '_input', None) + op_parsed_list = input_parsed_list.copy() + input_parsed_list.clear() + if 'input_kwargs' in op_data: + kwargs_item = op_data['input_kwargs'] + if isinstance(kwargs_item, dict) and "type" in kwargs_item or isinstance(kwargs_item, list): + kwarg_parsed_list = op_item_parse(kwargs_item, op_name + '_input', None) + op_parsed_list += kwarg_parsed_list + kwarg_parsed_list.clear() + elif kwargs_item: + for kwarg in kwargs_item: + kwarg_parsed_list = op_item_parse(kwargs_item[kwarg], op_name + '_input.' + kwarg, None) + op_parsed_list += kwarg_parsed_list + kwarg_parsed_list.clear() + if 'output' in op_data: + output_item = op_data['output'] + output_parsed_list = op_item_parse(output_item, op_name + '_output', None) + op_parsed_list += output_parsed_list + output_parsed_list.clear() + if 'backward' in op_name: + if 'input' in op_data: + input_item = op_data['input'] + input_parsed_list = op_item_parse(input_item, op_name + '_input', None) + op_parsed_list = input_parsed_list.copy() + input_parsed_list.clear() + if 'output' in op_data: + output_item = op_data['output'] + output_parsed_list = op_item_parse(output_item, op_name + '_output', None) + op_parsed_list += output_parsed_list + output_parsed_list.clear() + return op_parsed_list - all_mode_bool = not (summary_compare or md5_compare) - if all_mode_bool: - op_dict["data_name"] = [] - for tensor in tensor_list: - if len(tensor) == 2: - op_dict['stack_info'].append(tensor['full_info']) - break - op_dict["op_name"].append(tensor['full_op_name']) - if not md5_compare: - if tensor['full_op_name'].find("input") != -1: - op_dict["input_struct"].append((tensor['dtype'], tensor['shape'])) - elif tensor['full_op_name'].find("kwarg") != -1: - op_dict["kwargs_struct"].append((tensor['dtype'], tensor['shape'])) - elif tensor['full_op_name'].find("output") != -1: - op_dict["output_struct"].append((tensor['dtype'], tensor['shape'])) +def op_item_parse(item, op_name, index, item_list=None, top_bool=True): + if item_list is None: + item_list = [] + if item is None or (isinstance(item, dict) and not item): + if not top_bool: + tmp = {'full_op_name': op_name + '.' + str(index), 'Max': None, 'Min': None, 'Mean': None, 'Norm': None, + 'dtype': None, 'shape': None, 'md5': None, 'data_name': '-1'} else: - if tensor['full_op_name'].find("input") != -1: - op_dict["input_struct"].append((tensor['dtype'], tensor['shape'], tensor['md5'])) - elif tensor['full_op_name'].find("kwarg") != -1: - op_dict["kwargs_struct"].append((tensor['dtype'], tensor['shape'], tensor['md5'])) - elif tensor['full_op_name'].find("output") != -1: - op_dict["output_struct"].append((tensor['dtype'], tensor['shape'], tensor['md5'])) - - op_dict["summary"].append([tensor['Max'], tensor['Min'], tensor['Mean'], tensor['Norm']]) - - if all_mode_bool: - op_dict["data_name"].append(tensor['data_name']) - - if not op_dict["kwargs_struct"]: - del op_dict["kwargs_struct"] - return op_dict if op_dict["op_name"] else {} + tmp = {'full_op_name': op_name + '.0', 'Max': None, 'Min': None, 'Mean': None, 'Norm': None, 'dtype': None, + 'shape': None, 'md5': None, 'data_name': '-1'} + item_list.append(tmp) + return item_list + if index is None: + if isinstance(item, dict): + full_op_name = op_name + '.0' + else: + full_op_name = op_name + else: + full_op_name = op_name + Const.SEP + str(index) + if isinstance(item, dict): + if 'type' not in item: + for kwarg in item: + kwarg_parsed_list = op_item_parse(item[kwarg], op_name + Const.SEP + kwarg, None) + item_list += kwarg_parsed_list + kwarg_parsed_list.clear() + elif 'dtype' in item: + parsed_item = item + parsed_item['full_op_name'] = full_op_name + item_list.append(parsed_item) + elif 'type' in item: + parsed_item = {} + if item['type'] == 'torch.Size': + parsed_item['full_op_name'] = full_op_name + parsed_item['dtype'] = 'torch.Size' + parsed_item['shape'] = str(item['value']) + parsed_item['md5'] = None + parsed_item['Max'] = None + parsed_item['Min'] = None + parsed_item['Mean'] = None + parsed_item['Norm'] = None + parsed_item['data_name'] = '-1' + item_list.append(parsed_item) + elif item['type'] == 'slice': + parsed_item['full_op_name'] = full_op_name + parsed_item['dtype'] = 'slice' + parsed_item['shape'] = str(np.shape(np.array(item['value']))) + parsed_item['md5'] = None + parsed_item['Max'] = None + parsed_item['Min'] = None + parsed_item['Mean'] = None + parsed_item['Norm'] = None + parsed_item['data_name'] = '-1' + item_list.append(parsed_item) + else: + parsed_item['full_op_name'] = full_op_name + parsed_item['dtype'] = str(type(item['value'])) + parsed_item['shape'] = '[]' + parsed_item['md5'] = None + parsed_item['Max'] = item['value'] + parsed_item['Min'] = item['value'] + parsed_item['Mean'] = item['value'] + parsed_item['Norm'] = item['value'] + parsed_item['data_name'] = '-1' + item_list.append(parsed_item) + else: + resolve_api_special_parameters(item, full_op_name, item_list) + else: + for j, item_spec in enumerate(item): + op_item_parse(item_spec, full_op_name, j, item_list=item_list, top_bool=False) + return item_list -def match_op(npu_queue, bench_queue, fuzzy_match): - for b_index, b_op in enumerate(bench_queue[0: -1]): - if check_op(npu_queue[-1], b_op, fuzzy_match): - return len(npu_queue) - 1, b_index - if check_op(npu_queue[-1], bench_queue[-1], fuzzy_match): - return len(npu_queue) - 1, len(bench_queue) - 1 - for n_index, n_op in enumerate(npu_queue[0: -1]): - if check_op(n_op, bench_queue[-1], fuzzy_match): - return n_index, len(bench_queue) - 1 - return -1, -1 +def resolve_api_special_parameters(data_dict, full_op_name, item_list): + """ + Function Description: + 解析下面格式的数据, 是api参数的一种特殊格式 + { + "last_hidden_state": { + "type": "torch.Tensor", + "dtype": "torch.bfloat16", + ... + }, + "loss": { + "type": "torch.Tensor", + "dtype": "torch.float32", + ... + } + } + Parameter: + data_dict: 字典格式的数据 + full_op_name: 参数的全名字符串 + item_list: 参数信息集合 + """ + for key, value in data_dict.items(): + if isinstance(value, dict): + parsed_item = value + parts = full_op_name.split(".") + parts.insert(-1, key) + full_op_name_new = ".".join(parts) + parsed_item['full_op_name'] = full_op_name_new + item_list.append(parsed_item) def get_accuracy(result, n_dict, b_dict, summary_compare=False, md5_compare=False): @@ -304,700 +323,6 @@ def get_accuracy(result, n_dict, b_dict, summary_compare=False, md5_compare=Fals get_accuracy_core(n_num_input + n_num_kwarg, n_num_output, b_num_input + b_num_kwarg, b_num_output, 'output_struct') -def _do_multi_process(input_parma, result_df): - try: - result_df = _handle_multi_process(compare_ops, input_parma, result_df, multiprocessing.Manager().RLock()) - return result_df - except ValueError as e: - logger.error('result dataframe is not found.') - raise CompareException(CompareException.INVALID_DATA_ERROR) from e - - -def read_dump_data(result_df): - try: - npu_dump_name_list = result_df.iloc[0:, 0].tolist() - npu_dump_tensor_list = result_df.iloc[0:, -1].tolist() - op_name_mapping_dict = {} - for index, _ in enumerate(npu_dump_name_list): - npu_dump_name = npu_dump_name_list[index] - npu_dump_tensor = npu_dump_tensor_list[index] - op_name_mapping_dict[npu_dump_name] = [npu_dump_tensor, npu_dump_tensor] - return op_name_mapping_dict - except ValueError as e: - logger.error('result dataframe is not found.') - raise CompareException(CompareException.INVALID_DATA_ERROR) from e - except IndexError as e: - logger.error('result dataframe elements can not be access.') - raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e - - -def _handle_multi_process(func, input_parma, result_df, lock): - process_num = int((multiprocessing.cpu_count() + 1) / 2) - op_name_mapping_dict = read_dump_data(result_df) - - df_chunk_size = len(result_df) // process_num - if df_chunk_size > 0: - df_chunks = [result_df.iloc[i:i + df_chunk_size] for i in range(0, len(result_df), df_chunk_size)] - else: - df_chunks = [result_df] - - results = [] - pool = multiprocessing.Pool(process_num) - - def err_call(args): - logger.error('multiprocess compare failed! Reason: {}'.format(args)) - try: - pool.terminate() - except OSError as e: - logger.error("pool terminate failed") - - for process_idx, df_chunk in enumerate(df_chunks): - idx = df_chunk_size * process_idx - result = pool.apply_async(func, - args=(idx, op_name_mapping_dict, df_chunk, lock, input_parma), - error_callback=err_call) - results.append(result) - final_results = [r.get() for r in results] - pool.close() - pool.join() - return pd.concat(final_results, ignore_index=True) - - -def compare_ops(idx, dump_path_dict, result_df, lock, input_parma): - cos_result = [] - max_err_result = [] - max_relative_err_result = [] - err_mess = [] - one_thousand_err_ratio_result = [] - five_thousand_err_ratio_result = [] - is_print_compare_log = input_parma.get("is_print_compare_log") - for i in range(len(result_df)): - op_name = result_df.iloc[i, 0] - if is_print_compare_log: - logger.info("start compare: {}".format(op_name)) - cos_sim, max_abs_err, max_relative_err, one_thousand_err_ratio, five_thousand_err_ratio, err_msg = compare_by_op( - op_name, dump_path_dict, input_parma) - if is_print_compare_log: - logger.info( - "[{}] Compare result: cosine {}, max_abs_err {}, max_relative_err {}, {}, one_thousand_err_ratio {}, " - "five_thousand_err_ratio {}".format(op_name, cos_sim, max_abs_err, max_relative_err, err_msg, - one_thousand_err_ratio, five_thousand_err_ratio)) - cos_result.append(cos_sim) - max_err_result.append(max_abs_err) - max_relative_err_result.append(max_relative_err) - err_mess.append(err_msg) - one_thousand_err_ratio_result.append(one_thousand_err_ratio) - five_thousand_err_ratio_result.append(five_thousand_err_ratio) - - cr = ComparisonResult( - cos_result=cos_result, - max_err_result=max_err_result, - max_relative_err_result=max_relative_err_result, - err_msgs=err_mess, - one_thousand_err_ratio_result=one_thousand_err_ratio_result, - five_thousand_err_ratio_result=five_thousand_err_ratio_result - ) - - return _save_cmp_result(idx, cr, result_df, lock) - - -@dataclass -class ComparisonResult: - cos_result: list - max_err_result: list - max_relative_err_result: list - err_msgs: list - one_thousand_err_ratio_result: list - five_thousand_err_ratio_result: list - - -def _save_cmp_result(offset, result: ComparisonResult, result_df, lock): - """ - Save comparison results into the result DataFrame with thread safety. - Args: - offset: offset for index - result: data struct of ComparisonResult - result_df: result of DataFrame - lock: thread lock - - Returns: - comparison results in DataFrame - """ - - lock.acquire() - try: - for i, _ in enumerate(result.cos_result): - process_index = i + offset - result_df.loc[process_index, CompareConst.COSINE] = result.cos_result[i] - result_df.loc[process_index, CompareConst.MAX_ABS_ERR] = result.max_err_result[i] - result_df.loc[process_index, CompareConst.MAX_RELATIVE_ERR] = result.max_relative_err_result[i] - result_df.loc[process_index, CompareConst.ERROR_MESSAGE] = result.err_msgs[i] - result_df.loc[process_index, CompareConst.ACCURACY] = check_accuracy(result.cos_result[i], result.max_err_result[i]) - result_df.loc[process_index, CompareConst.ONE_THOUSANDTH_ERR_RATIO] = result.one_thousand_err_ratio_result[i] - result_df.loc[process_index, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO] = result.five_thousand_err_ratio_result[i] - return result_df - except ValueError as e: - logger.error('result dataframe is not found.') - raise CompareException(CompareException.INVALID_DATA_ERROR) from e - except IndexError as e: - logger.error('result dataframe elements can not be access.') - raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e - finally: - lock.release() - - -def check_accuracy(cos, max_abs_err): - if cos == CompareConst.SHAPE_UNMATCH: - return CompareConst.ACCURACY_CHECK_UNMATCH - if cos == CompareConst.NONE or max_abs_err == CompareConst.NONE: - return CompareConst.NONE - if cos == "N/A" or max_abs_err == "N/A": - return CompareConst.ACCURACY_CHECK_NO - try: - cos, max_abs_err = float(cos), float(max_abs_err) - except ValueError: - logger.warning("Cosine or MaxAbsErr can not get float value.") - return CompareConst.NONE - if cos < CompareConst.COS_THRESHOLD and max_abs_err > CompareConst.MAX_ABS_ERR_THRESHOLD: - return CompareConst.ACCURACY_CHECK_NO - if cos < CompareConst.COS_MAX_THRESHOLD or max_abs_err > CompareConst.MAX_ABS_ERR_MAX_THRESHOLD: - return CompareConst.ACCURACY_CHECK_NO - return CompareConst.ACCURACY_CHECK_YES - - -def read_npy_data(dir_path, file_name): - data_path = os.path.join(dir_path, file_name) - path_checker = FileChecker(data_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE, - FileCheckConst.PT_SUFFIX, False) - data_path = path_checker.common_check() - data_value = torch.load(data_path, map_location=torch.device('cpu')).detach() # detach for less memory - if data_value.dtype == torch.bfloat16: - data_value = data_value.to(torch.float32) - data_value = data_value.numpy() - return data_value - - -def compare_by_op(op_name, op_name_mapping_dict, input_parma): - npu_bench_name_list = op_name_mapping_dict[op_name] - data_name = npu_bench_name_list[1] - error_file, relative_err, error_flag = None, None, False - if data_name == '-1' or data_name == -1: # 没有真实数据路径 - n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE - error_flag = True - else: - try: - n_value = read_npy_data(input_parma.get("npu_dump_data_dir"), npu_bench_name_list[0]) - b_value = read_npy_data(input_parma.get("bench_dump_data_dir"), npu_bench_name_list[1]) - except IOError as error: - error_file = error.filename - n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE - error_flag = True - except FileCheckException: - error_file = data_name - n_value, b_value = CompareConst.READ_NONE, CompareConst.READ_NONE - error_flag = True - - n_value, b_value, error_flag = get_error_type(n_value, b_value, error_flag) - if not error_flag: - relative_err = get_relative_err(n_value, b_value) - n_value, b_value = reshape_value(n_value, b_value) - - err_msg = get_error_message(n_value, b_value, op_name, error_flag, error_file=error_file) - result_list, err_msg = compare_ops_apply(n_value, b_value, error_flag, err_msg, relative_err=relative_err) - - if npu_bench_name_list[0] != npu_bench_name_list[1]: - err_msg += " Fuzzy matching data, the comparison accuracy may be affected." - result_list.append(err_msg) - return result_list - - -def handle_inf_nan(n_value, b_value): - n_inf = np.isinf(n_value) - b_inf = np.isinf(b_value) - n_nan = np.isnan(n_value) - b_nan = np.isnan(b_value) - - # merge boolean expressions - any_inf = np.any(n_inf) or np.any(b_inf) - any_nan = np.any(n_nan) or np.any(b_nan) - if any_inf or any_nan: - if np.array_equal(n_inf, b_inf) and np.array_equal(n_nan, b_nan): - n_value[n_inf] = 0 - b_value[b_inf] = 0 - n_value[n_nan] = 0 - b_value[b_nan] = 0 - else: - return CompareConst.NAN, CompareConst.NAN - return n_value, b_value - - -def find_error_rows(result, last_len, n_num_input, highlight_dict, summary_compare=False, md5_compare=False): - """找到单个API中需要高亮的行""" - if md5_compare: - return - npu_max_index = get_header_index('NPU max', summary_compare) - bench_max_index = get_header_index('Bench max', summary_compare) - max_diff_index = get_header_index('Max diff' if summary_compare else 'MaxAbsErr', summary_compare) - - red_lines, yellow_lines = [], [] - LineInfo = namedtuple('LineInfo', ['line_data', 'num_pointer']) - ApiInfo = namedtuple('ApiInfo', ['api_input', 'api_output', 'num_pointer']) - ColorColumns = namedtuple('ColorColumns', ['red', 'yellow']) - color_columns = ColorColumns(red=red_lines, yellow=yellow_lines) - - # 对单行API的输入或输出进行误差判断 - for i, line in enumerate(result): - num = last_len + i - line_info = LineInfo(line_data=line, num_pointer=num) - for rule in HighlightRules.basic_rules.values(): - rule.apply(line_info, color_columns, summary_compare) - - # 对API的输出与输入比较,进行误差判断 - for n, api_out in enumerate(result[n_num_input:len(result)]): - num = last_len + n_num_input + n - if num in red_lines: - continue - if not isinstance(api_out[npu_max_index], (float, int)) \ - or not isinstance(api_out[bench_max_index], (float, int)) \ - or not isinstance(api_out[max_diff_index], (float, int)): - continue - for _, api_in in enumerate(result[0:n_num_input]): - if not isinstance(api_in[npu_max_index], (float, int)) \ - or not isinstance(api_in[bench_max_index], (float, int)) \ - or not isinstance(api_in[max_diff_index], (float, int)): - continue - - api_info = ApiInfo(api_input=api_in, api_output=api_out, num_pointer=num) - if summary_compare: - for rule in HighlightRules.summary_compare_rules.values(): - rule.apply(api_info, color_columns, summary_compare) - else: - for rule in HighlightRules.compare_rules.values(): - rule.apply(api_info, color_columns, summary_compare) - - highlight_dict.get('red_rows', []).extend(list(set(red_lines))) - highlight_dict.get('yellow_rows', []).extend(list(set(yellow_lines) - set(red_lines))) - - -def get_name_and_state(name): - """Get api/module name and state""" - if "input" in name: - api_name = name.split("input")[0] - state = "input" - else: - api_name = name.split("output")[0] - state = "output" - return api_name, state - - -def find_compare_result_error_rows(result_df, highlight_dict, summary_compare, md5_compare): - """将dataframe根据API分组,并找到有误差的算子用于高亮""" - result = result_df.values - start, input_num, output_num, end = 0, 0, 0, len(result_df) - last_api_name, last_state = None, None - num, last_len = 0, 0 - for res_i in result: - api_name, state = get_name_and_state(res_i[0]) - if last_api_name: - if api_name == last_api_name: - if state == last_state: - num += 1 - else: - input_num = num - num, last_state = 1, state - else: - output_num = num - find_error_rows(result[start:start + input_num + output_num], start, input_num, highlight_dict, - summary_compare, md5_compare) - num, last_api_name, last_state = 1, api_name, state - start += input_num + output_num - input_num, output_num = 1, 0 - else: - num, last_api_name, last_state = 1, api_name, state - if state: - if state == "input": - input_num = num - else: - output_num = num - find_error_rows(result[start:start + input_num + output_num], start, input_num, highlight_dict, summary_compare, md5_compare) - - -def highlight_rows_xlsx(result_df, highlight_dict, file_path): - """Write and highlight results in Excel""" - logger.info('Compare result is %s' % file_path) - - wb = openpyxl.Workbook() - ws = wb.active - - # write header - for j, col_name in enumerate(result_df.columns, start=1): - ws.cell(row=1, column=j, value=col_name) - - for i, row in enumerate(result_df.iterrows(), start=2): - for j, value in enumerate(row[1], start=1): - if not isinstance(value, (float, int)): - value = f'{str(value)}\t' if str(value) in ('inf', '-inf', 'nan') else str(value) - ws.cell(row=i, column=j, value=f'{str(value)}\t' if str(value) in ('inf', '-inf', 'nan') else value) - - if (i - 2) in highlight_dict['red_rows']: - ws.cell(row=i, column=j).fill = PatternFill(start_color=CompareConst.RED, - end_color=CompareConst.RED, fill_type="solid") - elif (i - 2) in highlight_dict['yellow_rows']: - ws.cell(row=i, column=j).fill = PatternFill(start_color=CompareConst.YELLOW, - end_color=CompareConst.YELLOW, fill_type="solid") - try: - wb.save(file_path) - except Exception as e: - logger.error('Save result file failed') - raise CompareException(CompareException.WRITE_FILE_ERROR) from e - change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY) - - -def compare(input_parma, output_path, stack_mode=False, auto_analyze=True, - fuzzy_match=False): - try: - summary_compare, md5_compare = task_dumppath_get(input_parma) - check_configuration_param(stack_mode, auto_analyze, fuzzy_match) - create_directory(output_path) - check_compare_param(input_parma, output_path, summary_compare, md5_compare) - except (CompareException, FileCheckException) as error: - logger.error('Compare failed. Please check the arguments and do it again!') - sys.exit(error.code) - compare_core(input_parma, output_path, stack_mode=stack_mode, - auto_analyze=auto_analyze, fuzzy_match=fuzzy_match, summary_compare=summary_compare, - md5_compare=md5_compare) - - -def compare_core(input_parma, output_path, **kwargs): - """ - Compares data from multiple JSON files and generates a comparison report. - - Args: - input_parma (dict): A dictionary containing paths to JSON files ("npu_json_path", "bench_json_path", - "stack_json_path"). - output_path (str): The path where the output Excel report will be saved. - **kwargs: Additional keyword arguments including: - - stack_mode (bool, optional): Enables stack mode comparison. Defaults to False. - - auto_analyze (bool, optional): If True, triggers automatic analysis after comparison. Defaults to True. - - suffix (str, optional): Suffix to append to the output file name. Defaults to ''. - - fuzzy_match (bool, optional): Enables fuzzy matching during comparison. Defaults to False. - - summary_compare (bool, optional): Enables summary comparison mode. Defaults to False. - - md5_compare (bool, optional): Enables MD5 comparison. Defaults to False. - - Returns: - """ - # get kwargs or set default value - stack_mode = kwargs.get('stack_mode', False) - auto_analyze = kwargs.get('auto_analyze', True) - suffix = kwargs.get('suffix', '') - fuzzy_match = kwargs.get('fuzzy_match', False) - summary_compare = kwargs.get('summary_compare', False) - md5_compare = kwargs.get('md5_compare', False) - - logger.info("Please check whether the input data belongs to you. If not, there may be security risks.") - file_name = add_time_with_xlsx("compare_result" + suffix) - file_path = os.path.join(os.path.realpath(output_path), file_name) - check_file_not_exists(file_path) - highlight_dict = {'red_rows': [], 'yellow_rows': []} - - with FileOpen(input_parma.get("npu_json_path"), "r") as npu_json, \ - FileOpen(input_parma.get("bench_json_path"), "r") as bench_json, \ - FileOpen(input_parma.get("stack_json_path"), "r") as stack_json: - result_df = compare_process([npu_json, bench_json, stack_json], stack_mode, fuzzy_match, - summary_compare, md5_compare) - - if not md5_compare and not summary_compare: - result_df = _do_multi_process(input_parma, result_df) - find_compare_result_error_rows(result_df, highlight_dict, summary_compare, md5_compare) - highlight_rows_xlsx(result_df, highlight_dict, file_path) - if auto_analyze: - advisor = Advisor(result_df, output_path) - advisor.analysis() - - -def parse(pkl_file, module_name_prefix): - if not isinstance(module_name_prefix, str): - logger.error("The parameter:module_name_prefix is not a string.") - raise CompareException(CompareException.INVALID_PARAM_ERROR) - with FileOpen(pkl_file, "r") as f: - done = False - title_printed = False - while not done: - pkl_line = f.readline() - if pkl_line == '\n': - continue - if len(pkl_line) == 0: - done = True - break - - msg = json.loads(pkl_line) - info_prefix = msg[0] - if not info_prefix.startswith(module_name_prefix): - continue - - if info_prefix.find("stack_info") != -1: - logger.info("\nTrace back({}):".format(msg[0])) - for item in reversed(msg[1]): - logger.info(" File \"{}\", line {}, in {}".format(item[0], item[1], item[2])) - logger.info(" {}".format(item[3])) - continue - if len(msg) > 5: - summary_info = " [{}][dtype: {}][shape: {}][max: {}][min: {}][mean: {}]" \ - .format(msg[0], msg[3], msg[4], msg[5][0], msg[5][1], msg[5][2]) - if not title_printed: - logger.info("\nStatistic Info:") - title_printed = True - logger.info(summary_info) - - -def op_item_parse(item, op_name, index, item_list=None, top_bool=True): - if item_list is None: - item_list = [] - if item is None or (isinstance(item, dict) and not item): - if not top_bool: - tmp = {'full_op_name': op_name + '.' + str(index), 'Max': None, 'Min': None, 'Mean': None, 'Norm': None, - 'dtype': None, 'shape': None, 'md5': None, 'data_name': '-1'} - else: - tmp = {'full_op_name': op_name + '.0', 'Max': None, 'Min': None, 'Mean': None, 'Norm': None, 'dtype': None, - 'shape': None, 'md5': None, 'data_name': '-1'} - item_list.append(tmp) - return item_list - if index is None: - if isinstance(item, dict): - full_op_name = op_name + '.0' - else: - full_op_name = op_name - else: - full_op_name = op_name + Const.SEP + str(index) - if isinstance(item, dict): - if 'type' not in item: - for kwarg in item: - kwarg_parsed_list = op_item_parse(item[kwarg], op_name + Const.SEP + kwarg, None) - item_list += kwarg_parsed_list - kwarg_parsed_list.clear() - elif 'dtype' in item: - parsed_item = item - parsed_item['full_op_name'] = full_op_name - item_list.append(parsed_item) - elif 'type' in item: - parsed_item = {} - if item['type'] == 'torch.Size': - parsed_item['full_op_name'] = full_op_name - parsed_item['dtype'] = 'torch.Size' - parsed_item['shape'] = str(item['value']) - parsed_item['md5'] = None - parsed_item['Max'] = None - parsed_item['Min'] = None - parsed_item['Mean'] = None - parsed_item['Norm'] = None - parsed_item['data_name'] = '-1' - item_list.append(parsed_item) - elif item['type'] == 'slice': - parsed_item['full_op_name'] = full_op_name - parsed_item['dtype'] = 'slice' - parsed_item['shape'] = str(np.shape(np.array(item['value']))) - parsed_item['md5'] = None - parsed_item['Max'] = None - parsed_item['Min'] = None - parsed_item['Mean'] = None - parsed_item['Norm'] = None - parsed_item['data_name'] = '-1' - item_list.append(parsed_item) - else: - parsed_item['full_op_name'] = full_op_name - parsed_item['dtype'] = str(type(item['value'])) - parsed_item['shape'] = '[]' - parsed_item['md5'] = None - parsed_item['Max'] = item['value'] - parsed_item['Min'] = item['value'] - parsed_item['Mean'] = item['value'] - parsed_item['Norm'] = item['value'] - parsed_item['data_name'] = '-1' - item_list.append(parsed_item) - else: - resolve_api_special_parameters(item, full_op_name, item_list) - else: - for j, item_spec in enumerate(item): - op_item_parse(item_spec, full_op_name, j, item_list=item_list, top_bool=False) - return item_list - - -def resolve_api_special_parameters(data_dict, full_op_name, item_list): - """ - Function Description: - 解析下面格式的数据, 是api参数的一种特殊格式 - { - "last_hidden_state": { - "type": "torch.Tensor", - "dtype": "torch.bfloat16", - ... - }, - "loss": { - "type": "torch.Tensor", - "dtype": "torch.float32", - ... - } - } - Parameter: - data_dict: 字典格式的数据 - full_op_name: 参数的全名字符串 - item_list: 参数信息集合 - """ - for key, value in data_dict.items(): - if isinstance(value, dict): - parsed_item = value - parts = full_op_name.split(".") - parts.insert(-1, key) - full_op_name_new = ".".join(parts) - parsed_item['full_op_name'] = full_op_name_new - item_list.append(parsed_item) - - -def read_op(op_data, op_name): - op_parsed_list = [] - if 'forward' in op_name: - if 'input_args' in op_data: - input_item = op_data['input_args'] - input_parsed_list = op_item_parse(input_item, op_name + '_input', None) - op_parsed_list = input_parsed_list.copy() - input_parsed_list.clear() - if 'input_kwargs' in op_data: - kwargs_item = op_data['input_kwargs'] - if isinstance(kwargs_item, dict) and "type" in kwargs_item or isinstance(kwargs_item, list): - kwarg_parsed_list = op_item_parse(kwargs_item, op_name + '_input', None) - op_parsed_list += kwarg_parsed_list - kwarg_parsed_list.clear() - elif kwargs_item: - for kwarg in kwargs_item: - kwarg_parsed_list = op_item_parse(kwargs_item[kwarg], op_name + '_input.' + kwarg, None) - op_parsed_list += kwarg_parsed_list - kwarg_parsed_list.clear() - if 'output' in op_data: - output_item = op_data['output'] - output_parsed_list = op_item_parse(output_item, op_name + '_output', None) - op_parsed_list += output_parsed_list - output_parsed_list.clear() - if 'backward' in op_name: - if 'input' in op_data: - input_item = op_data['input'] - input_parsed_list = op_item_parse(input_item, op_name + '_input', None) - op_parsed_list = input_parsed_list.copy() - input_parsed_list.clear() - if 'output' in op_data: - output_item = op_data['output'] - output_parsed_list = op_item_parse(output_item, op_name + '_output', None) - op_parsed_list += output_parsed_list - output_parsed_list.clear() - return op_parsed_list - - -def compare_process(file_handles, stack_mode, fuzzy_match, summary_compare=False, md5_compare=False): - npu_json_handle, bench_json_handle, stack_json_handle = file_handles - npu_json_data = json.load(npu_json_handle) - bench_json_data = json.load(bench_json_handle) - stack_json_data = json.load(stack_json_handle) - - if fuzzy_match: - logger.warning("This task uses fuzzy matching, which may affect the accuracy of the comparison.") - - npu_ops_queue = [] - bench_ops_queue = [] - result = [] - - ops_npu_iter = iter(npu_json_data['data']) - ops_bench_iter = iter(bench_json_data['data']) - read_err_npu = True - read_err_bench = True - last_npu_ops_len = 0 - last_bench_ops_len = 0 - - while True: - if not read_err_npu and not read_err_bench: - break - try: - last_npu_ops_len = len(npu_ops_queue) - op_name_npu = next(ops_npu_iter) - read_err_npu = True - - npu_op_data = npu_json_data['data'][op_name_npu] - npu_op_parsed_list = read_op(npu_op_data, op_name_npu) - if op_name_npu in stack_json_data: - npu_op_parsed_list.append({'full_op_name': op_name_npu, 'full_info': stack_json_data[op_name_npu]}) - else: - npu_op_parsed_list.append({'full_op_name': op_name_npu, 'full_info': None}) - - npu_merge_list = merge_tensor(npu_op_parsed_list, summary_compare, md5_compare) - if npu_merge_list: - npu_ops_queue.append(npu_merge_list) - except StopIteration: - read_err_npu = False - try: - last_bench_ops_len = len(bench_ops_queue) - op_name_bench = next(ops_bench_iter) - - bench_op_data = bench_json_data['data'][op_name_bench] - bench_op_parsed_list = read_op(bench_op_data, op_name_bench) - if op_name_bench in stack_json_data: - bench_op_parsed_list.append( - {'full_op_name': op_name_bench, 'full_info': stack_json_data[op_name_bench]}) - else: - bench_op_parsed_list.append({'full_op_name': op_name_bench, 'full_info': None}) - - bench_merge_list = merge_tensor(bench_op_parsed_list, summary_compare, md5_compare) - if bench_merge_list: - bench_ops_queue.append(bench_merge_list) - except StopIteration: - read_err_bench = False - - # merge all boolean expressions - both_empty = not npu_ops_queue and not bench_ops_queue - no_change = (len(npu_ops_queue) == last_npu_ops_len) and (len(bench_ops_queue) == last_bench_ops_len) - if both_empty or no_change: - continue - - n_match_point, b_match_point = match_op(npu_ops_queue, bench_ops_queue, fuzzy_match) - if n_match_point == -1 and b_match_point == -1: - continue - n_match_data = npu_ops_queue[n_match_point] - b_match_data = bench_ops_queue[b_match_point] - un_match_data = npu_ops_queue[0: n_match_point] - for npu_data in un_match_data: - get_un_match_accuracy(result, npu_data, md5_compare, summary_compare) - get_accuracy(result, n_match_data, b_match_data, summary_compare, md5_compare) - del npu_ops_queue[0: n_match_point + 1] - del bench_ops_queue[0: b_match_point + 1] - if npu_ops_queue: - for npu_data in npu_ops_queue: - get_un_match_accuracy(result, npu_data, md5_compare, summary_compare) - - header = [] - if md5_compare: - header = CompareConst.MD5_COMPARE_RESULT_HEADER[:] - elif summary_compare: - header = CompareConst.SUMMARY_COMPARE_RESULT_HEADER[:] - else: - header = CompareConst.COMPARE_RESULT_HEADER[:] - - all_mode_bool = not (summary_compare or md5_compare) - if stack_mode: - if all_mode_bool: - header.append(CompareConst.STACK) - header.append(CompareConst.DATA_NAME) - else: - header.append(CompareConst.STACK) - else: - if all_mode_bool: - for row in result: - del row[-2] - header.append(CompareConst.DATA_NAME) - else: - for row in result: - del row[-1] - - result_df = pd.DataFrame(result, columns=header) - return result_df - - def get_un_match_accuracy(result, n_dict, md5_compare, summary_compare): index_out = 0 npu_stack_info = n_dict.get("stack_info", None) @@ -1036,3 +361,63 @@ def get_un_match_accuracy(result, n_dict, md5_compare, summary_compare): else: result_item.extend([CompareConst.NONE, "-1"]) result.append(result_item) + + +def merge_tensor(tensor_list, summary_compare, md5_compare): + op_dict = {} + op_dict["op_name"] = [] + op_dict["input_struct"] = [] + op_dict["kwargs_struct"] = [] + op_dict["output_struct"] = [] + op_dict["summary"] = [] + op_dict["stack_info"] = [] + + all_mode_bool = not (summary_compare or md5_compare) + if all_mode_bool: + op_dict["data_name"] = [] + + for tensor in tensor_list: + if len(tensor) == 2: + op_dict['stack_info'].append(tensor['full_info']) + break + op_dict["op_name"].append(tensor['full_op_name']) + if not md5_compare: + if tensor['full_op_name'].find("input") != -1: + op_dict["input_struct"].append((tensor['dtype'], tensor['shape'])) + elif tensor['full_op_name'].find("kwarg") != -1: + op_dict["kwargs_struct"].append((tensor['dtype'], tensor['shape'])) + elif tensor['full_op_name'].find("output") != -1: + op_dict["output_struct"].append((tensor['dtype'], tensor['shape'])) + else: + if tensor['full_op_name'].find("input") != -1: + op_dict["input_struct"].append((tensor['dtype'], tensor['shape'], tensor['md5'])) + elif tensor['full_op_name'].find("kwarg") != -1: + op_dict["kwargs_struct"].append((tensor['dtype'], tensor['shape'], tensor['md5'])) + elif tensor['full_op_name'].find("output") != -1: + op_dict["output_struct"].append((tensor['dtype'], tensor['shape'], tensor['md5'])) + + op_dict["summary"].append([tensor['Max'], tensor['Min'], tensor['Mean'], tensor['Norm']]) + + if all_mode_bool: + op_dict["data_name"].append(tensor['data_name']) + + if not op_dict["kwargs_struct"]: + del op_dict["kwargs_struct"] + return op_dict if op_dict["op_name"] else {} + + +def _compare_parser(parser): + parser.add_argument("-i", "--input_path", dest="input_path", type=str, + help=" The compare input path, a dict json.", required=True) + parser.add_argument("-o", "--output_path", dest="output_path", type=str, + help=" The compare task result out path.", required=True) + parser.add_argument("-s", "--stack_mode", dest="stack_mode", action="store_true", + help=" Whether to save stack info.", required=False) + parser.add_argument("-a", "--auto_analyze", dest="auto_analyze", action="store_false", + help=" Whether to give advisor.", required=False) + parser.add_argument("-f", "--fuzzy_match", dest="fuzzy_match", action="store_true", + help=" Whether to perform a fuzzy match on the api name.", required=False) + + + + diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py b/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py index db437539afeb98050ce59aad87a1e79d98b84085..2ac077dca67964110ba2e5bfc4151a5cbaa86a6a 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py @@ -90,7 +90,7 @@ class DataCollector: if self.config.level == "L2": return self.data_writer.update_stack(self.data_processor.analyze_api_call_stack(name)) - if self.data_processor.stop_run(): + if self.data_processor.is_terminated: self.handle_data(name, data_info, use_buffer=False) raise Exception("[msprobe] exit") self.handle_data(name, data_info) @@ -101,18 +101,34 @@ class DataCollector: return data_info = self.data_processor.analyze_backward(name, module, module_input_output) - if self.data_processor.stop_run(): + if self.data_processor.is_terminated: self.handle_data(name, data_info, use_buffer=False) raise Exception("[msprobe] exit") self.handle_data(name, data_info) + def backward_input_data_collect(self, name, module, pid, module_input_output): + self.update_construct(name) + if not self.check_scope_and_pid(self.scope, name, pid): + return + + data_info = self.data_processor.analyze_backward_input(name, module, module_input_output) + self.handle_data(name, data_info) + + def backward_output_data_collect(self, name, module, pid, module_input_output): + self.update_construct(name) + if not self.check_scope_and_pid(self.scope, name, pid): + return + + data_info = self.data_processor.analyze_backward_output(name, module, module_input_output) + self.handle_data(name, data_info) + def update_construct(self, name): if self.config.level not in DataCollector.level_without_construct: self.data_writer.update_construct({name: self.module_processor.api_parent_node}) self.data_writer.update_construct(self.module_processor.module_node) def handle_data(self, name, data_info, use_buffer=True): - msg = f"msProbe is collecting data on {name}. " + msg = f"msprobe is collecting data on {name}. " if data_info: msg = self.update_data(data_info, msg) logger.info(msg) diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py index 2fbc86b5656c3bcfe14b2fe9fe6bb295451e9466..679d985bb020751d71feb2c809977165866833d4 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py @@ -41,6 +41,24 @@ class ModuleBackwardInputsOutputs: return convert_tuple(self.grad_output) +@dataclass +class ModuleBackwardInputs: + grad_input: Optional[Tuple] + + @property + def grad_input_tuple(self): + return convert_tuple(self.grad_input) + + +@dataclass +class ModuleBackwardOutputs: + grad_output: Optional[Tuple] + + @property + def grad_output_tuple(self): + return convert_tuple(self.grad_output) + + class TensorStatInfo: def __init__(self, max_val=None, min_val=None, mean_val=None, norm_val=None): self.max = max_val @@ -69,6 +87,10 @@ class BaseDataProcessor: @property def data_path(self): return self.data_writer.dump_tensor_data_dir + + @property + def is_terminated(self): + return False @staticmethod def analyze_api_call_stack(name): @@ -228,12 +250,31 @@ class BaseDataProcessor: return api_info_struct + def analyze_backward_input(self, name, module, + module_input_output: ModuleBackwardInputs): + api_info_struct = {} + if self.is_dump_for_data_mode(Const.BACKWARD, Const.INPUT): + api_info_struct[name] = {} + self.api_data_category = Const.INPUT + + input_info_list = self.analyze_element(module_input_output.grad_input_tuple) + api_info_struct[name][Const.INPUT] = input_info_list + return api_info_struct + + def analyze_backward_output(self, name, module, + module_input_output: ModuleBackwardOutputs): + api_info_struct = {} + if self.is_dump_for_data_mode(Const.BACKWARD, Const.OUTPUT): + api_info_struct[name] = {} + self.api_data_category = Const.OUTPUT + + output_info_list = self.analyze_element(module_input_output.grad_output_tuple) + api_info_struct[name][Const.OUTPUT] = output_info_list + return api_info_struct + def get_save_file_path(self, suffix): file_format = Const.PT_SUFFIX if self.config.framework == Const.PT_FRAMEWORK else Const.NUMPY_SUFFIX dump_data_name = (self.current_api_or_module_name + Const.SEP + self.api_data_category + Const.SEP + suffix + file_format) file_path = os.path.join(self.data_writer.dump_tensor_data_dir, dump_data_name) return dump_data_name, file_path - - def stop_run(self): - return False diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py index c208df7d900683197fc24081b42835716ce7605f..8d09669096108bf9446d6df698c28f32c8069a19 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py @@ -74,8 +74,9 @@ class MindsporeDataProcessor(BaseDataProcessor): if data.numel() == 0: return tensor_stat elif data.dtype == ms.bool_: - tensor_stat.max = self.mint_ops_func["max"](data).item() - tensor_stat.min = self.mint_ops_func["min"](data).item() + data_np = data.asnumpy() + tensor_stat.max = np.max(data_np) + tensor_stat.min = np.min(data_np) elif not data.shape: tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data.item() elif data.dtype == ms.complex64 or data.dtype == ms.complex128: @@ -154,9 +155,18 @@ class OverflowCheckDataProcessor(MindsporeDataProcessor): def __init__(self, config, data_writer): super().__init__(config, data_writer) self.cached_tensors_and_file_paths = {} - self.real_overflow_dump_times = 0 + self.real_overflow_nums = 0 self.overflow_nums = config.overflow_nums + @property + def is_terminated(self): + if self.overflow_nums == -1: + return False + if self.real_overflow_nums >= self.overflow_nums: + logger.info(f"[msprobe] 超过预设溢出次数 当前溢出次数: {self.real_overflow_nums}") + return True + return False + def analyze_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs): self.has_overflow = False api_info_struct = super().analyze_forward(name, module, module_input_output) @@ -175,17 +185,9 @@ class OverflowCheckDataProcessor(MindsporeDataProcessor): tensor = convert_bf16_to_fp32(tensor) np.save(file_path, tensor.asnumpy()) change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY) - self.real_overflow_dump_times += 1 + self.real_overflow_nums += 1 self.cached_tensors_and_file_paths = {} - def stop_run(self): - if self.overflow_nums == -1: - return False - if self.real_overflow_dump_times >= self.overflow_nums: - logger.warning(f"[msprobe] 超过预设溢出次数 当前溢出次数: {self.real_overflow_dump_times}") - return True - return False - def _analyze_maybe_overflow_tensor(self, tensor_json): if tensor_json['Max'] is None: return diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py index 007fec80964e300315c59f3d7fa4166b9d10fa70..922f3e7006fc4b47f428c858e60ceb301101c426 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py @@ -5,6 +5,7 @@ from typing import List import numpy as np import torch +import torch.distributed as dist from msprobe.core.common.exceptions import MsprobeException from msprobe.core.common.file_check import path_len_exceeds_limit, change_mode from msprobe.core.common.log import logger @@ -12,6 +13,8 @@ from msprobe.core.common.const import Const, OverflowConst, FileCheckConst from msprobe.core.data_dump.data_processor.base import BaseDataProcessor, ModuleBackwardInputsOutputs, \ ModuleForwardInputsOutputs, TensorStatInfo from msprobe.pytorch.free_benchmark import FreeBenchmarkCheck, UnequalRow +from msprobe.pytorch.common.utils import save_pt + try: import torch_npu @@ -21,7 +24,8 @@ except ImportError: class PytorchDataProcessor(BaseDataProcessor): - pytorch_special_type = (torch.device, torch.dtype, torch.Size, torch.Tensor) + pytorch_special_type = (torch.device, torch.dtype, torch.Size, torch.Tensor, \ + dist.ProcessGroup) def __init__(self, config, data_writer): super().__init__(config, data_writer) @@ -69,6 +73,12 @@ class PytorchDataProcessor(BaseDataProcessor): tensor_stat.min = False not in data_clone elif not data_clone.shape: tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data_clone.item() + elif torch.is_complex(data_clone): + data_np = data_clone.cpu().numpy() + data_abs = np.abs(data_np) + tensor_stat.max = np.max(data_abs).item() + tensor_stat.min = np.min(data_abs).item() + tensor_stat.mean = np.mean(data_abs).item() else: if not data_clone.is_floating_point() or data_clone.dtype == torch.float64: data_clone = data_clone.float() @@ -113,6 +123,10 @@ class PytorchDataProcessor(BaseDataProcessor): @staticmethod def _analyze_torch_size(arg): return {"type": "torch.Size", "value": list(arg)} + + @staticmethod + def _analyze_process_group_ranks(arg): + return dist.get_process_group_ranks(arg) @classmethod def get_special_types(cls): @@ -130,6 +144,8 @@ class PytorchDataProcessor(BaseDataProcessor): return self._analyze_tensor(element, Const.SEP.join(suffix_stack)) if isinstance(element, (bool, int, float, str, slice)): return self._analyze_builtin(element) + if isinstance(element, dist.ProcessGroup): + return self._analyze_process_group_ranks(element) return {} def analyze_element(self, element): @@ -167,11 +183,8 @@ class StatisticsDataProcessor(PytorchDataProcessor): class TensorDataProcessor(PytorchDataProcessor): def _analyze_tensor(self, tensor, suffix): dump_data_name, file_path = self.get_save_file_path(suffix) - if not path_len_exceeds_limit(file_path): - torch.save(tensor, file_path) - change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY) - else: - logger.warning(f'The file path {file_path} length exceeds limit.') + saved_tensor = tensor.contiguous().detach() + save_pt(saved_tensor, file_path) single_arg = super()._analyze_tensor(tensor, suffix) single_arg.update({"data_name": dump_data_name}) return single_arg @@ -183,10 +196,19 @@ class OverflowCheckDataProcessor(PytorchDataProcessor): def __init__(self, config, data_writer): super().__init__(config, data_writer) self.cached_tensors_and_file_paths = {} - self.real_overflow_dump_times = 0 - self.overflow_nums = config.overflow_nums self.bits_for_overflow = 8 - + self.real_overflow_nums = 0 + self.overflow_nums = config.overflow_nums + + @property + def is_terminated(self): + if self.overflow_nums == -1: + return False + if self.real_overflow_nums >= self.overflow_nums: + logger.info(f"[msprobe] 超过预设溢出次数 当前溢出次数: {self.real_overflow_nums}") + return True + return False + @staticmethod def overflow_debug_mode_enable(): overflow_mode = os.getenv(OverflowConst.OVERFLOW_DEBUG_MODE_ENABLE, Const.ENV_DISABLE) @@ -209,16 +231,9 @@ class OverflowCheckDataProcessor(PytorchDataProcessor): for file_path, tensor in self.cached_tensors_and_file_paths.items(): torch.save(tensor, file_path) change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY) - self.inc_and_check_overflow_times() + self.real_overflow_nums += 1 self.cached_tensors_and_file_paths = {} - def inc_and_check_overflow_times(self): - self.real_overflow_dump_times += 1 - if self.overflow_nums == -1: - return - if self.real_overflow_dump_times >= self.overflow_nums: - raise MsprobeException(MsprobeException.OVERFLOW_NUMS_ERROR, str(self.real_overflow_dump_times)) - def check_overflow_npu(self): if self.overflow_debug_mode_enalbe(): float_status = torch.zeros(self.bits_for_overflow).npu() @@ -303,7 +318,7 @@ class FreeBenchmarkDataProcessor(PytorchDataProcessor): self._forward_new_output = new_output def analyze_backward(self, name, module, module_input_output: ModuleBackwardInputsOutputs): - self.checker.backward(name, module, module_input_output.grad_output) + self.checker.backward(name, module, module_input_output.grad_input) class KernelDumpDataProcessor(PytorchDataProcessor): diff --git a/debug/accuracy_tools/msprobe/core/grad_probe/constant.py b/debug/accuracy_tools/msprobe/core/grad_probe/constant.py index 38d33e9886490bba65205eff6a8d080070213acc..189ec2d11b275b28d4577ff5ea9baca21b71d3ad 100644 --- a/debug/accuracy_tools/msprobe/core/grad_probe/constant.py +++ b/debug/accuracy_tools/msprobe/core/grad_probe/constant.py @@ -39,7 +39,7 @@ class GradConst: DIRECTORY_LENGTH = 4096 FILE_NAME_LENGTH = 255 FILE_VALID_PATTERN = r"^[a-zA-Z0-9_.:/-]+$" - PARAM_VALID_PATTERN = r"^[a-zA-Z0-9.]+$" + PARAM_VALID_PATTERN = r"^[a-zA-Z0-9_.]+$" DIR = "dir" FILE = "file" @@ -53,4 +53,19 @@ class GradConst: SHAPE = "shape" MAX = "max" MIN = "min" - NORM = "norm" \ No newline at end of file + NORM = "norm" + +level_adp = { + "L0": { + "header": [GradConst.MD5, GradConst.MAX, GradConst.MIN, GradConst.NORM, GradConst.SHAPE], + "have_grad_direction": False + }, + "L1": { + "header": [GradConst.MAX, GradConst.MIN, GradConst.NORM, GradConst.SHAPE], + "have_grad_direction": True + }, + "L2": { + "header": [GradConst.DISTRIBUTION, GradConst.MAX, GradConst.MIN, GradConst.NORM, GradConst.SHAPE], + "have_grad_direction": True + }, + } \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/core/grad_probe/grad_compare.py b/debug/accuracy_tools/msprobe/core/grad_probe/grad_compare.py index 26cba34f0786cc62bf9b7760ee96f9671878be30..22acdf2fbe6fadb6c6c60b8f46e09104f8026a34 100644 --- a/debug/accuracy_tools/msprobe/core/grad_probe/grad_compare.py +++ b/debug/accuracy_tools/msprobe/core/grad_probe/grad_compare.py @@ -10,7 +10,6 @@ from msprobe.core.common.file_check import create_directory from msprobe.core.common.log import logger from msprobe.core.common.utils import remove_path, write_csv, load_npy from msprobe.core.grad_probe.constant import GradConst -from msprobe.pytorch.common.utils import load_pt class GradComparator: @@ -163,12 +162,8 @@ class GradComparator: @classmethod def _load_grad_files(cls, grad_file1: str, grad_file2: str): - if grad_file1.endswith('pt'): - grad1 = load_pt(grad_file1).numpy() - grad2 = load_pt(grad_file2).numpy() - else: - grad1 = load_npy(grad_file1) - grad2 = load_npy(grad_file2) + grad1 = load_npy(grad_file1) + grad2 = load_npy(grad_file2) if grad1.shape != grad2.shape: raise RuntimeError(f"tensor shape is not equal: {grad_file1}, {grad_file2}") if grad1.dtype != bool: diff --git a/debug/accuracy_tools/msprobe/core/grad_probe/utils.py b/debug/accuracy_tools/msprobe/core/grad_probe/utils.py index 05dd9a568e6e1078927b42e0e569df92d3c8b25d..f5db74baafd3aef775ff211060ff05d4d15f17b5 100644 --- a/debug/accuracy_tools/msprobe/core/grad_probe/utils.py +++ b/debug/accuracy_tools/msprobe/core/grad_probe/utils.py @@ -1,3 +1,8 @@ +import re +from msprobe.core.grad_probe.constant import GradConst +from msprobe.core.common.log import logger +from msprobe.core.common.utils import write_csv + def data_in_list_target(data, lst): return not lst or len(lst) == 0 or data in lst @@ -7,3 +12,41 @@ def check_numeral_list_ascend(lst): raise Exception("The input list should only contain numbers") if lst != sorted(lst): raise Exception("The input list should be ascending") + + +def check_param(param_name): + if not re.match(GradConst.PARAM_VALID_PATTERN, param_name): + raise RuntimeError("The parameter name contains special characters.") + + +def check_str(string, variable_name): + if not isinstance(string, str): + raise ValueError(f'The variable: "{variable_name}" is not a string.') + + +class ListCache(list): + threshold = 1000 + + def __init__(self, *args): + super().__init__(*args) + self._output_file = None + + def __del__(self): + self.flush() + + def flush(self): + if len(self) == 0: + return + if not self._output_file: + logger.warning("dumpfile path is not setted") + write_csv(self, self._output_file) + logger.info(f"write {len(self)} items to {self._output_file}.") + self.clear() + + def append(self, data): + list.append(self, data) + if len(self) >= ListCache.threshold: + self.flush() + + def set_output_file(self, output_file): + self._output_file = output_file diff --git a/debug/accuracy_tools/msprobe/doc/grad_probe/grad_probe.md b/debug/accuracy_tools/msprobe/doc/grad_probe/grad_probe.md new file mode 100644 index 0000000000000000000000000000000000000000..fcbd2f123d1610c275049cc4f523c87364f53a6a --- /dev/null +++ b/debug/accuracy_tools/msprobe/doc/grad_probe/grad_probe.md @@ -0,0 +1,207 @@ +# Ascend模型梯度状态监测工具 + +梯度状态监测工具提供了两种能力: + +- 将模型权重的梯度数据导出。这种功能可以将模型权重的梯度值以统计量的形式采集出来,用以分析问题。 +- 将两份梯度数据进行相似度对比。在有标杆问题中,可以确认训练过程中精度问题出现的step,以及抓取反向过程中的问题。 + +工具支持PyTorch版本:2.0/2.1/2.2;支持MindSpore版本:r2.3。 + +## 工具特性 + +- 使用便捷,无需在训练流程里插入代码 +- 可以精准定位问题出现的step + +## 使用方式 + +### 梯度数据导出 + +1. 创建配置文件config.json,样例如下: + + ```json + { + "task": "grad_probe", + "dump_path": "./dump_path", + "rank": [], + "step": [], + "grad_probe": { + "grad_level": "L1", + "param_list": [], + "bounds": [-1, 0, 1] + } + } + ``` + > step指的是优化器被调用的次数(并非模型跑的step,某些step,例如loss为nan时,不会调用优化器) + + **参数说明** + + | 参数 | 说明 | 输入类型 | 是否必选 | + |--------------------------------|-----------------------------------|-----------------|----------| + | task | 填为"grad_probe"。 | str | 是 | + | grad_level | 输出级别。决定导出数据的详细程度,级别越大导出数据越详细。可取值:L0, L1, L2|str | 是 | + | param_list | 权重名称列表,表示需要监控的权重。列表为空就表示监控所有权重。 | List[str] | 是 | + | rank | rank id列表,在多卡场景下,表示需要导出梯度数据的进程的rank id。列表为空就表示导出所有rank的数据。(MindSpore静态图模式下,当前暂不支持指定rank功能) | List[int] | 是 | + | step | step列表,表示需要导出数据的step列表。列表为空就表示导出所有step的数据。(MindSpore静态图模式下,当前暂不支持指定step功能) | List[int] | 是 | + | bounds | 区间列表,用来划分区间以统计数值的分布。需要保证由数据小到大排列。可以使用默认值[-1, 0, 1] | List[float] | 是 | + | dump_path | 输出目录。如果不存在就会创建一个新目录。 | str | 是 | + + **不同级别的level的导出数据** + + + | 级别 | 特征数据表头 | 是否有方向数据 | + | ---- | ------------------------------------------------------------ | -------------- | + | L0 | ("param_name", "MD5", "max", "min", "norm", "shape") | 否 | + | L1 | ("param_name", "max", "min", "norm", "shape") | 是 | + | L2 | ("param_name", *intervals, "=0", "max", "min", "norm", "shape") | 是 | + + intervals就是根据值分布bounds划分出的区间。 + MindSpore静态图模式下,L0级别中暂不支持"MD5" + + **方向数据解释** + + 因为模型的参数往往非常大,所以存储真实数据是不可接受的,这里折衷一下,只存储梯度数据的正负号(一个布尔值),也就是方向。 + + **bounds和值分布解释** + + + 值分布:梯度数据落在各个区间的元素个数占总元素个数的比例。 + + bounds:一个列表,用来划分出区间以统计值分布。例如传入bounds = [-10, 0, 10],此时有一个 grad_value: Tensor = [9.3 , 5.4, -1.0, -12.3],依据 bounds 划分出 (-inf, -10]、(-10, 0]、(0, 10]、(10, inf) 四个区间,然后统计grad_value里的数据落在每个区间内的个数,得到 1、1、2、0。如下图所示: + ![Alt text](img/image-1.png) + +2. 插入代码。示例代码如下: + +- PyTorch框架:模型构造完成后,传入config.json的路径实例化一个GradientMonitor对象,然后调用gm.monitor并将`模型`作为参数传入。 +```python +from msprobe.pytorch import PrecisionDebugger +debugger = PrecisionDebugger("config_json_path") +debugger.monitor(model) +``` +- MindSpore框架:优化器构造完成后,传入config.json的路径实例化一个GradientMonitor对象,然后调用gm.monitor并将`优化器`作为参数传入。 +```python +from msprobe.mindspore import PrecisionDebugger +debugger = PrecisionDebugger("config_json_path") +debugger.monitor(optimizer) +``` + +3. 结束监控(MindSpore静态图模式下需要) + + 在训练结束之后,调用stop接口 + +```python +gm.stop() +``` + +### 输出结果 +**输出目录结构**(以level配置L2为例) + +```bash +{dump_path} + ├── rank{rank_id} + │ ├── grad_summary_{step}.csv + │ ├── step{step} + │ │ ├── {param_name}.npy +``` ++ {timestamp}:梯度工具导出数据的时候会在output_path下生成一个时间戳目录,然后在这个时间戳目录下输出结果。 ++ rank_{rank_id}:在分布式场景下,会记录卡的rank_id。非分布式场景下,如果是CPU则记录进程号,如果是CPU或GPU则记录卡号 ++ grad_summary_{step}.csv:会分step记录每一步的梯度数据统计值。 ++ step_{step}:这个目录下会存放该step的梯度的方向数据。 ++ {param_name}.pt(npy):模型参数的梯度方向数据,PyTorch保存的是pt文件,MindSpore是npy文件。 + +**grad_summary_{step}.csv** + +样例如下: + +![Alt text](img/image.png) + +| 字段 | 含义 | +| --------------------- | ------------------------------------------------------------| +| Param_name | 模型参数名称。 | +| MD5 | 梯度数据的MD5值。 | +| (-inf, -0.01]...[0.01, inf) | 梯度值落在区间内的元素个数占总元素的比例。 | +| =0 | 梯度为0的元素个数占总元素的比例。 | +| Max | 最大值。 | +| Min | 最小值。 | +| Norm | L2norm值。 | +| Shape | 形状。 | + +### 梯度相似度比对 + +会根据所导出的权重,分step比对梯度相似度,输出每个权重的梯度相似度和总的梯度相似度。单个权重的梯度相似度为两份方向数据的重合度,总的梯度相似度为每个权重的梯度相似度按元素个数加权。 + +#### 前提条件 + +- 相同配置下,以Level为L1或L2分别采集npu和gpu环境下的梯度数据。 +- 将两份梯度数据传到同一环境下。 + +#### 使用方式 + + +新建如下Python脚本,传入npu和gpu的dump_path以及输出目录,比对结果输出目录不存在的话会新建: + +```python +from msprobe import * +GradComparator.compare_distributed("配置文件里写的dump_path", + "配置文件里写的dump_path", + "比对结果输出目录") +``` + + +### 比对结果 + +**输出目录结构** + +如下为多卡比对结果,单卡则没有rank_{rank_id}这一级目录。 + +```bash +比对结果输出目录 + ├── rank{rank_id} + │ ├── similarities.csv + │ └── similarities_picture + │ ├── {param_name}.png + │ └── summary_similarities.png +``` + +**问题界定** + +原则:对于任意权重,第0步的梯度相似度低于0.97,或者某一步的梯度相似度下降超过0.03,认为这一步存在精度问题。例子如下: + +- 第0步相似度低于0.97 + +![Alt text](img/image-3.png) + +- 第3步相似度下降超过0.03 + +![Alt text](img/image-4.png) + +- 正常情况 + +![Alt text](img/image-2.png) + +这个原则是一个经验性的指标,并不是严格的标注,还需要结合实际情况具体分析。 + +## 公开接口 + +**接口说明** + +```python +PrecisionDebugger.monitor(module) +``` + +| 参数 | 说明 | 是否必选 | +| ----- | -------------------- | -------- | +| module |Pytorch框架下传入模型,必须是torch.nn.Module;MindSpore框架下传入优化器。 | 是 | + + +**接口说明** + +```python +GradComparator.compare_distributed(dump_path1, dump_path2, output_path) +``` + +| 参数 | 说明 | 是否必选 | +| ----- | -------------------- | -------- | +| dump_path1 |需要比对的其中一个dump目录,也就是配置文件里写的dump_path。 | 是 | +| dump_path2 |需要比对的其中一个dump目录,也就是配置文件里写的dump_path,与dump_path1可以互换。 | 是 | +| output_path |输出结果目录,不存在会新建。 | 是 | + + +# FAQ diff --git a/debug/accuracy_tools/msprobe/doc/grad_probe/img/image-1.png b/debug/accuracy_tools/msprobe/doc/grad_probe/img/image-1.png new file mode 100644 index 0000000000000000000000000000000000000000..bee75b8b42e4d63137c554cb703ddd7f70d8c1ce Binary files /dev/null and b/debug/accuracy_tools/msprobe/doc/grad_probe/img/image-1.png differ diff --git a/debug/accuracy_tools/msprobe/doc/grad_probe/img/image-2.png b/debug/accuracy_tools/msprobe/doc/grad_probe/img/image-2.png new file mode 100644 index 0000000000000000000000000000000000000000..587ffc560fadcfb6600fd0b528845753fca53c82 Binary files /dev/null and b/debug/accuracy_tools/msprobe/doc/grad_probe/img/image-2.png differ diff --git a/debug/accuracy_tools/msprobe/doc/grad_probe/img/image-3.png b/debug/accuracy_tools/msprobe/doc/grad_probe/img/image-3.png new file mode 100644 index 0000000000000000000000000000000000000000..4c280ac00f8972b15e06b3528a17c6d7600c96f7 Binary files /dev/null and b/debug/accuracy_tools/msprobe/doc/grad_probe/img/image-3.png differ diff --git a/debug/accuracy_tools/msprobe/doc/grad_probe/img/image-4.png b/debug/accuracy_tools/msprobe/doc/grad_probe/img/image-4.png new file mode 100644 index 0000000000000000000000000000000000000000..095e402f7c46235cc7dac1b0bf5194d42aa32811 Binary files /dev/null and b/debug/accuracy_tools/msprobe/doc/grad_probe/img/image-4.png differ diff --git a/debug/accuracy_tools/msprobe/doc/grad_probe/img/image.png b/debug/accuracy_tools/msprobe/doc/grad_probe/img/image.png new file mode 100644 index 0000000000000000000000000000000000000000..5a498f5d2a77d8808d36d5c12e52c87809137bd0 Binary files /dev/null and b/debug/accuracy_tools/msprobe/doc/grad_probe/img/image.png differ diff --git a/debug/accuracy_tools/msprobe/mindspore/common/const.py b/debug/accuracy_tools/msprobe/mindspore/common/const.py new file mode 100644 index 0000000000000000000000000000000000000000..08bb976493396b1c8bb2d902552d6836af51e856 --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/common/const.py @@ -0,0 +1,85 @@ +import numpy as np +import mindspore as ms + + +class Const: + CELL = "cell" + API = "api" + KERNEL = "kernel" + TOOL_LEVEL_DICT = { + "L0": CELL, + "L1": API, + "L2": KERNEL + } + PYNATIVE_MODE = "pynative" + GRAPH_GE_MODE = "graph_ge" + GRAPH_KBYK_MODE = "graph_kbyk" + + +class FreeBenchmarkConst: + DEFAULT_DEVICE = "npu" + DEFAULT_STAGE = "forward" + DEFAULT_DUMP_LEVEL = "L1" + DEFAULT_PERT_TYPE = "improve_precision" + DEFAULT_HANDLER_TYPE = "check" + FIX_HANDLER_MODE = "fix" + ADD_NOISE = "add_noise" + BIT_NOISE = "bit_noise" + NO_CHANGE = "no_change" + IMPROVE_PRECISION = "improve_precision" + CHECK = "check" + FIX = "fix" + DEVICE_LIST = ["npu"] + STAGE_LIST = ["forward"] + DUMP_LEVEL_LIST = ["L1"] + PERT_TYPE_LIST = [IMPROVE_PRECISION, ADD_NOISE, BIT_NOISE, NO_CHANGE] + HANDLER_TYPE_LIST = [CHECK, FIX] + COMMUNICATION_API_LIST = [ + "mindspore.communication.comm_func.all_gather_into_tensor", + "mindspore.communication.comm_func.gather_into_tensor", + "mindspore.communication.comm_func.all_reduce", + "mindspore.communication.comm_func.reduce", + "mindspore.communication.comm_func.reduce_scatter_tensor" + ] + NO_CHANGE_ERROR_THRESHOLD = 1.0 + SYMBOL_FLIPPING_RATIO = 8.0 + OPS_PREFIX = "mindspore.ops." + Tensor_PREFIX = "mindspore.Tensor." + MINT_PREFIX = "mindspore.mint." + MINT_NN_FUNC_PREFIX = "mindspore.mint.nn.functional." + COMM_PREFIX = "mindspore.communication.comm_func." + + API_PREFIX_DICT = { + "ops": OPS_PREFIX, + "Tensor": Tensor_PREFIX, + "mint": MINT_PREFIX, + "mint.nn.functional": MINT_NN_FUNC_PREFIX, + "communication": COMM_PREFIX + } + + PERT_VALUE_DICT = { + ms.bfloat16: 1e-4, + ms.float16: 1e-6, + ms.float32: 1e-8, + ms.float64: 1e-16 + } + + ERROR_THRESHOLD = { + ms.float16: 1.002, + ms.float32: 1.0002 + } + + PERT_BIT_DICT = { + ms.float16: np.int16, + ms.float32: np.int32, + ms.float64: np.int64 + } + + MS_NUMPY_DTYPE_DICT = { + ms.int16: np.int16, + ms.int32: np.int32, + ms.int64: np.int64, + ms.float16: np.float16, + ms.float32: np.float32, + ms.float64: np.float64 + } diff --git a/debug/accuracy_tools/msprobe/mindspore/compare/compare_cli.py b/debug/accuracy_tools/msprobe/mindspore/compare/compare_cli.py new file mode 100644 index 0000000000000000000000000000000000000000..4a8149657385de3b9471c86ba7eb6e5d6874f46b --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/compare/compare_cli.py @@ -0,0 +1,23 @@ +import json +from msprobe.core.common.file_check import FileOpen, check_file_type +from msprobe.core.common.const import FileCheckConst +from msprobe.core.common.utils import CompareException +from msprobe.core.common.log import logger +from msprobe.mindspore.compare.ms_compare import ms_compare +from msprobe.mindspore.compare.distributed_compare import compare_distributed + +def compare_cli_ms(args): + with FileOpen(args.input_path, "r") as file: + input_param = json.load(file) + npu_path = input_param.get("npu_path", None) + bench_path = input_param.get("bench_path", None) + + if check_file_type(npu_path) == FileCheckConst.FILE and check_file_type(bench_path) == FileCheckConst.FILE: + ms_compare(input_param, args.output_path, stack_mode=args.stack_mode, auto_analyze=args.auto_analyze, + fuzzy_match=args.fuzzy_match) + elif check_file_type(npu_path) == FileCheckConst.DIR and check_file_type(bench_path) == FileCheckConst.DIR: + kwargs = {"stack_mode": args.stack_mode, "auto_analyze": args.auto_analyze, "fuzzy_match": args.fuzzy_match} + compare_distributed(npu_path, bench_path, args.output_path, **kwargs) + else: + logger.error("The npu_path and bench_path need to be of the same type.") + raise CompareException(CompareException.INVALID_COMPARE_MODE) diff --git a/debug/accuracy_tools/msprobe/mindspore/compare/distributed_compare.py b/debug/accuracy_tools/msprobe/mindspore/compare/distributed_compare.py new file mode 100644 index 0000000000000000000000000000000000000000..08f0a03ec70d410b93c001f01e8c8f87c6f1c80b --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/compare/distributed_compare.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2019-2024. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +import os +from msprobe.core.common.utils import CompareException, check_compare_param, \ + check_configuration_param, task_dumppath_get +from msprobe.core.common.file_check import create_directory +from msprobe.core.common.exceptions import FileCheckException +from msprobe.core.common.log import logger +from msprobe.mindspore.compare.ms_compare import MSComparator +from msprobe.core.compare.utils import check_and_return_dir_contents, extract_json + + +def compare_distributed(npu_dump_dir, bench_dump_dir, output_path, **kwargs): + if kwargs.get('suffix'): + logger.error("Argument 'suffix' is not supported for compare_distributed.") + raise CompareException(CompareException.INVALID_PARAM_ERROR) + stack_mode = kwargs.get('stack_mode', False) + auto_analyze = kwargs.get('auto_analyze', True) + fuzzy_match = kwargs.get('fuzzy_match', False) + # get the ranks and match by order + npu_ranks = sorted(check_and_return_dir_contents(npu_dump_dir, 'rank')) + bench_ranks = sorted(check_and_return_dir_contents(bench_dump_dir, 'rank')) + if len(npu_ranks) != len(bench_ranks): + logger.error('The number of ranks in the two runs are different. ' + 'Unable to match the ranks. Please use another folder to compare ' + 'or use compare() api and manually match the ranks.') + raise CompareException(CompareException.INVALID_PATH_ERROR) + for nr, br in zip(npu_ranks, bench_ranks): + npu_data_dir = os.path.join(npu_dump_dir, nr) + bench_data_dir = os.path.join(bench_dump_dir, br) + npu_path = extract_json(npu_data_dir, stack_json=False) + bench_path = extract_json(bench_data_dir, stack_json=False) + stack_path = extract_json(npu_data_dir, stack_json=True) + + dump_result_param = { + 'npu_path': npu_path, + 'bench_path': bench_path, + 'stack_path': stack_path, + 'is_print_compare_log': True + } + try: + summary_compare, md5_compare = task_dumppath_get(dump_result_param) + check_configuration_param(stack_mode, auto_analyze, fuzzy_match) + create_directory(output_path) + check_compare_param(dump_result_param, output_path, summary_compare=summary_compare, md5_compare=md5_compare) + except (CompareException, FileCheckException) as error: + logger.error('Compare failed. Please check the arguments and do it again!') + raise CompareException(error.code) from error + msComparator = MSComparator() + msComparator.compare_core(dump_result_param, output_path, suffix=f'_{nr}-{br}', summary_compare=summary_compare, + md5_compare=md5_compare, **kwargs) diff --git a/debug/accuracy_tools/msprobe/mindspore/compare/ms_compare.py b/debug/accuracy_tools/msprobe/mindspore/compare/ms_compare.py new file mode 100644 index 0000000000000000000000000000000000000000..a4736a91bb839b9375a3dfee140ee9c41e0c9ca2 --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/compare/ms_compare.py @@ -0,0 +1,191 @@ +import json +import os.path +import numpy as np +from msprobe.core.advisor.advisor import Advisor +from msprobe.core.common.utils import check_compare_param, add_time_with_xlsx, CompareException, \ + check_file_not_exists, check_configuration_param, task_dumppath_get +from msprobe.core.common.file_check import FileChecker, FileOpen, create_directory +from msprobe.core.common.const import FileCheckConst +from msprobe.core.common.log import logger +from msprobe.core.common.exceptions import FileCheckException +from msprobe.core.compare.utils import get_un_match_accuracy, get_accuracy +from msprobe.core.compare.multiprocessing_compute import ComparisonResult, _save_cmp_result +from msprobe.core.compare.highlight import find_compare_result_error_rows, highlight_rows_xlsx +from msprobe.core.compare.acc_compare import Comparator + + +class MSComparator (Comparator): + + def __init__(self): + self.frame_name = MSComparator.__name__ + + def compare_ops(self,idx, dump_path_dict, result_df, lock, input_parma): + cos_result = [] + max_err_result = [] + max_relative_err_result = [] + err_mess = [] + one_thousand_err_ratio_result = [] + five_thousand_err_ratio_result = [] + is_print_compare_log = input_parma.get("is_print_compare_log") + for i in range(len(result_df)): + op_name = result_df.iloc[i, 0] + if is_print_compare_log: + logger.info("start compare: {}".format(op_name)) + cos_sim, max_abs_err, max_relative_err, one_thousand_err_ratio, five_thousand_err_ratio, err_msg = self.compare_by_op( + op_name, dump_path_dict, input_parma) + if is_print_compare_log: + logger.info( + "[{}] Compare result: cosine {}, max_abs_err {}, max_relative_err {}, {}, one_thousand_err_ratio {}, " + "five_thousand_err_ratio {}".format(op_name, cos_sim, max_abs_err, max_relative_err, err_msg, + one_thousand_err_ratio, five_thousand_err_ratio)) + cos_result.append(cos_sim) + max_err_result.append(max_abs_err) + max_relative_err_result.append(max_relative_err) + err_mess.append(err_msg) + one_thousand_err_ratio_result.append(one_thousand_err_ratio) + five_thousand_err_ratio_result.append(five_thousand_err_ratio) + + cr = ComparisonResult( + cos_result = cos_result, + max_err_result = max_err_result, + max_relative_err_result = max_relative_err_result, + err_msgs = err_mess, + one_thousand_err_ratio_result = one_thousand_err_ratio_result, + five_thousand_err_ratio_result = five_thousand_err_ratio_result + ) + + return _save_cmp_result(idx, cr, result_df, lock) + + def compare_process(self,file_handles, stack_mode, fuzzy_match, summary_compare=False, md5_compare=False): + npu_json_handle, bench_json_handle, stack_json_handle = file_handles + npu_json_data = json.load(npu_json_handle) + bench_json_data = json.load(bench_json_handle) + stack_json_data = json.load(stack_json_handle) + + if fuzzy_match: + logger.warning("This task uses fuzzy matching, which may affect the accuracy of the comparison.") + + npu_ops_queue = [] + bench_ops_queue = [] + result = [] + + ops_npu_iter = iter(npu_json_data['data']) + ops_bench_iter = iter(bench_json_data['data']) + read_err_npu = True + read_err_bench = True + last_npu_ops_len = 0 + last_bench_ops_len = 0 + + while True: + if not read_err_npu and not read_err_bench: + break + try: + last_npu_ops_len = len(npu_ops_queue) + op_name_npu = next(ops_npu_iter) + read_err_npu = True + npu_merge_list = self.gen_merge_list(npu_json_data,op_name_npu,stack_json_data,summary_compare, md5_compare) + if npu_merge_list: + npu_ops_queue.append(npu_merge_list) + except StopIteration: + read_err_npu = False + try: + last_bench_ops_len = len(bench_ops_queue) + op_name_bench = next(ops_bench_iter) + bench_merge_list = self.gen_merge_list(bench_json_data,op_name_bench,stack_json_data,summary_compare, md5_compare) + if bench_merge_list: + bench_ops_queue.append(bench_merge_list) + except StopIteration: + read_err_bench = False + + # merge all boolean expressions + both_empty = not npu_ops_queue and not bench_ops_queue + no_change = (len(npu_ops_queue) == last_npu_ops_len) and (len(bench_ops_queue) == last_bench_ops_len) + if both_empty or no_change: + continue + + n_match_point, b_match_point = super().match_op(npu_ops_queue, bench_ops_queue, fuzzy_match) + if n_match_point == -1 and b_match_point == -1: + continue + n_match_data = npu_ops_queue[n_match_point] + b_match_data = bench_ops_queue[b_match_point] + un_match_data = npu_ops_queue[0: n_match_point] + for npu_data in un_match_data: + get_un_match_accuracy(result, npu_data, md5_compare, summary_compare) + get_accuracy(result, n_match_data, b_match_data, summary_compare, md5_compare) + del npu_ops_queue[0: n_match_point + 1] + del bench_ops_queue[0: b_match_point + 1] + if npu_ops_queue: + for npu_data in npu_ops_queue: + get_un_match_accuracy(result, npu_data, md5_compare, summary_compare) + result_df = self.make_result_table(result,md5_compare,summary_compare,stack_mode) + return result_df + + def read_npy_data(self,dir_path, file_name): + data_path = os.path.join(dir_path, file_name) + path_checker = FileChecker(data_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE, + FileCheckConst.NUMPY_SUFFIX, False) + data_path = path_checker.common_check() + data_value = np.load(data_path) # detach for less memory + if data_value.dtype == np.float16: + data_value = data_value.astype(np.float32) + + return data_value + + def compare_core(self,input_parma, output_path, **kwargs): + """ + Compares data from multiple JSON files and generates a comparison report. + + Args: + input_parma (dict): A dictionary containing paths to JSON files ("npu_path", "bench_path", + "stack_path"). + output_path (str): The path where the output Excel report will be saved. + **kwargs: Additional keyword arguments including: + - stack_mode (bool, optional): Enables stack mode comparison. Defaults to False. + - auto_analyze (bool, optional): If True, triggers automatic analysis after comparison. Defaults to True. + - suffix (str, optional): Suffix to append to the output file name. Defaults to ''. + - fuzzy_match (bool, optional): Enables fuzzy matching during comparison. Defaults to False. + - summary_compare (bool, optional): Enables summary comparison mode. Defaults to False. + - md5_compare (bool, optional): Enables MD5 comparison. Defaults to False. + + Returns: + """ + # get kwargs or set default value + stack_mode = kwargs.get('stack_mode', False) + auto_analyze = kwargs.get('auto_analyze', True) + suffix = kwargs.get('suffix', '') + fuzzy_match = kwargs.get('fuzzy_match', False) + summary_compare = kwargs.get('summary_compare', False) + md5_compare = kwargs.get('md5_compare', False) + + logger.info("Please check whether the input data belongs to you. If not, there may be security risks.") + file_name = add_time_with_xlsx("compare_result" + suffix) + file_path = os.path.join(os.path.realpath(output_path), file_name) + check_file_not_exists(file_path) + highlight_dict = {'red_rows': [], 'yellow_rows': []} + with FileOpen(input_parma.get("npu_path"), "r") as npu_json, \ + FileOpen(input_parma.get("bench_path"), "r") as bench_json, \ + FileOpen(input_parma.get("stack_path"), "r") as stack_json: + result_df = self.compare_process([npu_json, bench_json, stack_json], stack_mode, fuzzy_match, + summary_compare, md5_compare) + + if not md5_compare and not summary_compare: + result_df = self._do_multi_process(input_parma, result_df) + find_compare_result_error_rows(result_df, highlight_dict, summary_compare, md5_compare) + highlight_rows_xlsx(result_df, highlight_dict, file_path) + if auto_analyze: + advisor = Advisor(result_df, output_path) + advisor.analysis() + +def ms_compare(input_param, output_path, stack_mode=False, auto_analyze=True, fuzzy_match=False): + try: + summary_compare, md5_compare = task_dumppath_get(input_param) + check_configuration_param(stack_mode, auto_analyze, fuzzy_match) + create_directory(output_path) + check_compare_param(input_param, output_path, summary_compare, md5_compare) + except (CompareException, FileCheckException) as error: + logger.error('Compare failed. Please check the arguments and do it again!') + raise CompareException(error.code) from error + msComparator = MSComparator() + msComparator.compare_core(input_param, output_path, stack_mode=stack_mode, + auto_analyze=auto_analyze, fuzzy_match=fuzzy_match, summary_compare=summary_compare, + md5_compare=md5_compare) \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/mindspore/debugger/debugger_config.py b/debug/accuracy_tools/msprobe/mindspore/debugger/debugger_config.py index 23cb7294b8dc0d64e012f5fac2b863bdfe871bbe..54f640703c8e24c7cb8e7275d97928b9f554fa20 100644 --- a/debug/accuracy_tools/msprobe/mindspore/debugger/debugger_config.py +++ b/debug/accuracy_tools/msprobe/mindspore/debugger/debugger_config.py @@ -1,11 +1,15 @@ import os +from pathlib import Path -from msprobe.core.common.utils import Const -from msprobe.core.common.const import MsConst +from msprobe.core.common.const import Const +from msprobe.mindspore.common.const import Const as MsConst +from msprobe.mindspore.common.const import FreeBenchmarkConst +from msprobe.core.common.file_check import FileChecker, FileCheckConst, check_path_before_create class DebuggerConfig: def __init__(self, common_config, task_config): + self.execution_mode = None self.dump_path = common_config.dump_path self.task = common_config.task self.rank = [] if not common_config.rank else common_config.rank @@ -23,6 +27,19 @@ class DebuggerConfig: self.framework = Const.MS_FRAMEWORK self.summary_mode = task_config.summary_mode self.check() + self._make_dump_path_if_not_exists() + + if self.task == Const.FREE_BENCHMARK: + self.pert_type = (FreeBenchmarkConst.DEFAULT_PERT_TYPE + if not task_config.pert_mode else task_config.pert_mode) + self.handler_type = (FreeBenchmarkConst.DEFAULT_HANDLER_TYPE + if not task_config.handler_type else task_config.handler_type) + if self.handler_type == FreeBenchmarkConst.FIX_HANDLER_MODE and \ + self.pert_type != FreeBenchmarkConst.DEFAULT_PERT_TYPE: + raise ValueError("pert_mode must be improve_precision or empty when handler_type is fix, " + f"but got {self.pert_type}.") + self.dump_level = FreeBenchmarkConst.DEFAULT_DUMP_LEVEL + self.stage = FreeBenchmarkConst.DEFAULT_STAGE def check(self): if not self.dump_path: @@ -50,3 +67,10 @@ class DebuggerConfig: for s in self.step: if not isinstance(s, int): raise ValueError(f"step element {s} should be int") + + def _make_dump_path_if_not_exists(self): + check_path_before_create(self.dump_path) + if not os.path.exists(self.dump_path): + Path(self.dump_path).mkdir(mode=0o750, exist_ok=True) + file_check = FileChecker(self.dump_path, FileCheckConst.DIR) + file_check.common_check() diff --git a/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py b/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py index 5475dc3586c35687fec63b51f265ac83c0d33a87..04cc3345c5ecc740850e323e2f56e01011d6699c 100644 --- a/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py +++ b/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py @@ -6,13 +6,18 @@ from msprobe.mindspore.service import Service from msprobe.mindspore.ms_config import parse_json_config from msprobe.mindspore.debugger.debugger_config import DebuggerConfig from msprobe.mindspore.task_handler_factory import TaskHandlerFactory -from msprobe.core.common.const import MsConst +from msprobe.core.common.const import Const +from msprobe.mindspore.common.const import Const as MsConst +from msprobe.mindspore.runtime import Runtime + +from msprobe.mindspore.grad_probe.grad_monitor import GradientMonitor class PrecisionDebugger: _instance = None + task_not_need_service = [Const.GRAD_PROBE] - def __new__(cls, config_path=None): + def __new__(cls, config_path=None, opt=None): if not cls._instance: cls._instance = super().__new__(cls) cls._instance.initialized = False @@ -24,11 +29,18 @@ class PrecisionDebugger: def __init__(self, config_path=None): if self.initialized: return + self.initialized = True if not config_path: config_path = os.path.join(os.path.dirname(__file__), "../../config/config.json") common_config, task_config = parse_json_config(config_path) + self.task = common_config.task + if self.task == Const.GRAD_PROBE: + self.gm = GradientMonitor(common_config, task_config) + return self.config = DebuggerConfig(common_config, task_config) - self.initialized = True + + Runtime.step_count = 0 + Runtime.is_running = False @staticmethod def _get_execution_mode(): @@ -41,35 +53,56 @@ class PrecisionDebugger: return MsConst.PYNATIVE_MODE @classmethod - def start(cls): + def start(cls, target=None): instance = cls._instance if not instance: raise Exception("No instance of PrecisionDebugger found.") + if instance.task in PrecisionDebugger.task_not_need_service: + return instance.config.execution_mode = instance._get_execution_mode() - if instance.config.execution_mode == MsConst.PYNATIVE_MODE and instance.config.level == MsConst.API: + if instance.config.execution_mode == MsConst.PYNATIVE_MODE and instance.config.level == MsConst.API and \ + instance.config.task != Const.FREE_BENCHMARK: if not instance.service: instance.service = Service(instance.config) - instance.service.start() + instance.service.start(target) else: if not instance.first_start: handler = TaskHandlerFactory.create(instance.config) handler.handle() instance.first_start = True + Runtime.is_running = True @classmethod def stop(cls): instance = cls._instance if not instance: raise Exception("PrecisionDebugger instance is not created.") + if instance.task == Const.GRAD_PROBE: + instance.gm.stop() + if instance.task in PrecisionDebugger.task_not_need_service: + return if instance.service: instance.service.stop() + Runtime.is_running = False @classmethod def step(cls): instance = cls._instance if not instance: raise Exception("PrecisionDebugger instance is not created.") + if instance.task in PrecisionDebugger.task_not_need_service: + return if instance.service: instance.service.step() + Runtime.step_count += 1 + + @classmethod + def monitor(cls, opt): + instance = cls._instance + if not instance: + raise Exception("PrecisionDebugger instance is not created.") + if instance.task != Const.GRAD_PROBE: + return + instance.gm.monitor(opt) diff --git a/debug/accuracy_tools/msprobe/mindspore/doc/dump.md b/debug/accuracy_tools/msprobe/mindspore/doc/dump.md index 425d0683a268ebdcaf54a4f70b5e448bb1233f3c..ef2431b9c1cf878705bc5e2d7fe5fdbd55a33d47 100644 --- a/debug/accuracy_tools/msprobe/mindspore/doc/dump.md +++ b/debug/accuracy_tools/msprobe/mindspore/doc/dump.md @@ -35,10 +35,18 @@ PrecisionDebugger(config_path=None) **原型** ```Python -debugger.start() +debugger.start(model = None) ``` -该函数为类函数,可以使用debugger.start()也可以使用PrecisionDebugger.start()。 +该函数为类函数,可以使用debugger.start(model = None)也可以使用PrecisionDebugger.start(model = None) + + +**参数说明** + +| 参数名 | 说明 | 是否必选 | +| ----------- |---------------------------------------------------------------------------------------| -------- | +| model | 指具体的mindspore.nn.Cell,默认未配置,L1级别下传入model可以使能对primitive op的dump,否则无法dump primitive op。 | 否 | + ## 示例代码 diff --git a/debug/accuracy_tools/msprobe/mindspore/free_benchmark/__init__.py b/debug/accuracy_tools/msprobe/mindspore/free_benchmark/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/debug/accuracy_tools/msprobe/mindspore/free_benchmark/api_pynative_self_check.py b/debug/accuracy_tools/msprobe/mindspore/free_benchmark/api_pynative_self_check.py new file mode 100644 index 0000000000000000000000000000000000000000..bcfa31520d660c35b734b182d1bdace3e999c915 --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/free_benchmark/api_pynative_self_check.py @@ -0,0 +1,116 @@ +import os +import inspect +import importlib + +import yaml +import mindspore as ms +from mindspore.communication import comm_func + +from msprobe.core.common.const import Const +from msprobe.mindspore.common.const import FreeBenchmarkConst +from msprobe.mindspore.free_benchmark.common.config import Config +from msprobe.core.common.file_check import check_path_length, FileOpen +from msprobe.mindspore.common.log import logger +from msprobe.mindspore.debugger.debugger_config import DebuggerConfig +from msprobe.mindspore.free_benchmark.decorator.decorator_factory import decorate_forward_function + + +class ApiPyNativeSelFCheck: + def __init__(self, config: DebuggerConfig): + Config.is_enable = True + Config.handler_type = config.handler_type + Config.pert_type = config.pert_type + Config.stage = config.stage + Config.dump_level = config.dump_level + Config.steps = config.step + Config.ranks = config.rank + Config.dump_path = os.path.join(config.dump_path, "free_benchmark.csv") + check_path_length(Config.dump_path) + + self.api_list = config.list + all_api = get_supported_ops() + if not self.api_list: + self.api_list = all_api + else: + self.api_list = set(self.api_list) & all_api + + def handle(self): + for api_name in self.api_list: + hijack(api_name) + + +def get_supported_ops(): + supported_ops = [] + cur_path = os.path.dirname(os.path.realpath(__file__)) + yaml_path = os.path.join(cur_path, "data", "support_wrap_ops.yaml") + + for k, v in FreeBenchmarkConst.API_PREFIX_DICT.items(): + with FileOpen(yaml_path, 'r') as f: + ops = yaml.safe_load(f).get(k) + if ops: + ops = [v + i for i in ops] + supported_ops += ops + + _all_functional_ops = [] + ms_ops = dir(ms.ops) + ms_ops = [FreeBenchmarkConst.OPS_PREFIX + i for i in ms_ops] + _all_functional_ops += ms_ops + + ms_tensor = dir(ms.Tensor) + ms_tensor = [FreeBenchmarkConst.Tensor_PREFIX + i for i in ms_tensor] + _all_functional_ops += ms_tensor + + ms_mint = dir(ms.mint) + ms_mint = [FreeBenchmarkConst.MINT_PREFIX + i for i in ms_mint] + _all_functional_ops += ms_mint + + ms_mint_nn_func = dir(ms.mint.nn.functional) + ms_mint_nn_func = [FreeBenchmarkConst.MINT_NN_FUNC_PREFIX + i for i in ms_mint_nn_func] + _all_functional_ops += ms_mint_nn_func + + ms_communication = dir(comm_func) + ms_communication = [FreeBenchmarkConst.COMM_PREFIX + i for i in ms_communication] + _all_functional_ops += ms_communication + + return set(supported_ops) & set(_all_functional_ops) + + +def get_decorate_func(): + return decorate_forward_function + + +def is_func_support_decorate(orig_func): + return not inspect.isclass(orig_func) and callable(orig_func) + + +def get_wrapper_obj(orig_func, api_name): + if is_func_support_decorate(orig_func): + wrapped_obj = get_decorate_func()(orig_func, api_name) + else: + wrapped_obj = orig_func + return wrapped_obj + + +def get_module(api_name): + func_name_list = api_name.split(Const.SEP) + func_name = func_name_list[-1] + module_obj = importlib.import_module(func_name_list[0]) + for i, module_name in enumerate(func_name_list[1:-1]): + if not hasattr(module_obj, module_name): + importlib.import_module(f"{Const.SEP.join(func_name_list[:i+2])}") + module_obj = getattr(module_obj, module_name) + orig_func = getattr(module_obj, func_name) + + return module_obj, orig_func + + +def hijack(api_name): + if not api_name.strip(): + return + try: + func_name = api_name.split(Const.SEP)[-1] + module_obj, origin_func = get_module(api_name) + wrapped_obj = get_wrapper_obj(origin_func, api_name) + setattr(module_obj, func_name, wrapped_obj) + except Exception as e: + logger.error(f"Failed decorator {api_name}: {e}") diff --git a/debug/accuracy_tools/msprobe/mindspore/free_benchmark/common/__init__.py b/debug/accuracy_tools/msprobe/mindspore/free_benchmark/common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/debug/accuracy_tools/msprobe/mindspore/free_benchmark/common/config.py b/debug/accuracy_tools/msprobe/mindspore/free_benchmark/common/config.py new file mode 100644 index 0000000000000000000000000000000000000000..85f684d8164903d91a9b0a703749fcf83b9e5249 --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/free_benchmark/common/config.py @@ -0,0 +1,12 @@ +from msprobe.mindspore.common.const import FreeBenchmarkConst + + +class Config: + is_enable: bool = False + handler_type = FreeBenchmarkConst.DEFAULT_HANDLER_TYPE + pert_type = FreeBenchmarkConst.DEFAULT_PERT_TYPE + stage = FreeBenchmarkConst.DEFAULT_STAGE + dump_level = FreeBenchmarkConst.DEFAULT_DUMP_LEVEL + steps: list = [] + ranks: list = [] + dump_path: str = "" diff --git a/debug/accuracy_tools/msprobe/mindspore/free_benchmark/common/handler_params.py b/debug/accuracy_tools/msprobe/mindspore/free_benchmark/common/handler_params.py new file mode 100644 index 0000000000000000000000000000000000000000..ae1733b986062d3540ef36624b36c59806a8976e --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/free_benchmark/common/handler_params.py @@ -0,0 +1,17 @@ +from typing import Optional, Any, Tuple, Dict, Callable + + +class HandlerParams: + """ + 参数结合体 + + """ + args: Optional[Tuple] = None + kwargs: Optional[Dict] = None + index: Optional[int] = None + original_result: Optional[Any] = None + fuzzed_result: Optional[Any] = None + is_consistent: Optional[bool] = True + save_flag: Optional[bool] = True + fuzzed_value: Optional[Any] = None + original_func: Optional[Callable] = None diff --git a/debug/accuracy_tools/msprobe/mindspore/free_benchmark/common/utils.py b/debug/accuracy_tools/msprobe/mindspore/free_benchmark/common/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3bb062800b9f930bfbdf5c3f8b3a2799b0d95566 --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/free_benchmark/common/utils.py @@ -0,0 +1,71 @@ +from typing import Any +from typing import Optional +from dataclasses import dataclass + +import mindspore as ms +from mindspore import Tensor + +from msprobe.mindspore.runtime import Runtime +from msprobe.mindspore.common.const import FreeBenchmarkConst +from .config import Config +from .handler_params import HandlerParams + + +class Tools: + + @staticmethod + def get_first_tensor_dtype(tensor_seq: Any): + if isinstance(tensor_seq, Tensor): + return tensor_seq.dtype + if isinstance(tensor_seq, (list, tuple)): + for i in tensor_seq: + if isinstance(i, Tensor): + return i.dtype + raise Exception("The sequence does not contain tensors.") + + @staticmethod + def get_default_error_threshold(dtype): + if Config.pert_type == FreeBenchmarkConst.NO_CHANGE: + return FreeBenchmarkConst.NO_CHANGE_ERROR_THRESHOLD + return FreeBenchmarkConst.ERROR_THRESHOLD.get(dtype, FreeBenchmarkConst.ERROR_THRESHOLD.get(ms.float32)) + + +@dataclass +class UnequalRow: + rank: Optional[int] = None + pert_type: Optional[str] = None + stage: Optional[str] = None + step: Optional[int] = None + api_name: Optional[str] = None + max_rel: Optional[float] = None + dtype: Optional[str] = None + shape: Optional[str] = None + output_index: Optional[int] = None + + +def make_unequal_row( + api_name: str, + params: HandlerParams, + ratio: float = None, + index: int = None, +): + row = UnequalRow( + api_name=api_name, + pert_type=Config.pert_type, + output_index=index, + stage=Config.stage, + step=Runtime.step_count + ) + if isinstance(ratio, float): + row.max_rel = ratio - 1 + original_tensor = params.original_result + fuzzed_tensor = params.fuzzed_result + if index: + original_tensor = original_tensor[index] + fuzzed_tensor = fuzzed_tensor[index] + row.output_index = index + if isinstance(original_tensor, Tensor): + row.dtype = original_tensor.dtype + row.shape = original_tensor.shape + row.rank = Runtime.rank_id if Runtime.rank_id != -1 else None + return row diff --git a/debug/accuracy_tools/msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml b/debug/accuracy_tools/msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml new file mode 100644 index 0000000000000000000000000000000000000000..cc802d38142fd3d68aad03ee75abbbe77ce7eb35 --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/free_benchmark/data/support_wrap_ops.yaml @@ -0,0 +1,842 @@ +# List of apis that support self check + +communication: + - all_gather_into_tensor + - gather_into_tensor + - all_reduce + - reduce + - reduce_scatter_tensor + +ops: + - adaptive_avg_pool1d + - adaptive_avg_pool2d + - adaptive_avg_pool3d + - adaptive_max_pool1d + - adaptive_max_pool2d + - avg_pool1d + - avg_pool2d + - avg_pool3d + - batch_norm + - bias_add + - ctc_greedy_decoder + - conv1d + - conv2d + - conv3d + - deformable_conv2d + - dense + - dropout + - dropout1d + - dropout2d + - dropout3d + - flatten + - fold + - fractional_max_pool3d + - lp_pool1d + - lp_pool2d + - lrn + - max_pool2d + - max_pool3d + - max_unpool1d + - max_unpool2d + - max_unpool3d + - unfold + - binary_cross_entropy + - binary_cross_entropy_with_logits + - cosine_embedding_loss + - cross_entropy + - ctc_loss + - gaussian_nll_loss + - hinge_embedding_loss + - huber_loss + - kl_div + - l1_loss + - margin_ranking_loss + - mse_loss + - multi_margin_loss + - multilabel_margin_loss + - multilabel_soft_margin_loss + - nll_loss + - smooth_l1_loss + - triplet_margin_loss + - elu + - fast_gelu + - gelu + - glu + - gumbel_softmax + - hardshrink + - hardsigmoid + - hardswish + - hardtanh + - leaky_relu + - log_softmax + - logsigmoid + - mish + - prelu + - relu + - relu6 + - rrelu + - selu + - sigmoid + - silu + - softmax + - softmin + - softshrink + - softsign + - tanh + - threshold + - cdist + - dist + - pdist + - choice_with_mask + - random_categorical + - log_uniform_candidate_sampler + - uniform_candidate_sampler + - affine_grid + - bounding_box_decode + - bounding_box_encode + - col2im + - check_valid + - crop_and_resize + - grid_sample + - interpolate + - iou + - pad + - padding + - pixel_shuffle + - pixel_unshuffle + - upsample + - abs + - absolute + - accumulate_n + - acos + - arccos + - acosh + - add + - addcdiv + - addcmul + - addmv + - addn + - angle + - arccosh + - arcsin + - arcsinh + - arctan + - arctanh + - arctan2 + - asin + - asinh + - atan + - atan2 + - atanh + - atleast_1d + - atleast_2d + - atleast_3d + - bessel_i0 + - bessel_i0e + - bessel_i1 + - bessel_i1e + - bessel_j0 + - bessel_j1 + - bessel_k0 + - bessel_k0e + - bessel_k1 + - bessel_k1e + - bessel_y0 + - bessel_y1 + - bitwise_and + - bitwise_left_shift + - bitwise_or + - bitwise_right_shift + - bitwise_xor + - ceil + - clamp + - clip + - combinations + - copysign + - cos + - cosh + - cosine_similarity + - cov + - diag_embed + - diff + - deg2rad + - digamma + - div + - divide + - erf + - erfc + - erfinv + - exp + - exp2 + - expm1 + - floor + - floor_div + - floor_mod + - float_power + - fmod + - frac + - gcd + - hypot + - igamma + - igammac + - imag + - i0 + - inv + - invert + - lcm + - ldexp + - lerp + - log + - log2 + - log10 + - log1p + - logaddexp + - logaddexp2 + - logical_and + - logical_not + - logical_or + - logical_xor + - logit + - mul + - multiply + - mvlgamma + - neg + - negative + - nextafter + - polar + - polygamma + - positive + - pow + - rad2deg + - ravel + - real + - reciprocal + - remainder + - rot90 + - round + - rsqrt + - sgn + - sign + - signbit + - sin + - sinc + - sinh + - sqrt + - square + - sub + - subtract + - t + - tan + - tanhshrink + - trapz + - tril_indices + - triu_indices + - true_divide + - trunc + - truncate_div + - truncate_mod + - xdivy + - xlogy + - zeta + - all + - amax + - amin + - aminmax + - any + - argmax + - argmin + - cummax + - cummin + - cumprod + - cumsum + - fmax + - histc + - logsumexp + - max + - mean + - median + - min + - norm + - prod + - std + - std_mean + - var + - var_mean + - argsort + - approximate_equal + - equal + - ge + - greater + - greater_equal + - gt + - intopk + - isclose + - isfinite + - isinf + - isnan + - isneginf + - isposinf + - isreal + - le + - less + - less_equal + - lt + - maximum + - minimum + - msort + - ne + - not_equal + - searchsorted + - topk + - bmm + - addbmm + - addmm + - baddbmm + - addr + - adjoint + - cholesky + - cholesky_solve + - batch_dot + - dot + - eig + - inner + - inverse + - geqrf + - ger + - kron + - lu_solve + - lu_unpack + - matmul + - matrix_solve + - matrix_band_part + - matrix_diag + - matrix_diag_part + - matrix_set_diag + - mm + - mv + - outer + - orgqr + - ormqr + - pinv + - svd + - tensor_dot + - logdet + - slogdet + - qr + - trace + - bartlett_window + - blackman_window + - hamming_window + - hann_window + - kaiser_window + - eye + - fill + - full + - full_like + - linspace + - logspace + - one_hot + - arange + - range + - heaviside + - bernoulli + - gamma + - laplace + - multinomial + - multinomial_with_replacement + - rand + - rand_like + - randint + - randint_like + - randn + - randn_like + - random_gamma + - random_poisson + - randperm + - standard_laplace + - standard_normal + - uniform + - argwhere + - batch_to_space_nd + - bincount + - block_diag + - broadcast_to + - cat + - channel_shuffle + - chunk + - column_stack + - concat + - conj + - count_nonzero + - deepcopy + - diag + - diagflat + - diagonal + - dyn_shape + - dsplit + - dstack + - einsum + - expand + - expand_dims + - flip + - fliplr + - flipud + - gather_d + - gather_elements + - gather_nd + - hsplit + - hstack + - masked_fill + - masked_select + - meshgrid + - moveaxis + - movedim + - narrow + - nan_to_num + - nansum + - normal + - nonzero + - population_count + - rank + - repeat_elements + - repeat_interleave + - reshape + - reverse + - reverse_sequence + - roll + - select + - sequence_mask + - shuffle + - size + - slice + - sort + - space_to_batch_nd + - sparse_segment_mean + - split + - squeeze + - stack + - strided_slice + - sum + - swapaxes + - swapdims + - tensor_split + - tile + - tril + - triu + - transpose + - unbind + - unique + - unique_consecutive + - unique_with_pad + - unsorted_segment_max + - unsorted_segment_min + - unsorted_segment_prod + - unsorted_segment_sum + - unsqueeze + - unstack + - view_as_real + - vsplit + - vstack + - where + - cross + - renorm + - tuple_to_array + - clip_by_global_norm + - clip_by_value + - derivative + - jet + +Tensor: + - __abs__ + - __add__ + - __and__ + - __iadd__ + - __ifloordiv__ + - __imatmul__ + - __imod__ + - __imul__ + - __isub__ + - __matmul__ + - __mod__ + - __mul__ + - __neg__ + - __or__ + - __pow__ + - __radd__ + - __rmatmul__ + - __rmod__ + - __rmul__ + - __rpow__ + - __rsub__ + - __sub__ + - __truediv__ + - __xor__ + - abs + - absolute + - acos + - acosh + - add + - addbmm + - addcdiv + - addcmul + - addmm + - addmv + - addr + - all + - amax + - amin + - any + - arccos + - arccosh + - argmax + - angle + - arcsin + - arcsinh + - arctan + - arctanh + - argmin + - argsort + - asin + - asinh + - atan + - atan2 + - atanh + - baddbmm + - bernoulli + - bincount + - bitwise_and + - bitwise_or + - bitwise_xor + - bmm + - broadcast_to + - ceil + - cholesky_solve + - cholesky + - clamp + - clip + - conj + - copysign + - cos + - cosh + - cross + - cummax + - cummin + - cumprod + - cumsum + - deg2rad + - diag + - diagflat + - diff + - digamma + - div + - divide + - equal + - erf + - erfc + - erfinv + - exp + - expand_as + - expm1 + - flip + - fliplr + - flipud + - float_power + - floor + - fmod + - frac + - gather_elements + - geqrf + - ger + - greater + - greater_equal + - half + - hardshrink + - heaviside + - histc + - hypot + - i0 + - igamma + - igammac + - imag + - index_add + - index_fill + - index_put + - index_select + - inner + - int + - inverse + - item + - lcm + - ldexp + - lerp + - log + - log10 + - log1p + - log2 + - logaddexp + - logaddexp2 + - logdet + - logical_and + - logical_not + - logical_or + - logical_xor + - logit + - logsumexp + - long + - masked_fill + - masked_scatter + - masked_select + - matmul + - max + - maximum + - mean + - median + - min + - minimum + - moveaxis + - movedim + - msort + - multinomial + - multiply + - mvlgamma + - nan_to_num + - nansum + - narrow + - neg + - negative + - nelement + - new_ones + - new_zeros + - nextafter + - norm + - nonzero + - not_equal + - ormqr + - permute + - pow + - prod + - qr + - ravel + - real + - reciprocal + - remainder + - renorm + - rad2deg + - tile + - repeat_interleave + - reshape + - reshape + - round + - rot90 + - rsqrt + - sum_to_size + - scatter + - sgn + - short + - sigmoid + - sign + - signbit + - sin + - sinc + - sinh + - slogdet + - sort + - split + - sqrt + - square + - squeeze + - std + - subtract + - subtract + - svd + - swapaxes + - swapdims + - t + - take + - tan + - tanh + - trace + - swapaxes + - tile + - topk + - tril + - tensor_split + - transpose + - true_divide + - trunc + - unbind + - unique_consecutive + - unsqueeze + - var + - view + - where + - xlogy + - from_numpy + - std + - take + - var + - all + - any + - copy + - diagonal + - flatten + - resize + - sum + +mint: + - abs + - absolute_import + - add + - add_ex + - all + - any + - any_ex + - arange + - argmax + - avg_pool2d + - baddbmm + - baddbmm_ex + - batch_norm + - binary_cross_entropy_with_logits + - bitwise_and + - bitwise_or + - bitwise_xor + - bmm + - broadcast_to + - cat + - cat_ex + - ceil + - chunk + - clamp + - conv2d + - conv_transpose2d + - cos + - cross + - cummax + - cummin + - cumsum + - div + - divide + - dropout + - embedding + - eq + - erf + - erfinv + - exp + - flatten + - flip + - flip_ex + - fold + - full + - gather + - gelu + - greater + - grid_sample + - group_norm + - gt + - index_select + - interpolate + - isclose + - isfinite + - layer_norm + - le + - leaky_relu + - less + - less_equal + - linear + - linspace + - log + - logical_and + - logical_not + - logical_or + - lt + - masked_select + - matmul + - max + - max_pool2d + - maximum + - mean + - mean_ex + - min + - minimum + - mul + - ne + - neg + - negative + - nonzero + - normal + - one_hot + - ones + - ones_ex + - ones_like + - pad + - permute + - permute_ex + - pow + - prod + - reciprocal + - relu + - remainder + - repeat_interleave + - rsqrt + - searchsorted + - sigmoid + - silu + - sin + - softmax + - softplus + - sort + - split + - sqrt + - sqrt_ex + - square + - stack + - sub + - sub_ex + - sum + - tanh + - tile + - topk + - tril + - triu + - unfold + - unique + - where + - xlogy + - zeros + - zeros_ex + - zeros_like + +mint.nn.functional: + - absolute_import + - avg_pool2d + - batch_norm + - batch_norm_ex + - bce_with_logits + - binary_cross_entropy_with_logits + - conv_transpose2d + - dense + - dropout + - embedding + - fold + - gelu + - grid_sample + - group_norm + - interpolate + - layer_norm + - leaky_relu + - linear + - max_pool2d + - max_pool2d_ex + - normal + - one_hot + - one_hot_ext + - pad + - relu + - sigmoid + - silu + - softmax + - softmax_ex + - softplus + - tanh + - unfold diff --git a/debug/accuracy_tools/msprobe/mindspore/free_benchmark/decorator/__init__.py b/debug/accuracy_tools/msprobe/mindspore/free_benchmark/decorator/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/debug/accuracy_tools/msprobe/mindspore/free_benchmark/decorator/dec_forward.py b/debug/accuracy_tools/msprobe/mindspore/free_benchmark/decorator/dec_forward.py new file mode 100644 index 0000000000000000000000000000000000000000..78661d7fca6b40247eff15bcb232e997db626dd3 --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/free_benchmark/decorator/dec_forward.py @@ -0,0 +1,42 @@ +from msprobe.mindspore.free_benchmark.common.config import Config +from msprobe.mindspore.common.const import FreeBenchmarkConst +from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams +from msprobe.mindspore.free_benchmark.handler.handler_factory import HandlerFactory +from msprobe.mindspore.free_benchmark.perturbation.perturbation_factory import PerturbationFactory + + +class ForwardSelfChecker: + + def __init__(self, api_name: str): + self.api_name = api_name + + def handle(self, params: HandlerParams): + """ + 装饰器实际执行逻辑 + + """ + perturbation = PerturbationFactory.create(self.api_name) + params.fuzzed_result = perturbation.handle(params) + params.original_result = params.original_func(*params.args, **params.kwargs) + if params.fuzzed_result is not False: + return self.deal_fuzzed_and_original_result(params) + return params.original_result + + def get_compare_data(self, params: HandlerParams): + if self.api_name not in FreeBenchmarkConst.COMMUNICATION_API_LIST: + return + # 以下为通讯类api处理逻辑 + params.fuzzed_result = params.fuzzed_value + if Config.pert_type == FreeBenchmarkConst.IMPROVE_PRECISION: + params.original_result = params.args + else: + params.original_result = params.args[params.index] + + def deal_fuzzed_and_original_result(self, params: HandlerParams): + original_result = params.original_result + self.get_compare_data(params) + handler = HandlerFactory.create(self.api_name) + result = handler.handle(params) + if self.api_name in FreeBenchmarkConst.COMMUNICATION_API_LIST: + result = original_result + return result diff --git a/debug/accuracy_tools/msprobe/mindspore/free_benchmark/decorator/decorator_factory.py b/debug/accuracy_tools/msprobe/mindspore/free_benchmark/decorator/decorator_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..c1cf50e9c33862f77cc22d127c4cba2860c4c7c3 --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/free_benchmark/decorator/decorator_factory.py @@ -0,0 +1,107 @@ +import os +import sys +import traceback +from functools import wraps +from typing import Tuple, Dict, List + +from mindspore import ops + +from msprobe.mindspore.runtime import Runtime +from msprobe.mindspore.common.log import logger +from msprobe.mindspore.free_benchmark.common.config import Config +from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams +from .dec_forward import ForwardSelfChecker + + +def decorate(original_func, decorate_func, api_name=None): + """ + 总装饰器 + """ + @wraps(original_func) + def fuzz_wrapper(*args, **kwargs): + + def __exec_decorate_func(): + params = data_pre_deal(api_name, original_func, *args, **kwargs) + result = decorate_func(params) + return result + + try: + if Runtime.rank_id == -1: + Runtime.rank_id = os.environ.get("RANK_ID", -1) + if need_wrapper_func(): + logger.info(f"[{api_name}] is checking.") + return __exec_decorate_func() + except Exception as e: + logger.error(f"[{api_name}] Error: {str(e)}") + logger.error(f"[{api_name}] Error detail: {traceback.format_exc()}") + + return original_func(*args, **kwargs) + + return fuzz_wrapper + + +def decorate_forward_function(func, api_name=None): + """ + 前向装饰器 + """ + + if not api_name: + api_name = func.__name__ + + def forward_func(params: HandlerParams): + forward = ForwardSelfChecker(api_name) + result = forward.handle(params) + return result + + return decorate(func, forward_func, api_name) + + +def stack_depth_check() -> bool: + nested_depth = 1 + frame = sys._getframe(1) + while frame: + if frame.f_code.co_name == "fuzz_wrapper": + nested_depth -= 1 + if nested_depth < 0: + return False + frame = frame.f_back + return True + + +def get_target_arg_index(args: Tuple) -> int: + """ + 类型校验 + + """ + for i, arg in enumerate(args): + if ops.is_tensor(arg): + if not ops.is_floating_point(arg): + continue + return i + if isinstance(arg, (List, Tuple, Dict)): + return i + return -1 + + +def data_pre_deal(api_name, func, *args, **kwargs): + params = HandlerParams() + params.args = args + params.kwargs = kwargs + params.original_func = func + index = get_target_arg_index(args) + if index == -1: + raise Exception(f"{api_name} has no supported input type") + params.index = index + return params + + +def need_wrapper_func(): + if not (Runtime.is_running and Config.is_enable): + return False + if not stack_depth_check(): + return False + if Config.steps and Runtime.step_count not in Config.steps: + return False + if Config.ranks and Runtime.rank_id != -1 and Runtime.rank_id not in Config.ranks: + return False + return True diff --git a/debug/accuracy_tools/msprobe/mindspore/free_benchmark/handler/__init__.py b/debug/accuracy_tools/msprobe/mindspore/free_benchmark/handler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/debug/accuracy_tools/msprobe/mindspore/free_benchmark/handler/base_handler.py b/debug/accuracy_tools/msprobe/mindspore/free_benchmark/handler/base_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..f35d23498d8b6a98d96df3e678630d41038c9162 --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/free_benchmark/handler/base_handler.py @@ -0,0 +1,90 @@ +import math +from abc import ABC, abstractmethod +from typing import Any, Tuple, Optional + +import mindspore as ms +from mindspore import Tensor, ops + +from msprobe.mindspore.common.log import logger +from msprobe.mindspore.free_benchmark.common.utils import Tools +from msprobe.mindspore.common.const import FreeBenchmarkConst +from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams + + +class BaseHandler(ABC): + + def __init__(self, api_name: str): + self.api_name = api_name + + @staticmethod + def pre_calculate(original_output, fuzzed_output): + abs_tol = FreeBenchmarkConst.PERT_VALUE_DICT.get(fuzzed_output.dtype, + FreeBenchmarkConst.PERT_VALUE_DICT.get(ms.float32)) + + return original_output.to(fuzzed_output.dtype), fuzzed_output, abs_tol + + @staticmethod + def get_threshold(dtype): + err = Tools.get_default_error_threshold(dtype) + return err + + @staticmethod + def convert_overflow_ratio_to_consistent(ratio): + if math.isnan(ratio) or math.isinf(ratio): + return FreeBenchmarkConst.NO_CHANGE_ERROR_THRESHOLD + return ratio + + @staticmethod + def get_endless_norm(first_tensor, second_tensor, abs_tol): + if first_tensor.dtype != ms.bfloat16 and second_tensor.dtype != ms.bfloat16: + ratio_tensor1 = ops.where(ops.abs(second_tensor) > abs_tol, ops.div(first_tensor, second_tensor), 1) + ratio_tensor2 = ops.where(ops.abs(first_tensor) > abs_tol, ops.div(second_tensor, first_tensor), 1) + else: + ratio_tensor1 = ops.where(ops.abs(second_tensor).to(ms.float32) > abs_tol, + ops.div(first_tensor.to(ms.float32), second_tensor.to(ms.float32)), 1) + ratio_tensor2 = ops.where(ops.abs(first_tensor).to(ms.float32) > abs_tol, + ops.div(second_tensor.to(ms.float32), first_tensor.to(ms.float32)), 1) + norm1 = BaseHandler.convert_overflow_ratio_to_consistent(ops.max(ratio_tensor1)[0].to(ms.float32).item()) + norm2 = BaseHandler.convert_overflow_ratio_to_consistent(ops.max(ratio_tensor2)[0].to(ms.float32).item()) + norm3 = BaseHandler.convert_overflow_ratio_to_consistent(ops.min(ratio_tensor1)[0].to(ms.float32).item()) + ratio = FreeBenchmarkConst.SYMBOL_FLIPPING_RATIO if norm3 < 0 else max(norm1, norm2) + + return ratio + + @staticmethod + def ratio_calculate(original_output, fuzzed_output) -> float: + try: + original_output, fuzzed_output, abs_tol = BaseHandler.pre_calculate(original_output, fuzzed_output) + except Exception as e: + logger.error(f"When computing ratio, y1 or y2 dtype is not supported {str(e)}") + return FreeBenchmarkConst.NO_CHANGE_ERROR_THRESHOLD + + abs_tol = abs_tol ** 0.5 + + return BaseHandler.get_endless_norm(original_output, fuzzed_output, abs_tol) + + @staticmethod + def npu_compare(original_output, fuzzed_output) -> Tuple[bool, Optional[float]]: + if not isinstance(fuzzed_output, Tensor): + logger.error(f"The compare for output type `{type(fuzzed_output)}` is not supported") + return True, 1.0 + + # 范数计算等 + err_thd = BaseHandler.get_threshold(original_output.dtype) + ratio = BaseHandler.ratio_calculate(original_output, fuzzed_output) + is_consistent = err_thd >= ratio >= 1.0 / err_thd + return is_consistent, ratio + + @staticmethod + def is_float_tensor(output) -> bool: + if isinstance(output, Tensor) and ops.is_floating_point(output): + return True + if isinstance(output, (list, tuple)): + for i in output: + if isinstance(i, Tensor) and ops.is_floating_point(i): + return True + return False + + @abstractmethod + def handle(self, params: HandlerParams) -> Any: + pass diff --git a/debug/accuracy_tools/msprobe/mindspore/free_benchmark/handler/check_handler.py b/debug/accuracy_tools/msprobe/mindspore/free_benchmark/handler/check_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..df80e76c0e84db136a464d18c6d25cf7435944da --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/free_benchmark/handler/check_handler.py @@ -0,0 +1,41 @@ +from typing import Any +from dataclasses import asdict + +from mindspore import Tensor, ops + +from msprobe.mindspore.common.log import logger +from msprobe.mindspore.free_benchmark.common.config import Config +from msprobe.mindspore.free_benchmark.handler.base_handler import BaseHandler +from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams +from msprobe.mindspore.free_benchmark.common.utils import make_unequal_row +from msprobe.core.data_dump.json_writer import DataWriter + + +class CheckHandler(BaseHandler): + + def npu_compare_and_save(self, original_output, fuzzed_output, params: HandlerParams, output_index=None): + is_consistent, ratio = self.npu_compare(original_output, fuzzed_output) + params.is_consistent = params.is_consistent and is_consistent + if not is_consistent: + row = make_unequal_row(self.api_name, params, ratio, output_index) + data_dict = asdict(row) + DataWriter.write_data_to_csv( + data_dict.values(), + data_dict.keys(), + Config.dump_path + ) + logger.error(f"{self.api_name} is not consistent") + + def handle(self, params: HandlerParams) -> Any: + try: + if not self.is_float_tensor(params.fuzzed_result): + return params.original_result + if isinstance(params.fuzzed_result, Tensor): + self.npu_compare_and_save(params.original_result, params.fuzzed_result, params) + elif isinstance(params.fuzzed_result, (list, tuple)): + for i, item in enumerate(params.original_result): + if ops.is_tensor(item) and ops.is_floating_point(item): + self.npu_compare_and_save(item, params.fuzzed_result[i], params, output_index=i) + except Exception as e: + logger.error(str(e)) + return params.original_result diff --git a/debug/accuracy_tools/msprobe/mindspore/free_benchmark/handler/fix_handler.py b/debug/accuracy_tools/msprobe/mindspore/free_benchmark/handler/fix_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..2c377ba896b45fea991ecd147d419210a3e8dc7a --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/free_benchmark/handler/fix_handler.py @@ -0,0 +1,36 @@ +from typing import Any + +from mindspore import Tensor + +from msprobe.mindspore.common.log import logger +from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams + + +class FixHandler: + + def __init__(self, api_name: str): + self.api_name = api_name + + @staticmethod + def use_fuzzed_result(original_result, fuzzed_result): + if isinstance(original_result, Tensor): + return fuzzed_result.to(original_result.dtype) + if isinstance(original_result, dict): + dict_fixed_result = dict() + for k, v in original_result.items(): + dict_fixed_result[k] = FixHandler.use_fuzzed_result(v, fuzzed_result[k]) + return dict_fixed_result + if isinstance(original_result, (tuple, list)): + list_fixed_result = list() + for i, v in enumerate(original_result): + list_fixed_result.append(FixHandler.use_fuzzed_result(v, fuzzed_result[i])) + return type(original_result)(list_fixed_result) + return original_result + + def handle(self, params: HandlerParams) -> Any: + try: + return FixHandler.use_fuzzed_result(params.original_result, params.fuzzed_result) + except Exception as e: + logger.error(f"{self.api_name} failed to fix.") + logger.error(str(e)) + return params.original_result diff --git a/debug/accuracy_tools/msprobe/mindspore/free_benchmark/handler/handler_factory.py b/debug/accuracy_tools/msprobe/mindspore/free_benchmark/handler/handler_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..bf8c681e54eb24e9fe1d8b6a71df5dba612b3c0d --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/free_benchmark/handler/handler_factory.py @@ -0,0 +1,21 @@ +from msprobe.mindspore.common.log import logger +from msprobe.mindspore.free_benchmark.common.config import Config +from msprobe.mindspore.common.const import FreeBenchmarkConst +from .check_handler import CheckHandler +from .fix_handler import FixHandler + + +class HandlerFactory: + result_handlers = { + FreeBenchmarkConst.CHECK: CheckHandler, + FreeBenchmarkConst.FIX: FixHandler, + } + + @staticmethod + def create(api_name: str): + handler = HandlerFactory.result_handlers.get(Config.handler_type) + if handler: + return handler(api_name) + else: + logger.error(f"{Config.handler_type} is not supported.") + raise Exception diff --git a/debug/accuracy_tools/msprobe/mindspore/free_benchmark/perturbation/add_noise.py b/debug/accuracy_tools/msprobe/mindspore/free_benchmark/perturbation/add_noise.py new file mode 100644 index 0000000000000000000000000000000000000000..2764d3d4908560a964f3ca804d2d625f05bf3b61 --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/free_benchmark/perturbation/add_noise.py @@ -0,0 +1,67 @@ +from typing import Any + +from mindspore import Tensor, ops + +from msprobe.mindspore.common.log import logger +from msprobe.mindspore.free_benchmark.perturbation.base_perturbation import BasePerturbation +from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams +from msprobe.mindspore.common.const import FreeBenchmarkConst + + +class AddNoisePerturbation(BasePerturbation): + + def handle(self, params: HandlerParams) -> Any: + """ + 返回增加扰动后的api输出 + + """ + params.fuzzed_value = self.add_noise(params.args[params.index]) + if not self.is_fuzzed: + logger.warning(f"{self.api_name} can not add noise.") + return False + return self.get_fuzzed_result(params) + + def add_noise(self, inputs) -> Any: + """ + 返回增加扰动后的api输入 + + """ + if isinstance(inputs, Tensor): + noise = self._get_noise(inputs) + if noise is not False: + result = ops.where(ops.abs(inputs) > self.perturbation_value ** 0.5, + ops.add(noise, inputs), inputs) + result = result.type(dtype=inputs.dtype) + self.is_fuzzed = True + return result + + if isinstance(inputs, dict): + return {k: self.add_noise(v) for k, v in inputs.items()} + + if isinstance(inputs, (list, tuple)): + return [self.add_noise(v) for v in inputs] + + return inputs + + def _get_noise(self, input): + """ + 得到要添加的噪声值 + + """ + if self.is_fuzzed: + return False + if not ops.is_floating_point(input) or ops.numel(input) == 0: + return False + + pert_value = FreeBenchmarkConst.PERT_VALUE_DICT.get(input.dtype) + if not pert_value: + return False + else: + self.perturbation_value = pert_value + + max_val = ops.max(ops.abs(input))[0].item() + if max_val < pert_value: + return False + + noise = ops.full(input.shape, self.perturbation_value, dtype=input.dtype) + return noise diff --git a/debug/accuracy_tools/msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py b/debug/accuracy_tools/msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py new file mode 100644 index 0000000000000000000000000000000000000000..becfe2964a3e8f258d57ff9539ff89c03a817150 --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/free_benchmark/perturbation/base_perturbation.py @@ -0,0 +1,21 @@ +from typing import Any + +from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams + + +class BasePerturbation: + + def __init__(self, api_name: str): + self.api_name = api_name + self.is_fuzzed = False + self.perturbation_value = None + + @staticmethod + def get_fuzzed_result(params: HandlerParams): + args_front = params.args[:params.index] + args_rear = params.args[params.index + 1:] + fuzzed_result = params.original_func(*args_front, params.fuzzed_value, *args_rear, **params.kwargs) + return fuzzed_result + + def handler(self, params: HandlerParams) -> Any: + pass diff --git a/debug/accuracy_tools/msprobe/mindspore/free_benchmark/perturbation/bit_noise.py b/debug/accuracy_tools/msprobe/mindspore/free_benchmark/perturbation/bit_noise.py new file mode 100644 index 0000000000000000000000000000000000000000..65202e0f66f7f57b71bffa1115c9f20260123edc --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/free_benchmark/perturbation/bit_noise.py @@ -0,0 +1,63 @@ +from typing import Any + +import numpy as np +from mindspore import Tensor, ops + +from msprobe.mindspore.common.log import logger +from msprobe.mindspore.common.const import FreeBenchmarkConst +from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams +from msprobe.mindspore.free_benchmark.perturbation.base_perturbation import BasePerturbation + + +class BitNoisePerturbation(BasePerturbation): + + def add_bit_noise(self, inputs) -> Any: + if isinstance(inputs, Tensor): + bit_len_type = self._get_bit_len_type(inputs) + if bit_len_type is not False: + sub_normal_np = np.finfo(FreeBenchmarkConst.MS_NUMPY_DTYPE_DICT.get(inputs.dtype)).smallest_normal + sub_normal = Tensor(sub_normal_np) + noise_type = list(FreeBenchmarkConst.MS_NUMPY_DTYPE_DICT.keys())[ + list(FreeBenchmarkConst.MS_NUMPY_DTYPE_DICT.values()).index(bit_len_type)] + noise = ops.full(inputs.shape, 1, dtype=noise_type) + input_np = inputs.asnumpy() + input_np_int = input_np.view(bit_len_type) + result = Tensor(input_np_int) + result = ops.where(ops.abs(inputs) > sub_normal, + ops.bitwise_xor(result, noise), result) + result_np = result.asnumpy() + result_np_float = result_np.view(FreeBenchmarkConst.MS_NUMPY_DTYPE_DICT.get(inputs.dtype)) + self.is_fuzzed = True + return Tensor(result_np_float) + + if isinstance(inputs, dict): + return {k: self.add_bit_noise(v) for k, v in inputs.items()} + if isinstance(inputs, (tuple, list)): + return type(inputs)([self.add_bit_noise(v) for v in inputs]) + return inputs + + def handle(self, params: HandlerParams) -> any: + args = params.args + params.fuzzed_value = self.add_bit_noise(params.args[params.index]) + if not self.is_fuzzed: + logger.warning(f"{self.api_name} can not add bit noise.") + return False + params.args = args + return self.get_fuzzed_result(params) + + def _get_bit_len_type(self, input): + if self.is_fuzzed: + return False + if not isinstance(input, Tensor) or not ops.is_floating_point(input) or \ + input.numel() == 0: + return False + bit_len_type = FreeBenchmarkConst.PERT_BIT_DICT.get(input.dtype) + if not bit_len_type: + return False + pert_value = FreeBenchmarkConst.PERT_VALUE_DICT.get(input.dtype) + if not pert_value: + return False + max_val = ops.max(ops.abs(input))[0].item() + if max_val < pert_value: + return False + return bit_len_type diff --git a/debug/accuracy_tools/msprobe/mindspore/free_benchmark/perturbation/improve_precision.py b/debug/accuracy_tools/msprobe/mindspore/free_benchmark/perturbation/improve_precision.py new file mode 100644 index 0000000000000000000000000000000000000000..f55a96aca3f9cc1f782ef83339782a38ce9c92cf --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/free_benchmark/perturbation/improve_precision.py @@ -0,0 +1,34 @@ +from typing import Any + +import mindspore as ms +from mindspore import Tensor, ops + +from msprobe.mindspore.free_benchmark.perturbation.base_perturbation import BasePerturbation +from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams +from msprobe.mindspore.common.const import FreeBenchmarkConst +from msprobe.mindspore.common.log import logger + + +class ImprovePrecisionPerturbation(BasePerturbation): + + def improve_tensor_precision(self, target_tensor): + if isinstance(target_tensor, Tensor) and ops.is_floating_point(target_tensor) and \ + target_tensor.dtype not in [ms.float64, ms.float32]: + self.is_fuzzed = True + return target_tensor.to(ms.float32) + if isinstance(target_tensor, dict): + return {k: self.improve_tensor_precision(v) for k, v in target_tensor.items()} + if isinstance(target_tensor, (tuple, list)): + return type(target_tensor)([self.improve_tensor_precision(v) for v in target_tensor]) + return target_tensor + + def handle(self, params: HandlerParams) -> Any: + args = self.improve_tensor_precision(params.args) + kwargs = self.improve_tensor_precision(params.kwargs) + fuzzed_value = args + if self.api_name in FreeBenchmarkConst.COMMUNICATION_API_LIST: + params.fuzzed_value = fuzzed_value + if not self.is_fuzzed: + logger.warning(f"{self.api_name} can not improve precision.") + return False + return params.original_func(*args, **kwargs) diff --git a/debug/accuracy_tools/msprobe/mindspore/free_benchmark/perturbation/no_change.py b/debug/accuracy_tools/msprobe/mindspore/free_benchmark/perturbation/no_change.py new file mode 100644 index 0000000000000000000000000000000000000000..fc844bfd6b0c48d27a1dfa2ea6728cb2f1ba69ca --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/free_benchmark/perturbation/no_change.py @@ -0,0 +1,12 @@ +from typing import Any + +from msprobe.mindspore.free_benchmark.perturbation.base_perturbation import BasePerturbation +from msprobe.mindspore.free_benchmark.common.handler_params import HandlerParams + + +class NoChangePerturbation(BasePerturbation): + + def handle(self, params: HandlerParams) -> Any: + params.fuzzed_value = params.args[params.index] + self.is_fuzzed = True + return self.get_fuzzed_result(params) diff --git a/debug/accuracy_tools/msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py b/debug/accuracy_tools/msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..6c8328dc2e2d7489369e2e05448db586d1ed3bff --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/free_benchmark/perturbation/perturbation_factory.py @@ -0,0 +1,27 @@ +from msprobe.mindspore.common.const import FreeBenchmarkConst +from msprobe.mindspore.free_benchmark.common.config import Config +from .add_noise import AddNoisePerturbation +from .bit_noise import BitNoisePerturbation +from .no_change import NoChangePerturbation +from .improve_precision import ImprovePrecisionPerturbation + + +class PerturbationFactory: + """ + 扰动工厂类 + + """ + perturbations = { + FreeBenchmarkConst.IMPROVE_PRECISION: ImprovePrecisionPerturbation, + FreeBenchmarkConst.ADD_NOISE: AddNoisePerturbation, + FreeBenchmarkConst.BIT_NOISE: BitNoisePerturbation, + FreeBenchmarkConst.NO_CHANGE: NoChangePerturbation, + } + + @staticmethod + def create(api_name: str): + perturbation = PerturbationFactory.perturbations.get(Config.pert_type) + if perturbation: + return perturbation(api_name) + else: + raise Exception(f'{Config.pert_type} is a invalid perturbation type') diff --git a/debug/accuracy_tools/msprobe/mindspore/free_benchmark/self_check_tool_factory.py b/debug/accuracy_tools/msprobe/mindspore/free_benchmark/self_check_tool_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..e485887ce6ede7fbe9d3e2ac22b62455924d0730 --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/free_benchmark/self_check_tool_factory.py @@ -0,0 +1,33 @@ +from msprobe.mindspore.common.const import Const +from msprobe.mindspore.debugger.debugger_config import DebuggerConfig +from msprobe.mindspore.free_benchmark.api_pynative_self_check import ApiPyNativeSelFCheck + + +class SelfCheckToolFactory: + tools = { + Const.CELL: { + Const.GRAPH_KBYK_MODE: None, + Const.GRAPH_GE_MODE: None, + Const.PYNATIVE_MODE: None + }, + Const.API: { + Const.GRAPH_KBYK_MODE: None, + Const.GRAPH_GE_MODE: None, + Const.PYNATIVE_MODE: ApiPyNativeSelFCheck + }, + Const.KERNEL: { + Const.GRAPH_KBYK_MODE: None, + Const.GRAPH_GE_MODE: None, + Const.PYNATIVE_MODE: None + } + } + + @staticmethod + def create(config: DebuggerConfig): + tool = SelfCheckToolFactory.tools.get(config.level) + if not tool: + raise Exception(f"{config.level} is not supported.") + tool = tool.get(config.execution_mode) + if not tool: + raise Exception(f"Task free_benchmark is not supported in this mode: {config.execution_mode}.") + return tool(config) diff --git a/debug/accuracy_tools/msprobe/mindspore/grad_probe/__init__.py b/debug/accuracy_tools/msprobe/mindspore/grad_probe/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/debug/accuracy_tools/msprobe/mindspore/grad_probe/global_context.py b/debug/accuracy_tools/msprobe/mindspore/grad_probe/global_context.py new file mode 100644 index 0000000000000000000000000000000000000000..16d0bd0b86298242b8aca1a57318290d11df9b76 --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/grad_probe/global_context.py @@ -0,0 +1,91 @@ +import os +import threading +from typing import Dict, Union + +from msprobe.core.grad_probe.utils import check_str +from msprobe.core.grad_probe.constant import GradConst +from msprobe.core.common.log import logger +from msprobe.core.common.file_check import create_directory +from msprobe.core.common.utils import check_path_before_create + + +class GlobalContext: + + _instance = None + _instance_lock = threading.Lock() + _setting = { + GradConst.LEVEL: None, + GradConst.PARAM_LIST: None, + GradConst.STEP: None, + GradConst.RANK: None, + GradConst.CURRENT_STEP: 0, + GradConst.BOUNDS: [-10, -1, -0.1, -0.01, -0.001, 0, 0.001, 0.01, 0.1, 1, 10], + GradConst.OUTPUT_PATH: None + } + + def __new__(cls, *args, **kwargs): + if cls._instance is None: + cls._instance_lock.acquire() + cls._instance = object.__new__(cls) + cls._instance_lock.release() + return cls._instance + + def init_context(self, config_dict: Dict): + level = config_dict.get(GradConst.LEVEL) + check_str(level, variable_name = "level in yaml") + if level in GradConst.SUPPORTED_LEVEL: + self._setting[GradConst.LEVEL] = config_dict.get(GradConst.LEVEL) + else: + raise ValueError("Invalid level set in config yaml file, level option: L0, L1, L2") + + self._set_input_list(config_dict, GradConst.PARAM_LIST, str) + self._set_input_list(config_dict, GradConst.BOUNDS, float) + self._set_input_list(config_dict, GradConst.STEP, int) + self._set_input_list(config_dict, GradConst.RANK, int) + + output_path = config_dict.get(GradConst.OUTPUT_PATH) + check_str(output_path, variable_name = "output_path in yaml") + try: + check_path_before_create(output_path) + except RuntimeError as err: + raise ValueError(f"Invalid output_path: {output_path}. The error message is {err}.") from err + self._setting[GradConst.OUTPUT_PATH] = output_path + if not os.path.isdir(self._setting.get(GradConst.OUTPUT_PATH)): + create_directory(self._setting.get(GradConst.OUTPUT_PATH)) + else: + logger.warning("The output_path exists, the data will be covered.") + + def get_context(self, key: str): + if key not in self._setting: + logger.warning(f"Unrecognized {key}.") + return self._setting.get(key) + + def update_step(self): + self._setting[GradConst.CURRENT_STEP] += 1 + + def step_need_dump(self, step): + dump_step_list = self.get_context(GradConst.STEP) + return (not dump_step_list) or (step in dump_step_list) + + def rank_need_dump(self, rank): + dump_rank_list = self.get_context(GradConst.RANK) + return (not dump_rank_list) or (rank in dump_rank_list) + + def _set_input_list(self, config_dict: Dict, name: str, dtype: Union[int, str, float]): + value = config_dict.get(name) + if dtype == int: + type_str = "integer" + elif dtype == float: + type_str = "float" + else: + type_str = "string" + if value and isinstance(value, list): + for val in value: + if not isinstance(val, dtype): + logger.warning(f"Invalid {name} which must be None or list of {type_str}") + return + self._setting[name] = value + else: + logger.warning(f"{name} is None or not a list with valid items, use default value.") + +grad_context = GlobalContext() diff --git a/debug/accuracy_tools/msprobe/mindspore/grad_probe/grad_analyzer.py b/debug/accuracy_tools/msprobe/mindspore/grad_probe/grad_analyzer.py new file mode 100644 index 0000000000000000000000000000000000000000..2bdc11114c37fa1378d72d9243d0dec48213795d --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/grad_probe/grad_analyzer.py @@ -0,0 +1,231 @@ +import os +import time +from typing import List, Tuple +import multiprocessing +from multiprocessing import Process + +import numpy as np +import mindspore as ms +from mindspore.communication import get_rank +from mindspore.ops import operations as P +from mindspore.common.parameter import Parameter + +from msprobe.core.grad_probe.utils import ListCache +from msprobe.core.grad_probe.constant import GradConst +from msprobe.core.common.log import logger +from msprobe.core.common.file_check import create_directory +from msprobe.core.common.utils import check_file_or_directory_path, write_csv, remove_path, move_file +from msprobe.mindspore.grad_probe.global_context import grad_context, GlobalContext + + +def get_rank_id(): + try: + rank_id = get_rank() + except Exception as err: + rank_id = 0 + return rank_id + + +@ms.jit +def grad_dump(dump_dir: str, g_name: str, dump_step: Parameter, grad: ms.Tensor, level: str, bounds: List): + ''' + Dump gradient statistic data. + level0: [step, max, min, norm, shape_dim, shape] + level1: [step, max, min, norm, shape_dim, shape] + grad_bool_data + level2: [step, max, min, norm, shape_dim, shape, dist_dim, dist] + grad_bool_data + ''' + dump_path = os.path.join(dump_dir, g_name) + dump_dir_path = dump_path + "_dir" + save_op = ms.ops.TensorDump() + + grad_flat = grad.reshape(-1) + max_val = grad_flat.max(axis=0).float() + min_val = grad_flat.min(axis=0).float() + norm_val = grad_flat.norm(ord=2).float() + shape = grad.shape + extrem_list = [dump_step[0].float(), max_val, min_val, norm_val] + extrem_stat = ms.ops.stack(extrem_list) + shape_list = [len(shape)] + list(shape) + shape_stat = ms.Tensor(shape_list).float() + level0_stat = ms.ops.concat((extrem_stat, shape_stat), axis=0) + level_stat = level0_stat + + if level == GradConst.LEVEL2: + zero_grad = (grad == 0).sum() + dist_dim = ms.Tensor([len(bounds) + 2]).float() + bucket_result = ms.ops.bucketize(grad.float(), bounds) + bucket_result = bucket_result.astype(ms.int8) + dist_stat = [(bucket_result == i).sum() for i in range(len(bounds) + 1)] + dist_stat.append(zero_grad) + dist_stat.append(ms.Tensor(1, dtype=ms.int64)) # make sure dist_stat is not empty + dist_stat = ms.ops.stack(dist_stat, axis=0).float() + level2_stat = ms.ops.concat((level0_stat, dist_dim, dist_stat), axis=0) + level_stat = level2_stat + + save_op(dump_path, level_stat) + if level == GradConst.LEVEL1 or level == GradConst.LEVEL2: + grad_direction = grad > 0 + save_op(dump_dir_path, grad_direction) + + +class CSVGenerator(Process): + + def __init__(self) -> None: + super().__init__() + self.dump_dir = None + self.save_dir = None + self.level = GradConst.LEVEL0 + self.cache_list = ListCache() + self.current_step = None + self.stop_event = None + self.last_finish = False + self.bounds = [-0.1, 0.0, 0.1], + + def init(self, context: GlobalContext): + rank_id = get_rank_id() + output_path = context.get_context(GradConst.OUTPUT_PATH) + self.level = context.get_context(GradConst.LEVEL) + self.bounds = context.get_context(GradConst.BOUNDS) + self.dump_dir = f"{output_path}/rank{rank_id}/Dump/" + self.save_dir = f"{output_path}/rank{rank_id}/" + self.current_step = None + self.stop_event = multiprocessing.Event() + self.last_finish = False + + def run(self): + while True: + if not os.path.exists(self.dump_dir): + time.sleep(0.1) + if self.stop_event.is_set(): + break + continue + npy_files = os.listdir(self.dump_dir) + npy_files.sort(key=lambda x: int(x.split("_")[0])) + self.traverse_files(npy_files) + empty = len(os.listdir(self.dump_dir)) == 0 + if self.stop_event.is_set() and empty and self.last_finish: + break + if os.path.exists(self.dump_dir): + remove_path(self.dump_dir) + + def stop(self): + self.stop_event.set() + + def traverse_files(self, npy_files: List): + for npy_file in npy_files: + file_path = os.path.join(self.dump_dir, npy_file) + while not os.path.exists(file_path): + time.sleep(0.01) + check_file_or_directory_path(file_path) + if GradConst.STEP_FINISH in npy_file: + self.cache_list.flush() + remove_path(file_path) + self.last_finish = True + elif file_path.split("_")[-1] == GradConst.DIR_SUFFIX: + prefix_idx = len(npy_file.split("_")[0]) + new_name = npy_file[prefix_idx + 1:].replace("_" + GradConst.DIR_SUFFIX, "." + GradConst.NPY_SUFFIX) + if not new_name: + raise RuntimeError("Invalid dump data name.") + if self.current_step is None: + raise RuntimeError("Current record step is None.") + step_dir = os.path.join(self.save_dir, f"step{self.current_step}") + if not os.path.exists(step_dir): + create_directory(step_dir) + dst_file = os.path.join(step_dir, new_name) + move_file(file_path, dst_file) + self.last_finish = False + elif file_path.split(".")[-1] == GradConst.NPY_SUFFIX: + stat_data = self.load_npy_data(file_path) + if stat_data is None: + continue + if not self.check_valid(stat_data): + os.remove(file_path) + continue + step = int(stat_data[GradConst.STEP_IDX]) + update_step = self.current_step is None or step != self.current_step + self.current_step = step + if update_step: + self.create_csv_file() + self.gen_csv_line(file_path, stat_data) + os.remove(file_path) + self.last_finish = False + + def check_valid(self, stat_data): + level = grad_context.get_context(GradConst.LEVEL) + try: + shape_dim = int(stat_data[GradConst.SHAPE_DIM_IDX]) + if level == GradConst.LEVEL2: + dist_dim = int(stat_data[shape_dim + GradConst.SHAPE_DIM_IDX + 1]) + length = shape_dim + dist_dim + 7 + else: + length = shape_dim + 5 + except IndexError as err: + return False + if length != len(stat_data): + return False + return True + + def load_npy_data(self, file_path: str): + stat_data = None + max_try = 10 + while max_try: + try: + stat_data = np.load(file_path) + return stat_data + except Exception as err: + logger.warning(f"load numpy file failed, retry...") + max_try -= 1 + time.sleep(0.1) + return stat_data + + def gen_csv_line(self, file_path: str, stat_data) -> None: + shape_dim = int(stat_data[GradConst.SHAPE_DIM_IDX]) + file_name = os.path.basename(file_path) + prefix_idx = len(file_name.split("_")[0]) + param_name = file_name[(prefix_idx + 1) : -(len(GradConst.NPY_SUFFIX) + 1)] + if not param_name: + raise RuntimeError("Invalid gradient statistic file name.") + csv_line = [param_name] + if self.level == GradConst.LEVEL2: + csv_line.extend(self.get_dist_data(shape_dim, stat_data)) + csv_line.extend(self.get_extrem_data(shape_dim, stat_data)) + self.cache_list.append(csv_line) + + def get_dist_data(self, shape_dim: int, stat_data: np.ndarray): + dist_data = stat_data[(shape_dim + GradConst.SHAPE_DIM_IDX + 2):-1] + element_num = dist_data.sum() - dist_data[-1] + if element_num != 0: + dist_data = dist_data / element_num + return list(dist_data) + + def get_extrem_data(self, shape_dim: int, stat_data: np.ndarray): + extrem_data = list(stat_data[(GradConst.STEP_IDX + 1):(GradConst.STEP_IDX + 4)]) + shape_data = stat_data[(GradConst.SHAPE_DIM_IDX + 1):(GradConst.SHAPE_DIM_IDX + shape_dim + 1)] + shape_data = list(shape_data.astype(int)) + extrem_data.append(shape_data) + return extrem_data + + def create_csv_file(self): + headers = ["Param_name"] + if self.level == GradConst.LEVEL2: + headers.extend(self.get_dist_header()) + headers.extend(self.get_extrem_headers()) + output_path = f"{self.save_dir}/grad_summary_{self.current_step}.csv" + write_csv([headers], output_path) + self.cache_list.set_output_file(output_path) + self.cache_list.clear() + + def get_extrem_headers(self) -> List[str]: + return ["Max", "Min", "Norm", "Shape"] + + def get_dist_header(self) -> List[str]: + intervals = [] + for i, _ in enumerate(self.bounds): + if i == 0: + intervals.append(f"(-inf, {self.bounds[i]}]") + else: + intervals.append(f"({self.bounds[i-1]}, {self.bounds[i]}]") + intervals.extend([f"({self.bounds[-1]}, inf)", "=0"]) + return intervals + +csv_generator = CSVGenerator() diff --git a/debug/accuracy_tools/msprobe/mindspore/grad_probe/grad_monitor.py b/debug/accuracy_tools/msprobe/mindspore/grad_probe/grad_monitor.py new file mode 100644 index 0000000000000000000000000000000000000000..f1e082688a6569e2dfa18bbbca42c4836858a923 --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/grad_probe/grad_monitor.py @@ -0,0 +1,27 @@ +from msprobe.mindspore.grad_probe.global_context import grad_context +from msprobe.mindspore.grad_probe.grad_analyzer import csv_generator +from msprobe.mindspore.grad_probe.hook import hook_optimizer +from msprobe.core.grad_probe.constant import GradConst + + +class GradientMonitor: + + def __init__(self, common_dict, task_config): + config = {} + config[GradConst.OUTPUT_PATH] = common_dict.dump_path + config[GradConst.STEP] = common_dict.step + config[GradConst.RANK] = common_dict.rank + config[GradConst.PARAM_LIST] = task_config.param_list + config[GradConst.LEVEL] = task_config.grad_level + config[GradConst.BOUNDS] = task_config.bounds + self.config = config + grad_context.init_context(self.config) + + @staticmethod + def monitor(opt): + csv_generator.init(grad_context) + hook_optimizer(opt) + + @staticmethod + def stop(): + csv_generator.stop() diff --git a/debug/accuracy_tools/msprobe/mindspore/grad_probe/grad_stat_csv.py b/debug/accuracy_tools/msprobe/mindspore/grad_probe/grad_stat_csv.py new file mode 100644 index 0000000000000000000000000000000000000000..1c2b0ee3bf3398a6e68107edce1558ad478a0cbd --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/grad_probe/grad_stat_csv.py @@ -0,0 +1,132 @@ +from abc import ABC, abstractmethod +import hashlib + +import mindspore +from mindspore import ops, Tensor +from msprobe.core.grad_probe.constant import GradConst + + +class CsvInput: + def __init__(self, param_name, grad, bounds): + self.param_name = param_name + self.grad = grad + self.bounds = bounds + +class GradStatCsv: + csv = {} + + @staticmethod + def get_csv_header(level, csv_input): + header = ["param_name"] + for key in level["header"]: + header.extend(GradStatCsv.csv[key].generate_csv_header(csv_input)) + return header + + @staticmethod + def get_csv_line(level, csv_input): + line = [csv_input.param_name] + for key in level["header"]: + line.extend(GradStatCsv.csv[key].generate_csv_content(csv_input)) + return line + + +def register_csv_item(key, cls=None): + if cls is None: + # 无参数时,返回装饰器函数 + return lambda cls: register_csv_item(key, cls) + GradStatCsv.csv[key] = cls + return cls + + +class CsvItem(ABC): + @staticmethod + @abstractmethod + def generate_csv_header(csv_input): + pass + + @staticmethod + @abstractmethod + def generate_csv_content(csv_input): + pass + + +@register_csv_item(GradConst.MD5) +class CsvMd5(CsvItem): + def generate_csv_header(csv_input): + return ["MD5"] + + def generate_csv_content(csv_input): + grad = csv_input.grad + tensor_bytes = grad.float().numpy().tobytes() + md5_hash = hashlib.md5(tensor_bytes) + return [md5_hash.hexdigest()] + + +@register_csv_item(GradConst.DISTRIBUTION) +class CsvDistribution(CsvItem): + def generate_csv_header(csv_input): + bounds = csv_input.bounds + intervals = [] + if bounds: + intervals.append(f"(-inf, {bounds[0]}]") + for i in range(1, len(bounds)): + intervals.append(f"({bounds[i-1]}, {bounds[i]}]") + if intervals: + intervals.append(f"({bounds[-1]}, inf)") + intervals.append("=0") + + return intervals + + def generate_csv_content(csv_input): + grad = csv_input.grad + bounds = csv_input.bounds + if grad.dtype == mindspore.bfloat16: + grad = grad.to(mindspore.float32) + element_num = grad.numel() + grad_equal_0_num = (grad == 0).sum().item() + bucketsize_result = ops.bucketize(grad.float(), bounds) + bucketsize_result = bucketsize_result.astype(mindspore.int8) + interval_nums = [(bucketsize_result == i).sum().item() for i in range(len(bounds) + 1)] + interval_nums.append(grad_equal_0_num) + return_list = [x / element_num if element_num != 0 else 0 for x in interval_nums] + return return_list + + +@register_csv_item(GradConst.MAX) +class CsvMax(CsvItem): + def generate_csv_header(csv_input): + return ["max"] + + def generate_csv_content(csv_input): + grad = csv_input.grad + return [ops.amax(grad).float().numpy().tolist()] + + +@register_csv_item(GradConst.MIN) +class CsvMin(CsvItem): + def generate_csv_header(csv_input): + return ["min"] + + def generate_csv_content(csv_input): + grad = csv_input.grad + return [ops.amin(grad).float().numpy().tolist()] + + +@register_csv_item(GradConst.NORM) +class CsvNorm(CsvItem): + def generate_csv_header(csv_input): + return ["norm"] + + def generate_csv_content(csv_input): + grad = csv_input.grad + return [ops.norm(grad).float().numpy().tolist()] + + +@register_csv_item(GradConst.SHAPE) +class CsvShape(CsvItem): + def generate_csv_header(csv_input): + return ["shape"] + + def generate_csv_content(csv_input): + grad = csv_input.grad + return [list(grad.shape)] \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/mindspore/grad_probe/hook.py b/debug/accuracy_tools/msprobe/mindspore/grad_probe/hook.py new file mode 100644 index 0000000000000000000000000000000000000000..243fb33de1c0b7b23d4b9d32c467fc80e80b9d32 --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/grad_probe/hook.py @@ -0,0 +1,92 @@ + +import os + +import mindspore +import mindspore as ms +from mindspore.common.api import jit +from mindspore.nn.optim.optimizer import Optimizer +from mindspore.common.parameter import Parameter +from mindspore.common.initializer import initializer + +from msprobe.core.grad_probe.constant import GradConst +from msprobe.core.common.log import logger + +from msprobe.core.common.utils import write_csv, remove_path +from msprobe.mindspore.grad_probe.global_context import grad_context +from msprobe.mindspore.grad_probe.grad_analyzer import grad_dump, get_rank_id +from msprobe.mindspore.grad_probe.grad_analyzer import csv_generator +from msprobe.mindspore.grad_probe.grad_stat_csv import GradStatCsv, CsvInput +from msprobe.mindspore.grad_probe.utils import save_grad_direction, get_adapted_level + +class HookInput: + + ''' + HookInput is a class wrapping all the variables used for hooking optimizer + ''' + + def __init__(self, opt) -> None: + self.func = opt.construct + self.g_names = [param.name for param in opt._parameters] + self.param_list = grad_context.get_context(GradConst.PARAM_LIST) + self.rank_id = get_rank_id() + output_path = grad_context.get_context(GradConst.OUTPUT_PATH) + self.dump_dir = os.path.join(output_path, f"rank{self.rank_id}", "Dump") + self.save_dir = os.path.join(output_path, f"rank{self.rank_id}") + self.step_finish_flag = os.path.join(self.dump_dir, GradConst.STEP_FINISH) + if os.path.exists(self.save_dir): + logger.warning(f"Delete existing path {self.save_dir}.") + remove_path(self.save_dir) + self.level = grad_context.get_context(GradConst.LEVEL) + self.bounds = grad_context.get_context(GradConst.BOUNDS) + self.mode = mindspore.get_context("mode") + +def hook_graph_mode_optimizer(opt, hook_input): + @jit + def new_construct(self, gradients): + for index, grad_value in enumerate(gradients): + if hook_input.param_list and hook_input.g_names[index] not in hook_input.param_list: + continue + grad_dump(hook_input.dump_dir, hook_input.g_names[index], self.dump_step, + grad_value, hook_input.level, hook_input.bounds) + ms.ops.TensorDump()(hook_input.step_finish_flag, self.dump_step) + self.assignadd(self.dump_step, self.global_step_increase_tensor) + out = hook_input.func(gradients) + return out + + opt.dump_step = Parameter(initializer(0, [1], ms.int32), name="dump_step") + opt.construct = new_construct.__get__(opt, type(opt)) + csv_generator.start() + +def hook_pynative_optimizer(opt, hook_input): + level_adapted = get_adapted_level(hook_input.level) + + def hook_fn(cell, input): + gradients, = input + cur_step = grad_context.get_context(GradConst.CURRENT_STEP) + if grad_context.step_need_dump(cur_step) and grad_context.rank_need_dump(hook_input.rank_id): + output_lines = [] + for index, grad_value in enumerate(gradients): + param_name = hook_input.g_names[index] + if hook_input.param_list and param_name not in hook_input.param_list: + continue + csv_input = CsvInput(param_name, grad_value, hook_input.bounds) + grad_info = GradStatCsv.get_csv_line(level_adapted, csv_input) + output_lines.append(grad_info) + if level_adapted["have_grad_direction"]: + save_grad_direction(param_name, grad_value, os.path.join(hook_input.save_dir, f'step{cur_step}')) + output_csv_path = os.path.join(hook_input.save_dir, f"grad_summary_{cur_step}.csv") + dummy_csv_input = CsvInput(None, None, hook_input.bounds) + output_lines.insert(0, GradStatCsv.get_csv_header(level_adapted, dummy_csv_input)) + write_csv(output_lines, output_csv_path) + grad_context.update_step() + + opt.register_forward_pre_hook(hook_fn) + + +def hook_optimizer(opt: Optimizer): + hook_input = HookInput(opt) + + if hook_input.mode == mindspore.GRAPH_MODE: + hook_graph_mode_optimizer(opt, hook_input) + else: + hook_pynative_optimizer(opt, hook_input) diff --git a/debug/accuracy_tools/msprobe/mindspore/grad_probe/utils.py b/debug/accuracy_tools/msprobe/mindspore/grad_probe/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..db0a36a022615c5560922b5c708dabc77fafbd2f --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/grad_probe/utils.py @@ -0,0 +1,29 @@ +import os + +import numpy as np +import mindspore +from msprobe.core.grad_probe.constant import GradConst, level_adp +from msprobe.core.grad_probe.utils import check_param +from msprobe.core.common.file_check import create_directory +from msprobe.core.common.utils import check_path_before_create, change_mode, check_file_or_directory_path, save_npy + + +def save_grad_direction(param_name, grad, save_path): + if not os.path.exists(save_path): + create_directory(save_path) + check_file_or_directory_path(save_path, isdir=True) + check_param(param_name) + save_filepath = os.path.join(save_path, f"{param_name}.npy") + check_path_before_create(save_filepath) + + if grad.dtype == mindspore.bfloat16: + grad = grad.to(mindspore.float32) + grad_direction_tensor = grad > 0 + grad_direction_ndarray = grad_direction_tensor.numpy() + + save_npy(grad_direction_ndarray, save_filepath) + + +def get_adapted_level(level: str): + level_adapted = level_adp.get(level) + return level_adapted \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/mindspore/ms_config.py b/debug/accuracy_tools/msprobe/mindspore/ms_config.py index c0ef6bb6c00aab426fd42a11c3bc2436440a4a6a..4109a867915f13cbdb469acb987ddad84e6fc247 100644 --- a/debug/accuracy_tools/msprobe/mindspore/ms_config.py +++ b/debug/accuracy_tools/msprobe/mindspore/ms_config.py @@ -1,7 +1,12 @@ import json + from msprobe.core.common_config import CommonConfig, BaseConfig from msprobe.core.common.file_check import FileOpen from msprobe.core.common.const import Const +from msprobe.mindspore.common.const import FreeBenchmarkConst +from msprobe.mindspore.common.log import logger +from msprobe.core.grad_probe.constant import level_adp +from msprobe.core.grad_probe.utils import check_numeral_list_ascend class TensorConfig(BaseConfig): @@ -51,10 +56,48 @@ class OverflowCheckConfig(BaseConfig): raise Exception("check_mode is invalid") +class FreeBenchmarkConfig(BaseConfig): + def __init__(self, task_config): + super().__init__(task_config) + self._check_config() + + def _check_config(self): + if self.fuzz_device and self.fuzz_device not in FreeBenchmarkConst.DEVICE_LIST: + raise Exception("fuzz_device must be npu or empty") + if self.pert_mode and self.pert_mode not in FreeBenchmarkConst.PERT_TYPE_LIST: + raise Exception("pert_mode must be improve_precision, add_noise, bit_noise, no_change or empty") + if self.handler_type and self.handler_type not in FreeBenchmarkConst.HANDLER_TYPE_LIST: + raise Exception("handler_type must be check, fix or empty") + if self.fuzz_level and self.fuzz_level not in FreeBenchmarkConst.DUMP_LEVEL_LIST: + raise Exception("fuzz_level must be L1 or empty") + if self.fuzz_stage and self.fuzz_stage not in FreeBenchmarkConst.STAGE_LIST: + raise Exception("fuzz_stage must be forward or empty") + if self.if_preheat or self.preheat_step or self.max_sample: + logger.warning("'if_preheat', 'preheat_step' and 'max_sample' settings " + "are not supported for mindspore free benchmark task.") + + +class GradProbeConfig(BaseConfig): + def __init__(self, json_config): + super().__init__(json_config) + self.grad_level = json_config.get("grad_level", "L1") + self.param_list = json_config.get("param_list", []) + self.bounds = json_config.get("bounds", []) + + def _check_config(self): + if self.grad_level not in level_adp.keys(): + raise Exception(f"grad_level must be one of {level_adp.keys()}") + if not isinstance(self.param_list, list): + raise Exception(f"param_list must be a list") + check_numeral_list_ascend(self.bounds) + + TaskDict = { Const.TENSOR: TensorConfig, Const.STATISTICS: StatisticsConfig, Const.OVERFLOW_CHECK: OverflowCheckConfig, + Const.FREE_BENCHMARK: FreeBenchmarkConfig, + Const.GRAD_PROBE: GradProbeConfig, } diff --git a/debug/accuracy_tools/msprobe/mindspore/runtime.py b/debug/accuracy_tools/msprobe/mindspore/runtime.py new file mode 100644 index 0000000000000000000000000000000000000000..380b30d9785a27a717e3d61739a14bdc28f73d33 --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/runtime.py @@ -0,0 +1,4 @@ +class Runtime: + step_count: int = 0 + rank_id: int = -1 + is_running: bool = False diff --git a/debug/accuracy_tools/msprobe/mindspore/service.py b/debug/accuracy_tools/msprobe/mindspore/service.py index 50776aaf1097339e7c6d98944db7ddf2d2238c5f..29881e738dfba66e87d28e96cf64e66d9441909b 100644 --- a/debug/accuracy_tools/msprobe/mindspore/service.py +++ b/debug/accuracy_tools/msprobe/mindspore/service.py @@ -19,6 +19,10 @@ from pathlib import Path import functools from collections import defaultdict +from mindspore.common.tensor import Tensor +from mindspore import ops +from mindspore import nn + from msprobe.core.data_dump.data_collector import build_data_collector from msprobe.core.data_dump.scope import BaseScope from msprobe.mindspore.common.utils import get_rank_if_initialized @@ -27,7 +31,9 @@ from msprobe.mindspore.common.log import logger from msprobe.core.common.utils import Const from msprobe.core.common.exceptions import DistributedNotInitializedError from msprobe.mindspore.dump.hook_cell.api_registry import api_register -from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs +from msprobe.core.data_dump.data_processor.base import ModuleBackwardInputsOutputs, ModuleForwardInputsOutputs, \ + ModuleBackwardInputs, ModuleBackwardOutputs +from msprobe.core.common.exceptions import MsprobeException from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell @@ -41,9 +47,18 @@ class Service: self.current_iter = 0 self.first_start = True self.current_rank = None + self.primitive_counters = {} self.dump_iter_dir = None self.start_call = False + @staticmethod + def check_model_valid(model): + if not model or isinstance(model, nn.Cell): + return model + raise MsprobeException( + MsprobeException.INVALID_PARAM_ERROR, "model 参数必须是 mindspore.nn.Cell 类型。" + ) + def build_hook(self, module_type, name): def forward_hook(api_or_module_name, module, input, output): self.data_collector.visit_and_clear_overflow_status(api_or_module_name) @@ -79,13 +94,139 @@ class Service: return wrap_forward_hook, wrap_backward_hook + def wrap_primitive(self, origin_func, primitive_name): + service_instance = self + + def create_backward_hook(captured_grads, num_tensors, updated_primitive_name, hook_type): + def backward_hook(grad): + captured_grads.append(grad) + backward_primitive_name = f"{updated_primitive_name}.{Const.BACKWARD}" + try: + if len(captured_grads) == num_tensors and hook_type == Const.INPUT: + service_instance.data_collector.visit_and_clear_overflow_status(backward_primitive_name) + new_module_input_output = ModuleBackwardOutputs(grad_output=tuple(captured_grads)) + service_instance.data_collector.backward_output_data_collect( + backward_primitive_name, service_instance, os.getpid(), new_module_input_output + ) + captured_grads.clear() + elif len(captured_grads) == num_tensors and hook_type == Const.OUTPUT: + service_instance.data_collector.visit_and_clear_overflow_status(backward_primitive_name) + new_module_input_output = ModuleBackwardInputs(grad_input=tuple(captured_grads)) + service_instance.data_collector.backward_input_data_collect( + backward_primitive_name, service_instance, os.getpid(), new_module_input_output + ) + captured_grads.clear() + + except Exception as exception: + raise Exception(f"This is a primitive op {hook_type}_backward dump error: {exception}," + f" updated_primitive_name: {updated_primitive_name}") from exception + + return backward_hook + + def hook_primitive_inputs(args, captured_grads_input, updated_primitive_name): + hooked_inputs = [] + num_tensors = sum(isinstance(arg, Tensor) for arg in args) + input_backward_hook = create_backward_hook(captured_grads_input, num_tensors, updated_primitive_name, + Const.INPUT) + for _, arg in enumerate(args): + if isinstance(arg, Tensor): + arg_hooked = ops.HookBackward(input_backward_hook)(arg) + hooked_inputs.append(arg_hooked) + else: + hooked_inputs.append(arg) + return hooked_inputs + + def hook_primitive_outputs(out, captured_grads_output, updated_primitive_name): + if isinstance(out, tuple): + num_output_tensors = sum(isinstance(tensor, Tensor) for tensor in out) + else: + num_output_tensors = 1 + output_backward_hook = create_backward_hook(captured_grads_output, num_output_tensors, + updated_primitive_name, Const.OUTPUT) + + if isinstance(out, Tensor): + return ops.HookBackward(output_backward_hook)(out) + elif isinstance(out, tuple): + hooked_outputs = [] + for tensor in out: + if isinstance(tensor, Tensor): + hooked_outputs.append(ops.HookBackward(output_backward_hook)(tensor)) + else: + hooked_outputs.append(tensor) + return tuple(hooked_outputs) + return out + + def wrapped_primitive_call(instance_self, *args, **kwargs): + service_instance.update_primitive_counters(primitive_name) + current_count = service_instance.primitive_counters.get(primitive_name, 0) + updated_primitive_name = f"{Const.PRIMITIVE_PREFIX}.{primitive_name}.{current_count}" + + if not service_instance.switch: + return origin_func(*args, **kwargs) + + captured_grads_input, captured_grads_output = [], [] + + try: + hooked_inputs = hook_primitive_inputs(args, captured_grads_input, updated_primitive_name) + except Exception as exception: + raise Exception("This is a primitive op dump error during input hooking: {}," + " primitive_name: {}".format(exception, primitive_name)) from exception + + try: + out = origin_func(*hooked_inputs, **kwargs) + except Exception as exception: + raise Exception("This is a primitive op dump error during function call: {}," + " primitive_name: {}".format(exception, primitive_name)) from exception + + forward_primitive_name = f"{updated_primitive_name}.{Const.FORWARD}" + service_instance.data_collector.visit_and_clear_overflow_status(forward_primitive_name) + if service_instance.data_collector: + module_input_output = ModuleForwardInputsOutputs(args=hooked_inputs, kwargs=kwargs, output=out) + try: + service_instance.data_collector.forward_data_collect(forward_primitive_name, instance_self, + os.getpid(), module_input_output) + except Exception as exception: + raise Exception("This is a primitive op dump error during forward data collection: {}," + " primitive_name: {}".format(exception, primitive_name)) from exception + + if service_instance.data_collector.if_return_forward_new_output(): + out = service_instance.data_collector.get_forward_new_output() + + try: + out = hook_primitive_outputs(out, captured_grads_output, updated_primitive_name) + except Exception as exception: + raise Exception("This is a primitive op dump error during output hooking: {}," + " primitive_name: {}".format(exception, primitive_name)) from exception + + return out + + return wrapped_primitive_call + + def update_primitive_counters(self, primitive_name): + if primitive_name not in self.primitive_counters: + self.primitive_counters[primitive_name] = 0 + else: + self.primitive_counters[primitive_name] += 1 + + def register_hooks(self): + primitive_set = set() + for _, cell in self.model.cells_and_names(): + for pname, primitive in cell._primitives.items(): + primitive_set.add((pname, primitive)) + + for pname, primitive in primitive_set: + NewPrimitive = type('NewPrimitive', (primitive.__class__,), + {'__call__': self.wrap_primitive(primitive.__call__, pname)}) + primitive.__class__ = NewPrimitive + def step(self): self.current_iter += 1 self.data_collector.update_iter(self.current_iter) HOOKCell.cell_count = defaultdict(int) + self.primitive_counters.clear() def start(self, model=None): - self.model = model + self.model = Service.check_model_valid(model) self.start_call = True logger.info("msprobe: debugger.start() is set successfully") if self.config.step and self.current_iter > max(self.config.step): @@ -150,3 +291,5 @@ class Service: if self.config.level == "L1": api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API)) api_register.api_set_hook_func() + if self.model: + self.register_hooks() diff --git a/debug/accuracy_tools/msprobe/mindspore/task_handler_factory.py b/debug/accuracy_tools/msprobe/mindspore/task_handler_factory.py index 7b7e6fd889c775a4491e824c1f73e6021cb99350..dfe2fbe2cdfa29f832d3b4a0c920f2616f12af32 100644 --- a/debug/accuracy_tools/msprobe/mindspore/task_handler_factory.py +++ b/debug/accuracy_tools/msprobe/mindspore/task_handler_factory.py @@ -1,17 +1,23 @@ +from msprobe.core.common.const import Const +from msprobe.mindspore.common.const import Const as MsConst from msprobe.mindspore.debugger.debugger_config import DebuggerConfig from msprobe.mindspore.dump.dump_tool_factory import DumpToolFactory from msprobe.mindspore.overflow_check.overflow_check_tool_factory import OverflowCheckToolFactory +from msprobe.mindspore.free_benchmark.self_check_tool_factory import SelfCheckToolFactory class TaskHandlerFactory: tasks = { - "tensor": DumpToolFactory, - "statistics": DumpToolFactory, - "overflow_check": OverflowCheckToolFactory + Const.TENSOR: DumpToolFactory, + Const.STATISTICS: DumpToolFactory, + Const.OVERFLOW_CHECK: OverflowCheckToolFactory, + Const.FREE_BENCHMARK: SelfCheckToolFactory } @staticmethod def create(config: DebuggerConfig): + if config.execution_mode == MsConst.PYNATIVE_MODE and config.task != Const.FREE_BENCHMARK: + raise Exception("Current Task can't run in pynative mode.") task = TaskHandlerFactory.tasks.get(config.task) if not task: raise Exception("valid task is needed.") diff --git a/debug/accuracy_tools/msprobe/msprobe.py b/debug/accuracy_tools/msprobe/msprobe.py index 698165b6150eabf63457cc23f4d80e7f58a5b423..54b4a12d01b7ff9a3cf24bd8773c1a65177fed6a 100644 --- a/debug/accuracy_tools/msprobe/msprobe.py +++ b/debug/accuracy_tools/msprobe/msprobe.py @@ -15,13 +15,14 @@ import argparse import sys -from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import _run_ut_parser, run_ut_command -from msprobe.pytorch.parse_tool.cli import parse as cli_parse -from msprobe.pytorch.api_accuracy_checker.run_ut.multi_run_ut import prepare_config, run_parallel_ut -from msprobe.pytorch.api_accuracy_checker.compare.api_precision_compare import _api_precision_compare_parser, \ - _api_precision_compare_command -from msprobe.pytorch.api_accuracy_checker.run_ut.run_overflow_check import _run_overflow_check_parser, \ - _run_overflow_check_command +import importlib.util +from msprobe.core.compare.utils import _compare_parser +from msprobe.core.common.log import logger + + +def is_module_available(module_name): + spec =importlib.util.find_spec(module_name) + return spec is not None def main(): @@ -31,37 +32,64 @@ def main(): "Providing one-site accuracy difference debugging toolkit for training on Ascend Devices.\n" f"For any issue, refer README.md first", ) + parser.set_defaults(print_help=parser.print_help) - parser.add_argument('-f', '--framework', required=True, choices=['pytorch'], + parser.add_argument('-f', '--framework', required=True, choices=['pytorch', 'mindspore'], help='Deep learning framework.') subparsers = parser.add_subparsers() subparsers.add_parser('parse') + compare_cmd_parser = subparsers.add_parser('compare') run_ut_cmd_parser = subparsers.add_parser('run_ut') multi_run_ut_cmd_parser = subparsers.add_parser('multi_run_ut') api_precision_compare_cmd_parser = subparsers.add_parser('api_precision_compare') run_overflow_check_cmd_parser = subparsers.add_parser('run_overflow_check') - _run_ut_parser(run_ut_cmd_parser) - _run_ut_parser(multi_run_ut_cmd_parser) - multi_run_ut_cmd_parser.add_argument('-n', '--num_splits', type=int, choices=range(1, 65), default=8, - help='Number of splits for parallel processing. Range: 1-64') - _api_precision_compare_parser(api_precision_compare_cmd_parser) - _run_overflow_check_parser(run_overflow_check_cmd_parser) + _compare_parser(compare_cmd_parser) + is_torch_available=is_module_available("torch") + if is_torch_available: + from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut import _run_ut_parser, run_ut_command + from msprobe.pytorch.parse_tool.cli import parse as cli_parse + from msprobe.pytorch.api_accuracy_checker.run_ut.multi_run_ut import prepare_config, run_parallel_ut + from msprobe.pytorch.api_accuracy_checker.compare.api_precision_compare import _api_precision_compare_parser, \ + _api_precision_compare_command + from msprobe.pytorch.api_accuracy_checker.run_ut.run_overflow_check import _run_overflow_check_parser, \ + _run_overflow_check_command + from msprobe.pytorch.compare.compare_cli import compare_cli + _run_ut_parser(run_ut_cmd_parser) + _run_ut_parser(multi_run_ut_cmd_parser) + multi_run_ut_cmd_parser.add_argument('-n', '--num_splits', type=int, choices=range(1, 65), default=8, + help='Number of splits for parallel processing. Range: 1-64') + _api_precision_compare_parser(api_precision_compare_cmd_parser) + _run_overflow_check_parser(run_overflow_check_cmd_parser) + if len(sys.argv) == 1: parser.print_help() sys.exit(0) args = parser.parse_args(sys.argv[1:]) - if sys.argv[3] == "run_ut": - run_ut_command(args) - elif sys.argv[3] == "parse": - cli_parse() - elif sys.argv[3] == "multi_run_ut": - config = prepare_config(args) - run_parallel_ut(config) - elif sys.argv[3] == "api_precision_compare": - _api_precision_compare_command(args) - elif sys.argv[3] == "run_overflow_check": - _run_overflow_check_command(args) - + if sys.argv[2] == "pytorch": + if not is_torch_available: + logger.error("PyTorch does not exit, please install PyTorch library") + raise Exception("PyTorch does not exit, please install PyTorch library") + if sys.argv[3] == "run_ut": + run_ut_command(args) + elif sys.argv[3] == "parse": + cli_parse() + elif sys.argv[3] == "multi_run_ut": + config = prepare_config(args) + run_parallel_ut(config) + elif sys.argv[3] == "api_precision_compare": + _api_precision_compare_command(args) + elif sys.argv[3] == "run_overflow_check": + _run_overflow_check_command(args) + elif sys.argv[3] == "compare": + compare_cli(args) + else: + if is_module_available("mindspore"): + from msprobe.mindspore.compare.compare_cli import compare_cli_ms + else: + logger.error("MindSpore does not exit, please install MindSpore library") + raise Exception("MindSpore does not exit, please install MindSpore library") + if sys.argv[3] == "compare": + compare_cli_ms(args) if __name__ == "__main__": main() diff --git a/debug/accuracy_tools/msprobe/pytorch/__init__.py b/debug/accuracy_tools/msprobe/pytorch/__init__.py index 482e850f7baa845bd831e0d4728e841661b9345b..c4e426772670212382addb9b855b4bdf69810d3d 100644 --- a/debug/accuracy_tools/msprobe/pytorch/__init__.py +++ b/debug/accuracy_tools/msprobe/pytorch/__init__.py @@ -1,4 +1,4 @@ from .debugger.precision_debugger import PrecisionDebugger from .common.utils import seed_all -from .compare.acc_compare import compare from .compare.distributed_compare import compare_distributed +from .compare.pt_compare import compare \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/config.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/config.py index 760e7c862dba5440412f5ee27d0345d1a17d2c5d..cf8af8d2cd3d4ea8aab9aad1d0e92cc09875d90f 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/config.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/common/config.py @@ -24,7 +24,13 @@ class Config: 'white_list': list, 'black_list': list, 'error_data_path': str, - 'precision': int + 'precision': int, + 'is_online': bool, + 'nfs_path': str, + 'host': str, + 'port': int, + 'rank_list': list, + 'tls_path': str } if key not in validators: raise ValueError(f"{key} must be one of {validators.keys()}") @@ -38,6 +44,10 @@ class Config: RunUTConfig.check_filter_list_config(key, value) if key == 'error_data_path': RunUTConfig.check_error_data_path_config(value) + if key == 'nfs_path': + RunUTConfig.check_nfs_path_config(value) + if key == 'tls_path': + RunUTConfig.check_tls_path_config(value) return value diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py index ee49588288efc0a33c086913cc5624059de82272..20f04b0cd7bea196f6ff6f785433592c9c4315a4 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py @@ -33,16 +33,30 @@ class Comparator: COLUMN_BACKWARD_SUCCESS = "Backward Test Success" COLUMN_STACK_INFO = "Traceback callstack info" - def __init__(self, result_csv_path, details_csv_path, is_continue_run_ut, stack_info_json_path=None): - self.save_path = result_csv_path - self.detail_save_path = details_csv_path - if not is_continue_run_ut and not os.path.exists(self.save_path) and not os.path.exists(self.detail_save_path): + def __init__(self, result_csv_path, details_csv_path, is_continue_run_ut, stack_info_json_path=None, config=None): + self.save_path_str = result_csv_path + self.detail_save_path_str = details_csv_path + self.save_path_list = [result_csv_path] + self.detail_save_path_list = [details_csv_path] + + if config and config.online_config.is_online: + self.save_path_str = result_csv_path.replace(".csv", "_rank{}.csv") + self.detail_save_path_str = details_csv_path.replace(".csv", "_rank{}.csv") + self.save_path_list = [self.save_path_str.format(rank) for rank in config.online_config.rank_list] + self.detail_save_path_list = \ + [self.detail_save_path_str.format(rank) for rank in config.online_config.rank_list] + + if not is_continue_run_ut: self.write_csv_title() if stack_info_json_path: self.stack_info = get_json_contents(stack_info_json_path) else: self.stack_info = None + @staticmethod + def get_path_from_rank(rank, path_list, path_pattern): + return path_list[-1] if len(path_list) == 1 else path_pattern.format(rank) + @staticmethod def print_pretest_result(): logger.info("Successfully completed run_ut/multi_run_ut.") @@ -86,10 +100,11 @@ class Comparator: def write_csv_title(self): summary_test_rows = [[self.COLUMN_API_NAME, self.COLUMN_FORWARD_SUCCESS, self.COLUMN_BACKWARD_SUCCESS, "Message"]] - if not os.path.exists(self.save_path): - write_csv(summary_test_rows, self.save_path) - if not os.path.exists(self.detail_save_path): - write_csv(DETAIL_TEST_ROWS, self.detail_save_path) + for save_path, detail_save_path in zip(self.save_path_list, self.detail_save_path_list): + if not os.path.exists(save_path): + write_csv(summary_test_rows, save_path) + if not os.path.exists(detail_save_path): + write_csv(DETAIL_TEST_ROWS, detail_save_path) def write_summary_csv(self, test_result): test_rows = [] @@ -104,7 +119,8 @@ class Comparator: stack_info = "\n".join(self.stack_info[name]) df_row.append(stack_info) test_rows.append(df_row) - write_csv(test_rows, self.save_path) + save_path = self.get_path_from_rank(test_result[-1], self.save_path_list, self.save_path_str) + write_csv(test_rows, save_path) def write_detail_csv(self, test_result): test_rows = [] @@ -125,7 +141,10 @@ class Comparator: if isinstance(item, float) else item for item in test_subject] test_rows.append([subject] + list(test_subject)) - write_csv(test_rows, self.detail_save_path) + detail_save_path = self.get_path_from_rank(test_result[-1], + self.detail_save_path_list, + self.detail_save_path_str) + write_csv(test_rows, detail_save_path) def record_results(self, args): self.write_summary_csv(args) diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/config.yaml b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/config.yaml index 2dac535dc0501f6e47f0cdcc48bd88e1d73ab0dd..49f8a726de8b7348491a37c9c16eea7f535c1270 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/config.yaml +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/config.yaml @@ -2,4 +2,9 @@ white_list: [] black_list: [] error_data_path: './' precision: 14 - \ No newline at end of file +is_online: False +nfs_path: "" +host: "" +port: -1 +rank_list: [0] +tls_path: "" diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py index 559dfdc0f14f191fc7142f6b2f9d735c51d6a738..7e5891b5a3c4c6458ae872a23cc35a5b645a7d62 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/run_ut/run_ut.py @@ -36,14 +36,20 @@ from msprobe.core.common.file_check import FileOpen, FileChecker, \ from msprobe.pytorch.common.log import logger from msprobe.pytorch.pt_config import parse_json_config from msprobe.core.common.const import Const, FileCheckConst, CompareConst +from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import ATTL, ATTLConfig, ApiData, move2device_exec +from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.device_dispatch import ConsumerDispatcher + current_time = time.strftime("%Y%m%d%H%M%S") UT_ERROR_DATA_DIR = 'ut_error_data' + current_time RESULT_FILE_NAME = "accuracy_checking_result_" + current_time + ".csv" DETAILS_FILE_NAME = "accuracy_checking_details_" + current_time + ".csv" RunUTConfig = namedtuple('RunUTConfig', ['forward_content', 'backward_content', 'result_csv_path', 'details_csv_path', - 'save_error_data', 'is_continue_run_ut', 'real_data_path', 'white_list', - 'black_list', 'error_data_path']) + 'save_error_data', 'is_continue_run_ut', 'real_data_path', 'white_list', + 'black_list', 'error_data_path', 'online_config']) + +OnlineConfig = namedtuple('OnlineConfig', ['is_online', 'nfs_path', 'host', 'port', 'rank_list', 'tls_path']) + not_backward_list = ['repeat_interleave'] not_detach_set = {'resize_', 'resize_as_', 'set_', 'transpose_', 't_', 'squeeze_', 'unsqueeze_'} not_raise_dtype_set = {'type_as'} @@ -140,7 +146,7 @@ def generate_cpu_params(input_args, input_kwargs, need_backward, api_name): elif isinstance(arg_in, torch.Tensor): if need_backward and arg_in.requires_grad: arg_in = deal_detach(raise_bench_data_dtype( - api_name, arg_in.clone(), raise_dtype=raise_dtype), to_detach).requires_grad_() + api_name, arg_in.clone(), raise_dtype=raise_dtype), to_detach).requires_grad_() temp_arg_in = arg_in * 1 arg_in = temp_arg_in.type_as(arg_in) arg_in.retain_grad() @@ -187,21 +193,33 @@ def run_ut(config): logger.info(f"UT task details will be saved in {config.details_csv_path}") if config.save_error_data: logger.info(f"UT task error_datas will be saved in {config.error_data_path}") - compare = Comparator(config.result_csv_path, config.details_csv_path, config.is_continue_run_ut) - with FileOpen(config.result_csv_path, 'r') as file: - csv_reader = csv.reader(file) - next(csv_reader) - api_name_set = {row[0] for row in csv_reader} + compare = Comparator(config.result_csv_path, config.details_csv_path, config.is_continue_run_ut, config=config) + + if config.online_config.is_online: + run_api_online(config, compare) + else: + with FileOpen(config.result_csv_path, 'r') as file: + csv_reader = csv.reader(file) + next(csv_reader) + api_name_set = {row[0] for row in csv_reader} + run_api_offline(config, compare, api_name_set) + for result_csv_path, details_csv_path in zip(compare.save_path_list, compare.detail_save_path_list): + change_mode(result_csv_path, FileCheckConst.DATA_FILE_AUTHORITY) + change_mode(details_csv_path, FileCheckConst.DATA_FILE_AUTHORITY) + logger.info(f"UT task result csv is saved in {result_csv_path}") + logger.info(f"UT task details csv is saved in {details_csv_path}") + compare.print_pretest_result() + + +def run_api_offline(config, compare, api_name_set): for _, (api_full_name, api_info_dict) in enumerate(tqdm(config.forward_content.items(), **tqdm_params)): if api_full_name in api_name_set: continue - if is_unsupported_api(api_full_name): # TODO run_ut does not support to the npu fusion api and distributed api + if is_unsupported_api(api_full_name): continue [_, api_name, _] = api_full_name.split(Const.SEP) try: - if config.black_list and api_name in config.black_list: - continue - if config.white_list and api_name not in config.white_list: + if blacklist_and_whitelist_filter(api_name, config.black_list, config.white_list): continue data_info = run_torch_api(api_full_name, config.real_data_path, config.backward_content, api_info_dict) is_fwd_success, is_bwd_success = compare.compare_output(api_full_name, data_info) @@ -223,9 +241,71 @@ def run_ut(config): else: torch.npu.empty_cache() gc.collect() - change_mode(compare.save_path, FileCheckConst.DATA_FILE_AUTHORITY) - change_mode(compare.detail_save_path, FileCheckConst.DATA_FILE_AUTHORITY) - compare.print_pretest_result() + + +def run_api_online(config, compare): + attl = init_attl(config.online_config) + dispatcher = ConsumerDispatcher(compare=compare) + dispatcher.start(handle_func=run_torch_api_online, config=config) + + def tcp_communication_flow(): + while True: + api_data = attl.recv() + if api_data == 'STOP_': + continue + if api_data == 'KILL_': + time.sleep(1) + logger.info("==========接收到STOP信号==========") + dispatcher.stop() + attl.stop_serve() + time.sleep(1) + break + if not isinstance(api_data, ApiData): + continue + api_full_name = api_data.name + [_, api_name, _] = api_full_name.split(Const.SEP) + if blacklist_and_whitelist_filter(api_name, config.black_list, config.white_list): + continue + dispatcher.update_consume_queue(api_data) + + def shared_storage_communication_flow(): + flag_num = -1 + while True: + api_data = attl.download() + if api_data == "start": + if flag_num == -1: + flag_num += 1 + flag_num += 1 + if api_data == "end": + flag_num -= 1 + if flag_num == 0: + dispatcher.stop() + break + if not isinstance(api_data, ApiData): + continue + api_full_name = api_data.name + [_, api_name, _] = api_full_name.split(Const.SEP) + if blacklist_and_whitelist_filter(api_name, config.black_list, config.white_list): + continue + dispatcher.update_consume_queue(api_data) + + if config.online_config.nfs_path: + shared_storage_communication_flow() + else: + tcp_communication_flow() + + +def blacklist_and_whitelist_filter(api_name, black_list, white_list): + """ + run api(api_name) if api_name not in black_list and in white_list. + If api is both in black_list and black_list, black_list first. + return: False for exec api, True for not exec + """ + if black_list and api_name in black_list: + return True + if white_list and api_name not in white_list: + return True + return False def is_unsupported_api(api_name): @@ -294,6 +374,20 @@ def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict return UtDataInfo(bench_grad_out, device_grad_out, device_out, out, bench_grad, in_fwd_data_list, backward_message) +def run_torch_api_online(api_full_name, api_data, backward_content): + in_fwd_data_list = [] + [api_type, api_name, _] = api_full_name.split(Const.SEP) + args, kwargs, out = api_data.args, api_data.kwargs, api_data.result + in_fwd_data_list.append(args) + in_fwd_data_list.append(kwargs) + if kwargs.get("device"): + del kwargs["device"] + + device_out = exec_api(api_type, api_name, args, kwargs) + device_out = move2device_exec(device_out, "cpu") + return UtDataInfo(None, None, out, device_out, None, in_fwd_data_list, None, rank=api_data.rank) + + def get_api_info(api_info_dict, api_name, real_data_path): convert_type, api_info_dict = api_info_preprocess(api_name, api_info_dict) need_grad = True @@ -357,11 +451,21 @@ def get_validated_details_csv_path(validated_result_csv_path): return validated_details_csv_path +def init_attl(config): + """config: OnlineConfig""" + attl = ATTL('gpu', ATTLConfig(is_benchmark_device=True, + connect_ip=config.host, + connect_port=config.port, + nfs_path=config.nfs_path, + tls_path=config.tls_path)) + return attl + + def _run_ut_parser(parser): parser.add_argument("-api_info", "--api_info_file", dest="api_info_file", default="", type=str, - help=" The api param tool result file: generate from api param tool, " + help=" The api param tool result file: generate from api param tool, " "a json file.", - required=True) + required=False) parser.add_argument("-o", "--out_path", dest="out_path", default="", type=str, help=" The ut task result out path.", required=False) @@ -451,20 +555,26 @@ def run_ut_command(args): except Exception as error: logger.error(f"Set device id failed. device id is: {args.device_id}") raise NotImplementedError from error - check_link(args.api_info_file) - api_info = os.path.realpath(args.api_info_file) - check_file_suffix(api_info, FileCheckConst.JSON_SUFFIX) + + # 在线预检场景下,不需要外出输出api信息,forward_content, backward_content, real_data_path设置为None + # 离线场景下,forward_content, backward_content, real_data_path从api_info_file中解析 + forward_content, backward_content, real_data_path = None, None, None + if args.api_info_file: + check_link(args.api_info_file) + api_info = os.path.realpath(args.api_info_file) + check_file_suffix(api_info, FileCheckConst.JSON_SUFFIX) + forward_content, backward_content, real_data_path = parse_json_info_forward_backward(api_info) + if args.filter_api: + logger.info("Start filtering the api in the forward_input_file.") + forward_content = preprocess_forward_content(forward_content) + logger.info("Finish filtering the api in the forward_input_file.") + out_path = os.path.realpath(args.out_path) if args.out_path else "./" check_path_before_create(out_path) create_directory(out_path) out_path_checker = FileChecker(out_path, FileCheckConst.DIR, ability=FileCheckConst.WRITE_ABLE) out_path = out_path_checker.common_check() save_error_data = args.save_error_data - forward_content, backward_content, real_data_path = parse_json_info_forward_backward(api_info) - if args.filter_api: - logger.info("Start filtering the api in the forward_input_file.") - forward_content = preprocess_forward_content(forward_content) - logger.info("Finish filtering the api in the forward_input_file.") result_csv_path = os.path.join(out_path, RESULT_FILE_NAME) details_csv_path = os.path.join(out_path, DETAILS_FILE_NAME) @@ -474,24 +584,39 @@ def run_ut_command(args): white_list = msCheckerConfig.white_list black_list = msCheckerConfig.black_list error_data_path = msCheckerConfig.error_data_path + is_online = msCheckerConfig.is_online + nfs_path = msCheckerConfig.nfs_path + host = msCheckerConfig.host + port = msCheckerConfig.port + rank_list = msCheckerConfig.rank_list + tls_path = msCheckerConfig.tls_path if args.config_path: _, task_config = parse_json_config(args.config_path, Const.RUN_UT) white_list = task_config.white_list black_list = task_config.black_list error_data_path = task_config.error_data_path + is_online = task_config.is_online + nfs_path = task_config.nfs_path + host = task_config.host + port = task_config.port + rank_list = task_config.rank_list + tls_path = task_config.tls_path + if save_error_data: if args.result_csv_path: time_info = result_csv_path.split('.')[0].split('_')[-1] global UT_ERROR_DATA_DIR UT_ERROR_DATA_DIR = 'ut_error_data' + time_info error_data_path = initialize_save_error_data(error_data_path) + online_config = OnlineConfig(is_online, nfs_path, host, port, rank_list, tls_path) run_ut_config = RunUTConfig(forward_content, backward_content, result_csv_path, details_csv_path, save_error_data, - args.result_csv_path, real_data_path, set(white_list), set(black_list), error_data_path) + args.result_csv_path, real_data_path, set(white_list), set(black_list), error_data_path, + online_config) run_ut(run_ut_config) class UtDataInfo: - def __init__(self, bench_grad, device_grad, device_output, bench_output, grad_in, in_fwd_data_list, + def __init__(self, bench_grad, device_grad, device_output, bench_output, grad_in, in_fwd_data_list, backward_message, rank=0): self.bench_grad = bench_grad self.device_grad = device_grad diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py new file mode 100644 index 0000000000000000000000000000000000000000..9ff0ad703c34c142ef5c59931e34a9b9497beff4 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/attl.py @@ -0,0 +1,202 @@ +import io +import os.path +import time +import re +from pathlib import Path +from multiprocessing import Queue +from typing import Optional, Union, Dict, Any +from collections import namedtuple +from dataclasses import dataclass + +import torch + +from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.client import TCPClient +from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.server import TCPServer +from msprobe.pytorch.common.utils import logger +from msprobe.pytorch.common.utils import save_pt +from msprobe.core.common.utils import remove_path + + +ApiData = namedtuple('ApiData', ['name', 'args', 'kwargs', 'result', 'step', 'rank'], + defaults=['unknown', None, None, None, 0, 0]) +BufferType = Union[ApiData, Dict[str, Any], str] # Union[Tensor, Tuple[Optional[Tensor]]] + + +@dataclass +class ATTLConfig: + is_benchmark_device: bool + connect_ip: str + connect_port: int + # storage_config + nfs_path: str = None + tls_path: str = None + check_sum: bool = True + queue_size: int = 50 + + +class ATTL: + def __init__(self, session_id: str, session_config: ATTLConfig, need_dump=True) -> None: + self.session_id = session_id + self.session_config = session_config + self.logger = logger + self.socket_manager = None + self.data_queue = Queue(maxsize=50) + self.dequeue_list = [] + self.message_end = False + self.kill_progress = False + self.check_attl_config() + if self.session_config.nfs_path: + self.nfs_path = Path(self.session_config.nfs_path) + elif self.session_config.is_benchmark_device: + + self.socket_manager = TCPServer(self.session_config.connect_port, + self.data_queue, + self.session_config.check_sum, + self.session_config.tls_path) + self.socket_manager.start() + elif need_dump: + self.socket_manager = TCPClient(self.session_config.connect_ip, + self.session_config.connect_port, + self.session_config.check_sum, + self.session_config.tls_path) + self.socket_manager.start() + + def check_attl_config(self): + if self.session_config.nfs_path: + if os.path.exists(self.session_config.nfs_path): + return + else: + raise Exception(f"nfs path {self.session_config.nfs_path} doesn't exists.") + ipv4_pattern = "([1-9]?\d|1\d{2}|2[0-4]\d|25[0-5])(\.([1-9]?\d|1\d{2}|2[0-4]\d|25[0-5])){3}$" + if not re.match(ipv4_pattern, self.session_config.connect_ip): + raise Exception(f"host {self.session_config.connect_ip} is invalid.") + if not (0 < self.session_config.connect_port <= 65535): + raise Exception(f"port {self.session_config.connect_port} is invalid.") + + def stop_serve(self): + if isinstance(self.socket_manager, TCPServer): + self.socket_manager.stop() + + def send(self, buffer: BufferType) -> None: + """ + npu major in 'send' (client) + """ + # know receiver receive and go next + if isinstance(buffer, ApiData): + buffer = move2target_device(buffer, torch.device('cpu')) + + if 'device' in buffer.kwargs: + buffer.kwargs.pop('device') + rank = buffer.rank if hasattr(buffer, "rank") and buffer.rank is not None else 0 + step = buffer.step if hasattr(buffer, "step") else 0 + io_buff = io.BytesIO() + try: + torch.save(buffer, io_buff) + except Exception as e: + self.logger.info(f"{buffer.name} can not be saved, skip: {e}") + return + data = io_buff.getvalue() + self.socket_manager.add_to_sending_queue(data, rank=rank, step=step) + + def recv(self, timeout_ms=0) -> Optional[BufferType]: + buffer = None + while buffer is None: + if timeout_ms > 0: + time.sleep(timeout_ms / 1000.0) + if buffer is None and not self.data_queue.empty(): + buffer = self.data_queue.get() + break + if buffer is None and timeout_ms > 0: # timeout is the only case we give up and return None + break + if self.message_end and self.data_queue.empty(): + buffer = b"KILL_CONFIRM" + self.kill_progress = True + break + time.sleep(0.1) # waiting outside the lock before next attempt + if buffer is None: + # this is a result of a timeout + self.logger.info(f"RECEIVE API DATA TIMED OUT") + else: + if buffer == b"STOP_": + return "STOP_" + if buffer == b"KILL_": + self.message_end = True + return "STOP_" + if buffer == b"KILL_CONFIRM": + self.kill_progress = True + return "KILL_" + buffer = io.BytesIO(buffer) + try: + buffer = torch.load(buffer, map_location="cpu") + except Exception as e: + self.logger.warning("there is something error. please check it. %s", e) + if isinstance(buffer, bytes): + return None + if isinstance(buffer, str): + return buffer + + return buffer + + def upload(self, buffer: BufferType): + if isinstance(buffer, ApiData): + buffer = move2target_device(buffer, torch.device('cpu')) + file_path = os.path.join(self.session_config.nfs_path, buffer.name + ".pt") + else: + file_path = os.path.join(self.session_config.nfs_path, buffer + f"_{int(time.time())}") + + try: + save_pt(buffer, file_path) + except Exception as e: + self.logger.warning("there is something error in save_pt. please check it. %s", e) + + def download(self): + for file_type in ("start*", "*.pt", "end*"): + cur_file = next(self.nfs_path.glob(file_type), None) + if cur_file is not None: + break + + if cur_file is None: + return None + else: + buffer = None + try: + buffer = torch.load(cur_file) + except Exception as e: + self.logger.warning("there is something error. please check it. %s", e) + remove_path(cur_file) + return buffer + + +def move2device_exec(obj, device): + if isinstance(obj, (tuple, list)): + data_list = [move2device_exec(val, device) for val in obj] + return data_list if isinstance(obj, list) else tuple(data_list) + if isinstance(obj, dict): + return {key: move2device_exec(val, device) for key, val in obj.items()} + elif isinstance(obj, torch.Tensor): + obj = obj.detach() + if obj.device.type != device: + obj = obj.to(device) + return obj + elif "return_types" in str(type(obj)): + return move2device_exec(tuple(obj), device) + elif isinstance(obj, torch._C.device): + return torch.device(device) + else: + return obj + + +def move2target_device(buffer: ApiData, target_device): + # handle args + new_args = move2device_exec(buffer.args, target_device) + + # handle kwargs + new_kwargs = move2device_exec(buffer.kwargs, target_device) + + # handle result + new_results = move2device_exec(buffer.result, target_device) + + if target_device == torch.device('cpu') or target_device == "cpu": + return ApiData(buffer.name, tuple(new_args), new_kwargs, new_results, buffer.step, buffer.rank) + else: + return ApiData(buffer.name, tuple(new_args), new_kwargs, buffer.result, buffer.step, buffer.rank) diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py new file mode 100644 index 0000000000000000000000000000000000000000..df7abc188ddf0b1c18c2269ae2d3e811bb4425f0 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/client.py @@ -0,0 +1,321 @@ +import hashlib +import io +import struct +import time +import os +import signal +import sys +from queue import Queue +from threading import Thread +from typing import Union + +from OpenSSL import SSL +from twisted.internet import ssl, reactor, protocol, endpoints +from twisted.protocols.basic import FileSender + +from msprobe.pytorch.common.utils import logger +from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.ssl_config import cipher_list + + +class TCPDataItem: + def __init__(self, data, + sequence_number: int, + rank: int = 0, + step: int = 0): + self.raw_data = data + self.sequence_number = sequence_number + self.rank = rank + self.step = step + self.retry_times = 0 + self.pending_time = 0 + self.busy_time = 0 + + +class TCPClient: + MAX_SENDING_QUEUE_SIZE = 20 + ACK_SUCCESS = b"OK___" + ACK_ERROR = b"ERROR" + ACK_BUSY = b"BUSY_" + ACK_STOP = b"STOP_" + ACK_STOP_CONFIRM = b"OVER_" + ACK_KILL_PROCESS = b"KILL_" + + QUEUE_PENDING_TIME = 600 # 队列10分钟都处于阻塞状态,则终止sending进程 + RESEND_RETRY_TIMES = 2 # 最大重传数 + RESEND_TIMER_TIME = 5 # 接收ACK超时定时器 + RESEND_PENDING_TIME = 60 # 连续pending时间超过1分钟则放弃该数据 + + def __init__(self, host="localhost", port=8000, check_sum=False, tls_path=None): + self.send_queue = Queue(self.MAX_SENDING_QUEUE_SIZE) + self.resend_dict = dict() + self.host = host + self.port = port + self.tls_path = tls_path + self.factory = None + self.sequence_number = 0 + self.signal_exit = False + self.tcp_manager = ClientProtocol(ack_queue_size=100, + chunk_size=655360, + check_sum=check_sum) + self.send_thread = Thread(target=self._sending_queue_data) + self.send_thread.setDaemon(True) + self.send_thread.start() + self.destroy_thread = Thread(target=self._destroy_queue_data) + self.destroy_thread.setDaemon(True) + self.destroy_thread.start() + + @staticmethod + def run_reactor(): + reactor.run(installSignalHandlers=False) + + def start(self): + def conn_callback(cur_protocol): + if cur_protocol.transport and cur_protocol.transport.getPeer().host == self.host: + logger.debug(f"Process: {os.getpid()} connects to server successfully.") + else: + logger.warning(f"Process: {os.getpid()} fails to connect to server. ") + raise ConnectionError(f"Failed to connect to {self.host}.") + + def conn_err_callback(failure): + self.signal_exit = True + time.sleep(1) + reactor.stop() + logger.error(f"Failed to connected {self.host} {self.port}. Reason is {failure.getErrorMessage()}") + os.kill(os.getpid(), signal.SIGKILL) + os.kill(os.getppid(), signal.SIGKILL) + + def cur_protocol(): + return self.tcp_manager + + self.factory = MessageClientFactory() + self.factory.protocol = cur_protocol + if self.tls_path: + client_key = os.path.join(self.tls_path, "client.key") + client_crt = os.path.join(self.tls_path, "client.crt") + client_context_factory = ssl.DefaultOpenSSLContextFactory(client_key, client_crt, SSL.TLSv1_2_METHOD) + client_context_ = client_context_factory.getContext() + client_context_.set_cipher_list(cipher_list) + client_context_.set_options(SSL.OP_NO_RENEGOTIATION) + endpoint = endpoints.SSL4ClientEndpoint(reactor, self.host, self.port, client_context_factory) + else: + endpoint = endpoints.TCP4ClientEndpoint(reactor, self.host, self.port) + d = endpoint.connect(self.factory) + d.addCallback(conn_callback) + d.addErrback(conn_err_callback) + + reactor_thread = Thread(target=self.run_reactor, daemon=True) + reactor_thread.start() + + def send_after_queue_empty(self, data): + while not self._ready_to_exit(): + self.add_to_sending_queue(data) + time.sleep(2) + + def check_client_alive(self): + return self.factory.num_connections > 0 + + def stop(self): + self.tcp_manager.connection_timeout() + + def send_stop_signal(self): + self.send_after_queue_empty(self.ACK_STOP) + while not self._ready_to_exit(): + if not self.check_client_alive(): + break + time.sleep(1) + while not self.tcp_manager.kill_process: + time.sleep(1) + + def add_to_sending_queue(self, data: Union[bytes, TCPDataItem], rank: int = 0, step: int = 0): + if self._ready_to_exit(): + return + + send_data = data + if not isinstance(data, TCPDataItem): + send_data = TCPDataItem(data=data, + sequence_number=self.sequence_number, + rank=rank, + step=step) + self.sequence_number += 1 + + self.send_queue.put(send_data, block=True, timeout=self.QUEUE_PENDING_TIME) + + def _send_data(self, data: TCPDataItem): + self.tcp_manager.send_wrapped_data(data.raw_data, + sequence_number=data.sequence_number, + rank=data.rank, + step=data.step + ) + + def _sending_queue_data(self): + while True: + if not self.tcp_manager.is_connected: + continue + + while self.send_queue.qsize() > 0: + if self._ready_to_exit(): + break + if len(self.resend_dict) < self.MAX_SENDING_QUEUE_SIZE: + data_obj = self.send_queue.get() + self._send_data(data_obj) + resend_key = str(data_obj.sequence_number) + "_" + str(data_obj.rank) + "_" + str(data_obj.step) + if resend_key not in self.resend_dict.keys(): + # Send data for the first time + self.resend_dict[resend_key] = data_obj + else: + time.sleep(0.1) + + if self._ready_to_exit(): + logger.debug("Successfully close sending process.") + break + time.sleep(0.1) + + def _destroy_queue_data(self): + while True: + if self._ready_to_exit(): + break + + while len(self.resend_dict) > 0 and self.tcp_manager.ack_queue.qsize() > 0: + ack_info, seq_number, rank, step = self.tcp_manager.ack_queue.get() + obj_key = str(seq_number) + "_" + str(rank) + "_" + str(step) + current_item = self.resend_dict.get(obj_key) + + if current_item is None: + continue + + if ack_info == self.ACK_SUCCESS: + self.resend_dict.pop(obj_key) + elif ack_info == self.ACK_BUSY: + logger.debug("RECV BUSY ACK") + if current_item.busy_time > 5: + self._resend_data(current_item) + else: + current_item.busy_time += 1 + elif ack_info == self.ACK_ERROR: + logger.debug("RECV ERROR ACK") + self._resend_data(current_item) + elif ack_info == self.ACK_STOP_CONFIRM: + logger.debug("RECV STOP ACK") + self.factory.num_connections -= 1 + + break + + time.sleep(0.1) + + def _resend_data(self, data: TCPDataItem): + if data.retry_times < self.RESEND_RETRY_TIMES: + data.retry_times += 1 + logger.debug(f"Resend data seq number: {data.sequence_number}") + self.add_to_sending_queue(data) + else: + self.resend_dict.pop(data.sequence_number) + logger.debug(f"SKIP send sequence number {data.sequence_number} after retry {data.retry_times} times!") + + def _pending_data(self, data: TCPDataItem): + if data.pending_time >= self.RESEND_PENDING_TIME: + self.resend_dict.pop(data.sequence_number) + logger.debug(f"SKIP send sequence number {data.sequence_number} after pending {data.pending_time} times!") + return + + # wait time is 100MB per second + pending_time = max(1, len(data.raw_data) // (2 ** 20 * 50)) + data.pending_time += pending_time + time.sleep(pending_time) + + def _ready_to_exit(self): + return self.signal_exit or self.tcp_manager.signal_exit + + +class ClientProtocol(protocol.Protocol): + TIMEOUT = 60 * 10 + + def __init__(self, ack_queue_size=100, chunk_size=65536, check_sum=False): + self.buffer = io.BytesIO() + self.is_connected = False + self.check_sum = check_sum + self.tell = 0 + self.ack_queue = Queue(maxsize=ack_queue_size) + self.file_sender = FileSender() + self.file_sender.CHUNK_SIZE = chunk_size + self.signal_exit = False + self.defer = None + self.kill_process = False + + def dataReceived(self, data): + if self.timeout_call.active(): + self.timeout_call.reset(self.TIMEOUT) + + self.buffer.seek(0, 2) + self.buffer.write(data) + self.buffer.seek(self.tell) + while True: + if len(self.buffer.getvalue()) >= 29: # 5 + 8 * 3 + ack = self.buffer.read(5) + seq_number = struct.unpack('!Q', self.buffer.read(8))[0] + rank = struct.unpack('!Q', self.buffer.read(8))[0] + step = struct.unpack('!Q', self.buffer.read(8))[0] + if ack == b"KILL_": + self.kill_process = True + logger.debug(f"接收到KILL信号, PID {os.getpid()}") + if ack == b"OVER_": + self.factory.num_connections -= 1 + self.tell += 29 + if not self.ack_queue.full(): + self.ack_queue.put((ack, seq_number, rank, step)) + self.buffer = io.BytesIO(self.buffer.getvalue()[self.tell:]) + self.tell = 0 + else: + time.sleep(0.1) + else: + break + + def send_wrapped_data(self, data, sequence_number: int = 0, rank: int = 0, step: int = 0): + length = len(data) + md5_hash = hashlib.md5(data).hexdigest() if self.check_sum else "" + while True: + if self.defer is None or self.defer.called: + self.defer = self.send_large_data( + length.to_bytes(8, byteorder='big') + + sequence_number.to_bytes(8, byteorder='big') + + rank.to_bytes(8, byteorder='big') + + step.to_bytes(8, byteorder='big') + + md5_hash.encode() + + data) + break + time.sleep(0.01) + + def send_large_data(self, data): + d = self.file_sender.beginFileTransfer(io.BytesIO(data), self.transport) + return d + + def connection_timeout(self): + if self.factory.num_connections <= 0: + return + + self.factory.num_connections -= 1 + logger.debug(f"超时退出{self.transport.addr}, PID {os.getpid()}") + self.transport.loseConnection() + + def connectionMade(self): + self.timeout_call = reactor.callLater(self.TIMEOUT, self.connection_timeout) + self.is_connected = True + self.factory.num_connections += 1 + logger.info("successfully connect server") + + def connectionLost(self, reason): + self.signal_exit = True + self.factory.num_connections -= 1 + logger.info("Lost connection with server") + + +class MessageClientFactory(protocol.ClientFactory): + def __init__(self): + self.num_connections = 0 + + def clientConnectionFailed(self, connector, reason): + logger.info(f"Fail to connection with server: {reason.getErrorMessage()}") + reactor.stop() + + def clientConnectionLost(self, connector, reason): + logger.info(f"Client lost connection with server: {reason.getErrorMessage()}") + reactor.stop() diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py new file mode 100644 index 0000000000000000000000000000000000000000..1a5462203423f93308b5d1f947661ebb71583cd4 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/device_dispatch.py @@ -0,0 +1,115 @@ +import time + +import torch +import torch.multiprocessing as mp + +from msprobe.core.common.const import Const +from msprobe.pytorch.common.utils import logger +from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import move2target_device + + +def run_ut_process(xpu_id, compare, consumer_queue, func, config): + """ When consumer_queue(shared with ConsumerDispatcher) is not empty, consume api data from consumer_queue. + :param xpu_id: int + :param compare: instance of Comparator + :param consumer_queue: shared queues of ConsumerDispatcher + :param func: run_touch_api_online + :param config: run_ut_config + :return: + """ + device = torch.device(f'cuda:{xpu_id}') + + while True: + if consumer_queue.empty(): + time.sleep(0.1) + continue + + api_data = consumer_queue.get() + if api_data == "KILL_": + # current consumer finish + return + + api_full_name = api_data.name + api_data = move2target_device(api_data, device) + try: + data_info = func(api_full_name, api_data, config.backward_content) + logger.debug(f"success exec in device {api_full_name}") + is_fwd_success, is_bwd_success = compare.compare_output(api_full_name, data_info) + logger.info(f"running api_full_name {api_full_name} ut, " + f"is_fwd_success: {is_fwd_success}, " + f"is_bwd_success: {is_bwd_success}") + except Exception as err: + [api_type, api_name, _] = api_full_name.split(Const.SEP) + if "expected scalar type Long" in str(err): + logger.warning(f"API {api_name} not support int32 tensor in CPU, please add {api_name} to CONVERT_API " + f"'int32_to_int64' list in accuracy_tools/msprobe/core/common/const.py file.") + elif api_type in [Const.DISTRIBUTED]: + logger.info(f"{api_full_name} is not supported for run ut. SKIP.") + else: + logger.error(f"Run {api_full_name} UT Error: {str(err)}") + + compare.write_summary_csv((api_full_name, "SKIP", "SKIP", str(err), api_data.rank)) + + finally: + torch.cuda.empty_cache() + + +class ConsumerDispatcher: + def __init__(self, compare, capacity=10, num_workers=8, device: str = "gpu") -> None: + self.num_workers = num_workers + self.capacity = capacity + self.compare = compare + self.queues = [] + self.processes = [] + self.reverse_sort = False + self.pool = None + self.device = device + self.data_id = 0 + self.lock = mp.Lock() + self.result_queue = mp.Queue() + mp.set_start_method("spawn", force=True) + + def start(self, handle_func, config): + self.queues = [mp.Queue(maxsize=self.capacity) for _ in range(self.num_workers)] + for xpu_id, q in enumerate(self.queues): + p = mp.Process(name="run_ut_process", target=run_ut_process, + args=(xpu_id, self.compare, q, handle_func, config)) + + p.start() + self.processes.append(p) + logger.info("Successfully start unittest process.") + + def stop(self): + for q in self.queues: + while q.full(): + time.sleep(0.1) + q.put("KILL_") + + for p in self.processes: + p.join() + logger.info("Successfully stop unittest process.") + + def update_consume_queue(self, api_data): + while True: + index = self._choose_max_empty_site_strategy() + if index != -1: + q = self.queues[index] + q.put(api_data) + logger.debug(f"将{api_data.name}调度给第{index}个GPU") + break + logger.debug("所有的UT队列都已满, 阻塞中") + time.sleep(0.1) + + def _choose_max_empty_site_strategy(self): + maximum = 0 + index = -1 + # 充分利用多卡资源,防止任务过多分配给前面的卡 + _reverse = 1 if not self.reverse_sort else -1 + for i, q in enumerate(self.queues[::_reverse]): + empty_site = self.capacity - q.qsize() + if empty_site > maximum: + maximum = empty_site + index = i + index = len(self.queues) - index - 1 if index != -1 and self.reverse_sort else index + self.reverse_sort = not self.reverse_sort + return index diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py new file mode 100644 index 0000000000000000000000000000000000000000..521f8d37f6f2c01b4ce4abc3b726155561a5d1dc --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/server.py @@ -0,0 +1,218 @@ +import os.path +import struct +import hashlib +import time +import io +from threading import Thread + +from OpenSSL import SSL +from twisted.internet import ssl, reactor, protocol, endpoints + +from msprobe.pytorch.common.utils import logger +from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.ssl_config import cipher_list + + +class TCPServer: + def __init__(self, port, shared_queue, check_sum=False, tls_path=None) -> None: + self.port = port + self.shared_queue = shared_queue + self.check_sum = check_sum + self.tls_path = tls_path + self.factory = MessageServerFactory() + self.reactor_thread = None + + @staticmethod + def run_reactor(): + reactor.run(installSignalHandlers=False) + + def start(self): + self.factory.protocol = self.build_protocol + + if self.tls_path: + server_key = os.path.join(self.tls_path, "server.key") + server_crt = os.path.join(self.tls_path, "server.crt") + server_context_factory = ssl.DefaultOpenSSLContextFactory(server_key, server_crt, SSL.TLSv1_2_METHOD) + server_context_ = server_context_factory.getContext() + server_context_.set_cipher_list(cipher_list) + server_context_.set_options(SSL.OP_NO_RENEGOTIATION) + endpoint = endpoints.SSL4ServerEndpoint(reactor, self.port, server_context_factory) + else: + endpoint = endpoints.TCP4ServerEndpoint(reactor, self.port) + endpoint.listen(self.factory) + self.reactor_thread = Thread(target=self.run_reactor, daemon=True) + self.reactor_thread.start() + + def is_running(self): + return not self.factory.is_all_connection_closed() + + def stop(self): + self.factory.doStop() + reactor.callFromThread(reactor.sigInt, 2) + self.reactor_thread.join() + + def build_protocol(self): + return ServerProtocol(self.shared_queue, self.check_sum) + + +class ServerProtocol(protocol.Protocol): + ACK_SUCCESS = b"OK___" + ACK_ERROR = b"ERROR" + ACK_BUSY = b"BUSY_" + ACK_STOP = b"STOP_" + ACK_STOP_CONFIRM = b"OVER_" + ACK_KILL_PROCESS = b"KILL_" + + def __init__(self, shared_queue, check_sum=False): + self.start_time = None + self.buffer = io.BytesIO() + self.consumer_queue = shared_queue + self.check_sum = check_sum + self.length_width = 8 + self.md5_width = 32 + self.obj_length = None + self.tell = 0 + self.obj_md5 = None + self.obj_body = None + self.sequence_number = -1 + self.rank = -1 + self.step = -1 + self.sequence_number_dict = dict() + + def connectionMade(self): + self.buffer = io.BytesIO() + self.obj_length = None + self.tell = 0 + self.obj_md5 = None + self.obj_body = None + self.factory.transport_dict[self.transport] = 1 + self.factory.transport_list.append(self.transport) + logger.info(f"Connected to {self.transport.getPeer()} successfully.") + + def connectionLost(self, reason): + self.factory.transport_dict.pop(self.transport, None) + if len(self.factory.transport_dict) == 0: + self.consumer_queue.put(self.ACK_KILL_PROCESS) + + logger.info(f"Lost connection with {self.transport.getPeer()}. Reason is: {reason} 与客户端 断开连接, " + f"current connection number is: {len(self.factory.transport_dict)}") + + def send_ack(self, ack_info): + ack_message = b"".join([ + ack_info, + self.sequence_number.to_bytes(8, byteorder='big'), + self.rank.to_bytes(8, byteorder='big'), + self.step.to_bytes(8, byteorder='big') + ]) + self.transport.write(ack_message) + + def post_process(self): + send_busy_ack = False + while self.consumer_queue.full(): + if not send_busy_ack: + self.send_ack(self.ACK_BUSY) + logger.debug("sending BUSY ACK") + send_busy_ack = True + time.sleep(0.1) + + obj_key = str(self.sequence_number) + "_" + str(self.rank) + "_" + str(self.step) + + recv_md5 = hashlib.md5(self.obj_body).hexdigest() + if self.check_sum and recv_md5 != self.obj_md5: + # when needs check md5 and check no pass, indicates received data error, send b"ERROR" to client. + logger.debug(f"Error:接收数据有问题,流水号{self.sequence_number}, expected {self.obj_md5}, but get {recv_md5}") + self.send_ack(self.ACK_ERROR) + else: + if self.obj_body == self.ACK_STOP: + self.handle_with_stop() + else: + self.send_ack(self.ACK_SUCCESS) + if obj_key in self.sequence_number_dict: + logger.debug(f"这是一次异常的重传,可以忽略。 {obj_key}, {self.sequence_number_dict}") + else: + self.sequence_number_dict[obj_key] = self.obj_md5 + self.consumer_queue.put(self.obj_body, block=True) + + self.reset_env() + finish_time = time.time() + logger.debug(f"finish_time: {finish_time - self.start_time}") + + def handle_with_stop(self): + logger.debug(f"接收到停止传输信号 TCP{self.transport.getPeer()}") + self.send_ack(self.ACK_STOP_CONFIRM) + if len(self.factory.transport_dict) == 0: + _rank, _step, _sequence_number = 0, 0, 100000000 + ack_kill = self.ACK_KILL_PROCESS + \ + _sequence_number.to_bytes(8, byteorder='big') + \ + _rank.to_bytes(8, byteorder='big') + \ + _step.to_bytes(8, byteorder='big') + for trans in self.factory.transport_list: + trans.write(ack_kill) + logger.debug(f"发送KILL信息给{self.transport.getPeer()}") + self.consumer_queue.put(self.ACK_KILL_PROCESS) + time.sleep(2) + + def reset_env(self): + self.obj_length = None + self.sequence_number = -1 + self.rank = -1 + self.step = -1 + self.obj_md5 = None + self.obj_body = None + + def dataReceived(self, data): + self.buffer.seek(0, 2) + self.buffer.write(data) + self.buffer.seek(self.tell) + + # The first data packet is packet header, it contains obj_length, sequence_number, rank, step + if self.obj_length is None and len(self.buffer.getvalue()) >= self.length_width * 4: + self.start_time = time.time() + self.obj_length = struct.unpack('!Q', self.buffer.read(self.length_width))[0] + self.sequence_number = struct.unpack('!Q', self.buffer.read(self.length_width))[0] + self.rank = struct.unpack('!Q', self.buffer.read(self.length_width))[0] + self.step = struct.unpack('!Q', self.buffer.read(self.length_width))[0] + self.tell += self.length_width * 4 + logger.debug( + f"流水号: {self.sequence_number}; RANK: {self.rank}; STEP: {self.step}; Length: {self.obj_length}") + + # If needs check md5 but not parse md5 yet, read 32b md5 values + check_sum_and_md5 = (self.check_sum + and self.obj_length is not None + and self.obj_md5 is None + and len(self.buffer.getvalue()) - self.tell >= self.md5_width) + if check_sum_and_md5: + self.obj_md5 = self.buffer.read(self.md5_width).decode() + self.tell += self.md5_width + logger.debug(f"MD5: {self.obj_md5}") + + current_length = len(self.buffer.getvalue()) - self.tell + if self.obj_length is not None and 0 < self.obj_length <= current_length: + # Current api data receive finished + self.obj_body = self.buffer.read(self.obj_length) + + self.tell += self.obj_length + self.buffer = io.BytesIO(self.buffer.getvalue()[self.tell:]) + self.buffer.seek(0) + self.tell = 0 + recv_data_time = time.time() + logger.debug(f"self.sequence_number {self.sequence_number} " + f"recv_data_time {recv_data_time - self.start_time}") + + if self.obj_body == self.ACK_STOP: + # Indicates the current TCP link receives a STOP signal and remove from the transport_dict + _transport = self.factory.transport_dict.pop(self.transport, None) + logger.debug(f"接收到b'STOP_' self.sequence_number {self.sequence_number} ") + self.post_process() + + +class MessageServerFactory(protocol.ServerFactory): + def __init__(self) -> None: + """ + transport_dict: links that have not completed data transmission. + transport_list: Records all TCP links. Appends TCP link to the transport list when a new TCP link is established. + """ + self.transport_dict = {} + self.transport_list = [] + + def is_all_connection_closed(self): + return len(self.transport_dict) == 0 diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py new file mode 100644 index 0000000000000000000000000000000000000000..8e29cafd22b564868e0cbbab4181b65132db7839 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/tensor_transport_layer/ssl_config.py @@ -0,0 +1,10 @@ +cipher_list = ":".join([ + 'ECDHE-ECDSA-AES128-GCM-SHA256', + 'ECDHE-RSA-AES128-GCM-SHA256', + 'ECDHE-ECDSA-AES256-GCM-SHA384', + 'ECDHE-RSA-AES256-GCM-SHA384', + 'ECDHE-ECDSA-CHACHA20-POLY1305', + 'ECDHE-RSA-CHACHA20-POLY1305', + 'DHE-RSA-AES128-GCM-SHA256', + 'DHE-RSA-AES256-GCM-SHA384' +]) diff --git a/debug/accuracy_tools/msprobe/pytorch/compare/compare_cli.py b/debug/accuracy_tools/msprobe/pytorch/compare/compare_cli.py new file mode 100644 index 0000000000000000000000000000000000000000..b344d4efbf40b7f0223089778e78cb8f36d7044d --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/compare/compare_cli.py @@ -0,0 +1,24 @@ +import json +from msprobe.core.common.file_check import FileOpen, check_file_type +from msprobe.core.common.const import FileCheckConst +from msprobe.core.common.utils import CompareException +from msprobe.core.common.log import logger +from msprobe.pytorch.compare.pt_compare import compare +from msprobe.pytorch.compare.distributed_compare import compare_distributed + + +def compare_cli(args): + with FileOpen(args.input_path, "r") as file: + input_param = json.load(file) + npu_path = input_param.get("npu_path", None) + bench_path = input_param.get("bench_path", None) + + if check_file_type(npu_path) == FileCheckConst.FILE and check_file_type(bench_path) == FileCheckConst.FILE: + compare(input_param, args.output_path, stack_mode=args.stack_mode, auto_analyze=args.auto_analyze, + fuzzy_match=args.fuzzy_match) + elif check_file_type(npu_path) == FileCheckConst.DIR and check_file_type(bench_path) == FileCheckConst.DIR: + kwargs = {"stack_mode": args.stack_mode, "auto_analyze": args.auto_analyze, "fuzzy_match": args.fuzzy_match} + compare_distributed(npu_path, bench_path, args.output_path, **kwargs) + else: + logger.error("The npu_path and bench_path need to be of the same type.") + raise CompareException(CompareException.INVALID_COMPARE_MODE) diff --git a/debug/accuracy_tools/msprobe/pytorch/compare/distributed_compare.py b/debug/accuracy_tools/msprobe/pytorch/compare/distributed_compare.py index caac139580751a9b9a36d0f73fbf163263d85a51..22d0598ed5ae0130a9e1562436d7c6913ad5411f 100644 --- a/debug/accuracy_tools/msprobe/pytorch/compare/distributed_compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/compare/distributed_compare.py @@ -15,63 +15,16 @@ # limitations under the License. """ import os -import sys -import re from msprobe.core.common.utils import CompareException, check_compare_param, \ - check_configuration_param, task_dumppath_get, check_file_or_directory_path, check_regex_prefix_format_valid -from msprobe.pytorch.compare.acc_compare import compare_core + check_configuration_param, task_dumppath_get from msprobe.core.common.file_check import create_directory from msprobe.core.common.exceptions import FileCheckException -from msprobe.pytorch.common.log import logger +from msprobe.core.common.log import logger +from msprobe.pytorch.compare.pt_compare import PTComparator +from msprobe.core.compare.utils import check_and_return_dir_contents, extract_json def compare_distributed(npu_dump_dir, bench_dump_dir, output_path, **kwargs): - def check_and_return_dir_contents(dump_dir, prefix): - """ - check the given dump dir and validate files in dump dir by using the given prefix patterns to build a - pattern: ^{prefix}(?:0|[0-9][1-9]*)?$ - - Args: - dump_dir (str): dump dir - prefix (str): prefix for the patterns, prefix should be less than 20 characters and alphanumeric/-/_ only - - Returns: - content [list]: dir contents - Raises: - CompareException: invalid path - ValueError: prefix not match the patterns - - """ - check_regex_prefix_format_valid(prefix) - check_file_or_directory_path(dump_dir, True) - contents = os.listdir(dump_dir) - pattern = re.compile(rf'^{prefix}(?:0|[0-9][1-9]*)?$') - for name in contents: - if not pattern.match(name): - logger.error( - f"dump_dir contains '{name}'. Expected '{prefix}'. This name is not in the format of dump " - f"output. Please check and delete irrelevant files in {dump_dir} and try again." - ) - raise CompareException(CompareException.INVALID_PATH_ERROR) - return contents - - def extract_json(dirname, stack_json=False): - json_path = '' - for fname in os.listdir(dirname): - full_path = os.path.join(dirname, fname) - if full_path.endswith('.json'): - json_path = full_path - if not stack_json and 'stack' not in json_path: - break - if stack_json and 'stack' in json_path: - break - - # Provide robustness on invalid directory inputs - if not json_path: - logger.error(f'No file is found in dump dir {dirname}. ') - raise CompareException(CompareException.NO_DUMP_FILE_ERROR) - return json_path - if kwargs.get('suffix'): logger.error("Argument 'suffix' is not supported for compare_distributed.") raise CompareException(CompareException.INVALID_PARAM_ERROR) @@ -89,14 +42,14 @@ def compare_distributed(npu_dump_dir, bench_dump_dir, output_path, **kwargs): for nr, br in zip(npu_ranks, bench_ranks): npu_data_dir = os.path.join(npu_dump_dir, nr) bench_data_dir = os.path.join(bench_dump_dir, br) - npu_json_path = extract_json(npu_data_dir, stack_json=False) - bench_json_path = extract_json(bench_data_dir, stack_json=False) - stack_json_path = extract_json(npu_data_dir, stack_json=True) + npu_path = extract_json(npu_data_dir, stack_json=False) + bench_path = extract_json(bench_data_dir, stack_json=False) + stack_path = extract_json(npu_data_dir, stack_json=True) dump_result_param = { - 'npu_json_path': npu_json_path, - 'bench_json_path': bench_json_path, - 'stack_json_path': stack_json_path, + 'npu_path': npu_path, + 'bench_path': bench_path, + 'stack_path': stack_path, 'is_print_compare_log': True } try: @@ -106,6 +59,7 @@ def compare_distributed(npu_dump_dir, bench_dump_dir, output_path, **kwargs): check_compare_param(dump_result_param, output_path, summary_compare=summary_compare, md5_compare=md5_compare) except (CompareException, FileCheckException) as error: logger.error('Compare failed. Please check the arguments and do it again!') - sys.exit(error.code) - compare_core(dump_result_param, output_path, suffix=f'_{nr}-{br}', summary_compare=summary_compare, + raise CompareException(error.code) from error + ptComparator = PTComparator() + ptComparator.compare_core(dump_result_param, output_path, suffix=f'_{nr}-{br}', summary_compare=summary_compare, md5_compare=md5_compare, **kwargs) diff --git a/debug/accuracy_tools/msprobe/pytorch/compare/match.py b/debug/accuracy_tools/msprobe/pytorch/compare/match.py index 6347d8887c85427fcb556eecb5cd4a7302166969..2a46105bdfd37b8c88c8cce9aa42f441279362b4 100644 --- a/debug/accuracy_tools/msprobe/pytorch/compare/match.py +++ b/debug/accuracy_tools/msprobe/pytorch/compare/match.py @@ -10,7 +10,7 @@ class AtenIrMapping(): yaml_path = os.path.join(cur_path, "mapping.yaml") with FileOpen(yaml_path, 'r') as f: self.aten_mapping = yaml.safe_load(f) - + def match(self, op1, op2): if "Aten" in op1 and "Aten" not in op2: return self.match_op(op1, op2) diff --git a/debug/accuracy_tools/msprobe/pytorch/compare/pt_compare.py b/debug/accuracy_tools/msprobe/pytorch/compare/pt_compare.py new file mode 100644 index 0000000000000000000000000000000000000000..1cc6301c53223bd94c94b98d2d52994f97b52c6a --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/compare/pt_compare.py @@ -0,0 +1,204 @@ +import json +import os.path +import torch +from msprobe.core.advisor.advisor import Advisor +from msprobe.core.common.utils import check_compare_param, add_time_with_xlsx, CompareException, \ + check_file_not_exists, check_configuration_param, task_dumppath_get +from msprobe.core.common.file_check import FileChecker, FileOpen, create_directory +from msprobe.core.common.const import FileCheckConst +from msprobe.core.common.log import logger +from msprobe.core.common.exceptions import FileCheckException +from msprobe.core.compare.utils import get_un_match_accuracy, get_accuracy +from msprobe.core.compare.multiprocessing_compute import ComparisonResult, _save_cmp_result +from msprobe.core.compare.highlight import find_compare_result_error_rows, highlight_rows_xlsx +from msprobe.core.compare.acc_compare import Comparator + + +class PTComparator (Comparator): + def __init__(self): + self.frame_name=PTComparator.__name__ + + def compare_ops(self,idx, dump_path_dict, result_df, lock, input_parma): + cos_result = [] + max_err_result = [] + max_relative_err_result = [] + err_mess = [] + one_thousand_err_ratio_result = [] + five_thousand_err_ratio_result = [] + is_print_compare_log = input_parma.get("is_print_compare_log") + for i in range(len(result_df)): + op_name = result_df.iloc[i, 0] + if is_print_compare_log: + logger.info("start compare: {}".format(op_name)) + cos_sim, max_abs_err, max_relative_err, one_thousand_err_ratio, five_thousand_err_ratio, err_msg = self.compare_by_op( + op_name, dump_path_dict, input_parma) + if is_print_compare_log: + logger.info( + "[{}] Compare result: cosine {}, max_abs_err {}, max_relative_err {}, {}, one_thousand_err_ratio {}, " + "five_thousand_err_ratio {}".format(op_name, cos_sim, max_abs_err, max_relative_err, err_msg, + one_thousand_err_ratio, five_thousand_err_ratio)) + cos_result.append(cos_sim) + max_err_result.append(max_abs_err) + max_relative_err_result.append(max_relative_err) + err_mess.append(err_msg) + one_thousand_err_ratio_result.append(one_thousand_err_ratio) + five_thousand_err_ratio_result.append(five_thousand_err_ratio) + + cr = ComparisonResult( + cos_result = cos_result, + max_err_result = max_err_result, + max_relative_err_result=max_relative_err_result, + err_msgs = err_mess, + one_thousand_err_ratio_result = one_thousand_err_ratio_result, + five_thousand_err_ratio_result = five_thousand_err_ratio_result + ) + + return _save_cmp_result(idx, cr, result_df, lock) + + def compare_process(self,file_handles, stack_mode, fuzzy_match, summary_compare=False, md5_compare=False): + npu_json_handle, bench_json_handle, stack_json_handle = file_handles + npu_json_data = json.load(npu_json_handle) + bench_json_data = json.load(bench_json_handle) + stack_json_data = json.load(stack_json_handle) + + if fuzzy_match: + logger.warning("This task uses fuzzy matching, which may affect the accuracy of the comparison.") + + npu_ops_queue = [] + bench_ops_queue = [] + result = [] + + ops_npu_iter = iter(npu_json_data['data']) + ops_bench_iter = iter(bench_json_data['data']) + read_err_npu = True + read_err_bench = True + last_npu_ops_len = 0 + last_bench_ops_len = 0 + + while True: + if not read_err_npu and not read_err_bench: + break + try: + last_npu_ops_len = len(npu_ops_queue) + op_name_npu = next(ops_npu_iter) + read_err_npu = True + npu_merge_list = self.gen_merge_list(npu_json_data,op_name_npu,stack_json_data,summary_compare,md5_compare) + if npu_merge_list: + npu_ops_queue.append(npu_merge_list) + except StopIteration: + read_err_npu = False + try: + last_bench_ops_len = len(bench_ops_queue) + op_name_bench = next(ops_bench_iter) + bench_merge_list = self.gen_merge_list(bench_json_data,op_name_bench,stack_json_data,summary_compare,md5_compare) + if bench_merge_list: + bench_ops_queue.append(bench_merge_list) + except StopIteration: + read_err_bench = False + + # merge all boolean expressions + both_empty = not npu_ops_queue and not bench_ops_queue + no_change = (len(npu_ops_queue) == last_npu_ops_len) and (len(bench_ops_queue) == last_bench_ops_len) + if both_empty or no_change: + continue + + n_match_point, b_match_point = super().match_op(npu_ops_queue, bench_ops_queue, fuzzy_match) + if n_match_point == -1 and b_match_point == -1: + continue + n_match_data = npu_ops_queue[n_match_point] + b_match_data = bench_ops_queue[b_match_point] + un_match_data = npu_ops_queue[0: n_match_point] + for npu_data in un_match_data: + get_un_match_accuracy(result, npu_data, md5_compare, summary_compare) + get_accuracy(result, n_match_data, b_match_data, summary_compare, md5_compare) + del npu_ops_queue[0: n_match_point + 1] + del bench_ops_queue[0: b_match_point + 1] + if npu_ops_queue: + for npu_data in npu_ops_queue: + get_un_match_accuracy(result, npu_data, md5_compare, summary_compare) + + result_df = self.make_result_table(result,md5_compare,summary_compare,stack_mode) + return result_df + + def read_npy_data(self,dir_path, file_name): + data_path = os.path.join(dir_path, file_name) + path_checker = FileChecker(data_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE, + FileCheckConst.PT_SUFFIX, False) + data_path = path_checker.common_check() + data_value = torch.load(data_path, map_location=torch.device('cpu')).detach() # detach for less memory + if data_value.dtype == torch.bfloat16: + data_value = data_value.to(torch.float32) + data_value = data_value.numpy() + return data_value + + def compare_core(self,input_parma, output_path, **kwargs): + """ + Compares data from multiple JSON files and generates a comparison report. + + Args: + input_parma (dict): A dictionary containing paths to JSON files ("npu_path", "bench_path", + "stack_path"). + output_path (str): The path where the output Excel report will be saved. + **kwargs: Additional keyword arguments including: + - stack_mode (bool, optional): Enables stack mode comparison. Defaults to False. + - auto_analyze (bool, optional): If True, triggers automatic analysis after comparison. Defaults to True. + - suffix (str, optional): Suffix to append to the output file name. Defaults to ''. + - fuzzy_match (bool, optional): Enables fuzzy matching during comparison. Defaults to False. + - summary_compare (bool, optional): Enables summary comparison mode. Defaults to False. + - md5_compare (bool, optional): Enables MD5 comparison. Defaults to False. + + Returns: + """ + # get kwargs or set default value + stack_mode = kwargs.get('stack_mode', False) + auto_analyze = kwargs.get('auto_analyze', True) + suffix = kwargs.get('suffix', '') + fuzzy_match = kwargs.get('fuzzy_match', False) + summary_compare = kwargs.get('summary_compare', False) + md5_compare = kwargs.get('md5_compare', False) + + logger.info("Please check whether the input data belongs to you. If not, there may be security risks.") + file_name = add_time_with_xlsx("compare_result" + suffix) + file_path = os.path.join(os.path.realpath(output_path), file_name) + check_file_not_exists(file_path) + highlight_dict = {'red_rows': [], 'yellow_rows': []} + + with FileOpen(input_parma.get("npu_path"), "r") as npu_json, \ + FileOpen(input_parma.get("bench_path"), "r") as bench_json, \ + FileOpen(input_parma.get("stack_path"), "r") as stack_json: + result_df = self.compare_process([npu_json, bench_json, stack_json], stack_mode, fuzzy_match, + summary_compare, md5_compare) + + if not md5_compare and not summary_compare: + result_df = self._do_multi_process(input_parma, result_df) + find_compare_result_error_rows(result_df, highlight_dict, summary_compare, md5_compare) + highlight_rows_xlsx(result_df, highlight_dict, file_path) + if auto_analyze: + advisor = Advisor(result_df, output_path) + advisor.analysis() + + + +def compare(input_param, output_path, stack_mode=False, auto_analyze=True, fuzzy_match=False): + try: + summary_compare, md5_compare = task_dumppath_get(input_param) + check_configuration_param(stack_mode, auto_analyze, fuzzy_match) + create_directory(output_path) + check_compare_param(input_param, output_path, summary_compare, md5_compare) + except (CompareException, FileCheckException) as error: + logger.error('Compare failed. Please check the arguments and do it again!') + raise CompareException(error.code) from error + ptComparator = PTComparator() + ptComparator.compare_core(input_param, output_path, stack_mode=stack_mode, + auto_analyze=auto_analyze, fuzzy_match=fuzzy_match, summary_compare=summary_compare, + md5_compare=md5_compare) + + + + + + + + + + \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/pytorch/debugger/debugger_config.py b/debug/accuracy_tools/msprobe/pytorch/debugger/debugger_config.py index f1289e9b013b7f1558e6a2339ecd5e84804c592f..7c32be7cc34e1f90334655a1ac03fbd1980b2228 100644 --- a/debug/accuracy_tools/msprobe/pytorch/debugger/debugger_config.py +++ b/debug/accuracy_tools/msprobe/pytorch/debugger/debugger_config.py @@ -35,7 +35,16 @@ class DebuggerConfig: "preheat_step": task_config.preheat_step if task_config.preheat_step else 15, "max_sample": task_config.max_sample if task_config.max_sample else 20, } - + + self.online_run_ut = False + if self.task == Const.TENSOR: + # dump api tensor and collaborate with online run_ut + self.online_run_ut = task_config.online_run_ut if task_config.online_run_ut else False + self.nfs_path = task_config.nfs_path if task_config.nfs_path else "" + self.tls_path = task_config.tls_path if task_config.tls_path else "" + self.host = task_config.host if task_config.host else "" + self.port = task_config.port if task_config.port else -1 + self.check() if self.step: self.step.sort() diff --git a/debug/accuracy_tools/msprobe/pytorch/debugger/precision_debugger.py b/debug/accuracy_tools/msprobe/pytorch/debugger/precision_debugger.py index 012d42fafeaf0909fcbc72cf925c93fda51d0cc6..8433f0af695bf7afb08f32d99fc6f58788ab8b97 100644 --- a/debug/accuracy_tools/msprobe/pytorch/debugger/precision_debugger.py +++ b/debug/accuracy_tools/msprobe/pytorch/debugger/precision_debugger.py @@ -36,7 +36,7 @@ class PrecisionDebugger: common_config, task_config = parse_json_config(config_path, task) self.task = common_config.task if self.task == Const.GRAD_PROBE: - GradientMonitor(task_config, model) + self.gm = GradientMonitor(common_config, task_config) return if step: common_config.step = step @@ -102,6 +102,14 @@ class PrecisionDebugger: raise Exception("PrecisionDebugger instance is not created.") cls._instance.service.step() + @classmethod + def monitor(cls, model): + if not cls._instance: + raise Exception("PrecisionDebugger instance is not created.") + if cls._instance.task != Const.GRAD_PROBE: + return + cls._instance.gm.monitor(model) + def iter_tracer(func): def func_wrapper(*args, **kwargs): diff --git a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/compare/grad_saver.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/compare/grad_saver.py index 6781a1c2fc4c7fd348d330a49352e2f6195e8a71..e58223e597ed487363f9270e7ad81c24c03f3e75 100644 --- a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/compare/grad_saver.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/compare/grad_saver.py @@ -2,7 +2,7 @@ import torch from msprobe.core.common.exceptions import FreeBenchmarkException from msprobe.pytorch.free_benchmark import logger from msprobe.pytorch.free_benchmark.common.constant import CommonField -from msprobe.pytorch.free_benchmark.common.params import DataParams, HandlerParams +from msprobe.pytorch.free_benchmark.common.params import DataParams, HandlerParams, data_pre_deal from msprobe.pytorch.free_benchmark.perturbed_layers.layer_factory import LayerFactory from msprobe.pytorch.free_benchmark.result_handlers.handler_factory import ( FuzzHandlerFactory, @@ -16,7 +16,6 @@ class GradSaver: self.handler_params = handler_params self.api_name = handler_params.api_name self.origin_func = origin_func - self.data_params = DataParams() self.is_compare = True self.kwargs = dict() self.perturbed_grad_input = tuple() @@ -61,28 +60,25 @@ class GradSaver: _index += 1 def compare_grad_results(self, handler, origin_grad, perturbed_grad, index): - # TODO get dtype? - self.data_params.original_result = origin_grad - self.data_params.perturbed_result = perturbed_grad - self.data_params.grad_unequal_flag = False - self.data_params.valid_input_index = index + data_params = DataParams() + data_params.original_result = origin_grad + data_params.perturbed_result = perturbed_grad + data_params.grad_unequal_flag = False + data_params.valid_input_index = index try: - handler.handle(self.data_params) - if not self.data_params.is_consistent: + handler.handle(data_params) + if not data_params.is_consistent: self.is_compare = False - self.data_params.grad_unequal_flag = True - self.data_params.is_consistent = True - self.data_params.perturbed_result = self.perturbed_grad_input - self.data_params.original_result = self.origin_grad_input - handler.handle(self.data_params) + data_params.grad_unequal_flag = True + data_params.is_consistent = True + data_params.perturbed_result = self.perturbed_grad_input + data_params.original_result = self.origin_grad_input + handler.handle(data_params) except Exception as e: logger.warning_on_rank_0( f"[msprobe] Free benchmark: compare two vjp failed: api:{self.handler_params.api_name}." f"{e}" ) - # 在扰动前后输出对比后释放输出的引用 - self.data_params.perturbed_result = None - self.data_params.original_result = None def check_grad_input(self, origin_grad, new_grad_index): if self.perturbed_grad_input is None: @@ -164,20 +160,20 @@ class GradSaver: return grad_input def calculate_perturbed_grad_input(self, grad_output, need_grad_tensors, inner_args): - self.data_params.args = [need_grad_tensors, grad_output, inner_args] - self.data_params.kwargs = {} - self.data_params.valid_input_index = 0 - self.data_params.origin_func = self.get_grad_input_from_vjp + data_params = data_pre_deal( + self.handler_params.api_name, + self.get_grad_input_from_vjp, + [need_grad_tensors, grad_output, inner_args], + {} + ) layer = LayerFactory.create( self.handler_params.api_name, self.handler_params.fuzz_device, self.handler_params.pert_mode, ) - layer.handle(self.data_params) - # 在计算扰动输出之后,释放输入的引用 - self.data_params.args = None + layer.handle(data_params) # 确定扰动成功后,才会暂存 - if self.data_params.perturbed_result: + if data_params.perturbed_result: self.perturbed_grad_input = tuple( - [x.cpu() for x in self.data_params.perturbed_result] + [x.cpu() for x in data_params.perturbed_result] ) diff --git a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/main.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/main.py index 971776d1326409c8878849e7b09a4614ffbc16f5..69ece0a0c6a7a58fe8904bde470b16bb32c0d404 100644 --- a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/main.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/main.py @@ -10,7 +10,10 @@ from msprobe.pytorch.free_benchmark.common.enums import ( HandlerType, PerturbationMode, ) -from msprobe.pytorch.free_benchmark.common.params import data_pre_deal, make_handler_params +from msprobe.pytorch.free_benchmark.common.params import ( + data_pre_deal, + make_handler_params, +) from msprobe.pytorch.free_benchmark.compare.grad_saver import GradSaver from msprobe.pytorch.free_benchmark.perturbed_layers.layer_factory import LayerFactory from msprobe.pytorch.free_benchmark.result_handlers.handler_factory import ( @@ -70,9 +73,9 @@ class FreeBenchmarkCheck(ABC): layer.handle(data_params) handler_params = make_handler_params(name, self.config, self.current_iter) handler = FuzzHandlerFactory.create(handler_params) - handler.handle(data_params) - return data_params.perturbed_result, handler.get_unequal_rows() - + perturbed_output = handler.handle(data_params) + return perturbed_output, handler.get_unequal_rows() + def backward(self, name, module, grad_output): if not self.config.fuzz_stage == Const.BACKWARD: diff --git a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py index a18ef1c51bd342c9b3ab5ffecf14c307e9be5527..2ccc2bfcf7a54079436a9da0bcf74771aa644964 100644 --- a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py @@ -32,7 +32,7 @@ class AddNoiseLayer(NpuBaseLayer): return type(tensor_obj)([self.add_noise(value) for value in tensor_obj]) return tensor_obj - def handle(self, params: DataParams) -> torch.Any: + def handle(self, params: DataParams): """ 对输入添加扰动并返回 """ diff --git a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py index 45dea7b93a5c7628b24bf0470af10af355a7742f..a0ac216917a1bc7a35fba2b0ccf5ecd3b79854fa 100644 --- a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py @@ -48,7 +48,7 @@ class BitNoiseLayer(NpuBaseLayer): return type(tensor_obj)([self.add_bit_noise(value) for value in tensor_obj]) return tensor_obj - def handle(self, params: DataParams) -> torch.Any: + def handle(self, params: DataParams): """ 对输入添加扰动并返回 """ diff --git a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py index 91085d57a68b4841b2e04453c05c41a2903477c3..ae5bf9f03be3dd389920d353b0364cf8560499c1 100644 --- a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/change_value.py @@ -39,7 +39,7 @@ class ChangeValueLayer(NpuBaseLayer): return type(tensor_obj)([self.change_value(value) for value in tensor_obj]) return tensor_obj - def handle(self, params: DataParams) -> torch.Any: + def handle(self, params: DataParams): """ 对输入添加扰动并返回 """ diff --git a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py index ad6d8b8989d6983f81a9a2d58798a26d4ccc45c1..b5a106dacb099e105add9f8858a059099077aec7 100644 --- a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py @@ -17,7 +17,7 @@ class ImprovePrecisionLayer(NpuBaseLayer): and torch.is_floating_point(tensor_obj) and tensor_obj.dtype not in [torch.float32, torch.float64] ): - self._set_improve_valus(tensor_obj) + self._set_improve_values(tensor_obj) tensor_obj = self._change_dtype(tensor_obj) self.is_added = True return tensor_obj @@ -32,7 +32,7 @@ class ImprovePrecisionLayer(NpuBaseLayer): ) return tensor_obj - def handle(self, params: DataParams) -> torch.Any: + def handle(self, params: DataParams): logger.info_on_rank_0( f"[msprobe] Free benchmark: Perturbation is " f"{PerturbationMode.IMPROVE_PRECISION} of {self.api_name}." @@ -50,7 +50,7 @@ class ImprovePrecisionLayer(NpuBaseLayer): params.perturbed_result = params.origin_func(*new_args, **new_kwargs) return params.perturbed_result - def _set_improve_valus(self, inputs): + def _set_improve_values(self, inputs): if inputs.dtype in [torch.float16, torch.bfloat16]: self.perturbed_value = torch.float32 diff --git a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py index a69c56002a205a518a6929835591859f63b800ff..fa775e00edb51f26f9ed59a518a31e81d577bd77 100644 --- a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/npu/no_change.py @@ -16,7 +16,7 @@ class NoChangeLayer(NpuBaseLayer): self.is_added = True return tensor_obj - def handle(self, params: DataParams) -> torch.Any: + def handle(self, params: DataParams): """ 对输入添加扰动并返回 """ diff --git a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py index d34ac976537d794a05255a32de8d54de2dbac5d3..376f4ee3ea5f77453e4938a92b24ab751ac6dbcf 100644 --- a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/perturbed_layers/run_cpu.py @@ -8,7 +8,7 @@ from msprobe.pytorch.free_benchmark.perturbed_layers.base_layer import BaseLayer class CpuLayer(BaseLayer): - def handle(self, params: DataParams) -> torch.Any: + def handle(self, params: DataParams): logger.info_on_rank_0( f"[msprobe] Free benchmark: Perturbation is to_cpu of {self.api_name}." diff --git a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py index 5ee968c6a86728786f526660594fbb6de4ce18ee..46efd8283cc5f03d3dd70a5d09358b0303f92121 100644 --- a/debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py +++ b/debug/accuracy_tools/msprobe/pytorch/free_benchmark/result_handlers/handler_factory.py @@ -22,7 +22,6 @@ class FuzzHandlerFactory: handler = FuzzHandlerFactory.result_handlers.get(params.handler_type) else: handler = FuzzHandlerFactory.result_handlers.get(HandlerType.PREHEAT) - # TODO if not handler: raise FreeBenchmarkException( FreeBenchmarkException.UnsupportedType, diff --git a/debug/accuracy_tools/msprobe/pytorch/functional/dump_module.py b/debug/accuracy_tools/msprobe/pytorch/functional/dump_module.py index efb95c3369f6cda2f883d70a86261e0232535f86..5d2e8d9856c59bf715e1e5f7ab01c39dd7de73ed 100644 --- a/debug/accuracy_tools/msprobe/pytorch/functional/dump_module.py +++ b/debug/accuracy_tools/msprobe/pytorch/functional/dump_module.py @@ -24,7 +24,7 @@ def module_dump(module, dump_name): dump_name = dump_name + Const.SEP + str(module_count.get(dump_name)) + Const.SEP pdg = PrecisionDebugger() - _, forward_hook, backward_hook = pdg.service.build_hook(BaseScope.Module_Type_Module, dump_name) + _, forward_hook, backward_hook, _ = pdg.service.build_hook(BaseScope.Module_Type_Module, dump_name) module.register_forward_hook(forward_hook, with_kwargs=True) module.register_full_backward_hook(backward_hook) diff --git a/debug/accuracy_tools/msprobe/pytorch/grad_probe/grad_monitor.py b/debug/accuracy_tools/msprobe/pytorch/grad_probe/grad_monitor.py index edd28635da6446d83a876237bfd115c25cae43f3..36aec34e0425601125f57b4b9468e7f6cef67b47 100644 --- a/debug/accuracy_tools/msprobe/pytorch/grad_probe/grad_monitor.py +++ b/debug/accuracy_tools/msprobe/pytorch/grad_probe/grad_monitor.py @@ -2,54 +2,38 @@ import os from collections import defaultdict import torch -from torch.optim.optimizer import register_optimizer_step_pre_hook +if int(torch.__version__.split('.')[0]) >= 2: + from torch.optim.optimizer import register_optimizer_step_pre_hook from msprobe.pytorch.grad_probe.grad_stat_csv import GradStatCsv from msprobe.core.grad_probe.utils import check_numeral_list_ascend, data_in_list_target -from msprobe.core.grad_probe.constant import GradConst +from msprobe.core.grad_probe.constant import GradConst, level_adp from msprobe.core.common.file_check import create_directory from msprobe.core.common.log import logger -from msprobe.core.common.utils import remove_path, write_csv +from msprobe.core.common.utils import remove_path, write_csv, save_npy from msprobe.pytorch.common.utils import get_rank_id, print_rank_0, save_pt class GradientMonitor: - level_adp = { - "L0": { - "header": [GradConst.MD5, GradConst.MAX, GradConst.MIN, GradConst.NORM, GradConst.SHAPE], - "have_grad_direction": False - }, - "L1": { - "header": [GradConst.MAX, GradConst.MIN, GradConst.NORM, GradConst.SHAPE], - "have_grad_direction": True - }, - "L2": { - "header": [GradConst.DISTRIBUTION, GradConst.MAX, GradConst.MIN, GradConst.NORM, GradConst.SHAPE], - "have_grad_direction": True - }, - } - def __init__(self, config, model): - self._config = config._config - self._model = model - level = self._config.get("level") - if level not in GradientMonitor.level_adp: - raise Exception(f"level is valid, not in {GradientMonitor.level_adp.keys()}") - self._level_adp = GradientMonitor.level_adp[level] - self._param_list = self._config.get('param_list') - self._target_ranks = self._config.get("rank") + def __init__(self, common_config, task_config): + level = task_config.grad_level + if level not in level_adp: + raise Exception(f"level is valid, not in {level_adp.keys()}") + self._level_adp = level_adp[level] + self._param_list = task_config.param_list + self._target_ranks = common_config.rank logger.info(f"target rank {self._target_ranks}") - self._target_step = self._config.get("step") + self._target_step = common_config.step logger.info(f"target step {self._target_step}") - self._bounds = self._config.get("bounds") + self._bounds = task_config.bounds check_numeral_list_ascend(self._bounds) - self._output_path = self._config.get("output_path") + self._output_path = common_config.dump_path if not os.path.exists(self._output_path): create_directory(self._output_path) else: logger.warning(f"the file in {self._output_path} will be recoverd") self._step = -1 self._param2name = defaultdict(str) - self._monitor() @property def output_path(self): @@ -61,12 +45,12 @@ class GradientMonitor: create_directory(save_path) param_grad = grad.clone().detach() is_positive = param_grad > 0 - save_filepath = os.path.join(save_path, f"{param_name}.pt") - save_pt(is_positive, save_filepath) + save_filepath = os.path.join(save_path, f"{param_name}.npy") + save_npy(is_positive.numpy(), save_filepath) - def _monitor(self): + def monitor(self, model): print_rank_0("> parameter names:") - for name, param in self._model.named_parameters(): + for name, param in model.named_parameters(): self._param2name[param] = name print_rank_0(f"\t{name}") setattr(self, "_rank", get_rank_id()) @@ -102,5 +86,5 @@ class GradientMonitor: header_result = GradStatCsv.generate_csv_header(self._level_adp, self._bounds) output_lines.insert(0, header_result) write_csv(output_lines, output_path) - - register_optimizer_step_pre_hook(optimizer_pre_step_hook) + if int(torch.__version__.split('.')[0]) >= 2: + register_optimizer_step_pre_hook(optimizer_pre_step_hook) diff --git a/debug/accuracy_tools/msprobe/pytorch/grad_probe/grad_stat_csv.py b/debug/accuracy_tools/msprobe/pytorch/grad_probe/grad_stat_csv.py index ae01b75ee10d4cba5537830cf3280e8c69eb78ad..757a1aebf7a1b109516771c6f7096dfcb1f4baa7 100644 --- a/debug/accuracy_tools/msprobe/pytorch/grad_probe/grad_stat_csv.py +++ b/debug/accuracy_tools/msprobe/pytorch/grad_probe/grad_stat_csv.py @@ -63,13 +63,15 @@ class CSV_distribution(CsvItem): def generate_csv_header(csv_header_input): bounds = csv_header_input.bounds intervals = [] - for i, _ in enumerate(bounds): - if i == 0: - intervals.append(f"(-inf, {bounds[i]}]") - else: + if bounds: + intervals.append(f"(-inf, {bounds[0]}]") + for i in range(1, len(bounds)): intervals.append(f"({bounds[i-1]}, {bounds[i]}]") - intervals.extend([f"({bounds[-1]}, inf)", "=0"]) - return intervals + if intervals: + intervals.append(f"({bounds[-1]}, inf)") + intervals.append("=0") + + return intervals def generate_csv_content(csv_content_input): grad = csv_content_input.grad diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/hook_module.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/hook_module.py index ff6427e51e5c6bc6b715991979890f759ab955cf..aa724b50fd43e6699f9cd1e5fb8b3d1efe59c778 100644 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/hook_module.py +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/hook_module.py @@ -23,6 +23,7 @@ import torch.nn as nn import torch.utils.hooks as full_hooks from msprobe.core.common.const import Const +torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0' class HOOKModule(nn.Module): @@ -48,9 +49,13 @@ class HOOKModule(nn.Module): else: HOOKModule.module_count[self.prefix] += 1 self.prefix = self.prefix + str(HOOKModule.module_count[self.prefix] - 1) + Const.SEP - forward_pre_hook, forward_hook, backward_hook = build_hook(self.prefix) - self.register_forward_pre_hook(forward_pre_hook, with_kwargs=True) - self.register_forward_hook(forward_hook, with_kwargs=True) + forward_pre_hook, forward_hook, backward_hook, _ = build_hook(self.prefix) + if torch_version_above_or_equal_2: + self.register_forward_pre_hook(forward_pre_hook, with_kwargs=True) + self.register_forward_hook(forward_hook, with_kwargs=True) + else: + self.register_forward_pre_hook(forward_pre_hook) + self.register_forward_hook(forward_hook) self.register_backward_hook(backward_hook) def __call__(self, *input, **kwargs): diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_distributed.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_distributed.py index 6cf425441cc381652ddca4b203ac7a2b4161a116..3ca1db0f507356db095ef040099524d9937dab90 100644 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_distributed.py +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_distributed.py @@ -57,7 +57,12 @@ class DistributedOPTemplate(HOOKModule): @torch_device_guard def forward(self, *args, **kwargs): - return distributed_func.get(self.op_name_)(*args, **kwargs) + if kwargs.get("async_op") or self.op_name_ in ["isend", "irecv"]: + handle = distributed_func.get(self.op_name_)(*args, **kwargs) + handle.wait() + return handle + else: + return distributed_func.get(self.op_name_)(*args, **kwargs) def wrap_distributed_op(op_name, hook): diff --git a/debug/accuracy_tools/msprobe/pytorch/module_processer.py b/debug/accuracy_tools/msprobe/pytorch/module_processer.py index 3e9969d32d9147a6f26b7a6a4219364368a6116b..e6d2125e421ba07547d3e32cadd13981d89b5763 100644 --- a/debug/accuracy_tools/msprobe/pytorch/module_processer.py +++ b/debug/accuracy_tools/msprobe/pytorch/module_processer.py @@ -5,6 +5,7 @@ from torch.utils.hooks import BackwardHook from msprobe.core.common.const import Const from msprobe.core.data_dump.scope import ModuleRangeScope +torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0' class ModuleProcesser: @@ -109,7 +110,29 @@ class ModuleProcesser: if self.scope: self.scope.end_module(module.mindstudio_reserved_name) - if Const.START in start_or_stop: - return pre_hook + def backward_hook(module, input, output=None): + try: + index = ModuleProcesser.module_count_func(name_prefix) + except IndexError as e: + index = None + pass + module.mindstudio_reserved_name = full_name = name_prefix + Const.SEP + str(index) + forward_full_name = full_name.replace(Const.BACKWARD, Const.FORWARD) + ModuleProcesser.module_node[full_name] = ModuleProcesser.module_node[forward_full_name].replace( + Const.FORWARD, Const.BACKWARD) if ModuleProcesser.module_node[forward_full_name] else None + ModuleProcesser.api_parent_node = None + if self.scope: + self.scope.begin_module(full_name) + + if torch_version_above_or_equal_2: + if Const.START in start_or_stop: + return pre_hook + else: + return end_hook else: - return end_hook + if Const.FORWARD in name_prefix and Const.START in start_or_stop: + return pre_hook + elif Const.BACKWARD in name_prefix: + return backward_hook + else: + return end_hook diff --git a/debug/accuracy_tools/msprobe/pytorch/pt_config.py b/debug/accuracy_tools/msprobe/pytorch/pt_config.py index daba5476cab5ca0844f646e56d2b4d6ec6a2f1b4..68225f8a820e06afc235dacbf76f085bbb953a0c 100644 --- a/debug/accuracy_tools/msprobe/pytorch/pt_config.py +++ b/debug/accuracy_tools/msprobe/pytorch/pt_config.py @@ -5,18 +5,35 @@ from msprobe.core.common_config import CommonConfig, BaseConfig from msprobe.core.common.file_check import FileOpen from msprobe.core.common.const import Const from msprobe.pytorch.hook_module.utils import WrapFunctionalOps, WrapTensorOps, WrapTorchOps +from msprobe.core.grad_probe.constant import level_adp +from msprobe.core.grad_probe.utils import check_numeral_list_ascend class TensorConfig(BaseConfig): def __init__(self, json_config): super().__init__(json_config) + self.online_run_ut = json_config.get("online_run_ut", False) + self.nfs_path = json_config.get("nfs_path", "") + self.host = json_config.get("host", "") + self.port = json_config.get("port", -1) + self.tls_path = json_config.get("tls_path", "") self.check_config() self._check_file_format() + self._check_tls_path_config() def _check_file_format(self): if self.file_format is not None and self.file_format not in ["npy", "bin"]: raise Exception("file_format is invalid") + def _check_tls_path_config(self): + if self.tls_path: + if not os.path.exists(self.tls_path): + raise Exception("tls_path: %s does not exist" % self.tls_path) + if not os.path.exists(os.path.join(self.tls_path, "client.key")): + raise Exception("tls_path does not contain client.key") + if not os.path.exists(os.path.join(self.tls_path, "client.crt")): + raise Exception("tls_path does not contain client.crt") + class StatisticsConfig(BaseConfig): def __init__(self, json_config): @@ -65,11 +82,18 @@ class FreeBenchmarkCheckConfig(BaseConfig): class RunUTConfig(BaseConfig): WrapApi = set(WrapFunctionalOps) | set(WrapTensorOps) | set(WrapTorchOps) + def __init__(self, json_config): super().__init__(json_config) self.white_list = json_config.get("white_list", Const.DEFAULT_LIST) self.black_list = json_config.get("black_list", Const.DEFAULT_LIST) self.error_data_path = json_config.get("error_data_path", Const.DEFAULT_PATH) + self.is_online = json_config.get("is_online", False) + self.nfs_path = json_config.get("nfs_path", "") + self.host = json_config.get("host", "") + self.port = json_config.get("port", -1) + self.rank_list = json_config.get("rank_list", Const.DEFAULT_LIST) + self.tls_path = json_config.get("tls_path", "") self.check_run_ut_config() @classmethod @@ -86,17 +110,43 @@ class RunUTConfig(BaseConfig): def check_error_data_path_config(cls, error_data_path): if not os.path.exists(error_data_path): raise Exception("error_data_path: %s does not exist" % error_data_path) - + + @classmethod + def check_nfs_path_config(cls, nfs_path): + if nfs_path and not os.path.exists(nfs_path): + raise Exception("nfs_path: %s does not exist" % nfs_path) + + @classmethod + def check_tls_path_config(cls, tls_path): + if tls_path: + if not os.path.exists(tls_path): + raise Exception("tls_path: %s does not exist" % tls_path) + if not os.path.exists(os.path.join(tls_path, "server.key")): + raise Exception("tls_path does not contain server.key") + if not os.path.exists(os.path.join(tls_path, "server.crt")): + raise Exception("tls_path does not contain server.crt") + def check_run_ut_config(self): RunUTConfig.check_filter_list_config(Const.WHITE_LIST, self.white_list) RunUTConfig.check_filter_list_config(Const.BLACK_LIST, self.black_list) RunUTConfig.check_error_data_path_config(self.error_data_path) + RunUTConfig.check_nfs_path_config(self.nfs_path) + RunUTConfig.check_tls_path_config(self.tls_path) class GradToolConfig(BaseConfig): def __init__(self, json_config): super().__init__(json_config) - self._config = json_config + self.grad_level = json_config.get("grad_level", "L1") + self.param_list = json_config.get("param_list", []) + self.bounds = json_config.get("bounds", []) + + def _check_config(self): + if self.grad_level not in level_adp.keys(): + raise Exception(f"grad_level must be one of {level_adp.keys()}") + if not isinstance(self.param_list, list): + raise Exception(f"param_list must be a list") + check_numeral_list_ascend(self.bounds) def parse_task_config(task, json_config): diff --git a/debug/accuracy_tools/msprobe/pytorch/service.py b/debug/accuracy_tools/msprobe/pytorch/service.py index 6b8d67abc9fa1f32e1353ae23485393302b66628..980c7d840cae208aae5c137d6a93e89279b580ca 100644 --- a/debug/accuracy_tools/msprobe/pytorch/service.py +++ b/debug/accuracy_tools/msprobe/pytorch/service.py @@ -1,7 +1,10 @@ import functools import os +import time from pathlib import Path +from collections import namedtuple +import torch from msprobe.core.common.const import Const, FileCheckConst from msprobe.core.common.exceptions import DistributedNotInitializedError, MsprobeException from msprobe.core.common.file_check import FileChecker, check_path_before_create @@ -14,6 +17,10 @@ from msprobe.pytorch.hook_module import remove_dropout from msprobe.pytorch.hook_module.api_registry import api_register from msprobe.pytorch.hook_module.hook_module import HOOKModule from msprobe.pytorch.module_processer import ModuleProcesser +from msprobe.pytorch.api_accuracy_checker.tensor_transport_layer.attl import ATTLConfig, ATTL, ApiData +torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0' + +HookFn = namedtuple('hookFn', ['pre_hook', 'forward_hook', 'backward_hook', 'forward_hook_torch_version_below_2']) class Service: @@ -27,6 +34,7 @@ class Service: self.first_start = True self.current_rank = None self.dump_iter_dir = None + self.attl = None @staticmethod def forward_backward_dump_end(): @@ -41,6 +49,8 @@ class Service: if not self.switch: return args, kwargs + if self.config.online_run_ut: + return None, None if self.data_collector: module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=None) self.data_collector.pre_forward_data_collect(api_or_module_name, module, pid, module_input_output) @@ -53,6 +63,14 @@ class Service: if not self.switch: return None + + if self.config.online_run_ut: + if self.data_collector.scope and not self.data_collector.scope.check(api_or_module_name): + return None + api_data = ApiData(name[:-1], args, kwargs, output, self.current_iter, self.current_rank) + self.attl_send(api_data) + return None + if self.data_collector: module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=output) self.data_collector.forward_data_collect(api_or_module_name, module, pid, module_input_output) @@ -60,6 +78,9 @@ class Service: return self.data_collector.get_forward_new_output() return output + def forward_hook_torch_version_below_2(api_or_module_name, module, args, output): + return forward_hook(api_or_module_name, module, args, {}, output) + def backward_hook(api_or_module_name, module, grad_input, grad_output): if module_type == BaseScope.Module_Type_Module: api_or_module_name = module.mindstudio_reserved_name @@ -67,6 +88,14 @@ class Service: if not self.switch: return + + if self.config.online_run_ut: + if self.data_collector.scope and not self.data_collector.scope.check(api_or_module_name): + return + api_data = ApiData(name[:-1], grad_input, {}, grad_output, self.current_iter, self.current_rank) + self.attl_send(api_data) + return + if self.data_collector: # 此处获取到的grad_input实际为反向过程的输出数据,grad_output为反向过程的输入数据,因此传入时调换顺序 module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_output, grad_output=grad_input) @@ -75,10 +104,11 @@ class Service: pid = os.getpid() forward_name_template = name + Const.FORWARD backward_name_template = name + Const.BACKWARD - pre_forward_hook = functools.partial(pre_hook, forward_name_template) - forward_hook = functools.partial(forward_hook, forward_name_template) - backward_hook = functools.partial(backward_hook, backward_name_template) - return pre_forward_hook, forward_hook, backward_hook + pre_forward_hook_fn = functools.partial(pre_hook, forward_name_template) + forward_hook_fn = functools.partial(forward_hook, forward_name_template) + backward_hook_fn = functools.partial(backward_hook, backward_name_template) + forward_hook_torch_version_below_2_fn = functools.partial(forward_hook_torch_version_below_2, forward_name_template) + return HookFn(pre_forward_hook_fn, forward_hook_fn, backward_hook_fn, forward_hook_torch_version_below_2_fn) def step(self): self.current_iter += 1 @@ -90,6 +120,9 @@ class Service: def start(self, model, api_origin=False): self.model = model if self.config.step and self.current_iter > max(self.config.step): + if self.config.online_run_ut: + # send stop signal if online_run_ut + self.attl_stop() self.stop() raise Exception("msprobe: exit after iteration {}".format(max(self.config.step))) if self.config.step and self.current_iter not in self.config.step: @@ -99,6 +132,7 @@ class Service: self.current_rank = get_rank_if_initialized() except DistributedNotInitializedError: self.current_rank = None + self.attl_init() if self.config.rank and self.current_rank not in self.config.rank: return @@ -108,7 +142,7 @@ class Service: api_register.api_modularity() self.switch = True logger.info_on_rank_0(f"Dump switch is turned on at step {self.current_iter}. ") - if self.config.level != "L2": + if self.config.level != "L2" and not self.config.online_run_ut: self.create_dirs() logger.info_on_rank_0(f"Dump data will be saved in {self.dump_iter_dir}.") @@ -120,6 +154,8 @@ class Service: if self.config.rank and self.current_rank not in self.config.rank: return self.switch = False + if self.config.online_run_ut: + return self.data_collector.write_json() def create_dirs(self): @@ -158,22 +194,59 @@ class Service: prefix = BaseScope.Module_Type_Module + Const.SEP + name + Const.SEP + \ module.__class__.__name__ + Const.SEP - pre_forward_hook, forward_hook, backward_hook = self.build_hook(BaseScope.Module_Type_Module, prefix) - module.register_forward_hook(forward_hook, with_kwargs=True) + pre_forward_hook, forward_hook, backward_hook, forward_hook_torch_version_below_2 \ + = self.build_hook(BaseScope.Module_Type_Module, prefix) + if torch_version_above_or_equal_2: + module.register_forward_hook(forward_hook, with_kwargs=True) + else: + module.register_full_backward_hook( + self.module_processor.node_hook(prefix + Const.BACKWARD, Const.STOP)) + module.register_forward_hook(forward_hook_torch_version_below_2) module.register_full_backward_hook(backward_hook) module.register_forward_pre_hook( self.module_processor.node_hook(prefix + Const.FORWARD, Const.START)) module.register_forward_hook( self.module_processor.node_hook(prefix + Const.FORWARD, Const.STOP)) - module.register_full_backward_pre_hook( - self.module_processor.node_hook(prefix + Const.BACKWARD, Const.START)) - module.register_full_backward_hook( - self.module_processor.node_hook(prefix + Const.BACKWARD, Const.STOP)) + if torch_version_above_or_equal_2: + module.register_full_backward_pre_hook( + self.module_processor.node_hook(prefix + Const.BACKWARD, Const.START)) + module.register_full_backward_hook( + self.module_processor.node_hook(prefix + Const.BACKWARD, Const.STOP)) if self.config.level in ["mix", "L1", "L2"]: api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API)) api_register.api_modularity() if Const.STATISTICS == self.config.task or Const.TENSOR == self.config.task: - remove_dropout() \ No newline at end of file + remove_dropout() + + def attl_init(self): + if self.config.online_run_ut: + attl_config = ATTLConfig(is_benchmark_device=False, + connect_ip=self.config.host, + connect_port=self.config.port, + nfs_path=self.config.nfs_path, + tls_path=self.config.tls_path) + need_dump = len(self.config.rank) == 0 or self.current_rank in self.config.rank + self.attl = ATTL('npu', attl_config, need_dump=need_dump) + if self.config.nfs_path: + self.attl.upload("start") + + def attl_send(self, api_data): + logger.info(f"tools is dumping api: {api_data.name}, rank: {self.current_rank}") + api_type, _, _ = api_data.name.split(Const.SEP) + if api_type in [Const.DISTRIBUTED]: + logger.info(f"api {api_data.name} is not supported, skip") + return + if self.config.nfs_path: + self.attl.upload(api_data) + else: + self.attl.send(api_data) + + def attl_stop(self): + if self.config.nfs_path: + self.attl.upload("end") + elif self.attl.socket_manager is not None: + logger.info(f"pid: {os.getpid()} finished, start send STOP signal.") + self.attl.socket_manager.send_stop_signal() diff --git a/debug/accuracy_tools/msprobe/test/core_ut/common/test_utils.py b/debug/accuracy_tools/msprobe/test/core_ut/common/test_utils.py index edd3eb53dccf453f1d3efde7189dfadcd6dee000..a02a402f6e27c0464c1e748df06a55cc5e1a3b3f 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/common/test_utils.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/common/test_utils.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- """ -# Copyright (C) 2022-2023. Huawei Technologies Co., Ltd. All rights reserved. +# Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -23,24 +23,24 @@ from unittest.mock import patch, MagicMock, mock_open from msprobe.core.common.log import logger from msprobe.core.common.const import Const from msprobe.core.common.utils import (CompareException, - check_seed_all, - check_inplace_op, - make_dump_path_if_not_exists, - check_mode_valid, - check_switch_valid, - check_dump_mode_valid, - check_summary_mode_valid, - check_summary_only_valid, - check_file_or_directory_path, - check_compare_param, - check_configuration_param, - is_starts_with, - _check_json, - check_json_file, - check_file_size, - check_regex_prefix_format_valid, - get_dump_data_path, - task_dumppath_get) + check_seed_all, + check_inplace_op, + make_dump_path_if_not_exists, + check_mode_valid, + check_switch_valid, + check_dump_mode_valid, + check_summary_mode_valid, + check_summary_only_valid, + check_file_or_directory_path, + check_compare_param, + check_configuration_param, + is_starts_with, + _check_json, + check_json_file, + check_file_size, + check_regex_prefix_format_valid, + get_dump_data_path, + task_dumppath_get) from msprobe.core.common.file_check import FileCheckConst @@ -189,28 +189,28 @@ class TestUtils(TestCase): @patch.object(logger, "error") def test_check_compare_param(self, mock_error): params = { - "npu_json_path": "npu_json_path", - "bench_json_path": "bench_json_path", - "stack_json_path": "stack_json_path", + "npu_path": "npu_path", + "bench_path": "bench_path", + "stack_path": "stack_path", "npu_dump_data_dir": "npu_dump_data_dir", "bench_dump_data_dir": "bench_dump_data_dir" } call_args = [ - ("npu_json_path", False), - ("bench_json_path", False), - ("stack_json_path", False), + ("npu_path", False), + ("bench_path", False), + ("stack_path", False), ("npu_dump_data_dir", True), ("bench_dump_data_dir", True), ("output_path", True), - ("npu_json_path", False), - ("bench_json_path", False), - ("stack_json_path", False), + ("npu_path", False), + ("bench_path", False), + ("stack_path", False), ("output_path", True) ] with self.assertRaises(CompareException) as context: - check_compare_param("npu_json_path", "output_path") + check_compare_param("npu_path", "output_path") self.assertEqual(context.exception.code, CompareException.INVALID_PARAM_ERROR) mock_error.assert_called_with("Invalid input parameters") @@ -264,14 +264,14 @@ class TestUtils(TestCase): @patch("msprobe.core.common.utils._check_json") def test_check_json_file(self, _mock_check_json): input_param = { - "npu_json_path": "npu_json_path", - "bench_json_path": "bench_json_path", - "stack_json_path": "stack_json_path" + "npu_path": "npu_path", + "bench_path": "bench_path", + "stack_path": "stack_path" } check_json_file(input_param, "npu_json", "bench_json", "stack_json") - self.assertEqual(_mock_check_json.call_args_list[0][0], ("npu_json", "npu_json_path")) - self.assertEqual(_mock_check_json.call_args_list[1][0], ("bench_json", "bench_json_path")) - self.assertEqual(_mock_check_json.call_args_list[2][0], ("stack_json", "stack_json_path")) + self.assertEqual(_mock_check_json.call_args_list[0][0], ("npu_json", "npu_path")) + self.assertEqual(_mock_check_json.call_args_list[1][0], ("bench_json", "bench_path")) + self.assertEqual(_mock_check_json.call_args_list[2][0], ("stack_json", "stack_path")) @patch.object(logger, "error") def test_check_file_size(self, mock_error): @@ -307,8 +307,8 @@ class TestUtils(TestCase): @patch.object(logger, "error") def test_task_dumppath_get(self, mock_error): input_param = { - "npu_json_path": None, - "bench_json_path": "bench_json_path" + "npu_path": None, + "bench_path": "bench_path" } npu_json = { "task": Const.TENSOR, @@ -321,7 +321,7 @@ class TestUtils(TestCase): self.assertEqual(context.exception.code, CompareException.INVALID_PATH_ERROR) mock_error.assert_called_with("Please check the json path is valid.") - input_param["npu_json_path"] = "npu_json_path" + input_param["npu_path"] = "npu_path" with patch("msprobe.core.common.utils.FileOpen", mock_open(read_data="")), \ patch("msprobe.core.common.utils.json.load", return_value=npu_json): summary_compare, md5_compare = task_dumppath_get(input_param) diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_debugger_config.py b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_factory.py similarity index 38% rename from debug/accuracy_tools/msprobe/test/mindspore_ut/test_debugger_config.py rename to debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_factory.py index 5187d3951c0cbf2bdeb5db6f402933c8bf08e94d..2f4f25300143392e3f35a34d1d320c935bcecdf3 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_debugger_config.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_factory.py @@ -14,29 +14,25 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -from unittest import TestCase +import unittest +from unittest.mock import patch +from msprobe.core.data_dump.data_processor.factory import DataProcessorFactory from msprobe.core.common.const import Const -from msprobe.core.common_config import CommonConfig, BaseConfig -from msprobe.mindspore.debugger.debugger_config import DebuggerConfig +from msprobe.core.data_dump.data_processor.mindspore_processor import ( + StatisticsDataProcessor as MindsporeStatisticsDataProcessor, + TensorDataProcessor as MindsporeTensorDataProcessor, + OverflowCheckDataProcessor as MindsporeOverflowCheckDataProcessor +) -class TestDebuggerConfig(TestCase): - def test_init(self): - json_config = { - "dump_path": "/absolute_path", - "rank": [], - "step": [], - "level": "L0" - } - common_config = CommonConfig(json_config) - task_config = BaseConfig(json_config) - debugger_config = DebuggerConfig(common_config, task_config) - self.assertEqual(debugger_config.task, Const.STATISTICS) - self.assertEqual(debugger_config.file_format, "npy") - self.assertEqual(debugger_config.check_mode, "all") - - common_config.dump_path = "./path" - with self.assertRaises(Exception) as context: - DebuggerConfig(common_config, task_config) - self.assertEqual(str(context.exception), "Dump path must be absolute path.") +class TestDataProcessorFactory(unittest.TestCase): + def test_register_processors(self): + with patch.object(DataProcessorFactory, "register_processor") as mock_register: + DataProcessorFactory.register_processors(Const.MS_FRAMEWORK) + self.assertEqual(mock_register.call_args_list[0][0], + (Const.MS_FRAMEWORK, Const.STATISTICS, MindsporeStatisticsDataProcessor)) + self.assertEqual(mock_register.call_args_list[1][0], + (Const.MS_FRAMEWORK, Const.TENSOR, MindsporeTensorDataProcessor)) + self.assertEqual(mock_register.call_args_list[2][0], + (Const.MS_FRAMEWORK, Const.OVERFLOW_CHECK, MindsporeOverflowCheckDataProcessor)) diff --git a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_mindspore_processor.py b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_mindspore_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..046388741b7d535342bf3e8ded69aa3617bd761e --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_mindspore_processor.py @@ -0,0 +1,145 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +import unittest +from unittest.mock import patch + +from mindspore import Tensor +import numpy as np + +from msprobe.core.data_dump.data_processor.base import BaseDataProcessor +from msprobe.core.data_dump.data_processor.mindspore_processor import MindsporeDataProcessor, OverflowCheckDataProcessor +from msprobe.core.common.const import FileCheckConst + + +class TestOverflowCheckDataProcessor(unittest.TestCase): + def setUp(self): + class Config: + def __init__(self): + self.overflow_nums = 1 + self.data_processor = OverflowCheckDataProcessor(Config(), None) + + def test___init__(self): + self.assertEqual(self.data_processor.cached_tensors_and_file_paths, {}) + self.assertEqual(self.data_processor.real_overflow_nums, 0) + self.assertEqual(self.data_processor.overflow_nums, 1) + + def test_analyze_forward(self): + def func(_): + self.data_processor.has_overflow = True + with patch.object(BaseDataProcessor, "analyze_forward", return_value={"min", 0}): + with patch.object(OverflowCheckDataProcessor, "maybe_save_overflow_data"): + api_info = self.data_processor.analyze_forward("name", "module", "module_input_output") + self.assertFalse(self.data_processor.has_overflow) + self.assertIsNone(api_info) + with patch.object(OverflowCheckDataProcessor, "maybe_save_overflow_data", new=func): + api_info = self.data_processor.analyze_forward("name", "module", "module_input_output") + self.assertTrue(self.data_processor.has_overflow) + self.assertEqual(api_info, {"min", 0}) + + def test_analyze_backward(self): + def func(_): + self.data_processor.has_overflow = True + with patch.object(BaseDataProcessor, "analyze_backward", return_value={"min", 0}): + with patch.object(OverflowCheckDataProcessor, "maybe_save_overflow_data"): + api_info = self.data_processor.analyze_backward("name", "module", "module_input_output") + self.assertFalse(self.data_processor.has_overflow) + self.assertIsNone(api_info) + with patch.object(OverflowCheckDataProcessor, "maybe_save_overflow_data", new=func): + api_info = self.data_processor.analyze_backward("name", "module", "module_input_output") + self.assertTrue(self.data_processor.has_overflow) + self.assertEqual(api_info, {"min", 0}) + + @patch("msprobe.core.data_dump.data_processor.mindspore_processor.np.save") + @patch("msprobe.core.data_dump.data_processor.mindspore_processor.change_mode") + def test_maybe_save_overflow_data(self, mock_change_mode, mock_save): + self.data_processor.has_overflow = True + tensor1 = Tensor(1) + tensor2 = Tensor(2) + self.data_processor.cached_tensors_and_file_paths = {"tensor1": tensor1, "tensor2": tensor2} + with patch("mindspore.Tensor.asnumpy", return_value="npy"): + self.data_processor.maybe_save_overflow_data() + self.assertEqual(mock_save.call_args_list[0][0], + ("tensor1", "npy")) + self.assertEqual(mock_save.call_args_list[1][0], + ("tensor2", "npy")) + self.assertEqual(mock_change_mode.call_args_list[0][0], + ("tensor1", FileCheckConst.DATA_FILE_AUTHORITY)) + self.assertEqual(mock_change_mode.call_args_list[1][0], + ("tensor2", FileCheckConst.DATA_FILE_AUTHORITY)) + + @patch("msprobe.core.data_dump.data_processor.mindspore_processor.logger.info") + def test_is_terminated(self, mock_info): + self.data_processor.overflow_nums = -1 + self.assertFalse(self.data_processor.is_terminated) + self.data_processor.real_overflow_nums = 2 + self.data_processor.overflow_nums = 2 + self.assertTrue(self.data_processor.is_terminated) + mock_info.assert_called_with("[msprobe] 超过预设溢出次数 当前溢出次数: 2") + self.data_processor.overflow_nums = 3 + self.assertFalse(self.data_processor.is_terminated) + + def test__analyze_maybe_overflow_tensor(self): + self.data_processor.has_overflow = False + tensor_json = {"Max": None, "Min": 0} + self.data_processor._analyze_maybe_overflow_tensor(tensor_json) + self.assertFalse(self.data_processor.has_overflow) + tensor_json.update({"Max": -np.inf}) + self.data_processor._analyze_maybe_overflow_tensor(tensor_json) + self.assertTrue(self.data_processor.has_overflow) + self.data_processor.has_overflow = False + tensor_json.update({"Max": np.inf}) + self.data_processor._analyze_maybe_overflow_tensor(tensor_json) + self.assertTrue(self.data_processor.has_overflow) + self.data_processor.has_overflow = False + tensor_json.update({"Max": np.nan}) + self.data_processor._analyze_maybe_overflow_tensor(tensor_json) + self.assertTrue(self.data_processor.has_overflow) + tensor_json.update({"Max": 0}) + self.data_processor.has_overflow = False + tensor_json.update({"Min": -np.inf}) + self.data_processor._analyze_maybe_overflow_tensor(tensor_json) + self.assertTrue(self.data_processor.has_overflow) + self.data_processor.has_overflow = False + tensor_json.update({"Min": np.inf}) + self.data_processor._analyze_maybe_overflow_tensor(tensor_json) + self.assertTrue(self.data_processor.has_overflow) + self.data_processor.has_overflow = False + tensor_json.update({"Min": np.nan}) + self.data_processor._analyze_maybe_overflow_tensor(tensor_json) + self.assertTrue(self.data_processor.has_overflow) + + @patch("msprobe.core.data_dump.data_processor.mindspore_processor.logger.warning") + @patch.object(OverflowCheckDataProcessor, "get_save_file_path") + @patch.object(MindsporeDataProcessor, "_analyze_tensor") + def test__analyze_tensor(self, mock_super, mock_get_file_path, mock_warning): + mock_get_file_path.return_value = ("dump_data_name", "file_path") + single_arg = {"Max": None} + mock_super.return_value = single_arg + + with patch("msprobe.core.data_dump.data_processor.mindspore_processor.path_len_exceeds_limit", + return_value=False): + ret = self.data_processor._analyze_tensor("tensor", "suffix") + self.assertEqual(self.data_processor.cached_tensors_and_file_paths, {"file_path": "tensor"}) + mock_warning.assert_not_called() + mock_super.assert_called_with("tensor", "suffix") + self.assertEqual(ret.get("Max"), None) + self.assertEqual(ret.get("data_name"), "dump_data_name") + + with patch("msprobe.core.data_dump.data_processor.mindspore_processor.path_len_exceeds_limit", + return_value=True): + self.data_processor._analyze_tensor("tensor", "suffix") + mock_warning.assert_called_with("The file path file_path length exceeds limit.") diff --git a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_data_collector.py b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_data_collector.py index eedbe5be7e0360d7874439357419510cbde73b71..15a0883f5b7b15efa5423280e94dd0c4200d965d 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_data_collector.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/test_data_collector.py @@ -1,3 +1,20 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + import unittest from unittest.mock import patch, mock_open, MagicMock @@ -5,6 +22,9 @@ from msprobe.core.common.utils import Const from msprobe.core.data_dump.data_collector import DataCollector from msprobe.pytorch.debugger.debugger_config import DebuggerConfig from msprobe.pytorch.pt_config import parse_json_config +from msprobe.core.data_dump.json_writer import DataWriter +from msprobe.core.data_dump.data_processor.base import BaseDataProcessor +from msprobe.core.data_dump.data_processor.pytorch_processor import StatisticsDataProcessor class TestDataCollector(unittest.TestCase): @@ -45,3 +65,54 @@ class TestDataCollector(unittest.TestCase): self.data_collector.pre_forward_data_collect(name, None, pid, None) self.data_collector.check_scope_and_pid.assert_called_once_with( self.data_collector.scope, "TestModule.backward", 123) + + def test_handle_data(self): + with patch.object(DataCollector, "update_data", return_value="msg") as mock_update_data, \ + patch.object(DataCollector, "write_json") as mock_write_json, \ + patch("msprobe.core.data_dump.data_collector.logger.info") as mock_info, \ + patch("msprobe.core.data_dump.json_writer.DataWriter.flush_data_when_buffer_is_full") as mock_flush: + self.data_collector.handle_data("Tensor.add", {"min": 0}) + msg = "msprobe is collecting data on Tensor.add. " + mock_update_data.assert_called_with({"min": 0}, msg) + mock_info.assert_called_with("msg") + mock_flush.assert_called() + mock_write_json.assert_not_called() + + mock_update_data.reset_mock() + mock_info.reset_mock() + mock_flush.reset_mock() + self.data_collector.handle_data("Tensor.add", {}, use_buffer=False) + mock_update_data.assert_not_called() + mock_info.assert_not_called() + mock_write_json.assert_called() + + @patch.object(DataCollector, "update_construct") + @patch.object(DataWriter, "update_stack") + @patch.object(BaseDataProcessor, "analyze_api_call_stack") + @patch.object(DataCollector, "handle_data") + def test_forward_data_collect(self, mock_handle_data, _, __, ___): + with patch.object(DataCollector, "check_scope_and_pid", return_value=True), \ + patch.object(DataCollector, "is_inplace", return_value=False), \ + patch.object(StatisticsDataProcessor, "analyze_forward", return_value={}): + with patch.object(StatisticsDataProcessor, "is_terminated", return_value=True), \ + self.assertRaises(Exception) as context: + self.data_collector.forward_data_collect("name", "module", "pid", "module_input_output") + mock_handle_data.assert_called_with("name", {}, use_buffer=False) + self.assertEqual(str(context.exception), "[msprobe] exit") + + self.data_collector.forward_data_collect("name", "module", "pid", "module_input_output") + mock_handle_data.assert_called_with("name", {}) + + @patch.object(DataCollector, "update_construct") + @patch.object(DataCollector, "handle_data") + def test_backward_data_collect(self, mock_handle_data, _): + with patch.object(DataCollector, "check_scope_and_pid", return_value=True), \ + patch.object(StatisticsDataProcessor, "analyze_backward", return_value={}): + with patch.object(StatisticsDataProcessor, "is_terminated", return_value=True), \ + self.assertRaises(Exception) as context: + self.data_collector.backward_data_collect("name", "module", "pid", "module_input_output") + mock_handle_data.assert_called_with("name", {}, use_buffer=False) + self.assertEqual(str(context.exception), "[msprobe] exit") + + self.data_collector.backward_data_collect("name", "module", "pid", "module_input_output") + mock_handle_data.assert_called_with("name", {}) diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/common/test_ms_utils.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/common/test_ms_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..96f2daf2033434a711e11c7f7a614c4d0cc45f74 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/common/test_ms_utils.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +import unittest + +from msprobe.mindspore.common.utils import MsprobeStep + + +class TestMsprobeStep(unittest.TestCase): + def setUp(self): + class Debugger: + def __init__(self): + self.start_called = False + self.stop_called = False + self.step_called = False + self.stop_called_first = False + + def start(self): + self.start_called = True + + def stop(self): + self.stop_called = True + + def step(self): + if self.stop_called: + self.stop_called_first = True + self.step_called = True + debugger = Debugger() + self.msprobe_step = MsprobeStep(debugger) + + def test_on_train_step_begin(self): + self.msprobe_step.on_train_step_begin("run_context") + self.assertTrue(self.msprobe_step.debugger.start_called) + self.assertFalse(self.msprobe_step.debugger.stop_called) + self.assertFalse(self.msprobe_step.debugger.step_called) + + def test_on_train_step_end(self): + self.msprobe_step.on_train_step_end("run_context") + self.assertFalse(self.msprobe_step.debugger.start_called) + self.assertTrue(self.msprobe_step.debugger.stop_called) + self.assertTrue(self.msprobe_step.debugger.step_called) + self.assertTrue(self.msprobe_step.debugger.stop_called_first) diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/debugger/test_debugger_config.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/debugger/test_debugger_config.py new file mode 100644 index 0000000000000000000000000000000000000000..9806632370248e991e98314d9ab650af28a96359 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/debugger/test_debugger_config.py @@ -0,0 +1,80 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +import unittest +from unittest.mock import patch + +from msprobe.core.common.const import Const, FileCheckConst +from msprobe.mindspore.common.const import FreeBenchmarkConst +from msprobe.core.common_config import CommonConfig, BaseConfig +from msprobe.mindspore.debugger.debugger_config import DebuggerConfig + + +class TestDebuggerConfig(unittest.TestCase): + @patch.object(DebuggerConfig, "_make_dump_path_if_not_exists") + def test_init(self, _): + json_config = { + "dump_path": "/absolute_path", + "rank": [], + "step": [], + "level": "L2" + } + common_config = CommonConfig(json_config) + task_config = BaseConfig(json_config) + debugger_config = DebuggerConfig(common_config, task_config) + self.assertEqual(debugger_config.task, Const.STATISTICS) + self.assertEqual(debugger_config.file_format, "npy") + self.assertEqual(debugger_config.check_mode, "all") + self.assertEqual(debugger_config.overflow_nums, 1) + + common_config.dump_path = "./path" + with self.assertRaises(Exception) as context: + DebuggerConfig(common_config, task_config) + self.assertEqual(str(context.exception), "Dump path must be absolute path.") + + common_config.dump_path = "./path" + with self.assertRaises(Exception) as context: + DebuggerConfig(common_config, task_config) + self.assertEqual(str(context.exception), "Dump path must be absolute path.") + + common_config.level = "L1" + common_config.task = Const.FREE_BENCHMARK + debugger_config = DebuggerConfig(common_config, task_config) + self.assertEqual(debugger_config.pert_type, FreeBenchmarkConst.DEFAULT_PERT_TYPE) + self.assertEqual(debugger_config.handler_type, FreeBenchmarkConst.DEFAULT_HANDLER_TYPE) + self.assertEqual(debugger_config.dump_level, FreeBenchmarkConst.DEFAULT_DUMP_LEVEL) + self.assertEqual(debugger_config.stage, FreeBenchmarkConst.DEFAULT_STAGE) + + task_config.handler_type = FreeBenchmarkConst.FIX + task_config.pert_mode = FreeBenchmarkConst.ADD_NOISE + with self.assertRaises(Exception) as context: + DebuggerConfig(common_config, task_config) + self.assertEqual(str(context.exception), + "pert_mode must be improve_precision or empty when handler_type is fix, " + f"but got {FreeBenchmarkConst.ADD_NOISE}.") + + @patch("msprobe.mindspore.debugger.debugger_config.os.path.exists", return_value=False) + def test__make_dump_path_if_not_exists(self, _): + json_config = {"dump_path": "/absolute_path"} + common_config = CommonConfig(json_config) + task_config = BaseConfig(json_config) + with patch("msprobe.mindspore.debugger.debugger_config.check_path_before_create") as mock_check_path, \ + patch("msprobe.mindspore.debugger.debugger_config.Path.mkdir") as mock_mkdir, \ + patch("msprobe.mindspore.debugger.debugger_config.FileChecker") as mock_checker: + DebuggerConfig(common_config, task_config) + mock_check_path.assert_called_with(json_config.get("dump_path")) + mock_mkdir.assert_called_with(mode=FileCheckConst.DATA_DIR_AUTHORITY, exist_ok=True) + mock_checker.assert_called_with(common_config.dump_path, FileCheckConst.DIR) diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_precision_debugger.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/debugger/test_precision_debugger.py similarity index 37% rename from debug/accuracy_tools/msprobe/test/mindspore_ut/test_precision_debugger.py rename to debug/accuracy_tools/msprobe/test/mindspore_ut/debugger/test_precision_debugger.py index 425ed3040dcc829927a8f4cbb25024f1b567a48f..ee9970f510b7f7c8f30294adf7fa4e516c7184c1 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_precision_debugger.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/debugger/test_precision_debugger.py @@ -14,16 +14,21 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -from unittest import TestCase -from unittest.mock import patch +import unittest +from unittest.mock import patch, MagicMock from msprobe.core.common_config import CommonConfig, BaseConfig from msprobe.mindspore.debugger.debugger_config import DebuggerConfig from msprobe.mindspore.debugger.precision_debugger import PrecisionDebugger +from msprobe.mindspore.runtime import Runtime +from msprobe.mindspore.common.const import Const as MsConst +from msprobe.core.common.const import Const -class TestPrecisionDebugger(TestCase): - def test_start(self): +class TestPrecisionDebugger(unittest.TestCase): + + @patch.object(DebuggerConfig, "_make_dump_path_if_not_exists") + def test_start(self, _): class Handler: called = False @@ -35,22 +40,68 @@ class TestPrecisionDebugger(TestCase): "dump_path": "/absolute_path", "rank": [], "step": [], - "level": "L0" + "level": "L1" } common_config = CommonConfig(json_config) task_config = BaseConfig(json_config) handler = Handler() - with patch("msprobe.mindspore.debugger.precision_debugger.parse_json_config", - return_value=[common_config, task_config]), \ + mock_get_mode = MagicMock() + mock_parse_json_config = MagicMock() + with patch("msprobe.mindspore.debugger.precision_debugger.parse_json_config", new=mock_parse_json_config), \ + patch.object(PrecisionDebugger, "_get_execution_mode", new=mock_get_mode), \ patch("msprobe.mindspore.debugger.precision_debugger.TaskHandlerFactory.create", return_value=handler): + mock_get_mode.return_value = MsConst.GRAPH_GE_MODE + mock_parse_json_config.return_value = [common_config, task_config] debugger = PrecisionDebugger() + self.assertEqual(Runtime.step_count, 0) + self.assertFalse(Runtime.is_running) debugger.start() - self.assertTrue(isinstance(debugger.config, DebuggerConfig)) - self.assertTrue(Handler.called) + self.assertTrue(Runtime.is_running) + self.assertTrue(isinstance(debugger.config, DebuggerConfig)) + self.assertTrue(Handler.called) + + mock_get_mode.return_value = MsConst.PYNATIVE_MODE + with patch("msprobe.mindspore.debugger.precision_debugger.Service") as mock_Service: + debugger = PrecisionDebugger() + debugger.start() + service = mock_Service.return_value + mock_Service.assert_called_with(debugger.config) + service.start.assert_called_with(None) PrecisionDebugger._instance = None with self.assertRaises(Exception) as context: debugger.start() self.assertEqual(str(context.exception), "No instance of PrecisionDebugger found.") + + with patch("msprobe.mindspore.debugger.precision_debugger.parse_json_config", new=mock_parse_json_config), \ + patch.object(PrecisionDebugger, "_get_execution_mode", new=mock_get_mode), \ + patch("msprobe.mindspore.debugger.precision_debugger.TaskHandlerFactory.create", return_value=handler): + common_config.task = Const.FREE_BENCHMARK + mock_get_mode.return_value = MsConst.PYNATIVE_MODE + mock_parse_json_config.return_value = [common_config, task_config] + Handler.called = False + debugger = PrecisionDebugger() + debugger.start() + self.assertTrue(Handler.called) + + def test_stop_step(self): + class MockPrecisionDebugger: + def __init__(self): + self.task = Const.TENSOR + self.service = None + PrecisionDebugger._instance = None + with self.assertRaises(Exception) as context: + PrecisionDebugger.stop() + self.assertEqual(str(context.exception), "PrecisionDebugger instance is not created.") + with self.assertRaises(Exception) as context: + PrecisionDebugger.step() + self.assertEqual(str(context.exception), "PrecisionDebugger instance is not created.") + PrecisionDebugger._instance = MockPrecisionDebugger() + Runtime.is_running = True + PrecisionDebugger.stop() + self.assertFalse(Runtime.is_running) + Runtime.step_count = 0 + PrecisionDebugger.step() + self.assertEqual(Runtime.step_count, 1) diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_config.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_config.py index 30212d95e621bea516b888e7e61d042990a2c93a..4954acc116c2a864e3f36b8a35cb8a35c68615e4 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_config.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_config.py @@ -14,15 +14,15 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -from unittest import TestCase +import unittest from unittest.mock import patch, mock_open from msprobe.core.common.const import Const from msprobe.mindspore.ms_config import (parse_json_config, parse_task_config, - TensorConfig, StatisticsConfig, OverflowCheckConfig) + TensorConfig, StatisticsConfig, OverflowCheckConfig, FreeBenchmarkConfig) -class TestMsConfig(TestCase): +class TestMsConfig(unittest.TestCase): def test_parse_json_config(self): mock_json_data = { "dump_path": "./dump/", @@ -64,6 +64,25 @@ class TestMsConfig(TestCase): task_config = parse_task_config("overflow_check", mock_json_config) self.assertTrue(isinstance(task_config, OverflowCheckConfig)) + mock_json_config.update({"overflow_check": {"overflow_nums": "1"}}) with self.assertRaises(Exception) as context: - parse_task_config("free_benchmark", mock_json_config) + task_config = parse_task_config("overflow_check", mock_json_config) + self.assertEqual(str(context.exception), "overflow_nums is invalid, it should be an integer") + + mock_json_config.update({"overflow_check": {"overflow_nums": 0}}) + with self.assertRaises(Exception) as context: + task_config = parse_task_config("overflow_check", mock_json_config) + self.assertEqual(str(context.exception), "overflow_nums should be -1 or positive integer") + + mock_json_config.update({"overflow_check": {"overflow_nums": 1}}) + mock_json_config.update({"overflow_check": {"check_mode": "core"}}) + with self.assertRaises(Exception) as context: + task_config = parse_task_config("overflow_check", mock_json_config) + self.assertEqual(str(context.exception), "check_mode is invalid") + + task_config = parse_task_config("free_benchmark", mock_json_config) + self.assertTrue(isinstance(task_config, FreeBenchmarkConfig)) + + with self.assertRaises(Exception) as context: + parse_task_config("unsupported_task", mock_json_config) self.assertEqual(str(context.exception), "task is invalid.") diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_primitive_dump.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_primitive_dump.py new file mode 100644 index 0000000000000000000000000000000000000000..25189a9b65da209ffc212493dfe3a3b6a8b98a62 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_primitive_dump.py @@ -0,0 +1,82 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +import os + +import unittest +from unittest.mock import Mock, patch +import copy +from msprobe.core.common.utils import Const +from msprobe.mindspore.service import Service +import mindspore +from mindspore.common.tensor import Tensor +from mindspore import ops +from mindspore import nn +from msprobe.core.common.exceptions import MsprobeException +from msprobe.core.common_config import CommonConfig, BaseConfig +from msprobe.mindspore.debugger.debugger_config import DebuggerConfig +from unittest.mock import MagicMock +import numpy as np + + +class DummyModel(nn.Cell): + def __init__(self): + super(DummyModel, self).__init__() + self.dense = nn.Dense(2, 2) + + def construct(self, x): + return self.dense(x) +class TestService(unittest.TestCase): + def setUp(self): + json_config = { + "task": "statistics", + "dump_path": "/absolute_path", + "rank": [], + "step": [0, 2], + "level": "L1" + } + + common_config = CommonConfig(json_config) + task_config = BaseConfig(json_config) + config = DebuggerConfig(common_config, task_config) + self.service = Service(config) + self.service.model = Mock() + self.service.data_collector = Mock() + self.service.switch = True # Make sure the switch is on for testing + + def test_check_model_valid_none(self): + model = None + self.assertIsNone(self.service.check_model_valid(model)) + + def test_check_model_valid_valid_model(self): + model = DummyModel() + self.assertEqual(self.service.check_model_valid(model), model) + + def test_check_model_valid_invalid_model(self): + model = "invalid_model" + with self.assertRaises(MsprobeException) as context: + self.service.check_model_valid(model) + + # For the purpose of the test, let's also verify the expected exception message + expected_message = "[msprobe] 无效参数: model 参数必须是 mindspore.nn.Cell 类型。" + self.assertEqual(str(context.exception), expected_message) + + def test_update_primitive_counters(self): + primitive_name = "test_primitive" + self.service.update_primitive_counters(primitive_name) + self.assertEqual(self.service.primitive_counters[primitive_name], 0) + self.service.update_primitive_counters(primitive_name) + self.assertEqual(self.service.primitive_counters[primitive_name], 1) diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_task_handler_factory.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_task_handler_factory.py index 41be7b1db6c7d723aaeec1607f564ac3d772b404..cdc88a3beb4918ad60482bffc7e87926178498c8 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_task_handler_factory.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_task_handler_factory.py @@ -21,6 +21,7 @@ from msprobe.core.common_config import CommonConfig, BaseConfig from msprobe.mindspore.debugger.debugger_config import DebuggerConfig from msprobe.mindspore.dump.kernel_graph_dump import KernelGraphDump from msprobe.mindspore.task_handler_factory import TaskHandlerFactory +from msprobe.mindspore.common.const import Const class TestTaskHandlerFactory(TestCase): @@ -43,6 +44,7 @@ class TestTaskHandlerFactory(TestCase): common_config = CommonConfig(json_config) task_config = BaseConfig(json_config) config = DebuggerConfig(common_config, task_config) + config.execution_mode = Const.GRAPH_GE_MODE handler = TaskHandlerFactory.create(config) self.assertTrue(isinstance(handler, KernelGraphDump)) @@ -52,7 +54,7 @@ class TestTaskHandlerFactory(TestCase): TaskHandlerFactory.create(config) self.assertEqual(str(context.exception), "Can not find task handler") - config.task = "free_benchmark" + config.task = "Free_benchmark" with self.assertRaises(Exception) as context: TaskHandlerFactory.create(config) self.assertEqual(str(context.exception), "valid task is needed.") diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/advisor/test_advisor.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/advisor/test_advisor.py index 176b80068f70e60a06a6eed77b23b35e8b48a50d..e140f8263826367edd4f0a58c0de06812c10219f 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/advisor/test_advisor.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/advisor/test_advisor.py @@ -7,8 +7,8 @@ from unittest.mock import patch import pandas -from msprobe.pytorch.advisor.advisor import Advisor -from msprobe.pytorch.advisor.advisor_const import AdvisorConst +from msprobe.core.advisor.advisor import Advisor +from msprobe.core.advisor.advisor_const import AdvisorConst class TestAdvisor(unittest.TestCase): diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/compare/test_acc_compare.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/compare/test_acc_compare.py index 288e259c0aae104a62054af3813b7831ec7722f7..b08b09c8529928470899dbb9cd3ffc0e40be5f24 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/compare/test_acc_compare.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/compare/test_acc_compare.py @@ -1,7 +1,10 @@ # coding=utf-8 import unittest import pandas as pd -from msprobe.pytorch.compare import acc_compare as compare +from msprobe.core.compare.check import check_graph_mode +from msprobe.core.compare.utils import merge_tensor, read_op, get_accuracy, rename_api +from msprobe.core.compare.highlight import find_error_rows,find_compare_result_error_rows +from msprobe.pytorch.compare.pt_compare import PTComparator npu_dict = {'op_name': ['Functional_conv2d_0_forward_input.0', 'Functional_conv2d_0_forward_input.1', 'Functional_conv2d_0_forward_input.2', 'Functional_conv2d_0_forward_output'], @@ -208,60 +211,62 @@ class TestUtilsMethods(unittest.TestCase): def test_check_graph_mode(self): op1 = "Aten" op2 = "torch" - self.assertTrue(compare.check_graph_mode(op1, op2)) - self.assertTrue(compare.check_graph_mode(op2, op1)) - self.assertFalse(compare.check_graph_mode(op1, op1)) - self.assertFalse(compare.check_graph_mode(op2, op2)) + self.assertTrue(check_graph_mode(op1, op2)) + self.assertTrue(check_graph_mode(op2, op1)) + self.assertFalse(check_graph_mode(op1, op1)) + self.assertFalse(check_graph_mode(op2, op2)) def test_check_op(self): fuzzy_match = False - result = compare.check_op(npu_dict, bench_dict, fuzzy_match) + ptComparator=PTComparator() + result = ptComparator.check_op(npu_dict, bench_dict, fuzzy_match) self.assertEqual(result, True) def test_merge_tensor(self): - op_dict = compare.merge_tensor(tensor_list, True, False) + op_dict = merge_tensor(tensor_list, True, False) self.assertEqual(op_dict, result_op_dict) def test_read_op(self): - result = compare.read_op(op_data, op_name) + result = read_op(op_data, op_name) self.assertEqual(result, op_result) def test_match_op(self): fuzzy_match = False - a, b = compare.match_op([npu_dict], [bench_dict], fuzzy_match) + ptComparator=PTComparator() + a, b = ptComparator.match_op([npu_dict], [bench_dict], fuzzy_match) self.assertEqual(a, 0) self.assertEqual(b, 0) def test_get_accuracy(self): result = [] - compare.get_accuracy(result, npu_dict, bench_dict, highlight_dict) + get_accuracy(result, npu_dict, bench_dict, highlight_dict) self.assertEqual(result, o_result) def test_get_accuracy_graph_mode(self): result = [] - compare.get_accuracy(result, npu_dict_aten, bench_dict_functional, highlight_dict) + get_accuracy(result, npu_dict_aten, bench_dict_functional, highlight_dict) self.assertEqual(result, aten_result) def test_find_error_rows(self): summary_result = [summary_line_input, summary_line_1, summary_line_2, summary_line_3] highlight_dict = {'red_rows': [], 'yellow_rows': []} - compare.find_error_rows(summary_result, 0, 1, highlight_dict, summary_compare=True) + find_error_rows(summary_result, 0, 1, highlight_dict, summary_compare=True) self.assertEqual(highlight_dict, {'red_rows': [], 'yellow_rows': []}) def test_find_compare_result_error_rows(self): result = [line_input, line_1, line_2, line_3] result_df = pd.DataFrame(result) highlight_dict = {'red_rows': [], 'yellow_rows': []} - compare.find_compare_result_error_rows(result_df, highlight_dict, False, False) + find_compare_result_error_rows(result_df, highlight_dict, False, False) self.assertEqual(highlight_dict, {'red_rows': [num_1, num_3], 'yellow_rows': [num_2]}) def test_rename_api(self): test_name_1 = "Distributed.broadcast.0.forward.input.0" expect_name_1 = "Distributed.broadcast.input.0" - actual_name_1 = compare.rename_api(test_name_1, "forward") + actual_name_1 = rename_api(test_name_1, "forward") self.assertEqual(actual_name_1, expect_name_1) test_name_2 = "Torch.sum.0.backward.output.0" expect_name_2 = "Torch.sum.output.0" - actual_name_2 = compare.rename_api(test_name_2, "backward") + actual_name_2 = rename_api(test_name_2, "backward") self.assertEqual(actual_name_2, expect_name_2) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py index 399efeb42d7cd7e7e34dd472cd8a9d82c26a5b5e..8be3be413fbd2fdc43127cff1a915b7d1bb14a12 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/free_benchmark/result_handlers/test_result_handler.py @@ -15,6 +15,7 @@ from msprobe.pytorch.free_benchmark.common.params import DataParams, make_handle from msprobe.pytorch.free_benchmark.result_handlers.handler_factory import ( FuzzHandlerFactory, ) +from msprobe.pytorch.free_benchmark.result_handlers.base_handler import FuzzHandler class Config(ABC): @@ -119,3 +120,21 @@ class TestFuzzHandler(TestCase): api_threshld, ThresholdConfig.DTYPE_PER_THD[torch.float16] ) + + def test_tensor_split_for_error_calculate(self): + # 设置模拟的张量的大小 + tensor_size = 256 * 1024 * 1024 + origin_output = torch.randn(tensor_size, dtype=torch.float32) + perturbed_output = torch.randn(tensor_size, dtype=torch.float32) + + # 调用tensor_split_for_error_calculate方法 + origin_output_chunks, perturbed_output_chunks = FuzzHandler.tensor_split_for_error_calculate( + origin_output, perturbed_output) + + # 验证返回的chunks数量和形状是否正确 + self.assertEqual(len(origin_output_chunks), 64) + self.assertEqual(len(perturbed_output_chunks), 64) + for chunk in origin_output_chunks: + self.assertEqual(chunk.shape, (4 * 1024 * 1024,)) + for chunk in perturbed_output_chunks: + self.assertEqual(chunk.shape, (4 * 1024 * 1024,)) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/free_benchmark/test_main.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/free_benchmark/test_main.py index 4498a2af7054edd89aa6fae6a057a489216794b6..3fe3da9a00f6eb610a8831a205e21850c0a521d7 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/free_benchmark/test_main.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/free_benchmark/test_main.py @@ -61,6 +61,7 @@ class TestInterface(TestCase): def testForwardFix(self): # 对于前向接口,在forward钩子中开启FIX,返回结果给hook的输出 + # 为了与下一层的输入对齐、应该转换为扰动前输出的dtype,否则可能报错 config = Config(Const.FORWARD, HandlerType.FIX) checker = FreeBenchmarkCheck(config) # 执行算子前向 @@ -76,7 +77,7 @@ class TestInterface(TestCase): kwargs={}, output=out, ) - self.assertEqual(result.dtype, torch.float32) + self.assertEqual(result.dtype, torch.float16) def testBackwardCheck(self): # 对于反向接口,在pre forward时暂存input, 然后在backwrad后进行对比 diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/grad_probe/test_grad_csv.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/grad_probe/test_grad_csv.py index bd569f5a29c96e3484f4982b72aed99f4b059129..f39d3f091faf8d57f80cccbadc15259ee54269f0 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/grad_probe/test_grad_csv.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/grad_probe/test_grad_csv.py @@ -4,6 +4,7 @@ import os import torch from msprobe.pytorch.grad_probe.grad_stat_csv import GradStatCsv from msprobe.pytorch.grad_probe.grad_monitor import GradientMonitor +from msprobe.core.grad_probe.constant import level_adp grad_tensor = torch.tensor([[-2, 2], [0.2, 0.3]]) @@ -11,27 +12,27 @@ grad_tensor = torch.tensor([[-2, 2], [0.2, 0.3]]) class TestGradCSV(unittest.TestCase): def test_level_L0_header(self): self.assertEqual(['param_name', 'MD5', 'max', 'min', 'norm', 'shape'], - GradStatCsv.generate_csv_header(GradientMonitor.level_adp["L0"], [-1, 0, 1])) + GradStatCsv.generate_csv_header(level_adp["L0"], [-1, 0, 1])) def test_level_L1_header(self): self.assertEqual(['param_name', 'max', 'min', 'norm', 'shape'], - GradStatCsv.generate_csv_header(GradientMonitor.level_adp["L1"], [-1, 0, 1])) + GradStatCsv.generate_csv_header(level_adp["L1"], [-1, 0, 1])) def test_level_L2_header(self): self.assertEqual(['param_name', '(-inf, -1]', '(-1, 0]', '(0, 1]', '(1, inf)', '=0', 'max', 'min', 'norm', 'shape'], - GradStatCsv.generate_csv_header(GradientMonitor.level_adp["L2"], [-1, 0, 1])) + GradStatCsv.generate_csv_header(level_adp["L2"], [-1, 0, 1])) def test_level_L0_content(self): - generated_csv_line = GradStatCsv.generate_csv_line("model.conv2d", GradientMonitor.level_adp["L0"], grad_tensor, [-1, 0, 1]) + generated_csv_line = GradStatCsv.generate_csv_line("model.conv2d", level_adp["L0"], grad_tensor, [-1, 0, 1]) self.assertEqual(['model.conv2d', '678a6c7d9d9716682b56fda097d0936c', 2.0, -2.0, 2.851315498352051, [2, 2]], generated_csv_line) def test_level_L1_content(self): - generated_csv_line = GradStatCsv.generate_csv_line("model.conv2d", GradientMonitor.level_adp["L1"], grad_tensor, [-1, 0, 1]) + generated_csv_line = GradStatCsv.generate_csv_line("model.conv2d", level_adp["L1"], grad_tensor, [-1, 0, 1]) self.assertEqual(['model.conv2d', 2.0, -2.0, 2.851315498352051, [2, 2]], generated_csv_line) def test_level_L2_content(self): - generated_csv_line = GradStatCsv.generate_csv_line("model.conv2d", GradientMonitor.level_adp["L2"], grad_tensor, [-1, 0, 1]) + generated_csv_line = GradStatCsv.generate_csv_line("model.conv2d", level_adp["L2"], grad_tensor, [-1, 0, 1]) self.assertEqual(['model.conv2d', 0.25, 0.0, 0.5, 0.25, 0.0, 2.0, -2.0, 2.851315498352051, [2, 2]], generated_csv_line) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/grad_probe/test_grad_monitor.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/grad_probe/test_grad_monitor.py index d79cca50287aa1bcba4496809cea221155f45187..607addd69b2217285569049ff44003af544c69eb 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/grad_probe/test_grad_monitor.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/grad_probe/test_grad_monitor.py @@ -10,15 +10,24 @@ from msprobe.core.grad_probe.grad_compare import GradComparator from msprobe.pytorch.grad_probe.grad_monitor import GradientMonitor from msprobe.pytorch.pt_config import GradToolConfig +class config: + def __init__(self, config_dict): + for key, value in config_dict.items(): + setattr(self, key, value) -config_dict = { - "level": "L1", - "param_list": "", +common_config_dict = { "rank": [], "step": [], - "bounds": [-1,0,1], - "output_path": "./grad_output" + "dump_path": "./grad_output" +} +common_config = config(common_config_dict) + +task_config_dict = { + "grad_level": "L1", + "param_list": "", + "bounds": [-1,0,1] } +task_config = config(task_config_dict) def seed_all(seed=1234, mode=False): random.seed(seed) @@ -53,7 +62,8 @@ def get_grad_monitor(): nn.init.constant_(test_module.linear.bias, 1.0) optimizer = torch.optim.SGD(test_module.parameters(), lr=1e-2) - gm = GradientMonitor(GradToolConfig(config_dict), test_module) + gm = GradientMonitor(common_config, task_config) + gm.monitor(test_module) for input_data, label in zip(inputs, labels): output = test_module(input_data) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_hook_module.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_hook_module.py index 50783e5d736c024b03f20008ad6b72882eddcd87..96f4b4df2905c0ce6a3a3dd72dac40596f89cf8a 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_hook_module.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_hook_module.py @@ -7,13 +7,18 @@ class TestHookModule(unittest.TestCase): def test_call_1(self): def forward_pre_hook(): return "result_input", "result_kwargs" + def forward_hook(): return 2 + def backward_hook(): pass + def forward_hook_torch_version_below_2(): + pass + def hook(prefix): - return forward_pre_hook, forward_hook, backward_hook + return forward_pre_hook, forward_hook, backward_hook, forward_hook_torch_version_below_2 HOOKModule.prefix_op_name_ = "123" test = HOOKModule(hook) test._call_func = Mock(return_value=1) @@ -23,13 +28,18 @@ class TestHookModule(unittest.TestCase): def test_call_2(self): def forward_pre_hook(nope, input, kwargs): return input, kwargs + def forward_hook(nope, input, kwargs, result): return input + def backward_hook(): pass + def forward_hook_torch_version_below_2(): + pass + def hook(prefix): - return forward_pre_hook, forward_hook, backward_hook + return forward_pre_hook, forward_hook, backward_hook, forward_hook_torch_version_below_2 HOOKModule.prefix_op_name_ = "123" input = 2 test = HOOKModule(hook) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py index 4940b07cb0d8e9d2283db3daebc910a7fdcd6ce9..f219e22e86402f5f0f0f9fa4fb6f95d7c8e88eac 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py @@ -6,12 +6,17 @@ from msprobe.pytorch.hook_module.wrap_aten import AtenOPTemplate, AtenOPPacketTe def hook(name): def forward_pre_hook(nope, input, kwargs): return input, kwargs + def forward_hook(nope, input, kwargs, result): return 2 + def backward_hook(): pass + + def forward_hook_torch_version_below_2(): + pass - return forward_pre_hook, forward_hook, backward_hook + return forward_pre_hook, forward_hook, backward_hook, forward_hook_torch_version_below_2 diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py index 9a375e45bfcdc93ac36fb9d44a79f50fea7932d5..246feb56becf9942de9214f5b24b8471e9b4024a 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_distributed.py @@ -6,11 +6,17 @@ class TestWrapDistributed(unittest.TestCase): def hook(name, prefix): def forward_pre_hook(nope, input, kwargs): return input, kwargs + def forward_hook(nope, input, kwargs, result): return 2 + def backward_hook(): pass - return forward_pre_hook, forward_hook, backward_hook + + def forward_hook_torch_version_below_2(): + pass + + return forward_pre_hook, forward_hook, backward_hook, forward_hook_torch_version_below_2 def test_get_distributed_ops(self): ops = get_distributed_ops() diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py index 61f76b0ca0a59ee680ff40991fa9cba5e42d869d..2aadc358a93c0330bf281d0c3b243d3bb57da713 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_tensor.py @@ -8,11 +8,17 @@ class TestWrapTensor(unittest.TestCase): def hook(name, prefix): def forward_pre_hook(nope, input, kwargs): return input, kwargs + def forward_hook(nope, input, kwargs, result): return 2 + def backward_hook(): pass - return forward_pre_hook, forward_hook, backward_hook + + def forward_hook_torch_version_below_2(): + pass + + return forward_pre_hook, forward_hook, backward_hook, forward_hook_torch_version_below_2 def test_get_tensor_ops(self): result = get_tensor_ops() diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py index e1a3e77983d80e7c0519e30afbb592311550e794..14b156e3b6534d9b1ddd8450a7cb14df60362eed 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_torch.py @@ -8,11 +8,17 @@ class TestWrapTorch(unittest.TestCase): def hook(name, prefix): def forward_pre_hook(nope, input, kwargs): return input, kwargs + def forward_hook(nope, input, kwargs, result): return 2 + def backward_hook(): pass - return forward_pre_hook, forward_hook, backward_hook + + def forward_hook_torch_version_below_2(): + pass + + return forward_pre_hook, forward_hook, backward_hook, forward_hook_torch_version_below_2 def setUp(self): diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/utils.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/utils.py index 80798ff4151cff5f2e7abdb75e6fc972e49975e5..583829074a20849a3241c276bbf5076d1b63063c 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/utils.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/utils.py @@ -21,7 +21,7 @@ range_begin_flag, range_end_flag = False, False def check_list_or_acl_mode(name_prefix): global dump_count for item in DumpUtil.dump_switch_scope: - if PRE_FORWARD in name_prefix: + if Const.PRE_FORWARD in name_prefix: rename = item.rsplit(Const.DOT, 1)[0] if name_prefix.startswith(rename): return True diff --git a/debug/accuracy_tools/setup.py b/debug/accuracy_tools/setup.py index afbf8feb3a0f400bb2864fa1beb3ebb584fb72ff..70a69e9de941110a36ea12deee131f5addd90771 100644 --- a/debug/accuracy_tools/setup.py +++ b/debug/accuracy_tools/setup.py @@ -14,7 +14,7 @@ import setuptools -__version__ = '1.0.1' +__version__ = '1.0.2' INSTALL_REQUIRED = [ "wheel", diff --git "a/plugins/tensorboard-plugins/tb_plugin/docs/\345\205\254\347\275\221URL\350\257\264\346\230\216.xlsx" "b/plugins/tensorboard-plugins/tb_plugin/docs/\345\205\254\347\275\221URL\350\257\264\346\230\216.xlsx" index b7a8bf1fd0e7eec640e46af76e16c6a228f335ba..fbe5a354ffba8619d9e93012d6fa3715e1f50e19 100644 Binary files "a/plugins/tensorboard-plugins/tb_plugin/docs/\345\205\254\347\275\221URL\350\257\264\346\230\216.xlsx" and "b/plugins/tensorboard-plugins/tb_plugin/docs/\345\205\254\347\275\221URL\350\257\264\346\230\216.xlsx" differ diff --git a/profiler/advisor/README.md b/profiler/advisor/README.md index 77027110559de578d9339c3f5a3d6c762e72a6b5..0f6a038077114368cd8dec715848e38c5f10d5a7 100644 --- a/profiler/advisor/README.md +++ b/profiler/advisor/README.md @@ -62,19 +62,22 @@ msprof-analyze的advisor功能是将Ascend PyTorch Profiler或者msprof采集的 #### 命令功能介绍 -| dimension | mode | 参数释义 | -| ---------- | -------------------------- | ---------------------------------------- | -| overall | overall_summary | 计算、通信、空闲等维度对性能数据进行拆解 | -| cluster | slow_rank | 慢卡识别 | -| | slow_link | 慢链路识别 | -| computing | aicpu | AI CPU调优 | -| | dynamic_shape_analysis | 识别动态Shape算子 | -| | block_dim_analysis | block dim算子调优 | -| | operator_no_bound_analysis | operator no bound | -| | graph | 融合算子图调优 | -| | freq_analysis | AI Core算子降频分析 | -| scheduling | timeline_fusion_ops | 亲和API替换调优 | -| | timeline_op_dispatch | 识别算子下发问题(路径3/路径5) | +| dimension | mode | 参数释义 | +| ---------- |---------------------------------------| ------------------------------------ | +| overall | overall_summary | 计算、通信、空闲等维度对性能数据进行拆解 | +| | environment_variable_analysis | 环境变量设置推荐 | +| cluster | slow_rank | 慢卡识别 | +| | slow_link | 慢链路识别 | +| | communication_retransmission_analysis |通信重传检测 | +| computing | aicpu | AI CPU调优 | +| | dynamic_shape_analysis | 识别动态Shape算子 | +| | block_dim_analysis | block dim算子调优 | +| | operator_no_bound_analysis | operator no bound | +| | graph | 融合算子图调优 | +| | freq_analysis | AI Core算子降频分析 | +|communication| packet_analysis |通信小包检测 | +| scheduling | timeline_fusion_ops | 亲和API替换调优 | +| | timeline_op_dispatch | 识别算子下发问题(路径3/路径5) | - all @@ -126,17 +129,23 @@ msprof-analyze的advisor功能是将Ascend PyTorch Profiler或者msprof采集的 ![输入图片说明](./img/cluster.png) -cluster模块的分析包含快慢卡和快慢链路分析,仅识别问题,不提供调优建议。 +cluster模块的分析 +1. 包含快慢卡和快慢链路分析,仅识别问题,不提供调优建议。 +2. 通信重传检测分析,识别发生重传的通信域并提供调优建议。 如下图示例,识别到当前训练任务的通信和下发(free较多说明存在任务下发存在问题)存在问题。 ![cluster_1](./img/cluster_1.png) - +如下图所示,识别到当前训练任务存在通信重传问题,并提供调优建议 +![cluster_2](./img/cluster_2.png) overall模块的分析包含当前训练任务慢卡的性能拆解,按照计算、通信和下发三个维度进行耗时的统计,可以基于该分析识别到训练性能瓶颈是计算、通信还是下发问题,同样不提供调优建议。 ![输入图片说明](./img/overall_0.png) ![输入图片说明](./img/overall.png) +overall模块的environment_variable_analysis是对环境变量的设置做出推荐 +![env_var.png](img%2Fenv_var.png) + schedule模块包含亲和API、aclOpCompile、syncBatchNorm、SynchronizeStream等多项检测。 如下图示例,Operator Dispatch Issues提示需要在运行脚本的最开头添加如下代码用于消除aclOpCompile: @@ -159,6 +168,9 @@ computation模块从device计算性能维度进行分析,能够识别AI CPU、 ![computation_1](./img/computation_1.png) +communication模块从通信维度进行分析,目前支持通信小算子检测。 +![communication](./img/communication.png) + ## 工具使用(Jupyter Notebook方式) Jupyter Notebook使用方式如下: diff --git a/profiler/advisor/analyzer/base_analyzer.py b/profiler/advisor/analyzer/base_analyzer.py index ada1b0bf4f4c8344c8830fe446c8d05dd583eac5..80368e1d60a14020637ba60bb41c5536dcf2e081 100644 --- a/profiler/advisor/analyzer/base_analyzer.py +++ b/profiler/advisor/analyzer/base_analyzer.py @@ -81,7 +81,11 @@ class BaseAnalyzer(VersionControl, metaclass=ABCMeta): for dataset_cls in dataset_cls_list: if dataset_cls and callable(dataset_cls): - dataset = dataset_cls(collection_path=self.collection_path, data=self.dataset_list, **self.kwargs) + try: + dataset = dataset_cls(collection_path=self.collection_path, data=self.dataset_list, **self.kwargs) + except Exception as e: + logger.error(e) + continue key = dataset_cls.get_key() if key not in self.dataset_list: self.dataset_list[key] = [] diff --git a/profiler/advisor/analyzer/cluster/Communication_retransmission_analyzer.py b/profiler/advisor/analyzer/cluster/Communication_retransmission_analyzer.py new file mode 100644 index 0000000000000000000000000000000000000000..3683ef1b44f8b6c571dd4d8fdce0d39882d342af --- /dev/null +++ b/profiler/advisor/analyzer/cluster/Communication_retransmission_analyzer.py @@ -0,0 +1,46 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from profiler.advisor.analyzer.base_analyzer import BaseAnalyzer +from profiler.advisor.result.result import OptimizeResult +from profiler.advisor.analyzer.cluster.Communication_retransmission_checker import CommunicationRetransmissionChecker +from profiler.advisor.display.html.render import HTMLRender +from profiler.advisor.dataset.cluster.cluster_dataset import ClusterCommunicationDataset + +logger = logging.getLogger() + + +class RDMARetransmissionAnalyzer(BaseAnalyzer): + dataset_cls_list = [ClusterCommunicationDataset] + + def __init__(self, collection_path, n_processes: int = 1, **kwargs) -> None: + super().__init__(collection_path, n_processes, **kwargs) + key = ClusterCommunicationDataset.get_key() + self.dataset = self.get_first_data_by_key(self.dataset_list, key) + self.result = OptimizeResult() + self.html_render = HTMLRender() + self.html = None + + @BaseAnalyzer.check_data((ClusterCommunicationDataset.get_key(),)) + def optimize(self, **kwargs): + add_render_list = kwargs.get("add_render_list", True) + rdma_checker = CommunicationRetransmissionChecker(**kwargs) + rdma_checker.check_retransmission(self.dataset) + if not rdma_checker.rdma_issues: + return self.result + rdma_checker.make_record(self.result) + self.html = rdma_checker.make_render(self.html_render, add_render_list) + return self.result diff --git a/profiler/advisor/analyzer/cluster/Communication_retransmission_checker.py b/profiler/advisor/analyzer/cluster/Communication_retransmission_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..cc0f688e843cdc75681827e5599572d5dd42c3cc --- /dev/null +++ b/profiler/advisor/analyzer/cluster/Communication_retransmission_checker.py @@ -0,0 +1,128 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os +from typing import Dict, List +from collections import defaultdict +from profiler.advisor.dataset.cluster.cluster_dataset import ClusterCommunicationDataset +from profiler.advisor.result.result import OptimizeResult +from profiler.advisor.result.item import OptimizeItem, OptimizeRecord +from profiler.cluster_analyse.common_func.file_manager import FileManager +from profiler.advisor.dataset.cluster.hccl_collection import HcclInfo + +logger = logging.getLogger() + + +class GroupStatistic: + def __init__(self, min_transmission_time): + self.retransmission_issue = False + self.abnormal_op_dict: Dict[str, List] = dict() + + def add_op(self, op_name: str, hccl_info: HcclInfo): + if self.abnormal_op_dict.get(op_name) is None: + self.abnormal_op_dict.setdefault(op_name, []) + self.abnormal_op_dict.get(op_name).append([hccl_info.group, op_name, hccl_info.step, hccl_info.rank, + hccl_info.get_rdma_transit_size(), + hccl_info.get_rdma_transmit_time(), hccl_info.get_rdma_bandwidth()]) + + +class CommunicationRetransmissionChecker: + def __init__(self, **kwargs): + self.rdma_issues = False + self.desc = "" + self.sdma_desc = "" + self.rdma_desc = "" + self.suggestions = [] + self.abnormal_group_count = 0 + self.abnormal_rdma_list = [] + self.step_id = kwargs.get("step") + self.stage = None + self.group_statistics = defaultdict(GroupStatistic) + self.headers = ["Communication group", "Op name", "Step id", "Rank id", "RDMA transmit size(MB)", + "RDMA transmit time(ms)", "RDMA bandwidth"] + self._init_rule() + + def check_possible_retransmission_occurrence(self, hccl_list: List[HcclInfo]): + min_elapse_time = min(hccl.elapse_time for hccl in hccl_list) + max_transit_time = max(hccl.rdma_info.get('Transit Time(ms)', 0) for hccl in hccl_list) + if min_elapse_time < self.min_retransmission_time: # 检测是否是卡间不同步问题,而不是重传 + return False + return max_transit_time > self.min_retransmission_time + + def check_retransmission(self, hccl_dataset: ClusterCommunicationDataset): + """ + :Param event_dataset: dataset of timeline event + """ + for group_name, hccl_group_dict in hccl_dataset.hccl_dict.items(): + for op_name, hccl_op_dict in hccl_group_dict.items(): + for step_id, hccl_list in hccl_op_dict.items(): + if self.step_id and step_id != self.step_id: # 传输指定step(self.step_id)情况下,非目标step跳过 + continue + if not self.check_possible_retransmission_occurrence(hccl_list): + continue + self.rdma_issues = True + if self.group_statistics.get(group_name) is None: + self.group_statistics.setdefault(group_name, GroupStatistic(self.min_retransmission_time)) + self.abnormal_group_count += 1 + for hccl_info in hccl_list: + if hccl_info.rdma_info.get('Transit Size(MB)', 0): + transit_time = hccl_info.rdma_info.get('Transit Time(ms)', 0) + if transit_time > self.min_retransmission_time: + self.group_statistics.get(group_name).add_op(op_name, hccl_info) + if self.rdma_issues: + self.desc = self.desc.format(group_count=self.abnormal_group_count) + for _, group_statistic in self.group_statistics.items(): + for _, op_list in group_statistic.abnormal_op_dict.items(): + for op in op_list: + self.abnormal_rdma_list.append(op) + + def make_record(self, result: OptimizeResult): + """ + make record for what and how to optimize + """ + optimization_item = OptimizeItem("Communication retransmission analysis", self.desc, self.suggestions) + result.add(OptimizeRecord(optimization_item)) + + sub_table_name = "Comm Retransmission Analysis" if not self.stage else f"Stage-{self.stage}: Comm Retransmission Analysis" + result.add_detail(sub_table_name, headers=self.headers) + + for row in self.abnormal_rdma_list: + result.add_detail(sub_table_name, detail=row) + + def make_render(self, html_render, add_render_list=True): + return html_render.render_template(key="cluster", + template_dir="templates", + template_name="communication_retransmission_analysis.html", + desc=self.desc, + solutions=self.solutions, + headers=self.headers, + data=self.abnormal_rdma_list + ) + + def _init_rule(self): + syncbn_rule_path = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))), + "rules", + "rdma_analysis.yaml" + ) + + syncbn_rule = FileManager.read_yaml_file(syncbn_rule_path) + self.desc = syncbn_rule.get("problem") + self.min_retransmission_time = syncbn_rule.get("min_retransmission_time") + + self.solutions = syncbn_rule.get("solutions") + for solution in self.solutions: + for key, val in solution.items(): + self.suggestions.append(f"{key}, {val.get('desc')}") diff --git a/profiler/advisor/analyzer/communication/packet_analyzer.py b/profiler/advisor/analyzer/communication/packet_analyzer.py new file mode 100644 index 0000000000000000000000000000000000000000..73e5bc2bc99bf3a2c7e11ef55ae279e8ddeb5ef5 --- /dev/null +++ b/profiler/advisor/analyzer/communication/packet_analyzer.py @@ -0,0 +1,46 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from profiler.advisor.analyzer.base_analyzer import BaseAnalyzer +from profiler.advisor.result.result import OptimizeResult +from profiler.advisor.analyzer.communication.packet_checker import PacketChecker +from profiler.advisor.display.html.render import HTMLRender +from profiler.advisor.dataset.communication.communication_dataset import CommunicationDataset + +logger = logging.getLogger() + + +class PacketAnalyzer(BaseAnalyzer): + dataset_cls_list = [CommunicationDataset] + + def __init__(self, collection_path, n_processes: int = 1, **kwargs) -> None: + super().__init__(collection_path, n_processes, **kwargs) + key = CommunicationDataset.get_key() + self.dataset = self.get_first_data_by_key(self.dataset_list, key) + self.result = OptimizeResult() + self.html_render = HTMLRender() + self.html = None + + @BaseAnalyzer.check_data((CommunicationDataset.get_key(),)) + def optimize(self, **kwargs): + add_render_list = kwargs.get("add_render_list", True) + packet_checker = PacketChecker(**kwargs) + packet_checker.check_packet(self.dataset) + if not packet_checker.packet_issues: + return self.result + packet_checker.make_record(self.result) + self.html = packet_checker.make_render(self.html_render, add_render_list) + return self.result diff --git a/profiler/advisor/analyzer/communication/packet_checker.py b/profiler/advisor/analyzer/communication/packet_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..3d9ac81ffdb9cc049e6b82d01570f2f041d3ff68 --- /dev/null +++ b/profiler/advisor/analyzer/communication/packet_checker.py @@ -0,0 +1,148 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os +from profiler.advisor.dataset.communication.communication_dataset import CommunicationDataset +from profiler.advisor.result.result import OptimizeResult +from profiler.advisor.result.item import OptimizeItem, OptimizeRecord +from profiler.cluster_analyse.common_func.file_manager import FileManager +from profiler.advisor.utils.utils import convert_to_float + +logger = logging.getLogger() + + +class Statistic: + def __init__(self, min_ratio, min_size, desc, type_): + self.issue = False + self.count = 0 + self.abnormal_count = 0 + self.abnormal_duration = 0 + self.abnormal_ratio = 0 + self.min_ratio = min_ratio + self.min_size = min_size + self.desc = desc + self.type = type_ + + def check_threshold(self): + if self.count and self.abnormal_count: + self.abnormal_ratio = self.abnormal_count / self.count + if self.abnormal_ratio > self.min_ratio: + self.issue = True + return self.issue + + def process(self, hccl_info): + info = dict() + if self.type == "SDMA": + info = hccl_info.sdma_info + elif self.type == "RDMA": + info = hccl_info.rdma_info + if info.get('Transit Size(MB)', 0): + packet_size = info.get('Transit Size(MB)', 0) + if packet_size < self.min_size: + self.abnormal_count += 1 + self.abnormal_duration += info.get('Transit Time(ms)', 0) + self.count += 1 + + def adapt(self, dst_headers: list, src_headers, datas: list): + if not self.issue: + return False + dst_headers.extend(src_headers) + datas.extend([self.count, self.abnormal_count, self.abnormal_ratio, self.abnormal_duration]) + self.desc = self.desc.format( + abnormal_sdma_ratio=f"{round(self.abnormal_ratio, 4):.2%}", + min_sdma_size=self.min_size, + abnormal_sdma_time=round(self.abnormal_duration, 4)) + return True + + +class PacketChecker: + def __init__(self, **kwargs): + self.packet_issues = False + self.desc = "" + self.sdma_desc = "" + self.rdma_desc = "" + self.suggestions = [] + self.min_sdma_size = 0 + self.min_rdma_size = 0 + self.min_sdma_ratio = 0 + self.min_rdma_ratio = 0 + self.step_id = kwargs.get("step") + self.stage = None + self.packet_issues = False + self._init_rule() + self.sdma_statistic = Statistic(self.min_sdma_ratio, self.min_sdma_size, self.sdma_desc, "SDMA") + self.rdma_statistic = Statistic(self.min_rdma_ratio, self.min_rdma_size, self.rdma_desc, "RDMA") + self.small_packet_detail = [] + self.headers = [] + self.sdma_headers = ["SDMA total count", "Small SDMA count", "Small SDMA ratio", "Small SDMA duration(ms)"] + self.rdma_headers = ["RDMA total count", "Small RDMA count", "Small RDMA ratio", "Small RDMA duration(ms)"] + + def check_packet(self, hccl_dataset: CommunicationDataset): + for step_id, hccl_list in hccl_dataset.hccl_dict.items(): + if self.step_id and step_id != self.step_id: + continue + for hccl_info in hccl_list: + self.sdma_statistic.process(hccl_info) + self.rdma_statistic.process(hccl_info) + self.sdma_statistic.check_threshold() + self.rdma_statistic.check_threshold() + if self.sdma_statistic.adapt(self.headers, self.sdma_headers, self.small_packet_detail): + self.packet_issues = True + self.desc += self.sdma_statistic.desc + if self.rdma_statistic.adapt(self.headers, self.rdma_headers, self.small_packet_detail): + self.packet_issues = True + self.desc += self.rdma_statistic.desc + + def make_record(self, result: OptimizeResult): + """ + make record for what and how to optimize + """ + optimization_item = OptimizeItem("Packet analysis", self.desc, self.suggestions) + result.add(OptimizeRecord(optimization_item)) + + sub_table_name = "Packet Analysis" if not self.stage else f"Stage-{self.stage}: Packet Analysis" + result.add_detail(sub_table_name, headers=self.headers) + result.add_detail(sub_table_name, detail=self.small_packet_detail) + + def make_render(self, html_render, add_render_list=True): + return html_render.render_template(key="communication", + template_dir="templates", + template_name="packet_analysis.html", + desc=self.desc, + solutions=self.solutions, + headers=self.headers, + data=self.small_packet_detail + ) + + def _init_rule(self): + syncbn_rule_path = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))), + "rules", + "packet.yaml" + ) + + syncbn_rule = FileManager.read_yaml_file(syncbn_rule_path) + self.desc = syncbn_rule.get("problem") + self.sdma_desc = syncbn_rule.get("sdma_problem") + self.rdma_desc = syncbn_rule.get("rdma_problem") + self.min_sdma_size = convert_to_float(syncbn_rule.get("min_sdma_size")) + self.min_rdma_size = convert_to_float(syncbn_rule.get("min_rdma_size")) + self.min_sdma_ratio = convert_to_float(syncbn_rule.get("min_sdma_ratio")) + self.min_rdma_ratio = convert_to_float(syncbn_rule.get("min_rdma_ratio")) + + self.solutions = syncbn_rule.get("solutions") + for solution in self.solutions: + for key, val in solution.items(): + self.suggestions.append(f"{key}, {val.get('desc')}") diff --git a/profiler/advisor/analyzer/computation/ai_core_freq/ai_core_freq_checker.py b/profiler/advisor/analyzer/computation/ai_core_freq/ai_core_freq_checker.py index 5ea4dbd7542750469967b05ab9a738f2d70600e4..7afa09cca48fd9939c4fcbfdf2a9fb5f29e3b468 100644 --- a/profiler/advisor/analyzer/computation/ai_core_freq/ai_core_freq_checker.py +++ b/profiler/advisor/analyzer/computation/ai_core_freq/ai_core_freq_checker.py @@ -49,7 +49,7 @@ class AICoreFreqChecker: max_freq = max(self.DEFAULT_FREQ, convert_to_float(Config().get_config("aic_frequency"))) decrease_freq_ratio = sum(max_freq - freq for freq in freq_list) / (max_freq * len(freq_list)) - if decrease_freq_ratio >= self.DECREASE_FREQ_RATIO: + if decrease_freq_ratio >= Config().get_config("frequency_threshold"): self.ai_core_freq_issues = True self.decrease_freq_ops.append([op_name, op_count, op_total_duration, f"{round(decrease_freq_ratio, 4):.2%}", diff --git a/profiler/advisor/analyzer/overall/environment_variable_analyzer.py b/profiler/advisor/analyzer/overall/environment_variable_analyzer.py new file mode 100644 index 0000000000000000000000000000000000000000..3daaa3460912620795294faa1266a34c858918e8 --- /dev/null +++ b/profiler/advisor/analyzer/overall/environment_variable_analyzer.py @@ -0,0 +1,47 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from profiler.advisor.analyzer.base_analyzer import BaseAnalyzer +from profiler.prof_common.path_manager import PathManager +from profiler.advisor.dataset.environment_variable_dataset import EnvironmentVariableDataset +from profiler.advisor.analyzer.overall.environment_variable_checker import EnvironmentVariabelChecker + + +class EnvironmentVariabelAnalyzer(BaseAnalyzer): + dataset_cls_list = [EnvironmentVariableDataset] + + def __init__(self, collection_path: str, n_processes: int = 1, **kwargs): + super().__init__(collection_path, n_processes, **kwargs) + self.dataset = self.get_first_data_by_key(self.dataset_list, EnvironmentVariableDataset.get_key()) + + def optimize(self, **kwargs): + try: + PathManager.check_input_directory_path(self.collection_path) + except RuntimeError as e: + logging.error("Invalid path: %s", str(e)) + return self.result + self.collection_path = PathManager.get_realpath(self.collection_path) + checker = EnvironmentVariabelChecker() + checker.format_env_suggest(self.dataset) + checker.make_record(self.result) + checker.make_render(self.html_render) + return self.result + + def make_record(self): + pass + + def make_render(self): + pass diff --git a/profiler/advisor/analyzer/overall/environment_variable_checker.py b/profiler/advisor/analyzer/overall/environment_variable_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..ca316530d706e6cc5ec4d303687133a96a67727e --- /dev/null +++ b/profiler/advisor/analyzer/overall/environment_variable_checker.py @@ -0,0 +1,102 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +from profiler.cluster_analyse.common_func.file_manager import FileManager +from profiler.advisor.result.result import OptimizeResult +from profiler.advisor.result.item import OptimizeItem +from profiler.advisor.result.item import OptimizeRecord +from profiler.advisor.common.analyzer_scopes import SupportedScopes +from profiler.advisor.display.html.render import HTMLRender + + +class EnvironmentVariabelChecker: + ENV_SUGGEST_CONDITION = { + "ASCEND_GLOBAL_LOG_LEVEL": lambda x: x != "" and x != 3, + "HCCL_RDAM_TC": lambda x: x != "", + "HCCL_RDMA_SL": lambda x: x != "", + "ACLNN_CACHE_LIMIT": lambda x: x == "" or (isinstance(x, int) and x < 10000), + "HOST_CACHE_CAPACITY": lambda x: x == "" or x == 0, + "ASCEND_ENHANCE_ENABLE": lambda x: x == 0, + "PYTORCH_NPU_ALLOC_CONF": lambda x: "expandable_segments:True" not in x, + "ASCEND_LAUNCH_BLOCKING": lambda x: x != 1, + } + + HEADERS = ["Environment", "Value", "Description", "Suggestion"] + + def __init__(self): + self.environment_info = self.read_environment_info() + self.env_suggest_csv = [] + self.env_suggest_html = [] + + @staticmethod + def read_environment_info(): + environment_variable_info_path = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))), + "rules", + "environment_variable_info.yaml" + ) + return FileManager.read_yaml_file(environment_variable_info_path) + + def format_env_suggest(self, data): + data = data.env_data.get('ENV_VARIABLES', {}) + for env, value in data.items(): + if not self.ENV_SUGGEST_CONDITION.get(env, lambda x: False)(value): + continue + desc = self.environment_info.get(env, {}).get("desc", "") + suggest = self.environment_info.get(env, {}).get("suggest", "") + self.env_suggest_csv += [ + [ + env, + value, + desc, + suggest, + ] + ] + self.env_suggest_html += [ + [ + env, + value, + desc.replace('\n', '
'), + self.environment_info.get(env, {}).get("suggest_html", suggest), + ] + ] + + def make_record(self, result: OptimizeResult): + if not self.env_suggest_csv: + return + desc = f"Describe and suggest the optimal environment variable settings" + suggestion = "Please set the optimal environment variable" + + optimization_item = OptimizeItem( + SupportedScopes.ENVIRONMENT_VARIABLE_ANALYSIS, + desc, + [suggestion] + ) + result.add(OptimizeRecord(optimization_item)) + result.add_detail(SupportedScopes.ENVIRONMENT_VARIABLE_ANALYSIS, headers=self.HEADERS) + for env_suggest in self.env_suggest_csv: + result.add_detail(SupportedScopes.ENVIRONMENT_VARIABLE_ANALYSIS, detail=env_suggest) + + def make_render(self, html_render: HTMLRender): + if not self.env_suggest_html: + return + html_render.render_template(key="overall", + template_dir="templates", + template_name="environment_variable.html", + result={ + "headers": self.HEADERS, + "data": self.env_suggest_html, + }) diff --git a/profiler/advisor/common/analyzer_scopes.py b/profiler/advisor/common/analyzer_scopes.py index 52e3e07554f354deb62222ee0de6e66ef8b07e2e..b947798c9e6bb708301d6c02c1df93e4665c4132 100644 --- a/profiler/advisor/common/analyzer_scopes.py +++ b/profiler/advisor/common/analyzer_scopes.py @@ -1,3 +1,17 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. class SupportedScopes: # used for specify fourth-level commands and define the key of the result dict @@ -6,7 +20,10 @@ class SupportedScopes: GRAPH = "graph" SLOW_RANK = "slow_rank" SLOW_LINK = "slow_link" + COMMUNICATION_RETRANSMISSION_DETECTION = "communication_retransmission_analysis" + PACKET = "packet_analysis" OVER_ALL = "over_all" + ENVIRONMENT_VARIABLE_ANALYSIS = "environment_variable_analysis" DYNAMIC_SHAPE_ANALYSIS = "dynamic_shape_analysis" AICPU_ANALYSIS = "aicpu_analysis" BLOCK_DIM_ANALYSIS = "block_dim_analysis" diff --git a/profiler/advisor/common/constant.py b/profiler/advisor/common/constant.py index 06186080d1701ac06ffc80d3b83a892a84ec7255..c97cfbfd11e27a3d83ea2f9a25ea7870899bcfd1 100644 --- a/profiler/advisor/common/constant.py +++ b/profiler/advisor/common/constant.py @@ -75,6 +75,7 @@ CANN_VERSION_C17 = '8.0.RC1' SUPPORTED_CANN_VERSION = [CANN_VERSION_C30, CANN_VERSION_C13, CANN_VERSION_C15, CANN_VERSION_C17] DEFAULT_CANN_VERSION = CANN_VERSION_C17 ASCEND_PYTORCH_PROFILER = "ascend_pytorch_profiler" +PROFILER_METADATA = "profiler_metadata.json" MSLITE = "mslite" MSPROF = "msprof" SUPPORTED_PROFILING_TYPE = [ASCEND_PYTORCH_PROFILER, MSLITE, MSPROF] @@ -123,6 +124,20 @@ MAX_RETRIES = 3 TIMEOUT = 3 ADVISOR_RULE_PATH = "ADVISOR_RULE_PATH" +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. CLOUD_RULE_PATH = "rules/cloud/" DEFAULT_RULE_PATH = "./rules/" @@ -137,6 +152,7 @@ CLUSTER_ANALYSIS_OUTPUT = "cluster_analysis_output" KERNEL_DETAILS_CSV = "kernel_details.csv" CLUSTER_STEP_TIME_CSV = "cluster_step_trace_time.csv" CLUSTER_COMM_JSON = "cluster_communication.json" +COMMUNICATION_JSON = "communication.json" BOTTLENECK = "bottleneck" DATA = "data" diff --git a/profiler/advisor/dataset/cluster/cluster_dataset.py b/profiler/advisor/dataset/cluster/cluster_dataset.py index e1163f1cdd84265eb5cc5e356753cad5fa663339..b4956139c58436f6998ea8ce94a56fc280c038c3 100644 --- a/profiler/advisor/dataset/cluster/cluster_dataset.py +++ b/profiler/advisor/dataset/cluster/cluster_dataset.py @@ -1,3 +1,17 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import logging import os @@ -10,6 +24,7 @@ from profiler.cluster_analyse.common_func.constant import Constant from collections import defaultdict from profiler.cluster_analyse.cluster_analysis import Interface from profiler.advisor.dataset.cluster.cluster_step_trace_time_bean import ClusterStepTraceTimeBean +from profiler.advisor.dataset.cluster.hccl_collection import HcclInfo logger = logging.getLogger() @@ -114,6 +129,7 @@ class ClusterCommunicationDataset(ClusterDataset): self.SDMA_TIME_MS: 0, self.SDMA_SIZE_MB: 0, }) + self.hccl_dict = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) super().__init__(collection_path, data) @staticmethod @@ -136,9 +152,26 @@ class ClusterCommunicationDataset(ClusterDataset): def process(self, communication_json: dict): for comm_group, group_dict in communication_json.items(): + if self.hccl_dict.get(comm_group) is None: + self.hccl_dict.setdefault(comm_group, defaultdict(lambda: defaultdict(list))) for step, step_dict in group_dict.items(): for op, op_dict in step_dict.items(): self.compute_bandwidth(op_dict) + self.process_hccl_info(comm_group, step, op, op_dict) + + def process_hccl_info(self, group, step, op, op_dict): + op_name = op.split("@")[0] + for rank_id, rank_dict in op_dict.items(): + try: + hccl_info = HcclInfo(group, step, rank_id, op, rank_dict) + if self.hccl_dict[group].get(op_name) is None: + self.hccl_dict[group].setdefault(op_name, defaultdict(list)) + if self.hccl_dict[group][op_name].get(step) is None: + self.hccl_dict[group][op_name].setdefault(step, list()) + self.hccl_dict[group][op_name][step].append(hccl_info) + except ValueError as e: + msg = "[ERROR] Cluster_communication.json has invalid structure." + raise ValueError(msg) from e def compute_bandwidth(self, op_dict: dict): for rank_id, rank_dict in op_dict.items(): diff --git a/profiler/advisor/dataset/cluster/hccl_collection.py b/profiler/advisor/dataset/cluster/hccl_collection.py new file mode 100644 index 0000000000000000000000000000000000000000..a9fa536efd85b13300cd0b45f28b1c54c21ffa64 --- /dev/null +++ b/profiler/advisor/dataset/cluster/hccl_collection.py @@ -0,0 +1,78 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +hccl info +""" +import logging + +logger = logging.getLogger() + + +class HcclInfo(): + def __init__(self, group: str, step: str, rank: str, op: str, rank_dict: dict) -> None: + self._group = group + self._step = step + self._rank = rank + self._name = op.split("@")[0] + self._elapse_time = self.get_elapse_time(rank_dict, "Elapse Time(ms)") + self._sdma_info = self.get_communication_info(rank_dict, "SDMA") + self._rdma_info = self.get_communication_info(rank_dict, "RDMA") + + @property + def group(self): + return self._group + + @property + def step(self): + return self._step + + @property + def rank(self): + return self._rank + + @property + def name(self): + return self._name + + @property + def rdma_info(self): + return self._rdma_info + + @property + def sdma_info(self): + return self._sdma_info + + @property + def elapse_time(self): + return self._elapse_time + + @staticmethod + def get_communication_info(rank_dict: dict, name: str): + communication_bandwidth_info = rank_dict.get('Communication Bandwidth Info', dict()) + return communication_bandwidth_info.get(name, dict()) + + @staticmethod + def get_elapse_time(rank_dict: dict, name: str): + communication_time_info = rank_dict.get('Communication Time Info', dict()) + return communication_time_info.get(name, "") + + def get_rdma_transmit_time(self): + return self.rdma_info.get('Transit Time(ms)', 0) + + def get_rdma_transit_size(self): + return self.rdma_info.get('Transit Size(MB)', 0) + + def get_rdma_bandwidth(self): + return self.rdma_info.get('Bandwidth(GB/s)', 0) diff --git a/profiler/advisor/dataset/communication/__init__.py b/profiler/advisor/dataset/communication/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/profiler/advisor/dataset/communication/communication_dataset.py b/profiler/advisor/dataset/communication/communication_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..6cfc8708362a356042114b4970b1b1dfe8a0ca24 --- /dev/null +++ b/profiler/advisor/dataset/communication/communication_dataset.py @@ -0,0 +1,109 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os +from collections import defaultdict +from profiler.advisor.utils.utils import singleton +from profiler.advisor.common import constant as const +from profiler.cluster_analyse.common_func.file_manager import FileManager +from profiler.advisor.dataset.cluster.hccl_collection import HcclInfo +from profiler.advisor.utils.utils import CheckPathAccess + +logger = logging.getLogger() + + +@singleton +class CommunicationDataset: + RANK = "rank" + + def __init__(self, collection_path, data: dict, **kwargs) -> None: + self.timeline_dir = collection_path + self.timeline_data_list = self.get_file_path_from_directory(self.timeline_dir, + lambda file: file.endswith(const.COMMUNICATION_JSON)) + self.hccl_dict = defaultdict(list) + self.step = kwargs.get("step") + if self.parse(): + key = self.get_key() + if key not in data: + data[key] = [] + data[key].append(self) + + @staticmethod + def load_json_data(json_path): + if not os.path.exists(json_path): + msg = "[ERROR] cluster_communication.json doesn't exist, terminate analysis." + raise RuntimeError(msg) + data = FileManager.read_json_file(json_path) + return data + + @staticmethod + @CheckPathAccess + def get_file_path_from_directory(path, check_func): + """ + get file from directory + """ + file_list = [] + + if not path: + return file_list + + if not os.path.isdir(path): + logger.warning("Expected existed directory, but got %s", path) + + for root, _, files in os.walk(path): + if root.endswith("cluster_analysis_output"): + continue + for filename in files: + filepath = os.path.join(root, filename) + if check_func(filename): + file_list.append(filepath) + return file_list + + @classmethod + def get_key(cls): + """ + get key of dataset + :return: key + """ + return cls.__module__.rsplit('.', maxsplit=1)[-1] + + def parse(self): + if len(self.timeline_data_list) == 0: + logger.warning("Please ensure communication.json in %s, skip timeline analysis.", self.timeline_dir) + return False + + if len(self.timeline_data_list) > 1: + logger.warning("Found multiple communication.json in %s, load the file of device 0 for analysis .", + self.timeline_dir) + + json_data = self.load_json_data(sorted(self.timeline_data_list)[0]) + self.process(json_data) + return True + + def process(self, communication_json: dict): + for step, step_dict in communication_json.items(): + for group, group_dict in step_dict.items(): + for op, op_dict in group_dict.items(): + self.process_hccl_info(group, step, op, op_dict) + + def process_hccl_info(self, group, step, op, op_dict): + try: + hccl_info = HcclInfo(group, step, "None", op, op_dict) + if self.hccl_dict.get(step) is None: + self.hccl_dict.setdefault(step, list()) + self.hccl_dict[step].append(hccl_info) + except ValueError as e: + msg = "[ERROR] Cluster_communication.json has invalid structure." + raise ValueError(msg) from e diff --git a/profiler/advisor/dataset/environment_variable_dataset.py b/profiler/advisor/dataset/environment_variable_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..577273ffe8ae955ae8b33e1d871ef2f867aa3f71 --- /dev/null +++ b/profiler/advisor/dataset/environment_variable_dataset.py @@ -0,0 +1,47 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import logging + +from profiler.advisor.common import constant +from profiler.cluster_analyse.common_func.file_manager import FileManager + + +class EnvironmentVariableDataset: + def __init__(self, collection_path, data: dict, **kwargs): + self.collection_path = collection_path + self.env_data = {} + self.read_data() + + @staticmethod + def get_env_data_file(collection_path: str) -> str: + for root, _, files in os.walk(collection_path): + for file_name in files: + if file_name == constant.PROFILER_METADATA: + return os.path.join(root, file_name) + return "" + + @classmethod + def get_key(cls): + return cls.__module__.rsplit('.', maxsplit=1)[-1] + + def read_data(self): + data_path = self.get_env_data_file(self.collection_path) + if not data_path: + return + try: + self.env_data = FileManager.read_json_file(data_path) + except RuntimeError as e: + logging.error("Read json failed. %s", str(e)) diff --git a/profiler/advisor/display/html/templates/communication_retransmission_analysis.html b/profiler/advisor/display/html/templates/communication_retransmission_analysis.html new file mode 100644 index 0000000000000000000000000000000000000000..75754fde72467934ac92166ebec7ad5440e55896 --- /dev/null +++ b/profiler/advisor/display/html/templates/communication_retransmission_analysis.html @@ -0,0 +1,40 @@ +
+

Communication Retransmission Analysis

+
+ {{ desc }} + + + + + + + {% for item in solutions %} + {% set rowloop = loop %} + {% for key, value in item.items() %} + + + + + {% endfor %} + {% endfor %} +
Suggestions
{{ rowloop.index }}. {{ value.desc }}
+

+ {{ desc }} + + + {% for header in headers %} + + {% endfor %} + + + {% for row in data %} + + {% for element in row %} + + {% endfor %} + + {% endfor %} +
{{ header }}
{{ element|safe }}
+ +
+
diff --git a/profiler/advisor/display/html/templates/environment_variable.html b/profiler/advisor/display/html/templates/environment_variable.html new file mode 100644 index 0000000000000000000000000000000000000000..ab95096393910e4c1c3a79d5a57640ccddc57928 --- /dev/null +++ b/profiler/advisor/display/html/templates/environment_variable.html @@ -0,0 +1,21 @@ +
+

Environment Variable Issues

+
+ + + {% for header in result.get("headers") %} + + {% endfor %} + + + {% for row in result.get("data") %} + + {% for value in row %} + + {% endfor %} + + {% endfor %} + +
{{ header }}
{{ value|safe }}
+
+
\ No newline at end of file diff --git a/profiler/advisor/display/html/templates/packet_analysis.html b/profiler/advisor/display/html/templates/packet_analysis.html new file mode 100644 index 0000000000000000000000000000000000000000..07189a92631157b800c79d4f76ef4ab72b3e254e --- /dev/null +++ b/profiler/advisor/display/html/templates/packet_analysis.html @@ -0,0 +1,23 @@ +
+

Packet Analysis

+
+ {{ desc }} + + + + + + + {% for item in solutions %} + {% set rowloop = loop %} + {% for key, value in item.items() %} + + + + + {% endfor %} + {% endfor %} +
Suggestions
{{ rowloop.index }}. {{ value.desc }}
+ +
+
diff --git a/profiler/advisor/img/cluster_2.png b/profiler/advisor/img/cluster_2.png new file mode 100644 index 0000000000000000000000000000000000000000..5cb7bd3ff9dbcc6ada325001f4fbe7cd79a6c51d Binary files /dev/null and b/profiler/advisor/img/cluster_2.png differ diff --git a/profiler/advisor/img/communication.png b/profiler/advisor/img/communication.png new file mode 100644 index 0000000000000000000000000000000000000000..ba7c753f6de93cdd483b04c16bcb41002a1432ef Binary files /dev/null and b/profiler/advisor/img/communication.png differ diff --git a/profiler/advisor/img/env_var.png b/profiler/advisor/img/env_var.png new file mode 100644 index 0000000000000000000000000000000000000000..a2c9b6f20e67600f09cff6f5269a464dd0010115 Binary files /dev/null and b/profiler/advisor/img/env_var.png differ diff --git a/profiler/advisor/interface/interface.py b/profiler/advisor/interface/interface.py index 1d3872a1783111af7b1f543241da6b23fb14a632..4908c275d05034666440e6c5e478dc8a68f1dad4 100644 --- a/profiler/advisor/interface/interface.py +++ b/profiler/advisor/interface/interface.py @@ -1,3 +1,17 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import os from collections import OrderedDict import sys @@ -11,12 +25,15 @@ from profiler.advisor.analyzer.graph_fusion.graph_fusion_analyzer import FusionO from profiler.advisor.common.analyzer_scopes import SupportedScopes from profiler.advisor.analyzer.cluster.slow_rank_analyser import SlowRankAnalyzer from profiler.advisor.analyzer.cluster.slow_link_analyser import SlowLinkAnalyzer +from profiler.advisor.analyzer.cluster.Communication_retransmission_analyzer import RDMARetransmissionAnalyzer from profiler.advisor.analyzer.overall.overall_summary_analyzer import OverallSummaryAnalyzer +from profiler.advisor.analyzer.overall.environment_variable_analyzer import EnvironmentVariabelAnalyzer from profiler.advisor.analyzer.schedule.dispatch.timeline_op_dispatch_analyzer import OpDispatchAnalyzer from profiler.advisor.analyzer.schedule.syncbn.syncbn_analyzer import SyncBNAnalyzer from profiler.advisor.analyzer.schedule.synchronize_stream.synchronize_stream_analyzer import SynchronizeStreamAnalyzer from profiler.advisor.analyzer.dataloader.dataloader_analyzer import DataloaderAnalyzer from profiler.advisor.analyzer.computation.ai_core_freq.ai_core_freq_analyzer import AICoreFreqAnalyzer +from profiler.advisor.analyzer.communication.packet_analyzer import PacketAnalyzer class Interface: @@ -35,10 +52,16 @@ class Interface: SupportedScopes.GRAPH: FusionOPAnalyzer, SupportedScopes.FREQ_ANALYSIS: AICoreFreqAnalyzer }), - "communication": OrderedDict(), - "overall": OrderedDict({SupportedScopes.OVER_ALL: OverallSummaryAnalyzer}), + "communication": OrderedDict({ + SupportedScopes.PACKET: PacketAnalyzer + }), + "overall": OrderedDict({ + SupportedScopes.ENVIRONMENT_VARIABLE_ANALYSIS: EnvironmentVariabelAnalyzer, + SupportedScopes.OVER_ALL: OverallSummaryAnalyzer, + }), "dataloader": OrderedDict({SupportedScopes.DATALOADER: DataloaderAnalyzer}), "cluster": OrderedDict({ + SupportedScopes.COMMUNICATION_RETRANSMISSION_DETECTION: RDMARetransmissionAnalyzer, SupportedScopes.SLOW_RANK: SlowRankAnalyzer, SupportedScopes.SLOW_LINK: SlowLinkAnalyzer }) diff --git a/profiler/advisor/rules/environment_variable_info.yaml b/profiler/advisor/rules/environment_variable_info.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b91f827ef47cb6c0894f321e617e787e138016f9 --- /dev/null +++ b/profiler/advisor/rules/environment_variable_info.yaml @@ -0,0 +1,42 @@ +ASCEND_GLOBAL_LOG_LEVEL: + desc: "log level: 0-debug, 1-info, 2-warning, 3-error.\nDefault is error level." + suggest: "Debug or info level may lead to training performance degradation,\n + recommended setting error level by execute command 'export ASCEND_GLOBAL_LOG_LEVEL=3." +HCCL_RDAM_TC: + desc: "Configure the DSCP value of RoCE packets sent by the network port.\n + In the DS field of IP datagram header, the rightmost 6 bits are DSCP, and leftmost 2 bits are 0.\n + It should be set to DSCP * 4. Default value is 132, that is, DSCP is 33 (132=33*4)." + suggest: "Please refer to https://support.huawei.com/enterprise/zh/doc/EDOC1100371278/5eeeed85?idPath=23710424" + suggest_html: "Please refer to LINK" +HCCL_RDMA_SL: + desc: "Specify the priority of the RDMA NIC.\n + The value must be the same as the PFC priority for the NIC.\n + Otherwise, the performance may deteriorate.\n + The value range is [0, 7], and default value is 4." + suggest: "Please refer to https://support.huawei.com/enterprise/zh/doc/EDOC1100371278/5eeeed85?idPath=23710424" + suggest_html: "Please refer to LINK" +ACLNN_CACHE_LIMIT: + desc: "Number of cached aclnn operators." + suggest: "Setting a large number when alcnn and host bound, such as 'export ACLNN_CACHE_LIMIT=100000'" +HOST_CACHE_CAPACITY: + desc: "Enable dynamic shape cache.\n + The default value is 0, indicating that the data cache is disabled.\n + If it is set to a non-zero positive integer, for example, 10, the system caches the execution data of 10 inputs shapes that frequently occur recently.\n + When the cached shapes appear again, the host execution performance will be improved, but the host memory usage increase.\n + The specific increase is proportional to the value of the HOST_CACHE_CAPACITY and size of the model." + suggest: "Setting a non-zero number, such as 'export HOST_CACHE_CAPACITY=20'" +ASCEND_ENHANCE_ENABLE: + desc: "Enable hccl ffts+ mode. 0-disable, 1-enable" + suggest: "Recommend enable hccl ffts+ mode by execute command 'export ASCEND_ENHANCE_ENABLE=1'" +PYTORCH_NPU_ALLOC_CONF: + desc: "Controlling cache allocator behavior.\n + The optional parameter is max_split_size_mb, garbage_collection_threshold and expandable_segments.\n + 1. max_split_size_mb:v--the memory block that is greater than v will be not split.\n + 2. garbage_collection_threshold:t--after the threshold is set, if the NPU memory usage exceed threshold, the cached allocator starts to reclaim memory block. The range of t is (0.0, 1.0).\n + 3. expandable_segments:True/False--The default value is False. If True, this setting instructs cache allocator to create specific memory blocks that can be expanded later to better handle frequent changed in memory usage." + suggest: "export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True" +ASCEND_LAUNCH_BLOCKING: + desc: "Whether to enable the synchronization mode during operation execution.\n + When set to 1, force the operator to run in synchronous mode, making it easier to debug and track down problems in the code.\n + If the set to 0, the task is executed in asynchronous mode." + suggest: "export ASCEND_LAUNCH_BLOCKING=1" \ No newline at end of file diff --git a/profiler/advisor/rules/packet.yaml b/profiler/advisor/rules/packet.yaml new file mode 100644 index 0000000000000000000000000000000000000000..521822af2788457eb6b0a54d9a0123d8bef4247b --- /dev/null +++ b/profiler/advisor/rules/packet.yaml @@ -0,0 +1,14 @@ +problem: "Excessive small communication packets may cause host delivery bottlenecks.\n" +sdma_problem: "In the SDMA communication, {abnormal_sdma_ratio} of the communication data volume is less than {min_sdma_size} MB, and the total time is {abnormal_sdma_time} ms.\n" +rdma_problem: "In the RDMA communication, {abnormal_rdma_ratio} of the communication data volume is less than {min_rdma_size} MB, and the total time is {abnormal_rdma_time} ms." +min_sdma_size: 16 #M +min_rdma_size: 1 #M +min_sdma_ratio: 0.2 +min_rdma_ratio: 0.2 +solutions: + - data parallelism suggestion: + desc: "If abnormal communication is centralized in data parallelism domain, please 1.increase batch size; 2.increase gradient accumulation" + - check the memory optimization policy: + desc: "If the memory optimization policy is Zero3, it is recommended to set it to Zero2/Zero1 if memory conditions allow." + - adopt fusion operators of affinity optimizers: + desc: "using the affinity optimizers or fusion operators may reduce the number of communication operators." \ No newline at end of file diff --git a/profiler/advisor/rules/rdma_analysis.yaml b/profiler/advisor/rules/rdma_analysis.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6c6062775763089b04a29ac4e16f5f1c9e106ca0 --- /dev/null +++ b/profiler/advisor/rules/rdma_analysis.yaml @@ -0,0 +1,9 @@ +problem: "RDMA communication retransmission occurs. A single retransmission takes more than 4s. Retransmission problems +are detected in {group_count} communication domains. \n +Advised to perform the following suggestions" +min_retransmission_time: 4000 #ms +solutions: + - check RDMA transmission time: + desc: "Check whether the transmission time of the RDMA operator that is suspected to be retransmitted is correct." + - Check the network configuration.: + desc: "Check the network configuration of the switch and compute node server." \ No newline at end of file diff --git a/profiler/cli/compare_cli.py b/profiler/cli/compare_cli.py index f9add948ea9da115ab785877a26b890329771f1b..b18099897b251c4e25474b282a3c0b7104c5e92f 100644 --- a/profiler/cli/compare_cli.py +++ b/profiler/cli/compare_cli.py @@ -42,9 +42,9 @@ from profiler.compare_tools.compare_backend.comparison_generator import Comparis required=False) @click.option('--use_input_shape', is_flag=True) @click.option('--gpu_flow_cat', type=str, default='', help="Identifier of the GPU connection.") +@click.option('--base_step', type=str, default='', help="基准性能数据指定比对step") +@click.option('--comparison_step', type=str, default='', help="比较性能数据指定比对step") + def compare_cli(**kwargs) -> None: args = AnalyzeDict(kwargs) - try: - ComparisonGenerator(args).run() - except RuntimeError as e: - print(f"[ERROR] {e}") + ComparisonGenerator(args).run() diff --git a/profiler/cluster_analyse/README.md b/profiler/cluster_analyse/README.md index 4a394e09a48bf815ab88d4201e918fa5deb5c540..785056252c286abafcbef180d47f6b71b5a650f3 100644 --- a/profiler/cluster_analyse/README.md +++ b/profiler/cluster_analyse/README.md @@ -98,6 +98,12 @@ K列:Communication(Not Overlapped and Exclude Receive)指剔除recieve算 L列:Preparing,指迭代开始到首个计算或通信算子运行的时间。 +M列:DP Index,指集群数据按照并行策略切分后所属DP组的索引, 如果没有采集则不显示。 + +N列:PP Index,指集群数据按照并行策略切分后所属PP组的索引,如果没有采集则不显示。 + +O列:TP Index,指集群数据按照并行策略切分后所属TP组的索引,如果没有采集则不显示。 + **Tips**:先筛选B列type为stage, 看stage间是否有问题,再筛选B列type为rank,看rank是否有问题,根据以下几点排查。 * 根据Computing的时间差异判断是否有慢卡,或者有负载不均衡的现象。 diff --git a/profiler/cluster_analyse/analysis/step_trace_time_analysis.py b/profiler/cluster_analyse/analysis/step_trace_time_analysis.py index 6a886fffa97b142e8267066117f561154d85b162..617c0aafcb8b35ce1561b7ed5dd5449c3bb43cc8 100644 --- a/profiler/cluster_analyse/analysis/step_trace_time_analysis.py +++ b/profiler/cluster_analyse/analysis/step_trace_time_analysis.py @@ -19,11 +19,14 @@ from common_func.db_manager import DBManager from common_func.constant import Constant from common_func.file_manager import FileManager from prof_bean.step_trace_time_bean import StepTraceTimeBean +from cluster_utils.parallel_strategy_calculator import ParallelStrategyCalculator class StepTraceTimeAnalysis: CLUSTER_TRACE_TIME_CSV = "cluster_step_trace_time.csv" CLUSTER_TRACE_TIME_TABLE = "ClusterStepTraceTime" + PROFILER_METADATA_JSON = "profiler_metadata.json" + PARALLEL_HEADERS = ["DP Index", "PP Index", "TP Index"] def __init__(self, param: dict): self.collection_path = param.get(Constant.COLLECTION_PATH) @@ -32,6 +35,7 @@ class StepTraceTimeAnalysis: self.step_time_dict = {} self.step_data_list = [] self.data_type = param.get(Constant.DATA_TYPE) + self.distributed_args = None @staticmethod def get_max_data_row(data_group_list: list): @@ -48,8 +52,35 @@ class StepTraceTimeAnalysis: def run(self): self.load_step_trace_time_data() self.analyze_step_time() + self.partition_ranks_data() self.dump_data() + def partition_ranks_data(self): + if not self.distributed_args: + return + + calculator = ParallelStrategyCalculator(**self.distributed_args) + parallelism_map = calculator.run() + + if len(parallelism_map) > len(self.step_time_dict): + missing_rank_ids = [rank_id for rank_id in range(len(parallelism_map)) + if rank_id not in self.step_time_dict] + print(f"[WARNING] Step trace data length should equal to real rank numbers, " + f"but get step data length = {len(self.step_time_dict)}, real rank numbers = {len(parallelism_map)}, " + f"maybe lost some rank ids = {missing_rank_ids}, please check your profiling data.") + + if len(parallelism_map) < len(self.step_time_dict): + print(f"[ERROR] Step trace data length should equal to real rank numbers, " + f"but get step data length = {len(self.step_time_dict)}, real rank numbers = {len(parallelism_map)}, " + f"maybe parallel params in profiler_metadata.json is error, please check your metadata data.") + self.distributed_args = None + return + + for step_data in self.step_data_list: + rank_id = step_data[2] + step_data.extend(list(parallelism_map[rank_id]) + if parallelism_map[rank_id] else ['NA'] * len(self.PARALLEL_HEADERS)) + def dump_data(self): if not self.step_data_list: print("[WARNING] Can't get step time info!") @@ -74,6 +105,10 @@ class StepTraceTimeAnalysis: def load_step_trace_time_data(self): for rank_id, profiling_dir_path in self.data_map.items(): + metadata_path = os.path.join(profiling_dir_path, self.PROFILER_METADATA_JSON) + if not self.distributed_args and os.path.exists(metadata_path): + metadata = FileManager.read_json_file(metadata_path) + self.distributed_args = metadata.get(Constant.DISTRIBUTED_ARGS, None) if metadata else None if self.data_type == Constant.TEXT: step_time_file = os.path.join(profiling_dir_path, Constant.SINGLE_OUTPUT, Constant.STEP_TIME_CSV) if os.path.exists(step_time_file): @@ -121,6 +156,8 @@ class StepTraceTimeAnalysis: def get_headers(self): if self.step_time_dict: for rank in self.step_time_dict: - if self.step_time_dict.get(rank): + if self.step_time_dict.get(rank) and self.distributed_args: + return self.step_time_dict[rank][0].all_headers + self.PARALLEL_HEADERS + elif self.step_time_dict.get(rank): return self.step_time_dict[rank][0].all_headers return [] diff --git a/profiler/cluster_analyse/cluster_utils/parallel_algorithm.py b/profiler/cluster_analyse/cluster_utils/parallel_algorithm.py new file mode 100644 index 0000000000000000000000000000000000000000..9da829bbd0feb975bc82a55adc374a5fcd1a92f6 --- /dev/null +++ b/profiler/cluster_analyse/cluster_utils/parallel_algorithm.py @@ -0,0 +1,120 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from abc import ABC, abstractmethod + + +class ParallelAlgorithm(ABC): + @abstractmethod + def partition(self): + pass + + +class MegatronAlgorithm(ParallelAlgorithm): + def __init__(self, + world_size: int = 1, + tensor_model_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, + data_parallel_size: int = 1, + context_parallel_size: int = 1, + expert_model_parallel_size: int = 1, + **kwargs): + + if data_parallel_size % expert_model_parallel_size != 0: + raise RuntimeError( + f"data_parallel_size is not divisible by " + f"expert_model_parallel_size, get data_parallel_size = {data_parallel_size}, " + f"expert_model_parallel_size = {expert_model_parallel_size}" + ) + + if data_parallel_size * context_parallel_size % expert_model_parallel_size != 0: + raise RuntimeError( + f"data_parallel_size * context_parallel_size {data_parallel_size * context_parallel_size} " + f"is not divisible by expert_model_parallel_size " + ) + + if world_size != tensor_model_parallel_size * pipeline_model_parallel_size * data_parallel_size: + raise RuntimeError( + f"world_size must be equal to tensor_model_parallel_size * " + f"pipeline_model_parallel_size * data_parallel_size, but get world_size = {world_size}, " + f"tensor_model_parallel_size = {tensor_model_parallel_size}, " + f"pipeline_model_parallel_size = {pipeline_model_parallel_size}, " + f"data_parallel_size = {data_parallel_size}" + ) + + self.world_size = world_size + self.tensor_model_parallel_size = tensor_model_parallel_size + self.pipeline_model_parallel_size = pipeline_model_parallel_size + self.data_parallel_size = data_parallel_size + self.context_parallel_size = context_parallel_size + self.expert_model_parallel_size = expert_model_parallel_size + + self.num_tensor_model_parallel_groups = self.world_size // tensor_model_parallel_size + self.num_pipeline_model_parallel_groups = self.world_size // pipeline_model_parallel_size + self.num_data_parallel_groups = self.world_size // data_parallel_size + + self.all_data_parallel_group_ranks = [] + self.all_data_parallel_group_ranks_with_cp = [] + self.all_model_parallel_group_ranks = [] + self.all_tensor_model_parallel_ranks = [] + self.all_expert_parallel_ranks = [] + self.all_pipeline_model_parallel_ranks = [] + + def partition(self): + self._build_dp_group() + self._build_tp_group() + self._build_pp_group() + self._build_ep_group() + + def _build_dp_group(self): + # Build the data-parallel groups + for i in range(self.pipeline_model_parallel_size): + begin_rank = self.num_pipeline_model_parallel_groups * i + end_rank = self.num_pipeline_model_parallel_groups * (i + 1) + for k in range(self.tensor_model_parallel_size * self.context_parallel_size): + ranks = range(begin_rank + k, + end_rank, self.tensor_model_parallel_size * self.context_parallel_size) + self.all_data_parallel_group_ranks.append(list(ranks)) + + for k in range(self.tensor_model_parallel_size): + ranks_with_cp = range(begin_rank + k, + end_rank, self.tensor_model_parallel_size) + self.all_data_parallel_group_ranks_with_cp.append(list(ranks_with_cp)) + + # Build the model-parallel groups + for i in range(self.data_parallel_size): + ranks = [data_parallel_group_ranks[i] + for data_parallel_group_ranks in self.all_data_parallel_group_ranks] + self.all_model_parallel_group_ranks.append(list(ranks)) + + def _build_tp_group(self): + # Build the tensor model-parallel groups. + for i in range(self.num_tensor_model_parallel_groups): + ranks = range(i * self.tensor_model_parallel_size, + (i + 1) * self.tensor_model_parallel_size) + self.all_tensor_model_parallel_ranks.append(list(ranks)) + + def _build_pp_group(self): + # Build the pipeline model-parallel groups. + for p in range(self.num_pipeline_model_parallel_groups): + ranks = range(p, self.world_size, + self.num_pipeline_model_parallel_groups) + self.all_pipeline_model_parallel_ranks.append(list(ranks)) + + def _build_ep_group(self): + # Build the expert model-parallel groups. + for dp_cp_ranks in self.all_data_parallel_group_ranks_with_cp: + for i in range(0, len(dp_cp_ranks), self.expert_model_parallel_size): + ranks = dp_cp_ranks[i:i + self.expert_model_parallel_size] + self.all_expert_parallel_ranks.append(list(ranks)) diff --git a/profiler/cluster_analyse/cluster_utils/parallel_strategy_calculator.py b/profiler/cluster_analyse/cluster_utils/parallel_strategy_calculator.py new file mode 100644 index 0000000000000000000000000000000000000000..0f0a1809d99e5a6226e7d88955286eed8bf4132c --- /dev/null +++ b/profiler/cluster_analyse/cluster_utils/parallel_strategy_calculator.py @@ -0,0 +1,119 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from enum import Enum +from dataclasses import dataclass + +from .parallel_algorithm import MegatronAlgorithm + + +class ParallelAlgorithmType(Enum): + Megatron = 0 + + +@dataclass +class RankMetrics: + computing: float = 0.0 + communication: float = 0.0 + free: float = 0.0 + + +class RankNode: + def __init__(self, + index: int, + rank_ids: list, + category: str, + metrics: RankMetrics): + self.index = index + self.rank_ids = rank_ids + self.category = category + self.metrics = metrics + self.children = [] + + def add_child(self, child_node): + if isinstance(child_node, RankNode): + self.children.append(child_node) + else: + raise TypeError("Child must be an instance of TreeNode") + + +class ParallelStrategyCalculator: + ROOT_LABEL = "ROOT" + TP_LABEL = "TP" + PP_LABEL = "PP" + DP_LABEL = "DP" + + parallel_algorithms = { + ParallelAlgorithmType.Megatron: MegatronAlgorithm + } + + def __init__(self, + algorithm_type: ParallelAlgorithmType = ParallelAlgorithmType.Megatron, + **kwargs): + + self.algorithm = self.parallel_algorithms.get(algorithm_type, MegatronAlgorithm)(**kwargs) + + # result of partition rank id to DP Index, PP Index, TP Index + self.ranks_ptd_map = [None] * self.algorithm.world_size + self.root_node = None + + def run(self): + self.algorithm.partition() + self._build_tree() + self._dfs(self.root_node) + return self.ranks_ptd_map + + def _build_tree(self): + if not self.algorithm.all_model_parallel_group_ranks: + return + + self.root_node = RankNode(-1, self.algorithm.all_model_parallel_group_ranks, + ParallelStrategyCalculator.ROOT_LABEL, RankMetrics()) + + # DP Level + for i, dp_group in enumerate(self.algorithm.all_model_parallel_group_ranks): + dp_node = RankNode(i, dp_group, ParallelStrategyCalculator.DP_LABEL, RankMetrics()) + + # PP Level + for pp_idx, j in enumerate(range(0, len(dp_group), self.algorithm.tensor_model_parallel_size)): + pp_group = dp_group[j:j + self.algorithm.tensor_model_parallel_size] + pp_node = RankNode(pp_idx, pp_group, ParallelStrategyCalculator.PP_LABEL, RankMetrics()) + + # TP Level + for k, tp_rank in enumerate(pp_group): + tp_node = RankNode(k, [tp_rank], + ParallelStrategyCalculator.TP_LABEL, RankMetrics()) + pp_node.add_child(tp_node) + + dp_node.add_child(pp_node) + self.root_node.add_child(dp_node) + + def _dfs(self, + rank_node: RankNode, + parent_node: RankNode = None, + grandparent_node: RankNode = None): + + if rank_node is None: + return + + if not rank_node.children: + if rank_node.rank_ids: + self.ranks_ptd_map[rank_node.rank_ids[0]] = ( + grandparent_node.index, # DP Index + parent_node.index, # PP Index + rank_node.index # TP Index + ) + + for child in rank_node.children: + self._dfs(child, rank_node, parent_node) diff --git a/profiler/cluster_analyse/common_func/constant.py b/profiler/cluster_analyse/common_func/constant.py index 2922d6a900fbbf243b61a73a13cf9caf945ec1c1..a5b93b0caaddffebdc66eaf150a04d8671bcba08 100644 --- a/profiler/cluster_analyse/common_func/constant.py +++ b/profiler/cluster_analyse/common_func/constant.py @@ -106,3 +106,6 @@ class Constant(object): CONFIG = "config" EXPER_CONFIG = "experimental_config" EXPORT_TYPE = "_export_type" + + # metadata key + DISTRIBUTED_ARGS = "distributed_args" diff --git a/profiler/cluster_analyse/common_func/tables_config.py b/profiler/cluster_analyse/common_func/tables_config.py index f010014519f864e627f83b99ad0df26af98af3f9..7122d6461fcbeda6b3d3a641ead1c508758d681a 100644 --- a/profiler/cluster_analyse/common_func/tables_config.py +++ b/profiler/cluster_analyse/common_func/tables_config.py @@ -59,7 +59,10 @@ class TablesConfig: ("stage", "NUMERIC, null"), ("bubble", "NUMERIC, null"), ("communication_not_overlapped_and_exclude_receive", "NUMERIC, null"), - ("preparing", "NUMERIC, null") + ("preparing", "NUMERIC, null"), + ("dp_index", "INTEGER, null"), + ("pp_index", "INTEGER, null"), + ("tp_index", "INTEGER, null") ], "HostInfoMap": [ ("hostUid", "INTEGER, null"), diff --git a/profiler/compare_tools/README.md b/profiler/compare_tools/README.md index b40f19e92fa130e896b69c6f59889756c518d9ff..97dcf5b19c7265f91750d62e0804fc8594b191e2 100644 --- a/profiler/compare_tools/README.md +++ b/profiler/compare_tools/README.md @@ -196,7 +196,9 @@ MindSpore场景仅支持**总体性能**和**通信性能**的对比。 | Lccl Time(Num) | Lccl算子耗时,Num表示计算的次数。 | | Computing Time | 计算流耗时,计算流所有event耗时总和。如果有多条并发计算,计算流耗时对重叠部分只会计算一次。 | | Mem Usage | 内存使用。GPU上的内存使用可以使用nvidia-smi查看,NPU上的内存使用可以使用npu-smi查看,Profiling信息采集时打开profile_memory=True开关,mem usage显示的是memory_record里面的最大resevered值,一般来说是进程级内存。 | -| Uncovered Communication Time(Wait Time) | 通信未掩盖耗时,包含Wait Time(只有采集性能数据的Level等级为L1以上并且采集NPU数据时才会存在)为同步时间。 | +| Uncovered Communication Time(Wait Time) | 通信未掩盖耗时。Wait Time为卡间等待时间(Wait Time仅NPU场景才会存在)。 | +| RDMA Bandwidth(GB/s) | RDMA带宽,单位GB/s。 | +| SDMA Bandwidth(GB/s) | SDMA带宽,单位GB/s。 | | SDMA Time(Num) | 拷贝类任务耗时,Num表示计算的次数。 | | Free Time | 调度耗时 = E2E耗时 - 算子耗时 - 通信不可掩盖耗时。Free的定义为Device侧既不在通信又不在计算的时间,因此包含拷贝时间(SDMA Time)。 | | E2E Time(Not minimal profiling) | E2E总耗时,计算流端到端耗时。当存在Not minimal profiling时,表示该时间存在性能膨胀,会影响通信和调度耗时。 | @@ -221,7 +223,7 @@ Index列字段说明: | 字段 | | | 说明 | | ---------------------------- | ------------------ | ----------------------------------- | ------------------------------------------------------------ | -| Computing Time | | | 计算流耗时,计算流所有event耗时总和。如果有多条并发计算,计算流耗时对重叠部分只会计算一次。 | +| Computing Time | | | 计算流耗时,计算流所有event耗时总和。如果有多条并发计算,计算流耗时对重叠部分只会计算一次。
NPU场景下,仅当采集性能数据的Level等级为L1及以上且aic_metrics取值为PipeUtilization时才可拆分出Computing Time的二级字段Flash Attention、Conv等。 | | | Flash Attention | | Flash Attention算子。 | | | | Flash Attention (Forward) (Cube) | Flash Attention前向算子下发的所有Cube类Kernel的总耗时,一般为执行该算子核心计算的算子。 | | | | Flash Attention (Forward) (Vector) | Flash Attention前向算子下发的所有Vector类Kernel的总耗时,一般为插入的转换类算子,如TransData。 | diff --git a/profiler/compare_tools/compare_backend/comparator/overall_performance_comparator.py b/profiler/compare_tools/compare_backend/comparator/overall_performance_comparator.py index 7283c17b47dea78058d0541c1332df0fa45e90d9..09d8688cf231ba713a2f731c25e1da7d54aa5ddb 100644 --- a/profiler/compare_tools/compare_backend/comparator/overall_performance_comparator.py +++ b/profiler/compare_tools/compare_backend/comparator/overall_performance_comparator.py @@ -64,6 +64,14 @@ class OverallPerformanceComparator(BaseComparator): else: comp_col.extend( [f'{comp_profiling_info.communication_not_overlapped: .3f}s({comp_profiling_info.wait_time:.3f}s)']) + if base_profiling_info.RDMA_bandwidth or comp_profiling_info.RDMA_bandwidth: + self._headers.extend(['RDMA Bandwidth']) + base_col.append(f'{base_profiling_info.RDMA_bandwidth:.3f}GB/s') + comp_col.append(f'{comp_profiling_info.RDMA_bandwidth:.3f}GB/s') + if base_profiling_info.SDMA_bandwidth or comp_profiling_info.SDMA_bandwidth: + self._headers.extend(['SDMA Bandwidth']) + base_col.append(f'{base_profiling_info.SDMA_bandwidth:.3f}GB/s') + comp_col.append(f'{comp_profiling_info.SDMA_bandwidth:.3f}GB/s') if base_profiling_info.sdma_time or comp_profiling_info.sdma_time: self._headers.append('SDMA Time(Num)') base_col.append(f'{base_profiling_info.sdma_time:.3f}s({base_profiling_info.sdma_num})') diff --git a/profiler/compare_tools/compare_backend/compare_bean/origin_data_bean/kernel_details_bean.py b/profiler/compare_tools/compare_backend/compare_bean/origin_data_bean/kernel_details_bean.py index c15396e9c597b67089acc1afd11c9f351e47b379..f29839724a64078e86eeedc59e14e50e2cf2655d 100644 --- a/profiler/compare_tools/compare_backend/compare_bean/origin_data_bean/kernel_details_bean.py +++ b/profiler/compare_tools/compare_backend/compare_bean/origin_data_bean/kernel_details_bean.py @@ -18,6 +18,7 @@ class KernelDetailsBean: self._mac_time = 0.0 self._duration = 0.0 self._start_time = Decimal("0") + self._step_id = "" self.init() @property @@ -65,6 +66,10 @@ class KernelDetailsBean: @property def end_time(self) -> Decimal: return self.start_time + convert_to_decimal(self._duration) + + @property + def step_id(self) -> int: + return int(self._step_id) if self._step_id else Constant.VOID_STEP def is_hide_op_pmu(self): if "mac_time(us)" in self._data.keys() or "aiv_vec_time(us)" in self._data.keys(): @@ -119,4 +124,5 @@ class KernelDetailsBean: self._aicore_time = self._data.get("aicore_time(us)", "") self._mac_time = self._data.get('mac_time(us)', "") self._duration = self._data.get('Duration(us)', 0) + self._step_id = self._data.get('Step Id', "") self._start_time = Decimal(self._data.get("Start Time(us)", "0")) diff --git a/profiler/compare_tools/compare_backend/compare_bean/profiling_info.py b/profiler/compare_tools/compare_backend/compare_bean/profiling_info.py index e0a80a4d30d0feda38d4290667df6620855d8562..c639aba5c09bb9aa531a745f2132143e23334aaf 100644 --- a/profiler/compare_tools/compare_backend/compare_bean/profiling_info.py +++ b/profiler/compare_tools/compare_backend/compare_bean/profiling_info.py @@ -8,31 +8,15 @@ class ProfilingInfo: def __init__(self, profiling_type: str): self.profiling_type = profiling_type - self.cube_time = 0.0 self.other_time = 0.0 - self.vec_time = 0.0 - self.cube_num = 0 - self.vec_num = 0 - self.sdma_num = 0 - self.fa_num_fwd = 0 - self.fa_num_bwd = 0 - self.pa_num = 0 self.lccl_num = 0 - self.conv_time_fwd = 0.0 - self.conv_time_bwd = 0.0 - self.conv_num_fwd = 0 - self.conv_num_bwd = 0 self.compute_time = 0.0 self.communication_not_overlapped = 0.0 self.wait_time = 0.0 self.memory_used = 0.0 self.e2e_time = 0.0 - self.sdma_time = 0.0 self.scheduling_time = 0.0 - self.fa_time_bwd = 0.0 - self.pa_time = 0.0 self.lccl_time = 0.0 - self.fa_time_fwd = 0.0 self.minimal_profiling = False self.hide_op_details = False self.is_level0 = False @@ -76,6 +60,8 @@ class ProfilingInfo: self.other_cube_time = 0.0 self.other_cube_num = 0 + self.RDMA_bandwidth = 0.0 + self.SDMA_bandwidth = 0.0 @property def e2e_time_ms(self): @@ -136,61 +122,78 @@ class ProfilingInfo: def vector_total_num(self): return sum((self.vector_num_trans, self.vector_num_notrans)) - def trans_time_to_s(self): - self.cube_time = self.cube_time / 10 ** 6 - self.other_time = self.other_time / 10 ** 6 - self.vec_time = self.vec_time / 10 ** 6 - self.compute_time = self.compute_time / 10 ** 6 - self.communication_not_overlapped = self.communication_not_overlapped / 10 ** 6 - self.wait_time = self.wait_time / 10 ** 6 - self.e2e_time = self.e2e_time / 10 ** 6 - self.sdma_time = self.sdma_time / 10 ** 6 - self.scheduling_time = self.scheduling_time / 10 ** 6 - self.fa_time_bwd = self.fa_time_bwd / 10 ** 6 - self.fa_time_fwd = self.fa_time_fwd / 10 ** 6 - self.pa_time = self.pa_time / 10 ** 6 - self.lccl_time = self.lccl_time / 10 ** 6 - self.conv_time_fwd = self.conv_time_fwd / 10 ** 6 - self.conv_time_bwd = self.conv_time_bwd / 10 ** 6 + @property + def cube_time(self): + return ( + self.matmul_time_cube + self.matmul_time_vector + self.other_cube_time) / Constant.MILLISECONDS_TO_SECONDS - # 新指标单位为ms - self.fa_time_fwd_cube /= 10 ** 3 - self.fa_time_bwd_cube /= 10 ** 3 - self.fa_time_fwd_vector /= 10 ** 3 - self.fa_time_bwd_vector /= 10 ** 3 - self.conv_time_fwd_cube /= 10 ** 3 - self.conv_time_bwd_cube /= 10 ** 3 - self.conv_time_fwd_vector /= 10 ** 3 - self.conv_time_bwd_vector /= 10 ** 3 - self.matmul_time_cube /= 10 ** 3 - self.matmul_time_vector /= 10 ** 3 - self.vector_time_trans /= 10 ** 3 - self.vector_time_notrans /= 10 ** 3 - self.sdma_time_tensor_move /= 10 ** 3 - self.sdma_time_stream /= 10 ** 3 - self.page_attention_time /= 10 ** 3 - self.other_cube_time /= 10 ** 3 + @property + def vec_time(self): + return (self.vector_time_trans + self.vector_time_notrans) / Constant.MILLISECONDS_TO_SECONDS + + @property + def cube_num(self): + return self.matmul_num_cube + self.matmul_num_vector + self.other_cube_num + + @property + def vec_num(self): + return self.vector_num_trans + self.vector_num_notrans + + @property + def sdma_num(self): + return self.sdma_num_tensor_move + self.sdma_num_stream + + @property + def fa_num_fwd(self): + return self.fa_num_fwd_cube + self.fa_num_fwd_vector + @property + def fa_num_bwd(self): + return self.fa_num_bwd_cube + self.fa_num_bwd_vector + + @property + def pa_num(self): + return self.page_attention_num + + @property + def pa_time(self): + return self.page_attention_time / Constant.MILLISECONDS_TO_SECONDS + + @property + def conv_time_fwd(self): + return (self.conv_time_fwd_cube + self.conv_time_fwd_vector) / Constant.MILLISECONDS_TO_SECONDS + + @property + def conv_time_bwd(self): + return (self.conv_time_bwd_cube + self.conv_time_bwd_vector) / Constant.MILLISECONDS_TO_SECONDS + + @property + def conv_num_fwd(self): + return self.conv_num_fwd_cube + self.conv_num_fwd_vector + + @property + def conv_num_bwd(self): + return self.conv_num_bwd_cube + self.conv_num_bwd_vector + + @property + def sdma_time(self): + return (self.sdma_time_tensor_move + self.sdma_time_stream) / Constant.MILLISECONDS_TO_SECONDS + + @property + def fa_time_fwd(self): + return (self.fa_time_fwd_cube + self.fa_time_fwd_vector) / Constant.MILLISECONDS_TO_SECONDS + + @property + def fa_time_bwd(self): + return (self.fa_time_bwd_cube + self.fa_time_bwd_vector) / Constant.MILLISECONDS_TO_SECONDS def calculate_other_time(self): self.other_time = max( [0, self.compute_time - self.cube_time - self.fa_time_fwd - self.fa_time_bwd - self.pa_time - self.vec_time - self.conv_time_fwd - self.conv_time_bwd]) - def calculate_vec_time(self): - self.vec_time = self.compute_time - self.cube_time - self.fa_time_fwd - self.fa_time_bwd \ - - self.conv_time_fwd - self.conv_time_bwd - def calculate_schedule_time(self): self.scheduling_time = (self.e2e_time - self.compute_time - self.lccl_time - self.communication_not_overlapped) - def update_fa_fwd_info(self, time: float): - self.fa_time_fwd += time - self.fa_num_fwd += 1 - - def update_fa_bwd_info(self, time: float): - self.fa_time_bwd += time - self.fa_num_bwd += 1 - def update_fa_fwd_cube_info(self, time: float): self.fa_time_fwd_cube += time self.fa_num_fwd_cube += 1 @@ -215,22 +218,10 @@ class ProfilingInfo: self.sdma_time_stream += time self.sdma_num_stream += num - def update_pa_info(self, time: float): - self.pa_time += time - self.pa_num += 1 - def update_lccl_info(self, time: float): self.lccl_time += time self.lccl_num += 1 - def update_conv_fwd_info(self, time: float): - self.conv_time_fwd += time - self.conv_num_fwd += 1 - - def update_conv_bwd_info(self, time: float): - self.conv_time_bwd += time - self.conv_num_bwd += 1 - def update_conv_bwd_cube_info(self, time: float): self.conv_time_bwd_cube += time self.conv_num_bwd_cube += 1 @@ -267,18 +258,6 @@ class ProfilingInfo: self.vector_time_notrans += time self.vector_num_notrans += 1 - def update_sdma_info(self, time: float, num: int = 1): - self.sdma_time += time - self.sdma_num += num - - def update_cube_info(self, time: float): - self.cube_time += time - self.cube_num += 1 - - def update_vec_info(self, time: float): - self.vec_time += time - self.vec_num += 1 - def update_other_cube_info(self, time: float): self.other_cube_time += time self.other_cube_num += 1 @@ -306,3 +285,35 @@ class ProfilingInfo: def is_not_minimal_profiling(self) -> bool: return self.profiling_type == Constant.NPU and not self.minimal_profiling + + def set_RDMA_bandwidth(self, bandwidth: float): + self.RDMA_bandwidth = bandwidth + + def set_SDMA_bandwidth(self, bandwidth: float): + self.SDMA_bandwidth = bandwidth + + def trans_time_to_s(self): + # 新指标单位为ms + self.fa_time_fwd_cube /= Constant.MILLISECONDS_TO_SECONDS + self.fa_time_bwd_cube /= Constant.MILLISECONDS_TO_SECONDS + self.fa_time_fwd_vector /= Constant.MILLISECONDS_TO_SECONDS + self.fa_time_bwd_vector /= Constant.MILLISECONDS_TO_SECONDS + self.conv_time_fwd_cube /= Constant.MILLISECONDS_TO_SECONDS + self.conv_time_bwd_cube /= Constant.MILLISECONDS_TO_SECONDS + self.conv_time_fwd_vector /= Constant.MILLISECONDS_TO_SECONDS + self.conv_time_bwd_vector /= Constant.MILLISECONDS_TO_SECONDS + self.matmul_time_cube /= Constant.MILLISECONDS_TO_SECONDS + self.matmul_time_vector /= Constant.MILLISECONDS_TO_SECONDS + self.vector_time_trans /= Constant.MILLISECONDS_TO_SECONDS + self.vector_time_notrans /= Constant.MILLISECONDS_TO_SECONDS + self.sdma_time_tensor_move /= Constant.MILLISECONDS_TO_SECONDS + self.sdma_time_stream /= Constant.MILLISECONDS_TO_SECONDS + self.page_attention_time /= Constant.MILLISECONDS_TO_SECONDS + self.other_cube_time /= Constant.MILLISECONDS_TO_SECONDS + self.other_time /= Constant.MICROSECONDS_TO_SECONDS + self.compute_time /= Constant.MICROSECONDS_TO_SECONDS + self.communication_not_overlapped /= Constant.MICROSECONDS_TO_SECONDS + self.wait_time /= Constant.MICROSECONDS_TO_SECONDS + self.e2e_time /= Constant.MICROSECONDS_TO_SECONDS + self.scheduling_time /= Constant.MICROSECONDS_TO_SECONDS + self.lccl_time /= Constant.MICROSECONDS_TO_SECONDS diff --git a/profiler/compare_tools/compare_backend/comparison_generator.py b/profiler/compare_tools/compare_backend/comparison_generator.py index b07170b648c44f8061fb1482bdd5d2d417cbcfaf..bfbc1bb7bd22f2db4806dc4d377d77ccb8028feb 100644 --- a/profiler/compare_tools/compare_backend/comparison_generator.py +++ b/profiler/compare_tools/compare_backend/comparison_generator.py @@ -12,19 +12,32 @@ class ComparisonGenerator: INTERFACE_DICT = {Constant.OVERALL_COMPARE: OverallInterface} def __init__(self, args): - self._args_manager = ArgsManager() - self._args_manager.init(args) + self._args_manager = ArgsManager(args) self._data_dict = {} def run(self): - self.load_data() - self.generate_compare_result() + try: + self._args_manager.init() + self.load_data() + self.generate_compare_result() + except NotImplementedError as e: + print(f"[ERROR] {e}") + except RuntimeError as e: + print(f"[ERROR] {e}") + except FileNotFoundError as e: + print(f"[ERROR] {e}") + except Exception as e: + print(f"[ERROR] {e}") def load_data(self): self._data_dict[Constant.BASE_DATA] = self.PARSER_DICT.get(self._args_manager.base_profiling_type)( - self._args_manager.args, self._args_manager.base_path_dict).load_data() + self._args_manager.args, + self._args_manager.base_path_dict, + self._args_manager.base_step).load_data() self._data_dict[Constant.COMPARISON_DATA] = self.PARSER_DICT.get(self._args_manager.comparison_profiling_type)( - self._args_manager.args, self._args_manager.comparison_path_dict).load_data() + self._args_manager.args, + self._args_manager.comparison_path_dict, + self._args_manager.comparison_step).load_data() def generate_compare_result(self): overall_data = {Constant.BASE_DATA: self._data_dict.get(Constant.BASE_DATA).overall_metrics, @@ -37,8 +50,18 @@ class ComparisonGenerator: generator.join() def run_interface(self, compare_type: str) -> dict: - self.load_data() - interface = self.INTERFACE_DICT.get(compare_type) - if interface: - return interface(self._data_dict).run() + try: + self._args_manager.init() + self.load_data() + interface = self.INTERFACE_DICT.get(compare_type) + if interface: + return interface(self._data_dict).run() + except NotImplementedError as e: + print(f"[ERROR] {e}") + except RuntimeError as e: + print(f"[ERROR] {e}") + except FileNotFoundError as e: + print(f"[ERROR] {e}") + except Exception as e: + print(f"[ERROR] {e}") return {} diff --git a/profiler/compare_tools/compare_backend/data_prepare/operator_data_prepare.py b/profiler/compare_tools/compare_backend/data_prepare/operator_data_prepare.py index 3106527c41997287c0457d0e2f555537c79e9a50..2df9ae43e957e9f91b1192a6413c88f524f89650 100644 --- a/profiler/compare_tools/compare_backend/data_prepare/operator_data_prepare.py +++ b/profiler/compare_tools/compare_backend/data_prepare/operator_data_prepare.py @@ -1,36 +1,48 @@ from compare_backend.profiling_parser.base_profiling_parser import ProfilingResult from compare_backend.utils.tree_builder import TreeBuilder - +from compare_backend.utils.constant import Constant class OperatorDataPrepare: - def __init__(self, profiling_data: ProfilingResult): + def __init__(self, profiling_data: ProfilingResult, specified_step_id: int = Constant.VOID_STEP): self.profiling_data = profiling_data + self._all_nodes = self._build_tree() + self._root_node = self._all_nodes[0] + self._specified_step_id = specified_step_id def get_top_layer_ops(self) -> any: - root_node = TreeBuilder.build_tree(self.profiling_data.torch_op_data, self.profiling_data.kernel_dict, - self.profiling_data.memory_list) - level1_child_nodes = root_node.child_nodes - result_data = [] - for level1_node in level1_child_nodes: - if level1_node.is_step_profiler(): - result_data.extend(level1_node.child_nodes) - else: - result_data.append(level1_node) - return result_data + if len(self._all_nodes) < 1: + return [] + return self._get_top_layers_ops_from_root_node(self._root_node.child_nodes) def get_all_layer_ops(self) -> any: - root_node = TreeBuilder.build_tree(self.profiling_data.torch_op_data, [], []) - level1_child_nodes = root_node.child_nodes - node_queue = [] result_data = [] - for level1_node in level1_child_nodes: - if level1_node.is_step_profiler(): - node_queue.extend(level1_node.child_nodes) - else: - node_queue.append(level1_node) + if len(self._all_nodes) < 1: + return result_data + if self._specified_step_id == Constant.VOID_STEP: + return list(filter(lambda x: not x.is_step_profiler(), self._all_nodes[1:])) + node_queue = self._get_top_layers_ops_from_root_node(self._root_node.child_nodes) while len(node_queue) > 0: node = node_queue.pop(0) result_data.append(node) if node.child_nodes: node_queue.extend(node.child_nodes) + return result_data + + def _build_tree(self): + return TreeBuilder.build_tree(self.profiling_data.torch_op_data, self.profiling_data.kernel_dict, + self.profiling_data.memory_list) + + def _get_top_layers_ops_from_root_node(self, top_layers_nodes: list) -> list: + result_data = [] + for level1_node in top_layers_nodes: + if self._specified_step_id == Constant.VOID_STEP: + if level1_node.is_step_profiler(): + result_data.extend(level1_node.child_nodes) + else: + result_data.append(level1_node) + elif level1_node.is_step_profiler() and level1_node.get_step_id() == self._specified_step_id: + result_data.extend(level1_node.child_nodes) + if not result_data and self._specified_step_id != Constant.VOID_STEP: + print(f"[WARNING] There is no operator infomation for step {self._specified_step_id}, " \ + "please check whether the data contains this step.") return result_data \ No newline at end of file diff --git a/profiler/compare_tools/compare_backend/data_prepare/sequence_pre_matching.py b/profiler/compare_tools/compare_backend/data_prepare/sequence_pre_matching.py new file mode 100644 index 0000000000000000000000000000000000000000..c04d4c2b699e7dc5ac6d76e6aa4ef40229558c30 --- /dev/null +++ b/profiler/compare_tools/compare_backend/data_prepare/sequence_pre_matching.py @@ -0,0 +1,162 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import deque + +from compare_backend.utils.name_function import NameFunction +from compare_backend.utils.common_func import longest_common_subsequence_matching +from compare_backend.utils.torch_op_node import TorchOpNode +from compare_backend.utils.module_node import ModuleNode + +from compare_backend.utils.constant import Constant + + +class SequencePreMatching: + OP_TYPE = 1 + MODULE_TYPE = 2 + + def __init__(self, args, base_bwd_tid=None, comparison_bwd_tid=None): + self._args = args + self._base_bwd_tid = base_bwd_tid + self._comparison_bwd_tid = comparison_bwd_tid + + @staticmethod + def _match_none_subsequence(base_ops: list, comparison_ops: list) -> list: + op_compare_result = [[op, None] for op in iter(base_ops)] + op_compare_result.extend([[None, op] for op in iter(comparison_ops)]) + return op_compare_result + + @staticmethod + def _split_operator_data(data_list, bwd_tid): + split_result = [] + if not data_list: + return split_result + data_list.sort(key=lambda x: x.start_time) + pre_tid = data_list[0].tid + part_data_dict = {Constant.IS_BWD: pre_tid == bwd_tid, Constant.OPS: []} + for op in data_list: + if op.tid == pre_tid or (pre_tid != bwd_tid and op.tid != bwd_tid): + part_data_dict[Constant.OPS].append(op) + else: + split_result.append(part_data_dict) + part_data_dict = {Constant.IS_BWD: op.tid == bwd_tid, Constant.OPS: [op]} + pre_tid = op.tid + split_result.append(part_data_dict) + return split_result + + def run(self, matching_type, base_data, comparison_data): + if matching_type == self.MODULE_TYPE: + return self._match_nn_module(base_data, comparison_data) + + if self._base_bwd_tid is None or self._comparison_bwd_tid is None: + return self._match_torch_op(base_data, comparison_data) + + base_data = self._split_operator_data(base_data, self._base_bwd_tid) + comparison_data = self._split_operator_data(comparison_data, self._comparison_bwd_tid) + if not base_data: + comparison_data_list = [] + for data in comparison_data: + comparison_data_list.extend(data.get(Constant.OPS, [])) + return self._match_torch_op([], comparison_data_list) + if not comparison_data: + base_data_list = [] + for data in base_data: + base_data_list.extend(data.get(Constant.OPS, [])) + return self._match_torch_op(base_data_list, []) + + result_data = [] + base_data_len, comparison_data_len = len(base_data), len(comparison_data) + if base_data[0].get(Constant.IS_BWD) == comparison_data[0].get(Constant.IS_BWD): + base_index, comparison_index = 0, 0 + elif base_data_len > comparison_data_len: + result_data.extend(self._match_torch_op(base_data[0].get(Constant.OPS, []), [])) + base_index, comparison_index = 1, 0 + else: + result_data.extend(self._match_torch_op([], comparison_data[0].get(Constant.OPS, []))) + base_index, comparison_index = 0, 1 + while base_index < base_data_len: + comparison_ops = [] if comparison_index >= comparison_data_len else comparison_data[ + comparison_index].get(Constant.OPS, []) + result_data.extend(self._match_torch_op(base_data[base_index].get(Constant.OPS, []), comparison_ops)) + base_index += 1 + comparison_index += 1 + while comparison_index < comparison_data_len: + result_data.extend(self._match_torch_op([], comparison_data[0].get(Constant.OPS, []))) + comparison_index += 1 + return result_data + + def _match_torch_op(self, base_ops, comparison_ops) -> list: + if not base_ops and not comparison_ops: + return [] + name_func = NameFunction(self._args).get_name_func() + op_compare_result = longest_common_subsequence_matching(base_ops, comparison_ops, name_func) \ + if not self._args.disable_details else self._match_none_subsequence(base_ops, comparison_ops) + if self._args.max_kernel_num is not None: + op_compare_result = self._drill_down(op_compare_result, name_func) + return op_compare_result + + def _drill_down(self, compare_result_data: list, name_func: any) -> list: + drill_down_result = [] + compare_result_data.reverse() + op_deque = deque(compare_result_data) + while op_deque: + match_data = op_deque.pop() + base_op = match_data[0] if match_data[0] else TorchOpNode() + comparison_op = match_data[1] if match_data[1] else TorchOpNode() + if not base_op.child_nodes or not comparison_op.child_nodes: + drill_down_result.append(match_data) + continue + if max(base_op.kernel_num, comparison_op.kernel_num) <= self._args.max_kernel_num: + drill_down_result.append(match_data) + continue + match_list = longest_common_subsequence_matching(base_op.child_nodes, + comparison_op.child_nodes, + name_func) \ + if not self._args.disable_details else self._match_none_subsequence(base_op.child_nodes, + comparison_op.child_nodes) + match_list.reverse() + op_deque.extend(match_list) + + return drill_down_result + + def _match_nn_module(self, base_root_node, comparison_root_node) -> list: + module_compare_result = [] + for index, base_node in enumerate(base_root_node): + comparison_node = comparison_root_node[index] if index < len(comparison_root_node) else None + if not base_node or not comparison_node: + continue + module_compare_result.extend(self._matching_all_modules(base_node, comparison_node)) + return module_compare_result + + def _matching_all_modules(self, base_node: ModuleNode, comparison_node: ModuleNode): + all_matched_modules = [] + matched_queue = deque() + matched_queue.append([base_node, comparison_node]) + while matched_queue: + matched_base_node, matched_comparison_node = matched_queue.popleft() + matched_node_list = self._matching_common_subsequence(matched_base_node, matched_comparison_node) + all_matched_modules.extend(matched_node_list) + for matched_node in matched_node_list: + matched_queue.append(matched_node) + return all_matched_modules + + def _matching_common_subsequence(self, base_node: ModuleNode, comparison_node: ModuleNode): + base_modules = base_node.child_nodes if base_node else [] + comparison_modules = comparison_node.child_nodes if comparison_node else [] + if not base_modules and not comparison_modules: + return [] + name_func = NameFunction(self._args).get_module_name + result = longest_common_subsequence_matching(base_modules, comparison_modules, name_func) \ + if not self._args.disable_details else self._match_none_subsequence(base_modules, comparison_modules) + return result diff --git a/profiler/compare_tools/compare_backend/disaggregate/overall_perf_interface.py b/profiler/compare_tools/compare_backend/disaggregate/overall_perf_interface.py index 7bac2b0335329a2ef9b5e13e21feeedcf569246d..65524664ee0cde85a4ac045475b2442c8a7da396 100644 --- a/profiler/compare_tools/compare_backend/disaggregate/overall_perf_interface.py +++ b/profiler/compare_tools/compare_backend/disaggregate/overall_perf_interface.py @@ -15,9 +15,18 @@ class OverallPerfInterface: self._result_data = {} def run(self): - self._check_path() - self._load_data() - self._generate_result() + try: + self._check_path() + self._load_data() + self._generate_result() + except NotImplementedError as e: + print(f"[ERROR] {e}") + except RuntimeError as e: + print(f"[ERROR] {e}") + except FileNotFoundError as e: + print(f"[ERROR] {e}") + except Exception as e: + print(f"[ERROR] {e}") return self._result_data def _check_path(self): diff --git a/profiler/compare_tools/compare_backend/generator/detail_performance_generator.py b/profiler/compare_tools/compare_backend/generator/detail_performance_generator.py index 6fe693fb0675cc91820f859c476c61999054ec25..916c426c63010fcfd7cad9da52b046c34d66e477 100644 --- a/profiler/compare_tools/compare_backend/generator/detail_performance_generator.py +++ b/profiler/compare_tools/compare_backend/generator/detail_performance_generator.py @@ -1,7 +1,5 @@ import os -from collections import deque from datetime import datetime -from queue import Queue from compare_backend.comparator.communication_comparator import CommunicationComparator from compare_backend.comparator.module_comparetor import ModuleComparator @@ -24,39 +22,27 @@ from compare_backend.compare_bean.overall_metrics_bean import OverallMetricsBean from compare_backend.data_prepare.module_data_prepare import ModuleDataPrepare from compare_backend.data_prepare.operator_data_prepare import OperatorDataPrepare from compare_backend.generator.base_generator import BaseGenerator -from compare_backend.utils.common_func import longest_common_subsequence_matching from compare_backend.utils.constant import Constant -from compare_backend.utils.module_node import ModuleNode -from compare_backend.utils.name_function import NameFunction -from compare_backend.utils.torch_op_node import TorchOpNode from compare_backend.view.excel_view import ExcelView +from compare_backend.data_prepare.sequence_pre_matching import SequencePreMatching + class DetailPerformanceGenerator(BaseGenerator): def __init__(self, profiling_data_dict: dict, args: any): super().__init__(profiling_data_dict, args) - - @classmethod - def _match_none_subsequence(cls, base_ops: list, comparison_ops: list) -> list: - op_compare_result = [[op, None] for op in iter(base_ops)] - op_compare_result.extend([[None, op] for op in iter(comparison_ops)]) - return op_compare_result + self._base_step_id = int(args.base_step) if args.base_step else Constant.VOID_STEP + self._comparison_step_id = int(args.comparison_step) if args.comparison_step else Constant.VOID_STEP def compare(self): enable_compare = [self._args.enable_operator_compare, self._args.enable_memory_compare, self._args.enable_communication_compare, self._args.enable_api_compare, - self._args.enable_kernel_compare] + self._args.enable_kernel_compare, self._args.enable_profiling_compare] if any(enable_compare): print("[INFO] Start to compare performance detail data, please wait.") comparator_list = self._create_comparator() else: comparator_list = [] - if self._args.enable_profiling_compare: - overall_data = {Constant.BASE_DATA: self._profiling_data_dict.get(Constant.BASE_DATA).overall_metrics, - Constant.COMPARISON_DATA: self._profiling_data_dict.get( - Constant.COMPARISON_DATA).overall_metrics} - # overall 数据在最前面 - comparator_list.insert(0, OverallMetricsComparator(overall_data, OverallMetricsBean)) for comparator in comparator_list: self._result_data.update(comparator.generate_data()) @@ -71,45 +57,62 @@ class DetailPerformanceGenerator(BaseGenerator): def _create_comparator(self): comparator_list = [] - - op_compare_result = [] - - if self._args.enable_operator_compare: - module_compare_result = self.match_nn_module() if self._profiling_data_dict.get( - Constant.BASE_DATA).python_function_data and self._profiling_data_dict.get( - Constant.COMPARISON_DATA).python_function_data else [] - if not module_compare_result: - op_compare_result = self.match_torch_op() - - if self._args.enable_memory_compare and not op_compare_result: - op_compare_result = self.match_torch_op() - + # 总体性能拆解 + if self._args.enable_profiling_compare: + overall_data = { + Constant.BASE_DATA: self._profiling_data_dict.get(Constant.BASE_DATA).overall_metrics, + Constant.COMPARISON_DATA: self._profiling_data_dict.get(Constant.COMPARISON_DATA).overall_metrics + } + comparator_list.append(OverallMetricsComparator(overall_data, OverallMetricsBean)) + # 通信性能比对 if self._args.enable_communication_compare: communication_data = { Constant.BASE_DATA: self._profiling_data_dict.get(Constant.BASE_DATA).communication_dict, Constant.COMPARISON_DATA: self._profiling_data_dict.get(Constant.COMPARISON_DATA).communication_dict} comparator_list.append(CommunicationComparator(communication_data, CommunicationBean)) + # 算子性能比对-module级 + enable_operator_compare = False if self._args.enable_operator_compare: + module_compare_result = self._module_match() if module_compare_result: comparator_list.append(ModuleStatisticComparator(module_compare_result, ModuleStatisticBean)) if not self._args.disable_details: comparator_list.append(ModuleComparator(module_compare_result, ModuleCompareBean)) else: - comparator_list.append(OperatorStatisticComparator(op_compare_result, OperatorStatisticBean)) - if not self._args.disable_details: - comparator_list.append(OperatorComparator(op_compare_result, OperatorCompareBean)) + enable_operator_compare = True + + # build tree for operator_compare memory_compare and api_compare + base_op_prepare, comparison_op_prepare = None, None + if self._args.enable_memory_compare or self.enable_api_compare or enable_operator_compare: + base_op_prepare = OperatorDataPrepare(self._profiling_data_dict.get(Constant.BASE_DATA), + self._base_step_id) + comparison_op_prepare = OperatorDataPrepare(self._profiling_data_dict.get(Constant.COMPARISON_DATA), + self._comparison_step_id) + + # 算子性能比对-operator级 + op_compare_result = [] + if enable_operator_compare: + op_compare_result = self._operator_match(base_op_prepare.get_top_layer_ops(), + comparison_op_prepare.get_top_layer_ops()) + comparator_list.append(OperatorStatisticComparator(op_compare_result, OperatorStatisticBean)) + if not self._args.disable_details: + comparator_list.append(OperatorComparator(op_compare_result, OperatorCompareBean)) + # 算子内存比对 if self._args.enable_memory_compare: + if not op_compare_result: + op_compare_result = self._operator_match(base_op_prepare.get_top_layer_ops(), + comparison_op_prepare.get_top_layer_ops()) comparator_list.append(OperatorStatisticComparator(op_compare_result, MemoryStatisticBean)) if not self._args.disable_details: comparator_list.append(OperatorComparator(op_compare_result, MemoryCompareBean)) + # host api比对 if self._args.enable_api_compare: api_compare_result = { - Constant.BASE_DATA: OperatorDataPrepare( - self._profiling_data_dict.get(Constant.BASE_DATA)).get_all_layer_ops(), - Constant.COMPARISON_DATA: OperatorDataPrepare( - self._profiling_data_dict.get(Constant.COMPARISON_DATA)).get_all_layer_ops()} + Constant.BASE_DATA: base_op_prepare.get_all_layer_ops(), + Constant.COMPARISON_DATA: comparison_op_prepare.get_all_layer_ops()} comparator_list.append(ApiCompareComparator(api_compare_result, ApiCompareBean)) + # kernel比对 if self._args.enable_kernel_compare: kernel_compare_result = { Constant.BASE_DATA: self._profiling_data_dict.get(Constant.BASE_DATA).kernel_details, @@ -117,74 +120,19 @@ class DetailPerformanceGenerator(BaseGenerator): comparator_list.append(KernelCompareComparator(kernel_compare_result, KernelCompareBean)) return comparator_list - def match_torch_op(self) -> list: - base_ops = OperatorDataPrepare(self._profiling_data_dict.get(Constant.BASE_DATA)).get_top_layer_ops() - comparison_ops = OperatorDataPrepare( - self._profiling_data_dict.get(Constant.COMPARISON_DATA)).get_top_layer_ops() - if not base_ops and not comparison_ops: + def _module_match(self): + if not self._profiling_data_dict.get(Constant.BASE_DATA).python_function_data or not \ + self._profiling_data_dict.get(Constant.COMPARISON_DATA).python_function_data: return [] - name_func = NameFunction(self._args).get_name_func() - op_compare_result = longest_common_subsequence_matching(base_ops, comparison_ops, name_func) \ - if not self._args.disable_details else self._match_none_subsequence(base_ops, comparison_ops) - if self._args.max_kernel_num is not None: - op_compare_result = self._drill_down(op_compare_result, name_func) - return op_compare_result - - def _drill_down(self, compare_result_data: list, name_func: any) -> list: - drill_down_result = [] - compare_result_data.reverse() - op_deque = deque(compare_result_data) - while op_deque: - match_data = op_deque.pop() - base_op = match_data[0] if match_data[0] else TorchOpNode() - comparison_op = match_data[1] if match_data[1] else TorchOpNode() - if not base_op.child_nodes or not comparison_op.child_nodes: - drill_down_result.append(match_data) - continue - if max(base_op.kernel_num, comparison_op.kernel_num) <= self._args.max_kernel_num: - drill_down_result.append(match_data) - continue - match_list = longest_common_subsequence_matching(base_op.child_nodes, - comparison_op.child_nodes, - name_func) \ - if not self._args.disable_details else self._match_none_subsequence(base_op.child_nodes, - comparison_op.child_nodes) - match_list.reverse() - for data in match_list: - op_deque.append(data) - - return drill_down_result - - def match_nn_module(self) -> list: - module_compare_result = [] - base_root_node = ModuleDataPrepare(self._profiling_data_dict.get(Constant.BASE_DATA)).build_module_tree() + base_root_node = ModuleDataPrepare( + self._profiling_data_dict.get(Constant.BASE_DATA)).build_module_tree() comparison_root_node = ModuleDataPrepare( self._profiling_data_dict.get(Constant.COMPARISON_DATA)).build_module_tree() - for index, base_node in enumerate(base_root_node): - comparison_node = comparison_root_node[index] if index < len(comparison_root_node) else None - if not base_node or not comparison_node: - continue - module_compare_result.extend(self._matching_all_modules(base_node, comparison_node)) - return module_compare_result - - def _matching_all_modules(self, base_node: ModuleNode, comparison_node: ModuleNode): - all_matched_modules = [] - matched_queue = Queue() - matched_queue.put([base_node, comparison_node]) - while not matched_queue.empty(): - matched_base_node, matched_comparison_node = matched_queue.get() - matched_node_list = self._matching_common_subsequence(matched_base_node, matched_comparison_node) - all_matched_modules.extend(matched_node_list) - for matched_node in matched_node_list: - matched_queue.put(matched_node) - return all_matched_modules - - def _matching_common_subsequence(self, base_node: ModuleNode, comparison_node: ModuleNode): - base_modules = base_node.child_nodes if base_node else [] - comparison_modules = comparison_node.child_nodes if comparison_node else [] - if not base_modules and not comparison_modules: - return [] - name_func = NameFunction(self._args).get_module_name - result = longest_common_subsequence_matching(base_modules, comparison_modules, name_func) \ - if not self._args.disable_details else self._match_none_subsequence(base_modules, comparison_modules) - return result + return SequencePreMatching(self._args).run(SequencePreMatching.MODULE_TYPE, base_root_node, + comparison_root_node) + + def _operator_match(self, base_ops, comparison_ops): + base_bwd_tid = self._profiling_data_dict.get(Constant.BASE_DATA).bwd_tid + comparison_bwd_tid = self._profiling_data_dict.get(Constant.COMPARISON_DATA).bwd_tid + return SequencePreMatching(self._args, base_bwd_tid, comparison_bwd_tid).run(SequencePreMatching.OP_TYPE, + base_ops, comparison_ops) diff --git a/profiler/compare_tools/compare_backend/profiling_parser/base_profiling_parser.py b/profiler/compare_tools/compare_backend/profiling_parser/base_profiling_parser.py index 9daaa55ef163b157d4f200cbe039a562a865d72f..6afc52ff9523e39e92cbdfe397967a3c01add418 100644 --- a/profiler/compare_tools/compare_backend/profiling_parser/base_profiling_parser.py +++ b/profiler/compare_tools/compare_backend/profiling_parser/base_profiling_parser.py @@ -21,6 +21,7 @@ class ProfilingResult: self.python_function_data = [] self.fwdbwd_dict = {} self.kernel_details = {} + self.bwd_tid = None def update_torch_op_data(self, event: TraceEventBean): event.is_torch_op = True @@ -44,14 +45,17 @@ class ProfilingResult: def update_comm_task_data(self, comm_name: str, task_event: TraceEventBean): self.communication_dict.setdefault(comm_name, {}).setdefault("comm_task", {}).setdefault( task_event.name, []).append(task_event.dur) - + def update_kernel_details(self, kernels: dict): self.kernel_details = kernels + def update_bwd_tid(self, bwd_tid): + self.bwd_tid = bwd_tid + class BaseProfilingParser(ABC): - def __init__(self, args: any, path_dict: dict): + def __init__(self, args: any, path_dict: dict, step_id: int = Constant.VOID_STEP): self._args = args self._profiling_type = path_dict.get(Constant.PROFILING_TYPE) self._profiling_path = path_dict.get(Constant.PROFILING_PATH) @@ -76,6 +80,7 @@ class BaseProfilingParser(ABC): self._categorize_performance_index = 0 self._cpu_cube_op = None self._bwd_tid = None + self._step_id = step_id @property def cpu_cube_op(self): @@ -115,6 +120,10 @@ class BaseProfilingParser(ABC): raise NotImplementedError("Function _get_dispatch_func need to be implemented.") def load_data(self) -> ProfilingResult: + self._result_data.update_bwd_tid(self._bwd_tid) + if self._step_id != Constant.VOID_STEP and self._profiling_type == Constant.GPU: + msg = "[WARNING] step id is invalid in GPU data, please use this when comparing between NPU datas." + raise RuntimeError(msg) self._dispatch_events() self._update_kernel_dict() self._update_communication_dict() diff --git a/profiler/compare_tools/compare_backend/profiling_parser/gpu_profiling_parser.py b/profiler/compare_tools/compare_backend/profiling_parser/gpu_profiling_parser.py index 0aeeba83efb1ec62b0cf53ced7084dcccb7aa6c8..65fcc092f9b1d414d6294f965a91d67d167121fb 100644 --- a/profiler/compare_tools/compare_backend/profiling_parser/gpu_profiling_parser.py +++ b/profiler/compare_tools/compare_backend/profiling_parser/gpu_profiling_parser.py @@ -13,8 +13,8 @@ class GPUProfilingParser(BaseProfilingParser): FLOW_CAT = ("async_gpu", "async_cpu_to_gpu", "ac2g", "async") TORCH_OP_CAT = ("cpu_op", "user_annotation", "cuda_runtime", "operator", "runtime") - def __init__(self, args: any, path_dict: dict): - super().__init__(args, path_dict) + def __init__(self, args: any, path_dict: dict, step_id: int = Constant.VOID_STEP): + super().__init__(args, path_dict, step_id) self._trace_events = [TraceEventBean(event) for event in self._trace_events.get("traceEvents", [])] self._flow_cat = (args.gpu_flow_cat,) if args.gpu_flow_cat else self.FLOW_CAT self._compute_stream_id = self._infer_compute_stream_id() @@ -61,7 +61,6 @@ class GPUProfilingParser(BaseProfilingParser): def _update_overall_metrics(self): self._calculate_performance_time() self.__parse_memory_reserved() - self._result_data.overall_metrics.calculate_vec_time() self._result_data.overall_metrics.calculate_schedule_time() self._result_data.overall_metrics.trans_time_to_s() @@ -76,7 +75,6 @@ class GPUProfilingParser(BaseProfilingParser): min_ts = min(event.start_time, min_ts) max_ts = max(event.end_time, max_ts) if event.stream == self._compute_stream_id and self.__is_sdma_time(event.name): - self._result_data.overall_metrics.update_sdma_info(event.dur) self._result_data.overall_metrics.update_sdma_stream_info(event.dur) continue if not event.is_kernel_cat(): @@ -84,7 +82,6 @@ class GPUProfilingParser(BaseProfilingParser): self.__add_marks(event) if event.is_nccl_name(): continue - self.__add_compute_time(event, aten_events, flow_dict_new) self.categorize_computing_performance_data(event, flow_dict_new) self._aten_events = None self._result_data.overall_metrics.set_e2e_time(float(max_ts - min_ts)) @@ -104,23 +101,6 @@ class GPUProfilingParser(BaseProfilingParser): for timestep in range(int(event.start_time + 1), int(event.end_time + 1)): self._marks[str(timestep)] += -100 # mark this timestep in compute stream - def __add_compute_time(self, event: TraceEventBean, aten_events: list, flow_dict_new: dict): - if self.__is_flash_attention(event.name): - if event.is_backward(): - self._result_data.overall_metrics.update_fa_bwd_info(event.dur) - else: - self._result_data.overall_metrics.update_fa_fwd_info(event.dur) - elif any(cube_mark in event.lower_name for cube_mark in self.CUBE_MARK): - is_conv = self.__check_is_conv(event, aten_events, flow_dict_new) - if is_conv == "conv_fwd": - self._result_data.overall_metrics.update_conv_fwd_info(event.dur) - elif is_conv == "conv_bwd": - self._result_data.overall_metrics.update_conv_bwd_info(event.dur) - else: - self._result_data.overall_metrics.update_cube_info(event.dur) - else: - self._result_data.overall_metrics.update_vec_info(event.dur) - def __check_is_conv(self, event: TraceEventBean, aten_events: list, flow_dict_new: dict) -> str: flow_start_time = flow_dict_new.get(event.start_time) if not flow_start_time: diff --git a/profiler/compare_tools/compare_backend/profiling_parser/npu_profiling_parser.py b/profiler/compare_tools/compare_backend/profiling_parser/npu_profiling_parser.py index cb25c252c6c825cb22fea63a4c1ecc82f9c61e57..b763d8c9b5febc9483038236873b702a2662c07e 100644 --- a/profiler/compare_tools/compare_backend/profiling_parser/npu_profiling_parser.py +++ b/profiler/compare_tools/compare_backend/profiling_parser/npu_profiling_parser.py @@ -17,11 +17,12 @@ class NPUProfilingParser(BaseProfilingParser): ACTIVE_CPU = "ProfilerActivity.CPU" LEVEL_0 = "Level0" - def __init__(self, args: any, path_dict: dict): - super().__init__(args, path_dict) + def __init__(self, args: any, path_dict: dict, step_id: int = Constant.VOID_STEP): + super().__init__(args, path_dict, step_id) self._operator_memory_path = os.path.join(path_dict.get(Constant.ASCEND_OUTPUT_PATH, ""), "operator_memory.csv") self._memory_record_path = os.path.join(path_dict.get(Constant.ASCEND_OUTPUT_PATH, ""), "memory_record.csv") self._kernel_detail_path = os.path.join(path_dict.get(Constant.ASCEND_OUTPUT_PATH, ""), "kernel_details.csv") + self._communication_path = os.path.join(path_dict.get(Constant.ASCEND_OUTPUT_PATH, ""), "communication.json") self._info_json_path = path_dict.get(Constant.INFO_JSON_PATH, "") self._trace_events = [TraceEventBean(event) for event in self._trace_events] self._hccl_pid = None @@ -71,11 +72,17 @@ class NPUProfilingParser(BaseProfilingParser): for kernel in kernel_details: if kernel.is_invalid(): continue + if self._step_id != Constant.VOID_STEP and kernel.step_id != self._step_id: + continue input_shapes = kernel.input_shapes if kernel.input_shapes else 'N/A' kernels_dict.setdefault(kernel.op_type, {}).setdefault(input_shapes, []).append( [kernel.name, kernel.duration]) - if len(kernels_dict) == 1: - print("[ERROR] Failed to enable enable_kernel_compare, type of kernel_details.csv is null.") + if not kernels_dict: + if self._step_id != Constant.VOID_STEP: + print(f"[ERROR] There is no kernel details infomation for step {self._step_id}," \ + " please check whether the data contains this step.") + else: + print("[ERROR] Failed to enable enable_kernel_compare, type of kernel_details.csv is null.") return self._result_data.update_kernel_details(kernels_dict) @@ -121,6 +128,35 @@ class NPUProfilingParser(BaseProfilingParser): return self._dequeue_data[left].corr_id if self._dequeue_data[left].start_time <= ts_time <= \ self._dequeue_data[left].end_time else Constant.INVALID_VALUE + def _update_bandwidth(self): + try: + communication_json = FileReader.read_trace_file(self._communication_path) + except FileNotFoundError: + print("[WARNING] The file communication.json does not exist.") + except Exception: + print("[ERROR] Failed to read communication.json.") + return + if not communication_json: + print("[WARNING] The communication.json file is empty.") + return + for _, group_dict in communication_json.items(): + step_dict = group_dict.get("collective", {}) + total_op_info = step_dict.get("Total Op Info", {}) + rdma_size_mb = rdma_time_ms = sdma_size_mb = sdma_time_ms = 0 + if "Communication Bandwidth Info" in total_op_info: + bandwidth_info = total_op_info["Communication Bandwidth Info"] + if "RDMA" in bandwidth_info: + rdma_info = bandwidth_info["RDMA"] + rdma_size_mb += rdma_info.get("Transit Size(MB)", 0) # 单位为 MB + rdma_time_ms += rdma_info.get("Transit Time(ms)", 0) # 单位为 MS + if "SDMA" in bandwidth_info: + sdma_info = bandwidth_info["SDMA"] + sdma_size_mb += sdma_info.get("Transit Size(MB)", 0) # 单位为 MB + sdma_time_ms += sdma_info.get("Transit Time(ms)", 0) # 单位为 MS + rdma_bandwidth = rdma_size_mb / rdma_time_ms if rdma_time_ms > 0 else 0 + sdma_bandwidth = sdma_size_mb / sdma_time_ms if sdma_time_ms > 0 else 0 + self._result_data.overall_metrics.set_RDMA_bandwidth(rdma_bandwidth) + self._result_data.overall_metrics.set_SDMA_bandwidth(sdma_bandwidth) def _update_overall_metrics(self): self.__parse_info_json() self.__parse_mem_csv() @@ -133,7 +169,7 @@ class NPUProfilingParser(BaseProfilingParser): self._result_data.overall_metrics.calculate_other_time() self._result_data.overall_metrics.calculate_schedule_time() self._result_data.overall_metrics.trans_time_to_s() - + self._update_bandwidth() def _picking_notify_wait_event_and_not_overlap_event(self): self.notify_event_cache = [] self._not_overlaped_commu_event = [] @@ -271,28 +307,6 @@ class NPUProfilingParser(BaseProfilingParser): self._result_data.overall_metrics.update_lccl_info(event.dur) def __parse_kernel_csv(self): - def __screen_data(kernel: KernelDetailsBean): - if kernel.is_flash_attention(): - if kernel.is_fa_bwd(): - self._result_data.overall_metrics.update_fa_bwd_info(kernel.duration) - else: - self._result_data.overall_metrics.update_fa_fwd_info(kernel.duration) - elif kernel.is_conv(): - if kernel.is_conv_bwd(): - self._result_data.overall_metrics.update_conv_bwd_info(kernel.duration) - else: - self._result_data.overall_metrics.update_conv_fwd_info(kernel.duration) - elif kernel.is_matmul(): - self._result_data.overall_metrics.update_cube_info(kernel.duration) - elif kernel.is_sdma(): - self._result_data.overall_metrics.update_sdma_info(kernel.duration) - elif kernel.is_page_attention(): - self._result_data.overall_metrics.update_pa_info(kernel.duration) - elif kernel.is_vector(): - self._result_data.overall_metrics.update_vec_info(kernel.duration) - else: - self._result_data.overall_metrics.update_cube_info(kernel.duration) - try: kernel_details = FileReader.read_csv_file(self._kernel_detail_path, KernelDetailsBean) except Exception: @@ -306,7 +320,6 @@ class NPUProfilingParser(BaseProfilingParser): for kernel in kernel_details: if kernel.is_invalid(): continue - __screen_data(kernel) self.categorize_computing_performance_data(kernel, flow_dict_new) def __parse_mem_csv(self): @@ -353,5 +366,4 @@ class NPUProfilingParser(BaseProfilingParser): compute_stream = event_wait_stream & ai_core_stream if event_wait_stream else ai_core_stream for stream in compute_stream: dur_list = sdma_dict.get(stream, []) - self._result_data.overall_metrics.update_sdma_info(sum(dur_list), len(dur_list)) self._result_data.overall_metrics.update_sdma_stream_info(sum(dur_list), len(dur_list)) diff --git a/profiler/compare_tools/compare_backend/utils/args_manager.py b/profiler/compare_tools/compare_backend/utils/args_manager.py index ab9fb43a9681a5c5a15d280f1052f493ae8dfcde..69136c4d7e8e3b826c5503bc400b24198037db9a 100644 --- a/profiler/compare_tools/compare_backend/utils/args_manager.py +++ b/profiler/compare_tools/compare_backend/utils/args_manager.py @@ -11,19 +11,21 @@ class Singleton(object): self._cls = cls self._instance = {} - def __call__(self): + def __call__(self, args): if self._cls not in self._instance: - self._instance[self._cls] = self._cls() + self._instance[self._cls] = self._cls(args) return self._instance[self._cls] @Singleton class ArgsManager: - def __init__(self): - self._args = None + def __init__(self, args: any): + self._args = args self._base_path_dict = {} self._comparison_path_dict = {} + self._base_step = Constant.VOID_STEP + self._comparison_step = Constant.VOID_STEP @property def args(self): @@ -53,6 +55,14 @@ class ArgsManager: def comparison_path_dict(self): return self._comparison_path_dict + @property + def base_step(self): + return self._base_step + + @property + def comparison_step(self): + return self._comparison_step + @property def enable_profiling_compare(self): return self._args.enable_profiling_compare @@ -88,6 +98,18 @@ class ArgsManager: PathManager.make_dir_safety(output_path) PathManager.check_path_writeable(output_path) + def get_step_args_with_validating(self): + if self._args.base_step and self._args.comparison_step: + if all([self._args.base_step.isdigit(), self._args.comparison_step.isdigit()]): + self._base_step = int(self._args.base_step) + self._comparison_step = int(self._args.comparison_step) + else: + msg = "Invalid param, base_step and comparison_step must be a number." + raise RuntimeError(msg) + elif any([self._args.base_step, self._args.comparison_step]): + msg = "Invalid param, base_step and comparison_step must be set at the same time." + raise RuntimeError(msg) + def parse_profiling_path(self, file_path: str): self.check_profiling_path(file_path) if os.path.isfile(file_path): @@ -114,8 +136,7 @@ class ArgsManager: path_dict.update({Constant.INFO_JSON_PATH: os.path.join(file_path, dir_name)}) return path_dict - def init(self, args: any): - self._args = args + def init(self): if self._args.max_kernel_num is not None and self._args.max_kernel_num <= Constant.LIMIT_KERNEL: msg = f"Invalid param, --max_kernel_num has to be greater than {Constant.LIMIT_KERNEL}" raise RuntimeError(msg) @@ -135,7 +156,8 @@ class ArgsManager: self._args.enable_communication_compare = True self._args.enable_api_compare = True self._args.enable_kernel_compare = True - + + self.get_step_args_with_validating() base_profiling_path = PathManager.get_realpath(self._args.base_profiling_path) self.check_profiling_path(base_profiling_path) self._base_path_dict = self.parse_profiling_path(base_profiling_path) diff --git a/profiler/compare_tools/compare_backend/utils/common_func.py b/profiler/compare_tools/compare_backend/utils/common_func.py index 68a1ab584f1514980bc784f4a55152efffe698cf..1ced3c0f8d0730d2d05fc3f89cbc10fbc995da28 100644 --- a/profiler/compare_tools/compare_backend/utils/common_func.py +++ b/profiler/compare_tools/compare_backend/utils/common_func.py @@ -41,6 +41,11 @@ def longest_common_subsequence_matching(base_ops: list, comparison_ops: list, na for index, value in enumerate(base_ops): result_data[index] = [value, None] return result_data + if not base_ops: + result_data = [None] * len(comparison_ops) + for index, value in enumerate(comparison_ops): + result_data[index] = [None, value] + return result_data comparison_len, base_len = len(comparison_ops), len(base_ops) if comparison_len * base_len > 50 * 10 ** 8: @@ -51,12 +56,12 @@ def longest_common_subsequence_matching(base_ops: list, comparison_ops: list, na cur_list = [0] * (base_len + 1) comparison_index = 1 - iter_comparison_data = iter(comparison_ops) - for comparison_data in iter_comparison_data: + all_base_data = [hash(name_func(op)) for op in base_ops] + all_comparison_data = [hash(name_func(op)) for op in comparison_ops] + for comparison_data in iter(all_comparison_data): base_index = 1 - iter_base_data = iter(base_ops) - for base_data in iter_base_data: - if name_func(comparison_data) == name_func(base_data): + for base_data in all_base_data: + if comparison_data == base_data: cur_list[base_index] = pre_list[base_index - 1] + 1 else: only_base = cur_list[base_index - 1] @@ -75,7 +80,7 @@ def longest_common_subsequence_matching(base_ops: list, comparison_ops: list, na while comparison_index > 0 and base_index > 0: base_data = base_ops[base_index - 1] comparison_data = comparison_ops[comparison_index - 1] - if name_func(base_data) == name_func(comparison_data): + if all_base_data[base_index - 1] == all_comparison_data[comparison_index - 1]: matched_op.append([base_data, comparison_data]) comparison_index -= 1 base_index -= 1 diff --git a/profiler/compare_tools/compare_backend/utils/compare_args.py b/profiler/compare_tools/compare_backend/utils/compare_args.py index 9e6291e89e0d8273073d1a2f4d8ec06a2dad79c1..36199b5b0ddaab91c5381b6c6c8bc2f362b9af2c 100644 --- a/profiler/compare_tools/compare_backend/utils/compare_args.py +++ b/profiler/compare_tools/compare_backend/utils/compare_args.py @@ -12,7 +12,9 @@ class Args: max_kernel_num: int = None, op_name_map: dict = {}, use_input_shape: bool = False, - gpu_flow_cat: str = ""): + gpu_flow_cat: str = "", + base_step: str = "", + comparison_step: str = ""): self.base_profiling_path = base_profiling_path self.comparison_profiling_path = comparison_profiling_path self.enable_profiling_compare = enable_profiling_compare @@ -26,3 +28,5 @@ class Args: self.op_name_map = op_name_map self.use_input_shape = use_input_shape self.gpu_flow_cat = gpu_flow_cat + self.base_step = base_step + self.comparison_step = comparison_step \ No newline at end of file diff --git a/profiler/compare_tools/compare_backend/utils/constant.py b/profiler/compare_tools/compare_backend/utils/constant.py index 252aa536e1c73d58f86071bbefab5286004bb6f9..08eb1792a88637835e17ba897846a159d31c7309 100644 --- a/profiler/compare_tools/compare_backend/utils/constant.py +++ b/profiler/compare_tools/compare_backend/utils/constant.py @@ -6,6 +6,7 @@ class Constant(object): MAX_PATH_LENGTH = 4096 MAX_FLOW_CAT_LEN = 20 MAX_FILE_SIZE = 1024 * 1024 * 1024 * 5 + MAX_JSON_SIZE = 1024 * 1024 * 1024 * 10 BYTE_TO_KB = 1024 YELLOW_COLOR = "FFFF00" GREEN_COLOR = "00FF00" @@ -15,6 +16,8 @@ class Constant(object): US_TO_MS = 1000 KB_TO_MB = 1024 INVALID_VALUE = -1 + MILLISECONDS_TO_SECONDS = 10 ** 3 + MICROSECONDS_TO_SECONDS = 10 ** 6 # epsilon EPS = 1e-15 @@ -91,3 +94,8 @@ class Constant(object): CPU_OP_MATMUL_MASK = ("aten::addmm", "aten::bmm", "aten::mm", "aten::matmul") KERNEL_CUBE_MASK = ("gemm", "conv", "cutlass", "wgrad") KERNEL_TRANS_MASK = ("cast", "transdata", "transpose") + + IS_BWD = "is_bwd" + OPS = "ops" + + VOID_STEP = -1 \ No newline at end of file diff --git a/profiler/compare_tools/compare_backend/utils/file_reader.py b/profiler/compare_tools/compare_backend/utils/file_reader.py index b4ae786388b2f1bed6ad50cfb39ac8621c1ea1f1..263888a3ecf6ffecff1c3343c8b87a133f70bff9 100644 --- a/profiler/compare_tools/compare_backend/utils/file_reader.py +++ b/profiler/compare_tools/compare_backend/utils/file_reader.py @@ -7,7 +7,6 @@ from compare_backend.utils.constant import Constant class FileReader: - @classmethod def read_trace_file(cls, file_path: str) -> any: PathManager.check_path_readable(file_path) diff --git a/profiler/compare_tools/compare_backend/utils/torch_op_node.py b/profiler/compare_tools/compare_backend/utils/torch_op_node.py index 69ee92d1232e5808ed428896a03718230559d12f..06479462cf350529d88b81f7666b55e568400245 100644 --- a/profiler/compare_tools/compare_backend/utils/torch_op_node.py +++ b/profiler/compare_tools/compare_backend/utils/torch_op_node.py @@ -24,6 +24,10 @@ class TorchOpNode: def name(self): return self._event.name + @property + def tid(self): + return self._event.tid + @property def input_shape(self): return str(self._event.args.get("Input Dims", Constant.NA)) @@ -67,7 +71,7 @@ class TorchOpNode: @property def api_dur(self): return self._event.dur - + @property def api_self_time(self): return self.api_dur - sum(child.api_dur for child in self._child_nodes) @@ -96,5 +100,10 @@ class TorchOpNode: def is_step_profiler(self) -> bool: return self._event.is_step_profiler() + def get_step_id(self) -> int: + if self.is_step_profiler(): + return int(self._event.name.split("#")[1]) + return Constant.VOID_STEP + def get_op_info(self) -> list: return [self.name, self.input_shape, self.input_type, self.call_stack] diff --git a/profiler/compare_tools/compare_backend/utils/tree_builder.py b/profiler/compare_tools/compare_backend/utils/tree_builder.py index d5aa787ac2cff1ba4a714b8522b839b0dc83bfd2..b770115795dde5d9877d769483aed9e003a14030 100644 --- a/profiler/compare_tools/compare_backend/utils/tree_builder.py +++ b/profiler/compare_tools/compare_backend/utils/tree_builder.py @@ -9,11 +9,13 @@ class TreeBuilder: @classmethod def build_tree(cls, event_list: list, kernel_dict: dict, memory_list: list) -> TorchOpNode: root_node = TorchOpNode() + all_nodes = [root_node] + ([None] * len(event_list)) all_event_list = [] all_event_list.extend(event_list) all_event_list.extend(memory_list) all_event_list.sort(key=lambda x: x.start_time) last_node = root_node + index = 1 for event in all_event_list: while last_node: if last_node != root_node and event.start_time > last_node.end_time: @@ -21,6 +23,8 @@ class TreeBuilder: continue if event.is_torch_op: tree_node = TorchOpNode(event, last_node) + all_nodes[index] = tree_node + index += 1 last_node.add_child_node(tree_node) last_node = tree_node if kernel_dict: @@ -29,7 +33,7 @@ class TreeBuilder: event.set_name(last_node.name) last_node.set_memory_allocated(event) break - return root_node + return all_nodes[:index] @classmethod def get_total_kernels(cls, root_node: TorchOpNode) -> list: diff --git a/profiler/compare_tools/compare_interface/comparison_interface.py b/profiler/compare_tools/compare_interface/comparison_interface.py index 919095b310126f2ce0c9c3e6912fb10f24d149e9..68bbcc026e5d14c7d1d3ae2e3c45a2bc173d68ce 100644 --- a/profiler/compare_tools/compare_interface/comparison_interface.py +++ b/profiler/compare_tools/compare_interface/comparison_interface.py @@ -12,16 +12,18 @@ from compare_backend.utils.constant import Constant class ComparisonInterface: - def __init__(self, base_profiling_path: str, comparison_profiling_path: str = ""): + def __init__(self, base_profiling_path: str, comparison_profiling_path: str = "", + base_step: str = "", comparison_step: str = ""): self.base_profiling_path = base_profiling_path if comparison_profiling_path: self._args = Args(base_profiling_path=base_profiling_path, - comparison_profiling_path=comparison_profiling_path) + comparison_profiling_path=comparison_profiling_path, + base_step=base_step, + comparison_step=comparison_step) def compare(self, compare_type: str) -> dict: if compare_type == Constant.OVERALL_COMPARE: self._args.enable_profiling_compare = True - return ComparisonGenerator(self._args).run_interface(compare_type) def disaggregate_perf(self, compare_type: str) -> dict: diff --git a/profiler/compare_tools/performance_compare.py b/profiler/compare_tools/performance_compare.py index 7c9d60aac0af38c3fe3dd5f2ca9c96380438d4c9..dff87db2fb7db9f2ccb073c8f37037ef088e822f 100644 --- a/profiler/compare_tools/performance_compare.py +++ b/profiler/compare_tools/performance_compare.py @@ -27,11 +27,12 @@ def main(): help="配置GPU与NPU等价的算子名称映射关系,以字典的形式传入") parser.add_argument("--use_input_shape", default=False, action='store_true', help="开启算子的精准匹配") parser.add_argument("--gpu_flow_cat", type=str, default='', help="gpu flow event的分类标识") + parser.add_argument("--base_step", type=str, default='', help="基准性能数据指定比对step") + parser.add_argument("--comparison_step", type=str, default='', help="比较性能数据指定比对step") args = parser.parse_args() ComparisonGenerator(args).run() - if __name__ == "__main__": start_time = datetime.datetime.now() main() diff --git a/profiler/test/ut/advisor/cluster_advice/test_rdma_retransmission_advice.py b/profiler/test/ut/advisor/cluster_advice/test_rdma_retransmission_advice.py new file mode 100644 index 0000000000000000000000000000000000000000..eb383a65991899a4510bdafa05224325cb4eb190 --- /dev/null +++ b/profiler/test/ut/advisor/cluster_advice/test_rdma_retransmission_advice.py @@ -0,0 +1,170 @@ +import os +import shutil +import stat +import json + +import unittest +from profiler.advisor.interface.interface import Interface +from profiler.advisor.common.analyzer_scopes import SupportedScopes + + +class TestRdmaAdvice(unittest.TestCase): + TMP_DIR = "./tmp/" + OUTPUT_DIR = "./tmp/cluster_analysis_output" + interface = None + err_interface = None + + def tearDown(self): + if os.path.exists(TestRdmaAdvice.TMP_DIR): + shutil.rmtree(TestRdmaAdvice.TMP_DIR) + if os.path.exists(TestRdmaAdvice.OUTPUT_DIR): + shutil.rmtree(TestRdmaAdvice.OUTPUT_DIR) + self.clear_htmls() + + def setUp(self): + if os.path.exists(TestRdmaAdvice.TMP_DIR): + shutil.rmtree(TestRdmaAdvice.TMP_DIR) + if not os.path.exists(TestRdmaAdvice.TMP_DIR): + os.makedirs(TestRdmaAdvice.TMP_DIR) + if not os.path.exists(TestRdmaAdvice.OUTPUT_DIR): + os.makedirs((TestRdmaAdvice.OUTPUT_DIR)) + self.clear_htmls() + + @classmethod + def clear_htmls(cls): + current_path = os.path.dirname(os.path.abspath(__file__)) + for filename in os.listdir(current_path): + # 检查文件是否以“mstt”开头 + if filename.startswith("mstt"): + # 构建文件的完整路径 + file_path = os.path.join(current_path, filename) + # 删除文件 + os.remove(file_path) + + @classmethod + def get_cluster_communication_view(cls): + data = {"p2p":{"step1" : { + "hcom_broadcast__844_0_1@13681369207305868844": { + "0": { + "Communication Time Info": { + "Start Timestamp(us)": 1713174287354248.0, + "Elapse Time(ms)": 4688, + "Transit Time(ms)": 0, + "Wait Time(ms)": 0.01162, + "Synchronization Time(ms)": 0.01162, + "Idle Time(ms)": 39.0606, + "Wait Time Ratio": 1.0, + "Synchronization Time Ratio": 1.0 + }, + "Communication Bandwidth Info": { + "RDMA": { + "Transit Size(MB)": 80, + "Transit Time(ms)": 4600, + "Bandwidth(GB/s)": 0.003, + "Large Packet Ratio": 0, + "Size Distribution": {} + }, + "HCCS": { + "Transit Size(MB)": 0, + "Transit Time(ms)": 0, + "Bandwidth(GB/s)": 0, + "Large Packet Ratio": 0, + "Size Distribution": {} + }, + "PCIE": { + "Transit Size(MB)": 0, + "Transit Time(ms)": 0, + "Bandwidth(GB/s)": 0, + "Large Packet Ratio": 0, + "Size Distribution": {} + }, + "SDMA": { + "Transit Size(MB)": 0, + "Transit Time(ms)": 0, + "Bandwidth(GB/s)": 0, + "Large Packet Ratio": 0, + "Size Distribution": {} + }, + "SIO": { + "Transit Size(MB)": 0, + "Transit Time(ms)": 0, + "Bandwidth(GB/s)": 0, + "Large Packet Ratio": 0, + "Size Distribution": {} + } + } + }, + "16": { + "Communication Time Info": { + "Start Timestamp(us)": 1713174287186619.8, + "Elapse Time(ms)": 4788, + "Transit Time(ms)": 0.0013, + "Wait Time(ms)": 39.037240000000004, + "Synchronization Time(ms)": 39.03034, + "Idle Time(ms)": 167.66008000000002, + "Wait Time Ratio": 1.0, + "Synchronization Time Ratio": 1.0 + }, + "Communication Bandwidth Info": { + "RDMA": { + "Transit Size(MB)": 80, + "Transit Time(ms)": 4700, + "Bandwidth(GB/s)": 0.0033, + "Large Packet Ratio": 0, + "Size Distribution": {} + }, + "HCCS": { + "Transit Size(MB)": 4e-05, + "Transit Time(ms)": 0.0013, + "Bandwidth(GB/s)": 0.0308, + "Large Packet Ratio": 0.0, + "Size Distribution": { + "4e-05": [ + 1, + 0.0013 + ] + } + }, + "PCIE": { + "Transit Size(MB)": 0, + "Transit Time(ms)": 0, + "Bandwidth(GB/s)": 0, + "Large Packet Ratio": 0, + "Size Distribution": {} + }, + "SDMA": { + "Transit Size(MB)": 4e-05, + "Transit Time(ms)": 0.0013, + "Bandwidth(GB/s)": 0.0308, + "Large Packet Ratio": 0, + "Size Distribution": {} + }, + "SIO": { + "Transit Size(MB)": 0, + "Transit Time(ms)": 0, + "Bandwidth(GB/s)": 0, + "Large Packet Ratio": 0, + "Size Distribution": {} + } + } + }, + } + }}} + return data + + @classmethod + def create_communicaton_json(cls): + raw_data = cls.get_cluster_communication_view() + with os.fdopen(os.open(f"{TestRdmaAdvice.OUTPUT_DIR}/cluster_communication.json", + os.O_WRONLY | os.O_CREAT, stat.S_IWUSR | stat.S_IRUSR), 'w') as fp: + fp.write(json.dumps(raw_data)) + + def test_run_should_run_success_when_contain_cluster_communication_json(self): + self.create_communicaton_json() + interface = Interface(profiling_path=self.TMP_DIR) + dimension = "cluster" + scope = SupportedScopes.COMMUNICATION_RETRANSMISSION_DETECTION + result = interface.get_result(dimension, scope, render_html=1, output_dict=False, profiling_path=self.TMP_DIR) + self.assertEqual(2, len(result.data.get("Comm Retransmission Analysis", []))) + self.assertEqual(2, len(result.data.get("Comm Retransmission Analysis", []).get('data'))) + result.clear() diff --git a/profiler/test/ut/advisor/communication_advice/test_packet_advice.py b/profiler/test/ut/advisor/communication_advice/test_packet_advice.py new file mode 100644 index 0000000000000000000000000000000000000000..a8fd4549ecd90cbe09acb482f6fbc3cff8a935eb --- /dev/null +++ b/profiler/test/ut/advisor/communication_advice/test_packet_advice.py @@ -0,0 +1,175 @@ +import os +import shutil +import stat +import json + +import unittest +from profiler.advisor.interface.interface import Interface +from profiler.advisor.common.analyzer_scopes import SupportedScopes + + +class TestPacketAdvice(unittest.TestCase): + TMP_DIR = "./ascend_pt" + OUTPUT_DIR = "./ascend_pt/ASCEND_PROFILER_OUTPUT" + interface = None + err_interface = None + + def tearDown(self): + if os.path.exists(TestPacketAdvice.TMP_DIR): + shutil.rmtree(TestPacketAdvice.TMP_DIR) + self.clear_htmls() + + def setUp(self): + if os.path.exists(TestPacketAdvice.TMP_DIR): + shutil.rmtree(TestPacketAdvice.TMP_DIR) + if not os.path.exists(TestPacketAdvice.TMP_DIR): + os.makedirs(TestPacketAdvice.TMP_DIR) + if not os.path.exists(TestPacketAdvice.OUTPUT_DIR): + os.makedirs(TestPacketAdvice.OUTPUT_DIR) + self.clear_htmls() + + @classmethod + def clear_htmls(cls): + current_path = os.path.dirname(os.path.abspath(__file__)) + for filename in os.listdir(current_path): + # 检查文件是否以“att”开头 + if filename.startswith("mstt"): + # 构建文件的完整路径 + file_path = os.path.join(current_path, filename) + # 删除文件 + os.remove(file_path) + + @classmethod + def get_communication_view(cls): + data = {"step1":{"collective" : { + "hcom_broadcast__844_1_1@13681369207305868844": { + "Communication Time Info": { + "Start Timestamp(us)": 1713174287407957.0, + "Elapse Time(ms)": 0.06086, + "Transit Time(ms)": 0.00126, + "Wait Time(ms)": 0.014939999999999998, + "Synchronization Time(ms)": 0.00714, + "Idle Time(ms)": 0.044660000000000005, + "Wait Time Ratio": 0.9222, + "Synchronization Time Ratio": 0.85 + }, + "Communication Bandwidth Info": { + "RDMA": { + "Transit Size(MB)": 0, + "Transit Time(ms)": 0, + "Bandwidth(GB/s)": 0, + "Large Packet Ratio": 0, + "Size Distribution": {} + }, + "HCCS": { + "Transit Size(MB)": 0.028575999999999997, + "Transit Time(ms)": 0.008620000000000001, + "Bandwidth(GB/s)": 3.3151, + "Large Packet Ratio": 0.0, + "Size Distribution": { + "0.004224": [ + 6, + 0.00736 + ], + "0.003232": [ + 1, + 0.00126 + ] + } + }, + "PCIE": { + "Transit Size(MB)": 0, + "Transit Time(ms)": 0, + "Bandwidth(GB/s)": 0, + "Large Packet Ratio": 0, + "Size Distribution": {} + }, + "SDMA": { + "Transit Size(MB)": 0.028575999999999997, + "Transit Time(ms)": 0.008620000000000001, + "Bandwidth(GB/s)": 3.3151, + "Large Packet Ratio": 0, + "Size Distribution": {} + }, + "SIO": { + "Transit Size(MB)": 0, + "Transit Time(ms)": 0, + "Bandwidth(GB/s)": 0, + "Large Packet Ratio": 0, + "Size Distribution": {} + } + } + }, + "hcom_allReduce__844_2_1@13681369207305868844": { + "Communication Time Info": { + "Start Timestamp(us)": 1713174287432401.2, + "Elapse Time(ms)": 2.9042, + "Transit Time(ms)": 1.35236, + "Wait Time(ms)": 1.47632, + "Synchronization Time(ms)": 1.44524, + "Idle Time(ms)": 0.07551999999999981, + "Wait Time Ratio": 0.5219, + "Synchronization Time Ratio": 0.5166 + }, + "Communication Bandwidth Info": { + "RDMA": { + "Transit Size(MB)": 0, + "Transit Time(ms)": 0, + "Bandwidth(GB/s)": 0, + "Large Packet Ratio": 0, + "Size Distribution": {} + }, + "HCCS": { + "Transit Size(MB)": 176.16076799999996, + "Transit Time(ms)": 9.55658, + "Bandwidth(GB/s)": 18.4335, + "Large Packet Ratio": 0.0, + "Size Distribution": { + "12.582912": [ + 14, + 9.55658 + ] + } + }, + "PCIE": { + "Transit Size(MB)": 0, + "Transit Time(ms)": 0, + "Bandwidth(GB/s)": 0, + "Large Packet Ratio": 0, + "Size Distribution": {} + }, + "SDMA": { + "Transit Size(MB)": 176.16076799999996, + "Transit Time(ms)": 9.55658, + "Bandwidth(GB/s)": 18.4335, + "Large Packet Ratio": 0, + "Size Distribution": {} + }, + "SIO": { + "Transit Size(MB)": 0, + "Transit Time(ms)": 0, + "Bandwidth(GB/s)": 0, + "Large Packet Ratio": 0, + "Size Distribution": {} + } + } + }, + }}} + return data + + @classmethod + def create_communicaton_json(cls): + raw_data = cls.get_communication_view() + with os.fdopen(os.open(f"{TestPacketAdvice.OUTPUT_DIR}/communication.json", + os.O_WRONLY | os.O_CREAT, stat.S_IWUSR | stat.S_IRUSR), 'w') as fp: + fp.write(json.dumps(raw_data)) + + def test_run_should_run_success_when_ascend_pt_contain_communication_json(self): + self.create_communicaton_json() + interface = Interface(profiling_path=self.TMP_DIR) + dimension = "communication" + scope = SupportedScopes.PACKET + result = interface.get_result(dimension, scope, render_html=1, output_dict=False, profiling_path=self.TMP_DIR) + self.assertEqual(2, len(result.data.get("Packet Analysis", []))) + self.assertEqual(1, len(result.data.get("Packet Analysis", []).get('data'))) + result.clear() diff --git a/profiler/test/ut/cluster_analyse/cluster_utils/test_parallel_strategy_calculator.py b/profiler/test/ut/cluster_analyse/cluster_utils/test_parallel_strategy_calculator.py new file mode 100644 index 0000000000000000000000000000000000000000..2eb8b300ab5faee326448712e5eb35d6725f466f --- /dev/null +++ b/profiler/test/ut/cluster_analyse/cluster_utils/test_parallel_strategy_calculator.py @@ -0,0 +1,46 @@ +import unittest + +from cluster_utils.parallel_strategy_calculator import ParallelStrategyCalculator + + +class TestParallelStrategyCalculator(unittest.TestCase): + def test_parallel_strategy_calculator_should_raise_runtime_error_when_dp4_ep3(self): + with self.assertRaises(RuntimeError): + calculator = ParallelStrategyCalculator( + world_size=16, + tensor_model_parallel_size=1, + pipeline_model_parallel_size=4, + data_parallel_size=4, + context_parallel_size=1, + expert_model_parallel_size=3) + + calculator.run() + + def test_parallel_strategy_calculator_should_raise_runtime_error_when_dp1_pp4_tp2_world_size16(self): + with self.assertRaises(RuntimeError): + calculator = ParallelStrategyCalculator( + world_size=16, + tensor_model_parallel_size=2, + pipeline_model_parallel_size=4, + data_parallel_size=1, + context_parallel_size=1, + expert_model_parallel_size=1) + + calculator.run() + + def test_parallel_strategy_calculator_dp2_pp4_tp2(self): + calculator = ParallelStrategyCalculator( + world_size=16, + tensor_model_parallel_size=2, + pipeline_model_parallel_size=4, + data_parallel_size=2, + context_parallel_size=1, + expert_model_parallel_size=1) + + # dp index, pp index, tp index + expected_res = [ + (0, 0, 0), (0, 0, 1), (1, 0, 0), (1, 0, 1), (0, 1, 0), (0, 1, 1), (1, 1, 0), (1, 1, 1), + (0, 2, 0), (0, 2, 1), (1, 2, 0), (1, 2, 1), (0, 3, 0), (0, 3, 1), (1, 3, 0), (1, 3, 1) + ] + res = calculator.run() + self.assertEqual(res, expected_res) diff --git a/profiler/test/ut/compare_tools/compare_bean/test_profiling_info.py b/profiler/test/ut/compare_tools/compare_bean/test_profiling_info.py index dc85b0af0ab6403122036d4a41f87e35c903dfb5..59525f18f96236a7e0383d08721629318f690f1b 100644 --- a/profiler/test/ut/compare_tools/compare_bean/test_profiling_info.py +++ b/profiler/test/ut/compare_tools/compare_bean/test_profiling_info.py @@ -4,28 +4,6 @@ from compare_backend.compare_bean.profiling_info import ProfilingInfo class TestProfilingInfo(unittest.TestCase): - def test_calculate_other_time(self): - info = ProfilingInfo("NPU") - info.compute_time = 10 - info.cube_time = 1 - info.fa_time_fwd = 2 - info.fa_time_bwd = 2 - info.vec_time = 3 - info.calculate_other_time() - self.assertEqual(info.other_time, 2) - info.vec_time = 7 - info.calculate_other_time() - self.assertEqual(info.other_time, 0) - - def test_calculate_vec_time(self): - info = ProfilingInfo("NPU") - info.compute_time = 10 - info.cube_time = 1 - info.fa_time_fwd = 2 - info.fa_time_bwd = 2 - info.calculate_vec_time() - self.assertEqual(info.vec_time, 5) - def test_calculate_schedule_time(self): info = ProfilingInfo("NPU") info.e2e_time = 10 @@ -36,41 +14,50 @@ class TestProfilingInfo(unittest.TestCase): def test_update_fa_fwd_info(self): info = ProfilingInfo("NPU") - info.update_fa_fwd_info(5) - info.update_fa_fwd_info(5) - self.assertEqual(info.fa_time_fwd, 10) + info.fa_time_fwd_cube = 5 + info.fa_time_fwd_vector = 5 + info.fa_num_fwd_cube = 1 + info.fa_num_fwd_vector = 1 + self.assertEqual(info.fa_time_fwd, 0.01) self.assertEqual(info.fa_num_fwd, 2) def test_update_fa_bwd_info(self): info = ProfilingInfo("NPU") - info.update_fa_bwd_info(5) - info.update_fa_bwd_info(5) - self.assertEqual(info.fa_time_bwd, 10) + info.fa_time_bwd_cube = 5 + info.fa_time_bwd_vector = 5 + info.fa_num_bwd_cube = 1 + info.fa_num_bwd_vector = 1 + self.assertEqual(info.fa_time_bwd, 0.01) self.assertEqual(info.fa_num_bwd, 2) def test_update_sdma_info(self): info = ProfilingInfo("NPU") - info.update_sdma_info(5) - self.assertEqual(info.sdma_time, 5) - self.assertEqual(info.sdma_num, 1) - info.update_sdma_info(5, 5) - self.assertEqual(info.sdma_time, 10) - self.assertEqual(info.sdma_num, 6) + info.sdma_time_tensor_move = 5 + info.sdma_time_stream = 5 + info.sdma_num_tensor_move = 5 + info.sdma_num_stream = 5 + self.assertEqual(info.sdma_time, 0.01) + self.assertEqual(info.sdma_num, 10) def test_update_cube_info(self): info = ProfilingInfo("NPU") - info.update_cube_info(5) - info.update_cube_info(5) - self.assertEqual(info.cube_time, 10) - self.assertEqual(info.cube_num, 2) + info.matmul_time_cube = 1 + info.matmul_time_vector = 1 + info.other_cube_time = 1 + info.matmul_num_cube = 5 + info.matmul_num_vector = 5 + info.other_cube_num = 5 + self.assertEqual(info.cube_time, 0.003) + self.assertEqual(info.cube_num, 15) def test_update_vec_info(self): info = ProfilingInfo("NPU") - info.update_vec_info(5) - info.update_vec_info(5) - self.assertEqual(info.vec_time, 10) - self.assertEqual(info.vec_num, 2) - + info.vector_time_trans = 1 + info.vector_time_notrans = 1 + info.vector_num_trans = 2 + info.vector_num_notrans = 2 + self.assertEqual(info.vec_time, 0.002) + self.assertEqual(info.vec_num, 4) def test_set_compute_time(self): info = ProfilingInfo("NPU") info.update_compute_time(1) diff --git a/profiler/test/ut/compare_tools/profiling_parser/test_base_profiling_parser.py b/profiler/test/ut/compare_tools/profiling_parser/test_base_profiling_parser.py index 80734635929597fff2f5a1bbbe79582817ba2858..b78c59f1f70634a4aa63efdbe5d83f6692d9efae 100644 --- a/profiler/test/ut/compare_tools/profiling_parser/test_base_profiling_parser.py +++ b/profiler/test/ut/compare_tools/profiling_parser/test_base_profiling_parser.py @@ -26,6 +26,8 @@ class ProfilingParser(BaseProfilingParser): self._enable_communication_compare = True self._enable_kernel_compare = True self._enable_api_compare = True + self._bwd_tid = 1 + self._step_id = -1 def _update_kernel_details(self): pass diff --git a/profiler/test/ut/compare_tools/profiling_parser/test_gpu_profiling_parser.py b/profiler/test/ut/compare_tools/profiling_parser/test_gpu_profiling_parser.py index d7cb3d0588a3e13097d2429a92f283b6c3eaf4b8..25293d64a2c371002e6c9624f4fa6c10c592c13b 100644 --- a/profiler/test/ut/compare_tools/profiling_parser/test_gpu_profiling_parser.py +++ b/profiler/test/ut/compare_tools/profiling_parser/test_gpu_profiling_parser.py @@ -76,16 +76,12 @@ class TestGpuProfilingParser(unittest.TestCase): res._marks = defaultdict(int) res._calculate_performance_time() self.assertEqual(res._result_data.overall_metrics.e2e_time, 98) - self.assertEqual(res._result_data.overall_metrics.sdma_time, 4) + self.assertEqual(res._result_data.overall_metrics.sdma_time, 0.004) self.assertEqual(res._result_data.overall_metrics.sdma_num, 4) - self.assertEqual(res._result_data.overall_metrics.cube_time, 1) + self.assertEqual(res._result_data.overall_metrics.cube_time, 0.001) self.assertEqual(res._result_data.overall_metrics.cube_num, 1) - self.assertEqual(res._result_data.overall_metrics.fa_time_fwd, 2) - self.assertEqual(res._result_data.overall_metrics.fa_num_fwd, 2) - self.assertEqual(res._result_data.overall_metrics.fa_time_bwd, 2) - self.assertEqual(res._result_data.overall_metrics.fa_num_bwd, 2) - self.assertEqual(res._result_data.overall_metrics.vec_time, 2) - self.assertEqual(res._result_data.overall_metrics.vec_num, 2) # cun yi + self.assertEqual(res._result_data.overall_metrics.vec_time, 0.006) + self.assertEqual(res._result_data.overall_metrics.vec_num, 6) # cun yi self.assertEqual(res._result_data.overall_metrics.communication_not_overlapped, 2) self.assertEqual(res._result_data.overall_metrics.compute_time, 7) diff --git a/profiler/test/ut/compare_tools/utils/test_tree_builder.py b/profiler/test/ut/compare_tools/utils/test_tree_builder.py index b9565b45ed5a049fbb23a246cafa888f75dc7102..326a424d3dd9a36d158816ba73ffcf260ac583d9 100644 --- a/profiler/test/ut/compare_tools/utils/test_tree_builder.py +++ b/profiler/test/ut/compare_tools/utils/test_tree_builder.py @@ -18,11 +18,11 @@ class TestUtils(unittest.TestCase): for event in event_list: event.is_torch_op = True tree = TreeBuilder.build_tree(event_list, flow_kernel_dict, memory_allocated_list) - child_nodes = tree.child_nodes - self.assertEqual(len(tree._child_nodes), 2) + child_nodes = tree[0].child_nodes + self.assertEqual(len(tree[0].child_nodes), 2) self.assertEqual(child_nodes[0].start_time, 0) self.assertEqual(child_nodes[0].end_time, 1) self.assertEqual(child_nodes[0].kernel_num, 2) self.assertEqual(child_nodes[1].kernel_num, 0) - self.assertEqual(len(TreeBuilder.get_total_kernels(tree)), 2) - self.assertEqual(TreeBuilder.get_total_memory(tree)[0].size, 1) + self.assertEqual(len(TreeBuilder.get_total_kernels(tree[0])), 2) + self.assertEqual(TreeBuilder.get_total_memory(tree[0])[0].size, 1) diff --git "a/\345\205\254\347\275\221URL\350\257\264\346\230\216.md" "b/\345\205\254\347\275\221URL\350\257\264\346\230\216.md" index abf8e10555980a3b85d26d331919cbf51c9a42ed..2db9b87864c905d8df93f46e8232c24defe48568 100644 --- "a/\345\205\254\347\275\221URL\350\257\264\346\230\216.md" +++ "b/\345\205\254\347\275\221URL\350\257\264\346\230\216.md" @@ -7,4 +7,5 @@ | 开源软件 | MindStudio Training Tools - msprof-analyze advisor | /profiler/advisor/common/constant.py | 公网地址 | ["https://www.hiascend.com/document/detail/zh/canncommercial/70RC1/devtools/auxiliarydevtool/aoe_16_045.html"] | Advisor优化手段参考示例 | | 开源软件 | MindStudio Training Tools - msprof-analyze advisor | /profiler/advisor/common/constant.py | 公网地址 | ["https://www.mindspore.cn/lite/docs/en/master/use/cloud_infer/converter_tool_ascend.html#aoe-auto-tuning"] | Advisor优化手段参考示例 | | 开源软件 | MindStudio Training Tools - msprof-analyze advisor | /profiler/advisor/common/constant.py | 公网地址 | ["https://www.hiascend.com/document/detail/zh/canncommercial/70RC1/modeldevpt/ptmigr/AImpug_0059.html"] | Advisor优化手段参考示例 | +| 文档 | MindStudio Training Tools - msprof-analyze advisor | /profiler/advisor/common/constant.py | 公网地址 | ["https://support.huawei.com/enterprise/zh/doc/EDOC1100371278/5eeeed85?idPath=23710424"] | Advisor优化手段参考示例 |