diff --git a/test/contrib/test_transfer_to_npu.py b/test/contrib/test_transfer_to_npu.py index a4efff90741f932e36d5d4834d037da42df97e82..cdd3f6298cc081d2a55ff53e94c6522bbaf56bbe 100644 --- a/test/contrib/test_transfer_to_npu.py +++ b/test/contrib/test_transfer_to_npu.py @@ -12,6 +12,27 @@ from torch_npu.contrib import transfer_to_npu class TestTransferToNpu(TestCase): + def test_generator(self): + g0 = torch.Generator() + self.assertTrue(isinstance(g0, torch.Generator)) + self.assertEqual(g0.device.type, 'cpu') + + g1 = torch.Generator('cuda') + self.assertTrue(isinstance(g1, torch.Generator)) + self.assertEqual(g1.device.type, 'npu') + + g2 = torch.Generator(torch.device('cuda')) + self.assertTrue(isinstance(g2, torch.Generator)) + self.assertEqual(g2.device.type, 'npu') + + g3 = torch.Generator(device='cuda') + self.assertTrue(isinstance(g3, torch.Generator)) + self.assertEqual(g3.device.type, 'npu') + + g4 = torch.Generator(device=torch.device('cuda')) + self.assertTrue(isinstance(g4, torch.Generator)) + self.assertEqual(g4.device.type, 'npu') + def test_wrap_isinstance(self): # check builtins isinstance grammar self.assertTrue(isinstance(1, int)) diff --git a/torch_npu/contrib/transfer_to_npu.py b/torch_npu/contrib/transfer_to_npu.py index bf3dcf6e37c3dff3bd18223e6b3df88a07550543..38fd3046507122cdbd398da68c2b941c0b1465a3 100644 --- a/torch_npu/contrib/transfer_to_npu.py +++ b/torch_npu/contrib/transfer_to_npu.py @@ -28,7 +28,7 @@ torch_fn_white_list = ['logspace', 'randint', 'hann_window', 'rand', 'full_like' 'eye', '_sparse_csr_tensor_unsafe', 'empty', '_sparse_coo_tensor_unsafe', 'blackman_window', 'zeros_like', 'range', 'sparse_csr_tensor', 'randn_like', 'from_file', '_cudnn_init_dropout_state', '_empty_affine_quantized', 'linspace', 'hamming_window', - 'empty_quantized', '_pin_memory', 'autocast', 'load', "Generator", 'set_default_device'] + 'empty_quantized', '_pin_memory', 'autocast', 'load', 'set_default_device'] torch_tensor_fn_white_list = ['new_empty', 'new_empty_strided', 'new_full', 'new_ones', 'new_tensor', 'new_zeros', 'to', 'pin_memory'] torch_module_fn_white_list = ['to', 'to_empty'] @@ -45,6 +45,14 @@ cur_path = os.path.dirname(os.path.realpath(__file__)) config_path = os.path.join(cur_path, 'apis_config.json') +class GeneratorProxy(torch.Generator): + + def __new__(cls, device='cpu'): + device = _replace_cuda_to_npu_in_list([device], None)[0] + instance = super().__new__(cls, device) + return instance + + def _get_function_from_string(attribute_string): try: module_path, _, attr_name = attribute_string.rpartition('.') @@ -332,6 +340,7 @@ def _init(): # torch.* _device_wrapper(torch, torch_fn_white_list) torch.UntypedStorage.__new__ = _wrapper_cuda(torch.UntypedStorage.__new__) + torch.Generator = GeneratorProxy # torch.Tensor.* _device_wrapper(torch.Tensor, torch_tensor_fn_white_list)