From cf95bb62ec32e6f981eb75440ab5a88d45e8dcfd Mon Sep 17 00:00:00 2001 From: GuoGuanghao Date: Tue, 19 Aug 2025 22:10:44 +0800 Subject: [PATCH 1/7] Support stream into Dynamo charts --- test/dynamo/test_stream.py | 29 +++++++++++++++++++++++++++++ torch_npu/__init__.py | 4 ++++ torch_npu/dynamo/__init__.py | 1 + torch_npu/dynamo/trace_rule.py | 15 +++++++++++++++ 4 files changed, 49 insertions(+) create mode 100644 test/dynamo/test_stream.py create mode 100644 torch_npu/dynamo/trace_rule.py diff --git a/test/dynamo/test_stream.py b/test/dynamo/test_stream.py new file mode 100644 index 00000000000..0b19cca8d2d --- /dev/null +++ b/test/dynamo/test_stream.py @@ -0,0 +1,29 @@ +# Owner(s): ["module: dynamo"] +import torch +import torch_npu + +import torch._dynamo.test_case + +requires_npu = functools.partial(unittest.skipIf, not torch.npu.is_available(), "requires npu") + +class StreamintoDynamoTests(torch._dynamo.test_case.TestCase): + + @requires_npu() + def test_stream(self): + def model_1(x): + a = x * x + s = torch.npu.stream() + s.wait_stream(torch.npu.current_stream()) + with torch.npu.stream(s): + b = x + a + return b + inp = torch.randn(2,8).npu() + m = torch.compile(model_1,backend="aot_eager",fullgraph=True) + output = m(inp) + output1 = model_1(inp) + torch.allclose(output,, output1) + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + run_tests() diff --git a/torch_npu/__init__.py b/torch_npu/__init__.py index 9e7cb7b2508..36e3b9a6738 100644 --- a/torch_npu/__init__.py +++ b/torch_npu/__init__.py @@ -94,6 +94,7 @@ from torch_npu._C._distributed_c10d import ParallelStore from torch_npu.op_plugin.meta import _meta_registrations from torch_npu.version import __version__ as __version__ from torch_npu import _op_plugin_docs +from torch_npu.dynamo import _patch_npu_trace_rules del _op_plugin_docs _cann_package_check() @@ -299,6 +300,9 @@ if 'TORCH_NPU_SANITIZER' in os.environ: # register npu device op overrides for inductor _inductor_register_device_op_overrides() +# Support stream into Dynamo charts +_patch_npu_trace_rules() + if hasattr(sys, 'ps1'): os.environ["TASK_QUEUE_ENABLE"] = '0' warnings.warn("On the interactive interface, the value of TASK_QUEUE_ENABLE is set to 0 by default. \ diff --git a/torch_npu/dynamo/__init__.py b/torch_npu/dynamo/__init__.py index a6c23570873..95be98be633 100644 --- a/torch_npu/dynamo/__init__.py +++ b/torch_npu/dynamo/__init__.py @@ -10,6 +10,7 @@ from torch.library import Library, impl from torch_npu.utils._error_code import ErrCode, pta_error from torch_npu.utils.utils import _should_print_warning +from .trace_rule import _patch_npu_trace_rules _global_npu_backend = None __all__ = [] diff --git a/torch_npu/dynamo/trace_rule.py b/torch_npu/dynamo/trace_rule.py new file mode 100644 index 00000000000..01e8992fa1a --- /dev/null +++ b/torch_npu/dynamo/trace_rule.py @@ -0,0 +1,15 @@ +import torch +from torch._dynamo.variables import TorchInGraphFunctionVariable + +torch_c_binding_in_graph_functions_npu = dict.fromkeys( + [ + "torch.npu.current_stream", + "torch.npu.default_stream", + "torch.npu.stream", + "torch.npu.set_stream", + ] +) + +def _patch_npu_trace_rules(): + torch._dynamo.trace_rules.clear_lru_cache() + torch._dynamo.trace_rules.torch_name_rule_map.append(torch_c_binding_in_graph_functions_npu) -- Gitee From f91d77eb4f34be0a7a2b633215ab02badba841e2 Mon Sep 17 00:00:00 2001 From: GuoGuanghao Date: Wed, 20 Aug 2025 14:56:26 +0800 Subject: [PATCH 2/7] cleancode --- test/dynamo/test_stream.py | 2 +- torch_npu/__init__.py | 2 +- torch_npu/dynamo/trace_rule.py | 3 ++- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/test/dynamo/test_stream.py b/test/dynamo/test_stream.py index 0b19cca8d2d..4b9e30bfb91 100644 --- a/test/dynamo/test_stream.py +++ b/test/dynamo/test_stream.py @@ -21,7 +21,7 @@ class StreamintoDynamoTests(torch._dynamo.test_case.TestCase): m = torch.compile(model_1,backend="aot_eager",fullgraph=True) output = m(inp) output1 = model_1(inp) - torch.allclose(output,, output1) + torch.allclose(output, output1) if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch_npu/__init__.py b/torch_npu/__init__.py index 36e3b9a6738..4822897a6d2 100644 --- a/torch_npu/__init__.py +++ b/torch_npu/__init__.py @@ -92,9 +92,9 @@ from torch_npu.asd.asd import _asd_patch from torch_npu.asd.checksum import _matmul_checksum as matmul_checksum from torch_npu._C._distributed_c10d import ParallelStore from torch_npu.op_plugin.meta import _meta_registrations +from torch_npu.dynamo import _patch_npu_trace_rules from torch_npu.version import __version__ as __version__ from torch_npu import _op_plugin_docs -from torch_npu.dynamo import _patch_npu_trace_rules del _op_plugin_docs _cann_package_check() diff --git a/torch_npu/dynamo/trace_rule.py b/torch_npu/dynamo/trace_rule.py index 01e8992fa1a..ca6cd41bfc4 100644 --- a/torch_npu/dynamo/trace_rule.py +++ b/torch_npu/dynamo/trace_rule.py @@ -7,7 +7,8 @@ torch_c_binding_in_graph_functions_npu = dict.fromkeys( "torch.npu.default_stream", "torch.npu.stream", "torch.npu.set_stream", - ] + ], + TorchInGraphFunctionVariable, ) def _patch_npu_trace_rules(): -- Gitee From bf4e4b51e327a54df5a50e59e2312af957afe813 Mon Sep 17 00:00:00 2001 From: GuoGuanghao Date: Wed, 20 Aug 2025 15:31:58 +0800 Subject: [PATCH 3/7] cleancode --- test/dynamo/test_stream.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/test/dynamo/test_stream.py b/test/dynamo/test_stream.py index 4b9e30bfb91..cdf614ed413 100644 --- a/test/dynamo/test_stream.py +++ b/test/dynamo/test_stream.py @@ -1,11 +1,13 @@ # Owner(s): ["module: dynamo"] +import functools +import unittest import torch -import torch_npu - import torch._dynamo.test_case +import torch_npu requires_npu = functools.partial(unittest.skipIf, not torch.npu.is_available(), "requires npu") + class StreamintoDynamoTests(torch._dynamo.test_case.TestCase): @requires_npu() @@ -17,12 +19,13 @@ class StreamintoDynamoTests(torch._dynamo.test_case.TestCase): with torch.npu.stream(s): b = x + a return b - inp = torch.randn(2,8).npu() - m = torch.compile(model_1,backend="aot_eager",fullgraph=True) + inp = torch.randn(2, 8).npu() + m = torch.compile(model_1, backend="aot_eager", fullgraph=True) output = m(inp) output1 = model_1(inp) torch.allclose(output, output1) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests -- Gitee From 0d82f218ac3f7d1d248ee8ebfc19929c6b9036d3 Mon Sep 17 00:00:00 2001 From: GuoGuanghao Date: Wed, 20 Aug 2025 16:58:30 +0800 Subject: [PATCH 4/7] cleancode --- torch_npu/dynamo/trace_rule.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch_npu/dynamo/trace_rule.py b/torch_npu/dynamo/trace_rule.py index ca6cd41bfc4..3f3593dddc0 100644 --- a/torch_npu/dynamo/trace_rule.py +++ b/torch_npu/dynamo/trace_rule.py @@ -11,6 +11,7 @@ torch_c_binding_in_graph_functions_npu = dict.fromkeys( TorchInGraphFunctionVariable, ) + def _patch_npu_trace_rules(): torch._dynamo.trace_rules.clear_lru_cache() torch._dynamo.trace_rules.torch_name_rule_map.append(torch_c_binding_in_graph_functions_npu) -- Gitee From 1e51334d64143556091a3db083aefd39923a61dd Mon Sep 17 00:00:00 2001 From: GuoGuanghao Date: Wed, 20 Aug 2025 18:52:47 +0800 Subject: [PATCH 5/7] fix test --- test/dynamo/test_stream.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/dynamo/test_stream.py b/test/dynamo/test_stream.py index cdf614ed413..13bd0ec26af 100644 --- a/test/dynamo/test_stream.py +++ b/test/dynamo/test_stream.py @@ -14,7 +14,7 @@ class StreamintoDynamoTests(torch._dynamo.test_case.TestCase): def test_stream(self): def model_1(x): a = x * x - s = torch.npu.stream() + s = torch.npu.Stream() s.wait_stream(torch.npu.current_stream()) with torch.npu.stream(s): b = x + a -- Gitee From d7568ce07b51fe33d7c8f91e5ea5c761c7821f65 Mon Sep 17 00:00:00 2001 From: GuoGuanghao Date: Wed, 20 Aug 2025 19:08:30 +0800 Subject: [PATCH 6/7] fix test --- test/allowlist_for_publicAPI.json | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/allowlist_for_publicAPI.json b/test/allowlist_for_publicAPI.json index 982a8e02f9e..6d46cb6bb33 100644 --- a/test/allowlist_for_publicAPI.json +++ b/test/allowlist_for_publicAPI.json @@ -2877,5 +2877,8 @@ "torch_npu.utils.profiler": [ "Singleton", "Profile" + ], + "torch_npu.dynamo.trace_rule": [ + "TorchInGraphFunctionVariable" ] } -- Gitee From 3e46ea8e4662b5c728fbe11b8a8b7f6edebbaf55 Mon Sep 17 00:00:00 2001 From: GuoGuanghao Date: Thu, 21 Aug 2025 09:41:46 +0800 Subject: [PATCH 7/7] fix test --- test/allowlist_for_publicAPI.json | 3 --- torch_npu/dynamo/trace_rule.py | 2 ++ 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/test/allowlist_for_publicAPI.json b/test/allowlist_for_publicAPI.json index 6d46cb6bb33..982a8e02f9e 100644 --- a/test/allowlist_for_publicAPI.json +++ b/test/allowlist_for_publicAPI.json @@ -2877,8 +2877,5 @@ "torch_npu.utils.profiler": [ "Singleton", "Profile" - ], - "torch_npu.dynamo.trace_rule": [ - "TorchInGraphFunctionVariable" ] } diff --git a/torch_npu/dynamo/trace_rule.py b/torch_npu/dynamo/trace_rule.py index 3f3593dddc0..856aa214b6d 100644 --- a/torch_npu/dynamo/trace_rule.py +++ b/torch_npu/dynamo/trace_rule.py @@ -1,6 +1,8 @@ import torch from torch._dynamo.variables import TorchInGraphFunctionVariable +__all__ = [] + torch_c_binding_in_graph_functions_npu = dict.fromkeys( [ "torch.npu.current_stream", -- Gitee