From 8908ddb50b2b28f0aab1579637f8ba4aca0a4ce1 Mon Sep 17 00:00:00 2001 From: GuoGuanghao Date: Tue, 19 Aug 2025 22:10:44 +0800 Subject: [PATCH 1/8] Support stream into Dynamo charts --- test/dynamo/test_stream.py | 29 +++++++++++++++++++++++++++++ torch_npu/__init__.py | 4 ++++ torch_npu/dynamo/__init__.py | 1 + torch_npu/dynamo/trace_rule.py | 15 +++++++++++++++ 4 files changed, 49 insertions(+) create mode 100644 test/dynamo/test_stream.py create mode 100644 torch_npu/dynamo/trace_rule.py diff --git a/test/dynamo/test_stream.py b/test/dynamo/test_stream.py new file mode 100644 index 00000000000..0b19cca8d2d --- /dev/null +++ b/test/dynamo/test_stream.py @@ -0,0 +1,29 @@ +# Owner(s): ["module: dynamo"] +import torch +import torch_npu + +import torch._dynamo.test_case + +requires_npu = functools.partial(unittest.skipIf, not torch.npu.is_available(), "requires npu") + +class StreamintoDynamoTests(torch._dynamo.test_case.TestCase): + + @requires_npu() + def test_stream(self): + def model_1(x): + a = x * x + s = torch.npu.stream() + s.wait_stream(torch.npu.current_stream()) + with torch.npu.stream(s): + b = x + a + return b + inp = torch.randn(2,8).npu() + m = torch.compile(model_1,backend="aot_eager",fullgraph=True) + output = m(inp) + output1 = model_1(inp) + torch.allclose(output,, output1) + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + run_tests() diff --git a/torch_npu/__init__.py b/torch_npu/__init__.py index 9612c47c4ff..1d7b73cf9fc 100644 --- a/torch_npu/__init__.py +++ b/torch_npu/__init__.py @@ -94,6 +94,7 @@ from torch_npu._C._distributed_c10d import ParallelStore from torch_npu.op_plugin.meta import _meta_registrations from torch_npu.version import __version__ as __version__ from torch_npu import _op_plugin_docs +from torch_npu.dynamo import _patch_npu_trace_rules del _op_plugin_docs _cann_package_check() @@ -298,6 +299,9 @@ if 'TORCH_NPU_SANITIZER' in os.environ: # register npu device op overrides for inductor _inductor_register_device_op_overrides() +# Support stream into Dynamo charts +_patch_npu_trace_rules() + if hasattr(sys, 'ps1'): os.environ["TASK_QUEUE_ENABLE"] = '0' warnings.warn("On the interactive interface, the value of TASK_QUEUE_ENABLE is set to 0 by default. \ diff --git a/torch_npu/dynamo/__init__.py b/torch_npu/dynamo/__init__.py index a6c23570873..95be98be633 100644 --- a/torch_npu/dynamo/__init__.py +++ b/torch_npu/dynamo/__init__.py @@ -10,6 +10,7 @@ from torch.library import Library, impl from torch_npu.utils._error_code import ErrCode, pta_error from torch_npu.utils.utils import _should_print_warning +from .trace_rule import _patch_npu_trace_rules _global_npu_backend = None __all__ = [] diff --git a/torch_npu/dynamo/trace_rule.py b/torch_npu/dynamo/trace_rule.py new file mode 100644 index 00000000000..01e8992fa1a --- /dev/null +++ b/torch_npu/dynamo/trace_rule.py @@ -0,0 +1,15 @@ +import torch +from torch._dynamo.variables import TorchInGraphFunctionVariable + +torch_c_binding_in_graph_functions_npu = dict.fromkeys( + [ + "torch.npu.current_stream", + "torch.npu.default_stream", + "torch.npu.stream", + "torch.npu.set_stream", + ] +) + +def _patch_npu_trace_rules(): + torch._dynamo.trace_rules.clear_lru_cache() + torch._dynamo.trace_rules.torch_name_rule_map.append(torch_c_binding_in_graph_functions_npu) -- Gitee From df242506eb509201b170357330b48c649f38ef9b Mon Sep 17 00:00:00 2001 From: GuoGuanghao Date: Wed, 20 Aug 2025 09:11:02 +0800 Subject: [PATCH 2/8] cleancode --- test/dynamo/test_stream.py | 2 +- torch_npu/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/dynamo/test_stream.py b/test/dynamo/test_stream.py index 0b19cca8d2d..4b9e30bfb91 100644 --- a/test/dynamo/test_stream.py +++ b/test/dynamo/test_stream.py @@ -21,7 +21,7 @@ class StreamintoDynamoTests(torch._dynamo.test_case.TestCase): m = torch.compile(model_1,backend="aot_eager",fullgraph=True) output = m(inp) output1 = model_1(inp) - torch.allclose(output,, output1) + torch.allclose(output, output1) if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch_npu/__init__.py b/torch_npu/__init__.py index 1d7b73cf9fc..e47043fa298 100644 --- a/torch_npu/__init__.py +++ b/torch_npu/__init__.py @@ -90,11 +90,11 @@ from torch_npu.npu.utils import _erase_stream as erase_stream from torch_npu.utils._error_code import ErrCode, pta_error, _except_handler from torch_npu.asd.asd import _asd_patch from torch_npu.asd.checksum import _matmul_checksum as matmul_checksum +from torch_npu.dynamo import _patch_npu_trace_rules from torch_npu._C._distributed_c10d import ParallelStore from torch_npu.op_plugin.meta import _meta_registrations from torch_npu.version import __version__ as __version__ from torch_npu import _op_plugin_docs -from torch_npu.dynamo import _patch_npu_trace_rules del _op_plugin_docs _cann_package_check() -- Gitee From 0d0ee461280c4d86bb80a3d8bf88f3ac37733245 Mon Sep 17 00:00:00 2001 From: GuoGuanghao Date: Wed, 20 Aug 2025 14:53:00 +0800 Subject: [PATCH 3/8] cleancode --- torch_npu/dynamo/trace_rule.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch_npu/dynamo/trace_rule.py b/torch_npu/dynamo/trace_rule.py index 01e8992fa1a..ca6cd41bfc4 100644 --- a/torch_npu/dynamo/trace_rule.py +++ b/torch_npu/dynamo/trace_rule.py @@ -7,7 +7,8 @@ torch_c_binding_in_graph_functions_npu = dict.fromkeys( "torch.npu.default_stream", "torch.npu.stream", "torch.npu.set_stream", - ] + ], + TorchInGraphFunctionVariable, ) def _patch_npu_trace_rules(): -- Gitee From 79eddff7eb85b476a16b400c2c130650bb6455a7 Mon Sep 17 00:00:00 2001 From: GuoGuanghao Date: Wed, 20 Aug 2025 15:32:55 +0800 Subject: [PATCH 4/8] cleancode --- test/dynamo/test_stream.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/test/dynamo/test_stream.py b/test/dynamo/test_stream.py index 4b9e30bfb91..cdf614ed413 100644 --- a/test/dynamo/test_stream.py +++ b/test/dynamo/test_stream.py @@ -1,11 +1,13 @@ # Owner(s): ["module: dynamo"] +import functools +import unittest import torch -import torch_npu - import torch._dynamo.test_case +import torch_npu requires_npu = functools.partial(unittest.skipIf, not torch.npu.is_available(), "requires npu") + class StreamintoDynamoTests(torch._dynamo.test_case.TestCase): @requires_npu() @@ -17,12 +19,13 @@ class StreamintoDynamoTests(torch._dynamo.test_case.TestCase): with torch.npu.stream(s): b = x + a return b - inp = torch.randn(2,8).npu() - m = torch.compile(model_1,backend="aot_eager",fullgraph=True) + inp = torch.randn(2, 8).npu() + m = torch.compile(model_1, backend="aot_eager", fullgraph=True) output = m(inp) output1 = model_1(inp) torch.allclose(output, output1) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests -- Gitee From 504fa31f956e611fab644f91f4411083e40edf45 Mon Sep 17 00:00:00 2001 From: GuoGuanghao Date: Wed, 20 Aug 2025 16:57:14 +0800 Subject: [PATCH 5/8] cleancode --- torch_npu/dynamo/trace_rule.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch_npu/dynamo/trace_rule.py b/torch_npu/dynamo/trace_rule.py index ca6cd41bfc4..3f3593dddc0 100644 --- a/torch_npu/dynamo/trace_rule.py +++ b/torch_npu/dynamo/trace_rule.py @@ -11,6 +11,7 @@ torch_c_binding_in_graph_functions_npu = dict.fromkeys( TorchInGraphFunctionVariable, ) + def _patch_npu_trace_rules(): torch._dynamo.trace_rules.clear_lru_cache() torch._dynamo.trace_rules.torch_name_rule_map.append(torch_c_binding_in_graph_functions_npu) -- Gitee From f361a46af1863d0aa08b1da155f1cdb5a4fe2a45 Mon Sep 17 00:00:00 2001 From: GuoGuanghao Date: Wed, 20 Aug 2025 18:53:37 +0800 Subject: [PATCH 6/8] fix test --- test/dynamo/test_stream.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/dynamo/test_stream.py b/test/dynamo/test_stream.py index cdf614ed413..13bd0ec26af 100644 --- a/test/dynamo/test_stream.py +++ b/test/dynamo/test_stream.py @@ -14,7 +14,7 @@ class StreamintoDynamoTests(torch._dynamo.test_case.TestCase): def test_stream(self): def model_1(x): a = x * x - s = torch.npu.stream() + s = torch.npu.Stream() s.wait_stream(torch.npu.current_stream()) with torch.npu.stream(s): b = x + a -- Gitee From 7a6ccd46c57ae7a5bc893df6e34f740428a159fe Mon Sep 17 00:00:00 2001 From: GuoGuanghao Date: Wed, 20 Aug 2025 19:07:17 +0800 Subject: [PATCH 7/8] fix test --- test/allowlist_for_publicAPI.json | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/allowlist_for_publicAPI.json b/test/allowlist_for_publicAPI.json index 97e775a73b7..9dae977e07a 100644 --- a/test/allowlist_for_publicAPI.json +++ b/test/allowlist_for_publicAPI.json @@ -2881,5 +2881,8 @@ "torch_npu.utils.profiler": [ "Singleton", "Profile" + ], + "torch_npu.dynamo.trace_rule": [ + "TorchInGraphFunctionVariable" ] } -- Gitee From e7a7a501a87a6923a6a3592a6fc9c073c3298a9b Mon Sep 17 00:00:00 2001 From: GuoGuanghao Date: Thu, 21 Aug 2025 09:42:59 +0800 Subject: [PATCH 8/8] fix test --- test/allowlist_for_publicAPI.json | 3 --- torch_npu/dynamo/trace_rule.py | 2 ++ 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/test/allowlist_for_publicAPI.json b/test/allowlist_for_publicAPI.json index 9dae977e07a..97e775a73b7 100644 --- a/test/allowlist_for_publicAPI.json +++ b/test/allowlist_for_publicAPI.json @@ -2881,8 +2881,5 @@ "torch_npu.utils.profiler": [ "Singleton", "Profile" - ], - "torch_npu.dynamo.trace_rule": [ - "TorchInGraphFunctionVariable" ] } diff --git a/torch_npu/dynamo/trace_rule.py b/torch_npu/dynamo/trace_rule.py index 3f3593dddc0..856aa214b6d 100644 --- a/torch_npu/dynamo/trace_rule.py +++ b/torch_npu/dynamo/trace_rule.py @@ -1,6 +1,8 @@ import torch from torch._dynamo.variables import TorchInGraphFunctionVariable +__all__ = [] + torch_c_binding_in_graph_functions_npu = dict.fromkeys( [ "torch.npu.current_stream", -- Gitee