diff --git a/tf_adapter/python/npu_bridge/embedding/embedding_service.py b/tf_adapter/python/npu_bridge/embedding/embedding_service.py index b109301e70b05c79f8334d588d6457a2315fddf9..79fb15aaefa8e381b530b7ba199a45a5f089e6fa 100644 --- a/tf_adapter/python/npu_bridge/embedding/embedding_service.py +++ b/tf_adapter/python/npu_bridge/embedding/embedding_service.py @@ -469,7 +469,7 @@ class ESWorker: # @param name str 类型 # @param ids int64 类型 # @return values float32 类型 - def embedding_lookup_v2(self, name: str, ids: typing.Any): + def embedding_lookup(self, name: str, ids: typing.Any): """ Operator for look up in embedding table. """ if (name is None) or (ids is None): raise ValueError("table name or ids must be specified.") @@ -536,7 +536,7 @@ class ESWorker: # @param table_id int32 类型 # @param input_ids int64 类型 # @return values float32 类型 - def embedding_lookup(self, table_id: int, input_ids: typing.Any): + def embedding_lookup_v1(self, table_id: int, input_ids: typing.Any): """ Operator for look up in embedding table. """ if (table_id is None) or (input_ids is None): raise ValueError("table_id or input_ids must be specified.") @@ -598,7 +598,7 @@ class ESWorker: # new version # 提供embedding update功能 # @param loss 类型 - def embedding_update_v2(self, loss): + def embedding_update(self, loss): """ Operator for update in embedding table. """ params = self._ps_table_lookup_result input_ids_list = self._ps_table_lookup_key @@ -640,7 +640,7 @@ class ESWorker: # @param params float32 类型 # @param table_ids int32 类型 # @param input_ids_list int64 类型 - def embedding_update(self, loss, params, table_ids, input_ids_list): + def embedding_update_v1(self, loss, params, table_ids, input_ids_list): """ Operator for update in embedding table. """ if (loss is None) or (params is None) or (table_ids is None) or (input_ids_list is None): raise ValueError("loss or params or table_ids or input_ids_list is None.") @@ -809,7 +809,7 @@ class ESWorker: output_slots[sid] = out_emb return output_slots - def save_embedding_v2(self, name: str, path: str): + def save_embedding(self, name: str, path: str): """ Operator for save values in table_id embedding table. """ if path is None or name is None: raise ValueError("table name, embedding table path can not be None.") @@ -873,7 +873,7 @@ class ESWorker: table_name=self._ps_table_name_list) return tf.group([embedding_table_export]) - def restore_embedding_v2(self, name: str, path: str): + def restore_embedding(self, name: str, path: str): if path is None or name is None: raise ValueError("table name, embedding table path can not be None.") if path[-1] == '/': @@ -919,7 +919,7 @@ class ESWorker: table_name=self._ps_table_name_list) return tf.group([embedding_table_import]) - def save_checkpoint_v2(self, name: str, path: str): + def save_checkpoint(self, name: str, path: str): """ Operator for save values and optimizer params in table_id embedding table. """ if path is None or name is None: raise ValueError("table name, embedding table path can not be None.") @@ -1000,7 +1000,7 @@ class ESWorker: table_name=self._ps_table_name_list) return tf.group([embedding_compute_var_export]) - def restore_checkpoint_v2(self, name: str, path: str): + def restore_checkpoint(self, name: str, path: str): """ Operator for restore values and optimizer params in table_id embedding table. """ if path is None or name is None: raise ValueError("name, embedding table path can not be None.") @@ -1069,7 +1069,7 @@ class ESWorker: table_name=self._ps_table_name_list) return tf.group([embedding_compute_var_import]) - def save_incremental_embedding_v2(self, name: str, path: str): + def save_incremental_embedding(self, name: str, path: str): """ Operator for save incremental values in table_id embedding table. """ if path is None or name is None: raise ValueError("table name, embedding table path can not be None.") @@ -1133,7 +1133,7 @@ class ESWorker: table_name=self._ps_table_name_list) return tf.group([embedding_table_export]) - def restore_incremental_embedding_v2(self, name: str, path: str): + def restore_incremental_embedding(self, name: str, path: str): if path is None or name is None: raise ValueError("table name, embedding table path can not be None.") if path[-1] == '/': @@ -1180,7 +1180,7 @@ class ESWorker: return tf.group([embedding_table_import]) # old version - def save_embedding(self, path: str, table_id: int): + def save_embedding_v1(self, path: str, table_id: int): """ Operator for save values in table_id embedding table. """ if path is None or table_id is None: raise ValueError("table_id, embedding table path can not be None.") @@ -1210,7 +1210,7 @@ class ESWorker: table_name=[self._table_id_to_name.get(table_id)]) return tf.group([embedding_table_export]) - def restore_embedding(self, path: str, table_id: int): + def restore_embedding_v1(self, path: str, table_id: int): if path is None or table_id is None: raise ValueError("table_id, embedding table path can not be None.") if path[-1] == '/': @@ -1231,7 +1231,7 @@ class ESWorker: table_name=[self._table_id_to_name.get(table_id)]) return tf.group([embedding_table_import]) - def save_checkpoint(self, path: str, table_id: int): + def save_checkpoint_v1(self, path: str, table_id: int): """ Operator for save values and optimizer params in table_id embedding table. """ if path is None or table_id is None: raise ValueError("table_id, embedding table path can not be None.") @@ -1269,7 +1269,7 @@ class ESWorker: table_name=[self._table_id_to_name.get(table_id)]) return tf.group([embedding_compute_var_export]) - def restore_checkpoint(self, path: str, table_id: int): + def restore_checkpoint_v1(self, path: str, table_id: int): """ Operator for restore values and optimizer params in table_id embedding table. """ if path is None or table_id is None: raise ValueError("table_id, embedding table path can not be None.") @@ -1302,7 +1302,7 @@ class ESWorker: table_name=[self._table_id_to_name.get(table_id)]) return tf.group([embedding_compute_var_import]) - def save_incremental_embedding(self, path: str, table_id: int): + def save_incremental_embedding_v1(self, path: str, table_id: int): """ Operator for save incremental values in table_id embedding table. """ if path is None or table_id is None: raise ValueError("table_id, embedding table path can not be None.") @@ -1332,7 +1332,7 @@ class ESWorker: table_name=[self._table_id_to_name.get(table_id)]) return tf.group([embedding_table_export]) - def restore_incremental_embedding(self, path: str, table_id: int): + def restore_incremental_embedding_v1(self, path: str, table_id: int): if path is None or table_id is None: raise ValueError("table_id, embedding table path can not be None.") if path[-1] == '/':