diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000000000000000000000000000000000000..21ce891d8cf6fce50abd343d9c18fbc01eaae474 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,5 @@ +{ + "files.associations": { + "chrono": "cpp" + } +} \ No newline at end of file diff --git a/tf_adapter/interface_spec/api_npu_config.pyh b/tf_adapter/interface_spec/api_npu_config.pyh index c89bc3c6bc63a016b3ccfa1f9abc353a98e12088..815cd35a9ab02fdf6a1c827f25a53b67bcff2099 100644 --- a/tf_adapter/interface_spec/api_npu_config.pyh +++ b/tf_adapter/interface_spec/api_npu_config.pyh @@ -12,14 +12,14 @@ class NPURunConfig(run_config_lib.RunConfig): enable_exception_dump=0, op_select_implmode=None, optypelist_for_implmode=None, dynamic_input_config=None, aoe_mode=None, work_path=None, buffer_optimize="l2_optimize", enable_small_channel=0, fusion_switch_file=None, enable_compress_weight=False, compress_weight_conf=None, - op_compiler_cache_mode=None, op_compiler_cache_dir=None, debug_dir=None, hcom_multi_mode=False, dynamic_input=False, + op_compiler_cache_mode=None, op_compiler_cache_dir=None, debug_dir=None, hcom_multi_mode=False, dynamic_input=None, dynamic_graph_execute_mode="dynamic_execute", dynamic_inputs_shape_range=None, train_distribute=None, eval_distribute=None, local_rank_id=None, local_device_list=None, session_device_id=None, distribute_config=None, modify_mixlist=None, op_precision_mode=None, device_type="default_device_type", soc_config=None, hccl_timeout=None, op_wait_timeout=None, op_execute_timeout=None, HCCL_algorithm=None, customize_dtypes=None, op_debug_config=None, memory_config=None, experimental_config=None, topo_sorting_mode=None, aoe_config_file=None, insert_op_file=None, stream_sync_timeout=-1, - event_sync_timeout=-1, external_weight=False, es_cluster_config=None, deterministic=0, + event_sync_timeout=-1, external_weight=False, deterministic=0, frozen_variable=False, variable_placement="Device"): class ProfilingConfig(): diff --git a/tf_adapter/kernels/aicpu/host_queue_dataset_op.cc b/tf_adapter/kernels/aicpu/host_queue_dataset_op.cc index b46c6e0e889834f6dbfe293e1124a353cd0ac55d..f661ffdc5fb200125370940d2cfdcd10d54e6763 100644 --- a/tf_adapter/kernels/aicpu/host_queue_dataset_op.cc +++ b/tf_adapter/kernels/aicpu/host_queue_dataset_op.cc @@ -597,8 +597,8 @@ class HostQueueDatasetOp : public DatasetOpKernel { ADP_LOG(INFO) << "Slave SendDataThread exit."; } - void RecordMbufQueueBytes(const bool is_hold_type, const uint64_t args_total_bytes) { - if (!is_hold_type) { return; } + void RecordMbufQueueBytes(const bool is_hold, const uint64_t args_total_bytes) { + if (!is_hold) { return; } mbuf_queue_rear_ = (mbuf_queue_rear_ + 1) % kStringTypeDepth; mbuf_queue_bytes_[mbuf_queue_rear_] = args_total_bytes; } @@ -629,7 +629,7 @@ class HostQueueDatasetOp : public DatasetOpKernel { Status status = Status::OK(); bool is_need_resend = false; - while(!finish_send_) { + while (!finish_send_) { if (IsHoldDataTrans()) { auto start = std::chrono::high_resolution_clock::now(); auto end = start + std::chrono::microseconds(kSleepDuration); diff --git a/tf_adapter/kernels/geop_npu.cc b/tf_adapter/kernels/geop_npu.cc index e2fdb4301a837108be98a6cb05d7f30a4597aebe..fc8538f00695a7934b6f5816fa10bdaae01b8553 100644 --- a/tf_adapter/kernels/geop_npu.cc +++ b/tf_adapter/kernels/geop_npu.cc @@ -326,10 +326,6 @@ void GeOp::Initialize(OpKernelConstruction *ctx) { } ctx->GetAttr("_recompute_mode", &recompute_mode_); - ctx->GetAttr("_deploy_inject_config", &deploy_inject_config_); - ctx->GetAttr("_execute_times", &execute_times_); - ctx->GetAttr("_max_num", &max_num_); - ctx->GetAttr("_embedding_dim", &embedding_dim_); ctx->GetAttr("_dynamic_input", &dynamic_input_); if (!dynamic_input_.empty() && dynamic_input_ == "1") { jit_compile_ = true; @@ -349,9 +345,7 @@ void GeOp::Initialize(OpKernelConstruction *ctx) { << ", getnext_inputs_shape_range: " << getnext_inputs_shape_range_ << ", data_inputs_shape_range: " << data_inputs_shape_range_ << ", is_train_graph: " << is_train_graph_ << ", is_dynamic_getnext: " << is_dynamic_getnext_ << ", placeholder_index: " << placeholder_index_ - << ", is_var_init_graph: " << is_var_init_graph_ << ", deploy_inject_config: " << deploy_inject_config_ - << ", execute_times: " << execute_times_ << ", max_num: " << max_num_ - << ", embedding_dim: " << embedding_dim_; + << ", is_var_init_graph: " << is_var_init_graph_; // global environment Initialize, invoke once for each process std::string sess_config = ""; @@ -865,18 +859,6 @@ void GeOp::ComputeAsync(OpKernelContext *ctx, DoneCallback done) { if (!recompute_mode_.empty()) { graph_options_["ge.recompute"] = recompute_mode_; } - if (!deploy_inject_config_.empty()) { - graph_options_["ge.exec.clusterSpec"] = deploy_inject_config_; - } - if (!execute_times_.empty()) { - graph_options_["ge.execute_times"] = execute_times_; - } - if (!max_num_.empty()) { - graph_options_["ge.max_num"] = max_num_; - } - if (!embedding_dim_.empty()) { - graph_options_["ge.embedding_dim"] = embedding_dim_; - } SetDynamicInput(); graph_options_["ge.exec.isVarInitGraph"] = is_var_init_graph_; graph_options_["ge.jit_compile"] = jit_compile_ ? "1" : "0"; diff --git a/tf_adapter/kernels/geop_npu.h b/tf_adapter/kernels/geop_npu.h index 93a4681462b4a9d5357a7742a4385b0faca72672..f2b970c1b57bc4036d2536047b1b597e405f1ddd 100644 --- a/tf_adapter/kernels/geop_npu.h +++ b/tf_adapter/kernels/geop_npu.h @@ -190,10 +190,6 @@ private: std::atomic_flag tuned_flag_; std::vector> remove_index_; std::string is_var_init_graph_; - std::string deploy_inject_config_; - std::string execute_times_; - std::string max_num_; - std::string embedding_dim_; std::string recompute_mode_; std::vector> input_shapes_vec_; bool jit_compile_; diff --git a/tf_adapter/optimizers/om_partition_subgraphs_pass.cc b/tf_adapter/optimizers/om_partition_subgraphs_pass.cc index 13588e8c9b048a3a7a49f2c3cfcb7ef87f11dd09..66fd524ad2083f60c1533076b4dac1df719fa85b 100644 --- a/tf_adapter/optimizers/om_partition_subgraphs_pass.cc +++ b/tf_adapter/optimizers/om_partition_subgraphs_pass.cc @@ -2000,10 +2000,6 @@ void OMPartitionSubgraphsPass::GetGraphConfig(const Node &node, bool enable_dp, const std::string kDynamicInputsShapeRange = "_graph_dynamic_inputs_shape_range"; const std::string kIsTrainGraph = "_is_train_graph"; const std::string kRecomputeMode = "_recompute_mode"; - const std::string kDeployInjectConfig = "_deploy_inject_config"; - const std::string kExecuteTimes = "_execute_times"; - const std::string kMaxNum = "_max_num"; - const std::string kEmbeddingDim = "_embedding_dim"; if (node_attrs.find(kDynamicInput) != node_attrs.end()) { bool dynamic_input = node_attrs.at(kDynamicInput).b(); graph_options["dynamic_input"] = std::to_string(static_cast(dynamic_input)); @@ -2024,21 +2020,6 @@ void OMPartitionSubgraphsPass::GetGraphConfig(const Node &node, bool enable_dp, std::string recompute_mode = node_attrs.at(kRecomputeMode).s(); graph_options["recompute_mode"] = recompute_mode; } - if (node_attrs.find(kDeployInjectConfig) != node_attrs.end()) { - graph_options["deploy_inject_config"] = node_attrs.at(kDeployInjectConfig).s(); - } - if (node_attrs.find(kExecuteTimes) != node_attrs.end()) { - const auto execute_times = node_attrs.at(kExecuteTimes).i(); - graph_options["execute_times"] = std::to_string(static_cast(execute_times)); - } - if (node_attrs.find(kMaxNum) != node_attrs.end()) { - const auto max_num = node_attrs.at(kMaxNum).i(); - graph_options["max_num"] = std::to_string(static_cast(max_num)); - } - if (node_attrs.find(kEmbeddingDim) != node_attrs.end()) { - const auto embedding_dim = node_attrs.at(kEmbeddingDim).i(); - graph_options["embedding_dim"] = std::to_string(static_cast(embedding_dim)); - } } Status OMPartitionSubgraphsPass::ProcessGetNext(Node &node, const std::string enable_dp, @@ -2263,7 +2244,9 @@ Status OMPartitionSubgraphsPass::ProcessGraph(std::unique_ptr *graph, Fun return Status::OK(); } if (mix_compile_mode) { - TF_RETURN_IF_ERROR(CopyVarsBetweenGeOp(graph_in)); + if (pass_options["variable_location"] != "Host") { + TF_RETURN_IF_ERROR(CopyVarsBetweenGeOp(graph_in)); + } TF_RETURN_IF_ERROR(CopyConstBetweenGeOp(graph_in)); } diff --git a/tf_adapter/python/npu_bridge/embedding/__init__.py b/tf_adapter/python/npu_bridge/embedding/__init__.py deleted file mode 100644 index f4cdb646bdf64b19f3293b49e2a9ca71b53a2be4..0000000000000000000000000000000000000000 --- a/tf_adapter/python/npu_bridge/embedding/__init__.py +++ /dev/null @@ -1,23 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - - -from npu_bridge.embedding.embedding_optimizer import AdamOptimizer as EmbeddingAdamOptimizer -from npu_bridge.embedding.embedding_optimizer import AdagradOptimizer as EmbeddingAdagradOptimizer -from npu_bridge.embedding.embedding_service import ESWorker as EmbeddingService -from npu_bridge.embedding.tf_path import path_on_tf -path_on_tf() \ No newline at end of file diff --git a/tf_adapter/python/npu_bridge/embedding/embedding_optimizer.py b/tf_adapter/python/npu_bridge/embedding/embedding_optimizer.py deleted file mode 100644 index 0d596fed7eb104f104bd63e076d57531ad4d0bfc..0000000000000000000000000000000000000000 --- a/tf_adapter/python/npu_bridge/embedding/embedding_optimizer.py +++ /dev/null @@ -1,158 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -from tensorflow.python.framework import ops -from tensorflow.python.eager import context -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.training import adam -from tensorflow.python.training import adagrad -from tensorflow.python.training import training_ops -from tensorflow.python.training import training_util -from npu_bridge.embedding.embedding_resource import NpuEmbeddingResource -from npu_bridge.npu_cpu.npu_cpu_ops import gen_npu_cpu_ops - -_GLOBAL_STEP_VALUE = 1 - - -class AdamOptimizer(adam.AdamOptimizer): - @property - def embedding_dims(self): - return self._embedding_dims - - @embedding_dims.setter - def embedding_dims(self, val): - self._embedding_dims = val - - def _prepare(self): - lr = self._call_if_callable(self._lr) - epsilon = self._call_if_callable(self._epsilon) - self._beta1_t_list = [] - self._beta2_t_list = [] - self._lr_t = ops.convert_to_tensor(lr, name="learning_rate") - self._epsilon_t = ops.convert_to_tensor(epsilon, name="epsilon") - - def _resource_apply_sparse(self, grad, var, indices): - if isinstance(var, NpuEmbeddingResource): - beta1 = self._call_if_callable(self._beta1) - beta2 = self._call_if_callable(self._beta2) - self._beta1_t = ops.convert_to_tensor(beta1, name="beta1" + str(self.table_idx)) - self._beta2_t = ops.convert_to_tensor(beta2, name="beta2" + str(self.table_idx)) - self._beta1_t_list.append(self._beta1_t) - self._beta2_t_list.append(self._beta2_t) - beta1_power, beta2_power = self._get_beta_accumulators() - self.table_idx += 1 - return gen_npu_cpu_ops.embedding_apply_adam(var.handle, beta1_power, beta2_power, - math_ops.cast(self._lr_t, grad.dtype), - math_ops.cast(self._beta1_t, grad.dtype), - math_ops.cast(self._beta2_t, grad.dtype), - math_ops.cast(self._epsilon_t, grad.dtype), - grad, - indices, - ops.convert_to_tensor(_GLOBAL_STEP_VALUE), - self._embedding_dims) - else: - return self._apply_sparse_shared(grad, var, indices, self._resource_scatter_add) - - def _create_slots(self, var_list): - self.table_num = 0 - self.table_idx = 0 - first_var = min(var_list, key=lambda x: x.name) - for idx in range(len(var_list)): - self._create_non_slot_variable( - initial_value=self._beta1, name="beta1_power" + str(idx), colocate_with=first_var) - self._create_non_slot_variable( - initial_value=self._beta2, name="beta2_power" + str(idx), colocate_with=first_var) - self.table_num += 1 - - for v in var_list: - if not isinstance(v, NpuEmbeddingResource): - self._zeros_slot(v, "m", self._name) - self._zeros_slot(v, "v", self._name) - - def _get_beta_accumulators(self): - with ops.init_scope(): - if context.executing_eagerly(): - graph = None - else: - graph = ops.get_default_graph() - return (self._get_non_slot_variable("beta1_power" + str(self.table_idx), graph=graph), - self._get_non_slot_variable("beta2_power" + str(self.table_idx), graph=graph)) - - def _finish(self, update_ops, name_scope): - # Update the power accumulators. - self.table_num = 0 - self.table_idx = 0 - finish_output = [] - with ops.control_dependencies(update_ops): - beta1_power_list = [] - beta2_power_list = [] - for k in update_ops: - beta1_power, beta2_power = self._get_beta_accumulators() - beta1_power_list.append(beta1_power) - beta2_power_list.append(beta2_power) - self.table_idx += 1 - for idx in range(len(update_ops)): - beta1_power = beta1_power_list[idx] - beta2_power = beta2_power_list[idx] - with ops.colocate_with(beta1_power): - update_beta1 = beta1_power.assign( - beta1_power * self._beta1_t_list[idx], use_locking=self._use_locking) - update_beta2 = beta2_power.assign( - beta2_power * self._beta2_t_list[idx], use_locking=self._use_locking) - new_update_op = [] - new_update_op.append(update_ops[idx]) - finish_output.append(control_flow_ops.group( - *new_update_op + [update_beta1, update_beta2], name=name_scope + str(idx))) - return finish_output - - -class AdagradOptimizer(adagrad.AdagradOptimizer): - @property - def embedding_dims(self): - return self._embedding_dims - - @embedding_dims.setter - def embedding_dims(self, val): - self._embedding_dims = val - - def _resource_apply_sparse(self, grad, var, indices): - if isinstance(var, NpuEmbeddingResource): - return gen_npu_cpu_ops.embedding_apply_ada_grad(var.handle, - math_ops.cast(self._learning_rate_tensor, grad.dtype), - grad, - indices, - ops.convert_to_tensor(_GLOBAL_STEP_VALUE), - self._embedding_dims) - else: - return self.training_ops.resource_sparse_apply_adagrad(var.handle, grad.handle, - math_ops.cast(self._learning_rate_tensor, - grad.dtype), - grad, indices, - use_locking=self._use_locking) - - def _create_slots(self, var_list): - for v in var_list: - if not isinstance(v, NpuEmbeddingResource): - dtype = v.dtype.base_dtype - if v.get_shape().is_fully_defined(): - init = init_ops.constant_initializer(self._initial_accumulator_value, - dtype=dtype) - else: - init = self._init_constant_op(v, dtype) - self._get_or_make_slot_with_initializer(v, init, v.get_shape(), dtype, - "accumulator", self._name) diff --git a/tf_adapter/python/npu_bridge/embedding/embedding_resource.py b/tf_adapter/python/npu_bridge/embedding/embedding_resource.py deleted file mode 100644 index fdfc288bc52c88d1367d32e0efa4344bb6f12e4e..0000000000000000000000000000000000000000 --- a/tf_adapter/python/npu_bridge/embedding/embedding_resource.py +++ /dev/null @@ -1,39 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -from tensorflow.python.framework import ops -from npu_bridge.npu_cpu.npu_cpu_ops import gen_npu_cpu_ops - - -class NpuEmbeddingResource: - - def __init__(self, table_id): - self.name = table_id - self._tensor = gen_npu_cpu_ops.table_to_resource(ops.convert_to_tensor(table_id)) - - @property - def handle(self): - return self._tensor - - @property - def graph(self): - return self._tensor.graph - - @property - def op(self): - return self._tensor.op - diff --git a/tf_adapter/python/npu_bridge/embedding/embedding_service.py b/tf_adapter/python/npu_bridge/embedding/embedding_service.py deleted file mode 100644 index a661e11c99bb6ca81f36700cc1e5b9c270e2181b..0000000000000000000000000000000000000000 --- a/tf_adapter/python/npu_bridge/embedding/embedding_service.py +++ /dev/null @@ -1,363 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -import json -import contextlib -import os -import math -import tensorflow as tf -from tensorflow.python.framework import ops -from tensorflow.python.framework import random_seed -from tensorflow.core.framework import attr_value_pb2 -from npu_bridge.npu_cpu.npu_cpu_ops import gen_npu_cpu_ops -from npu_bridge.embedding.embedding_resource import NpuEmbeddingResource -from npu_bridge.embedding import embedding_optimizer -from npu_bridge.embedding.embedding_table_map_policy import NoneTableMapPolicy, AutoMergeTableMapPolicy - -_INT32_MAX_VALUE = 2147483647 - - -@contextlib.contextmanager -def specified_ps_engine_scope(): - """ - Enable the non npu compilation of operators within the scope. - """ - attrs = { - "_process_node_engine_id": attr_value_pb2.AttrValue(s=tf.compat.as_bytes("PS")) - } - with ops.get_default_graph()._attr_scope(attrs): - yield - - -class ESWorker: - """ Embedding service class. """ - - def __init__(self, config_from_param=None): - env_dist = os.environ - cluster_config_from_env = env_dist.get("ESCLUSTER_CONFIG_PATH") - if cluster_config_from_env is None: - if config_from_param is None: - raise ValueError("EsClusterConfig and env variable are both null.") - es_cluster_config = config_from_param - else: - es_cluster_config = cluster_config_from_env - with open(es_cluster_config, encoding='utf-8') as a: - es_cluster_config_json = json.load(a) - self._es_cluster_conf = json.dumps(es_cluster_config_json) - self._ps_num = int(es_cluster_config_json["psNum"]) - self._embedding_dim = -1 - self._max_num = -1 - self._ps_ids = [] - self._ps_ids_list = es_cluster_config_json["psCluster"] - self._init_embedding_hash_maps = {} - self._init_partition_maps = {} - self._table_to_embedding_dim = {} - for each_ps in self._ps_ids_list: - self._ps_ids.append(each_ps["id"]) - self._train_mode = True - self._train_level = False - self._optimizer = None - self.slot_vars_num = None - self._initializer = None - self._init_flag = False - self._table_has_init = [] - self.user_defined_table_infos = [] - self.table_map_policy = None - self.table_create_infos = [] - self.total_variable_table = [] - self.total_embedding_count = 0 - config = tf.ConfigProto() - custom_op = config.graph_options.rewrite_options.custom_optimizers.add() - custom_op.name = "NpuOptimizer" - custom_op.parameter_map["es_cluster_config"].s = tf.compat.as_bytes(self._es_cluster_conf) - self.es_all_config = config - - # 提供embedding init功能 - # @param vocabulary_size int 类型 - # @param file_path string 类型 - # @param file_name string 类型 - # @param table_id int32 类型 - # @param max_batch_size int32 类型 - # @param optimizer 类型 - # @param initializer string 类型 - # @param embedding_dim int32 类型 - # @param only_var bool 类型 - # @param mode string 类型 - # @param partition_num int 类型 - def embedding_init(self, vocabulary_size, file_path, file_name, table_id, max_batch_size, optimizer=None, - initializer=None, embedding_dim=-1, only_var=False, mode="bin", partition_num=65537): - """ Operator for init embedding table. """ - if vocabulary_size is None or table_id is None or max_batch_size is None or embedding_dim is None: - raise ValueError("vocabulary_size or table_id or max_batch_size or embedding_dim is None.") - if (not isinstance(vocabulary_size, int)) or (not isinstance(table_id, int)) or \ - (not isinstance(max_batch_size, int)) or (not isinstance(embedding_dim, int)): - raise ValueError("vocabulary_size, table_id, max_batch_size and embedding_dim must be int.") - if vocabulary_size < 0 or table_id < 0: - raise ValueError("vocabulary_size and table_id can not be smaller than zero.") - if vocabulary_size >= _INT32_MAX_VALUE or table_id >= _INT32_MAX_VALUE: - raise ValueError("vocabulary_size or table_id exceed int32 max value.") - if embedding_dim <= 0 or partition_num <= 0 or max_batch_size <= 0: - raise ValueError("embedding_dim, partition_num and max_batch_size must be greater than zero.") - if table_id in self._table_has_init: - raise ValueError("this table has already initialized.") - self._embedding_dim = embedding_dim - self._max_num = max_batch_size - self._table_to_embedding_dim[table_id] = embedding_dim - self._initializer = initializer - self._table_has_init.append(table_id) - bucket_size = math.ceil(vocabulary_size / self._ps_num) - if optimizer is None: - if file_path is None or file_name is None or (not tf.gfile.Exists(os.path.join(file_path, file_name))): - raise ValueError("embedding table file not exist.") - self._train_mode = False - self.slot_vars_num = 0 - else: - if (not isinstance(optimizer, embedding_optimizer.AdamOptimizer) and - not isinstance(optimizer, embedding_optimizer.AdagradOptimizer)): - raise ValueError( - "optimizer should be embedding_optimizer.AdamOptimizer or embedding_optimizer.AdagradOptimizer") - if (initializer is not None) and (initializer is not 'random_uniform') and \ - (initializer is not 'truncated_normal'): - raise ValueError("initializer must be random_uniform or truncated_normal.") - self._optimizer = optimizer - self._optimizer._embedding_dims = embedding_dim - # adam include m and v, 2 slots; adagrad include accumulator, 1 slot - self.slot_vars_num = 2 if isinstance(self._optimizer, embedding_optimizer.AdamOptimizer) else 1 - if (file_path is None) or (file_name is None) or (not tf.gfile.Exists(os.path.join(file_path, file_name))): - if initializer is None: - raise ValueError("In new embedding training, initializer can not be None.") - self._train_level = True - with specified_ps_engine_scope(): - self._init_partition_maps[table_id] = \ - gen_npu_cpu_ops.init_partition_map(ps_num=ops.convert_to_tensor(self._ps_num), - ps_ids=ops.convert_to_tensor(self._ps_ids), - partition_num=partition_num) - self._init_partition_maps.get(table_id)._set_attr("_execute_times", attr_value_pb2.AttrValue(i=1)) - self._init_partition_maps.get(table_id)._set_attr("_embedding_dim", - attr_value_pb2.AttrValue(i=self._embedding_dim)) - self._init_partition_maps.get(table_id)._set_attr("_max_num", attr_value_pb2.AttrValue(i=self._max_num)) - self._init_partition_maps.get(table_id)._set_attr("_deploy_inject_config", - attr_value_pb2.AttrValue( - s=tf.compat.as_bytes(self._es_cluster_conf))) - return self._init_hashmap_and_table_import(bucket_size, file_path, file_name, table_id, - initializer, embedding_dim, only_var, mode) - - # 提供embedding lookup功能 - # @param table_id int32 类型 - # @param input_ids int64 类型 - # @return values float32 类型 - def embedding_lookup(self, table_id, input_ids): - """ Operator for look up in embedding table. """ - if (table_id is None) or (input_ids is None): - raise ValueError("table_id or input_ids must be specified.") - if not isinstance(table_id, int): - raise ValueError("type of table_id must be int.") - if input_ids.dtype != tf.int64: - raise ValueError("dtype of input_ids must be tf.int64.") - if table_id < 0: - raise ValueError("table_id can not be smaller than zero.") - if not self._init_flag: - raise ValueError("embedding must init first!") - if table_id not in self._table_has_init: - raise ValueError("this table has not yet initialized.") - if self._train_mode: - seed1, seed2 = random_seed.get_seed(None) - result = gen_npu_cpu_ops.embedding_table_find_and_init(table_id=ops.convert_to_tensor(table_id), - keys=input_ids, - embedding_dim= - self._table_to_embedding_dim.get(table_id), - random_alg=self._initializer, - seed=seed1, seed2=seed2, - value_total_len= - self._table_to_embedding_dim.get(table_id) * - (self.slot_vars_num + 1) - ) - else: - result = gen_npu_cpu_ops.embedding_table_find(table_id=ops.convert_to_tensor(table_id), - keys=input_ids, - embedding_dim=self._table_to_embedding_dim.get(table_id)) - result.op._set_attr("_embedding_dim", attr_value_pb2.AttrValue(i=self._embedding_dim)) - result.op._set_attr("_max_num", attr_value_pb2.AttrValue(i=self._max_num)) - result.op._set_attr("_deploy_inject_config", - attr_value_pb2.AttrValue(s=tf.compat.as_bytes(self._es_cluster_conf))) - return result - - # 提供embedding update功能 - # @param loss 类型 - # @param params float32 类型 - # @param table_ids int32 类型 - # @param input_ids_list int64 类型 - def embedding_update(self, loss, params, table_ids, input_ids_list): - """ Operator for update in embedding table. """ - if (loss is None) or (params is None) or (table_ids is None) or (input_ids_list is None): - raise ValueError("loss or params or table_ids or input_ids_list is None.") - if (isinstance(loss, str)) or (isinstance(params, str)) or isinstance(table_ids, str) or \ - isinstance(input_ids_list, str): - raise ValueError("loss, params, table_ids and input_ids_list can not be str.") - if not self._init_flag: - raise ValueError("embedding must init first!") - if (not isinstance(params, (list, tuple)) and not isinstance(table_ids, (list, tuple)) - and not isinstance(input_ids_list, (list, tuple))): - params = [params] - table_ids = [table_ids] - input_ids_list = [input_ids_list] - for table_id in table_ids: - if table_id not in self._table_has_init: - raise ValueError("this table has not yet initialized.") - if (len(params) != len(table_ids)) or (len(params) != len(input_ids_list)) \ - or (len(table_ids) != len(input_ids_list)): - raise ValueError("The length of params, table_ids, input_ids_list should be equal.") - embedding_grads = tf.gradients(loss, params) - params_grads = [] - for i in range(len(embedding_grads)): - params_grads.append(tf.IndexedSlices(embedding_grads[i], input_ids_list[i], dense_shape=params[i].shape)) - with specified_ps_engine_scope(): - var_refs = [NpuEmbeddingResource(table_id) for table_id in table_ids] - update_op = self._optimizer.apply_gradients(list(zip(params_grads, var_refs))) - return update_op - - # 提供训练好的embedding values save功能 - # @param file_path string 类型 - # @param file_name string 类型 - # @param table_id int32 类型 - # @param mode string 类型 - def embedding_save(self, file_path, file_name, table_id, mode="bin"): - """ Operator for save values in embedding table. """ - if file_path is None or file_name is None or table_id is None: - raise ValueError("table_id, embedding table file_name and file_path can not be None.") - if table_id not in self._table_has_init: - raise ValueError("this table has not yet initialized.") - if not os.path.exists(file_path): - os.mkdir(file_path) - with specified_ps_engine_scope(): - embedding_dim = self._table_to_embedding_dim.get(table_id) - return gen_npu_cpu_ops.embedding_table_export(file_path, file_name, ops.convert_to_tensor(-1), table_id, - embedding_dim, embedding_dim, True, mode) - - # 提供训练好的embedding values + 调优参数 save功能 - # @param file_path string 类型 - # @param file_name string 类型 - # @param table_id int32 类型 - # @param mode string 类型 - def embedding_ckpt_save(self, file_path, file_name, table_id, mode="bin"): - """ Operator for save values and optimizer params in embedding table. """ - if file_path is None or file_name is None or table_id is None: - raise ValueError("table_id, embedding table file_name and file_path can not be None.") - if table_id not in self._table_has_init: - raise ValueError("this table has not yet initialized.") - if not os.path.exists(file_path): - os.mkdir(file_path) - with specified_ps_engine_scope(): - embedding_dim = self._table_to_embedding_dim.get(table_id) - return gen_npu_cpu_ops.embedding_table_export(file_path, file_name, ops.convert_to_tensor(-1), table_id, - embedding_dim, embedding_dim * (self.slot_vars_num + 1), - False, mode) - - def data_parallel_embedding(self, max_vocabulary_size, embedding_dim, multihot_lens, allow_merge=True): - if not isinstance(multihot_lens, list): - raise ValueError("multihot_lens must be list.") - new_table_info = dict( - max_vocabulary_size=max_vocabulary_size, - embedding_dim=embedding_dim, - multihot_lens=multihot_lens, - allow_merge=allow_merge - ) - self.user_defined_table_infos.append(new_table_info) - - def init_table(self, table_map_policy=AutoMergeTableMapPolicy()): - self.table_map_policy = table_map_policy - self.table_create_infos = self.table_map_policy.map_table_infos(self.user_defined_table_infos) - for table_info_ in self.table_create_infos: - self.total_variable_table.append(tf.Variable( - tf.random_normal([table_info_['max_vocabulary_size'], table_info_['embedding_dim']], mean=0.0, - stddev=1.0, dtype=tf.float32, seed=1234) - )) - self.total_embedding_count += 1 - - def embeddings_look_up(self, tf_indices): - if self.total_embedding_count != len(self.table_create_infos): - raise ValueError("Must init_table() first!") - (in_slot_size_group, slot_to_table, table_to_input_group, \ - table_to_slot, table_to_output_slots) = \ - (self.table_map_policy.in_slot_size_group, self.table_map_policy.slot_to_table, \ - self.table_map_policy.table_to_input_groups, self.table_map_policy.table_to_slot, \ - self.table_map_policy.table_to_output_slots) - - indices_split = tf.split(tf_indices, in_slot_size_group, axis=1) - for tid in range(self.total_embedding_count): - table_to_input_group[tid] = [] - for sid, indices in enumerate(indices_split): - tid = slot_to_table[sid] - table_to_input_group[tid].append(indices) - - output_slots = [None for _ in in_slot_size_group] - for tid, table_input_group in enumerate(table_to_input_group): - table_input_after_mapping = \ - gen_npu_cpu_ops.embedding_feature_mapping(feature_id=tf.concat(table_input_group, axis=1)) - table_to_input_group[tid] = table_input_after_mapping - table_embedding = tf.nn.embedding_lookup(self.total_variable_table[tid], table_input_after_mapping) - out_embedding_splited = tf.split(table_embedding, table_to_output_slots[tid], axis=1) - for out_emb, sid in zip(out_embedding_splited, table_to_slot[tid]): - output_slots[sid] = out_emb - return output_slots - - def _init_hashmap_and_table_import(self, bucket_size, file_path, file_name, table_id, - initializer, embedding_dim, only_var, mode): - with tf.control_dependencies([self._init_partition_maps.get(table_id)]): - if self._train_mode: - if self._train_level: - seed1, seed2 = random_seed.get_seed(None) - self._init_embedding_hash_maps[table_id] = \ - gen_npu_cpu_ops.init_embedding_hashmap(table_id=ops.convert_to_tensor(table_id), - bucket_size=bucket_size, - value_total_len=embedding_dim * (self.slot_vars_num + 1), - embedding_dim=embedding_dim, - random_alg=initializer, seed=seed1, seed2=seed2) - else: - self._init_embedding_hash_maps[table_id] = \ - gen_npu_cpu_ops.init_embedding_hashmap(table_id=ops.convert_to_tensor(table_id), - bucket_size=bucket_size, - value_total_len=embedding_dim * (self.slot_vars_num + 1), - embedding_dim=embedding_dim, - random_alg=None, seed=None, seed2=None) - else: - self._init_embedding_hash_maps[table_id] = \ - gen_npu_cpu_ops.init_embedding_hashmap(table_id=ops.convert_to_tensor(table_id), - bucket_size=bucket_size, - value_total_len=embedding_dim, - embedding_dim=embedding_dim, - random_alg=None, seed=None, seed2=None) - self._init_flag = True - return self._init_or_restore(file_path, file_name, table_id, embedding_dim, only_var, mode) - - def _init_or_restore(self, file_path, file_name, table_id, embedding_dim, only_var, mode): - if self._train_mode and self._train_level: - return tf.group( - [tf.initializers.variables(self._optimizer.variables()), self._init_embedding_hash_maps.get(table_id)]) - # restore embedding table - with tf.control_dependencies([self._init_embedding_hash_maps.get(table_id)]): - embedding_table_import = gen_npu_cpu_ops.embedding_table_import( - file_path=ops.convert_to_tensor(file_path), - file_name=ops.convert_to_tensor(file_name), - # ps_id will be changed in executor, so can not be optimized in graph - ps_id=ops.convert_to_tensor(-1), - table_id=ops.convert_to_tensor(table_id), - embedding_dim=embedding_dim, - value_total_len=embedding_dim * (self.slot_vars_num + 1), - only_var_flag=only_var, - file_type=mode) - return tf.group([embedding_table_import]) diff --git a/tf_adapter/python/npu_bridge/embedding/embedding_table_map_policy.py b/tf_adapter/python/npu_bridge/embedding/embedding_table_map_policy.py deleted file mode 100644 index c9cf6be1c032054e9c128525c6763678f53b6d3f..0000000000000000000000000000000000000000 --- a/tf_adapter/python/npu_bridge/embedding/embedding_table_map_policy.py +++ /dev/null @@ -1,119 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -from functools import reduce - - -class BaseTableMapPolicy(): - def __init__(self, assign_groups=None): - self.table_create_infos = [] - if assign_groups is None: - self.assign_groups = [] - else: - self.assign_groups = assign_groups - self.in_slot_size_group = [] - self.slot_to_table = [] - self.table_to_output_slots = [] - self.table_to_input_groups = [] - self.table_to_slot = [] - - @staticmethod - def _is_equal_table_info(info1, info2): - if info1['embedding_dim'] != info2['embedding_dim']: # dim of table is the same or not - print('embedding dim different!, value is %d and %d' % (info1['embedding_dim'], info2['embedding_dim'])) - return False - return True - - def map_table_infos(self, user_defined_table_infos): - raise NotImplementedError() - - def _register_new_table_info(self, new_table_info): - self.table_create_infos.append(new_table_info) - self.table_to_output_slots.append([]) - self.table_to_input_groups.append([]) - self.table_to_slot.append([]) - - def _merge_new_table_info(self, new_table_info, assign_tabld_id): - main_table_info = self.table_create_infos[assign_tabld_id] - main_table_info['multihot_lens'] += new_table_info['multihot_lens'] - main_table_info['max_vocabulary_size'] += new_table_info['max_vocabulary_size'] - - def _register_table_info(self, new_table_info, assign_tid=-1): - multihot_lens = new_table_info['multihot_lens'] - in_slot_size = sum(multihot_lens) - out_slot_size = len(multihot_lens) - - tid = assign_tid - if tid == -1: - tid = len(self.table_create_infos) - self._register_new_table_info(new_table_info) - else: - self._merge_new_table_info(new_table_info, tid) - - self.table_to_slot[tid].append(len(self.in_slot_size_group)) - self.table_to_output_slots[tid].append(in_slot_size) - self.in_slot_size_group.append(in_slot_size) - self.slot_to_table.append(tid) - - def _map_table_infos(self, user_defined_table_infos, assign_groups): - self.table_create_infos = [] - assign_groups_flat = reduce(lambda a, b: a+b, assign_groups, []) - sid_to_gid = reduce(lambda a, b: {**a, **b}, - [{sid: gid for sid in group} - for gid, group in enumerate(assign_groups)], {}) - gid_to_tid = dict() - for sid, table_info in enumerate(user_defined_table_infos): - if sid in assign_groups_flat: - gid = sid_to_gid.get(sid) - if gid in gid_to_tid: - self._register_table_info(table_info, assign_tid=gid_to_tid.get(gid)) - else: - tid = len(self.table_create_infos) - self._register_table_info(table_info, assign_tid=-1) - gid_to_tid[gid] = tid - else: - self._register_table_info(table_info, assign_tid=-1) - return self.table_create_infos - - -# no slot merge -class NoneTableMapPolicy(BaseTableMapPolicy): - def map_table_infos(self, user_defined_table_infos): - return self._map_table_infos(user_defined_table_infos, self.assign_groups) - - -# merge slot by user's assign_groups -class AutoMergeTableMapPolicy(BaseTableMapPolicy): - def map_table_infos(self, user_defined_table_infos): - assign_groups_flat = reduce(lambda a, b: a+b, self.assign_groups, []) - new_assign_groups = [] - for sid, table_info in enumerate(user_defined_table_infos): - if sid in assign_groups_flat: - continue - gid = -1 - if user_defined_table_infos[sid]['allow_merge']: - for ngid, group in enumerate(new_assign_groups): - if self._is_equal_table_info(user_defined_table_infos[group[0]], table_info) \ - and user_defined_table_infos[group[0]]['allow_merge']: - gid = ngid - break - if gid == -1: - gid = len(new_assign_groups) - new_assign_groups.append([]) - new_assign_groups[gid].append(sid) - new_assign_groups = self.assign_groups + new_assign_groups - return self._map_table_infos(user_defined_table_infos, new_assign_groups) diff --git a/tf_adapter/python/npu_bridge/embedding/tf_path.py b/tf_adapter/python/npu_bridge/embedding/tf_path.py deleted file mode 100644 index a5717e652ec960fa3a849471e9d02d74ff0c58da..0000000000000000000000000000000000000000 --- a/tf_adapter/python/npu_bridge/embedding/tf_path.py +++ /dev/null @@ -1,64 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -from tensorflow.python.eager import context -from tensorflow.python.framework import ops -from tensorflow.python.ops import resource_variable_ops -from tensorflow.python.ops import variables -from tensorflow.python.training import optimizer as embeddingOptimizer -from npu_bridge.embedding.embedding_resource import NpuEmbeddingResource - - -class _NpuEmbeddingResourceProcessor(embeddingOptimizer._OptimizableVariable): - """Processor for dense NpuEmbeddingResourceProcessor.""" - - def __init__(self, v): - self._v = v - - def target(self): - return self._v - - def update_op(self, optimizer, g): - return optimizer._resource_apply_sparse(g.values, self._v, g.indices) - - -def _get_processor(v): - """The processor of v.""" - if context.executing_eagerly(): - if isinstance(v, ops.Tensor): - return embeddingOptimizer._TensorProcessor(v) - else: - return embeddingOptimizer._DenseResourceVariableProcessor(v) - if isinstance(v, NpuEmbeddingResource): - return _NpuEmbeddingResourceProcessor(v) - if resource_variable_ops.is_resource_variable(v) and not v._in_graph_mode: # pylint: disable=protected-access - # True if and only if `v` was initialized eagerly. - return embeddingOptimizer._DenseResourceVariableProcessor(v) - if v.op.type == "VarHandleOp": - return embeddingOptimizer._DenseResourceVariableProcessor(v) - if isinstance(v, variables.Variable): - return embeddingOptimizer._RefVariableProcessor(v) - if isinstance(v, ops.Tensor): - return embeddingOptimizer._TensorProcessor(v) - - raise NotImplementedError("Trying to optimize unsupported type ", v) - - -def path_on_tf(): - embeddingOptimizer._get_processor = _get_processor - - diff --git a/tf_adapter/python/npu_bridge/estimator/npu/npu_config.py b/tf_adapter/python/npu_bridge/estimator/npu/npu_config.py index 6004fecb7b0090fedc69afca15d3e3463ccd4fdb..50536b343ae6775c1868b2cb11ed85faac949c6a 100644 --- a/tf_adapter/python/npu_bridge/estimator/npu/npu_config.py +++ b/tf_adapter/python/npu_bridge/estimator/npu/npu_config.py @@ -78,7 +78,7 @@ class NPURunConfig(run_config_lib.RunConfig): op_compiler_cache_dir=None, debug_dir=None, hcom_multi_mode=False, - dynamic_input=False, + dynamic_input=None, dynamic_graph_execute_mode="dynamic_execute", dynamic_inputs_shape_range=None, train_distribute=None, @@ -105,7 +105,6 @@ class NPURunConfig(run_config_lib.RunConfig): stream_sync_timeout=-1, event_sync_timeout=-1, external_weight=False, - es_cluster_config=None, deterministic=0, frozen_variable=False, variable_placement="Device" @@ -166,7 +165,6 @@ class NPURunConfig(run_config_lib.RunConfig): experimental_config: The experimental configuration. topo_sorting_mode: Provides an interface for users to customize topology sorting. external_weight: Whether convert const to fileconstant and save weight to file. - es_cluster_config: esClusterConfig from user input in embedding service. frozen_variable: Whether folding constant variables variable_placement: Process variable on host or device """ @@ -256,7 +254,6 @@ class NPURunConfig(run_config_lib.RunConfig): self.stream_sync_timeout = stream_sync_timeout self.event_sync_timeout = event_sync_timeout self._external_weight = external_weight - self.es_cluster_config = es_cluster_config super(NPURunConfig, self).__init__( model_dir=model_dir, diff --git a/tf_adapter/python/npu_bridge/estimator/npu/npu_estimator.py b/tf_adapter/python/npu_bridge/estimator/npu/npu_estimator.py index c3199d8fbe5b92d51223c9e631ab329b7de27353..eff80548b5a0cab482b5d21221ab2e39b31f5e0a 100644 --- a/tf_adapter/python/npu_bridge/estimator/npu/npu_estimator.py +++ b/tf_adapter/python/npu_bridge/estimator/npu/npu_estimator.py @@ -740,7 +740,8 @@ class NPUEstimator(estimator_lib.Estimator): if config._debug_dir is not None: custom_op.parameter_map["debug_dir"].s = tf.compat.as_bytes(config._debug_dir) custom_op.parameter_map["hcom_multi_mode"].b = config._hcom_multi_mode - custom_op.parameter_map["dynamic_input"].b = config._dynamic_input + if config._dynamic_input is not None: + custom_op.parameter_map["dynamic_input"].b = config._dynamic_input custom_op.parameter_map["dynamic_graph_execute_mode"].s = tf.compat.as_bytes(config._dynamic_graph_execute_mode) if config._dynamic_inputs_shape_range is not None: custom_op.parameter_map["dynamic_inputs_shape_range"].s = tf.compat.as_bytes( @@ -768,8 +769,6 @@ class NPUEstimator(estimator_lib.Estimator): custom_op.parameter_map["topo_sorting_mode"].i = config.topo_sorting_mode if config.insert_op_file is not None: custom_op.parameter_map["insert_op_file"].s = config.insert_op_file - if config.es_cluster_config is not None: - custom_op.parameter_map["es_cluster_config"].s = tf.compat.as_bytes(config.es_cluster_config) custom_op.parameter_map["stream_sync_timeout"].i = config.stream_sync_timeout custom_op.parameter_map["event_sync_timeout"].i = config.event_sync_timeout custom_op.parameter_map["external_weight"].b = config._external_weight diff --git a/tf_adapter/python/npu_bridge/npu_cpu/npu_cpu_ops.py b/tf_adapter/python/npu_bridge/npu_cpu/npu_cpu_ops.py index cf43c5536b6d14ca5c50b7cdc14d2ebd07aadfae..1e91cbb8e168ae68eb64df2ded895e03b702306e 100644 --- a/tf_adapter/python/npu_bridge/npu_cpu/npu_cpu_ops.py +++ b/tf_adapter/python/npu_bridge/npu_cpu/npu_cpu_ops.py @@ -283,102 +283,3 @@ def non_zero_with_value_shape(value, index, count): index=index, count=count) return result - - -class ESWorker: - """ Embedding service class. """ - def __init__(self, es_cluster_config): - with open(es_cluster_config, encoding='utf-8') as a: - es_cluster_config_json = json.load(a) - self._es_cluster_conf = json.dumps(es_cluster_config_json) - self._ps_num = int(es_cluster_config_json["psNum"]) - self._embedding_dim = -1 - self._max_num = -1 - self._ps_ids = [] - self._ps_ids_list = es_cluster_config_json["psCluster"] - for each_ps in self._ps_ids_list: - self._ps_ids.append(each_ps["id"]) - - config = tf.ConfigProto() - custom_op = config.graph_options.rewrite_options.custom_optimizers.add() - custom_op.name = "NpuOptimizer" - custom_op.parameter_map["es_cluster_config"].s = tf.compat.as_bytes(self._es_cluster_conf) - self.es_all_config = config - - ## 提供embedding init功能 - # @param bucket_size int 类型 - # @param file_path string 类型 - # @param file_name string 类型 - # @param table_id uint32 类型 - # @param embedding_dim uint32 类型 - # @param max_batch_size uint32 类型 - def embedding_init(self, bucket_size, file_path, file_name, table_id, embedding_dim, max_batch_size): - """ Operator for init embedding table. """ - self._embedding_dim = embedding_dim - self._max_num = max_batch_size - ps_num = tf.constant(self._ps_num, dtype=tf.uint32, name='ps_num') - ps_ids = tf.constant(self._ps_ids, dtype=tf.uint32, name='ps_ids') - ps_engine_values = "PS" - ps_engine = attr_value_pb2.AttrValue(s=compat.as_bytes(ps_engine_values)) - ps_num.op._set_attr("_process_node_engine_id", ps_engine) - ps_ids.op._set_attr("_process_node_engine_id", ps_engine) - init_partition_map = gen_npu_cpu_ops.init_partition_map(ps_num=ps_num, - ps_ids=ps_ids) - table_id = tf.constant(table_id, dtype=tf.uint32, name='table_id') - table_id.op._set_attr("_process_node_engine_id", ps_engine) - with tf.control_dependencies([init_partition_map]): - init_embedding_hash_map = gen_npu_cpu_ops.init_embedding_hashmap(table_id=table_id, bucket_size=bucket_size) - file_name = tf.constant(file_name, dtype=tf.string, name='file_name') - file_path = tf.constant(file_path, dtype=tf.string, name='file_path') - embedding_dim = tf.constant(embedding_dim, dtype=tf.uint32, name="embedding_dim") - ps_id = -1 - ps_id = tf.constant(ps_id, dtype=tf.uint32, name='ps_id') - ps_id.op._set_attr("_process_node_engine_id", ps_engine) - file_name.op._set_attr("_process_node_engine_id", ps_engine) - file_path.op._set_attr("_process_node_engine_id", ps_engine) - embedding_dim.op._set_attr("_process_node_engine_id", ps_engine) - with tf.control_dependencies([init_embedding_hash_map]): - embedding_table_import = gen_npu_cpu_ops.embedding_table_import(file_path=file_path, - file_name=file_name, - ps_id=ps_id, - table_id=table_id, - embedding_dim=embedding_dim) - init_partition_map._set_attr("_process_node_engine_id", ps_engine) - init_embedding_hash_map._set_attr("_process_node_engine_id", ps_engine) - embedding_table_import._set_attr("_process_node_engine_id", ps_engine) - execute_times_value = 1 - execute_times = attr_value_pb2.AttrValue(i=execute_times_value) - embedding_table_import._set_attr("_execute_times", execute_times) - embedding_dim_value = 1 - embedding_dim = attr_value_pb2.AttrValue(i=embedding_dim_value) - embedding_table_import._set_attr("_embedding_dim", embedding_dim) - max_num_value = self._max_num - max_num = attr_value_pb2.AttrValue(i=max_num_value) - embedding_table_import._set_attr("_max_num", max_num) - deploy_inject_config_value = self._es_cluster_conf - deploy_inject_config = attr_value_pb2.AttrValue(s=compat.as_bytes(deploy_inject_config_value)) - embedding_table_import._set_attr("_deploy_inject_config", deploy_inject_config) - result = embedding_table_import - return result - - # 提供embedding lookup功能 - # @param table_id uint32 类型 - # @param input_ids uint64 类型 - # @return values float32 类型 - def embedding_look_up(self, table_id, input_ids): - """ Operator for look up in embedding table. """ - table_id = tf.constant(table_id, dtype=tf.uint32, name="table_id") - result = gen_npu_cpu_ops.embedding_table_find(table_id=table_id, - keys=input_ids, - embedding_dim=self._embedding_dim) - max_num_value = self._max_num - max_num = attr_value_pb2.AttrValue(i=max_num_value) - result.op._set_attr("_max_num", max_num) - if self._embedding_dim == -1: - self._embedding_dim = 4 - embedding_dim_value = attr_value_pb2.AttrValue(i=self._embedding_dim) - result.op._set_attr("_embedding_dim", embedding_dim_value) - deploy_inject_config_value = self._es_cluster_conf - deploy_inject_config = attr_value_pb2.AttrValue(s=compat.as_bytes(deploy_inject_config_value)) - result.op._set_attr("_deploy_inject_config", deploy_inject_config) - return result \ No newline at end of file diff --git a/tf_adapter/tests/st/kernels/pbtxt/geop.pbtxt b/tf_adapter/tests/st/kernels/pbtxt/geop.pbtxt index 5fd4165770dbcab69f207df8129b0d5772c4b936..58575959fcc4e83066f0231f17413e0ac2d7c80d 100644 --- a/tf_adapter/tests/st/kernels/pbtxt/geop.pbtxt +++ b/tf_adapter/tests/st/kernels/pbtxt/geop.pbtxt @@ -113,24 +113,6 @@ node { s: "dynamic_execute" } } - attr { - key: "_execute_times" - value { - s: "2" - } - } - attr { - key: "_max_num" - value { - s: "1" - } - } - attr { - key: "_embedding_dim" - value { - s: "1" - } - } attr { key: "_dynamic_input" value { diff --git a/tf_adapter/tests/st/optimizers/testcase/get_attr_optimize_pass_test.cc b/tf_adapter/tests/st/optimizers/testcase/get_attr_optimize_pass_test.cc index a6fee094fcc36a9198c581080607f4d21471a2e8..0649fc74ac85db2d62e8b960079f5c735c4ab505 100644 --- a/tf_adapter/tests/st/optimizers/testcase/get_attr_optimize_pass_test.cc +++ b/tf_adapter/tests/st/optimizers/testcase/get_attr_optimize_pass_test.cc @@ -210,9 +210,6 @@ TEST_F(GetAttrOptimizationPassTest, SetAttrTest) { AttrValue insert_op_file = AttrValue(); insert_op_file.set_s("aipp.cfg"); (*custom_config->mutable_parameter_map())["insert_op_file"] = insert_op_file; - AttrValue es_cluster_config = AttrValue(); - es_cluster_config.set_s("es"); - (*custom_config->mutable_parameter_map())["es_cluster_config"] = es_cluster_config; AttrValue external_weight = AttrValue(); external_weight.set_b(true); (*custom_config->mutable_parameter_map())["external_weight"] = external_weight; diff --git a/tf_adapter/tests/ut/kernels/pbtxt/geop.pbtxt b/tf_adapter/tests/ut/kernels/pbtxt/geop.pbtxt index 0409c9ba4a9be074d2213500989df7a729154203..58575959fcc4e83066f0231f17413e0ac2d7c80d 100644 --- a/tf_adapter/tests/ut/kernels/pbtxt/geop.pbtxt +++ b/tf_adapter/tests/ut/kernels/pbtxt/geop.pbtxt @@ -113,30 +113,6 @@ node { s: "dynamic_execute" } } - attr { - key: "_deploy_inject_config" - value { - s: "deploy_inject_config" - } - } - attr { - key: "_execute_times" - value { - s: "2" - } - } - attr { - key: "_max_num" - value { - s: "1" - } - } - attr { - key: "_embedding_dim" - value { - s: "1" - } - } attr { key: "_dynamic_input" value { diff --git a/tf_adapter/tests/ut/optimizers/testcase/get_attr_optimize_pass_test.cc b/tf_adapter/tests/ut/optimizers/testcase/get_attr_optimize_pass_test.cc index 151838b587ccc1997b128917ad6d316577cdf6ef..0649fc74ac85db2d62e8b960079f5c735c4ab505 100644 --- a/tf_adapter/tests/ut/optimizers/testcase/get_attr_optimize_pass_test.cc +++ b/tf_adapter/tests/ut/optimizers/testcase/get_attr_optimize_pass_test.cc @@ -210,9 +210,6 @@ TEST_F(GetAttrOptimizationPassTest, SetAttrTest) { AttrValue insert_op_file = AttrValue(); insert_op_file.set_s("aipp.cfg"); (*custom_config->mutable_parameter_map())["insert_op_file"] = insert_op_file; - AttrValue es_cluster_config = AttrValue(); - es_cluster_config.set_s("esclusterconfig.json"); - (*custom_config->mutable_parameter_map())["es_cluster_config"] = es_cluster_config; AttrValue external_weight = AttrValue(); external_weight.set_b(true); (*custom_config->mutable_parameter_map())["external_weight"] = external_weight; diff --git a/tf_adapter/util/npu_attrs.cc b/tf_adapter/util/npu_attrs.cc index 45de04adf9e432915dc7c9f2024860d369868999..a071d809197717eec62751e562cf81d6fe9ca3b6 100644 --- a/tf_adapter/util/npu_attrs.cc +++ b/tf_adapter/util/npu_attrs.cc @@ -576,7 +576,6 @@ std::map NpuAttrs::GetInitOptions(const OpKernelConstr std::string aoe_config_file; std::string stream_sync_timeout = "-1"; std::string event_sync_timeout = "-1"; - std::string es_cluster_config; if (ctx != nullptr && ctx->GetAttr("_NpuOptimizer", &npuOptimizer) == Status::OK()) { (void) ctx->GetAttr("_precision_mode", &precision_mode); @@ -615,7 +614,6 @@ std::map NpuAttrs::GetInitOptions(const OpKernelConstr (void) ctx->GetAttr("_aoe_config_file", &aoe_config_file); (void) ctx->GetAttr("_stream_sync_timeout", &stream_sync_timeout); (void) ctx->GetAttr("_event_sync_timeout", &event_sync_timeout); - (void) ctx->GetAttr("_es_cluster_config", &es_cluster_config); } if (precision_mode.empty()) { @@ -666,7 +664,6 @@ std::map NpuAttrs::GetInitOptions(const OpKernelConstr init_options_["ge.aoe_config_file"] = aoe_config_file; init_options_["stream_sync_timeout"] = stream_sync_timeout; init_options_["event_sync_timeout"] = event_sync_timeout; - init_options_["ge.esClusterConfig"] = es_cluster_config; return init_options_; } @@ -1067,7 +1064,6 @@ std::map NpuAttrs::GetAllAttrOptions(const AttrSlice & std::string stream_sync_timeout = "-1"; std::string event_sync_timeout = "-1"; std::string external_weight = "0"; - std::string es_cluster_config; std::string graph_parallel_option_path; std::string enable_graph_parallel; @@ -1147,7 +1143,6 @@ std::map NpuAttrs::GetAllAttrOptions(const AttrSlice & auto model_deploy_devicelist_value = attrs.Find("_model_deploy_devicelist"); auto topo_sorting_mode_value = attrs.Find("_topo_sorting_mode"); auto insert_op_file_value = attrs.Find("_insert_op_file"); - auto es_cluster_config_value = attrs.Find("_es_cluster_config"); auto resource_config_path_value = attrs.Find("_resource_config_path"); auto aoe_config_file_value = attrs.Find("_aoe_config_file"); auto stream_sync_timeout_value = attrs.Find("_stream_sync_timeout"); @@ -1428,9 +1423,6 @@ std::map NpuAttrs::GetAllAttrOptions(const AttrSlice & if (external_weight_value != nullptr) { external_weight = external_weight_value->s(); } - if (es_cluster_config_value != nullptr) { - es_cluster_config = es_cluster_config_value->s(); - } } all_options["variable_format_optimize"] = variable_format_optimize; @@ -1518,8 +1510,6 @@ std::map NpuAttrs::GetAllAttrOptions(const AttrSlice & all_options["ge.topoSortingMode"] = topo_sorting_mode; all_options["insert_op_file"] = insert_op_file; all_options["ge.insertOpFile"] = insert_op_file; - all_options["es_cluster_config"] = es_cluster_config; - all_options["ge.esClusterConfig"] = es_cluster_config; all_options["resource_config_path"] = resource_config_path; all_options["ge.aoe_config_file"] = aoe_config_file; all_options["aoe_config_file"] = aoe_config_file; @@ -2063,11 +2053,6 @@ Status NpuAttrs::SetNpuOptimizerAttr(const GraphOptimizationPassOptions &options if (params.count("external_weight") > 0) { external_weight = params.at("external_weight").b(); } - if (params.count("es_cluster_config") > 0) { - std::string es_cluster_config = params.at("es_cluster_config").s(); - init_options_["es_cluster_config"] = es_cluster_config; - init_options_["ge.esClusterConfig"] = es_cluster_config; - } if (params.count("frozen_variable") > 0) { frozen_variable = params.at("frozen_variable").b(); } diff --git a/tf_adapter/util/util.cc b/tf_adapter/util/util.cc index f7a1e8c81a99221fcda2005c248bd7b786f60a62..8cb0b1fc33b35d654ef03828514c4db51655d44c 100644 --- a/tf_adapter/util/util.cc +++ b/tf_adapter/util/util.cc @@ -115,7 +115,6 @@ bool IsVariableOrResourceVariable(const Node * const node) { bool IsVariableExecuteOnHost(const Node * const node, const std::string &variable_location) { if (variable_location == "Host" && IsVariableOrResourceVariable(node)) { - ADP_LOG(INFO) << "Node : " << node->name() << " op name : " << node->type_string() << "is execute on host"; return true; } return false;