diff --git a/debug/accuracy_tools/msprobe/core/common/const.py b/debug/accuracy_tools/msprobe/core/common/const.py index 5f27bf467b1e592cf7c6aea9cd1dd34e2280b96f..45f4c44008c46f4d0267080b9058993e39628fab 100644 --- a/debug/accuracy_tools/msprobe/core/common/const.py +++ b/debug/accuracy_tools/msprobe/core/common/const.py @@ -233,6 +233,7 @@ class Const: NORM = 'Norm' DATA_NAME = 'data_name' TENSOR_STAT_INDEX = 'tensor_stat_index' + SUMMARY_METRICS_LIST = [MAX, MIN, MEAN, NORM] CODE_STACK = 'Code Stack' OP_NAME = 'Op Name' diff --git a/debug/accuracy_tools/msprobe/core/compare/utils.py b/debug/accuracy_tools/msprobe/core/compare/utils.py index 809c17d603b1704c215bfa524ae5b13f8cd8ed84..4da1adfa0908b66679b11f9f3ffadaaf237c1177 100644 --- a/debug/accuracy_tools/msprobe/core/compare/utils.py +++ b/debug/accuracy_tools/msprobe/core/compare/utils.py @@ -215,7 +215,10 @@ def merge_tensor(tensor_list, dump_mode): op_dict.get(struct_key).append((tensor[Const.DTYPE], tensor[Const.SHAPE], tensor[Const.MD5])) else: op_dict.get(struct_key).append((tensor[Const.DTYPE], tensor[Const.SHAPE])) - op_dict[Const.SUMMARY].append([tensor[Const.MAX], tensor[Const.MIN], tensor[Const.MEAN], tensor[Const.NORM]]) + + # 当统计量为None时,转成字符串None,避免后续操作list放到pd中时None被默认转成NaN + op_dict[Const.SUMMARY].append( + [str(tensor[key]) if tensor[key] is None else tensor[key] for key in Const.SUMMARY_METRICS_LIST]) if dump_mode == Const.ALL: op_dict["data_name"].append(tensor['data_name'])