From ac93568aaf4de20905f6d46e7ecf127fc84e5bfd Mon Sep 17 00:00:00 2001 From: shenpengcheng Date: Sat, 5 Mar 2022 16:33:13 +0800 Subject: [PATCH] =?UTF-8?q?bernoulli=E7=AE=97=E5=AD=90=E8=BF=81=E7=A7=BB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit code fix1 --- test/test_network_ops/test_bernoulli.py | 173 ++++++++++++++++++ .../csrc/aten/ops/BernoulliKernelNpu.cpp | 121 ++++++++++++ 2 files changed, 294 insertions(+) create mode 100644 test/test_network_ops/test_bernoulli.py create mode 100644 torch_npu/csrc/aten/ops/BernoulliKernelNpu.cpp diff --git a/test/test_network_ops/test_bernoulli.py b/test/test_network_ops/test_bernoulli.py new file mode 100644 index 0000000000..1f2ed2573b --- /dev/null +++ b/test/test_network_ops/test_bernoulli.py @@ -0,0 +1,173 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.common_utils import create_common_tensor + +class TestBernoulli(TestCase): + def cpu_op_exec(self, input1): + output = torch.bernoulli(input1) + output = output.numpy() + return output + + def npu_op_exec(self, input1): + output = torch.bernoulli(input1) + output = output.to("cpu") + output = output.numpy() + return output + + def cpu_op_inplace_tensor_exec(self, input1, p): + output = input1.bernoulli_(p) + output = output.numpy() + return output + + def npu_op_inplace_tensor_exec(self, input1, p): + output = input1.bernoulli_(p) + output = output.to("cpu") + output = output.numpy() + return output + + def cpu_op_inplace_float_exec(self, input1): + output = input1.bernoulli_(0.5) + output = output.numpy() + return output + + def npu_op_inplace_float_exec(self, input1): + output = input1.bernoulli_(0.5) + output = output.to("cpu") + output = output.numpy() + return output + + def cpu_op_out_exec(self, input1, output1): + torch.bernoulli(input1, out=output1) + output1 = output1.numpy() + return output1 + + def npu_op_out_exec(self, input1, output1): + torch.bernoulli(input1, out=output1) + output1 = output1.to("cpu") + output1 = output1.numpy() + return output1 + + def cpu_op_p_exec(self, input1, p): + input1 = input1.bernoulli_(p) + input1 = input1.numpy() + return input1 + + def npu_op_p_exec(self, input1, p): + input1 = input1.bernoulli_(p) + input1 = input1.to("cpu") + input1 = input1.numpy() + return input1 + + def test_bernoulli_p(self): + format_list = [0, 3] + shape_list = [(2, 3, 4)] + p_list = [0.2, 0.6] + shape_format = [ + [np.float32, i, j, k] for i in format_list for j in shape_list + for k in p_list + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 0, 1) + cpu_output = self.cpu_op_p_exec(cpu_input, item[3]) + npu_output = self.npu_op_p_exec(npu_input, item[3]) + print(cpu_output, npu_output) + + + def test_bernoulli_out_float32(self): + format_list = [0, 3] + shape_list = [(2, 3, 4)] + shape_format = [ + [np.float32, i, j] for i in format_list for j in shape_list + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 0, 1) + cpu_output1, npu_output1 = create_common_tensor(item, 0, 1) + cpu_output = self.cpu_op_out_exec(cpu_input, cpu_output1) + npu_output = self.npu_op_out_exec(npu_input, npu_output1) + print(cpu_output, npu_output) + + def test_bernoulli_out_float16(self): + format_list = [0, 3] + shape_list = [(2, 3, 4)] + shape_format = [ + [np.float16, i, j] for i in format_list for j in shape_list + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 0, 1) + cpu_output1, npu_output1 = create_common_tensor(item, 0, 1) + cpu_input = cpu_input.to(torch.float32) + cpu_output1 = cpu_output1.to(torch.float32) + cpu_output = self.cpu_op_out_exec(cpu_input, cpu_output1) + npu_output = self.npu_op_out_exec(npu_input, npu_output1) + cpu_output = cpu_output.astype(np.float16) + print(cpu_output, npu_output) + + def test_bernoulli_float32(self): + format_list = [0, 3] + shape_list = [(2, 3, 4)] + shape_format = [ + [np.float32, i, j] for i in format_list for j in shape_list + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 0, 1) + cpu_output = self.cpu_op_exec(cpu_input) + npu_output = self.npu_op_exec(npu_input) + print(cpu_output, npu_output) + + def test_bernoulli_float16(self): + format_list = [0, 3] + shape_list = [(2, 3, 4)] + shape_format = [ + [np.float16, i, j] for i in format_list for j in shape_list + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 0, 1) + cpu_input = cpu_input.to(torch.float32) + cpu_output = self.cpu_op_exec(cpu_input) + npu_output = self.npu_op_exec(npu_input) + cpu_output = cpu_output.astype(np.float16) + print(cpu_output, npu_output) + + def test_bernoulli_tensor_p(self): + format_list = [0, 3] + shape_list = [(2, 3, 4)] + shape_format = [ + [np.float32, i, j] for i in format_list for j in shape_list + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 0, 1) + cpu_input_p, npu_input_p = create_common_tensor(item, 0, 1) + cpu_output = self.cpu_op_inplace_tensor_exec(cpu_input, cpu_input_p) + npu_output = self.npu_op_inplace_tensor_exec(npu_input, npu_input_p) + print(cpu_output, npu_output) + + def test_bernoulli_float_p(self): + format_list = [0, 3] + shape_list = [(2, 3, 4)] + shape_format = [ + [np.float32, i, j] for i in format_list for j in shape_list + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 0, 1) + cpu_output = self.cpu_op_inplace_float_exec(cpu_input) + npu_output = self.npu_op_inplace_float_exec(npu_input) + print(cpu_output, npu_output) + +if __name__ == '__main__': + run_tests() \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/BernoulliKernelNpu.cpp b/torch_npu/csrc/aten/ops/BernoulliKernelNpu.cpp new file mode 100644 index 0000000000..49470de998 --- /dev/null +++ b/torch_npu/csrc/aten/ops/BernoulliKernelNpu.cpp @@ -0,0 +1,121 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/core/npu/SecondaryStreamGuard.h" +#include "torch_npu/csrc/core/npu/NPUCachingAllocator.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor& bernoulli_npu_nocheck(at::Tensor& result, const at::Tensor& self, double p) { + auto original_stream = c10::npu::getCurrentNPUStream(); + { + torch_npu::SecondaryStreamGuard guard(c10::npu::getCurrentSecondaryStream()); + OpCommand cmd; + cmd.Name("Bernoulli") + .Input(self) + .Input(p, at::ScalarType::Float) + .Output(result) + .Run(); + } + c10_npu::NPUCachingAllocator::recordStream(self.storage().data_ptr(), original_stream); + + return result; +} + +at::Tensor& bernoulli_npu_nocheck(at::Tensor& result, const at::Tensor& self, const at::Tensor& p) { + auto original_stream = c10::npu::getCurrentNPUStream(); + { + torch_npu::SecondaryStreamGuard guard(c10::npu::getCurrentSecondaryStream()); + OpCommand cmd; + cmd.Name("Bernoulli") + .Input(self) + .Input(p) + .Output(result) + .Run(); + } + c10_npu::NPUCachingAllocator::recordStream(self.storage().data_ptr(), original_stream); + c10_npu::NPUCachingAllocator::recordStream(p.storage().data_ptr(), original_stream); + + return result; +} + +at::Tensor& NPUNativeFunctions::bernoulli_(at::Tensor& self, double p, c10::optional gen) { + at::ScalarType selfType = self.scalar_type(); + at::Tensor selfFp32 = self; + if (self.scalar_type() == at::ScalarType::Half) { + selfFp32 = NPUNativeFunctions::npu_dtype_cast(self, at::ScalarType::Float); + } + + if (!NpuUtils::check_match(&self)) { + at::Tensor contiguousSelf = NpuUtils::format_contiguous(selfFp32); + at::Tensor result = bernoulli_npu_nocheck(contiguousSelf, contiguousSelf, p); + NpuUtils::format_fresh_view(self, result); + } else { + bernoulli_npu_nocheck(selfFp32, selfFp32, p); + self.copy_(selfFp32); + } + + if(self.scalar_type() != selfType){ + self = NPUNativeFunctions::npu_dtype_cast(self, at::ScalarType::Half); + } + return self; +} + +at::Tensor& NPUNativeFunctions::bernoulli_(at::Tensor& self, const at::Tensor& p, c10::optional gen) { + at::ScalarType selfType = self.scalar_type(); + at::Tensor selfFp32 = self; + at::Tensor pFp32 = OpPreparation::CastBackToOriFormat(p); + if (self.scalar_type() == at::ScalarType::Half) { + selfFp32 = NPUNativeFunctions::npu_dtype_cast(self, at::ScalarType::Float); + pFp32 = NPUNativeFunctions::npu_dtype_cast(p, at::ScalarType::Float); + } + + if (!NpuUtils::check_match(&self)) { + at::Tensor contiguousSelf = NpuUtils::format_contiguous(selfFp32); + at::Tensor result = bernoulli_npu_nocheck(contiguousSelf, contiguousSelf, pFp32); + NpuUtils::format_fresh_view(self, result); + } else { + bernoulli_npu_nocheck(selfFp32, selfFp32, pFp32); + self.copy_(selfFp32); + } + + if(self.scalar_type() != selfType){ + self = NPUNativeFunctions::npu_dtype_cast(self, at::ScalarType::Half); + } + return self; +} + +at::Tensor NPUNativeFunctions::bernoulli(const at::Tensor& self, c10::optional gen) { + at::Tensor selfCopy = OpPreparation::ApplyTensorWithFormat(self.sizes(), self.options(), ACL_FORMAT_ND); + selfCopy.copy_(self); + return NPUNativeFunctions::bernoulli_(selfCopy, self, gen); +} + +at::Tensor NPUNativeFunctions::bernoulli(const at::Tensor& self, double p, c10::optional gen) { + return at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT).bernoulli_(p, gen); +} + +at::Tensor& NPUNativeFunctions::bernoulli_out(const at::Tensor& self, c10::optional gen, at::Tensor& result) { + result.resize_(self.sizes()).bernoulli_(self, gen); + at::namedinference::propagate_names(result, self); + return result; +} + +} // namespace native +} // namespace at_npu \ No newline at end of file -- Gitee