From e152a55ce185558c448dbc5e2d12253487ac6daf Mon Sep 17 00:00:00 2001 From: zhoufan37 Date: Thu, 10 Feb 2022 11:26:16 +0800 Subject: [PATCH] Add Swhere Operator --- test/test_network_ops/test_where.py | 123 ++++++++++++++++++ torch_npu/csrc/aten/npu_native_functions.yaml | 3 - torch_npu/csrc/aten/ops/WhereKernelNpu.cpp | 123 ++++++++++++++++++ 3 files changed, 246 insertions(+), 3 deletions(-) create mode 100644 test/test_network_ops/test_where.py create mode 100644 torch_npu/csrc/aten/ops/WhereKernelNpu.cpp diff --git a/test/test_network_ops/test_where.py b/test/test_network_ops/test_where.py new file mode 100644 index 00000000000..cc58c37f7f2 --- /dev/null +++ b/test/test_network_ops/test_where.py @@ -0,0 +1,123 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class TestWhere(TestCase): + def cpu_op_exec(self, input1): + output = torch.where(input1) + output = list(output) + output[0] = output[0].numpy().astype(np.int32) + return output + + def npu_op_exec(self, input1): + output = torch.where(input1) + output = list(output) + output[0] = output[0].to("cpu").numpy().astype(np.int32) + return output + + def cpu_op_exec_condition(self, input1, ones): + output = torch.where(input1 > 0, input1, ones) + output = output.numpy() + return output + + def npu_op_exec_condition(self, input1, ones): + output = torch.where(input1 > 0, input1, ones) + output = output.to("cpu").numpy() + return output + + def cpu_op_exec_s(self, input1, ones): + output = torch._s_where(input1 > 0, input1, ones) + output = output.numpy() + return output + + def npu_op_exec_s(self, input1, ones): + output = torch._s_where(input1 > 0, input1, ones) + output = output.to("cpu").numpy() + return output + + def where_result(self, shape_format): + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item, -100, 100) + cpu_ones = torch.ones_like(cpu_input1) + npu_ones = cpu_ones.to("npu") + if cpu_input1.dtype == torch.float16: + cpu_input1 = cpu_input1.to(torch.float32) + cpu_ones = cpu_ones.to(torch.float32) + + cpu_output = self.cpu_op_exec(cpu_input1) + npu_output = self.npu_op_exec(npu_input1) + + cpu_output_cond = self.cpu_op_exec_condition(cpu_input1, cpu_ones) + npu_output_cond = self.npu_op_exec_condition(npu_input1, npu_ones) + cpu_output_cond = cpu_output_cond.astype(npu_output_cond.dtype) + + cpu_output_s = self.cpu_op_exec_s(cpu_input1, cpu_ones) + npu_output_s = self.npu_op_exec_s(npu_input1, npu_ones) + cpu_output_s = cpu_output_s.astype(npu_output_s.dtype) + + cpu_output[0] = cpu_output[0].astype(npu_output[0].dtype) + self.assertRtolEqual(cpu_output[0], npu_output[0]) + self.assertRtolEqual(cpu_output_cond, npu_output_cond) + self.assertRtolEqual(cpu_output_s, npu_output_s) + + def test_where_shape_format_fp32_1d(self, device): + format_list = [0, 3] + shape_format = [[np.float32, i, [18]] for i in format_list] + self.where_result(shape_format) + + def test_where_shape_format_fp32_2d(self, device): + format_list = [0] + shape_format = [[np.float32, i, [5, 256]] for i in format_list] + self.where_result(shape_format) + + def test_where_shape_format_fp32_3d(self, device): + format_list = [0] + shape_format = [[np.float32, i, [32, 3, 3]] for i in format_list] + self.where_result(shape_format) + + def test_where_shape_format_fp32_4d(self, device): + format_list = [0, 3] + shape_format = [[np.float32, i, [64, 112, 7, 7]] for i in format_list] + self.where_result(shape_format) + + def test_where_shape_format_fp16_1d(self, device): + format_list = [0, 3] + shape_format = [[np.float16, i, [18]] for i in format_list] + self.where_result(shape_format) + + def test_where_shape_format_fp16_2d(self, device): + format_list = [0, 3, 4, 29] + shape_format = [[np.float16, i, [5, 256]] for i in format_list] + self.where_result(shape_format) + + def test_where_shape_format_fp16_3d(self, device): + format_list = [0, 3, 4, 29] + shape_format = [[np.float16, i, [32, 3, 3]] for i in format_list] + self.where_result(shape_format) + + def test_where_shape_format_fp16_4d(self, device): + format_list = [0, 3, 4, 29] + shape_format = [[np.float16, i, [64, 112, 7, 7]] for i in format_list] + self.where_result(shape_format) + +instantiate_device_type_tests(TestWhere, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() diff --git a/torch_npu/csrc/aten/npu_native_functions.yaml b/torch_npu/csrc/aten/npu_native_functions.yaml index 1f08a26654f..7fe3eed7d5b 100644 --- a/torch_npu/csrc/aten/npu_native_functions.yaml +++ b/torch_npu/csrc/aten/npu_native_functions.yaml @@ -898,9 +898,6 @@ supported: - var_mean.names_dim - view_as - where.self - - where.ScalarSelf - - where.ScalarOther - - where.Scalar - where - _s_where - norm_except_dim diff --git a/torch_npu/csrc/aten/ops/WhereKernelNpu.cpp b/torch_npu/csrc/aten/ops/WhereKernelNpu.cpp new file mode 100644 index 00000000000..149b23b25c2 --- /dev/null +++ b/torch_npu/csrc/aten/ops/WhereKernelNpu.cpp @@ -0,0 +1,123 @@ +// Copyright (c) 2020, Huawei Technologies. +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +std::tuple npu_expand_outplace( + const at::Tensor &to_expand1, + const at::Tensor &to_expand2, + const at::Tensor &to_expand3, + const char *api_name) { + for (auto& t : {to_expand1, to_expand2, to_expand3}) { + if (!t.defined()) { + AT_ERROR(api_name, "(...) called with an undefined Tensor"); + } + } + + if (to_expand1.sizes().equals(to_expand2.sizes()) && to_expand1.sizes().equals(to_expand3.sizes())) { + return std::make_tuple(to_expand1, to_expand2, to_expand3); + } + + auto expanded_size12 = broadcast_ops_npu_output_size(to_expand1, to_expand2); + auto expanded_size = broadcast_ops_npu_output_size(expanded_size12, to_expand3.sizes()); + + return std::make_tuple( + to_expand1.expand(expanded_size, true), + to_expand2.expand(expanded_size, true), + to_expand3.expand(expanded_size, true)); +} + +at::Tensor NPUNativeFunctions::_s_where( + const at::Tensor& condition, + const at::Tensor& self, + const at::Tensor& other) { + at::Tensor result = OpPreparation::ApplyTensor(self); + + OpCommand cmd; + cmd.Name("Select") + .Input(condition) + .Input(self) + .Input(other) + .Output(result) + .Run(); + + return result; +} + +at::Tensor NPUNativeFunctions::where( + const at::Tensor& condition, + const at::Tensor& self, + const at::Tensor& other) { + TORCH_CHECK(condition.device() == self.device() && self.device() == other.device(), + "expected condition, x and y to be on the same device, but condition is on ", + condition.device(), " and x and y are on ", self.device(), " and ", other.device(), + " respectively"); + if (condition.scalar_type() != at::ScalarType::Byte && condition.scalar_type() != at::ScalarType::Bool) { + AT_ERROR("Expected condition to have ScalarType Byte, but got ScalarType ", + toString(condition.scalar_type())); + } + at::Tensor b_condition, b_self, b_other; + std::tie(b_condition, b_self, b_other) = npu_expand_outplace(condition, self, other, "where_npu"); + return at::_s_where(b_condition, b_self, b_other); +} + +c10::SmallVector where_npu_output_size(const at::Tensor& condition){ + int64_t dim = condition.dim(); + at::Tensor boolSelf = NPUNativeFunctions::npu_dtype_cast(condition, at::ScalarType::Bool); + at::Tensor intSelf = NPUNativeFunctions::npu_dtype_cast(boolSelf, at::ScalarType::Int); + at::Tensor coutNonzeroSelf = at::sum(intSelf, at::ScalarType::Int); + int64_t nonzeroNum = coutNonzeroSelf.item().toInt(); + c10::SmallVector outputSize = {nonzeroNum, dim}; + return outputSize; +} + +vector NPUNativeFunctions::where(const at::Tensor& condition) { + at::Tensor formatCastOfCondition = condition; + if (condition.storage().unsafeGetStorageImpl()->npu_desc_.npu_format_ != + ACL_FORMAT_ND) { + formatCastOfCondition = NPUNativeFunctions::npu_format_cast(formatCastOfCondition, ACL_FORMAT_ND); + } + if (condition.scalar_type() == at::ScalarType::Half) { + formatCastOfCondition = NPUNativeFunctions::npu_dtype_cast(formatCastOfCondition, at::ScalarType::Float); + } + + // calculate the output size + auto outputSize = where_npu_output_size(formatCastOfCondition); + + // construct the output tensor of the NPU + at::Tensor result = OpPreparation::ApplyTensorWithFormat( + outputSize, formatCastOfCondition.options().dtype(at::kLong), ACL_FORMAT_ND); + + OpCommand cmd; + cmd.Name("NonZero") + .Input(formatCastOfCondition) + .Output(result) + .Run(); + result = result.transpose(1, 0); + std::vector chunkResult = result.chunk(result.size(0), 0); + std::vector squeezeResult; + for(int64_t i = 0; i < chunkResult.size(); i++){ + squeezeResult.push_back(chunkResult[i].squeeze(0)); + } + + return squeezeResult; +} +} // namespace native +} // namespace at_npu \ No newline at end of file -- Gitee