diff --git a/tf_adapter/kernels/aicore/dynamic_rnn_v2_ops.cc b/tf_adapter/kernels/aicore/dynamic_rnn_v2_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..d6cde04608f8d4e0cd865a47005df4cc00a4e2c1 --- /dev/null +++ b/tf_adapter/kernels/aicore/dynamic_rnn_v2_ops.cc @@ -0,0 +1,30 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2019-2020. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" + +namespace tensorflow { +template +class DynamicRnnV2OP : public OpKernel { +public: + explicit DynamicRnnV2OP(OpKernelConstruction *ctx) : OpKernel(ctx) { LOG(INFO) << "new DynamicRnnV2OP"; } + ~DynamicRnnV2OP() { LOG(INFO) << "del DynamicRnnV2OP"; } + void Compute(OpKernelContext *ctx) override { LOG(INFO) << "in DynamicRnnV2OP"; } + bool IsExpensive() override { return false; } +}; + +REGISTER_KERNEL_BUILDER(Name("DynamicRnnV2").Device(DEVICE_CPU), DynamicRnnV2OP); +} // namespace tensorflow \ No newline at end of file diff --git a/tf_adapter/ops/aicore/npu_aicore_ops.cc b/tf_adapter/ops/aicore/npu_aicore_ops.cc index c56e4f9e422f9165bf11872af67b05570f5f9101..e76449704b570b7c9ef4b6ca8a0425a0c63a57e2 100644 --- a/tf_adapter/ops/aicore/npu_aicore_ops.cc +++ b/tf_adapter/ops/aicore/npu_aicore_ops.cc @@ -219,6 +219,72 @@ REGISTER_OP("DynamicRnn") return Status::OK(); }); +REGISTER_OP("DynamicRnnV2") + .Input("x: T") + .Input("w: T") + .Input("b: T") + .Input("init_h: T") + .Input("init_c: T") + .Output("y: T") + .Output("output_h: T") + .Output("output_c: T") + .Output("i: T") + .Output("j: T") + .Output("f: T") + .Output("o: T") + .Output("tanhc: T") + .Attr("T: {float16, float32}") + .Attr("cell_type: string") + .Attr("direction: string") + .Attr("cell_depth: int = 1") + .Attr("use_peephole: bool = false") + .Attr("keep_prob: float = 1.0") + .Attr("cell_clip: float = -1.0") + .Attr("num_proj: int = 0") + .Attr("time_major: bool = true") + .Attr("activation: string") + .Attr("forget_bias: float = 0.0") + .Attr("is_training: bool = true") + .SetIsStateful() + .SetShapeFn([](InferenceContext* c) { + auto input_shape = c->input(0); + auto num_step = c->Dim(input_shape, 0); + auto batch_size = c->Dim(input_shape, 1); + auto input_size = c->Dim(input_shape, 2); + auto w = c->input(1); + auto hidden_size_total = c->Dim(w, 0); + DimensionHandle hidden_size; + TF_RETURN_IF_ERROR(c->Subtract(hidden_size_total, input_size, &hidden_size)); + int num_proj = 0; + TF_RETURN_IF_ERROR(c->GetAttr("num_proj", &num_proj)); + ShapeHandle output_y_shape; + if (num_proj == 0) { + output_y_shape = c->MakeShape({num_step, batch_size, hidden_size}); + } else { + std::vector num_projs; + num_projs.reserve(num_proj); + auto num_proj_shape = c->MakeShape(num_projs); + DimensionHandle num_proj_size = c->Dim(num_proj_shape, 0); + DimensionHandle output_hidden_size; + TF_RETURN_IF_ERROR(c->Min(num_proj_size, hidden_size, &output_hidden_size)); + output_y_shape = c->MakeShape({num_step, batch_size, output_hidden_size}); + } + auto output_h_shape = + c->MakeShape({num_step, batch_size, hidden_size}); + auto output_c_shape = + c->MakeShape({num_step, batch_size, hidden_size}); + + c->set_output(0, output_y_shape); + c->set_output(1, output_h_shape); + c->set_output(2, output_c_shape); + c->set_output(3, c->UnknownShape()); + c->set_output(4, c->UnknownShape()); + c->set_output(5, c->UnknownShape()); + c->set_output(6, c->UnknownShape()); + c->set_output(7, c->UnknownShape()); + return Status::OK(); + }); + REGISTER_OP("DynamicRnnGrad") .Input("x: T") .Input("w: T") diff --git a/tf_adapter/python/npu_bridge/estimator/npu/npu_dynamic_rnn.py b/tf_adapter/python/npu_bridge/estimator/npu/npu_dynamic_rnn.py index 5cf30e3683d4784d724b0a126b801b0bf93c4351..ef9fe0d0121efa986b814b3f4d74b0544576ecc4 100644 --- a/tf_adapter/python/npu_bridge/estimator/npu/npu_dynamic_rnn.py +++ b/tf_adapter/python/npu_bridge/estimator/npu/npu_dynamic_rnn.py @@ -290,34 +290,37 @@ class DynamicRNN(_DynamicBasic): shape=[4 * self._hidden_size], dtype=self._dtype, initializer=init_ops.zeros_initializer(dtype=self._dtype)) - self._init_h = array_ops.zeros([1, batch_size, self._hidden_size], dtype=self._dtype) - self._init_c = array_ops.zeros([1, batch_size, self._hidden_size], dtype=self._dtype) super(DynamicRNN, self).build(input_shape) def call(self, x, + weight=None, + bias=None, seq_length=None, init_h=None, init_c=None): """Dynamic GRU. """ super(DynamicRNN, self).call(x, seq_length=seq_length) + batch_size = array_ops.shape(x)[1] + if init_h is None: + self._init_h = array_ops.zeros([1, batch_size, self._hidden_size], dtype=self._dtype) init_h = self._init_h - else: - init_h_shape = tensor_shape.TensorShape(init_h) - if init_h_shape.ndims == 2: - init_h = tf.reshape(init_h, [1, init_h_shape[0], init_h_shape[1]]) - if init_c is None: - init_c = self._init_c - else: - init_c_shape = tensor_shape.TensorShape(init_c) - if init_c_shape.ndims == 2: - init_c = tf.reshape(init_c, [1, init_c_shape[0], init_c_shape[1]]) if init_c is None: + self._init_c = array_ops.zeros([1, batch_size, self._hidden_size], dtype=self._dtype) init_c = self._init_c - self._args["w"] = self._rnn_w - self._args["b"] = self._rnn_b + + if weight is None: + weight = self._rnn_w + if bias is None: + bias = self._rnn_b + self._args["w"] = weight + self._args["b"] = bias self._args["init_h"] = init_h self._args["init_c"] = init_c - return gen_npu_ops.dynamic_rnn(**self._args) + if seq_length is None: + self._args.pop("seq_length") + return gen_npu_ops.dynamic_rnn_v2(**self._args) + else: + return gen_npu_ops.dynamic_rnn(**self._args) diff --git a/tf_adapter/python/npu_bridge/estimator/npu_ops.py b/tf_adapter/python/npu_bridge/estimator/npu_ops.py index 43b32674283f72faea7b003de769aa66ad3d23bb..a5cd2fd5e39f1549ba4efdcefd107c02c2b65c08 100644 --- a/tf_adapter/python/npu_bridge/estimator/npu_ops.py +++ b/tf_adapter/python/npu_bridge/estimator/npu_ops.py @@ -269,6 +269,25 @@ def dynamic_rnn_grad(op, dy, dh, dc, di, dj, df, do, dtanhc): return (dx, dw, db, seq_length, dh_prev, dc_prev) +@ops.RegisterGradient("DynamicRnnV2") +def dynamic_rnn_v2_grad(op, dy, dh, dc, di, dj, df, do, dtanhc): + (x, w, b, init_h, init_c) = op.inputs + (y, output_h, output_c, i, j, f, o, tanhc) = op.outputs + (dw, db, dx, dh_prev, dc_prev) = gen_npu_ops.dynamic_rnn_grad(x, w, b, y, init_h[-1], init_c[-1], output_h, + output_c, dy, dh[-1], dc[-1], i, j, f, o, tanhc, + cell_type=op.get_attr("cell_type"), + direction=op.get_attr("direction"), + cell_depth=op.get_attr("cell_depth"), + use_peephole=op.get_attr("use_peephole"), + keep_prob=op.get_attr("keep_prob"), + cell_clip=op.get_attr("cell_clip"), + num_proj=op.get_attr("num_proj"), + time_major=op.get_attr("time_major"), + forget_bias=op.get_attr("forget_bias")) + + return (dx, dw, db, dh_prev, dc_prev) + + def scatter_elements(data, indices, updates, axis=0, name=None): data = ops.convert_to_tensor(data, name="data") indices = ops.convert_to_tensor(indices, name="indices") diff --git a/tf_adapter/tests/st/kernels/testcase/dynamic_rnn_v2_test.cc b/tf_adapter/tests/st/kernels/testcase/dynamic_rnn_v2_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..7654dcc089502b8571c30cdb85bc379928c6a28c --- /dev/null +++ b/tf_adapter/tests/st/kernels/testcase/dynamic_rnn_v2_test.cc @@ -0,0 +1,72 @@ +#include "tf_adapter/kernels/aicore/dynamic_rnn_v2_ops.cc" +#include +#include "gtest/gtest.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +PartialTensorShape TShape(std::initializer_list dims) { + return PartialTensorShape(dims); +} + +FakeInputFunctor FakeInputStub(DataType dt) { + return [dt](const OpDef& op_def, int in_index, const NodeDef& node_def, + NodeDefBuilder* builder) { + char c = 'a' + (in_index % 26); + string in_node = string(&c, 1); + builder->Input(in_node, 0, dt); + return Status::OK(); + }; +} + +TEST(DynamicRnnV2OpTest, TestDynamicRnnV2) { + DataTypeSlice input_types({DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT}); + MemoryTypeSlice input_memory_types; + DataTypeSlice output_types({DT_FLOAT,DT_FLOAT,DT_FLOAT,DT_FLOAT, + DT_FLOAT,DT_FLOAT,DT_FLOAT,DT_FLOAT}); + MemoryTypeSlice output_memory_types; + DeviceBase *device = new DeviceBase(Env::Default()); + NodeDef *node_def = new NodeDef(); + OpDef *op_def = new OpDef(); + OpKernelConstruction *context = new OpKernelConstruction(DEVICE_CPU, device, nullptr, node_def, op_def, nullptr, + input_types, input_memory_types, output_types, output_memory_types, + 1, nullptr); + DynamicRnnV2OP dynamic_rnn_v2(context); + OpKernelContext *ctx = nullptr; + dynamic_rnn_v2.Compute(ctx); + dynamic_rnn_v2.IsExpensive(); + delete device; + delete node_def; + delete op_def; + delete context; +} + +TEST(DynamicRnnV2OpTest, TestDynamicRnnV2ShapeInference) { + const OpRegistrationData* reg; + TF_CHECK_OK(OpRegistry::Global()->LookUp("DynamicRnnV2", ®)); + OpDef op_def = reg->op_def; + NodeDef def; + TF_CHECK_OK(NodeDefBuilder("dummy", &op_def) + .Attr("T", DT_FLOAT) + .Attr("direction", "BIDIRECTIONAL") + .Input(FakeInputStub(DT_FLOAT)) + .Input(FakeInputStub(DT_FLOAT)) + .Input(FakeInputStub(DT_FLOAT)) + .Input(FakeInputStub(DT_FLOAT)) + .Input(FakeInputStub(DT_FLOAT)) + .Finalize(&def)); + shape_inference::InferenceContext c(0, &def, op_def,{TShape({1,16,16}), TShape({32,64}), TShape({64}), + TShape({1,16,16}), TShape({1,16,16})}, {}, {}, {}); + TF_CHECK_OK(reg->shape_inference_fn(&c)); +} + +} // namespace +} // namespace tensorflow diff --git a/tf_adapter/tests/st/kernels/testcase/layer_norm_grad_ops_test.cc b/tf_adapter/tests/st/kernels/testcase/layer_norm_grad_ops_test.cc index 3ee42c2fcc9d4401d2867844ab0e120df676112f..532abc7d40a83b343feeda689e92d83930a2ba02 100644 --- a/tf_adapter/tests/st/kernels/testcase/layer_norm_grad_ops_test.cc +++ b/tf_adapter/tests/st/kernels/testcase/layer_norm_grad_ops_test.cc @@ -21,6 +21,8 @@ FakeInputFunctor FakeInputStub(DataType dt) { }; } + + TEST(LayerNormGradOpTest, TestLayerNormGrad) { DataTypeSlice input_types({DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT}); MemoryTypeSlice input_memory_types; diff --git a/tf_adapter/tests/ut/kernels/testcase/dynamic_rnn_v2_test.cc b/tf_adapter/tests/ut/kernels/testcase/dynamic_rnn_v2_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..7654dcc089502b8571c30cdb85bc379928c6a28c --- /dev/null +++ b/tf_adapter/tests/ut/kernels/testcase/dynamic_rnn_v2_test.cc @@ -0,0 +1,72 @@ +#include "tf_adapter/kernels/aicore/dynamic_rnn_v2_ops.cc" +#include +#include "gtest/gtest.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +PartialTensorShape TShape(std::initializer_list dims) { + return PartialTensorShape(dims); +} + +FakeInputFunctor FakeInputStub(DataType dt) { + return [dt](const OpDef& op_def, int in_index, const NodeDef& node_def, + NodeDefBuilder* builder) { + char c = 'a' + (in_index % 26); + string in_node = string(&c, 1); + builder->Input(in_node, 0, dt); + return Status::OK(); + }; +} + +TEST(DynamicRnnV2OpTest, TestDynamicRnnV2) { + DataTypeSlice input_types({DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT}); + MemoryTypeSlice input_memory_types; + DataTypeSlice output_types({DT_FLOAT,DT_FLOAT,DT_FLOAT,DT_FLOAT, + DT_FLOAT,DT_FLOAT,DT_FLOAT,DT_FLOAT}); + MemoryTypeSlice output_memory_types; + DeviceBase *device = new DeviceBase(Env::Default()); + NodeDef *node_def = new NodeDef(); + OpDef *op_def = new OpDef(); + OpKernelConstruction *context = new OpKernelConstruction(DEVICE_CPU, device, nullptr, node_def, op_def, nullptr, + input_types, input_memory_types, output_types, output_memory_types, + 1, nullptr); + DynamicRnnV2OP dynamic_rnn_v2(context); + OpKernelContext *ctx = nullptr; + dynamic_rnn_v2.Compute(ctx); + dynamic_rnn_v2.IsExpensive(); + delete device; + delete node_def; + delete op_def; + delete context; +} + +TEST(DynamicRnnV2OpTest, TestDynamicRnnV2ShapeInference) { + const OpRegistrationData* reg; + TF_CHECK_OK(OpRegistry::Global()->LookUp("DynamicRnnV2", ®)); + OpDef op_def = reg->op_def; + NodeDef def; + TF_CHECK_OK(NodeDefBuilder("dummy", &op_def) + .Attr("T", DT_FLOAT) + .Attr("direction", "BIDIRECTIONAL") + .Input(FakeInputStub(DT_FLOAT)) + .Input(FakeInputStub(DT_FLOAT)) + .Input(FakeInputStub(DT_FLOAT)) + .Input(FakeInputStub(DT_FLOAT)) + .Input(FakeInputStub(DT_FLOAT)) + .Finalize(&def)); + shape_inference::InferenceContext c(0, &def, op_def,{TShape({1,16,16}), TShape({32,64}), TShape({64}), + TShape({1,16,16}), TShape({1,16,16})}, {}, {}, {}); + TF_CHECK_OK(reg->shape_inference_fn(&c)); +} + +} // namespace +} // namespace tensorflow