diff --git "a/Pytorch\350\256\255\347\273\203\347\244\272\344\276\213/WideResNet50_2_ID1627_for_PyTorch/README.md" "b/Pytorch\350\256\255\347\273\203\347\244\272\344\276\213/WideResNet50_2_ID1627_for_PyTorch/README.md" index e504fa9f28690c0519334a088035f61c7167602b..cd3b369a37557ba6a160e89886c188826f3433ca 100644 --- "a/Pytorch\350\256\255\347\273\203\347\244\272\344\276\213/WideResNet50_2_ID1627_for_PyTorch/README.md" +++ "b/Pytorch\350\256\255\347\273\203\347\244\272\344\276\213/WideResNet50_2_ID1627_for_PyTorch/README.md" @@ -30,6 +30,12 @@ bash ./test/train_full_8p.sh --data_path=real_data_path # training 8p performance bash ./test/train_performance_8p.sh --data_path=real_data_path + +#test 8p accuracy +bash test/train_eval_8p.sh --data_path=real_data_path --weight_path=real_pre_train_model_path + +# finetuning 1p +bash test/train_finetune_1p.sh --data_path=real_data_path --weight_path=real_pre_train_model_path ``` Log path: diff --git "a/Pytorch\350\256\255\347\273\203\347\244\272\344\276\213/WideResNet50_2_ID1627_for_PyTorch/main.py" "b/Pytorch\350\256\255\347\273\203\347\244\272\344\276\213/WideResNet50_2_ID1627_for_PyTorch/main.py" index a99b07efc19ac24399e8c4b512e67653d42a4315..2cd1f50ed96be48de00bcacf85e36bb81bb7573a 100644 --- "a/Pytorch\350\256\255\347\273\203\347\244\272\344\276\213/WideResNet50_2_ID1627_for_PyTorch/main.py" +++ "b/Pytorch\350\256\255\347\273\203\347\244\272\344\276\213/WideResNet50_2_ID1627_for_PyTorch/main.py" @@ -65,6 +65,8 @@ parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', help='evaluate model on validation set') parser.add_argument('--pretrained', dest='pretrained', action='store_true', help='use pre-trained model') +parser.add_argument('--weight_path', default='', help='pretrained weight path') +parser.add_argument('--num_classes', default=None,type=int, help='num_classes') parser.add_argument('--world-size', default=-1, type=int, help='number of nodes for distributed training') parser.add_argument('--rank', default=-1, type=int, @@ -189,9 +191,16 @@ def main_worker(gpu, ngpus_per_node, args): if args.pretrained: print("=> using pre-trained model wide_resnet50_2") model = resnet_0_6_0.wide_resnet50_2(num_classes=args.num_classes) - print("loading model of yours...") - pretrained_dict = torch.load("./model_best.pth.tar", map_location="cpu")["state_dict"] + + if os.path.exists(args.weight_path): + print("loading model of your give") + pretrained_dict = torch.load(args.weight_path, map_location="cpu")["state_dict"] + else: + print("loading model default ") + pretrained_dict = torch.load("./model_best.pth.tar", map_location="cpu")["state_dict"] + model.load_state_dict({k.replace('module.',''):v for k, v in pretrained_dict.items()}) + if "fc.weight" in pretrained_dict: pretrained_dict.pop('fc.weight') pretrained_dict.pop('fc.bias') @@ -359,7 +368,10 @@ def main_worker(gpu, ngpus_per_node, args): num_workers=args.workers, pin_memory=False, drop_last=True) if args.evaluate: + print("===eval start===") validate(val_loader, model, criterion, args, ngpus_per_node) + print("===eval over===") + return if args.prof: @@ -389,6 +401,7 @@ def main_worker(gpu, ngpus_per_node, args): and args.rank % ngpus_per_node == 0): ############## npu modify begin ############# + print("=========save model======") if args.amp: save_checkpoint({ 'epoch': epoch + 1, diff --git "a/Pytorch\350\256\255\347\273\203\347\244\272\344\276\213/WideResNet50_2_ID1627_for_PyTorch/test/train_eval_8p.sh" "b/Pytorch\350\256\255\347\273\203\347\244\272\344\276\213/WideResNet50_2_ID1627_for_PyTorch/test/train_eval_8p.sh" index 7f09f2fbfe86ed7dd5925d55228290cfb1d9dc1f..0afbaaf4580816fe0b2911daa1a569838337641d 100644 --- "a/Pytorch\350\256\255\347\273\203\347\244\272\344\276\213/WideResNet50_2_ID1627_for_PyTorch/test/train_eval_8p.sh" +++ "b/Pytorch\350\256\255\347\273\203\347\244\272\344\276\213/WideResNet50_2_ID1627_for_PyTorch/test/train_eval_8p.sh" @@ -9,8 +9,8 @@ batch_size=4096 # 训练使用的npu卡数 export RANK_SIZE=8 # checkpoint文件路径,以实际路径为准 -resume=/home/checkpoint.pth.tar -# 数据集路径,保持为空,不需要修改 +# resume=/home/checkpoint.pth.tar +# 数据集路径,需要根据实际情况填入 data_path="" # 训练epoch @@ -28,6 +28,8 @@ do workers=`echo ${para#*=}` elif [[ $para == --data_path* ]];then data_path=`echo ${para#*=}` + elif [[ $para == --weight_path* ]];then + weight_path=`echo ${para#*=}` fi done @@ -73,7 +75,6 @@ fi python3.7 ./main.py \ ${data_path} \ --evaluate \ - --resume ${resume} \ --addr=$(hostname -I |awk '{print $1}') \ --seed=49 \ --workers=${workers} \ @@ -91,6 +92,8 @@ python3.7 ./main.py \ --warm_up_epochs=5 \ --loss-scale=32 \ --amp \ + --pretrained \ + --weight_path=${weight_path} \ --batch-size=${batch_size} > ${test_path_dir}/output/${ASCEND_DEVICE_ID}/train_${ASCEND_DEVICE_ID}.log 2>&1 & wait diff --git "a/Pytorch\350\256\255\347\273\203\347\244\272\344\276\213/WideResNet50_2_ID1627_for_PyTorch/test/train_finetune_1p.sh" "b/Pytorch\350\256\255\347\273\203\347\244\272\344\276\213/WideResNet50_2_ID1627_for_PyTorch/test/train_finetune_1p.sh" index 30dac86bf9162a5421cb525f0d8363fec42009f0..f87422ddcb6c197f55ce9a20da315d8cfb09eae6 100644 --- "a/Pytorch\350\256\255\347\273\203\347\244\272\344\276\213/WideResNet50_2_ID1627_for_PyTorch/test/train_finetune_1p.sh" +++ "b/Pytorch\350\256\255\347\273\203\347\244\272\344\276\213/WideResNet50_2_ID1627_for_PyTorch/test/train_finetune_1p.sh" @@ -25,6 +25,8 @@ do device_id=`echo ${para#*=}` elif [[ $para == --data_path* ]];then data_path=`echo ${para#*=}` + elif [[ $para == --weight_path* ]];then + weight_path=`echo ${para#*=}` fi done @@ -94,8 +96,8 @@ python3.7 ./main.py \ --warm_up_epochs=5 \ --loss-scale=32 \ --amp \ - --pretrained \ --num_classes=1200 \ + --resume=${weight_path} \ --batch-size=${batch_size} > ${test_path_dir}/output/${ASCEND_DEVICE_ID}/train_${ASCEND_DEVICE_ID}.log 2>&1 & wait