diff --git a/test/dynamo/test_stream.py b/test/dynamo/test_stream.py new file mode 100644 index 0000000000000000000000000000000000000000..13bd0ec26afe4dc32d926c302d4f72620d695be0 --- /dev/null +++ b/test/dynamo/test_stream.py @@ -0,0 +1,32 @@ +# Owner(s): ["module: dynamo"] +import functools +import unittest +import torch +import torch._dynamo.test_case +import torch_npu + +requires_npu = functools.partial(unittest.skipIf, not torch.npu.is_available(), "requires npu") + + +class StreamintoDynamoTests(torch._dynamo.test_case.TestCase): + + @requires_npu() + def test_stream(self): + def model_1(x): + a = x * x + s = torch.npu.Stream() + s.wait_stream(torch.npu.current_stream()) + with torch.npu.stream(s): + b = x + a + return b + inp = torch.randn(2, 8).npu() + m = torch.compile(model_1, backend="aot_eager", fullgraph=True) + output = m(inp) + output1 = model_1(inp) + torch.allclose(output, output1) + + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + run_tests() diff --git a/torch_npu/__init__.py b/torch_npu/__init__.py index c4e663975387f26db735aa1874146ef52cd58718..e5145061aa6b9b8d5216c8a574d3ecc4ac16ad7f 100644 --- a/torch_npu/__init__.py +++ b/torch_npu/__init__.py @@ -92,6 +92,7 @@ from torch_npu.asd.asd import _asd_patch from torch_npu.asd.checksum import _matmul_checksum as matmul_checksum from torch_npu._C._distributed_c10d import ParallelStore from torch_npu.op_plugin.meta import _meta_registrations +from torch_npu.dynamo import _patch_npu_trace_rules from torch_npu.version import __version__ as __version__ from torch_npu import _op_plugin_docs del _op_plugin_docs @@ -299,6 +300,9 @@ if 'TORCH_NPU_SANITIZER' in os.environ: # register npu device op overrides for inductor _inductor_register_device_op_overrides() +# Support stream into Dynamo charts +_patch_npu_trace_rules() + if hasattr(sys, 'ps1'): os.environ["TASK_QUEUE_ENABLE"] = '0' warnings.warn("On the interactive interface, the value of TASK_QUEUE_ENABLE is set to 0 by default. \ diff --git a/torch_npu/dynamo/__init__.py b/torch_npu/dynamo/__init__.py index a6c235708733bcc8b1fe28e285e47bd821d41a6d..95be98be633ffb47d92b380b095e7cd04173b205 100644 --- a/torch_npu/dynamo/__init__.py +++ b/torch_npu/dynamo/__init__.py @@ -10,6 +10,7 @@ from torch.library import Library, impl from torch_npu.utils._error_code import ErrCode, pta_error from torch_npu.utils.utils import _should_print_warning +from .trace_rule import _patch_npu_trace_rules _global_npu_backend = None __all__ = [] diff --git a/torch_npu/dynamo/trace_rule.py b/torch_npu/dynamo/trace_rule.py new file mode 100644 index 0000000000000000000000000000000000000000..856aa214b6dc249b9ddd4211b1723e97ffc59356 --- /dev/null +++ b/torch_npu/dynamo/trace_rule.py @@ -0,0 +1,19 @@ +import torch +from torch._dynamo.variables import TorchInGraphFunctionVariable + +__all__ = [] + +torch_c_binding_in_graph_functions_npu = dict.fromkeys( + [ + "torch.npu.current_stream", + "torch.npu.default_stream", + "torch.npu.stream", + "torch.npu.set_stream", + ], + TorchInGraphFunctionVariable, +) + + +def _patch_npu_trace_rules(): + torch._dynamo.trace_rules.clear_lru_cache() + torch._dynamo.trace_rules.torch_name_rule_map.append(torch_c_binding_in_graph_functions_npu)