From c43e495effbbf3ed4036973831d264c848d4042a Mon Sep 17 00:00:00 2001 From: curry3 <485078529@qq.com> Date: Sat, 6 Sep 2025 15:36:06 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90bugfix=E3=80=91=E4=BF=AE=E5=A4=8Dfsdp2?= =?UTF-8?q?=E5=9C=BA=E6=99=AFdump=E6=98=BE=E5=AD=98=E8=86=A8=E8=83=80?= =?UTF-8?q?=E4=B8=8E=E6=8A=A5=E9=94=99=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../data_processor/pytorch_processor.py | 63 +++++++++++++------ 1 file changed, 43 insertions(+), 20 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py index 9bef9ad2d8..eb277502a1 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py @@ -13,13 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import ctypes import os import zlib -import ctypes from collections.abc import Iterable +from concurrent.futures import ThreadPoolExecutor from dataclasses import asdict from typing import List -from concurrent.futures import ThreadPoolExecutor import numpy as np import torch @@ -29,7 +29,6 @@ from torch.distributed.distributed_c10d import _get_default_group from msprobe.core.common.const import Const from msprobe.core.common.decorator import recursion_depth_decorator from msprobe.core.common.exceptions import MsprobeException -from msprobe.core.common.file_utils import path_len_exceeds_limit from msprobe.core.common.log import logger from msprobe.core.common.utils import convert_tuple, is_int from msprobe.core.data_dump.data_processor.base import BaseDataProcessor, ModuleBackwardInputsOutputs, \ @@ -49,6 +48,16 @@ class TensorHandler: self.has_dtensor = hasattr(dist, "tensor") and hasattr(dist.tensor, "DTensor") self.has_fake_tensor = hasattr(torch, "_subclasses") and hasattr(torch._subclasses, "fake_tensor") + @staticmethod + def free_tensor(tensor, tensor_name): + try: + tensor.untyped_storage().resize_(0) + except Exception as e: + logger.wraning( + f"Failed to clear {tensor_name} tensor cache, which may lead to increased device memory usage. " + f"The reason: {str(e)}." + ) + def is_dtensor(self, tensor): return self.has_dtensor and isinstance(tensor, torch.distributed.tensor.DTensor) @@ -94,6 +103,24 @@ class TensorHandler: dtensor_info.update({"placements": placements}) return dtensor_info + def save_tensor(self, tensor, file_path): + common_tensor = self.convert_common_tensor(tensor) + if self.is_empty_data(common_tensor): + logger.debug(f"Saving fake tensor or meta tensor is not supported, the current tensor is {file_path}.") + return + if common_tensor.untyped_storage().data_ptr() == 0: + logger.debug(f"Saving null-pointer tensor is not supported, the current tensor is {file_path}.") + return + + try: + saved_tensor = common_tensor.contiguous().detach() + save_pt(saved_tensor, file_path) + except Exception as e: + logger.debug(f"Failed to save {file_path} tensor, trying clone before save. The failed reason: {str(e)}.") + saved_tensor = common_tensor.clone().contiguous().detach() + save_pt(saved_tensor, file_path) + self.free_tensor(saved_tensor, file_path) + class PytorchDataProcessor(BaseDataProcessor): pytorch_special_type = ( @@ -288,7 +315,7 @@ class PytorchDataProcessor(BaseDataProcessor): def dump_async_data(self): for file_path, tensor in self._async_dump_cache.items(): - save_pt(tensor.contiguous(), file_path) + self.tensor_handler.save_tensor(tensor, file_path) self._async_dump_cache.clear() def analyze_single_element(self, element, suffix_stack): @@ -385,24 +412,23 @@ class PytorchDataProcessor(BaseDataProcessor): def _analyze_and_save_tensor(self, tensor, suffix): dump_data_name, file_path = self.get_save_file_path(suffix) single_arg = PytorchDataProcessor._analyze_tensor(self, tensor, suffix) - if self.tensor_handler.is_empty_data(tensor) or tensor.untyped_storage().data_ptr() == 0: - logger.debug( - "Collecting real data of fake tensor or meta tensor is not supported or data_ptr is 0, " - f"the current api/module name is {self.current_api_or_module_name}." - ) - return single_arg - single_arg.update({"data_name": dump_data_name}) if self.config.async_dump: - self._async_dump_cache[file_path] = tensor.clone().detach() + common_tensor = self.tensor_handler.convert_common_tensor(tensor) + if self.tensor_handler.is_empty_data(common_tensor) or common_tensor.untyped_storage().data_ptr() == 0: + logger.warning( + "Collecting real data of fake tensor, meta tensor or null-pointer tensor is not supported, " + f"the current tensor is {file_path}." + ) + return single_arg + self._async_dump_cache[file_path] = common_tensor.clone().detach() else: - saved_tensor = tensor.clone().contiguous().detach() - save_pt(saved_tensor, file_path) + self.tensor_handler.save_tensor(tensor, file_path) return single_arg def _analyze_and_save_ndarray(self, ndarray, suffix): dump_data_name, file_path = self.get_save_file_path(suffix) - save_pt(torch.tensor(ndarray), file_path) + self.tensor_handler.save_tensor(torch.tensor(ndarray), file_path) ndarray_json = PytorchDataProcessor._analyze_ndarray(ndarray, suffix) ndarray_json.update({"data_name": dump_data_name}) return ndarray_json @@ -493,7 +519,7 @@ class OverflowCheckDataProcessor(PytorchDataProcessor): self._analyze_maybe_overflow_flag() if self.has_overflow: for file_path, tensor in self.cached_tensors_and_file_paths.items(): - save_pt(tensor.clone().contiguous().detach(), file_path) + self.tensor_handler.save_tensor(tensor, file_path) self.real_overflow_nums += 1 if self.overflow_nums != -1 and self.real_overflow_nums >= self.overflow_nums: logger.info(f"[{Const.TOOL_NAME}] Reached the preset overflow times, " @@ -538,10 +564,7 @@ class OverflowCheckDataProcessor(PytorchDataProcessor): def _analyze_tensor(self, tensor, suffix): dump_data_name, file_path = self.get_save_file_path(suffix) - if not path_len_exceeds_limit(file_path): - self.cached_tensors_and_file_paths.update({file_path: tensor}) - else: - logger.warning(f'The file path {file_path} length exceeds limit.') + self.cached_tensors_and_file_paths.update({file_path: tensor}) single_arg = super()._analyze_tensor(tensor, suffix) single_arg.update({"data_name": dump_data_name}) if not self.has_overflow and self.support_inf_nan: -- Gitee