diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py b/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py index db437539afeb98050ce59aad87a1e79d98b84085..aa93a12996a72d6996abfa707ff8884a679cdfe7 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py @@ -90,7 +90,7 @@ class DataCollector: if self.config.level == "L2": return self.data_writer.update_stack(self.data_processor.analyze_api_call_stack(name)) - if self.data_processor.stop_run(): + if self.data_processor.is_terminated: self.handle_data(name, data_info, use_buffer=False) raise Exception("[msprobe] exit") self.handle_data(name, data_info) @@ -101,7 +101,7 @@ class DataCollector: return data_info = self.data_processor.analyze_backward(name, module, module_input_output) - if self.data_processor.stop_run(): + if self.data_processor.is_terminated: self.handle_data(name, data_info, use_buffer=False) raise Exception("[msprobe] exit") self.handle_data(name, data_info) @@ -112,7 +112,7 @@ class DataCollector: self.data_writer.update_construct(self.module_processor.module_node) def handle_data(self, name, data_info, use_buffer=True): - msg = f"msProbe is collecting data on {name}. " + msg = f"msprobe is collecting data on {name}. " if data_info: msg = self.update_data(data_info, msg) logger.info(msg) diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py index cb7d31c6061f10ee6a3898f668c06870e1e9e93f..ecca712082e4f79de8f341314ef3f3cb2f2c781f 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py @@ -235,5 +235,6 @@ class BaseDataProcessor: file_path = os.path.join(self.data_writer.dump_tensor_data_dir, dump_data_name) return dump_data_name, file_path - def stop_run(self): + @property + def is_terminated(self): return False diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py index c208df7d900683197fc24081b42835716ce7605f..457abb6976fb071984587a8a63a2b3cd0380c06e 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py @@ -178,11 +178,12 @@ class OverflowCheckDataProcessor(MindsporeDataProcessor): self.real_overflow_dump_times += 1 self.cached_tensors_and_file_paths = {} - def stop_run(self): + @property + def is_terminated(self): if self.overflow_nums == -1: return False if self.real_overflow_dump_times >= self.overflow_nums: - logger.warning(f"[msprobe] 超过预设溢出次数 当前溢出次数: {self.real_overflow_dump_times}") + logger.info(f"[msprobe] 超过预设溢出次数 当前溢出次数: {self.real_overflow_dump_times}") return True return False diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py index 007fec80964e300315c59f3d7fa4166b9d10fa70..f528e39de7dfb31f14ece05208c168b7199afa76 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py @@ -209,16 +209,18 @@ class OverflowCheckDataProcessor(PytorchDataProcessor): for file_path, tensor in self.cached_tensors_and_file_paths.items(): torch.save(tensor, file_path) change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY) - self.inc_and_check_overflow_times() + self.real_overflow_dump_times += 1 self.cached_tensors_and_file_paths = {} - def inc_and_check_overflow_times(self): - self.real_overflow_dump_times += 1 + @property + def is_terminated(self): if self.overflow_nums == -1: - return + return False if self.real_overflow_dump_times >= self.overflow_nums: - raise MsprobeException(MsprobeException.OVERFLOW_NUMS_ERROR, str(self.real_overflow_dump_times)) - + logger.info(f"[msprobe] 超过预设溢出次数 当前溢出次数:{self.real_overflow_dump_times}") + return True + return False + def check_overflow_npu(self): if self.overflow_debug_mode_enalbe(): float_status = torch.zeros(self.bits_for_overflow).npu()