From 1dab7c1e6c44560126fc0fb5d92834535efba3a5 Mon Sep 17 00:00:00 2001 From: pxp1 <958876660@qq.com> Date: Mon, 15 Jul 2024 15:50:03 +0800 Subject: [PATCH] codeclean --- .../grad_tool/common/base_comparator.py | 36 +++++++++---------- .../grad_tool/grad_ms/grad_analyzer.py | 2 ++ 2 files changed, 20 insertions(+), 18 deletions(-) diff --git a/debug/accuracy_tools/grad_tool/common/base_comparator.py b/debug/accuracy_tools/grad_tool/common/base_comparator.py index b5dc45b20..f940ef513 100644 --- a/debug/accuracy_tools/grad_tool/common/base_comparator.py +++ b/debug/accuracy_tools/grad_tool/common/base_comparator.py @@ -12,6 +12,24 @@ from grad_tool.common.utils import write_csv, check_file_or_directory_path, prin class BaseComparator(ABC): + @staticmethod + def _get_grad_weight_order(path1, path2): + for summary_file in os.listdir(path1): + if not summary_file.endswith(".csv"): + continue + if not os.path.exists(os.path.join(path2, summary_file)): + continue + summary_csv = pd.read_csv(os.path.join(path1, summary_file)) + return summary_csv["param_name"] + raise RuntimeError("no matched grad_summary.csv for comparison, please dump data in same configuration") + + @staticmethod + def _get_name_matched_grad_file(param_name, grad_files): + for grad_file in grad_files: + if param_name == grad_file[:grad_file.rfind('.')]: + return grad_file + raise RuntimeError("no matched grad_file for comparison, please dump data in same configuration") + @classmethod def compare_distributed(cls, path1: str, path2: str, output_dir: str): ranks = cls._get_matched_dirs(path1, path2, "rank") @@ -72,24 +90,6 @@ class BaseComparator(ABC): head_tuple = tuple(['step'] + [str(step) for step in steps]) write_csv(os.path.join(output_dir, "similarities.csv"), [[key] + value], head_tuple) - @staticmethod - def _get_grad_weight_order(path1, path2): - for summary_file in os.listdir(path1): - if not summary_file.endswith(".csv"): - continue - if not os.path.exists(os.path.join(path2, summary_file)): - continue - summary_csv = pd.read_csv(os.path.join(path1, summary_file)) - return summary_csv["param_name"] - raise RuntimeError("no matched grad_summary.csv for comparison, please dump data in same configuration") - - @staticmethod - def _get_name_matched_grad_file(param_name, grad_files): - for grad_file in grad_files: - if param_name == grad_file[:grad_file.rfind('.')]: - return grad_file - raise RuntimeError("no matched grad_file for comparison, please dump data in same configuration") - @classmethod def _calculate_separated_similarities(cls, path1, path2, steps): similarities = {} diff --git a/debug/accuracy_tools/grad_tool/grad_ms/grad_analyzer.py b/debug/accuracy_tools/grad_tool/grad_ms/grad_analyzer.py index 963a37f86..75280b319 100644 --- a/debug/accuracy_tools/grad_tool/grad_ms/grad_analyzer.py +++ b/debug/accuracy_tools/grad_tool/grad_ms/grad_analyzer.py @@ -78,6 +78,8 @@ class CSVGenerator(Process): self.level = GradConst.LEVEL0 self.cache_list = ListCache() self.current_step = None + self.stop_event = None + self.last_finish = False self.bounds = [-0.1, 0.0, 0.1], def init(self, context: GlobalContext): -- Gitee