From 97e37518461a7ad2b335e7880c1a31789432a559 Mon Sep 17 00:00:00 2001 From: pipihugh Date: Sat, 29 Jan 2022 17:53:38 +0800 Subject: [PATCH] =?UTF-8?q?gather=E7=AE=97=E5=AD=90=E7=A7=BB=E6=A4=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/test_network_ops/test_gather.py | 56 +++++++++++ torch_npu/csrc/aten/ops/GatherKernelNpu.cpp | 102 ++++++++++++++++++++ 2 files changed, 158 insertions(+) create mode 100644 test/test_network_ops/test_gather.py create mode 100644 torch_npu/csrc/aten/ops/GatherKernelNpu.cpp diff --git a/test/test_network_ops/test_gather.py b/test/test_network_ops/test_gather.py new file mode 100644 index 0000000000..9cc3a761db --- /dev/null +++ b/test/test_network_ops/test_gather.py @@ -0,0 +1,56 @@ +# Copyright (c) 2022 Huawei Technologies Co., Ltd +# Copyright (c) 2019, Facebook CORPORATION. +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class TestGather(TestCase): + def cpu_op_exec(self, input1, dim, index): + output = torch.gather(input1, dim, index) + output = output.numpy() + return output + + def npu_op_exec(self, input1, dim, index): + output = torch.gather(input1, dim, index) + output = output.to("cpu") + output = output.numpy() + return output + + def test_gather_shape_format(self, device): + shape_format = [ + [[np.int32, 0, (4, 3)], 0, torch.LongTensor([[0, 1, 1], [2, 0, 1]])], + [[np.int64, 0, (2, 3)], 1, torch.LongTensor([[0, 1, 1], [0, 0, 1]])], + [[np.float16, 0, (2, 3, 5)], 2, torch.LongTensor([[[0, 1, 2, 0, 2], [0, 0, 0, 0, 0], [1, 1, 1, 1, 1]], + [[1, 1, 1, 1, 1], [0, 0, 0, 0, 0], [1, 1, 1, 1, 1]]])], + [[np.float32, 0, (3, 3, 5)], -3, torch.LongTensor([[[0, 1, 2, 0, 2], [0, 0, 0, 0, 0], [1, 1, 1, 1, 1]], + [[1, 2, 2, 2, 2], [0, 0, 0, 0, 0], [2, 2, 2, 2, 2]]])] + ] + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item[0], -100, 100) + cpu_output = self.cpu_op_exec(cpu_input1, item[1], item[2]) + npu_idx = item[2].to("npu") + npu_output = self.npu_op_exec(npu_input1, item[1], npu_idx) + self.assertRtolEqual(cpu_output, npu_output) + +instantiate_device_type_tests(TestGather, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() diff --git a/torch_npu/csrc/aten/ops/GatherKernelNpu.cpp b/torch_npu/csrc/aten/ops/GatherKernelNpu.cpp new file mode 100644 index 0000000000..92d3c95012 --- /dev/null +++ b/torch_npu/csrc/aten/ops/GatherKernelNpu.cpp @@ -0,0 +1,102 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor& gather_out_npu_nocheck( + const at::Tensor& self, + int64_t dim, + const at::Tensor& index, + bool sparse_grad, + at::Tensor& result) { + at::Tensor dtypeCastOfSelf = self; + at::Tensor dtypeCastOfResult = result; + if (self.scalar_type() == at::ScalarType::Half) { + dtypeCastOfSelf = NPUNativeFunctions::npu_dtype_cast(dtypeCastOfSelf, at::ScalarType::Float); + dtypeCastOfResult = NPUNativeFunctions::npu_dtype_cast(dtypeCastOfResult, at::ScalarType::Float); + } + + if (self.scalar_type() == at::kLong) { + TORCH_WARN_ONCE("The oprator of gather is executed, Currently High Accuracy but Low Performance OP" + "with 64-bit has been used,Please Do Some Cast at Python Functions with 32-bit for Better Performance!"); + } + + OpCommand cmd; + cmd.Name("GatherElements") + .Input(dtypeCastOfSelf) + .Input(index) + .Attr("dim", dim) + .Output(dtypeCastOfResult) + .Run(); + result.copy_(dtypeCastOfResult); + return result; +} + +at::Tensor& NPUNativeFunctions::gather_out( + const at::Tensor& self, + int64_t dim, + const at::Tensor& index, + bool sparse_grad, + at::Tensor& result) { + auto outputSize = index.sizes(); + OpPreparation::CheckOut( + {self}, + result, + self, + outputSize); + return gather_out_npu_nocheck(self, dim, index, sparse_grad, result); +} + +at::Tensor& NPUNativeFunctions::gather_out( + const at::Tensor& self, + at::Dimname dim, + const at::Tensor& index, + bool sparse_grad, + at::Tensor& result) { + auto outputSize = index.sizes(); + OpPreparation::CheckOut( + {self}, + result, + self, + outputSize); + return gather_out_npu_nocheck(self, dimname_to_position(self, dim), index, sparse_grad, result); +} + +at::Tensor NPUNativeFunctions::gather( + const at::Tensor& self, + int64_t dim, + const at::Tensor& index, + bool sparse_grad) { + auto outputSize = input_same_output_size(index); + at::Tensor result = OpPreparation::ApplyTensor(self, outputSize); + gather_out_npu_nocheck(self, dim, index, sparse_grad, result); + return result; +} + +at::Tensor NPUNativeFunctions::gather( + const at::Tensor& self, + at::Dimname dim, + const at::Tensor& index, + bool sparse_grad) { + return gather(self, dimname_to_position(self, dim), index, sparse_grad); +} +} // namespace native +} // namespace at_npu \ No newline at end of file -- Gitee