diff --git a/PyTorch/dev/cv/image_classification/CycleGAN_ID0521_for_PyTorch/train.py b/PyTorch/dev/cv/image_classification/CycleGAN_ID0521_for_PyTorch/train.py index 0539810524ada82ff1d8c2e6b63ed9dda521fef5..ad942495bba0e244ea462cee484f57f068dfb4c1 100644 --- a/PyTorch/dev/cv/image_classification/CycleGAN_ID0521_for_PyTorch/train.py +++ b/PyTorch/dev/cv/image_classification/CycleGAN_ID0521_for_PyTorch/train.py @@ -227,8 +227,8 @@ for epoch in range(0, args.epochs): break start_time = time.time() # get batch size data - real_image_A = data["A"].to(device) - real_image_B = data["B"].to(device) + real_image_A = data["A"].to(device,non_blocking=True) + real_image_B = data["B"].to(device,non_blocking=True) batch_size = real_image_A.size(0) # real data label is 1, fake data label is 0.