diff --git a/test/torch_npu_schema.json b/test/torch_npu_schema.json index 44b695d355e0a0d5a8d9e0b8ae1d162ea5b78793..47688abd281d021add50989c338d3e70613b0d85 100644 --- a/test/torch_npu_schema.json +++ b/test/torch_npu_schema.json @@ -1268,6 +1268,15 @@ "torch_npu.npu.synchronize": { "signature": "(device=None)" }, + "torch_npu.npu.obfuscation_initialize": { + "signature": "(hidden_size, tp_rank, cmd, *, data_type=None, model_obf_seed_id=0, data_obf_seed_id=0, thread_num=4, obf_coefficient=1.0)" + }, + "torch_npu.npu.obfuscation_finalize": { + "signature": "(fd_to_close)" + }, + "torch_npu.npu.obfuscation_calculate": { + "signature": "(fd, x, param, *, obf_coefficient=1.0)" + }, "torch_npu.npu.utilization": { "signature": "(device=None)" }, @@ -1667,6 +1676,15 @@ "torch_npu.npu.utils.synchronize": { "signature": "(device=None)" }, + "torch_npu.npu.utils.obfuscation_initialize": { + "signature": "(hidden_size, tp_rank, cmd, *, data_type=None, model_obf_seed_id=0, data_obf_seed_id=0, thread_num=4, obf_coefficient=1.0)" + }, + "torch_npu.npu.utils.obfuscation_finalize": { + "signature": "(fd_to_close)" + }, + "torch_npu.npu.utils.obfuscation_calculate": { + "signature": "(fd, x, param, *, obf_coefficient=1.0)" + }, "torch_npu.npu.utils.stress_detect": { "signature": "(detect_type='aic')" }, diff --git a/torch_npu/npu/__init__.py b/torch_npu/npu/__init__.py index 682d13f6df8b9c6af350bfb9dc65eff0e0c8617e..b00e1b14fd1454868b2c36ec2b63023d5a643d60 100644 --- a/torch_npu/npu/__init__.py +++ b/torch_npu/npu/__init__.py @@ -120,7 +120,10 @@ __all__ = [ "get_device_limit", "set_stream_limit", "reset_stream_limit", - "get_stream_limit" + "get_stream_limit", + "obfuscation_initialize", + "obfuscation_finalize", + "obfuscation_calculate" ] from typing import Tuple, Union, List, cast, Optional @@ -135,7 +138,8 @@ from torch_npu.utils import _should_print_warning import torch_npu from torch_npu.utils._error_code import ErrCode, pta_error, prof_error -from .utils import (synchronize, set_device, current_device, _get_device_index, +from .utils import (obfuscation_initialize, obfuscation_calculate, obfuscation_finalize, + synchronize, set_device, current_device, _get_device_index, device, device_of, StreamContext, stream, set_stream, current_stream, default_stream, set_sync_debug_mode, get_sync_debug_mode, init_dump, current_blas_handle, is_bf16_supported, finalize_dump, set_dump, get_npu_overflow_flag, clear_npu_overflow_flag, diff --git a/torch_npu/npu/utils.py b/torch_npu/npu/utils.py index 3e2fe1d929f2223d7a441ab7ec96b2b859f24874..4c2c04d17b69dc9eb5428f0e760642d8d31752bc 100644 --- a/torch_npu/npu/utils.py +++ b/torch_npu/npu/utils.py @@ -13,13 +13,27 @@ from torch_npu.utils._error_code import ErrCode, pta_error, _except_handler from torch_npu.npu._backends import get_soc_version -__all__ = ["synchronize", "set_device", "current_device", "device", "device_of", "StreamContext", +__all__ = ["obfuscation_initialize", "obfuscation_finalize", "obfuscation_calculate", + "synchronize", "set_device", "current_device", "device", "device_of", "StreamContext", "stream", "set_stream", "current_stream", "default_stream", "set_sync_debug_mode", "get_sync_debug_mode", "init_dump", "set_dump", "finalize_dump", "is_support_inf_nan", "is_bf16_supported", "get_npu_overflow_flag", "npu_check_overflow", "clear_npu_overflow_flag", "current_blas_handle", "check_uce_in_memory", "stress_detect", "get_cann_version"] +def obfuscation_initialize(hidden_size, tp_rank, cmd, *, data_type=None, model_obf_seed_id=0, data_obf_seed_id=0, thread_num=4, obf_coefficient=1.0): + return torch_npu.obfuscation_initialize(hidden_size, tp_rank, cmd, data_type=data_type, model_obf_seed_id=model_obf_seed_id, + data_obf_seed_id=data_obf_seed_id, thread_num=thread_num, obf_coefficient=obf_coefficient) + + +def obfuscation_finalize(fd_to_close): + return torch_npu.obfuscation_finalize(fd_to_close) + + +def obfuscation_calculate(fd, x, param, *, obf_coefficient=1.0): + return torch_npu.obfuscation_calculate(fd, x, param, obf_coefficient=obf_coefficient) + + def get_cann_version(module="CANN"): r""" Args: