diff --git a/debug/accuracy_tools/kj600/README.md b/debug/accuracy_tools/kj600/README.md index 05fcb4a215b0e124b2227b1f64519819c9fa5b22..b39db6a007357581327c358e856016b737b63c2a 100644 --- a/debug/accuracy_tools/kj600/README.md +++ b/debug/accuracy_tools/kj600/README.md @@ -1,8 +1,8 @@ -# kj600 模型训练状态监控工具 +# TensorProbe (codename:kj600) 模型训练状态监控工具 ## 简介 -本项目开发了名为kj600的模型训练状态监控工具,能够收集和聚合模型训练过程中的层和优化器的中间状态,帮助诊断模型训练过程中出现的异常情况。 +本项目开发了一个模型训练状态监控工具,能够收集和聚合模型训练过程中的网络层,优化器, 通信算子的中间值,帮助诊断模型训练过程中计算, 通信,优化器各部分出现的异常情况。 ## 安装 @@ -41,16 +41,27 @@ pip install -e . "targets": { "language_model.encoder.layers.0": {"input": "tuple[2]:0", "output": "tensor", "input_grad":"tuple[2]:0", "output_grad":"tuple[1]:0"} }, - "module_ranks": "1,2,3,4", - "ur_distribution": true + "print_struct": false, + "module_ranks": [1,2,3,4], + "ur_distribution": true, + "xy_distribution": true, + "mv_distribution": true, + "wg_distribution": true, + "mg_direction": true, + "cc_distribution": {"enable":true, "cc_codeline":[]}, + "alert": { + "rules": [{"rule_name": "AnomalyTurbulence", "args": {"threshold": 0.5}}] + }, + "ops": ["min", "max", "norm", "zeros", "id"], + "eps": 1e-8 } ``` -每个要监控的module有特定的输入输出格式(依赖于模型实现),所以我们需要指定前向输入输出格式和反向计算时输入张量的梯度和输出张量的梯度格式。 如果不清楚的话可以先猜测, 格式规范与实际输入不同时会报详细错误。 我们也会随时更新更多常用module的格式规范。 +每个要监控的module有特定的输入输出格式(依赖于模型实现),所以我们需要指定前向输入输出格式和反向计算时输入张量的梯度和输出张量的梯度格式。 如果不清楚的话可以填空字段("targets":{}),然后将 "print_struct" 字段设置为 true, 之后工具会打印详细的模型结构。 我们也会随时更新更多常用module的格式规范。 下面详细解释各个字段: -"targets":必选字段,指定需要监控的大模型层, 例如transformer的第0层language_model.encoder.layers.0。如果不清楚层命名, 可以使用空的json配置文件, 之后监控工具会打印模型中torch module的名字, 你可以从中选择你关心的module。 +"targets":必选字段,指定需要监控的大模型层, 例如transformer的第0层language_model.encoder.layers.0。如果不清楚模型结构, 可以填空字段("targets":{}),然后将 "print_struct" 字段设置为 true, 之后监控工具会打印模型中torch module的名字和详细结构,并在第1个step后退出, 你可以从中选择你关心的module。 "input":可选字段,"tuple[2]:0"的意思是目标module的前向input参数为长度为2的tuple, 我们关心的是tuple第0个元素。 @@ -62,9 +73,25 @@ pip install -e . "module_ranks":可选字段,用于在分布式训练场景中希望控制在哪些rank开启module监控。如果不填,则默认在所有rank开启。 -"ur_distribution": 可选字段,若为true则会统计adam优化器的update和ratio的数值分布,并展示在heatmap里,默认为false。 +"print_struct":可选字段,设置为true后监控工具会打印模型中torch module的名字和详细结构,并在第1个step后退出。不填默认为false。 + +"ur_distribution": 可选字段,若为true则会统计adam优化器指定模块(targets中指定)参数的update和ratio向量的数值分布,并展示在heatmap里,默认为false。 + +"xy_distribution": 可选字段, 若为true则会监控指定module(targets中指定)的输入输出张量。 默认为false。 + +"mv_distribution": 可选字段, 若为true则会监控指定模块中的参数的优化器状态, 默认为false。 + +"wg_distribution": 可选字段, 若为true则会监控指定模块的参数梯度, 默认为false。 + +"alert": 必选字段。 指定自动报警的异常检测机制及其相应的阈值。目前实现的异常检测是AnomalyTurbulence。 如果统计标量超出历史均值的指定浮动范围(threshold指定, 0.5意味着上浮或者下浮50%)。 目前报警是在控制台打印, 未来会实现发邮件和写数据库。 + +"mg_direction": 可选字段,若为true则会统计adam优化器的一阶矩($m_{t-1}$)和当前梯度($g_t$)符号一致的参数比例。 + +"cc_distribution": 可选字段, 其中“enable”字段控制开关;“code_line”字段指定监控的代码行,默认为空列表,不特别指定。!!开启后, 会在监控过程让异步通信同步。 + +"ops": 可选字段,与ur_distribution、xy_distribution、mv_distribution、wg_distribution、mg_direction、cc_distribution配合,监控所选张量的min、max、norm、zeros值。其中,zeros代表监控所选张量的元素小于eps的比例,id代表监控所选的非张量本身,默认为[]。 -"mg_direction": 可选字段,若为true则会统计adam优化器的动量与当前梯度方向一致的参数比例。 +"eps": 可选字段,若ops里包含"zeros"则需要配置,默认为1e-8。 下面给出transformer架构模型中常见的module的前向计算的输入输出和反向计算输入张量的梯度和输出张量的梯度格式,以供参考: @@ -98,11 +125,14 @@ pip install -e . ``` from kj600.module_hook import TrainerMon - hooker = TrainerMon("./llama2_config.json") - hooker.hook_modules(model=model, global_batch_size=args.global_batch_size, dp=args.data_parallel_size, micro_batch_size=args.micro_batch_size, fwd_or_bkd=0) + hooker = TrainerMon("./llama2_config.json", params_have_main_grad=True, opt_ty="Megatron_DistributedOptimizer") # or opt_ty=Megatron_Float16OptimizerWithFloat16Params + hooker.hook_modules(model=model, grad_acc_steps=args.global_batch_size//args.data_parallel_size//args.micro_batch_size) ``` + params_have_main_grad: 若为True则参数权重梯度为main_grad,否则为grad,默认为True。 + + 如果不是Megatron-LM的训练框架, 可以设置对应的梯度累积步数grad_acc_steps。 - 如果要监控混合精度优化器的动量和方差, 需要在混合精度优化器构造后加入如下代码: + 如果要监控混合精度优化器的动量和方差, 需要在混合精度优化器构造后加入如下代码。 目前只支持Megatron_DistributedOptimizer, 使用bf16或者fp16混合精度时开启分布式优化器。 或者Megatron_Float16OptimizerWithFloat16Params, 使用bf16或者fp16混合精度选项并且不开启分布式优化器。 ``` model, optimizer, opt_param_scheduler = setup_model_and_optimizer( diff --git a/debug/accuracy_tools/kj600/img/cpu_info.png b/debug/accuracy_tools/kj600/img/cpu_info.png new file mode 100644 index 0000000000000000000000000000000000000000..c69eb61b11be5901428fd20b3d5f69909efffafb Binary files /dev/null and b/debug/accuracy_tools/kj600/img/cpu_info.png differ diff --git a/debug/accuracy_tools/kj600/img/train.png b/debug/accuracy_tools/kj600/img/train.png new file mode 100644 index 0000000000000000000000000000000000000000..2dde057196dd934d29fb766cd9fea3ff527a696a Binary files /dev/null and b/debug/accuracy_tools/kj600/img/train.png differ diff --git a/debug/accuracy_tools/kj600/img/train_with_kj600.png b/debug/accuracy_tools/kj600/img/train_with_kj600.png new file mode 100644 index 0000000000000000000000000000000000000000..b64a6d1f48004ac246b53f381f863348f58d196c Binary files /dev/null and b/debug/accuracy_tools/kj600/img/train_with_kj600.png differ diff --git a/debug/accuracy_tools/kj600/kj600/anomaly_detect.py b/debug/accuracy_tools/kj600/kj600/anomaly_detect.py new file mode 100644 index 0000000000000000000000000000000000000000..0ce34fe22aaf13f02ccc0a5333a161186d77cd0f --- /dev/null +++ b/debug/accuracy_tools/kj600/kj600/anomaly_detect.py @@ -0,0 +1,86 @@ +import statistics as st +from abc import ABC +from typing import List +import sys +from torch.utils.tensorboard import SummaryWriter +from collections import defaultdict + +class ScanRule(ABC): + def apply(self, history, cur): + raise NotImplemented("abstract method apply is not implemented") + +class AnomalyTurbulence(ScanRule): + name = "AnomalyTurbulence" + def __init__(self, threshold) -> None: + self.threshold = threshold + def apply(self, history, cur): + baseline = st.mean(history) if isinstance(history, list) else history + + up_bound = baseline + baseline * self.threshold + if baseline > 0: + return cur > up_bound + else: + return cur < up_bound + +class AnomalyScanner: + + @staticmethod + def load_rules(specs: List[dict]): + if specs is None: + return [] + alert_rules = [] + for spec in specs: + rule_cls_name = spec["rule_name"] + rule_args = spec["args"] + cur_module = sys.modules[__name__] + rule_cls = getattr(cur_module, rule_cls_name) + rule_instance = rule_cls(**rule_args) + alert_rules.append(rule_instance) + return alert_rules + + @staticmethod + def scan(scan_rules: List[ScanRule], history, cur): + anomaly = False + for rule in scan_rules: + anomaly = rule.apply(history, cur) + if anomaly: + return anomaly, rule.name + return anomaly, None + +class bcolors: + HEADER = '\033[95m' + OKBLUE = '\033[94m' + OKCYAN = '\033[96m' + OKGREEN = '\033[92m' + WARNING = '\033[93m' + FAIL = '\033[91m' + ENDC = '\033[0m' + BOLD = '\033[1m' + UNDERLINE = '\033[4m' + +class SummaryWriterWithAD(SummaryWriter): + def __init__(self, path, ad_rules, anomaly_inform=False): + super().__init__(path) + self.tag2scalars = defaultdict(list) + self.ad_rules = ad_rules + self.anomaly_inform = anomaly_inform + + def _ad(self, scalar_value, history): + + return AnomalyScanner.scan(self.ad_rules, history, cur=scalar_value) + + def add_scalar(self, tag, scalar_value, global_step=None, walltime=None, new_style=False, double_precision=False): + new_avg = avg = scalar_value + if tag in self.tag2scalars: + N = len(self.tag2scalars[tag]) + _, avg = self.tag2scalars[tag][-1] + new_avg = (avg*N + scalar_value)/(N + 1) + self.tag2scalars[tag].append((scalar_value, new_avg)) + detected, rule_name = self._ad(scalar_value, history=avg) + if detected: + print(f"{bcolors.WARNING}> Rule {rule_name} reports anomaly signal in {tag} at step {global_step}.{bcolors.ENDC}") + exception_message = f"{bcolors.WARNING}> Rule {rule_name} reports anomaly signal in {tag} at step {global_step}.{bcolors.ENDC}" + if self.anomaly_inform: + self.anomaly_inform.run(exception_message) + return super().add_scalar(tag, scalar_value, global_step, walltime, new_style, double_precision) +# if __name__ == "__main__": diff --git a/debug/accuracy_tools/kj600/kj600/anomaly_inform.py b/debug/accuracy_tools/kj600/kj600/anomaly_inform.py new file mode 100644 index 0000000000000000000000000000000000000000..0bdafdaf827e5ac6658bccc0de83294d9f313602 --- /dev/null +++ b/debug/accuracy_tools/kj600/kj600/anomaly_inform.py @@ -0,0 +1,75 @@ +import smtplib +from email.mime.text import MIMEText +import sqlite3 +from datetime import datetime, timedelta + +# define class InformRegistry to get inform_sub_class +class AnomalyInformFactory: + @staticmethod + def create_informer(**kwargs): + if kwargs['recipient'] == "database": + return DatabaseInform(**kwargs) + elif kwargs['recipient'] == "email": + return EmailInform(**kwargs) + else: + raise ValueError("Invaild recipient specified") + +# define class AnomalyInform to inform with database or email +class AnomalyInform: + def __init__(self, **kwargs): + self.inform_args = kwargs + self.exception_message_list = [] + self.time = 0 + self.current_time = 0 + + def inform_fun(self, exception_message_list): + pass + + def run(self, exception_message): + if self.time != 0 and self.current_time == 0: + self.current_time = datetime.now() + if self.time == 0 or ((self.current_time - self.time) > timedelta(minutes=self.interval_time)): + self.exception_message_list.append(exception_message) + self.inform_fun(self.exception_message_list) + self.exception_message_list = [] + self.time = datetime.now() + elif (self.current_time - self.time) <= timedelta(minutes=self.interval_time): + self.exception_message_list.append(exception_message) + self.current_time = datetime.now() + +class DatabaseInform(AnomalyInform): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.interval_time = 2 + + def inform_fun(self, exception_message_list): + with sqlite3.connect(self.inform_args['connection_str']) as conn: + cursor = conn.cursor() + cursor.execute('''CREATE TABLE IF NOT EXISTS exceptions( + id INTEGER PRIMARY KEY, + message TEXT + )''') + now_time = datetime.now() + for exception_message in exception_message_list: + exception_message = f"Current time is :{now_time}" + exception_message + cursor.execute("INSERT INTO exceptions (message) VALUES (?)",(exception_message,)) + +class EmailInform(AnomalyInform): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.interval_time = 10 + + def inform_fun(self, exception_message_list): + subject = "Exception Detected in Your Program" + text = f"{len(exception_message_list)} exception was detected in your program:\n\n" + for exception_message in exception_message_list: + text += exception_message + '\n' + message = MIMEText(text, "plain") + message["Subject"] = subject + message["From"] = self.inform_args['email'] + message["To"] = self.inform_args['email'] + + with smtplib.SMTP(self.inform_args['smtp_server_name'], self.inform_args.get('smtp_number', 587)) as server: + server.starttls() + server.login(self.inform_args['id'], self.inform_args['password']) + server.sendmail(self.inform_args['email'], self.inform_args['email'], message.as_string()) diff --git a/debug/accuracy_tools/kj600/kj600/distributed/distributed_ops.yaml b/debug/accuracy_tools/kj600/kj600/distributed/distributed_ops.yaml new file mode 100644 index 0000000000000000000000000000000000000000..51c803eb0b075a9afe0dbe68b4c84535ae867c60 --- /dev/null +++ b/debug/accuracy_tools/kj600/kj600/distributed/distributed_ops.yaml @@ -0,0 +1,14 @@ +distributed: + - send + - recv + - broadcast + - all_reduce + - reduce + - all_gather + - gather + - isend + - irecv + - scatter + - reduce_scatter + - _reduce_scatter_base + - _all_gather_base \ No newline at end of file diff --git a/debug/accuracy_tools/kj600/kj600/distributed/wrap_distributed.py b/debug/accuracy_tools/kj600/kj600/distributed/wrap_distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..ba559baa4480ead844b99d3709f245af247c9138 --- /dev/null +++ b/debug/accuracy_tools/kj600/kj600/distributed/wrap_distributed.py @@ -0,0 +1,156 @@ +import os +from functools import wraps +from collections import defaultdict +import yaml +import re +import inspect +import functools +import torch +import torch.nn as nn +import torch.distributed as dist +import torch.utils.hooks as full_hooks + +from ..module_metric import get_metrics + +try: + import torch_npu +except ImportError: + is_gpu = True +else: + is_gpu = False + + +cur_path = os.path.dirname(os.path.realpath(__file__)) +yaml_path = os.path.join(cur_path, "distributed_ops.yaml") +with open(yaml_path) as f: + WrapDistributedOps = yaml.safe_load(f).get('distributed') + +npu_distributed_api = ['isend', 'irecv'] + +distributed_func = {} +for f in dir(dist): + distributed_func[f] = getattr(dist, f) + + +def get_distributed_ops(): + global WrapDistributedOps + _all_distributed_ops = dir(dist) + return set(WrapDistributedOps) & set(_all_distributed_ops) + + +class DistributedOPTemplate(nn.Module): + def __init__(self, op_name, hook): + super(DistributedOPTemplate, self).__init__() + self.op_name_ = op_name + self.prefix_op_name_ = str(op_name) + self.register_forward_hook(hook(), with_kwargs=True) + + def forward(self, *args, **kwargs): + return distributed_func.get(self.op_name_)(*args, **kwargs) + + +class ApiRegistry: + def __init__(self): + self.distributed_attr_origin = {} + self.distributed_attr_hooked = {} + + @staticmethod + def store_ori_attr(ori_api_group, api_list, api_ori_attr): + for api in api_list: + if '.' in api: + sub_module_name, sub_op = api.rsplit('.', 1) + sub_module = getattr(ori_api_group, sub_module_name) + api_ori_attr[api] = getattr(sub_module, sub_op) + else: + api_ori_attr[api] = getattr(ori_api_group, api) + + @staticmethod + def set_api_attr(api_group, attr_dict): + for cc_api_name, cc_api_entry_func in attr_dict.items(): + if '.' in cc_api_name: + sub_module_name, sub_op = cc_api_name.rsplit('.', 1) + sub_module = getattr(api_group, sub_module_name, None) + if sub_module is not None: + setattr(sub_module, sub_op, cc_api_entry_func) + else: + setattr(api_group, cc_api_name, cc_api_entry_func) + + def redirect_api(self): + self.set_api_attr(dist, self.distributed_attr_hooked) + self.set_api_attr(dist.distributed_c10d, self.distributed_attr_hooked) + + def restore_api(self): + self.set_api_attr(dist, self.distributed_attr_origin) + self.set_api_attr(dist.distributed_c10d, self.distributed_attr_origin) + + def initialize_hook(self, hook): + self.store_ori_attr(dist, get_distributed_ops(), self.distributed_attr_origin) + for op_name in get_distributed_ops(): + self.distributed_attr_hooked[op_name] = DistributedOPTemplate(op_name, hook) + + +def get_callstack(): + callstack = [] + for (_, path, line, func, code, _) in inspect.stack(): + stack_line = f'{path}[{line}]' + callstack.append(stack_line) + return callstack + +def op_aggregate(op, t1, t2): + if op == 'min': + return min(t1, t2) + if op == 'max': + return max(t1, t2) + if op == 'norm': + return (t1**2+t2**2)**0.5 + if op == 'zeros': # TODO wrong + return (t1+t2)/2 + +def update_data(old, new): + updated = {op:{} for op in new.keys()} + if old: + for op, tag2tensor in old.items(): + for tag, t_old in tag2tensor.items(): + t_new = new[op][tag] + updated[op][tag] = op_aggregate(op, t_old, t_new) + else: + updated = new + return updated + +def create_hook(context, monitor): + def cc_hook(module, args, kwargs, out=None): + args = args + tuple(kwargs.values()) + if out: + out.wait() + if (dist.is_initialized() and dist.get_rank() not in monitor.module_rank_list): + return out + stack = get_callstack() + whole_stack = ';'.join(stack) + is_target = monitor.cc_codeline == [] + for pattern in monitor.cc_codeline: + if re.search(pattern, whole_stack): + is_target = True + break + if not is_target: + return out + tensor_args = {} + for arg in args: + if isinstance(arg, torch.Tensor): + tensor_args[f'input_{len(tensor_args)}'] = arg + elif isinstance(arg, list): + arg = torch.stack(arg) + tensor_args[f'input_{len(tensor_args)}'] = arg + new_data = {op: get_metrics(op, tensor_args, 1e-8) for op in monitor.ops} + context[module.prefix_op_name_].indata=update_data(context[module.prefix_op_name_].indata, new_data) + if out and isinstance(out, dist.Work): + tensor_res = {} + for res in out.result(): + if isinstance(res, torch.Tensor): + tensor_res[f'output_{len(tensor_res)}'] = res + new_data = {op: get_metrics(op, tensor_res, 1e-8) for op in monitor.ops} + context[module.prefix_op_name_].outdata=update_data(context[module.prefix_op_name_].outdata, new_data) + return out + return cc_hook + +api_register = ApiRegistry() + diff --git a/debug/accuracy_tools/kj600/kj600/features.py b/debug/accuracy_tools/kj600/kj600/features.py index b4fc8f3085ee606972ea6fcdf32a56196a713877..be54215241fb9d1a384ee0944b62a59ea28a0afa 100644 --- a/debug/accuracy_tools/kj600/kj600/features.py +++ b/debug/accuracy_tools/kj600/kj600/features.py @@ -6,6 +6,26 @@ from torch.autograd.functional import jacobian def square_sum(x: torch.tensor): return (x * x).sum() +@torch.no_grad() +def get_min(x: torch.tensor): + return torch.min(x) + + +@torch.no_grad() +def get_max(x: torch.tensor): + return torch.max(x) + + +@torch.no_grad() +def get_zeros(x: torch.tensor, eps: float): + return torch.sum(torch.abs(x) < eps) / x.numel() + +@torch.no_grad() +def get_sign_matches(x: torch.tensor, y:torch.tensor): + xs = x.sign() + ys = y.sign() + same_direction_ratio = ((xs * ys).sum()/ys.numel() + 1)/2 + return same_direction_ratio @torch.no_grad() def eff_rank(param: torch.tensor, threshold=1e-10): diff --git a/debug/accuracy_tools/kj600/kj600/module_hook.py b/debug/accuracy_tools/kj600/kj600/module_hook.py index 233b000f88bbca32b381a2dc7e97922a3776537e..8a35c646f53ea3f62ab232dbcde21de3b83a7399 100644 --- a/debug/accuracy_tools/kj600/kj600/module_hook.py +++ b/debug/accuracy_tools/kj600/kj600/module_hook.py @@ -1,25 +1,21 @@ import os import uuid +import json from collections import defaultdict -from typing import List from datetime import datetime +from functools import partial import torch -from torch.nn.modules.module import register_module_forward_hook import torch.distributed as dist from torch.optim.optimizer import register_optimizer_step_pre_hook, register_optimizer_step_post_hook -from torch.utils.tensorboard import SummaryWriter -from kj600.features import square_sum from kj600.module_spec_verifier import get_config, validate_config_spec -from kj600.optimizer_collect import MixPrecsionOptimizerMon, print_rank_0 -from kj600.features import eff_rank +from kj600.optimizer_collect import MixPrecsionOptimizerMon, print_rank_0, OptimizerMonFactory, MegatronDistributedOptimizerMon +from kj600.features import eff_rank, get_sign_matches from kj600.visualizer import HeatmapVisualizer - - -def get_summary_writer_tag_name(module_or_param_name:str, tag:str, rank): - if rank is None: - return f"{module_or_param_name}/{tag}" - else: - return f"{module_or_param_name}/{rank}/{tag}" +from kj600.anomaly_detect import AnomalyScanner, SummaryWriterWithAD +from kj600.anomaly_inform import AnomalyInformFactory +from kj600.module_metric import get_metrics, write_metrics_tensorboard, get_summary_writer_tag_name +from kj600.distributed.wrap_distributed import api_register, create_hook +from kj600.utils import print_warn_log, print_info_log, get_param_struct class ModuleHookContext: @@ -47,41 +43,80 @@ class ModuleHookContext: class OptimizerContext: def __init__(self) -> None: self.step = 0 - self.param_gnorm = defaultdict(float) # norm of grad - self.param_exp_avg_norm = defaultdict(float) # norm of expection of gradient average (m_{t-1}) - self.param_exp_avg_sign = defaultdict(int) # sign of expection of gradient average (m_{t-1}) - self.param_mg_direction = defaultdict(float) # ratio of parameters in same direction between g_{t} and m_{t-1} - self.param_exp_avg_sq_norm = defaultdict(float) # norm of expection of gradient square (v_{t-1}) - self.param_effective_rank = defaultdict(float) # ratio of parameters above a threshold - self.param_adam_update = defaultdict() # distribution of update (m_t/(v_t**0.5+eps)) - self.param_adam_ratio = defaultdict() # distribution of ratio (m_t/v_t**0.5) + self.param_effective_rank = defaultdict(float) + self.param_mg_direction = defaultdict(float) + self.param_adam_update = defaultdict() + self.param_adam_ratio = defaultdict() + self.param_weight_grad = defaultdict() + self.param_exp_avg = defaultdict() + self.param_exp_avg_sq = defaultdict() + self.metric_list = [] + +class CommunicationContext: + def __init__(self) -> None: + self.indata = {} + self.outdata = {} + + def reset(self): + self.indata = {} + self.outdata = {} class TrainerMon: - + @staticmethod def set_wrapped_optimizer(_wrapped_optimizer): MixPrecsionOptimizerMon.set_wrapped_optimizer(_wrapped_optimizer) - def __init__(self, config_file_path) -> None: + # opt_ty: "Megatron_Float16OptimizerWithFloat16Params" or "Megatron_DistributedOptimizer" + def __init__(self, config_file_path, params_have_main_grad=True, opt_ty=None) -> None: self.module_fwd_hook_context_by_module = defaultdict(ModuleHookContext) self.module_bwd_hook_context_by_module = defaultdict(ModuleHookContext) self.optimizer_context = defaultdict(OptimizerContext) - self.params_have_main_grad = True + self.cc_context = defaultdict(CommunicationContext) + self.params_have_main_grad = params_have_main_grad self.config = get_config(config_file_path) - self.module_rank_list = [int(rank) for rank in self.config.get("module_ranks", "").split(',') if rank.strip()] + self.module_rank_list = self.config.get("module_ranks", []) + self.eps = self.config.get('eps', 1e-8) + self.ops = self.config.get('ops', []) + self.xy_distribution = self.config.get('xy_distribution', False) + if not self.xy_distribution: + print_rank_0("> module input/output input_grad/output_grad is not monitored. ") self.ur_distribution = self.config.get('ur_distribution', False) + if not self.ur_distribution: + print_rank_0("> update vector and ratio vector of adam is not monitored. ") + self.mv_distribution = self.config.get("mv_distribution", False) + if not self.mv_distribution: + print_rank_0("> momentum and variance of adam is not monitored. ") + self.wg_distribution = self.config.get("wg_distribution", False) + if not self.wg_distribution: + print_rank_0("> weight grad of specified module is not monitored. ") self.mg_direction = self.config.get('mg_direction', False) + if not self.mg_direction: + print_rank_0('> grad and momentum direction will not be compared.') + self.cc_distribution = self.config.get("cc_distribution", {}) + if not self.cc_distribution.get('enable', False): + print_rank_0("> cc operator is not monitored.") + else: + self.cc_codeline = self.cc_distribution.get('cc_codeline', []) + api_register.initialize_hook(partial(create_hook, context=self.cc_context, monitor=self)) + api_register.redirect_api() + alert_setting = self.config.get('alert', {"rules":[]}) + self.alert_rules = AnomalyScanner.load_rules(alert_setting["rules"]) + + anomaly_inform = AnomalyInformFactory.create_informer(**alert_setting["inform"]) if "inform" in alert_setting else None + self.optimizer_hooked = False output_base_dir = os.getenv('KJ600_OUTPUT_DIR', './kj600_output') cur_time = datetime.now().strftime('%b%d_%H-%M-%S') unique_id = str(uuid.uuid4())[:8] if dist.is_initialized(): if (dist.get_rank() in self.module_rank_list) or len(self.module_rank_list) == 0: - self.summary_writer = SummaryWriter(os.path.join(output_base_dir, f"{cur_time}-rank{dist.get_rank()}-{unique_id}")) + self.summary_writer = SummaryWriterWithAD( + os.path.join(output_base_dir, f"{cur_time}-rank{dist.get_rank()}-{unique_id}"), self.alert_rules, anomaly_inform) else: - self.summary_writer = SummaryWriter(os.path.join(output_base_dir, f"{cur_time}-{unique_id}")) + self.summary_writer = SummaryWriterWithAD(os.path.join(output_base_dir, f"{cur_time}-{unique_id}"), self.alert_rules, anomaly_inform) # A HeatmapVisualizer instance is associated with an image self.update_heatmap_visualizer = defaultdict(HeatmapVisualizer) self.ratio_heatmap_visualizer = defaultdict(HeatmapVisualizer) @@ -90,21 +125,38 @@ class TrainerMon: self.param_name_list = [] self.param2name = defaultdict(str) - self.mix_precision_optimizer_mon = MixPrecsionOptimizerMon() + self.mix_precision_optimizer_mon = OptimizerMonFactory.create_optimizer_mon(opt_ty) + if opt_ty is None: + assert not self.ur_distribution, "ur_distribution cannot be enabled with unknown optimizer." + assert not self.mv_distribution, "mv_distribution cannot be enabled with unknown optimizer." + self.print_struct = self.config.get("print_struct", False) + self.module_struct = {} return - + def __del__(self): if hasattr(self, "summary_writer"): self.summary_writer.close() - def _hook_module(self, target_name:str, module: torch.nn.Module, fwd_or_bkd): - paths = target_name.split('.') + def _smallest_rank_print(self, msg): + if dist.is_initialized(): + if dist.get_rank() == min(self.module_rank_list): + print_info_log(msg) + else: + print_info_log(msg) + + def _hook_module(self, target_names, module: torch.nn.Module, fwd_or_bkd): if '_modules' not in module.__dict__: # nothing to hook return 0 - + def fwd_hook_fun(module, module_input, module_output): - context = self.module_fwd_hook_context_by_module[module] + context: ModuleHookContext = self.module_fwd_hook_context_by_module[module] + if self.print_struct: + self.module_struct[context.module_name].update( + {"input": f"{get_param_struct(module_input)}", "output": f"{get_param_struct(module_output)}"}) + return + if not self.xy_distribution: + return if not context.format_by_arg: context.set_format_by_arg('input', self.config['targets']) context.set_format_by_arg('output', self.config['targets']) @@ -114,22 +166,35 @@ class TrainerMon: context.focused_out_col = validate_config_spec(context.format_by_arg['output'], module_output, context.module_name, 'output') context.verified = True # expect output be tensor type + tbtag_tensor_map = {} if not context.ignore_in: cared_input = module_input if context.focused_in_col is None else module_input[context.focused_in_col] - cared_input_cal_result = square_sum(cared_input) - else: - cared_input_cal_result = None + tbtag_tensor_map.update(self.build_tbtag_tensor_map(context.module_name, 'input', cared_input)) cared_output = module_output if context.focused_out_col is None else module_output[context.focused_out_col] - context.actv.append((cared_input_cal_result, square_sum(cared_output))) + tbtag_tensor_map.update(self.build_tbtag_tensor_map(context.module_name, 'output', cared_output)) + metric_dict = {} + for metric_name in self.ops: + metric_dict[metric_name] = get_metrics(metric_name, tbtag_tensor_map, self.eps) + if context.micro_step == 0 and context.actv: + print_warn_log( + f"actv context of {context.module_name} is not empty when first micro_step, maybe something wrong happened. Now clear it.") + context.actv.clear() + context.actv.append(metric_dict) context.micro_step += 1 if context.micro_step == self.micro_batch_number: context.micro_step = 0 context.step += 1 return - + def bwd_hook_fun(module, input_grad, output_grad): - context = self.module_bwd_hook_context_by_module[module] + context: ModuleHookContext = self.module_bwd_hook_context_by_module[module] + if self.print_struct: + self.module_struct[context.module_name].update( + {"input_grad": f"{get_param_struct(input_grad)}", "output_grad": f"{get_param_struct(output_grad)}"}) + return + if not self.xy_distribution: + return if not context.format_by_arg: context.set_format_by_arg('input_grad', self.config['targets']) context.set_format_by_arg('output_grad', self.config['targets']) @@ -138,44 +203,53 @@ class TrainerMon: context.focused_in_col = validate_config_spec(context.format_by_arg['input_grad'], input_grad, context.module_name, 'input_grad') context.focused_out_col = validate_config_spec(context.format_by_arg['output_grad'], output_grad, context.module_name, 'output_grad') context.verified = True + + tbtag_tensor_map = {} if not context.ignore_in: cared_input_grad = input_grad if context.focused_in_col is None else input_grad[context.focused_in_col] - cared_input_grad_cal_result = square_sum(cared_input_grad) if cared_input_grad is not None else torch.tensor(0.) - else: - cared_input_grad_cal_result = None + tbtag_tensor_map.update(self.build_tbtag_tensor_map(context.module_name, 'input_grad', cared_input_grad)) cared_output_grad = output_grad if context.focused_out_col is None else output_grad[context.focused_out_col] - context.actvgrad.append((cared_input_grad_cal_result, square_sum(cared_output_grad))) + tbtag_tensor_map.update(self.build_tbtag_tensor_map(context.module_name, 'output_grad', cared_output_grad)) + metric_dict = {} + for metric_name in self.ops: + metric_dict[metric_name] = get_metrics(metric_name, tbtag_tensor_map, self.eps) + if context.micro_step == 0 and context.actvgrad: + print_warn_log(f"actvgrad context of {context.module_name} is not empty when first micro_step, maybe something wrong happened. Now clear it.") + context.actvgrad.clear() + context.actvgrad.append(metric_dict) + context.micro_step += 1 if context.micro_step == self.micro_batch_number: context.micro_step = 0 context.step += 1 return - + + hooked_count = 0 for name, submodule in module.named_modules(): - if name == target_name: + self.module_struct[name] = {} + if name in target_names: submodule.register_forward_hook(fwd_hook_fun) self.module_fwd_hook_context_by_module[submodule] = ModuleHookContext(name) submodule.register_full_backward_hook(bwd_hook_fun) self.module_bwd_hook_context_by_module[submodule] = ModuleHookContext(name) print_rank_0(f"> {name} is monitored successfully") - return 1 - return 0 + hooked_count += 1 + return hooked_count - def hook_modules(self, model:torch.nn.Module, global_batch_size, dp, micro_batch_size, fwd_or_bkd, params_have_main_grad=True): + def hook_modules(self, model:torch.nn.Module, grad_acc_steps): # fwd=0, bkd=1 # targets is module name list like ["xx.xxx1", "xxx.xxx2"] which can be obtained when first run. print_rank_0("> module names:") for name, _ in model.named_modules(): print_rank_0(f"\t{name}") - self.micro_batch_number = global_batch_size // dp // micro_batch_size - + self.micro_batch_number = grad_acc_steps + if not self.module_rank_list or (dist.is_initialized() and dist.get_rank() in self.module_rank_list): - hooked = 0 - for target, _ in self.config['targets'].items(): - hooked += self._hook_module(target, model, fwd_or_bkd=0) - print_rank_0(f"> {hooked} out of {len(self.config['targets'])} are monitored.") + targets = [x for x, _ in model.named_modules()] if self.print_struct else self.config['targets'].keys() + hooked_count = self._hook_module(targets, model, fwd_or_bkd=0) + print_rank_0(f"> {hooked_count} out of {len(self.config['targets'])} are monitored.") else: - return + return if not self.optimizer_hooked: self.optimizer_hooked = True @@ -187,72 +261,125 @@ class TrainerMon: self.param_name_list.append(name) self.param2name[param] = name self.hook_optimizer() - self.params_have_main_grad = params_have_main_grad return + + def build_tbtag_tensor_map(self, module_name, tag, tensor): + metrics = {} + rank = dist.get_rank() if dist.is_initialized() else None + key = get_summary_writer_tag_name(module_name, tag, rank) + if tensor is not None: + metrics[key] = tensor + return metrics + + def generate_param_metrics(self, tag, param_tensor): + metrics = {} + rank = dist.get_rank() if dist.is_initialized() else None + for param, name in self.param2name.items(): + key = get_summary_writer_tag_name(name, tag, rank) + if name not in param_tensor or param_tensor[name] is None: + continue + metrics[key] = param_tensor[name] + return metrics + def generate_cc_metrics(self, cc_name, cc_tensor): + metrics = defaultdict(dict) + rank = dist.get_rank() if dist.is_initialized() else None + for op, tag2tensor in cc_tensor.indata.items(): + for tag, tensor in tag2tensor.items(): + key = get_summary_writer_tag_name(cc_name, tag, rank) + metrics[op].update({key: tensor}) + for op, tag2tensor in cc_tensor.outdata.items(): + for tag, tensor in tag2tensor.items(): + key = get_summary_writer_tag_name(cc_name, tag, rank) + metrics[op].update({key: tensor}) + cc_tensor.reset() + return metrics + + def write_xy_tb(self, step): + if not self.xy_distribution: + return + for _, fwd_context in self.module_fwd_hook_context_by_module.items(): + if not len(fwd_context.actv) == self.micro_batch_number: + print_warn_log(f"fwd_context.actv not equal to micro_batch_number: {len(fwd_context.actv)}, {self.micro_batch_number}") + for metric_name in self.ops: + write_metrics_tensorboard(metric_name, self.summary_writer, fwd_context.actv, step) + fwd_context.actv.clear() + + for _, bwd_context in self.module_bwd_hook_context_by_module.items(): + if not len(bwd_context.actvgrad) == self.micro_batch_number: + print_warn_log(f"bwd_context.actvgrad not equal to micro_batch_number: {len(bwd_context.actvgrad)}, {self.micro_batch_number}") + for metric_name in self.ops: + write_metrics_tensorboard(metric_name, self.summary_writer, bwd_context.actvgrad, step) + bwd_context.actvgrad.clear() + def hook_optimizer(self): # in DDP by default use params_have_main_grad def optimizer_pre_step_hook(optimizer, args, kwargs): + if self.print_struct and not all(value == {} for value in self.module_struct.values()): + self._smallest_rank_print("> module struct:") + self._smallest_rank_print(json.dumps(self.module_struct, indent=4)) + raise Exception("exit after first step when print model struct") context = self.optimizer_context[optimizer] - context.param_exp_avg_norm, context.param_exp_avg_sign, context.param_exp_avg_sq_norm, context.param_adam_update, context.param_adam_ratio = self.mix_precision_optimizer_mon.fetch_mv( - optimizer, self.param2name, self.update_heatmap_visualizer, self.ratio_heatmap_visualizer, self.ur_distribution, self.mg_direction) + context.param_exp_avg, context.param_exp_avg_sq, context.param_adam_update, context.param_adam_ratio = self.mix_precision_optimizer_mon.fetch_mv(self, + optimizer, self.param2name) for param, name in self.param2name.items(): - grad = param.main_grad if self.params_have_main_grad else param.grad - context.param_gnorm[name] = grad.detach().norm() if "params_effrank" in self.config and name in self.config["params_effrank"]: context.param_effective_rank[name] = eff_rank(param.detach()) - + grad = param.main_grad if self.params_have_main_grad else param.grad + if grad is None: + print_warn_log(f"grad is None: {name}, maybe something wrong happened.") + continue + if self.wg_distribution: + context.param_weight_grad[name] = grad if self.mg_direction: - if name in context.param_exp_avg_sign: - g_sign = grad.detach().sign() - m_sign = context.param_exp_avg_sign.pop(name) - same_direction_ratio = ((m_sign * g_sign).sum().item()/m_sign.numel() + 1)/2 + if context.step == 0: + same_direction_ratio = torch.tensor(1.) else: - same_direction_ratio = 1 + same_direction_ratio = get_sign_matches(grad, context.param_exp_avg[name]) context.param_mg_direction[name] = same_direction_ratio + tbtag_tensor_map = {} + if self.wg_distribution: + tbtag_tensor_map.update(self.generate_param_metrics('weight_grad', context.param_weight_grad)) + if self.mv_distribution: + tbtag_tensor_map.update(self.generate_param_metrics('exp_avg', context.param_exp_avg)) + tbtag_tensor_map.update(self.generate_param_metrics('exp_avg_sq', context.param_exp_avg_sq)) + if self.mg_direction: + tbtag_tensor_map.update(self.generate_param_metrics('mg_direction', context.param_mg_direction)) + # if not tbtag_tensor_map: + # return + metric_dict = {} + for metric_name in self.ops: + metric_dict[metric_name] = get_metrics(metric_name, tbtag_tensor_map, self.eps) + if self.cc_distribution: + for k, c in self.cc_context.items(): + cc_metrics = self.generate_cc_metrics(k, c) + for op, m in cc_metrics.items(): + metric_dict[op].update(m) + if not metric_dict: + return + context.metric_list.append(metric_dict) return - + def optimizer_post_step_hook(optimizer, args, kwargs): context = self.optimizer_context[optimizer] rank = dist.get_rank() if dist.is_initialized() else None - for _, fwd_context in self.module_fwd_hook_context_by_module.items(): - if not len(fwd_context.actv) == self.micro_batch_number: - raise Exception(f"fwd_context.actv not equal to micro_batch_number: {len(fwd_context.actv)}, {self.micro_batch_number}") - if not fwd_context.ignore_in: - x_norm = sum([x.item() for x, _ in fwd_context.actv]) - self.summary_writer.add_scalar(get_summary_writer_tag_name(fwd_context.module_name, 'input', rank), x_norm, context.step) - y_norm = sum([y.item() for _, y in fwd_context.actv]) - self.summary_writer.add_scalar(get_summary_writer_tag_name(fwd_context.module_name, 'output', rank), y_norm, context.step) - fwd_context.actv.clear() - - for _, bwd_context in self.module_bwd_hook_context_by_module.items(): - if not len(bwd_context.actvgrad) == self.micro_batch_number: - raise Exception(f"fwd_context.actvgrad not equal to micro_batch_number: {len(fwd_context.actvgrad)}, {self.micro_batch_number}") - if not bwd_context.ignore_in: - x_grad_norm = sum([x.item() for x, _ in bwd_context.actvgrad]) - self.summary_writer.add_scalar(get_summary_writer_tag_name(bwd_context.module_name, 'input_grad', rank), x_grad_norm, context.step) - y_grad_norm = sum([y.item() for _, y in bwd_context.actvgrad]) - self.summary_writer.add_scalar(get_summary_writer_tag_name(bwd_context.module_name, 'output_grad', rank), y_grad_norm, context.step) - bwd_context.actvgrad.clear() - - for param_name, grad_norm in context.param_gnorm.items(): - self.summary_writer.add_scalar(get_summary_writer_tag_name(param_name, 'weight_grad', rank), grad_norm.item(), context.step) - - for param_name, exp_avg_norm in context.param_exp_avg_norm.items(): - self.summary_writer.add_scalar(get_summary_writer_tag_name(param_name, 'exp_avg_norm', rank), exp_avg_norm.item(), context.step) - for param_name, exp_avg_sq_norm in context.param_exp_avg_sq_norm.items(): - self.summary_writer.add_scalar(get_summary_writer_tag_name(param_name, 'exp_avg_sq_norm', rank), exp_avg_sq_norm.item(), context.step) + + self.write_xy_tb(context.step) + if self.ur_distribution: for param_name, _ in context.param_adam_update.items(): self.update_heatmap_visualizer[param_name].visualize(get_summary_writer_tag_name(param_name, 'adam_update', rank), context.step, self.summary_writer) for param_name, _ in context.param_adam_ratio.items(): self.ratio_heatmap_visualizer[param_name].visualize(get_summary_writer_tag_name(param_name, 'adam_ratio', rank), context.step, self.summary_writer) - if self.mg_direction: - for param_name, mg_direction in context.param_mg_direction.items(): - self.summary_writer.add_scalar(get_summary_writer_tag_name(param_name, 'adam_mg_direction', rank), mg_direction, context.step) + + for metric_name in self.ops: + if not context.metric_list: + break + write_metrics_tensorboard(metric_name, self.summary_writer, context.metric_list, context.step) + context.metric_list.clear() context.step += 1 return diff --git a/debug/accuracy_tools/kj600/kj600/module_metric.py b/debug/accuracy_tools/kj600/kj600/module_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..d42749d2be3a230c56aa28a4e983dd6d989c003e --- /dev/null +++ b/debug/accuracy_tools/kj600/kj600/module_metric.py @@ -0,0 +1,125 @@ +import math +import statistics + +from kj600.features import square_sum, get_max, get_min, get_zeros + + +def get_summary_writer_tag_name(module_or_param_name:str, tag:str, rank): + if rank is None: + return f"{module_or_param_name}/{tag}" + else: + return f"{module_or_param_name}/{rank}/{tag}" + + +# 用于存储所有metric实现类的注册表 +config_metric_registry = {} + + +def register_config_metric(key, cls=None): + """装饰器 用于注册Metric的实现类""" + if cls is None: + # 无参数时,返回装饰器函数 + return lambda cls: register_config_metric(key, cls) + config_metric_registry[key] = cls + return cls + + +class Metric(object): + @staticmethod + def get_metric_value(tensor, eps): + pass + + def get_metrics(self, tag2tensor: dict, eps): + metrics_dict = {} + for tag, tensor in tag2tensor.items(): + metrics_dict[tag] = self.get_metric_value(tensor, eps) + return metrics_dict + + @staticmethod + def metric_tensorboard(metric_name, summary_writer, metric_value, step): + pass + + +@register_config_metric("min") +class MinMetric(Metric): + @staticmethod + def get_metric_value(tensor, eps): + return get_min(tensor) + + @staticmethod + def metric_tensorboard(metric_name, summary_writer, metric_value, step): + for key in metric_value[0][metric_name].keys(): + min_value = min([item[metric_name][key].item() for item in metric_value]) + summary_writer.add_scalar(f'{key}_min', min_value, step) + + +@register_config_metric("max") +class MaxMetric(Metric): + @staticmethod + def get_metric_value(tensor, eps): + return get_max(tensor) + + @staticmethod + def metric_tensorboard(metric_name, summary_writer, metric_value, step): + for key in metric_value[0][metric_name].keys(): + max_value = max([item[metric_name][key].item() for item in metric_value]) + summary_writer.add_scalar(f'{key}_max', max_value, step) + + +@register_config_metric("norm") +class NormMetric(Metric): + @staticmethod + def get_metric_value(tensor, eps): + return square_sum(tensor) + + @staticmethod + def metric_tensorboard(metric_name, summary_writer, metric_value, step): + for key in metric_value[0][metric_name].keys(): + norm_value = math.sqrt(sum([item[metric_name][key].item() for item in metric_value])) + summary_writer.add_scalar(f'{key}_norm', norm_value, step) + + +@register_config_metric("zeros") +class ZerosMetric(Metric): + @staticmethod + def get_metric_value(tensor, eps): + return get_zeros(tensor, eps) + + @staticmethod + def metric_tensorboard(metric_name, summary_writer, metric_value, step): + for key in metric_value[0][metric_name].keys(): + zeros_value = statistics.mean([item[metric_name][key].item() for item in metric_value]) + summary_writer.add_scalar(f'{key}_zeros', zeros_value, step) + + +@register_config_metric("id") +class IdentMetric(Metric): + @staticmethod + def get_metric_value(tensor, eps): + if tensor.dim() != 0: + return None + return tensor + + @staticmethod + def metric_tensorboard(metric_name, summary_writer, metric_value, step): #metric_value is a dict, key is parameter name and value is a list of scalar tensor + if len(metric_value) == 1: + for key, value in metric_value[0][metric_name].items(): + if not value: + continue + summary_writer.add_scalar(f'{key}_identical', value.item(), step) + + +def get_metrics(metric_name, tag2tensor, eps): + try: + fun_metric = config_metric_registry[metric_name] + return fun_metric().get_metrics(tag2tensor, eps) + except KeyError as e: + raise ValueError(f"Not supported this metric, expected metric: {config_metric_registry.keys()}, actual metric: {metric_name}") + + +def write_metrics_tensorboard(metric_name, summary_writer, metric_value, step): + try: + fun_metric = config_metric_registry[metric_name] + return fun_metric.metric_tensorboard(metric_name, summary_writer, metric_value, step) + except KeyError as e: + raise ValueError(f"Not supported this metric, expected metric: {config_metric_registry.keys()}, actual metric: {metric_name}") diff --git a/debug/accuracy_tools/kj600/kj600/optimizer_collect.py b/debug/accuracy_tools/kj600/kj600/optimizer_collect.py index 44f478416cc30054d566d2167426095eed941210..dfb473ca074f135d809e32c85a8fc9b4047da4d3 100644 --- a/debug/accuracy_tools/kj600/kj600/optimizer_collect.py +++ b/debug/accuracy_tools/kj600/kj600/optimizer_collect.py @@ -22,20 +22,10 @@ class MixPrecsionOptimizerMon: def __init__(self) -> None: self.fp16_to_fp32_param = {} - - # parameter tensors we want to monitor and their names are in params2name_dict - # base_optimizer is pytorch optimizer, wrapped_optimizer is a normal object with base_optimizer - def fetch_mv(self, torch_opt, params2name, update_heatmap_visualizer, ratio_heatmap_visualizer, ur_distribution, mg_direction): - mix_prec_opt = MixPrecsionOptimizerMon.wrapped_optimizer - - if not self.fp16_to_fp32_param and mix_prec_opt is not None: - for fp16_group, fp32_group in zip(mix_prec_opt.float16_groups, mix_prec_opt.fp32_from_float16_groups): - for fp16_param, fp32_param in zip(fp16_group, fp32_group): - self.fp16_to_fp32_param[fp16_param] = fp32_param - exp_avg_norm_dict = defaultdict(float) - exp_avg_sign_dict = defaultdict(int) - exp_avg_sq_norm_dict = defaultdict(float) + def _fetch_mv_in_adam(self, params2name, torch_opt, monitor): + exp_avg_dict = defaultdict(float) + exp_avg_sq_dict = defaultdict(float) update_dict = defaultdict() ratio_dict = defaultdict() @@ -46,16 +36,52 @@ class MixPrecsionOptimizerMon: if param in torch_opt.state: exp_avg = torch_opt.state[param]["exp_avg"] exp_avg_sq = torch_opt.state[param]["exp_avg_sq"] - exp_avg_norm = exp_avg.detach().norm() - exp_avg_sq_norm = exp_avg_sq.detach().norm() - exp_avg_norm_dict[name] = exp_avg_norm - exp_avg_sq_norm_dict[name] = exp_avg_sq_norm - if mg_direction: - exp_avg_sign_dict[name] = exp_avg.detach().sign() - if ur_distribution: + if monitor.mv_distribution: + exp_avg_dict[name] = exp_avg + exp_avg_sq_dict[name] = exp_avg_sq + if monitor.mg_direction: + exp_avg_dict[name] = exp_avg + if monitor.ur_distribution: update_dict[name] = exp_avg / (torch.sqrt(exp_avg_sq) + torch_opt.defaults['eps']) ratio_dict[name] = exp_avg / torch.sqrt(exp_avg_sq) - update_heatmap_visualizer[name].pre_cal(update_dict[name]) - ratio_heatmap_visualizer[name].pre_cal(ratio_dict[name]) - - return exp_avg_norm_dict, exp_avg_sign_dict, exp_avg_sq_norm_dict, update_dict, ratio_dict + monitor.update_heatmap_visualizer[name].pre_cal(update_dict[name]) + monitor.ratio_heatmap_visualizer[name].pre_cal(ratio_dict[name]) + return exp_avg_dict, exp_avg_sq_dict, update_dict, ratio_dict + + # parameter tensors we want to monitor and their names are in params2name_dict + # base_optimizer is pytorch optimizer, wrapped_optimizer is a normal object with base_optimizer + def fetch_mv(self, monitor, torch_opt, params2name): + mix_prec_opt = MixPrecsionOptimizerMon.wrapped_optimizer + + if not self.fp16_to_fp32_param and mix_prec_opt is not None: + for fp16_group, fp32_group in zip(mix_prec_opt.float16_groups, mix_prec_opt.fp32_from_float16_groups): + for fp16_param, fp32_param in zip(fp16_group, fp32_group): + self.fp16_to_fp32_param[fp16_param] = fp32_param + return self._fetch_mv_in_adam(params2name, torch_opt, monitor) + +class MegatronDistributedOptimizerMon(MixPrecsionOptimizerMon): + def fetch_mv(self, monitor, torch_opt, params2name): + mix_prec_opt = MixPrecsionOptimizerMon.wrapped_optimizer + assert hasattr(mix_prec_opt, "model_float16_groups") and hasattr(mix_prec_opt, "shard_fp32_from_float16_groups"), \ + "megatron distributed optimizer should have model_float16_groups and shard_fp32_from_float16_groups, if not, please check megatron-lm version" + if not self.fp16_to_fp32_param and mix_prec_opt is not None: + for fp16_group, shard_fp32_group in zip(mix_prec_opt.model_float16_groups, mix_prec_opt.shard_fp32_from_float16_groups): + for fp16_param, shard_fp32_param in zip(fp16_group, shard_fp32_group): + self.fp16_to_fp32_param[fp16_param] = shard_fp32_param + + return self._fetch_mv_in_adam(params2name, torch_opt, monitor) + +class DummyOptimizerMon(MixPrecsionOptimizerMon): + def fetch_mv(self, monitor, torch_opt, params2name): + return None, None, None, None + +class OptimizerMonFactory: + @staticmethod + def create_optimizer_mon(opt_ty:str): + if opt_ty == "Megatron_Float16OptimizerWithFloat16Params": + return MixPrecsionOptimizerMon() + if opt_ty == "Megatron_DistributedOptimizer": + return MegatronDistributedOptimizerMon() + if opt_ty == None or opt_ty == "unknown": + return DummyOptimizerMon() + assert opt_ty != None, "opt_ty should be Megatron_Float16OptimizerWithFloat16Params or Megatron_DistributedOptimizer or None or unknown" \ No newline at end of file diff --git a/debug/accuracy_tools/kj600/kj600/unittest/config_1.json b/debug/accuracy_tools/kj600/kj600/unittest/config_1.json deleted file mode 100644 index a3b10f731d10b64a8b2df703079f9c56080876eb..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/kj600/kj600/unittest/config_1.json +++ /dev/null @@ -1,8 +0,0 @@ -{ - "targets": { - "fc": {"input": "tuple[1]:0", "output": "tensor", "input_grad":"tuple[1]:0", "output_grad":"tuple[1]:0"}, - "relu": {"input": "tuple[1]:0", "output": "tensor", "input_grad":"tuple[1]:0", "output_grad":"tuple[1]:0"} - }, - "ur_distribution": true, - "mg_direction": true -} \ No newline at end of file diff --git a/debug/accuracy_tools/kj600/kj600/unittest/test_features.py b/debug/accuracy_tools/kj600/kj600/unittest/test_features.py deleted file mode 100644 index bc8c6dd71ab4e0bf708cf3d97d02dab3a2ded9cc..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/kj600/kj600/unittest/test_features.py +++ /dev/null @@ -1,33 +0,0 @@ -import unittest -import torch -import torch.nn as nn -import torch_npu -from kj600.features import eff_rank - - -class TestFeatureCalculation(unittest.TestCase): - def test_effective_rank(self): - param = torch.randn(10, 10).npu() - rank = eff_rank(param) - self.assertTrue(rank.item() >= 1) - - def test_lambda_max(self): - pass - # input_dim = 10 - # hidden_dim = 100 - # output_dim = 1 - # num_samples = 100 - # X = torch.randn(num_samples, input_dim) - # network = nn.Sequential( - # nn.Linear(input_dim, hidden_dim), - # nn.ReLU(), - # nn.Linear(hidden_dim, output_dim) - # ) - # Y = network(X) - # Y.backward() - # for name, param in network.named_parameters(): - # lm = lambda_max(param) - - -if __name__ == "__main__": - unittest.main() \ No newline at end of file diff --git a/debug/accuracy_tools/kj600/kj600/unittest/test_module_hook.py b/debug/accuracy_tools/kj600/kj600/unittest/test_module_hook.py deleted file mode 100644 index f077fc7004dafddc0836300d5e0ffc19d1ed3d06..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/kj600/kj600/unittest/test_module_hook.py +++ /dev/null @@ -1,78 +0,0 @@ -import argparse -import torch_npu -import torch -import torch.nn.functional as F -from kj600.module_hook import TrainerMon # Modify PYTHONPATH to import TrainerMon -#from hook_api import reg_grad_hook, reg_grad_one_hook, reg_module_backward_hook, reg_module_forward_hook -#from torch.cuda.amp import GradScaler - -from torch.npu.amp import GradScaler - - -# from ptdbg_ascend import PrecisionDebugger as PD -# from monitor import GradientMonitor - -print(torch_npu.__version__) - -#debugger = PD(dump_path="./dump/", hook_name="dump", step=[1, 2, 3], enable_dataloader=False) -#debugger.configure_hook(mode="list", scope=["optim_Adam_step"], ) - -parser = argparse.ArgumentParser(prog="kj600 debug", description="kj600 sample code", epilog="") -parser.add_argument("-o", "--out_dir", type=str, default=".") -args = parser.parse_args() -DTYPE = torch.float32 - - -class Model(torch.nn.Module): - def __init__(self): - super().__init__() - self.fc = torch.nn.Linear(784, 10, dtype=DTYPE) - self.relu = torch.nn.ReLU() - - def forward(self, x): - return self.relu(self.fc(x).type(DTYPE)) - -npu = torch.device('npu:0') -net = Model().to(device=npu) - -config = { - "targets": { - "fc": {"input": "tuple[2]:0", "output": "tensor::"}, - "relu": {"input": "..", "output": ".."} - } -} -# reg_grad_hook(net, hook_factory=hook_factory, config=config) -# reg_grad_one_hook(net, hook=monitor_hook, config=config) -# net.fc.register_forward_hook(get_actv_hook("fc")) -# reg_module_forward_hook(net, module_fwd_hook, config) -# reg_module_backward_hook(net, module_bwd_hook, config) -optimizer = torch.optim.Adam(net.parameters(), lr=0.0001) - -hooker = TrainerMon('./kj600/unittest/config_1.json') -hooker.hook_modules(model=net, global_batch_size=2, dp=1, micro_batch_size=2, fwd_or_bkd=0, params_have_main_grad=False) -# hooker.hook_optimizer(optimizer) - - -class ToyDataset(torch.utils.data.Dataset): - def __init__(self): - self.data = torch.randn(16, 784, dtype=DTYPE, requires_grad=True) - self.labels = torch.randint(low=0, high=9, size=(16,)) - - def __len__(self): - return len(self.labels) - - def __getitem__(self, idx): - return self.data[idx].to(npu), self.labels[idx].to(npu) - -train_ds = ToyDataset() -train_loader = torch.utils.data.DataLoader(train_ds, shuffle=True, batch_size=2) - - -# scaler = GradScaler() -for (inputs, labels) in train_loader: - optimizer.zero_grad() - outputs = net(inputs) - loss = F.cross_entropy(outputs, labels) - - loss.backward() - optimizer.step() diff --git a/debug/accuracy_tools/kj600/kj600/utils.py b/debug/accuracy_tools/kj600/kj600/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fae87693e025e775891673cab2a8873a1c103086 --- /dev/null +++ b/debug/accuracy_tools/kj600/kj600/utils.py @@ -0,0 +1,47 @@ +import os +import time +import sys + + +def _print_log(level, msg, end='\n'): + current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(int(time.time()))) + pid = os.getgid() + print(current_time + "(" + str(pid) + ")-[" + level + "]" + msg, end=end) + sys.stdout.flush() + + +def print_info_log(info_msg, end='\n'): + """ + Function Description: + print info log. + Parameter: + info_msg: the info message. + """ + _print_log("INFO", info_msg, end=end) + + +def print_error_log(error_msg): + """ + Function Description: + print error log. + Parameter: + error_msg: the error message. + """ + _print_log("ERROR", error_msg) + + +def print_warn_log(warn_msg): + """ + Function Description: + print warn log. + Parameter: + warn_msg: the warning message. + """ + _print_log("WARNING", warn_msg) + +def get_param_struct(param): + if isinstance(param, tuple): + return f"tuple[{len(param)}]" + if isinstance(param, list): + return f"list[{len(param)}]" + return "tensor" diff --git "a/debug/accuracy_tools/kj600/\350\256\255\347\273\203\347\212\266\346\200\201\347\233\221\346\216\247\345\267\245\345\205\267\346\200\247\350\203\275\345\237\272\347\272\277.md" "b/debug/accuracy_tools/kj600/\350\256\255\347\273\203\347\212\266\346\200\201\347\233\221\346\216\247\345\267\245\345\205\267\346\200\247\350\203\275\345\237\272\347\272\277.md" new file mode 100644 index 0000000000000000000000000000000000000000..90461fa5c86a822f0b3db9b984b7598eb681259c --- /dev/null +++ "b/debug/accuracy_tools/kj600/\350\256\255\347\273\203\347\212\266\346\200\201\347\233\221\346\216\247\345\267\245\345\205\267\346\200\247\350\203\275\345\237\272\347\272\277.md" @@ -0,0 +1,52 @@ +# ptdbg_ascend精度工具标准性能基线报告 + +## 环境信息 + +NPU:Atlas A2 训练系列产品 + +CPU: + +![输入图片说明](img/cpu_info.png) + +Torch:2.1.0 + +CANN:8.0.RC2 + +除上述环境信息影响性能外,被检控的模块的数量和结构会对性能产生影响,因此本次选取典型网络进行测试,并且选取耗时稳定后的步数进行测试。工具输出键小,对内存无要求。 + +## 模型信息和性能基线 + +以下场景的性能基线测试数据均为多次测试后取平均值,因此实际运行时性能数据可能会根据环境状态稍有浮动。 + +### LLAMA2-13B + +主要数据类型:BFLOAT16 + +模型层数:40 + +配置文件(采了10层): +``` +{ + "targets": { + "language_model.encoder.layers.0": {"input": "tuple[2]:0", "output": "tensor", "input_grad":"tuple[2]:0", "output_grad":"tuple[1]:0"}, + "language_model.encoder.layers.1": {"input": "tuple[2]:0", "output": "tensor", "input_grad":"tuple[2]:0", "output_grad":"tuple[1]:0"}, + "language_model.encoder.layers.2": {"input": "tuple[2]:0", "output": "tensor", "input_grad":"tuple[2]:0", "output_grad":"tuple[1]:0"}, + "language_model.encoder.layers.3": {"input": "tuple[2]:0", "output": "tensor", "input_grad":"tuple[2]:0", "output_grad":"tuple[1]:0"}, + "language_model.encoder.layers.4": {"input": "tuple[2]:0", "output": "tensor", "input_grad":"tuple[2]:0", "output_grad":"tuple[1]:0"}, + "language_model.encoder.layers.5": {"input": "tuple[2]:0", "output": "tensor", "input_grad":"tuple[2]:0", "output_grad":"tuple[1]:0"}, + "language_model.encoder.layers.6": {"input": "tuple[2]:0", "output": "tensor", "input_grad":"tuple[2]:0", "output_grad":"tuple[1]:0"}, + "language_model.encoder.layers.7": {"input": "tuple[2]:0", "output": "tensor", "input_grad":"tuple[2]:0", "output_grad":"tuple[1]:0"}, + "language_model.encoder.layers.8": {"input": "tuple[2]:0", "output": "tensor", "input_grad":"tuple[2]:0", "output_grad":"tuple[1]:0"}, + "language_model.encoder.layers.9": {"input": "tuple[2]:0", "output": "tensor", "input_grad":"tuple[2]:0", "output_grad":"tuple[1]:0"} + }, + "module_ranks": "0" +} +``` + +启动命令参数:python3 -u pretrain_gpt.py --local-rank=1 --tensor-model-parallel-size 8 --pipeline-model-parallel-size 1 --sequence-parallel --num-layers 40 --hidden-size 5120 --ffn-hidden-size 13824 --num-attention-heads 40 --tokenizer-type Llama2Tokenizer --tokenizer-model /new_data/LLM/checkpoint_origin/llama2-13b-hf/tokenizer.model --seq-length 4096 --max-position-embeddings 4096 --micro-batch-size 2 --global-batch-size 16 --make-vocab-size-divisible-by 1 --lr 1e-6 --train-iters 5000 --lr-decay-style cosine --untie-embeddings-and-output-weights --disable-bias-linear --attention-dropout 0.0 --init-method-std 0.01 --hidden-dropout 0.0 --position-embedding-type rope --normalization RMSNorm --use-fused-rmsnorm --swiglu --use-flash-attn --no-masked-softmax-fusion --attention-softmax-in-fp32 --min-lr 1e-8 --weight-decay 1e-1 --lr-warmup-fraction 0.01 --clip-grad 1.0 --adam-beta1 0.9 --initial-loss-scale 4096 --adam-beta2 0.95 --no-gradient-accumulation-fusion --load /data/LLM/checkpoint_magatron/llama2_13b_tp1_pp8 --no-load-optim --no-load-rng --use-fused-swiglu --use-fused-rotary-pos-emb --use-mc2 --bf16 --data-path /data/LLM/data_modellink/llama2_13b/alpaca_text_document --split 949,50,1 --log-interval 1 --save-interval 10000 --eval-interval 1000 --eval-iters 10 --distributed-backend nccl --save ./ckpt + +不加工具原始耗时:**4s** + +加工具后单卡耗时:**4.25s** + +加工具后多卡耗时:**4.35s**