diff --git a/README.md b/README.md index 87a1a03725d6778b24ab322b7ee8a3725c303e4a..a51c3a258d8b67289056cd90cdefbd3704b8f0a2 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,8 @@ # ATT -Ascend Training Tools,昇腾训练工具链。针对训练&大模型场景,提供端到端命令行&可视化调试调优工具,帮助用户快速提高模型开发效率。 +Ascend Training Tools,昇腾训练工具链。【Powered by MindStudio】 + +针对训练&大模型场景,提供端到端命令行&可视化调试调优工具,帮助用户快速提高模型开发效率。 ## 模型训练迁移全流程 ![输入图片说明](debug/resources/model_training_migration_process.png) @@ -45,6 +47,10 @@ Ascend Training Tools,昇腾训练工具链。针对训练&大模型场景, 提供多机多卡的集群分析能力(基于通信域的通信分析和迭代耗时分析), 当前需要配合Ascend Insight的集群分析功能使用。 +3. [affinity_cpu_bind (亲和性cpu绑核工具) ](https://gitee.com/ascend/att/tree/master/profiler/affinity_cpu_bind) + + 提供亲和性CPU绑核能力,改善host_bound调度问题。 + ### [Tensorboard](https://gitee.com/ascend/att/tree/master/plugins/tensorboard-plugins/tb_plugin) Tensorboard支持NPU性能数据可视化插件PyTorch Profiler TensorBoard NPU Plugin。 diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/__init__.py b/debug/accuracy_tools/api_accuracy_checker/dump/__init__.py deleted file mode 100644 index f3e3fe66364169f8d1617acfd378905e225a52d2..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/api_accuracy_checker/dump/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from api_accuracy_checker.dump.dump import set_dump_switch -import api_accuracy_checker.dump.dump_scope -from api_accuracy_checker.common.config import msCheckerConfig - -__all__ = ['set_dump_switch'] diff --git a/debug/accuracy_tools/msacc/__init__.py b/debug/accuracy_tools/msacc/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..06d63e90feca1bd4f114292798fc9a763058511a --- /dev/null +++ b/debug/accuracy_tools/msacc/__init__.py @@ -0,0 +1 @@ +from .debugger.precision_debugger import PrecisionDebugger diff --git a/debug/accuracy_tools/atat/atat.py b/debug/accuracy_tools/msacc/atat.py similarity index 86% rename from debug/accuracy_tools/atat/atat.py rename to debug/accuracy_tools/msacc/atat.py index 4f69afd2349f211d1c6e17ab9386de3c8fcd6909..890349afa069fb0466c35a43eaf659e8bb5ae30d 100644 --- a/debug/accuracy_tools/atat/atat.py +++ b/debug/accuracy_tools/msacc/atat.py @@ -15,11 +15,11 @@ import argparse import sys -from api_accuracy_checker.run_ut.run_ut import _run_ut_parser, run_ut_command +from calibrator.pytorch.api_accuracy_checker import _run_ut_parser, run_ut_command from ptdbg_ascend.src.python.ptdbg_ascend.parse_tool.cli import parse as cli_parse -from api_accuracy_checker.run_ut.multi_run_ut import prepare_config, run_parallel_ut -from api_accuracy_checker.compare.api_precision_compare import _api_precision_compare_parser, _api_precision_compare_command -from api_accuracy_checker.run_ut.run_overflow_check import _run_overflow_check_parser, _run_overflow_check_command +from calibrator.pytorch.api_accuracy_checker import prepare_config, run_parallel_ut +from calibrator.pytorch.api_accuracy_checker import _api_precision_compare_parser, _api_precision_compare_command +from calibrator.pytorch.api_accuracy_checker import _run_overflow_check_parser, _run_overflow_check_command def main(): diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/__init__.py b/debug/accuracy_tools/msacc/pytorch/__init__.py similarity index 100% rename from debug/accuracy_tools/api_accuracy_checker/compare/__init__.py rename to debug/accuracy_tools/msacc/pytorch/__init__.py diff --git a/debug/accuracy_tools/api_accuracy_checker/.keep b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/.keep similarity index 100% rename from debug/accuracy_tools/api_accuracy_checker/.keep rename to debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/.keep diff --git a/debug/accuracy_tools/api_accuracy_checker/README.md b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/README.md similarity index 97% rename from debug/accuracy_tools/api_accuracy_checker/README.md rename to debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/README.md index 7738501db87b1cacbc9eb96687bf09aed3a5ed68..6b5bc966a4bb27fe4da45af3abe74df3e01baaf3 100644 --- a/debug/accuracy_tools/api_accuracy_checker/README.md +++ b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/README.md @@ -62,16 +62,17 @@ Ascend模型精度预检工具能在昇腾NPU上扫描用户训练模型中所 在训练代码中添加数据dump操作如下: ```Python - import api_accuracy_checker.dump as DP + +from calibrator import api_accuracy_checker as DP # 需要先修改enable_dataloader参数值为False # 关闭torch.utils.data.dataloader加载数据时,下列代码须在训练step代码内添加 - DP.dump.start() # 开启工具dump模块 + calibrator.api_accuracy_checker.dump.dump.start() # 开启工具dump模块 ... - DP.dump.stop() # 控制dump结束 - DP.dump.step() # 在DP.dump.stop()后加入DP.dump.step()即可指定需要dump的step + calibrator.api_accuracy_checker.dump.dump.stop() # 控制dump结束 + calibrator.api_accuracy_checker.dump.dump.step() # 在DP.dump.stop()后加入DP.dump.step()即可指定需要dump的step ``` 上述代码要添加在迭代内,如对于[ModelLink](https://gitee.com/ascend/ModelLink)的LLAMA2-7B可以添加在training.py中train函数的iteration循环内。之后工具会适配这个场景开关的自动打开。 @@ -89,7 +90,8 @@ Ascend模型精度预检工具能在昇腾NPU上扫描用户训练模型中所 其次,在训练脚本中加入以下代码导入工具dump模块,启动训练即可自动抓取网络所有API信息。 ```python - import api_accuracy_checker.dump + +from calibrator.pytorch import api_accuracy_checker ``` 工具默认抓取训练的**第二个迭代**并且在第二个迭代后会报错退出训练进程,可通过target_iter参数配置。 @@ -117,7 +119,8 @@ forward_info与stack_info中的key值一一对应,用户可根据forward_info 预检工具默认为随机数据模式,如果想要完全复刻整网的API运行情况,可以使用真实数据模式,添加以下代码即可: ```python -from api_accuracy_checker.dump import msCheckerConfig +from calibrator.api_accuracy_checker.dump import msCheckerConfig + msCheckerConfig.update_config(real_data=True) ``` @@ -126,15 +129,16 @@ msCheckerConfig.update_config(real_data=True) 精度预检工具可以对指定API进行预检操作,可以在dump时的训练脚本中直接添加白名单参数,只dump指定的API数据,示例代码如下: ```python -from api_accuracy_checker.dump import msCheckerConfig +from calibrator.api_accuracy_checker.dump import msCheckerConfig + msCheckerConfig.update_config(white_list=["conv1d", "conv2d"]) ``` -配置的API名称须存在于[support_wrap_ops.yaml](./hook_module/support_wrap_ops.yaml)文件下。 +配置的API名称须存在于[support_wrap_ops.yaml](hook_module/support_wrap_ops.yaml)文件下。 #### 工具支持的API列表 -预检工具维护固定的API支持列表,若需要删除或增加dump的API,可以在[support_wrap_ops.yaml](./hook_module/support_wrap_ops.yaml)文件内手动修改,如下示例: +预检工具维护固定的API支持列表,若需要删除或增加dump的API,可以在[support_wrap_ops.yaml](hook_module/support_wrap_ops.yaml)文件内手动修改,如下示例: ```bash functional: # functional为算子类别,找到对应的类别,在该类别下按照下列格式删除或添加API @@ -154,7 +158,7 @@ functional: # functional为算子类别,找到对应的类别,在该类别 **函数原型** ```python -msCheckerConfig.update_config(dump_path="./", real_data=False, target_iter=[1], white_list=[], enable_dataloader=False) +msCheckerConfig.update_config(dump_path="/", real_data=False, target_iter=[1], white_list=[], enable_dataloader=False) ``` **参数说明** diff --git a/debug/accuracy_tools/api_accuracy_checker/__init__.py b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/__init__.py similarity index 91% rename from debug/accuracy_tools/api_accuracy_checker/__init__.py rename to debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/__init__.py index 22c3634838883ae8d91737ce3b2788b43dc1712b..7a89d8d6a18d8c32aaac83b10d2004d0b22dc010 100644 --- a/debug/accuracy_tools/api_accuracy_checker/__init__.py +++ b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/__init__.py @@ -1,21 +1,21 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -# Copyright (C) 2019-2020. Huawei Technologies Co., Ltd. All rights reserved. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -""" - -from api_accuracy_checker.common.utils import seed_all -seed_all() -__all__ = [] +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2019-2020. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" + +from calibrator.pytorch.api_accuracy_checker.common import seed_all +seed_all() +__all__ = [] diff --git a/debug/accuracy_tools/api_accuracy_checker/common/.keep b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/common/.keep similarity index 100% rename from debug/accuracy_tools/api_accuracy_checker/common/.keep rename to debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/common/.keep diff --git a/debug/accuracy_tools/atat/__init__.py b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/common/__init__.py similarity index 100% rename from debug/accuracy_tools/atat/__init__.py rename to debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/common/__init__.py diff --git a/debug/accuracy_tools/api_accuracy_checker/common/config.py b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/common/config.py similarity index 93% rename from debug/accuracy_tools/api_accuracy_checker/common/config.py rename to debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/common/config.py index 57f59b0785c5cbb2b3e41903ce7350c414fadf27..eb54bb59e830210883e3a1939df7262eb2ea7afe 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/config.py +++ b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/common/config.py @@ -1,76 +1,76 @@ -import os -import yaml -from api_accuracy_checker.common.utils import check_file_or_directory_path -from api_accuracy_checker.hook_module.utils import WrapFunctionalOps, WrapTensorOps, WrapTorchOps -from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileOpen - -WrapApi = set(WrapFunctionalOps) | set(WrapTensorOps) | set(WrapTorchOps) - - -class Config: - def __init__(self, yaml_file): - check_file_or_directory_path(yaml_file, False) - with FileOpen(yaml_file, 'r') as file: - config = yaml.safe_load(file) - self.config = {key: self.validate(key, value) for key, value in config.items()} - - def validate(self, key, value): - validators = { - 'dump_path': str, - 'real_data': bool, - 'enable_dataloader': bool, - 'target_iter': list, - 'white_list': list, - 'error_data_path': str, - 'jit_compile': bool, - 'precision': int - } - if key not in validators: - raise ValueError(f"{key} must be one of {validators.keys()}") - if not isinstance(value, validators.get(key)): - raise ValueError(f"{key} must be {validators[key].__name__} type") - if key == 'target_iter': - if not isinstance(value, list): - raise ValueError("target_iter must be a list type") - if any(isinstance(i, bool) for i in value): - raise ValueError("target_iter cannot contain boolean values") - if not all(isinstance(i, int) for i in value): - raise ValueError("All elements in target_iter must be of int type") - if any(i < 0 for i in value): - raise ValueError("All elements in target_iter must be greater than or equal to 0") - if key == 'precision' and value < 0: - raise ValueError("precision must be greater than 0") - if key == 'white_list': - if not isinstance(value, list): - raise ValueError("white_list must be a list type") - if not all(isinstance(i, str) for i in value): - raise ValueError("All elements in white_list must be of str type") - invalid_api = [i for i in value if i not in WrapApi] - if invalid_api: - raise ValueError(f"{', '.join(invalid_api)} is not in support_wrap_ops.yaml, please check the white_list") - return value - - def __getattr__(self, item): - return self.config[item] - - def __str__(self): - return '\n'.join(f"{key}={value}" for key, value in self.config.items()) - - def update_config(self, dump_path=None, real_data=None, target_iter=None, white_list=None, enable_dataloader=None): - args = { - "dump_path": dump_path if dump_path else self.config.get("dump_path", './'), - "real_data": real_data if real_data else self.config.get("real_data", False), - "target_iter": target_iter if target_iter else self.config.get("target_iter", [1]), - "white_list": white_list if white_list else self.config.get("white_list", []), - "enable_dataloader": enable_dataloader if enable_dataloader else self.config.get("enable_dataloader", False) - } - for key, value in args.items(): - if key in self.config: - self.config[key] = self.validate(key, value) - else: - raise ValueError(f"Invalid key '{key}'") - - -cur_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) -yaml_path = os.path.join(cur_path, "config.yaml") +import os +import yaml +from calibrator.pytorch.api_accuracy_checker.common import check_file_or_directory_path +from calibrator.pytorch.api_accuracy_checker import WrapFunctionalOps, WrapTensorOps, WrapTorchOps +from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileOpen + +WrapApi = set(WrapFunctionalOps) | set(WrapTensorOps) | set(WrapTorchOps) + + +class Config: + def __init__(self, yaml_file): + check_file_or_directory_path(yaml_file, False) + with FileOpen(yaml_file, 'r') as file: + config = yaml.safe_load(file) + self.config = {key: self.validate(key, value) for key, value in config.items()} + + def validate(self, key, value): + validators = { + 'dump_path': str, + 'real_data': bool, + 'enable_dataloader': bool, + 'target_iter': list, + 'white_list': list, + 'error_data_path': str, + 'jit_compile': bool, + 'precision': int + } + if key not in validators: + raise ValueError(f"{key} must be one of {validators.keys()}") + if not isinstance(value, validators.get(key)): + raise ValueError(f"{key} must be {validators[key].__name__} type") + if key == 'target_iter': + if not isinstance(value, list): + raise ValueError("target_iter must be a list type") + if any(isinstance(i, bool) for i in value): + raise ValueError("target_iter cannot contain boolean values") + if not all(isinstance(i, int) for i in value): + raise ValueError("All elements in target_iter must be of int type") + if any(i < 0 for i in value): + raise ValueError("All elements in target_iter must be greater than or equal to 0") + if key == 'precision' and value < 0: + raise ValueError("precision must be greater than 0") + if key == 'white_list': + if not isinstance(value, list): + raise ValueError("white_list must be a list type") + if not all(isinstance(i, str) for i in value): + raise ValueError("All elements in white_list must be of str type") + invalid_api = [i for i in value if i not in WrapApi] + if invalid_api: + raise ValueError(f"{', '.join(invalid_api)} is not in support_wrap_ops.yaml, please check the white_list") + return value + + def __getattr__(self, item): + return self.config[item] + + def __str__(self): + return '\n'.join(f"{key}={value}" for key, value in self.config.items()) + + def update_config(self, dump_path=None, real_data=None, target_iter=None, white_list=None, enable_dataloader=None): + args = { + "dump_path": dump_path if dump_path else self.config.get("dump_path", './'), + "real_data": real_data if real_data else self.config.get("real_data", False), + "target_iter": target_iter if target_iter else self.config.get("target_iter", [1]), + "white_list": white_list if white_list else self.config.get("white_list", []), + "enable_dataloader": enable_dataloader if enable_dataloader else self.config.get("enable_dataloader", False) + } + for key, value in args.items(): + if key in self.config: + self.config[key] = self.validate(key, value) + else: + raise ValueError(f"Invalid key '{key}'") + + +cur_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) +yaml_path = os.path.join(cur_path, "config.yaml") msCheckerConfig = Config(yaml_path) \ No newline at end of file diff --git a/debug/accuracy_tools/api_accuracy_checker/common/utils.py b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/common/utils.py similarity index 96% rename from debug/accuracy_tools/api_accuracy_checker/common/utils.py rename to debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/common/utils.py index eee58ef7ae5e278993ff8eade41052de7b9deec1..05ee8bb4a6bb5b73e6c88729647c019c03da19b9 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/utils.py +++ b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/common/utils.py @@ -1,651 +1,652 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -# Copyright (C) 2023-2023. Huawei Technologies Co., Ltd. All rights reserved. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -import collections -import json -import os -import random -import re -import stat -import subprocess -import sys -import time -import csv -from datetime import datetime, timezone - -import numpy as np -import torch - -try: - import torch_npu -except ImportError: - IS_GPU = True -else: - IS_GPU = False - -from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileCheckConst, FileChecker, FileOpen -from ptdbg_ascend.src.python.ptdbg_ascend.common import file_check_util - -torch_without_guard_version_list = ['2.1'] -for version in torch_without_guard_version_list: - if torch.__version__.startswith(version): - torch_without_guard_version = True - break - else: - torch_without_guard_version = False -if not IS_GPU and not torch_without_guard_version: - from torch_npu.utils.device_guard import torch_device_guard as torch_npu_device_guard - - -class Const: - """ - Class for const - """ - DIRECTORY_LENGTH = 4096 - FILE_NAME_LENGTH = 255 - FILE_PATTERN = r'^[a-zA-Z0-9_./-]+$' - MODEL_TYPE = ['.onnx', '.pb', '.om'] - SEMICOLON = ";" - COLON = ":" - EQUAL = "=" - COMMA = "," - DOT = "." - DUMP_RATIO_MAX = 100 - SUMMERY_DATA_NUMS = 256 - ONE_HUNDRED_MB = 100 * 1024 * 1024 - FLOAT_EPSILON = np.finfo(float).eps - SUPPORT_DUMP_MODE = ['api', 'acl'] - ON = 'ON' - OFF = 'OFF' - BACKWARD = 'backward' - FORWARD = 'forward' - FLOAT_TYPE = [np.half, np.single, float, np.double, np.float64, np.longdouble, np.float32, np.float16] - BOOL_TYPE = [bool, np.uint8] - INT_TYPE = [np.int32, np.int64] - - # dump mode - ALL = "all" - LIST = "list" - RANGE = "range" - STACK = "stack" - ACL = "acl" - API_LIST = "api_list" - API_STACK = "api_stack" - DUMP_MODE = [ALL, LIST, RANGE, STACK, ACL, API_LIST, API_STACK] - - WRITE_FLAGS = os.O_WRONLY | os.O_CREAT - WRITE_MODES = stat.S_IWUSR | stat.S_IRUSR - - RAISE_PRECISION = { - torch.float16: torch.float32, - torch.bfloat16: torch.float32, - torch.float32: torch.float64 - } - CONVERT = { - "int32_to_int64": ["torch.int32", "torch.int64"], - } - - CONVERT_API = { - "int32_to_int64": ["cross_entropy"] - } - - -class CompareConst: - """ - Class for compare module const - """ - # compare result column name - NPU_NAME = "NPU Name" - BENCH_NAME = "Bench Name" - NPU_DTYPE = "NPU Tensor Dtype" - BENCH_DTYPE = "Bench Tensor Dtype" - NPU_SHAPE = "NPU Tensor Shape" - BENCH_SHAPE = "Bench Tensor Shape" - NPU_MAX = "NPU max" - NPU_MIN = "NPU min" - NPU_MEAN = "NPU mean" - BENCH_MAX = "Bench max" - BENCH_MIN = "Bench min" - BENCH_MEAN = "Bench mean" - COSINE = "Cosine" - MAX_ABS_ERR = "MaxAbsErr" - ACCURACY = "Accuracy Reached or Not" - STACK = "NPU_Stack_Info" - ERROR_MESSAGE = "Err_message" - - # compare result data - NAN = 'Nan' - SHAPE_UNMATCH = 'shape unmatched' - DTYPE_UNMATCH = 'dtype unmatched' - - # accuracy standards - COS_THRESHOLD = 0.99 - MAX_ABS_ERR_THRESHOLD = 0.001 - COS_MAX_THRESHOLD = 0.9 - MAX_ABS_ERR_MAX_THRESHOLD = 1 - ACCURACY_CHECK_YES = "Yes" - ACCURACY_CHECK_NO = "No" - ACCURACY_CHECK_UNMATCH = "Unmatched" - - # error message - NO_BENCH = "No bench data matched." - - -class VersionCheck: - """ - Class for TorchVersion - """ - V1_8 = "1.8" - V1_11 = "1.11" - - @staticmethod - def check_torch_version(version): - torch_version = torch.__version__ - if torch_version.startswith(version): - return True - else: - return False - - -class CompareException(Exception): - """ - Class for Accuracy Compare Exception - """ - NONE_ERROR = 0 - INVALID_PATH_ERROR = 1 - OPEN_FILE_ERROR = 2 - CLOSE_FILE_ERROR = 3 - READ_FILE_ERROR = 4 - WRITE_FILE_ERROR = 5 - INVALID_FILE_ERROR = 6 - PERMISSION_ERROR = 7 - INDEX_OUT_OF_BOUNDS_ERROR = 8 - NO_DUMP_FILE_ERROR = 9 - INVALID_DATA_ERROR = 10 - INVALID_PARAM_ERROR = 11 - INVALID_DUMP_RATIO = 12 - INVALID_DUMP_FILE = 13 - UNKNOWN_ERROR = 14 - INVALID_DUMP_MODE = 15 - PARSE_FILE_ERROR = 16 - INVALID_COMPARE_MODE = 17 - - def __init__(self, code, error_info: str = ""): - super(CompareException, self).__init__() - self.code = code - self.error_info = error_info - - def __str__(self): - return self.error_info - - -class DumpException(CompareException): - pass - - -def read_json(file): - with FileOpen(file, 'r') as f: - obj = json.load(f) - return obj - - -def write_csv(data, filepath): - with FileOpen(filepath, 'a', encoding='utf-8-sig') as f: - writer = csv.writer(f) - writer.writerows(data) - - -def _print_log(level, msg, end='\n'): - current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(int(time.time()))) - pid = os.getgid() - print(current_time + "(" + str(pid) + ")-[" + level + "]" + msg, end=end) - sys.stdout.flush() - - -def print_info_log(info_msg, end='\n'): - """ - Function Description: - print info log. - Parameter: - info_msg: the info message. - """ - _print_log("INFO", info_msg, end=end) - - -def print_error_log(error_msg): - """ - Function Description: - print error log. - Parameter: - error_msg: the error message. - """ - _print_log("ERROR", error_msg) - - -def print_warn_log(warn_msg): - """ - Function Description: - print warn log. - Parameter: - warn_msg: the warning message. - """ - _print_log("WARNING", warn_msg) - - -def check_mode_valid(mode): - if mode not in Const.DUMP_MODE: - msg = "Current mode '%s' is not supported. Please use the field in %s" % \ - (mode, Const.DUMP_MODE) - raise CompareException(CompareException.INVALID_DUMP_MODE, msg) - - -def check_object_type(check_object, allow_type): - """ - Function Description: - Check if the object belongs to a certain data type - Parameter: - check_object: the object to be checked - allow_type: legal data type - Exception Description: - when invalid data throw exception - """ - if not isinstance(check_object, allow_type): - print_error_log(f"{check_object} not of {allow_type} type") - raise CompareException(CompareException.INVALID_DATA_ERROR) - - -def check_file_or_directory_path(path, isdir=False): - """ - Function Description: - check whether the path is valid - Parameter: - path: the path to check - isdir: the path is dir or file - Exception Description: - when invalid data throw exception - """ - if isdir: - if not os.path.exists(path): - print_error_log('The path {} is not exist.'.format(path)) - raise CompareException(CompareException.INVALID_PATH_ERROR) - - if not os.path.isdir(path): - print_error_log('The path {} is not a directory.'.format(path)) - raise CompareException(CompareException.INVALID_PATH_ERROR) - - if not os.access(path, os.W_OK): - print_error_log( - 'The path {} does not have permission to write. Please check the path permission'.format(path)) - raise CompareException(CompareException.INVALID_PATH_ERROR) - else: - if not os.path.isfile(path): - print_error_log('{} is an invalid file or non-exist.'.format(path)) - raise CompareException(CompareException.INVALID_PATH_ERROR) - - if not os.access(path, os.R_OK): - print_error_log( - 'The path {} does not have permission to read. Please check the path permission'.format(path)) - raise CompareException(CompareException.INVALID_PATH_ERROR) - - -def _check_pkl(pkl_file_handle, file_name): - tensor_line = pkl_file_handle.readline() - if len(tensor_line) == 0: - print_error_log("dump file {} have empty line!".format(file_name)) - raise CompareException(CompareException.INVALID_DUMP_FILE) - pkl_file_handle.seek(0, 0) - - -def check_file_mode(npu_pkl, bench_pkl, stack_mode): - npu_pkl_name = os.path.split(npu_pkl)[-1] - bench_pkl_name = os.path.split(bench_pkl)[-1] - - if not npu_pkl_name.startswith("api_stack") and not bench_pkl_name.startswith("api_stack"): - if stack_mode: - print_error_log("The current file does not contain stack information, please turn off the stack_mode") - raise CompareException(CompareException.INVALID_COMPARE_MODE) - elif npu_pkl_name.startswith("api_stack") and bench_pkl_name.startswith("api_stack"): - if not stack_mode: - print_error_log("The current file contains stack information, please turn on the stack_mode") - raise CompareException(CompareException.INVALID_COMPARE_MODE) - else: - print_error_log("The dump mode of the two files is not same, please check the dump files") - raise CompareException(CompareException.INVALID_COMPARE_MODE) - - -def check_file_size(input_file, max_size): - try: - file_size = os.path.getsize(input_file) - except OSError as os_error: - print_error_log('Failed to open "%s". %s' % (input_file, str(os_error))) - raise CompareException(CompareException.INVALID_FILE_ERROR) from os_error - if file_size > max_size: - print_error_log('The size (%d) of %s exceeds (%d) bytes, tools not support.' - % (file_size, input_file, max_size)) - raise CompareException(CompareException.INVALID_FILE_ERROR) - - -def get_dump_data_path(dump_dir): - """ - Function Description: - traverse directories and obtain the absolute path of dump data - Parameter: - dump_dir: dump data directory - Return Value: - dump data path,file is exist or file is not exist - """ - dump_data_path = None - file_is_exist = False - - check_file_or_directory_path(dump_dir, True) - for dir_path, sub_paths, files in os.walk(dump_dir): - if len(files) != 0: - dump_data_path = dir_path - file_is_exist = True - break - dump_data_path = dir_path - return dump_data_path, file_is_exist - - -def modify_dump_path(dump_path, mode): - if mode == Const.ALL: - return dump_path - file_name = os.path.split(dump_path) - mode_file_name = mode + "_" + file_name[-1] - return os.path.join(file_name[0], mode_file_name) - - -def create_directory(dir_path): - """ - Function Description: - creating a directory with specified permissions in a thread-safe manner - Parameter: - dir_path: directory path - Exception Description: - when invalid data throw exception - """ - try: - os.makedirs(dir_path, mode=FileCheckConst.DATA_DIR_AUTHORITY, exist_ok=True) - except OSError as ex: - print_error_log( - 'Failed to create {}. Please check the path permission or disk space. {}'.format(dir_path, str(ex))) - raise CompareException(CompareException.INVALID_PATH_ERROR) from ex - - -def execute_command(cmd): - """ - Function Description: - run the following command - Parameter: - cmd: command - Exception Description: - when invalid command throw exception - """ - print_info_log('Execute command:%s' % cmd) - process = subprocess.Popen(cmd, shell=False, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) - while process.poll() is None: - line = process.stdout.readline() - line = line.strip() - if line: - print(line) - if process.returncode != 0: - print_error_log('Failed to execute command:%s' % " ".join(cmd)) - raise CompareException(CompareException.INVALID_DATA_ERROR) - - -def save_numpy_data(file_path, data): - """ - save_numpy_data - """ - if not os.path.exists(os.path.dirname(file_path)): - os.makedirs(os.path.dirname(file_path)) - np.save(file_path, data) - - -def parse_arg_value(values): - """ - parse dynamic arg value of atc cmdline - """ - value_list = [] - for item in values.split(Const.SEMICOLON): - value_list.append(parse_value_by_comma(item)) - return value_list - - -def parse_value_by_comma(value): - """ - parse value by comma, like '1,2,4,8' - """ - value_list = [] - value_str_list = value.split(Const.COMMA) - for value_str in value_str_list: - value_str = value_str.strip() - if value_str.isdigit() or value_str == '-1': - value_list.append(int(value_str)) - else: - print_error_log("please check your input shape.") - raise CompareException(CompareException.INVALID_PARAM_ERROR) - return value_list - - -def get_data_len_by_shape(shape): - data_len = 1 - for item in shape: - if item == -1: - print_error_log("please check your input shape, one dim in shape is -1.") - return -1 - data_len = data_len * item - return data_len - - -def add_time_as_suffix(name): - return '{}_{}.csv'.format(name, time.strftime("%Y%m%d%H%M%S", time.localtime(time.time()))) - - -def get_time(): - return datetime.now(tz=timezone.utc).strftime("%Y%m%d_%H%M%S") - - -def format_value(value): - return '{:.6f}'.format(value) - - -def torch_device_guard(func): - if IS_GPU or torch_without_guard_version: - return func - # Parse args/kwargs matched torch.device objects - - @torch_npu_device_guard - def wrapper(*args, **kwargs): - return func(*args, **kwargs) - return wrapper - - -def seed_all(seed=1234, mode=False): - random.seed(seed) - os.environ['PYTHONHASHSEED'] = str(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.use_deterministic_algorithms(mode) - if IS_GPU: - torch.cuda.manual_seed_all(seed) - torch.cuda.manual_seed(seed) - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.enable = False - torch.backends.cudnn.benchmark = False - else: - torch_npu.npu.manual_seed_all(seed) - torch_npu.npu.manual_seed(seed) - - -def get_process_rank(model): - print_info_log("Rank id is not provided. Trying to get the rank id of the model.") - try: - device = next(model.parameters()).device - except StopIteration: - print_warn_log('There is no parameter in the model. Fail to get rank id.') - return 0, False - if device.type == 'cpu': - print_warn_log("Warning: the debugger is unable to get the rank id. " - "This may cause the dumpped data to be corrupted in the " - "case of distributed training. (You may ignore this if you are using only one card.) " - "Transfer the model to npu or gpu before register_hook() to avoid this warning.") - return 0, False - else: - return device.index, True - - -def get_json_contents(file_path): - ops = get_file_content_bytes(file_path) - try: - json_obj = json.loads(ops) - except ValueError as error: - print_error_log('Failed to load "%s". %s' % (file_path, str(error))) - raise CompareException(CompareException.INVALID_FILE_ERROR) from error - if not isinstance(json_obj, dict): - print_error_log('Json file %s, content is not a dictionary!' % file_path) - raise CompareException(CompareException.INVALID_FILE_ERROR) - return json_obj - - -def get_file_content_bytes(file): - with FileOpen(file, 'rb') as file_handle: - return file_handle.read() - - -def islink(path): - path = os.path.abspath(path) - return os.path.islink(path) - - -class SoftlinkCheckException(Exception): - pass - - -MAX_JSON_FILE_SIZE = 10 * 1024 ** 2 -LINUX_FILE_NAME_LENGTH_LIMIT = 200 - - -def check_path_length_valid(path): - path = os.path.realpath(path) - return len(os.path.basename(path)) <= LINUX_FILE_NAME_LENGTH_LIMIT - - -def check_path_pattern_valid(path): - pattern = re.compile(r'(\.|/|:|_|-|\s|[~0-9a-zA-Z])+') - if not pattern.fullmatch(path): - raise ValueError('Only the following characters are allowed in the path: A-Z a-z 0-9 - _ . / :') - - -def check_input_file_valid(input_path, max_file_size=MAX_JSON_FILE_SIZE): - if islink(input_path): - raise SoftlinkCheckException("Input path doesn't support soft link.") - - input_path = os.path.realpath(input_path) - if not os.path.exists(input_path): - raise ValueError('Input file %s does not exist!' % input_path) - - if not os.access(input_path, os.R_OK): - raise PermissionError('Input file %s is not readable!' % input_path) - - if not check_path_length_valid(input_path): - raise ValueError("The real path or file_name of input is too long.") - - check_path_pattern_valid(input_path) - - if os.path.getsize(input_path) > max_file_size: - raise ValueError(f'The file is too large, exceeds {max_file_size // 1024 ** 2}MB') - - -def check_need_convert(api_name): - convert_type = None - for key, value in Const.CONVERT_API.items(): - if api_name not in value: - continue - else: - convert_type = key - return convert_type - - -def api_info_preprocess(api_name, api_info_dict): - """ - Function Description: - Preprocesses the API information. - Parameter: - api_name: Name of the API. - api_info_dict: argument of the API. - Return api_info_dict: - convert_type: Type of conversion. - api_info_dict: Processed argument of the API. - """ - convert_type = check_need_convert(api_name) - if api_name == 'cross_entropy': - api_info_dict = cross_entropy_process(api_info_dict) - return convert_type, api_info_dict - - -def cross_entropy_process(api_info_dict): - """ - Function Description: - Preprocesses the cross_entropy API information. - Parameter: - api_info_dict: argument of the API. - Return api_info_dict: - api_info_dict: Processed argument of the API. - """ - if 'args' in api_info_dict and len(api_info_dict['args']) > 1 and 'Min' in api_info_dict['args'][1]: - if api_info_dict['args'][1]['Min'] <= 0: - # The second argument in cross_entropy should be -100 or not less than 0 - api_info_dict['args'][1]['Min'] = 0 - return api_info_dict - - -def initialize_save_path(save_path, dir_name): - data_path = os.path.join(save_path, dir_name) - if os.path.exists(data_path): - print_warn_log(f"{data_path} already exists, it will be overwritten") - else: - os.mkdir(data_path, mode=FileCheckConst.DATA_DIR_AUTHORITY) - data_path_checker = FileChecker(data_path, FileCheckConst.DIR) - data_path_checker.common_check() - - -def write_pt(file_path, tensor): - if os.path.exists(file_path): - raise ValueError(f"File {file_path} already exists") - torch.save(tensor, file_path) - full_path = os.path.realpath(file_path) - file_check_util.change_mode(full_path, FileCheckConst.DATA_FILE_AUTHORITY) - return full_path - - -def get_real_data_path(file_path): - targets = ['forward_real_data', 'backward_real_data', 'ut_error_data\d+'] - pattern = re.compile(r'({})'.format('|'.join(targets))) - match = pattern.search(file_path) - if match: - target_index = match.start() - target_path = file_path[target_index:] - return target_path - else: - raise DumpException(DumpException.INVALID_PATH_ERROR) - - -def get_full_data_path(data_path, real_data_path): - if not data_path: - return data_path - full_data_path = os.path.join(real_data_path, data_path) - return os.path.realpath(full_data_path) +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2023-2023. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +import collections +import json +import os +import random +import re +import stat +import subprocess +import sys +import time +import csv +from datetime import datetime, timezone + +import numpy as np +import torch + +try: + import torch_npu +except ImportError: + IS_GPU = True +else: + IS_GPU = False + +from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileCheckConst, FileChecker, FileOpen +from ptdbg_ascend.src.python.ptdbg_ascend.common import file_check_util + +torch_without_guard_version_list = ['2.1'] +for version in torch_without_guard_version_list: + if torch.__version__.startswith(version): + torch_without_guard_version = True + break + else: + torch_without_guard_version = False +if not IS_GPU and not torch_without_guard_version: + from torch_npu.utils.device_guard import torch_device_guard as torch_npu_device_guard + + +class Const: + """ + Class for const + """ + SEP = '.' + DIRECTORY_LENGTH = 4096 + FILE_NAME_LENGTH = 255 + FILE_PATTERN = r'^[a-zA-Z0-9_./-]+$' + MODEL_TYPE = ['.onnx', '.pb', '.om'] + SEMICOLON = ";" + COLON = ":" + EQUAL = "=" + COMMA = "," + DOT = "." + DUMP_RATIO_MAX = 100 + SUMMERY_DATA_NUMS = 256 + ONE_HUNDRED_MB = 100 * 1024 * 1024 + FLOAT_EPSILON = np.finfo(float).eps + SUPPORT_DUMP_MODE = ['api', 'acl'] + ON = 'ON' + OFF = 'OFF' + BACKWARD = 'backward' + FORWARD = 'forward' + FLOAT_TYPE = [np.half, np.single, float, np.double, np.float64, np.longdouble, np.float32, np.float16] + BOOL_TYPE = [bool, np.uint8] + INT_TYPE = [np.int32, np.int64] + + # dump mode + ALL = "all" + LIST = "list" + RANGE = "range" + STACK = "stack" + ACL = "acl" + API_LIST = "api_list" + API_STACK = "api_stack" + DUMP_MODE = [ALL, LIST, RANGE, STACK, ACL, API_LIST, API_STACK] + + WRITE_FLAGS = os.O_WRONLY | os.O_CREAT + WRITE_MODES = stat.S_IWUSR | stat.S_IRUSR + + RAISE_PRECISION = { + torch.float16: torch.float32, + torch.bfloat16: torch.float32, + torch.float32: torch.float64 + } + CONVERT = { + "int32_to_int64": ["torch.int32", "torch.int64"], + } + + CONVERT_API = { + "int32_to_int64": ["cross_entropy"] + } + + +class CompareConst: + """ + Class for compare module const + """ + # compare result column name + NPU_NAME = "NPU Name" + BENCH_NAME = "Bench Name" + NPU_DTYPE = "NPU Tensor Dtype" + BENCH_DTYPE = "Bench Tensor Dtype" + NPU_SHAPE = "NPU Tensor Shape" + BENCH_SHAPE = "Bench Tensor Shape" + NPU_MAX = "NPU max" + NPU_MIN = "NPU min" + NPU_MEAN = "NPU mean" + BENCH_MAX = "Bench max" + BENCH_MIN = "Bench min" + BENCH_MEAN = "Bench mean" + COSINE = "Cosine" + MAX_ABS_ERR = "MaxAbsErr" + ACCURACY = "Accuracy Reached or Not" + STACK = "NPU_Stack_Info" + ERROR_MESSAGE = "Err_message" + + # compare result data + NAN = 'Nan' + SHAPE_UNMATCH = 'shape unmatched' + DTYPE_UNMATCH = 'dtype unmatched' + + # accuracy standards + COS_THRESHOLD = 0.99 + MAX_ABS_ERR_THRESHOLD = 0.001 + COS_MAX_THRESHOLD = 0.9 + MAX_ABS_ERR_MAX_THRESHOLD = 1 + ACCURACY_CHECK_YES = "Yes" + ACCURACY_CHECK_NO = "No" + ACCURACY_CHECK_UNMATCH = "Unmatched" + + # error message + NO_BENCH = "No bench data matched." + + +class VersionCheck: + """ + Class for TorchVersion + """ + V1_8 = "1.8" + V1_11 = "1.11" + + @staticmethod + def check_torch_version(version): + torch_version = torch.__version__ + if torch_version.startswith(version): + return True + else: + return False + + +class CompareException(Exception): + """ + Class for Accuracy Compare Exception + """ + NONE_ERROR = 0 + INVALID_PATH_ERROR = 1 + OPEN_FILE_ERROR = 2 + CLOSE_FILE_ERROR = 3 + READ_FILE_ERROR = 4 + WRITE_FILE_ERROR = 5 + INVALID_FILE_ERROR = 6 + PERMISSION_ERROR = 7 + INDEX_OUT_OF_BOUNDS_ERROR = 8 + NO_DUMP_FILE_ERROR = 9 + INVALID_DATA_ERROR = 10 + INVALID_PARAM_ERROR = 11 + INVALID_DUMP_RATIO = 12 + INVALID_DUMP_FILE = 13 + UNKNOWN_ERROR = 14 + INVALID_DUMP_MODE = 15 + PARSE_FILE_ERROR = 16 + INVALID_COMPARE_MODE = 17 + + def __init__(self, code, error_info: str = ""): + super(CompareException, self).__init__() + self.code = code + self.error_info = error_info + + def __str__(self): + return self.error_info + + +class DumpException(CompareException): + pass + + +def read_json(file): + with FileOpen(file, 'r') as f: + obj = json.load(f) + return obj + + +def write_csv(data, filepath): + with FileOpen(filepath, 'a', encoding='utf-8-sig') as f: + writer = csv.writer(f) + writer.writerows(data) + + +def _print_log(level, msg, end='\n'): + current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(int(time.time()))) + pid = os.getgid() + print(current_time + "(" + str(pid) + ")-[" + level + "]" + msg, end=end) + sys.stdout.flush() + + +def print_info_log(info_msg, end='\n'): + """ + Function Description: + print info log. + Parameter: + info_msg: the info message. + """ + _print_log("INFO", info_msg, end=end) + + +def print_error_log(error_msg): + """ + Function Description: + print error log. + Parameter: + error_msg: the error message. + """ + _print_log("ERROR", error_msg) + + +def print_warn_log(warn_msg): + """ + Function Description: + print warn log. + Parameter: + warn_msg: the warning message. + """ + _print_log("WARNING", warn_msg) + + +def check_mode_valid(mode): + if mode not in Const.DUMP_MODE: + msg = "Current mode '%s' is not supported. Please use the field in %s" % \ + (mode, Const.DUMP_MODE) + raise CompareException(CompareException.INVALID_DUMP_MODE, msg) + + +def check_object_type(check_object, allow_type): + """ + Function Description: + Check if the object belongs to a certain data type + Parameter: + check_object: the object to be checked + allow_type: legal data type + Exception Description: + when invalid data throw exception + """ + if not isinstance(check_object, allow_type): + print_error_log(f"{check_object} not of {allow_type} type") + raise CompareException(CompareException.INVALID_DATA_ERROR) + + +def check_file_or_directory_path(path, isdir=False): + """ + Function Description: + check whether the path is valid + Parameter: + path: the path to check + isdir: the path is dir or file + Exception Description: + when invalid data throw exception + """ + if isdir: + if not os.path.exists(path): + print_error_log('The path {} is not exist.'.format(path)) + raise CompareException(CompareException.INVALID_PATH_ERROR) + + if not os.path.isdir(path): + print_error_log('The path {} is not a directory.'.format(path)) + raise CompareException(CompareException.INVALID_PATH_ERROR) + + if not os.access(path, os.W_OK): + print_error_log( + 'The path {} does not have permission to write. Please check the path permission'.format(path)) + raise CompareException(CompareException.INVALID_PATH_ERROR) + else: + if not os.path.isfile(path): + print_error_log('{} is an invalid file or non-exist.'.format(path)) + raise CompareException(CompareException.INVALID_PATH_ERROR) + + if not os.access(path, os.R_OK): + print_error_log( + 'The path {} does not have permission to read. Please check the path permission'.format(path)) + raise CompareException(CompareException.INVALID_PATH_ERROR) + + +def _check_pkl(pkl_file_handle, file_name): + tensor_line = pkl_file_handle.readline() + if len(tensor_line) == 0: + print_error_log("dump file {} have empty line!".format(file_name)) + raise CompareException(CompareException.INVALID_DUMP_FILE) + pkl_file_handle.seek(0, 0) + + +def check_file_mode(npu_pkl, bench_pkl, stack_mode): + npu_pkl_name = os.path.split(npu_pkl)[-1] + bench_pkl_name = os.path.split(bench_pkl)[-1] + + if not npu_pkl_name.startswith("api_stack") and not bench_pkl_name.startswith("api_stack"): + if stack_mode: + print_error_log("The current file does not contain stack information, please turn off the stack_mode") + raise CompareException(CompareException.INVALID_COMPARE_MODE) + elif npu_pkl_name.startswith("api_stack") and bench_pkl_name.startswith("api_stack"): + if not stack_mode: + print_error_log("The current file contains stack information, please turn on the stack_mode") + raise CompareException(CompareException.INVALID_COMPARE_MODE) + else: + print_error_log("The dump mode of the two files is not same, please check the dump files") + raise CompareException(CompareException.INVALID_COMPARE_MODE) + + +def check_file_size(input_file, max_size): + try: + file_size = os.path.getsize(input_file) + except OSError as os_error: + print_error_log('Failed to open "%s". %s' % (input_file, str(os_error))) + raise CompareException(CompareException.INVALID_FILE_ERROR) from os_error + if file_size > max_size: + print_error_log('The size (%d) of %s exceeds (%d) bytes, tools not support.' + % (file_size, input_file, max_size)) + raise CompareException(CompareException.INVALID_FILE_ERROR) + + +def get_dump_data_path(dump_dir): + """ + Function Description: + traverse directories and obtain the absolute path of dump data + Parameter: + dump_dir: dump data directory + Return Value: + dump data path,file is exist or file is not exist + """ + dump_data_path = None + file_is_exist = False + + check_file_or_directory_path(dump_dir, True) + for dir_path, sub_paths, files in os.walk(dump_dir): + if len(files) != 0: + dump_data_path = dir_path + file_is_exist = True + break + dump_data_path = dir_path + return dump_data_path, file_is_exist + + +def modify_dump_path(dump_path, mode): + if mode == Const.ALL: + return dump_path + file_name = os.path.split(dump_path) + mode_file_name = mode + "_" + file_name[-1] + return os.path.join(file_name[0], mode_file_name) + + +def create_directory(dir_path): + """ + Function Description: + creating a directory with specified permissions in a thread-safe manner + Parameter: + dir_path: directory path + Exception Description: + when invalid data throw exception + """ + try: + os.makedirs(dir_path, mode=FileCheckConst.DATA_DIR_AUTHORITY, exist_ok=True) + except OSError as ex: + print_error_log( + 'Failed to create {}. Please check the path permission or disk space. {}'.format(dir_path, str(ex))) + raise CompareException(CompareException.INVALID_PATH_ERROR) from ex + + +def execute_command(cmd): + """ + Function Description: + run the following command + Parameter: + cmd: command + Exception Description: + when invalid command throw exception + """ + print_info_log('Execute command:%s' % cmd) + process = subprocess.Popen(cmd, shell=False, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + while process.poll() is None: + line = process.stdout.readline() + line = line.strip() + if line: + print(line) + if process.returncode != 0: + print_error_log('Failed to execute command:%s' % " ".join(cmd)) + raise CompareException(CompareException.INVALID_DATA_ERROR) + + +def save_numpy_data(file_path, data): + """ + save_numpy_data + """ + if not os.path.exists(os.path.dirname(file_path)): + os.makedirs(os.path.dirname(file_path)) + np.save(file_path, data) + + +def parse_arg_value(values): + """ + parse dynamic arg value of atc cmdline + """ + value_list = [] + for item in values.split(Const.SEMICOLON): + value_list.append(parse_value_by_comma(item)) + return value_list + + +def parse_value_by_comma(value): + """ + parse value by comma, like '1,2,4,8' + """ + value_list = [] + value_str_list = value.split(Const.COMMA) + for value_str in value_str_list: + value_str = value_str.strip() + if value_str.isdigit() or value_str == '-1': + value_list.append(int(value_str)) + else: + print_error_log("please check your input shape.") + raise CompareException(CompareException.INVALID_PARAM_ERROR) + return value_list + + +def get_data_len_by_shape(shape): + data_len = 1 + for item in shape: + if item == -1: + print_error_log("please check your input shape, one dim in shape is -1.") + return -1 + data_len = data_len * item + return data_len + + +def add_time_as_suffix(name): + return '{}_{}.csv'.format(name, time.strftime("%Y%m%d%H%M%S", time.localtime(time.time()))) + + +def get_time(): + return datetime.now(tz=timezone.utc).strftime("%Y%m%d_%H%M%S") + + +def format_value(value): + return '{:.6f}'.format(value) + + +def torch_device_guard(func): + if IS_GPU or torch_without_guard_version: + return func + # Parse args/kwargs matched torch.device objects + + @torch_npu_device_guard + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + return wrapper + + +def seed_all(seed=1234, mode=False): + random.seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.use_deterministic_algorithms(mode) + if IS_GPU: + torch.cuda.manual_seed_all(seed) + torch.cuda.manual_seed(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.enable = False + torch.backends.cudnn.benchmark = False + else: + torch_npu.npu.manual_seed_all(seed) + torch_npu.npu.manual_seed(seed) + + +def get_process_rank(model): + print_info_log("Rank id is not provided. Trying to get the rank id of the model.") + try: + device = next(model.parameters()).device + except StopIteration: + print_warn_log('There is no parameter in the model. Fail to get rank id.') + return 0, False + if device.type == 'cpu': + print_warn_log("Warning: the debugger is unable to get the rank id. " + "This may cause the dumpped data to be corrupted in the " + "case of distributed training. (You may ignore this if you are using only one card.) " + "Transfer the model to npu or gpu before register_hook() to avoid this warning.") + return 0, False + else: + return device.index, True + + +def get_json_contents(file_path): + ops = get_file_content_bytes(file_path) + try: + json_obj = json.loads(ops) + except ValueError as error: + print_error_log('Failed to load "%s". %s' % (file_path, str(error))) + raise CompareException(CompareException.INVALID_FILE_ERROR) from error + if not isinstance(json_obj, dict): + print_error_log('Json file %s, content is not a dictionary!' % file_path) + raise CompareException(CompareException.INVALID_FILE_ERROR) + return json_obj + + +def get_file_content_bytes(file): + with FileOpen(file, 'rb') as file_handle: + return file_handle.read() + + +def islink(path): + path = os.path.abspath(path) + return os.path.islink(path) + + +class SoftlinkCheckException(Exception): + pass + + +MAX_JSON_FILE_SIZE = 10 * 1024 ** 2 +LINUX_FILE_NAME_LENGTH_LIMIT = 200 + + +def check_path_length_valid(path): + path = os.path.realpath(path) + return len(os.path.basename(path)) <= LINUX_FILE_NAME_LENGTH_LIMIT + + +def check_path_pattern_valid(path): + pattern = re.compile(r'(\.|/|:|_|-|\s|[~0-9a-zA-Z])+') + if not pattern.fullmatch(path): + raise ValueError('Only the following characters are allowed in the path: A-Z a-z 0-9 - _ . / :') + + +def check_input_file_valid(input_path, max_file_size=MAX_JSON_FILE_SIZE): + if islink(input_path): + raise SoftlinkCheckException("Input path doesn't support soft link.") + + input_path = os.path.realpath(input_path) + if not os.path.exists(input_path): + raise ValueError('Input file %s does not exist!' % input_path) + + if not os.access(input_path, os.R_OK): + raise PermissionError('Input file %s is not readable!' % input_path) + + if not check_path_length_valid(input_path): + raise ValueError("The real path or file_name of input is too long.") + + check_path_pattern_valid(input_path) + + if os.path.getsize(input_path) > max_file_size: + raise ValueError(f'The file is too large, exceeds {max_file_size // 1024 ** 2}MB') + + +def check_need_convert(api_name): + convert_type = None + for key, value in Const.CONVERT_API.items(): + if api_name not in value: + continue + else: + convert_type = key + return convert_type + + +def api_info_preprocess(api_name, api_info_dict): + """ + Function Description: + Preprocesses the API information. + Parameter: + api_name: Name of the API. + api_info_dict: argument of the API. + Return api_info_dict: + convert_type: Type of conversion. + api_info_dict: Processed argument of the API. + """ + convert_type = check_need_convert(api_name) + if api_name == 'cross_entropy': + api_info_dict = cross_entropy_process(api_info_dict) + return convert_type, api_info_dict + + +def cross_entropy_process(api_info_dict): + """ + Function Description: + Preprocesses the cross_entropy API information. + Parameter: + api_info_dict: argument of the API. + Return api_info_dict: + api_info_dict: Processed argument of the API. + """ + if 'args' in api_info_dict and len(api_info_dict['args']) > 1 and 'Min' in api_info_dict['args'][1]: + if api_info_dict['args'][1]['Min'] <= 0: + # The second argument in cross_entropy should be -100 or not less than 0 + api_info_dict['args'][1]['Min'] = 0 + return api_info_dict + + +def initialize_save_path(save_path, dir_name): + data_path = os.path.join(save_path, dir_name) + if os.path.exists(data_path): + print_warn_log(f"{data_path} already exists, it will be overwritten") + else: + os.mkdir(data_path, mode=FileCheckConst.DATA_DIR_AUTHORITY) + data_path_checker = FileChecker(data_path, FileCheckConst.DIR) + data_path_checker.common_check() + + +def write_pt(file_path, tensor): + if os.path.exists(file_path): + raise ValueError(f"File {file_path} already exists") + torch.save(tensor, file_path) + full_path = os.path.realpath(file_path) + file_check_util.change_mode(full_path, FileCheckConst.DATA_FILE_AUTHORITY) + return full_path + + +def get_real_data_path(file_path): + targets = ['forward_real_data', 'backward_real_data', 'ut_error_data\d+'] + pattern = re.compile(r'({})'.format('|'.join(targets))) + match = pattern.search(file_path) + if match: + target_index = match.start() + target_path = file_path[target_index:] + return target_path + else: + raise DumpException(DumpException.INVALID_PATH_ERROR) + + +def get_full_data_path(data_path, real_data_path): + if not data_path: + return data_path + full_data_path = os.path.join(real_data_path, data_path) + return os.path.realpath(full_data_path) diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/.keep b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/compare/__init__.py similarity index 100% rename from debug/accuracy_tools/api_accuracy_checker/dump/.keep rename to debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/compare/__init__.py diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/compare/algorithm.py similarity index 98% rename from debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py rename to debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/compare/algorithm.py index c92eff25a701c5f0c228d3225fbbb22959d5f929..d94c7fb08384241fd68f11caaefd331a66a793a8 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py +++ b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/compare/algorithm.py @@ -1,8 +1,7 @@ # 定义比对算法及比对标准 import torch import numpy as np -from api_accuracy_checker.compare.compare_utils import CompareConst, check_dtype_comparable -from api_accuracy_checker.common.utils import Const +from calibrator.api_accuracy_checker.compare.compare_utils import CompareConst #cos diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/api_precision_compare.py b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/compare/api_precision_compare.py similarity index 94% rename from debug/accuracy_tools/api_accuracy_checker/compare/api_precision_compare.py rename to debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/compare/api_precision_compare.py index ad38cb9561e3e14dc332790b3c06c9cb78716c4e..c39a7776761779816a74bf3639ad7783b57ecc24 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/api_precision_compare.py +++ b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/compare/api_precision_compare.py @@ -1,393 +1,391 @@ -import argparse -import os -import sys -import csv -import math -from collections import namedtuple -import pandas as pd - -from api_accuracy_checker.common.utils import print_info_log, print_warn_log, print_error_log, write_csv, \ - CompareException, create_directory -from api_accuracy_checker.common.config import msCheckerConfig -from api_accuracy_checker.compare.compare_utils import CompareConst, API_PRECISION_COMPARE_RESULT_FILE_NAME, \ -API_PRECISION_COMPARE_DETAILS_FILE_NAME, BENCHMARK_COMPARE_SUPPORT_LIST, API_PRECISION_COMPARE_UNSUPPORT_LIST, \ - ApiPrecisionCompareColumn, AbsoluteStandardApi, BinaryStandardApi, BINARY_COMPARE_UNSUPPORT_LIST, \ - convert_str_to_float, CompareMessage -from api_accuracy_checker.compare.compare_column import ApiPrecisionOutputColumn -from api_accuracy_checker.run_ut.run_ut import get_validated_result_csv_path -from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileCheckConst, FileChecker, change_mode -from ptdbg_ascend.src.python.ptdbg_ascend.common.utils import check_path_before_create - - -CompareConfig = namedtuple('CompareConfig', ['npu_csv_path', 'gpu_csv_path', 'result_csv_path', 'details_csv_path']) -unsupported_message = 'This data type does not support benchmark compare.' - - -benchmark_algorithms_thresholds = { - 'small_value' : { - 'error_threshold' : 2, - 'warning_threshold' : 1 - }, - 'rmse' : { - 'error_threshold' : 2, - 'warning_threshold' : 1 - }, - 'max_rel_err' : { - 'error_threshold' : 10, - 'warning_threshold' : 1 - }, - 'mean_rel_err' : { - 'error_threshold' : 2, - 'warning_threshold' : 1 - }, - 'eb' : { - 'error_threshold' : 2, - 'warning_threshold' : 1 - } -} - - -benchmark_message = { - "small_value_err_status": { - CompareConst.ERROR: "ERROR: 小值域错误比值超过阈值\n", - CompareConst.WARNING: "WARNING: 小值域错误比值超过阈值\n" - }, - "rmse_status": { - CompareConst.ERROR: "ERROR: 均方根误差比值超过阈值\n", - CompareConst.WARNING: "WARNING: 均方根误差比值超过阈值\n" - }, - "max_rel_err_status": { - CompareConst.ERROR: "ERROR: 相对误差最大值比值超过阈值\n", - CompareConst.WARNING: "WARNING: 相对误差最大值比值超过阈值\n" - }, - "mean_rel_err_status": { - CompareConst.ERROR: "ERROR: 相对误差平均值比值超过阈值\n", - CompareConst.WARNING: "WARNING: 相对误差平均值比值超过阈值\n" - } -} - - -class BenchmarkStandard: - def __init__(self, api_name, npu_precision, gpu_precision): - self.api_name = api_name - self.npu_precision = npu_precision - self.gpu_precision = gpu_precision - self.small_value_err_ratio = 1 - self.rmse_ratio = 1 - self.max_rel_err_ratio = 1 - self.mean_rel_err_ratio = 1 - self.eb_ratio = 1 - self.small_value_err_status = CompareConst.PASS - self.rmse_status = CompareConst.PASS - self.max_rel_err_status = CompareConst.PASS - self.mean_rel_err_status = CompareConst.PASS - self.eb_status = CompareConst.PASS - self.check_result_list = [] - self.final_result = CompareConst.PASS - - def __str__(self): - return "%s" % (self.api_name) - - def get_result(self): - self._compare_ratio() - self.small_value_err_status = self._get_status(self.small_value_err_ratio, 'small_value') - self.check_result_list.append(self.small_value_err_status) - self.rmse_status = self._get_status(self.rmse_ratio, 'rmse') - self.check_result_list.append(self.rmse_status) - self.max_rel_err_status = self._get_status(self.max_rel_err_ratio, 'max_rel_err') - self.check_result_list.append(self.max_rel_err_status) - self.mean_rel_err_status = self._get_status(self.mean_rel_err_ratio, 'mean_rel_err') - self.check_result_list.append(self.mean_rel_err_status) - self.eb_status = self._get_status(self.eb_ratio, 'eb') - if CompareConst.ERROR in self.check_result_list: - self.final_result = CompareConst.ERROR - elif CompareConst.WARNING in self.check_result_list: - self.final_result = CompareConst.WARNING - - def _compare_ratio(self): - self.small_value_err_ratio = self._calc_ratio( - self.npu_precision.get(ApiPrecisionCompareColumn.SMALL_VALUE_ERROR_RATE), - self.gpu_precision.get(ApiPrecisionCompareColumn.SMALL_VALUE_ERROR_RATE), 10000.0) - self.rmse_ratio = self._calc_ratio(self.npu_precision.get(ApiPrecisionCompareColumn.RMSE), - self.gpu_precision.get(ApiPrecisionCompareColumn.RMSE), 10000.0) - self.max_rel_err_ratio = self._calc_ratio(self.npu_precision.get(ApiPrecisionCompareColumn.MAX_REL_ERR), - self.gpu_precision.get(ApiPrecisionCompareColumn.MAX_REL_ERR), 10000.0) - self.mean_rel_err_ratio = self._calc_ratio(self.npu_precision.get(ApiPrecisionCompareColumn.MEAN_REL_ERR), - self.gpu_precision.get(ApiPrecisionCompareColumn.MEAN_REL_ERR), 10000.0) - self.eb_ratio = self._calc_ratio(self.npu_precision.get(ApiPrecisionCompareColumn.EB), - self.gpu_precision.get(ApiPrecisionCompareColumn.EB), 10000.0) - - def to_column_value(self): - return [self.small_value_err_ratio, self.small_value_err_status, self.rmse_ratio, - self.rmse_status, self.max_rel_err_ratio, self.max_rel_err_status, self.mean_rel_err_ratio, - self.mean_rel_err_status, self.eb_ratio, self.eb_status] - - @staticmethod - def _get_status(ratio, algorithm): - error_threshold = benchmark_algorithms_thresholds.get(algorithm).get('error_threshold') - warning_threshold = benchmark_algorithms_thresholds.get(algorithm).get('warning_threshold') - if ratio > error_threshold: - return CompareConst.ERROR - elif ratio > warning_threshold: - return CompareConst.WARNING - return CompareConst.PASS - - @staticmethod - def _calc_ratio(x, y, default_value=1.0): - x, y = convert_str_to_float(x), convert_str_to_float(y) - if math.isclose(y, 0.0): - return 1.0 if math.isclose(x, 0.0) else default_value - else: - return abs(x / y) - - -def write_detail_csv(content, save_path): - rows = [] - content = ["{:.{}f}".format(item, msCheckerConfig.precision) \ - if isinstance(item, float) else item for item in content] - rows.append(content) - write_csv(rows, save_path) - - -def api_precision_compare(config): - print_info_log("Start compare task") - print_info_log(f"Compare task result will be saved in {config.result_csv_path}") - print_info_log(f"Compare task detail will be saved in {config.details_csv_path}") - try: - npu_data = pd.read_csv(config.npu_csv_path) - except Exception as err: - print_error_log(f"Open npu csv Error: %s" % str(err)) - check_csv_columns(npu_data.columns, "npu_csv") - try: - gpu_data = pd.read_csv(config.gpu_csv_path) - except Exception as err: - print_error_log(f"Open gpu csv Error: %s" % str(err)) - check_csv_columns(gpu_data.columns, "gpu_csv") - detail_csv_title = [ApiPrecisionCompareColumn.get_detail_csv_title()] - result_csv_title = [ApiPrecisionCompareColumn.get_result_csv_title()] - write_csv(result_csv_title, config.result_csv_path) - write_csv(detail_csv_title, config.details_csv_path) - try: - analyse_csv(npu_data, gpu_data, config) - except Exception as err: - print_error_log(f"Analyse csv Error: %s" % str(err)) - change_mode(config.result_csv_path, FileCheckConst.DATA_FILE_AUTHORITY) - change_mode(config.details_csv_path, FileCheckConst.DATA_FILE_AUTHORITY) - - -def analyse_csv(npu_data, gpu_data, config): - forward_status, backward_status = [], [] - full_last_api_name, last_api_dtype = None, None - for _, row_npu in npu_data.iterrows(): - message = '' - compare_column = ApiPrecisionOutputColumn() - full_api_name_with_direction_status = row_npu[ApiPrecisionCompareColumn.API_NAME] - row_gpu = gpu_data[gpu_data[ApiPrecisionCompareColumn.API_NAME] == full_api_name_with_direction_status] - full_api_name, direction_status, _, _ = full_api_name_with_direction_status.split(".") - if row_gpu.empty: - print_warn_log(f'This API : {full_api_name_with_direction_status} does not exist in the GPU data.') - continue - if len(row_gpu) > 1: - msg = f'This API : {full_api_name_with_direction_status} has multiple records in the GPU data.' - raise CompareException(CompareException.INVALID_DATA_ERROR, msg) - row_gpu = row_gpu.iloc[0] - #当前API的输出为空(例如反向过程中requires_grad=False),跳过比对 - if row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE].isspace(): - continue - _, api_name, _ = full_api_name.split("*") - new_status = CompareConst.SPACE - compare_column.api_name = full_api_name_with_direction_status - if row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] not in BINARY_COMPARE_UNSUPPORT_LIST or api_name in BinaryStandardApi: - new_status = record_binary_consistency_result(api_name, compare_column, row_npu) - elif api_name in AbsoluteStandardApi: - new_status = record_absolute_threshold_result(compare_column, row_npu) - elif row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] in BENCHMARK_COMPARE_SUPPORT_LIST: - bs = BenchmarkStandard(full_api_name_with_direction_status, row_npu, row_gpu) - new_status = record_benchmark_compare_result(compare_column, bs) - write_detail_csv(compare_column.to_column_value(), config.details_csv_path) - - if full_last_api_name is not None and full_api_name != full_last_api_name: - if last_api_dtype in API_PRECISION_COMPARE_UNSUPPORT_LIST: - message = unsupported_message - write_csv([[full_last_api_name, "skip", "skip", message]], config.result_csv_path) - forward_status, backward_status = [], [] - message = '' - else: - forward_result = get_api_checker_result(forward_status) - backward_result = get_api_checker_result(backward_status) - _, last_api_name, _ = full_last_api_name.split("*") - message += CompareMessage.get(last_api_name, "") if forward_result == CompareConst.ERROR else "" - write_csv([[full_last_api_name, forward_result, backward_result, message]], config.result_csv_path) - forward_status, backward_status = [], [] - message = '' - - is_supported = row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] not in API_PRECISION_COMPARE_UNSUPPORT_LIST - full_last_api_name = full_api_name - - last_api_dtype = row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] - if not is_supported: - continue - - if direction_status == 'forward': - forward_status.append(new_status) - elif direction_status == 'backward': - backward_status.append(new_status) - else: - print_error_log(f"Invalid direction status: {direction_status}") - - if full_last_api_name is not None: - if last_api_dtype in API_PRECISION_COMPARE_UNSUPPORT_LIST: - message = unsupported_message - write_csv([[full_last_api_name, "skip", "skip", message]], config.result_csv_path) - else: - forward_result = get_api_checker_result(forward_status) - backward_result = get_api_checker_result(backward_status) - _, last_api_name, _ = full_last_api_name.split("*") - message += CompareMessage.get(last_api_name, "") if forward_result == CompareConst.ERROR else "" - write_csv([[full_last_api_name, forward_result, backward_result, message]], config.result_csv_path) - - -def check_error_rate(npu_error_rate): - return CompareConst.PASS if convert_str_to_float(npu_error_rate) == 0 else CompareConst.ERROR - - -def get_absolute_threshold_result(row_npu): - inf_nan_error_ratio = convert_str_to_float(row_npu[ApiPrecisionCompareColumn.INF_NAN_ERROR_RATIO]) - rel_err_ratio = convert_str_to_float(row_npu[ApiPrecisionCompareColumn.REL_ERR_RATIO]) - abs_err_ratio = convert_str_to_float(row_npu[ApiPrecisionCompareColumn.ABS_ERR_RATIO]) - - inf_nan_result = CompareConst.PASS if inf_nan_error_ratio == 0 else CompareConst.ERROR - rel_err_result = CompareConst.PASS if rel_err_ratio == 0 else CompareConst.ERROR - abs_err_result = CompareConst.PASS if abs_err_ratio == 0 else CompareConst.ERROR - - if CompareConst.ERROR in [inf_nan_result, rel_err_result, abs_err_result]: - absolute_threshold_result = CompareConst.ERROR - else: - absolute_threshold_result = CompareConst.PASS - - return { - "inf_nan_error_ratio": inf_nan_error_ratio, - "inf_nan_result": inf_nan_result, - "rel_err_ratio": rel_err_ratio, - "rel_err_result": rel_err_result, - "abs_err_ratio": abs_err_ratio, - "abs_err_result": abs_err_result, - "absolute_threshold_result": absolute_threshold_result, - } - - -def get_api_checker_result(status): - if not status: - return CompareConst.SPACE - for const in (CompareConst.ERROR, CompareConst.WARNING): - if const in status: - return const - return CompareConst.PASS - - -def check_csv_columns(columns, csv_type): - required_columns = ApiPrecisionCompareColumn.to_required_columns() - missing_columns = [column for column in required_columns if column not in columns] - if missing_columns: - msg = f"The followint columns {','.join(missing_columns)} are missing in{csv_type}" - raise CompareException(CompareException.INVALID_DATA_ERROR, msg) - - -def record_binary_consistency_result(api_name, compare_column, row_npu): - new_status = check_error_rate(row_npu[ApiPrecisionCompareColumn.ERROR_RATE]) - compare_column.error_rate = row_npu[ApiPrecisionCompareColumn.ERROR_RATE] - compare_column.error_rate_status = new_status - compare_column.compare_result = new_status - compare_column.compare_algorithm = "二进制一致法" - message = '' - if compare_column.error_rate_status == CompareConst.ERROR: - message += "ERROR: 二进制一致错误率超过阈值\n" - message += CompareMessage.get(api_name, "") - compare_column.compare_message = message - return new_status - - -def record_absolute_threshold_result(compare_column, row_npu): - absolute_threshold_result = get_absolute_threshold_result(row_npu) - compare_column.inf_nan_error_ratio = absolute_threshold_result.get("inf_nan_error_ratio") - compare_column.inf_nan_error_ratio_status = absolute_threshold_result.get("inf_nan_result") - compare_column.rel_err_ratio = absolute_threshold_result.get("rel_err_ratio") - compare_column.rel_err_ratio_status = absolute_threshold_result.get("rel_err_result") - compare_column.abs_err_ratio = absolute_threshold_result.get("abs_err_ratio") - compare_column.abs_err_ratio_status = absolute_threshold_result.get("abs_err_result") - compare_column.compare_result = absolute_threshold_result.get("absolute_threshold_result") - compare_column.compare_algorithm = "绝对阈值法" - message = '' - if compare_column.inf_nan_error_ratio_status == CompareConst.ERROR: - message += "ERROR: inf/nan错误率超过阈值\n" - if compare_column.rel_err_ratio_status == CompareConst.ERROR: - message += "ERROR: 相对误差错误率超过阈值\n" - if compare_column.abs_err_ratio_status == CompareConst.ERROR: - message += "ERROR: 绝对误差错误率超过阈值\n" - compare_column.compare_message = message - return compare_column.compare_result - - -def record_benchmark_compare_result(compare_column, bs): - bs.get_result() - compare_column.small_value_err_ratio = bs.small_value_err_ratio - compare_column.small_value_err_status = bs.small_value_err_status - compare_column.rmse_ratio = bs.rmse_ratio - compare_column.rmse_status = bs.rmse_status - compare_column.max_rel_err_ratio = bs.max_rel_err_ratio - compare_column.max_rel_err_status = bs.max_rel_err_status - compare_column.mean_rel_err_ratio = bs.mean_rel_err_ratio - compare_column.mean_rel_err_status = bs.mean_rel_err_status - compare_column.eb_ratio = bs.eb_ratio - compare_column.eb_status = bs.eb_status - compare_column.compare_result = bs.final_result - compare_column.compare_algorithm = "标杆比对法" - message = '' - for status_attr, messages in benchmark_message.items(): - status_value = getattr(compare_column, status_attr) - if status_value in messages: - message += messages[status_value] - compare_column.compare_message = message - return compare_column.compare_result - - -def _api_precision_compare(parser=None): - if not parser: - parser = argparse.ArgumentParser() - _api_precision_compare_parser(parser) - args = parser.parse_args(sys.argv[1:]) - _api_precision_compare_command(args) - - -def _api_precision_compare_command(args): - npu_csv_path = get_validated_result_csv_path(args.npu_csv_path, 'detail') - gpu_csv_path = get_validated_result_csv_path(args.gpu_csv_path, 'detail') - out_path = os.path.realpath(args.out_path) if args.out_path else "./" - check_path_before_create(out_path) - create_directory(out_path) - out_path_checker = FileChecker(out_path, FileCheckConst.DIR, ability=FileCheckConst.WRITE_ABLE) - out_path = out_path_checker.common_check() - result_csv_path = os.path.join(out_path, API_PRECISION_COMPARE_RESULT_FILE_NAME) - details_csv_path = os.path.join(out_path, API_PRECISION_COMPARE_DETAILS_FILE_NAME) - compare_config = CompareConfig(npu_csv_path, gpu_csv_path, result_csv_path, details_csv_path) - api_precision_compare(compare_config) - - -def _api_precision_compare_parser(parser): - parser.add_argument("-npu", "--npu_csv_path", dest="npu_csv_path", default="", type=str, - help=" , Accuracy_checking_details.csv generated on the NPU by using the " - "api_accuracy_checker tool.", - required=True) - parser.add_argument("-gpu", "--gpu_csv_path", dest="gpu_csv_path", default="", type=str, - help=" Accuracy_checking_details.csv generated on the GPU by using the " - "api_accuracy_checker tool.", - required=False) - parser.add_argument("-o", "--out_path", dest="out_path", default="", type=str, - help=" The api precision compare task result out path.", - required=False) - - -if __name__ == '__main__': - _api_precision_compare() - print_info_log("Compare task completed.") +import argparse +import os +import sys +import math +from collections import namedtuple +import pandas as pd + +from calibrator.pytorch.api_accuracy_checker.common import print_info_log, print_warn_log, print_error_log, write_csv, \ + CompareException, create_directory +from calibrator.pytorch.api_accuracy_checker.common import msCheckerConfig +from calibrator.api_accuracy_checker.compare.compare_utils import CompareConst, API_PRECISION_COMPARE_RESULT_FILE_NAME, \ +API_PRECISION_COMPARE_DETAILS_FILE_NAME, BENCHMARK_COMPARE_SUPPORT_LIST, API_PRECISION_COMPARE_UNSUPPORT_LIST, \ + ApiPrecisionCompareColumn, AbsoluteStandardApi, BinaryStandardApi, BINARY_COMPARE_UNSUPPORT_LIST, \ + convert_str_to_float, CompareMessage +from calibrator.api_accuracy_checker.compare.compare_column import ApiPrecisionOutputColumn +from calibrator.pytorch.api_accuracy_checker.run_ut.run_ut import get_validated_result_csv_path +from calibrator.common.file_check import FileCheckConst, FileChecker, change_mode, check_path_before_create + + +CompareConfig = namedtuple('CompareConfig', ['npu_csv_path', 'gpu_csv_path', 'result_csv_path', 'details_csv_path']) +unsupported_message = 'This data type does not support benchmark compare.' + + +benchmark_algorithms_thresholds = { + 'small_value' : { + 'error_threshold' : 2, + 'warning_threshold' : 1 + }, + 'rmse' : { + 'error_threshold' : 2, + 'warning_threshold' : 1 + }, + 'max_rel_err' : { + 'error_threshold' : 10, + 'warning_threshold' : 1 + }, + 'mean_rel_err' : { + 'error_threshold' : 2, + 'warning_threshold' : 1 + }, + 'eb' : { + 'error_threshold' : 2, + 'warning_threshold' : 1 + } +} + + +benchmark_message = { + "small_value_err_status": { + CompareConst.ERROR: "ERROR: 小值域错误比值超过阈值\n", + CompareConst.WARNING: "WARNING: 小值域错误比值超过阈值\n" + }, + "rmse_status": { + CompareConst.ERROR: "ERROR: 均方根误差比值超过阈值\n", + CompareConst.WARNING: "WARNING: 均方根误差比值超过阈值\n" + }, + "max_rel_err_status": { + CompareConst.ERROR: "ERROR: 相对误差最大值比值超过阈值\n", + CompareConst.WARNING: "WARNING: 相对误差最大值比值超过阈值\n" + }, + "mean_rel_err_status": { + CompareConst.ERROR: "ERROR: 相对误差平均值比值超过阈值\n", + CompareConst.WARNING: "WARNING: 相对误差平均值比值超过阈值\n" + } +} + + +class BenchmarkStandard: + def __init__(self, api_name, npu_precision, gpu_precision): + self.api_name = api_name + self.npu_precision = npu_precision + self.gpu_precision = gpu_precision + self.small_value_err_ratio = 1 + self.rmse_ratio = 1 + self.max_rel_err_ratio = 1 + self.mean_rel_err_ratio = 1 + self.eb_ratio = 1 + self.small_value_err_status = CompareConst.PASS + self.rmse_status = CompareConst.PASS + self.max_rel_err_status = CompareConst.PASS + self.mean_rel_err_status = CompareConst.PASS + self.eb_status = CompareConst.PASS + self.check_result_list = [] + self.final_result = CompareConst.PASS + + def __str__(self): + return "%s" % (self.api_name) + + def get_result(self): + self._compare_ratio() + self.small_value_err_status = self._get_status(self.small_value_err_ratio, 'small_value') + self.check_result_list.append(self.small_value_err_status) + self.rmse_status = self._get_status(self.rmse_ratio, 'rmse') + self.check_result_list.append(self.rmse_status) + self.max_rel_err_status = self._get_status(self.max_rel_err_ratio, 'max_rel_err') + self.check_result_list.append(self.max_rel_err_status) + self.mean_rel_err_status = self._get_status(self.mean_rel_err_ratio, 'mean_rel_err') + self.check_result_list.append(self.mean_rel_err_status) + self.eb_status = self._get_status(self.eb_ratio, 'eb') + if CompareConst.ERROR in self.check_result_list: + self.final_result = CompareConst.ERROR + elif CompareConst.WARNING in self.check_result_list: + self.final_result = CompareConst.WARNING + + def _compare_ratio(self): + self.small_value_err_ratio = self._calc_ratio( + self.npu_precision.get(ApiPrecisionCompareColumn.SMALL_VALUE_ERROR_RATE), + self.gpu_precision.get(ApiPrecisionCompareColumn.SMALL_VALUE_ERROR_RATE), 10000.0) + self.rmse_ratio = self._calc_ratio(self.npu_precision.get(ApiPrecisionCompareColumn.RMSE), + self.gpu_precision.get(ApiPrecisionCompareColumn.RMSE), 10000.0) + self.max_rel_err_ratio = self._calc_ratio(self.npu_precision.get(ApiPrecisionCompareColumn.MAX_REL_ERR), + self.gpu_precision.get(ApiPrecisionCompareColumn.MAX_REL_ERR), 10000.0) + self.mean_rel_err_ratio = self._calc_ratio(self.npu_precision.get(ApiPrecisionCompareColumn.MEAN_REL_ERR), + self.gpu_precision.get(ApiPrecisionCompareColumn.MEAN_REL_ERR), 10000.0) + self.eb_ratio = self._calc_ratio(self.npu_precision.get(ApiPrecisionCompareColumn.EB), + self.gpu_precision.get(ApiPrecisionCompareColumn.EB), 10000.0) + + def to_column_value(self): + return [self.small_value_err_ratio, self.small_value_err_status, self.rmse_ratio, + self.rmse_status, self.max_rel_err_ratio, self.max_rel_err_status, self.mean_rel_err_ratio, + self.mean_rel_err_status, self.eb_ratio, self.eb_status] + + @staticmethod + def _get_status(ratio, algorithm): + error_threshold = benchmark_algorithms_thresholds.get(algorithm).get('error_threshold') + warning_threshold = benchmark_algorithms_thresholds.get(algorithm).get('warning_threshold') + if ratio > error_threshold: + return CompareConst.ERROR + elif ratio > warning_threshold: + return CompareConst.WARNING + return CompareConst.PASS + + @staticmethod + def _calc_ratio(x, y, default_value=1.0): + x, y = convert_str_to_float(x), convert_str_to_float(y) + if math.isclose(y, 0.0): + return 1.0 if math.isclose(x, 0.0) else default_value + else: + return abs(x / y) + + +def write_detail_csv(content, save_path): + rows = [] + content = ["{:.{}f}".format(item, msCheckerConfig.precision) \ + if isinstance(item, float) else item for item in content] + rows.append(content) + write_csv(rows, save_path) + + +def api_precision_compare(config): + print_info_log("Start compare task") + print_info_log(f"Compare task result will be saved in {config.result_csv_path}") + print_info_log(f"Compare task detail will be saved in {config.details_csv_path}") + try: + npu_data = pd.read_csv(config.npu_csv_path) + except Exception as err: + print_error_log(f"Open npu csv Error: %s" % str(err)) + check_csv_columns(npu_data.columns, "npu_csv") + try: + gpu_data = pd.read_csv(config.gpu_csv_path) + except Exception as err: + print_error_log(f"Open gpu csv Error: %s" % str(err)) + check_csv_columns(gpu_data.columns, "gpu_csv") + detail_csv_title = [ApiPrecisionCompareColumn.get_detail_csv_title()] + result_csv_title = [ApiPrecisionCompareColumn.get_result_csv_title()] + write_csv(result_csv_title, config.result_csv_path) + write_csv(detail_csv_title, config.details_csv_path) + try: + analyse_csv(npu_data, gpu_data, config) + except Exception as err: + print_error_log(f"Analyse csv Error: %s" % str(err)) + change_mode(config.result_csv_path, FileCheckConst.DATA_FILE_AUTHORITY) + change_mode(config.details_csv_path, FileCheckConst.DATA_FILE_AUTHORITY) + + +def analyse_csv(npu_data, gpu_data, config): + forward_status, backward_status = [], [] + full_last_api_name, last_api_dtype = None, None + for _, row_npu in npu_data.iterrows(): + message = '' + compare_column = ApiPrecisionOutputColumn() + full_api_name_with_direction_status = row_npu[ApiPrecisionCompareColumn.API_NAME] + row_gpu = gpu_data[gpu_data[ApiPrecisionCompareColumn.API_NAME] == full_api_name_with_direction_status] + full_api_name, direction_status, _, _ = full_api_name_with_direction_status.split(".") + if row_gpu.empty: + print_warn_log(f'This API : {full_api_name_with_direction_status} does not exist in the GPU data.') + continue + if len(row_gpu) > 1: + msg = f'This API : {full_api_name_with_direction_status} has multiple records in the GPU data.' + raise CompareException(CompareException.INVALID_DATA_ERROR, msg) + row_gpu = row_gpu.iloc[0] + #当前API的输出为空(例如反向过程中requires_grad=False),跳过比对 + if row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE].isspace(): + continue + _, api_name, _ = full_api_name.split("*") + new_status = CompareConst.SPACE + compare_column.api_name = full_api_name_with_direction_status + if row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] not in BINARY_COMPARE_UNSUPPORT_LIST or api_name in BinaryStandardApi: + new_status = record_binary_consistency_result(api_name, compare_column, row_npu) + elif api_name in AbsoluteStandardApi: + new_status = record_absolute_threshold_result(compare_column, row_npu) + elif row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] in BENCHMARK_COMPARE_SUPPORT_LIST: + bs = BenchmarkStandard(full_api_name_with_direction_status, row_npu, row_gpu) + new_status = record_benchmark_compare_result(compare_column, bs) + write_detail_csv(compare_column.to_column_value(), config.details_csv_path) + + if full_last_api_name is not None and full_api_name != full_last_api_name: + if last_api_dtype in API_PRECISION_COMPARE_UNSUPPORT_LIST: + message = unsupported_message + write_csv([[full_last_api_name, "skip", "skip", message]], config.result_csv_path) + forward_status, backward_status = [], [] + message = '' + else: + forward_result = get_api_checker_result(forward_status) + backward_result = get_api_checker_result(backward_status) + _, last_api_name, _ = full_last_api_name.split("*") + message += CompareMessage.get(last_api_name, "") if forward_result == CompareConst.ERROR else "" + write_csv([[full_last_api_name, forward_result, backward_result, message]], config.result_csv_path) + forward_status, backward_status = [], [] + message = '' + + is_supported = row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] not in API_PRECISION_COMPARE_UNSUPPORT_LIST + full_last_api_name = full_api_name + + last_api_dtype = row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] + if not is_supported: + continue + + if direction_status == 'forward': + forward_status.append(new_status) + elif direction_status == 'backward': + backward_status.append(new_status) + else: + print_error_log(f"Invalid direction status: {direction_status}") + + if full_last_api_name is not None: + if last_api_dtype in API_PRECISION_COMPARE_UNSUPPORT_LIST: + message = unsupported_message + write_csv([[full_last_api_name, "skip", "skip", message]], config.result_csv_path) + else: + forward_result = get_api_checker_result(forward_status) + backward_result = get_api_checker_result(backward_status) + _, last_api_name, _ = full_last_api_name.split("*") + message += CompareMessage.get(last_api_name, "") if forward_result == CompareConst.ERROR else "" + write_csv([[full_last_api_name, forward_result, backward_result, message]], config.result_csv_path) + + +def check_error_rate(npu_error_rate): + return CompareConst.PASS if convert_str_to_float(npu_error_rate) == 0 else CompareConst.ERROR + + +def get_absolute_threshold_result(row_npu): + inf_nan_error_ratio = convert_str_to_float(row_npu[ApiPrecisionCompareColumn.INF_NAN_ERROR_RATIO]) + rel_err_ratio = convert_str_to_float(row_npu[ApiPrecisionCompareColumn.REL_ERR_RATIO]) + abs_err_ratio = convert_str_to_float(row_npu[ApiPrecisionCompareColumn.ABS_ERR_RATIO]) + + inf_nan_result = CompareConst.PASS if inf_nan_error_ratio == 0 else CompareConst.ERROR + rel_err_result = CompareConst.PASS if rel_err_ratio == 0 else CompareConst.ERROR + abs_err_result = CompareConst.PASS if abs_err_ratio == 0 else CompareConst.ERROR + + if CompareConst.ERROR in [inf_nan_result, rel_err_result, abs_err_result]: + absolute_threshold_result = CompareConst.ERROR + else: + absolute_threshold_result = CompareConst.PASS + + return { + "inf_nan_error_ratio": inf_nan_error_ratio, + "inf_nan_result": inf_nan_result, + "rel_err_ratio": rel_err_ratio, + "rel_err_result": rel_err_result, + "abs_err_ratio": abs_err_ratio, + "abs_err_result": abs_err_result, + "absolute_threshold_result": absolute_threshold_result, + } + + +def get_api_checker_result(status): + if not status: + return CompareConst.SPACE + for const in (CompareConst.ERROR, CompareConst.WARNING): + if const in status: + return const + return CompareConst.PASS + + +def check_csv_columns(columns, csv_type): + required_columns = ApiPrecisionCompareColumn.to_required_columns() + missing_columns = [column for column in required_columns if column not in columns] + if missing_columns: + msg = f"The followint columns {','.join(missing_columns)} are missing in{csv_type}" + raise CompareException(CompareException.INVALID_DATA_ERROR, msg) + + +def record_binary_consistency_result(api_name, compare_column, row_npu): + new_status = check_error_rate(row_npu[ApiPrecisionCompareColumn.ERROR_RATE]) + compare_column.error_rate = row_npu[ApiPrecisionCompareColumn.ERROR_RATE] + compare_column.error_rate_status = new_status + compare_column.compare_result = new_status + compare_column.compare_algorithm = "二进制一致法" + message = '' + if compare_column.error_rate_status == CompareConst.ERROR: + message += "ERROR: 二进制一致错误率超过阈值\n" + message += CompareMessage.get(api_name, "") + compare_column.compare_message = message + return new_status + + +def record_absolute_threshold_result(compare_column, row_npu): + absolute_threshold_result = get_absolute_threshold_result(row_npu) + compare_column.inf_nan_error_ratio = absolute_threshold_result.get("inf_nan_error_ratio") + compare_column.inf_nan_error_ratio_status = absolute_threshold_result.get("inf_nan_result") + compare_column.rel_err_ratio = absolute_threshold_result.get("rel_err_ratio") + compare_column.rel_err_ratio_status = absolute_threshold_result.get("rel_err_result") + compare_column.abs_err_ratio = absolute_threshold_result.get("abs_err_ratio") + compare_column.abs_err_ratio_status = absolute_threshold_result.get("abs_err_result") + compare_column.compare_result = absolute_threshold_result.get("absolute_threshold_result") + compare_column.compare_algorithm = "绝对阈值法" + message = '' + if compare_column.inf_nan_error_ratio_status == CompareConst.ERROR: + message += "ERROR: inf/nan错误率超过阈值\n" + if compare_column.rel_err_ratio_status == CompareConst.ERROR: + message += "ERROR: 相对误差错误率超过阈值\n" + if compare_column.abs_err_ratio_status == CompareConst.ERROR: + message += "ERROR: 绝对误差错误率超过阈值\n" + compare_column.compare_message = message + return compare_column.compare_result + + +def record_benchmark_compare_result(compare_column, bs): + bs.get_result() + compare_column.small_value_err_ratio = bs.small_value_err_ratio + compare_column.small_value_err_status = bs.small_value_err_status + compare_column.rmse_ratio = bs.rmse_ratio + compare_column.rmse_status = bs.rmse_status + compare_column.max_rel_err_ratio = bs.max_rel_err_ratio + compare_column.max_rel_err_status = bs.max_rel_err_status + compare_column.mean_rel_err_ratio = bs.mean_rel_err_ratio + compare_column.mean_rel_err_status = bs.mean_rel_err_status + compare_column.eb_ratio = bs.eb_ratio + compare_column.eb_status = bs.eb_status + compare_column.compare_result = bs.final_result + compare_column.compare_algorithm = "标杆比对法" + message = '' + for status_attr, messages in benchmark_message.items(): + status_value = getattr(compare_column, status_attr) + if status_value in messages: + message += messages[status_value] + compare_column.compare_message = message + return compare_column.compare_result + + +def _api_precision_compare(parser=None): + if not parser: + parser = argparse.ArgumentParser() + _api_precision_compare_parser(parser) + args = parser.parse_args(sys.argv[1:]) + _api_precision_compare_command(args) + + +def _api_precision_compare_command(args): + npu_csv_path = get_validated_result_csv_path(args.npu_csv_path, 'detail') + gpu_csv_path = get_validated_result_csv_path(args.gpu_csv_path, 'detail') + out_path = os.path.realpath(args.out_path) if args.out_path else "./" + check_path_before_create(out_path) + create_directory(out_path) + out_path_checker = FileChecker(out_path, FileCheckConst.DIR, ability=FileCheckConst.WRITE_ABLE) + out_path = out_path_checker.common_check() + result_csv_path = os.path.join(out_path, API_PRECISION_COMPARE_RESULT_FILE_NAME) + details_csv_path = os.path.join(out_path, API_PRECISION_COMPARE_DETAILS_FILE_NAME) + compare_config = CompareConfig(npu_csv_path, gpu_csv_path, result_csv_path, details_csv_path) + api_precision_compare(compare_config) + + +def _api_precision_compare_parser(parser): + parser.add_argument("-npu", "--npu_csv_path", dest="npu_csv_path", default="", type=str, + help=" , Accuracy_checking_details.csv generated on the NPU by using the " + "api_accuracy_checker tool.", + required=True) + parser.add_argument("-gpu", "--gpu_csv_path", dest="gpu_csv_path", default="", type=str, + help=" Accuracy_checking_details.csv generated on the GPU by using the " + "api_accuracy_checker tool.", + required=False) + parser.add_argument("-o", "--out_path", dest="out_path", default="", type=str, + help=" The api precision compare task result out path.", + required=False) + + +if __name__ == '__main__': + _api_precision_compare() + print_info_log("Compare task completed.") \ No newline at end of file diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/api_precision_standard.yaml b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml similarity index 94% rename from debug/accuracy_tools/api_accuracy_checker/compare/api_precision_standard.yaml rename to debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml index 4033538b73e9b6a094bbf2c05e0c02bbed607c24..efba9c5c02bbcc094b75ce2497d830789744b143 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/api_precision_standard.yaml +++ b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml @@ -1,108 +1,108 @@ -# Copyright (c) 2024 Huawei Technologies Co., Ltd -# All rights reserved. -# -# Licensed under the BSD 3-Clause License (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://opensource.org/licenses/BSD-3-Clause -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -AbsoluteThreshStandard: - - mul - - mul_ - - __mul__ - - __imul__ - - __rmul__ - - add - - add_ - - __add__ - - __iadd__ - - __radd__ - - div - - div_ - - __div__ - - __idiv__ - - divide - - divide_ - - leaky_relu - - leaky_relu_ - - prelu - - reciprocal - - reciprocal_ - - rsqrt - - rsqrt_ - - square - - square_ - - sub - - sub_ - - rsub - - __isub__ - - __sub__ - -BinaryCompareStandard: - - abs - - abs_ - - absolute - - absolute_ - - argmin - - bitwise_and - - bitwise_and_ - - broadcast_to - - ceil - - ceil_ - - equal - - fill_ - - flatten - - floor - - floor_ - - gather - - greater - - greater_ - - greater_equal - - greater_equal_ - - isfinite - - isnan - - less - - less_ - - less_equal - - less_equal_ - - logical_and - - logical_and_ - - logical_not - - logical_not_ - - logical_or - - logical_or_ - - masked_fill - - masked_fill_ - - max_pool3d - - maximum - - minimum - - neg - - neg_ - - nonzero - - not_equal - - not_equal_ - - one_hot - - pad - - relu - - reshape - - round - - round_ - - select - - sign - - sign_ - - sort - - tile - - topk - - transpose - - transpose_ - - tril - - tril_ - - triu - - triu_ +# Copyright (c) 2024 Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +AbsoluteThreshStandard: + - mul + - mul_ + - __mul__ + - __imul__ + - __rmul__ + - add + - add_ + - __add__ + - __iadd__ + - __radd__ + - div + - div_ + - __div__ + - __idiv__ + - divide + - divide_ + - leaky_relu + - leaky_relu_ + - prelu + - reciprocal + - reciprocal_ + - rsqrt + - rsqrt_ + - square + - square_ + - sub + - sub_ + - rsub + - __isub__ + - __sub__ + +BinaryCompareStandard: + - abs + - abs_ + - absolute + - absolute_ + - argmin + - bitwise_and + - bitwise_and_ + - broadcast_to + - ceil + - ceil_ + - equal + - fill_ + - flatten + - floor + - floor_ + - gather + - greater + - greater_ + - greater_equal + - greater_equal_ + - isfinite + - isnan + - less + - less_ + - less_equal + - less_equal_ + - logical_and + - logical_and_ + - logical_not + - logical_not_ + - logical_or + - logical_or_ + - masked_fill + - masked_fill_ + - max_pool3d + - maximum + - minimum + - neg + - neg_ + - nonzero + - not_equal + - not_equal_ + - one_hot + - pad + - relu + - reshape + - round + - round_ + - select + - sign + - sign_ + - sort + - tile + - topk + - transpose + - transpose_ + - tril + - tril_ + - triu + - triu_ diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/api_precision_threshold.yaml b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml similarity index 95% rename from debug/accuracy_tools/api_accuracy_checker/compare/api_precision_threshold.yaml rename to debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml index 7565112da3122c636ac1e67a8494bbaed51d17c7..0684bd8e9129653b6b69afcf43ab19207006801f 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/api_precision_threshold.yaml +++ b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/compare/api_precision_threshold.yaml @@ -1,390 +1,390 @@ -mul: - torch.float32: - rtol: 0.000001 - small_value: 0.000001 - small_value_atol: 0.000001 - torch.float16: - rtol: 0.001 - small_value: 0.001 - small_value_atol: 0.001 - torch.bfloat16: - rtol: 0.004 - small_value: 0.001 - small_value_atol: 0.001 -mul_: - torch.float32: - rtol: 0.000001 - small_value: 0.000001 - small_value_atol: 0.000001 - torch.float16: - rtol: 0.001 - small_value: 0.001 - small_value_atol: 0.001 - torch.bfloat16: - rtol: 0.004 - small_value: 0.001 - small_value_atol: 0.001 -__mul__: - torch.float32: - rtol: 0.000001 - small_value: 0.000001 - small_value_atol: 0.000001 - torch.float16: - rtol: 0.001 - small_value: 0.001 - small_value_atol: 0.001 - torch.bfloat16: - rtol: 0.004 - small_value: 0.001 - small_value_atol: 0.001 -__imul__: - torch.float32: - rtol: 0.000001 - small_value: 0.000001 - small_value_atol: 0.000001 - torch.float16: - rtol: 0.001 - small_value: 0.001 - small_value_atol: 0.001 - torch.bfloat16: - rtol: 0.004 - small_value: 0.001 - small_value_atol: 0.001 -__rmul__: - torch.float32: - rtol: 0.000001 - small_value: 0.000001 - small_value_atol: 0.000001 - torch.float16: - rtol: 0.001 - small_value: 0.001 - small_value_atol: 0.001 - torch.bfloat16: - rtol: 0.004 - small_value: 0.001 - small_value_atol: 0.001 -add: - torch.float32: - rtol: 0.000001 - small_value: 0.000001 - small_value_atol: 0.000001 - torch.float16: - rtol: 0.001 - small_value: 0.001 - small_value_atol: 0.001 - torch.bfloat16: - rtol: 0.004 - small_value: 0.001 - small_value_atol: 0.001 -add_: - torch.float32: - rtol: 0.000001 - small_value: 0.000001 - small_value_atol: 0.000001 - torch.float16: - rtol: 0.001 - small_value: 0.001 - small_value_atol: 0.001 - torch.bfloat16: - rtol: 0.004 - small_value: 0.001 - small_value_atol: 0.001 -__add__: - torch.float32: - rtol: 0.000001 - small_value: 0.000001 - small_value_atol: 0.000001 - torch.float16: - rtol: 0.001 - small_value: 0.001 - small_value_atol: 0.001 - torch.bfloat16: - rtol: 0.004 - small_value: 0.001 - small_value_atol: 0.001 -__iadd__: - torch.float32: - rtol: 0.000001 - small_value: 0.000001 - small_value_atol: 0.000001 - torch.float16: - rtol: 0.001 - small_value: 0.001 - small_value_atol: 0.001 - torch.bfloat16: - rtol: 0.004 - small_value: 0.001 - small_value_atol: 0.001 -__radd__: - torch.float32: - rtol: 0.000001 - small_value: 0.000001 - small_value_atol: 0.000001 - torch.float16: - rtol: 0.001 - small_value: 0.001 - small_value_atol: 0.001 - torch.bfloat16: - rtol: 0.004 - small_value: 0.001 - small_value_atol: 0.001 -div: - torch.float32: - rtol: 0.000001 - small_value: 0.000001 - small_value_atol: 0.000001 - torch.float16: - rtol: 0.001 - small_value: 0.001 - small_value_atol: 0.001 - torch.bfloat16: - rtol: 0.004 - small_value: 0.001 - small_value_atol: 0.001 -div_: - torch.float32: - rtol: 0.000001 - small_value: 0.000001 - small_value_atol: 0.000001 - torch.float16: - rtol: 0.001 - small_value: 0.001 - small_value_atol: 0.001 - torch.bfloat16: - rtol: 0.004 - small_value: 0.001 - small_value_atol: 0.001 -__div__: - torch.float32: - rtol: 0.000001 - small_value: 0.000001 - small_value_atol: 0.000001 - torch.float16: - rtol: 0.001 - small_value: 0.001 - small_value_atol: 0.001 - torch.bfloat16: - rtol: 0.004 - small_value: 0.001 - small_value_atol: 0.001 -__idiv__: - torch.float32: - rtol: 0.000001 - small_value: 0.000001 - small_value_atol: 0.000001 - torch.float16: - rtol: 0.001 - small_value: 0.001 - small_value_atol: 0.001 - torch.bfloat16: - rtol: 0.004 - small_value: 0.001 - small_value_atol: 0.001 -divide: - torch.float32: - rtol: 0.000001 - small_value: 0.000001 - small_value_atol: 0.000001 - torch.float16: - rtol: 0.001 - small_value: 0.001 - small_value_atol: 0.001 - torch.bfloat16: - rtol: 0.004 - small_value: 0.001 - small_value_atol: 0.001 -divide_: - torch.float32: - rtol: 0.000001 - small_value: 0.000001 - small_value_atol: 0.000001 - torch.float16: - rtol: 0.001 - small_value: 0.001 - small_value_atol: 0.001 - torch.bfloat16: - rtol: 0.004 - small_value: 0.001 - small_value_atol: 0.001 -leaky_relu: - torch.float32: - rtol: 0.000001 - small_value: 0.000001 - small_value_atol: 0.000001 - torch.float16: - rtol: 0.001 - small_value: 0.001 - small_value_atol: 0.001 - torch.bfloat16: - rtol: 0.004 - small_value: 0.001 - small_value_atol: 0.001 -leaky_relu_: - torch.float32: - rtol: 0.000001 - small_value: 0.000001 - small_value_atol: 0.000001 - torch.float16: - rtol: 0.001 - small_value: 0.001 - small_value_atol: 0.001 - torch.bfloat16: - rtol: 0.004 - small_value: 0.001 - small_value_atol: 0.001 -prelu: - torch.float32: - rtol: 0.000001 - small_value: 0.000001 - small_value_atol: 0.000001 - torch.float16: - rtol: 0.001 - small_value: 0.001 - small_value_atol: 0.001 - torch.bfloat16: - rtol: 0.004 - small_value: 0.001 - small_value_atol: 0.001 -reciprocal: - torch.float32: - rtol: 0.000001 - small_value: 0.000001 - small_value_atol: 0.000001 - torch.float16: - rtol: 0.001 - small_value: 0.001 - small_value_atol: 0.001 - torch.bfloat16: - rtol: 0.004 - small_value: 0.001 - small_value_atol: 0.001 -reciprocal_: - torch.float32: - rtol: 0.000001 - small_value: 0.000001 - small_value_atol: 0.000001 - torch.float16: - rtol: 0.001 - small_value: 0.001 - small_value_atol: 0.001 - torch.bfloat16: - rtol: 0.004 - small_value: 0.001 - small_value_atol: 0.001 -rsqrt: - torch.float32: - rtol: 0.000001 - small_value: 0.000001 - small_value_atol: 0.000001 - torch.float16: - rtol: 0.001 - small_value: 0.001 - small_value_atol: 0.001 - torch.bfloat16: - rtol: 0.004 - small_value: 0.001 - small_value_atol: 0.001 -rsqrt_: - torch.float32: - rtol: 0.000001 - small_value: 0.000001 - small_value_atol: 0.000001 - torch.float16: - rtol: 0.001 - small_value: 0.001 - small_value_atol: 0.001 - torch.bfloat16: - rtol: 0.004 - small_value: 0.001 - small_value_atol: 0.001 -square: - torch.float32: - rtol: 0.000001 - small_value: 0.000001 - small_value_atol: 0.000001 - torch.float16: - rtol: 0.001 - small_value: 0.001 - small_value_atol: 0.001 - torch.bfloat16: - rtol: 0.004 - small_value: 0.001 - small_value_atol: 0.001 -square_: - torch.float32: - rtol: 0.000001 - small_value: 0.000001 - small_value_atol: 0.000001 - torch.float16: - rtol: 0.001 - small_value: 0.001 - small_value_atol: 0.001 - torch.bfloat16: - rtol: 0.004 - small_value: 0.001 - small_value_atol: 0.001 -sub: - torch.float32: - rtol: 0.000001 - small_value: 0.000001 - small_value_atol: 0.000001 - torch.float16: - rtol: 0.001 - small_value: 0.001 - small_value_atol: 0.001 - torch.bfloat16: - rtol: 0.004 - small_value: 0.001 - small_value_atol: 0.001 -sub_: - torch.float32: - rtol: 0.000001 - small_value: 0.000001 - small_value_atol: 0.000001 - torch.float16: - rtol: 0.001 - small_value: 0.001 - small_value_atol: 0.001 - torch.bfloat16: - rtol: 0.004 - small_value: 0.001 - small_value_atol: 0.001 -rsub: - torch.float32: - rtol: 0.000001 - small_value: 0.000001 - small_value_atol: 0.000001 - torch.float16: - rtol: 0.001 - small_value: 0.001 - small_value_atol: 0.001 - torch.bfloat16: - rtol: 0.004 - small_value: 0.001 - small_value_atol: 0.001 -__isub__: - torch.float32: - rtol: 0.000001 - small_value: 0.000001 - small_value_atol: 0.000001 - torch.float16: - rtol: 0.001 - small_value: 0.001 - small_value_atol: 0.001 - torch.bfloat16: - rtol: 0.004 - small_value: 0.001 - small_value_atol: 0.001 -__sub__: - torch.float32: - rtol: 0.000001 - small_value: 0.000001 - small_value_atol: 0.000001 - torch.float16: - rtol: 0.001 - small_value: 0.001 - small_value_atol: 0.001 - torch.bfloat16: - rtol: 0.004 - small_value: 0.001 - small_value_atol: 0.001 +mul: + torch.float32: + rtol: 0.000001 + small_value: 0.000001 + small_value_atol: 0.000001 + torch.float16: + rtol: 0.001 + small_value: 0.001 + small_value_atol: 0.001 + torch.bfloat16: + rtol: 0.004 + small_value: 0.001 + small_value_atol: 0.001 +mul_: + torch.float32: + rtol: 0.000001 + small_value: 0.000001 + small_value_atol: 0.000001 + torch.float16: + rtol: 0.001 + small_value: 0.001 + small_value_atol: 0.001 + torch.bfloat16: + rtol: 0.004 + small_value: 0.001 + small_value_atol: 0.001 +__mul__: + torch.float32: + rtol: 0.000001 + small_value: 0.000001 + small_value_atol: 0.000001 + torch.float16: + rtol: 0.001 + small_value: 0.001 + small_value_atol: 0.001 + torch.bfloat16: + rtol: 0.004 + small_value: 0.001 + small_value_atol: 0.001 +__imul__: + torch.float32: + rtol: 0.000001 + small_value: 0.000001 + small_value_atol: 0.000001 + torch.float16: + rtol: 0.001 + small_value: 0.001 + small_value_atol: 0.001 + torch.bfloat16: + rtol: 0.004 + small_value: 0.001 + small_value_atol: 0.001 +__rmul__: + torch.float32: + rtol: 0.000001 + small_value: 0.000001 + small_value_atol: 0.000001 + torch.float16: + rtol: 0.001 + small_value: 0.001 + small_value_atol: 0.001 + torch.bfloat16: + rtol: 0.004 + small_value: 0.001 + small_value_atol: 0.001 +add: + torch.float32: + rtol: 0.000001 + small_value: 0.000001 + small_value_atol: 0.000001 + torch.float16: + rtol: 0.001 + small_value: 0.001 + small_value_atol: 0.001 + torch.bfloat16: + rtol: 0.004 + small_value: 0.001 + small_value_atol: 0.001 +add_: + torch.float32: + rtol: 0.000001 + small_value: 0.000001 + small_value_atol: 0.000001 + torch.float16: + rtol: 0.001 + small_value: 0.001 + small_value_atol: 0.001 + torch.bfloat16: + rtol: 0.004 + small_value: 0.001 + small_value_atol: 0.001 +__add__: + torch.float32: + rtol: 0.000001 + small_value: 0.000001 + small_value_atol: 0.000001 + torch.float16: + rtol: 0.001 + small_value: 0.001 + small_value_atol: 0.001 + torch.bfloat16: + rtol: 0.004 + small_value: 0.001 + small_value_atol: 0.001 +__iadd__: + torch.float32: + rtol: 0.000001 + small_value: 0.000001 + small_value_atol: 0.000001 + torch.float16: + rtol: 0.001 + small_value: 0.001 + small_value_atol: 0.001 + torch.bfloat16: + rtol: 0.004 + small_value: 0.001 + small_value_atol: 0.001 +__radd__: + torch.float32: + rtol: 0.000001 + small_value: 0.000001 + small_value_atol: 0.000001 + torch.float16: + rtol: 0.001 + small_value: 0.001 + small_value_atol: 0.001 + torch.bfloat16: + rtol: 0.004 + small_value: 0.001 + small_value_atol: 0.001 +div: + torch.float32: + rtol: 0.000001 + small_value: 0.000001 + small_value_atol: 0.000001 + torch.float16: + rtol: 0.001 + small_value: 0.001 + small_value_atol: 0.001 + torch.bfloat16: + rtol: 0.004 + small_value: 0.001 + small_value_atol: 0.001 +div_: + torch.float32: + rtol: 0.000001 + small_value: 0.000001 + small_value_atol: 0.000001 + torch.float16: + rtol: 0.001 + small_value: 0.001 + small_value_atol: 0.001 + torch.bfloat16: + rtol: 0.004 + small_value: 0.001 + small_value_atol: 0.001 +__div__: + torch.float32: + rtol: 0.000001 + small_value: 0.000001 + small_value_atol: 0.000001 + torch.float16: + rtol: 0.001 + small_value: 0.001 + small_value_atol: 0.001 + torch.bfloat16: + rtol: 0.004 + small_value: 0.001 + small_value_atol: 0.001 +__idiv__: + torch.float32: + rtol: 0.000001 + small_value: 0.000001 + small_value_atol: 0.000001 + torch.float16: + rtol: 0.001 + small_value: 0.001 + small_value_atol: 0.001 + torch.bfloat16: + rtol: 0.004 + small_value: 0.001 + small_value_atol: 0.001 +divide: + torch.float32: + rtol: 0.000001 + small_value: 0.000001 + small_value_atol: 0.000001 + torch.float16: + rtol: 0.001 + small_value: 0.001 + small_value_atol: 0.001 + torch.bfloat16: + rtol: 0.004 + small_value: 0.001 + small_value_atol: 0.001 +divide_: + torch.float32: + rtol: 0.000001 + small_value: 0.000001 + small_value_atol: 0.000001 + torch.float16: + rtol: 0.001 + small_value: 0.001 + small_value_atol: 0.001 + torch.bfloat16: + rtol: 0.004 + small_value: 0.001 + small_value_atol: 0.001 +leaky_relu: + torch.float32: + rtol: 0.000001 + small_value: 0.000001 + small_value_atol: 0.000001 + torch.float16: + rtol: 0.001 + small_value: 0.001 + small_value_atol: 0.001 + torch.bfloat16: + rtol: 0.004 + small_value: 0.001 + small_value_atol: 0.001 +leaky_relu_: + torch.float32: + rtol: 0.000001 + small_value: 0.000001 + small_value_atol: 0.000001 + torch.float16: + rtol: 0.001 + small_value: 0.001 + small_value_atol: 0.001 + torch.bfloat16: + rtol: 0.004 + small_value: 0.001 + small_value_atol: 0.001 +prelu: + torch.float32: + rtol: 0.000001 + small_value: 0.000001 + small_value_atol: 0.000001 + torch.float16: + rtol: 0.001 + small_value: 0.001 + small_value_atol: 0.001 + torch.bfloat16: + rtol: 0.004 + small_value: 0.001 + small_value_atol: 0.001 +reciprocal: + torch.float32: + rtol: 0.000001 + small_value: 0.000001 + small_value_atol: 0.000001 + torch.float16: + rtol: 0.001 + small_value: 0.001 + small_value_atol: 0.001 + torch.bfloat16: + rtol: 0.004 + small_value: 0.001 + small_value_atol: 0.001 +reciprocal_: + torch.float32: + rtol: 0.000001 + small_value: 0.000001 + small_value_atol: 0.000001 + torch.float16: + rtol: 0.001 + small_value: 0.001 + small_value_atol: 0.001 + torch.bfloat16: + rtol: 0.004 + small_value: 0.001 + small_value_atol: 0.001 +rsqrt: + torch.float32: + rtol: 0.000001 + small_value: 0.000001 + small_value_atol: 0.000001 + torch.float16: + rtol: 0.001 + small_value: 0.001 + small_value_atol: 0.001 + torch.bfloat16: + rtol: 0.004 + small_value: 0.001 + small_value_atol: 0.001 +rsqrt_: + torch.float32: + rtol: 0.000001 + small_value: 0.000001 + small_value_atol: 0.000001 + torch.float16: + rtol: 0.001 + small_value: 0.001 + small_value_atol: 0.001 + torch.bfloat16: + rtol: 0.004 + small_value: 0.001 + small_value_atol: 0.001 +square: + torch.float32: + rtol: 0.000001 + small_value: 0.000001 + small_value_atol: 0.000001 + torch.float16: + rtol: 0.001 + small_value: 0.001 + small_value_atol: 0.001 + torch.bfloat16: + rtol: 0.004 + small_value: 0.001 + small_value_atol: 0.001 +square_: + torch.float32: + rtol: 0.000001 + small_value: 0.000001 + small_value_atol: 0.000001 + torch.float16: + rtol: 0.001 + small_value: 0.001 + small_value_atol: 0.001 + torch.bfloat16: + rtol: 0.004 + small_value: 0.001 + small_value_atol: 0.001 +sub: + torch.float32: + rtol: 0.000001 + small_value: 0.000001 + small_value_atol: 0.000001 + torch.float16: + rtol: 0.001 + small_value: 0.001 + small_value_atol: 0.001 + torch.bfloat16: + rtol: 0.004 + small_value: 0.001 + small_value_atol: 0.001 +sub_: + torch.float32: + rtol: 0.000001 + small_value: 0.000001 + small_value_atol: 0.000001 + torch.float16: + rtol: 0.001 + small_value: 0.001 + small_value_atol: 0.001 + torch.bfloat16: + rtol: 0.004 + small_value: 0.001 + small_value_atol: 0.001 +rsub: + torch.float32: + rtol: 0.000001 + small_value: 0.000001 + small_value_atol: 0.000001 + torch.float16: + rtol: 0.001 + small_value: 0.001 + small_value_atol: 0.001 + torch.bfloat16: + rtol: 0.004 + small_value: 0.001 + small_value_atol: 0.001 +__isub__: + torch.float32: + rtol: 0.000001 + small_value: 0.000001 + small_value_atol: 0.000001 + torch.float16: + rtol: 0.001 + small_value: 0.001 + small_value_atol: 0.001 + torch.bfloat16: + rtol: 0.004 + small_value: 0.001 + small_value_atol: 0.001 +__sub__: + torch.float32: + rtol: 0.000001 + small_value: 0.000001 + small_value_atol: 0.000001 + torch.float16: + rtol: 0.001 + small_value: 0.001 + small_value_atol: 0.001 + torch.bfloat16: + rtol: 0.004 + small_value: 0.001 + small_value_atol: 0.001 diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/compare/compare.py similarity index 97% rename from debug/accuracy_tools/api_accuracy_checker/compare/compare.py rename to debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/compare/compare.py index bd10f77976642331fa8e7bce28a703f0922c1411..e3ab9eda431a3a03243546ca4b36053e2e0f18bb 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py +++ b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/compare/compare.py @@ -5,16 +5,16 @@ import torch import numpy as np from rich.table import Table from rich.console import Console -from api_accuracy_checker.common.utils import get_json_contents, write_csv, print_warn_log -from api_accuracy_checker.compare.compare_utils import CompareConst, check_dtype_comparable, DETAIL_TEST_ROWS, \ +from calibrator.pytorch.api_accuracy_checker.common import get_json_contents, write_csv, print_warn_log, Const +from calibrator.pytorch.api_accuracy_checker import CompareConst, check_dtype_comparable, DETAIL_TEST_ROWS, \ precision_configs, BENCHMARK_COMPARE_SUPPORT_LIST, AbsoluteStandardApi, BinaryStandardApi, apis_threshold -from api_accuracy_checker.compare.compare_column import CompareColumn -from api_accuracy_checker.compare.algorithm import get_rmse, get_error_balance, get_max_rel_err, get_mean_rel_err, \ +from calibrator.pytorch.api_accuracy_checker import CompareColumn +from calibrator.pytorch.api_accuracy_checker.compare.algorithm import get_rmse, get_error_balance, get_max_rel_err, get_mean_rel_err, \ get_rel_err, get_abs_err, get_max_abs_err, get_rel_err_ratio, cosine_sim, get_rel_err_origin, \ get_small_value_err_ratio, get_finite_and_infinite_mask, get_small_value_mask, check_inf_nan_value, \ check_small_value, check_norm_value, get_abs_bench_with_eps -from api_accuracy_checker.common.config import msCheckerConfig -from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileOpen +from calibrator.pytorch.api_accuracy_checker.common import msCheckerConfig +from calibrator.common.file_check import FileOpen class Comparator: @@ -159,7 +159,7 @@ class Comparator: self.write_detail_csv(args) def compare_output(self, full_api_name, bench_output, device_output, bench_grad=None, npu_grad=None): - _, api_name, _ = full_api_name.split("*") + _, api_name, _ = full_api_name.split(Const.SEP) compare_func = self._compare_dropout if "dropout" in full_api_name else self._compare_core_wrapper fwd_success_status, fwd_compare_alg_results = compare_func(api_name, bench_output, device_output) if not (bench_grad and npu_grad): diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/compare_column.py b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/compare/compare_column.py similarity index 97% rename from debug/accuracy_tools/api_accuracy_checker/compare/compare_column.py rename to debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/compare/compare_column.py index 961fce6811efd34789cb06f19d894244da681c33..7d42547909232bde4f01164480f7dc6fe0a87a84 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/compare_column.py +++ b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/compare/compare_column.py @@ -1,4 +1,4 @@ -from api_accuracy_checker.compare.compare_utils import CompareConst +from calibrator.pytorch.api_accuracy_checker import CompareConst class CompareColumn: diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/compare_utils.py b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/compare/compare_utils.py similarity index 97% rename from debug/accuracy_tools/api_accuracy_checker/compare/compare_utils.py rename to debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/compare/compare_utils.py index bce185a9cb9427610022443bfb48dc2e9803089a..4eb7fefac9d5d795464bd57debad509a97c9d248 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/compare_utils.py +++ b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/compare/compare_utils.py @@ -3,8 +3,8 @@ import os import numpy as np import torch import yaml -from api_accuracy_checker.common.utils import Const, print_warn_log, CompareException -from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileOpen +from calibrator.pytorch.api_accuracy_checker.common import Const, print_warn_log, CompareException +from calibrator.common.file_check import FileOpen current_time = time.strftime("%Y%m%d%H%M%S") diff --git a/debug/accuracy_tools/api_accuracy_checker/config.yaml b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/config.yaml similarity index 94% rename from debug/accuracy_tools/api_accuracy_checker/config.yaml rename to debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/config.yaml index a6e70c57ebaec9141434499cfebe2aed6c21a7be..e2582c4539c9408102d3496242651cedeeefeb22 100644 --- a/debug/accuracy_tools/api_accuracy_checker/config.yaml +++ b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/config.yaml @@ -1,9 +1,9 @@ -dump_path: './' -real_data: False -enable_dataloader: False -target_iter: [1] -white_list: [] -error_data_path: './' -jit_compile: True -precision: 14 +dump_path: './' +real_data: False +enable_dataloader: False +target_iter: [1] +white_list: [] +error_data_path: './' +jit_compile: True +precision: 14 \ No newline at end of file diff --git "a/debug/accuracy_tools/api_accuracy_checker/doc/API Accuracy Checker\351\242\204\346\243\200\345\267\245\345\205\267\346\240\207\345\207\206\346\200\247\350\203\275\345\237\272\347\272\277\346\212\245\345\221\212.md" "b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/doc/API Accuracy Checker\351\242\204\346\243\200\345\267\245\345\205\267\346\240\207\345\207\206\346\200\247\350\203\275\345\237\272\347\272\277\346\212\245\345\221\212.md" similarity index 100% rename from "debug/accuracy_tools/api_accuracy_checker/doc/API Accuracy Checker\351\242\204\346\243\200\345\267\245\345\205\267\346\240\207\345\207\206\346\200\247\350\203\275\345\237\272\347\272\277\346\212\245\345\221\212.md" rename to "debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/doc/API Accuracy Checker\351\242\204\346\243\200\345\267\245\345\205\267\346\240\207\345\207\206\346\200\247\350\203\275\345\237\272\347\272\277\346\212\245\345\221\212.md" diff --git a/debug/accuracy_tools/api_accuracy_checker/hook_module/.keep b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/dump/.keep similarity index 100% rename from debug/accuracy_tools/api_accuracy_checker/hook_module/.keep rename to debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/dump/.keep diff --git a/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/dump/__init__.py b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/dump/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c9602292b85f753fd132634b98c74c76460997b0 --- /dev/null +++ b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/dump/__init__.py @@ -0,0 +1 @@ +__all__ = ['set_dump_switch'] diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/dump/api_info.py similarity index 94% rename from debug/accuracy_tools/api_accuracy_checker/dump/api_info.py rename to debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/dump/api_info.py index 50ad39166fb03ae210af92217f424cfdbe6e1eb4..9ffadd6f63a75e7630b3b501442aaa342920d93e 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py +++ b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/dump/api_info.py @@ -1,237 +1,237 @@ -# 定义API INFO,保存基本信息,用于后续结构体的落盘,注意考虑random场景及真实数据场景 -import os -import inspect -import torch -import numpy as np -from api_accuracy_checker.common.config import msCheckerConfig -from api_accuracy_checker.common.utils import print_error_log, write_pt, create_directory, DumpException, \ - get_real_data_path -from ptdbg_ascend.src.python.ptdbg_ascend.common.utils import check_path_before_create - - -def get_tensor_extremum(data, operator): - if data.dtype is torch.bool: - if data.numel() == 0: - return False, False - if operator == 'max': - return True in data, True in data - elif operator == 'min': - return False not in data, False not in data - data_clone = data.float().clone().detach() - if operator == 'max': - max_result = torch._C._VariableFunctionsClass.max(data_clone).item() - if np.isinf(max_result) or np.isnan(max_result): - return handle_tensor_extremum_nan_inf(data_clone, operator), max_result - else: - return max_result, max_result - else: - min_result = torch._C._VariableFunctionsClass.min(data_clone).item() - if np.isinf(min_result) or np.isnan(min_result): - return handle_tensor_extremum_nan_inf(data_clone, operator), min_result - else: - return min_result, min_result - - -def handle_tensor_extremum_nan_inf(data_clone, operator): - data_nan = torch._C._VariableFunctionsClass.isnan(data_clone) - if int(torch._C._VariableFunctionsClass.sum(data_nan)) == data_clone.numel(): - return float('nan') - finite_mask = torch._C._VariableFunctionsClass.isfinite(data_clone) - if int(torch._C._VariableFunctionsClass.sum(finite_mask)) > 0: - finite_values = data_clone[finite_mask] - return torch._C._VariableFunctionsClass.max(finite_values).item() if operator == 'max' else \ - torch._C._VariableFunctionsClass.min(finite_values).item() - else: - data_no_nan = data_clone[~data_nan] - return torch._C._VariableFunctionsClass.max(data_no_nan).item() if operator == 'max' else \ - torch._C._VariableFunctionsClass.min(data_no_nan).item() - - -def get_type_name(name): - left = name.index("'") - right = name.rindex("'") - return name[left + 1: right] - - -def transfer_types(data, dtype): - if 'int' in dtype or 'bool' in dtype: - return int(data) - else: - return float(data) - - -def is_builtin_class(element): - return element is None or isinstance(element, (bool, int, float, str, slice)) - - -def analyze_device_in_kwargs(element): - single_arg = {} - single_arg.update({'type': 'torch.device'}) - if not isinstance(element, str): - if hasattr(element, "index"): - device_value = element.type + ":" + str(element.index) - else: - device_value = element.type - single_arg.update({'value': device_value}) - else: - single_arg.update({'value': element}) - return single_arg - - -def analyze_dtype_in_kwargs(element): - single_arg = {} - single_arg.update({'type': 'torch.dtype'}) - single_arg.update({'value': str(element)}) - return single_arg - - -class APIInfo: - def __init__(self, api_name, save_path, is_save_data=False): - self.api_name = api_name - self.torch_object_key = {'device': analyze_device_in_kwargs, 'dtype': analyze_dtype_in_kwargs} - self.rank = os.getpid() - self.is_save_data = is_save_data - self.save_path = save_path - self.args_num = 0 - - @staticmethod - def get_full_save_path(save_path, dir_name, contain_step=False): - if contain_step: - from api_accuracy_checker.dump.dump import DumpUtil - step_dir = "step" + str(DumpUtil.call_num - 1 if msCheckerConfig.enable_dataloader else DumpUtil.call_num) - rank_dir = f"rank{os.getpid()}" - return os.path.join(save_path, step_dir, dir_name, rank_dir) - else: - return os.path.join(save_path, dir_name) - - def analyze_element(self, element): - if isinstance(element, (list, tuple)): - out = [] - for item in element: - out.append(self.analyze_element(item)) - return out - - if isinstance(element, dict): - out_dict = {} - for key, value in element.items(): - if key in self.torch_object_key.keys(): - fun = self.torch_object_key[key] - out_dict[key] = fun(value) - else: - out_dict[key] = self.analyze_element(value) - return out_dict - - converted_numpy, numpy_type = self._convert_numpy_to_builtin(element) - if converted_numpy is not element: - return self._analyze_numpy(converted_numpy, numpy_type) - - if isinstance(element, torch.Tensor): - return self._analyze_tensor(element) - - if is_builtin_class(element): - return self._analyze_builtin(element) - - msg = f"Type {type(element)} is unsupported at analyze_element" - print_error_log(msg) - raise DumpException(DumpException.INVALID_DATA_ERROR) - - def _analyze_tensor(self, arg): - single_arg = {} - if not self.is_save_data: - single_arg.update({'type': 'torch.Tensor'}) - single_arg.update({'dtype': str(arg.dtype)}) - single_arg.update({'shape': arg.shape}) - max_handle, max_origin = get_tensor_extremum(arg, 'max') - single_arg.update({'Max': transfer_types(max_handle, str(arg.dtype))}) - single_arg.update({'Max_origin': transfer_types(max_origin, str(arg.dtype))}) - min_handle, min_origin = get_tensor_extremum(arg, 'min') - single_arg.update({'Min': transfer_types(min_handle, str(arg.dtype))}) - single_arg.update({'Min_origin': transfer_types(min_origin, str(arg.dtype))}) - single_arg.update({'requires_grad': arg.requires_grad}) - else: - api_args = self.api_name + '.' + str(self.args_num) - check_path_before_create(self.save_path) - create_directory(self.save_path) - file_path = os.path.join(self.save_path, f'{api_args}.pt') - pt_path = write_pt(file_path, arg.contiguous().cpu().detach()) - self.args_num += 1 - real_data_path = get_real_data_path(pt_path) - single_arg.update({'type': 'torch.Tensor'}) - single_arg.update({'datapath': real_data_path}) - single_arg.update({'requires_grad': arg.requires_grad}) - return single_arg - - def _analyze_builtin(self, arg): - single_arg = {} - if self.is_save_data: - self.args_num += 1 - if isinstance(arg, slice): - single_arg.update({'type': "slice"}) - single_arg.update({'value': [arg.start, arg.stop, arg.step]}) - else: - single_arg.update({'type': get_type_name(str(type(arg)))}) - single_arg.update({'value': arg}) - return single_arg - - def _analyze_numpy(self, value, numpy_type): - single_arg = {} - if self.is_save_data: - self.args_num += 1 - single_arg.update({'type': numpy_type}) - single_arg.update({'value': value}) - return single_arg - - def _convert_numpy_to_builtin(self, arg): - type_mapping = { - np.integer: int, - np.floating: float, - np.bool_: bool, - np.complexfloating: complex, - np.str_: str, - np.bytes_: bytes, - np.unicode_: str - } - for numpy_type, builtin_type in type_mapping.items(): - if isinstance(arg, numpy_type): - return builtin_type(arg), get_type_name(str(type(arg))) - return arg, '' - - -class ForwardAPIInfo(APIInfo): - def __init__(self, name, args, kwargs): - super().__init__(name, - self.get_full_save_path(msCheckerConfig.dump_path, 'forward_real_data', contain_step=True), - is_save_data=msCheckerConfig.real_data) - self.api_info_struct = {} - self.stack_info_struct = {} - self.analyze_api_input(args, kwargs) - self.analyze_api_call_stack() - - def analyze_api_input(self, args, kwargs): - args_info_list = self.analyze_element(args) - kwargs_info_dict = self.analyze_element(kwargs) - self.api_info_struct = {self.api_name: {"args": args_info_list, "kwargs": kwargs_info_dict}} - - def analyze_api_call_stack(self): - stack_str = [] - for (_, path, line, func, code, _) in inspect.stack()[3:]: - if not code: - continue - stack_line = " ".join([ - "File", ", ".join([path, " ".join(["line", str(line)]), " ".join(["in", func]), - " ".join(["\n", code[0].strip()])])]) - stack_str.append(stack_line) - self.stack_info_struct = {self.api_name: stack_str} - - -class BackwardAPIInfo(APIInfo): - def __init__(self, name, grads): - super().__init__(name, - self.get_full_save_path(msCheckerConfig.dump_path, 'backward_real_data', contain_step=True), - is_save_data=msCheckerConfig.real_data) - self.grad_info_struct = {} - self.analyze_api_input(grads) - - def analyze_api_input(self, grads): - grads_info_list = self.analyze_element(grads) - self.grad_info_struct = {self.api_name: grads_info_list} +# 定义API INFO,保存基本信息,用于后续结构体的落盘,注意考虑random场景及真实数据场景 +import os +import inspect +import torch +import numpy as np +from calibrator.pytorch.api_accuracy_checker.common import msCheckerConfig +from calibrator.pytorch.api_accuracy_checker.common import print_error_log, write_pt, create_directory, DumpException, \ + get_real_data_path +from calibrator.common.file_check import check_path_before_create + + +def get_tensor_extremum(data, operator): + if data.dtype is torch.bool: + if data.numel() == 0: + return False, False + if operator == 'max': + return True in data, True in data + elif operator == 'min': + return False not in data, False not in data + data_clone = data.float().clone().detach() + if operator == 'max': + max_result = torch._C._VariableFunctionsClass.max(data_clone).item() + if np.isinf(max_result) or np.isnan(max_result): + return handle_tensor_extremum_nan_inf(data_clone, operator), max_result + else: + return max_result, max_result + else: + min_result = torch._C._VariableFunctionsClass.min(data_clone).item() + if np.isinf(min_result) or np.isnan(min_result): + return handle_tensor_extremum_nan_inf(data_clone, operator), min_result + else: + return min_result, min_result + + +def handle_tensor_extremum_nan_inf(data_clone, operator): + data_nan = torch._C._VariableFunctionsClass.isnan(data_clone) + if int(torch._C._VariableFunctionsClass.sum(data_nan)) == data_clone.numel(): + return float('nan') + finite_mask = torch._C._VariableFunctionsClass.isfinite(data_clone) + if int(torch._C._VariableFunctionsClass.sum(finite_mask)) > 0: + finite_values = data_clone[finite_mask] + return torch._C._VariableFunctionsClass.max(finite_values).item() if operator == 'max' else \ + torch._C._VariableFunctionsClass.min(finite_values).item() + else: + data_no_nan = data_clone[~data_nan] + return torch._C._VariableFunctionsClass.max(data_no_nan).item() if operator == 'max' else \ + torch._C._VariableFunctionsClass.min(data_no_nan).item() + + +def get_type_name(name): + left = name.index("'") + right = name.rindex("'") + return name[left + 1: right] + + +def transfer_types(data, dtype): + if 'int' in dtype or 'bool' in dtype: + return int(data) + else: + return float(data) + + +def is_builtin_class(element): + return element is None or isinstance(element, (bool, int, float, str, slice)) + + +def analyze_device_in_kwargs(element): + single_arg = {} + single_arg.update({'type': 'torch.device'}) + if not isinstance(element, str): + if hasattr(element, "index"): + device_value = element.type + ":" + str(element.index) + else: + device_value = element.type + single_arg.update({'value': device_value}) + else: + single_arg.update({'value': element}) + return single_arg + + +def analyze_dtype_in_kwargs(element): + single_arg = {} + single_arg.update({'type': 'torch.dtype'}) + single_arg.update({'value': str(element)}) + return single_arg + + +class APIInfo: + def __init__(self, api_name, save_path, is_save_data=False): + self.api_name = api_name + self.torch_object_key = {'device': analyze_device_in_kwargs, 'dtype': analyze_dtype_in_kwargs} + self.rank = os.getpid() + self.is_save_data = is_save_data + self.save_path = save_path + self.args_num = 0 + + @staticmethod + def get_full_save_path(save_path, dir_name, contain_step=False): + if contain_step: + from calibrator.pytorch.api_accuracy_checker.dump.dump import DumpUtil + step_dir = "step" + str(DumpUtil.call_num - 1 if msCheckerConfig.enable_dataloader else DumpUtil.call_num) + rank_dir = f"rank{os.getpid()}" + return os.path.join(save_path, step_dir, dir_name, rank_dir) + else: + return os.path.join(save_path, dir_name) + + def analyze_element(self, element): + if isinstance(element, (list, tuple)): + out = [] + for item in element: + out.append(self.analyze_element(item)) + return out + + if isinstance(element, dict): + out_dict = {} + for key, value in element.items(): + if key in self.torch_object_key.keys(): + fun = self.torch_object_key[key] + out_dict[key] = fun(value) + else: + out_dict[key] = self.analyze_element(value) + return out_dict + + converted_numpy, numpy_type = self._convert_numpy_to_builtin(element) + if converted_numpy is not element: + return self._analyze_numpy(converted_numpy, numpy_type) + + if isinstance(element, torch.Tensor): + return self._analyze_tensor(element) + + if is_builtin_class(element): + return self._analyze_builtin(element) + + msg = f"Type {type(element)} is unsupported at analyze_element" + print_error_log(msg) + raise DumpException(DumpException.INVALID_DATA_ERROR) + + def _analyze_tensor(self, arg): + single_arg = {} + if not self.is_save_data: + single_arg.update({'type': 'torch.Tensor'}) + single_arg.update({'dtype': str(arg.dtype)}) + single_arg.update({'shape': arg.shape}) + max_handle, max_origin = get_tensor_extremum(arg, 'max') + single_arg.update({'Max': transfer_types(max_handle, str(arg.dtype))}) + single_arg.update({'Max_origin': transfer_types(max_origin, str(arg.dtype))}) + min_handle, min_origin = get_tensor_extremum(arg, 'min') + single_arg.update({'Min': transfer_types(min_handle, str(arg.dtype))}) + single_arg.update({'Min_origin': transfer_types(min_origin, str(arg.dtype))}) + single_arg.update({'requires_grad': arg.requires_grad}) + else: + api_args = self.api_name + '.' + str(self.args_num) + check_path_before_create(self.save_path) + create_directory(self.save_path) + file_path = os.path.join(self.save_path, f'{api_args}.pt') + pt_path = write_pt(file_path, arg.contiguous().cpu().detach()) + self.args_num += 1 + real_data_path = get_real_data_path(pt_path) + single_arg.update({'type': 'torch.Tensor'}) + single_arg.update({'datapath': real_data_path}) + single_arg.update({'requires_grad': arg.requires_grad}) + return single_arg + + def _analyze_builtin(self, arg): + single_arg = {} + if self.is_save_data: + self.args_num += 1 + if isinstance(arg, slice): + single_arg.update({'type': "slice"}) + single_arg.update({'value': [arg.start, arg.stop, arg.step]}) + else: + single_arg.update({'type': get_type_name(str(type(arg)))}) + single_arg.update({'value': arg}) + return single_arg + + def _analyze_numpy(self, value, numpy_type): + single_arg = {} + if self.is_save_data: + self.args_num += 1 + single_arg.update({'type': numpy_type}) + single_arg.update({'value': value}) + return single_arg + + def _convert_numpy_to_builtin(self, arg): + type_mapping = { + np.integer: int, + np.floating: float, + np.bool_: bool, + np.complexfloating: complex, + np.str_: str, + np.bytes_: bytes, + np.unicode_: str + } + for numpy_type, builtin_type in type_mapping.items(): + if isinstance(arg, numpy_type): + return builtin_type(arg), get_type_name(str(type(arg))) + return arg, '' + + +class ForwardAPIInfo(APIInfo): + def __init__(self, name, args, kwargs): + super().__init__(name, + self.get_full_save_path(msCheckerConfig.dump_path, 'forward_real_data', contain_step=True), + is_save_data=msCheckerConfig.real_data) + self.api_info_struct = {} + self.stack_info_struct = {} + self.analyze_api_input(args, kwargs) + self.analyze_api_call_stack() + + def analyze_api_input(self, args, kwargs): + args_info_list = self.analyze_element(args) + kwargs_info_dict = self.analyze_element(kwargs) + self.api_info_struct = {self.api_name: {"args": args_info_list, "kwargs": kwargs_info_dict}} + + def analyze_api_call_stack(self): + stack_str = [] + for (_, path, line, func, code, _) in inspect.stack()[3:]: + if not code: + continue + stack_line = " ".join([ + "File", ", ".join([path, " ".join(["line", str(line)]), " ".join(["in", func]), + " ".join(["\n", code[0].strip()])])]) + stack_str.append(stack_line) + self.stack_info_struct = {self.api_name: stack_str} + + +class BackwardAPIInfo(APIInfo): + def __init__(self, name, grads): + super().__init__(name, + self.get_full_save_path(msCheckerConfig.dump_path, 'backward_real_data', contain_step=True), + is_save_data=msCheckerConfig.real_data) + self.grad_info_struct = {} + self.analyze_api_input(grads) + + def analyze_api_input(self, grads): + grads_info_list = self.analyze_element(grads) + self.grad_info_struct = {self.api_name: grads_info_list} diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/dump.py b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/dump/dump.py similarity index 86% rename from debug/accuracy_tools/api_accuracy_checker/dump/dump.py rename to debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/dump/dump.py index d8b317aa282bb74f7b35c6a4b6216446959fb30e..01e788a9a51837396cfb79b2a3639b3e7f190464 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/dump.py +++ b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/dump/dump.py @@ -1,109 +1,109 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -# Copyright (C) 2023-2023. Huawei Technologies Co., Ltd. All rights reserved. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" - -from api_accuracy_checker.dump.api_info import ForwardAPIInfo, BackwardAPIInfo -from api_accuracy_checker.dump.info_dump import write_api_info_json, initialize_output_json -from api_accuracy_checker.common.utils import print_error_log, CompareException, print_info_log -from api_accuracy_checker.hook_module.register_hook import initialize_hook -from api_accuracy_checker.common.config import msCheckerConfig - - -def set_dump_switch(switch): - if switch not in ["ON", "OFF"]: - print_error_log("Please set switch with 'ON' or 'OFF'.") - raise CompareException(CompareException.INVALID_PARAM_ERROR) - if switch == "ON": - initialize_hook(pretest_hook) - initialize_output_json() - DumpUtil.set_dump_switch(switch) - - -def check_dataloader_status(): - if msCheckerConfig.enable_dataloader: - error_info = ("If you want to use this function, set enable_dataloader " - "in the accuracy_tools/api_accuracy_check/config.yaml " - "to False first") - raise CompareException(CompareException.INVALID_PARAM_ERROR, error_info) - - -def start(): - check_dataloader_status() - if not DumpUtil.get_dump_switch(): - DumpUtil.incr_iter_num_maybe_exit() - - -def stop(): - check_dataloader_status() - DumpUtil.set_dump_switch("OFF") - - -def step(): - check_dataloader_status() - DumpUtil.call_num += 1 - - -class DumpUtil(object): - dump_switch = None - call_num = 0 - - @staticmethod - def set_dump_switch(switch): - DumpUtil.dump_switch = switch - - @staticmethod - def get_dump_switch(): - return DumpUtil.dump_switch == "ON" - - @staticmethod - def incr_iter_num_maybe_exit(): - if DumpUtil.call_num in msCheckerConfig.target_iter: - set_dump_switch("ON") - elif DumpUtil.call_num > max(msCheckerConfig.target_iter): - raise Exception("Model pretest: exit after iteration {}".format(DumpUtil.call_num - 1)) - else: - set_dump_switch("OFF") - - -class DumpConst: - delimiter = '*' - forward = 'forward' - backward = 'backward' - - -def pretest_info_dump(name, out_feat, module, phase): - if not DumpUtil.get_dump_switch(): - return - if phase == DumpConst.forward: - api_info = ForwardAPIInfo(name, module.input_args, module.input_kwargs) - elif phase == DumpConst.backward: - api_info = BackwardAPIInfo(name, out_feat) - else: - msg = "Unexpected training phase {}.".format(phase) - print_error_log(msg) - raise NotImplementedError(msg) - print_info_log(f"tools is dumping api: {name}" + " " * 10, end='\r') - write_api_info_json(api_info) - - -def pretest_hook(name, phase): - def pretest_info_dump_hook(module, in_feat, out_feat): - pretest_info_dump(name, out_feat, module, phase) - if hasattr(module, "input_args"): - del module.input_args - if hasattr(module, "input_kwargs"): - del module.input_kwargs - return pretest_info_dump_hook +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2023-2023. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from calibrator.pytorch.api_accuracy_checker import ForwardAPIInfo, BackwardAPIInfo +from calibrator.pytorch.api_accuracy_checker import write_api_info_json, initialize_output_json +from calibrator.pytorch.api_accuracy_checker.common import print_error_log, CompareException, print_info_log +from calibrator.pytorch.api_accuracy_checker import initialize_hook +from calibrator.pytorch.api_accuracy_checker.common import msCheckerConfig + + +def set_dump_switch(switch): + if switch not in ["ON", "OFF"]: + print_error_log("Please set switch with 'ON' or 'OFF'.") + raise CompareException(CompareException.INVALID_PARAM_ERROR) + if switch == "ON": + initialize_hook(pretest_hook) + initialize_output_json() + DumpUtil.set_dump_switch(switch) + + +def check_dataloader_status(): + if msCheckerConfig.enable_dataloader: + error_info = ("If you want to use this function, set enable_dataloader " + "in the accuracy_tools/api_accuracy_check/config.yaml " + "to False first") + raise CompareException(CompareException.INVALID_PARAM_ERROR, error_info) + + +def start(): + check_dataloader_status() + if not DumpUtil.get_dump_switch(): + DumpUtil.incr_iter_num_maybe_exit() + + +def stop(): + check_dataloader_status() + DumpUtil.set_dump_switch("OFF") + + +def step(): + check_dataloader_status() + DumpUtil.call_num += 1 + + +class DumpUtil(object): + dump_switch = None + call_num = 0 + + @staticmethod + def set_dump_switch(switch): + DumpUtil.dump_switch = switch + + @staticmethod + def get_dump_switch(): + return DumpUtil.dump_switch == "ON" + + @staticmethod + def incr_iter_num_maybe_exit(): + if DumpUtil.call_num in msCheckerConfig.target_iter: + set_dump_switch("ON") + elif DumpUtil.call_num > max(msCheckerConfig.target_iter): + raise Exception("Model pretest: exit after iteration {}".format(DumpUtil.call_num - 1)) + else: + set_dump_switch("OFF") + + +class DumpConst: + delimiter = '*' + forward = 'forward' + backward = 'backward' + + +def pretest_info_dump(name, out_feat, module, phase): + if not DumpUtil.get_dump_switch(): + return + if phase == DumpConst.forward: + api_info = ForwardAPIInfo(name, module.input_args, module.input_kwargs) + elif phase == DumpConst.backward: + api_info = BackwardAPIInfo(name, out_feat) + else: + msg = "Unexpected training phase {}.".format(phase) + print_error_log(msg) + raise NotImplementedError(msg) + print_info_log(f"tools is dumping api: {name}" + " " * 10, end='\r') + write_api_info_json(api_info) + + +def pretest_hook(name, phase): + def pretest_info_dump_hook(module, in_feat, out_feat): + pretest_info_dump(name, out_feat, module, phase) + if hasattr(module, "input_args"): + del module.input_args + if hasattr(module, "input_kwargs"): + del module.input_kwargs + return pretest_info_dump_hook diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/dump/dump_scope.py similarity index 72% rename from debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py rename to debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/dump/dump_scope.py index 1f65dbc9c8a7e482d8ac85e3d06cffc3b11b406a..4fe12f372c5d6205ebf8630496430c80b88d422d 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py +++ b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/dump/dump_scope.py @@ -1,8 +1,7 @@ # dump范围控制 -import torch -from torch.utils.data.dataloader import _BaseDataLoaderIter -from api_accuracy_checker.dump.dump import DumpUtil -from api_accuracy_checker.common.config import msCheckerConfig +from torch.utils.data.dataloader import _BaseDataLoaderIter +from calibrator.pytorch.api_accuracy_checker.dump.dump import DumpUtil +from calibrator.pytorch.api_accuracy_checker.common import msCheckerConfig def iter_tracer(original_next): diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/dump/info_dump.py similarity index 83% rename from debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py rename to debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/dump/info_dump.py index c73058e4f3058a9d1bf10b0a14046845f78440ee..64163de74112bf97dcd1333f729e682df205c726 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py +++ b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/dump/info_dump.py @@ -4,20 +4,19 @@ import os import threading import multiprocessing -from api_accuracy_checker.dump.api_info import ForwardAPIInfo, BackwardAPIInfo -from api_accuracy_checker.common.utils import check_file_or_directory_path, initialize_save_path, create_directory -from ptdbg_ascend.src.python.ptdbg_ascend.common.utils import check_path_before_create -from api_accuracy_checker.common.config import msCheckerConfig +from calibrator.pytorch.api_accuracy_checker import ForwardAPIInfo, BackwardAPIInfo +from calibrator.pytorch.api_accuracy_checker.common import check_file_or_directory_path, create_directory +from calibrator.common.file_check import check_path_before_create +from calibrator.common.file_check import FileOpen, FileCheckConst, FileChecker, change_mode +from calibrator.pytorch.api_accuracy_checker.common import msCheckerConfig -from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileOpen, FileCheckConst, FileChecker, change_mode - lock = threading.Lock() proc_lock = multiprocessing.Lock() def write_api_info_json(api_info): - from api_accuracy_checker.dump.dump import DumpUtil + from calibrator.pytorch.api_accuracy_checker.dump.dump import DumpUtil dump_path = msCheckerConfig.dump_path dump_path = os.path.join(msCheckerConfig.dump_path, "step" + str((DumpUtil.call_num - 1) if msCheckerConfig.enable_dataloader else DumpUtil.call_num)) check_path_before_create(dump_path) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/.keep b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/hook_module/.keep similarity index 100% rename from debug/accuracy_tools/api_accuracy_checker/run_ut/.keep rename to debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/hook_module/.keep diff --git a/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/hook_module/__init__.py b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/hook_module/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/debug/accuracy_tools/api_accuracy_checker/hook_module/hook_module.py b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/hook_module/hook_module.py similarity index 97% rename from debug/accuracy_tools/api_accuracy_checker/hook_module/hook_module.py rename to debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/hook_module/hook_module.py index 64df179e4b980d443dc249e19442e3096f300ff1..02d5fa5500e470a158b980ff889ab4d7a7ec25bf 100644 --- a/debug/accuracy_tools/api_accuracy_checker/hook_module/hook_module.py +++ b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/hook_module/hook_module.py @@ -1,113 +1,113 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -# Copyright (C) 2023-2023. Huawei Technologies Co., Ltd. All rights reserved. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" - - -import functools - -import torch -import torch.nn as nn -import torch.utils.hooks as full_hooks - -module_count = {} -g_stop_hook = False - - -class HOOKModule(nn.Module): - - def __init__(self, hook) -> None: - super(HOOKModule, self).__init__() - self.has_overflow = False - self.input_args = tuple() - self.input_kwargs = dict() - self._enable_hook = True - prefix = "" - if hasattr(self, "prefix_op_name_"): - prefix = self.prefix_op_name_ - - if prefix not in module_count: - module_count[prefix] = 1 - prefix += '0' - else: - module_count[prefix] += 1 - prefix = prefix + str(module_count[prefix] - 1) - - self.register_forward_hook(hook(prefix, "forward")) - self.register_backward_hook(hook(prefix, "backward")) - - def __call__(self, *inputs, **kwargs): - changed = False - global g_stop_hook - if g_stop_hook: - self._enable_hook = False - else: - g_stop_hook = True - changed = True - result = self._call_func(*inputs, **kwargs) - if changed: - g_stop_hook = False - return result - - def _call_func(self, *inputs, **kwargs): - if self._enable_hook: - full_backward_hooks, non_full_backward_hooks = [], [] - if len(self._backward_hooks) > 0: - full_backward_hooks, non_full_backward_hooks = self._get_backward_hooks() - for hook in self._forward_pre_hooks.values(): - result = hook(self, inputs) - if result is not None: - if not isinstance(result, tuple): - result = (result,) - inputs = result - bw_hook = None - if len(full_backward_hooks) > 0: - bw_hook = full_hooks.BackwardHook(self, full_backward_hooks) - inputs = bw_hook.setup_input_hook(inputs) - self.input_args = inputs - self.input_kwargs = kwargs - if torch._C._get_tracing_state(): - result = self._slow_forward(*inputs, **kwargs) - else: - result = self.forward(*inputs, **kwargs) - for hook in self._forward_hooks.values(): - hook_result = hook(self, inputs, result) - if hook_result is not None: - result = hook_result - if bw_hook: - result = bw_hook.setup_output_hook(result) - if len(non_full_backward_hooks) > 0: - var = result - while not isinstance(var, torch.Tensor): - if isinstance(var, dict): - var = next((v for v in var.values() if isinstance(v, torch.Tensor))) - elif isinstance(var, (list, tuple)): - if var: - var = var[0] - else: - return result - else: - return result - grad_fn = var.grad_fn - if grad_fn is not None: - for hook in non_full_backward_hooks: - wrapper = functools.partial(hook, self) - functools.update_wrapper(wrapper, hook) - grad_fn.register_hook(wrapper) - self._maybe_warn_non_full_backward_hook(inputs, result, grad_fn) - return result - else: - forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.forward) - return forward_call(*inputs, **kwargs) +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2023-2023. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + + +import functools + +import torch +import torch.nn as nn +import torch.utils.hooks as full_hooks + +module_count = {} +g_stop_hook = False + + +class HOOKModule(nn.Module): + + def __init__(self, hook) -> None: + super(HOOKModule, self).__init__() + self.has_overflow = False + self.input_args = tuple() + self.input_kwargs = dict() + self._enable_hook = True + prefix = "" + if hasattr(self, "prefix_op_name_"): + prefix = self.prefix_op_name_ + + if prefix not in module_count: + module_count[prefix] = 1 + prefix += '0' + else: + module_count[prefix] += 1 + prefix = prefix + str(module_count[prefix] - 1) + + self.register_forward_hook(hook(prefix, "forward")) + self.register_backward_hook(hook(prefix, "backward")) + + def __call__(self, *inputs, **kwargs): + changed = False + global g_stop_hook + if g_stop_hook: + self._enable_hook = False + else: + g_stop_hook = True + changed = True + result = self._call_func(*inputs, **kwargs) + if changed: + g_stop_hook = False + return result + + def _call_func(self, *inputs, **kwargs): + if self._enable_hook: + full_backward_hooks, non_full_backward_hooks = [], [] + if len(self._backward_hooks) > 0: + full_backward_hooks, non_full_backward_hooks = self._get_backward_hooks() + for hook in self._forward_pre_hooks.values(): + result = hook(self, inputs) + if result is not None: + if not isinstance(result, tuple): + result = (result,) + inputs = result + bw_hook = None + if len(full_backward_hooks) > 0: + bw_hook = full_hooks.BackwardHook(self, full_backward_hooks) + inputs = bw_hook.setup_input_hook(inputs) + self.input_args = inputs + self.input_kwargs = kwargs + if torch._C._get_tracing_state(): + result = self._slow_forward(*inputs, **kwargs) + else: + result = self.forward(*inputs, **kwargs) + for hook in self._forward_hooks.values(): + hook_result = hook(self, inputs, result) + if hook_result is not None: + result = hook_result + if bw_hook: + result = bw_hook.setup_output_hook(result) + if len(non_full_backward_hooks) > 0: + var = result + while not isinstance(var, torch.Tensor): + if isinstance(var, dict): + var = next((v for v in var.values() if isinstance(v, torch.Tensor))) + elif isinstance(var, (list, tuple)): + if var: + var = var[0] + else: + return result + else: + return result + grad_fn = var.grad_fn + if grad_fn is not None: + for hook in non_full_backward_hooks: + wrapper = functools.partial(hook, self) + functools.update_wrapper(wrapper, hook) + grad_fn.register_hook(wrapper) + self._maybe_warn_non_full_backward_hook(inputs, result, grad_fn) + return result + else: + forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.forward) + return forward_call(*inputs, **kwargs) diff --git a/debug/accuracy_tools/api_accuracy_checker/hook_module/register_hook.py b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/hook_module/register_hook.py similarity index 92% rename from debug/accuracy_tools/api_accuracy_checker/hook_module/register_hook.py rename to debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/hook_module/register_hook.py index b355e029b6b74e2accc9241b42deebe31cb8e5ca..231a272e8a4cbea7491800106b5b6a3e8a7943f8 100644 --- a/debug/accuracy_tools/api_accuracy_checker/hook_module/register_hook.py +++ b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/hook_module/register_hook.py @@ -1,37 +1,37 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -# Copyright (C) 2023-2023. Huawei Technologies Co., Ltd. All rights reserved. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -import torch - -from api_accuracy_checker.hook_module import wrap_torch, wrap_functional, wrap_tensor - - -def initialize_hook(hook): - wrap_tensor.wrap_tensor_ops_and_bind(hook) - for attr_name in dir(wrap_tensor.HOOKTensor): - if attr_name.startswith("wrap_"): - setattr(torch.Tensor, attr_name[5:], getattr(wrap_tensor.HOOKTensor, attr_name)) - - wrap_torch.wrap_torch_ops_and_bind(hook) - for attr_name in dir(wrap_torch.HOOKTorchOP): - if attr_name.startswith("wrap_"): - setattr(torch, attr_name[5:], getattr(wrap_torch.HOOKTorchOP, attr_name)) - - wrap_functional.wrap_functional_ops_and_bind(hook) - for attr_name in dir(wrap_functional.HOOKFunctionalOP): - if attr_name.startswith("wrap_"): - setattr(torch.nn.functional, attr_name[5:], getattr(wrap_functional.HOOKFunctionalOP, attr_name)) - +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2023-2023. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +import torch + +from calibrator.pytorch.api_accuracy_checker import wrap_functional, wrap_torch, wrap_tensor + + +def initialize_hook(hook): + wrap_tensor.wrap_tensor_ops_and_bind(hook) + for attr_name in dir(wrap_tensor.HOOKTensor): + if attr_name.startswith("wrap_"): + setattr(torch.Tensor, attr_name[5:], getattr(wrap_tensor.HOOKTensor, attr_name)) + + wrap_torch.wrap_torch_ops_and_bind(hook) + for attr_name in dir(wrap_torch.HOOKTorchOP): + if attr_name.startswith("wrap_"): + setattr(torch, attr_name[5:], getattr(wrap_torch.HOOKTorchOP, attr_name)) + + wrap_functional.wrap_functional_ops_and_bind(hook) + for attr_name in dir(wrap_functional.HOOKFunctionalOP): + if attr_name.startswith("wrap_"): + setattr(torch.nn.functional, attr_name[5:], getattr(wrap_functional.HOOKFunctionalOP, attr_name)) + diff --git a/debug/accuracy_tools/api_accuracy_checker/hook_module/support_wrap_ops.yaml b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/hook_module/support_wrap_ops.yaml similarity index 93% rename from debug/accuracy_tools/api_accuracy_checker/hook_module/support_wrap_ops.yaml rename to debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/hook_module/support_wrap_ops.yaml index c7ed0a1f81cf7b6b2e17ce0e6c37965567f5e42a..acd4cc0e6e658dd4278f6a67c4f0e8fc288efde6 100644 --- a/debug/accuracy_tools/api_accuracy_checker/hook_module/support_wrap_ops.yaml +++ b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/hook_module/support_wrap_ops.yaml @@ -1,999 +1,999 @@ -# Copyright (c) 2023 Huawei Technologies Co., Ltd -# All rights reserved. -# -# Licensed under the BSD 3-Clause License (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://opensource.org/licenses/BSD-3-Clause -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# List of ops that register hooks - -functional: - - conv1d - - conv2d - - conv3d - - conv_transpose1d - - conv_transpose2d - - conv_transpose3d - - conv_tbc - - avg_pool1d - - avg_pool2d - - avg_pool3d - - fractional_max_pool2d_with_indices - - fractional_max_pool2d - - fractional_max_pool3d_with_indices - - fractional_max_pool3d - - max_pool1d_with_indices - - max_pool1d - - max_pool2d_with_indices - - max_pool2d - - max_pool3d_with_indices - - max_pool3d - - max_unpool1d - - max_unpool2d - - max_unpool3d - - lp_pool2d - - lp_pool1d - - adaptive_max_pool1d_with_indices - - adaptive_max_pool1d - - adaptive_max_pool2d_with_indices - - adaptive_max_pool2d - - adaptive_max_pool3d_with_indices - - adaptive_max_pool3d - - adaptive_avg_pool1d - - adaptive_avg_pool2d - - adaptive_avg_pool3d - - dropout - - alpha_dropout - - dropout2d - - dropout3d - - feature_alpha_dropout - - threshold - - threshold_ - - relu - - relu_ - - glu - - hardtanh - - hardtanh_ - - relu6 - - elu - - elu_ - - selu - - selu_ - - celu - - celu_ - - leaky_relu - - leaky_relu_ - - prelu - - rrelu - - rrelu_ - - logsigmoid - - gelu - - hardshrink - - tanhshrink - - softsign - - softplus - - softmin - - softmax - - gumbel_softmax - - log_softmax - - softshrink - - tanh - - sigmoid - - hardsigmoid - - linear - - bilinear - - silu - - hardswish - - embedding - - embedding_bag - - batch_norm - - instance_norm - - layer_norm - - group_norm - - local_response_norm - - ctc_loss - - nll_loss - - poisson_nll_loss - - gaussian_nll_loss - - kl_div - - cross_entropy - - binary_cross_entropy - - binary_cross_entropy_with_logits - - smooth_l1_loss - - l1_loss - - mse_loss - - margin_ranking_loss - - hinge_embedding_loss - - multilabel_margin_loss - - soft_margin_loss - - multilabel_soft_margin_loss - - cosine_embedding_loss - - multi_margin_loss - - pixel_shuffle - - pixel_unshuffle - - channel_shuffle - - upsample - - interpolate - - upsample_nearest - - upsample_bilinear - - grid_sample - - affine_grid - - pad - - pairwise_distance - - pdist - - cosine_similarity - - one_hot - - triplet_margin_loss - - triplet_margin_with_distance_loss - - normalize - - unfold - - fold - - multi_head_attention_forward - -tensor: - - __add__ - - __and__ - - __bool__ - - __div__ - - __eq__ - - __ge__ - - __gt__ - - __iadd__ - - __iand__ - - __idiv__ - - __ifloordiv__ - - __ilshift__ - - __imod__ - - __imul__ - - __ior__ - - __irshift__ - - __isub__ - - __ixor__ - - __lshift__ - - __matmul__ - - __mod__ - - __mul__ - - __nonzero__ - - __or__ - - __radd__ - - __rmul__ - - __rshift__ - - __sub__ - - __truediv__ - - __xor__ - - abs - - abs_ - - absolute - - absolute_ - - acos - - acos_ - - acosh - - acosh_ - - add - - add_ - - addbmm - - addbmm_ - - addcdiv - - addcdiv_ - - addcmul - - addcmul_ - - addmm - - addmm_ - - addmv - - addmv_ - - addr - - addr_ - - align_as - - align_to - - all - - allclose - - amax - - amin - - angle - - any - - arccos - - arccos_ - - arccosh - - arccosh_ - - arcsin - - arcsin_ - - arcsinh - - arcsinh_ - - arctan - - arctan_ - - arctanh - - arctanh_ - - argmax - - argmin - - argsort - - asin - - asin_ - - asinh - - asinh_ - - atan - - atan2 - - atan2_ - - atan_ - - atanh - - atanh_ - - baddbmm - - baddbmm_ - - bernoulli - - bernoulli_ - - bincount - - bitwise_and - - bitwise_and_ - - bitwise_not - - bitwise_not_ - - bitwise_or - - bitwise_or_ - - bitwise_xor - - bitwise_xor_ - - bmm - - broadcast_to - - cauchy_ - - ceil - - ceil_ - - cholesky - - chunk - - clamp - - cholesky_solve - - cholesky_inverse - - clamp_ - - clamp_max - - clamp_max_ - - clip - - clamp_min - - clamp_min_ - - clip_ - - copysign - - copysign_ - - cos - - cos_ - - cosh - - cosh_ - - count_nonzero - - cummax - - cummin - - cumprod - - cumprod_ - - cumsum - - cumsum_ - - deg2rad - - deg2rad_ - - det - - diag - - diag_embed - - diagflat - - diagonal - - diff - - dist - - digamma - - digamma_ - - div - - div_ - - divide - - divide_ - - dot - - eig - - eq - - eq_ - - erf - - equal - - erf_ - - erfc - - erfc_ - - erfinv - - erfinv_ - - exp - - exp2 - - exp2_ - - expm1 - - exp_ - - expm1_ - - exponential_ - - fill_ - - fix - - fill_diagonal_ - - fix_ - - flip - - fliplr - - flatten - - flipud - - float_power - - float_power_ - - floor - - floor_ - - floor_divide - - floor_divide_ - - fmax - - fmin - - fmod - - fmod_ - - frac - - frac_ - - gather - - gcd - - gcd_ - - ge - - ge_ - - geometric_ - - geqrf - - ger - - greater - - greater_ - - gt - - gt_ - - greater_equal - - greater_equal_ - - hardshrink - - heaviside - - heaviside_ - - histc - - hypot - - hypot_ - - igamma - - igamma_ - - igammac - - igammac_ - - index_add - - index_add_ - - inverse - - index_copy - - index_copy_ - - index_fill - - index_fill_ - - index_put - - index_put_ - - inner - - index_select - - isclose - - isfinite - - isinf - - isnan - - isneginf - - isposinf - - isreal - - kron - - kthvalue - - lcm - - lcm_ - - ldexp - - ldexp_ - - le - - le_ - - lerp - - lerp_ - - where - - less - - less_ - - less_equal - - less_equal_ - - lgamma - - lgamma_ - - log - - log10 - - log10_ - - log1p - - log1p_ - - log2 - - log2_ - - log_ - - log_normal_ - - log_softmax - - logcumsumexp - - logdet - - logaddexp - - logaddexp2 - - logical_and - - logical_and_ - - logical_not - - logit - - logical_not_ - - logical_or - - logical_or_ - - logical_xor - - logical_xor_ - - logit_ - - logsumexp - - lstsq - - lt - - lt_ - - lu_solve - - map2_ - - map_ - - masked_fill - - matmul - - masked_fill_ - - masked_scatter - - masked_scatter_ - - masked_select - - matrix_exp - - max - - maximum - - mean - - matrix_power - - median - - min - - minimum - - mm - - mode - - msort - - mul - - mul_ - - multinomial - - multiply - - multiply_ - - mv - - mvlgamma - - mvlgamma_ - - nansum - - narrow - - narrow_copy - - ne - - ne_ - - neg - - neg_ - - negative - - negative_ - - nonzero - - normal_ - - not_equal - - not_equal_ - - permute - - pinverse - - polygamma - - pow - - pow_ - - polygamma_ - - prelu - - prod - - put_ - - rad2deg - - rad2deg_ - - ravel - - real - - reciprocal - - reciprocal_ - - relu - - relu_ - - remainder - - repeat_interleave - - reshape - - remainder_ - - renorm - - renorm_ - - repeat - - reshape_as - - resize_ - - resize_as_ - - roll - - rot90 - - round - - round_ - - rsqrt - - rsqrt_ - - scatter - - scatter_ - - scatter_add - - scatter_add_ - - select - - sgn - - sgn_ - - sigmoid - - sigmoid_ - - sign - - sign_ - - signbit - - sin - - sin_ - - sinc - - sinc_ - - sinh - - sinh_ - - slogdet - - smm - - softmax - - solve - - sort - - split_with_sizes - - sqrt - - sqrt_ - - square - - square_ - - squeeze - - squeeze_ - - sspaddmm - - std - - sub - - sub_ - - sum - - sum_to_size - - svd - - symeig - - t - - t_ - - take - - tan - - tan_ - - tanh - - tanh_ - - tensor_split - - tile - - topk - - transpose - - transpose_ - - triangular_solve - - tril - - tril_ - - triu - - true_divide - - triu_ - - true_divide_ - - trunc - - trunc_ - - type_as - - unbind - - unflatten - - unfold - - unsafe_chunk - - unsqueeze - - unsafe_split - - unsafe_split_with_sizes - - var - - vdot - - unsqueeze_ - - view_as - - xlogy - - xlogy_ - -torch: - - _adaptive_avg_pool2d - - _add_relu - - _add_relu_ - - _aminmax - - _batch_norm_impl_index - - _convolution - - abs - - abs_ - - absolute - - acos - - acos_ - - acosh - - acosh_ - - adaptive_avg_pool1d - - adaptive_max_pool1d - - add - - addbmm - - addcdiv - - addcmul - - addmm - - addmv - - addmv_ - - addr - - amax - - affine_grid_generator - - align_tensors - - all - - alpha_dropout - - amin - - alpha_dropout_ - - angle - - any - - arange - - arccos - - arccos_ - - arccosh - - arccosh_ - - arcsin - - arcsin_ - - arcsinh - - arcsinh_ - - arctan - - arctan_ - - arctanh - - arctanh_ - - argmax - - argmin - - argsort - - asin - - asin_ - - asinh - - asinh_ - - atan - - atan2 - - atan_ - - atanh - - atanh_ - - atleast_1d - - atleast_2d - - atleast_3d - - avg_pool1d - - baddbmm - - bartlett_window - - batch_norm_backward_elemt - - batch_norm_backward_reduce - - batch_norm_elemt - - batch_norm_gather_stats - - batch_norm_gather_stats_with_counts - - bernoulli - - batch_norm_stats - - batch_norm_update_stats - - bilinear - - bincount - - binomial - - binary_cross_entropy_with_logits - - bitwise_and - - bitwise_not - - bitwise_or - - bitwise_xor - - blackman_window - - block_diag - - bmm - - broadcast_tensors - - broadcast_to - - cartesian_prod - - cat - - cdist - - ceil - - ceil_ - - celu - - celu_ - - chain_matmul - - channel_shuffle - - cholesky - - cholesky_inverse - - cholesky_solve - - choose_qparams_optimized - - chunk - - clamp - - clamp_ - - clamp_max - - clamp_max_ - - clamp_min - - clamp_min_ - - clip - - clip_ - - clone - - column_stack - - combinations - - constant_pad_nd - - conv1d - - conv2d - - conv3d - - conv_tbc - - conv_transpose1d - - conv_transpose2d - - conv_transpose3d - - cos - - convolution - - copysign - - cos_ - - cosh - - cosh_ - - cosine_embedding_loss - - cosine_similarity - - count_nonzero - - cross - - ctc_loss - - cummax - - cummin - - cumprod - - cumsum - - deg2rad - - deg2rad_ - - det - - diag - - diag_embed - - diff - - diagflat - - diagonal - - digamma - - dist - - div - - divide - - dot - - dropout - - dropout_ - - dsmm - - dstack - - eig - - einsum - - embedding - - embedding_bag - - embedding_renorm_ - - eq - - equal - - erf - - erf_ - - erfc - - erfc_ - - erfinv - - exp - - exp2 - - exp2_ - - exp_ - - expm1 - - expm1_ - - eye - - feature_dropout - - feature_alpha_dropout - - feature_alpha_dropout_ - - feature_dropout_ - - fix - - fill_ - - fix_ - - flatten - - flip - - fliplr - - flipud - - float_power - - floor - - floor_ - - floor_divide - - fmax - - fmin - - fmod - - frac - - frac_ - - full - - frobenius_norm - - full_like - - gather - - gcd - - gcd_ - - ge - - geqrf - - ger - - greater - - greater_equal - - grid_sampler - - grid_sampler_2d - - group_norm - - grid_sampler_3d - - gru - - gru_cell - - gt - - hamming_window - - hann_window - - hardshrink - - heaviside - - hinge_embedding_loss - - histc - - hsmm - - hspmm - - hstack - - hypot - - igamma - - igammac - - index_add - - index_copy - - inner - - index_fill - - index_put - - index_put_ - - index_select - - instance_norm - - isclose - - isfinite - - isinf - - isnan - - isneginf - - isposinf - - istft - - kaiser_window - - kl_div - - kron - - kthvalue - - layer_norm - - lcm - - lcm_ - - ldexp - - ldexp_ - - le - - lerp - - less - - less_equal - - lgamma - - linspace - - log - - log10 - - log10_ - - log1p - - log1p_ - - log2 - - log2_ - - log_softmax - - log_ - - logaddexp - - logaddexp2 - - logcumsumexp - - logdet - - logical_and - - logical_not - - logical_or - - logical_xor - - logit - - logit_ - - logspace - - logsumexp - - lstm - - lstm_cell - - lstsq - - lt - - lu_solve - - masked_fill - - margin_ranking_loss - - masked_scatter - - masked_select - - matrix_exp - - matmul - - matrix_power - - matrix_rank - - max - - max_pool1d - - max_pool2d - - max_pool1d_with_indices - - max_pool3d - - maximum - - mean - - median - - min - - minimum - - mm - - mode - - moveaxis - - movedim - - msort - - mul - - multinomial - - multiply - - mv - - mvlgamma - - nan_to_num - - nan_to_num_ - - nanmedian - - nansum - - narrow - - native_batch_norm - - native_group_norm - - narrow_copy - - native_layer_norm - - native_norm - - ne - - neg - - negative - - neg_ - - negative_ - - nextafter - - nonzero - - norm_except_dim - - normal - - not_equal - - nuclear_norm - - pairwise_distance - - pdist - - pinverse - - pixel_shuffle - - pixel_unshuffle - - poisson - - poisson_nll_loss - - polar - - polygamma - - pow - - prelu - - prod - - rad2deg - - promote_types - - rad2deg_ - - range - - ravel - - real - - reciprocal - - relu - - reciprocal_ - - relu_ - - remainder - - renorm - - repeat_interleave - - reshape - - resize_as_ - - roll - - rot90 - - round - - round_ - - rrelu - - rrelu_ - - rsqrt - - row_stack - - rsqrt_ - - rsub - - saddmm - - scalar_tensor - - scatter - - select - - scatter_add - - searchsorted - - selu - - selu_ - - sgn - - sigmoid - - sigmoid_ - - sign - - signbit - - sin - - sin_ - - sinc - - sinc_ - - sinh - - sinh_ - - slogdet - - smm - - softmax - - solve - - sort - - sparse_coo_tensor - - square - - split_with_sizes - - spmm - - sqrt - - sqrt_ - - square_ - - squeeze - - sspaddmm - - stack - - std - - std_mean - - sub - - subtract - - sum - - svd - - swapaxes - - swapdims - - symeig - - t - - take - - tan - - tan_ - - tanh - - tanh_ - - tensordot - - tensor_split - - threshold - - threshold_ - - tile - - topk - - transpose - - trapz - - triangular_solve - - tril - - tril_indices - - triplet_margin_loss - - triu - - triu_indices - - true_divide - - trunc - - trunc_ - - unique_consecutive - - xlogy - - unbind - - unique_dim - - unsafe_chunk - - unsafe_split - - vander - - var - - vdot - - unsafe_split_with_sizes - - unsqueeze - - var_mean - - vstack - - where - - xlogy_ +# Copyright (c) 2023 Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# List of ops that register hooks + +functional: + - conv1d + - conv2d + - conv3d + - conv_transpose1d + - conv_transpose2d + - conv_transpose3d + - conv_tbc + - avg_pool1d + - avg_pool2d + - avg_pool3d + - fractional_max_pool2d_with_indices + - fractional_max_pool2d + - fractional_max_pool3d_with_indices + - fractional_max_pool3d + - max_pool1d_with_indices + - max_pool1d + - max_pool2d_with_indices + - max_pool2d + - max_pool3d_with_indices + - max_pool3d + - max_unpool1d + - max_unpool2d + - max_unpool3d + - lp_pool2d + - lp_pool1d + - adaptive_max_pool1d_with_indices + - adaptive_max_pool1d + - adaptive_max_pool2d_with_indices + - adaptive_max_pool2d + - adaptive_max_pool3d_with_indices + - adaptive_max_pool3d + - adaptive_avg_pool1d + - adaptive_avg_pool2d + - adaptive_avg_pool3d + - dropout + - alpha_dropout + - dropout2d + - dropout3d + - feature_alpha_dropout + - threshold + - threshold_ + - relu + - relu_ + - glu + - hardtanh + - hardtanh_ + - relu6 + - elu + - elu_ + - selu + - selu_ + - celu + - celu_ + - leaky_relu + - leaky_relu_ + - prelu + - rrelu + - rrelu_ + - logsigmoid + - gelu + - hardshrink + - tanhshrink + - softsign + - softplus + - softmin + - softmax + - gumbel_softmax + - log_softmax + - softshrink + - tanh + - sigmoid + - hardsigmoid + - linear + - bilinear + - silu + - hardswish + - embedding + - embedding_bag + - batch_norm + - instance_norm + - layer_norm + - group_norm + - local_response_norm + - ctc_loss + - nll_loss + - poisson_nll_loss + - gaussian_nll_loss + - kl_div + - cross_entropy + - binary_cross_entropy + - binary_cross_entropy_with_logits + - smooth_l1_loss + - l1_loss + - mse_loss + - margin_ranking_loss + - hinge_embedding_loss + - multilabel_margin_loss + - soft_margin_loss + - multilabel_soft_margin_loss + - cosine_embedding_loss + - multi_margin_loss + - pixel_shuffle + - pixel_unshuffle + - channel_shuffle + - upsample + - interpolate + - upsample_nearest + - upsample_bilinear + - grid_sample + - affine_grid + - pad + - pairwise_distance + - pdist + - cosine_similarity + - one_hot + - triplet_margin_loss + - triplet_margin_with_distance_loss + - normalize + - unfold + - fold + - multi_head_attention_forward + +tensor: + - __add__ + - __and__ + - __bool__ + - __div__ + - __eq__ + - __ge__ + - __gt__ + - __iadd__ + - __iand__ + - __idiv__ + - __ifloordiv__ + - __ilshift__ + - __imod__ + - __imul__ + - __ior__ + - __irshift__ + - __isub__ + - __ixor__ + - __lshift__ + - __matmul__ + - __mod__ + - __mul__ + - __nonzero__ + - __or__ + - __radd__ + - __rmul__ + - __rshift__ + - __sub__ + - __truediv__ + - __xor__ + - abs + - abs_ + - absolute + - absolute_ + - acos + - acos_ + - acosh + - acosh_ + - add + - add_ + - addbmm + - addbmm_ + - addcdiv + - addcdiv_ + - addcmul + - addcmul_ + - addmm + - addmm_ + - addmv + - addmv_ + - addr + - addr_ + - align_as + - align_to + - all + - allclose + - amax + - amin + - angle + - any + - arccos + - arccos_ + - arccosh + - arccosh_ + - arcsin + - arcsin_ + - arcsinh + - arcsinh_ + - arctan + - arctan_ + - arctanh + - arctanh_ + - argmax + - argmin + - argsort + - asin + - asin_ + - asinh + - asinh_ + - atan + - atan2 + - atan2_ + - atan_ + - atanh + - atanh_ + - baddbmm + - baddbmm_ + - bernoulli + - bernoulli_ + - bincount + - bitwise_and + - bitwise_and_ + - bitwise_not + - bitwise_not_ + - bitwise_or + - bitwise_or_ + - bitwise_xor + - bitwise_xor_ + - bmm + - broadcast_to + - cauchy_ + - ceil + - ceil_ + - cholesky + - chunk + - clamp + - cholesky_solve + - cholesky_inverse + - clamp_ + - clamp_max + - clamp_max_ + - clip + - clamp_min + - clamp_min_ + - clip_ + - copysign + - copysign_ + - cos + - cos_ + - cosh + - cosh_ + - count_nonzero + - cummax + - cummin + - cumprod + - cumprod_ + - cumsum + - cumsum_ + - deg2rad + - deg2rad_ + - det + - diag + - diag_embed + - diagflat + - diagonal + - diff + - dist + - digamma + - digamma_ + - div + - div_ + - divide + - divide_ + - dot + - eig + - eq + - eq_ + - erf + - equal + - erf_ + - erfc + - erfc_ + - erfinv + - erfinv_ + - exp + - exp2 + - exp2_ + - expm1 + - exp_ + - expm1_ + - exponential_ + - fill_ + - fix + - fill_diagonal_ + - fix_ + - flip + - fliplr + - flatten + - flipud + - float_power + - float_power_ + - floor + - floor_ + - floor_divide + - floor_divide_ + - fmax + - fmin + - fmod + - fmod_ + - frac + - frac_ + - gather + - gcd + - gcd_ + - ge + - ge_ + - geometric_ + - geqrf + - ger + - greater + - greater_ + - gt + - gt_ + - greater_equal + - greater_equal_ + - hardshrink + - heaviside + - heaviside_ + - histc + - hypot + - hypot_ + - igamma + - igamma_ + - igammac + - igammac_ + - index_add + - index_add_ + - inverse + - index_copy + - index_copy_ + - index_fill + - index_fill_ + - index_put + - index_put_ + - inner + - index_select + - isclose + - isfinite + - isinf + - isnan + - isneginf + - isposinf + - isreal + - kron + - kthvalue + - lcm + - lcm_ + - ldexp + - ldexp_ + - le + - le_ + - lerp + - lerp_ + - where + - less + - less_ + - less_equal + - less_equal_ + - lgamma + - lgamma_ + - log + - log10 + - log10_ + - log1p + - log1p_ + - log2 + - log2_ + - log_ + - log_normal_ + - log_softmax + - logcumsumexp + - logdet + - logaddexp + - logaddexp2 + - logical_and + - logical_and_ + - logical_not + - logit + - logical_not_ + - logical_or + - logical_or_ + - logical_xor + - logical_xor_ + - logit_ + - logsumexp + - lstsq + - lt + - lt_ + - lu_solve + - map2_ + - map_ + - masked_fill + - matmul + - masked_fill_ + - masked_scatter + - masked_scatter_ + - masked_select + - matrix_exp + - max + - maximum + - mean + - matrix_power + - median + - min + - minimum + - mm + - mode + - msort + - mul + - mul_ + - multinomial + - multiply + - multiply_ + - mv + - mvlgamma + - mvlgamma_ + - nansum + - narrow + - narrow_copy + - ne + - ne_ + - neg + - neg_ + - negative + - negative_ + - nonzero + - normal_ + - not_equal + - not_equal_ + - permute + - pinverse + - polygamma + - pow + - pow_ + - polygamma_ + - prelu + - prod + - put_ + - rad2deg + - rad2deg_ + - ravel + - real + - reciprocal + - reciprocal_ + - relu + - relu_ + - remainder + - repeat_interleave + - reshape + - remainder_ + - renorm + - renorm_ + - repeat + - reshape_as + - resize_ + - resize_as_ + - roll + - rot90 + - round + - round_ + - rsqrt + - rsqrt_ + - scatter + - scatter_ + - scatter_add + - scatter_add_ + - select + - sgn + - sgn_ + - sigmoid + - sigmoid_ + - sign + - sign_ + - signbit + - sin + - sin_ + - sinc + - sinc_ + - sinh + - sinh_ + - slogdet + - smm + - softmax + - solve + - sort + - split_with_sizes + - sqrt + - sqrt_ + - square + - square_ + - squeeze + - squeeze_ + - sspaddmm + - std + - sub + - sub_ + - sum + - sum_to_size + - svd + - symeig + - t + - t_ + - take + - tan + - tan_ + - tanh + - tanh_ + - tensor_split + - tile + - topk + - transpose + - transpose_ + - triangular_solve + - tril + - tril_ + - triu + - true_divide + - triu_ + - true_divide_ + - trunc + - trunc_ + - type_as + - unbind + - unflatten + - unfold + - unsafe_chunk + - unsqueeze + - unsafe_split + - unsafe_split_with_sizes + - var + - vdot + - unsqueeze_ + - view_as + - xlogy + - xlogy_ + +torch: + - _adaptive_avg_pool2d + - _add_relu + - _add_relu_ + - _aminmax + - _batch_norm_impl_index + - _convolution + - abs + - abs_ + - absolute + - acos + - acos_ + - acosh + - acosh_ + - adaptive_avg_pool1d + - adaptive_max_pool1d + - add + - addbmm + - addcdiv + - addcmul + - addmm + - addmv + - addmv_ + - addr + - amax + - affine_grid_generator + - align_tensors + - all + - alpha_dropout + - amin + - alpha_dropout_ + - angle + - any + - arange + - arccos + - arccos_ + - arccosh + - arccosh_ + - arcsin + - arcsin_ + - arcsinh + - arcsinh_ + - arctan + - arctan_ + - arctanh + - arctanh_ + - argmax + - argmin + - argsort + - asin + - asin_ + - asinh + - asinh_ + - atan + - atan2 + - atan_ + - atanh + - atanh_ + - atleast_1d + - atleast_2d + - atleast_3d + - avg_pool1d + - baddbmm + - bartlett_window + - batch_norm_backward_elemt + - batch_norm_backward_reduce + - batch_norm_elemt + - batch_norm_gather_stats + - batch_norm_gather_stats_with_counts + - bernoulli + - batch_norm_stats + - batch_norm_update_stats + - bilinear + - bincount + - binomial + - binary_cross_entropy_with_logits + - bitwise_and + - bitwise_not + - bitwise_or + - bitwise_xor + - blackman_window + - block_diag + - bmm + - broadcast_tensors + - broadcast_to + - cartesian_prod + - cat + - cdist + - ceil + - ceil_ + - celu + - celu_ + - chain_matmul + - channel_shuffle + - cholesky + - cholesky_inverse + - cholesky_solve + - choose_qparams_optimized + - chunk + - clamp + - clamp_ + - clamp_max + - clamp_max_ + - clamp_min + - clamp_min_ + - clip + - clip_ + - clone + - column_stack + - combinations + - constant_pad_nd + - conv1d + - conv2d + - conv3d + - conv_tbc + - conv_transpose1d + - conv_transpose2d + - conv_transpose3d + - cos + - convolution + - copysign + - cos_ + - cosh + - cosh_ + - cosine_embedding_loss + - cosine_similarity + - count_nonzero + - cross + - ctc_loss + - cummax + - cummin + - cumprod + - cumsum + - deg2rad + - deg2rad_ + - det + - diag + - diag_embed + - diff + - diagflat + - diagonal + - digamma + - dist + - div + - divide + - dot + - dropout + - dropout_ + - dsmm + - dstack + - eig + - einsum + - embedding + - embedding_bag + - embedding_renorm_ + - eq + - equal + - erf + - erf_ + - erfc + - erfc_ + - erfinv + - exp + - exp2 + - exp2_ + - exp_ + - expm1 + - expm1_ + - eye + - feature_dropout + - feature_alpha_dropout + - feature_alpha_dropout_ + - feature_dropout_ + - fix + - fill_ + - fix_ + - flatten + - flip + - fliplr + - flipud + - float_power + - floor + - floor_ + - floor_divide + - fmax + - fmin + - fmod + - frac + - frac_ + - full + - frobenius_norm + - full_like + - gather + - gcd + - gcd_ + - ge + - geqrf + - ger + - greater + - greater_equal + - grid_sampler + - grid_sampler_2d + - group_norm + - grid_sampler_3d + - gru + - gru_cell + - gt + - hamming_window + - hann_window + - hardshrink + - heaviside + - hinge_embedding_loss + - histc + - hsmm + - hspmm + - hstack + - hypot + - igamma + - igammac + - index_add + - index_copy + - inner + - index_fill + - index_put + - index_put_ + - index_select + - instance_norm + - isclose + - isfinite + - isinf + - isnan + - isneginf + - isposinf + - istft + - kaiser_window + - kl_div + - kron + - kthvalue + - layer_norm + - lcm + - lcm_ + - ldexp + - ldexp_ + - le + - lerp + - less + - less_equal + - lgamma + - linspace + - log + - log10 + - log10_ + - log1p + - log1p_ + - log2 + - log2_ + - log_softmax + - log_ + - logaddexp + - logaddexp2 + - logcumsumexp + - logdet + - logical_and + - logical_not + - logical_or + - logical_xor + - logit + - logit_ + - logspace + - logsumexp + - lstm + - lstm_cell + - lstsq + - lt + - lu_solve + - masked_fill + - margin_ranking_loss + - masked_scatter + - masked_select + - matrix_exp + - matmul + - matrix_power + - matrix_rank + - max + - max_pool1d + - max_pool2d + - max_pool1d_with_indices + - max_pool3d + - maximum + - mean + - median + - min + - minimum + - mm + - mode + - moveaxis + - movedim + - msort + - mul + - multinomial + - multiply + - mv + - mvlgamma + - nan_to_num + - nan_to_num_ + - nanmedian + - nansum + - narrow + - native_batch_norm + - native_group_norm + - narrow_copy + - native_layer_norm + - native_norm + - ne + - neg + - negative + - neg_ + - negative_ + - nextafter + - nonzero + - norm_except_dim + - normal + - not_equal + - nuclear_norm + - pairwise_distance + - pdist + - pinverse + - pixel_shuffle + - pixel_unshuffle + - poisson + - poisson_nll_loss + - polar + - polygamma + - pow + - prelu + - prod + - rad2deg + - promote_types + - rad2deg_ + - range + - ravel + - real + - reciprocal + - relu + - reciprocal_ + - relu_ + - remainder + - renorm + - repeat_interleave + - reshape + - resize_as_ + - roll + - rot90 + - round + - round_ + - rrelu + - rrelu_ + - rsqrt + - row_stack + - rsqrt_ + - rsub + - saddmm + - scalar_tensor + - scatter + - select + - scatter_add + - searchsorted + - selu + - selu_ + - sgn + - sigmoid + - sigmoid_ + - sign + - signbit + - sin + - sin_ + - sinc + - sinc_ + - sinh + - sinh_ + - slogdet + - smm + - softmax + - solve + - sort + - sparse_coo_tensor + - square + - split_with_sizes + - spmm + - sqrt + - sqrt_ + - square_ + - squeeze + - sspaddmm + - stack + - std + - std_mean + - sub + - subtract + - sum + - svd + - swapaxes + - swapdims + - symeig + - t + - take + - tan + - tan_ + - tanh + - tanh_ + - tensordot + - tensor_split + - threshold + - threshold_ + - tile + - topk + - transpose + - trapz + - triangular_solve + - tril + - tril_indices + - triplet_margin_loss + - triu + - triu_indices + - true_divide + - trunc + - trunc_ + - unique_consecutive + - xlogy + - unbind + - unique_dim + - unsafe_chunk + - unsafe_split + - vander + - var + - vdot + - unsafe_split_with_sizes + - unsqueeze + - var_mean + - vstack + - where + - xlogy_ diff --git a/debug/accuracy_tools/api_accuracy_checker/hook_module/utils.py b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/hook_module/utils.py similarity index 90% rename from debug/accuracy_tools/api_accuracy_checker/hook_module/utils.py rename to debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/hook_module/utils.py index 7d16ac993ed45faa0f9b48bb64050592e15ef4d2..0a03bdd5a9098343e2b343415c2ea9ce90af768d 100644 --- a/debug/accuracy_tools/api_accuracy_checker/hook_module/utils.py +++ b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/hook_module/utils.py @@ -1,29 +1,29 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -# Copyright (C) 2023-2023. Huawei Technologies Co., Ltd. All rights reserved. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" - -import os -import yaml - -from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileOpen - -cur_path = os.path.dirname(os.path.realpath(__file__)) -yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") -with FileOpen(yaml_path, 'r') as f: - Ops = yaml.safe_load(f) - WrapFunctionalOps = Ops.get('functional') - WrapTensorOps = Ops.get('tensor') +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2023-2023. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import os +import yaml + +from calibrator.common.file_check import FileOpen + +cur_path = os.path.dirname(os.path.realpath(__file__)) +yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") +with FileOpen(yaml_path, 'r') as f: + Ops = yaml.safe_load(f) + WrapFunctionalOps = Ops.get('functional') + WrapTensorOps = Ops.get('tensor') WrapTorchOps = Ops.get('torch') \ No newline at end of file diff --git a/debug/accuracy_tools/api_accuracy_checker/hook_module/wrap_functional.py b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/hook_module/wrap_functional.py similarity index 81% rename from debug/accuracy_tools/api_accuracy_checker/hook_module/wrap_functional.py rename to debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/hook_module/wrap_functional.py index 967e9efc84123533422240c1f1fe530e73800bf0..583ef3404d03faf741ebd65dcc82a46de9b6eadd 100644 --- a/debug/accuracy_tools/api_accuracy_checker/hook_module/wrap_functional.py +++ b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/hook_module/wrap_functional.py @@ -1,68 +1,63 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -# Copyright (C) 2023-2023. Huawei Technologies Co., Ltd. All rights reserved. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" - -import os - -import torch -import yaml - -from api_accuracy_checker.hook_module.hook_module import HOOKModule -from api_accuracy_checker.common.utils import torch_device_guard -from api_accuracy_checker.common.config import msCheckerConfig -from api_accuracy_checker.hook_module.utils import WrapFunctionalOps -from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileOpen - -for f in dir(torch.nn.functional): - locals().update({f: getattr(torch.nn.functional, f)}) - - -def get_functional_ops(): - global WrapFunctionalOps - _all_functional_ops = dir(torch.nn.functional) - if msCheckerConfig.white_list: - return set(WrapFunctionalOps) & set(_all_functional_ops) & set(msCheckerConfig.white_list) - else: - return set(WrapFunctionalOps) & set(_all_functional_ops) - - -class HOOKFunctionalOP(object): - pass - - -class FunctionalOPTemplate(HOOKModule): - def __init__(self, op_name, hook, need_hook=True): - self.op_name_ = op_name - self.prefix_op_name_ = "Functional*" + str(op_name) + "*" - if need_hook: - super().__init__(hook) - - @torch_device_guard - def forward(self, *args, **kwargs): - return eval(self.op_name_)(*args, **kwargs) - - -def wrap_functional_op(op_name, hook): - def functional_op_template(*args, **kwargs): - return FunctionalOPTemplate(op_name, hook)(*args, **kwargs) - - return functional_op_template - - -def wrap_functional_ops_and_bind(hook): - _functional_ops = get_functional_ops() - for op_name in _functional_ops: - setattr(HOOKFunctionalOP, "wrap_" + op_name, wrap_functional_op(op_name, hook)) +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2023-2023. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import torch + +from calibrator.pytorch.api_accuracy_checker import HOOKModule +from calibrator.pytorch.api_accuracy_checker.common import torch_device_guard +from calibrator.pytorch.api_accuracy_checker.common import msCheckerConfig + +for f in dir(torch.nn.functional): + locals().update({f: getattr(torch.nn.functional, f)}) + + +def get_functional_ops(): + global WrapFunctionalOps + _all_functional_ops = dir(torch.nn.functional) + if msCheckerConfig.white_list: + return set(WrapFunctionalOps) & set(_all_functional_ops) & set(msCheckerConfig.white_list) + else: + return set(WrapFunctionalOps) & set(_all_functional_ops) + + +class HOOKFunctionalOP(object): + pass + + +class FunctionalOPTemplate(HOOKModule): + def __init__(self, op_name, hook, need_hook=True): + self.op_name_ = op_name + self.prefix_op_name_ = "Functional*" + str(op_name) + "*" + if need_hook: + super().__init__(hook) + + @torch_device_guard + def forward(self, *args, **kwargs): + return eval(self.op_name_)(*args, **kwargs) + + +def wrap_functional_op(op_name, hook): + def functional_op_template(*args, **kwargs): + return FunctionalOPTemplate(op_name, hook)(*args, **kwargs) + + return functional_op_template + + +def wrap_functional_ops_and_bind(hook): + _functional_ops = get_functional_ops() + for op_name in _functional_ops: + setattr(HOOKFunctionalOP, "wrap_" + op_name, wrap_functional_op(op_name, hook)) diff --git a/debug/accuracy_tools/api_accuracy_checker/hook_module/wrap_tensor.py b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/hook_module/wrap_tensor.py similarity index 77% rename from debug/accuracy_tools/api_accuracy_checker/hook_module/wrap_tensor.py rename to debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/hook_module/wrap_tensor.py index d60cac74baf15872854d71089df7cdd81925746e..22411b0535121bd14586b1cc785b67f01dc72a76 100644 --- a/debug/accuracy_tools/api_accuracy_checker/hook_module/wrap_tensor.py +++ b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/hook_module/wrap_tensor.py @@ -1,69 +1,64 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -# Copyright (C) 2023-2023. Huawei Technologies Co., Ltd. All rights reserved. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" - -import os - -import torch -import yaml - -from api_accuracy_checker.hook_module.hook_module import HOOKModule -from api_accuracy_checker.common.utils import torch_device_guard -from api_accuracy_checker.common.config import msCheckerConfig -from api_accuracy_checker.hook_module.utils import WrapTensorOps -from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileOpen -from ptdbg_ascend.src.python.ptdbg_ascend.common.utils import parameter_adapter - - -def get_tensor_ops(): - global WrapTensorOps - _tensor_ops = dir(torch._C._TensorBase) - if msCheckerConfig.white_list: - return set(WrapTensorOps) & set(_tensor_ops) & set(msCheckerConfig.white_list) - else: - return set(WrapTensorOps) & set(_tensor_ops) - - -class HOOKTensor(object): - pass - - -class TensorOPTemplate(HOOKModule): - - def __init__(self, op_name, hook, need_hook=True): - self.op_name_ = op_name - self.prefix_op_name_ = "Tensor*" + str(op_name) + "*" - if need_hook: - super().__init__(hook) - - @torch_device_guard - @parameter_adapter - def forward(self, *args, **kwargs): - return getattr(torch._C._TensorBase, str(self.op_name_))(*args, **kwargs) - - -def wrap_tensor_op(op_name, hook): - - def tensor_op_template(*args, **kwargs): - return TensorOPTemplate(op_name, hook)(*args, **kwargs) - - return tensor_op_template - - -def wrap_tensor_ops_and_bind(hook): - _tensor_ops = get_tensor_ops() - for op_name in _tensor_ops: - setattr(HOOKTensor, "wrap_" + str(op_name), wrap_tensor_op(op_name, hook)) +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2023-2023. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import torch + +from calibrator.pytorch.api_accuracy_checker import HOOKModule +from calibrator.pytorch.api_accuracy_checker.common import torch_device_guard +from calibrator.pytorch.api_accuracy_checker.common import msCheckerConfig +from calibrator.common.utils import parameter_adapter + + +def get_tensor_ops(): + global WrapTensorOps + _tensor_ops = dir(torch._C._TensorBase) + if msCheckerConfig.white_list: + return set(WrapTensorOps) & set(_tensor_ops) & set(msCheckerConfig.white_list) + else: + return set(WrapTensorOps) & set(_tensor_ops) + + +class HOOKTensor(object): + pass + + +class TensorOPTemplate(HOOKModule): + + def __init__(self, op_name, hook, need_hook=True): + self.op_name_ = op_name + self.prefix_op_name_ = "Tensor*" + str(op_name) + "*" + if need_hook: + super().__init__(hook) + + @torch_device_guard + @parameter_adapter + def forward(self, *args, **kwargs): + return getattr(torch._C._TensorBase, str(self.op_name_))(*args, **kwargs) + + +def wrap_tensor_op(op_name, hook): + + def tensor_op_template(*args, **kwargs): + return TensorOPTemplate(op_name, hook)(*args, **kwargs) + + return tensor_op_template + + +def wrap_tensor_ops_and_bind(hook): + _tensor_ops = get_tensor_ops() + for op_name in _tensor_ops: + setattr(HOOKTensor, "wrap_" + str(op_name), wrap_tensor_op(op_name, hook)) diff --git a/debug/accuracy_tools/api_accuracy_checker/hook_module/wrap_torch.py b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/hook_module/wrap_torch.py similarity index 88% rename from debug/accuracy_tools/api_accuracy_checker/hook_module/wrap_torch.py rename to debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/hook_module/wrap_torch.py index 9fbe343d374364a5904f36d0d82189ff856e14b3..efee3b090ef9bc4e0ffbe2731af1f5756e86870c 100644 --- a/debug/accuracy_tools/api_accuracy_checker/hook_module/wrap_torch.py +++ b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/hook_module/wrap_torch.py @@ -1,110 +1,105 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -# Copyright (C) 2023-2023. Huawei Technologies Co., Ltd. All rights reserved. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" - -import os - -import torch -import yaml - -from api_accuracy_checker.hook_module.hook_module import HOOKModule -from api_accuracy_checker.common.utils import torch_device_guard -from api_accuracy_checker.common.config import msCheckerConfig -from api_accuracy_checker.hook_module.utils import WrapTorchOps -from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileOpen - - -def get_torch_ops(): - global WrapTorchOps - _torch_ops = dir(torch._C._VariableFunctionsClass) - if msCheckerConfig.white_list: - return set(WrapTorchOps) & set(_torch_ops) & set(msCheckerConfig.white_list) - else: - return set(WrapTorchOps) & set(_torch_ops) - - -class HOOKTorchOP(object): - pass - - -class TorchOPTemplate(HOOKModule): - - def __init__(self, op_name, hook, need_hook=True): - self.op_name_ = op_name - self.prefix_op_name_ = "Torch*" + str(op_name) + "*" - if need_hook: - super().__init__(hook) - - def input_param_need_adapt(self): - special_op_list = ["broadcast_tensors", "block_diag"] - for item in special_op_list: - if item in self.op_name_: - return True - return False - - def einsum_adapt(self, *args): - if len(args) < 2: - raise ValueError('einsum(): must specify the equation string and at least one operand, ' - 'or at least one operand and its subscripts list') - equation = None - operands = None - if isinstance(args[0], torch.Tensor): - def parse_subscript(n: int) -> str: - if n == Ellipsis: - return '...' - if n >= 0 and n < 26: - return chr(ord('A') + n) - if n >= 26 and n < 52: - return chr(ord('a') + n - 26) - raise ValueError('einsum(): subscript in subscript list is not within the valid range [0, 52]') - equation = ','.join(''.join(parse_subscript(script) for script in arg) for arg in args[1::2]) - - if len(args) % 2 == 1: - equation += '->' + ''.join(parse_subscript(script) for script in args[-1]) - operands = args[:-1:2] - else: - operands = args[::2] - else: - equation = args[0] - operands = args[1:] - - if len(operands) == 1 and isinstance(operands[0], (list, tuple)): - _operands = operands[0] - return self.einsum_adapt(equation, *_operands) - return equation, operands - - @torch_device_guard - def forward(self, *args, **kwargs): - if self.input_param_need_adapt(): - return getattr(torch._C._VariableFunctionsClass, str(self.op_name_))(args, **kwargs) - else: - if self.op_name_ == 'einsum': - args = self.einsum_adapt(*args) - return getattr(torch._C._VariableFunctionsClass, str(self.op_name_))(*args, **kwargs) - - -def wrap_torch_op(op_name, hook): - - def torch_op_template(*args, **kwargs): - return TorchOPTemplate(op_name, hook)(*args, **kwargs) - - return torch_op_template - - -def wrap_torch_ops_and_bind(hook): - _torch_ops = get_torch_ops() - for op_name in _torch_ops: - setattr(HOOKTorchOP, "wrap_" + op_name, wrap_torch_op(op_name, hook)) +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2023-2023. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import torch + +from calibrator.pytorch.api_accuracy_checker import HOOKModule +from calibrator.pytorch.api_accuracy_checker.common import torch_device_guard +from calibrator.pytorch.api_accuracy_checker.common import msCheckerConfig + + +def get_torch_ops(): + global WrapTorchOps + _torch_ops = dir(torch._C._VariableFunctionsClass) + if msCheckerConfig.white_list: + return set(WrapTorchOps) & set(_torch_ops) & set(msCheckerConfig.white_list) + else: + return set(WrapTorchOps) & set(_torch_ops) + + +class HOOKTorchOP(object): + pass + + +class TorchOPTemplate(HOOKModule): + + def __init__(self, op_name, hook, need_hook=True): + self.op_name_ = op_name + self.prefix_op_name_ = "Torch*" + str(op_name) + "*" + if need_hook: + super().__init__(hook) + + def input_param_need_adapt(self): + special_op_list = ["broadcast_tensors", "block_diag"] + for item in special_op_list: + if item in self.op_name_: + return True + return False + + def einsum_adapt(self, *args): + if len(args) < 2: + raise ValueError('einsum(): must specify the equation string and at least one operand, ' + 'or at least one operand and its subscripts list') + equation = None + operands = None + if isinstance(args[0], torch.Tensor): + def parse_subscript(n: int) -> str: + if n == Ellipsis: + return '...' + if n >= 0 and n < 26: + return chr(ord('A') + n) + if n >= 26 and n < 52: + return chr(ord('a') + n - 26) + raise ValueError('einsum(): subscript in subscript list is not within the valid range [0, 52]') + equation = ','.join(''.join(parse_subscript(script) for script in arg) for arg in args[1::2]) + + if len(args) % 2 == 1: + equation += '->' + ''.join(parse_subscript(script) for script in args[-1]) + operands = args[:-1:2] + else: + operands = args[::2] + else: + equation = args[0] + operands = args[1:] + + if len(operands) == 1 and isinstance(operands[0], (list, tuple)): + _operands = operands[0] + return self.einsum_adapt(equation, *_operands) + return equation, operands + + @torch_device_guard + def forward(self, *args, **kwargs): + if self.input_param_need_adapt(): + return getattr(torch._C._VariableFunctionsClass, str(self.op_name_))(args, **kwargs) + else: + if self.op_name_ == 'einsum': + args = self.einsum_adapt(*args) + return getattr(torch._C._VariableFunctionsClass, str(self.op_name_))(*args, **kwargs) + + +def wrap_torch_op(op_name, hook): + + def torch_op_template(*args, **kwargs): + return TorchOPTemplate(op_name, hook)(*args, **kwargs) + + return torch_op_template + + +def wrap_torch_ops_and_bind(hook): + _torch_ops = get_torch_ops() + for op_name in _torch_ops: + setattr(HOOKTorchOP, "wrap_" + op_name, wrap_torch_op(op_name, hook)) diff --git a/debug/accuracy_tools/api_accuracy_checker/img/accuracy_checking_details.png b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/img/accuracy_checking_details.png similarity index 100% rename from debug/accuracy_tools/api_accuracy_checker/img/accuracy_checking_details.png rename to debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/img/accuracy_checking_details.png diff --git a/debug/accuracy_tools/api_accuracy_checker/img/accuracy_checking_result.png b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/img/accuracy_checking_result.png similarity index 100% rename from debug/accuracy_tools/api_accuracy_checker/img/accuracy_checking_result.png rename to debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/img/accuracy_checking_result.png diff --git a/debug/accuracy_tools/api_accuracy_checker/img/api_precision_compare_details.png b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/img/api_precision_compare_details.png similarity index 100% rename from debug/accuracy_tools/api_accuracy_checker/img/api_precision_compare_details.png rename to debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/img/api_precision_compare_details.png diff --git a/debug/accuracy_tools/api_accuracy_checker/img/api_precision_compare_result.png b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/img/api_precision_compare_result.png similarity index 100% rename from debug/accuracy_tools/api_accuracy_checker/img/api_precision_compare_result.png rename to debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/img/api_precision_compare_result.png diff --git a/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/run_ut/.keep b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/run_ut/.keep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/run_ut/__init__.py b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/run_ut/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/run_ut/data_generate.py similarity index 96% rename from debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py rename to debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/run_ut/data_generate.py index 51fcfedefbb8a6ef079a2d26cdc7d9ca841092bf..4bc82de7173329bdee851d8055cf0ae5be313a4b 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py +++ b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/run_ut/data_generate.py @@ -20,7 +20,7 @@ import math import torch import numpy -from api_accuracy_checker.common.utils import Const, check_file_or_directory_path, check_object_type, print_warn_log, \ +from calibrator.pytorch.api_accuracy_checker.common import Const, check_file_or_directory_path, check_object_type, print_warn_log, \ print_error_log, get_full_data_path, CompareException TORCH_TYPE = ["torch.device", "torch.dtype"] @@ -226,6 +226,8 @@ def gen_args(args_info, need_grad=True, convert_type=None, real_data_path=None): data = gen_args(arg, need_grad, convert_type, real_data_path) elif isinstance(arg, dict): data = gen_data(arg, need_grad, convert_type, real_data_path) + elif arg is None: + data = None else: print_warn_log(f'Warning: {arg} is not supported') raise NotImplementedError() @@ -243,10 +245,12 @@ def gen_kwargs(api_info, convert_type=None, real_data_path=None): real_data_path: the root directory for storing real data. """ check_object_type(api_info, dict) - kwargs_params = api_info.get("kwargs") + kwargs_params = api_info.get("input_kwargs") for key, value in kwargs_params.items(): if isinstance(value, (list, tuple)): kwargs_params[key] = gen_list_kwargs(value, convert_type, real_data_path) + elif value is None: + kwargs_params[key] = None elif value.get('type') in TENSOR_DATA_LIST or value.get('type').startswith("numpy"): kwargs_params[key] = gen_data(value, True, convert_type, real_data_path) elif value.get('type') in TORCH_TYPE: @@ -293,8 +297,8 @@ def gen_api_params(api_info, need_grad=True, convert_type=None, real_data_path=N error_info = f"convert_type params not support {convert_type}." raise CompareException(CompareException.INVALID_PARAM_ERROR, error_info) kwargs_params = gen_kwargs(api_info, convert_type, real_data_path) - if api_info.get("args"): - args_params = gen_args(api_info.get("args"), need_grad, convert_type, real_data_path) + if api_info.get("input_args"): + args_params = gen_args(api_info.get("input_args"), need_grad, convert_type, real_data_path) else: print_warn_log(f'Warning: No args in {api_info} ') args_params = [] diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/multi_run_ut.py b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py similarity index 91% rename from debug/accuracy_tools/api_accuracy_checker/run_ut/multi_run_ut.py rename to debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py index 760e088eb38a26ba01fd25ac579130240a76a9e8..4f747a9a91330d34fc9046dfb4cc61377817dc01 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/multi_run_ut.py +++ b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py @@ -1,182 +1,182 @@ -import subprocess -import json -import os -import sys -import argparse -import time -import signal -import threading -from collections import namedtuple -from itertools import cycle -from tqdm import tqdm -from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileCheckConst, FileChecker, \ - check_file_suffix, check_link, FileOpen -from api_accuracy_checker.compare.compare import Comparator -from api_accuracy_checker.run_ut.run_ut import _run_ut_parser, get_validated_result_csv_path, get_validated_details_csv_path, preprocess_forward_content -from api_accuracy_checker.common.utils import print_error_log, print_warn_log, print_info_log, create_directory -from ptdbg_ascend.src.python.ptdbg_ascend.common.utils import check_path_before_create - - -def split_json_file(input_file, num_splits, filter_api): - with FileOpen(input_file, 'r') as file: - data = json.load(file) - if filter_api: - data = preprocess_forward_content(data) - - items = list(data.items()) - total_items = len(items) - chunk_size = total_items // num_splits - split_files = [] - - for i in range(num_splits): - start = i * chunk_size - end = (i + 1) * chunk_size if i < num_splits - 1 else total_items - split_filename = f"temp_part{i}.json" - with FileOpen(split_filename, 'w') as split_file: - json.dump(dict(items[start:end]), split_file) - split_files.append(split_filename) - - return split_files, total_items - - -def signal_handler(signum, frame): - print_warn_log(f'Signal handler called with signal {signum}') - raise KeyboardInterrupt() - -signal.signal(signal.SIGINT, signal_handler) -signal.signal(signal.SIGTERM, signal_handler) - - -ParallelUTConfig = namedtuple('ParallelUTConfig', ['forward_files', 'backward_files', 'out_path', 'num_splits', 'save_error_data_flag', 'jit_compile_flag', 'device_id', 'result_csv_path', 'total_items', 'real_data_path']) - - -def run_parallel_ut(config): - processes = [] - device_id_cycle = cycle(config.device_id) - if config.save_error_data_flag: - print_info_log("UT task error datas will be saved") - print_info_log(f"Starting parallel UT with {config.num_splits} processes") - progress_bar = tqdm(total=config.total_items, desc="Total items", unit="items") - - def create_cmd(fwd, bwd, dev_id): - cmd = [ - sys.executable, 'run_ut.py', - '-forward', fwd, - *(['-backward', bwd] if bwd else []), - *(['-o', config.out_path] if config.out_path else []), - '-d', str(dev_id), - *(['-j'] if config.jit_compile_flag else []), - *(['-save_error_data'] if config.save_error_data_flag else []), - '-csv_path', config.result_csv_path, - *(['-real_data_path', config.real_data_path] if config.real_data_path else []) - ] - return cmd - - def read_process_output(process): - try: - while True: - if process.poll() is not None: - break - output = process.stdout.readline() - if output == '': - break - if '[ERROR]' in output: - print(output, end='') - sys.stdout.flush() - except ValueError as e: - print_warn_log(f"An error occurred while reading subprocess output: {e}") - - def update_progress_bar(progress_bar, result_csv_path): - while any(process.poll() is None for process in processes): - try: - with open(result_csv_path, 'r') as result_file: - completed_items = len(result_file.readlines()) - 1 - progress_bar.update(completed_items - progress_bar.n) - except FileNotFoundError: - print_warn_log(f"Result CSV file not found: {result_csv_path}.") - except Exception as e: - print_error_log(f"An unexpected error occurred while reading result CSV: {e}") - time.sleep(1) - - for fwd, bwd in zip(config.forward_files, config.backward_files): - cmd = create_cmd(fwd, bwd, next(device_id_cycle)) - process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, text=True, bufsize=1) - processes.append(process) - threading.Thread(target=read_process_output, args=(process,), daemon=True).start() - - progress_bar_thread = threading.Thread(target=update_progress_bar, args=(progress_bar, config.result_csv_path)) - progress_bar_thread.start() - - def clean_up(): - progress_bar.close() - for process in processes: - try: - process.terminate() - process.wait(timeout=1) - except subprocess.TimeoutExpired: - process.kill() - for file in config.forward_files: - check_link(file) - try: - os.remove(file) - except FileNotFoundError: - print_warn_log(f"File not found and could not be deleted: {file}") - - try: - for process in processes: - process.communicate(timeout=None) - except KeyboardInterrupt: - print_warn_log("Interrupted by user, terminating processes and cleaning up...") - except Exception as e: - print_error_log(f"An unexpected error occurred: {e}") - finally: - if progress_bar.n < config.total_items: - print_warn_log("The UT task has not been completed. The parameter '-csv_path' along with the path to the result CSV file will be utilized to resume the UT task.") - clean_up() - progress_bar_thread.join() - try: - comparator = Comparator(config.result_csv_path, config.result_csv_path, False) - comparator.print_pretest_result() - except FileNotFoundError as e: - print_error_log(f"Error: {e}") - except Exception as e: - print_error_log(f"An unexpected error occurred: {e}") - - -def prepare_config(args): - check_link(args.forward_input_file) - check_link(args.backward_input_file) if args.backward_input_file else None - forward_file = os.path.realpath(args.forward_input_file) - backward_file = os.path.realpath(args.backward_input_file) if args.backward_input_file else None - check_file_suffix(forward_file, FileCheckConst.JSON_SUFFIX) - out_path = os.path.realpath(args.out_path) if args.out_path else "./" - check_path_before_create(out_path) - create_directory(out_path) - out_path_checker = FileChecker(out_path, FileCheckConst.DIR, ability=FileCheckConst.WRITE_ABLE) - out_path = out_path_checker.common_check() - forward_splits, total_items = split_json_file(args.forward_input_file, args.num_splits, args.filter_api) - backward_splits = [backward_file] * args.num_splits if backward_file else [None] * args.num_splits - result_csv_path = args.result_csv_path or os.path.join(out_path, f"accuracy_checking_result_{time.strftime('%Y%m%d%H%M%S')}.csv") - if not args.result_csv_path: - details_csv_path = os.path.join(out_path, f"accuracy_checking_details_{time.strftime('%Y%m%d%H%M%S')}.csv") - comparator = Comparator(result_csv_path, details_csv_path, False) - print_info_log(f"UT task result will be saved in {result_csv_path}") - print_info_log(f"UT task details will be saved in {details_csv_path}") - else: - result_csv_path = get_validated_result_csv_path(args.result_csv_path, 'result') - details_csv_path = get_validated_details_csv_path(result_csv_path) - print_info_log(f"UT task result will be saved in {result_csv_path}") - print_info_log(f"UT task details will be saved in {details_csv_path}") - return ParallelUTConfig(forward_splits, backward_splits, out_path, args.num_splits, args.save_error_data, args.jit_compile, args.device_id, result_csv_path, total_items, args.real_data_path) - - -def main(): - parser = argparse.ArgumentParser(description='Run UT in parallel') - _run_ut_parser(parser) - parser.add_argument('-n', '--num_splits', type=int, choices=range(1, 65), default=8, help='Number of splits for parallel processing. Range: 1-64') - args = parser.parse_args() - config = prepare_config(args) - run_parallel_ut(config) - -if __name__ == '__main__': - main() +import subprocess +import json +import os +import sys +import argparse +import time +import signal +import threading +from collections import namedtuple +from itertools import cycle +from tqdm import tqdm +from calibrator.common.file_check import FileCheckConst, FileChecker, \ + check_file_suffix, check_link, FileOpen +from calibrator.pytorch.api_accuracy_checker import Comparator +from calibrator.pytorch.api_accuracy_checker.run_ut.run_ut import _run_ut_parser, get_validated_result_csv_path, get_validated_details_csv_path, preprocess_forward_content +from calibrator.pytorch.api_accuracy_checker.common import print_error_log, print_warn_log, print_info_log, create_directory +from calibrator.common.file_check import check_path_before_create + + +def split_json_file(input_file, num_splits, filter_api): + with FileOpen(input_file, 'r') as file: + data = json.load(file) + if filter_api: + data = preprocess_forward_content(data) + + items = list(data.items()) + total_items = len(items) + chunk_size = total_items // num_splits + split_files = [] + + for i in range(num_splits): + start = i * chunk_size + end = (i + 1) * chunk_size if i < num_splits - 1 else total_items + split_filename = f"temp_part{i}.json" + with FileOpen(split_filename, 'w') as split_file: + json.dump(dict(items[start:end]), split_file) + split_files.append(split_filename) + + return split_files, total_items + + +def signal_handler(signum, frame): + print_warn_log(f'Signal handler called with signal {signum}') + raise KeyboardInterrupt() + +signal.signal(signal.SIGINT, signal_handler) +signal.signal(signal.SIGTERM, signal_handler) + + +ParallelUTConfig = namedtuple('ParallelUTConfig', ['forward_files', 'backward_files', 'out_path', 'num_splits', 'save_error_data_flag', 'jit_compile_flag', 'device_id', 'result_csv_path', 'total_items', 'real_data_path']) + + +def run_parallel_ut(config): + processes = [] + device_id_cycle = cycle(config.device_id) + if config.save_error_data_flag: + print_info_log("UT task error datas will be saved") + print_info_log(f"Starting parallel UT with {config.num_splits} processes") + progress_bar = tqdm(total=config.total_items, desc="Total items", unit="items") + + def create_cmd(fwd, bwd, dev_id): + cmd = [ + sys.executable, 'run_ut.py', + '-forward', fwd, + *(['-backward', bwd] if bwd else []), + *(['-o', config.out_path] if config.out_path else []), + '-d', str(dev_id), + *(['-j'] if config.jit_compile_flag else []), + *(['-save_error_data'] if config.save_error_data_flag else []), + '-csv_path', config.result_csv_path, + *(['-real_data_path', config.real_data_path] if config.real_data_path else []) + ] + return cmd + + def read_process_output(process): + try: + while True: + if process.poll() is not None: + break + output = process.stdout.readline() + if output == '': + break + if '[ERROR]' in output: + print(output, end='') + sys.stdout.flush() + except ValueError as e: + print_warn_log(f"An error occurred while reading subprocess output: {e}") + + def update_progress_bar(progress_bar, result_csv_path): + while any(process.poll() is None for process in processes): + try: + with open(result_csv_path, 'r') as result_file: + completed_items = len(result_file.readlines()) - 1 + progress_bar.update(completed_items - progress_bar.n) + except FileNotFoundError: + print_warn_log(f"Result CSV file not found: {result_csv_path}.") + except Exception as e: + print_error_log(f"An unexpected error occurred while reading result CSV: {e}") + time.sleep(1) + + for fwd, bwd in zip(config.forward_files, config.backward_files): + cmd = create_cmd(fwd, bwd, next(device_id_cycle)) + process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, text=True, bufsize=1) + processes.append(process) + threading.Thread(target=read_process_output, args=(process,), daemon=True).start() + + progress_bar_thread = threading.Thread(target=update_progress_bar, args=(progress_bar, config.result_csv_path)) + progress_bar_thread.start() + + def clean_up(): + progress_bar.close() + for process in processes: + try: + process.terminate() + process.wait(timeout=1) + except subprocess.TimeoutExpired: + process.kill() + for file in config.forward_files: + check_link(file) + try: + os.remove(file) + except FileNotFoundError: + print_warn_log(f"File not found and could not be deleted: {file}") + + try: + for process in processes: + process.communicate(timeout=None) + except KeyboardInterrupt: + print_warn_log("Interrupted by user, terminating processes and cleaning up...") + except Exception as e: + print_error_log(f"An unexpected error occurred: {e}") + finally: + if progress_bar.n < config.total_items: + print_warn_log("The UT task has not been completed. The parameter '-csv_path' along with the path to the result CSV file will be utilized to resume the UT task.") + clean_up() + progress_bar_thread.join() + try: + comparator = Comparator(config.result_csv_path, config.result_csv_path, False) + comparator.print_pretest_result() + except FileNotFoundError as e: + print_error_log(f"Error: {e}") + except Exception as e: + print_error_log(f"An unexpected error occurred: {e}") + + +def prepare_config(args): + check_link(args.forward_input_file) + check_link(args.backward_input_file) if args.backward_input_file else None + forward_file = os.path.realpath(args.forward_input_file) + backward_file = os.path.realpath(args.backward_input_file) if args.backward_input_file else None + check_file_suffix(forward_file, FileCheckConst.JSON_SUFFIX) + out_path = os.path.realpath(args.out_path) if args.out_path else "./" + check_path_before_create(out_path) + create_directory(out_path) + out_path_checker = FileChecker(out_path, FileCheckConst.DIR, ability=FileCheckConst.WRITE_ABLE) + out_path = out_path_checker.common_check() + forward_splits, total_items = split_json_file(args.forward_input_file, args.num_splits, args.filter_api) + backward_splits = [backward_file] * args.num_splits if backward_file else [None] * args.num_splits + result_csv_path = args.result_csv_path or os.path.join(out_path, f"accuracy_checking_result_{time.strftime('%Y%m%d%H%M%S')}.csv") + if not args.result_csv_path: + details_csv_path = os.path.join(out_path, f"accuracy_checking_details_{time.strftime('%Y%m%d%H%M%S')}.csv") + comparator = Comparator(result_csv_path, details_csv_path, False) + print_info_log(f"UT task result will be saved in {result_csv_path}") + print_info_log(f"UT task details will be saved in {details_csv_path}") + else: + result_csv_path = get_validated_result_csv_path(args.result_csv_path, 'result') + details_csv_path = get_validated_details_csv_path(result_csv_path) + print_info_log(f"UT task result will be saved in {result_csv_path}") + print_info_log(f"UT task details will be saved in {details_csv_path}") + return ParallelUTConfig(forward_splits, backward_splits, out_path, args.num_splits, args.save_error_data, args.jit_compile, args.device_id, result_csv_path, total_items, args.real_data_path) + + +def main(): + parser = argparse.ArgumentParser(description='Run UT in parallel') + _run_ut_parser(parser) + parser.add_argument('-n', '--num_splits', type=int, choices=range(1, 65), default=8, help='Number of splits for parallel processing. Range: 1-64') + args = parser.parse_args() + config = prepare_config(args) + run_parallel_ut(config) + +if __name__ == '__main__': + main() diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_overflow_check.py b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py similarity index 94% rename from debug/accuracy_tools/api_accuracy_checker/run_ut/run_overflow_check.py rename to debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py index 2e8a12231ed31c499648f12ee93486dfed47e00c..5a908b2a8ebd5d0586a3487fa4fcdeedb6fb72bb 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_overflow_check.py +++ b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py @@ -4,9 +4,9 @@ import sys import torch_npu import torch from tqdm import tqdm -from api_accuracy_checker.run_ut.run_ut import exec_api, generate_device_params, get_api_info -from api_accuracy_checker.common.utils import print_info_log, print_warn_log, get_json_contents, print_error_log -from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import check_link +from calibrator.pytorch.api_accuracy_checker.run_ut.run_ut import exec_api, generate_device_params, get_api_info +from calibrator.pytorch.api_accuracy_checker.common import print_info_log, print_warn_log, get_json_contents, print_error_log +from calibrator.common.file_check import check_link def check_tensor_overflow(x): diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/run_ut/run_ut.py similarity index 93% rename from debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py rename to debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/run_ut/run_ut.py index 856cb237ca7e1ce71da45938be2284c2e6133a2b..34bafbad6d69dff106f8a10d7d9a9c10575ed154 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/run_ut/run_ut.py @@ -16,19 +16,18 @@ else: current_device = "npu" import torch from tqdm import tqdm -from api_accuracy_checker.run_ut.data_generate import gen_api_params, gen_args -from api_accuracy_checker.common.utils import print_info_log, print_warn_log, get_json_contents, api_info_preprocess, \ +from calibrator.pytorch.api_accuracy_checker import gen_api_params, gen_args +from calibrator.pytorch.api_accuracy_checker.common import print_info_log, print_warn_log, get_json_contents, api_info_preprocess, \ print_error_log, initialize_save_path, Const, create_directory -from api_accuracy_checker.compare.compare import Comparator -from api_accuracy_checker.hook_module.wrap_tensor import TensorOPTemplate -from api_accuracy_checker.hook_module.wrap_functional import FunctionalOPTemplate -from api_accuracy_checker.hook_module.wrap_torch import TorchOPTemplate -from api_accuracy_checker.common.config import msCheckerConfig -from api_accuracy_checker.dump.api_info import APIInfo -from ptdbg_ascend.src.python.ptdbg_ascend.common.utils import check_path_before_create - - -from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileOpen, FileCheckConst, FileChecker, \ +from calibrator.pytorch.api_accuracy_checker import Comparator +from calibrator.pytorch.api_accuracy_checker import TensorOPTemplate +from calibrator.pytorch.api_accuracy_checker import FunctionalOPTemplate +from calibrator.pytorch.api_accuracy_checker import TorchOPTemplate +from calibrator.pytorch.api_accuracy_checker.common import msCheckerConfig +from calibrator.pytorch.api_accuracy_checker import APIInfo +from calibrator.pytorch.common.parse_json import parse_json_info_forward_backward +from calibrator.common.file_check import check_path_before_create +from calibrator.common.file_check import FileOpen, FileCheckConst, FileChecker, \ change_mode, check_file_suffix, check_link current_time = time.strftime("%Y%m%d%H%M%S") @@ -165,7 +164,7 @@ def run_ut(config): continue try: if msCheckerConfig.white_list: - [_, api_name, _] = api_full_name.split("*") + [_, api_name, _] = api_full_name.split(Const.SEP) if api_name not in set(msCheckerConfig.white_list): continue data_info = run_torch_api(api_full_name, config.real_data_path, config.backward_content, api_info_dict) @@ -177,7 +176,7 @@ def run_ut(config): if config.save_error_data: do_save_error_data(api_full_name, data_info, is_fwd_success, is_bwd_success) except Exception as err: - [_, api_name, _] = api_full_name.split("*") + [_, api_name, _] = api_full_name.split(Const.SEP) if "expected scalar type Long" in str(err): print_warn_log(f"API {api_name} not support int32 tensor in CPU, please add {api_name} to CONVERT_API " f"'int32_to_int64' list in accuracy_tools/api_accuracy_check/common/utils.py file.") @@ -197,7 +196,6 @@ def run_ut(config): def do_save_error_data(api_full_name, data_info, is_fwd_success, is_bwd_success): if not is_fwd_success or not is_bwd_success: - api_full_name = api_full_name.replace("*", ".") for element in data_info.in_fwd_data_list: UtAPIInfo(api_full_name + '.forward.input', element) UtAPIInfo(api_full_name + '.forward.output.bench', data_info.bench_out) @@ -209,7 +207,7 @@ def do_save_error_data(api_full_name, data_info, is_fwd_success, is_bwd_success) def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict): in_fwd_data_list = [] - [api_type, api_name, _] = api_full_name.split("*") + [api_type, api_name, _] = api_full_name.split(Const.SEP) args, kwargs, need_grad = get_api_info(api_info_dict, api_name, real_data_path) in_fwd_data_list.append(args) in_fwd_data_list.append(kwargs) @@ -239,7 +237,7 @@ def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict grad_index = grad_input_index.get('grad_index') if need_backward: - backward_args = backward_content[api_full_name] + backward_args = backward_content[api_full_name].get("grad_output") grad = gen_args(backward_args, real_data_path=real_data_path)[0] bench_grad, _ = generate_cpu_params(grad, {}, False, api_name) bench_grad_out = run_backward(cpu_args, bench_grad, grad_index, out) @@ -255,14 +253,13 @@ def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict def get_api_info(api_info_dict, api_name, real_data_path): convert_type, api_info_dict = api_info_preprocess(api_name, api_info_dict) need_grad = True - if api_info_dict.get("kwargs") and "out" in api_info_dict.get("kwargs"): + if api_info_dict.get("input_kwargs") and "out" in api_info_dict.get("input_kwargs"): need_grad = False args, kwargs = gen_api_params(api_info_dict, need_grad, convert_type, real_data_path) return args, kwargs, need_grad def run_backward(args, grad, grad_index, out): - if grad_index is not None: out[grad_index].backward(grad) elif isinstance(out, (list, tuple)): @@ -358,7 +355,7 @@ def preprocess_forward_content(forward_content): processed_content = {} base_keys_variants = {} for key, value in forward_content.items(): - base_key = key.rsplit('*', 1)[0] + base_key = key.rsplit(Const.SEP, 1)[0] new_args = value['args'] new_kwargs = value['kwargs'] filtered_new_args = [{k: v for k, v in arg.items() if k not in ['Max', 'Min']} for arg in new_args if isinstance(arg, dict)] @@ -412,15 +409,10 @@ def run_ut_command(args): out_path_checker = FileChecker(out_path, FileCheckConst.DIR, ability=FileCheckConst.WRITE_ABLE) out_path = out_path_checker.common_check() save_error_data = args.save_error_data - forward_content = get_json_contents(forward_file) + forward_content, backward_content, real_data_path = parse_json_info_forward_backward(forward_file) if args.filter_api: forward_content = preprocess_forward_content(forward_content) - backward_content = {} - if args.backward_input_file: - check_link(args.backward_input_file) - backward_file = os.path.realpath(args.backward_input_file) - check_file_suffix(backward_file, FileCheckConst.JSON_SUFFIX) - backward_content = get_json_contents(backward_file) + result_csv_path = os.path.join(out_path, RESULT_FILE_NAME) details_csv_path = os.path.join(out_path, DETAILS_FILE_NAME) if args.result_csv_path: diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/torch_ut_setting.json b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json similarity index 100% rename from debug/accuracy_tools/api_accuracy_checker/run_ut/torch_ut_setting.json rename to debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/run_ut/torch_ut_setting.json diff --git a/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/test/__init__.py b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/test/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/debug/accuracy_tools/api_accuracy_checker/test/resources/forward.json b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/test/resources/forward.json similarity index 100% rename from debug/accuracy_tools/api_accuracy_checker/test/resources/forward.json rename to debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/test/resources/forward.json diff --git a/debug/accuracy_tools/api_accuracy_checker/test/run_test.sh b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/test/run_test.sh similarity index 100% rename from debug/accuracy_tools/api_accuracy_checker/test/run_test.sh rename to debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/test/run_test.sh diff --git a/debug/accuracy_tools/api_accuracy_checker/test/run_ut.py b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/test/run_ut.py similarity index 100% rename from debug/accuracy_tools/api_accuracy_checker/test/run_ut.py rename to debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/test/run_ut.py diff --git a/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/test/ut/__init__.py b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/test/ut/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/test/ut/common/__init__.py b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/test/ut/common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/debug/accuracy_tools/api_accuracy_checker/test/ut/common/test_common_utils.py b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/test/ut/common/test_common_utils.py similarity index 98% rename from debug/accuracy_tools/api_accuracy_checker/test/ut/common/test_common_utils.py rename to debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/test/ut/common/test_common_utils.py index 5f25e81c09783eeb8c682fd33d3178b99352f6e0..febd2f1a03a0ef258094926a2db49d935f3b4f71 100644 --- a/debug/accuracy_tools/api_accuracy_checker/test/ut/common/test_common_utils.py +++ b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/test/ut/common/test_common_utils.py @@ -1,8 +1,7 @@ import unittest import os -import numpy as np import torch -from api_accuracy_checker.common.utils import * + class TestUtils(unittest.TestCase): diff --git a/debug/accuracy_tools/api_accuracy_checker/test/ut/common/test_config.py b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/test/ut/common/test_config.py similarity index 91% rename from debug/accuracy_tools/api_accuracy_checker/test/ut/common/test_config.py rename to debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/test/ut/common/test_config.py index a68057dfb41ca38ba79e1daa992a8f51ce4d64e4..d7c6618e58983df683c66c27a4b982517b0d2e80 100644 --- a/debug/accuracy_tools/api_accuracy_checker/test/ut/common/test_config.py +++ b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/test/ut/common/test_config.py @@ -1,6 +1,6 @@ import unittest import os -from api_accuracy_checker.common.config import Config +from calibrator.pytorch.api_accuracy_checker.common import Config class TestConfig(unittest.TestCase): def setUp(self): diff --git a/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/test/ut/compare/__init__.py b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/test/ut/compare/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/debug/accuracy_tools/api_accuracy_checker/test/ut/compare/test_algorithm.py b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/test/ut/compare/test_algorithm.py similarity index 90% rename from debug/accuracy_tools/api_accuracy_checker/test/ut/compare/test_algorithm.py rename to debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/test/ut/compare/test_algorithm.py index 90e18d166f56f98b8c1e1f80f2ae28dab7db67d3..0b707795bdc5ccbf81c526f3c9662f05f01a4413 100644 --- a/debug/accuracy_tools/api_accuracy_checker/test/ut/compare/test_algorithm.py +++ b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/test/ut/compare/test_algorithm.py @@ -1,8 +1,7 @@ import unittest import numpy as np -import torch -from api_accuracy_checker.compare import compare as cmp -from api_accuracy_checker.compare import algorithm as alg +from calibrator.pytorch.api_accuracy_checker.compare import algorithm as alg + class TestAlgorithmMethods(unittest.TestCase): diff --git a/debug/accuracy_tools/api_accuracy_checker/test/ut/compare/test_compare.py b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/test/ut/compare/test_compare.py similarity index 97% rename from debug/accuracy_tools/api_accuracy_checker/test/ut/compare/test_compare.py rename to debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/test/ut/compare/test_compare.py index 4ce73ce550dfc5d5cd21246dbc2756a6024f6fea..f9a2ffdca15df80ee27bc6b72f891e8388cf94c1 100644 --- a/debug/accuracy_tools/api_accuracy_checker/test/ut/compare/test_compare.py +++ b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/test/ut/compare/test_compare.py @@ -7,8 +7,8 @@ import unittest import numpy as np import torch.nn.functional -from api_accuracy_checker.compare.compare import Comparator -from api_accuracy_checker.compare.compare_column import CompareColumn +from calibrator.pytorch.api_accuracy_checker import Comparator +from calibrator.pytorch.api_accuracy_checker import CompareColumn current_time = time.strftime("%Y%m%d%H%M%S") RESULT_FILE_NAME = "accuracy_checking_result_" + current_time + ".csv" diff --git a/debug/accuracy_tools/api_accuracy_checker/test/ut/compare/test_compare_utils.py b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/test/ut/compare/test_compare_utils.py similarity index 91% rename from debug/accuracy_tools/api_accuracy_checker/test/ut/compare/test_compare_utils.py rename to debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/test/ut/compare/test_compare_utils.py index 4e83c0643ef452c28d11c02bbbc2fee359a1ea2e..3950d7eaefb696fd4e626b72e20a7a9d08662eba 100644 --- a/debug/accuracy_tools/api_accuracy_checker/test/ut/compare/test_compare_utils.py +++ b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/test/ut/compare/test_compare_utils.py @@ -1,6 +1,6 @@ import unittest import numpy as np -from api_accuracy_checker.compare.compare_utils import CompareConst, check_dtype_comparable +from calibrator.api_accuracy_checker.compare.compare_utils import check_dtype_comparable class TestCompareUtils(unittest.TestCase): def test_check_dtype_comparable(self): diff --git a/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/test/ut/dump/__init__.py b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/test/ut/dump/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/debug/accuracy_tools/api_accuracy_checker/test/ut/dump/test_api_info.py b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/test/ut/dump/test_api_info.py similarity index 96% rename from debug/accuracy_tools/api_accuracy_checker/test/ut/dump/test_api_info.py rename to debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/test/ut/dump/test_api_info.py index 2c03d56e722decc424052367dfe9700ba3df94ce..ef71f57734aa8211e8db4e1ee5cdb0354efd09d4 100644 --- a/debug/accuracy_tools/api_accuracy_checker/test/ut/dump/test_api_info.py +++ b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/test/ut/dump/test_api_info.py @@ -3,9 +3,9 @@ import shutil import unittest import torch import numpy as np -from api_accuracy_checker.dump.api_info import APIInfo, ForwardAPIInfo, BackwardAPIInfo, transfer_types, \ +from calibrator.pytorch.api_accuracy_checker import APIInfo, ForwardAPIInfo, BackwardAPIInfo, transfer_types, \ get_tensor_extremum, get_type_name, is_builtin_class, analyze_device_in_kwargs, analyze_dtype_in_kwargs -from api_accuracy_checker.common.config import msCheckerConfig +from calibrator.pytorch.api_accuracy_checker.common import msCheckerConfig class TestAPIInfo(unittest.TestCase): diff --git a/debug/accuracy_tools/api_accuracy_checker/test/ut/dump/test_dump.py b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/test/ut/dump/test_dump.py similarity index 95% rename from debug/accuracy_tools/api_accuracy_checker/test/ut/dump/test_dump.py rename to debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/test/ut/dump/test_dump.py index 655e624e809a5cceb406b9fce9df4e4f89efb4ee..634ca16c1a7debd36dda80a0fe18f256ef1a0e22 100644 --- a/debug/accuracy_tools/api_accuracy_checker/test/ut/dump/test_dump.py +++ b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/test/ut/dump/test_dump.py @@ -1,5 +1,5 @@ import unittest -from api_accuracy_checker.dump.dump import * + class TestDumpUtil(unittest.TestCase): def test_set_dump_switch(self): diff --git a/debug/accuracy_tools/api_accuracy_checker/test/ut/dump/test_dump_scope.py b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/test/ut/dump/test_dump_scope.py similarity index 81% rename from debug/accuracy_tools/api_accuracy_checker/test/ut/dump/test_dump_scope.py rename to debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/test/ut/dump/test_dump_scope.py index 7712552abe49d757a07bcbbd746038ed22d4027b..2b91601213242695c3b1e6e204cf1d88cdf59848 100644 --- a/debug/accuracy_tools/api_accuracy_checker/test/ut/dump/test_dump_scope.py +++ b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/test/ut/dump/test_dump_scope.py @@ -1,6 +1,6 @@ import unittest -from api_accuracy_checker.dump.dump_scope import iter_tracer -from api_accuracy_checker.dump.dump import DumpUtil +from calibrator.api_accuracy_checker.dump.dump_scope import iter_tracer +from calibrator.pytorch.api_accuracy_checker.dump.dump import DumpUtil class TestDumpScope(unittest.TestCase): diff --git a/debug/accuracy_tools/api_accuracy_checker/test/ut/dump/test_info_dump.py b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/test/ut/dump/test_info_dump.py similarity index 87% rename from debug/accuracy_tools/api_accuracy_checker/test/ut/dump/test_info_dump.py rename to debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/test/ut/dump/test_info_dump.py index 45e57f2c389292e9226039f56b83966941c603ca..82292805d99515df006dfb6307195d4325f429fb 100644 --- a/debug/accuracy_tools/api_accuracy_checker/test/ut/dump/test_info_dump.py +++ b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/test/ut/dump/test_info_dump.py @@ -1,8 +1,8 @@ import unittest import os from unittest.mock import patch -from api_accuracy_checker.dump.api_info import APIInfo, BackwardAPIInfo -from api_accuracy_checker.dump.info_dump import write_api_info_json +from calibrator.pytorch.api_accuracy_checker import APIInfo, BackwardAPIInfo +from calibrator.pytorch.api_accuracy_checker import write_api_info_json class TestInfoDump(unittest.TestCase): diff --git a/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/test/ut/hook_module/__init__.py b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/test/ut/hook_module/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/debug/accuracy_tools/api_accuracy_checker/test/ut/hook_module/test_wrap_functional.py b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/test/ut/hook_module/test_wrap_functional.py similarity index 84% rename from debug/accuracy_tools/api_accuracy_checker/test/ut/hook_module/test_wrap_functional.py rename to debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/test/ut/hook_module/test_wrap_functional.py index 37058e77fd87e697b7dd7fde5e94b78d01a2cb89..882457326352ba5fb271d07728b68969f85e7c6e 100644 --- a/debug/accuracy_tools/api_accuracy_checker/test/ut/hook_module/test_wrap_functional.py +++ b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/test/ut/hook_module/test_wrap_functional.py @@ -1,7 +1,7 @@ # coding=utf-8 import unittest -import torch -from api_accuracy_checker.hook_module import wrap_functional as wf +from calibrator.pytorch.api_accuracy_checker import wrap_functional as wf + class TestWrapFunctional(unittest.TestCase): diff --git a/debug/accuracy_tools/api_accuracy_checker/test/ut/hook_module/test_wrap_tensor.py b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/test/ut/hook_module/test_wrap_tensor.py similarity index 82% rename from debug/accuracy_tools/api_accuracy_checker/test/ut/hook_module/test_wrap_tensor.py rename to debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/test/ut/hook_module/test_wrap_tensor.py index bfae3c72771510b141abf9204723bfe48bfa8de3..238e1d48d9b5b6c3d69f403e00e3d9c4c19302da 100644 --- a/debug/accuracy_tools/api_accuracy_checker/test/ut/hook_module/test_wrap_tensor.py +++ b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/test/ut/hook_module/test_wrap_tensor.py @@ -1,8 +1,6 @@ # coding=utf-8 import unittest -import torch -import yaml -from api_accuracy_checker.hook_module.wrap_tensor import get_tensor_ops, HOOKTensor, TensorOPTemplate, wrap_tensor_op, wrap_tensor_ops_and_bind +from calibrator.api_accuracy_checker.hook_module.wrap_tensor import get_tensor_ops, HOOKTensor, TensorOPTemplate, wrap_tensor_op, wrap_tensor_ops_and_bind class TestWrapTensor(unittest.TestCase): def hook(self, a, b): diff --git a/debug/accuracy_tools/api_accuracy_checker/test/ut/hook_module/test_wrap_torch.py b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/test/ut/hook_module/test_wrap_torch.py similarity index 94% rename from debug/accuracy_tools/api_accuracy_checker/test/ut/hook_module/test_wrap_torch.py rename to debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/test/ut/hook_module/test_wrap_torch.py index 40cef939adfd06158eb543c07b3d682e29d6cdab..f07ffaea7ea77204e338dea60a68124f6e41610c 100644 --- a/debug/accuracy_tools/api_accuracy_checker/test/ut/hook_module/test_wrap_torch.py +++ b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/test/ut/hook_module/test_wrap_torch.py @@ -1,8 +1,7 @@ # coding=utf-8 import unittest import torch -import yaml -from api_accuracy_checker.hook_module.wrap_torch import * + class TestWrapTorch(unittest.TestCase): diff --git a/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/test/ut/run_ut/__init__.py b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/test/ut/run_ut/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/debug/accuracy_tools/api_accuracy_checker/test/ut/run_ut/test_data_generate.py b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/test/ut/run_ut/test_data_generate.py similarity index 96% rename from debug/accuracy_tools/api_accuracy_checker/test/ut/run_ut/test_data_generate.py rename to debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/test/ut/run_ut/test_data_generate.py index b98f84d516404665b5c3284f1e03f14eedddac55..6560045bd45159e7e30b4fa48bd928a519c9e431 100644 --- a/debug/accuracy_tools/api_accuracy_checker/test/ut/run_ut/test_data_generate.py +++ b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/test/ut/run_ut/test_data_generate.py @@ -1,10 +1,8 @@ # coding=utf-8 import unittest -import numpy as np -import os import copy -from api_accuracy_checker.run_ut.data_generate import * -from api_accuracy_checker.common.utils import get_json_contents +from calibrator.pytorch.api_accuracy_checker import * +from calibrator.pytorch.api_accuracy_checker.common import get_json_contents base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) forward_file = os.path.join(base_dir, "../resources/forward.json") diff --git a/debug/accuracy_tools/api_accuracy_checker/test/ut/run_ut/test_multi_run_ut.py b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/test/ut/run_ut/test_multi_run_ut.py similarity index 93% rename from debug/accuracy_tools/api_accuracy_checker/test/ut/run_ut/test_multi_run_ut.py rename to debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/test/ut/run_ut/test_multi_run_ut.py index 18293a4bc1fc899191bde35252034962f8312f3c..f6fb0b8129efc82da81728c37becd83334d23d3f 100644 --- a/debug/accuracy_tools/api_accuracy_checker/test/ut/run_ut/test_multi_run_ut.py +++ b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/test/ut/run_ut/test_multi_run_ut.py @@ -1,103 +1,103 @@ -import unittest -from unittest.mock import patch, mock_open, MagicMock -import json -import signal -from api_accuracy_checker.run_ut.multi_run_ut import split_json_file, signal_handler, run_parallel_ut, prepare_config, main, ParallelUTConfig - - -class TestMultiRunUT(unittest.TestCase): - - def setUp(self): - self.test_json_file = 'test_file.json' - self.test_data = {'key1': 'TRUE', 'key2': 'TRUE', 'key3': 'TRUE'} - self.test_json_content = json.dumps(self.test_data) - self.forward_split_files_content = [ - {'key1': 'TRUE', 'key2': 'TRUE'}, - {'key3': 'TRUE', 'key4': 'TRUE'} - ] - - @patch('api_accuracy_checker.run_ut.multi_run_ut.FileOpen') - def test_split_json_file(self, mock_FileOpen): - mock_FileOpen.return_value.__enter__.return_value = mock_open(read_data=self.test_json_content).return_value - num_splits = 2 - split_files, total_items = split_json_file(self.test_json_file, num_splits, False) - self.assertEqual(len(split_files), num_splits) - self.assertEqual(total_items, len(self.test_data)) - - @patch('api_accuracy_checker.run_ut.multi_run_ut.print_warn_log') - def test_signal_handler(self, mock_print_warn_log): - with self.assertRaises(KeyboardInterrupt): - signal_handler(signal.SIGINT, None) - mock_print_warn_log.assert_called() - - @patch('subprocess.Popen') - @patch('os.path.exists', return_value=True) - @patch('builtins.open', new_callable=mock_open) - @patch('json.load', side_effect=lambda f: {'key1': 'TRUE', 'key2': 'TRUE'}) - def test_run_parallel_ut(self, mock_json_load, mock_file, mock_exists, mock_popen): - mock_process = MagicMock() - mock_process.poll.side_effect = [None, None, 1] - mock_process.stdout.readline.side_effect = ['[ERROR] Test Error Message\n', ''] - mock_popen.return_value = mock_process - - config = ParallelUTConfig( - forward_files=['forward_split1.json', 'forward_split2.json'], - backward_files=[None, None], - out_path='./', - num_splits=2, - save_error_data_flag=True, - jit_compile_flag=False, - device_id=[0, 1], - result_csv_path='result.csv', - total_items=2, - real_data_path=None - ) - - mock_file.side_effect = [ - mock_open(read_data=json.dumps(self.forward_split_files_content[0])).return_value, - mock_open(read_data=json.dumps(self.forward_split_files_content[1])).return_value - ] - - run_parallel_ut(config) - - mock_popen.assert_called() - mock_exists.assert_called() - - @patch('os.remove') - @patch('os.path.realpath', side_effect=lambda x: x) - @patch('api_accuracy_checker.run_ut.multi_run_ut.check_link') - @patch('api_accuracy_checker.run_ut.multi_run_ut.check_file_suffix') - @patch('api_accuracy_checker.run_ut.multi_run_ut.FileChecker') - @patch('api_accuracy_checker.run_ut.multi_run_ut.split_json_file', return_value=(['forward_split1.json', 'forward_split2.json'], 2)) - def test_prepare_config(self, mock_split_json_file, mock_FileChecker, mock_check_file_suffix, mock_check_link, mock_realpath, mock_remove): - mock_FileChecker_instance = MagicMock() - mock_FileChecker_instance.common_check.return_value = './' - mock_FileChecker.return_value = mock_FileChecker_instance - args = MagicMock() - args.forward_input_file = 'forward.json' - args.backward_input_file = None - args.out_path = './' - args.num_splits = 2 - args.save_error_data = True - args.jit_compile = False - args.device_id = [0, 1] - args.result_csv_path = None - args.real_data_path = None - - config = prepare_config(args) - - self.assertEqual(config.num_splits, 2) - self.assertTrue(config.save_error_data_flag) - self.assertFalse(config.jit_compile_flag) - self.assertEqual(config.device_id, [0, 1]) - self.assertEqual(len(config.forward_files), 2) - self.assertEqual(config.total_items, 2) - - @patch('argparse.ArgumentParser.parse_args') - @patch('api_accuracy_checker.run_ut.multi_run_ut.prepare_config') - @patch('api_accuracy_checker.run_ut.multi_run_ut.run_parallel_ut') - def test_main(self, mock_run_parallel_ut, mock_prepare_config, mock_parse_args): - main() - mock_parse_args.assert_called() - mock_prepare_config.assert_called() +import unittest +from unittest.mock import patch, mock_open, MagicMock +import json +import signal +from calibrator.pytorch.api_accuracy_checker import split_json_file, signal_handler, run_parallel_ut, prepare_config, main, ParallelUTConfig + + +class TestMultiRunUT(unittest.TestCase): + + def setUp(self): + self.test_json_file = 'test_file.json' + self.test_data = {'key1': 'TRUE', 'key2': 'TRUE', 'key3': 'TRUE'} + self.test_json_content = json.dumps(self.test_data) + self.forward_split_files_content = [ + {'key1': 'TRUE', 'key2': 'TRUE'}, + {'key3': 'TRUE', 'key4': 'TRUE'} + ] + + @patch('api_accuracy_checker.run_ut.multi_run_ut.FileOpen') + def test_split_json_file(self, mock_FileOpen): + mock_FileOpen.return_value.__enter__.return_value = mock_open(read_data=self.test_json_content).return_value + num_splits = 2 + split_files, total_items = split_json_file(self.test_json_file, num_splits, False) + self.assertEqual(len(split_files), num_splits) + self.assertEqual(total_items, len(self.test_data)) + + @patch('api_accuracy_checker.run_ut.multi_run_ut.print_warn_log') + def test_signal_handler(self, mock_print_warn_log): + with self.assertRaises(KeyboardInterrupt): + signal_handler(signal.SIGINT, None) + mock_print_warn_log.assert_called() + + @patch('subprocess.Popen') + @patch('os.path.exists', return_value=True) + @patch('builtins.open', new_callable=mock_open) + @patch('json.load', side_effect=lambda f: {'key1': 'TRUE', 'key2': 'TRUE'}) + def test_run_parallel_ut(self, mock_json_load, mock_file, mock_exists, mock_popen): + mock_process = MagicMock() + mock_process.poll.side_effect = [None, None, 1] + mock_process.stdout.readline.side_effect = ['[ERROR] Test Error Message\n', ''] + mock_popen.return_value = mock_process + + config = ParallelUTConfig( + forward_files=['forward_split1.json', 'forward_split2.json'], + backward_files=[None, None], + out_path='/', + num_splits=2, + save_error_data_flag=True, + jit_compile_flag=False, + device_id=[0, 1], + result_csv_path='result.csv', + total_items=2, + real_data_path=None + ) + + mock_file.side_effect = [ + mock_open(read_data=json.dumps(self.forward_split_files_content[0])).return_value, + mock_open(read_data=json.dumps(self.forward_split_files_content[1])).return_value + ] + + run_parallel_ut(config) + + mock_popen.assert_called() + mock_exists.assert_called() + + @patch('os.remove') + @patch('os.path.realpath', side_effect=lambda x: x) + @patch('api_accuracy_checker.run_ut.multi_run_ut.check_link') + @patch('api_accuracy_checker.run_ut.multi_run_ut.check_file_suffix') + @patch('api_accuracy_checker.run_ut.multi_run_ut.FileChecker') + @patch('api_accuracy_checker.run_ut.multi_run_ut.split_json_file', return_value=(['forward_split1.json', 'forward_split2.json'], 2)) + def test_prepare_config(self, mock_split_json_file, mock_FileChecker, mock_check_file_suffix, mock_check_link, mock_realpath, mock_remove): + mock_FileChecker_instance = MagicMock() + mock_FileChecker_instance.common_check.return_value = './' + mock_FileChecker.return_value = mock_FileChecker_instance + args = MagicMock() + args.forward_input_file = 'forward.json' + args.backward_input_file = None + args.out_path = '/' + args.num_splits = 2 + args.save_error_data = True + args.jit_compile = False + args.device_id = [0, 1] + args.result_csv_path = None + args.real_data_path = None + + config = prepare_config(args) + + self.assertEqual(config.num_splits, 2) + self.assertTrue(config.save_error_data_flag) + self.assertFalse(config.jit_compile_flag) + self.assertEqual(config.device_id, [0, 1]) + self.assertEqual(len(config.forward_files), 2) + self.assertEqual(config.total_items, 2) + + @patch('argparse.ArgumentParser.parse_args') + @patch('api_accuracy_checker.run_ut.multi_run_ut.prepare_config') + @patch('api_accuracy_checker.run_ut.multi_run_ut.run_parallel_ut') + def test_main(self, mock_run_parallel_ut, mock_prepare_config, mock_parse_args): + main() + mock_parse_args.assert_called() + mock_prepare_config.assert_called() mock_run_parallel_ut.assert_called() \ No newline at end of file diff --git a/debug/accuracy_tools/api_accuracy_checker/test/ut/run_ut/test_run_ut.py b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/test/ut/run_ut/test_run_ut.py similarity index 96% rename from debug/accuracy_tools/api_accuracy_checker/test/ut/run_ut/test_run_ut.py rename to debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/test/ut/run_ut/test_run_ut.py index fdcc1cfddeb38d4fca0d2a67a09147b571b35def..ac318e9583cc7c4b9e0498583fb4fb556f52f98a 100644 --- a/debug/accuracy_tools/api_accuracy_checker/test/ut/run_ut/test_run_ut.py +++ b/debug/accuracy_tools/msacc/pytorch/api_accuracy_checker/test/ut/run_ut/test_run_ut.py @@ -4,8 +4,7 @@ import copy import unittest from unittest.mock import patch, DEFAULT import torch -from api_accuracy_checker.run_ut.run_ut import * -from api_accuracy_checker.common.utils import get_json_contents +from calibrator.pytorch.api_accuracy_checker.common import get_json_contents base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) forward_file = os.path.join(base_dir, "../resources/forward.json") diff --git a/debug/accuracy_tools/msacc/pytorch/common/__init__.py b/debug/accuracy_tools/msacc/pytorch/common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..05628161eeada91f2854bd0c3330978791bc779d --- /dev/null +++ b/debug/accuracy_tools/msacc/pytorch/common/__init__.py @@ -0,0 +1,3 @@ +from .recursive import recursive_apply_transform +from .log import print_error_log_rank_0, print_info_log_rank_0, print_warn_log_rank_0 +from .parse_json import parse_json_info_forward_backward \ No newline at end of file diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/common/compare_script.template b/debug/accuracy_tools/msacc/pytorch/common/compare_script.template similarity index 100% rename from debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/common/compare_script.template rename to debug/accuracy_tools/msacc/pytorch/common/compare_script.template diff --git a/debug/accuracy_tools/msacc/pytorch/common/exceptions.py b/debug/accuracy_tools/msacc/pytorch/common/exceptions.py new file mode 100644 index 0000000000000000000000000000000000000000..607f6eb2be4f573df4a472b7e0774b069422856a --- /dev/null +++ b/debug/accuracy_tools/msacc/pytorch/common/exceptions.py @@ -0,0 +1,67 @@ + +class CodedException(Exception): + def __init__(self, code, error_info=''): + self.error_info = self.err_strs.get(code) + error_info + + def __str__(self): + return self.error_info + + +class MsaccException(CodedException): + INVALID_PARAM_ERROR = 0 + + err_strs = { + INVALID_PARAM_ERROR: "[msacc] 无效参数: " + } + + +class FileCheckException(CodedException): + INVALID_FILE_ERROR = 0 + FILE_PERMISSION_ERROR = 1 + SOFT_LINK_ERROR = 2 + ILLEGAL_PATH_ERROR = 3 + ILLEGAL_PARAM_ERROR = 4 + FILE_TOO_LARGE_ERROR = 5 + + err_strs = { + SOFT_LINK_ERROR: "[msacc] 检测到软链接: ", + FILE_PERMISSION_ERROR: "[msacc] 文件权限错误: ", + INVALID_FILE_ERROR: "[msacc] 无效文件: ", + ILLEGAL_PATH_ERROR: "[msacc] 非法文件路径: ", + ILLEGAL_PARAM_ERROR: "[msacc] 非法打开方式: ", + FILE_TOO_LARGE_ERROR: "[msacc] 文件过大: " + } + + +class ParseJsonException(CodedException): + UnexpectedNameStruct = 0 + InvalidDumpJson = 1 + err_strs = { + UnexpectedNameStruct: "[msacc] Unexpected name in json: ", + InvalidDumpJson: "[msacc] json格式不正确: ", + } + + +class ScopeException(CodedException): + InvalidApiStr = 0 + InvalidScope = 1 + ArgConflict = 2 + err_strs = { + InvalidApiStr: "[msacc] Invalid api_list: ", + InvalidScope: "[msacc] Invalid scope: ", + ArgConflict: "[msacc] Scope and api_list conflict: ", + } + + +class RepairException(CodedException): + InvalidRepairType = 0 + err_strs = { + InvalidRepairType: "[msacc] Invalid repair_type: " + } + + +class StepException(CodedException): + InvalidPostProcess = 0 + err_strs = { + InvalidPostProcess: "[msacc] 错误的step后处理配置: ", + } diff --git a/debug/accuracy_tools/msacc/pytorch/common/file_check.py b/debug/accuracy_tools/msacc/pytorch/common/file_check.py new file mode 100644 index 0000000000000000000000000000000000000000..3bcd850ee77fc8e5730b124c22fbf8131ec640d7 --- /dev/null +++ b/debug/accuracy_tools/msacc/pytorch/common/file_check.py @@ -0,0 +1,298 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2022-2023. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +import os +import re + +from .log import print_error_log +from .exceptions import FileCheckException +from calibrator.common.utils import Const + + +class FileCheckConst: + """ + Class for file check const + """ + READ_ABLE = "read" + WRITE_ABLE = "write" + READ_WRITE_ABLE = "read and write" + DIRECTORY_LENGTH = 4096 + FILE_NAME_LENGTH = 255 + FILE_VALID_PATTERN = r"^[a-zA-Z0-9_.:/-]+$" + PKL_SUFFIX = ".pkl" + NUMPY_SUFFIX = ".npy" + JSON_SUFFIX = ".json" + PT_SUFFIX = ".pt" + CSV_SUFFIX = ".csv" + YAML_SUFFIX = ".yaml" + MAX_PKL_SIZE = 1 * 1024 * 1024 * 1024 + MAX_NUMPY_SIZE = 10 * 1024 * 1024 * 1024 + MAX_JSON_SIZE = 1 * 1024 * 1024 * 1024 + MAX_PT_SIZE = 10 * 1024 * 1024 * 1024 + MAX_CSV_SIZE = 1 * 1024 * 1024 * 1024 + MAX_YAML_SIZE = 10 * 1024 * 1024 + DIR = "dir" + FILE = "file" + DATA_DIR_AUTHORITY = 0o750 + DATA_FILE_AUTHORITY = 0o640 + FILE_SIZE_DICT = { + PKL_SUFFIX: MAX_PKL_SIZE, + NUMPY_SUFFIX: MAX_NUMPY_SIZE, + JSON_SUFFIX: MAX_JSON_SIZE, + PT_SUFFIX: MAX_PT_SIZE, + CSV_SUFFIX: MAX_CSV_SIZE, + YAML_SUFFIX: MAX_YAML_SIZE + } + + +class FileChecker: + """ + The class for check file. + + Attributes: + file_path: The file or dictionary path to be verified. + path_type: file or dictionary + ability(str): FileCheckConst.WRITE_ABLE or FileCheckConst.READ_ABLE to set file has writability or readability + file_type(str): The correct file type for file + """ + def __init__(self, file_path, path_type, ability=None, file_type=None, is_script=True): + self.file_path = file_path + self.path_type = self._check_path_type(path_type) + self.ability = ability + self.file_type = file_type + self.is_script = is_script + + @staticmethod + def _check_path_type(path_type): + if path_type not in [FileCheckConst.DIR, FileCheckConst.FILE]: + print_error_log(f'The path_type must be {FileCheckConst.DIR} or {FileCheckConst.FILE}.') + raise FileCheckException(FileCheckException.ILLEGAL_PARAM_ERROR) + return path_type + + def common_check(self): + """ + 功能:用户校验基本文件权限:软连接、文件长度、是否存在、读写权限、文件属组、文件特殊字符 + 注意:文件后缀的合法性,非通用操作,可使用其他独立接口实现 + """ + check_path_exists(self.file_path) + check_link(self.file_path) + self.file_path = os.path.realpath(self.file_path) + check_path_length(self.file_path) + check_path_type(self.file_path, self.path_type) + self.check_path_ability() + if self.is_script: + check_path_owner_consistent(self.file_path) + check_path_pattern_vaild(self.file_path) + check_common_file_size(self.file_path) + check_file_suffix(self.file_path, self.file_type) + return self.file_path + + def check_path_ability(self): + if self.ability == FileCheckConst.WRITE_ABLE: + check_path_writability(self.file_path) + if self.ability == FileCheckConst.READ_ABLE: + check_path_readability(self.file_path) + if self.ability == FileCheckConst.READ_WRITE_ABLE: + check_path_readability(self.file_path) + check_path_writability(self.file_path) + + +class FileOpen: + """ + The class for open file by a safe way. + + Attributes: + file_path: The file or dictionary path to be opened. + mode(str): The file open mode + """ + SUPPORT_READ_MODE = ["r", "rb"] + SUPPORT_WRITE_MODE = ["w", "wb", "a", "ab"] + SUPPORT_READ_WRITE_MODE = ["r+", "rb+", "w+", "wb+", "a+", "ab+"] + + def __init__(self, file_path, mode, encoding='utf-8'): + self.file_path = file_path + self.mode = mode + self.encoding = encoding + self._handle = None + + def __enter__(self): + self.check_file_path() + binary_mode = "b" + if binary_mode not in self.mode: + self._handle = open(self.file_path, self.mode, encoding=self.encoding) + else: + self._handle = open(self.file_path, self.mode) + return self._handle + + def __exit__(self, exc_type, exc_val, exc_tb): + if self._handle: + self._handle.close() + + def check_file_path(self): + support_mode = self.SUPPORT_READ_MODE + self.SUPPORT_WRITE_MODE + self.SUPPORT_READ_WRITE_MODE + if self.mode not in support_mode: + print_error_log("File open not support %s mode" % self.mode) + raise FileCheckException(FileCheckException.ILLEGAL_PARAM_ERROR) + check_link(self.file_path) + self.file_path = os.path.realpath(self.file_path) + check_path_length(self.file_path) + self.check_ability_and_owner() + check_path_pattern_vaild(self.file_path) + if os.path.exists(self.file_path): + check_common_file_size(self.file_path) + + def check_ability_and_owner(self): + if self.mode in self.SUPPORT_READ_MODE: + check_path_exists(self.file_path) + check_path_readability(self.file_path) + check_path_owner_consistent(self.file_path) + if self.mode in self.SUPPORT_WRITE_MODE and os.path.exists(self.file_path): + check_path_writability(self.file_path) + check_path_owner_consistent(self.file_path) + if self.mode in self.SUPPORT_READ_WRITE_MODE and os.path.exists(self.file_path): + check_path_readability(self.file_path) + check_path_writability(self.file_path) + check_path_owner_consistent(self.file_path) + + +def check_link(path): + abs_path = os.path.abspath(path) + if os.path.islink(abs_path): + print_error_log('The file path {} is a soft link.'.format(path)) + raise FileCheckException(FileCheckException.SOFT_LINK_ERROR) + + +def check_path_length(path, name_length=None): + file_max_name_length = name_length if name_length else FileCheckConst.FILE_NAME_LENGTH + if len(path) > FileCheckConst.DIRECTORY_LENGTH or \ + len(os.path.basename(path)) > file_max_name_length: + print_error_log('The file path length exceeds limit.') + raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR) + + +def check_path_exists(path): + if not os.path.exists(path): + print_error_log('The file path %s does not exist.' % path) + raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR) + + +def check_path_readability(path): + if not os.access(path, os.R_OK): + print_error_log('The file path %s is not readable.' % path) + raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR) + + +def check_path_writability(path): + if not os.access(path, os.W_OK): + print_error_log('The file path %s is not writable.' % path) + raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR) + + +def check_path_executable(path): + if not os.access(path, os.X_OK): + print_error_log('The file path %s is not executable.' % path) + raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR) + + +def check_other_user_writable(path): + st = os.stat(path) + if st.st_mode & 0o002: + print_error_log('The file path %s may be insecure because other users have write permissions. ' % path) + raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR) + + +def check_path_owner_consistent(path): + file_owner = os.stat(path).st_uid + if file_owner != os.getuid(): + print_error_log('The file path %s may be insecure because is does not belong to you.' % path) + raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR) + + +def check_path_pattern_vaild(path): + if not re.match(FileCheckConst.FILE_VALID_PATTERN, path): + print_error_log('The file path {} contains special characters.'.format(path)) + raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR) + + +def check_file_size(file_path, max_size): + file_size = os.path.getsize(file_path) + if file_size >= max_size: + print_error_log(f'The size of file path {file_path} exceeds {max_size} bytes.') + raise FileCheckException(FileCheckException.FILE_TOO_LARGE_ERROR) + + +def check_common_file_size(file_path): + if os.path.isfile(file_path): + for suffix, max_size in FileCheckConst.FILE_SIZE_DICT.items(): + if file_path.endswith(suffix): + check_file_size(file_path, max_size) + break + + +def check_file_suffix(file_path, file_suffix): + if file_suffix: + if not file_path.endswith(file_suffix): + print_error_log(f"The {file_path} should be a {file_suffix} file!") + raise FileCheckException(FileCheckException.INVALID_FILE_ERROR) + + +def check_path_type(file_path, file_type): + if file_type == FileCheckConst.FILE: + if not os.path.isfile(file_path): + print_error_log(f"The {file_path} should be a file!") + raise FileCheckException(FileCheckException.INVALID_FILE_ERROR) + if file_type == FileCheckConst.DIR: + if not os.path.isdir(file_path): + print_error_log(f"The {file_path} should be a dictionary!") + raise FileCheckException(FileCheckException.INVALID_FILE_ERROR) + + +def create_directory(dir_path): + """ + Function Description: + creating a directory with specified permissions + Parameter: + dir_path: directory path + Exception Description: + when invalid data throw exception + """ + dir_path = os.path.realpath(dir_path) + try: + os.makedirs(dir_path, mode=FileCheckConst.DATA_DIR_AUTHORITY, exist_ok=True) + except OSError as ex: + raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR, + 'Failed to create {}. Please check the path permission or disk space .{}'.format(dir_path, str(ex))) from ex + + +def check_path_before_create(path): + if len(os.path.realpath(path)) > Const.DIRECTORY_LENGTH or len(os.path.basename(path)) > \ + Const.FILE_NAME_LENGTH: + raise FileCheckException(FileCheckException.INVALID_PATH_ERROR, 'The file path length exceeds limit.') + + if not re.match(Const.FILE_PATTERN, os.path.realpath(path)): + raise FileCheckException(FileCheckException.INVALID_PATH_ERROR, + 'The file path {} contains special characters.'.format(path)) + + +def change_mode(path, mode): + if not os.path.exists(path) or os.path.islink(path): + return + try: + os.chmod(path, mode) + except PermissionError as ex: + raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR, + 'Failed to change {} authority. {}'.format(path, str(ex))) from ex + diff --git a/debug/accuracy_tools/msacc/pytorch/common/log.py b/debug/accuracy_tools/msacc/pytorch/common/log.py new file mode 100644 index 0000000000000000000000000000000000000000..fab5aca45c08af7253dedf8ee13db10b271683da --- /dev/null +++ b/debug/accuracy_tools/msacc/pytorch/common/log.py @@ -0,0 +1,59 @@ +import os +import time +import sys +from .utils import get_rank_if_initialized + + +def on_rank_0(func): + def func_rank_0(*args, **kwargs): + current_rank = get_rank_if_initialized() + if current_rank is None or current_rank == 0: + return func(*args, **kwargs) + + return func_rank_0 + + +def _print_log(level, msg, end='\n'): + current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(int(time.time()))) + pid = os.getpid() + full_msg = current_time + "(" + str(pid) + ")-[" + level + "]" + msg + current_rank = get_rank_if_initialized() + if current_rank is not None: + full_msg = f"[rank {current_rank}]-" + full_msg + print(full_msg, end=end) + sys.stdout.flush() + + +def print_info_log(info_msg, end='\n'): + """ + Function Description: + print info log. + Parameter: + info_msg: the info message. + """ + _print_log("INFO", info_msg, end=end) + + +def print_error_log(error_msg): + """ + Function Description: + print error log. + Parameter: + error_msg: the error message. + """ + _print_log("ERROR", error_msg) + + +def print_warn_log(warn_msg): + """ + Function Description: + print warn log. + Parameter: + warn_msg: the warning message. + """ + _print_log("WARNING", warn_msg) + + +print_info_log_rank_0 = on_rank_0(print_info_log) +print_warn_log_rank_0 = on_rank_0(print_warn_log) +print_error_log_rank_0 = on_rank_0(print_error_log) diff --git a/debug/accuracy_tools/msacc/pytorch/common/parse_json.py b/debug/accuracy_tools/msacc/pytorch/common/parse_json.py new file mode 100644 index 0000000000000000000000000000000000000000..2dddb185c14abb7e3b6e560322aa6169708a122d --- /dev/null +++ b/debug/accuracy_tools/msacc/pytorch/common/parse_json.py @@ -0,0 +1,37 @@ +import json +from .exceptions import ParseJsonException + + +def parse_json_info_forward_backward(json_path): + def parse_data_name_with_pattern(data_name, pattern): + name_struct = data_name.split('.') + if not name_struct[-1] == pattern: + raise ParseJsonException(ParseJsonException.UnexpectedNameStruct, + f"{data_name} in file {json_path}") + api_name = '.'.join(name_struct[:-1]) + return api_name + + with open(json_path, 'r') as f: + dump_json = json.load(f) + + real_data_path = dump_json.get("dump_path") + dump_data = dump_json.get("data") + if not dump_data: + raise ParseJsonException(ParseJsonException.InvalidDumpJson, "dump数据中没有data字段") + + forward_data = {} + backward_data = {} + for data_name, data_item in dump_data.items(): + if "Module" in data_name: + continue + if "forward" in data_name: + api_name = parse_data_name_with_pattern(data_name, "forward") + forward_data.update({api_name: data_item}) + elif "backward" in data_name: + api_name = parse_data_name_with_pattern(data_name, "backward") + backward_data.update({api_name: data_item}) + else: + raise ParseJsonException(ParseJsonException.UnexpectedNameStruct, + f"{data_name} in file {json_path}.") + + return forward_data, backward_data, real_data_path diff --git a/debug/accuracy_tools/msacc/pytorch/common/recursive.py b/debug/accuracy_tools/msacc/pytorch/common/recursive.py new file mode 100644 index 0000000000000000000000000000000000000000..3745a33f9eac6c1c7e8e5437ca375dc4e0f8f22a --- /dev/null +++ b/debug/accuracy_tools/msacc/pytorch/common/recursive.py @@ -0,0 +1,23 @@ +import torch + +_recursive_key_stack = [] +def recursive_apply_transform(args, transform): + global _recursive_key_stack + if isinstance(args, (list, tuple)): + transform_result = [] + for i, arg in enumerate(args): + _recursive_key_stack.append(str(i)) + transform_result.append(recursive_apply_transform(arg, transform)) + _recursive_key_stack.pop() + return type(args)(transform_result) + elif isinstance(args, dict): + transform_result = {} + for k, arg in args.items(): + _recursive_key_stack.append(str(k)) + transform_result[k] = recursive_apply_transform(arg, transform) + _recursive_key_stack.pop() + return transform_result + else: + arg_transform = transform(args, _recursive_key_stack) + return arg_transform + diff --git a/debug/accuracy_tools/msacc/pytorch/common/utils.py b/debug/accuracy_tools/msacc/pytorch/common/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..85d2276d86eef59daabf2b25c5a043d14af6e088 --- /dev/null +++ b/debug/accuracy_tools/msacc/pytorch/common/utils.py @@ -0,0 +1,139 @@ +import os +from pathlib import Path +import stat +import torch +import numpy as np +from functools import wraps +try: + import torch_npu +except ImportError: + is_gpu = True +else: + is_gpu = False + + +torch_without_guard_version_list = ['2.1'] # TODO: 2.2? +for version in torch_without_guard_version_list: + if torch.__version__.startswith(version): + torch_without_guard_version = True + break + else: + torch_without_guard_version = False + +if not is_gpu and not torch_without_guard_version: + from torch_npu.utils.device_guard import torch_device_guard as torch_npu_device_guard + +npu_distributed_api = ['isend', 'irecv'] + + +def parameter_adapter(func): + + @wraps(func) + def inner(self, *args, **kwargs): + if self.op_name_ == "__getitem__" and len(args) > 1 and isinstance(args[1], torch.Tensor): + input_tensor = args[0] + indices = args[1] + if indices.dtype == torch.uint8: + indices = indices.bool() + if indices.dtype == torch.bool: + if indices.shape == input_tensor.shape: + return getattr(torch._C._VariableFunctionsClass, "masked_select")(input_tensor, indices) + else: + indices = getattr(torch._C._VariableFunctionsClass, "nonzero")(indices, as_tuple=True) + return getattr(torch._C._TensorBase, "__getitem__")(input_tensor, indices) + elif indices.dtype != torch.bool: + if not indices.shape or len(indices.shape) == 1: + return func(self, input_tensor, indices.tolist()) + elif len(indices.shape) == 2: + result = [func(self, input_tensor, index) for index in indices.tolist()] + return getattr(torch._C._VariableFunctionsClass, "stack")(result, 0) + else: + res = [input_tensor[tensor_index] for tensor_index in indices] + return getattr(torch._C._VariableFunctionsClass, "stack")(res, 0) + if self.op_name_ == "__eq__" and args[1] is None: + return False + return func(self, *args, **kwargs) + return inner + + +def torch_device_guard(func): + if is_gpu or torch_without_guard_version: + return func + # Parse args/kwargs matched torch.device objects + + @torch_npu_device_guard + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + return wrapper + + +def get_rank_if_initialized(): + if torch.distributed.is_initialized(): + return torch.distributed.get_rank() + return None + + + + + +class Const: + """ + Class for const + """ + SEP = "." + MODEL_TYPE = ['.onnx', '.pb', '.om'] + DIM_PATTERN = r"^(-?[0-9]+)(,-?[0-9]+)*" + SEMICOLON = ";" + COLON = ":" + EQUAL = "=" + COMMA = "," + DOT = "." + DUMP_RATIO_MAX = 100 + SUMMERY_DATA_NUMS = 256 + FLOAT_EPSILON = np.finfo(float).eps + SUPPORT_DUMP_MODE = ['api', 'acl'] + ON = 'ON' + OFF = 'OFF' + BACKWARD = 'backward' + FORWARD = 'forward' + PRE_FORWARD = "pre_forward" + + # dump mode + ALL = "all" + LIST = "list" + RANGE = "range" + STACK = "stack" + ACL = "acl" + API_LIST = "api_list" + API_STACK = "api_stack" + DUMP_MODE = [ALL, LIST, RANGE, STACK, ACL, API_LIST, API_STACK] + AUTO = "auto" + ONLINE_DUMP_MODE = [ALL, LIST, AUTO, OFF] + SUMMARY = "summary" + MD5 = "md5" + SUMMARY_MODE = [ALL, SUMMARY, MD5] + + WRITE_FLAGS = os.O_WRONLY | os.O_CREAT + WRITE_MODES = stat.S_IWUSR | stat.S_IRUSR + + PKL_SUFFIX = ".pkl" + NUMPY_SUFFIX = ".npy" + ONE_GB = 1 * 1024 * 1024 * 1024 + TEN_GB = 10 * 1024 * 1024 * 1024 + FILE_PATTERN = r'^[a-zA-Z0-9_./-]+$' + FILE_NAME_LENGTH = 255 + DIRECTORY_LENGTH = 4096 + DISTRIBUTED_PREFIX_LENGTH = 60 + SUMMARY_COLUMN_NUM = 6 + STACK_COLUMN_NUM = 2 + # env dump path + ASCEND_WORK_PATH = "ASCEND_WORK_PATH" + DUMP_DIR = "dump_data" + + ENV_ENABLE = "1" + ENV_DISABLE = "0" + + MAX_SEED_VALUE = 2**32 - 1 + + INPLACE_LIST = ["broadcast", "all_reduce", "reduce", "all_gather", "gather", "scatter", "reduce_scatter", + "_reduce_scatter_base", "_all_gather_base"] \ No newline at end of file diff --git a/debug/accuracy_tools/msacc/pytorch/debugger/__init__.py b/debug/accuracy_tools/msacc/pytorch/debugger/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/debug/accuracy_tools/msacc/pytorch/debugger/debugger_config.py b/debug/accuracy_tools/msacc/pytorch/debugger/debugger_config.py new file mode 100644 index 0000000000000000000000000000000000000000..22b0a7a05bde3b6776245b16f61414d5ef4937e7 --- /dev/null +++ b/debug/accuracy_tools/msacc/pytorch/debugger/debugger_config.py @@ -0,0 +1,49 @@ +import os +from ..common import print_warn_log_rank_0 + + + +class DebuggerConfig: + def __init__(self, dump_path, task, level=None, scope=[], api_list=[], on_step_end=None, + rank=None, step=None, repair_type=None, repair_scope=None, repair_api_str=None, + task_config=None): + self.task_config = task_config + self.dump_path = dump_path + self.task = task + self.rank = rank + self.step = step if step is not None else [] + self.scope = scope + self.level = level + self.api_list = api_list + self.repair_type = repair_type + self.repair_scope = repair_scope + self.repair_api_str = repair_api_str + self.on_step_end = on_step_end + + self.check() + if self.step: + self.step.sort() + + def check(self): + # self._check_hook_name() + self._check_rank() + self._check_step() + return True + + def _check_hook_name(self): + if self.hook_name not in ["dump", "overflow_check"]: + raise ValueError(f"hook_name should be in ['dump', 'overflow_check'], got {self.hook_name}") + + def _check_rank(self): + if self.rank is not None: + if not isinstance(self.rank, int) or self.rank < 0: + raise ValueError(f"rank {self.rank} must be a positive integer.") + else: + print_warn_log_rank_0(f"Rank argument is provided. Only rank {self.rank} data will be dumpped.") + + def _check_step(self): + if not isinstance(self.step, list): + raise ValueError(f"step {self.step} should be list") + for s in self.step: + if not isinstance(s, int): + raise ValueError(f"step element {s} should be int") diff --git a/debug/accuracy_tools/msacc/pytorch/debugger/precision_debugger.py b/debug/accuracy_tools/msacc/pytorch/debugger/precision_debugger.py new file mode 100644 index 0000000000000000000000000000000000000000..45fe3e7e037fe192294401003c41845247db39c8 --- /dev/null +++ b/debug/accuracy_tools/msacc/pytorch/debugger/precision_debugger.py @@ -0,0 +1,48 @@ +from .debugger_config import DebuggerConfig +from pytorch.service import Service +from ..common import print_warn_log_rank_0 + + +class PrecisionDebugger: + _instance = None + + def __new__(cls, *args, **kwargs): + if cls._instance is None: + cls._instance = super(PrecisionDebugger, cls).__new__(cls) + cls._instance.config = None + cls._instance.model = None + cls._instance.enable_dataloader = False + return cls._instance + + def __init__(self, *args, **kwargs): + if not hasattr(self, 'initialized'): + self.initialized = True + self.config = DebuggerConfig(*args, **kwargs) + + self.service = Service(self.config) # todo: enable dataloader功能 + + @classmethod + def start(cls, model): + instance = cls._instance + if not instance: + raise Exception("No instance of PrecisionDebugger found.") + if instance.enable_dataloader: + print_warn_log_rank_0("DataLoader is enabled, start() skipped.") + else: + instance.service.start(model) + + @classmethod + def stop(cls): + instance = cls._instance + if not instance: + raise Exception("PrecisionDebugger instance is not created.") + if instance.enable_dataloader: + print_warn_log_rank_0("DataLoader is enabled, stop() skipped.") + else: + instance.service.stop() + + @classmethod + def step(cls): + if not cls._instance: + raise Exception("PrecisionDebugger instance is not created.") + cls._instance.service.step() diff --git a/debug/accuracy_tools/msacc/pytorch/functional/__init__.py b/debug/accuracy_tools/msacc/pytorch/functional/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d7282af08f0aadc803f7554602614359e7689e14 --- /dev/null +++ b/debug/accuracy_tools/msacc/pytorch/functional/__init__.py @@ -0,0 +1,4 @@ +from .repair import build_repair +from .scope import build_scope +from .step_post_process import build_step_post_process +from .data_collector import build_collect_data \ No newline at end of file diff --git a/debug/accuracy_tools/msacc/pytorch/functional/data_collector.py b/debug/accuracy_tools/msacc/pytorch/functional/data_collector.py new file mode 100644 index 0000000000000000000000000000000000000000..31a9320dd8738a518d4a8b9322d072ac9f0146ce --- /dev/null +++ b/debug/accuracy_tools/msacc/pytorch/functional/data_collector.py @@ -0,0 +1,106 @@ + +import os +from calibrator.pytorch.module_processer import ModuleProcesser +from .scope import BaseScope, build_scope +from .json_writer import DataWriter +from ..common.log import print_info_log, print_info_log_rank_0, print_error_log_rank_0 +from ..common.utils import Const +from ..common.file_check import FileOpen +from .data_processor import build_data_processor, DataProcessor + + +def build_collect_data(config): + return DataCollector(config) + + +class DataCollector: + overflow_task = "overflow" + tasks_need_tensor_data = ["overflow", "tensor"] + level_without_construct = "API" + + def __init__(self, config): + self.config = config + self.data_writer = DataWriter() + self.data_processor = build_data_processor(config.task, config.task_config, self.data_writer) + self.module_count = {} + self.scope = build_scope(None, self.config.scope, self.config.api_list) + + @property + def dump_data_dir(self): + return self.data_writer.dump_tensor_data_dir + + @property + def dump_file_path(self): + return self.data_writer.dump_file_path + + def write_json(self): + self.data_writer.write_json() + + def __call__(self, name_template, module_type, module, pid, module_input_output): + if module_type == BaseScope.Module_Type_Module: + name = module.mindstudio_reserved_name + else: + name = name_template + + if self.config.level != DataCollector.level_without_construct: + self.data_writer.update_construct({name: ModuleProcesser.api_parent_node}) + self.data_writer.update_construct(ModuleProcesser.module_node) + if not self.scope or self.scope.check(name): + msg = f"Calibrator is collecting data on {name}. " + if pid == os.getpid(): + if "forward" in name: + data_info = self.data_processor.analyze_forward(name, module_input_output) + self.data_writer.update_stack(self.data_processor.analyze_api_call_stack(name)) + else: + data_info = self.data_processor.analyze_backward(name, module_input_output) + if self.config.task == DataProcessor.overflow: + if data_info: + self.data_writer.update_data(data_info) + msg += "Overflow detected." + else: + msg += "No Overflow, OK." + else: + self.data_writer.update_data(data_info) + print_info_log(msg) + + + def module_count_func(self, name, name_template): + module_name = name.split(Const.SEP)[-3] + if "forward" in name_template: + if module_name not in self.module_count: + self.module_count[module_name] = [0, [0]] + else: + if self.module_count[module_name][-1] and \ + self.module_count[module_name][0] != self.module_count[module_name][-1][-1]: + self.module_count[module_name][-1].pop() + self.module_count[module_name][0] += 1 + self.module_count[module_name][-1].append(self.module_count[module_name][0]) + index = self.module_count[module_name][0] + else: + backward_stack = self.module_count[module_name][-1] if module_name in self.module_count else [] + if not backward_stack: + index = "abnormal" + else: + index = backward_stack.pop() + return index + + def update_dump_paths(self, *args): + self.data_writer.update_dump_paths(*args) + self.data_writer.initialize_json_file(task=self.config.task, level=self.config.level) + + def generate_compare_script(self): + template_path = os.path.join(os.path.dirname(__file__), '../..', 'common', "compare_script.template") + pkl_dir = os.path.dirname(self.dump_file_path) + compare_script_path = os.path.join(pkl_dir, "compare_data.py") + is_api_stack = "True" if self.config.task == Const.API_STACK else "False" + + try: + with FileOpen(template_path, 'r') as ftemp, \ + os.fdopen(os.open(compare_script_path, Const.WRITE_FLAGS, Const.WRITE_MODES), 'w+') as fout: + code_temp = ftemp.read() + fout.write(code_temp % (self.dump_file_path, self.dump_data_dir, is_api_stack)) + except OSError: + print_error_log_rank_0(f"Failed to open file. Please check file {template_path} or path {pkl_dir}.") + + print_info_log_rank_0(f"Generate compare script successfully which is {compare_script_path}.") + diff --git a/debug/accuracy_tools/msacc/pytorch/functional/data_processor.py b/debug/accuracy_tools/msacc/pytorch/functional/data_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..035eeb7679793c30de41e43a3df9fca226b9a96a --- /dev/null +++ b/debug/accuracy_tools/msacc/pytorch/functional/data_processor.py @@ -0,0 +1,317 @@ +import torch +import zlib +import numpy as np +import os, inspect +from dataclasses import dataclass +from typing import Tuple, List, Dict, Optional, Union +from ..common.exceptions import MsaccException +from ..common.utils import Const +from ..common import recursive_apply_transform + + +def build_data_processor(task, task_config, data_writer): + if task == DataProcessor.full: + return FullTensorDataProcessor(task_config, data_writer) + elif task == DataProcessor.summary: + return DataProcessor(task_config, data_writer) + elif task == DataProcessor.overflow: + return OverflowTensorDataProcessor(task_config, data_writer) + else: + raise MsaccException(MsaccException.INVALID_PARAM_ERROR, + "task should be in [{}, {}, {}]".format( + DataProcessor.full, + DataProcessor.summary, + DataProcessor.overflow + )) + + +@dataclass +class ModuleForwardInputsOutputs: + args: Optional[Tuple] + kwargs: Optional[Dict] + output: Union[Tuple, torch.Tensor] + + def __init__(self, args, kwargs, output): + if not isinstance(args, tuple): + args = (args, ) + if not isinstance(output, tuple): + output = (output, ) + self.args = args + self.kwargs = kwargs + self.output = output + + +@dataclass +class ModuleBackwardInputsOutputs: + grad_output: Optional[Tuple] + grad_input: Optional[Tuple] + + def __init__(self, grad_input, grad_output): + if not isinstance(grad_input, tuple): + grad_input = (grad_input, ) + if not isinstance(grad_output, tuple): + grad_output = (grad_output,) + self.grad_input = grad_input + self.grad_output = grad_output + + +class DataProcessor: + full = "tensor" + summary = "summary" + overflow = "overflow" + + def __init__(self, task_config, data_writer): + self.data_writer = data_writer + self.api_info_struct = {} + self.stack_info_struct = {} + self.torch_object_key = { + "device": self.analyze_device_in_kwargs, + "dtype": self.analyze_dtype_in_kwargs + } + self.api_name = None + self.task_config = task_config + self.api_data_category = None + self.has_overflow = False + + @staticmethod + def get_md5_for_tensor(x): + if x.dtype == torch.bfloat16: + x = x.float() + tensor_bytes = x.cpu().detach().numpy().tobytes() + crc32_hash = zlib.crc32(tensor_bytes) + return f"{crc32_hash:08x}" + + @staticmethod + def analyze_device_in_kwargs(element): + single_arg = {} + single_arg.update({'type': "torch.device"}) + if not isinstance(element, str): + if hasattr(element, "index"): + device_value = element.type + ":" + str(element.index) + else: + device_value = element.type + single_arg.update({"value": device_value}) + else: + single_arg.update({"value": element}) + return single_arg + + @staticmethod + def analyze_dtype_in_kwargs(element): + single_arg = {} + single_arg.update({"type": "torch.dtype"}) + single_arg.update({"value": str(element)}) + return single_arg + + @staticmethod + def _convert_numpy_to_builtin(arg): + type_mapping = { + np.integer: int, + np.floating: float, + np.bool_: bool, + np.complexfloating: complex, + np.str_: str, + np.byte: bytes, + np.unicode_: str + } + for numpy_type, builtin_type in type_mapping.items(): + if isinstance(arg, numpy_type): + return builtin_type(arg), type(arg).__name__ + return arg, '' + + def _analyze_numpy(self, value, numpy_type): + single_arg = {} + single_arg.update({"type": numpy_type}) + single_arg.update({"value": value}) + return single_arg + + def get_stat_info(self, data): + if data.is_meta: + return + data_clone = data.detach() + if data_clone.numel() == 0: + tensor_max = None + tensor_min = None + tensor_mean = None + tensor_norm = None + elif data_clone.dtype == torch.bool: + tensor_max = True in data_clone + tensor_min = False not in data_clone + tensor_mean = None + tensor_norm = None + elif not len(data_clone.shape): + tensor_max = data_clone.item() + tensor_min = tensor_max + tensor_mean = tensor_max + tensor_norm = tensor_max + else: + if not data_clone.is_floating_point(): + data_clone = data_clone.float() + tensor_max = torch._C._VariableFunctionsClass.max(data_clone).item() + tensor_min = torch._C._VariableFunctionsClass.min(data_clone).item() + tensor_mean = torch._C._VariableFunctionsClass.mean(data_clone).item() + tensor_norm = torch._C._VariableFunctionsClass.norm(data_clone).item() + + return tensor_max, tensor_min, tensor_mean, tensor_norm + + def _analyze_builtin(self, arg): + single_arg = {} + if isinstance(arg, slice): + single_arg.update({"type": "slice"}) + single_arg.update({"value": [arg.start, arg.stop, arg.step]}) + else: + single_arg.update({"type": type(arg).__name__}) + single_arg.update({"value": arg}) + return single_arg + + @staticmethod + def handle_tensor_extremum_nan_inf(data_clone, operator): + data_nan = torch._C._VariableFunctionsClass.isnan(data_clone) + if int(torch._C._VariableFunctionsClass.sum(data_nan)) == data_clone.numel(): + return float('nan') + finite_mask = torch._C._VariableFunctionsClass.isfinite(data_clone) + if int(torch._C._VariableFunctionsClass.sum(finite_mask)) > 0: + finite_values = data_clone[finite_mask] + return torch._C._VariableFunctionsClass.max(finite_values).item() if operator == 'max' else \ + torch._C._VariableFunctionsClass.min(finite_values).item() + else: + data_no_nan = data_clone[~data_nan] + return torch._C._VariableFunctionsClass.max(data_no_nan).item() if operator == 'max' else \ + torch._C._VariableFunctionsClass.min(data_no_nan).item() + + def _analyze_maybe_overflow_tensor(self, tensor_json, tensor): + if np.isinf(tensor_json['Max']) or np.isnan(tensor_json['Max']): + tensor_json['Max_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(tensor, "max") + self.has_overflow = True + if np.isinf(tensor_json['Min']) or np.isnan(tensor_json['Min']): + tensor_json['Min_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(tensor, "min") + self.has_overflow = True + + def _analyze_tensor(self, tensor, suffix): + tensor_max, tensor_min, tensor_mean, tensor_norm = self.get_stat_info(tensor) + + tensor_json = {} + tensor_json.update({'type': 'torch.Tensor'}) + tensor_json.update({'dtype': str(tensor.dtype)}) + tensor_json.update({"shape": tensor.shape}) + tensor_json.update({"Max": tensor_max}) + tensor_json.update({"Min": tensor_min}) + self._analyze_maybe_overflow_tensor(tensor_json, tensor) + tensor_json.update({"Mean": tensor_mean}) + tensor_json.update({"Norm": tensor_norm}) + tensor_json.update({"requires_grad": tensor.requires_grad}) + if self.task_config.md5: + tensor_md5 = self.get_md5_for_tensor(tensor) + tensor_json.update({"md5": tensor_md5}) + + return tensor_json + + def analyze_single_element(self, element, suffix_stack): + if suffix_stack and suffix_stack[-1] in self.torch_object_key: + return self.torch_object_key[suffix_stack[-1]](element) + + converted_numpy, numpy_type = self._convert_numpy_to_builtin(element) + if converted_numpy is not element: + return self._analyze_numpy(converted_numpy, numpy_type) + + if isinstance(element, torch.Tensor): + return self._analyze_tensor(element, Const.SEP.join(suffix_stack)) + + if isinstance(element, (bool, int, float, str, slice)): + return self._analyze_builtin(element) + + def analyze_element(self, element): + return recursive_apply_transform(element, self.analyze_single_element) + + @staticmethod + def analyze_api_call_stack(name): + stack_str = [] + for (_, path, line, func, code, _) in inspect.stack()[5:]: + if not code: + continue + stack_line = " ".join([ + "File", ", ".join([ + path, + " ".join(["line", str(line)]), + " ".join(["in", func]), + " ".join(["\n", code[0].strip()]) + ]) + ]) + stack_str.append(stack_line) + stack_info_struct = {name: stack_str} + return stack_info_struct + + def analyze_forward(self, name, + module_input_output: ModuleForwardInputsOutputs): + self.api_name = name + self.api_data_category = "input" + args_info_list = self.analyze_element(module_input_output.args) + self.api_data_category = "kwargs" + kwargs_info_list = self.analyze_element(module_input_output.kwargs) + self.api_data_category = "output" + output_info_list = self.analyze_element(module_input_output.output) + api_info_struct = {name: {"input_args": args_info_list, + "input_kwargs": kwargs_info_list, + "output": output_info_list}} + return api_info_struct + + def analyze_backward(self, name, + module_input_output: ModuleBackwardInputsOutputs): + self.api_name = name + self.api_data_category = "output" + input_info_list = self.analyze_element(module_input_output.grad_input) + self.api_data_category = "input" + output_info_list = self.analyze_element(module_input_output.grad_output) + api_info_struct = {name: {"grad_input": input_info_list, "grad_output": output_info_list}} # TODO: magic str + return api_info_struct + + +class FullTensorDataProcessor(DataProcessor): + def _analyze_tensor(self, tensor, suffix): + self.data_path = self.data_writer.dump_tensor_data_dir + dump_data_name = (self.api_name + Const.SEP + self.api_data_category + Const.SEP + + suffix + ".pt") + file_path = os.path.join(self.data_writer.dump_tensor_data_dir, dump_data_name) + torch.save(tensor, file_path) + single_arg = super()._analyze_tensor(tensor, suffix) + single_arg.update({"data_name": dump_data_name}) + return single_arg + + +class OverflowTensorDataProcessor(FullTensorDataProcessor): + __slots__ = ["cached_tensors_and_file_paths"] + + def __init__(self, task_config, data_writer): + super().__init__(task_config, data_writer) + self.cached_tensors_and_file_paths = {} + + def _analyze_tensor(self, tensor, suffix): + self.data_path = self.data_writer.dump_tensor_data_dir + dump_data_name = (self.api_name + Const.SEP + self.api_data_category + Const.SEP + + suffix + ".pt") + file_path = os.path.join(self.data_writer.dump_tensor_data_dir, dump_data_name) + self.cached_tensors_and_file_paths.update({file_path: tensor}) + single_arg = super()._analyze_tensor(tensor, suffix) + single_arg.update({"data_name": dump_data_name}) + + def analyze_forward(self, name, + module_input_output: ModuleForwardInputsOutputs): + self.has_overflow = False + api_info_struct = super().analyze_forward(name, module_input_output) + if self.has_overflow: + self.save_overflow_data() + return api_info_struct + return None + + def analyze_backward(self, name, + module_input_output: ModuleBackwardInputsOutputs): + self.has_overflow = False + api_info_struct = super().analyze_backward(name, module_input_output) + if self.has_overflow: + self.save_overflow_data() + return api_info_struct + return None + + def save_overflow_data(self): + for file_path, tensor in self.cached_tensors_and_file_paths.items(): + torch.save(tensor, file_path) + self.cached_tensors_and_file_paths = {} diff --git a/debug/accuracy_tools/msacc/pytorch/functional/json_writer.py b/debug/accuracy_tools/msacc/pytorch/functional/json_writer.py new file mode 100644 index 0000000000000000000000000000000000000000..760f4f2d71afd7210155d1de5f930842dcd3dc80 --- /dev/null +++ b/debug/accuracy_tools/msacc/pytorch/functional/json_writer.py @@ -0,0 +1,90 @@ +import os +from pathlib import Path +import json +from ..common.log import print_info_log_rank_0 + + +class DataWriter: # TODO: UT + # dump_json_name = "dump.json" + # stack_json_name = "stack.json" + # construct_json_name = "construct.json" + + def __init__(self, init_json=None) -> None: + self.dump_count = 0 + self.init_json = init_json + self.dump_file_path = None # os.path.join(dump_dir, DataWriter.dump_json_name) + self.stack_file_path = None # os.path.join(dump_dir, DataWriter.stack_json_name) + self.construct_file_path = None # os.path.join(dump_dir, DataWriter.construct_json_name) + self.dump_tensor_data_dir = None + self.batch_size = 1000 + self.cache_data = {"data": {}} + self.cache_stack = {} + self.cache_construct = {} + + def initialize_json_file(self, **kwargs): + kwargs.update({"dump_data_dir": self.dump_tensor_data_dir, "data": {}}) + with open(self.dump_file_path, 'w') as f: + json.dump(kwargs, f) + + if os.path.exists(self.stack_file_path): + os.remove(self.stack_file_path) + Path(self.stack_file_path).touch() + + if os.path.exists(self.construct_file_path): + os.remove(self.construct_file_path) + Path(self.construct_file_path).touch() + + def update_dump_paths(self, dump_file_path, stack_file_path, construct_file_path, dump_data_dir): + self.dump_file_path = dump_file_path + self.stack_file_path = stack_file_path + self.construct_file_path = construct_file_path + self.dump_tensor_data_dir = dump_data_dir + + def update_data(self, new_data): + self.cache_data["data"].update(new_data) + if len(self.cache_data["data"]) >= self.batch_size: + self.write_data_json(self.dump_file_path) + + def update_stack(self, new_data): + self.cache_stack.update(new_data) + + def update_construct(self, new_data): + self.cache_construct.update(new_data) + + def write_data_json(self, file_path): + import fcntl + print_info_log_rank_0(f"dump.json is at {os.path.dirname(os.path.dirname(file_path))}. ") + if Path(file_path).exists() and os.path.getsize(file_path) > 0: + with open(file_path, "r+") as f: + fcntl.flock(f, fcntl.LOCK_EX) + data_to_write = json.load(f) + fcntl.flock(f, fcntl.LOCK_UN) + else: + self.init_json['data_path'] = self.dump_tensor_data_dir + data_to_write = self.init_json + data_to_write['data'].update(self.cache_data['data']) + with open(file_path, 'w+') as f: + fcntl.flock(f, fcntl.LOCK_EX) + json.dump(data_to_write, f, indent=1) + fcntl.flock(f, fcntl.LOCK_UN) + + self.cache_data["data"].clear() + + def write_stack_info_json(self, file_path): + import fcntl + with open(file_path, 'w+') as f: + fcntl.flock(f, fcntl.LOCK_EX) + json.dump(self.cache_stack, f, indent=1) + fcntl.flock(f, fcntl.LOCK_UN) + + def write_construct_info_json(self, file_path): + import fcntl + with open(file_path, 'w+') as f: + fcntl.flock(f, fcntl.LOCK_EX) + json.dump(self.cache_construct, f, indent=1) + fcntl.flock(f, fcntl.LOCK_UN) + + def write_json(self): + self.write_data_json(self.dump_file_path) + self.write_stack_info_json(self.stack_file_path) + self.write_construct_info_json(self.construct_file_path) \ No newline at end of file diff --git a/debug/accuracy_tools/msacc/pytorch/functional/repair.py b/debug/accuracy_tools/msacc/pytorch/functional/repair.py new file mode 100644 index 0000000000000000000000000000000000000000..3469db9da74de2e0fc8145631eb69e2d64d01558 --- /dev/null +++ b/debug/accuracy_tools/msacc/pytorch/functional/repair.py @@ -0,0 +1,90 @@ +from abc import ABC, abstractmethod + +import torch + +from .scope import build_scope, ListScope, BaseScope +from ..common.exceptions import RepairException +from ..common import recursive_apply_transform, print_info_log_rank_0 + + +def build_repair(config): + if config.repair_type is None: + return None + elif config.repair_type == RepairAPI.ToCPU: + return RepairAPI_toCPU(config) + elif config.repair_type == RepairAPI.RaisePrecision: + return RepairAPI_raise(config) + else: + raise RepairException(RepairException.InvalidRepairType, f"精度修复类型" + f"须配置为'{RepairAPI.ToCPU}'或'{RepairAPI.RaisePrecision}," + f"实际配置为{config.repair_type}") + + +class RepairAPI(ABC): + ToCPU = "cpu" + RaisePrecision = "raise" + + def __init__(self, config): + self.config = config + self.scope = build_scope(ListScope, config.repair_scope, config.repair_api_str) + self.saved, self.towards = "None", "None" + + def check_name_and_module_type(self, name, module_type): + if module_type == BaseScope.Module_Type_Module: + return False + if not self.scope.check(name): + return False + return True + + def convert(self, name, module_type, args, kwargs): + is_target = self.check_name_and_module_type(name, module_type) + if is_target: + args = recursive_apply_transform(args, self.fx) + kwargs = recursive_apply_transform(kwargs, self.fx) + print_info_log_rank_0(f"[calibrator] convert inputs of {name} to " + f"{self.towards}.") + return args, kwargs + + def invert(self, name, module_type, out_feat): + is_target = self.check_name_and_module_type(name, module_type) + if is_target: + out_feat = recursive_apply_transform(out_feat, self.inv_fx) + print_info_log_rank_0(f"[calibrator] convert outputs of {name} back to "\ + f"{self.saved}.") + return out_feat + + +class RepairAPI_toCPU(RepairAPI): + def fx(self, arg, _): + if isinstance(arg, torch.Tensor): + self.saved = arg.device + self.towards = torch.device("cpu") + return arg.cpu() + return arg + + def inv_fx(self, arg, _): + if isinstance(arg, torch.Tensor): + return arg.to(self.saved) + return arg + + +class RepairAPI_raise(RepairAPI): + raise_dtype_map = { + torch.bfloat16: torch.float32, + torch.float16: torch.float32 + } + + def fx(self, arg, _): + if isinstance(arg, torch.Tensor): + self.saved = arg.dtype + self.towards = RepairAPI_raise.raise_dtype_map.get(self.saved) + # bug: nested input may be of various dtypes. which to save and invert? + return arg.to(self.towards) + return arg + + def inv_fx(self, arg, _): + if isinstance(arg, torch.Tensor): + return arg.to(self.saved) + return arg + + diff --git a/debug/accuracy_tools/msacc/pytorch/functional/scope.py b/debug/accuracy_tools/msacc/pytorch/functional/scope.py new file mode 100644 index 0000000000000000000000000000000000000000..01ea607ac049cb3edbeba212d5f8c541571f1dd2 --- /dev/null +++ b/debug/accuracy_tools/msacc/pytorch/functional/scope.py @@ -0,0 +1,174 @@ +from abc import ABC, abstractmethod +from ..common.exceptions import ScopeException +from ..common.utils import Const + + +def build_scope(scope_class, scope=[], api_list=[]): + if not scope and not api_list: + return None + if scope_class: + return scope_class(scope, api_list) + return build_range_scope_according_to_scope_name(scope, api_list) + + +def build_range_scope_according_to_scope_name(scope, api_list): + api_range_scope = APIRangeScope(scope, api_list) + module_range_scope = ModuleRangeScope(scope, api_list) + if not scope: # 如果没有scope参数则用哪类scope都一样 + return api_range_scope + if api_range_scope.is_valid and module_range_scope.is_valid: + raise ScopeException(ScopeException.InvalidScope, f"scope={scope}.") + elif api_range_scope.is_valid: + return api_range_scope + elif module_range_scope.is_valid: + return module_range_scope + else: + raise ScopeException(ScopeException.InvalidScope, f"scope={scope}") + + +class BaseScope(ABC): + Module_Type_Module = "Module" + Module_Type_API = "api" + + @staticmethod + def rectify_args(scope, api_list): + if not isinstance(api_list, list): + raise ScopeException(ScopeException.InvalidApiStr, + f"api_list参数须配置为列表,实际类型为{type(api_list)}.") + for api_list in api_list: + if not isinstance(api_list, str): + raise ScopeException(ScopeException.InvalidApiStr, + f"api_list中的元素须配置为字符串,实际类型为{type(api_list)}.") + if isinstance(scope, str): + scope = [scope] + return scope, api_list + if not isinstance(scope, list): + raise ScopeException(ScopeException.InvalidScope, + f"scope参数须配置为字符串或列表,实际类型为{type(scope)}.") + for s in scope: + if not isinstance(s, str): + raise ScopeException(ScopeException.InvalidScope, + f"scope列表元素要求类型为字符串,实际类型为{type(s)}.") + return scope, api_list + + def __init__(self, scope, api_list): + scope, api_list = self.rectify_args(scope, api_list) + self.scope = scope + self.api_list = api_list + + def check_api_list(self, api_name): + if not self.api_list: + return True + for api_str in self.api_list: + if api_str in api_name: + return True + + @abstractmethod + def check(self, name): + pass + + +class ListScope(BaseScope): + @staticmethod + def rectify_args(scope, api_list): + if scope and api_list: + raise ScopeException(ScopeException.ArgConflict, + f"scope和api_list不可以同时配置,实际配置为scope={scope}, api_list={api_list}.") + return super(ListScope, ListScope).rectify_args(scope, api_list) + + def check(self, module_name): + if not self.scope or module_name in self.scope: + return self.check_api_list(module_name) + return False + + +class RangeScope(BaseScope, ABC): + @staticmethod + def rectify_args(scope, api_list): + scope, api_list = super(RangeScope, RangeScope).rectify_args(scope, api_list) + if isinstance(scope, list): + if len(scope) == 1: + scope.append(scope[0]) + elif len(scope) > 2: + raise ScopeException(ScopeException.InvalidScope, + f"scope参数指定区间断点,须传入长度为1或2的列表,实际长度为{len(scope)}.") + + return scope, api_list + + @abstractmethod + def check_scope_is_valid(self): + pass + + def __init__(self, *args): + super().__init__(*args) + self.in_scope = False + self.is_valid = self.check_scope_is_valid() + + def begin_module(self, module_name): + pass + + def end_module(self, module_name): + pass + + +class APIRangeScope(RangeScope): + def check_scope_is_valid(self): + if not self.scope: + return True + scope_start_type = self.scope[0].split(Const.SEP)[0] + if scope_start_type == BaseScope.Module_Type_Module: + return False + scope_stop_type = self.scope[1].split(Const.SEP)[0] + if scope_stop_type == BaseScope.Module_Type_Module: + return False + return True + + def check(self, api_name): + if self.scope and api_name == self.scope[0]: + self.in_scope = True + + if not self.scope or self.in_scope: + result = self.check_api_list(api_name) + else: + result = False + + if self.scope and api_name == self.scope[1]: + self.in_scope = False + return result + + +class ModuleRangeScope(RangeScope): + """ + 模块与api不同的是,模块内部还有子结构需要dump, + 需要用pre_hook和full_backward_hook来精确控制module的开始和结束, + 在这些hook触发时调用begin_module和end_module做区间控制 + """ + def check_scope_is_valid(self): + if not self.scope: + return True + scope_start_type = self.scope[0].split(Const.SEP)[0] + scope_stop_type = self.scope[1].split(Const.SEP)[0] + if scope_start_type == BaseScope.Module_Type_Module and \ + scope_stop_type == BaseScope.Module_Type_Module: + return True + return False + + def begin_module(self, module_name): + if not self.scope: + return + if module_name == self.scope[0]: + self.in_scope = True + + def end_module(self, module_name): + if not self.scope: + return + if module_name == self.scope[1]: + self.in_scope = False + + def check(self, module_name): + if not self.scope or self.in_scope: + return self.check_api_list(module_name) + return False + + + diff --git a/debug/accuracy_tools/msacc/pytorch/functional/step_post_process.py b/debug/accuracy_tools/msacc/pytorch/functional/step_post_process.py new file mode 100644 index 0000000000000000000000000000000000000000..7f0d3459326f04691a0041c120bf4efc676f8bc1 --- /dev/null +++ b/debug/accuracy_tools/msacc/pytorch/functional/step_post_process.py @@ -0,0 +1,43 @@ +from abc import ABC, abstractmethod +from ..common.exceptions import StepException + + +def run_parallel_ut(config): + pass + + +def compare_distrbuted(config): + pass + + +def build_step_post_process(config): + if not config.on_step_end: + return None + if config.on_step_end == StepPostProcess.SingleAPICheck: + return SingleAPICheck(config) + elif config.on_step_end == StepPostProcess.Compare: + return AutoCompare(config) + else: + raise StepException(StepException.InvalidPostProcess, f"step后处理须配置为" + f"'{StepPostProcess.SingleAPICheck}'或'{StepPostProcess.Compare}'," + f"实际配置为{config.on_step_end}") + + +class StepPostProcess(ABC): + SingleAPICheck = 'single_api_check' + Compare = 'compare' + + +class SingleAPICheck: + def __init__(self, config): + self.config = config + + def run(self): + run_parallel_ut(self.config) + +class AutoCompare: + def __init__(self, config): + self.config = config + + def run(self): + compare_distrbuted(self.config.bench_dump_path, self.config.dump_path) diff --git a/debug/accuracy_tools/msacc/pytorch/hook_module/__init__.py b/debug/accuracy_tools/msacc/pytorch/hook_module/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4e7a5ca15e8d08d0bb886866bf413712796c9edd --- /dev/null +++ b/debug/accuracy_tools/msacc/pytorch/hook_module/__init__.py @@ -0,0 +1 @@ +from .wrap_functional import remove_dropout \ No newline at end of file diff --git a/debug/accuracy_tools/msacc/pytorch/hook_module/api_registry.py b/debug/accuracy_tools/msacc/pytorch/hook_module/api_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..003a8699cd750a424bf989ae9d1b3fac78f76650 --- /dev/null +++ b/debug/accuracy_tools/msacc/pytorch/hook_module/api_registry.py @@ -0,0 +1,158 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2022-2023. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import torch +import torch.distributed as dist +from . import wrap_torch, wrap_functional, wrap_tensor, wrap_vf, wrap_distributed, wrap_aten +from .wrap_torch import get_torch_ops +from .wrap_functional import get_functional_ops +from .wrap_tensor import get_tensor_ops +from .wrap_vf import get_vf_ops +from .wrap_distributed import get_distributed_ops +from .wrap_aten import get_aten_ops +from ..common.utils import torch_without_guard_version, npu_distributed_api, is_gpu +torch_version_above_2 = torch.__version__.split('+')[0] > '2.0' + +if not is_gpu: + import torch_npu + from . import wrap_npu_custom + from .wrap_npu_custom import get_npu_ops + + +class ApiRegistry: + def __init__(self): + self.tensor_ori_attr = {} + self.torch_ori_attr = {} + self.functional_ori_attr = {} + self.distributed_ori_attr = {} + self.npu_distributed_ori_attr = {} + self.vf_ori_attr = {} + self.aten_ori_attr = {} + self.torch_npu_ori_attr = {} + + self.tensor_hook_attr = {} + self.torch_hook_attr = {} + self.functional_hook_attr = {} + self.distributed_hook_attr = {} + self.npu_distributed_hook_attr = {} + self.vf_hook_attr = {} + self.aten_hook_attr = {} + self.torch_npu_hook_attr = {} + + @staticmethod + def store_ori_attr(ori_api_group, api_list, api_ori_attr): + for api in api_list: + if '.' in api: + sub_module_name, sub_op = api.rsplit('.', 1) + sub_module = getattr(ori_api_group, sub_module_name) + api_ori_attr[api] = getattr(sub_module, sub_op) + else: + api_ori_attr[api] = getattr(ori_api_group, api) + + @staticmethod + def set_api_attr(api_group, attr_dict): + for api, api_attr in attr_dict.items(): + if '.' in api: + sub_module_name, sub_op = api.rsplit('.', 1) + sub_module = getattr(api_group, sub_module_name, None) + if sub_module is not None: + setattr(sub_module, sub_op, api_attr) + else: + setattr(api_group, api, api_attr) + + def api_modularity(self): + self.set_api_attr(torch.Tensor, self.tensor_hook_attr) + self.set_api_attr(torch, self.torch_hook_attr) + self.set_api_attr(torch.nn.functional, self.functional_hook_attr) + self.set_api_attr(dist, self.distributed_hook_attr) + self.set_api_attr(dist.distributed_c10d, self.distributed_hook_attr) + if not is_gpu and not torch_without_guard_version: + self.set_api_attr(torch_npu.distributed, self.npu_distributed_hook_attr) + self.set_api_attr(torch_npu.distributed.distributed_c10d, self.npu_distributed_hook_attr) + if torch_version_above_2: + self.set_api_attr(torch.ops.aten, self.aten_hook_attr) + self.set_api_attr(torch._VF, self.vf_hook_attr) + if not is_gpu: + self.set_api_attr(torch_npu, self.torch_npu_hook_attr) + + def api_originality(self): + self.set_api_attr(torch.Tensor, self.tensor_ori_attr) + self.set_api_attr(torch, self.torch_ori_attr) + self.set_api_attr(torch.nn.functional, self.functional_ori_attr) + self.set_api_attr(dist, self.distributed_ori_attr) + self.set_api_attr(dist.distributed_c10d, self.distributed_ori_attr) + if not is_gpu and not torch_without_guard_version: + self.set_api_attr(torch_npu.distributed, self.npu_distributed_ori_attr) + self.set_api_attr(torch_npu.distributed.distributed_c10d, self.npu_distributed_ori_attr) + if torch_version_above_2: + self.set_api_attr(torch.ops.aten, self.aten_ori_attr) + self.set_api_attr(torch._VF, self.vf_ori_attr) + if not is_gpu: + self.set_api_attr(torch_npu, self.torch_npu_ori_attr) + + def initialize_hook(self, hook): + self.store_ori_attr(torch.Tensor, get_tensor_ops(), self.tensor_ori_attr) + wrap_tensor.wrap_tensor_ops_and_bind(hook) + for attr_name in dir(wrap_tensor.HOOKTensor): + if attr_name.startswith("wrap_"): + self.tensor_hook_attr[attr_name[5:]] = getattr(wrap_tensor.HOOKTensor, attr_name) + + self.store_ori_attr(torch, get_torch_ops(), self.torch_ori_attr) + wrap_torch.wrap_torch_ops_and_bind(hook) + for attr_name in dir(wrap_torch.HOOKTorchOP): + if attr_name.startswith("wrap_"): + self.torch_hook_attr[attr_name[5:]] = getattr(wrap_torch.HOOKTorchOP, attr_name) + + self.store_ori_attr(torch.nn.functional, get_functional_ops(), self.functional_ori_attr) + wrap_functional.wrap_functional_ops_and_bind(hook) + for attr_name in dir(wrap_functional.HOOKFunctionalOP): + if attr_name.startswith("wrap_"): + self.functional_hook_attr[attr_name[5:]] = getattr(wrap_functional.HOOKFunctionalOP, attr_name) + + self.store_ori_attr(dist, get_distributed_ops(), self.distributed_ori_attr) + wrap_distributed.wrap_distributed_ops_and_bind(hook) + if not is_gpu and not torch_without_guard_version: + self.store_ori_attr(torch_npu.distributed, npu_distributed_api, self.npu_distributed_ori_attr) + for attr_name in dir(wrap_distributed.HOOKDistributedOP): + if attr_name.startswith("wrap_"): + self.distributed_hook_attr[attr_name[5:]] = getattr(wrap_distributed.HOOKDistributedOP, attr_name) + if not is_gpu and not torch_without_guard_version and attr_name[5:] in npu_distributed_api: + self.npu_distributed_hook_attr[attr_name[5:]] = getattr(wrap_distributed.HOOKDistributedOP, + attr_name) + + if torch_version_above_2: + self.store_ori_attr(torch.ops.aten, get_aten_ops(), self.aten_ori_attr) + wrap_aten.wrap_aten_ops_and_bind(hook) + for attr_name in dir(wrap_aten.HOOKAtenOP): + if attr_name.startswith("wrap_"): + self.aten_hook_attr[attr_name[5:]] = getattr(wrap_aten.HOOKAtenOP, attr_name) + + self.store_ori_attr(torch._VF, get_vf_ops(), self.vf_ori_attr) + wrap_vf.wrap_vf_ops_and_bind(hook) + for attr_name in dir(wrap_vf.HOOKVfOP): + if attr_name.startswith("wrap_"): + self.vf_hook_attr[attr_name[5:]] = getattr(wrap_vf.HOOKVfOP, attr_name) + + if not is_gpu: + self.store_ori_attr(torch_npu, get_npu_ops(), self.torch_npu_ori_attr) + wrap_npu_custom.wrap_npu_ops_and_bind(hook) + for attr_name in dir(wrap_npu_custom.HOOKNpuOP): + if attr_name.startswith("wrap_"): + self.torch_npu_hook_attr[attr_name[5:]] = getattr(wrap_npu_custom.HOOKNpuOP, attr_name) + + +api_register = ApiRegistry() diff --git a/debug/accuracy_tools/msacc/pytorch/hook_module/hook_module.py b/debug/accuracy_tools/msacc/pytorch/hook_module/hook_module.py new file mode 100644 index 0000000000000000000000000000000000000000..eb35de84b2da72a92532bc62c612bac1c29097f6 --- /dev/null +++ b/debug/accuracy_tools/msacc/pytorch/hook_module/hook_module.py @@ -0,0 +1,110 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2019-2020. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import functools +import threading +import torch +import torch.nn as nn +import torch.utils.hooks as full_hooks +from ..common.utils import Const + +class HOOKModule(nn.Module): + module_count = {} + inner_stop_hook = {} + + def __init__(self, build_hook) -> None: + super(HOOKModule, self).__init__() + self.has_overflow = False + self.prefix = "" + self.current_thread = threading.current_thread().ident + if self.current_thread not in HOOKModule.inner_stop_hook: + HOOKModule.inner_stop_hook[self.current_thread] = False + self.stop_hook = HOOKModule.inner_stop_hook.get(self.current_thread, False) + + if not self.stop_hook: + if hasattr(self, "prefix_op_name_"): + self.prefix = self.prefix_op_name_ + + if self.prefix not in HOOKModule.module_count: + HOOKModule.module_count[self.prefix] = 1 + self.prefix += '0' + Const.SEP + else: + HOOKModule.module_count[self.prefix] += 1 + self.prefix = self.prefix + str(HOOKModule.module_count[self.prefix] - 1) + Const.SEP + forward_pre_hook, forward_hook, backward_hook = build_hook(self.prefix) + self.register_forward_pre_hook(forward_pre_hook, with_kwargs=True) + self.register_forward_hook(forward_hook, with_kwargs=True) + self.register_backward_hook(backward_hook) + + def __call__(self, *input, **kwargs): + changed = False + if not self.stop_hook: + HOOKModule.inner_stop_hook[self.current_thread] = True + changed = True + result = self._call_func(*input, **kwargs) + if changed: + HOOKModule.inner_stop_hook[self.current_thread] = False + return result + + def _call_func(self, *input, **kwargs): + full_backward_hooks, non_full_backward_hooks = [], [] + if len(self._backward_hooks) > 0: + full_backward_hooks, non_full_backward_hooks = self._get_backward_hooks() + for hook in self._forward_pre_hooks.values(): + result_input, result_kwargs = hook(self, input, kwargs) + if result_input is not None: + if not isinstance(result_input, tuple): + result_input = (result_input,) + input = result_input + if result_kwargs is not None: + kwargs = result_kwargs + bw_hook = None + if len(full_backward_hooks) > 0: + bw_hook = full_hooks.BackwardHook(self, full_backward_hooks) + input = bw_hook.setup_input_hook(input) + if torch._C._get_tracing_state(): + result = self._slow_forward(*input, **kwargs) + else: + result = self.forward(*input, **kwargs) + input_list = list(input) + input_list.extend(kwargs.values()) + for hook in self._forward_hooks.values(): + hook_result = hook(self, input, kwargs, result) + if hook_result is not None: + result = hook_result + if bw_hook: + result = bw_hook.setup_output_hook(result) + if len(non_full_backward_hooks) > 0: + var = result + while not isinstance(var, torch.Tensor): + if isinstance(var, dict): + var = next((v for v in var.values() if isinstance(v, torch.Tensor))) + elif isinstance(var, (list, tuple)): + if var: + var = var[0] + else: + return result + else: + return result + grad_fn = var.grad_fn + if grad_fn is not None: + for hook in non_full_backward_hooks: + wrapper = functools.partial(hook, self) + functools.update_wrapper(wrapper, hook) + grad_fn.register_hook(wrapper) + self._maybe_warn_non_full_backward_hook(input, result, grad_fn) + return result diff --git a/debug/accuracy_tools/msacc/pytorch/hook_module/support_wrap_ops.yaml b/debug/accuracy_tools/msacc/pytorch/hook_module/support_wrap_ops.yaml new file mode 100644 index 0000000000000000000000000000000000000000..92096fc4bb336928b2ddf9c3e8eba33dca71a12c --- /dev/null +++ b/debug/accuracy_tools/msacc/pytorch/hook_module/support_wrap_ops.yaml @@ -0,0 +1,1876 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# List of ops that register hooks + +functional: + - conv1d + - conv2d + - conv3d + - conv_transpose1d + - conv_transpose2d + - conv_transpose3d + - conv_tbc + - avg_pool1d + - avg_pool2d + - avg_pool3d + - fractional_max_pool2d_with_indices + - fractional_max_pool2d + - fractional_max_pool3d_with_indices + - fractional_max_pool3d + - max_pool1d_with_indices + - max_pool1d + - max_pool2d_with_indices + - max_pool2d + - max_pool3d_with_indices + - max_pool3d + - max_unpool1d + - max_unpool2d + - max_unpool3d + - lp_pool2d + - lp_pool1d + - adaptive_max_pool1d_with_indices + - adaptive_max_pool1d + - adaptive_max_pool2d_with_indices + - adaptive_max_pool2d + - adaptive_max_pool3d_with_indices + - adaptive_max_pool3d + - adaptive_avg_pool1d + - adaptive_avg_pool2d + - adaptive_avg_pool3d + - dropout + - alpha_dropout + - dropout2d + - dropout3d + - feature_alpha_dropout + - threshold + - threshold_ + - relu + - relu_ + - glu + - hardtanh + - hardtanh_ + - relu6 + - elu + - elu_ + - selu + - selu_ + - celu + - celu_ + - leaky_relu + - leaky_relu_ + - prelu + - rrelu + - rrelu_ + - logsigmoid + - gelu + - hardshrink + - tanhshrink + - softsign + - softplus + - softmin + - softmax + - gumbel_softmax + - log_softmax + - softshrink + - tanh + - sigmoid + - hardsigmoid + - linear + - bilinear + - silu + - hardswish + - embedding + - embedding_bag + - batch_norm + - instance_norm + - layer_norm + - group_norm + - local_response_norm + - ctc_loss + - nll_loss + - poisson_nll_loss + - gaussian_nll_loss + - kl_div + - cross_entropy + - binary_cross_entropy + - binary_cross_entropy_with_logits + - smooth_l1_loss + - l1_loss + - mse_loss + - margin_ranking_loss + - hinge_embedding_loss + - multilabel_margin_loss + - soft_margin_loss + - multilabel_soft_margin_loss + - cosine_embedding_loss + - multi_margin_loss + - pixel_shuffle + - pixel_unshuffle + - channel_shuffle + - upsample + - interpolate + - upsample_nearest + - upsample_bilinear + - grid_sample + - affine_grid + - pad + - pairwise_distance + - pdist + - cosine_similarity + - one_hot + - triplet_margin_loss + - triplet_margin_with_distance_loss + - normalize + - unfold + - fold + - multi_head_attention_forward + - scaled_dot_product_attention + +tensor: + - __add__ + - __and__ + - __bool__ + - __div__ + - __eq__ + - __ge__ + - __gt__ + - __getitem__ + - __iadd__ + - __iand__ + - __idiv__ + - __ifloordiv__ + - __ilshift__ + - __imod__ + - __imul__ + - __ior__ + - __irshift__ + - __isub__ + - __ixor__ + - __lshift__ + - __matmul__ + - __mod__ + - __mul__ + - __nonzero__ + - __or__ + - __radd__ + - __rmul__ + - __rshift__ + - __setitem__ + - __sub__ + - __truediv__ + - __xor__ + - abs + - abs_ + - absolute + - absolute_ + - acos + - acos_ + - acosh + - acosh_ + - add + - add_ + - addbmm + - addbmm_ + - addcdiv + - addcdiv_ + - addcmul + - addcmul_ + - addmm + - addmm_ + - addmv + - addmv_ + - addr + - addr_ + - align_as + - align_to + - all + - allclose + - amax + - amin + - angle + - any + - arccos + - arccos_ + - arccosh + - arccosh_ + - arcsin + - arcsin_ + - arcsinh + - arcsinh_ + - arctan + - arctan_ + - arctanh + - arctanh_ + - argmax + - argmin + - argsort + - asin + - asin_ + - asinh + - asinh_ + - atan + - atan2 + - atan2_ + - atan_ + - atanh + - atanh_ + - baddbmm + - baddbmm_ + - bernoulli + - bernoulli_ + - bincount + - bitwise_and + - bitwise_and_ + - bitwise_not + - bitwise_not_ + - bitwise_or + - bitwise_or_ + - bitwise_xor + - bitwise_xor_ + - bmm + - broadcast_to + - cauchy_ + - ceil + - ceil_ + - cholesky + - chunk + - clamp + - cholesky_solve + - cholesky_inverse + - clamp_ + - clamp_max + - clamp_max_ + - clip + - clamp_min + - clamp_min_ + - clip_ + - copysign + - copysign_ + - cos + - cos_ + - cosh + - cosh_ + - count_nonzero + - cummax + - cummin + - cumprod + - cumprod_ + - cumsum + - cumsum_ + - deg2rad + - deg2rad_ + - det + - diag + - diag_embed + - diagflat + - diagonal + - diff + - dist + - digamma + - digamma_ + - div + - div_ + - divide + - divide_ + - dot + - eig + - eq + - eq_ + - erf + - equal + - erf_ + - erfc + - erfc_ + - erfinv + - erfinv_ + - exp + - exp2 + - exp2_ + - expm1 + - exp_ + - expm1_ + - exponential_ + - fill_ + - fix + - fill_diagonal_ + - fix_ + - flip + - fliplr + - flatten + - flipud + - float_power + - float_power_ + - floor + - floor_ + - floor_divide + - floor_divide_ + - fmax + - fmin + - fmod + - fmod_ + - frac + - frac_ + - gather + - gcd + - gcd_ + - ge + - ge_ + - geometric_ + - geqrf + - ger + - greater + - greater_ + - gt + - gt_ + - greater_equal + - greater_equal_ + - hardshrink + - heaviside + - heaviside_ + - histc + - hypot + - hypot_ + - igamma + - igamma_ + - igammac + - igammac_ + - index_add + - index_add_ + - inverse + - index_copy + - index_copy_ + - index_fill + - index_fill_ + - index_put + - index_put_ + - inner + - index_select + - isclose + - isfinite + - isinf + - isnan + - isneginf + - isposinf + - isreal + - kron + - kthvalue + - lcm + - lcm_ + - ldexp + - ldexp_ + - le + - le_ + - lerp + - lerp_ + - where + - less + - less_ + - less_equal + - less_equal_ + - lgamma + - lgamma_ + - log + - log10 + - log10_ + - log1p + - log1p_ + - log2 + - log2_ + - log_ + - log_normal_ + - log_softmax + - logcumsumexp + - logdet + - logaddexp + - logaddexp2 + - logical_and + - logical_and_ + - logical_not + - logit + - logical_not_ + - logical_or + - logical_or_ + - logical_xor + - logical_xor_ + - logit_ + - logsumexp + - lstsq + - lt + - lt_ + - lu_solve + - map2_ + - map_ + - masked_fill + - matmul + - masked_fill_ + - masked_scatter + - masked_scatter_ + - masked_select + - matrix_exp + - max + - maximum + - mean + - matrix_power + - median + - min + - minimum + - mm + - mode + - msort + - mul + - mul_ + - multinomial + - multiply + - multiply_ + - mv + - mvlgamma + - mvlgamma_ + - nansum + - narrow + - narrow_copy + - ne + - ne_ + - neg + - neg_ + - negative + - negative_ + - nonzero + - norm + - normal_ + - not_equal + - not_equal_ + - permute + - pinverse + - polygamma + - pow + - pow_ + - polygamma_ + - prelu + - prod + - put_ + - rad2deg + - rad2deg_ + - ravel + - real + - reciprocal + - reciprocal_ + - relu + - relu_ + - remainder + - repeat_interleave + - reshape + - remainder_ + - renorm + - renorm_ + - repeat + - reshape_as + - resize_ + - resize_as_ + - roll + - rot90 + - round + - round_ + - rsqrt + - rsqrt_ + - scatter + - scatter_ + - scatter_add + - scatter_add_ + - select + - sgn + - sgn_ + - sigmoid + - sigmoid_ + - sign + - sign_ + - signbit + - sin + - sin_ + - sinc + - sinc_ + - sinh + - sinh_ + - slogdet + - smm + - softmax + - solve + - sort + - split_with_sizes + - sqrt + - sqrt_ + - square + - square_ + - squeeze + - squeeze_ + - sspaddmm + - std + - sub + - sub_ + - sum + - sum_to_size + - svd + - symeig + - t + - t_ + - take + - tan + - tan_ + - tanh + - tanh_ + - tensor_split + - tile + - topk + - transpose + - transpose_ + - triangular_solve + - tril + - tril_ + - triu + - true_divide + - triu_ + - true_divide_ + - trunc + - trunc_ + - type_as + - unbind + - unflatten + - unfold + - unsafe_chunk + - unsqueeze + - unsafe_split + - unsafe_split_with_sizes + - var + - vdot + - unsqueeze_ + - view_as + - xlogy + - xlogy_ + +torch: + - linalg.norm + - linalg.vector_norm + - linalg.matrix_norm + - linalg.diagonal + - linalg.det + - linalg.slogdet + - linalg.cond + - linalg.matrix_rank + - linalg.qr + - linalg.lu + - linalg.lu_factor + - linalg.svd + - linalg.svdvals + - linalg.solve + - linalg.lstsq + - linalg.inv + - linalg.pinv + - linalg.matrix_exp + - linalg.matrix_power + - linalg.cross + - linalg.matmul + - linalg.vecdot + - linalg.multi_dot + - linalg.householder_product + - linalg.tensorsolve + - linalg.vander + - linalg.cholesky_ex + - linalg.inv_ex + - linalg.solve_ex + - linalg.lu_factor_ex + - linalg.ldl_factor + - linalg.ldl_factor_ex + - _adaptive_avg_pool2d + - _add_relu + - _add_relu_ + - _aminmax + - _batch_norm_impl_index + - _convolution + - _foreach_norm + - _softmax_backward_data + - abs + - abs_ + - absolute + - acos + - acos_ + - acosh + - acosh_ + - adaptive_avg_pool1d + - adaptive_max_pool1d + - add + - addbmm + - addcdiv + - addcmul + - addmm + - addmv + - addmv_ + - addr + - amax + - affine_grid_generator + - align_tensors + - all + - alpha_dropout + - amin + - alpha_dropout_ + - angle + - any + - arange + - arccos + - arccos_ + - arccosh + - arccosh_ + - arcsin + - arcsin_ + - arcsinh + - arcsinh_ + - arctan + - arctan_ + - arctanh + - arctanh_ + - argmax + - argmin + - argsort + - asin + - asin_ + - asinh + - asinh_ + - atan + - atan2 + - atan_ + - atanh + - atanh_ + - atleast_1d + - atleast_2d + - atleast_3d + - avg_pool1d + - baddbmm + - bartlett_window + - batch_norm_backward_elemt + - batch_norm_backward_reduce + - batch_norm_elemt + - batch_norm_gather_stats + - batch_norm_gather_stats_with_counts + - bernoulli + - batch_norm_stats + - batch_norm_update_stats + - bilinear + - bincount + - binomial + - binary_cross_entropy_with_logits + - bitwise_and + - bitwise_not + - bitwise_or + - bitwise_xor + - blackman_window + - block_diag + - bmm + - broadcast_tensors + - broadcast_to + - bucketize + - cartesian_prod + - cat + - cdist + - ceil + - ceil_ + - celu + - celu_ + - chain_matmul + - channel_shuffle + - cholesky + - cholesky_inverse + - cholesky_solve + - choose_qparams_optimized + - chunk + - clamp + - clamp_ + - clamp_max + - clamp_max_ + - clamp_min + - clamp_min_ + - clip + - clip_ + - clone + - column_stack + - combinations + - concat + - concatenate + - constant_pad_nd + - conv1d + - conv2d + - conv3d + - conv_tbc + - conv_transpose1d + - conv_transpose2d + - conv_transpose3d + - cos + - convolution + - copysign + - cos_ + - cosh + - cosh_ + - cosine_embedding_loss + - cosine_similarity + - count_nonzero + - cov + - cross + - ctc_loss + - cummax + - cummin + - cumprod + - cumsum + - deg2rad + - deg2rad_ + - det + - diag + - diag_embed + - diff + - diagflat + - diagonal + - digamma + - dist + - div + - divide + - dot + - dropout + - dropout_ + - dsmm + - dstack + - eig + - einsum + - embedding + - embedding_bag + - embedding_renorm_ + - eq + - equal + - erf + - erf_ + - erfc + - erfc_ + - erfinv + - exp + - exp2 + - exp2_ + - exp_ + - expm1 + - expm1_ + - eye + - feature_dropout + - feature_alpha_dropout + - feature_alpha_dropout_ + - feature_dropout_ + - fix + - fill_ + - fix_ + - flatten + - flip + - fliplr + - flipud + - float_power + - floor + - floor_ + - floor_divide + - fmax + - fmin + - fmod + - frac + - frac_ + - full + - frobenius_norm + - full_like + - gather + - gcd + - gcd_ + - ge + - geqrf + - ger + - greater + - greater_equal + - grid_sampler + - grid_sampler_2d + - group_norm + - grid_sampler_3d + - gru + - gru_cell + - gt + - hamming_window + - hann_window + - hardshrink + - heaviside + - hinge_embedding_loss + - histc + - hsmm + - hspmm + - hstack + - hypot + - igamma + - igammac + - index_add + - index_copy + - inner + - index_fill + - index_put + - index_put_ + - index_select + - instance_norm + - inverse + - isclose + - isfinite + - isinf + - isnan + - isneginf + - isposinf + - istft + - kaiser_window + - kl_div + - kron + - kthvalue + - layer_norm + - lcm + - lcm_ + - ldexp + - ldexp_ + - le + - lerp + - less + - less_equal + - lgamma + - linspace + - log + - log10 + - log10_ + - log1p + - log1p_ + - log2 + - log2_ + - log_softmax + - log_ + - logaddexp + - logaddexp2 + - logcumsumexp + - logdet + - logical_and + - logical_not + - logical_or + - logical_xor + - logit + - logit_ + - logspace + - logsumexp + - lstm + - lstm_cell + - lstsq + - lt + - lu_solve + - lu_unpack + - masked_fill + - margin_ranking_loss + - masked_scatter + - masked_select + - matrix_exp + - matmul + - matrix_power + - matrix_rank + - max + - max_pool1d + - max_pool2d + - max_pool1d_with_indices + - max_pool3d + - maximum + - mean + - median + - min + - minimum + - mm + - mode + - moveaxis + - movedim + - msort + - mul + - multinomial + - multiply + - mv + - mvlgamma + - nan_to_num + - nan_to_num_ + - nanmedian + - nansum + - narrow + - native_batch_norm + - native_group_norm + - narrow_copy + - native_layer_norm + - native_norm + - ne + - neg + - negative + - neg_ + - negative_ + - nextafter + - nonzero + - norm + - norm_except_dim + - normal + - not_equal + - nuclear_norm + - ones_like + - outer + - pairwise_distance + - pdist + - permute + - pinverse + - pixel_shuffle + - pixel_unshuffle + - poisson + - poisson_nll_loss + - polar + - polygamma + - pow + - prelu + - prod + - qr + - quantile + - rad2deg + - rad2deg_ + - range + - ravel + - real + - reciprocal + - relu + - reciprocal_ + - relu_ + - remainder + - renorm + - repeat_interleave + - reshape + - resize_as_ + - roll + - rot90 + - round + - round_ + - rrelu + - rrelu_ + - rsqrt + - row_stack + - rsqrt_ + - rsub + - saddmm + - scalar_tensor + - scatter + - select + - scatter_add + - searchsorted + - selu + - selu_ + - sgn + - sigmoid + - sigmoid_ + - sign + - signbit + - sin + - sin_ + - sinc + - sinc_ + - sinh + - sinh_ + - slogdet + - smm + - softmax + - solve + - sort + - sparse_coo_tensor + - square + - split + - split_with_sizes + - spmm + - sqrt + - sqrt_ + - square_ + - squeeze + - sspaddmm + - stack + - std + - std_mean + - stft + - sub + - subtract + - sum + - svd + - swapaxes + - swapdims + - symeig + - t + - take + - take_along_dim + - tan + - tan_ + - tanh + - tanh_ + - tensordot + - tensor_split + - threshold + - threshold_ + - tile + - topk + - transpose + - trapz + - triangular_solve + - tril + - tril_indices + - triplet_margin_loss + - triu + - triu_indices + - true_divide + - trunc + - trunc_ + - unique_consecutive + - xlogy + - unbind + - unsafe_chunk + - unsafe_split + - vander + - var + - vdot + - unsafe_split_with_sizes + - unsqueeze + - var_mean + - vstack + - where + - xlogy_ + +_VF: + - lstm + +torch_npu: + - one_ + - npu_sort_v2 + - npu_transpose + - npu_broadcast + - npu_dtype_cast + - empty_with_format + - npu_one_hot + - npu_stride_add + - npu_ps_roi_pooling + - npu_roi_align + - npu_nms_v4 + - npu_iou + - npu_nms_with_mask + - npu_pad + - npu_bounding_box_encode + - npu_bounding_box_decode + - npu_batch_nms + - npu_slice + - _npu_dropout + - npu_indexing + - npu_ifmr + - npu_max + - npu_scatter + - npu_layer_norm_eval + - npu_alloc_float_status + - npu_confusion_transpose + - npu_bmmV2 + - fast_gelu + - npu_sub_sample + - npu_deformable_conv2d + - npu_mish + - npu_anchor_response_flags + - npu_yolo_boxes_encode + - npu_grid_assign_positive + - npu_normalize_batch + - npu_masked_fill_range + - npu_linear + - npu_bert_apply_adam + - npu_giou + - npu_ciou + - npu_diou + - npu_sign_bits_pack + - npu_sign_bits_unpack + - npu_flash_attention + - npu_scaled_masked_softmax + - npu_rotary_mul + - npu_roi_align + - npu_roi_alignbk + - npu_ptiou + - npu_fusion_attention + - npu_dropout_with_add_softmax + - npu_random_choice_with_mask + - npu_rotated_iou + - npu_conv2d + - npu_conv3d + - npu_softmax_cross_entropy_with_logits + - npu_all_gather_base_mm + - npu_swiglu + - npu_rms_norm + - npu_mm_reduce_scatter_base + - npu_mm_all_reduce_base + - npu_conv_transpose2d + - npu_convolution + - npu_convolution_transpose + - npu_min + - npu_nms_rotated + - npu_reshape + - npu_rotated_box_decode + - npu_rotated_box_encode + - npu_rotated_overlaps + - npu_silu + - npu_fused_attention_score + - npu_multi_head_attention + - npu_gru + - npu_incre_flash_attention + - npu_prompt_flash_attention + - npu_lstm + - npu_apply_adam + +aten: + - signbit + - logical_not_ + - _foreach_copy_ + - clamp + - hardswish_ + - arcsin_ + - logsumexp + - native_group_norm + - special_i1e + - bitwise_and + - new_full + - fft_ihfft + - _adaptive_avg_pool2d + - scatter_add + - abs + - selu + - exponential + - silu + - _native_batch_norm_legit_functional + - special_hermite_polynomial_h + - tanh_ + - log_sigmoid_forward + - _fft_c2c + - heaviside_ + - sigmoid_backward + - zeros_like + - as_strided_scatter + - trace + - _assert_async + - avg_pool2d_backward + - exp2 + - binary_cross_entropy_backward + - geometric + - fft_ihfftn + - smooth_l1_loss + - multiply + - __lshift__ + - binary_cross_entropy_with_logits + - _embedding_bag + - arange + - linalg_qr + - _embedding_bag_forward_only + - _unsafe_view + - remainder + - cholesky_inverse + - sub_ + - zero + - fix + - xlogy + - __doc__ + - rsqrt_ + - cummin + - __xor__ + - eye + - _fused_adam + - ceil + - nll_loss2d_backward + - replication_pad3d_backward + - fill_ + - logaddexp2 + - _thnn_fused_lstm_cell_backward_impl + - native_dropout + - fft_ifft + - expand + - _cdist_backward + - avg_pool3d_backward + - round_ + - topk + - max_unpool3d + - xlogy_ + - reflection_pad2d_backward + - addcdiv_ + - relu6 + - multilabel_margin_loss_forward + - prelu + - logaddexp + - _cholesky_solve_helper + - _foreach_addcdiv + - arctan_ + - fft_irfftn + - logical_or + - bitwise_or_ + - hardtanh_backward + - uniform + - less_equal + - _foreach_sub + - linalg_cholesky_ex + - hardswish + - fft_fft2 + - sign + - min + - norm + - asin + - addcmul_ + - stft + - col2im + - special_chebyshev_polynomial_u + - adaptive_max_pool3d + - __ilshift__ + - _resize_output + - gather + - lu_unpack + - native_batch_norm_backward + - sigmoid + - sqrt + - new_empty_strided + - _foreach_lerp_ + - mean + - scatter_add_ + - _fft_c2r + - rand_like + - true_divide_ + - gcd_ + - multinomial + - permute + - index_put_ + - arcsinh_ + - log1p_ + - index_add + - atan + - glu_backward + - searchsorted + - fill + - _unsafe_index + - index_reduce_ + - replication_pad2d + - expm1_ + - hardsigmoid + - addmm + - fft_fftn + - fft_ifftshift + - special_modified_bessel_k1 + - fft_rfft + - ge + - _adaptive_avg_pool2d_backward + - argmin + - linalg_lu_factor_ex + - atanh_ + - addmv + - _foreach_sqrt_ + - huber_loss_backward + - empty_like + - softshrink + - subtract_ + - bitwise_left_shift_ + - special_modified_bessel_i0 + - _nested_tensor_from_tensor_list + - slice_backward + - special_modified_bessel_i1 + - special_chebyshev_polynomial_t + - conj_physical + - _cdist_forward + - margin_ranking_loss + - max_pool3d_with_indices_backward + - _foreach_reciprocal_ + - lcm + - transpose_ + - cudnn_batch_norm_backward + - reciprocal + - copysign_ + - _foreach_pow + - rad2deg + - _foreach_sqrt + - negative + - replication_pad3d + - atanh + - _linalg_eigh + - igamma_ + - special_i0e + - linalg_ldl_factor_ex + - special_ndtri + - logit + - diagonal_copy + - triu + - silu_ + - polygamma + - square_ + - nextafter_ + - special_scaled_modified_bessel_k0 + - bitwise_not + - var + - mkldnn_rnn_layer_backward + - upsample_bilinear2d + - arctan2 + - clone + - arcsin + - new_ones + - soft_margin_loss + - nan_to_num + - huber_loss + - linalg_lu_solve + - elu_backward + - acosh + - __ior__ + - _unsafe_index_put + - __or__ + - _linalg_slogdet + - arcsinh + - select_scatter + - less_ + - reflection_pad1d + - istft + - reflection_pad2d + - diagonal_backward + - special_entr + - _softmax_backward_data + - randn + - celu + - embedding + - igammac_ + - new_zeros + - native_layer_norm_backward + - nonzero_static + - diagonal_scatter + - grid_sampler_2d + - smooth_l1_loss_backward + - _to_copy + - fft_irfft2 + - relu_ + - fmod + - log1p + - i0 + - mse_loss_backward + - copy + - special_laguerre_polynomial_l + - addmv_ + - quantized_gru + - diag_embed + - acos + - fmod_ + - linalg_cross + - mvlgamma_ + - _foreach_mul + - cummax + - less_equal_ + - ne + - to + - _pdist_forward + - special_xlog1py + - digamma + - lgamma + - mv + - softplus + - special_bessel_y1 + - pin_memory + - logical_xor_ + - cat + - grid_sampler_2d_backward + - frac_ + - dropout + - unsafe_chunk + - masked_fill_ + - log + - negative_ + - _scaled_dot_product_flash_attention + - _amp_foreach_non_finite_check_and_unscale_ + - randn_like + - add + - roll + - threshold + - gcd + - asinh + - round + - t_ + - unfold_backward + - scatter_reduce + - softplus_backward + - bitwise_right_shift_ + - pdist + - select_backward + - relu + - special_bessel_j1 + - asinh_ + - pow + - fft_fftshift + - clamp_max_ + - logical_xor + - index_reduce + - _foreach_add_ + - adaptive_max_pool2d + - adaptive_max_pool3d_backward + - tan + - addbmm_ + - cosh_ + - __rshift__ + - _foreach_maximum + - fft_ifftn + - special_spherical_bessel_j0 + - split_with_sizes + - divide_ + - neg_ + - nll_loss + - _euclidean_dist + - pairwise_distance + - _adaptive_avg_pool3d + - slice + - absolute_ + - gelu_backward + - arccos + - sin + - tril_ + - triu_ + - fft_irfft + - flip + - _foreach_sign + - linalg_householder_product + - _list_to_tensor + - cumprod + - randint_like + - item + - narrow_copy + - tanh + - linalg_vector_norm + - _cudnn_rnn + - _scaled_dot_product_efficient_attention + - _reshape_alias + - _linalg_det + - constant_pad_nd + - _linalg_svd + - sinh_ + - view + - nll_loss_backward + - greater + - sqrt_ + - avg_pool3d + - arctan + - le_ + - _pdist_backward + - _adaptive_avg_pool3d_backward + - log_ + - logical_or_ + - mse_loss + - rrelu_with_noise_backward + - _native_batch_norm_legit + - log10 + - scatter_ + - atan2_ + - greater_equal + - index_select + - __iand__ + - digamma_ + - eq + - divide + - cholesky_solve + - _prelu_kernel + - fft_ifft2 + - _foreach_neg_ + - alias + - erfc_ + - not_equal + - mul + - gru + - _dir + - glu + - clip + - lt + - rsqrt + - avg_pool2d + - conj_physical_ + - quantized_lstm + - erfinv_ + - log10_ + - float_power_ + - _functional_assert_async + - hardtanh + - logical_and_ + - _resize_output_ + - clamp_min + - _functional_sym_constrain_range_for_size + - _addmm_activation + - bucketize + - _thnn_fused_lstm_cell + - zeros + - reflection_pad1d_backward + - tan_ + - bitwise_not_ + - addmm_ + - absolute + - as_strided + - special_ndtr + - gt_ + - baddbmm + - special_log_ndtr + - hardshrink + - fft_hfft + - hypot + - native_layer_norm + - _scaled_dot_product_flash_attention_backward + - floor_divide + - is_same_size + - std + - floor_divide_ + - clamp_min_ + - _foreach_sign_ + - std_mean + - tanh_backward + - _foreach_addcmul + - binary_cross_entropy + - threshold_backward + - deg2rad_ + - masked_fill + - linspace + - reflection_pad3d + - mish + - index_copy + - scatter_reduce_ + - _sparse_coo_tensor_with_dims_and_tensors + - __loader__ + - _foreach_div_ + - cosh + - _foreach_maximum_ + - neg + - lift_fresh + - logspace + - selu_ + - leaky_relu_ + - matmul + - _foreach_sub_ + - bitwise_or + - unfold + - fmin + - convolution + - argmax + - maximum + - reflection_pad3d_backward + - fft_fft + - mode + - remainder_ + - _foreach_neg + - erf_ + - special_zeta + - index_add_ + - arccos_ + - lgamma_ + - unsqueeze_ + - gelu_ + - bmm + - _add_relu + - unfold_copy + - not_equal_ + - subtract + - true_divide + - max_pool2d_with_indices_backward + - _native_batch_norm_legit_no_training + - replication_pad1d + - name + - greater_ + - log_normal + - minimum + - alpha_dropout + - rnn_tanh + - _functional_sym_constrain_range + - sum + - _prelu_kernel_backward + - cumsum_ + - ne_ + - _linalg_solve_ex + - native_batch_norm + - igammac + - hypot_ + - exp + - leaky_relu + - new_empty + - cudnn_batch_norm + - resize_as_ + - mm + - triangular_solve + - sign_ + - clamp_max + - bitwise_right_shift + - logical_and + - special_i0 + - index_copy_ + - arctanh_ + - elu + - index + - isposinf + - linalg_solve_triangular + - logcumsumexp + - arccosh + - nan_to_num_ + - nll_loss_forward + - convolution_backward + - sub + - special_scaled_modified_bessel_k1 + - mish_ + - diagonal + - median + - tril + - sgn + - native_group_norm_backward + - stack + - take + - linalg_lu + - log2 + - hardsigmoid_ + - erfc + - max + - native_dropout_backward + - logit_ + - addr + - clip_ + - _foreach_minimum_ + - atan_ + - repeat + - cumprod_ + - bitwise_xor_ + - less + - index_put + - rrelu_with_noise + - addbmm + - special_bessel_y0 + - __and__ + - bernoulli_ + - uniform_ + - log2_ + - mul_ + - adaptive_max_pool2d_backward + - _foreach_addcmul_ + - slice_scatter + - isneginf + - pow_ + - renorm_ + - arccosh_ + - replication_pad1d_backward + - bitwise_and_ + - heaviside + - renorm + - special_modified_bessel_k0 + - le + - is_pinned + - __ixor__ + - leaky_relu_backward + - count_nonzero + - _fused_adam_ + - repeat_interleave + - upsample_bicubic2d + - rsub + - arctan2_ + - frac + - scalar_tensor + - rrelu_with_noise_ + - rot90 + - erf + - lerp_ + - expm1 + - full + - sym_constrain_range_for_size + - prod + - normal_ + - elu_ + - special_airy_ai + - nextafter + - split + - addcdiv + - fft_rfft2 + - max_pool3d_with_indices + - positive + - transpose + - mish_backward + - clamp_ + - exp_ + - _foreach_reciprocal + - linalg_matrix_exp + - unsqueeze + - upsample_nearest2d + - sinc_ + - select + - rad2deg_ + - trunc_ + - _make_dep_token + - nanmedian + - fft_hfftn + - hardtanh_ + - sym_constrain_range + - index_fill_ + - deg2rad + - rand + - sinc + - pixel_shuffle + - tril_indices + - copy_ + - _int_mm + - greater_equal_ + - celu_ + - div + - igamma + - exp2_ + - cos + - log_normal_ + - _log_softmax_backward_data + - im2col + - reciprocal_ + - amax + - broadcast_tensors + - erfinv + - __spec__ + - _fused_dropout + - special_hermite_polynomial_he + - aminmax + - rnn_relu + - meshgrid + - var_mean + - eq_ + - upsample_nearest3d + - dot + - zero_ + - floor_ + - fft_rfftn + - special_erfcx + - _foreach_div + - fft_hfft2 + - _upsample_bilinear2d_aa + - sort + - log_sigmoid_backward + - add_ + - copysign + - bernoulli + - special_bessel_j0 + - max_pool2d_with_indices + - _scaled_dot_product_efficient_attention_backward + - t + - _softmax + - arctanh + - hinge_embedding_loss + - hardswish_backward + - fmax + - multiply_ + - floor + - lstm + - i0_ + - cholesky + - where + - __irshift__ + - addcmul + - embedding_dense_backward + - sigmoid_ + - fix_ + - ormqr + - exponential_ + - __name__ + - fft_ihfft2 + - logical_not + - ones + - sgn_ + - sinh + - any + - _foreach_addcdiv_ + - asin_ + - gt + - lift + - squeeze + - grid_sampler_3d_backward + - atan2 + - _fft_r2c + - angle + - silu_backward + - acosh_ + - abs_ + - lerp + - special_i1 + - complex + - ceil_ + - _foreach_minimum + - hardsigmoid_backward + - upsample_nearest1d + - mvlgamma + - acos_ + - lt_ + - grid_sampler_3d + - max_unpool2d + - ones_like + - soft_margin_loss_backward + - _fused_moving_avg_obs_fq_helper + - isnan + - nansum + - baddbmm_ + - amin + - isinf + - bitwise_left_shift + - unsafe_split_with_sizes + - full_like + - sin_ + - bitwise_xor + - linalg_ldl_solve + - cos_ + - div_ + - polar + - randint + - trunc + - __package__ + - nll_loss2d_forward + - diag + - argsort + - _foreach_mul_ + - square + - detach + - affine_grid_generator + - _pin_memory + - geometric_ + - unbind + - randperm + - upsample_nearest2d_backward + - all + - threshold_ + - unsafe_split + - cauchy + - normal + - linalg_inv_ex + - multi_margin_loss + - cumsum + - gelu + - index_fill + - scatter + - mkldnn_rnn_layer + - ge_ + - dist + - _foreach_add + - logit_backward + - triu_indices + - lcm_ + - empty_strided + - replication_pad2d_backward + - cauchy_ + - _log_softmax + - vdot + +distributed: + - send + - recv + - broadcast + - all_reduce + - reduce + - all_gather + - gather + - isend + - irecv + - scatter + - reduce_scatter + - _reduce_scatter_base + - _all_gather_base \ No newline at end of file diff --git a/debug/accuracy_tools/msacc/pytorch/hook_module/wrap_aten.py b/debug/accuracy_tools/msacc/pytorch/hook_module/wrap_aten.py new file mode 100644 index 0000000000000000000000000000000000000000..8666287095bbe12f7e9d5f314cff1db75d74a108 --- /dev/null +++ b/debug/accuracy_tools/msacc/pytorch/hook_module/wrap_aten.py @@ -0,0 +1,99 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2023-2023. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import os +import torch + +import yaml + +from .hook_module import HOOKModule +from ..common.utils import torch_device_guard, Const +from ..common.file_check import FileOpen + + +cur_path = os.path.dirname(os.path.realpath(__file__)) +yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") +with FileOpen(yaml_path, 'r') as f: + WrapAtenOps = yaml.safe_load(f).get('aten') + + +aten_func = {} +for f in dir(torch.ops.aten): + aten_func[f] = getattr(torch.ops.aten, f) + + +def get_aten_ops(): + global WrapAtenOps + _all_aten_ops = dir(torch.ops.aten) + return set(WrapAtenOps) & set(_all_aten_ops) + + +class HOOKAtenOP(object): + pass + + +class AtenOPTemplate(HOOKModule): + def __init__(self, op, hook): + if isinstance(op, torch._ops.OpOverloadPacket): + op_name_ = op._qualified_op_name.split("::")[-1] + else: + op_name_ = op.name().split("::")[-1] + overload_name = op._overloadname + if not '.' + overload_name in op_name_: + op_name_ = op_name_ + '.' + overload_name + self.op = op + self.prefix_op_name_ = "Aten" + Const.SEP + str(op_name_) + Const.SEP + super().__init__(hook) + + @torch_device_guard + def forward(self, *args, **kwargs): + return self.op(*args, **kwargs) + + +class AtenOPPacketTemplate(): + def __init__(self, opPacket, hook): + self.opPacket = opPacket + self.hook = hook + + def __getattr__(self, key): + try: + attr = getattr(self.opPacket, key) + except AttributeError as e: + raise AttributeError(f"AtenOPPacketTemplate or OpOverloadPacket does not have attribute '{key}'.") from e + if isinstance(attr, torch._ops.OpOverload): + return AtenOPTemplate(attr, self.hook) + else: + return attr + + def overloads(self): + return self.opPacket.overloads() + + @torch_device_guard + def __call__(self, *args, **kwargs): + return AtenOPTemplate(self.opPacket, self.hook)(*args, **kwargs) + + +def wrap_aten_op(op, hook): + return AtenOPPacketTemplate(op, hook) + + +def wrap_aten_ops_and_bind(hook): + _aten_ops = get_aten_ops() + for op_name in _aten_ops: + if not isinstance(aten_func.get(op_name), torch._ops.OpOverloadPacket): + continue + setattr(HOOKAtenOP, "wrap_" + str(op_name), wrap_aten_op(aten_func.get(op_name), hook)) diff --git a/debug/accuracy_tools/msacc/pytorch/hook_module/wrap_distributed.py b/debug/accuracy_tools/msacc/pytorch/hook_module/wrap_distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..64ce06c33e8fe45966b900bcb9748d798e1b6e84 --- /dev/null +++ b/debug/accuracy_tools/msacc/pytorch/hook_module/wrap_distributed.py @@ -0,0 +1,74 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2022-2023. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import os +from functools import wraps +import torch.distributed as dist +import yaml + +from .hook_module import HOOKModule +from ..common.utils import torch_device_guard, Const +from ..common.file_check import FileOpen + + +cur_path = os.path.dirname(os.path.realpath(__file__)) +yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") +with FileOpen(yaml_path, 'r') as f: + WrapDistributedOps = yaml.safe_load(f).get('distributed') + + +distributed_func = {} +for f in dir(dist): + distributed_func[f] = getattr(dist, f) + + +def get_distributed_ops(): + global WrapDistributedOps + _all_distributed_ops = dir(dist) + return set(WrapDistributedOps) & set(_all_distributed_ops) + + +class HOOKDistributedOP(object): + pass + + +class DistributedOPTemplate(HOOKModule): + def __init__(self, op_name, hook): + self.op_name_ = op_name + self.prefix_op_name_ = "Distributed" + Const.SEP + str(op_name) + Const.SEP + super().__init__(hook) + if self.op_name_ in Const.INPLACE_LIST: + self.register_forward_pre_hook(hook(self.prefix + Const.PRE_FORWARD)) + + @torch_device_guard + def forward(self, *args, **kwargs): + return distributed_func.get(self.op_name_)(*args, **kwargs) + + +def wrap_distributed_op(op_name, hook): + @wraps(DistributedOPTemplate) + def distributed_op_template(*args, **kwargs): + return DistributedOPTemplate(op_name, hook)(*args, **kwargs) + + distributed_op_template.__name__ = op_name + return distributed_op_template + + +def wrap_distributed_ops_and_bind(hook): + _distributed_ops = get_distributed_ops() + for op_name in _distributed_ops: + setattr(HOOKDistributedOP, "wrap_" + str(op_name), wrap_distributed_op(op_name, hook)) diff --git a/debug/accuracy_tools/msacc/pytorch/hook_module/wrap_functional.py b/debug/accuracy_tools/msacc/pytorch/hook_module/wrap_functional.py new file mode 100644 index 0000000000000000000000000000000000000000..46f25efe664fca2bff917b93e3e0632398bdc74e --- /dev/null +++ b/debug/accuracy_tools/msacc/pytorch/hook_module/wrap_functional.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2019-2020. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import os + +import torch +import yaml + +from .hook_module import HOOKModule +from ..common.utils import torch_device_guard, Const +from ..common.log import print_info_log_rank_0 +from ..common.file_check import FileOpen + + +def remove_dropout(): + if torch.__version__ > "1.8": + print_info_log_rank_0("For precision comparison, the probability p in the dropout method is set to 0.") + import torch.nn.functional as F + from torch import _VF + from torch.overrides import has_torch_function_unary, handle_torch_function + + def function_dropout(input: torch.Tensor, p: float = 0.5, training: bool = True, + inplace: bool = False) -> torch.Tensor: + if has_torch_function_unary(input): + return handle_torch_function(function_dropout, (input,), input, p=0., training=training, inplace=inplace) + if p < 0.0 or p > 1.0: + raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p)) + return _VF.dropout_(input, 0., training) if inplace else _VF.dropout(input, 0., training) + + + def function_dropout2d(input: torch.Tensor, p: float = 0.5, training: bool = True, + inplace: bool = False) -> torch.Tensor: + if has_torch_function_unary(input): + return handle_torch_function(function_dropout2d, (input,), input, p=0., training=training, inplace=inplace) + if p < 0.0 or p > 1.0: + raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p)) + return _VF.feature_dropout_(input, 0., training) if inplace else _VF.feature_dropout(input, 0., training) + + + def function_dropout3d(input: torch.Tensor, p: float = 0.5, training: bool = True, + inplace: bool = False) -> torch.Tensor: + if has_torch_function_unary(input): + return handle_torch_function(function_dropout3d, (input,), input, p=0., training=training, inplace=inplace) + if p < 0.0 or p > 1.0: + raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p)) + return _VF.feature_dropout_(input, 0., training) if inplace else _VF.feature_dropout(input, 0., training) + + F.dropout = function_dropout + F.dropout2d = function_dropout2d + F.dropout3d = function_dropout3d + +cur_path = os.path.dirname(os.path.realpath(__file__)) +yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") +with FileOpen(yaml_path, 'r') as f: + WrapFunctionalOps = yaml.safe_load(f).get('functional') + + +def get_functional_ops(): + global WrapFunctionalOps + _all_functional_ops = dir(torch.nn.functional) + return set(WrapFunctionalOps) & set(_all_functional_ops) + + +TorchFunctions = {func: getattr(torch.nn.functional, func) for func in get_functional_ops()} + + +class HOOKFunctionalOP(object): + pass + + +class FunctionalOPTemplate(HOOKModule): + def __init__(self, op_name, hook): + self.op_name_ = op_name + self.prefix_op_name_ = "Functional" + Const.SEP + str(op_name) + Const.SEP + super().__init__(hook) + + @torch_device_guard + def forward(self, *args, **kwargs): + return TorchFunctions[str(self.op_name_)](*args, **kwargs) + + +def wrap_functional_op(op_name, hook): + def functional_op_template(*args, **kwargs): + return FunctionalOPTemplate(op_name, hook)(*args, **kwargs) + + return functional_op_template + + +def wrap_functional_ops_and_bind(hook): + _functional_ops = get_functional_ops() + for op_name in _functional_ops: + setattr(HOOKFunctionalOP, "wrap_" + op_name, wrap_functional_op(op_name, hook)) diff --git a/debug/accuracy_tools/msacc/pytorch/hook_module/wrap_npu_custom.py b/debug/accuracy_tools/msacc/pytorch/hook_module/wrap_npu_custom.py new file mode 100644 index 0000000000000000000000000000000000000000..e910e609c8379e0c66239755c3ec2a44953ef1ec --- /dev/null +++ b/debug/accuracy_tools/msacc/pytorch/hook_module/wrap_npu_custom.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2019-2020. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import os +import torch +import torch_npu +import yaml + +from .hook_module import HOOKModule +from ..common.utils import torch_device_guard, torch_without_guard_version, Const +from ..common.file_check import FileOpen + +cur_path = os.path.dirname(os.path.realpath(__file__)) +yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") +with FileOpen(yaml_path, 'r') as f: + WrapNpuOps = yaml.safe_load(f).get('torch_npu') + + +def get_npu_ops(): + global WrapNpuOps + if torch_without_guard_version: + _npu_ops = dir(torch.ops.npu) + else: + _npu_ops = dir(torch_npu._C._VariableFunctionsClass) + return set(WrapNpuOps) & set(_npu_ops) + + +class HOOKNpuOP(object): + pass + + +class NpuOPTemplate(HOOKModule): + + def __init__(self, op_name, hook): + self.op_name_ = op_name + self.prefix_op_name_ = "NPU" + Const.SEP + str(op_name) + Const.SEP + super().__init__(hook) + + @torch_device_guard + def forward(self, *args, **kwargs): + if torch_without_guard_version: + return getattr(torch.ops.npu, str(self.op_name_))(*args, **kwargs) + else: + return getattr(torch_npu._C._VariableFunctionsClass, str(self.op_name_))(*args, **kwargs) + + +def wrap_npu_op(op_name, hook): + + def npu_op_template(*args, **kwargs): + return NpuOPTemplate(op_name, hook)(*args, **kwargs) + + return npu_op_template + + +def wrap_npu_ops_and_bind(hook): + _npu_ops = get_npu_ops() + for op_name in _npu_ops: + setattr(HOOKNpuOP, "wrap_" + str(op_name), wrap_npu_op(op_name, hook)) diff --git a/debug/accuracy_tools/msacc/pytorch/hook_module/wrap_tensor.py b/debug/accuracy_tools/msacc/pytorch/hook_module/wrap_tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..6b49826ab4712d440b4933651eb6b7eab950d023 --- /dev/null +++ b/debug/accuracy_tools/msacc/pytorch/hook_module/wrap_tensor.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2019-2020. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import os + +import torch +import yaml + +from .hook_module import HOOKModule +from ..common.utils import torch_device_guard, parameter_adapter, Const +from ..common.file_check import FileOpen + +cur_path = os.path.dirname(os.path.realpath(__file__)) +yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") +with FileOpen(yaml_path, 'r') as f: + WrapTensorOps = yaml.safe_load(f).get('tensor') + + +def get_tensor_ops(): + global WrapTensorOps + _tensor_ops = dir(torch.Tensor) + return set(WrapTensorOps) & set(_tensor_ops) + + +TensorOps = {op: getattr(torch.Tensor, op) for op in get_tensor_ops()} + + +class HOOKTensor(object): + pass + + +class TensorOPTemplate(HOOKModule): + + def __init__(self, op_name, hook): + self.op_name_ = op_name + self.prefix_op_name_ = "Tensor" + Const.SEP + str(op_name) + Const.SEP + super().__init__(hook) + + @torch_device_guard + @parameter_adapter + def forward(self, *args, **kwargs): + return TensorOps[str(self.op_name_)](*args, **kwargs) + + +def wrap_tensor_op(op_name, hook): + + def tensor_op_template(*args, **kwargs): + return TensorOPTemplate(op_name, hook)(*args, **kwargs) + + return tensor_op_template + + +def wrap_tensor_ops_and_bind(hook): + _tensor_ops = get_tensor_ops() + for op_name in _tensor_ops: + setattr(HOOKTensor, "wrap_" + str(op_name), wrap_tensor_op(op_name, hook)) diff --git a/debug/accuracy_tools/msacc/pytorch/hook_module/wrap_torch.py b/debug/accuracy_tools/msacc/pytorch/hook_module/wrap_torch.py new file mode 100644 index 0000000000000000000000000000000000000000..889512e9c0c64d9d05dc19cbc30e542c6e5b577c --- /dev/null +++ b/debug/accuracy_tools/msacc/pytorch/hook_module/wrap_torch.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2019-2020. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import os + +import torch +import yaml + +from .hook_module import HOOKModule +from ..common.utils import torch_device_guard, Const +from ..common.file_check import FileOpen + +cur_path = os.path.dirname(os.path.realpath(__file__)) +yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") +with FileOpen(yaml_path, 'r') as f: + WrapTorchOps = yaml.safe_load(f).get('torch') + + +def get_torch_ops(): + global WrapTorchOps + _torch_ops = [] + for operation in WrapTorchOps: + if '.' in operation: + operation_sub_module_name, operation_sub_op = operation.rsplit('.', 1) + operation_sub_module = getattr(torch, operation_sub_module_name) + if operation_sub_op in dir(operation_sub_module): + _torch_ops.append(operation) + else: + if hasattr(torch, operation): + _torch_ops.append(operation) + return set(_torch_ops) + + +TorchOps = {} +for op in get_torch_ops(): + if '.' in op: + sub_module_name, sub_op = op.rsplit('.', 1) + sub_module = getattr(torch, sub_module_name) + TorchOps[op] = getattr(sub_module, sub_op) + else: + TorchOps[op] = getattr(torch, op) + + + +class HOOKTorchOP(object): + pass + + +class TorchOPTemplate(HOOKModule): + + def __init__(self, op_name, hook): + self.op_name_ = op_name + self.prefix_op_name_ = "Torch" + Const.SEP + str(op_name) + Const.SEP + super().__init__(hook) + + @torch_device_guard + def forward(self, *args, **kwargs): + return TorchOps[str(self.op_name_)](*args, **kwargs) + + +def wrap_torch_op(op_name, hook): + + def torch_op_template(*args, **kwargs): + return TorchOPTemplate(op_name, hook)(*args, **kwargs) + + return torch_op_template + + +def wrap_torch_ops_and_bind(hook): + _torch_ops = get_torch_ops() + for op_name in _torch_ops: + setattr(HOOKTorchOP, "wrap_" + op_name, wrap_torch_op(op_name, hook)) diff --git a/debug/accuracy_tools/msacc/pytorch/hook_module/wrap_vf.py b/debug/accuracy_tools/msacc/pytorch/hook_module/wrap_vf.py new file mode 100644 index 0000000000000000000000000000000000000000..08d47308e077981e65193eea71874d4f9432c6c0 --- /dev/null +++ b/debug/accuracy_tools/msacc/pytorch/hook_module/wrap_vf.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2019-2020. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import os + +import torch +import yaml + +from .hook_module import HOOKModule +from ..common.utils import torch_device_guard, Const +from ..common.file_check import FileOpen + +cur_path = os.path.dirname(os.path.realpath(__file__)) +yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") +with FileOpen(yaml_path, 'r') as f: + WrapVfOps = yaml.safe_load(f).get('_VF') + + +def get_vf_ops(): + global WrapVfOps + # _all_functional_ops = dir(torch.nn.functional) + # assert set(WrapFunctionalOps) <= set(_all_functional_ops) + return WrapVfOps + + +class HOOKVfOP(object): + pass + + +class VfOPTemplate(HOOKModule): + def __init__(self, op_name, hook): + self.op_name_ = op_name + self.prefix_op_name_ = "VF" + Const.SEP + str(op_name) + Const.SEP + super().__init__(hook) + + @torch_device_guard + def forward(self, *args, **kwargs): + return getattr(torch._C._VariableFunctionsClass, str(self.op_name_))(*args, **kwargs) + + +def wrap_vf_op(op_name, hook): + def vf_op_template(*args, **kwargs): + return VfOPTemplate(op_name, hook)(*args, **kwargs) + + return vf_op_template + + +def wrap_vf_ops_and_bind(hook): + _vf_ops = get_vf_ops() + for op_name in _vf_ops: + setattr(HOOKVfOP, "wrap_" + op_name, wrap_vf_op(op_name, hook)) diff --git a/debug/accuracy_tools/msacc/pytorch/module_processer.py b/debug/accuracy_tools/msacc/pytorch/module_processer.py new file mode 100644 index 0000000000000000000000000000000000000000..ecd360732a8a43fc0944cd5354a0eec144c30752 --- /dev/null +++ b/debug/accuracy_tools/msacc/pytorch/module_processer.py @@ -0,0 +1,76 @@ +from functools import wraps +import torch +from torch.utils.hooks import BackwardHook +from calibrator.functional.scope import ModuleRangeScope +from calibrator.common.utils import Const + + +class ModuleProcesser: + module_stack = [] + api_parent_node = "" + module_node = {} + current_module_name = "" + + def __init__(self, scope): + if isinstance(scope, ModuleRangeScope): + self.scope = scope + else: + self.scope = None + BackwardHook.setup_input_hook = ModuleProcesser.clone_return_value(BackwardHook.setup_input_hook) + BackwardHook.setup_output_hook = ModuleProcesser.clone_return_value(BackwardHook.setup_output_hook) + self.module_count = {} + + @staticmethod + def clone_return_value(func): + @wraps(func) + def clone_return_value_func(*args, **kwargs): + result = func(*args, **kwargs) + if isinstance(result, torch.Tensor): + result = result.clone() + elif isinstance(result, tuple): + result = tuple(r.clone() for r in result) + return result + + return clone_return_value_func + + def node_hook(self, name_prefix, start_or_stop, **kwargs): + + def pre_hook(module, input, output=None): + try: # ??todo why try except + index = self.module_count_func(name_prefix) + except IndexError as e: + index = None + pass + module.mindstudio_reserved_name = full_name = name_prefix + Const.SEP + str(index) + if self.module_stack: + ModuleProcesser.module_node[full_name] = self.module_stack[-1] + else: + ModuleProcesser.module_node[full_name] = None + + ModuleProcesser.module_stack.append(full_name) + if self.module_stack: + ModuleProcesser.api_parent_node = self.module_stack[-1] + if self.scope: + self.scope.begin_module(full_name) + + def end_hook(module, input, output=None): + if self.module_stack: + ModuleProcesser.module_stack.pop() + if self.module_stack: + ModuleProcesser.api_parent_node = self.module_stack[-1] + else: + ModuleProcesser.api_parent_node = None + if self.scope: + self.scope.end_module(module.mindstudio_reserved_name) + + if "start" in start_or_stop: + return pre_hook + else: + return end_hook + + def module_count_func(self, module_name): + if module_name not in self.module_count: + self.module_count[module_name] = 0 + else: + self.module_count[module_name] += 1 + return self.module_count[module_name] diff --git a/debug/accuracy_tools/msacc/requirements.txt b/debug/accuracy_tools/msacc/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..033ef1113d84375d69a125ca0fb4d5ac18821a02 --- /dev/null +++ b/debug/accuracy_tools/msacc/requirements.txt @@ -0,0 +1,3 @@ +torch +fcntl +yaml diff --git a/debug/accuracy_tools/pytorch/service.py b/debug/accuracy_tools/pytorch/service.py new file mode 100644 index 0000000000000000000000000000000000000000..944bdf22de2be27769e7eee9261e7361baa692fa --- /dev/null +++ b/debug/accuracy_tools/pytorch/service.py @@ -0,0 +1,165 @@ +import os +from pathlib import Path +import functools +import torch +from .functional import build_repair, build_collect_data, build_step_post_process +from .functional.scope import BaseScope +from .common.utils import get_rank_if_initialized, is_gpu, Const +from .common.file_check import FileChecker, FileCheckConst, check_path_before_create +from .common import print_info_log_rank_0 +from .common.exceptions import CalibratorException +from .hook_module.api_registry import api_register +from .hook_module import remove_dropout +from .functional.data_processor import ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs +from calibrator.pytorch.module_processer import ModuleProcesser + + +class Service: + make_dir_flag = True + REGISTER_HOOK_KWARGS = ["overflow_nums", "dump_mode", "dump_config"] + + def __init__(self, config): + self.model = None + self.config = config + self.collect_data = build_collect_data(config) + self.module_processor = ModuleProcesser(self.collect_data.scope) + self.repair = build_repair(config) + self.step_post_process = build_step_post_process(config) + self.switch = False + self.current_iter = 0 + self.first_start = True + self.current_rank = None + self.first_touch_dir = True + + def build_hook(self, module_type, name): + def pre_hook(repair, name_template, module, args, kwargs): + if repair: + args, kwargs = repair.convert(name_template, module_type, args, kwargs) + return args, kwargs + + def forward_hook(repair, name_template, module, args, kwargs, output): + nonlocal module_type, pid + if not self.switch: + return + if self.collect_data: + module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=output) + self.collect_data(name_template, module_type, module, pid, module_input_output) + if repair: + output = repair.invert(name_template, module_type, output) + + return output + + def backward_hook(repair, name_template, module, grad_input, grad_output): + nonlocal module_type, pid + if not self.switch: + return + if self.collect_data: + module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_input, grad_output=grad_output) + self.collect_data(name_template, module_type, module, pid, module_input_output) + + pid = os.getpid() + if module_type == BaseScope.Module_Type_Module: + forward_name_template = name + Const.SEP + "{}" + Const.SEP + "forward" + backward_name_template = name + Const.SEP + "{}" + Const.SEP + "backward" + else: + forward_name_template = name + "forward" + backward_name_template = name + "backward" + pre_forward_hook = functools.partial(pre_hook, self.repair, forward_name_template) + forward_hook = functools.partial(forward_hook, self.repair, forward_name_template) + backward_hook = functools.partial(backward_hook, None, backward_name_template) + return pre_forward_hook, forward_hook, backward_hook + + def step(self): + self.current_iter += 1 + if self.step_post_process: + self.step_post_process() + + @staticmethod + def check_model_valid(model): + if isinstance(model, torch.nn.Module): + return model + raise CalibratorException(CalibratorException.INVALID_PARAM_ERROR, "model 参数必须是torch.nn.Module类型。") + + def start(self, model): + if self.config.step and self.current_iter > max(self.config.step): + self.stop() + raise Exception("ptdbg: exit after iteration {}".format(max(self.config.step))) + if self.config.step and self.current_iter not in self.config.step: + return + self.model = self.check_model_valid(model) + if self.first_start: + self.current_rank = get_rank_if_initialized() + if self.config.rank and self.current_rank not in self.config.rank: + return + self.register_hook_new() + self.first_start = False + self.switch = True + self.create_dirs() + print_info_log_rank_0(f"Dump switch is turned on at step {self.current_iter}. " + f"Dump data will be saved in {self.dump_iter_dir}.") + + def stop(self): + self.switch = False + self.collect_data.write_json() + + # if DumpUtil.is_single_rank and DumpUtil.dump_thread_pool: # todo: 多线程dump + # DumpUtil.dump_thread_pool.shutdown(wait=True) + if not is_gpu: + self.collect_data.generate_compare_script() + + def create_dirs(self): + check_path_before_create(self.config.dump_path) + if not os.path.exists(self.config.dump_path): + Path(self.config.dump_path).mkdir(mode=0o750, exist_ok=True) + file_check = FileChecker(self.config.dump_path, FileCheckConst.DIR) + file_check.common_check() + self.dump_iter_dir = os.path.join(self.config.dump_path, f"step{self.current_iter}") + dump_dir = os.path.join(self.dump_iter_dir, f"rank{self.current_rank}") + if not os.path.exists(dump_dir): + Path(dump_dir).mkdir(mode=0o750, parents=True, exist_ok=True) + if self.config.task in self.collect_data.tasks_need_tensor_data: + dump_data_dir = os.path.join(dump_dir, "dump_tensor_data") + Path(dump_data_dir).mkdir(mode=0o750, exist_ok=True) + else: + dump_data_dir = None + + dump_file_path = os.path.join(dump_dir, "dump.json") + stack_file_path = os.path.join(dump_dir, "stack.json") + construct_file_path = os.path.join(dump_dir, "construct.json") + self.collect_data.update_dump_paths(dump_file_path, stack_file_path, construct_file_path, dump_data_dir) + + def register_hook_new(self): + hook_name = self.config.task + + if "overflow_check" in hook_name and not is_gpu: + pass # 自提单: todo: clear overflow + + print_info_log_rank_0("The {} hook function is successfully mounted to the model.".format(hook_name)) + if self.config.level in ["L1", "L2"]: + assert self.model is not None # TODO assert + print_info_log_rank_0("The init dump mode is enabled, and the module dump function will not be available") + if not isinstance(self.model, torch.nn.Module): + raise CalibratorException(CalibratorException.INVALID_PARAM_ERROR, + "The argument model must be an object of torch.nn.Module") + for name, module in self.model.named_modules(): + if module == self.model: + continue + prefix = BaseScope.Module_Type_Module + Const.SEP + name + Const.SEP +\ + module.__class__.__name__ + Const.SEP + + pre_forward_hook, forward_hook, backward_hook = self.build_hook(BaseScope.Module_Type_Module, prefix) + module.register_forward_hook(forward_hook, with_kwargs=True) + module.register_full_backward_hook(backward_hook) + + module.register_forward_pre_hook(self.module_processor.node_hook(prefix + "forward", "start")) + module.register_forward_hook(self.module_processor.node_hook(prefix + "forward", "stop")) + module.register_full_backward_pre_hook(self.module_processor.node_hook(prefix + "backward", "start")) + module.register_full_backward_hook(self.module_processor.node_hook(prefix + "backward", "stop")) + + if self.config.level in ["L2", "API"]: + api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API)) + api_register.api_modularity() + + if "acc_cmp_dump" in hook_name: + remove_dropout() + diff --git a/debug/accuracy_tools/test/common/test_parse_json.py b/debug/accuracy_tools/test/common/test_parse_json.py new file mode 100644 index 0000000000000000000000000000000000000000..c9d5ef1a48737569d937786590b2fd315ab909ee --- /dev/null +++ b/debug/accuracy_tools/test/common/test_parse_json.py @@ -0,0 +1,51 @@ +import os +from unittest import TestCase +from calibrator.common import parse_json_info_forward_backward + + +class TestParseJsonIntoForwardAndBackward(TestCase): + def test_forward_backward_module_real_data(self): + current_file_path = __file__ + current_dir = os.path.dirname(current_file_path) + dump_json_path = os.path.join(current_dir, '..', 'test_data', 'dump.json') + forward, backward, real_data_path = parse_json_info_forward_backward(dump_json_path) + expected_forward = {"Torch.matmul.0": { + "input_args": [ + { + "type": "torch.Tensor", "shape": [3, 5], "dtype": "torch.float32", "Max": 5.0, + "Min": -3.0, "Mean": 1.5, "Norm": 200, "requires_grad": True + }, + { + "type": "torch.Tensor", "shape": [5, 3], "dtype": "torch.float32", "Max": 3.3, + "Min": -3.1, "Mean": 0.5, "Norm": 2020, "requires_grad": True + } + ], + "input_kwargs": {}, + "output": [ + { + "type": "torch.Tensor", "shape": [3, 5], "dtype": "torch.float32", "Max": 3.3, + "Min": -3.1, "Mean": 0.5, "Norm": 202, "requires_grad": True + } + ] + } + } + self.assertEqual(forward, expected_forward) + expected_backward = {"Torch.matmul.0": { + "input_args": [{ + "type": "torch.Tensor", "shape": [3, 5], "dtype": "torch.float32", "Max": 3.3, + "Min": -3.1, "Mean": 0.5, "Norm": 202, "requires_grad": True + }], + "output": [ + { + "type": "torch.Tensor", "shape": [3, 5], "dtype": "torch.float32", "Max": 5.0, + "Min": -3.0, "Mean": 1.5, "Norm": 200, "requires_grad": True + }, + { + "type": "torch.Tensor", "shape": [5, 3], "dtype": "torch.float32", "Max": 3.3, + "Min": -3.1, "Mean": 0.5, "Norm": 2020, "requires_grad": True + } + ] + } + } + self.assertEqual(backward, expected_backward) + self.assertIsNone(real_data_path) diff --git a/debug/accuracy_tools/test/common/test_recursive.py b/debug/accuracy_tools/test/common/test_recursive.py new file mode 100644 index 0000000000000000000000000000000000000000..25403b82ae8cdc6de1d82fb9c1c8a4652ae73b4b --- /dev/null +++ b/debug/accuracy_tools/test/common/test_recursive.py @@ -0,0 +1,39 @@ +from unittest import TestCase +import torch +from calibrator.common import recursive_apply_transform + + +class TestRecursiveApplyTransform(TestCase): + def setUp(self) -> None: + def transform(x, stack=None): + return x ** 2 + + self.transform = transform + self.index = 0 + self.nested_data = {"a": [0, 1], "b": (2, 3)} + self.expected_recursive_stack = [["a", "0"], ["a", "1"], ["b", "0"], ["b", "1"]] + self.expected_data_item = [0, 1, 2, 3] + + def check_transform_inputs(self, x, stack): + self.assertEqual(x, self.expected_data_item[self.index]) + self.assertEqual(stack, self.expected_recursive_stack[self.index]) + self.index += 1 + + def testRecursiveStack(self): + recursive_apply_transform(self.nested_data, self.check_transform_inputs) + + def testElement(self): + arg = torch.tensor(3) + self.assertEqual(recursive_apply_transform(arg, self.transform), torch.tensor(9)) + + def testTupleAndList(self): + arg = (torch.tensor(1), torch.tensor(4)) + self.assertEqual(recursive_apply_transform(arg, self.transform), (torch.tensor(1), torch.tensor(16))) + + arg = [torch.tensor(1), torch.tensor(4)] + self.assertEqual(recursive_apply_transform(arg, self.transform), [torch.tensor(1), torch.tensor(16)]) + + def testDict(self): + arg = {"a": torch.tensor(4), "b": torch.tensor(10)} + self.assertEqual(recursive_apply_transform(arg, self.transform), {"a": torch.tensor(16), "b": torch.tensor(100)}) + diff --git a/debug/accuracy_tools/test/functional/test_collect_data.py b/debug/accuracy_tools/test/functional/test_collect_data.py new file mode 100644 index 0000000000000000000000000000000000000000..255edd0bc86d9d08d625e80cb4c8c9dec3b6613e --- /dev/null +++ b/debug/accuracy_tools/test/functional/test_collect_data.py @@ -0,0 +1,63 @@ +import os +from unittest import TestCase +from unittest.mock import Mock, mock_open, patch +from calibrator.pytorch.functional.data_collector import build_collect_data +import calibrator.functional.data_processor as data_processor +from calibrator.functional.data_processor import \ + ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs + + +# mock_file = Mock() +m = mock_open() + + +class TestDataCollector(TestCase): + + @patch.object(data_processor.DataProcessor, "analyze_element") + def test_data_collector_should_call_3_times_analyze_element_when_forward(self, mock_analyze_element: Mock): + config = Mock(task='summary', scope=['forward', 'backward'], api_list=['a']) + self.data_collector = build_collect_data(config) + + mock_args, mock_kwargs, mock_output = Mock(), Mock(), Mock() + module_input_output = ModuleForwardInputsOutputs(args=mock_args, kwargs=mock_kwargs, output=mock_output) + name_template, module_type, module, pid = "forward", Mock(), Mock(), os.getpid() + self.data_collector(name_template, module_type, module, pid, module_input_output) + mock_analyze_element.assert_any_call((mock_args, )) + mock_analyze_element.assert_any_call(mock_kwargs) + mock_analyze_element.assert_any_call((mock_output, )) + + @patch.object(data_processor.DataProcessor, "analyze_element") + def test_data_collector_should_call_twice_analyze_element_when_backward(self, mock_analyze_element: Mock): + config = Mock(task='summary', scope=['backward'], api_list=[]) + self.data_collector = build_collect_data(config) + + grad_input, grad_output = Mock(), Mock() + module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_input, grad_output=grad_output) + name_template, module_type, module, pid = "backward", Mock(), Mock(), os.getpid() + self.data_collector(name_template, module_type, module, pid, module_input_output) + mock_analyze_element.assert_any_call((grad_input, )) + mock_analyze_element.assert_any_call((grad_output, )) + + # def test + + +# class TestDataWriter(TestCase): + # @patch('msacc.functional.data_collector.open', m) + # @patch("json.dump") + # def test_data_writer_initialize_json_file(self, mock_write): + # dump_file_path, stack_file_path, construct_file_path, dump_data_dir = Mock(), Mock(), Mock(), Mock() + # data_writer = DataWriter() + # data_writer.update_dump_paths(dump_file_path, stack_file_path, construct_file_path, dump_data_dir) + # task, level = Mock(), Mock() + # with patch('os.path.exists', Mock(return_value=False)): + # with patch.multiple(Path, touch=DEFAULT, __init__=DEFAULT) as mock_touch: + # data_writer.initialize_json_file(task=task, level=level) + # m.assert_called_once_with(dump_file_path, 'w') + # m().assert_called_once() + # m().assert_called_once_with({ + # 'task': task, 'level': level, 'dump_data_dir': dump_data_dir, 'data': {} + # }, mock_file) + + +# class TestDataCollector(TestCase): +# def test_data_collector_ \ No newline at end of file diff --git a/debug/accuracy_tools/test/functional/test_data_processor.py b/debug/accuracy_tools/test/functional/test_data_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..ae0a3c1802c8faad885bb86b2a27d6ba9efefb1f --- /dev/null +++ b/debug/accuracy_tools/test/functional/test_data_processor.py @@ -0,0 +1,60 @@ +from unittest import TestCase +from unittest.mock import Mock, patch, DEFAULT +import torch, zlib +from calibrator.functional.data_processor import FullTensorDataProcessor, DataProcessor,\ + ModuleBackwardInputsOutputs, ModuleForwardInputsOutputs + + +class TestFullTensorDataProcessor(TestCase): + @patch.object(torch, 'save') + def test_full_tensor_data_collector_should_save_data(self, mock_save): + tensor_in, tensor_out = torch.randn(10), torch.randn(10) + inputs_outputs = ModuleBackwardInputsOutputs(tensor_in, tensor_out) + proc = FullTensorDataProcessor(Mock(md5=False), Mock(dump_tensor_data_dir='111')) + proc.analyze_backward('backward', inputs_outputs) + mock_save.assert_called() + # mock_save.assert_any_call(tensor_in, '111/backward.output.0.pt') + # mock_save.assert_any_call(tensor_out, '111/backward.input.0.pt') + + +class TestDataProcessor(TestCase): + def setUp(self): + self.data_processor = DataProcessor(Mock(md5=False), Mock(dump_tensor_data_dir='111')) + + def test_analyze_element(self): + element = [1, 2, 3] + result = self.data_processor.analyze_element(element) + self.assertEqual(result, + [{'type': 'int', 'value': 1}, {'type': 'int', 'value': 2}, {'type': 'int', 'value': 3}]) + + def test_analyze_tensor(self): + tensor = torch.tensor([1, 2, 3], dtype=torch.float32, requires_grad=True) + result = self.data_processor._analyze_tensor(tensor, "") + self.assertEqual(result.get('type'), 'torch.Tensor') + self.assertTrue(result.get('requires_grad')) + # datapath = result.get('datapath') + # self.assertTrue(datapath.startswith('forward_real_data') or datapath.startswith('backward_real_data')) + + def test_analyze_builtin(self): + arg = slice(1, 10, 2) + result = self.data_processor._analyze_builtin(arg) + self.assertEqual(result, {'type': 'slice', 'value': [1, 10, 2]}) + + def test_analyze_element_when_device_in_kwargs(self): + element = torch.device('cuda:0') + result = self.data_processor.analyze_element({'device': element}) + self.assertEqual(result, {'device': {'type': 'torch.device', 'value': 'cuda:0'}}) + + def test_analyze_element_when_dtype_in_kwargs(self): + element = {"dtype": torch.float32} + result = self.data_processor.analyze_element(element) + self.assertEqual(result, {"dtype": {'type': 'torch.dtype', 'value': 'torch.float32'}}) + + def test_analyze_tensor_when_md5_True(self): + self.data_processor.task_config.md5 = True + tensor = torch.ones(10) + crc32_hash = zlib.crc32(tensor.numpy().tobytes()) + md5 = f"{crc32_hash:08x}" + extract_result = self.data_processor._analyze_tensor(tensor, []) + self.assertEqual(extract_result['md5'], str(md5)) + self.data_processor.task_config.md5 = False diff --git a/debug/accuracy_tools/test/functional/test_repair.py b/debug/accuracy_tools/test/functional/test_repair.py new file mode 100644 index 0000000000000000000000000000000000000000..ada4385568048951d7271bc39a28feafc221458f --- /dev/null +++ b/debug/accuracy_tools/test/functional/test_repair.py @@ -0,0 +1,140 @@ +# from unittest import TestCase +# from unittest.mock import Mock +# import torch +# import torch_npu +# from msacc.functional.repair import build_repair +# from msacc.functional.scope import BaseScope +# from msacc.common.exceptions import RepairException +# +# +# class Config(): +# def __init__(self, repair_type, repair_scope=[], repair_api_str=''): +# self.repair_type = repair_type +# self.repair_scope = repair_scope +# self.repair_api_str = repair_api_str +# +# +# torch.distributed.get_rank = Mock(return_value=0) +# +# +# class TestNoRepair(TestCase): +# def testNoRepair(self): +# config = Config(None) +# self.assertIsNone(build_repair(config)) +# +# def testMisc(self): +# config = Config('na') +# self.assertRaises(RepairException, build_repair, config) +# +# +# class TestRepairCpuWithScope(TestCase): +# def setUp(self): +# config = Config("cpu", ['a']) +# self.repair = build_repair(config) +# torch_npu.npu.set_device(0) +# +# def testConvert(self): +# arg = torch.randn(10).npu() +# args = (arg,) +# kwargs = {'k': arg} +# args, kwargs = self.repair.convert('a', BaseScope.Module_Type_API, args, kwargs) +# self.assertEqual(args[0].device, torch.device("cpu")) +# self.assertEqual(kwargs['k'].device, torch.device("cpu")) +# self.assertEqual(self.repair.saved, torch.device("npu:0")) +# self.assertEqual(self.repair.towards, torch.device("cpu")) +# +# def testInvert(self): +# self.repair.saved = torch.device("npu:0") +# out = torch.randn(10) +# out = self.repair.invert("a", BaseScope.Module_Type_API, out) +# self.assertEqual(out.device, torch.device("npu:0")) +# +# def testModule(self): +# out = torch.randn(10) +# out = self.repair.invert('a', BaseScope.Module_Type_Module, out) +# self.assertEqual(out.device, torch.device("cpu")) +# +# +# class TestRepairCpuWithApiStr(TestCase): +# def setUp(self): +# config = Config("cpu", [], "a") +# self.repair = build_repair(config) +# torch_npu.npu.set_device(0) +# +# def testConvert(self): +# arg = torch.randn(10).npu() +# args = (arg,) +# kwargs = {'k': arg} +# args, kwargs = self.repair.convert('ab', BaseScope.Module_Type_API, args, kwargs) +# self.assertEqual(args[0].device, torch.device("cpu")) +# self.assertEqual(kwargs['k'].device, torch.device("cpu")) +# self.assertEqual(self.repair.saved, torch.device("npu:0")) +# self.assertEqual(self.repair.towards, torch.device("cpu")) +# +# def testInvert(self): +# self.repair.saved = torch.device("npu:0") +# out = torch.randn(10) +# out = self.repair.invert('ab', BaseScope.Module_Type_API, out) +# self.assertEqual(out.device, torch.device("npu:0")) +# +# def testModule(self): +# self.repair.saved = torch.device("npu:0") +# out = torch.randn(10) +# out = self.repair.invert('ab', BaseScope.Module_Type_Module, out) +# self.assertEqual(out.device, torch.device("cpu")) +# +# +# class TestRepairDtypeWithScope(TestCase): +# def setUp(self): +# config = Config("raise", ['a']) +# self.repair = build_repair(config) +# +# def testConvert(self): +# arg = torch.randn(10, dtype=torch.float16) +# args = (arg,) +# kwargs = {"k": arg} +# args, kwargs = self.repair.convert("a", BaseScope.Module_Type_API, args, kwargs) +# self.assertEqual(args[0].dtype, torch.float32) +# self.assertEqual(kwargs['k'].dtype, torch.float32) +# self.assertEqual(self.repair.saved, torch.float16) +# self.assertEqual(self.repair.towards, torch.float32) +# +# def testInvert(self): +# self.repair.saved = torch.float16 +# out = torch.randn(10) +# out = self.repair.invert('a', BaseScope.Module_Type_API, out) +# self.assertEqual(out.dtype, torch.float16) +# +# def testModule(self): +# self.repair.saved = torch.float16 +# out = torch.randn(10) +# out = self.repair.invert('a', BaseScope.Module_Type_Module, out) +# self.assertEqual(out.dtype, torch.float32) +# +# +# class TestRepairDtypeWithApiStr(TestCase): +# def setUp(self): +# config = Config("raise", [], 'a') +# self.repair = build_repair(config) +# +# def testConvert(self): +# arg = torch.randn(10, dtype=torch.float16) +# args = (arg,) +# kwargs = {"k": arg} +# args, kwargs = self.repair.convert("a", BaseScope.Module_Type_API, args, kwargs) +# self.assertEqual(args[0].dtype, torch.float32) +# self.assertEqual(kwargs['k'].dtype, torch.float32) +# self.assertEqual(self.repair.saved, torch.float16) +# self.assertEqual(self.repair.towards, torch.float32) +# +# def testInvert(self): +# self.repair.saved = torch.float16 +# out = torch.randn(10) +# out = self.repair.invert('a', BaseScope.Module_Type_API, out) +# self.assertEqual(out.dtype, torch.float16) +# +# def testModule(self): +# self.repair.saved = torch.float16 +# out = torch.randn(10) +# out = self.repair.invert('a', BaseScope.Module_Type_Module, out) +# self.assertEqual(out.dtype, torch.float32) diff --git a/debug/accuracy_tools/test/functional/test_scope.py b/debug/accuracy_tools/test/functional/test_scope.py new file mode 100644 index 0000000000000000000000000000000000000000..25d5710b841d7a03386e3a958f5522edaaaa1faf --- /dev/null +++ b/debug/accuracy_tools/test/functional/test_scope.py @@ -0,0 +1,160 @@ +from unittest import TestCase +import pytest +import torch +from calibrator.common.exceptions import ScopeException +from calibrator.functional.scope import build_scope, ListScope, ModuleRangeScope, APIRangeScope + + +class TestNoScope(TestCase): + def testNoScope(self): + self.assertIsNone(build_scope(ListScope, [], "")) + self.assertIsNone(build_scope(APIRangeScope, [], "")) + self.assertIsNone(build_scope(ListScope)) + self.assertIsNone(build_scope(APIRangeScope)) + + +class TestListScopeWithListScope(TestCase): + def setUp(self): + self.scope = build_scope(ListScope, ['a', 'b', 'c']) + + def testCheck(self): + self.assertTrue(self.scope.check('a')) + self.assertTrue(self.scope.check('b')) + self.assertTrue(self.scope.check('c')) + self.assertFalse(self.scope.check('d')) + + +class TestListScopeWithStrScope(TestCase): + def setUp(self): + self.scope = build_scope(ListScope, 'a') + + def testCheck(self): + self.assertTrue(self.scope.check('a')) + self.assertFalse(self.scope.check('d')) + + +class TestListScopeWithApiStr(TestCase): + def setUp(self): + self.scope = build_scope(ListScope, api_list=['a']) + + def testCheck(self): + self.assertTrue(self.scope.check('ab')) + self.assertFalse(self.scope.check('db')) + + +class TestListScopeMisc(TestCase): + def testMisc(self): + with self.assertRaises(ScopeException): + scope = build_scope(ListScope, 'a', 'b') + with self.assertRaises(ScopeException): + scope = build_scope(ListScope, [torch.tensor(1)]) + with self.assertRaises(ScopeException): + scope = build_scope(ListScope, torch.nn.ReLU()) + with self.assertRaises(ScopeException): + scope = build_scope(ListScope, api_list=[torch.tensor(1)]) + + +class TestAPIRangeScopeWithStrScope(TestCase): + def setUp(self): + self.scope = build_scope(APIRangeScope, 'a') + + def testCheck(self): + self.assertTrue(self.scope.check('a')) + self.assertFalse(self.scope.check('c')) + + +class TestAPIRangeScopeWith2ListScope(TestCase): + def setUp(self): + self.scope = build_scope(APIRangeScope, ['a', 'b']) + + def testCheck(self): + self.assertTrue(self.scope.check('a')) + self.assertTrue(self.scope.check('c')) + self.assertTrue(self.scope.check('b')) + self.assertFalse(self.scope.check('c')) + + +class TestAPIRangeScopeWith2ListScope(TestCase): + def setUp(self): + self.scope = build_scope(APIRangeScope, api_list=['a']) + + def testCheck(self): + self.assertTrue(self.scope.check('ab')) + self.assertFalse(self.scope.check('c')) + + +class TestAPIRangeScopeWith2ListScopeAndApiStr(TestCase): + def setUp(self): + self.scope1 = build_scope(APIRangeScope, ['a', 'b'], api_list=['c']) + self.scope2 = build_scope(APIRangeScope, ['ac', 'bc'], api_list=['c']) + + def testCheckWhenScopeNotContainApiStr(self): + self.assertFalse(self.scope1.check('a')) + self.assertTrue(self.scope1.check('cd')) + self.assertFalse(self.scope1.check('dd')) + self.assertFalse(self.scope1.check('b')) + self.assertFalse(self.scope1.check('cd')) + + def testCheckWhenScopeContainApiStr(self): + self.assertTrue(self.scope2.check('ac')) + self.assertTrue(self.scope2.check('cd')) + self.assertFalse(self.scope2.check('dd')) + self.assertTrue(self.scope2.check('bc')) + self.assertFalse(self.scope2.check('cd')) + + +class TestModuleRangeScopeWithStrScope(TestCase): + def setUp(self): + self.scope = build_scope(ModuleRangeScope, 'a') + + def testCheck(self): + self.assertFalse(self.scope.check('a')) + self.scope.begin_module('a') + self.assertTrue(self.scope.check('a')) + self.assertTrue(self.scope.check('b')) + self.assertTrue(self.scope.check('c')) + self.scope.end_module('a') + self.assertFalse(self.scope.check('a')) + self.assertFalse(self.scope.check('b')) + self.assertFalse(self.scope.check('c')) + + +class TestModuleRangeScopeWith2ListScope(TestCase): + def setUp(self): + self.scope = build_scope(ModuleRangeScope, ['a', 'b']) + + def testCheck(self): + self.assertFalse(self.scope.check('a')) + self.scope.begin_module('a') + self.assertTrue(self.scope.check('c')) + self.scope.end_module('a') + self.assertTrue(self.scope.check('c')) + self.scope.begin_module('b') + self.assertTrue(self.scope.check('c')) + self.scope.end_module('b') + self.assertFalse(self.scope.check('c')) + + +class TestModuleRangeScopeWith2ListScopeAndApiStr(TestCase): + def setUp(self): + self.scope = build_scope(ModuleRangeScope, ['a', 'b'], api_list=['c']) + + def testCheck(self): + self.assertFalse(self.scope.check('a')) + self.scope.begin_module('a') + self.assertTrue(self.scope.check('c')) + self.assertFalse(self.scope.check('d')) + self.scope.end_module('a') + self.assertTrue(self.scope.check('c')) + self.assertFalse(self.scope.check('d')) + self.scope.begin_module('b') + self.assertTrue(self.scope.check('c')) + self.assertFalse(self.scope.check('d')) + self.scope.end_module('b') + self.assertFalse(self.scope.check('c')) + self.assertFalse(self.scope.check('d')) + + +class TestRangeScopeWithOnlyApiList(TestCase): + def test_range_scope_with_only_api_list(self): + scope = build_scope(APIRangeScope, api_list=['a']) diff --git a/debug/accuracy_tools/test/functional/test_step_post_process.py b/debug/accuracy_tools/test/functional/test_step_post_process.py new file mode 100644 index 0000000000000000000000000000000000000000..758064d4b9da62b904ea1d3d2eef949251e3f324 --- /dev/null +++ b/debug/accuracy_tools/test/functional/test_step_post_process.py @@ -0,0 +1,18 @@ +from unittest import TestCase +from calibrator.pytorch.functional.step_post_process import build_step_post_process +from calibrator.common.exceptions import StepException + + +class Config(): + def __init__(self, on_step_end): + self.on_step_end = on_step_end + + +class TestNoStep(TestCase): + def testNoStep(self): + config = Config(None) + self.assertIsNone(build_step_post_process(config)) + + def testMisc(self): + config = Config('na') + self.assertRaises(StepException, build_step_post_process, config) \ No newline at end of file diff --git a/debug/accuracy_tools/test/hook_module/test_wrap_torch.py b/debug/accuracy_tools/test/hook_module/test_wrap_torch.py new file mode 100644 index 0000000000000000000000000000000000000000..f700b8a0b44f7954dbaba9e17a4d4727f01f3922 --- /dev/null +++ b/debug/accuracy_tools/test/hook_module/test_wrap_torch.py @@ -0,0 +1,28 @@ +from unittest import TestCase +from unittest.mock import Mock, patch, DEFAULT +import torch +from calibrator.hook_module.api_registry import api_register + + +def forward_pre_hook(module, args, kwargs): + return args, kwargs + + +def forward_hook(module, args, kwargs, output): + return output + + +def backward_hook(grad_input, grad_output): + return grad_input + + +class TestInitializeHook(TestCase): + # @patch.multiple('test_wrap_torch', forward_pre_hook=DEFAULT, forward_hook=DEFAULT, backward_hook=DEFAULT) + def test_initialize_hook_should_call_hook(self): + build_hook = Mock(return_value=(forward_pre_hook, forward_hook, backward_hook)) + api_register.initialize_hook(build_hook) + api_register.api_modularity() + + x1 = torch.randn(10) + x2 = torch.randn(20) + y = torch.concat((x1, x2)) diff --git a/debug/accuracy_tools/test/test_data/dump.json b/debug/accuracy_tools/test/test_data/dump.json new file mode 100644 index 0000000000000000000000000000000000000000..19cd7f2daaaf9dfbf3b4f4c74f15f9009a91d14b --- /dev/null +++ b/debug/accuracy_tools/test/test_data/dump.json @@ -0,0 +1,45 @@ +{ + "level": "API", + "task": "md5", + "data_path": null, + "data":{ + "Torch.matmul.0.forward":{ + "input_args":[ + { + "type": "torch.Tensor", "shape": [3, 5], "dtype": "torch.float32", "Max": 5.0, + "Min": -3.0, "Mean": 1.5, "Norm": 200, "requires_grad": true + }, + { + "type": "torch.Tensor", "shape": [5, 3], "dtype": "torch.float32", "Max": 3.3, + "Min": -3.1, "Mean": 0.5, "Norm": 2020, "requires_grad": true + } + ], + "input_kwargs": {}, + "output": [ + { + "type": "torch.Tensor", "shape": [3, 5], "dtype": "torch.float32", "Max": 3.3, + "Min": -3.1, "Mean": 0.5, "Norm": 202, "requires_grad": true + } + ] + }, + "Torch.matmul.0.backward": { + "input_args": [ + { + "type": "torch.Tensor", "shape": [3, 5], "dtype": "torch.float32", "Max": 3.3, + "Min": -3.1, "Mean": 0.5, "Norm": 202, "requires_grad": true + } + ], + "output": [ + { + "type": "torch.Tensor", "shape": [3, 5], "dtype": "torch.float32", "Max": 5.0, + "Min": -3.0, "Mean": 1.5, "Norm": 200, "requires_grad": true + }, + { + "type": "torch.Tensor", "shape": [5, 3], "dtype": "torch.float32", "Max": 3.3, + "Min": -3.1, "Mean": 0.5, "Norm": 2020, "requires_grad": true + } + ] + + } + } +} \ No newline at end of file diff --git a/debug/accuracy_tools/test/test_service.py b/debug/accuracy_tools/test/test_service.py new file mode 100644 index 0000000000000000000000000000000000000000..8f2e3208f43b2570f86f35824c3e03b42543a247 --- /dev/null +++ b/debug/accuracy_tools/test/test_service.py @@ -0,0 +1,62 @@ +from unittest import TestCase +from unittest.mock import Mock, patch, DEFAULT +import torch +from pytorch.service import Service +from calibrator.hook_module.api_registry import api_register +from calibrator.debugger.debugger_config import DebuggerConfig +from calibrator.pytorch.functional.data_collector import DataWriter + + +class MyModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(10, 10) + self.relu = torch.nn.ReLU() + + def forward(self, x): + return self.relu(self.linear(x)) + + +class SampleModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(10, 10) + self.model = MyModel() + + def forward(self, x): + return self.linear(self.model(x)) + + +class TestService(TestCase): + def setUp(self): + config = DebuggerConfig(dump_path="./dump_path", task="summary", + task_config=Mock(md5=False), level="API") + self.service_api = Service(Mock(), config) + config = DebuggerConfig(dump_path="./dump_path", task="summary", + task_config=Mock(md5=False), level="L1") + self.model = SampleModel() + self.service_L1 = Service(self.model, config) + + def test_start_when_level_is_L1(self): + x = torch.randn(10, 10) + self.service_L1.start() + y = self.model(x).sum() + # self.service_L1.stop() + + @patch.multiple(api_register, initialize_hook=DEFAULT, api_modularity=DEFAULT) + def test_start_when_level_is_API(self, initialize_hook: Mock, api_modularity: Mock): + self.service_api.start() + # self.assertIsInstance(torch.cat, ) + initialize_hook.assert_called_once() + api_modularity.assert_called_once() + + @patch.multiple(DataWriter, write_json=DEFAULT, initialize_json_file=DEFAULT) + def test_stop_when_level_is_API(self, write_json, initialize_json_file): + dump_file_path = Mock() + other_file_path = Mock() + self.service_api.collect_data.update_dump_paths(dump_file_path, other_file_path, other_file_path, other_file_path) + self.service_api.stop() + initialize_json_file.assert_called_once() + write_json.assert_called_once() + + # def test_build_hook \ No newline at end of file diff --git a/profiler/advisor/advisor_backend/advice_factory/compute_advice_factory.py b/profiler/advisor/advisor_backend/advice_factory/compute_advice_factory.py index 2b6e5270f278276521b20eae225b0c004a77a2f7..336bef7dd8553eb82586d52260443a7d01e84ab0 100644 --- a/profiler/advisor/advisor_backend/advice_factory/compute_advice_factory.py +++ b/profiler/advisor/advisor_backend/advice_factory/compute_advice_factory.py @@ -15,11 +15,13 @@ from common_func_advisor.constant import Constant from advice_factory.advice_factory import AdviceFactory from compute_advice.npu_fused_advice import NpuFusedAdvice +from compute_advice.npu_slow_advice import NpuSlowAdvice class ComputeAdviceFactory(AdviceFactory): ADVICE_LIB = { Constant.NPU_FUSED: NpuFusedAdvice, + Constant.NPU_SLOW: NpuSlowAdvice, } def __init__(self, collection_path: str): diff --git a/profiler/advisor/advisor_backend/cluster_advice/cluster_advice_base.py b/profiler/advisor/advisor_backend/cluster_advice/cluster_advice_base.py index 8cd9acab4c43cc5eff89b6e8c3bdd3ab4a72fc4b..e9be4675963a9cd48da3b4cd91ee646f8e82468b 100644 --- a/profiler/advisor/advisor_backend/cluster_advice/cluster_advice_base.py +++ b/profiler/advisor/advisor_backend/cluster_advice/cluster_advice_base.py @@ -46,7 +46,8 @@ class ClusterAdviceBase(AdviceBase): def cluster_analyze(self): parameter = { - Constant.COLLECTION_PATH: self.collection_path + Constant.COLLECTION_PATH: self.collection_path, + Constant.ANALYSIS_MODE: "all" } try: Interface(parameter).run() diff --git a/profiler/advisor/advisor_backend/cluster_advice/kernel_cluster_advice.py b/profiler/advisor/advisor_backend/cluster_advice/kernel_cluster_advice.py index e2ca914a79451b5bf5fdbcbba14e1f2606cc7cd5..6fa83c765f5fe1f4ac20dcc62895fe0450e338ce 100644 --- a/profiler/advisor/advisor_backend/cluster_advice/kernel_cluster_advice.py +++ b/profiler/advisor/advisor_backend/cluster_advice/kernel_cluster_advice.py @@ -12,7 +12,7 @@ class KernelClusterAdvice(ClusterAdviceBase): COLUMNS_TO_CAL = ["Duration(us)"] CAL_FUN = ['mean', 'var', 'max', 'min', 'count', 'sum'] - def __init__(self, collection_path: str): + def __init__(self, collection_path: str, kwargs: dict = None): super().__init__(collection_path) self.all_kernel_data = pd.DataFrame() diff --git a/profiler/advisor/advisor_backend/cluster_advice/slow_link_advice.py b/profiler/advisor/advisor_backend/cluster_advice/slow_link_advice.py index e350e08f39c087198962f0317926787acbceb406..f8a625242f3939602cbb7b8391cd8062e21fe01b 100644 --- a/profiler/advisor/advisor_backend/cluster_advice/slow_link_advice.py +++ b/profiler/advisor/advisor_backend/cluster_advice/slow_link_advice.py @@ -33,7 +33,7 @@ class SlowLinkAdvice(ClusterAdviceBase): SDMA = "SDMA" RDMA = "RDMA" - def __init__(self, collection_path: str): + def __init__(self, collection_path: str, kwargs: dict = None): super().__init__(collection_path) self.rank_bw_dict = defaultdict(lambda: { self.RDMA_TIME_MS: 0, diff --git a/profiler/advisor/advisor_backend/cluster_advice/slow_rank_advice.py b/profiler/advisor/advisor_backend/cluster_advice/slow_rank_advice.py index 516554583240878f211dba01d4f92c0a17a79cdc..4e789fb7fb688626df7e8f5b25b84e4955d6c2a3 100644 --- a/profiler/advisor/advisor_backend/cluster_advice/slow_rank_advice.py +++ b/profiler/advisor/advisor_backend/cluster_advice/slow_rank_advice.py @@ -26,7 +26,7 @@ class SlowRankAdvice(ClusterAdviceBase): RATIO_THRESHOLD = 0.05 BOTTLENECK_LIST = ['Computing', 'Communication', "Free"] - def __init__(self, collection_path: str): + def __init__(self, collection_path: str, kwargs: dict = None): super().__init__(collection_path) def load_step_time(self): diff --git a/profiler/advisor/advisor_backend/common_func_advisor/constant.py b/profiler/advisor/advisor_backend/common_func_advisor/constant.py index 34879db9f2c078854aab6cfe658fc46865b885df..46a7fb24c2dade75c157f18118f29233eb924b88 100644 --- a/profiler/advisor/advisor_backend/common_func_advisor/constant.py +++ b/profiler/advisor/advisor_backend/common_func_advisor/constant.py @@ -15,11 +15,104 @@ from enum import Enum +class CsvTitle: + MODEL_NAME = "Model Name" + MODEL_ID = "Model ID" + TASK_ID = "Task ID" + STREAM_ID = "Stream ID" + INFER_ID = "Infer ID" + TASK_START_TIME = "Task Start Time(us)" + TASK_WAIT_TIME = "Task Wait Time(us)" + BLOCK_DIM = "Block Dim" + MIX_BLOCK_DIM = "Mix Block Dim" + HF32_ELIGIBLE = "HF32 Eligible" + INPUT_SHAPES = "Input Shapes" + INPUT_DATA_TYPES = "Input Data Types" + INPUT_FORMATS = "Input Formats" + OUTPUT_SHAPES = "Output Shapes" + OUTPUT_DATA_TYPES = "Output Data Types" + OUTPUT_FORMATS = "Output Formats" + CONTEXT_ID = "Context ID" + AICORE_TIME = "aicore_time(us)" + AIC_TOTAL_CYCLES = "aic_total_cycles" + AIC_MAC_TIME = "aic_mac_time(us)" + AIC_MAC_RATIO = "aic_mac_ratio" + AIC_SCALAR_TIME = "aic_scalar_time(us)" + AIC_SCALAR_RATIO = "aic_scalar_ratio" + AIC_MTE1_TIME = "aic_mte1_time(us)" + AIC_MTE1_RATIO = "aic_mte1_ratio" + AIC_MTE2_TIME = "aic_mte2_time(us)" + AIC_MTE2_RATIO = "aic_mte2_ratio" + AIC_FIXPIPE_TIME = "aic_fixpipe_time(us)" + AIC_FIXPIPE_RATIO = "aic_fixpipe_ratio" + AIC_ICACHE_MISS_RATE = "aic_icache_miss_rate" + AIV_TIME = "aiv_time(us)" + AIV_TOTAL_CYCLES = "aiv_total_cycles" + AIV_VEC_TIME = "aiv_vec_time(us)" + AIV_VEC_RATIO = "aiv_vec_ratio" + AIV_SCALAR_TIME = "aiv_scalar_time(us)" + AIV_SCALAR_RATIO = "aiv_scalar_ratio" + AIV_MTE2_TIME = "aiv_mte2_time(us)" + AIV_MTE2_RATIO = "aiv_mte2_ratio" + AIV_MTE3_TIME = "aiv_mte3_time(us)" + AIV_MTE3_RATIO = "aiv_mte3_ratio" + AIV_ICACHE_MISS_RATE = "aiv_icache_miss_rate" + CUBE_UTILIZATION = "cube_utilization( %)" + TASK_DURATION_SUM = "Task Duration Sum(us)" + TASK_DURATION_MEAN = "Task Duration Mean(us)" + TASK_DURATION_STD = "Task Duration Std(us)" + TASK_DURATION_RATIO = "Task Duration Ratio(100%)" + SIZE = "size(MB)" + THROUGHPUT = "throughput(GB/s)" + COLOR = "color" + GAP = "Gap(us)" + DURATION_SUM = "Duration Sum(us)" + COUNT = "Count" + MAX_DURATION = "Max Duration(us)" + MIN_DURATION = "Min Duration(us)" + AVG_DURATION = "Avg Duration(us)" + DURATION_RATIO = "Duration Ratio" + INDEX = "Index" + + +# 定义CSV_TITILE_V1类,继承自CSV_TITILE类, 适配旧版csv +class CsvTitleV1(CsvTitle): + OP_NAME = "Op Name" + OP_TYPE = "OP Type" + TASK_TYPE = "Task Type" + TASK_DURATION = "Task Duration(us)" + + +# 定义CSV_TITILE_V1类,继承自CSV_TITILE类, 适配新版csv +class CsvTitleV2(CsvTitle): + OP_NAME = "Name" + OP_TYPE = "Type" + TASK_TYPE = "Accelerator Core" + TASK_DURATION = "Duration(us)" + + class Constant: + DTYPE_SIZE_MAP = {"int8": 1, "uint8": 1, + "int16": 2, "uint16": 2, + "int32": 4, "uint32": 4, + "int64": 8, "uint64": 8, + "float16": 2, + "bfloat16": 2, + "bf16": 2, + "dt_bf16": 2, + "float32": 4, + "float": 4, + "float64": 8, + "complex64": 8, + "complex128": 16, + "bool": 1} + TP_THRESHOLD = 1150 MAX_INPUT_MODE_LEN = 30 MAX_INPUT_ADVICE_LEN = 30 SMALL_OP_DUR_RATIO = 0.2 SMALL_OP_NUM_RATIO = 0.2 + BYTE_UNIT_TRANS = 1024 + UNIT_TRANS = 1000 # mode list COMPUTE = "compute" @@ -35,6 +128,7 @@ class Constant: # compute NPU_FUSED = "npu_fused" + NPU_SLOW = "npu_slow" # timeline OPTIM = "optimizer" @@ -108,3 +202,24 @@ class Constant: ("Cast", "Mul", "MaskedFill", "SoftmaxV2", "Cast"): "torch_npu.npu_scaled_masked_softmax", ("Mul", "Slice", "Neg", "Slice", "ConcatD", "Mul"): "torch_npu.npu_rotary_mul", ("Cast", "Square", "ReduceMeanD", "Add", "Rsqrt", "Mul", "Cast", "Mul"): "torch_npu.npu_rms_norm"} + TITLE = CsvTitleV2 + + @classmethod + def update_title(cls): + cls.TITLE = CsvTitleV1 + + +class CoreType: + AIV = "AI_VECTOR_CORE" + AIC = "AI_CORE" + AICPU = "AI_CPU" + MIX_AIV = "MIX_AIV" + MIX_AIC = "MIX_AIC" + HCCL = "HCCL" + + +class PerfColor(Enum): + WHITE = 0 + GREEN = 1 + YELLOW = 2 + RED = 3 diff --git a/profiler/advisor/advisor_backend/common_func_advisor/trace_view_json.py b/profiler/advisor/advisor_backend/common_func_advisor/trace_view_json.py index 08ef02876561001b9721e365c9aa6934057674de..8171f06ee235fc02da715044b4d310087c36c102 100644 --- a/profiler/advisor/advisor_backend/common_func_advisor/trace_view_json.py +++ b/profiler/advisor/advisor_backend/common_func_advisor/trace_view_json.py @@ -12,13 +12,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import os from abc import abstractmethod from dataclasses import dataclass from dataclasses import field from typing import Dict from typing import List +import pandas as pd + from common_func.file_manager import FileManager @@ -89,9 +91,34 @@ class TraceViewJson: self.cann_dur_events: Dict[str, DurationEvent] = dict() self.ascend_hardware_dur_events: Dict[str, DurationEvent] = dict() self.torch_2_npu_flow_events: Dict[str, FlowEvent] = dict() - traces = FileManager.read_json_file(path) self._load_obj(traces) + + def get_call_stack(self, data: pd.DataFrame, index_id: int, ts_col: str) -> str: + if ts_col not in data.columns.tolist(): + print("[ERROR] No {} col found in data columns.".format(ts_col)) + return "" + row = data.loc[index_id] + timestamp = row[ts_col] + flow_event = self.get_torch_2_npu_flow_event(timestamp) + if not flow_event.valid(): + print("[ERROR] Get flow event failed for pattern {}.".format(row['pattern'])) + return "" + flow_event_s_key = flow_event.s_point_ts + python_dur_events = self.get_python_dur_events_contain_ts(flow_event_s_key) + if not python_dur_events: + print("[ERROR] No python dur event found for pattern {}.".format(row['pattern'])) + return "" + # 保持新老版本callstack兼容性 + if python_dur_events[0].args.get("Call stack"): + # 旧版本 + call_stack_list = python_dur_events[0].args.get("Call stack").split(";") + else: + python_dur_events.sort(key=lambda e: e.ts) + # 新版本 + call_stack_list = [event.name for event in python_dur_events if event.cat == "python_function"] + call_stack = "\n".join(call_stack_list) + return call_stack def get_torch_2_npu_flow_event(self, end_time) -> FlowEvent: if not self.torch_2_npu_flow_events or not self.torch_2_npu_flow_events.get(end_time): diff --git a/profiler/advisor/advisor_backend/compute_advice/npu_fused/csv_analyzer.py b/profiler/advisor/advisor_backend/compute_advice/npu_fused/csv_analyzer.py index 5411610a7f4229c6f01c04e352d380f3a2864784..c85c14d618ceda199c9c376abc27a3581eed97b8 100644 --- a/profiler/advisor/advisor_backend/compute_advice/npu_fused/csv_analyzer.py +++ b/profiler/advisor/advisor_backend/compute_advice/npu_fused/csv_analyzer.py @@ -28,18 +28,10 @@ class CSVAnalyzer: def process(self): df = pd.read_csv(self._path, dtype={"Start Time(us)": str}) - - - pool = multiprocessing.Pool(multiprocessing.cpu_count()) - # 数据预解析 - result = pool.map(self.update_op_row, df.iterrows()) - pool.close() - - preparse_df = pd.DataFrame(result) # 分析是否存在可融合的算子 - op_type_list = preparse_df["Type"].tolist() - duration_list = preparse_df["Duration(us)"].tolist() - start_times = preparse_df["Start Time(us)"].tolist() + op_type_list = df["Type"].tolist() + duration_list = df["Duration(us)"].tolist() + start_times = df["Start Time(us)"].tolist() # 去除末尾的\t分隔符 start_times = [start_time[:-1] for start_time in start_times] result_list = [] @@ -50,10 +42,6 @@ class CSVAnalyzer: "index", "first_timestamp"] return data_frame - @staticmethod - def update_op_row(row): - return OpPerfFactory.build(row[1]).update() - @staticmethod def find_all_sub_lists(op_type_list, duration_list, start_times, expect_sub_list): # 创建一个空字典,用来存储子列表和它们的出现次数和起始位置 diff --git a/profiler/advisor/advisor_backend/compute_advice/npu_fused/op_perf.py b/profiler/advisor/advisor_backend/compute_advice/npu_fused/op_perf.py index 2442807fd10b7942177990d2283ad34c369659bd..7bcbed5a75807b57a55787c743cfaaff55a68589 100644 --- a/profiler/advisor/advisor_backend/compute_advice/npu_fused/op_perf.py +++ b/profiler/advisor/advisor_backend/compute_advice/npu_fused/op_perf.py @@ -12,19 +12,29 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import functools from typing import Dict + from common_func_advisor.constant import Constant +from common_func_advisor.constant import CoreType +from common_func_advisor.constant import PerfColor class OpPerfFactory: @classmethod def build(cls, op_row: Dict): - return OpPerf(op_row) + if op_row.get(Constant.TITLE.TASK_TYPE) == CoreType.AIV: + return VecOpPerf(op_row) + elif op_row.get(Constant.TITLE.TASK_TYPE) == CoreType.AIC: + return CubeOpPerf(op_row) + else: + return OpPerf(op_row) class OpPerf: def __init__(self, op_row: Dict): + if "OP Type" in op_row.keys(): + Constant.update_title() self.row = op_row self.model_name = op_row.get("Model Name") self.model_id = op_row.get("Model ID") @@ -75,6 +85,112 @@ class OpPerf: self.aiv_mte3_ratio = op_row.get("aiv_mte3_ratio") self.aiv_icache_miss_rate = op_row.get("aiv_icache_miss_rate") self.cube_utilization = op_row.get("cube_utilization( %)") + + @staticmethod + def get_dtype_size(dtype_str: str): + return Constant.DTYPE_SIZE_MAP.get(dtype_str.lower(), 0) + + @staticmethod + def get_element_count(shape: list): + return functools.reduce(lambda x, y: int(x) * int(y), shape) + + @staticmethod + def shape_to_tuple(shape_str: str) -> tuple: + if not isinstance(shape_str, str): + return [] + shape_str = shape_str.strip('"') + split_shape = shape_str.strip(';') + if not split_shape: + return [] + pairs = split_shape.split(';') + shape_result = [] + for pair in pairs: + pair = pair.strip(";") + elements = pair.split(',') + elements = tuple(int(element) if "" != element else 0 for element in elements) + shape_result.append(elements) + return tuple(shape_result) + + @staticmethod + def dtype_to_tuple(dtypes_str: str) -> tuple: + if not isinstance(dtypes_str, str): + return [] + dtypes_str = dtypes_str.strip('"') + split_dtypes = dtypes_str.strip(';') + if not split_dtypes: + return [] + pairs = split_dtypes.split(';') + return tuple(pairs) + + def get_mac_ratio(self): + return self.aic_mac_ratio + + def get_size(self, shapes_str, dtypes_str): + shapes = self.shape_to_tuple(shapes_str) + dtypes = self.dtype_to_tuple(dtypes_str) + if len(shapes) > len(dtypes): + print(f"[ERROR] The size of shape is greater than that of dtypes.") + return 0 + if len(shapes) < len(dtypes): + shapes = list(shapes) + shapes.extend([(1,)] * (len(dtypes) - len(shapes))) + all_size = 0 + for index, shape in enumerate(shapes): + element_count = self.get_element_count(shape) + dtype_size = self.get_dtype_size(dtypes[index]) + all_size += element_count * dtype_size + return all_size + + def get_calc_size(self): + # input and output bytes (MB) + if not self.input_shapes or not self.output_shapes: + print("[ERROR] There is no tensor data, do not assess vector op performance.") + return 0 + intput_size = self.get_size(self.input_shapes, self.input_data_types) + output_size = self.get_size(self.output_shapes, self.output_data_types) + return (intput_size + output_size) / (Constant.BYTE_UNIT_TRANS * Constant.BYTE_UNIT_TRANS) + + def get_throughput(self): + # throughput(GB/s) + if not self.task_duration or abs(self.task_duration) < 1e-6: + print("[ERROR] There is no task_duration, do not assess vector op performance.") + return 0 + return self.row[Constant.TITLE.SIZE] / Constant.BYTE_UNIT_TRANS / self.task_duration * Constant.UNIT_TRANS * Constant.UNIT_TRANS + + def get_perf_color(self): + return PerfColor.WHITE def update(self): + self.row[Constant.TITLE.SIZE] = self.get_calc_size() + self.row[Constant.TITLE.THROUGHPUT] = self.get_throughput() + self.row[Constant.TITLE.COLOR] = self.get_perf_color().name return self.row + + +class VecOpPerf(OpPerf): + def get_perf_color(self) -> PerfColor: + throughput = self.row[Constant.TITLE.THROUGHPUT] + op_duration = self.task_duration + tp_threshold = Constant.TP_THRESHOLD + if throughput == 0: + return PerfColor.WHITE + if throughput < tp_threshold / 2 and op_duration > 20: + return PerfColor.RED + elif tp_threshold / 2 <= throughput < tp_threshold: + return PerfColor.YELLOW + else: + return PerfColor.GREEN + + +class CubeOpPerf(OpPerf): + def get_perf_color(self) -> PerfColor: + aic_mac_ratio = self.get_mac_ratio() + if not aic_mac_ratio: + print("[WARNING] There is no aic_mac_ratio, do not assess cube op performance.") + return PerfColor.WHITE + elif aic_mac_ratio < 0.6: + return PerfColor.RED + elif 0.6 <= aic_mac_ratio < 0.8: + return PerfColor.YELLOW + else: + return PerfColor.GREEN diff --git a/profiler/advisor/advisor_backend/compute_advice/npu_slow_advice.py b/profiler/advisor/advisor_backend/compute_advice/npu_slow_advice.py new file mode 100644 index 0000000000000000000000000000000000000000..caff1c792c2171c33a4dd876b0741d6c215c5766 --- /dev/null +++ b/profiler/advisor/advisor_backend/compute_advice/npu_slow_advice.py @@ -0,0 +1,82 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from abc import ABC +import multiprocessing + +import pandas as pd + +from compute_advice.compute_advice_base import ComputeAdviceBase +from compute_advice.npu_fused.op_perf import OpPerfFactory +from common_func_advisor.constant import Constant +from common_func_advisor.constant import PerfColor +from advisor_backend.common_func_advisor.trace_view_json import TraceViewJson + + +class NpuSlowAdvice(ComputeAdviceBase, ABC): + OP_PERF_SHEET = "op_perf" + + def __init__(self, collection_path: str): + super().__init__(collection_path) + self.kernel_details_path = "" + self.data = pd.DataFrame() + + @staticmethod + def save_to_excel(data: pd.DataFrame, file_path: str) -> None: + writer = pd.ExcelWriter(file_path, engine="xlsxwriter", mode="w") + data.index.name = Constant.TITLE.INDEX + data.to_excel(writer, index=True, sheet_name=NpuSlowAdvice.OP_PERF_SHEET) + NpuSlowAdvice.color_sheet(data, writer.book, writer.sheets[NpuSlowAdvice.OP_PERF_SHEET]) + writer.sheets[NpuSlowAdvice.OP_PERF_SHEET].freeze_panes = "A2" + writer.close() + + @staticmethod + def color_sheet(data: pd.DataFrame, workbook, worksheet): + color_rgb = { + PerfColor.GREEN.name: workbook.add_format({'bg_color': '#C6EFCE'}), + PerfColor.YELLOW.name: workbook.add_format({'bg_color': '#FFEB9C'}), + PerfColor.RED.name: workbook.add_format({'bg_color': '#FFC7CE'}), + } + for row in data.iterrows(): + color = row[1][Constant.TITLE.COLOR] + fill_format = color_rgb.get(color) + if not fill_format: + continue + worksheet.set_row(row[0] + 1, None, fill_format) + + @staticmethod + def update_op_row(row: tuple): + return OpPerfFactory.build(row[1]).update() + + def get_call_stack(self, data: pd.DataFrame, index_id: int, ts_col: str) -> str: + if not self.has_callstack(): + print("There is no call stack info, please set 'with_stack=True'") + return "" + trace_json = TraceViewJson(self.trace_view_path) + return trace_json.get_call_stack(data, index_id, ts_col) + + def run(self): + if not self.path_check(): + return self.data + self.process() + return self.data + + def process(self): + self.data = pd.read_csv(self.kernel_details_path, dtype={"Start Time(us)": str}) + # 去除末尾的\t分隔符 + self.data["Start Time(us)"] = self.data["Start Time(us)"].apply(lambda x: x[:-1]) + pool = multiprocessing.Pool(multiprocessing.cpu_count()) + result = pool.map(self.update_op_row, self.data.iterrows()) + pool.close() + self.data = pd.DataFrame(result) diff --git a/profiler/advisor/advisor_backend/overall_advice/overall_summary_advice.py b/profiler/advisor/advisor_backend/overall_advice/overall_summary_advice.py index 7cbf7d807e0498a4f17d7f1ee78b38fd2212e94e..bdee8029b8470d568b2e8888e84a1e14dc3d03a4 100644 --- a/profiler/advisor/advisor_backend/overall_advice/overall_summary_advice.py +++ b/profiler/advisor/advisor_backend/overall_advice/overall_summary_advice.py @@ -27,7 +27,7 @@ class OverallSummaryAdvice(AdviceBase): } time_name_map = { "Computing Time": "computing", - "Uncovered Communication Time": "communication", + "Uncovered Communication Time(Wait Time)": "communication", "Free Time": "free", 'Cube Time(Num)': 'Cube Time', 'Vector Time(Num)': 'Vector Time', @@ -39,7 +39,7 @@ class OverallSummaryAdvice(AdviceBase): performance_time_dict = { "Computing Time": ['Cube Time(Num)', 'Vector Time(Num)', 'Flash Attention Time(Forward)(Num)', 'Flash Attention Time(Backward)(Num)', 'Other Time'], - "Uncovered Communication Time": [], + "Uncovered Communication Time(Wait Time)": [], "Free Time": ['SDMA Time(Num)'] } @@ -112,6 +112,7 @@ class OverallSummaryAdvice(AdviceBase): if time_value == Constant.INVALID_VALUE: continue duration, _ = self.split_duration_and_num(time_value) + time_category = time_category.split("(")[0] time_category_dict[time_category] = duration self.get_sub_category_time(time_category, time_list, duration) self.cur_data["overall_data"] = time_category_dict @@ -145,7 +146,7 @@ class OverallSummaryAdvice(AdviceBase): overall_data = self.cur_data.get("overall_data") if not overall_data: return - e2e_time = sum([data for data in overall_data.values()]) + e2e_time = '%.3f' % sum([data for data in overall_data.values()]) overall_bottleneck = f"The Model E2E Time is {e2e_time}s.\n" comparison_bottleneck = "" for time_type, time_value in overall_data.items(): @@ -160,7 +161,9 @@ class OverallSummaryAdvice(AdviceBase): if not self._has_base_collection: continue # add comparison bottleneck - base_duration, _ = self.split_duration_and_num(self.get_time_value(time_type, self._base_data)) + time_type_origin = "Uncovered Communication Time(Wait Time)" \ + if time_type == "Uncovered Communication Time" else time_type + base_duration, _ = self.split_duration_and_num(self.get_time_value(time_type_origin, self._base_data)) if time_value > base_duration: ratio = "{:.2%}".format(self.calculate_ratio(time_value - base_duration, base_duration)) comparison_bottleneck += f"{time_type} exceeds the benchmark by {ratio}\n" diff --git a/profiler/advisor/compute_perf_analysis.ipynb b/profiler/advisor/compute_perf_analysis.ipynb index 27c9caf37bf43871f319a9418294953f54f9cafd..e7a663130c8da335129513a5ca1a99cf28fe48b7 100644 --- a/profiler/advisor/compute_perf_analysis.ipynb +++ b/profiler/advisor/compute_perf_analysis.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "metadata": { "ExecuteTime": { "end_time": "2024-02-21T09:19:13.937531900Z", @@ -11,6 +11,7 @@ }, "outputs": [], "source": [ + "import os\n", "import pandas as pd\n", "\n", "from advisor_backend.interface import Interface\n", @@ -24,15 +25,18 @@ "# 算子调优分析\n", "## 1. 算子分析的数据准备\n", "当前算子分析工具支持分析Ascend Pyorch Profiler方式生成的ascend_pt目录\n", - "## 2. 算子分析解决的问题\n", + "## 2. 融合算子分析\n", "当前支持分析模型中存在可融合的小算子,并给出优化建议。\n", "\n", - "\"更多融合算子信息,请查阅 https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/700alpha003/processormodel/hardwaredesc_0001.html" + "\"更多融合算子信息,请查阅 https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/700alpha003/processormodel/hardwaredesc_0001.html\n", + "\n", + "## 3. 异常性能算子分析\n", + "支持分析模型中性能异常的计算算子" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 4, "metadata": { "ExecuteTime": { "end_time": "2024-02-22T08:41:17.455567500Z", @@ -44,18 +48,75 @@ "name": "stdout", "output_type": "stream", "text": [ - "[INFO] Start to analyse the target file: C:\\data\\ascend_pt\\ASCEND_PROFILER_OUTPUT\\kernel_details.csv\n", - " pattern_name pattern len count duration sum(us) op durations(us) index\n", - "18 torch_npu.npu_swiglu (Slice, Slice, Swish, Mul) 4 1 12.56 [3.14, 3.14, 3.14, 3.14] [0]\n", + "[INFO] Start to analyse the target file: D:\\work\\ascend_pt\\ASCEND_PROFILER_OUTPUT\\kernel_details.csv\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
pattern_namepatternlencountduration sum(us)op durations(us)index
18torch_npu.npu_swiglu(Slice, Slice, Swish, Mul)4127.53[21.2, 0.05, 3.14, 3.14][0]
\n", + "
" + ], + "text/plain": [ + " pattern_name pattern len count duration sum(us) op durations(us) index\n", + "18 torch_npu.npu_swiglu (Slice, Slice, Swish, Mul) 4 1 27.53 [21.2, 0.05, 3.14, 3.14] [0]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ "\n", "\n", - "The computing time of fusable op is 12.56 ms.\n", + "The computing time of fusable op is 27.53 ms.\n", "\n", "\n", "Advice 0:\n", "Replace [Slice, Slice, Swish, Mul] with torch_npu.npu_swiglu. This pattern first happened in: \n", - "torch/nn/modules/module.py(1513): _call_impl\n", - "profiler_main.py(116):forward\n" + "/root/torch/module.py\n", + "/root/test/slice.py(116)\n" ] } ], @@ -66,7 +127,7 @@ "data = interface.get_data('compute', 'npu_fused')\n", "pd.set_option('display.max_columns', None)\n", "pd.set_option('display.width', 900)\n", - "print(data['data'].iloc[:, :-2])\n", + "display(data['data'].iloc[:, :-2])\n", "print('\\n')\n", "print(data['bottleneck'])\n", "print('\\n')\n", @@ -75,21 +136,217 @@ }, { "cell_type": "code", - "outputs": [], + "execution_count": 5, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[INFO] Start to analyse the target file: D:\\work\\ascend_pt\\ASCEND_PROFILER_OUTPUT\\kernel_details.csv\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
Step IdModel IDTask IDStream IDNameTypeAccelerator CoreStart Time(us)Duration(us)Wait Time(us)Block DimMix Block DimInput ShapesInput Data TypesInput FormatsOutput ShapesOutput Data TypesOutput FormatsContext IDaicore_time(us)aic_total_cyclesaic_mac_ratioaic_mac_int8_ratioaic_cube_fopsaic_vector_fopsaiv_time(us)aiv_total_cyclesaiv_vec_fp32_ratioaiv_vec_fp16_ratioaiv_vec_int32_ratioaiv_vec_misc_ratioaiv_cube_fopsaiv_vector_fopssize(MB)throughput(GB/s)color
014294967295126516Slice1SliceAI_VECTOR_CORE169952962310675021.20261.56904,1025INT64FORMAT_ND4,1025INT32FORMAT_NDNaN0.00.00.00.00.00.01.7729508.00.00.00.00620.00.05856.00.0469212.161371RED
414294967295126516Add1AddAI_CORE16995296231067543.14261.56904,1025INT64FORMAT_ND4,1025INT32FORMAT_NDNaN2.328888.00.20.10.10.70.000.00.00.00.00000.00.00.00.04692114.592698RED
\n", + "
" + ], + "text/plain": [ + " Step Id Model ID Task ID Stream ID Name Type Accelerator Core Start Time(us) Duration(us) Wait Time(us) Block Dim Mix Block Dim Input Shapes Input Data Types Input Formats Output Shapes Output Data Types Output Formats Context ID aicore_time(us) aic_total_cycles aic_mac_ratio aic_mac_int8_ratio aic_cube_fops aic_vector_fops aiv_time(us) aiv_total_cycles aiv_vec_fp32_ratio aiv_vec_fp16_ratio aiv_vec_int32_ratio aiv_vec_misc_ratio aiv_cube_fops aiv_vector_fops size(MB) throughput(GB/s) color\n", + "0 1 4294967295 1265 16 Slice1 Slice AI_VECTOR_CORE 1699529623106750 21.20 261.56 9 0 4,1025 INT64 FORMAT_ND 4,1025 INT32 FORMAT_ND NaN 0.0 0.0 0.0 0.0 0.0 0.0 1.77 29508.0 0.0 0.0 0.0062 0.0 0.0 5856.0 0.046921 2.161371 RED\n", + "4 1 4294967295 1265 16 Add1 Add AI_CORE 1699529623106754 3.14 261.56 9 0 4,1025 INT64 FORMAT_ND 4,1025 INT32 FORMAT_ND NaN 2.3 28888.0 0.2 0.1 0.1 0.7 0.00 0.0 0.0 0.0 0.0000 0.0 0.0 0.0 0.046921 14.592698 RED" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ + "# 异常性能算子识别\n", + "from advisor_backend.compute_advice.npu_slow_advice import NpuSlowAdvice\n", "\n", - "\n" + "npu_slow_advice = NpuSlowAdvice(compute_path)\n", + "data = interface.get_data('compute', 'npu_slow')\n", + "slow_op_data = data[data[\"color\"] == \"RED\"]\n", + "display(slow_op_data)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "NpuSlowAdvice.save_to_excel(data, file_path=os.path.join(compute_path, \"slow_op.xlsx\"))" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "call stack: \n", + "/root/torch/module.py\n", + "/root/test/slice.py(116)\n" + ] + } ], - "metadata": { - "collapsed": false - } + "source": [ + "# 异常性能算子call stack\n", + "call_stack = npu_slow_advice.get_call_stack(data, index_id=0, ts_col=\"Start Time(us)\")\n", + "print(\"call stack: \")\n", + "print(call_stack)" + ] } ], "metadata": { "kernelspec": { - "name": "python3", + "display_name": "Python 3 (ipykernel)", "language": "python", - "display_name": "Python 3 (ipykernel)" + "name": "python3" }, "language_info": { "codemirror_mode": { @@ -101,7 +358,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.18" + "version": "3.11.5" } }, "nbformat": 4, diff --git a/profiler/affinity_cpu_bind/README.md b/profiler/affinity_cpu_bind/README.md new file mode 100644 index 0000000000000000000000000000000000000000..8c3b47ed5183fd2dbade8fc316e0319b8feea880 --- /dev/null +++ b/profiler/affinity_cpu_bind/README.md @@ -0,0 +1,40 @@ +# 昇腾亲和性CPU绑核工具 + +昇腾亲和性CPU绑核工具支持用户无需修改代码,直接运行工具即可实现按CPU亲和性策略绑核,提升推理或训练性能。 + +绑核工具用户arm服务器环境,对于训练或推理任务因为CPU资源调度等出现host_bound问题时使用,可改善该问题;对于非host_bound的场景无明显改善效果。 + +## 使用须知 + +使用绑核工具前手动执行npu-smi info -t topo,出现以下类似信息,说明环境支持绑核,否则请将环境HDK包升级到Ascend HDK 23.0.RC2及以上版本。 + + NPU0 NPU1 NPU2 NPU3 NPU4 NPU5 NPU6 NPU7 NPUx CPU Affinity + NPU0 X HCCS HCCS HCCS HCCS HCCS HCCS HCCS ... xx-xx + NPU1 HCCS X HCCS HCCS HCCS HCCS HCCS HCCS ... xx-xx + NPU2 HCCS HCCS X HCCS HCCS HCCS HCCS HCCS ... xx-xx + NPU3 HCCS HCCS HCCS X HCCS HCCS HCCS HCCS ... xx-xx + NPU4 HCCS HCCS HCCS HCCS X HCCS HCCS HCCS ... xx-xx + NPU5 HCCS HCCS HCCS HCCS HCCS X HCCS HCCS ... xx-xx + NPU6 HCCS HCCS HCCS HCCS HCCS HCCS X HCCS ... xx-xx + NPU7 HCCS HCCS HCCS HCCS HCCS HCCS HCCS X ... xx-xx + NPUx ... ... ... ... ... ... ... ... ... ... + +## 使用方式 + +1.执行以下命令实施绑核: + + - 直接执行绑核命令 +```bash +python3 bind_core.py -app/--application="inferenec/train cmd" +``` +该方式会自动拉起训练或推理任务,检测任务进程,并实施绑核。 + + - 手动拉起训练或推理任务后再执行绑核 +```bash +python3 bind_core.py +``` +该方式会循环查找(循环5次,每次10s,若找不到进程,则直接退出)使用到NPU的任务进程,并实施绑核。 + +2.绑核运行过程的日志会保存到当前路径的bind_core_时间戳.log。 + +3.如果推理或训练进程拉起后需要一定时间预处理,才会真正执行任务,可在执行绑核命令时设置-t/--time参数(单位秒),绑核工具会在延迟配置的时间后,再实施绑核动作。例如:python3 bind_core.py -app="cmd" -t=10,配置后工具会在10秒后执行绑核操作。 \ No newline at end of file diff --git a/profiler/affinity_cpu_bind/__init__.py b/profiler/affinity_cpu_bind/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/profiler/affinity_cpu_bind/bind_core.py b/profiler/affinity_cpu_bind/bind_core.py new file mode 100644 index 0000000000000000000000000000000000000000..1bd9720f9a9fab7a4a6950ef7e0174c786c95a45 --- /dev/null +++ b/profiler/affinity_cpu_bind/bind_core.py @@ -0,0 +1,214 @@ +import subprocess +import argparse +import os +import time +import logging +from datetime import datetime +from datetime import timezone + + +class PathManager: + DATA_FILE_AUTHORITY = 0o640 + + @classmethod + def create_file_safety(cls, path: str): + base_name = os.path.basename(path) + msg = f"Failed to create file: {base_name}" + if os.path.islink(path): + raise RuntimeError(msg) + if os.path.exists(path): + return + try: + os.close(os.open(path, os.O_WRONLY | os.O_CREAT, cls.DATA_FILE_AUTHORITY)) + except Exception as err: + raise RuntimeError(msg) from err + + +class BindCoreManager(): + DEFAULT_FIND_RUNNING_PID_TIMES = 5 + + def __init__(self): + self.npu_id_list = [] + self.running_pid_on_npu = {} + self.find_running_pid_times = self.DEFAULT_FIND_RUNNING_PID_TIMES + self.npu_affinity_cpu_dict = {} + self.log_file = '' + self._init_log_file() + + + def _init_log_file(self): + now_time = datetime.now(tz=timezone.utc) + time_stamp = str(now_time.year) + '_' + \ + str(now_time.month) + '_' + \ + str(now_time.day) + '_' + \ + str(now_time.hour) + '_' + \ + str(now_time.minute) + '_' + \ + str(now_time.second) + log_file_name = 'bind_core_' + time_stamp + '.log' + msg = f"Failed to create file: {log_file_name}" + try: + PathManager.create_file_safety(os.path.join(os.getcwd(), log_file_name)) + except RuntimeError as err: + raise RuntimeError(msg) from err + self.log_file = log_file_name + logging.basicConfig(filename=self.log_file, + level=logging.INFO, + filemode='w', + format='%(asctime)s-%(name)s-%(levelname)s-%(message)s') + + def _get_all_npu_id(self) -> None: + get_npu_info_cmd = 'npu-smi info -l' + get_npu_info_process = subprocess.run(get_npu_info_cmd.split(), shell=False, capture_output=True) + get_npu_id_cmd = 'grep ID' + get_npu_id_process = subprocess.run(get_npu_id_cmd.split(), shell=False, input=get_npu_info_process.stdout, capture_output=True) + res = get_npu_id_process.stdout.decode('utf-8').split() + for i in res: + if i.isdigit(): + self.npu_id_list.append(int(i)) + logging.info(f'NPU total id list: {self.npu_id_list}') + + def _get_npu_affinity(self) -> bool: + cpu_num = os.cpu_count() + cpu_num_for_each_npu = cpu_num // len(self.npu_id_list) + get_npu_topo_cmd = 'npu-smi info -t topo' + p = subprocess.run(get_npu_topo_cmd.split(), shell=False, capture_output=True) + res = p.stdout.decode('utf-8').split() + if not res: + print('[ERROR] Failed to run get npu affinity info, please check if driver version support cmd npu-smi info -t topo') + return False + + index = 0 + for v in res: + if '-' in v: + affinity_cpus = [] + cpu_lists = v.split(',') + for cpu_list in cpu_lists: + cpus = cpu_list.split('-') + if len(cpus) != 2: + continue + if int(cpus[1]) - int(cpus[0]) == cpu_num_for_each_npu - 1: + cpus[1] = str(int(cpus[1]) + cpu_num_for_each_npu) + affinity_cpus.append(cpus[0] + '-' + cpus[1]) + if index < len(self.npu_id_list): + self.npu_affinity_cpu_dict[self.npu_id_list[index]] = ','.join(affinity_cpu for affinity_cpu in affinity_cpus) + index += 1 + else: + print('[ERROR] Get affinity_cpu_list for {} npus, more than real npu num: {}'.format(index + 1, len(self.npu_id_list))) + return False + + for k in self.npu_affinity_cpu_dict.keys(): + logging.info(f'Affinity CPU list {self.npu_affinity_cpu_dict[k]} for NPU {k}') + return True + + def get_running_pid_on_npu(self) -> bool: + no_running_pids_on_npu_msg = '[INFO] Now there is no running process on all NPUs, stop bind cores' + logging.info('Begin to find running process on all NPUs') + # get running process on NPUs + for times in range(self.find_running_pid_times): + running_pid_on_npu = {} + for npu_id in self.npu_id_list: + get_npu_pids_cmd = 'npu-smi info -t proc-mem -i {} -c 0'.format(npu_id) + get_npu_pids_process = subprocess.run(get_npu_pids_cmd.split(), shell=False, capture_output=True) + res = get_npu_pids_process.stdout.decode('utf-8').split() + pid_list = [] + for value in res: + if value.startswith('id:'): + pid = value.split(':')[1] + pid_list.append(pid) + if pid_list: + running_pid_on_npu[npu_id] = list(set(pid_list)) + + if len(self.running_pid_on_npu.keys()) == len(running_pid_on_npu.keys()) and running_pid_on_npu: + self.running_pid_on_npu = running_pid_on_npu + break + + self.running_pid_on_npu = running_pid_on_npu + time.sleep(5) + + # delete repeat pid + for npu_id in self.npu_id_list: + if npu_id not in self.running_pid_on_npu: + continue + pids_on_npu = self.running_pid_on_npu[npu_id] + for pid in pids_on_npu: + for npu_id_with_pids, pids in self.running_pid_on_npu.items(): + if npu_id == npu_id_with_pids: + continue + if pid in pids: + pids_on_npu.remove(pid) + + if_running_process = False + for npu_id, pids in self.running_pid_on_npu.items(): + if not pids: + logging.info(f'There is no running process on NPU {npu_id}') + else: + logging.info(f'Succeed to find running process {pids} on NPU {npu_id}') + if_running_process = True + if not if_running_process: + print(no_running_pids_on_npu_msg) + return if_running_process + + def get_npu_info(self) -> bool: + try: + self._get_all_npu_id() + if not self._get_npu_affinity(): + return False + except subprocess.CalledProcessError: + return False + return True + + def run_bind_core(self): + if not self.running_pid_on_npu: + return + for npu, pid_list in self.running_pid_on_npu.items(): + if npu not in self.npu_affinity_cpu_dict.keys(): + logging.warning(f'Cannot find affinity cpu for npu: {npu}') + continue + affinity_cpu = self.npu_affinity_cpu_dict.get(npu) + for pid in pid_list: + try: + logging.info(f'Begin to bind cores for process {pid} on NPU {npu}') + set_affinity_cpu_cmd = 'taskset -pc {} {}'.format(affinity_cpu, pid) + p = subprocess.run(set_affinity_cpu_cmd.split(), shell=False, capture_output=True) + logging.info(p.stdout.decode('utf-8')) + except subprocess.CalledProcessError: + print('[ERROR] Failed to bind process {} on NPU {} with cpu cores list {}'.format(pid, npu, affinity_cpu)) + + logging.info(f'Succeed to bind process {pid} on NPU {npu} with cpu cores list {affinity_cpu}') + + def args_parse(self): + parser = argparse.ArgumentParser(description='This is a affinity cpu core bind script.') + parser.add_argument('-t', '--time', type=int, metavar='', help='Wait time before bind cores that you want to set. The unit is \'s\'.') + parser.add_argument('-app', '--application', metavar='', nargs='+', help='Training or inference command that you want to run.') + args = parser.parse_args() + if args.application: + application_cmd = ' '.join(args.application) + self.launch_process(application_cmd) + time.sleep(2) + # if time is set, wait for setting time before bind cores + if args.time: + time.sleep(args.time) + + def launch_process(self, cmd: list): + logging.info(f'Start to execute cmd: {cmd}') + try: + subprocess.Popen(cmd.split(), shell=False) + except subprocess.CalledProcessError as e: + raise RuntimeError(f'Failed to run cmd: {cmd}') from e + + +if __name__ == '__main__': + print('[INFO] Begin to run bind-cores script...') + bind_core_manager = BindCoreManager() + bind_core_manager.args_parse() + + if not bind_core_manager.get_npu_info(): + print('[ERROR] Failed to get current npus info') + exit() + + if not bind_core_manager.get_running_pid_on_npu(): + exit() + bind_core_manager.run_bind_core() + print('[INFO] End to run bind-cores script, the log is saved in {}'.format(bind_core_manager.log_file)) + + diff --git a/profiler/cluster_analyse/README.md b/profiler/cluster_analyse/README.md index 7cdb2d2c1e68da2cbbc00629ddf06b2ae48a28c2..f7646f67c40c53d1f82aecb4ae9dc0bfa810a77f 100644 --- a/profiler/cluster_analyse/README.md +++ b/profiler/cluster_analyse/README.md @@ -21,6 +21,12 @@ experimental_config = torch_npu.profiler._ExperimentalConfig( - ./ASCEND_PROFILER_OUTPUT/communication.json, - ./ASCEND_PROFILER_OUTPUT/communication_matrix.json +或者具备: + +- analysis.db + +以上csv、json文件与db文件只能存在一类,否则集群分析工具解析异常。 + 确认这几个文件生成后,继续下面的集群分析。 ## 数据汇聚与集群解析 @@ -37,11 +43,11 @@ python3 cluster_analysis.py -d {cluster profiling data path} -m {mode} | --collection_path或-d | 性能数据汇集目录,运行分析脚本之后会在该目录下自动创建cluster_analysis_output文件夹,保存分析数据。 | 是 | | --mode或-m | 数据解析模式。取值为:communication_matrix(解析通信矩阵数据)、communication_time(解析通信耗时数据)、all(同时解析通信矩阵和通信耗时数据),默认值为all。 | 否 | -## 交付件 +### 交付件 集群分析工具的交付件通过Ascend Insight工具展示,详见《MindStudio 可视化调优工具指南(Ascend Insight)》。 -### cluster_step_trace_time.csv +#### cluster_step_trace_time.csv 数据解析模式为communication_matrix、communication_time或all时均生成。 @@ -79,7 +85,7 @@ K列:Communication(Not Overlapped and Exclude Receive)指剔除recieve算 以上时间理论上都应该处于持平状态,即最大值小于最小值5%,否则就可能出现慢卡。 -### cluster_communication_matrix.json +#### cluster_communication_matrix.json 数据解析模式为communication_matrix或all时生成。 @@ -99,8 +105,21 @@ K列:Communication(Not Overlapped and Exclude Receive)指剔除recieve算 - “HCCS”或“PCIE”是节点内片间拷贝,速度在18GB左右或以上比较正常。 - “RDMA”是节点间拷贝,910A速度在12GB左右或以上。 -### cluster_communication.json +#### cluster_communication.json 数据解析模式为communication_time或all时生成。 主要为通信耗时数据。 + +#### cluster_analysis.db + +解析analysis.db生成的交付件,当前解析通信类数据,主要包含下面数据: + +- ClusterCommAnalyzerTime:集群通信时间信息。 +- ClusterCommAnalyzerBandwidth:集群通信带宽信息。 +- ClusterCommAnalyzerMatrix:集群通信矩阵数据。 +- CommunicationGroup:通信组信息。 +- ClusterStepTraceTime:集群迭代轨迹数据。 + + + diff --git a/profiler/cluster_analyse/analysis/analysis_facade.py b/profiler/cluster_analyse/analysis/analysis_facade.py index b383a704df27d18e0191b2b251efd9de61beee55..06be6002e1e075645dd21cd1328505829a9b3305 100644 --- a/profiler/cluster_analyse/analysis/analysis_facade.py +++ b/profiler/cluster_analyse/analysis/analysis_facade.py @@ -14,10 +14,10 @@ # limitations under the License. from multiprocessing import Process -from common_func.constant import Constant + from analysis.communication_analysis import CommunicationAnalysis +from analysis.comm_matrix_analysis import CommMatrixAnalysis from analysis.step_trace_time_analysis import StepTraceTimeAnalysis -from analysis.communication_analysis import CommMatrixAnalysis class AnalysisFacade: diff --git a/profiler/cluster_analyse/analysis/base_analysis.py b/profiler/cluster_analyse/analysis/base_analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..cc803813dda4a535c529a935c1b42dae197855c9 --- /dev/null +++ b/profiler/cluster_analyse/analysis/base_analysis.py @@ -0,0 +1,77 @@ +from abc import abstractmethod +from common_func.constant import Constant +from utils.data_transfer_adapter import DataTransferAdapter +from common_func.file_manager import FileManager + + +class BaseAnalysis: + + def __init__(self, param: dict): + self.collection_path = param.get(Constant.COLLECTION_PATH) + self.data_map = param.get(Constant.DATA_MAP) + self.data_type = param.get(Constant.DATA_TYPE) + self.communication_ops = [] + self.collective_group_dict = param.get(Constant.COMM_DATA_DICT, {}).get(Constant.COLLECTIVE_GROUP) + self.comm_ops_struct = {} + self.adapter = DataTransferAdapter() + + @staticmethod + def compute_ratio(dividend: float, divisor: float): + if abs(divisor) < Constant.EPS: + return 0 + else: + return round(dividend / divisor, 4) + + @staticmethod + def check_add_op(op_name: str): + """ + 兼容2个版本,判断是否需要将此算子信息相加 + """ + stat_list = ["middle", "top", "bottom", "total"] + total = "total" + for stat_name in stat_list: + if stat_name in op_name: + if stat_name != total: + return False + return True + + @abstractmethod + def run(self): + pass + + def dump_data(self): + if not self.comm_ops_struct: + print("[WARNING] There is no final comm ops data generated") + return + if self.data_type == Constant.TEXT: + self.dump_json() + else: + self.dump_db() + + @abstractmethod + def dump_db(self): + pass + + def dump_json(self): + output_comm_data = {} + for key in self.comm_ops_struct: + output_comm_data[str(key)] = self.comm_ops_struct.get(key) + FileManager.create_json_file(self.collection_path, output_comm_data, self.SAVED_JSON) + + def split_op_by_group(self): + for single_op in self.communication_ops: + if single_op.get(Constant.COMM_OP_TYPE) == Constant.P2P: + rank_tup = Constant.P2P + else: + rank_tup = tuple(self.collective_group_dict.get(single_op.get(Constant.GROUP_NAME), [])) + rank_id = single_op.get(Constant.RANK_ID, 'N/A') + step_id = single_op.get(Constant.STEP_ID, 'N/A') + op_name = single_op.get(Constant.COMM_OP_NAME, 'N/A') + op_info = single_op.get(Constant.COMM_OP_INFO) + self.comm_ops_struct.setdefault(rank_tup, {}).setdefault(step_id, {}).\ + setdefault(op_name, {}).setdefault(rank_id, op_info) + + def combine_ops_total_info(self): + for rank_tup, group_dict in self.comm_ops_struct.items(): + for step_id, communication_ops in group_dict.items(): + self.compute_total_info(communication_ops) diff --git a/profiler/cluster_analyse/analysis/comm_matrix_analysis.py b/profiler/cluster_analyse/analysis/comm_matrix_analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..8dc04471fe0a164fc859e51597d41028523f7a32 --- /dev/null +++ b/profiler/cluster_analyse/analysis/comm_matrix_analysis.py @@ -0,0 +1,106 @@ +import os +from collections import defaultdict + +from analysis.base_analysis import BaseAnalysis +from common_func.constant import Constant +from common_func.db_manager import DBManager + + +class CommMatrixAnalysis(BaseAnalysis): + SAVED_JSON = "cluster_communication_matrix.json" + COMMUNICATION_MATRIX_TABLE = "ClusterCommAnalyzerMatrix" + + def __init__(self, param: dict): + super().__init__(param) + self.communication_ops = param.get(Constant.COMM_DATA_DICT, {}).get(Constant.MATRIX_OPS) + + @staticmethod + def combine_link(link_info_dict: dict, single_link_dict: dict): + link_info_dict[Constant.TRANSPORT_TYPE] = single_link_dict.get(Constant.TRANSPORT_TYPE) + link_info_dict[Constant.OP_NAME] = single_link_dict.get(Constant.OP_NAME, '') + link_info_dict[Constant.TRANSIT_TIME_MS] += single_link_dict.get(Constant.TRANSIT_TIME_MS, 0) + link_info_dict[Constant.TRANSIT_SIZE_MB] += single_link_dict.get(Constant.TRANSIT_SIZE_MB, 0) + + def run(self): + if not self.communication_ops: + return + self.split_op_by_group() + self.combine_ops_total_info() + self.dump_data() + + def dump_db(self): + res_comm_matrix = self.adapter.transfer_matrix_from_json_to_db(self.comm_ops_struct) + output_path = os.path.join(self.collection_path, Constant.CLUSTER_ANALYSIS_OUTPUT) + result_db = os.path.join(output_path, Constant.DB_CLUSTER_COMMUNICATION_ANALYZER) + DBManager.create_tables(result_db, self.COMMUNICATION_MATRIX_TABLE) + conn, cursor = DBManager.create_connect_db(result_db) + if res_comm_matrix: + res_matrix_value = [list(data.values()) for data in res_comm_matrix] + sql = "insert into {} values ({value})".format(self.COMMUNICATION_MATRIX_TABLE, + value="?," * (len(res_matrix_value[0]) - 1) + "?") + DBManager.executemany_sql(conn, sql, res_matrix_value) + DBManager.destroy_db_connect(conn, cursor) + + def compute_total_info(self, step_dict: dict): + self.merge_same_links(step_dict) + self.combine_link_info(step_dict) + + def merge_same_links(self, step_dict: dict): + def process_link_key(): + for link_key in rank_dict: + if '-' not in link_key: + print(f"[WARNING] {op_name} has an invalid link key {link_key}!") + break + src_rank = link_key.split('-')[0] + dst_rank = link_key.split('-')[1] + if src_rank == dst_rank: + if src_rank not in project_local_global_rank_map: + project_local_global_rank_map[src_rank] = rank_id + elif project_local_global_rank_map.get(src_rank) != rank_id: + print(f"[WARNING] In the same communication group, local ranks projecting to global ranks " + f"repeat!") + self.combine_link(link_info[link_key], rank_dict[link_key]) + + def convert_local_to_global_rank(): + tmp_link = {} + for link_key, link_dict in link_info.items(): + src_rank = link_key.split('-')[0] + dst_rank = link_key.split('-')[1] + src_rank = project_local_global_rank_map[src_rank] \ + if src_rank in project_local_global_rank_map else src_rank + dst_rank = project_local_global_rank_map[dst_rank] \ + if dst_rank in project_local_global_rank_map else dst_rank + link_dict[Constant.BANDWIDTH_GB_S] = \ + self.compute_ratio(link_dict.get(Constant.TRANSIT_SIZE_MB, 0), + link_dict.get(Constant.TRANSIT_TIME_MS, 0)) + tmp_link[f"{src_rank}-{dst_rank}"] = link_dict + return tmp_link + + project_local_global_rank_map = dict() + for op_name, op_dict in step_dict.items(): + link_info = defaultdict(lambda: { + Constant.TRANSPORT_TYPE: '', + Constant.TRANSIT_TIME_MS: 0, + Constant.TRANSIT_SIZE_MB: 0, + Constant.OP_NAME: '' + }) + for rank_id, rank_dict in op_dict.items(): + process_link_key() + step_dict[op_name] = convert_local_to_global_rank() + + def combine_link_info(self, step_dict: dict): + total_op_info = defaultdict(lambda: { + Constant.TRANSPORT_TYPE: '', + Constant.TRANSIT_TIME_MS: 0, + Constant.TRANSIT_SIZE_MB: 0, + Constant.OP_NAME: '' + }) + for op_name, op_dict in step_dict.items(): + if self.check_add_op(op_name): + for link_key, link_dict in op_dict.items(): + self.combine_link(total_op_info[link_key], link_dict) + for link_key, link_dict in total_op_info.items(): + link_dict[Constant.BANDWIDTH_GB_S] = \ + self.compute_ratio(link_dict.get(Constant.TRANSIT_SIZE_MB, 0), + link_dict.get(Constant.TRANSIT_TIME_MS, 0)) + step_dict[Constant.TOTAL_OP_INFO] = total_op_info diff --git a/profiler/cluster_analyse/analysis/communication_analysis.py b/profiler/cluster_analyse/analysis/communication_analysis.py index 88ac073a9cc899ecfb32378a8aca662de2bfe879..3f0a9b417e211b124b052cb5c5534f2fdbe5302e 100644 --- a/profiler/cluster_analyse/analysis/communication_analysis.py +++ b/profiler/cluster_analyse/analysis/communication_analysis.py @@ -1,75 +1,15 @@ -# Copyright (c) 2023, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - +import os from collections import defaultdict -from abc import abstractmethod +from analysis.base_analysis import BaseAnalysis from common_func.constant import Constant -from common_func.file_manager import FileManager - - -class BaseCommAnalysis: - - def __init__(self, param: dict): - self.collection_path = param.get(Constant.COLLECTION_PATH) - self.data_map = param.get(Constant.DATA_MAP) - self.communication_ops = [] - self.collective_group_dict = param.get(Constant.COMM_DATA_DICT, {}).get(Constant.COLLECTIVE_GROUP) - self.comm_ops_struct = {} +from common_func.db_manager import DBManager - @staticmethod - def compute_ratio(dividend: float, divisor: float): - if abs(divisor) < Constant.EPS: - return 0 - else: - return round(dividend / divisor, 4) - - @abstractmethod - def run(self): - pass - - def dump_data(self): - if not self.comm_ops_struct: - print("[WARNING] There is no final comm ops data generated") - return - output_comm_data = {} - for key in self.comm_ops_struct: - output_comm_data[str(key)] = self.comm_ops_struct.get(key) - FileManager.create_json_file(self.collection_path, output_comm_data, self.SAVED_JSON) - def split_op_by_group(self): - for single_op in self.communication_ops: - if single_op.get(Constant.COMM_OP_TYPE) == Constant.P2P: - rank_tup = Constant.P2P - else: - rank_tup = tuple(self.collective_group_dict.get(single_op.get(Constant.GROUP_NAME), [])) - rank_id = single_op.get(Constant.RANK_ID, 'N/A') - step_id = single_op.get(Constant.STEP_ID, 'N/A') - op_name = single_op.get(Constant.COMM_OP_NAME, 'N/A') - op_info = single_op.get(Constant.COMM_OP_INFO) - self.comm_ops_struct.setdefault(rank_tup, {}).setdefault(step_id, {}).\ - setdefault(op_name, {}).setdefault(rank_id, op_info) - - def combine_ops_total_info(self): - for rank_tup, group_dict in self.comm_ops_struct.items(): - for step_id, communication_ops in group_dict.items(): - self.compute_total_info(communication_ops) - - -class CommunicationAnalysis(BaseCommAnalysis): +class CommunicationAnalysis(BaseAnalysis): SAVED_JSON = "cluster_communication.json" + COMMUNICATION_BANDWIDTH_TABLE = "ClusterCommAnalyzerBandwidth" + COMMUNICATION_TIME_TABLE = "ClusterCommAnalyzerTime" def __init__(self, param: dict): super().__init__(param) @@ -88,6 +28,23 @@ class CommunicationAnalysis(BaseCommAnalysis): self.combine_ops_total_info() self.dump_data() + def dump_db(self): + res_comm_time, res_comm_bandwidth = self.adapter.transfer_comm_from_json_to_db(self.comm_ops_struct) + output_path = os.path.join(self.collection_path, Constant.CLUSTER_ANALYSIS_OUTPUT) + result_db = os.path.join(output_path, Constant.DB_CLUSTER_COMMUNICATION_ANALYZER) + DBManager.create_tables(result_db, self.COMMUNICATION_TIME_TABLE, self.COMMUNICATION_BANDWIDTH_TABLE) + conn, cursor = DBManager.create_connect_db(result_db) + self.execute(conn, res_comm_time, self.COMMUNICATION_TIME_TABLE) + self.execute(conn, res_comm_bandwidth, self.COMMUNICATION_BANDWIDTH_TABLE) + DBManager.destroy_db_connect(conn, cursor) + + @staticmethod + def execute(conn, res_data, table_name): + if res_data: + res_value = [list(data.values()) for data in res_data] + sql = "insert into {} values ({value})".format(table_name, value="?," * (len(res_value[0]) - 1) + "?") + DBManager.executemany_sql(conn, sql, res_value) + def compute_total_info(self, comm_ops: dict): if not comm_ops: return @@ -144,100 +101,3 @@ class CommunicationAnalysis(BaseCommAnalysis): bandwidth_dict[Constant.BANDWIDTH_GB_S] = \ self.compute_ratio(bandwidth_dict.get(Constant.TRANSIT_SIZE_MB, 0), bandwidth_dict.get(Constant.TRANSIT_TIME_MS, 0)) - - -class CommMatrixAnalysis(BaseCommAnalysis): - SAVED_JSON = "cluster_communication_matrix.json" - STAT_LIST = ['middle', 'top', 'bottom', 'total'] - TOTAL = 'total' - - def __init__(self, param: dict): - super().__init__(param) - self.communication_ops = param.get(Constant.COMM_DATA_DICT, {}).get(Constant.MATRIX_OPS) - - @staticmethod - def combine_link(link_info_dict: dict, single_link_dict: dict): - link_info_dict[Constant.TRANSPORT_TYPE] = single_link_dict.get(Constant.TRANSPORT_TYPE) - link_info_dict[Constant.OP_NAME] = single_link_dict.get(Constant.OP_NAME, '') - link_info_dict[Constant.TRANSIT_TIME_MS] += single_link_dict.get(Constant.TRANSIT_TIME_MS, 0) - link_info_dict[Constant.TRANSIT_SIZE_MB] += single_link_dict.get(Constant.TRANSIT_SIZE_MB, 0) - - def run(self): - if not self.communication_ops: - return - self.split_op_by_group() - self.combine_ops_total_info() - self.dump_data() - - def compute_total_info(self, step_dict: dict): - self.merge_same_links(step_dict) - self.combine_link_info(step_dict) - - def merge_same_links(self, step_dict: dict): - def process_link_key(): - for link_key in rank_dict: - if '-' not in link_key: - print(f"[WARNING] {op_name} has an invalid link key {link_key}!") - break - src_rank = link_key.split('-')[0] - dst_rank = link_key.split('-')[1] - if src_rank == dst_rank: - if src_rank not in project_local_global_rank_map: - project_local_global_rank_map[src_rank] = rank_id - elif project_local_global_rank_map.get(src_rank) != rank_id: - print(f"[WARNING] In the same communication group, local ranks projecting to global ranks repeat!") - self.combine_link(link_info[link_key], rank_dict[link_key]) - - def convert_local_to_global_rank(): - tmp_link = {} - for link_key, link_dict in link_info.items(): - src_rank = link_key.split('-')[0] - dst_rank = link_key.split('-')[1] - src_rank = project_local_global_rank_map[src_rank] \ - if src_rank in project_local_global_rank_map else src_rank - dst_rank = project_local_global_rank_map[dst_rank] \ - if dst_rank in project_local_global_rank_map else dst_rank - link_dict[Constant.BANDWIDTH_GB_S] = \ - self.compute_ratio(link_dict.get(Constant.TRANSIT_SIZE_MB, 0), - link_dict.get(Constant.TRANSIT_TIME_MS, 0)) - tmp_link[f"{src_rank}-{dst_rank}"] = link_dict - return tmp_link - - project_local_global_rank_map = dict() - for op_name, op_dict in step_dict.items(): - link_info = defaultdict(lambda: { - Constant.TRANSPORT_TYPE: '', - Constant.TRANSIT_TIME_MS: 0, - Constant.TRANSIT_SIZE_MB: 0, - Constant.OP_NAME: '' - }) - for rank_id, rank_dict in op_dict.items(): - process_link_key() - step_dict[op_name] = convert_local_to_global_rank() - - def combine_link_info(self, step_dict: dict): - total_op_info = defaultdict(lambda: { - Constant.TRANSPORT_TYPE: '', - Constant.TRANSIT_TIME_MS: 0, - Constant.TRANSIT_SIZE_MB: 0, - Constant.OP_NAME: '' - }) - for op_name, op_dict in step_dict.items(): - if self.check_add_op(op_name): - for link_key, link_dict in op_dict.items(): - self.combine_link(total_op_info[link_key], link_dict) - for link_key, link_dict in total_op_info.items(): - link_dict[Constant.BANDWIDTH_GB_S] = \ - self.compute_ratio(link_dict.get(Constant.TRANSIT_SIZE_MB, 0), - link_dict.get(Constant.TRANSIT_TIME_MS, 0)) - step_dict[Constant.TOTAL_OP_INFO] = total_op_info - - def check_add_op(self: any, op_name: str): - """ - 兼容2个版本,判断是否需要将此算子信息相加 - """ - for stat_name in self.STAT_LIST: - if stat_name in op_name: - if stat_name != self.TOTAL: - return False - return True diff --git a/profiler/cluster_analyse/analysis/step_trace_time_analysis.py b/profiler/cluster_analyse/analysis/step_trace_time_analysis.py index d24a7f1fe635e62c0857e276578463539a61ee76..f570deee1c9ac53f7bbe65be9660d9e014576d04 100644 --- a/profiler/cluster_analyse/analysis/step_trace_time_analysis.py +++ b/profiler/cluster_analyse/analysis/step_trace_time_analysis.py @@ -14,8 +14,8 @@ # limitations under the License. import os -from collections import defaultdict +from common_func.db_manager import DBManager from common_func.constant import Constant from common_func.file_manager import FileManager from prof_bean.step_trace_time_bean import StepTraceTimeBean @@ -23,6 +23,7 @@ from prof_bean.step_trace_time_bean import StepTraceTimeBean class StepTraceTimeAnalysis: CLUSTER_TRACE_TIME_CSV = "cluster_step_trace_time.csv" + CLUSTER_TRACE_TIME_TABLE = "ClusterStepTraceTime" def __init__(self, param: dict): self.collection_path = param.get(Constant.COLLECTION_PATH) @@ -30,6 +31,7 @@ class StepTraceTimeAnalysis: self.communication_group = param.get(Constant.COMM_DATA_DICT, {}).get(Constant.COMMUNICATION_GROUP) self.step_time_dict = {} self.step_data_list = [] + self.data_type = param.get(Constant.DATA_TYPE) @staticmethod def get_max_data_row(data_group_list: list): @@ -51,21 +53,46 @@ class StepTraceTimeAnalysis: def dump_data(self): if not self.step_data_list: print("[WARNING] Can't get step time info!") - headers = self.get_headers() - FileManager.create_csv_file(self.collection_path, self.step_data_list, self.CLUSTER_TRACE_TIME_CSV, headers) + return + if self.data_type == Constant.TEXT: + headers = self.get_headers() + FileManager.create_csv_file(self.collection_path, self.step_data_list, self.CLUSTER_TRACE_TIME_CSV, headers) + else: + output_path = os.path.join(self.collection_path, Constant.CLUSTER_ANALYSIS_OUTPUT) + result_db = os.path.join(output_path, Constant.DB_CLUSTER_COMMUNICATION_ANALYZER) + DBManager.create_tables(result_db, self.CLUSTER_TRACE_TIME_TABLE) + conn, cursor = DBManager.create_connect_db(result_db) + sql = "insert into {} values ({value})".format(self.CLUSTER_TRACE_TIME_TABLE, + value="?," * (len(self.step_data_list[0]) - 1) + "?") + DBManager.executemany_sql(conn, sql, self.step_data_list) + DBManager.destroy_db_connect(conn, cursor) def load_step_trace_time_data(self): for rank_id, profiling_dir_path in self.data_map.items(): - step_time_file = os.path.join(profiling_dir_path, Constant.SINGLE_OUTPUT, Constant.STEP_TIME_CSV) - if step_time_file: - self.step_time_dict[rank_id] = FileManager.read_csv_file(step_time_file, StepTraceTimeBean) + if self.data_type == Constant.TEXT: + step_time_file = os.path.join(profiling_dir_path, Constant.SINGLE_OUTPUT, Constant.STEP_TIME_CSV) + if os.path.exists(step_time_file): + self.step_time_dict[rank_id] = FileManager.read_csv_file(step_time_file, StepTraceTimeBean) + else: + step_time_file = os.path.join(profiling_dir_path, Constant.SINGLE_OUTPUT, + Constant.DB_COMMUNICATION_ANALYZER) + if (os.path.exists(step_time_file) and + DBManager.check_tables_in_db(step_time_file, Constant.TABLE_STEP_TRACE)): + conn, cursor = DBManager.create_connect_db(step_time_file) + sql = "select * from {0}".format(Constant.TABLE_STEP_TRACE) + data = DBManager.fetch_all_data(cursor, sql, is_dict=False) + self.step_time_dict[rank_id] = data + DBManager.destroy_db_connect(conn, cursor) if not self.step_time_dict.get(rank_id): - print(f"[WARNING] Rank {rank_id} does not have a valid step_trace_time.json.") + print(f"[WARNING] Rank {rank_id} does not have a valid step_trace_time data in {self.data_type} file.") def analyze_step_time(self): for rank_id, data_bean_list in self.step_time_dict.items(): for data_bean in data_bean_list: - self.step_data_list.append([data_bean.step, Constant.RANK, rank_id] + data_bean.row) + if self.data_type == Constant.TEXT: + self.step_data_list.append([data_bean.step, Constant.RANK, rank_id] + data_bean.row) + else: + self.step_data_list.append([data_bean[0], Constant.RANK, rank_id] + list(data_bean[1:])) stage_list = self.communication_group.get(Constant.P2P) if not stage_list: return @@ -80,7 +107,11 @@ class StepTraceTimeAnalysis: step_group_dict.setdefault(key, []).append(data_list[3:]) for key, data_group_list in step_group_dict.items(): - self.step_data_list.append([key[0], Constant.STAGE, key[1]] + self.get_max_data_row(data_group_list)) + if self.data_type == Constant.TEXT: + self.step_data_list.append([key[0], Constant.STAGE, key[1]] + self.get_max_data_row(data_group_list)) + else: + index = "(" + ",".join(str(i) for i in key[1]) + ")" + self.step_data_list.append([key[0], Constant.STAGE, index] + self.get_max_data_row(data_group_list)) def get_headers(self): if self.step_time_dict: diff --git a/profiler/cluster_analyse/cluster_analysis.py b/profiler/cluster_analyse/cluster_analysis.py index e07cac170300650bbf735f7e302b33377dd30a5e..24454622119acbb223c70dfea65d3b792b00444c 100644 --- a/profiler/cluster_analyse/cluster_analysis.py +++ b/profiler/cluster_analyse/cluster_analysis.py @@ -47,25 +47,31 @@ class Interface: ascend_pt_dirs.append(os.path.join(root, dir_name)) if dir_name.endswith(self.ASCEND_MS): ascend_ms_dirs.append(os.path.join(root, dir_name)) - pt_data_map = PytorchDataPreprocessor(ascend_pt_dirs).get_data_map() + pytorch_processor = PytorchDataPreprocessor(ascend_pt_dirs) + pt_data_map = pytorch_processor.get_data_map() + data_type = pytorch_processor.get_data_type() ms_data_map = MindsporeDataPreprocessor(ascend_ms_dirs).get_data_map() if pt_data_map and ms_data_map: print("[ERROR] Can not analyze pytorch and mindspore meantime.") - return[] - return pt_data_map if pt_data_map else ms_data_map + return [] + return (pt_data_map, data_type) if pt_data_map else (ms_data_map, Constant.TEXT) def run(self): PathManager.check_input_directory_path(self.collection_path) PathManager.check_path_owner_consistent(self.collection_path) FileManager.create_output_dir(self.collection_path) - data_map = self.allocate_prof_data() + data_map, data_type = self.allocate_prof_data() if not data_map: print("[WARNING] Can not get rank info or profiling data.") return + if data_type == Constant.INVALID: + print("[ERROR] The current folder contains both DB and other files. Please check.") + return params = { Constant.COLLECTION_PATH: self.collection_path, Constant.DATA_MAP: data_map, - Constant.ANALYSIS_MODE: self.analysis_mode + Constant.ANALYSIS_MODE: self.analysis_mode, + Constant.DATA_TYPE: data_type } comm_data_dict = CommunicationGroupGenerator(params).generate() params[Constant.COMM_DATA_DICT] = comm_data_dict diff --git a/profiler/cluster_analyse/cluster_data_preprocess/data_preprocessor.py b/profiler/cluster_analyse/cluster_data_preprocess/data_preprocessor.py index ebc9647c208b05f51698563b8dabb7d13c28c7ec..72d65ae6571e68564e46f43463843d1f46a3a69e 100644 --- a/profiler/cluster_analyse/cluster_data_preprocess/data_preprocessor.py +++ b/profiler/cluster_analyse/cluster_data_preprocess/data_preprocessor.py @@ -12,15 +12,30 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import os from abc import abstractmethod class DataPreprocessor: - def __init__(self, collection_path: str): - self.collection_path = collection_path + PROFILER_INFO_HEAD = 'profiler_info_' + PROFILER_INFO_EXTENSION = '.json' + + def __init__(self, path_list: list): + self.path_list = path_list self.data_map = {} @abstractmethod - def input_data(self): + def get_data_map(self): pass + + def get_rank_id(self, dir_name: str) -> int: + files = os.listdir(dir_name) + for file_name in files: + if file_name.startswith(self.PROFILER_INFO_HEAD) and file_name.endswith(self.PROFILER_INFO_EXTENSION): + rank_id_str = file_name[len(self.PROFILER_INFO_HEAD): -1 * len(self.PROFILER_INFO_EXTENSION)] + try: + rank_id = int(rank_id_str) + except ValueError: + rank_id = -1 + return rank_id + return -1 diff --git a/profiler/cluster_analyse/cluster_data_preprocess/mindspore_data_preprocessor.py b/profiler/cluster_analyse/cluster_data_preprocess/mindspore_data_preprocessor.py index 85debdd31bb07cf96b91c12eb731cc00b00fcaa3..a3e09983ddb54b972a9e343c1661b5c8b2cbb8c8 100644 --- a/profiler/cluster_analyse/cluster_data_preprocess/mindspore_data_preprocessor.py +++ b/profiler/cluster_analyse/cluster_data_preprocess/mindspore_data_preprocessor.py @@ -14,17 +14,14 @@ # limitations under the License. from collections import defaultdict -import os -from common_func.file_manager import FileManager -from common_func.path_manager import PathManager +from cluster_data_preprocess.data_preprocessor import DataPreprocessor -class MindsporeDataPreprocessor: - PROFILER_INFO_HEAD = 'profiler_info_' - PROFILER_INFO_EXTENSION = '.json' - def __init__(self, path_list: str): - self.path_list = path_list +class MindsporeDataPreprocessor(DataPreprocessor): + + def __init__(self, path_list: list): + super().__init__(path_list) def get_data_map(self) -> dict: rank_id_map = defaultdict(list) @@ -35,23 +32,10 @@ class MindsporeDataPreprocessor: continue rank_id_map[rank_id].append(dir_name) - ret_dict = dict() try: for (rank_id, dir_list) in rank_id_map.items(): dir_list.sort(key=lambda x: x.split('_')[-3]) - ret_dict[rank_id] = dir_list[0] + self.data_map[rank_id] = dir_list[0] except Exception as e: raise RuntimeError("Found invalid directory name!") from e - return ret_dict - - def get_rank_id(self, dir_name: str) -> int: - files = os.listdir(dir_name) - for file_name in files: - if file_name.startswith(self.PROFILER_INFO_HEAD) and file_name.endswith(self.PROFILER_INFO_EXTENSION): - rank_id_str = file_name[len(self.PROFILER_INFO_HEAD): -1 * len(self.PROFILER_INFO_EXTENSION)] - try: - rank_id = int(rank_id_str) - except ValueError: - rank_id = -1 - return rank_id - return -1 + return self.data_map diff --git a/profiler/cluster_analyse/cluster_data_preprocess/pytorch_data_preprocessor.py b/profiler/cluster_analyse/cluster_data_preprocess/pytorch_data_preprocessor.py index f1e4c062a7c05656980f0767a3180154e91942ae..55c3d03958b97c427fe8fde0625e72ea4dee8997 100644 --- a/profiler/cluster_analyse/cluster_data_preprocess/pytorch_data_preprocessor.py +++ b/profiler/cluster_analyse/cluster_data_preprocess/pytorch_data_preprocessor.py @@ -12,19 +12,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import glob from collections import defaultdict import os + +from cluster_data_preprocess.data_preprocessor import DataPreprocessor +from common_func.constant import Constant from common_func.file_manager import FileManager -from common_func.path_manager import PathManager -class PytorchDataPreprocessor: - PROFILER_INFO_HEAD = 'profiler_info_' - PROFILER_INFO_EXTENSION = '.json' +class PytorchDataPreprocessor(DataPreprocessor): - def __init__(self, path_list: str): - self.path_list = path_list + def __init__(self, path_list: list): + super().__init__(path_list) + self.data_type = set() def get_data_map(self) -> dict: rank_id_map = defaultdict(list) @@ -33,25 +34,23 @@ class PytorchDataPreprocessor: if rank_id < 0: print('[Error]fail to get rankid or rankid invalid.') continue + for file_name in os.listdir(dir_name): + if file_name.startswith(self.PROFILER_INFO_HEAD) and file_name.endswith(self.PROFILER_INFO_EXTENSION): + file_path = os.path.join(dir_name, file_name) + config = FileManager.read_json_file(file_path) + self.data_type.add(config.get(Constant.CONFIG, {}).get(Constant.EXPER_CONFIG, {}). + get(Constant.EXPORT_TYPE, Constant.TEXT)) rank_id_map[rank_id].append(dir_name) - ret_dict = dict() try: for (rank_id, dir_list) in rank_id_map.items(): dir_list.sort(key=lambda x: x.split('_')[-3]) - ret_dict[rank_id] = dir_list[0] + self.data_map[rank_id] = dir_list[0] except Exception as e: raise RuntimeError("Found invalid directory name!") from e - return ret_dict + return self.data_map - def get_rank_id(self, dir_name: str) -> int: - files = os.listdir(dir_name) - for file_name in files: - if file_name.startswith(self.PROFILER_INFO_HEAD) and file_name.endswith(self.PROFILER_INFO_EXTENSION): - rank_id_str = file_name[len(self.PROFILER_INFO_HEAD): -1 * len(self.PROFILER_INFO_EXTENSION)] - try: - rank_id = int(rank_id_str) - except ValueError: - rank_id = -1 - return rank_id - return -1 + def get_data_type(self): + if len(self.data_type) == 1: + return self.data_type.pop() + return Constant.INVALID diff --git a/profiler/cluster_analyse/common_func/constant.py b/profiler/cluster_analyse/common_func/constant.py index e426a9d22567ae9e70411f709c1c09ce02cbdeca..3b4126de792357b6a7a0d4d0d4dbce40067c4651 100644 --- a/profiler/cluster_analyse/common_func/constant.py +++ b/profiler/cluster_analyse/common_func/constant.py @@ -30,6 +30,7 @@ class Constant(object): MAX_JSON_SIZE = 1024 * 1024 * 1024 * 10 MAX_CSV_SIZE = 1024 * 1024 * 1024 * 5 MAX_PATH_LENGTH = 4096 + MAX_READ_DB_FILE_BYTES = 1024 * 1024 * 1024 * 8 # communication P2P = "p2p" @@ -56,6 +57,9 @@ class Constant(object): OP_NAME = "Op Name" BANDWIDTH_GB_S = "Bandwidth(GB/s)" COMMUNICATION = "communication.json" + ELAPSE_TIME_MS = "Elapse Time(ms)" + IDLE_TIME_MS = "Idle Time(ms)" + LARGE_PACKET_RATIO = "Large Packet Ratio" # params DATA_MAP = "data_map" @@ -66,11 +70,12 @@ class Constant(object): COMMUNICATION_GROUP = "communication_group" TRANSPORT_TYPE = "Transport Type" COMM_DATA_DICT = "comm_data_dict" + DATA_TYPE = "data_type" ANALYSIS_MODE = "analysis_mode" # step time - RANK = 'rank' - STAGE = 'stage' + RANK = "rank" + STAGE = "stage" # epsilon EPS = 1e-15 @@ -78,3 +83,23 @@ class Constant(object): # file suffix JSON_SUFFIX = ".json" CSV_SUFFIX = ".csv" + + # result files type + TEXT = "text" + DB = "db" + INVALID = "invalid" + + # db name + DB_COMMUNICATION_ANALYZER = "analysis.db" + DB_CLUSTER_COMMUNICATION_ANALYZER = "cluster_analysis.db" + + # db tables + TABLE_COMM_ANALYZER_BANDWIDTH = "CommAnalyzerBandwidth" + TABLE_COMM_ANALYZER_TIME = "CommAnalyzerTime" + TABLE_COMM_ANALYZER_MATRIX = "CommAnalyzerMatrix" + TABLE_STEP_TRACE = "StepTraceTime" + + # data config key + CONFIG = "config" + EXPER_CONFIG = "experimental_config" + EXPORT_TYPE = "_export_type" diff --git a/profiler/cluster_analyse/common_func/db_manager.py b/profiler/cluster_analyse/common_func/db_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..7b1d641d745d957aca0a83d9c1d62e77575228f4 --- /dev/null +++ b/profiler/cluster_analyse/common_func/db_manager.py @@ -0,0 +1,211 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sqlite3 + +from common_func.constant import Constant +from common_func.empty_class import EmptyClass +from common_func.file_manager import check_db_path_valid +from common_func.tables_config import TablesConfig + + +class DBManager: + """ + class to manage DB operation + """ + FETCH_SIZE = 10000 + INSERT_SIZE = 10000 + MAX_ROW_COUNT = 100000000 + + @staticmethod + def create_connect_db(db_path: str) -> tuple: + """ + create and connect database + """ + if check_db_path_valid(db_path, is_create=True): + try: + conn = sqlite3.connect(db_path) + except sqlite3.Error as err: + print(f"[ERROR] {err}") + return EmptyClass("empty conn"), EmptyClass("empty curs") + try: + if isinstance(conn, sqlite3.Connection): + curs = conn.cursor() + os.chmod(db_path, Constant.FILE_AUTHORITY) + return conn, curs + except sqlite3.Error as err: + print(f"[ERROR] {err}") + return EmptyClass("empty conn"), EmptyClass("empty curs") + return EmptyClass("empty conn"), EmptyClass("empty curs") + + @staticmethod + def destroy_db_connect(conn: any, curs: any) -> None: + """ + destroy db connection + """ + try: + if isinstance(curs, sqlite3.Cursor): + curs.close() + except sqlite3.Error as err: + print(f"[ERROR] {err}") + try: + if isinstance(conn, sqlite3.Connection): + conn.close() + except sqlite3.Error as err: + print(f"[ERROR] {err}") + + @staticmethod + def judge_table_exists(curs: any, table_name: str) -> any: + """ + judge table exists + """ + if not isinstance(curs, sqlite3.Cursor): + return False + try: + curs.execute("select count(*) from sqlite_master where type='table' and name=?", (table_name,)) + return curs.fetchone()[0] + except sqlite3.Error as err: + print("[ERROR] {}".format(err)) + return False + + @staticmethod + def sql_generate_table(table_map: str): + header_with_type_begin = "(" + header_with_type_end = ")" + header_with_type_list = [] + if table_map in TablesConfig.DATA: + items = TablesConfig.DATA[table_map] + for item in items: + if item[0] == "index": + header_with_type_list.append('"' + item[0] + '" ' + item[1].split(",")[0]) + else: + header_with_type_list.append(item[0] + ' ' + item[1].split(",")[0]) + header_with_type_begin += ",".join(header_with_type_list) + header_with_type_begin += header_with_type_end + return header_with_type_begin + return "" + + @classmethod + def check_tables_in_db(cls, db_path: any, *tables: any) -> bool: + if check_db_path_valid(db_path): + conn, curs = cls.create_connect_db(db_path) + if not (conn and curs): + return False + res = True + for table in tables: + if not cls.judge_table_exists(curs, table): + res = False + break + cls.destroy_db_connect(conn, curs) + return res + return False + + @classmethod + def create_tables(cls, db_path: any, *tables: any): + conn, curs = cls.create_connect_db(db_path) + if not (conn and curs): + return + for table_name in tables: + if cls.judge_table_exists(curs, table_name): + drop_sql = "drop table {0}".format(table_name) + cls.execute_sql(conn, drop_sql) + table_map = "{0}Map".format(table_name) + header_with_type = cls.sql_generate_table(table_map) + sql = "CREATE TABLE IF NOT EXISTS " + table_name + header_with_type + cls.execute_sql(conn, sql) + cls.destroy_db_connect(conn, curs) + + @staticmethod + def execute_sql(conn: any, sql: str, params: any = None) -> bool: + """ + execute sql + """ + try: + if isinstance(conn, sqlite3.Connection): + if params: + conn.cursor().execute(sql, params) + else: + conn.cursor().execute(sql) + conn.commit() + return True + except sqlite3.Error as err: + print(f"[ERROR] {err}") + return False + print("[ERROR] conn is invalid param") + return False + + @staticmethod + def executemany_sql(conn: any, sql: str, params: any) -> bool: + """ + execute many sql once + """ + try: + if isinstance(conn, sqlite3.Connection): + conn.cursor().executemany(sql, params) + conn.commit() + return True + except sqlite3.Error as err: + print(f"[ERROR] {err}") + return False + print("[ERROR] conn is invalid param") + return False + + @classmethod + def fetch_all_data(cls: any, curs: any, sql: str, param: tuple = None, is_dict: bool = True) -> list: + """ + fetch 10000 num of data from db each time to get all data + """ + if not isinstance(curs, sqlite3.Cursor): + return [] + data = [] + try: + if param: + res = curs.execute(sql, param) + else: + res = curs.execute(sql) + except sqlite3.Error as err: + print(f"[ERROR] {err}") + curs.row_factory = None + return [] + try: + description = res.description + while True: + res = curs.fetchmany(cls.FETCH_SIZE) + if is_dict: + data += CustomizedDictFactory.generate_dict_from_db(res, description) + else: + data += res + if len(data) > cls.MAX_ROW_COUNT: + print("[WARRING] The records count in the table exceeds the limit!") + if len(res) < cls.FETCH_SIZE: + break + return data + except sqlite3.Error as err: + print(f"[ERROR] {err}") + return [] + finally: + curs.row_factory = None + + +class CustomizedDictFactory: + @staticmethod + def generate_dict_from_db(data_result: any, description: any) -> any: + description_set = [i[0] for i in description] + res = [] + for data in data_result: + data_dict = dict(zip(description_set, data)) + res.append(data_dict) + return res diff --git a/profiler/cluster_analyse/common_func/empty_class.py b/profiler/cluster_analyse/common_func/empty_class.py new file mode 100644 index 0000000000000000000000000000000000000000..df100d156fa064cca4514260db0b2e843e217d09 --- /dev/null +++ b/profiler/cluster_analyse/common_func/empty_class.py @@ -0,0 +1,20 @@ +class EmptyClass: + + def __init__(self: any, info: str = "") -> None: + self._info = info + + @classmethod + def __bool__(cls: any) -> bool: + return False + + @classmethod + def __str__(cls: any) -> str: + return "" + + @property + def info(self: any) -> str: + return self._info + + @staticmethod + def is_empty() -> bool: + return True diff --git a/profiler/cluster_analyse/common_func/file_manager.py b/profiler/cluster_analyse/common_func/file_manager.py index 3853c806f92de1d8da14e32105fcc869789a9a40..28ecbeaaf16ec5461660f414df03728b36b521d7 100644 --- a/profiler/cluster_analyse/common_func/file_manager.py +++ b/profiler/cluster_analyse/common_func/file_manager.py @@ -115,3 +115,13 @@ class FileManager: file_size = os.path.getsize(file_path) if file_size > limit_size: raise RuntimeError(f"The file({base_name}) size exceeds the preset max value.") + + +def check_db_path_valid(path: str, is_create: bool = False, max_size: int = Constant.MAX_READ_DB_FILE_BYTES) -> bool: + if os.path.islink(path): + print(f'[ERROR] The db file path: {path} is link. Please check the path') + return False + if not is_create and os.path.exists(path) and os.path.getsize(path) > max_size: + print(f'[ERROR] The db file: {path} is too large to read. Please check the file') + return False + return True diff --git a/profiler/cluster_analyse/common_func/table_constant.py b/profiler/cluster_analyse/common_func/table_constant.py new file mode 100644 index 0000000000000000000000000000000000000000..de6d47e97e5683493905de5353a9978195e87b70 --- /dev/null +++ b/profiler/cluster_analyse/common_func/table_constant.py @@ -0,0 +1,27 @@ +class TableConstant: + + RANK_SET = "rank_set" + STEP = "step" + RANK_ID = "rank_id" + TYPE = "type" + HCCL_OP_NAME = "hccl_op_name" + GROUP_NAME = "group_name" + START_TIMESTAMP = "start_timestamp" + ELAPSED_TIME = "elapse_time" + TRANSIT_TIME = "transit_time" + WAIT_TIME = "wait_time" + SYNCHRONIZATION_TIME = "synchronization_time" + IDLE_TIME = "idle_time" + SYNCHRONIZATION_TIME_RATIO = "synchronization_time_ratio" + WAIT_TIME_RATIO = "wait_time_ratio" + BAND_TYPE = "band_type" + TRANSIT_SIZE = "transit_size" + BANDWIDTH = "bandwidth" + LARGE_PACKET_RATIO = "large_packet_ratio" + PACKAGE_SIZE = "package_size" + COUNT = "count" + TOTAL_DURATION = "total_duration" + SRC_RANK = "src_rank" + DST_RANK = "dst_rank" + TRANSPORT_TYPE = "transport_type" + OPNAME = "op_name" diff --git a/profiler/cluster_analyse/common_func/tables_config.py b/profiler/cluster_analyse/common_func/tables_config.py new file mode 100644 index 0000000000000000000000000000000000000000..0fe7d07ed12f67cf2c03181004ddb91f9a47c4d6 --- /dev/null +++ b/profiler/cluster_analyse/common_func/tables_config.py @@ -0,0 +1,64 @@ +class TablesConfig: + DATA = { + "ClusterCommAnalyzerTimeMap": [ + ("rank_set", "TEXT, null"), + ("step", "TEXT, null"), + ("rank_id", "INTEGER, null"), + ("hccl_op_name", "TEXT, null"), + ("group_name", "TEXT, null"), + ("start_timestamp", "NUMERIC, null"), + ("elapsed_time", "NUMERIC, null"), + ("transit_time", "NUMERIC, null"), + ("wait_time", "NUMERIC, null"), + ("synchronization_time", "NUMERIC, null"), + ("idle_time", "NUMERIC, null"), + ("synchronization_time_ratio", "NUMERIC, null"), + ("wait_time_ratio", "NUMERIC, null") + ], + "CommunicationGroupMap": [ + ("type", "TEXT, null"), + ("rank_set", "TEXT, null") + ], + "ClusterCommAnalyzerBandwidthMap": [ + ("rank_set", "TEXT, null"), + ("step", "TEXT, null"), + ("rank_id", "INTEGER, null"), + ("hccl_op_name", "TEXT, null"), + ("group_name", "TEXT, null"), + ("band_type", "TEXT, null"), + ("transit_size", "NUMERIC, null"), + ("transit_time", "NUMERIC, null"), + ("bandwidth", "NUMERIC, null"), + ("large_packet_ratio", "NUMERIC, null"), + ("package_size", "NUMERIC, null"), + ("count", "NUMERIC, null"), + ("total_duration", "NUMERIC, null") + ], + "ClusterCommAnalyzerMatrixMap": [ + ("rank_set", "TEXT, null"), + ("step", "TEXT, null"), + ("hccl_op_name", "TEXT, null"), + ("group_name", "TEXT, null"), + ("src_rank", "TEXT, null"), + ("dst_rank", "TEXT, null"), + ("transit_size", "NUMERIC, null"), + ("transit_time", "NUMERIC, null"), + ("bandwidth", "NUMERIC, null"), + ("transport_type", "TEXT, null"), + ("op_name", "TEXT, null") + ], + "ClusterStepTraceTimeMap": [ + ("step", "TEXT, null"), + ("type", "TEXT, null"), + ("index", "TEXT, null"), + ("computing", "NUMERIC, null"), + ("communication_not_overlapped", "NUMERIC, null"), + ("overlapped", "NUMERIC, null"), + ("communication", "NUMERIC, null"), + ("free", "NUMERIC, null"), + ("stage", "NUMERIC, null"), + ("bubble", "NUMERIC, null"), + ("communication_not_overlapped_and_exclude_receive", "NUMERIC, null"), + ("preparing", "NUMERIC, null") + ] + } diff --git a/profiler/cluster_analyse/communication_group/base_communication_group.py b/profiler/cluster_analyse/communication_group/base_communication_group.py new file mode 100644 index 0000000000000000000000000000000000000000..923d479ee736edabc4c9e7a137664f3426b593e8 --- /dev/null +++ b/profiler/cluster_analyse/communication_group/base_communication_group.py @@ -0,0 +1,227 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from abc import abstractmethod +from collections import defaultdict +from copy import deepcopy +from multiprocessing import Pool + +from common_func.constant import Constant +from utils.data_transfer_adapter import DataTransferAdapter + + +class BaseCommunicationGroup: + def __init__(self, params: dict): + self.collection_path = params.get(Constant.COLLECTION_PATH) + self.data_map = params.get(Constant.DATA_MAP) + self.data_type = params.get(Constant.DATA_TYPE) + self.analysis_mode = params.get(Constant.ANALYSIS_MODE) + self.rank_comm_dir_dict = {} + self.p2p_link = [] + self.collective_group_dict = defaultdict(set) + self.p2p_comm_group = [] + self.communication_group = {} + self.communication_ops = [] + self.matrix_ops = [] + self.adapter = DataTransferAdapter() + + def load_communication_data(self): + comm_op_dirs = [] + for rank_id, profiling_dir_path in self.data_map.items(): + if self.data_type == Constant.TEXT: + comm_dir = os.path.join(profiling_dir_path, Constant.SINGLE_OUTPUT, Constant.COMM_JSON) + matrix_dir = os.path.join(profiling_dir_path, Constant.SINGLE_OUTPUT, Constant.COMM_MATRIX_JSON) + else: + comm_dir = os.path.join(profiling_dir_path, Constant.SINGLE_OUTPUT, Constant.DB_COMMUNICATION_ANALYZER) + matrix_dir = comm_dir + if os.path.exists(comm_dir) or os.path.exists(matrix_dir): + comm_op_dirs.append((rank_id, comm_dir, matrix_dir)) + else: + print( + f"[WARNING] Rank {rank_id} does not have valid communication data and communication_matrix data.") + with Pool() as p: + self.rank_comm_dir_dict = p.map(self.read_communication_func, comm_op_dirs) + + def set_p2p_groups(self): + self.p2p_link = sorted(self.p2p_link, key=lambda x: min(x)) + while self.p2p_link: + union_set = deepcopy(self.p2p_link[0]) + rm_list = [self.p2p_link[0]] + for idx, link_rank_set_x in enumerate(self.p2p_link[1:]): + if UnionFind.is_connected(link_rank_set_x, union_set): + union_set = union_set.union(link_rank_set_x) + rm_list.append(link_rank_set_x) + self.p2p_comm_group.append(union_set) + self.p2p_link = [element for element in self.p2p_link if element not in rm_list] + + def generate_collective_communication_group(self): + self.communication_group[Constant.COLLECTIVE] = \ + [list(group) for group_name, group in self.collective_group_dict.items()] + + def generate_p2p_communication_group(self): + stage_group = {} + for group_name, rank_set in self.collective_group_dict.items(): + if not self.whether_valid_comm_group(rank_set): + continue + unioned_set = set() + remove_key = [] + for first_rank, stage in stage_group.items(): + if UnionFind.is_connected(rank_set, stage): + unioned_set = UnionFind.union(rank_set, stage, unioned_set) + remove_key.append(first_rank) + if unioned_set: + for key in remove_key: + del stage_group[key] + stage_group[min(unioned_set)] = unioned_set + else: + stage_group[min(rank_set)] = rank_set + first_rank_sort_list = sorted([first_rank for first_rank in stage_group]) + self.communication_group[Constant.P2P] = \ + [list(stage_group.get(first_rank, {})) for first_rank in first_rank_sort_list] + + def whether_valid_comm_group(self, rank_set: set): + """ + while distinguish which communication group should be used to infer stage info, these group should be ignored: + 1. group can not include more than 1 rank in every single p2p group + """ + for p2p_rank_set in self.p2p_comm_group: + if len(rank_set.intersection(p2p_rank_set)) > 1: + return False + return True + + @abstractmethod + def read_communication_func(self, params: tuple): + pass + + def analyze_communication_data(self): + for rank_id, rank_id_comm_dict, rank_id_matrix_dict in self.rank_comm_dir_dict: + for step_id, step_id_dict in rank_id_comm_dict.items(): + if not isinstance(step_id_dict, dict): + print(f"[WARNING] rank{rank_id}'s communication.json has a wrong data struct.") + continue + self.get_collective_ops_name(rank_id, step_id_dict.get(Constant.COLLECTIVE)) + for comm_op_type, comm_op_dict in step_id_dict.items(): + self.add_communication_ops(rank_id, step_id, comm_op_type, comm_op_dict) + + for step_id, step_id_dict in rank_id_matrix_dict.items(): + if not isinstance(step_id_dict, dict): + print(f"[WARNING] rank{rank_id}'s communication_matrix.json has a wrong data struct.") + continue + self.set_p2p_link(rank_id, step_id, rank_id_matrix_dict) + self.get_collective_ops_name(rank_id, step_id_dict.get(Constant.COLLECTIVE)) + + @abstractmethod + def dump_data(self): + pass + + def collect_comm_data(self): + comm_data_dict = { + Constant.COLLECTIVE_GROUP: self.collective_group_dict, + Constant.COMMUNICATION_OPS: self.communication_ops, + Constant.MATRIX_OPS: self.matrix_ops, + Constant.COMMUNICATION_GROUP: self.communication_group + } + return comm_data_dict + + def generate(self): + self.load_communication_data() + self.analyze_communication_data() + self.set_p2p_groups() + self.generate_collective_communication_group() + self.generate_p2p_communication_group() + self.dump_data() + return self.collect_comm_data() + + def set_p2p_link(self, rank_id: int, step_id: str, rank_id_matrix_dict: dict): + ops = rank_id_matrix_dict.get(step_id, {}) + self.add_matrix_ops(rank_id, step_id, ops) + if not ops: + print(f"[WARNING] rank{rank_id} {step_id} do not have communication matrix ops data.") + return + p2p_ops = ops.get(Constant.P2P, {}) + for op_name, link_dict in p2p_ops.items(): + self.append_p2p_link(op_name, link_dict) + + def append_p2p_link(self, op_name, link_dict): + for link in link_dict: + if '-' not in link: + print(f"[WARNING] {op_name} has an invalid link key {link}!") + break + src_rank = int(link.split('-')[0]) + dst_rank = int(link.split('-')[1]) + if src_rank != dst_rank: + rank_set = {src_rank, dst_rank} + if rank_set in self.p2p_link: + continue + self.p2p_link.append(rank_set) + + def get_collective_ops_name(self, rank_id: int, comm_op_dict: dict): + for comm_op in comm_op_dict: + if comm_op.startswith('Total'): + continue + group_name = comm_op.split('@')[-1] + self.collective_group_dict[group_name].add(rank_id) + + def add_communication_ops(self, rank_id: str, step_id: str, comm_op_type: str, comm_op_dict: dict): + for comm_op in comm_op_dict: + if comm_op.startswith('Total'): + continue + group_name = comm_op.split('@')[-1] + self.communication_ops.append({ + Constant.RANK_ID: rank_id, + Constant.STEP_ID: step_id, + Constant.COMM_OP_TYPE: comm_op_type, + Constant.COMM_OP_NAME: comm_op, + Constant.GROUP_NAME: group_name, + Constant.COMM_OP_INFO: comm_op_dict.get(comm_op) + }) + + def add_matrix_ops(self, rank_id: int, step_id: str, step_id_dict: dict): + for comm_op_type, comm_dict in step_id_dict.items(): + if comm_op_type != Constant.COLLECTIVE and comm_op_type != Constant.P2P: + print(f"[WARNING] Unknown communication operators type!") + continue + for op_name, op_link_info in comm_dict.items(): + if op_name.startswith('Total'): + continue + group_name = op_name.split('@')[-1] + self.matrix_ops.append({ + Constant.RANK_ID: rank_id, + Constant.STEP_ID: step_id, + Constant.COMM_OP_TYPE: comm_op_type, + Constant.COMM_OP_NAME: op_name, + Constant.GROUP_NAME: group_name, + Constant.COMM_OP_INFO: op_link_info + }) + + +class UnionFind(object): + """Disjoint Set Union""" + + @classmethod + def union(cls, first_set: set, second_set: set, third_set: set): + """make p and q the same set""" + return first_set | second_set | third_set + + @classmethod + def is_connected(cls, first_set: set, second_set: set): + """ + check whether set p and set q are connected + """ + if first_set & second_set: + return True + else: + return False diff --git a/profiler/cluster_analyse/communication_group/communication_db_group.py b/profiler/cluster_analyse/communication_group/communication_db_group.py new file mode 100644 index 0000000000000000000000000000000000000000..510dcd971357dfb4798e4d284a72fbb3f3a21859 --- /dev/null +++ b/profiler/cluster_analyse/communication_group/communication_db_group.py @@ -0,0 +1,57 @@ +import os + +from common_func.db_manager import DBManager +from common_func.constant import Constant +from communication_group.base_communication_group import BaseCommunicationGroup + + +class CommunicationDBGroup(BaseCommunicationGroup): + COMMUNICATION_GROUP_TABLE = "CommunicationGroup" + + def __init__(self, params: dict): + super().__init__(params) + + def read_communication_func(self, params: tuple): + if len(params) < 3: + return -1, ({}, {}, {}) + rank_id = params[0] + db_path = params[1] + time_data = [] + bandwidth_data = [] + matrix_data = [] + if os.path.exists(db_path): + conn, cursor = DBManager.create_connect_db(db_path) + time_info_sql = "select * from {0}".format(Constant.TABLE_COMM_ANALYZER_TIME) + bandwidth_info_sql = "select * from {0}".format(Constant.TABLE_COMM_ANALYZER_BANDWIDTH) + matrix_info_sql = "select * from {0}".format(Constant.TABLE_COMM_ANALYZER_MATRIX) + if (DBManager.check_tables_in_db(db_path, Constant.TABLE_COMM_ANALYZER_TIME, + Constant.TABLE_COMM_ANALYZER_BANDWIDTH) + and self.analysis_mode in ["all", "communication_time"]): + time_data = DBManager.fetch_all_data(cursor, time_info_sql) + bandwidth_data = DBManager.fetch_all_data(cursor, bandwidth_info_sql) + if (DBManager.check_tables_in_db(db_path, Constant.TABLE_COMM_ANALYZER_MATRIX) + and self.analysis_mode in ["all", "communication_matrix"]): + matrix_data = DBManager.fetch_all_data(cursor, matrix_info_sql) + DBManager.destroy_db_connect(conn, cursor) + comm_data = self.adapter.transfer_comm_from_db_to_json(time_data, bandwidth_data) + comm_matrix_data = self.adapter.transfer_matrix_from_db_to_json(matrix_data) + return rank_id, comm_data, comm_matrix_data + + def dump_data(self): + output_path = os.path.join(self.collection_path, Constant.CLUSTER_ANALYSIS_OUTPUT) + result_db = os.path.join(output_path, Constant.DB_CLUSTER_COMMUNICATION_ANALYZER) + res = [] + for data_type, data_list in self.communication_group.items(): + for data in data_list: + rank_set = "(" + ",".join(str(i) for i in data) + ")" + data = [data_type, rank_set] + res.append(data) + if res: + DBManager.create_tables(result_db, self.COMMUNICATION_GROUP_TABLE) + conn, cursor = DBManager.create_connect_db(result_db) + sql = "insert into {} values ({value})".format(self.COMMUNICATION_GROUP_TABLE, + value="?," * (len(res[0]) - 1) + "?") + DBManager.executemany_sql(conn, sql, res) + DBManager.destroy_db_connect(conn, cursor) + else: + print("[WARNING] The CommunicationGroup table won't be created because no data has been calculated.") diff --git a/profiler/cluster_analyse/communication_group/communication_group_generator.py b/profiler/cluster_analyse/communication_group/communication_group_generator.py index 4963bf95399fea29edf31be324a49801e7f485d1..3dca90454b608fe3ffb1c365854c2aa3950b6cee 100644 --- a/profiler/cluster_analyse/communication_group/communication_group_generator.py +++ b/profiler/cluster_analyse/communication_group/communication_group_generator.py @@ -13,211 +13,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os -from copy import deepcopy -from multiprocessing import Pool -from collections import defaultdict from common_func.constant import Constant -from common_func.file_manager import FileManager +from communication_group.communication_db_group import CommunicationDBGroup +from communication_group.communication_json_group import CommunicationJsonGroup class CommunicationGroupGenerator: - COMMUNICATION_GROUP_JSON = "communication_group.json" + + GROUP_MAP = { + Constant.DB: CommunicationDBGroup, + Constant.TEXT: CommunicationJsonGroup + } def __init__(self, params: dict): - self.collection_path = params.get(Constant.COLLECTION_PATH) - self.data_map = params.get(Constant.DATA_MAP) - self.analysis_mode = params.get(Constant.ANALYSIS_MODE) - self.communication_group = {} - self.collective_group_dict = defaultdict(set) - self.p2p_group_dict = defaultdict(list) - self.rank_comm_dir_dict = {} - self.communication_ops = [] - self.p2p_comm_group = [] - self.p2p_link = [] - self.matrix_ops = [] + self.processor = self.GROUP_MAP.get(params.get(Constant.DATA_TYPE))(params) def generate(self): - self.load_communication_json() - self.analyze_communication_ops() - self.set_p2p_groups() - self.generate_collective_communication_group() - self.generate_p2p_communication_group() - FileManager.create_json_file(self.collection_path, self.communication_group, self.COMMUNICATION_GROUP_JSON) - comm_data_dict = { - Constant.COLLECTIVE_GROUP: self.collective_group_dict, - Constant.COMMUNICATION_OPS: self.communication_ops, - Constant.MATRIX_OPS: self.matrix_ops, - Constant.COMMUNICATION_GROUP: self.communication_group - } - return comm_data_dict - - def analyze_communication_ops(self): - for rank_id, rank_id_comm_dict, rank_id_matrix_dict in self.rank_comm_dir_dict: - for step_id, step_id_dict in rank_id_comm_dict.items(): - if not isinstance(step_id_dict, dict): - print(f"[WARNING] rank{rank_id}'s communication.json has a wrong data struct.") - continue - self.get_collective_ops_name(rank_id, step_id_dict.get(Constant.COLLECTIVE)) - for comm_op_type, comm_op_dict in step_id_dict.items(): - self.add_communication_ops(rank_id, step_id, comm_op_type, comm_op_dict) - - for step_id, step_id_dict in rank_id_matrix_dict.items(): - if not isinstance(step_id_dict, dict): - print(f"[WARNING] rank{rank_id}'s communication_matrix.json has a wrong data struct.") - continue - self.set_p2p_link(rank_id, step_id, rank_id_matrix_dict) - self.get_collective_ops_name(rank_id, step_id_dict.get(Constant.COLLECTIVE)) - - def read_comm_json_func(self: any, params: tuple): - if len(params) < 3: - return -1, {}, {} - rank_id = params[0] - comm_json_path = params[1] - matrix_json_path = params[2] - comm_data = {} - matrix_data = {} - if os.path.exists(comm_json_path) and self.analysis_mode in ['all', 'communication_time']: - comm_data = FileManager.read_json_file(comm_json_path) - if os.path.exists(matrix_json_path) and self.analysis_mode in ['all', 'communication_matrix']: - matrix_data = FileManager.read_json_file(matrix_json_path) - return rank_id, comm_data, matrix_data - - def load_communication_json(self): - comm_op_dirs = [] - for rank_id, profiling_dir_path in self.data_map.items(): - comm_dir = os.path.join(profiling_dir_path, Constant.SINGLE_OUTPUT, Constant.COMM_JSON) - matrix_dir = os.path.join(profiling_dir_path, Constant.SINGLE_OUTPUT, Constant.COMM_MATRIX_JSON) - if comm_dir and matrix_dir: - comm_op_dirs.append((rank_id, comm_dir, matrix_dir)) - else: - print(f"[WARNING] Rank {rank_id} does not have a valid communication.json or communication_matrix.json.") - with Pool() as p: - self.rank_comm_dir_dict = p.map(self.read_comm_json_func, comm_op_dirs) - - def generate_collective_communication_group(self): - self.communication_group[Constant.COLLECTIVE] = \ - [list(group) for group_name, group in self.collective_group_dict.items()] - - def whether_valid_comm_group(self, rank_set: set): - """ - while distinguish which communication group should be used to infer stage info, these group should be ignored: - 1. group can not include more than 1 rank in every single p2p group - """ - for p2p_rank_set in self.p2p_comm_group: - if len(rank_set.intersection(p2p_rank_set)) > 1: - return False - return True - - def generate_p2p_communication_group(self): - stage_group = {} - for group_name, rank_set in self.collective_group_dict.items(): - if not self.whether_valid_comm_group(rank_set): - continue - unioned_set = set() - remove_key = [] - for first_rank, stage in stage_group.items(): - if UnionFind.is_connected(rank_set, stage): - unioned_set = UnionFind.union(rank_set, stage, unioned_set) - remove_key.append(first_rank) - if unioned_set: - for key in remove_key: - del stage_group[key] - stage_group[min(unioned_set)] = unioned_set - else: - stage_group[min(rank_set)] = rank_set - first_rank_sort_list = sorted([first_rank for first_rank in stage_group]) - self.communication_group[Constant.P2P] = \ - [list(stage_group.get(first_rank, {})) for first_rank in first_rank_sort_list] - - def set_p2p_groups(self): - self.p2p_link = sorted(self.p2p_link, key=lambda x: min(x)) - while self.p2p_link: - union_set = deepcopy(self.p2p_link[0]) - rm_list = [self.p2p_link[0]] - for idx, link_rank_set_x in enumerate(self.p2p_link[1:]): - if UnionFind.is_connected(link_rank_set_x, union_set): - union_set = union_set.union(link_rank_set_x) - rm_list.append(link_rank_set_x) - self.p2p_comm_group.append(union_set) - self.p2p_link = [element for element in self.p2p_link if element not in rm_list] - - def set_p2p_link(self, rank_id: int, step_id: str, rank_id_matrix_dict: dict): - ops = rank_id_matrix_dict.get(step_id, {}) - self.add_matrix_ops(rank_id, step_id, ops) - if not ops: - print(f"[WARNING] rank{rank_id} {step_id} do not have communication matrix ops data.") - return - p2p_ops = ops.get(Constant.P2P, {}) - for op_name, link_dict in p2p_ops.items(): - self.append_p2p_link(op_name, link_dict) - - def append_p2p_link(self, op_name, link_dict): - for link in link_dict: - if '-' not in link: - print(f"[WARNING] {op_name} has an invalid link key {link}!") - break - src_rank = int(link.split('-')[0]) - dst_rank = int(link.split('-')[1]) - if src_rank != dst_rank: - rank_set = set([src_rank, dst_rank]) - if rank_set in self.p2p_link: - continue - self.p2p_link.append(rank_set) - - def get_collective_ops_name(self, rank_id: int, comm_op_dict: dict): - for comm_op in comm_op_dict: - if comm_op.startswith('Total'): - continue - group_name = comm_op.split('@')[-1] - self.collective_group_dict[group_name].add(rank_id) - - def add_communication_ops(self, rank_id: str, step_id: str, comm_op_type: str, comm_op_dict: dict): - for comm_op in comm_op_dict: - if comm_op.startswith('Total'): - continue - group_name = comm_op.split('@')[-1] - self.communication_ops.append({ - Constant.RANK_ID: rank_id, - Constant.STEP_ID: step_id, - Constant.COMM_OP_TYPE: comm_op_type, - Constant.COMM_OP_NAME: comm_op, - Constant.GROUP_NAME: group_name, - Constant.COMM_OP_INFO: comm_op_dict.get(comm_op) - }) - - def add_matrix_ops(self, rank_id: int, step_id: str, step_id_dict: dict): - for comm_op_type, comm_dict in step_id_dict.items(): - if comm_op_type != Constant.COLLECTIVE and comm_op_type != Constant.P2P: - print(f"[WARNING] Unknown communication operators type!") - continue - for op_name, op_link_info in comm_dict.items(): - if op_name.startswith('Total'): - continue - group_name = op_name.split('@')[-1] - self.matrix_ops.append({ - Constant.RANK_ID: rank_id, - Constant.STEP_ID: step_id, - Constant.COMM_OP_TYPE: comm_op_type, - Constant.COMM_OP_NAME: op_name, - Constant.GROUP_NAME: group_name, - Constant.COMM_OP_INFO: op_link_info - }) - - -class UnionFind(object): - """Disjoint Set Union""" - @classmethod - def union(cls, p: set, q: set, o: set): - """make p and q the same set""" - return p | q | o - - @classmethod - def is_connected(cls, p: set, q: set): - """ - check whether set p and set q are connected - """ - if p & q: - return True - else: - return False + return self.processor.generate() diff --git a/profiler/cluster_analyse/communication_group/communication_json_group.py b/profiler/cluster_analyse/communication_group/communication_json_group.py new file mode 100644 index 0000000000000000000000000000000000000000..f6e01e3abfde4d8f180043a5bf9a50c6b5a4964c --- /dev/null +++ b/profiler/cluster_analyse/communication_group/communication_json_group.py @@ -0,0 +1,44 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from common_func.constant import Constant +from common_func.file_manager import FileManager +from communication_group.base_communication_group import BaseCommunicationGroup + + +class CommunicationJsonGroup(BaseCommunicationGroup): + COMMUNICATION_GROUP_JSON = "communication_group.json" + + def __init__(self, params: dict): + super().__init__(params) + + def dump_data(self): + FileManager.create_json_file(self.collection_path, self.communication_group, self.COMMUNICATION_GROUP_JSON) + + def read_communication_func(self: any, params: tuple): + if len(params) < 3: + return -1, {}, {} + rank_id = params[0] + comm_json_path = params[1] + matrix_json_path = params[2] + comm_data = {} + matrix_data = {} + if os.path.exists(comm_json_path) and self.analysis_mode in ["all", "communication_time"]: + comm_data = FileManager.read_json_file(comm_json_path) + if os.path.exists(matrix_json_path) and self.analysis_mode in ["all", "communication_matrix"]: + matrix_data = FileManager.read_json_file(matrix_json_path) + return rank_id, comm_data, matrix_data diff --git a/profiler/cluster_analyse/utils/__init__.py b/profiler/cluster_analyse/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/profiler/cluster_analyse/utils/data_transfer_adapter.py b/profiler/cluster_analyse/utils/data_transfer_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..1f306415fa789ae0dab7d8751b1c240b3433de0d --- /dev/null +++ b/profiler/cluster_analyse/utils/data_transfer_adapter.py @@ -0,0 +1,142 @@ +import copy + +from common_func.constant import Constant +from common_func.table_constant import TableConstant + + +class DataTransferAdapter(object): + COMM_TIME_TABLE_COLUMN = [TableConstant.START_TIMESTAMP, TableConstant.ELAPSED_TIME, TableConstant.TRANSIT_TIME, + TableConstant.WAIT_TIME, TableConstant.SYNCHRONIZATION_TIME, TableConstant.IDLE_TIME, + TableConstant.SYNCHRONIZATION_TIME_RATIO, TableConstant.WAIT_TIME_RATIO] + COMM_TIME_JSON_COLUMN = [Constant.START_TIMESTAMP, Constant.ELAPSE_TIME_MS, Constant.TRANSIT_TIME_MS, + Constant.WAIT_TIME_MS, Constant.SYNCHRONIZATION_TIME_MS, Constant.IDLE_TIME_MS, + Constant.SYNCHRONIZATION_TIME_RATIO, Constant.WAIT_TIME_RATIO] + MATRIX_TABLE_COLUMN = [TableConstant.TRANSIT_SIZE, TableConstant.TRANSIT_TIME, TableConstant.BANDWIDTH, + TableConstant.TRANSPORT_TYPE, TableConstant.OPNAME] + MATRIX_JSON_COLUMN = [Constant.TRANSIT_SIZE_MB, Constant.TRANSIT_TIME_MS, Constant.BANDWIDTH_GB_S, + Constant.TRANSPORT_TYPE, Constant.OP_NAME] + COMM_BD_TABLE_COLUMN = [TableConstant.TRANSIT_SIZE, TableConstant.TRANSIT_TIME, TableConstant.BANDWIDTH, + TableConstant.LARGE_PACKET_RATIO] + COMM_BD_JSON_COLUMN = [Constant.TRANSIT_SIZE_MB, Constant.TRANSIT_TIME_MS, Constant.BANDWIDTH_GB_S, + Constant.LARGE_PACKET_RATIO] + + def __init__(self): + super().__init__() + + def transfer_comm_from_db_to_json(self, time_info: list, bandwidth_info: list): + result = {} + if not time_info and not bandwidth_info: + return result + for time_data in time_info: + comm_time = dict() + hccl_name = time_data[TableConstant.HCCL_OP_NAME] + "@" + time_data[TableConstant.GROUP_NAME] + for key, value in dict(zip(self.COMM_TIME_JSON_COLUMN, self.COMM_TIME_TABLE_COLUMN)).items(): + if not key.endswith("ratio"): + comm_time[key] = time_data.get(value, 0) + result.setdefault(time_data[TableConstant.STEP], {}).setdefault(time_data[TableConstant.TYPE], {}). \ + setdefault(hccl_name, {})[Constant.COMMUNICATION_TIME_INFO] = comm_time + hccl_set = set() + for bd_data in bandwidth_info: + hccl_name = bd_data[TableConstant.HCCL_OP_NAME] + "@" + bd_data[TableConstant.GROUP_NAME] + hccl_set.add(hccl_name) + for hccl in hccl_set: + comm_bd = dict() + for bd_data in bandwidth_info: + if hccl == (bd_data[TableConstant.HCCL_OP_NAME] + "@" + bd_data[TableConstant.GROUP_NAME]): + temp_dict = dict() + key_dict = dict(zip(self.COMM_BD_JSON_COLUMN, self.COMM_BD_TABLE_COLUMN)) + self.set_value_by_key(temp_dict, bd_data, key_dict) + comm_bd.setdefault(bd_data[TableConstant.TRANSPORT_TYPE], temp_dict).setdefault( + Constant.SIZE_DISTRIBUTION, {})[bd_data[TableConstant.PACKAGE_SIZE]] = \ + [bd_data[TableConstant.COUNT], bd_data[TableConstant.TOTAL_DURATION]] + result.setdefault(bd_data[TableConstant.STEP], {}).setdefault(bd_data[TableConstant.TYPE], {}). \ + setdefault(hccl, {})[Constant.COMMUNICATION_BANDWIDTH_INFO] = comm_bd + return result + + def transfer_comm_from_json_to_db(self, res_data: dict): + res_comm_data, res_bd_data = list(), list() + + def split_comm_time(): + for rank_id, comm_data in op_data.items(): + time_data = comm_data.get(Constant.COMMUNICATION_TIME_INFO) + res_time = set_only_value(rank_id) + for key, value in dict(zip(self.COMM_TIME_TABLE_COLUMN, self.COMM_TIME_JSON_COLUMN)).items(): + res_time[key] = time_data.get(value, 0) + res_comm_data.append(res_time) + bd_data = comm_data.get(Constant.COMMUNICATION_BANDWIDTH_INFO, {}) + for transport_type, data in bd_data.items(): + res_bandwidth = set_only_value(rank_id) + key_dict = dict(zip(self.COMM_BD_TABLE_COLUMN, self.COMM_BD_JSON_COLUMN)) + res_bandwidth[TableConstant.TRANSPORT_TYPE] = transport_type + self.set_value_by_key(res_bandwidth, data, key_dict) + for key, value in data.get(Constant.SIZE_DISTRIBUTION, {}).items(): + res_bandwidth[TableConstant.PACKAGE_SIZE] = key + res_bandwidth[TableConstant.COUNT] = value[0] + res_bandwidth[TableConstant.TOTAL_DURATION] = value[1] + temp_dict = copy.deepcopy(res_bandwidth) + res_bd_data.append(temp_dict) + + def set_only_value(rank_id): + res_dict = dict() + res_dict[TableConstant.RANK_SET] = str(rank_set) + res_dict[TableConstant.STEP] = step + res_dict[TableConstant.RANK_ID] = rank_id + res_dict[TableConstant.HCCL_OP_NAME] = op_name.split("@")[0] if "@" in op_name else op_name + res_dict[TableConstant.GROUP_NAME] = op_name.split("@")[1] if "@" in op_name else "" + return res_dict + + for rank_set, step_dict in res_data.items(): + for step, op_dict in step_dict.items(): + for op_name, op_data in op_dict.items(): + split_comm_time() + return res_comm_data, res_bd_data + + def set_value_by_key(self, src_dict, dst_dict, key_dict): + for key, value in key_dict.items(): + src_dict[key] = dst_dict.get(value, 0) + + def transfer_matrix_from_db_to_json(self, matrix_data: list): + result = {} + if not matrix_data: + return result + hccl_set = set() + for data in matrix_data: + hccl = data[TableConstant.HCCL_OP_NAME] + "@" + data[TableConstant.GROUP_NAME] + hccl_set.add(hccl) + for hccl in hccl_set: + for data in matrix_data: + if hccl == (data[TableConstant.HCCL_OP_NAME] + "@" + data[TableConstant.GROUP_NAME]): + key = data[TableConstant.SRC_RANK] + '-' + data[TableConstant.DST_RANK] + temp_dict = dict() + key_dict = dict(zip(self.MATRIX_JSON_COLUMN, self.MATRIX_TABLE_COLUMN)) + self.set_value_by_key(temp_dict, data, key_dict) + result.setdefault(data[TableConstant.STEP], {}).setdefault(data[TableConstant.TYPE], {}). \ + setdefault(hccl, {}).setdefault(key, temp_dict) + return result + + def transfer_matrix_from_json_to_db(self, res_data: dict): + result = list() + + def split_matrix_data(): + for op_name, op_data in op_dict.items(): + for link_key, link_data in op_data.items(): + if "@" in op_name: + hccl_op_name, group_name = op_name.split("@")[0], op_name.split("@")[1] + else: + hccl_op_name, group_name = op_name, "" + matrix_data = { + TableConstant.RANK_SET: str(rank_set), + TableConstant.STEP: step, + TableConstant.HCCL_OP_NAME: hccl_op_name, + TableConstant.GROUP_NAME: group_name, + TableConstant.SRC_RANK: link_key.split("-")[0], + TableConstant.DST_RANK: link_key.split("-")[1] + } + key_dict = dict(zip(self.MATRIX_TABLE_COLUMN, self.MATRIX_JSON_COLUMN)) + self.set_value_by_key(matrix_data, link_data, key_dict) + result.append(matrix_data) + + for rank_set, step_dict in res_data.items(): + for step, op_dict in step_dict.items(): + split_matrix_data() + return result diff --git a/profiler/compare_tools/README.md b/profiler/compare_tools/README.md index 17d26d07e2074b4b50ddb2c27371770bc92da144..0283be8f27ea6bb726d2cb7b343229813abf7989 100644 --- a/profiler/compare_tools/README.md +++ b/profiler/compare_tools/README.md @@ -32,11 +32,11 @@ pip3 install numpy 采集样例代码参考一: ```Python -with torch.profiler.profile( +with torch_npu.profiler.profile( profile_memory=True, # 内存数据采集的开关 record_shapes=True, # 算子input shape信息采集的开关 - schedule=torch.profiler.schedule(wait=10, warmup=0, active=1, repeat=1), - on_trace_ready=torch.profiler.tensorboard_trace_handler("./result_dir") + schedule=torch_npu.profiler.schedule(wait=10, warmup=0, active=1, repeat=1), + on_trace_ready=torch_npu.profiler.tensorboard_trace_handler("./result_dir") ) as prof: for step in ranges(step_number): train_one_step() @@ -46,10 +46,10 @@ with torch.profiler.profile( 采集样例代码参考二: ```Python -prof = torch.profiler.profile( +prof = torch_npu.profiler.profile( profile_memory=True, # 内存数据采集的开关 record_shapes=True, # 算子input shape信息采集的开关 - on_trace_ready=torch.profiler.tensorboard_trace_handler("./result_dir")) + on_trace_ready=torch_npu.profiler.tensorboard_trace_handler("./result_dir")) for step in range(step_number): if step == 11: prof.start() @@ -97,7 +97,7 @@ python performance_compare.py [基准性能数据文件] [比对性能数据文 - 比对性能数据文件(必选):可以指定以“ascend_pt”结尾的目录、ASCEND_PROFILER_OUTPUT目录或trace_view.json文件,指定trace_view.json无法显示算子的内存占用。 - --output_path(可选):性能比对结果存放的路径,默认保存在当前目录。 -工具将总体性能拆解为训练耗时和内存占用,其中训练耗时可拆分为算子、通信、调度三个维度,以打屏的形式输出总体指标,帮助用户定界劣化的方向。与此同时,工具还会生成performance_comparison_result_*.xlsl,展示每个算子在执行耗时、通信耗时、内存占用的优劣,可通过DIFF列大于0筛选出劣化算子。详细介绍请参见“**比对结果说明**”。 +工具将总体性能拆解为训练耗时和内存占用,其中训练耗时可拆分为算子(包括算子和nn.Module)、通信、调度三个维度,以打屏的形式输出总体指标,帮助用户定界劣化的方向。与此同时,工具还会生成performance_comparison_result_*.xlsx,展示每个算子在执行耗时、通信耗时、内存占用的优劣,可通过DIFF列大于0筛选出劣化算子。详细介绍请参见“**比对结果说明**”。 #### 通用参数说明 @@ -120,10 +120,10 @@ python performance_compare.py [基准性能数据文件] [比对性能数据文 | 参数名 | 说明 | 是否必选 | | ----------------- | ------------------------------------------------------------ | -------- | -| --gpu_flow_cat | 配置GPU trace中cpu侧算子与device kernel的连线标识,当GPU的kernel均为空时设置。根据timeline的json文件在chrome://tracing上的Flow events的选项配置。使用示例:--gpu_flow_cat=async_gpu | 否 | +| --gpu_flow_cat | 配置GPU trace中CPU侧算子与device kernel的连线标识,当GPU的Device Duration(us)均为0时设置。使用chrome://tracing打开GPU的json,右上角Flow events找到连线标识,将标识配置进该参数。使用示例:--gpu_flow_cat=async_gpu | 否 | | --use_input_shape | 开启算子精准匹配,默认关闭。使用示例:--use_input_shape | 否 | -| --max_kernel_num | 设置CPU侧算子下发的最大kernel数量,当超过设定值时工具会自动往下找子算子,直至满足条件,默认仅比对最上层算子。使用示例:--max_kernel_num=10 | 否 | -| --op_name_map | 设置GPU与NPU等价的算子名称的映射关系,以字典形式存入。使用示例:--op_name_map='{"Optimizer.step#SGD.step":"Optimizer.step#NpuFusedSGD.step"}' | 否 | +| --max_kernel_num | 设置CPU侧算子下发的最大kernel数量,当超过设定值时工具会自动往下找子算子,直至满足条件。默认仅比对最上层算子,粒度较粗;若想要更细粒度的算子比对,可设置该参数,参数值不得小于4,参数值设置越小,比对粒度越细。使用示例:--max_kernel_num=10 | 否 | +| --op_name_map | 设置GPU与NPU等价的算子名称的映射关系,以字典形式存入。使用示例:--op_name_map={'Optimizer.step#SGD.step':'Optimizer.step#NpuFusedSGD.step'} | 否 | ## 比对结果说明 @@ -131,19 +131,20 @@ python performance_compare.py [基准性能数据文件] [比对性能数据文 总体性能比对结果以打屏的形式呈现。 -| 字段 | 说明 | -| ------------------------------- | ------------------------------------------------------------ | -| Cube Time(Num) | Cube算子总耗时,Num表示计算的次数。 | -| Vector Time(Num) | Vector算子总耗时,Num表示计算的次数。 | -| Other Time | AI CPU、DSA等其他非cube vector算子耗时。 | -| Flash Attention Time(Forward) | Flash Attention算子前向耗时。 | -| Flash Attention Time(Backward) | Flash Attention算子反向耗时。 | -| Computing Time | 计算流耗时,计算流所有event耗时总和。如果有多条并发计算,计算流耗时对重叠部分只会计算一次。 | -| Mem Usage | 内存使用。gpu上的内存使用可以使用nvidia-smi查看,npu上的内存使用可以使用npu-smi查看,Profiling信息采集时打开profile_memory=True开关,mem usage显示的是memory_record里面的最大resevered值,一般来说是进程级内存。 | -| Uncovered Communication Time | 通信未掩盖耗时。 | -| SDMA Time(Num) | 拷贝类任务耗时,Num表示计算的次数。 | -| Free Time | 调度耗时 = E2E耗时 - 算子耗时 - 通信不可掩盖耗时。Free的定义为Device侧既不在通信又不在计算的时间,因此包含拷贝时间(SDMA Time)。 | -| E2E Time(Not minimal profiling) | E2E总耗时,计算流端到端耗时。当存在Not minimal profiling时,表示该时间存在性能膨胀,会影响通信和调度耗时。 | +| 字段 | 说明 | +| --------------------------------------- | ------------------------------------------------------------ | +| Cube Time(Num) | Cube算子总耗时,Num表示计算的次数。 | +| Vector Time(Num) | Vector算子总耗时,Num表示计算的次数。 | +| Other Time | AI CPU、DSA等其他非cube vector算子耗时。 | +| Flash Attention Time(Forward) | Flash Attention算子前向耗时。 | +| Flash Attention Time(Backward) | Flash Attention算子反向耗时。 | +| Computing Time | 计算流耗时,计算流所有event耗时总和。如果有多条并发计算,计算流耗时对重叠部分只会计算一次。 | +| Mem Usage | 内存使用。GPU上的内存使用可以使用nvidia-smi查看,NPU上的内存使用可以使用npu-smi查看,Profiling信息采集时打开profile_memory=True开关,mem usage显示的是memory_record里面的最大resevered值,一般来说是进程级内存。 | +| Uncovered Communication Time(Wait Time) | 通信未掩盖耗时,包含Wait Time(只有采集性能数据的Level等级为L1以上并且采集NPU数据时才会存在)为同步时间。 | +| SDMA Time(Num) | 拷贝类任务耗时,Num表示计算的次数。 | +| Free Time | 调度耗时 = E2E耗时 - 算子耗时 - 通信不可掩盖耗时。Free的定义为Device侧既不在通信又不在计算的时间,因此包含拷贝时间(SDMA Time)。 | +| E2E Time(Not minimal profiling) | E2E总耗时,计算流端到端耗时。当存在Not minimal profiling时,表示该时间存在性能膨胀,会影响通信和调度耗时。 | +| Other Time | AI CPU、DSA、TensorMove等其他算子耗时。 | 可以采取最简性能数据采集的方式来减少E2E耗时的性能膨胀,示例代码如下: @@ -160,40 +161,63 @@ with torch_npu.profiler.profile( activities配置仅采集NPU数据,不配置experimental_config参数以及其他可选开关。 +- 当Computing Time耗时增大,分析**算子性能**。 +- 当Uncovered Communication Time耗时增大,分析**通信性能**,若通信性能分析没有劣化的通信算子,代表通信与计算的并行度较差,继续进行NPU的集群性能分析。 +- 当Mem Usage增大,分析**算子内存**,若没有明显占用较大的算子,则代表算子内存申请量没有差异,问题在于内存的释放(持有时间过久),可以使用tensorboard或ascend insight继续进行NPU内存的分析。 + ### 算子性能 -算子性能比对结果在performance_comparison_result_*.xlsl中OperatorCompare和OperatorCompare(TOP)的sheet页呈现。 +算子性能比对结果在performance_comparison_result_*.xlsx中OperatorCompare和OperatorCompareStatistic的sheet页呈现。 -- OperatorCompare(TOP):算子为粒度的统计呈现,按照算子在device上的总耗时与基准算子的差距值(Diff Duration(ms)列)进行逆序。 +- OperatorCompareStatistic:算子为粒度的统计呈现,按照算子在device上的总耗时与基准算子的差距值(Diff Duration(ms)列)进行逆序。 - OperatorCompare:算子比对的明细展示,可以查看每一个算子对应的kernel详情。 - Diff Ratio:比较算子在device上执行总耗时 / 基准算子在device上执行总耗时,红色代表劣化。 +- Device Duration(us):该算子下发到device上执行的所有kernel耗时的总和。 -#### Device Duration(us) +步骤1:查看OperatorCompareStatistic页,找出耗时差距TOP的算子。 +步骤2:查看OperatorCompare页,搜索耗时差距TOP的算子,查看具体执行的kernel耗时,寻找可优化点。 -``` -该算子下发到device上执行的所有kernel耗时的总和 -``` +### nn.Module性能 + +nn.Module是所有神经网络模块的基类,使用PyTorch构建神经网络需要继承nn.Module类来实现,性能比对工具支持模块级的比对(包含优化器和nn.Module),帮助优化模型结构。 + +当用户采集时开启with_stack开关,会上报python function事件,当比对的双方数据都存在python function的事件时,可进行模块级别的比对。 + +nn.Module性能比对结果在performance_comparison_result_*.xlsx中ModuleCompareStatistic的sheet页呈现。 + +- Module Class:Module名,如nn.Module: Linear。 +- Module Level:Module的层级。 +- Module Name:Module唯一标识名,如/ DynamicNet_0/ Linear_0。 +- Operator Name:框架侧算子名,如aten::add。字段为[ TOTAL ]代表该module的总体情况。 +- Kernel Detail:算子详细信息。 +- Device Self Time(ms):该模块调用的算子(排除子模块)在device侧执行的总耗时,单位ms。 +- Number:该Module或算子被调用的次数。 +- Device Total Time(ms):该模块调用的算子(包含子模块)在device侧执行的总耗时,单位ms。 +- Device Total Time Diff(ms):GPU与NPU的Device Total Time(ms)差值。 +- Device Self Time Diff(ms):GPU与NPU的Device Self Time(ms)差值。 +- Total Time Ratio:GPU与NPU的Device Total Time(ms)比值。 +- Base Call Stack:基准文件模块的调用栈。 +- Comparison Call Stack:比较文件模块的调用栈。 ### 通信性能 -通信性能比对结果在performance_comparison_result_*.xlsl中CommunicationCompare的sheet页呈现。 +通信性能比对结果在performance_comparison_result_*.xlsx中CommunicationCompare的sheet页呈现。 -- 淡蓝色背景的记录行:通信算子的summary信息,包括通信算子名称、调用总次数、通信算子总耗时(单位:us)、通信算子平均耗时(单位:us)、通信算子最大耗时(单位:us)、通信算子最小耗时(单位:us)。 +- 第二行表头:通信算子的summary信息,包括通信算子名称、调用总次数、通信算子总耗时(单位:us)、通信算子平均耗时(单位:us)、通信算子最大耗时(单位:us)、通信算子最小耗时(单位:us)。 - 无背景色的记录行:通信算子的detail信息,仅支持NPU,包含了该通信算子下的所有Task信息,包括Task名称、Task调用次数、Task总耗时(单位:us)、Task平均耗时(单位:us)、Task最大耗时(单位:us)、Task最小耗时(单位:us)。 - Diff Ratio: 比较通信算子的总耗时 / 基准通信算子的总耗时,红色代表劣化。 ### 算子内存 -算子内存比对结果在performance_comparison_result_*.xlsl中MemoryCompare和MemoryCompare(TOP)的sheet页呈现。 +算子内存比对结果在performance_comparison_result_*.xlsx中MemoryCompare和MemoryCompareStatistic的sheet页呈现。 -- MemoryCompare(TOP):算子为粒度的统计呈现,按照算子占用的总内存与基准算子的差距值(Diff Memory(MB))进行逆序。 +- MemoryCompareStatistic:算子为粒度的统计呈现,按照算子占用的总内存与基准算子的差距值(Diff Memory(MB))进行逆序。 - MemoryCompare:算子内存比对的明细展示,可以查看每一个算子申请内存的详情。 - Diff Ratio: 比较算子占用的总内存 / 基准算子占用的总内存,红色代表劣化。 -#### Size(KB) +- Size(KB):该算子占用的device内存大小,单位KB。 -``` -该算子占用的device内存大小,单位KB -``` \ No newline at end of file +步骤1:查看MemoryCompareStatistic页,找出内存占用差距TOP的算子。 +步骤2:查看MemoryCompare页,搜索内存占用差距TOP的算子,查看具体占用的子算子。 diff --git a/profiler/compare_tools/compare_backend/comparator/module_comparetor.py b/profiler/compare_tools/compare_backend/comparator/module_comparetor.py new file mode 100644 index 0000000000000000000000000000000000000000..49c50b53c5a1b00bd17b7281d80b61d5011cb59a --- /dev/null +++ b/profiler/compare_tools/compare_backend/comparator/module_comparetor.py @@ -0,0 +1,36 @@ +from compare_backend.comparator.base_comparator import BaseComparator +from compare_backend.utils.common_func import update_order_id +from compare_backend.utils.constant import Constant + + +class ModuleComparator(BaseComparator): + def __init__(self, origin_data: any, bean: any): + super().__init__(origin_data, bean) + + def _compare(self): + if not self._origin_data: + return + base_all_data = [data for data in self._origin_data if data[0]] # index 0 for base module + base_all_data.sort(key=lambda x: x[0].start_time) + base_none_data = [data for data in self._origin_data if not data[0]] # index 0 for base module + base_none_data.sort(key=lambda x: x[1].start_time) + index = 0 + for base_module, comparison_module in base_all_data: + if not comparison_module: + self._rows.extend(self._bean(base_module, comparison_module).rows) + continue + while index < len(base_none_data): + module = base_none_data[index][1] # index 1 for comparison module + if module.start_time < comparison_module.start_time: + self._rows.extend(self._bean(None, module).rows) + index += 1 + else: + break + self._rows.extend(self._bean(base_module, comparison_module).rows) + while index < len(base_none_data): + module = base_none_data[index][1] # index 1 for comparison module + self._rows.extend(self._bean(None, module).rows) + index += 1 + update_order_id(self._rows) + if not any(row[-1] != Constant.NA for row in self._rows): + print(f"[WARNING] If you want to see the operator's call stack, you must enable with_stack switch.") diff --git a/profiler/compare_tools/compare_backend/comparator/module_statistic_comparator.py b/profiler/compare_tools/compare_backend/comparator/module_statistic_comparator.py new file mode 100644 index 0000000000000000000000000000000000000000..e09108f3cbe3744068daf6c5316dc318aea53177 --- /dev/null +++ b/profiler/compare_tools/compare_backend/comparator/module_statistic_comparator.py @@ -0,0 +1,45 @@ +from collections import OrderedDict + +from compare_backend.comparator.base_comparator import BaseComparator +from compare_backend.utils.common_func import update_order_id + + +class ModuleStatisticComparator(BaseComparator): + def __init__(self, origin_data: list, bean: any): + super().__init__(origin_data, bean) + + def _compare(self): + if not self._origin_data: + return + base_module_dict, comparison_module_dict = self._group_by_module_name() + for module_name, base_data in base_module_dict.items(): + comparison_data = comparison_module_dict.pop(module_name, []) + self._rows.extend(self._bean(module_name, base_data, comparison_data).rows) + for module_name, comparison_data in comparison_module_dict.items(): + self._rows.extend(self._bean(module_name, [], comparison_data).rows) + update_order_id(self._rows) + + def _group_by_module_name(self): + base_module_dict, comparison_module_dict = OrderedDict(), OrderedDict() + base_all_data = [data for data in self._origin_data if data[0]] # index 0 for base module + base_all_data.sort(key=lambda x: x[0].start_time) + base_none_data = [data for data in self._origin_data if not data[0]] # index 0 for base module + base_none_data.sort(key=lambda x: x[1].start_time) + index = 0 + for base_module, comparison_module in base_all_data: + base_module_dict.setdefault(base_module.module_name, []).append(base_module) + if not comparison_module: + continue + while index < len(base_none_data): + module = base_none_data[index][1] # index 1 for comparison module + if module.start_time < comparison_module.start_time: + comparison_module_dict.setdefault(module.module_name, []).append(module) + index += 1 + else: + break + comparison_module_dict.setdefault(comparison_module.module_name, []).append(comparison_module) + while index < len(base_none_data): + module = base_none_data[index][1] # index 1 for comparison module + comparison_module_dict.setdefault(module.module_name, []).append(module) + index += 1 + return base_module_dict, comparison_module_dict diff --git a/profiler/compare_tools/compare_backend/comparator/overall_performance_comparator.py b/profiler/compare_tools/compare_backend/comparator/overall_performance_comparator.py index bfc631c66c86f061b10445e117e9f947d7ebdbc5..803f953630a12e9e2bca1780ee6ec4cd751d233f 100644 --- a/profiler/compare_tools/compare_backend/comparator/overall_performance_comparator.py +++ b/profiler/compare_tools/compare_backend/comparator/overall_performance_comparator.py @@ -18,10 +18,14 @@ class OverallPerformanceComparator(BaseComparator): f'{base_profiling_info.vec_time:.3f}s({base_profiling_info.vec_num})']) comp_col.extend([f'{comp_profiling_info.cube_time:.3f}s({comp_profiling_info.cube_num})', f'{comp_profiling_info.vec_time:.3f}s({comp_profiling_info.vec_num})']) - if base_profiling_info.other_time or comp_profiling_info.other_time: - self._headers.append('Other Time') - base_col.append(f'{base_profiling_info.other_time:.3f}s') - comp_col.append(f'{comp_profiling_info.other_time:.3f}s') + if base_profiling_info.conv_time_fwd or comp_profiling_info.conv_time_fwd: + self._headers.append('Conv Time(Forward)(Num)') + base_col.append(f'{base_profiling_info.conv_time_fwd:.3f}s({base_profiling_info.conv_num_fwd})') + comp_col.append(f'{comp_profiling_info.conv_time_fwd:.3f}s({comp_profiling_info.conv_num_fwd})') + if base_profiling_info.conv_time_bwd or comp_profiling_info.conv_time_bwd: + self._headers.append('Conv Time(Backward)(Num)') + base_col.append(f'{base_profiling_info.conv_time_bwd:.3f}s({base_profiling_info.conv_num_bwd})') + comp_col.append(f'{comp_profiling_info.conv_time_bwd:.3f}s({comp_profiling_info.conv_num_bwd})') if base_profiling_info.fa_time_fwd or comp_profiling_info.fa_time_fwd: self._headers.append('Flash Attention Time(Forward)(Num)') base_col.append(f'{base_profiling_info.fa_time_fwd:.3f}s({base_profiling_info.fa_num_fwd})') @@ -30,6 +34,10 @@ class OverallPerformanceComparator(BaseComparator): self._headers.append('Flash Attention Time(Backward)(Num)') base_col.append(f'{base_profiling_info.fa_time_bwd:.3f}s({base_profiling_info.fa_num_bwd})') comp_col.append(f'{comp_profiling_info.fa_time_bwd:.3f}s({comp_profiling_info.fa_num_bwd})') + if base_profiling_info.other_time or comp_profiling_info.other_time: + self._headers.append('Other Time') + base_col.append(f'{base_profiling_info.other_time:.3f}s') + comp_col.append(f'{comp_profiling_info.other_time:.3f}s') self._headers.extend(['Computing Time']) base_col.extend([f'{base_profiling_info.compute_time:.3f}s']) comp_col.extend([f'{comp_profiling_info.compute_time:.3f}s']) @@ -37,9 +45,17 @@ class OverallPerformanceComparator(BaseComparator): self._headers.append('Mem Usage') base_col.append(f'{base_profiling_info.memory_used:.2f}G') comp_col.append(f'{comp_profiling_info.memory_used:.2f}G') - self._headers.extend(['Uncovered Communication Time']) - base_col.extend([f'{base_profiling_info.communication_not_overlapped: .3f}s']) - comp_col.extend([f'{comp_profiling_info.communication_not_overlapped: .3f}s']) + self._headers.extend(['Uncovered Communication Time(Wait Time)']) + if base_profiling_info.wait_time: + base_col.extend( + [f'{base_profiling_info.communication_not_overlapped: .3f}s({base_profiling_info.wait_time:.3f}s)']) + else: + base_col.extend([f'{base_profiling_info.communication_not_overlapped: .3f}s( / )']) + if comp_profiling_info.is_level0: + comp_col.extend([f'{comp_profiling_info.communication_not_overlapped: .3f}s( / )']) + else: + comp_col.extend( + [f'{comp_profiling_info.communication_not_overlapped: .3f}s({comp_profiling_info.wait_time:.3f}s)']) if base_profiling_info.sdma_time or comp_profiling_info.sdma_time: self._headers.append('SDMA Time(Num)') base_col.append(f'{base_profiling_info.sdma_time:.3f}s({base_profiling_info.sdma_num})') diff --git a/profiler/compare_tools/compare_backend/compare_bean/module_compare_bean.py b/profiler/compare_tools/compare_backend/compare_bean/module_compare_bean.py new file mode 100644 index 0000000000000000000000000000000000000000..abfce00d83d6c1a914aa71481277e2dc1c195f17 --- /dev/null +++ b/profiler/compare_tools/compare_backend/compare_bean/module_compare_bean.py @@ -0,0 +1,83 @@ +from compare_backend.utils.common_func import longest_common_subsequence_matching, calculate_diff_ratio +from compare_backend.utils.constant import Constant +from compare_backend.utils.excel_config import ExcelConfig +from compare_backend.utils.module_node import ModuleNode +from compare_backend.utils.name_function import NameFunction +from compare_backend.utils.torch_op_node import TorchOpNode + + +class ModuleCompareBean: + TABLE_NAME = Constant.MODULE_TABLE + HEADERS = ExcelConfig.HEADERS.get(TABLE_NAME) + OVERHEAD = ExcelConfig.OVERHEAD.get(TABLE_NAME) + + def __init__(self, base_module: ModuleNode, comparison_module: ModuleNode): + self._base_module = ModuleInfo(base_module) + self._comparison_module = ModuleInfo(comparison_module) + self.module_class = self._base_module.module_class if base_module else self._comparison_module.module_class + self.module_level = self._base_module.module_level if base_module else self._comparison_module.module_level + self.module_name = self._base_module.module_name if base_module else self._comparison_module.module_name + + @property + def rows(self): + return [self.get_total_row(), *self.get_detail_rows()] + + def get_total_row(self): + total_diff, total_ratio = calculate_diff_ratio(self._base_module.device_total_time, + self._comparison_module.device_total_time) + self_diff, _ = calculate_diff_ratio(self._base_module.device_self_time, + self._comparison_module.device_self_time) + return [None, self.module_class, self.module_level, self.module_name, "TOTAL", None, + self._base_module.device_self_time, self._base_module.device_total_time, "TOTAL", None, + self._comparison_module.device_self_time, self._comparison_module.device_total_time, total_diff, + self_diff, total_ratio, self._base_module.call_stack, self._comparison_module.call_stack] + + def get_detail_rows(self): + rows = [] + matched_ops = longest_common_subsequence_matching(self._base_module.top_layer_ops, + self._comparison_module.top_layer_ops, NameFunction.get_name) + for base_op, comparison_op in matched_ops: + base_op = OpInfo(base_op) + comparison_op = OpInfo(comparison_op) + self_diff, self_ratio = calculate_diff_ratio(base_op.device_self_time, comparison_op.device_self_time) + base_call_stack = base_op.call_stack if self_diff > 0 else None + comparison_call_stack = comparison_op.call_stack if self_diff > 0 else None + rows.append( + [None, self.module_class, self.module_level, self.module_name, base_op.operator_name, + base_op.kernel_details, base_op.device_self_time, None, comparison_op.operator_name, + comparison_op.kernel_details, comparison_op.device_self_time, None, None, self_diff, self_ratio, + base_call_stack, comparison_call_stack]) + return rows + + +class ModuleInfo: + def __init__(self, module: ModuleNode): + self.module_class = "" + self.module_level = "" + self.module_name = "" + self.device_self_time = 0 + self.device_total_time = 0 + self.top_layer_ops = [] + self.call_stack = "" + if module: + self.module_class = module.module_class + self.module_level = module.module_level + self.module_name = module.module_name.replace("nn.Module:", "") + self.device_self_time = module.device_self_dur + self.device_total_time = module.device_total_dur + self.top_layer_ops = module.toy_layer_api_list + self.call_stack = module.call_stack + + +class OpInfo: + def __init__(self, operator: TorchOpNode): + self.operator_name = "" + self.kernel_details = "" + self.device_self_time = 0 + self.call_stack = "" + if operator: + self.operator_name = operator.name + for kernel in operator.kernel_list: + self.device_self_time += kernel.device_dur + self.kernel_details += kernel.kernel_details + self.call_stack = operator.call_stack diff --git a/profiler/compare_tools/compare_backend/compare_bean/module_statistic_bean.py b/profiler/compare_tools/compare_backend/compare_bean/module_statistic_bean.py new file mode 100644 index 0000000000000000000000000000000000000000..97fc98bdd354e1ebe1fbb3fc44def4eaf3059235 --- /dev/null +++ b/profiler/compare_tools/compare_backend/compare_bean/module_statistic_bean.py @@ -0,0 +1,98 @@ +import re + +from compare_backend.utils.common_func import calculate_diff_ratio +from compare_backend.utils.constant import Constant +from compare_backend.utils.excel_config import ExcelConfig + + +class ModuleStatisticBean: + TABLE_NAME = Constant.MODULE_TOP_TABLE + HEADERS = ExcelConfig.HEADERS.get(TABLE_NAME) + OVERHEAD = ExcelConfig.OVERHEAD.get(TABLE_NAME) + + def __init__(self, name: str, base_data: list, comparison_data: list): + self._module_name = name.replace("nn.Module:", "") + pattern = re.compile('_[0-9]+$') + self._module_class = pattern.sub('', name.split("/")[-1]) + self._module_level = name.count("/") + self._base_info = ModuleStatisticInfo(base_data) + self._comparison_info = ModuleStatisticInfo(comparison_data) + + @property + def rows(self): + rows = [self.get_total_row()] + rows.extend(self.get_detail_rows()) + return rows + + @staticmethod + def _get_kernel_detail_rows(base_kernel_dict, com_kernel_dict): + base_kernel_detals = "" + com_kernel_details = "" + for kernel_name, base_dur_list in base_kernel_dict.items(): + base_dur = "%.3f" % sum(base_dur_list) + base_kernel_detals += f"{kernel_name}, [number: {len(base_dur_list)}], [duration_ms: {base_dur}]\n" + for kernel_name, com_dur_list in com_kernel_dict.items(): + com_dur = "%.3f" % sum(com_dur_list) + com_kernel_details += f"{kernel_name}, [number: {len(com_dur_list)}], [duration_ms: {com_dur}]\n" + return [base_kernel_detals, com_kernel_details] + + def get_total_row(self): + total_diff, total_ratio = calculate_diff_ratio(self._base_info.device_total_dur_ms, + self._comparison_info.device_total_dur_ms) + self_diff, _ = calculate_diff_ratio(self._base_info.device_self_dur_ms, + self._comparison_info.device_self_dur_ms) + row = [None, self._module_class, self._module_level, self._module_name, "[ TOTAL ]", None, + self._base_info.device_self_dur_ms, self._base_info.number, self._base_info.device_total_dur_ms, + None, self._comparison_info.device_self_dur_ms, self._comparison_info.number, + self._comparison_info.device_total_dur_ms, total_diff, self_diff, + total_ratio, self._base_info.call_stack, self._comparison_info.call_stack] + return row + + def get_detail_rows(self): + rows = [] + for op_name, base_dur_dict in self._base_info.api_dict.items(): + base_dur_list = base_dur_dict.get("total", []) + com_dur_dict = self._comparison_info.api_dict.pop(op_name, {}) + com_dur_list = com_dur_dict.get("total", []) + base_kernel_detals, com_kernel_details = self._get_kernel_detail_rows(base_dur_dict.get("detail", {}), + com_dur_dict.get("detail", {})) + self_diff, self_ratio = calculate_diff_ratio(sum(base_dur_list), sum(com_dur_list)) + row = [None, self._module_class, self._module_level, self._module_name, op_name, base_kernel_detals, + sum(base_dur_list), len(base_dur_list), None, com_kernel_details, sum(com_dur_list), + len(com_dur_list), None, None, self_diff, self_ratio, None, None] + rows.append(row) + + for op_name, com_dur_dict in self._comparison_info.api_dict.items(): + com_dur_list = com_dur_dict.get("total", []) + base_kernel_detals, com_kernel_details = self._get_kernel_detail_rows({}, com_dur_dict.get("detail", {})) + self_diff, self_ratio = calculate_diff_ratio(0, sum(com_dur_list)) + row = [None, self._module_class, self._module_level, self._module_name, op_name, base_kernel_detals, 0, 0, + None, com_kernel_details, sum(com_dur_list), len(com_dur_list), None, None, self_diff, + self_ratio, None, None] + rows.append(row) + return rows + + +class ModuleStatisticInfo: + def __init__(self, data_list: list): + self._data_list = data_list + self.device_self_dur_ms = 0 + self.device_total_dur_ms = 0 + self.call_stack = "" + self.number = len(data_list) + self.api_dict = {} + self._get_info() + + def _get_info(self): + if self._data_list: + self.call_stack = self._data_list[0].call_stack + for module in self._data_list: + self.device_self_dur_ms += module.device_self_dur / Constant.US_TO_MS + self.device_total_dur_ms += module.device_total_dur / Constant.US_TO_MS + for torch_op in module.toy_layer_api_list: + self.api_dict.setdefault(torch_op.name, {}).setdefault("total", []).append( + torch_op.device_dur / Constant.US_TO_MS) + for kernel in torch_op.kernel_list: + self.api_dict.setdefault(torch_op.name, {}).setdefault("detail", {}).setdefault(kernel.kernel_name, + []).append( + kernel.device_dur / Constant.US_TO_MS) diff --git a/profiler/compare_tools/compare_backend/compare_bean/origin_data_bean/kernel_details_bean.py b/profiler/compare_tools/compare_backend/compare_bean/origin_data_bean/kernel_details_bean.py index ef5e59c555507a9e97d6a2b0c7824110c4b3fce7..84b8eae7c64b3f9d2ca261ded52052290fa50fa7 100644 --- a/profiler/compare_tools/compare_backend/compare_bean/origin_data_bean/kernel_details_bean.py +++ b/profiler/compare_tools/compare_backend/compare_bean/origin_data_bean/kernel_details_bean.py @@ -3,6 +3,7 @@ import math import pandas as pd from compare_backend.utils.common_func import convert_to_float +from compare_backend.utils.constant import Constant class KernelDetailsBean: @@ -68,6 +69,13 @@ class KernelDetailsBean: def is_cube(self): return "matmul" in self.op_type.lower() + def is_conv(self): + return self.op_type.lower().startswith("conv") + + def is_conv_bwd(self): + lower_op_type = self.op_type.lower() + return any(bwd in lower_op_type for bwd in Constant.BWD_LIST) + def init(self): self._op_type = self._data.get('Type', "") self._name = self._data.get('Name', "") diff --git a/profiler/compare_tools/compare_backend/compare_bean/origin_data_bean/trace_event_bean.py b/profiler/compare_tools/compare_backend/compare_bean/origin_data_bean/trace_event_bean.py index 6ce91ba53c8f2a9286319f35f76b62773743bc49..7f51a2d80ae5a3e0d8644922eef8fa45a45c3d73 100644 --- a/profiler/compare_tools/compare_backend/compare_bean/origin_data_bean/trace_event_bean.py +++ b/profiler/compare_tools/compare_backend/compare_bean/origin_data_bean/trace_event_bean.py @@ -181,11 +181,25 @@ class TraceEventBean: return self.task_type == 'EVENT_WAIT_SQE' def is_backward(self): - bwd_list = ["bwd", "backward"] - for bwd in bwd_list: - if bwd in self.lower_name: - return True - return False + return any(bwd in self.lower_name for bwd in Constant.BWD_LIST) + + def is_python_function(self): + return self.lower_cat == "python_function" + + def is_optimizer(self): + return self.lower_name.startswith("optimizer") + + def is_fwdbwd(self): + return self.lower_cat == "fwdbwd" + + def is_step_profiler(self): + return self.name.find("ProfilerStep#") != -1 + + def reset_name(self, name): + self._name = name + + def is_conv(self): + return self.name.lower().startswith("aten::conv") def init(self): if isinstance(self._event, dict): diff --git a/profiler/compare_tools/compare_backend/compare_bean/profiling_info.py b/profiler/compare_tools/compare_backend/compare_bean/profiling_info.py index 9184c790b7ea59246b602442a13e7e533d921bc8..44b277141b7e2ae24dfb68569445aaf92d86be34 100644 --- a/profiler/compare_tools/compare_backend/compare_bean/profiling_info.py +++ b/profiler/compare_tools/compare_backend/compare_bean/profiling_info.py @@ -16,8 +16,13 @@ class ProfilingInfo: self.sdma_num = 0 self.fa_num_fwd = 0 self.fa_num_bwd = 0 + self.conv_time_fwd = 0.0 + self.conv_time_bwd = 0.0 + self.conv_num_fwd = 0 + self.conv_num_bwd = 0 self.compute_time = 0.0 self.communication_not_overlapped = 0.0 + self.wait_time = 0.0 self.memory_used = 0.0 self.e2e_time = 0.0 self.sdma_time = 0.0 @@ -26,6 +31,7 @@ class ProfilingInfo: self.fa_time_fwd = 0.0 self.minimal_profiling = False self.hide_op_details = False + self.is_level0 = False def trans_time_to_s(self): self.cube_time = self.cube_time / 10 ** 6 @@ -33,18 +39,23 @@ class ProfilingInfo: self.vec_time = self.vec_time / 10 ** 6 self.compute_time = self.compute_time / 10 ** 6 self.communication_not_overlapped = self.communication_not_overlapped / 10 ** 6 + self.wait_time = self.wait_time / 10 ** 6 self.e2e_time = self.e2e_time / 10 ** 6 self.sdma_time = self.sdma_time / 10 ** 6 self.scheduling_time = self.scheduling_time / 10 ** 6 self.fa_time_bwd = self.fa_time_bwd / 10 ** 6 self.fa_time_fwd = self.fa_time_fwd / 10 ** 6 + self.conv_time_fwd = self.conv_time_fwd / 10 ** 6 + self.conv_time_bwd = self.conv_time_bwd / 10 ** 6 def calculate_other_time(self): self.other_time = max( - [0, self.compute_time - self.cube_time - self.fa_time_fwd - self.fa_time_bwd - self.vec_time]) + [0, self.compute_time - self.cube_time - self.fa_time_fwd - self.fa_time_bwd - + self.vec_time - self.conv_time_fwd - self.conv_time_bwd]) def calculate_vec_time(self): - self.vec_time = self.compute_time - self.cube_time - self.fa_time_fwd - self.fa_time_bwd + self.vec_time = self.compute_time - self.cube_time - self.fa_time_fwd - self.fa_time_bwd \ + - self.conv_time_fwd - self.conv_time_bwd def calculate_schedule_time(self): self.scheduling_time = self.e2e_time - self.compute_time - self.communication_not_overlapped @@ -57,6 +68,14 @@ class ProfilingInfo: self.fa_time_bwd += time self.fa_num_bwd += 1 + def update_conv_fwd_info(self, time: float): + self.conv_time_fwd += time + self.conv_num_fwd += 1 + + def update_conv_bwd_info(self, time: float): + self.conv_time_bwd += time + self.conv_num_bwd += 1 + def update_sdma_info(self, time: float, num: int = 1): self.sdma_time += time self.sdma_num += num @@ -84,6 +103,9 @@ class ProfilingInfo: def update_comm_not_overlap(self, time: float): self.communication_not_overlapped += time + def update_comm_not_overlap_wait_time(self, time: float): + self.wait_time = time + def set_memory_used(self, memory: float): self.memory_used = memory diff --git a/profiler/compare_tools/compare_backend/data_prepare/__init__.py b/profiler/compare_tools/compare_backend/data_prepare/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/profiler/compare_tools/compare_backend/data_prepare/module_data_prepare.py b/profiler/compare_tools/compare_backend/data_prepare/module_data_prepare.py new file mode 100644 index 0000000000000000000000000000000000000000..6d45b98dd700117d01d8f55a6a8de66983f25f8a --- /dev/null +++ b/profiler/compare_tools/compare_backend/data_prepare/module_data_prepare.py @@ -0,0 +1,97 @@ +from queue import Queue + +from compare_backend.compare_bean.origin_data_bean.trace_event_bean import TraceEventBean +from compare_backend.profiling_parser.base_profiling_parser import ProfilingResult +from compare_backend.utils.constant import Constant +from compare_backend.utils.module_node import ModuleNode +from compare_backend.utils.tree_builder import TreeBuilder + + +class ModuleDataPrepare: + def __init__(self, profiling_data: ProfilingResult): + self.profiling_data = profiling_data + self._nn_module_list = [] + self._call_function = [] + for event in profiling_data.python_function_data: + if event.lower_name.startswith("nn.module:"): + self._nn_module_list.append(event) + else: + self._call_function.append(event) + self._bwd_dict = {} + self._bwd_pid = self._get_bwd_pid() + + @staticmethod + def update_module_node_info(fwd_root_node, bwd_root_node, func_root_node): + queue = Queue() + queue.put(fwd_root_node) + queue.put(bwd_root_node) + while not queue.empty(): + module_node = queue.get() + module_node.update_torch_op_kernel_list() + call_function = func_root_node.find_module_call(module_node.start_time) + if call_function: + module_node.reset_call_stack(call_function.call_stack) + for sub_module_node in module_node.child_nodes: + queue.put(sub_module_node) + + def build_module_tree(self): + if not self._nn_module_list: + return [None, None] + self._dispatch_torch_op() + event_list = [TraceEventBean({"ts": ts}) for ts in self.profiling_data.kernel_dict.keys()] + self._nn_module_list.extend(event_list) + root_node = TreeBuilder.build_module_tree(self._nn_module_list, self.profiling_data.kernel_dict) + func_root_node = TreeBuilder.build_module_tree(self._call_function, {}) + bwd_module_list = self.get_bwd_module(root_node) + if bwd_module_list: + bwd_module_list.extend(event_list) + bwd_root_node = TreeBuilder.build_module_tree(bwd_module_list, self.profiling_data.kernel_dict) + self.match_torch_op(root_node, bwd_root_node) + self.update_module_node_info(root_node, bwd_root_node, func_root_node) + return [root_node, bwd_root_node] + + def get_bwd_module(self, root_node: ModuleNode): + bwd_module_list = [] + for flow in self.profiling_data.fwdbwd_dict.values(): + start_point = flow.get("start") + end_point = flow.get("end") + if not start_point or not end_point: + continue + end_event = self._bwd_dict.get(end_point.start_time) + if not end_event: + continue + call_module = root_node.find_module_call(start_point.start_time) + if call_module: + end_event.reset_name(f"[ BACKWARD ]{call_module.module_name}") + bwd_module_list.append(end_event) + return bwd_module_list + + def match_torch_op(self, fwd_root_node, bwd_root_node): + torch_op_list = sorted(self.profiling_data.torch_op_data, key=lambda x: x.start_time) + for torch_op in torch_op_list: + if torch_op.is_optimizer(): + continue + if torch_op.is_step_profiler(): + continue + matched_module = fwd_root_node.find_module_call(torch_op.start_time) + if matched_module: + matched_module.find_torch_op_call(torch_op) + continue + matched_module = bwd_root_node.find_module_call(torch_op.start_time) + if matched_module: + matched_module.find_torch_op_call(torch_op) + + def _dispatch_torch_op(self): + for torch_op in self.profiling_data.torch_op_data: + if torch_op.is_optimizer(): + self._nn_module_list.append(torch_op) + continue + if torch_op.pid == self._bwd_pid: + self._bwd_dict[torch_op.start_time] = torch_op + + def _get_bwd_pid(self): + for flow in self.profiling_data.fwdbwd_dict.values(): + end_point = flow.get("end") + if end_point: + return end_point.pid + return Constant.INVALID_VALUE diff --git a/profiler/compare_tools/compare_backend/data_prepare/operator_data_prepare.py b/profiler/compare_tools/compare_backend/data_prepare/operator_data_prepare.py new file mode 100644 index 0000000000000000000000000000000000000000..fdce23c6ab4ff7f9f6f7d6bc1442063c57cb6098 --- /dev/null +++ b/profiler/compare_tools/compare_backend/data_prepare/operator_data_prepare.py @@ -0,0 +1,19 @@ +from compare_backend.profiling_parser.base_profiling_parser import ProfilingResult +from compare_backend.utils.tree_builder import TreeBuilder + + +class OperatorDataPrepare: + def __init__(self, profiling_data: ProfilingResult): + self.profiling_data = profiling_data + + def get_top_layer_ops(self) -> any: + root_node = TreeBuilder.build_tree(self.profiling_data.torch_op_data, self.profiling_data.kernel_dict, + self.profiling_data.memory_list) + level1_child_nodes = root_node.child_nodes + result_data = [] + for level1_node in level1_child_nodes: + if level1_node.is_step_profiler(): + result_data.extend(level1_node.child_nodes) + else: + result_data.append(level1_node) + return result_data diff --git a/profiler/compare_tools/compare_backend/disaggregate/__init__.py b/profiler/compare_tools/compare_backend/disaggregate/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/profiler/compare_tools/compare_backend/disaggregate/overall_perf_interface.py b/profiler/compare_tools/compare_backend/disaggregate/overall_perf_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..c89e84519302781a590523bc7fdaaf9e1254acf5 --- /dev/null +++ b/profiler/compare_tools/compare_backend/disaggregate/overall_perf_interface.py @@ -0,0 +1,34 @@ +from common_func.path_manager import PathManager +from compare_backend.profiling_parser.gpu_profiling_parser import GPUProfilingParser +from compare_backend.profiling_parser.npu_profiling_parser import NPUProfilingParser +from compare_backend.utils.args_manager import ArgsManager +from compare_backend.utils.compare_args import Args +from compare_backend.utils.constant import Constant + + +class OverallPerfInterface: + PARSER_DICT = {Constant.NPU: NPUProfilingParser, Constant.GPU: GPUProfilingParser} + + def __init__(self, profiling_path: str): + self._profiling_path = profiling_path + self._profiling_path_dict = {} + self._result_data = {} + + def run(self): + self._check_path() + self._load_data() + self._generate_result() + return self._result_data + + def _check_path(self): + profiling_path = PathManager.get_realpath(self._profiling_path) + self._profiling_path_dict = ArgsManager().parse_profiling_path(profiling_path) + + def _load_data(self): + args = Args(enable_profiling_compare=True) + profiling_type = self._profiling_path_dict.get(Constant.PROFILING_TYPE, Constant.NPU) + self._profiling_data = self.PARSER_DICT.get(profiling_type)(args, self._profiling_path_dict).load_data() + + def _generate_result(self): + overall_data = self._profiling_data.overall_metrics + self._result_data = getattr(overall_data, "__dict__", {}) diff --git a/profiler/compare_tools/compare_backend/generator/base_generator.py b/profiler/compare_tools/compare_backend/generator/base_generator.py index c472bc9922e6febf118f62a66424056243156c07..e77071b5998a9915d09c54f8b4c811d434555167 100644 --- a/profiler/compare_tools/compare_backend/generator/base_generator.py +++ b/profiler/compare_tools/compare_backend/generator/base_generator.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from collections import OrderedDict from multiprocessing import Process @@ -7,7 +8,7 @@ class BaseGenerator(Process, ABC): super(BaseGenerator, self).__init__() self._profiling_data_dict = profiling_data_dict self._args = args - self._result_data = {} + self._result_data = OrderedDict() def run(self): self.compare() diff --git a/profiler/compare_tools/compare_backend/generator/detail_performance_generator.py b/profiler/compare_tools/compare_backend/generator/detail_performance_generator.py index 72ce3ba86893b08ffdd8deff5c586731db4b84f5..a7d95228d47023f78434529f79ece2c59edd543c 100644 --- a/profiler/compare_tools/compare_backend/generator/detail_performance_generator.py +++ b/profiler/compare_tools/compare_backend/generator/detail_performance_generator.py @@ -1,23 +1,28 @@ import os from collections import deque from datetime import datetime - -import numpy as np +from queue import Queue from compare_backend.comparator.communication_comparator import CommunicationComparator +from compare_backend.comparator.module_comparetor import ModuleComparator +from compare_backend.comparator.module_statistic_comparator import ModuleStatisticComparator from compare_backend.comparator.operator_comparator import OperatorComparator from compare_backend.comparator.operator_statistic_comparator import OperatorStatisticComparator from compare_backend.compare_bean.communication_bean import CommunicationBean from compare_backend.compare_bean.memory_compare_bean import MemoryCompareBean from compare_backend.compare_bean.memory_statistic_bean import MemoryStatisticBean +from compare_backend.compare_bean.module_compare_bean import ModuleCompareBean +from compare_backend.compare_bean.module_statistic_bean import ModuleStatisticBean from compare_backend.compare_bean.operator_compare_bean import OperatorCompareBean from compare_backend.compare_bean.operator_statistic_bean import OperatorStatisticBean +from compare_backend.data_prepare.module_data_prepare import ModuleDataPrepare +from compare_backend.data_prepare.operator_data_prepare import OperatorDataPrepare from compare_backend.generator.base_generator import BaseGenerator -from compare_backend.profiling_parser.base_profiling_parser import ProfilingResult +from compare_backend.utils.common_func import longest_common_subsequence_matching from compare_backend.utils.constant import Constant +from compare_backend.utils.module_node import ModuleNode from compare_backend.utils.name_function import NameFunction from compare_backend.utils.torch_op_node import TorchOpNode -from compare_backend.utils.tree_builder import TreeBuilder from compare_backend.view.excel_view import ExcelView @@ -44,7 +49,14 @@ class DetailPerformanceGenerator(BaseGenerator): def _create_comparator(self): comparator_list = [] - if self._args.enable_operator_compare or self._args.enable_memory_compare: + + op_compare_result = [] + if self._args.enable_operator_compare: + module_compare_result = self.match_nn_module() + if not module_compare_result: + op_compare_result = self.match_torch_op() + + if self._args.enable_memory_compare and not op_compare_result: op_compare_result = self.match_torch_op() if self._args.enable_communication_compare: @@ -54,89 +66,28 @@ class DetailPerformanceGenerator(BaseGenerator): comparator_list.append(CommunicationComparator(communication_data, CommunicationBean)) if self._args.enable_operator_compare: - comparator_list.append(OperatorComparator(op_compare_result, OperatorCompareBean)) - comparator_list.append(OperatorStatisticComparator(op_compare_result, OperatorStatisticBean)) - + if module_compare_result: + comparator_list.append(ModuleStatisticComparator(module_compare_result, ModuleStatisticBean)) + comparator_list.append(ModuleComparator(module_compare_result, ModuleCompareBean)) + else: + comparator_list.append(OperatorStatisticComparator(op_compare_result, OperatorStatisticBean)) + comparator_list.append(OperatorComparator(op_compare_result, OperatorCompareBean)) if self._args.enable_memory_compare: - comparator_list.append(OperatorComparator(op_compare_result, MemoryCompareBean)) comparator_list.append(OperatorStatisticComparator(op_compare_result, MemoryStatisticBean)) + comparator_list.append(OperatorComparator(op_compare_result, MemoryCompareBean)) return comparator_list def match_torch_op(self) -> list: - base_ops = self._get_top_layer_ops(self._profiling_data_dict.get(Constant.BASE_DATA)) - comparison_ops = self._get_top_layer_ops(self._profiling_data_dict.get(Constant.COMPARISON_DATA)) + base_ops = OperatorDataPrepare(self._profiling_data_dict.get(Constant.BASE_DATA)).get_top_layer_ops() + comparison_ops = OperatorDataPrepare( + self._profiling_data_dict.get(Constant.COMPARISON_DATA)).get_top_layer_ops() if not base_ops and not comparison_ops: return [] name_func = NameFunction(self._args).get_name_func() - compare_result_data = self._matching_op(base_ops, comparison_ops, name_func) + op_compare_result = longest_common_subsequence_matching(base_ops, comparison_ops, name_func) if self._args.max_kernel_num is not None: - compare_result_data = self._drill_down(compare_result_data, name_func) - return compare_result_data - - @classmethod - def _matching_op(cls, base_ops: list, comparison_ops: list, name_func: any) -> list: - if not comparison_ops: - result_data = [None] * len(base_ops) - for index, value in enumerate(base_ops): - result_data[index] = [value, None] - return result_data - - result_data = [] - comparison_len, base_len = len(comparison_ops), len(base_ops) - dp = [[0] * (base_len + 1) for _ in range(comparison_len + 1)] - for comparison_index in range(1, comparison_len + 1): - for base_index in range(1, base_len + 1): - if name_func(base_ops[base_index - 1]) == name_func( - comparison_ops[comparison_index - 1]): - dp[comparison_index][base_index] = dp[comparison_index - 1][base_index - 1] + 1 - else: - dp[comparison_index][base_index] = max(dp[comparison_index][base_index - 1], - dp[comparison_index - 1][base_index]) - matched_op = [] - comparison_index, base_index = comparison_len, base_len - while comparison_index > 0 and base_index > 0: - if name_func(base_ops[base_index - 1]) == name_func( - comparison_ops[comparison_index - 1]): - matched_op.append([comparison_index - 1, base_index - 1]) - comparison_index -= 1 - base_index -= 1 - continue - if dp[comparison_index][base_index - 1] > dp[comparison_index - 1][base_index]: - base_index -= 1 - else: - comparison_index -= 1 - if not matched_op: - matched_base_index_list = [] - else: - matched_op.reverse() - matched_op = np.array(matched_op) - matched_base_index_list = list(matched_op[:, 1]) - curr_comparison_index = 0 - for base_index, base_api_node in enumerate(base_ops): - if base_index not in matched_base_index_list: - result_data.append([base_api_node, None]) - continue - matched_comparison_index = matched_op[matched_base_index_list.index(base_index), 0] - for comparison_index in range(curr_comparison_index, matched_comparison_index): - result_data.append([None, comparison_ops[comparison_index]]) - result_data.append([base_api_node, comparison_ops[matched_comparison_index]]) - curr_comparison_index = matched_comparison_index + 1 - if curr_comparison_index < len(comparison_ops): - for comparison_index in range(curr_comparison_index, len(comparison_ops)): - result_data.append([None, comparison_ops[comparison_index]]) - return result_data - - def _get_top_layer_ops(self, profiling_data: ProfilingResult) -> any: - root_node = TreeBuilder.build_tree(profiling_data.torch_op_data, profiling_data.kernel_dict, - profiling_data.memory_list) - level1_child_nodes = root_node.child_nodes - result_data = [] - for level1_node in level1_child_nodes: - if level1_node.is_step_profiler(): - result_data.extend(level1_node.child_nodes) - else: - result_data.append(level1_node) - return result_data + op_compare_result = self._drill_down(op_compare_result, name_func) + return op_compare_result def _drill_down(self, compare_result_data: list, name_func: any) -> list: drill_down_result = [] @@ -152,9 +103,41 @@ class DetailPerformanceGenerator(BaseGenerator): if max(base_op.kernel_num, comparison_op.kernel_num) <= self._args.max_kernel_num: drill_down_result.append(match_data) continue - match_list = self._matching_op(base_op.child_nodes, comparison_op.child_nodes, name_func) + match_list = longest_common_subsequence_matching(base_op.child_nodes, comparison_op.child_nodes, name_func) match_list.reverse() for data in match_list: op_deque.append(data) return drill_down_result + + def match_nn_module(self) -> list: + module_compare_result = [] + base_root_node = ModuleDataPrepare(self._profiling_data_dict.get(Constant.BASE_DATA)).build_module_tree() + comparison_root_node = ModuleDataPrepare( + self._profiling_data_dict.get(Constant.COMPARISON_DATA)).build_module_tree() + for index, base_node in enumerate(base_root_node): + comparison_node = comparison_root_node[index] if index < len(comparison_root_node) else None + if not base_node or not comparison_node: + continue + module_compare_result.extend(self._matching_all_modules(base_node, comparison_node)) + return module_compare_result + + def _matching_all_modules(self, base_node: ModuleNode, comparison_node: ModuleNode): + all_matched_modules = [] + matched_queue = Queue() + matched_queue.put([base_node, comparison_node]) + while not matched_queue.empty(): + matched_base_node, matched_comparison_node = matched_queue.get() + matched_node_list = self._matching_common_subsequence(matched_base_node, matched_comparison_node) + all_matched_modules.extend(matched_node_list) + for matched_node in matched_node_list: + matched_queue.put(matched_node) + return all_matched_modules + + def _matching_common_subsequence(self, base_node: ModuleNode, comparison_node: ModuleNode): + base_modules = base_node.child_nodes if base_node else [] + comparison_modules = comparison_node.child_nodes if comparison_node else [] + if not base_modules and not comparison_modules: + return [] + name_func = NameFunction(self._args).get_module_name + return longest_common_subsequence_matching(base_modules, comparison_modules, name_func) diff --git a/profiler/compare_tools/compare_backend/profiling_parser/base_profiling_parser.py b/profiler/compare_tools/compare_backend/profiling_parser/base_profiling_parser.py index 4c0b51272b0bbf71f6632a7b28005bae2298d056..2127ff5e75e23e98f0debb0dfdafbeb01930c082 100644 --- a/profiler/compare_tools/compare_backend/profiling_parser/base_profiling_parser.py +++ b/profiler/compare_tools/compare_backend/profiling_parser/base_profiling_parser.py @@ -4,7 +4,6 @@ from decimal import Decimal from compare_backend.compare_bean.origin_data_bean.compare_event import KernelEvent, MemoryEvent from compare_backend.compare_bean.origin_data_bean.trace_event_bean import TraceEventBean from compare_backend.compare_bean.profiling_info import ProfilingInfo -from compare_backend.utils.args_manager import ArgsManager from compare_backend.utils.constant import Constant from compare_backend.utils.file_reader import FileReader @@ -18,11 +17,19 @@ class ProfilingResult: self.memory_list = [] self.communication_dict = {} self.overall_metrics = ProfilingInfo(profiling_type) + self.python_function_data = [] + self.fwdbwd_dict = {} def update_torch_op_data(self, event: TraceEventBean): event.is_torch_op = True self.torch_op_data.append(event) + def update_python_function_data(self, event: TraceEventBean): + self.python_function_data.append(event) + + def update_fwdbwd_data(self, flow_type: str, event: TraceEventBean): + self.fwdbwd_dict.setdefault(event.id, {})[flow_type] = event + def update_kernel_dict(self, start_time: Decimal, kernel_event: TraceEventBean): self.kernel_dict.setdefault(start_time, []).append(KernelEvent(kernel_event, self._profiling_type)) @@ -45,14 +52,15 @@ class BaseProfilingParser(ABC): self._profiling_path = path_dict.get(Constant.PROFILING_PATH) self._json_path = path_dict.get(Constant.TRACE_PATH) self._trace_events = [] if self._profiling_path == Constant.NPU else {} - self._enable_profiling_compare = ArgsManager().enable_profiling_compare - self._enable_operator_compare = ArgsManager().enable_operator_compare - self._enable_memory_compare = ArgsManager().enable_memory_compare - self._enable_communication_compare = ArgsManager().enable_communication_compare + self._enable_profiling_compare = args.enable_profiling_compare + self._enable_operator_compare = args.enable_operator_compare + self._enable_memory_compare = args.enable_memory_compare + self._enable_communication_compare = args.enable_communication_compare self._dispatch_func = self._get_dispatch_func() self._result_data = ProfilingResult(self._profiling_type) self._memory_events = [] self._flow_dict = {} + self._fwdbwd_dict = {} self._all_kernels = {} self._comm_task_list = [] self._comm_list = [] @@ -134,6 +142,21 @@ class BaseProfilingParser(ABC): return True return False + def _picking_python_function_event(self, event: TraceEventBean): + if event.is_python_function(): + self._result_data.update_python_function_data(event) + return True + return False + + def _picking_fwdbwd_flow_event(self, event: TraceEventBean): + if event.is_fwdbwd(): + if event.is_flow_start(): + self._result_data.update_fwdbwd_data("start", event) + elif event.is_flow_end(): + self._result_data.update_fwdbwd_data("end", event) + return True + return False + def _update_kernel_dict(self): if self._profiling_type == Constant.NPU: for comm in self._comm_list: diff --git a/profiler/compare_tools/compare_backend/profiling_parser/gpu_profiling_parser.py b/profiler/compare_tools/compare_backend/profiling_parser/gpu_profiling_parser.py index 2ad2e1a557fad7095bea642892c64f32363182e9..c4089aec9bdcb35b80ae9ff9121fcd75bde3a63e 100644 --- a/profiler/compare_tools/compare_backend/profiling_parser/gpu_profiling_parser.py +++ b/profiler/compare_tools/compare_backend/profiling_parser/gpu_profiling_parser.py @@ -3,23 +3,23 @@ from collections import defaultdict, Counter from compare_backend.compare_bean.origin_data_bean.trace_event_bean import TraceEventBean from compare_backend.profiling_parser.base_profiling_parser import BaseProfilingParser -from compare_backend.utils.args_manager import ArgsManager from compare_backend.utils.constant import Constant class GPUProfilingParser(BaseProfilingParser): - CUBE_MARK = 'gemm' - FA_MARK_LIST = [['fmha', 'kernel'], ['flash', 'kernel']] + CUBE_MARK = ['gemm', 'conv', 'cutlass', 'wgrad'] + FA_MARK_LIST = [['fmha', 'kernel'], ['flash', 'kernel'], ['attention', 'kernel']] SDMA_MARK_LIST = ['htod', 'dtod', 'dtoh', 'memset (device)'] FLOW_CAT = ("async_gpu", "async_cpu_to_gpu", "ac2g", "async") - TORCH_OP_CAT = ("cpu_op", "user_annotation", "cuda_runtime", "operator") + TORCH_OP_CAT = ("cpu_op", "user_annotation", "cuda_runtime", "operator", "runtime") def __init__(self, args: any, path_dict: dict): super().__init__(args, path_dict) self._trace_events = [TraceEventBean(event) for event in self._trace_events.get("traceEvents", [])] - self._flow_cat = (ArgsManager().args.gpu_flow_cat,) if ArgsManager().args.gpu_flow_cat else self.FLOW_CAT + self._flow_cat = (args.gpu_flow_cat,) if args.gpu_flow_cat else self.FLOW_CAT self._compute_stream_id = self._infer_compute_stream_id() self._marks = defaultdict(int) + self._aten_index = 0 @classmethod def __is_flash_attention(cls, name: str): @@ -67,6 +67,14 @@ class GPUProfilingParser(BaseProfilingParser): def _calculate_performance_time(self): min_ts = sys.float_info.max max_ts = sys.float_info.min + self._trace_events.sort(key=lambda x: x.start_time) + aten_events = list(filter(lambda x: x.name.startswith("aten::"), self._trace_events)) + flow_dict_new = {} + for flow_event in self._flow_dict.values(): + start_event = flow_event.get("start") + end_event = flow_event.get("end") + if start_event and end_event: + flow_dict_new[end_event.start_time] = start_event.start_time for event in self._trace_events: if event.stream: min_ts = min(event.start_time, min_ts) @@ -79,7 +87,8 @@ class GPUProfilingParser(BaseProfilingParser): self.__add_marks(event) if event.is_nccl_name(): continue - self.__add_compute_time(event) + self.__add_compute_time(event, aten_events, flow_dict_new) + self._aten_events = None self._result_data.overall_metrics.set_e2e_time(float(max_ts - min_ts)) self.__add_compute_and_overlap_time() @@ -97,17 +106,38 @@ class GPUProfilingParser(BaseProfilingParser): for timestep in range(int(event.start_time + 1), int(event.end_time + 1)): self._marks[str(timestep)] += -100 # mark this timestep in compute stream - def __add_compute_time(self, event: TraceEventBean): + def __add_compute_time(self, event: TraceEventBean, aten_events: list, flow_dict_new: dict): if self.__is_flash_attention(event.name): if event.is_backward(): self._result_data.overall_metrics.update_fa_bwd_info(event.dur) else: self._result_data.overall_metrics.update_fa_fwd_info(event.dur) - elif self.CUBE_MARK in event.lower_name: - self._result_data.overall_metrics.update_cube_info(event.dur) + elif any(cube_mark in event.lower_name for cube_mark in self.CUBE_MARK): + is_conv = self.__check_is_conv(event, aten_events, flow_dict_new) + if is_conv == "conv_fwd": + self._result_data.overall_metrics.update_conv_fwd_info(event.dur) + elif is_conv == "conv_bwd": + self._result_data.overall_metrics.update_conv_bwd_info(event.dur) + else: + self._result_data.overall_metrics.update_cube_info(event.dur) else: self._result_data.overall_metrics.update_vec_info(event.dur) + def __check_is_conv(self, event: TraceEventBean, aten_events: list, flow_dict_new: dict) -> str: + flow_start_time = flow_dict_new.get(event.start_time) + if not flow_start_time: + return "" + aten_len = len(aten_events) + while self._aten_index < aten_len: + cur_aten = aten_events[self._aten_index] + if cur_aten.end_time < flow_start_time: + self._aten_index += 1 + continue + if cur_aten.start_time < flow_start_time: + if cur_aten.is_conv(): + return "conv_bwd" if cur_aten.is_backward() else "conv_fwd" + return "" + def _picking_memory_event(self, event: TraceEventBean): if event.is_memory_event(): self._memory_events.append(event) @@ -136,6 +166,9 @@ class GPUProfilingParser(BaseProfilingParser): func_set.add(self._picking_torch_op_event) if self._enable_communication_compare: func_set.add(self._picking_kernel_event) + if self._enable_operator_compare: + func_set.add(self._picking_python_function_event) + func_set.add(self._picking_fwdbwd_flow_event) if self._enable_operator_compare or self._args.max_kernel_num: func_set.add(self._picking_kernel_event) func_set.add(self._picking_flow_event) diff --git a/profiler/compare_tools/compare_backend/profiling_parser/npu_profiling_parser.py b/profiler/compare_tools/compare_backend/profiling_parser/npu_profiling_parser.py index f872e52a5314a40dbc2e0d4ff7868e875986b809..b068366c96e388acd18fe899663a31d013a9cd9c 100644 --- a/profiler/compare_tools/compare_backend/profiling_parser/npu_profiling_parser.py +++ b/profiler/compare_tools/compare_backend/profiling_parser/npu_profiling_parser.py @@ -41,6 +41,9 @@ class NPUProfilingParser(BaseProfilingParser): if self._enable_operator_compare or self._args.max_kernel_num: func_list.add(self._picking_kernel_event) func_list.add(self._picking_flow_event) + if self._enable_operator_compare: + func_list.add(self._picking_python_function_event) + func_list.add(self._picking_fwdbwd_flow_event) if self._enable_memory_compare: func_list.add(self._picking_task_queue_data) if self._enable_communication_compare: @@ -48,6 +51,7 @@ class NPUProfilingParser(BaseProfilingParser): if self._enable_profiling_compare: func_list.add(self._picking_overlap_analysis_data) func_list.add(self._picking_kernel_event) + func_list.add(self._picking_hccl_event) return list(func_list) def _update_memory_list(self): @@ -98,10 +102,73 @@ class NPUProfilingParser(BaseProfilingParser): self.__parse_kernel_csv() self.__add_sdma_time() self.__add_overlap_analysis_time() + self._picking_notify_wait_event_and_not_overlap_event() + self.__add_overlap_wait_time() self._result_data.overall_metrics.calculate_other_time() self._result_data.overall_metrics.calculate_schedule_time() self._result_data.overall_metrics.trans_time_to_s() + def _picking_notify_wait_event_and_not_overlap_event(self): + self.notify_event_cache = [] + self._not_overlaped_commu_event = [] + for event in self._comm_task_list: + if event.name == 'Notify_Wait' and event.args.get('rdma_type', 0) != 'RDMA_PAYLOAD_CHECK' \ + and event.args.get('rdma_type', 0) != 'RDMA_PAYLOAD_ACK': + self.notify_event_cache.append(event) + for event in self._overlap_analysis: + if event.is_comm_not_overlap(): + self._not_overlaped_commu_event.append(event) + self._not_overlaped_commu_event.sort(key=lambda x: x.start_time) + + def __add_overlap_wait_time(self): + notify_wait_event_dict = dict() + for notify_event in self.notify_event_cache: + if notify_event.tid in notify_wait_event_dict: + notify_wait_event_dict[notify_event.tid].append(notify_event) + else: + notify_wait_event_dict[notify_event.tid] = [notify_event] + + if self._result_data.overall_metrics.is_level0: + return + + total_time = 0 + for commu_event in self._not_overlaped_commu_event: + wait_time_list = [0] + commu_event_start_time = float(commu_event.start_time) + commu_event_end_time = float(commu_event.start_time) + commu_event.dur + + for plane_id, events in notify_wait_event_dict.items(): + wait_time = 0 + idx = 0 + for notify_event in events: + notify_event_start_time = float(notify_event.start_time) + notify_event_end_time = float(notify_event.start_time) + notify_event.dur + if notify_event_start_time < commu_event_start_time and notify_event_end_time > \ + commu_event_end_time: + wait_time = commu_event_end_time - commu_event_start_time + break + elif notify_event_start_time < commu_event_start_time <= notify_event_end_time <= \ + commu_event_end_time: + wait_time += notify_event_end_time - commu_event_start_time + idx += 1 + elif commu_event_start_time <= notify_event_start_time <= commu_event_end_time < \ + notify_event_end_time: + wait_time += commu_event_end_time - notify_event_start_time + break + elif notify_event_start_time >= commu_event_start_time and notify_event_end_time <= \ + commu_event_end_time: + wait_time += notify_event_end_time - notify_event_start_time + idx += 1 + elif notify_event_end_time < commu_event_start_time: + idx += 1 + else: + break + + wait_time_list.append(wait_time) + notify_wait_event_dict[plane_id] = notify_wait_event_dict[plane_id][idx:] + total_time += max(wait_time_list) + self._result_data.overall_metrics.update_comm_not_overlap_wait_time(total_time) + def _picking_hccl_event(self, event: TraceEventBean): if event.pid != self._hccl_pid or not event.is_x_mode(): return False @@ -162,9 +229,11 @@ class NPUProfilingParser(BaseProfilingParser): if not isinstance(json_data, dict) or not json_data: print('[WARNING] Invalid profiler info.') return - if self.ACTIVE_CPU in json_data.get('config', {}).get('common_config', {}).get('activities', []): + level = json_data.get('config', {}).get('experimental_config', {}).get('_profiler_level', '') + if self.LEVEL_0 != level: return - if self.LEVEL_0 != json_data.get('config', {}).get('experimental_config', {}).get('_profiler_level', ''): + self._result_data.overall_metrics.is_level0 = True + if self.ACTIVE_CPU in json_data.get('config', {}).get('common_config', {}).get('activities', []): return self._result_data.overall_metrics.minimal_profiling = True @@ -185,6 +254,11 @@ class NPUProfilingParser(BaseProfilingParser): self._result_data.overall_metrics.update_fa_bwd_info(kernel.duration) else: self._result_data.overall_metrics.update_fa_fwd_info(kernel.duration) + elif kernel.is_conv(): + if kernel.is_conv_bwd(): + self._result_data.overall_metrics.update_conv_bwd_info(kernel.duration) + else: + self._result_data.overall_metrics.update_conv_fwd_info(kernel.duration) elif kernel.is_cube(): self._result_data.overall_metrics.update_cube_info(kernel.duration) elif kernel.is_sdma(): @@ -235,7 +309,7 @@ class NPUProfilingParser(BaseProfilingParser): sdma_dict.setdefault(stream_id, []).append(event.dur) elif event.is_compute_event(): ai_core_stream.add(stream_id) - compute_stream = event_wait_stream & ai_core_stream + compute_stream = event_wait_stream & ai_core_stream if event_wait_stream else ai_core_stream for stream in compute_stream: dur_list = sdma_dict.get(stream, []) self._result_data.overall_metrics.update_sdma_info(sum(dur_list), len(dur_list)) diff --git a/profiler/compare_tools/compare_backend/utils/common_func.py b/profiler/compare_tools/compare_backend/utils/common_func.py index 26584626cd1786d32d4e7f5fcaef1a09d8726852..a3cab286e33a9d474e85d0b51023d73edc22ca56 100644 --- a/profiler/compare_tools/compare_backend/utils/common_func.py +++ b/profiler/compare_tools/compare_backend/utils/common_func.py @@ -1,5 +1,7 @@ from decimal import Decimal +import numpy + def calculate_diff_ratio(base_value: float, comparison_value: float): if not base_value and not comparison_value: @@ -31,3 +33,60 @@ def convert_to_decimal(data: any) -> Decimal: print('[ERROR] Invalid profiling data which failed to convert data to decimal.') return 0.0 return decimal_value + + +def longest_common_subsequence_matching(base_ops: list, comparison_ops: list, name_func: any) -> list: + if not comparison_ops: + result_data = [None] * len(base_ops) + for index, value in enumerate(base_ops): + result_data[index] = [value, None] + return result_data + + comparison_len, base_len = len(comparison_ops), len(base_ops) + dp_flag = numpy.zeros(shape=(comparison_len + 1, base_len + 1), dtype=int) + pre_list = [0] * (base_len + 1) + cur_list = [0] * (base_len + 1) + + comparison_index = 1 + iter_comparison_data = iter(comparison_ops) + for comparison_data in iter_comparison_data: + base_index = 1 + iter_base_data = iter(base_ops) + for base_data in iter_base_data: + if name_func(comparison_data) == name_func(base_data): + cur_list[base_index] = pre_list[base_index - 1] + 1 + else: + only_base = cur_list[base_index - 1] + only_comparison = pre_list[base_index] + if only_base < only_comparison: + dp_flag[comparison_index][base_index] = 1 # 1 for only comparison op + cur_list[base_index] = only_comparison + else: + cur_list[base_index] = only_base + base_index += 1 + pre_list = cur_list + comparison_index += 1 + + matched_op = [] + comparison_index, base_index = comparison_len, base_len + while comparison_index > 0 and base_index > 0: + base_data = base_ops[base_index - 1] + comparison_data = comparison_ops[comparison_index - 1] + if name_func(base_data) == name_func(comparison_data): + matched_op.append([base_data, comparison_data]) + comparison_index -= 1 + base_index -= 1 + elif dp_flag[comparison_index][base_index] == 1: # 1 for only comparison op + matched_op.append([None, comparison_data]) + comparison_index -= 1 + else: + matched_op.append([base_data, None]) + base_index -= 1 + while comparison_index > 0: + matched_op.append([None, comparison_ops[comparison_index - 1]]) + comparison_index -= 1 + while base_index > 0: + matched_op.append([base_ops[base_index - 1], None]) + base_index -= 1 + matched_op.reverse() + return matched_op diff --git a/profiler/compare_tools/compare_backend/utils/compare_args.py b/profiler/compare_tools/compare_backend/utils/compare_args.py new file mode 100644 index 0000000000000000000000000000000000000000..ab9bc364f440ca8412a6e40d67ca74b7c897cbd9 --- /dev/null +++ b/profiler/compare_tools/compare_backend/utils/compare_args.py @@ -0,0 +1,24 @@ +class Args: + def __init__(self, + base_profiling_path: str = "", + comparison_profiling_path: str = "", + enable_profiling_compare: bool = False, + enable_operator_compare: bool = False, + enable_memory_compare: bool = False, + enable_communication_compare: bool = False, + output_path: str = "", + max_kernel_num: int = None, + op_name_map: dict = {}, + use_input_shape: bool = False, + gpu_flow_cat: str = ""): + self.base_profiling_path = base_profiling_path + self.comparison_profiling_path = comparison_profiling_path + self.enable_profiling_compare = enable_profiling_compare + self.enable_operator_compare = enable_operator_compare + self.enable_memory_compare = enable_memory_compare + self.enable_communication_compare = enable_communication_compare + self.output_path = output_path + self.max_kernel_num = max_kernel_num + self.op_name_map = op_name_map + self.use_input_shape = use_input_shape + self.gpu_flow_cat = gpu_flow_cat diff --git a/profiler/compare_tools/compare_backend/utils/constant.py b/profiler/compare_tools/compare_backend/utils/constant.py index d44f9fea93649f5301fa436a1dcac6a39702112a..1b77b214c85f6733e36298e119e43a778fd7969f 100644 --- a/profiler/compare_tools/compare_backend/utils/constant.py +++ b/profiler/compare_tools/compare_backend/utils/constant.py @@ -53,6 +53,8 @@ class Constant(object): MEMORY_TOP_TABLE = "MemoryCompareStatistic" COMMUNICATION_TABLE = "CommunicationCompare" PERFORMANCE_TABLE = "Model Profiling Time Distribution" + MODULE_TABLE = "ModuleCompare" + MODULE_TOP_TABLE = "ModuleCompareStatistic" # memory SIZE = "Size(KB)" @@ -74,3 +76,5 @@ class Constant(object): #compare type OVERALL_COMPARE = "overall" + + BWD_LIST = ["bwd", "backward", "back"] diff --git a/profiler/compare_tools/compare_backend/utils/excel_config.py b/profiler/compare_tools/compare_backend/utils/excel_config.py index 50b2e6329e3b450fc85caca1c0b0d8ab8895a522..306abcdfec6e62f24977b989258ad190a90c9bd7 100644 --- a/profiler/compare_tools/compare_backend/utils/excel_config.py +++ b/profiler/compare_tools/compare_backend/utils/excel_config.py @@ -14,6 +14,10 @@ class CellFormatType: 'bold': True} # 字符串,无背景色,字体加粗 BLUE_BOLD = {"font_name": "Arial", 'font_size': 11, 'fg_color': Constant.BLUE_COLOR, 'align': 'left', 'valign': 'vcenter', 'bold': True, 'border': True} # 蓝色背景,加粗 + GREEN_BOLD = {"font_name": "Arial", 'font_size': 11, 'fg_color': Constant.GREEN_COLOR, 'align': 'left', + 'valign': 'vcenter', 'bold': True, 'border': True} # 绿色背景,加粗 + YELLOW_BOLD = {"font_name": "Arial", 'font_size': 11, 'fg_color': Constant.YELLOW_COLOR, 'align': 'left', + 'valign': 'vcenter', 'bold': True, 'border': True} # 黄色背景,加粗 class ExcelConfig(object): @@ -46,6 +50,21 @@ class ExcelConfig(object): AVG_DURATION = "Avg Duration(us)" MAX_DURATION = "Max Duration(us)" MIN_DURATION = "Min Duration(us)" + MODULE_CLASS = "Module Class" + MODULE_NAME = "Module Name" + DEVICE_SELF_TIME = "Device Self Time(ms)" + DEVICE_TOTAL_TIME = "Device Total Time(ms)" + DIFF_SELF_TIME = "Device Self Time Diff(ms)" + DIFF_TOTAL_RATIO = "Total Diff Ratio" + DIFF_TOTAL_TIME = "Device Total Time Diff(ms)" + DEVICE_SELF_TIME_US = "Device Self Time(us)" + DEVICE_TOTAL_TIME_US = "Device Total Time(us)" + DIFF_SELF_TIME_US = "Device Self Time Diff(us)" + DIFF_TOTAL_TIME_US = "Device Total Time Diff(us)" + NUMBER = "Number" + MODULE_LEVEL = "Module Level" + BASE_CALL_STACK = "Base Call Stack" + COMPARISON_CALL_STACK = "Comparison Call Stack" HEADERS = { Constant.OPERATOR_TABLE: [ @@ -118,9 +137,49 @@ class ExcelConfig(object): {"name": MIN_DURATION, "type": CellFormatType.DEFAULT_FLOAT, "width": 17}, {"name": DIFF_DUR, "type": CellFormatType.DEFAULT_FLOAT, "width": 20}, {"name": DIFF_RATIO, "type": CellFormatType.DEFAULT_RATIO, "width": 20} + ], + Constant.MODULE_TOP_TABLE: [ + {"name": ORDER, "type": CellFormatType.DEFAULT, "width": 10}, + {"name": MODULE_CLASS, "type": CellFormatType.DEFAULT, "width": 20}, + {"name": MODULE_LEVEL, "type": CellFormatType.DEFAULT, "width": 15}, + {"name": MODULE_NAME, "type": CellFormatType.DEFAULT, "width": 35}, + {"name": OPERATOR_NAME, "type": CellFormatType.DEFAULT, "width": 25}, + {"name": KERNEL_DETAILS, "type": CellFormatType.DEFAULT, "width": 20}, + {"name": DEVICE_SELF_TIME, "type": CellFormatType.DEFAULT_FLOAT, "width": 20}, + {"name": NUMBER, "type": CellFormatType.DEFAULT, "width": 10}, + {"name": DEVICE_TOTAL_TIME, "type": CellFormatType.DEFAULT_FLOAT, "width": 20}, + {"name": KERNEL_DETAILS, "type": CellFormatType.DEFAULT, "width": 20}, + {"name": DEVICE_SELF_TIME, "type": CellFormatType.DEFAULT_FLOAT, "width": 20}, + {"name": NUMBER, "type": CellFormatType.DEFAULT, "width": 10}, + {"name": DEVICE_TOTAL_TIME, "type": CellFormatType.DEFAULT_FLOAT, "width": 20}, + {"name": DIFF_TOTAL_TIME, "type": CellFormatType.DEFAULT_FLOAT, "width": 20}, + {"name": DIFF_SELF_TIME, "type": CellFormatType.DEFAULT_FLOAT, "width": 20}, + {"name": DIFF_TOTAL_RATIO, "type": CellFormatType.DEFAULT_RATIO, "width": 15}, + {"name": BASE_CALL_STACK, "type": CellFormatType.DEFAULT, "width": 30}, + {"name": COMPARISON_CALL_STACK, "type": CellFormatType.DEFAULT, "width": 30} + ], + Constant.MODULE_TABLE: [ + {"name": ORDER, "type": CellFormatType.DEFAULT, "width": 10}, + {"name": MODULE_CLASS, "type": CellFormatType.DEFAULT, "width": 20}, + {"name": MODULE_LEVEL, "type": CellFormatType.DEFAULT, "width": 15}, + {"name": MODULE_NAME, "type": CellFormatType.DEFAULT, "width": 35}, + {"name": OPERATOR_NAME, "type": CellFormatType.DEFAULT, "width": 25}, + {"name": KERNEL_DETAILS, "type": CellFormatType.DEFAULT, "width": 20}, + {"name": DEVICE_SELF_TIME_US, "type": CellFormatType.DEFAULT_FLOAT, "width": 20}, + {"name": DEVICE_TOTAL_TIME_US, "type": CellFormatType.DEFAULT_FLOAT, "width": 20}, + {"name": OPERATOR_NAME, "type": CellFormatType.DEFAULT, "width": 25}, + {"name": KERNEL_DETAILS, "type": CellFormatType.DEFAULT, "width": 20}, + {"name": DEVICE_SELF_TIME_US, "type": CellFormatType.DEFAULT_FLOAT, "width": 20}, + {"name": DEVICE_TOTAL_TIME_US, "type": CellFormatType.DEFAULT_FLOAT, "width": 20}, + {"name": DIFF_TOTAL_TIME_US, "type": CellFormatType.DEFAULT_FLOAT, "width": 20}, + {"name": DIFF_SELF_TIME_US, "type": CellFormatType.DEFAULT_FLOAT, "width": 20}, + {"name": DIFF_TOTAL_RATIO, "type": CellFormatType.DEFAULT_RATIO, "width": 15}, + {"name": BASE_CALL_STACK, "type": CellFormatType.DEFAULT, "width": 30}, + {"name": COMPARISON_CALL_STACK, "type": CellFormatType.DEFAULT, "width": 30} ] } OVERHEAD = {Constant.OPERATOR_TABLE: ["B1:F1", "G1:K1"], Constant.MEMORY_TABLE: ["B1:F1", "G1:K1"], Constant.COMMUNICATION_TABLE: ["B1:H1", "I1:O1"], Constant.OPERATOR_TOP_TABLE: ["C1:D1", "E1:F1"], - Constant.MEMORY_TOP_TABLE: ["C1:E1", "F1:H1"]} + Constant.MEMORY_TOP_TABLE: ["C1:E1", "F1:H1"], Constant.MODULE_TOP_TABLE: ["F1:I1", "J1:M1"], + Constant.MODULE_TABLE: ["E1:H1", "I1:L1"]} diff --git a/profiler/compare_tools/compare_backend/utils/module_node.py b/profiler/compare_tools/compare_backend/utils/module_node.py new file mode 100644 index 0000000000000000000000000000000000000000..f85606094ede7abc378c1b3d017b4a98c8800107 --- /dev/null +++ b/profiler/compare_tools/compare_backend/utils/module_node.py @@ -0,0 +1,171 @@ +import re +from math import ceil + +from compare_backend.compare_bean.origin_data_bean.trace_event_bean import TraceEventBean +from compare_backend.utils.torch_op_node import TorchOpNode + + +class ModuleNode: + ts = "ts" + kernels = "kernels" + + def __init__(self, event: TraceEventBean, parent_node=None): + self._event = event + self._parent_node = parent_node + self._child_nodes = [] + self._module_name = f"{parent_node.module_name}/{event.name}" if parent_node else event.name + self._module_level = parent_node.module_level + 1 if parent_node else 1 + self._kernel_self_list = [] + self._kernel_total_list = [] + self._call_stack = f"{parent_node.call_stack};\n{event.name}" if parent_node and parent_node.call_stack \ + else event.name + self._root_torch_op_node = TorchOpNode() + self._cur_torch_op_node = self._root_torch_op_node + + @property + def module_name(self): + return self._module_name + + @property + def module_class(self): + pattern = re.compile('_[0-9]+$') + return pattern.sub('', self.name.split("/")[-1]) + + @property + def module_level(self): + return self._module_level + + @property + def name(self): + return self._event.name + + @property + def parent_node(self): + return self._parent_node + + @property + def child_nodes(self): + return self._child_nodes + + @property + def dur(self): + return self._event.dur + + @property + def start_time(self): + return self._event.start_time + + @property + def end_time(self): + return self._event.end_time + + @property + def host_self_dur(self): + return self.dur - sum([node.dur for node in self.child_nodes]) + + @property + def device_self_dur(self): + dur = 0 + for kernel_dict in self._kernel_self_list: + kernel_list = kernel_dict.get(self.kernels, []) + dur += sum([kernel.device_dur for kernel in kernel_list]) + return dur + + @property + def device_total_dur(self): + dur = 0 + for kernel_dict in self._kernel_total_list: + kernel_list = kernel_dict.get(self.kernels, []) + dur += sum([kernel.device_dur for kernel in kernel_list]) + return dur + + @property + def kernel_details(self): + kernel_details = "" + for kernel_dict in self._kernel_self_list: + kernel_list = kernel_dict.get(self.kernels, []) + for kernel in kernel_list: + kernel_details += kernel.kernel_details + return kernel_details + + @property + def toy_layer_api_list(self): + return self._root_torch_op_node.child_nodes + + @property + def call_stack(self): + return self._call_stack + + @staticmethod + def _binary_search(ts_time, parent_node): + if not parent_node.child_nodes: + return None + right = len(parent_node.child_nodes) - 1 + left = 0 + while right > left: + mid = left + ceil((right - left) / 2) + if ts_time >= parent_node.child_nodes[mid].start_time: + left = mid + else: + right = mid - 1 + if parent_node.child_nodes[left].start_time < ts_time < parent_node.child_nodes[left].end_time: + return parent_node.child_nodes[left] + return None + + def reset_call_stack(self, call_stack): + self._call_stack = call_stack + + def update_child_nodes(self, node): + self._child_nodes.append(node) + + def update_kernel_list(self, ts, kernel_list: list): + self._update_kernel_self_list(ts, kernel_list) + node = self + while node.parent_node: + node._update_kernel_total_list(ts, kernel_list) + node = node.parent_node + + def _update_kernel_self_list(self, ts, kernel_list: list): + self._kernel_self_list.append({self.ts: ts, self.kernels: kernel_list}) + + def _update_kernel_total_list(self, ts, kernel_list: list): + self._kernel_total_list.append({self.ts: ts, self.kernels: kernel_list}) + + def find_module_call(self, ts_time): + call_module = self._binary_search(ts_time, self) + while call_module: + module = self._binary_search(ts_time, call_module) + if not module: + return call_module + call_module = module + return call_module + + def find_torch_op_call(self, event): + while self._cur_torch_op_node: + if self._cur_torch_op_node != self._root_torch_op_node and \ + event.start_time > self._cur_torch_op_node.end_time: + self._cur_torch_op_node = self._cur_torch_op_node.parent + continue + tree_node = TorchOpNode(event, self._cur_torch_op_node) + self._cur_torch_op_node.add_child_node(tree_node) + self._cur_torch_op_node = tree_node + break + + def update_torch_op_kernel_list(self): + top_node_list = self._root_torch_op_node.child_nodes + if not top_node_list: + return + top_node_list.sort(key=lambda x: x.start_time) + cur_index = 0 + self._kernel_self_list.sort(key=lambda x: x.get(self.ts, 0)) + for kernel_dict in self._kernel_self_list: + ts = kernel_dict.get(self.ts, 0) + kernel_list = kernel_dict.get(self.kernels, []) + while cur_index < len(top_node_list): + if ts > top_node_list[cur_index].end_time: + cur_index += 1 + continue + if ts < top_node_list[cur_index].start_time: + break + top_node_list[cur_index].update_kernel_list(kernel_list) + break diff --git a/profiler/compare_tools/compare_backend/utils/name_function.py b/profiler/compare_tools/compare_backend/utils/name_function.py index d83f9e4291c9c1afbcbc1e398741d2bdbedd8df8..cd79e8a03fa7a970ce97ad59f14fae12766f096b 100644 --- a/profiler/compare_tools/compare_backend/utils/name_function.py +++ b/profiler/compare_tools/compare_backend/utils/name_function.py @@ -1,3 +1,4 @@ +from compare_backend.utils.module_node import ModuleNode from compare_backend.utils.torch_op_node import TorchOpNode @@ -41,3 +42,11 @@ class NameFunction: input_shape = ';\r\n'.join(data) return f'{self.args.op_name_map.get(op_node.name, op_node.name)}{input_shape}' return f'{self.args.op_name_map.get(op_node.name, op_node.name)}{op_node.input_shape}' + + def get_module_name(self, module: ModuleNode) -> str: + if not self.args.op_name_map: + return module.module_name + module = module.module_name + for old_name, new_name in self.args.op_name_map.items(): + module.replace(old_name, new_name) + return module diff --git a/profiler/compare_tools/compare_backend/utils/torch_op_node.py b/profiler/compare_tools/compare_backend/utils/torch_op_node.py index 45b9299ba0a23fcc0072546f73cec125890d2e21..690c46cd51c1e2991b0bfaf44e9af431cdad5151 100644 --- a/profiler/compare_tools/compare_backend/utils/torch_op_node.py +++ b/profiler/compare_tools/compare_backend/utils/torch_op_node.py @@ -60,6 +60,10 @@ class TorchOpNode: def memory_allocated(self): return self._memory_allocated_list + @property + def device_dur(self): + return sum([kernel.device_dur for kernel in self._kernel_list]) + def add_child_node(self, child_node): self._child_nodes.append(child_node) @@ -73,11 +77,16 @@ class TorchOpNode: cur_node._kernel_num += kernel_num cur_node = cur_node._parent_node + def update_kernel_list(self, kernel_list: list): + if not kernel_list: + return + self._kernel_list.extend(kernel_list) + def set_memory_allocated(self, memory_allocated: MemoryEvent): self._memory_allocated_list.append(memory_allocated) def is_step_profiler(self) -> bool: - return self.name.find("ProfilerStep#") != -1 + return self._event.is_step_profiler() def get_op_info(self) -> list: return [self.name, self.input_shape, self.input_type, self.call_stack] diff --git a/profiler/compare_tools/compare_backend/utils/tree_builder.py b/profiler/compare_tools/compare_backend/utils/tree_builder.py index f621453d1a5a2281425a01e93b3f89b012f35b88..34c1fe1a1f4046d1e60af107f5ee74484424174a 100644 --- a/profiler/compare_tools/compare_backend/utils/tree_builder.py +++ b/profiler/compare_tools/compare_backend/utils/tree_builder.py @@ -1,5 +1,7 @@ from queue import Queue +from compare_backend.compare_bean.origin_data_bean.trace_event_bean import TraceEventBean +from compare_backend.utils.module_node import ModuleNode from compare_backend.utils.torch_op_node import TorchOpNode @@ -7,10 +9,12 @@ class TreeBuilder: @classmethod def build_tree(cls, event_list: list, kernel_dict: dict, memory_list: list) -> TorchOpNode: root_node = TorchOpNode() - event_list.extend(memory_list) - event_list.sort(key=lambda x: x.start_time) + all_event_list = [] + all_event_list.extend(event_list) + all_event_list.extend(memory_list) + all_event_list.sort(key=lambda x: x.start_time) last_node = root_node - for event in event_list: + for event in all_event_list: while last_node: if last_node != root_node and event.start_time > last_node.end_time: last_node = last_node.parent @@ -53,3 +57,26 @@ class TreeBuilder: for child_node in tree_node.child_nodes: node_queue.put(child_node) return result_list + + @classmethod + def build_module_tree(cls, event_list: list, kernel_dict: dict): + root_node = ModuleNode(TraceEventBean({})) + event_list.sort(key=lambda x: x.start_time) + last_node = root_node + for event in event_list: + while last_node: + if last_node != root_node and event.start_time > last_node.end_time: + last_node = last_node.parent_node + continue + if event.is_x_mode(): + tree_node = ModuleNode(event, last_node) + last_node.update_child_nodes(tree_node) + last_node = tree_node + break + if last_node == root_node: + break + kernel_list = kernel_dict.get(event.start_time, []) + if kernel_list: + last_node.update_kernel_list(event.start_time, kernel_list) + break + return root_node diff --git a/profiler/compare_tools/compare_backend/view/work_sheet_creator.py b/profiler/compare_tools/compare_backend/view/work_sheet_creator.py index c5e56c2f8b9a7ae0c1d1a596dbe81e3541f6ce73..7a33168da377ae77ab64fff0886e09eef065b4e2 100644 --- a/profiler/compare_tools/compare_backend/view/work_sheet_creator.py +++ b/profiler/compare_tools/compare_backend/view/work_sheet_creator.py @@ -23,20 +23,28 @@ class WorkSheetCreator: self._write_data() def _write_headers(self): - header_format = self._work_book.add_format(CellFormatType.BLUE_BOLD) + base_header_format = self._work_book.add_format(CellFormatType.GREEN_BOLD) + com_header_format = self._work_book.add_format(CellFormatType.YELLOW_BOLD) + com_index_range = [-1, -1] overhead = self._data.get("overhead", []) if overhead: base_path = f"Base Profiling: {self._args.base_profiling_path}" - self._work_sheet.merge_range(overhead[0], base_path, header_format) + self._work_sheet.merge_range(overhead[0], base_path, base_header_format) + com_index_range = [self._col_ids.index(overhead[1].split(":")[0][0]), + self._col_ids.index(overhead[1].split(":")[1][0])] comparison_path = f"Comparison Profiling: {self._args.comparison_profiling_path}" - self._work_sheet.merge_range(overhead[1], comparison_path, header_format) + self._work_sheet.merge_range(overhead[1], comparison_path, com_header_format) self._row_id += 2 for index, header in enumerate(self._data.get("headers")): + if index in range(com_index_range[0], com_index_range[1] + 1): + header_format = com_header_format + else: + header_format = base_header_format col_id = self._col_ids[index] self._work_sheet.set_column(f"{col_id}:{col_id}", header.get("width")) self._work_sheet.write(f"{col_id}{self._row_id}", header.get("name"), header_format) self._field_format[index] = self._work_book.add_format(header.get("type")) - if header.get("name") == ExcelConfig.DIFF_RATIO: + if header.get("name") in (ExcelConfig.DIFF_RATIO, ExcelConfig.DIFF_TOTAL_RATIO): self._diff_ratio_index = index self._row_id += 1 diff --git a/profiler/compare_tools/compare_interface/comparison_interface.py b/profiler/compare_tools/compare_interface/comparison_interface.py index b3ba5f63c8260aabe132955746184c7c28edbc2a..919095b310126f2ce0c9c3e6912fb10f24d149e9 100644 --- a/profiler/compare_tools/compare_interface/comparison_interface.py +++ b/profiler/compare_tools/compare_interface/comparison_interface.py @@ -1,39 +1,31 @@ -from compare_backend.comparison_generator import ComparisonGenerator -from compare_backend.utils.constant import Constant +import sys +import os +sys.path.append( + os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), "cluster_analyse")) +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -class Args: - def __init__(self, - base_profiling_path: str, - comparison_profiling_path: str, - enable_profiling_compare: bool = False, - enable_operator_compare: bool = False, - enable_memory_compare: bool = False, - enable_communication_compare: bool = False, - output_path: str = "", - max_kernel_num: int = None, - op_name_map: dict = None, - use_input_shape: bool = False, - gpu_flow_cat: str = ""): - self.base_profiling_path = base_profiling_path - self.comparison_profiling_path = comparison_profiling_path - self.enable_profiling_compare = enable_profiling_compare - self.enable_operator_compare = enable_operator_compare - self.enable_memory_compare = enable_memory_compare - self.enable_communication_compare = enable_communication_compare - self.output_path = output_path - self.max_kernel_num = max_kernel_num - self.op_name_map = op_name_map or {} - self.use_input_shape = use_input_shape - self.gpu_flow_cat = gpu_flow_cat +from compare_backend.comparison_generator import ComparisonGenerator +from compare_backend.disaggregate.overall_perf_interface import OverallPerfInterface +from compare_backend.utils.compare_args import Args +from compare_backend.utils.constant import Constant class ComparisonInterface: - def __init__(self, base_profiling_path: str, comparison_profiling_path: str): - self._args = Args(base_profiling_path, comparison_profiling_path) + def __init__(self, base_profiling_path: str, comparison_profiling_path: str = ""): + self.base_profiling_path = base_profiling_path + if comparison_profiling_path: + self._args = Args(base_profiling_path=base_profiling_path, + comparison_profiling_path=comparison_profiling_path) def compare(self, compare_type: str) -> dict: if compare_type == Constant.OVERALL_COMPARE: self._args.enable_profiling_compare = True return ComparisonGenerator(self._args).run_interface(compare_type) + + def disaggregate_perf(self, compare_type: str) -> dict: + if compare_type != Constant.OVERALL_COMPARE: + print('[ERROR] Invalid compare_type value: {compare_type} which not supported.') + return {} + return OverallPerfInterface(self.base_profiling_path).run() diff --git a/profiler/test/ut/advisor/advisor_backend/compute_advice/test_npu_slow_advice.py b/profiler/test/ut/advisor/advisor_backend/compute_advice/test_npu_slow_advice.py new file mode 100644 index 0000000000000000000000000000000000000000..8830d495992cfcd2c26024863f8b644d5b4c6902 --- /dev/null +++ b/profiler/test/ut/advisor/advisor_backend/compute_advice/test_npu_slow_advice.py @@ -0,0 +1,223 @@ +import json +import os +import shutil +import stat +import csv +import unittest + +from advisor_backend.interface import Interface +from advisor_backend.compute_advice.npu_slow_advice import NpuSlowAdvice + + +class TestNpuSlowAdvice(unittest.TestCase): + ASCEND_PT_DIR = "./ascend_pt" + OUTPUT_DIR = "./ascend_pt/ASCEND_PROFILER_OUTPUT" + interface = None + err_interface = None + + def tearDown(self): + if os.path.exists(TestNpuSlowAdvice.ASCEND_PT_DIR): + shutil.rmtree(TestNpuSlowAdvice.ASCEND_PT_DIR) + + def setUp(self): + if os.path.exists(TestNpuSlowAdvice.ASCEND_PT_DIR): + shutil.rmtree(TestNpuSlowAdvice.ASCEND_PT_DIR) + if not os.path.exists(TestNpuSlowAdvice.ASCEND_PT_DIR): + os.makedirs(TestNpuSlowAdvice.ASCEND_PT_DIR) + if not os.path.exists(TestNpuSlowAdvice.OUTPUT_DIR): + os.makedirs(TestNpuSlowAdvice.OUTPUT_DIR) + + @classmethod + def get_basic_trace_view(cls): + # Python pid + py_pid_data = {"ph": "M", "name": "process_name", "tid": 0, "pid": 1, "args": {"name": "Python"}} + # ascend pid + ascend_pid_data = {"ph": "M", "name": "process_name", "tid": 0, "pid": 4, "args": {"name": "Ascend Hardware"}} + # ascend pid + cann_pid_data = {"ph": "M", "name": "process_name", "tid": 0, "pid": 5, "args": {"name": "CANN"}} + # ascend hardware ops + ah_event1 = {"ph": "X", "name": "Slice1", "ts": "1699529623106750", "dur": 100, "tid": 3, "pid": 4, "args": {}} + ah_event2 = {"ph": "X", "name": "Slice2", "ts": "1699529623106751", "dur": 80, "tid": 3, "pid": 4, "args": {}} + # flow event + flow_event_s = {"ph": "s", "name": "link1", "id": 1, "tid": 3, "pid": 1, "ts": "200", "args": {}} + flow_event_e = {"ph": "f", "name": "link1", "id": 1, "tid": 3, "pid": 1, "ts": "1699529623106750", "args": {}} + return [py_pid_data, ascend_pid_data, cann_pid_data, ah_event1, ah_event2, flow_event_s, flow_event_e] + + @classmethod + def create_profiler_info_json(cls): + info = { + "config": { + "common_config": { + "with_stack": True, + "activities": ["ProfilerActivity.CPU", "ProfilerActivity.NPU"] + } + } + } + with os.fdopen(os.open(f"{TestNpuSlowAdvice.ASCEND_PT_DIR}/profiler_info_0.json", + os.O_WRONLY | os.O_CREAT, stat.S_IWUSR | stat.S_IRUSR), 'w') as fp: + fp.write(json.dumps(info)) + + @classmethod + def create_old_version_trace_view(cls): + basic_info = cls.get_basic_trace_view() + + # python ops + py_event1 = {"ph": "X", "cat": "python_function", "name": "aten::slice", "ts": "200", "dur": 100, "tid": 2, + "pid": 1, + "args": {"Call stack": "/root/test/slice.py(116);\r\n/root/torch/module.py"}} + py_event2 = {"ph": "X", "cat": "python_function", "name": "slice", "ts": "199", "dur": 200, "tid": 2, "pid": 1, + "args": {"Call stack": "/root/test/slice.py(116);\r\n/root/torch/module.py"}} + raw_data = [ + *basic_info, py_event1, py_event2 + ] + + with os.fdopen(os.open(f"{TestNpuSlowAdvice.OUTPUT_DIR}/trace_view.json", + os.O_WRONLY | os.O_CREAT, stat.S_IWUSR | stat.S_IRUSR), 'w') as fp: + fp.write(json.dumps(raw_data)) + + @classmethod + def create_new_version_trace_view(cls): + basic_info = cls.get_basic_trace_view() + # python ops + py_event1 = {"ph": "X", "name": "aten::slice", "ts": "200", "dur": 100, "tid": 2, "pid": 1, "args": {}} + py_event2 = {"ph": "X", "name": "slice", "ts": "199", "dur": 105, "tid": 2, "pid": 1, "args": {}} + py_event3 = {"ph": "X", "cat": "python_function", "name": "/root/test/slice.py(116)", "ts": "198", "dur": 120, + "tid": 2, "pid": 1, + "args": {}} + py_event4 = {"ph": "X", "cat": "python_function", "name": "/root/torch/module.py", "ts": "197", "dur": 150, + "tid": 2, "pid": 1, "args": {}} + + raw_data = [ + *basic_info, py_event1, py_event2, py_event3, py_event4 + ] + + with os.fdopen(os.open(f"{TestNpuSlowAdvice.OUTPUT_DIR}/trace_view.json", + os.O_WRONLY | os.O_CREAT, stat.S_IWUSR | stat.S_IRUSR), 'w') as fp: + fp.write(json.dumps(raw_data)) + + @classmethod + def create_kernel_details(cls): + # create csv files + csv_header = ['Step Id', 'Model ID', 'Task ID', 'Stream ID', 'Name', 'Type', 'Accelerator Core', + 'Start Time(us)', + 'Duration(us)', 'Wait Time(us)', 'Block Dim', 'Mix Block Dim', 'Input Shapes', 'Input Data Types', + 'Input Formats', 'Output Shapes', 'Output Data Types', 'Output Formats', 'Context ID', + 'aicore_time(us)', + 'aic_total_cycles', 'aic_mac_ratio', 'aic_mac_int8_ratio', 'aic_cube_fops', + 'aic_vector_fops', + 'aiv_time(us)', 'aiv_total_cycles', 'aiv_vec_fp32_ratio', 'aiv_vec_fp16_ratio', + 'aiv_vec_int32_ratio', + 'aiv_vec_misc_ratio', 'aiv_cube_fops', 'aiv_vector_fops'] + # RED: size=0.0492 MB, throughput=2.32 GB/s, task_duration=21.2us + csv_row1 = [1, 4294967295, 1265, 16, 'Slice1', 'Slice', 'AI_VECTOR_CORE', "1699529623106750\t", 21.2, 261.56, 9, + 0, + '4,1025', 'INT64', 'FORMAT_ND', '4,1025', 'INT32', 'FORMAT_ND', 'N/A', + 0, 0, 0, 0, 0, 0, + 1.77, 29508, 0, 0, 0.0062, 0, 0, 5856] + # YELLOW: size=0.0492 MB, throughput=984 GB/s, task_duration=0.05us + csv_row2 = [1, 4294967295, 1265, 16, 'Slice2', 'Slice', 'AI_VECTOR_CORE', "1699529623106751\t", 0.05, 261.56, 9, + 0, + '4,1025', 'INT64', 'FORMAT_ND', '4,1025', 'INT32', 'FORMAT_ND', 'N/A', + 0, 0, 0, 0, 0, 0, + 1.77, 29508, 0, 0, 0.0062, 0, 0, 5856] + # WHITE: AI_CPU + csv_row3 = [1, 4294967295, 1265, 16, 'Swish1', 'Swish', 'AI_CPU', "1699529623106752\t", 3.14, 261.56, 9, + 0, + '4,1025', 'INT64', 'FORMAT_ND', '4,1025', 'INT32', 'FORMAT_ND', 'N/A', + 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', + 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A'] + # GREEN: size=0.0492 MB, throughput=15.67 GB/s, task_duration = 3.14us + csv_row4 = [1, 4294967295, 1265, 16, 'Mul1', 'Mul', 'AI_VECTOR_CORE', "1699529623106753\t", 3.14, 261.56, 9, 0, + '4,1025', 'INT64', 'FORMAT_ND', '4,1025', 'INT32', 'FORMAT_ND', 'N/A', + 0, 0, 0, 0, 0, 0, + 1.77, 29508, 0, 0, 0.0062, 0, 0, 5856] + # RED: aic_mac_ratio=0.2 + csv_row5 = [1, 4294967295, 1265, 16, 'Add1', 'Add', 'AI_CORE', "1699529623106754\t", 3.14, 261.56, 9, 0, + '4,1025', 'INT64', 'FORMAT_ND', '4,1025', 'INT32', 'FORMAT_ND', 'N/A', + 2.3, 28888, 0.2, 0.1, 0.1, 0.7, + 0, 0, 0, 0, 0, 0, 0, 0] + # GREEN: aic_mac_ratio=0.85 + csv_row6 = [1, 4294967295, 1265, 16, 'Add1', 'Add', 'AI_CORE', "1699529623106754\t", 3.14, 261.56, 9, 0, + '4,1025', 'INT64', 'FORMAT_ND', '4,1025', 'INT32', 'FORMAT_ND', 'N/A', + 2.3, 38888, 0.85, 0.1, 0.1, 0.7, + 0, 0, 0, 0, 0, 0, 0, 0] + # YELLOW: aic_mac_ratio=0.64 + csv_row7 = [1, 4294967295, 1265, 16, 'Add1', 'Add', 'AI_CORE', "1699529623106754\t", 3.14, 261.56, 9, 0, + '4,1025', 'INT64', 'FORMAT_ND', '4,1025', 'INT32', 'FORMAT_ND', 'N/A', + 2.3, 48888, 0.64, 0.1, 0.1, 0.7, + 0, 0, 0, 0, 0, 0, 0, 0] + # WHITE: MIX_AIC + csv_row8 = [1, 4294967295, 1265, 16, 'Slice2', 'Slice', 'MIX_AIC', "1699529623106751\t", 0.05, 261.56, 9, + 0, + '4,1025', 'INT64', 'FORMAT_ND', '4,1025', 'INT32', 'FORMAT_ND', 'N/A', + 2.3, 28888, 0.4, 0.1, 0.1, 0.7, + 1.77, 29508, 0, 0, 0.0062, 0, 0, 5856] + # WHITE: MIX_AIV + csv_row9 = [1, 4294967295, 1265, 16, 'Slice2', 'Slice', 'MIX_AIV', "1699529623106751\t", 0.05, 261.56, 9, + 0, + '4,1025', 'INT64', 'FORMAT_ND', '4,1025', 'INT32', 'FORMAT_ND', 'N/A', + 2.3, 28888, 0.4, 0.1, 0.1, 0.7, + 1.77, 29508, 0, 0, 0.0062, 0, 0, 5856] + with os.fdopen(os.open(f"{TestNpuSlowAdvice.OUTPUT_DIR}/kernel_details.csv", + os.O_WRONLY | os.O_CREAT, stat.S_IWUSR | stat.S_IRUSR), 'w') as fp: + csv_writer = csv.writer(fp) + csv_writer.writerow(csv_header) + csv_writer.writerow(csv_row1) + csv_writer.writerow(csv_row2) + csv_writer.writerow(csv_row3) + csv_writer.writerow(csv_row4) + csv_writer.writerow(csv_row5) + csv_writer.writerow(csv_row6) + csv_writer.writerow(csv_row7) + csv_writer.writerow(csv_row8) + csv_writer.writerow(csv_row9) + + def test_run_should_return_empty_when_ascend_pt_path_not_exist(self): + interface = Interface("") + data = interface.get_data('compute', 'npu_slow') + self.assertEqual(0, len(data)) + + def test_run_should_return_empty_when_there_is_no_kernel_details(self): + interface = Interface(self.ASCEND_PT_DIR) + data = interface.get_data('compute', 'npu_slow') + self.assertEqual(0, len(data)) + + def test_run_should_return_7_data_without_call_stack_when_json_not_exist(self): + self.create_kernel_details() + interface = Interface(self.ASCEND_PT_DIR) + data = interface.get_data('compute', 'npu_slow') + call_stack = NpuSlowAdvice(self.ASCEND_PT_DIR).get_call_stack(data, index_id=0, ts_col="Start Time(us)") + self.assertEqual(9, len(data)) + self.assertEqual("", call_stack) + + def test_run_should_return_7_data_with_call_stack_when_new_trace_view_exists(self): + self.create_profiler_info_json() + self.create_kernel_details() + self.create_new_version_trace_view() + interface = Interface(self.ASCEND_PT_DIR) + data = interface.get_data('compute', 'npu_slow') + slow_op_data = data[data["color"] == "RED"] + NpuSlowAdvice.save_to_excel(data, file_path=os.path.join(self.ASCEND_PT_DIR, "slow_op.xlsx")) + call_stack = NpuSlowAdvice(self.ASCEND_PT_DIR).get_call_stack(data, index_id=0, ts_col="Start Time(us)") + self.assertEqual(9, len(data)) + self.assertEqual(2, len(slow_op_data)) + print(call_stack) + call_stack_res = "/root/torch/module.py\n" \ + "/root/test/slice.py(116)" + self.assertEqual(call_stack_res, call_stack) + + def test_run_should_return_7_data_with_call_stack_when_old_trace_view_exists(self): + self.create_profiler_info_json() + self.create_kernel_details() + self.create_old_version_trace_view() + interface = Interface(self.ASCEND_PT_DIR) + data = interface.get_data('compute', 'npu_slow') + slow_op_data = data[data["color"] == "RED"] + NpuSlowAdvice.save_to_excel(data, file_path=os.path.join(self.ASCEND_PT_DIR, "slow_op.xlsx")) + call_stack = NpuSlowAdvice(self.ASCEND_PT_DIR).get_call_stack(data, index_id=0, ts_col="Start Time(us)") + self.assertEqual(9, len(data)) + self.assertEqual(2, len(slow_op_data)) + print(call_stack) + call_stack_res = "/root/test/slice.py(116)\n\r\n" \ + "/root/torch/module.py" + self.assertEqual(call_stack_res, call_stack) diff --git a/profiler/test/ut/compare_tools/profiling_parser/test_gpu_profiling_parser.py b/profiler/test/ut/compare_tools/profiling_parser/test_gpu_profiling_parser.py index 388b92ec4167821aeae03799b173ac226d4dd1d9..04468721504b1e1133b659a4d497c4ef86ed0414 100644 --- a/profiler/test/ut/compare_tools/profiling_parser/test_gpu_profiling_parser.py +++ b/profiler/test/ut/compare_tools/profiling_parser/test_gpu_profiling_parser.py @@ -71,6 +71,7 @@ class TestGpuProfilingParser(unittest.TestCase): res._trace_events = [TraceEventBean(event) for event in self.trace_events] res._result_data = ProfilingResult("GPU") res._compute_stream_id = 3 + res._flow_dict = {} res._marks = defaultdict(int) res._calculate_performance_time() self.assertEqual(res._result_data.overall_metrics.e2e_time, 98) diff --git a/sample/README.md b/sample/README.md index 167b1a01cbd87c75eb6a6479a39fc198360a402f..6bd55a2f83422b2f0c8424c9687a38f1698aa6fb 100644 --- a/sample/README.md +++ b/sample/README.md @@ -5,12 +5,61 @@ 如果考虑商用集成,推荐使用CANN软件包中的AscendC样例工程,比如:ascendc_kernel_cmake目录。本项目中的工程就是基于其进行简化仅用于快速验证。 +说明:该sample目录中,每个最小目录就是一个完整的样例工程。这些样例工程本身可能以为依赖的不同存在差异。 + ## 依赖说明 安装CANN包,并使能环境变量,并确保```ASCEND_HOME_PATH```生效,可以在CANN包安装目录下使能: ``` source set_env.sh ``` +## 目录介绍 +整体目录结构如下: +``` +- sample + |- build # 编译并运行所有样例内容(建议按需使用,此处命令可以参考 + |- normal_sample # 纯C/C++的AscendC单算子极简工程,可配合msdebug和msprof工具 + |- cube_only # 仅含aic的AscendC单算子极简工程 + |- mix # mix算子的AscendC单算子极简工程 + |- vec_only # 仅含aiv的AscendC单算子极简工程 + |- pytorch_adapter # 适配pytorch的AscendC单算子极简工程,可配合msdebug和msprof工具 + |- jit_compile # jit模式,运行时编译使用 + |- with_setuptools # 编译成wheel包安装使用 + |- sanitizer_sample # 异常样例,用于配合mssanitizer工具 + |- racecheck # 含竞争问题的样例 + |- xx # 其他异常样例 +``` + +如果你关注自定义算子的pytorch框架适配,详见[此处](./pytorch_adapter/README.md) + + +## 算子调试 msdebug +若使用msdebug进行上板调试,还需要额外调整,具体如下: +1. 编译阶段:在```sample\normal_sample\vec_only```相对路径下的```Makefile```文件中修改如下内容: + + 调试信息增强,并扩大栈空间: + ``` + COMPILER_FLAG := -xcce -O2 -std=c++17 + 修改为: + COMPILER_FLAG := -xcce -O0 -std=c++17 -g -mllvm -cce-aicore-function-stack-size=0x8000 -mllvm -cce-aicore-stack-size=0x8000 -mllvm -cce-aicore-jump-expand=true + ``` + +2. 运行阶段: +``` +msdebug ./*.fatbin +``` + +## 内存检测 sanitizer +1. 编译阶段:在编译过程中添加```--cce-enable-sanitizer -g```参数, 在链接过程中添加```--cce-enable-sanitizer```参数。(现样例中已在Makefile中添加),执行如下命令: +``` +make +``` + +2. 运行阶段: +``` +mssanitizer ./*.fatbin # 默认进行memcheck检查 +``` + + ## 算子调优 算子调优工具可以支持上板和仿真算子的调优,下面将以vec_only中的算子为例,进行工具使用的实战命令讲解 @@ -84,30 +133,3 @@ source set_env.sh └── trace.json # 算子所有核的流水图 ``` 4. 更多指标信息请参考算子开发工具使用手册。 - -## 算子调试msdebug -若使用msdebug进行上板调试,还需要额外调整,具体如下: -1. 编译阶段:在```sample\normal_sample\vec_only```相对路径下的```Makefile```文件中修改如下内容: - + 调试信息增强,并扩大栈空间: - ``` - COMPILER_FLAG := -xcce -O2 -std=c++17 - 修改为: - COMPILER_FLAG := -xcce -O0 -std=c++17 -g -mllvm -cce-aicore-function-stack-size=0x8000 -mllvm -cce-aicore-stack-size=0x8000 -mllvm -cce-aicore-jump-expand=true - -## 内存检测 sanitizer -### sanitizer_sample目录介绍 - -此目录下为sanitizer对应的样例库,包含竞争检测和内存检测相关的样例。 - -#### Racecheck目录介绍 - -Racecheck为竞争检测相关的样例。 - -raw_error_kernel.cpp文件为UB上先读后写竞争和GM上先写后读竞争问题的样例。 - - -运行阶段: - -``` -/usr/local/Ascend/ascend-toolkit/latest/tools/mssanitizer/bin/mssanitizer --tool=racecheck ./raw_error.fatbin -``` \ No newline at end of file diff --git a/sample/pytorch_adapter/README.md b/sample/pytorch_adapter/README.md new file mode 100644 index 0000000000000000000000000000000000000000..a2b1ba63570058ac954a121f4b14b396f5dace81 --- /dev/null +++ b/sample/pytorch_adapter/README.md @@ -0,0 +1,53 @@ +# 自定义算子的pytorch框架适配说明 + +## 简介 +昇腾提供丰富的算子接入框架的方式,此处将介绍最简单的一种,每个目录中都是一个独立的可使用的工程 + +## 依赖 +与业内pytorch的算子介入方式相同,算子接入框架需要保障设备上有正确的pytorch版本(我们还依赖torch_npu版本) + +pytorch版本可由pip安装,torch_npu版本详见[此处](https://gitee.com/ascend/pytorch/releases),请选择与pytorch适配的torch_npu版本。 + +## 工程介绍 +整体工程目录如下: +``` +- pytorch_adapter + |- jit_compile # 实时编译的接入方式 + |- add_adapter.cpp # 使用算子动态库接口完成算子在pytorch框架的适配 + |- add_kernel.cpp # 昇腾算子实现,并提供host侧的动态库接口 + |- main.py # python的入口,实现整体集成 + |- Makefile # 用以生成昇腾算子的host侧动态库的编译脚本 + |- with_setuptools # wheel包的接入方式 + |- add_adapter.cpp + |- add_kernel.cpp + |- Makefile + |- setup.py # setuptools的入口,支持编译并打包生成wheel包 + |- test.py # 测试wheel包功能的入口 +``` + +## 工程使用 + +### jit_compile工程 +执行如下命令,就会在运行过程中,现场生成python模块并使用: +``` +python main.py +``` + +### setuptools工程 +针对with_setuptools工程,可以编译出可安装的wheel包,便于多机部署使用。 + + +1. 执行如下命令可以编译出软件包(setuptools可以支持多种方式,比如:build,install等,此处不一一展示): +``` +pytorch setup.py bdist_wheel # 编译出wheel包,在dist目录下 +``` + +2. 到```dist```目录下用pip命令安装对应软件包。 + +3. 执行测试脚本 +``` +python test.py +``` + +## 其他 +1. 此处样例使用的是静态tiling,如果使用动态tiling,则可以在adapter.cpp中对Tensor的shape进行分析,选择合适tiling。(这部分是流程中必须的,只是可能在不同位置,比如aclnn中,这部分在接口实现;此处,我们本身也可以对add_custom_do进行封装,将tiling内置。) \ No newline at end of file diff --git a/sample/pytorch_adapter/jit_compile/Makefile b/sample/pytorch_adapter/jit_compile/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..ec9115f377a578677470b89f365583dfcf246515 --- /dev/null +++ b/sample/pytorch_adapter/jit_compile/Makefile @@ -0,0 +1,20 @@ +# Location of the CANN, 主要基于${ASCEND_HOME_PATH}/compiler/tikcpp/ascendc_kernel_cmake中内容简化 +ASCEND_HOME_PATH ?= /usr/local/Ascend/ascend-toolkit/latest + +COMPILER := $(ASCEND_HOME_PATH)/compiler/ccec_compiler/bin/ccec # 参考device_config.cmake中CMAKE_C_COMPILER配置 +COMPILER_FLAG := -xcce -O2 -std=c++17 +DYNAMIC_LIB_FLAG := -fPIC -shared +DAV_FLAG := --cce-aicore-arch=dav-c220-vec +ASCENDC_INC_FLAG := -I${ASCEND_HOME_PATH}/compiler/tikcpp/tikcfw -I${ASCEND_HOME_PATH}/compiler/tikcpp/tikcfw/impl -I${ASCEND_HOME_PATH}/compiler/tikcpp/tikcfw/interface -I${ASCEND_HOME_PATH}/include # 参考device_intf.cmake的配置简化 + +all: build + +build: libcustom_kernels.so + +# 后续如果要扩展,把多个kernel的cpp都加到后面 +libcustom_kernels.so: add_kernel.cpp + $(COMPILER) $(DYNAMIC_LIB_FLAG) $(COMPILER_FLAG) $(DAV_FLAG) $(ASCENDC_INC_FLAG) -o $@ $^ + +.PHONY: clean +clean: + rm *.so \ No newline at end of file diff --git a/sample/pytorch_adapter/jit_compile/add_adapter.cpp b/sample/pytorch_adapter/jit_compile/add_adapter.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6c65e60ec596fe8b5627e06f678549b5f2f05660 --- /dev/null +++ b/sample/pytorch_adapter/jit_compile/add_adapter.cpp @@ -0,0 +1,128 @@ +#include +#include "torch_npu/csrc/core/npu/NPUStream.h" +#include "torch_npu/csrc/framework/OpCommand.h" + +using torch::autograd::AutogradContext; +using torch::autograd::Function; +using tensor_list = std::vector; +using namespace at; + +extern "C" void add_custom_do(uint32_t blockDim, void *stream, uint8_t *x, uint8_t *y, uint8_t *z); + +// 为NPU设备注册前向实现 +at::Tensor my_add_impl_npu(const at::Tensor &self, const at::Tensor &other) +{ + // 创建输出内存 + at::Tensor result = at::Tensor(self); + // 将pytorch中的结构翻译成为CANN认识的数据类型和结构 + // 1. (重要)通过对tensor的shape分析,选择合适的tiling(该算子为了简化,固定了tiling,只有特定shape下计算才正确) + // 2. 对数据类型和格式转换 -- 此处无需数据格式处理,直接使用 + auto stream = c10_npu::getCurrentNPUStream().stream(false); + auto x = self.storage().data(); + auto y = other.storage().data(); + auto z = result.storage().data(); + + uint32_t blockDim = 8; + auto callback = [stream, blockDim, x, y, z]() -> int { + add_custom_do(blockDim, stream, (uint8_t *)x, (uint8_t *)y, (uint8_t *)z); + return 0; // 此处可以通过某种方式获取算子执行结果,还未实现 + }; + // 下发算子 + at_npu::native::OpCommand cmd; + cmd.Name("my_add").SetCustomHandler(callback).Run(); + return result; +} + +// 为NPU设备注册反向实现 +std::tuple my_add_backward_impl_npu(const at::Tensor &self) +{ + at::Tensor result = at::Tensor(self); // 创建输出内存 + + return {result, result}; +} + +// 为Meta设备注册前向实现 +at::Tensor my_add_impl_meta(const at::Tensor &self, const at::Tensor &other) +{ + return empty_like(self); +} + +// 为Meta设备注册反向实现 +std::tuple my_add_backward_impl_meta(const at::Tensor &self) +{ + auto result = empty_like(self); + return std::make_tuple(result, result); +} + +// 寻找注册在该op上的不同设备的实现 +at::Tensor my_add_impl(const at::Tensor &self, const at::Tensor &other) +{ + static auto op = + torch::Dispatcher::singleton().findSchemaOrThrow("myaten::my_add", "").typed(); + return op.call(self, other); +} +// 寻找注册在该op上的不同设备的实现 +std::tuple my_add_backward_impl(const at::Tensor &self) +{ + static auto op = torch::Dispatcher::singleton() + .findSchemaOrThrow("myaten::my_add_backward", "") + .typed(); + return op.call(self); +} + +// 在myaten命名空间里注册my_add和my_add_backward两个schema +TORCH_LIBRARY(myaten, m) +{ + m.def("my_add(Tensor self, Tensor other) -> Tensor"); + m.def("my_add_backward(Tensor self) -> (Tensor, Tensor)"); +} + +// 通过继承torch::autograd::Function类实现前反向绑定 +class MyAddFunction : public torch::autograd::Function { +public: + static at::Tensor forward(AutogradContext *ctx, at::Tensor self, at::Tensor other) + { + at::AutoDispatchBelowADInplaceOrView guard; + return my_add_impl(self, other); + } + + static tensor_list backward(AutogradContext *ctx, tensor_list grad_outputs) + { + auto grad_output = grad_outputs[0]; + auto result = my_add_backward_impl(grad_output); + return {std::get<0>(result), std::get<1>(result)}; + } +}; + +at::Tensor my_add_impl_autograd(const at::Tensor &self, const at::Tensor &other) +{ + return MyAddFunction::apply(self, other); +} + +// 给op绑定NPU的自动求导实现 +// 如果是pytorch 2.1以下的版本,AutogradPrivateUse1需要改成AutogradXLA +TORCH_LIBRARY_IMPL(myaten, AutogradPrivateUse1, m) +{ + m.impl("my_add", &my_add_impl_autograd); +} + +// 为NPU设备注册前反向实现 +// NPU设备在pytorch 2.1及以上版本使用的设备名称是PrivateUse1,在2.1以下版本用的是XLA,如果是2.1以下版本PrivateUse1需要改成XLA +TORCH_LIBRARY_IMPL(myaten, PrivateUse1, m) +{ + m.impl("my_add", &my_add_impl_npu); + m.impl("my_add_backward", &my_add_backward_impl_npu); +} + +// 为Meta设备注册前反向实现 +TORCH_LIBRARY_IMPL(myaten, Meta, m) +{ + m.impl("my_add", &my_add_impl_meta); + m.impl("my_add_backward", &my_add_backward_impl_meta); +} + +// 通过pybind将c++接口和python接口绑定 +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("my_add", &my_add_impl_autograd, "x + y"); +} diff --git a/sample/pytorch_adapter/jit_compile/add_kernel.cpp b/sample/pytorch_adapter/jit_compile/add_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9aa62e093633de1f5bddc8d9b7f80fb58831bdb9 --- /dev/null +++ b/sample/pytorch_adapter/jit_compile/add_kernel.cpp @@ -0,0 +1,106 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + * + * Function : z = x + y + * This sample is a very basic sample that implements vector add on Ascend plaform. + * In this sample: + * Length of x / y / z is 8*2048. + * Num of vector core used in sample is 8. + * Length for each core to compute is 2048. + * Tiles for each core is 8 which means we add 2048/8=256 elements in one loop. + * + */ +#include "kernel_operator.h" +using namespace AscendC; +constexpr int32_t TOTAL_LENGTH = 8 * 2048; // total length of data +constexpr int32_t USE_CORE_NUM = 8; // num of core used +constexpr int32_t BLOCK_LENGTH = TOTAL_LENGTH / USE_CORE_NUM; // length computed of each core +constexpr int32_t TILE_NUM = 8; // split data into 8 tiles for each core +constexpr int32_t BUFFER_NUM = 2; // tensor num for each queue +constexpr int32_t TILE_LENGTH = BLOCK_LENGTH / TILE_NUM / BUFFER_NUM; // seperate to 2 parts, due to double buffer + +class KernelAdd { +public: + __aicore__ inline KernelAdd() + {} + __aicore__ inline void Init(GM_ADDR x, GM_ADDR y, GM_ADDR z) + { + // get start index for current core, core parallel + xGm.SetGlobalBuffer((__gm__ half *)x + BLOCK_LENGTH * GetBlockIdx(), BLOCK_LENGTH); + yGm.SetGlobalBuffer((__gm__ half *)y + BLOCK_LENGTH * GetBlockIdx(), BLOCK_LENGTH); + zGm.SetGlobalBuffer((__gm__ half *)z + BLOCK_LENGTH * GetBlockIdx(), BLOCK_LENGTH); + // pipe alloc memory to queue, the unit is Bytes + pipe.InitBuffer(inQueueX, BUFFER_NUM, TILE_LENGTH * sizeof(half)); + pipe.InitBuffer(inQueueY, BUFFER_NUM, TILE_LENGTH * sizeof(half)); + pipe.InitBuffer(outQueueZ, BUFFER_NUM, TILE_LENGTH * sizeof(half)); + } + __aicore__ inline void Process() + { + // loop count need to be doubled, due to double buffer + constexpr int32_t loopCount = TILE_NUM * BUFFER_NUM; + // tiling strategy, pipeline parallel + for (int32_t i = 0; i < loopCount; i++) { + CopyIn(i); + Compute(i); + CopyOut(i); + } + } + +private: + __aicore__ inline void CopyIn(int32_t progress) + { + // alloc tensor from queue memory + LocalTensor xLocal = inQueueX.AllocTensor(); + LocalTensor yLocal = inQueueY.AllocTensor(); + // copy progress_th tile from global tensor to local tensor + DataCopy(xLocal, xGm[progress * TILE_LENGTH], TILE_LENGTH); + DataCopy(yLocal, yGm[progress * TILE_LENGTH], TILE_LENGTH); + // enque input tensors to VECIN queue + inQueueX.EnQue(xLocal); + inQueueY.EnQue(yLocal); + } + __aicore__ inline void Compute(int32_t progress) + { + // deque input tensors from VECIN queue + LocalTensor xLocal = inQueueX.DeQue(); + LocalTensor yLocal = inQueueY.DeQue(); + LocalTensor zLocal = outQueueZ.AllocTensor(); + // call Add instr for computation + Add(zLocal, xLocal, yLocal, TILE_LENGTH); + // enque the output tensor to VECOUT queue + outQueueZ.EnQue(zLocal); + // free input tensors for reuse + inQueueX.FreeTensor(xLocal); + inQueueY.FreeTensor(yLocal); + } + __aicore__ inline void CopyOut(int32_t progress) + { + // deque output tensor from VECOUT queue + LocalTensor zLocal = outQueueZ.DeQue(); + // copy progress_th tile from local tensor to global tensor + DataCopy(zGm[progress * TILE_LENGTH], zLocal, TILE_LENGTH); + // free output tensor for reuse + outQueueZ.FreeTensor(zLocal); + } + +private: + TPipe pipe; + // create queues for input, in this case depth is equal to buffer num + TQue inQueueX, inQueueY; + // create queue for output, in this case depth is equal to buffer num + TQue outQueueZ; + GlobalTensor xGm, yGm, zGm; +}; +// implementation of kernel function +extern "C" __global__ __aicore__ void add_custom(GM_ADDR x, GM_ADDR y, GM_ADDR z) +{ + KernelAdd op; + op.Init(x, y, z); + op.Process(); +} + +// 包裹核函数,使得普通编译器能认识这个符号 +extern "C" void add_custom_do(uint32_t blockDim, void *stream, uint8_t *x, uint8_t *y, uint8_t *z) +{ + add_custom<<>>(x, y, z); +} \ No newline at end of file diff --git a/sample/pytorch_adapter/jit_compile/main.py b/sample/pytorch_adapter/jit_compile/main.py new file mode 100644 index 0000000000000000000000000000000000000000..847a51f1c4787dcf353759d1115f352c1c760353 --- /dev/null +++ b/sample/pytorch_adapter/jit_compile/main.py @@ -0,0 +1,70 @@ +import os +import subprocess +import torch +import torch_npu +import torch.utils.cpp_extension +from torch_npu.testing.testcase import TestCase, run_tests + +PYTORCH_NPU_INSTALL_PATH = os.path.dirname(os.path.abspath(torch_npu.__file__)) +CUR_PATH = os.path.abspath(os.path.dirname(__file__)) + + +def compile_kernels(): + # 由于pytorch中没有昇腾device编译的扩展,所以此处人工加make + subprocess.run("make") + + +def compile_host(): + extra_ldflags = [] + extra_ldflags.append(f"-L{PYTORCH_NPU_INSTALL_PATH}/lib") + extra_ldflags.append("-ltorch_npu") + extra_ldflags.append(f"-L{CUR_PATH}/") + extra_ldflags.append("-lcustom_kernels") + extra_include_paths = [] + extra_include_paths.append("./") + extra_include_paths.append(os.path.join( + PYTORCH_NPU_INSTALL_PATH, "include")) + extra_include_paths.append(os.path.join(os.path.join(os.path.join(os.path.join( + PYTORCH_NPU_INSTALL_PATH, "include"), "third_party"), "acl"), "inc")) + + module = torch.utils.cpp_extension.load( + name="jit_extension", + sources=[ + "add_adapter.cpp" + ], + extra_include_paths=extra_include_paths, + extra_ldflags=extra_ldflags, + verbose=True) + return module + + +class TestCustomAdd(TestCase): + def test_add(self): + module = compile_host() + # 由于kernel现在是静态tiling,所以此处尺寸需要匹配 + # 因为add是elementwise的,现有算子支持8*2048(详见kernel实现),所以,小于这个应该都可以 + length = [8, 2048] + x = torch.rand(length, device='cpu', dtype=torch.float16) + y = torch.rand(length, device='cpu', dtype=torch.float16) + + x_npu = x.npu() + y_npu = y.npu() + x_npu.requires_grad = True + y_npu.requires_grad = True + output = module.my_add(x_npu, y_npu) + # 反向能力验证 + output.backward(output) + + x.requires_grad = True + y.requires_grad = True + cpuout = torch.add(x, y) + cpuout.backward(cpuout) + + self.assertRtolEqual(output, cpuout) + self.assertRtolEqual(x_npu.grad, x.grad) + self.assertRtolEqual(y_npu.grad, y.grad) + + +if __name__ == '__main__': + compile_kernels() + run_tests() diff --git a/sample/pytorch_adapter/with_setuptools/Makefile b/sample/pytorch_adapter/with_setuptools/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..ec9115f377a578677470b89f365583dfcf246515 --- /dev/null +++ b/sample/pytorch_adapter/with_setuptools/Makefile @@ -0,0 +1,20 @@ +# Location of the CANN, 主要基于${ASCEND_HOME_PATH}/compiler/tikcpp/ascendc_kernel_cmake中内容简化 +ASCEND_HOME_PATH ?= /usr/local/Ascend/ascend-toolkit/latest + +COMPILER := $(ASCEND_HOME_PATH)/compiler/ccec_compiler/bin/ccec # 参考device_config.cmake中CMAKE_C_COMPILER配置 +COMPILER_FLAG := -xcce -O2 -std=c++17 +DYNAMIC_LIB_FLAG := -fPIC -shared +DAV_FLAG := --cce-aicore-arch=dav-c220-vec +ASCENDC_INC_FLAG := -I${ASCEND_HOME_PATH}/compiler/tikcpp/tikcfw -I${ASCEND_HOME_PATH}/compiler/tikcpp/tikcfw/impl -I${ASCEND_HOME_PATH}/compiler/tikcpp/tikcfw/interface -I${ASCEND_HOME_PATH}/include # 参考device_intf.cmake的配置简化 + +all: build + +build: libcustom_kernels.so + +# 后续如果要扩展,把多个kernel的cpp都加到后面 +libcustom_kernels.so: add_kernel.cpp + $(COMPILER) $(DYNAMIC_LIB_FLAG) $(COMPILER_FLAG) $(DAV_FLAG) $(ASCENDC_INC_FLAG) -o $@ $^ + +.PHONY: clean +clean: + rm *.so \ No newline at end of file diff --git a/sample/pytorch_adapter/with_setuptools/add_adapter.cpp b/sample/pytorch_adapter/with_setuptools/add_adapter.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6c65e60ec596fe8b5627e06f678549b5f2f05660 --- /dev/null +++ b/sample/pytorch_adapter/with_setuptools/add_adapter.cpp @@ -0,0 +1,128 @@ +#include +#include "torch_npu/csrc/core/npu/NPUStream.h" +#include "torch_npu/csrc/framework/OpCommand.h" + +using torch::autograd::AutogradContext; +using torch::autograd::Function; +using tensor_list = std::vector; +using namespace at; + +extern "C" void add_custom_do(uint32_t blockDim, void *stream, uint8_t *x, uint8_t *y, uint8_t *z); + +// 为NPU设备注册前向实现 +at::Tensor my_add_impl_npu(const at::Tensor &self, const at::Tensor &other) +{ + // 创建输出内存 + at::Tensor result = at::Tensor(self); + // 将pytorch中的结构翻译成为CANN认识的数据类型和结构 + // 1. (重要)通过对tensor的shape分析,选择合适的tiling(该算子为了简化,固定了tiling,只有特定shape下计算才正确) + // 2. 对数据类型和格式转换 -- 此处无需数据格式处理,直接使用 + auto stream = c10_npu::getCurrentNPUStream().stream(false); + auto x = self.storage().data(); + auto y = other.storage().data(); + auto z = result.storage().data(); + + uint32_t blockDim = 8; + auto callback = [stream, blockDim, x, y, z]() -> int { + add_custom_do(blockDim, stream, (uint8_t *)x, (uint8_t *)y, (uint8_t *)z); + return 0; // 此处可以通过某种方式获取算子执行结果,还未实现 + }; + // 下发算子 + at_npu::native::OpCommand cmd; + cmd.Name("my_add").SetCustomHandler(callback).Run(); + return result; +} + +// 为NPU设备注册反向实现 +std::tuple my_add_backward_impl_npu(const at::Tensor &self) +{ + at::Tensor result = at::Tensor(self); // 创建输出内存 + + return {result, result}; +} + +// 为Meta设备注册前向实现 +at::Tensor my_add_impl_meta(const at::Tensor &self, const at::Tensor &other) +{ + return empty_like(self); +} + +// 为Meta设备注册反向实现 +std::tuple my_add_backward_impl_meta(const at::Tensor &self) +{ + auto result = empty_like(self); + return std::make_tuple(result, result); +} + +// 寻找注册在该op上的不同设备的实现 +at::Tensor my_add_impl(const at::Tensor &self, const at::Tensor &other) +{ + static auto op = + torch::Dispatcher::singleton().findSchemaOrThrow("myaten::my_add", "").typed(); + return op.call(self, other); +} +// 寻找注册在该op上的不同设备的实现 +std::tuple my_add_backward_impl(const at::Tensor &self) +{ + static auto op = torch::Dispatcher::singleton() + .findSchemaOrThrow("myaten::my_add_backward", "") + .typed(); + return op.call(self); +} + +// 在myaten命名空间里注册my_add和my_add_backward两个schema +TORCH_LIBRARY(myaten, m) +{ + m.def("my_add(Tensor self, Tensor other) -> Tensor"); + m.def("my_add_backward(Tensor self) -> (Tensor, Tensor)"); +} + +// 通过继承torch::autograd::Function类实现前反向绑定 +class MyAddFunction : public torch::autograd::Function { +public: + static at::Tensor forward(AutogradContext *ctx, at::Tensor self, at::Tensor other) + { + at::AutoDispatchBelowADInplaceOrView guard; + return my_add_impl(self, other); + } + + static tensor_list backward(AutogradContext *ctx, tensor_list grad_outputs) + { + auto grad_output = grad_outputs[0]; + auto result = my_add_backward_impl(grad_output); + return {std::get<0>(result), std::get<1>(result)}; + } +}; + +at::Tensor my_add_impl_autograd(const at::Tensor &self, const at::Tensor &other) +{ + return MyAddFunction::apply(self, other); +} + +// 给op绑定NPU的自动求导实现 +// 如果是pytorch 2.1以下的版本,AutogradPrivateUse1需要改成AutogradXLA +TORCH_LIBRARY_IMPL(myaten, AutogradPrivateUse1, m) +{ + m.impl("my_add", &my_add_impl_autograd); +} + +// 为NPU设备注册前反向实现 +// NPU设备在pytorch 2.1及以上版本使用的设备名称是PrivateUse1,在2.1以下版本用的是XLA,如果是2.1以下版本PrivateUse1需要改成XLA +TORCH_LIBRARY_IMPL(myaten, PrivateUse1, m) +{ + m.impl("my_add", &my_add_impl_npu); + m.impl("my_add_backward", &my_add_backward_impl_npu); +} + +// 为Meta设备注册前反向实现 +TORCH_LIBRARY_IMPL(myaten, Meta, m) +{ + m.impl("my_add", &my_add_impl_meta); + m.impl("my_add_backward", &my_add_backward_impl_meta); +} + +// 通过pybind将c++接口和python接口绑定 +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("my_add", &my_add_impl_autograd, "x + y"); +} diff --git a/sample/pytorch_adapter/with_setuptools/add_kernel.cpp b/sample/pytorch_adapter/with_setuptools/add_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9aa62e093633de1f5bddc8d9b7f80fb58831bdb9 --- /dev/null +++ b/sample/pytorch_adapter/with_setuptools/add_kernel.cpp @@ -0,0 +1,106 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + * + * Function : z = x + y + * This sample is a very basic sample that implements vector add on Ascend plaform. + * In this sample: + * Length of x / y / z is 8*2048. + * Num of vector core used in sample is 8. + * Length for each core to compute is 2048. + * Tiles for each core is 8 which means we add 2048/8=256 elements in one loop. + * + */ +#include "kernel_operator.h" +using namespace AscendC; +constexpr int32_t TOTAL_LENGTH = 8 * 2048; // total length of data +constexpr int32_t USE_CORE_NUM = 8; // num of core used +constexpr int32_t BLOCK_LENGTH = TOTAL_LENGTH / USE_CORE_NUM; // length computed of each core +constexpr int32_t TILE_NUM = 8; // split data into 8 tiles for each core +constexpr int32_t BUFFER_NUM = 2; // tensor num for each queue +constexpr int32_t TILE_LENGTH = BLOCK_LENGTH / TILE_NUM / BUFFER_NUM; // seperate to 2 parts, due to double buffer + +class KernelAdd { +public: + __aicore__ inline KernelAdd() + {} + __aicore__ inline void Init(GM_ADDR x, GM_ADDR y, GM_ADDR z) + { + // get start index for current core, core parallel + xGm.SetGlobalBuffer((__gm__ half *)x + BLOCK_LENGTH * GetBlockIdx(), BLOCK_LENGTH); + yGm.SetGlobalBuffer((__gm__ half *)y + BLOCK_LENGTH * GetBlockIdx(), BLOCK_LENGTH); + zGm.SetGlobalBuffer((__gm__ half *)z + BLOCK_LENGTH * GetBlockIdx(), BLOCK_LENGTH); + // pipe alloc memory to queue, the unit is Bytes + pipe.InitBuffer(inQueueX, BUFFER_NUM, TILE_LENGTH * sizeof(half)); + pipe.InitBuffer(inQueueY, BUFFER_NUM, TILE_LENGTH * sizeof(half)); + pipe.InitBuffer(outQueueZ, BUFFER_NUM, TILE_LENGTH * sizeof(half)); + } + __aicore__ inline void Process() + { + // loop count need to be doubled, due to double buffer + constexpr int32_t loopCount = TILE_NUM * BUFFER_NUM; + // tiling strategy, pipeline parallel + for (int32_t i = 0; i < loopCount; i++) { + CopyIn(i); + Compute(i); + CopyOut(i); + } + } + +private: + __aicore__ inline void CopyIn(int32_t progress) + { + // alloc tensor from queue memory + LocalTensor xLocal = inQueueX.AllocTensor(); + LocalTensor yLocal = inQueueY.AllocTensor(); + // copy progress_th tile from global tensor to local tensor + DataCopy(xLocal, xGm[progress * TILE_LENGTH], TILE_LENGTH); + DataCopy(yLocal, yGm[progress * TILE_LENGTH], TILE_LENGTH); + // enque input tensors to VECIN queue + inQueueX.EnQue(xLocal); + inQueueY.EnQue(yLocal); + } + __aicore__ inline void Compute(int32_t progress) + { + // deque input tensors from VECIN queue + LocalTensor xLocal = inQueueX.DeQue(); + LocalTensor yLocal = inQueueY.DeQue(); + LocalTensor zLocal = outQueueZ.AllocTensor(); + // call Add instr for computation + Add(zLocal, xLocal, yLocal, TILE_LENGTH); + // enque the output tensor to VECOUT queue + outQueueZ.EnQue(zLocal); + // free input tensors for reuse + inQueueX.FreeTensor(xLocal); + inQueueY.FreeTensor(yLocal); + } + __aicore__ inline void CopyOut(int32_t progress) + { + // deque output tensor from VECOUT queue + LocalTensor zLocal = outQueueZ.DeQue(); + // copy progress_th tile from local tensor to global tensor + DataCopy(zGm[progress * TILE_LENGTH], zLocal, TILE_LENGTH); + // free output tensor for reuse + outQueueZ.FreeTensor(zLocal); + } + +private: + TPipe pipe; + // create queues for input, in this case depth is equal to buffer num + TQue inQueueX, inQueueY; + // create queue for output, in this case depth is equal to buffer num + TQue outQueueZ; + GlobalTensor xGm, yGm, zGm; +}; +// implementation of kernel function +extern "C" __global__ __aicore__ void add_custom(GM_ADDR x, GM_ADDR y, GM_ADDR z) +{ + KernelAdd op; + op.Init(x, y, z); + op.Process(); +} + +// 包裹核函数,使得普通编译器能认识这个符号 +extern "C" void add_custom_do(uint32_t blockDim, void *stream, uint8_t *x, uint8_t *y, uint8_t *z) +{ + add_custom<<>>(x, y, z); +} \ No newline at end of file diff --git a/sample/pytorch_adapter/with_setuptools/setup.py b/sample/pytorch_adapter/with_setuptools/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..92ab1d3c78c7866b4bd53d9531bf0674c8b2987e --- /dev/null +++ b/sample/pytorch_adapter/with_setuptools/setup.py @@ -0,0 +1,51 @@ +import os +import subprocess +import torch +import torch_npu +from setuptools import setup, find_packages +from torch.utils.cpp_extension import BuildExtension +from torch_npu.utils.cpp_extension import NpuExtension + +PYTORCH_NPU_INSTALL_PATH = os.path.dirname(os.path.abspath(torch_npu.__file__)) +CUR_PATH = os.path.abspath(os.path.dirname(__file__)) + + +def compile_kernels(): + # 由于pytorch中没有昇腾device编译的扩展,所以此处人工加make + subprocess.run("make") + return "libcustom_kernels.so" # 这个make出来的库名字 + + +def compile_adapter(): + ext = NpuExtension( + name="ascend_custom_kernels_lib", # import的库的名字 + # 如果还有其他cpp文件参与编译,需要在这里添加 + sources=[f"{CUR_PATH}/add_adapter.cpp"], + extra_compile_args=[ + '-I' + os.path.join(os.path.join(os.path.join(os.path.join( + PYTORCH_NPU_INSTALL_PATH, "include"), "third_party"), "acl"), "inc"), + ], + library_dirs=[f"{CUR_PATH}"], # 编译时需要依赖的库文件的路径,相当于g++编译时的-L选项 + libraries=["custom_kernels"], # 编译时依赖的库文件,相当于-l选项 + ) + return [ext] + + +if __name__ == "__main__": + # 编译出含有算子的库,并以so的方式提供 + kernel_so = compile_kernels() + + # 编译出pytorch适配层的库,支持被框架集成 + exts = compile_adapter() + + # 将整体打包成wheel包 + setup( + name="ascend_custom_kernels", # package的名字 + version='1.0', + keywords='ascend_custom_kernels', + ext_modules=exts, + packages=find_packages(), + cmdclass={"build_ext": BuildExtension}, + data_files=[(".", [kernel_so])], + include_package_data=True, + ) diff --git a/sample/pytorch_adapter/with_setuptools/test.py b/sample/pytorch_adapter/with_setuptools/test.py new file mode 100644 index 0000000000000000000000000000000000000000..896eef2c0fbb1a113377fb7dc770f45fd99832f4 --- /dev/null +++ b/sample/pytorch_adapter/with_setuptools/test.py @@ -0,0 +1,34 @@ +import torch +import torch_npu +import ascend_custom_kernels_lib +from torch_npu.testing.testcase import TestCase, run_tests + + +class TestCustomAdd(TestCase): + def test_add(self): + # 由于kernel现在是静态tiling,所以此处尺寸需要匹配 + # 因为add是elementwise的,现有算子支持8*2048(详见kernel实现),所以,小于这个应该都可以 + length = [8, 2048] + x = torch.rand(length, device='cpu', dtype=torch.float16) + y = torch.rand(length, device='cpu', dtype=torch.float16) + + x_npu = x.npu() + y_npu = y.npu() + x_npu.requires_grad = True + y_npu.requires_grad = True + output = ascend_custom_kernels_lib.my_add(x_npu, y_npu) + # 反向能力验证 + output.backward(output) + + x.requires_grad = True + y.requires_grad = True + cpuout = torch.add(x, y) + cpuout.backward(cpuout) + + self.assertRtolEqual(output, cpuout) + self.assertRtolEqual(x_npu.grad, x.grad) + self.assertRtolEqual(y_npu.grad, y.grad) + + +if __name__ == "__main__": + run_tests() diff --git a/sample/third_party/lib/libruntime.so.aarch64 b/sample/third_party/lib/libruntime.so.aarch64 deleted file mode 100644 index 2c686dc3e0ab56768ec8c45cfac9f1fbb107888f..0000000000000000000000000000000000000000 Binary files a/sample/third_party/lib/libruntime.so.aarch64 and /dev/null differ diff --git a/sample/third_party/lib/libruntime.so.x86 b/sample/third_party/lib/libruntime.so.x86 deleted file mode 100644 index 6da21687dc7655cc6745003cfcbb6c3c0a8ceb34..0000000000000000000000000000000000000000 Binary files a/sample/third_party/lib/libruntime.so.x86 and /dev/null differ diff --git a/sample/third_party/lib/libruntime_camodel.so.aarch64 b/sample/third_party/lib/libruntime_camodel.so.aarch64 deleted file mode 100644 index 2c686dc3e0ab56768ec8c45cfac9f1fbb107888f..0000000000000000000000000000000000000000 Binary files a/sample/third_party/lib/libruntime_camodel.so.aarch64 and /dev/null differ diff --git a/sample/third_party/lib/libruntime_camodel.so.x86 b/sample/third_party/lib/libruntime_camodel.so.x86 deleted file mode 100644 index 6da21687dc7655cc6745003cfcbb6c3c0a8ceb34..0000000000000000000000000000000000000000 Binary files a/sample/third_party/lib/libruntime_camodel.so.x86 and /dev/null differ