diff --git a/tf_adapter/python/npu_bridge/embedding/embedding_helper.py b/tf_adapter/python/npu_bridge/embedding/embedding_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..e532a2ac4c06c88f5a37a9192b0e104871bbb51a --- /dev/null +++ b/tf_adapter/python/npu_bridge/embedding/embedding_helper.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +class EmbeddingVariableOption: + """ option for embedding service table. """ + + def __init__(self, filter_option=None, + evict_option=None, + storage_option=None, + feature_freezing_option=None, + communication_option=None): + self.filter_option = filter_option + self.evict_option = evict_option + self.storage_option = storage_option + self.feature_freezing_option = feature_freezing_option + self.communication_option = communication_option + + +class CounterFilter: + """ Counter filter for embedding table. """ + + def __init__(self, + use_counter_filter=1, + filter_mode="counter", + filter_freq=None, + default_key_or_value=None, + default_key=None, + default_value=None): + self.use_counter_filter = use_counter_filter + self.filter_mode = filter_mode + self.filter_freq = filter_freq + self.default_key = default_key + self.default_value = default_value + self.default_key_or_value = default_key_or_value + + +class EsInitializer: + """Initializer for embedding service table.""" + + def __init__(self, initializer_mode, min=-0.01, max=0.01, constant_value=1.0, mu=0.0, sigma=1.0, seed=0): + self.initializer_mode = initializer_mode + self.min = min + self.max = max + self.constant_value = constant_value + self.mu = mu + self.sigma = sigma + self.seed = seed diff --git a/tf_adapter/python/npu_bridge/embedding/embedding_service.py b/tf_adapter/python/npu_bridge/embedding/embedding_service.py index c3a7e936e44ce893ade8f21ae0e5f27a2457455c..0646f587ac3ce8ab233df0a1887283ca22c0088e 100644 --- a/tf_adapter/python/npu_bridge/embedding/embedding_service.py +++ b/tf_adapter/python/npu_bridge/embedding/embedding_service.py @@ -29,6 +29,7 @@ from npu_bridge.npu_cpu.npu_cpu_ops import gen_npu_cpu_ops from npu_bridge.npu_cpu.npu_cpu_ops import host_feature_mapping_export from npu_bridge.npu_cpu.npu_cpu_ops import host_feature_mapping_import from npu_bridge.embedding.embedding_resource import NpuEmbeddingResource +from npu_bridge.embedding.embedding_helper import EmbeddingVariableOption, CounterFilter, EsInitializer from npu_bridge.embedding import embedding_optimizer from npu_bridge.embedding.embedding_table_map_policy import NoneTableMapPolicy, AutoMergeTableMapPolicy @@ -47,44 +48,6 @@ def specified_ps_engine_scope(): yield -class EmbeddingVariableOption: - """ option for embedding service table. """ - - def __init__(self, filter_option=None, - evict_option=None, - storage_option=None, - feature_freezing_option=None, - communication_option=None): - self.filter_option = filter_option - self.evict_option = evict_option - self.storage_option = storage_option - self.feature_freezing_option = feature_freezing_option - self.communication_option = communication_option - - -class CounterFilter: - """ Counter filter for embedding table. """ - - def __init__(self, filter_freq, default_key_or_value, default_key=None, default_value=None): - self.filter_freq = filter_freq - self.default_key = default_key - self.default_value = default_value - self.default_key_or_value = default_key_or_value - - -class EsInitializer: - """Initializer for embedding service table.""" - - def __init__(self, initializer_mode, min=-0.01, max=0.01, constant_value=1.0, mu=0.0, sigma=1.0, seed=0): - self.initializer_mode = initializer_mode - self.min = min - self.max = max - self.constant_value = constant_value - self.mu = mu - self.sigma = sigma - self.seed = seed - - # 提供 embedding_service table initializer method # min 下限值, float 类型 # max 上限值, float 类型 @@ -172,7 +135,6 @@ class ESWorker: self._table_to_optimizer = {} self._table_to_initializer = {} self._table_to_slot_var_num = {} - self._table_to_counter_filter = {} self._train_mode = True self._train_level = False self._optimizer = None @@ -219,12 +181,6 @@ class ESWorker: # use for counter filter self._table_use_counter_filter = {} - self._use_counter_filter = False - self._default_key_or_value = True - self._filter_freq = None - self._default_key = None - self._default_value = None - # 提供 embedding_service table initializer method # table_id embedding 表索引, int 类型 # min 下限值, float 类型 @@ -273,7 +229,6 @@ class ESWorker: raise ValueError("Now filter_option can't be None.") if not isinstance(filter_option, CounterFilter): raise TypeError("If filter_option isn't None, it must be CounterFilter type.") - self._use_counter_filter = True return EmbeddingVariableOption(filter_option=filter_option, evict_option=evict_option, storage_option=storage_option, feature_freezing_option=feature_freezing_option, communication_option=communication_option) @@ -295,129 +250,16 @@ class ESWorker: check_common_init_params(name=name, init_vocabulary_size=init_vocabulary_size, embedding_dim=embedding_dim, embedding_type=embedding_type, mask_zero=mask_zero) if embedding_type == "data_parallel": - self._check_and_update_small_init_params(name=name, init_vocabulary_size=init_vocabulary_size, - embedding_dim=embedding_dim, multihot_lens=multihot_lens, - key_dtype=key_dtype, value_dtype=value_dtype, - allow_merge=allow_merge, initializer=initializer) - new_small_table_info = dict( - name=name, - max_vocabulary_size=init_vocabulary_size, - embedding_dim=embedding_dim, - multihot_lens=multihot_lens, - allow_merge=allow_merge, - initializer=initializer) - self.user_defined_table_infos.append(new_small_table_info) + new_small_table_info = self._init_small_table(name=name, init_vocabulary_size=init_vocabulary_size, + embedding_dim=embedding_dim, multihot_lens=multihot_lens, + key_dtype=key_dtype, value_dtype=value_dtype, + allow_merge=allow_merge, initializer=initializer) return new_small_table_info elif embedding_type == "PS": - table_id = self._check_and_update_ps_init_params(name=name, init_vocabulary_size=init_vocabulary_size, - embedding_dim=embedding_dim, - max_feature_count=max_feature_count, ev_option=ev_option) - self._ps_lookup_index = self._ps_table_count - self._table_to_embedding_dim[table_id] = embedding_dim - self._table_to_max_num[table_id] = max_feature_count - # storage the table id for embedding PS table - self._ps_table_id_list.append(table_id) - self._ps_table_name_list.append(name) - if len(self._ps_table_id_list) > 10: - raise ValueError("Now only 10 PS embedding tables can be init.") - bucket_size = math.ceil(init_vocabulary_size / self._ps_num) - if optimizer is None: - self._train_mode = False - self._table_to_slot_var_num[table_id] = 0 - else: - self._check_ps_opt_and_initializer(optimizer=optimizer, initializer=initializer, table_id=table_id) - self._optimizer = optimizer - self._optimizer._embedding_dims = embedding_dim - self._optimizer._max_nums = max_feature_count - self._optimizer.mask_zero = mask_zero - self._table_to_optimizer[table_id] = self._optimizer - self._ps_table_id_to_optimizer_params[table_id] = [] - self._update_optimizer_slot_var_num(table_id=table_id) - # new train or continue train from a checkpoint - if initializer is not None: - self._train_level = True - with specified_ps_engine_scope(): - self._init_partition_maps[table_id] = \ - gen_npu_cpu_ops.init_partition_map(ps_num=ops.convert_to_tensor(self._ps_num), - ps_ids=ops.convert_to_tensor(self._ps_ids), - partition_num=65537) - self._init_partition_maps.get(table_id)._set_attr("_embedding_dim", - attr_value_pb2.AttrValue(i=embedding_dim)) - self._init_partition_maps.get(table_id)._set_attr("_max_key_num", - attr_value_pb2.AttrValue(i=max_feature_count)) - return self._init_hashmap_and_table_import(bucket_size, table_id, embedding_dim, ev_option) - - # old version - def embedding_init(self, vocabulary_size, table_id, max_batch_size, embedding_dim, optimizer=None, - initializer=None, ev_option=None): - """ Operator for init embedding table. """ - if vocabulary_size is None or table_id is None or max_batch_size is None or embedding_dim is None: - raise ValueError("vocabulary_size or table_id or max_batch_size or embedding_dim is None.") - if (ev_option is not None) and (not isinstance(ev_option, EmbeddingVariableOption)): - raise TypeError("ev_option must be EmbeddingVariableOption type.") - if (not isinstance(vocabulary_size, int)) or (not isinstance(table_id, int)) or \ - (not isinstance(max_batch_size, int)) or (not isinstance(embedding_dim, int)): - raise ValueError("vocabulary_size, table_id, max_batch_size and embedding_dim must be int.") - if vocabulary_size < 0 or table_id < 0: - raise ValueError("vocabulary_size and table_id can not be smaller than zero.") - if vocabulary_size >= _INT32_MAX_VALUE or table_id >= _INT32_MAX_VALUE: - raise ValueError("vocabulary_size or table_id exceed int32 max value.") - if embedding_dim <= 0 or max_batch_size <= 0: - raise ValueError("embedding_dim and max_batch_size must be greater than zero.") - if table_id in self._ps_table_id_list: - raise ValueError("this table has already initialized.") - - self._table_to_embedding_dim[table_id] = embedding_dim - self._table_to_max_num[table_id] = max_batch_size - self._table_id_to_name[table_id] = str(table_id) - self._ps_table_id_list.append(table_id) - self._ps_table_name_list.append(str(table_id)) - if len(self._ps_table_id_list) > 10: - raise ValueError("Now only 10 embedding tables can be init.") - bucket_size = math.ceil(vocabulary_size / self._ps_num) - if (self._table_id_to_initializer.get(table_id) is None) and (initializer is not None): - self._table_id_to_initializer[table_id] = EsInitializer(min=-2, - max=2, - initializer_mode=initializer, - constant_value=0, - mu=0.0, - sigma=1.0) - if optimizer is None: - self._train_mode = False - self._table_to_slot_var_num[table_id] = 0 - else: - if (not isinstance(optimizer, embedding_optimizer.AdamOptimizer)) and \ - (not isinstance(optimizer, embedding_optimizer.AdagradOptimizer)) and \ - (not isinstance(optimizer, embedding_optimizer.AdamWOptimizer)): - raise ValueError( - "optimizer should be embedding_optimizer AdamOptimizer, AdagradOptimizer or AdamWOptimizer.") - if (initializer is not None) and (initializer != 'random_uniform') and \ - (initializer != 'truncated_normal') and (initializer != 'constant'): - raise ValueError("initializer must be random_uniform or truncated_normal or constant.") - self._optimizer = optimizer - self._optimizer._embedding_dims = embedding_dim - self._optimizer._max_nums = max_batch_size - self._optimizer._es_cluster_configs = self._es_cluster_conf - self._table_to_optimizer[table_id] = self._optimizer - self._ps_table_id_to_optimizer_params[table_id] = [] - # adam include m and v, 2 slots; adagrad include accumulator, 1 slot - if isinstance(self._optimizer, embedding_optimizer.AdagradOptimizer): - self._table_to_slot_var_num[table_id] = 1 - else: - self._table_to_slot_var_num[table_id] = 2 - if (initializer is not None) or (self._table_to_initializer.get(table_id) is not None): - self._train_level = True - - with specified_ps_engine_scope(): - self._init_partition_maps[table_id] = \ - gen_npu_cpu_ops.init_partition_map(ps_num=ops.convert_to_tensor(self._ps_num), - ps_ids=ops.convert_to_tensor(self._ps_ids), - partition_num=65537) - self._init_partition_maps.get(table_id)._set_attr("_embedding_dim", - attr_value_pb2.AttrValue(i=embedding_dim)) - self._init_partition_maps.get(table_id)._set_attr("_max_key_num", - attr_value_pb2.AttrValue(i=max_batch_size)) - return self._init_hashmap_and_table_import(bucket_size, table_id, embedding_dim, ev_option) + return self._init_big_table(name=name, init_vocabulary_size=init_vocabulary_size, + embedding_dim=embedding_dim, initializer=initializer, + optimizer=optimizer, mask_zero=mask_zero, + max_feature_count=max_feature_count, ev_option=ev_option) # new version # 提供embedding lookup功能 @@ -427,15 +269,6 @@ class ESWorker: def embedding_lookup(self, name: str, ids: typing.Any, actual_keys_input=None, unique_indices=None, key_count=None): """ Operator for look up in embedding table. """ table_id = self._check_ps_lookup_params(name=name, ids=ids) - if self._table_to_counter_filter.get(table_id) is not None: - filter_mode = "counter" - self._filter_freq = self._table_to_counter_filter.get(table_id).filter_freq - self._default_key_or_value = self._table_to_counter_filter.get(table_id).default_key_or_value - self._default_key = self._table_to_counter_filter.get(table_id).default_key - self._default_value = self._table_to_counter_filter.get(table_id).default_value - else: - filter_mode = "no_filter" - self._default_value = -1 # whether to use host unique to improve performance self.use_host_unique = False use_counter_filter = False @@ -445,88 +278,9 @@ class ESWorker: use_counter_filter = True result = self._call_lookup_op(table_id=table_id, ids=ids, actual_keys_input=actual_keys_input, - unique_indices=unique_indices, filter_mode=filter_mode, - use_counter_filter=use_counter_filter, key_count=key_count) - - self._filter_freq = None - self._default_key_or_value = True - self._default_key = None - self._default_value = None - if (self._ps_lookup_index != 0) or (self._existing_lookup_table_ids.count(table_id) != 0): - self._ps_table_has_lookup.append(table_id) - self._ps_table_lookup_key.append(ids) - self._ps_table_lookup_result.append(result) - self._ps_lookup_index = self._ps_lookup_index - 1 - if self.use_host_unique: - self.key_recovery_matrix.append(unique_indices) - # restore table id that has called lookup, if this table call lookup again, key and values must be stored. - self._existing_lookup_table_ids.append(table_id) - return result - - # old version - # 提供embedding lookup功能 - # @param table_id int32 类型 - # @param input_ids int64 类型 - # @return values float32 类型 - def embedding_lookup_v1(self, table_id: int, input_ids: typing.Any): - """ Operator for look up in embedding table. """ - if (table_id is None) or (input_ids is None): - raise ValueError("table_id or input_ids must be specified.") - if not isinstance(table_id, int): - raise ValueError("type of table_id must be int.") - if input_ids.dtype != tf.int64: - raise ValueError("dtype of input_ids must be tf.int64.") - if table_id < 0: - raise ValueError("table_id can not be smaller than zero.") - if not self._init_table_flag: - raise ValueError("embedding must init first!") - if table_id not in self._ps_table_id_list: - raise ValueError("this table has not yet initialized.") - if self._train_mode: - seed1, seed2 = random_seed.get_seed(None) - if self._table_to_counter_filter.get(table_id) is not None: - filter_mode = "counter" - self._filter_freq = self._table_to_counter_filter.get(table_id).filter_freq - self._default_key_or_value = self._table_to_counter_filter.get(table_id).default_key_or_value - self._default_key = self._table_to_counter_filter.get(table_id).default_key - self._default_value = self._table_to_counter_filter.get(table_id).default_value - else: - filter_mode = "no_filter" - result = gen_npu_cpu_ops. \ - embedding_table_find_and_init(table_id=ops.convert_to_tensor(table_id), - keys=input_ids, - embedding_dim=self._table_to_embedding_dim.get(table_id), - initializer_mode=self._table_id_to_initializer.get(table_id) - .initializer_mode, - constant_value=self._table_id_to_initializer.get(table_id).constant_value, - min=self._table_id_to_initializer.get(table_id).min, - max=self._table_id_to_initializer.get(table_id).max, - mu=self._table_id_to_initializer.get(table_id).mu, - sigma=self._table_id_to_initializer.get(table_id).sigma, - seed=seed1, - seed2=seed2, - value_total_len=self._table_to_embedding_dim.get(table_id) * - (self._table_to_slot_var_num.get(table_id) + 1), - filter_mode=filter_mode, - filter_freq=self._filter_freq, - default_key_or_value=self._default_key_or_value, - default_key=self._default_key, - default_value=self._default_value, - optimizer_mode=self._ps_table_id_to_optimizer_mode.get(table_id), - optimizer_params=self._ps_table_id_to_optimizer_params.get(table_id) - ) - self._filter_freq = None - self._default_key_or_value = True - self._default_key = None - self._default_value = None - else: - result = gen_npu_cpu_ops.embedding_table_find(table_id=ops.convert_to_tensor(table_id), - keys=input_ids, - embedding_dim=self._table_to_embedding_dim.get(table_id)) - result.op._set_attr("_embedding_dim", attr_value_pb2.AttrValue(i=self._table_to_embedding_dim.get(table_id))) - result.op._set_attr("_max_key_num", attr_value_pb2.AttrValue(i=self._table_to_max_num.get(table_id))) - result.op._set_attr("_use_counter_filter", - attr_value_pb2.AttrValue(i=self._table_use_counter_filter.get(table_id))) + unique_indices=unique_indices, use_counter_filter=use_counter_filter, + key_count=key_count) + self._store_lookup_data_and_refresh_state(table_id, ids, result, unique_indices) return result # new version @@ -537,18 +291,7 @@ class ESWorker: params = self._ps_table_lookup_result input_ids_list = self._ps_table_lookup_key table_ids = self._ps_table_has_lookup - self._check_update_params(params, input_ids_list, table_ids, loss) - if (not isinstance(params, (list, tuple)) and not isinstance(table_ids, (list, tuple)) - and not isinstance(input_ids_list, (list, tuple))): - params = [params] - table_ids = [table_ids] - input_ids_list = [input_ids_list] - for table_id in table_ids: - if table_id not in self._ps_table_id_list: - raise ValueError("this table has not yet initialized.") - if (len(params) != len(table_ids)) or (len(params) != len(input_ids_list)) \ - or (len(table_ids) != len(input_ids_list)): - raise ValueError("The length of params, table_ids, input_ids_list should be equal.") + params, table_ids, input_ids_list = self._check_update_params(params, input_ids_list, table_ids, loss) embedding_grads = tf.gradients(loss, params) update_op = [] self._ps_table_lookup_result = [] @@ -574,56 +317,8 @@ class ESWorker: self._table_to_optimizer.get(table_ids[i]).apply_gradients(list(zip(params_grads, var_refs)))) return update_op - # old version - # 提供embedding update功能 - # @param loss 类型 - # @param params float32 类型 - # @param table_ids int32 类型 - # @param input_ids_list int64 类型 - def embedding_update_v1(self, loss, params, table_ids, input_ids_list): - """ Operator for update in embedding table. """ - if (loss is None) or (params is None) or (table_ids is None) or (input_ids_list is None): - raise ValueError("loss or params or table_ids or input_ids_list is None.") - if (isinstance(loss, str)) or (isinstance(params, str)) or isinstance(table_ids, str) or \ - isinstance(input_ids_list, str): - raise ValueError("loss, params, table_ids and input_ids_list can not be str.") - if not self._init_table_flag: - raise ValueError("embedding must init first!") - if (not isinstance(params, (list, tuple)) and not isinstance(table_ids, (list, tuple)) - and not isinstance(input_ids_list, (list, tuple))): - params = [params] - table_ids = [table_ids] - input_ids_list = [input_ids_list] - for table_id in table_ids: - if table_id not in self._ps_table_id_list: - raise ValueError("this table has not yet initialized.") - if (len(params) != len(table_ids)) or (len(params) != len(input_ids_list)) \ - or (len(table_ids) != len(input_ids_list)): - raise ValueError("The length of params, table_ids, input_ids_list should be equal.") - embedding_grads = tf.gradients(loss, params) - update_op = [] - with specified_ps_engine_scope(): - for i in range(len(table_ids)): - params_grads = [tf.IndexedSlices(embedding_grads[i], input_ids_list[i], dense_shape=params[i].shape)] - var_refs = [NpuEmbeddingResource(table_ids[i])] - update_op.append( - self._table_to_optimizer.get(table_ids[i]).apply_gradients(list(zip(params_grads, var_refs)))) - return update_op - def counter_filter(self, filter_freq, default_key=None, default_value=None): - if not isinstance(filter_freq, int): - raise TypeError("filter_freq must be int, please check.") - if filter_freq < 0: - raise ValueError("filter_freq must can not be smaller than 0.") - if (default_key is None) and (default_value is None): - raise ValueError("default_key and default_value can not be both None.") - if (default_key is not None) and (default_value is not None): - raise ValueError("default_key and default_value can not be both set.") - if default_key is None and (not isinstance(default_value, (int, float))): - raise TypeError("When default_value is not None, it must be float or int, please check.") - if default_value is None and (not isinstance(default_key, int)): - raise TypeError("When default_key is not None, it must be int, please check.") - self._use_counter_filter = True + self._check_counter_filter_params(filter_freq, default_key, default_value) if default_key is None: return CounterFilter(filter_freq=filter_freq, default_key_or_value=False, default_key=default_key, default_value=default_value) @@ -631,30 +326,6 @@ class ESWorker: return CounterFilter(filter_freq=filter_freq, default_key_or_value=True, default_key=default_key, default_value=default_value) - def data_parallel_embedding(self, max_vocabulary_size, embedding_dim, multihot_lens, allow_merge=True, - initializer=tf.random_uniform_initializer(minval=-0.01, maxval=0.01, seed=1234)): - if (max_vocabulary_size is None) or (embedding_dim is None) or (multihot_lens is None): - raise ValueError("max_vocabulary_size or embedding_dim or multihot_lens can not be None.") - if (not isinstance(max_vocabulary_size, int)) or (not isinstance(embedding_dim, int)) or \ - (not isinstance(multihot_lens, int)) or (not isinstance(allow_merge, bool)): - raise TypeError("max_vocabulary_size, embedding_dim, multihot_lens must be int, allow_merge must be bool.") - if max_vocabulary_size <= 0 or embedding_dim <= 0 or multihot_lens <= 0: - raise ValueError("max_vocabulary_size, embedding_dim, multihot_lens must be greater than zero.") - if initializer is None: - raise ValueError("Initializer can not be None.") - if initializer is not None and not callable(initializer): - init_dtype = ops.convert_to_tensor(initializer).dtype.base_dtype - if init_dtype != tf.float32: - raise ValueError("Initializer type '%s' and explict dtype tf.float32 don't match." % init_dtype) - new_table_info = dict( - max_vocabulary_size=max_vocabulary_size, - embedding_dim=embedding_dim, - multihot_lens=multihot_lens, - allow_merge=allow_merge, - initializer=initializer - ) - self.user_defined_table_infos.append(new_table_info) - def init_table(self, table_map_policy=AutoMergeTableMapPolicy()): if len(self.user_defined_table_infos) == 0: raise ValueError("small table has not been created.") @@ -738,43 +409,6 @@ class ESWorker: output_slots[sid] = out_emb return output_slots - # old version - def embeddings_look_up(self, tf_indices): - if self.total_embedding_count != len(self.table_create_infos) or self.total_embedding_count == 0: - raise ValueError("Must init_table() first!") - if tf_indices is None: - raise ValueError("tf_indices can not be None.") - if tf_indices.dtype != tf.int64: - raise TypeError("dtype of tf_indices must be tf.int64.") - (in_slot_size_group, slot_to_table, table_to_input_group, \ - table_to_slot, table_to_output_slots) = \ - (self.table_map_policy.in_slot_size_group, self.table_map_policy.slot_to_table, \ - self.table_map_policy.table_to_input_groups, self.table_map_policy.table_to_slot, \ - self.table_map_policy.table_to_output_slots) - - tf_indices_shape_list = tf_indices.get_shape().as_list() - total_in_slot_num = 0 - for in_slot_size in in_slot_size_group: - total_in_slot_num += in_slot_size - if tf_indices_shape_list[1] != total_in_slot_num: - raise ValueError("size of tf_indices is not the same as all small tables.") - - indices_split = tf.split(tf_indices, in_slot_size_group, axis=1) - for tid in range(self.total_embedding_count): - table_to_input_group[tid] = [] - for sid, indices in enumerate(indices_split): - tid = slot_to_table[sid] - table_to_input_group[tid].append(indices) - - output_slots = [None for _ in in_slot_size_group] - for tid, table_input_group in enumerate(table_to_input_group): - table_input_hash = tf.concat(table_input_group, axis=1) - table_embedding = tf.nn.embedding_lookup(self.total_variable_table[tid], table_input_hash) - out_embedding_splited = tf.split(table_embedding, table_to_output_slots[tid], axis=1) - for out_emb, sid in zip(out_embedding_splited, table_to_slot[tid]): - output_slots[sid] = out_emb - return output_slots - def save_embedding(self, name: str, path: str): """ Operator for save values in table_id embedding table. """ self._check_save_or_restore_params(name=name, path=path) @@ -1377,6 +1011,81 @@ class ESWorker: raise ValueError("this ps table has not yet initialized.") return table_id + + @staticmethod + def _check_counter_filter_params(filter_freq, default_key, default_value): + if not isinstance(filter_freq, int): + raise TypeError("filter_freq must be int, please check.") + if filter_freq < 0: + raise ValueError("filter_freq must can not be smaller than 0.") + if (default_key is None) and (default_value is None): + raise ValueError("default_key and default_value can not be both None.") + if (default_key is not None) and (default_value is not None): + raise ValueError("default_key and default_value can not be both set.") + if default_key is None and (not isinstance(default_value, (int, float))): + raise TypeError("When default_value is not None, it must be float or int, please check.") + if default_value is None and (not isinstance(default_key, int)): + raise TypeError("When default_key is not None, it must be int, please check.") + return + + def _init_big_table(self, name, init_vocabulary_size, + embedding_dim, initializer, + optimizer, mask_zero, + max_feature_count, ev_option): + table_id = self._check_and_update_ps_init_params(name=name, init_vocabulary_size=init_vocabulary_size, + embedding_dim=embedding_dim, + max_feature_count=max_feature_count, ev_option=ev_option) + self._ps_lookup_index = self._ps_table_count + self._table_to_embedding_dim[table_id] = embedding_dim + self._table_to_max_num[table_id] = max_feature_count + # storage the table id for embedding PS table + self._ps_table_id_list.append(table_id) + self._ps_table_name_list.append(name) + if len(self._ps_table_id_list) > 10: + raise ValueError("Now only 10 PS embedding tables can be init.") + bucket_size = math.ceil(init_vocabulary_size / self._ps_num) + if optimizer is None: + self._train_mode = False + self._table_to_slot_var_num[table_id] = 0 + else: + self._check_ps_opt_and_initializer(optimizer=optimizer, initializer=initializer, table_id=table_id) + self._optimizer = optimizer + self._optimizer._embedding_dims = embedding_dim + self._optimizer._max_nums = max_feature_count + self._optimizer.mask_zero = mask_zero + self._table_to_optimizer[table_id] = self._optimizer + self._ps_table_id_to_optimizer_params[table_id] = [] + self._update_optimizer_slot_var_num(table_id=table_id) + # new train or continue train from a checkpoint + if initializer is not None: + self._train_level = True + with specified_ps_engine_scope(): + self._init_partition_maps[table_id] = \ + gen_npu_cpu_ops.init_partition_map(ps_num=ops.convert_to_tensor(self._ps_num), + ps_ids=ops.convert_to_tensor(self._ps_ids), + partition_num=65537) + self._init_partition_maps.get(table_id)._set_attr("_embedding_dim", + attr_value_pb2.AttrValue(i=embedding_dim)) + self._init_partition_maps.get(table_id)._set_attr("_max_key_num", + attr_value_pb2.AttrValue(i=max_feature_count)) + return self._init_hashmap_and_table_import(bucket_size, table_id, embedding_dim, ev_option) + + def _init_small_table(self, name, init_vocabulary_size, embedding_dim, multihot_lens, + key_dtype, value_dtype, allow_merge, initializer): + self._check_and_update_small_init_params(name=name, init_vocabulary_size=init_vocabulary_size, + embedding_dim=embedding_dim, multihot_lens=multihot_lens, + key_dtype=key_dtype, value_dtype=value_dtype, + allow_merge=allow_merge, initializer=initializer) + new_small_table_info = dict( + name=name, + max_vocabulary_size=init_vocabulary_size, + embedding_dim=embedding_dim, + multihot_lens=multihot_lens, + allow_merge=allow_merge, + initializer=initializer) + self.user_defined_table_infos.append(new_small_table_info) + return new_small_table_info + def _check_update_params(self, params, input_ids_list, table_ids, loss): if (loss is None) or (params is None) or (table_ids is None) or (input_ids_list is None): raise ValueError("loss or params or table_ids or input_ids_list is None.") @@ -1385,6 +1094,18 @@ class ESWorker: raise ValueError("loss, params, table_ids and input_ids_list can not be str.") if not self._init_table_flag: raise ValueError("embedding must init first!") + if (not isinstance(params, (list, tuple)) and not isinstance(table_ids, (list, tuple)) + and not isinstance(input_ids_list, (list, tuple))): + params = [params] + table_ids = [table_ids] + input_ids_list = [input_ids_list] + for table_id in table_ids: + if table_id not in self._ps_table_id_list: + raise ValueError("this table has not yet initialized.") + if (len(params) != len(table_ids)) or (len(params) != len(input_ids_list)) \ + or (len(table_ids) != len(input_ids_list)): + raise ValueError("The length of params, table_ids, input_ids_list should be equal.") + return params, table_ids, input_ids_list def _check_save_or_restore_params(self, name, path): if path is None or name is None: @@ -1416,13 +1137,17 @@ class ESWorker: def _init_counter_filter(self, table_id, ev_option): if (ev_option is not None) and (ev_option.filter_option is not None): - filter_mode = "counter" - self._table_to_counter_filter[table_id] = ev_option.filter_option - self._table_use_counter_filter[table_id] = 1 + self._table_use_counter_filter[table_id] =\ + CounterFilter(use_counter_filter=1, + filter_mode="counter", + filter_freq=ev_option.filter_option.filter_freq, + default_key_or_value=ev_option.filter_option.default_key_or_value, + default_key=ev_option.filter_option.default_key, + default_value=ev_option.filter_option.default_value) else: - filter_mode = "no_filter" - self._table_use_counter_filter[table_id] = 0 - return filter_mode + self._table_use_counter_filter[table_id] = CounterFilter(use_counter_filter=0, filter_mode="no_filter", + default_value=-1) + return def _init_optimizer_mode_and_params(self, table_id): if isinstance(self._table_to_optimizer.get(table_id), embedding_optimizer.AdagradOptimizer): @@ -1447,7 +1172,7 @@ class ESWorker: self._table_to_optimizer.get(table_id).mom) def _init_hashmap_and_table_import(self, bucket_size, table_id, embedding_dim, ev_option): - filter_mode = self._init_counter_filter(table_id, ev_option) + self._init_counter_filter(table_id, ev_option) self._init_optimizer_mode_and_params(table_id) with tf.control_dependencies([self._init_partition_maps.get(table_id)]): @@ -1471,7 +1196,8 @@ class ESWorker: sigma=self._table_id_to_initializer.get(table_id).sigma, seed=self._table_id_to_initializer.get(table_id).seed, seed2=self._table_id_to_initializer.get(table_id).seed, - filter_mode=filter_mode, + filter_mode= + self._table_use_counter_filter.get(table_id).filter_mode, optimizer_mode= self._ps_table_id_to_optimizer_mode.get(table_id), optimizer_params= @@ -1485,7 +1211,9 @@ class ESWorker: embedding_dim=embedding_dim, initializer_mode=None, constant_value=None, min=None, max=None, mu=None, sigma=None, - seed=None, seed2=None, filter_mode=filter_mode, + seed=None, seed2=None, + filter_mode= + self._table_use_counter_filter.get(table_id).filter_mode, optimizer_mode= self._ps_table_id_to_optimizer_mode.get(table_id), optimizer_params= @@ -1498,7 +1226,9 @@ class ESWorker: embedding_dim=embedding_dim, initializer_mode=None, constant_value=None, min=None, max=None, mu=None, sigma=None, - seed=None, seed2=None, filter_mode=filter_mode, + seed=None, seed2=None, + filter_mode= + self._table_use_counter_filter.get(table_id).filter_mode, optimizer_mode= self._ps_table_id_to_optimizer_mode.get(table_id), optimizer_params= @@ -1512,7 +1242,7 @@ class ESWorker: return tf.group([self._init_embedding_hash_maps.get(table_id)]) def _call_lookup_op(self, table_id, ids, actual_keys_input=None, unique_indices=None, - filter_mode=None, use_counter_filter=False, key_count=None): + use_counter_filter=False, key_count=None): if self._train_mode: if self.use_host_unique: if use_counter_filter: @@ -1538,11 +1268,12 @@ class ESWorker: seed2=self._table_id_to_initializer.get(table_id).seed, value_total_len=self._table_to_embedding_dim .get(table_id) * (self._table_to_slot_var_num.get(table_id) + 1), - filter_mode=filter_mode, - filter_freq=self._filter_freq, - default_key_or_value=self._default_key_or_value, - default_key=self._default_key, - default_value=self._default_value, + filter_mode=self._table_use_counter_filter.get(table_id).filter_mode, + filter_freq=self._table_use_counter_filter.get(table_id).filter_freq, + default_key_or_value= + self._table_use_counter_filter.get(table_id).default_key_or_value, + default_key=self._table_use_counter_filter.get(table_id).default_key, + default_value=self._table_use_counter_filter.get(table_id).default_value, optimizer_mode=self._ps_table_id_to_optimizer_mode.get(table_id), optimizer_params=self._ps_table_id_to_optimizer_params.get(table_id) ) @@ -1563,11 +1294,13 @@ class ESWorker: seed2=self._table_id_to_initializer.get(table_id).seed, value_total_len=self._table_to_embedding_dim.get(table_id) * (self._table_to_slot_var_num.get(table_id) + 1), - filter_mode=filter_mode, - filter_freq=self._filter_freq, - default_key_or_value=self._default_key_or_value, - default_key=self._default_key, - default_value=self._default_value, + filter_mode=self._table_use_counter_filter.get(table_id).filter_mode, + filter_freq=self._table_use_counter_filter.get(table_id).filter_freq, + default_key_or_value= + self._table_use_counter_filter.get(table_id).default_key_or_value, + default_key=self._table_use_counter_filter.get(table_id).default_key, + default_value= + self._table_use_counter_filter.get(table_id).default_value, optimizer_mode=self._ps_table_id_to_optimizer_mode.get(table_id), optimizer_params=self._ps_table_id_to_optimizer_params.get(table_id) ) @@ -1580,5 +1313,17 @@ class ESWorker: result.op._set_attr("_embedding_dim", attr_value_pb2.AttrValue(i=self._table_to_embedding_dim.get(table_id))) result.op._set_attr("_max_key_num", attr_value_pb2.AttrValue(i=self._table_to_max_num.get(table_id))) result.op._set_attr("_use_counter_filter", - attr_value_pb2.AttrValue(i=self._table_use_counter_filter.get(table_id))) + attr_value_pb2.AttrValue(i=self._table_use_counter_filter.get(table_id).use_counter_filter)) return result + + def _store_lookup_data_and_refresh_state(self, table_id, ids, result, unique_indices): + if (self._ps_lookup_index != 0) or (self._existing_lookup_table_ids.count(table_id) != 0): + self._ps_table_has_lookup.append(table_id) + self._ps_table_lookup_key.append(ids) + self._ps_table_lookup_result.append(result) + self._ps_lookup_index = self._ps_lookup_index - 1 + if self.use_host_unique: + self.key_recovery_matrix.append(unique_indices) + # restore table id that has called lookup, if this table call lookup again, key and values must be stored. + self._existing_lookup_table_ids.append(table_id) + return