From 92200bac89c836ffbe40f3806f95ced5ddcca510 Mon Sep 17 00:00:00 2001 From: MooYeh Date: Mon, 13 May 2024 09:34:04 +0800 Subject: [PATCH] [bugfix] ms_grad_tool --- debug/accuracy_tools/grad_tool/README.md | 24 ++-- .../accuracy_tools/grad_tool/grad_monitor.py | 4 + .../grad_tool/grad_ms/global_context.py | 2 +- .../grad_tool/grad_ms/grad_analyzer.py | 112 ++++++++---------- .../grad_tool/grad_ms/grad_monitor.py | 3 + .../accuracy_tools/grad_tool/grad_ms/hook.py | 22 ++-- 6 files changed, 88 insertions(+), 79 deletions(-) diff --git a/debug/accuracy_tools/grad_tool/README.md b/debug/accuracy_tools/grad_tool/README.md index e282ba4d7..5b95ba254 100644 --- a/debug/accuracy_tools/grad_tool/README.md +++ b/debug/accuracy_tools/grad_tool/README.md @@ -40,7 +40,8 @@ bounds: [-10, -1, -0.1, -0.01, -0.001, 0, 0.001, 0.01, 0.1, 1, 10] output_path: /home/pxp1/code/train_test_msft_multi/test/npu_grad_output4 ``` - > step在MindSpore框架下,要求必须是range列表或者不指定 + > 在MindSpore框架下,当前不支持rank和step配置,默认所有rank和所有step都进行采集 + > MindSpore中step指的是优化器被调用的次数 **参数说明** @@ -92,15 +93,24 @@ - PyTorch框架 在训练开始前,调用gm.monitor并将模型作为参数传入。 - ```python - gm.monitor(model) - ``` +```python +gm.monitor(model) +``` + - MindSpore框架 在训练开始前,调用gm.monitor并将优化器作为参数传入。 - ```python - gm.monitor(optimizer) - ``` +```python +gm.monitor(optimizer) +``` + +4. 结束监控(MindSpore需要) + + 在训练结束之后,调用stop接口 + +```python +gm.stop() +``` ### 输出结果 **输出目录结构**(以level配置L2为例) diff --git a/debug/accuracy_tools/grad_tool/grad_monitor.py b/debug/accuracy_tools/grad_tool/grad_monitor.py index 8b6c6bb4a..64afba746 100644 --- a/debug/accuracy_tools/grad_tool/grad_monitor.py +++ b/debug/accuracy_tools/grad_tool/grad_monitor.py @@ -16,3 +16,7 @@ class GradientMonitor: def monitor(self, module): self.grad_monitor.monitor(module) + + def stop(self): + if self.framework == GradConst.MindSpore: + self.grad_monitor.stop() diff --git a/debug/accuracy_tools/grad_tool/grad_ms/global_context.py b/debug/accuracy_tools/grad_tool/grad_ms/global_context.py index 91806ee6a..b988d3503 100644 --- a/debug/accuracy_tools/grad_tool/grad_ms/global_context.py +++ b/debug/accuracy_tools/grad_tool/grad_ms/global_context.py @@ -17,7 +17,7 @@ class GlobalContext: GradConst.RANK: None, GradConst.STEP: [0, 0], GradConst.CURRENT_STEP: 0, - GradConst.BOUNDS: [-10., -1., -0.1, -0.01, -0.001, 0, 0.001, 0.01, 0.1, 1., 10.], + GradConst.BOUNDS: [-1., 0., 1.], GradConst.OUTPUT_PATH: "./grad_stat" } diff --git a/debug/accuracy_tools/grad_tool/grad_ms/grad_analyzer.py b/debug/accuracy_tools/grad_tool/grad_ms/grad_analyzer.py index 64af57ad8..1b11ce63c 100644 --- a/debug/accuracy_tools/grad_tool/grad_ms/grad_analyzer.py +++ b/debug/accuracy_tools/grad_tool/grad_ms/grad_analyzer.py @@ -2,6 +2,7 @@ import os import shutil import time from typing import List, Tuple +import multiprocessing from multiprocessing import Process import numpy as np @@ -25,63 +26,49 @@ def get_rank_id(): return rank_id -class GradAnalyzer: - - @staticmethod - def dump(dump_dir: str, g_name: str, dump_step: Parameter, grad: ms.Tensor): - ''' - Dump gradient statistic data. - level0: [step, max, min, norm, shape_dim, shape] - level1: [step, max, min, norm, shape_dim, shape, dist_dim, dist] - level2: [step, max, min, norm, shape_dim, shape] + grad_bool_data - level3: [step, max, min, norm, shape_dim, shape, dist_dim, dist] + grad_bool_data - ''' - dump_path = os.path.join(dump_dir, g_name) - dump_dir_path = dump_path + "_dir" - save_op = ms.ops.TensorDump() - level = grad_context.get_context(GradConst.LEVEL) - - if level == GradConst.LEVEL0 or level == GradConst.LEVEL2: - level_stat = GradAnalyzer.calculate_level0(dump_step, grad) - else: - level_stat = GradAnalyzer.calculate_level1(dump_step, grad) - - save_op(dump_path, level_stat) - if level == GradConst.LEVEL2 or level == GradConst.LEVEL3: - grad_direction = GradAnalyzer.calculate_direction(grad) - save_op(dump_dir_path, grad_direction) - - @staticmethod - def calculate_level0(dump_step: Parameter, grad: ms.Tensor): - is_bf16 = grad.dtype == ms.bfloat16 - max_val = grad.max().float() if is_bf16 else grad.max() - min_val = grad.min().float() if is_bf16 else grad.min() - norm_val = grad.norm().float() if is_bf16 else grad.norm() - shape = grad.shape - extrem_stat = ms.ops.stack([dump_step[0].astype(max_val.dtype), max_val, min_val, norm_val]) - shape_stat = ms.Tensor([len(shape)] + list(shape)).astype(max_val.dtype) - level0_stat = ms.ops.concat((extrem_stat, shape_stat), axis=0) - return level0_stat - - @staticmethod - def calculate_level1(dump_step: Parameter, grad: ms.Tensor): - level0_stat = GradAnalyzer.calculate_level0(dump_step, grad) - bounds = grad_context.get_context(GradConst.BOUNDS) +@ms.jit +def grad_dump(dump_dir: str, g_name: str, dump_step: Parameter, grad: ms.Tensor, level: str, bounds: List): + ''' + Dump gradient statistic data. + level0: [step, max, min, norm, shape_dim, shape] + level1: [step, max, min, norm, shape_dim, shape, dist_dim, dist] + level2: [step, max, min, norm, shape_dim, shape] + grad_bool_data + level3: [step, max, min, norm, shape_dim, shape, dist_dim, dist] + grad_bool_data + ''' + dump_path = dump_dir + g_name + dump_dir_path = dump_path + "_dir" + save_op = ms.ops.TensorDump() + + grad_flat = grad.reshape(-1) + max_val = grad_flat.max(axis=0).float() + min_val = grad_flat.min(axis=0).float() + norm_val = grad_flat.norm(ord=2).float() + shape = grad.shape + extrem_list = [dump_step[0].float(), max_val, min_val, norm_val] + extrem_stat = ms.ops.stack(extrem_list) + shape_list = [len(shape)] + list(shape) + shape_stat = ms.Tensor(shape_list).float() + level0_stat = ms.ops.concat((extrem_stat, shape_stat), axis=0) + level_stat = level0_stat + + if level == "L1" or level == "L3": zero_grad = (grad == 0).sum() - dist_dim = ms.Tensor([len(bounds) + 2]).astype(level0_stat.dtype) - bucket_result = ms.ops.bucketize(grad, bounds).astype(ms.int8) + dist_dim = ms.Tensor([len(bounds) + 2]).float() + bucket_result = ms.ops.bucketize(grad.float(), bounds) + bucket_result = bucket_result.astype(ms.int8) dist_stat = [(bucket_result == i).sum() for i in range(len(bounds) + 1)] dist_stat.append(zero_grad) - dist_stat = ms.ops.stack(dist_stat, axis=0).astype(level0_stat.dtype) + dist_stat = ms.ops.stack(dist_stat, axis=0).float() element_num = dist_stat.sum() - dist_stat[-1] if element_num != 0: dist_stat = dist_stat / element_num level1_stat = ms.ops.concat((level0_stat, dist_dim, dist_stat), axis=0) - return level1_stat + level_stat = level1_stat - @staticmethod - def calculate_direction(grad: ms.Tensor): - return grad > 0 + save_op(dump_path, level_stat) + if level == "L2" or level == "L3": + grad_direction = grad > 0 + save_op(dump_dir_path, grad_direction) class CSVGenerator(Process): @@ -93,32 +80,35 @@ class CSVGenerator(Process): self.level = GradConst.LEVEL0 self.cache_list = ListCache() self.current_step = None - self.bounds = [-10, -1, -0.1, -0.01, -0.001, 0, 0.001, 0.01, 0.1, 1, 10], + self.bounds = [-0.1, 0.0, 0.1], def init(self, context: GlobalContext): rank_id = get_rank_id() output_path = context.get_context(GradConst.OUTPUT_PATH) self.level = context.get_context(GradConst.LEVEL) self.bounds = context.get_context(GradConst.BOUNDS) - step_range = context.get_context(GradConst.STEP) - self.step_end = 0 if step_range is None else step_range[1] self.dump_dir = f"{output_path}/rank_{rank_id}/Dump/" self.save_dir = f"{output_path}/rank_{rank_id}/" self.current_step = None - self.finish_flag = False + self.stop_event = multiprocessing.Event() + self.last_finish = False def run(self): - while not self.finish_flag: + while True: if not os.path.exists(self.dump_dir): time.sleep(0.1) continue npy_files = os.listdir(self.dump_dir) npy_files.sort(key=lambda x: int(x.split("_")[0])) - if not npy_files: - continue self.traverse_files(npy_files) + empty = len(os.listdir(self.dump_dir)) == 0 + if self.stop_event.is_set() and empty and self.last_finish: + break shutil.rmtree(self.dump_dir) + def stop(self): + self.stop_event.set() + def traverse_files(self, npy_files: List): for npy_file in npy_files: file_path = os.path.join(self.dump_dir, npy_file) @@ -128,8 +118,7 @@ class CSVGenerator(Process): if GradConst.STEP_FINISH in npy_file: self.cache_list.flush() os.remove(file_path) - if self.current_step == self.step_end: - self.finish_flag = True + self.last_finish = True elif file_path.split("_")[-1] == GradConst.DIR_SUFFIX: prefix_idx = len(npy_file.split("_")[0]) new_name = npy_file[prefix_idx + 1:].replace("_" + GradConst.DIR_SUFFIX, "." + GradConst.NPY_SUFFIX) @@ -142,16 +131,19 @@ class CSVGenerator(Process): create_directory(step_dir) dst_file = os.path.join(step_dir, new_name) shutil.move(file_path, dst_file) + self.last_finish = False elif file_path.split(".")[-1] == GradConst.NPY_SUFFIX: stat_data = self.load_npy_data(file_path) if stat_data is None: continue step = int(stat_data[GradConst.STEP_IDX]) - if self.current_step is None or step != self.current_step: - self.current_step = step + update_step = self.current_step is None or step != self.current_step + self.current_step = step + if update_step: self.create_csv_file() self.gen_csv_line(file_path, stat_data) os.remove(file_path) + self.last_finish = False def load_npy_data(self, file_path: str): stat_data = None diff --git a/debug/accuracy_tools/grad_tool/grad_ms/grad_monitor.py b/debug/accuracy_tools/grad_tool/grad_ms/grad_monitor.py index a7ab6f453..f822fa923 100644 --- a/debug/accuracy_tools/grad_tool/grad_ms/grad_monitor.py +++ b/debug/accuracy_tools/grad_tool/grad_ms/grad_monitor.py @@ -14,3 +14,6 @@ class MsGradientMonitor(BaseMonitor): def monitor(self, module): hook_optimizer(module) + + def stop(self): + csv_generator.stop() diff --git a/debug/accuracy_tools/grad_tool/grad_ms/hook.py b/debug/accuracy_tools/grad_tool/grad_ms/hook.py index 42aae0d82..d69a299ba 100644 --- a/debug/accuracy_tools/grad_tool/grad_ms/hook.py +++ b/debug/accuracy_tools/grad_tool/grad_ms/hook.py @@ -12,7 +12,7 @@ from mindspore.common.initializer import initializer from grad_tool.common.constant import GradConst from grad_tool.common.utils import print_warn_log from grad_tool.grad_ms.global_context import grad_context -from grad_tool.grad_ms.grad_analyzer import GradAnalyzer, get_rank_id +from grad_tool.grad_ms.grad_analyzer import grad_dump, get_rank_id from grad_tool.grad_ms.grad_analyzer import csv_generator @@ -32,20 +32,20 @@ def hook_optimizer(opt: Optimizer): if os.path.exists(save_dir): print_warn_log(f"Delete existing path {save_dir}.") shutil.rmtree(save_dir) + level = grad_context.get_context(GradConst.LEVEL) + bounds = grad_context.get_context(GradConst.BOUNDS) @jit def new_construct(self, gradients): - if step_start <= self.dump_step[0] <= step_end: - for index, grad_value in enumerate(gradients): - if param_list and g_names[index] not in param_list: - continue - GradAnalyzer.dump(dump_dir, g_names[index], self.dump_step, grad_value) - ms.ops.TensorDump()(step_finish_flag, self.dump_step) + for index, grad_value in enumerate(gradients): + if param_list and g_names[index] not in param_list: + continue + grad_dump(dump_dir, g_names[index], self.dump_step, grad_value, level, bounds) + ms.ops.TensorDump()(step_finish_flag, self.dump_step) self.assignadd(self.dump_step, self.global_step_increase_tensor) out = func(gradients) return out - if rank_list is None or rank_id in rank_list: - opt.dump_step = Parameter(initializer(0, [1], ms.int32), name="dump_step") - opt.construct = new_construct.__get__(opt, type(opt)) - csv_generator.start() + opt.dump_step = Parameter(initializer(0, [1], ms.int32), name="dump_step") + opt.construct = new_construct.__get__(opt, type(opt)) + csv_generator.start() -- Gitee