From 64daec30337699402ecb11648cbc06c2308a8e45 Mon Sep 17 00:00:00 2001 From: kai-ma Date: Fri, 4 Jul 2025 12:19:09 +0800 Subject: [PATCH] add config_dump and valid_params --- .../msprobe/core/config_initiator/__init__.py | 15 + .../core/config_initiator/config_dump.py | 113 ++++++ .../core/config_initiator/validate_params.py | 325 ++++++++++++++++++ 3 files changed, 453 insertions(+) create mode 100644 accuracy_tools/msprobe/core/config_initiator/__init__.py create mode 100644 accuracy_tools/msprobe/core/config_initiator/config_dump.py create mode 100644 accuracy_tools/msprobe/core/config_initiator/validate_params.py diff --git a/accuracy_tools/msprobe/core/config_initiator/__init__.py b/accuracy_tools/msprobe/core/config_initiator/__init__.py new file mode 100644 index 00000000000..492a36cd5db --- /dev/null +++ b/accuracy_tools/msprobe/core/config_initiator/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from msprobe.core.config_initiator.config_dump import DumpConfig diff --git a/accuracy_tools/msprobe/core/config_initiator/config_dump.py b/accuracy_tools/msprobe/core/config_initiator/config_dump.py new file mode 100644 index 00000000000..80df48e6cf6 --- /dev/null +++ b/accuracy_tools/msprobe/core/config_initiator/config_dump.py @@ -0,0 +1,113 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from msprobe.base import BaseConfig +from msprobe.core.config_initiator.validate_params import ( + valid_data_mode, + valid_device, + valid_dump_extra, + valid_dump_ge_graph, + valid_dump_graph_level, + valid_dump_path, + valid_fusion_switch_file, + valid_input, + valid_list, + valid_onnx_fusion_switch, + valid_op_id, + valid_saved_model_signature, + valid_saved_model_tag, + valid_summary_mode, + valid_weight_path, +) +from msprobe.utils.constants import CfgConst, DumpConst + + +class DumpConfig(BaseConfig): + def check_config(self, dump_path: str = None): + self.config[self.config.get(CfgConst.TASK)] = self._check_dump_dic(dump_path) + return self.config + + def _check_dump_dic(self, dump_path: str = None): + self._update_config( + self.task_config, + DumpConst.DUMP_PATH, + valid_dump_path, + dump_path or self.task_config.get(DumpConst.DUMP_PATH, "./"), + ) + self._update_config( + self.task_config, + DumpConst.LIST, + valid_list, + (self.task_config.get(DumpConst.LIST, []), self.config.get(CfgConst.LEVEL)), + ) + self._update_config( + self.task_config, DumpConst.DATA_MODE, valid_data_mode, self.task_config.get(DumpConst.DATA_MODE, ["all"]) + ) + self._update_config( + self.task_config, + DumpConst.SUMMARY_MODE, + valid_summary_mode, + self.task_config.get(DumpConst.SUMMARY_MODE, CfgConst.TASK_STAT), + ) + self._update_config( + self.task_config, DumpConst.DUMP_EXTRA, valid_dump_extra, self.task_config.get(DumpConst.DUMP_EXTRA, []) + ) + self._update_config(self.task_config, DumpConst.OP_ID, valid_op_id, self.task_config.get(DumpConst.OP_ID, [])) + self._update_config( + self.task_config, + DumpConst.DUMP_GE_GRAPH, + valid_dump_ge_graph, + self.task_config.get(DumpConst.DUMP_GE_GRAPH, "2"), + ) + self._update_config( + self.task_config, + DumpConst.DUMP_GRAPH_LEVEL, + valid_dump_graph_level, + self.task_config.get(DumpConst.DUMP_GRAPH_LEVEL, "3"), + ) + self._update_config( + self.task_config, + DumpConst.FUSION_SWITCH_FILE, + valid_fusion_switch_file, + self.task_config.get(DumpConst.FUSION_SWITCH_FILE, None), + ) + self._update_config( + self.task_config, DumpConst.DEVICE, valid_device, self.task_config.get(DumpConst.DEVICE, None) + ) + self._update_config(self.task_config, DumpConst.INPUT, valid_input, self.task_config.get(DumpConst.INPUT, [])) + self._update_config( + self.task_config, + DumpConst.ONNX_FUSION_switch, + valid_onnx_fusion_switch, + self.task_config.get(DumpConst.ONNX_FUSION_switch, True), + ) + self._update_config( + self.task_config, + DumpConst.SAVED_MODEL_TAG, + valid_saved_model_tag, + self.task_config.get(DumpConst.SAVED_MODEL_TAG, ["serve"]), + ) + self._update_config( + self.task_config, + DumpConst.SAVED_MODEL_SIGN, + valid_saved_model_signature, + self.task_config.get(DumpConst.SAVED_MODEL_SIGN, "serving_default"), + ) + self._update_config( + self.task_config, + DumpConst.WEIGHT_PATH, + valid_weight_path, + self.task_config.get(DumpConst.WEIGHT_PATH, None), + ) + return self.task_config diff --git a/accuracy_tools/msprobe/core/config_initiator/validate_params.py b/accuracy_tools/msprobe/core/config_initiator/validate_params.py new file mode 100644 index 00000000000..e39ff1bb9d9 --- /dev/null +++ b/accuracy_tools/msprobe/core/config_initiator/validate_params.py @@ -0,0 +1,325 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from itertools import product + +from msprobe.common.validation import parse_hyphen +from msprobe.utils.constants import DumpConst, MsgConst, PathConst +from msprobe.utils.exceptions import MsprobeException +from msprobe.utils.log import logger +from msprobe.utils.path import SafePath +from msprobe.utils.toolkits import check_int_border + +_OP_ID_PATTERN = r"^\d{1,10}(_\d{1,10}){0,9}$" +_ALL_DEVICE = {"cpu", "npu"} +_VALID_CHAR = r"^[a-zA-Z0-9_.-:]+$" + + +def check_special_char(value: str): + if not (isinstance(value, str) and re.match(_VALID_CHAR, value)): + raise MsprobeException(MsgConst.RISK_ALERT, f"Invalid input: contains unsafe characters: {value}.") + + +def valid_dump_path(value: str): + return SafePath(value, PathConst.DIR, "w").check() + + +def valid_list(value: tuple): + def re_format(value: tuple): + ret = {} + for ii in value[1]: + for vv in value[0]: + check_special_char(vv) + ret[ii] = value[0] + return ret + + if not value[0] or (isinstance(value[0], list) and len(value[1]) == 1): + return re_format(value) + elif isinstance(value[0], dict): + for key, vv in value[0].items(): + if key not in value[1]: + raise MsprobeException(MsgConst.INVALID_ARGU, f"Key not in allowed list {value[1]}, currently: {key}.") + if not isinstance(vv, list): + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, f"Value must be a list, got {type(vv)} instead.") + for v in vv: + check_special_char(v) + return value[0] + else: + raise MsprobeException( + MsgConst.INVALID_DATA_TYPE, + """The list parameter supports two types: + 1. List, which requires "level" to be set with only one element. + 2. Dictionary, which allows "level" to be set with multiple elements.""", + ) + + +def valid_data_mode(value: list): + if not value: + return value + if not isinstance(value, list): + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, '"data_mode" must be a list.') + if len(value) == 1: + if value[0] not in DumpConst.ALL_DATA_MODE: + raise MsprobeException( + MsgConst.INVALID_ARGU, f'"data_mode" must be one of {DumpConst.ALL_DATA_MODE}, currently: {value[0]}.' + ) + else: + raise MsprobeException(MsgConst.INVALID_ARGU, '"data_mode" only accepts a single-item list.') + return value + + +def valid_dump_extra(values: list): + if not values: + return values + if not isinstance(values, list): + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, '"dump_extra" must be a list.') + for value in values: + if value not in DumpConst.ALL_DUMP_EXTRA: + raise MsprobeException( + MsgConst.INVALID_ARGU, f'"dump_extra" must be one of {DumpConst.ALL_DUMP_EXTRA}, currently: {value}.' + ) + return values + + +def valid_op_id(value: list): + if not value: + return value + if not isinstance(value, list): + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, '"op_id" must be a list.') + res = [] + for element in value: + if isinstance(element, int): + check_int_border(element, tag="the integer part of op_id") + res.append(element) + elif isinstance(element, str) and re.match(_OP_ID_PATTERN, element): + res.append(element) + else: + raise MsprobeException( + MsgConst.INVALID_DATA_TYPE, + '"op_id" is only supported in the ATB dump scenario, ' + f"with formats like 2, 3_1, or 3_1_2, currently: {element}.", + ) + return res + + +def valid_dump_ge_graph(value: str): + if value is None: + return value + if not isinstance(value, str): + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, '"dump_ge_graph" must be a string.') + if value not in DumpConst.ALL_DUMP_GE_GRAPH: + raise MsprobeException( + MsgConst.INVALID_ARGU, f'"dump_ge_graph" must be one of {DumpConst.ALL_DUMP_GE_GRAPH}, currently: {value}.' + ) + return value + + +def valid_dump_graph_level(value: str): + if value is None: + return value + if not isinstance(value, str): + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, '"dump_graph_level" must be a string.') + if value not in DumpConst.ALL_DUMP_GRAPH_LEVEL: + raise MsprobeException( + MsgConst.INVALID_ARGU, + f'"dump_graph_level" must be one of {DumpConst.ALL_DUMP_GRAPH_LEVEL}, currently: {value}.', + ) + return value + + +def valid_fusion_switch_file(value: str): + if value is None: + return value + return SafePath(value, PathConst.FILE, "r", PathConst.SIZE_500M, (".json", ".cfg")).check() + + +def valid_device(value: str): + if value is None: + return value + if not isinstance(value, str): + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, '"device" must be a string.') + if value not in _ALL_DEVICE: + raise MsprobeException(MsgConst.INVALID_ARGU, f'"device" must be one of {_ALL_DEVICE}, currently: {value}.') + return value + + +def valid_input(value: list): + if not value: + return value + return OfflineModelInput(value).parse() + + +class OfflineModelInput: + def __init__(self, input_list): + self.input_list = input_list + self._check_form() + self.is_need_expand_shape = False + + @staticmethod + def _check_name(infile: dict): + if not infile.get("name"): + raise MsprobeException(MsgConst.PARSING_FAILED, "Each input must have a name.") + return infile.get("name") + + @staticmethod + def _check_input_shape(infile: dict, name): + inshape = infile.get("shape") + if inshape: + if not isinstance(inshape, list): + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, f'"shape" of the input {name} must be a list.') + for vv in inshape: + check_int_border(vv, tag=f'Elements in "shape" of the input {name}') + + @staticmethod + def _check_input_path(infile: dict, name): + inpath = infile.get("path") + if inpath: + if not isinstance(inpath, str): + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, f'"path" of the input {name} must be a string.') + if not inpath.endswith((".bin", ".npy")): + raise MsprobeException( + MsgConst.INVALID_ARGU, f'"path" of {name} can only accept .npy or .bin files, currently: {inpath}.' + ) + _ = SafePath(inpath, PathConst.FILE, "r", PathConst.SIZE_10G).check() + + @staticmethod + def _parse_shape_range_for_str(shape): + if "-" in shape: + ranges = parse_hyphen(shape, tag="Elements in a dynamic shape") + elif "," in shape and shape.count(",") == 1: + try: + ranges = list(map(int, shape.split(","))) + except Exception as e: + raise MsprobeException( + MsgConst.INVALID_ARGU, + f"Both sides of the hyphen (-) in the input must be numbers, currently: {shape}.", + ) from e + else: + raise MsprobeException( + MsgConst.INVALID_ARGU, 'The "dym_shape" of the input can only contain hyphen (-) or a comma (,).' + ) + return ranges + + def parse(self): + logger.info("Start parsing the input list.") + modify_file = [] + for infile in self.input_list: + name = self._check_name(infile) + self._check_input_shape(infile, name) + self._check_input_path(infile, name) + infile = self._check_dym_shape(infile, name) + modify_file.append(infile) + shapes, paths = self._draw_shape_and_path(modify_file) + return shapes, paths + + def _parse_dym_shape_range(self, shapes, name): + if not isinstance(shapes, list): + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, f'"dym_shape" of the input {name} must be a list.') + shapes_list = [] + for shape in shapes: + if isinstance(shape, str): + ranges = self._parse_shape_range_for_str(shape) + elif isinstance(shape, int): + check_int_border(shape, tag="Integer in a dynamic shape") + ranges = [shape] + else: + raise MsprobeException( + MsgConst.INVALID_DATA_TYPE, + f'Elements in "dym_shape" of the input support only string and integers, currently: {shape}.', + ) + shapes_list.append(ranges) + return [list(s) for s in list(product(*shapes_list))] + + def _check_form(self): + if isinstance(self.input_list, list): + for vv in self.input_list: + if not isinstance(vv, dict): + raise MsprobeException( + MsgConst.INVALID_DATA_TYPE, "Each element in the input must be a dictionary." + ) + else: + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, "The input must be a list.") + + def _check_dym_shape(self, infile: dict, name: str): + if infile.get("dym_shape"): + self.is_need_expand_shape = True + infile["dym_shape"] = self._parse_dym_shape_range(infile["dym_shape"], name) + infile["shape"] = [] + if infile.get("path"): + infile["path"] = "" + logger.warning('Since "dym_shape" is used, "shape" and "path" will not take effect.') + return infile + + def _draw_shape_and_path(self, modify_file): + if self.is_need_expand_shape: + dym_shapes = [item["dym_shape"] for item in modify_file] + if all(len(shapes) == len(dym_shapes[0]) for shapes in dym_shapes): + shapes = [dict(zip([item["name"] for item in modify_file], shapes)) for shapes in zip(*dym_shapes)] + paths = None + else: + raise MsprobeException( + MsgConst.INVALID_ARGU, "Ensure all inputs have the same expanded dynamic shape length." + ) + else: + shapes, paths = {}, [] + for item in modify_file: + shapes[item["name"]] = item.get("shape") + if item.get("path"): + paths.append(item["path"]) + return shapes, paths + + +def valid_onnx_fusion_switch(value: bool): + if not value: + return value + if not isinstance(value, bool): + raise MsprobeException( + MsgConst.INVALID_DATA_TYPE, f'"onnx_fusion_switch" must be a boolean, currently: {value}.' + ) + return value + + +def valid_saved_model_tag(value: list): + if not value: + return value + if not isinstance(value, list): + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, "saved_model_tag msut be a list.") + for vv in value: + check_special_char(vv) + return value + + +def valid_saved_model_signature(value: str): + if value is None: + return value + check_special_char(value) + return value + + +def valid_weight_path(value: str): + if value is None: + return value + return SafePath(value, PathConst.FILE, "r", PathConst.SIZE_50G, ".caffemodel").check() + + +def valid_summary_mode(value: str): + if value is None: + return value + if value not in DumpConst.ALL_SUMMARY_MODE: + raise MsprobeException( + MsgConst.INVALID_ARGU, + f'"summary_mode" must be one of {DumpConst.ALL_SUMMARY_MODE}, currently: {value}.', + ) + return value -- Gitee