From 9cd9e3cd5ac431517ed72f020f2da04e5f23ea37 Mon Sep 17 00:00:00 2001 From: l30036321 Date: Mon, 25 Mar 2024 15:34:40 +0800 Subject: [PATCH] fix async issue --- .../ptdbg_ascend/src/python/ptdbg_ascend/common/utils.py | 7 +++++-- .../python/ptdbg_ascend/hook_module/wrap_distributed.py | 7 ++++++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/common/utils.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/common/utils.py index 1e82f62fde..f3342dc6b4 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/common/utils.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/common/utils.py @@ -114,8 +114,11 @@ class Const: MAX_SEED_VALUE = 2**32 - 1 - INPLACE_LIST = ["broadcast", "all_reduce", "reduce", "all_gather", "gather", "scatter", "reduce_scatter", - "_reduce_scatter_base", "_all_gather_base"] + INPLACE_LIST = [ + "broadcast", "all_reduce", "reduce", "all_gather", "gather", "scatter", + "reduce_scatter", "_reduce_scatter_base", "_all_gather_base", "send", + "recv", "isend", "irecv" + ] class CompareConst: diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/wrap_distributed.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/wrap_distributed.py index 48e92faa1b..87d16636fa 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/wrap_distributed.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/wrap_distributed.py @@ -56,7 +56,12 @@ class DistributedOPTemplate(HOOKModule): @torch_device_guard def forward(self, *args, **kwargs): - return distributed_func.get(self.op_name_)(*args, **kwargs) + if ("async_op" in kwargs and kwargs["async_op"]) or self.op_name_ in ["isend", "irecv"]: + handle = distributed_func.get(self.op_name_)(*args, **kwargs) + handle.wait() + return handle + else: + return distributed_func.get(self.op_name_)(*args, **kwargs) def wrap_distributed_op(op_name, hook): -- Gitee