From cca5288dd9cd0998f5331bb0d3fe1a1b03ff1458 Mon Sep 17 00:00:00 2001 From: l30036321 Date: Tue, 23 Apr 2024 17:17:57 +0800 Subject: [PATCH] fix async issue --- .../ptdbg_ascend/src/python/ptdbg_ascend/common/utils.py | 6 ++++-- .../python/ptdbg_ascend/hook_module/wrap_distributed.py | 7 ++++++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/common/utils.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/common/utils.py index db092a8bc..a264f89c7 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/common/utils.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/common/utils.py @@ -115,8 +115,10 @@ class Const: MAX_SEED_VALUE = 2**32 - 1 - INPLACE_LIST = ["broadcast", "all_reduce", "reduce", "all_gather", "gather", "scatter", "reduce_scatter", - "_reduce_scatter_base", "_all_gather_base"] + INPLACE_LIST = [ + "broadcast", "all_reduce", "reduce", "all_gather", "gather", "scatter", "reduce_scatter", + "_reduce_scatter_base", "_all_gather_base", "send", "recv", "irecv", "isend" + ] class CompareConst: diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/wrap_distributed.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/wrap_distributed.py index 48e92faa1..8d5140206 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/wrap_distributed.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/wrap_distributed.py @@ -56,7 +56,12 @@ class DistributedOPTemplate(HOOKModule): @torch_device_guard def forward(self, *args, **kwargs): - return distributed_func.get(self.op_name_)(*args, **kwargs) + if kwargs.get("async_op") or self.op_name_ in ["isend", "irecv"]: + handle = distributed_func.get(self.op_name_)(*args, **kwargs) + handle.wait() + return handle + else: + return distributed_func.get(self.op_name_)(*args, **kwargs) def wrap_distributed_op(op_name, hook): -- Gitee