From 1d462e3f6d6c622564769b4770f714ec9fa43d64 Mon Sep 17 00:00:00 2001 From: Linwei-Ying Date: Tue, 1 Jul 2025 19:23:11 +0800 Subject: [PATCH] compare None to Nan bugfix compare None to Nan bugfix --- debug/accuracy_tools/msprobe/core/common/const.py | 1 + debug/accuracy_tools/msprobe/core/compare/utils.py | 5 ++++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/debug/accuracy_tools/msprobe/core/common/const.py b/debug/accuracy_tools/msprobe/core/common/const.py index 59c3a354f..7b90110f2 100644 --- a/debug/accuracy_tools/msprobe/core/common/const.py +++ b/debug/accuracy_tools/msprobe/core/common/const.py @@ -238,6 +238,7 @@ class Const: NORM = 'Norm' DATA_NAME = 'data_name' TENSOR_STAT_INDEX = 'tensor_stat_index' + SUMMARY_METRICS_LIST = [MAX, MIN, MEAN, NORM] CODE_STACK = 'Code Stack' OP_NAME = 'Op Name' diff --git a/debug/accuracy_tools/msprobe/core/compare/utils.py b/debug/accuracy_tools/msprobe/core/compare/utils.py index 183f8dfb6..1e67a8020 100644 --- a/debug/accuracy_tools/msprobe/core/compare/utils.py +++ b/debug/accuracy_tools/msprobe/core/compare/utils.py @@ -228,7 +228,10 @@ def merge_tensor(tensor_list, dump_mode): op_dict.get(struct_key).append((tensor[Const.DTYPE], tensor[Const.SHAPE], tensor[Const.MD5])) else: op_dict.get(struct_key).append((tensor[Const.DTYPE], tensor[Const.SHAPE])) - op_dict[Const.SUMMARY].append([tensor[Const.MAX], tensor[Const.MIN], tensor[Const.MEAN], tensor[Const.NORM]]) + + # 当统计量为None时,转成字符串None,避免后续操作list放到pd中时None被默认转成NaN + op_dict[Const.SUMMARY].append( + [str(tensor[key]) if tensor[key] is None else tensor[key] for key in Const.SUMMARY_METRICS_LIST]) if dump_mode == Const.ALL: op_dict["data_name"].append(tensor['data_name']) -- Gitee