diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/anomaly_detect.py b/debug/accuracy_tools/msprobe/pytorch/monitor/anomaly_detect.py index 1a9ecc3ad1f05a01f6457ddf8a5530d1c7d48f78..286297acd0db4c8c8f804925bec03224f4963254 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/anomaly_detect.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/anomaly_detect.py @@ -1,31 +1,15 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - import os import sys import statistics as st from abc import ABC from typing import List from collections import defaultdict - +from dataclasses import dataclass, field +import pandas as pd from torch.utils.tensorboard import SummaryWriter - -from msprobe.pytorch.monitor.utils import print_info_log, print_warn_log, print_error_log -from msprobe.pytorch.monitor.file_check import check_path_before_create, change_mode, FileCheckConst, create_directory +from kj600.utils import print_info_log, check_file_valid_writable, make_file_safety, create_directory +from kj600.const import Const +from kj600.file_check import change_mode, FileCheckConst class ScanRule(ABC): @@ -57,29 +41,12 @@ class AnomalyScanner: return [] alert_rules = [] for spec in specs: - # 使用get方法获取键值,如果键不存在则返回None - rule_cls_name = spec.get("rule_name") - rule_args = spec.get("args") - - # 检查必要的键是否存在 - if rule_cls_name is None or rule_args is None: - print_warn_log(f"Spec is missing required keys: {spec}") - continue - + rule_cls_name = spec["rule_name"] + rule_args = spec["args"] cur_module = sys.modules[__name__] - try: - rule_cls = getattr(cur_module, rule_cls_name) - except AttributeError: - print_error_log(f"Rule class '{rule_cls_name}' not found in the current module.") - continue - - try: - rule_instance = rule_cls(**rule_args) - alert_rules.append(rule_instance) - except Exception as e: - print_error_log(f"Error creating instance of rule '{rule_cls_name}': {e}") - continue - + rule_cls = getattr(cur_module, rule_cls_name) + rule_instance = rule_cls(**rule_args) + alert_rules.append(rule_instance) return alert_rules @staticmethod @@ -92,7 +59,7 @@ class AnomalyScanner: return anomaly, None -class BCOLORS: +class bcolors: HEADER = '\033[95m' OKBLUE = '\033[94m' OKCYAN = '\033[96m' @@ -104,40 +71,185 @@ class BCOLORS: UNDERLINE = '\033[4m' -class SummaryWriterWithAD(SummaryWriter): - def __init__(self, path, ad_rules, job_id, anomaly_inform=False): - check_path_before_create(path) - create_directory(path) - try: - super().__init__(path) - except Exception as e: - print_error_log(f'error when init summary writer at {path}: {e}') - raise ValueError("Init summary writer error.") from e - for event in os.listdir(path): - change_mode(os.path.join(path, event), FileCheckConst.DATA_FILE_AUTHORITY) - self.tag2scalars = defaultdict(list) +class AnomalyDataFactory(ABC): + def __init__(self, rank, pp_stage, group_mates): + super().__init__() + self.rank = rank + self.pp_stage = pp_stage + self.group_mates = group_mates + self.micro_step = 0 + self.vpp_stage = 0 + self.name2callid = {} + + def set_call_id(self, name2callid): + """根据当前GradContext信息更新call_id vpp_stage等信息 + """ + self.name2callid = name2callid + + def create(self, tag_name, message, step): + """如果检查出异常, 调用当前接口生成GradAnomalyData实例 + """ + param_name = tag_name.split('/')[0] + call_id = self.name2callid.get(param_name, -1) + if Const.VPP_SEP in param_name: + vpp_stage = int(param_name.split(Const.VPP_SEP)[0]) + else: + vpp_stage = 0 + + return GradAnomalyData( + self.rank, + step, + self.micro_step, + self.pp_stage, + self.vpp_stage, + call_id, + tag_name, + message, + self.group_mates + ) + + +@dataclass(eq=True) +class GradAnomalyData: + rank: int = 0 + step: int = 0 + micro_step: int = 0 + pp_stage: int = 0 + vpp_stage: int = 0 + call_id: int = 0 + tag_name: str = field(default=None, compare=False) + message: str = field(default="", compare=False) + group_mates: list = field(default=None, compare=False) + + def __lt__(self, other): + if not isinstance(other, GradAnomalyData): + return NotImplemented + if self.step != other.step: + return self.step < other.step + if self.micro_step != other.micro_step: + return self.micro_step < other.micro_step + if self.pp_stage != other.pp_stage: + return self.pp_stage > other.pp_stage + if self.vpp_stage != other.vpp_stage: + return self.vpp_stage > other.vpp_stage + if self.call_id != other.call_id: + return self.call_id < other.call_id + return False + + def __le__(self, other): + if not isinstance(other, GradAnomalyData): + return NotImplemented + return self == other or self < other + + def to_dict(self): + return self.__dict__ + + def get_key(self): + return ''.join( + (str(self.tag_name), "_step_", str(self.step), "_call_", str(self.call_id))) + + +class BaseWriterWithAD: + def __init__(self, path, ad_rules, job_id, anomaly_inform=False, anomaly_factory=None, ndigits=6): + self.tag2scalars = {} self.ad_rules = ad_rules self.job_id = job_id self.anomaly_inform = anomaly_inform + self.anomaly_factory = anomaly_factory + self.anomalies = [] + self.ndigits = ndigits + + def get_anomalies(self): + """返回已检测到的异常列表 + """ + return self.anomalies + + def clear_anomalies(self): + self.anomalies.clear() - def add_scalar(self, tag, scalar_value, global_step=None, walltime=None, new_style=False, double_precision=False): - new_avg = avg = scalar_value - if tag in self.tag2scalars: - n = len(self.tag2scalars[tag]) - _, avg = self.tag2scalars[tag][-1] - new_avg = (avg * n + scalar_value) / (n + 1) - self.tag2scalars[tag].append((scalar_value, new_avg)) + def add_scalar(self, tag, scalar_value, global_step=None): + avg = self._update_tag2scalars(tag, scalar_value) detected, rule_name = self._ad(scalar_value, history=avg) if detected: - print_info_log( - f"{BCOLORS.WARNING}> Rule {rule_name} reports anomaly signal in {tag} at step {global_step}." - f"{BCOLORS.ENDC}") - exception_message = (f"{BCOLORS.WARNING}> Rule {rule_name} reports anomaly signal in {tag} at step " - f"{global_step}.{BCOLORS.ENDC}") + exception_message = f"Rule {rule_name} reports anomaly signal in {tag} at step {global_step}." + print_info_log(f"{bcolors.WARNING}> {exception_message}{bcolors.ENDC}") if self.anomaly_inform: self.anomaly_inform.run(exception_message, self.job_id) - args = [tag, scalar_value, global_step, walltime, new_style, double_precision] - return super().add_scalar(*args) + + if self.anomaly_factory: + self.anomalies.append(self.anomaly_factory.create(tag, exception_message, global_step)) def _ad(self, scalar_value, history): return AnomalyScanner.scan(self.ad_rules, history, cur=scalar_value) + + def _update_tag2scalars(self, tag, scalar_value): + """Update the average and count of a scalar value associated with a tag. + + This method is used to maintain a running average of scalar values for each tag. + + + Args: + tag (str): The tag identifier. + scalar_value (float): The scalar value to be added. + + Returns: + float: The average value before update. + """ + if tag not in self.tag2scalars: + self.tag2scalars[tag] = {'avg': scalar_value, 'count': 0} + avg = self.tag2scalars[tag]['avg'] + new_avg = (avg * self.tag2scalars[tag]['count'] + scalar_value) / (self.tag2scalars[tag]['count'] + 1) + self.tag2scalars[tag]['avg'] = new_avg + self.tag2scalars[tag]['count'] += 1 + return avg + + +class CSVWriterWithAD(BaseWriterWithAD): + def __init__(self, path, ad_rules, job_id, anomaly_inform=False, anomaly_factory=None, ndigits=6): + super().__init__(path, ad_rules, job_id, anomaly_inform, anomaly_factory, ndigits) + + self.log_dir = path + create_directory(path) + self.context_dict = defaultdict(list) + self.header = [] + + def write_csv(self, prefix, step): + if len(self.context_dict) == 0: + return + filepath = os.path.join(self.log_dir, f'{prefix}_{step}.csv') + if not os.path.exists(filepath): + make_file_safety(filepath) + data_frame = pd.DataFrame(columns=self.header) + data_frame.to_csv(filepath, index=False) + change_mode(filepath, FileCheckConst.DATA_FILE_AUTHORITY) + + check_file_valid_writable(filepath) + new_data = [] + for name, metric_value in self.context_dict.items(): + if Const.VPP_SEP not in name: + new_data.append([name] + metric_value) + else: + new_data.append(name.split(Const.VPP_SEP) + metric_value) + new_data = pd.DataFrame(new_data) + new_data.to_csv(filepath, mode='a+', header=False, index=False) + self.context_dict = defaultdict(list) + + def add_scalar(self, tag, scalar_value, global_step): + super().add_scalar(tag, scalar_value, global_step) + + name = tag.split('/')[0] + self.context_dict[name].append(round(scalar_value, self.ndigits)) + + def close(self): + pass + + +class SummaryWriterWithAD(SummaryWriter, BaseWriterWithAD): + def __init__(self, path, ad_rules, job_id, anomaly_inform=False, anomaly_factory=None, ndigits=6): + super(SummaryWriter, self).__init__(path, ad_rules, job_id, anomaly_inform, anomaly_factory, ndigits) + super().__init__(path) + change_mode(path, FileCheckConst.DATA_DIR_AUTHORITY) + + def add_scalar(self, tag, scalar_value, global_step): + super(SummaryWriter, self).add_scalar(tag, scalar_value, global_step) + return super().add_scalar(tag, scalar_value, global_step) diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/anomaly_inform.py b/debug/accuracy_tools/msprobe/pytorch/monitor/anomaly_inform.py index 21e4e3a84fdf947787b80275c34d7e384f77f1b2..51c554e1eeca0950f70d8f813096af4b42776696 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/anomaly_inform.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/anomaly_inform.py @@ -1,39 +1,20 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - import smtplib from email.mime.text import MIMEText from datetime import datetime, timedelta -from msprobe.core.common.const import MonitorConst -from msprobe.pytorch.monitor.database import Database, ExceptionMessage -from msprobe.pytorch.monitor.utils import beijing_tz +from kj600.database import Database, ExceptionMessage # define class InformRegistry to get inform_sub_class class AnomalyInformFactory: @staticmethod def create_informer(**kwargs): - recipient = kwargs.get("recipient") - if recipient == MonitorConst.DATABASE: + if kwargs['recipient'] == "database": return DatabaseInform(**kwargs) - elif recipient == MonitorConst.EMAIL: + elif kwargs['recipient'] == "email": return EmailInform(**kwargs) - raise ValueError("Invalid recipient specified") + else: + raise ValueError("Invaild recipient specified") # define class AnomalyInform to inform with database or email @@ -49,15 +30,15 @@ class AnomalyInform: def run(self, exception_message, job_id): if self.time != 0 and self.current_time == 0: - self.current_time = datetime.now(tz=beijing_tz) + self.current_time = datetime.now() if self.time == 0 or ((self.current_time - self.time) > timedelta(minutes=self.interval_time)): self.exception_message_list.append(exception_message) self.inform_fun(self.exception_message_list, job_id) self.exception_message_list = [] - self.time = datetime.now(tz=beijing_tz) + self.time = datetime.now() elif (self.current_time - self.time) <= timedelta(minutes=self.interval_time): self.exception_message_list.append(exception_message) - self.current_time = datetime.now(tz=beijing_tz) + self.current_time = datetime.now() class DatabaseInform(AnomalyInform): @@ -70,11 +51,7 @@ class DatabaseInform(AnomalyInform): def inform_fun(self, exception_message_list, job_id): save_list = [] for exception_message in exception_message_list: - item = { - 'job_id': job_id, - 'message': exception_message, - 'create_time': datetime.now(tz=beijing_tz) - } + item = {'job_id': job_id, 'message': exception_message, 'create_time': datetime.now()} save_list.append(ExceptionMessage(**item)) self.database.insert_batch(save_list) diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/file_check.py b/debug/accuracy_tools/msprobe/pytorch/monitor/file_check.py index 9ce5d7aed15daac139f18b8d172e14648717e045..e481838de19e61c8709d06e7c676746b368ed141 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/file_check.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/file_check.py @@ -1,9 +1,8 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); +""" +# Copyright (C) 2022-2024. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # @@ -14,13 +13,62 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +""" import os import re -from msprobe.core.common.log import logger -from msprobe.core.common.exceptions import FileCheckException -from msprobe.core.common.const import FileCheckConst +from kj600.utils import print_error_log + + +class CodedException(Exception): + def __init__(self, code, error_info=""): + super().__init__() + self.code = code + self.error_info = self.err_strs.get(code) + error_info + + def __str__(self): + return self.error_info + + +class FileCheckException(CodedException): + INVALID_FILE_ERROR = 0 + FILE_PERMISSION_ERROR = 1 + SOFT_LINK_ERROR = 2 + ILLEGAL_PATH_ERROR = 3 + ILLEGAL_PARAM_ERROR = 4 + FILE_TOO_LARGE_ERROR = 5 + + err_strs = { + SOFT_LINK_ERROR: "[kj600] 检测到软链接: ", + FILE_PERMISSION_ERROR: "[kj600] 文件权限错误: ", + INVALID_FILE_ERROR: "[kj600] 无效文件: ", + ILLEGAL_PATH_ERROR: "[kj600] 非法文件路径: ", + ILLEGAL_PARAM_ERROR: "[kj600] 非法打开方式: ", + FILE_TOO_LARGE_ERROR: "[kj600] 文件过大: ", + } + + +class FileCheckConst: + """ + Class for file check const + """ + + READ_ABLE = "read" + WRITE_ABLE = "write" + READ_WRITE_ABLE = "read and write" + DIRECTORY_LENGTH = 4096 + FILE_NAME_LENGTH = 255 + FILE_VALID_PATTERN = r"^[a-zA-Z0-9_.:/-]+$" + FILE_PATTERN = r"^[a-zA-Z0-9_./-]+$" + JSON_SUFFIX = ".json" + MAX_JSON_SIZE = 1073741824 # 1 * 1024 * 1024 * 1024 + DIR = "dir" + FILE = "file" + DATA_DIR_AUTHORITY = 0o750 + DATA_FILE_AUTHORITY = 0o640 + FILE_SIZE_DICT = { + JSON_SUFFIX: MAX_JSON_SIZE, + } class FileChecker: @@ -34,7 +82,9 @@ class FileChecker: file_type(str): The correct file type for file """ - def __init__(self, file_path, path_type, ability=None, file_type=None, is_script=True): + def __init__( + self, file_path, path_type, ability=None, file_type=None, is_script=True + ): self.file_path = file_path self.path_type = self._check_path_type(path_type) self.ability = ability @@ -44,7 +94,9 @@ class FileChecker: @staticmethod def _check_path_type(path_type): if path_type not in [FileCheckConst.DIR, FileCheckConst.FILE]: - logger.error(f'The path_type must be {FileCheckConst.DIR} or {FileCheckConst.FILE}.') + print_error_log( + f"The path_type must be {FileCheckConst.DIR} or {FileCheckConst.FILE}." + ) raise FileCheckException(FileCheckException.ILLEGAL_PARAM_ERROR) return path_type @@ -61,7 +113,7 @@ class FileChecker: self.check_path_ability() if self.is_script: check_path_owner_consistent(self.file_path) - check_path_pattern_valid(self.file_path) + check_path_pattern_vaild(self.file_path) check_common_file_size(self.file_path) check_file_suffix(self.file_path, self.file_type) return self.file_path @@ -84,11 +136,12 @@ class FileOpen: file_path: The file or dictionary path to be opened. mode(str): The file open mode """ + SUPPORT_READ_MODE = ["r", "rb"] SUPPORT_WRITE_MODE = ["w", "wb", "a", "ab"] SUPPORT_READ_WRITE_MODE = ["r+", "rb+", "w+", "wb+", "a+", "ab+"] - def __init__(self, file_path, mode, encoding='utf-8'): + def __init__(self, file_path, mode, encoding="utf-8"): self.file_path = file_path self.mode = mode self.encoding = encoding @@ -108,15 +161,19 @@ class FileOpen: self._handle.close() def check_file_path(self): - support_mode = self.SUPPORT_READ_MODE + self.SUPPORT_WRITE_MODE + self.SUPPORT_READ_WRITE_MODE + support_mode = ( + self.SUPPORT_READ_MODE + + self.SUPPORT_WRITE_MODE + + self.SUPPORT_READ_WRITE_MODE + ) if self.mode not in support_mode: - logger.error("File open not support %s mode" % self.mode) + print_error_log(f"File open not support {self.mode} mode") raise FileCheckException(FileCheckException.ILLEGAL_PARAM_ERROR) check_link(self.file_path) self.file_path = os.path.realpath(self.file_path) check_path_length(self.file_path) self.check_ability_and_owner() - check_path_pattern_valid(self.file_path) + check_path_pattern_vaild(self.file_path) if os.path.exists(self.file_path): check_common_file_size(self.file_path) @@ -137,66 +194,68 @@ class FileOpen: def check_link(path): abs_path = os.path.abspath(path) if os.path.islink(abs_path): - logger.error('The file path {} is a soft link.'.format(path)) + print_error_log(f"The file path {path} is a soft link.") raise FileCheckException(FileCheckException.SOFT_LINK_ERROR) -def check_path_length(path, name_length=None): - file_max_name_length = name_length if name_length else FileCheckConst.FILE_NAME_LENGTH - if len(path) > FileCheckConst.DIRECTORY_LENGTH or \ - len(os.path.basename(path)) > file_max_name_length: - logger.error('The file path length exceeds limit.') +def check_path_length(path): + if path_len_exceeds_limit(path): + print_error_log("The file path length exceeds limit.") raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR) def check_path_exists(path): if not os.path.exists(path): - logger.error('The file path %s does not exist.' % path) + print_error_log(f"The file path {path} does not exist.") raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR) def check_path_readability(path): if not os.access(path, os.R_OK): - logger.error('The file path %s is not readable.' % path) + print_error_log(f"The file path {path} is not readable.") raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR) def check_path_writability(path): if not os.access(path, os.W_OK): - logger.error('The file path %s is not writable.' % path) + print_error_log(f"The file path {path} is not writable.") raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR) def check_path_executable(path): if not os.access(path, os.X_OK): - logger.error('The file path %s is not executable.' % path) + print_error_log(f"The file path {path} is not executable.") raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR) def check_other_user_writable(path): st = os.stat(path) if st.st_mode & 0o002: - logger.error('The file path %s may be insecure because other users have write permissions. ' % path) + print_error_log( + f"The file path {path} may be insecure because other users have write permissions. " + ) raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR) def check_path_owner_consistent(path): file_owner = os.stat(path).st_uid if file_owner != os.getuid(): - logger.error('The file path %s may be insecure because is does not belong to you.' % path) + print_error_log( + f"The file path {path} may be insecure because is does not belong to you." + ) raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR) -def check_path_pattern_valid(path): +def check_path_pattern_vaild(path): if not re.match(FileCheckConst.FILE_VALID_PATTERN, path): - logger.error('The file path %s contains special characters.' % path) + print_error_log(f"The file path {path} contains special characters.") raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR) def check_file_size(file_path, max_size): file_size = os.path.getsize(file_path) if file_size >= max_size: - logger.error(f'The size of file path {file_path} exceeds {max_size} bytes.') + print_error_log(f"The size of file path {file_path} exceeds {max_size} bytes.") raise FileCheckException(FileCheckException.FILE_TOO_LARGE_ERROR) @@ -211,46 +270,32 @@ def check_common_file_size(file_path): def check_file_suffix(file_path, file_suffix): if file_suffix: if not file_path.endswith(file_suffix): - logger.error(f"The {file_path} should be a {file_suffix} file!") + print_error_log(f"The {file_path} should be a {file_suffix} file!") raise FileCheckException(FileCheckException.INVALID_FILE_ERROR) def check_path_type(file_path, file_type): if file_type == FileCheckConst.FILE: if not os.path.isfile(file_path): - logger.error(f"The {file_path} should be a file!") + print_error_log(f"The {file_path} should be a file!") raise FileCheckException(FileCheckException.INVALID_FILE_ERROR) if file_type == FileCheckConst.DIR: if not os.path.isdir(file_path): - logger.error(f"The {file_path} should be a dictionary!") + print_error_log(f"The {file_path} should be a dictionary!") raise FileCheckException(FileCheckException.INVALID_FILE_ERROR) -def create_directory(dir_path): - """ - Function Description: - creating a directory with specified permissions - Parameter: - dir_path: directory path - Exception Description: - when invalid data throw exception - """ - dir_path = os.path.realpath(dir_path) - try: - os.makedirs(dir_path, mode=FileCheckConst.DATA_DIR_AUTHORITY, exist_ok=True) - except OSError as ex: - raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR, - 'Failed to create {}. Please check the path permission or disk space .{}'.format( - dir_path, str(ex))) from ex - - def check_path_before_create(path): if path_len_exceeds_limit(path): - raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR, 'The file path length exceeds limit.') + raise FileCheckException( + FileCheckException.ILLEGAL_PATH_ERROR, "The file path length exceeds limit." + ) if not re.match(FileCheckConst.FILE_PATTERN, os.path.realpath(path)): - raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR, - 'The file path {} contains special characters.'.format(path)) + raise FileCheckException( + FileCheckException.ILLEGAL_PATH_ERROR, + f"The file path {path} contains special characters." + ) def change_mode(path, mode): @@ -259,28 +304,14 @@ def change_mode(path, mode): try: os.chmod(path, mode) except PermissionError as ex: - raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR, - 'Failed to change {} authority. {}'.format(path, str(ex))) from ex + raise FileCheckException( + FileCheckException.FILE_PERMISSION_ERROR, + f"Failed to change {path} authority. {str(ex)}", + ) from ex def path_len_exceeds_limit(file_path): - return len(os.path.realpath(file_path)) > FileCheckConst.DIRECTORY_LENGTH or \ - len(os.path.basename(file_path)) > FileCheckConst.FILE_NAME_LENGTH - - -def check_file_type(path): - """ - Function Description: - determine if it is a file or a directory - Parameter: - path: path - Exception Description: - when neither a file nor a directory throw exception - """ - if os.path.isdir(path): - return FileCheckConst.DIR - elif os.path.isfile(path): - return FileCheckConst.FILE - else: - logger.error('Neither a file nor a directory.') - raise FileCheckException(FileCheckException.INVALID_FILE_ERROR) + return ( + len(os.path.realpath(file_path)) > FileCheckConst.DIRECTORY_LENGTH + or len(os.path.basename(file_path)) > FileCheckConst.FILE_NAME_LENGTH + ) diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py b/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py index f6e7a7de67a7e50133c0f0a6a4d7ee849954686a..a26ffb554db9c653053c136766b32f36e7e0a63e 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/module_hook.py @@ -1,58 +1,58 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging +import inspect import os import uuid import json from collections import defaultdict +from functools import partial +from copy import deepcopy from datetime import datetime import torch +torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0' +if not torch_version_above_or_equal_2: + raise ValueError("msmonitor require torch>=2.0") + import torch.distributed as dist +from torch.utils.hooks import BackwardHook from torch.optim.optimizer import register_optimizer_step_pre_hook, register_optimizer_step_post_hook -from msprobe.pytorch.monitor.module_spec_verifier import validate_config_spec -from msprobe.pytorch.monitor.optimizer_collect import MixPrecsionOptimizerMon, OptimizerMonFactory -from msprobe.pytorch.monitor.features import eff_rank, get_sign_matches -from msprobe.pytorch.monitor.visualizer import HeatmapVisualizer -from msprobe.pytorch.monitor.anomaly_detect import AnomalyScanner, SummaryWriterWithAD -from msprobe.pytorch.monitor.anomaly_inform import AnomalyInformFactory -from msprobe.pytorch.monitor.module_metric import get_metrics, write_metrics_tensorboard, get_summary_writer_tag_name, \ - TensorMetrics -from msprobe.pytorch.monitor.distributed.wrap_distributed import api_register, create_hooks, op_aggregate -from msprobe.pytorch.monitor.utils import print_warn_log, print_info_log, print_rank_0, get_param_struct, \ - check_path_length, check_path_pattern_valid, change_mode, FileCheckConst, validate_config, beijing_tz -from msprobe.pytorch.monitor.file_check import FileOpen -from msprobe.core.common.const import MonitorConst +from kj600.module_spec_verifier import validate_config_spec +from kj600.optimizer_collect import OptimizerMon, print_rank_0, OptimizerMonFactory +from kj600.features import eff_rank, get_sign_matches +from kj600.visualizer import HeatmapVisualizer +from kj600.anomaly_detect import AnomalyScanner, AnomalyDataFactory, SummaryWriterWithAD, CSVWriterWithAD, \ + BaseWriterWithAD +from kj600.anomaly_analyse import AnomalyDataWriter +from kj600.module_metric import get_metrics, write_metrics_tensorboard, write_metrics_csv, get_summary_writer_tag_name, \ + TensorMetrics, squash_param_name, sqrt_norm_metric, reorder_metric +from kj600.distributed.wrap_distributed import api_register, create_hooks, op_aggregate, get_process_group +from kj600.utils import print_warn_log, print_info_log, print_error_log, get_param_struct, validate_config, validate_ops +from kj600.const import Const +from kj600.file_check import FileOpen -torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0' -if not torch_version_above_or_equal_2: - raise ValueError("monitor require torch>=2.0") +try: + import torch_npu +except ImportError: + pass -output_base_dir = os.getenv(MonitorConst.MONITOR_OUTPUT_DIR, MonitorConst.DEFAULT_MONITOR_OUTPUT_DIR) + +def param_is_not_tensor_parallel_duplicate(param, tp_group): + return (hasattr(param, 'tensor_model_parallel') and param.tensor_model_parallel) or ( + torch.distributed.get_rank(group=tp_group) == 0 + ) + + +def param_is_data_parallel_duplicate(dp_group): + return torch.distributed.get_rank(group=dp_group) != 0 class ModuleHookContext: def __init__(self, module_name) -> None: self.step = 0 self.micro_step = 0 - self.actv = [] + self.actv = defaultdict(dict) self.actvgrad = [] self.module_name = module_name + self.struct = {} self.format_by_arg = {} self.verified = False self.focused_in_col = 0 @@ -60,12 +60,12 @@ class ModuleHookContext: self.ignore_in = False # no need to care when no key 'input' or 'input_grad' found def set_format_by_arg(self, key_name: str, target_config: dict): - if key_name in target_config[self.module_name]: - self.format_by_arg[key_name] = target_config[self.module_name][key_name] + cared = target_config.get(self.module_name, self.struct) + if key_name in cared: + config = cared[key_name].get('config') + self.format_by_arg[key_name] = config if config else cared[key_name] elif key_name in ['input', 'input_grad']: self.ignore_in = True - else: - raise KeyError(f"Missing key: {key_name} of {self.module_name} in config.json") class OptimizerContext: @@ -77,7 +77,9 @@ class OptimizerContext: self.param_adam_ratio = defaultdict() self.param_weight_grad = defaultdict() self.param_exp_avg = defaultdict() + self.exp_avg_metric = [] self.param_exp_avg_sq = defaultdict() + self.exp_avg_sq_metric = [] self.metric_list = [] @@ -101,35 +103,53 @@ class CommunicationContext: self.data = self._agg(self.data) +class GradContext: + def __init__(self) -> None: + self.pre = [] + self.post = [] + self.acc_metric = [] + self.acc = {} + self.actv = {} + + def reset(self): + self.pre.clear() + self.post.clear() + self.acc_metric.clear() + self.acc.clear() + self.actv.clear() + + class TrainerMon: tensor_metrics = TensorMetrics() - def __init__(self, config_file_path, params_have_main_grad=True, opt_ty=None) -> None: - """ - config_file_path: str, monitor config path - params_have_main_grad: bool, whether param has attribution main_grad - opt_ty: str, Megatron_Float16OptimizerWithFloat16Params or Megatron_DistributedOptimizer - """ + # opt_ty: "Megatron_Float16OptimizerWithFloat16Params" or "Megatron_DistributedOptimizer" + def __init__(self, config_file_path, process_group=None, params_have_main_grad=True, opt_ty=None) -> None: self.module_fwd_hook_context_by_module = defaultdict(ModuleHookContext) self.module_bwd_hook_context_by_module = defaultdict(ModuleHookContext) self.optimizer_context = defaultdict(OptimizerContext) self.cc_context = defaultdict(CommunicationContext) + self.grad_context = GradContext() + self.process_group = get_process_group(process_group) self.params_have_main_grad = params_have_main_grad + self.opt_ty = opt_ty with FileOpen(config_file_path, 'r') as f: self.config = json.load(f) validate_config(self.config) self.module_rank_list = self.config.get("module_ranks", []) + self.format = self.config.get('format', 'tensorboard') self.eps = self.config.get('eps', 1e-8) self.ops = self.config.get('ops', []) + self.ndigits = self.config.get('ndigits', 6) + self.all_xy = self.config.get('all_xy', False) self.xy_distribution = self.config.get('xy_distribution', False) if not self.xy_distribution: print_rank_0("> module input/output input_grad/output_grad is not monitored. ") - - # backward hook cause megatron-lm pipeline parallel schedule assert exception. + # backward hook cause megatron-lm pipeline parallel schedule assert exception. # TBD: backward hook cause output tensor is view of some base tensor. root cause invesigation pending. self.forward_only = self.config.get('forward_only', False) if self.forward_only: print_rank_0("> only module forward is monitored. ") + self.backward_only = self.config.get('backward_only', False) self.ur_distribution = self.config.get('ur_distribution', False) if not self.ur_distribution: @@ -158,35 +178,73 @@ class TrainerMon: alert_setting = self.config.get('alert', {"rules": []}) self.alert_rules = AnomalyScanner.load_rules(alert_setting["rules"]) - anomaly_inform = AnomalyInformFactory.create_informer( - **alert_setting["inform"]) if "inform" in alert_setting else None - - self.optimizer_hooked = False - cur_time = datetime.now(beijing_tz).strftime('%b%d_%H-%M-%S') + output_base_dir = os.getenv('KJ600_OUTPUT_DIR', './kj600_output') + cur_time = datetime.now().strftime('%b%d_%H-%M-%S') unique_id = str(uuid.uuid4())[:8] + if dist.is_initialized(): - cur_path = os.path.join(output_base_dir, f"{cur_time}-rank{dist.get_rank()}-{unique_id}") - if (dist.get_rank() in self.module_rank_list) or len(self.module_rank_list) == 0: - check_path_length(cur_path) - check_path_pattern_valid(cur_path) - self.summary_writer = SummaryWriterWithAD( - cur_path, self.alert_rules, unique_id, anomaly_inform) + rank = dist.get_rank() + tensorboard_dir = os.path.join(output_base_dir, f"{cur_time}-rank{rank}-{unique_id}") + pp_stage = dist.get_group_rank(self.process_group, rank) + group_mates = dist.get_process_group_ranks(self.process_group) else: - cur_path = os.path.join(output_base_dir, f"{cur_time}-{unique_id}") - check_path_length(cur_path) - check_path_pattern_valid(cur_path) - self.summary_writer = SummaryWriterWithAD(cur_path, self.alert_rules, unique_id, anomaly_inform) - - full_path = os.path.realpath(cur_path) - change_mode(full_path, FileCheckConst.DATA_DIR_AUTHORITY) + rank = 0 + tensorboard_dir = os.path.join(output_base_dir, f"{cur_time}-{unique_id}") + pp_stage = 0 + group_mates = [0] + self.rank = rank + + # 初始化AnomalyData工厂 + self.anomaly_data_factory = AnomalyDataFactory(rank, pp_stage, group_mates) if alert_setting.get('dump', + False) else None + + if self.format == 'tensorboard': + writer = SummaryWriterWithAD + self.write_metrics = write_metrics_tensorboard + elif self.format == 'csv': + writer = CSVWriterWithAD + self.write_metrics = write_metrics_csv + elif self.format == 'api': + writer = BaseWriterWithAD + self.write_metrics = write_metrics_tensorboard + + if (rank in self.module_rank_list) or len(self.module_rank_list) == 0: + + self.summary_writer = writer( + tensorboard_dir, + self.alert_rules, + unique_id, + None, + self.anomaly_data_factory, + self.ndigits + ) + # 初始化anomaly deteted文件目录 + if self.anomaly_data_factory: + self.anomaly_data_writer = AnomalyDataWriter( + os.path.join(output_base_dir, "anomaly_detected"), rank) + self.anomaly_data_writer.init_detected_json() # A HeatmapVisualizer instance is associated with an image self.update_heatmap_visualizer = defaultdict(HeatmapVisualizer) self.ratio_heatmap_visualizer = defaultdict(HeatmapVisualizer) - self.micro_batch_number = 0 + self.micro_batch_number = 1 + + self.model = None + self.weight_hooked = False + self.optimizer_hooked = False + self.param_registered = False + self.vpp = False + self.dp_group = None + self.tp_group = None - self.param_name_list = [] self.param2name = defaultdict(str) + self.name2index = defaultdict() + self.name2indices = defaultdict() + self.name2param = {} + self.param_name_call_id = {} + self.call_id = 0 + self.grad_accs = [] + self.handles = defaultdict(list) self.mix_precision_optimizer_mon = OptimizerMonFactory.create_optimizer_mon(opt_ty) if opt_ty is None: @@ -194,18 +252,30 @@ class TrainerMon: raise Exception("ur_distribution cannot be enabled with unknown optimizer.") if self.mv_distribution: raise Exception("mv_distribution cannot be enabled with unknown optimizer.") + self.verbose = False self.print_struct = self.config.get("print_struct", False) + if self.print_struct: + self.verbose = True self.struct_printed = False - self.module_struct = defaultdict(dict) + self.module_struct = {} + return def __del__(self): if hasattr(self, "summary_writer"): self.summary_writer.close() + @property + def ops(self): + return self._ops + + @ops.setter + def ops(self, value): + self._ops = validate_ops(value) + @staticmethod def set_wrapped_optimizer(_wrapped_optimizer): - MixPrecsionOptimizerMon.set_wrapped_optimizer(_wrapped_optimizer) + OptimizerMon.set_wrapped_optimizer(_wrapped_optimizer) @staticmethod def adhoc_check(target_tensor: torch.tensor, module_name: str, tensor_name: str, rank_list, ops_list): @@ -216,69 +286,163 @@ class TrainerMon: return TrainerMon.tensor_metrics.stat_insert(target_tensor, ops_list, module_name, tensor_name, rank) - @staticmethod - def build_tbtag_tensor_map(module_name, tag, tensor): + def hook_modules(self, model: torch.nn.Module, grad_acc_steps): + if self.module_rank_list and (self.rank not in self.module_rank_list): + return + + if not isinstance(model, list): + model = [model] + self.model = model + self._register_param_name(model) + + self.micro_batch_number = grad_acc_steps + + targets = self.config['targets'] + module_in_all_stage = [key for key in targets.keys() if Const.VPP_SEP not in key] + for key in module_in_all_stage: + struct = targets.pop(key) + targets.update({f'{vpp_stage}{Const.VPP_SEP}{key}': struct for vpp_stage in range(len(model))}) + + hooked_count = 0 + for vpp_stage, model_chunk in enumerate(model): + vpp_stage = f'{vpp_stage}{Const.VPP_SEP}' + targets = [x for x, _ in model_chunk.named_modules()] if self.print_struct else self.config[ + 'targets'].keys() + hooked_count += self._hook_module(targets, model_chunk, vpp_stage) + + print_rank_0(f"> {hooked_count} out of {len(self.config['targets'])} are monitored.") + + def clone_if_tensor(args): + if isinstance(args, tuple): + return tuple([clone_if_tensor(arg) for arg in args]) + elif isinstance(args, torch.Tensor): + return args.clone() + else: + return args + + @torch.no_grad + def wrap_hook_setup(setup): + def wrapped_setup(*args, **kwargs): + args = setup(*args, **kwargs) + args = clone_if_tensor(args) + return args + + return wrapped_setup + + BackwardHook.setup_output_hook = wrap_hook_setup(BackwardHook.setup_output_hook) + + if not self.optimizer_hooked: + self.hook_optimizer() + return + + def generate_mv_metrics(self, opt_context): + if not self.mv_distribution: + return + opt_context.exp_avg_metric = {} + opt_context.exp_avg_sq_metric = {} + m_tag_tensor_map = self.generate_param_metrics('exp_avg', opt_context.param_exp_avg) + v_tag_tensor_map = self.generate_param_metrics('exp_avg_sq', opt_context.param_exp_avg_sq) + for metric_name in self.ops: + opt_context.exp_avg_metric[metric_name] = get_metrics(metric_name, m_tag_tensor_map, self.eps) + opt_context.exp_avg_sq_metric[metric_name] = get_metrics(metric_name, v_tag_tensor_map, self.eps) + + def generate_wgrad_metrics(self): + if not self.wg_distribution: + return {}, {} + + unreduced = {} + if self.weight_hooked: + for metric_name in self.ops: + unreduced[metric_name] = get_metrics(metric_name, self.grad_context.acc, self.eps) + self.grad_context.acc_metric = [unreduced.copy()] + sqrt_norm_metric(unreduced) + unreduced = reorder_metric(unreduced) + + grad_dict = {} + for param, name in self.param2name.items(): + if self.tp_group and not param_is_not_tensor_parallel_duplicate(param, self.tp_group): + continue + if self.dp_group and param_is_data_parallel_duplicate(self.dp_group): + continue + grad = param.main_grad if self.params_have_main_grad else param.grad + if grad is None: + print_warn_log(f"grad is None: {name}, maybe something wrong happened.") + continue + key = get_summary_writer_tag_name(name, 'post_grad', self.rank) + grad_dict[key] = grad + + reduced = {op: get_metrics(op, grad_dict, self.eps) for op in self.ops} + self.grad_context.post = [reduced.copy()] + sqrt_norm_metric(reduced) + reduced = reorder_metric(reduced) + + return reduced, unreduced + + def monitor_gnorm_with_ad(self, model, grad_acc_steps=1, optimizer=None, tp_group=None, dp_group=None): + print_info_log(f'grad acc steps {grad_acc_steps}') + self.hook_optimizer(optimizer) + self.micro_batch_number = grad_acc_steps + + self.dp_group = dp_group + self.tp_group = tp_group + + self._register_param_name(model) + self._hook_weights() + self.hook_modules(model, grad_acc_steps) + + def build_tbtag_tensor_map(self, module_name, tag, tensor): metrics = {} rank = dist.get_rank() if dist.is_initialized() else None key = get_summary_writer_tag_name(module_name, tag, rank) - if tensor is not None: + if torch.is_tensor(tensor): metrics[key] = tensor return metrics - @staticmethod - def generate_cc_metrics(cc_name, cc_tensor): - metrics = defaultdict(dict) - rank = dist.get_rank() if dist.is_initialized() else None - for op, tag2tensor in cc_tensor.data.items(): - for tag, tensor in tag2tensor.items(): - key = get_summary_writer_tag_name(cc_name, tag, rank) - metrics[op].update({key: tensor}) - cc_tensor.reset() - return metrics - def generate_param_metrics(self, tag, param_tensor): metrics = {} rank = dist.get_rank() if dist.is_initialized() else None - for _, name in self.param2name.items(): + for name in self.param2name.values(): key = get_summary_writer_tag_name(name, tag, rank) if name not in param_tensor or param_tensor[name] is None: continue metrics[key] = param_tensor[name] return metrics - def hook_modules(self, model: torch.nn.Module, grad_acc_steps): - # fwd=0, bkd=1 - # targets is module name list like ["xx.xxx1", "xxx.xxx2"] which can be obtained when first run. - if not isinstance(model, torch.nn.Module): - raise TypeError("model should be a nn.Module") - if not isinstance(grad_acc_steps, int) or isinstance(grad_acc_steps, bool): - raise TypeError("grad_acc_steps should be int") - print_rank_0("> module names:") - for name, _ in model.named_modules(): - print_rank_0(f"\t{name}") - - self.micro_batch_number = grad_acc_steps - - if not self.module_rank_list or (dist.is_initialized() and dist.get_rank() in self.module_rank_list): - targets = [x for x, _ in model.named_modules()] if self.print_struct else self.config['targets'].keys() - hooked_count = self._hook_module(targets, model, fwd_or_bkd=0) - print_rank_0(f"> {hooked_count} out of {len(self.config['targets'])} are monitored.") - else: - return + def generate_cc_metrics(self, cc_name, cc_tensor): + metrics = defaultdict(dict) + rank = dist.get_rank() if dist.is_initialized() else None + for op, tag2tensor in cc_tensor.data.items(): + for tag, tensor in tag2tensor.items(): + key = get_summary_writer_tag_name(cc_name, tag, rank) + metrics[op].update({key: tensor}) + cc_tensor.reset() + return metrics - if not self.optimizer_hooked: - self.optimizer_hooked = True - print_rank_0("> parameter names:") - for name, param in model.named_parameters(): - print_rank_0(f"\t{name}") - for target_module, _ in self.config['targets'].items(): - if name.startswith(target_module): - # name : language_model.encoder.layers.0.mlp.weight - # target_module:language_model.encoder.layers.0 - self.param_name_list.append(name) - self.param2name[param] = name - self.hook_optimizer() - return + def generate_xy_metrics(self): + actv = {} + for fwd_context in self.module_fwd_hook_context_by_module.values(): + for op in self.ops: + if op not in actv: + actv[op] = {} + actv[op].update(fwd_context.actv[op]) + sqrt_norm_metric(actv) + actv = reorder_metric(actv) + + actv_grad = deepcopy(self.grad_context.actv) + sqrt_norm_metric(actv_grad) + actv_grad = reorder_metric(actv_grad) + + return actv, actv_grad + + def reload_xy(self, xy_distribution=False): + self.xy_distribution = xy_distribution + + for handle in self.handles['xy']: + handle.remove() + self.handles['xy'].clear() + self.hook_modules(self.model, self.micro_batch_number) + for _, fwd_context in self.module_fwd_hook_context_by_module.items(): + fwd_context.actv.clear() def write_adhoc_check(self, step): TrainerMon.tensor_metrics.flush(self.summary_writer) @@ -287,32 +451,51 @@ class TrainerMon: if not self.xy_distribution: return for _, fwd_context in self.module_fwd_hook_context_by_module.items(): - if not len(fwd_context.actv) == self.micro_batch_number: - print_warn_log( - f"fwd_context.actv not equal to micro_batch_number: {len(fwd_context.actv)}, " - f"{self.micro_batch_number}") - for metric_name in self.ops: - write_metrics_tensorboard(metric_name, self.summary_writer, fwd_context.actv, step) + if len(fwd_context.actv) == 0: + continue + self.write_metrics(self.ops, self.summary_writer, [fwd_context.actv], step, 'actv') fwd_context.actv.clear() + if self.grad_context.actv: + self.write_metrics(self.ops, self.summary_writer, [self.grad_context.actv], step, 'actv_grad') - for _, bwd_context in self.module_bwd_hook_context_by_module.items(): - if not len(bwd_context.actvgrad) == self.micro_batch_number: - print_warn_log( - f"bwd_context.actvgrad not equal to micro_batch_number: {len(bwd_context.actvgrad)}, " - f"{self.micro_batch_number}") - for metric_name in self.ops: - write_metrics_tensorboard(metric_name, self.summary_writer, bwd_context.actvgrad, step) - bwd_context.actvgrad.clear() + def write_mv_tb(self, opt_context): + if not self.mv_distribution: + return + self.write_metrics(self.ops, self.summary_writer, [opt_context.exp_avg_metric], opt_context.step, 'exp_avg') + self.write_metrics(self.ops, self.summary_writer, [opt_context.exp_avg_sq_metric], opt_context.step, + 'exp_avg_sq') + + def write_grad_tb(self, step): + if not self.wg_distribution: + return + + self.write_metrics(self.ops, self.summary_writer, self.grad_context.post, step, 'grad_reduced') + self.write_metrics(self.ops, self.summary_writer, self.grad_context.acc_metric, step, 'grad_unreduced') - def hook_optimizer(self): + def hook_optimizer(self, optimizer=None): # in DDP by default use params_have_main_grad def optimizer_pre_step_hook(optimizer, args, kwargs): context = self.optimizer_context[optimizer] - if (self.print_struct and not all(value == {} for value in self.module_struct.values()) - and not self.struct_printed): + if self.opt_ty in Const.DEEPSPEED_OPT_TY: + if context.step == 0: + return + elif context.step == 1: + self.name2indices = self.mix_precision_optimizer_mon.get_param_index(self.param2name, + self.name2index) + mv_result = self.mix_precision_optimizer_mon.fetch_mv(self, optimizer, self.param2name, + self.name2indices) + self.param2name = mv_result.grad + else: + mv_result = self.mix_precision_optimizer_mon.fetch_mv(self, optimizer, self.param2name) + context.param_exp_avg = mv_result.exp_avg + context.param_exp_avg_sq = mv_result.exp_avg_sq + context.param_adam_update = mv_result.update + context.param_adam_ratio = mv_result.ratio + + if self.print_struct and not all( + value == {} for value in self.module_struct.values()) and not self.struct_printed: self._smallest_rank_print("> module struct:") - self._smallest_rank_print(json.dumps(self.module_struct)) - self.struct_printed = True + self._smallest_rank_print(json.dumps(self.module_struct, indent=4)) if not self.cc_log_only: raise Exception("exit after first step when print model struct") if self.cc_log_only and context.step > 0: @@ -321,33 +504,23 @@ class TrainerMon: json.dumps({k: [i.split(';') for i in v] for k, v in self.cc_logged_stack.items()}, indent=4)) raise Exception("exit after first step when print cc stack") - context.param_exp_avg, context.param_exp_avg_sq, context.param_adam_update, context.param_adam_ratio = \ - self.mix_precision_optimizer_mon.fetch_mv(self, optimizer, self.param2name) + self.generate_wgrad_metrics() + self.generate_mv_metrics(context) - for param, name in self.param2name.items(): - if "params_effrank" in self.config and name in self.config["params_effrank"]: - context.param_effective_rank[name] = eff_rank(param.detach()) - grad = param.main_grad if self.params_have_main_grad else param.grad - if grad is None: - print_warn_log(f"grad is None: {name}, maybe something wrong happened.") - continue - if self.wg_distribution: - context.param_weight_grad[name] = grad - if self.mg_direction: + tbtag_tensor_map = {} + if self.mg_direction: + for param, name in self.param2name.items(): + grad = param.main_grad if self.params_have_main_grad else param.grad + if grad is None: + print_warn_log(f"grad is None: {name}, maybe something wrong happened.") + continue if context.step == 0: same_direction_ratio = torch.tensor(1.) else: same_direction_ratio = get_sign_matches(grad, context.param_exp_avg[name]) context.param_mg_direction[name] = same_direction_ratio - - tbtag_tensor_map = {} - if self.wg_distribution: - tbtag_tensor_map.update(self.generate_param_metrics('weight_grad', context.param_weight_grad)) - if self.mv_distribution: - tbtag_tensor_map.update(self.generate_param_metrics('exp_avg', context.param_exp_avg)) - tbtag_tensor_map.update(self.generate_param_metrics('exp_avg_sq', context.param_exp_avg_sq)) - if self.mg_direction: tbtag_tensor_map.update(self.generate_param_metrics('mg_direction', context.param_mg_direction)) + metric_dict = {} for metric_name in self.ops: metric_dict[metric_name] = get_metrics(metric_name, tbtag_tensor_map, self.eps) @@ -356,6 +529,7 @@ class TrainerMon: cc_metrics = self.generate_cc_metrics(k, c) for op, m in cc_metrics.items(): metric_dict[op].update(m) + if not metric_dict: return context.metric_list.append(metric_dict) @@ -363,36 +537,63 @@ class TrainerMon: def optimizer_post_step_hook(optimizer, args, kwargs): context = self.optimizer_context[optimizer] + if (self.opt_ty in Const.DEEPSPEED_OPT_TY and context.step == 0): + context.step += 1 + return rank = dist.get_rank() if dist.is_initialized() else None + if self.anomaly_data_factory: + self.anomaly_data_factory.set_call_id(self.param_name_call_id) self.write_xy_tb(context.step) + self.write_grad_tb(context.step) + self.write_mv_tb(context) self.write_adhoc_check(context.step) if self.ur_distribution: for param_name, _ in context.param_adam_update.items(): self.update_heatmap_visualizer[param_name].visualize( - get_summary_writer_tag_name(param_name, 'adam_update', rank), context.step, - self.summary_writer) + get_summary_writer_tag_name(param_name, 'adam_update', rank), context.step, self.summary_writer) for param_name, _ in context.param_adam_ratio.items(): self.ratio_heatmap_visualizer[param_name].visualize( - get_summary_writer_tag_name(param_name, 'adam_ratio', rank), context.step, - self.summary_writer) + get_summary_writer_tag_name(param_name, 'adam_ratio', rank), context.step, self.summary_writer) - for metric_name in self.ops: - if not context.metric_list: - break - write_metrics_tensorboard(metric_name, self.summary_writer, context.metric_list, context.step) + if context.metric_list: + self.write_metrics(self.ops, self.summary_writer, context.metric_list, context.step, 'other') context.metric_list.clear() context.step += 1 + self.grad_context.reset() + if self.anomaly_data_factory: + self.anomaly_data_writer.write_detected_json(self.summary_writer.get_anomalies()) + self.summary_writer.clear_anomalies() + self.call_id = 0 + self.param_name_call_id.clear() + return + + def patch_step(func, optimizer): + def wrapper(*args, **kwargs): + optimizer_pre_step_hook(optimizer, args, kwargs) + out = func(*args, **kwargs) + optimizer_post_step_hook(optimizer, args, kwargs) + return out + return wrapper + + if self.optimizer_hooked: return - if not self.module_rank_list or (dist.is_initialized() and dist.get_rank() in self.module_rank_list): - register_optimizer_step_pre_hook(optimizer_pre_step_hook) - register_optimizer_step_post_hook(optimizer_post_step_hook) + if optimizer: + optimizer.__class__.step = patch_step(optimizer.__class__.step, optimizer) + + else: + if not self.module_rank_list or (dist.is_initialized() and dist.get_rank() in self.module_rank_list): + register_optimizer_step_pre_hook(optimizer_pre_step_hook) + register_optimizer_step_post_hook(optimizer_post_step_hook) + self.optimizer_hooked = True return def _smallest_rank_print(self, msg): + if not self.verbose: + return if dist.is_initialized(): if self.module_rank_list: if dist.get_rank() == min(self.module_rank_list): @@ -403,48 +604,132 @@ class TrainerMon: else: print_info_log(msg) - def _hook_module(self, target_names, module: torch.nn.Module, fwd_or_bkd): + def _is_target_param(self, param_name, param, prefix): + squash_name = prefix + squash_param_name(param_name) + name = prefix + param_name + for target in self.config['targets'].keys(): + if param_name.startswith(target) or squash_name.startswith(target) or name.startswith(target): + setattr(param, "zero_out_wgrad", True) + return True + + return False + + def _register_chunk(self, model_chunk, prefix): + for index, (param_name, param) in enumerate(model_chunk.named_parameters()): + if not param.requires_grad: + continue + if self._is_target_param(param_name, param, prefix): + name = prefix + squash_param_name(param_name) + if name in self.param2name.values(): + print_error_log(f'same name {name} for different param. Current param is {param_name}. \ + May be error of squash_param_name') + raise Exception("param with same name will be overwritten.") + self.param2name[param] = name + self.name2param[name] = param + self.name2index[name] = index + + def _register_param_name(self, model): + if self.param_registered: + return + + if not isinstance(model, list): + model = [model] + + if len(model) > 1: + self.vpp = True + self._smallest_rank_print('vpp enabled') + + for vpp_stage, model_chunk in enumerate(model): + prefix = f'{vpp_stage}{Const.VPP_SEP}' + self._register_chunk(model_chunk, prefix) + + self.param_registered = True + + def _is_target_module(self, module_name, targets, vpp_stage): + if self.all_xy or self.print_struct: + return vpp_stage + squash_param_name(module_name) + for pattern in [ + vpp_stage + squash_param_name(module_name), + vpp_stage + module_name, + ]: + if pattern in targets: + return pattern + return "" + + def _hook_module(self, target_names, module: torch.nn.Module, vpp_stage=''): if '_modules' not in module.__dict__: # nothing to hook return 0 + def _is_recomputation(): + """Check if the current operation is in the recomputation phase. + + This function inspects the current call stack to determine if the 'backward' function is being + executed and if the execution is taking place within the 'torch/autograd/function.py' file. + If both conditions are met, it indicates that the current operation is in the recomputation phase. + + Returns: + bool: True if in the recomputation phase, False otherwise. + """ + backward_function_indices = [] + call_stack = inspect.stack() + + # Identify indices in the call stack where the 'backward' function is being executed + for idx, frame_info in enumerate(call_stack): + if frame_info.function == 'backward': + backward_function_indices.append(idx) + + # Check if the execution is within 'torch/autograd/function.py' file + for idx in backward_function_indices: + if idx + 1 < len(call_stack) and call_stack[idx + 1].filename.endswith('torch/autograd/function.py'): + return True + + return False + def fwd_hook_fun(module, module_input, module_output): + if _is_recomputation(): + return context: ModuleHookContext = self.module_fwd_hook_context_by_module[module] + if not context.struct: + context.struct = {Const.ACTV_IN: get_param_struct(module_input), + Const.ACTV_OUT: get_param_struct(module_output)} if self.print_struct: if context.module_name not in self.module_struct: self.module_struct[context.module_name] = {} - self.module_struct[context.module_name].update({ - "input": f"{get_param_struct(module_input)}", - "output": f"{get_param_struct(module_output)}" - }) + self.module_struct[context.module_name].update(context.struct) return - if not self.xy_distribution: + if not module.training: return if not context.format_by_arg: - context.set_format_by_arg('input', self.config['targets']) - context.set_format_by_arg('output', self.config['targets']) + context.set_format_by_arg(Const.ACTV_IN, self.config['targets']) + context.set_format_by_arg(Const.ACTV_OUT, self.config['targets']) + if not context.format_by_arg: + return if not context.verified: if not context.ignore_in: - context.focused_in_col = validate_config_spec(context.format_by_arg['input'], module_input, - context.module_name, 'input') - context.focused_out_col = validate_config_spec(context.format_by_arg['output'], module_output, - context.module_name, 'output') + context.focused_in_col = validate_config_spec(context.format_by_arg[Const.ACTV_IN], module_input, + context.module_name, Const.ACTV_IN) + context.focused_out_col = validate_config_spec(context.format_by_arg[Const.ACTV_OUT], module_output, + context.module_name, Const.ACTV_OUT) context.verified = True # expect output be tensor type tbtag_tensor_map = {} if not context.ignore_in: cared_input = module_input if context.focused_in_col is None else module_input[context.focused_in_col] - tbtag_tensor_map.update(self.build_tbtag_tensor_map(context.module_name, 'input', cared_input)) + tbtag_tensor_map.update( + self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', Const.ACTV_IN, + cared_input)) cared_output = module_output if context.focused_out_col is None else module_output[context.focused_out_col] - tbtag_tensor_map.update(self.build_tbtag_tensor_map(context.module_name, 'output', cared_output)) - metric_dict = {} + tbtag_tensor_map.update( + self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', Const.ACTV_OUT, + cared_output)) + for metric_name in self.ops: - metric_dict[metric_name] = get_metrics(metric_name, tbtag_tensor_map, self.eps) - if context.micro_step == 0 and context.actv: - print_warn_log(f"actv context of {context.module_name} is not empty when first micro_step, " - f"maybe something wrong happened. Now clear it.") - context.actv.clear() - context.actv.append(metric_dict) + if context.micro_step == 0 and context.actv.get(metric_name, []): + print_warn_log( + f"actv context of {context.module_name} is not empty when first micro_step, maybe something wrong happened. Now clear it.") + context.actv.clear() + context.actv[metric_name].update(get_metrics(metric_name, tbtag_tensor_map, self.eps)) context.micro_step += 1 if context.micro_step == self.micro_batch_number: @@ -454,41 +739,47 @@ class TrainerMon: def bwd_hook_fun(module, input_grad, output_grad): context: ModuleHookContext = self.module_bwd_hook_context_by_module[module] + if not context.struct: + context.struct = {Const.ACTVGRAD_IN: get_param_struct(input_grad), + Const.ACTVGRAD_OUT: get_param_struct(output_grad)} if self.print_struct: - self.module_struct[context.module_name].update({ - "input_grad": f"{get_param_struct(input_grad)}", - "output_grad": f"{get_param_struct(output_grad)}" - }) - return - if not self.xy_distribution: + if context.module_name not in self.module_struct: + self.module_struct[context.module_name] = {} + self.module_struct[context.module_name].update(context.struct) return if not context.format_by_arg: - context.set_format_by_arg('input_grad', self.config['targets']) - context.set_format_by_arg('output_grad', self.config['targets']) + context.set_format_by_arg(Const.ACTVGRAD_IN, self.config['targets']) + context.set_format_by_arg(Const.ACTVGRAD_OUT, self.config['targets']) + if not context.format_by_arg: + return if not context.verified: if not context.ignore_in: - context.focused_in_col = validate_config_spec( - context.format_by_arg['input_grad'], input_grad, context.module_name, 'input_grad') - context.focused_out_col = validate_config_spec( - context.format_by_arg['output_grad'], output_grad, context.module_name, 'output_grad') + context.focused_in_col = validate_config_spec(context.format_by_arg[Const.ACTVGRAD_IN], input_grad, + context.module_name, Const.ACTVGRAD_IN) + context.focused_out_col = validate_config_spec(context.format_by_arg[Const.ACTVGRAD_OUT], output_grad, + context.module_name, Const.ACTVGRAD_OUT) context.verified = True tbtag_tensor_map = {} if not context.ignore_in: cared_input_grad = input_grad if context.focused_in_col is None else input_grad[context.focused_in_col] tbtag_tensor_map.update( - self.build_tbtag_tensor_map(context.module_name, 'input_grad', cared_input_grad)) + self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', Const.ACTVGRAD_IN, + cared_input_grad)) cared_output_grad = output_grad if context.focused_out_col is None else output_grad[context.focused_out_col] - tbtag_tensor_map.update(self.build_tbtag_tensor_map(context.module_name, 'output_grad', - cared_output_grad)) - metric_dict = {} - for metric_name in self.ops: - metric_dict[metric_name] = get_metrics(metric_name, tbtag_tensor_map, self.eps) + tbtag_tensor_map.update( + self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', Const.ACTVGRAD_OUT, + cared_output_grad)) + if context.micro_step == 0 and context.actvgrad: - print_warn_log(f"actvgrad context of {context.module_name} is not empty when first micro_step, " - f"maybe something wrong happened. Now clear it.") + print_warn_log( + f"actvgrad context of {context.module_name} is not empty when first micro_step, maybe something wrong happened. Now clear it.") context.actvgrad.clear() - context.actvgrad.append(metric_dict) + + for metric_name in self.ops: + if metric_name not in self.grad_context.actv: + self.grad_context.actv[metric_name] = {} + self.grad_context.actv[metric_name].update(get_metrics(metric_name, tbtag_tensor_map, self.eps)) context.micro_step += 1 if context.micro_step == self.micro_batch_number: @@ -496,15 +787,50 @@ class TrainerMon: context.step += 1 return + if self.backward_only and self.forward_only: + print_warn_log('not enable backward_only and forward_only simultaneously') + hooked_count = 0 - for name, submodule in module.named_modules(): - self.module_struct[name] = {} - if name in target_names: - submodule.register_forward_hook(fwd_hook_fun) - self.module_fwd_hook_context_by_module[submodule] = ModuleHookContext(name) + if self.xy_distribution or self.print_struct: + for module_name, submodule in module.named_modules(): + name = self._is_target_module(module_name, target_names, vpp_stage) + if not name: + continue + if not self.backward_only: + handle = submodule.register_forward_hook(fwd_hook_fun) + self.handles['xy'].append(handle) + self.module_fwd_hook_context_by_module[submodule] = ModuleHookContext(name) if not self.forward_only: - submodule.register_full_backward_hook(bwd_hook_fun) + handle = submodule.register_full_backward_hook(bwd_hook_fun) + self.handles['xy'].append(handle) self.module_bwd_hook_context_by_module[submodule] = ModuleHookContext(name) print_rank_0(f"> {name} is monitored successfully") hooked_count += 1 return hooked_count + + def _hook_weights(self): + context = self.grad_context + + @torch.no_grad + def param_hook(*args, context_dict, param, key, name): + param.micro_step += 1 + self.param_name_call_id[name] = self.call_id + self.call_id += 1 + if param.micro_step == self.micro_batch_number: + param.micro_step = 0 + if self.params_have_main_grad: + context_dict[key] = param.main_grad.clone() + else: + context_dict[key] = param.grad.clone() + + for param, name in self.param2name.items(): + key = get_summary_writer_tag_name(name, 'acc_grad', self.rank) + setattr(param, 'micro_step', 0) + param_tmp = param.expand_as(param) + grad_acc = param_tmp.grad_fn.next_functions[0][0] + handle = grad_acc.register_hook( + partial(param_hook, context_dict=context.acc, param=param, key=key, name=name)) + self.grad_accs.append(grad_acc) + self.handles['wgrads'].append(handle) + + self.weight_hooked = True diff --git a/debug/accuracy_tools/msprobe/pytorch/monitor/module_metric.py b/debug/accuracy_tools/msprobe/pytorch/monitor/module_metric.py index e840c306a697a2459e21679a2d880a149c5294fd..0de2fed9bcf94d3758ba655dc8b72bb48d397a94 100644 --- a/debug/accuracy_tools/msprobe/pytorch/monitor/module_metric.py +++ b/debug/accuracy_tools/msprobe/pytorch/monitor/module_metric.py @@ -1,32 +1,31 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - import math +import re import statistics +import itertools +import torch -from msprobe.pytorch.monitor.features import square_sum, get_max, get_min, get_zeros, get_nans, get_norm -from msprobe.pytorch.monitor.utils import print_error_log +from kj600.const import Const +from kj600.features import square_sum, get_max, get_min, get_zeros, get_nans, get_norm, get_mean +from kj600.utils import print_warn_log def get_summary_writer_tag_name(module_or_param_name: str, tag: str, rank): if rank is None: return f"{module_or_param_name}/{tag}" else: - return f"{module_or_param_name}/{rank}/{tag}" + return f"{module_or_param_name}/rank{rank}/{tag}" + + +def squash_param_name(param_name): + name = '' + for pattern in ['layers?\.(.*)', 'embeddings?\.(.*)', 'final.*', 'output.*', 'norm.*']: + match = re.findall(pattern, param_name) + if match: + name += match[0] + break + if name == '': + name = param_name + return name # 用于存储所有metric实现类的注册表 @@ -44,14 +43,13 @@ def register_config_metric(key, cls=None): class TensorMetrics: def __init__(self) -> None: - # tensor_tag --> [] - self.metrics = {} + self.metrics = {} # tensor_tag --> [] self.cur_idx = {} - fun_map = {"norm": get_norm, "max": get_max, "min": get_min} + fun_map = {"norm": get_norm, "max": get_max, "min": get_min, "mean": get_mean} # get stats and insert into metrics dictionary - def stat_insert(self, tensor, stat_ops, module_name, tensor_name, rank): + def stat_insert(self, tensor, stat_ops, module_name, tensor_name, rank, eps=1e-8): prefix = get_summary_writer_tag_name(module_name, tensor_name, rank) for stat_op in stat_ops: y = TensorMetrics.fun_map[stat_op](tensor) @@ -81,7 +79,12 @@ class Metric(object): def get_metrics(self, tag2tensor: dict, eps): metrics_dict = {} for tag, tensor in tag2tensor.items(): - metrics_dict[tag] = self.get_metric_value(tensor, eps) + try: + metrics_dict[tag] = self.get_metric_value(tensor, eps) + if torch.isnan(metrics_dict[tag]): + print_warn_log(f'nan when calculate metric for {tag}') + except RuntimeError as e: + metrics_dict[tag] = torch.tensor(torch.nan) return metrics_dict @@ -93,12 +96,22 @@ class MinMetric(Metric): @staticmethod def metric_tensorboard(metric_name, summary_writer, metric_value, step): - try: - for key in metric_value[0][metric_name].keys(): - min_value = min([item[metric_name][key].item() for item in metric_value]) - summary_writer.add_scalar(f'{key}_min', min_value, step) - except Exception as e: - print_error_log(f"min metric metric_tensorboard error: {e}") + for key in metric_value[0][metric_name].keys(): + min_value = min([item[metric_name][key].item() for item in metric_value]) + summary_writer.add_scalar(f'{key}_min', min_value, step) + + +@register_config_metric("mean") +class MeanMetric(Metric): + @staticmethod + def get_metric_value(tensor, eps): + return get_mean(tensor) + + @staticmethod + def metric_tensorboard(metric_name, summary_writer, metric_value, step): + for key in metric_value[0][metric_name].keys(): + mean_value = sum([item[metric_name][key].item() for item in metric_value]) / len(metric_value) + summary_writer.add_scalar(f'{key}_mean', mean_value, step) @register_config_metric("max") @@ -109,12 +122,9 @@ class MaxMetric(Metric): @staticmethod def metric_tensorboard(metric_name, summary_writer, metric_value, step): - try: - for key in metric_value[0][metric_name].keys(): - max_value = max([item[metric_name][key].item() for item in metric_value]) - summary_writer.add_scalar(f'{key}_max', max_value, step) - except Exception as e: - print_error_log(f"max metric metric_tensorboard error: {e}") + for key in metric_value[0][metric_name].keys(): + max_value = max([item[metric_name][key].item() for item in metric_value]) + summary_writer.add_scalar(f'{key}_max', max_value, step) @register_config_metric("norm") @@ -125,12 +135,9 @@ class NormMetric(Metric): @staticmethod def metric_tensorboard(metric_name, summary_writer, metric_value, step): - try: - for key in metric_value[0][metric_name].keys(): - norm_value = math.sqrt(sum([item[metric_name][key].item() for item in metric_value])) - summary_writer.add_scalar(f'{key}_norm', norm_value, step) - except Exception as e: - print_error_log(f"norm metric metric_tensorboard error: {e}") + for key in metric_value[0][metric_name].keys(): + norm_value = math.sqrt(sum([item[metric_name][key].item() for item in metric_value])) + summary_writer.add_scalar(f'{key}_norm', norm_value, step) @register_config_metric("zeros") @@ -141,28 +148,22 @@ class ZerosMetric(Metric): @staticmethod def metric_tensorboard(metric_name, summary_writer, metric_value, step): - try: - for key in metric_value[0][metric_name].keys(): - zeros_value = statistics.mean([item[metric_name][key].item() for item in metric_value]) - summary_writer.add_scalar(f'{key}_zeros', zeros_value, step) - except Exception as e: - print_error_log(f"zeros metric metric_tensorboard error: {e}") + for key in metric_value[0][metric_name].keys(): + zeros_value = statistics.mean([item[metric_name][key].item() for item in metric_value]) + summary_writer.add_scalar(f'{key}_zeros', zeros_value, step) @register_config_metric("nans") class NaNsMetric(Metric): @staticmethod - def get_metric_value(tensor, eps): - return get_nans(tensor) + def get_metric_value(t, eps): + return get_nans(t) @staticmethod def metric_tensorboard(metric_name, summary_writer, metric_value, step): - try: - for key in metric_value[0][metric_name].keys(): - nans_value = sum([v[metric_name][key].item() for v in metric_value]) - summary_writer.add_scalar(f'{key}_nans', nans_value, step) - except Exception as e: - print_error_log(f"nans metric metric_tensorboard error: {e}") + for key in metric_value[0][metric_name].keys(): + nans_value = sum([v[metric_name][key].item() for v in metric_value]) + summary_writer.add_scalar(f'{key}_nans', nans_value, step) @register_config_metric("id") @@ -174,16 +175,28 @@ class IdentMetric(Metric): return tensor @staticmethod - def metric_tensorboard(metric_name, summary_writer, metric_value, step): - # metric_value is a dict, key is parameter name and value is a list of scalar tensor - try: - if len(metric_value) == 1: - for key, value in metric_value[0][metric_name].items(): - if not value: - continue - summary_writer.add_scalar(f'{key}_identical', value.item(), step) - except Exception as e: - print_error_log(f"id metric metric_tensorboard error: {e}") + def metric_tensorboard(metric_name, summary_writer, metric_value, + context): # metric_value is a dict, key is parameter name and value is a list of scalar tensor + if len(metric_value) == 1: + for key, value in metric_value[0][metric_name].items(): + if not value: + continue + summary_writer.add_scalar(f'{key}_identical', value.item(), context) + + +def reorder_metric(metrics): + new_metrics = {} + for op, tag2metric in metrics.items(): + for tag, metric in tag2metric.items(): + if tag not in new_metrics: + new_metrics[tag] = {} + new_metrics[tag][op] = metric + return new_metrics + + +def sqrt_norm_metric(metrics): + if 'norm' in metrics: + metrics["norm"] = {tag: metric ** 0.5 for tag, metric in metrics["norm"].items()} def get_metrics(metric_name, tag2tensor, eps): @@ -192,15 +205,43 @@ def get_metrics(metric_name, tag2tensor, eps): return fun_metric().get_metrics(tag2tensor, eps) except KeyError as e: raise ValueError( - f"Not supported this metric, expected metric: {config_metric_registry.keys()}, actual metric: " - f"{metric_name}") from e + f"Not supported this metric, expected metric: {config_metric_registry.keys()}, actual metric: {metric_name}") from e -def write_metrics_tensorboard(metric_name, summary_writer, metric_value, step): - try: - fun_metric = config_metric_registry[metric_name] - return fun_metric.metric_tensorboard(metric_name, summary_writer, metric_value, step) - except KeyError as e: - raise ValueError( - f"Not supported this metric, expected metric: {config_metric_registry.keys()}, actual metric: " - f"{metric_name}") from e +def write_metrics_tensorboard(ops, summary_writer, metric_value, step, prefix=''): + for metric_name in ops: + try: + fun_metric = config_metric_registry[metric_name] + fun_metric.metric_tensorboard(metric_name, summary_writer, metric_value, step) + except KeyError as e: + raise ValueError( + f"Not supported this metric, expected metric: {config_metric_registry.keys()}, actual metric: {metric_name}") from e + + +def write_metrics_csv(ops, summary_writer, metric_value, step, prefix=''): + for metric_name in ops: + try: + fun_metric = config_metric_registry[metric_name] + fun_metric.metric_tensorboard(metric_name, summary_writer, metric_value, step) + + except KeyError as e: + raise ValueError( + f"Not supported this metric, expected metric: {config_metric_registry.keys()}, actual metric: {metric_name}") from e + + if not summary_writer.header: + if prefix == 'actv': + summary_writer.header = ['module_name'] + else: + summary_writer.header = ['param_name'] + + if prefix in ['actv', 'actv_grad']: + summary_writer.header.extend([''.join(i) for i in itertools.product(ops, ['_input', '_output'])]) + else: + summary_writer.header.extend(ops) + + for key in metric_value[0][ops[0]].keys(): + if Const.VPP_SEP in key: + summary_writer.header.insert(0, 'vpp_stage') + break + summary_writer.write_csv(prefix, step) + summary_writer.header = []