From e0f31bb978a1156391f0a41569e1359295081ddc Mon Sep 17 00:00:00 2001 From: fengjk12138 Date: Sat, 3 Aug 2024 14:20:31 +0000 Subject: [PATCH] =?UTF-8?q?Increase=20load=20state=20dict=20robustness=20?= =?UTF-8?q?=E5=A6=82=E6=9E=9C=E4=BB=8Envidia=E7=AD=89=E5=85=B6=E4=BB=96?= =?UTF-8?q?=E5=B9=B3=E5=8F=B0=E8=BF=81=E7=A7=BBcheckpoint=E8=BF=87?= =?UTF-8?q?=E6=9D=A5=E7=BB=A7=E7=BB=AD=E8=AE=AD=E7=BB=83=EF=BC=8C=E9=9C=80?= =?UTF-8?q?=E8=A6=81load=20optimizer=E7=9A=84=E7=8A=B6=E6=80=81=EF=BC=8C?= =?UTF-8?q?=E4=BD=86=E6=98=AFcuda=E5=B9=B3=E5=8F=B0=E4=B8=8A=E6=B2=A1?= =?UTF-8?q?=E6=9C=89`self.=5Fdynamic`=E8=BF=99=E4=B8=AA=E5=8F=98=E9=87=8F?= =?UTF-8?q?=E3=80=82=E9=98=B2=E6=AD=A2=E5=87=BA=E9=94=99=EF=BC=8C=E5=A2=9E?= =?UTF-8?q?=E5=8A=A0=E4=B8=80=E4=B8=AAif=E8=AF=AD=E5=8F=A5=E8=BF=9B?= =?UTF-8?q?=E8=A1=8C=E5=88=A4=E6=96=AD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: fengjk12138 --- torch_npu/npu/amp/grad_scaler.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/torch_npu/npu/amp/grad_scaler.py b/torch_npu/npu/amp/grad_scaler.py index 831e9b66cb2..f56a39b9df1 100644 --- a/torch_npu/npu/amp/grad_scaler.py +++ b/torch_npu/npu/amp/grad_scaler.py @@ -456,7 +456,11 @@ class GradScaler(BaseGradScaler): "from a disabled instance of GradScaler." + pta_error(ErrCode.VALUE)) super(GradScaler, self).load_state_dict(state_dict) - self._dynamic = state_dict["dynamic"] + + if 'dynamic' in state_dict: + self._dynamic = state_dict["dynamic"] + else: + warnings.warn("`dynamic` variable not found in loss_scaler's state_dict, default value will be used") @staticmethod def get_npu_overflow_flag(): -- Gitee