From 7ff94f6d16c698e59e876a7f290c51d2e718d420 Mon Sep 17 00:00:00 2001 From: h00613304 Date: Thu, 23 Nov 2023 20:43:22 +0800 Subject: [PATCH 1/2] =?UTF-8?q?=E4=BC=98=E5=8C=96summary=5Fonly=E6=A8=A1?= =?UTF-8?q?=E5=BC=8F=E4=B8=8Bdump=E6=80=A7=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../src/python/ptdbg_ascend/dump/dump.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/dump.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/dump.py index 6e9a7f7b77..9e193e1b0a 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/dump.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/dump.py @@ -49,8 +49,7 @@ multi_output_apis = ["_sort_", "npu_flash_attention"] class DataInfo(object): - def __init__(self, data, save_data, summary_data, dtype, shape): - self.data = data + def __init__(self, save_data, summary_data, dtype, shape): self.save_data = save_data self.summary_data = summary_data self.dtype = dtype @@ -75,7 +74,7 @@ def get_not_float_tensor_info(data): def get_scalar_data_info(data): summary_data = [data, data, data] - return DataInfo(data, data, summary_data, str(type(data)), str([])) + return DataInfo(data, summary_data, str(type(data)), str([])) def get_float_tensor_info(data): @@ -87,13 +86,15 @@ def get_float_tensor_info(data): def get_tensor_data_info(data, tensor_max, tensor_min, tensor_mean): summary_data = [] - saved_tensor = data.contiguous().cpu().detach() - if data.dtype == torch.bfloat16: - saved_numpy = saved_tensor.to(torch.float32).numpy() - else: - saved_numpy = saved_tensor.numpy() summary_data.extend([tensor_max, tensor_min, tensor_mean]) - return DataInfo(data, saved_numpy, summary_data, str(data.dtype), tuple(data.shape)) + if not DumpUtil.summary_only: + saved_tensor = data.contiguous().cpu().detach() + if data.dtype == torch.bfloat16: + saved_numpy = saved_tensor.to(torch.float32).numpy() + else: + saved_numpy = saved_tensor.numpy() + return DataInfo(saved_numpy, summary_data, str(data.dtype), tuple(data.shape)) + return DataInfo([], summary_data, str(data.dtype), tuple(data.shape)) def json_dump_condition(prefix): -- Gitee From 4e1c23aa090e3adfaae18f3fe7441433ad686317 Mon Sep 17 00:00:00 2001 From: h00613304 Date: Thu, 23 Nov 2023 22:15:37 +0800 Subject: [PATCH 2/2] =?UTF-8?q?=E4=BF=AE=E6=94=B9ut=E7=94=A8=E4=BE=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- debug/accuracy_tools/ptdbg_ascend/test/ut/test_dump.py | 1 - 1 file changed, 1 deletion(-) diff --git a/debug/accuracy_tools/ptdbg_ascend/test/ut/test_dump.py b/debug/accuracy_tools/ptdbg_ascend/test/ut/test_dump.py index 804f8ff541..5b8cb1d9c0 100644 --- a/debug/accuracy_tools/ptdbg_ascend/test/ut/test_dump.py +++ b/debug/accuracy_tools/ptdbg_ascend/test/ut/test_dump.py @@ -21,7 +21,6 @@ class TestDump(unittest.TestCase): def test_get_scalar_data_info(self): data_info = get_scalar_data_info(self.scalar) - self.assertEqual(data_info.data, self.scalar) self.assertEqual(data_info.save_data, self.scalar) self.assertEqual(data_info.summary_data, [self.scalar, self.scalar, self.scalar]) self.assertEqual(data_info.dtype, '') -- Gitee