diff --git a/tf_adapter/python/npu_bridge/embedding/__init__.py b/tf_adapter/python/npu_bridge/embedding/__init__.py index 29a8ca0bb9b5fcb6fa12325400f9b99e22f5bc85..fa8b91fd5cbffda44daa736a7459f1bdc3978e48 100644 --- a/tf_adapter/python/npu_bridge/embedding/__init__.py +++ b/tf_adapter/python/npu_bridge/embedding/__init__.py @@ -22,6 +22,7 @@ from npu_bridge.embedding.embedding_optimizer import AdamWOptimizer as Embedding from npu_bridge.embedding.embedding_optimizer import SgdOptimizer as EmbeddingSgdOptimizer from npu_bridge.embedding.embedding_optimizer import RmspropOptimizer as EmbeddingRmspropOptimizer from npu_bridge.embedding.embedding_optimizer import FtrlOptimizer as EmbeddingFtrlOptimizer +from npu_bridge.embedding.embedding_optimizer import EmbeddingHashTableAdamWOptimizer as AdamWOptimizer from npu_bridge.embedding.embedding_optimizer import exponential_decay_lr as exponential_decay_lr from npu_bridge.embedding.embedding_service import ESWorker as EmbeddingService from npu_bridge.embedding.embedding_service import es_initializer as es_initializer diff --git a/tf_adapter/python/npu_bridge/embedding/embedding_optimizer.py b/tf_adapter/python/npu_bridge/embedding/embedding_optimizer.py index 111117acccaebb58c7dcd7821bb2d2c568631e60..7d6eb1536e922f9c2a4c3b59000cef677d619d4c 100644 --- a/tf_adapter/python/npu_bridge/embedding/embedding_optimizer.py +++ b/tf_adapter/python/npu_bridge/embedding/embedding_optimizer.py @@ -26,7 +26,7 @@ from tensorflow.python.training import adam from tensorflow.python.training import adagrad from tensorflow.python.training import training_ops from tensorflow.python.training import training_util -from npu_bridge.embedding.embedding_resource import NpuEmbeddingResource +from npu_bridge.embedding.embedding_resource import NpuEmbeddingResource, NpuEmbeddingResourceV2 from npu_bridge.npu_cpu.npu_cpu_ops import gen_npu_cpu_ops _GLOBAL_STEP_VALUE = 1 @@ -34,6 +34,7 @@ _ADAM_BEAT1_POWER_VALUE = 0.9 _ADAM_BEAT2_POWER_VALUE = 0.99 _ADAMW_BEAT1_POWER_VALUE = 0.9 _ADAMW_BEAT2_POWER_VALUE = 0.99 +_SMALL_ADAMW_INDEX = 0 class AdamOptimizer(adam.AdamOptimizer): @@ -590,6 +591,93 @@ class FtrlOptimizer(optimizer.Optimizer): raise TypeError("Variable is not NpuEmbeddingResource type, please check.") +class EmbeddingHashTableAdamWOptimizer(optimizer.Optimizer): + """A basic adam optimizer that includes "correct" L2 weight decay.""" + + def __init__(self, + learning_rate=0.01, + weight_decay=0.004, + beta_1=0.9, + beta_2=0.999, + epsilon=1e-6, + amsgrad: bool = False, + maximize: bool = False, + name="EmbeddingHashTableAdamWOptimizer"): + """Construct a AdamW optimizer.""" + super(EmbeddingHashTableAdamWOptimizer, self).__init__(False, name) + if (learning_rate is None) or (weight_decay is None) or (beta_1 is None) or (beta_2 is None): + raise ValueError("learning_rate, weight decay, beta_1 and beta_2 can not be None.") + if (epsilon is None) or (amsgrad is None) or (maximize is None): + raise ValueError("epsilon, amsgrad and maximize can not be None.") + # const input + self._lr = learning_rate + self._weight_decay = weight_decay + self._beta1 = beta_1 + self._beta2 = beta_2 + self._epsilon = epsilon + # var ref input + self._beta1_power_v = tf.Variable(initial_value=0.9, name="beta1_power_" + str(_SMALL_ADAMW_INDEX)) + self._beta2_power_v = tf.Variable(initial_value=0.9, name="beta2_power_" + str(_SMALL_ADAMW_INDEX)) + # attr + self._amsgrad = amsgrad + self._maximize = maximize + # Tensor versions of the constructor arguments, created in _prepare() + self._lr_t = None + self._weight_decay_t = None + self._beta1_t = None + self._beta2_t = None + self._epsilon_t = None + # attr + self.embedding_dim = -1 + self.bucket_size = -1 + + def _prepare(self): + self._m_v = tf.Variable(tf.random_uniform([self.bucket_size, self.embedding_dim], minval=1.0, maxval=1.0), + name="m_" + str(_SMALL_ADAMW_INDEX)) + self._v_v = tf.Variable(tf.random_uniform([self.bucket_size, self.embedding_dim], minval=1.0, maxval=1.0), + name="v_" + str(_SMALL_ADAMW_INDEX)) + self._max_grad_norm_v = \ + tf.Variable(tf.random_uniform([self.bucket_size, self.embedding_dim], minval=1.0, maxval=1.0), + name="max_grad_norm_" + str(_SMALL_ADAMW_INDEX)) + lr = self._call_if_callable(self._lr) + weight_decay = self._call_if_callable(self._weight_decay) + beta1 = self._call_if_callable(self._beta1) + beta2 = self._call_if_callable(self._beta2) + epsilon = self._call_if_callable(self._epsilon) + + self._lr_t = ops.convert_to_tensor(lr, name="learning_rate") + self._weight_decay_t = ops.convert_to_tensor(weight_decay, name="weight_decay") + self._beta1_t = ops.convert_to_tensor(beta1, name="beta1") + self._beta2_t = ops.convert_to_tensor(beta2, name="beta2") + self._epsilon_t = ops.convert_to_tensor(epsilon, name="epsilon") + + def _resource_apply_sparse(self, grad, var, indices): + if isinstance(var, NpuEmbeddingResourceV2): + result = gen_npu_cpu_ops.embedding_hash_table_apply_adam_w(table_handle=var.handle, + m=self._m_v, + v=self._v_v, + beta1_power=self._beta1_power_v, + beta2_power=self._beta2_power_v, + lr=math_ops.cast(self._lr_t, grad.dtype), + weight_decay= + math_ops.cast(self._weight_decay_t, grad.dtype), + beta1=math_ops.cast(self._beta1_t, grad.dtype), + beta2=math_ops.cast(self._beta2_t, grad.dtype), + epsilon= + math_ops.cast(self._epsilon_t, grad.dtype), + grad=grad, + keys=indices, + max_grad_norm=self._max_grad_norm_v, + embedding_dim=self.embedding_dim, + bucket_size=self.bucket_size, + amsgrad=self._amsgrad, + maximize=self._maximize + ) + return result + else: + raise TypeError("Variable is not NpuEmbeddingResourceV2 type, please check.") + + class ExponentialDecayLR: """ exponential decay learning rate used in embedding optimizer. """ diff --git a/tf_adapter/python/npu_bridge/embedding/embedding_resource.py b/tf_adapter/python/npu_bridge/embedding/embedding_resource.py index 74fe6087c6613c1d6cc16c97144e61be51b66d0a..6c9368116e10286bfb937fb32e9f44c6029c3991 100644 --- a/tf_adapter/python/npu_bridge/embedding/embedding_resource.py +++ b/tf_adapter/python/npu_bridge/embedding/embedding_resource.py @@ -41,3 +41,26 @@ class NpuEmbeddingResource: def device(self): return self._tensor.op.device + +class NpuEmbeddingResourceV2: + + def __init__(self, table_id): + self.name = table_id + self._tensor = gen_npu_cpu_ops.table_to_resource_v2(ops.convert_to_tensor([table_id])) + + @property + def handle(self): + return self._tensor + + @property + def graph(self): + return self._tensor.graph + + @property + def op(self): + return self._tensor.op + + @property + def device(self): + return self._tensor.op.device + diff --git a/tf_adapter/python/npu_bridge/embedding/embedding_service.py b/tf_adapter/python/npu_bridge/embedding/embedding_service.py index 6084ac94d0c12869bf41dee332326623c2c8b2e2..c0cdc79128fef2a56b73bef93cdfd092df910dc9 100644 --- a/tf_adapter/python/npu_bridge/embedding/embedding_service.py +++ b/tf_adapter/python/npu_bridge/embedding/embedding_service.py @@ -30,7 +30,7 @@ from npu_bridge.npu_cpu.npu_cpu_ops import gen_npu_cpu_ops from npu_bridge.hccl.hccl_ops import allgather from hccl.manage.api import create_group from hccl.manage.api import set_ps_table_num -from npu_bridge.embedding.embedding_resource import NpuEmbeddingResource +from npu_bridge.embedding.embedding_resource import NpuEmbeddingResource, NpuEmbeddingResourceV2 from npu_bridge.embedding import embedding_optimizer from npu_bridge.embedding.embedding_table_map_policy import NoneTableMapPolicy, AutoMergeTableMapPolicy from npu_bridge.embedding.embedding_utils import EmbeddingVariableOption, CounterFilter, PaddingParamsOption, \ @@ -104,6 +104,23 @@ def es_initializer(initializer_mode, min=-2.0, max=2.0, constant_value=0.0, mu=0 seed=seed) +def check_small_hashtable_init_params(name, init_vocabulary_size, embedding_dim, max_feature_count, initializer_mode): + if (name is None) or (init_vocabulary_size is None) or (embedding_dim is None): + raise ValueError("table name, init_vocabulary_size and embedding_dim can not be None.") + if not isinstance(name, str): + raise TypeError("embedding table name must be string.") + regex = re.compile('[@!#$%^&*()<>?/\|}{~:]') + if regex.search(name) is not None: + raise ValueError("table name contains illegal character.") + if (not isinstance(init_vocabulary_size, int)) or (not isinstance(embedding_dim, int)) or \ + (not isinstance(max_feature_count, int)): + raise ValueError("init_vocabulary_size, embedding_dim and max_feature_count must be int.") + if init_vocabulary_size <= 0 or max_feature_count <= 0 or embedding_dim <= 0: + raise ValueError("init_vocabulary_size and max_feature_count and embedding_dim must be greater than zero.") + if (initializer_mode is not None) and (initializer_mode != "random") and (initializer_mode != "constant"): + raise TypeError("initializer_mode must be random or constant") + + class ESWorker: """ Embedding service class. """ @@ -186,6 +203,8 @@ class ESWorker: self._use_completion_key = False self._table_id_to_completion_option = {} self._user_group_set = set() + # test for david small table + self._init_small_hashtable_params() # 提供 embedding_service table initializer method # table_id embedding 表索引, int 类型 @@ -862,9 +881,88 @@ class ESWorker: steps_to_live=self._steps_to_live) return tf.group([embedding_table_evict]) + def get_embedding_small_variable(self, name, init_vocabulary_size, embedding_dim, max_feature_count, + initializer_mode="constant", constant_value=1.0, load_factor=0.8, + optimizer=None, ev_option=None): + if name not in self._small_hash_table_has_init: + table_id = self._small_hash_table_count + self._small_hash_table_name_to_id[name] = table_id + self._small_hash_table_id_to_name[table_id] = name + self._small_hash_table_count += 1 + self._small_hash_table_has_init.append(name) + else: + raise ValueError("This small hashtable has been initialized.") + check_small_hashtable_init_params(name=name, init_vocabulary_size=init_vocabulary_size, + embedding_dim=embedding_dim, max_feature_count=max_feature_count, + initializer_mode=initializer_mode) + self._update_small_hash_table_dict(table_id=table_id, embedding_dim=embedding_dim, + max_feature_count=max_feature_count, + init_vocabulary_size=init_vocabulary_size, optimizer=optimizer, + ev_option=ev_option) + self._init_embedding_hashmap_v2[table_id] = \ + gen_npu_cpu_ops.init_embedding_hashmap_v2(table_id=table_id, + bucket_size=init_vocabulary_size, + embedding_dim=embedding_dim, + load_factor=load_factor, + dtype=tf.float32) + init_constant_value = constant_value + if initializer_mode is "constant": + sampled_values = ops.convert_to_tensor(1.0, tf.float32) + else: + sampled_values = tf.random.stateless_uniform(shape=[init_vocabulary_size, embedding_dim], + seed=[42, 1234], + minval=0.0, + maxval=1.0, + dtype=tf.float32) + self._init_embedding_hash_table[table_id] = \ + gen_npu_cpu_ops.init_embedding_hash_table(table_handle=self._init_embedding_hashmap_v2.get(table_id), + sampled_values=sampled_values, + bucket_size=init_vocabulary_size, + embedding_dim=embedding_dim, + initializer_mode=initializer_mode, + constant_value=init_constant_value) + return self._init_embedding_hash_table[table_id] + + def forward_lookup(self, name, key): + table_id = self._small_hash_table_name_to_id[name] + if table_id not in self._small_hash_table_id_list: + raise ValueError("This hash table hash not yet initialized.") + table_handle = gen_npu_cpu_ops.table_to_resource_v2(table_id=[table_id]) + result = gen_npu_cpu_ops.embedding_hash_table_lookup_or_insert(table_handle=table_handle, + keys=key, + bucket_size= + self._small_hash_table_to_bucket_size + .get(table_id), + embedding_dim= + self._small_hash_table_to_embedding_dim + .get(table_id), + filter_mode= + self._small_hash_table_to_filter_mode + .get(table_id), + filter_freq= + self._small_hash_table_to_counter_filter + .get(table_id).filter_freq, + default_key_or_value= + self._small_hash_table_to_counter_filter + .get(table_id).default_key_or_value, + default_key= + self._small_hash_table_to_counter_filter + .get(table_id).default_key, + default_value= + self._small_hash_table_to_counter_filter + .get(table_id).default_value) + self._small_hash_table_to_lookup_key[table_id] = key + self._small_hash_table_to_lookup_result[table_id] = result + self._small_hash_table_has_lookup.append(table_id) + return result + def _update_config_params(self): env_dist = os.environ + rank_size = env_dist.get("RANK_SIZE") + rank_id = env_dist.get("RANK_ID") cluster_config_file = env_dist.get("ESCLUSTER_CONFIG_PATH") + if (cluster_config_file is None) and (rank_size is None) and (rank_id is None): + return if cluster_config_file is None: raise ValueError("EsClusterConfig env is null, check your env config.") with open(cluster_config_file, encoding='utf-8') as b: @@ -885,6 +983,47 @@ class ESWorker: if self._server_ip_to_ps_num[each_server_ps_num] > 4: raise ValueError("PS num of one server can not exceed 4, please check config params.") + def _init_small_hashtable_params(self): + self._small_hash_table_name_to_id = {} + self._small_hash_table_id_to_name = {} + self._small_hash_table_count = 0 + self._small_hash_table_has_init = [] + self._small_hash_table_id_list = [] + self._small_hash_table_to_embedding_dim = {} + self._small_hash_table_to_key_num = {} + self._small_hash_table_to_bucket_size = {} + self._small_hash_table_to_optimizer = {} + self._small_hash_table_to_filter_mode = {} + self._small_hash_table_to_counter_filter = {} + # op + self._init_embedding_hashmap_v2 = {} + self._init_embedding_hash_table = {} + # for lookup + self._small_hash_table_lookup_result = {} + self._small_hash_table_to_lookup_result = {} + self._small_hash_table_to_lookup_key = {} + self._small_hash_table_has_lookup = [] + + def _update_small_hash_table_dict(self, table_id, embedding_dim, max_feature_count, + init_vocabulary_size, optimizer, ev_option): + self._small_hash_table_to_embedding_dim[table_id] = embedding_dim + self._small_hash_table_to_key_num[table_id] = max_feature_count + self._small_hash_table_to_bucket_size[table_id] = init_vocabulary_size + self._small_hash_table_id_list.append(table_id) + + if (ev_option is not None) and (ev_option.filter_option is not None): + self._small_hash_table_to_filter_mode[table_id] = "counter" + self._small_hash_table_to_counter_filter[table_id] = ev_option.filter_option + else: + self._small_hash_table_to_filter_mode[table_id] = "no_filter" + self._small_hash_table_to_counter_filter[table_id] = CounterFilter(filter_freq=1, + default_key_or_value=False, + default_key=1, + default_value=1.0) + self._small_hash_table_to_optimizer[table_id] = optimizer + self._small_hash_table_to_optimizer[table_id].embedding_dim = embedding_dim + self._small_hash_table_to_optimizer[table_id].bucket_size = init_vocabulary_size + def _check_and_update_small_init_params(self, name, init_vocabulary_size, embedding_dim, multihot_lens, key_dtype, value_dtype, allow_merge, initializer): if name not in self._small_table_name_list: