From 3df9635485988610f1d04be3fef662adcae14c08 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E8=B6=85?= Date: Wed, 9 Feb 2022 14:15:26 +0800 Subject: [PATCH 1/3] add UTs for trans-contiguous-single_view --- ...est_single_broadcast_copy_to_contiguous.py | 56 +++++ .../test_single_permute_copy_to_contiguous.py | 54 +++++ .../test_single_reshape_copy_to_contiguous.py | 198 ++++++++++++++++++ .../test_single_slice_copy_to_contiguous.py | 151 +++++++++++++ .../framework/contiguous/indexing_opt.cpp | 4 +- .../csrc/framework/contiguous/permute_opt.cpp | 2 +- .../csrc/framework/contiguous/slice_opt.cpp | 2 +- torch_npu/testing/util_test.py | 33 ++- 8 files changed, 494 insertions(+), 6 deletions(-) create mode 100644 test/test_trans_contiguous/test_single_broadcast_copy_to_contiguous.py create mode 100644 test/test_trans_contiguous/test_single_permute_copy_to_contiguous.py create mode 100644 test/test_trans_contiguous/test_single_reshape_copy_to_contiguous.py create mode 100644 test/test_trans_contiguous/test_single_slice_copy_to_contiguous.py diff --git a/test/test_trans_contiguous/test_single_broadcast_copy_to_contiguous.py b/test/test_trans_contiguous/test_single_broadcast_copy_to_contiguous.py new file mode 100644 index 0000000000..454263b0ae --- /dev/null +++ b/test/test_trans_contiguous/test_single_broadcast_copy_to_contiguous.py @@ -0,0 +1,56 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor_for_broadcast, check_operators_in_prof + +os.environ["PTCOPY_ENABLE"] = "1" + +# Optimized view Ops contains Transpose, permute, narrow, strideslice, select, unfold +class SingleViewCopyToContiguous(TestCase): + def test_broadcast_copy_contiguous(self, device): + dtype_list = [np.float16, np.float32, np.int32, np.int8, np.uint8] + format_list = [-1] + shape_list = [ + [[1], [5]], + [[1, 2], [3, 2]], + [[1, 2, 1], [1, 2, 3]], + [[1, 2, 1, 3], [4, 2, 5, 3]], + [[1, 3], [1, 1, 4, 3]], + [[1, 3], [2, 1, 4, 3]], + [[1, 3], [1, 2, 4, 3]], + [[3, 1], [2, 1, 3, 1]], + [[3, 1], [1, 2, 3, 1]], + ] + shape_format = [ + [i, j, k] for i in dtype_list for j in format_list for k in shape_list + ] + + for item in shape_format: + cpu_input, npu_input = create_common_tensor_for_broadcast(item, 0, 100) + with torch.autograd.profiler.profile(use_npu=True) as prof: + npu_out1 = npu_input.expand(item[2][1]).contiguous() + self.assertEqual(check_operators_in_prof(['npuBroadcast'], prof), True, "npuBroadcast is not called!") + cpu_out1 = cpu_input.expand(item[2][1]).contiguous() + self.assertRtolEqual(npu_out1.to("cpu").numpy(), cpu_out1.numpy()) + +instantiate_device_type_tests(SingleViewCopyToContiguous, globals(), except_for='cpu') +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/test/test_trans_contiguous/test_single_permute_copy_to_contiguous.py b/test/test_trans_contiguous/test_single_permute_copy_to_contiguous.py new file mode 100644 index 0000000000..da422836cb --- /dev/null +++ b/test/test_trans_contiguous/test_single_permute_copy_to_contiguous.py @@ -0,0 +1,54 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor, check_operators_in_prof + +os.environ["PTCOPY_ENABLE"] = "1" + +# Optimized view Ops contains Transpose, permute, narrow, strideslice, select, unfold +class SingleViewCopyToContiguous(TestCase): + def test_permute_copy_contiguous(self, device): + dtype_list = [np.float16, np.float32] + format_list = [-1] + shape_list = [[2, 6, 9, 4]] + shape_format = [ + [i, j, k] for i in dtype_list for j in format_list for k in shape_list + ] + + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 0, 100) + with torch.autograd.profiler.profile(use_npu=True) as prof: + npu_out1 = npu_input.permute(1,0,2,3).contiguous() + self.assertEqual(check_operators_in_prof(['npuTranspose'], prof), True, "NpuTranspose op is not called!") + + with torch.autograd.profiler.profile(use_npu=True) as prof: + npu_out2 = npu_input.permute(2,3,0,1).contiguous() + self.assertEqual(check_operators_in_prof(['npuTranspose'], prof), True, "NpuTranspose op is not called!") + + cpu_out1 = cpu_input.permute(1,0,2,3).contiguous() + cpu_out2 = cpu_input.permute(2,3,0,1).contiguous() + + self.assertRtolEqual(npu_out1.to("cpu").numpy(), cpu_out1.numpy()) + self.assertRtolEqual(npu_out2.to("cpu").numpy(), cpu_out2.numpy()) + +instantiate_device_type_tests(SingleViewCopyToContiguous, globals(), except_for='cpu') +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/test/test_trans_contiguous/test_single_reshape_copy_to_contiguous.py b/test/test_trans_contiguous/test_single_reshape_copy_to_contiguous.py new file mode 100644 index 0000000000..2a9322e9c8 --- /dev/null +++ b/test/test_trans_contiguous/test_single_reshape_copy_to_contiguous.py @@ -0,0 +1,198 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor, check_operators_in_prof + +os.environ["PTCOPY_ENABLE"] = "1" + +# Optimized view Ops contains Transpose, permute, narrow, strideslice, select, unfold +class SingleViewCopyToContiguous(TestCase): + def test_view_copy(self, device): + dtype_list = [np.float16, np.float32] + format_list = [0, 3, 29] + shape_list = [ + #No padding for NZ format + [2, 3, 16, 16], + #Padding for NZ format + [2, 3, 15, 16], + ] + shape_format = [ + [i, j, k] for i in dtype_list for j in format_list for k in shape_list + ] + + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 0, 100) + # Directly using d2dcopy without transdata(View_d2dCopyAsync) + # case1. base format + match_case1 = (item[1] == 0) + with torch.autograd.profiler.profile(use_npu=True) as prof: + npu_out1 = npu_input.view(1, 6, npu_input.size(2), npu_input.size(3)).clone() + # case2. The key axis remains unchanged for NZ format + match_case2 = (item[1] == 29) + if match_case1 or match_case2: + self.assertEqual(check_operators_in_prof(['View_d2dCopyAsync'], prof), True, "View_d2dCopyAsync is not called!") + else: + self.assertEqual(check_operators_in_prof(['d2dCopyAsync'], prof), True, "d2dCopyAsync is not called!") + cpu_out1 = cpu_input.view(1, 6, cpu_input.size(2), cpu_input.size(3)).clone() + self.assertRtolEqual(npu_out1.to("cpu").numpy(), cpu_out1.numpy()) + + # The key axis changes for NZ format + with torch.autograd.profiler.profile(use_npu=True) as prof: + npu_out2 = npu_input.view(1, 6, npu_input.size(2)*npu_input.size(3), 1).clone() + if match_case1: + self.assertEqual(check_operators_in_prof(['View_d2dCopyAsync'], prof), True, "View_d2dCopyAsync is not called!") + else: + self.assertEqual(check_operators_in_prof(['d2dCopyAsync'], prof), True, "d2dCopyAsync is not called!") + cpu_out2 = cpu_input.view(1, 6, cpu_input.size(2)*cpu_input.size(3), 1).clone() + self.assertRtolEqual(npu_out2.to("cpu").numpy(), cpu_out2.numpy()) + + def test_unsqueeze_copy(self, device): + dtype_list = [np.float16, np.float32] + format_list = [2, 3, 29] + shape_list = [ + [3, 16, 16], + [3, 15, 16], + ] + shape_format = [ + [i, j, k] for i in dtype_list for j in format_list for k in shape_list + ] + + for i in range(3): + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 0, 100) + # Directly using d2dcopy without transdata(View_d2dCopyAsync) + # case1. base format + match_case1 = (item[1] == 2) + # case2. The key axis remains unchanged for NZ format + match_case2 = (item[1] == 29 and i < 2) + with torch.autograd.profiler.profile(use_npu=True) as prof: + npu_out = npu_input.unsqueeze(i).clone() + if match_case1 or match_case2: + self.assertEqual(check_operators_in_prof(['View_d2dCopyAsync'], prof), True, "View_d2dCopyAsync is not called!") + else: + self.assertEqual(check_operators_in_prof(['d2dCopyAsync'], prof), True, "d2dCopyAsync is not called!") + cpu_out = cpu_input.unsqueeze(i).clone() + self.assertRtolEqual(npu_out.to("cpu").numpy(), cpu_out.numpy()) + + def test_flatten_copy(self, device): + dtype_list = [np.float16, np.float32] + format_list = [0, 3, 29] + shape_list = [ + [2, 3, 16, 16], + [2, 3, 16, 15], + ] + shape_format = [ + [i, j, k] for i in dtype_list for j in format_list for k in shape_list + ] + + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 0, 100) + with torch.autograd.profiler.profile(use_npu=True) as prof: + npu_out = torch.flatten(npu_input, 0, 1).clone() + if item[1] == 3: + # Using d2dcopy with transdata(d2dCopyAsync) + self.assertEqual(check_operators_in_prof(['d2dCopyAsync'], prof), True, "d2dCopyAsync is not called!") + else: + # Directly using d2dcopy without transdata(View_d2dCopyAsync) + self.assertEqual(check_operators_in_prof(['View_d2dCopyAsync'], prof), True, "View_d2dCopyAsync is not called!") + + cpu_out = torch.flatten(cpu_input, 0, 1).clone() + self.assertRtolEqual(npu_out.to("cpu").numpy(), cpu_out.numpy()) + + def test_narrow_at_first_axis_copy(self, device): + # this case: slice at the first dim, tensor with offset remains contiguous + dtype_list = [np.float16, np.float32] + format_list = [2, 3, 29] + shape_list = [ + [20, 16, 16], + [20, 16, 15], + ] + shape_format = [ + [i, j, k] for i in dtype_list for j in format_list for k in shape_list + ] + + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 0, 100) + # Directly using d2dcopy without transdata(View_d2dCopyAsync) + # The key axis remains unchanged for NZ format in all cases. + # case1. base format + match_case1 = (item[1] == 2) + # case2. NZ format but no padding + match_case2 = (item[1] == 29 and item[2] == [20, 16, 16]) + + # contiguous and no offset + with torch.autograd.profiler.profile(use_npu=True) as prof: + npu_out1 = npu_input[:10,:,:].clone() + # case3. NZ format with padding but no offset + match_case3 = (item[1] == 29 and True) + if match_case1 or match_case2 or match_case3: + self.assertEqual(check_operators_in_prof(['View_d2dCopyAsync'], prof), True, "View_d2dCopyAsync is not called!") + else: + self.assertEqual(check_operators_in_prof(['narrow_npuSlice'], prof), True, "narrow_npuSlice is not called!") + cpu_out1 = cpu_input[:10,:,:].clone() + self.assertRtolEqual(npu_out1.to("cpu").numpy(), cpu_out1.numpy()) + + # contiguous but has offset + with torch.autograd.profiler.profile(use_npu=True) as prof: + npu_out2 = npu_input[1:10,:,:].clone() + match_case3 = False + if match_case1 or match_case2 or match_case3: + self.assertEqual(check_operators_in_prof(['View_d2dCopyAsync'], prof), True, "View_d2dCopyAsync is not called!") + else: + self.assertEqual(check_operators_in_prof(['narrow_npuSlice'], prof), True, "narrow_npuSlice is not called!") + cpu_out2 = cpu_input[1:10,:,:].clone() + self.assertRtolEqual(npu_out2.to("cpu").numpy(), cpu_out2.numpy()) + + def test_select_at_first_axis_to_single_element_tensor_copy(self, device): + dtype_list = [torch.float32] + format_list = [2, 3, 29] + shape_format = [ + [i, j] for i in dtype_list for j in format_list + ] + + for item in shape_format: + cpu_input = torch.tensor([1.0]).to(item[0]) + npu_input = cpu_input.npu().npu_format_cast(item[1]) + + match_case = (item[1] == 2) + with torch.autograd.profiler.profile(use_npu=True) as prof: + npu_out1 = npu_input[0].clone() + if match_case: + self.assertEqual(check_operators_in_prof(['View_d2dCopyAsync'], prof), True, "View_d2dCopyAsync is not called!") + else: + self.assertEqual(check_operators_in_prof(['narrow_npuSlice'], prof), True, "narrow_npuSlice is not called!") + cpu_out1 = cpu_input[0].clone() + self.assertRtolEqual(npu_out1.to("cpu").numpy(), cpu_out1.numpy()) + + with torch.autograd.profiler.profile(use_npu=True) as prof: + npu_out2 = npu_input[0] + 1 + if match_case: + self.assertEqual(check_operators_in_prof(['memory_repoint'], prof), True, "memory_repoint is not called!") + else: + # refresh storage desc after transdata + self.assertEqual(check_operators_in_prof(['Identity'], prof), True, "Identity is not called!") + cpu_out2 = cpu_input[0] + 1 + self.assertRtolEqual(npu_out2.to("cpu").numpy(), cpu_out2.numpy()) + + +instantiate_device_type_tests(SingleViewCopyToContiguous, globals(), except_for='cpu') +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/test/test_trans_contiguous/test_single_slice_copy_to_contiguous.py b/test/test_trans_contiguous/test_single_slice_copy_to_contiguous.py new file mode 100644 index 0000000000..624e04cf56 --- /dev/null +++ b/test/test_trans_contiguous/test_single_slice_copy_to_contiguous.py @@ -0,0 +1,151 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor, check_operators_in_prof + +os.environ["PTCOPY_ENABLE"] = "1" + +# Optimized view Ops contains Transpose, permute, narrow, strideslice, select, unfold +class SingleViewCopyToContiguous(TestCase): + def test_narrow_copy_contiguous(self, device): + # AssertionError: required dtype in [np.bool, np.int32, np.float16, np.float32, np.int8, np.uint8, np.int64] + # However, considering the dtypes that Transdata supports, only np.float16, np.float32 are tested. + dtype_list = [np.float16, np.float32] + format_list_4D = [0, 3, 29, 4] + shape_list_4D = [[2, 32, 16, 20]] + format_list_5D = [30, 32, 33] + shape_list_5D = [[2, 32, 16, 20, 15]] + shape_format_4D = [ + [i, j, k] for i in dtype_list for j in format_list_4D for k in shape_list_4D + ] + shape_format_5D = [ + [i, j, k] for i in dtype_list for j in format_list_5D for k in shape_list_5D + ] + shape_format = shape_format_4D + shape_format_5D + + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 0, 100) + # for narrow with step=1, if narrow at the first axis, it will generate a contiguous tensor + with torch.autograd.profiler.profile(use_npu=True) as prof: + npu_out1 = npu_input[:,:16,:,:].contiguous() + self.assertEqual(check_operators_in_prof(['narrow_npuSlice'], prof), True, "narrow_npuSlice is not called!") + with torch.autograd.profiler.profile(use_npu=True) as prof: + npu_out2 = npu_input[:,:,1:16,:].contiguous() + self.assertEqual(check_operators_in_prof(['narrow_npuSlice'], prof), True, "narrow_npuSlice is not called!") + with torch.autograd.profiler.profile(use_npu=True) as prof: + npu_out3 = npu_input[:,:,:,2:16].contiguous() + self.assertEqual(check_operators_in_prof(['narrow_npuSlice'], prof), True, "narrow_npuSlice is not called!") + + cpu_out1 = cpu_input[:,:16,:,:].contiguous() + cpu_out2 = cpu_input[:,:,1:16,:].contiguous() + cpu_out3 = cpu_input[:,:,:,2:16].contiguous() + if npu_input.dim() == 5: + cpu_out4 = cpu_input[:,:,:,:,3:10].contiguous() + npu_out4 = npu_input[:,:,:,:,3:10].contiguous() + self.assertRtolEqual(npu_out4.to("cpu").numpy(), cpu_out4.numpy()) + + self.assertRtolEqual(npu_out1.to("cpu").numpy(), cpu_out1.numpy()) + self.assertRtolEqual(npu_out2.to("cpu").numpy(), cpu_out2.numpy()) + self.assertRtolEqual(npu_out3.to("cpu").numpy(), cpu_out3.numpy()) + + def test_strideslice_copy_contiguous(self, device): + dtype_list = [np.float16, np.float32, np.int8, np.int32, np.uint8, np.bool] + format_list = [-1] + shape_list = [[10,32,16,9], [10,32,16,9,10]] + shape_format = [ + [i, j, k] for i in dtype_list for j in format_list for k in shape_list + ] + + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 0, 100) + # for indexing with step>1 -- stridedSlice + if cpu_input.dim() == 4: + with torch.autograd.profiler.profile(use_npu=True) as prof: + npu_out1 = npu_input[::2].contiguous() + self.assertEqual(check_operators_in_prof(['npuStridedSlice'], prof), True, "Error operators called!") + with torch.autograd.profiler.profile(use_npu=True) as prof: + npu_out2 = npu_input[:,1:17:4].contiguous() + self.assertEqual(check_operators_in_prof(['npuStridedSlice'], prof), True, "Error operators called!") + with torch.autograd.profiler.profile(use_npu=True) as prof: + npu_out3 = npu_input[:,:,2:16:5].contiguous() + self.assertEqual(check_operators_in_prof(['npuStridedSlice'], prof), True, "Error operators called!") + with torch.autograd.profiler.profile(use_npu=True) as prof: + # stridedSlice do not support slice at last dim + npu_out4 = npu_input[:,:,:,3:9:2].contiguous() + self.assertEqual(check_operators_in_prof(['d2dCopyWithPTCopy'], prof), True, "Error operators called!") + with torch.autograd.profiler.profile(use_npu=True) as prof: + npu_out5 = npu_input[::2,1:17:4,2:16:5,:].contiguous() + self.assertEqual(check_operators_in_prof(['npuStridedSlice'], prof), True, "Error operators called!") + + cpu_out1 = cpu_input[::2].contiguous() + cpu_out2 = cpu_input[:,1:17:4].contiguous() + cpu_out3 = cpu_input[:,:,2:16:5].contiguous() + cpu_out4 = cpu_input[:,:,:,3:9:2].contiguous() + #strideslice at each axis + cpu_out5 = cpu_input[::2,1:17:4,2:16:5,:].contiguous() + + self.assertRtolEqual(npu_out1.to("cpu").numpy(), cpu_out1.numpy()) + self.assertRtolEqual(npu_out2.to("cpu").numpy(), cpu_out2.numpy()) + self.assertRtolEqual(npu_out3.to("cpu").numpy(), cpu_out3.numpy()) + self.assertRtolEqual(npu_out4.to("cpu").numpy(), cpu_out4.numpy()) + self.assertRtolEqual(npu_out5.to("cpu").numpy(), cpu_out5.numpy()) + if cpu_input.dim() == 5: + cpu_out6 = cpu_input[:,:,:,:,1:7:3].contiguous() + npu_out6 = npu_input[:,:,:,:,1:7:3].contiguous() + self.assertRtolEqual(npu_out6.to("cpu").numpy(), cpu_out6.numpy()) + + def test_select_copy_contiguous(self, device): + dtype_list = [np.float16, np.float32] + format_list = [-1] + shape_list = [[2,32,16,9], [2,32,16,9,10]] + shape_format = [ + [i, j, k] for i in dtype_list for j in format_list for k in shape_list + ] + + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 0, 100) + for dim in range(1,len(item[2])): + with torch.autograd.profiler.profile(use_npu=True) as prof: + npu_out = npu_input.select(dim,1).contiguous() + self.assertEqual(check_operators_in_prof(['select_npuStridedSlice'], prof), True, "select_npuStridedSlice is not called!") + cpu_out = cpu_input.select(dim,1).contiguous() + self.assertRtolEqual(npu_out.to("cpu").numpy(), cpu_out.numpy()) + + def test_span_axis_strideslice_contiguous(self, device): + dtype_list = [np.float16, np.float32] + format_list = [-1] + shape_list = [[32,8,2], [(8,6,2), (5,4,1), 1]] + shape_format = [ + [i, j, shape_list[0]] for i in dtype_list for j in format_list + ] + + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 0, 100) + # npuStrideSlice do not support span-axis strideslice, can not be optimized + with torch.autograd.profiler.profile(use_npu=True) as prof: + npu_out = torch.as_strided(npu_input, shape_list[1][0], shape_list[1][1], shape_list[1][2]).contiguous() + self.assertEqual(check_operators_in_prof(['d2dCopyWithPTCopy'], prof, ['npuStridedSlice']), True, "Error operators called!") + cpu_out = torch.as_strided(cpu_input, shape_list[1][0], shape_list[1][1], shape_list[1][2]).contiguous() + self.assertRtolEqual(npu_out.to("cpu").numpy(), cpu_out.numpy()) + +instantiate_device_type_tests(SingleViewCopyToContiguous, globals(), except_for='cpu') +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/torch_npu/csrc/framework/contiguous/indexing_opt.cpp b/torch_npu/csrc/framework/contiguous/indexing_opt.cpp index ff4dc0683a..345087a534 100644 --- a/torch_npu/csrc/framework/contiguous/indexing_opt.cpp +++ b/torch_npu/csrc/framework/contiguous/indexing_opt.cpp @@ -32,7 +32,7 @@ namespace at_npu if (can_use_indexing(src, start, end, step)) { - RECORD_FUNCTION("npuStridedSliceD", std::vector({src})); + RECORD_FUNCTION("npuStridedSlice", std::vector({src})); indexing_to_contiguous(src, self, start, end, step); return true; } @@ -86,7 +86,7 @@ namespace at_npu for (int64_t i = 0; i < src.dim(); i++) { int64_t calculate_end = start[i] + src.size(i) * step[i]; - if (calculate_end > src.size(i)) + if (calculate_end - step[i] > src_desc.base_sizes_[i]) { // Op StrideSlice(Slice) don't support span-axis indexing(slice). return false; diff --git a/torch_npu/csrc/framework/contiguous/permute_opt.cpp b/torch_npu/csrc/framework/contiguous/permute_opt.cpp index 238017c21d..a649576cb6 100644 --- a/torch_npu/csrc/framework/contiguous/permute_opt.cpp +++ b/torch_npu/csrc/framework/contiguous/permute_opt.cpp @@ -33,7 +33,7 @@ namespace at_npu if (can_use_permute(src, perm, sizes)) { // delete call and implementation, after more test - RECORD_FUNCTION("npuTransposeD", std::vector({src})); + RECORD_FUNCTION("npuTranspose", std::vector({src})); // create contiguous tensor for npu transpose at::Tensor temp_src = at::empty(sizes, src.options()); temp_src.set_(src.storage(), temp_src.storage_offset(), temp_src.sizes(), temp_src.strides()); diff --git a/torch_npu/csrc/framework/contiguous/slice_opt.cpp b/torch_npu/csrc/framework/contiguous/slice_opt.cpp index 0eeea31d2c..2c9d96397d 100644 --- a/torch_npu/csrc/framework/contiguous/slice_opt.cpp +++ b/torch_npu/csrc/framework/contiguous/slice_opt.cpp @@ -34,7 +34,7 @@ namespace at_npu c10::SmallVector size; if (can_use_slice(src, offsets, size)) { - RECORD_FUNCTION("narrow_npuSliceD", std::vector({src})); + RECORD_FUNCTION("narrow_npuSlice", std::vector({src})); slice_to_contiguous(src, self, offsets, size); return true; } diff --git a/torch_npu/testing/util_test.py b/torch_npu/testing/util_test.py index a460af4ae3..ccb935c378 100644 --- a/torch_npu/testing/util_test.py +++ b/torch_npu/testing/util_test.py @@ -14,9 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import torch import numpy as np -import os threshold = 1.e-4 threshold2 = 1.e-3 @@ -126,4 +126,33 @@ def create_dtype_tensor(shape, dtype, npu_format=-1, min_value=-5, max_value=5, npu_input = torch.from_numpy(x).to(npu_device) if npu_format != -1 and (dtype in [torch.float, torch.half]): npu_input = npu_input.npu_format_cast(npu_format) - return cpu_input, npu_input \ No newline at end of file + return cpu_input, npu_input + +def create_common_tensor_for_broadcast(item, minValue, maxValue): + dtype = item[0] + npu_format = item[1] + shape = item[2] + input1 = np.random.uniform(minValue, maxValue, shape[0]).astype(dtype) + cpu_input = torch.from_numpy(input1) + npu_input = torch.from_numpy(input1).to("npu") + if npu_format != -1: + npu_input = npu_input.npu_format_cast(npu_format) + return cpu_input, npu_input + +def check_operators_in_prof(expected_operators, prof, unexpected_operators=None): + unexpected_operators = unexpected_operators or [] + prof_key_averages = prof.key_averages() + if not prof_key_averages: + return print("torch profiling is empty, please check it") + for prof_item in prof_key_averages: + if prof_item.key in unexpected_operators: + # if unexpected oprators are called, pattern inferring in trans-contiguous is failed + return False + elif prof_item.key in expected_operators: + # if expected oprator is called, empty it in expected_operators list + expected_operators.remove(prof_item.key) + + # if expected_operators list is empty, all oprators have been called + if not expected_operators: + return True + return False \ No newline at end of file -- Gitee From a19754bd592ac492514898aa5fee0a3162e0a2b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E8=B6=85?= Date: Thu, 10 Feb 2022 12:02:28 +0800 Subject: [PATCH 2/3] trans-contiguous optimization-format_contiguous and clone --- .../csrc/aten/common/TensorFactories.cpp | 22 +++++++++---------- .../csrc/framework/contiguous/ReshapeOpt.cpp | 8 +++++++ torch_npu/csrc/framework/utils/NpuUtils.cpp | 16 ++------------ 3 files changed, 20 insertions(+), 26 deletions(-) diff --git a/torch_npu/csrc/aten/common/TensorFactories.cpp b/torch_npu/csrc/aten/common/TensorFactories.cpp index 3d95b195a1..fde258e93e 100644 --- a/torch_npu/csrc/aten/common/TensorFactories.cpp +++ b/torch_npu/csrc/aten/common/TensorFactories.cpp @@ -38,6 +38,7 @@ #include "torch_npu/csrc/framework/utils/OpAdapter.h" #include "torch_npu/csrc/aten/NPUNativeFunctions.h" #include "torch_npu/csrc/core/tensor_impl.h" +#include "torch_npu/csrc/framework/contiguous/ContiguousOpt.h" namespace at_npu { @@ -656,22 +657,19 @@ namespace at_npu at::Tensor NPUNativeFunctions::clone(const at::Tensor &src, c10::optional format) { - auto desc = src.storage().unsafeGetStorageImpl()->npu_desc_; - auto formatSelf = OpPreparation::ApplyTensorWithFormat( - src.sizes(), src.options(), desc.npu_format_); - if (try_to_optimize_copy_with_any_format(formatSelf, src)) + std::vector opt_patterns{"reshape", "slice"}; + if (TransContiguous::CanOptimize(src, opt_patterns)) { - return formatSelf; - } - else if (can_use_memcpy(formatSelf, src)) - { - RECORD_FUNCTION("d2dCopyAsync with format", std::vector({src})); - copy_d2d_by_memcpy(formatSelf, src); - return formatSelf; + if (try_to_optimize_copy_with_any_format(formatSelf, src)) + { + auto formatTempTensor = + TransContiguous::ContiguousOptimizeWithAnyFormat(src, opt_patterns); + return formatTempTensor.value(); + } } else { - auto baseSelf = OpPreparation::ApplyTensor(src); + auto baseSelf = OpPreparation::ApplyTensorWithSizes(src.sizes(), src.options()); copy_d2d_dtype(baseSelf, src, false); return baseSelf; } diff --git a/torch_npu/csrc/framework/contiguous/ReshapeOpt.cpp b/torch_npu/csrc/framework/contiguous/ReshapeOpt.cpp index db562ee712..854ef8a562 100644 --- a/torch_npu/csrc/framework/contiguous/ReshapeOpt.cpp +++ b/torch_npu/csrc/framework/contiguous/ReshapeOpt.cpp @@ -23,6 +23,14 @@ namespace at_npu bool can_use_memecpy_for_NZ_format(const at::Tensor &tensor) { auto base_size = tensor.storage().get_npu_desc().base_sizes_; + + // No padding&&offset!=0 at same time. e.g. x(3, 15, 16)[1:] with format=29 + if (((tensor.size(-1) % 16 != 0) || (tensor.size(-2) % 16 != 0)) && + tensor.storage_offset() != 0) + { + return false; + } + // Make sure that sizes of last 2 dims don't change if (tensor.size(-1) != base_size[base_size.size() - 1] || tensor.size(-2) != base_size[base_size.size() - 2]) diff --git a/torch_npu/csrc/framework/utils/NpuUtils.cpp b/torch_npu/csrc/framework/utils/NpuUtils.cpp index a0efe852f7..bc88a06640 100644 --- a/torch_npu/csrc/framework/utils/NpuUtils.cpp +++ b/torch_npu/csrc/framework/utils/NpuUtils.cpp @@ -231,18 +231,6 @@ namespace at_npu return ret_tmp; } - at::Tensor tensors_convert_contiguous(const at::Tensor &src) - { - std::vector optimizations{"slice"}; - auto formatTempTensor = - TransContiguous::ContiguousOptimizeWithAnyFormat(src, optimizations); - if (formatTempTensor.has_value()) - { - return formatTempTensor.value(); - } - return src.contiguous(); - } - at::Tensor metadata_convert_match(const at::Tensor &src) { auto &src_desc = src.storage().unsafeGetStorageImpl()->npu_desc_; @@ -284,7 +272,7 @@ namespace at_npu if (!src.is_contiguous()) { RECORD_FUNCTION("format_contiguous", vector({src})); - return tensors_convert_contiguous(src); + return src.contiguous(); } // case2:meta data not match, sizes or strides of presentation // layer is different from that of storage layer @@ -313,7 +301,7 @@ namespace at_npu if (!src.is_contiguous()) { RECORD_FUNCTION("format_contiguousV2", vector({src})); - return tensors_convert_contiguous(src); + return src.contiguous(); } // case2:meta data not match, sizes or strides of presentation // layer is different from that of storage layer -- Gitee From 765d1893692d3b6d0239e613b0a08ca83819a9e9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E8=B6=85?= Date: Thu, 10 Feb 2022 12:18:01 +0800 Subject: [PATCH 3/3] trans-contiguous optimization-format_contiguous and clone-1 --- torch_npu/csrc/aten/common/TensorFactories.cpp | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/torch_npu/csrc/aten/common/TensorFactories.cpp b/torch_npu/csrc/aten/common/TensorFactories.cpp index fde258e93e..828a0b337f 100644 --- a/torch_npu/csrc/aten/common/TensorFactories.cpp +++ b/torch_npu/csrc/aten/common/TensorFactories.cpp @@ -660,12 +660,9 @@ namespace at_npu std::vector opt_patterns{"reshape", "slice"}; if (TransContiguous::CanOptimize(src, opt_patterns)) { - if (try_to_optimize_copy_with_any_format(formatSelf, src)) - { - auto formatTempTensor = - TransContiguous::ContiguousOptimizeWithAnyFormat(src, opt_patterns); - return formatTempTensor.value(); - } + auto formatTempTensor = + TransContiguous::ContiguousOptimizeWithAnyFormat(src, opt_patterns); + return formatTempTensor.value(); } else { -- Gitee