From 008f2b88b8b4a0cc0e5e363a40ea9ede99ef1b38 Mon Sep 17 00:00:00 2001 From: makai Date: Sun, 29 Sep 2024 17:38:00 +0800 Subject: [PATCH] check split scope --- .../msprobe/core/common/const.py | 2 ++ .../msprobe/core/common/utils.py | 6 ++--- .../msprobe/core/data_dump/scope.py | 27 ++++++++++++++----- 3 files changed, 26 insertions(+), 9 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/common/const.py b/debug/accuracy_tools/msprobe/core/common/const.py index 389ba2013..6e18cda53 100644 --- a/debug/accuracy_tools/msprobe/core/common/const.py +++ b/debug/accuracy_tools/msprobe/core/common/const.py @@ -125,6 +125,8 @@ class Const: NAME_FIRST_POSSIBLE_INDEX = -4 NAME_SECOND_POSSIBLE_INDEX = -5 + + MIN_SPLIT_SCOPE_LENGTH = 1 INPLACE_LIST = [ "broadcast", "all_reduce", "reduce", "all_gather", "gather", "scatter", "reduce_scatter", diff --git a/debug/accuracy_tools/msprobe/core/common/utils.py b/debug/accuracy_tools/msprobe/core/common/utils.py index 910d3a769..77f301948 100644 --- a/debug/accuracy_tools/msprobe/core/common/utils.py +++ b/debug/accuracy_tools/msprobe/core/common/utils.py @@ -287,10 +287,10 @@ def print_tools_ends_info(): def get_step_or_rank_from_string(step_or_rank, obj): - splited = step_or_rank.split(Const.HYPHEN) - if len(splited) == 2: + split = step_or_rank.split(Const.HYPHEN) + if len(split) == 2: try: - borderlines = int(splited[0]), int(splited[1]) + borderlines = int(split[0]), int(split[1]) except (ValueError, IndexError) as e: raise MsprobeException(MsprobeException.INVALID_PARAM_ERROR, "The hyphen(-) must start and end with decimal numbers.") from e diff --git a/debug/accuracy_tools/msprobe/core/data_dump/scope.py b/debug/accuracy_tools/msprobe/core/data_dump/scope.py index 00df77870..edd62244d 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/scope.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/scope.py @@ -120,15 +120,30 @@ class RangeScope(BaseScope, ABC): class APIRangeScope(RangeScope): + + @staticmethod + def is_length_greater_than_or_equal_one(split_scope): + if len(split_scope) >= Const.MIN_SPLIT_SCOPE_LENGTH: + return True + return False + def check_scope_is_valid(self): if not self.scope: return True - scope_start_type = self.scope[0].split(Const.SEP)[0] - if scope_start_type in BaseScope.module_type: - return False - scope_stop_type = self.scope[1].split(Const.SEP)[0] - if scope_stop_type in BaseScope.module_type: - return False + if self.is_length_greater_than_or_equal_one(self.scope[0].split(Const.SEP)): + scope_start_type = self.scope[0].split(Const.SEP)[0] + if scope_start_type in BaseScope.module_type: + return False + else: + raise ScopeException(ScopeException.InvalidScope, + f"The split scope list must have a length greater than or equal to 1.") + if self.is_length_greater_than_or_equal_one(self.scope[1].split(Const.SEP)): + scope_stop_type = self.scope[1].split(Const.SEP)[0] + if scope_stop_type in BaseScope.module_type: + return False + else: + raise ScopeException(ScopeException.InvalidScope, + f"The split scope list must have a length greater than or equal to 1.") return True def check(self, api_name): -- Gitee