From 9bb95a9c0a32e5d7a62848fbedc416045efe76e3 Mon Sep 17 00:00:00 2001 From: shibo19 Date: Tue, 8 Feb 2022 15:50:17 +0800 Subject: [PATCH] =?UTF-8?q?=E7=AE=80=E5=8C=96aoe=E4=BD=BF=E8=83=BD?= =?UTF-8?q?=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/test_aoe/test_aoe.py | 3 +-- torch_npu/npu/__init__.py | 4 ++-- torch_npu/npu/npu_frontend_enhance.py | 14 ++++++++++++-- 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/test/test_aoe/test_aoe.py b/test/test_aoe/test_aoe.py index c6310ba546..d465e04368 100644 --- a/test/test_aoe/test_aoe.py +++ b/test/test_aoe/test_aoe.py @@ -48,8 +48,7 @@ class TestAoe(TestCase): @classmethod def enable_aoe(cls): - option = {"autotune": "enable", "autotunegraphdumppath": TestAoe.results_path} - torch.npu.set_option(option) + torch.npu.set_aoe(TestAoe.results_path) def test_aoe_dumpgraph(self): def train(): diff --git a/torch_npu/npu/__init__.py b/torch_npu/npu/__init__.py index 758c2abd34..d2f0179f38 100644 --- a/torch_npu/npu/__init__.py +++ b/torch_npu/npu/__init__.py @@ -26,7 +26,7 @@ __all__ = [ "reset_peak_memory_stats", "reset_max_memory_allocated", "reset_max_memory_cached", "memory_allocated", "max_memory_allocated", "memory_reserved", "max_memory_reserved", "memory_cached", "max_memory_cached", "memory_snapshot", "memory_summary", - "Stream", "Event", "profiler", "set_option" + "Stream", "Event", "profiler", "set_option", "set_aoe" ] @@ -44,4 +44,4 @@ from .memory import (_free_mutex, caching_allocator_alloc, caching_allocator_del memory_cached, max_memory_cached, memory_snapshot, memory_summary) from .streams import Stream, Event from . import profiler -from .npu_frontend_enhance import set_option \ No newline at end of file +from .npu_frontend_enhance import set_option, set_aoe \ No newline at end of file diff --git a/torch_npu/npu/npu_frontend_enhance.py b/torch_npu/npu/npu_frontend_enhance.py index af120c713a..bb3b14bfa7 100644 --- a/torch_npu/npu/npu_frontend_enhance.py +++ b/torch_npu/npu/npu_frontend_enhance.py @@ -14,11 +14,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from logging import exception import os import torch_npu._C # this file is used to enhance the npu frontend API by set_option or other. -__all__ = ["set_option", "global_step_inc", "set_start_fuzz_compile_step"] +__all__ = ["set_option", "global_step_inc", "set_start_fuzz_compile_step", "set_aoe"] def set_option(option): if not isinstance(option, dict): @@ -34,7 +35,7 @@ def init_dump(): def set_dump(cfg_file): if not os.path.exists(cfg_file): - raise AssertionError("cfg_file %s path not exists."%(cfg_file)) + raise AssertionError("cfg_file %s path does not exists."%(cfg_file)) cfg_file = os.path.abspath(cfg_file) option = {"mdldumpconfigpath": cfg_file} torch_npu._C._npu_setOption(option) @@ -62,3 +63,12 @@ def set_start_fuzz_compile_step(step): option = {"fuzzycompileswitch": "disable"} torch_npu._C._npu_setOption(option) +def set_aoe(dump_path): + if os.path.exists(dump_path): + option = {"autotune": "enable", "autotunegraphdumppath": dump_path} + torch_npu._C._npu_setOption(option) + else: + try: + os.makedirs(dump_path) + except Exception: + raise ValueError("the path of '%s' is invaild."%(dump_path)) -- Gitee