From 56c3e98ba4eba60e19c8ed2a0ab4431b70ef052a Mon Sep 17 00:00:00 2001 From: pipihugh Date: Thu, 24 Feb 2022 12:11:19 +0800 Subject: [PATCH] =?UTF-8?q?masked=5Ffill=5Frange=E8=87=AA=E5=AE=9A?= =?UTF-8?q?=E4=B9=89=E7=AE=97=E5=AD=90=E7=A7=BB=E6=A4=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../test_masked_fill_range.py | 84 +++++++++++++++++++ torch_npu/csrc/aten/npu_native_functions.yaml | 2 + .../aten/ops/MaskedFillRangeKernelNpu.cpp | 74 ++++++++++++++++ 3 files changed, 160 insertions(+) create mode 100644 test/test_network_ops/test_masked_fill_range.py create mode 100644 torch_npu/csrc/aten/ops/MaskedFillRangeKernelNpu.cpp diff --git a/test/test_network_ops/test_masked_fill_range.py b/test/test_network_ops/test_masked_fill_range.py new file mode 100644 index 00000000000..8da2a1a2110 --- /dev/null +++ b/test/test_network_ops/test_masked_fill_range.py @@ -0,0 +1,84 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class TestMaskedFillRange(TestCase): + def cpu_op_exec(self, input1, start, end, value, axis, dim): + out = input1.clone() + start_shape = start.shape + for i in range(0, start_shape[0]): + for j in range(0, start_shape[1]): + for k in range(start[i, j], end[i, j]): + if dim == 1: + out[k] = value[i] + elif dim == 2: + if axis == 0: + out[k, :] = value[i] + else: + out[j, k] = value[i] + elif dim == 3: + if axis == 0: + out[k, :, :] = value[i] + elif axis == 1: + out[:, k, :] = value[i] + else: + out[j, :, k] = value[i] + return out + + def npu_op_exec(self, input1, start, end, value, axis): + out = torch_npu.npu_masked_fill_range(input1, start, end, value, axis) + out = out.to("cpu") + return out.detach().numpy() + + def test_normalize_batch(self, device): + shape_format = [ + [[np.float32, -1, [32, 64, 1688]], + [list(range(0, 32))], + [list(range(6, 38))], [[1], torch.float32], 2], + [[np.float16, -1, [32, 64, 1688]], + [list(range(0, 32))], + [list(range(6, 38))], [[1], torch.float16], 2], + [[np.int32, -1, [32, 64, 1688]], + [list(range(0, 32))], + [list(range(6, 38))], [[1], torch.int32], 2], + [[np.int8, -1, [32, 64, 1688]], + [list(range(0, 32))], + [list(range(6, 38))], [[1], torch.int8], 2], + ] + for item in shape_format: + axis = item[-1] + cpu_input1, npu_input1 = create_common_tensor(item[0], 1, 100) + shape = item[0][-1] + cpu_start = torch.tensor(item[1], dtype=torch.int32) + npu_start = cpu_start.npu() + cpu_end = torch.tensor(item[2], dtype=torch.int32) + npu_end = cpu_end.npu() + cpu_value = torch.tensor(item[3][0], dtype=item[3][1]) + npu_value = cpu_value.npu() + cpu_output = self.cpu_op_exec(cpu_input1, cpu_start, cpu_end, cpu_value, axis, len(shape)) + npu_output = self.npu_op_exec(npu_input1, npu_start, npu_end, npu_value, axis) + cpu_output = cpu_output.numpy() + cpu_output = cpu_output.astype(npu_output.dtype) + self.assertRtolEqual(cpu_output, npu_output) + +instantiate_device_type_tests(TestMaskedFillRange, globals(), except_for='cpu') +if __name__ == "__main__": + run_tests() diff --git a/torch_npu/csrc/aten/npu_native_functions.yaml b/torch_npu/csrc/aten/npu_native_functions.yaml index 51ecc22c349..5e46a4b8f53 100644 --- a/torch_npu/csrc/aten/npu_native_functions.yaml +++ b/torch_npu/csrc/aten/npu_native_functions.yaml @@ -1927,6 +1927,8 @@ custom: - func: npu_dropout_backward(Tensor grad_output, Tensor mask, float p) -> Tensor - func: npu_nms_rotated(Tensor self, Tensor scores, float iou_threshold, float scores_threshold=0, int max_output_size=-1, int mode=0) -> (Tensor, Tensor) variants: function, method + - func: npu_masked_fill_range(Tensor self, Tensor start, Tensor end, Tensor value, int axis=-1) -> Tensor + variants: function, method custom_autograd: - func: npu_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor - func: npu_convolution_transpose(Tensor input, Tensor weight, Tensor? bias, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups) -> Tensor diff --git a/torch_npu/csrc/aten/ops/MaskedFillRangeKernelNpu.cpp b/torch_npu/csrc/aten/ops/MaskedFillRangeKernelNpu.cpp new file mode 100644 index 00000000000..4ba56bab85a --- /dev/null +++ b/torch_npu/csrc/aten/ops/MaskedFillRangeKernelNpu.cpp @@ -0,0 +1,74 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +void mask_fill_range_nocheck( + const at::Tensor& self, + const at::Tensor& start, + const at::Tensor& end, + const at::Tensor& value, + int64_t axis){ + int64_t x_dim = self.dim(); + int64_t min = -x_dim; + int64_t max = x_dim - 1; + TORCH_CHECK( + !(axis < min || axis > max), + "axis overfloaw the range, expected in range [", + -x_dim, + " ", + x_dim - 1, + "] "); + TORCH_CHECK( + start.ndimension() == 2 && start.sizes() == end.sizes(), + "Expected noempty 2D start tensor and start' sizes() should be equal end's sizes() "); + TORCH_CHECK( + start.size(0) == value.size(0), + "Expected value.length equal start loop num "); + TORCH_CHECK( + self.scalar_type() == value.scalar_type(), + "value dtype should be equal self dtype !, but value dtype is ", + value.scalar_type(), + " and self dtype is ", + self.scalar_type()); +} + +at::Tensor NPUNativeFunctions::npu_masked_fill_range( + const at::Tensor& self, + const at::Tensor& start, + const at::Tensor& end, + const at::Tensor& value, + int64_t axis){ + mask_fill_range_nocheck(self, start, end, value, axis); + at::Tensor result = OpPreparation::ApplyTensor(self); + OpCommand cmd; + cmd.Name("MaskedFillRange") + .Input(self) + .Input(start) + .Input(end) + .Input(value) + .Output(result) + .Attr("axis", axis) + .Run(); + return result; +} + +} // namespace native +} // namespace at_npu -- Gitee