From 419d1420393f776ad0e715ee98ceafe51514c82e Mon Sep 17 00:00:00 2001 From: hxf12345677 Date: Fri, 25 Feb 2022 17:13:27 +0800 Subject: [PATCH] =?UTF-8?q?npu=5Fmin=E7=AE=97=E5=AD=901.8.1=E7=A7=BB?= =?UTF-8?q?=E6=A4=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/test_network_ops/test_min_v1.py | 40 ++++++ test/test_network_ops/test_min_v1_backward.py | 48 +++++++ torch_npu/csrc/aten/npu_native_functions.yaml | 5 +- torch_npu/csrc/aten/ops/MinV1KernelNpu.cpp | 126 ++++++++++++++++++ 4 files changed, 218 insertions(+), 1 deletion(-) create mode 100644 test/test_network_ops/test_min_v1.py create mode 100644 test/test_network_ops/test_min_v1_backward.py create mode 100644 torch_npu/csrc/aten/ops/MinV1KernelNpu.cpp diff --git a/test/test_network_ops/test_min_v1.py b/test/test_network_ops/test_min_v1.py new file mode 100644 index 0000000000..42af3c50ec --- /dev/null +++ b/test/test_network_ops/test_min_v1.py @@ -0,0 +1,40 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests + +class TestMinV1(TestCase): + def cpu_op_exec(self, data, dim): + outputs, indices = torch.min(data, dim) + return outputs.detach() + + def npu_op_exec(self, data, dim): + data = data.to("npu") + outputs, indices = torch_npu.npu_min(data, dim) + return outputs.detach().cpu() + + def test_min_v1_fp32(self, device): + data = torch.randn(2, 2, 2, 2, dtype = torch.float32) + npu_data = data.clone() + cpu_out = self.cpu_op_exec(data, 2) + npu_out = self.npu_op_exec(npu_data, 2) + self.assertRtolEqual(cpu_out, npu_out) + +instantiate_device_type_tests(TestMinV1, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/test/test_network_ops/test_min_v1_backward.py b/test/test_network_ops/test_min_v1_backward.py new file mode 100644 index 0000000000..08bbaa439f --- /dev/null +++ b/test/test_network_ops/test_min_v1_backward.py @@ -0,0 +1,48 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests + +class TestMinV1Backward(TestCase): + def op_exec(self, npu_flag, data, dim): + if npu_flag: + data = data.to("npu") + data.requires_grad = True + if npu_flag: + outputs, indices = torch_npu.npu_min(data, dim) + else: + outputs, indices = torch.min(data, dim) + outputs.backward(torch.ones_like(outputs)) + gradoutput = data.grad + out = outputs.detach() + if npu_flag: + out = out.cpu() + gradoutput = gradoutput.cpu() + return out, gradoutput + + def test_min_v1_backward_fp32(self, device): + data = torch.randn(2, 2, 2, 2, dtype = torch.float32) + npu_data = data.clone() + cpu_out, cpu_grad_out = self.op_exec(0, data, 2) + npu_out, npu_grad_out = self.op_exec(1, npu_data, 2) + self.assertRtolEqual(cpu_grad_out, npu_grad_out) + self.assertRtolEqual(cpu_out, npu_out) + +instantiate_device_type_tests(TestMinV1Backward, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/torch_npu/csrc/aten/npu_native_functions.yaml b/torch_npu/csrc/aten/npu_native_functions.yaml index b4c04c619f..266e91179f 100644 --- a/torch_npu/csrc/aten/npu_native_functions.yaml +++ b/torch_npu/csrc/aten/npu_native_functions.yaml @@ -1932,6 +1932,7 @@ custom: - func: npu_sub_sample(Tensor self, int per_images, float positive_fraction) -> Tensor - func: npu_yolo_boxes_encode(Tensor self, Tensor gt_bboxes, Tensor stride, bool performance_mode=False) -> Tensor - func: npu_scatter(Tensor self, Tensor indices, Tensor updates, int dim) -> Tensor + - func: npu_min_backward(Tensor grad, int dim, Tensor indices, int[] sizes, bool keepdim=False) -> Tensor custom_autograd: - func: npu_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor - func: npu_convolution_transpose(Tensor input, Tensor weight, Tensor? bias, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups) -> Tensor @@ -1943,4 +1944,6 @@ custom_autograd: - func: _npu_dropout(Tensor self, float p) -> (Tensor, Tensor) - func: npu_ifmr(Tensor data, Tensor data_min, Tensor data_max, Tensor cumsum, float min_percentile, float max_percentile, float search_start, float search_end, float search_step, bool with_offset) -> (Tensor, Tensor) - func: npu_grid_assign_positive(Tensor self, Tensor overlaps, Tensor box_responsible_flags, Tensor max_overlaps, Tensor argmax_overlaps, Tensor gt_max_overlaps, Tensor gt_argmax_overlaps, int num_gts, float pos_iou_thr, float min_pos_iou, bool gt_max_assign_all) -> Tensor - - func: npu_softmax_cross_entropy_with_logits(Tensor self, Tensor labels) -> Tensor \ No newline at end of file + - func: npu_softmax_cross_entropy_with_logits(Tensor self, Tensor labels) -> Tensor + - func: npu_min.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) + - func: npu_min.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/MinV1KernelNpu.cpp b/torch_npu/csrc/aten/ops/MinV1KernelNpu.cpp new file mode 100644 index 0000000000..9ef8d59fc2 --- /dev/null +++ b/torch_npu/csrc/aten/ops/MinV1KernelNpu.cpp @@ -0,0 +1,126 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { +using torch::autograd::Function; +using torch::autograd::AutogradContext; +using tensor_list = std::vector; + +tuple min_v1_out_npu_nocheck( + at::Tensor& output, + at::Tensor& indices, + const at::Tensor& self, + int64_t dim, + bool keepdim) { + OpCommand cmd; + cmd.Name("ArgMinWithValue") + .Input(self) + .Output(indices) + .Output(output) + .Attr("dimension", dim) + .Attr("keep_dims", keepdim) + .Run(); + + return std::tie(output, indices); +} + +tuple min_v1_npu(const at::Tensor& self, int64_t dim, bool keepdim) { + c10::SmallVector dims = {dim}; + c10::SmallVector outputSize = + reduce_ops_npu_output_size(self, dims, keepdim); + c10::SmallVector indicesSize = + reduce_ops_npu_output_size(self, dims, keepdim); + + int64_t npuFormat = CalcuOpUtil::get_tensor_npu_format(self); + if (outputSize.empty()) { + npuFormat = ACL_FORMAT_NCHW; + } + + at::Tensor outputs = OpPreparation::ApplyTensorWithFormat(outputSize, self.options(), npuFormat); + at::Tensor indices = OpPreparation::ApplyTensorWithFormat(indicesSize, self.options().dtype(at::kInt), npuFormat); + + min_v1_out_npu_nocheck(outputs, indices, self, dim, keepdim); + return std::tie(outputs, indices); +} + +tuple NPUNativeFunctions::npu_min(const at::Tensor& self, at::Dimname dim, bool keepdim) { + return min_v1_npu(self, dimname_to_position(self, dim), keepdim); +} + +at::Tensor NPUNativeFunctions::npu_min_backward( + const at::Tensor& grad, + int64_t dim, + const at::Tensor& indices, + at::IntArrayRef sizes, + bool keepdim) { + at::Tensor newGrad = grad; + at::Tensor newIndices = indices; + if (keepdim && sizes.size() > 0) { + newGrad = grad.squeeze(dim); + newIndices = indices.squeeze(dim); + } + auto gradInput = NPUNativeFunctions::npu_scatter( + at::native::zeros(sizes, newGrad.options()), newIndices, newGrad, dim); + return gradInput; +} + +class NPUMinFunction : public torch::autograd::Function { +public: + static tensor_list forward(AutogradContext *ctx, + const at::Tensor& self, + int64_t dim, + bool keepdim) { + ctx->saved_data["dim"] = dim; + ctx->saved_data["keepdim"] = keepdim; + ctx->saved_data["size"] = self.sizes(); + + auto result = min_v1_npu(self, dim, keepdim); + auto result1 = std::get<1>(result); + ctx->saved_data["indices"] = result1; + at::AutoNonVariableTypeMode g; + tensor_list result_list = {std::get<0>(result), result1}; + return result_list; + } + + static tensor_list backward(AutogradContext *ctx, + tensor_list grad_outputs) { + auto dim = ctx->saved_data["dim"].toInt(); + auto keepdim = ctx->saved_data["keepdim"].toBool(); + auto size = ctx->saved_data["size"].toIntVector(); + auto indices = ctx->saved_data["indices"].toTensor(); + + at::Tensor result = NPUNativeFunctions::npu_min_backward( + grad_outputs[0], dim, indices, size, keepdim); + tensor_list output = {result, at::Tensor(), at::Tensor()}; + return output; + } +}; + +tuple NPUNativeFunctions::npu_min(const at::Tensor& self, int64_t dim, bool keepdim) { + auto result = NPUMinFunction::apply(self, dim, keepdim); + std::tuple output(result[0], result[1]); + return output; +} + +} // namespace native +} // namespace at_npu \ No newline at end of file -- Gitee