diff --git a/tf_adapter/python/npu_bridge/embedding/embedding_service.py b/tf_adapter/python/npu_bridge/embedding/embedding_service.py index c3a7e936e44ce893ade8f21ae0e5f27a2457455c..93b3a135980529844eca85ab2562a7f64b9b3b7e 100644 --- a/tf_adapter/python/npu_bridge/embedding/embedding_service.py +++ b/tf_adapter/python/npu_bridge/embedding/embedding_service.py @@ -26,6 +26,8 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.core.framework import attr_value_pb2 from npu_bridge.npu_cpu.npu_cpu_ops import gen_npu_cpu_ops +from npu_bridge.hccl.hccl_ops import allgather +from hccl.manage.api import create_group from npu_bridge.npu_cpu.npu_cpu_ops import host_feature_mapping_export from npu_bridge.npu_cpu.npu_cpu_ops import host_feature_mapping_import from npu_bridge.embedding.embedding_resource import NpuEmbeddingResource @@ -179,6 +181,7 @@ class ESWorker: self._init_table_flag = False self._small_table_name_list = [] + self._small_table_variable_list = [] self._ps_table_count = 0 self._table_name_to_id = {} self._table_id_to_name = {} @@ -209,6 +212,7 @@ class ESWorker: self.table_map_policy = None self.table_create_infos = [] self.total_variable_table = [] + self._small_table_embedding_dim = -1 # if all small table do not merge self._small_table_to_variable = {} self._small_table_to_multihot_lens = {} @@ -306,6 +310,7 @@ class ESWorker: multihot_lens=multihot_lens, allow_merge=allow_merge, initializer=initializer) + self._small_table_embedding_dim = embedding_dim self.user_defined_table_infos.append(new_small_table_info) return new_small_table_info elif embedding_type == "PS": @@ -659,6 +664,12 @@ class ESWorker: if len(self.user_defined_table_infos) == 0: raise ValueError("small table has not been created.") self.total_embedding_count = 0 + if (os.environ.get("RANK_SIZE") is not None) and (int(os.environ.get("RANK_SIZE")) > 1): + rank_size = int(os.environ.get("RANK_SIZE")) + rank_list = [] + for i in range(rank_size): + rank_list.append(i) + create_group("user_group_fm", rank_size, rank_list) if not self._need_table_merge: for user_table_info in self.user_defined_table_infos: self._small_table_to_variable[user_table_info['name']] =\ @@ -666,6 +677,7 @@ class ESWorker: user_table_info['embedding_dim']], initializer=user_table_info['initializer'], dtype=tf.float32) self._small_table_to_multihot_lens[self.total_embedding_count] = user_table_info['multihot_lens'] + self._small_table_variable_list.append(user_table_info['name'] + ":0") self.total_embedding_count += 1 else: self.total_variable_table = [] @@ -682,12 +694,50 @@ class ESWorker: dtype=tf.float32 )) self._npu_table_to_embedding_dim[self.total_embedding_count] = table_info_['embedding_dim'] + self._small_table_variable_list.append('ES' + str(self.total_embedding_count) + ":0") self.total_embedding_count += 1 self.user_defined_table_infos = [] self._small_table_name_list = [] # new version def embeddings_lookup(self, ids_list, name=None): + if ids_list is None: + raise ValueError("ids_list can not be None.") + env_dist = os.environ + rank_size = int(env_dist.get("RANK_SIZE")) + rank_id = int(env_dist.get("RANK_ID")) + if rank_size < 1: + raise ValueError('Rank size from env must be at least 1, 'f' Received: {rank_size}.') + if rank_id < 0 or rank_id >= rank_size: + raise ValueError('Rank id from env must be at least 0, and smaller than Rank Size.' + 'But Rank id 'f' Received: {rank_id}.') + + if not self._need_table_merge: + return self._small_table_lookup_v1(name, rank_id, rank_size, ids_list) + + if self.total_embedding_count != len(self.table_create_infos) or self.total_embedding_count == 0: + raise ValueError("Must init_table() first!") + (in_slot_size_group, slot_to_table, table_to_input_group, \ + table_to_slot, table_to_output_slots) = \ + (self.table_map_policy.in_slot_size_group, self.table_map_policy.slot_to_table, \ + self.table_map_policy.table_to_input_groups, self.table_map_policy.table_to_slot, \ + self.table_map_policy.table_to_output_slots) + + ids_list_shape_list = ids_list.get_shape().as_list() + total_in_slot_num = 0 + for in_slot_size in in_slot_size_group: + total_in_slot_num += in_slot_size + if ids_list_shape_list[1] != total_in_slot_num: + raise ValueError("size of ids_list is not the same as all small tables.") + + if self.total_embedding_count == 1: + return self._small_table_lookup_v2(rank_id, rank_size, in_slot_size_group, + ids_list, table_to_output_slots, table_to_slot) + + return self._small_table_lookup_v3(rank_id, rank_size, ids_list, in_slot_size_group, slot_to_table, + table_to_input_group, table_to_output_slots, table_to_slot) + + def embeddings_lookup1(self, ids_list, name=None): if ids_list is None: raise ValueError("ids_list can not be None.") if ids_list.dtype != tf.int64: @@ -839,7 +889,12 @@ class ESWorker: return tf.group([embedding_table_import]) def restore_embeddings(self, path: str): - self._check_save_or_restore_params_v2(path=path, save_flag=False) + if len(self._ps_table_name_list) != 0: + self._check_save_or_restore_params_v2(path=path, save_flag=False) + if len(self._small_table_variable_list) != 0: + feature_mapping_import_list = self._call_feature_mapping_import_op(path=path) + if self._ps_table_count == 0: + return feature_mapping_import_list with specified_ps_engine_scope(): table_id_list = [] embedding_dim_list = [] @@ -855,7 +910,9 @@ class ESWorker: only_var_flag=True, file_type="bin", table_name=self._ps_table_name_list) - return tf.group([embedding_table_import]) + if len(self._small_table_variable_list) == 0: + return tf.group([embedding_table_import]) + return embedding_table_import, feature_mapping_import_list def save_checkpoint(self, name: str, path: str, save_filtered_features=False): """ Operator for save values and optimizer params in table_id embedding table. """ @@ -894,15 +951,14 @@ class ESWorker: def save_checkpoints(self, path: str, save_filtered_features=False, export_feature_mapping=False): """ Operator for save values and optimizer params in all embedding tables. """ - if export_feature_mapping is False: + if len(self._ps_table_name_list) != 0: self._check_save_or_restore_params_v2(path=path, save_flag=True) if not isinstance(save_filtered_features, bool): raise TypeError("save_filtered_features must be bool.") - if export_feature_mapping: - feature_mapping_export = host_feature_mapping_export(path=path, - table_name_list=self._feature_mapping_name_list) + if export_feature_mapping or len(self._small_table_variable_list) != 0: + feature_mapping_export_list = self._call_feature_mapping_export_op(path) if self._ps_table_count == 0: - return feature_mapping_export + return feature_mapping_export_list with specified_ps_engine_scope(): table_id_list = [] embedding_dim_list = [] @@ -932,9 +988,9 @@ class ESWorker: ps_id=ps_id_tensor, table_id=table_id_tensor, table_name=self._ps_table_name_list) - if export_feature_mapping is False: + if len(self._small_table_variable_list) == 0: return tf.group([embedding_compute_var_export]) - return embedding_compute_var_export, feature_mapping_export + return embedding_compute_var_export, feature_mapping_export_list def restore_checkpoint(self, name: str, path: str): """ Operator for restore values and optimizer params in table_id embedding table. """ @@ -965,12 +1021,12 @@ class ESWorker: def restore_checkpoints(self, path: str, import_feature_mapping=False): """ Operator for restore values and optimizer params in all embedding tables. """ - if import_feature_mapping is False: + if len(self._ps_table_name_list) != 0: self._check_save_or_restore_params_v2(path=path, save_flag=False) - if import_feature_mapping: - feature_mapping_import = host_feature_mapping_import(path=path) + if import_feature_mapping or len(self._small_table_variable_list) != 0: + feature_mapping_import_list = self._call_feature_mapping_import_op(path=path) if self._ps_table_count == 0: - return feature_mapping_import + return feature_mapping_import_list with specified_ps_engine_scope(): table_id_list = [] embedding_dim_list = [] @@ -998,9 +1054,9 @@ class ESWorker: ps_id=ps_id_tensor, table_id=table_id_tensor, table_name=self._ps_table_name_list) - if import_feature_mapping is False: + if len(self._small_table_variable_list) == 0: return tf.group([embedding_compute_var_import]) - return embedding_compute_var_import, feature_mapping_import + return embedding_compute_var_import, feature_mapping_import_list def save_incremental_embedding(self, name: str, path: str): """ Operator for save incremental values in table_id embedding table. """ @@ -1582,3 +1638,116 @@ class ESWorker: result.op._set_attr("_use_counter_filter", attr_value_pb2.AttrValue(i=self._table_use_counter_filter.get(table_id))) return result + + def _small_table_lookup_v1(self, name, rank_id, rank_size, ids_list): + if not isinstance(name, str): + raise TypeError("embedding table name must be string.") + if self.total_embedding_count == 0: + raise ValueError("Must init_table() first!") + hash_key_shape = ids_list.get_shape().as_list() + if rank_size > 1: + hash_key = allgather(tensor=ids_list, rank_size=rank_size, group="user_group") + non_hash_key = gen_npu_cpu_ops.embedding_feature_mapping_v2(feature_id=hash_key, table_name=name) + recovery_matrix = [] + for i in range(hash_key_shape[0]): + recovery_matrix.append(rank_id * hash_key_shape[0] + i) + local_non_hash_keys = tf.gather(non_hash_key, recovery_matrix) + else: + hash_key = ids_list + local_non_hash_keys = gen_npu_cpu_ops.embedding_feature_mapping_v2(feature_id=hash_key, table_name=name) + return tf.nn.embedding_lookup(self._small_table_to_variable[name], local_non_hash_keys) + + def _small_table_lookup_v2(self, rank_id, rank_size, in_slot_size_group, + ids_list, table_to_output_slots, table_to_slot): + # all small table merge to One table + hash_key_shape = ids_list.get_shape().as_list() + if rank_size > 1: + hash_key = allgather(tensor=ids_list, rank_size=rank_size, group="user_group") + non_hash_key = gen_npu_cpu_ops.embedding_feature_mapping_v2( + feature_id=hash_key, table_name=self._small_table_variable_list[0][:-2]) + recovery_matrix = [] + for i in range(hash_key_shape[0]): + recovery_matrix.append(rank_id * hash_key_shape[0] + i) + local_non_hash_keys = tf.gather(non_hash_key, recovery_matrix) + else: + hash_key = ids_list + local_non_hash_keys = gen_npu_cpu_ops.embedding_feature_mapping_v2( + feature_id=hash_key, table_name=self._small_table_variable_list[0][:-2]) + + output_slots = [None for _ in in_slot_size_group] + tid = 0 + table_embedding = tf.nn.embedding_lookup(self.total_variable_table[tid], local_non_hash_keys) + out_embedding_splited = tf.split(table_embedding, table_to_output_slots[0], axis=1) + for out_emb, sid in zip(out_embedding_splited, table_to_slot[0]): + output_slots[sid] = out_emb + return output_slots + + def _small_table_lookup_v3(self, rank_id, rank_size, ids_list, in_slot_size_group, slot_to_table, + table_to_input_group, table_to_output_slots, table_to_slot): + # All small tables merge to two or more tables + indices_split = tf.split(ids_list, in_slot_size_group, axis=1) + for tid in range(self.total_embedding_count): + table_to_input_group[tid] = [] + for sid, indices in enumerate(indices_split): + tid = slot_to_table[sid] + table_to_input_group[tid].append(indices) + + output_slots = [None for _ in in_slot_size_group] + for tid, table_input_group in enumerate(table_to_input_group): + table_input_hash = tf.concat(table_input_group, axis=1) + hash_key_shape = table_input_hash.get_shape().as_list() + if rank_size > 1: + hash_key = allgather(tensor=table_input_hash, rank_size=rank_size, group="user_group") + non_hash_key = gen_npu_cpu_ops.embedding_feature_mapping_v2( + feature_id=hash_key, table_name=self._small_table_variable_list[tid][:-2]) + recovery_matrix = [] + for i in range(hash_key_shape[0]): + recovery_matrix.append(rank_id * hash_key_shape[0] + i) + local_non_hash_keys = tf.gather(non_hash_key, recovery_matrix) + else: + hash_key = table_input_hash + local_non_hash_keys = gen_npu_cpu_ops.embedding_feature_mapping_v2( + feature_id=hash_key, table_name=self._small_table_variable_list[tid][:-2]) + table_embedding = tf.nn.embedding_lookup(self.total_variable_table[tid], local_non_hash_keys) + out_embedding_splited = tf.split(table_embedding, table_to_output_slots[tid], axis=1) + for out_emb, sid in zip(out_embedding_splited, table_to_slot[tid]): + output_slots[sid] = out_emb + return output_slots + + def _call_feature_mapping_export_op(self, path): + feature_mapping_export_list = [] + tvar = tf.trainable_variables() + for x in tvar: + if x.name not in self._small_table_variable_list: + continue + feature_size = gen_npu_cpu_ops.embedding_feature_mapping_table_size(table_name=x.name[:-2]) + feature_id, offset_id = gen_npu_cpu_ops.embedding_feature_mapping_find(table_name=x.name[:-2], + feature_size=feature_size) + values = tf.gather(x, offset_id) + feature_mapping_export = gen_npu_cpu_ops.embedding_feature_mapping_export(file_path=path, + feature_id=feature_id, + offset_id=offset_id, + values=values, + table_name=x.name[:-2]) + feature_mapping_export_list.append(feature_mapping_export) + return feature_mapping_export_list + + def _call_feature_mapping_import_op(self, path): + feature_mapping_import_list = [] + for x in self._small_table_variable_list: + feature_size = \ + gen_npu_cpu_ops.embedding_feature_mapping_file_size(file_path=path, + table_name=x[:-2], + embedding_dim=self._small_table_embedding_dim) + feature_id, offset_id, values = \ + gen_npu_cpu_ops.embedding_feature_mapping_import(file_path=path, + table_name=x[:-2], + feature_size=feature_size, + embedding_dim=self._small_table_embedding_dim) + feature_mapping_insert = \ + gen_npu_cpu_ops.embedding_feature_mapping_insert(table_name=x[:-2], + feature_id=feature_id, + offset_id=offset_id) + feature_mapping_import_list.append(feature_mapping_insert) + return feature_mapping_import_list +