From 32c2e5ddd4658be5301a942c2e01b5dd82d501d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E5=A4=A9?= <1063185601@qq.com> Date: Thu, 30 May 2024 12:38:21 +0000 Subject: [PATCH 1/2] =?UTF-8?q?=E6=94=AF=E6=8C=81=E6=BA=A2=E5=87=BA?= =?UTF-8?q?=E6=A3=80=E6=B5=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 李天 <1063185601@qq.com> --- debug/accuracy_tools/atat/pytorch/functional/data_processor.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/debug/accuracy_tools/atat/pytorch/functional/data_processor.py b/debug/accuracy_tools/atat/pytorch/functional/data_processor.py index 79433ffff5..68b84f6b29 100644 --- a/debug/accuracy_tools/atat/pytorch/functional/data_processor.py +++ b/debug/accuracy_tools/atat/pytorch/functional/data_processor.py @@ -278,7 +278,7 @@ class FullTensorDataProcessor(DataProcessor): return single_arg -class OverflowTensorDataProcessor(FullTensorDataProcessor): +class OverflowTensorDataProcessor(DataProcessor): __slots__ = ["cached_tensors_and_file_paths"] def __init__(self, config, data_writer): @@ -293,6 +293,7 @@ class OverflowTensorDataProcessor(FullTensorDataProcessor): self.cached_tensors_and_file_paths.update({file_path: tensor}) single_arg = super()._analyze_tensor(tensor, suffix) single_arg.update({"data_name": dump_data_name}) + return single_arg def analyze_forward(self, name, module_input_output: ModuleForwardInputsOutputs): -- Gitee From 4b8e7d4aa2d4c79126295aa78d26b1b3421bd10d Mon Sep 17 00:00:00 2001 From: litian_drinksnow <1063185601@qq.com> Date: Fri, 31 May 2024 11:12:23 +0800 Subject: [PATCH 2/2] fix bug in overflowprocessor --- .../atat/pytorch/functional/data_processor.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/debug/accuracy_tools/atat/pytorch/functional/data_processor.py b/debug/accuracy_tools/atat/pytorch/functional/data_processor.py index 68b84f6b29..f41a96768d 100644 --- a/debug/accuracy_tools/atat/pytorch/functional/data_processor.py +++ b/debug/accuracy_tools/atat/pytorch/functional/data_processor.py @@ -299,21 +299,23 @@ class OverflowTensorDataProcessor(DataProcessor): module_input_output: ModuleForwardInputsOutputs): self.has_overflow = False api_info_struct = super().analyze_forward(name, module_input_output) + self.maybe_save_overflow_data() if self.has_overflow: - self.save_overflow_data() return api_info_struct - return None + else: + return None def analyze_backward(self, name, module_input_output: ModuleBackwardInputsOutputs): self.has_overflow = False api_info_struct = super().analyze_backward(name, module_input_output) + self.maybe_save_overflow_data() if self.has_overflow: - self.save_overflow_data() return api_info_struct return None - def save_overflow_data(self): - for file_path, tensor in self.cached_tensors_and_file_paths.items(): - torch.save(tensor, file_path) + def maybe_save_overflow_data(self): + if self.has_overflow: + for file_path, tensor in self.cached_tensors_and_file_paths.items(): + torch.save(tensor, file_path) self.cached_tensors_and_file_paths = {} -- Gitee