From 20c807ca7fc1192850584ee78bf85090ab4b604a Mon Sep 17 00:00:00 2001 From: TAJh Date: Tue, 1 Jul 2025 15:33:00 +0800 Subject: [PATCH 1/4] add param checker --- .../msprobe/mindspore/monitor/module_hook.py | 1 - .../msprobe/mindspore/monitor/utils.py | 34 ++++++++++++++++--- .../msprobe/pytorch/monitor/utils.py | 10 ++++++ 3 files changed, 40 insertions(+), 5 deletions(-) diff --git a/debug/accuracy_tools/msprobe/mindspore/monitor/module_hook.py b/debug/accuracy_tools/msprobe/mindspore/monitor/module_hook.py index 0354ab53368..1e22ec14cea 100644 --- a/debug/accuracy_tools/msprobe/mindspore/monitor/module_hook.py +++ b/debug/accuracy_tools/msprobe/mindspore/monitor/module_hook.py @@ -250,7 +250,6 @@ class TrainerMon: self.has_collect_times = 0 # 重设采集计数器 self.print_struct = self.config.get("print_struct", False) self.targets = self.config.get("targets", None) - self.is_select = self.config.get("is_select", False) self.module_rank_list = self.config.get("module_ranks", []) self.format = self.config.get('format', MonitorConst.CSV) # only csv supported in mindspore self.eps = self.config.get('eps', 1e-8) diff --git a/debug/accuracy_tools/msprobe/mindspore/monitor/utils.py b/debug/accuracy_tools/msprobe/mindspore/monitor/utils.py index e0817eb2a4e..5baf3761629 100644 --- a/debug/accuracy_tools/msprobe/mindspore/monitor/utils.py +++ b/debug/accuracy_tools/msprobe/mindspore/monitor/utils.py @@ -153,6 +153,15 @@ def validate_param_distribution(param_distribution): raise TypeError('param_distribution should be a bool') +def validate_ndigits(ndigits): + if not ndigits: + return + if not is_int(ndigits) or ndigits <= 0: + raise ValueError(f"ndigits({ndigits}) is not a positive integer, current is: {ndigits}.") + if ndigits > MonitorConst.MAX_NDIGITS: + raise ValueError(f"The maximum supported ndigits is {MonitorConst.MAX_NDIGITS}, current value: {ndigits}.") + + def validate_cc_distribution(cc_distribution): if not isinstance(cc_distribution, dict): raise TypeError('cc_distribution should be a dictionary') @@ -235,9 +244,24 @@ def validate_monitor_mbs_grad(monitor_mbs_grad): return monitor_mbs_grad +def validate_squash_name(squash_name): + if not isinstance(squash_name, bool): + raise TypeError('squash_name should be a bool') + + +def validate_append_output(append_output): + if not isinstance(append_output, list): + raise TypeError('append_output should be a list') + if len(append_output) > 0 and len(append_output) != 2: + raise ValueError('append_output should be empty or contain exactly 2 elements') + + def validate_config(config): config['ops'] = validate_ops(config.get('ops', [])) + ndigits = config.get('ndigits') + validate_ndigits(ndigits) + eps = config.get('eps', 1e-8) if not isinstance(eps, float): raise TypeError("eps should be a float") @@ -277,13 +301,18 @@ def validate_config(config): start_step = config.get('start_step', 0) validate_start_step(start_step) - step_interval = config.get('step_interval', 1) validate_step_interval(step_interval) collect_times = config.get('collect_times', int(1e8)) validate_collect_times(collect_times) + squash_name = config.get('squash_name', True) + validate_squash_name(squash_name) + + time_tags = config.get("append_output", []) + validate_append_output(time_tags) + config["monitor_mbs_grad"] = validate_monitor_mbs_grad(config.get('monitor_mbs_grad', False)) dynamic_on = config.get('dynamic_on', False) @@ -293,9 +322,6 @@ def validate_config(config): if xy_distribution: config["all_xy"] = True config["targets"] = {"": {}} - config["is_select"] = False - else: - config["is_select"] = True def time_str2time_digit(time_str): diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/utils.py b/debug/accuracy_tools/msprobe/pytorch/monitor/utils.py index 76770747971..22157dae983 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/utils.py @@ -231,6 +231,13 @@ def validate_monitor_mbs_grad(monitor_mbs_grad): return monitor_mbs_grad +def validate_append_output(append_output): + if not isinstance(append_output, list): + raise TypeError('append_output should be a list') + if len(append_output) > 0 and len(append_output) != 2: + raise ValueError('append_output should be empty or contain exactly 2 elements') + + def validate_config(config): config['ops'] = validate_ops(config.get('ops', [])) @@ -285,6 +292,9 @@ def validate_config(config): squash_name = config.get('squash_name', True) validate_squash_name(squash_name) + time_tags = config.get("append_output", []) + validate_append_output(time_tags) + config["monitor_mbs_grad"] = validate_monitor_mbs_grad(config.get('monitor_mbs_grad', False)) dynamic_on = config.get('dynamic_on', False) -- Gitee From 2924a778541a418bae86c76e8e7fe955adfec3e9 Mon Sep 17 00:00:00 2001 From: TAJh Date: Wed, 2 Jul 2025 11:43:19 +0800 Subject: [PATCH 2/4] bugfix --- .../msprobe/core/monitor/utils.py | 343 ++++++++++++++++++ .../msprobe/mindspore/monitor/common_func.py | 2 +- .../msprobe/mindspore/monitor/module_hook.py | 5 +- .../msprobe/mindspore/monitor/utils.py | 278 -------------- .../msprobe/pytorch/monitor/csv2tb.py | 3 +- .../msprobe/pytorch/monitor/module_hook.py | 5 +- .../pytorch/monitor/optimizer_collect.py | 2 +- .../msprobe/pytorch/monitor/utils.py | 329 +---------------- 8 files changed, 354 insertions(+), 613 deletions(-) create mode 100644 debug/accuracy_tools/msprobe/core/monitor/utils.py diff --git a/debug/accuracy_tools/msprobe/core/monitor/utils.py b/debug/accuracy_tools/msprobe/core/monitor/utils.py new file mode 100644 index 00000000000..24cdbaa1ba4 --- /dev/null +++ b/debug/accuracy_tools/msprobe/core/monitor/utils.py @@ -0,0 +1,343 @@ +# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import namedtuple +from datetime import timezone, timedelta +from functools import wraps +from datetime import datetime +import os +import re + +from msprobe.core.common.const import MonitorConst +from msprobe.core.common.log import logger +from msprobe.core.common.utils import is_int +from msprobe.core.common.file_utils import check_file_or_directory_path, recursive_chmod + + +NAN_TENSOR_ON_DEVICE = None +FILE_MAX_SIZE = 10 * 1024 * 1024 * 1024 +FILE_NAME_MAX_LENGTH = 255 +DIRECTORY_MAX_LENGTH = 4096 + +beijing_tz = timezone(timedelta(hours=8)) +MVResult = namedtuple('MVResult', ("exp_avg", "exp_avg_sq", "update", "ratio")) + + +class MsgConst: + """ + Class for log messages const + """ + SPECIAL_CHAR = ["\n", "\r", "\u007F", "\b", "\f", "\t", "\u000B", "%08", "%0a", "%0b", "%0c", "%0d", "%7f"] + + +def get_output_base_dir(): + return os.getenv(MonitorConst.MONITOR_OUTPUT_DIR, MonitorConst.DEFAULT_MONITOR_OUTPUT_DIR) + + +def filter_special_chars(func): + @wraps(func) + def func_level(msg): + for char in MsgConst.SPECIAL_CHAR: + msg = msg.replace(char, '_') + return func(msg) + + return func_level + + +def validate_ops(ops): + if not isinstance(ops, list): + raise TypeError("ops should be a list") + valid_ops = [] + for op in ops: + if op not in MonitorConst.OP_LIST: + logger.warning(f"op {op} is not supported. Optional ops: {MonitorConst.OP_LIST}") + continue + valid_ops.append(op) + if not valid_ops: + default_op = MonitorConst.OP_LIST[0] + valid_ops.append(default_op) + logger.info_on_rank_0(f"There is no valid ops, default op {default_op} is used") + # 增加默认shape和dtype参数 + if "shape" not in valid_ops: + valid_ops.append("shape") + if "dtype" not in valid_ops: + valid_ops.append("dtype") + return valid_ops + + +def validate_ndigits(ndigits): + if not ndigits: + return + if not is_int(ndigits) or ndigits <= 0: + raise ValueError(f"ndigits({ndigits}) is not a positive integer, current is: {ndigits}.") + if ndigits > MonitorConst.MAX_NDIGITS: + raise ValueError(f"The maximum supported ndigits is {MonitorConst.MAX_NDIGITS}, current value: {ndigits}.") + + +def validate_ranks(ranks): + if not isinstance(ranks, list): + raise TypeError("module_ranks should be a list") + for rank in ranks: + if not isinstance(rank, int) or isinstance(rank, bool): + raise TypeError(f"element in module_ranks should be a int, get {type(rank)}") + + +def validate_targets(targets): + if not isinstance(targets, dict): + raise TypeError('targets in config.json should be a dict') + for module_name, field in targets.items(): + if not isinstance(module_name, str): + raise TypeError('key of targets should be module_name[str] in config.json') + if not isinstance(field, dict): + raise TypeError('values of targets should be cared filed e.g. {"input": "tensor"} in config.json') + + +def validate_print_struct(print_struct): + if not isinstance(print_struct, bool): + raise TypeError("print_struct should be a bool") + + +def validate_ur_distribution(ur_distribution): + if not isinstance(ur_distribution, bool): + raise TypeError('ur_distribution should be a bool') + + +def validate_xy_distribution(xy_distribution): + if not isinstance(xy_distribution, bool): + raise TypeError('xy_distribution should be a bool') + + +def validate_wg_distribution(wg_distribution): + if not isinstance(wg_distribution, bool): + raise TypeError('wg_distribution should be a bool') + + +def validate_mg_distribution(mg_distribution): + if not isinstance(mg_distribution, bool): + raise TypeError('mg_distribution should be a bool') + + +def validate_param_distribution(param_distribution): + if not isinstance(param_distribution, bool): + raise TypeError('param_distribution should be a bool') + + +def validate_cc_distribution(cc_distribution): + if not isinstance(cc_distribution, dict): + raise TypeError('cc_distribution should be a dictionary') + for key, value in cc_distribution.items(): + if key == 'enable': + if not isinstance(value, bool): + raise TypeError('cc_distribution enable should be a bool') + elif key == 'cc_codeline': + if not isinstance(value, list): + raise TypeError('cc_distribution cc_codeline should be a list') + elif key == 'cc_pre_hook': + if not isinstance(value, bool): + raise TypeError('cc_distribution cc_pre_hook should be a bool') + elif key == 'cc_log_only': + if not isinstance(value, bool): + raise TypeError('cc_distribution cc_log_only should be a bool') + else: + raise TypeError(f'{key} of cc_distribution is not supported.') + + +def validate_squash_name(squash_name): + if not isinstance(squash_name, bool): + raise TypeError('squash_name should be a bool') + + +def validate_alert(alert): + if not isinstance(alert, dict): + raise TypeError('alert should be a dictionary') + rules = alert.get('rules') + if rules and isinstance(rules, list): + for rule in rules: + rule_name = rule.get("rule_name") + if rule_name and rule_name not in MonitorConst.RULE_NAME: + raise TypeError(f"{rule_name} is not supported") + args = rule.get("args") + if args and isinstance(args, dict): + threshold = args.get("threshold") + if not isinstance(threshold, (float, int)) or threshold < 0: + raise TypeError('threshold must be float and not less than 0') + dump = alert.get('dump') + if dump and not isinstance(dump, bool): + raise TypeError('dump must be bool.') + + +def validate_step_count_per_record(step_count_per_record): + if not is_int(step_count_per_record): + raise TypeError('step_count_per_record must be int.') + if step_count_per_record < 1: + raise ValueError("step_count_per_record must greater than 0") + if step_count_per_record > 1e6: + raise ValueError("step_count_per_record must smaller than 1e6") + + +def validate_dynamic_on(dynamic_on): + if not isinstance(dynamic_on, bool): + raise TypeError('dynamic_on should be a bool') + + +def validate_monitor_mbs_grad(monitor_mbs_grad): + if not isinstance(monitor_mbs_grad, bool): + logger.warning(f'monitor_mbs_grad should be a bool, actual value is {monitor_mbs_grad}.') + return False + return monitor_mbs_grad + + +def validate_append_output(append_output): + if not isinstance(append_output, list): + raise TypeError('append_output should be a list') + if len(append_output) > 0 and len(append_output) != 2: + raise ValueError('append_output should be empty or contain exactly 2 elements') + + +def validate_config(config): + config['ops'] = validate_ops(config.get('ops', [])) + + ndigits = config.get('ndigits') + validate_ndigits(ndigits) + + eps = config.get('eps', 1e-8) + if not isinstance(eps, float): + raise TypeError("eps should be a float") + + ranks = config.get("module_ranks", []) + validate_ranks(ranks) + + targets = config.get("targets", {}) + validate_targets(targets) + + print_struct = config.get('print_struct', False) + validate_print_struct(print_struct) + + ur_distribution = config.get('ur_distribution', False) + validate_ur_distribution(ur_distribution) + + xy_distribution = config.get('xy_distribution', False) + validate_xy_distribution(xy_distribution) + + wg_distribution = config.get('wg_distribution', False) + validate_wg_distribution(wg_distribution) + + mg_distribution = config.get('mg_distribution', False) + validate_mg_distribution(mg_distribution) + + param_distribution = config.get('param_distribution', False) + validate_param_distribution(param_distribution) + + cc_distribution = config.get('cc_distribution', {}) + validate_cc_distribution(cc_distribution) + + alert = config.get('alert', {}) + validate_alert(alert) + + step_count_per_record = config.get('step_count_per_record', 1) + validate_step_count_per_record(step_count_per_record) + + config["start_step"] = validate_int_arg(config.get("start_step"), "start_step", + MonitorConst.DEFAULT_START_STEP, MonitorConst.DEFAULT_START_STEP) + config["collect_times"] = validate_int_arg(config.get("collect_times"), "collect_times", + MonitorConst.DEFAULT_MIN_COLLECT_TIMES, + MonitorConst.DEFAULT_MAX_COLLECT_TIMES) + config["step_interval"] = validate_int_arg(config.get("step_interval"), "step_interval", + MonitorConst.DEFAULT_STEP_INTERVAL, MonitorConst.DEFAULT_STEP_INTERVAL) + + squash_name = config.get('squash_name', True) + validate_squash_name(squash_name) + + time_tags = config.get("append_output", []) + validate_append_output(time_tags) + + config["monitor_mbs_grad"] = validate_monitor_mbs_grad(config.get('monitor_mbs_grad', False)) + + dynamic_on = config.get('dynamic_on', False) + validate_dynamic_on(dynamic_on) + + if not targets: + if xy_distribution: + config["all_xy"] = True + config["targets"] = {"": {}} + + +def time_str2time_digit(time_str): + time_format = '%b%d_%H-%M-%S' + if not isinstance(time_str, str): + raise TypeError(f"time_str:{time_str} should be a str") + try: + time_digit = datetime.strptime(time_str, time_format) + except Exception as e: + raise RuntimeError(f"illegal timestamp: {time_str}, timestamp should be prefix \ + of existing output dirpath, like 'Dec03_21-34-40'.") from e + return time_digit + + +def get_target_output_dir(monitor_path, time_start, time_end): + check_file_or_directory_path(monitor_path, isdir=True) + time_start = time_str2time_digit(time_start) if time_start is not None else time_start + time_end = time_str2time_digit(time_end) if time_end is not None else time_end + if time_start and time_end and time_start > time_end: + raise ValueError(f"time_start({time_start}) greater than time_end({time_end})") + result = {} + for dirname in os.listdir(monitor_path): + match = re.match(MonitorConst.OUTPUT_DIR_PATTERN, dirname) + if not match: + continue + time_tag = match.group(1) + rank = match.group(2) + target_time = time_str2time_digit(time_tag) + start_ok = time_start is None or target_time >= time_start + end_ok = time_end is None or target_time <= time_end + if start_ok and end_ok: + result[rank] = os.path.join(monitor_path, dirname) + return result + + +def chmod_tensorboard_dir(path): + """ + format配置为tensorboard时,需要补充文件权限设置 + """ + try: + recursive_chmod(path) + except Exception as e: + logger.warning(f"chmod tensorboard dir wrong because {e}, not updated, please check!!!") + + +def validate_set_monitor(grad_acc_steps, start_iteration): + """ + validate parameters of set_monitor. + """ + grad_acc_steps = validate_int_arg(grad_acc_steps, "grad_acc_steps", + MonitorConst.DEFAULT_GRAD_ACC_STEPS, MonitorConst.DEFAULT_GRAD_ACC_STEPS) + + start_iteration = validate_int_arg(start_iteration, "start_iteration", + MonitorConst.DEFAULT_START_ITERATION, MonitorConst.DEFAULT_START_ITERATION) + return grad_acc_steps, start_iteration + + +def validate_int_arg(value, name, minimum, default_value): + """Validate int args, if any exception occurs, use the default value.""" + if value is None: + return default_value + try: + if not is_int(value): + raise TypeError(f"{name} must be int") + if value < minimum: + raise ValueError(f"{name} must greater than {minimum}") + except Exception as e: + value = default_value + logger.warning(f"Validate {name} failed, {e}, replaced with default value {value}.") + return value diff --git a/debug/accuracy_tools/msprobe/mindspore/monitor/common_func.py b/debug/accuracy_tools/msprobe/mindspore/monitor/common_func.py index 52db10d0546..5880d2284f3 100644 --- a/debug/accuracy_tools/msprobe/mindspore/monitor/common_func.py +++ b/debug/accuracy_tools/msprobe/mindspore/monitor/common_func.py @@ -16,7 +16,7 @@ from mindspore import nn from mindspore import communication -from msprobe.mindspore.monitor.utils import logger +from msprobe.core.common.log import logger from msprobe.mindspore.common.utils import is_mindtorch if is_mindtorch(): import torch diff --git a/debug/accuracy_tools/msprobe/mindspore/monitor/module_hook.py b/debug/accuracy_tools/msprobe/mindspore/monitor/module_hook.py index 1e22ec14cea..0d74fca7808 100644 --- a/debug/accuracy_tools/msprobe/mindspore/monitor/module_hook.py +++ b/debug/accuracy_tools/msprobe/mindspore/monitor/module_hook.py @@ -28,11 +28,12 @@ from mindspore import nn, _no_grad from msprobe.core.common.log import logger from msprobe.core.common.const import MonitorConst, Const from msprobe.core.common.file_utils import load_json, save_json +from msprobe.core.monitor.utils import validate_config, get_output_base_dir, get_target_output_dir from msprobe.core.monitor.anomaly_processor import AnomalyScanner, AnomalyDataFactory, AnomalyDataWriter from msprobe.mindspore.common.utils import is_mindtorch from msprobe.mindspore.monitor.common_func import is_valid_instance, get_parameters, get_submodules, get_rank -from msprobe.mindspore.monitor.utils import get_summary_writer_tag_name, validate_config, step_accumulates_one, \ - is_skip_step, get_metrics, get_target_output_dir +from msprobe.mindspore.monitor.utils import get_summary_writer_tag_name, step_accumulates_one, is_skip_step, \ + get_metrics from msprobe.mindspore.monitor.optimizer_collect import OptimizerMonFactory from msprobe.mindspore.monitor.data_writers import CSVWriterWithAD, BaseWriterWithAD, WriterInput from msprobe.mindspore.monitor.distributed.wrap_distributed import api_register, create_hooks, op_aggregate diff --git a/debug/accuracy_tools/msprobe/mindspore/monitor/utils.py b/debug/accuracy_tools/msprobe/mindspore/monitor/utils.py index 5baf3761629..a6cc09fb3c1 100644 --- a/debug/accuracy_tools/msprobe/mindspore/monitor/utils.py +++ b/debug/accuracy_tools/msprobe/mindspore/monitor/utils.py @@ -12,16 +12,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os -import re -from datetime import datetime from mindspore import dtype as mstype, Tensor from msprobe.mindspore.monitor.features import FUNC_MAP -from msprobe.core.common.const import MonitorConst -from msprobe.core.common.utils import is_int -from msprobe.core.common.log import logger -from msprobe.core.common.file_utils import check_file_or_directory_path def get_single_metrics(op_list, tag, tensor, eps=1e-8, output=None): @@ -82,274 +75,3 @@ def is_skip_step(step, start_step, step_interval, has_collect_times=0, collect_t :return: whether skip or not, bool """ return step < start_step or (step - start_step) % step_interval != 0 or has_collect_times >= collect_times - - -def validate_ops(ops): - if not isinstance(ops, list): - raise TypeError("ops should be a list") - valid_ops = [] - for op in ops: - if op not in MonitorConst.OP_LIST: - logger.warning(f"op {op} is not supported. Optional ops: {MonitorConst.OP_LIST}") - continue - valid_ops.append(op) - if not valid_ops: - default_op = MonitorConst.OP_LIST[0] - valid_ops.append(default_op) - logger.info(f"There is no valid ops, default op {default_op} is used") - # 增加默认shape和dtype参数 - if "shape" not in valid_ops: - valid_ops.append("shape") - if "dtype" not in valid_ops: - valid_ops.append("dtype") - return valid_ops - - -def validate_ranks(ranks): - if not isinstance(ranks, list): - raise TypeError("module_ranks should be a list") - for rank in ranks: - if not isinstance(rank, int): - raise TypeError(f"element in module_ranks should be a int, get {type(rank)}") - - -def validate_targets(targets): - if not isinstance(targets, dict): - raise TypeError('targets in config.json should be a dict') - for module_name, field in targets.items(): - if not isinstance(module_name, str): - raise TypeError('key of targets should be module_name[str] in config.json') - if not isinstance(field, dict): - raise TypeError('values of targets should be cared filed e.g. {"input": "tensor"} in config.json') - - -def validate_print_struct(print_struct): - if not isinstance(print_struct, bool): - raise TypeError("print_struct should be a bool") - - -def validate_ur_distribution(ur_distribution): - if not isinstance(ur_distribution, bool): - raise TypeError('ur_distribution should be a bool') - - -def validate_xy_distribution(xy_distribution): - if not isinstance(xy_distribution, bool): - raise TypeError('xy_distribution should be a bool') - - -def validate_wg_distribution(wg_distribution): - if not isinstance(wg_distribution, bool): - raise TypeError('wg_distribution should be a bool') - - -def validate_mg_distribution(mg_distribution): - if not isinstance(mg_distribution, bool): - raise TypeError('mg_distribution should be a bool') - - -def validate_param_distribution(param_distribution): - if not isinstance(param_distribution, bool): - raise TypeError('param_distribution should be a bool') - - -def validate_ndigits(ndigits): - if not ndigits: - return - if not is_int(ndigits) or ndigits <= 0: - raise ValueError(f"ndigits({ndigits}) is not a positive integer, current is: {ndigits}.") - if ndigits > MonitorConst.MAX_NDIGITS: - raise ValueError(f"The maximum supported ndigits is {MonitorConst.MAX_NDIGITS}, current value: {ndigits}.") - - -def validate_cc_distribution(cc_distribution): - if not isinstance(cc_distribution, dict): - raise TypeError('cc_distribution should be a dictionary') - expected_keys = { - 'enable': bool, - 'cc_codeline': list, - 'cc_pre_hook': bool, - 'cc_log_only': bool - } - for key, value in cc_distribution.items(): - if key in expected_keys: - if not isinstance(value, expected_keys[key]): - raise TypeError(f'cc_distribution {key} should be a {expected_keys[key].__name__}') - else: - raise TypeError(f'{key} of cc_distribution is not supported.') - - -def validate_alert(alert): - if not isinstance(alert, dict): - raise TypeError('alert should be a dictionary') - rules = alert.get('rules') - if rules and isinstance(rules, list): - for rule in rules: - rule_name = rule.get("rule_name") - if rule_name and rule_name not in MonitorConst.RULE_NAME: - raise TypeError(f"{rule_name} is not supported") - args = rule.get("args") - if args and isinstance(args, dict): - threshold = args.get("threshold") - if not isinstance(threshold, (float, int)) or threshold < 0: - raise TypeError('threshold must be float and not less than 0') - dump = alert.get('dump') - if dump and not isinstance(dump, bool): - raise TypeError('dump must be bool.') - - -def validate_step_count_per_record(step_count_per_record): - if not is_int(step_count_per_record): - raise TypeError('step_count_per_record must be int.') - if step_count_per_record < 1: - raise ValueError("step_count_per_record must greater than 0") - if step_count_per_record > 1e6: - raise ValueError("step_count_per_record must smaller than 1e6") - - -def validate_start_step(start_step): - if not is_int(start_step): - raise TypeError('start_step must be int.') - if start_step < 0: - raise ValueError("start_step must greater than 0") - if start_step > 1e8: - raise ValueError("start_step must smaller than 1e8") - - -def validate_step_interval(step_interval): - if not is_int(step_interval): - raise TypeError('step_interval must be int.') - if step_interval < 1: - raise ValueError("step_interval must greater than 1") - if step_interval > 1e8: - raise ValueError("step_interval must smaller than 1e8") - - -def validate_collect_times(collect_times): - if not is_int(collect_times): - raise TypeError('collect_times must be int.') - if collect_times < 1: - raise ValueError("collect_times must greater than 1") - - -def validate_dynamic_on(dynamic_on): - if not isinstance(dynamic_on, bool): - raise TypeError('dynamic_on should be a bool') - - -def validate_monitor_mbs_grad(monitor_mbs_grad): - if not isinstance(monitor_mbs_grad, bool): - logger.warning(f'monitor_mbs_grad should be a bool, actual value is {monitor_mbs_grad}.') - return False - return monitor_mbs_grad - - -def validate_squash_name(squash_name): - if not isinstance(squash_name, bool): - raise TypeError('squash_name should be a bool') - - -def validate_append_output(append_output): - if not isinstance(append_output, list): - raise TypeError('append_output should be a list') - if len(append_output) > 0 and len(append_output) != 2: - raise ValueError('append_output should be empty or contain exactly 2 elements') - - -def validate_config(config): - config['ops'] = validate_ops(config.get('ops', [])) - - ndigits = config.get('ndigits') - validate_ndigits(ndigits) - - eps = config.get('eps', 1e-8) - if not isinstance(eps, float): - raise TypeError("eps should be a float") - - ranks = config.get("module_ranks", []) - validate_ranks(ranks) - - targets = config.get("targets", {}) - validate_targets(targets) - - print_struct = config.get('print_struct', False) - validate_print_struct(print_struct) - - ur_distribution = config.get('ur_distribution', False) - validate_ur_distribution(ur_distribution) - - xy_distribution = config.get('xy_distribution', False) - validate_xy_distribution(xy_distribution) - - wg_distribution = config.get('wg_distribution', False) - validate_wg_distribution(wg_distribution) - - mg_distribution = config.get('mg_distribution', False) - validate_mg_distribution(mg_distribution) - - param_distribution = config.get('param_distribution', False) - validate_param_distribution(param_distribution) - - cc_distribution = config.get('cc_distribution', {}) - validate_cc_distribution(cc_distribution) - - alert = config.get('alert', {}) - validate_alert(alert) - - step_count_per_record = config.get('step_count_per_record', 1) - validate_step_count_per_record(step_count_per_record) - - start_step = config.get('start_step', 0) - validate_start_step(start_step) - step_interval = config.get('step_interval', 1) - validate_step_interval(step_interval) - - collect_times = config.get('collect_times', int(1e8)) - validate_collect_times(collect_times) - - squash_name = config.get('squash_name', True) - validate_squash_name(squash_name) - - time_tags = config.get("append_output", []) - validate_append_output(time_tags) - - config["monitor_mbs_grad"] = validate_monitor_mbs_grad(config.get('monitor_mbs_grad', False)) - - dynamic_on = config.get('dynamic_on', False) - validate_dynamic_on(dynamic_on) - - if not targets: - if xy_distribution: - config["all_xy"] = True - config["targets"] = {"": {}} - - -def time_str2time_digit(time_str): - time_format = '%b%d_%H-%M-%S' - try: - time_digit = datetime.strptime(time_str, time_format) - except Exception as e: - raise RuntimeError(f"illegal timestamp: {time_str}, timestamp should be prefix \ - of existing output dirpath, like 'Dec03_21-34-40'.") from e - return time_digit - - -def get_target_output_dir(monitor_path, time_start, time_end): - check_file_or_directory_path(monitor_path, isdir=True) - time_start = time_str2time_digit(time_start) if time_start is not None else time_start - time_end = time_str2time_digit(time_end) if time_end is not None else time_end - if time_start and time_end and time_start > time_end: - raise ValueError(f"time_start({time_start}) greater than time_end({time_end})") - result = {} - for dirname in os.listdir(monitor_path): - match = re.match(MonitorConst.OUTPUT_DIR_PATTERN, dirname) - if not match: - continue - time_tag = match.group(1) - rank = match.group(2) - target_time = time_str2time_digit(time_tag) - start_ok = time_start is None or target_time >= time_start - end_ok = time_end is None or target_time <= time_end - if start_ok and end_ok: - result[rank] = os.path.join(monitor_path, dirname) - return result diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/csv2tb.py b/debug/accuracy_tools/msprobe/pytorch/monitor/csv2tb.py index d2a6d95cb3e..a807f7e2798 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/csv2tb.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/csv2tb.py @@ -25,8 +25,9 @@ from msprobe.core.common.const import MonitorConst from msprobe.core.common.file_utils import read_csv, create_directory, remove_path, recursive_chmod from msprobe.core.common.utils import check_process_num from msprobe.core.common.decorator import recursion_depth_decorator +from msprobe.core.monitor.utils import get_target_output_dir from msprobe.pytorch.common.log import logger -from msprobe.pytorch.monitor.utils import get_target_output_dir + all_data_type_list = [ "actv", "actv_grad", "exp_avg", "exp_avg_sq", diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py b/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py index 042ccf651c9..0158b371375 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py @@ -31,8 +31,11 @@ from msprobe.core.common.decorator import recursion_depth_decorator from msprobe.core.monitor.anomaly_processor import AnomalyScanner, AnomalyDataFactory, AnomalyDataWriter from msprobe.core.common.file_utils import write_df_to_csv from msprobe.core.common.utils import analyze_api_call_stack +from msprobe.core.monitor.utils import validate_config, validate_ops, \ + get_output_base_dir, get_target_output_dir, chmod_tensorboard_dir, validate_set_monitor from msprobe.pytorch.common.log import logger from msprobe.pytorch.common.utils import is_recomputation, is_float8_tensor +from msprobe.pytorch.monitor.utils import get_param_struct from msprobe.pytorch.monitor.data_writers import SummaryWriterWithAD, CSVWriterWithAD, BaseWriterWithAD, WriterInput from msprobe.pytorch.monitor.distributed.wrap_distributed import api_register, create_hooks, op_aggregate, \ get_process_group @@ -40,8 +43,6 @@ from msprobe.pytorch.monitor.features import get_sign_matches from msprobe.pytorch.monitor.module_metric import get_metrics, get_summary_writer_tag_name, \ TensorMetrics, squash_param_name from msprobe.pytorch.monitor.optimizer_collect import OptimizerMonFactory -from msprobe.pytorch.monitor.utils import get_param_struct, validate_config, validate_ops, \ - get_output_base_dir, get_target_output_dir, chmod_tensorboard_dir, validate_set_monitor from msprobe.pytorch.monitor.visualizer import HeatmapVisualizer diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/optimizer_collect.py b/debug/accuracy_tools/msprobe/pytorch/monitor/optimizer_collect.py index 08827fef178..8a63eaef9c3 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/optimizer_collect.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/optimizer_collect.py @@ -17,7 +17,7 @@ from abc import abstractmethod import torch from msprobe.pytorch.common.log import logger -from msprobe.pytorch.monitor.utils import MVResult +from msprobe.core.monitor.utils import MVResult from msprobe.core.common.const import MonitorConst diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/utils.py b/debug/accuracy_tools/msprobe/pytorch/monitor/utils.py index 22157dae983..31cdd53f42b 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/utils.py @@ -12,20 +12,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import inspect -from collections import namedtuple -from datetime import timezone, timedelta -from functools import wraps -from datetime import datetime -import os -import re - import torch -from msprobe.core.common.const import MonitorConst from msprobe.pytorch.common.log import logger -from msprobe.core.common.utils import is_int -from msprobe.core.common.file_utils import check_file_or_directory_path, recursive_chmod device = "cpu" @@ -37,24 +26,6 @@ except ImportError: device = "cuda" NAN_TENSOR_ON_DEVICE = None -FILE_MAX_SIZE = 10 * 1024 * 1024 * 1024 -FILE_NAME_MAX_LENGTH = 255 -DIRECTORY_MAX_LENGTH = 4096 - -beijing_tz = timezone(timedelta(hours=8)) -MVResult = namedtuple('MVResult', ("exp_avg", "exp_avg_sq", "update", "ratio")) - - -class MsgConst: - """ - Class for log messages const - """ - SPECIAL_CHAR = ["\n", "\r", "\u007F", "\b", "\f", "\t", "\u000B", "%08", "%0a", "%0b", "%0c", "%0d", "%7f"] - - -def get_output_base_dir(): - return os.getenv(MonitorConst.MONITOR_OUTPUT_DIR, MonitorConst.DEFAULT_MONITOR_OUTPUT_DIR) - def get_nan_tensor(): global NAN_TENSOR_ON_DEVICE @@ -63,16 +34,6 @@ def get_nan_tensor(): return NAN_TENSOR_ON_DEVICE -def filter_special_chars(func): - @wraps(func) - def func_level(msg): - for char in MsgConst.SPECIAL_CHAR: - msg = msg.replace(char, '_') - return func(msg) - - return func_level - - def get_param_struct(param): res = {} if isinstance(param, (tuple, list)): @@ -85,292 +46,4 @@ def get_param_struct(param): else: res['config'] = f'{type(param)}' logger.warning(f'Not support type({type(param)}) now, please check the type of param {param}') - return res - - -def validate_ops(ops): - if not isinstance(ops, list): - raise TypeError("ops should be a list") - valid_ops = [] - for op in ops: - if op not in MonitorConst.OP_LIST: - logger.warning(f"op {op} is not supported. Optional ops: {MonitorConst.OP_LIST}") - continue - valid_ops.append(op) - if not valid_ops: - default_op = MonitorConst.OP_LIST[0] - valid_ops.append(default_op) - logger.info_on_rank_0(f"There is no valid ops, default op {default_op} is used") - # 增加默认shape和dtype参数 - if "shape" not in valid_ops: - valid_ops.append("shape") - if "dtype" not in valid_ops: - valid_ops.append("dtype") - return valid_ops - - -def validate_ndigits(ndigits): - if not ndigits: - return - if not is_int(ndigits) or ndigits <= 0: - raise ValueError(f"ndigits({ndigits}) is not a positive integer, current is: {ndigits}.") - if ndigits > MonitorConst.MAX_NDIGITS: - raise ValueError(f"The maximum supported ndigits is {MonitorConst.MAX_NDIGITS}, current value: {ndigits}.") - - -def validate_ranks(ranks): - if not isinstance(ranks, list): - raise TypeError("module_ranks should be a list") - for rank in ranks: - if not isinstance(rank, int) or isinstance(rank, bool): - raise TypeError(f"element in module_ranks should be a int, get {type(rank)}") - - -def validate_targets(targets): - if not isinstance(targets, dict): - raise TypeError('targets in config.json should be a dict') - for module_name, field in targets.items(): - if not isinstance(module_name, str): - raise TypeError('key of targets should be module_name[str] in config.json') - if not isinstance(field, dict): - raise TypeError('values of targets should be cared filed e.g. {"input": "tensor"} in config.json') - - -def validate_print_struct(print_struct): - if not isinstance(print_struct, bool): - raise TypeError("print_struct should be a bool") - - -def validate_ur_distribution(ur_distribution): - if not isinstance(ur_distribution, bool): - raise TypeError('ur_distribution should be a bool') - - -def validate_xy_distribution(xy_distribution): - if not isinstance(xy_distribution, bool): - raise TypeError('xy_distribution should be a bool') - - -def validate_wg_distribution(wg_distribution): - if not isinstance(wg_distribution, bool): - raise TypeError('wg_distribution should be a bool') - - -def validate_mg_distribution(mg_distribution): - if not isinstance(mg_distribution, bool): - raise TypeError('mg_distribution should be a bool') - - -def validate_param_distribution(param_distribution): - if not isinstance(param_distribution, bool): - raise TypeError('param_distribution should be a bool') - - -def validate_cc_distribution(cc_distribution): - if not isinstance(cc_distribution, dict): - raise TypeError('cc_distribution should be a dictionary') - for key, value in cc_distribution.items(): - if key == 'enable': - if not isinstance(value, bool): - raise TypeError('cc_distribution enable should be a bool') - elif key == 'cc_codeline': - if not isinstance(value, list): - raise TypeError('cc_distribution cc_codeline should be a list') - elif key == 'cc_pre_hook': - if not isinstance(value, bool): - raise TypeError('cc_distribution cc_pre_hook should be a bool') - elif key == 'cc_log_only': - if not isinstance(value, bool): - raise TypeError('cc_distribution cc_log_only should be a bool') - else: - raise TypeError(f'{key} of cc_distribution is not supported.') - - -def validate_squash_name(squash_name): - if not isinstance(squash_name, bool): - raise TypeError('squash_name should be a bool') - - -def validate_alert(alert): - if not isinstance(alert, dict): - raise TypeError('alert should be a dictionary') - rules = alert.get('rules') - if rules and isinstance(rules, list): - for rule in rules: - rule_name = rule.get("rule_name") - if rule_name and rule_name not in MonitorConst.RULE_NAME: - raise TypeError(f"{rule_name} is not supported") - args = rule.get("args") - if args and isinstance(args, dict): - threshold = args.get("threshold") - if not isinstance(threshold, (float, int)) or threshold < 0: - raise TypeError('threshold must be float and not less than 0') - dump = alert.get('dump') - if dump and not isinstance(dump, bool): - raise TypeError('dump must be bool.') - - -def validate_step_count_per_record(step_count_per_record): - if not is_int(step_count_per_record): - raise TypeError('step_count_per_record must be int.') - if step_count_per_record < 1: - raise ValueError("step_count_per_record must greater than 0") - if step_count_per_record > 1e6: - raise ValueError("step_count_per_record must smaller than 1e6") - - -def validate_dynamic_on(dynamic_on): - if not isinstance(dynamic_on, bool): - raise TypeError('dynamic_on should be a bool') - - -def validate_monitor_mbs_grad(monitor_mbs_grad): - if not isinstance(monitor_mbs_grad, bool): - logger.warning(f'monitor_mbs_grad should be a bool, actual value is {monitor_mbs_grad}.') - return False - return monitor_mbs_grad - - -def validate_append_output(append_output): - if not isinstance(append_output, list): - raise TypeError('append_output should be a list') - if len(append_output) > 0 and len(append_output) != 2: - raise ValueError('append_output should be empty or contain exactly 2 elements') - - -def validate_config(config): - config['ops'] = validate_ops(config.get('ops', [])) - - ndigits = config.get('ndigits') - validate_ndigits(ndigits) - - eps = config.get('eps', 1e-8) - if not isinstance(eps, float): - raise TypeError("eps should be a float") - - ranks = config.get("module_ranks", []) - validate_ranks(ranks) - - targets = config.get("targets", {}) - validate_targets(targets) - - print_struct = config.get('print_struct', False) - validate_print_struct(print_struct) - - ur_distribution = config.get('ur_distribution', False) - validate_ur_distribution(ur_distribution) - - xy_distribution = config.get('xy_distribution', False) - validate_xy_distribution(xy_distribution) - - wg_distribution = config.get('wg_distribution', False) - validate_wg_distribution(wg_distribution) - - mg_distribution = config.get('mg_distribution', False) - validate_mg_distribution(mg_distribution) - - param_distribution = config.get('param_distribution', False) - validate_param_distribution(param_distribution) - - cc_distribution = config.get('cc_distribution', {}) - validate_cc_distribution(cc_distribution) - - alert = config.get('alert', {}) - validate_alert(alert) - - step_count_per_record = config.get('step_count_per_record', 1) - validate_step_count_per_record(step_count_per_record) - - config["start_step"] = validate_int_arg(config.get("start_step"), "start_step", - MonitorConst.DEFAULT_START_STEP, MonitorConst.DEFAULT_START_STEP) - config["collect_times"] = validate_int_arg(config.get("collect_times"), "collect_times", - MonitorConst.DEFAULT_MIN_COLLECT_TIMES, - MonitorConst.DEFAULT_MAX_COLLECT_TIMES) - config["step_interval"] = validate_int_arg(config.get("step_interval"), "step_interval", - MonitorConst.DEFAULT_STEP_INTERVAL, MonitorConst.DEFAULT_STEP_INTERVAL) - - squash_name = config.get('squash_name', True) - validate_squash_name(squash_name) - - time_tags = config.get("append_output", []) - validate_append_output(time_tags) - - config["monitor_mbs_grad"] = validate_monitor_mbs_grad(config.get('monitor_mbs_grad', False)) - - dynamic_on = config.get('dynamic_on', False) - validate_dynamic_on(dynamic_on) - - if not targets: - if xy_distribution: - config["all_xy"] = True - config["targets"] = {"": {}} - - -def time_str2time_digit(time_str): - time_format = '%b%d_%H-%M-%S' - if not isinstance(time_str, str): - raise TypeError(f"time_str:{time_str} should be a str") - try: - time_digit = datetime.strptime(time_str, time_format) - except Exception as e: - raise RuntimeError(f"illegal timestamp: {time_str}, timestamp should be prefix \ - of existing output dirpath, like 'Dec03_21-34-40'.") from e - return time_digit - - -def get_target_output_dir(monitor_path, time_start, time_end): - check_file_or_directory_path(monitor_path, isdir=True) - time_start = time_str2time_digit(time_start) if time_start is not None else time_start - time_end = time_str2time_digit(time_end) if time_end is not None else time_end - if time_start and time_end and time_start > time_end: - raise ValueError(f"time_start({time_start}) greater than time_end({time_end})") - result = {} - for dirname in os.listdir(monitor_path): - match = re.match(MonitorConst.OUTPUT_DIR_PATTERN, dirname) - if not match: - continue - time_tag = match.group(1) - rank = match.group(2) - target_time = time_str2time_digit(time_tag) - start_ok = time_start is None or target_time >= time_start - end_ok = time_end is None or target_time <= time_end - if start_ok and end_ok: - result[rank] = os.path.join(monitor_path, dirname) - return result - - -def chmod_tensorboard_dir(path): - """ - format配置为tensorboard时,需要补充文件权限设置 - """ - try: - recursive_chmod(path) - except Exception as e: - logger.warning(f"chmod tensorboard dir wrong because {e}, not updated, please check!!!") - - -def validate_set_monitor(grad_acc_steps, start_iteration): - """ - validate parameters of set_monitor. - """ - grad_acc_steps = validate_int_arg(grad_acc_steps, "grad_acc_steps", - MonitorConst.DEFAULT_GRAD_ACC_STEPS, MonitorConst.DEFAULT_GRAD_ACC_STEPS) - - start_iteration = validate_int_arg(start_iteration, "start_iteration", - MonitorConst.DEFAULT_START_ITERATION, MonitorConst.DEFAULT_START_ITERATION) - return grad_acc_steps, start_iteration - - -def validate_int_arg(value, name, minimum, default_value): - """Validate int args, if any exception occurs, use the default value.""" - if value is None: - return default_value - try: - if not is_int(value): - raise TypeError(f"{name} must be int") - if value < minimum: - raise ValueError(f"{name} must greater than {minimum}") - except Exception as e: - value = default_value - logger.warning(f"Validate {name} failed, {e}, replaced with default value {value}.") - return value + return res \ No newline at end of file -- Gitee From 0c80904e58d25d04f337eb07f41b81dc3f7fbc42 Mon Sep 17 00:00:00 2001 From: TAJh Date: Wed, 2 Jul 2025 14:39:20 +0800 Subject: [PATCH 3/4] fix monitor UT --- .../test/mindspore_ut/ms_monitor/test_mon_utils.py | 12 ++++++------ .../test/pytorch_ut/monitor/test_monitor_utils.py | 5 +++-- .../pytorch_ut/monitor/test_optimizer_collect.py | 2 +- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/ms_monitor/test_mon_utils.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/ms_monitor/test_mon_utils.py index 9299d37d077..01005fadf56 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/ms_monitor/test_mon_utils.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/ms_monitor/test_mon_utils.py @@ -82,7 +82,7 @@ class TestMonitorUtils(unittest.TestCase): self.assertTrue(is_skip_step(12, 10, 2, has_collect_times=5, collect_times=5)) def test_validate_ops(self): - from msprobe.mindspore.monitor.utils import validate_ops + from msprobe.core.monitor.utils import validate_ops # 测试输入不是list的情况 with self.assertRaises(TypeError): @@ -105,7 +105,7 @@ class TestMonitorUtils(unittest.TestCase): self.assertIn("dtype", result) def test_validate_ranks(self): - from msprobe.mindspore.monitor.utils import validate_ranks + from msprobe.core.monitor.utils import validate_ranks # 测试输入不是list的情况 with self.assertRaises(TypeError): @@ -122,7 +122,7 @@ class TestMonitorUtils(unittest.TestCase): self.fail(f"validate_ranks raised unexpected exception: {e}") def test_validate_targets(self): - from msprobe.mindspore.monitor.utils import validate_targets + from msprobe.core.monitor.utils import validate_targets # 测试输入不是dict的情况 with self.assertRaises(TypeError): @@ -143,7 +143,7 @@ class TestMonitorUtils(unittest.TestCase): self.fail(f"validate_targets raised unexpected exception: {e}") def test_validate_config(self): - from msprobe.mindspore.monitor.utils import validate_config + from msprobe.core.monitor.utils import validate_config # 测试基本配置验证 config = { @@ -185,7 +185,7 @@ class TestMonitorUtils(unittest.TestCase): validate_config(invalid_config) def test_time_str2time_digit(self): - from msprobe.mindspore.monitor.utils import time_str2time_digit + from msprobe.core.monitor.utils import time_str2time_digit # 测试有效时间字符串 time_str = "Dec03_21-34-40" @@ -201,7 +201,7 @@ class TestMonitorUtils(unittest.TestCase): time_str2time_digit(invalid_time_str) def test_get_target_output_dir(self): - from msprobe.mindspore.monitor.utils import get_target_output_dir + from msprobe.core.monitor.utils import get_target_output_dir # 测试不带时间范围的情况 result = get_target_output_dir(self.temp_dir, None, None) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_monitor_utils.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_monitor_utils.py index 87822ab0503..83e8217c894 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_monitor_utils.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_monitor_utils.py @@ -5,10 +5,11 @@ from unittest.mock import patch, MagicMock import torch from msprobe.core.common.const import MonitorConst -from msprobe.pytorch.monitor.utils import filter_special_chars, MsgConst, get_param_struct, validate_ops, \ - validate_ranks, validate_targets, validate_print_struct, validate_ur_distribution, validate_xy_distribution, \ +from msprobe.core.monitor.utils import filter_special_chars, MsgConst, validate_ops, validate_ranks, \ + validate_targets, validate_print_struct, validate_ur_distribution, validate_xy_distribution, \ validate_mg_distribution, validate_wg_distribution, validate_cc_distribution, validate_alert, validate_config, \ get_output_base_dir +from msprobe.pytorch.monitor.utils import get_param_struct from msprobe.pytorch.common.utils import is_recomputation diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_optimizer_collect.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_optimizer_collect.py index e32e4f860ee..c7cbd86bbcc 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_optimizer_collect.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_optimizer_collect.py @@ -9,7 +9,7 @@ from msprobe.pytorch.monitor.optimizer_collect import OptimizerMon, \ MegatronChainedDistributedOptimizerMon, MegatronChainedMixPrecisionOptimizerMon, \ DeepSpeedZeroOptimizerMon, DeepSpeedZeroOptimizerStage0Mon, \ DeepSpeedZeroOptimizerStage1or2Mon, DeepSpeedZeroOptimizerStage3Mon -from msprobe.pytorch.monitor.utils import MVResult +from msprobe.core.monitor.utils import MVResult def setup_param_groups(num_groups=2, params_per_group=5): -- Gitee From 7613ab63ce635b9dd3621d133517209de79c0c19 Mon Sep 17 00:00:00 2001 From: TAJh Date: Wed, 2 Jul 2025 14:43:56 +0800 Subject: [PATCH 4/4] delete unused --- debug/accuracy_tools/msprobe/core/monitor/utils.py | 5 ----- debug/accuracy_tools/msprobe/pytorch/monitor/utils.py | 1 + 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/monitor/utils.py b/debug/accuracy_tools/msprobe/core/monitor/utils.py index 24cdbaa1ba4..f19e14d89e6 100644 --- a/debug/accuracy_tools/msprobe/core/monitor/utils.py +++ b/debug/accuracy_tools/msprobe/core/monitor/utils.py @@ -25,11 +25,6 @@ from msprobe.core.common.utils import is_int from msprobe.core.common.file_utils import check_file_or_directory_path, recursive_chmod -NAN_TENSOR_ON_DEVICE = None -FILE_MAX_SIZE = 10 * 1024 * 1024 * 1024 -FILE_NAME_MAX_LENGTH = 255 -DIRECTORY_MAX_LENGTH = 4096 - beijing_tz = timezone(timedelta(hours=8)) MVResult = namedtuple('MVResult', ("exp_avg", "exp_avg_sq", "update", "ratio")) diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/utils.py b/debug/accuracy_tools/msprobe/pytorch/monitor/utils.py index 31cdd53f42b..ca339ad6482 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/utils.py @@ -27,6 +27,7 @@ except ImportError: NAN_TENSOR_ON_DEVICE = None + def get_nan_tensor(): global NAN_TENSOR_ON_DEVICE if not NAN_TENSOR_ON_DEVICE: -- Gitee