From 7a520897c08b33f425ce3afc6487aab2fa76c7b4 Mon Sep 17 00:00:00 2001 From: Ryan Date: Thu, 31 Mar 2022 10:50:11 +0000 Subject: [PATCH] =?UTF-8?q?SSD-MobileNet=5FID1936=5Ffor=5FPyTorch=E5=A2=9E?= =?UTF-8?q?=E5=8A=A0=E6=A8=A1=E7=B3=8A=E7=BC=96=E8=AF=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../SSD-MobileNet_ID1936_for_PyTorch/train.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/PyTorch/dev/cv/image_classification/SSD-MobileNet_ID1936_for_PyTorch/train.py b/PyTorch/dev/cv/image_classification/SSD-MobileNet_ID1936_for_PyTorch/train.py index 5b78a89b3e..bc8ca196dc 100644 --- a/PyTorch/dev/cv/image_classification/SSD-MobileNet_ID1936_for_PyTorch/train.py +++ b/PyTorch/dev/cv/image_classification/SSD-MobileNet_ID1936_for_PyTorch/train.py @@ -40,6 +40,7 @@ Created on Sat Jun 10 15:45:16 2019 @author: viswanatha """ +import torch import time import torch.backends.cudnn as cudnn import torch.optim @@ -58,6 +59,8 @@ try: except ImportError: amp = None import apex +torch.npu.set_start_fuzz_compile_step(3) + NPU_CALCULATE_DEVICE = 0 if os.getenv('NPU_CALCULATE_DEVICE') and str.isdigit(os.getenv('NPU_CALCULATE_DEVICE')): NPU_CALCULATE_DEVICE = int(os.getenv('NPU_CALCULATE_DEVICE')) @@ -86,6 +89,7 @@ def train(train_loader, model, criterion, optimizer, epoch, grad_clip, args): # Batches for i, (images, boxes, labels, _) in enumerate(train_loader): + torch.npu.global_step_inc() data_time.update(time.time() - start) start_time = time.time() -- Gitee