diff --git a/PyTorch/built-in/cv/detection/DB_ID0706_for_PyTorch/trainer.py b/PyTorch/built-in/cv/detection/DB_ID0706_for_PyTorch/trainer.py index 9c1366b7dc1b91e444186f683cd5253eb719e8ed..7cd849ac0313ef94f0be0373c049af56e6a0f8ce 100644 --- a/PyTorch/built-in/cv/detection/DB_ID0706_for_PyTorch/trainer.py +++ b/PyTorch/built-in/cv/detection/DB_ID0706_for_PyTorch/trainer.py @@ -115,6 +115,9 @@ class Trainer: if self.experiment.amp: amp.register_float_function(torch, "sigmoid") model, optimizer = amp.initialize(model, optimizer, opt_level="O1", loss_scale=1024, combine_grad=True) + if self.experiment.amp and self.experiment.opt_level == 'O0': + amp.register_float_function(torch, "sigmoid") + model, optimizer = amp.initialize(model, optimizer, opt_level="O0", combine_grad=False) self.logger.report_time('Init') batch_time = AverageMeter() data_time = AverageMeter()