From 22fbdd131c315a9374deb5bb27cda4a31f015e3d Mon Sep 17 00:00:00 2001 From: xiao-daipeng Date: Fri, 10 Jun 2022 16:44:22 +0800 Subject: [PATCH] =?UTF-8?q?=E6=96=B0=E5=A2=9E=E7=AE=97=E5=AD=90SignBitsPac?= =?UTF-8?q?k=E9=80=82=E9=85=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/test_network_ops/test_sign_bits_pack.py | 55 +++++++++++++++++++ torch_npu/csrc/aten/npu_native_functions.yaml | 1 + .../csrc/aten/ops/SignBitsPackKernelNpu.cpp | 53 ++++++++++++++++++ 3 files changed, 109 insertions(+) create mode 100644 test/test_network_ops/test_sign_bits_pack.py create mode 100644 torch_npu/csrc/aten/ops/SignBitsPackKernelNpu.cpp diff --git a/test/test_network_ops/test_sign_bits_pack.py b/test/test_network_ops/test_sign_bits_pack.py new file mode 100644 index 0000000000..b3acdbd849 --- /dev/null +++ b/test/test_network_ops/test_sign_bits_pack.py @@ -0,0 +1,55 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import numpy as np +import torch_npu + +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.common_utils import create_common_tensor + +class TestLess(TestCase): + def cpu_op_exec(self, input1, size): + sign_data = np.sign(input1) + sign_data = sign_data + 1 + bool_data = np.bool_(sign_data) + pack_bit = np.packbits(bool_data, bitorder="little") + return pack_bit.reshape(size, pack_bit.shape[0]//size) + + def npu_op_exec(self,input1,size): + output = torch_npu.npu_sign_bits_pack(input1, size) + output = output.to("cpu") + output = output.numpy() + return output + + def test_sign_bits_pack(self): + shape_format = [ + [[np.float16, (16,)], 2], + [[np.float32, (8,)], 1], + [[np.float32, (32,)], 1], + [[np.float32, (32,)], 2], + [[np.float32, (16,)], 2] + ] + + for item in shape_format: + input1 = np.random.uniform(-10, 10, item[0][1]).astype(item[0][0]) + npu_input1 = torch.from_numpy(input1).npu() + cpu_output = self.cpu_op_exec(input1, item[1]) + npu_output = self.npu_op_exec(npu_input1, item[1]) + + self.assertRtolEqual(cpu_output.astype(item[0][0]), npu_output.astype(item[0][0])) + + +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/torch_npu/csrc/aten/npu_native_functions.yaml b/torch_npu/csrc/aten/npu_native_functions.yaml index 8a7eb6ec1a..fd1f4d48f4 100644 --- a/torch_npu/csrc/aten/npu_native_functions.yaml +++ b/torch_npu/csrc/aten/npu_native_functions.yaml @@ -1883,6 +1883,7 @@ custom: - func: fast_gelu_backward(Tensor grad, Tensor self) -> Tensor variants: function, method - func: _amp_foreach_non_finite_check_(Tensor[] scaled_grads) -> bool + - func: npu_sign_bits_pack(Tensor self, int size) -> Tensor - func: npu_bert_apply_adam(Scalar lr, Scalar beta1, Scalar beta2, Scalar epsilon, Tensor grad, Scalar max_grad_norm, Scalar global_grad_norm, Scalar weight_decay, Scalar? step_size=None, int adam_mode=0) -> (Tensor var, Tensor m, Tensor v) - func: npu_bert_apply_adam.out(Scalar lr, Scalar beta1, Scalar beta2, Scalar epsilon, Tensor grad, Scalar max_grad_norm, Scalar global_grad_norm, Scalar weight_decay, Scalar? step_size=None, int adam_mode=0, *, Tensor(a!) var, Tensor(b!) m, Tensor(c!) v) -> (Tensor(a!), Tensor(b!), Tensor(c!)) - func: npu_conv_transpose2d_backward(Tensor input, Tensor grad_output, Tensor weight, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor) diff --git a/torch_npu/csrc/aten/ops/SignBitsPackKernelNpu.cpp b/torch_npu/csrc/aten/ops/SignBitsPackKernelNpu.cpp new file mode 100644 index 0000000000..a79b93797c --- /dev/null +++ b/torch_npu/csrc/aten/ops/SignBitsPackKernelNpu.cpp @@ -0,0 +1,53 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor& sign_bits_pack_npu_nocheck( + at::Tensor& result, + const at::Tensor& self, + int64_t size) { + OpCommand cmd; + cmd.Name("SignBitsPack") + .Input(self) + .Output(result) + .Attr("size", size) + .Run(); + return result; +} + +at::Tensor NPUNativeFunctions::npu_sign_bits_pack(const at::Tensor& self, int64_t size) { + TORCH_CHECK(self.dim() == 1 && self.numel()%(size*8) == 0, + "input must be one-dimensional and size of input can be divisible by size*8"); + TORCH_CHECK(self.scalar_type() == at::ScalarType::Half || self.scalar_type() == at::ScalarType::Float, + "all only supports torch.float16 and torch.float32 dtypes"); + + // construct the output tensor of the NPU + at::Tensor result = OpPreparation::ApplyTensor({size, self.numel()/(size*8)}, self.options().dtype(at::kByte), self); + + // calculate the output result of the NPU + sign_bits_pack_npu_nocheck(result, self, size); + + return result; +} + +} // namespace native +} // namespace at_npu \ No newline at end of file -- Gitee