diff --git a/debug/accuracy_tools/kj600/kj600/features.py b/debug/accuracy_tools/kj600/kj600/features.py index 302ab3f8c550260ca1cca0fa1b67e965e3c90160..0dfd2d4be994ddd29eab42a2677de3c56fc1e442 100644 --- a/debug/accuracy_tools/kj600/kj600/features.py +++ b/debug/accuracy_tools/kj600/kj600/features.py @@ -82,4 +82,4 @@ def cal_histc(tensor_cal, bins_total, min_val, max_val): @torch.no_grad() def get_nans(t): - return torch.isnan(t).sum() + return torch.isnan(t).float().norm(p=1)