diff --git a/debug/accuracy_tools/msprobe/core/single_save/single_comparator.py b/debug/accuracy_tools/msprobe/core/single_save/single_comparator.py index 70c4b826973a01805ac8db9b1105dc8d9dcd7923..61095de62cf693d53c4088593f644617476186bb 100644 --- a/debug/accuracy_tools/msprobe/core/single_save/single_comparator.py +++ b/debug/accuracy_tools/msprobe/core/single_save/single_comparator.py @@ -14,6 +14,7 @@ # limitations under the License. import os +import re import multiprocessing from dataclasses import dataclass @@ -70,6 +71,9 @@ class SingleComparator: 比较两个NumPy数组,计算最大绝对误差、最大相对误差和相同元素的百分比 """ # 计算每个维度上的最小尺寸 + if array1.ndim != array2.ndim: + array1 = array1.flatten() + array2 = array2.flatten() min_shape = [min(s1, s2) for s1, s2 in zip(array1.shape, array2.shape)] # 截取数组到相同的形状 sliced_array1 = array1[tuple(slice(0, s) for s in min_shape)] @@ -176,9 +180,18 @@ class SingleComparator: continue for step, step_path in cls.get_steps(tag_path): for rank, rank_path in cls.get_ranks(step_path): - for micro_step, micro_step_path in cls.get_micro_steps(rank_path): - for array_id, array_path in cls.get_arrays(micro_step_path): - array_paths.setdefault(tag, []).append((step, rank, micro_step, array_id, array_path)) + for item in os.listdir(rank_path): + next_path = os.path.join(rank_path, item) + if re.match(r"micro_step(\d+)", item): + micro_step = re.match(r"micro_step(\d+)", item).group(1) + for array_id, array_path in cls.get_arrays(next_path): + array_paths.setdefault(tag, []).append( + (step, rank, int(micro_step), array_id, array_path)) + elif re.match(r"\w{1,100}_(\d{1,100})\.npy", item): + array_id = re.match(r"\w{1,100}_(\d{1,100})\.npy", item).group(1) + array_paths.setdefault(tag, []).append((step, rank, 0, int(array_id), next_path)) + else: + array_paths.setdefault(tag, []).append((step, rank, 0, 0, next_path)) return array_paths @classmethod