diff --git a/tf_adapter/ops/aicore/npu_aicore_ops.cc b/tf_adapter/ops/aicore/npu_aicore_ops.cc index bb5e99e7f50ebb25804b2a9c598c1fcfa554043f..91913fd56cc367f55872e5bcf376ee36107dd8c5 100644 --- a/tf_adapter/ops/aicore/npu_aicore_ops.cc +++ b/tf_adapter/ops/aicore/npu_aicore_ops.cc @@ -527,6 +527,8 @@ REGISTER_OP("EmbeddingHashTableLookupOrInsert") .Attr("default_key_or_value:bool = false") .Attr("default_key: int = 0") .Attr("default_value: float = 0.0") + .Attr("filter_key_flag: bool = false") + .Attr("filter_key: int = -1") .SetIsStateful() .SetShapeFn([](InferenceContext* c) { int64 num = 0; diff --git a/tf_adapter/python/npu_bridge/npu_cpu/npu_cpu_ops.py b/tf_adapter/python/npu_bridge/npu_cpu/npu_cpu_ops.py index e7c5fa61fc2d1410d1f819c7b0217d7bc56542a3..8b79fb817ae0e86670d57f382928e41c63cb8911 100644 --- a/tf_adapter/python/npu_bridge/npu_cpu/npu_cpu_ops.py +++ b/tf_adapter/python/npu_bridge/npu_cpu/npu_cpu_ops.py @@ -37,12 +37,12 @@ gen_npu_cpu_ops = helper.get_gen_ops() # @param default_value float 类型 # @return values float 类型 def embedding_hashtable_lookup_or_insert(table_handle, keys, bucket_size, embedding_dim, filter_mode, filter_freq, - default_key_or_value, default_key, default_value): + default_key_or_value, default_key, default_value, filter_key_flag, filter_key): """ device embedding feature mapping lookup or insert. """ result = gen_npu_cpu_ops.EmbeddingHashTableLookupOrInsert( table_handle=table_handle, keys=keys, bucket_size=bucket_size, embedding_dim=embedding_dim, filter_mode=filter_mode, filter_freq=filter_freq, default_key_or_value=default_key_or_value, - default_key=default_key, default_value=default_value) + default_key=default_key, default_value=default_value, filter_key_flag=filter_key_flag, filter_key=filter_key) return result