diff --git a/test/test_network_ops/test__ilshift__.py b/test/test_network_ops/test__ilshift__.py new file mode 100644 index 0000000000000000000000000000000000000000..debdd59f0ba5766997ec0908b1b682adfad47ed8 --- /dev/null +++ b/test/test_network_ops/test__ilshift__.py @@ -0,0 +1,69 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd +# Copyright (c) 2019, Facebook CORPORATION. +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class TestiLshift(TestCase): + def cpu_op_exec(self, input1, input2): + input1.__ilshift__(input2) + output = input1.numpy() + return output + + def npu_op_exec(self, input1, input2): + input1.__ilshift__(input2) + output = input1.to("cpu") + output = output.numpy() + return output + + def test_ilshift_tensor(self, device): + format_list = [0] + shape_list = [(256, 32, 56)] + shape_format = [ + [np.int32, i, j] for i in format_list for j in shape_list + ] + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item, 0, 100) + cpu_input2 = torch.tensor([1]).to(torch.int32) + npu_input2 = cpu_input2.npu() + cpu_output = self.cpu_op_exec(cpu_input1, cpu_input2) + npu_output = self.npu_op_exec(npu_input1, npu_input2) + cpu_output = cpu_output.astype(npu_output.dtype) + self.assertRtolEqual(cpu_output, npu_output) + + def test_ilshift_scalar(self, device): + format_list = [0] + shape_list = [(256, 32, 56)] + shape_format = [ + [np.int32, i, j] for i in format_list for j in shape_list + ] + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item, 0, 100) + cpu_input2 = torch.tensor(1).to(torch.int32) + npu_input2 = cpu_input2.npu() + cpu_output = self.cpu_op_exec(cpu_input1, cpu_input2) + npu_output = self.npu_op_exec(npu_input1, npu_input2) + cpu_output = cpu_output.astype(npu_output.dtype) + self.assertRtolEqual(cpu_output, npu_output) + +instantiate_device_type_tests(TestiLshift, globals(), except_for='cpu') +if __name__ == "__main__": + run_tests() diff --git a/test/test_network_ops/test_grid_assign_positive.py b/test/test_network_ops/test_grid_assign_positive.py new file mode 100644 index 0000000000000000000000000000000000000000..bacd6558a264e908f4e00d22882c606ecdbd62a2 --- /dev/null +++ b/test/test_network_ops/test_grid_assign_positive.py @@ -0,0 +1,50 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests + +class TestGridAssignPositive(TestCase): + def npu_op_exec(self, *args): + out = torch_npu.npu_grid_assign_positive(*args) + out = out.to("cpu") + return out.detach().numpy() + + def test_grid_assign_positive(self, device): + assigned_gt_inds = torch.rand((4,), dtype=torch.float32).to("npu") + overlaps = torch.rand((2,4), dtype=torch.float32).to("npu") + box_responsible_flags = torch.tensor([1,1,1,0], dtype=torch.uint8).to("npu") + max_overlap = torch.rand((4,), dtype=torch.float32).to("npu") + argmax_overlap = torch.tensor([1,0,1,0], dtype=torch.int32).to("npu") + gt_max_overlaps = torch.rand((2,), dtype=torch.float32).to("npu") + gt_argmax_overlaps = torch.tensor([1,0],dtype=torch.int32).to("npu") + inputs = [assigned_gt_inds,overlaps,box_responsible_flags,max_overlap, + argmax_overlap,gt_max_overlaps,gt_argmax_overlaps] + num_gts = 128 + pos_iou_thr = .5 + min_pos_iou = .0 + gt_max_assign_all = True + attrs = [num_gts, pos_iou_thr, min_pos_iou, gt_max_assign_all] + + params = inputs + attrs + expect_cpu = torch.tensor([2., 1., 0.25984418, 0.36664134], dtype=torch.float32) + npu_output = self.npu_op_exec(*params) + self.assertRtolEqual(expect_cpu.numpy(), npu_output) + +instantiate_device_type_tests(TestGridAssignPositive, globals(), except_for='cpu') +if __name__ == "__main__": + run_tests() diff --git a/test/test_network_ops/test_ifmr.py b/test/test_network_ops/test_ifmr.py new file mode 100644 index 0000000000000000000000000000000000000000..9b73f45a5485546deabd5419b82a04015929224e --- /dev/null +++ b/test/test_network_ops/test_ifmr.py @@ -0,0 +1,141 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import reduce + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class TestIFMR(TestCase): + def cpu_op_exec(self, + input_data, + with_offset, + bins_num=128, + min_percentile=0.999999, + max_percentile=0.999999, + search_range1=0.7, + search_range2=1.3, + search_step=0.01): + pre_mode = np.float32 + input_data = input_data.numpy().astype(pre_mode) + data_min = input_data.min() + data_max = input_data.max() + data_num = reduce(lambda x, y: x * y, input_data.shape) + data_num = np.array(data_num, pre_mode) + + bins, threshold = np.histogram(input_data, bins_num) + cumsum = np.cumsum(bins).astype(np.int32) + + bins_num = np.array(bins_num, pre_mode) + cdf = cumsum.astype(pre_mode) / data_num + max_index = np.where(cdf >= np.array(max_percentile, pre_mode), 0, + 1).sum().astype(pre_mode) + min_index = np.where(cdf >= np.array(1 - min_percentile, pre_mode), 0, + 1).sum().astype(pre_mode) + max_init = max_index / bins_num * (data_max - data_min) + data_min + min_init = min_index / bins_num * (data_max - data_min) + data_min + + step = np.arange(search_range1, + search_range2, + search_step, + dtype=pre_mode) + if with_offset: + if max_init < 0: + max_init = np.array(0, pre_mode) + if min_init > 0: + min_init = np.array(0, pre_mode) + min_list = min_init * np.ones(step.shape, dtype=pre_mode) + else: + max_init = np.max([np.abs(max_init), np.abs(min_init)]) + max_list = max_init * step + + if with_offset: + scale = (max_list - min_list) / 255 + scale = np.where(scale < 1.192092896e-07, 1, scale) + offset = np.round(min_list / scale) + offset = -(offset + 128) + else: + scale = max_list / 127 + offset = np.round(scale * 0) + + loss_list = np.zeros(step.shape, dtype=pre_mode) + for i in range(step.size): + quant_data = np.round(input_data / scale[i]) + offset[i] + np.clip(quant_data, -128, 127, out=quant_data) + quant_data = (quant_data - offset[i]) * scale[i] + loss_list[i] = np.sum(np.square(quant_data - input_data)) + index = np.argmin(loss_list) + return scale[index], offset[index] + + def npu_op_exec(self, input_data, with_offset): + min_value = torch.min(input_data) + max_value = torch.max(input_data) + min_value = torch.reshape(min_value, (1, )) + max_value = torch.reshape(max_value, (1, )) + hist = torch.histc(input_data.to('cpu'), + bins=128, + min=min_value[0].to('cpu'), + max=max_value[0].to('cpu')) + cdf = torch.cumsum(hist, dim=0).int() + + cdf = cdf.to('npu') + scale, offset = torch_npu.npu_ifmr(input_data, + min_value, + max_value, + cdf, + min_percentile=0.999999, + max_percentile=0.999999, + search_start=0.7, + search_end=1.3, + search_step=0.01, + with_offset=with_offset) + + return scale, offset + + def test_ifrm_with_offset(self, device): + format_list = [0, 3] + shape_list = [(2, 2, 3, 4), (5, 5)] + shape_format = [[np.float32, i, j] for i in format_list + for j in shape_list] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, -1, 1) + scale_cpu, offset_cpu = self.cpu_op_exec(cpu_input, + with_offset=True) + scale_npu, offset_npu = self.npu_op_exec(npu_input, + with_offset=True) + self.assertTrue((scale_cpu - scale_npu[0]) / scale_cpu < 0.0001) + self.assertEqual(offset_cpu, offset_npu[0]) + + def test_ifrm_without_offset(self, device): + format_list = [0, 3] + shape_list = [(2, 2, 3, 4), (5, 5)] + shape_format = [[np.float32, i, j] for i in format_list + for j in shape_list] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, -1, 1) + scale_cpu, offset_cpu = self.cpu_op_exec(cpu_input, + with_offset=False) + scale_npu, offset_npu = self.npu_op_exec(npu_input, + with_offset=False) + self.assertTrue((scale_cpu - scale_npu[0]) / scale_cpu < 0.0001) + self.assertEqual(offset_cpu, offset_npu[0]) + +instantiate_device_type_tests(TestIFMR, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() diff --git a/torch_npu/csrc/aten/npu_native_functions.yaml b/torch_npu/csrc/aten/npu_native_functions.yaml index 51ecc22c349d71aed2fe6eb9e6a773855f8732de..73a80f0c67b0f0a0c7fdfd675702023691e1185e 100644 --- a/torch_npu/csrc/aten/npu_native_functions.yaml +++ b/torch_npu/csrc/aten/npu_native_functions.yaml @@ -1936,5 +1936,6 @@ custom_autograd: - func: npu_ps_roi_pooling(Tensor self, Tensor rois, float spatial_scale, int group_size, int output_dim) -> Tensor - func: npu_linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor - func: _npu_dropout(Tensor self, float p) -> (Tensor, Tensor) - - func: npu_softmax_cross_entropy_with_logits(Tensor self, Tensor labels) -> Tensor - variants: function, method \ No newline at end of file + - func: npu_ifmr(Tensor data, Tensor data_min, Tensor data_max, Tensor cumsum, float min_percentile, float max_percentile, float search_start, float search_end, float search_step, bool with_offset) -> (Tensor, Tensor) + - func: npu_grid_assign_positive(Tensor self, Tensor overlaps, Tensor box_responsible_flags, Tensor max_overlaps, Tensor argmax_overlaps, Tensor gt_max_overlaps, Tensor gt_argmax_overlaps, int num_gts, float pos_iou_thr, float min_pos_iou, bool gt_max_assign_all) -> Tensor + - func: npu_softmax_cross_entropy_with_logits(Tensor self, Tensor labels) -> Tensor \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/GridAssignPositiveKernelNpu.cpp b/torch_npu/csrc/aten/ops/GridAssignPositiveKernelNpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..46ceebef0d4f6531a80a84954b62b57495f61488 --- /dev/null +++ b/torch_npu/csrc/aten/ops/GridAssignPositiveKernelNpu.cpp @@ -0,0 +1,77 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { +using namespace torch::autograd; + +static inline void grid_assign_positive_check( + const at::Tensor& argmax_overlaps, + const at::Tensor& gt_argmax_overlaps){ + TORCH_CHECK( + at::isIntegralType(argmax_overlaps.scalar_type(), true) && argmax_overlaps.scalar_type() != at::ScalarType::Long, + "int32 argmax_overlaps tensor expected but got a tensor with dtype: ", + argmax_overlaps.scalar_type()); + TORCH_CHECK( + at::isIntegralType(gt_argmax_overlaps.scalar_type(), true) && gt_argmax_overlaps.scalar_type() != at::ScalarType::Long, + "int32 gt_argmax_overlaps tensor expected but got a tensor with dtype: ", + gt_argmax_overlaps.scalar_type()); +} + +at::Tensor NPUNativeFunctions::npu_grid_assign_positive( + const at::Tensor& assigned_gt_inds, + const at::Tensor& overlaps, + const at::Tensor& box_responsible_flags, + const at::Tensor& max_overlaps, + const at::Tensor& argmax_overlaps, + const at::Tensor& gt_max_overlaps, + const at::Tensor& gt_argmax_overlaps, + int64_t num_gts, + double pos_iou_thr, + double min_pos_iou, + bool gt_max_assign_all){ + grid_assign_positive_check(argmax_overlaps, gt_argmax_overlaps); + at::Tensor result = OpPreparation::ApplyTensor(assigned_gt_inds); + auto option = assigned_gt_inds.options().dtype(at::kInt); + at::Scalar s(num_gts); + at::Tensor numOfGts = at::empty({}, option).fill_(s); + + at::Tensor argmaxOverLaps = NPUNativeFunctions::npu_dtype_cast(argmax_overlaps, at::ScalarType::Int); + at::Tensor gtArgmaxOverLaps = NPUNativeFunctions::npu_dtype_cast(gt_argmax_overlaps, at::ScalarType::Int); + + OpCommand cmd; + cmd.Name("GridAssignPositive") + .Input(assigned_gt_inds) + .Input(overlaps) + .Input(box_responsible_flags) + .Input(max_overlaps) + .Input(argmaxOverLaps) + .Input(gt_max_overlaps) + .Input(gtArgmaxOverLaps) + .Input(numOfGts) + .Output(result) + .Attr("pos_iou_thr", (float) pos_iou_thr) + .Attr("min_pos_iou", (float) min_pos_iou) + .Attr("gt_max_assign_all", gt_max_assign_all) + .Run(); + return result; +} + +} // namespace native +} // namespace at_npu diff --git a/torch_npu/csrc/aten/ops/IfmrKernelNpu.cpp b/torch_npu/csrc/aten/ops/IfmrKernelNpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ed4796729c4bfbfae7b42b615de91e76bd15b0e1 --- /dev/null +++ b/torch_npu/csrc/aten/ops/IfmrKernelNpu.cpp @@ -0,0 +1,62 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { +using namespace torch::autograd; + +tuple NPUNativeFunctions::npu_ifmr( + const at::Tensor& data, + const at::Tensor& data_min, + const at::Tensor& data_max, + const at::Tensor& cumsum, + const double min_percentile=0.999999, + const double max_percentile=0.999999, + const double search_start=0.7, + const double search_end=1.3, + const double search_step=0.01, + const bool with_offset=true) { + at::Tensor scale = OpPreparation::ApplyTensorWithFormat(data_min, ACL_FORMAT_NCHW); + at::Tensor offset = OpPreparation::ApplyTensorWithFormat(data_min, ACL_FORMAT_NCHW); + + std::vector tmp; + tmp.push_back(static_cast(search_start)); + tmp.push_back(static_cast(search_end)); + at::ArrayRef searchRange(tmp); + + OpCommand cmd; + cmd.Name("IFMR") + .Input(data) + .Input(data_min) + .Input(data_max) + .Input(cumsum) + .Attr("min_percentile", static_cast(min_percentile)) + .Attr("max_percentile", static_cast(max_percentile)) + .Attr("search_range", searchRange) + .Attr("search_step", static_cast(search_step)) + .Attr("with_offset", with_offset) + .Output(scale) + .Output(offset) + .Run(); + + return std::tie(scale, offset); +} + +} // namespace native +} // namespace at_npu diff --git a/torch_npu/csrc/aten/ops/__iLshift__KernelNpu.cpp b/torch_npu/csrc/aten/ops/__iLshift__KernelNpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..baeb07553fc5ebce68b1cc89ffc0b017f481bea7 --- /dev/null +++ b/torch_npu/csrc/aten/ops/__iLshift__KernelNpu.cpp @@ -0,0 +1,78 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor& ilshift_out_npu( + at::Tensor& result, + at::Tensor& self, + at::Scalar other) { + at::Tensor otherBroadcast = OpPreparation::ApplyTensor(self).fill_(other); + OpCommand cmd; + cmd.Name("LeftShift") + .Input(self) + .Input(otherBroadcast) + .Output(result) + .Run(); + + return result; +} + +at::Tensor& ilshift_out_npu( + at::Tensor& result, + at::Tensor& self, + const at::Tensor& other) { + at::Tensor otherBroadcast = other.expand(self.sizes()); + OpCommand cmd; + cmd.Name("LeftShift") + .Input(self) + .Input(otherBroadcast) + .Output(result) + .Run(); + + return result; +} + +at::Tensor& NPUNativeFunctions::__ilshift__(at::Tensor& self, const at::Tensor& other) { + if(!NpuUtils::check_match(&self)){ + at::Tensor contiguousSelf = NpuUtils::format_contiguous(self); + ilshift_out_npu(contiguousSelf, contiguousSelf, other); + NpuUtils::format_fresh_view(self, contiguousSelf); + } else { + ilshift_out_npu(self, self, other); + } + + return self; +} + +at::Tensor& NPUNativeFunctions::__ilshift__(at::Tensor& self, at::Scalar other) { + if(!NpuUtils::check_match(&self)){ + at::Tensor contiguousSelf = NpuUtils::format_contiguous(self); + ilshift_out_npu(contiguousSelf, contiguousSelf, other); + NpuUtils::format_fresh_view(self, contiguousSelf); + } else { + ilshift_out_npu(self, self, other); + } + + return self; +} + +} // namespace native +} // namespace at_npu