diff --git a/tf_adapter/python/npu_bridge/embedding/embedding_optimizer.py b/tf_adapter/python/npu_bridge/embedding/embedding_optimizer.py index fe650e028910fc73b11319db482208847eef73e0..c13bd1da08e26c617844f705b5d3be70f2c17301 100644 --- a/tf_adapter/python/npu_bridge/embedding/embedding_optimizer.py +++ b/tf_adapter/python/npu_bridge/embedding/embedding_optimizer.py @@ -616,9 +616,6 @@ class EmbeddingHashTableAdamWOptimizer(optimizer.Optimizer): self._beta2 = beta_2 self._epsilon = epsilon # var ref input - self._m_v = tf.Variable(initial_value=1.0, name="m_" + str(_SMALL_ADAMW_INDEX)) - self._v_v = tf.Variable(initial_value=1.0, name="v_" + str(_SMALL_ADAMW_INDEX)) - self._max_grad_norm_v = tf.Variable(initial_value=0.5, name="max_grad_norm_" + str(_SMALL_ADAMW_INDEX)) self._beta1_power_v = tf.Variable(initial_value=0.9, name="beta1_power_" + str(_SMALL_ADAMW_INDEX)) self._beta2_power_v = tf.Variable(initial_value=0.9, name="beta2_power_" + str(_SMALL_ADAMW_INDEX)) # attr @@ -635,6 +632,12 @@ class EmbeddingHashTableAdamWOptimizer(optimizer.Optimizer): self.bucket_size = -1 def _prepare(self): + self._m_v = tf.Variable(tf.random_uniform([self.bucket_size, self.embedding_dim], minval=1.0, maxval=1.0), + name="m_" + str(_SMALL_ADAMW_INDEX)) + self._v_v = tf.Variable(tf.random_uniform([self.bucket_size, self.embedding_dim], minval=1.0, maxval=1.0), + name="v_" + str(_SMALL_ADAMW_INDEX)) + self._max_grad_norm_v = tf.Variable(tf.random_uniform([10, self.embedding_dim], minval=1.0, maxval=1.0), + name="max_grad_norm_" + str(_SMALL_ADAMW_INDEX)) lr = self._call_if_callable(self._lr) weight_decay = self._call_if_callable(self._weight_decay) beta1 = self._call_if_callable(self._beta1) @@ -649,30 +652,26 @@ class EmbeddingHashTableAdamWOptimizer(optimizer.Optimizer): def _resource_apply_sparse(self, grad, var, indices): if isinstance(var, NpuEmbeddingResourceV2): - result = gen_npu_cpu_ops.embedding_apply_adam_w(table_handle=var.handle, - m= - math_ops.cast(self._m_v, grad.dtype), - v= - math_ops.cast(self._v_v, grad.dtype), - beta1_power= - math_ops.cast(self._beta1_power_v, grad.dtype), - beta2_power= - math_ops.cast(self._beta2_power_v, grad.dtype), - lr=math_ops.cast(self._lr_t, grad.dtype), - weight_decay= - math_ops.cast(self._weight_decay_t, grad.dtype), - beta1=math_ops.cast(self._beta1_t, grad.dtype), - beta2=math_ops.cast(self._beta2_t, grad.dtype), - epsilon=math_ops.cast(self._epsilon_t, grad.dtype), - grad=grad, - keys=indices, - max_grad_norm= - math_ops.cast(self._max_grad_norm_v, grad.dtype), - embedding_dim=self.embedding_dim, - bucket_size=self.bucket_size, - amsgrad=self._amsgrad, - maximize=self._maximize - ) + result = gen_npu_cpu_ops.embedding_hash_table_apply_adam_w(table_handle=var.handle, + m=self._m_v, + v=self._v_v, + beta1_power=self._beta1_power_v, + beta2_power=self._beta2_power_v, + lr=math_ops.cast(self._lr_t, grad.dtype), + weight_decay= + math_ops.cast(self._weight_decay_t, grad.dtype), + beta1=math_ops.cast(self._beta1_t, grad.dtype), + beta2=math_ops.cast(self._beta2_t, grad.dtype), + epsilon= + math_ops.cast(self._epsilon_t, grad.dtype), + grad=grad, + keys=indices, + max_grad_norm=self._max_grad_norm_v, + embedding_dim=self.embedding_dim, + bucket_size=self.bucket_size, + amsgrad=self._amsgrad, + maximize=self._maximize + ) return result else: raise TypeError("Variable is not NpuEmbeddingResourceV2 type, please check.") diff --git a/tf_adapter/python/npu_bridge/embedding/embedding_resource.py b/tf_adapter/python/npu_bridge/embedding/embedding_resource.py index 9e47637d7d738a2544214373e05ac7e54ea3d3ff..6c9368116e10286bfb937fb32e9f44c6029c3991 100644 --- a/tf_adapter/python/npu_bridge/embedding/embedding_resource.py +++ b/tf_adapter/python/npu_bridge/embedding/embedding_resource.py @@ -46,7 +46,7 @@ class NpuEmbeddingResourceV2: def __init__(self, table_id): self.name = table_id - self._tensor = gen_npu_cpu_ops.table_to_resource_v2(ops.convert_to_tensor(table_id)) + self._tensor = gen_npu_cpu_ops.table_to_resource_v2(ops.convert_to_tensor([table_id])) @property def handle(self): diff --git a/tf_adapter/python/npu_bridge/embedding/embedding_service.py b/tf_adapter/python/npu_bridge/embedding/embedding_service.py index 75b7d17fe4d562ce7c6b6a79b4a81f9bd5f018e8..d66a8008dfeefce291ad54948fabd29cf1541d4c 100644 --- a/tf_adapter/python/npu_bridge/embedding/embedding_service.py +++ b/tf_adapter/python/npu_bridge/embedding/embedding_service.py @@ -396,7 +396,7 @@ class ESWorker: table_id = self._small_hash_table_name_to_id[name] if table_id not in self._small_hash_table_id_list: raise ValueError("This hash table hash not yet initialized.") - table_handle = gen_npu_cpu_ops.table_to_resource_v2(table_id=table_id) + table_handle = gen_npu_cpu_ops.table_to_resource_v2(table_id=[table_id]) result = gen_npu_cpu_ops.embedding_hash_table_lookup_or_insert(table_handle=table_handle, keys=key, bucket_size= @@ -424,6 +424,128 @@ class ESWorker: self._small_hash_table_has_lookup.append(table_id) return result + def backward_update(self, loss): + lookup_values = [] + key_values = [] + for value in self._small_hash_table_lookup_result.values(): + lookup_values.append(value) + for key in self._small_hash_table_to_lookup_key.values(): + key_values.append(key) + table_id_values = self._small_hash_table_has_lookup + if (not isinstance(lookup_values, (list, tuple)) and not isinstance(table_id_values, (list, tuple)) + and not isinstance(key_values, (list, tuple))): + lookup_values = [lookup_values] + table_id_values = [table_id_values] + key_values = [key_values] + if (len(lookup_values) != len(table_id_values)) or (len(lookup_values) != len(key_values)) \ + or (len(table_id_values) != len(key_values)): + raise ValueError("The length of lookup_values, table_id_values, key_values should be equal.") + embedding_grads = tf.gradients(loss, lookup_values) + update_op = [] + self._small_hash_table_lookup_result = {} + self._small_hash_table_to_lookup_key = {} + self._small_hash_table_has_lookup = [] + + for i in range(len(table_id_values)): + if embedding_grads[i] is None: + continue + params_grads = [tf.IndexedSlices(embedding_grads[i], input_ids_list[i], dense_shape=params[i].shape)] + var_refs = [NpuEmbeddingResourceV2(table_ids[i])] + update_op.append( + self._small_hash_table_to_optimizer.get(table_ids[i]) + .apply_gradients(list(zip(params_grads, var_refs)))) + return update_op + + def small_hash_table_export(self, file_path, table_name=None, step=0): + table_ids = [] + if table_name is None: + table_id = self._small_hash_table_name_to_id[table_name] + if table_id not in self._small_hash_table_id_list: + raise ValueError("This hash table has not yey initialized.") + table_ids.append(table_id) + else: + table_ids = self._small_hash_table_id_list + embedding_dims = [] + bucket_sizes = [] + table_names = [] + for table_id in table_ids: + embedding_dims.append(self._small_hash_table_to_embedding_dim.get(table_id)) + bucket_sizes.append(self._small_hash_table_to_optimizer[table_id].bucket_size) + table_names.append(self._small_hash_table_id_to_name.get(table_id)) + table_handles = gen_npu_cpu_ops.table_to_resource_v2(tale_id=table_ids) + table_sizes = gen_npu_cpu_ops.embedding_hashmap_size(table_ids= + ops.convert_to_tensor(table_ids, name="table_ids"), + export_mode="all", + filter_export_flag=True) + dynamic_keys, dynamic_counters, dynamic_filter_flags, dynamic_values = \ + gen_npu_cpu_ops.embedding_hash_table_export(table_handles=table_handles, + table_sizes=table_sizes, + embedding_dims= + ops.convert_to_tensor(embedding_dims, + name="embedding_dims", dtype=tf.int64), + bucket_sizes= + ops.convert_to_tensor(bucket_sizes, + name="bucket_sizes", dtype=tf.int64), + export_mode="all", + filter_export_flag=True, + num=len(table_names)) + result = gen_npu_cpu_ops.embedding_hashmap_export(file_path=ops.convert_to_tensor(file_path, name="file_path"), + table_ids=ops.convert_to_tensor(table_ids, name="table_ids"), + table_names= + ops.convert_to_tensor(table_names, name="table_names"), + global_step=ops.convert_to_tensor(step, name="export_step"), + keys=dynamic_keys, + counters=dynamic_counters, + filter_flags=dynamic_filter_flags, + values=dynamic_values) + return result + + def small_hash_table_import(self, file_path, table_name=None, step=0): + table_ids = [] + if table_name is None: + table_id = self._small_hash_table_name_to_id[table_name] + if table_id not in self._small_hash_table_id_list: + raise ValueError("This hash table has not yey initialized.") + table_ids.append(table_id) + else: + table_ids = self._small_hash_table_id_list + embedding_dims = [] + bucket_sizes = [] + table_names = [] + for table_id in table_ids: + embedding_dims.append(self._small_hash_table_to_embedding_dim.get(table_id)) + bucket_sizes.append(self._small_hash_table_to_optimizer[table_id].bucket_size) + table_names.append(self._small_hash_table_id_to_name.get(table_id)) + table_handles = gen_npu_cpu_ops.table_to_resource_v2(tale_id=table_ids) + table_sizes = \ + gen_npu_cpu_ops.embedding_hashmap_file_size(file_path=ops.convert_to_tensor(file_path, name="file_path"), + table_ids=ops.convert_to_tensor(table_ids, name="table_ids"), + table_names= + ops.convert_to_tensor(table_names, name="table_names"), + global_step=ops.convert_to_tensor(step, name="import_step"), + embedding_dims=embedding_dims) + dynamic_keys, dynamic_counters, dynamic_filter_flags, dynamic_values = \ + gen_npu_cpu_ops.embedding_hashmap_import(file_path=ops.convert_to_tensor(file_path, name="file_path"), + table_ids=ops.convert_to_tensor(table_ids, name="table_ids"), + table_sizes=table_sizes, + table_names=ops.convert_to_tensor(table_names, name="table_names"), + global_step=ops.convert_to_tensor(step, name="import_step"), + embedding_dims=embedding_dims, + num=len(table_names)) + result = \ + gen_npu_cpu_ops.embedding_hash_table_import(table_handles=table_handles, + embedding_dims= + ops.convert_to_tensor(embedding_dims, + name="embedding_dims", dtype=tf.int64), + bucket_sizes= + ops.convert_to_tensor(bucket_sizes, + name="bucket_sizes", dtype=tf.int64), + keys=dynamic_keys, + counters=dynamic_counters, + filter_flags=dynamic_filter_flags, + values=dynamic_values) + return result + # new version # 提供embedding lookup功能 # @param name str 类型 diff --git a/tf_adapter/python/npu_bridge/embedding/tf_path.py b/tf_adapter/python/npu_bridge/embedding/tf_path.py index a5717e652ec960fa3a849471e9d02d74ff0c58da..3e71bf9d9bd5b028bf9a3a81d3df3f60d2ace3fe 100644 --- a/tf_adapter/python/npu_bridge/embedding/tf_path.py +++ b/tf_adapter/python/npu_bridge/embedding/tf_path.py @@ -15,12 +15,14 @@ # limitations under the License. # ============================================================================== +import tensorflow as tf from tensorflow.python.eager import context from tensorflow.python.framework import ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variables +from tensorflow.python.ops import math_ops from tensorflow.python.training import optimizer as embeddingOptimizer -from npu_bridge.embedding.embedding_resource import NpuEmbeddingResource +from npu_bridge.embedding.embedding_resource import NpuEmbeddingResource, NpuEmbeddingResourceV2 class _NpuEmbeddingResourceProcessor(embeddingOptimizer._OptimizableVariable): @@ -36,6 +38,21 @@ class _NpuEmbeddingResourceProcessor(embeddingOptimizer._OptimizableVariable): return optimizer._resource_apply_sparse(g.values, self._v, g.indices) +class _NpuEmbeddingResourceProcessorV2(embeddingOptimizer._OptimizableVariable): + """Processor for dense NpuEmbeddingResourceProcessor.""" + + def __init__(self, v): + self._v = v + + def target(self): + return self._v + + def update_op(self, optimizer, g): + uniqued_key, key_position = tf.unique(g.indices) + summed_grad = math_ops.unsorted_segment_sum(g.values, key_position, tf.shape(uniqued_key)[0]) + return optimizer._resource_apply_sparse(summed_grad, self._v, uniqued_key) + + def _get_processor(v): """The processor of v.""" if context.executing_eagerly(): @@ -45,6 +62,8 @@ def _get_processor(v): return embeddingOptimizer._DenseResourceVariableProcessor(v) if isinstance(v, NpuEmbeddingResource): return _NpuEmbeddingResourceProcessor(v) + if isinstance(v, NpuEmbeddingResourceV2): + return _NpuEmbeddingResourceProcessorV2(v) if resource_variable_ops.is_resource_variable(v) and not v._in_graph_mode: # pylint: disable=protected-access # True if and only if `v` was initialized eagerly. return embeddingOptimizer._DenseResourceVariableProcessor(v)