diff --git a/accuracy_tools/msprobe/utils/hijack.py b/accuracy_tools/msprobe/utils/hijack.py new file mode 100644 index 0000000000000000000000000000000000000000..91d31d8b842bd4b4a916732fdaa08dd48f92c47e --- /dev/null +++ b/accuracy_tools/msprobe/utils/hijack.py @@ -0,0 +1,405 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from abc import ABC, abstractmethod +from collections import defaultdict +from enum import Enum +from importlib.abc import Loader, MetaPathFinder +from importlib.util import module_from_spec, spec_from_loader +from uuid import uuid4 + +from msprobe.utils.constants import MsgConst +from msprobe.utils.exceptions import MsprobeException +from msprobe.utils.toolkits import check_int_border + + +class ActionType(Enum): + REPLACE = 0 + PRE_HOOK = 1 + POST_HOOK = 2 + + +class HijackHandler: + def __init__(self, unit): + self.unit = unit + self.call_count = 0 + self.call_data = defaultdict(dict) + self.released = False + + +def hijacker( + *, + stub: callable, + module: str, + cls: str = "", + function: str = "", + action: ActionType = ActionType.REPLACE, + priority: int = 100, +) -> str: + """ + Hijack module-import process or function execution process. + Support attaching pre/post hooks to the process, or replacing function implementations. + + .. target:: + When only set "module": module + When set "module" and "function": function in module + When set "module", "cls" and "function": function in class + + .. warning:: + The pre-hook of the module-import process will only take effect if it is set before the module is imported. + If the module is modified in its post-hook, the impact cannot be restored even if the hijacking is released. + + Parameters + ---------- + stub: Callable object. + Follow different format under different target and action. + --------------------------------------------------------------------------------------------------------------- + | target | action | format | description | + |-------------------------------------------------------------------------------------------------------------| + | module | pre-hook | callable() | Called before module import. | + |-------------------------------------------------------------------------------------------------------------| + | module | post-hook | callable(m) | Called after module import. "m" is the module. | + |-------------------------------------------------------------------------------------------------------------| + | function | replace | ret = callable(*args, **kws) | Replace original object. | + |-------------------------------------------------------------------------------------------------------------| + | function | pre-hook | args, kws = | Called before function execution, and the return will | + | | | callable(*args) | replace original input of the target function. | + |-------------------------------------------------------------------------------------------------------------| + | function | post-hook | ret = callable(ret, *args, | Called after function execution, and the return will | + | | | **kws) | replace original return of target function. | + --------------------------------------------------------------------------------------------------------------- + module: str + Full name of target module. + cls: str, optional + Full name of target class. + function: str, optional + Name of target function. + action: enum, optional + Choose between REPLACE, PRE_HOOK, and POST_HOOK. + priority: int, optional + The smaller the value is, the higher the priority is. When multiple hooks are set on the same target, they will + be excuted by priority. + + Returns + ------- + hander: + Handler to a hijacking. E.g., handler.unit, handler.call_data, handler.released. + """ + HiJackerManager.initialize() + unit = HijackerUnit(stub, module, cls, function, action, priority) + handler = HijackHandler(unit) + unit.handler = handler + HiJackerManager.add_unit(unit) + return handler + + +def release(handler): + """ + Cancel a hijacking. "handler" is returned by function "hijack". + """ + if isinstance(handler, HijackHandler): + handler.released = True + HiJackerManager.remove_unit(handler.unit) + else: + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, "Handler must be an instance of HijackHandler.") + + +class HijackerUnit: + def __init__(self, stub, module, cls, function, action, priority): + self.stub = stub + self.module = module + self.cls = cls + self.function = function + self.action = action + self.priority = priority + self.target = f"{module}-{cls}-{function}" + self.handler = None + self._check_para_valid() + + def _check_para_valid(self): + if not callable(self.stub): + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, '"stub" should be callable.') + if not self.module: + raise MsprobeException(MsgConst.REQUIRED_ARGU_MISSING, '"module" is required.') + if not isinstance(self.module, str): + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, '"module" should be a str.') + if self.cls and not isinstance(self.cls, str): + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, '"cls" should be a str.') + if self.cls and not self.function: + raise MsprobeException(MsgConst.REQUIRED_ARGU_MISSING, '"function" should be used when "cls" used.') + if self.function and not isinstance(self.function, str): + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, '"function" should be a str.') + if not isinstance(self.action, ActionType): + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, '"action" should be an ActionType.') + if not self.cls and not self.function and self.action == ActionType.REPLACE: + raise MsprobeException(MsgConst.INVALID_ARGU, "replacement of a module is not supported") + check_int_border(self.priority, tag="priority of HijackerUnit") + + +class HiJackerWrapperObj(ABC): + def __init__(self, name): + self.name = name + self.ori_obj = None + self.replacement = [] + self.pre_hooks = [] + self.post_hooks = [] + self.mod_name, self.class_name, self.func_name = name.split("-") + + @property + def is_empty(self): + return not self.replacement and not self.pre_hooks and not self.post_hooks + + @abstractmethod + def activate(self): + pass + + @abstractmethod + def deactivate(self): + pass + + def add_unit(self, unit): + if unit.action == ActionType.REPLACE: + self.replacement.append(unit) + self.replacement.sort(key=lambda x: x.priority) + elif unit.action == ActionType.PRE_HOOK: + self.pre_hooks.append(unit) + self.pre_hooks.sort(key=lambda x: x.priority) + else: + self.post_hooks.append(unit) + self.post_hooks.sort(key=lambda x: x.priority) + + def remove_unit(self, unit): + if unit.action == ActionType.REPLACE: + self.replacement.remove(unit) + elif unit.action == ActionType.PRE_HOOK: + self.pre_hooks.remove(unit) + else: + self.post_hooks.remove(unit) + + def set_ori_obj(self, obj): + self.ori_obj = obj + + +class HiJackerWrapperModule(HiJackerWrapperObj): + def __init__(self, name): + super().__init__(name) + + def exec_pre_hook(self): + for unit in self.pre_hooks: + unit.stub() + + def exec_post_hook(self, m): + self.set_ori_obj(m) + for unit in self.post_hooks: + unit.stub(m) + + def add_unit(self, unit): + super().add_unit(unit) + if unit.action == ActionType.POST_HOOK: + m = sys.modules.get(self.mod_name) + if m: + unit.stub(m) + + def activate(self): + HiJackerPathFinder.add_mod(self.mod_name) + + def deactivate(self): + HiJackerPathFinder.remove_mod(self.mod_name) + + +class HiJackerWrapperFunction(HiJackerWrapperObj): + def __init__(self, name): + super().__init__(name) + self.mod_hijacker = None + + def activate(self): + def replace_closure(class_name, func_name, wrapper): + def modify_module(m): + parent_obj = m + class_chain = class_name.split(".") if class_name else [] + for c in class_chain: + if not hasattr(parent_obj, c): + return + parent_obj = getattr(parent_obj, c) + if parent_obj and hasattr(parent_obj, func_name): + ori_obj = getattr(parent_obj, func_name) + self.set_ori_obj(ori_obj) + setattr(parent_obj, func_name, wrapper) + return + + return modify_module + + self.mod_hijacker = hijacker( + stub=replace_closure(self.class_name, self.func_name, self._get_wrapper()), + module=self.mod_name, + action=ActionType.POST_HOOK, + priority=0, + ) + return + + def deactivate(self): + if self.mod_hijacker: + release(self.mod_hijacker) + self.mod_hijacker = None + mod = sys.modules.get(self.mod_name) + if mod and self.ori_obj: + parent_obj = mod + class_chain = self.class_name.split(".") if self.class_name else [] + for c in class_chain: + if not hasattr(parent_obj, c): + self.ori_obj = None + return + parent_obj = getattr(parent_obj, c) + if parent_obj and hasattr(parent_obj, self.func_name): + setattr(parent_obj, self.func_name, self.ori_obj) + self.ori_obj = None + return + + def _get_wrapper(self): + def wrapper(*args, **kws): + if not self.ori_obj: + raise MsprobeException( + MsgConst.VALUE_NOT_FOUND, + "Original function object not found. Ensure activate() was called successfully.", + ) + call_index = None + for unit in self.pre_hooks + self.replacement + self.post_hooks: + if unit.handler: + unit.handler.call_count += 1 + call_index = unit.handler.call_count + unit.handler.call_data[call_index] = {"args": args, "kwargs": kws} + for unit in self.pre_hooks: + result = unit.stub(*args, **kws) + if isinstance(result, tuple): + args, kws = result + else: + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, "Pre-hook must return a tuple of (args, kws)") + f = self.replacement[0].stub if self.replacement else self.ori_obj + ret = f(*args, **kws) + for unit in self.post_hooks: + ret = unit.stub(ret, *args, **kws) + if call_index: + for unit in self.pre_hooks + self.replacement + self.post_hooks: + if unit.handler: + unit.handler.call_data[call_index]["return"] = ret + return ret + + return wrapper + + +class HiJackerManager: + _initialized = False + _hijacker_units = {} + _hijacker_wrappers = {} + + @classmethod + def initialize(cls): + if cls._initialized: + return + sys.meta_path.insert(0, HiJackerPathFinder()) + cls._initialized = True + + @classmethod + def add_unit(cls, unit): + handler = uuid4().hex + cls._hijacker_units[handler] = unit + wrapper_obj = cls._hijacker_wrappers.get(unit.target) + if not wrapper_obj: + wrapper_obj = cls._build_wrapper_obj(unit.target) + cls._hijacker_wrappers[unit.target] = wrapper_obj + wrapper_obj.activate() + wrapper_obj.add_unit(unit) + return handler + + @classmethod + def remove_unit(cls, handler): + unit = cls._hijacker_units.get(handler) + if not unit: + return + wrapper_obj = cls._hijacker_wrappers.get(unit.target) + wrapper_obj.remove_unit(unit) + if wrapper_obj.is_empty: + wrapper_obj.deactivate() + del cls._hijacker_wrappers[unit.target] + del cls._hijacker_units[handler] + + @classmethod + def get_module_wrapper(cls, name): + return cls._hijacker_wrappers.get(f"{name}--") + + @classmethod + def _build_wrapper_obj(cls, name): + _, _, f = name.split("-") + if f: + return HiJackerWrapperFunction(name) + else: + return HiJackerWrapperModule(name) + + +class HiJackerPathFinder(MetaPathFinder): + _modules_of_insterest = set() + + @classmethod + def add_mod(cls, name): + cls._modules_of_insterest.add(name) + + @classmethod + def remove_mod(cls, name): + cls._modules_of_insterest.discard(name) + + def find_spec(self, fullname, path, target=None): + if fullname not in self._modules_of_insterest: + return None + for finder in sys.meta_path: + if isinstance(finder, HiJackerPathFinder): + continue + spec = finder.find_spec(fullname, path, target) + if not spec: + continue + return spec_from_loader(fullname, HiJackerLoader(spec)) + return None + + def find_module(self, fullname, path=None): + if fullname not in self._modules_of_insterest: + return None + for finder in sys.meta_path: + if isinstance(finder, HiJackerPathFinder): + continue + loader = finder.find_module(fullname, path) + if not loader: + continue + return HiJackerLoader(spec_from_loader(fullname, loader)) + return None + + +class HiJackerLoader(Loader): + def __init__(self, ori_spec): + self.ori_spec = ori_spec + + def create_module(self, spec): + module = module_from_spec(self.ori_spec) + return module + + def load_module(self, fullname): + module = self.ori_spec.loader.load_module(fullname) + return module + + def exec_module(self, module): + wrapper = HiJackerManager.get_module_wrapper(module.__name__) + if wrapper: + wrapper.exec_pre_hook() + self.ori_spec.loader.exec_module(module) + if wrapper: + wrapper.exec_post_hook(module) diff --git a/accuracy_tools/msprobe/utils/io.py b/accuracy_tools/msprobe/utils/io.py new file mode 100644 index 0000000000000000000000000000000000000000..77533a5ecc60d20bc0554a3f5fad0b69dce3a0d3 --- /dev/null +++ b/accuracy_tools/msprobe/utils/io.py @@ -0,0 +1,347 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import csv +import json +import pickle +from functools import wraps + +import numpy as np +import pandas as pd +import yaml + +from msprobe.utils.constants import MsgConst, PathConst +from msprobe.utils.dependencies import dependent +from msprobe.utils.exceptions import MsprobeException +from msprobe.utils.log import logger +from msprobe.utils.path import ( + AUTHORITY_DIR, + AUTHORITY_FILE, + SafePath, + change_permission, + get_basename_from_path, + get_file_size, + join_path, +) +from msprobe.utils.toolkits import CsvCheckLevel, is_input_yes, sanitize_csv_value + +_LOAD_ERROR = 'Failed to load the path "{}" using <{}>.' +_SAVE_ERROR = 'Failed to save {} to "{}" using <{}>. Please check permissions or disk space.' + + +class SafelyOpen: + def __init__(self, file_path, mode, file_size_limitation=None, suffix=None, path_exist=True, encoding="utf-8"): + self.file_path = SafePath(file_path, PathConst.FILE, mode, file_size_limitation, suffix).check( + path_exist=path_exist + ) + self.mode = mode + self.encoding = encoding + self._file = None + + def __enter__(self): + if "b" not in self.mode: + self._file = open(self.file_path, self.mode, encoding=self.encoding) + else: + self._file = open(self.file_path, self.mode) + return self._file + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + def close(self): + if self._file and not self._file.closed: + self._file.close() + + +def _load_file(mode, file_size, file_suffix, use_safely_open: bool, encoding="utf-8"): + def decorator(func): + @wraps(func) + def wrapper(path, *args, **kwargs): + try: + if use_safely_open: + with SafelyOpen(path, mode, file_size, file_suffix, encoding) as f: + return func(f) + else: + path = SafePath(path, PathConst.FILE, mode, file_size, file_suffix).check() + return func(path, *args, **kwargs) + except Exception as e: + raise MsprobeException(MsgConst.IO_FAILURE, _LOAD_ERROR.format(path, func.__name__)) from e + + return wrapper + + return decorator + + +def _load_dir(dir_size): + def decorator(func): + @wraps(func) + def wrapper(path, *args, **kwargs): + path = SafePath(path, PathConst.DIR, "r", dir_size).check() + try: + return func(path, *args, **kwargs) + except Exception as e: + raise MsprobeException(MsgConst.IO_FAILURE, _LOAD_ERROR.format(path, func.__name__)) from e + + return wrapper + + return decorator + + +def _save_file(mode, file_size, file_suffix, use_safely_open: bool): + def decorator(func): + @wraps(func) + def wrapper(data, path, *args, **kwargs): + try: + if use_safely_open: + with SafelyOpen(path, mode, file_size, file_suffix, path_exist=False) as f: + func(data, f, *args, **kwargs) + else: + path = SafePath(path, PathConst.FILE, mode, file_size, file_suffix).check(path_exist=False) + func(data, path, *args, **kwargs) + except Exception as e: + raise MsprobeException( + MsgConst.IO_FAILURE, _SAVE_ERROR.format(data.__class__.__name__, path, func.__name__) + ) from e + change_permission(path, AUTHORITY_FILE) + + return wrapper + + return decorator + + +def _save_dir(dir_size): + def decorator(func): + @wraps(func) + def wrapper(data, path, *args, **kwargs): + path = SafePath(path, PathConst.DIR, "w", dir_size).check(path_exist=False) + try: + func(data, path, *args, **kwargs) + except Exception as e: + raise MsprobeException( + MsgConst.IO_FAILURE, _SAVE_ERROR.format(data.__class__.__name__, path, func.__name__) + ) from e + change_permission(path, AUTHORITY_DIR) + + return wrapper + + return decorator + + +@_load_file("r", PathConst.SIZE_30G, ".onnx", use_safely_open=False) +def load_onnx_model(model_path): + onnx = dependent.get("onnx") + return onnx.load_model(model_path) + + +@_load_file("r", PathConst.SIZE_30G, ".onnx", use_safely_open=False) +def load_onnx_session(model_path, onnx_fusion_switch=True, provider="CPUExecutionProvider"): + ort = dependent.get("onnxruntime") + options = ort.SessionOptions() + if not onnx_fusion_switch: + options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL + return ort.InferenceSession(model_path, sess_options=options, providers=[provider]) + + +@_load_file("r", PathConst.SIZE_30G, ".om", use_safely_open=False) +def load_om_model(model_path): + cmsprobe = dependent.get("msprobe.lib.msprobe_c") + model_id, ret = cmsprobe.acl.load_from_file(model_path) + if ret != 0: + raise MsprobeException(MsgConst.IO_FAILURE, f"Load model: {model_path} failed! ErrorCode = {ret}.") + logger.info(f"Load model: {model_path} success!") + return model_id + + +@_save_file("w", None, ".onnx", use_safely_open=False) +def save_onnx_model(onnx_model, save_path): + onnx = dependent.get("onnx") + model_size = onnx_model.ByteSize() + save_external_flag = model_size > PathConst.SIZE_2G + onnx.save_model(onnx_model, save_path, save_as_external_data=save_external_flag) + + +@_load_file("r", PathConst.SIZE_30G, ".prototxt", use_safely_open=False) +def load_caffe_model(model_path, weight_path): + caffe = dependent.get("caffe") + if caffe: + caffe.set_mode_cpu() + return caffe.Net(model_path, weight_path, caffe.TEST) + return None + + +@_load_file("r", PathConst.SIZE_10G, ".npy", use_safely_open=False) +def load_npy(npy_path): + return np.load(npy_path, allow_pickle=False) + + +def load_npy_from_buffer(raw_data, dtype, shape): + try: + return np.frombuffer(raw_data, dtype=dtype).reshape(shape) + except Exception as e: + raise MsprobeException(MsgConst.IO_FAILURE, "Failed to load npy data from buffer.") from e + + +@_save_file("w", None, ".npy", use_safely_open=False) +def save_npy(npy_data, save_path): + np.save(save_path, npy_data) + + +@_save_file("wb", None, ".bin", use_safely_open=False) +def save_bin_from_ndarray(numpy_data: np.ndarray, save_path): + numpy_data.tofile(save_path) + + +@_save_file("wb", None, ".bin", use_safely_open=True) +def save_bin_from_bytes(bytes_data, f): + f.write(bytes_data) + + +@_load_file("r", PathConst.SIZE_10G, ".bin", use_safely_open=False) +def load_bin_data(bin_path, dtype=np.float16, shape=None, is_byte_data=False): + if is_byte_data: + return np.fromfile(bin_path, dtype=np.int8) + if dtype == np.float32 and get_file_size(bin_path) == np.prod(shape) * 2: + return np.fromfile(bin_path, dtype=np.float16).astype(np.float32) + else: + return np.fromfile(bin_path, dtype=dtype) + + +@_load_dir(PathConst.SIZE_30G) +def load_saved_model(model_path, tag): + pons = dependent.get_tensorflow() + if None not in pons: + tf, _, _ = pons + tf.compat.v1.reset_default_graph() + graph = tf.compat.v1.Graph() + sess = tf.compat.v1.Session(graph=graph) + saved_model = tf.compat.v1.saved_model.loader.load(sess, set(tag), model_path) + return saved_model, sess + return None, None + + +@_load_file("rb", PathConst.SIZE_30G, ".pb", use_safely_open=False) +def load_pb_frozen_graph_model(model_path): + pons = dependent.get_tensorflow() + if None not in pons: + tf, _, _ = pons + data = tf.compat.v1.gfile.GFile(model_path, "rb").read() + graph_def = tf.compat.v1.GraphDef() + graph_def.ParseFromString(data) + tf.compat.v1.import_graph_def(graph_def, name="") + return graph_def + return None + + +@_save_file("wb", PathConst.SIZE_30G, ".pb", use_safely_open=False) +def save_pb_frozen_graph_model(frozen_graph, model_path): + pons = dependent.get_tensorflow() + if None not in pons: + tf, _, _ = pons + with tf.io.gfile.GFile(model_path, "wb") as f: + f.write(frozen_graph) + + +def savedmodel2pb(model_path, tag, serve, pb_save_dir): + """ + Converts a TensorFlow 1.x SavedModel to a frozen PB file. + + :param model_path: Path to the saved TensorFlow SavedModel directory + :param tag: Tag used for loading the model + :param serve: Signature key (e.g., "serving_default") + :param pb_save_dir: Directory to save the PB file + :return: Path to the converted PB file and net output nodes + """ + pons = dependent.get_tensorflow() + if None not in pons: + _, _, sm2pb = pons + meta_graph_def, sess = load_saved_model(model_path, tag) + signature_def = meta_graph_def.signature_def.get(serve) + if signature_def is None: + raise MsprobeException(MsgConst.VALUE_NOT_FOUND, f'Signature "{serve}" not found in the model.') + input_tensor_names = [t.name for t in signature_def.inputs.values()] + output_tensor_names = [t.name for t in signature_def.outputs.values()] + logger.info(f"Saved model input tensors: {input_tensor_names}.") + logger.info(f"Saved model output tensors: {output_tensor_names}.") + output_node_names = [t.split(":")[0] for t in output_tensor_names] + frozen_graph_def = sm2pb(sess, sess.graph.as_graph_def(), output_node_names) + pb_file_name = get_basename_from_path(model_path) + ".pb" + pb_file_path = join_path(pb_save_dir, pb_file_name) + save_pb_frozen_graph_model(frozen_graph_def.SerializeToString(), pb_file_path) + sess.close() + logger.info(f"SavedModel has been successfully converted to a frozen PB file at {pb_file_path}.") + return pb_file_path + return "" + + +@_load_file("r", PathConst.SIZE_500M, ".yaml", use_safely_open=True) +def load_yaml(f): + return yaml.safe_load(f) + + +@_save_file("w", None, ".yaml", use_safely_open=True) +def save_yaml(yaml_data, f): + yaml.dump(yaml_data, f) + + +@_load_file("r", PathConst.SIZE_2G, ".json", use_safely_open=True) +def load_json(f): + return json.load(f) + + +@_save_file("w", None, ".json", use_safely_open=True) +def save_json(json_data, f, indent: int = None): + json.dump(json_data, f, indent=indent, default=str) + + +@_load_file("r", PathConst.SIZE_500M, ".csv", use_safely_open=True, encoding="utf-8-sig") +def load_csv_by_builtin(f, sep=",", check=CsvCheckLevel.STRICT): + csv_reader = csv.reader(f, delimiter=sep) + sanitized_rows = [] + for row in csv_reader: + sanitized_row = [sanitize_csv_value(value, check) for value in row] + sanitized_rows.append(sanitized_row) + return sanitized_rows + + +@_load_file("r", PathConst.SIZE_500M, ".csv", use_safely_open=False) +def load_csv_by_pandas(csv_path, sep=",", check=CsvCheckLevel.STRICT): + df = pd.read_csv(csv_path, sep=sep, dtype=str) + df = df.applymap(lambda value: sanitize_csv_value(value, check)) + return df + + +@_save_file("w", None, ".csv", use_safely_open=False) +def save_csv_by_pandas(csv_data: pd.DataFrame, csv_path, sep=",", check=CsvCheckLevel.STRICT): + sanitized_data = csv_data.applymap(lambda value: sanitize_csv_value(value, check)) + sanitized_data.to_csv(csv_path, sep=sep, index=False) + + +@_load_file("r", PathConst.SIZE_30G, None, use_safely_open=False) +def load_torch_obj(path, **kwargs): + kwargs.setdefault("weights_only", True) + try: + torch = dependent.get("torch") + return torch.load(path, **kwargs) + except pickle.UnpicklingError: + if kwargs["weights_only"]: + prompt = """ + Weights only load failed. Re-running with `weights_only` set to `False` will likely succeed, + but it can result in arbitrary code execution. Do it only if you get the file from a trusted source. \n + Please confirm your awareness of the risks associated with this action ([y]/n): """ + if not is_input_yes(prompt): + return None + kwargs["weights_only"] = False + return torch.load(path, **kwargs) + else: + return None