From 017c5a54c07dd2862f16b20c7b66a3486f3db486 Mon Sep 17 00:00:00 2001 From: kai-ma Date: Fri, 4 Jul 2025 15:40:23 +0800 Subject: [PATCH] add utils for depen,env,excep --- accuracy_tools/msprobe/utils/__init__.py | 13 ++ accuracy_tools/msprobe/utils/constants.py | 201 +++++++++++++++++++ accuracy_tools/msprobe/utils/dependencies.py | 96 +++++++++ accuracy_tools/msprobe/utils/env.py | 80 ++++++++ accuracy_tools/msprobe/utils/exceptions.py | 22 ++ 5 files changed, 412 insertions(+) create mode 100644 accuracy_tools/msprobe/utils/__init__.py create mode 100644 accuracy_tools/msprobe/utils/constants.py create mode 100644 accuracy_tools/msprobe/utils/dependencies.py create mode 100644 accuracy_tools/msprobe/utils/env.py create mode 100644 accuracy_tools/msprobe/utils/exceptions.py diff --git a/accuracy_tools/msprobe/utils/__init__.py b/accuracy_tools/msprobe/utils/__init__.py new file mode 100644 index 00000000000..53529bc8d31 --- /dev/null +++ b/accuracy_tools/msprobe/utils/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/accuracy_tools/msprobe/utils/constants.py b/accuracy_tools/msprobe/utils/constants.py new file mode 100644 index 00000000000..6872a8999b5 --- /dev/null +++ b/accuracy_tools/msprobe/utils/constants.py @@ -0,0 +1,201 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class CmdConst: + """ + Class for command line const + """ + + DUMP = "dump" + COMPARE = "compare" + + HELP_SERVICE_MAP = {DUMP: "Data collection for Ascend device.", COMPARE: "Accuracy compare for dump task."} + HELP_TASK_MAP = {} + + +class PathConst: + """ + Class for file or dir path const + """ + + FILE = "file" + DIR = "dir" + + SIZE_20M = 20_971_520 # 20 * 1024 * 1024 + SIZE_500M = 524_288_000 # 500 * 1024 * 1024 + SIZE_2G = 2_147_483_648 # 2 * 1024 * 1024 * 1024 + SIZE_4G = 4_294_967_296 # 4 * 1024 * 1024 * 1024 + SIZE_10G = 10_737_418_240 # 10 * 1024 * 1024 * 1024 + SIZE_30G = 32_212_254_720 # 30 * 1024 * 1024 * 1024 + SIZE_50G = 53_687_091_200 # 50 * 1024 * 1024 * 1024 + + SUFFIX_ONLINE_SCRIPT = (".py", ".sh") + SUFFIX_OFFLINE_MODEL = (".pb", ".onnx", ".om", ".prototxt") + + +class MsgConst: + """ + Class for log messages const + """ + + INVALID_ARGU = "[ERROR] invalid argument." + INVALID_DATA_TYPE = "[ERROR] invalid data type." + REQUIRED_ARGU_MISSING = "[ERROR] Required argument missing." + RISK_ALERT = "[ERROR] Risk alert." + NO_PERMISSION = "[ERROR] No permission." + IO_FAILURE = "[ERROR] I/O failure." + PATH_NOT_FOUND = "[ERROR] Path not found." + VALUE_NOT_FOUND = "[ERROR] Value not found." + PARSING_FAILED = "[ERROR] Parsing failed." + CANN_FAILED = "[ERROR] CANN enabling failed." + ATTRIBUTE_ERROR = "[ERROR] Attribute not found." + CALL_FAILED = "[ERROR] Call failed." + CONVERSION_FAILED = "[ERROR] Conversion failed." + MAX_RECURSION_DEPTH = 5 + + +class CompConst: + """ + Class for component name const + """ + + DUMP_WRITER_COMP = "DumpWriterComp" + ACL_DUMPER_COMP = "ACLDumperComp" + ONNX_ACTUATOR_COMP = "OnnxActuatorComp" + ONNX_DUMPER_COMP = "OnnxDumperComp" + FROZEN_GRAPH_ACTUATOR_COMP_CPU = "FrozenGraphActuatorCompCPU" + FROZEN_GRAPH_DUMPER_COMP_CPU = "FrozenGraphDumperCompCPU" + FROZEN_GRAPH_ACTUATOR_COMP_NPU = "FrozenGraphActuatorCompNPU" + FROZEN_GRAPH_SET_GE_COMP_NPU = "FrozenGraphSetGECompNPU" + CAFFE_ACTUATOR_COMP = "CaffeActuatorComp" + CAFFE_DUMPER_COMP = "CaffeDumperComp" + ATB_ACTUATOR_COMP = "ATBActuatorComp" + OM_ACTUATOR_COMP = "OmActuatorComp" + ACL_COMPATIBLE_COMP = "ACLCompatibleComp" + + +class CfgConst: + """ + Class for config items + """ + + CONFIG_PATH = "config_path" + TASK = "task" + TASK_STAT = "statistics" + TASK_TENSOR = "tensor" + ALL_TASK = {TASK_STAT, TASK_TENSOR} + EXEC = "exec" + FRAMEWORK = "framework" + FRAMEWORK_MINDIE_LLM = "mindie_llm" + FRAMEWORK_TORCH_AIR = "torch_air" + FRAMEWORK_MINDIE_TORCH = "mindie_torch" + FRAMEWORK_PT = "pytorch" + FRAMEWORK_MS = "mindspore" + FRAMEWORK_ONNX = "ONNX" + FRAMEWORK_TF = "TensorFlow" + FRAMEWORK_OM = "Ascend OM" + FRAMEWORK_CAFFE = "Caffe" + ALL_FRAMEWORK = {FRAMEWORK_MINDIE_LLM, FRAMEWORK_TORCH_AIR, FRAMEWORK_MINDIE_TORCH, FRAMEWORK_PT, FRAMEWORK_MS} + RANK = "rank" + STEP = "step" + LEVEL = "level" + LEVEL_MODULE = "L0" + LEVEL_API = "L1" + LEVEL_KERNEL = "L2" + ALL_LEVEL = {LEVEL_MODULE, LEVEL_API, LEVEL_KERNEL} + LOG_LEVEL = "log_level" + SEED = "seed" + BUFFER_SIZE = "buffer_size" + + +class DumpConst: + """ + Class for dump const + """ + + DEVICE = "device" + INPUT_ARGS = "input_args" + OUTPUT_ARGS = "output_args" + INPUT = "input" + OUTPUT = "output" + INPUT_ALL = [INPUT, "all"] + OUTPUT_ALL = [OUTPUT, "all"] + ALL_DATA_MODE = [INPUT, OUTPUT, "all"] + + DUMP_PATH = "dump_path" + LIST = "list" + DATA_MODE = "data_mode" + SUMMARY_MODE = "summary_mode" + SUMMARY_MD5 = "md5" + ALL_SUMMARY_MODE = {CfgConst.TASK_STAT, SUMMARY_MD5} + DUMP_EXTRA = "dump_extra" + ALL_DUMP_EXTRA = {"tiling", "cpu_profiling", "kernel_info", "op_info"} + OP_ID = "op_id" + DUMP_LAST_LOGITS = "dump_last_logits" + DUMP_WEIGHT = "dump_weight" + DUMP_GE_GRAPH = "dump_ge_graph" + ALL_DUMP_GE_GRAPH = {"1", "2", "3"} + DUMP_GRAPH_LEVEL = "dump_graph_level" + ALL_DUMP_GRAPH_LEVEL = {"1", "2", "3", "4"} + FUSION_SWITCH_FILE = "fusion_switch_file" + ONNX_FUSION_switch = "onnx_fusion_switch" + SAVED_MODEL_TAG = "saved_model_tag" + SAVED_MODEL_SIGN = "saved_model_signature" + WEIGHT_PATH = "weight_path" + + DUMP_DATA_DIR = "dump_data_dir" + DATA = "data" + DUMP_JSON = "dump.json" + STACK_JSON = "stack.json" + NPY_FORMAT = "npy_format" + BIN_FORMAT = "bin_format" + NET_OUTPUT_NODES_JSON = "net_output_nodes.json" + + ENVVAR_DUMP_GE_GRAPH = "DUMP_GE_GRAPH" + ENVVAR_DUMP_GRAPH_LEVEL = "DUMP_GRAPH_LEVEL" + ENVVAR_DUMP_GRAPH_PATH = "DUMP_GRAPH_PATH" + ENVVAR_ASCEND_WORK_PATH = "ASCEND_WORK_PATH" + + ENVVAR_LINK_DUMP_PATH = "LINK_DUMP_PATH" + ENVVAR_LINK_DUMP_TASK = "LINK_DUMP_TASK" + ENVVAR_LINK_DUMP_LEVEL = "LINK_DUMP_LEVEL" + ENVVAR_LINK_STEP = "LINK_STEP" + ENVVAR_LINK_RANK = "LINK_RANK" + ENVVAR_LINK_LOG_LEVEL = "LINK_LOG_LEVEL" + ENVVAR_LINK_SUMMARY_MODE = "LINK_SUMMARY_MODE" + ENVVAR_LINK_BUFFER_SIZE = "LINK_BUFFER_SIZE" + ENVVAR_LINK_DATA_MODE = "LINK_DATA_MODE" + ENVVAR_LINK_SAVE_TILING = "LINK_SAVE_TILING" + ENVVAR_LINK_SAVE_CPU_PROFILING = "LINK_SAVE_CPU_PROFILING" + ENVVAR_LINK_SAVE_ONNX = "LINK_SAVE_ONNX" + ENVVAR_LINK_SAVE_KERNEL_INFO = "LINK_SAVE_KERNEL_INFO" + ENVVAR_LINK_SAVE_OP_INFO = "LINK_SAVE_OP_INFO" + ENVVAR_LINK_SAVE_PARAM = "LINK_SAVE_PARAM" + ENVVAR_LINK_SAVE_TENSOR_IDS = "LINK_SAVE_TENSOR_IDS" + ENVVAR_LINK_SAVE_TENSOR_RUNNER = "LINK_SAVE_TENSOR_RUNNER" + + +class ACLConst: + """ + Class for Ascendcl const + """ + + SUCCESS = 0 + MEMCPY_HOST_TO_DEVICE = 1 + MEMCPY_DEVICE_TO_HOST = 2 + IS_LAST_CHUNK = "is_last_chunk" + BUF_LEN = "buf_len" + FILE_NAME = "file_name" + DATA_BUF = "data_buf" diff --git a/accuracy_tools/msprobe/utils/dependencies.py b/accuracy_tools/msprobe/utils/dependencies.py new file mode 100644 index 00000000000..1e790a8ec09 --- /dev/null +++ b/accuracy_tools/msprobe/utils/dependencies.py @@ -0,0 +1,96 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from functools import wraps +from importlib import import_module + +from msprobe.utils.exceptions import MsprobeException +from msprobe.utils.log import logger + +import_warnings_shown = set() + + +def safely_import(func): + @wraps(func) + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except Exception: + dependency = args[1] + if dependency not in import_warnings_shown: + logger.warning(f"{dependency} is not installed. Please install it if needed.") + import_warnings_shown.add(dependency) + return None + + return wrapper + + +def temporary_tf_log_level(func): + @wraps(func) + def wrapper(*args, **kwargs): + original_log_level = os.environ.get("TF_CPP_MIN_LOG_LEVEL", "0") + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # 只打印 warning、error + result = func(*args, **kwargs) + os.environ["TF_CPP_MIN_LOG_LEVEL"] = original_log_level + return result + + return wrapper + + +class DependencyManager: + _instance = None + + def __new__(cls, *args, **kwargs): + if not cls._instance: + cls._instance = super(DependencyManager, cls).__new__(cls) + return cls._instance + + def __init__(self): + self._dependencies = {} + + def get(self, package_name): + return self._dependencies.get(package_name, self._import_package(package_name)) + + def get_tensorflow(self): + tf = self.get("tensorflow") + re_writer_config = self.get("tensorflow/RewriterConfig") + sm2pb = self.get("tensorflow/convert_variables_to_constants") + return tf, re_writer_config, sm2pb + + @safely_import + def _import_package(self, package_name): + if package_name in self._dependencies: + return self._dependencies[package_name] + if package_name == "tensorflow": + return self._import_tensorflow() + module = import_module(package_name) + self._dependencies[package_name] = module + return module + + @temporary_tf_log_level + def _import_tensorflow(self): + module = import_module("tensorflow") + if module.__version__ != "2.6.5": + raise MsprobeException("[ERROR] Incompatible versions. Currently only supports TensorFlow v2.6.5.") + from tensorflow.core.protobuf.rewriter_config_pb2 import RewriterConfig + from tensorflow.python.framework.graph_util import convert_variables_to_constants + + self._dependencies["tensorflow/convert_variables_to_constants"] = convert_variables_to_constants + self._dependencies["tensorflow/RewriterConfig"] = RewriterConfig + self._dependencies["tensorflow"] = module + return module + + +dependent = DependencyManager() diff --git a/accuracy_tools/msprobe/utils/env.py b/accuracy_tools/msprobe/utils/env.py new file mode 100644 index 00000000000..a74859c4c3a --- /dev/null +++ b/accuracy_tools/msprobe/utils/env.py @@ -0,0 +1,80 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from msprobe.utils.constants import MsgConst +from msprobe.utils.exceptions import MsprobeException +from msprobe.utils.log import logger + + +class EnvVarManager: + _instance = None + + def __new__(cls, *args, **kwargs): + if not cls._instance: + cls._instance = super(EnvVarManager, cls).__new__(cls) + return cls._instance + + def __init__(self): + self.prefix = "" + + @staticmethod + def _log(msg): + logger.debug(msg) + + def set_prefix(self, prefix): + self.prefix = prefix + + def get(self, key, default=None, cast_type=None, required=True): + value = os.environ.get(key, default) + self._log(f"Accessed environment variable {key}, Value: {value}.") + if required and value is None: + raise MsprobeException( + MsgConst.REQUIRED_ARGU_MISSING, + f"Environment variable {key} is required but not set. " + f"Please check the current environment configuration by `echo ${key}`.", + ) + if value is not None and cast_type: + try: + value = cast_type(value) + self._log(f"Casted {key} to {cast_type.__name__}, Result: {value}.") + except Exception as e: + raise MsprobeException( + MsgConst.INVALID_DATA_TYPE, f"Failed to cast environment variable {key} to {cast_type}." + ) from e + return value + + def set(self, key, value): + os.environ[key] = str(value) + self._log(f"Set environment variable {key} to {value}.") + + def delete(self, key): + if key in os.environ: + os.environ.pop(key, None) + self._log(f"Deleted environment variable {key}.") + else: + self._log(f"{key} not found to delete.") + + def list_all(self): + if self.prefix: + filtered_env = {k: v for k, v in os.environ.items() if k.startswith(self.prefix)} + self._log(f"Listed environment variables with prefix {self.prefix}: {filtered_env}.") + return filtered_env + else: + self._log(f"Listed all environment variables: {dict(os.environ)}.") + return dict(os.environ) + + +evars = EnvVarManager() diff --git a/accuracy_tools/msprobe/utils/exceptions.py b/accuracy_tools/msprobe/utils/exceptions.py new file mode 100644 index 00000000000..55c51b0a794 --- /dev/null +++ b/accuracy_tools/msprobe/utils/exceptions.py @@ -0,0 +1,22 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class MsprobeException(Exception, object): + def __init__(self, error_group, error_msg=""): + super().__init__() + self.error_msg = " ".join([error_group, error_msg]) + + def __str__(self): + return self.error_msg -- Gitee