diff --git a/README.en.md b/README.en.md index 427b8a264563f09c3e4c770b69f4917039365457..a7519564f452c31ec6b985b3e5702c6294b21636 100644 --- a/README.en.md +++ b/README.en.md @@ -73,7 +73,6 @@ The following environment variables are function classes used in NPU scenarios o ``` export TASK_QUEUE_ENABLE=1 # Delivered by an asynchronous task to asynchronously call the ACL interface. You are advised to enable this environment variable and set its value to 1. -export PTCOPY_ENABLE=1 # Use the PTCopy operator mode to accelerate continuous rotation and copy. You are advised to enable this environment variable and set its value to 1. ``` The following are optional environment variables that may affect running models: diff --git a/README.zh.md b/README.zh.md index 262022df682d0c43b0d915e8997c3f646d7ad8c8..e9fcb196e595c1c03b35aea0b0fa9d26b20c2bc5 100644 --- a/README.zh.md +++ b/README.zh.md @@ -77,7 +77,6 @@ source pytorch/env.sh ``` export TASK_QUEUE_ENABLE=1 # 使用异步任务下发,异步调用acl接口,建议默认开启,开启设置为1 -export PTCOPY_ENABLE=1 # 使用PTCopy算子模式,加速转连续及copy等过程,建议默认开启,开启设置为1 ``` 可选的环境变量可能会对运行的模型产生影响: diff --git a/env.sh b/env.sh index d35a02be766772bb21fa33fd4b3bd2fe4db367a8..d17e4862485d21a814361f1ac4ff67e10e9a60d5 100644 --- a/env.sh +++ b/env.sh @@ -96,7 +96,6 @@ export LD_LIBRARY_PATH=${path_lib}:$LD_LIBRARY_PATH # pytorch 自定义环境变量 export TASK_QUEUE_ENABLE=0 # 使用异步任务下发,异步调用acl接口,建议默认开启,开启设置为1 -export PTCOPY_ENABLE=1 # 使用PTCopy算子模式,加速转连续及copy等过程,建议默认开启,开启设置为1 #export DYNAMIC_COMPILE_ENABLE=1 # 动态shape特性功能,针对shape变化场景,可选 开启设置为1 # log diff --git a/test/test_trans_contiguous/test_as_strided_copy_to_contiguous.py b/test/test_trans_contiguous/test_as_strided_copy_to_contiguous.py new file mode 100644 index 0000000000000000000000000000000000000000..4cda3e3ac4c53fa7dfc8993d9dd183a7e703f850 --- /dev/null +++ b/test/test_trans_contiguous/test_as_strided_copy_to_contiguous.py @@ -0,0 +1,69 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor, check_operators_in_prof + +os.environ["COMBINED_ENABLE"] = "1" # Open combined-view cases optimization + +class TestAsStridedCopyToContiguous(TestCase): + def cpu_op_exec(self, input1, size, stride, storage_offset): + output = torch.as_strided(input1, size, stride, storage_offset).contiguous() + output = output.numpy() + return output + + def npu_op_exec(self,input1, size, stride, storage_offset): + with torch.autograd.profiler.profile(use_npu=True) as prof: + output = torch.as_strided(input1, size, stride, storage_offset).contiguous() + self.assertEqual(check_operators_in_prof(['npuAsStrided'], prof, ['npuCombined']), True, "Error operators called!") + output = output.cpu().numpy() + return output + + def test_as_strided(self, device): + dtype_list = [np.bool, np.int32, np.float16, np.float32, np.int8, np.uint8, np.int64] + format_list = [-1] + small_shape_list = [ + [5, 5] + ] + small_shape_format = [ + [i, j, k] for i in dtype_list for j in format_list for k in small_shape_list + ] + + for item in small_shape_format: + cpu_input, npu_input = create_common_tensor(item, -100, 100) + cpu_output = self.cpu_op_exec(cpu_input, (3, 3), (1, 2), 1) + npu_output = self.npu_op_exec(npu_input, (3, 3), (1, 2), 1) + self.assertRtolEqual(cpu_output, npu_output) + + other_shape_format = [ + [[np.float16, 0, [13, 23]], (10, 15), (1, 2), 1], + [[np.float16, 3, [2, 13, 23]], (10, 15), (1, 2), 2], + [[np.float32, 29, [6, 32, 8, 2]], (8, 6, 2), (5, 4, 1), 3], + ] + + for item in other_shape_format: + cpu_input, npu_input = create_common_tensor(item[0], -100, 100) + cpu_output = self.cpu_op_exec(cpu_input, item[1], item[2], item[3]) + npu_output = self.npu_op_exec(npu_input, item[1], item[2], item[3]) + self.assertRtolEqual(cpu_output, npu_output) + +instantiate_device_type_tests(TestAsStridedCopyToContiguous, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/test/test_trans_contiguous/test_tri_combined_views_copy_to_contiguous.py b/test/test_trans_contiguous/test_tri_combined_views_copy_to_contiguous.py new file mode 100644 index 0000000000000000000000000000000000000000..e2ea549d0b7578928eb579584dbbba5aaa26432e --- /dev/null +++ b/test/test_trans_contiguous/test_tri_combined_views_copy_to_contiguous.py @@ -0,0 +1,81 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor, check_operators_in_prof + +os.environ["COMBINED_ENABLE"] = "1" # Open combined-view cases optimization + +class TestTriCombinedViewsCopyToContiguous(TestCase): + def test_view_narrow_permute_copy_contiguous(self, device): + dtype_list = [np.float16, np.float32] + format_list = [-1] + shape_list = [ + [200, 30, 40, 16], + ] + shape_format = [ + [i, j, k] for i in dtype_list for j in format_list for k in shape_list + ] + + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 0, 100) + # case 1: view+narrow+permute ==> cannot be optimized + with torch.autograd.profiler.profile(use_npu=True) as prof: + npu_out1 = npu_input.view(npu_input.size(0) * npu_input.size(1), npu_input.size(2), npu_input.size(3))[:,1:10].transpose(0, 1).contiguous() + self.assertEqual(check_operators_in_prof(['npuAsStrided'], prof, ['npuCombined']), True, "Error operators called!") + cpu_out1 = cpu_input.view(cpu_input.size(0) * cpu_input.size(1), cpu_input.size(2), cpu_input.size(3))[:,1:10].transpose(0, 1).contiguous() + self.assertRtolEqual(npu_out1.to("cpu").numpy(), cpu_out1.numpy()) + + # case 2: permute+view+narrow ==> cannot be optimized + with torch.autograd.profiler.profile(use_npu=True) as prof: + npu_out2 = npu_input.permute(1, 0, 2, 3).view(npu_input.size(1), npu_input.size(0), npu_input.size(2)*npu_input.size(3))[:,:,1:10].contiguous() + self.assertEqual(check_operators_in_prof(['npuAsStrided'], prof, ['npuCombined']), True, "Error operators called!") + cpu_out2 = cpu_input.permute(1, 0, 2, 3).view(cpu_input.size(1), cpu_input.size(0), cpu_input.size(2)*cpu_input.size(3))[:,:,1:10].contiguous() + self.assertRtolEqual(npu_out2.to("cpu").numpy(), cpu_out2.numpy()) + + def test_view_select_permute_copy_contiguous(self, device): + dtype_list = [np.float16, np.float32] + format_list = [-1] + shape_list = [ + [200, 30, 40, 16], + ] + shape_format = [ + [i, j, k] for i in dtype_list for j in format_list for k in shape_list + ] + + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 0, 100) + # case 1: view+select+permute ==> cannot be optimized + with torch.autograd.profiler.profile(use_npu=True) as prof: + npu_out1 = npu_input.view(npu_input.size(0) * npu_input.size(1), npu_input.size(2), npu_input.size(3))[:,1].transpose(0, 1).contiguous() + self.assertEqual(check_operators_in_prof(['npuAsStrided'], prof, ['npuCombined']), True, "Error operators called!") + cpu_out1 = cpu_input.view(cpu_input.size(0) * cpu_input.size(1), cpu_input.size(2), cpu_input.size(3))[:,1].transpose(0, 1).contiguous() + self.assertRtolEqual(npu_out1.to("cpu").numpy(), cpu_out1.numpy()) + + # case 2: permute+view+select ==> cannot be optimized + with torch.autograd.profiler.profile(use_npu=True) as prof: + npu_out2 = npu_input.permute(1, 0, 2, 3).view(npu_input.size(1), npu_input.size(0), npu_input.size(2)*npu_input.size(3))[:,:,2].contiguous() + self.assertEqual(check_operators_in_prof(['npuAsStrided'], prof, ['npuCombined']), True, "Error operators called!") + cpu_out2 = cpu_input.permute(1, 0, 2, 3).view(cpu_input.size(1), cpu_input.size(0), cpu_input.size(2)*cpu_input.size(3))[:,:,2].contiguous() + self.assertRtolEqual(npu_out2.to("cpu").numpy(), cpu_out2.numpy()) + +instantiate_device_type_tests(TestTriCombinedViewsCopyToContiguous, globals(), except_for='cpu') +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/torch_npu/csrc/aten/common/CopyKernel.cpp b/torch_npu/csrc/aten/common/CopyKernel.cpp index 6bc360ae5bc1360064c35195fda6f612c1116470..e48705bee02e3da6fe18806f0e7d08f547e5c773 100644 --- a/torch_npu/csrc/aten/common/CopyKernel.cpp +++ b/torch_npu/csrc/aten/common/CopyKernel.cpp @@ -33,75 +33,6 @@ namespace at_npu { namespace native { namespace { -// src : host <-- device -// | copy src to dst on cpu -// dst : host --> device -void copy_d2d_via_host(at::Tensor& self, const at::Tensor& src, bool same_type) { - c10::npu::NPUStream copy_stream = c10::npu::getCurrentNPUStream(); - aclError error = aclrtSynchronizeStream(copy_stream); - if (error != ACL_ERROR_NONE) { - AT_ERROR("ACL stream synchronize failed."); - return; - } - - int64_t real_bytes = - StorageDescHelper::GetValidMemorySize(src) * src.element_size(); - auto cpu_src = at::empty( - real_bytes / src.element_size(), src.options().device(at::kCPU)); - cpu_src = cpu_src.as_strided(src.sizes(), src.strides()); - - error = aclrtMemcpy( - cpu_src.data_ptr(), - real_bytes, - src.data_ptr(), - real_bytes, - ACL_MEMCPY_DEVICE_TO_HOST); - if (error != ACL_ERROR_NONE) { - AT_ERROR("aclrtMemcpy device to cpu_src error."); - return; - } - - real_bytes = - StorageDescHelper::GetValidMemorySize(self) * self.element_size(); - auto cpu_dst = at::empty( - real_bytes / self.element_size(), self.options().device(at::kCPU)); - cpu_dst = cpu_dst.as_strided(self.sizes(), self.strides()); - - if (!same_type) { - cpu_src = cpu_src.to(cpu_dst.dtype()); - } - - // sometimes npu_dst just need part of cpu_dst's elements, so we do memory - // copy from npu to cpu here, let npu_dst cover cpu_dst, to avoid unneeded - // cpu_dst's elements cover npu_dst's original elements - if ((!cpu_dst.is_contiguous()) && (self.defined())) { - error = aclrtMemcpy( - cpu_dst.data_ptr(), - real_bytes, - self.data_ptr(), - real_bytes, - ACL_MEMCPY_DEVICE_TO_HOST); - if (error != ACL_ERROR_NONE) { - AT_ERROR("ACL_Memcpy device to cpu_dst error."); - return; - } - } - - cpu_dst.copy_(cpu_src); - - error = aclrtMemcpy( - self.data_ptr(), - real_bytes, - cpu_dst.data_ptr(), - real_bytes, - ACL_MEMCPY_HOST_TO_DEVICE); - if (error != ACL_ERROR_NONE) { - AT_ERROR("aclrtMemcpy cpu_dst to device error."); - return; - } - NPU_LOGD("Src or dst is not contiguous when do device to device copy."); -} - // NOTE: helper function of copy, the input parameter is not checked, The caller // needs to ensure that the parameters are correct. @@ -132,14 +63,8 @@ void copy_d2d_last_method( bool same_type, bool non_blocking) { // general copy method but Low performance - if (torch_npu::option::OptionsManager::CheckPTcopy_Enable()) { - RECORD_FUNCTION("d2dCopyWithPTCopy", std::vector({src})); - copy_kernel_npu(self, src, non_blocking); - } else { - RECORD_FUNCTION( - "d2dCopyWithStreamSynchronize", std::vector({src})); - copy_d2d_via_host(self, src, same_type); - } + RECORD_FUNCTION("d2dCopyWithPTCopy", std::vector({src})); + copy_kernel_npu(self, src, non_blocking); } // the dst and src are same format now @@ -150,15 +75,22 @@ void copy_d2d_dtype_baseformat( const at::Tensor& src, bool non_blocking) { if (!self.is_contiguous()) { + // Contiguous/discontiguous source tensor copy to discontiguous self tensor return copy_d2d_last_method(self, src, true, non_blocking); } if (!src.is_contiguous()) { - // discontiguous + // Discontiguous source tensor copy to contiguous self tensor if (TransContiguous::ContiguousOptimizeWithBaseFormat(self, src)) { + // Optimized trans-contiguous method + return; + } else { + // General trans-contiguous method + NPUNativeFunctions::npu_stride_copy_out(src, src.sizes(), src.strides(), src.storage_offset(), self); return; } } else { + // Contiguous source tensor copy to contiguous self tensor int64_t numel = self.numel(); if (numel == src.numel()) { RECORD_FUNCTION("d2dCopyAsync", std::vector({src})); diff --git a/torch_npu/csrc/aten/common/CopyKernelNpu.cpp b/torch_npu/csrc/aten/common/CopyKernelNpu.cpp index a283c31d63de486cc0bcabdaefb75761a2109eac..cb669c4b2df742c7c2f0a3c1370f736c839bc198 100644 --- a/torch_npu/csrc/aten/common/CopyKernelNpu.cpp +++ b/torch_npu/csrc/aten/common/CopyKernelNpu.cpp @@ -28,7 +28,7 @@ c10::SmallVector get_view_value( static c10::SmallVector value; // It is determined by the definition of view attr value.resize(strides.size() + 3); - value[0] = t.numel(); // storageImpl numel + value[0] = t.storage().nbytes() / t.element_size(); // storageImpl numel value[1] = t.storage_offset(); // default to 0 value[2] = strides.size(); for (size_t i = 0; i < strides.size(); i++) { diff --git a/torch_npu/csrc/aten/npu_native_functions.yaml b/torch_npu/csrc/aten/npu_native_functions.yaml index c2624df5c34a5feb2e077fe8689c5c5959bcb26e..5cbdf152bf29fcd090ba194fe942ce7f1e7322a2 100644 --- a/torch_npu/csrc/aten/npu_native_functions.yaml +++ b/torch_npu/csrc/aten/npu_native_functions.yaml @@ -1914,6 +1914,9 @@ custom: variants: function, method - func: npu_softmax_cross_entropy_with_logits_backward(Tensor grad, Tensor self, Tensor labels) -> Tensor variants: function, method + - func: npu_stride_copy(Tensor self, int[] shape, int[] stride, Scalar storage_offset) -> Tensor + variants: function, method + - func: npu_stride_copy.out(Tensor self, int[] shape, int[] stride, Scalar storage_offset, *, Tensor(a!) out) -> Tensor(a!) custom_autograd: - func: npu_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor - func: npu_convolution_transpose(Tensor input, Tensor weight, Tensor? bias, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups) -> Tensor diff --git a/torch_npu/csrc/aten/ops/AsStridedKernelNpu.cpp b/torch_npu/csrc/aten/ops/AsStridedKernelNpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..fbb3f1e0e76a60b76aef7a9400bb7679ca2d514c --- /dev/null +++ b/torch_npu/csrc/aten/ops/AsStridedKernelNpu.cpp @@ -0,0 +1,63 @@ +// Copyright (c) 2020, Huawei Technologies.All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" +#include + +namespace at_npu { +namespace native { + +at::Tensor& stride_copy_out_npu_nocheck( + at::Tensor& result, + const at::Tensor& self, + at::IntArrayRef shape, + at::IntArrayRef stride, + at::Scalar storage_offset) { + RECORD_FUNCTION("npuAsStrided", std::vector({self})); + OpCommand cmd; + cmd.Name("AsStrided") + .InputWithoutContiguous(self) + .Input(shape) + .Input(stride) + .Input(storage_offset, at::kLong) + .Output(result) + .Run(); + return result; +} + +at::Tensor& NPUNativeFunctions::npu_stride_copy_out( + const at::Tensor& self, + at::IntArrayRef shape, + at::IntArrayRef stride, + at::Scalar storage_offset, + at::Tensor& result) { + stride_copy_out_npu_nocheck(result, self, shape, stride, storage_offset); + return result; +} + +at::Tensor NPUNativeFunctions:: npu_stride_copy( + const at::Tensor& self, + at::IntArrayRef shape, + at::IntArrayRef stride, + at::Scalar storage_offset) { + // AsStrided OP only supports ND input + at::Tensor result = OpPreparation::ApplyTensorWithFormat( + shape, self.options(), ACL_FORMAT_ND); + stride_copy_out_npu_nocheck(result, self, shape, stride, storage_offset); + return result; +} + +} // namespace native +} // namespace at \ No newline at end of file diff --git a/torch_npu/csrc/framework/OpCommandBase.h b/torch_npu/csrc/framework/OpCommandBase.h index 069c2e4e8211226a682f16b0f703cd73e1fd4c4e..bb27053170fe61b9266db18ce1b2774893d98268 100644 --- a/torch_npu/csrc/framework/OpCommandBase.h +++ b/torch_npu/csrc/framework/OpCommandBase.h @@ -72,6 +72,14 @@ namespace at_npu return static_cast(*this); } + Derived &InputWithoutContiguous( + const at::Tensor &input, + const string &descName = "", + const string &realData = "") + { + return AddTensorInput(const_cast(input), at::ScalarType::Undefined, descName, realData); + } + Derived &Input() { return AddNoneTensor(); diff --git a/torch_npu/csrc/framework/contiguous/slice_opt.cpp b/torch_npu/csrc/framework/contiguous/slice_opt.cpp index 0eeea31d2ce6233d13c7e570814276e68877e005..c43cb5cf6fac49027ae121a8e2dcc191191ffecd 100644 --- a/torch_npu/csrc/framework/contiguous/slice_opt.cpp +++ b/torch_npu/csrc/framework/contiguous/slice_opt.cpp @@ -27,7 +27,7 @@ namespace at_npu public: bool Optimizer(const at::Tensor &src, at::Tensor &self) override { - // Pattern slice. Current pattern should be used before PTcopy process. + // Pattern slice. // Current pattern does not directly depend on other patterns. // The relative sequence of this pattern and other patterns is not important. c10::SmallVector offsets; diff --git a/torch_npu/csrc/register/OptionsManager.cpp b/torch_npu/csrc/register/OptionsManager.cpp index 0afa6ce4d49bdd48ee26026256829ea18559574f..cb93b59f9df31efe5906b2167567bdb06ae9bd7a 100644 --- a/torch_npu/csrc/register/OptionsManager.cpp +++ b/torch_npu/csrc/register/OptionsManager.cpp @@ -30,14 +30,6 @@ bool OptionsManager::CheckQueueEnable() { return (queue_enable == 1); } -bool OptionsManager::CheckPTcopy_Enable() { - static int32_t PTcopy__enable = -1; - if (PTcopy__enable == -1) { - PTcopy__enable = GetBoolTypeOption("PTCOPY_ENABLE"); - } - return (PTcopy__enable == 1); -} - bool OptionsManager::CheckCombinedOptimizerEnable() { static int32_t combined_optimize = -1; if (combined_optimize == -1) { diff --git a/torch_npu/csrc/register/OptionsManager.h b/torch_npu/csrc/register/OptionsManager.h index dd59ea6513cb6d4da2518f3c728e6ea30a39e598..0bae5ab2005ea57d95afc7a922ddca3f6ac8283c 100644 --- a/torch_npu/csrc/register/OptionsManager.h +++ b/torch_npu/csrc/register/OptionsManager.h @@ -27,7 +27,6 @@ namespace option { class OptionsManager { public: static bool CheckQueueEnable(); - static bool CheckPTcopy_Enable(); static bool CheckCombinedOptimizerEnable(); static bool CheckTriCombinedOptimizerEnable(); static bool CheckAclDumpDateEnable(); diff --git a/torch_npu/testing/util_test.py b/torch_npu/testing/util_test.py index a460af4ae332f2fe2a229946dfc58da3c3e0483e..9d384a822fcf5500f5c544c5109eaf9209d4d4b5 100644 --- a/torch_npu/testing/util_test.py +++ b/torch_npu/testing/util_test.py @@ -126,4 +126,22 @@ def create_dtype_tensor(shape, dtype, npu_format=-1, min_value=-5, max_value=5, npu_input = torch.from_numpy(x).to(npu_device) if npu_format != -1 and (dtype in [torch.float, torch.half]): npu_input = npu_input.npu_format_cast(npu_format) - return cpu_input, npu_input \ No newline at end of file + return cpu_input, npu_input + +def check_operators_in_prof(expected_operators, prof, unexpected_operators=None): + unexpected_operators = unexpected_operators or [] + prof_key_averages = prof.key_averages() + if not prof_key_averages: + return print("torch profiling is empty, please check it") + for prof_item in prof_key_averages: + if prof_item.key in unexpected_operators: + # if unexpected oprators are called, pattern inferring in trans-contiguous is failed + return False + elif prof_item.key in expected_operators: + # if expected oprator is called, empty it in expected_operators list + expected_operators.remove(prof_item.key) + + # if expected_operators list is empty, all oprators have been called + if not expected_operators: + return True + return False \ No newline at end of file