diff --git a/test/test_network_ops/test_format_cast.py b/test/test_network_ops/test_format_cast.py new file mode 100644 index 0000000000000000000000000000000000000000..d87d6311037a6f92142c3ad849e97b61945c2614 --- /dev/null +++ b/test/test_network_ops/test_format_cast.py @@ -0,0 +1,228 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests + +class TestFormatCast(TestCase): + def create_single_npu_tensor(self, item, minvalue, maxvalue): + dtype = item[0] + format = item[1] + shape = item[2] + input1 = np.random.uniform(minvalue, maxvalue, shape).astype(dtype) + npu_input = torch.from_numpy(input1).to("npu") + if format != -1: + npu_input = torch_npu.npu_format_cast(npu_input, format) + return npu_input + + def check_result(self, expectValue, retTensor): + if retTensor.storage().npu_format() != expectValue: + print("expectValue: ", expectValue, " resultValue: ", retTensor.storage().npu_format()) + sys.exit(-1) + + def test_format_cast_tensor(self, device): + src_shape_format = [ + [np.float16, 0, (2, 2, 4, 4)], + [np.float16, 2, (2, 2, 4, 4)] + ] + dst_shape_format = [ + [np.float16, 3, (2, 2, 4, 4)], + [np.float16, 4, (2, 2, 4, 4)], + [np.float16, 29, (2, 2, 4, 4)], + [np.float16, 30, (2, 2 ,2 , 4, 4)], + ] + + for i in src_shape_format: + src_tensor = self.create_single_npu_tensor(i, 1, 5) + for j in dst_shape_format: + dst_tensor = self.create_single_npu_tensor(j, 3, 6) + result_tensor = torch_npu.npu_format_cast(src_tensor, dst_tensor) + self.check_result(dst_tensor.storage().npu_format(), npu_tensor) + + def test_format_cast(self, device): + shape_format = [np.float16, -1, (2, 2, 4, 4)] + npu_tensor = self.create_single_npu_tensor(shape_format, 1, 5) + # print("org format: ", npu_tensor.storage().npu_format()) + npu_tensor = torch_npu.npu_format_cast(npu_tensor, 2) + self.check_result(2, npu_tensor) + npu_tensor = torch_npu.npu_format_cast(npu_tensor, 3) + self.check_result(3, npu_tensor) + npu_tensor = torch_npu.npu_format_cast(npu_tensor, 0) + self.check_result(0, npu_tensor) + npu_tensor = torch_npu.npu_format_cast(npu_tensor, 3) + self.check_result(3, npu_tensor) + npu_tensor = torch_npu.npu_format_cast(npu_tensor, 2) + self.check_result(0, npu_tensor) + npu_tensor = torch_npu.npu_format_cast(npu_tensor, 0) + self.check_result(0, npu_tensor) + npu_tensor = torch_npu.npu_format_cast(npu_tensor, 2) + self.check_result(2, npu_tensor) + + npu_tensor = torch_npu.npu_format_cast(npu_tensor, 0) + self.check_result(0, npu_tensor) + npu_tensor = torch_npu.npu_format_cast(npu_tensor, 4) + self.check_result(4, npu_tensor) + npu_tensor = torch_npu.npu_format_cast(npu_tensor, 0) + self.check_result(0, npu_tensor) + npu_tensor = torch_npu.npu_format_cast(npu_tensor, 29) + self.check_result(29, npu_tensor) + npu_tensor = torch_npu.npu_format_cast(npu_tensor, 0) + self.check_result(0, npu_tensor) + npu_tensor = torch_npu.npu_format_cast(npu_tensor, 2) + self.check_result(2, npu_tensor) + npu_tensor = torch_npu.npu_format_cast(npu_tensor, 4) + self.check_result(4, npu_tensor) + npu_tensor = torch_npu.npu_format_cast(npu_tensor, 0) + self.check_result(0, npu_tensor) + npu_tensor = torch_npu.npu_format_cast(npu_tensor, 2) + self.check_result(2, npu_tensor) + npu_tensor = torch_npu.npu_format_cast(npu_tensor, 29) + self.check_result(29, npu_tensor) + npu_tensor = torch_npu.npu_format_cast(npu_tensor, 0) + self.check_result(0, npu_tensor) + + npu_tensor = npu_tensor.view(2,2,2,2,4).clone() + # print("five org format: ", npu_tensor.storage().npu_format()) + npu_tensor = torch_npu.npu_format_cast(npu_tensor, 30) + self.check_result(30, npu_tensor) + npu_tensor = torch_npu.npu_format_cast(npu_tensor, 33) + self.check_result(33, npu_tensor) + npu_tensor = torch_npu.npu_format_cast(npu_tensor, 30) + self.check_result(30, npu_tensor) + npu_tensor = torch_npu.npu_format_cast(npu_tensor, 2) + self.check_result(2, npu_tensor) + npu_tensor = torch_npu.npu_format_cast(npu_tensor, 33) + self.check_result(33, npu_tensor) + npu_tensor = torch_npu.npu_format_cast(npu_tensor, 2) + self.check_result(2, npu_tensor) + npu_tensor = torch_npu.npu_format_cast(npu_tensor, 30) + self.check_result(30, npu_tensor) + npu_tensor = torch_npu.npu_format_cast(npu_tensor, 32) + self.check_result(32, npu_tensor) + npu_tensor = torch_npu.npu_format_cast(npu_tensor, 30) + self.check_result(30, npu_tensor) + npu_tensor = torch_npu.npu_format_cast(npu_tensor, 2) + self.check_result(2, npu_tensor) + npu_tensor = torch_npu.npu_format_cast(npu_tensor, 32) + self.check_result(32, npu_tensor) + npu_tensor = torch_npu.npu_format_cast(npu_tensor, 2) + self.check_result(2, npu_tensor) + + def test_format_cast_tensor(self, device): + shape_format = [ + [np.float16, -1, (2, 2, 4, 4)], + [np.float16, 0, (2, 2, 4, 4)], + [np.float16, 2, (2, 2, 4, 4)], + [np.float16, 3, (2, 2, 4, 4)], + [np.float16, 4, (2, 2, 4, 4)], + [np.float16, 29, (2, 2, 4, 4)], + [np.float16, 30, (2, 2, 2, 4, 4)], + [np.float16, 32, (2, 2, 2, 4, 4)] + ] + + for i in shape_format: + src_tensor = self.create_single_npu_tensor(i, 1, 5) + for j in shape_format: + dst_tensor = self.create_single_npu_tensor(j, 3, 6) + result_tensor = torch_npu.npu_format_cast(src_tensor, dst_tensor) + self.check_result(dst_tensor.storage().npu_format(), result_tensor) + + def test_format_cast_(self, device): + shape_format = [np.float16, -1, (2, 2, 4, 4)] + npu_tensor = self.create_single_npu_tensor(shape_format, 1, 5) + # print("org format: ", npu_tensor.storage().npu_format()) + npu_tensor = torch_npu.npu_format_cast_(npu_tensor, 2) + self.check_result(2, npu_tensor) + npu_tensor = torch_npu.npu_format_cast_(npu_tensor, 3) + self.check_result(3, npu_tensor) + npu_tensor = torch_npu.npu_format_cast_(npu_tensor, 0) + self.check_result(0, npu_tensor) + npu_tensor = torch_npu.npu_format_cast_(npu_tensor, 3) + self.check_result(3, npu_tensor) + npu_tensor = torch_npu.npu_format_cast_(npu_tensor, 2) + self.check_result(0, npu_tensor) + npu_tensor = torch_npu.npu_format_cast_(npu_tensor, 0) + self.check_result(0, npu_tensor) + npu_tensor = torch_npu.npu_format_cast_(npu_tensor, 2) + self.check_result(2, npu_tensor) + + npu_tensor = torch_npu.npu_format_cast_(npu_tensor, 0) + self.check_result(0, npu_tensor) + npu_tensor = torch_npu.npu_format_cast_(npu_tensor, 4) + self.check_result(4, npu_tensor) + npu_tensor = torch_npu.npu_format_cast_(npu_tensor, 0) + self.check_result(0, npu_tensor) + npu_tensor = torch_npu.npu_format_cast_(npu_tensor, 29) + self.check_result(29, npu_tensor) + npu_tensor = torch_npu.npu_format_cast_(npu_tensor, 0) + self.check_result(0, npu_tensor) + npu_tensor = torch_npu.npu_format_cast_(npu_tensor, 2) + self.check_result(2, npu_tensor) + npu_tensor = torch_npu.npu_format_cast_(npu_tensor, 4) + self.check_result(4, npu_tensor) + npu_tensor = torch_npu.npu_format_cast_(npu_tensor, 0) + self.check_result(0, npu_tensor) + npu_tensor = torch_npu.npu_format_cast_(npu_tensor, 2) + self.check_result(2, npu_tensor) + npu_tensor = torch_npu.npu_format_cast_(npu_tensor, 29) + self.check_result(29, npu_tensor) + npu_tensor = torch_npu.npu_format_cast_(npu_tensor, 0) + self.check_result(0, npu_tensor) + + npu_tensor = npu_tensor.view(2,2,2,2,4).clone() + # print("five org format: ", npu_tensor.storage().npu_format()) + npu_tensor = torch_npu.npu_format_cast_(npu_tensor, 30) + self.check_result(30, npu_tensor) + npu_tensor = torch_npu.npu_format_cast_(npu_tensor, 33) + self.check_result(33, npu_tensor) + npu_tensor = torch_npu.npu_format_cast_(npu_tensor, 30) + self.check_result(30, npu_tensor) + npu_tensor = torch_npu.npu_format_cast_(npu_tensor, 2) + self.check_result(2, npu_tensor) + npu_tensor = torch_npu.npu_format_cast_(npu_tensor, 33) + self.check_result(33, npu_tensor) + npu_tensor = torch_npu.npu_format_cast_(npu_tensor, 2) + self.check_result(2, npu_tensor) + npu_tensor = torch_npu.npu_format_cast_(npu_tensor, 30) + self.check_result(30, npu_tensor) + npu_tensor = torch_npu.npu_format_cast_(npu_tensor, 32) + self.check_result(32, npu_tensor) + npu_tensor = torch_npu.npu_format_cast_(npu_tensor, 30) + self.check_result(30, npu_tensor) + npu_tensor = torch_npu.npu_format_cast_(npu_tensor, 2) + self.check_result(2, npu_tensor) + npu_tensor = torch_npu.npu_format_cast_(npu_tensor, 32) + self.check_result(32, npu_tensor) + npu_tensor = torch_npu.npu_format_cast_(npu_tensor, 2) + self.check_result(2, npu_tensor) + + # UT for view + transdata scene + def test_format_cast_val(self, device): + shape_format = [np.float32, -1, (10, 4)] + npu_tensor = self.create_single_npu_tensor(shape_format, 1, 5) + npu_tensor = torch_npu.npu_format_cast(npu_tensor, 3) + a = torch_npu.npu_format_cast(npu_tensor[1], 0).contiguous() + b = torch_npu.npu_format_cast(npu_tensor, 0)[1].contiguous() + a = a.to("cpu") + b = b.to("cpu") + self.assertRtolEqual(a, b) + +instantiate_device_type_tests(TestFormatCast, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/torch_npu/csrc/aten/common/FormatCastKernelNpu.cpp b/torch_npu/csrc/aten/common/FormatCastKernelNpu.cpp index c518156b0a4c281420fab2e229102037e8292ad6..082f6656b04c0a35247c0da1ab5756114773ea74 100644 --- a/torch_npu/csrc/aten/common/FormatCastKernelNpu.cpp +++ b/torch_npu/csrc/aten/common/FormatCastKernelNpu.cpp @@ -88,6 +88,15 @@ at::Tensor NPUNativeFunctions::npu_format_cast( return dst; } +// conver self to dst'format, write the result into new result tensor +at::Tensor NPUNativeFunctions::npu_format_cast( + const at::Tensor& src, + const at::Tensor& dst) { + c10::NPUStorageDesc dst_desc = dst.storage().unsafeGetStorageImpl()->npu_desc_; + int64_t dst_format = dst_desc.npu_format_; + return NPUNativeFunctions::npu_format_cast(src, dst_format); +} + // conver self to acl_format, write the result into self at::Tensor& NPUNativeFunctions::npu_format_cast_( at::Tensor& src, diff --git a/torch_npu/csrc/aten/npu_native_functions.yaml b/torch_npu/csrc/aten/npu_native_functions.yaml index 890b95c7915f06485ab2656d24ca4253b72848ed..b7283ff6e27ca3710ca76febb564c6cb98da4007 100644 --- a/torch_npu/csrc/aten/npu_native_functions.yaml +++ b/torch_npu/csrc/aten/npu_native_functions.yaml @@ -1900,8 +1900,9 @@ custom: - func: npu_conv3d.out(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups, *, Tensor(a!) out) -> Tensor(a!) - func: npu_conv3d_backward(Tensor input, Tensor grad, Tensor weight, int[] stride, int[] padding, int[] dilation, int groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor) - func: npu_format_cast(Tensor self, int acl_format) -> Tensor + - func: npu_format_cast.Tensor(Tensor self, Tensor dst) -> Tensor - func: npu_format_cast_.acl_format(Tensor(a!) self, int acl_format) -> Tensor(a!) - - func: npu_format_cast_(Tensor(a!) self, Tensor src) -> Tensor(a!) + - func: npu_format_cast_(Tensor(a!) self, Tensor dst) -> Tensor(a!) - func: empty_with_format(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, int acl_format=2) -> Tensor - func: empty_with_format.names(int[] size, Dimname[]? names, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, int acl_format=2) -> Tensor - func: copy_memory_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!)