From 2de26e313a815c6834403f507eed8b275bc348a3 Mon Sep 17 00:00:00 2001 From: hxf12345677 Date: Wed, 16 Feb 2022 09:16:40 +0800 Subject: [PATCH] =?UTF-8?q?npu=5Fone=5Fhot1.8.1=E7=AE=97=E5=AD=90=E7=A7=BB?= =?UTF-8?q?=E6=A4=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/test_custom_ops/test_npu_one_hot.py | 52 ++++++++++++++ torch_npu/csrc/aten/npu_native_functions.yaml | 2 + torch_npu/csrc/aten/ops/OnehotNpu.cpp | 71 +++++++++++++++++++ 3 files changed, 125 insertions(+) create mode 100644 test/test_custom_ops/test_npu_one_hot.py create mode 100644 torch_npu/csrc/aten/ops/OnehotNpu.cpp diff --git a/test/test_custom_ops/test_npu_one_hot.py b/test/test_custom_ops/test_npu_one_hot.py new file mode 100644 index 00000000000..d7aaaf597e3 --- /dev/null +++ b/test/test_custom_ops/test_npu_one_hot.py @@ -0,0 +1,52 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd +# Copyright (c) 2019, Facebook CORPORATION. +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import Dtypes, instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + + +class TestNpuOneHot(TestCase): + + def create_target_lable(self, num_classes, size): + label = torch.randint(0, num_classes, size) + return label + + def cpu_op_exec(self, input1, num_classes, on_value=1, off_value=0): + output = torch.nn.functional.one_hot(input1, num_classes=num_classes).float() + output[output == 1] = on_value + output[output == 0] = off_value + output = output.numpy() + return output + + def npu_op_exec(self, input1, num_classes, on_value=1, off_value=0): + output = torch_npu.npu_one_hot(input1, -1, num_classes, on_value, off_value) + output = output.cpu().numpy() + return output + + def test_one_hot_1(self, device): + target = self.create_target_lable(10, (64, )) + cpu_output = self.cpu_op_exec(target, 10, 0.9, 0.1) + npu_output = self.npu_op_exec(target.npu(), 10, 0.9, 0.1) + self.assertRtolEqual(cpu_output, npu_output) + +instantiate_device_type_tests(TestNpuOneHot, globals(), except_for='cpu') +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/torch_npu/csrc/aten/npu_native_functions.yaml b/torch_npu/csrc/aten/npu_native_functions.yaml index fac8c89f70e..f1237894453 100644 --- a/torch_npu/csrc/aten/npu_native_functions.yaml +++ b/torch_npu/csrc/aten/npu_native_functions.yaml @@ -1919,6 +1919,8 @@ custom: - func: npu_stride_copy(Tensor self, int[] shape, int[] stride, Scalar storage_offset) -> Tensor - func: npu_stride_copy.out(Tensor self, int[] shape, int[] stride, Scalar storage_offset, *, Tensor(a!) out) -> Tensor(a!) - func: npu_confusion_transpose_backward(Tensor grad, int[] perm, int[] shape, bool transpose_first) -> Tensor + - func: npu_one_hot(Tensor self, int num_classes=-1, int depth=1, Scalar on_value=1, Scalar off_value=0) -> Tensor + variants: function, method custom_autograd: - func: npu_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor - func: npu_convolution_transpose(Tensor input, Tensor weight, Tensor? bias, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups) -> Tensor diff --git a/torch_npu/csrc/aten/ops/OnehotNpu.cpp b/torch_npu/csrc/aten/ops/OnehotNpu.cpp new file mode 100644 index 00000000000..88e1df181a5 --- /dev/null +++ b/torch_npu/csrc/aten/ops/OnehotNpu.cpp @@ -0,0 +1,71 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor& one_hot_out_npu( + at::Tensor& result, + const at::Tensor& self, + int64_t axis, + int64_t depth, + at::Scalar on_value, + at::Scalar off_value) { + at::Tensor selfCp = NPUNativeFunctions::npu_dtype_cast(self, at::kInt); + at::Tensor on_tmp = OpPreparation::ApplyTensor( + {1}, + selfCp.options().dtype(at::ScalarType::Float), + selfCp) + .fill_(on_value); + at::Tensor off_tmp = OpPreparation::ApplyTensor( + {1}, + selfCp.options().dtype(at::ScalarType::Float), + selfCp) + .fill_(off_value); + OpCommand cmd; + cmd.Name("OneHotD") + .Input(selfCp) + .Input(on_tmp) + .Input(off_tmp) + .Output(result) + .Attr("axis", axis) + .Attr("depth", depth) + .Run(); + return result; +} + +at::Tensor NPUNativeFunctions::npu_one_hot( + const at::Tensor& self, + int64_t axis, + int64_t depth, + at::Scalar on_value, + at::Scalar off_value) { + auto outputSize = array_to_small_vector(self.sizes()); + outputSize.emplace_back(depth); + + at::Tensor result = OpPreparation::ApplyTensor( + outputSize, + self.options().dtype(at::ScalarType::Float), + self); + one_hot_out_npu(result, self, axis, depth, on_value, off_value); + + return result; +} +} // namespace native +} // namespace at_npu \ No newline at end of file -- Gitee