From ffbab6d9796e235a09e5e7b4318fab5cff929f61 Mon Sep 17 00:00:00 2001 From: kai-ma Date: Fri, 4 Jul 2025 12:31:04 +0800 Subject: [PATCH] add dump for acl,caffe,om,onnx,tf --- accuracy_tools/msprobe/core/dump/__init__.py | 19 ++ .../msprobe/core/dump/acl_manager.py | 109 +++++++++++ .../msprobe/core/dump/caffe_model.py | 59 ++++++ accuracy_tools/msprobe/core/dump/om_model.py | 171 ++++++++++++++++++ .../msprobe/core/dump/onnx_model.py | 68 +++++++ accuracy_tools/msprobe/core/dump/tf_model.py | 171 ++++++++++++++++++ 6 files changed, 597 insertions(+) create mode 100644 accuracy_tools/msprobe/core/dump/__init__.py create mode 100644 accuracy_tools/msprobe/core/dump/acl_manager.py create mode 100644 accuracy_tools/msprobe/core/dump/caffe_model.py create mode 100644 accuracy_tools/msprobe/core/dump/om_model.py create mode 100644 accuracy_tools/msprobe/core/dump/onnx_model.py create mode 100644 accuracy_tools/msprobe/core/dump/tf_model.py diff --git a/accuracy_tools/msprobe/core/dump/__init__.py b/accuracy_tools/msprobe/core/dump/__init__.py new file mode 100644 index 000000000..6267a4ba6 --- /dev/null +++ b/accuracy_tools/msprobe/core/dump/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from msprobe.core.dump.acl_manager import acl_device_manager +from msprobe.core.dump.caffe_model import CaffeModelActuator +from msprobe.core.dump.om_model import OmModelActuator +from msprobe.core.dump.onnx_model import OnnxModelActuator +from msprobe.core.dump.tf_model import FrozenGraphActuatorCPU, FrozenGraphActuatorNPU diff --git a/accuracy_tools/msprobe/core/dump/acl_manager.py b/accuracy_tools/msprobe/core/dump/acl_manager.py new file mode 100644 index 000000000..59e1785c2 --- /dev/null +++ b/accuracy_tools/msprobe/core/dump/acl_manager.py @@ -0,0 +1,109 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from msprobe.lib.msprobe_c import acl +from msprobe.utils.constants import ACLConst, MsgConst +from msprobe.utils.exceptions import MsprobeException +from msprobe.utils.log import logger + + +class ACLDeviceManager: + _instance = None + + def __new__(cls, *args, **kwargs): + if not cls._instance: + cls._instance = super(ACLDeviceManager, cls).__new__(cls) + return cls._instance + + def __init__(self): + self.acl_device_manager_map = {} + + def get_acl_resource_manager(self, rank): + if rank not in self.acl_device_manager_map: + self.acl_device_manager_map[rank] = ACLResourceManager(rank) + return self.acl_device_manager_map[rank] + + +class ACLResourceManager: + def __init__(self, rank=0): + self.ptr_context = None + self.is_acl_initialized = False + self.is_set_dump = False + self.rank = rank + + def initialize(self): + if self.is_acl_initialized: + return + ret = acl.init() + if ret == ACLConst.SUCCESS: + logger.info("Acl init success!") + else: + raise MsprobeException(MsgConst.CALL_FAILED, f"Acl init failed! ErrorCode = {ret}.") + ret = acl.rt_set_device(self.rank) + if ret == ACLConst.SUCCESS: + logger.info(f"Set device:{self.rank} success!") + else: + raise MsprobeException(MsgConst.CALL_FAILED, f"Acl set device:{self.rank} failed! ErrorCode = {ret}.") + self.ptr_context, ret = acl.rt_create_context(self.rank) + if ret == ACLConst.SUCCESS: + logger.info("Create new context success!") + else: + raise MsprobeException(MsgConst.CALL_FAILED, f"Acl create context failed! ErrorCode = {ret}.") + self.is_acl_initialized = True + + def set_dump(self, dump_cfg_path, message_call_back): + if not self.is_acl_initialized or self.is_set_dump: + return + ret = acl.init_dump() + if ret != ACLConst.SUCCESS: + raise MsprobeException(MsgConst.CALL_FAILED, f"Acl init dump failed! ErrorCode = {ret}.") + ret = acl.dump_reg_callback(message_call_back, 0) + if ret != ACLConst.SUCCESS: + raise MsprobeException(MsgConst.CALL_FAILED, f"Acl dump reg callback failed! ErrorCode = {ret}.") + ret = acl.set_dump(dump_cfg_path) + if ret != ACLConst.SUCCESS: + raise MsprobeException(MsgConst.CALL_FAILED, f"Acl set dump failed! ErrorCode = {ret}.") + self.is_set_dump = True + + def destroy_resource(self): + if not self.is_acl_initialized: + return + self._finalize_dump() + if self.ptr_context is not None: + ret = acl.rt_destroy_context(self.ptr_context) + if ret != ACLConst.SUCCESS: + logger.error(f"Destroy context failed! ErrorCode = {ret}.") + ret = acl.rt_reset_device(self.rank) + if ret != ACLConst.SUCCESS: + logger.error(f"Reset deivce failed! DeviceId = {self.rank}, ErrorCode = {ret}.") + else: + logger.info(f"End to reset device:{self.rank}.") + ret = acl.finalize() + if ret != ACLConst.SUCCESS: + logger.error(f"Finalize failed! ErrorCode = {ret}.") + else: + logger.info("End to finalize.") + self.is_acl_initialized = False + + def _finalize_dump(self): + if not self.is_set_dump: + return + acl.dump_unreg_callback() + ret = acl.finalize_dump() + if ret != ACLConst.SUCCESS: + logger.error(f"Finalize dump failed! ErrorCode = {ret}.") + self.is_set_dump = False + + +acl_device_manager = ACLDeviceManager() diff --git a/accuracy_tools/msprobe/core/dump/caffe_model.py b/accuracy_tools/msprobe/core/dump/caffe_model.py new file mode 100644 index 000000000..ad395e75e --- /dev/null +++ b/accuracy_tools/msprobe/core/dump/caffe_model.py @@ -0,0 +1,59 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +from msprobe.core.base import OfflineModelActuator +from msprobe.utils.constants import MsgConst +from msprobe.utils.exceptions import MsprobeException +from msprobe.utils.io import load_caffe_model +from msprobe.utils.log import logger + + +class CaffeModelActuator(OfflineModelActuator): + def __init__(self, model_path, input_shape, input_path, **kwargs): + super().__init__(model_path, input_shape, input_path, **kwargs) + self.weight_path = kwargs.get("weight_path", "") + if not self.weight_path: + raise MsprobeException( + MsgConst.REQUIRED_ARGU_MISSING, + "When using Caffe for inference, a weight file (.caffemodel) is required.", + ) + + def load_model(self): + self.model = load_caffe_model(self.model_path, self.weight_path) + + def get_input_tensor_info(self): + inputs_tensor_info = [] + input_blob_names = list(self.model.blobs.keys())[: len(self.model.inputs)] + for input_name in input_blob_names: + tensor_data = self.model.blobs[input_name].data + tensor_info = {"name": input_name, "shape": tuple(tensor_data.shape), "type": str(tensor_data.dtype)} + inputs_tensor_info.append(tensor_info) + logger.warning( + "Caffe model doesn't support dynamic shapes and " + "will use the input shape defined in the model for inference." + ) + logger.info(f"Model input tensor info: {inputs_tensor_info}.") + return inputs_tensor_info + + def infer(self, input_map): + try: + for input_name, input_data in input_map.items(): + np.copyto(self.model.blobs[input_name].data, input_data) + return self.model.forward() + except Exception as e: + raise MsprobeException( + MsgConst.CALL_FAILED, "Please check if the input shape or data matches the model requirements." + ) from e diff --git a/accuracy_tools/msprobe/core/dump/om_model.py b/accuracy_tools/msprobe/core/dump/om_model.py new file mode 100644 index 000000000..ef0caa824 --- /dev/null +++ b/accuracy_tools/msprobe/core/dump/om_model.py @@ -0,0 +1,171 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from msprobe.common.ascend import cann +from msprobe.core.base import OfflineModelActuator +from msprobe.lib.msprobe_c import acl +from msprobe.utils.constants import ACLConst, MsgConst +from msprobe.utils.exceptions import MsprobeException +from msprobe.utils.io import load_om_model +from msprobe.utils.log import logger +from msprobe.utils.path import get_name_and_ext, join_path + +_BUFFER_METHOD_MAP = {"input": acl.get_input_size_by_index, "output": acl.get_output_size_by_index} + + +class OmModelActuator(OfflineModelActuator): + def __init__(self, model_path, input_shape, input_path, **kwargs): + super().__init__(model_path, input_shape, input_path, **kwargs) + self.ptr_model_desc = None + self.ptr_input_dataset = None + self.ptr_output_dataset = None + self.model_id = None + self.input_size = 0 + self.output_size = 0 + self.input_ptr_size = [] + self.output_ptr_size = [] + + def load_model(self): + self.model_id = load_om_model(self.model_path) + + def get_input_tensor_info(self): + inputs_tensor_info = [] + self._get_model_info() + for index in range(self.input_size): + name = acl.get_input_name_by_index(self.ptr_model_desc, index) + if name is None: + raise MsprobeException(MsgConst.CALL_FAILED, f"Get input name by index:{index} failed!") + shape, ret = acl.get_input_dims(self.ptr_model_desc, index) + if shape is None or ret != ACLConst.SUCCESS: + raise MsprobeException(MsgConst.CALL_FAILED, f"Get input shape by index:{index} failed!") + dtype = acl.get_input_data_type(self.ptr_model_desc, index) + if dtype is None: + raise MsprobeException(MsgConst.CALL_FAILED, f"Get input type by index:{index} failed!") + inputs_tensor_info.append({"name": name, "shape": shape["dims"], "type": dtype}) + logger.info(f"Model input tensor info: {inputs_tensor_info}.") + return inputs_tensor_info + + def infer(self, input_map): + self._create_data_buffer() + self._copy_data_from_host_to_device(input_map) + self._run() + self._destroy_data_buffer() + self._destroy_resource() + + def convert_om2json(self): + name, _ = get_name_and_ext(self.model_path) + json_path = join_path(self.dir_pool.get_model_dir(), name + ".json") + cann.model2json(self.model_path, json_path) + + def _run(self): + ret = acl.execute(self.model_id, self.ptr_input_dataset, self.ptr_output_dataset) + if ret != ACLConst.SUCCESS: + raise MsprobeException(MsgConst.CALL_FAILED, f"Model execute failed! ErrorCode = {ret}.") + else: + logger.info("Model execute success!") + + def _get_model_info(self): + self.ptr_model_desc = acl.create_desc() + if self.ptr_model_desc is None: + raise MsprobeException(MsgConst.CALL_FAILED, "Create model description Failed!") + ret = acl.get_desc(self.ptr_model_desc, self.model_id) + if ret != ACLConst.SUCCESS: + raise MsprobeException(MsgConst.CALL_FAILED, f"Get model description failed! ErrorCode = {ret}.") + self.input_size = acl.get_num_inputs(self.ptr_model_desc) + if self.input_size is None: + raise MsprobeException(MsgConst.CALL_FAILED, "Get input nums failed!") + self.output_size = acl.get_num_outputs(self.ptr_model_desc) + if self.output_size is None: + raise MsprobeException(MsgConst.CALL_FAILED, "Get output nums failed!") + logger.info("Create model description Success!") + + def _create_data_buffer(self): + for mode in ["input", "output"]: + ptr_dataset = getattr(self, f"ptr_{mode}_dataset", None) + data_size = getattr(self, f"{mode}_size", 0) + ptr_size_map = getattr(self, f"{mode}_ptr_size", []) + ptr_dataset = acl.create_dataset() + if ptr_dataset is None: + raise MsprobeException(MsgConst.CALL_FAILED, f"Create {mode} dataset failed!") + for index in range(data_size): + temp_buffer_size = _BUFFER_METHOD_MAP.get(mode)(self.ptr_model_desc, index) + if temp_buffer_size is None: + raise MsprobeException(MsgConst.CALL_FAILED, f"Get {mode} size by index:{index} failed!") + temp_ptr, ret = acl.rt_malloc(temp_buffer_size) + if ret != ACLConst.SUCCESS: + raise MsprobeException(MsgConst.CALL_FAILED, f"{mode.title()} malloc failed! ErrorCode = {ret}.") + ptr_size_map.append({"buffer": temp_ptr, "size": temp_buffer_size}) + temp_buffer = acl.create_databuffer(temp_ptr, temp_buffer_size) + if temp_buffer is None: + acl.rt_free(temp_ptr) + raise MsprobeException(MsgConst.CALL_FAILED, f"Create {mode} buffer failed!") + ret = acl.add_dataset_buffer(ptr_dataset, temp_buffer) + if ret != ACLConst.SUCCESS: + acl.rt_free(temp_ptr) + raise MsprobeException( + MsgConst.CALL_FAILED, f"Add {mode} buffer to dataset failed! ErrorCode = {ret}." + ) + setattr(self, f"ptr_{mode}_dataset", ptr_dataset) + setattr(self, f"{mode}_size_map", ptr_size_map) + + def _destroy_resource(self): + ret = acl.unload(self.model_id) + if ret != ACLConst.SUCCESS: + logger.error(f"Unload model failed! ErrorCode = {ret}.") + else: + logger.info("End to unload model.") + if self.ptr_model_desc is not None: + ret = acl.destroy_desc(self.ptr_model_desc) + if ret != ACLConst.SUCCESS: + logger.error(f"Destroy model description failed! ErrorCode = {ret}.") + + def _destroy_data_buffer(self): + for mode in ["input", "output"]: + dataset = getattr(self, f"ptr_{mode}_dataset", None) + ptr_size_map = getattr(self, f"{mode}_ptr_size", []) + if dataset is None or not ptr_size_map: + return + buffer_nums = acl.get_dataset_num_buffers(dataset) + if buffer_nums is None: + logger.error(f"Get dataset num buffers failed!") + return + for index in range(buffer_nums): + data_buffer = acl.get_dataset_buffer(dataset, index) + if data_buffer is None: + logger.error(f"From {mode} dataset get dataBuffer failed!") + continue + ret = acl.destroy_databuffer(data_buffer) + if ret != ACLConst.SUCCESS: + logger.error(f"Destroy dataBuffer failed! ErrorCode = {ret}.") + ret = acl.destroy_dataset(dataset) + if ret != ACLConst.SUCCESS: + logger.error(f"Destroy {mode} dataset failed! ErrorCode = {ret}.") + for items in ptr_size_map: + ptr = items.get("buffer", None) + ret = acl.rt_free(ptr) + if ret != ACLConst.SUCCESS: + logger.error(f"Free Failed! ErrorCode = {ret}.") + + def _copy_data_from_host_to_device(self, input_map): + if len(input_map) != len(self.input_ptr_size): + logger.warning(f"input_map size:{len(input_map)} not equal input_ptr_size:{len(self.input_ptr_size)}") + return + for index, (_, input_data) in enumerate(input_map.items()): + dest_ptr = self.input_ptr_size[index].get("buffer", None) + dest_size = self.input_ptr_size[index].get("size", 0) + byte_data = input_data.tobytes() + ret = acl.rt_memcpy(dest_ptr, dest_size, byte_data, len(byte_data), ACLConst.MEMCPY_HOST_TO_DEVICE) + if ret != ACLConst.SUCCESS: + logger.error(f"Memcpy Input data from host to device failed! ErrorCode = {ret}.") + return diff --git a/accuracy_tools/msprobe/core/dump/onnx_model.py b/accuracy_tools/msprobe/core/dump/onnx_model.py new file mode 100644 index 000000000..64afe1dc9 --- /dev/null +++ b/accuracy_tools/msprobe/core/dump/onnx_model.py @@ -0,0 +1,68 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from msprobe.core.base import OfflineModelActuator +from msprobe.utils.constants import MsgConst, PathConst +from msprobe.utils.dependencies import dependent +from msprobe.utils.exceptions import MsprobeException +from msprobe.utils.io import load_onnx_model, load_onnx_session, save_onnx_model +from msprobe.utils.log import logger +from msprobe.utils.path import convert_bytes, get_basename_from_path, is_file, join_path + + +class OnnxModelActuator(OfflineModelActuator): + def __init__(self, model_path, input_shape, input_path, **kwargs): + super().__init__(model_path, input_shape, input_path, **kwargs) + + @staticmethod + def infer(uninfer_model_path, input_map): + model_session = load_onnx_session(uninfer_model_path) + output_name = [node.name for node in model_session.get_outputs()] + try: + return model_session.run(output_name, input_map) + except Exception as e: + raise MsprobeException( + MsgConst.CALL_FAILED, "Please check if the input shape or data matches the model requirements." + ) from e + + def load_model(self): + self.origin_model = load_onnx_model(self.model_path) + self.model_session = load_onnx_session(self.model_path, self.kwargs.get("onnx_fusion_switch", True)) + + def get_input_tensor_info(self): + inputs_tensor_info = [] + for input_item in self.model_session.get_inputs(): + tensor_name, tensor_type, tensor_shape = (input_item.name, input_item.type, tuple(input_item.shape)) + tensor_shape_info = self.process_tensor_shape(tensor_name, tensor_type, tensor_shape) + inputs_tensor_info.extend(tensor_shape_info) + logger.info(f"Model input tensor info: {inputs_tensor_info}.") + return inputs_tensor_info + + def export_uninfer_model(self): + model_name = "inferential_" + get_basename_from_path(self.model_path) + uninfer_model_path = join_path(self.dir_pool.get_model_dir(), model_name) + if not is_file(uninfer_model_path): + onnx = dependent.get("onnx") + del self.origin_model.graph.output[:] + for node in self.origin_model.graph.node: + for tensor_name in node.output: + value_info = onnx.ValueInfoProto(name=tensor_name) + self.origin_model.graph.output.append(value_info) + model_size = self.origin_model.ByteSize() + logger.info(f"The size of the modified ONNX model to be saved is {convert_bytes(model_size)}.") + if model_size < 0 or model_size > PathConst.SIZE_2G: + logger.warning("The modified ONNX model size has exceeded 2GB, posing a risk of numerical overflow.") + save_onnx_model(self.origin_model, uninfer_model_path) + logger.info(f"The modified ONNX model has been successfully saved to {uninfer_model_path}.") + return uninfer_model_path diff --git a/accuracy_tools/msprobe/core/dump/tf_model.py b/accuracy_tools/msprobe/core/dump/tf_model.py new file mode 100644 index 000000000..bf51a554a --- /dev/null +++ b/accuracy_tools/msprobe/core/dump/tf_model.py @@ -0,0 +1,171 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from glob import glob + +from msprobe.common.ascend import cann +from msprobe.core.base import OfflineModelActuator +from msprobe.utils.constants import MsgConst +from msprobe.utils.dependencies import dependent +from msprobe.utils.exceptions import MsprobeException +from msprobe.utils.io import load_pb_frozen_graph_model +from msprobe.utils.log import logger +from msprobe.utils.path import get_name_and_ext, join_path +from msprobe.utils.toolkits import get_net_output_nodes_from_graph_def + + +class FrozenGraphActuator(OfflineModelActuator): + def __init__(self, model_path, input_shape, input_path, **kwargs): + super().__init__(model_path, input_shape, input_path, **kwargs) + self.tf, self.rewriter_config = self._import_tf() + self.sess = None + self.graph_def = None + self.all_node_names = [] + + @staticmethod + def _import_tf(): + pons = dependent.get_tensorflow() + if None not in pons: + tf, rewriter_config, _ = pons + tf.compat.v1.disable_eager_execution() + return tf, rewriter_config + return None, None + + @staticmethod + def _get_tensor_name(name: str): + return name.split(":")[0] + + @staticmethod + def _tf_shape_to_list(tensor_shape): + shape_list = [] + for dim in tensor_shape.dim: + if dim.size == -1: + shape_list.append(None) + else: + shape_list.append(dim.size) + return shape_list + + def close(self): + if self.sess is not None: + try: + self.sess.close() + except AttributeError: + pass + self.sess = None + + def get_input_tensor_info(self): + inputs_tensor_info = [] + for node in self.graph_def.node: + if node.op == "Placeholder": + tensor_name = node.name + tensor_dtype = self.tf.dtypes.as_dtype(node.attr["dtype"].type) + tensor_shape = self._tf_shape_to_list(node.attr["shape"].shape) + inputs_tensor_info.extend(self.process_tensor_shape(tensor_name, tensor_dtype, tensor_shape)) + self.all_node_names.append(node.name) + logger.info(f"Model input tensor info: {inputs_tensor_info}.") + return inputs_tensor_info + + def load_model(self): + self.graph_def = load_pb_frozen_graph_model(self.model_path) + + def infer(self, input_map: dict): + self.sess = self._open_session() + self._renew_all_node_names() + tf_ops = self._get_tf_ops() + feed_dict = self._build_feed(input_map) + try: + outputs = self.sess.run(tf_ops, feed_dict=feed_dict) + except Exception as e: + raise MsprobeException( + MsgConst.CALL_FAILED, "Please check if the input shape or data matches the model requirements." + ) from e + self.close() + return outputs + + def _open_session(self): + return + + def _renew_all_node_names(self): + pass + + def _get_tf_ops(self): + tf_ops = [] + for name in self.all_node_names: + try: + tf_ops.append(self.sess.graph.get_tensor_by_name(name + ":0")) + except Exception as e: + raise MsprobeException( + MsgConst.CALL_FAILED, f'The model lacks the {name + ":0"} node. Please check your model.' + ) from e + return tf_ops + + def _build_feed(self, input_map: dict): + feed_dict = {} + for name, input_data in input_map.items(): + tensor_name = name + ":0" if ":" not in name else name + try: + feed_dict[self.sess.graph.get_tensor_by_name(tensor_name)] = input_data + except Exception as e: + raise MsprobeException( + MsgConst.CALL_FAILED, f"The model lacks the {tensor_name} node. Please check your model." + ) from e + return feed_dict + + +class FrozenGraphActuatorCPU(FrozenGraphActuator): + def __init__(self, model_path, input_shape, input_path, **kwargs): + super().__init__(model_path, input_shape, input_path, **kwargs) + + def _open_session(self): + return self.tf.compat.v1.Session() + + +class FrozenGraphActuatorNPU(FrozenGraphActuator): + def __init__(self, model_path, input_shape, input_path, **kwargs): + super().__init__(model_path, input_shape, input_path, **kwargs) + self.data_mode = kwargs.get("data_mode", ["all"]) + self.fusion_switch_file = kwargs.get("fsf", "") + + def convert_txt2json(self): + model_path = sorted(glob(join_path(self.dir_pool.get_model_dir(), "*", "*_Build.txt"))) + if model_path: + name, _ = get_name_and_ext(model_path[-1]) + cann.model2json(model_path[-1], join_path(self.dir_pool.get_model_dir(), name + ".json")) + else: + raise MsprobeException( + MsgConst.PATH_NOT_FOUND, "No TXT format graph file found in the TensorFlow framework." + ) + + def _open_session(self): + npu_device = dependent.get("npu_device") + if not npu_device: + raise MsprobeException( + MsgConst.ATTRIBUTE_ERROR, "Please ensure that the TF plugin npu_device is properly installed." + ) + npu_device.compat.enable_v1() + config_proto = self.tf.compat.v1.ConfigProto() + custom_op = config_proto.graph_options.rewrite_options.custom_optimizers.add() + custom_op.name = "NpuOptimizer" + custom_op.parameter_map["enable_dump"].b = True + custom_op.parameter_map["dump_path"].s = self.tf.compat.as_bytes(self.dir_pool.get_rank_dir()) + custom_op.parameter_map["dump_step"].s = self.tf.compat.as_bytes("0") + custom_op.parameter_map["data_mode"].s = self.tf.compat.as_bytes(self.data_mode[0]) + if self.fusion_switch_file: + logger.info(f"Fusion switch settings read from {self.fusion_switch_file}.") + custom_op.parameter_map["fusion_switch_file"].s = self.tf.compat.as_bytes(self.fusion_switch_file) + config_proto.graph_options.rewrite_options.remapping = self.rewriter_config.OFF + return self.tf.compat.v1.Session(config=config_proto) + + def _renew_all_node_names(self): + self.all_node_names = get_net_output_nodes_from_graph_def(self.graph_def) -- Gitee