From e112dd899990c8a4d2c7d50dbc6925e3db581fd1 Mon Sep 17 00:00:00 2001 From: fanglanyue Date: Thu, 4 Sep 2025 11:40:42 +0800 Subject: [PATCH] extract rank_id from pattern match --- .../prof_data_allocate.py | 37 ++++++++++++++----- .../test_prof_data_allocate.py | 15 ++++++-- 2 files changed, 40 insertions(+), 12 deletions(-) diff --git a/profiler/msprof_analyze/cluster_analyse/cluster_data_preprocess/prof_data_allocate.py b/profiler/msprof_analyze/cluster_analyse/cluster_data_preprocess/prof_data_allocate.py index 9db3e8137..e4f9e7af2 100644 --- a/profiler/msprof_analyze/cluster_analyse/cluster_data_preprocess/prof_data_allocate.py +++ b/profiler/msprof_analyze/cluster_analyse/cluster_data_preprocess/prof_data_allocate.py @@ -30,10 +30,9 @@ logger = get_logger() class ProfDataAllocate: - DB_PATTERNS = { - Constant.PYTORCH: re.compile(r'^ascend_pytorch_profiler(?:_\d+)?\.db$'), - Constant.MINDSPORE: re.compile(r'^ascend_mindspore_profiler(?:_\d+)?\.db$'), + Constant.PYTORCH: re.compile(r'^ascend_pytorch_profiler(?:_(\d+))?\.db$'), + Constant.MINDSPORE: re.compile(r'^ascend_mindspore_profiler(?:_(\d+))?\.db$'), Constant.MSPROF: re.compile(r'^msprof_\d{14}\.db$'), Constant.MSMONITOR: re.compile(r'^msmonitor_(\d+)_\d{17}_(-1|\d+)\.db$') } @@ -41,6 +40,7 @@ class ProfDataAllocate: ASCEND_PT = "ascend_pt" ASCEND_MS = "ascend_ms" PROF = "PROF_" + DEFAULT_RANK_ID = -1 def __init__(self, profiling_path): self.profiling_path = profiling_path @@ -60,13 +60,33 @@ class ProfDataAllocate: @staticmethod def _extract_rank_id_from_profiler_db(file_name: str, prof_type: str): """从profiler_db文件名中提取rank_id,传入的file_name已经过正则匹配""" + if prof_type not in [Constant.PYTORCH, Constant.MINDSPORE, Constant.MSMONITOR]: + logger.error(f"Unsupported prof_type {prof_type}. Can not extract rank_id from profile db.") + return None + + pattern = ProfDataAllocate.DB_PATTERNS[prof_type] + match = pattern.match(file_name) + + if not match: + return None + try: - if prof_type in [Constant.PYTORCH, Constant.MINDSPORE, Constant.MSMONITOR]: - return int(file_name.strip(".db").split("_")[-1]) + if prof_type == Constant.MSMONITOR: + # msmonitor格式:第二个捕获组是rank_id + rank_str = match.group(2) else: - logger.error(f"Unsupported prof_type {prof_type}. Can not extract rank_id from profile db.") - return None - except (IndexError, ValueError): + # pytorch和mindspore格式:第一个捕获组是rank_id + rank_str = match.group(1) + + # 处理特殊情况:ascend_pytorch_profiler.db(捕获组为None) + if rank_str is None: + logger.warning(f"No rank_id for {file_name}. Using default value {ProfDataAllocate.DEFAULT_RANK_ID}.") + return ProfDataAllocate.DEFAULT_RANK_ID + + return int(rank_str) + + except (IndexError, ValueError) as e: + logger.error(f"Failed to extract rank_id from {file_name}: {str(e)}") return None @staticmethod @@ -204,4 +224,3 @@ class ProfDataAllocate: self.data_type = data_type self.data_map = data_map - diff --git a/profiler/msprof_analyze/test/ut/cluster_analyse/cluster_data_preprocess/test_prof_data_allocate.py b/profiler/msprof_analyze/test/ut/cluster_analyse/cluster_data_preprocess/test_prof_data_allocate.py index aa749b856..53c82a01a 100644 --- a/profiler/msprof_analyze/test/ut/cluster_analyse/cluster_data_preprocess/test_prof_data_allocate.py +++ b/profiler/msprof_analyze/test/ut/cluster_analyse/cluster_data_preprocess/test_prof_data_allocate.py @@ -77,7 +77,7 @@ class TestProfDataAllocate(unittest.TestCase): result = ProfDataAllocate.match_file_pattern_in_dir(test_dir, pattern) self.assertEqual(result, "") - + def test_extract_rank_id_from_profiler_db_when_pytorch_file_then_return_rank_id(self): """Test extracting rank ID from PyTorch profiler DB filename""" file_name = "ascend_pytorch_profiler_1.db" @@ -104,15 +104,24 @@ class TestProfDataAllocate(unittest.TestCase): result = ProfDataAllocate._extract_rank_id_from_profiler_db(file_name, prof_type) self.assertEqual(result, 1) - + def test_extract_rank_id_from_profiler_db_when_invalid_format_then_return_none(self): """Test extracting rank ID from invalid filename format""" file_name = "invalid_filename.db" prof_type = Constant.PYTORCH + + result = ProfDataAllocate._extract_rank_id_from_profiler_db(file_name, prof_type) + + self.assertIsNone(result) + + def test_extract_rank_id_from_profiler_db_when_no_rank_id_then_return_minus_one(self): + """Test extracting rank ID from invalid filename format""" + file_name = "ascend_pytorch_profiler.db" + prof_type = Constant.PYTORCH result = ProfDataAllocate._extract_rank_id_from_profiler_db(file_name, prof_type) - self.assertIsNone(result) + self.assertEqual(result, -1) def test_extract_rank_id_from_profiler_db_when_unsupported_prof_type_then_return_none(self): """Test extracting rank ID from unsupported profiler type""" -- Gitee