diff --git a/test/test_network_ops/test_erf.py b/test/test_network_ops/test_erf.py new file mode 100644 index 0000000000000000000000000000000000000000..8501eebf566c56a0a8ed9a70ed098d431a888d93 --- /dev/null +++ b/test/test_network_ops/test_erf.py @@ -0,0 +1,104 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# coding: utf-8 + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.common_utils import create_common_tensor + +class TestErf(TestCase): + + def cpu_op_exec(self,input1): + output = torch.erf(input1) + output = output.numpy() + return output + + def npu_op_exec(self,input1): + output = torch.erf(input1) + output = output.to("cpu") + output = output.numpy() + return output + + def cpu_op_exec_(self,input1): + torch.erf_(input1) + output = input1.numpy() + return output + + def npu_op_exec_(self,input1): + torch.erf_(input1) + output = input1.to("cpu") + output = output.numpy() + return output + + def cpu_op_exec_out(self,input1,cpu_out): + torch.erf(input1, out = cpu_out) + output = cpu_out.numpy() + return output + + def npu_op_exec_out(self,input1,npu_out): + torch.erf(input1, out = npu_out) + output = npu_out.to("cpu") + output = output.numpy() + return output + + def test_erf_float32_common_shape_format(self, device="npu"): + shape_format = [ + [np.float32, 0 , (4, 3)], + [np.float32, -1, (2,4, 3)], + [np.float32, 3, (20, 13)], + [np.float32, 4, (20, 13)], + [np.float32, 29, (20, 13)] + ] + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item, 1, 100) + cpu_output = self.cpu_op_exec(cpu_input1) + npu_output = self.npu_op_exec(npu_input1) + self.assertRtolEqual(cpu_output, npu_output) + + def test_erf_float321_common_shape_format(self, device="npu"): + shape_format = [ + [np.float32, 0 , (4, 3)], + [np.float32, -1, (2,4, 3)], + [np.float32, 3, (20, 13)], + [np.float32, 4, (20, 13)], + [np.float32, 29, (20, 13)] + ] + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item, 1, 100) + cpu_output = self.cpu_op_exec_(cpu_input1) + npu_output = self.npu_op_exec_(npu_input1) + self.assertRtolEqual(cpu_output, npu_output) + + def test_erf_out_float32_common_shape_format(self, device="npu"): + shape_format = [ + [np.float32, 0 , (4, 3)], + [np.float32, -1, (2,4, 3)], + [np.float32, 3, (20, 13)], + [np.float32, 4, (20, 13)], + [np.float32, 29, (20, 13)] + ] + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item, 1, 100) + cpu_out, npu_out = create_common_tensor(item, 1, 100) + cpu_output = self.cpu_op_exec_out(cpu_input1, cpu_out) + npu_output = self.npu_op_exec_out(npu_input1, npu_out) + self.assertRtolEqual(cpu_output, npu_output) + + +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/test/test_network_ops/test_ger.py b/test/test_network_ops/test_ger.py new file mode 100644 index 0000000000000000000000000000000000000000..ae3a67d12694b5592cb52dd72d9991c56fa170f6 --- /dev/null +++ b/test/test_network_ops/test_ger.py @@ -0,0 +1,99 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.common_utils import create_common_tensor + + +class TestGer(TestCase): + def cpu_op_exec(self,input1, input2): + output = torch.ger(input1, input2) + output = output.numpy() + + return output + + def npu_op_exec(self,input1, input2): + output = torch.ger(input1, input2) + output = output.to("cpu").numpy() + + return output + + def npu_op_exec_out(self,input1, input2, output): + torch.ger(input1, input2, out=output) + output = output.to("cpu").numpy() + + return output + + def ger_result(self, shape_format): + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item[0], -100, 100) + cpu_input2, npu_input2 = create_common_tensor(item[1], -100, 100) + if cpu_input1.dtype == torch.float16: + cpu_input1 = cpu_input1.to(torch.float32) + if cpu_input2.dtype == torch.float16: + cpu_input2 = cpu_input2.to(torch.float32) + cpu_output = self.cpu_op_exec(cpu_input1, cpu_input2) + npu_output = self.npu_op_exec(npu_input1, npu_input2) + cpu_output = cpu_output.astype(npu_output.dtype) + self.assertRtolEqual(cpu_output, npu_output) + + def ger_out_result(self, shape_format): + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item[0], -100, 100) + cpu_input2, npu_input2 = create_common_tensor(item[1], -100, 100) + cpu_input3, npu_input3 = create_common_tensor(item[2], -100, 100) + if cpu_input1.dtype == torch.float16: + cpu_input1 = cpu_input1.to(torch.float32) + if cpu_input2.dtype == torch.float16: + cpu_input2 = cpu_input2.to(torch.float32) + if cpu_input3.dtype == torch.float16: + cpu_input3 = cpu_input3.to(torch.float32) + cpu_output = self.cpu_op_exec(cpu_input1, cpu_input2) + npu_output_out = self.npu_op_exec_out(npu_input1, npu_input2, npu_input3) + cpu_output = cpu_output.astype(npu_output_out.dtype) + self.assertRtolEqual(cpu_output, npu_output_out) + + def test_ger_result(self, device="npu"): + shape_format = [ + [[np.float16, 0, [128]], [np.float16, 0, [256]]], + [[np.float16, 0, [128]], [np.float16, 0, [58]]], + [[np.float16, 0, [128]], [np.float16, 0, [3]]], + [[np.float16, 0, [128]], [np.float16, 0, [116]]], + [[np.float32, 0, [256]], [np.float32, 0, [128]]], + [[np.float32, 0, [256]], [np.float32, 0, [3]]], + [[np.float32, 0, [2]], [np.float32, 0, [3]]], + [[np.float32, 0, [128]], [np.float32, 0, [232]]], + ] + self.ger_result(shape_format) + + def test_ger_out_result(self, device="npu"): + shape_format = [ + [[np.float16, 0, [128]], [np.float16, 0, [256]], [np.float16, 0, [256, 116]]], + [[np.float16, 0, [128]], [np.float16, 0, [58]], [np.float16, 0, [58, 58, 1, 1]]], + [[np.float16, 0, [128]], [np.float16, 0, [3]], [np.float16, 0, [3, 3]]], + [[np.float16, 0, [128]], [np.float16, 0, [116]], [np.float16, 0, [128, 116]]], + [[np.float32, 0, [256]], [np.float32, 0, [128]], [np.float32, 0, [128, 128, 3, 3]]], + [[np.float32, 0, [256]], [np.float32, 0, [3]], [np.float32, 0, [256, 3]]], + [[np.float32, 0, [2]], [np.float32, 0, [3]], [np.float32, 0, [3, 1, 3, 3]]], + [[np.float32, 0, [128]], [np.float32, 0, [232]], [np.float32, 0, [232, 232]]], + ] + self.ger_out_result(shape_format) + + +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/test/test_network_ops/test_grid_sampler_3d.py b/test/test_network_ops/test_grid_sampler_3d.py new file mode 100644 index 0000000000000000000000000000000000000000..f64df15f0967cb7be0e1649f745b542b11558cd8 --- /dev/null +++ b/test/test_network_ops/test_grid_sampler_3d.py @@ -0,0 +1,80 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd +# Copyright (c) 2019, Facebook CORPORATION. +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.common_utils import create_common_tensor + +class TestGridSampler3D(TestCase): + def exec_grid_sampler3d_fp32(self, interpolation_mode, padding_mode, align_corners): + format_list = [2] + shape_list = [[2, 100, 1, 28, 28], [2, 100, 64, 32, 28]] + shape_format = [ + [np.float32, j, k] for j in format_list for k in shape_list + ] + sample_format = [np.float32, 2, [2, 100, 1, 1, 3]] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 0, 100) + cpu_sample, npu_sample = create_common_tensor(sample_format, -1, 1) + cpu_output = self.op_exec_com(0, + cpu_input, cpu_sample, interpolation_mode, padding_mode, align_corners) + npu_output = self.op_exec_com(1, + npu_input, npu_sample, interpolation_mode, padding_mode, align_corners) + self.assertRtolEqual(cpu_output, npu_output) + + def test_grid_sampler3d_fp32(self, device="npu"): + self.exec_grid_sampler3d_fp32(0, 0, True) + self.exec_grid_sampler3d_fp32(0, 1, True) + self.exec_grid_sampler3d_fp32(1, 0, True) + self.exec_grid_sampler3d_fp32(1, 1, True) + self.exec_grid_sampler3d_fp32(0, 0, False) + self.exec_grid_sampler3d_fp32(0, 1, False) + self.exec_grid_sampler3d_fp32(1, 0, False) + self.exec_grid_sampler3d_fp32(1, 1, False) + + def test_grid_sampler3d_fp16(self, device="npu"): + format_list = [2] + shape_list = [[2, 1, 1, 3, 3], [2, 1, 2, 3, 4]] + shape_format = [ + [np.float16, j, k] for j in format_list for k in shape_list + ] + sample_format = [np.float16, 2, [2, 1, 2, 2, 3]] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 0, 10) + cpu_sample, npu_sample = create_common_tensor(sample_format, -1, 1) + cpu_output = self.cpu_op_fp16_exec(cpu_input, cpu_sample, 0, 0, True) + npu_output = self.op_exec_com(1, npu_input, npu_sample, 0, 0, True) + self.assertRtolEqual(cpu_output, npu_output) + + def op_exec_com(self, npu_flag, input1, sample, interpolation_mode, padding_mode, align_corners): + output = torch.grid_sampler_3d(input1, sample, interpolation_mode, padding_mode, align_corners) + if npu_flag: + return output.to("cpu").numpy() + return output.numpy() + + def cpu_op_fp16_exec(self, input1, sample, interpolation_mode, padding_mode, align_corners): + input1 = input1.to(torch.float32) + sample = sample.to(torch.float32) + output = torch.grid_sampler_3d(input1, sample, interpolation_mode, padding_mode, align_corners) + output = output.numpy().astype(np.float16) + return output + + +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/test/test_network_ops/test_grid_sampler_3d_backward.py b/test/test_network_ops/test_grid_sampler_3d_backward.py new file mode 100644 index 0000000000000000000000000000000000000000..e97f8957641683d86500cbaff38f9b610fe212bf --- /dev/null +++ b/test/test_network_ops/test_grid_sampler_3d_backward.py @@ -0,0 +1,95 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd +# Copyright (c) 2019, Facebook CORPORATION. +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.common_utils import create_common_tensor + +class TestGridSampler3DBackward(TestCase): + def exec_grid_sampler3d_bk_fp32(self, interpolation_mode, padding_mode, align_corners): + format_list = [2] + shape_list = [[2, 100, 1, 28, 28], [2, 100, 64, 32, 28]] + shape_format = [ + [np.float32, j, k] for j in format_list for k in shape_list + ] + sample_format = [np.float32, 2, [2, 100, 1, 1, 3]] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 0, 100) + cpu_sample, npu_sample = create_common_tensor(sample_format, -1, 1) + cpu_output, cpu_dx, cpu_dgrad = self.op_exec_com(0, + cpu_input, cpu_sample, interpolation_mode, padding_mode, align_corners) + npu_output, npu_dx, npu_dgrad = self.op_exec_com(1, + npu_input, npu_sample, interpolation_mode, padding_mode, align_corners) + self.assertRtolEqual(cpu_output, npu_output) + self.assertRtolEqual(cpu_dx, npu_dx) + self.assertRtolEqual(cpu_dgrad, npu_dgrad) + + def test_grid_sampler3d_bk_fp32(self, device="npu"): + self.exec_grid_sampler3d_bk_fp32(0, 0, True) + self.exec_grid_sampler3d_bk_fp32(0, 1, True) + self.exec_grid_sampler3d_bk_fp32(1, 0, True) + self.exec_grid_sampler3d_bk_fp32(1, 1, True) + self.exec_grid_sampler3d_bk_fp32(0, 0, False) + self.exec_grid_sampler3d_bk_fp32(0, 1, False) + self.exec_grid_sampler3d_bk_fp32(1, 0, False) + self.exec_grid_sampler3d_bk_fp32(1, 1, False) + + def test_grid_sampler3d_bk_fp16(self, device="npu"): + format_list = [2] + shape_list = [[2, 1, 1, 3, 3], [2, 1, 2, 3, 4]] + shape_format = [ + [np.float16, j, k] for j in format_list for k in shape_list + ] + sample_format = [np.float16, 2, [2, 1, 2, 2, 3]] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 0, 10) + cpu_sample, npu_sample = create_common_tensor(sample_format, -1, 1) + cpu_output, cpu_dx, cpu_dgrad = self.cpu_op_fp16_exec( + cpu_input.to(torch.float32), cpu_sample.to(torch.float32), 0, 0, True) + npu_output, npu_dx, npu_dgrad = self.op_exec_com(1, npu_input, npu_sample, 0, 0, True) + self.assertRtolEqual(cpu_output, npu_output) + self.assertRtolEqual(cpu_dx, npu_dx) + self.assertRtolEqual(cpu_dgrad, npu_dgrad) + + def op_exec_com(self, npu_flag, input1, sample, interpolation_mode, padding_mode, align_corners): + input1.requires_grad = True + output = torch.grid_sampler_3d(input1, sample, interpolation_mode, padding_mode, align_corners) + output.backward(torch.ones_like(output)) + dx, dgrad = input1.grad + if npu_flag: + output = output.detach().to("cpu").numpy() + return output, dx.to("cpu").numpy(), dgrad.to("cpu").numpy() + output = output.detach().numpy() + return output, dx.numpy(), dgrad.numpy() + + def cpu_op_fp16_exec(self, input1, sample, interpolation_mode, padding_mode, align_corners): + input1 = input1.to(torch.float32) + sample = sample.to(torch.float32) + input1.requires_grad = True + output = torch.grid_sampler_3d(input1, sample, interpolation_mode, padding_mode, align_corners) + output.backward(torch.ones_like(output)) + dx, dgrad = input1.grad + dx = dx.numpy().astype(np.float16) + dgrad = dgrad.numpy().astype(np.float16) + output = output.detach().numpy().astype(np.float16) + return output, dx, dgrad + + +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/test/test_network_ops/test_sin.py b/test/test_network_ops/test_sin.py new file mode 100644 index 0000000000000000000000000000000000000000..4252cafb90e6abc3a84b0609261d04b5fd0ebd36 --- /dev/null +++ b/test/test_network_ops/test_sin.py @@ -0,0 +1,74 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.common_utils import create_common_tensor + +class TestSin(TestCase): + def cpu_op_exec(self, input1): + output = torch.sin(input1) + output = output.numpy() + return output + + def npu_op_exec(self, input1): + output = torch.sin(input1) + output = output.to("cpu") + output = output.numpy() + return output + + def npu_op_exec_out(self, input1, input2): + torch.sin(input1, out=input2) + output = input2.to("cpu") + output = output.numpy() + return output + + def test_sin_common_shape_format(self, device="npu"): + shape_format = [ + [[np.float32, 0, (5,3)]], + ] + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item[0], -10, 10) + cpu_output = self.cpu_op_exec(cpu_input1) + npu_output = self.npu_op_exec(npu_input1) + self.assertRtolEqual(cpu_output, npu_output) + + def test_sin_out_common_shape_format(self, device="npu"): + shape_format = [ + [[np.float16, -1, (4, 3, 128, 128)], [np.float16, -1, (4, 3, 128, 128)]], + [[np.float16, 0, (4, 3, 128, 128)], [np.float16, 0, (10, 3, 64, 128)]], + [[np.float16, 0, (4, 3, 128, 128)], [np.float16, 0, (2, 3, 256, 128)]], + [[np.float32, 0, (4, 3, 128, 128)], [np.float32, 0, (4, 3, 128, 128)]], + [[np.float32, 0, (4, 3, 128, 128)], [np.float32, 0, (8, 3, 64, 128)]], + [[np.float32, -1, (4, 3, 128, 128)], [np.float32, -1, (4, 3, 256, 64)]], + ] + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item[0], -10, 10) + cpu_input2, npu_input2 = create_common_tensor(item[0], -10, 10) + cpu_input3, npu_input3 = create_common_tensor(item[1], -10, 10) + if cpu_input1.dtype == torch.float16: + cpu_input1 = cpu_input1.to(torch.float32) + cpu_output = self.cpu_op_exec(cpu_input1) + npu_output_out1 = self.npu_op_exec_out(npu_input1, npu_input2) + npu_output_out2 = self.npu_op_exec_out(npu_input1, npu_input3) + cpu_output = cpu_output.astype(npu_output_out1.dtype) + self.assertRtolEqual(cpu_output, npu_output_out1) + self.assertRtolEqual(cpu_output, npu_output_out2) + + +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/test/test_network_ops/test_std.py b/test/test_network_ops/test_std.py new file mode 100644 index 0000000000000000000000000000000000000000..bfe8b8a95d41a04d0bde59b99b0e133a41201374 --- /dev/null +++ b/test/test_network_ops/test_std.py @@ -0,0 +1,234 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd +# Copyright (c) 2019, Facebook CORPORATION. +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.common_utils import create_common_tensor + + +class TestStd(TestCase): + def cpu_op_exec(self, input1, unbiased=True): + output = torch.std(input1, unbiased=unbiased) + output = output.numpy() + return output + + def npu_op_exec(self, input1, unbiased=True): + output = torch.std(input1, unbiased=unbiased) + output = output.to("cpu") + output = output.numpy() + return output + + def cpu_op_dim_exec(self, input1, dim, unbiased=True, keepdim=False): + output = torch.std(input1, dim, unbiased=unbiased, keepdim=keepdim) + output = output.numpy() + return output + + def npu_op_dim_exec(self, input1, dim, unbiased=True, keepdim=False): + output = torch.std(input1, dim, unbiased=unbiased, keepdim=keepdim) + output = output.to("cpu") + output = output.numpy() + return output + + def cpu_op_dim_out_exec(self, input1, dim, output1, unbiased=True, keepdim=False): + torch.std(input1, dim, unbiased=unbiased, keepdim=keepdim,out=output1) + output1 = output1.numpy() + return output1 + + def npu_op_dim_out_exec(self, input1, dim, output1, unbiased=True, keepdim=False): + torch.std(input1, dim, unbiased=unbiased, keepdim=keepdim,out=output1) + output1 = output1.to("cpu") + output1 = output1.numpy() + return output1 + + def output_shape(self, inputshape, dim, unbiased=True, keepdim=False): + shape = list(inputshape) + if dim < len(inputshape): + if keepdim: + shape[dim] = 1 + else: + shape.pop(dim) + return shape + + + def create_output_tensor(self, minvalue,maxvalue,shape,npuformat,dtype): + input1 = np.random.uniform(minvalue, maxvalue, shape).astype(dtype) + cpu_input = torch.from_numpy(input1) + npu_input = torch.from_numpy(input1).npu() + if npuformat != -1: + npu_input = torch_npu.npu_format_cast(npu_input, npuformat) + return cpu_input, npu_input + + def test_std_shape_format_fp16(self, device="npu"): + format_list = [0] + shape_list = [[16], [32, 1024], [32, 8, 1024], [128, 32, 8, 1024]] + unbiased_list = [True, False] + shape_format = [ + [np.float16, i, j, k] for i in format_list for j in shape_list for k in unbiased_list + ] + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item, 0, 100) + cpu_input1 = cpu_input1.to(torch.float32) + cpu_output1 = self.cpu_op_exec(cpu_input1, item[3]) + cpu_output1 = cpu_output1.astype(np.float16) + npu_output1 = self.npu_op_exec(npu_input1, item[3]) + self.assertRtolEqual(cpu_output1, npu_output1) + + def test_std_shape_format_fp32(self, device="npu"): + format_list = [0] + shape_list = [[1024], [32, 1024], [32, 8, 1024], [128, 32, 8, 1024]] + unbiased_list = [True, False] + shape_format = [ + [np.float32, i, j, k] for i in format_list for j in shape_list for k in unbiased_list + ] + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item, 0, 100) + cpu_output = self.cpu_op_exec(cpu_input1, item[3]) + npu_output = self.npu_op_exec(npu_input1, item[3]) + self.assertRtolEqual(cpu_output, npu_output) + + def test_std_dim_shape_format_fp16(self, device="npu"): + format_list = [0] + shape_list = [[1024], [32, 1024], [32, 8, 1024], [128, 32, 8, 1024]] + dim_list = [0] + unbiased_list = [True, False] + keepdim_list = [True, False] + shape_format = [ + [np.float16, i, j, k, l, m] for i in format_list for j in shape_list + for k in dim_list for l in unbiased_list for m in keepdim_list + ] + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item, 0, 100) + cpu_input1 = cpu_input1.to(torch.float32) + cpu_output1 = self.cpu_op_dim_exec(cpu_input1, item[3], item[4], item[5]) + cpu_output1 = cpu_output1.astype(np.float16) + npu_output1 = self.npu_op_dim_exec(npu_input1, item[3], item[4], item[5]) + self.assertRtolEqual(cpu_output1, npu_output1) + + def test_std_dim_shape_format_fp32(self, device="npu"): + format_list = [0] + shape_list = [[1024], [32, 1024], [32, 8, 1024], [128, 32, 8, 1024]] + dim_list = [0] + unbiased_list = [True, False] + keepdim_list = [True, False] + shape_format = [ + [np.float32, i, j, k, l, m] for i in format_list for j in shape_list + for k in dim_list for l in unbiased_list for m in keepdim_list + ] + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item, 0, 100) + cpu_output1 = self.cpu_op_dim_exec(cpu_input1, item[3], item[4], item[5]) + npu_output1 = self.npu_op_dim_exec(npu_input1, item[3], item[4], item[5]) + self.assertRtolEqual(cpu_output1, npu_output1) + + def test_std_dim_out_shape_format_fp16(self, device="npu"): + format_list = [0] + shape_list = [[1024], [32, 24], [32, 8, 24], [12, 32, 8, 24]] + dim_list = [0] + unbiased_list = [True, False] + keepdim_list = [True, False] + shape_format = [ + [np.float16, i, j, k, l, m] for i in format_list for j in shape_list + for k in dim_list for l in unbiased_list for m in keepdim_list + ] + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item, 0, 100) + outputshape = self.output_shape(item[2],item[3],item[4],item[5]) + cpu_output,npu_output = self.create_output_tensor(0,1,outputshape,item[1],item[0]) + if item[0] == np.float16: + cpu_input1 = cpu_input1.to(torch.float32) + cpu_output = cpu_output.to(torch.float32) + cpu_output1 = self.cpu_op_dim_out_exec(cpu_input1, item[3], cpu_output, item[4], item[5]) + npu_output1 = self.npu_op_dim_out_exec(npu_input1, item[3], npu_output, item[4], item[5]) + if item[0] == np.float16: + cpu_output1 = cpu_output1.astype(np.float16) + self.assertRtolEqual(cpu_output1, npu_output1) + + def test_std_dim_out_shape_format_fp32(self, device="npu"): + format_list = [0] + shape_list = [[1024], [32, 24], [32, 8, 24], [12, 32, 8, 24]] + dim_list = [0] + unbiased_list = [True, False] + keepdim_list = [True, False] + shape_format = [ + [np.float32, i, j, k, l, m] for i in format_list for j in shape_list + for k in dim_list for l in unbiased_list for m in keepdim_list + ] + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item, 0, 100) + outputshape = self.output_shape(item[2],item[3],item[4],item[5]) + cpu_output,npu_output = self.create_output_tensor(0,1,outputshape,item[1],item[0]) + cpu_output1 = self.cpu_op_dim_out_exec(cpu_input1, item[3], cpu_output, item[4], item[5]) + npu_output1 = self.npu_op_dim_out_exec(npu_input1, item[3], npu_output, item[4], item[5]) + self.assertRtolEqual(cpu_output1, npu_output1) + + def test_std_dim_name_fp16(self, device="npu"): + shape = (1024, 8, 32) + cpu_input = torch.rand(shape, dtype=torch.float32) + npu_input = cpu_input.npu().to(torch.float16) + cpu_input.names = ['N','C','H'] + npu_input.names = ['N','C','H'] + dim = np.random.choice(['N', 'C', 'H']) + cpu_output = torch.std(cpu_input, dim=dim) + npu_output = torch.std(npu_input, dim=dim) + self.assertRtolEqual(cpu_output.to(torch.float16).numpy(), npu_output.cpu().numpy()) + + def test_std_dim_name_fp32(self, device="npu"): + shape = (1024, 8, 32) + cpu_input = torch.rand(shape, dtype=torch.float32, names=('N', 'C', 'H')) + npu_input = cpu_input.npu() + dim = np.random.choice(['N', 'C', 'H']) + cpu_output = torch.std(cpu_input, dim=dim) + npu_output = torch.std(npu_input, dim=dim) + self.assertRtolEqual(cpu_output.numpy(), npu_output.cpu().numpy()) + + def test_std_dim_out_name_fp16(self, device="npu"): + shape = (1024, 8, 32) + dimlist = ['N', 'C', 'H'] + cpu_input = torch.rand(shape, dtype=torch.float32) + npu_input = cpu_input.npu() + dim = np.random.choice(dimlist) + dims = dimlist.index(dim) + outputshape = self.output_shape(shape, dims) + cpu_output,npu_output = self.create_output_tensor(0, 1, outputshape, -1, np.float32) + npu_input = npu_input.to(torch.float16) + npu_output = npu_output.to(torch.float16) + cpu_input.names = ['N','C','H'] + npu_input.names = ['N','C','H'] + + cpu_output = torch.std(cpu_input, dim=dim,out=cpu_output) + npu_output = torch.std(npu_input, dim=dim,out=npu_output) + cpu_output = cpu_output.to(torch.float16) + self.assertRtolEqual(cpu_output.numpy(), npu_output.cpu().numpy()) + + def test_std_dim_out_name_fp32(self, device="npu"): + shape = (1024, 8, 32) + dimlist = ['N', 'C', 'H'] + cpu_input = torch.rand(shape, dtype=torch.float32, names=('N', 'C', 'H')) + npu_input = cpu_input.npu() + dim = np.random.choice(dimlist) + dims = dimlist.index(dim) + outputshape = self.output_shape(shape, dims) + cpu_output,npu_output = self.create_output_tensor(0, 1, outputshape, -1, np.float32) + cpu_output = torch.std(cpu_input, dim=dim,out=cpu_output) + npu_output = torch.std(npu_input, dim=dim,out=npu_output) + self.assertRtolEqual(cpu_output.numpy(), npu_output.cpu().numpy()) + + +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/test/test_network_ops/test_std_mean.py b/test/test_network_ops/test_std_mean.py new file mode 100644 index 0000000000000000000000000000000000000000..d49444561cbbb4f0309465771ec49d280fb06083 --- /dev/null +++ b/test/test_network_ops/test_std_mean.py @@ -0,0 +1,146 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd +# Copyright (c) 2019, Facebook CORPORATION. +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.common_utils import create_common_tensor + + +class TestStdMean(TestCase): + def cpu_op_mean_exec(self, input1, unbiased=True): + output = torch.std_mean(input1, unbiased=unbiased) + result = [] + result.append(output[0].numpy()) + result.append(output[1].numpy()) + return result + + def npu_op_mean_exec(self, input1, unbiased=True): + output = torch.std_mean(input1, unbiased=unbiased) + result = [] + result.append(output[0].to("cpu").numpy()) + result.append(output[1].to("cpu").numpy()) + return result + + def cpu_op_dim_mean_exec(self, input1, dim, unbiased=True, keepdim=False): + output = torch.std_mean(input1, dim, unbiased=unbiased, keepdim=keepdim) + result = [] + result.append(output[0].numpy()) + result.append(output[1].numpy()) + return result + + def npu_op_dim_mean_exec(self, input1, dim, unbiased=True, keepdim=False): + output = torch.std_mean(input1, dim, unbiased=unbiased, keepdim=keepdim) + result = [] + result.append(output[0].to("cpu").numpy()) + result.append(output[1].to("cpu").numpy()) + return result + + def test_std_mean_shape_format_fp16(self, device="npu"): + format_list = [0, 3, 4] + shape_list = [[2], [1, 2], [1, 1, 2], [1, 1, 1, 2]] + unbiased_list = [True, False] + shape_format = [ + [np.float16, i, j, k] for i in format_list for j in shape_list for k in unbiased_list + ] + for item in shape_format: + cpu_input1,npu_input1 = create_common_tensor(item, 0, 100) + cpu_input1 = cpu_input1.to(torch.float32) + cpu_output1 = self.cpu_op_mean_exec(cpu_input1, item[3]) + cpu_output1[0] = cpu_output1[0].astype(np.float16) + cpu_output1[1] = cpu_output1[1].astype(np.float16) + npu_output1 = self.npu_op_mean_exec(npu_input1, item[3]) + self.assertRtolEqual(cpu_output1[0], npu_output1[0]) + self.assertRtolEqual(cpu_output1[1], npu_output1[1]) + + def test_std_mean_shape_format_fp32(self, device="npu"): + format_list = [0, 3, 4] + shape_list = [[2], [1, 2], [1, 1, 2], [1, 1, 1, 2]] + unbiased_list = [True, False] + shape_format = [ + [np.float32, i, j, k] for i in format_list for j in shape_list for k in unbiased_list + ] + for item in shape_format: + cpu_input1,npu_input1 = create_common_tensor(item, 0, 100) + cpu_output1 = self.cpu_op_mean_exec(cpu_input1, item[3]) + npu_output1 = self.npu_op_mean_exec(npu_input1, item[3]) + self.assertRtolEqual(cpu_output1[0], npu_output1[0]) + self.assertRtolEqual(cpu_output1[1], npu_output1[1]) + + def test_std_dim_mean_shape_format_fp16(self, device="npu"): + format_list = [0, 3, 4] + shape_list = [[2], [1, 2], [1, 1, 2], [1, 1, 1, 2]] + dim_list = [0] + unbiased_list = [True, False] + keepdim_list = [True, False] + shape_format = [ + [np.float16, i, j, k, l, m] for i in format_list for j in shape_list + for k in dim_list for l in unbiased_list for m in keepdim_list + ] + for item in shape_format: + cpu_input1,npu_input1 = create_common_tensor(item, 0, 100) + cpu_input1 = cpu_input1.to(torch.float32) + cpu_output1 = self.npu_op_dim_mean_exec(cpu_input1, item[3], item[4], item[5]) + cpu_output1[0] = cpu_output1[0].astype(np.float16) + cpu_output1[1] = cpu_output1[1].astype(np.float16) + npu_output1 = self.npu_op_dim_mean_exec(npu_input1, item[3], item[4], item[5]) + self.assertRtolEqual(cpu_output1[0], npu_output1[0]) + self.assertRtolEqual(cpu_output1[1], npu_output1[1]) + + def test_std_dim_mean_shape_format_fp32(self, device="npu"): + format_list = [0, 3, 4] + shape_list = [[2], [1, 2], [1, 1, 2], [1, 1, 1, 2]] + dim_list = [0] + unbiased_list = [True, False] + keepdim_list = [True, False] + shape_format = [ + [np.float32, i, j, k, l, m] for i in format_list for j in shape_list + for k in dim_list for l in unbiased_list for m in keepdim_list + ] + for item in shape_format: + cpu_input1,npu_input1 = create_common_tensor(item, 0, 100) + cpu_output1 = self.npu_op_dim_mean_exec(cpu_input1, item[3], item[4], item[5]) + npu_output1 = self.npu_op_dim_mean_exec(npu_input1, item[3], item[4], item[5]) + self.assertRtolEqual(cpu_output1[0], npu_output1[0]) + self.assertRtolEqual(cpu_output1[1], npu_output1[1]) + + def test_std_dim_mean_name_fp32(self, device="npu"): + shape = (1, 1, 2) + cpu_input = torch.rand(shape, dtype=torch.float32, names=('N', 'C', 'H')) + npu_input = cpu_input.npu() + dim = np.random.choice(['N', 'C', 'H']) + cpu_output = torch.std_mean(cpu_input, dim=dim) + npu_output = torch.std_mean(npu_input, dim=dim) + self.assertRtolEqual(cpu_output[0].numpy(), npu_output[0].cpu().numpy()) + self.assertRtolEqual(cpu_output[1].numpy(), npu_output[1].cpu().numpy()) + + def test_std_dim_mean_name_fp16(self, device="npu"): + shape = (1, 1, 2) + cpu_input = torch.rand(shape, dtype=torch.float32) + npu_input = cpu_input.to(torch.float16).npu() + cpu_input.names = ['N', 'C', 'H'] + npu_input.names = ['N', 'C', 'H'] + dim = np.random.choice(['N', 'C', 'H']) + cpu_output = torch.std_mean(cpu_input, dim=dim) + npu_output = torch.std_mean(npu_input, dim=dim) + self.assertRtolEqual(cpu_output[0].to(torch.float16).numpy(), npu_output[0].cpu().numpy()) + self.assertRtolEqual(cpu_output[1].to(torch.float16).numpy(), npu_output[1].cpu().numpy()) + + +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/ErfKernelNpu.cpp b/torch_npu/csrc/aten/ops/ErfKernelNpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ef46051274e3ff692cf61bd19d6a09f08da346e4 --- /dev/null +++ b/torch_npu/csrc/aten/ops/ErfKernelNpu.cpp @@ -0,0 +1,59 @@ +// Copyright (c) 2020, Huawei Technologies.All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor& erf_npu_nocheck(const at::Tensor& self, at::Tensor& out) { + OpCommand cmd; + cmd.Name("Erf") + .Input(self) + .Output(out) + .Run(); + return out; +} + +at::Tensor& NPUNativeFunctions::erf_out(const at::Tensor& self, at::Tensor& out) { + OpPreparation::CheckOut( + {self}, + out, + self); + + if (!NpuUtils::check_match(&out)) { + at::Tensor contiguousResult = NpuUtils::format_contiguous(out); + at::Tensor newResult = erf_npu_nocheck(self, contiguousResult); + NpuUtils::format_fresh_view(out, newResult); + } else { + erf_npu_nocheck(self, out); + } + return out; +} + +at::Tensor NPUNativeFunctions::erf(const at::Tensor& self) { + auto outputSize = input_same_output_size(self); + at::Tensor result = OpPreparation::ApplyTensor(self, outputSize); + erf_npu_nocheck(self, result); + return result; +} + +at::Tensor& NPUNativeFunctions::erf_(at::Tensor& self) { + NPUNativeFunctions::erf_out(self, self); + return self; +} + +} // namespace native +} // namespace at_npu \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/GerKernelNpu.cpp b/torch_npu/csrc/aten/ops/GerKernelNpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6256544e3121dadc48534bfe662dc61154182596 --- /dev/null +++ b/torch_npu/csrc/aten/ops/GerKernelNpu.cpp @@ -0,0 +1,84 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +c10::SmallVector ger_npu_output_size( + const at::Tensor& self, + const at::Tensor& vec2) { + int64_t outputsize_0 = self.size(0); + int64_t outputsize_1 = vec2.size(0); + c10::SmallVector outputsize = {outputsize_0, outputsize_1}; + + return outputsize; +} + +at::Tensor& ger_out_npu_nocheck(const at::Tensor& self , const at::Tensor& vec2, at::Tensor& result) { + OpCommand cmd; + cmd.Name("Ger") + .Input(self) + .Input(vec2) + .Output(result) + .Run(); + + return result; +} + +at::Tensor& NPUNativeFunctions::ger_out(const at::Tensor& self , const at::Tensor& vec2, at::Tensor& result) { + // check shape + TORCH_CHECK( + self.dim() == 1, "Input1 must have only1 dims."); + TORCH_CHECK( + vec2.dim() == 1, "Input2 must have only1 dims."); + + // calculate the output size + auto outputSize = ger_npu_output_size(self, vec2); + + OpPreparation::CheckOut( + {self}, + result, + self, + outputSize); + + OpPipeWithDefinedOut pipe; + return pipe.Func([&self, &vec2](at::Tensor& result){ger_out_npu_nocheck(self, vec2, result);}) + .Call(result); +} + +at::Tensor NPUNativeFunctions::ger(const at::Tensor& self, const at::Tensor& vec2) { + // check shape + TORCH_CHECK( + self.dim() == 1, "Input1 must have only1 dims."); + TORCH_CHECK( + vec2.dim() == 1, "Input2 must have only1 dims."); + + // calculate the output size + auto outputSize = ger_npu_output_size(self, vec2); + + // construct the output Tensor of the NPU + at::Tensor result = OpPreparation::ApplyTensor(self, outputSize); + + // calculate the output result of the NPU + ger_out_npu_nocheck(self, vec2, result); + + return result; +} +} // namespace native +} // namespace at_npu \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/GridSampler3dBackwardKernelNpu.cpp b/torch_npu/csrc/aten/ops/GridSampler3dBackwardKernelNpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f94572ae2b5b241ff94ba1bb4cfa7f1cf6ed61c8 --- /dev/null +++ b/torch_npu/csrc/aten/ops/GridSampler3dBackwardKernelNpu.cpp @@ -0,0 +1,91 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +std::tuple grid_sampler_3d_backward_npu_nocheck( + const at::Tensor& grad, + const at::Tensor& input, + const at::Tensor& grid, + std::string interMode, + std::string paddingMode, + bool align_corners, + at::Tensor& dx, + at::Tensor& dgrid) { + OpCommand cmd; + cmd.Name("GridSampler3DGrad") + .Input(grad) + .Input(input) + .Input(grid) + .Output(dx) + .Output(dgrid) + .Attr("interpolation_mode", interMode) + .Attr("padding_mode", paddingMode) + .Attr("align_corners", align_corners) + .Run(); + return std::tie(dx, dgrid); +} + +std::tuple NPUNativeFunctions::grid_sampler_3d_backward( + const at::Tensor& grad, + const at::Tensor& input, + const at::Tensor& grid, + int64_t interpolation_mode, + int64_t padding_mode, + bool align_corners) { + TORCH_CHECK( + (0 <= interpolation_mode && interpolation_mode <= 2), + "interpolation_mode must be in range [0~2].") + TORCH_CHECK( + (0 <= padding_mode && padding_mode <= 2), + "padding_mode must be in range [0~2].") + at::Tensor formatCastOfGrad = grad; + at::Tensor formatCastOfInput = input; + at::Tensor formatCastOfGrid = grid; + if (formatCastOfGrad.scalar_type() == at::ScalarType::Half) { + formatCastOfGrad = NPUNativeFunctions::npu_dtype_cast(formatCastOfGrad, at::ScalarType::Float); + } + if (formatCastOfInput.scalar_type() == at::ScalarType::Half) { + formatCastOfInput = NPUNativeFunctions::npu_dtype_cast(formatCastOfInput, at::ScalarType::Float); + } + if (formatCastOfGrid.scalar_type() == at::ScalarType::Half) { + formatCastOfGrid = NPUNativeFunctions::npu_dtype_cast(formatCastOfGrid, at::ScalarType::Float); + } + + // construct the output tensor of the NPU + at::Tensor dx = OpPreparation::ApplyTensor(formatCastOfInput); + at::Tensor dgrid = OpPreparation::ApplyTensor(formatCastOfGrid); + std::string interMode[] = {"bilinear", "nearest", "bicubic"}; + std::string paddingMode[] = {"zeros", "border", "reflection"}; + + // calculate the output result of the NPU + grid_sampler_3d_backward_npu_nocheck( + formatCastOfGrad, + formatCastOfInput, + formatCastOfGrid, + interMode[interpolation_mode], + paddingMode[padding_mode], + align_corners, + dx, + dgrid); + return std::tie(dx, dgrid); +} +} // namespace native +} // namespace at_npu \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/GridSampler3dKernelNpu.cpp b/torch_npu/csrc/aten/ops/GridSampler3dKernelNpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..baf0771573b5e0a5a9232fbb082a50e83aaba7c7 --- /dev/null +++ b/torch_npu/csrc/aten/ops/GridSampler3dKernelNpu.cpp @@ -0,0 +1,92 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor& grid_sampler_3d_npu_nocheck( + const at::Tensor& self, + const at::Tensor& grid, + std::string interMode, + std::string paddingMode, + bool align_corners, + at::Tensor& result) { + OpCommand cmd; + cmd.Name("GridSampler3D") + .Input(self) + .Input(grid) + .Output(result) + .Attr("interpolation_mode", interMode) + .Attr("padding_mode", paddingMode) + .Attr("align_corners", align_corners) + .Run(); + return result; +} + +at::Tensor NPUNativeFunctions::grid_sampler_3d( + const at::Tensor& self, + const at::Tensor& grid, + int64_t interpolation_mode, + int64_t padding_mode, + bool align_corners) { + TORCH_CHECK( + (0 <= interpolation_mode && interpolation_mode <= 2), + "interpolation_mode must be in range [0~2].") + TORCH_CHECK( + (0 <= padding_mode && padding_mode <= 2), + "padding_mode must be in range [0~2].") + at::Tensor formatCastOfSelf = self; + at::Tensor formatCastOfGrid = grid; + if (formatCastOfSelf.scalar_type() == at::ScalarType::Half) { + formatCastOfSelf = NPUNativeFunctions::npu_dtype_cast(formatCastOfSelf, at::ScalarType::Float); + } + if (formatCastOfGrid.scalar_type() == at::ScalarType::Half) { + formatCastOfGrid = NPUNativeFunctions::npu_dtype_cast(formatCastOfGrid, at::ScalarType::Float); + } + + // calculate the output size + c10::SmallVector outputSize = {formatCastOfSelf.size(0), + formatCastOfSelf.size(1), + formatCastOfGrid.size(1), + formatCastOfGrid.size(2), + formatCastOfGrid.size(3)}; + + // construct the output tensor of the NPU + at::Tensor result = OpPreparation::ApplyTensorWithFormat( + outputSize, formatCastOfSelf.options(), ACL_FORMAT_ND); + std::string interMode[] = {"bilinear", "nearest", "bicubic"}; + std::string paddingMode[] = {"zeros", "border", "reflection"}; + + // calculate the output result of the NPU + grid_sampler_3d_npu_nocheck( + formatCastOfSelf, + formatCastOfGrid, + interMode[interpolation_mode], + paddingMode[padding_mode], + align_corners, + result); + + at::ScalarType selfScalarType(self.scalar_type()); + if (result.scalar_type() != selfScalarType) { + result = NPUNativeFunctions::npu_dtype_cast(result, selfScalarType); + } + return result; +} +} // namespace native +} // namespace at_npu \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/SinKernelNpu.cpp b/torch_npu/csrc/aten/ops/SinKernelNpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b9b6b9a87dfd2a6e9e7a5abed89d07082059e160 --- /dev/null +++ b/torch_npu/csrc/aten/ops/SinKernelNpu.cpp @@ -0,0 +1,62 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor& sin_out_npu_nocheck(at::Tensor& result, const at::Tensor& self) { + OpCommand cmd; + cmd.Name("Sin") + .Input(self) + .Output(result) + .Run(); + + return result; +} + +at::Tensor& NPUNativeFunctions::sin_out(const at::Tensor& self, at::Tensor& result) { + OpPreparation::CheckOut({self}, result, self); + + if (!NpuUtils::check_match(&result)) { + at::Tensor contiguousResult = NpuUtils::format_contiguous(result); + at::Tensor newResult = sin_out_npu_nocheck(contiguousResult, self); + NpuUtils::format_fresh_view(result, newResult); + } else { + sin_out_npu_nocheck(result, self); + } + + return result; +} + +at::Tensor NPUNativeFunctions::sin(const at::Tensor& self) { + // construct the output tensor of the NPU + at::Tensor result = OpPreparation::ApplyTensorWithSizes(self.sizes(), self.options()); + + // calculate the output result of the NPU + sin_out_npu_nocheck(result, self); + + return result; +} + +at::Tensor& NPUNativeFunctions::sin_(at::Tensor& self) { + NPUNativeFunctions::sin_out(self, self); + return self; +} +} // namespace native +} // namespace at_npu \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/StdKernelNpu.cpp b/torch_npu/csrc/aten/ops/StdKernelNpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..dee968daba393f8f8d6f17d31483c5ad0d79e877 --- /dev/null +++ b/torch_npu/csrc/aten/ops/StdKernelNpu.cpp @@ -0,0 +1,150 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +tuple std_mean_out_npu_nocheck( + at::Tensor& resultStd, + at::Tensor& resultMean, + const at::Tensor& self, + at::IntArrayRef dim, + bool unbiased, + bool keepdim) { + OpCommand cmd1; + cmd1.Name("ReduceMeanD") + .Input(self) + .Output(resultMean) + .Attr("axes", dim) + .Attr("keep_dims", keepdim) + .Run(); + + at::Tensor resultMeanCopy = resultMean; + if (resultMean.dim() != 0 && keepdim == false) { + auto dimVector = array_to_small_vector(dim); + std::sort(dimVector.begin(), dimVector.end()); + for (int64_t i = 0; i < dimVector.size(); i++) { + resultMeanCopy = resultMeanCopy.unsqueeze(dimVector[i]); + } + } + resultMeanCopy = resultMeanCopy.expand(self.sizes()); + OpCommand cmd2; + cmd2.Name("ReduceStdWithMean") + .Input(self) + .Input(resultMeanCopy) + .Output(resultStd) + .Attr("dim", dim) + .Attr("unbiased", unbiased) + .Attr("keepdim", keepdim) + .Run(); + + return std::tie(resultStd, resultMean); +} + +at::Tensor& NPUNativeFunctions::std_out( + const at::Tensor& self, + at::IntArrayRef dim, + bool unbiased, + bool keepdim, + at::Tensor& result) { + auto outputSize = reduce_ops_npu_output_size(self, dim, keepdim); + at::Tensor meanResult = OpPreparation::ApplyTensor(self, outputSize); + + OpPreparation::CheckOut( + {self}, + result, + ACL_FORMAT_ND, + self.scalar_type(), + outputSize); + + std_mean_out_npu_nocheck(result, meanResult, self, dim, unbiased, keepdim); + + return result; +} + +at::Tensor& NPUNativeFunctions::std_out( + const at::Tensor& self, + at::DimnameList dim, + bool unbiased, + bool keepdim, + at::Tensor& result) { + return NPUNativeFunctions::std_out(self, dimnames_to_positions(self, dim), unbiased, keepdim, result); +} + +at::Tensor NPUNativeFunctions::std( + const at::Tensor & self, + at::IntArrayRef dim, + bool unbiased, + bool keepdim) { + auto outputSize = reduce_ops_npu_output_size(self, dim, keepdim); + + at::Tensor result1 = OpPreparation::ApplyTensor(self, outputSize); + at::Tensor result2 = OpPreparation::ApplyTensor(self, outputSize); + + std_mean_out_npu_nocheck(result1, result2, self, dim, unbiased, keepdim); + return result1; +} + +at::Tensor NPUNativeFunctions::std( + const at::Tensor & self, + bool unbiased) { + c10::SmallVector dims = CalcuOpUtil::get_dimlist_for_tensor(self); + return NPUNativeFunctions::std(self, dims, unbiased, false); +} + +tuple NPUNativeFunctions::std_mean( + const at::Tensor & self, + at::IntArrayRef dim, + bool unbiased, + bool keepdim) { + auto outputSize = reduce_ops_npu_output_size(self, dim, keepdim); + + at::Tensor result1 = OpPreparation::ApplyTensor(self, outputSize); + at::Tensor result2 = OpPreparation::ApplyTensor(self, outputSize); + + std_mean_out_npu_nocheck(result1, result2, self, dim, unbiased, keepdim); + + return std::tie(result1, result2); +} + +tuple NPUNativeFunctions::std_mean( + const at::Tensor & self, + bool unbiased) { + c10::SmallVector dims = CalcuOpUtil::get_dimlist_for_tensor(self); + return NPUNativeFunctions::std_mean(self, dims, unbiased, false); +} + +tuple NPUNativeFunctions::std_mean( + const at::Tensor & self, + at::DimnameList dim, + bool unbiased, + bool keepdim) { + return NPUNativeFunctions::std_mean(self, dimnames_to_positions(self, dim), unbiased, keepdim); +} + +at::Tensor NPUNativeFunctions::std( + const at::Tensor & self, + at::DimnameList dim, + bool unbiased, + bool keepdim) { + return NPUNativeFunctions::std(self, dimnames_to_positions(self, dim), unbiased, keepdim); +} +} // namespace native +} // namespace at_npu \ No newline at end of file