diff --git a/debug/accuracy_tools/monitor/monitor/anomaly_detect.py b/debug/accuracy_tools/monitor/monitor/anomaly_detect.py index cb8d6ee82cdc0bf8668b0531ea6ce06e071e0da0..073f220e82d6eea3ddae7501cde458fa0260091e 100644 --- a/debug/accuracy_tools/monitor/monitor/anomaly_detect.py +++ b/debug/accuracy_tools/monitor/monitor/anomaly_detect.py @@ -160,8 +160,10 @@ class BaseWriterWithAD: self.anomalies.clear() def add_scalar(self, tag, scalar_value, global_step=None): - avg = self._update_tag2scalars(tag, scalar_value) - detected, rule_name = self._ad(scalar_value, history=avg) + detected = False + if self.ad_rules: + avg = self._update_tag2scalars(tag, scalar_value) + detected, rule_name = self._ad(scalar_value, history=avg) if detected: exception_message = f"Rule {rule_name} reports anomaly signal in {tag} at step {global_step}." print_info_log(f"{bcolors.WARNING}> {exception_message}{bcolors.ENDC}") @@ -222,15 +224,15 @@ class CSVWriterWithAD(BaseWriterWithAD): new_data.append([name] + metric_value) else: new_data.append(name.split(Const.VPP_SEP) + metric_value) - new_data = pd.DataFrame(new_data) + new_data = pd.DataFrame(new_data).round(self.ndigits) new_data.to_csv(filepath, mode='a+', header=False, index=False) self.context_dict = defaultdict(list) def add_scalar(self, tag, scalar_value, global_step): super().add_scalar(tag, scalar_value, global_step) - name = tag.split('/')[0] - self.context_dict[name].append(round(scalar_value, self.ndigits)) + name = tag[0].split('/')[0] + self.context_dict[name].append(scalar_value.item()) def close(self): pass @@ -243,5 +245,6 @@ class SummaryWriterWithAD(SummaryWriter, BaseWriterWithAD): def add_scalar(self, tag, scalar_value, global_step): super(SummaryWriter, self).add_scalar(tag, scalar_value, global_step) + tag = f'{tag[0]}_{tag[1]}' return super().add_scalar(tag, scalar_value, global_step) \ No newline at end of file diff --git a/debug/accuracy_tools/monitor/monitor/distributed/wrap_distributed.py b/debug/accuracy_tools/monitor/monitor/distributed/wrap_distributed.py index 1b82c5704fc62ece14ef30f67b45b4a41433d875..a77852420b91b41027b4dc2d1de6c4a7ddc01c90 100644 --- a/debug/accuracy_tools/monitor/monitor/distributed/wrap_distributed.py +++ b/debug/accuracy_tools/monitor/monitor/distributed/wrap_distributed.py @@ -7,7 +7,7 @@ import torch import torch.nn as nn import torch.distributed as dist -from ..module_metric import get_metrics +from monitor.module_metric import get_metrics, get_summary_writer_tag_name try: import torch_npu @@ -15,6 +15,8 @@ except ImportError: pass PREFIX_POST = "post" +PREFIX_PRE = "pre" +RANK = None OpsPath = os.path.join(os.path.dirname(__file__), "distributed_ops.yaml") with open(OpsPath) as f: @@ -135,26 +137,31 @@ def op_aggregate(op, tensorlist): if isinstance(tensorlist, torch.Tensor): return tensorlist if not tensorlist: - return torch.nan + return torch.tensor(torch.nan) if op == 'min': return min(tensorlist) if op == 'max': return max(tensorlist) if op == 'norm': return sum(tensorlist) - if op == 'zeros': # TODO wrong - return sum(tensorlist) / len(tensorlist) if len(tensorlist) != 0 else 0 - return torch.nan + if op == 'zeros': + return sum(tensorlist) / len(tensorlist) + if op == 'nans': + return sum(tensorlist) + if op == 'mean': + return sum(tensorlist) / len(tensorlist) + return torch.tensor(torch.nan) def update_data(old, new): - for op, tag2tensorlist in new.items(): - if op not in old: - old[op] = {} - for tag, tensor in tag2tensorlist.items(): - if tag not in old[op]: - old[op][tag] = [tensor] + for tag, op2tensor in new.items(): + if tag not in old: + old[tag] = {} + for op, tensor in op2tensor.items(): + if op not in old[tag]: + old[tag][op] = [tensor] else: - old[op][tag].append(tensor) + old[tag][op].append(tensor) + return old @@ -169,29 +176,30 @@ def is_target_line(codeline): return False @torch.no_grad() -def catch_data(cc_context, ops, args, prefix): +def catch_data(cc_context, cc_name, ops, args, prefix): tensor_args = {} for arg in args: if isinstance(arg, torch.Tensor): - tensor_args[f'{prefix}_{len(tensor_args)}'] = arg + key = get_summary_writer_tag_name(cc_name, f'{prefix}_{len(tensor_args)}', RANK) + tensor_args[key] = arg elif isinstance(arg, list): if isinstance(arg[0], torch.Tensor): stacked_arg = torch.stack(arg) elif isinstance(arg[0], dist.P2POp): stacked_arg = torch.stack([op.tensor for op in arg]) - tensor_args[f'{prefix}_{len(tensor_args)}'] = stacked_arg + key = get_summary_writer_tag_name(cc_name, f'{prefix}_{len(tensor_args)}', RANK) + tensor_args[key] = stacked_arg - new_data = {op: get_metrics(op, tensor_args, 1e-8) for op in ops} + new_data = get_metrics(ops, tensor_args, 1e-8) cc_context.data=update_data(cc_context.data, new_data) -def create_async_callback_func(context, ops, args, prefix): +def create_async_callback_func(context, cc_name, ops, args, prefix): def store_data(): - catch_data(context, ops, args, prefix) + catch_data(context, cc_name, ops, args, prefix) return store_data def create_hooks(context, monitor): - def cc_log_hook(module, args, kwargs): stack = ';'.join(get_callstack()) monitor.cc_logged_stack[module.op_name_].add(stack) @@ -201,7 +209,7 @@ def create_hooks(context, monitor): if not is_target_line(monitor.cc_codeline): return args = args + tuple(kwargs.values()) - catch_data(context[module.op_name_], monitor.ops, args, 'pre') + catch_data(context[module.op_name_], module.op_name_, monitor.ops, args, PREFIX_PRE) return def cc_hook(module, args, kwargs, out=None): @@ -210,17 +218,19 @@ def create_hooks(context, monitor): args = args + tuple(kwargs.values()) if out: # async if isinstance(out, dist.Work): - PENDING_ASYNC_CC_BY_HANDLE[out] = create_async_callback_func(context[module.op_name_], monitor.ops, args, PREFIX_POST) + PENDING_ASYNC_CC_BY_HANDLE[out] = create_async_callback_func(context[module.op_name_], module.op_name_, monitor.ops, args, PREFIX_POST) elif isinstance(out, list): # batch_isend_irecv for o in out: - PENDING_ASYNC_CC_BY_HANDLE[o] = create_async_callback_func(context[module.op_name_], monitor.ops, args, PREFIX_POST) + PENDING_ASYNC_CC_BY_HANDLE[o] = create_async_callback_func(context[module.op_name_], module.op_name_, monitor.ops, args, PREFIX_POST) return out - catch_data(context[module.op_name_], monitor.ops, args, PREFIX_POST) + catch_data(context[module.op_name_], module.op_name_, monitor.ops, args, PREFIX_POST) return out + global RANK pre_hooks = [] hooks = [] - if (dist.is_initialized() and dist.get_rank() not in monitor.module_rank_list and monitor.module_rank_list != []): + RANK = dist.get_rank() + if (dist.is_initialized() and RANK not in monitor.module_rank_list and monitor.module_rank_list != []): return [pre_hooks, hooks] if monitor.cc_log_only: diff --git a/debug/accuracy_tools/monitor/monitor/module_hook.py b/debug/accuracy_tools/monitor/monitor/module_hook.py index 3b3e25ea3f22c212dcb36450961e45b46591a0b8..7a5ec440499633cc51b15a29967bf1c323d6c214 100644 --- a/debug/accuracy_tools/monitor/monitor/module_hook.py +++ b/debug/accuracy_tools/monitor/monitor/module_hook.py @@ -20,7 +20,7 @@ from monitor.features import eff_rank, get_sign_matches from monitor.visualizer import HeatmapVisualizer from monitor.anomaly_detect import AnomalyScanner, AnomalyDataFactory, SummaryWriterWithAD, CSVWriterWithAD, BaseWriterWithAD from monitor.anomaly_analyse import AnomalyDataWriter -from monitor.module_metric import get_metrics, write_metrics_tensorboard, write_metrics_csv, get_summary_writer_tag_name, TensorMetrics, squash_param_name, sqrt_norm_metric, reorder_metric +from monitor.module_metric import get_metrics, write_metrics_base, write_metrics_csv, get_summary_writer_tag_name, TensorMetrics, squash_param_name from monitor.distributed.wrap_distributed import api_register, create_hooks, op_aggregate, get_process_group from monitor.utils import print_warn_log, print_info_log, print_error_log, get_param_struct, validate_config, validate_ops from monitor.const import Const @@ -78,7 +78,7 @@ class OptimizerContext: self.exp_avg_metric = [] self.param_exp_avg_sq = defaultdict() self.exp_avg_sq_metric = [] - self.metric_list = [] + self.metric_dict = {} class CommunicationContext: @@ -88,10 +88,10 @@ class CommunicationContext: @staticmethod def _agg(data): aggregated_data = {} - for op, tag2tensorlist in data.items(): - aggregated_data[op] = {} - for tag, tensorlist in tag2tensorlist.items(): - aggregated_data[op][tag] = op_aggregate(op, tensorlist) + for tag, op2tensorlist in data.items(): + aggregated_data[tag] = {} + for op, tensorlist in op2tensorlist.items(): + aggregated_data[tag][op] = op_aggregate(op, tensorlist) return aggregated_data def reset(self): @@ -102,9 +102,9 @@ class CommunicationContext: class GradContext: def __init__(self) -> None: - self.pre = [] - self.post = [] - self.acc_metric = [] + self.pre = {} + self.post = {} + self.acc_metric = {} self.acc = {} self.actv = {} @@ -198,13 +198,13 @@ class TrainerMon: if self.format == 'tensorboard': writer = SummaryWriterWithAD - self.write_metrics = write_metrics_tensorboard + self.write_metrics = write_metrics_base elif self.format == 'csv': writer = CSVWriterWithAD self.write_metrics = write_metrics_csv elif self.format == 'api': writer = BaseWriterWithAD - self.write_metrics = write_metrics_tensorboard + self.write_metrics = write_metrics_base if (rank in self.module_rank_list) or len(self.module_rank_list) == 0: @@ -234,11 +234,14 @@ class TrainerMon: self.vpp = False self.dp_group = None self.tp_group = None + self.enable_megatron = False self.param2name = defaultdict(str) self.name2index = defaultdict() self.name2indices = defaultdict() self.name2param = {} + self.duplicate_param = {} + self.name2tag = {} self.param_name_call_id = {} self.call_id = 0 self.grad_accs = [] @@ -293,16 +296,6 @@ class TrainerMon: metrics[key] = tensor return metrics - @staticmethod - def generate_cc_metrics(cc_name, cc_tensor): - metrics = defaultdict(dict) - rank = dist.get_rank() if dist.is_initialized() else None - for op, tag2tensor in cc_tensor.data.items(): - for tag, tensor in tag2tensor.items(): - key = get_summary_writer_tag_name(cc_name, tag, rank) - metrics[op].update({key: tensor}) - cc_tensor.reset() - return metrics def hook_modules(self, model:torch.nn.Module, grad_acc_steps): if self.module_rank_list and (self.rank not in self.module_rank_list): @@ -358,42 +351,33 @@ class TrainerMon: opt_context.exp_avg_metric = {} opt_context.exp_avg_sq_metric = {} m_tag_tensor_map = self.generate_param_metrics('exp_avg', opt_context.param_exp_avg) - v_tag_tensor_map = self.generate_param_metrics('exp_avg_sq', opt_context.param_exp_avg_sq) - for metric_name in self.ops: - opt_context.exp_avg_metric[metric_name] = get_metrics(metric_name, m_tag_tensor_map, self.eps) - opt_context.exp_avg_sq_metric[metric_name] = get_metrics(metric_name, v_tag_tensor_map, self.eps) + v_tag_tensor_map = self.generate_param_metrics('efxp_avg_sq', opt_context.param_exp_avg_sq) + get_metrics(self.ops, m_tag_tensor_map, self.eps, opt_context.exp_avg_metric) + get_metrics(self.ops, v_tag_tensor_map, self.eps, opt_context.exp_avg_sq_metric) def generate_wgrad_metrics(self): if not self.wg_distribution: return {}, {} - unreduced = {} if self.weight_hooked: - for metric_name in self.ops: - unreduced[metric_name] = get_metrics(metric_name, self.grad_context.acc, self.eps) - self.grad_context.acc_metric = [unreduced.copy()] - sqrt_norm_metric(unreduced) - unreduced = reorder_metric(unreduced) + get_metrics(self.ops, self.grad_context.acc, self.eps, self.grad_context.acc_metric) grad_dict = {} for param, name in self.param2name.items(): - if self.tp_group and not param_is_not_tensor_parallel_duplicate(param, self.tp_group): + if self.duplicate_param.get(name, False): continue - if self.dp_group and param_is_data_parallel_duplicate(self.dp_group): - continue grad = param.main_grad if self.params_have_main_grad else param.grad if grad is None: print_warn_log(f"grad is None: {name}, maybe something wrong happened.") continue - key = get_summary_writer_tag_name(name, 'post_grad', self.rank) - grad_dict[key] = grad + tag = self.name2tag.get(name, {}).get('post') + if tag is None: + continue + grad_dict[tag] = grad - reduced = {op: get_metrics(op, grad_dict, self.eps) for op in self.ops} - self.grad_context.post = [reduced.copy()] - sqrt_norm_metric(reduced) - reduced = reorder_metric(reduced) + get_metrics(self.ops, grad_dict, self.eps, self.grad_context.post) - return reduced, unreduced + return self.grad_context.post, self.grad_context.acc_metric def monitor_gnorm_with_ad(self, model, grad_acc_steps=1, optimizer=None, tp_group=None, dp_group=None): print_info_log(f'grad acc steps {grad_acc_steps}') @@ -404,7 +388,7 @@ class TrainerMon: self.tp_group = tp_group self._register_param_name(model) - self._hook_weights() + self._patch_grad_sync() self.hook_modules(model, grad_acc_steps) def generate_param_metrics(self, tag, param_tensor): @@ -420,16 +404,9 @@ class TrainerMon: def generate_xy_metrics(self): actv = {} for fwd_context in self.module_fwd_hook_context_by_module.values(): - for op in self.ops: - if op not in actv: - actv[op] = {} - actv[op].update(fwd_context.actv[op]) - sqrt_norm_metric(actv) - actv = reorder_metric(actv) - - actv_grad = deepcopy(self.grad_context.actv) - sqrt_norm_metric(actv_grad) - actv_grad = reorder_metric(actv_grad) + actv.update(fwd_context.actv) + + actv_grad = self.grad_context.actv return actv, actv_grad @@ -449,26 +426,29 @@ class TrainerMon: def write_xy_tb(self, step): if not self.xy_distribution: return + actv = {} for _, fwd_context in self.module_fwd_hook_context_by_module.items(): if len(fwd_context.actv) == 0: continue - self.write_metrics(self.ops, self.summary_writer, [fwd_context.actv], step, 'actv') + actv.update(fwd_context.actv) fwd_context.actv.clear() + + self.write_metrics(self.ops, self.summary_writer, actv, step, 'actv') if self.grad_context.actv: - self.write_metrics(self.ops, self.summary_writer, [self.grad_context.actv], step, 'actv_grad') + self.write_metrics(self.ops, self.summary_writer, self.grad_context.actv, step, 'actv_grad') def write_mv_tb(self, opt_context): if not self.mv_distribution: return - self.write_metrics(self.ops, self.summary_writer, [opt_context.exp_avg_metric], opt_context.step, 'exp_avg') - self.write_metrics(self.ops, self.summary_writer, [opt_context.exp_avg_sq_metric], opt_context.step, 'exp_avg_sq') + self.write_metrics(self.ops, self.summary_writer, opt_context.exp_avg_metric, opt_context.step, 'exp_avg') + self.write_metrics(self.ops, self.summary_writer, opt_context.exp_avg_sq_metric, opt_context.step, 'exp_avg_sq') def write_grad_tb(self, step): if not self.wg_distribution: return self.write_metrics(self.ops, self.summary_writer, self.grad_context.post, step, 'grad_reduced') - self.write_metrics(self.ops, self.summary_writer, self.grad_context.acc_metric, step, 'grad_unreduced') + self.write_metrics(self.ops, self.summary_writer, self.grad_context.pre, step, 'grad_unreduced') def hook_optimizer(self, optimizer=None): # in DDP by default use params_have_main_grad @@ -516,17 +496,15 @@ class TrainerMon: tbtag_tensor_map.update(self.generate_param_metrics('mg_direction', context.param_mg_direction)) metric_dict = {} - for metric_name in self.ops: - metric_dict[metric_name] = get_metrics(metric_name, tbtag_tensor_map, self.eps) - for k, c in self.cc_context.items(): - c.aggregate() - cc_metrics = self.generate_cc_metrics(k, c) - for op, m in cc_metrics.items(): - metric_dict[op].update(m) + get_metrics(self.ops, tbtag_tensor_map, self.eps, metric_dict) + for cc in self.cc_context.values(): + cc.aggregate() + metric_dict.update(cc.data) + cc.reset() if not metric_dict: return - context.metric_list.append(metric_dict) + context.metric_dict = metric_dict return def optimizer_post_step_hook(optimizer, args, kwargs): @@ -549,16 +527,14 @@ class TrainerMon: for param_name, _ in context.param_adam_ratio.items(): self.ratio_heatmap_visualizer[param_name].visualize(get_summary_writer_tag_name(param_name, 'adam_ratio', rank), context.step, self.summary_writer) - if context.metric_list: - self.write_metrics(self.ops, self.summary_writer, context.metric_list, context.step, 'other') - context.metric_list.clear() + if context.metric_dict: + self.write_metrics(self.ops, self.summary_writer, context.metric_dict, context.step, 'other') + context.metric_dict.clear() context.step += 1 - self.grad_context.reset() if self.anomaly_data_factory: self.anomaly_data_writer.write_detected_json(self.summary_writer.get_anomalies()) self.summary_writer.clear_anomalies() self.call_id = 0 - self.param_name_call_id.clear() return def patch_step(func, optimizer): @@ -619,6 +595,14 @@ class TrainerMon: self.param2name[param] = name self.name2param[name] = param self.name2index[name] = index + + if self.tp_group and not param_is_not_tensor_parallel_duplicate(param, self.tp_group): + self.duplicate_param[name] = True + if self.dp_group and param_is_data_parallel_duplicate(self.dp_group): + self.duplicate_param[name] = True + self.name2tag[name] = {} + self.name2tag[name]['pre'] = get_summary_writer_tag_name(name, 'pre_grad', self.rank) + self.name2tag[name]['post'] = get_summary_writer_tag_name(name, 'post_grad', self.rank) def _register_param_name(self, model): if self.param_registered: @@ -721,13 +705,8 @@ class TrainerMon: cared_output = module_output if context.focused_out_col is None else module_output[context.focused_out_col] tbtag_tensor_map.update(self.build_tbtag_tensor_map(f'{context.module_name}_{context.micro_step}', Const.ACTV_OUT, cared_output)) - for metric_name in self.ops: - if context.micro_step == 0 and context.actv.get(metric_name, []): - print_warn_log( - f"actv context of {context.module_name} is not empty when first micro_step, maybe something wrong happened. Now clear it.") - context.actv.clear() - context.actv[metric_name].update(get_metrics(metric_name, tbtag_tensor_map, self.eps)) - + get_metrics(self.ops, tbtag_tensor_map, self.eps, context.actv) + context.micro_step += 1 if context.micro_step == self.micro_batch_number: context.micro_step = 0 @@ -765,10 +744,7 @@ class TrainerMon: print_warn_log(f"actvgrad context of {context.module_name} is not empty when first micro_step, maybe something wrong happened. Now clear it.") context.actvgrad.clear() - for metric_name in self.ops: - if metric_name not in self.grad_context.actv: - self.grad_context.actv[metric_name] = {} - self.grad_context.actv[metric_name].update(get_metrics(metric_name, tbtag_tensor_map, self.eps)) + get_metrics(self.ops, tbtag_tensor_map, self.eps, self.grad_context.actv) context.micro_step += 1 if context.micro_step == self.micro_batch_number: @@ -797,14 +773,43 @@ class TrainerMon: hooked_count += 1 return hooked_count + def _patch_grad_sync(self): + def patch_sync(sync_grad_func): + def wrapper(bucket): + grad_dict = {} + for param, name in self.param2name.items(): + if param not in bucket.params_list: + continue + grad = param.main_grad if self.params_have_main_grad else param.grad + if grad is None: + print_warn_log(f"grad is None: {name}, maybe something wrong happened.") + continue + tag = self.name2tag.get(name,{}).get('pre') + if tag is None: + continue + grad_dict[tag] = grad + get_metrics(self.ops, grad_dict, self.eps, self.grad_context.pre) + out = sync_grad_func(bucket) + return out + return wrapper + try: + from megatron.core.distributed.param_and_grad_buffer import Bucket + self.enable_megatron = True + except ImportError: + self.enable_megatron = False + + if self.enable_megatron: + Bucket.start_grad_sync = patch_sync(Bucket.start_grad_sync) # differ in different megatron version + def _hook_weights(self): context = self.grad_context @torch.no_grad def param_hook(*args, context_dict, param, key, name): param.micro_step += 1 - self.param_name_call_id[name] = self.call_id - self.call_id += 1 + if self.anomaly_data_factory: + self.param_name_call_id[name] = self.call_id + self.call_id += 1 if param.micro_step == self.micro_batch_number: param.micro_step = 0 if self.params_have_main_grad: diff --git a/debug/accuracy_tools/monitor/monitor/module_metric.py b/debug/accuracy_tools/monitor/monitor/module_metric.py index 3c99afed187ea99dcaca25caed710091f1230d58..915a403604d493a6b5f598a970a9bb32223a684d 100644 --- a/debug/accuracy_tools/monitor/monitor/module_metric.py +++ b/debug/accuracy_tools/monitor/monitor/module_metric.py @@ -35,7 +35,7 @@ def register_config_metric(key, cls=None): if cls is None: # 无参数时,返回装饰器函数 return lambda cls: register_config_metric(key, cls) - config_metric_registry[key] = cls + config_metric_registry[key] = cls() return cls class TensorMetrics: @@ -65,22 +65,13 @@ class TensorMetrics: class Metric(object): @staticmethod def get_metric_value(tensor, eps): - pass + raise NotImplementedError - @staticmethod - def metric_tensorboard(metric_name, summary_writer, metric_value, step): - pass - - def get_metrics(self, tag2tensor: dict, eps): - metrics_dict = {} - for tag, tensor in tag2tensor.items(): - try: - metrics_dict[tag] = self.get_metric_value(tensor, eps) - if torch.isnan(metrics_dict[tag]): - print_warn_log(f'nan when calculate metric for {tag}') - except RuntimeError as e: - metrics_dict[tag] = torch.tensor(torch.nan) - return metrics_dict + def get_metric(self, tensor, eps): + try: + return self.get_metric_value(tensor, eps) + except RuntimeError as e: + return torch.tensor(torch.nan).to(tensor.device) @register_config_metric("min") class MinMetric(Metric): @@ -88,12 +79,6 @@ class MinMetric(Metric): def get_metric_value(tensor, eps): return get_min(tensor) - @staticmethod - def metric_tensorboard(metric_name, summary_writer, metric_value, step): - for key in metric_value[0][metric_name].keys(): - min_value = min([item[metric_name][key].item() for item in metric_value]) - summary_writer.add_scalar(f'{key}_min', min_value, step) - @register_config_metric("mean") class MeanMetric(Metric): @@ -101,12 +86,6 @@ class MeanMetric(Metric): def get_metric_value(tensor, eps): return get_mean(tensor) - @staticmethod - def metric_tensorboard(metric_name, summary_writer, metric_value, step): - for key in metric_value[0][metric_name].keys(): - mean_value = sum([item[metric_name][key].item() for item in metric_value]) / len(metric_value) - summary_writer.add_scalar(f'{key}_mean', mean_value, step) - @register_config_metric("max") class MaxMetric(Metric): @@ -114,24 +93,13 @@ class MaxMetric(Metric): def get_metric_value(tensor, eps): return get_max(tensor) - @staticmethod - def metric_tensorboard(metric_name, summary_writer, metric_value, step): - for key in metric_value[0][metric_name].keys(): - max_value = max([item[metric_name][key].item() for item in metric_value]) - summary_writer.add_scalar(f'{key}_max', max_value, step) - @register_config_metric("norm") class NormMetric(Metric): @staticmethod def get_metric_value(tensor, eps): - return square_sum(tensor) + return get_norm(tensor) - @staticmethod - def metric_tensorboard(metric_name, summary_writer, metric_value, step): - for key in metric_value[0][metric_name].keys(): - norm_value = math.sqrt(sum([item[metric_name][key].item() for item in metric_value])) - summary_writer.add_scalar(f'{key}_norm', norm_value, step) @register_config_metric("zeros") class ZerosMetric(Metric): @@ -139,11 +107,6 @@ class ZerosMetric(Metric): def get_metric_value(tensor, eps): return get_zeros(tensor, eps) - @staticmethod - def metric_tensorboard(metric_name, summary_writer, metric_value, step): - for key in metric_value[0][metric_name].keys(): - zeros_value = statistics.mean([item[metric_name][key].item() for item in metric_value]) - summary_writer.add_scalar(f'{key}_zeros', zeros_value, step) @register_config_metric("nans") class NaNsMetric(Metric): @@ -151,11 +114,6 @@ class NaNsMetric(Metric): def get_metric_value(t, eps): return get_nans(t) - @staticmethod - def metric_tensorboard(metric_name, summary_writer, metric_value, step): - for key in metric_value[0][metric_name].keys(): - nans_value = sum([v[metric_name][key].item() for v in metric_value]) - summary_writer.add_scalar(f'{key}_nans', nans_value, step) @register_config_metric("id") class IdentMetric(Metric): @@ -165,51 +123,34 @@ class IdentMetric(Metric): return None return tensor - @staticmethod - def metric_tensorboard(metric_name, summary_writer, metric_value, context): #metric_value is a dict, key is parameter name and value is a list of scalar tensor - if len(metric_value) == 1: - for key, value in metric_value[0][metric_name].items(): - if not value: - continue - summary_writer.add_scalar(f'{key}_identical', value.item(), context) - -def reorder_metric(metrics): - new_metrics = {} - for op, tag2metric in metrics.items(): - for tag, metric in tag2metric.items(): - if tag not in new_metrics: - new_metrics[tag] = {} - new_metrics[tag][op] = metric - return new_metrics - -def sqrt_norm_metric(metrics): - if 'norm' in metrics: - metrics["norm"] = {tag:metric**0.5 for tag, metric in metrics["norm"].items()} - -def get_metrics(metric_name, tag2tensor, eps): - try: - fun_metric = config_metric_registry[metric_name] - return fun_metric().get_metrics(tag2tensor, eps) - except KeyError as e: - raise ValueError(f"Not supported this metric, expected metric: {config_metric_registry.keys()}, actual metric: {metric_name}") from e + +def get_metrics(ops, tag2tensor, eps, out_dict=None): + if out_dict is None: + out_dict = {} + for tag, tensor in tag2tensor.items(): + if tag not in out_dict: + out_dict[tag] = {} + for metric_name in ops: + fun_metric = config_metric_registry.get(metric_name) + out_dict[tag][metric_name] = fun_metric.get_metric(tensor, eps) + return out_dict + -def write_metrics_tensorboard(ops, summary_writer, metric_value, step, prefix=''): - for metric_name in ops: - try: - fun_metric = config_metric_registry[metric_name] - fun_metric.metric_tensorboard(metric_name, summary_writer, metric_value, step) - except KeyError as e: - raise ValueError(f"Not supported this metric, expected metric: {config_metric_registry.keys()}, actual metric: {metric_name}") from e +def write_metrics_base(ops, summary_writer, metric_value, step, prefix=''): + if not metric_value: + return + tensors = [] + tags = list(itertools.product(metric_value.keys(), ops)) + for op2tensor in metric_value.values(): + tensors.extend(op2tensor.values()) + with torch.no_grad(): + metric_list = torch.stack(tensors).squeeze().cpu() + for tag, metric in zip(tags, metric_list): + summary_writer.add_scalar(tag, metric, step) def write_metrics_csv(ops, summary_writer, metric_value, step, prefix=''): - for metric_name in ops: - try: - fun_metric = config_metric_registry[metric_name] - fun_metric.metric_tensorboard(metric_name, summary_writer, metric_value, step) - - except KeyError as e: - raise ValueError(f"Not supported this metric, expected metric: {config_metric_registry.keys()}, actual metric: {metric_name}") from e - + write_metrics_base(ops, summary_writer, metric_value, step, prefix) + if not summary_writer.header: if prefix == 'actv': summary_writer.header = ['module_name'] @@ -221,7 +162,7 @@ def write_metrics_csv(ops, summary_writer, metric_value, step, prefix=''): else: summary_writer.header.extend(ops) - for key in metric_value[0][ops[0]].keys(): + for key in metric_value.keys(): if Const.VPP_SEP in key: summary_writer.header.insert(0, 'vpp_stage') break