diff --git a/debug/accuracy_tools/kj600/kj600/anomaly_analyse.py b/debug/accuracy_tools/kj600/kj600/anomaly_analyse.py index 963ec0a06256ceb212a3aabc08edfbf69900480e..d29eb5e4cbf807d9700acb41d6bea069c5575341 100644 --- a/debug/accuracy_tools/kj600/kj600/anomaly_analyse.py +++ b/debug/accuracy_tools/kj600/kj600/anomaly_analyse.py @@ -24,7 +24,8 @@ import os from pathlib import Path import sys -from kj600.utils import print_info_log, print_warn_log +from kj600.const import Const +from kj600.utils import print_log_with_rank from kj600.anomaly_detect import GradAnomalyData from kj600.file_check import ( change_mode, @@ -84,7 +85,8 @@ class AnomalyDataWriter: self.json_path, FileCheckConst.FILE, FileCheckConst.WRITE_ABLE ) file_check.common_check() - print_warn_log(f"The existing file will be deleted: {self.json_path}.") + print_log_with_rank(f"The existing file will be deleted: {self.json_path}.", + Const.CURRENT_RANK, Const.WARNING) os.remove(self.json_path) Path(self.json_path).touch() change_mode(self.json_path, FileCheckConst.DATA_FILE_AUTHORITY) @@ -96,7 +98,7 @@ class AnomalyDataWriter: anomalies: GradAnomalyData对象列表 """ anomalies_json = self.get_anomaly_dict(anomalies) - print_info_log(f"{ANOMALY_JSON} is at {self.dump_rank_dir}.") + print_log_with_rank(f"{ANOMALY_JSON} is at {self.dump_rank_dir}.", Const.CURRENT_RANK, Const.INFO) if Path(self.json_path).exists() and os.path.getsize(self.json_path) > 0: with FileOpen(self.json_path, "r+") as f: fcntl.flock(f, fcntl.LOCK_EX) @@ -119,10 +121,10 @@ class AnomalyDataLoader: try: instances.append(GradAnomalyData(**values)) except KeyError as e: - print_warn_log(f"Missing key in anomaly data: {e}") + print_log_with_rank(f"Missing key in anomaly data: {e}", Const.CURRENT_RANK, Const.WARNING) except ValueError as e: - print_warn_log( - f"Value error when creating a GradAnomalyData instance: {e}" + print_log_with_rank( + f"Value error when creating a GradAnomalyData instance: {e}", Const.CURRENT_RANK, Const.WARNING ) return instances @@ -178,14 +180,15 @@ class AnomalyAnalyse: file_check.common_check() sorted_data = AnomalyDataWriter.get_anomaly_dict(self.sorted_anomalies) - print_info_log(f"{ANALYSE_JSON} is at {output_path}.") + print_log_with_rank(f"{ANALYSE_JSON} is at {output_path}.", Const.CURRENT_RANK, Const.INFO) json_path = os.path.join(output_path, ANALYSE_JSON) if os.path.exists(json_path): file_check = FileChecker( json_path, FileCheckConst.FILE, FileCheckConst.WRITE_ABLE ) file_check.common_check() - print_warn_log(f"The existing file will be deleted: {json_path}.") + print_log_with_rank(f"The existing file will be deleted: {json_path}.", + Const.CURRENT_RANK, Const.WARNING) os.remove(json_path) Path(json_path).touch() change_mode(json_path, FileCheckConst.DATA_FILE_AUTHORITY) @@ -238,11 +241,11 @@ def _anomaly_analyse(): args.out_path if args.out_path else args.data_path_dir ) - print_info_log(f"Top {top_k_number} anomalies are listed as follows:") + print_log_with_rank(f"Top {top_k_number} anomalies are listed as follows:", Const.CURRENT_RANK, Const.INFO) for index, anomaly in enumerate(top_anomalies): - print_info_log(f"{index}: {anomaly.message}") + print_log_with_rank(f"{index}: {anomaly.message}", Const.CURRENT_RANK, Const.INFO) if __name__ == "__main__": _anomaly_analyse() - print_info_log("Analyse task completed.") + print_log_with_rank("Analyse task completed.", Const.CURRENT_RANK, Const.INFO) diff --git a/debug/accuracy_tools/kj600/kj600/anomaly_detect.py b/debug/accuracy_tools/kj600/kj600/anomaly_detect.py index 36768ac973d67d6d87765939637f167147e54f5c..a14b50ad1084c7622f167b616f370f5b0ad0eaa4 100644 --- a/debug/accuracy_tools/kj600/kj600/anomaly_detect.py +++ b/debug/accuracy_tools/kj600/kj600/anomaly_detect.py @@ -7,7 +7,7 @@ from collections import defaultdict from dataclasses import dataclass, field import pandas as pd from torch.utils.tensorboard import SummaryWriter -from kj600.utils import print_info_log, check_file_valid_writable, make_file_safety, create_directory +from kj600.utils import print_log_with_rank, check_file_valid_writable, make_file_safety, create_directory from kj600.const import Const from kj600.file_check import change_mode, FileCheckConst @@ -164,7 +164,8 @@ class BaseWriterWithAD: detected, rule_name = self._ad(scalar_value, history=avg) if detected: exception_message = f"Rule {rule_name} reports anomaly signal in {tag} at step {global_step}." - print_info_log(f"{bcolors.WARNING}> {exception_message}{bcolors.ENDC}") + print_log_with_rank(f"{bcolors.WARNING}> {exception_message}{bcolors.ENDC}", + Const.CURRENT_RANK, Const.INFO) if self.anomaly_inform: self.anomaly_inform.run(exception_message, self.job_id) diff --git a/debug/accuracy_tools/kj600/kj600/const.py b/debug/accuracy_tools/kj600/kj600/const.py index 095356631d43e163c500178197250bcc5bbf3f7a..9d9e68458f06a29f4080a318c24d6c70d72ad291 100644 --- a/debug/accuracy_tools/kj600/kj600/const.py +++ b/debug/accuracy_tools/kj600/kj600/const.py @@ -9,4 +9,9 @@ class Const: OP_LIST = ['min', 'max', 'norm', 'mean', 'id', 'zeros', 'nans'] DEEPSPEED_OPT_TY = ("DeepSpeedZeroOptimizer_Stage1_or_2", "DeepSpeedZeroOptimizer_Stage3") - \ No newline at end of file + + # Used for print log + INFO = 'INFO' + WARNING = 'WARNING' + ERROR = 'ERROR' + CURRENT_RANK = -1 \ No newline at end of file diff --git a/debug/accuracy_tools/kj600/kj600/features.py b/debug/accuracy_tools/kj600/kj600/features.py index 302ab3f8c550260ca1cca0fa1b67e965e3c90160..5b075f9850d8f26a1c589719dd0a8a1472c9417e 100644 --- a/debug/accuracy_tools/kj600/kj600/features.py +++ b/debug/accuracy_tools/kj600/kj600/features.py @@ -1,6 +1,8 @@ import torch from torch.autograd.functional import jacobian -from kj600.utils import print_info_log + +from kj600.const import Const +from kj600.utils import print_log_with_rank @torch.no_grad() @@ -34,7 +36,7 @@ def get_sign_matches(x: torch.tensor, y:torch.tensor): try: same_direction_ratio = ((xs * ys).sum()/ys.numel() + 1)/2 except RuntimeError as e: - print_info_log(f"RuntimeError: {e}") + print_log_with_rank(f"RuntimeError: {e}", Const.CURRENT_RANK, Const.INFO) same_direction_ratio = torch.tensor(0.) return same_direction_ratio diff --git a/debug/accuracy_tools/kj600/kj600/file_check.py b/debug/accuracy_tools/kj600/kj600/file_check.py index 80f456a6287c015d1894fc75eefb32119dff4428..a12ed089bc7f74512b1c0428d4de91294f0c76fc 100644 --- a/debug/accuracy_tools/kj600/kj600/file_check.py +++ b/debug/accuracy_tools/kj600/kj600/file_check.py @@ -17,7 +17,8 @@ import os import re -from kj600.utils import print_error_log +from kj600.const import Const +from kj600.utils import print_log_with_rank class CodedException(Exception): @@ -94,8 +95,8 @@ class FileChecker: @staticmethod def _check_path_type(path_type): if path_type not in [FileCheckConst.DIR, FileCheckConst.FILE]: - print_error_log( - f"The path_type must be {FileCheckConst.DIR} or {FileCheckConst.FILE}." + print_log_with_rank( + f"The path_type must be {FileCheckConst.DIR} or {FileCheckConst.FILE}.", Const.CURRENT_RANK, Const.ERROR ) raise FileCheckException(FileCheckException.ILLEGAL_PARAM_ERROR) return path_type @@ -167,7 +168,7 @@ class FileOpen: + self.SUPPORT_READ_WRITE_MODE ) if self.mode not in support_mode: - print_error_log(f"File open not support {self.mode} mode") + print_log_with_rank(f"File open not support {self.mode} mode", Const.CURRENT_RANK, Const.ERROR) raise FileCheckException(FileCheckException.ILLEGAL_PARAM_ERROR) check_link(self.file_path) self.file_path = os.path.realpath(self.file_path) @@ -194,45 +195,45 @@ class FileOpen: def check_link(path): abs_path = os.path.abspath(path) if os.path.islink(abs_path): - print_error_log(f"The file path {path} is a soft link.") + print_log_with_rank(f"The file path {path} is a soft link.", Const.CURRENT_RANK, Const.ERROR) raise FileCheckException(FileCheckException.SOFT_LINK_ERROR) def check_path_length(path): if path_len_exceeds_limit(path): - print_error_log("The file path length exceeds limit.") + print_log_with_rank("The file path length exceeds limit.", Const.CURRENT_RANK, Const.ERROR) raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR) def check_path_exists(path): if not os.path.exists(path): - print_error_log(f"The file path {path} does not exist.") + print_log_with_rank(f"The file path {path} does not exist.", Const.CURRENT_RANK, Const.ERROR) raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR) def check_path_readability(path): if not os.access(path, os.R_OK): - print_error_log(f"The file path {path} is not readable.") + print_log_with_rank(f"The file path {path} is not readable.", Const.CURRENT_RANK, Const.ERROR) raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR) def check_path_writability(path): if not os.access(path, os.W_OK): - print_error_log(f"The file path {path} is not writable.") + print_log_with_rank(f"The file path {path} is not writable.", Const.CURRENT_RANK, Const.ERROR) raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR) def check_path_executable(path): if not os.access(path, os.X_OK): - print_error_log(f"The file path {path} is not executable.") + print_log_with_rank(f"The file path {path} is not executable.", Const.CURRENT_RANK, Const.ERROR) raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR) def check_other_user_writable(path): st = os.stat(path) if st.st_mode & 0o002: - print_error_log( - f"The file path {path} may be insecure because other users have write permissions. " + print_log_with_rank( + f"The file path {path} may be insecure because other users have write permissions. ", Const.CURRENT_RANK, Const.ERROR ) raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR) @@ -240,22 +241,22 @@ def check_other_user_writable(path): def check_path_owner_consistent(path): file_owner = os.stat(path).st_uid if file_owner != os.getuid(): - print_error_log( - f"The file path {path} may be insecure because is does not belong to you." + print_log_with_rank( + f"The file path {path} may be insecure because is does not belong to you.", Const.CURRENT_RANK, Const.ERROR ) raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR) def check_path_pattern_vaild(path): if not re.match(FileCheckConst.FILE_VALID_PATTERN, path): - print_error_log(f"The file path {path} contains special characters.") + print_log_with_rank(f"The file path {path} contains special characters.", Const.CURRENT_RANK, Const.ERROR) raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR) def check_file_size(file_path, max_size): file_size = os.path.getsize(file_path) if file_size >= max_size: - print_error_log(f"The size of file path {file_path} exceeds {max_size} bytes.") + print_log_with_rank(f"The size of file path {file_path} exceeds {max_size} bytes.", Const.CURRENT_RANK, Const.ERROR) raise FileCheckException(FileCheckException.FILE_TOO_LARGE_ERROR) @@ -270,18 +271,18 @@ def check_common_file_size(file_path): def check_file_suffix(file_path, file_suffix): if file_suffix: if not file_path.endswith(file_suffix): - print_error_log(f"The {file_path} should be a {file_suffix} file!") + print_log_with_rank(f"The {file_path} should be a {file_suffix} file!", Const.CURRENT_RANK, Const.ERROR) raise FileCheckException(FileCheckException.INVALID_FILE_ERROR) def check_path_type(file_path, file_type): if file_type == FileCheckConst.FILE: if not os.path.isfile(file_path): - print_error_log(f"The {file_path} should be a file!") + print_log_with_rank(f"The {file_path} should be a file!", Const.CURRENT_RANK, Const.ERROR) raise FileCheckException(FileCheckException.INVALID_FILE_ERROR) if file_type == FileCheckConst.DIR: if not os.path.isdir(file_path): - print_error_log(f"The {file_path} should be a dictionary!") + print_log_with_rank(f"The {file_path} should be a dictionary!", Const.CURRENT_RANK, Const.ERROR) raise FileCheckException(FileCheckException.INVALID_FILE_ERROR) diff --git a/debug/accuracy_tools/kj600/kj600/module_hook.py b/debug/accuracy_tools/kj600/kj600/module_hook.py index c26750f78725abe3d7e4e73f46f191656f646564..696140e0fea05eb50a559b998a002cb2e2fb1871 100644 --- a/debug/accuracy_tools/kj600/kj600/module_hook.py +++ b/debug/accuracy_tools/kj600/kj600/module_hook.py @@ -14,14 +14,14 @@ import torch.distributed as dist from torch.utils.hooks import BackwardHook from torch.optim.optimizer import register_optimizer_step_pre_hook, register_optimizer_step_post_hook from kj600.module_spec_verifier import validate_config_spec -from kj600.optimizer_collect import OptimizerMon, print_rank_0, OptimizerMonFactory +from kj600.optimizer_collect import OptimizerMon, OptimizerMonFactory from kj600.features import eff_rank, get_sign_matches from kj600.visualizer import HeatmapVisualizer from kj600.anomaly_detect import AnomalyScanner, AnomalyDataFactory, SummaryWriterWithAD, CSVWriterWithAD, BaseWriterWithAD from kj600.anomaly_analyse import AnomalyDataWriter from kj600.module_metric import get_metrics, write_metrics_tensorboard, write_metrics_csv, get_summary_writer_tag_name, TensorMetrics, squash_param_name, sqrt_norm_metric, reorder_metric from kj600.distributed.wrap_distributed import api_register, create_hooks, op_aggregate, get_process_group -from kj600.utils import print_warn_log, print_info_log, print_error_log, get_param_struct, validate_config, validate_ops +from kj600.utils import print_log_with_rank, get_param_struct, validate_config, validate_ops from kj600.const import Const from kj600.file_check import FileOpen @@ -136,29 +136,29 @@ class TrainerMon: self.all_xy = self.config.get('all_xy', False) self.xy_distribution = self.config.get('xy_distribution', False) if not self.xy_distribution: - print_rank_0("> module input/output input_grad/output_grad is not monitored. ") + print_log_with_rank("> module input/output input_grad/output_grad is not monitored. ", 0, Const.INFO) # backward hook cause megatron-lm pipeline parallel schedule assert exception. # TBD: backward hook cause output tensor is view of some base tensor. root cause invesigation pending. self.forward_only = self.config.get('forward_only', False) if self.forward_only: - print_rank_0("> only module forward is monitored. ") + print_log_with_rank("> only module forward is monitored. ", 0, Const.INFO) self.backward_only = self.config.get('backward_only', False) self.ur_distribution = self.config.get('ur_distribution', False) if not self.ur_distribution: - print_rank_0("> update vector and ratio vector of adam is not monitored. ") + print_log_with_rank("> update vector and ratio vector of adam is not monitored. ", 0, Const.INFO) self.mv_distribution = self.config.get("mv_distribution", False) if not self.mv_distribution: - print_rank_0("> momentum and variance of adam is not monitored. ") + print_log_with_rank("> momentum and variance of adam is not monitored. ", 0, Const.INFO) self.wg_distribution = self.config.get("wg_distribution", False) if not self.wg_distribution: - print_rank_0("> weight grad of specified module is not monitored. ") + print_log_with_rank("> weight grad of specified module is not monitored. ", 0, Const.INFO) self.mg_direction = self.config.get('mg_direction', False) if not self.mg_direction: - print_rank_0('> grad and momentum direction will not be compared.') + print_log_with_rank('> grad and momentum direction will not be compared.', 0, Const.INFO) self.cc_distribution = self.config.get("cc_distribution", {}) if not self.cc_distribution.get('enable', False): - print_rank_0("> cc operator is not monitored.") + print_log_with_rank("> cc operator is not monitored.", 0, Const.INFO) self.cc_log_only = False else: self.cc_codeline = self.cc_distribution.get('cc_codeline', []) @@ -244,10 +244,7 @@ class TrainerMon: raise Exception("ur_distribution cannot be enabled with unknown optimizer.") if self.mv_distribution: raise Exception("mv_distribution cannot be enabled with unknown optimizer.") - self.verbose = False self.print_struct = self.config.get("print_struct", False) - if self.print_struct: - self.verbose = True self.struct_printed = False self.module_struct = {} @@ -300,8 +297,8 @@ class TrainerMon: vpp_stage = f'{vpp_stage}{Const.VPP_SEP}' targets = [x for x, _ in model_chunk.named_modules()] if self.print_struct else self.config['targets'].keys() hooked_count += self._hook_module(targets, model_chunk, vpp_stage) - - print_rank_0(f"> {hooked_count} out of {len(self.config['targets'])} are monitored.") + + print_log_with_rank(f"> {hooked_count} out of {len(self.config['targets'])} are monitored.", 0, Const.INFO) def clone_if_tensor(args): if isinstance(args, tuple): @@ -357,7 +354,7 @@ class TrainerMon: continue grad = param.main_grad if self.params_have_main_grad else param.grad if grad is None: - print_warn_log(f"grad is None: {name}, maybe something wrong happened.") + print_log_with_rank(f"grad is None: {name}, maybe something wrong happened.", self.rank, Const.WARNING) continue key = get_summary_writer_tag_name(name, 'post_grad', self.rank) grad_dict[key] = grad @@ -370,7 +367,7 @@ class TrainerMon: return reduced, unreduced def monitor_gnorm_with_ad(self, model, grad_acc_steps=1, optimizer=None, tp_group=None, dp_group=None): - print_info_log(f'grad acc steps {grad_acc_steps}') + print_log_with_rank(f'grad acc steps {grad_acc_steps}', self.rank, Const.INFO) self.hook_optimizer(optimizer) self.micro_batch_number = grad_acc_steps @@ -480,14 +477,15 @@ class TrainerMon: context.param_adam_update = mv_result.update context.param_adam_ratio = mv_result.ratio + smallest_rank = min(self.module_rank_list) if self.module_rank_list else 0 if self.print_struct and not all(value == {} for value in self.module_struct.values()) and not self.struct_printed: - self._smallest_rank_print("> module struct:") - self._smallest_rank_print(json.dumps(self.module_struct, indent=4)) + print_log_with_rank("> module struct:", smallest_rank, Const.INFO) + print_log_with_rank(json.dumps(self.module_struct, indent=4), smallest_rank, Const.INFO) if not self.cc_log_only: raise Exception("exit after first step when print model struct") if self.cc_log_only and context.step > 0: - self._smallest_rank_print("> Used communication ops and corresponding stack") - self._smallest_rank_print(json.dumps({k:[i.split(';') for i in v] for k,v in self.cc_logged_stack.items()}, indent=4)) + print_log_with_rank("> Used communication ops and corresponding stack", smallest_rank, Const.INFO) + print_log_with_rank(json.dumps({k:[i.split(';') for i in v] for k,v in self.cc_logged_stack.items()}, indent=4), smallest_rank, Const.INFO) raise Exception("exit after first step when print cc stack") self.generate_wgrad_metrics() @@ -498,7 +496,7 @@ class TrainerMon: for param, name in self.param2name.items(): grad = param.main_grad if self.params_have_main_grad else param.grad if grad is None: - print_warn_log(f"grad is None: {name}, maybe something wrong happened.") + print_log_with_rank(f"grad is None: {name}, maybe something wrong happened.", self.rank, Const.WARNING) continue if context.step == 0: same_direction_ratio = torch.tensor(1.) @@ -575,19 +573,6 @@ class TrainerMon: self.optimizer_hooked = True return - def _smallest_rank_print(self, msg): - if not self.verbose: - return - if dist.is_initialized(): - if self.module_rank_list: - if dist.get_rank() == min(self.module_rank_list): - print_info_log(msg) - else: - if dist.get_rank() == 0: - print_info_log(msg) - else: - print_info_log(msg) - def _is_target_param(self, param_name, param, prefix): squash_name = prefix + squash_param_name(param_name) name = prefix + param_name @@ -605,8 +590,8 @@ class TrainerMon: if self._is_target_param(param_name, param, prefix): name = prefix + squash_param_name(param_name) if name in self.param2name.values(): - print_error_log(f'same name {name} for different param. Current param is {param_name}. \ - May be error of squash_param_name') + print_log_with_rank(f'same name {name} for different param. Current param is {param_name}. \ + May be error of squash_param_name', self.rank, Const.ERROR) raise Exception("param with same name will be overwritten.") self.param2name[param] = name self.name2param[name] = param @@ -621,7 +606,8 @@ class TrainerMon: if len(model) > 1: self.vpp = True - self._smallest_rank_print('vpp enabled') + smallest_rank = min(self.module_rank_list) if self.module_rank_list else 0 + print_log_with_rank('vpp enabled', smallest_rank, Const.INFO) for vpp_stage, model_chunk in enumerate(model): prefix = f'{vpp_stage}{Const.VPP_SEP}' @@ -677,8 +663,9 @@ class TrainerMon: for metric_name in self.ops: if context.micro_step == 0 and context.actv.get(metric_name, []): - print_warn_log( - f"actv context of {context.module_name} is not empty when first micro_step, maybe something wrong happened. Now clear it.") + print_log_with_rank( + f"actv context of {context.module_name} is not empty when first micro_step, maybe " + f"something wrong happened. Now clear it.", self.rank, Const.WARNING) context.actv.clear() context.actv[metric_name].update(get_metrics(metric_name, tbtag_tensor_map, self.eps)) @@ -716,7 +703,8 @@ class TrainerMon: tbtag_tensor_map.update(self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', Const.ACTVGRAD_OUT, cared_output_grad)) if context.micro_step == 0 and context.actvgrad: - print_warn_log(f"actvgrad context of {context.module_name} is not empty when first micro_step, maybe something wrong happened. Now clear it.") + print_log_with_rank(f"actvgrad context of {context.module_name} is not empty when first micro_step, " + f"maybe something wrong happened. Now clear it.", self.rank, Const.WARNING) context.actvgrad.clear() for metric_name in self.ops: @@ -731,7 +719,7 @@ class TrainerMon: return if self.backward_only and self.forward_only: - print_warn_log('not enable backward_only and forward_only simultaneously') + print_log_with_rank('not enable backward_only and forward_only simultaneously', self.rank, Const.ERROR) hooked_count = 0 if self.xy_distribution or self.print_struct: @@ -747,7 +735,7 @@ class TrainerMon: handle = submodule.register_full_backward_hook(bwd_hook_fun) self.handles['xy'].append(handle) self.module_bwd_hook_context_by_module[submodule] = ModuleHookContext(name) - print_rank_0(f"> {name} is monitored successfully") + print_log_with_rank(f"> {name} is monitored successfully", 0, Const.INFO) hooked_count += 1 return hooked_count diff --git a/debug/accuracy_tools/kj600/kj600/module_metric.py b/debug/accuracy_tools/kj600/kj600/module_metric.py index 179a9eea2a84879d592967b561c54fcb4cd67316..b825750bd420585c12cc9438596a3112992287a1 100644 --- a/debug/accuracy_tools/kj600/kj600/module_metric.py +++ b/debug/accuracy_tools/kj600/kj600/module_metric.py @@ -6,7 +6,7 @@ import torch from kj600.const import Const from kj600.features import square_sum, get_max, get_min, get_zeros, get_nans, get_norm, get_mean -from kj600.utils import print_warn_log +from kj600.utils import print_log_with_rank def get_summary_writer_tag_name(module_or_param_name:str, tag:str, rank): @@ -77,7 +77,7 @@ class Metric(object): try: metrics_dict[tag] = self.get_metric_value(tensor, eps) if torch.isnan(metrics_dict[tag]): - print_warn_log(f'nan when calculate metric for {tag}') + print_log_with_rank(f'nan when calculate metric for {tag}', Const.CURRENT_RANK, Const.WARNING) except RuntimeError as e: metrics_dict[tag] = torch.tensor(torch.nan) return metrics_dict diff --git a/debug/accuracy_tools/kj600/kj600/module_spec_verifier.py b/debug/accuracy_tools/kj600/kj600/module_spec_verifier.py index 69f15afdbc9e88e7873179e8e740ed12c0f3a068..fe9c2c57d2abd5de534f71e38a739d8a94c0845d 100644 --- a/debug/accuracy_tools/kj600/kj600/module_spec_verifier.py +++ b/debug/accuracy_tools/kj600/kj600/module_spec_verifier.py @@ -1,9 +1,9 @@ -import json import re import abc import torch -from kj600.utils import print_warn_log +from kj600.const import Const +from kj600.utils import print_log_with_rank # 用于存储所有validator实现类的注册表 config_validator_registry = {} @@ -67,7 +67,8 @@ def validate_config_spec(config_spec:str, actual_data, module_name:str, data_typ try: focused_col = config_validator.validate(actual_data, module_name, data_type, pattern_match) except ValueError as e: - print_warn_log(str(e)) + print_log_with_rank(str(e), Const.CURRENT_RANK, Const.WARNING) return focused_col - print_warn_log(f"config spec in {module_name} {data_type} not supported, expected spec:'tuple\[(\d+)\]:(\d+)' or 'tensor', actual spec: {config_spec}.") + print_log_with_rank(f"config spec in {module_name} {data_type} not supported, expected spec:'tuple\[(\d+)\]:(\d+)' " + f"or 'tensor', actual spec: {config_spec}.", Const.CURRENT_RANK, Const.WARNING) return focused_col diff --git a/debug/accuracy_tools/kj600/kj600/optimizer_collect.py b/debug/accuracy_tools/kj600/kj600/optimizer_collect.py index 4e55d38c30cd31e0cc0f0463cf1ae046c4ce4280..95f6f5e652b9725dff323783613eed12168f0f1e 100644 --- a/debug/accuracy_tools/kj600/kj600/optimizer_collect.py +++ b/debug/accuracy_tools/kj600/kj600/optimizer_collect.py @@ -3,15 +3,8 @@ from collections import defaultdict, namedtuple import torch import torch.distributed as dist -from kj600.utils import print_warn_log, print_error_log - - -def print_rank_0(message): - if dist.is_initialized(): - if dist.get_rank() == 0: - print(message) - else: - print(message) +from kj600.const import Const +from kj600.utils import print_log_with_rank MVResult = namedtuple('MVResult', ("exp_avg", "exp_avg_sq", "update", "ratio")) @@ -48,7 +41,8 @@ class OptimizerMon(ABC): exp_avg = state_param.get("exp_avg", None) exp_avg_sq = state_param.get("exp_avg_sq", None) if exp_avg is None or exp_avg_sq is None: - print_warn_log(f"exp_avg or exp_avg_sq of {name} is None, maybe something wrong happened.") + print_log_with_rank(f"exp_avg or exp_avg_sq of {name} is None, " + f"maybe something wrong happened.", Const.CURRENT_RANK, Const.WARNING) continue if monitor.mv_distribution: exp_avg_dict[name] = exp_avg @@ -61,7 +55,8 @@ class OptimizerMon(ABC): elif 'step' in torch_opt.param_groups[0]: step = torch_opt.param_groups[0]['step'] # AdamW from mindspeed else: - print_warn_log(f"step of {name} is None, maybe something wrong happened.") + print_log_with_rank(f"step of {name} is None, maybe something wrong happened.", + Const.CURRENT_RANK, Const.WARNING) continue exp_avg_hat = exp_avg / (1 - torch_opt.defaults['betas'][0] ** step) exp_avg_sq_hat = exp_avg_sq / (1 - torch_opt.defaults['betas'][1] ** step) @@ -90,7 +85,8 @@ class OptimizerMon(ABC): exp_avg = state_param.get("exp_avg", None) exp_avg_sq = state_param.get("exp_avg_sq", None) if exp_avg is None or exp_avg_sq is None: - print_warn_log(f"exp_avg or exp_avg_sq of {name} is None, maybe something wrong happened.") + print_log_with_rank(f"exp_avg or exp_avg_sq of {name} is None, maybe something wrong happened.", + Const.CURRENT_RANK, Const.WARNING) continue exp_avg = exp_avg[start_idx: end_idx] exp_avg_sq = exp_avg_sq[start_idx: end_idx] @@ -105,7 +101,8 @@ class OptimizerMon(ABC): elif 'step' in torch_opt.param_groups[0]: step = torch_opt.param_groups[0]['step'] # AdamW from mindspeed else: - print_warn_log(f"step of {name} is None, maybe something wrong happened.") + print_log_with_rank(f"step of {name} is None, maybe something wrong happened.", + Const.CURRENT_RANK, Const.WARNING) continue exp_avg_hat = exp_avg / (1 - torch_opt.defaults['betas'][0] ** step) exp_avg_sq_hat = exp_avg_sq / (1 - torch_opt.defaults['betas'][1] ** step) diff --git a/debug/accuracy_tools/kj600/kj600/utils.py b/debug/accuracy_tools/kj600/kj600/utils.py index a5c0b44bb3660401d25a20fa68bbbeeafbc46459..1c4994e717ccd599ad62147d91b1acc8ffe96592 100644 --- a/debug/accuracy_tools/kj600/kj600/utils.py +++ b/debug/accuracy_tools/kj600/kj600/utils.py @@ -4,6 +4,7 @@ import sys import re import warnings import torch +import torch.distributed as dist from kj600.const import Const @@ -14,6 +15,7 @@ FILE_NAME_MAX_LENGTH = 255 DIRECTORY_MAX_LENGTH = 4096 FILE_NAME_VALID_PATTERN = r"^[a-zA-Z0-9_.:/-]+$" + def _print_log(level, msg, end='\n'): current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(int(time.time()))) pid = os.getgid() @@ -21,7 +23,7 @@ def _print_log(level, msg, end='\n'): sys.stdout.flush() -def print_info_log(info_msg, end='\n'): +def _print_info_log(info_msg, end='\n'): """ Function Description: print info log. @@ -31,7 +33,7 @@ def print_info_log(info_msg, end='\n'): _print_log("INFO", info_msg, end=end) -def print_error_log(error_msg): +def _print_error_log(error_msg): """ Function Description: print error log. @@ -41,7 +43,7 @@ def print_error_log(error_msg): _print_log("ERROR", error_msg) -def print_warn_log(warn_msg): +def _print_warn_log(warn_msg): """ Function Description: print warn log. @@ -51,6 +53,29 @@ def print_warn_log(warn_msg): _print_log("WARNING", warn_msg) +print_method_map = { + Const.INFO: _print_info_log, + Const.WARNING: _print_warn_log, + Const.ERROR: _print_error_log +} + + +def print_log_with_rank(msg: str, rank: int, level: str): + """Logs a message with a specified rank and log level. + + Args: + msg (str): The message to log. + rank (int): The rank to check against. If -1, logs from current rank. + level (str): The log level ('INFO', 'ERROR', 'WARNING'). + """ + print_method = print_method_map.get(level, _print_info_log) + if dist.is_initialized(): + if rank == dist.get_rank() or rank == Const.CURRENT_RANK: + print_method(f'[RANK{dist.get_rank()}]{msg}') + else: + print_method(msg) + + def get_param_struct(param): res = {} if isinstance(param, (tuple, list)): @@ -62,7 +87,8 @@ def get_param_struct(param): res['tensor'] = f'size={tuple(param.shape)}, dtype={param.dtype}' else: res['config'] = f'{type(param)}' - print_warn_log(f'Not support type({type(param)}) now, please check the type of param {param}') + print_log_with_rank(f'Not support type({type(param)}) now, please check the type of param {param}', + Const.CURRENT_RANK, Const.WARNING) return res @@ -157,7 +183,8 @@ def validate_ops(ops): valid_ops = [] for op in ops: if op not in Const.OP_LIST: - print_warn_log(f"given op {op} is not supported. Optional ops: {Const.OP_LIST}") + print_log_with_rank(f"given op {op} is not supported. Optional ops: {Const.OP_LIST}", + Const.CURRENT_RANK, Const.WARNING) else: valid_ops.append(op) return valid_ops