diff --git a/msprobe/core/config_initiator/__init__.py b/msprobe/core/config_initiator/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..492a36cd5dbc0292bf326560d547c074e2458c3e --- /dev/null +++ b/msprobe/core/config_initiator/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from msprobe.core.config_initiator.config_dump import DumpConfig diff --git a/msprobe/core/config_initiator/config_dump.py b/msprobe/core/config_initiator/config_dump.py new file mode 100644 index 0000000000000000000000000000000000000000..1f360cd0761f1e2226b960f1191f8bf6d382fe9f --- /dev/null +++ b/msprobe/core/config_initiator/config_dump.py @@ -0,0 +1,101 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from msprobe.base import BaseConfig +from msprobe.core.config_initiator.validate_params import ( + valid_data_mode, + valid_device, + valid_dump_extra, + valid_dump_ge_graph, + valid_dump_graph_level, + valid_dump_path, + valid_fusion_switch_file, + valid_input, + valid_list, + valid_onnx_fusion_switch, + valid_op_id, + valid_saved_model_signature, + valid_saved_model_tag, + valid_summary_mode, +) +from msprobe.utils.constants import CfgConst, DumpConst + + +class DumpConfig(BaseConfig): + def check_config(self, dump_path: str = None): + self.config[self.config.get(CfgConst.TASK)] = self._check_dump_dic(dump_path) + return self.config + + def _check_dump_dic(self, dump_path: str = None): + self._update_config( + self.task_config, + DumpConst.DUMP_PATH, + valid_dump_path, + dump_path or self.task_config.get(DumpConst.DUMP_PATH, "./"), + ) + self._update_config(self.task_config, DumpConst.LIST, valid_list, self.task_config.get(DumpConst.LIST, [])) + self._update_config( + self.task_config, DumpConst.DATA_MODE, valid_data_mode, self.task_config.get(DumpConst.DATA_MODE, ["all"]) + ) + self._update_config( + self.task_config, + DumpConst.SUMMARY_MODE, + valid_summary_mode, + self.task_config.get(DumpConst.SUMMARY_MODE, CfgConst.TASK_STAT), + ) + self._update_config( + self.task_config, DumpConst.DUMP_EXTRA, valid_dump_extra, self.task_config.get(DumpConst.DUMP_EXTRA, []) + ) + self._update_config(self.task_config, DumpConst.OP_ID, valid_op_id, self.task_config.get(DumpConst.OP_ID, [])) + self._update_config( + self.task_config, + DumpConst.DUMP_GE_GRAPH, + valid_dump_ge_graph, + self.task_config.get(DumpConst.DUMP_GE_GRAPH, "2"), + ) + self._update_config( + self.task_config, + DumpConst.DUMP_GRAPH_LEVEL, + valid_dump_graph_level, + self.task_config.get(DumpConst.DUMP_GRAPH_LEVEL, "3"), + ) + self._update_config( + self.task_config, + DumpConst.FUSION_SWITCH_FILE, + valid_fusion_switch_file, + self.task_config.get(DumpConst.FUSION_SWITCH_FILE, None), + ) + self._update_config( + self.task_config, DumpConst.DEVICE, valid_device, self.task_config.get(DumpConst.DEVICE, None) + ) + self._update_config(self.task_config, DumpConst.INPUT, valid_input, self.task_config.get(DumpConst.INPUT, [])) + self._update_config( + self.task_config, + DumpConst.ONNX_FUSION_switch, + valid_onnx_fusion_switch, + self.task_config.get(DumpConst.ONNX_FUSION_switch, True), + ) + self._update_config( + self.task_config, + DumpConst.SAVED_MODEL_TAG, + valid_saved_model_tag, + self.task_config.get(DumpConst.SAVED_MODEL_TAG, ["serve"]), + ) + self._update_config( + self.task_config, + DumpConst.SAVED_MODEL_SIGN, + valid_saved_model_signature, + self.task_config.get(DumpConst.SAVED_MODEL_SIGN, "serving_default"), + ) + return self.task_config diff --git a/msprobe/core/config_initiator/validate_params.py b/msprobe/core/config_initiator/validate_params.py new file mode 100644 index 0000000000000000000000000000000000000000..50106d53fe99cd9ad9c45976a3780d6da5023894 --- /dev/null +++ b/msprobe/core/config_initiator/validate_params.py @@ -0,0 +1,277 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from itertools import product + +from msprobe.common.validation import parse_hyphen +from msprobe.utils.constants import DumpConst, MsgConst, PathConst +from msprobe.utils.env import check_special_char +from msprobe.utils.exceptions import MsprobeException +from msprobe.utils.log import logger +from msprobe.utils.path import SafePath +from msprobe.utils.toolkits import check_int_border + +_OP_ID_PATTERN = r"^\d{1,10}(_\d{1,10}){0,9}$" +_ALL_DEVICE = {"cpu", "npu"} + + +def valid_dump_path(value: str): + return SafePath(value, PathConst.DIR, "w").check() + + +def _check_common_list(value: list, tag: str = None, limit=None): + if not value: + return value + if not isinstance(value, list): + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, f"'{tag}' must be a list, got {type(value)} instead.") + for vv in value: + check_special_char(vv) + if not limit: + continue + if vv not in limit: + raise MsprobeException(MsgConst.INVALID_ARGU, f"'{tag}' must be one of {limit}, currently: {vv}.") + return value + + +def valid_list(value: list): + return _check_common_list(value, tag="list") + + +def valid_data_mode(value: list): + return _check_common_list(value, tag="data_mode", limit=DumpConst.ALL_DATA_MODE) + + +def valid_dump_extra(values: list): + return _check_common_list(values, tag="dump_extra", limit=DumpConst.ALL_DUMP_EXTRA) + + +def valid_op_id(value: list): + if not value: + return value + if not isinstance(value, list): + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, '"op_id" must be a list.') + res = [] + for element in value: + if isinstance(element, int): + check_int_border(element, tag="the integer part of op_id") + res.append(element) + elif isinstance(element, str) and re.match(_OP_ID_PATTERN, element): + res.append(element) + else: + raise MsprobeException( + MsgConst.INVALID_DATA_TYPE, + '"op_id" is only supported in the ATB dump scenario, ' + f"with formats like 2, 3_1, or 3_1_2, currently: {element}.", + ) + return res + + +def valid_dump_ge_graph(value: str): + if value is None: + return value + if not isinstance(value, str): + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, '"dump_ge_graph" must be a string.') + if value not in DumpConst.ALL_DUMP_GE_GRAPH: + raise MsprobeException( + MsgConst.INVALID_ARGU, f'"dump_ge_graph" must be one of {DumpConst.ALL_DUMP_GE_GRAPH}, currently: {value}.' + ) + return value + + +def valid_dump_graph_level(value: str): + if value is None: + return value + if not isinstance(value, str): + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, '"dump_graph_level" must be a string.') + if value not in DumpConst.ALL_DUMP_GRAPH_LEVEL: + raise MsprobeException( + MsgConst.INVALID_ARGU, + f'"dump_graph_level" must be one of {DumpConst.ALL_DUMP_GRAPH_LEVEL}, currently: {value}.', + ) + return value + + +def valid_fusion_switch_file(value: str): + if value is None: + return value + return SafePath(value, PathConst.FILE, "r", PathConst.SIZE_500M, (".json", ".cfg")).check() + + +def valid_device(value: str): + if value is None: + return value + if not isinstance(value, str): + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, '"device" must be a string.') + if value not in _ALL_DEVICE: + raise MsprobeException(MsgConst.INVALID_ARGU, f'"device" must be one of {_ALL_DEVICE}, currently: {value}.') + return value + + +def valid_input(value: list): + if not value: + return value + return OfflineModelInput(value).parse() + + +class OfflineModelInput: + def __init__(self, input_list): + self.input_list = input_list + self._check_form() + self.is_need_expand_shape = False + + @staticmethod + def _check_name(infile: dict): + if not infile.get("name"): + raise MsprobeException(MsgConst.PARSING_FAILED, "Each input must have a name.") + return infile.get("name") + + @staticmethod + def _check_input_shape(infile: dict, name): + inshape = infile.get("shape") + if inshape: + if not isinstance(inshape, list): + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, f'"shape" of the input {name} must be a list.') + for vv in inshape: + check_int_border(vv, tag=f'Elements in "shape" of the input {name}') + + @staticmethod + def _check_input_path(infile: dict, name): + inpath = infile.get("path") + if inpath: + if not isinstance(inpath, str): + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, f'"path" of the input {name} must be a string.') + if not inpath.endswith((".bin", ".npy")): + raise MsprobeException( + MsgConst.INVALID_ARGU, f'"path" of {name} can only accept .npy or .bin files, currently: {inpath}.' + ) + _ = SafePath(inpath, PathConst.FILE, "r", PathConst.SIZE_10G).check() + + @staticmethod + def _parse_shape_range_for_str(shape): + if "-" in shape: + ranges = parse_hyphen(shape, tag="Elements in a dynamic shape") + elif "," in shape and shape.count(",") == 1: + try: + ranges = list(map(int, shape.split(","))) + except Exception as e: + raise MsprobeException( + MsgConst.INVALID_ARGU, + f"Both sides of the hyphen (-) in the input must be numbers, currently: {shape}.", + ) from e + else: + raise MsprobeException( + MsgConst.INVALID_ARGU, 'The "dyn_shape" of the input can only contain hyphen (-) or a comma (,).' + ) + return ranges + + def parse(self): + logger.info("Start parsing the input list.") + modify_file = [] + for infile in self.input_list: + name = self._check_name(infile) + self._check_input_shape(infile, name) + self._check_input_path(infile, name) + infile = self._check_dyn_shape(infile, name) + modify_file.append(infile) + shapes, paths = self._draw_shape_and_path(modify_file) + return shapes, paths + + def _parse_dyn_shape_range(self, shapes, name): + if not isinstance(shapes, list): + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, f'"dyn_shape" of the input {name} must be a list.') + shapes_list = [] + for shape in shapes: + if isinstance(shape, str): + ranges = self._parse_shape_range_for_str(shape) + elif isinstance(shape, int): + check_int_border(shape, tag="Integer in a dynamic shape") + ranges = [shape] + else: + raise MsprobeException( + MsgConst.INVALID_DATA_TYPE, + f'Elements in "dyn_shape" of the input support only string and integers, currently: {shape}.', + ) + shapes_list.append(ranges) + return [list(s) for s in list(product(*shapes_list))] + + def _check_form(self): + if isinstance(self.input_list, list): + for vv in self.input_list: + if not isinstance(vv, dict): + raise MsprobeException( + MsgConst.INVALID_DATA_TYPE, "Each element in the input must be a dictionary." + ) + else: + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, "The input must be a list.") + + def _check_dyn_shape(self, infile: dict, name: str): + if infile.get("dyn_shape"): + self.is_need_expand_shape = True + infile["dyn_shape"] = self._parse_dyn_shape_range(infile["dyn_shape"], name) + infile["shape"] = [] + if infile.get("path"): + infile["path"] = "" + logger.warning('Since "dyn_shape" is used, "shape" and "path" will not take effect.') + return infile + + def _draw_shape_and_path(self, modify_file): + if self.is_need_expand_shape: + dyn_shapes = [item["dyn_shape"] for item in modify_file] + if all(len(shapes) == len(dyn_shapes[0]) for shapes in dyn_shapes): + shapes = [dict(zip([item["name"] for item in modify_file], shapes)) for shapes in zip(*dyn_shapes)] + paths = None + else: + raise MsprobeException( + MsgConst.INVALID_ARGU, "Ensure all inputs have the same expanded dynamic shape length." + ) + else: + shapes, paths = {}, [] + for item in modify_file: + shapes[item["name"]] = item.get("shape") + if item.get("path"): + paths.append(item["path"]) + return shapes, paths + + +def valid_onnx_fusion_switch(value: bool): + if not value: + return value + if not isinstance(value, bool): + raise MsprobeException( + MsgConst.INVALID_DATA_TYPE, f'"onnx_fusion_switch" must be a boolean, currently: {value}.' + ) + return value + + +def valid_saved_model_tag(value: list): + return _check_common_list(value, tag="saved_model_tag") + + +def valid_saved_model_signature(value: str): + if value is None: + return value + check_special_char(value) + return value + + +def valid_summary_mode(value: str): + if value is None: + return value + if value not in DumpConst.ALL_SUMMARY_MODE: + raise MsprobeException( + MsgConst.INVALID_ARGU, + f'"summary_mode" must be one of {DumpConst.ALL_SUMMARY_MODE}, currently: {value}.', + ) + return value