diff --git a/debug/accuracy_tools/msprobe/core/common/const.py b/debug/accuracy_tools/msprobe/core/common/const.py index 59c3a354f6d36b62728b2b0844b148f99c075b2e..7b90110f29ef09ad88b0fa748f2e718f8f32c1b0 100644 --- a/debug/accuracy_tools/msprobe/core/common/const.py +++ b/debug/accuracy_tools/msprobe/core/common/const.py @@ -238,6 +238,7 @@ class Const: NORM = 'Norm' DATA_NAME = 'data_name' TENSOR_STAT_INDEX = 'tensor_stat_index' + SUMMARY_METRICS_LIST = [MAX, MIN, MEAN, NORM] CODE_STACK = 'Code Stack' OP_NAME = 'Op Name' diff --git a/debug/accuracy_tools/msprobe/core/compare/utils.py b/debug/accuracy_tools/msprobe/core/compare/utils.py index 183f8dfb6b78ba7a88c9682fdba6cea14b1e3652..1e67a8020d4af209b91fc7481c5eac4cee164a69 100644 --- a/debug/accuracy_tools/msprobe/core/compare/utils.py +++ b/debug/accuracy_tools/msprobe/core/compare/utils.py @@ -228,7 +228,10 @@ def merge_tensor(tensor_list, dump_mode): op_dict.get(struct_key).append((tensor[Const.DTYPE], tensor[Const.SHAPE], tensor[Const.MD5])) else: op_dict.get(struct_key).append((tensor[Const.DTYPE], tensor[Const.SHAPE])) - op_dict[Const.SUMMARY].append([tensor[Const.MAX], tensor[Const.MIN], tensor[Const.MEAN], tensor[Const.NORM]]) + + # 当统计量为None时,转成字符串None,避免后续操作list放到pd中时None被默认转成NaN + op_dict[Const.SUMMARY].append( + [str(tensor[key]) if tensor[key] is None else tensor[key] for key in Const.SUMMARY_METRICS_LIST]) if dump_mode == Const.ALL: op_dict["data_name"].append(tensor['data_name'])