diff --git a/tf_adapter/python/npu_bridge/embedding/embedding_service.py b/tf_adapter/python/npu_bridge/embedding/embedding_service.py index ef03bae87ea1eda89de9131b7d5ad79647bbf4df..659d96f437f9bec83df87fbd1850540ab6dea3c7 100644 --- a/tf_adapter/python/npu_bridge/embedding/embedding_service.py +++ b/tf_adapter/python/npu_bridge/embedding/embedding_service.py @@ -26,8 +26,6 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.core.framework import attr_value_pb2 from npu_bridge.npu_cpu.npu_cpu_ops import gen_npu_cpu_ops -from npu_bridge.hccl.hccl_ops import allgather -from hccl.manage.api import create_group from npu_bridge.npu_cpu.npu_cpu_ops import host_feature_mapping_export from npu_bridge.npu_cpu.npu_cpu_ops import host_feature_mapping_import from npu_bridge.embedding.embedding_resource import NpuEmbeddingResource @@ -178,7 +176,6 @@ class ESWorker: self._init_table_flag = False self._small_table_name_list = [] - self._small_table_variable_list = [] self._ps_table_count = 0 self._table_name_to_id = {} self._table_id_to_name = {} @@ -209,7 +206,6 @@ class ESWorker: self.table_map_policy = None self.table_create_infos = [] self.total_variable_table = [] - self._small_table_embedding_dim = -1 # if all small table do not merge self._small_table_to_variable = {} self._small_table_to_multihot_lens = {} @@ -307,7 +303,6 @@ class ESWorker: multihot_lens=multihot_lens, allow_merge=allow_merge, initializer=initializer) - self._small_table_embedding_dim = embedding_dim self.user_defined_table_infos.append(new_small_table_info) return new_small_table_info elif embedding_type == "PS": @@ -661,12 +656,6 @@ class ESWorker: if len(self.user_defined_table_infos) == 0: raise ValueError("small table has not been created.") self.total_embedding_count = 0 - if (os.environ.get("RANK_SIZE") is not None) and (int(os.environ.get("RANK_SIZE")) > 1): - rank_size = int(os.environ.get("RANK_SIZE")) - rank_list = [] - for i in range(rank_size): - rank_list.append(i) - create_group("user_group_fm", rank_size, rank_list) if not self._need_table_merge: for user_table_info in self.user_defined_table_infos: self._small_table_to_variable[user_table_info['name']] =\ @@ -674,7 +663,6 @@ class ESWorker: user_table_info['embedding_dim']], initializer=user_table_info['initializer'], dtype=tf.float32) self._small_table_to_multihot_lens[self.total_embedding_count] = user_table_info['multihot_lens'] - self._small_table_variable_list.append(user_table_info['name'] + ":0") self.total_embedding_count += 1 else: self.total_variable_table = [] @@ -691,50 +679,12 @@ class ESWorker: dtype=tf.float32 )) self._npu_table_to_embedding_dim[self.total_embedding_count] = table_info_['embedding_dim'] - self._small_table_variable_list.append('ES' + str(self.total_embedding_count) + ":0") self.total_embedding_count += 1 self.user_defined_table_infos = [] self._small_table_name_list = [] # new version def embeddings_lookup(self, ids_list, name=None): - if ids_list is None: - raise ValueError("ids_list can not be None.") - env_dist = os.environ - rank_size = int(env_dist.get("RANK_SIZE")) - rank_id = int(env_dist.get("RANK_ID")) - if rank_size < 1: - raise ValueError('Rank size from env must be at least 1, 'f' Received: {rank_size}.') - if rank_id < 0 or rank_id >= rank_size: - raise ValueError('Rank id from env must be at least 0, and smaller than Rank Size.' - 'But Rank id 'f' Received: {rank_id}.') - - if not self._need_table_merge: - return self._small_table_lookup_v1(name, rank_id, rank_size, ids_list) - - if self.total_embedding_count != len(self.table_create_infos) or self.total_embedding_count == 0: - raise ValueError("Must init_table() first!") - (in_slot_size_group, slot_to_table, table_to_input_group, \ - table_to_slot, table_to_output_slots) = \ - (self.table_map_policy.in_slot_size_group, self.table_map_policy.slot_to_table, \ - self.table_map_policy.table_to_input_groups, self.table_map_policy.table_to_slot, \ - self.table_map_policy.table_to_output_slots) - - ids_list_shape_list = ids_list.get_shape().as_list() - total_in_slot_num = 0 - for in_slot_size in in_slot_size_group: - total_in_slot_num += in_slot_size - if ids_list_shape_list[1] != total_in_slot_num: - raise ValueError("size of ids_list is not the same as all small tables.") - - if self.total_embedding_count == 1: - return self._small_table_lookup_v2(rank_id, rank_size, in_slot_size_group, - ids_list, table_to_output_slots, table_to_slot) - - return self._small_table_lookup_v3(rank_id, rank_size, ids_list, in_slot_size_group, slot_to_table, - table_to_input_group, table_to_output_slots, table_to_slot) - - def embeddings_lookup1(self, ids_list, name=None): if ids_list is None: raise ValueError("ids_list can not be None.") if ids_list.dtype != tf.int64: @@ -886,12 +836,7 @@ class ESWorker: return tf.group([embedding_table_import]) def restore_embeddings(self, path: str): - if len(self._ps_table_name_list) != 0: - self._check_save_or_restore_params_v2(path=path, save_flag=False) - if len(self._small_table_variable_list) != 0: - feature_mapping_import_list = self._call_feature_mapping_import_op(path=path) - if self._ps_table_count == 0: - return feature_mapping_import_list + self._check_save_or_restore_params_v2(path=path, save_flag=False) with specified_ps_engine_scope(): table_id_list = [] embedding_dim_list = [] @@ -907,9 +852,7 @@ class ESWorker: only_var_flag=True, file_type="bin", table_name=self._ps_table_name_list) - if len(self._small_table_variable_list) == 0: - return tf.group([embedding_table_import]) - return embedding_table_import, feature_mapping_import_list + return tf.group([embedding_table_import]) def save_checkpoint(self, name: str, path: str, save_filtered_features=False): """ Operator for save values and optimizer params in table_id embedding table. """ @@ -948,14 +891,15 @@ class ESWorker: def save_checkpoints(self, path: str, save_filtered_features=False, export_feature_mapping=False): """ Operator for save values and optimizer params in all embedding tables. """ - if len(self._ps_table_name_list) != 0: + if export_feature_mapping is False: self._check_save_or_restore_params_v2(path=path, save_flag=True) if not isinstance(save_filtered_features, bool): raise TypeError("save_filtered_features must be bool.") - if export_feature_mapping or len(self._small_table_variable_list) != 0: - feature_mapping_export_list = self._call_feature_mapping_export_op(path) + if export_feature_mapping: + feature_mapping_export = host_feature_mapping_export(path=path, + table_name_list=self._feature_mapping_name_list) if self._ps_table_count == 0: - return feature_mapping_export_list + return feature_mapping_export with specified_ps_engine_scope(): table_id_list = [] embedding_dim_list = [] @@ -985,9 +929,9 @@ class ESWorker: ps_id=ps_id_tensor, table_id=table_id_tensor, table_name=self._ps_table_name_list) - if len(self._small_table_variable_list) == 0: + if export_feature_mapping is False: return tf.group([embedding_compute_var_export]) - return embedding_compute_var_export, feature_mapping_export_list + return embedding_compute_var_export, feature_mapping_export def restore_checkpoint(self, name: str, path: str): """ Operator for restore values and optimizer params in table_id embedding table. """ @@ -1018,12 +962,12 @@ class ESWorker: def restore_checkpoints(self, path: str, import_feature_mapping=False): """ Operator for restore values and optimizer params in all embedding tables. """ - if len(self._ps_table_name_list) != 0: + if import_feature_mapping is False: self._check_save_or_restore_params_v2(path=path, save_flag=False) - if import_feature_mapping or len(self._small_table_variable_list) != 0: - feature_mapping_import_list = self._call_feature_mapping_import_op(path=path) + if import_feature_mapping: + feature_mapping_import = host_feature_mapping_import(path=path) if self._ps_table_count == 0: - return feature_mapping_import_list + return feature_mapping_import with specified_ps_engine_scope(): table_id_list = [] embedding_dim_list = [] @@ -1051,9 +995,9 @@ class ESWorker: ps_id=ps_id_tensor, table_id=table_id_tensor, table_name=self._ps_table_name_list) - if len(self._small_table_variable_list) == 0: + if import_feature_mapping is False: return tf.group([embedding_compute_var_import]) - return embedding_compute_var_import, feature_mapping_import_list + return embedding_compute_var_import, feature_mapping_import def save_incremental_embedding(self, name: str, path: str): """ Operator for save incremental values in table_id embedding table. """ @@ -1688,116 +1632,3 @@ class ESWorker: result.op._set_attr("_execute_times", attr_value_pb2.AttrValue(i=2)) return result - - def _small_table_lookup_v1(self, name, rank_id, rank_size, ids_list): - if not isinstance(name, str): - raise TypeError("embedding table name must be string.") - if self.total_embedding_count == 0: - raise ValueError("Must init_table() first!") - hash_key_shape = ids_list.get_shape().as_list() - if rank_size > 1: - hash_key = allgather(tensor=ids_list, rank_size=rank_size, group="user_group") - non_hash_key = gen_npu_cpu_ops.embedding_feature_mapping_v2(feature_id=hash_key, table_name=name) - recovery_matrix = [] - for i in range(hash_key_shape[0]): - recovery_matrix.append(rank_id * hash_key_shape[0] + i) - local_non_hash_keys = tf.gather(non_hash_key, recovery_matrix) - else: - hash_key = ids_list - local_non_hash_keys = gen_npu_cpu_ops.embedding_feature_mapping_v2(feature_id=hash_key, table_name=name) - return tf.nn.embedding_lookup(self._small_table_to_variable[name], local_non_hash_keys) - - def _small_table_lookup_v2(self, rank_id, rank_size, in_slot_size_group, - ids_list, table_to_output_slots, table_to_slot): - # all small table merge to One table - hash_key_shape = ids_list.get_shape().as_list() - if rank_size > 1: - hash_key = allgather(tensor=ids_list, rank_size=rank_size, group="user_group") - non_hash_key = gen_npu_cpu_ops.embedding_feature_mapping_v2( - feature_id=hash_key, table_name=self._small_table_variable_list[0][:-2]) - recovery_matrix = [] - for i in range(hash_key_shape[0]): - recovery_matrix.append(rank_id * hash_key_shape[0] + i) - local_non_hash_keys = tf.gather(non_hash_key, recovery_matrix) - else: - hash_key = ids_list - local_non_hash_keys = gen_npu_cpu_ops.embedding_feature_mapping_v2( - feature_id=hash_key, table_name=self._small_table_variable_list[0][:-2]) - - output_slots = [None for _ in in_slot_size_group] - tid = 0 - table_embedding = tf.nn.embedding_lookup(self.total_variable_table[tid], local_non_hash_keys) - out_embedding_splited = tf.split(table_embedding, table_to_output_slots[0], axis=1) - for out_emb, sid in zip(out_embedding_splited, table_to_slot[0]): - output_slots[sid] = out_emb - return output_slots - - def _small_table_lookup_v3(self, rank_id, rank_size, ids_list, in_slot_size_group, slot_to_table, - table_to_input_group, table_to_output_slots, table_to_slot): - # All small tables merge to two or more tables - indices_split = tf.split(ids_list, in_slot_size_group, axis=1) - for tid in range(self.total_embedding_count): - table_to_input_group[tid] = [] - for sid, indices in enumerate(indices_split): - tid = slot_to_table[sid] - table_to_input_group[tid].append(indices) - - output_slots = [None for _ in in_slot_size_group] - for tid, table_input_group in enumerate(table_to_input_group): - table_input_hash = tf.concat(table_input_group, axis=1) - hash_key_shape = table_input_hash.get_shape().as_list() - if rank_size > 1: - hash_key = allgather(tensor=table_input_hash, rank_size=rank_size, group="user_group") - non_hash_key = gen_npu_cpu_ops.embedding_feature_mapping_v2( - feature_id=hash_key, table_name=self._small_table_variable_list[tid][:-2]) - recovery_matrix = [] - for i in range(hash_key_shape[0]): - recovery_matrix.append(rank_id * hash_key_shape[0] + i) - local_non_hash_keys = tf.gather(non_hash_key, recovery_matrix) - else: - hash_key = table_input_hash - local_non_hash_keys = gen_npu_cpu_ops.embedding_feature_mapping_v2( - feature_id=hash_key, table_name=self._small_table_variable_list[tid][:-2]) - table_embedding = tf.nn.embedding_lookup(self.total_variable_table[tid], local_non_hash_keys) - out_embedding_splited = tf.split(table_embedding, table_to_output_slots[tid], axis=1) - for out_emb, sid in zip(out_embedding_splited, table_to_slot[tid]): - output_slots[sid] = out_emb - return output_slots - - def _call_feature_mapping_export_op(self, path): - feature_mapping_export_list = [] - tvar = tf.trainable_variables() - for x in tvar: - if x.name not in self._small_table_variable_list: - continue - feature_size = gen_npu_cpu_ops.embedding_feature_mapping_table_size(table_name=x.name[:-2]) - feature_id, offset_id = gen_npu_cpu_ops.embedding_feature_mapping_find(table_name=x.name[:-2], - feature_size=feature_size) - values = tf.gather(x, offset_id) - feature_mapping_export = gen_npu_cpu_ops.embedding_feature_mapping_export(file_path=path, - feature_id=feature_id, - offset_id=offset_id, - values=values, - table_name=x.name[:-2]) - feature_mapping_export_list.append(feature_mapping_export) - return feature_mapping_export_list - - def _call_feature_mapping_import_op(self, path): - feature_mapping_import_list = [] - for x in self._small_table_variable_list: - feature_size = \ - gen_npu_cpu_ops.embedding_feature_mapping_file_size(file_path=path, - table_name=x[:-2], - embedding_dim=self._small_table_embedding_dim) - feature_id, offset_id, values = \ - gen_npu_cpu_ops.embedding_feature_mapping_import(file_path=path, - table_name=x[:-2], - feature_size=feature_size, - embedding_dim=self._small_table_embedding_dim) - feature_mapping_insert = \ - gen_npu_cpu_ops.embedding_feature_mapping_insert(table_name=x[:-2], - feature_id=feature_id, - offset_id=offset_id) - feature_mapping_import_list.append(feature_mapping_insert) - return feature_mapping_import_list -