diff --git a/torch_npu/contrib/transfer_to_npu.py b/torch_npu/contrib/transfer_to_npu.py index 2c37fcb17b52f4c91afb05a9e337d1d839511e82..40f0bcc65cfbc55716f715697ef803d544129943 100644 --- a/torch_npu/contrib/transfer_to_npu.py +++ b/torch_npu/contrib/transfer_to_npu.py @@ -1,14 +1,17 @@ import os import warnings import json +import collections import importlib.metadata import logging as logger import functools from functools import wraps +from typing import Callable, cast, Optional import torch from torch.utils._device import _device_constructors from torch.utils._triton import has_triton from torch.nn.parameter import UninitializedTensorMixin +from torch._utils import _get_device_module import torch_npu try: @@ -29,7 +32,7 @@ torch_fn_white_list = ['logspace', 'randint', 'hann_window', 'rand', 'full_like' 'zeros_like', 'range', 'sparse_csr_tensor', 'randn_like', 'from_file', '_cudnn_init_dropout_state', '_empty_affine_quantized', 'linspace', 'hamming_window', 'empty_quantized', '_pin_memory', 'autocast', 'load', 'set_default_device'] -torch_tensor_fn_white_list = ['new_empty', 'new_empty_strided', 'new_full', 'new_ones', 'new_tensor', 'new_zeros', 'to', +torch_tensor_fn_white_list = ['new_empty', 'new_empty_strided', 'new_full', 'new_ones', 'new_tensor', 'new_zeros', 'to', 'pin_memory'] torch_module_fn_white_list = ['to', 'to_empty'] torch_cuda_fn_white_list = [ @@ -261,6 +264,32 @@ def _patch_has_triton(): return False +def _patch_get_available_device_type(): + if torch.npu.is_available(): + return 'npu' + return None + + +def _patch_OverlappingCpuLoader_init_(self, resolve_fun: Callable, stream: Optional[torch.Stream] = None, + inflight_threshhold: int = 1_000_000) -> None: + self.resolve_fun = resolve_fun + self.items: list[tuple[int, object]] = [] + self.inflight_threshhold = inflight_threshhold + self.in_flight_data = 0 + self.current_items: collections.deque = collections.deque() + self.idx = 0 + self.started = False + self.device_type = ( + stream.device_type if stream else _patch_get_available_device_type() + ) + self.device_module = _get_device_module(self.device_type) + self.stream = cast( + torch.cuda.Stream, stream or self.device_module.current_stream() + ) + if self.stream != self.device_module.current_stream(): + self.stream.wait_stream(self.device_module.current_stream()) + + def _patch_cuda(): patchs = [ ['cuda', torch_npu.npu], ['cuda.amp', torch_npu.npu.amp], @@ -379,6 +408,10 @@ def _init(): setattr(torch.utils._triton, 'has_triton', _patch_has_triton) + setattr(torch._utils, '_get_available_device_type', _patch_get_available_device_type) + setattr(torch.distributed.checkpoint.filesystem._OverlappingCpuLoader, '__init__', + _patch_OverlappingCpuLoader_init_) + _replace_to_method_in_allowed_methods()