diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py index ba66116478ff30931f7a065ae841b8b38b486054..3d230d237cea024a269d01c0f4992ebe50861fc8 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py @@ -6,6 +6,7 @@ from typing import List import numpy as np import torch +import torch.distributed as dist from msprobe.core.common.file_utils import path_len_exceeds_limit, change_mode from msprobe.core.common.log import logger from msprobe.core.common.const import Const, OverflowConst, FileCheckConst @@ -22,7 +23,7 @@ except ImportError: class PytorchDataProcessor(BaseDataProcessor): - pytorch_special_type = (torch.device, torch.dtype, torch.Size, torch.Tensor) + pytorch_special_type = (torch.device, torch.dtype, torch.Size, torch.Tensor, dist.ProcessGroup) def __init__(self, config, data_writer): super().__init__(config, data_writer) @@ -105,6 +106,16 @@ class PytorchDataProcessor(BaseDataProcessor): def _analyze_torch_size(arg): return {"type": "torch.Size", "value": list(arg)} + @staticmethod + def _analyze_process_group(arg): + group_info = {"type": "torch.distributed.ProcessGroup", "group_id": id(arg)} + try: + group_ranks = dist.get_process_group_ranks(arg) + group_info.update({"group_ranks": group_ranks}) + except Exception as e: + logger.warning(f"Failed to get process group(id: {id(arg)}) ranks info, the error info: {e}.") + return group_info + @classmethod def get_special_types(cls): return super().get_special_types() + cls.pytorch_special_type @@ -114,6 +125,8 @@ class PytorchDataProcessor(BaseDataProcessor): return self.torch_object_key[suffix_stack[-1]](element) if isinstance(element, torch.Size): return self._analyze_torch_size(element) + if isinstance(element, dist.ProcessGroup): + return self._analyze_process_group(element) converted_numpy, numpy_type = self._convert_numpy_to_builtin(element) if converted_numpy is not element: return self._analyze_numpy(converted_numpy, numpy_type) diff --git a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_pytorch_processor.py b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_pytorch_processor.py index 7d44c49e2ccfe714b6dfc67e6991e4b6ad2109e6..6f28c27dc2fc74d4336ff9d97423755d8d0afef8 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_pytorch_processor.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_pytorch_processor.py @@ -131,6 +131,26 @@ class TestPytorchDataProcessor(unittest.TestCase): expected = {'type': 'torch.Size', 'value': [3, 4, 5]} self.assertEqual(result, expected) + @patch('msprobe.core.data_dump.data_processor.pytorch_processor.dist.get_process_group_ranks') + def test_analyze_process_group_success(self, mock_get_process_group_ranks): + mock_get_process_group_ranks.return_value = [0, 1, 2] + arg = MagicMock() + group_info = self.processor._analyze_process_group(arg) + + self.assertEqual(group_info["type"], "torch.distributed.ProcessGroup") + self.assertEqual(group_info["group_id"], id(arg)) + self.assertEqual(group_info["group_ranks"], [0, 1, 2]) + + @patch('msprobe.core.data_dump.data_processor.pytorch_processor.dist.get_process_group_ranks', + side_effect=Exception('Error message')) + def test_analyze_process_group_exception(self, _): + arg = MagicMock() + group_info = self.processor._analyze_process_group(arg) + + self.assertEqual(group_info["type"], "torch.distributed.ProcessGroup") + self.assertEqual(group_info["group_id"], id(arg)) + self.assertNotIn("group_ranks", group_info) + def test_get_special_types(self): special_types = self.processor.get_special_types() self.assertIn(torch.Tensor, special_types)