From 6c3db39b461c769a3e716d32cca70df8d021fa4b Mon Sep 17 00:00:00 2001 From: hxf12345677 Date: Mon, 21 Feb 2022 10:15:06 +0800 Subject: [PATCH] =?UTF-8?q?repeat=5Finterleave1.8.1=E7=AE=97=E5=AD=90?= =?UTF-8?q?=E7=A7=BB=E6=A4=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../test_repeat_interleave.py | 85 +++++++++++++++++++ torch_npu/csrc/aten/npu_native_functions.yaml | 2 - .../aten/ops/RepeatInterLeaveKernelNpu.cpp | 60 +++++++++++++ 3 files changed, 145 insertions(+), 2 deletions(-) create mode 100644 test/test_network_ops/test_repeat_interleave.py create mode 100644 torch_npu/csrc/aten/ops/RepeatInterLeaveKernelNpu.cpp diff --git a/test/test_network_ops/test_repeat_interleave.py b/test/test_network_ops/test_repeat_interleave.py new file mode 100644 index 00000000000..f0673b66d3f --- /dev/null +++ b/test/test_network_ops/test_repeat_interleave.py @@ -0,0 +1,85 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class TestRepeatInterleave(TestCase): + + def generate_data(self, min_d, max_d, shape, dtype): + input1 = np.random.uniform(min_d, max_d, shape).astype(dtype) + npu_input1 = torch.from_numpy(input1) + return npu_input1 + + def cpu_op_exec(self, input1, input2, input3): + output = torch.repeat_interleave(input1, input2, dim=input3) + output = output.numpy() + return output + + def npu_op_exec(self, input1, input2, input3): + output = torch.repeat_interleave(input1, input2, dim=input3) + output = output.to("cpu") + output = output.numpy() + return output + + def cpu_op_exec_without_dim(self, input1, input2): + output = torch.repeat_interleave(input1, input2) + output = output.numpy() + return output + + def npu_op_exec_without_dim(self, input1, input2): + output = torch.repeat_interleave(input1, input2) + output = output.to("cpu") + output = output.numpy() + return output + + def test_repeat_interleave_float16(self, device): + npu_input1 = self.generate_data(0, 100, (3, 3, 3), np.float16) + npu_input2 = np.random.randint(1, 100) + npu_input3 = np.random.randint(0, 2) + cpu_output = self.cpu_op_exec(npu_input1, npu_input2, npu_input3) + npu_output = self.npu_op_exec(npu_input1, npu_input2, npu_input3) + self.assertRtolEqual(cpu_output, npu_output) + + def test_repeat_interleave_float32(self, device): + npu_input1 = self.generate_data(0, 100, (3, 3, 3), np.float32) + npu_input2 = np.random.randint(1, 100) + npu_input3 = np.random.randint(0, 2) + cpu_output = self.cpu_op_exec(npu_input1, npu_input2, npu_input3) + npu_output = self.npu_op_exec(npu_input1, npu_input2, npu_input3) + self.assertRtolEqual(cpu_output, npu_output) + + def test_repeat_interleave_int32(self, device): + npu_input1 = self.generate_data(0, 100, (3, 3, 3), np.int32) + npu_input2 = np.random.randint(1, 100) + npu_input3 = np.random.randint(0, 2) + cpu_output = self.cpu_op_exec(npu_input1, npu_input2, npu_input3) + npu_output = self.npu_op_exec(npu_input1, npu_input2, npu_input3) + self.assertRtolEqual(cpu_output, npu_output) + + def test_repeat_interleave_int32_without_dim(self, device): + npu_input1 = self.generate_data(0, 100, (3, 3, 3), np.int32) + npu_input2 = np.random.randint(1, 100) + cpu_output = self.cpu_op_exec_without_dim(npu_input1, npu_input2) + npu_output = self.npu_op_exec_without_dim(npu_input1, npu_input2) + self.assertRtolEqual(cpu_output, npu_output) + +instantiate_device_type_tests(TestRepeatInterleave, globals(), except_for='cpu') +if __name__ == '__main__': + run_tests() \ No newline at end of file diff --git a/torch_npu/csrc/aten/npu_native_functions.yaml b/torch_npu/csrc/aten/npu_native_functions.yaml index 1aaf07af260..c6dae62c539 100644 --- a/torch_npu/csrc/aten/npu_native_functions.yaml +++ b/torch_npu/csrc/aten/npu_native_functions.yaml @@ -727,8 +727,6 @@ supported: - negative_ - negative.out - repeat - - repeat_interleave.Tensor - - repeat_interleave.self_Tensor - repeat_interleave.self_int - reshape - _mkldnn_reshape diff --git a/torch_npu/csrc/aten/ops/RepeatInterLeaveKernelNpu.cpp b/torch_npu/csrc/aten/ops/RepeatInterLeaveKernelNpu.cpp new file mode 100644 index 00000000000..db0f0d80f9d --- /dev/null +++ b/torch_npu/csrc/aten/ops/RepeatInterLeaveKernelNpu.cpp @@ -0,0 +1,60 @@ +// Copyright (c) 2020, Huawei Technologies.All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor& repeat_interleave_out_npu(at::Tensor& result, at::Tensor& self, int64_t repeats, int64_t dim) { + OpCommand cmd; + cmd.Name("TileWithAxis") + .Input(self) + .Output(result) + .Attr("tiles", repeats) + .Attr("axis", dim) + .Run(); + + return result; +} + +at::Tensor NPUNativeFunctions::repeat_interleave( + const at::Tensor &self, + int64_t repeats, + c10::optional dim) { + int64_t realDim = dim.value_or(0); + + int64_t self_dim = self.dim(); + if((realDim < -self_dim) || (realDim > self_dim - 1)){ + AT_ERROR("dim value should be in the range of [-x, x-1], x is the dimension number of input tensor."); + } + + at::Tensor selfTensor = self; + if(!dim.has_value()){ + selfTensor = at::flatten(selfTensor); + } + + auto outputSize = repeat_interleave_npu_output_size(selfTensor, repeats, realDim); + at::Tensor result = OpPreparation::ApplyTensorWithSizes( + outputSize, + selfTensor.options()); + repeat_interleave_out_npu(result, selfTensor, repeats, realDim); + + return result; +} + +} // namespace native +} // namespace at_npu \ No newline at end of file -- Gitee