diff --git a/test/test_network_ops/test_min_v1.py b/test/test_network_ops/test_min_v1.py new file mode 100644 index 0000000000000000000000000000000000000000..b27e3652eacc82433332bd8aff09c5c03c5c1ced --- /dev/null +++ b/test/test_network_ops/test_min_v1.py @@ -0,0 +1,38 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu + +from torch_npu.testing.testcase import TestCase, run_tests + +class TestMinV1(TestCase): + def cpu_op_exec(self, data, dim): + outputs, indices = torch.min(data, dim) + return outputs.detach() + + def npu_op_exec(self, data, dim): + data = data.to("npu") + outputs, indices = torch_npu.npu_min(data, dim) + return outputs.detach().cpu() + + def test_min_v1_fp32(self, device="npu"): + data = torch.randn(2, 2, 2, 2, dtype = torch.float32) + npu_data = data.clone() + cpu_out = self.cpu_op_exec(data, 2) + npu_out = self.npu_op_exec(npu_data, 2) + self.assertRtolEqual(cpu_out, npu_out) + +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/test/test_network_ops/test_min_v1_backward.py b/test/test_network_ops/test_min_v1_backward.py new file mode 100644 index 0000000000000000000000000000000000000000..be24e14f0520b14daa208c42da619d74f5efb885 --- /dev/null +++ b/test/test_network_ops/test_min_v1_backward.py @@ -0,0 +1,46 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu + +from torch_npu.testing.testcase import TestCase, run_tests + +class TestMinV1Backward(TestCase): + def op_exec(self, npu_flag, data, dim): + if npu_flag: + data = data.to("npu") + data.requires_grad = True + if npu_flag: + outputs, indices = torch_npu.npu_min(data, dim) + else: + outputs, indices = torch.min(data, dim) + outputs.backward(torch.ones_like(outputs)) + gradoutput = data.grad + out = outputs.detach() + if npu_flag: + out = out.cpu() + gradoutput = gradoutput.cpu() + return out, gradoutput + + def test_min_v1_backward_fp32(self, device="npu"): + data = torch.randn(2, 2, 2, 2, dtype = torch.float32) + npu_data = data.clone() + cpu_out, cpu_grad_out = self.op_exec(0, data, 2) + npu_out, npu_grad_out = self.op_exec(1, npu_data, 2) + self.assertRtolEqual(cpu_grad_out, npu_grad_out) + self.assertRtolEqual(cpu_out, npu_out) + +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/test/test_network_ops/test_mish.py b/test/test_network_ops/test_mish.py new file mode 100644 index 0000000000000000000000000000000000000000..e52d5d46f71450ff544865e0024e5049faa35dd3 --- /dev/null +++ b/test/test_network_ops/test_mish.py @@ -0,0 +1,59 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd +# Copyright (c) 2019, Facebook CORPORATION. +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import torch.nn.functional as F +import numpy as np + +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.common_utils import create_common_tensor + +class TestMish(TestCase): + def npu_op_exec(self, input1): + output = torch_npu.npu_mish(input1) + output = output.cpu().numpy() + return output + + def cpu_op_exec(self, input1): + output = input1 * (torch.tanh(F.softplus(input1))) + output = output.numpy() + return output + + def test_mish_fp32(self, device="npu"): + shape_format = [ + [[np.float32, -1, [10,30,10]]], + [[np.float32, -1, [20,30,20]]], + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item[0], 0, 100) + cpu_output = self.cpu_op_exec(cpu_input) + npu_output = self.npu_op_exec(npu_input) + self.assertRtolEqual(cpu_output, npu_output) + + def test_mish_fp16(self, device="npu"): + shape_format = [ + [[np.float16, -1, [10,30,10]]], + [[np.float16, -1, [20,30,20]]], + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item[0], 0, 100) + cpu_output = self.cpu_op_exec(cpu_input.float()).astype(np.float16) + npu_output = self.npu_op_exec(npu_input) + self.assertRtolEqual(cpu_output, npu_output) + +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/test/test_network_ops/test_mish_backward.py b/test/test_network_ops/test_mish_backward.py new file mode 100644 index 0000000000000000000000000000000000000000..e0b4081ee93e10d8353c276a80b8a6beb39908e0 --- /dev/null +++ b/test/test_network_ops/test_mish_backward.py @@ -0,0 +1,55 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd +# Copyright (c) 2019, Facebook CORPORATION. +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import torch.nn.functional as F +import numpy as np + +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.common_utils import create_common_tensor + +class TestMishBackward(TestCase): + def npu_op_exec(self, input1): + input1.requires_grad = True + output = torch.npu_mish(input1) + output.backward(torch.ones_like(output)) + output_grad = input1.grad + output_grad = output_grad.to("cpu") + output_grad = output_grad.detach().numpy() + output = output.cpu().detach().numpy() + return output_grad, output + + def cpu_op_exec(self, input1): + input1.requires_grad = True + output = input1 * (torch.tanh(F.softplus(input1))) + output.backward(torch.ones_like(output)) + output_grad = input1.grad + output_grad = output_grad.to("cpu") + output_grad = output_grad.detach().numpy() + output = output.detach().numpy() + return output_grad, output + + def test_mish_fp32(self, device="npu"): + npu_input = torch.tensor([1., 2., 3., 4., 5., 6., 7., 8., 9., 10.]).npu() + cpu_input = torch.tensor([1., 2., 3., 4., 5., 6., 7., 8., 9., 10.]) + output_grad, npu_output = self.npu_op_exec(npu_input) + ep_output_grad, ep_npu_output = self.cpu_op_exec(cpu_input) + self.assertRtolEqual(ep_output_grad, output_grad) + self.assertRtolEqual(ep_npu_output, npu_output) + +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/test/test_network_ops/test_reshape.py b/test/test_network_ops/test_reshape.py new file mode 100644 index 0000000000000000000000000000000000000000..f600ddad26dc56cf673b35fe241de91f8cae640c --- /dev/null +++ b/test/test_network_ops/test_reshape.py @@ -0,0 +1,49 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.common_utils import create_common_tensor + +class TestReshape(TestCase): + def cpu_op_exec(self, input1, shape): + output = input1.reshape(shape) + output = output.numpy() + return output + + def npu_op_exec(self, input1, shape): + output = input1.reshape(shape) + output = output.to("cpu") + output = output.numpy() + return output + + def test_reshape_shape_format(self, device="npu"): + dtype_list = [np.float16, np.float32, np.int32, np.bool_] + format_list = [0] + shape_list = [[8,8], [2,4,8], [2,4,4,2]] + shape_format = [ + [i, j, k] for i in dtype_list for j in format_list for k in shape_list + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 0, 100) + shape = [4,16] + cpu_output = self.cpu_op_exec(cpu_input, shape) + npu_output = self.npu_op_exec(npu_input, shape) + self.assertRtolEqual(cpu_output, npu_output) + +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/torch_npu/csrc/aten/npu_native_functions.yaml b/torch_npu/csrc/aten/npu_native_functions.yaml index 4ca5faea1b49bb82f9527cc67d66f78b8bfc4086..4aefe54d9e62ab2f7bceb27929abea654ccf79b2 100644 --- a/torch_npu/csrc/aten/npu_native_functions.yaml +++ b/torch_npu/csrc/aten/npu_native_functions.yaml @@ -1928,6 +1928,10 @@ custom: - func: npu_silu_backward(Tensor grad_output, Tensor x0, Tensor x1) -> Tensor - func: npu_rotated_iou(Tensor self, Tensor query_boxes, bool trans=False, int mode=0, bool is_cross=True, float v_threshold=0.0, float e_threshold=0.0) -> Tensor - func: npu_gru_backward(Tensor? grady, Tensor? gradh, Tensor input, Tensor weight_input, Tensor weight_hidden, Tensor bias_input, Tensor bias_hidden, Tensor seq_length, Tensor hx, Tensor y_output, Tensor h_output, Tensor output_updata, Tensor output_reset, Tensor output_new, Tensor hidden_new) -> Tensor[] + - func: npu_mish_backward(Tensor grad, Tensor input) -> Tensor + - func: npu_min_backward(Tensor grad, int dim, Tensor indices, int[] sizes, bool keepdim=False) -> Tensor + - func: npu_reshape(Tensor self, int[] shape, bool can_refresh=False) -> Tensor + - func: npu_reshape.out(Tensor self, int[] shape, bool can_refresh=False, *, Tensor(a!) out) -> Tensor(a!) custom_autograd: - func: npu_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor @@ -1950,4 +1954,8 @@ custom_autograd: variants: function, method - func: npu_silu(Tensor self) -> Tensor - func: npu_silu_(Tensor(a!) self) -> Tensor(a!) - - func: npu_gru(Tensor input, Tensor hx, Tensor weight_input, Tensor weight_hidden, Tensor bias_input, Tensor bias_hidden, Tensor seq_length, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> Tensor[] \ No newline at end of file + - func: npu_gru(Tensor input, Tensor hx, Tensor weight_input, Tensor weight_hidden, Tensor bias_input, Tensor bias_hidden, Tensor seq_length, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> Tensor[] + - func: npu_mish(Tensor self) -> Tensor + variants: function, method + - func: npu_min.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) + - func: npu_min.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/MinV1KernelNpu.cpp b/torch_npu/csrc/aten/ops/MinV1KernelNpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9ef8d59fc2febe53aa89954699db33d03d6d96ba --- /dev/null +++ b/torch_npu/csrc/aten/ops/MinV1KernelNpu.cpp @@ -0,0 +1,126 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { +using torch::autograd::Function; +using torch::autograd::AutogradContext; +using tensor_list = std::vector; + +tuple min_v1_out_npu_nocheck( + at::Tensor& output, + at::Tensor& indices, + const at::Tensor& self, + int64_t dim, + bool keepdim) { + OpCommand cmd; + cmd.Name("ArgMinWithValue") + .Input(self) + .Output(indices) + .Output(output) + .Attr("dimension", dim) + .Attr("keep_dims", keepdim) + .Run(); + + return std::tie(output, indices); +} + +tuple min_v1_npu(const at::Tensor& self, int64_t dim, bool keepdim) { + c10::SmallVector dims = {dim}; + c10::SmallVector outputSize = + reduce_ops_npu_output_size(self, dims, keepdim); + c10::SmallVector indicesSize = + reduce_ops_npu_output_size(self, dims, keepdim); + + int64_t npuFormat = CalcuOpUtil::get_tensor_npu_format(self); + if (outputSize.empty()) { + npuFormat = ACL_FORMAT_NCHW; + } + + at::Tensor outputs = OpPreparation::ApplyTensorWithFormat(outputSize, self.options(), npuFormat); + at::Tensor indices = OpPreparation::ApplyTensorWithFormat(indicesSize, self.options().dtype(at::kInt), npuFormat); + + min_v1_out_npu_nocheck(outputs, indices, self, dim, keepdim); + return std::tie(outputs, indices); +} + +tuple NPUNativeFunctions::npu_min(const at::Tensor& self, at::Dimname dim, bool keepdim) { + return min_v1_npu(self, dimname_to_position(self, dim), keepdim); +} + +at::Tensor NPUNativeFunctions::npu_min_backward( + const at::Tensor& grad, + int64_t dim, + const at::Tensor& indices, + at::IntArrayRef sizes, + bool keepdim) { + at::Tensor newGrad = grad; + at::Tensor newIndices = indices; + if (keepdim && sizes.size() > 0) { + newGrad = grad.squeeze(dim); + newIndices = indices.squeeze(dim); + } + auto gradInput = NPUNativeFunctions::npu_scatter( + at::native::zeros(sizes, newGrad.options()), newIndices, newGrad, dim); + return gradInput; +} + +class NPUMinFunction : public torch::autograd::Function { +public: + static tensor_list forward(AutogradContext *ctx, + const at::Tensor& self, + int64_t dim, + bool keepdim) { + ctx->saved_data["dim"] = dim; + ctx->saved_data["keepdim"] = keepdim; + ctx->saved_data["size"] = self.sizes(); + + auto result = min_v1_npu(self, dim, keepdim); + auto result1 = std::get<1>(result); + ctx->saved_data["indices"] = result1; + at::AutoNonVariableTypeMode g; + tensor_list result_list = {std::get<0>(result), result1}; + return result_list; + } + + static tensor_list backward(AutogradContext *ctx, + tensor_list grad_outputs) { + auto dim = ctx->saved_data["dim"].toInt(); + auto keepdim = ctx->saved_data["keepdim"].toBool(); + auto size = ctx->saved_data["size"].toIntVector(); + auto indices = ctx->saved_data["indices"].toTensor(); + + at::Tensor result = NPUNativeFunctions::npu_min_backward( + grad_outputs[0], dim, indices, size, keepdim); + tensor_list output = {result, at::Tensor(), at::Tensor()}; + return output; + } +}; + +tuple NPUNativeFunctions::npu_min(const at::Tensor& self, int64_t dim, bool keepdim) { + auto result = NPUMinFunction::apply(self, dim, keepdim); + std::tuple output(result[0], result[1]); + return output; +} + +} // namespace native +} // namespace at_npu \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/MishKernelNpu.cpp b/torch_npu/csrc/aten/ops/MishKernelNpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..76a9b22421f7f172fba0b36fc078a1eb85a4e7bd --- /dev/null +++ b/torch_npu/csrc/aten/ops/MishKernelNpu.cpp @@ -0,0 +1,77 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { +using torch::autograd::Function; +using torch::autograd::AutogradContext; +using tensor_list = std::vector; + +at::Tensor mish_npu(const at::Tensor& self) { + at::Tensor result = OpPreparation::ApplyTensor(self); + OpCommand cmd; + cmd.Name("Mish") + .Input(self) + .Output(result) + .Run(); + + return result; +} + +at::Tensor NPUNativeFunctions::npu_mish_backward(const at::Tensor& grad, const at::Tensor& input) { + at::Tensor result = OpPreparation::ApplyTensor(input); + OpCommand cmd; + cmd.Name("MishGrad") + .Input(grad) + .Input(input) + .Output(result) + .Run(); + + return result; +} + +class NPUMishFunction : public torch::autograd::Function { +public: + static at::Tensor forward(AutogradContext *ctx, + const at::Tensor& self) { + at::AutoNonVariableTypeMode g; + ctx->save_for_backward({self}); + return mish_npu(self); + } + + static tensor_list backward(AutogradContext *ctx, + tensor_list grad_outputs) { + auto saved = ctx->get_saved_variables(); + auto input = saved[0]; + + at::Tensor result = NPUNativeFunctions::npu_mish_backward(grad_outputs[0], input); + tensor_list output = {result}; + return output; + } +}; + +at::Tensor NPUNativeFunctions::npu_mish(const at::Tensor& self) { + return NPUMishFunction::apply(self); +} + +} // namespace native +} // namespace at_npu \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/ReshapeKernelNpu.cpp b/torch_npu/csrc/aten/ops/ReshapeKernelNpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1588635c4c712753d36c83fb9eafab00d680d426 --- /dev/null +++ b/torch_npu/csrc/aten/ops/ReshapeKernelNpu.cpp @@ -0,0 +1,67 @@ +// Copyright (c) 2021 Huawei Technologies Co., Ltd +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/framework/StorageDescHelper.h" +#include "torch_npu/csrc/aten/common/InnerNpuNativeFunction.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor& reshape_out_nocheck( + at::Tensor& result, + const at::Tensor& self, + at::IntArrayRef shape, + bool can_refresh) { + if (can_refresh) { + StorageDescHelper::SetDesc( + result, + array_to_small_vector(result.sizes()), + array_to_small_vector(result.strides())); + } else { + copy_d2d_by_memcpy( + result, + self, + at::prod_intlist(result.storage().get_npu_desc().storage_sizes_)); + } + return result; +} + +at::Tensor& NPUNativeFunctions::npu_reshape_out( + const at::Tensor& self, + at::IntArrayRef shape, + bool can_refresh, + at::Tensor& result) { + OpPreparation::CheckOut( + {self}, + result, + self, + shape); + return reshape_out_nocheck(result, self, shape, can_refresh); +} + +at::Tensor NPUNativeFunctions::npu_reshape(const at::Tensor& self, at::IntArrayRef shape, bool can_refresh) { + // construct the output tensor of the NPU + at::Tensor result = OpPreparation::ApplyTensorWithFormat( + shape, self.options(), CalcuOpUtil::get_tensor_npu_format(self)); + reshape_out_nocheck(result, self, shape, can_refresh); + + return result; +} + +} // namespace native +} // namespace at_npu \ No newline at end of file