diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/common/utils.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/common/utils.py index 1e82f62fde16206f64229df1ada275f3ca630771..f3342dc6b44e3275d79e1ea4a6fa374e5aefa0f0 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/common/utils.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/common/utils.py @@ -114,8 +114,11 @@ class Const: MAX_SEED_VALUE = 2**32 - 1 - INPLACE_LIST = ["broadcast", "all_reduce", "reduce", "all_gather", "gather", "scatter", "reduce_scatter", - "_reduce_scatter_base", "_all_gather_base"] + INPLACE_LIST = [ + "broadcast", "all_reduce", "reduce", "all_gather", "gather", "scatter", + "reduce_scatter", "_reduce_scatter_base", "_all_gather_base", "send", + "recv", "isend", "irecv" + ] class CompareConst: diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/wrap_distributed.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/wrap_distributed.py index 48e92faa1b9294c7905be86e391fa54a1f7b153f..87d16636fa9b9fe6beac6cf22f2bb827213080ad 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/wrap_distributed.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/wrap_distributed.py @@ -56,7 +56,12 @@ class DistributedOPTemplate(HOOKModule): @torch_device_guard def forward(self, *args, **kwargs): - return distributed_func.get(self.op_name_)(*args, **kwargs) + if ("async_op" in kwargs and kwargs["async_op"]) or self.op_name_ in ["isend", "irecv"]: + handle = distributed_func.get(self.op_name_)(*args, **kwargs) + handle.wait() + return handle + else: + return distributed_func.get(self.op_name_)(*args, **kwargs) def wrap_distributed_op(op_name, hook):