diff --git a/test/test_network_ops/test_gru.py b/test/test_network_ops/test_gru.py new file mode 100644 index 0000000000000000000000000000000000000000..2fdc2e17ef55329f06c07b5b1cbd4566a875e7fc --- /dev/null +++ b/test/test_network_ops/test_gru.py @@ -0,0 +1,67 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd +# Copyright (c) 2019, Facebook CORPORATION. +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.testcase import TestCase, run_tests + +class TestGru(TestCase): + def test_gru(self, device="npu"): + shape_format = [ + [[np.float16, (1, 3, 2)], [np.float16, (1, 3, 2)], 2, 2, 1, False, True, False], + [[np.float32, (2, 1, 1)], [np.float32, (1, 2, 2)], 1, 2, 1, False, False, True], + [[np.float16, (1, 3, 1)], [np.float16, (2, 3, 2)], 1, 2, 2, False, True, False], + [[np.float32, (1, 1, 2)], [np.float32, (1, 1, 3)], 2, 3, 1, False, False, False], + [[np.float16, (1, 1, 1)], [np.float16, (3, 1, 1)], 1, 1, 3, False, True, True], + ] + + for item in shape_format: + cpu_gru = torch.nn.GRU(input_size=item[2], hidden_size=item[3], num_layers=item[4], + bidirectional=item[5], bias=item[-2], batch_first=item[-1]) + npu_gru = copy.deepcopy(cpu_gru).npu() + + input1 = np.random.uniform(0, 1, item[0][1]).astype(item[0][0]) + if item[0][0] == np.float16: + cpu_input1 = torch.from_numpy(input1.astype(np.float32)) + else: + cpu_input1 = torch.from_numpy(input1) + npu_input1 = torch.from_numpy(input1).npu() + + h0 = np.random.uniform(0, 1, item[1][1]).astype(item[1][0]) + if item[1][0] == np.float16: + cpu_h0 = torch.from_numpy(h0.astype(np.float32)) + else: + cpu_h0 = torch.from_numpy(h0) + npu_h0 = torch.from_numpy(h0).npu() + + cpu_output_y, cpu_output_h = cpu_gru(cpu_input1, cpu_h0) + npu_output_y, npu_output_h = npu_gru(npu_input1, npu_h0) + + if item[0][0] == np.float16: + self.assertRtolEqual(cpu_output_y.detach().numpy().astype(np.float16), + npu_output_y.cpu().detach().numpy()) + self.assertRtolEqual(cpu_output_h.detach().numpy().astype(np.float16), + npu_output_h.cpu().detach().numpy()) + else: + # Ascend: fp33 isn't enough precision, relaxation of precision requirement temporary + self.assertRtolEqual(cpu_output_y.detach().numpy(), npu_output_y.cpu().detach().numpy(), prec=1.e-1) + self.assertRtolEqual(cpu_output_h.detach().numpy(), npu_output_h.cpu().detach().numpy(), prec=1.e-1) + +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/test/test_network_ops/test_gru_backward.py b/test/test_network_ops/test_gru_backward.py new file mode 100644 index 0000000000000000000000000000000000000000..7a1fa437172f4b62b7e03885268940503798e6d1 --- /dev/null +++ b/test/test_network_ops/test_gru_backward.py @@ -0,0 +1,78 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd +# Copyright (c) 2019, Facebook CORPORATION. +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.testcase import TestCase, run_tests + +class TestGruBackward(TestCase): + def test_gru_backward(self, device="npu"): + shape_format = [ + [[np.float16, (16, 32, 64)], 64, 32], + [[np.float16, (5, 32, 64)], 64, 32], + [[np.float32, (5, 32, 64)], 64, 32], + [[np.float32, (5, 32, 64)], 64, 64], + ] + + for item in shape_format: + cpu_gru = torch.nn.GRU(input_size=item[1], hidden_size=item[2], num_layers=1, bidirectional=False) + cpu_gru.weight_ih_l0.requires_grad_(True) + cpu_gru.weight_hh_l0.requires_grad_(True) + cpu_gru.bias_ih_l0.requires_grad_(True) + cpu_gru.bias_hh_l0.requires_grad_(True) + npu_gru = copy.deepcopy(cpu_gru).npu() + + input1 = np.random.uniform(0, 1, item[0][1]).astype(item[0][0]) + cpu_input1 = torch.from_numpy(input1.astype(np.float32)) + cpu_input1.requires_grad_(True) + npu_input1 = torch.from_numpy(input1).npu() + npu_input1.requires_grad_(True) + + cpu_output_y, cpu_output_h = cpu_gru(cpu_input1) + npu_output_y, npu_output_h = npu_gru(npu_input1) + + self.assertRtolEqual(cpu_output_y.detach().numpy(), npu_output_y.cpu().detach().numpy().astype(np.float32), + prec=1.e-1) + self.assertRtolEqual(cpu_output_h.detach().numpy(), npu_output_h.cpu().detach().numpy().astype(np.float32), + prec=1.e-1) + + cpu_input1.retain_grad() + cpu_output_y.backward(torch.ones(cpu_output_y.size(), dtype=torch.float)) + cpu_dx = cpu_input1.grad + cpu_dw_ih = cpu_gru.weight_ih_l0.grad + cpu_dw_hh = cpu_gru.weight_hh_l0.grad + cpu_db_ih = cpu_gru.bias_ih_l0.grad + cpu_db_hh = cpu_gru.bias_hh_l0.grad + + npu_input1.retain_grad() + npu_output_y.backward(torch.ones(npu_output_y.size(), dtype=torch.float).npu()) + npu_dx = npu_input1.grad + npu_dw_ih = npu_gru.weight_ih_l0.grad + npu_dw_hh = npu_gru.weight_hh_l0.grad + npu_db_ih = npu_gru.bias_ih_l0.grad + npu_db_hh = npu_gru.bias_hh_l0.grad + + self.assertRtolEqual(cpu_dx.numpy(), npu_dx.cpu().numpy().astype(np.float32), prec=1.e-1) + self.assertRtolEqual(cpu_dw_ih.numpy(), npu_dw_ih.cpu().numpy().astype(np.float32), prec=1.e-1) + self.assertRtolEqual(cpu_dw_hh.numpy(), npu_dw_hh.cpu().numpy().astype(np.float32), prec=1.e-1) + self.assertRtolEqual(cpu_db_ih.numpy(), npu_db_ih.cpu().numpy().astype(np.float32), prec=1.e1) + self.assertRtolEqual(cpu_db_hh.numpy(), npu_db_hh.cpu().numpy().astype(np.float32), prec=1.e1) + +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/test/test_network_ops/test_gru_true.py b/test/test_network_ops/test_gru_true.py new file mode 100644 index 0000000000000000000000000000000000000000..521fa5d57cf7df56cc7063086e040fa7e1893421 --- /dev/null +++ b/test/test_network_ops/test_gru_true.py @@ -0,0 +1,68 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd +# Copyright (c) 2019, Facebook CORPORATION. +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.testcase import TestCase, run_tests + +class TestGru(TestCase): + def test_gru(self, device="npu"): + shape_format = [ + [[np.float32, (2, 3, 2)], [np.float32, (2, 2, 1)], 2, 1, 1, True, False, True], + [[np.float32, (1, 1, 1)], [np.float32, (6, 1, 1)], 1, 1, 3, True, True, False], + [[np.float32, (2, 1, 1)], [np.float32, (4, 1, 1)], 1, 1, 2, True, True, False], + [[np.float16, (1, 2, 3)], [np.float16, (4, 1, 2)], 3, 2, 2, True, False, True], + [[np.float32, (2, 2, 1)], [np.float32, (2, 2, 2)], 1, 2, 1, True, True, False], + [[np.float16, (1, 2, 1)], [np.float16, (4, 1, 2)], 1, 2, 2, True, False, True], + ] + + for item in shape_format: + cpu_gru = torch.nn.GRU(input_size=item[2], hidden_size=item[3], num_layers=item[4], + bidirectional=item[5], bias=item[-2], batch_first=item[-1]) + npu_gru = copy.deepcopy(cpu_gru).npu() + + input2 = np.random.uniform(0, 1, item[0][1]).astype(item[0][0]) + if item[0][0] == np.float16: + cpu_input2 = torch.from_numpy(input2.astype(np.float32)) + else: + cpu_input2 = torch.from_numpy(input2) + npu_input2 = torch.from_numpy(input2).npu() + + h0 = np.random.uniform(0, 1, item[1][1]).astype(item[1][0]) + if item[1][0] == np.float16: + cpu_h0 = torch.from_numpy(h0.astype(np.float32)) + else: + cpu_h0 = torch.from_numpy(h0) + npu_h0 = torch.from_numpy(h0).npu() + + cpu_output_y1, cpu_output_h1 = cpu_gru(cpu_input2, cpu_h0) + npu_output_y1, npu_output_h1 = npu_gru(npu_input2, npu_h0) + + if item[0][0] == np.float16: + self.assertRtolEqual(cpu_output_y1.detach().numpy().astype(np.float16), + npu_output_y1.cpu().detach().numpy()) + self.assertRtolEqual(cpu_output_h1.detach().numpy().astype(np.float16), + npu_output_h1.cpu().detach().numpy()) + else: + # Ascend: fp33 isn't enough precision, relaxation of precision requirement temporary + self.assertRtolEqual(cpu_output_y1.detach().numpy(), npu_output_y1.cpu().detach().numpy(), prec=1.e-1) + self.assertRtolEqual(cpu_output_h1.detach().numpy(), npu_output_h1.cpu().detach().numpy(), prec=1.e-1) + +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/torch_npu/csrc/aten/npu_native_functions.yaml b/torch_npu/csrc/aten/npu_native_functions.yaml index f794be4e55b7e46b46793010b04f51eb2ae36f72..4ca5faea1b49bb82f9527cc67d66f78b8bfc4086 100644 --- a/torch_npu/csrc/aten/npu_native_functions.yaml +++ b/torch_npu/csrc/aten/npu_native_functions.yaml @@ -1048,7 +1048,6 @@ supported: - lstm.input - lstm.data - gru.input - - gru.data - rnn_tanh.input - rnn_tanh.data - rnn_relu.input @@ -1928,6 +1927,8 @@ custom: - func: npu_rotated_overlaps(Tensor self, Tensor query_boxes, bool trans=False) -> Tensor - func: npu_silu_backward(Tensor grad_output, Tensor x0, Tensor x1) -> Tensor - func: npu_rotated_iou(Tensor self, Tensor query_boxes, bool trans=False, int mode=0, bool is_cross=True, float v_threshold=0.0, float e_threshold=0.0) -> Tensor + - func: npu_gru_backward(Tensor? grady, Tensor? gradh, Tensor input, Tensor weight_input, Tensor weight_hidden, Tensor bias_input, Tensor bias_hidden, Tensor seq_length, Tensor hx, Tensor y_output, Tensor h_output, Tensor output_updata, Tensor output_reset, Tensor output_new, Tensor hidden_new) -> Tensor[] + custom_autograd: - func: npu_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor - func: npu_convolution_transpose(Tensor input, Tensor weight, Tensor? bias, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups) -> Tensor @@ -1948,4 +1949,5 @@ custom_autograd: - func: npu_dtype_cast(Tensor self, ScalarType dtype) -> Tensor variants: function, method - func: npu_silu(Tensor self) -> Tensor - - func: npu_silu_(Tensor(a!) self) -> Tensor(a!) \ No newline at end of file + - func: npu_silu_(Tensor(a!) self) -> Tensor(a!) + - func: npu_gru(Tensor input, Tensor hx, Tensor weight_input, Tensor weight_hidden, Tensor bias_input, Tensor bias_hidden, Tensor seq_length, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> Tensor[] \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/GruKernelNpu.cpp b/torch_npu/csrc/aten/ops/GruKernelNpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..171d3e3461350990d1a65a36d9902c3c045586b3 --- /dev/null +++ b/torch_npu/csrc/aten/ops/GruKernelNpu.cpp @@ -0,0 +1,565 @@ +// Copyright (c) 2020, Huawei Technologies.All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { +using torch::autograd::Function; +using torch::autograd::AutogradContext; +using tensor_list = std::vector; + +struct CellParams { + CellParams(const at::Tensor& _w_ih, const at::Tensor& _w_hh) + : w_ih(_w_ih), w_hh(_w_hh), b_ih({}), b_hh({}) {}; + CellParams(const at::Tensor& _w_ih, const at::Tensor& _w_hh, const at::Tensor& _b_ih, const at::Tensor& _b_hh) + : w_ih(_w_ih), w_hh(_w_hh), b_ih(_b_ih), b_hh(_b_hh) {}; + const at::Tensor& w_ih; + const at::Tensor& w_hh; + const at::Tensor& b_ih; /* optional */ + const at::Tensor& b_hh; /* optional */ +}; + +using BidirectCellParams = std::pair; +using pair_of = std::pair; +static std::vector make_pair_vec(const std::vector& vals) { + TORCH_CHECK(vals.size() % 2 == 0, "Odd number of params or hiddens given to a bidirectional RNN"); + std::vector result; + result.reserve(vals.size() / 2); + for (size_t i = 0; i < vals.size(); i += 2) { + result.emplace_back(vals[i], vals[i + 1]); + } + return result; +} + +std::vector gru_npu( + const at::Tensor& input, + const at::Tensor& hx, + const at::Tensor& weight_input, + const at::Tensor& weight_hidden, + const at::Tensor& bias_input, + const at::Tensor& bias_hidden, + const at::Tensor& seq_length, + bool has_biases, + int64_t num_layers, + double dropout, + bool train, + bool bidirectional, + bool batch_first) { + int64_t numStep = input.size(0); + int64_t batchSize = input.size(1); + int64_t hiddenSize = bias_input.size(0) / 3; + c10::SmallVector outputSize = {numStep, batchSize, hiddenSize}; + int64_t npu_format = ACL_FORMAT_FRACTAL_NZ; + + at::Tensor output_y = OpPreparation::ApplyTensorWithFormat( + outputSize, + bias_input.options(), + npu_format); + at::Tensor output_h = OpPreparation::ApplyTensorWithFormat( + outputSize, + bias_input.options(), + ACL_FORMAT_ND); + at::Tensor output_updata = OpPreparation::ApplyTensorWithFormat( + outputSize, + bias_input.options(), + npu_format); + at::Tensor output_reset = OpPreparation::ApplyTensorWithFormat( + outputSize, + bias_input.options(), + npu_format); + at::Tensor output_new = OpPreparation::ApplyTensorWithFormat( + outputSize, + bias_input.options(), + npu_format); + at::Tensor hidden_new = OpPreparation::ApplyTensorWithFormat( + outputSize, + bias_input.options(), + npu_format); + + OpCommand cmd; + cmd.Name("DynamicGRUV2") + .Input(input) + .Input(weight_input) + .Input(weight_hidden) + .Input(bias_input) + .Input(bias_hidden) + .Input() + .Input(hx) + .Output(output_y) + .Output(output_h) + .Output(output_updata) + .Output(output_reset) + .Output(output_new) + .Output(hidden_new) + .Attr("direction", (string)"UNIDIRECTIONAL") + .Attr("cell_depth", (int64_t)1) + .Attr("keep_prob", (float)1.0) + .Attr("cell_clip", (float)-1.0) + .Attr("num_proj", (int64_t)0) + .Attr("time_major", true) + .Attr("activation", (string)"tanh") + .Attr("gate_order", (string)"rzh") + .Attr("reset_after", true) + .Attr("is_training", true) + .Run(); + tensor_list results = {output_y, output_h, output_updata, output_reset, output_new, hidden_new}; + return results; +} + +std::vector NPUNativeFunctions::npu_gru_backward( + const c10::optional& grady_opt, + const c10::optional& gradh_opt, + const at::Tensor& input, + const at::Tensor& weight_input, + const at::Tensor& weight_hidden, + const at::Tensor& bias_input, + const at::Tensor& bias_hidden, + const at::Tensor& seq_length, + const at::Tensor& init_h, + const at::Tensor& output_y, + const at::Tensor& output_h, + const at::Tensor& output_updata, + const at::Tensor& output_reset, + const at::Tensor& output_new, + const at::Tensor& hidden_new) { + const at::Tensor& grady = c10::value_or_else(grady_opt, [] {return at::Tensor();}); + const at::Tensor& gradh = c10::value_or_else(gradh_opt, [] {return at::Tensor();}); + at::Tensor inh = at::squeeze(init_h, 0); + auto grad_y = + grady.defined() ? grady : OpPreparation::ApplyTensorWithFormat(output_y.sizes(), output_y.options(), ACL_FORMAT_FRACTAL_NZ).mul(0); + auto grad_h = + gradh.defined() ? gradh[input.size(0)-1] : OpPreparation::ApplyTensorWithFormat(inh.sizes(), output_h.options(), ACL_FORMAT_FRACTAL_NZ).mul(0); + + at::Tensor mask = at::zeros({}, input.options().dtype(at::kByte)); // uint8 + at::Tensor seq_lengths = at::zeros({}, input.options()); + + int64_t npu_format = ACL_FORMAT_ND; + + at::Tensor grad_w_input = OpPreparation::ApplyTensorWithFormat(weight_input.sizes(), input.options(), npu_format); + at::Tensor grad_w_hidden = OpPreparation::ApplyTensorWithFormat(weight_hidden.sizes(), input.options(), npu_format); + at::Tensor grad_x = OpPreparation::ApplyTensorWithFormat(input.sizes(), input.options(), npu_format); + at::Tensor grad_b_input = OpPreparation::ApplyTensorWithFormat(bias_input.sizes(), input.options(), npu_format); + at::Tensor grad_b_hidden = OpPreparation::ApplyTensorWithFormat(bias_hidden.sizes(), input.options(), npu_format); + at::Tensor grad_h_prev = OpPreparation::ApplyTensorWithFormat(init_h.sizes(), input.options(), npu_format); + + OpCommand cmd; + cmd.Name("DynamicGRUV2Grad") + .Input(input) + .Input(weight_input) + .Input(weight_hidden) + .Input(output_y) + .Input(inh) + .Input(output_h) + .Input(grad_y) + .Input(grad_h) + .Input(output_updata) + .Input(output_reset) + .Input(output_new) + .Input(hidden_new) + .Input(seq_lengths) + .Input(mask) + .Output(grad_w_input) + .Output(grad_w_hidden) + .Output(grad_b_input) + .Output(grad_b_hidden) + .Output(grad_x) + .Output(grad_h_prev) + .Attr("direction", (string) "UNIDIRECTIONAL") + .Attr("cell_depth", (int64_t)1) + .Attr("keep_prob", (float)1.0) + .Attr("cell_clip", (float)-1.0) + .Attr("num_proj", (int64_t)0) + .Attr("time_major", (bool)true) + .Attr("bias_type", (string) "no_bias") + .Attr("gate_order", (string) "rzh") + .Attr("reset_after", (bool)true) + .Run(); + tensor_list results = {grad_x, grad_h_prev, grad_w_input, grad_w_hidden, grad_b_input, grad_b_hidden}; + return results; +} + +class NPUGruFunction : public torch::autograd::Function { +public: + static tensor_list forward(AutogradContext *ctx, + const at::Tensor& input, + const at::Tensor& hx, + const at::Tensor& weight_input, + const at::Tensor& weight_hidden, + const at::Tensor& bias_input, + const at::Tensor& bias_hidden, + const at::Tensor& seq_length, + bool has_biases, + int64_t num_layers, + double dropout, + bool train, + bool bidirectional, + bool batch_first) { + auto result = gru_npu(input, hx, weight_input, weight_hidden, + bias_input, bias_hidden, seq_length, has_biases, num_layers, dropout, train, bidirectional, batch_first); + auto result0 = result[0]; + ctx->saved_data["result0"] = result0; + auto result1 = result[1]; + ctx->saved_data["result1"] = result1; + auto result2 = result[2]; + ctx->saved_data["result2"] = result2; + auto result3 = result[3]; + ctx->saved_data["result3"] = result3; + auto result4 = result[4]; + ctx->saved_data["result4"] = result4; + auto result5 = result[5]; + ctx->saved_data["result5"] = result5; + ctx->saved_data["seq_length"] = seq_length; + + at::AutoNonVariableTypeMode g; + ctx->save_for_backward({weight_input, weight_hidden, input, bias_input, bias_hidden, hx}); + return result; + } + + static tensor_list backward(AutogradContext *ctx, + tensor_list grad_outputs) { + auto result0 = ctx->saved_data["result0"].toTensor(); + auto result1 = ctx->saved_data["result1"].toTensor(); + auto result2 = ctx->saved_data["result2"].toTensor(); + auto result3 = ctx->saved_data["result3"].toTensor(); + auto result4 = ctx->saved_data["result4"].toTensor(); + auto result5 = ctx->saved_data["result5"].toTensor(); + auto seq_length = ctx->saved_data["seq_length"].toTensor(); + + auto saved = ctx->get_saved_variables(); + auto weight_input = saved[0]; + auto weight_hidden = saved[1]; + auto input = saved[2]; + auto bias_input = saved[3]; + auto bias_hidden = saved[4]; + auto hx = saved[5]; + + tensor_list result = NPUNativeFunctions::npu_gru_backward( + grad_outputs[0], + grad_outputs[1], + input, + weight_input, + weight_hidden, + bias_input, + bias_hidden, + seq_length, + hx, + result0, + result1, + result2, + result3, + result4, + result5); + + tensor_list output = { + result[0], + result[1], + result[2], + result[3], + result[4], + result[5], + at::Tensor(), + at::Tensor(), + at::Tensor(), + at::Tensor(), + at::Tensor(), + at::Tensor(), + at::Tensor()}; + return output; + } +}; +std::vector NPUNativeFunctions::npu_gru( + const at::Tensor& input, + const at::Tensor& hx, + const at::Tensor& weight_input, + const at::Tensor& weight_hidden, + const at::Tensor& bias_input, + const at::Tensor& bias_hidden, + const at::Tensor& seq_length, + bool has_biases, + int64_t num_layers, + double dropout, + bool train, + bool bidirectional, + bool batch_first) { + return NPUGruFunction::apply(input, hx, weight_input, weight_hidden, + bias_input, bias_hidden, seq_length, has_biases, num_layers, dropout, train, bidirectional, batch_first); +} + +tuple gru_single_layer_bidirec_npu( + const at::Tensor& input, + pair_of& hx, + BidirectCellParams params, + bool has_biases, + int64_t num_layers, + double dropout, + bool train, + bool bidirectional, + bool batch_first) { + at::Tensor fw_weight_input = params.first.w_ih.t(); + at::Tensor fw_weight_hidden = params.first.w_hh.t(); + at::Tensor rev_weight_input = params.second.w_ih.t(); + at::Tensor rev_weight_hidden = params.second.w_hh.t(); + at::Tensor fw_bias_input; + at::Tensor fw_bias_hidden; + at::Tensor rev_bias_input; + at::Tensor rev_bias_hidden; + if (has_biases) { + fw_bias_input = params.first.b_ih.to(input.dtype()); + fw_bias_hidden = params.first.b_hh.to(input.dtype()); + rev_bias_input = params.second.b_ih.to(input.dtype()); + rev_bias_hidden = params.second.b_hh.to(input.dtype()); + } else { + fw_bias_input = OpPreparation::ApplyTensorWithFormat(fw_weight_input.size(1), input.options(), ACL_FORMAT_FRACTAL_NZ).mul(0); + fw_bias_hidden = OpPreparation::ApplyTensorWithFormat(fw_weight_hidden.size(1), input.options(), ACL_FORMAT_FRACTAL_NZ).mul(0); + rev_bias_input = OpPreparation::ApplyTensorWithFormat(rev_weight_input.size(1), input.options(), ACL_FORMAT_FRACTAL_NZ).mul(0); + rev_bias_hidden = OpPreparation::ApplyTensorWithFormat(rev_weight_hidden.size(1), input.options(), ACL_FORMAT_FRACTAL_NZ).mul(0); + } + at::Tensor seq_length = OpPreparation::ApplyTensorWithFormat({}, input.options(), ACL_FORMAT_ND); + auto results = NPUNativeFunctions::npu_gru( + input, + hx.first, + fw_weight_input, + fw_weight_hidden, + fw_bias_input, + fw_bias_hidden, + seq_length, + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_first); + int64_t numStep = input.size(0); + at::Tensor fw_output_hy = at::unsqueeze(results[1][numStep - 1], 0); + at::Tensor fw_output = results[0]; + auto rev_inputs = at::flip(input, {0}); // reverse input; + auto rev_results = NPUNativeFunctions::npu_gru( + rev_inputs, + hx.second, + rev_weight_input, + rev_weight_hidden, + rev_bias_input, + rev_bias_hidden, + seq_length, + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_first); + at::Tensor rev_output_hy = at::unsqueeze(rev_results[1][numStep - 1], 0); + at::Tensor rev_output = at::flip(rev_results[0],{0}); + return std::make_tuple(at::cat({fw_output, rev_output}, -1), + at::cat({fw_output_hy, rev_output_hy})); +} + +tuple gru_single_layer_direc_npu( + const at::Tensor& input, + const at::Tensor& hx, + CellParams params, + bool has_biases, + int64_t num_layers, + double dropout, + bool train, + bool bidirectional, + bool batch_first) { + // get weight fp16 + at::Tensor weight_input = params.w_ih.t(); + at::Tensor weight_hidden = params.w_hh.t(); + + // get bias fp16 / fp32 + at::Tensor bias_input; + at::Tensor bias_hidden; + if (has_biases) { + bias_input = params.b_ih.to(input.dtype()); + bias_hidden = params.b_hh.to(input.dtype()); + } else { + bias_input = OpPreparation::ApplyTensorWithFormat(weight_input.size(1), input.options(), ACL_FORMAT_FRACTAL_NZ).mul(0); + bias_hidden = OpPreparation::ApplyTensorWithFormat(weight_hidden.size(1), input.options(), ACL_FORMAT_FRACTAL_NZ).mul(0); + } + + at::Tensor seq_length = OpPreparation::ApplyTensorWithFormat({}, input.options(), ACL_FORMAT_ND); + + auto results = NPUNativeFunctions::npu_gru( + input, + hx, + weight_input, + weight_hidden, + bias_input, + bias_hidden, + seq_length, + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_first); + int64_t numStep = input.size(0); + at::Tensor output_hy = at::unsqueeze(results[1][numStep - 1], 0); + return std::tuple(results[0], output_hy); +} + +tuple apply_layer_stack( + const at::Tensor& input, + std::vector hx, + std::vector params, + bool has_biases, + int64_t num_layers, + double dropout, + bool train, + bool bidirectional, + bool batch_first) { + auto layer_input = input; + auto hidden_it = hx.begin(); + auto params_size = params.size(); + + std::vector weights; + std::vector::iterator params_it = params.begin(); + if (has_biases) { + for (int64_t i = 0; i < params_size; i = i + 4){ + weights.emplace_back(CellParams((*(params_it+i)).first, (*(params_it+i)).second, + (*(params_it+i+1)).first, (*(params_it+i+1)).second), + CellParams((*(params_it+i+2)).first, (*(params_it+i+2)).second, + (*(params_it+i+3)).first, (*(params_it+i+3)).second)); + } + } else { + for (int64_t i = 0; i < params_size; i = i + 2){ + weights.emplace_back(CellParams((*(params_it+i)).first, (*(params_it+i)).second), + CellParams((*(params_it+i+1)).first, (*(params_it+i+1)).second)); + } + } + auto weights_it = weights.begin(); + std::vector final_hiddens; + for (int64_t l = 0; l < num_layers; ++l) { + auto layer_output = gru_single_layer_bidirec_npu( + layer_input, + *(hidden_it++), + *(weights_it++), + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_first); + final_hiddens.push_back(std::move(std::get<1>(layer_output))); + layer_input = std::get<0>(layer_output); + } + return std::make_tuple(layer_input, at::cat(final_hiddens, 0)); +} + +tuple apply_layer_stack( + const at::Tensor& input, + std::vector& hx, + std::vector& params, + bool has_biases, + int64_t num_layers, + double dropout, + bool train, + bool bidirectional, + bool batch_first) { + auto layer_input = input; + auto hidden_it = hx.begin(); + + auto params_size = params.size(); + std::vector weights; + std::vector::iterator params_it = params.begin(); + if (has_biases) { + for (int64_t i = 0; i < params_size; i = i + 4){ + weights.emplace_back(CellParams(*(params_it+i), *(params_it+i+1), + *(params_it+i+2), *(params_it+i+3))); + } + } else { + for (int64_t i = 0; i < params_size; i = i + 2){ + weights.emplace_back(CellParams(*(params_it+i), *(params_it+i+1))); + } + } + auto weights_it = weights.begin(); + std::vector final_hiddens; + + for (int64_t l = 0; l < num_layers; ++l) { + auto layer_output = gru_single_layer_direc_npu( + layer_input, + *(hidden_it++), + *(weights_it++), + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_first); + final_hiddens.push_back(std::move(std::get<1>(layer_output))); + layer_input = std::get<0>(layer_output); + } + auto hidden_state = at::cat(final_hiddens, 0); + return std::make_tuple(layer_input, hidden_state); +} + +tuple NPUNativeFunctions::gru( + const at::Tensor& input_, + const at::Tensor& hx, + at::TensorList params, + bool has_biases, + int64_t num_layers, + double dropout, + bool train, + bool bidirectional, + bool batch_first) { + // The operator of DynamicGRU only supports the T axis as the first axis. + auto input = batch_first ? input_.transpose(0, 1) : input_; + + auto layer_hx = hx.unbind(0); + int64_t total_layers = layer_hx.size(); + std::vector hiddens; + for (int64_t i = 0; i < total_layers; ++i) { + hiddens.emplace_back(std::move(layer_hx[i])); + } + std::vector paramsVec; + for (int64_t i = 0; i < params.size(); ++i) { + paramsVec.emplace_back(std::move(params[i])); + } + tuple result; + if (bidirectional) { + result = apply_layer_stack( + input, + make_pair_vec(hiddens), + make_pair_vec(paramsVec), + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_first); + } else { + result = apply_layer_stack( + input, + hiddens, + paramsVec, + has_biases, + num_layers, + dropout, + train, + bidirectional, + batch_first); + } + std::get<0>(result) = batch_first ? std::get<0>(result).transpose(0, 1) : std::get<0>(result); + return result; +} + +} // namespace native +} // namespace at_npu \ No newline at end of file