From 0c27ccf1d52b9c2242b118b47bf25a1821e29645 Mon Sep 17 00:00:00 2001 From: liuqiang Date: Mon, 15 May 2023 16:34:18 +0800 Subject: [PATCH] embedding support set train or evaluate --- .../npu_bridge/embedding/embedding_service.py | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/tf_adapter/python/npu_bridge/embedding/embedding_service.py b/tf_adapter/python/npu_bridge/embedding/embedding_service.py index 95a396145..c32cf571a 100644 --- a/tf_adapter/python/npu_bridge/embedding/embedding_service.py +++ b/tf_adapter/python/npu_bridge/embedding/embedding_service.py @@ -90,6 +90,7 @@ class ESWorker: self._initializer = None self._init_flag = False self._table_init = False + self._set_train_or_infer_mode = False self._table_has_init = [] self.user_defined_table_infos = [] self.table_map_policy = None @@ -133,6 +134,31 @@ class ESWorker: mu=mu, sigma=sigma) + # Set inference mode + def enable_inference_mode(self): + self._train_mode = False + + # Set train mode + def enable_train_mode(self): + self._train_mode = True + + # Return the model is train mode or inference mode + def get_model_mode(self): + if self._train_mode: + return "CURRENT_SETTING = TRAIN" + else: + return "CURRENT_SETTING = INFERENCE" + + def set_train_or_evaluate(self, flag): + """ Set whether it is train or evaluate. """ + if flag != "train" and flag != "evaluate": + raise ValueError("Flag can only be following two values: train, evaluate.") + if flag == "train": + self._train_mode = True + else: + self._train_mode = False + self._set_train_or_infer_mode = True + # 提供embedding init功能 # @param vocabulary_size 表的初始大小, int 类型 # @param table_id, int32 类型 @@ -171,9 +197,13 @@ class ESWorker: mu=0.0, sigma=1.0) if optimizer is None: + if self._set_train_or_infer_mode and self._train_mode: + raise ValueError("If set train mode, optimizer parameter can not be None.") self._train_mode = False self.slot_vars_num = 0 else: + if self._set_train_or_infer_mode and (not self._train_mode): + raise ValueError("If set evaluate mode, optimizer parameter must be None.") if (not isinstance(optimizer, embedding_optimizer.AdamOptimizer)) and \ (not isinstance(optimizer, embedding_optimizer.AdagradOptimizer)) and \ (not isinstance(optimizer, embedding_optimizer.AdamWOptimizer)): -- Gitee