diff --git a/debug/accuracy_tools/msprobe/core/common/const.py b/debug/accuracy_tools/msprobe/core/common/const.py index 7d749883819e8dc7137e73a53439c2f5f9c024ff..fa9fe97ba2f82036fb789d0d6d39d3be62523819 100644 --- a/debug/accuracy_tools/msprobe/core/common/const.py +++ b/debug/accuracy_tools/msprobe/core/common/const.py @@ -799,5 +799,20 @@ class MonitorConst: CSV = "csv" API = "api" HEADER_NAME = 'name' - MAX_NDIGITS = 20 + + DEFAULT_STAGE = -1 + FORWARD_STAGE = 0 + BACKWARD_STAGE = 1 + OPTIMIZER_STAGE = 2 + FORWARD_KEY = [ACTV] + BACKWARD_KEY = [ACTVGRAD, PRE_GRAD, POST_GRAD, ACC_GRAD] + OPTIMIZER_KEY = [EXP_AVG, EXP_AVG_SQ] + + TRAIN_STAGE = {} + for key in FORWARD_KEY: + TRAIN_STAGE[key] = FORWARD_STAGE + for key in BACKWARD_KEY: + TRAIN_STAGE[key] = BACKWARD_STAGE + for key in OPTIMIZER_KEY: + TRAIN_STAGE[key] = OPTIMIZER_STAGE diff --git a/debug/accuracy_tools/msprobe/core/monitor/__init__.py b/debug/accuracy_tools/msprobe/core/monitor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/anomaly_analyse.py b/debug/accuracy_tools/msprobe/core/monitor/anomaly_processor.py similarity index 50% rename from debug/accuracy_tools/msprobe/pytorch/monitor/anomaly_analyse.py rename to debug/accuracy_tools/msprobe/core/monitor/anomaly_processor.py index f1bdaa35ef7dea1471c5d54fbefa513c410126b3..8c50ad761682c05533d525acaa39e6f830cc4e48 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/anomaly_analyse.py +++ b/debug/accuracy_tools/msprobe/core/monitor/anomaly_processor.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -12,18 +12,205 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import os import sys +import math import argparse import ast import heapq +from abc import ABC +from dataclasses import dataclass, field +from typing import List -from msprobe.pytorch.common.log import logger from msprobe.core.common.const import MonitorConst +from msprobe.core.common.log import logger from msprobe.core.common.file_utils import save_json, create_directory, remove_path, \ check_file_or_directory_path, load_json -from msprobe.pytorch.monitor.anomaly_detect import GradAnomalyData + + +class ScanRule(ABC): + name = "ScanRule" + + def apply(self, cur, history=None): + raise NotImplementedError("abstract method apply is not implemented") + + +class AnomalyTurbulence(ScanRule): + name = "AnomalyTurbulence" + + def __init__(self, threshold) -> None: + self.threshold = threshold + + def apply(self, cur, history=None): + """ + :param cur: float, current metric value + :param history: float, history weighted average + :return: bool, whether the current value deviates from the historical average value of current metric + """ + up_bound = history * (1 + self.threshold) + return abs(cur) > up_bound + + +class AnomalyNan(ScanRule): + name = "AnomalyNan" + + def __init__(self, threshold=None) -> None: + self.threshold = threshold + + def apply(self, cur, history=None): + return math.isnan(cur) or (self.threshold is not None and abs(cur) > self.threshold) + + +class AnomalyScanner: + + @staticmethod + def load_rules(specs: List[dict]): + """ + specs: [{"rule_name": "AnomalyTurbulence", "args": {"threshold": 0.5}}] + """ + if specs is None: + return [] + alert_rules = [] + for spec in specs: + # 使用get方法获取键值,如果键不存在则返回None + rule_cls_name = spec.get("rule_name") + rule_args = spec.get("args") + + # 检查必要的键是否存在 + if rule_cls_name is None or (rule_cls_name == "AnomalyTurbulence" and rule_args is None): + logger.warning(f"Spec is missing required keys: {spec}") + continue + + cur_module = sys.modules.get(__name__) + try: + rule_cls = getattr(cur_module, rule_cls_name) + except AttributeError: + logger.error(f"Rule class '{rule_cls_name}' not found in the current module.") + continue + + try: + rule_instance = rule_cls(**rule_args) if rule_args is not None else rule_cls() + alert_rules.append(rule_instance) + except Exception as e: + logger.error(f"Error creating instance of rule '{rule_cls_name}': {e}") + continue + + return alert_rules + + @staticmethod + def scan(scan_rules: List[ScanRule], history, cur): + anomaly = False + for rule in scan_rules: + anomaly = rule.apply(cur, history=history) + if anomaly: + return anomaly, rule.name + return anomaly, None + + +class AnomalyDataFactory(ABC): + def __init__(self, rank, pp_stage, group_mates): + super().__init__() + self.rank = rank + self.pp_stage = pp_stage + self.group_mates = group_mates + self.micro_step = 0 + self.name2callid = {} + + def set_call_id(self, name2callid): + """根据当前GradContext信息更新call_id vpp_stage等信息 + """ + self.name2callid = name2callid + + def create(self, tag, message, step): + """如果检查出异常, 调用当前接口生成GradAnomalyData实例 + tag (tuple): metric tag ('0:1.post_attention_norm.weight/rank0/pre_grad', 'min') + message (str): anomaly detect message + step (int): training step + """ + if not isinstance(tag, tuple) or len(tag) != 2: + raise ValueError("tag must be a tuple with length 2") + tag_name = tag[0] + param_name = tag_name.split('/')[0] + call_id = self.name2callid.get(tag_name, -1) + if MonitorConst.NAME_SEP in param_name: + vpp_stage = int(param_name.split(MonitorConst.NAME_SEP)[0]) + else: + vpp_stage = 0 + + return GradAnomalyData( + self.rank, + step, + self.micro_step, + self.pp_stage, + vpp_stage, + call_id, + tag_name, + message, + self.group_mates + ) + + +@dataclass(eq=True) +class GradAnomalyData: + rank: int = 0 + step: int = 0 + micro_step: int = 0 + pp_stage: int = 0 + vpp_stage: int = 0 + call_id: int = 0 + tag_name: str = field(default=None, compare=False) + message: str = field(default="", compare=False) + group_mates: list = field(default=None, compare=False) + + def __lt__(self, other): + """ + 自定义比较函数,用于确定 GradAnomalyData 实例之间的顺序。 + 比较规则为: + step 和 micro_step 值越小优先级越高; + vpp 和 pp 在前向阶段值越小优先级越高,在非前向阶段值越大优先级越高; + call_id 值越小优先级越高。 + """ + if not isinstance(other, GradAnomalyData): + return NotImplemented + + self_train_stage = self.get_train_stage(self.tag_name) + other_train_stage = self.get_train_stage(other.tag_name) + + def vpp_pp_comparator(anomaly): + """ + Determine the priority rule for vpp and pp based on train stage + Forward stage prefers smaller vpp and pp + Other stages prefer larger vpp and pp + """ + if self_train_stage == MonitorConst.FORWARD_STAGE: + return anomaly.vpp_stage, anomaly.pp_stage + else: + return -anomaly.vpp_stage, -anomaly.pp_stage + + self_cmp = [self.step, self.micro_step, self_train_stage, *vpp_pp_comparator(self), self.call_id] + other_cmp = [other.step, other.micro_step, other_train_stage, *vpp_pp_comparator(other), other.call_id] + return self_cmp < other_cmp + + def __le__(self, other): + if not isinstance(other, GradAnomalyData): + return NotImplemented + return self == other or self < other + + @staticmethod + def get_train_stage(tag_name): + """ + :param tag_name: "0:fc2.input:0/rank0/actv", "0:fc1.weight/rank0/post_grad", "0:fc2.weight/rank0/exp_avg_sq" + :return: int, if forward return 0; if backward return 1; if optimizer return 2 + """ + key_ = tag_name.split("/")[-1] + return MonitorConst.TRAIN_STAGE.get(key_, MonitorConst.DEFAULT_STAGE) + + def to_dict(self): + return self.__dict__ + + def get_key(self): + # 0:1.self_attention.core_attention_flash_0/rank0/input_grad + return ''.join([str(self.tag_name), "_step_", str(self.step), "_call_", str(self.call_id)]) class AnomalyDataWriter: @@ -61,11 +248,12 @@ class AnomalyDataWriter: anomalies: GradAnomalyData对象列表 """ anomalies_json = self.get_anomaly_dict(anomalies) - logger.info(f"{MonitorConst.ANOMALY_JSON} is at {self.dump_rank_dir}.") + if anomalies_json: + logger.info(f"{MonitorConst.ANOMALY_JSON} is at {self.dump_rank_dir}.") - data_to_write = load_json(self.json_path) if os.path.exists(self.json_path) else {} - data_to_write.update(anomalies_json) - save_json(self.json_path, data_to_write, indent=1) + data_to_write = load_json(self.json_path) if os.path.exists(self.json_path) else {} + data_to_write.update(anomalies_json) + save_json(self.json_path, data_to_write, indent=1) class AnomalyDataLoader: @@ -140,27 +328,6 @@ class AnomalyAnalyse: save_json(json_path, sorted_data, indent=1) -def _get_parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument("-d", "--data_path", dest="data_path_dir", default="./", type=str, - help=" The anomaly detect result dictionary: generate from monitor tool.", - required=True, - ) - parser.add_argument("-o", "--out_path", dest="out_path", default="", type=str, - help=" The analyse task result out path.", - required=False, - ) - parser.add_argument("-k", "--topk", dest="top_k_number", default=8, type=int, - help=" Top K number of earliest anomalies.", - required=False, - ) - parser.add_argument("-s", "--step", dest="step_list", default="[]", type=str, - help=" Analyse which steps.", - required=False, - ) - return parser.parse_args(sys.argv[1:]) - - def _get_step_and_stop(args): try: step_list = ast.literal_eval(args.step_list) @@ -191,6 +358,27 @@ def _anomaly_analyse(): logger.info(f"{index}: {anomaly.message}") +def _get_parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("-d", "--data_path", dest="data_path_dir", default="./", type=str, + help=" The anomaly detect result dictionary: generate from monitor tool.", + required=True, + ) + parser.add_argument("-o", "--out_path", dest="out_path", default="", type=str, + help=" The analyse task result out path.", + required=False, + ) + parser.add_argument("-k", "--topk", dest="top_k_number", default=8, type=int, + help=" Top K number of earliest anomalies.", + required=False, + ) + parser.add_argument("-s", "--step", dest="step_list", default="[]", type=str, + help=" Analyse which steps.", + required=False, + ) + return parser.parse_args(sys.argv[1:]) + + if __name__ == "__main__": _anomaly_analyse() logger.info("Analyse task completed.") diff --git a/debug/accuracy_tools/msprobe/mindspore/monitor/anomaly_analyse.py b/debug/accuracy_tools/msprobe/mindspore/monitor/anomaly_analyse.py deleted file mode 100644 index d9331d2ba9e2f8ae16d33a7daa5b0335faf39e9c..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/mindspore/monitor/anomaly_analyse.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os - -from msprobe.core.common.log import logger -from msprobe.core.common.const import MonitorConst -from msprobe.core.common.file_utils import save_json, create_directory, remove_path, \ - check_file_or_directory_path, load_json - - -class AnomalyDataWriter: - """ - 异常数据写入类,负责将异常数据写入到JSON文件中。 - """ - - def __init__(self, dump_path, rank) -> None: - self.dump_path = dump_path - self.dump_rank_dir = os.path.join(self.dump_path, f"rank{rank}") - self.json_path = os.path.join(self.dump_rank_dir, MonitorConst.ANOMALY_JSON) - - @staticmethod - def get_anomaly_dict(anomalies): - """将GradAnomalyData列表转换为json""" - anomalies_json = {} - for anomaly in anomalies: - anomalies_json.update({anomaly.get_key(): anomaly.to_dict()}) - return anomalies_json - - def init_detected_json(self): - """初始化落盘文件""" - create_directory(self.dump_rank_dir) - - if os.path.exists(self.json_path): - check_file_or_directory_path(self.json_path, isdir=False) - logger.warning(f"The existing file will be deleted: {self.json_path}.") - remove_path(self.json_path) - save_json(self.json_path, {}, indent=1) - - def write_detected_json(self, anomalies): - """ - 落盘异常数据 - Args: - anomalies: GradAnomalyData对象列表 - """ - anomalies_json = self.get_anomaly_dict(anomalies) - logger.info(f"{MonitorConst.ANOMALY_JSON} is at {self.dump_rank_dir}.") - - data_to_write = load_json(self.json_path) if os.path.exists(self.json_path) else {} - data_to_write.update(anomalies_json) - save_json(self.json_path, data_to_write, indent=1) diff --git a/debug/accuracy_tools/msprobe/mindspore/monitor/anomaly_detect.py b/debug/accuracy_tools/msprobe/mindspore/monitor/data_writers.py similarity index 54% rename from debug/accuracy_tools/msprobe/mindspore/monitor/anomaly_detect.py rename to debug/accuracy_tools/msprobe/mindspore/monitor/data_writers.py index 5d551ee70a67acdb3e4ef28bc1672c9bb9534cc8..85c1096123c337a123b16f18236655bfe6e49c5e 100644 --- a/debug/accuracy_tools/msprobe/mindspore/monitor/anomaly_detect.py +++ b/debug/accuracy_tools/msprobe/mindspore/monitor/data_writers.py @@ -15,104 +15,20 @@ import itertools import os -import sys -import math -import statistics as st -from abc import ABC -from dataclasses import dataclass, field -from typing import List +from dataclasses import dataclass from collections import defaultdict import pandas as pd - from mindspore import ops from mindspore import Tensor from mindspore import _no_grad + from msprobe.core.common.log import logger from msprobe.core.common.file_utils import change_mode, create_directory, write_df_to_csv +from msprobe.core.monitor.anomaly_processor import AnomalyDataFactory, AnomalyTurbulence, AnomalyScanner from msprobe.core.common.const import FileCheckConst, MonitorConst -class ScanRule(ABC): - name = "ScanRule" - - def apply(self, cur, history=None): - raise NotImplementedError("abstract method apply is not implemented") - - -class AnomalyTurbulence(ScanRule): - name = "AnomalyTurbulence" - - def __init__(self, threshold) -> None: - self.threshold = threshold - - def apply(self, cur, history=None): - """ - :param cur: float, current metric value - :param history: float, history weighted average - :return: bool, whether the current value deviates from the historical average value of current metric - """ - baseline = st.mean(history) if isinstance(history, list) else history - up_bound = baseline * (1 + self.threshold) - return abs(cur) > up_bound - - -class AnomalyNan(ScanRule): - name = "AnomalyNan" - - def __init__(self, threshold=None) -> None: - self.threshold = threshold - - def apply(self, cur, history=None): - return math.isnan(cur) or (self.threshold is not None and abs(cur) > self.threshold) - - -class AnomalyScanner: - - @staticmethod - def load_rules(specs: List[dict]): - """ - specs: [{"rule_name": "AnomalyTurbulence", "args": {"threshold": 0.5}}] - """ - if specs is None: - return [] - alert_rules = [] - for spec in specs: - # 使用get方法获取键值,如果键不存在则返回None - rule_cls_name = spec.get("rule_name") - rule_args = spec.get("args") - - # 检查必要的键是否存在 - if rule_cls_name is None or (rule_cls_name == "AnomalyTurbulence" and rule_args is None): - logger.warning(f"Spec is missing required keys: {spec}") - continue - - cur_module = sys.modules.get(__name__) - try: - rule_cls = getattr(cur_module, rule_cls_name) - except AttributeError: - logger.error(f"Rule class '{rule_cls_name}' not found in the current module.") - continue - - try: - rule_instance = rule_cls(**rule_args) if rule_args is not None else rule_cls() - alert_rules.append(rule_instance) - except Exception as e: - logger.error(f"Error creating instance of rule '{rule_cls_name}': {e}") - continue - - return alert_rules - - @staticmethod - def scan(scan_rules: List[ScanRule], history, cur): - anomaly = False - for rule in scan_rules: - anomaly = rule.apply(cur, history=history) - if anomaly: - return anomaly, rule.name - return anomaly, None - - class BCOLORS: HEADER = '\033[95m' OKBLUE = '\033[94m' @@ -125,129 +41,6 @@ class BCOLORS: UNDERLINE = '\033[4m' -class AnomalyDataFactory(ABC): - def __init__(self, rank, pp_stage, group_mates): - super().__init__() - self.rank = rank - self.pp_stage = pp_stage - self.group_mates = group_mates - self.micro_step = 0 - self.name2callid = {} - - def set_call_id(self, name2callid): - """根据当前GradContext信息更新call_id vpp_stage等信息 - """ - self.name2callid = name2callid - - def create(self, tag, message, step): - """如果检查出异常, 调用当前接口生成GradAnomalyData实例 - tag (tuple): metric tag ('0:1.post_attention_norm.weight/rank0/pre_grad', 'min') - message (str): anomaly detect message - step (int): training step - """ - if not isinstance(tag, tuple) or len(tag) != 2: - raise ValueError("tag must be a tuple with length 2") - tag_name = tag[0] - param_name = tag_name.split('/')[0] - call_id = self.name2callid.get(tag_name, -1) - if MonitorConst.NAME_SEP in param_name: - vpp_stage = int(param_name.split(MonitorConst.NAME_SEP)[0]) - else: - vpp_stage = 0 - - return GradAnomalyData( - self.rank, - step, - self.micro_step, - self.pp_stage, - vpp_stage, - call_id, - tag_name, - message, - self.group_mates - ) - - -class TrainStage: - DEFAULT_STAGE = -1 - FORWARD_STAGE = 0 - BACKWARD_STAGE = 1 - OPTIMIZER_STAGE = 2 - - -FORWARD_KEY = [MonitorConst.ACTV] -BACKWARD_KEY = [MonitorConst.ACTVGRAD, MonitorConst.PRE_GRAD, MonitorConst.POST_GRAD, MonitorConst.ACC_GRAD] -OPTIMIZER_KEY = [MonitorConst.EXP_AVG, MonitorConst.EXP_AVG_SQ] -TRAIN_STAGE = { - **{key_: TrainStage.FORWARD_STAGE for key_ in FORWARD_KEY}, - **{key_: TrainStage.BACKWARD_STAGE for key_ in BACKWARD_KEY}, - **{key_: TrainStage.OPTIMIZER_STAGE for key_ in OPTIMIZER_KEY} -} - - -@dataclass(eq=True) -class GradAnomalyData: - rank: int = 0 - step: int = 0 - micro_step: int = 0 - pp_stage: int = 0 - vpp_stage: int = 0 - call_id: int = 0 - tag_name: str = field(default=None, compare=False) - message: str = field(default="", compare=False) - group_mates: list = field(default=None, compare=False) - - def __lt__(self, other): - """ - 自定义比较函数,用于确定 GradAnomalyData 实例之间的顺序。 - 比较规则为: - step 和 micro_step 值越小优先级越高; - vpp 和 pp 在前向阶段值越小优先级越高,在非前向阶段值越大优先级越高; - call_id 值越小优先级越高。 - """ - if not isinstance(other, GradAnomalyData): - return NotImplemented - - self_train_stage = self.get_train_stage(self.tag_name) - other_train_stage = self.get_train_stage(other.tag_name) - - def vpp_pp_comparator(anomaly): - """ - Determine the priority rule for vpp and pp based on train stage - Forward stage prefers smaller vpp and pp - Other stages prefer larger vpp and pp - """ - if self_train_stage == TrainStage.FORWARD_STAGE: - return anomaly.vpp_stage, anomaly.pp_stage - else: - return -anomaly.vpp_stage, -anomaly.pp_stage - - self_cmp = [self.step, self.micro_step, self_train_stage, *vpp_pp_comparator(self), self.call_id] - other_cmp = [other.step, other.micro_step, other_train_stage, *vpp_pp_comparator(other), other.call_id] - return self_cmp < other_cmp - - def __le__(self, other): - if not isinstance(other, GradAnomalyData): - return NotImplemented - return self == other or self < other - - @staticmethod - def get_train_stage(tag_name): - """ - :param tag_name: "0:fc2.input:0/rank0/actv", "0:fc1.weight/rank0/post_grad", "0:fc2.weight/rank0/exp_avg_sq" - :return: int, if forward return 0; if backward return 1; if optimizer return 2 - """ - key_ = tag_name.split("/")[-1] - return TRAIN_STAGE.get(key_, TrainStage.DEFAULT_STAGE) - - def to_dict(self): - return self.__dict__ - - def get_key(self): - # 0:1.self_attention.core_attention_flash_0/rank0/input_grad - return ''.join([str(self.tag_name), "_step_", str(self.step), "_call_", str(self.call_id)]) - - @dataclass class WriterInput: path: str diff --git a/debug/accuracy_tools/msprobe/mindspore/monitor/module_hook.py b/debug/accuracy_tools/msprobe/mindspore/monitor/module_hook.py index 945b8ef9eccab09e8a46fa315de07df5730cc71a..0354ab533683c2ac7efc44c2914581e980fae4ff 100644 --- a/debug/accuracy_tools/msprobe/mindspore/monitor/module_hook.py +++ b/debug/accuracy_tools/msprobe/mindspore/monitor/module_hook.py @@ -21,20 +21,20 @@ from datetime import datetime import pytz import pandas as pd +import mindspore from mindspore import Tensor, mint from mindspore import nn, _no_grad from msprobe.core.common.log import logger from msprobe.core.common.const import MonitorConst, Const from msprobe.core.common.file_utils import load_json, save_json +from msprobe.core.monitor.anomaly_processor import AnomalyScanner, AnomalyDataFactory, AnomalyDataWriter from msprobe.mindspore.common.utils import is_mindtorch from msprobe.mindspore.monitor.common_func import is_valid_instance, get_parameters, get_submodules, get_rank from msprobe.mindspore.monitor.utils import get_summary_writer_tag_name, validate_config, step_accumulates_one, \ is_skip_step, get_metrics, get_target_output_dir from msprobe.mindspore.monitor.optimizer_collect import OptimizerMonFactory -from msprobe.mindspore.monitor.anomaly_detect import AnomalyScanner, AnomalyDataFactory, \ - CSVWriterWithAD, BaseWriterWithAD, WriterInput -from msprobe.mindspore.monitor.anomaly_analyse import AnomalyDataWriter +from msprobe.mindspore.monitor.data_writers import CSVWriterWithAD, BaseWriterWithAD, WriterInput from msprobe.mindspore.monitor.distributed.wrap_distributed import api_register, create_hooks, op_aggregate from msprobe.core.common.file_utils import write_df_to_csv from msprobe.core.common.utils import analyze_api_call_stack @@ -563,9 +563,9 @@ class TrainerMon: v_dict = {} for name, param in get_parameters(common_opt): if MonitorConst.EXP_AVG_SQ in name: - m_dict[name] = param - elif MonitorConst.EXP_AVG in name: v_dict[name] = param + elif MonitorConst.EXP_AVG in name: + m_dict[name] = param return m_dict, v_dict def generate_mv_metrics(self, opt_context): @@ -784,10 +784,16 @@ class TrainerMon: step_accumulates_one(context, self.micro_batch_number) return - def fwd_hook_fun_wrapper(fwd_hook_fun, name): - def wrapper(module, args, kwargs, module_output): - return fwd_hook_fun(module, args, kwargs, module_output, name) - return wrapper + def fwd_hook_register(module, fwd_hook_fun, name): + if mindspore.__version__ >= '2.6.0': + def wrapper(module, args, kwargs, module_output): + return fwd_hook_fun(module, args, kwargs, module_output, name) + return module.register_forward_hook(wrapper, with_kwargs=True) + + else: + def wrapper(module, args, module_output): + return fwd_hook_fun(module, args, None, module_output, name) + return module.register_forward_hook(wrapper) def stack_hook(module, args, kwargs, module_output, name): if module not in self.module_fwd_hook_context_by_module: @@ -803,15 +809,14 @@ class TrainerMon: for module_name, submodule in get_submodules(module): if self.stack_info: name = vpp_stage + squash_param_name(module_name) - handle = submodule.register_forward_hook(fwd_hook_fun_wrapper(stack_hook, name=name), with_kwargs=True) + handle = fwd_hook_register(submodule, stack_hook, name=name) self.handles["stack"].append(handle) name = self._is_target_module(module_name, target_names, vpp_stage) if not name: continue if self.xy_distribution or self.print_struct: if not self.backward_only: - handle = submodule.register_forward_hook(fwd_hook_fun_wrapper(fwd_hook_fun, name=name), - with_kwargs=True) + handle = fwd_hook_register(submodule, fwd_hook_fun, name=name) self.handles['xy'].append(handle) if not self.forward_only: handle = submodule.register_backward_hook(bwd_hook_fun) diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/anomaly_detect.py b/debug/accuracy_tools/msprobe/pytorch/monitor/data_writers.py similarity index 57% rename from debug/accuracy_tools/msprobe/pytorch/monitor/anomaly_detect.py rename to debug/accuracy_tools/msprobe/pytorch/monitor/data_writers.py index bcc694ca7e4f132ff435c27e160ef64a5c0a10e3..bd6bde7e9f6ede789f520acc2138492e99bac509 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/anomaly_detect.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/data_writers.py @@ -14,13 +14,8 @@ # limitations under the License. import itertools import os -import math -import statistics as st -import sys -from abc import ABC from collections import defaultdict -from dataclasses import dataclass, field -from typing import List +from dataclasses import dataclass import pandas as pd import torch @@ -28,89 +23,10 @@ from torch.utils.tensorboard import SummaryWriter from msprobe.core.common.const import FileCheckConst, MonitorConst from msprobe.core.common.file_utils import change_mode, create_directory, write_df_to_csv +from msprobe.core.monitor.anomaly_processor import AnomalyDataFactory, AnomalyTurbulence, AnomalyScanner from msprobe.pytorch.common.log import logger -class ScanRule(ABC): - name = "ScanRule" - - def apply(self, cur, history=None): - raise NotImplementedError("abstract method apply is not implemented") - - -class AnomalyTurbulence(ScanRule): - name = "AnomalyTurbulence" - - def __init__(self, threshold) -> None: - self.threshold = threshold - - def apply(self, cur, history=None): - """ - :param cur: float, current metric value - :param history: float, history weighted average - :return: bool, whether the current value deviates from the historical average value of current metric - """ - baseline = st.mean(history) if isinstance(history, list) else history - up_bound = baseline * (1 + self.threshold) - return abs(cur) > up_bound - - -class AnomalyNan(ScanRule): - name = "AnomalyNan" - - def __init__(self, threshold=None) -> None: - self.threshold = threshold - - def apply(self, cur, history=None): - return math.isnan(cur) or (self.threshold is not None and abs(cur) > self.threshold) - - -class AnomalyScanner: - - @staticmethod - def load_rules(specs: List[dict]): - """ - specs: [{"rule_name": "AnomalyTurbulence", "args": {"threshold": 0.5}}] - """ - if specs is None: - return [] - alert_rules = [] - for spec in specs: - # 使用get方法获取键值,如果键不存在则返回None - rule_cls_name = spec.get("rule_name") - rule_args = spec.get("args") - - # 检查必要的键是否存在 - if rule_cls_name is None or (rule_cls_name == "AnomalyTurbulence" and rule_args is None): - logger.warning(f"Spec is missing required keys: {spec}") - continue - - cur_module = sys.modules.get(__name__) - try: - rule_cls = getattr(cur_module, rule_cls_name) - except AttributeError: - logger.error(f"Rule class '{rule_cls_name}' not found in the current module.") - continue - - try: - rule_instance = rule_cls(**rule_args) if rule_args is not None else rule_cls() - alert_rules.append(rule_instance) - except Exception as e: - logger.error(f"Error creating instance of rule '{rule_cls_name}': {e}") - continue - - return alert_rules - - @staticmethod - def scan(scan_rules: List[ScanRule], history, cur): - anomaly = False - for rule in scan_rules: - anomaly = rule.apply(cur, history=history) - if anomaly: - return anomaly, rule.name - return anomaly, None - - class BCOLORS: HEADER = '\033[95m' OKBLUE = '\033[94m' @@ -123,130 +39,6 @@ class BCOLORS: UNDERLINE = '\033[4m' -class AnomalyDataFactory(ABC): - def __init__(self, rank, pp_stage, group_mates): - super().__init__() - self.rank = rank - self.pp_stage = pp_stage - self.group_mates = group_mates - self.micro_step = 0 - self.name2callid = {} - - def set_call_id(self, name2callid): - """根据当前GradContext信息更新call_id vpp_stage等信息 - """ - self.name2callid = name2callid - - def create(self, tag, message, step): - """如果检查出异常, 调用当前接口生成GradAnomalyData实例 - tag (tuple): metric tag ('0:1.post_attention_norm.weight/rank0/pre_grad', 'min') - message (str): anomaly detect message - step (int): training step - """ - if not isinstance(tag, tuple) or len(tag) != 2: - raise ValueError("tag must be a tuple with length 2") - tag_name = tag[0] - param_name = tag_name.split('/')[0] - call_id = self.name2callid.get(tag_name, -1) - if MonitorConst.NAME_SEP in param_name: - vpp_stage = int(param_name.split(MonitorConst.NAME_SEP)[0]) - else: - vpp_stage = 0 - - return GradAnomalyData( - self.rank, - step, - self.micro_step, - self.pp_stage, - vpp_stage, - call_id, - tag_name, - message, - self.group_mates - ) - - -class TrainStage: - DEFAULT_STAGE = -1 - FORWARD_STAGE = 0 - BACKWARD_STAGE = 1 - OPTIMIZER_STAGE = 2 - - -FORWARD_KEY = [MonitorConst.ACTV] -BACKWARD_KEY = [MonitorConst.ACTVGRAD, MonitorConst.PRE_GRAD, - MonitorConst.POST_GRAD, MonitorConst.ACC_GRAD] -OPTIMIZER_KEY = [MonitorConst.EXP_AVG, MonitorConst.EXP_AVG_SQ] -TRAIN_STAGE = { - **{key_: TrainStage.FORWARD_STAGE for key_ in FORWARD_KEY}, - **{key_: TrainStage.BACKWARD_STAGE for key_ in BACKWARD_KEY}, - **{key_: TrainStage.OPTIMIZER_STAGE for key_ in OPTIMIZER_KEY} -} - - -@dataclass(eq=True) -class GradAnomalyData: - rank: int = 0 - step: int = 0 - micro_step: int = 0 - pp_stage: int = 0 - vpp_stage: int = 0 - call_id: int = 0 - tag_name: str = field(default=None, compare=False) - message: str = field(default="", compare=False) - group_mates: list = field(default=None, compare=False) - - def __lt__(self, other): - """ - 自定义比较函数,用于确定 GradAnomalyData 实例之间的顺序。 - 比较规则为: - step 和 micro_step 值越小优先级越高; - vpp 和 pp 在前向阶段值越小优先级越高,在非前向阶段值越大优先级越高; - call_id 值越小优先级越高。 - """ - if not isinstance(other, GradAnomalyData): - return NotImplemented - - self_train_stage = self.get_train_stage(self.tag_name) - other_train_stage = self.get_train_stage(other.tag_name) - - def vpp_pp_comparator(anomaly): - """ - Determine the priority rule for vpp and pp based on train stage - Forward stage prefers smaller vpp and pp - Other stages prefer larger vpp and pp - """ - if self_train_stage == TrainStage.FORWARD_STAGE: - return anomaly.vpp_stage, anomaly.pp_stage - else: - return -anomaly.vpp_stage, -anomaly.pp_stage - - self_cmp = [self.step, self.micro_step, self_train_stage, *vpp_pp_comparator(self), self.call_id] - other_cmp = [other.step, other.micro_step, other_train_stage, *vpp_pp_comparator(other), other.call_id] - return self_cmp < other_cmp - - def __le__(self, other): - if not isinstance(other, GradAnomalyData): - return NotImplemented - return self == other or self < other - - @staticmethod - def get_train_stage(tag_name): - """ - :param tag_name: "0:fc2.input:0/rank0/actv", "0:fc1.weight/rank0/post_grad", "0:fc2.weight/rank0/exp_avg_sq" - :return: int, if forward return 0; if backward return 1; if optimizer return 2 - """ - key_ = tag_name.split("/")[-1] - return TRAIN_STAGE.get(key_, TrainStage.DEFAULT_STAGE) - - def to_dict(self): - return self.__dict__ - - def get_key(self): - # 0:1.self_attention.core_attention_flash_0/rank0/input_grad - return ''.join([str(self.tag_name), "_step_", str(self.step), "_call_", str(self.call_id)]) - - @dataclass class WriterInput: path: str @@ -416,7 +208,6 @@ class CSVWriterWithAD(BaseWriterWithAD): new_line = name.split(MonitorConst.NAME_SEP) + metric_value new_line.insert(2, step) new_data.append(new_line) - new_data = pd.DataFrame(new_data).round(self.ndigits).fillna("nan") write_df_to_csv(new_data, filepath, mode='a+', header=False) self.context_dict = defaultdict(list) diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py b/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py index 33299b52633e82c5b7a472c59cb0b8ed346461db..f3750d1c94994e8466810a61b2109df45d69ba22 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py @@ -28,11 +28,12 @@ from torch.utils.hooks import BackwardHook from msprobe.core.common.const import MonitorConst, Const from msprobe.core.common.file_utils import load_json, save_json from msprobe.core.common.decorator import recursion_depth_decorator +from msprobe.core.monitor.anomaly_processor import AnomalyScanner, AnomalyDataFactory, AnomalyDataWriter +from msprobe.core.common.file_utils import write_df_to_csv +from msprobe.core.common.utils import analyze_api_call_stack from msprobe.pytorch.common.log import logger from msprobe.pytorch.common.utils import is_recomputation, is_float8_tensor -from msprobe.pytorch.monitor.anomaly_analyse import AnomalyDataWriter -from msprobe.pytorch.monitor.anomaly_detect import AnomalyScanner, SummaryWriterWithAD, AnomalyDataFactory, \ - CSVWriterWithAD, BaseWriterWithAD, WriterInput +from msprobe.pytorch.monitor.data_writers import SummaryWriterWithAD, CSVWriterWithAD, BaseWriterWithAD, WriterInput from msprobe.pytorch.monitor.distributed.wrap_distributed import api_register, create_hooks, op_aggregate, \ get_process_group from msprobe.pytorch.monitor.features import get_sign_matches @@ -42,8 +43,7 @@ from msprobe.pytorch.monitor.optimizer_collect import OptimizerMonFactory from msprobe.pytorch.monitor.utils import get_param_struct, validate_config, validate_ops, \ get_output_base_dir, get_target_output_dir, chmod_tensorboard_dir, validate_set_monitor from msprobe.pytorch.monitor.visualizer import HeatmapVisualizer -from msprobe.core.common.file_utils import write_df_to_csv -from msprobe.core.common.utils import analyze_api_call_stack + torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0' if not torch_version_above_or_equal_2: diff --git a/debug/accuracy_tools/msprobe/test/core_ut/monitor/__init__.py b/debug/accuracy_tools/msprobe/test/core_ut/monitor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_anomaly_analyse.py b/debug/accuracy_tools/msprobe/test/core_ut/monitor/test_anomaly_processor.py similarity index 46% rename from debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_anomaly_analyse.py rename to debug/accuracy_tools/msprobe/test/core_ut/monitor/test_anomaly_processor.py index ad4a97acaa9940e807e4023b9745bd210a827501..2511d60caa823366c778761ccb8fb9bca747d2f5 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_anomaly_analyse.py +++ b/debug/accuracy_tools/msprobe/test/core_ut/monitor/test_anomaly_processor.py @@ -1,14 +1,282 @@ import os import unittest +from unittest import TestCase from unittest.mock import patch, MagicMock -from msprobe.pytorch.monitor.anomaly_detect import GradAnomalyData +from msprobe.core.monitor.anomaly_processor import ScanRule, AnomalyTurbulence, AnomalyNan, AnomalyScanner, \ + AnomalyDataFactory, GradAnomalyData, AnomalyDataWriter, AnomalyDataLoader, AnomalyAnalyse, \ + _get_step_and_stop, _anomaly_analyse, _get_parse_args -from msprobe.pytorch.monitor.anomaly_analyse import AnomalyDataWriter, AnomalyDataLoader, AnomalyAnalyse, \ - _get_parse_args, _get_step_and_stop, _anomaly_analyse +class TestScanRule(TestCase): + def test_apply_not_implemented(self): + scan_rule = ScanRule() + with self.assertRaises(Exception) as context: + scan_rule.apply(None, None) + + self.assertEqual(str(context.exception), "abstract method apply is not implemented") + + +class TestAnomalyTurbulence(TestCase): + + def setUp(self) -> None: + self.threshold = 0.2 + self.rule = AnomalyTurbulence(self.threshold) + + def test_apply_with_positive_baseline(self): + history = 12 + cur = 16 + result = self.rule.apply(cur, history=history) + self.assertTrue(result) + + def test_apply_with_non_positive_baseline(self): + history = 0 + cur = -1 + result = self.rule.apply(cur, history=history) + self.assertTrue(result) + + def test_apply_with_valid_value(self): + history = 0 + cur = 0 + result = self.rule.apply(cur, history=history) + self.assertFalse(result) + + +class TestAnomalyNan(TestCase): + + def setUp(self) -> None: + self.threshold = 1e10 + self.rule = AnomalyNan(self.threshold) + + def test_apply_with_nan(self): + cur = float("nan") + result = self.rule.apply(cur) + self.assertTrue(result) + + def test_apply_with_big_value(self): + cur = float("1e30") + result = self.rule.apply(cur) + self.assertTrue(result) + + def test_apply_with_valid_value(self): + cur = 0.5 + result = self.rule.apply(cur) + self.assertFalse(result) + + +class TestAnomalyScanner(TestCase): + + def test_load_rules_with_valied_spec(self): + specs = [ + {"rule_name": "AnomalyTurbulence", "args": {"threshold": 0.2}} + ] + rules = AnomalyScanner.load_rules(specs) + + self.assertEqual(len(rules), 1) + self.assertIsInstance(rules[0], AnomalyTurbulence) + self.assertEqual(rules[0].threshold, 0.2) + + rules = AnomalyScanner.load_rules(None) + self.assertEqual(len(rules), 0) + + @patch("msprobe.core.monitor.anomaly_processor.logger") + def test_load_rules_with_missing_keys(self, mock_logger): + specs = [ + {"rule_name": "AnomalyTurbulence"} + ] + rules = AnomalyScanner.load_rules(specs) -class TestAnomalyDataWriter(unittest.TestCase): + self.assertEqual(len(rules), 0) + mock_logger.warning.assert_called_once_with(f"Spec is missing required keys: {specs[0]}") + + def test_load_rules_with_invalid_rule(self): + # test invalid rule_name + specs = [{"rule_name": "InvalidRule", "args": {"threshold": 0.2}}] + rules = AnomalyScanner.load_rules(specs) + self.assertEqual(len(rules), 0) + + # test invalid args + specs = [{"rule_name": "AnomalyTurbulence", "args": "invalid args"}] + rules = AnomalyScanner.load_rules(specs) + self.assertEqual(len(rules), 0) + + def test_scan(self): + ad_rules = [AnomalyTurbulence(0.2)] + # test scan with anomaly + expected = True, "AnomalyTurbulence" + self.assertEqual(AnomalyScanner.scan(ad_rules, 1.0, 2.0), expected) + # test scan with no anomaly + expected = False, None + self.assertEqual(AnomalyScanner.scan(ad_rules, 1.0, 1.0), expected) + + +class TestAnomalyDataFactory(TestCase): + + def setUp(self) -> None: + rank = 0 + pp_stage = 0 + group_mates = [0] + self.AnomalyDataFactory = AnomalyDataFactory(rank, pp_stage, group_mates) + + def test_set_call_id(self): + name2callid = {'param_name': 0} + self.AnomalyDataFactory.set_call_id(name2callid) + + self.assertEqual(self.AnomalyDataFactory.name2callid, {'param_name': 0}) + + def test_create_success(self): + tag = ('0:1.self_attention.core_attention_flash_0/rank0/output', 'min') + message = "Rule AnomalyTurbulence reports anomaly signal in ('0:1.self_attention.core_attention_flash_0/rank0/output', 'min') at step 2." + step = 2 + result = self.AnomalyDataFactory.create(tag, message, step) + + self.assertEqual(result.step, step) + self.assertEqual(result.tag_name, tag[0]) + self.assertEqual(result.message, message) + self.assertEqual(result.vpp_stage, 0) + + # test no vpp_stage + tag = ('1.self_attention.core_attention_flash_0/rank0/output', 'min') + result = self.AnomalyDataFactory.create(tag, message, step) + self.assertEqual(result.vpp_stage, 0) + + def test_create_failed(self): + error_tag = '0:1.self_attention.core_attention_flash_0/rank0/output' + message = "Rule AnomalyTurbulence reports anomaly signal in ('0:1.self_attention.core_attention_flash_0/rank0/output', 'min') at step 2." + step = 2 + with self.assertRaises(Exception) as context: + self.AnomalyDataFactory.create(error_tag, message, step) + self.assertEqual(str(context.exception), "tag must be a tuple with length 2") + + +class TestGradAnomalyData(TestCase): + + def setUp(self) -> None: + tag_name = "0:1.self_attention.core_attention_flash.output:0/rank0/actv" + message = "Rule AnomalyTurbulence reports anomaly signal in ('0:1.self_attention.core_attention_flash.output:0/rank0/actv', 'min') at step 2." + group_mates = [0] + self.GradAnomalyData = GradAnomalyData(tag_name=tag_name, message=message, group_mates=group_mates) + + def test_get_train_stage(self): + tag_name_list = ["0:fc2.input:0/rank0/actv", "0:fc1.weight/rank0/post_grad", "0:fc2.weight/rank0/exp_avg_sq", ""] + expected_train_stage_list = [0, 1, 2, -1] + for tag_name, expected_train_stage in zip(tag_name_list, expected_train_stage_list): + train_stage = GradAnomalyData.get_train_stage(tag_name) + self.assertEqual(train_stage, expected_train_stage) + + def test_to_dict(self): + expected = { + 'rank': 0, + 'step': 0, + 'micro_step': 0, + 'pp_stage': 0, + 'vpp_stage': 0, + 'call_id': 0, + 'tag_name': "0:1.self_attention.core_attention_flash.output:0/rank0/actv", + 'message': "Rule AnomalyTurbulence reports anomaly signal in ('0:1.self_attention.core_attention_flash.output:0/rank0/actv', 'min') at step 2.", + 'group_mates': [0] + } + + self.assertEqual(self.GradAnomalyData.to_dict(), expected) + + def test_get_key(self): + expected = "0:1.self_attention.core_attention_flash.output:0/rank0/actv_step_0_call_0" + + self.assertEqual(self.GradAnomalyData.get_key(), expected) + + def test_lt_different_step(self): + data1 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="") + data2 = GradAnomalyData(step=2, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="") + self.assertLess(data1, data2) + self.assertGreater(data2, data1) + + def test_lt_same_step_different_micro_step(self): + data1 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="") + data2 = GradAnomalyData(step=1, micro_step=1, vpp_stage=0, pp_stage=0, call_id=0, tag_name="") + self.assertLess(data1, data2) + self.assertGreater(data2, data1) + + def test_lt_same_step_same_micro_step_different_vpp_stage(self): + # same forward + data1 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="xxx/actv") + data2 = GradAnomalyData(step=1, micro_step=0, vpp_stage=1, pp_stage=0, call_id=0, tag_name="xxx/actv") + self.assertLess(data1, data2) + self.assertGreater(data2, data1) + + # same backward + data1 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="xxx/post_grad") + data2 = GradAnomalyData(step=1, micro_step=0, vpp_stage=1, pp_stage=0, call_id=0, tag_name="xxx/post_grad") + self.assertLess(data2, data1) + self.assertGreater(data1, data2) + + # diff train stage + data1 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="xxx/actv") + data2 = GradAnomalyData(step=1, micro_step=0, vpp_stage=1, pp_stage=0, call_id=0, tag_name="xxx/post_grad") + self.assertLess(data1, data2) + self.assertGreater(data2, data1) + + def test_lt_same_step_same_micro_step_same_vpp_stage_different_pp_stage(self): + # same forward + data1 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="xxx/actv") + data2 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=1, call_id=0, tag_name="xxx/actv") + self.assertLess(data1, data2) + self.assertGreater(data2, data1) + + # same backward + data1 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="xxx/post_grad") + data2 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=1, call_id=0, tag_name="xxx/post_grad") + self.assertLess(data2, data1) + self.assertGreater(data1, data2) + + # diff train stage + data1 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="xxx/input") + data2 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=1, call_id=0, tag_name="xxx/post_grad") + self.assertLess(data1, data2) + self.assertGreater(data2, data1) + + def test_lt_same_step_same_micro_step_same_vpp_stage_same_pp_stage_different_call_id(self): + data1 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="") + data2 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=1, tag_name="") + self.assertLess(data1, data2) + self.assertGreater(data2, data1) + + def test_lt_same_data(self): + data1 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="") + data2 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="") + self.assertGreaterEqual(data1, data2) + self.assertLessEqual(data1, data2) + + def test_lt_not_instance(self): + data = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0) + not_instance = "not an instance of GradAnomalyData" + self.assertEqual(data.__lt__(not_instance), NotImplemented) + + def test_le_same_instance(self): + # 测试相同实例的情况 + data1 = GradAnomalyData() + self.assertTrue(data1 <= data1) + + def test_le_different_instance(self): + # 测试不同实例的情况 + data1 = GradAnomalyData() + data2 = GradAnomalyData() + self.assertTrue(data1 <= data2) + + def test_le_not_instance(self): + # 测试非GradAnomalyData实例的情况 + data = GradAnomalyData() + not_instance = "Not an instance of GradAnomalyData" + self.assertEqual(data.__le__(not_instance), NotImplemented) + + def test_le_different_instance_not_equal(self): + # 测试不同实例且不相等的情况 + data1 = GradAnomalyData() + data2 = GradAnomalyData() + data2.some_attribute = "some value" + self.assertTrue(data1 <= data2) + + +class TestAnomalyDataWriter(TestCase): def test_get_anomaly_dict(self): # 测试 get_anomaly_dict 方法 @@ -29,9 +297,9 @@ class TestAnomalyDataWriter(unittest.TestCase): } self.assertEqual(result, expected) - @patch('msprobe.pytorch.monitor.anomaly_analyse.os.path.exists') - @patch('msprobe.pytorch.monitor.anomaly_analyse.create_directory') - @patch('msprobe.pytorch.monitor.anomaly_analyse.save_json') + @patch('msprobe.core.monitor.anomaly_processor.os.path.exists') + @patch('msprobe.core.monitor.anomaly_processor.create_directory') + @patch('msprobe.core.monitor.anomaly_processor.save_json') def test_init_detected_json(self, mock_save_json, mock_create_directory, mock_exists): # 模拟路径检查 mock_exists.side_effect = [False, False, False] # dump_path, dump_rank_dir, json_path @@ -47,10 +315,10 @@ class TestAnomalyDataWriter(unittest.TestCase): # 检查是否初始化了 JSON 文件 mock_save_json.assert_called_once_with(writer.json_path, {}, indent=1) - @patch('msprobe.pytorch.monitor.anomaly_analyse.check_file_or_directory_path') - @patch('msprobe.pytorch.monitor.anomaly_analyse.remove_path') - @patch('msprobe.pytorch.monitor.anomaly_analyse.save_json') - @patch('msprobe.pytorch.monitor.anomaly_analyse.logger') + @patch('msprobe.core.monitor.anomaly_processor.check_file_or_directory_path') + @patch('msprobe.core.monitor.anomaly_processor.remove_path') + @patch('msprobe.core.monitor.anomaly_processor.save_json') + @patch('msprobe.core.monitor.anomaly_processor.logger') def test_init_detected_json_existing_file(self, mock_logger, mock_save_json, mock_remove_path, mock_check_path): # 设置测试参数 dump_path = 'test/dump_path' @@ -71,9 +339,9 @@ class TestAnomalyDataWriter(unittest.TestCase): mock_logger.warning.assert_called_once_with(f"The existing file will be deleted: {writer.json_path}.") mock_save_json.assert_called_once_with(writer.json_path, {}, indent=1) - @patch('msprobe.pytorch.monitor.anomaly_analyse.os.path.exists') - @patch('msprobe.pytorch.monitor.anomaly_analyse.load_json') - @patch('msprobe.pytorch.monitor.anomaly_analyse.save_json') + @patch('msprobe.core.monitor.anomaly_processor.os.path.exists') + @patch('msprobe.core.monitor.anomaly_processor.load_json') + @patch('msprobe.core.monitor.anomaly_processor.save_json') def test_write_detected_json(self, mock_save_json, mock_load_json, mock_exists): mock_exists.side_effect = [True, True] # json_path 存在 @@ -100,9 +368,9 @@ class TestAnomalyDataWriter(unittest.TestCase): mock_save_json.assert_called_once_with(writer.json_path, expected_data, indent=1) -class TestAnomalyDataLoader(unittest.TestCase): +class TestAnomalyDataLoader(TestCase): - @patch('msprobe.pytorch.monitor.anomaly_analyse.GradAnomalyData') # 替换为 GradAnomalyData 的实际导入路径 + @patch('msprobe.core.monitor.anomaly_processor.GradAnomalyData') # 替换为 GradAnomalyData 的实际导入路径 def test_create_instances_from_dict(self, mock_GradAnomalyData): # 模拟 GradAnomalyData 的构造函数 def mock_constructor(**kwargs): @@ -121,11 +389,11 @@ class TestAnomalyDataLoader(unittest.TestCase): # 确保创建了两个实例,第三个因缺少 key2 被捕获 self.assertEqual(len(instances), 2) - @patch('msprobe.pytorch.monitor.anomaly_analyse.os.listdir') - @patch('msprobe.pytorch.monitor.anomaly_analyse.os.path.exists') - @patch('msprobe.pytorch.monitor.anomaly_analyse.load_json') - @patch('msprobe.pytorch.monitor.anomaly_analyse.check_file_or_directory_path') - @patch('msprobe.pytorch.monitor.anomaly_analyse.GradAnomalyData') + @patch('msprobe.core.monitor.anomaly_processor.os.listdir') + @patch('msprobe.core.monitor.anomaly_processor.os.path.exists') + @patch('msprobe.core.monitor.anomaly_processor.load_json') + @patch('msprobe.core.monitor.anomaly_processor.check_file_or_directory_path') + @patch('msprobe.core.monitor.anomaly_processor.GradAnomalyData') def test_get_anomalies_from_jsons(self, mock_GradAnomalyData, mock_check_path, mock_load_json, mock_exists, mock_listdir): mock_check_path.return_value = None @@ -145,7 +413,7 @@ class TestAnomalyDataLoader(unittest.TestCase): mock_GradAnomalyData.side_effect = mock_constructor # 假设构造成功 loader = AnomalyDataLoader('/tmp/data') - with patch('msprobe.pytorch.monitor.anomaly_analyse.os.path.isdir', return_value=True): + with patch('msprobe.core.monitor.anomaly_processor.os.path.isdir', return_value=True): anomalies = loader.get_anomalies_from_jsons() # 确保从 rank0 读取了异常数据 @@ -154,7 +422,7 @@ class TestAnomalyDataLoader(unittest.TestCase): mock_load_json.assert_called_once_with('/tmp/data/rank0/anomaly.json') -class TestAnomalyAnalyse(unittest.TestCase): +class TestAnomalyAnalyse(TestCase): def setUp(self): self.anomaly_analyse = AnomalyAnalyse() @@ -188,10 +456,10 @@ class TestAnomalyAnalyse(unittest.TestCase): self.assertEqual(len(result), 3) self.assertEqual(result, [anomalies[1], anomalies[0], anomalies[2]]) - @patch('msprobe.pytorch.monitor.anomaly_analyse.os.path.exists') - @patch('msprobe.pytorch.monitor.anomaly_analyse.AnomalyDataWriter.get_anomaly_dict') - @patch('msprobe.pytorch.monitor.anomaly_analyse.save_json') - @patch('msprobe.pytorch.monitor.anomaly_analyse.logger') + @patch('msprobe.core.monitor.anomaly_processor.os.path.exists') + @patch('msprobe.core.monitor.anomaly_processor.AnomalyDataWriter.get_anomaly_dict') + @patch('msprobe.core.monitor.anomaly_processor.save_json') + @patch('msprobe.core.monitor.anomaly_processor.logger') def test_rewrite_sorted_anomalies(self, mock_logger, mock_save_json, mock_get_anomaly_dict, mock_exists): # 设置 mock mock_exists.return_value = False @@ -201,7 +469,7 @@ class TestAnomalyAnalyse(unittest.TestCase): # 调用方法 self.anomaly_analyse.sorted_anomalies = self.anomalies - with patch("msprobe.pytorch.monitor.anomaly_analyse.check_file_or_directory_path", return_value=None): + with patch("msprobe.core.monitor.anomaly_processor.check_file_or_directory_path", return_value=None): self.anomaly_analyse.rewrite_sorted_anomalies(output_path) # 验证调用 @@ -213,17 +481,17 @@ class TestAnomalyAnalyse(unittest.TestCase): ) mock_logger.info.assert_called_once_with("anomaly_analyse.json is at output_path.") - @patch('msprobe.pytorch.monitor.anomaly_analyse.os.path.exists') - @patch('msprobe.pytorch.monitor.anomaly_analyse.logger') + @patch('msprobe.core.monitor.anomaly_processor.os.path.exists') + @patch('msprobe.core.monitor.anomaly_processor.logger') def test_rewrite_sorted_anomalies_file_exists(self, mock_logger, mock_exists): # 模拟文件已经存在的情况 mock_exists.return_value = True output_path = 'output_path' # 调用方法 - with patch("msprobe.pytorch.monitor.anomaly_analyse.check_file_or_directory_path", return_value=None), \ - patch("msprobe.pytorch.monitor.anomaly_analyse.remove_path", return_value=None), \ - patch("msprobe.pytorch.monitor.anomaly_analyse.save_json", return_value=None): + with patch("msprobe.core.monitor.anomaly_processor.check_file_or_directory_path", return_value=None), \ + patch("msprobe.core.monitor.anomaly_processor.remove_path", return_value=None), \ + patch("msprobe.core.monitor.anomaly_processor.save_json", return_value=None): self.anomaly_analyse.rewrite_sorted_anomalies(output_path) # 验证日志警告 @@ -231,35 +499,7 @@ class TestAnomalyAnalyse(unittest.TestCase): f"The existing file will be deleted: output_path/anomaly_analyse.json.") -class TestParseArgs(unittest.TestCase): - - @patch('msprobe.pytorch.monitor.anomaly_analyse.sys.argv', - new=['script_name', '-d', 'path/to/data', '-o', 'path/to/output', '-k', '5', '-s', '[1,2,3]']) - def test_parse_args_with_all_arguments(self): - args = _get_parse_args() - self.assertEqual(args.data_path_dir, 'path/to/data') - self.assertEqual(args.out_path, 'path/to/output') - self.assertEqual(args.top_k_number, 5) - self.assertEqual(args.step_list, '[1,2,3]') - - @patch('msprobe.pytorch.monitor.anomaly_analyse.sys.argv', new=['script_name', '-d', 'path/to/data']) - def test_parse_args_with_required_argument_only(self): - args = _get_parse_args() - self.assertEqual(args.data_path_dir, 'path/to/data') - self.assertEqual(args.out_path, '') - self.assertEqual(args.top_k_number, 8) # 默认值 - self.assertEqual(args.step_list, '[]') # 默认值 - - @patch('msprobe.pytorch.monitor.anomaly_analyse.sys.argv', new=['script_name', '-d', 'path/to/data', '-k', '10']) - def test_parse_args_with_topk_only(self): - args = _get_parse_args() - self.assertEqual(args.data_path_dir, 'path/to/data') - self.assertEqual(args.out_path, '') - self.assertEqual(args.top_k_number, 10) # 提供的值 - self.assertEqual(args.step_list, '[]') # 默认值 - - -class TestGetStepAndStop(unittest.TestCase): +class TestGetStepAndStop(TestCase): def test_valid_step_list_and_top_k(self): # 构造有效的 args 对象 @@ -317,13 +557,13 @@ class TestGetStepAndStop(unittest.TestCase): self.assertEqual(str(context.exception), "The top k number must be greater than 0.") -class TestAnomalyAnalyseFunction(unittest.TestCase): +class TestAnomalyAnalyseFunction(TestCase): - @patch('msprobe.pytorch.monitor.anomaly_analyse._get_parse_args') # 模拟命令行参数解析 - @patch('msprobe.pytorch.monitor.anomaly_analyse._get_step_and_stop') # 模拟步骤和顶级数字解析 - @patch('msprobe.pytorch.monitor.anomaly_analyse.AnomalyDataLoader') # 模拟数据加载器 - @patch('msprobe.pytorch.monitor.anomaly_analyse.AnomalyAnalyse') # 模拟异常分析器 - @patch('msprobe.pytorch.monitor.anomaly_analyse.logger') # 模拟日志记录 + @patch('msprobe.core.monitor.anomaly_processor._get_parse_args') # 模拟命令行参数解析 + @patch('msprobe.core.monitor.anomaly_processor._get_step_and_stop') # 模拟步骤和顶级数字解析 + @patch('msprobe.core.monitor.anomaly_processor.AnomalyDataLoader') # 模拟数据加载器 + @patch('msprobe.core.monitor.anomaly_processor.AnomalyAnalyse') # 模拟异常分析器 + @patch('msprobe.core.monitor.anomaly_processor.logger') # 模拟日志记录 def test_anomaly_analyse(self, mock_logger, mock_anomaly_analyse, mock_anomaly_data_loader, mock_get_step_and_stop, mock_get_parse_args): # 模拟命令行参数 @@ -375,5 +615,33 @@ class TestAnomalyAnalyseFunction(unittest.TestCase): mock_logger.info.assert_any_call("1: Top Anomaly 2") +class TestParseArgs(TestCase): + + @patch('msprobe.core.monitor.anomaly_processor.sys.argv', + new=['script_name', '-d', 'path/to/data', '-o', 'path/to/output', '-k', '5', '-s', '[1,2,3]']) + def test_parse_args_with_all_arguments(self): + args = _get_parse_args() + self.assertEqual(args.data_path_dir, 'path/to/data') + self.assertEqual(args.out_path, 'path/to/output') + self.assertEqual(args.top_k_number, 5) + self.assertEqual(args.step_list, '[1,2,3]') + + @patch('msprobe.core.monitor.anomaly_processor.sys.argv', new=['script_name', '-d', 'path/to/data']) + def test_parse_args_with_required_argument_only(self): + args = _get_parse_args() + self.assertEqual(args.data_path_dir, 'path/to/data') + self.assertEqual(args.out_path, '') + self.assertEqual(args.top_k_number, 8) # 默认值 + self.assertEqual(args.step_list, '[]') # 默认值 + + @patch('msprobe.core.monitor.anomaly_processor.sys.argv', new=['script_name', '-d', 'path/to/data', '-k', '10']) + def test_parse_args_with_topk_only(self): + args = _get_parse_args() + self.assertEqual(args.data_path_dir, 'path/to/data') + self.assertEqual(args.out_path, '') + self.assertEqual(args.top_k_number, 10) # 提供的值 + self.assertEqual(args.step_list, '[]') # 默认值 + + if __name__ == '__main__': unittest.main() diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/ms_monitor/test_mon_module_hook.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/ms_monitor/test_mon_module_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..87d630c3620ad242f43d8ef629f44698ce7a5bbd --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/ms_monitor/test_mon_module_hook.py @@ -0,0 +1,399 @@ +import pytest +import os +import json +import numpy as np +import mock +from datetime import datetime +import unittest +import inspect +from unittest.mock import MagicMock, patch, mock_open +from collections import defaultdict + +import mindspore as ms +from mindspore import nn, ops, Tensor, Parameter +from msprobe.core.common.const import MonitorConst, Const +from msprobe.mindspore.monitor.module_hook import ( + TrainerMon, + ModuleHookContext, + OptimizerContext, + GradContext, + CommunicationContext +) + +class MyMomentum(nn.Optimizer): + def __init__(self, params, learning_rate, momentum=0.9): + super(MyMomentum, self).__init__(learning_rate, params) + self.moments = self.parameters.clone(prefix="exp_avg", init="zeros") + self.momentum = momentum + self.opt = ops.ApplyMomentum() + + def construct(self, gradients): + params = self.parameters + lr = self.get_lr() + gradients = self.flatten_gradients(gradients) + gradients = self.decay_weight(gradients) + gradients = self.gradients_centralization(gradients) + gradients = self.scale_grad(gradients) + + success = None + for param, mom, grad in zip(params, self.moments, gradients): + success = self.opt(param, mom, lr, grad, self.momentum) + return success + + +class TestContext(unittest.TestCase): + def test_communication_context(self): + cc_ctx = CommunicationContext() + cc_ctx.reset() + cc_ctx.data = {'tag1': {'min': [1, 2, 3], 'max': [10, 11, 12]}, + 'tag2': {'min': [16, 17, 18], 'max': [22, 23, 24]}} + cc_ctx.aggregate() + expected_aggregated_data = {'tag1': {'max': 12, 'min': 1}, 'tag2': {'max': 24, 'min': 16}} + self.assertEqual(cc_ctx.data, expected_aggregated_data) + + def test_grad_context(self): + grad_ctx = GradContext() + grad_ctx.reset() + self.assertEqual(grad_ctx.pre, {}) + self.assertEqual(grad_ctx.post, {}) + + def test_module_hook_context_initialization(self): + """测试 ModuleHookContext 初始化状态""" + ctx = ModuleHookContext(module_name="test_module") + + # 验证基本属性 + self.assertEqual(ctx.step, 0) + self.assertEqual(ctx.micro_step, 0) + self.assertEqual(ctx.module_name, "test_module") + self.assertEqual(ctx.stack, "") + + # 验证数据结构类型 + self.assertIsInstance(ctx.actv, defaultdict) + self.assertEqual(len(ctx.actv), 0) # 应为空字典 + + self.assertIsInstance(ctx.actvgrad, list) + self.assertEqual(len(ctx.actvgrad), 0) # 应为空列表 + + self.assertIsInstance(ctx.struct, dict) + self.assertEqual(len(ctx.struct), 0) # 应为空字典 + + def test_module_hook_context_reset(self): + """测试 ModuleHookContext 重置功能""" + ctx = ModuleHookContext(module_name="test") + + # 填充测试数据 + ctx.step = 5 + ctx.micro_step = 3 + ctx.actv['layer1']['weight'] = [1.2, 3.4] + ctx.actvgrad.append('grad_data') + ctx.stack = "test_stack" + ctx.struct['meta'] = {'size': 10} + + # 执行重置 + ctx.reset() + + # 验证重置后状态 + self.assertEqual(ctx.step, 5) # 不应重置 + self.assertEqual(ctx.micro_step, 3) # 不应重置 + self.assertEqual(len(ctx.actv), 0) # 字典应清空 + self.assertEqual(len(ctx.actvgrad), 0) # 列表应清空 + self.assertEqual(ctx.stack, "test_stack") # 不应重置 + self.assertEqual(len(ctx.struct), 1) # 不应重置 + + def test_optimizer_context_initialization(self): + """测试 OptimizerContext 初始化状态""" + ctx = OptimizerContext() + + # 验证基本属性 + self.assertEqual(ctx.step, 0) + + # 验证所有字典结构均为空 + self.assertIsInstance(ctx.param_mg_direction, defaultdict) + self.assertEqual(len(ctx.param_mg_direction), 0) + + self.assertIsInstance(ctx.param_adam_update, defaultdict) + self.assertEqual(len(ctx.param_adam_update), 0) + + self.assertIsInstance(ctx.param_adam_ratio, defaultdict) + self.assertEqual(len(ctx.param_adam_ratio), 0) + + self.assertIsInstance(ctx.param_weight_grad, defaultdict) + self.assertEqual(len(ctx.param_weight_grad), 0) + + self.assertIsInstance(ctx.param_exp_avg, defaultdict) + self.assertEqual(len(ctx.param_exp_avg), 0) + + self.assertIsInstance(ctx.param_exp_avg_sq, defaultdict) + self.assertEqual(len(ctx.param_exp_avg_sq), 0) + + self.assertIsInstance(ctx.exp_avg_metric, dict) + self.assertEqual(len(ctx.exp_avg_metric), 0) + + self.assertIsInstance(ctx.exp_avg_sq_metric, dict) + self.assertEqual(len(ctx.exp_avg_sq_metric), 0) + + self.assertIsInstance(ctx.metric_dict, dict) + self.assertEqual(len(ctx.metric_dict), 0) + + self.assertIsInstance(ctx.param_metric, dict) + self.assertEqual(len(ctx.param_metric), 0) + + def test_optimizer_context_reset(self): + """测试 OptimizerContext 重置功能""" + ctx = OptimizerContext() + + # 填充测试数据 + ctx.step = 100 + ctx.param_mg_direction['weight'] = 0.5 + ctx.param_adam_update['bias'] = (0.1, 0.2) + ctx.param_adam_ratio['embed'] = 0.8 + ctx.param_weight_grad['linear'] = [-0.4, 0.6] + ctx.param_exp_avg['conv'] = [0.9] + ctx.param_exp_avg_sq['norm'] = [0.99] + ctx.exp_avg_metric['acc'] = 0.75 + ctx.exp_avg_sq_metric['loss'] = 0.25 + ctx.metric_dict['f1'] = 0.9 + ctx.param_metric['weight_metric'] = 1.0 + + # 执行重置 + ctx.reset() + + # 验证重置后状态 + self.assertEqual(ctx.step, 100) # 不应重置 + + # 所有字典/默认字典应为空 + self.assertEqual(len(ctx.param_mg_direction), 0) + self.assertEqual(len(ctx.param_adam_update), 0) + self.assertEqual(len(ctx.param_adam_ratio), 0) + self.assertEqual(len(ctx.param_weight_grad), 0) + self.assertEqual(len(ctx.param_exp_avg), 0) + self.assertEqual(len(ctx.param_exp_avg_sq), 0) + self.assertEqual(len(ctx.exp_avg_metric), 0) + self.assertEqual(len(ctx.exp_avg_sq_metric), 0) + self.assertEqual(len(ctx.metric_dict), 0) + self.assertEqual(len(ctx.param_metric), 0) + + +class TestTrainerMonWithRealNetwork: + @classmethod + def setup_class(cls): + """Setup once for all tests in this class""" + cls.mock_config = { + "start_step": 0, + "collect_times": 10, + "step_interval": 1, + "format": "csv", + "ops": ["norm"], + "alert": {"rules": [], "dump": False}, + "xy_distribution": True, + "mv_distribution": True, + "forward_only": True + } + cls.config_file = "test_config.json" + with open(cls.config_file, 'w') as f: + json.dump(cls.mock_config, f) + + # Setup real network components + cls.net = nn.Dense(2, 3) + cls.loss_fn = nn.MAELoss() + cls.opt = MyMomentum(cls.net.trainable_params(), 0.01) + + @classmethod + def teardown_class(cls): + """Clean up after all tests""" + if os.path.exists(cls.config_file): + os.remove(cls.config_file) + + def setup_method(self): + """Setup before each test""" + self.trainer = TrainerMon(self.config_file) + self.trainer.set_monitor(self.net, self.opt) + + def test_monitor_with_real_training_step_when_valid_then_pass(self): + + # Create test data + data = Tensor(np.random.rand(1, 10, 2), ms.float32) + label = Tensor(np.random.rand(1, 10, 3), ms.float32) + + # Define forward function + def forward_fn(data, label): + logits = self.net(data) + loss = self.loss_fn(logits, label) + return loss, logits + + # Define grad function + grad_fn = ms.value_and_grad(forward_fn, None, self.opt.parameters, has_aux=True) + + # Define training step + def train_step(data, label): + (loss, _), grads = grad_fn(data, label) + self.opt(grads) + return loss + + # Execute training step + loss = train_step(data, label) + + # Verify monitoring results + assert isinstance(loss, Tensor) + assert len(self.trainer.module_fwd_hook_context_by_module) > 0 + assert len(self.trainer.optimizer_context) > 0 + + def test_monitor_with_multiple_training_steps_when_valid_then_pass(self): + + # Create test data + data = Tensor(np.random.rand(1, 10, 2), ms.float32) + label = Tensor(np.random.rand(1, 10, 3), ms.float32) + + # Define forward function + def forward_fn(data, label): + logits = self.net(data) + loss = self.loss_fn(logits, label) + return loss, logits + + # Define grad function + grad_fn = ms.value_and_grad(forward_fn, None, self.opt.parameters, has_aux=True) + + # Define training step + def train_step(data, label): + (loss, _), grads = grad_fn(data, label) + self.opt(grads) + return loss + + # Execute multiple training steps + for step in range(3): + loss = train_step(data, label) + + # Verify monitoring results + assert isinstance(loss, Tensor) + assert len(self.trainer.module_fwd_hook_context_by_module) > 0 + assert len(self.trainer.optimizer_context) > 0 + assert self.trainer.optimizer_context[self.opt].step == step + 1 + + def test_monitor_with_parameter_updates_when_valid_then_pass(self): + # Get initial parameters + initial_params = [param.value() for param in self.net.get_parameters()] + + # Create test data + data = Tensor(np.random.rand(1, 10, 2), ms.float32) + label = Tensor(np.random.rand(1, 10, 3), ms.float32) + + # Define forward function + def forward_fn(data, label): + logits = self.net(data) + loss = self.loss_fn(logits, label) + return loss, logits + + # Define grad function + grad_fn = ms.value_and_grad(forward_fn, None, self.opt.parameters, has_aux=True) + + # Define training step + def train_step(data, label): + (loss, _), grads = grad_fn(data, label) + self.opt(grads) + return loss + + # Execute training step + loss = train_step(data, label) + + # Get updated parameters + updated_params = [param.value() for param in self.net.get_parameters()] + + # Verify parameters have changed + for init_param, updated_param in zip(initial_params, updated_params): + assert not np.array_equal(init_param.asnumpy(), updated_param.asnumpy()) + + def test_monitor_with_gradient_collection_when_valid_then_pass(self): + # Enable gradient monitoring + self.trainer.wg_distribution = True + self.monitor_mbs_grad = True + self.trainer._hook_weights() + + # Create test data + data = Tensor(np.random.rand(1, 10, 2), ms.float32) + label = Tensor(np.random.rand(1, 10, 3), ms.float32) + + # Define forward function + def forward_fn(data, label): + logits = self.net(data) + loss = self.loss_fn(logits, label) + return loss, logits + + # Define grad function + grad_fn = ms.value_and_grad(forward_fn, None, self.opt.parameters, has_aux=True) + + # Define training step + def train_step(data, label): + (loss, _), grads = grad_fn(data, label) + # Assign to main_grad + for param, grad in zip(self.opt.parameters, grads): + param.main_grad = grad + self.opt(grads) + return loss + + # Execute training step + loss = train_step(data, label) + + # Verify gradients were collected + assert len(self.trainer.grad_context.post) > 0 + + def test_monitor_with_momentum_collection_when_valid_then_pass(self): + # Enable momentum monitoring + self.trainer.mv_distribution = True + + # Create test data + data = Tensor(np.random.rand(1, 10, 2), ms.float32) + label = Tensor(np.random.rand(1, 10, 3), ms.float32) + + # Define forward function + def forward_fn(data, label): + logits = self.net(data) + loss = self.loss_fn(logits, label) + return loss, logits + + # Define grad function + grad_fn = ms.value_and_grad(forward_fn, None, self.opt.parameters, has_aux=True) + + # Define training step + def train_step(data, label): + (loss, _), grads = grad_fn(data, label) + self.opt(grads) + return loss + + # Execute training step + loss = train_step(data, label) + + # Verify momentum was collected + opt_context = self.trainer.optimizer_context[self.opt] + assert len(opt_context.exp_avg_metric) > 0 + + def test_dynamic_monitor_when_change_then_pass(self): + self.trainer.dynamic_enable = True + + # Create test data + data = Tensor(np.random.rand(1, 10, 2), ms.float32) + label = Tensor(np.random.rand(1, 10, 3), ms.float32) + + # Define forward function + def forward_fn(data, label): + logits = self.net(data) + loss = self.loss_fn(logits, label) + return loss, logits + + # Define grad function + grad_fn = ms.value_and_grad(forward_fn, None, self.opt.parameters, has_aux=True) + + # Define training step + def train_step(data, label): + (loss, _), grads = grad_fn(data, label) + self.opt(grads) + return loss + + for step in range(3): + loss = train_step(data, label) + if step == 0: + self.mock_config['start_step'] = 2 # 修改为step2 + self.mock_config["collect_times"] = 1 + self.mock_config['dynamic_on'] = True + with open(self.config_file, 'w') as f: + json.dump(self.mock_config, f) + assert len(self.trainer.module_fwd_hook_context_by_module) > 0 diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_anomaly_detect.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_anomaly_detect.py deleted file mode 100644 index 6e416de8c689e6df7642dd52c60021a7e1b58baf..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_anomaly_detect.py +++ /dev/null @@ -1,320 +0,0 @@ -import unittest -from unittest import TestCase -from unittest.mock import patch - -from msprobe.pytorch.monitor.anomaly_detect import AnomalyTurbulence, AnomalyNan, AnomalyScanner, \ - AnomalyDataFactory, GradAnomalyData, BaseWriterWithAD, ScanRule, WriterInput - - -class TestScanRule(TestCase): - def test_apply_not_implemented(self): - scan_rule = ScanRule() - with self.assertRaises(Exception) as context: - scan_rule.apply(None, None) - - self.assertEqual(str(context.exception), "abstract method apply is not implemented") - - -class TestAnomalyTurbulence(TestCase): - - def setUp(self) -> None: - self.threshold = 0.2 - self.rule = AnomalyTurbulence(self.threshold) - - def test_apply_with_positive_baseline(self): - history = [10, 12, 14] - cur = 16 - result = self.rule.apply(cur, history=history) - self.assertTrue(result) - - def test_apply_with_non_positive_baseline(self): - history = [0, 0, 0] - cur = -1 - result = self.rule.apply(cur, history=history) - self.assertTrue(result) - - def test_apply_with_valid_value(self): - history = [0, 0, 0] - cur = 0 - result = self.rule.apply(cur, history=history) - self.assertFalse(result) - - -class TestAnomalyNan(TestCase): - - def setUp(self) -> None: - self.threshold = 1e10 - self.rule = AnomalyNan(self.threshold) - - def test_apply_with_nan(self): - cur = float("nan") - result = self.rule.apply(cur) - self.assertTrue(result) - - def test_apply_with_big_value(self): - cur = float("1e30") - result = self.rule.apply(cur) - self.assertTrue(result) - - def test_apply_with_valid_value(self): - cur = 0.5 - result = self.rule.apply(cur) - self.assertFalse(result) - - -class TestAnomalyScanner(TestCase): - - def test_load_rules_with_valied_spec(self): - specs = [ - {"rule_name": "AnomalyTurbulence", "args": {"threshold": 0.2}} - ] - rules = AnomalyScanner.load_rules(specs) - - self.assertEqual(len(rules), 1) - self.assertIsInstance(rules[0], AnomalyTurbulence) - self.assertEqual(rules[0].threshold, 0.2) - - rules = AnomalyScanner.load_rules(None) - self.assertEqual(len(rules), 0) - - @patch("msprobe.pytorch.monitor.anomaly_detect.logger") - def test_load_rules_with_missing_keys(self, mock_logger): - specs = [ - {"rule_name": "AnomalyTurbulence"} - ] - rules = AnomalyScanner.load_rules(specs) - - self.assertEqual(len(rules), 0) - mock_logger.warning.assert_called_once_with(f"Spec is missing required keys: {specs[0]}") - - def test_load_rules_with_invalid_rule(self): - # test invalid rule_name - specs = [{"rule_name": "InvalidRule", "args": {"threshold": 0.2}}] - rules = AnomalyScanner.load_rules(specs) - self.assertEqual(len(rules), 0) - - # test invalid args - specs = [{"rule_name": "AnomalyTurbulence", "args": "invalid args"}] - rules = AnomalyScanner.load_rules(specs) - self.assertEqual(len(rules), 0) - - def test_scan(self): - ad_rules = [AnomalyTurbulence(0.2)] - # test scan with anomaly - expected = True, "AnomalyTurbulence" - self.assertEqual(AnomalyScanner.scan(ad_rules, 1.0, 2.0), expected) - # test scan with no anomaly - expected = False, None - self.assertEqual(AnomalyScanner.scan(ad_rules, 1.0, 1.0), expected) - - -class TestAnomalyDataFactory(TestCase): - - def setUp(self) -> None: - rank = 0 - pp_stage = 0 - group_mates = [0] - self.AnomalyDataFactory = AnomalyDataFactory(rank, pp_stage, group_mates) - - def test_set_call_id(self): - name2callid = {'param_name': 0} - self.AnomalyDataFactory.set_call_id(name2callid) - - self.assertEqual(self.AnomalyDataFactory.name2callid, {'param_name': 0}) - - def test_create_success(self): - tag = ('0:1.self_attention.core_attention_flash_0/rank0/output', 'min') - message = "Rule AnomalyTurbulence reports anomaly signal in ('0:1.self_attention.core_attention_flash_0/rank0/output', 'min') at step 2." - step = 2 - result = self.AnomalyDataFactory.create(tag, message, step) - - self.assertEqual(result.step, step) - self.assertEqual(result.tag_name, tag[0]) - self.assertEqual(result.message, message) - self.assertEqual(result.vpp_stage, 0) - - # test no vpp_stage - tag = ('1.self_attention.core_attention_flash_0/rank0/output', 'min') - result = self.AnomalyDataFactory.create(tag, message, step) - self.assertEqual(result.vpp_stage, 0) - - def test_create_failed(self): - error_tag = '0:1.self_attention.core_attention_flash_0/rank0/output' - message = "Rule AnomalyTurbulence reports anomaly signal in ('0:1.self_attention.core_attention_flash_0/rank0/output', 'min') at step 2." - step = 2 - with self.assertRaises(Exception) as context: - self.AnomalyDataFactory.create(error_tag, message, step) - self.assertEqual(str(context.exception), "tag must be a tuple with length 2") - - -class TestGradAnomalyData(TestCase): - - def setUp(self) -> None: - tag_name = "0:1.self_attention.core_attention_flash.output:0/rank0/actv" - message = "Rule AnomalyTurbulence reports anomaly signal in ('0:1.self_attention.core_attention_flash.output:0/rank0/actv', 'min') at step 2." - group_mates = [0] - self.GradAnomalyData = GradAnomalyData(tag_name=tag_name, message=message, group_mates=group_mates) - - def test_get_train_stage(self): - tag_name_list = ["0:fc2.input:0/rank0/actv", "0:fc1.weight/rank0/post_grad", "0:fc2.weight/rank0/exp_avg_sq", ""] - expected_train_stage_list = [0, 1, 2, -1] - for tag_name, expected_train_stage in zip(tag_name_list, expected_train_stage_list): - train_stage = GradAnomalyData.get_train_stage(tag_name) - self.assertEqual(train_stage, expected_train_stage) - - def test_to_dict(self): - expected = { - 'rank': 0, - 'step': 0, - 'micro_step': 0, - 'pp_stage': 0, - 'vpp_stage': 0, - 'call_id': 0, - 'tag_name': "0:1.self_attention.core_attention_flash.output:0/rank0/actv", - 'message': "Rule AnomalyTurbulence reports anomaly signal in ('0:1.self_attention.core_attention_flash.output:0/rank0/actv', 'min') at step 2.", - 'group_mates': [0] - } - - self.assertEqual(self.GradAnomalyData.to_dict(), expected) - - def test_get_key(self): - expected = "0:1.self_attention.core_attention_flash.output:0/rank0/actv_step_0_call_0" - - self.assertEqual(self.GradAnomalyData.get_key(), expected) - - def test_lt_different_step(self): - data1 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="") - data2 = GradAnomalyData(step=2, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="") - self.assertLess(data1, data2) - self.assertGreater(data2, data1) - - def test_lt_same_step_different_micro_step(self): - data1 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="") - data2 = GradAnomalyData(step=1, micro_step=1, vpp_stage=0, pp_stage=0, call_id=0, tag_name="") - self.assertLess(data1, data2) - self.assertGreater(data2, data1) - - def test_lt_same_step_same_micro_step_different_vpp_stage(self): - # same forward - data1 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="xxx/actv") - data2 = GradAnomalyData(step=1, micro_step=0, vpp_stage=1, pp_stage=0, call_id=0, tag_name="xxx/actv") - self.assertLess(data1, data2) - self.assertGreater(data2, data1) - - # same backward - data1 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="xxx/post_grad") - data2 = GradAnomalyData(step=1, micro_step=0, vpp_stage=1, pp_stage=0, call_id=0, tag_name="xxx/post_grad") - self.assertLess(data2, data1) - self.assertGreater(data1, data2) - - # diff train stage - data1 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="xxx/actv") - data2 = GradAnomalyData(step=1, micro_step=0, vpp_stage=1, pp_stage=0, call_id=0, tag_name="xxx/post_grad") - self.assertLess(data1, data2) - self.assertGreater(data2, data1) - - def test_lt_same_step_same_micro_step_same_vpp_stage_different_pp_stage(self): - # same forward - data1 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="xxx/actv") - data2 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=1, call_id=0, tag_name="xxx/actv") - self.assertLess(data1, data2) - self.assertGreater(data2, data1) - - # same backward - data1 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="xxx/post_grad") - data2 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=1, call_id=0, tag_name="xxx/post_grad") - self.assertLess(data2, data1) - self.assertGreater(data1, data2) - - # diff train stage - data1 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="xxx/input") - data2 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=1, call_id=0, tag_name="xxx/post_grad") - self.assertLess(data1, data2) - self.assertGreater(data2, data1) - - def test_lt_same_step_same_micro_step_same_vpp_stage_same_pp_stage_different_call_id(self): - data1 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="") - data2 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=1, tag_name="") - self.assertLess(data1, data2) - self.assertGreater(data2, data1) - - def test_lt_same_data(self): - data1 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="") - data2 = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0, tag_name="") - self.assertGreaterEqual(data1, data2) - self.assertLessEqual(data1, data2) - - def test_lt_not_instance(self): - data = GradAnomalyData(step=1, micro_step=0, vpp_stage=0, pp_stage=0, call_id=0) - not_instance = "not an instance of GradAnomalyData" - self.assertEqual(data.__lt__(not_instance), NotImplemented) - - def test_le_same_instance(self): - # 测试相同实例的情况 - data1 = GradAnomalyData() - self.assertTrue(data1 <= data1) - - def test_le_different_instance(self): - # 测试不同实例的情况 - data1 = GradAnomalyData() - data2 = GradAnomalyData() - self.assertTrue(data1 <= data2) - - def test_le_not_instance(self): - # 测试非GradAnomalyData实例的情况 - data = GradAnomalyData() - not_instance = "Not an instance of GradAnomalyData" - self.assertEqual(data.__le__(not_instance), NotImplemented) - - def test_le_different_instance_not_equal(self): - # 测试不同实例且不相等的情况 - data1 = GradAnomalyData() - data2 = GradAnomalyData() - data2.some_attribute = "some value" - self.assertTrue(data1 <= data2) - - -class TestBaseWriterWithAD(TestCase): - - def setUp(self) -> None: - self.BaseWriter = BaseWriterWithAD(WriterInput('', None, None)) - - def test_get_anomalies(self): - expected = [] - - self.assertEqual(self.BaseWriter.get_anomalies(), expected) - - def test_clear_anomalies(self): - self.BaseWriter.anomalies = ['anomaly1', 'anomaly2'] - self.BaseWriter.clear_anomalies() - - self.assertEqual(self.BaseWriter.anomalies, []) - - @patch("msprobe.pytorch.monitor.anomaly_detect.logger") - def test_add_scalar(self, mock_logger): - AnomalyTurbulence_obj = AnomalyTurbulence(0.2) - self.BaseWriter.ad_rules = [AnomalyTurbulence_obj] - tag = ('0:1.post_attention_norm.weight/rank0/pre_grad', 'mean') - self.BaseWriter.tag2scalars = {tag: {'avg': 1.0, 'count': 1}} - self.BaseWriter.add_scalar(tag, 2.0) - - mock_logger.info.assert_called_once() - - def test_ad(self): - AnomalyTurbulence_obj = AnomalyTurbulence(0.2) - self.BaseWriter.ad_rules = [AnomalyTurbulence_obj] - expected = True, "AnomalyTurbulence" - - self.assertEqual(self.BaseWriter._ad(2.0, 1.0), expected) - - def test_update_tag2scalars(self): - self.BaseWriter._update_tag2scalars('tag1', 1.0) - self.assertEqual(self.BaseWriter.tag2scalars['tag1']['avg'], 1.0) - self.assertEqual(self.BaseWriter.tag2scalars['tag1']['count'], 1) - self.BaseWriter._update_tag2scalars('tag1', 2.0) - self.assertEqual(self.BaseWriter.tag2scalars['tag1']['avg'], 1.01) - self.assertEqual(self.BaseWriter.tag2scalars['tag1']['count'], 2) - - -if __name__ == '__main__': - unittest.main() diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_data_writers.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_data_writers.py new file mode 100644 index 0000000000000000000000000000000000000000..34204267935cd7691f5bcccce6c1af5451a2c34f --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/monitor/test_data_writers.py @@ -0,0 +1,52 @@ +import unittest +from unittest import TestCase +from unittest.mock import patch + +from msprobe.core.monitor.anomaly_processor import AnomalyTurbulence +from msprobe.pytorch.monitor.data_writers import BaseWriterWithAD, WriterInput + + +class TestBaseWriterWithAD(TestCase): + + def setUp(self) -> None: + self.BaseWriter = BaseWriterWithAD(WriterInput('', None, None)) + + def test_get_anomalies(self): + expected = [] + + self.assertEqual(self.BaseWriter.get_anomalies(), expected) + + def test_clear_anomalies(self): + self.BaseWriter.anomalies = ['anomaly1', 'anomaly2'] + self.BaseWriter.clear_anomalies() + + self.assertEqual(self.BaseWriter.anomalies, []) + + @patch("msprobe.pytorch.monitor.data_writers.logger") + def test_add_scalar(self, mock_logger): + AnomalyTurbulence_obj = AnomalyTurbulence(0.2) + self.BaseWriter.ad_rules = [AnomalyTurbulence_obj] + tag = ('0:1.post_attention_norm.weight/rank0/pre_grad', 'mean') + self.BaseWriter.tag2scalars = {tag: {'avg': 1.0, 'count': 1}} + self.BaseWriter.add_scalar(tag, 2.0) + + mock_logger.info.assert_called_once() + + def test_ad(self): + AnomalyTurbulence_obj = AnomalyTurbulence(0.2) + self.BaseWriter.ad_rules = [AnomalyTurbulence_obj] + expected = True, "AnomalyTurbulence" + + self.assertEqual(self.BaseWriter._ad(2.0, 1.0), expected) + + def test_update_tag2scalars(self): + self.BaseWriter._update_tag2scalars('tag1', 1.0) + self.assertEqual(self.BaseWriter.tag2scalars['tag1']['avg'], 1.0) + self.assertEqual(self.BaseWriter.tag2scalars['tag1']['count'], 1) + self.BaseWriter._update_tag2scalars('tag1', 2.0) + self.assertEqual(self.BaseWriter.tag2scalars['tag1']['avg'], 1.01) + self.assertEqual(self.BaseWriter.tag2scalars['tag1']['count'], 2) + + +if __name__ == '__main__': + unittest.main()