diff --git a/debug/accuracy_tools/api_accuracy_checker/generate_op_script/operator_replication.template b/debug/accuracy_tools/api_accuracy_checker/generate_op_script/operator_replication.template new file mode 100644 index 0000000000000000000000000000000000000000..7630839aa937c6d0419629b5e93c34b51b71f295 --- /dev/null +++ b/debug/accuracy_tools/api_accuracy_checker/generate_op_script/operator_replication.template @@ -0,0 +1,325 @@ +import json +import os +import math +from enum import Enum, auto +import torch +try: + import torch_npu +except ImportError: + pass + + +TENSOR_DATA_LIST = ["torch.Tensor", "torch.nn.parameter.Parameter"] +TORCH_BOOL_TYPE = ["torch.bool"] +TORCH_INT_TYPE = ["torch.uint8", "torch.int8", "torch.int16", "torch.short", "torch.int32", "torch.int", + "torch.int64", "torch.long"] +TORCH_FLOAT_TYPE = ["torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.float", + "torch.float64", "torch.double"] +TORCH_COMPLEX_TYPE = ["torch.complex32", "torch.chalf", "torch.complex64", "torch.cfloat", "torch.complex128", "torch.cdouble"] +RAISE_PRECISION = {{ + "torch.float16": torch.float32, + "torch.half": torch.float32, + "torch.bfloat16": torch.float32, + "torch.float32": torch.float64, + "torch.float": torch.float64 +}} + + +class CompareStandard(Enum): + BINARY_EQUALITY_STANDARD = auto() + ABSOLUTE_THRESHOLD_STANDARD = auto() + ULP_ERROR_STANDARD = auto() + BENCHMARK_STANDARD = auto() + + +def get_device(): + if torch.cuda.is_available(): + device = torch.device("cuda") + elif torch_npu.npu.is_available(): + device = torch.device("npu") + else: + raise Exception("Error: This device is not NPU or GPU!") + return device + + +def generate_bool_tensor(low, high, shape): + low, high = int(low), int(high) + tensor = torch.randint(low, high + 1, shape) + bool_tensor = torch.gt(tensor, 0) + return bool_tensor + + +def generate_numerical_tensor(low, high, shape, data_dtype): + if data_dtype in TORCH_FLOAT_TYPE: + scale = high - low + rand01 = torch.rand(shape, dtype=eval(data_dtype)) + tensor = rand01 * scale + low + elif data_dtype in TORCH_INT_TYPE: + low, high = int(low), int(high) + tensor = torch.randint(low, high + 1, shape, dtype=eval(data_dtype)) + else: + raise NotImplementedError(f"{{data_dtype}} is not supported!") + if torch.numel(tensor) == 0: + return tensor + tmp_tensor = tensor.reshape(-1) + tmp_tensor[0] = low + tmp_tensor[-1] = high + data = tmp_tensor.reshape(shape) + return data + + +def generate_random_tensor(info): + low, high = info.get('Min'), info.get('Max') + data_dtype = info.get('dtype') + shape = tuple(info.get('shape')) + if data_dtype == "torch.bool": + data = generate_bool_tensor(low, high, shape) + else: + data = generate_numerical_tensor(low, high, shape, data_dtype) + return data + + +def generate_real_tensor(data_path): + data_path = os.path.realpath(data_path) + data = torch.load(data_path) + return data + + +def generate_data(info): + data_type = info.get("type") + data_path = info.get("datapath") + if data_type in TENSOR_DATA_LIST: + if data_path: + data = generate_real_tensor(data_path) + else: + data = generate_random_tensor(info) + else: + data = info.get("value") + return data + + +def get_input(): +{args_element_assignment} + args_device = [{args_list_generator_device}] + args_bench = [{args_list_generator_bench}] +{kwargs_value_assignment} + kwargs_device = {{{kwargs_dict_generator_device}}} + kwargs_bench = {{{kwargs_dict_generator_bench}}} + return args_device, kwargs_device, args_bench, kwargs_bench + + +def exec_api_device(args, kwargs): + output_device = {api_type}.{api_name}(*args, **kwargs) + return output_device + + +def exec_api_bench(args, kwargs): + output_bench = {api_type}.{api_name}(*args, **kwargs) + return output_bench + + +def compute_inf_nan_proportion(inf_nan_mask, out_device, out_bench, abs_bench_with_eps, rtol): + out_bench = out_bench.to(out_device.dtype) + min = torch.finfo(out_device.dtype).min + max = torch.finfo(out_device.dtype).max + bench_clip = torch.clamp(out_bench, min=min, max=max) + device_clip = torch.clamp(out_device, min=min, max=max) + clipped_abs_ae = torch.abs(device_clip - bench_clip) + clipped_re = clipped_abs_ae / abs_bench_with_eps + pass_mask = torch.less_equal(clipped_re, rtol) + both_nan_mask = torch.logical_and(torch.isnan(out_device), torch.isnan(bench_clip)) + pass_mask = torch.logical_or(pass_mask, both_nan_mask) + not_pass_mask = torch.logical_not(pass_mask) + not_pass_mask = torch.logical_and(not_pass_mask, inf_nan_mask) + inf_nan_err_cnt = torch.sum(not_pass_mask) + return 0 if torch.sum(inf_nan_mask) == 0 else inf_nan_err_cnt / torch.sum(inf_nan_mask) + + +def compute_rmse(abs_err, normal_value_mask): + if torch.sum(normal_value_mask) == 0: + return 0 + else: + masked_ae = torch.where(normal_value_mask, abs_err, 0) + mse = torch.sum(torch.square(masked_ae)) / torch.sum(normal_value_mask) + rmse = torch.sqrt(mse) + return rmse + + +def compute_error_balance(out_device, out_bench): + larger_count = torch.sum(torch.greater(out_device - out_bench.to(out_device.dtype), 0)) + smaller_count = torch.sum(torch.less(out_device - out_bench.to(out_device.dtype), 0)) + total_count = torch.numel(out_bench) + error_balance = abs(larger_count - smaller_count) / total_count + return error_balance + + +def compare_tensor(out_device, out_bench, api_name): + if out_device.shape != out_bench.shape: + print("ERROR: shape of out_device and out_bench is not equal!") + return None + if torch.numel(out_bench) == 0: + print("Both out_device and out_bench have zero elements.") + return None + print(f"shape is {{out_bench.shape}}") + print(f"dtype of out_device is {{out_device.dtype}}") + print(f"dtype of out_bench is {{out_bench.dtype}}") + dtype_device = out_device.dtype + dtype_bench = out_bench.dtype + if str(dtype_device) in TORCH_FLOAT_TYPE and str(dtype_bench) in TORCH_FLOAT_TYPE \ + or str(dtype_device) in TORCH_INT_TYPE and str(dtype_bench) in TORCH_INT_TYPE \ + or str(dtype_device) in TORCH_BOOL_TYPE and str(dtype_bench) in TORCH_BOOL_TYPE: + out_device = out_device.to(torch.device("cpu")) + if str(dtype_device) in TORCH_BOOL_TYPE or str(dtype_device) in TORCH_INT_TYPE or compare_standard == CompareStandard.BINARY_EQUALITY_STANDARD: + print("compare standard: binary equality standard:") + error_number = torch.sum(out_device != out_bench).item() + error_rate = error_number / torch.numel(out_bench) + print(f"error rate is {{error_rate}}.") + else: + abs_err = torch.abs(out_device - out_bench) + abs_bench = torch.abs(out_bench) + if dtype_bench == torch.float32: + eps = 2 ** -23 + if dtype_bench == torch.float64: + eps = 2 ** -52 + abs_bench_with_eps = abs_bench + eps + rel_err = torch.abs(abs_err / abs_bench_with_eps) + device_finite_mask = torch.isfinite(out_device) + bench_finite_mask = torch.isfinite(out_bench.to(dtype_device)) + both_finite_mask = torch.logical_and(device_finite_mask, bench_finite_mask) + inf_nan_mask = torch.logical_not(both_finite_mask) + if compare_standard == CompareStandard.ABSOLUTE_THRESHOLD_STANDARD: + if dtype_device == torch.float16: + rtol, small_value, small_value_atol = 1.0e-3, 1.0e-3, 1.0e-5 + elif dtype_device == torch.bfloat16: + rtol, small_value, small_value_atol = 4.0e-3, 1.0e-3, 1.0e-5 + else: + rtol, small_value, small_value_atol = 1.0e-6, 1.0e-6, 1.0e-9 + small_value_mask = torch.less_equal(abs_bench, small_value) + small_value_mask = torch.logical_and(small_value_mask, both_finite_mask) + normal_value_mask = torch.logical_and(both_finite_mask, torch.logical_not(small_value_mask)) + inf_nan_proportion = compute_inf_nan_proportion(inf_nan_mask, out_device, out_bench, abs_bench_with_eps, rtol) + rel_err_mask = torch.greater(rel_err, rtol) + rel_err_mask = torch.logical_and(rel_err_mask, normal_value_mask) + if torch.sum(normal_value_mask) == 0: + rel_err_proportion = 0 + else: + rel_err_proportion = torch.sum(rel_err_mask) / torch.sum(normal_value_mask) + abs_err_mask = torch.greater(abs_err, small_value_atol) + abs_err_mask = torch.logical_and(abs_err_mask, small_value_mask) + if torch.sum(small_value_mask) == 0: + abs_err_proportion = 0 + else: + abs_err_proportion = torch.sum(abs_err_mask) / torch.sum(small_value_mask) + print("compare standard: absolute threshold standard") + print(f"relative error ratio is {{rel_err_proportion}}") + print(f"absolute error ratio is {{abs_err_proportion}}") + elif compare_standard == CompareStandard.ULP_ERROR_STANDARD: + if dtype_device == torch.float16: + min_eb, exponent_num = -14, 10 + elif dtype_device == torch.bfloat16: + min_eb, exponent_num = -126, 7 + else: + min_eb, exponent_num = -126, 23 + eb = torch.where(abs_bench == 0, torch.zeros(out_bench.shape), torch.floor(torch.log2(abs_bench))) + eb = torch.maximum(eb, min_eb * torch.ones(out_bench.shape)) + if dtype_device == torch.float32: + ulp_err = (out_device.to(torch.float64) - out_bench).to(torch.float64) * torch.exp2(-eb + exponent_num).to(torch.float64) + else: + ulp_err = (out_device.to(torch.float32) - out_bench).to(torch.float32) * torch.exp2(-eb + exponent_num).to(torch.float32) + ulp_err = torch.abs(ulp_err) + max_ulp_err = torch.max(ulp_err) + mean_ulp_err = torch.mean(ulp_err) + if dtype_device == torch.float32: + ulp_err_proportion = torch.sum(ulp_err > 32) / torch.numel(out_bench) + else: + ulp_err_proportion = torch.sum(ulp_err > 1) / torch.numel(out_bench) + print("compare standard: ulp error standard") + print(f"maximum ulp error is {{max_ulp_err}}") + print(f"mean ulp error is {{mean_ulp_err}}") + print(f"ulp error proportion is {{ulp_err_proportion}}") + else: + if dtype_device == torch.float16: + small_value, small_value_atol = 1.0e-3, 1.0e-5 + elif dtype_device == torch.bfloat16: + small_value, small_value_atol = 1.0e-3, 1.0e-5 + else: + small_value, small_value_atol = 1.0e-6, 1.0e-9 + small_value_mask = torch.less_equal(abs_bench, small_value) + small_value_mask = torch.logical_and(small_value_mask, both_finite_mask) + normal_value_mask = torch.logical_and(both_finite_mask, torch.logical_not(small_value_mask)) + abs_err_mask = torch.greater(abs_err, small_value_atol) + abs_err_mask = torch.logical_and(abs_err_mask, small_value_mask) + if torch.sum(small_value_mask) == 0: + small_value_err_proportion = 0 + else: + small_value_err_proportion = torch.sum(abs_err_mask) / torch.sum(small_value_mask) + rel_err = torch.where(normal_value_mask, rel_err, -1 * torch.ones(out_device.shape)) + if torch.max(rel_err) >= 0: + max_rel_err = torch.max(rel_err) + else: + max_rel_err = 0 + if torch.sum(normal_value_mask) == 0: + mean_rel_err = 0 + else: + mean_rel_err = torch.sum(torch.clamp(rel_err, min=0)) / torch.sum(normal_value_mask) + rmse = compute_rmse(abs_err, normal_value_mask) + error_balance = compute_error_balance(out_device, out_bench) + print("compare standard: benchmark standard") + print(f"small value error proportion is {{small_value_err_proportion}}") + print(f"maximum relative error is {{max_rel_err}}") + print(f"mean relative error is {{mean_rel_err}}") + print(f"root mean squared error is {{rmse}}") + print(f"error balance is {{error_balance}}") + else: + print(f"ERROR: out_device dtype is {{dtype_device}}, out_bench dtype is {{dtype_bench}}, not comparable.") + return None + + +def compare_element(out_device, out_bench, api_name): + if type(out_device) != type(out_bench): + print("ERROR: out_device and out_bench is not the same type!") + return None + if isinstance(out_bench, torch.Tensor): + print(f"data type: {{type(out_bench)}}") + compare_tensor(out_device, out_bench, api_name) + elif isinstance(out_bench, (bool, int, float, str)): + print(f"data type: {{type(out_bench)}}") + if out_device == out_bench: + print("PASS: out_device and out_bench equals.") + else: + print("ERROR: out_device and out_bench is not equal!") + else: + print(f"ERROR: comparison of type {{type(out_bench)}} is not supported.") + return None + + +def compare(out_device, out_bench, api_name): + print("Compare result:") + if type(out_device) != type(out_bench): + print("ERROR: out_device and out_bench is not the same type!") + print("Compare finished.") + return None + if isinstance(out_bench, (list, tuple)): + print(f"data type: {{type(out_bench)}}") + if len(out_device) != len(out_bench): + print("ERROR: len of out_device and out_bench is different!") + print("Compare finished.") + return None + for index, _ in enumerate(out_bench): + print(f"index {{index}}:") + compare_element(out_device[index], out_bench[index], api_name) + else: + compare_element(out_device, out_bench, api_name) + print("Compare finished.") + + +device = get_device() +api_name = "{api_name}" +compare_standard = {compare_standard} +torch.manual_seed({random_seed}) +for i in range({iter_times}): + print(f"iter: {{i}}:") + args_device, kwargs_device, args_bench, kwargs_bench = get_input() + output_device = exec_api_device(args_device, kwargs_device) + output_bench = exec_api_bench(args_bench, kwargs_bench) + compare(output_device, output_bench, api_name) diff --git a/debug/accuracy_tools/monitor/monitor/module_hook.py b/debug/accuracy_tools/monitor/monitor/module_hook.py index 771230f8b4b48e22b5f60ab1a9973eb31bfa3956..3b3e25ea3f22c212dcb36450961e45b46591a0b8 100644 --- a/debug/accuracy_tools/monitor/monitor/module_hook.py +++ b/debug/accuracy_tools/monitor/monitor/module_hook.py @@ -1,8 +1,10 @@ +import inspect import os import uuid import json from collections import defaultdict from functools import partial +from copy import deepcopy from datetime import datetime import torch torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0' @@ -10,18 +12,17 @@ if not torch_version_above_or_equal_2: raise ValueError("msmonitor require torch>=2.0") import torch.distributed as dist -from torch import Stream +from torch.utils.hooks import BackwardHook from torch.optim.optimizer import register_optimizer_step_pre_hook, register_optimizer_step_post_hook from monitor.module_spec_verifier import validate_config_spec from monitor.optimizer_collect import OptimizerMon, print_rank_0, OptimizerMonFactory from monitor.features import eff_rank, get_sign_matches from monitor.visualizer import HeatmapVisualizer from monitor.anomaly_detect import AnomalyScanner, AnomalyDataFactory, SummaryWriterWithAD, CSVWriterWithAD, BaseWriterWithAD -from monitor.anomaly_inform import AnomalyInformFactory from monitor.anomaly_analyse import AnomalyDataWriter -from monitor.module_metric import get_metrics, write_metrics_tensorboard, write_metrics_csv, get_summary_writer_tag_name, TensorMetrics, squash_param_name +from monitor.module_metric import get_metrics, write_metrics_tensorboard, write_metrics_csv, get_summary_writer_tag_name, TensorMetrics, squash_param_name, sqrt_norm_metric, reorder_metric from monitor.distributed.wrap_distributed import api_register, create_hooks, op_aggregate, get_process_group -from monitor.utils import print_warn_log, print_info_log, print_error_log, get_param_struct +from monitor.utils import print_warn_log, print_info_log, print_error_log, get_param_struct, validate_config, validate_ops from monitor.const import Const from monitor.file_check import FileOpen @@ -44,9 +45,10 @@ class ModuleHookContext: def __init__(self, module_name) -> None: self.step = 0 self.micro_step = 0 - self.actv = [] + self.actv = defaultdict(dict) self.actvgrad = [] self.module_name = module_name + self.struct = {} self.format_by_arg = {} self.verified = False self.focused_in_col = 0 @@ -54,8 +56,13 @@ class ModuleHookContext: self.ignore_in = False # no need to care when no key 'input' or 'input_grad' found def set_format_by_arg(self, key_name:str, target_config:dict): - if key_name in target_config[self.module_name]: - self.format_by_arg[key_name] = target_config[self.module_name][key_name] + cared = target_config.get(self.module_name, self.struct) + if key_name in cared: + if isinstance(cared[key_name], dict): # cared = self.struct + config = cared[key_name].get('config') + self.format_by_arg[key_name] = config + else: # cared = target_config[self.module_name] + self.format_by_arg[key_name] = cared[key_name] elif key_name in ['input', 'input_grad']: self.ignore_in = True @@ -68,7 +75,9 @@ class OptimizerContext: self.param_adam_ratio = defaultdict() self.param_weight_grad = defaultdict() self.param_exp_avg = defaultdict() + self.exp_avg_metric = [] self.param_exp_avg_sq = defaultdict() + self.exp_avg_sq_metric = [] self.metric_list = [] @@ -97,7 +106,7 @@ class GradContext: self.post = [] self.acc_metric = [] self.acc = {} - self.actv = defaultdict(dict) + self.actv = {} def reset(self): self.pre.clear() @@ -111,6 +120,9 @@ class TrainerMon: tensor_metrics = TensorMetrics() def __init__(self, config_file_path, process_group=None, params_have_main_grad=True, opt_ty=None) -> None: + """ + opt_ty: "Megatron_Float16OptimizerWithFloat16Params" or "Megatron_DistributedOptimizer" + """ self.module_fwd_hook_context_by_module = defaultdict(ModuleHookContext) self.module_bwd_hook_context_by_module = defaultdict(ModuleHookContext) self.optimizer_context = defaultdict(OptimizerContext) @@ -118,13 +130,16 @@ class TrainerMon: self.grad_context = GradContext() self.process_group = get_process_group(process_group) self.params_have_main_grad = params_have_main_grad + self.opt_ty = opt_ty with FileOpen(config_file_path, 'r') as f: self.config = json.load(f) + validate_config(self.config) self.module_rank_list = self.config.get("module_ranks", []) self.format = self.config.get('format', 'tensorboard') self.eps = self.config.get('eps', 1e-8) self.ops = self.config.get('ops', []) self.ndigits = self.config.get('ndigits', 6) + self.all_xy = self.config.get('all_xy', False) self.xy_distribution = self.config.get('xy_distribution', False) if not self.xy_distribution: print_rank_0("> module input/output input_grad/output_grad is not monitored. ") @@ -161,7 +176,6 @@ class TrainerMon: alert_setting = self.config.get('alert', {"rules":[]}) self.alert_rules = AnomalyScanner.load_rules(alert_setting["rules"]) - anomaly_inform = AnomalyInformFactory.create_informer(**alert_setting["inform"]) if "inform" in alert_setting else None output_base_dir = os.getenv('MONITOR_OUTPUT_DIR', './monitor_output') cur_time = datetime.now().strftime('%b%d_%H-%M-%S') @@ -198,7 +212,7 @@ class TrainerMon: tensorboard_dir, self.alert_rules, unique_id, - anomaly_inform, + None, self.anomaly_data_factory, self.ndigits ) @@ -213,6 +227,7 @@ class TrainerMon: self.ratio_heatmap_visualizer = defaultdict(HeatmapVisualizer) self.micro_batch_number = 1 + self.model = None self.weight_hooked = False self.optimizer_hooked = False self.param_registered = False @@ -221,8 +236,13 @@ class TrainerMon: self.tp_group = None self.param2name = defaultdict(str) + self.name2index = defaultdict() + self.name2indices = defaultdict() + self.name2param = {} self.param_name_call_id = {} self.call_id = 0 + self.grad_accs = [] + self.handles = defaultdict(list) self.mix_precision_optimizer_mon = OptimizerMonFactory.create_optimizer_mon(opt_ty) if opt_ty is None: @@ -235,13 +255,21 @@ class TrainerMon: if self.print_struct: self.verbose = True self.struct_printed = False - self.module_struct = defaultdict(dict) + self.module_struct = {} return def __del__(self): if hasattr(self, "summary_writer"): self.summary_writer.close() + + @property + def ops(self): + return self._ops + + @ops.setter + def ops(self, value): + self._ops = validate_ops(value) @staticmethod def set_wrapped_optimizer(_wrapped_optimizer): @@ -261,7 +289,7 @@ class TrainerMon: metrics = {} rank = dist.get_rank() if dist.is_initialized() else None key = get_summary_writer_tag_name(module_name, tag, rank) - if tensor is not None: + if torch.is_tensor(tensor): metrics[key] = tensor return metrics @@ -282,20 +310,59 @@ class TrainerMon: if not isinstance(model, list): model = [model] - + self.model = model self._register_param_name(model) self.micro_batch_number = grad_acc_steps + + targets = self.config['targets'] + module_in_all_stage = [key for key in targets.keys() if Const.VPP_SEP not in key] + for key in module_in_all_stage: + struct = targets.pop(key) + targets.update({f'{vpp_stage}{Const.VPP_SEP}{key}':struct for vpp_stage in range(len(model))}) + + hooked_count = 0 for vpp_stage, model_chunk in enumerate(model): - vpp_stage = f'{vpp_stage}{Const.vpp_sep}' if self.vpp else '' + vpp_stage = f'{vpp_stage}{Const.VPP_SEP}' targets = [x for x, _ in model_chunk.named_modules()] if self.print_struct else self.config['targets'].keys() - hooked_count = self._hook_module(targets, model_chunk, vpp_stage) - print_rank_0(f"> {hooked_count} out of {len(self.config['targets'])} are monitored.") + hooked_count += self._hook_module(targets, model_chunk, vpp_stage) + + print_rank_0(f"> {hooked_count} out of {len(self.config['targets'])} are monitored.") + + def clone_if_tensor(args): + if isinstance(args, tuple): + return tuple([clone_if_tensor(arg) for arg in args]) + elif isinstance(args, torch.Tensor): + return args.clone() + else: + return args + + @torch.no_grad + def wrap_hook_setup(setup): + def wrapped_setup(*args, **kwargs): + args = setup(*args, **kwargs) + args = clone_if_tensor(args) + return args + + return wrapped_setup + + BackwardHook.setup_output_hook = wrap_hook_setup(BackwardHook.setup_output_hook) if not self.optimizer_hooked: self.hook_optimizer() return + def generate_mv_metrics(self, opt_context): + if not self.mv_distribution: + return + opt_context.exp_avg_metric = {} + opt_context.exp_avg_sq_metric = {} + m_tag_tensor_map = self.generate_param_metrics('exp_avg', opt_context.param_exp_avg) + v_tag_tensor_map = self.generate_param_metrics('exp_avg_sq', opt_context.param_exp_avg_sq) + for metric_name in self.ops: + opt_context.exp_avg_metric[metric_name] = get_metrics(metric_name, m_tag_tensor_map, self.eps) + opt_context.exp_avg_sq_metric[metric_name] = get_metrics(metric_name, v_tag_tensor_map, self.eps) + def generate_wgrad_metrics(self): if not self.wg_distribution: return {}, {} @@ -304,7 +371,9 @@ class TrainerMon: if self.weight_hooked: for metric_name in self.ops: unreduced[metric_name] = get_metrics(metric_name, self.grad_context.acc, self.eps) - self.grad_context.acc_metric = [unreduced] + self.grad_context.acc_metric = [unreduced.copy()] + sqrt_norm_metric(unreduced) + unreduced = reorder_metric(unreduced) grad_dict = {} for param, name in self.param2name.items(): @@ -317,19 +386,19 @@ class TrainerMon: print_warn_log(f"grad is None: {name}, maybe something wrong happened.") continue key = get_summary_writer_tag_name(name, 'post_grad', self.rank) - grad_dict[key] = grad + grad_dict[key] = grad - reduced = {op:get_metrics(op, grad_dict, self.eps) for op in self.ops} - self.grad_context.post = [reduced] + reduced = {op: get_metrics(op, grad_dict, self.eps) for op in self.ops} + self.grad_context.post = [reduced.copy()] + sqrt_norm_metric(reduced) + reduced = reorder_metric(reduced) return reduced, unreduced - def monitor_gnorm_with_ad(self, model, grad_acc_steps=1, optimizer=None, tp_group=None, dp_group=None): print_info_log(f'grad acc steps {grad_acc_steps}') self.hook_optimizer(optimizer) self.micro_batch_number = grad_acc_steps - self.backward_only = True self.dp_group = dp_group self.tp_group = tp_group @@ -341,13 +410,39 @@ class TrainerMon: def generate_param_metrics(self, tag, param_tensor): metrics = {} rank = dist.get_rank() if dist.is_initialized() else None - for _, name in self.param2name.items(): + for name in self.param2name.values(): key = get_summary_writer_tag_name(name, tag, rank) if name not in param_tensor or param_tensor[name] is None: continue metrics[key] = param_tensor[name] return metrics + def generate_xy_metrics(self): + actv = {} + for fwd_context in self.module_fwd_hook_context_by_module.values(): + for op in self.ops: + if op not in actv: + actv[op] = {} + actv[op].update(fwd_context.actv[op]) + sqrt_norm_metric(actv) + actv = reorder_metric(actv) + + actv_grad = deepcopy(self.grad_context.actv) + sqrt_norm_metric(actv_grad) + actv_grad = reorder_metric(actv_grad) + + return actv, actv_grad + + def reload_xy(self, xy_distribution=False): + self.xy_distribution = xy_distribution + + for handle in self.handles['xy']: + handle.remove() + self.handles['xy'].clear() + self.hook_modules(self.model, self.micro_batch_number) + for _, fwd_context in self.module_fwd_hook_context_by_module.items(): + fwd_context.actv.clear() + def write_adhoc_check(self, step): TrainerMon.tensor_metrics.flush(self.summary_writer) @@ -357,12 +452,16 @@ class TrainerMon: for _, fwd_context in self.module_fwd_hook_context_by_module.items(): if len(fwd_context.actv) == 0: continue - if not len(fwd_context.actv) == self.micro_batch_number: - print_warn_log(f"fwd_context.actv not equal to micro_batch_number: {len(fwd_context.actv)}, {self.micro_batch_number}") - self.write_metrics(self.ops, self.summary_writer, fwd_context.actv, step, 'actv') + self.write_metrics(self.ops, self.summary_writer, [fwd_context.actv], step, 'actv') fwd_context.actv.clear() + if self.grad_context.actv: + self.write_metrics(self.ops, self.summary_writer, [self.grad_context.actv], step, 'actv_grad') - self.write_metrics(self.ops, self.summary_writer, [self.grad_context.actv], step, 'grad_actv') + def write_mv_tb(self, opt_context): + if not self.mv_distribution: + return + self.write_metrics(self.ops, self.summary_writer, [opt_context.exp_avg_metric], opt_context.step, 'exp_avg') + self.write_metrics(self.ops, self.summary_writer, [opt_context.exp_avg_sq_metric], opt_context.step, 'exp_avg_sq') def write_grad_tb(self, step): if not self.wg_distribution: @@ -375,6 +474,20 @@ class TrainerMon: # in DDP by default use params_have_main_grad def optimizer_pre_step_hook(optimizer, args, kwargs): context = self.optimizer_context[optimizer] + if self.opt_ty in Const.DEEPSPEED_OPT_TY: + if context.step == 0: + return + elif context.step == 1: + self.name2indices = self.mix_precision_optimizer_mon.get_param_index(self.param2name, self.name2index) + mv_result = self.mix_precision_optimizer_mon.fetch_mv(self, optimizer, self.param2name, self.name2indices) + self.param2name = mv_result.grad + else: + mv_result = self.mix_precision_optimizer_mon.fetch_mv(self, optimizer, self.param2name) + context.param_exp_avg = mv_result.exp_avg + context.param_exp_avg_sq = mv_result.exp_avg_sq + context.param_adam_update = mv_result.update + context.param_adam_ratio = mv_result.ratio + if self.print_struct and not all(value == {} for value in self.module_struct.values()) and not self.struct_printed: self._smallest_rank_print("> module struct:") self._smallest_rank_print(json.dumps(self.module_struct, indent=4)) @@ -384,37 +497,24 @@ class TrainerMon: self._smallest_rank_print("> Used communication ops and corresponding stack") self._smallest_rank_print(json.dumps({k:[i.split(';') for i in v] for k,v in self.cc_logged_stack.items()}, indent=4)) raise Exception("exit after first step when print cc stack") - - - self.generate_wgrad_metrics() - mv_result = self.mix_precision_optimizer_mon.fetch_mv(self, optimizer, self.param2name) - context.param_exp_avg = mv_result.exp_avg - context.param_exp_avg_sq = mv_result.exp_avg_sq - context.param_adam_update = mv_result.update - context.param_adam_ratio = mv_result.ratio - - for param, name in self.param2name.items(): - if "params_effrank" in self.config and name in self.config["params_effrank"]: - context.param_effective_rank[name] = eff_rank(param.detach()) - grad = param.main_grad if self.params_have_main_grad else param.grad - if grad is None: - print_warn_log(f"grad is None: {name}, maybe something wrong happened.") - continue + self.generate_wgrad_metrics() + self.generate_mv_metrics(context) - if self.mg_direction: + tbtag_tensor_map = {} + if self.mg_direction: + for param, name in self.param2name.items(): + grad = param.main_grad if self.params_have_main_grad else param.grad + if grad is None: + print_warn_log(f"grad is None: {name}, maybe something wrong happened.") + continue if context.step == 0: same_direction_ratio = torch.tensor(1.) else: same_direction_ratio = get_sign_matches(grad, context.param_exp_avg[name]) context.param_mg_direction[name] = same_direction_ratio - - tbtag_tensor_map = {} - if self.mv_distribution: - tbtag_tensor_map.update(self.generate_param_metrics('exp_avg', context.param_exp_avg)) - tbtag_tensor_map.update(self.generate_param_metrics('exp_avg_sq', context.param_exp_avg_sq)) - if self.mg_direction: tbtag_tensor_map.update(self.generate_param_metrics('mg_direction', context.param_mg_direction)) + metric_dict = {} for metric_name in self.ops: metric_dict[metric_name] = get_metrics(metric_name, tbtag_tensor_map, self.eps) @@ -431,12 +531,16 @@ class TrainerMon: def optimizer_post_step_hook(optimizer, args, kwargs): context = self.optimizer_context[optimizer] + if (self.opt_ty in Const.DEEPSPEED_OPT_TY and context.step == 0): + context.step += 1 + return rank = dist.get_rank() if dist.is_initialized() else None if self.anomaly_data_factory: self.anomaly_data_factory.set_call_id(self.param_name_call_id) self.write_xy_tb(context.step) self.write_grad_tb(context.step) + self.write_mv_tb(context) self.write_adhoc_check(context.step) if self.ur_distribution: @@ -491,70 +595,139 @@ class TrainerMon: print_info_log(msg) else: print_info_log(msg) + + def _is_target_param(self, param_name, param, prefix): + squash_name = prefix + squash_param_name(param_name) + name = prefix + param_name + for target in self.config['targets'].keys(): + if param_name.startswith(target) or squash_name.startswith(target) or name.startswith(target): + setattr(param, "zero_out_wgrad", True) + return True + + return False + def _register_chunk(self, model_chunk, prefix): + for index, (param_name, param) in enumerate(model_chunk.named_parameters()): + if not param.requires_grad: + continue + if self._is_target_param(param_name, param, prefix): + name = prefix + squash_param_name(param_name) + if name in self.param2name.values(): + print_error_log(f'same name {name} for different param. Current param is {param_name}. \ + May be error of squash_param_name') + raise Exception("param with same name will be overwritten.") + self.param2name[param] = name + self.name2param[name] = param + self.name2index[name] = index + def _register_param_name(self, model): if self.param_registered: return + if not isinstance(model, list): model = [model] if len(model) > 1: self.vpp = True self._smallest_rank_print('vpp enabled') - - for vpp_stage, model_chunk in enumerate(model): - prefix = f'{Const.vpp}{vpp_stage}{Const.vpp_sep}' if self.vpp else '' - for param_name, param in model_chunk.named_parameters(): - name = prefix + squash_param_name(param_name) - for target in self.config['targets'].keys(): - if param_name.startswith(target) and param.requires_grad: - self._smallest_rank_print(f'>> monitoring: {name}') - setattr(param, "zero_out_wgrad", True) - if name in self.param2name.values() or name == '': - print_error_log(f'same name {name} for different param. Current param is {param_name}. \ - May be error of squash_param_name') - raise Exception("param with same name will be overwriten.") - self.param2name[param] = name - break + for vpp_stage, model_chunk in enumerate(model): + prefix = f'{vpp_stage}{Const.VPP_SEP}' + self._register_chunk(model_chunk, prefix) + self.param_registered = True + def _is_target_module(self, module_name, targets, vpp_stage): + if self.all_xy or self.print_struct: + return vpp_stage + squash_param_name(module_name) + for pattern in [ + vpp_stage + squash_param_name(module_name), + vpp_stage + module_name, + ]: + if pattern in targets: + return pattern + return "" + def _hook_module(self, target_names, module: torch.nn.Module, vpp_stage=''): if '_modules' not in module.__dict__: # nothing to hook return 0 + def _is_recomputation(): + """Check if the current operation is in the recomputation phase. + + This function inspects the current call stack to indicate whether the current operation is in the + recomputation phase. We use a blacklist mechanism, now supported megatron and mindspeed framework. + megatron: The 'backward' function is called by the 'torch/autograd/function.py' file. + mindspeed: The 'checkpoint_function_backward' function is called by the 'torch/autograd/function.py' + file or the custom module(use CheckpointWithoutOutput) with the 'backward' function is executed within the + 'torch/_tensor.py' file. + + Returns: + bool: True if in the recomputation phase, False otherwise. + """ + backward_function_indices = [] + call_stack = inspect.stack() + + # Identify the function 'backward' is being executed within the 'torch/_tensor.py' file. + for frame_info in call_stack: + if frame_info.function == 'backward' and frame_info.filename.endswith('torch/_tensor.py'): + del call_stack + return True + + # Identify indices in the call stack where the specific function is being executed + for idx, frame_info in enumerate(call_stack): + if frame_info.function == 'backward' or frame_info.function == 'checkpoint_function_backward': + backward_function_indices.append(idx) + + # Check if the execution is within 'torch/autograd/function.py' file + for idx in backward_function_indices: + if idx + 1 < len(call_stack) and call_stack[idx + 1].filename.endswith('torch/autograd/function.py'): + del call_stack + return True + + del call_stack + return False + def fwd_hook_fun(module, module_input, module_output): + if _is_recomputation(): + return context: ModuleHookContext = self.module_fwd_hook_context_by_module[module] + if not context.struct: + context.struct = {Const.ACTV_IN: get_param_struct(module_input), Const.ACTV_OUT: get_param_struct(module_output)} if self.print_struct: - self.module_struct[context.module_name].update( - {"input": f"{get_param_struct(module_input)}", "output": f"{get_param_struct(module_output)}"}) + if context.module_name not in self.module_struct: + self.module_struct[context.module_name] = {} + self.module_struct[context.module_name].update(context.struct) return + if not module.training: + return + if not context.format_by_arg: + context.set_format_by_arg(Const.ACTV_IN, self.config['targets']) + context.set_format_by_arg(Const.ACTV_OUT, self.config['targets']) if not context.format_by_arg: - context.set_format_by_arg('input', self.config['targets']) - context.set_format_by_arg('output', self.config['targets']) + return if not context.verified: if not context.ignore_in: - context.focused_in_col = validate_config_spec(context.format_by_arg['input'], module_input, context.module_name, 'input') - context.focused_out_col = validate_config_spec(context.format_by_arg['output'], module_output, context.module_name, 'output') + context.focused_in_col = validate_config_spec(context.format_by_arg[Const.ACTV_IN], module_input, context.module_name, Const.ACTV_IN) + context.focused_out_col = validate_config_spec(context.format_by_arg[Const.ACTV_OUT], module_output, context.module_name, Const.ACTV_OUT) context.verified = True # expect output be tensor type tbtag_tensor_map = {} if not context.ignore_in: cared_input = module_input if context.focused_in_col is None else module_input[context.focused_in_col] - tbtag_tensor_map.update(self.build_tbtag_tensor_map(context.module_name, 'input', cared_input)) + tbtag_tensor_map.update(self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', Const.ACTV_IN, cared_input)) cared_output = module_output if context.focused_out_col is None else module_output[context.focused_out_col] - tbtag_tensor_map.update(self.build_tbtag_tensor_map(context.module_name, 'output', cared_output)) - metric_dict = {} - for metric_name in self.ops: - metric_dict[metric_name] = get_metrics(metric_name, tbtag_tensor_map, self.eps) - if context.micro_step == 0 and context.actv: - print_warn_log( - f"actv context of {context.module_name} is not empty when first micro_step, maybe something wrong happened. Now clear it.") - context.actv.clear() - context.actv.append(metric_dict) + tbtag_tensor_map.update(self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', Const.ACTV_OUT, cared_output)) + for metric_name in self.ops: + if context.micro_step == 0 and context.actv.get(metric_name, []): + print_warn_log( + f"actv context of {context.module_name} is not empty when first micro_step, maybe something wrong happened. Now clear it.") + context.actv.clear() + context.actv[metric_name].update(get_metrics(metric_name, tbtag_tensor_map, self.eps)) + context.micro_step += 1 if context.micro_step == self.micro_batch_number: context.micro_step = 0 @@ -563,35 +736,40 @@ class TrainerMon: def bwd_hook_fun(module, input_grad, output_grad): context: ModuleHookContext = self.module_bwd_hook_context_by_module[module] + if not context.struct: + context.struct = {Const.ACTVGRAD_IN: get_param_struct(input_grad), Const.ACTVGRAD_OUT: get_param_struct(output_grad)} if self.print_struct: - self.module_struct[context.module_name].update( - {"input_grad": f"{get_param_struct(input_grad)}", "output_grad": f"{get_param_struct(output_grad)}"}) + if context.module_name not in self.module_struct: + self.module_struct[context.module_name] = {} + self.module_struct[context.module_name].update(context.struct) return if not context.format_by_arg: - context.set_format_by_arg('input_grad', self.config['targets']) - context.set_format_by_arg('output_grad', self.config['targets']) + context.set_format_by_arg(Const.ACTVGRAD_IN, self.config['targets']) + context.set_format_by_arg(Const.ACTVGRAD_OUT, self.config['targets']) if not context.format_by_arg: return if not context.verified: if not context.ignore_in: - context.focused_in_col = validate_config_spec(context.format_by_arg['input_grad'], input_grad, context.module_name, 'input_grad') - context.focused_out_col = validate_config_spec(context.format_by_arg['output_grad'], output_grad, context.module_name, 'output_grad') + context.focused_in_col = validate_config_spec(context.format_by_arg[Const.ACTVGRAD_IN], input_grad, context.module_name, Const.ACTVGRAD_IN) + context.focused_out_col = validate_config_spec(context.format_by_arg[Const.ACTVGRAD_OUT], output_grad, context.module_name, Const.ACTVGRAD_OUT) context.verified = True tbtag_tensor_map = {} if not context.ignore_in: cared_input_grad = input_grad if context.focused_in_col is None else input_grad[context.focused_in_col] - tbtag_tensor_map.update(self.build_tbtag_tensor_map(context.module_name+f'_{context.micro_step}', f'input_grad', cared_input_grad)) + tbtag_tensor_map.update(self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', Const.ACTVGRAD_IN, cared_input_grad)) cared_output_grad = output_grad if context.focused_out_col is None else output_grad[context.focused_out_col] - tbtag_tensor_map.update(self.build_tbtag_tensor_map(context.module_name+f'_{context.micro_step}', f'output_grad', cared_output_grad)) + tbtag_tensor_map.update(self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', Const.ACTVGRAD_OUT, cared_output_grad)) if context.micro_step == 0 and context.actvgrad: print_warn_log(f"actvgrad context of {context.module_name} is not empty when first micro_step, maybe something wrong happened. Now clear it.") context.actvgrad.clear() for metric_name in self.ops: + if metric_name not in self.grad_context.actv: + self.grad_context.actv[metric_name] = {} self.grad_context.actv[metric_name].update(get_metrics(metric_name, tbtag_tensor_map, self.eps)) - + context.micro_step += 1 if context.micro_step == self.micro_batch_number: context.micro_step = 0 @@ -604,17 +782,19 @@ class TrainerMon: hooked_count = 0 if self.xy_distribution or self.print_struct: for module_name, submodule in module.named_modules(): - name = vpp_stage + module_name - self.module_struct[name] = {} - if name in target_names or module_name in target_names: - if not self.backward_only: - submodule.register_forward_hook(fwd_hook_fun) - self.module_fwd_hook_context_by_module[submodule] = ModuleHookContext(name) - if not self.forward_only: - submodule.register_full_backward_hook(bwd_hook_fun) - self.module_bwd_hook_context_by_module[submodule] = ModuleHookContext(name) - print_rank_0(f"> {name} is monitored successfully") - hooked_count += 1 + name = self._is_target_module(module_name, target_names, vpp_stage) + if not name: + continue + if not self.backward_only: + handle = submodule.register_forward_hook(fwd_hook_fun) + self.handles['xy'].append(handle) + self.module_fwd_hook_context_by_module[submodule] = ModuleHookContext(name) + if not self.forward_only: + handle = submodule.register_full_backward_hook(bwd_hook_fun) + self.handles['xy'].append(handle) + self.module_bwd_hook_context_by_module[submodule] = ModuleHookContext(name) + print_rank_0(f"> {name} is monitored successfully") + hooked_count += 1 return hooked_count def _hook_weights(self): @@ -637,6 +817,8 @@ class TrainerMon: setattr(param, 'micro_step', 0) param_tmp = param.expand_as(param) grad_acc = param_tmp.grad_fn.next_functions[0][0] - grad_acc.register_hook(partial(param_hook, context_dict=context.acc, param=param, key=key, name=name)) + handle = grad_acc.register_hook(partial(param_hook, context_dict=context.acc, param=param, key=key, name=name)) + self.grad_accs.append(grad_acc) + self.handles['wgrads'].append(handle) self.weight_hooked = True