From 42ce2490ce8ec69f216897955a6947879fb69ccb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E5=BC=BA?= Date: Tue, 30 Jul 2024 10:35:13 +0800 Subject: [PATCH] es evict new --- tf_adapter/python/npu_bridge/embedding/embedding_service.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tf_adapter/python/npu_bridge/embedding/embedding_service.py b/tf_adapter/python/npu_bridge/embedding/embedding_service.py index f16998687..2467366b9 100644 --- a/tf_adapter/python/npu_bridge/embedding/embedding_service.py +++ b/tf_adapter/python/npu_bridge/embedding/embedding_service.py @@ -827,6 +827,8 @@ class ESWorker: raise ValueError("steps_to_live must be greater than zero.") self._steps_to_live = steps_to_live table_id_list = [] + env_dist = os.environ + rank_size = int(env_dist.get("RANK_SIZE")) with specified_ps_engine_scope(): for table_id in self._ps_table_id_list: table_id_list.append(table_id) @@ -834,7 +836,8 @@ class ESWorker: gen_npu_cpu_ops.embedding_table_evict(var_handle=ops.convert_to_tensor(table_id_list), global_step=1, steps_to_live=self._steps_to_live) - return tf.group([embedding_table_evict]) + evict_out = allgather(tensor=embedding_table_evict, rank_size=rank_size, group="user_group_evict") + return tf.group([evict_out]) def _update_config_params(self): env_dist = os.environ -- Gitee