diff --git a/debug/accuracy_tools/kj600/kj600/features.py b/debug/accuracy_tools/kj600/kj600/features.py index 302ab3f8c550260ca1cca0fa1b67e965e3c90160..d8ed4521b950c775cf6e474d1e5ae9bccf1b77f4 100644 --- a/debug/accuracy_tools/kj600/kj600/features.py +++ b/debug/accuracy_tools/kj600/kj600/features.py @@ -80,6 +80,7 @@ def lambda_max_subsample(module: torch.nn.Module, x: torch.tensor, num_iteration def cal_histc(tensor_cal, bins_total, min_val, max_val): return torch.histc(tensor_cal, bins=bins_total, min=min_val, max=max_val) + @torch.no_grad() def get_nans(t): - return torch.isnan(t).sum() + return torch.isnan(t).float().norm(p=1)