diff --git a/debug/accuracy_tools/msprobe/core/common/const.py b/debug/accuracy_tools/msprobe/core/common/const.py index 7e98eb341833aa5c910f09c013b8deca1bf9562e..1a806f9c1744f4693a72144120c71ae402a6f1a2 100644 --- a/debug/accuracy_tools/msprobe/core/common/const.py +++ b/debug/accuracy_tools/msprobe/core/common/const.py @@ -234,6 +234,7 @@ class Const: NORM = 'Norm' DATA_NAME = 'data_name' TENSOR_STAT_INDEX = 'tensor_stat_index' + SUMMARY_METRICS_LIST = [MAX, MIN, MEAN, NORM] CODE_STACK = 'Code Stack' OP_NAME = 'Op Name' diff --git a/debug/accuracy_tools/msprobe/core/compare/utils.py b/debug/accuracy_tools/msprobe/core/compare/utils.py index c3a93e15792cc274a3518552c8c8c33c149a7ade..ecec24859a7113b31e93785b7093778daf61e70f 100644 --- a/debug/accuracy_tools/msprobe/core/compare/utils.py +++ b/debug/accuracy_tools/msprobe/core/compare/utils.py @@ -215,7 +215,10 @@ def merge_tensor(tensor_list, dump_mode): op_dict.get(struct_key).append((tensor[Const.DTYPE], tensor[Const.SHAPE], tensor[Const.MD5])) else: op_dict.get(struct_key).append((tensor[Const.DTYPE], tensor[Const.SHAPE])) - op_dict[Const.SUMMARY].append([tensor[Const.MAX], tensor[Const.MIN], tensor[Const.MEAN], tensor[Const.NORM]]) + + # 当统计量为None时,转成字符串None,避免后续操作list放到pd中时None被默认转成NaN + op_dict[Const.SUMMARY].append( + [str(tensor[key]) if tensor[key] is None else tensor[key] for key in Const.SUMMARY_METRICS_LIST]) if dump_mode == Const.ALL: op_dict["data_name"].append(tensor['data_name'])