From c8eeb38b0eb9a20347a23df9ed41bd6e03fd9323 Mon Sep 17 00:00:00 2001 From: xumingqian Date: Tue, 26 Mar 2024 15:45:32 +0800 Subject: [PATCH] evict --- .../npu_bridge/embedding/embedding_service.py | 68 +++++++++++++++---- 1 file changed, 54 insertions(+), 14 deletions(-) diff --git a/tf_adapter/python/npu_bridge/embedding/embedding_service.py b/tf_adapter/python/npu_bridge/embedding/embedding_service.py index 767bf6167..2aebe8cd7 100644 --- a/tf_adapter/python/npu_bridge/embedding/embedding_service.py +++ b/tf_adapter/python/npu_bridge/embedding/embedding_service.py @@ -71,6 +71,11 @@ class CounterFilter: self.default_value = default_value self.default_key_or_value = default_key_or_value +class EvictOption: + """ Evict option for embedding table. """ + + def __init__(self, steps_to_live): + self.steps_to_live = steps_to_live class EsInitializer: """Initializer for embedding service table.""" @@ -222,6 +227,9 @@ class ESWorker: self._default_key = None self._default_value = None + # use for evict option + self._steps_to_live = 0 + # 提供 embedding_service table initializer method # table_id embedding 表索引, int 类型 # min 下限值, float 类型 @@ -266,11 +274,14 @@ class ESWorker: # 暂时只使用特征准入option def embedding_variable_option(self, filter_option=None, evict_option=None, storage_option=None, feature_freezing_option=None, communication_option=None): - if filter_option is None: - raise ValueError("Now filter_option can't be None.") - if not isinstance(filter_option, CounterFilter): - raise TypeError("If filter_option isn't None, it must be CounterFilter type.") - self._use_counter_filter = True + if (filter_option is None) and (not isinstance(filter_option, CounterFilter)): + raise TypeError("When filter_option is not None, it must be float or int, please check.") + if (evict_option is None) and (not isinstance(evict_option, EvictOption)): + raise TypeError("When evict_option is not None, it must be float or int, please check.") + if (filter_option is None) and (evict_option is None): + raise TypeError("Now filter_option and evict_option can't be None simutaneously.") + if filter_option is not None: + self._use_counter_filter = True return EmbeddingVariableOption(filter_option=filter_option, evict_option=evict_option, storage_option=storage_option, feature_freezing_option=feature_freezing_option, communication_option=communication_option) @@ -628,6 +639,13 @@ class ESWorker: return CounterFilter(filter_freq=filter_freq, default_key_or_value=True, default_key=default_key, default_value=default_value) + def evict_option(self, steps_to_live): + if not isinstance(steps_to_live, int): + raise TypeError("steps_to_live must be int, please check.") + if steps_to_live <= 0: + raise ValueError("steps_to_live must must be greater than 0.") + return EvictOption(steps_to_live=steps_to_live) + def data_parallel_embedding(self, max_vocabulary_size, embedding_dim, multihot_lens, allow_merge=True, initializer=tf.random_uniform_initializer(minval=-0.01, maxval=0.01, seed=1234)): if (max_vocabulary_size is None) or (embedding_dim is None) or (multihot_lens is None): @@ -793,7 +811,8 @@ class ESWorker: export_mode="all", only_var_flag=True, file_type="bin", - table_name=[name]) + table_name=[name], + step_threshold=self._steps_to_live) return tf.group([embedding_table_export]) def save_embeddings(self, path: str): @@ -817,7 +836,8 @@ class ESWorker: export_mode="all", only_var_flag=True, file_type="bin", - table_name=self._ps_table_name_list) + table_name=self._ps_table_name_list, + step_threshold=self._steps_to_live) return tf.group([embedding_table_export]) def restore_embedding(self, name: str, path: str): @@ -880,7 +900,8 @@ class ESWorker: only_var_flag=False, file_type="bin", table_name=[name], - filter_export_flag=save_filtered_features) + filter_export_flag=save_filtered_features, + step_threshold=self._steps_to_live) with tf.control_dependencies([embedding_table_export]): embedding_compute_var_export = \ gen_npu_cpu_ops.embedding_compute_var_export(file_path=file_path_tensor, @@ -922,7 +943,8 @@ class ESWorker: only_var_flag=False, file_type="bin", table_name=self._ps_table_name_list, - filter_export_flag=save_filtered_features) + filter_export_flag=save_filtered_features, + step_threshold=self._steps_to_live) with tf.control_dependencies([embedding_table_export]): embedding_compute_var_export = \ gen_npu_cpu_ops.embedding_compute_var_export(file_path=file_path_tensor, @@ -1020,7 +1042,8 @@ class ESWorker: export_mode="new", only_var_flag=True, file_type="bin", - table_name=[name]) + table_name=[name], + step_threshold=self._steps_to_live) return tf.group([embedding_table_export]) def save_incremental_embeddings(self, path: str): @@ -1044,7 +1067,8 @@ class ESWorker: export_mode="new", only_var_flag=True, file_type="bin", - table_name=self._ps_table_name_list) + table_name=self._ps_table_name_list, + step_threshold=self._steps_to_live) return tf.group([embedding_table_export]) def restore_incremental_embedding(self, name: str, path: str): @@ -1109,7 +1133,8 @@ class ESWorker: export_mode="all", only_var_flag=True, file_type="bin", - table_name=[self._table_id_to_name.get(table_id)]) + table_name=[self._table_id_to_name.get(table_id)], + step_threshold=self._steps_to_live) return tf.group([embedding_table_export]) def restore_embedding_v1(self, path: str, table_id: int): @@ -1162,7 +1187,8 @@ class ESWorker: export_mode="all", only_var_flag=False, file_type="bin", - table_name=[self._table_id_to_name.get(table_id)]) + table_name=[self._table_id_to_name.get(table_id)], + step_threshold=self._steps_to_live) with tf.control_dependencies([embedding_table_export]): embedding_compute_var_export = \ gen_npu_cpu_ops.embedding_compute_var_export(file_path=file_path_tensor, @@ -1231,7 +1257,8 @@ class ESWorker: export_mode="new", only_var_flag=True, file_type="bin", - table_name=[self._table_id_to_name.get(table_id)]) + table_name=[self._table_id_to_name.get(table_id)], + step_threshold=self._steps_to_live) return tf.group([embedding_table_export]) def restore_incremental_embedding_v1(self, path: str, table_id: int): @@ -1255,6 +1282,19 @@ class ESWorker: table_name=[self._table_id_to_name.get(table_id)]) return tf.group([embedding_table_import]) + def embedding_evict(self, steps_to_live: int): + """ Operator for evict values in all embedding tables. """ + if steps_to_live <= 0: + raise ValueError("steps_to_live must be greater than zero.") + self._steps_to_live = steps_to_live + with specified_ps_engine_scope(): + for table_id in self._ps_table_id_list: + var = NpuEmbeddingResource(table_id) + embedding_table_evict = \ + gen_npu_cpu_ops.embedding_table_evict(var_handle=var.handle, + step_threshold=self._steps_to_live) + return tf.group([embedding_table_evict]) + def _update_config_params(self): env_dist = os.environ rank_id = env_dist.get("RANK_ID") -- Gitee