From 8076f197a6af5b1c4e2ead2ee6ad161dc97fda07 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=82=B5=E9=9D=9E=E5=87=A1?= Date: Mon, 25 Aug 2025 17:14:59 +0800 Subject: [PATCH] fix _get_sequence_number_for_group --- torch_npu/__init__.py | 2 ++ torch_npu/distributed/distributed_c10d.py | 12 +++++++++++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/torch_npu/__init__.py b/torch_npu/__init__.py index 7d68a6a6e04..97ebf2066bf 100644 --- a/torch_npu/__init__.py +++ b/torch_npu/__init__.py @@ -195,6 +195,8 @@ def _apply_distributed_methods_patch(): torch.distributed.reinit_process_group = torch_npu.distributed.reinit_process_group torch.distributed.distributed_c10d.rendezvous = torch_npu.distributed.distributed_c10d._trigger_rendezvous_decorator(torch.distributed.distributed_c10d.rendezvous) torch.distributed.launcher.api._get_addr_and_port = torch_npu.distributed.distributed_c10d._trigger__get_addr_and_port_decorator(torch.distributed.launcher.api._get_addr_and_port) + torch._C._distributed_c10d.ProcessGroup._get_sequence_number_for_group = ( + torch_npu.distributed.distributed_c10d._hccl_get_sequence_number_for_group) torch.serialization.add_safe_globals([torch_npu.npu._format.Format]) diff --git a/torch_npu/distributed/distributed_c10d.py b/torch_npu/distributed/distributed_c10d.py index 53bbc0ba74a..4e609e8c0bc 100644 --- a/torch_npu/distributed/distributed_c10d.py +++ b/torch_npu/distributed/distributed_c10d.py @@ -16,11 +16,13 @@ from torch.distributed.distributed_c10d import _get_default_group, get_group_ran _get_object_coll_device, _object_to_tensor, get_world_size, _tensor_to_object, all_gather, Backend, \ get_backend, GatherOptions, _update_default_pg, _world, _unregister_all_process_groups, _pg_map, \ ProcessGroup, default_pg_timeout, ReduceScatterOptions, _unregister_process_group +from torch._C._distributed_c10d import ProcessGroup from torch_npu.utils._error_code import ErrCode, dist_error logger = logging.getLogger("torch.distributed") +origin_get_sequence_number_for_group = ProcessGroup._get_sequence_number_for_group def _batch_isend_irecv(p2p_op_list): @@ -340,4 +342,12 @@ def _destructor_process_group(): _world.tags_to_pg.clear() _world.pg_coalesce_state.clear() _unregister_all_process_groups() - _world.group_count = 0 \ No newline at end of file + _world.group_count = 0 + + +def _hccl_get_sequence_number_for_group(self): + backend = torch.distributed.get_backend_config(self) + if backend == "hccl" or backend == "npu:hccl": + return self._get_backend(torch.device("npu"))._get_sequence_number_for_group() + else: + return origin_get_sequence_number_for_group(self) -- Gitee