diff --git a/debug/accuracy_tools/atat/pytorch/functional/data_processor.py b/debug/accuracy_tools/atat/pytorch/functional/data_processor.py index 79433ffff5ebabd99fde8fc5f4b24511efd59c1c..f41a96768de20b83868311b18a243f2c119bb650 100644 --- a/debug/accuracy_tools/atat/pytorch/functional/data_processor.py +++ b/debug/accuracy_tools/atat/pytorch/functional/data_processor.py @@ -278,7 +278,7 @@ class FullTensorDataProcessor(DataProcessor): return single_arg -class OverflowTensorDataProcessor(FullTensorDataProcessor): +class OverflowTensorDataProcessor(DataProcessor): __slots__ = ["cached_tensors_and_file_paths"] def __init__(self, config, data_writer): @@ -293,26 +293,29 @@ class OverflowTensorDataProcessor(FullTensorDataProcessor): self.cached_tensors_and_file_paths.update({file_path: tensor}) single_arg = super()._analyze_tensor(tensor, suffix) single_arg.update({"data_name": dump_data_name}) + return single_arg def analyze_forward(self, name, module_input_output: ModuleForwardInputsOutputs): self.has_overflow = False api_info_struct = super().analyze_forward(name, module_input_output) + self.maybe_save_overflow_data() if self.has_overflow: - self.save_overflow_data() return api_info_struct - return None + else: + return None def analyze_backward(self, name, module_input_output: ModuleBackwardInputsOutputs): self.has_overflow = False api_info_struct = super().analyze_backward(name, module_input_output) + self.maybe_save_overflow_data() if self.has_overflow: - self.save_overflow_data() return api_info_struct return None - def save_overflow_data(self): - for file_path, tensor in self.cached_tensors_and_file_paths.items(): - torch.save(tensor, file_path) + def maybe_save_overflow_data(self): + if self.has_overflow: + for file_path, tensor in self.cached_tensors_and_file_paths.items(): + torch.save(tensor, file_path) self.cached_tensors_and_file_paths = {}