diff --git a/PyTorch/built-in/cv/classification/MobileNetV2_for_PyTorch/train/mobilenetv2_8p_main_anycard.py b/PyTorch/built-in/cv/classification/MobileNetV2_for_PyTorch/train/mobilenetv2_8p_main_anycard.py index b38bdd8a852b40659a960b4f57424f7f25cc377f..d631020f0dd1649a9288de4092ea2840e128cf87 100644 --- a/PyTorch/built-in/cv/classification/MobileNetV2_for_PyTorch/train/mobilenetv2_8p_main_anycard.py +++ b/PyTorch/built-in/cv/classification/MobileNetV2_for_PyTorch/train/mobilenetv2_8p_main_anycard.py @@ -341,61 +341,69 @@ def train(train_loader, train_loader_len, model, criterion, optimizer, epoch, ar steps_per_epoch = train_loader_len print('==========step per epoch======================', steps_per_epoch) for i, (images, target) in enumerate(train_loader): - if i > 200 : - pass - # measure data loading time - data_time.update(time.time() - end) - - global_step = epoch * steps_per_epoch + i - lr = adjust_learning_rate(optimizer, global_step, steps_per_epoch, args) - - target = target.to(torch.int32) - images = images.to(loc, non_blocking=True).to(torch.float).sub(mean).div(std) - target = target.to(loc, non_blocking=True) - - # compute output - output = model(images) - stream = torch.npu.current_stream() - stream.synchronize() - - loss = criterion(output, target) - stream = torch.npu.current_stream() - stream.synchronize() - - # measure accuracy and record loss - acc1, acc5 = accuracy(output, target, topk=(1, 5)) - losses.update(loss.item(), images.size(0)) - top1.update(acc1[0], images.size(0)) - top5.update(acc5[0], images.size(0)) - - # compute gradient and do SGD step - if args.benchmark == 0: - optimizer.zero_grad() - - if args.amp: - with amp.scale_loss(loss, optimizer) as scaled_loss: - scaled_loss.backward() - else: - loss.backward() - - stream = torch.npu.current_stream() - stream.synchronize() - - if args.benchmark == 0: - optimizer.step() - elif args.benchmark == 1: - batch_size_multiplier = int(OPTIMIZER_BATCH_SIZE / args.batch_size) - bm_optimizer_step = ((i + 1) % batch_size_multiplier) == 0 - if bm_optimizer_step: - for param_group in optimizer.param_groups: - for param in param_group['params']: - param.grad /= batch_size_multiplier - optimizer.step() + #with torch.autograd.profiler.profile(use_npu=False) as prof: + if True: + torch.npu.enable_graph_mode() + if i > 200 : + pass + # measure data loading time + data_time.update(time.time() - end) + + global_step = epoch * steps_per_epoch + i + lr = adjust_learning_rate(optimizer, global_step, steps_per_epoch, args) + + images = images.to(loc, non_blocking=True) + target = target.to(loc, non_blocking=True) + images = images.to(torch.float).sub(mean).div(std) + target = target.to(torch.int32) + + # compute output + output = model(images) + # stream = torch.npu.current_stream() + # stream.synchronize() + + loss = criterion(output, target) + # stream = torch.npu.current_stream() + # stream.synchronize() + + # measure accuracy and record loss + acc1, acc5 = accuracy(output, target, topk=(1, 5)) + # losses.update(loss.item(), images.size(0)) + # top1.update(acc1[0], images.size(0)) + # top5.update(acc5[0], images.size(0)) + + # compute gradient and do SGD step + if args.benchmark == 0: optimizer.zero_grad() - stream = torch.npu.current_stream() - stream.synchronize() - # measure elapsed time + if args.amp: + with amp.scale_loss(loss, optimizer) as scaled_loss: + scaled_loss.backward() + else: + loss.backward() + + # stream = torch.npu.current_stream() + # stream.synchronize() + + if args.benchmark == 0: + optimizer.step() + elif args.benchmark == 1: + batch_size_multiplier = int(OPTIMIZER_BATCH_SIZE / args.batch_size) + bm_optimizer_step = ((i + 1) % batch_size_multiplier) == 0 + if bm_optimizer_step: + for param_group in optimizer.param_groups: + for param in param_group['params']: + param.grad /= batch_size_multiplier + optimizer.step() + optimizer.zero_grad() + # stream = torch.npu.current_stream() + # stream.synchronize() + torch.npu.launch_graph() + if i == 200: + torch.npu.synchronize() + # measure elapsed time + #prof.export_chrome_trace('./npu_profile_%d.json'%i) + batch_time.update(time.time() - end) end = time.time() @@ -403,7 +411,7 @@ def train(train_loader, train_loader_len, model, criterion, optimizer, epoch, ar if not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0): progress.display(i) - + torch.npu.disable_graph_mode() if not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0): print("[npu id:", args.gpu, "]", '* FPS@all {:.3f}'.format(ngpus_per_node * args.batch_size / batch_time.avg)) diff --git a/PyTorch/built-in/cv/classification/ResNet50_for_PyTorch/DistributedResnet50/main_apex_d76_npu.py b/PyTorch/built-in/cv/classification/ResNet50_for_PyTorch/DistributedResnet50/main_apex_d76_npu.py index 7886ada80dc0c113df4dd2c5b2dba3f34d5fb809..d05fe4ed88ab2754b3b154db9f65547c83d620f1 100644 --- a/PyTorch/built-in/cv/classification/ResNet50_for_PyTorch/DistributedResnet50/main_apex_d76_npu.py +++ b/PyTorch/built-in/cv/classification/ResNet50_for_PyTorch/DistributedResnet50/main_apex_d76_npu.py @@ -682,6 +682,10 @@ def train(train_loader, train_loader_len, model, criterion, optimizer, epoch, ar if args.benchmark == 1 : optimizer.zero_grad() for i, (images, target) in enumerate(train_loader): + # 图模式 + if args.graph_mode: + print("args.graph_mode") + torch.npu.enable_graph_mode() # measure data loading time data_time.update(time.time() - end) @@ -689,8 +693,15 @@ def train(train_loader, train_loader_len, model, criterion, optimizer, epoch, ar if args.device == 'npu': loc = 'npu:{}'.format(args.gpu) - images = images.to(loc, non_blocking=True).to(torch.float).sub(mean).div(std) - target = target.to(torch.int32).to(loc, non_blocking=True) + # 图模式 + if args.graph_mode: + images = images.to(loc, non_blocking=True) + target = target.to(loc, non_blocking=True) + images = images.to(torch.float).sub(mean).div(std) + target = target.to(torch.int32) + else: + images = images.to(loc, non_blocking=True).to(torch.float).sub(mean).div(std) + target = target.to(torch.int32).to(loc, non_blocking=True) else: images = images.cuda(args.gpu, non_blocking=True) target = target.cuda(args.gpu, non_blocking=True) @@ -701,10 +712,12 @@ def train(train_loader, train_loader_len, model, criterion, optimizer, epoch, ar loss = criterion(output, target) # measure accuracy and record loss - acc1, acc5 = accuracy(output, target, topk=(1, 5)) - losses.update(loss.item(), images.size(0)) - top1.update(acc1[0], images.size(0)) - top5.update(acc5[0], images.size(0)) + # 图模式 + if not args.graph_mode: + acc1, acc5 = accuracy(output, target, topk=(1, 5)) + losses.update(loss.item(), images.size(0)) + top1.update(acc1[0], images.size(0)) + top5.update(acc5[0], images.size(0)) # compute gradient and do SGD step if args.benchmark == 0 : @@ -727,7 +740,13 @@ def train(train_loader, train_loader_len, model, criterion, optimizer, epoch, ar optimizer.zero_grad() torch.npu.synchronize() - + + # 图模式 + if args.graph_mode: + print("args.graph_mode") + torch.npu.launch_graph() + if i == 100: + torch.npu.synchronize() # measure elapsed time batch_time.update(time.time() - end) end = time.time() diff --git a/PyTorch/built-in/cv/classification/ResNet50_for_PyTorch/pytorch_resnet50_apex.py b/PyTorch/built-in/cv/classification/ResNet50_for_PyTorch/pytorch_resnet50_apex.py index 695f9de34e2fa641aea579dc49504ab9e4d42a5c..2932965ffd972a5c05a597dc18c64fbe6751d15b 100644 --- a/PyTorch/built-in/cv/classification/ResNet50_for_PyTorch/pytorch_resnet50_apex.py +++ b/PyTorch/built-in/cv/classification/ResNet50_for_PyTorch/pytorch_resnet50_apex.py @@ -181,10 +181,14 @@ parser.add_argument('-t', '--fine-tuning', action='store_true', help='transfer learning + fine tuning - train only the last FC layer.') +# 图模式 +parser.add_argument('--graph_mode', + action='store_true', + help='whether to enable graph mode.') best_acc1 = 0 - +args = parser.parse_args() def main(): - args = parser.parse_args() + if args.npu is None: args.npu = 0 global CALCULATE_DEVICE @@ -428,6 +432,11 @@ def train(train_loader, model, criterion, optimizer, epoch, args): optimizer.zero_grad() end = time.time() for i, (images, target) in enumerate(train_loader): + # 图模式 + if args.graph_mode: + print("args.graph_mode") + torch.npu.enable_graph_mode() + if i > 100: pass # measure data loading time @@ -438,20 +447,34 @@ def train(train_loader, model, criterion, optimizer, epoch, args): images = images.to(CALCULATE_DEVICE, non_blocking=True) if args.label_smoothing == 0.0: - target = target.to(torch.int32).to(CALCULATE_DEVICE, non_blocking=True) - + # 图模式 + if args.graph_mode: + print("args.graph_mode") + target = target.to(CALCULATE_DEVICE, non_blocking=True).to(torch.int32) + else: + target = target.to(torch.int32).to(CALCULATE_DEVICE, non_blocking=True) # compute output output = model(images) loss = criterion(output, target) if args.label_smoothing > 0.0: - target = target.to(torch.int32).to(CALCULATE_DEVICE, non_blocking=True) + # 图模式 + if args.graph_mode: + print("args.graph_mode") + target = target.to(CALCULATE_DEVICE, non_blocking=True).to(torch.int32) + else: + target = target.to(torch.int32).to(CALCULATE_DEVICE, non_blocking=True) + + # measure accuracy and record loss - acc1, acc5 = accuracy(output, target, topk=(1, 5)) - losses.update(loss.item(), images.size(0)) - top1.update(acc1[0], images.size(0)) - top5.update(acc5[0], images.size(0)) + # 图模式 + if not args.graph_mode: + # print("args.graph_mode====================") + acc1, acc5 = accuracy(output, target, topk=(1, 5)) + losses.update(loss.item(), images.size(0)) + top1.update(acc1[0], images.size(0)) + top5.update(acc5[0], images.size(0)) # compute gradient and do SGD step with amp.scale_loss(loss, optimizer) as scaled_loss: @@ -464,6 +487,13 @@ def train(train_loader, model, criterion, optimizer, epoch, args): param.grad /= batch_size_multiplier optimizer.step() optimizer.zero_grad() + + # 图模式 + if args.graph_mode: + print("args.graph_mode") + torch.npu.launch_graph() + if i == 100: + torch.npu.synchronize() # measure elapsed time batch_time.update(time.time() - end) @@ -474,6 +504,10 @@ def train(train_loader, model, criterion, optimizer, epoch, args): if i == TRAIN_STEP: break + # 图模式 + if args.graph_mode: + print("args.graph_mode") + torch.npu.disable_graph_mode() print("batch_size:", args.batch_size, 'Time: {:.3f}'.format(batch_time.avg), '* FPS@all {:.3f}'.format( args.batch_size/batch_time.avg)) @@ -615,12 +649,20 @@ class LabelSmoothing(nn.Module): self.smoothing = smoothing def forward(self, x, target): - logprobs = torch.nn.functional.log_softmax(x, dim=-1).to("cpu") + # 图模式 + if args.graph_mode: + logprobs = torch.nn.functional.log_softmax(x, dim=-1) + else: + logprobs = torch.nn.functional.log_softmax(x, dim=-1).to("cpu") nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) nll_loss = nll_loss.squeeze(1) smooth_loss = -logprobs.mean(dim=-1) loss = self.confidence * nll_loss + self.smoothing * smooth_loss - return loss.mean().to(CALCULATE_DEVICE) + # 图模式 + if args.graph_mode: + return loss.mean() + else: + return loss.mean().to(CALCULATE_DEVICE) def lr_policy(lr_fn, logger=None): if logger is not None: diff --git a/PyTorch/built-in/cv/classification/ResNet50_for_PyTorch/test/train_full_1p.sh b/PyTorch/built-in/cv/classification/ResNet50_for_PyTorch/test/train_full_1p.sh index 37fd0fd4b8c3f01b3406578561c597fde64c190b..ce1bf775b06238893b7083eddf4dae4a294c89e1 100644 --- a/PyTorch/built-in/cv/classification/ResNet50_for_PyTorch/test/train_full_1p.sh +++ b/PyTorch/built-in/cv/classification/ResNet50_for_PyTorch/test/train_full_1p.sh @@ -25,9 +25,17 @@ do device_id=`echo ${para#*=}` elif [[ $para == --data_path* ]];then data_path=`echo ${para#*=}` + elif [[ $para == --graph_mode* ]];then + graph_mode=`echo ${para#*=}` fi done +#圖模式 +graph="" +if [[ x"${graph_mode}" == x"true" ]];then + graph="--graph_mode" +fi + # 校验是否传入data_path,不需要修改 if [[ $data_path == "" ]];then echo "[Error] para \"data_path\" must be confing" @@ -86,6 +94,7 @@ python3.7 ./pytorch_resnet50_apex.py \ --warmup 5 \ --label-smoothing=0.1 \ --epochs ${train_epochs} \ + ${graph} \ --optimizer-batch-size 512 > ${test_path_dir}/output/${ASCEND_DEVICE_ID}/train_${ASCEND_DEVICE_ID}.log 2>&1 & wait diff --git a/PyTorch/built-in/cv/classification/ResNet50_for_PyTorch/test/train_full_8p.sh b/PyTorch/built-in/cv/classification/ResNet50_for_PyTorch/test/train_full_8p.sh index 56b6147f44e9750c79ada63e6b70b5804fcda011..169217917cf62fadd92a99d26503722ab197d14d 100644 --- a/PyTorch/built-in/cv/classification/ResNet50_for_PyTorch/test/train_full_8p.sh +++ b/PyTorch/built-in/cv/classification/ResNet50_for_PyTorch/test/train_full_8p.sh @@ -22,9 +22,17 @@ for para in $* do if [[ $para == --data_path* ]];then data_path=`echo ${para#*=}` + elif [[ $para == --graph_mode* ]];then + graph_mode=`echo ${para#*=}` fi done +#圖模式 +graph="" +if [[ x"${graph_mode}" == x"true" ]];then + graph="--graph_mode" +fi + # 校验是否传入data_path,不需要修改 if [[ $data_path == "" ]];then echo "[Error] para \"data_path\" must be confing" @@ -98,6 +106,7 @@ python3.7 ./DistributedResnet50/main_apex_d76_npu.py \ --benchmark=0 \ --device='npu' \ --epochs=${train_epochs} \ + ${graph} \ --batch-size=${batch_size} > ${test_path_dir}/output/${ASCEND_DEVICE_ID}/train_${ASCEND_DEVICE_ID}.log 2>&1 & wait diff --git a/PyTorch/built-in/cv/classification/ResNet50_for_PyTorch/test/train_performance_1p.sh b/PyTorch/built-in/cv/classification/ResNet50_for_PyTorch/test/train_performance_1p.sh index 96226ecf321f592c0a6be9f4a02b38046c9eb37f..8e3b8c9d2af520cad5c1a0997a20e198643b8054 100644 --- a/PyTorch/built-in/cv/classification/ResNet50_for_PyTorch/test/train_performance_1p.sh +++ b/PyTorch/built-in/cv/classification/ResNet50_for_PyTorch/test/train_performance_1p.sh @@ -25,9 +25,17 @@ do device_id=`echo ${para#*=}` elif [[ $para == --data_path* ]];then data_path=`echo ${para#*=}` + elif [[ $para == --graph_mode* ]];then + graph_mode=`echo ${para#*=}` fi done +#圖模式 +graph="" +if [[ x"${graph_mode}" == x"true" ]];then + graph="--graph_mode" +fi + # 校验是否传入data_path,不需要修改 if [[ $data_path == "" ]];then echo "[Error] para \"data_path\" must be confing" @@ -79,6 +87,7 @@ etp_flag=`echo ${check_etp_flag#*=}` if [ x"${etp_flag}" != x"true" ];then source ${test_path_dir}/env_npu.sh fi + python3.7 ./pytorch_resnet50_apex.py \ --data ${data_path} \ --npu ${ASCEND_DEVICE_ID} \ @@ -88,6 +97,7 @@ python3.7 ./pytorch_resnet50_apex.py \ --warmup 5 \ --label-smoothing=0.1 \ --epochs ${train_epochs} \ + ${graph} \ --optimizer-batch-size 512 > ${test_path_dir}/output/${ASCEND_DEVICE_ID}/train_${ASCEND_DEVICE_ID}.log 2>&1 & wait diff --git a/PyTorch/built-in/cv/classification/ResNet50_for_PyTorch/test/train_performance_8p.sh b/PyTorch/built-in/cv/classification/ResNet50_for_PyTorch/test/train_performance_8p.sh index 850d3dda9e0f13c3bfedcdcd0c819f3c2143c3f9..4c1704fbc4eb1a86032ee2f8c270e35e1aba595d 100644 --- a/PyTorch/built-in/cv/classification/ResNet50_for_PyTorch/test/train_performance_8p.sh +++ b/PyTorch/built-in/cv/classification/ResNet50_for_PyTorch/test/train_performance_8p.sh @@ -21,9 +21,17 @@ for para in $* do if [[ $para == --data_path* ]];then data_path=`echo ${para#*=}` + elif [[ $para == --graph_mode* ]];then + graph_mode=`echo ${para#*=}` fi done +#圖模式 +graph="" +if [[ x"${graph_mode}" == x"true" ]];then + graph="--graph_mode" +fi + # 校验是否传入data_path,不需要修改 if [[ $data_path == "" ]];then echo "[Error] para \"data_path\" must be confing" @@ -83,6 +91,7 @@ python3.7 ./DistributedResnet50/main_apex_d76_npu.py \ --benchmark=0 \ --device='npu' \ --epochs=${train_epochs} \ + ${graph} \ --batch-size=${batch_size} > ${test_path_dir}/output/${ASCEND_DEVICE_ID}/train_${ASCEND_DEVICE_ID}.log 2>&1 & wait diff --git a/PyTorch/built-in/nlp/Bert-Squad_ID0470_for_PyTorch/run_squad.py b/PyTorch/built-in/nlp/Bert-Squad_ID0470_for_PyTorch/run_squad.py index a5203e848d4513cb9969823dcd6115fc5c179f71..2327cbcce4920422e34172229f2a668ff6e33013 100644 --- a/PyTorch/built-in/nlp/Bert-Squad_ID0470_for_PyTorch/run_squad.py +++ b/PyTorch/built-in/nlp/Bert-Squad_ID0470_for_PyTorch/run_squad.py @@ -1162,7 +1162,6 @@ def main(): "step_loss": round(final_loss, 4), "learning_rate": round(optimizer.param_groups[0]['lr'], 10)}) step_start_time = time.time() - time_to_train = time.time() - train_start if args.do_train and is_main_process() and not args.skip_checkpoint: diff --git a/PyTorch/built-in/nlp/Bert-Squad_ID0470_for_PyTorch/test/train_performance_1p.sh b/PyTorch/built-in/nlp/Bert-Squad_ID0470_for_PyTorch/test/train_performance_1p.sh index 1ab48782e32e2054a00a85905b9c3e02e6e6782b..37d3c8363753d81588865f4b871b2839822530fa 100644 --- a/PyTorch/built-in/nlp/Bert-Squad_ID0470_for_PyTorch/test/train_performance_1p.sh +++ b/PyTorch/built-in/nlp/Bert-Squad_ID0470_for_PyTorch/test/train_performance_1p.sh @@ -9,7 +9,6 @@ export RANK_SIZE=1 export JOB_ID=10087 RANK_ID_START=0 - # 数据集路径,保持为空,不需要修改 data_path="" ckpt_path=""