diff --git a/torch_npu/__init__.py b/torch_npu/__init__.py index 10e15a15224e216ee116060b6838936b4a8d2712..60a12df4249de3519331282573b47cc726372cfe 100644 --- a/torch_npu/__init__.py +++ b/torch_npu/__init__.py @@ -190,6 +190,8 @@ def _apply_distributed_methods_patch(): torch.distributed.reinit_process_group = torch_npu.distributed.reinit_process_group torch.distributed.distributed_c10d.rendezvous = torch_npu.distributed.distributed_c10d._trigger_rendezvous_decorator(torch.distributed.distributed_c10d.rendezvous) torch.distributed.launcher.api._get_addr_and_port = torch_npu.distributed.distributed_c10d._trigger__get_addr_and_port_decorator(torch.distributed.launcher.api._get_addr_and_port) + torch._C._distributed_c10d.ProcessGroup._get_sequence_number_for_group = ( + torch_npu.distributed.distributed_c10d._hccl_get_sequence_number_for_group) torch.utils.rename_privateuse1_backend("npu") diff --git a/torch_npu/distributed/distributed_c10d.py b/torch_npu/distributed/distributed_c10d.py index 690b82dfaa9d62e530d5fc8ab48169ba9dddc612..09b0ee8931300a43f1f569c1b66dead6e1e47ac0 100644 --- a/torch_npu/distributed/distributed_c10d.py +++ b/torch_npu/distributed/distributed_c10d.py @@ -16,11 +16,13 @@ from torch.distributed.distributed_c10d import _get_default_group, get_group_ran _get_pg_default_device, _object_to_tensor, get_world_size, _tensor_to_object, all_gather, Backend, \ get_backend, GatherOptions, _update_default_pg, _world, _unregister_all_process_groups, _pg_map, \ ProcessGroup, default_pg_timeout, ReduceScatterOptions, _unregister_process_group +from torch._C._distributed_c10d import ProcessGroup from torch_npu.utils._error_code import ErrCode, dist_error logger = logging.getLogger("torch.distributed") +origin_get_sequence_number_for_group = ProcessGroup._get_sequence_number_for_group def _batch_isend_irecv(p2p_op_list): @@ -343,4 +345,12 @@ def _destructor_process_group(): _world.pg_coalesce_state.clear() _world.pg_default_device.clear() _unregister_all_process_groups() - _world.group_count = 0 \ No newline at end of file + _world.group_count = 0 + + +def _hccl_get_sequence_number_for_group(self): + backend = torch.distributed.get_backend_config(self) + if backend == "hccl" or backend == "npu:hccl": + return self._get_backend(torch.device("npu"))._get_sequence_number_for_group() + else: + return origin_get_sequence_number_for_group(self)