diff --git a/debug/accuracy_tools/msprobe/core/common/const.py b/debug/accuracy_tools/msprobe/core/common/const.py index 560d939b345e169a84dd6a06f58749115e93333b..d446be479f51d87bd4676ffdc4d5f4d3e454de2d 100644 --- a/debug/accuracy_tools/msprobe/core/common/const.py +++ b/debug/accuracy_tools/msprobe/core/common/const.py @@ -266,6 +266,10 @@ class Const: TENSOR_STAT_LEN = 2 + TENSOR_TYPE = "torch.Tensor" + DTENSOR_TYPE = "torch.distributed.tensor.DTensor" + FAKE_TENSOR_TYPE = "torch._subclasses.fake_tensor.FakeTensor" + SUPPORT_API_FILE_NAME = "support_wrap_ops.yaml" PT_API_TYPE_FUNCTIONAL = "functional" diff --git a/debug/accuracy_tools/msprobe/core/common/utils.py b/debug/accuracy_tools/msprobe/core/common/utils.py index 3f57d095a666b5122ac9de264e8ade941bb36c5e..cc3196e41ea674443f42ab55abb5bfa1310a76dd 100644 --- a/debug/accuracy_tools/msprobe/core/common/utils.py +++ b/debug/accuracy_tools/msprobe/core/common/utils.py @@ -559,7 +559,7 @@ def check_token_range(token_range): raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR) start, end = token_range - if not isinstance(start, int) or not isinstance(end, int): + if not is_int(start) or not is_int(end): logger.error("Start and end in token_range must be integer.") raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR) if start > end: diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py index 398419fea058e81280d10e458cb67cb512f11051..e3d5b5bdd23e320beb8c0ce688f7fb24b4fcd305 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py @@ -14,6 +14,7 @@ # limitations under the License. import zlib +from collections.abc import Iterable from dataclasses import asdict from typing import List @@ -23,11 +24,11 @@ from torch import distributed as dist from torch.distributed.distributed_c10d import _get_default_group from msprobe.core.common.const import Const +from msprobe.core.common.decorator import recursion_depth_decorator from msprobe.core.common.exceptions import MsprobeException from msprobe.core.common.file_utils import path_len_exceeds_limit from msprobe.core.common.log import logger -from msprobe.core.common.utils import convert_tuple -from msprobe.core.common.decorator import recursion_depth_decorator +from msprobe.core.common.utils import convert_tuple, is_int from msprobe.core.data_dump.data_processor.base import BaseDataProcessor, ModuleBackwardInputsOutputs, \ ModuleForwardInputsOutputs, TensorStatInfo from msprobe.pytorch.common.utils import save_pt @@ -40,6 +41,57 @@ except ImportError: is_gpu = True +class TensorHandler: + def __init__(self): + self.has_dtensor = hasattr(dist, "tensor") and hasattr(dist.tensor, "DTensor") + self.has_fake_tensor = hasattr(torch, "_subclasses") and hasattr(torch._subclasses, "fake_tensor") + + def is_dtensor(self, tensor): + return self.has_dtensor and isinstance(tensor, torch.distributed.tensor.DTensor) + + def is_fake_tensor(self, tensor): + return self.has_fake_tensor and isinstance(tensor, torch._subclasses.fake_tensor.FakeTensor) + + def is_empty_data(self, tensor): + return tensor.is_meta or self.is_fake_tensor(tensor) + + def convert_common_tensor(self, tensor): + if self.is_dtensor(tensor): + return tensor.to_local() + if self.is_fake_tensor(tensor): + logger.debug("FakeTensor cannot be converted to torch.Tensor type.") + return tensor + return tensor + + def get_tensor_type(self, tensor): + if self.is_dtensor(tensor): + return Const.DTENSOR_TYPE + if self.is_fake_tensor(tensor): + return Const.FAKE_TENSOR_TYPE + return Const.TENSOR_TYPE + + def get_dtensor_info(self, tensor): + dtensor_info = {} + if not self.is_dtensor(tensor): + return dtensor_info + if hasattr(tensor, "device_mesh") and tensor.device_mesh: + dtensor_info.update({"device_mesh": tensor.device_mesh.mesh.tolist()}) + + placements = [] + if hasattr(tensor, "placements") and isinstance(tensor.placements, Iterable): + for placement in tensor.placements: + if placement.is_shard() and is_int(placement.dim): + placements.append({"Shard": {"dim": placement.dim}}) + continue + if placement.is_replicate(): + placements.append({"Replicate": {}}) + continue + if placement.is_partial() and isinstance(placement.reduce_op, str): + placements.append({"Partial": {"reduce_op": placement.reduce_op}}) + dtensor_info.update({"placements": placements}) + return dtensor_info + + class PytorchDataProcessor(BaseDataProcessor): pytorch_special_type = ( torch.device, @@ -65,6 +117,7 @@ class PytorchDataProcessor(BaseDataProcessor): "dtype": self.analyze_dtype_in_kwargs } self._async_dump_cache = {} + self.tensor_handler = TensorHandler() @staticmethod def get_md5_for_tensor(x): @@ -94,54 +147,6 @@ class PytorchDataProcessor(BaseDataProcessor): def analyze_dtype_in_kwargs(element): return {"type": "torch.dtype", "value": str(element)} - @staticmethod - def get_stat_info(data, async_dump=False, precision=Const.DUMP_PRECISION_HIGH): - tensor_stat = TensorStatInfo() - if data.is_meta: - return tensor_stat - data_clone = data.detach() - if not data_clone.numel() or not data_clone.data_ptr(): - return tensor_stat - if torch.is_complex(data): - if async_dump: - logger.warning("Async dump do not support complex data!") - return tensor_stat - data_np = data.cpu().numpy() - data_abs = np.abs(data_np) - tensor_stat.max = np.max(data_abs).item() - tensor_stat.min = np.min(data_abs).item() - tensor_stat.mean = np.mean(data_abs).item() - elif data.dtype == torch.bool: - tensor_stat.max = torch.any(data) - tensor_stat.min = torch.all(data) - elif not data.shape: - tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data.clone() - else: - if precision == Const.DUMP_PRECISION_HIGH or data.dtype == torch.float64 or not data.is_floating_point(): - data = data.float() - tensor_stat.max = torch.max(data) - tensor_stat.min = torch.min(data) - tensor_stat.mean = torch.mean(data) - tensor_stat.norm = torch.norm(data) - return tensor_stat - - @staticmethod - def handle_tensor_extremum_nan_inf(tensor, operator): - data_clone = tensor.detach() - data_nan = torch.isnan(data_clone) - if int(torch.sum(data_nan)) == data_clone.numel(): - return float('nan') - - finite_mask = torch.isfinite(data_clone) - if int(torch.sum(finite_mask)) > 0: - finite_values = data_clone[finite_mask] - return torch.max(finite_values).item() if operator == 'max' else \ - torch.min(finite_values).item() - else: - data_no_nan = data_clone[~data_nan] - return torch.max(data_no_nan).item() if operator == 'max' else \ - torch.min(data_no_nan).item() - @staticmethod def process_group_hash(arg): group_ranks = dist.get_process_group_ranks(arg) @@ -188,6 +193,36 @@ class PytorchDataProcessor(BaseDataProcessor): def get_special_types(cls): return super().get_special_types() + cls.pytorch_special_type + def get_stat_info(self, data, async_dump=False, precision=Const.DUMP_PRECISION_HIGH): + tensor_stat = TensorStatInfo() + if self.tensor_handler.is_empty_data(data): + return tensor_stat + data_clone = data.detach() + if not data_clone.numel() or not data_clone.data_ptr(): + return tensor_stat + if torch.is_complex(data): + if async_dump: + logger.warning("Async dump do not support complex data!") + return tensor_stat + data_np = data.cpu().numpy() + data_abs = np.abs(data_np) + tensor_stat.max = np.max(data_abs).item() + tensor_stat.min = np.min(data_abs).item() + tensor_stat.mean = np.mean(data_abs).item() + elif data.dtype == torch.bool: + tensor_stat.max = torch.any(data) + tensor_stat.min = torch.all(data) + elif not data.shape: + tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data.clone() + else: + if precision == Const.DUMP_PRECISION_HIGH or data.dtype == torch.float64 or not data.is_floating_point(): + data = data.float() + tensor_stat.max = torch.max(data) + tensor_stat.min = torch.min(data) + tensor_stat.mean = torch.mean(data) + tensor_stat.norm = torch.norm(data) + return tensor_stat + def dump_async_data(self): for file_path, tensor in self._async_dump_cache.items(): save_pt(tensor.contiguous(), file_path) @@ -230,9 +265,10 @@ class PytorchDataProcessor(BaseDataProcessor): return p2pop_info def _analyze_tensor(self, tensor, suffix): - tensor_stat = self.get_stat_info(tensor, self.config.async_dump, self.config.precision) + common_tensor = self.tensor_handler.convert_common_tensor(tensor) + tensor_stat = self.get_stat_info(common_tensor, self.config.async_dump, self.config.precision) tensor_json = {} - tensor_json.update({'type': 'torch.Tensor'}) + tensor_json.update({'type': self.tensor_handler.get_tensor_type(tensor)}) tensor_json.update({'dtype': str(tensor.dtype)}) tensor_json.update({"shape": tensor.shape}) @@ -246,15 +282,25 @@ class PytorchDataProcessor(BaseDataProcessor): tensor_json.update({Const.TENSOR_STAT_INDEX: placeholder_index}) tensor_json.update({"requires_grad": tensor.requires_grad}) + if self.tensor_handler.is_dtensor(tensor): + dtensor_info = self.tensor_handler.get_dtensor_info(tensor) + tensor_json.update(dtensor_info) if self.config.summary_mode == Const.MD5 and not self.config.async_dump: - tensor_md5 = self.get_md5_for_tensor(tensor) + tensor_md5 = None + if not self.tensor_handler.is_empty_data(tensor): + logger.debug("Calculating the md5 value of fake tensor or meta tensor is not supported.") + tensor_md5 = self.get_md5_for_tensor(common_tensor) tensor_json.update({Const.MD5: tensor_md5}) return tensor_json def _analyze_and_save_tensor(self, tensor, suffix): dump_data_name, file_path = self.get_save_file_path(suffix) single_arg = PytorchDataProcessor._analyze_tensor(self, tensor, suffix) + if self.tensor_handler.is_empty_data(tensor): + logger.debug("Collecting real data of fake tensor or meta tensor is not supported.") + return single_arg + single_arg.update({"data_name": dump_data_name}) if self.config.async_dump: self._async_dump_cache[file_path] = tensor.clone().detach() diff --git a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_pytorch_processor.py b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_pytorch_processor.py index 33071c0e01c800858a892061927d470c06bcff73..2cd997f08f68d369f7e8371a9b732540ddc985ef 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_pytorch_processor.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_pytorch_processor.py @@ -60,13 +60,13 @@ class TestPytorchDataProcessor(unittest.TestCase): def test_get_stat_info_with_meta_tensor(self): mock_data = self.mock_tensor(is_meta=True) - result = PytorchDataProcessor.get_stat_info(mock_data) + result = self.processor.get_stat_info(mock_data) self.assertIsInstance(result, TensorStatInfo) def test_get_stat_info_with_fake_tensor(self): with FakeTensorMode() as fake_tensor_mode: fake_tensor = fake_tensor_mode.from_tensor(torch.randn(1, 2, 3)) - result = PytorchDataProcessor.get_stat_info(fake_tensor) + result = self.processor.get_stat_info(fake_tensor) self.assertIsNone(result.max) self.assertIsNone(result.min) self.assertIsNone(result.mean) @@ -107,7 +107,7 @@ class TestPytorchDataProcessor(unittest.TestCase): def test_get_stat_info_with_scalar_tensor(self): scalar_tensor = torch.tensor(42.0) - result = PytorchDataProcessor.get_stat_info(scalar_tensor) + result = self.processor.get_stat_info(scalar_tensor) self.assertIsInstance(result, TensorStatInfo) self.assertEqual(result.max, 42.0) self.assertEqual(result.min, 42.0) @@ -116,7 +116,7 @@ class TestPytorchDataProcessor(unittest.TestCase): def test_get_stat_info_with_complex_tensor(self): complex_tensor = torch.tensor([1 + 2j, 3 + 4j], dtype=torch.complex64) - result = PytorchDataProcessor.get_stat_info(complex_tensor) + result = self.processor.get_stat_info(complex_tensor) expected_max = np.abs(np.array([1 + 2j, 3 + 4j])).max().item() expected_min = np.abs(np.array([1 + 2j, 3 + 4j])).min().item() expected_mean = np.abs(np.array([1 + 2j, 3 + 4j])).mean().item() @@ -125,49 +125,6 @@ class TestPytorchDataProcessor(unittest.TestCase): self.assertAlmostEqual(result.min, expected_min, places=6) self.assertAlmostEqual(result.mean, expected_mean, places=6) - def test_handle_tensor_extremum_nan_inf_all_nan(self): - tensor = torch.tensor([float('nan'), float('nan')]) - result = self.processor.handle_tensor_extremum_nan_inf(tensor, 'max') - self.assertTrue(np.isnan(result)) - - def test_handle_tensor_extremum_nan_inf_all_inf(self): - tensor = torch.tensor([float('inf'), float('inf')]) - result = self.processor.handle_tensor_extremum_nan_inf(tensor, 'max') - self.assertTrue(np.isinf(result)) - - def test_handle_tensor_extremum_nan_inf_all_negative_inf(self): - tensor = torch.tensor([float('-inf'), float('-inf')]) - result = self.processor.handle_tensor_extremum_nan_inf(tensor, 'min') - self.assertTrue(np.isinf(result) and result < 0) - - def test_handle_tensor_extremum_nan_inf_mixed(self): - tensor = torch.tensor([1.0, float('nan'), 3.0, float('-inf'), 2.0]) - result_max = self.processor.handle_tensor_extremum_nan_inf(tensor, 'max') - result_min = self.processor.handle_tensor_extremum_nan_inf(tensor, 'min') - self.assertEqual(result_max, 3.0) - self.assertEqual(result_min, 1.0) - - def test_handle_tensor_extremum_nan_inf_mixed_with_inf(self): - tensor = torch.tensor([1.0, float('nan'), 3.0, float('inf'), 2.0]) - result_max = self.processor.handle_tensor_extremum_nan_inf(tensor, 'max') - result_min = self.processor.handle_tensor_extremum_nan_inf(tensor, 'min') - self.assertEqual(result_max, 3.0) - self.assertEqual(result_min, 1.0) - - def test_handle_tensor_extremum_nan_inf_no_inf_nan(self): - tensor = torch.tensor([1.0, 2.0, 3.0]) - result_max = self.processor.handle_tensor_extremum_nan_inf(tensor, 'max') - result_min = self.processor.handle_tensor_extremum_nan_inf(tensor, 'min') - self.assertEqual(result_max, 3.0) - self.assertEqual(result_min, 1.0) - - def test_handle_tensor_extremum_nan_inf_all_inf_nan(self): - tensor = torch.tensor([float('nan'), float('inf'), float('-inf')]) - result_max = self.processor.handle_tensor_extremum_nan_inf(tensor, 'max') - result_min = self.processor.handle_tensor_extremum_nan_inf(tensor, 'min') - self.assertTrue(np.isinf(result_max)) - self.assertTrue(np.isinf(result_min)) - def test_analyze_builtin(self): result = self.processor._analyze_builtin(slice(1, torch.tensor(10, dtype=torch.int32), np.int64(2))) expected = {'type': 'slice', 'value': [1, 10, 2]}