diff --git a/test/dynamo/test_stream.py b/test/dynamo/test_stream.py new file mode 100644 index 0000000000000000000000000000000000000000..13bd0ec26afe4dc32d926c302d4f72620d695be0 --- /dev/null +++ b/test/dynamo/test_stream.py @@ -0,0 +1,32 @@ +# Owner(s): ["module: dynamo"] +import functools +import unittest +import torch +import torch._dynamo.test_case +import torch_npu + +requires_npu = functools.partial(unittest.skipIf, not torch.npu.is_available(), "requires npu") + + +class StreamintoDynamoTests(torch._dynamo.test_case.TestCase): + + @requires_npu() + def test_stream(self): + def model_1(x): + a = x * x + s = torch.npu.Stream() + s.wait_stream(torch.npu.current_stream()) + with torch.npu.stream(s): + b = x + a + return b + inp = torch.randn(2, 8).npu() + m = torch.compile(model_1, backend="aot_eager", fullgraph=True) + output = m(inp) + output1 = model_1(inp) + torch.allclose(output, output1) + + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + run_tests() diff --git a/torch_npu/__init__.py b/torch_npu/__init__.py index 9612c47c4ffbe6b4a8ff6fff2408a2f4bb051ee7..e47043fa2984ab44ed4669237e41100579d7148d 100644 --- a/torch_npu/__init__.py +++ b/torch_npu/__init__.py @@ -90,6 +90,7 @@ from torch_npu.npu.utils import _erase_stream as erase_stream from torch_npu.utils._error_code import ErrCode, pta_error, _except_handler from torch_npu.asd.asd import _asd_patch from torch_npu.asd.checksum import _matmul_checksum as matmul_checksum +from torch_npu.dynamo import _patch_npu_trace_rules from torch_npu._C._distributed_c10d import ParallelStore from torch_npu.op_plugin.meta import _meta_registrations from torch_npu.version import __version__ as __version__ @@ -298,6 +299,9 @@ if 'TORCH_NPU_SANITIZER' in os.environ: # register npu device op overrides for inductor _inductor_register_device_op_overrides() +# Support stream into Dynamo charts +_patch_npu_trace_rules() + if hasattr(sys, 'ps1'): os.environ["TASK_QUEUE_ENABLE"] = '0' warnings.warn("On the interactive interface, the value of TASK_QUEUE_ENABLE is set to 0 by default. \ diff --git a/torch_npu/dynamo/__init__.py b/torch_npu/dynamo/__init__.py index a6c235708733bcc8b1fe28e285e47bd821d41a6d..95be98be633ffb47d92b380b095e7cd04173b205 100644 --- a/torch_npu/dynamo/__init__.py +++ b/torch_npu/dynamo/__init__.py @@ -10,6 +10,7 @@ from torch.library import Library, impl from torch_npu.utils._error_code import ErrCode, pta_error from torch_npu.utils.utils import _should_print_warning +from .trace_rule import _patch_npu_trace_rules _global_npu_backend = None __all__ = [] diff --git a/torch_npu/dynamo/trace_rule.py b/torch_npu/dynamo/trace_rule.py new file mode 100644 index 0000000000000000000000000000000000000000..856aa214b6dc249b9ddd4211b1723e97ffc59356 --- /dev/null +++ b/torch_npu/dynamo/trace_rule.py @@ -0,0 +1,19 @@ +import torch +from torch._dynamo.variables import TorchInGraphFunctionVariable + +__all__ = [] + +torch_c_binding_in_graph_functions_npu = dict.fromkeys( + [ + "torch.npu.current_stream", + "torch.npu.default_stream", + "torch.npu.stream", + "torch.npu.set_stream", + ], + TorchInGraphFunctionVariable, +) + + +def _patch_npu_trace_rules(): + torch._dynamo.trace_rules.clear_lru_cache() + torch._dynamo.trace_rules.torch_name_rule_map.append(torch_c_binding_in_graph_functions_npu)