diff --git a/tf_adapter/python/npu_bridge/embedding/embedding_service.py b/tf_adapter/python/npu_bridge/embedding/embedding_service.py index db292eb3c9a342c5d9fed34ab23a96b374a58c54..e55321ef6abf5765cefea562a86c01d1e4579ce7 100644 --- a/tf_adapter/python/npu_bridge/embedding/embedding_service.py +++ b/tf_adapter/python/npu_bridge/embedding/embedding_service.py @@ -493,8 +493,8 @@ class ESWorker: gen_npu_cpu_ops.embedding_table_export(file_path=file_path_tensor, ps_id=ps_id_tensor, table_id=table_id_tensor, - embedding_dim=embedding_dim_list[0], - value_total_len=embedding_dim_list[0], + embedding_dim=embedding_dim_list, + value_total_len=embedding_dim_list, export_mode="all", only_var_flag=True, file_type="bin") @@ -535,8 +535,8 @@ class ESWorker: gen_npu_cpu_ops.embedding_table_import(ps_id=ops.convert_to_tensor(-1), file_path=ops.convert_to_tensor(path), table_id=ops.convert_to_tensor(table_id_list), - embedding_dim=embedding_dim_list[0], - value_total_len=embedding_dim_list[0], + embedding_dim=embedding_dim_list, + value_total_len=embedding_dim_list, only_var_flag=True, file_type="bin") return tf.group([embedding_table_import]) @@ -602,8 +602,8 @@ class ESWorker: gen_npu_cpu_ops.embedding_table_export(file_path=file_path_tensor, ps_id=ps_id_tensor, table_id=table_id_tensor, - embedding_dim=embedding_dim_list[0], - value_total_len=value_total_len_list[0], + embedding_dim=embedding_dim_list, + value_total_len=value_total_len_list, export_mode="all", only_var_flag=False, file_type="bin") @@ -666,8 +666,8 @@ class ESWorker: gen_npu_cpu_ops.embedding_table_import(ps_id=ps_id_tensor, file_path=file_path_tensor, table_id=table_id_tensor, - embedding_dim=embedding_dim_list[0], - value_total_len=value_total_len_list[0], + embedding_dim=embedding_dim_list, + value_total_len=value_total_len_list, only_var_flag=False, file_type="bin") with tf.control_dependencies([embedding_table_import]): @@ -729,8 +729,8 @@ class ESWorker: gen_npu_cpu_ops.embedding_table_export(file_path=file_path_tensor, ps_id=ps_id_tensor, table_id=table_id_tensor, - embedding_dim=embedding_dim_list[0], - value_total_len=embedding_dim_list[0], + embedding_dim=embedding_dim_list, + value_total_len=embedding_dim_list, export_mode="new", only_var_flag=True, file_type="bin") @@ -771,8 +771,8 @@ class ESWorker: gen_npu_cpu_ops.embedding_table_import(ps_id=ops.convert_to_tensor(-1), file_path=ops.convert_to_tensor(path), table_id=ops.convert_to_tensor(table_id_list), - embedding_dim=embedding_dim_list[0], - value_total_len=embedding_dim_list[0], + embedding_dim=embedding_dim_list, + value_total_len=embedding_dim_list, only_var_flag=True, file_type="bin") return tf.group([embedding_table_import])