diff --git a/torch_npu/npu/__init__.py b/torch_npu/npu/__init__.py index c9a87b9c09a6afc0cfd311cda83aac6735b95453..3ccda401dacd486f5ae5a94b3398d22361474536 100644 --- a/torch_npu/npu/__init__.py +++ b/torch_npu/npu/__init__.py @@ -15,6 +15,7 @@ __all__ = [ "device", "device_of", "stream", + "StreamContext", "set_stream", "current_stream", "default_stream", diff --git a/torch_npu/npu/utils.py b/torch_npu/npu/utils.py index f5bb5e68c40b37c982eae385c1199ef129200830..bc212c65a1e9f8ec62ad7fb0541b39eac760d1a2 100644 --- a/torch_npu/npu/utils.py +++ b/torch_npu/npu/utils.py @@ -4,6 +4,7 @@ from functools import lru_cache import warnings import contextlib from enum import Enum +from typing import Any, Optional import torch from torch._utils import _get_device_index as _torch_get_device_index @@ -182,39 +183,72 @@ class device_of(device): super(device_of, self).__init__(idx) -@contextlib.contextmanager -def stream(stream): +class StreamContext: r"""Context-manager that selects a given stream. All NPU kernels queued within its context will be enqueued on a selected stream. + Args: + Stream (Stream): selected stream. This manager is a no-op if it's + ``None``. + .. note:: Streams are per-device. + """ + cur_stream: Optional["torch_npu.npu.Stream"] + + def __init__(self, stream: Optional["torch_npu.npu.Stream"]): + self.stream = stream + self.idx = _get_device_index(None, True) + if not torch.jit.is_scripting(): + if self.idx is None: + self.idx = -1 + + self.src_prev_stream = ( + None if not torch.jit.is_scripting() else torch.npu.default_stream() + ) + self.dst_prev_stream = ( + None if not torch.jit.is_scripting() else torch.npu.default_stream() + ) + + def __enter__(self): + # Local cur_stream variable for type refinement + cur_stream = self.stream + # Return if stream is None or CUDA device not available + if cur_stream is None or self.idx == -1: + return + self.src_prev_stream = torch.npu.current_stream() + + # If the stream is not on the current device, then + # set the current stream on the device + if self.src_prev_stream.device != cur_stream.device: + with device(cur_stream.device): + self.dst_prev_stream = torch.cuda.current_stream(cur_stream.device) + torch.npu.set_stream(cur_stream) + + def __exit__(self, type: Any, value: Any, traceback: Any): + # Local cur_stream variable for type refinement + cur_stream = self.stream + # If stream is None or no NPU device available, return + if cur_stream is None or self.idx == -1: + return + + # Reset the stream on the original device + # and destination device + if self.src_prev_stream.device != cur_stream.device: # type: ignore[union-attr] + torch.npu.set_stream(self.dst_prev_stream) # type: ignore[arg-type] + torch.npu.set_stream(self.src_prev_stream) # type: ignore[arg-type] + + +def stream(stream): + r"""Wrap around the Context-manager StreamContext that selects a given stream. + Arguments: stream (Stream): selected stream. This manager is a no-op if it's ``None``. - - .. note:: Streams are per-device. If the selected stream is not on the - current device, this function will also change the current device to - match the stream. + ..Note:: In eager mode stream is of type Stream class while in JIT it is + an object of the custom class ``torch.classes.npu.Stream``. """ - if stream is None: - yield - return - src_prev_stream = current_stream() - - if src_prev_stream.device != stream.device: - # The given stream is on a different device; have to restore the - # current_stream on that device on exit as well - with device(stream.device): - dst_prev_stream = current_stream() - - torch.npu.set_stream(stream) - try: - yield - finally: - if src_prev_stream.device != stream.device: - torch.npu.set_stream(dst_prev_stream) - torch.npu.set_stream(src_prev_stream) + return StreamContext(stream) def set_stream(stream):