diff --git a/tf_adapter/python/npu_bridge/embedding/embedding_optimizer.py b/tf_adapter/python/npu_bridge/embedding/embedding_optimizer.py index 7d6eb1536e922f9c2a4c3b59000cef677d619d4c..3880b631f6fdbafacdc580d4c0d5d6bf3afded7c 100644 --- a/tf_adapter/python/npu_bridge/embedding/embedding_optimizer.py +++ b/tf_adapter/python/npu_bridge/embedding/embedding_optimizer.py @@ -618,6 +618,9 @@ class EmbeddingHashTableAdamWOptimizer(optimizer.Optimizer): # var ref input self._beta1_power_v = tf.Variable(initial_value=0.9, name="beta1_power_" + str(_SMALL_ADAMW_INDEX)) self._beta2_power_v = tf.Variable(initial_value=0.9, name="beta2_power_" + str(_SMALL_ADAMW_INDEX)) + self.m_v = None + self.v_v = None + self.max_grad_norm_v = None # attr self._amsgrad = amsgrad self._maximize = maximize @@ -632,13 +635,6 @@ class EmbeddingHashTableAdamWOptimizer(optimizer.Optimizer): self.bucket_size = -1 def _prepare(self): - self._m_v = tf.Variable(tf.random_uniform([self.bucket_size, self.embedding_dim], minval=1.0, maxval=1.0), - name="m_" + str(_SMALL_ADAMW_INDEX)) - self._v_v = tf.Variable(tf.random_uniform([self.bucket_size, self.embedding_dim], minval=1.0, maxval=1.0), - name="v_" + str(_SMALL_ADAMW_INDEX)) - self._max_grad_norm_v = \ - tf.Variable(tf.random_uniform([self.bucket_size, self.embedding_dim], minval=1.0, maxval=1.0), - name="max_grad_norm_" + str(_SMALL_ADAMW_INDEX)) lr = self._call_if_callable(self._lr) weight_decay = self._call_if_callable(self._weight_decay) beta1 = self._call_if_callable(self._beta1) @@ -654,8 +650,8 @@ class EmbeddingHashTableAdamWOptimizer(optimizer.Optimizer): def _resource_apply_sparse(self, grad, var, indices): if isinstance(var, NpuEmbeddingResourceV2): result = gen_npu_cpu_ops.embedding_hash_table_apply_adam_w(table_handle=var.handle, - m=self._m_v, - v=self._v_v, + m=self.m_v, + v=self.v_v, beta1_power=self._beta1_power_v, beta2_power=self._beta2_power_v, lr=math_ops.cast(self._lr_t, grad.dtype), @@ -667,7 +663,7 @@ class EmbeddingHashTableAdamWOptimizer(optimizer.Optimizer): math_ops.cast(self._epsilon_t, grad.dtype), grad=grad, keys=indices, - max_grad_norm=self._max_grad_norm_v, + max_grad_norm=self.max_grad_norm_v, embedding_dim=self.embedding_dim, bucket_size=self.bucket_size, amsgrad=self._amsgrad, diff --git a/tf_adapter/python/npu_bridge/embedding/embedding_service.py b/tf_adapter/python/npu_bridge/embedding/embedding_service.py index db0c0c0da8ec1087c01252673c024337c8059d1a..1d2bf1bdd5747427e401e7c3ad4180a0a43343fa 100644 --- a/tf_adapter/python/npu_bridge/embedding/embedding_service.py +++ b/tf_adapter/python/npu_bridge/embedding/embedding_service.py @@ -1119,6 +1119,15 @@ class ESWorker: self._small_hash_table_to_optimizer[table_id] = optimizer self._small_hash_table_to_optimizer[table_id].embedding_dim = embedding_dim self._small_hash_table_to_optimizer[table_id].bucket_size = init_vocabulary_size + self._small_hash_table_to_optimizer[table_id].m_v = \ + tf.Variable(tf.random_uniform([init_vocabulary_size, embedding_dim], minval=1.0, maxval=1.0), + name="m_" + str(_SMALL_ADAMW_INDEX)) + self._small_hash_table_to_optimizer[table_id].v_v = \ + tf.Variable(tf.random_uniform([init_vocabulary_size, embedding_dim], minval=1.0, maxval=1.0), + name="v_" + str(_SMALL_ADAMW_INDEX)) + self._small_hash_table_to_optimizer[table_id].max_grad_norm_v = \ + tf.Variable(tf.random_uniform([init_vocabulary_size, embedding_dim], minval=1.0, maxval=1.0), + name="max_grad_norm_" + str(_SMALL_ADAMW_INDEX)) def _check_and_update_small_init_params(self, name, init_vocabulary_size, embedding_dim, multihot_lens, key_dtype, value_dtype, allow_merge, initializer):