From d2eb1ffcc53f1dae1d7d45a0f8020198f2beba27 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E5=87=8C?= Date: Tue, 17 Jan 2023 03:19:25 +0000 Subject: [PATCH] update PyTorch/built-in/cv/detection/DB_ID0706_for_PyTorch/trainer.py. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 王凌 --- PyTorch/built-in/cv/detection/DB_ID0706_for_PyTorch/trainer.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/PyTorch/built-in/cv/detection/DB_ID0706_for_PyTorch/trainer.py b/PyTorch/built-in/cv/detection/DB_ID0706_for_PyTorch/trainer.py index 9c1366b7dc..7cd849ac03 100644 --- a/PyTorch/built-in/cv/detection/DB_ID0706_for_PyTorch/trainer.py +++ b/PyTorch/built-in/cv/detection/DB_ID0706_for_PyTorch/trainer.py @@ -115,6 +115,9 @@ class Trainer: if self.experiment.amp: amp.register_float_function(torch, "sigmoid") model, optimizer = amp.initialize(model, optimizer, opt_level="O1", loss_scale=1024, combine_grad=True) + if self.experiment.amp and self.experiment.opt_level == 'O0': + amp.register_float_function(torch, "sigmoid") + model, optimizer = amp.initialize(model, optimizer, opt_level="O0", combine_grad=False) self.logger.report_time('Init') batch_time = AverageMeter() data_time = AverageMeter() -- Gitee