From 633cf76f25919269fe17e986b9ef36fe5e24e824 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E8=B6=85?= Date: Mon, 14 Feb 2022 11:43:52 +0800 Subject: [PATCH 1/8] replace ptcopy with asstride in trans-contiguous --- README.en.md | 1 - README.zh.md | 1 - env.sh | 1 - .../test_as_strided_copy_to_contiguous.py | 69 +++++++++++++++ ...t_tri_combined_views_copy_to_contiguous.py | 81 +++++++++++++++++ torch_npu/csrc/aten/common/CopyKernel.cpp | 88 +++---------------- torch_npu/csrc/aten/common/CopyKernelNpu.cpp | 2 +- torch_npu/csrc/aten/npu_native_functions.yaml | 3 + .../csrc/aten/ops/AsStridedKernelNpu.cpp | 63 +++++++++++++ torch_npu/csrc/framework/OpCommandBase.h | 9 ++ .../csrc/framework/contiguous/slice_opt.cpp | 2 +- torch_npu/csrc/register/OptionsManager.cpp | 8 -- torch_npu/csrc/register/OptionsManager.h | 1 - torch_npu/testing/util_test.py | 20 ++++- 14 files changed, 256 insertions(+), 93 deletions(-) create mode 100644 test/test_trans_contiguous/test_as_strided_copy_to_contiguous.py create mode 100644 test/test_trans_contiguous/test_tri_combined_views_copy_to_contiguous.py create mode 100644 torch_npu/csrc/aten/ops/AsStridedKernelNpu.cpp diff --git a/README.en.md b/README.en.md index 427b8a26456..a7519564f45 100644 --- a/README.en.md +++ b/README.en.md @@ -73,7 +73,6 @@ The following environment variables are function classes used in NPU scenarios o ``` export TASK_QUEUE_ENABLE=1 # Delivered by an asynchronous task to asynchronously call the ACL interface. You are advised to enable this environment variable and set its value to 1. -export PTCOPY_ENABLE=1 # Use the PTCopy operator mode to accelerate continuous rotation and copy. You are advised to enable this environment variable and set its value to 1. ``` The following are optional environment variables that may affect running models: diff --git a/README.zh.md b/README.zh.md index 262022df682..e9fcb196e59 100644 --- a/README.zh.md +++ b/README.zh.md @@ -77,7 +77,6 @@ source pytorch/env.sh ``` export TASK_QUEUE_ENABLE=1 # 使用异步任务下发,异步调用acl接口,建议默认开启,开启设置为1 -export PTCOPY_ENABLE=1 # 使用PTCopy算子模式,加速转连续及copy等过程,建议默认开启,开启设置为1 ``` 可选的环境变量可能会对运行的模型产生影响: diff --git a/env.sh b/env.sh index d35a02be766..d17e4862485 100644 --- a/env.sh +++ b/env.sh @@ -96,7 +96,6 @@ export LD_LIBRARY_PATH=${path_lib}:$LD_LIBRARY_PATH # pytorch 自定义环境变量 export TASK_QUEUE_ENABLE=0 # 使用异步任务下发,异步调用acl接口,建议默认开启,开启设置为1 -export PTCOPY_ENABLE=1 # 使用PTCopy算子模式,加速转连续及copy等过程,建议默认开启,开启设置为1 #export DYNAMIC_COMPILE_ENABLE=1 # 动态shape特性功能,针对shape变化场景,可选 开启设置为1 # log diff --git a/test/test_trans_contiguous/test_as_strided_copy_to_contiguous.py b/test/test_trans_contiguous/test_as_strided_copy_to_contiguous.py new file mode 100644 index 00000000000..4cda3e3ac4c --- /dev/null +++ b/test/test_trans_contiguous/test_as_strided_copy_to_contiguous.py @@ -0,0 +1,69 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor, check_operators_in_prof + +os.environ["COMBINED_ENABLE"] = "1" # Open combined-view cases optimization + +class TestAsStridedCopyToContiguous(TestCase): + def cpu_op_exec(self, input1, size, stride, storage_offset): + output = torch.as_strided(input1, size, stride, storage_offset).contiguous() + output = output.numpy() + return output + + def npu_op_exec(self,input1, size, stride, storage_offset): + with torch.autograd.profiler.profile(use_npu=True) as prof: + output = torch.as_strided(input1, size, stride, storage_offset).contiguous() + self.assertEqual(check_operators_in_prof(['npuAsStrided'], prof, ['npuCombined']), True, "Error operators called!") + output = output.cpu().numpy() + return output + + def test_as_strided(self, device): + dtype_list = [np.bool, np.int32, np.float16, np.float32, np.int8, np.uint8, np.int64] + format_list = [-1] + small_shape_list = [ + [5, 5] + ] + small_shape_format = [ + [i, j, k] for i in dtype_list for j in format_list for k in small_shape_list + ] + + for item in small_shape_format: + cpu_input, npu_input = create_common_tensor(item, -100, 100) + cpu_output = self.cpu_op_exec(cpu_input, (3, 3), (1, 2), 1) + npu_output = self.npu_op_exec(npu_input, (3, 3), (1, 2), 1) + self.assertRtolEqual(cpu_output, npu_output) + + other_shape_format = [ + [[np.float16, 0, [13, 23]], (10, 15), (1, 2), 1], + [[np.float16, 3, [2, 13, 23]], (10, 15), (1, 2), 2], + [[np.float32, 29, [6, 32, 8, 2]], (8, 6, 2), (5, 4, 1), 3], + ] + + for item in other_shape_format: + cpu_input, npu_input = create_common_tensor(item[0], -100, 100) + cpu_output = self.cpu_op_exec(cpu_input, item[1], item[2], item[3]) + npu_output = self.npu_op_exec(npu_input, item[1], item[2], item[3]) + self.assertRtolEqual(cpu_output, npu_output) + +instantiate_device_type_tests(TestAsStridedCopyToContiguous, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/test/test_trans_contiguous/test_tri_combined_views_copy_to_contiguous.py b/test/test_trans_contiguous/test_tri_combined_views_copy_to_contiguous.py new file mode 100644 index 00000000000..e2ea549d0b7 --- /dev/null +++ b/test/test_trans_contiguous/test_tri_combined_views_copy_to_contiguous.py @@ -0,0 +1,81 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor, check_operators_in_prof + +os.environ["COMBINED_ENABLE"] = "1" # Open combined-view cases optimization + +class TestTriCombinedViewsCopyToContiguous(TestCase): + def test_view_narrow_permute_copy_contiguous(self, device): + dtype_list = [np.float16, np.float32] + format_list = [-1] + shape_list = [ + [200, 30, 40, 16], + ] + shape_format = [ + [i, j, k] for i in dtype_list for j in format_list for k in shape_list + ] + + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 0, 100) + # case 1: view+narrow+permute ==> cannot be optimized + with torch.autograd.profiler.profile(use_npu=True) as prof: + npu_out1 = npu_input.view(npu_input.size(0) * npu_input.size(1), npu_input.size(2), npu_input.size(3))[:,1:10].transpose(0, 1).contiguous() + self.assertEqual(check_operators_in_prof(['npuAsStrided'], prof, ['npuCombined']), True, "Error operators called!") + cpu_out1 = cpu_input.view(cpu_input.size(0) * cpu_input.size(1), cpu_input.size(2), cpu_input.size(3))[:,1:10].transpose(0, 1).contiguous() + self.assertRtolEqual(npu_out1.to("cpu").numpy(), cpu_out1.numpy()) + + # case 2: permute+view+narrow ==> cannot be optimized + with torch.autograd.profiler.profile(use_npu=True) as prof: + npu_out2 = npu_input.permute(1, 0, 2, 3).view(npu_input.size(1), npu_input.size(0), npu_input.size(2)*npu_input.size(3))[:,:,1:10].contiguous() + self.assertEqual(check_operators_in_prof(['npuAsStrided'], prof, ['npuCombined']), True, "Error operators called!") + cpu_out2 = cpu_input.permute(1, 0, 2, 3).view(cpu_input.size(1), cpu_input.size(0), cpu_input.size(2)*cpu_input.size(3))[:,:,1:10].contiguous() + self.assertRtolEqual(npu_out2.to("cpu").numpy(), cpu_out2.numpy()) + + def test_view_select_permute_copy_contiguous(self, device): + dtype_list = [np.float16, np.float32] + format_list = [-1] + shape_list = [ + [200, 30, 40, 16], + ] + shape_format = [ + [i, j, k] for i in dtype_list for j in format_list for k in shape_list + ] + + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 0, 100) + # case 1: view+select+permute ==> cannot be optimized + with torch.autograd.profiler.profile(use_npu=True) as prof: + npu_out1 = npu_input.view(npu_input.size(0) * npu_input.size(1), npu_input.size(2), npu_input.size(3))[:,1].transpose(0, 1).contiguous() + self.assertEqual(check_operators_in_prof(['npuAsStrided'], prof, ['npuCombined']), True, "Error operators called!") + cpu_out1 = cpu_input.view(cpu_input.size(0) * cpu_input.size(1), cpu_input.size(2), cpu_input.size(3))[:,1].transpose(0, 1).contiguous() + self.assertRtolEqual(npu_out1.to("cpu").numpy(), cpu_out1.numpy()) + + # case 2: permute+view+select ==> cannot be optimized + with torch.autograd.profiler.profile(use_npu=True) as prof: + npu_out2 = npu_input.permute(1, 0, 2, 3).view(npu_input.size(1), npu_input.size(0), npu_input.size(2)*npu_input.size(3))[:,:,2].contiguous() + self.assertEqual(check_operators_in_prof(['npuAsStrided'], prof, ['npuCombined']), True, "Error operators called!") + cpu_out2 = cpu_input.permute(1, 0, 2, 3).view(cpu_input.size(1), cpu_input.size(0), cpu_input.size(2)*cpu_input.size(3))[:,:,2].contiguous() + self.assertRtolEqual(npu_out2.to("cpu").numpy(), cpu_out2.numpy()) + +instantiate_device_type_tests(TestTriCombinedViewsCopyToContiguous, globals(), except_for='cpu') +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/torch_npu/csrc/aten/common/CopyKernel.cpp b/torch_npu/csrc/aten/common/CopyKernel.cpp index 6bc360ae5bc..a2473cc62e4 100644 --- a/torch_npu/csrc/aten/common/CopyKernel.cpp +++ b/torch_npu/csrc/aten/common/CopyKernel.cpp @@ -33,75 +33,6 @@ namespace at_npu { namespace native { namespace { -// src : host <-- device -// | copy src to dst on cpu -// dst : host --> device -void copy_d2d_via_host(at::Tensor& self, const at::Tensor& src, bool same_type) { - c10::npu::NPUStream copy_stream = c10::npu::getCurrentNPUStream(); - aclError error = aclrtSynchronizeStream(copy_stream); - if (error != ACL_ERROR_NONE) { - AT_ERROR("ACL stream synchronize failed."); - return; - } - - int64_t real_bytes = - StorageDescHelper::GetValidMemorySize(src) * src.element_size(); - auto cpu_src = at::empty( - real_bytes / src.element_size(), src.options().device(at::kCPU)); - cpu_src = cpu_src.as_strided(src.sizes(), src.strides()); - - error = aclrtMemcpy( - cpu_src.data_ptr(), - real_bytes, - src.data_ptr(), - real_bytes, - ACL_MEMCPY_DEVICE_TO_HOST); - if (error != ACL_ERROR_NONE) { - AT_ERROR("aclrtMemcpy device to cpu_src error."); - return; - } - - real_bytes = - StorageDescHelper::GetValidMemorySize(self) * self.element_size(); - auto cpu_dst = at::empty( - real_bytes / self.element_size(), self.options().device(at::kCPU)); - cpu_dst = cpu_dst.as_strided(self.sizes(), self.strides()); - - if (!same_type) { - cpu_src = cpu_src.to(cpu_dst.dtype()); - } - - // sometimes npu_dst just need part of cpu_dst's elements, so we do memory - // copy from npu to cpu here, let npu_dst cover cpu_dst, to avoid unneeded - // cpu_dst's elements cover npu_dst's original elements - if ((!cpu_dst.is_contiguous()) && (self.defined())) { - error = aclrtMemcpy( - cpu_dst.data_ptr(), - real_bytes, - self.data_ptr(), - real_bytes, - ACL_MEMCPY_DEVICE_TO_HOST); - if (error != ACL_ERROR_NONE) { - AT_ERROR("ACL_Memcpy device to cpu_dst error."); - return; - } - } - - cpu_dst.copy_(cpu_src); - - error = aclrtMemcpy( - self.data_ptr(), - real_bytes, - cpu_dst.data_ptr(), - real_bytes, - ACL_MEMCPY_HOST_TO_DEVICE); - if (error != ACL_ERROR_NONE) { - AT_ERROR("aclrtMemcpy cpu_dst to device error."); - return; - } - NPU_LOGD("Src or dst is not contiguous when do device to device copy."); -} - // NOTE: helper function of copy, the input parameter is not checked, The caller // needs to ensure that the parameters are correct. @@ -132,14 +63,8 @@ void copy_d2d_last_method( bool same_type, bool non_blocking) { // general copy method but Low performance - if (torch_npu::option::OptionsManager::CheckPTcopy_Enable()) { - RECORD_FUNCTION("d2dCopyWithPTCopy", std::vector({src})); - copy_kernel_npu(self, src, non_blocking); - } else { - RECORD_FUNCTION( - "d2dCopyWithStreamSynchronize", std::vector({src})); - copy_d2d_via_host(self, src, same_type); - } + RECORD_FUNCTION("d2dCopyWithPTCopy", std::vector({src})); + copy_kernel_npu(self, src, non_blocking); } // the dst and src are same format now @@ -150,15 +75,22 @@ void copy_d2d_dtype_baseformat( const at::Tensor& src, bool non_blocking) { if (!self.is_contiguous()) { + // Contiguous/discontiguous source tensor copy to discontiguous self tensor return copy_d2d_last_method(self, src, true, non_blocking); } if (!src.is_contiguous()) { - // discontiguous + // Discontiguous source tensor copy to contiguous self tensor if (TransContiguous::ContiguousOptimizeWithBaseFormat(self, src)) { + // Optimized trans-contiguous method + return; + } else { + // General trans-contiguous method + NPUNativeFunctions::npu_stride_copy_out(self, src, src.sizes(), src.strides(), src.storage_offset()); return; } } else { + // Contiguous source tensor copy to contiguous self tensor int64_t numel = self.numel(); if (numel == src.numel()) { RECORD_FUNCTION("d2dCopyAsync", std::vector({src})); diff --git a/torch_npu/csrc/aten/common/CopyKernelNpu.cpp b/torch_npu/csrc/aten/common/CopyKernelNpu.cpp index a283c31d63d..cb669c4b2df 100644 --- a/torch_npu/csrc/aten/common/CopyKernelNpu.cpp +++ b/torch_npu/csrc/aten/common/CopyKernelNpu.cpp @@ -28,7 +28,7 @@ c10::SmallVector get_view_value( static c10::SmallVector value; // It is determined by the definition of view attr value.resize(strides.size() + 3); - value[0] = t.numel(); // storageImpl numel + value[0] = t.storage().nbytes() / t.element_size(); // storageImpl numel value[1] = t.storage_offset(); // default to 0 value[2] = strides.size(); for (size_t i = 0; i < strides.size(); i++) { diff --git a/torch_npu/csrc/aten/npu_native_functions.yaml b/torch_npu/csrc/aten/npu_native_functions.yaml index c2624df5c34..5cbdf152bf2 100644 --- a/torch_npu/csrc/aten/npu_native_functions.yaml +++ b/torch_npu/csrc/aten/npu_native_functions.yaml @@ -1914,6 +1914,9 @@ custom: variants: function, method - func: npu_softmax_cross_entropy_with_logits_backward(Tensor grad, Tensor self, Tensor labels) -> Tensor variants: function, method + - func: npu_stride_copy(Tensor self, int[] shape, int[] stride, Scalar storage_offset) -> Tensor + variants: function, method + - func: npu_stride_copy.out(Tensor self, int[] shape, int[] stride, Scalar storage_offset, *, Tensor(a!) out) -> Tensor(a!) custom_autograd: - func: npu_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor - func: npu_convolution_transpose(Tensor input, Tensor weight, Tensor? bias, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups) -> Tensor diff --git a/torch_npu/csrc/aten/ops/AsStridedKernelNpu.cpp b/torch_npu/csrc/aten/ops/AsStridedKernelNpu.cpp new file mode 100644 index 00000000000..827974f013a --- /dev/null +++ b/torch_npu/csrc/aten/ops/AsStridedKernelNpu.cpp @@ -0,0 +1,63 @@ +// Copyright (c) 2020, Huawei Technologies.All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" +#include + +namespace at_npu { +namespace native { + +at::Tensor& stride_copy_out_npu_nocheck( + at::Tensor& result, + const at::Tensor& self, + at::IntArrayRef shape, + at::IntArrayRef stride, + at::Scalar storage_offset) { + RECORD_FUNCTION("npuAsStrided", std::vector({self})); + OpCommand cmd; + cmd.Name("AsStrided") + .InputWithoutContiguous(self) + .Input(shape) + .Input(stride) + .Input(storage_offset, at::kLong) + .Output(result) + .Run(); + return result; +} + +at::Tensor& NPUNativeFunctions::npu_stride_copy_out( + at::Tensor& result, + const at::Tensor& self, + at::IntArrayRef shape, + at::IntArrayRef stride, + at::Scalar storage_offset) { + stride_copy_out_npu_nocheck(result, self, shape, stride, storage_offset); + return result; +} + +at::Tensor NPUNativeFunctions:: npu_stride_copy( + const at::Tensor& self, + at::IntArrayRef shape, + at::IntArrayRef stride, + at::Scalar storage_offset) { + // AsStrided OP only supports ND input + at::Tensor result = OpPreparation::ApplyTensorWithFormat( + shape, self.options(), ACL_FORMAT_ND); + stride_copy_out_npu_nocheck(result, self, shape, stride, storage_offset); + return result; +} + +} // namespace native +} // namespace at \ No newline at end of file diff --git a/torch_npu/csrc/framework/OpCommandBase.h b/torch_npu/csrc/framework/OpCommandBase.h index 069c2e4e821..7f566693488 100644 --- a/torch_npu/csrc/framework/OpCommandBase.h +++ b/torch_npu/csrc/framework/OpCommandBase.h @@ -72,6 +72,15 @@ namespace at_npu return static_cast(*this); } + Derived &InputWithoutContiguous( + const Tensor &input, + const string &descName = "", + const optional &sensitive_format = nullopt, + const string &realData = "") + { + return AddTensorInput(const_cast(input), ScalarType::Undefined, descName, realData); + } + Derived &Input() { return AddNoneTensor(); diff --git a/torch_npu/csrc/framework/contiguous/slice_opt.cpp b/torch_npu/csrc/framework/contiguous/slice_opt.cpp index 0eeea31d2ce..c43cb5cf6fa 100644 --- a/torch_npu/csrc/framework/contiguous/slice_opt.cpp +++ b/torch_npu/csrc/framework/contiguous/slice_opt.cpp @@ -27,7 +27,7 @@ namespace at_npu public: bool Optimizer(const at::Tensor &src, at::Tensor &self) override { - // Pattern slice. Current pattern should be used before PTcopy process. + // Pattern slice. // Current pattern does not directly depend on other patterns. // The relative sequence of this pattern and other patterns is not important. c10::SmallVector offsets; diff --git a/torch_npu/csrc/register/OptionsManager.cpp b/torch_npu/csrc/register/OptionsManager.cpp index 0afa6ce4d49..cb93b59f9df 100644 --- a/torch_npu/csrc/register/OptionsManager.cpp +++ b/torch_npu/csrc/register/OptionsManager.cpp @@ -30,14 +30,6 @@ bool OptionsManager::CheckQueueEnable() { return (queue_enable == 1); } -bool OptionsManager::CheckPTcopy_Enable() { - static int32_t PTcopy__enable = -1; - if (PTcopy__enable == -1) { - PTcopy__enable = GetBoolTypeOption("PTCOPY_ENABLE"); - } - return (PTcopy__enable == 1); -} - bool OptionsManager::CheckCombinedOptimizerEnable() { static int32_t combined_optimize = -1; if (combined_optimize == -1) { diff --git a/torch_npu/csrc/register/OptionsManager.h b/torch_npu/csrc/register/OptionsManager.h index dd59ea6513c..0bae5ab2005 100644 --- a/torch_npu/csrc/register/OptionsManager.h +++ b/torch_npu/csrc/register/OptionsManager.h @@ -27,7 +27,6 @@ namespace option { class OptionsManager { public: static bool CheckQueueEnable(); - static bool CheckPTcopy_Enable(); static bool CheckCombinedOptimizerEnable(); static bool CheckTriCombinedOptimizerEnable(); static bool CheckAclDumpDateEnable(); diff --git a/torch_npu/testing/util_test.py b/torch_npu/testing/util_test.py index a460af4ae33..9d384a822fc 100644 --- a/torch_npu/testing/util_test.py +++ b/torch_npu/testing/util_test.py @@ -126,4 +126,22 @@ def create_dtype_tensor(shape, dtype, npu_format=-1, min_value=-5, max_value=5, npu_input = torch.from_numpy(x).to(npu_device) if npu_format != -1 and (dtype in [torch.float, torch.half]): npu_input = npu_input.npu_format_cast(npu_format) - return cpu_input, npu_input \ No newline at end of file + return cpu_input, npu_input + +def check_operators_in_prof(expected_operators, prof, unexpected_operators=None): + unexpected_operators = unexpected_operators or [] + prof_key_averages = prof.key_averages() + if not prof_key_averages: + return print("torch profiling is empty, please check it") + for prof_item in prof_key_averages: + if prof_item.key in unexpected_operators: + # if unexpected oprators are called, pattern inferring in trans-contiguous is failed + return False + elif prof_item.key in expected_operators: + # if expected oprator is called, empty it in expected_operators list + expected_operators.remove(prof_item.key) + + # if expected_operators list is empty, all oprators have been called + if not expected_operators: + return True + return False \ No newline at end of file -- Gitee From 81961271222dc87472d28b44cabc46e83f7933a7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E8=B6=85?= Date: Mon, 14 Feb 2022 14:35:09 +0800 Subject: [PATCH 2/8] replace ptcopy with asstride in trans-contiguous-1 --- torch_npu/csrc/framework/OpCommandBase.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_npu/csrc/framework/OpCommandBase.h b/torch_npu/csrc/framework/OpCommandBase.h index 7f566693488..6e420484e13 100644 --- a/torch_npu/csrc/framework/OpCommandBase.h +++ b/torch_npu/csrc/framework/OpCommandBase.h @@ -73,7 +73,7 @@ namespace at_npu } Derived &InputWithoutContiguous( - const Tensor &input, + const at::Tensor &input, const string &descName = "", const optional &sensitive_format = nullopt, const string &realData = "") -- Gitee From 1953b8c2585c9f3b16c919d2f0ef1dbb699ac9d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E8=B6=85?= Date: Mon, 14 Feb 2022 14:36:15 +0800 Subject: [PATCH 3/8] replace ptcopy with asstride in trans-contiguous-2 --- torch_npu/csrc/framework/OpCommandBase.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_npu/csrc/framework/OpCommandBase.h b/torch_npu/csrc/framework/OpCommandBase.h index 6e420484e13..aace3820f5a 100644 --- a/torch_npu/csrc/framework/OpCommandBase.h +++ b/torch_npu/csrc/framework/OpCommandBase.h @@ -78,7 +78,7 @@ namespace at_npu const optional &sensitive_format = nullopt, const string &realData = "") { - return AddTensorInput(const_cast(input), ScalarType::Undefined, descName, realData); + return AddTensorInput(const_cast(input), at::ScalarType::Undefined, descName, realData); } Derived &Input() -- Gitee From edd3f167e623c8239239c3dbebdc8475f102ffb3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E8=B6=85?= Date: Mon, 14 Feb 2022 14:37:31 +0800 Subject: [PATCH 4/8] replace ptcopy with asstride in trans-contiguous-3 --- torch_npu/csrc/framework/OpCommandBase.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_npu/csrc/framework/OpCommandBase.h b/torch_npu/csrc/framework/OpCommandBase.h index aace3820f5a..abe6ed6c6b0 100644 --- a/torch_npu/csrc/framework/OpCommandBase.h +++ b/torch_npu/csrc/framework/OpCommandBase.h @@ -75,7 +75,7 @@ namespace at_npu Derived &InputWithoutContiguous( const at::Tensor &input, const string &descName = "", - const optional &sensitive_format = nullopt, + const c10::optional &sensitive_format = nullopt, const string &realData = "") { return AddTensorInput(const_cast(input), at::ScalarType::Undefined, descName, realData); -- Gitee From b35a803a920e412c4149638ecf73ead6ceb08f19 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E8=B6=85?= Date: Mon, 14 Feb 2022 14:38:12 +0800 Subject: [PATCH 5/8] replace ptcopy with asstride in trans-contiguous-4 --- torch_npu/csrc/framework/OpCommandBase.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_npu/csrc/framework/OpCommandBase.h b/torch_npu/csrc/framework/OpCommandBase.h index abe6ed6c6b0..61e3692e77d 100644 --- a/torch_npu/csrc/framework/OpCommandBase.h +++ b/torch_npu/csrc/framework/OpCommandBase.h @@ -75,7 +75,7 @@ namespace at_npu Derived &InputWithoutContiguous( const at::Tensor &input, const string &descName = "", - const c10::optional &sensitive_format = nullopt, + const c10::optional &sensitive_format = c10::nullopt, const string &realData = "") { return AddTensorInput(const_cast(input), at::ScalarType::Undefined, descName, realData); -- Gitee From df4ed058a23b7603d16b18941728ba147230be34 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E8=B6=85?= Date: Mon, 14 Feb 2022 16:12:40 +0800 Subject: [PATCH 6/8] replace ptcopy with asstride in trans-contiguous-5 --- torch_npu/csrc/aten/npu_native_functions.yaml | 1 - torch_npu/csrc/framework/OpCommandBase.h | 1 - 2 files changed, 2 deletions(-) diff --git a/torch_npu/csrc/aten/npu_native_functions.yaml b/torch_npu/csrc/aten/npu_native_functions.yaml index 5cbdf152bf2..9a647a8db46 100644 --- a/torch_npu/csrc/aten/npu_native_functions.yaml +++ b/torch_npu/csrc/aten/npu_native_functions.yaml @@ -1915,7 +1915,6 @@ custom: - func: npu_softmax_cross_entropy_with_logits_backward(Tensor grad, Tensor self, Tensor labels) -> Tensor variants: function, method - func: npu_stride_copy(Tensor self, int[] shape, int[] stride, Scalar storage_offset) -> Tensor - variants: function, method - func: npu_stride_copy.out(Tensor self, int[] shape, int[] stride, Scalar storage_offset, *, Tensor(a!) out) -> Tensor(a!) custom_autograd: - func: npu_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor diff --git a/torch_npu/csrc/framework/OpCommandBase.h b/torch_npu/csrc/framework/OpCommandBase.h index 61e3692e77d..bb27053170f 100644 --- a/torch_npu/csrc/framework/OpCommandBase.h +++ b/torch_npu/csrc/framework/OpCommandBase.h @@ -75,7 +75,6 @@ namespace at_npu Derived &InputWithoutContiguous( const at::Tensor &input, const string &descName = "", - const c10::optional &sensitive_format = c10::nullopt, const string &realData = "") { return AddTensorInput(const_cast(input), at::ScalarType::Undefined, descName, realData); -- Gitee From 6426d3661ba8a591940a5e0d97fb1f9daa7f223a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E8=B6=85?= Date: Mon, 14 Feb 2022 16:24:08 +0800 Subject: [PATCH 7/8] replace ptcopy with asstride in trans-contiguous-6 --- torch_npu/csrc/aten/common/CopyKernel.cpp | 2 +- torch_npu/csrc/aten/ops/AsStridedKernelNpu.cpp | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/torch_npu/csrc/aten/common/CopyKernel.cpp b/torch_npu/csrc/aten/common/CopyKernel.cpp index a2473cc62e4..e48705bee02 100644 --- a/torch_npu/csrc/aten/common/CopyKernel.cpp +++ b/torch_npu/csrc/aten/common/CopyKernel.cpp @@ -86,7 +86,7 @@ void copy_d2d_dtype_baseformat( return; } else { // General trans-contiguous method - NPUNativeFunctions::npu_stride_copy_out(self, src, src.sizes(), src.strides(), src.storage_offset()); + NPUNativeFunctions::npu_stride_copy_out(src, src.sizes(), src.strides(), src.storage_offset(), self); return; } } else { diff --git a/torch_npu/csrc/aten/ops/AsStridedKernelNpu.cpp b/torch_npu/csrc/aten/ops/AsStridedKernelNpu.cpp index 827974f013a..fbb3f1e0e76 100644 --- a/torch_npu/csrc/aten/ops/AsStridedKernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/AsStridedKernelNpu.cpp @@ -38,11 +38,11 @@ at::Tensor& stride_copy_out_npu_nocheck( } at::Tensor& NPUNativeFunctions::npu_stride_copy_out( - at::Tensor& result, const at::Tensor& self, at::IntArrayRef shape, at::IntArrayRef stride, - at::Scalar storage_offset) { + at::Scalar storage_offset, + at::Tensor& result) { stride_copy_out_npu_nocheck(result, self, shape, stride, storage_offset); return result; } -- Gitee From 6613ef181592a8f8fd8f2ea8d2a8b2f3a7724d50 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E8=B6=85?= Date: Mon, 14 Feb 2022 16:30:28 +0800 Subject: [PATCH 8/8] replace ptcopy with asstride in trans-contiguous-7 --- torch_npu/csrc/aten/npu_native_functions.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/torch_npu/csrc/aten/npu_native_functions.yaml b/torch_npu/csrc/aten/npu_native_functions.yaml index 9a647a8db46..5cbdf152bf2 100644 --- a/torch_npu/csrc/aten/npu_native_functions.yaml +++ b/torch_npu/csrc/aten/npu_native_functions.yaml @@ -1915,6 +1915,7 @@ custom: - func: npu_softmax_cross_entropy_with_logits_backward(Tensor grad, Tensor self, Tensor labels) -> Tensor variants: function, method - func: npu_stride_copy(Tensor self, int[] shape, int[] stride, Scalar storage_offset) -> Tensor + variants: function, method - func: npu_stride_copy.out(Tensor self, int[] shape, int[] stride, Scalar storage_offset, *, Tensor(a!) out) -> Tensor(a!) custom_autograd: - func: npu_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor -- Gitee