diff --git a/test/test_network_ops/test_cos.py b/test/test_network_ops/test_cos.py new file mode 100644 index 0000000000000000000000000000000000000000..e3b45fc1e793c654e92c4d9859fe844bb51c3fbd --- /dev/null +++ b/test/test_network_ops/test_cos.py @@ -0,0 +1,87 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import Dtypes, instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class TestCos(TestCase): + + def cpu_op_exec(self, input1): + output = torch.cos(input1) + output = output.numpy() + return output + + def npu_op_exec(self, input1): + output = torch.cos(input1) + output = output.to("cpu") + output = output.numpy() + return output + + def npu_op_exec_out(self, input1, input2): + torch.cos(input1, out=input2) + output = input2.to("cpu") + output = output.numpy() + return output + + def cpu_inp_op_exec(self, input1): + output = torch.cos_(input1) + output = output.numpy() + return output + + def npu_inp_op_exec(self, input1): + output = torch.cos_(input1) + output = output.to("cpu") + output = output.numpy() + return output + + def test_cos_common_shape_format(self, device): + shape_format = [ + [[np.float32, 0, (5,3)]], + ] + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item[0], -10, 10) + cpu_output = self.cpu_op_exec(cpu_input1) + npu_output = self.npu_op_exec(npu_input1) + self.assertRtolEqual(cpu_output, npu_output) + + def test_cos_out_common_shape_format(self, device): + shape_format = [ + [[np.float32, 0, (4,3)], [np.float32, 0, (4,3)]], + ] + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item[0], -10, 10) + cpu_input2, npu_input2 = create_common_tensor(item[1], -10, 10) + cpu_output = self.cpu_op_exec(cpu_input1) + npu_output = self.npu_op_exec_out(npu_input1, npu_input2) + self.assertRtolEqual(cpu_output, npu_output) + + def test_cos_common_shape_format(self, device): + shape_format = [ + [[np.float32, 0, (5,3)]], + ] + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item[0], -10, 10) + cpu_output = self.cpu_inp_op_exec(cpu_input1) + npu_output = self.npu_inp_op_exec(npu_input1) + self.assertRtolEqual(cpu_output, npu_output) + +instantiate_device_type_tests(TestCos, globals(), except_for='cpu') +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/CosKernelNpu.cpp b/torch_npu/csrc/aten/ops/CosKernelNpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..79611211c61b106b4569aa60ecdc97e56cc5eafd --- /dev/null +++ b/torch_npu/csrc/aten/ops/CosKernelNpu.cpp @@ -0,0 +1,60 @@ +// Copyright (c) 2020, Huawei Technologies.All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor& cos_out_npu_nocheck( + at::Tensor& result, + const at::Tensor& self) { + OpCommand cmd; + cmd.Name("Cos") + .Input(self) + .Output(result) + .Run(); + return result; +} + +at::Tensor& NPUNativeFunctions::cos_out( + const at::Tensor& self, + at::Tensor& result) { + OpPreparation::CheckOut( + {self}, + result, + self); + if (!NpuUtils::check_match(&result)) { + at::Tensor contiguousResult = NpuUtils::format_contiguous(result); + at::Tensor checkResult = cos_out_npu_nocheck(contiguousResult, self); + NpuUtils::format_fresh_view(result, checkResult); + } else { + cos_out_npu_nocheck(result, self); + } + return result; +} + +at::Tensor NPUNativeFunctions::cos(const at::Tensor& self) { + at::Tensor result = OpPreparation::ApplyTensor(self); + cos_out_npu_nocheck(result, self); + return result; +} + +at::Tensor& NPUNativeFunctions::cos_(at::Tensor& self) { + return cos_out(self, self); +} +} // namespace native +} // namespace at_npu \ No newline at end of file