diff --git a/debug/accuracy_tools/msprobe/core/monitor/utils.py b/debug/accuracy_tools/msprobe/core/monitor/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f19e14d89e6b7cb29d8bdc756a5d258081b106ab --- /dev/null +++ b/debug/accuracy_tools/msprobe/core/monitor/utils.py @@ -0,0 +1,338 @@ +# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import namedtuple +from datetime import timezone, timedelta +from functools import wraps +from datetime import datetime +import os +import re + +from msprobe.core.common.const import MonitorConst +from msprobe.core.common.log import logger +from msprobe.core.common.utils import is_int +from msprobe.core.common.file_utils import check_file_or_directory_path, recursive_chmod + + +beijing_tz = timezone(timedelta(hours=8)) +MVResult = namedtuple('MVResult', ("exp_avg", "exp_avg_sq", "update", "ratio")) + + +class MsgConst: + """ + Class for log messages const + """ + SPECIAL_CHAR = ["\n", "\r", "\u007F", "\b", "\f", "\t", "\u000B", "%08", "%0a", "%0b", "%0c", "%0d", "%7f"] + + +def get_output_base_dir(): + return os.getenv(MonitorConst.MONITOR_OUTPUT_DIR, MonitorConst.DEFAULT_MONITOR_OUTPUT_DIR) + + +def filter_special_chars(func): + @wraps(func) + def func_level(msg): + for char in MsgConst.SPECIAL_CHAR: + msg = msg.replace(char, '_') + return func(msg) + + return func_level + + +def validate_ops(ops): + if not isinstance(ops, list): + raise TypeError("ops should be a list") + valid_ops = [] + for op in ops: + if op not in MonitorConst.OP_LIST: + logger.warning(f"op {op} is not supported. Optional ops: {MonitorConst.OP_LIST}") + continue + valid_ops.append(op) + if not valid_ops: + default_op = MonitorConst.OP_LIST[0] + valid_ops.append(default_op) + logger.info_on_rank_0(f"There is no valid ops, default op {default_op} is used") + # 增加默认shape和dtype参数 + if "shape" not in valid_ops: + valid_ops.append("shape") + if "dtype" not in valid_ops: + valid_ops.append("dtype") + return valid_ops + + +def validate_ndigits(ndigits): + if not ndigits: + return + if not is_int(ndigits) or ndigits <= 0: + raise ValueError(f"ndigits({ndigits}) is not a positive integer, current is: {ndigits}.") + if ndigits > MonitorConst.MAX_NDIGITS: + raise ValueError(f"The maximum supported ndigits is {MonitorConst.MAX_NDIGITS}, current value: {ndigits}.") + + +def validate_ranks(ranks): + if not isinstance(ranks, list): + raise TypeError("module_ranks should be a list") + for rank in ranks: + if not isinstance(rank, int) or isinstance(rank, bool): + raise TypeError(f"element in module_ranks should be a int, get {type(rank)}") + + +def validate_targets(targets): + if not isinstance(targets, dict): + raise TypeError('targets in config.json should be a dict') + for module_name, field in targets.items(): + if not isinstance(module_name, str): + raise TypeError('key of targets should be module_name[str] in config.json') + if not isinstance(field, dict): + raise TypeError('values of targets should be cared filed e.g. {"input": "tensor"} in config.json') + + +def validate_print_struct(print_struct): + if not isinstance(print_struct, bool): + raise TypeError("print_struct should be a bool") + + +def validate_ur_distribution(ur_distribution): + if not isinstance(ur_distribution, bool): + raise TypeError('ur_distribution should be a bool') + + +def validate_xy_distribution(xy_distribution): + if not isinstance(xy_distribution, bool): + raise TypeError('xy_distribution should be a bool') + + +def validate_wg_distribution(wg_distribution): + if not isinstance(wg_distribution, bool): + raise TypeError('wg_distribution should be a bool') + + +def validate_mg_distribution(mg_distribution): + if not isinstance(mg_distribution, bool): + raise TypeError('mg_distribution should be a bool') + + +def validate_param_distribution(param_distribution): + if not isinstance(param_distribution, bool): + raise TypeError('param_distribution should be a bool') + + +def validate_cc_distribution(cc_distribution): + if not isinstance(cc_distribution, dict): + raise TypeError('cc_distribution should be a dictionary') + for key, value in cc_distribution.items(): + if key == 'enable': + if not isinstance(value, bool): + raise TypeError('cc_distribution enable should be a bool') + elif key == 'cc_codeline': + if not isinstance(value, list): + raise TypeError('cc_distribution cc_codeline should be a list') + elif key == 'cc_pre_hook': + if not isinstance(value, bool): + raise TypeError('cc_distribution cc_pre_hook should be a bool') + elif key == 'cc_log_only': + if not isinstance(value, bool): + raise TypeError('cc_distribution cc_log_only should be a bool') + else: + raise TypeError(f'{key} of cc_distribution is not supported.') + + +def validate_squash_name(squash_name): + if not isinstance(squash_name, bool): + raise TypeError('squash_name should be a bool') + + +def validate_alert(alert): + if not isinstance(alert, dict): + raise TypeError('alert should be a dictionary') + rules = alert.get('rules') + if rules and isinstance(rules, list): + for rule in rules: + rule_name = rule.get("rule_name") + if rule_name and rule_name not in MonitorConst.RULE_NAME: + raise TypeError(f"{rule_name} is not supported") + args = rule.get("args") + if args and isinstance(args, dict): + threshold = args.get("threshold") + if not isinstance(threshold, (float, int)) or threshold < 0: + raise TypeError('threshold must be float and not less than 0') + dump = alert.get('dump') + if dump and not isinstance(dump, bool): + raise TypeError('dump must be bool.') + + +def validate_step_count_per_record(step_count_per_record): + if not is_int(step_count_per_record): + raise TypeError('step_count_per_record must be int.') + if step_count_per_record < 1: + raise ValueError("step_count_per_record must greater than 0") + if step_count_per_record > 1e6: + raise ValueError("step_count_per_record must smaller than 1e6") + + +def validate_dynamic_on(dynamic_on): + if not isinstance(dynamic_on, bool): + raise TypeError('dynamic_on should be a bool') + + +def validate_monitor_mbs_grad(monitor_mbs_grad): + if not isinstance(monitor_mbs_grad, bool): + logger.warning(f'monitor_mbs_grad should be a bool, actual value is {monitor_mbs_grad}.') + return False + return monitor_mbs_grad + + +def validate_append_output(append_output): + if not isinstance(append_output, list): + raise TypeError('append_output should be a list') + if len(append_output) > 0 and len(append_output) != 2: + raise ValueError('append_output should be empty or contain exactly 2 elements') + + +def validate_config(config): + config['ops'] = validate_ops(config.get('ops', [])) + + ndigits = config.get('ndigits') + validate_ndigits(ndigits) + + eps = config.get('eps', 1e-8) + if not isinstance(eps, float): + raise TypeError("eps should be a float") + + ranks = config.get("module_ranks", []) + validate_ranks(ranks) + + targets = config.get("targets", {}) + validate_targets(targets) + + print_struct = config.get('print_struct', False) + validate_print_struct(print_struct) + + ur_distribution = config.get('ur_distribution', False) + validate_ur_distribution(ur_distribution) + + xy_distribution = config.get('xy_distribution', False) + validate_xy_distribution(xy_distribution) + + wg_distribution = config.get('wg_distribution', False) + validate_wg_distribution(wg_distribution) + + mg_distribution = config.get('mg_distribution', False) + validate_mg_distribution(mg_distribution) + + param_distribution = config.get('param_distribution', False) + validate_param_distribution(param_distribution) + + cc_distribution = config.get('cc_distribution', {}) + validate_cc_distribution(cc_distribution) + + alert = config.get('alert', {}) + validate_alert(alert) + + step_count_per_record = config.get('step_count_per_record', 1) + validate_step_count_per_record(step_count_per_record) + + config["start_step"] = validate_int_arg(config.get("start_step"), "start_step", + MonitorConst.DEFAULT_START_STEP, MonitorConst.DEFAULT_START_STEP) + config["collect_times"] = validate_int_arg(config.get("collect_times"), "collect_times", + MonitorConst.DEFAULT_MIN_COLLECT_TIMES, + MonitorConst.DEFAULT_MAX_COLLECT_TIMES) + config["step_interval"] = validate_int_arg(config.get("step_interval"), "step_interval", + MonitorConst.DEFAULT_STEP_INTERVAL, MonitorConst.DEFAULT_STEP_INTERVAL) + + squash_name = config.get('squash_name', True) + validate_squash_name(squash_name) + + time_tags = config.get("append_output", []) + validate_append_output(time_tags) + + config["monitor_mbs_grad"] = validate_monitor_mbs_grad(config.get('monitor_mbs_grad', False)) + + dynamic_on = config.get('dynamic_on', False) + validate_dynamic_on(dynamic_on) + + if not targets: + if xy_distribution: + config["all_xy"] = True + config["targets"] = {"": {}} + + +def time_str2time_digit(time_str): + time_format = '%b%d_%H-%M-%S' + if not isinstance(time_str, str): + raise TypeError(f"time_str:{time_str} should be a str") + try: + time_digit = datetime.strptime(time_str, time_format) + except Exception as e: + raise RuntimeError(f"illegal timestamp: {time_str}, timestamp should be prefix \ + of existing output dirpath, like 'Dec03_21-34-40'.") from e + return time_digit + + +def get_target_output_dir(monitor_path, time_start, time_end): + check_file_or_directory_path(monitor_path, isdir=True) + time_start = time_str2time_digit(time_start) if time_start is not None else time_start + time_end = time_str2time_digit(time_end) if time_end is not None else time_end + if time_start and time_end and time_start > time_end: + raise ValueError(f"time_start({time_start}) greater than time_end({time_end})") + result = {} + for dirname in os.listdir(monitor_path): + match = re.match(MonitorConst.OUTPUT_DIR_PATTERN, dirname) + if not match: + continue + time_tag = match.group(1) + rank = match.group(2) + target_time = time_str2time_digit(time_tag) + start_ok = time_start is None or target_time >= time_start + end_ok = time_end is None or target_time <= time_end + if start_ok and end_ok: + result[rank] = os.path.join(monitor_path, dirname) + return result + + +def chmod_tensorboard_dir(path): + """ + format配置为tensorboard时,需要补充文件权限设置 + """ + try: + recursive_chmod(path) + except Exception as e: + logger.warning(f"chmod tensorboard dir wrong because {e}, not updated, please check!!!") + + +def validate_set_monitor(grad_acc_steps, start_iteration): + """ + validate parameters of set_monitor. + """ + grad_acc_steps = validate_int_arg(grad_acc_steps, "grad_acc_steps", + MonitorConst.DEFAULT_GRAD_ACC_STEPS, MonitorConst.DEFAULT_GRAD_ACC_STEPS) + + start_iteration = validate_int_arg(start_iteration, "start_iteration", + MonitorConst.DEFAULT_START_ITERATION, MonitorConst.DEFAULT_START_ITERATION) + return grad_acc_steps, start_iteration + + +def validate_int_arg(value, name, minimum, default_value): + """Validate int args, if any exception occurs, use the default value.""" + if value is None: + return default_value + try: + if not is_int(value): + raise TypeError(f"{name} must be int") + if value < minimum: + raise ValueError(f"{name} must greater than {minimum}") + except Exception as e: + value = default_value + logger.warning(f"Validate {name} failed, {e}, replaced with default value {value}.") + return value diff --git a/debug/accuracy_tools/msprobe/mindspore/monitor/common_func.py b/debug/accuracy_tools/msprobe/mindspore/monitor/common_func.py index 52db10d05469e83d2992f7844a2db049a6c8c5af..5880d2284f33f8a9a406cabb298e921e85e6b6b5 100644 --- a/debug/accuracy_tools/msprobe/mindspore/monitor/common_func.py +++ b/debug/accuracy_tools/msprobe/mindspore/monitor/common_func.py @@ -16,7 +16,7 @@ from mindspore import nn from mindspore import communication -from msprobe.mindspore.monitor.utils import logger +from msprobe.core.common.log import logger from msprobe.mindspore.common.utils import is_mindtorch if is_mindtorch(): import torch diff --git a/debug/accuracy_tools/msprobe/mindspore/monitor/module_hook.py b/debug/accuracy_tools/msprobe/mindspore/monitor/module_hook.py index 0354ab533683c2ac7efc44c2914581e980fae4ff..0d74fca7808e30a5a071aeb9ff556be0cd116bc1 100644 --- a/debug/accuracy_tools/msprobe/mindspore/monitor/module_hook.py +++ b/debug/accuracy_tools/msprobe/mindspore/monitor/module_hook.py @@ -28,11 +28,12 @@ from mindspore import nn, _no_grad from msprobe.core.common.log import logger from msprobe.core.common.const import MonitorConst, Const from msprobe.core.common.file_utils import load_json, save_json +from msprobe.core.monitor.utils import validate_config, get_output_base_dir, get_target_output_dir from msprobe.core.monitor.anomaly_processor import AnomalyScanner, AnomalyDataFactory, AnomalyDataWriter from msprobe.mindspore.common.utils import is_mindtorch from msprobe.mindspore.monitor.common_func import is_valid_instance, get_parameters, get_submodules, get_rank -from msprobe.mindspore.monitor.utils import get_summary_writer_tag_name, validate_config, step_accumulates_one, \ - is_skip_step, get_metrics, get_target_output_dir +from msprobe.mindspore.monitor.utils import get_summary_writer_tag_name, step_accumulates_one, is_skip_step, \ + get_metrics from msprobe.mindspore.monitor.optimizer_collect import OptimizerMonFactory from msprobe.mindspore.monitor.data_writers import CSVWriterWithAD, BaseWriterWithAD, WriterInput from msprobe.mindspore.monitor.distributed.wrap_distributed import api_register, create_hooks, op_aggregate @@ -250,7 +251,6 @@ class TrainerMon: self.has_collect_times = 0 # 重设采集计数器 self.print_struct = self.config.get("print_struct", False) self.targets = self.config.get("targets", None) - self.is_select = self.config.get("is_select", False) self.module_rank_list = self.config.get("module_ranks", []) self.format = self.config.get('format', MonitorConst.CSV) # only csv supported in mindspore self.eps = self.config.get('eps', 1e-8) diff --git a/debug/accuracy_tools/msprobe/mindspore/monitor/utils.py b/debug/accuracy_tools/msprobe/mindspore/monitor/utils.py index e0817eb2a4efc11000a96c6b328f6fbd07145060..a6cc09fb3c1dde73caa355a85d68ed016ba7d59f 100644 --- a/debug/accuracy_tools/msprobe/mindspore/monitor/utils.py +++ b/debug/accuracy_tools/msprobe/mindspore/monitor/utils.py @@ -12,16 +12,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os -import re -from datetime import datetime from mindspore import dtype as mstype, Tensor from msprobe.mindspore.monitor.features import FUNC_MAP -from msprobe.core.common.const import MonitorConst -from msprobe.core.common.utils import is_int -from msprobe.core.common.log import logger -from msprobe.core.common.file_utils import check_file_or_directory_path def get_single_metrics(op_list, tag, tensor, eps=1e-8, output=None): @@ -82,248 +75,3 @@ def is_skip_step(step, start_step, step_interval, has_collect_times=0, collect_t :return: whether skip or not, bool """ return step < start_step or (step - start_step) % step_interval != 0 or has_collect_times >= collect_times - - -def validate_ops(ops): - if not isinstance(ops, list): - raise TypeError("ops should be a list") - valid_ops = [] - for op in ops: - if op not in MonitorConst.OP_LIST: - logger.warning(f"op {op} is not supported. Optional ops: {MonitorConst.OP_LIST}") - continue - valid_ops.append(op) - if not valid_ops: - default_op = MonitorConst.OP_LIST[0] - valid_ops.append(default_op) - logger.info(f"There is no valid ops, default op {default_op} is used") - # 增加默认shape和dtype参数 - if "shape" not in valid_ops: - valid_ops.append("shape") - if "dtype" not in valid_ops: - valid_ops.append("dtype") - return valid_ops - - -def validate_ranks(ranks): - if not isinstance(ranks, list): - raise TypeError("module_ranks should be a list") - for rank in ranks: - if not isinstance(rank, int): - raise TypeError(f"element in module_ranks should be a int, get {type(rank)}") - - -def validate_targets(targets): - if not isinstance(targets, dict): - raise TypeError('targets in config.json should be a dict') - for module_name, field in targets.items(): - if not isinstance(module_name, str): - raise TypeError('key of targets should be module_name[str] in config.json') - if not isinstance(field, dict): - raise TypeError('values of targets should be cared filed e.g. {"input": "tensor"} in config.json') - - -def validate_print_struct(print_struct): - if not isinstance(print_struct, bool): - raise TypeError("print_struct should be a bool") - - -def validate_ur_distribution(ur_distribution): - if not isinstance(ur_distribution, bool): - raise TypeError('ur_distribution should be a bool') - - -def validate_xy_distribution(xy_distribution): - if not isinstance(xy_distribution, bool): - raise TypeError('xy_distribution should be a bool') - - -def validate_wg_distribution(wg_distribution): - if not isinstance(wg_distribution, bool): - raise TypeError('wg_distribution should be a bool') - - -def validate_mg_distribution(mg_distribution): - if not isinstance(mg_distribution, bool): - raise TypeError('mg_distribution should be a bool') - - -def validate_param_distribution(param_distribution): - if not isinstance(param_distribution, bool): - raise TypeError('param_distribution should be a bool') - - -def validate_cc_distribution(cc_distribution): - if not isinstance(cc_distribution, dict): - raise TypeError('cc_distribution should be a dictionary') - expected_keys = { - 'enable': bool, - 'cc_codeline': list, - 'cc_pre_hook': bool, - 'cc_log_only': bool - } - for key, value in cc_distribution.items(): - if key in expected_keys: - if not isinstance(value, expected_keys[key]): - raise TypeError(f'cc_distribution {key} should be a {expected_keys[key].__name__}') - else: - raise TypeError(f'{key} of cc_distribution is not supported.') - - -def validate_alert(alert): - if not isinstance(alert, dict): - raise TypeError('alert should be a dictionary') - rules = alert.get('rules') - if rules and isinstance(rules, list): - for rule in rules: - rule_name = rule.get("rule_name") - if rule_name and rule_name not in MonitorConst.RULE_NAME: - raise TypeError(f"{rule_name} is not supported") - args = rule.get("args") - if args and isinstance(args, dict): - threshold = args.get("threshold") - if not isinstance(threshold, (float, int)) or threshold < 0: - raise TypeError('threshold must be float and not less than 0') - dump = alert.get('dump') - if dump and not isinstance(dump, bool): - raise TypeError('dump must be bool.') - - -def validate_step_count_per_record(step_count_per_record): - if not is_int(step_count_per_record): - raise TypeError('step_count_per_record must be int.') - if step_count_per_record < 1: - raise ValueError("step_count_per_record must greater than 0") - if step_count_per_record > 1e6: - raise ValueError("step_count_per_record must smaller than 1e6") - - -def validate_start_step(start_step): - if not is_int(start_step): - raise TypeError('start_step must be int.') - if start_step < 0: - raise ValueError("start_step must greater than 0") - if start_step > 1e8: - raise ValueError("start_step must smaller than 1e8") - - -def validate_step_interval(step_interval): - if not is_int(step_interval): - raise TypeError('step_interval must be int.') - if step_interval < 1: - raise ValueError("step_interval must greater than 1") - if step_interval > 1e8: - raise ValueError("step_interval must smaller than 1e8") - - -def validate_collect_times(collect_times): - if not is_int(collect_times): - raise TypeError('collect_times must be int.') - if collect_times < 1: - raise ValueError("collect_times must greater than 1") - - -def validate_dynamic_on(dynamic_on): - if not isinstance(dynamic_on, bool): - raise TypeError('dynamic_on should be a bool') - - -def validate_monitor_mbs_grad(monitor_mbs_grad): - if not isinstance(monitor_mbs_grad, bool): - logger.warning(f'monitor_mbs_grad should be a bool, actual value is {monitor_mbs_grad}.') - return False - return monitor_mbs_grad - - -def validate_config(config): - config['ops'] = validate_ops(config.get('ops', [])) - - eps = config.get('eps', 1e-8) - if not isinstance(eps, float): - raise TypeError("eps should be a float") - - ranks = config.get("module_ranks", []) - validate_ranks(ranks) - - targets = config.get("targets", {}) - validate_targets(targets) - - print_struct = config.get('print_struct', False) - validate_print_struct(print_struct) - - ur_distribution = config.get('ur_distribution', False) - validate_ur_distribution(ur_distribution) - - xy_distribution = config.get('xy_distribution', False) - validate_xy_distribution(xy_distribution) - - wg_distribution = config.get('wg_distribution', False) - validate_wg_distribution(wg_distribution) - - mg_distribution = config.get('mg_distribution', False) - validate_mg_distribution(mg_distribution) - - param_distribution = config.get('param_distribution', False) - validate_param_distribution(param_distribution) - - cc_distribution = config.get('cc_distribution', {}) - validate_cc_distribution(cc_distribution) - - alert = config.get('alert', {}) - validate_alert(alert) - - step_count_per_record = config.get('step_count_per_record', 1) - validate_step_count_per_record(step_count_per_record) - - start_step = config.get('start_step', 0) - validate_start_step(start_step) - - step_interval = config.get('step_interval', 1) - validate_step_interval(step_interval) - - collect_times = config.get('collect_times', int(1e8)) - validate_collect_times(collect_times) - - config["monitor_mbs_grad"] = validate_monitor_mbs_grad(config.get('monitor_mbs_grad', False)) - - dynamic_on = config.get('dynamic_on', False) - validate_dynamic_on(dynamic_on) - - if not targets: - if xy_distribution: - config["all_xy"] = True - config["targets"] = {"": {}} - config["is_select"] = False - else: - config["is_select"] = True - - -def time_str2time_digit(time_str): - time_format = '%b%d_%H-%M-%S' - try: - time_digit = datetime.strptime(time_str, time_format) - except Exception as e: - raise RuntimeError(f"illegal timestamp: {time_str}, timestamp should be prefix \ - of existing output dirpath, like 'Dec03_21-34-40'.") from e - return time_digit - - -def get_target_output_dir(monitor_path, time_start, time_end): - check_file_or_directory_path(monitor_path, isdir=True) - time_start = time_str2time_digit(time_start) if time_start is not None else time_start - time_end = time_str2time_digit(time_end) if time_end is not None else time_end - if time_start and time_end and time_start > time_end: - raise ValueError(f"time_start({time_start}) greater than time_end({time_end})") - result = {} - for dirname in os.listdir(monitor_path): - match = re.match(MonitorConst.OUTPUT_DIR_PATTERN, dirname) - if not match: - continue - time_tag = match.group(1) - rank = match.group(2) - target_time = time_str2time_digit(time_tag) - start_ok = time_start is None or target_time >= time_start - end_ok = time_end is None or target_time <= time_end - if start_ok and end_ok: - result[rank] = os.path.join(monitor_path, dirname) - return result diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/csv2tb.py b/debug/accuracy_tools/msprobe/pytorch/monitor/csv2tb.py index d2a6d95cb3efa359ea4eb6df8c3bbbe8d451fce2..a807f7e27987713f5c3f3488088239745b4aa0df 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/csv2tb.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/csv2tb.py @@ -25,8 +25,9 @@ from msprobe.core.common.const import MonitorConst from msprobe.core.common.file_utils import read_csv, create_directory, remove_path, recursive_chmod from msprobe.core.common.utils import check_process_num from msprobe.core.common.decorator import recursion_depth_decorator +from msprobe.core.monitor.utils import get_target_output_dir from msprobe.pytorch.common.log import logger -from msprobe.pytorch.monitor.utils import get_target_output_dir + all_data_type_list = [ "actv", "actv_grad", "exp_avg", "exp_avg_sq", diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py b/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py index 042ccf651c91f1596d17b520a41200a9cfb97f91..0158b3713758a10ff7470793198d823c4eed773b 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py @@ -31,8 +31,11 @@ from msprobe.core.common.decorator import recursion_depth_decorator from msprobe.core.monitor.anomaly_processor import AnomalyScanner, AnomalyDataFactory, AnomalyDataWriter from msprobe.core.common.file_utils import write_df_to_csv from msprobe.core.common.utils import analyze_api_call_stack +from msprobe.core.monitor.utils import validate_config, validate_ops, \ + get_output_base_dir, get_target_output_dir, chmod_tensorboard_dir, validate_set_monitor from msprobe.pytorch.common.log import logger from msprobe.pytorch.common.utils import is_recomputation, is_float8_tensor +from msprobe.pytorch.monitor.utils import get_param_struct from msprobe.pytorch.monitor.data_writers import SummaryWriterWithAD, CSVWriterWithAD, BaseWriterWithAD, WriterInput from msprobe.pytorch.monitor.distributed.wrap_distributed import api_register, create_hooks, op_aggregate, \ get_process_group @@ -40,8 +43,6 @@ from msprobe.pytorch.monitor.features import get_sign_matches from msprobe.pytorch.monitor.module_metric import get_metrics, get_summary_writer_tag_name, \ TensorMetrics, squash_param_name from msprobe.pytorch.monitor.optimizer_collect import OptimizerMonFactory -from msprobe.pytorch.monitor.utils import get_param_struct, validate_config, validate_ops, \ - get_output_base_dir, get_target_output_dir, chmod_tensorboard_dir, validate_set_monitor from msprobe.pytorch.monitor.visualizer import HeatmapVisualizer diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/optimizer_collect.py b/debug/accuracy_tools/msprobe/pytorch/monitor/optimizer_collect.py index 08827fef1788803dc0280c0ebd392012e5404d9e..8a63eaef9c348663c0bc1084b6415050dc90935e 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/optimizer_collect.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/optimizer_collect.py @@ -17,7 +17,7 @@ from abc import abstractmethod import torch from msprobe.pytorch.common.log import logger -from msprobe.pytorch.monitor.utils import MVResult +from msprobe.core.monitor.utils import MVResult from msprobe.core.common.const import MonitorConst diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/utils.py b/debug/accuracy_tools/msprobe/pytorch/monitor/utils.py index 767707479719159ae2806b54f1f706ed6faa5a20..ca339ad64823c20382910a339a1e4136008d675a 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/utils.py @@ -12,20 +12,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import inspect -from collections import namedtuple -from datetime import timezone, timedelta -from functools import wraps -from datetime import datetime -import os -import re - import torch -from msprobe.core.common.const import MonitorConst from msprobe.pytorch.common.log import logger -from msprobe.core.common.utils import is_int -from msprobe.core.common.file_utils import check_file_or_directory_path, recursive_chmod device = "cpu" @@ -37,23 +26,6 @@ except ImportError: device = "cuda" NAN_TENSOR_ON_DEVICE = None -FILE_MAX_SIZE = 10 * 1024 * 1024 * 1024 -FILE_NAME_MAX_LENGTH = 255 -DIRECTORY_MAX_LENGTH = 4096 - -beijing_tz = timezone(timedelta(hours=8)) -MVResult = namedtuple('MVResult', ("exp_avg", "exp_avg_sq", "update", "ratio")) - - -class MsgConst: - """ - Class for log messages const - """ - SPECIAL_CHAR = ["\n", "\r", "\u007F", "\b", "\f", "\t", "\u000B", "%08", "%0a", "%0b", "%0c", "%0d", "%7f"] - - -def get_output_base_dir(): - return os.getenv(MonitorConst.MONITOR_OUTPUT_DIR, MonitorConst.DEFAULT_MONITOR_OUTPUT_DIR) def get_nan_tensor(): @@ -63,16 +35,6 @@ def get_nan_tensor(): return NAN_TENSOR_ON_DEVICE -def filter_special_chars(func): - @wraps(func) - def func_level(msg): - for char in MsgConst.SPECIAL_CHAR: - msg = msg.replace(char, '_') - return func(msg) - - return func_level - - def get_param_struct(param): res = {} if isinstance(param, (tuple, list)): @@ -85,282 +47,4 @@ def get_param_struct(param): else: res['config'] = f'{type(param)}' logger.warning(f'Not support type({type(param)}) now, please check the type of param {param}') - return res - - -def validate_ops(ops): - if not isinstance(ops, list): - raise TypeError("ops should be a list") - valid_ops = [] - for op in ops: - if op not in MonitorConst.OP_LIST: - logger.warning(f"op {op} is not supported. Optional ops: {MonitorConst.OP_LIST}") - continue - valid_ops.append(op) - if not valid_ops: - default_op = MonitorConst.OP_LIST[0] - valid_ops.append(default_op) - logger.info_on_rank_0(f"There is no valid ops, default op {default_op} is used") - # 增加默认shape和dtype参数 - if "shape" not in valid_ops: - valid_ops.append("shape") - if "dtype" not in valid_ops: - valid_ops.append("dtype") - return valid_ops - - -def validate_ndigits(ndigits): - if not ndigits: - return - if not is_int(ndigits) or ndigits <= 0: - raise ValueError(f"ndigits({ndigits}) is not a positive integer, current is: {ndigits}.") - if ndigits > MonitorConst.MAX_NDIGITS: - raise ValueError(f"The maximum supported ndigits is {MonitorConst.MAX_NDIGITS}, current value: {ndigits}.") - - -def validate_ranks(ranks): - if not isinstance(ranks, list): - raise TypeError("module_ranks should be a list") - for rank in ranks: - if not isinstance(rank, int) or isinstance(rank, bool): - raise TypeError(f"element in module_ranks should be a int, get {type(rank)}") - - -def validate_targets(targets): - if not isinstance(targets, dict): - raise TypeError('targets in config.json should be a dict') - for module_name, field in targets.items(): - if not isinstance(module_name, str): - raise TypeError('key of targets should be module_name[str] in config.json') - if not isinstance(field, dict): - raise TypeError('values of targets should be cared filed e.g. {"input": "tensor"} in config.json') - - -def validate_print_struct(print_struct): - if not isinstance(print_struct, bool): - raise TypeError("print_struct should be a bool") - - -def validate_ur_distribution(ur_distribution): - if not isinstance(ur_distribution, bool): - raise TypeError('ur_distribution should be a bool') - - -def validate_xy_distribution(xy_distribution): - if not isinstance(xy_distribution, bool): - raise TypeError('xy_distribution should be a bool') - - -def validate_wg_distribution(wg_distribution): - if not isinstance(wg_distribution, bool): - raise TypeError('wg_distribution should be a bool') - - -def validate_mg_distribution(mg_distribution): - if not isinstance(mg_distribution, bool): - raise TypeError('mg_distribution should be a bool') - - -def validate_param_distribution(param_distribution): - if not isinstance(param_distribution, bool): - raise TypeError('param_distribution should be a bool') - - -def validate_cc_distribution(cc_distribution): - if not isinstance(cc_distribution, dict): - raise TypeError('cc_distribution should be a dictionary') - for key, value in cc_distribution.items(): - if key == 'enable': - if not isinstance(value, bool): - raise TypeError('cc_distribution enable should be a bool') - elif key == 'cc_codeline': - if not isinstance(value, list): - raise TypeError('cc_distribution cc_codeline should be a list') - elif key == 'cc_pre_hook': - if not isinstance(value, bool): - raise TypeError('cc_distribution cc_pre_hook should be a bool') - elif key == 'cc_log_only': - if not isinstance(value, bool): - raise TypeError('cc_distribution cc_log_only should be a bool') - else: - raise TypeError(f'{key} of cc_distribution is not supported.') - - -def validate_squash_name(squash_name): - if not isinstance(squash_name, bool): - raise TypeError('squash_name should be a bool') - - -def validate_alert(alert): - if not isinstance(alert, dict): - raise TypeError('alert should be a dictionary') - rules = alert.get('rules') - if rules and isinstance(rules, list): - for rule in rules: - rule_name = rule.get("rule_name") - if rule_name and rule_name not in MonitorConst.RULE_NAME: - raise TypeError(f"{rule_name} is not supported") - args = rule.get("args") - if args and isinstance(args, dict): - threshold = args.get("threshold") - if not isinstance(threshold, (float, int)) or threshold < 0: - raise TypeError('threshold must be float and not less than 0') - dump = alert.get('dump') - if dump and not isinstance(dump, bool): - raise TypeError('dump must be bool.') - - -def validate_step_count_per_record(step_count_per_record): - if not is_int(step_count_per_record): - raise TypeError('step_count_per_record must be int.') - if step_count_per_record < 1: - raise ValueError("step_count_per_record must greater than 0") - if step_count_per_record > 1e6: - raise ValueError("step_count_per_record must smaller than 1e6") - - -def validate_dynamic_on(dynamic_on): - if not isinstance(dynamic_on, bool): - raise TypeError('dynamic_on should be a bool') - - -def validate_monitor_mbs_grad(monitor_mbs_grad): - if not isinstance(monitor_mbs_grad, bool): - logger.warning(f'monitor_mbs_grad should be a bool, actual value is {monitor_mbs_grad}.') - return False - return monitor_mbs_grad - - -def validate_config(config): - config['ops'] = validate_ops(config.get('ops', [])) - - ndigits = config.get('ndigits') - validate_ndigits(ndigits) - - eps = config.get('eps', 1e-8) - if not isinstance(eps, float): - raise TypeError("eps should be a float") - - ranks = config.get("module_ranks", []) - validate_ranks(ranks) - - targets = config.get("targets", {}) - validate_targets(targets) - - print_struct = config.get('print_struct', False) - validate_print_struct(print_struct) - - ur_distribution = config.get('ur_distribution', False) - validate_ur_distribution(ur_distribution) - - xy_distribution = config.get('xy_distribution', False) - validate_xy_distribution(xy_distribution) - - wg_distribution = config.get('wg_distribution', False) - validate_wg_distribution(wg_distribution) - - mg_distribution = config.get('mg_distribution', False) - validate_mg_distribution(mg_distribution) - - param_distribution = config.get('param_distribution', False) - validate_param_distribution(param_distribution) - - cc_distribution = config.get('cc_distribution', {}) - validate_cc_distribution(cc_distribution) - - alert = config.get('alert', {}) - validate_alert(alert) - - step_count_per_record = config.get('step_count_per_record', 1) - validate_step_count_per_record(step_count_per_record) - - config["start_step"] = validate_int_arg(config.get("start_step"), "start_step", - MonitorConst.DEFAULT_START_STEP, MonitorConst.DEFAULT_START_STEP) - config["collect_times"] = validate_int_arg(config.get("collect_times"), "collect_times", - MonitorConst.DEFAULT_MIN_COLLECT_TIMES, - MonitorConst.DEFAULT_MAX_COLLECT_TIMES) - config["step_interval"] = validate_int_arg(config.get("step_interval"), "step_interval", - MonitorConst.DEFAULT_STEP_INTERVAL, MonitorConst.DEFAULT_STEP_INTERVAL) - - squash_name = config.get('squash_name', True) - validate_squash_name(squash_name) - - config["monitor_mbs_grad"] = validate_monitor_mbs_grad(config.get('monitor_mbs_grad', False)) - - dynamic_on = config.get('dynamic_on', False) - validate_dynamic_on(dynamic_on) - - if not targets: - if xy_distribution: - config["all_xy"] = True - config["targets"] = {"": {}} - - -def time_str2time_digit(time_str): - time_format = '%b%d_%H-%M-%S' - if not isinstance(time_str, str): - raise TypeError(f"time_str:{time_str} should be a str") - try: - time_digit = datetime.strptime(time_str, time_format) - except Exception as e: - raise RuntimeError(f"illegal timestamp: {time_str}, timestamp should be prefix \ - of existing output dirpath, like 'Dec03_21-34-40'.") from e - return time_digit - - -def get_target_output_dir(monitor_path, time_start, time_end): - check_file_or_directory_path(monitor_path, isdir=True) - time_start = time_str2time_digit(time_start) if time_start is not None else time_start - time_end = time_str2time_digit(time_end) if time_end is not None else time_end - if time_start and time_end and time_start > time_end: - raise ValueError(f"time_start({time_start}) greater than time_end({time_end})") - result = {} - for dirname in os.listdir(monitor_path): - match = re.match(MonitorConst.OUTPUT_DIR_PATTERN, dirname) - if not match: - continue - time_tag = match.group(1) - rank = match.group(2) - target_time = time_str2time_digit(time_tag) - start_ok = time_start is None or target_time >= time_start - end_ok = time_end is None or target_time <= time_end - if start_ok and end_ok: - result[rank] = os.path.join(monitor_path, dirname) - return result - - -def chmod_tensorboard_dir(path): - """ - format配置为tensorboard时,需要补充文件权限设置 - """ - try: - recursive_chmod(path) - except Exception as e: - logger.warning(f"chmod tensorboard dir wrong because {e}, not updated, please check!!!") - - -def validate_set_monitor(grad_acc_steps, start_iteration): - """ - validate parameters of set_monitor. - """ - grad_acc_steps = validate_int_arg(grad_acc_steps, "grad_acc_steps", - MonitorConst.DEFAULT_GRAD_ACC_STEPS, MonitorConst.DEFAULT_GRAD_ACC_STEPS) - - start_iteration = validate_int_arg(start_iteration, "start_iteration", - MonitorConst.DEFAULT_START_ITERATION, MonitorConst.DEFAULT_START_ITERATION) - return grad_acc_steps, start_iteration - - -def validate_int_arg(value, name, minimum, default_value): - """Validate int args, if any exception occurs, use the default value.""" - if value is None: - return default_value - try: - if not is_int(value): - raise TypeError(f"{name} must be int") - if value < minimum: - raise ValueError(f"{name} must greater than {minimum}") - except Exception as e: - value = default_value - logger.warning(f"Validate {name} failed, {e}, replaced with default value {value}.") - return value + return res \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/ms_monitor/test_mon_utils.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/ms_monitor/test_mon_utils.py index 9299d37d07792ab243adc05659a820f33330c308..01005fadf563a92e9fac68dbde4363fff70410fb 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/ms_monitor/test_mon_utils.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/ms_monitor/test_mon_utils.py @@ -82,7 +82,7 @@ class TestMonitorUtils(unittest.TestCase): self.assertTrue(is_skip_step(12, 10, 2, has_collect_times=5, collect_times=5)) def test_validate_ops(self): - from msprobe.mindspore.monitor.utils import validate_ops + from msprobe.core.monitor.utils import validate_ops # 测试输入不是list的情况 with self.assertRaises(TypeError): @@ -105,7 +105,7 @@ class TestMonitorUtils(unittest.TestCase): self.assertIn("dtype", result) def test_validate_ranks(self): - from msprobe.mindspore.monitor.utils import validate_ranks + from msprobe.core.monitor.utils import validate_ranks # 测试输入不是list的情况 with self.assertRaises(TypeError): @@ -122,7 +122,7 @@ class TestMonitorUtils(unittest.TestCase): self.fail(f"validate_ranks raised unexpected exception: {e}") def test_validate_targets(self): - from msprobe.mindspore.monitor.utils import validate_targets + from msprobe.core.monitor.utils import validate_targets # 测试输入不是dict的情况 with self.assertRaises(TypeError): @@ -143,7 +143,7 @@ class TestMonitorUtils(unittest.TestCase): self.fail(f"validate_targets raised unexpected exception: {e}") def test_validate_config(self): - from msprobe.mindspore.monitor.utils import validate_config + from msprobe.core.monitor.utils import validate_config # 测试基本配置验证 config = { @@ -185,7 +185,7 @@ class TestMonitorUtils(unittest.TestCase): validate_config(invalid_config) def test_time_str2time_digit(self): - from msprobe.mindspore.monitor.utils import time_str2time_digit + from msprobe.core.monitor.utils import time_str2time_digit # 测试有效时间字符串 time_str = "Dec03_21-34-40" @@ -201,7 +201,7 @@ class TestMonitorUtils(unittest.TestCase): time_str2time_digit(invalid_time_str) def test_get_target_output_dir(self): - from msprobe.mindspore.monitor.utils import get_target_output_dir + from msprobe.core.monitor.utils import get_target_output_dir # 测试不带时间范围的情况 result = get_target_output_dir(self.temp_dir, None, None) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_monitor_utils.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_monitor_utils.py index 87822ab0503bd21e0546d8c846d69f56204eb048..83e8217c894d38b1d8506cb0e1cd241ffcbcb759 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_monitor_utils.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_monitor_utils.py @@ -5,10 +5,11 @@ from unittest.mock import patch, MagicMock import torch from msprobe.core.common.const import MonitorConst -from msprobe.pytorch.monitor.utils import filter_special_chars, MsgConst, get_param_struct, validate_ops, \ - validate_ranks, validate_targets, validate_print_struct, validate_ur_distribution, validate_xy_distribution, \ +from msprobe.core.monitor.utils import filter_special_chars, MsgConst, validate_ops, validate_ranks, \ + validate_targets, validate_print_struct, validate_ur_distribution, validate_xy_distribution, \ validate_mg_distribution, validate_wg_distribution, validate_cc_distribution, validate_alert, validate_config, \ get_output_base_dir +from msprobe.pytorch.monitor.utils import get_param_struct from msprobe.pytorch.common.utils import is_recomputation diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_optimizer_collect.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_optimizer_collect.py index e32e4f860ee40a2bb3198ee30fd522b98ae2e36e..c7cbd86bbcc8671b88aa8a7c97e24181c5a42379 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_optimizer_collect.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_optimizer_collect.py @@ -9,7 +9,7 @@ from msprobe.pytorch.monitor.optimizer_collect import OptimizerMon, \ MegatronChainedDistributedOptimizerMon, MegatronChainedMixPrecisionOptimizerMon, \ DeepSpeedZeroOptimizerMon, DeepSpeedZeroOptimizerStage0Mon, \ DeepSpeedZeroOptimizerStage1or2Mon, DeepSpeedZeroOptimizerStage3Mon -from msprobe.pytorch.monitor.utils import MVResult +from msprobe.core.monitor.utils import MVResult def setup_param_groups(num_groups=2, params_per_group=5):