From e2085bb782bfab698b0b58610e0f0cbfa6c99ea7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E4=BC=9F=E6=A0=B9?= <1101204667@qq.com> Date: Wed, 30 Mar 2022 08:48:13 +0000 Subject: [PATCH] update train.py. --- .../cv/image_classification/VGG16_ID0467_for_PyTorch/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/PyTorch/dev/cv/image_classification/VGG16_ID0467_for_PyTorch/train.py b/PyTorch/dev/cv/image_classification/VGG16_ID0467_for_PyTorch/train.py index 4e9f258d82..f9045e4110 100644 --- a/PyTorch/dev/cv/image_classification/VGG16_ID0467_for_PyTorch/train.py +++ b/PyTorch/dev/cv/image_classification/VGG16_ID0467_for_PyTorch/train.py @@ -62,7 +62,7 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, pri for image, target in metric_logger.log_every(data_loader, print_freq, header): start_time = time.time() #image, target = image.to(device), target.to(device) - image, target = image.to(device), target.to(torch.int).to(device) + image, target = image.to(device, non_blocking=True), target.to(torch.int).to(device, non_blocking=True) output = model(image) loss = criterion(output, target) -- Gitee