diff --git a/torch_npu/contrib/transfer_to_npu.py b/torch_npu/contrib/transfer_to_npu.py index e7e6014c4f824a5d3d98d5fb1a73b27748957520..bc78c3147f26a677e1a2dba263f4e9969a5ebe96 100644 --- a/torch_npu/contrib/transfer_to_npu.py +++ b/torch_npu/contrib/transfer_to_npu.py @@ -29,7 +29,7 @@ torch_fn_white_list = ['logspace', 'randint', 'hann_window', 'rand', 'full_like' 'zeros_like', 'range', 'sparse_csr_tensor', 'randn_like', 'from_file', '_cudnn_init_dropout_state', '_empty_affine_quantized', 'linspace', 'hamming_window', 'empty_quantized', '_pin_memory', 'autocast', 'load', 'set_default_device'] -torch_tensor_fn_white_list = ['new_empty', 'new_empty_strided', 'new_full', 'new_ones', 'new_tensor', 'new_zeros', 'to', +torch_tensor_fn_white_list = ['new_empty', 'new_empty_strided', 'new_full', 'new_ones', 'new_tensor', 'new_zeros', 'to', 'pin_memory'] torch_module_fn_white_list = ['to', 'to_empty'] torch_cuda_fn_white_list = [ @@ -261,6 +261,12 @@ def _patch_has_triton(): return False +def _patch_get_available_device_type(): + if torch.npu.is_available(): + return 'npu' + return None + + def _patch_cuda(): patchs = [ ['cuda', torch_npu.npu], ['cuda.amp', torch_npu.npu.amp], @@ -378,6 +384,8 @@ def _init(): setattr(torch.utils._triton, 'has_triton', _patch_has_triton) + setattr(torch._utils, '_get_available_device_type', _patch_get_available_device_type) + _replace_to_method_in_allowed_methods()