From 1ada1fda3947072887790362c041807c93b9f619 Mon Sep 17 00:00:00 2001 From: shenpengcheng Date: Tue, 8 Feb 2022 20:17:46 +0800 Subject: [PATCH] =?UTF-8?q?MaskedSelect=E7=AE=97=E5=AD=90=E8=BF=81?= =?UTF-8?q?=E7=A7=BB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit output非连续型问题 code check code check问题 code check --- test/test_network_ops/test_masked_select.py | 128 ++++++++++++++++++ .../csrc/aten/ops/MaskedSelectKernelNpu.cpp | 114 ++++++++++++++++ 2 files changed, 242 insertions(+) create mode 100644 test/test_network_ops/test_masked_select.py create mode 100644 torch_npu/csrc/aten/ops/MaskedSelectKernelNpu.cpp diff --git a/test/test_network_ops/test_masked_select.py b/test/test_network_ops/test_masked_select.py new file mode 100644 index 0000000000..9587eadd13 --- /dev/null +++ b/test/test_network_ops/test_masked_select.py @@ -0,0 +1,128 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd +# Copyright (c) 2019, Facebook CORPORATION. +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest import makeSuite +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class TestMaskedSelect(TestCase): + def get_mask(self): + mask = torch.tensor([[ + [ True, False, True, True, False], + [ True, False, False, True, False], + [False, False, False, False, False], + [ True, False, False, False, False]], + + [[ True, False, False, False, True], + [False, True, False, True, True], + [False, True, False, True, True], + [False, False, False, False, False]], + + [[False, True, True, False, True], + [False, True, True, True, True], + [False, True, False, True, False], + [False, True, True, False, False]]]) + return mask + + def cpu_op_exec(self, input1, mask): + output = torch.masked_select(input1, mask) + output = output.numpy() + return output + + def npu_op_exec(self, input1, mask): + mask = mask.to("npu") + output = torch.masked_select(input1, mask) + output = output.to("cpu") + output = output.numpy() + return output + + def npu_op_exec_out(self, input1, mask, output): + output = torch.masked_select(input1, mask, out=output) + return output.detach().to("cpu").numpy() + + def test_maskedselect_out_result(self, device): + shape_format = [ + [[np.float16, 2, [15, 15, 15, 16]], [np.float16, 2, [15, 15, 15, 16]]], + [[np.float16, 2, [15, 15, 15, 16]], [np.float16, 2, [3, 3, 7, 7]]], + [[np.float16, 0, [15, 15, 15, 16]], [np.float16, 0, [15, 15, 15, 16]]], + [[np.float16, 0, [15, 15, 15, 16]], [np.float16, 0, [116, 116, 1, 1]]], + [[np.float32, 2, [15, 15, 15, 16]], [np.float32, 2, [15, 15, 15, 16]]], + [[np.float32, 2, [15, 15, 15, 16]], [np.float32, 2, [3, 3, 7, 7]]], + [[np.float32, 0, [15, 15, 15, 16]], [np.float32, 0, [15, 15, 15, 16]]], + [[np.float32, 0, [15, 15, 15, 16]], [np.float32, 0, [232, 232, 1, 1]]], + ] + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item[0], -2, 2) + cpu_input2, npu_input2 = create_common_tensor(item[0], -2, 2) + cpu_input3, npu_input3 = create_common_tensor(item[1], -2, 2) + if cpu_input1.dtype == torch.float16: + cpu_input1 = cpu_input1.to(torch.float32) + cpu_output = self.cpu_op_exec(cpu_input1, cpu_input2.to(torch.int32) > 0) + npu_output = self.npu_op_exec_out(npu_input1, npu_input2.to(torch.int32) > 0, npu_input3) + cpu_output = cpu_output.astype(npu_output.dtype) + self.assertRtolEqual(cpu_output, npu_output) + + def test_maskedselect_shape_format_maskdiff(self, device): + dtype_list = [np.int64, np.int32, np.float32] + format_list = [0] + shape_list = [[3, 4, 5]] + shape_format = [ + [i, j, k] for i in dtype_list for j in format_list for k in shape_list + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 0, 100) + mask_cpu, mask_npu = create_common_tensor((np.int32, 0, (3, 4, 1)), 0, 100) + cpu_output = self.cpu_op_exec(cpu_input, mask_cpu > 50) + npu_output = self.npu_op_exec(npu_input, mask_npu > 50) + self.assertRtolEqual(cpu_output, npu_output) + + def test_maskedselect_shape_format_fp32(self, device): + format_list = [0, 3] + shape_list = [[3, 4, 5]] + shape_format = [ + [np.float32, i, j] for i in format_list for j in shape_list + ] + mask = self.get_mask() + + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 0, 100) + cpu_output = self.cpu_op_exec(cpu_input, mask) + npu_output = self.npu_op_exec(npu_input, mask) + self.assertRtolEqual(cpu_output, npu_output) + + def test_maskedselect_shape_format_int(self, device): + dtype_list = [np.int32, np.int64] + format_list = [0] + shape_list = [[3, 4, 5]] + shape_format = [ + [i, j, k] for i in dtype_list for j in format_list for k in shape_list + ] + mask = self.get_mask() + + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 0, 100) + cpu_output = self.cpu_op_exec(cpu_input, mask) + npu_output = self.npu_op_exec(npu_input, mask) + self.assertRtolEqual(cpu_output, npu_output) + +instantiate_device_type_tests(TestMaskedSelect, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/MaskedSelectKernelNpu.cpp b/torch_npu/csrc/aten/ops/MaskedSelectKernelNpu.cpp new file mode 100644 index 0000000000..6952fb2b12 --- /dev/null +++ b/torch_npu/csrc/aten/ops/MaskedSelectKernelNpu.cpp @@ -0,0 +1,114 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::SmallVector masked_select_npu_output_size( + const at::Tensor& self, + const at::Tensor& mask) { + int64_t shape; + shape = mask.sum().item().toInt(); + return {shape}; +} + +at::Tensor& masked_select_out_npu_nocheck( + at::Tensor& result, + const at::Tensor& self, + const at::Tensor& mask) { + at::Tensor maskBool = mask; + if (!(mask.dtype() == at::kBool)) { + maskBool = NPUNativeFunctions::npu_dtype_cast(mask, at::kBool); + } + + OpCommand cmd; + cmd.Name("MaskedSelect") + .Input(self) + .Input(maskBool) + .Output(result) + .Run(); + + return result; +} + +at::Tensor& NPUNativeFunctions::masked_select_out( + const at::Tensor& self, + const at::Tensor& mask, + at::Tensor& result) { + at::Tensor dtypeCastOfSelf = self; + at::Tensor maskCast = mask; + if (maskCast.sizes() != dtypeCastOfSelf.sizes()) { + maskCast = NPUNativeFunctions::npu_broadcast(mask, dtypeCastOfSelf.sizes()); + } + if (dtypeCastOfSelf.scalar_type() == at::ScalarType::Half) { + dtypeCastOfSelf = NPUNativeFunctions::npu_dtype_cast(dtypeCastOfSelf, at::ScalarType::Float); + result = NPUNativeFunctions::npu_dtype_cast(result, at::ScalarType::Float); + } + auto outputSize = masked_select_npu_output_size(dtypeCastOfSelf, maskCast); + + OpPreparation::CheckOut( + {dtypeCastOfSelf}, + result, + dtypeCastOfSelf, + outputSize); + + OpPipeWithDefinedOut pipe; + result = pipe.CheckMemory({dtypeCastOfSelf, maskCast}, {result}) + .Func([&dtypeCastOfSelf, &maskCast](at::Tensor& result) + {masked_select_out_npu_nocheck(result, dtypeCastOfSelf, maskCast);}) + .Call(result); + + if (result.scalar_type() != self.scalar_type()) { + result = NPUNativeFunctions::npu_dtype_cast(result, at::ScalarType::Half); + } + if (!NpuUtils::check_match(&result)) { + at::Tensor contiguousResult = NpuUtils::format_contiguous(result); + masked_select_out_npu_nocheck(contiguousResult, self, mask); + NpuUtils::format_fresh_view(result, contiguousResult); + } else { + masked_select_out_npu_nocheck(result, self, mask); + } + return result; +} + +at::Tensor NPUNativeFunctions::masked_select( + const at::Tensor& self, + const at::Tensor& mask) { + at::Tensor dtypeCastOfSelf = self; + at::Tensor maskCast = mask; + if (maskCast.sizes() != dtypeCastOfSelf.sizes()) { + maskCast = NPUNativeFunctions::npu_broadcast(mask, dtypeCastOfSelf.sizes()); + } + if (dtypeCastOfSelf.scalar_type() == at::ScalarType::Half) { + dtypeCastOfSelf = NPUNativeFunctions::npu_dtype_cast(dtypeCastOfSelf, at::ScalarType::Float); + } + auto outputSize = masked_select_npu_output_size(dtypeCastOfSelf, maskCast); + + at::Tensor result = OpPreparation::ApplyTensor(dtypeCastOfSelf, outputSize); + + masked_select_out_npu_nocheck(result, dtypeCastOfSelf, maskCast); + + if (result.scalar_type() != self.scalar_type()) { + result = NPUNativeFunctions::npu_dtype_cast(result, at::ScalarType::Half); + } + return result; +} + +} // namespace native +} // namespace at_npu -- Gitee