diff --git a/accuracy_tools/msprobe/utils/log.py b/accuracy_tools/msprobe/utils/log.py new file mode 100644 index 0000000000000000000000000000000000000000..965f229563685b5b9403175b82fd0918a79e4d0a --- /dev/null +++ b/accuracy_tools/msprobe/utils/log.py @@ -0,0 +1,69 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from msprobe.lib.msprobe_c import log + +_STAR = "*" +_DEBUG = "DEBUG" +_INFO = "INFO" +_WARNING = "WARNING" +_ERROR = "ERROR" +_TOTAL_CHAR_LENGTH = 80 +LOG_LEVEL = [_DEBUG, _INFO, _WARNING, _ERROR] + + +class Logger: + _instance = None + + def __new__(cls, *args, **kwargs): + if not cls._instance: + cls._instance = super(Logger, cls).__new__(cls) + return cls._instance + + @staticmethod + def get_level_id(level: str): + if level.upper() in LOG_LEVEL: + return LOG_LEVEL.index(level.upper()) + else: + return LOG_LEVEL.index(LOG_LEVEL[1]) + + @staticmethod + def error(msg): + log.print_log(LOG_LEVEL.index(_ERROR), msg) + + @staticmethod + def warning(msg): + log.print_log(LOG_LEVEL.index(_WARNING), msg) + + @staticmethod + def info(msg): + log.print_log(LOG_LEVEL.index(_INFO), msg) + + @staticmethod + def debug(msg): + log.print_log(LOG_LEVEL.index(_DEBUG), msg) + + def set_level(self, level: str): + level_id = self.get_level_id(level) + log.set_log_level(level_id) + + +logger = Logger() + + +def print_log_with_star(info_message: str): + total_length = _TOTAL_CHAR_LENGTH + logger.info(_STAR * total_length) + logger.info(f"{_STAR}{info_message.center(total_length - 2)}{_STAR}") + logger.info(_STAR * total_length) diff --git a/accuracy_tools/msprobe/utils/path.py b/accuracy_tools/msprobe/utils/path.py new file mode 100644 index 0000000000000000000000000000000000000000..173e30a4ceb18cf261557d3b18cc126b9468a140 --- /dev/null +++ b/accuracy_tools/msprobe/utils/path.py @@ -0,0 +1,340 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re +from enum import Enum +from pathlib import Path +from shutil import disk_usage +from stat import S_IMODE, S_IRUSR, S_IWGRP, S_IWOTH, S_IWUSR, S_IXUSR + +from msprobe.utils.constants import MsgConst, PathConst +from msprobe.utils.exceptions import MsprobeException +from msprobe.utils.log import logger +from msprobe.utils.toolkits import check_int_border + +_MAX_PATH_LENGTH = 4096 +_MAX_LAST_NAME_LENGTH = 255 +_VALID_PATH_PATTERN = r"^(?!.*\.\.)[a-zA-Z0-9_./-]+$" + +_MODE_READ = {"r", "rb"} +_MODE_WRITE = {"w", "wb", "a", "ab", "a+"} +_MODE_EXEC = {"e"} +_MODE = _MODE_READ | _MODE_WRITE | _MODE_EXEC +_MAX_DIR_DEPTH = 32 +AUTHORITY_DIR = 0o750 +AUTHORITY_FILE = 0o640 + + +class SoftLinkLevel(Enum): + IGNORE = 0 + WARNING = 1 + STRICT = 2 + + +def is_file(path: str): + return os.path.isfile(path) + + +def is_dir(path: str): + return os.path.isdir(path) + + +def get_basename_from_path(path: str): + return os.path.basename(path.rstrip("/")) + + +def get_file_size(path: str): + return os.path.getsize(path) + + +def get_abs_path(path: str): + return os.path.abspath(path) + + +def get_name_and_ext(model_path): + basename = get_basename_from_path(model_path) + # Always returns (name, ext). + return os.path.splitext(basename) + + +def join_path(*args, max_depth=_MAX_DIR_DEPTH): + check_int_border(max_depth, tag="max value of directory depth") + + def flatten(items, depth=0): + if depth > max_depth: + raise MsprobeException(MsgConst.RISK_ALERT, f"Maximum recursion depth {max_depth} exceeded") + for item in items: + if isinstance(item, str): + yield item + elif isinstance(item, (list, tuple)): + yield from flatten(item, depth + 1) + else: + pass + + return os.path.join(*flatten(args)) + + +def is_saved_model_scene(model_path): + saved_model_pb = join_path(model_path, "saved_model.pb") + if not is_file(saved_model_pb): + return False + variables_dir = join_path(model_path, "variables") + return is_dir(variables_dir) + + +def convert_bytes(bytes_size: int) -> str: + if bytes_size < 1024: + return f"{bytes_size} Bytes" + elif bytes_size < 1_048_576: # 1024 * 1024 + return f"{bytes_size / 1024:.2f} KB" + elif bytes_size < 1_073_741_824: # 1024 * 1024 * 1024 + return f"{bytes_size / (1_048_576):.2f} MB" + else: + return f"{bytes_size / (1_073_741_824):.2f} GB" + + +class SafePath: + def __init__( + self, + path: str, + path_type: str, + mode: str, + size_limitation: int = None, + suffix: str = None, + max_dir_depth: int = _MAX_DIR_DEPTH, + ): + self.path = self._check_path(path) + self.path_type = self._check_path_type(path_type) + self.mode = self._check_mode(mode) + self.size_limitation = self._check_int(size_limitation) if size_limitation else None + self.suffix = suffix + self.max_dir_depth = self._check_int(max_dir_depth) + + @staticmethod + def _check_path(path): + if not isinstance(path, str): + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, '"path" must be string.') + return path + + @staticmethod + def _check_path_type(path_type): + if path_type not in [PathConst.FILE, PathConst.DIR]: + raise MsprobeException( + MsgConst.INVALID_ARGU, + f"The path type must be one of {[PathConst.FILE, PathConst.DIR]}, " f"currently: {path_type}.", + ) + return path_type + + @staticmethod + def _check_mode(mode): + if mode not in _MODE: + raise MsprobeException(MsgConst.INVALID_ARGU, f"Mode must be one of {_MODE}, currently: {mode}.") + return mode + + @staticmethod + def _check_int(value): + if not isinstance(value, int): + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, f"Value must be an integer, currently: {value}.") + return value + + @staticmethod + def _check_path_exist(path): + if not os.path.exists(path): + raise MsprobeException(MsgConst.INVALID_ARGU, f"Path not found: {path}.") + + @staticmethod + def _check_soft_link(path: str, level: SoftLinkLevel) -> str: + if not os.path.islink(path): + return path + real_path = os.path.realpath(path) + if not isinstance(level, SoftLinkLevel): + raise MsprobeException( + MsgConst.INVALID_ARGU, f"The validation level of symbolic links must be a SoftLinkLevel enum value." + ) + if level == SoftLinkLevel.STRICT: + raise MsprobeException(MsgConst.RISK_ALERT, f"Path {path} is a symlink. Usage prohibited.") + elif level == SoftLinkLevel.WARNING: + logger.warning(f"Found a symlink, path {path} -> {real_path}.") + else: + pass + return real_path + + @staticmethod + def _check_write_permission_for_group_others(path, permission): + if bool(permission & (S_IWGRP | S_IWOTH)): + raise MsprobeException( + MsgConst.RISK_ALERT, + f"The path {path} is writable by group and others. " + "Permissions for files (or directories) should not exceed 0o755 (rwxr-xr-x).", + ) + + @classmethod + def _check_permission(cls, path, mode): + path_stat = os.stat(path) + owner_id = path_stat.st_uid + current_uid = os.geteuid() + if owner_id not in {current_uid, 0}: + raise MsprobeException(MsgConst.RISK_ALERT, f"The owner of {path} must be root or the current user.") + permission = S_IMODE(path_stat.st_mode) + if current_uid == 0: + logger.warning(f"Running as root: Skipping permission checks for {path}, but this is a potential risk.") + else: + cls._check_write_permission_for_group_others(path, permission) + if mode in _MODE_READ and not (permission & S_IRUSR): + raise MsprobeException( + MsgConst.NO_PERMISSION, f"The current user is not authorized to read the path: {path}." + ) + if mode in _MODE_WRITE and not (permission & S_IWUSR): + raise MsprobeException( + MsgConst.NO_PERMISSION, f"The current user is not authorized to write the path: {path}." + ) + if mode in _MODE_EXEC and not (permission & S_IXUSR): + raise MsprobeException( + MsgConst.NO_PERMISSION, f"The current user is not authorized to execute the path: {path}." + ) + + def check(self, path_exist=True, soft_link_level=SoftLinkLevel.STRICT): + self.path = get_abs_path(os.path.normpath(self.path)) + if self.mode in _MODE_WRITE and not path_exist: + parent_dir = get_abs_path(join_path(self.path, os.pardir)) + self._check_path_exist(parent_dir) # The current path doesn't exist, but the parent directory does. + parent_dir = self._check_soft_link(parent_dir, soft_link_level) + if not is_dir(parent_dir): + raise MsprobeException(MsgConst.INVALID_ARGU, f"The parent directory {parent_dir} is not valid.") + self._check_special_chars() + self._check_path_length() + self._check_permission(parent_dir, self.mode) + else: + self._check_path_exist(self.path) + self.path = self._check_soft_link(self.path, soft_link_level) + self._check_special_chars() + self._check_path_length() + if self.path_type == PathConst.FILE: + if not is_file(self.path): + raise MsprobeException(MsgConst.INVALID_ARGU, f"The path {self.path} is not a file.") + self._check_file_suffix() + self._check_file_size() + elif self.path_type == PathConst.DIR: + if not is_dir(self.path): + raise MsprobeException(MsgConst.INVALID_ARGU, f"The path {self.path} is not a directory.") + self._check_dir_size() + self._check_permission(self.path, self.mode) + if self.path_type == PathConst.DIR and not self.path.endswith("/"): + self.path += "/" + return self.path + + def _check_special_chars(self): + if not re.match(_VALID_PATH_PATTERN, self.path): + raise MsprobeException(MsgConst.INVALID_ARGU, f"Path {self.path} contains special characters.") + + def _check_path_length(self): + if len(self.path) > _MAX_PATH_LENGTH: + raise MsprobeException( + MsgConst.RISK_ALERT, f"Current path length ({len(self.path)}) exceeds the limit ({_MAX_PATH_LENGTH})." + ) + dir_depth = 0 + for dir_name in self.path.split("/"): + dir_depth += 1 + if dir_depth > self.max_dir_depth: + raise MsprobeException(MsgConst.RISK_ALERT, f"Exceeded max directory depth ({self.max_dir_depth}).") + if len(dir_name) > _MAX_LAST_NAME_LENGTH: + raise MsprobeException( + MsgConst.RISK_ALERT, + f"Current {self.path_type} length ({len(dir_name)}) exceeds the limit ({_MAX_LAST_NAME_LENGTH}).", + ) + + def _check_file_suffix(self): + if self.suffix and not self.path.endswith(self.suffix): + raise MsprobeException(MsgConst.INVALID_ARGU, f"{self.path} is not a {self.suffix} file.") + + def _check_file_size(self): + if self.size_limitation and os.path.getsize(self.path) > self.size_limitation: + raise MsprobeException( + MsgConst.RISK_ALERT, f"File size exceeds the limit ({convert_bytes(self.size_limitation)})." + ) + + def _check_dir_size(self): + if self.size_limitation and get_dir_size(self.path, self.max_dir_depth) > self.size_limitation: + raise MsprobeException( + MsgConst.RISK_ALERT, f"Directory size exceeds the limit ({convert_bytes(self.size_limitation)})." + ) + + +def get_dir_size(dir_path, max_dir_depth=_MAX_DIR_DEPTH): + total_size = 0 + for root, _, files in os.walk(dir_path): + # fmt: off + current_depth = root[len(dir_path):].count(os.sep) + # fmt: on + if current_depth > max_dir_depth: + raise MsprobeException( + MsgConst.RISK_ALERT, + f"Calculated size of {dir_path}, but exceeded max depth ({max_dir_depth}). Current size: {total_size}.", + ) + for file_name in files: + total_size += os.path.getsize(join_path(root, file_name)) + return total_size + + +def make_dirs(dir_path: str): + normalized_path = os.path.normpath(dir_path) + depth_parts = normalized_path.strip(os.sep).split(os.sep) + depth = len([p for p in depth_parts if p]) + if depth > _MAX_DIR_DEPTH: + raise MsprobeException( + MsgConst.RISK_ALERT, f"Directory depth exceeds the limit of {_MAX_DIR_DEPTH}: {dir_path} has depth {depth}." + ) + + try: + Path(dir_path).mkdir(mode=AUTHORITY_DIR, exist_ok=True, parents=True) + except OSError as e: + raise MsprobeException( + MsgConst.IO_FAILURE, + f"Failed to create {dir_path}, please Check if the parent directory of the current " + f"path exists, and verify permissions or disk space.", + ) from e + + +def change_permission(path, permission): + if not os.path.exists(path) or os.path.islink(path): + return + try: + os.chmod(path, permission) + except PermissionError as e: + raise MsprobeException(MsgConst.NO_PERMISSION, f"Failed to set permissions ({permission}) for {path}.") from e + + +def is_enough_disk_space(path, required_space): + return disk_usage(path).free >= required_space + + +class DirSafeHandler: + @staticmethod + def ensure_dir_exists(path: str): + if not is_dir(path): + make_dirs(path) + + @staticmethod + def get_or_raise(path: str, error_msg: str): + if path: + return path + else: + raise MsprobeException(MsgConst.PATH_NOT_FOUND, error_msg) + + @staticmethod + def join_and_create(*args): + path = join_path(args) + DirSafeHandler.ensure_dir_exists(path) + return path diff --git a/accuracy_tools/msprobe/utils/toolkits.py b/accuracy_tools/msprobe/utils/toolkits.py new file mode 100644 index 0000000000000000000000000000000000000000..e36a3efd60c96c6cbdd8be4c33dbc676127ac696 --- /dev/null +++ b/accuracy_tools/msprobe/utils/toolkits.py @@ -0,0 +1,424 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re +from enum import Enum +from functools import wraps +from random import seed +from subprocess import PIPE, CalledProcessError, Popen, run +from time import perf_counter, time + +import numpy as np + +from msprobe.utils.constants import MsgConst +from msprobe.utils.dependencies import dependent +from msprobe.utils.env import evars +from msprobe.utils.exceptions import MsprobeException +from msprobe.utils.log import logger + +_MALICIOUS_CSV_PATTERN = re.compile(r"^[+-=%@\+\-=%@]|;[+-=%@\+\-=%@]") + + +class CsvCheckLevel(Enum): + IGNORE = 0 + REPLACE = 1 + STRICT = 2 + + +_POSITIVE_INT_BORDER = [0, 4_294_967_295] # [0, 2**32 - 1] + + +def get_pid(): + return os.getpid() + + +def get_current_timestamp(microsecond=True): + if microsecond: + return round(perf_counter() * 1e6) % 10**10 + else: + timestamp = int(time()) + return timestamp + + +def filter_cmd(paras): + whitelist_pattern = re.compile(r"^[a-zA-Z0-9_\-./=:,\[\] ]+$") + filtered = [] + for arg in paras: + arg_str = str(arg) + if whitelist_pattern.fullmatch(arg_str): + filtered.append(arg_str) + else: + raise MsprobeException( + MsgConst.RISK_ALERT, + f'The command contains invalid characters. Only the "{whitelist_pattern}" pattern is allowed.', + ) + return filtered + + +def register(name, tmp_map): + @wraps(name) + def wrapper(comp_type): + tmp_map[name] = comp_type + return comp_type + + return wrapper + + +def safely_compute(func): + @wraps(func) + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except Exception as e: + logger.warning(f"Calculation failed via {func.__name__}: {e}") + return None + + return wrapper + + +def get_valid_name(name: str): + if name and name[0] == "/": + name = name.lstrip("/") + return name.replace(".", "_").replace("/", "_").replace(":", "_") + + +def run_subprocess(cmd: list, capture_output=False): + if not isinstance(cmd, list): + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, "`cmd` must be a list of strings.") + cmd = filter_cmd(cmd) + logger.warning("Please ensure the executed command is correct.") + logger.info(f'Running command: {" ".join(cmd)}.') + if capture_output: + process = Popen(cmd, stdout=PIPE, stderr=PIPE, text=True, bufsize=1, shell=False) + stdout, stderr = process.communicate() + stderr_lines = stderr.splitlines() + if process.returncode != 0: + logger.error(f"Sub-process failed with error: {stderr_lines}.") + process.terminate() + raise MsprobeException(MsgConst.CALL_FAILED, f"Failed to execute command: {' '.join(cmd)}.") + return stdout + else: + try: + run(cmd, text=True, shell=False, check=True) + return None + except CalledProcessError as e: + raise MsprobeException(MsgConst.CALL_FAILED, f"Command failed: {' '.join(cmd)}") from e + + +class DistBackend: + torch = dependent.get("torch") + dist_map = {"cuda": "nccl", "npu": "hccl", "cpu": "gloo"} + + @staticmethod + def _get_visible_device(device_type) -> int: + try: + return int(evars.get(device_type, "0").split(",")[0]) + except Exception as e: + raise MsprobeException( + MsgConst.INVALID_DATA_TYPE, + f"Please check the value of the environment variable {device_type}, " + f'currently: {evars.get(device_type, "0")}.', + ) from e + + @classmethod + def get(cls): + return cls.dist_map.get(cls._get_global_device(), "cpu") + + @classmethod + def _is_device_available(cls, device_name, device_type): + if device_name == "npu" and hasattr(cls.torch, "npu") and cls.torch.npu.is_available(): + return cls._get_visible_device(device_type) >= 0 + elif device_name == "cuda" and hasattr(cls.torch, "cuda") and cls.torch.cuda.is_available(): + return cls._get_visible_device(device_type) >= 0 + elif device_name == "cpu": + return True + return False + + @classmethod + def _get_global_device(cls): + if cls._is_device_available("npu", "ASCEND_VISIBLE_DEVICES"): + return "npu" + elif cls._is_device_available("cuda", "CUDA_VISIBLE_DEVICES"): + return "cuda" + else: + return "cpu" + + +def timestamp_sync(timestamp: int): + torch = dependent.get("torch") + world_size = evars.get("LOCAL_WORLD_SIZE", "1", int) + if world_size < 2: + return timestamp + if torch: + timestamp = torch.tensor(timestamp) + if not torch.distributed.is_initialized(): + rank = evars.get("LOCAL_RANK", "0", int) + torch.distributed.init_process_group(backend=DistBackend.get(), rank=rank, world_size=world_size) + torch.distributed.all_reduce(timestamp, op=torch.distributed.ReduceOp.MAX) + return timestamp.item() + return timestamp + + +def get_current_rank() -> str: + torch = dependent.get("torch") + if torch and torch.distributed.is_initialized(): + return str(torch.distributed.get_rank()) + return "" + + +def check_int_border(*args, border: list = None, tag: str = None): + if not border: + border = _POSITIVE_INT_BORDER + if len(border) != 2: + raise MsprobeException(MsgConst.INVALID_ARGU, "The border must be a list of two integers.") + for num in args: + if not isinstance(num, int): + msg = f"Expected int type, but got {type(num).__name__}." + if tag: + msg += f" Context: {tag}." + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, msg) + if not (border[0] <= num <= border[1]): + msg = f"The integer range is limited to {border}, currently: {num}." + if tag: + msg += f" Context: {tag}." + raise MsprobeException(MsgConst.INVALID_ARGU, msg) + + +class DropoutHandler: + @staticmethod + def remove_for_pt(): + torch = dependent.get("torch") + if not torch or torch.__version__ <= "1.8": + return + logger.info("For precision comparison, the probability p in the dropout method is set to 0.") + _f = torch.nn.functional + _vf = torch._C._VariableFunctions + has_torch_function_unary = torch.overrides.has_torch_function_unary + handle_torch_function = torch.overrides.handle_torch_function + + def function_dropout(input_tensor, p: float = 0.5, training: bool = True, inplace: bool = False): + if has_torch_function_unary(input_tensor): + return handle_torch_function( + function_dropout, (input_tensor,), input_tensor, p=0.0, training=training, inplace=inplace + ) + if p < 0.0 or p > 1.0: + raise MsprobeException( + MsgConst.INVALID_ARGU, f"dropout probability has to be between 0 and 1, but got {p}." + ) + return _vf.dropout_(input_tensor, 0.0, training) if inplace else _vf.dropout(input_tensor, 0.0, training) + + def function_dropout2d(input_tensor, p: float = 0.5, training: bool = True, inplace: bool = False): + if has_torch_function_unary(input_tensor): + return handle_torch_function( + function_dropout2d, (input_tensor,), input_tensor, p=0.0, training=training, inplace=inplace + ) + if p < 0.0 or p > 1.0: + raise MsprobeException( + MsgConst.INVALID_ARGU, f"dropout probability has to be between 0 and 1, but got {p}." + ) + return ( + _vf.feature_dropout_(input_tensor, 0.0, training) + if inplace + else _vf.feature_dropout(input_tensor, 0.0, training) + ) + + def function_dropout3d(input_tensor, p: float = 0.5, training: bool = True, inplace: bool = False): + if has_torch_function_unary(input_tensor): + return handle_torch_function( + function_dropout3d, (input_tensor,), input_tensor, p=0.0, training=training, inplace=inplace + ) + if p < 0.0 or p > 1.0: + raise MsprobeException( + MsgConst.INVALID_ARGU, f"dropout probability has to be between 0 and 1, but got {p}." + ) + return ( + _vf.feature_dropout_(input_tensor, 0.0, training) + if inplace + else _vf.feature_dropout(input_tensor, 0.0, training) + ) + + _f.dropout = function_dropout + _f.dropout2d = function_dropout2d + _f.dropout3d = function_dropout3d + + @staticmethod + def remove_for_ms(): + ms = dependent.get("mindspore") + if not ms: + return + ops = ms.ops + nn = ms.mint.nn + + class Dropout(ops.Dropout): + def __init__(self, keep_prob=0.5, seed0=0, seed1=1): + super().__init__(1.0, seed0, seed1) + + class Dropout2D(ops.Dropout2D): + def __init__(self, keep_prob=0.5): + super().__init__(1.0) + + class Dropout3D(ops.Dropout3D): + def __init__(self, keep_prob=0.5): + super().__init__(1.0) + + class DropoutExt(nn.Dropout): + def __init__(self, p=0.5): + super().__init__(0) + + def dropout_ext(input_tensor, p=0.5, training=True): + return input_tensor + + ops.Dropout = Dropout + ops.operations.Dropout = Dropout + ops.Dropout2D = Dropout2D + ops.operations.Dropout2D = Dropout2D + ops.Dropout3D = Dropout3D + ops.operations.Dropout3D = Dropout3D + nn.Dropout = DropoutExt + nn.functional.dropout = dropout_ext + + +class SetSeed: + _instance = None + + def __new__(cls, *args, **kwargs): + if not cls._instance: + cls._instance = super(SetSeed, cls).__new__(cls) + return cls._instance + + def __init__(self, seed_num: int, mode: bool, rm_dropout: bool): + self.seed_num = seed_num + self.mode = mode + self.rm_dropout = rm_dropout + self._check_param() + + @classmethod + def all(cls): + cls._focus_on_native() + cls._focus_on_torch() + cls._focus_on_torch_npu() + cls._focus_on_ascend() + cls._focus_on_mindspore() + + def _check_param(self): + check_int_border(self.seed_num) + if not isinstance(self.mode, bool): + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, "`mode` must be a boolean.") + if not isinstance(self.rm_dropout, bool): + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, "`rm_dropout` must be a boolean.") + + def _focus_on_native(self): + evars.set("PYTHONHASHSEED", str(self.seed_num)) + seed(self.seed_num) + np.random.seed(self.seed_num) + + def _focus_on_ascend(self): + evars.set("LCCL_DETERMINISTIC", "1") + evars.set("HCCL_DETERMINISTIC", "true" if self.mode else "false") + evars.set("ATB_MATMUL_SHUFFLE_K_ENABLE", "0") + evars.set("ATB_LLM_LCOC_ENABLE", "0") + + def _focus_on_torch(self): + torch = dependent.get("torch") + if not torch: + return + torch.manual_seed(self.seed_num) + torch.use_deterministic_algorithms(mode=self.mode) + if hasattr(torch, "cuda"): + torch.cuda.manual_seed(self.seed_num) + torch.cuda.manual_seed_all(self.seed_num) + if hasattr(torch, "backends"): + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.enable = False + torch.backends.cudnn.benchmark = False + if hasattr(torch, "version"): + cuda_version = torch.version.cuda + if cuda_version: + major, minor = map(int, cuda_version.split(".")[:2]) + if (major, minor) >= (10, 2): + evars.set("CUBLAS_WORKSPACE_CONFIG", ":4096:8") + if self.rm_dropout: + DropoutHandler.remove_for_pt() + + def _focus_on_torch_npu(self): + torch_npu = dependent.get("torch_npu") + if not torch_npu: + return + torch_npu.npu.manual_seed(self.seed_num) + torch_npu.npu.manual_seed_all(self.seed_num) + + def _focus_on_mindspore(self): + ms = dependent.get("mindspore") + if not ms: + return + ms.set_seed(self.seed_num) + ms.set_context(deterministic="ON" if self.mode else "OFF") + if self.rm_dropout: + DropoutHandler.remove_for_ms() + + +def seed_all(seed_num=666, mode=False, rm_dropout=True): + try: + SetSeed.all(seed_num, mode, rm_dropout) + except Exception as e: + raise MsprobeException(MsgConst.CALL_FAILED, f"Failed to set seed: {e}") from e + logger.info(f"Enable deterministic computation sucess! current seed is {seed_num}.") + + +def sanitize_csv_value(value: str, errors=CsvCheckLevel.STRICT): + if errors == CsvCheckLevel.IGNORE or not isinstance(value, str): + return value + sanitized_value = value + try: + float(value) + except Exception as e: + if not _MALICIOUS_CSV_PATTERN.search(value): + pass + elif errors == CsvCheckLevel.REPLACE: + sanitized_value = "" + logger.warning(f'Malicious CSV value detected and replaced: "{value}" -> "{sanitized_value}".') + else: + msg = f"Malicious value detected: {value}, please check the value written to the csv." + raise MsprobeException(MsgConst.RISK_ALERT, msg) from e + return sanitized_value + + +def get_net_output_nodes_from_graph_def(graph_def): + all_nodes = {node.name for node in graph_def.node} + input_nodes = set() + for node in graph_def.node: + for inp in node.input: + input_nodes.add(inp) + output_nodes = all_nodes - input_nodes + return list(output_nodes) + + +def is_input_yes(prompt): + confirm_pattern = re.compile(r"^\s*y(?:es)?\s*$", re.IGNORECASE) + try: + user_action = input(prompt).strip() + except (EOFError, KeyboardInterrupt): + logger.info('Input interrupted. Defaulting to "no".') + return False + return bool(confirm_pattern.fullmatch(user_action)) + + +def set_ld_preload(so_path): + ld_preload = evars.get("LD_PRELOAD", required=False) + if ld_preload: + evars.set("LD_PRELOAD", f"{so_path}:{ld_preload}") + else: + evars.set("LD_PRELOAD", so_path) + logger.info(f"Environment updated with .so library: {so_path}.")