diff --git a/test/test_network_ops/test_full.py b/test/test_network_ops/test_full.py new file mode 100644 index 0000000000000000000000000000000000000000..c779acf22a95f33d2e3c1a020b15746edde4b198 --- /dev/null +++ b/test/test_network_ops/test_full.py @@ -0,0 +1,69 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class TestFull(TestCase): + def test_full_shape_format_fp16(self, device): + format_list = [0, 3] + dtype_list = [torch.float32, torch.float16, torch.int32] + shape_list = [[5, 8], [2, 4, 1, 1], [16]] + shape_format = [[[np.float16, i, j], k] + for i in format_list for j in shape_list for k in dtype_list] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item[0], 0, 100) + cpu_output = torch.full(cpu_input.size(), 6, dtype=item[1], device="cpu") + cpu_output = cpu_output.numpy() + npu_output = torch.full(npu_input.size(), 6, dtype=item[1], device="npu") + npu_output = npu_output.to("cpu") + npu_output = npu_output.numpy() + self.assertRtolEqual(cpu_output, npu_output) + + def test_full_shape_format_fp32(self, device): + format_list = [0, 3] + dtype_list = [torch.float32, torch.float16, torch.int32] + shape_list = [[5, 8], [2, 4, 1, 1], [16]] + shape_format = [[[np.float32, i, j], k] + for i in format_list for j in shape_list for k in dtype_list] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item[0], 0, 100) + cpu_output = torch.full(cpu_input.size(), 6, dtype=item[1], device="cpu") + cpu_output = cpu_output.numpy() + npu_output = torch.full(npu_input.size(), 6, dtype=item[1], device="npu") + npu_output = npu_output.to("cpu") + npu_output = npu_output.numpy() + self.assertRtolEqual(cpu_output, npu_output) + + def test_full_out(self, device): + + shape_format = [[[np.float32, 0, [5, 8]], torch.float32]] + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item[0], 0, 100) + cpu_input2, npu_input2 = create_common_tensor(item[0], 0, 100) + cpu_output = torch.full(cpu_input1.size(), 6, dtype=item[1], device="cpu") + cpu_output = cpu_output.numpy() + npu_output = torch.full(npu_input1.size(), 6, dtype=item[1], out=npu_input2, device="npu") + npu_output = npu_output.to("cpu") + npu_output = npu_output.numpy() + self.assertRtolEqual(cpu_output, npu_output) + +instantiate_device_type_tests(TestFull, globals(), except_for="cpu") +if __name__ == '__main__': + run_tests() \ No newline at end of file diff --git a/torch_npu/csrc/aten/common/TensorFactories.cpp b/torch_npu/csrc/aten/common/TensorFactories.cpp index 98d9adbfc13135c284737c65374773158981a610..1b84630f90e9868a40152252d80b30d232bfd961 100644 --- a/torch_npu/csrc/aten/common/TensorFactories.cpp +++ b/torch_npu/csrc/aten/common/TensorFactories.cpp @@ -698,11 +698,20 @@ namespace at_npu } } - at::Tensor full_npu( + at::Tensor NPUNativeFunctions::full( c10::IntArrayRef size, c10::Scalar fill_value, - const c10::TensorOptions &options) + c10::optional dtype_opt, + c10::optional layout_opt, + c10::optional device_opt, + c10::optional pin_memory_opt) { + c10::TensorOptions options; + auto device = device_or_default(device_opt); + options = options.dtype(dtype_opt) + .device(device) + .layout(layout_opt) + .pinned_memory(pin_memory_opt); TORCH_CHECK( options.layout() != at::kSparse, "full(...) is not implemented for sparse layout"); diff --git a/torch_npu/csrc/aten/ops/FullKernelNpu.cpp b/torch_npu/csrc/aten/ops/FullKernelNpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8f644cbc0f9e59dd62593602c1b3505e305dfc07 --- /dev/null +++ b/torch_npu/csrc/aten/ops/FullKernelNpu.cpp @@ -0,0 +1,51 @@ +// Copyright (c) 2020, Huawei Technologies.All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +at::Tensor& NPUNativeFunctions::full_out(at::IntArrayRef size, at::Scalar fill_value, at::Tensor& out) { + OpPreparation::CheckOut( + {}, + out, + out, + size); + fill_(out, fill_value); + return out; +} + +at::Tensor NPUNativeFunctions::full( + at::IntArrayRef size, + at::Scalar fill_value, + c10::optional names, + c10::optional dtype_opt, + c10::optional layout_opt, + c10::optional device_opt, + c10::optional pin_memory_opt) { + c10::TensorOptions option; + auto device = device_or_default(device_opt); + option = option.dtype(dtype_opt) + .device(device) + .layout(layout_opt) + .pinned_memory(pin_memory_opt); + at::Tensor result = OpPreparation::ApplyTensorWithSizes(size, option); + return result.fill_(fill_value); +} + +} +} \ No newline at end of file