diff --git a/torch_npu/npu/amp/grad_scaler.py b/torch_npu/npu/amp/grad_scaler.py index 831e9b66cb2e0a74fafab1ef1107656e9ae8239c..f56a39b9df154a9c4d0b50059146813bbf337541 100644 --- a/torch_npu/npu/amp/grad_scaler.py +++ b/torch_npu/npu/amp/grad_scaler.py @@ -456,7 +456,11 @@ class GradScaler(BaseGradScaler): "from a disabled instance of GradScaler." + pta_error(ErrCode.VALUE)) super(GradScaler, self).load_state_dict(state_dict) - self._dynamic = state_dict["dynamic"] + + if 'dynamic' in state_dict: + self._dynamic = state_dict["dynamic"] + else: + warnings.warn("`dynamic` variable not found in loss_scaler's state_dict, default value will be used") @staticmethod def get_npu_overflow_flag():