From 22cf8bead23b488d707135fa216c31048ceea7e7 Mon Sep 17 00:00:00 2001 From: wuxingpeng Date: Fri, 8 Apr 2022 02:00:37 +0000 Subject: [PATCH 1/5] =?UTF-8?q?update=20=E3=80=90PyTorch=E3=80=91=E3=80=90?= =?UTF-8?q?dev=E3=80=91=E3=80=90FixMatch=5FID0989=5Ffor=5FPyTorch=E3=80=91?= =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E6=A8=A1=E7=B3=8A=E7=BC=96=E8=AF=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../image_classification/FixMatch_ID0989_for_PyTorch/train.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/PyTorch/dev/cv/image_classification/FixMatch_ID0989_for_PyTorch/train.py b/PyTorch/dev/cv/image_classification/FixMatch_ID0989_for_PyTorch/train.py index 3f8baad87a..48cc598d7d 100644 --- a/PyTorch/dev/cv/image_classification/FixMatch_ID0989_for_PyTorch/train.py +++ b/PyTorch/dev/cv/image_classification/FixMatch_ID0989_for_PyTorch/train.py @@ -57,7 +57,7 @@ import apex logger = logging.getLogger(__name__) best_acc = 0 - +torch.npu.set_start_fuzz_compile_step(3) def save_checkpoint(state, is_best, checkpoint, filename='checkpoint.pth.tar'): filepath = os.path.join(checkpoint, filename) torch.save(state, filepath) @@ -365,6 +365,8 @@ def train(args, labeled_trainloader, unlabeled_trainloader, test_loader, p_bar = tqdm(range(args.eval_step), disable=args.local_rank not in [-1, 0]) for batch_idx in range(args.eval_step): + #模糊编译 + torch.npu.global_step_inc() try: inputs_x, targets_x = labeled_iter.next() except: -- Gitee From 3dd03fab7cbb188f6b3e2642ef5a5e7ea024cd60 Mon Sep 17 00:00:00 2001 From: wuxingpeng Date: Fri, 8 Apr 2022 08:25:45 +0000 Subject: [PATCH 2/5] =?UTF-8?q?=E3=80=90PyTorch=E3=80=91=E3=80=90buit-in?= =?UTF-8?q?=E3=80=91=E3=80=90MobileNetV2=5Ffor=5FPyTorch=E3=80=91MobileNet?= =?UTF-8?q?V2=5Ffor=5FPyTorch=E7=BD=91=E7=BB=9C=E6=B7=BB=E5=8A=A0=E5=9B=BE?= =?UTF-8?q?=E6=A8=A1=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../train/mobilenetv2_8p_main_anycard.py | 72 +++++++++++++------ 1 file changed, 50 insertions(+), 22 deletions(-) diff --git a/PyTorch/built-in/cv/classification/MobileNetV2_for_PyTorch/train/mobilenetv2_8p_main_anycard.py b/PyTorch/built-in/cv/classification/MobileNetV2_for_PyTorch/train/mobilenetv2_8p_main_anycard.py index b38bdd8a85..8985b80cc2 100644 --- a/PyTorch/built-in/cv/classification/MobileNetV2_for_PyTorch/train/mobilenetv2_8p_main_anycard.py +++ b/PyTorch/built-in/cv/classification/MobileNetV2_for_PyTorch/train/mobilenetv2_8p_main_anycard.py @@ -121,6 +121,10 @@ parser.add_argument('--opt-level', default='O2', type=str, help='loss scale using in amp, default -1 means dynamic') parser.add_argument('--class-nums', default=1000, type=int, help='class-nums only for pretrain') +# 图模式 +parser.add_argument('--graph_mode', + action='store_true', + help='whether to enable graph mode.') warnings.filterwarnings('ignore') best_acc1 = 0 @@ -341,6 +345,10 @@ def train(train_loader, train_loader_len, model, criterion, optimizer, epoch, ar steps_per_epoch = train_loader_len print('==========step per epoch======================', steps_per_epoch) for i, (images, target) in enumerate(train_loader): + #图模式 + if args.graph_mode: + print("graph mode on") + torch.npu.enable_graph_mode() if i > 200 : pass # measure data loading time @@ -348,25 +356,34 @@ def train(train_loader, train_loader_len, model, criterion, optimizer, epoch, ar global_step = epoch * steps_per_epoch + i lr = adjust_learning_rate(optimizer, global_step, steps_per_epoch, args) + #图模式 + if args.gradph_mode: + images = images.to(loc, non_blocking = True) + target = target.to(loc, non_blocking = True) + images = images.to(torch.float).sub(mean).div(std) + target = target.to(torch.int32) + # compute output + output = model(images) + loss = criterion(output, target) + acc1, acc5 = accuracy(output, target, topk=(1, 5)) + else: + target = target.to(torch.int32) + images = images.to(loc, non_blocking=True).to(torch.float).sub(mean).div(std) + target = target.to(loc, non_blocking=True) + # compute output + output = model(images) + stream = torch.npu.current_stream() + stream.synchronize() - target = target.to(torch.int32) - images = images.to(loc, non_blocking=True).to(torch.float).sub(mean).div(std) - target = target.to(loc, non_blocking=True) - - # compute output - output = model(images) - stream = torch.npu.current_stream() - stream.synchronize() - - loss = criterion(output, target) - stream = torch.npu.current_stream() - stream.synchronize() + loss = criterion(output, target) + stream = torch.npu.current_stream() + stream.synchronize() - # measure accuracy and record loss - acc1, acc5 = accuracy(output, target, topk=(1, 5)) - losses.update(loss.item(), images.size(0)) - top1.update(acc1[0], images.size(0)) - top5.update(acc5[0], images.size(0)) + # measure accuracy and record loss + acc1, acc5 = accuracy(output, target, topk=(1, 5)) + losses.update(loss.item(), images.size(0)) + top1.update(acc1[0], images.size(0)) + top5.update(acc5[0], images.size(0)) # compute gradient and do SGD step if args.benchmark == 0: @@ -377,9 +394,10 @@ def train(train_loader, train_loader_len, model, criterion, optimizer, epoch, ar scaled_loss.backward() else: loss.backward() - - stream = torch.npu.current_stream() - stream.synchronize() + #图模式 + if not args.graph_mode: + stream = torch.npu.current_stream() + stream.synchronize() if args.benchmark == 0: optimizer.step() @@ -392,8 +410,14 @@ def train(train_loader, train_loader_len, model, criterion, optimizer, epoch, ar param.grad /= batch_size_multiplier optimizer.step() optimizer.zero_grad() - stream = torch.npu.current_stream() - stream.synchronize() + #图模式 + if args.graph_mode: + torch.npu.launch_graph() + if i == len(train_loader): + torch.npu.synchronize() + else: + stream = torch.npu.current_stream() + stream.synchronize() # measure elapsed time batch_time.update(time.time() - end) @@ -404,6 +428,10 @@ def train(train_loader, train_loader_len, model, criterion, optimizer, epoch, ar and args.rank % ngpus_per_node == 0): progress.display(i) + #图模式 + if args.graph_mode: + print("args.graph_mode") + torch.npu.disable_graph_mode() if not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0): print("[npu id:", args.gpu, "]", '* FPS@all {:.3f}'.format(ngpus_per_node * args.batch_size / batch_time.avg)) -- Gitee From ba703027dde2d86eee0f2e76fbefb1c629c018ee Mon Sep 17 00:00:00 2001 From: wuxingpeng Date: Fri, 8 Apr 2022 08:54:03 +0000 Subject: [PATCH 3/5] =?UTF-8?q?=E3=80=90PyTorch=E3=80=91=E3=80=90buit-in?= =?UTF-8?q?=E3=80=91=E3=80=90MobileNetV2=5Ffor=5FPyTorch=E3=80=91MobileNet?= =?UTF-8?q?V2=5Ffor=5FPyTorch=E7=BD=91=E7=BB=9C=E6=B7=BB=E5=8A=A0=E5=9B=BE?= =?UTF-8?q?=E6=A8=A1=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../train/mobilenetv2_8p_main_anycard.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/PyTorch/built-in/cv/classification/MobileNetV2_for_PyTorch/train/mobilenetv2_8p_main_anycard.py b/PyTorch/built-in/cv/classification/MobileNetV2_for_PyTorch/train/mobilenetv2_8p_main_anycard.py index 8985b80cc2..4fac8b49dc 100644 --- a/PyTorch/built-in/cv/classification/MobileNetV2_for_PyTorch/train/mobilenetv2_8p_main_anycard.py +++ b/PyTorch/built-in/cv/classification/MobileNetV2_for_PyTorch/train/mobilenetv2_8p_main_anycard.py @@ -357,7 +357,7 @@ def train(train_loader, train_loader_len, model, criterion, optimizer, epoch, ar global_step = epoch * steps_per_epoch + i lr = adjust_learning_rate(optimizer, global_step, steps_per_epoch, args) #图模式 - if args.gradph_mode: + if args.graph_mode: images = images.to(loc, non_blocking = True) target = target.to(loc, non_blocking = True) images = images.to(torch.float).sub(mean).div(std) -- Gitee From e7006f830ea3724859e64274198b1de04abf63fc Mon Sep 17 00:00:00 2001 From: wuxingpeng Date: Fri, 8 Apr 2022 08:56:08 +0000 Subject: [PATCH 4/5] =?UTF-8?q?=E3=80=90PyTorch=E3=80=91=E3=80=90buit-in?= =?UTF-8?q?=E3=80=91=E3=80=90MobileNetV2=5Ffor=5FPyTorch=E3=80=91MobileNet?= =?UTF-8?q?V2=5Ffor=5FPyTorch=E7=BD=91=E7=BB=9C=E6=B7=BB=E5=8A=A0=E5=9B=BE?= =?UTF-8?q?=E6=A8=A1=E5=BC=8F=E5=8D=95P=E6=80=A7=E8=83=BD=E6=89=A7?= =?UTF-8?q?=E8=A1=8C=E8=84=9A=E6=9C=AC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ...train_ID3072_MobileNetV2_performance_1p.py | 164 ++++++++++++++++++ 1 file changed, 164 insertions(+) create mode 100644 PyTorch/built-in/cv/classification/MobileNetV2_for_PyTorch/test/train_ID3072_MobileNetV2_performance_1p.py diff --git a/PyTorch/built-in/cv/classification/MobileNetV2_for_PyTorch/test/train_ID3072_MobileNetV2_performance_1p.py b/PyTorch/built-in/cv/classification/MobileNetV2_for_PyTorch/test/train_ID3072_MobileNetV2_performance_1p.py new file mode 100644 index 0000000000..86b6ec9d5e --- /dev/null +++ b/PyTorch/built-in/cv/classification/MobileNetV2_for_PyTorch/test/train_ID3072_MobileNetV2_performance_1p.py @@ -0,0 +1,164 @@ +#!/bin/bash + +#当前路径,不需要修改 +cur_path=`pwd` + + +#集合通信参数,不需要修改 +export HCCL_WHITELIST_DISABLE=1 +export RANK_SIZE=1 +export JOB_ID=10087 +RANK_ID_START=0 +# source env.sh +#RANK_SIZE=8 +# 数据集路径,保持为空,不需要修改 +data_path="" + +#设置默认日志级别,不需要修改 +# export ASCEND_GLOBAL_LOG_LEVEL_ETP=3 + +#基础参数,需要模型审视修改 +#网络名称,同目录名称 +Network="MobileNetV2_ID3072_for_PyTorch" +#训练epoch +train_epochs=1 +#训练batch_size +batch_size=512 +#训练step +train_steps=`expr 1281167 / ${batch_size}` +#学习率 +learning_rate=0.045 + +#维测参数,precision_mode需要模型审视修改 +precision_mode="allow_mix_precision" +#维持参数,以下不需要修改 +over_dump=False +data_dump_flag=False +data_dump_step="10" +profiling=False + + +if [[ $1 == --help || $1 == --h ]];then + echo "usage:./train_performance_1p.sh --data_path=data_dir --batch_size=1024 --learning_rate=0.04" + exit 1 +fi + +for para in $* +do + if [[ $para == --data_path* ]];then + data_path=`echo ${para#*=}` + elif [[ $para == --batch_size* ]];then + batch_size=`echo ${para#*=}` + elif [[ $para == --learning_rate* ]];then + learning_rate=`echo ${para#*=}` + elif [[ $para == --precision_mode* ]];then + precision_mode=`echo ${para#*=}` + fi +done + +PREC="" +if [[ $precision_mode == "amp" ]];then + PREC="--amp" +fi + +#校验是否传入data_path,不需要修改 +if [[ $data_path == "" ]];then + echo "[Error] para \"data_path\" must be confing" + exit 1 +fi + +cd $cur_path + +#设置环境变量,不需要修改 +echo "Device ID: $ASCEND_DEVICE_ID" +export RANK_ID=$RANK_ID + +if [ -d $cur_path/output ];then + rm -rf $cur_path/output/* + mkdir -p $cur_path/output/$ASCEND_DEVICE_ID +else + mkdir -p $cur_path/output/$ASCEND_DEVICE_ID +fi +wait + +#参数修改 +sed -i "s|pass|break|g" ${cur_path}/../train/mobilenetv2_8p_main_anycard.py +wait + +#训练开始时间,不需要修改 +start_time=$(date +%s) + +# 绑核,不需要的绑核的模型删除,需要模型审视修改 +python3.7 ${cur_path}/../train/mobilenetv2_8p_main_anycard.py \ + --addr=$(hostname -I |awk '{print $1}') \ + --seed 49 \ + --workers 128 \ + --lr 0.05 \ + --print-freq 1 \ + --eval-freq 1 \ + --dist-url 'tcp://127.0.0.1:50002' \ + --dist-backend 'hccl' \ + --multiprocessing-distributed \ + --world-size 1 \ + --class-nums 1000 \ + --batch-size $batch_size \ + --epochs $train_epochs \ + --rank 0 \ + --device-list $ASCEND_DEVICE_ID \ + --amp \ + --benchmark 0 \ + --graph_mode \ + --data $data_path > $cur_path/output/$ASCEND_DEVICE_ID/train_$ASCEND_DEVICE_ID.log 2>&1 & +wait + +#训练结束时间,不需要修改 +end_time=$(date +%s) +e2e_time=$(( $end_time - $start_time )) + +#参数改回 +sed -i "s|break|pass|g" ${cur_path}/../train/mobilenetv2_8p_main_anycard.py +wait + +#结果打印,不需要修改 +echo "------------------ Final result ------------------" +#输出性能FPS,需要模型审视修改 +FPS=`grep FPS ${cur_path}/output/$ASCEND_DEVICE_ID/train_$ASCEND_DEVICE_ID.log|awk '{print $NF}'|awk '{sum+=$1} END {print sum/NR}'` + +#打印,不需要修改 +echo "Final Performance images/sec : $FPS" + +#输出训练精度,需要模型审视修改 +#train_accuracy=`grep -a '* Acc@1' $cur_path/output/${ASCEND_DEVICE_ID}/train_${ASCEND_DEVICE_ID}.log|awk 'END {print}'|awk -F "Acc@1" '{print $NF}'|awk -F " " '{print $1}'` + +#打印,不需要修改 +#echo "Final Train Accuracy : ${train_accuracy}" +echo "E2E Training Duration sec : $e2e_time" + +#性能看护结果汇总 +#训练用例信息,不需要修改 +BatchSize=${batch_size} +DeviceType=`uname -m` +CaseName=${Network}_bs${BatchSize}_${RANK_SIZE}'p'_'perf' + +##获取性能数据,不需要修改 +#吞吐量 +ActualFPS=${FPS} +#单迭代训练时长 +TrainingTime=`awk 'BEGIN{printf "%.2f\n", '${batch_size}'*1000/'${FPS}'}'` + +#从train_$ASCEND_DEVICE_ID.log提取Loss到train_${CaseName}_loss.txt中,需要根据模型审视 +grep Epoch $cur_path/output/$ASCEND_DEVICE_ID/train_$ASCEND_DEVICE_ID.log|awk -F 'Loss' '{print $2}' |awk '{print $1}' > $cur_path/output/$ASCEND_DEVICE_ID/train_${CaseName}_loss.txt + +#最后一个迭代loss值,不需要修改 +ActualLoss=`awk 'END {print}' $cur_path/output/$ASCEND_DEVICE_ID/train_${CaseName}_loss.txt` + +#关键信息打印到${CaseName}.log中,不需要修改 +echo "Network = ${Network}" > $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "RankSize = ${RANK_SIZE}" >> $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "BatchSize = ${BatchSize}" >> $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "DeviceType = ${DeviceType}" >> $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "CaseName = ${CaseName}" >> $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "ActualFPS = ${ActualFPS}" >> $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "TrainingTime = ${TrainingTime}" >> $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "ActualLoss = ${ActualLoss}" >> $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log +echo "E2ETrainingTime = ${e2e_time}" >> $cur_path/output/$ASCEND_DEVICE_ID/${CaseName}.log -- Gitee From 1b089b616594de04e4d070e80bc7abaa361e3df3 Mon Sep 17 00:00:00 2001 From: wuxingpeng Date: Fri, 8 Apr 2022 09:08:17 +0000 Subject: [PATCH 5/5] =?UTF-8?q?=E9=87=8D=E5=91=BD=E5=90=8D=20PyTorch/built?= =?UTF-8?q?-in/cv/classification/MobileNetV2=5Ffor=5FPyTorch/test/train=5F?= =?UTF-8?q?ID3072=5FMobileNetV2=5Fperformance=5F1p.py=20=E4=B8=BA=20PyTorc?= =?UTF-8?q?h/built-in/cv/classification/MobileNetV2=5Ffor=5FPyTorch/test/t?= =?UTF-8?q?rain=5FID3072=5FMobileNetV2=5Fperformance=5F1p.sh?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ...rformance_1p.py => train_ID3072_MobileNetV2_performance_1p.sh} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename PyTorch/built-in/cv/classification/MobileNetV2_for_PyTorch/test/{train_ID3072_MobileNetV2_performance_1p.py => train_ID3072_MobileNetV2_performance_1p.sh} (100%) diff --git a/PyTorch/built-in/cv/classification/MobileNetV2_for_PyTorch/test/train_ID3072_MobileNetV2_performance_1p.py b/PyTorch/built-in/cv/classification/MobileNetV2_for_PyTorch/test/train_ID3072_MobileNetV2_performance_1p.sh similarity index 100% rename from PyTorch/built-in/cv/classification/MobileNetV2_for_PyTorch/test/train_ID3072_MobileNetV2_performance_1p.py rename to PyTorch/built-in/cv/classification/MobileNetV2_for_PyTorch/test/train_ID3072_MobileNetV2_performance_1p.sh -- Gitee