From cb81e38bedc783ed642ee13dae34d804b69ef30d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=BA=90=E6=98=8E?= <1021955028@qq.com> Date: Mon, 6 Sep 2021 10:51:03 +0000 Subject: [PATCH 1/4] =?UTF-8?q?update=20Pytorch=E8=AE=AD=E7=BB=83=E7=A4=BA?= =?UTF-8?q?=E4=BE=8B/WideResNet50=5F2=5FID1627=5Ffor=5FPyTorch/README.md.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../WideResNet50_2_ID1627_for_PyTorch/README.md" | 6 ++++++ 1 file changed, 6 insertions(+) diff --git "a/Pytorch\350\256\255\347\273\203\347\244\272\344\276\213/WideResNet50_2_ID1627_for_PyTorch/README.md" "b/Pytorch\350\256\255\347\273\203\347\244\272\344\276\213/WideResNet50_2_ID1627_for_PyTorch/README.md" index e504fa9..cd3b369 100644 --- "a/Pytorch\350\256\255\347\273\203\347\244\272\344\276\213/WideResNet50_2_ID1627_for_PyTorch/README.md" +++ "b/Pytorch\350\256\255\347\273\203\347\244\272\344\276\213/WideResNet50_2_ID1627_for_PyTorch/README.md" @@ -30,6 +30,12 @@ bash ./test/train_full_8p.sh --data_path=real_data_path # training 8p performance bash ./test/train_performance_8p.sh --data_path=real_data_path + +#test 8p accuracy +bash test/train_eval_8p.sh --data_path=real_data_path --weight_path=real_pre_train_model_path + +# finetuning 1p +bash test/train_finetune_1p.sh --data_path=real_data_path --weight_path=real_pre_train_model_path ``` Log path: -- Gitee From 4726fd93ae11ff16cb07d2be8d5ba13a7087a8b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=BA=90=E6=98=8E?= <1021955028@qq.com> Date: Mon, 6 Sep 2021 10:51:32 +0000 Subject: [PATCH 2/4] =?UTF-8?q?update=20Pytorch=E8=AE=AD=E7=BB=83=E7=A4=BA?= =?UTF-8?q?=E4=BE=8B/WideResNet50=5F2=5FID1627=5Ffor=5FPyTorch/main.py.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../WideResNet50_2_ID1627_for_PyTorch/main.py" | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git "a/Pytorch\350\256\255\347\273\203\347\244\272\344\276\213/WideResNet50_2_ID1627_for_PyTorch/main.py" "b/Pytorch\350\256\255\347\273\203\347\244\272\344\276\213/WideResNet50_2_ID1627_for_PyTorch/main.py" index a99b07e..2cd1f50 100644 --- "a/Pytorch\350\256\255\347\273\203\347\244\272\344\276\213/WideResNet50_2_ID1627_for_PyTorch/main.py" +++ "b/Pytorch\350\256\255\347\273\203\347\244\272\344\276\213/WideResNet50_2_ID1627_for_PyTorch/main.py" @@ -65,6 +65,8 @@ parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', help='evaluate model on validation set') parser.add_argument('--pretrained', dest='pretrained', action='store_true', help='use pre-trained model') +parser.add_argument('--weight_path', default='', help='pretrained weight path') +parser.add_argument('--num_classes', default=None,type=int, help='num_classes') parser.add_argument('--world-size', default=-1, type=int, help='number of nodes for distributed training') parser.add_argument('--rank', default=-1, type=int, @@ -189,9 +191,16 @@ def main_worker(gpu, ngpus_per_node, args): if args.pretrained: print("=> using pre-trained model wide_resnet50_2") model = resnet_0_6_0.wide_resnet50_2(num_classes=args.num_classes) - print("loading model of yours...") - pretrained_dict = torch.load("./model_best.pth.tar", map_location="cpu")["state_dict"] + + if os.path.exists(args.weight_path): + print("loading model of your give") + pretrained_dict = torch.load(args.weight_path, map_location="cpu")["state_dict"] + else: + print("loading model default ") + pretrained_dict = torch.load("./model_best.pth.tar", map_location="cpu")["state_dict"] + model.load_state_dict({k.replace('module.',''):v for k, v in pretrained_dict.items()}) + if "fc.weight" in pretrained_dict: pretrained_dict.pop('fc.weight') pretrained_dict.pop('fc.bias') @@ -359,7 +368,10 @@ def main_worker(gpu, ngpus_per_node, args): num_workers=args.workers, pin_memory=False, drop_last=True) if args.evaluate: + print("===eval start===") validate(val_loader, model, criterion, args, ngpus_per_node) + print("===eval over===") + return if args.prof: @@ -389,6 +401,7 @@ def main_worker(gpu, ngpus_per_node, args): and args.rank % ngpus_per_node == 0): ############## npu modify begin ############# + print("=========save model======") if args.amp: save_checkpoint({ 'epoch': epoch + 1, -- Gitee From 2bce91e044170b2147704c097078710a0603f7dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=BA=90=E6=98=8E?= <1021955028@qq.com> Date: Mon, 6 Sep 2021 10:52:11 +0000 Subject: [PATCH 3/4] =?UTF-8?q?update=20Pytorch=E8=AE=AD=E7=BB=83=E7=A4=BA?= =?UTF-8?q?=E4=BE=8B/WideResNet50=5F2=5FID1627=5Ffor=5FPyTorch/test/train?= =?UTF-8?q?=5Feval=5F8p.sh.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../test/train_eval_8p.sh" | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git "a/Pytorch\350\256\255\347\273\203\347\244\272\344\276\213/WideResNet50_2_ID1627_for_PyTorch/test/train_eval_8p.sh" "b/Pytorch\350\256\255\347\273\203\347\244\272\344\276\213/WideResNet50_2_ID1627_for_PyTorch/test/train_eval_8p.sh" index 7f09f2f..0afbaaf 100644 --- "a/Pytorch\350\256\255\347\273\203\347\244\272\344\276\213/WideResNet50_2_ID1627_for_PyTorch/test/train_eval_8p.sh" +++ "b/Pytorch\350\256\255\347\273\203\347\244\272\344\276\213/WideResNet50_2_ID1627_for_PyTorch/test/train_eval_8p.sh" @@ -9,8 +9,8 @@ batch_size=4096 # 训练使用的npu卡数 export RANK_SIZE=8 # checkpoint文件路径,以实际路径为准 -resume=/home/checkpoint.pth.tar -# 数据集路径,保持为空,不需要修改 +# resume=/home/checkpoint.pth.tar +# 数据集路径,需要根据实际情况填入 data_path="" # 训练epoch @@ -28,6 +28,8 @@ do workers=`echo ${para#*=}` elif [[ $para == --data_path* ]];then data_path=`echo ${para#*=}` + elif [[ $para == --weight_path* ]];then + weight_path=`echo ${para#*=}` fi done @@ -73,7 +75,6 @@ fi python3.7 ./main.py \ ${data_path} \ --evaluate \ - --resume ${resume} \ --addr=$(hostname -I |awk '{print $1}') \ --seed=49 \ --workers=${workers} \ @@ -91,6 +92,8 @@ python3.7 ./main.py \ --warm_up_epochs=5 \ --loss-scale=32 \ --amp \ + --pretrained \ + --weight_path=${weight_path} \ --batch-size=${batch_size} > ${test_path_dir}/output/${ASCEND_DEVICE_ID}/train_${ASCEND_DEVICE_ID}.log 2>&1 & wait -- Gitee From d838be206527b7328a7169d56bc281b837d6db18 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=BA=90=E6=98=8E?= <1021955028@qq.com> Date: Mon, 6 Sep 2021 10:52:33 +0000 Subject: [PATCH 4/4] =?UTF-8?q?update=20Pytorch=E8=AE=AD=E7=BB=83=E7=A4=BA?= =?UTF-8?q?=E4=BE=8B/WideResNet50=5F2=5FID1627=5Ffor=5FPyTorch/test/train?= =?UTF-8?q?=5Ffinetune=5F1p.sh.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../test/train_finetune_1p.sh" | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git "a/Pytorch\350\256\255\347\273\203\347\244\272\344\276\213/WideResNet50_2_ID1627_for_PyTorch/test/train_finetune_1p.sh" "b/Pytorch\350\256\255\347\273\203\347\244\272\344\276\213/WideResNet50_2_ID1627_for_PyTorch/test/train_finetune_1p.sh" index 30dac86..f87422d 100644 --- "a/Pytorch\350\256\255\347\273\203\347\244\272\344\276\213/WideResNet50_2_ID1627_for_PyTorch/test/train_finetune_1p.sh" +++ "b/Pytorch\350\256\255\347\273\203\347\244\272\344\276\213/WideResNet50_2_ID1627_for_PyTorch/test/train_finetune_1p.sh" @@ -25,6 +25,8 @@ do device_id=`echo ${para#*=}` elif [[ $para == --data_path* ]];then data_path=`echo ${para#*=}` + elif [[ $para == --weight_path* ]];then + weight_path=`echo ${para#*=}` fi done @@ -94,8 +96,8 @@ python3.7 ./main.py \ --warm_up_epochs=5 \ --loss-scale=32 \ --amp \ - --pretrained \ --num_classes=1200 \ + --resume=${weight_path} \ --batch-size=${batch_size} > ${test_path_dir}/output/${ASCEND_DEVICE_ID}/train_${ASCEND_DEVICE_ID}.log 2>&1 & wait -- Gitee