From cf015e624c885a815a6d66f5c615e91ec34e2f71 Mon Sep 17 00:00:00 2001 From: huangyunlong Date: Wed, 2 Jul 2025 11:28:23 +0800 Subject: [PATCH] add patch for to(int) --- torch_npu/__init__.py | 36 +++++++++++++++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/torch_npu/__init__.py b/torch_npu/__init__.py index 9b1d869d858..34ff2fb2788 100644 --- a/torch_npu/__init__.py +++ b/torch_npu/__init__.py @@ -1,5 +1,6 @@ __all__ = ["erase_stream", "matmul_checksum"] +import builtins import os import sys import types @@ -270,4 +271,37 @@ if 'TORCH_NPU_SANITIZER' in os.environ: if hasattr(sys, 'ps1'): os.environ["TASK_QUEUE_ENABLE"] = '0' warnings.warn("On the interactive interface, the value of TASK_QUEUE_ENABLE is set to 0 by default. \ - Do not set it to 1 to prevent some unknown errors") \ No newline at end of file + Do not set it to 1 to prevent some unknown errors") + + +def _replace_cuda_to_npu_in_list(args_list): + for idx, arg in enumerate(args_list): + if not isinstance(arg, builtins.bool) and isinstance(arg, builtins.int): + args_list[idx] = f'npu:{arg}' + return args_list + + +def _wrapper_cuda(fn): + @wraps(fn) + def decorated(*args, **kwargs): + if args: + args_new = list(args) + args = _replace_cuda_to_npu_in_list(args_new) + if kwargs: + device_ids = kwargs.get('device_ids', None) + if type(device_ids) == list: + device_ids = _replace_cuda_to_npu_in_list(device_ids) + return fn(*args, **kwargs) + + return decorated + + +def _device_wrapper(enter_fn, white_list): + for fn_name in white_list: + fn = getattr(enter_fn, fn_name, None) + if fn: + setattr(enter_fn, fn_name, _wrapper_cuda(fn)) + + +_device_wrapper(torch.Tensor, ["to"]) +_device_wrapper(torch.nn.Module, ["to", "to_empty"]) -- Gitee