diff --git a/test/test_network_ops/test_index_select.py b/test/test_network_ops/test_index_select.py new file mode 100644 index 0000000000000000000000000000000000000000..5567f96c811086cd107bb51a5727d9b74286af21 --- /dev/null +++ b/test/test_network_ops/test_index_select.py @@ -0,0 +1,123 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd +# Copyright (c) 2019, Facebook CORPORATION. +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import torch +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class TestIndexSelect(TestCase): + def cpu_op_exec(self, input1, axis, indices): + '''the shape of input:float16, float32,int8,uint8,int32,uint32,int16,uint16,int64,uint64,''' + output = torch.index_select(input1, dim=axis, index=indices) + output = output.numpy() + return output + + def npu_op_exec(self, input1, axis, indices): + output = torch.index_select(input1, dim=axis, index=indices) + output = output.to('cpu') + output = output.numpy() + return output + + def cpu_op_out_exec(self, input1, axis, indices, output): + '''the shape of input:float16, float32,int8,uint8,int32,uint32,int16,uint16,int64,uint64,''' + torch.index_select(input1, dim=axis, index=indices,out=output) + output = output.numpy() + return output + + def npu_op_out_exec(self, input1, axis, indices, output): + torch.index_select(input1, dim=axis, index=indices, out=output) + output = output.to('cpu') + output = output.numpy() + return output + + def test_index_select(self, device): + shape_format = [ + [[np.float32, 0, (3, )], torch.tensor(0, dtype=torch.int64), 0], + [[np.float32, 0, (3, )], torch.tensor([0, 1], dtype=torch.int64), 0], + [[np.float32, 0, (2, 4)], torch.tensor([0, 1, 2], dtype=torch.int64), 1], + [[np.float32, 0, (3, 4, 6)], torch.tensor([1, 2, 4], dtype=torch.int64), 2], + [[np.float32, 3, (4, 5, 6, 7)], torch.tensor([3, 5, 6], dtype=torch.int64), 3], + [[np.float32, -1, (3, 4, 8, 9, 12)], torch.tensor([2, 3, 5, 6], dtype=torch.int64), 4], + + [[np.int8, 0, (3,)], torch.tensor([0, 1], dtype=torch.int64), 0], + [[np.int8, 0, (2, 4)], torch.tensor([0, 1, 2], dtype=torch.int64), 1], + [[np.int8, 0, (3, 4, 6)], torch.tensor([1, 2, 4], dtype=torch.int64), 2], + [[np.int8, 0, (4, 5, 6, 7)], torch.tensor([3, 5, 6], dtype=torch.int64), 3], + [[np.int8, -1, (3, 4, 8, 9, 12)], torch.tensor([2, 3, 5, 6], dtype=torch.int64), 4], + + [[np.uint8, 0, (3,)], torch.tensor([0, 1], dtype=torch.int64), 0], + [[np.uint8, 0, (2, 4)], torch.tensor([0, 1, 2], dtype=torch.int64), 1], + [[np.uint8, 0, (3, 4, 6)], torch.tensor([1, 2, 4], dtype=torch.int64), 2], + [[np.uint8, 0, (4, 5, 6, 7)], torch.tensor([3, 5, 6], dtype=torch.int64), 3], + [[np.uint8, -1, (3, 4, 8, 9, 12)], torch.tensor([2, 3, 5, 6], dtype=torch.int64), 4], + + [[np.int32, 0, (3,)], torch.tensor([0, 1], dtype=torch.int64), 0], + [[np.int32, 0, (2, 4)], torch.tensor([0, 1, 2], dtype=torch.int64), 1], + [[np.int32, 0, (3, 4, 6)], torch.tensor([1, 2, 4], dtype=torch.int64), 2], + [[np.int32, 0, (4, 5, 6, 7)], torch.tensor([3, 5, 6], dtype=torch.int64), 3], + [[np.int32, -1, (3, 4, 8, 9, 12)], torch.tensor([2, 3, 5, 6], dtype=torch.int64), 4], + + [[np.uint8, 0, (3,)], torch.tensor([0, 1], dtype=torch.int64), 0], + [[np.uint8, 0, (2, 4)], torch.tensor([0, 1, 2], dtype=torch.int64), 1], + [[np.uint8, 0, (3, 4, 6)], torch.tensor([1, 2, 4], dtype=torch.int64), 2], + [[np.uint8, 0, (4, 5, 6, 7)], torch.tensor([3, 5, 6], dtype=torch.int64), 3], + [[np.uint8, -1, (3, 4, 8, 9, 12)], torch.tensor([2, 3, 5, 6], dtype=torch.int64), 4], + + [[np.uint8, 0, (3,)], torch.tensor([0, 1], dtype=torch.int64), 0], + [[np.uint8, 0, (2, 4)], torch.tensor([0, 1, 2], dtype=torch.int64), 1], + [[np.uint8, 0, (3, 4, 6)], torch.tensor([1, 2, 4], dtype=torch.int64), 2], + [[np.uint8, 0, (4, 5, 6, 7)], torch.tensor([3, 5, 6], dtype=torch.int64), 3], + [[np.uint8, -1, (3, 4, 8, 9, 12)], torch.tensor([2, 3, 5, 6], dtype=torch.int64), 4], + + [[np.int16, 0, (3,)], torch.tensor([0, 1], dtype=torch.int64), 0], + [[np.int16, 0, (2, 4)], torch.tensor([0, 1, 2], dtype=torch.int64), 1], + [[np.int16, 0, (3, 4, 6)], torch.tensor([1, 2, 4], dtype=torch.int64), 2], + [[np.int16, 0, (4, 5, 6, 7)], torch.tensor([3, 5, 6], dtype=torch.int64), 3], + [[np.int16, -1, (3, 4, 8, 9, 12)], torch.tensor([2, 3, 5, 6], dtype=torch.int64), 4], + ] + for item in shape_format: + input1, npu_input = create_common_tensor(item[0], 1, 100) + _, npu_out = create_common_tensor(item[0], 1, 100) + cpu_output = self.cpu_op_exec(input1, item[2], item[1]) + npu_output = self.npu_op_exec(npu_input, item[2], item[1].to('npu')) + npu_output_out = self.npu_op_out_exec(npu_input, item[2], item[1].to('npu'), npu_out) + self.assertRtolEqual(cpu_output, npu_output) + self.assertRtolEqual(cpu_output, npu_output_out) + + def test_index_select_fp16(self, device): + shape_format = [ + [[np.float16, 0, (3,)], torch.tensor([0, 1], dtype=torch.int64), 0], + [[np.float16, 0, (2, 4)], torch.tensor([0, 1, 2], dtype=torch.int64), 1], + [[np.float16, 0, (3, 4, 6)], torch.tensor([1, 2, 4], dtype=torch.int64), 2], + [[np.float16, 3, (4, 5, 6, 7)], torch.tensor([3, 5, 6], dtype=torch.int64), 3], + [[np.float16, -1, (3, 4, 8, 9, 12)], torch.tensor([2, 3, 5, 6], dtype=torch.int64), 4], + [[np.float16, 0, (3, )], torch.tensor(0, dtype=torch.int64), 0], + ] + for item in shape_format: + input1, npu_input = create_common_tensor(item[0], 1, 100) + input1 = input1.to(torch.float32) + cpu_output = self.cpu_op_exec(input1, item[2], item[1]) + npu_output = self.npu_op_exec(npu_input, item[2], item[1].to('npu')) + cpu_output = cpu_output.astype(np.float16) + self.assertRtolEqual(cpu_output, npu_output) + +instantiate_device_type_tests(TestIndexSelect, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() diff --git a/torch_npu/csrc/aten/ops/IndexSelectKernelNpu.cpp b/torch_npu/csrc/aten/ops/IndexSelectKernelNpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b080182b42e5bb747a8782dcd18b8ebb4579e1fa --- /dev/null +++ b/torch_npu/csrc/aten/ops/IndexSelectKernelNpu.cpp @@ -0,0 +1,124 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor& index_select_out_npu_nocheck( + const at::Tensor& self, + int64_t dim, + const at::Tensor& index, + at::Tensor& result) { + if (self.scalar_type() == at::kLong) { + TORCH_WARN_ONCE("The oprator of index_select is executed, Currently High Accuracy but Low Performance OP with 64-bit has been used," + "Please Do Some Cast at Python Functions with 32-bit for Better Performance!"); + } + c10::SmallVector dimVec = {dim}; + OpCommand cmd; + cmd.Name("GatherV2") + .Input(self) + .Input(index) + .Input(dimVec, at::kInt) + .Output(result) + .Run(); + return result; +} + +at::Tensor& NPUNativeFunctions::index_select_out( + const at::Tensor& self, + int64_t dim, + const at::Tensor& index, + at::Tensor& result) { + at::Tensor indexTmp(index); + if (indexTmp.ndimension() == 0) { + indexTmp = index.unsqueeze(0); + } + auto outputSize = index_select_npu_output_size(self, dim, indexTmp); + int64_t npu_format = CalcuOpUtil::get_tensor_npu_format(self); + if (outputSize.empty()) { + npu_format = ACL_FORMAT_ND; + } + at::Tensor input = self; + if (self.dtype() == at::kBool) { + input = NPUNativeFunctions::npu_dtype_cast(input, at::kInt); + } + OpPreparation::CheckOut( + {input}, + result, + npu_format, + input.scalar_type(), + outputSize); + OpPipeWithDefinedOut pipe; + result = pipe.CheckMemory({input, indexTmp}, {result}) + .Func([&input, &dim, &indexTmp](at::Tensor& result) + {index_select_out_npu_nocheck(input, dim, indexTmp, result);}) + .Call(result); + if (self.dtype() == at::kBool) { + result = result.to(at::kBool); + } + return result; +} + +at::Tensor NPUNativeFunctions::index_select( + const at::Tensor& self, + int64_t dim, + const at::Tensor& index) { + at::Tensor indexTmp(index); + if (indexTmp.ndimension() == 0) { + indexTmp = index.unsqueeze(0); + } + auto outputSize = index_select_npu_output_size(self, dim, indexTmp); + int64_t npu_format = CalcuOpUtil::get_tensor_npu_format(self); + if (outputSize.empty()) { + npu_format = ACL_FORMAT_ND; + } + at::Tensor input = self; + if (self.dtype() == at::kBool) { + input = NPUNativeFunctions::npu_dtype_cast(input, at::kInt); + } + at::Tensor result = OpPreparation::ApplyTensorWithFormat(input, outputSize, npu_format); + index_select_out_npu_nocheck(input, dim, indexTmp, result); + if (self.dtype() == at::kBool) { + result = NPUNativeFunctions::npu_dtype_cast(result, at::kBool); + } + return result; +} + +at::Tensor& NPUNativeFunctions::index_select_out( + const at::Tensor& self, + at::Dimname dim, + const at::Tensor& index, + at::Tensor& result) { + at::Tensor indexTmp(index); + if (indexTmp.ndimension() == 0) { + indexTmp = index.unsqueeze(0); + } + return index_select_out( + self, dimname_to_position(self, dim), indexTmp, result); +} + +at::Tensor NPUNativeFunctions::index_select( + const at::Tensor& self, + at::Dimname dim, + const at::Tensor& index) { + return index_select(self, dimname_to_position(self, dim), index); +} +} // namespace native +} // namespace at_npu