From 76ced10eb35af2147f720e7f2976823bf0017221 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=BA=90=E6=98=8E?= <1021955028@qq.com> Date: Tue, 7 Sep 2021 07:15:40 +0000 Subject: [PATCH 1/7] =?UTF-8?q?update=20Pytorch=E8=AE=AD=E7=BB=83=E7=A4=BA?= =?UTF-8?q?=E4=BE=8B/WideResNet50=5F2=5FID1627=5Ffor=5FPyTorch/test/train?= =?UTF-8?q?=5Feval=5F8p.sh.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../test/train_eval_8p.sh" | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git "a/Pytorch\350\256\255\347\273\203\347\244\272\344\276\213/WideResNet50_2_ID1627_for_PyTorch/test/train_eval_8p.sh" "b/Pytorch\350\256\255\347\273\203\347\244\272\344\276\213/WideResNet50_2_ID1627_for_PyTorch/test/train_eval_8p.sh" index 7f09f2f..d34f47b 100644 --- "a/Pytorch\350\256\255\347\273\203\347\244\272\344\276\213/WideResNet50_2_ID1627_for_PyTorch/test/train_eval_8p.sh" +++ "b/Pytorch\350\256\255\347\273\203\347\244\272\344\276\213/WideResNet50_2_ID1627_for_PyTorch/test/train_eval_8p.sh" @@ -8,11 +8,10 @@ Network="WideResNet50_2_ID1627_for_PyTorch" batch_size=4096 # 训练使用的npu卡数 export RANK_SIZE=8 -# checkpoint文件路径,以实际路径为准 -resume=/home/checkpoint.pth.tar # 数据集路径,保持为空,不需要修改 data_path="" - +# checkpoint文件路径,以实际路径为准 +pth_path="" # 训练epoch train_epochs=200 # 学习率 @@ -28,6 +27,8 @@ do workers=`echo ${para#*=}` elif [[ $para == --data_path* ]];then data_path=`echo ${para#*=}` + elif [[ $para == --pth_path* ]];then + pth_path=`echo ${para#*=}` fi done @@ -37,6 +38,11 @@ if [[ $data_path == "" ]];then exit 1 fi +# 校验是否传入 pth_path , 验证脚本需要传入此参数 +if [[ $pth_path == "" ]];then + echo "[Error] para \"pth_path\" must be confing" + exit 1 +fi ###############指定训练脚本执行路径############### # cd到与test文件夹同层级目录下执行脚本,提高兼容性;test_path_dir为包含test文件夹的路径 @@ -73,7 +79,7 @@ fi python3.7 ./main.py \ ${data_path} \ --evaluate \ - --resume ${resume} \ + --resume ${pth_path} \ --addr=$(hostname -I |awk '{print $1}') \ --seed=49 \ --workers=${workers} \ -- Gitee From 2313dfa642c3c98f1945b08c722bd4c66069e551 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=BA=90=E6=98=8E?= <1021955028@qq.com> Date: Tue, 7 Sep 2021 07:16:08 +0000 Subject: [PATCH 2/7] =?UTF-8?q?update=20Pytorch=E8=AE=AD=E7=BB=83=E7=A4=BA?= =?UTF-8?q?=E4=BE=8B/WideResNet50=5F2=5FID1627=5Ffor=5FPyTorch/main.py.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../WideResNet50_2_ID1627_for_PyTorch/main.py" | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git "a/Pytorch\350\256\255\347\273\203\347\244\272\344\276\213/WideResNet50_2_ID1627_for_PyTorch/main.py" "b/Pytorch\350\256\255\347\273\203\347\244\272\344\276\213/WideResNet50_2_ID1627_for_PyTorch/main.py" index a99b07e..e34755a 100644 --- "a/Pytorch\350\256\255\347\273\203\347\244\272\344\276\213/WideResNet50_2_ID1627_for_PyTorch/main.py" +++ "b/Pytorch\350\256\255\347\273\203\347\244\272\344\276\213/WideResNet50_2_ID1627_for_PyTorch/main.py" @@ -65,6 +65,8 @@ parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', help='evaluate model on validation set') parser.add_argument('--pretrained', dest='pretrained', action='store_true', help='use pre-trained model') +parser.add_argument('--pth_path', default='', type=str, metavar='PATH', + help='path to pretrained checkpoint (default: none)') parser.add_argument('--world-size', default=-1, type=int, help='number of nodes for distributed training') parser.add_argument('--rank', default=-1, type=int, @@ -82,7 +84,7 @@ parser.add_argument('--multiprocessing-distributed', action='store_true', 'N processes per node, which has N GPUs. This is the ' 'fastest way to use PyTorch for either single node or ' 'multi node data parallel training') -parser.add_argument('--num-classes', default=1000, type=int, +parser.add_argument('--num_classes', default=1000, type=int, help='The number of classes.') ## for ascend 910 parser.add_argument('--device', default='npu', type=str, help='npu or gpu') @@ -190,9 +192,15 @@ def main_worker(gpu, ngpus_per_node, args): print("=> using pre-trained model wide_resnet50_2") model = resnet_0_6_0.wide_resnet50_2(num_classes=args.num_classes) print("loading model of yours...") - pretrained_dict = torch.load("./model_best.pth.tar", map_location="cpu")["state_dict"] + if args.pth_path: + print("load pth you give") + pretrained_dict = torch.load(args.pth_path, map_location="cpu")["state_dict"] + else: + pretrained_dict = torch.load("./model_best.pth.tar", map_location="cpu")["state_dict"] + model.load_state_dict({k.replace('module.',''):v for k, v in pretrained_dict.items()}) if "fc.weight" in pretrained_dict: + print("pop out fc layer") pretrained_dict.pop('fc.weight') pretrained_dict.pop('fc.bias') model.load_state_dict(pretrained_dict, strict=False) -- Gitee From e45a94c909621489ce140c994f25f97e79fd0561 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=BA=90=E6=98=8E?= <1021955028@qq.com> Date: Tue, 7 Sep 2021 07:16:32 +0000 Subject: [PATCH 3/7] =?UTF-8?q?update=20Pytorch=E8=AE=AD=E7=BB=83=E7=A4=BA?= =?UTF-8?q?=E4=BE=8B/WideResNet50=5F2=5FID1627=5Ffor=5FPyTorch/test/train?= =?UTF-8?q?=5Ffinetune=5F1p.sh.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../test/train_finetune_1p.sh" | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git "a/Pytorch\350\256\255\347\273\203\347\244\272\344\276\213/WideResNet50_2_ID1627_for_PyTorch/test/train_finetune_1p.sh" "b/Pytorch\350\256\255\347\273\203\347\244\272\344\276\213/WideResNet50_2_ID1627_for_PyTorch/test/train_finetune_1p.sh" index 30dac86..b2febb1 100644 --- "a/Pytorch\350\256\255\347\273\203\347\244\272\344\276\213/WideResNet50_2_ID1627_for_PyTorch/test/train_finetune_1p.sh" +++ "b/Pytorch\350\256\255\347\273\203\347\244\272\344\276\213/WideResNet50_2_ID1627_for_PyTorch/test/train_finetune_1p.sh" @@ -10,7 +10,8 @@ batch_size=256 export RANK_SIZE=1 # 数据集路径,保持为空,不需要修改 data_path="" - +# checkpoint文件路径,以实际路径为准 +pth_path="" # 训练epoch train_epochs=200 # 指定训练所使用的npu device卡id @@ -25,6 +26,8 @@ do device_id=`echo ${para#*=}` elif [[ $para == --data_path* ]];then data_path=`echo ${para#*=}` + elif [[ $para == --pth_path* ]];then + pth_path=`echo ${para#*=}` fi done @@ -33,6 +36,13 @@ if [[ $data_path == "" ]];then echo "[Error] para \"data_path\" must be confing" exit 1 fi + +# 校验是否传入 pth_path , 验证脚本需要传入此参数 +if [[ $pth_path == "" ]];then + echo "[Error] para \"pth_path\" must be confing" + exit 1 +fi + # 校验是否指定了device_id,分动态分配device_id与手动指定device_id,此处不需要修改 if [ $ASCEND_DEVICE_ID ];then echo "device id is ${ASCEND_DEVICE_ID}" @@ -95,11 +105,13 @@ python3.7 ./main.py \ --loss-scale=32 \ --amp \ --pretrained \ - --num_classes=1200 \ + --pth_path=${pth_path} \ + --num_classes=1000 \ --batch-size=${batch_size} > ${test_path_dir}/output/${ASCEND_DEVICE_ID}/train_${ASCEND_DEVICE_ID}.log 2>&1 & wait +#注意 num_classes参数 根据加载的 pth_path 填写实际的 classes 数目 ##################获取训练数据################ #训练结束时间,不需要修改 -- Gitee From 4e7233e10ff086957553fe3d6c092127e2eba379 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=BA=90=E6=98=8E?= <1021955028@qq.com> Date: Tue, 7 Sep 2021 07:18:06 +0000 Subject: [PATCH 4/7] =?UTF-8?q?update=20Pytorch=E8=AE=AD=E7=BB=83=E7=A4=BA?= =?UTF-8?q?=E4=BE=8B/WideResNet50=5F2=5FID1627=5Ffor=5FPyTorch/README.md.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../WideResNet50_2_ID1627_for_PyTorch/README.md" | 6 ++++++ 1 file changed, 6 insertions(+) diff --git "a/Pytorch\350\256\255\347\273\203\347\244\272\344\276\213/WideResNet50_2_ID1627_for_PyTorch/README.md" "b/Pytorch\350\256\255\347\273\203\347\244\272\344\276\213/WideResNet50_2_ID1627_for_PyTorch/README.md" index e504fa9..d7ddc8a 100644 --- "a/Pytorch\350\256\255\347\273\203\347\244\272\344\276\213/WideResNet50_2_ID1627_for_PyTorch/README.md" +++ "b/Pytorch\350\256\255\347\273\203\347\244\272\344\276\213/WideResNet50_2_ID1627_for_PyTorch/README.md" @@ -30,6 +30,12 @@ bash ./test/train_full_8p.sh --data_path=real_data_path # training 8p performance bash ./test/train_performance_8p.sh --data_path=real_data_path + +#test 8p accuracy +bash test/train_eval_8p.sh --data_path=real_data_path --pth_path=real_pre_train_model_path + +# finetuning 1p +bash test/train_finetune_1p.sh --data_path=real_data_path --pth_path=real_pre_train_model_path ``` Log path: -- Gitee From 2baa9daf64602c0005c5c26d294e783a69b49b11 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=BA=90=E6=98=8E?= <1021955028@qq.com> Date: Tue, 7 Sep 2021 08:25:32 +0000 Subject: [PATCH 5/7] =?UTF-8?q?update=20Pytorch=E8=AE=AD=E7=BB=83=E7=A4=BA?= =?UTF-8?q?=E4=BE=8B/WideResNet50=5F2=5FID1627=5Ffor=5FPyTorch/test/train?= =?UTF-8?q?=5Ffinetune=5F1p.sh.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../test/train_finetune_1p.sh" | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git "a/Pytorch\350\256\255\347\273\203\347\244\272\344\276\213/WideResNet50_2_ID1627_for_PyTorch/test/train_finetune_1p.sh" "b/Pytorch\350\256\255\347\273\203\347\244\272\344\276\213/WideResNet50_2_ID1627_for_PyTorch/test/train_finetune_1p.sh" index b2febb1..1821a33 100644 --- "a/Pytorch\350\256\255\347\273\203\347\244\272\344\276\213/WideResNet50_2_ID1627_for_PyTorch/test/train_finetune_1p.sh" +++ "b/Pytorch\350\256\255\347\273\203\347\244\272\344\276\213/WideResNet50_2_ID1627_for_PyTorch/test/train_finetune_1p.sh" @@ -106,12 +106,11 @@ python3.7 ./main.py \ --amp \ --pretrained \ --pth_path=${pth_path} \ - --num_classes=1000 \ + --num_classes=1200 \ --batch-size=${batch_size} > ${test_path_dir}/output/${ASCEND_DEVICE_ID}/train_${ASCEND_DEVICE_ID}.log 2>&1 & wait -#注意 num_classes参数 根据加载的 pth_path 填写实际的 classes 数目 ##################获取训练数据################ #训练结束时间,不需要修改 -- Gitee From 8dc5453b1f192a27d11e3851326175ea8141b1cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=BA=90=E6=98=8E?= <1021955028@qq.com> Date: Tue, 7 Sep 2021 12:24:31 +0000 Subject: [PATCH 6/7] =?UTF-8?q?update=20Pytorch=E8=AE=AD=E7=BB=83=E7=A4=BA?= =?UTF-8?q?=E4=BE=8B/WideResNet50=5F2=5FID1627=5Ffor=5FPyTorch/main.py.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../WideResNet50_2_ID1627_for_PyTorch/main.py" | 1 - 1 file changed, 1 deletion(-) diff --git "a/Pytorch\350\256\255\347\273\203\347\244\272\344\276\213/WideResNet50_2_ID1627_for_PyTorch/main.py" "b/Pytorch\350\256\255\347\273\203\347\244\272\344\276\213/WideResNet50_2_ID1627_for_PyTorch/main.py" index e34755a..87a8b62 100644 --- "a/Pytorch\350\256\255\347\273\203\347\244\272\344\276\213/WideResNet50_2_ID1627_for_PyTorch/main.py" +++ "b/Pytorch\350\256\255\347\273\203\347\244\272\344\276\213/WideResNet50_2_ID1627_for_PyTorch/main.py" @@ -198,7 +198,6 @@ def main_worker(gpu, ngpus_per_node, args): else: pretrained_dict = torch.load("./model_best.pth.tar", map_location="cpu")["state_dict"] - model.load_state_dict({k.replace('module.',''):v for k, v in pretrained_dict.items()}) if "fc.weight" in pretrained_dict: print("pop out fc layer") pretrained_dict.pop('fc.weight') -- Gitee From 0d26125b12a1df1d7429fcb26d295678d7782412 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=BA=90=E6=98=8E?= <1021955028@qq.com> Date: Wed, 8 Sep 2021 00:57:45 +0000 Subject: [PATCH 7/7] =?UTF-8?q?update=20Pytorch=E8=AE=AD=E7=BB=83=E7=A4=BA?= =?UTF-8?q?=E4=BE=8B/WideResNet50=5F2=5FID1627=5Ffor=5FPyTorch/main.py.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../WideResNet50_2_ID1627_for_PyTorch/main.py" | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git "a/Pytorch\350\256\255\347\273\203\347\244\272\344\276\213/WideResNet50_2_ID1627_for_PyTorch/main.py" "b/Pytorch\350\256\255\347\273\203\347\244\272\344\276\213/WideResNet50_2_ID1627_for_PyTorch/main.py" index 87a8b62..32cb29f 100644 --- "a/Pytorch\350\256\255\347\273\203\347\244\272\344\276\213/WideResNet50_2_ID1627_for_PyTorch/main.py" +++ "b/Pytorch\350\256\255\347\273\203\347\244\272\344\276\213/WideResNet50_2_ID1627_for_PyTorch/main.py" @@ -199,7 +199,7 @@ def main_worker(gpu, ngpus_per_node, args): pretrained_dict = torch.load("./model_best.pth.tar", map_location="cpu")["state_dict"] if "fc.weight" in pretrained_dict: - print("pop out fc layer") + print("pop fc layer weight") pretrained_dict.pop('fc.weight') pretrained_dict.pop('fc.bias') model.load_state_dict(pretrained_dict, strict=False) -- Gitee