diff --git a/test/test_network_ops/test_affine_grid_generator_backward.py b/test/test_network_ops/test_affine_grid_generator_backward.py new file mode 100644 index 0000000000000000000000000000000000000000..7a6159f75dd1ccd64546d1ad6dea99ffc55fb8f2 --- /dev/null +++ b/test/test_network_ops/test_affine_grid_generator_backward.py @@ -0,0 +1,73 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np +from torch.nn import functional as F + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import Dtypes, instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor, test_2args_broadcast, create_dtype_tensor, UT_FAST_MODE + +class TestAffineGridGeneratorBackward(TestCase): + def test_affine_grid_generator_backward_common_shape(self, device): + shape_list = [[100, 2, 3], [10, 2, 3]] + shape_format = [ + [np.float32, -1, j] for j in shape_list + ] + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item, 0, 1) + size = torch.Size((item[2][0], 2, 28, 2)) + cpu_input1.requires_grad = True + cpu_output = self.cpu_op_exec(cpu_input1, size) + npu_input1.requires_grad = True + npu_output = self.npu_op_exec(npu_input1, size) + self.assertRtolEqual(cpu_output, npu_output) + + def test_affine_grid_generator_backward_fp16(self, device): + shape_list = [[100, 2, 3], [10, 2, 3]] + shape_format = [ + [np.float16, -1, j] for j in shape_list + ] + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item, 0, 1) + cpu_input1 = cpu_input1.to(torch.float32) + npu_input1 = npu_input1.to(torch.float32) + size = torch.Size((item[2][0], 2, 28, 2)) + cpu_input1.requires_grad = True + cpu_output = self.cpu_op_exec(cpu_input1, size) + npu_input1.requires_grad = True + npu_output = self.npu_op_exec(npu_input1, size) + self.assertRtolEqual(cpu_output.astype(np.float16), npu_output.astype(np.float16)) + + def cpu_op_exec(self, input1, size): + out = F.affine_grid(input1, size, True) + input1.requires_grad = True + grad_output = torch.ones(out.size(), dtype=torch.float) + out.backward(gradient=grad_output) + output = input1.grad.numpy() + return output + + def npu_op_exec(self, input1, size): + input1.requires_grad = True + out = F.affine_grid(input1, size, True) + grad_output = torch.ones(out.size(), dtype=torch.float).npu() + out.backward(gradient=grad_output) + output = input1.grad.to("cpu").numpy() + return output + +instantiate_device_type_tests(TestAffineGridGeneratorBackward, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/test/test_network_ops/test_linspace.py b/test/test_network_ops/test_linspace.py new file mode 100644 index 0000000000000000000000000000000000000000..00f4d44e2d78578040cd3cfb81c97d87d2c12f0e --- /dev/null +++ b/test/test_network_ops/test_linspace.py @@ -0,0 +1,63 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch_npu +import numpy as np +from torch.nn import functional as F + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import Dtypes, instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor, test_2args_broadcast, create_dtype_tensor, UT_FAST_MODE + +class TestLinspace(TestCase): + def test_linspace(self, device): + shape_format = [ + [0, 100, 10, torch.float32, + torch.tensor([0.,11.111111, 22.222221, 33.333332, 44.444443, + 55.555557, 66.666664, 77.77778, 88.888885, 100.])], + [1, 100, 20, torch.int32, + torch.tensor([1, 6, 11, 16, 21, 27, 32, 37, 42, + 47, 53, 58, 63, 68, 73, 79, 84, 89, 94, 100], dtype=torch.int32)], + ] + + for item in shape_format: + cpu_output = torch.linspace(item[0], item[1], item[2], dtype=item[3], + device="cpu") + npu_output = torch.linspace(item[0], item[1], item[2], dtype=item[3], + device="npu").cpu() + benchmark15 = item[4] + self.assertRtolEqual(benchmark15, npu_output) + + def test_linspace_out(self, device): + shape_format = [ + [0, 100, 10, torch.float32, [np.float32, 0, [10]], + torch.tensor([0.,11.111111, 22.222221, 33.333332, 44.444443, + 55.555557, 66.666664, 77.77778, 88.888885, 100.])], + [1, 100, 20, torch.int32, [np.int32, 0, [20]], + torch.tensor([1, 6, 11, 16, 21, 27, 32, 37, 42, + 47, 53, 58, 63, 68, 73, 79, 84, 89, 94, 100], dtype=torch.int32)], + ] + + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item[4], 0, 10) + cpu_output = torch.linspace(item[0], item[1], item[2], out=cpu_input1, + dtype=item[3], device="cpu") + npu_output = torch.linspace(item[0], item[1], item[2], out=npu_input1, + dtype=item[3], device="npu").cpu() + benchmark15 = item[5] + self.assertRtolEqual(benchmark15, npu_output) + +instantiate_device_type_tests(TestLinspace, globals(), except_for='cpu') +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/AffineGridGeneratorBackwardKernelNpu.cpp b/torch_npu/csrc/aten/ops/AffineGridGeneratorBackwardKernelNpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2afe7b14737cba479ab5aca66ad4adee33f2a3ec --- /dev/null +++ b/torch_npu/csrc/aten/ops/AffineGridGeneratorBackwardKernelNpu.cpp @@ -0,0 +1,83 @@ +// Copyright (c) 2020, Huawei Technologies.All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { +namespace{ +at::Tensor _linspace_from_neg_one(const at::Tensor& grid, int64_t num_steps, bool align_corners) { + if (num_steps <= 1) { + return at::tensor(0, grid.options()); + } + auto range = at::linspace(-1, 1, num_steps, grid.options()); + if (!align_corners && num_steps != 0) { + range = range * (num_steps - 1) / num_steps; + } + return range; +} + +at::Tensor& affine_grid_generator_backward_nocheck( + at::Tensor& result, + const at::Tensor& grad, + at::IntArrayRef size, + bool align_corners) { + at::Tensor assist = OpPreparation::ApplyTensor(grad, {size[0], size[2], size[3], 3}); + assist.select(-1, 0).copy_(_linspace_from_neg_one(grad, size[3], align_corners)); + assist.select(-1, 1).copy_(_linspace_from_neg_one(grad, size[2], align_corners).unsqueeze_(-1)); + assist.select(-1, 2).fill_(1); + AT_ASSERT(grad.sizes() == at::IntArrayRef({size[0], size[2], size[3], 2})); + + auto reassist = assist.view({size[0], size[2]*size[3], 3}).transpose(1, 2); + auto grid = grad.view({size[0], size[2]*size[3], 2}); + + OpCommand cmd; + cmd.Name("BatchMatMul") + .Input(reassist) + .Input(grid) + .Output(result) + .Attr("bias", (int64_t)0) + .Attr("adj_x1", (bool)false) + .Attr("adj_x2", (bool)false) + .Run(); + + return result; +} +} // namespace + +at::Tensor NPUNativeFunctions::affine_grid_generator_backward( + const at::Tensor& grad, + at::IntArrayRef size, + bool align_corners) { + TORCH_CHECK(size.size() == 4, "AffineGridGeneratorBackward needs 4d (spatial) input.") + + // calculate the output size + c10::SmallVector outputSize = {size[0], 3, 2}; + + // construct the output tensor of the NPU + at::Tensor result = OpPreparation::ApplyTensorWithFormat(grad, outputSize, ACL_FORMAT_ND); + + // calculate the output result of the NPU + affine_grid_generator_backward_nocheck( + result, + grad, + size, + align_corners); + auto fresult = result.transpose(1, 2); + + return fresult; +} +} // namespace native +} // namespace at_npu diff --git a/torch_npu/csrc/aten/ops/LinspaceKernelNpu.cpp b/torch_npu/csrc/aten/ops/LinspaceKernelNpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9bbe4f603a13c51754b4d9a37e98b3e443c40ed6 --- /dev/null +++ b/torch_npu/csrc/aten/ops/LinspaceKernelNpu.cpp @@ -0,0 +1,97 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { +at::Tensor linspace_assist(int64_t steps) { + c10::SmallVector assist; + assist.resize(steps); + + for (int64_t i = 0; i < steps; i++) { + assist[i] = (float)(i); + } + at::Tensor assistTensor = + at::from_blob(assist.data(), {steps}, dtype(at::ScalarType::Float)); + return CalcuOpUtil::copy_tensor_host_to_device(assistTensor); +} + +at::Tensor& NPUNativeFunctions::linspace_out(at::Scalar start, at::Scalar end, c10::optional step, at::Tensor& result) { + int64_t steps = step.has_value()? step.value():-65530; + TORCH_CHECK(steps >= 0, "number of steps must be non-negative"); + + if (result.numel() != steps) { + result.resize_({steps}); + } + at::Tensor r = result.is_contiguous() ? result : result.contiguous(); + r = r.to(at::kFloat); + if(steps == 0){ + // skip + } else if (steps == 1) { + r.fill_(start); + } else { + c10::SmallVector sizeVec = {steps}; + OpCommand cmd; + cmd.Name("LinSpace") + .Input(start, at::ScalarType::Float) + .Input(end, at::ScalarType::Float) + .Input(sizeVec, at::ScalarType::Int) + .Output(r) + .Run(); + } + + if(r.dtype() != result.dtype()) { + r = r.to(result.dtype()); + } + + return result.copy_(r); +} + +at::Tensor NPUNativeFunctions::linspace(at::Scalar start, at::Scalar end, + c10::optional step, + c10::optional dtype_opt, + c10::optional layout_opt, + c10::optional device_opt, + c10::optional pin_memory_opt +) { + int64_t steps = step.has_value()? step.value():-65530; + TORCH_CHECK(steps >= 0, "number of steps must be non-negative"); + + auto device = device_or_default(device_opt); + at::TensorOptions option; + option = option.dtype(dtype_opt) + .layout(layout_opt) + .device(device) + .pinned_memory(pin_memory_opt); + + // construct the output tensor of the NPU + at::Tensor result = OpPreparation::ApplyTensorWithFormat({steps}, option, ACL_FORMAT_ND); + at::Tensor resultCast = result.to(at::kFloat); + + // calculate the output result of the NPU + NPUNativeFunctions::linspace_out(start, end, steps, resultCast); + + if(option.dtype() != resultCast.dtype()) { + resultCast = resultCast.to(option.dtype()); + } + + return resultCast; +} +} // namespace native +} // namespace at_npu