From 4ff2051a08989e41d53abe0ff7de7a6eec7ca3fb Mon Sep 17 00:00:00 2001 From: pxp1 <958876660@qq.com> Date: Wed, 3 Sep 2025 16:28:01 +0800 Subject: [PATCH] =?UTF-8?q?=E9=80=82=E9=85=8Drlhf=E6=AF=94=E5=AF=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../core/single_save/single_comparator.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/single_save/single_comparator.py b/debug/accuracy_tools/msprobe/core/single_save/single_comparator.py index 70c4b8269..61095de62 100644 --- a/debug/accuracy_tools/msprobe/core/single_save/single_comparator.py +++ b/debug/accuracy_tools/msprobe/core/single_save/single_comparator.py @@ -14,6 +14,7 @@ # limitations under the License. import os +import re import multiprocessing from dataclasses import dataclass @@ -70,6 +71,9 @@ class SingleComparator: 比较两个NumPy数组,计算最大绝对误差、最大相对误差和相同元素的百分比 """ # 计算每个维度上的最小尺寸 + if array1.ndim != array2.ndim: + array1 = array1.flatten() + array2 = array2.flatten() min_shape = [min(s1, s2) for s1, s2 in zip(array1.shape, array2.shape)] # 截取数组到相同的形状 sliced_array1 = array1[tuple(slice(0, s) for s in min_shape)] @@ -176,9 +180,18 @@ class SingleComparator: continue for step, step_path in cls.get_steps(tag_path): for rank, rank_path in cls.get_ranks(step_path): - for micro_step, micro_step_path in cls.get_micro_steps(rank_path): - for array_id, array_path in cls.get_arrays(micro_step_path): - array_paths.setdefault(tag, []).append((step, rank, micro_step, array_id, array_path)) + for item in os.listdir(rank_path): + next_path = os.path.join(rank_path, item) + if re.match(r"micro_step(\d+)", item): + micro_step = re.match(r"micro_step(\d+)", item).group(1) + for array_id, array_path in cls.get_arrays(next_path): + array_paths.setdefault(tag, []).append( + (step, rank, int(micro_step), array_id, array_path)) + elif re.match(r"\w{1,100}_(\d{1,100})\.npy", item): + array_id = re.match(r"\w{1,100}_(\d{1,100})\.npy", item).group(1) + array_paths.setdefault(tag, []).append((step, rank, 0, int(array_id), next_path)) + else: + array_paths.setdefault(tag, []).append((step, rank, 0, 0, next_path)) return array_paths @classmethod -- Gitee