diff --git a/tf_adapter/ops/aicore/npu_aicore_ops.cc b/tf_adapter/ops/aicore/npu_aicore_ops.cc index 9734c7ba77302fe338e4488198c4814f6b6c278b..c72bf9c05f61bad1e7a82fe49a9c0def11948494 100644 --- a/tf_adapter/ops/aicore/npu_aicore_ops.cc +++ b/tf_adapter/ops/aicore/npu_aicore_ops.cc @@ -113,6 +113,7 @@ REGISTER_OP("DynamicGruV2Grad") .Input("reset: T") .Input("new: T") .Input("hidden_new: T") + .Input("seq_length: int32") .Output("dw_input: T") .Output("dw_hidden: T") .Output("db_input: T") @@ -226,6 +227,7 @@ REGISTER_OP("DynamicAUGRUGrad") .Input("reset: T") .Input("new: T") .Input("hidden_new: T") +.Input("seq_length: int32") .Output("dw_input: T") .Output("dw_hidden: T") .Output("db_input: T") diff --git a/tf_adapter/python/npu_bridge/estimator/npu/npu_dynamic_rnn.py b/tf_adapter/python/npu_bridge/estimator/npu/npu_dynamic_rnn.py index 3804be9cdb4e632a9f655a8a9ec12ec1bd031b0a..5d17344ee8925c2aa0da7cdba327db4b250b5d49 100644 --- a/tf_adapter/python/npu_bridge/estimator/npu/npu_dynamic_rnn.py +++ b/tf_adapter/python/npu_bridge/estimator/npu/npu_dynamic_rnn.py @@ -231,6 +231,8 @@ class DynamicGRUV2(_DynamicBasic): self._args["weight_hidden"] = self._gruv2_weight_hidden self._args["bias_input"] = self._bias_input self._args["bias_hidden"] = self._bias_hidden + if seq_length is not None: + self._args["seq_length"] = seq_length return gen_npu_ops.dynamic_gru_v2(**self._args) @@ -326,6 +328,8 @@ class DynamicAUGRU(_DynamicBasic): self._args["weight_att"] = weight_att self._args["bias_input"] = self._bias_input self._args["bias_hidden"] = self._bias_hidden + if seq_length is not None: + self._args["seq_length"] = seq_length return gen_npu_ops.dynamic_augru(**self._args) diff --git a/tf_adapter/python/npu_bridge/estimator/npu_ops.py b/tf_adapter/python/npu_bridge/estimator/npu_ops.py index 93a3b18036f78e69141ca18c409deccd552dac81..907479c76217bedab72ee60d0dd34833016a613f 100644 --- a/tf_adapter/python/npu_bridge/estimator/npu_ops.py +++ b/tf_adapter/python/npu_bridge/estimator/npu_ops.py @@ -230,9 +230,11 @@ def dynamic_gru_v2_grad(op, dy, doutput_h, dupdate, dreset, dnew, dhidden_new): (y, output_h, update, reset, new, hidden_new) = op.outputs (dw_input, dw_hidden, db_input, db_hidden, dx, dh_prev) = gen_npu_ops.dynamic_gru_v2_grad(x, weight_input, weight_hidden, y, init_h, - output_h, dy, doutput_h, + output_h, dy, + doutput_h[-1], update, reset, new, hidden_new, + seq_length, direction=op.get_attr( "direction"), cell_depth=op.get_attr( @@ -268,10 +270,11 @@ def dynamic_augru_grad(op, dy, doutput_h, dupdate, dupdate_att, dreset, dnew, dh weight_hidden, weight_att, y, init_h, output_h, - dy, doutput_h, + dy, doutput_h[-1], update, update_att, reset, new, hidden_new, + seq_length, direction=op.get_attr( "direction"), cell_depth=op.get_attr( diff --git a/tf_adapter/tests/st/kernels/testcase/dynamic_augru_grad_test.cc b/tf_adapter/tests/st/kernels/testcase/dynamic_augru_grad_test.cc index 6cf752c5ed5c087f705de27c68e1c110a076c011..71febcbd1f18bd1a2dcc47680789973e14f8f01c 100644 --- a/tf_adapter/tests/st/kernels/testcase/dynamic_augru_grad_test.cc +++ b/tf_adapter/tests/st/kernels/testcase/dynamic_augru_grad_test.cc @@ -29,7 +29,7 @@ FakeInputFunctor FakeInputStub(DataType dt) { TEST(DynamicAUGRUGradTest, TestDynamicAUGRUGrad) { DataTypeSlice input_types({DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, - DT_FLOAT, DT_FLOAT}); + DT_FLOAT, DT_FLOAT, DT_INT32}); MemoryTypeSlice input_memory_types; DataTypeSlice output_types( {DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT}); @@ -72,13 +72,14 @@ TEST(DynamicAUGRUGradOpTest, TestDynamicAUGRUGradShapeInference) { .Input(FakeInputStub(DT_FLOAT)) .Input(FakeInputStub(DT_FLOAT)) .Input(FakeInputStub(DT_FLOAT)) + .Input(FakeInputStub(DT_INT32)) .Finalize(&def)); shape_inference::InferenceContext c( 0, &def, op_def, {TShape({1, 16, 16}), TShape({16, 48}), TShape({16, 48}), TShape({1, 16}), TShape({16, 16}), TShape({16, 16}), TShape({16, 16}), TShape({16, 16}), TShape({16, 16}), TShape({16, 16}), TShape({16, 16}), TShape({16, 16}), TShape({16, 16}), - TShape({16, 16})}, + TShape({16, 16}), TShape({16})}, {}, {}, {}); TF_CHECK_OK(reg->shape_inference_fn(&c)); } diff --git a/tf_adapter/tests/st/kernels/testcase/dynamic_gruv2_grad_test.cc b/tf_adapter/tests/st/kernels/testcase/dynamic_gruv2_grad_test.cc index d53e4d0bcc41381e3f572014f99a3770632a5fc1..b60069e51264d600bdd9be93e83204a03ba33f19 100644 --- a/tf_adapter/tests/st/kernels/testcase/dynamic_gruv2_grad_test.cc +++ b/tf_adapter/tests/st/kernels/testcase/dynamic_gruv2_grad_test.cc @@ -29,7 +29,7 @@ FakeInputFunctor FakeInputStub(DataType dt) { TEST(DynamicGruV2GradTest, TestDynamicGruV2Grad) { DataTypeSlice input_types({DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, - DT_FLOAT, DT_FLOAT}); + DT_FLOAT, DT_FLOAT, DT_INT32}); MemoryTypeSlice input_memory_types; DataTypeSlice output_types( {DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT}); @@ -70,13 +70,14 @@ TEST(DynamicGruV2GradOpTest, TestDynamicGruV2GradShapeInference) { .Input(FakeInputStub(DT_FLOAT)) .Input(FakeInputStub(DT_FLOAT)) .Input(FakeInputStub(DT_FLOAT)) + .Input(FakeInputStub(DT_INT32)) .Finalize(&def)); shape_inference::InferenceContext c( 0, &def, op_def, {TShape({1, 16, 16}), TShape({16, 48}), TShape({16, 48}), TShape({16, 16}), TShape({16, 16}), TShape({16, 16}), TShape({16, 16}), TShape({16, 16}), TShape({16, 16}), TShape({16, 16}), TShape({16, 16}), - TShape({16, 16})}, + TShape({16, 16}), TShape({16})}, {}, {}, {}); TF_CHECK_OK(reg->shape_inference_fn(&c)); } diff --git a/tf_adapter/tests/ut/kernels/testcase/dynamic_augru_grad_test.cc b/tf_adapter/tests/ut/kernels/testcase/dynamic_augru_grad_test.cc index 7f9dcc4daa88dfa256c3f91eb07f3fa8f5d6a467..1cec26db5d0b4866f46ea97d1e99e098335e33d3 100644 --- a/tf_adapter/tests/ut/kernels/testcase/dynamic_augru_grad_test.cc +++ b/tf_adapter/tests/ut/kernels/testcase/dynamic_augru_grad_test.cc @@ -27,60 +27,58 @@ FakeInputFunctor FakeInputStub(DataType dt) { } TEST(DynamicAUGRUGradTest, TestDynamicAUGRUGrad) { -DataTypeSlice input_types({DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, - DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, - DT_FLOAT, DT_FLOAT}); -MemoryTypeSlice input_memory_types; -DataTypeSlice output_types( - {DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT}); -MemoryTypeSlice output_memory_types; -DeviceBase *device = new DeviceBase(Env::Default()); -NodeDef *node_def = new NodeDef(); -OpDef *op_def = new OpDef(); -OpKernelConstruction *context = new OpKernelConstruction( - DEVICE_CPU, device, nullptr, node_def, op_def, nullptr, input_types, - input_memory_types, output_types, output_memory_types, 1, nullptr); -DynamicAUGRUGradOP dynamic_augru_grad(context); -OpKernelContext *ctx = nullptr; -dynamic_augru_grad.Compute(ctx); -dynamic_augru_grad.IsExpensive(); -delete device; -delete node_def; -delete op_def; -delete context; + DataTypeSlice input_types({DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, + DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_INT32}); + MemoryTypeSlice input_memory_types; + DataTypeSlice output_types({DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT}); + MemoryTypeSlice output_memory_types; + DeviceBase *device = new DeviceBase(Env::Default()); + NodeDef *node_def = new NodeDef(); + OpDef *op_def = new OpDef(); + OpKernelConstruction *context = + new OpKernelConstruction(DEVICE_CPU, device, nullptr, node_def, op_def, nullptr, input_types, input_memory_types, + output_types, output_memory_types, 1, nullptr); + DynamicAUGRUGradOP dynamic_augru_grad(context); + OpKernelContext *ctx = nullptr; + dynamic_augru_grad.Compute(ctx); + dynamic_augru_grad.IsExpensive(); + delete device; + delete node_def; + delete op_def; + delete context; } TEST(DynamicAUGRUGradOpTest, TestDynamicAUGRUGradShapeInference) { -const OpRegistrationData *reg; -TF_CHECK_OK(OpRegistry::Global()->LookUp("DynamicAUGRUGrad", ®)); -OpDef op_def = reg->op_def; -NodeDef def; -TF_CHECK_OK(NodeDefBuilder("dummy", &op_def) -.Attr("T", DT_FLOAT) -.Attr("direction", "BIDIRECTIONAL") -.Input(FakeInputStub(DT_FLOAT)) -.Input(FakeInputStub(DT_FLOAT)) -.Input(FakeInputStub(DT_FLOAT)) -.Input(FakeInputStub(DT_FLOAT)) -.Input(FakeInputStub(DT_FLOAT)) -.Input(FakeInputStub(DT_FLOAT)) -.Input(FakeInputStub(DT_FLOAT)) -.Input(FakeInputStub(DT_FLOAT)) -.Input(FakeInputStub(DT_FLOAT)) -.Input(FakeInputStub(DT_FLOAT)) -.Input(FakeInputStub(DT_FLOAT)) -.Input(FakeInputStub(DT_FLOAT)) -.Input(FakeInputStub(DT_FLOAT)) -.Input(FakeInputStub(DT_FLOAT)) -.Finalize(&def)); -shape_inference::InferenceContext c( - 0, &def, op_def, - {TShape({1, 16, 16}), TShape({16, 48}), TShape({16, 48}), TShape({1, 16}), - TShape({16, 16}), TShape({16, 16}), TShape({16, 16}), TShape({16, 16}), - TShape({16, 16}), TShape({16, 16}), TShape({16, 16}), TShape({16, 16}), TShape({16, 16}), - TShape({16, 16})}, - {}, {}, {}); -TF_CHECK_OK(reg->shape_inference_fn(&c)); + const OpRegistrationData *reg; + TF_CHECK_OK(OpRegistry::Global()->LookUp("DynamicAUGRUGrad", ®)); + OpDef op_def = reg->op_def; + NodeDef def; + TF_CHECK_OK(NodeDefBuilder("dummy", &op_def) + .Attr("T", DT_FLOAT) + .Attr("direction", "BIDIRECTIONAL") + .Input(FakeInputStub(DT_FLOAT)) + .Input(FakeInputStub(DT_FLOAT)) + .Input(FakeInputStub(DT_FLOAT)) + .Input(FakeInputStub(DT_FLOAT)) + .Input(FakeInputStub(DT_FLOAT)) + .Input(FakeInputStub(DT_FLOAT)) + .Input(FakeInputStub(DT_FLOAT)) + .Input(FakeInputStub(DT_FLOAT)) + .Input(FakeInputStub(DT_FLOAT)) + .Input(FakeInputStub(DT_FLOAT)) + .Input(FakeInputStub(DT_FLOAT)) + .Input(FakeInputStub(DT_FLOAT)) + .Input(FakeInputStub(DT_FLOAT)) + .Input(FakeInputStub(DT_FLOAT)) + .Input(FakeInputStub(DT_INT32)) + .Finalize(&def)); + shape_inference::InferenceContext c(0, &def, op_def, + {TShape({1, 16, 16}), TShape({16, 48}), TShape({16, 48}), TShape({1, 16}), + TShape({16, 16}), TShape({16, 16}), TShape({16, 16}), TShape({16, 16}), + TShape({16, 16}), TShape({16, 16}), TShape({16, 16}), TShape({16, 16}), + TShape({16, 16}), TShape({16, 16}), TShape({16})}, + {}, {}, {}); + TF_CHECK_OK(reg->shape_inference_fn(&c)); } -} // namespace -} // namespace tensorflow \ No newline at end of file +} // namespace +} // namespace tensorflow \ No newline at end of file diff --git a/tf_adapter/tests/ut/kernels/testcase/dynamic_gruv2_grad_test.cc b/tf_adapter/tests/ut/kernels/testcase/dynamic_gruv2_grad_test.cc index d53e4d0bcc41381e3f572014f99a3770632a5fc1..b60069e51264d600bdd9be93e83204a03ba33f19 100644 --- a/tf_adapter/tests/ut/kernels/testcase/dynamic_gruv2_grad_test.cc +++ b/tf_adapter/tests/ut/kernels/testcase/dynamic_gruv2_grad_test.cc @@ -29,7 +29,7 @@ FakeInputFunctor FakeInputStub(DataType dt) { TEST(DynamicGruV2GradTest, TestDynamicGruV2Grad) { DataTypeSlice input_types({DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, - DT_FLOAT, DT_FLOAT}); + DT_FLOAT, DT_FLOAT, DT_INT32}); MemoryTypeSlice input_memory_types; DataTypeSlice output_types( {DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT}); @@ -70,13 +70,14 @@ TEST(DynamicGruV2GradOpTest, TestDynamicGruV2GradShapeInference) { .Input(FakeInputStub(DT_FLOAT)) .Input(FakeInputStub(DT_FLOAT)) .Input(FakeInputStub(DT_FLOAT)) + .Input(FakeInputStub(DT_INT32)) .Finalize(&def)); shape_inference::InferenceContext c( 0, &def, op_def, {TShape({1, 16, 16}), TShape({16, 48}), TShape({16, 48}), TShape({16, 16}), TShape({16, 16}), TShape({16, 16}), TShape({16, 16}), TShape({16, 16}), TShape({16, 16}), TShape({16, 16}), TShape({16, 16}), - TShape({16, 16})}, + TShape({16, 16}), TShape({16})}, {}, {}, {}); TF_CHECK_OK(reg->shape_inference_fn(&c)); }