diff --git a/test/test_network_ops/test_hardsigmoid.py b/test/test_network_ops/test_hardsigmoid.py new file mode 100644 index 0000000000000000000000000000000000000000..a4b5444503d173f92e5b8ecb9e559f6d8cccf977 --- /dev/null +++ b/test/test_network_ops/test_hardsigmoid.py @@ -0,0 +1,79 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.testcase import TestCase, run_tests + +class TestHardsigmoid(TestCase): + def generate_single_data(self, min_d, max_d, shape, dtype): + input1 = np.random.uniform(min_d, max_d, shape).astype(dtype) + npu_input1 = torch.from_numpy(input1) + return npu_input1 + + def cpu_op_exec(self, input1, input2): + output = input2 + h = torch.nn.Hardsigmoid() + output = h(input1) + output = output.numpy() + return output + + def npu_op_exec(self, input1, input2): + input1 = input1.to("npu") + output = input2 + h = torch.nn.Hardsigmoid() + output = h(input1) + output = output.to("cpu") + output = output.numpy() + return output + + def test_hardsigmoid_int32(self, device="npu"): + def cpu_op_exec_int32(input1): + input1 = input1.to(torch.float32) + h = torch.nn.Hardsigmoid() + output = h(input1) + output = output.numpy() + output = output.astype(np.int32) + return output + npu_input1 = self.generate_single_data(-6, 6, (3,6), np.int32) + npu_input2 = self.generate_single_data(-6, 6, (3,6), np.int32) + cpu_output = cpu_op_exec_int32(npu_input1) + npu_output = self.npu_op_exec(npu_input1, npu_input2) + self.assertRtolEqual(cpu_output, npu_output) + + def test_hardsigmoid_float32(self, device="npu"): + npu_input1 = self.generate_single_data(-6, 6, (9,3), np.float32) + npu_input2 = self.generate_single_data(-6, 6, (9,3), np.float32) + cpu_output = self.cpu_op_exec(npu_input1, npu_input2) + npu_output = self.npu_op_exec(npu_input1, npu_input2) + self.assertRtolEqual(cpu_output, npu_output) + + def test_hardsigmoid_float16(self, device="npu"): + def cpu_op_exec_float16(input1): + input1 = input1.to(torch.float32) + h = torch.nn.Hardsigmoid() + output = h(input1) + output = output.numpy() + output = output.astype(np.float16) + return output + npu_input1 = self.generate_single_data(-6, 6, (2,7), np.float16) + npu_input2 = self.generate_single_data(-6, 6, (2,7), np.float16) + cpu_output = cpu_op_exec_float16(npu_input1) + npu_output = self.npu_op_exec(npu_input1, npu_input2) + self.assertRtolEqual(cpu_output, npu_output) + +if __name__ == '__main__': + run_tests() \ No newline at end of file diff --git a/test/test_network_ops/test_hardsigmoid_backward.py b/test/test_network_ops/test_hardsigmoid_backward.py new file mode 100644 index 0000000000000000000000000000000000000000..7cd17ba8d857adc127d7c56927866887c816c654 --- /dev/null +++ b/test/test_network_ops/test_hardsigmoid_backward.py @@ -0,0 +1,114 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.testcase import TestCase, run_tests + +def cpu_input_grad_hook(grad): + global cpu_input_grad + cpu_input_grad = grad + +def npu_input_grad_hook(grad): + global npu_input_grad + npu_input_grad = grad.cpu() + +class TestHardSigmoidBackward(TestCase): + def generate_data(self, min_d, max_d, shape, dtype): + input_grad = np.random.uniform(min_d, max_d, shape).astype(dtype) + input_x = np.random.uniform(min_d, max_d, shape).astype(dtype) + input_grad = torch.from_numpy(input_grad) + input_x = torch.from_numpy(input_x) + return input_grad, input_x + + def cpu_op_exec(self, input_x, input_grad): + input_x.requires_grad_(True) + input_x.register_hook(cpu_input_grad_hook) + h = torch.nn.Hardsigmoid() + output = h(input_x) + output.backward(input_grad) + + def npu_op_exec(self, input_x, input_grad): + input_x = input_x.to("npu") + input_grad = input_grad.to("npu") + input_x.requires_grad_(True) + input_x.register_hook(npu_input_grad_hook) + h = torch.nn.Hardsigmoid() + output = h(input_x) + output.backward(input_grad) + + def test_hardsigmoidbackward_6_6_float32(self, device="npu"): + input_grad, input_x = self.generate_data(-6, 6, (6, 6), np.float32) + self.cpu_op_exec(input_x, input_grad) + self.npu_op_exec(input_x, input_grad) + self.assertRtolEqual(cpu_input_grad, npu_input_grad) + + def test_hardsigmoidbackward_10_10_float32(self, device="npu"): + input_grad, input_x = self.generate_data(-6, 6, (10, 10), np.float32) + self.cpu_op_exec(input_x, input_grad) + self.npu_op_exec(input_x, input_grad) + self.assertRtolEqual(cpu_input_grad, npu_input_grad) + + def test_hardsigmoidbackward_100_100_float32(self, device="npu"): + input_grad, input_x = self.generate_data(-6, 6, (100, 100), np.float32) + self.cpu_op_exec(input_x, input_grad) + self.npu_op_exec(input_x, input_grad) + self.assertRtolEqual(cpu_input_grad, npu_input_grad) + + def test_hardsigmoidbackward_10_10_10_10_float32(self, device="npu"): + input_grad, input_x = self.generate_data(-6, 6, (10, 10, 10, 10), np.float32) + self.cpu_op_exec(input_x, input_grad) + self.npu_op_exec(input_x, input_grad) + self.assertRtolEqual(cpu_input_grad, npu_input_grad) + + def test_hardsigmoidbackward_6_6_float16(self, device="npu"): + input_grad1, input_x1 = self.generate_data(-6, 6, (6, 6), np.float16) + input_grad1 = input_grad1.to(torch.float32) + input_x1 = input_x1.to(torch.float32) + self.cpu_op_exec(input_x1, input_grad1) + self.npu_op_exec(input_x1, input_grad1) + self.assertRtolEqual(cpu_input_grad.detach().numpy().astype(npu_input_grad.detach().numpy().dtype), + npu_input_grad.detach().numpy()) + + def test_hardsigmoidbackward_10_10_float16(self, device="npu"): + input_grad1, input_x1 = self.generate_data(-6, 6, (10, 10), np.float16) + input_grad1 = input_grad1.to(torch.float32) + input_x1 = input_x1.to(torch.float32) + self.cpu_op_exec(input_x1, input_grad1) + self.npu_op_exec(input_x1, input_grad1) + self.assertRtolEqual(cpu_input_grad.detach().numpy().astype(npu_input_grad.detach().numpy().dtype), + npu_input_grad.detach().numpy()) + + def test_hardsigmoidbackward_100_100_float16(self, device="npu"): + input_grad1, input_x1 = self.generate_data(-6, 6, (100, 100), np.float16) + input_grad1 = input_grad1.to(torch.float32) + input_x1 = input_x1.to(torch.float32) + self.cpu_op_exec(input_x1, input_grad1) + self.npu_op_exec(input_x1, input_grad1) + self.assertRtolEqual(cpu_input_grad.detach().numpy().astype(npu_input_grad.detach().numpy().dtype), + npu_input_grad.detach().numpy()) + + def test_hardsigmoidbackward_10_10_10_10_float16(self, device="npu"): + input_grad1, input_x1 = self.generate_data(-6, 6, (10, 10, 10, 10), np.float16) + input_grad1 = input_grad1.to(torch.float32) + input_x1 = input_x1.to(torch.float32) + self.cpu_op_exec(input_x1, input_grad1) + self.npu_op_exec(input_x1, input_grad1) + self.assertRtolEqual(cpu_input_grad.detach().numpy().astype(npu_input_grad.detach().numpy().dtype), + npu_input_grad.detach().numpy()) + +if __name__ == '__main__': + run_tests() diff --git a/test/test_network_ops/test_logsigmoid.py b/test/test_network_ops/test_logsigmoid.py new file mode 100644 index 0000000000000000000000000000000000000000..0e9c19ffa51ef18948950cbb24c857631adb6195 --- /dev/null +++ b/test/test_network_ops/test_logsigmoid.py @@ -0,0 +1,67 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# coding: utf-8 + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.common_utils import create_common_tensor + +class TestLogsigmoid(TestCase): + + def cpu_op_exec(self, input1): + output = torch.nn.functional.logsigmoid(input1) + output = output.numpy() + return output + + def npu_op_exec(self, input1): + output = torch.nn.functional.logsigmoid(input1) + output = output.to("cpu") + output = output.numpy() + return output + + def test_sigmoid_shape_format(self, device="npu"): + shape_format = [ + [[np.float32, 0, (6, 4)]], + [[np.float32, 3, (2, 4, 5)]], + [[np.float32, 4, (1, 2, 3, 3)]], + [[np.float32, 29, (1, 2, 3, 3)]] + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item[0], 1, 100) + cpu_output = self.cpu_op_exec(cpu_input) + npu_output = self.npu_op_exec(npu_input) + self.assertRtolEqual(cpu_output, npu_output) + + def test_sigmoid_shape_format_float16(self, device="npu"): + shape_format1 = [ + [[np.float16, 0, (6, 4)]], + [[np.float16, 3, (2, 4, 5)]], + [[np.float16, 4, (1, 2, 3, 3)]], + [[np.float16, 29, (1, 2, 3, 3)]] + ] + for item in shape_format1: + cpu_input1, npu_input1 = create_common_tensor(item[0], 1, 100) + if cpu_input1.dtype == torch.float16: + cpu_input1 = cpu_input1.to(torch.float32) + cpu_output1 = self.cpu_op_exec(cpu_input1) + npu_output1 = self.npu_op_exec(npu_input1) + cpu_output1 = cpu_output1.astype(npu_output1.dtype) + self.assertRtolEqual(cpu_output1, npu_output1) + +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/test/test_network_ops/test_logsigmoid_backward.py b/test/test_network_ops/test_logsigmoid_backward.py new file mode 100644 index 0000000000000000000000000000000000000000..0419ed4732e691532da42ff5b0d8c3e6400af8ab --- /dev/null +++ b/test/test_network_ops/test_logsigmoid_backward.py @@ -0,0 +1,89 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.common_utils import create_common_tensor + +cpu_input_grad = None +npu_input_grad = None + +def cpu_input_grad_hook(grad): + global cpu_input_grad + cpu_input_grad = grad.numpy() + +def cpu_float16_input_grad_hook(grad): + global cpu_input_grad + cpu_input_grad = grad.numpy() + cpu_input_grad = cpu_input_grad.astype(np.float16) + +def npu_input_grad_hook(grad): + global npu_input_grad + npu_input_grad = grad.cpu().numpy() + +class TestLogSigmoidBackward(TestCase): + def cpu_op_exec(self, input1): + input1.requires_grad = True + input1.register_hook(cpu_input_grad_hook) + output = torch.nn.functional.logsigmoid(input1) + z = output.sum() + z.backward() + + def npu_op_exec(self, input1): + input1.requires_grad = True + input1.register_hook(npu_input_grad_hook) + output = torch.nn.functional.logsigmoid(input1) + z = output.sum() + z.backward() + + def test_log_sigmoid_backward_shape_format(self, device="npu"): + shape_format = [ + [[np.float32, 0, (6, 4)]], + [[np.float32, 3, (2, 4, 5)]], + [[np.float32, 4, (1, 2, 3, 3)]], + [[np.float32, 29, (10, 3, 5, 3)]] + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item[0], -50, 50) + self.cpu_op_exec(cpu_input) + self.npu_op_exec(npu_input) + self.assertRtolEqual(cpu_input_grad, npu_input_grad) + + def test_log_sigmoid_backward_float16_shape_format(self, device="npu"): + def cpu_op_exec_fp16(input1): + input1.requires_grad = True + input1.register_hook(cpu_float16_input_grad_hook) + input1 = input1.to(torch.float32) + output = torch.nn.functional.logsigmoid(input1) + z = output.sum() + z.backward() + + shape_format = [ + [[np.float16, 0, (6, 4)]], + [[np.float16, 3, (2, 4, 5)]], + [[np.float16, 4, (1, 2, 3, 3)]], + [[np.float16, 29, (10, 3, 5, 3)]], + ] + + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item[0], -50, 50) + cpu_op_exec_fp16(cpu_input1) + self.npu_op_exec(npu_input1) + self.assertRtolEqual(cpu_input_grad, npu_input_grad) + +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/test/test_network_ops/test_logsigmoidforward.py b/test/test_network_ops/test_logsigmoidforward.py new file mode 100644 index 0000000000000000000000000000000000000000..6f95bd5ac2954c87616e7e2b501d375811f17621 --- /dev/null +++ b/test/test_network_ops/test_logsigmoidforward.py @@ -0,0 +1,70 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# coding: utf-8 + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.common_utils import create_common_tensor + +class TestLogsigmoidForward(TestCase): + + def cpu_op_exec(self, input1): + m = torch.nn.LogSigmoid() + output = m.forward(input1) + return output.numpy() + + def npu_op_exec(self, input1): + m = torch.nn.LogSigmoid().to("npu") + output = m.forward(input1) + output = output.to("cpu") + return output.numpy() + + def test_sigmoid_forward_shape_format(self, device="npu"): + shape_format1 = [ + [[np.float32, 0, (6, 4)]], + [[np.float32, 3, (2, 4, 5)]], + [[np.float32, 4, (1, 2, 3, 3)]], + [[np.float32, 29, (1, 2, 3, 3)]] + ] + for item in shape_format1: + cpu_input1, npu_input1 = create_common_tensor(item[0], 1, 100) + cpu_output1 = self.cpu_op_exec(cpu_input1) + npu_output1 = self.npu_op_exec(npu_input1) + self.assertRtolEqual(cpu_output1, npu_output1) + + def test_sigmoid_forward_fp16_shape_format(self, device="npu"): + shape_format = [ + [[np.float16, 0, (6, 4)]], + [[np.float16, 3, (2, 4, 5)]], + [[np.float16, 4, (1, 2, 3, 3)]], + [[np.float16, 29, (1, 2, 3, 3)]] + ] + def cpu_op_fp16_exec(input1): + input1 = input1.to(torch.float32) + m = torch.nn.LogSigmoid() + output = m.forward(input1) + return output.numpy().astype(np.float16) + + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item[0], 1, 100) + cpu_output = cpu_op_fp16_exec(cpu_input) + npu_output = self.npu_op_exec(npu_input) + self.assertRtolEqual(cpu_output, npu_output) + +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/HardsigmoidBackwardKernelNpu.cpp b/torch_npu/csrc/aten/ops/HardsigmoidBackwardKernelNpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ae79c90a7c5d7e9b5939790c77b8a3f233ce2344 --- /dev/null +++ b/torch_npu/csrc/aten/ops/HardsigmoidBackwardKernelNpu.cpp @@ -0,0 +1,48 @@ +// Copyright (c) 2020, Huawei Technologies.All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { +namespace { + +at::Tensor& hardsigmoid_backward_nocheck( + at::Tensor& grad_input, + const at::Tensor& grad_output, + const at::Tensor& self) { + OpCommand cmd; + cmd.Name("HardSigmoidGrad") + .Input(grad_output) + .Input(self) + .Output(grad_input) + .Run(); + + return grad_input; +} +} // namespace + +at::Tensor NPUNativeFunctions::hardsigmoid_backward( + const at::Tensor& grad_output, + const at::Tensor& self) { + at::Tensor grad_input = OpPreparation::ApplyTensor(grad_output); + // calculate the output result of the NPU + hardsigmoid_backward_nocheck(grad_input, grad_output, self); + + return grad_input; +} + +} // namespace native +} // namespace at_npu \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/HardsigmoidKernelNpu.cpp b/torch_npu/csrc/aten/ops/HardsigmoidKernelNpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..da0c09cb23085773bbc0b77098e3eba8881b33c4 --- /dev/null +++ b/torch_npu/csrc/aten/ops/HardsigmoidKernelNpu.cpp @@ -0,0 +1,63 @@ +// Copyright (c) 2020, Huawei Technologies.All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor& hardsigmoid_out_nocheck( + const at::Tensor& self, + at::Tensor& result) { + OpCommand cmd; + cmd.Name("HardSigmoid") + .Input(self) + .Output(result) + .Run(); + return result; +} + +at::Tensor& NPUNativeFunctions::hardsigmoid_out( + const at::Tensor& self, + at::Tensor& result) { + OpPreparation::CheckOut( + {self}, + result, + self); + if (!NpuUtils::check_match(&result)) { + at::Tensor contiguousResult = NpuUtils::format_contiguous(result); + at::Tensor checkResult = hardsigmoid_out_nocheck(self, contiguousResult); + NpuUtils::format_fresh_view(result, checkResult); + } else { + hardsigmoid_out_nocheck(self, result); + } + return result; +} + +at::Tensor NPUNativeFunctions::hardsigmoid(const at::Tensor& self) { + at::Tensor result = OpPreparation::ApplyTensor(self); + // calculate the output result of the NPU + hardsigmoid_out_nocheck(self, result); + return result; +} + +at::Tensor& hardsigmoid_(at::Tensor& self) { + OpPreparation::CheckMemory({self}, {self}); + hardsigmoid_out(self, self); + return self; +} + +} // namespace native +} // namespace at_npu diff --git a/torch_npu/csrc/aten/ops/LogSigmoidBackwardKernelNpu.cpp b/torch_npu/csrc/aten/ops/LogSigmoidBackwardKernelNpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8f6ab4c337d9226c8071bfd495db057bab5106c4 --- /dev/null +++ b/torch_npu/csrc/aten/ops/LogSigmoidBackwardKernelNpu.cpp @@ -0,0 +1,49 @@ +// Copyright (c) 2020, Huawei Technologies.All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor& NPUNativeFunctions::log_sigmoid_backward_out( + const at::Tensor& grad_output, + const at::Tensor& self, + const at::Tensor& buffer, + at::Tensor& grad_input) { + OpPreparation::CheckMemory({grad_output, self, buffer}, {grad_input}); + OpCommand cmd; + cmd.Name("LogSigmoidGrad") + .Input(grad_output) + .Input(self) + .Output(grad_input) + .Run(); + return grad_input; +} + +at::Tensor NPUNativeFunctions::log_sigmoid_backward( + const at::Tensor& grad_output, + const at::Tensor& self, + const at::Tensor& buffer) { + // construct the output tensor of the NPU + at::Tensor grad_input = OpPreparation::ApplyTensor(grad_output); + // calculate the output result of the NPU + log_sigmoid_backward_out(grad_output, self, buffer, grad_input); + + return grad_input; +} + +} // namespace native +} // namespace at_npu \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/LogSigmoidKernelNpu.cpp b/torch_npu/csrc/aten/ops/LogSigmoidKernelNpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0deb27b73d86bce5808f085bd740c83b152af327 --- /dev/null +++ b/torch_npu/csrc/aten/ops/LogSigmoidKernelNpu.cpp @@ -0,0 +1,51 @@ +// Copyright (c) 2020, Huawei Technologies.All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +tuple NPUNativeFunctions::log_sigmoid_forward_out( + const at::Tensor& self, + at::Tensor& output, + at::Tensor& buffer) { + OpCommand cmd; + cmd.Name("LogSigmoid") + .Input(self) + .Output(output) + .Run(); + return std::tie(output, buffer); +} + +tuple NPUNativeFunctions::log_sigmoid_forward(const at::Tensor& self) { + at::Tensor output = OpPreparation::ApplyTensor(self); + at::Tensor buffer = OpPreparation::ApplyTensorWithSizes({0}, self.options()); + // calculate the output result of the NPU + log_sigmoid_forward_out(self, output, buffer); + return tuple(output, buffer); +} + +at::Tensor& NPUNativeFunctions::log_sigmoid_out(const at::Tensor& self, at::Tensor& result) { + at::Tensor buffer = OpPreparation::ApplyTensorWithSizes({0}, self.options()); + return std::get<0>(at::log_sigmoid_forward_out(result, buffer, self)); +} + +at::Tensor NPUNativeFunctions::log_sigmoid(const at::Tensor& self) { + return std::get<0>(at::log_sigmoid_forward(self)); +} + +} // namespace native +} // namespace at_npu \ No newline at end of file