diff --git a/PyTorch/contrib/cv/classification/RepVGG/train.py b/PyTorch/contrib/cv/classification/RepVGG/train.py index 16966a1e2003f921c98f0526854e1b9dc23eabf3..20a7903807032ed0f636c4a7e07fcd5231b21908 100644 --- a/PyTorch/contrib/cv/classification/RepVGG/train.py +++ b/PyTorch/contrib/cv/classification/RepVGG/train.py @@ -37,6 +37,8 @@ import time import warnings import torch +if torch.__version__ >= "1.8.1": + import torch_npu import torch.nn as nn import torch.nn.parallel import torch.backends.cudnn as cudnn