diff --git a/test/test_network_ops/test_scatterv1.py b/test/test_network_ops/test_scatterv1.py new file mode 100644 index 0000000000000000000000000000000000000000..38d66bbbd3956fbb70950f99f2a82b24ab12f603 --- /dev/null +++ b/test/test_network_ops/test_scatterv1.py @@ -0,0 +1,51 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class TestScatterV1(TestCase): + def npu_op_exec(self, input1, indices, updates, dim): + output = torch_npu.npu_scatter(input1, indices, updates, dim) + output = output.to("cpu") + output = output.numpy() + return output + + def test_scatterv1(self, device): + input1_list = [[[1.6279, 0.1226], [0.9041, 1.0980]]] + indices_list = [[0, 1]] + updates_list = [[-1.1993, -1.5247]] + dim_list = [0] + exoutput_list = [[[-1.1993, 0.1226], [0.9041, -1.5247]]] + + shape_format = [[i, j, k, h, f] for i in input1_list + for j in indices_list for k in updates_list for h in dim_list for f in exoutput_list] + + for item in shape_format: + input1_tensor = torch.tensor(item[0]).npu() + indices_tensor = torch.tensor(item[1]).npu().to(torch.int32) + updates_tensor = torch.tensor(item[2]).npu() + dim = item[3] + exoutput_tensor = torch.tensor(item[4]) + output = self.npu_op_exec(input1_tensor, indices_tensor, updates_tensor, dim) + self.assertRtolEqual(exoutput_tensor.numpy(), output) + + +instantiate_device_type_tests(TestScatterV1, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/test/test_network_ops/test_sub_sample.py b/test/test_network_ops/test_sub_sample.py new file mode 100644 index 0000000000000000000000000000000000000000..f6d95510ee74f69b21ec9db5a4cfcdecf086fded --- /dev/null +++ b/test/test_network_ops/test_sub_sample.py @@ -0,0 +1,77 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class TestSubSample(TestCase): + def get_num(self, input1, output): + input_num1 = 0 + input_num0 = 0 + output_num1 = 0 + output_num0 = 0 + for i in range(input1.size()[0]): + if input1[i] == 1: + input_num1 = input_num1 + 1 + if input1[i] == 0: + input_num0 = input_num0 + 1 + for i in range(output.size()[0]): + if output[i] == 1: + output_num1 = output_num1 + 1 + if output[i] == 0: + output_num0 = output_num0 + 1 + list1 = [input_num1, input_num0, output_num1, output_num0] + + return list1 + + def numless_equal(self, input_num1, input_num0, output_num1, output_num0, size, fraction): + error_name = "result error" + if input_num1 < size * fraction: + if output_num1 != input_num1: + self.fail(error_name) + if input_num0 < size - input_num1 and output_num0 != input_num0: + self.fail(error_name) + if input_num0 >= size - input_num1 and output_num0 != size - input_num1: + self.fail(error_name) + + def nummore_equal(self, input_num1, input_num0, output_num1, output_num0, size, fraction): + error_name = "result error" + if input_num1 >= size * fraction: + if output_num1 != size * fraction: + self.fail(error_name) + if input_num0 < size - size * fraction and output_num0 != input_num0: + self.fail(error_name) + if input_num0 >= size - size * fraction and output_num0 != size - size * fraction: + self.fail(error_name) + + def test_subsample(self, device): + for _ in range(20): + input1 = np.random.randint(-1, 2, size = (10)) + npu_input = torch.from_numpy(input1).to("npu") + #input only suport int32 + npu_input = npu_input.to(torch.int32) + npu_output1 = torch_npu.npu_sub_sample(npu_input, 5, 0.6) + getlist = self.get_num(npu_input, npu_output1) + self.numless_equal(getlist[0], getlist[1], getlist[2], getlist[3], 5, 0.6) + self.nummore_equal(getlist[0], getlist[1], getlist[2], getlist[3], 5, 0.6) + + +instantiate_device_type_tests(TestSubSample, globals(), except_for='cpu') +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/test/test_network_ops/test_yolo_boxes_encode.py b/test/test_network_ops/test_yolo_boxes_encode.py new file mode 100644 index 0000000000000000000000000000000000000000..2755be59381cc31b6a534676f6baa9d6a5606b70 --- /dev/null +++ b/test/test_network_ops/test_yolo_boxes_encode.py @@ -0,0 +1,50 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class TestYoloBoxesEncode(TestCase): + def npu_op_exec(self, anchor_boxes, gt_bboxes, stride, impl_mode=False): + out = torch_npu.npu_yolo_boxes_encode(anchor_boxes, gt_bboxes, stride, impl_mode) + out = out.to("cpu") + return out.detach().numpy() + + def test_yolo_boxes_encode(self, device): + anchor_boxes = [(2, 4)] + gt_bboxes = [(2 ,4)] + stride = [[2, 2]] + exoutput_list = [[[0.7921727, 0.5314963, -0.74224466, -13.815511], + [0.7360072, 0.58343244, 4.3334002, -0.51378196]]] + + shape_format = [[i, j, k, h] for i in anchor_boxes + for j in gt_bboxes for k in stride for h in exoutput_list] + + for item in shape_format: + anchor_boxes_tensor = torch.rand(item[0], dtype=torch.float32).to("npu") + gt_bboxes_tensor = torch.rand(item[1], dtype=torch.float32).to("npu") + stride_tensor = torch.tensor(item[2], dtype=torch.float32).to("npu") + exoutput_cpu_tensor = torch.tensor(item[3], dtype=torch.float32) + npu_output = self.npu_op_exec(anchor_boxes_tensor, gt_bboxes_tensor, stride_tensor, False) + + self.assertRtolEqual(exoutput_cpu_tensor.numpy(), npu_output) + + +instantiate_device_type_tests(TestYoloBoxesEncode, globals(), except_for='cpu') +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/torch_npu/csrc/aten/npu_native_functions.yaml b/torch_npu/csrc/aten/npu_native_functions.yaml index c14f2194adbefe83351df768d70dafda4da8732c..b4c04c619f7991a453360e2260fc216d6566ea5a 100644 --- a/torch_npu/csrc/aten/npu_native_functions.yaml +++ b/torch_npu/csrc/aten/npu_native_functions.yaml @@ -1929,6 +1929,9 @@ custom: variants: function, method - func: npu_masked_fill_range(Tensor self, Tensor start, Tensor end, Tensor value, int axis=-1) -> Tensor variants: function, method + - func: npu_sub_sample(Tensor self, int per_images, float positive_fraction) -> Tensor + - func: npu_yolo_boxes_encode(Tensor self, Tensor gt_bboxes, Tensor stride, bool performance_mode=False) -> Tensor + - func: npu_scatter(Tensor self, Tensor indices, Tensor updates, int dim) -> Tensor custom_autograd: - func: npu_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor - func: npu_convolution_transpose(Tensor input, Tensor weight, Tensor? bias, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups) -> Tensor diff --git a/torch_npu/csrc/aten/ops/ScatterV1KernelNpu.cpp b/torch_npu/csrc/aten/ops/ScatterV1KernelNpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f142a37884e644b5bef99b6a27480daafcf6552c --- /dev/null +++ b/torch_npu/csrc/aten/ops/ScatterV1KernelNpu.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor& scatter_out_npu( + at::Tensor& output, + const at::Tensor& self, + const at::Tensor& indices, + const at::Tensor& updates, + int64_t dim) { + OpCommand cmd; + cmd.Name("ArgMaxGrad") + .Input(self) + .Input(indices) + .Input(updates) + .Output(output) + .Attr("dimension", dim) + .Run(); + + return output; +} + +at::Tensor NPUNativeFunctions::npu_scatter(const at::Tensor& self, const at::Tensor& indices, const at::Tensor& updates, int64_t dim) { + at::Tensor outputs = OpPreparation::ApplyTensor(self); + scatter_out_npu(outputs, self, indices, updates, dim); + + return outputs; +} +} // namespace native +} // namespace at_npu \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/SubSampleKernelNpu.cpp b/torch_npu/csrc/aten/ops/SubSampleKernelNpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..adf201614d05bbfa8aaed4e56319cdfb9055e58e --- /dev/null +++ b/torch_npu/csrc/aten/ops/SubSampleKernelNpu.cpp @@ -0,0 +1,37 @@ +// Copyright (c) 2020, Huawei Technologies.All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor NPUNativeFunctions::npu_sub_sample( + const at::Tensor &self, + int64_t per_images, + double positive_fraction) { + + at::Tensor result = OpPreparation::ApplyTensor(self); + OpCommand cmd; + cmd.Name("SubSample") + .Input(self) + .Output(result) + .Attr("batch_size_per_images", per_images) + .Attr("positive_fraction", (float)positive_fraction) + .Run(); + return result; +} +} // namespace native +} // namespace at_npu \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/YoloBoxesEncodeKernelNpu.cpp b/torch_npu/csrc/aten/ops/YoloBoxesEncodeKernelNpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..877c0a79c07ecfa413ddabb05c580332f59c034c --- /dev/null +++ b/torch_npu/csrc/aten/ops/YoloBoxesEncodeKernelNpu.cpp @@ -0,0 +1,75 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +static inline void yolo_boxes_encode_check( + const at::Tensor& anchor_boxes, + const at::Tensor& gt_bboxes, + const at::Tensor& stride){ + TORCH_CHECK( + anchor_boxes.dim() == 2 && anchor_boxes.size(1) == 4, + "Non-empty 2D anchor_boxes tensor expected but got a tensor with sizes ", + anchor_boxes.sizes()); + TORCH_CHECK( + anchor_boxes.size(0) <= 20480, + "anchor_boxes only support max [20480] num, but got num ", + anchor_boxes.size(0)); + TORCH_CHECK( + gt_bboxes.dim() == 2 && gt_bboxes.size(1) == 4, + "Non-empty 2D gt_bboxes tensor expected but got a tensor with sizes ", + gt_bboxes.sizes()); + TORCH_CHECK( + stride.dim() == 1, + "Non-empty 1D stride tensor expected but got a tensor with sizes ", + stride.sizes()); + TORCH_CHECK( + stride.size(0) == gt_bboxes.size(0), + "stride's length should be equal gt_bboxes' num, but got stride length ", + stride.size(0), + "gt_bboxes num ", + gt_bboxes.size(0)); + TORCH_CHECK( + at::isIntegralType(stride.scalar_type(), true) && stride.scalar_type() != at::ScalarType::Long, + "int32 strdie tensor expected but got a tensor with dtype: ", + stride.scalar_type()); +} + +at::Tensor NPUNativeFunctions::npu_yolo_boxes_encode( + const at::Tensor& anchor_boxes, + const at::Tensor& gt_bboxes, + const at::Tensor& stride, + bool performance_mode){ + yolo_boxes_encode_check(anchor_boxes, gt_bboxes, stride); + at::Tensor result = OpPreparation::ApplyTensor(gt_bboxes); + string implModeStr = performance_mode ? "high_performance" : "high_precision"; + at::Tensor strideCp = NPUNativeFunctions::npu_dtype_cast(stride, at::ScalarType::Int); + OpCommand cmd; + cmd.Name("YoloBoxesEncode") + .Input(anchor_boxes) + .Input(gt_bboxes) + .Input(strideCp) + .Output(result) + .Attr("performance_mode", implModeStr) + .Run(); + return result; +} +} // namespace native +} // namespace at_npu \ No newline at end of file