diff --git a/debug/accuracy_tools/msprobe/core/common/const.py b/debug/accuracy_tools/msprobe/core/common/const.py index b46144e5c94482210751fc170d3eafe57c3e9f5e..5f27bf467b1e592cf7c6aea9cd1dd34e2280b96f 100644 --- a/debug/accuracy_tools/msprobe/core/common/const.py +++ b/debug/accuracy_tools/msprobe/core/common/const.py @@ -743,6 +743,11 @@ class MonitorConst: DEFAULT_STEP_INTERVAL = 1 OP_LIST = ["norm", "min", "max", "zeros", "nans", "id", "mean", "shape", "dtype"] + OP_MONVIS_SUPPORTED = [ + "norm", "min", "max", "zeros", "nans", "mean", + "entropy", "softmax_max", "sr", "kernel_norm", "std_x", "jacobian", + "proxy", "token_similarity" + ] MONITOR_OUTPUT_DIR = "MONITOR_OUTPUT_DIR" DEFAULT_MONITOR_OUTPUT_DIR = "./monitor_output" DATABASE = "database" diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/csv2db.py b/debug/accuracy_tools/msprobe/pytorch/monitor/csv2db.py index ade9ea6d13aba7aa8e7811e48e367a1023039be4..32873b6682b45785e95303bcb81f845a7d7e9873 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/csv2db.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/csv2db.py @@ -37,7 +37,8 @@ from tqdm import tqdm all_data_type_list = [ "actv", "actv_grad", "exp_avg", "exp_avg_sq", - "grad_unreduced", "grad_reduced", "param_origin", "param_updated" + "grad_unreduced", "grad_reduced", "param_origin", "param_updated", + "linear_hook", "norm_hook", "proxy_model", "token_hook", "attention_hook" ] DEFAULT_INT_VALUE = 0 MAX_PROCESS_NUM = 128 @@ -83,7 +84,7 @@ def update_with_order_dict(main_dict, new_list): def get_ordered_stats(stats): if not isinstance(stats, Iterable): return [] - return [stat for stat in MonitorConst.OP_LIST if stat in stats] + return [stat for stat in MonitorConst.OP_MONVIS_SUPPORTED if stat in stats] def pre_scan_single_rank(rank, files): @@ -106,7 +107,7 @@ def pre_scan_single_rank(rank, files): max_step = step_end if max_step < step_end else max_step data = read_csv(file_path) - stats = [k for k in data.keys() if k in MonitorConst.OP_LIST] + stats = [k for k in data.keys() if k in MonitorConst.OP_MONVIS_SUPPORTED] metric_stats[metric_name].update(stats) for _, row in data.iterrows():