diff --git a/PyTorch/contrib/cv/classification/csp_resnext50-mish/train-1p.py b/PyTorch/contrib/cv/classification/csp_resnext50-mish/train-1p.py index 9c4316c2b8971bc4218a298a5a981befa27d7fbe..558434aaf2103704eb144c790a3b4e35423a6cd0 100755 --- a/PyTorch/contrib/cv/classification/csp_resnext50-mish/train-1p.py +++ b/PyTorch/contrib/cv/classification/csp_resnext50-mish/train-1p.py @@ -51,6 +51,8 @@ from timm.scheduler import create_scheduler from timm.utils import ApexScaler, NativeScaler #from ghostnet.ghostnet_pytorch.ghostnet import ghostnet +if torch.version >= '1.8.1': +import torch_npu import torch.npu import apex @@ -273,7 +275,7 @@ parser.add_argument('--distributed',action='store_true', 'fastest way to use PyTorch for either single node or ' 'multi node data parallel training') parser.add_argument("--world_size", default=1, type=int) -parser.add_argument("--loss-scale", default=1024, type=int) +parser.add_argument("--loss-scale", default='dynamic', type=str) parser.add_argument('--reprob', type=float, default=0., metavar='PCT', help='Random erase prob (default: 0.)') parser.add_argument('--remode', type=str, default='const', diff --git a/PyTorch/contrib/cv/classification/csp_resnext50-mish/train-8p.py b/PyTorch/contrib/cv/classification/csp_resnext50-mish/train-8p.py index 608f3795f4ca5aefb69b44c28b810a7e56d5ded5..591bc1ee488c6527615cb051d385dadeb39b7995 100755 --- a/PyTorch/contrib/cv/classification/csp_resnext50-mish/train-8p.py +++ b/PyTorch/contrib/cv/classification/csp_resnext50-mish/train-8p.py @@ -51,6 +51,8 @@ from timm.scheduler import create_scheduler from timm.utils import ApexScaler, NativeScaler #from ghostnet.ghostnet_pytorch.ghostnet import ghostnet +if torch.version >= '1.8.1': +import torch_npu import torch.npu import apex @@ -273,7 +275,7 @@ parser.add_argument('--distributed',action='store_true', 'fastest way to use PyTorch for either single node or ' 'multi node data parallel training') parser.add_argument("--world_size", default=1, type=int) -parser.add_argument("--loss-scale", default=1024, type=int) +parser.add_argument("--loss-scale", default='dynamic', type=str) parser.add_argument('--reprob', type=float, default=0., metavar='PCT', help='Random erase prob (default: 0.)') parser.add_argument('--remode', type=str, default='const',