diff --git a/torch_npu/__init__.py b/torch_npu/__init__.py index 9b1d869d8584aecacf5129bf74c67d17a2df5c1c..34ff2fb2788664f2241c95d3d18cbb9d16defd5d 100644 --- a/torch_npu/__init__.py +++ b/torch_npu/__init__.py @@ -1,5 +1,6 @@ __all__ = ["erase_stream", "matmul_checksum"] +import builtins import os import sys import types @@ -270,4 +271,37 @@ if 'TORCH_NPU_SANITIZER' in os.environ: if hasattr(sys, 'ps1'): os.environ["TASK_QUEUE_ENABLE"] = '0' warnings.warn("On the interactive interface, the value of TASK_QUEUE_ENABLE is set to 0 by default. \ - Do not set it to 1 to prevent some unknown errors") \ No newline at end of file + Do not set it to 1 to prevent some unknown errors") + + +def _replace_cuda_to_npu_in_list(args_list): + for idx, arg in enumerate(args_list): + if not isinstance(arg, builtins.bool) and isinstance(arg, builtins.int): + args_list[idx] = f'npu:{arg}' + return args_list + + +def _wrapper_cuda(fn): + @wraps(fn) + def decorated(*args, **kwargs): + if args: + args_new = list(args) + args = _replace_cuda_to_npu_in_list(args_new) + if kwargs: + device_ids = kwargs.get('device_ids', None) + if type(device_ids) == list: + device_ids = _replace_cuda_to_npu_in_list(device_ids) + return fn(*args, **kwargs) + + return decorated + + +def _device_wrapper(enter_fn, white_list): + for fn_name in white_list: + fn = getattr(enter_fn, fn_name, None) + if fn: + setattr(enter_fn, fn_name, _wrapper_cuda(fn)) + + +_device_wrapper(torch.Tensor, ["to"]) +_device_wrapper(torch.nn.Module, ["to", "to_empty"])