From e9d5987e43a9f1ab2231809d16f73647dddb53fa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E5=BC=BA?= Date: Thu, 31 Oct 2024 10:50:41 +0800 Subject: [PATCH] es cleancode for tfa --- .../npu_bridge/embedding/embedding_service.py | 61 ++----------------- .../embedding/embedding_table_map_policy.py | 42 +++++++------ .../npu_bridge/embedding/embedding_utils.py | 49 ++++++++++++++- 3 files changed, 78 insertions(+), 74 deletions(-) diff --git a/tf_adapter/python/npu_bridge/embedding/embedding_service.py b/tf_adapter/python/npu_bridge/embedding/embedding_service.py index b6ed48c12..dddb10f6b 100644 --- a/tf_adapter/python/npu_bridge/embedding/embedding_service.py +++ b/tf_adapter/python/npu_bridge/embedding/embedding_service.py @@ -34,7 +34,7 @@ from npu_bridge.embedding.embedding_resource import NpuEmbeddingResource from npu_bridge.embedding import embedding_optimizer from npu_bridge.embedding.embedding_table_map_policy import NoneTableMapPolicy, AutoMergeTableMapPolicy from npu_bridge.embedding.embedding_utils import EmbeddingVariableOption, CounterFilter, PaddingParamsOption, \ - CompletionKeyOption + CompletionKeyOption, check_common_init_params, check_each_initializer, check_init_params_type from npu_bridge.embedding.embedding_utils import EvictOption _INT32_MAX_VALUE = 2147483647 @@ -104,26 +104,6 @@ def es_initializer(initializer_mode, min=-2.0, max=2.0, constant_value=0.0, mu=0 seed=seed) -def check_common_init_params(name, init_vocabulary_size, embedding_dim, embedding_type, mask_zero): - if (name is None) or (init_vocabulary_size is None) or (embedding_dim is None): - raise ValueError("table name, init_vocabulary_size and embedding_dim can not be None.") - if not isinstance(name, str): - raise TypeError("embedding table name must be string.") - regex = re.compile('[@!#$%^&*()<>?/\|}{~:]') - if regex.search(name) is not None: - raise ValueError("table name contains illegal character.") - if (not isinstance(init_vocabulary_size, int)) or (not isinstance(embedding_dim, int)): - raise ValueError("init_vocabulary_size and embedding_dim must be int.") - if init_vocabulary_size < 0: - raise ValueError("init_vocabulary_size can not be smaller than zero.") - if embedding_dim <= 0: - raise ValueError("embedding_dim must be greater than zero.") - if (embedding_type != "PS") and (embedding_type != "data_parallel"): - raise TypeError("embedding_type must be PS or data_parallel") - if not isinstance(mask_zero, bool): - raise TypeError("mask zero must be bool") - - class ESWorker: """ Embedding service class. """ @@ -219,19 +199,8 @@ class ESWorker: """Operator for init initializer.""" if (table_id is None) or (initializer_mode is None): raise ValueError("table_id and initializer_mode can not be None.") - if initializer_mode == 'random_uniform': - if (min is None) or (max is None) or \ - (not isinstance(min, (float, int))) or (not isinstance(max, (float, int))): - raise ValueError("If initializer is random_uniform, min and max can not be None, must be int or float.") - if initializer_mode == 'truncated_normal': - if (min is None) or (max is None) or (mu is None) or (sigma is None) or \ - (not isinstance(min, (float, int))) or (not isinstance(max, (float, int))) or \ - (not isinstance(mu, (float, int))) or (not isinstance(sigma, (float, int))): - raise ValueError("If initializer is truncated_normal, min, max, mu and sigma can not be None," - "and they must be int or float.") - if initializer_mode == 'constant': - if (constant_value is None) or (not isinstance(constant_value, (float, int))): - raise ValueError("If initializer is constant, constant_value can not be None, must be float or int.") + check_each_initializer(initializer_mode=initializer_mode, min_value=min, max_value=max, + constant_value=constant_value, mu=mu, sigma=sigma) if (not isinstance(table_id, int)) or (table_id < 0) or (table_id >= _INT32_MAX_VALUE): raise ValueError("table_id value is false, must be [0, 2147483647) and int type, please check.") if min > max: @@ -927,12 +896,9 @@ class ESWorker: raise ValueError("max_vocabulary_size or embedding_dim or multihot_lens can not be None.") if (key_dtype is None) or (value_dtype is None): raise ValueError("key_dtype and value_dtype can not be None.") - if (key_dtype is not tf.int64) or (value_dtype is not tf.float32): - raise TypeError("key_dtype only support tf.int64, value_dtype only support tf.float32 now.") - if (not isinstance(init_vocabulary_size, int)) or (not isinstance(embedding_dim, int)) or \ - (not isinstance(multihot_lens, int)) or (not isinstance(allow_merge, bool)): - raise TypeError("init_vocabulary_size, embedding_dim, multihot_lens must be int," - "allow_merge must be bool.") + check_init_params_type(key_dtype=key_dtype, value_dtype=value_dtype, + init_vocabulary_size=init_vocabulary_size, embedding_dim=embedding_dim, + multihot_lens=multihot_lens, allow_merge=allow_merge) if init_vocabulary_size <= 0 or embedding_dim <= 0 or multihot_lens <= 0: raise ValueError("init_vocabulary_size, embedding_dim, multihot_lens must be greater than zero.") if initializer is None: @@ -940,21 +906,6 @@ class ESWorker: if allow_merge: raise ValueError("allow_merge do not support now.") self._need_table_merge = True - if isinstance(initializer, EsInitializer): - if initializer.initializer_mode == "random_uniform": - self._table_id_to_initializer[table_id] = \ - tf.random_uniform_initializer(minval=initializer.min, maxval=initializer.max, - seed=initializer.seed, dtype=value_dtype) - elif initializer.initializer_mode == "truncated_normal": - self._table_id_to_initializer[table_id] = \ - tf.truncated_normal_initializer(stddev=initializer.stddev, mean=initializer.mean, - seed=initializer.seed, dtype=value_dtype) - elif initializer.initializer_mode == "constant": - self._table_id_to_initializer[table_id] = \ - tf.constant_initializer(value=initializer.value, dtype=value_dtype) - elif not callable(initializer): - if ops.convert_to_tensor(initializer).dtype.base_dtype != tf.float32: - raise ValueError("Initializer type '%s' and explict dtype tf.float32 don't match." % init_dtype) def _check_and_update_ps_init_params(self, name, init_vocabulary_size, embedding_dim, max_feature_count, ev_option): steps_to_live = 0 diff --git a/tf_adapter/python/npu_bridge/embedding/embedding_table_map_policy.py b/tf_adapter/python/npu_bridge/embedding/embedding_table_map_policy.py index 96a317c67..d3a788632 100644 --- a/tf_adapter/python/npu_bridge/embedding/embedding_table_map_policy.py +++ b/tf_adapter/python/npu_bridge/embedding/embedding_table_map_policy.py @@ -20,29 +20,35 @@ import tensorflow as tf from tensorflow.python.platform import tf_logging as logging -def compare_initializer(init1, init2): - if isinstance(init1, tf.initializers.truncated_normal): - if isinstance(init1, tf.initializers.truncated_normal): - if (init1.stddev != init2.stddev) or (init1.seed != init2.seed) or (init1.mean != init2.mean) or \ - (init1.dtype != init2.dtype): - return False - else: - return True - else: +def compare_for_truncated_normal(init1, init2): + if isinstance(init2, tf.initializers.truncated_normal): + if (init1.stddev != init2.stddev) or (init1.seed != init2.seed) or (init1.mean != init2.mean) or \ + (init1.dtype != init2.dtype): return False - - if isinstance(init1, tf.initializers.random_uniform): - if isinstance(init1, tf.initializers.random_uniform): - if (init1.minval != init2.minval) or (init1.maxval != init2.maxval) or (init1.seed != init2.seed) or \ - (init1.dtype != init2.dtype): - return False - else: - return True else: + return True + else: + return False + + +def compare_for_random_uniform(init1, init2): + if isinstance(init2, tf.initializers.random_uniform): + if (init1.minval != init2.minval) or (init1.maxval != init2.maxval) or (init1.seed != init2.seed) or \ + (init1.dtype != init2.dtype): return False + else: + return True + else: + return False + +def compare_initializer(init1, init2): + if isinstance(init1, tf.initializers.truncated_normal): + return compare_for_truncated_normal(init1, init2) + if isinstance(init1, tf.initializers.random_uniform): + return compare_for_random_uniform(init1, init2) if isinstance(init1, tf.initializers.constant): - if isinstance(init1, tf.initializers.constant): + if isinstance(init2, tf.initializers.constant): if (init1.value != init2.value) or (init1.dtype != init2.dtype): return False else: diff --git a/tf_adapter/python/npu_bridge/embedding/embedding_utils.py b/tf_adapter/python/npu_bridge/embedding/embedding_utils.py index 88982e691..9d8680f9a 100644 --- a/tf_adapter/python/npu_bridge/embedding/embedding_utils.py +++ b/tf_adapter/python/npu_bridge/embedding/embedding_utils.py @@ -15,6 +15,9 @@ # limitations under the License. # ============================================================================== +import re + + class EmbeddingVariableOption: """ option for embedding service table. """ @@ -65,4 +68,48 @@ class EvictOption: """ Evict option for embedding table. """ def __init__(self, steps_to_live): - self.steps_to_live = steps_to_live \ No newline at end of file + self.steps_to_live = steps_to_live + + +def check_common_init_params(name, init_vocabulary_size, embedding_dim, embedding_type, mask_zero): + if (name is None) or (init_vocabulary_size is None) or (embedding_dim is None): + raise ValueError("table name, init_vocabulary_size and embedding_dim can not be None.") + if not isinstance(name, str): + raise TypeError("embedding table name must be string.") + regex = re.compile('[@!#$%^&*()<>?/\|}{~:]') + if regex.search(name) is not None: + raise ValueError("table name contains illegal character.") + if (not isinstance(init_vocabulary_size, int)) or (not isinstance(embedding_dim, int)): + raise ValueError("init_vocabulary_size and embedding_dim must be int.") + if init_vocabulary_size < 0: + raise ValueError("init_vocabulary_size can not be smaller than zero.") + if embedding_dim <= 0: + raise ValueError("embedding_dim must be greater than zero.") + if (embedding_type != "PS") and (embedding_type != "data_parallel"): + raise TypeError("embedding_type must be PS or data_parallel") + if not isinstance(mask_zero, bool): + raise TypeError("mask zero must be bool") + + +def check_each_initializer(initializer_mode, min_value, max_value, constant_value, mu, sigma): + if initializer_mode == 'random_uniform': + if (min_value is None) or (max_value is None) or \ + (not isinstance(min_value, (float, int))) or (not isinstance(max_value, (float, int))): + raise ValueError("If initializer is random_uniform, min and max can not be None, must be int or float.") + if initializer_mode == 'truncated_normal': + if (min_value is None) or (max_value is None) or (mu is None) or (sigma is None) or \ + (not isinstance(min_value, (float, int))) or (not isinstance(max_value, (float, int))) or \ + (not isinstance(mu, (float, int))) or (not isinstance(sigma, (float, int))): + raise ValueError("If initializer is truncated_normal, min, max, mu and sigma can not be None," + "and they must be int or float.") + if initializer_mode == 'constant': + if (constant_value is None) or (not isinstance(constant_value, (float, int))): + raise ValueError("If initializer is constant, constant_value can not be None, must be float or int.") + + +def check_init_params_type(key_dtype, value_dtype, init_vocabulary_size, embedding_dim, multihot_lens, allow_merge): + if (key_dtype is not tf.int64) or (value_dtype is not tf.float32): + raise TypeError("key_dtype only support tf.int64, value_dtype only support tf.float32 now.") + if (not isinstance(init_vocabulary_size, int)) or (not isinstance(embedding_dim, int)) or \ + (not isinstance(multihot_lens, int)) or (not isinstance(allow_merge, bool)): + raise TypeError("init_vocabulary_size, embedding_dim, multihot_lens must be int, allow_merge must be bool.") -- Gitee