diff --git a/tf_adapter/ops/aicore/npu_aicore_ops.cc b/tf_adapter/ops/aicore/npu_aicore_ops.cc index 1fde1ee6ff2048a008d817999938ea45d19ee2db..73c2bc600d9c04c950a93a36f6acb08ab39c86bc 100644 --- a/tf_adapter/ops/aicore/npu_aicore_ops.cc +++ b/tf_adapter/ops/aicore/npu_aicore_ops.cc @@ -17,6 +17,7 @@ #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" +#include namespace tensorflow { using shape_inference::DimensionHandle; diff --git a/tf_adapter/python/npu_bridge/estimator/npu/npu_dynamic_rnn.py b/tf_adapter/python/npu_bridge/estimator/npu/npu_dynamic_rnn.py index 5d17344ee8925c2aa0da7cdba327db4b250b5d49..4a8158635aa1cf9e602322573307ff52a2cd6c91 100644 --- a/tf_adapter/python/npu_bridge/estimator/npu/npu_dynamic_rnn.py +++ b/tf_adapter/python/npu_bridge/estimator/npu/npu_dynamic_rnn.py @@ -17,6 +17,7 @@ """NPU impletemented RNN""" import math +import tensorflow as tf from tensorflow.python.framework import dtypes from tensorflow.python.layers import base as base_layer from tensorflow.python.ops import init_ops @@ -27,6 +28,7 @@ gen_npu_ops = helper.get_gen_ops() DYNAMIC_RNN_UNIDIRECTION = "UNIDIRECTIONAL" DYNAMIC_RNN_BIDIRECTION = "BIDIRECTIONAL" +DYNAMIC_RNN_REDIRECTIONAL = "REDIRECTIONAL" class _DynamicBasic(base_layer.Layer): @@ -109,10 +111,15 @@ class _DynamicBasic(base_layer.Layer): def check_direction(self): """Check validity of direction.""" - if self._direction not in (DYNAMIC_RNN_UNIDIRECTION, DYNAMIC_RNN_BIDIRECTION): + if self._direction not in (DYNAMIC_RNN_UNIDIRECTION, DYNAMIC_RNN_BIDIRECTION, DYNAMIC_RNN_REDIRECTIONAL): raise ValueError("Invalid direction: %s, expecting %s or %s" % (self._direction, DYNAMIC_RNN_UNIDIRECTION, DYNAMIC_RNN_BIDIRECTION)) + def reshape(self, shape): + if DYNAMIC_RNN_BIDIRECTION == self._direction: + return [2] + shape + return shape + def build(self, input_shape): """Build class""" time_size = input_shape[0].value @@ -193,46 +200,92 @@ class DynamicGRUV2(_DynamicBasic): raise ValueError("Expected input_shape[2] to be known, saw shape: input_size.") input_size = input_shape[2].value batch_size = input_shape[1].value + if batch_size is None: + batch_size = 16 stdv = 1.0 / math.sqrt(self._hidden_size) self._gruv2_weight_input = self.add_variable( "dynamicgruv2/weight_input", - shape=[input_size, 3 * self._hidden_size], + shape=self.reshape([input_size, 3 * self._hidden_size]), dtype=self._dtype, initializer=init_ops.random_uniform_initializer(-stdv, stdv)) self._gruv2_weight_hidden = self.add_variable( "dynamicgruv2/weight_hidden", - shape=[self._hidden_size, 3 * self._hidden_size], + shape=self.reshape([self._hidden_size, 3 * self._hidden_size]), dtype=self._dtype, initializer=init_ops.random_uniform_initializer(-stdv, stdv)) self._bias_input = self.add_variable( "dynamicgruv2/bias_input", - shape=[3 * self._hidden_size], + shape=self.reshape([3 * self._hidden_size]), dtype=self._dtype, initializer=init_ops.random_uniform_initializer(-stdv, stdv)) self._bias_hidden = self.add_variable( "dynamicgruv2/bias_hidden", - shape=[3 * self._hidden_size], + shape=self.reshape([3 * self._hidden_size]), dtype=self._dtype, initializer=init_ops.random_uniform_initializer(-stdv, stdv)) - self._init_h = array_ops.zeros([batch_size, self._hidden_size], dtype=self._dtype) + self._init_h = array_ops.zeros(self.reshape([batch_size, self._hidden_size]), dtype=self._dtype) super(DynamicGRUV2, self).build(input_shape) def call(self, x, seq_length=None, - init_h=None): + init_h=None, + weight_input=None, + weight_hidden=None, + bias_input=None, + bias_hidden=None): """Dynamic GRU. """ super(DynamicGRUV2, self).call(x, seq_length=seq_length) if init_h is None: init_h = self._init_h - self._args["init_h"] = init_h - self._args["weight_input"] = self._gruv2_weight_input - self._args["weight_hidden"] = self._gruv2_weight_hidden - self._args["bias_input"] = self._bias_input - self._args["bias_hidden"] = self._bias_hidden + if weight_input is None: + weight_input = self._gruv2_weight_input + if weight_hidden is None: + weight_hidden = self._gruv2_weight_hidden + if bias_input is None: + bias_input = self._bias_input + if bias_hidden is None: + bias_hidden = self._bias_hidden if seq_length is not None: self._args["seq_length"] = seq_length + + if DYNAMIC_RNN_BIDIRECTION == self._direction: + init_h, init_h_2 = tf.split(init_h, 2, 0) + weight_input, weight_input_2 = tf.split(weight_input, 2, 0) + weight_hidden, weight_hidden_2 = tf.split(weight_hidden, 2, 0) + bias_input, bias_input_2 = tf.split(bias_input, 2, 0) + bias_hidden, bias_hidden_2 = tf.split(bias_hidden, 2, 0) + + self._args["direction"] = DYNAMIC_RNN_UNIDIRECTION + self._args["init_h"] = tf.reshape(init_h, init_h.shape[1:]) + self._args["weight_input"] = tf.reshape(weight_input, weight_input.shape[1:]) + self._args["weight_hidden"] = tf.reshape(weight_hidden, weight_hidden.shape[1:]) + self._args["bias_input"] = tf.reshape(bias_input, bias_input.shape[1:]) + self._args["bias_hidden"] = tf.reshape(bias_hidden, bias_hidden.shape[1:]) + forward = gen_npu_ops.dynamic_gru_v2(**self._args) + + self._args["direction"] = DYNAMIC_RNN_REDIRECTIONAL + self._args["init_h"] = tf.reshape(init_h_2, init_h_2.shape[1:]) + self._args["weight_input"] = tf.reshape(weight_input_2, weight_input_2.shape[1:]) + self._args["weight_hidden"] = tf.reshape(weight_hidden_2, weight_hidden_2.shape[1:]) + self._args["bias_input"] = tf.reshape(bias_input_2, bias_input_2.shape[1:]) + self._args["bias_hidden"] = tf.reshape(bias_hidden_2, bias_hidden_2.shape[1:]) + reverse = gen_npu_ops.dynamic_gru_v2(**self._args) + + concat_y = tf.concat([tf.expand_dims(forward[0], 0), + tf.expand_dims(reverse[0], 0)], + 0) + concat_yh = tf.concat([tf.expand_dims(forward[1], 0), + tf.expand_dims(reverse[1], 0)], + 0) + return [concat_y, concat_yh] + + self._args["init_h"] = init_h + self._args["weight_input"] = weight_input + self._args["weight_hidden"] = weight_hidden + self._args["bias_input"] = bias_input + self._args["bias_hidden"] = bias_hidden return gen_npu_ops.dynamic_gru_v2(**self._args)