From 47638a3dcb0658ce651ca0289a4893a0d5281606 Mon Sep 17 00:00:00 2001 From: heweidong7 <511650494@qq.com> Date: Fri, 18 Oct 2024 09:27:01 +0800 Subject: [PATCH 01/13] =?UTF-8?q?=E7=A8=8B=E5=BA=8F=E6=89=93=E5=8D=B0?= =?UTF-8?q?=E7=BB=9F=E4=B8=80=E8=87=B3print=5Flog=5Fwith=5Frank=E5=87=BD?= =?UTF-8?q?=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../accuracy_tools/kj600/kj600/module_hook.py | 69 ++++++++----------- debug/accuracy_tools/kj600/kj600/utils.py | 18 +++++ 2 files changed, 48 insertions(+), 39 deletions(-) diff --git a/debug/accuracy_tools/kj600/kj600/module_hook.py b/debug/accuracy_tools/kj600/kj600/module_hook.py index c26750f7872..6bbcee832f4 100644 --- a/debug/accuracy_tools/kj600/kj600/module_hook.py +++ b/debug/accuracy_tools/kj600/kj600/module_hook.py @@ -14,14 +14,14 @@ import torch.distributed as dist from torch.utils.hooks import BackwardHook from torch.optim.optimizer import register_optimizer_step_pre_hook, register_optimizer_step_post_hook from kj600.module_spec_verifier import validate_config_spec -from kj600.optimizer_collect import OptimizerMon, print_rank_0, OptimizerMonFactory +from kj600.optimizer_collect import OptimizerMon, OptimizerMonFactory from kj600.features import eff_rank, get_sign_matches from kj600.visualizer import HeatmapVisualizer from kj600.anomaly_detect import AnomalyScanner, AnomalyDataFactory, SummaryWriterWithAD, CSVWriterWithAD, BaseWriterWithAD from kj600.anomaly_analyse import AnomalyDataWriter from kj600.module_metric import get_metrics, write_metrics_tensorboard, write_metrics_csv, get_summary_writer_tag_name, TensorMetrics, squash_param_name, sqrt_norm_metric, reorder_metric from kj600.distributed.wrap_distributed import api_register, create_hooks, op_aggregate, get_process_group -from kj600.utils import print_warn_log, print_info_log, print_error_log, get_param_struct, validate_config, validate_ops +from kj600.utils import print_log_with_rank, get_param_struct, validate_config, validate_ops from kj600.const import Const from kj600.file_check import FileOpen @@ -136,29 +136,29 @@ class TrainerMon: self.all_xy = self.config.get('all_xy', False) self.xy_distribution = self.config.get('xy_distribution', False) if not self.xy_distribution: - print_rank_0("> module input/output input_grad/output_grad is not monitored. ") + print_log_with_rank("> module input/output input_grad/output_grad is not monitored. ", 0, 'INFO') # backward hook cause megatron-lm pipeline parallel schedule assert exception. # TBD: backward hook cause output tensor is view of some base tensor. root cause invesigation pending. self.forward_only = self.config.get('forward_only', False) if self.forward_only: - print_rank_0("> only module forward is monitored. ") + print_log_with_rank("> only module forward is monitored. ", 0, 'INFO') self.backward_only = self.config.get('backward_only', False) self.ur_distribution = self.config.get('ur_distribution', False) if not self.ur_distribution: - print_rank_0("> update vector and ratio vector of adam is not monitored. ") + print_log_with_rank("> update vector and ratio vector of adam is not monitored. ", 0, 'INFO') self.mv_distribution = self.config.get("mv_distribution", False) if not self.mv_distribution: - print_rank_0("> momentum and variance of adam is not monitored. ") + print_log_with_rank("> momentum and variance of adam is not monitored. ", 0, 'INFO') self.wg_distribution = self.config.get("wg_distribution", False) if not self.wg_distribution: - print_rank_0("> weight grad of specified module is not monitored. ") + print_log_with_rank("> weight grad of specified module is not monitored. ", 0, 'INFO') self.mg_direction = self.config.get('mg_direction', False) if not self.mg_direction: - print_rank_0('> grad and momentum direction will not be compared.') + print_log_with_rank('> grad and momentum direction will not be compared.', 0, 'INFO') self.cc_distribution = self.config.get("cc_distribution", {}) if not self.cc_distribution.get('enable', False): - print_rank_0("> cc operator is not monitored.") + print_log_with_rank("> cc operator is not monitored.", 0, 'INFO') self.cc_log_only = False else: self.cc_codeline = self.cc_distribution.get('cc_codeline', []) @@ -300,8 +300,8 @@ class TrainerMon: vpp_stage = f'{vpp_stage}{Const.VPP_SEP}' targets = [x for x, _ in model_chunk.named_modules()] if self.print_struct else self.config['targets'].keys() hooked_count += self._hook_module(targets, model_chunk, vpp_stage) - - print_rank_0(f"> {hooked_count} out of {len(self.config['targets'])} are monitored.") + + print_log_with_rank(f"> {hooked_count} out of {len(self.config['targets'])} are monitored.", 0, 'INFO') def clone_if_tensor(args): if isinstance(args, tuple): @@ -357,7 +357,7 @@ class TrainerMon: continue grad = param.main_grad if self.params_have_main_grad else param.grad if grad is None: - print_warn_log(f"grad is None: {name}, maybe something wrong happened.") + print_log_with_rank(f"grad is None: {name}, maybe something wrong happened.", self.rank, 'WARNING') continue key = get_summary_writer_tag_name(name, 'post_grad', self.rank) grad_dict[key] = grad @@ -370,7 +370,7 @@ class TrainerMon: return reduced, unreduced def monitor_gnorm_with_ad(self, model, grad_acc_steps=1, optimizer=None, tp_group=None, dp_group=None): - print_info_log(f'grad acc steps {grad_acc_steps}') + print_log_with_rank(f'grad acc steps {grad_acc_steps}', self.rank, 'INFO') self.hook_optimizer(optimizer) self.micro_batch_number = grad_acc_steps @@ -480,14 +480,15 @@ class TrainerMon: context.param_adam_update = mv_result.update context.param_adam_ratio = mv_result.ratio + smallest_rank = min(self.module_rank_list) if self.module_rank_list else 0 if self.print_struct and not all(value == {} for value in self.module_struct.values()) and not self.struct_printed: - self._smallest_rank_print("> module struct:") - self._smallest_rank_print(json.dumps(self.module_struct, indent=4)) + print_log_with_rank("> module struct:", smallest_rank, 'INFO') + print_log_with_rank(json.dumps(self.module_struct, indent=4), smallest_rank, 'INFO') if not self.cc_log_only: raise Exception("exit after first step when print model struct") if self.cc_log_only and context.step > 0: - self._smallest_rank_print("> Used communication ops and corresponding stack") - self._smallest_rank_print(json.dumps({k:[i.split(';') for i in v] for k,v in self.cc_logged_stack.items()}, indent=4)) + print_log_with_rank("> Used communication ops and corresponding stack", smallest_rank, 'INFO') + print_log_with_rank(json.dumps({k:[i.split(';') for i in v] for k,v in self.cc_logged_stack.items()}, indent=4), smallest_rank, 'INFO') raise Exception("exit after first step when print cc stack") self.generate_wgrad_metrics() @@ -498,7 +499,7 @@ class TrainerMon: for param, name in self.param2name.items(): grad = param.main_grad if self.params_have_main_grad else param.grad if grad is None: - print_warn_log(f"grad is None: {name}, maybe something wrong happened.") + print_log_with_rank(f"grad is None: {name}, maybe something wrong happened.", self.rank, 'WARNING') continue if context.step == 0: same_direction_ratio = torch.tensor(1.) @@ -575,19 +576,6 @@ class TrainerMon: self.optimizer_hooked = True return - def _smallest_rank_print(self, msg): - if not self.verbose: - return - if dist.is_initialized(): - if self.module_rank_list: - if dist.get_rank() == min(self.module_rank_list): - print_info_log(msg) - else: - if dist.get_rank() == 0: - print_info_log(msg) - else: - print_info_log(msg) - def _is_target_param(self, param_name, param, prefix): squash_name = prefix + squash_param_name(param_name) name = prefix + param_name @@ -605,8 +593,8 @@ class TrainerMon: if self._is_target_param(param_name, param, prefix): name = prefix + squash_param_name(param_name) if name in self.param2name.values(): - print_error_log(f'same name {name} for different param. Current param is {param_name}. \ - May be error of squash_param_name') + print_log_with_rank(f'same name {name} for different param. Current param is {param_name}. \ + May be error of squash_param_name', self.rank, 'ERROR') raise Exception("param with same name will be overwritten.") self.param2name[param] = name self.name2param[name] = param @@ -621,7 +609,8 @@ class TrainerMon: if len(model) > 1: self.vpp = True - self._smallest_rank_print('vpp enabled') + smallest_rank = min(self.module_rank_list) if self.module_rank_list else 0 + print_log_with_rank('vpp enabled', smallest_rank, 'INFO') for vpp_stage, model_chunk in enumerate(model): prefix = f'{vpp_stage}{Const.VPP_SEP}' @@ -677,8 +666,9 @@ class TrainerMon: for metric_name in self.ops: if context.micro_step == 0 and context.actv.get(metric_name, []): - print_warn_log( - f"actv context of {context.module_name} is not empty when first micro_step, maybe something wrong happened. Now clear it.") + print_log_with_rank( + f"actv context of {context.module_name} is not empty when first micro_step, maybe " + f"something wrong happened. Now clear it.", self.rank, 'WARNING') context.actv.clear() context.actv[metric_name].update(get_metrics(metric_name, tbtag_tensor_map, self.eps)) @@ -716,7 +706,8 @@ class TrainerMon: tbtag_tensor_map.update(self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', Const.ACTVGRAD_OUT, cared_output_grad)) if context.micro_step == 0 and context.actvgrad: - print_warn_log(f"actvgrad context of {context.module_name} is not empty when first micro_step, maybe something wrong happened. Now clear it.") + print_log_with_rank(f"actvgrad context of {context.module_name} is not empty when first micro_step, " + f"maybe something wrong happened. Now clear it.", self.rank, 'WARNING') context.actvgrad.clear() for metric_name in self.ops: @@ -731,7 +722,7 @@ class TrainerMon: return if self.backward_only and self.forward_only: - print_warn_log('not enable backward_only and forward_only simultaneously') + print_log_with_rank('not enable backward_only and forward_only simultaneously', self.rank, 'ERROR') hooked_count = 0 if self.xy_distribution or self.print_struct: @@ -747,7 +738,7 @@ class TrainerMon: handle = submodule.register_full_backward_hook(bwd_hook_fun) self.handles['xy'].append(handle) self.module_bwd_hook_context_by_module[submodule] = ModuleHookContext(name) - print_rank_0(f"> {name} is monitored successfully") + print_log_with_rank(f"> {name} is monitored successfully", 0, 'INFO') hooked_count += 1 return hooked_count diff --git a/debug/accuracy_tools/kj600/kj600/utils.py b/debug/accuracy_tools/kj600/kj600/utils.py index a5c0b44bb36..432bf3e321f 100644 --- a/debug/accuracy_tools/kj600/kj600/utils.py +++ b/debug/accuracy_tools/kj600/kj600/utils.py @@ -4,6 +4,7 @@ import sys import re import warnings import torch +import torch.distributed as dist from kj600.const import Const @@ -14,6 +15,7 @@ FILE_NAME_MAX_LENGTH = 255 DIRECTORY_MAX_LENGTH = 4096 FILE_NAME_VALID_PATTERN = r"^[a-zA-Z0-9_.:/-]+$" + def _print_log(level, msg, end='\n'): current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(int(time.time()))) pid = os.getgid() @@ -51,6 +53,22 @@ def print_warn_log(warn_msg): _print_log("WARNING", warn_msg) +print_method_map = { + "INFO": print_info_log, + "ERROR": print_error_log, + "WARNING": print_warn_log +} + + +def print_log_with_rank(msg: str, rank: int, level: str): + print_method = print_method_map.get(level, print_info_log) + if dist.is_initialized(): + if dist.get_rank() == rank: + print_method(f'[RANK{rank}]{msg}') + else: + print_method(msg) + + def get_param_struct(param): res = {} if isinstance(param, (tuple, list)): -- Gitee From ef0eb1a021a11de4d3c348a946da1a1bb5fd6fde Mon Sep 17 00:00:00 2001 From: heweidong7 <511650494@qq.com> Date: Fri, 18 Oct 2024 09:43:53 +0800 Subject: [PATCH 02/13] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E6=97=A0=E6=B3=95?= =?UTF-8?q?=E6=89=93=E5=8D=B0=E9=80=9A=E4=BF=A1=E8=B0=83=E7=94=A8=E6=A0=88?= =?UTF-8?q?=E7=9A=84bug=EF=BC=9A=E5=88=A0=E9=99=A4self.verbose?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- debug/accuracy_tools/kj600/kj600/module_hook.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/debug/accuracy_tools/kj600/kj600/module_hook.py b/debug/accuracy_tools/kj600/kj600/module_hook.py index 6bbcee832f4..061bfaf3293 100644 --- a/debug/accuracy_tools/kj600/kj600/module_hook.py +++ b/debug/accuracy_tools/kj600/kj600/module_hook.py @@ -244,10 +244,7 @@ class TrainerMon: raise Exception("ur_distribution cannot be enabled with unknown optimizer.") if self.mv_distribution: raise Exception("mv_distribution cannot be enabled with unknown optimizer.") - self.verbose = False self.print_struct = self.config.get("print_struct", False) - if self.print_struct: - self.verbose = True self.struct_printed = False self.module_struct = {} -- Gitee From 6ff0fe89350721680d73653351835c8ba4cab0b1 Mon Sep 17 00:00:00 2001 From: heweidong7 <511650494@qq.com> Date: Fri, 18 Oct 2024 10:10:29 +0800 Subject: [PATCH 03/13] =?UTF-8?q?=E5=A2=9E=E5=8A=A0rank=3D-1=E6=97=B6?= =?UTF-8?q?=E9=BB=98=E8=AE=A4=E6=89=93=E5=8D=B0=E5=BD=93=E5=89=8Drank?= =?UTF-8?q?=E4=BF=A1=E6=81=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- debug/accuracy_tools/kj600/kj600/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/debug/accuracy_tools/kj600/kj600/utils.py b/debug/accuracy_tools/kj600/kj600/utils.py index 432bf3e321f..d7f8a05e027 100644 --- a/debug/accuracy_tools/kj600/kj600/utils.py +++ b/debug/accuracy_tools/kj600/kj600/utils.py @@ -63,8 +63,8 @@ print_method_map = { def print_log_with_rank(msg: str, rank: int, level: str): print_method = print_method_map.get(level, print_info_log) if dist.is_initialized(): - if dist.get_rank() == rank: - print_method(f'[RANK{rank}]{msg}') + if dist.get_rank() == rank or rank == -1: + print_method(f'[RANK{dist.get_rank()}]{msg}') else: print_method(msg) -- Gitee From 3c29deadf849beda799364243d5cbe7ad4dc1e84 Mon Sep 17 00:00:00 2001 From: heweidong7 <511650494@qq.com> Date: Fri, 18 Oct 2024 10:11:40 +0800 Subject: [PATCH 04/13] =?UTF-8?q?=E5=88=A0=E9=99=A4print=5Frank=5F0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- debug/accuracy_tools/kj600/kj600/optimizer_collect.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/debug/accuracy_tools/kj600/kj600/optimizer_collect.py b/debug/accuracy_tools/kj600/kj600/optimizer_collect.py index 4e55d38c30c..04232cd8172 100644 --- a/debug/accuracy_tools/kj600/kj600/optimizer_collect.py +++ b/debug/accuracy_tools/kj600/kj600/optimizer_collect.py @@ -3,15 +3,7 @@ from collections import defaultdict, namedtuple import torch import torch.distributed as dist -from kj600.utils import print_warn_log, print_error_log - - -def print_rank_0(message): - if dist.is_initialized(): - if dist.get_rank() == 0: - print(message) - else: - print(message) +from kj600.utils import print_warn_log MVResult = namedtuple('MVResult', ("exp_avg", "exp_avg_sq", "update", "ratio")) -- Gitee From c84c27f08fe5e5faea09b35b1783d67ea8fe611f Mon Sep 17 00:00:00 2001 From: heweidong7 <511650494@qq.com> Date: Fri, 18 Oct 2024 10:30:48 +0800 Subject: [PATCH 05/13] =?UTF-8?q?=E5=90=88=E5=B9=B6optimizer=5Fcollect.py?= =?UTF-8?q?=E4=B8=AD=E7=9A=84=E6=89=93=E5=8D=B0=E6=96=B9=E6=B3=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../accuracy_tools/kj600/kj600/optimizer_collect.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/debug/accuracy_tools/kj600/kj600/optimizer_collect.py b/debug/accuracy_tools/kj600/kj600/optimizer_collect.py index 04232cd8172..29b48b447bc 100644 --- a/debug/accuracy_tools/kj600/kj600/optimizer_collect.py +++ b/debug/accuracy_tools/kj600/kj600/optimizer_collect.py @@ -3,7 +3,7 @@ from collections import defaultdict, namedtuple import torch import torch.distributed as dist -from kj600.utils import print_warn_log +from kj600.utils import print_log_with_rank MVResult = namedtuple('MVResult', ("exp_avg", "exp_avg_sq", "update", "ratio")) @@ -40,7 +40,8 @@ class OptimizerMon(ABC): exp_avg = state_param.get("exp_avg", None) exp_avg_sq = state_param.get("exp_avg_sq", None) if exp_avg is None or exp_avg_sq is None: - print_warn_log(f"exp_avg or exp_avg_sq of {name} is None, maybe something wrong happened.") + print_log_with_rank(f"exp_avg or exp_avg_sq of {name} is None, " + f"maybe something wrong happened.", -1, "WARNING") continue if monitor.mv_distribution: exp_avg_dict[name] = exp_avg @@ -53,7 +54,8 @@ class OptimizerMon(ABC): elif 'step' in torch_opt.param_groups[0]: step = torch_opt.param_groups[0]['step'] # AdamW from mindspeed else: - print_warn_log(f"step of {name} is None, maybe something wrong happened.") + print_log_with_rank(f"step of {name} is None, maybe something wrong happened.", + -1, "WARNING") continue exp_avg_hat = exp_avg / (1 - torch_opt.defaults['betas'][0] ** step) exp_avg_sq_hat = exp_avg_sq / (1 - torch_opt.defaults['betas'][1] ** step) @@ -82,7 +84,8 @@ class OptimizerMon(ABC): exp_avg = state_param.get("exp_avg", None) exp_avg_sq = state_param.get("exp_avg_sq", None) if exp_avg is None or exp_avg_sq is None: - print_warn_log(f"exp_avg or exp_avg_sq of {name} is None, maybe something wrong happened.") + print_log_with_rank(f"exp_avg or exp_avg_sq of {name} is None, maybe something wrong happened.", + -1, "WARNING") continue exp_avg = exp_avg[start_idx: end_idx] exp_avg_sq = exp_avg_sq[start_idx: end_idx] @@ -97,7 +100,7 @@ class OptimizerMon(ABC): elif 'step' in torch_opt.param_groups[0]: step = torch_opt.param_groups[0]['step'] # AdamW from mindspeed else: - print_warn_log(f"step of {name} is None, maybe something wrong happened.") + print_log_with_rank(f"step of {name} is None, maybe something wrong happened.", -1, "WARNING") continue exp_avg_hat = exp_avg / (1 - torch_opt.defaults['betas'][0] ** step) exp_avg_sq_hat = exp_avg_sq / (1 - torch_opt.defaults['betas'][1] ** step) -- Gitee From 827201bf604550d495a14966343d0f3691e00ae1 Mon Sep 17 00:00:00 2001 From: heweidong7 <511650494@qq.com> Date: Fri, 18 Oct 2024 10:33:18 +0800 Subject: [PATCH 06/13] =?UTF-8?q?=E5=90=88=E5=B9=B6module=5Fspec=5Fverifie?= =?UTF-8?q?r.py=E4=B8=AD=E7=9A=84=E6=89=93=E5=8D=B0=E6=96=B9=E6=B3=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- debug/accuracy_tools/kj600/kj600/module_spec_verifier.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/debug/accuracy_tools/kj600/kj600/module_spec_verifier.py b/debug/accuracy_tools/kj600/kj600/module_spec_verifier.py index 69f15afdbc9..c84acc4006f 100644 --- a/debug/accuracy_tools/kj600/kj600/module_spec_verifier.py +++ b/debug/accuracy_tools/kj600/kj600/module_spec_verifier.py @@ -1,9 +1,8 @@ -import json import re import abc import torch -from kj600.utils import print_warn_log +from kj600.utils import print_log_with_rank # 用于存储所有validator实现类的注册表 config_validator_registry = {} @@ -67,7 +66,8 @@ def validate_config_spec(config_spec:str, actual_data, module_name:str, data_typ try: focused_col = config_validator.validate(actual_data, module_name, data_type, pattern_match) except ValueError as e: - print_warn_log(str(e)) + print_log_with_rank(str(e), -1, 'WARNING') return focused_col - print_warn_log(f"config spec in {module_name} {data_type} not supported, expected spec:'tuple\[(\d+)\]:(\d+)' or 'tensor', actual spec: {config_spec}.") + print_log_with_rank(f"config spec in {module_name} {data_type} not supported, expected spec:'tuple\[(\d+)\]:(\d+)' " + f"or 'tensor', actual spec: {config_spec}.", -1, 'WARNING') return focused_col -- Gitee From 09d106c5b3c1f91a79be5a19abf7b2ea34af487a Mon Sep 17 00:00:00 2001 From: heweidong7 <511650494@qq.com> Date: Fri, 18 Oct 2024 10:36:21 +0800 Subject: [PATCH 07/13] =?UTF-8?q?=E5=90=88=E5=B9=B6module=5Fmetric.py?= =?UTF-8?q?=E4=B8=AD=E7=9A=84=E6=89=93=E5=8D=B0=E6=96=B9=E6=B3=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- debug/accuracy_tools/kj600/kj600/module_metric.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/debug/accuracy_tools/kj600/kj600/module_metric.py b/debug/accuracy_tools/kj600/kj600/module_metric.py index 179a9eea2a8..d5041aaf5e6 100644 --- a/debug/accuracy_tools/kj600/kj600/module_metric.py +++ b/debug/accuracy_tools/kj600/kj600/module_metric.py @@ -6,7 +6,7 @@ import torch from kj600.const import Const from kj600.features import square_sum, get_max, get_min, get_zeros, get_nans, get_norm, get_mean -from kj600.utils import print_warn_log +from kj600.utils import print_log_with_rank def get_summary_writer_tag_name(module_or_param_name:str, tag:str, rank): @@ -77,7 +77,7 @@ class Metric(object): try: metrics_dict[tag] = self.get_metric_value(tensor, eps) if torch.isnan(metrics_dict[tag]): - print_warn_log(f'nan when calculate metric for {tag}') + print_log_with_rank(f'nan when calculate metric for {tag}', -1, 'WARNING') except RuntimeError as e: metrics_dict[tag] = torch.tensor(torch.nan) return metrics_dict -- Gitee From f51ec7e9368bd197a3399e5eba63a728a2b76288 Mon Sep 17 00:00:00 2001 From: heweidong7 <511650494@qq.com> Date: Fri, 18 Oct 2024 10:40:58 +0800 Subject: [PATCH 08/13] =?UTF-8?q?=E5=90=88=E5=B9=B6file=5Fcheck.py?= =?UTF-8?q?=E4=B8=AD=E7=9A=84=E6=89=93=E5=8D=B0=E6=96=B9=E6=B3=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../accuracy_tools/kj600/kj600/file_check.py | 38 +++++++++---------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/debug/accuracy_tools/kj600/kj600/file_check.py b/debug/accuracy_tools/kj600/kj600/file_check.py index 80f456a6287..6aef1e48eba 100644 --- a/debug/accuracy_tools/kj600/kj600/file_check.py +++ b/debug/accuracy_tools/kj600/kj600/file_check.py @@ -17,7 +17,7 @@ import os import re -from kj600.utils import print_error_log +from kj600.utils import print_log_with_rank class CodedException(Exception): @@ -94,8 +94,8 @@ class FileChecker: @staticmethod def _check_path_type(path_type): if path_type not in [FileCheckConst.DIR, FileCheckConst.FILE]: - print_error_log( - f"The path_type must be {FileCheckConst.DIR} or {FileCheckConst.FILE}." + print_log_with_rank( + f"The path_type must be {FileCheckConst.DIR} or {FileCheckConst.FILE}.", -1, 'ERROR' ) raise FileCheckException(FileCheckException.ILLEGAL_PARAM_ERROR) return path_type @@ -167,7 +167,7 @@ class FileOpen: + self.SUPPORT_READ_WRITE_MODE ) if self.mode not in support_mode: - print_error_log(f"File open not support {self.mode} mode") + print_log_with_rank(f"File open not support {self.mode} mode", -1, 'ERROR') raise FileCheckException(FileCheckException.ILLEGAL_PARAM_ERROR) check_link(self.file_path) self.file_path = os.path.realpath(self.file_path) @@ -194,45 +194,45 @@ class FileOpen: def check_link(path): abs_path = os.path.abspath(path) if os.path.islink(abs_path): - print_error_log(f"The file path {path} is a soft link.") + print_log_with_rank(f"The file path {path} is a soft link.", -1, 'ERROR') raise FileCheckException(FileCheckException.SOFT_LINK_ERROR) def check_path_length(path): if path_len_exceeds_limit(path): - print_error_log("The file path length exceeds limit.") + print_log_with_rank("The file path length exceeds limit.", -1, 'ERROR') raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR) def check_path_exists(path): if not os.path.exists(path): - print_error_log(f"The file path {path} does not exist.") + print_log_with_rank(f"The file path {path} does not exist.", -1, 'ERROR') raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR) def check_path_readability(path): if not os.access(path, os.R_OK): - print_error_log(f"The file path {path} is not readable.") + print_log_with_rank(f"The file path {path} is not readable.", -1, 'ERROR') raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR) def check_path_writability(path): if not os.access(path, os.W_OK): - print_error_log(f"The file path {path} is not writable.") + print_log_with_rank(f"The file path {path} is not writable.", -1, 'ERROR') raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR) def check_path_executable(path): if not os.access(path, os.X_OK): - print_error_log(f"The file path {path} is not executable.") + print_log_with_rank(f"The file path {path} is not executable.", -1, 'ERROR') raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR) def check_other_user_writable(path): st = os.stat(path) if st.st_mode & 0o002: - print_error_log( - f"The file path {path} may be insecure because other users have write permissions. " + print_log_with_rank( + f"The file path {path} may be insecure because other users have write permissions. ", -1, 'ERROR' ) raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR) @@ -240,22 +240,22 @@ def check_other_user_writable(path): def check_path_owner_consistent(path): file_owner = os.stat(path).st_uid if file_owner != os.getuid(): - print_error_log( - f"The file path {path} may be insecure because is does not belong to you." + print_log_with_rank( + f"The file path {path} may be insecure because is does not belong to you.", -1, 'ERROR' ) raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR) def check_path_pattern_vaild(path): if not re.match(FileCheckConst.FILE_VALID_PATTERN, path): - print_error_log(f"The file path {path} contains special characters.") + print_log_with_rank(f"The file path {path} contains special characters.", -1, 'ERROR') raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR) def check_file_size(file_path, max_size): file_size = os.path.getsize(file_path) if file_size >= max_size: - print_error_log(f"The size of file path {file_path} exceeds {max_size} bytes.") + print_log_with_rank(f"The size of file path {file_path} exceeds {max_size} bytes.", -1, 'ERROR') raise FileCheckException(FileCheckException.FILE_TOO_LARGE_ERROR) @@ -270,18 +270,18 @@ def check_common_file_size(file_path): def check_file_suffix(file_path, file_suffix): if file_suffix: if not file_path.endswith(file_suffix): - print_error_log(f"The {file_path} should be a {file_suffix} file!") + print_log_with_rank(f"The {file_path} should be a {file_suffix} file!", -1, 'ERROR') raise FileCheckException(FileCheckException.INVALID_FILE_ERROR) def check_path_type(file_path, file_type): if file_type == FileCheckConst.FILE: if not os.path.isfile(file_path): - print_error_log(f"The {file_path} should be a file!") + print_log_with_rank(f"The {file_path} should be a file!", -1, 'ERROR') raise FileCheckException(FileCheckException.INVALID_FILE_ERROR) if file_type == FileCheckConst.DIR: if not os.path.isdir(file_path): - print_error_log(f"The {file_path} should be a dictionary!") + print_log_with_rank(f"The {file_path} should be a dictionary!", -1, 'ERROR') raise FileCheckException(FileCheckException.INVALID_FILE_ERROR) -- Gitee From acee53f6eeccd12ac410e656704158c9c0ef10b3 Mon Sep 17 00:00:00 2001 From: heweidong7 <511650494@qq.com> Date: Fri, 18 Oct 2024 10:41:51 +0800 Subject: [PATCH 09/13] =?UTF-8?q?=E5=90=88=E5=B9=B6features.py=E4=B8=AD?= =?UTF-8?q?=E7=9A=84=E6=89=93=E5=8D=B0=E6=96=B9=E6=B3=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- debug/accuracy_tools/kj600/kj600/features.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/debug/accuracy_tools/kj600/kj600/features.py b/debug/accuracy_tools/kj600/kj600/features.py index 302ab3f8c55..77f46fb8e59 100644 --- a/debug/accuracy_tools/kj600/kj600/features.py +++ b/debug/accuracy_tools/kj600/kj600/features.py @@ -1,6 +1,6 @@ import torch from torch.autograd.functional import jacobian -from kj600.utils import print_info_log +from kj600.utils import print_log_with_rank @torch.no_grad() @@ -34,7 +34,7 @@ def get_sign_matches(x: torch.tensor, y:torch.tensor): try: same_direction_ratio = ((xs * ys).sum()/ys.numel() + 1)/2 except RuntimeError as e: - print_info_log(f"RuntimeError: {e}") + print_log_with_rank(f"RuntimeError: {e}", -1, 'INFO') same_direction_ratio = torch.tensor(0.) return same_direction_ratio -- Gitee From 53b7f9d8a071984ba6388993fd6080f7cd618670 Mon Sep 17 00:00:00 2001 From: heweidong7 <511650494@qq.com> Date: Fri, 18 Oct 2024 10:44:18 +0800 Subject: [PATCH 10/13] =?UTF-8?q?=E5=90=88=E5=B9=B6anomaly=5Fdetect.py?= =?UTF-8?q?=E4=B8=AD=E7=9A=84=E6=89=93=E5=8D=B0=E6=96=B9=E6=B3=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- debug/accuracy_tools/kj600/kj600/anomaly_detect.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/debug/accuracy_tools/kj600/kj600/anomaly_detect.py b/debug/accuracy_tools/kj600/kj600/anomaly_detect.py index 36768ac973d..d13f8b61b4d 100644 --- a/debug/accuracy_tools/kj600/kj600/anomaly_detect.py +++ b/debug/accuracy_tools/kj600/kj600/anomaly_detect.py @@ -7,7 +7,7 @@ from collections import defaultdict from dataclasses import dataclass, field import pandas as pd from torch.utils.tensorboard import SummaryWriter -from kj600.utils import print_info_log, check_file_valid_writable, make_file_safety, create_directory +from kj600.utils import print_log_with_rank, check_file_valid_writable, make_file_safety, create_directory from kj600.const import Const from kj600.file_check import change_mode, FileCheckConst @@ -164,7 +164,7 @@ class BaseWriterWithAD: detected, rule_name = self._ad(scalar_value, history=avg) if detected: exception_message = f"Rule {rule_name} reports anomaly signal in {tag} at step {global_step}." - print_info_log(f"{bcolors.WARNING}> {exception_message}{bcolors.ENDC}") + print_log_with_rank(f"{bcolors.WARNING}> {exception_message}{bcolors.ENDC}", -1, 'INFO') if self.anomaly_inform: self.anomaly_inform.run(exception_message, self.job_id) -- Gitee From 17ac4a8180b9df60edd67009c408c06945939779 Mon Sep 17 00:00:00 2001 From: heweidong7 <511650494@qq.com> Date: Fri, 18 Oct 2024 10:47:53 +0800 Subject: [PATCH 11/13] =?UTF-8?q?=E5=90=88=E5=B9=B6anomaly=5Fanalyse.py?= =?UTF-8?q?=E4=B8=AD=E7=9A=84=E6=89=93=E5=8D=B0=E6=96=B9=E6=B3=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../kj600/kj600/anomaly_analyse.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/debug/accuracy_tools/kj600/kj600/anomaly_analyse.py b/debug/accuracy_tools/kj600/kj600/anomaly_analyse.py index 963ec0a0625..b44e5d36eba 100644 --- a/debug/accuracy_tools/kj600/kj600/anomaly_analyse.py +++ b/debug/accuracy_tools/kj600/kj600/anomaly_analyse.py @@ -24,7 +24,7 @@ import os from pathlib import Path import sys -from kj600.utils import print_info_log, print_warn_log +from kj600.utils import print_log_with_rank from kj600.anomaly_detect import GradAnomalyData from kj600.file_check import ( change_mode, @@ -84,7 +84,7 @@ class AnomalyDataWriter: self.json_path, FileCheckConst.FILE, FileCheckConst.WRITE_ABLE ) file_check.common_check() - print_warn_log(f"The existing file will be deleted: {self.json_path}.") + print_log_with_rank(f"The existing file will be deleted: {self.json_path}.", -1, 'WARNING') os.remove(self.json_path) Path(self.json_path).touch() change_mode(self.json_path, FileCheckConst.DATA_FILE_AUTHORITY) @@ -96,7 +96,7 @@ class AnomalyDataWriter: anomalies: GradAnomalyData对象列表 """ anomalies_json = self.get_anomaly_dict(anomalies) - print_info_log(f"{ANOMALY_JSON} is at {self.dump_rank_dir}.") + print_log_with_rank(f"{ANOMALY_JSON} is at {self.dump_rank_dir}.", -1, 'INFO') if Path(self.json_path).exists() and os.path.getsize(self.json_path) > 0: with FileOpen(self.json_path, "r+") as f: fcntl.flock(f, fcntl.LOCK_EX) @@ -119,10 +119,10 @@ class AnomalyDataLoader: try: instances.append(GradAnomalyData(**values)) except KeyError as e: - print_warn_log(f"Missing key in anomaly data: {e}") + print_log_with_rank(f"Missing key in anomaly data: {e}", -1, 'WARNING') except ValueError as e: - print_warn_log( - f"Value error when creating a GradAnomalyData instance: {e}" + print_log_with_rank( + f"Value error when creating a GradAnomalyData instance: {e}", -1, 'WARNING' ) return instances @@ -178,14 +178,14 @@ class AnomalyAnalyse: file_check.common_check() sorted_data = AnomalyDataWriter.get_anomaly_dict(self.sorted_anomalies) - print_info_log(f"{ANALYSE_JSON} is at {output_path}.") + print_log_with_rank(f"{ANALYSE_JSON} is at {output_path}.", -1, 'INFO') json_path = os.path.join(output_path, ANALYSE_JSON) if os.path.exists(json_path): file_check = FileChecker( json_path, FileCheckConst.FILE, FileCheckConst.WRITE_ABLE ) file_check.common_check() - print_warn_log(f"The existing file will be deleted: {json_path}.") + print_log_with_rank(f"The existing file will be deleted: {json_path}.", -1, 'WARNING') os.remove(json_path) Path(json_path).touch() change_mode(json_path, FileCheckConst.DATA_FILE_AUTHORITY) @@ -238,11 +238,11 @@ def _anomaly_analyse(): args.out_path if args.out_path else args.data_path_dir ) - print_info_log(f"Top {top_k_number} anomalies are listed as follows:") + print_log_with_rank(f"Top {top_k_number} anomalies are listed as follows:", -1, 'INFO') for index, anomaly in enumerate(top_anomalies): - print_info_log(f"{index}: {anomaly.message}") + print_log_with_rank(f"{index}: {anomaly.message}", -1, 'INFO') if __name__ == "__main__": _anomaly_analyse() - print_info_log("Analyse task completed.") + print_log_with_rank("Analyse task completed.", -1, 'INFO') -- Gitee From ac6aa1f2f3ba46422216a9c13534c700522e2f82 Mon Sep 17 00:00:00 2001 From: heweidong7 <511650494@qq.com> Date: Fri, 18 Oct 2024 10:52:23 +0800 Subject: [PATCH 12/13] =?UTF-8?q?=E5=90=88=E5=B9=B6utils.py=E4=B8=AD?= =?UTF-8?q?=E7=9A=84=E6=89=93=E5=8D=B0=E6=96=B9=E6=B3=95=EF=BC=8C=E5=A2=9E?= =?UTF-8?q?=E5=8A=A0=E6=B3=A8=E9=87=8A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- debug/accuracy_tools/kj600/kj600/utils.py | 25 +++++++++++++++-------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/debug/accuracy_tools/kj600/kj600/utils.py b/debug/accuracy_tools/kj600/kj600/utils.py index d7f8a05e027..f8f7440cbc9 100644 --- a/debug/accuracy_tools/kj600/kj600/utils.py +++ b/debug/accuracy_tools/kj600/kj600/utils.py @@ -23,7 +23,7 @@ def _print_log(level, msg, end='\n'): sys.stdout.flush() -def print_info_log(info_msg, end='\n'): +def _print_info_log(info_msg, end='\n'): """ Function Description: print info log. @@ -33,7 +33,7 @@ def print_info_log(info_msg, end='\n'): _print_log("INFO", info_msg, end=end) -def print_error_log(error_msg): +def _print_error_log(error_msg): """ Function Description: print error log. @@ -43,7 +43,7 @@ def print_error_log(error_msg): _print_log("ERROR", error_msg) -def print_warn_log(warn_msg): +def _print_warn_log(warn_msg): """ Function Description: print warn log. @@ -54,14 +54,21 @@ def print_warn_log(warn_msg): print_method_map = { - "INFO": print_info_log, - "ERROR": print_error_log, - "WARNING": print_warn_log + "INFO": _print_info_log, + "ERROR": _print_error_log, + "WARNING": _print_warn_log } def print_log_with_rank(msg: str, rank: int, level: str): - print_method = print_method_map.get(level, print_info_log) + """Logs a message with a specified rank and log level. + + Args: + msg (str): The message to log. + rank (int): The rank to check against. If -1, logs from current rank. + level (str): The log level ('INFO', 'ERROR', 'WARNING'). + """ + print_method = print_method_map.get(level, _print_info_log) if dist.is_initialized(): if dist.get_rank() == rank or rank == -1: print_method(f'[RANK{dist.get_rank()}]{msg}') @@ -80,7 +87,7 @@ def get_param_struct(param): res['tensor'] = f'size={tuple(param.shape)}, dtype={param.dtype}' else: res['config'] = f'{type(param)}' - print_warn_log(f'Not support type({type(param)}) now, please check the type of param {param}') + print_log_with_rank(f'Not support type({type(param)}) now, please check the type of param {param}', -1, 'WARNING') return res @@ -175,7 +182,7 @@ def validate_ops(ops): valid_ops = [] for op in ops: if op not in Const.OP_LIST: - print_warn_log(f"given op {op} is not supported. Optional ops: {Const.OP_LIST}") + print_log_with_rank(f"given op {op} is not supported. Optional ops: {Const.OP_LIST}", -1, 'WARNING') else: valid_ops.append(op) return valid_ops -- Gitee From 8a4d083fa299d950ac897beebc6d09382a21154c Mon Sep 17 00:00:00 2001 From: heweidong7 <511650494@qq.com> Date: Fri, 18 Oct 2024 11:11:39 +0800 Subject: [PATCH 13/13] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E4=B8=BA=E5=B8=B8?= =?UTF-8?q?=E9=87=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../kj600/kj600/anomaly_analyse.py | 21 ++++++---- .../kj600/kj600/anomaly_detect.py | 3 +- debug/accuracy_tools/kj600/kj600/const.py | 7 +++- debug/accuracy_tools/kj600/kj600/features.py | 4 +- .../accuracy_tools/kj600/kj600/file_check.py | 31 +++++++------- .../accuracy_tools/kj600/kj600/module_hook.py | 42 +++++++++---------- .../kj600/kj600/module_metric.py | 2 +- .../kj600/kj600/module_spec_verifier.py | 5 ++- .../kj600/kj600/optimizer_collect.py | 12 +++--- debug/accuracy_tools/kj600/kj600/utils.py | 14 ++++--- 10 files changed, 79 insertions(+), 62 deletions(-) diff --git a/debug/accuracy_tools/kj600/kj600/anomaly_analyse.py b/debug/accuracy_tools/kj600/kj600/anomaly_analyse.py index b44e5d36eba..d29eb5e4cbf 100644 --- a/debug/accuracy_tools/kj600/kj600/anomaly_analyse.py +++ b/debug/accuracy_tools/kj600/kj600/anomaly_analyse.py @@ -24,6 +24,7 @@ import os from pathlib import Path import sys +from kj600.const import Const from kj600.utils import print_log_with_rank from kj600.anomaly_detect import GradAnomalyData from kj600.file_check import ( @@ -84,7 +85,8 @@ class AnomalyDataWriter: self.json_path, FileCheckConst.FILE, FileCheckConst.WRITE_ABLE ) file_check.common_check() - print_log_with_rank(f"The existing file will be deleted: {self.json_path}.", -1, 'WARNING') + print_log_with_rank(f"The existing file will be deleted: {self.json_path}.", + Const.CURRENT_RANK, Const.WARNING) os.remove(self.json_path) Path(self.json_path).touch() change_mode(self.json_path, FileCheckConst.DATA_FILE_AUTHORITY) @@ -96,7 +98,7 @@ class AnomalyDataWriter: anomalies: GradAnomalyData对象列表 """ anomalies_json = self.get_anomaly_dict(anomalies) - print_log_with_rank(f"{ANOMALY_JSON} is at {self.dump_rank_dir}.", -1, 'INFO') + print_log_with_rank(f"{ANOMALY_JSON} is at {self.dump_rank_dir}.", Const.CURRENT_RANK, Const.INFO) if Path(self.json_path).exists() and os.path.getsize(self.json_path) > 0: with FileOpen(self.json_path, "r+") as f: fcntl.flock(f, fcntl.LOCK_EX) @@ -119,10 +121,10 @@ class AnomalyDataLoader: try: instances.append(GradAnomalyData(**values)) except KeyError as e: - print_log_with_rank(f"Missing key in anomaly data: {e}", -1, 'WARNING') + print_log_with_rank(f"Missing key in anomaly data: {e}", Const.CURRENT_RANK, Const.WARNING) except ValueError as e: print_log_with_rank( - f"Value error when creating a GradAnomalyData instance: {e}", -1, 'WARNING' + f"Value error when creating a GradAnomalyData instance: {e}", Const.CURRENT_RANK, Const.WARNING ) return instances @@ -178,14 +180,15 @@ class AnomalyAnalyse: file_check.common_check() sorted_data = AnomalyDataWriter.get_anomaly_dict(self.sorted_anomalies) - print_log_with_rank(f"{ANALYSE_JSON} is at {output_path}.", -1, 'INFO') + print_log_with_rank(f"{ANALYSE_JSON} is at {output_path}.", Const.CURRENT_RANK, Const.INFO) json_path = os.path.join(output_path, ANALYSE_JSON) if os.path.exists(json_path): file_check = FileChecker( json_path, FileCheckConst.FILE, FileCheckConst.WRITE_ABLE ) file_check.common_check() - print_log_with_rank(f"The existing file will be deleted: {json_path}.", -1, 'WARNING') + print_log_with_rank(f"The existing file will be deleted: {json_path}.", + Const.CURRENT_RANK, Const.WARNING) os.remove(json_path) Path(json_path).touch() change_mode(json_path, FileCheckConst.DATA_FILE_AUTHORITY) @@ -238,11 +241,11 @@ def _anomaly_analyse(): args.out_path if args.out_path else args.data_path_dir ) - print_log_with_rank(f"Top {top_k_number} anomalies are listed as follows:", -1, 'INFO') + print_log_with_rank(f"Top {top_k_number} anomalies are listed as follows:", Const.CURRENT_RANK, Const.INFO) for index, anomaly in enumerate(top_anomalies): - print_log_with_rank(f"{index}: {anomaly.message}", -1, 'INFO') + print_log_with_rank(f"{index}: {anomaly.message}", Const.CURRENT_RANK, Const.INFO) if __name__ == "__main__": _anomaly_analyse() - print_log_with_rank("Analyse task completed.", -1, 'INFO') + print_log_with_rank("Analyse task completed.", Const.CURRENT_RANK, Const.INFO) diff --git a/debug/accuracy_tools/kj600/kj600/anomaly_detect.py b/debug/accuracy_tools/kj600/kj600/anomaly_detect.py index d13f8b61b4d..a14b50ad108 100644 --- a/debug/accuracy_tools/kj600/kj600/anomaly_detect.py +++ b/debug/accuracy_tools/kj600/kj600/anomaly_detect.py @@ -164,7 +164,8 @@ class BaseWriterWithAD: detected, rule_name = self._ad(scalar_value, history=avg) if detected: exception_message = f"Rule {rule_name} reports anomaly signal in {tag} at step {global_step}." - print_log_with_rank(f"{bcolors.WARNING}> {exception_message}{bcolors.ENDC}", -1, 'INFO') + print_log_with_rank(f"{bcolors.WARNING}> {exception_message}{bcolors.ENDC}", + Const.CURRENT_RANK, Const.INFO) if self.anomaly_inform: self.anomaly_inform.run(exception_message, self.job_id) diff --git a/debug/accuracy_tools/kj600/kj600/const.py b/debug/accuracy_tools/kj600/kj600/const.py index 095356631d4..9d9e68458f0 100644 --- a/debug/accuracy_tools/kj600/kj600/const.py +++ b/debug/accuracy_tools/kj600/kj600/const.py @@ -9,4 +9,9 @@ class Const: OP_LIST = ['min', 'max', 'norm', 'mean', 'id', 'zeros', 'nans'] DEEPSPEED_OPT_TY = ("DeepSpeedZeroOptimizer_Stage1_or_2", "DeepSpeedZeroOptimizer_Stage3") - \ No newline at end of file + + # Used for print log + INFO = 'INFO' + WARNING = 'WARNING' + ERROR = 'ERROR' + CURRENT_RANK = -1 \ No newline at end of file diff --git a/debug/accuracy_tools/kj600/kj600/features.py b/debug/accuracy_tools/kj600/kj600/features.py index 77f46fb8e59..5b075f9850d 100644 --- a/debug/accuracy_tools/kj600/kj600/features.py +++ b/debug/accuracy_tools/kj600/kj600/features.py @@ -1,5 +1,7 @@ import torch from torch.autograd.functional import jacobian + +from kj600.const import Const from kj600.utils import print_log_with_rank @@ -34,7 +36,7 @@ def get_sign_matches(x: torch.tensor, y:torch.tensor): try: same_direction_ratio = ((xs * ys).sum()/ys.numel() + 1)/2 except RuntimeError as e: - print_log_with_rank(f"RuntimeError: {e}", -1, 'INFO') + print_log_with_rank(f"RuntimeError: {e}", Const.CURRENT_RANK, Const.INFO) same_direction_ratio = torch.tensor(0.) return same_direction_ratio diff --git a/debug/accuracy_tools/kj600/kj600/file_check.py b/debug/accuracy_tools/kj600/kj600/file_check.py index 6aef1e48eba..a12ed089bc7 100644 --- a/debug/accuracy_tools/kj600/kj600/file_check.py +++ b/debug/accuracy_tools/kj600/kj600/file_check.py @@ -17,6 +17,7 @@ import os import re +from kj600.const import Const from kj600.utils import print_log_with_rank @@ -95,7 +96,7 @@ class FileChecker: def _check_path_type(path_type): if path_type not in [FileCheckConst.DIR, FileCheckConst.FILE]: print_log_with_rank( - f"The path_type must be {FileCheckConst.DIR} or {FileCheckConst.FILE}.", -1, 'ERROR' + f"The path_type must be {FileCheckConst.DIR} or {FileCheckConst.FILE}.", Const.CURRENT_RANK, Const.ERROR ) raise FileCheckException(FileCheckException.ILLEGAL_PARAM_ERROR) return path_type @@ -167,7 +168,7 @@ class FileOpen: + self.SUPPORT_READ_WRITE_MODE ) if self.mode not in support_mode: - print_log_with_rank(f"File open not support {self.mode} mode", -1, 'ERROR') + print_log_with_rank(f"File open not support {self.mode} mode", Const.CURRENT_RANK, Const.ERROR) raise FileCheckException(FileCheckException.ILLEGAL_PARAM_ERROR) check_link(self.file_path) self.file_path = os.path.realpath(self.file_path) @@ -194,37 +195,37 @@ class FileOpen: def check_link(path): abs_path = os.path.abspath(path) if os.path.islink(abs_path): - print_log_with_rank(f"The file path {path} is a soft link.", -1, 'ERROR') + print_log_with_rank(f"The file path {path} is a soft link.", Const.CURRENT_RANK, Const.ERROR) raise FileCheckException(FileCheckException.SOFT_LINK_ERROR) def check_path_length(path): if path_len_exceeds_limit(path): - print_log_with_rank("The file path length exceeds limit.", -1, 'ERROR') + print_log_with_rank("The file path length exceeds limit.", Const.CURRENT_RANK, Const.ERROR) raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR) def check_path_exists(path): if not os.path.exists(path): - print_log_with_rank(f"The file path {path} does not exist.", -1, 'ERROR') + print_log_with_rank(f"The file path {path} does not exist.", Const.CURRENT_RANK, Const.ERROR) raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR) def check_path_readability(path): if not os.access(path, os.R_OK): - print_log_with_rank(f"The file path {path} is not readable.", -1, 'ERROR') + print_log_with_rank(f"The file path {path} is not readable.", Const.CURRENT_RANK, Const.ERROR) raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR) def check_path_writability(path): if not os.access(path, os.W_OK): - print_log_with_rank(f"The file path {path} is not writable.", -1, 'ERROR') + print_log_with_rank(f"The file path {path} is not writable.", Const.CURRENT_RANK, Const.ERROR) raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR) def check_path_executable(path): if not os.access(path, os.X_OK): - print_log_with_rank(f"The file path {path} is not executable.", -1, 'ERROR') + print_log_with_rank(f"The file path {path} is not executable.", Const.CURRENT_RANK, Const.ERROR) raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR) @@ -232,7 +233,7 @@ def check_other_user_writable(path): st = os.stat(path) if st.st_mode & 0o002: print_log_with_rank( - f"The file path {path} may be insecure because other users have write permissions. ", -1, 'ERROR' + f"The file path {path} may be insecure because other users have write permissions. ", Const.CURRENT_RANK, Const.ERROR ) raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR) @@ -241,21 +242,21 @@ def check_path_owner_consistent(path): file_owner = os.stat(path).st_uid if file_owner != os.getuid(): print_log_with_rank( - f"The file path {path} may be insecure because is does not belong to you.", -1, 'ERROR' + f"The file path {path} may be insecure because is does not belong to you.", Const.CURRENT_RANK, Const.ERROR ) raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR) def check_path_pattern_vaild(path): if not re.match(FileCheckConst.FILE_VALID_PATTERN, path): - print_log_with_rank(f"The file path {path} contains special characters.", -1, 'ERROR') + print_log_with_rank(f"The file path {path} contains special characters.", Const.CURRENT_RANK, Const.ERROR) raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR) def check_file_size(file_path, max_size): file_size = os.path.getsize(file_path) if file_size >= max_size: - print_log_with_rank(f"The size of file path {file_path} exceeds {max_size} bytes.", -1, 'ERROR') + print_log_with_rank(f"The size of file path {file_path} exceeds {max_size} bytes.", Const.CURRENT_RANK, Const.ERROR) raise FileCheckException(FileCheckException.FILE_TOO_LARGE_ERROR) @@ -270,18 +271,18 @@ def check_common_file_size(file_path): def check_file_suffix(file_path, file_suffix): if file_suffix: if not file_path.endswith(file_suffix): - print_log_with_rank(f"The {file_path} should be a {file_suffix} file!", -1, 'ERROR') + print_log_with_rank(f"The {file_path} should be a {file_suffix} file!", Const.CURRENT_RANK, Const.ERROR) raise FileCheckException(FileCheckException.INVALID_FILE_ERROR) def check_path_type(file_path, file_type): if file_type == FileCheckConst.FILE: if not os.path.isfile(file_path): - print_log_with_rank(f"The {file_path} should be a file!", -1, 'ERROR') + print_log_with_rank(f"The {file_path} should be a file!", Const.CURRENT_RANK, Const.ERROR) raise FileCheckException(FileCheckException.INVALID_FILE_ERROR) if file_type == FileCheckConst.DIR: if not os.path.isdir(file_path): - print_log_with_rank(f"The {file_path} should be a dictionary!", -1, 'ERROR') + print_log_with_rank(f"The {file_path} should be a dictionary!", Const.CURRENT_RANK, Const.ERROR) raise FileCheckException(FileCheckException.INVALID_FILE_ERROR) diff --git a/debug/accuracy_tools/kj600/kj600/module_hook.py b/debug/accuracy_tools/kj600/kj600/module_hook.py index 061bfaf3293..696140e0fea 100644 --- a/debug/accuracy_tools/kj600/kj600/module_hook.py +++ b/debug/accuracy_tools/kj600/kj600/module_hook.py @@ -136,29 +136,29 @@ class TrainerMon: self.all_xy = self.config.get('all_xy', False) self.xy_distribution = self.config.get('xy_distribution', False) if not self.xy_distribution: - print_log_with_rank("> module input/output input_grad/output_grad is not monitored. ", 0, 'INFO') + print_log_with_rank("> module input/output input_grad/output_grad is not monitored. ", 0, Const.INFO) # backward hook cause megatron-lm pipeline parallel schedule assert exception. # TBD: backward hook cause output tensor is view of some base tensor. root cause invesigation pending. self.forward_only = self.config.get('forward_only', False) if self.forward_only: - print_log_with_rank("> only module forward is monitored. ", 0, 'INFO') + print_log_with_rank("> only module forward is monitored. ", 0, Const.INFO) self.backward_only = self.config.get('backward_only', False) self.ur_distribution = self.config.get('ur_distribution', False) if not self.ur_distribution: - print_log_with_rank("> update vector and ratio vector of adam is not monitored. ", 0, 'INFO') + print_log_with_rank("> update vector and ratio vector of adam is not monitored. ", 0, Const.INFO) self.mv_distribution = self.config.get("mv_distribution", False) if not self.mv_distribution: - print_log_with_rank("> momentum and variance of adam is not monitored. ", 0, 'INFO') + print_log_with_rank("> momentum and variance of adam is not monitored. ", 0, Const.INFO) self.wg_distribution = self.config.get("wg_distribution", False) if not self.wg_distribution: - print_log_with_rank("> weight grad of specified module is not monitored. ", 0, 'INFO') + print_log_with_rank("> weight grad of specified module is not monitored. ", 0, Const.INFO) self.mg_direction = self.config.get('mg_direction', False) if not self.mg_direction: - print_log_with_rank('> grad and momentum direction will not be compared.', 0, 'INFO') + print_log_with_rank('> grad and momentum direction will not be compared.', 0, Const.INFO) self.cc_distribution = self.config.get("cc_distribution", {}) if not self.cc_distribution.get('enable', False): - print_log_with_rank("> cc operator is not monitored.", 0, 'INFO') + print_log_with_rank("> cc operator is not monitored.", 0, Const.INFO) self.cc_log_only = False else: self.cc_codeline = self.cc_distribution.get('cc_codeline', []) @@ -298,7 +298,7 @@ class TrainerMon: targets = [x for x, _ in model_chunk.named_modules()] if self.print_struct else self.config['targets'].keys() hooked_count += self._hook_module(targets, model_chunk, vpp_stage) - print_log_with_rank(f"> {hooked_count} out of {len(self.config['targets'])} are monitored.", 0, 'INFO') + print_log_with_rank(f"> {hooked_count} out of {len(self.config['targets'])} are monitored.", 0, Const.INFO) def clone_if_tensor(args): if isinstance(args, tuple): @@ -354,7 +354,7 @@ class TrainerMon: continue grad = param.main_grad if self.params_have_main_grad else param.grad if grad is None: - print_log_with_rank(f"grad is None: {name}, maybe something wrong happened.", self.rank, 'WARNING') + print_log_with_rank(f"grad is None: {name}, maybe something wrong happened.", self.rank, Const.WARNING) continue key = get_summary_writer_tag_name(name, 'post_grad', self.rank) grad_dict[key] = grad @@ -367,7 +367,7 @@ class TrainerMon: return reduced, unreduced def monitor_gnorm_with_ad(self, model, grad_acc_steps=1, optimizer=None, tp_group=None, dp_group=None): - print_log_with_rank(f'grad acc steps {grad_acc_steps}', self.rank, 'INFO') + print_log_with_rank(f'grad acc steps {grad_acc_steps}', self.rank, Const.INFO) self.hook_optimizer(optimizer) self.micro_batch_number = grad_acc_steps @@ -479,13 +479,13 @@ class TrainerMon: smallest_rank = min(self.module_rank_list) if self.module_rank_list else 0 if self.print_struct and not all(value == {} for value in self.module_struct.values()) and not self.struct_printed: - print_log_with_rank("> module struct:", smallest_rank, 'INFO') - print_log_with_rank(json.dumps(self.module_struct, indent=4), smallest_rank, 'INFO') + print_log_with_rank("> module struct:", smallest_rank, Const.INFO) + print_log_with_rank(json.dumps(self.module_struct, indent=4), smallest_rank, Const.INFO) if not self.cc_log_only: raise Exception("exit after first step when print model struct") if self.cc_log_only and context.step > 0: - print_log_with_rank("> Used communication ops and corresponding stack", smallest_rank, 'INFO') - print_log_with_rank(json.dumps({k:[i.split(';') for i in v] for k,v in self.cc_logged_stack.items()}, indent=4), smallest_rank, 'INFO') + print_log_with_rank("> Used communication ops and corresponding stack", smallest_rank, Const.INFO) + print_log_with_rank(json.dumps({k:[i.split(';') for i in v] for k,v in self.cc_logged_stack.items()}, indent=4), smallest_rank, Const.INFO) raise Exception("exit after first step when print cc stack") self.generate_wgrad_metrics() @@ -496,7 +496,7 @@ class TrainerMon: for param, name in self.param2name.items(): grad = param.main_grad if self.params_have_main_grad else param.grad if grad is None: - print_log_with_rank(f"grad is None: {name}, maybe something wrong happened.", self.rank, 'WARNING') + print_log_with_rank(f"grad is None: {name}, maybe something wrong happened.", self.rank, Const.WARNING) continue if context.step == 0: same_direction_ratio = torch.tensor(1.) @@ -591,7 +591,7 @@ class TrainerMon: name = prefix + squash_param_name(param_name) if name in self.param2name.values(): print_log_with_rank(f'same name {name} for different param. Current param is {param_name}. \ - May be error of squash_param_name', self.rank, 'ERROR') + May be error of squash_param_name', self.rank, Const.ERROR) raise Exception("param with same name will be overwritten.") self.param2name[param] = name self.name2param[name] = param @@ -607,7 +607,7 @@ class TrainerMon: if len(model) > 1: self.vpp = True smallest_rank = min(self.module_rank_list) if self.module_rank_list else 0 - print_log_with_rank('vpp enabled', smallest_rank, 'INFO') + print_log_with_rank('vpp enabled', smallest_rank, Const.INFO) for vpp_stage, model_chunk in enumerate(model): prefix = f'{vpp_stage}{Const.VPP_SEP}' @@ -665,7 +665,7 @@ class TrainerMon: if context.micro_step == 0 and context.actv.get(metric_name, []): print_log_with_rank( f"actv context of {context.module_name} is not empty when first micro_step, maybe " - f"something wrong happened. Now clear it.", self.rank, 'WARNING') + f"something wrong happened. Now clear it.", self.rank, Const.WARNING) context.actv.clear() context.actv[metric_name].update(get_metrics(metric_name, tbtag_tensor_map, self.eps)) @@ -704,7 +704,7 @@ class TrainerMon: if context.micro_step == 0 and context.actvgrad: print_log_with_rank(f"actvgrad context of {context.module_name} is not empty when first micro_step, " - f"maybe something wrong happened. Now clear it.", self.rank, 'WARNING') + f"maybe something wrong happened. Now clear it.", self.rank, Const.WARNING) context.actvgrad.clear() for metric_name in self.ops: @@ -719,7 +719,7 @@ class TrainerMon: return if self.backward_only and self.forward_only: - print_log_with_rank('not enable backward_only and forward_only simultaneously', self.rank, 'ERROR') + print_log_with_rank('not enable backward_only and forward_only simultaneously', self.rank, Const.ERROR) hooked_count = 0 if self.xy_distribution or self.print_struct: @@ -735,7 +735,7 @@ class TrainerMon: handle = submodule.register_full_backward_hook(bwd_hook_fun) self.handles['xy'].append(handle) self.module_bwd_hook_context_by_module[submodule] = ModuleHookContext(name) - print_log_with_rank(f"> {name} is monitored successfully", 0, 'INFO') + print_log_with_rank(f"> {name} is monitored successfully", 0, Const.INFO) hooked_count += 1 return hooked_count diff --git a/debug/accuracy_tools/kj600/kj600/module_metric.py b/debug/accuracy_tools/kj600/kj600/module_metric.py index d5041aaf5e6..b825750bd42 100644 --- a/debug/accuracy_tools/kj600/kj600/module_metric.py +++ b/debug/accuracy_tools/kj600/kj600/module_metric.py @@ -77,7 +77,7 @@ class Metric(object): try: metrics_dict[tag] = self.get_metric_value(tensor, eps) if torch.isnan(metrics_dict[tag]): - print_log_with_rank(f'nan when calculate metric for {tag}', -1, 'WARNING') + print_log_with_rank(f'nan when calculate metric for {tag}', Const.CURRENT_RANK, Const.WARNING) except RuntimeError as e: metrics_dict[tag] = torch.tensor(torch.nan) return metrics_dict diff --git a/debug/accuracy_tools/kj600/kj600/module_spec_verifier.py b/debug/accuracy_tools/kj600/kj600/module_spec_verifier.py index c84acc4006f..fe9c2c57d2a 100644 --- a/debug/accuracy_tools/kj600/kj600/module_spec_verifier.py +++ b/debug/accuracy_tools/kj600/kj600/module_spec_verifier.py @@ -2,6 +2,7 @@ import re import abc import torch +from kj600.const import Const from kj600.utils import print_log_with_rank # 用于存储所有validator实现类的注册表 @@ -66,8 +67,8 @@ def validate_config_spec(config_spec:str, actual_data, module_name:str, data_typ try: focused_col = config_validator.validate(actual_data, module_name, data_type, pattern_match) except ValueError as e: - print_log_with_rank(str(e), -1, 'WARNING') + print_log_with_rank(str(e), Const.CURRENT_RANK, Const.WARNING) return focused_col print_log_with_rank(f"config spec in {module_name} {data_type} not supported, expected spec:'tuple\[(\d+)\]:(\d+)' " - f"or 'tensor', actual spec: {config_spec}.", -1, 'WARNING') + f"or 'tensor', actual spec: {config_spec}.", Const.CURRENT_RANK, Const.WARNING) return focused_col diff --git a/debug/accuracy_tools/kj600/kj600/optimizer_collect.py b/debug/accuracy_tools/kj600/kj600/optimizer_collect.py index 29b48b447bc..95f6f5e652b 100644 --- a/debug/accuracy_tools/kj600/kj600/optimizer_collect.py +++ b/debug/accuracy_tools/kj600/kj600/optimizer_collect.py @@ -3,6 +3,7 @@ from collections import defaultdict, namedtuple import torch import torch.distributed as dist +from kj600.const import Const from kj600.utils import print_log_with_rank @@ -41,7 +42,7 @@ class OptimizerMon(ABC): exp_avg_sq = state_param.get("exp_avg_sq", None) if exp_avg is None or exp_avg_sq is None: print_log_with_rank(f"exp_avg or exp_avg_sq of {name} is None, " - f"maybe something wrong happened.", -1, "WARNING") + f"maybe something wrong happened.", Const.CURRENT_RANK, Const.WARNING) continue if monitor.mv_distribution: exp_avg_dict[name] = exp_avg @@ -54,8 +55,8 @@ class OptimizerMon(ABC): elif 'step' in torch_opt.param_groups[0]: step = torch_opt.param_groups[0]['step'] # AdamW from mindspeed else: - print_log_with_rank(f"step of {name} is None, maybe something wrong happened.", - -1, "WARNING") + print_log_with_rank(f"step of {name} is None, maybe something wrong happened.", + Const.CURRENT_RANK, Const.WARNING) continue exp_avg_hat = exp_avg / (1 - torch_opt.defaults['betas'][0] ** step) exp_avg_sq_hat = exp_avg_sq / (1 - torch_opt.defaults['betas'][1] ** step) @@ -85,7 +86,7 @@ class OptimizerMon(ABC): exp_avg_sq = state_param.get("exp_avg_sq", None) if exp_avg is None or exp_avg_sq is None: print_log_with_rank(f"exp_avg or exp_avg_sq of {name} is None, maybe something wrong happened.", - -1, "WARNING") + Const.CURRENT_RANK, Const.WARNING) continue exp_avg = exp_avg[start_idx: end_idx] exp_avg_sq = exp_avg_sq[start_idx: end_idx] @@ -100,7 +101,8 @@ class OptimizerMon(ABC): elif 'step' in torch_opt.param_groups[0]: step = torch_opt.param_groups[0]['step'] # AdamW from mindspeed else: - print_log_with_rank(f"step of {name} is None, maybe something wrong happened.", -1, "WARNING") + print_log_with_rank(f"step of {name} is None, maybe something wrong happened.", + Const.CURRENT_RANK, Const.WARNING) continue exp_avg_hat = exp_avg / (1 - torch_opt.defaults['betas'][0] ** step) exp_avg_sq_hat = exp_avg_sq / (1 - torch_opt.defaults['betas'][1] ** step) diff --git a/debug/accuracy_tools/kj600/kj600/utils.py b/debug/accuracy_tools/kj600/kj600/utils.py index f8f7440cbc9..1c4994e717c 100644 --- a/debug/accuracy_tools/kj600/kj600/utils.py +++ b/debug/accuracy_tools/kj600/kj600/utils.py @@ -54,9 +54,9 @@ def _print_warn_log(warn_msg): print_method_map = { - "INFO": _print_info_log, - "ERROR": _print_error_log, - "WARNING": _print_warn_log + Const.INFO: _print_info_log, + Const.WARNING: _print_warn_log, + Const.ERROR: _print_error_log } @@ -70,7 +70,7 @@ def print_log_with_rank(msg: str, rank: int, level: str): """ print_method = print_method_map.get(level, _print_info_log) if dist.is_initialized(): - if dist.get_rank() == rank or rank == -1: + if rank == dist.get_rank() or rank == Const.CURRENT_RANK: print_method(f'[RANK{dist.get_rank()}]{msg}') else: print_method(msg) @@ -87,7 +87,8 @@ def get_param_struct(param): res['tensor'] = f'size={tuple(param.shape)}, dtype={param.dtype}' else: res['config'] = f'{type(param)}' - print_log_with_rank(f'Not support type({type(param)}) now, please check the type of param {param}', -1, 'WARNING') + print_log_with_rank(f'Not support type({type(param)}) now, please check the type of param {param}', + Const.CURRENT_RANK, Const.WARNING) return res @@ -182,7 +183,8 @@ def validate_ops(ops): valid_ops = [] for op in ops: if op not in Const.OP_LIST: - print_log_with_rank(f"given op {op} is not supported. Optional ops: {Const.OP_LIST}", -1, 'WARNING') + print_log_with_rank(f"given op {op} is not supported. Optional ops: {Const.OP_LIST}", + Const.CURRENT_RANK, Const.WARNING) else: valid_ops.append(op) return valid_ops -- Gitee