diff --git a/PyTorch/dev/cv/image_segmentation/DeepLabV3+_ID0458_for_PyTorch/main.py b/PyTorch/dev/cv/image_segmentation/DeepLabV3+_ID0458_for_PyTorch/main.py index a73edf64048cf1ab5f3db50e2bdc2b43080894d0..600ee7a85685322159fb530512523ce20e851a80 100644 --- a/PyTorch/dev/cv/image_segmentation/DeepLabV3+_ID0458_for_PyTorch/main.py +++ b/PyTorch/dev/cv/image_segmentation/DeepLabV3+_ID0458_for_PyTorch/main.py @@ -56,6 +56,7 @@ import torch.npu import apex from apex import amp,optimizers import sys +from torch.contrib.npu.optimized_lib.module.prefetcher import Prefetcher as Prefetcher import os NPU_CALCULATE_DEVICE = 0 @@ -397,7 +398,11 @@ def main(): model.train() cur_epochs += 1 end = time.time() - for (images, labels) in train_loader: + + data_prefetcher_stream = torch.npu.Stream() + prefetcher = Prefetcher(train_loader, stream=data_prefetcher_stream) + images,labels = prefetcher.next() + while images is not None: cur_itrs += 1 images = images.to(f'npu:{NPU_CALCULATE_DEVICE}', dtype=torch.float32, non_blocking=True) @@ -455,7 +460,8 @@ def main(): if cur_itrs >= opts.total_itrs: return - + images,labels = prefetcher.next() + class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self):