diff --git a/torch_npu/contrib/transfer_to_npu.py b/torch_npu/contrib/transfer_to_npu.py index b899c0ecec371271bdeb25d03f03cac95916bf01..9a0f0946ec76eb3f021aae5b3f6e0a58d43eebec 100644 --- a/torch_npu/contrib/transfer_to_npu.py +++ b/torch_npu/contrib/transfer_to_npu.py @@ -30,7 +30,7 @@ torch_fn_white_list = ['logspace', 'randint', 'hann_window', 'rand', 'full_like' 'zeros_like', 'range', 'sparse_csr_tensor', 'randn_like', 'from_file', '_cudnn_init_dropout_state', '_empty_affine_quantized', 'linspace', 'hamming_window', 'empty_quantized', '_pin_memory', 'autocast', 'load', 'set_default_device'] -torch_tensor_fn_white_list = ['new_empty', 'new_empty_strided', 'new_full', 'new_ones', 'new_tensor', 'new_zeros', 'to', +torch_tensor_fn_white_list = ['new_empty', 'new_empty_strided', 'new_full', 'new_ones', 'new_tensor', 'new_zeros', 'to', 'pin_memory'] torch_module_fn_white_list = ['to', 'to_empty'] torch_cuda_fn_white_list = [ @@ -263,6 +263,12 @@ def _patch_has_triton(): return False +def _patch_get_available_device_type(): + if torch.npu.is_available(): + return 'npu' + return None + + def _patch_cuda(): patchs = [ ['cuda', torch_npu.npu], ['cuda.amp', torch_npu.npu.amp], @@ -373,6 +379,8 @@ def _init(): setattr(torch._inductor.utils, 'has_triton', _patch_has_triton) + setattr(torch._utils, '_get_available_device_type', _patch_get_available_device_type) + _replace_to_method_in_allowed_methods()