diff --git a/PyTorch/built-in/cv/classification/ResNet50_for_PyTorch/DistributedResnet50/main_apex_d76_npu.py b/PyTorch/built-in/cv/classification/ResNet50_for_PyTorch/DistributedResnet50/main_apex_d76_npu.py index 3176314e251c76349e8c07ad7663a48af06da306..63cff615783b3760927f227921968c5cebd2e6f1 100644 --- a/PyTorch/built-in/cv/classification/ResNet50_for_PyTorch/DistributedResnet50/main_apex_d76_npu.py +++ b/PyTorch/built-in/cv/classification/ResNet50_for_PyTorch/DistributedResnet50/main_apex_d76_npu.py @@ -416,7 +416,7 @@ def main_worker(gpu, ngpus_per_node, args): args.rank = args.rank * ngpus_per_node + gpu if args.device == 'npu': - args.rank = int(os.environ["NODE_RANK"]) * 8 + args.rank + args.rank = int(os.getenv("NODE_RANK", 0)) * 8 + args.rank print("the global_rank is :", args.rank) dist.init_process_group(backend=args.dist_backend, world_size=args.world_size, rank=args.rank)