diff --git a/accuracy_tools/msprobe/common/__init__.py b/accuracy_tools/msprobe/common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..53529bc8d3158c537ae7970cf531b33ba6acd57a --- /dev/null +++ b/accuracy_tools/msprobe/common/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/accuracy_tools/msprobe/common/ascend.py b/accuracy_tools/msprobe/common/ascend.py new file mode 100644 index 0000000000000000000000000000000000000000..61c5e4a5957f0628f6deb7354ea0f8a8dfddef88 --- /dev/null +++ b/accuracy_tools/msprobe/common/ascend.py @@ -0,0 +1,93 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from msprobe.utils.constants import MsgConst, PathConst +from msprobe.utils.env import evars +from msprobe.utils.exceptions import MsprobeException +from msprobe.utils.log import logger +from msprobe.utils.path import SafePath, SoftLinkLevel, join_path +from msprobe.utils.toolkits import run_subprocess + +_ENVVAR_ASCEND_TOOLKIT_HOME = "ASCEND_TOOLKIT_HOME" +_DEFAULT_ASCEND_TOOLKIT_HOME = "/usr/local/Ascend/ascend-toolkit/latest" +_ENVVAR_ATB_HOME_PATH = "ATB_HOME_PATH" +_SUFFIX_CONVERT_MODEL = (".om", ".txt") +_ATC_BIN_PATH = "compiler/bin/atc" +_OLD_ATC_BIN_PATH = "atc/bin/atc" +_ATC_MODE_OM2JSON = "1" +_ATC_MODE_GETXT2JSON = "5" + + +class CANN: + _instance = None + + def __new__(cls, *args, **kwargs): + if not cls._instance: + cls._instance = super(CANN, cls).__new__(cls) + return cls._instance + + def __init__(self): + self.cann_home = evars.get(_ENVVAR_ASCEND_TOOLKIT_HOME, _DEFAULT_ASCEND_TOOLKIT_HOME) + + @property + def lib_atb_path(self): + atb_home_path = evars.get(_ENVVAR_ATB_HOME_PATH) + return SafePath(join_path(atb_home_path, "lib", "libatb.so"), PathConst.FILE, "r", PathConst.SIZE_20M).check( + soft_link_level=SoftLinkLevel.IGNORE + ) + + @property + def probe_symbols(self): + output = run_subprocess(["nm", "-D", self.lib_atb_path], capture_output=True) + res = [] + for line in (output or "").splitlines(): + parts = line.strip().split() + if len(parts) != 3: + continue + symbol_type = parts[1] + symbol_name = parts[2] + if symbol_type == "T" and "Probe" in symbol_name: + res.append(symbol_name) + return res + + def model2json(self, model_path: str, json_path: str): + model_path = SafePath(model_path, PathConst.FILE, "r", PathConst.SIZE_30G, _SUFFIX_CONVERT_MODEL).check() + json_path = SafePath(json_path, PathConst.FILE, "w", suffix=".json").check(path_exist=False) + atc = self._get_atc_path() + if model_path.endswith(".om"): + mode_type = _ATC_MODE_OM2JSON + else: + mode_type = _ATC_MODE_GETXT2JSON + atc_cmd = [atc, "--mode=" + mode_type, "--om=" + model_path, "--json=" + json_path] + logger.info("Start converting the model format to JSON.") + run_subprocess(atc_cmd) + logger.info(f"The model has been converted to a JSON file, located at {json_path}.") + + def _get_atc_path(self): + try: + atc_path = SafePath( + join_path(self.cann_home, _ATC_BIN_PATH), PathConst.FILE, "e", PathConst.SIZE_20M + ).check(soft_link_level=SoftLinkLevel.IGNORE) + except Exception as e1: + logger.error(str(e1)) + try: + atc_path = SafePath( + join_path(self.cann_home, _OLD_ATC_BIN_PATH), PathConst.FILE, "e", PathConst.SIZE_20M + ).check(soft_link_level=SoftLinkLevel.IGNORE) + except Exception as e2: + raise MsprobeException(MsgConst.CANN_FAILED) from e2 + return atc_path + + +cann = CANN() diff --git a/accuracy_tools/msprobe/common/cli.py b/accuracy_tools/msprobe/common/cli.py new file mode 100644 index 0000000000000000000000000000000000000000..1861d34d5ddaaef377c9fdc7c9a99ea18b156f82 --- /dev/null +++ b/accuracy_tools/msprobe/common/cli.py @@ -0,0 +1,99 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from argparse import ArgumentParser +from pathlib import Path +from sys import argv + +from msprobe.base import BaseCommand, Command, Service +from msprobe.utils.constants import CfgConst, CmdConst, MsgConst, PathConst +from msprobe.utils.exceptions import MsprobeException +from msprobe.utils.log import logger +from msprobe.utils.path import SafePath +from msprobe.utils.toolkits import run_subprocess, set_ld_preload + +_DESCRIPTION = """ + _ + _ __ ___ ___ _ __ _ __ ___ | |__ ___ + | '_ ` _ \/ __| '_ \| '__/ _ \| '_ \ / _ \ + | | | | | \__ \ |_) | | | (_) | |_) | __/ + |_| |_| |_|___/ .__/|_| \___/|_.__/ \___| + |_| + +msprobe (MindStudio-Probe), [Powered by MindStudio]. +A set of tools for diagnosing and improving model accuracy on Ascend NPU, +including API accuracy, args checker, grad tool etc. +""" +_L2COMMAND = "L2command" +_ROOT_LEVEL = 1 +_SECEND_LEVEL = 2 + + +class MainCommand(BaseCommand): + def __init__(self): + super().__init__() + self.parser = ArgumentParser(prog="msprobe", description=_DESCRIPTION, formatter_class=self.formatter_class) + self.subparser = self.parser.add_subparsers(dest=_L2COMMAND) + self.second_commands = Command.get("msprobe") + self.subcommand_level = _ROOT_LEVEL + + @property + def _msprobe_so_path(self): + current_file = Path(__file__).resolve() + lib_path = str(current_file.parent.parent / "lib" / "msprobe_c.so") + return SafePath(lib_path, PathConst.FILE, "r", PathConst.SIZE_500M, ".so").check() + + def add_arguments(self, parse): + pass + + def register(self): + for name, cmd_class in self.second_commands.items(): + cmd_parser = self.subparser.add_parser( + name=name, help=CmdConst.HELP_SERVICE_MAP.get(name), formatter_class=self.formatter_class + ) + if self.service_key in self.second_commands: + cmd_class.add_arguments(cmd_parser) + self.subcommand_level = _SECEND_LEVEL + self.build_parser(cmd_parser, cmd_class) + + def parse(self): + return self.parser.parse_args() + + def execute(self, args): + if len(argv) <= self.subcommand_level: + self.parser.print_help() + return + serv_name = argv[self.subcommand_level - 1] + if Service.get(serv_name): + logger.info(f"Preparing to launch {serv_name} service.") + if args.framework: + self._set_env(args.framework) + if not args.msprobex: + serv = Service(cmd_namespace=args, serv_name=serv_name) + serv.run_cli() + else: + run_subprocess(args.exec) + else: + raise MsprobeException(MsgConst.CALL_FAILED, f"The {serv_name} service is not registered. Please check it.") + + def _set_env(self, framework): + env_func = {CfgConst.FRAMEWORK_MINDIE_LLM: self._set_mindie_llm_env} + frame_init = env_func.get(framework) + if frame_init: + frame_init() + else: + raise MsprobeException(MsgConst.CALL_FAILED, f"The {framework} framework is not supported.") + + def _set_mindie_llm_env(self): + set_ld_preload(self._msprobe_so_path) diff --git a/accuracy_tools/msprobe/common/dirs.py b/accuracy_tools/msprobe/common/dirs.py new file mode 100644 index 0000000000000000000000000000000000000000..3907c8681ba7d12e3d4b5b1f42198dcbd728c084 --- /dev/null +++ b/accuracy_tools/msprobe/common/dirs.py @@ -0,0 +1,74 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datetime import datetime + +from msprobe.utils.log import logger +from msprobe.utils.path import DirSafeHandler +from msprobe.utils.toolkits import get_current_rank, get_current_timestamp, timestamp_sync + + +class DirPool: + msprobe_path = None + model_dir = None + + def __init__(self): + self.step_dir = None + self.rank_dir = None + self.input_dir = None + self.tensor_dir = None + + @classmethod + def make_msprobe_dir(cls, path: str): + timestamp = get_current_timestamp(microsecond=False) + timestamp = timestamp_sync(timestamp) + formatted_date = datetime.fromtimestamp(timestamp).strftime("%Y%m%d_%H%M%S") + cls.msprobe_path = DirSafeHandler.join_and_create(path, f"msprobe_{formatted_date}/") + + @classmethod + def get_msprobe_dir(cls): + return DirSafeHandler.get_or_raise(cls.msprobe_path, "Dump dir has not been set.") + + @classmethod + def make_model_dir(cls): + cls.model_dir = DirSafeHandler.join_and_create(cls.get_msprobe_dir(), "model") + + @classmethod + def get_model_dir(cls): + return DirSafeHandler.get_or_raise(cls.model_dir, "Model dir has not been set.") + + def make_step_dir(self, current_step: int): + self.step_dir = DirSafeHandler.join_and_create(self.get_msprobe_dir(), f"step{current_step}") + + def get_step_dir(self): + logger.info(f"Step dir has switched to {self.step_dir}.") + return DirSafeHandler.get_or_raise(self.step_dir, "Step dir has not been set.") + + def make_rank_dir(self): + self.rank_dir = DirSafeHandler.join_and_create(self.get_step_dir(), f"rank{get_current_rank()}") + + def get_rank_dir(self): + return DirSafeHandler.get_or_raise(self.rank_dir, "Rank dir has not been set.") + + def make_input_dir(self): + self.input_dir = DirSafeHandler.join_and_create(self.get_rank_dir(), "input") + + def get_input_dir(self): + return DirSafeHandler.get_or_raise(self.input_dir, "Input dir has not been set.") + + def make_tensor_dir(self): + self.tensor_dir = DirSafeHandler.join_and_create(self.get_rank_dir(), "dump_tensor_data") + + def get_tensor_dir(self): + return DirSafeHandler.get_or_raise(self.tensor_dir, "dump_tensor_data dir has not been set.") diff --git a/accuracy_tools/msprobe/common/stat.py b/accuracy_tools/msprobe/common/stat.py new file mode 100644 index 0000000000000000000000000000000000000000..558adb4c60519b77857508e2f900a414fc8f3208 --- /dev/null +++ b/accuracy_tools/msprobe/common/stat.py @@ -0,0 +1,89 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from zlib import crc32 + +import numpy as np + +from msprobe.utils.constants import DumpConst, MsgConst +from msprobe.utils.exceptions import MsprobeException +from msprobe.utils.log import logger +from msprobe.utils.toolkits import safely_compute + + +class DataStat: + @staticmethod + def get_valid_type(np_data): + try: + module_name = np_data.__class__.__module__ + class_name = np_data.__class__.__name__ + return f"{module_name}.{class_name}" + except Exception: + logger.warning(f"Unrecognized type pattern: {type(np_data)}.") + return None + + @staticmethod + @safely_compute + def get_dtype(npy): + return npy.dtype + + @staticmethod + @safely_compute + def get_shape(npy): + return npy.shape + + @staticmethod + @safely_compute + def get_max(npy): + return float(npy.max()) + + @staticmethod + @safely_compute + def get_min(npy): + return float(npy.min()) + + @staticmethod + @safely_compute + def get_mean(npy): + return float(npy.mean()) + + @staticmethod + @safely_compute + def get_norm(npy): + return float(np.linalg.norm(npy)) + + @staticmethod + @safely_compute + def get_crc32_hash(npy): + npy_bytes = npy.tobytes() + crc32_hash = crc32(npy_bytes) + return f"{crc32_hash:08x}" + + @classmethod + def collect_stats_for_numpy(cls, npy: np.ndarray, summary_mode: str): + try: + npy = np.asarray(npy) + except Exception as e: + raise MsprobeException(MsgConst.CONVERSION_FAILED, f"Failed to convert to numpy array.") from e + stat_dict = {} + stat_dict["type"] = cls.get_valid_type(npy) + stat_dict["dtype"] = cls.get_dtype(npy) + stat_dict["shape"] = cls.get_shape(npy) + stat_dict["Max"] = cls.get_max(npy) + stat_dict["Min"] = cls.get_min(npy) + stat_dict["Mean"] = cls.get_mean(npy) + stat_dict["Norm"] = cls.get_norm(npy) + if summary_mode == DumpConst.SUMMARY_MD5: + stat_dict["md5"] = cls.get_crc32_hash(npy) + return stat_dict diff --git a/accuracy_tools/msprobe/common/validation.py b/accuracy_tools/msprobe/common/validation.py new file mode 100644 index 0000000000000000000000000000000000000000..ac617aa51589f4e4e94c7942ec199a872b7c18d6 --- /dev/null +++ b/accuracy_tools/msprobe/common/validation.py @@ -0,0 +1,181 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from argparse import Action + +from msprobe.utils.constants import CfgConst, MsgConst, PathConst +from msprobe.utils.exceptions import MsprobeException +from msprobe.utils.log import LOG_LEVEL +from msprobe.utils.path import SafePath, is_dir, is_file +from msprobe.utils.toolkits import check_int_border + +_HYPHEN_NUM_PATTERN = r"^(?:\d+-\d+|\d+-\d+-\d+)$" + + +def valid_task(value: str): + if not isinstance(value, str): + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, '"task" must be a string.') + if value not in CfgConst.ALL_TASK: + raise MsprobeException(MsgConst.INVALID_ARGU, f'"task" must be one of {CfgConst.ALL_TASK}, currently: {value}.') + return value + + +def _valid_suffix_for_exec(value: str, extension: str, error_msg: str): + try: + if not value.endswith(extension): + raise MsprobeException(MsgConst.INVALID_ARGU, error_msg) + except Exception as e: + raise MsprobeException(MsgConst.PARSING_FAILED) from e + _ = SafePath(value, PathConst.FILE, "r", PathConst.SIZE_30G).check() + + +def valid_exec(values: str): + if values is None: + return values + if not isinstance(values, str): + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, '"exec" must be a string.') + values = values.split(" ") + first_keyword = values[0] + if is_dir(first_keyword): + _ = SafePath(first_keyword, PathConst.DIR, "r", PathConst.SIZE_50G).check() + elif first_keyword == "bash": + _valid_suffix_for_exec(values[1], ".sh", "The interpreter must start with bash when the script ends with .sh.") + elif first_keyword in {"python", "python3"}: + _valid_suffix_for_exec( + values[1], ".py", "The interpreter must start with python, python3 when the script ends with .py." + ) + elif is_file(first_keyword): + _valid_suffix_for_exec( + first_keyword, + (PathConst.SUFFIX_OFFLINE_MODEL + PathConst.SUFFIX_ONLINE_SCRIPT), + "A single readable or executable file must end with " + f"{PathConst.SUFFIX_OFFLINE_MODEL + PathConst.SUFFIX_ONLINE_SCRIPT}.", + ) + else: + raise MsprobeException(MsgConst.INVALID_ARGU, f"Please check the `--exec (-e)`, currently: {values}.") + return values + + +class CheckExec(Action): + def __call__(self, parser, namespace, values, option_string=None): + values = valid_exec(values) + setattr(namespace, self.dest, values) + + +def valid_config_path(value: str): + return SafePath(value, PathConst.FILE, "r", PathConst.SIZE_2G, ".json").check() + + +class CheckConfigPath(Action): + def __call__(self, parser, namespace, values, option_string=None): + values = valid_config_path(values) + setattr(namespace, self.dest, values) + + +def valid_framework(value: str): + if not value: + return value + if not isinstance(value, str): + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, '"framework" must be a string.') + if value not in CfgConst.ALL_FRAMEWORK: + raise MsprobeException( + MsgConst.INVALID_ARGU, f'"framework" must be one of {CfgConst.ALL_FRAMEWORK}, currently: {value}.' + ) + return value + + +class CheckFramework(Action): + def __call__(self, parser, namespace, values, option_string=None): + values = valid_framework(values) + setattr(namespace, self.dest, values) + + +def parse_hyphen(element, tag=None): + if not re.match(_HYPHEN_NUM_PATTERN, element): + msg = 'Only accepts numbers or a range like "123-456", "123-456-2".' + if tag: + msg += f" Context: {tag}." + raise MsprobeException(MsgConst.INVALID_ARGU, msg) + split_ele = element.split("-") + start = int(split_ele[0]) + end = int(split_ele[1]) + check_int_border(start, end, tag="Hyphen-connected integer") + if start > end: + msg = f"The left value must be smaller than the right, currently: {start} v.s. {end}." + if msg: + msg += f" Context: {tag}." + raise MsprobeException(MsgConst.INVALID_ARGU, msg) + step = int(split_ele[2]) if len(split_ele) == 3 else 1 + ranges = [i for i in range(start, end + 1, step)] + return ranges + + +def valid_step_or_rank(values: list): + if not values: + return values + if not isinstance(values, list): + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, '"rank" or "step" must be a list.') + res = [] + for element in values: + if isinstance(element, str): + res.extend(parse_hyphen(element, tag="strp or rank")) + elif isinstance(element, int): + check_int_border(element, tag="Element in the 'rank' or 'step' list") + res.append(element) + else: + raise MsprobeException( + MsgConst.INVALID_DATA_TYPE, 'Elements in the "rank" or "step" support only strings and integers.' + ) + res = list(set(res)) + res.sort() + return res + + +def valid_level(values: list): + if not values: + return values + if not isinstance(values, list): + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, '"level" must be a list.') + for value in values: + if value not in CfgConst.ALL_LEVEL: + raise MsprobeException( + MsgConst.INVALID_ARGU, f'"level" must be one of {CfgConst.ALL_LEVEL}, currently: {value}.' + ) + return values + + +def valid_log_level(value: str): + if value is None: + return value + if not isinstance(value, str): + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, '"log_level" must be a string.') + log_level = {level.lower() for level in LOG_LEVEL} + if value not in log_level: + raise MsprobeException(MsgConst.INVALID_ARGU, f'"log_level" must be one of {log_level}, currently: {value}.') + return value + + +def valid_seed(value: int): + if value is None: + return value + check_int_border(value, tag="seed number") + return value + + +def valid_buffer_size(value: int): + if value is None: + return value + check_int_border(value, tag="buffer size") + return value