diff --git a/accuracy_tools/msprobe/core/base/__init__.py b/accuracy_tools/msprobe/core/base/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..83608d033346d9a536a7c5afab770cd206a3ad0c --- /dev/null +++ b/accuracy_tools/msprobe/core/base/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from msprobe.core.base.dump_actuator import OfflineModelActuator +from msprobe.core.base.dump_dumper import BaseDumper +from msprobe.core.base.dump_writer import RankDirFile, SaveBinTensor, SaveNpyTensor, SaveTensor diff --git a/accuracy_tools/msprobe/core/base/dump_actuator.py b/accuracy_tools/msprobe/core/base/dump_actuator.py new file mode 100644 index 0000000000000000000000000000000000000000..e399b0216c1326de8f1ee928bf194fd78fa5e96b --- /dev/null +++ b/accuracy_tools/msprobe/core/base/dump_actuator.py @@ -0,0 +1,191 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +from msprobe.utils.constants import MsgConst +from msprobe.utils.dependencies import dependent +from msprobe.utils.exceptions import MsprobeException +from msprobe.utils.io import load_bin_data, load_npy, save_npy +from msprobe.utils.log import logger +from msprobe.utils.path import join_path + + +def get_tf_type2dtype_map(): + pons = dependent.get_tensorflow() + if None not in pons: + tf, _, _ = pons + return { + tf.float16: np.float16, + tf.float32: np.float32, + tf.float64: np.float64, + tf.int8: np.int8, + tf.int16: np.int16, + tf.int32: np.int32, + tf.int64: np.int64, + } + else: + return {} + + +class OfflineModelActuator: + def __init__(self, model_path: str, input_shape: dict, input_path: str, **kwargs): + self.model_path = model_path + self.input_shape = input_shape or {} + self.input_path = input_path or "" + self.kwargs = kwargs + self.dir_pool = kwargs.get("dir_pool") + + @staticmethod + def _is_dynamic_shape(tensor_shape): + for shape in tensor_shape: + if shape is None or not isinstance(shape, int): + return True + return False + + @staticmethod + def _tensor2numpy_for_type(tensor_type): + base_type2dtype_map = { + "tensor(int)": np.int32, + "tensor(int8)": np.int8, + "tensor(int16)": np.int16, + "tensor(int32)": np.int32, + "tensor(int64)": np.int64, + "tensor(uint8)": np.uint8, + "tensor(uint16)": np.uint16, + "tensor(uint32)": np.uint32, + "tensor(uint64)": np.uint64, + "tensor(float)": np.float32, + "tensor(float16)": np.float16, + "tensor(double)": np.double, + "tensor(bool)": np.bool_, + "tensor(complex64)": np.complex64, + "tensor(complex128)": np.complex_, + "float32": np.float32, + "float16": np.float16, + } + numpy_data_type = {**base_type2dtype_map, **get_tf_type2dtype_map()}.get(tensor_type) + if numpy_data_type: + return numpy_data_type + else: + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, f"Tensor type {tensor_type} not provided.") + + @staticmethod + def _generate_random_input_data(save_dir, names, shapes, dtypes, is_byte_data=False): + input_map = {} + for tensor_name, tensor_shape, tensor_dtype in zip(names, shapes, dtypes): + if is_byte_data: + input_data = np.random.randint(0, 256, int(np.prod(tensor_shape))).astype(np.uint8) + else: + input_data = np.random.random(tensor_shape).astype(tensor_dtype) + input_map[tensor_name] = input_data + shape_str = "_".join(list(map(str, tensor_shape))) + file_name = "_".join([tensor_name, "shape", shape_str, ".npy"]) + save_npy(input_data, join_path(save_dir, file_name)) + logger.info( + f"Save input file path: {join_path(save_dir, file_name)}, " + f"shape: {input_data.shape}, dtype: {input_data.dtype}." + ) + return input_map + + @staticmethod + def _read_input_data(input_paths, names, shapes, dtypes, is_byte_data=False): + input_map = {} + for input_path, name, shape, dtype in zip(input_paths, names, shapes, dtypes): + if input_path.endswith(".bin"): + input_data = load_bin_data(input_path, dtype, shape, is_byte_data) + elif input_path.endswith(".npy"): + input_data = load_npy(input_path) + if np.prod(input_data.shape) != np.prod(shape) and not is_byte_data: + raise MsprobeException( + MsgConst.INVALID_ARGU, + "The shape of the input data does not match the model's shape, " + f"input path: {input_path}, input shape: {input_data.shape}, " + f"model's shape: {shape}.", + ) + if not is_byte_data: + input_data = input_data.reshape(shape) + input_map[name] = input_data + logger.info(f"Load input file path: {input_path}, shape: {input_data.shape}, dtype: {input_data.dtype}.") + return input_map + + @staticmethod + def _check_input_shape(op_name, model_shape, input_shape): + if not input_shape: + raise MsprobeException( + MsgConst.REQUIRED_ARGU_MISSING, + f"{op_name}'s input_shape is missing. " + f'Please set `shape: [xxx]` in "input" according to {model_shape}.', + ) + if len(model_shape) != len(input_shape): + raise MsprobeException( + MsgConst.INVALID_ARGU, + f"Unequal lengths for the shape of {op_name}. " + f"Model shape: {model_shape}, input shape: {input_shape}.", + ) + for index, value in enumerate(model_shape): + if value is None or isinstance(value, str): + continue + if input_shape[index] != value: + raise MsprobeException( + MsgConst.INVALID_ARGU, + "The input shape does not match the model shape. " + f"Tensor name: {op_name}, {str(input_shape)} v.s. {str(model_shape)}.", + ) + + @classmethod + def _get_input_shape_info(cls, tensor_name, tensor_shape, input_shape, tensor_type): + cls._check_input_shape(tensor_name, tensor_shape, input_shape) + tensor_shape_info = {"name": tensor_name, "shape": input_shape, "type": tensor_type} + logger.info(f"The dynamic shape of {tensor_name} has been fixed to {input_shape}.") + return tensor_shape_info + + def get_inputs_data(self, inputs_tensor_info, is_byte_data=False): + names, shapes, dtypes = [], [], [] + for x in inputs_tensor_info: + names.append(x["name"]) + shapes.append(x["shape"]) + # read raw byte data (memory) regardless of type; defaults to int8. + dtypes.append(self._tensor2numpy_for_type(x["type"]) if not is_byte_data else np.int8) + if not self.input_path: + self.dir_pool.make_input_dir() + input_map = self._generate_random_input_data( + self.dir_pool.get_input_dir(), names, shapes, dtypes, is_byte_data + ) + else: + input_map = self._read_input_data(self.input_path, names, shapes, dtypes, is_byte_data) + return input_map + + def process_tensor_shape(self, tensor_name, tensor_type, tensor_shape): + tensor_shape_info_list = [] + if self._is_dynamic_shape(tensor_shape): + if not self.input_shape: + raise MsprobeException( + MsgConst.INVALID_ARGU, + f"The dynamic shape {tensor_shape} are not supported. Please " + f'set "shape" of {tensor_name} in "input" to fix the dynamic shape.', + ) + if tensor_name not in self.input_shape: + raise MsprobeException( + MsgConst.INVALID_ARGU, + f'{tensor_name} has a dynamic shape, but its shape is not defined in the "input".', + ) + if self.input_shape: + inshape = self.input_shape.get(tensor_name) + tensor_shape_info = self._get_input_shape_info(tensor_name, tensor_shape, inshape, tensor_type) + tensor_shape_info_list.append(tensor_shape_info) + else: + tensor_shape_info = {"name": tensor_name, "shape": tensor_shape, "type": tensor_type} + tensor_shape_info_list.append(tensor_shape_info) + return tensor_shape_info_list diff --git a/accuracy_tools/msprobe/core/base/dump_dumper.py b/accuracy_tools/msprobe/core/base/dump_dumper.py new file mode 100644 index 0000000000000000000000000000000000000000..5c566c26a882b26d28ba18ffedd96f26d605d839 --- /dev/null +++ b/accuracy_tools/msprobe/core/base/dump_dumper.py @@ -0,0 +1,49 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod + +from msprobe.utils.constants import MsgConst +from msprobe.utils.exceptions import MsprobeException +from msprobe.utils.hijack import release + + +class BaseDumper(ABC): + def __init__(self, data_mode): + self.data_mode = data_mode + self.data_for_save = {} + self.input_map = {} + self.output_map = {} + self.handler = [] + self._data_iter = None + + @staticmethod + def through_nodes(nodes, node_name, in_or_out, data_map): + for i, item in enumerate(nodes): + if isinstance(item, str): + args_name = item + elif hasattr(item, "name"): + args_name = item.name + else: + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, "Unsupported node type.") + res = [node_name, in_or_out, args_name, i, data_map.get(args_name)] + yield res + + @abstractmethod + def register_hook(self): + pass + + def release_hook(self): + for handler_hex in self.handler: + release(handler_hex) diff --git a/accuracy_tools/msprobe/core/base/dump_writer.py b/accuracy_tools/msprobe/core/base/dump_writer.py new file mode 100644 index 0000000000000000000000000000000000000000..3b0e47dca641daa09ad73b9ac48602f1fda367b5 --- /dev/null +++ b/accuracy_tools/msprobe/core/base/dump_writer.py @@ -0,0 +1,126 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from abc import ABC, abstractmethod + +import numpy as np + +from msprobe.utils.constants import DumpConst, MsgConst, PathConst +from msprobe.utils.exceptions import MsprobeException +from msprobe.utils.io import save_bin_from_bytes, save_bin_from_ndarray, save_npy +from msprobe.utils.path import SafePath, join_path +from msprobe.utils.toolkits import get_current_timestamp, get_valid_name, register + + +class RankDirFile(ABC): + def __init__(self, buffer_size): + self.max_cache_size = buffer_size + self.cache_file = None + self.cache_file_size = 0 + self.rank_dir = None + + @abstractmethod + def _save(self): + pass + + def add_rank_dir(self, rank_dir): + self.rank_dir = rank_dir + + def cover(self, data): + self.cache_file_size += sys.getsizeof(data) + if self.cache_file_size < self.max_cache_size: + return + self._save() + self.cache_file_size = 0 + + def clear_cache(self): + if self.cache_file_size == 0: + return + self._save() + self.cache_file_size = 0 + if isinstance(self.cache_file, dict): + self.cache_file.clear() + + +class SaveTensorStrategy(ABC): + def __init__(self): + self.tensor_dir = None + self.tensor_path = None + self.suffix = None + + @abstractmethod + def _save(self, data): + pass + + def add_tensor_dir(self, tensor_dir): + self.tensor_dir = tensor_dir + + def save_tensor_data(self, node_name, args_name, data): + file_name = self._generate_file_name(node_name, args_name) + self.tensor_path = self._generate_path(file_name) + self._save(data) + + def _generate_path(self, tensor_name): + tensor_path = SafePath(join_path(self.tensor_dir, tensor_name), PathConst.FILE, "w", suffix=self.suffix).check( + path_exist=False + ) + return tensor_path + + def _generate_file_name(self, node_name, args_name): + name = ".".join( + [ + str(get_current_timestamp(microsecond=True)), + get_valid_name(node_name), + get_valid_name(args_name) + f"{self.suffix}", + ] + ) + return name + + +class SaveTensor: + _fmt_map = {} + + @classmethod + def register(cls, name): + return register(name, cls._fmt_map) + + @classmethod + def get(cls, name): + return cls._fmt_map.get(name) + + +@SaveTensor.register(DumpConst.NPY_FORMAT) +class SaveNpyTensor(SaveTensorStrategy): + def __init__(self): + super().__init__() + self.suffix = ".npy" + + def _save(self, data): + save_npy(data, self.tensor_path) + + +@SaveTensor.register(DumpConst.BIN_FORMAT) +class SaveBinTensor(SaveTensorStrategy): + def __init__(self): + super().__init__() + self.suffix = ".bin" + + def _save(self, data): + if isinstance(data, np.ndarray): + save_bin_from_ndarray(data, self.tensor_path) + elif isinstance(data, bytes): + save_bin_from_bytes(data, self.tensor_path) + else: + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, f"Unsupported data type: {type(data).__name__}.") diff --git a/accuracy_tools/msprobe/core/cli/__init__.py b/accuracy_tools/msprobe/core/cli/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..178ce99fa169c4def8769dcb5154cc4f63fe44b2 --- /dev/null +++ b/accuracy_tools/msprobe/core/cli/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from msprobe.core.cli.command_dump import DumpCommand diff --git a/accuracy_tools/msprobe/core/cli/command_dump.py b/accuracy_tools/msprobe/core/cli/command_dump.py new file mode 100644 index 0000000000000000000000000000000000000000..bfc50e3ef27a0083c6303331aeaf22df1c1093f0 --- /dev/null +++ b/accuracy_tools/msprobe/core/cli/command_dump.py @@ -0,0 +1,65 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from msprobe.base import BaseCommand, Command +from msprobe.common.validation import CheckConfigPath, CheckExec, CheckFramework +from msprobe.utils.constants import CfgConst, CmdConst, PathConst + + +@Command.register("msprobe", CmdConst.DUMP) +class DumpCommand(BaseCommand): + @staticmethod + def add_required_arguments(parser): + req = parser.add_argument_group("Required arguments") + req.add_argument( + "-e", + "--exec", + dest=CfgConst.EXEC, + action=CheckExec, + required=True, + help=f""" Supports two input types: + 1. An offline model file with {("saved_model",) + PathConst.SUFFIX_OFFLINE_MODEL} extension; + 2. An executable CLI scripts enclosed in quotes end with {PathConst.SUFFIX_ONLINE_SCRIPT}. Default: None""", + ) + + @staticmethod + def add_optional_arguments(parser): + opt = parser.add_argument_group("Optional arguments") + opt.add_argument( + "-cfg", + "--config", + dest=CfgConst.CONFIG_PATH, + action=CheckConfigPath, + help=""" A config JSON file for storing data dump settings. Default: None""", + ) + opt.add_argument( + "-f", + "--framework", + dest=CfgConst.FRAMEWORK, + action=CheckFramework, + help=f""" Required when using: {CfgConst.ALL_FRAMEWORK}. Default: None""", + ) + opt.add_argument( + "-x", + "--msprobex", + dest="msprobex", + default=False, + action="store_true", + help=""" Use msprobe extended API. Default: False""", + ) + + @classmethod + def add_arguments(cls, parser): + cls.add_required_arguments(parser) + cls.add_optional_arguments(parser)