From 06829ae7233ca9c4ed4136e864a80c6596821140 Mon Sep 17 00:00:00 2001 From: pengxiaopeng Date: Mon, 1 Apr 2024 10:29:34 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E6=A2=AF=E5=BA=A6=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E7=B1=BB=E5=9E=8B=E4=B8=8D=E4=B8=BAtorch.float32?= =?UTF-8?q?=E6=97=B6=E4=BF=9D=E5=AD=98=E6=96=B9=E5=90=91=E6=95=B0=E6=8D=AE?= =?UTF-8?q?=E6=8A=A5=E9=94=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- debug/accuracy_tools/grad_tool/level_adapter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/debug/accuracy_tools/grad_tool/level_adapter.py b/debug/accuracy_tools/grad_tool/level_adapter.py index 51e6717d94..64bb1e92a9 100644 --- a/debug/accuracy_tools/grad_tool/level_adapter.py +++ b/debug/accuracy_tools/grad_tool/level_adapter.py @@ -34,7 +34,7 @@ class LevelOps: def save_grad_direction(param_name, grad, save_path): if not os.path.exists(save_path): os.makedirs(save_path) - param_grad = torch.Tensor(grad.clone().cpu()) + param_grad = grad.clone().detach() is_positive = param_grad > 0 torch.save(is_positive, f'{save_path}/{param_name}.pt') print_info_log(f'Save {param_name} bool tensor, it has {is_positive.sum()}/{is_positive.numel()} positive elements') -- Gitee