From 2f73f943029823e65b1c5ec53d7d46dbb7ed03df Mon Sep 17 00:00:00 2001 From: songmingyang001 Date: Tue, 26 Aug 2025 20:29:42 +0800 Subject: [PATCH] add embedding hash table api --- .../python/npu_bridge/npu_cpu/npu_cpu_ops.py | 101 ++++++++++++++++++ 1 file changed, 101 insertions(+) diff --git a/tf_adapter/python/npu_bridge/npu_cpu/npu_cpu_ops.py b/tf_adapter/python/npu_bridge/npu_cpu/npu_cpu_ops.py index e5abba177..64c0a221e 100644 --- a/tf_adapter/python/npu_bridge/npu_cpu/npu_cpu_ops.py +++ b/tf_adapter/python/npu_bridge/npu_cpu/npu_cpu_ops.py @@ -25,6 +25,29 @@ from npu_bridge.helper import helper gen_npu_cpu_ops = helper.get_gen_ops() +## 提供device侧FeatureMapping LookupOrInsert功能 +# @param table_handle int64 类型 +# @param keys int64 类型 +# @param bucket_size int 类型 +# @param embedding_dim int 类型 +# @param filter_mode string 类型 +# @param filter_freq int 类型 +# @param default_key_or_value bool 类型 +# @param default_key int 类型 +# @param default_value float 类型 +# @param filter_key_flag bool 类型 +# @param filter_key int 类型 +# @return values float 类型 +def embedding_hashtable_lookup_or_insert(table_handle, keys, bucket_size, embedding_dim, filter_mode, filter_freq, + default_key_or_value, default_key, default_value, filter_key_flag, filter_key): + """ device embedding feature mapping lookup or insert. """ + result = gen_npu_cpu_ops.EmbeddingHashTableLookupOrInsert( + table_handle=table_handle, keys=keys, bucket_size=bucket_size, embedding_dim=embedding_dim, + filter_mode=filter_mode, filter_freq=filter_freq, default_key_or_value=default_key_or_value, + default_key=default_key, default_value=default_value, filter_key_flag=filter_key_flag, filter_key=filter_key) + return result + + ## 提供embeddingrankid功能 # @param addr_tensor tensorflow的tensor类型,embeddingrankid操作的输入; # @param index tensorflow的tensor类型,embeddingrankid操作的输入; @@ -467,4 +490,82 @@ def embedding_hashmap_import_v2(file_path, table_ids, table_sizes, table_names, result = gen_npu_cpu_ops.EmbeddingHashmapImport( file_path=file_path, table_ids=table_ids, table_sizes=table_sizes, table_names=table_names, global_step=global_step, embedding_dims=embedding_dims, num=num) + return result + + +## EmbeddingHashTable Init功能 +# @param table_handle int64 类型 +# @param sampled_values float 类型 +# @param bucket_size int 类型 +# @param embedding_dim int 类型 +# @param initializer_mode string 类型 +# @param constant_value int 类型 +def init_embedding_hashtable(table_handle, sampled_values, bucket_size, embedding_dim, initializer_mode, + constant_value): + """ device init embedding hashtable. """ + result = gen_npu_cpu_ops.InitEmbeddingHashTable( + table_handle=table_handle, sampled_values=sampled_values, bucket_size=bucket_size, embedding_dim=embedding_dim, + initializer_mode=initializer_mode, constant_value=constant_value) + return result + + +## 提供host侧hashTable导入功能 +# @param table_handles int64 类型 +# @param embedding_dims int64 类型 +# @param bucket_sizes int64 类型 +# @param keys int64 类型 +# @param counters uint64 类型 +# @param filter_flags uint8 类型 +# @param values float 类型 +def embedding_hash_table_import(table_handles, embedding_dims, bucket_sizes, keys, counters, filter_flags, values): + """ host embedding feature hash table import. """ + result = gen_npu_cpu_ops.EmbeddingHashTableImport( + table_handles=table_handles, embedding_dims=embedding_dims, bucket_sizes=bucket_sizes, + keys=keys, counters=counters, filter_flags=filter_flags, values=values) + return result + + +## 提供host侧hashTable导出功能 +# @param table_handles int64 类型 +# @param table_sizes int64 类型 +# @param embedding_dims int64 类型 +# @param bucket_sizes int64 类型 +# @param export_mode string 类型 +# @param filtered_export_flag bool 类型 +def embedding_hash_table_export(table_handles, table_sizes, embedding_dims, bucket_sizes, export_mode='all', + filter_export_flag=False): + """ host embedding feature hash table export. """ + result = gen_npu_cpu_ops.EmbeddingHashTableExport( + table_handles=table_handles, table_sizes=table_sizes, embedding_dims=embedding_dims, bucket_sizes=bucket_sizes, + export_mode=export_mode, filter_export_flag=filter_export_flag) + return result + + +## EmbeddingHashTableApplyAdamW AdamW 更新功能 +# @param table_handle int64 类型 +# @param keys int64 类型 +# @param m float16, float32 类型 +# @param v float16, float32 类型 +# @param beta1_power float16, float32 类型 +# @param beta2_power float16, float32 类型 +# @param lr float16, float32 类型 +# @param weight_decay float16, float32 类型 +# @param beta1 float16, float32 类型 +# @param beta2 float16, float32 类型 +# @param epsilon float16, float32 类型 +# @param grad float16, float32 类型 +# @param max_grad_norm float16, float32 类型 +# @param embedding_dim int 类型 +# @param bucket_size int 类型 +# @param amsgrad bool 类型 +# @param maximize bool 类型 +def embedding_hashtable_apply_adam_w(table_handle, keys, m, v, beta1_power, beta2_power, lr, weight_decay, + beta1, beta2, epsilon, grad, max_grad_norm, embedding_dim, + bucket_size, amsgrad, maximize): + """ device update embedding hashtable using AdamW. """ + result = gen_npu_cpu_ops.EmbeddingHashTableApplyAdamW( + table_handle=table_handle, keys=keys, m=m, v=v, beta1_power=beta1_power, beta2_power=beta2_power, + lr=lr, weight_decay=weight_decay, beta1=beta1, beta2=beta2, epsilon=epsilon, grad=grad, + max_grad_norm=max_grad_norm, embedding_dim=embedding_dim, bucket_size=bucket_size, + amsgrad=amsgrad, maximize=maximize) return result \ No newline at end of file -- Gitee