From 8377eb1dd69f9727d0271d03f3a5209d85e43b44 Mon Sep 17 00:00:00 2001 From: yangxinxian <947098055@qq.com> Date: Thu, 13 Mar 2025 14:35:25 +0800 Subject: [PATCH 1/8] =?UTF-8?q?=E5=86=92=E7=83=9F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../data_processor/mindspore_processor.py | 40 ++++++---- .../data_processor/pytorch_processor.py | 43 ++++++---- .../msprobe/core/data_dump/json_writer.py | 78 +++++++++++++++++++ 3 files changed, 129 insertions(+), 32 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py index c6ab0293cf3..125f68643ab 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py @@ -43,6 +43,7 @@ class MindsporeDataProcessor(BaseDataProcessor): self.mindspore_object_key = { "dtype": self.analyze_dtype_in_kwargs } + self.stat_stack_list = [] self._async_dump_cache = {} self.api_register = get_api_register() @@ -61,11 +62,10 @@ class MindsporeDataProcessor(BaseDataProcessor): def get_stat_info_sync(data): tensor_stat = TensorStatInfo() if data.dtype == ms.bool_: - data_np = data.asnumpy() - tensor_stat.max = np.max(data_np).item() - tensor_stat.min = np.min(data_np).item() + tensor_stat.max = mint.any(data) + tensor_stat.min = mint.all(data) elif not data.shape: - tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data.item() + tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data elif data.dtype == ms.complex64 or data.dtype == ms.complex128: data_abs = np.abs(data.asnumpy()) tensor_stat.max = np.max(data_abs).item() @@ -76,10 +76,10 @@ class MindsporeDataProcessor(BaseDataProcessor): if not ops.is_floating_point(data) or data.dtype == ms.float64: data = data.to(ms.float32) get_norm_value = mint.norm if hasattr(mint, "norm") else ops.norm - tensor_stat.max = mint.max(data).item() - tensor_stat.min = mint.min(data).item() - tensor_stat.mean = mint.mean(data).item() - tensor_stat.norm = get_norm_value(data).item() + tensor_stat.max = mint.max(data) + tensor_stat.min = mint.min(data) + tensor_stat.mean = mint.mean(data) + tensor_stat.norm = get_norm_value(data) return tensor_stat @staticmethod @@ -147,10 +147,16 @@ class MindsporeDataProcessor(BaseDataProcessor): } if tensor_stat.stack_tensor_stat is None: - tensor_json.update({'Max': self.transfer_type(tensor_stat.max)}) - tensor_json.update({'Min': self.transfer_type(tensor_stat.min)}) - tensor_json.update({'Mean': self.transfer_type(tensor_stat.mean)}) - tensor_json.update({'Norm': self.transfer_type(tensor_stat.norm)}) + stat_values = [ + tensor_stat.max, + tensor_stat.min, + tensor_stat.mean, + tensor_stat.norm + ] + + placeholder_index = self.data_writer.append_stat_to_buffer(stat_values) + + tensor_json.update({"tensor_stat_index": placeholder_index}) else: tensor_json.update({'tensor_stat': tensor_stat.stack_tensor_stat}) if self.config.summary_mode == Const.MD5 and not self.config.async_dump: @@ -231,7 +237,7 @@ class OverflowCheckDataProcessor(MindsporeDataProcessor): api_info_struct = super().analyze_backward(name, module, module_input_output) self.maybe_save_overflow_data() return api_info_struct if self.has_overflow else None - + def analyze_params(self, name, param_name, grad): self.has_overflow = False api_info_struct = super().analyze_params(name, param_name, grad) @@ -249,11 +255,13 @@ class OverflowCheckDataProcessor(MindsporeDataProcessor): self.cached_tensors_and_file_paths = {} def _analyze_maybe_overflow_tensor(self, tensor_json): - if tensor_json['Max'] is None: + max_tensor = self.data_writer.get_buffer_values_max(tensor_json['tensor_stat_index']) + min_tensor = self.data_writer.get_buffer_values_min(tensor_json['tensor_stat_index']) + if max_tensor is None or min_tensor is None: return - if np.isinf(tensor_json['Max']) or np.isnan(tensor_json['Max']): + if mint.isinf(max_tensor) or mint.isnan(max_tensor): self.has_overflow = True - if np.isinf(tensor_json['Min']) or np.isnan(tensor_json['Min']): + if mint.isinf(min_tensor) or mint.isnan(min_tensor): self.has_overflow = True def _analyze_tensor(self, tensor, suffix): diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py index 4c56419dcb1..b4e6551ca34 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py @@ -126,17 +126,17 @@ class PytorchDataProcessor(BaseDataProcessor): tensor_stat.min = np.min(data_abs).item() tensor_stat.mean = np.mean(data_abs).item() elif data.dtype == torch.bool: - tensor_stat.max = torch.any(data).item() - tensor_stat.min = torch.all(data).item() + tensor_stat.max = torch.any(data) + tensor_stat.min = torch.all(data) elif not data.shape: - tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data.item() + tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data else: if not data.is_floating_point() or data.dtype == torch.float64: data = data.float() - tensor_stat.max = torch.max(data).item() - tensor_stat.min = torch.min(data).item() - tensor_stat.mean = torch.mean(data).item() - tensor_stat.norm = torch.norm(data).item() + tensor_stat.max = torch.max(data) + tensor_stat.min = torch.min(data) + tensor_stat.mean = torch.mean(data) + tensor_stat.norm = torch.norm(data) return tensor_stat @staticmethod @@ -271,16 +271,21 @@ class PytorchDataProcessor(BaseDataProcessor): tensor_json.update({'dtype': str(tensor.dtype)}) tensor_json.update({"shape": tensor.shape}) if tensor_stat.stack_tensor_stat is None: - tensor_json.update({"Max": tensor_stat.max}) - tensor_json.update({"Min": tensor_stat.min}) - tensor_json.update({"Mean": tensor_stat.mean}) - tensor_json.update({"Norm": tensor_stat.norm}) + stat_values = [ + tensor_stat.max, + tensor_stat.min, + tensor_stat.mean, + tensor_stat.norm + ] + placeholder_index = self.data_writer.append_stat_to_buffer(stat_values) + + tensor_json.update({"tensor_stat_index": placeholder_index}) tensor_json.update({"requires_grad": tensor.requires_grad}) if tensor_stat.max is not None: - if np.isinf(tensor_stat.max) or np.isnan(tensor_stat.max): + if torch.isinf(tensor_stat.max) or torch.isnan(tensor_stat.max): tensor_json['Max_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(tensor, "max") if tensor_stat.min is not None: - if np.isinf(tensor_stat.min) or np.isnan(tensor_stat.min): + if torch.isinf(tensor_stat.min) or torch.isnan(tensor_stat.min): tensor_json['Min_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(tensor, "min") else: @@ -411,10 +416,16 @@ class OverflowCheckDataProcessor(PytorchDataProcessor): raise RuntimeError(f"overflow check failed") from e def _analyze_maybe_overflow_tensor(self, tensor_json): - if tensor_json['Max'] is None or tensor_json['Min'] is None: + max_tensor = self.data_writer.get_buffer_values_max(tensor_json['tensor_stat_index']) + min_tensor = self.data_writer.get_buffer_values_min(tensor_json['tensor_stat_index']) + + if max_tensor is None or min_tensor is None : return - self.has_overflow = np.isinf(tensor_json['Max']) or np.isnan(tensor_json['Max']) or \ - np.isinf(tensor_json['Min']) or np.isnan(tensor_json['Min']) + if torch.isinf(max_tensor) or torch.isnan(max_tensor): + self.has_overflow = True + + if torch.isinf(min_tensor) or torch.isnan(min_tensor): + self.has_overflow = True def _analyze_tensor(self, tensor, suffix): dump_data_name, file_path = self.get_save_file_path(suffix) diff --git a/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py b/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py index b1e26d16f97..33c95d686b4 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py @@ -38,6 +38,7 @@ class DataWriter: self.cache_stack = {} self.cache_construct = {} self.cache_debug = {} + self.stat_stack_list = [] @staticmethod def write_data_to_csv(result: list, result_header: tuple, file_path: str): @@ -126,7 +127,47 @@ class DataWriter: def write_debug_info_json(self, file_path): save_json(file_path, self.cache_debug, indent=1) + def append_stat_to_buffer(self, stat_vector): + """ + 直接使用 Python list 存储 stat_vector, + 将 stat_vector 存入 self.stat_stack_list 的方式 + """ + self.stat_stack_list.append(stat_vector) + return len(self.stat_stack_list) - 1 + + def get_buffer_values_max(self, index): + return self.stat_stack_list[index][0] + + def get_buffer_values_min(self, index): + print(f"self.stat_stack_list[index]:{self.stat_stack_list[index]}") + return self.stat_stack_list[index][1] + + def flush_stat_stack(self): + """ + 在 flush 阶段,将所有存储的统计值从设备搬到 CPU, + 这里返回一个列表,每个元素是 [Max, Min, Mean, Norm] 的数值列表 + """ + if not self.stat_stack_list: + return [] + result = [ + [ + x.cpu().detach().numpy().tolist() if hasattr(x, "cpu") else + x.asnumpy().tolist() if hasattr(x, "asnumpy") else x + for x in stat_values + ] + for stat_values in self.stat_stack_list + ] + self.stat_stack_list = [] + return result + def write_json(self): + # 在写 JSON 前,统一获取统计值 + stat_result = self.flush_stat_stack() + print(f"before:{self.cache_data}") + # 遍历 cache_data,将占位符替换为最终统计值 + if stat_result: + self._replace_stat_placeholders(self.cache_data, stat_result) + print(f"after:{self.cache_data}") if self.cache_data: self.write_data_json(self.dump_file_path) if self.cache_stack: @@ -136,6 +177,43 @@ class DataWriter: if self.cache_debug: self.write_debug_info_json(self.debug_file_path) + def _replace_stat_placeholders(self, data, stat_result): + if isinstance(data, dict): + keys = list(data.keys()) # 获取当前所有键 + for key in keys: # 避免遍历时修改字典 + value = data[key] + if key == "tensor_stat_index" and isinstance(value, int): + idx = value + stat_values = stat_result[idx] if idx < len(stat_result) else [None] * 4 + + # 构建新字段并删除旧键 + new_entries = { + "type": data["type"], + "dtype": data["dtype"], + "shape": data["shape"], + "Max": stat_values[0], + "Min": stat_values[1], + "Mean": stat_values[2], + "Norm": stat_values[3] + } + del data[key] + + # 重构字典顺序 + updated_dict = {} + # 先插入统计字段 + updated_dict.update(new_entries) + # 保留原字典其他字段(排除已删除的tensor_stat_index) + for k in data: + if k not in new_entries: + updated_dict[k] = data[k] + data.clear() + data.update(updated_dict) + else: + self._replace_stat_placeholders(value, stat_result) + elif isinstance(data, list): + for item in data: + self._replace_stat_placeholders(item, stat_result) + def fill_stack_tensor_data(self): self.process_stat_data_recursive(self.cache_data) -- Gitee From e1c00fb30760f29271a0527c8a5aefb4318a8f85 Mon Sep 17 00:00:00 2001 From: yangxinxian <947098055@qq.com> Date: Thu, 13 Mar 2025 14:41:34 +0800 Subject: [PATCH 2/8] =?UTF-8?q?=E5=86=92=E7=83=9F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- debug/accuracy_tools/msprobe/mindspore/dump/jit_dump.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/jit_dump.py b/debug/accuracy_tools/msprobe/mindspore/dump/jit_dump.py index a9a543a8fac..48fac8d1317 100644 --- a/debug/accuracy_tools/msprobe/mindspore/dump/jit_dump.py +++ b/debug/accuracy_tools/msprobe/mindspore/dump/jit_dump.py @@ -102,7 +102,7 @@ class JitDump(_MindsporeFunctionExecutor): return False return True - def grad(self, obj, grad, weights, grad_position, *args, **kwargs): + def grad(self, obj, grad, weights, grad_position, False, *args, **kwargs): if JitDump.jit_dump_switch and JitDump.jit_enable: _api_register.restore_all_api() output = self._executor.grad(grad, obj, weights, grad_position, *args, *(kwargs.values())) -- Gitee From 90a10b47f0bb2ba920e260dfa65d4b9824dfc191 Mon Sep 17 00:00:00 2001 From: yangxinxian <947098055@qq.com> Date: Thu, 13 Mar 2025 14:45:18 +0800 Subject: [PATCH 3/8] Update jit_dump.py --- debug/accuracy_tools/msprobe/mindspore/dump/jit_dump.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/jit_dump.py b/debug/accuracy_tools/msprobe/mindspore/dump/jit_dump.py index 48fac8d1317..715a06a1e96 100644 --- a/debug/accuracy_tools/msprobe/mindspore/dump/jit_dump.py +++ b/debug/accuracy_tools/msprobe/mindspore/dump/jit_dump.py @@ -102,10 +102,10 @@ class JitDump(_MindsporeFunctionExecutor): return False return True - def grad(self, obj, grad, weights, grad_position, False, *args, **kwargs): + def grad(self, obj, grad, weights, grad_position, *args, **kwargs): if JitDump.jit_dump_switch and JitDump.jit_enable: _api_register.restore_all_api() - output = self._executor.grad(grad, obj, weights, grad_position, *args, *(kwargs.values())) + output = self._executor.grad(grad, obj, weights, grad_position, False, *args, *(kwargs.values())) if JitDump.jit_dump_switch and JitDump.jit_enable: dump_jit(obj, args, None, False) _api_register.register_all_api() -- Gitee From 42ae4f5d87f34d6954a1dba78c1ae43029ee1e5d Mon Sep 17 00:00:00 2001 From: yangxinxian <947098055@qq.com> Date: Thu, 13 Mar 2025 16:32:15 +0800 Subject: [PATCH 4/8] Update jit_dump.py --- debug/accuracy_tools/msprobe/mindspore/dump/jit_dump.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/jit_dump.py b/debug/accuracy_tools/msprobe/mindspore/dump/jit_dump.py index 715a06a1e96..a9a543a8fac 100644 --- a/debug/accuracy_tools/msprobe/mindspore/dump/jit_dump.py +++ b/debug/accuracy_tools/msprobe/mindspore/dump/jit_dump.py @@ -105,7 +105,7 @@ class JitDump(_MindsporeFunctionExecutor): def grad(self, obj, grad, weights, grad_position, *args, **kwargs): if JitDump.jit_dump_switch and JitDump.jit_enable: _api_register.restore_all_api() - output = self._executor.grad(grad, obj, weights, grad_position, False, *args, *(kwargs.values())) + output = self._executor.grad(grad, obj, weights, grad_position, *args, *(kwargs.values())) if JitDump.jit_dump_switch and JitDump.jit_enable: dump_jit(obj, args, None, False) _api_register.register_all_api() -- Gitee From 61143517e1a4da742c5a61e30d6433e260662f0d Mon Sep 17 00:00:00 2001 From: yangxinxian <947098055@qq.com> Date: Mon, 17 Mar 2025 17:28:29 +0800 Subject: [PATCH 5/8] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E6=A3=80=E8=A7=86?= =?UTF-8?q?=E6=84=8F=E8=A7=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../core/data_dump/data_processor/pytorch_processor.py | 7 ------- debug/accuracy_tools/msprobe/core/data_dump/json_writer.py | 3 +-- 2 files changed, 1 insertion(+), 9 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py index b4e6551ca34..6b146203c4f 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py @@ -281,13 +281,6 @@ class PytorchDataProcessor(BaseDataProcessor): tensor_json.update({"tensor_stat_index": placeholder_index}) tensor_json.update({"requires_grad": tensor.requires_grad}) - if tensor_stat.max is not None: - if torch.isinf(tensor_stat.max) or torch.isnan(tensor_stat.max): - tensor_json['Max_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(tensor, "max") - if tensor_stat.min is not None: - if torch.isinf(tensor_stat.min) or torch.isnan(tensor_stat.min): - tensor_json['Min_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(tensor, "min") - else: tensor_json.update({"requires_grad": tensor.requires_grad}) tensor_json.update({"tensor_stat": tensor_stat.stack_tensor_stat}) diff --git a/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py b/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py index 33c95d686b4..c4af30fe933 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py @@ -151,8 +151,7 @@ class DataWriter: return [] result = [ [ - x.cpu().detach().numpy().tolist() if hasattr(x, "cpu") else - x.asnumpy().tolist() if hasattr(x, "asnumpy") else x + x.item() if hasattr(x, "item") else x for x in stat_values ] for stat_values in self.stat_stack_list -- Gitee From d676820b179c743ce02400fa3b3c58ef57489b3b Mon Sep 17 00:00:00 2001 From: yangxinxian <947098055@qq.com> Date: Mon, 17 Mar 2025 19:13:39 +0800 Subject: [PATCH 6/8] Update json_writer.py --- debug/accuracy_tools/msprobe/core/data_dump/json_writer.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py b/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py index c4af30fe933..96d79de2787 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py @@ -162,11 +162,9 @@ class DataWriter: def write_json(self): # 在写 JSON 前,统一获取统计值 stat_result = self.flush_stat_stack() - print(f"before:{self.cache_data}") # 遍历 cache_data,将占位符替换为最终统计值 if stat_result: self._replace_stat_placeholders(self.cache_data, stat_result) - print(f"after:{self.cache_data}") if self.cache_data: self.write_data_json(self.dump_file_path) if self.cache_stack: -- Gitee From 714d9d5c7823cb92d2201809554588680ffd8403 Mon Sep 17 00:00:00 2001 From: yangxinxian <947098055@qq.com> Date: Tue, 18 Mar 2025 15:28:01 +0800 Subject: [PATCH 7/8] Update mindspore_processor.py --- .../msprobe/core/data_dump/data_processor/mindspore_processor.py | 1 - 1 file changed, 1 deletion(-) diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py index 125f68643ab..7b71e97c78d 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py @@ -43,7 +43,6 @@ class MindsporeDataProcessor(BaseDataProcessor): self.mindspore_object_key = { "dtype": self.analyze_dtype_in_kwargs } - self.stat_stack_list = [] self._async_dump_cache = {} self.api_register = get_api_register() -- Gitee From 715447d24685741aba12bf1087381bbb4e0fe5ed Mon Sep 17 00:00:00 2001 From: yangxinxian <947098055@qq.com> Date: Tue, 18 Mar 2025 15:44:12 +0800 Subject: [PATCH 8/8] Update json_writer.py --- debug/accuracy_tools/msprobe/core/data_dump/json_writer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py b/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py index 96d79de2787..364237fd3c9 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py @@ -139,7 +139,6 @@ class DataWriter: return self.stat_stack_list[index][0] def get_buffer_values_min(self, index): - print(f"self.stat_stack_list[index]:{self.stat_stack_list[index]}") return self.stat_stack_list[index][1] def flush_stat_stack(self): -- Gitee