From 087c0e0095bb5cb6e6183012bbc880ee61cde09a Mon Sep 17 00:00:00 2001 From: MooYeh Date: Mon, 20 May 2024 18:23:46 +0800 Subject: [PATCH] [bugfix] In pipepline parallel mode, filter invalid grad data. --- .../grad_tool/grad_ms/grad_analyzer.py | 28 ++++++++++++++++--- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/debug/accuracy_tools/grad_tool/grad_ms/grad_analyzer.py b/debug/accuracy_tools/grad_tool/grad_ms/grad_analyzer.py index d710b8eeb..963a37f86 100644 --- a/debug/accuracy_tools/grad_tool/grad_ms/grad_analyzer.py +++ b/debug/accuracy_tools/grad_tool/grad_ms/grad_analyzer.py @@ -58,10 +58,8 @@ def grad_dump(dump_dir: str, g_name: str, dump_step: Parameter, grad: ms.Tensor, bucket_result = bucket_result.astype(ms.int8) dist_stat = [(bucket_result == i).sum() for i in range(len(bounds) + 1)] dist_stat.append(zero_grad) + dist_stat.append(ms.Tensor(1, dtype=ms.int64)) # make sure dist_stat is not empty dist_stat = ms.ops.stack(dist_stat, axis=0).float() - element_num = dist_stat.sum() - dist_stat[-1] - if element_num != 0: - dist_stat = dist_stat / element_num level1_stat = ms.ops.concat((level0_stat, dist_dim, dist_stat), axis=0) level_stat = level1_stat @@ -139,6 +137,9 @@ class CSVGenerator(Process): stat_data = self.load_npy_data(file_path) if stat_data is None: continue + if not self.check_valid(stat_data): + os.remove(file_path) + continue step = int(stat_data[GradConst.STEP_IDX]) update_step = self.current_step is None or step != self.current_step self.current_step = step @@ -148,6 +149,21 @@ class CSVGenerator(Process): os.remove(file_path) self.last_finish = False + def check_valid(self, stat_data): + level = grad_context.get_context(GradConst.LEVEL) + try: + shape_dim = int(stat_data[GradConst.SHAPE_DIM_IDX]) + if level in [GradConst.LEVEL1, GradConst.LEVEL3]: + dist_dim = int(stat_data[shape_dim + GradConst.SHAPE_DIM_IDX + 1]) + length = shape_dim + dist_dim + 7 + else: + length = shape_dim + 5 + except IndexError as err: + return False + if length != len(stat_data): + return False + return True + def load_npy_data(self, file_path: str): stat_data = None max_try = 10 @@ -175,7 +191,11 @@ class CSVGenerator(Process): self.cache_list.append(csv_line) def get_dist_data(self, shape_dim: int, stat_data: np.ndarray): - return list(stat_data[(shape_dim + GradConst.SHAPE_DIM_IDX + 2):]) + dist_data = stat_data[(shape_dim + GradConst.SHAPE_DIM_IDX + 2):-1] + element_num = dist_data.sum() - dist_data[-1] + if element_num != 0: + dist_data = dist_data / element_num + return list(dist_data) def get_extrem_data(self, shape_dim: int, stat_data: np.ndarray): extrem_data = list(stat_data[(GradConst.STEP_IDX + 1):(GradConst.STEP_IDX + 4)]) -- Gitee