diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py index c6ab0293cf3edafab06a5bf03e1a429d86e92720..7b71e97c78d2f7b01ef1d7e6ca0c18859b9fa663 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py @@ -61,11 +61,10 @@ class MindsporeDataProcessor(BaseDataProcessor): def get_stat_info_sync(data): tensor_stat = TensorStatInfo() if data.dtype == ms.bool_: - data_np = data.asnumpy() - tensor_stat.max = np.max(data_np).item() - tensor_stat.min = np.min(data_np).item() + tensor_stat.max = mint.any(data) + tensor_stat.min = mint.all(data) elif not data.shape: - tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data.item() + tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data elif data.dtype == ms.complex64 or data.dtype == ms.complex128: data_abs = np.abs(data.asnumpy()) tensor_stat.max = np.max(data_abs).item() @@ -76,10 +75,10 @@ class MindsporeDataProcessor(BaseDataProcessor): if not ops.is_floating_point(data) or data.dtype == ms.float64: data = data.to(ms.float32) get_norm_value = mint.norm if hasattr(mint, "norm") else ops.norm - tensor_stat.max = mint.max(data).item() - tensor_stat.min = mint.min(data).item() - tensor_stat.mean = mint.mean(data).item() - tensor_stat.norm = get_norm_value(data).item() + tensor_stat.max = mint.max(data) + tensor_stat.min = mint.min(data) + tensor_stat.mean = mint.mean(data) + tensor_stat.norm = get_norm_value(data) return tensor_stat @staticmethod @@ -147,10 +146,16 @@ class MindsporeDataProcessor(BaseDataProcessor): } if tensor_stat.stack_tensor_stat is None: - tensor_json.update({'Max': self.transfer_type(tensor_stat.max)}) - tensor_json.update({'Min': self.transfer_type(tensor_stat.min)}) - tensor_json.update({'Mean': self.transfer_type(tensor_stat.mean)}) - tensor_json.update({'Norm': self.transfer_type(tensor_stat.norm)}) + stat_values = [ + tensor_stat.max, + tensor_stat.min, + tensor_stat.mean, + tensor_stat.norm + ] + + placeholder_index = self.data_writer.append_stat_to_buffer(stat_values) + + tensor_json.update({"tensor_stat_index": placeholder_index}) else: tensor_json.update({'tensor_stat': tensor_stat.stack_tensor_stat}) if self.config.summary_mode == Const.MD5 and not self.config.async_dump: @@ -231,7 +236,7 @@ class OverflowCheckDataProcessor(MindsporeDataProcessor): api_info_struct = super().analyze_backward(name, module, module_input_output) self.maybe_save_overflow_data() return api_info_struct if self.has_overflow else None - + def analyze_params(self, name, param_name, grad): self.has_overflow = False api_info_struct = super().analyze_params(name, param_name, grad) @@ -249,11 +254,13 @@ class OverflowCheckDataProcessor(MindsporeDataProcessor): self.cached_tensors_and_file_paths = {} def _analyze_maybe_overflow_tensor(self, tensor_json): - if tensor_json['Max'] is None: + max_tensor = self.data_writer.get_buffer_values_max(tensor_json['tensor_stat_index']) + min_tensor = self.data_writer.get_buffer_values_min(tensor_json['tensor_stat_index']) + if max_tensor is None or min_tensor is None: return - if np.isinf(tensor_json['Max']) or np.isnan(tensor_json['Max']): + if mint.isinf(max_tensor) or mint.isnan(max_tensor): self.has_overflow = True - if np.isinf(tensor_json['Min']) or np.isnan(tensor_json['Min']): + if mint.isinf(min_tensor) or mint.isnan(min_tensor): self.has_overflow = True def _analyze_tensor(self, tensor, suffix): diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py index 4c56419dcb17b78918e4d46a3aaa50b12ef32777..6b146203c4f5ba11fa3bfffd9b22b1f1ab4e2a29 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py @@ -126,17 +126,17 @@ class PytorchDataProcessor(BaseDataProcessor): tensor_stat.min = np.min(data_abs).item() tensor_stat.mean = np.mean(data_abs).item() elif data.dtype == torch.bool: - tensor_stat.max = torch.any(data).item() - tensor_stat.min = torch.all(data).item() + tensor_stat.max = torch.any(data) + tensor_stat.min = torch.all(data) elif not data.shape: - tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data.item() + tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data else: if not data.is_floating_point() or data.dtype == torch.float64: data = data.float() - tensor_stat.max = torch.max(data).item() - tensor_stat.min = torch.min(data).item() - tensor_stat.mean = torch.mean(data).item() - tensor_stat.norm = torch.norm(data).item() + tensor_stat.max = torch.max(data) + tensor_stat.min = torch.min(data) + tensor_stat.mean = torch.mean(data) + tensor_stat.norm = torch.norm(data) return tensor_stat @staticmethod @@ -271,18 +271,16 @@ class PytorchDataProcessor(BaseDataProcessor): tensor_json.update({'dtype': str(tensor.dtype)}) tensor_json.update({"shape": tensor.shape}) if tensor_stat.stack_tensor_stat is None: - tensor_json.update({"Max": tensor_stat.max}) - tensor_json.update({"Min": tensor_stat.min}) - tensor_json.update({"Mean": tensor_stat.mean}) - tensor_json.update({"Norm": tensor_stat.norm}) + stat_values = [ + tensor_stat.max, + tensor_stat.min, + tensor_stat.mean, + tensor_stat.norm + ] + placeholder_index = self.data_writer.append_stat_to_buffer(stat_values) + + tensor_json.update({"tensor_stat_index": placeholder_index}) tensor_json.update({"requires_grad": tensor.requires_grad}) - if tensor_stat.max is not None: - if np.isinf(tensor_stat.max) or np.isnan(tensor_stat.max): - tensor_json['Max_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(tensor, "max") - if tensor_stat.min is not None: - if np.isinf(tensor_stat.min) or np.isnan(tensor_stat.min): - tensor_json['Min_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(tensor, "min") - else: tensor_json.update({"requires_grad": tensor.requires_grad}) tensor_json.update({"tensor_stat": tensor_stat.stack_tensor_stat}) @@ -411,10 +409,16 @@ class OverflowCheckDataProcessor(PytorchDataProcessor): raise RuntimeError(f"overflow check failed") from e def _analyze_maybe_overflow_tensor(self, tensor_json): - if tensor_json['Max'] is None or tensor_json['Min'] is None: + max_tensor = self.data_writer.get_buffer_values_max(tensor_json['tensor_stat_index']) + min_tensor = self.data_writer.get_buffer_values_min(tensor_json['tensor_stat_index']) + + if max_tensor is None or min_tensor is None : return - self.has_overflow = np.isinf(tensor_json['Max']) or np.isnan(tensor_json['Max']) or \ - np.isinf(tensor_json['Min']) or np.isnan(tensor_json['Min']) + if torch.isinf(max_tensor) or torch.isnan(max_tensor): + self.has_overflow = True + + if torch.isinf(min_tensor) or torch.isnan(min_tensor): + self.has_overflow = True def _analyze_tensor(self, tensor, suffix): dump_data_name, file_path = self.get_save_file_path(suffix) diff --git a/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py b/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py index b1e26d16f9741765c1c9600a64efb112aa0f42d7..364237fd3c974816849f23330d253baba9e459d1 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py @@ -38,6 +38,7 @@ class DataWriter: self.cache_stack = {} self.cache_construct = {} self.cache_debug = {} + self.stat_stack_list = [] @staticmethod def write_data_to_csv(result: list, result_header: tuple, file_path: str): @@ -126,7 +127,43 @@ class DataWriter: def write_debug_info_json(self, file_path): save_json(file_path, self.cache_debug, indent=1) + def append_stat_to_buffer(self, stat_vector): + """ + 直接使用 Python list 存储 stat_vector, + 将 stat_vector 存入 self.stat_stack_list 的方式 + """ + self.stat_stack_list.append(stat_vector) + return len(self.stat_stack_list) - 1 + + def get_buffer_values_max(self, index): + return self.stat_stack_list[index][0] + + def get_buffer_values_min(self, index): + return self.stat_stack_list[index][1] + + def flush_stat_stack(self): + """ + 在 flush 阶段,将所有存储的统计值从设备搬到 CPU, + 这里返回一个列表,每个元素是 [Max, Min, Mean, Norm] 的数值列表 + """ + if not self.stat_stack_list: + return [] + result = [ + [ + x.item() if hasattr(x, "item") else x + for x in stat_values + ] + for stat_values in self.stat_stack_list + ] + self.stat_stack_list = [] + return result + def write_json(self): + # 在写 JSON 前,统一获取统计值 + stat_result = self.flush_stat_stack() + # 遍历 cache_data,将占位符替换为最终统计值 + if stat_result: + self._replace_stat_placeholders(self.cache_data, stat_result) if self.cache_data: self.write_data_json(self.dump_file_path) if self.cache_stack: @@ -136,6 +173,43 @@ class DataWriter: if self.cache_debug: self.write_debug_info_json(self.debug_file_path) + def _replace_stat_placeholders(self, data, stat_result): + if isinstance(data, dict): + keys = list(data.keys()) # 获取当前所有键 + for key in keys: # 避免遍历时修改字典 + value = data[key] + if key == "tensor_stat_index" and isinstance(value, int): + idx = value + stat_values = stat_result[idx] if idx < len(stat_result) else [None] * 4 + + # 构建新字段并删除旧键 + new_entries = { + "type": data["type"], + "dtype": data["dtype"], + "shape": data["shape"], + "Max": stat_values[0], + "Min": stat_values[1], + "Mean": stat_values[2], + "Norm": stat_values[3] + } + del data[key] + + # 重构字典顺序 + updated_dict = {} + # 先插入统计字段 + updated_dict.update(new_entries) + # 保留原字典其他字段(排除已删除的tensor_stat_index) + for k in data: + if k not in new_entries: + updated_dict[k] = data[k] + data.clear() + data.update(updated_dict) + else: + self._replace_stat_placeholders(value, stat_result) + elif isinstance(data, list): + for item in data: + self._replace_stat_placeholders(item, stat_result) + def fill_stack_tensor_data(self): self.process_stat_data_recursive(self.cache_data)