From 09c9c21e934b4761bfc16ace1265199f9ee1927a Mon Sep 17 00:00:00 2001 From: "rrrr.cao@hotmail.com" Date: Wed, 6 Apr 2022 11:14:07 +0800 Subject: [PATCH 1/7] add resnet50 8p graph mode --- .../DistributedResnet50/main_apex_d76_npu.py | 33 +++- .../train_ID3071_ResNet50_performance_8p.sh | 141 ++++++++++++++++++ 2 files changed, 167 insertions(+), 7 deletions(-) create mode 100644 PyTorch/built-in/cv/classification/ResNet50_for_PyTorch/test/train_ID3071_ResNet50_performance_8p.sh diff --git a/PyTorch/built-in/cv/classification/ResNet50_for_PyTorch/DistributedResnet50/main_apex_d76_npu.py b/PyTorch/built-in/cv/classification/ResNet50_for_PyTorch/DistributedResnet50/main_apex_d76_npu.py index 7886ada80d..d05fe4ed88 100644 --- a/PyTorch/built-in/cv/classification/ResNet50_for_PyTorch/DistributedResnet50/main_apex_d76_npu.py +++ b/PyTorch/built-in/cv/classification/ResNet50_for_PyTorch/DistributedResnet50/main_apex_d76_npu.py @@ -682,6 +682,10 @@ def train(train_loader, train_loader_len, model, criterion, optimizer, epoch, ar if args.benchmark == 1 : optimizer.zero_grad() for i, (images, target) in enumerate(train_loader): + # 图模式 + if args.graph_mode: + print("args.graph_mode") + torch.npu.enable_graph_mode() # measure data loading time data_time.update(time.time() - end) @@ -689,8 +693,15 @@ def train(train_loader, train_loader_len, model, criterion, optimizer, epoch, ar if args.device == 'npu': loc = 'npu:{}'.format(args.gpu) - images = images.to(loc, non_blocking=True).to(torch.float).sub(mean).div(std) - target = target.to(torch.int32).to(loc, non_blocking=True) + # 图模式 + if args.graph_mode: + images = images.to(loc, non_blocking=True) + target = target.to(loc, non_blocking=True) + images = images.to(torch.float).sub(mean).div(std) + target = target.to(torch.int32) + else: + images = images.to(loc, non_blocking=True).to(torch.float).sub(mean).div(std) + target = target.to(torch.int32).to(loc, non_blocking=True) else: images = images.cuda(args.gpu, non_blocking=True) target = target.cuda(args.gpu, non_blocking=True) @@ -701,10 +712,12 @@ def train(train_loader, train_loader_len, model, criterion, optimizer, epoch, ar loss = criterion(output, target) # measure accuracy and record loss - acc1, acc5 = accuracy(output, target, topk=(1, 5)) - losses.update(loss.item(), images.size(0)) - top1.update(acc1[0], images.size(0)) - top5.update(acc5[0], images.size(0)) + # 图模式 + if not args.graph_mode: + acc1, acc5 = accuracy(output, target, topk=(1, 5)) + losses.update(loss.item(), images.size(0)) + top1.update(acc1[0], images.size(0)) + top5.update(acc5[0], images.size(0)) # compute gradient and do SGD step if args.benchmark == 0 : @@ -727,7 +740,13 @@ def train(train_loader, train_loader_len, model, criterion, optimizer, epoch, ar optimizer.zero_grad() torch.npu.synchronize() - + + # 图模式 + if args.graph_mode: + print("args.graph_mode") + torch.npu.launch_graph() + if i == 100: + torch.npu.synchronize() # measure elapsed time batch_time.update(time.time() - end) end = time.time() diff --git a/PyTorch/built-in/cv/classification/ResNet50_for_PyTorch/test/train_ID3071_ResNet50_performance_8p.sh b/PyTorch/built-in/cv/classification/ResNet50_for_PyTorch/test/train_ID3071_ResNet50_performance_8p.sh new file mode 100644 index 0000000000..ffae349489 --- /dev/null +++ b/PyTorch/built-in/cv/classification/ResNet50_for_PyTorch/test/train_ID3071_ResNet50_performance_8p.sh @@ -0,0 +1,141 @@ +#!/bin/bash + +################基础配置参数,需要模型审视修改################## +# 必选字段(必须在此处定义的参数): Network batch_size RANK_SIZE +# 网络名称,同目录名称 +Network="ResNet50_ID3071_for_PyTorch" +# 训练batch_size +batch_size=4096 +# 训练使用的npu卡数 +export RANK_SIZE=8 +# 数据集路径,保持为空,不需要修改 +data_path="" + +# 训练epoch 90 +train_epochs=3 +# 加载数据进程数 +workers=128 + +# 参数校验,data_path为必传参数,其他参数的增删由模型自身决定;此处新增参数需在上面有定义并赋值 +for para in $* +do + if [[ $para == --data_path* ]];then + data_path=`echo ${para#*=}` + fi +done + + +# 校验是否传入data_path,不需要修改 +if [[ $data_path == "" ]];then + echo "[Error] para \"data_path\" must be confing" + exit 1 +fi + +###############指定训练脚本执行路径############### +# cd到与test文件夹同层级目录下执行脚本,提高兼容性;test_path_dir为包含test文件夹的路径 +cur_path=`pwd` +cur_path_last_dirname=${cur_path##*/} +if [ x"${cur_path_last_dirname}" == x"test" ];then + test_path_dir=${cur_path} + cd .. + cur_path=`pwd` +else + test_path_dir=${cur_path}/test +fi + + +#################创建日志输出目录,不需要修改################# +ASCEND_DEVICE_ID=0 +if [ -d ${test_path_dir}/output/${ASCEND_DEVICE_ID} ];then + rm -rf ${test_path_dir}/output/${ASCEND_DEVICE_ID} + mkdir -p ${test_path_dir}/output/$ASCEND_DEVICE_ID +else + mkdir -p ${test_path_dir}/output/$ASCEND_DEVICE_ID +fi + + +#################启动训练脚本################# +# 训练开始时间,不需要修改 +start_time=$(date +%s) +# 非平台场景时source 环境变量 +check_etp_flag=`env | grep etp_running_flag` +etp_flag=`echo ${check_etp_flag#*=}` +if [ x"${etp_flag}" != x"true" ];then + source ${test_path_dir}/env_npu.sh +fi + +python3.7 ./DistributedResnet50/main_apex_d76_npu.py \ + --data ${data_path} \ + --addr=$(hostname -I |awk '{print $1}') \ + --seed=49 \ + --workers=${workers} \ + --learning-rate=1.6 \ + --warmup=8 \ + --label-smoothing=0.0 \ + --mom=0.9 \ + --weight-decay=1.0e-04 \ + --static-loss-scale=128 \ + --print-freq=1 \ + --dist-url='tcp://127.0.0.1:50000' \ + --dist-backend='hccl' \ + --multiprocessing-distributed \ + --world-size=1 \ + --rank=0 \ + --benchmark=0 \ + --device='npu' \ + --graph_mode \ + --epochs=${train_epochs} \ + + --batch-size=${batch_size} > ${test_path_dir}/output/${ASCEND_DEVICE_ID}/train_${ASCEND_DEVICE_ID}.log 2>&1 & + +wait + + +##################获取训练数据################ +# 训练结束时间,不需要修改 +end_time=$(date +%s) +e2e_time=$(( $end_time - $start_time )) + +# 训练用例信息,不需要修改 +BatchSize=${batch_size} +DeviceType=`uname -m` +CaseName=${Network}_bs${BatchSize}_${RANK_SIZE}'p'_'perf' + +# 结果打印,不需要修改 +echo "------------------ Final result ------------------" +# 输出性能FPS,需要模型审视修改 +grep "FPS@all" ${test_path_dir}/output/${ASCEND_DEVICE_ID}/train_${ASCEND_DEVICE_ID}.log | awk '{print $11}' >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/train_${CaseName}_fps.log +FPS=`cat ${test_path_dir}/output/${ASCEND_DEVICE_ID}/train_${CaseName}_fps.log | awk '{a+=$1} END {if (NR != 0) printf("%.3f",a/NR)}'` +# 打印,不需要修改 +echo "Final Performance images/sec : $FPS" + +# 输出训练精度,需要模型审视修改 +train_accuracy=`grep -a '* Acc@1' ${test_path_dir}/output/${ASCEND_DEVICE_ID}/train_${ASCEND_DEVICE_ID}.log|awk 'END {print}'|awk -F "Acc@1" '{print $NF}'|awk -F " " '{print $1}'` +# 打印,不需要修改 +echo "Final Train Accuracy : ${train_accuracy}" +echo "E2E Training Duration sec : $e2e_time" + +# 性能看护结果汇总 +# 获取性能数据,不需要修改 +# 吞吐量 +ActualFPS=${FPS} +# 单迭代训练时长 +TrainingTime=`awk 'BEGIN{printf "%.2f\n", '${batch_size}'*1000/'${FPS}'}'` + +# 从train_$ASCEND_DEVICE_ID.log提取Loss到train_${CaseName}_loss.txt中,需要根据模型审视 +grep Epoch: ${test_path_dir}/output/$ASCEND_DEVICE_ID/train_$ASCEND_DEVICE_ID.log|grep -v Test|awk -F "Loss" '{print $NF}' | awk -F " " '{print $1}' >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/train_${CaseName}_loss.txt + +# 最后一个迭代loss值,不需要修改 +ActualLoss=`awk 'END {print}' ${test_path_dir}/output/$ASCEND_DEVICE_ID/train_${CaseName}_loss.txt` + +# 关键信息打印到${CaseName}.log中,不需要修改 +echo "Network = ${Network}" > ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "RankSize = ${RANK_SIZE}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "BatchSize = ${BatchSize}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "DeviceType = ${DeviceType}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "CaseName = ${CaseName}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "ActualFPS = ${ActualFPS}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "TrainingTime = ${TrainingTime}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "TrainAccuracy = ${train_accuracy}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "ActualLoss = ${ActualLoss}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "E2ETrainingTime = ${e2e_time}" >> ${test_path_dir}/output/$ASCEND_DEVICE_ID/${CaseName}.log \ No newline at end of file -- Gitee From 08bc4d217b72e0abfc72f4521950121bcb2f3f36 Mon Sep 17 00:00:00 2001 From: Ryan Date: Wed, 6 Apr 2022 03:23:22 +0000 Subject: [PATCH 2/7] update train_ID3071_ResNet50_performance_8p.sh. --- .../test/train_ID3071_ResNet50_performance_8p.sh | 1 - 1 file changed, 1 deletion(-) diff --git a/PyTorch/built-in/cv/classification/ResNet50_for_PyTorch/test/train_ID3071_ResNet50_performance_8p.sh b/PyTorch/built-in/cv/classification/ResNet50_for_PyTorch/test/train_ID3071_ResNet50_performance_8p.sh index ffae349489..0013d69590 100644 --- a/PyTorch/built-in/cv/classification/ResNet50_for_PyTorch/test/train_ID3071_ResNet50_performance_8p.sh +++ b/PyTorch/built-in/cv/classification/ResNet50_for_PyTorch/test/train_ID3071_ResNet50_performance_8p.sh @@ -85,7 +85,6 @@ python3.7 ./DistributedResnet50/main_apex_d76_npu.py \ --device='npu' \ --graph_mode \ --epochs=${train_epochs} \ - --batch-size=${batch_size} > ${test_path_dir}/output/${ASCEND_DEVICE_ID}/train_${ASCEND_DEVICE_ID}.log 2>&1 & wait -- Gitee From f03f8a2a1098570324bf8b1e93352d1f04fff7a1 Mon Sep 17 00:00:00 2001 From: Ryan Date: Wed, 6 Apr 2022 05:56:53 +0000 Subject: [PATCH 3/7] update DistributedResnet50/main_apex_d76_npu.py. --- .../DistributedResnet50/main_apex_d76_npu.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/PyTorch/built-in/cv/classification/ResNet50_for_PyTorch/DistributedResnet50/main_apex_d76_npu.py b/PyTorch/built-in/cv/classification/ResNet50_for_PyTorch/DistributedResnet50/main_apex_d76_npu.py index d05fe4ed88..ab85e9fff0 100644 --- a/PyTorch/built-in/cv/classification/ResNet50_for_PyTorch/DistributedResnet50/main_apex_d76_npu.py +++ b/PyTorch/built-in/cv/classification/ResNet50_for_PyTorch/DistributedResnet50/main_apex_d76_npu.py @@ -259,6 +259,10 @@ parser.add_argument('-t', '--fine-tuning', action='store_true', help='transfer learning + fine tuning - train only the last FC layer.') +# 图模式 +parser.add_argument('--graph_mode', + action='store_true', + help='whether to enable graph mode.') best_acc1 = 0 def nvidia_model_config(args): -- Gitee From adc2125b2981e42339b9c53085ad2a07dea7e56f Mon Sep 17 00:00:00 2001 From: Ryan Date: Wed, 6 Apr 2022 06:33:41 +0000 Subject: [PATCH 4/7] update DistributedResnet50/main_apex_d76_npu.py. --- .../DistributedResnet50/main_apex_d76_npu.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/PyTorch/built-in/cv/classification/ResNet50_for_PyTorch/DistributedResnet50/main_apex_d76_npu.py b/PyTorch/built-in/cv/classification/ResNet50_for_PyTorch/DistributedResnet50/main_apex_d76_npu.py index ab85e9fff0..f9b8ad14ac 100644 --- a/PyTorch/built-in/cv/classification/ResNet50_for_PyTorch/DistributedResnet50/main_apex_d76_npu.py +++ b/PyTorch/built-in/cv/classification/ResNet50_for_PyTorch/DistributedResnet50/main_apex_d76_npu.py @@ -749,7 +749,7 @@ def train(train_loader, train_loader_len, model, criterion, optimizer, epoch, ar if args.graph_mode: print("args.graph_mode") torch.npu.launch_graph() - if i == 100: + if i == 312: torch.npu.synchronize() # measure elapsed time batch_time.update(time.time() - end) -- Gitee From 983fdef4a880f746bc3533badf7f39556ed66e27 Mon Sep 17 00:00:00 2001 From: Ryan Date: Wed, 6 Apr 2022 06:38:12 +0000 Subject: [PATCH 5/7] update DistributedResnet50/main_apex_d76_npu.py. --- .../DistributedResnet50/main_apex_d76_npu.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/PyTorch/built-in/cv/classification/ResNet50_for_PyTorch/DistributedResnet50/main_apex_d76_npu.py b/PyTorch/built-in/cv/classification/ResNet50_for_PyTorch/DistributedResnet50/main_apex_d76_npu.py index f9b8ad14ac..6eefaa9e64 100644 --- a/PyTorch/built-in/cv/classification/ResNet50_for_PyTorch/DistributedResnet50/main_apex_d76_npu.py +++ b/PyTorch/built-in/cv/classification/ResNet50_for_PyTorch/DistributedResnet50/main_apex_d76_npu.py @@ -759,7 +759,10 @@ def train(train_loader, train_loader_len, model, criterion, optimizer, epoch, ar if not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0): progress.display(i) - + # 图模式 + if args.graph_mode: + print("args.graph_mode") + torch.npu.disable_graph_mode() if not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0): print("[npu id:",args.gpu,"]", "batch_size:", ngpus_per_node*args.batch_size, 'Time: {:.3f}'.format(batch_time.avg), '* FPS@all {:.3f}'.format( -- Gitee From b6451e2fe649be357fb5c9ad9e39323b47251679 Mon Sep 17 00:00:00 2001 From: Ryan Date: Wed, 6 Apr 2022 07:04:25 +0000 Subject: [PATCH 6/7] update main_apex_d76_npu.py. --- .../DistributedResnet50/main_apex_d76_npu.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/PyTorch/built-in/cv/classification/ResNet50_for_PyTorch/DistributedResnet50/main_apex_d76_npu.py b/PyTorch/built-in/cv/classification/ResNet50_for_PyTorch/DistributedResnet50/main_apex_d76_npu.py index 6eefaa9e64..5798a195f1 100644 --- a/PyTorch/built-in/cv/classification/ResNet50_for_PyTorch/DistributedResnet50/main_apex_d76_npu.py +++ b/PyTorch/built-in/cv/classification/ResNet50_for_PyTorch/DistributedResnet50/main_apex_d76_npu.py @@ -688,7 +688,7 @@ def train(train_loader, train_loader_len, model, criterion, optimizer, epoch, ar for i, (images, target) in enumerate(train_loader): # 图模式 if args.graph_mode: - print("args.graph_mode") + print("graph mode on") torch.npu.enable_graph_mode() # measure data loading time data_time.update(time.time() - end) @@ -747,7 +747,7 @@ def train(train_loader, train_loader_len, model, criterion, optimizer, epoch, ar # 图模式 if args.graph_mode: - print("args.graph_mode") + print("graph mode launch") torch.npu.launch_graph() if i == 312: torch.npu.synchronize() @@ -761,7 +761,7 @@ def train(train_loader, train_loader_len, model, criterion, optimizer, epoch, ar progress.display(i) # 图模式 if args.graph_mode: - print("args.graph_mode") + print("graph mode off") torch.npu.disable_graph_mode() if not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0): -- Gitee From c8f5580ea21a3683ea9746d6a4469a110694d3f9 Mon Sep 17 00:00:00 2001 From: Ryan Date: Wed, 6 Apr 2022 07:14:12 +0000 Subject: [PATCH 7/7] update main_apex_d76_npu.py. --- .../DistributedResnet50/main_apex_d76_npu.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/PyTorch/built-in/cv/classification/ResNet50_for_PyTorch/DistributedResnet50/main_apex_d76_npu.py b/PyTorch/built-in/cv/classification/ResNet50_for_PyTorch/DistributedResnet50/main_apex_d76_npu.py index 5798a195f1..26edd676ce 100644 --- a/PyTorch/built-in/cv/classification/ResNet50_for_PyTorch/DistributedResnet50/main_apex_d76_npu.py +++ b/PyTorch/built-in/cv/classification/ResNet50_for_PyTorch/DistributedResnet50/main_apex_d76_npu.py @@ -749,7 +749,7 @@ def train(train_loader, train_loader_len, model, criterion, optimizer, epoch, ar if args.graph_mode: print("graph mode launch") torch.npu.launch_graph() - if i == 312: + if i == len(train_loader): torch.npu.synchronize() # measure elapsed time batch_time.update(time.time() - end) -- Gitee