diff --git a/PyTorch/built-in/cv/classification/MobileNetV3-Large_ID1784_for_PyTorch/main.py b/PyTorch/built-in/cv/classification/MobileNetV3-Large_ID1784_for_PyTorch/main.py index f1703a2140000f4e959449fbe7a3fce4f157f368..8904e76d4baf97e68753585bfd2c58fe78bbe53f 100644 --- a/PyTorch/built-in/cv/classification/MobileNetV3-Large_ID1784_for_PyTorch/main.py +++ b/PyTorch/built-in/cv/classification/MobileNetV3-Large_ID1784_for_PyTorch/main.py @@ -217,23 +217,14 @@ def main(): criterion = nn.CrossEntropyLoss().to(device) # vision optimizer - #optimizer = torch.optim.RMSprop(model.parameters(), lr=args.lr, momentum=args.momentum, - #weight_decay=args.weight_decay, eps=0.0316, alpha=0.9) - # prepare for new version, significant improvement - optimizer = apex.optimizers.NpuFusedRMSprop(model.parameters(), lr=args.lr, momentum=args.momentum, - weight_decay=args.weight_decay, eps=0.0316, alpha=0.9) - - if args.apex: - model, optimizer = amp.initialize(model, optimizer, - opt_level='O2', - loss_scale=args.loss_scale_value, - combine_grad=True) - - if args.distributed: - model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.device_id]) + optimizer = apex.optimizers.NpuFusedRMSprop(model.parameters(), lr=args.lr, momentum=args.momentum, + weight_decay=args.weight_decay, eps=0.0316, alpha=0.9) + else: + optimizer = torch.optim.RMSprop(model.parameters(), lr=args.lr, momentum=args.momentum, + weight_decay=args.weight_decay, eps=0.0316, alpha=0.9) - # optionally resume from a checkpoint + # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) @@ -247,6 +238,15 @@ def main(): else: print("=> no checkpoint found at '{}'".format(args.resume)) + if args.apex: + model, optimizer = amp.initialize(model, optimizer, + opt_level='O2', + loss_scale=args.loss_scale_value, + combine_grad=True) + + if args.distributed: + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.device_id]) + cudnn.benchmark = True # Data loading code