diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/utils.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/utils.py index b619d0cf493ff63529f4cb1d70c53e8abe7a45a8..8595bc2389478a4b94638570277b84865a9b2606 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/utils.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/utils.py @@ -4,6 +4,7 @@ import shutil import sys from pathlib import Path import torch +import torch.distributed as dist from ..dump import dump from ..common.utils import print_error_log, CompareException, DumpException, Const, get_time, print_info_log, \ @@ -156,6 +157,9 @@ def set_dump_path(fpath=None, dump_tag='ptdbg_dump'): def get_tensor_rank(in_feat, out_feat): + if dist.is_initialized(): + return dist.get_rank() + def get_tensor_rank_single(x): if isinstance(x, (list, tuple)): if len(x) > 0: