From 573082e24e528d05586afdf8d479da9deaa1dbfb Mon Sep 17 00:00:00 2001 From: zhengying Date: Mon, 6 Mar 2023 11:08:52 +0800 Subject: [PATCH] =?UTF-8?q?dropout=E6=B7=BB=E5=8A=A0v4=E7=9A=84tf=E6=8E=A5?= =?UTF-8?q?=E5=8F=A3=E5=92=8C=E9=80=82=E9=85=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../npu_bridge/estimator/npu_aicore_ops.py | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/tf_adapter/python/npu_bridge/estimator/npu_aicore_ops.py b/tf_adapter/python/npu_bridge/estimator/npu_aicore_ops.py index 2098dd510..bc5b8a42d 100644 --- a/tf_adapter/python/npu_bridge/estimator/npu_aicore_ops.py +++ b/tf_adapter/python/npu_bridge/estimator/npu_aicore_ops.py @@ -125,6 +125,37 @@ def _DropOutDoMaskV3Grad(op, grad): return [result, None, None] +def dropout_v4(x, keep_prob, noise_shape=None, seed=None, output_dtype=dtypes.bool, name=None): + """The gradient for `gelu`. + + Args: + x: A tensor with type is float. + keep_prob: A tensor, float, rate of every element reserved. + noise_shape: A 1-D tensor, with type int32, shape of keep/drop what random + generated. + seed: Random seed. + output_dtype: dtype of output tensor, default is bool. + name: Layer name. + + Returns: + A tensor. + """ + x = ops.convert_to_tensor(x, name="x") + if not x.dtype.is_floating: + raise ValueError("x must be a floating point tensor." + " Got a %s tensor instead." % x.dtype) + if isinstance(keep_prob, numbers.Real) and not 0 < keep_prob <= 1.0: + raise ValueError("keep_prob must be a float value or a scalar tensor in the " + "range (0, 1], got %g" % keep_prob) + if isinstance(keep_prob, float) and keep_prob == 1.0: + return x + seed, seed2 = random_seed.get_seed(seed) + noise_shape = _get_noise_shape(x, noise_shape) + gen_out = npu_aicore_ops.drop_out_gen_mask_v3(noise_shape, keep_prob, seed, seed2, output_dtype, name) + result = npu_aicore_ops.drop_out_do_mask_v3(x, gen_out, keep_prob, name) + return result + + def lru_cache_v2(index_list, data, cache, tag, is_last_call, pre_route_count, name=None): """ LRUCacheV2 op -- Gitee