From 3c93731d28d5f408d60152f53c257b33f08821d3 Mon Sep 17 00:00:00 2001 From: curry3 <485078529@qq.com> Date: Tue, 10 Sep 2024 20:27:40 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90feature=E3=80=91msprobe=20dump?= =?UTF-8?q?=E6=94=AF=E6=8C=81dist.ProcessGroup=E7=B1=BB=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../data_processor/pytorch_processor.py | 15 +++++++++++++- .../data_processor/test_pytorch_processor.py | 20 +++++++++++++++++++ 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py index ba66116478..3d230d237c 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py @@ -6,6 +6,7 @@ from typing import List import numpy as np import torch +import torch.distributed as dist from msprobe.core.common.file_utils import path_len_exceeds_limit, change_mode from msprobe.core.common.log import logger from msprobe.core.common.const import Const, OverflowConst, FileCheckConst @@ -22,7 +23,7 @@ except ImportError: class PytorchDataProcessor(BaseDataProcessor): - pytorch_special_type = (torch.device, torch.dtype, torch.Size, torch.Tensor) + pytorch_special_type = (torch.device, torch.dtype, torch.Size, torch.Tensor, dist.ProcessGroup) def __init__(self, config, data_writer): super().__init__(config, data_writer) @@ -105,6 +106,16 @@ class PytorchDataProcessor(BaseDataProcessor): def _analyze_torch_size(arg): return {"type": "torch.Size", "value": list(arg)} + @staticmethod + def _analyze_process_group(arg): + group_info = {"type": "torch.distributed.ProcessGroup", "group_id": id(arg)} + try: + group_ranks = dist.get_process_group_ranks(arg) + group_info.update({"group_ranks": group_ranks}) + except Exception as e: + logger.warning(f"Failed to get process group(id: {id(arg)}) ranks info, the error info: {e}.") + return group_info + @classmethod def get_special_types(cls): return super().get_special_types() + cls.pytorch_special_type @@ -114,6 +125,8 @@ class PytorchDataProcessor(BaseDataProcessor): return self.torch_object_key[suffix_stack[-1]](element) if isinstance(element, torch.Size): return self._analyze_torch_size(element) + if isinstance(element, dist.ProcessGroup): + return self._analyze_process_group(element) converted_numpy, numpy_type = self._convert_numpy_to_builtin(element) if converted_numpy is not element: return self._analyze_numpy(converted_numpy, numpy_type) diff --git a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_pytorch_processor.py b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_pytorch_processor.py index 7d44c49e2c..6f28c27dc2 100644 --- a/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_pytorch_processor.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/data_dump/data_processor/test_pytorch_processor.py @@ -131,6 +131,26 @@ class TestPytorchDataProcessor(unittest.TestCase): expected = {'type': 'torch.Size', 'value': [3, 4, 5]} self.assertEqual(result, expected) + @patch('msprobe.core.data_dump.data_processor.pytorch_processor.dist.get_process_group_ranks') + def test_analyze_process_group_success(self, mock_get_process_group_ranks): + mock_get_process_group_ranks.return_value = [0, 1, 2] + arg = MagicMock() + group_info = self.processor._analyze_process_group(arg) + + self.assertEqual(group_info["type"], "torch.distributed.ProcessGroup") + self.assertEqual(group_info["group_id"], id(arg)) + self.assertEqual(group_info["group_ranks"], [0, 1, 2]) + + @patch('msprobe.core.data_dump.data_processor.pytorch_processor.dist.get_process_group_ranks', + side_effect=Exception('Error message')) + def test_analyze_process_group_exception(self, _): + arg = MagicMock() + group_info = self.processor._analyze_process_group(arg) + + self.assertEqual(group_info["type"], "torch.distributed.ProcessGroup") + self.assertEqual(group_info["group_id"], id(arg)) + self.assertNotIn("group_ranks", group_info) + def test_get_special_types(self): special_types = self.processor.get_special_types() self.assertIn(torch.Tensor, special_types) -- Gitee