From 69fcfb48bdf8d24baca567e94534934a88180543 Mon Sep 17 00:00:00 2001 From: GuoGuanghao Date: Tue, 19 Aug 2025 22:10:44 +0800 Subject: [PATCH 1/6] Support stream into Dynamo charts --- test/dynamo/test_stream.py | 29 +++++++++++++++++++++++++++++ torch_npu/__init__.py | 4 ++++ torch_npu/dynamo/__init__.py | 1 + torch_npu/dynamo/trace_rule.py | 15 +++++++++++++++ 4 files changed, 49 insertions(+) create mode 100644 test/dynamo/test_stream.py create mode 100644 torch_npu/dynamo/trace_rule.py diff --git a/test/dynamo/test_stream.py b/test/dynamo/test_stream.py new file mode 100644 index 0000000000..0b19cca8d2 --- /dev/null +++ b/test/dynamo/test_stream.py @@ -0,0 +1,29 @@ +# Owner(s): ["module: dynamo"] +import torch +import torch_npu + +import torch._dynamo.test_case + +requires_npu = functools.partial(unittest.skipIf, not torch.npu.is_available(), "requires npu") + +class StreamintoDynamoTests(torch._dynamo.test_case.TestCase): + + @requires_npu() + def test_stream(self): + def model_1(x): + a = x * x + s = torch.npu.stream() + s.wait_stream(torch.npu.current_stream()) + with torch.npu.stream(s): + b = x + a + return b + inp = torch.randn(2,8).npu() + m = torch.compile(model_1,backend="aot_eager",fullgraph=True) + output = m(inp) + output1 = model_1(inp) + torch.allclose(output,, output1) + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + run_tests() diff --git a/torch_npu/__init__.py b/torch_npu/__init__.py index f571d55240..b8807797e5 100644 --- a/torch_npu/__init__.py +++ b/torch_npu/__init__.py @@ -94,6 +94,7 @@ from torch_npu._C._distributed_c10d import ParallelStore from torch_npu.op_plugin.meta import _meta_registrations from torch_npu.version import __version__ as __version__ from torch_npu import _op_plugin_docs +from torch_npu.dynamo import _patch_npu_trace_rules del _op_plugin_docs _cann_package_check() @@ -300,6 +301,9 @@ if 'TORCH_NPU_SANITIZER' in os.environ: # register npu device op overrides for inductor _inductor_register_device_op_overrides() +# Support stream into Dynamo charts +_patch_npu_trace_rules() + if hasattr(sys, 'ps1'): os.environ["TASK_QUEUE_ENABLE"] = '0' warnings.warn("On the interactive interface, the value of TASK_QUEUE_ENABLE is set to 0 by default. \ diff --git a/torch_npu/dynamo/__init__.py b/torch_npu/dynamo/__init__.py index a6c2357087..95be98be63 100644 --- a/torch_npu/dynamo/__init__.py +++ b/torch_npu/dynamo/__init__.py @@ -10,6 +10,7 @@ from torch.library import Library, impl from torch_npu.utils._error_code import ErrCode, pta_error from torch_npu.utils.utils import _should_print_warning +from .trace_rule import _patch_npu_trace_rules _global_npu_backend = None __all__ = [] diff --git a/torch_npu/dynamo/trace_rule.py b/torch_npu/dynamo/trace_rule.py new file mode 100644 index 0000000000..01e8992fa1 --- /dev/null +++ b/torch_npu/dynamo/trace_rule.py @@ -0,0 +1,15 @@ +import torch +from torch._dynamo.variables import TorchInGraphFunctionVariable + +torch_c_binding_in_graph_functions_npu = dict.fromkeys( + [ + "torch.npu.current_stream", + "torch.npu.default_stream", + "torch.npu.stream", + "torch.npu.set_stream", + ] +) + +def _patch_npu_trace_rules(): + torch._dynamo.trace_rules.clear_lru_cache() + torch._dynamo.trace_rules.torch_name_rule_map.append(torch_c_binding_in_graph_functions_npu) -- Gitee From c353260eabfab190e7fe19684746b01f01815b4e Mon Sep 17 00:00:00 2001 From: GuoGuanghao Date: Wed, 20 Aug 2025 14:54:21 +0800 Subject: [PATCH 2/6] cleancode --- test/dynamo/test_stream.py | 2 +- torch_npu/__init__.py | 2 +- torch_npu/dynamo/trace_rule.py | 3 ++- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/test/dynamo/test_stream.py b/test/dynamo/test_stream.py index 0b19cca8d2..4b9e30bfb9 100644 --- a/test/dynamo/test_stream.py +++ b/test/dynamo/test_stream.py @@ -21,7 +21,7 @@ class StreamintoDynamoTests(torch._dynamo.test_case.TestCase): m = torch.compile(model_1,backend="aot_eager",fullgraph=True) output = m(inp) output1 = model_1(inp) - torch.allclose(output,, output1) + torch.allclose(output, output1) if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch_npu/__init__.py b/torch_npu/__init__.py index b8807797e5..a6be46a08d 100644 --- a/torch_npu/__init__.py +++ b/torch_npu/__init__.py @@ -92,9 +92,9 @@ from torch_npu.asd.asd import _asd_patch from torch_npu.asd.checksum import _matmul_checksum as matmul_checksum from torch_npu._C._distributed_c10d import ParallelStore from torch_npu.op_plugin.meta import _meta_registrations +from torch_npu.dynamo import _patch_npu_trace_rules from torch_npu.version import __version__ as __version__ from torch_npu import _op_plugin_docs -from torch_npu.dynamo import _patch_npu_trace_rules del _op_plugin_docs _cann_package_check() diff --git a/torch_npu/dynamo/trace_rule.py b/torch_npu/dynamo/trace_rule.py index 01e8992fa1..ca6cd41bfc 100644 --- a/torch_npu/dynamo/trace_rule.py +++ b/torch_npu/dynamo/trace_rule.py @@ -7,7 +7,8 @@ torch_c_binding_in_graph_functions_npu = dict.fromkeys( "torch.npu.default_stream", "torch.npu.stream", "torch.npu.set_stream", - ] + ], + TorchInGraphFunctionVariable, ) def _patch_npu_trace_rules(): -- Gitee From 7f6bfaf7bd90df8404bdecc87f9064f254916f02 Mon Sep 17 00:00:00 2001 From: GuoGuanghao Date: Wed, 20 Aug 2025 15:32:15 +0800 Subject: [PATCH 3/6] cleancode --- test/dynamo/test_stream.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/test/dynamo/test_stream.py b/test/dynamo/test_stream.py index 4b9e30bfb9..cdf614ed41 100644 --- a/test/dynamo/test_stream.py +++ b/test/dynamo/test_stream.py @@ -1,11 +1,13 @@ # Owner(s): ["module: dynamo"] +import functools +import unittest import torch -import torch_npu - import torch._dynamo.test_case +import torch_npu requires_npu = functools.partial(unittest.skipIf, not torch.npu.is_available(), "requires npu") + class StreamintoDynamoTests(torch._dynamo.test_case.TestCase): @requires_npu() @@ -17,12 +19,13 @@ class StreamintoDynamoTests(torch._dynamo.test_case.TestCase): with torch.npu.stream(s): b = x + a return b - inp = torch.randn(2,8).npu() - m = torch.compile(model_1,backend="aot_eager",fullgraph=True) + inp = torch.randn(2, 8).npu() + m = torch.compile(model_1, backend="aot_eager", fullgraph=True) output = m(inp) output1 = model_1(inp) torch.allclose(output, output1) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests -- Gitee From 8826ada1690b66d84ef8af52d648085202a1d46e Mon Sep 17 00:00:00 2001 From: GuoGuanghao Date: Wed, 20 Aug 2025 16:58:13 +0800 Subject: [PATCH 4/6] cleancode --- torch_npu/dynamo/trace_rule.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch_npu/dynamo/trace_rule.py b/torch_npu/dynamo/trace_rule.py index ca6cd41bfc..3f3593dddc 100644 --- a/torch_npu/dynamo/trace_rule.py +++ b/torch_npu/dynamo/trace_rule.py @@ -11,6 +11,7 @@ torch_c_binding_in_graph_functions_npu = dict.fromkeys( TorchInGraphFunctionVariable, ) + def _patch_npu_trace_rules(): torch._dynamo.trace_rules.clear_lru_cache() torch._dynamo.trace_rules.torch_name_rule_map.append(torch_c_binding_in_graph_functions_npu) -- Gitee From dfb6ed79ec5349c936abae9d364d3439e9d57304 Mon Sep 17 00:00:00 2001 From: GuoGuanghao Date: Wed, 20 Aug 2025 18:53:03 +0800 Subject: [PATCH 5/6] fix test --- test/dynamo/test_stream.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/dynamo/test_stream.py b/test/dynamo/test_stream.py index cdf614ed41..13bd0ec26a 100644 --- a/test/dynamo/test_stream.py +++ b/test/dynamo/test_stream.py @@ -14,7 +14,7 @@ class StreamintoDynamoTests(torch._dynamo.test_case.TestCase): def test_stream(self): def model_1(x): a = x * x - s = torch.npu.stream() + s = torch.npu.Stream() s.wait_stream(torch.npu.current_stream()) with torch.npu.stream(s): b = x + a -- Gitee From b56d1e31e970e99d70c1242503257ff65b9b623c Mon Sep 17 00:00:00 2001 From: GuoGuanghao Date: Wed, 20 Aug 2025 19:08:15 +0800 Subject: [PATCH 6/6] fix test --- test/allowlist_for_publicAPI.json | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/allowlist_for_publicAPI.json b/test/allowlist_for_publicAPI.json index 356c2a5680..bb3d9caaad 100644 --- a/test/allowlist_for_publicAPI.json +++ b/test/allowlist_for_publicAPI.json @@ -2876,5 +2876,8 @@ "torch_npu.utils.profiler": [ "Singleton", "Profile" + ], + "torch_npu.dynamo.trace_rule": [ + "TorchInGraphFunctionVariable" ] } -- Gitee