From 72f215c7027f20776b89644d39cc53779ed0d72e Mon Sep 17 00:00:00 2001 From: zhoufan37 Date: Wed, 16 Feb 2022 11:10:00 +0800 Subject: [PATCH] Add Roll Operator --- test/test_network_ops/test_roll.py | 79 ++++++++++++ torch_npu/csrc/aten/ops/RollKernelNpu.cpp.cpp | 112 ++++++++++++++++++ 2 files changed, 191 insertions(+) create mode 100644 test/test_network_ops/test_roll.py create mode 100644 torch_npu/csrc/aten/ops/RollKernelNpu.cpp.cpp diff --git a/test/test_network_ops/test_roll.py b/test/test_network_ops/test_roll.py new file mode 100644 index 0000000000..207ebf2349 --- /dev/null +++ b/test/test_network_ops/test_roll.py @@ -0,0 +1,79 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class TestRoll(TestCase): + def generate_data(self, min_d, max_d, shape, dtype): + input_x = np.random.uniform(min_d, max_d, shape).astype(dtype) + npu_input = torch.from_numpy(input_x) + return npu_input + + def cpu_op_exec(self, input_x, shifts, dims): + output = torch.roll(input_x, shifts, dims).numpy() + return output + + def npu_op_exec(self, input_x, shifts, dims): + input1 = input_x.to("npu") + output = torch.roll(input1, shifts, dims) + output = output.to("cpu") + output = output.numpy() + return output + + def test_roll_3_4_5_float32(self, device): + input_x1 = self.generate_data(-1, 1, (3, 4, 5), np.float32) + cpu_output1 = self.cpu_op_exec(input_x1, [2, 1], [0, 1]) + npu_output1 = self.npu_op_exec(input_x1, [2, 1], [0, 1]) + self.assertRtolEqual(cpu_output1, npu_output1) + + def test_roll_3_4_5_float16(self, device): + input_x1 = self.generate_data(-1, 1, (3, 4, 5), np.float16) + input_cpu = input_x1.float() + cpu_output1 = self.cpu_op_exec(input_cpu, [2, 1], [0, 1]).astype(np.float16) + npu_output1 = self.npu_op_exec(input_x1, [2, 1], [0, 1]) + self.assertRtolEqual(cpu_output1, npu_output1) + + def test_roll_30_40_50_int32(self, device): + input_x1 = self.generate_data(-1, 1, (30, 40, 50), np.int32) + cpu_output1 = self.cpu_op_exec(input_x1, [20], []) + npu_output1 = self.npu_op_exec(input_x1, [20], []) + self.assertRtolEqual(cpu_output1, npu_output1) + + def test_roll_10_10_10_10_10_10_int8(self, device): + input_x1 = self.generate_data(-1, 1, (10, 10, 10, 10, 10, 10), np.int8) + cpu_output1 = self.cpu_op_exec(input_x1, [-20, 30, 5], [-3, -4, -5]) + npu_output1 = self.npu_op_exec(input_x1, [-20, 30, 5], [-3, -4, -5]) + self.assertRtolEqual(cpu_output1, npu_output1) + + def test_roll_20_30_40_50_uint8(self, device): + input_x1 = self.generate_data(-1, 1, (20, 30, 40, 50), np.uint8) + cpu_output1 = self.cpu_op_exec(input_x1, [-20, 30], [-1, 0]) + npu_output1 = self.npu_op_exec(input_x1, [-20, 30], [-1, 0]) + self.assertRtolEqual(cpu_output1, npu_output1) + + def test_roll_20_30_40_50_flaot32(self, device): + input_x1 = self.generate_data(-1, 1, (20, 30, 40, 50), np.float32) + cpu_output1 = self.cpu_op_exec(input_x1, [30], [3]) + npu_output1 = self.npu_op_exec(input_x1, [30], [3]) + self.assertRtolEqual(cpu_output1, npu_output1) + +instantiate_device_type_tests(TestRoll, globals(), except_for='cpu') +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/RollKernelNpu.cpp.cpp b/torch_npu/csrc/aten/ops/RollKernelNpu.cpp.cpp new file mode 100644 index 0000000000..9dafaff230 --- /dev/null +++ b/torch_npu/csrc/aten/ops/RollKernelNpu.cpp.cpp @@ -0,0 +1,112 @@ +// Copyright (c) 2020, Huawei Technologies.All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor& roll_out_npu_no_transpose( + at::Tensor& result, + const at::Tensor& self, + at::IntArrayRef shifts, + at::IntArrayRef dims) { + + // executing the NPU operator + OpCommand cmd; + cmd.Name("Roll") + .Input(self) + .Output(result) + .Attr("shifts", shifts) + .Attr("dims", dims) + .Run(); + + return result; +} + +at::Tensor& roll_transpose( + at::Tensor& result, + const at::Tensor& self, + int64_t axis, + int64_t firstDim, + at::IntArrayRef shifts, + int64_t id) { + + c10::SmallVector perm; + for (int64_t i = 0; i < self.dim(); i++) { + perm.emplace_back(i); + } + std::swap(perm[axis], perm[firstDim]); + at::Tensor transposeSelf = NPUNativeFunctions::npu_transpose(self, perm); + auto outputSize = transpose_npu_output_size(result, perm); + at::Tensor transposeResult = OpPreparation::ApplyTensorWithSizes( + outputSize, + self.options()); + c10::SmallVector dim = {firstDim}; + c10::SmallVector shift_bak = {shifts[id]}; + at::IntArrayRef dim_now = at::IntArrayRef(dim); + at::IntArrayRef shift_now = at::IntArrayRef(shift_bak); + roll_out_npu_no_transpose(transposeResult, transposeSelf, shift_now, dim_now); + NPUNativeFunctions::npu_transpose_out(transposeResult, perm, result); + return result; +} + +at::Tensor& roll_out_npu( + at::Tensor& result, + const at::Tensor& self, + at::IntArrayRef shifts, + at::IntArrayRef dims) { + + if (dims.size() == 0) { + roll_out_npu_no_transpose(result, self, shifts, dims); + } else { + TORCH_CHECK(dims.size() == shifts.size(), + "The size of shifts and dims should be the same when the size of dims is not 0."); + int64_t firstDim = CalcuOpUtil::make_wrap_dim(0, self.dim()); + for (int i = 0; i < dims.size(); i++) { + int64_t axis = CalcuOpUtil::make_wrap_dim(dims[i], self.dim()); + if (i == 0) { + if (axis == firstDim) { + c10::SmallVector dim = {firstDim}; + c10::SmallVector shift_bak = {shifts[i]}; + at::IntArrayRef dim_now = at::IntArrayRef(dim); + at::IntArrayRef shift_now = at::IntArrayRef(shift_bak); + roll_out_npu_no_transpose(result, self, shift_now, dim_now); + } else { + roll_transpose(result, self, axis, firstDim, shifts, i); + } + } else { + roll_transpose(result, result, axis, firstDim, shifts, i); + } + } + } + return result; +} + +at::Tensor NPUNativeFunctions::roll( + const at::Tensor& self, + at::IntArrayRef shifts, + at::IntArrayRef dims) { + + // construct the output tensor of the NPU + at::Tensor result = OpPreparation::ApplyTensor(self); + + // calculate the output result of the NPU + roll_out_npu(result, self, shifts, dims); + return result; +} +} // namespace native +} // namespace at_npu \ No newline at end of file -- Gitee