diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py b/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py index db437539afeb98050ce59aad87a1e79d98b84085..aa93a12996a72d6996abfa707ff8884a679cdfe7 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py @@ -90,7 +90,7 @@ class DataCollector: if self.config.level == "L2": return self.data_writer.update_stack(self.data_processor.analyze_api_call_stack(name)) - if self.data_processor.stop_run(): + if self.data_processor.is_terminated: self.handle_data(name, data_info, use_buffer=False) raise Exception("[msprobe] exit") self.handle_data(name, data_info) @@ -101,7 +101,7 @@ class DataCollector: return data_info = self.data_processor.analyze_backward(name, module, module_input_output) - if self.data_processor.stop_run(): + if self.data_processor.is_terminated: self.handle_data(name, data_info, use_buffer=False) raise Exception("[msprobe] exit") self.handle_data(name, data_info) @@ -112,7 +112,7 @@ class DataCollector: self.data_writer.update_construct(self.module_processor.module_node) def handle_data(self, name, data_info, use_buffer=True): - msg = f"msProbe is collecting data on {name}. " + msg = f"msprobe is collecting data on {name}. " if data_info: msg = self.update_data(data_info, msg) logger.info(msg) diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py index 2fbc86b5656c3bcfe14b2fe9fe6bb295451e9466..80db0104bdd85a5b4986dabe82075222adb8d933 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py @@ -65,10 +65,21 @@ class BaseDataProcessor: self.current_iter = 0 self._return_forward_new_output = False self._forward_new_output = None + self.real_overflow_dump_times = 0 + self.overflow_nums = config.overflow_nums @property def data_path(self): return self.data_writer.dump_tensor_data_dir + + @property + def is_terminated(self): + if self.overflow_nums == -1: + return False + if self.real_overflow_dump_times >= self.overflow_nums: + logger.info(f"[msprobe] 超过预设溢出次数 当前溢出次数: {self.real_overflow_dump_times}") + return True + return False @staticmethod def analyze_api_call_stack(name): @@ -234,6 +245,3 @@ class BaseDataProcessor: suffix + file_format) file_path = os.path.join(self.data_writer.dump_tensor_data_dir, dump_data_name) return dump_data_name, file_path - - def stop_run(self): - return False diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py index c208df7d900683197fc24081b42835716ce7605f..8eeb0c2500d5b90d942dba2f41ecbfcc2a64fd94 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py @@ -154,8 +154,6 @@ class OverflowCheckDataProcessor(MindsporeDataProcessor): def __init__(self, config, data_writer): super().__init__(config, data_writer) self.cached_tensors_and_file_paths = {} - self.real_overflow_dump_times = 0 - self.overflow_nums = config.overflow_nums def analyze_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs): self.has_overflow = False @@ -178,14 +176,6 @@ class OverflowCheckDataProcessor(MindsporeDataProcessor): self.real_overflow_dump_times += 1 self.cached_tensors_and_file_paths = {} - def stop_run(self): - if self.overflow_nums == -1: - return False - if self.real_overflow_dump_times >= self.overflow_nums: - logger.warning(f"[msprobe] 超过预设溢出次数 当前溢出次数: {self.real_overflow_dump_times}") - return True - return False - def _analyze_maybe_overflow_tensor(self, tensor_json): if tensor_json['Max'] is None: return diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py index 007fec80964e300315c59f3d7fa4166b9d10fa70..191a33f9f7bfa1d8f503b5be704ade5a00e97132 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py @@ -183,8 +183,6 @@ class OverflowCheckDataProcessor(PytorchDataProcessor): def __init__(self, config, data_writer): super().__init__(config, data_writer) self.cached_tensors_and_file_paths = {} - self.real_overflow_dump_times = 0 - self.overflow_nums = config.overflow_nums self.bits_for_overflow = 8 @staticmethod @@ -209,16 +207,9 @@ class OverflowCheckDataProcessor(PytorchDataProcessor): for file_path, tensor in self.cached_tensors_and_file_paths.items(): torch.save(tensor, file_path) change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY) - self.inc_and_check_overflow_times() + self.real_overflow_dump_times += 1 self.cached_tensors_and_file_paths = {} - def inc_and_check_overflow_times(self): - self.real_overflow_dump_times += 1 - if self.overflow_nums == -1: - return - if self.real_overflow_dump_times >= self.overflow_nums: - raise MsprobeException(MsprobeException.OVERFLOW_NUMS_ERROR, str(self.real_overflow_dump_times)) - def check_overflow_npu(self): if self.overflow_debug_mode_enalbe(): float_status = torch.zeros(self.bits_for_overflow).npu()