diff --git a/test/test_network_ops/test_eq.py b/test/test_network_ops/test_eq.py new file mode 100644 index 0000000000000000000000000000000000000000..771c9623351ff071a27e803225a8d41fede7b245 --- /dev/null +++ b/test/test_network_ops/test_eq.py @@ -0,0 +1,111 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import Dtypes, instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class TestEqual(TestCase): + def cpu_op_exec(self, input1, input2): + output = torch.eq(input1, input2) + output = output.numpy() + return output + + def npu_op_exec(self, input1, input2): + output = torch.eq(input1, input2) + output = output.to("cpu") + output = output.numpy() + return output + + def npu_op_exec_out(self, input1, input2): + input3 = torch.empty(0).bool().npu() + torch.eq(input1, input2, out=input3) + output = input3.to("cpu") + output = output.numpy() + return output + + def test_equal_shape_format_fp32(self, device): + dtype_list = [np.float32] + format_list = [0, 3] + shape_list = [[1024], [8, 128], [2, 8, 128], [2, 8, 128, 512]] + shape_format = [ + [d, i, j] for d in dtype_list for i in format_list for j in shape_list + ] + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item, 1, 100) + cpu_input2, npu_input2 = create_common_tensor(item, 1, 100) + cpu_output = self.cpu_op_exec(cpu_input1, cpu_input2) + npu_output = self.npu_op_exec(npu_input1, npu_input2) + self.assertEqual(cpu_output, npu_output) + + def test_equal_shape_format_fp16(self, device): + dtype_list = [np.float16] + format_list = [0, 3] + shape_list = [[1024], [8, 128], [2, 8, 128], [2, 8, 128, 512]] + shape_format = [ + [d, i, j] for d in dtype_list for i in format_list for j in shape_list + ] + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item, 1, 100) + cpu_input2, npu_input2 = create_common_tensor(item, 1, 100) + if cpu_input1.dtype == torch.float16: + cpu_input1 = cpu_input1.to(torch.float32) + if cpu_input1.dtype == torch.float16: + cpu_input2 = cpu_input2.to(torch.float32) + cpu_output = self.cpu_op_exec(cpu_input1, cpu_input2) + npu_output = self.npu_op_exec(npu_input1, npu_input2) + cpu_output = cpu_output.astype(np.float16) + self.assertEqual(cpu_output, npu_output) + + def test_equal_out_shape_format_fp32(self, device): + dtype_list = [np.float32] + format_list = [0] + shape_list = [[1024], [8, 128], [2, 8, 128], [2, 8, 128, 512]] + shape_format = [ + [[d, i, j]] for d in dtype_list for i in format_list for j in shape_list + ] + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item[0], -10, 10) + cpu_input2, npu_input2 = create_common_tensor(item[0], -10, 10) + npu_output_out = self.npu_op_exec_out(npu_input1, npu_input2) + npu_output = self.npu_op_exec(npu_input1, npu_input2) + self.assertEqual(npu_output_out, npu_output) + + def test_equal_scalar_out_shape_format_fp32(self, device): + dtype_list = [np.float32] + format_list = [0] + shape_list = [[1024], [8, 128], [2, 8, 128], [2, 8, 128, 512]] + shape_format = [ + [[d, i, j]] for d in dtype_list for i in format_list for j in shape_list + ] + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item[0], -10, 10) + npu_output_out = self.npu_op_exec_out(npu_input1, 5) + npu_output = self.npu_op_exec(npu_input1, 5) + self.assertEqual(npu_output_out, npu_output) + + def test_equal_mix_dtype(self, device): + npu_input1, npu_input2 = create_common_tensor([np.float16, 0, (2, 3)], 1, 100) + npu_input3, npu_input4 = create_common_tensor([np.float32, 0, (2, 3)], 1, 100) + cpu_output = self.cpu_op_exec(npu_input1, npu_input3) + npu_output = self.npu_op_exec(npu_input2, npu_input4) + self.assertRtolEqual(cpu_output, npu_output) + +instantiate_device_type_tests(TestEqual, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() diff --git a/test/test_network_ops/test_floordivide.py b/test/test_network_ops/test_floordivide.py new file mode 100644 index 0000000000000000000000000000000000000000..8c477d15ae711f2b884d96c66793e4e763fb51d2 --- /dev/null +++ b/test/test_network_ops/test_floordivide.py @@ -0,0 +1,148 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import Dtypes, instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class TestFloorDivide(TestCase): + + def generate_data(self, min1, max1, shape, dtype): + input1 = np.random.uniform(min1, max1, shape).astype(dtype) + input2 = np.random.uniform(min1, max1, shape).astype(dtype) + + #modify from numpy.ndarray to torch.tensor + npu_input1 = torch.from_numpy(input1) + npu_input2 = torch.from_numpy(input2) + + return npu_input1, npu_input2 + + def generate_three_data(self, min1, max1, shape, dtype): + input1 = np.random.uniform(min1, max1, shape).astype(dtype) + input2 = np.random.uniform(min1, max1, shape).astype(dtype) + input3 = np.random.uniform(min1, max1, shape).astype(dtype) + + #modify from numpy.ndarray to torch.tensor + npu_input1 = torch.from_numpy(input1) + npu_input2 = torch.from_numpy(input2) + npu_input3 = torch.from_numpy(input3) + + return npu_input1, npu_input2, npu_input3 + + def cpu_op_exec(self, input1, input2): + output = torch.floor_divide(input1,input2) + output = output.numpy() + return output + + def npu_op_exec(self, input1, input2): + input1 = input1.to("npu") + input2 = input2.to("npu") + output = torch.floor_divide(input1,input2) + output = output.to("cpu") + output = output.numpy() + return output + + + def npu_op_exec_scalar(self, input1, input2): # + input1 = input1.to("npu") + output = torch.floor_divide(input1,input2) + output = output.to("cpu") + output = output.numpy() + return output + + + def npu_op_exec_out(self, input1, input2, input3): # + input1 = input1.to("npu") + input2 = input2.to("npu") + output = input3.to("npu") + torch.floor_divide(input1, input2, out=output) + output = output.to("cpu") + output = output.numpy() + return output + + + def test_floor_divide_float32(self, device): + npu_input1, npu_input2 = self.generate_data(1, 100, (1, 2), np.float32) + cpu_output = self.cpu_op_exec(npu_input1, npu_input2) + npu_output = self.npu_op_exec(npu_input1, npu_input2) + self.assertRtolEqual(cpu_output, npu_output) + + + def test_floor_divide_float32_out(self, device): + npu_input1, npu_input2, npu_input3 = self.generate_three_data(1, 100, (1,2), np.float32) + cpu_output = self.cpu_op_exec(npu_input1, npu_input2) + npu_output = self.npu_op_exec_out(npu_input1, npu_input2, npu_input3) + self.assertRtolEqual(cpu_output, npu_output) + + + def test_floor_divide_int32(self, device): + npu_input1, npu_input2 = self.generate_data(1, 100, (1,2), np.int32) + cpu_output = self.cpu_op_exec(npu_input1, npu_input2) + npu_output = self.npu_op_exec(npu_input1, npu_input2) + self.assertRtolEqual(cpu_output, npu_output) + + def test_floor_divide_int8(self, device): + npu_input1, npu_input2 = self.generate_data(1, 100, (1,2), np.int8) + cpu_output = self.cpu_op_exec(npu_input1, npu_input2) + npu_output = self.npu_op_exec(npu_input1, npu_input2) + self.assertRtolEqual(cpu_output, npu_output) + + def test_floor_divide_uint8(self, device): + npu_input1, npu_input2 = self.generate_data(1, 100, (1,3), np.uint8) + cpu_output = self.cpu_op_exec(npu_input1, npu_input2) + npu_output = self.npu_op_exec(npu_input1, npu_input2) + self.assertRtolEqual(cpu_output, npu_output) + + def test_floor_divide_scalar_float32(self, device): + npu_input1, _ = self.generate_data(1, 100, (1,3), np.float32) + cpu_output = self.cpu_op_exec(npu_input1, 1) + npu_output = self.npu_op_exec_scalar(npu_input1, 1) + self.assertRtolEqual(cpu_output, npu_output) + + def test_floor_divide_scalar_bool(self, device): + npu_input1, _ = self.generate_data(1, 10, (2, 5), np.float32) + cpu_output = self.cpu_op_exec(npu_input1 > 5, 1.0) + npu_output = self.npu_op_exec_scalar(npu_input1 > 5, 1.0) + self.assertRtolEqual(cpu_output, npu_output) + + def npu_uncontiguous_op_exec_scalar(self, input1, input2): # + input1 = input1.to("npu") + input1 = input1.as_strided([2,2], [1,2], 1) + output = torch.floor_divide(input1, input2) + output = output.to("cpu") + output = output.numpy() + return output + + def cpu_uncontiguous_op_exec_scalar(self, input1, input2): # + input1 = input1.as_strided([2,2], [1,2], 1) + output = torch.floor_divide(input1, input2) + output = output.numpy() + return output + + def test_floor_divide_uncontiguous_float32_scalar(self, device): + npu_input1, npu_input2 = self.generate_data(1, 100, (4,3), np.float32) + cpu_input1 = copy.deepcopy(npu_input1) + cpu_output = self.cpu_uncontiguous_op_exec_scalar(cpu_input1, 2) + npu_output = self.npu_uncontiguous_op_exec_scalar(npu_input1, 2) + self.assertRtolEqual(cpu_output, npu_output) + +instantiate_device_type_tests(TestFloorDivide, globals(), except_for='cpu') +if __name__ == '__main__': + run_tests() + diff --git a/test/test_network_ops/test_sigmoid.py b/test/test_network_ops/test_sigmoid.py new file mode 100644 index 0000000000000000000000000000000000000000..41cdaf60abd6cbb89654e0c1a75d4fc39e087236 --- /dev/null +++ b/test/test_network_ops/test_sigmoid.py @@ -0,0 +1,123 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import Dtypes, instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class TestSigmoid(TestCase): + @Dtypes(torch.float) + def test_sigmoid(self, device, dtype): + # TODO: why not simulate math.sigmoid like with rsqrt? + inputValues = [-1000, -1, 0, 0.5, 1, 2, 1000] + expectedOutput = [0.0000, 0.2689, 0.5, 0.6225, 0.7311, 0.8808, 1.000] + precision_4dps = 0.0002 + + self.assertEqual( + torch.tensor( + inputValues, dtype=dtype, device=device).sigmoid().cpu(), + torch.tensor( + expectedOutput, + dtype=dtype, device=device).cpu(), precision_4dps) + + def cpu_op_exec(self, input1): + output = torch.sigmoid(input1) + output = output.numpy() + return output + + def npu_op_exec(self, input1): + output = torch.sigmoid(input1) + output = output.to("cpu") + output = output.numpy() + return output + + def npu_op_out_exec(self, input1, output): + torch.sigmoid(input1, out = output) + output = output.to("cpu").numpy() + return output + + def test_sigmoid_shape_format_fp16(self, device): + format_list = [0] + shape_list = [1, (64, 10), (32, 3, 3), (256, 2048, 7, 7)] + shape_format = [ + [np.float16, i, j] for i in format_list for j in shape_list + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 1, 100) + cpu_input = cpu_input.to(torch.float32) + cpu_output = self.cpu_op_exec(cpu_input) + npu_output = self.npu_op_exec(npu_input) + cpu_output = cpu_output.astype(np.float16) + self.assertRtolEqual(cpu_output, npu_output) + + def test_sigmoid_shape_format_fp32(self, device): + format_list = [0, 3, 4, 29] + shape_list = [1, (32, 32, 3, 3), (256, 2048, 7, 7)] + shape_format = [ + [np.float32, i, j] for i in format_list for j in shape_list + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 1, 100) + cpu_output = self.cpu_op_exec(cpu_input) + npu_output = self.npu_op_exec(npu_input) + self.assertRtolEqual(cpu_output, npu_output) + + def test_sigmoid_out_float32_shape_format(self, device): + shape_format = [ + [[np.float32, 0, [1024, 32, 7, 7]], [np.float32, 0, [1024, 32, 7, 7]]], + [[np.float32, 0, [1024, 32, 7]], [np.float32, 0, [128, 32, 8, 10]]], + [[np.float32, 0, [512, 32]], [np.float32, 0, [1024, 20]]], + [[np.float32, 0, [1024]], [np.float32, 0, [1024, 1]]], + [[np.float32, 3, [1024, 32, 7, 7]], [np.float32, 3, [1024, 32, 7, 7]]], + [[np.float32, 3, [1024, 32, 7]], [np.float32, 3, [1024, 32]]], + [[np.float32, 3, [1024, 32]], [np.float32, 3, [1024, 20]]], + [[np.float32, 3, [1024]], [np.float32, 3, [1024]]], + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item[0], 1, 100) + cpu_output, npu_output = create_common_tensor(item[1], -1, 1) + cpu_output = self.cpu_op_exec(cpu_input) + npu_output = self.npu_op_out_exec(npu_input, npu_output) + self.assertRtolEqual(cpu_output, npu_output) + + def test_sigmoid_out_float16_shape_format(self, device): + shape_format = [ + [[np.float16, 0, [1024, 32, 7, 7]], [np.float16, 0, [1024, 32, 7, 7]]], + [[np.float16, 0, [1024, 32, 7]], [np.float16, 0, [128, 32, 8, 10]]], + [[np.float16, 0, [510, 32]], [np.float16, 0, [1024, 20]]], + [[np.float16, 0, [1024]], [np.float16, 0, [1024, 1]]], + [[np.float16, 3, [1024, 32, 7, 7]], [np.float16, 3, [1024, 32, 7, 7]]], + [[np.float16, 3, [1024, 32, 7]], [np.float16, 3, [1024, 32]]], + [[np.float16, 3, [1024, 32]], [np.float16, 3, [1024, 20]]], + [[np.float16, 3, [1024]], [np.float16, 3, [128]]], + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item[0], 1, 100) + cpu_output, npu_output = create_common_tensor(item[1], -1, 1) + if item[0][0] == np.float16: + cpu_input = cpu_input.to(torch.float32) + cpu_output = cpu_output.to(torch.float32) + cpu_output = self.cpu_op_exec(cpu_input) + npu_output = self.npu_op_out_exec(npu_input, npu_output) + if item[0][0] == np.float16: + cpu_output = cpu_output.astype(np.float16) + self.assertRtolEqual(cpu_output, npu_output) + + +instantiate_device_type_tests(TestSigmoid, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() diff --git a/test/test_network_ops/test_sigmoid_backward.py b/test/test_network_ops/test_sigmoid_backward.py new file mode 100644 index 0000000000000000000000000000000000000000..98f6257b6b5a444a734d7f1598cd4f1479c010fa --- /dev/null +++ b/test/test_network_ops/test_sigmoid_backward.py @@ -0,0 +1,97 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import Dtypes, instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +def input_grad_hook(grad): + global input_grad + input_grad = grad + input_grad = input_grad.numpy() + + +def npu_input_grad_hook(grad): + global npu_input_grad + npu_input_grad = grad.to("cpu") + npu_input_grad = npu_input_grad.numpy() + + +class TestSigmoidBackward(TestCase): + def cpu_op_exec(self, input1, is_contiguous = True): + if is_contiguous is False : + input1 = input1.as_strided([2,2], [1,2], 1) + input1.requires_grad = True + input1.register_hook(input_grad_hook) + output = torch.sigmoid(input1) + z = output.sum() + z.backward() + + def npu_op_exec(self, input1, is_contiguous = True): + if is_contiguous is False : + input1 = input1.as_strided([2,2], [1,2], 1) + input1.requires_grad = True + input1.register_hook(npu_input_grad_hook) + + output = torch.sigmoid(input1) + z = output.sum() + z.backward() + input1 = input1.cpu() + + def test_sigmoid_backward_shape_format_fp16(self, device): + format_list = [0] + shape_list = [5,(64, 10),(32, 3, 3),(256, 2048, 7, 7)] + shape_format = [ + [np.float16, i, j] for i in format_list for j in shape_list + ] + for item in shape_format: + input1, npu_input1 = create_common_tensor(item, 1, 100) + input2, npu_input2 = create_common_tensor(item, 1, 100) + input1 = input1.to(torch.float32) + input2 = input2.to(torch.float32) + self.cpu_op_exec(input1) + self.npu_op_exec(npu_input1) + global input_grad + input_grad = input_grad.astype(npu_input_grad.dtype) + self.assertRtolEqual(input_grad, npu_input_grad) + + self.cpu_op_exec(input2, False) + self.npu_op_exec(npu_input2, False) + input_grad = input_grad.astype(np.float16) + self.assertRtolEqual(input_grad, npu_input_grad) + + def test_sigmoid_backward_shape_format_fp32(self, device): + format_list = [0, 3, 4, 29] + shape_list = [(256, 2048, 7, 7)] + shape_format = [ + [np.float32, i, j] for i in format_list for j in shape_list + ] + for item in shape_format: + input1, npu_input1 = create_common_tensor(item, 1, 100) + input2, npu_input2 = create_common_tensor(item, 1, 100) + self.cpu_op_exec(input1) + self.npu_op_exec(npu_input1) + self.assertRtolEqual(input_grad, npu_input_grad) + + self.cpu_op_exec(input2, False) + self.npu_op_exec(npu_input2, False) + self.assertRtolEqual(input_grad, npu_input_grad) + + +instantiate_device_type_tests(TestSigmoidBackward, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() diff --git a/test/test_network_ops/test_stack.py b/test/test_network_ops/test_stack.py new file mode 100644 index 0000000000000000000000000000000000000000..14b32b8a681732245d49845fb42b8ea9f0843713 --- /dev/null +++ b/test/test_network_ops/test_stack.py @@ -0,0 +1,159 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import Dtypes, instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class TestStack(TestCase): + def cpu_op_exec(self, input1, input2, dim): + cpu_output = torch.stack((input1, input2), dim) + cpu_output = cpu_output.numpy() + return cpu_output + + def npu_op_exec(self, input1, input2, dim): + output = torch.stack((input1, input2), dim) + output = output.to("cpu") + output = output.numpy() + return output + + def cpu_op_exec_out(self, input1, input2, dim, input3): + torch.stack((input1, input2), dim, out=input3) + output = input3.numpy() + return output + + def npu_op_exec_out(self, input1, input2, dim, input3): + torch.stack((input1, input2), dim, out=input3) + output = input3.to("cpu") + output = output.numpy() + return output + + def npu_output_size(self, inputs, dim = 0): + shape = [] + for i in range(dim): + shape.append(inputs[0].size(i)) + shape.append(len(inputs)) + for i in range(dim, inputs[0].dim()): + shape.append(inputs[0].size(i)) + + return shape + + def stack_result(self, shape_format): + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item[0], 0, 100) + cpu_input2, npu_input2 = create_common_tensor(item[0], 0, 100) + shape = self.npu_output_size([npu_input1,npu_input2], item[1]) + npu_input3 = torch.ones(shape, dtype = cpu_input1.dtype).npu() + cpu_input3 = torch.ones(shape, dtype = cpu_input1.dtype) + if cpu_input1.dtype == torch.float16: + cpu_input1 = cpu_input1.to(torch.float32) + cpu_input2 = cpu_input2.to(torch.float32) + cpu_input3 = cpu_input3.to(torch.float32) + + cpu_output = self.cpu_op_exec(cpu_input1, cpu_input2, item[1]) + npu_output = self.npu_op_exec(npu_input1, npu_input2, item[1]) + cpu_output_out = self.cpu_op_exec_out(cpu_input1, cpu_input2, item[1], cpu_input3) + npu_output_out = self.npu_op_exec_out(npu_input1, npu_input2, item[1], npu_input3) + + cpu_output = cpu_output.astype(npu_output.dtype) + self.assertRtolEqual(cpu_output, npu_output) + self.assertRtolEqual(cpu_output, npu_output_out) + + def test_stack_shape_format_fp16_1d(self, device): + format_list = [0, 3] + shape_format = [[[np.float16, i, [18]], np.random.randint(0, 1)] for i in format_list] + self.stack_result(shape_format) + + def test_stack_shape_format_fp16_2d(self, device): + format_list = [0, 3, 29] + shape_format = [[[np.float16, i, [5, 256]], np.random.randint(0, 2)] for i in format_list] + self.stack_result(shape_format) + + def test_stack_shape_format_fp16_3d(self, device): + format_list = [0, 3, 29] + shape_format = [[[np.float16, i, [32, 3, 3]], np.random.randint(0, 3)] for i in format_list] + self.stack_result(shape_format) + + def test_stack_shape_format_fp16_4d(self, device): + format_list = [0, 3, 29] + shape_format = [[[np.float16, i, [32, 32, 3, 3]], np.random.randint(0, 4)] for i in format_list] + self.stack_result(shape_format) + + def test_stack_shape_format_fp32_1d(self, device): + format_list = [0, 3] + shape_format = [[[np.float32, i, [18]], np.random.randint(0, 1)] for i in format_list] + self.stack_result(shape_format) + + def test_stack_shape_format_fp32_2d(self, device): + format_list = [0, 3, 29] + shape_format = [[[np.float32, i, [5, 256]], np.random.randint(0, 2)] for i in format_list] + self.stack_result(shape_format) + + def test_stack_shape_format_fp32_3d(self, device): + format_list = [0, 3, 29] + shape_format = [[[np.float32, i, [32, 3, 3]], np.random.randint(0, 3)] for i in format_list] + self.stack_result(shape_format) + + def test_stack_shape_format_fp32_4d(self, device): + format_list = [0, 3, 29] + shape_format = [[[np.float32, i, [32, 32, 3, 3]], np.random.randint(0, 4)] for i in format_list] + self.stack_result(shape_format) + + def test_stack_shape_format_int32_1d(self, device): + format_list = [0] + shape_format = [[[np.int32, i, [18]], np.random.randint(0, 1)] for i in format_list] + self.stack_result(shape_format) + + def test_stack_shape_format_int32_2d(self, device): + format_list = [0] + shape_format = [[[np.int32, i, [5, 256]], np.random.randint(0, 2)] for i in format_list] + self.stack_result(shape_format) + + def test_stack_shape_format_int32_3d(self, device): + format_list = [0] + shape_format = [[[np.int32, i, [32, 3, 3]], np.random.randint(0, 3)] for i in format_list] + self.stack_result(shape_format) + + def test_stack_shape_format_int32_4d(self, device): + format_list = [-1] + shape_format = [[[np.int32, i, [32, 32, 3, 3]], np.random.randint(0, 4)] for i in format_list] + self.stack_result(shape_format) + + def test_stack_size_dim(self, device): + def cpu_op_exec(input1): + output = torch.stack((input1, input1, input1, input1, input1, input1, input1, input1, input1)) + return output.numpy() + + def npu_op_exec(input1): + output = torch.stack((input1, input1, input1, input1, input1, input1, input1, input1, input1)) + output = output.to("cpu") + return output.numpy() + shape_format = [ + [[np.int32, 0, ()]], + [[np.float32, 0, ()]], + [[np.float16, 0, ()]], + ] + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item[0], 1, 100) + cpu_output = cpu_op_exec(cpu_input1) + npu_output = npu_op_exec(npu_input1) + self.assertRtolEqual(cpu_output, npu_output) + + +instantiate_device_type_tests(TestStack, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() diff --git a/test/test_network_ops/test_zero.py b/test/test_network_ops/test_zero.py new file mode 100644 index 0000000000000000000000000000000000000000..a8466832968c82af2a789dd09bda8d29f641ae41 --- /dev/null +++ b/test/test_network_ops/test_zero.py @@ -0,0 +1,109 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import Dtypes, instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class TestZero(TestCase): + def cpu_op_exec(self, input1): + torch.zero_(input1) + output = input1.numpy() + return output + + def npu_op_exec(self, input1): + torch.zero_(input1) + output = input1.to("cpu") + output = output.numpy() + return output + + def zero_result(self, shape_format): + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item, 0, 100) + if cpu_input1.dtype == torch.float16: + cpu_input1 = cpu_input1.to(torch.float32) + + cpu_output = self.cpu_op_exec(cpu_input1) + npu_output = self.npu_op_exec(npu_input1) + + cpu_output = cpu_output.astype(npu_output.dtype) + self.assertRtolEqual(cpu_output, npu_output) + + def test_zero_shape_format_fp16_1d(self, device): + format_list = [0, 3, 29] + shape_format = [[np.float16, i, [18]] for i in format_list] + self.zero_result(shape_format) + + def test_zero_shape_format_fp16_2d(self, device): + format_list = [0, 3, 29] + shape_format = [[np.float16, i, [5, 256]] for i in format_list] + self.zero_result(shape_format) + + def test_zero_shape_format_fp16_3d(self, device): + format_list = [0, 3, 29] + shape_format = [[np.float16, i, [32, 3, 3]] for i in format_list] + self.zero_result(shape_format) + + def test_zero_shape_format_fp16_4d(self, device): + format_list = [0, 3, 29] + shape_format = [[np.float16, i, [64, 112, 7, 7]] for i in format_list] + self.zero_result(shape_format) + + def test_zero_shape_format_fp32_1d(self, device): + format_list = [0, 3, 29] + shape_format = [[np.float32, i, [18]] for i in format_list] + self.zero_result(shape_format) + + def test_zero_shape_format_fp32_2d(self, device): + format_list = [0, 3, 29] + shape_format = [[np.float32, i, [5, 256]] for i in format_list] + self.zero_result(shape_format) + + def test_zero_shape_format_fp32_3d(self, device): + format_list = [0, 3, 29] + shape_format = [[np.float32, i, [32, 3, 3]] for i in format_list] + self.zero_result(shape_format) + + def test_zero_shape_format_fp32_4d(self, device): + format_list = [0, 3, 29] + shape_format = [[np.float32, i, [64, 112, 7, 7]] for i in format_list] + self.zero_result(shape_format) + + def test_zero_shape_format_int32_1d(self, device): + format_list = [-1, 0] + shape_format = [[np.int32, i, [18]] for i in format_list] + self.zero_result(shape_format) + + def test_zero_shape_format_int32_2d(self, device): + format_list = [-1, 0] + shape_format = [[np.int32, i, [5, 256]] for i in format_list] + self.zero_result(shape_format) + + def test_zero_shape_format_int32_3d(self, device): + format_list = [-1, 0] + shape_format = [[np.int32, i, [32, 3, 3]] for i in format_list] + self.zero_result(shape_format) + + def test_zero_shape_format_int32_4d(self, device): + format_list = [-1, 0] + shape_format = [[np.int32, i, [64, 112, 7, 7]] for i in format_list] + self.zero_result(shape_format) + + +instantiate_device_type_tests(TestZero, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() diff --git a/test/test_network_ops/test_zeros.py b/test/test_network_ops/test_zeros.py new file mode 100644 index 0000000000000000000000000000000000000000..d49d942ab4b5140eda631960790ec4a5a9cf75f9 --- /dev/null +++ b/test/test_network_ops/test_zeros.py @@ -0,0 +1,151 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import Dtypes, instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class TestZeros(TestCase): + def cpu_op_exec(self, input1, dtype): + output = torch.zeros(input1.size(), dtype=dtype, device="cpu") + output = output.numpy() + return output + + def npu_op_exec(self, input1, dtype): + output = torch.zeros(input1.size(), dtype=dtype, device="npu") + output = output.to("cpu") + output = output.numpy() + return output + + def npu_op_exec_out(self, input1, input2, dtype): + torch.zeros(input1.size(), dtype=dtype, device="npu", out=input2) + output = input2.to("cpu") + output = output.numpy() + return output + + def zeros_result(self, shape_format): + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item[0], 0, 100) + npu_input2 = copy.deepcopy(cpu_input1) + npu_input2 = npu_input2.to(item[1]).to('npu') + if cpu_input1.dtype == torch.float16: + cpu_input1 = cpu_input1.to(torch.float32) + + cpu_output = self.cpu_op_exec(cpu_input1, item[1]) + npu_output = self.npu_op_exec(npu_input1, item[1]) + npu_output_out = self.npu_op_exec_out(npu_input1, npu_input2, item[1]) + cpu_output = cpu_output.astype(npu_output.dtype) + + self.assertRtolEqual(cpu_output, npu_output) + self.assertRtolEqual(cpu_output, npu_output_out) + + def test_zeros_shape_format_names(self, device): + format_list = [0, 3, 29] + dtype_list = [torch.float16, torch.float32, torch.int32] + shape_format = [[[np.float32, i, [18, 24, 8, 8]], j] for i in format_list for j in dtype_list] + + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item[0], 0, 100) + if cpu_input1.dtype == torch.float16: + cpu_input1 = cpu_input1.to(torch.float32) + + cpu_output = torch.zeros(cpu_input1.size(), names=('N', 'C', 'H', 'W'), dtype=item[1], device="cpu") + cpu_output = cpu_output.numpy() + npu_output = torch.zeros(cpu_input1.size(), names=('N', 'C', 'H', 'W'), dtype=item[1], device="npu") + npu_output = npu_output.to("cpu") + npu_output = npu_output.numpy() + cpu_output = cpu_output.astype(npu_output.dtype) + + self.assertRtolEqual(cpu_output, npu_output) + + def test_zeros_shape_format_fp16_1d(self, device): + format_list = [0, 3, 29] + dtype_list = [torch.float16, torch.float32, torch.int32] + shape_format = [[[np.float16, i, [18]], j] for i in format_list for j in dtype_list] + self.zeros_result(shape_format) + + def test_zeros_shape_format_fp16_2d(self, device): + format_list = [0, 3, 29] + dtype_list = [torch.float16, torch.float32, torch.int32] + shape_format = [[[np.float16, i, [5, 256]], j] for i in format_list for j in dtype_list] + self.zeros_result(shape_format) + + def test_zeros_shape_format_fp16_3d(self, device): + format_list = [0, 3, 29] + dtype_list = [torch.float16, torch.float32, torch.int32] + shape_format = [[[np.float16, i, [32, 3, 3]], j] for i in format_list for j in dtype_list] + self.zeros_result(shape_format) + + def test_zeros_shape_format_fp16_4d(self, device): + format_list = [0, 3, 29] + dtype_list = [torch.float16, torch.float32, torch.int32] + shape_format = [[[np.float16, i, [64, 112, 7, 7]], j] for i in format_list for j in dtype_list] + self.zeros_result(shape_format) + + def test_zeros_shape_format_fp32_1d(self, device): + format_list = [0, 3, 29] + dtype_list = [torch.float16, torch.float32, torch.int32] + shape_format = [[[np.float32, i, [18]], j] for i in format_list for j in dtype_list] + self.zeros_result(shape_format) + + def test_zeros_shape_format_fp32_2d(self, device): + format_list = [0, 3, 29] + dtype_list = [torch.float16, torch.float32, torch.int32] + shape_format = [[[np.float32, i, [5, 256]], j] for i in format_list for j in dtype_list] + self.zeros_result(shape_format) + + def test_zeros_shape_format_fp32_3d(self, device): + format_list = [0, 3, 29] + dtype_list = [torch.float16, torch.float32, torch.int32] + shape_format = [[[np.float32, i, [32, 3, 3]], j] for i in format_list for j in dtype_list] + self.zeros_result(shape_format) + + def test_zeros_shape_format_fp32_4d(self, device): + format_list = [0, 3, 29] + dtype_list = [torch.float16, torch.float32, torch.int32] + shape_format = [[[np.float32, i, [64, 112, 7, 7]], j] for i in format_list for j in dtype_list] + self.zeros_result(shape_format) + + def test_zeros_shape_format_int32_1d(self, device): + format_list = [-1, 0] + dtype_list = [torch.float16, torch.float32, torch.int32] + shape_format = [[[np.int32, i, [18]], j] for i in format_list for j in dtype_list] + self.zeros_result(shape_format) + + def test_zeros_shape_format_int32_2d(self, device): + format_list = [-1, 0] + dtype_list = [torch.float16, torch.float32, torch.int32] + shape_format = [[[np.int32, i, [5, 256]], j] for i in format_list for j in dtype_list] + self.zeros_result(shape_format) + + def test_zeros_shape_format_int32_3d(self, device): + format_list = [-1, 0] + dtype_list = [torch.float16, torch.float32, torch.int32] + shape_format = [[[np.int32, i, [32, 3, 3]], j] for i in format_list for j in dtype_list] + self.zeros_result(shape_format) + + def test_zeros_shape_format_int32_4d(self, device): + format_list = [-1, 0] + dtype_list = [torch.float16, torch.float32, torch.int32] + shape_format = [[[np.int32, i, [64, 112, 7, 7]], j] for i in format_list for j in dtype_list] + self.zeros_result(shape_format) + + +instantiate_device_type_tests(TestZeros, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() diff --git a/torch_npu/csrc/aten/ops/EqKernelNpu.cpp b/torch_npu/csrc/aten/ops/EqKernelNpu.cpp index fa27bc376a46b8306ca6d9b5974dea18bfb83658..0b05e8361dd9b7288cfb4e8f39594168a9cf2f01 100644 --- a/torch_npu/csrc/aten/ops/EqKernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/EqKernelNpu.cpp @@ -123,7 +123,7 @@ namespace at_npu return result; } - at::Tensor &eq_npu_(at::Tensor &self, const at::Tensor &other) + at::Tensor& NPUNativeFunctions::eq_(at::Tensor &self, const at::Tensor &other) { OpPreparation::CastBackToOriFormat(self); at::Tensor formatCastOfOther = OpPreparation::CastBackToOriFormat(other); @@ -152,7 +152,7 @@ namespace at_npu return self; } - at::Tensor &eq_npu_(at::Tensor &self, at::Scalar other) + at::Tensor& NPUNativeFunctions::eq_(at::Tensor &self, at::Scalar other) { OpPreparation::CastBackToOriFormat(self); c10::SmallVector inputs = {self}; diff --git a/torch_npu/csrc/aten/ops/FloorDivideKernelNpu.cpp b/torch_npu/csrc/aten/ops/FloorDivideKernelNpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0e8bc5d1c2709e0412e675a7d4e1d0629d12f9db --- /dev/null +++ b/torch_npu/csrc/aten/ops/FloorDivideKernelNpu.cpp @@ -0,0 +1,152 @@ +// Copyright (c) 2022 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor& floor_divide_out_npu_nocheck(at::Tensor& result, const at::Tensor& self, at::Scalar other) { + OpCommand cmd; + cmd.Name("FloorDiv") + .Input(self) + .Input(other, self.scalar_type()) + .Output(result) + .Run(); + return result; +} + +at::Tensor& floor_divide_out_scalar_npu(const at::Tensor& self, at::Scalar other, at::Tensor& result) { + at::Tensor formatCastOfSelf = OpPreparation::CastBackToOriFormat(self); + auto outputSize = formatCastOfSelf.sizes(); + OpPreparation::CheckOut( + {self}, + result, + ACL_FORMAT_ND, + result.scalar_type(), + outputSize); + + floor_divide_out_npu_nocheck(result, formatCastOfSelf, other); + return result; +} + +at::Tensor& NPUNativeFunctions::floor_divide_out(const at::Tensor& self, const at::Tensor& other, at::Tensor& result) { + at::Tensor formatCastOfSelf = OpPreparation::CastBackToOriFormat(self); + auto outputSize = formatCastOfSelf.sizes(); + OpPreparation::CheckOut( + {self, other}, + result, + ACL_FORMAT_ND, + result.scalar_type(), + outputSize); + // executing the NPU operator + if (other.dim() == 0) { + floor_divide_out_npu_nocheck(result, self, other.item()); + } else { + OpCommand cmd; + cmd.Name("FloorDiv") + .Input(self) + .Input(other) + .Output(result) + .Run(); + } + + return result; +} + +at::Tensor NPUNativeFunctions::floor_divide(const at::Tensor& self, const at::Tensor& other) { + at::Tensor selfCast = self; + if(self.dtype() == at::ScalarType::Bool){ + selfCast = selfCast.to(at::ScalarType::Float); + } + at::Tensor otherCast = other; + if (other.scalar_type() == at::ScalarType::Double) { + otherCast = otherCast.to(at::ScalarType::Float); + } + if (other.scalar_type() == at::ScalarType::Long) { + otherCast = otherCast.to(at::ScalarType::Int); + } + + // calculate the output size + bool isSelfWrapped = CalcuOpUtil::is_scalar_wrapped_to_tensor(selfCast); + at::Tensor outputTensor = isSelfWrapped ? otherCast : selfCast; + + auto outputSize = broadcast_ops_npu_output_size(selfCast, otherCast); + + // construct the output tensor of the NPU + + at::Tensor result = OpPreparation::ApplyTensorWithFormat( + outputSize, + outputTensor.options(), + CalcuOpUtil::get_tensor_npu_format(selfCast)); + + // calculate the output result of the NPU + NPUNativeFunctions::floor_divide_out(selfCast, otherCast, result); + + return result; +} + +at::Tensor NPUNativeFunctions::floor_divide(const at::Tensor& self, at::Scalar other) { + + // calculate the output size + auto outputSize = input_same_output_size(self); + + // construct the output tensor of the NPU + at::Tensor result = OpPreparation::ApplyTensor(self, outputSize); + + // calculate the output result of the NPU + floor_divide_out_scalar_npu(self, other, result); + + return result; +} + +at::Tensor& NPUNativeFunctions::floor_divide_(at::Tensor& self, const at::Tensor& other) { + at::Tensor otherCast = other; + if (other.scalar_type() == at::ScalarType::Double) { + otherCast = otherCast.to(at::ScalarType::Float); + } + if (other.scalar_type() == at::ScalarType::Long) { + otherCast = otherCast.to(at::ScalarType::Int); + } + at::SmallVector inputs = {self, otherCast}; + at::SmallVector outputs = {self}; + CalcuOpUtil::check_memory_over_laps(inputs, outputs); + + if (!NpuUtils::check_match(&self)) { + at::Tensor contiguousSelf = NpuUtils::format_contiguous(self); + at::Tensor result = NPUNativeFunctions::floor_divide_out(contiguousSelf, other, contiguousSelf); + NpuUtils::format_fresh_view(self, result); + } else { + NPUNativeFunctions::floor_divide_out(self, otherCast, self); + } + + return self; +} + +at::Tensor& NPUNativeFunctions::floor_divide_(at::Tensor& self, at::Scalar other) { + if (!NpuUtils::check_match(&self)) { + at::Tensor contiguousSelf = NpuUtils::format_contiguous(self); + floor_divide_out_scalar_npu(contiguousSelf, other, contiguousSelf); + NpuUtils::format_fresh_view(self, contiguousSelf); + } else { + floor_divide_out_scalar_npu(self, other, self); + } + return self; +} + +} // namespace native +} // at_npu \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/SigmoidBackwardKernelNpu.cpp b/torch_npu/csrc/aten/ops/SigmoidBackwardKernelNpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fdd644c05bd87f93b863cae8a8b90d5628e0a274 --- /dev/null +++ b/torch_npu/csrc/aten/ops/SigmoidBackwardKernelNpu.cpp @@ -0,0 +1,56 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor& sigmoid_backward_out_npu_nocheck( + at::Tensor& result, + const at::Tensor& grad_output, + const at::Tensor& output) { + auto unified_result = OpPreparation::binary_op_check(result, output, grad_output, true); + OpCommand cmd; + cmd.Name("SigmoidGrad") + .Expect(unified_result) + .Input(output) + .Input(grad_output) + .Output(result) + .Run(); + + return result; +} + +at::Tensor& NPUNativeFunctions::sigmoid_backward_out( + const at::Tensor& grad_output, + const at::Tensor& output, + at::Tensor& result) { + OpPreparation::CheckOut({grad_output, output}, result, grad_output); + sigmoid_backward_out_npu_nocheck(result, grad_output, output); + return result; +} + +at::Tensor NPUNativeFunctions::sigmoid_backward(const at::Tensor& grad_output, const at::Tensor& output) { + at::Tensor grad_input = OpPreparation::ApplyTensor(grad_output); + sigmoid_backward_out_npu_nocheck(grad_input, grad_output, output); + + return grad_input; +} + +} // namespace native +} // namespace at_npu \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/SigmoidKernelNpu.cpp b/torch_npu/csrc/aten/ops/SigmoidKernelNpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..19f6ffb5e1f91d826fcbc4dfc1755e0098ae4bb4 --- /dev/null +++ b/torch_npu/csrc/aten/ops/SigmoidKernelNpu.cpp @@ -0,0 +1,57 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor& sigmoid_out_npu_nocheck(at::Tensor& result, const at::Tensor& self) { + OpCommand cmd; + cmd.Name("Sigmoid") + .Input(self) + .Output(result) + .Run(); + + return result; +} + +at::Tensor& NPUNativeFunctions::sigmoid_out(const at::Tensor& self, at::Tensor& result) { + OpPreparation::CheckOut( + {self}, + result, + self); + + OpPipeWithDefinedOut pipe; + return pipe.CheckMemory({self}, {result}) + .Func([&self](at::Tensor& result){sigmoid_out_npu_nocheck(result, self);}) + .Call(result); +} + +at::Tensor& NPUNativeFunctions::sigmoid_(at::Tensor& self) { + NPUNativeFunctions::sigmoid_out(self, self); + + return self; +} + +at::Tensor NPUNativeFunctions::sigmoid(const at::Tensor& self) { + at::Tensor result = OpPreparation::ApplyTensor(self); + sigmoid_out_npu_nocheck(result, self); + return result; +} + +} // namespace native +} // namespace at_npu \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/StackKernelNpu.cpp b/torch_npu/csrc/aten/ops/StackKernelNpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cf54ed20af77ed8bf324c1d2e5e5ea604726654a --- /dev/null +++ b/torch_npu/csrc/aten/ops/StackKernelNpu.cpp @@ -0,0 +1,85 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::SmallVector stack_npu_output_size( + at::TensorList tensors, + int64_t dim) { + dim = CalcuOpUtil::make_wrap_dim(dim, tensors[0].dim() + 1); + at::SmallVector shape; + for (int i = 0; i < dim; i++) { + shape.emplace_back(tensors[0].size(i)); + } + shape.emplace_back(tensors.size()); + for (int i = dim; i < tensors[0].dim(); i++) { + shape.emplace_back(tensors[0].size(i)); + } + + return shape; +} + +at::Tensor& stack_out_npu_nocheck(at::TensorList tensors, int64_t dim, at::Tensor& result) { + auto inputTensors = CalcuOpUtil::ConvertTensorListToSmallVector(tensors); + + OpCommand cmd; + cmd.Name("Pack"); + for (int i = 0; i < inputTensors.size(); i++) { + string inputName = "x" + std::to_string(i); + cmd.Input(inputTensors[i],inputName); + } + cmd.Output(result) + .Attr("N", (int64_t)tensors.size()) + .Attr("axis", dim) + .Run(); + + return result; +} + +at::Tensor& NPUNativeFunctions::stack_out(at::TensorList tensors, int64_t dim, at::Tensor& result) { + auto outputSize = stack_npu_output_size(tensors, dim); + + OpPreparation::CheckOut( + {tensors[0]}, + result, + ACL_FORMAT_ND, + tensors[0].scalar_type(), + outputSize); + + stack_out_npu_nocheck(tensors, dim, result); + + return result; +} + +at::Tensor NPUNativeFunctions::stack(at::TensorList tensors, int64_t dim) { + auto outputSize = stack_npu_output_size(tensors, dim); + + at::Tensor result = OpPreparation::ApplyTensorWithFormat( + outputSize, + tensors[0].options(), + ACL_FORMAT_ND); + + stack_out_npu_nocheck(tensors, dim, result); + + return result; +} + +} // namespace native +} // namespace at_npu diff --git a/torch_npu/csrc/aten/ops/ZerosKernelNpu.cpp b/torch_npu/csrc/aten/ops/ZerosKernelNpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..173ec561023c3199d888f05c5d8bc07f129be78a --- /dev/null +++ b/torch_npu/csrc/aten/ops/ZerosKernelNpu.cpp @@ -0,0 +1,52 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor& NPUNativeFunctions::zeros_out(at::IntArrayRef size, at::Tensor& result) { + result.resize_(size); + return result.zero_(); +} + +at::Tensor NPUNativeFunctions::zeros(at::IntArrayRef size, + c10::optional dtype_opt, + c10::optional layout_opt, + c10::optional device_opt, + c10::optional pin_memory_opt) { + at::TensorOptions option = option.dtype(dtype_opt) + .layout(layout_opt) + .device(device_opt) + .pinned_memory(pin_memory_opt); + at::Tensor result = OpPreparation::ApplyTensorWithFormat(size, option, ACL_FORMAT_ND); + return result.zero_(); +} + +at::Tensor NPUNativeFunctions::zeros( + at::IntArrayRef size, + c10::optional names, + c10::optional dtype_opt, + c10::optional layout_opt, + c10::optional device_opt, + c10::optional pin_memory_opt) { + return NPUNativeFunctions::zeros(size, dtype_opt, layout_opt, device_opt, pin_memory_opt); +} + +} // namespace native +} // namespace at_npu \ No newline at end of file