From 07ad7020ff3852ca0eb6a171bef2e4f04a1bfd8f Mon Sep 17 00:00:00 2001 From: yuhaiyan Date: Wed, 10 Sep 2025 11:54:50 +0800 Subject: [PATCH] Fixed test_set_snapshort_fn --- .../test_pluggable_allocator_extensions.py | 12 --- test/allocator/test_set_snapshot.py | 79 +++++++++++++++++++ 2 files changed, 79 insertions(+), 12 deletions(-) create mode 100644 test/allocator/test_set_snapshot.py diff --git a/test/allocator/test_pluggable_allocator_extensions.py b/test/allocator/test_pluggable_allocator_extensions.py index d63051220d..4b015dcdf5 100644 --- a/test/allocator/test_pluggable_allocator_extensions.py +++ b/test/allocator/test_pluggable_allocator_extensions.py @@ -98,18 +98,6 @@ class TestPluggableAllocator(TestCase): torch.npu.reset_peak_memory_stats() self.assertEqual(torch.npu.max_memory_allocated(), 0) - def test_set_snapshot_fn(self): - os_path = os.path.join(TestPluggableAllocator.build_directory, 'pluggable_allocator_extensions.so') - myallocator = ctypes.CDLL(os_path) - snapshot_fn = ctypes.cast(getattr(myallocator, "my_snapshot"), ctypes.c_void_p).value - - msg = "snapshot_fn_ is not define, please set by set_snapshot_fn" - with self.assertRaisesRegex(RuntimeError, msg): - torch.npu.memory_snapshot() - - TestPluggableAllocator.new_alloc.allocator().set_snapshot_fn(snapshot_fn) - self.assertEqual(torch.npu.memory_snapshot(), []) - def test_pluggable_allocator_after_init(self): os_path = os.path.join(TestPluggableAllocator.build_directory, 'pluggable_allocator_extensions.so') # Do an initial memory allocator diff --git a/test/allocator/test_set_snapshot.py b/test/allocator/test_set_snapshot.py new file mode 100644 index 0000000000..bf30fd3b55 --- /dev/null +++ b/test/allocator/test_set_snapshot.py @@ -0,0 +1,79 @@ +import os +import sys +import shutil +import subprocess +import ctypes +import torch +import torch.utils.cpp_extension + +import torch_npu +from torch_npu.testing.testcase import TestCase, run_tests + + +PYTORCH_INSTALL_PATH = os.path.dirname(os.path.realpath(torch.__file__)) +PYTORCH_NPU_INSTALL_PATH = os.path.dirname(os.path.realpath(torch_npu.__file__)) + + +def create_build_path(build_directory): + if os.path.exists(build_directory): + shutil.rmtree(build_directory, ignore_errors=True) + os.makedirs(build_directory, exist_ok=True) + + +def build_stub(base_dir): + build_stub_cmd = ["sh", os.path.join(base_dir, 'third_party/acl/libs/build_stub.sh')] + if subprocess.call(build_stub_cmd) != 0: + raise RuntimeError('Failed to build stub: {}'.format(build_stub_cmd)) + + +class TestSnapshot(TestCase): + module = None + build_directory = "allocator/build" + + @classmethod + def setUpClass(cls): + # Build Extension + BASE_DIR = os.path.abspath("./../") + build_stub(BASE_DIR) + create_build_path(cls.build_directory) + CANN_LIB_PATH = os.path.join(BASE_DIR, 'third_party/acl/libs') + extra_ldflags = [] + extra_ldflags.append("-lascendcl") + extra_ldflags.append(f"-L{CANN_LIB_PATH}") + extra_ldflags.append("-lc10") + extra_ldflags.append(f"-L{PYTORCH_INSTALL_PATH}") + extra_include_paths = ["cpp_extensions"] + extra_include_paths.append(os.path.join(PYTORCH_NPU_INSTALL_PATH, 'include')) + extra_include_paths.append(os.path.join(PYTORCH_NPU_INSTALL_PATH, 'include/third_party/hccl/inc')) + extra_include_paths.append(os.path.join(PYTORCH_NPU_INSTALL_PATH, 'include/third_party/acl/inc')) + + cls.module = torch.utils.cpp_extension.load( + name="pluggable_allocator_extensions", + sources=[ + "cpp_extensions/pluggable_allocator_extensions.cpp" + ], + extra_include_paths=extra_include_paths, + extra_cflags=["-g"], + extra_ldflags=extra_ldflags, + build_directory=cls.build_directory, + verbose=True, + ) + + def test_set_snapshot_fn(self): + os_path = os.path.join(TestSnapshot.build_directory, 'pluggable_allocator_extensions.so') + myallocator = ctypes.CDLL(os_path) + snapshot_fn = ctypes.cast(getattr(myallocator, "my_snapshot"), ctypes.c_void_p).value + # Load the allocator + new_alloc = torch_npu.npu.memory.NPUPluggableAllocator(os_path, 'my_malloc', 'my_free') + # Swap the current allocator + torch_npu.npu.memory.change_current_allocator(new_alloc) + msg = "snapshot_fn_ is not define, please set by set_snapshot_fn" + with self.assertRaisesRegex(RuntimeError, msg): + torch.npu.memory_snapshot() + + new_alloc.allocator().set_snapshot_fn(snapshot_fn) + self.assertEqual(torch.npu.memory_snapshot(), []) + + +if __name__ == "__main__": + run_tests() -- Gitee