From f2b6ec68009b88ebe52688d3c52e2e24e19a5e97 Mon Sep 17 00:00:00 2001 From: cuishiang Date: Thu, 10 Feb 2022 11:19:21 +0800 Subject: [PATCH 1/2] =?UTF-8?q?1.=E6=B7=BB=E5=8A=A0=E4=BA=86=E9=92=88?= =?UTF-8?q?=E5=AF=B9=E6=B5=8B=E8=AF=95=E6=A1=86=E6=9E=B6=E7=9A=84=E6=B5=8B?= =?UTF-8?q?=E8=AF=95=E7=94=A8=E4=BE=8B=202.=E4=BF=AE=E5=A4=8D=E4=BA=86?= =?UTF-8?q?=E8=B0=83=E7=94=A8=E6=B5=8B=E8=AF=95=E6=A1=86=E6=9E=B6TestCase?= =?UTF-8?q?=E6=96=AD=E8=A8=80=E6=96=B9=E6=B3=95=E9=94=99=E8=AF=AF=E7=9A=84?= =?UTF-8?q?=E9=97=AE=E9=A2=98=203.=E6=B7=BB=E5=8A=A0=E4=BA=86TestCase?= =?UTF-8?q?=E7=9A=84setUpClass=E6=96=B9=E6=B3=95=204.=E5=88=9B=E5=BB=BA?= =?UTF-8?q?=E4=BA=86set=5Fnpu(),=20get=5Fnpu()=E5=87=BD=E6=95=B0=205.?= =?UTF-8?q?=E4=BF=AE=E6=94=B9=E4=BA=86util=5Ftest.py=E4=B8=AD=E9=83=A8?= =?UTF-8?q?=E5=88=86=E5=88=9B=E5=BB=BATensor=E7=9A=84=E5=87=BD=E6=95=B0,?= =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E4=BA=86device=E5=8F=82=E6=95=B0(device=5Fid?= =?UTF-8?q?)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/test_testing.py | 48 +++++++++++++++++++++ torch_npu/testing/common_device_type.py | 11 +---- torch_npu/testing/common_utils.py | 7 +++- torch_npu/testing/util_test.py | 56 +++++++++++++++---------- 4 files changed, 90 insertions(+), 32 deletions(-) create mode 100644 test/test_testing.py diff --git a/test/test_testing.py b/test/test_testing.py new file mode 100644 index 00000000000..3ddfb5337a0 --- /dev/null +++ b/test/test_testing.py @@ -0,0 +1,48 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests, Dtypes, Formats +from torch_npu.testing.util_test import create_dtype_tensor + + +# For testing TestCase methods and torch_npu.testing functions +class TestTesting(TestCase): + # Ensure that assertEqual handles cpu arrays properly + @Dtypes(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64, + torch.bool, + torch.complex64, torch.complex128) + @Formats(0, 3, 4, 29) + def test_assert_equal_cpu(self, device, dtype, npu_format): + S = 10 + test_sizes = [ + (), + (0,), + (S,), + (S, S), + (0, S), + (S, 0) + ] + for test_size in test_sizes: + a_cpu, a_npu = create_dtype_tensor(test_size, dtype, npu_format, device=device) + msg = f'Device: {device} Size: {test_size} Dtype: {dtype} Npu_format: {npu_format}' + self.assertEqual(a_cpu, a_npu, message=msg) + + +instantiate_device_type_tests(TestTesting, globals(), except_for="cpu") + +if __name__ == '__main__': + run_tests() \ No newline at end of file diff --git a/torch_npu/testing/common_device_type.py b/torch_npu/testing/common_device_type.py index b85f001bc0d..37e74ff3fd9 100644 --- a/torch_npu/testing/common_device_type.py +++ b/torch_npu/testing/common_device_type.py @@ -19,7 +19,8 @@ import threading from functools import wraps import unittest import torch -from torch.testing._internal.common_utils import TestCase, TEST_MKL +from torch.testing._internal.common_utils import TEST_MKL +from torch_npu.testing.common_utils import TestCase # Note: Generic Device-Type Testing # @@ -87,14 +88,6 @@ from torch.testing._internal.common_utils import TestCase, TEST_MKL # becomes test_erfinv_cpu_float, test_erfinv_cpu_double, test_erfinv_npu_half, # ... # -# In short, if you write a test signature like -# def textX(self, device) -# You are effectively writing -# def testX_cpu(self, device='cpu') -# def textX_npu(self, device='npu') -# def testX_xla(self, device='xla') -# ... -# # These tests can be run directly like normal tests: # "python test_torch.py TestTorchDeviceTypeCPU.test_diagonal_cpu" # diff --git a/torch_npu/testing/common_utils.py b/torch_npu/testing/common_utils.py index 5b2143b8c72..34004a84a36 100644 --- a/torch_npu/testing/common_utils.py +++ b/torch_npu/testing/common_utils.py @@ -64,6 +64,7 @@ import torch.backends.mkl from enum import Enum from torch.autograd import gradcheck from torch.autograd.gradcheck import gradgradcheck +from torch_npu.testing.util_test import set_npu_device torch.backends.disable_global_flags() @@ -587,9 +588,11 @@ class TestCase(expecttest.TestCase): def __init__(self, method_name='runTest'): super(TestCase, self).__init__(method_name) - def setUp(self): - + @classmethod + def setUpClass(self): + self.npu_device = set_npu_device() + def setUp(self): if TEST_SKIP_FAST: if not getattr(self, self._testMethodName).__dict__.get('slow_test', False): raise unittest.SkipTest("test is fast; we disabled it with PYTORCH_TEST_SKIP_FAST") diff --git a/torch_npu/testing/util_test.py b/torch_npu/testing/util_test.py index 835814c30d4..40c00f3d3bb 100644 --- a/torch_npu/testing/util_test.py +++ b/torch_npu/testing/util_test.py @@ -19,29 +19,36 @@ import torch_npu import numpy as np import os -threshold = 1.e-4 -threshold2 = 1.e-3 -npu_device = os.environ.get('SET_NPU_DEVICE') -if npu_device is None: - npu_device = "npu:0" -else: - npu_device = f"npu:{npu_device}" -torch.npu.set_device(npu_device) -print(f"Your device is {npu_device}") +UT_FAST_MODE = os.getenv('UT_FAST_MODE') == '1' -threshold = 1.e-4 -threshold2 = 1.e-3 -UT_FAST_MODE = os.getenv('UT_FAST_MODE') == '1' +def set_npu_device(): + npu_device = get_npu_device() + torch.npu.set_device(npu_device) + print(f"Your device is {npu_device}") + return npu_device -def create_common_tensor(item, minValue, maxValue): + +def get_npu_device(): + npu_device = os.environ.get('SET_NPU_DEVICE') + if npu_device is None: + npu_device = "npu:0" + else: + npu_device = f"npu:{npu_device}" + return npu_device + + +def create_common_tensor(item, minValue, maxValue, device=None): + if device is None: + device = get_npu_device() + dtype = item[0] npu_format = item[1] shape = item[2] input1 = np.random.uniform(minValue, maxValue, shape).astype(dtype) cpu_input = torch.from_numpy(input1) - npu_input = torch.from_numpy(input1).to(npu_device) + npu_input = torch.from_numpy(input1).to(device) if npu_format != -1: npu_input = torch_npu.npu_format_cast(npu_input, npu_format) return cpu_input, npu_input @@ -72,16 +79,19 @@ def compare_res_new(cpu_output, npu_output, testcase_name): print('testcase_name={0}, datatype={1} shape={2} pass!'.format(testcase_name, cpu_output.dtype, cpu_output.shape)) -def __generate_2args_broadcast_cases(): +def __generate_2args_broadcast_cases(device=None): + if device is None: + device = get_npu_device() + # Set broadcast and no axis, i.e. broadcasting 1. X = np.random.rand(2, 3, 4, 5).astype(np.float32) Y = np.random.rand(1, 1, 1).astype(np.float32) cpu_x = torch.from_numpy(X) - npu_x = torch.from_numpy(X).to(npu_device) + npu_x = torch.from_numpy(X).to(device) cpu_y = torch.from_numpy(Y) - npu_y = torch.from_numpy(Y).to(npu_device) + npu_y = torch.from_numpy(Y).to(device) yield cpu_x, cpu_y, npu_x, npu_y @@ -90,10 +100,10 @@ def __generate_2args_broadcast_cases(): Y = np.random.rand(4, 5).astype(np.float32) cpu_x = torch.from_numpy(X) - npu_x = torch.from_numpy(X).to(npu_device) + npu_x = torch.from_numpy(X).to(device) cpu_y = torch.from_numpy(Y) - npu_y = torch.from_numpy(Y).to(npu_device) + npu_y = torch.from_numpy(Y).to(device) yield cpu_x, cpu_y, npu_x, npu_y @@ -106,7 +116,11 @@ def test_2args_broadcast(fn): return output_list -def create_dtype_tensor(shape, dtype, npu_format=-1, min_value=-5, max_value=5, no_zero=False): + +def create_dtype_tensor(shape, dtype, npu_format=-1, min_value=-5, max_value=5, no_zero=False, device=None): + if device is None: + device = get_npu_device() + if dtype == torch.bool: x = np.random.randint(0, 2, size=shape).astype(np.bool) @@ -124,7 +138,7 @@ def create_dtype_tensor(shape, dtype, npu_format=-1, min_value=-5, max_value=5, x = np.where(x != 0, x, ones) cpu_input = torch.from_numpy(x) - npu_input = torch.from_numpy(x).to(npu_device) + npu_input = torch.from_numpy(x).to(device) if npu_format != -1 and (dtype in [torch.float, torch.half]): npu_input = torch_npu.npu_format_cast(npu_input, npu_format) return cpu_input, npu_input \ No newline at end of file -- Gitee From 2847551692cddf44591773636eca0335dbe07971 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E8=B6=85?= Date: Mon, 14 Feb 2022 18:31:37 +0800 Subject: [PATCH 2/2] replace ptcopy with asstride in trans-contiguous Add Var Operator --- README.en.md | 1 - README.zh.md | 1 - env.sh | 1 - test/test_network_ops/test_var.py | 430 ++++++++++++++++++ .../test_as_strided_copy_to_contiguous.py | 70 +++ ...t_tri_combined_views_copy_to_contiguous.py | 97 ++++ torch_npu/csrc/aten/common/CopyKernel.cpp | 88 +--- torch_npu/csrc/aten/common/CopyKernelNpu.cpp | 2 +- torch_npu/csrc/aten/npu_native_functions.yaml | 2 + .../csrc/aten/ops/AsStridedKernelNpu.cpp | 63 +++ torch_npu/csrc/aten/ops/VarKernelNpu.cpp | 225 +++++++++ torch_npu/csrc/framework/OpCommandBase.h | 8 + .../csrc/framework/contiguous/slice_opt.cpp | 2 +- torch_npu/csrc/register/OptionsManager.cpp | 8 - torch_npu/csrc/register/OptionsManager.h | 1 - torch_npu/testing/util_test.py | 20 +- 16 files changed, 926 insertions(+), 93 deletions(-) create mode 100644 test/test_network_ops/test_var.py create mode 100644 test/test_trans_contiguous/test_as_strided_copy_to_contiguous.py create mode 100644 test/test_trans_contiguous/test_tri_combined_views_copy_to_contiguous.py create mode 100644 torch_npu/csrc/aten/ops/AsStridedKernelNpu.cpp create mode 100644 torch_npu/csrc/aten/ops/VarKernelNpu.cpp diff --git a/README.en.md b/README.en.md index 427b8a26456..a7519564f45 100644 --- a/README.en.md +++ b/README.en.md @@ -73,7 +73,6 @@ The following environment variables are function classes used in NPU scenarios o ``` export TASK_QUEUE_ENABLE=1 # Delivered by an asynchronous task to asynchronously call the ACL interface. You are advised to enable this environment variable and set its value to 1. -export PTCOPY_ENABLE=1 # Use the PTCopy operator mode to accelerate continuous rotation and copy. You are advised to enable this environment variable and set its value to 1. ``` The following are optional environment variables that may affect running models: diff --git a/README.zh.md b/README.zh.md index 262022df682..e9fcb196e59 100644 --- a/README.zh.md +++ b/README.zh.md @@ -77,7 +77,6 @@ source pytorch/env.sh ``` export TASK_QUEUE_ENABLE=1 # 使用异步任务下发,异步调用acl接口,建议默认开启,开启设置为1 -export PTCOPY_ENABLE=1 # 使用PTCopy算子模式,加速转连续及copy等过程,建议默认开启,开启设置为1 ``` 可选的环境变量可能会对运行的模型产生影响: diff --git a/env.sh b/env.sh index d35a02be766..d17e4862485 100644 --- a/env.sh +++ b/env.sh @@ -96,7 +96,6 @@ export LD_LIBRARY_PATH=${path_lib}:$LD_LIBRARY_PATH # pytorch 自定义环境变量 export TASK_QUEUE_ENABLE=0 # 使用异步任务下发,异步调用acl接口,建议默认开启,开启设置为1 -export PTCOPY_ENABLE=1 # 使用PTCopy算子模式,加速转连续及copy等过程,建议默认开启,开启设置为1 #export DYNAMIC_COMPILE_ENABLE=1 # 动态shape特性功能,针对shape变化场景,可选 开启设置为1 # log diff --git a/test/test_network_ops/test_var.py b/test/test_network_ops/test_var.py new file mode 100644 index 00000000000..7b07deed853 --- /dev/null +++ b/test/test_network_ops/test_var.py @@ -0,0 +1,430 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd +# Copyright (c) 2019, Facebook CORPORATION. +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class TestVar(TestCase): + def cpu_op_exec(self, input1, unbiased=True): + output = torch.var(input1, unbiased=unbiased) + output = output.numpy() + return output + + def npu_op_exec(self, input1, unbiased=True): + output = torch.var(input1, unbiased=unbiased) + output = output.to("cpu") + output = output.numpy() + return output + + def cpu_op_dim_exec(self, input1, dim, unbiased=True, keepdim=False): + output = torch.var(input1, dim, unbiased=unbiased, keepdim=keepdim) + output = output.numpy() + return output + + def npu_op_dim_exec(self, input1, dim, unbiased=True, keepdim=False): + output = torch.var(input1, dim, unbiased=unbiased, keepdim=keepdim) + output = output.to("cpu") + output = output.numpy() + return output + + def cpu_op_names_dim_exec(self, input1, dim, unbiased=True, keepdim=False): + output = torch.var(input1, dim, unbiased=unbiased, keepdim=keepdim) + output = output.numpy() + return output + + def npu_op_names_dim_exec(self, input1, dim, unbiased=True, keepdim=False): + output = torch.var(input1, dim, unbiased=unbiased, keepdim=keepdim) + output = output.to("cpu") + output = output.numpy() + return output + + def cpu_op_out_exec(self, input1, dim, output1, unbiased=True, keepdim=False): + torch.var(input1, dim, unbiased=unbiased, keepdim=keepdim, out=output1) + output1 = output1.numpy() + return output1 + + def npu_op_out_exec(self, input1, dim, output1, unbiased=True, keepdim=False): + torch.var(input1, dim, unbiased=unbiased, keepdim=keepdim, out=output1) + output1 = output1.to("cpu") + output1 = output1.numpy() + return output1 + + def cpu_op_var_exec(self, input1, unbiased=True): + output = torch.var(input1, unbiased=unbiased) + output = output.numpy() + return output + + def npu_op_var_exec(self, input1, unbiased=True): + output = torch.var(input1, unbiased=unbiased) + output = output.to("cpu") + output = output.numpy() + return output + + def cpu_op_mean_exec(self, input1, unbiased=True): + output = torch.var_mean(input1, unbiased=unbiased) + output1 = output[0] + output2 = output[1] + output1 = output1.numpy() + output2 = output2.numpy() + return output1, output2 + + def npu_op_mean_exec(self, input1, unbiased=True): + output = torch.var_mean(input1, unbiased=unbiased) + output1 = output[0].to("cpu") + output2 = output[1].to("cpu") + output1 = output1.numpy() + output2 = output2.numpy() + return output1, output2 + + def cpu_op_mean_dim_exec(self, input1, dim, unbiased=True, keepdim=False): + output = torch.var_mean(input1, dim, unbiased=unbiased, keepdim=keepdim) + output1 = output[0] + output2 = output[1] + output1 = output1.numpy() + output2 = output2.numpy() + return output1, output2 + + def npu_op_mean_dim_exec(self, input1, dim, unbiased=True, keepdim=False): + output = torch.var_mean(input1, dim, unbiased=unbiased, keepdim=keepdim) + output1 = output[0].to("cpu") + output2 = output[1].to("cpu") + output1 = output1.numpy() + output2 = output2.numpy() + return output1, output2 + + def cpu_op_mean_names_dim_exec(self, input1, dim, unbiased=True, keepdim=False): + output = torch.var_mean(input1, dim, unbiased=unbiased, keepdim=keepdim) + output1 = output[0] + output2 = output[1] + output1 = output1.numpy() + output2 = output2.numpy() + return output1, output2 + + def npu_op_mean_names_dim_exec(self, input1, dim, unbiased=True, keepdim=False): + output = torch.var_mean(input1, dim, unbiased=unbiased, keepdim=keepdim) + output1 = output[0].to("cpu") + output2 = output[1].to("cpu") + output1 = output1.numpy() + output2 = output2.numpy() + return output1, output2 + + def output_shape(self, inputshape, dim, unbiased=True, keepdim=False): + shape = list(inputshape) + len1 = len(inputshape) + if dim < len1 and keepdim == True: + shape[dim] = 1 + elif dim < len1 and keepdim == False: + shape.pop(dim) + return shape + + def create_output_tensor(self, minvalue, maxvalue, shape, npuformat, dtype): + input1 = np.random.uniform(minvalue, maxvalue, shape).astype(dtype) + cpu_input = torch.from_numpy(input1) + npu_input = torch.from_numpy(input1).npu() + if npuformat != -1: + npu_input = npu_input.npu_format_cast(npuformat) + return cpu_input, npu_input + + def test_var_shape_format_fp16(self, device): + format_list = [-1] + shape_list = [[32, 24], [32, 8, 24]] + unbiased_list = [True, False] + shape_format = [ + [np.float16, i, j, k] for i in format_list for j in shape_list for k in unbiased_list + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 0, 100) + cpu_input = cpu_input.to(torch.float32) + cpu_output = self.cpu_op_exec(cpu_input, item[3]) + cpu_output = cpu_output.astype(np.float16) + npu_output = self.npu_op_exec(npu_input, item[3]) + self.assertRtolEqual(cpu_output, npu_output) + + def test_var_shape_format_fp32(self, device): + format_list = [-1] + shape_list = [[32, 24], [32, 8, 24]] + unbiased_list = [True, False] + shape_format = [ + [np.float32, i, j, k] for i in format_list for j in shape_list for k in unbiased_list + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 0, 100) + cpu_output = self.cpu_op_exec(cpu_input, item[3]) + npu_output = self.npu_op_exec(npu_input, item[3]) + self.assertRtolEqual(cpu_output, npu_output) + + def test_var_dim_shape_format_fp16(self, device): + format_list = [-1] + shape_list = [[32, 24], [32, 8, 24]] + dim_list = [0] + unbiased_list = [True, False] + keepdim_list = [True, False] + shape_format = [ + [np.float16, i, j, k, l, m] for i in format_list for j in shape_list + for k in dim_list for l in unbiased_list for m in keepdim_list + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 0, 100) + cpu_input = cpu_input.to(torch.float32) + cpu_output = self.cpu_op_dim_exec(cpu_input, item[3], item[4], item[5]) + cpu_output = cpu_output.astype(np.float16) + npu_output = self.npu_op_dim_exec(npu_input, item[3], item[4], item[5]) + self.assertRtolEqual(cpu_output, npu_output) + + def test_var_dim_shape_format_fp32(self, device): + format_list = [-1] + shape_list = [[32, 24], [32, 8, 24]] + dim_list = [0] + unbiased_list = [True, False] + keepdim_list = [True, False] + shape_format = [ + [np.float32, i, j, k, l, m] for i in format_list for j in shape_list + for k in dim_list for l in unbiased_list for m in keepdim_list + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 0, 100) + cpu_output = self.cpu_op_dim_exec(cpu_input, item[3], item[4], item[5]) + npu_output = self.npu_op_dim_exec(npu_input, item[3], item[4], item[5]) + self.assertRtolEqual(cpu_output, npu_output) + + def test_var_names_dim_shape_format_fp16(self, device): + format_list = [-1] + shape_list = [[32, 24], [32, 8, 24]] + dim_list = [0] + unbiased_list = [True, False] + keepdim_list = [True, False] + shape_format = [ + [np.float16, i, j, k, l, m] for i in format_list for j in shape_list + for k in dim_list for l in unbiased_list for m in keepdim_list + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 0, 100) + cpu_input = cpu_input.to(torch.float32) + cpu_output = self.cpu_op_names_dim_exec(cpu_input, item[3], item[4], item[5]) + cpu_output = cpu_output.astype(np.float16) + npu_output = self.npu_op_names_dim_exec(npu_input, item[3], item[4], item[5]) + self.assertRtolEqual(cpu_output, npu_output) + + def test_var_names_dim_shape_format_fp32(self, device): + format_list = [-1] + shape_list = [[32, 24], [32, 8, 24]] + dim_list = [0] + unbiased_list = [True, False] + keepdim_list = [True, False] + shape_format = [ + [np.float32, i, j, k, l, m] for i in format_list for j in shape_list + for k in dim_list for l in unbiased_list for m in keepdim_list + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 0, 100) + cpu_output = self.cpu_op_names_dim_exec(cpu_input, item[3], item[4], item[5]) + npu_output = self.npu_op_names_dim_exec(npu_input, item[3], item[4], item[5]) + self.assertRtolEqual(cpu_output, npu_output) + + + def test_var_out_shape_format_fp16(self, device): + format_list = [-1] + shape_list = [[32, 24], [32, 8, 24]] + dim_list = [0] + unbiased_list = [True, False] + keepdim_list = [True, False] + shape_format = [ + [np.float16, i, j, k, l, m] for i in format_list for j in shape_list + for k in dim_list for l in unbiased_list for m in keepdim_list + ] + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item, 0, 100) + outputshape = self.output_shape(item[2],item[3],item[4],item[5]) + cpu_output,npu_output = self.create_output_tensor(0,1,outputshape,item[1],item[0]) + cpu_input1 = cpu_input1.to(torch.float32) + cpu_output = cpu_output.to(torch.float32) + cpu_output1 = self.cpu_op_out_exec(cpu_input1, item[3], cpu_output, item[4], item[5]) + npu_output1 = self.npu_op_out_exec(npu_input1, item[3], npu_output, item[4], item[5]) + cpu_output1 = cpu_output1.astype(np.float16) + self.assertRtolEqual(cpu_output1, npu_output1) + + def test_var_out_shape_format_fp32(self, device): + format_list = [-1] + shape_list = [[32, 24], [32, 8, 24]] + dim_list = [0] + unbiased_list = [True, False] + keepdim_list = [True, False] + shape_format = [ + [np.float32, i, j, k, l, m] for i in format_list for j in shape_list + for k in dim_list for l in unbiased_list for m in keepdim_list + ] + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item, 0, 100) + outputshape = self.output_shape(item[2], item[3], item[4], item[5]) + cpu_output,npu_output = self.create_output_tensor(0, 1, outputshape, item[1], item[0]) + cpu_output1 = self.cpu_op_out_exec(cpu_input1, item[3], cpu_output, item[4], item[5]) + npu_output1 = self.npu_op_out_exec(npu_input1, item[3], npu_output, item[4], item[5]) + self.assertRtolEqual(cpu_output1, npu_output1) + + def test__var_shape_format_fp16(self, device): + format_list = [-1] + shape_list = [[32, 24], [32, 8, 24]] + unbiased_list = [True, False] + shape_format = [ + [np.float16, i, j, l] for i in format_list for j in shape_list + for l in unbiased_list + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 0, 100) + cpu_input = cpu_input.to(torch.float32) + cpu_output = self.cpu_op_var_exec(cpu_input, item[3]) + cpu_output = cpu_output.astype(np.float16) + npu_output = self.npu_op_var_exec(npu_input, item[3]) + self.assertRtolEqual(cpu_output, npu_output) + + def test__var_shape_format_fp32(self,device): + format_list = [-1] + shape_list = [[32, 24], [32, 8, 24]] + unbiased_list = [True, False] + shape_format = [ + [np.float32, i, j, l] for i in format_list for j in shape_list + for l in unbiased_list + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 0, 100) + cpu_output = self.cpu_op_var_exec(cpu_input, item[3]) + npu_output = self.npu_op_var_exec(npu_input, item[3]) + self.assertRtolEqual(cpu_output, npu_output) + + def test_var_mean_shape_format_fp16(self, device): + format_list = [-1] + shape_list = [[32, 24], [32, 8, 24]] + unbiased_list = [True, False] + shape_format = [ + [np.float16, i, j, k] for i in format_list for j in shape_list + for k in unbiased_list + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 0, 100) + cpu_input = cpu_input.to(torch.float32) + cpu_output1, cpu_output2 = self.cpu_op_mean_exec(cpu_input, item[3]) + cpu_output1 = cpu_output1.astype(np.float16) + cpu_output2 = cpu_output2.astype(np.float16) + npu_output1, npu_output2 = self.npu_op_mean_exec(npu_input, item[3]) + self.assertRtolEqual(cpu_output1, npu_output1) + self.assertRtolEqual(cpu_output2, npu_output2) + + def test_var_mean_shape_format_fp32(self, device): + format_list = [-1] + shape_list = [[32, 24], [32, 8, 24]] + unbiased_list = [True, False] + shape_format = [ + [np.float32, i, j, k] for i in format_list for j in shape_list + for k in unbiased_list + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 0, 100) + cpu_output1, cpu_output2 = self.cpu_op_mean_exec(cpu_input, item[3]) + npu_output1, npu_output2 = self.npu_op_mean_exec(npu_input, item[3]) + self.assertRtolEqual(cpu_output1, npu_output1) + self.assertRtolEqual(cpu_output2, npu_output2) + + def test_var_mean_dim_shape_format_fp16(self, device): + format_list = [-1] + shape_list = [[32, 24], [32, 8, 24]] + dim_list = [0] + unbiased_list = [True, False] + keepdim_list = [True, False] + shape_format = [ + [np.float16, i, j, k, l, m] for i in format_list for j in shape_list + for k in dim_list for l in unbiased_list for m in keepdim_list + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 0, 100) + cpu_input = cpu_input.to(torch.float32) + cpu_output1, cpu_output2 = self.cpu_op_mean_dim_exec(cpu_input, item[3], item[4], item[5]) + cpu_output1 = cpu_output1.astype(np.float16) + cpu_output2 = cpu_output2.astype(np.float16) + npu_output1, npu_output2 = self.npu_op_mean_dim_exec(npu_input, item[3], item[4], item[5]) + self.assertRtolEqual(cpu_output1, npu_output1) + self.assertRtolEqual(cpu_output2, npu_output2) + + def test_var_mean_dim_shape_format_fp32(self, device): + format_list = [-1] + shape_list = [[32, 1024], [32, 8, 1024]] + dim_list = [0] + unbiased_list = [True, False] + keepdim_list = [True, False] + shape_format = [ + [np.float16, i, j, k, l, m] for i in format_list for j in shape_list + for k in dim_list for l in unbiased_list for m in keepdim_list + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 0, 100) + cpu_output1, cpu_output2 = self.cpu_op_mean_dim_exec(cpu_input, item[3], item[4], item[5]) + npu_output1, npu_output2 = self.npu_op_mean_dim_exec(npu_input, item[3], item[4], item[5]) + self.assertRtolEqual(cpu_output1, npu_output1) + self.assertRtolEqual(cpu_output2, npu_output2) + + def test_var_mean_names_dim_shape_format_fp16(self, device): + shape = (1024, 8, 32) + dimlist = ['N', 'C', 'H'] + cpu_input = torch.rand(shape, dtype=torch.float32) + npu_input = cpu_input.npu() + dim = np.random.choice(dimlist) + npu_input = npu_input.to(torch.float16) + cpu_input.names = ['N', 'C', 'H'] + npu_input.names = ['N', 'C', 'H'] + cpu_output1, cpu_output2 = self.cpu_op_mean_names_dim_exec(cpu_input, dim=dim) + cpu_output1 = cpu_output1.astype(np.float16) + cpu_output2 = cpu_output2.astype(np.float16) + npu_output1, npu_output2 = self.npu_op_mean_names_dim_exec(npu_input, dim=dim) + self.assertRtolEqual(cpu_output1, npu_output1) + self.assertRtolEqual(cpu_output2, npu_output2) + + def test_var_mean_names_dim_shape_format_fp32(self, device): + shape = (1024, 8, 32) + dimlist = ['N', 'C', 'H'] + cpu_input = torch.rand(shape, dtype=torch.float32, names=('N', 'C', 'H')) + npu_input = cpu_input.npu() + dim = np.random.choice(dimlist) + dims = dimlist.index(dim) + cpu_output1, cpu_output2 = self.cpu_op_mean_names_dim_exec(cpu_input, dim=dim) + npu_output1, npu_output2 = self.npu_op_mean_names_dim_exec(npu_input, dim=dim) + self.assertRtolEqual(cpu_output1, npu_output1) + self.assertRtolEqual(cpu_output2, npu_output2) + + def test_var_dim_shape_format_5d_fp16(self, device): + format_list = [-1] + shape_list = [[2, 94, 4, 52, 192]] + dim_list = [0] + unbiased_list = [True, False] + keepdim_list = [True, False] + shape_format = [ + [np.float16, i, j, k, l, m] for i in format_list for j in shape_list + for k in dim_list for l in unbiased_list for m in keepdim_list + ] + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item, 0, 100) + cpu_input1 = cpu_input1.to(torch.float32) + cpu_output1 = self.cpu_op_dim_exec(cpu_input1, item[3], item[4], item[5]) + cpu_output1 = cpu_output1.astype(np.float16) + npu_output1 = self.npu_op_dim_exec(npu_input1, item[3], item[4], item[5]) + self.assertRtolEqual(cpu_output1, npu_output1, prec16=0.004) + +instantiate_device_type_tests(TestVar, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/test/test_trans_contiguous/test_as_strided_copy_to_contiguous.py b/test/test_trans_contiguous/test_as_strided_copy_to_contiguous.py new file mode 100644 index 00000000000..78d05c1681d --- /dev/null +++ b/test/test_trans_contiguous/test_as_strided_copy_to_contiguous.py @@ -0,0 +1,70 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor, check_operators_in_prof + +os.environ["COMBINED_ENABLE"] = "1" # Open combined-view cases optimization + +class TestAsStridedCopyToContiguous(TestCase): + def cpu_op_exec(self, input1, size, stride, storage_offset): + output = torch.as_strided(input1, size, stride, storage_offset).contiguous() + output = output.numpy() + return output + + def npu_op_exec(self,input1, size, stride, storage_offset): + with torch.autograd.profiler.profile(use_npu=True) as prof: + output = torch.as_strided(input1, size, stride, storage_offset).contiguous() + self.assertEqual(check_operators_in_prof(['npuAsStrided'], prof, ['npuCombined']) \ + , True, "Error operators called!") + output = output.cpu().numpy() + return output + + def test_as_strided(self, device): + dtype_list = [np.bool, np.int32, np.float16, np.float32, np.int8, np.uint8, np.int64] + format_list = [-1] + small_shape_list = [ + [5, 5] + ] + small_shape_format = [ + [i, j, k] for i in dtype_list for j in format_list for k in small_shape_list + ] + + for item in small_shape_format: + cpu_input, npu_input = create_common_tensor(item, -100, 100) + cpu_output = self.cpu_op_exec(cpu_input, (3, 3), (1, 2), 1) + npu_output = self.npu_op_exec(npu_input, (3, 3), (1, 2), 1) + self.assertRtolEqual(cpu_output, npu_output) + + other_shape_format = [ + [[np.float16, 0, [13, 23]], (10, 15), (1, 2), 1], + [[np.float16, 3, [2, 13, 23]], (10, 15), (1, 2), 2], + [[np.float32, 29, [6, 32, 8, 2]], (8, 6, 2), (5, 4, 1), 3], + ] + + for item in other_shape_format: + cpu_input, npu_input = create_common_tensor(item[0], -100, 100) + cpu_output = self.cpu_op_exec(cpu_input, item[1], item[2], item[3]) + npu_output = self.npu_op_exec(npu_input, item[1], item[2], item[3]) + self.assertRtolEqual(cpu_output, npu_output) + +instantiate_device_type_tests(TestAsStridedCopyToContiguous, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/test/test_trans_contiguous/test_tri_combined_views_copy_to_contiguous.py b/test/test_trans_contiguous/test_tri_combined_views_copy_to_contiguous.py new file mode 100644 index 00000000000..115d6059c9f --- /dev/null +++ b/test/test_trans_contiguous/test_tri_combined_views_copy_to_contiguous.py @@ -0,0 +1,97 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor, check_operators_in_prof + +os.environ["COMBINED_ENABLE"] = "1" # Open combined-view cases optimization + +class TestTriCombinedViewsCopyToContiguous(TestCase): + def test_view_narrow_permute_copy_contiguous(self, device): + dtype_list1 = [np.float16, np.float32] + format_list1 = [-1] + shape_list1 = [ + [200, 30, 40, 16], + ] + shape_format = [ + [i, j, k] for i in dtype_list1 for j in format_list1 for k in shape_list1 + ] + + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 0, 100) + # case 1: view+narrow+permute ==> cannot be optimized + with torch.autograd.profiler.profile(use_npu=True) as prof: + npu_out1 = npu_input.view(npu_input.size(0) * npu_input.size(1), npu_input.size(2), npu_input.size(3)) \ + [:,1:10].transpose(0, 1).contiguous() + self.assertEqual(check_operators_in_prof(['npuAsStrided'], prof, ['npuCombined']), \ + True, "Error operators called!") + cpu_out1 = cpu_input.view(cpu_input.size(0) * cpu_input.size(1), cpu_input.size(2), cpu_input.size(3)) \ + [:,1:10].transpose(0, 1).contiguous() + self.assertRtolEqual(npu_out1.to("cpu").numpy(), cpu_out1.numpy()) + + # case 2: permute+view+narrow ==> cannot be optimized + with torch.autograd.profiler.profile(use_npu=True) as prof: + npu_out2 = npu_input.permute(1, 0, 2, 3). \ + view(npu_input.size(1), npu_input.size(0), npu_input.size(2)*npu_input.size(3)) \ + [:,:,1:10].contiguous() + self.assertEqual(check_operators_in_prof(['npuAsStrided'], prof, ['npuCombined']), \ + True, "Error operators called!") + cpu_out2 = cpu_input.permute(1, 0, 2, 3). \ + view(cpu_input.size(1), cpu_input.size(0), cpu_input.size(2)*cpu_input.size(3)) \ + [:,:,1:10].contiguous() + self.assertRtolEqual(npu_out2.to("cpu").numpy(), cpu_out2.numpy()) + + def test_view_select_permute_copy_contiguous(self, device): + dtype_list2 = [np.float16, np.float32] + format_list2 = [-1] + shape_list2 = [ + [200, 30, 40, 16], + ] + shape_format = [ + [i, j, k] for i in dtype_list2 for j in format_list2 for k in shape_list2 + ] + + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 0, 100) + # case 1: view+select+permute ==> cannot be optimized + with torch.autograd.profiler.profile(use_npu=True) as prof: + npu_out1 = npu_input.view(npu_input.size(0) * npu_input.size(1), npu_input.size(2), npu_input.size(3)) \ + [:,1].transpose(0, 1).contiguous() + self.assertEqual(check_operators_in_prof(['npuAsStrided'], prof, ['npuCombined']), \ + True, "Error operators called!") + cpu_out1 = cpu_input.view(cpu_input.size(0) * cpu_input.size(1), cpu_input.size(2), cpu_input.size(3)) \ + [:,1].transpose(0, 1).contiguous() + self.assertRtolEqual(npu_out1.to("cpu").numpy(), cpu_out1.numpy()) + + # case 2: permute+view+select ==> cannot be optimized + with torch.autograd.profiler.profile(use_npu=True) as prof: + npu_out2 = npu_input.permute(1, 0, 2, 3). \ + view(npu_input.size(1), npu_input.size(0), npu_input.size(2)*npu_input.size(3)) \ + [:,:,2].contiguous() + self.assertEqual(check_operators_in_prof(['npuAsStrided'], prof, ['npuCombined']), \ + True, "Error operators called!") + cpu_out2 = cpu_input.permute(1, 0, 2, 3). \ + view(cpu_input.size(1), cpu_input.size(0), cpu_input.size(2)*cpu_input.size(3)) \ + [:,:,2].contiguous() + self.assertRtolEqual(npu_out2.to("cpu").numpy(), cpu_out2.numpy()) + +instantiate_device_type_tests(TestTriCombinedViewsCopyToContiguous, globals(), except_for='cpu') +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/torch_npu/csrc/aten/common/CopyKernel.cpp b/torch_npu/csrc/aten/common/CopyKernel.cpp index e511fb30dd2..935b225506c 100644 --- a/torch_npu/csrc/aten/common/CopyKernel.cpp +++ b/torch_npu/csrc/aten/common/CopyKernel.cpp @@ -33,75 +33,6 @@ namespace at_npu { namespace native { namespace { -// src : host <-- device -// | copy src to dst on cpu -// dst : host --> device -void copy_d2d_via_host(at::Tensor& self, const at::Tensor& src, bool same_type) { - c10::npu::NPUStream copy_stream = c10::npu::getCurrentNPUStream(); - aclError error = aclrtSynchronizeStream(copy_stream); - if (error != ACL_ERROR_NONE) { - AT_ERROR("ACL stream synchronize failed."); - return; - } - - int64_t real_bytes = - StorageDescHelper::GetValidMemorySize(src) * src.element_size(); - auto cpu_src = at::empty( - real_bytes / src.element_size(), src.options().device(at::kCPU)); - cpu_src = cpu_src.as_strided(src.sizes(), src.strides()); - - error = aclrtMemcpy( - cpu_src.data_ptr(), - real_bytes, - src.data_ptr(), - real_bytes, - ACL_MEMCPY_DEVICE_TO_HOST); - if (error != ACL_ERROR_NONE) { - AT_ERROR("aclrtMemcpy device to cpu_src error."); - return; - } - - real_bytes = - StorageDescHelper::GetValidMemorySize(self) * self.element_size(); - auto cpu_dst = at::empty( - real_bytes / self.element_size(), self.options().device(at::kCPU)); - cpu_dst = cpu_dst.as_strided(self.sizes(), self.strides()); - - if (!same_type) { - cpu_src = cpu_src.to(cpu_dst.dtype()); - } - - // sometimes npu_dst just need part of cpu_dst's elements, so we do memory - // copy from npu to cpu here, let npu_dst cover cpu_dst, to avoid unneeded - // cpu_dst's elements cover npu_dst's original elements - if ((!cpu_dst.is_contiguous()) && (self.defined())) { - error = aclrtMemcpy( - cpu_dst.data_ptr(), - real_bytes, - self.data_ptr(), - real_bytes, - ACL_MEMCPY_DEVICE_TO_HOST); - if (error != ACL_ERROR_NONE) { - AT_ERROR("ACL_Memcpy device to cpu_dst error."); - return; - } - } - - cpu_dst.copy_(cpu_src); - - error = aclrtMemcpy( - self.data_ptr(), - real_bytes, - cpu_dst.data_ptr(), - real_bytes, - ACL_MEMCPY_HOST_TO_DEVICE); - if (error != ACL_ERROR_NONE) { - AT_ERROR("aclrtMemcpy cpu_dst to device error."); - return; - } - NPU_LOGD("Src or dst is not contiguous when do device to device copy."); -} - // NOTE: helper function of copy, the input parameter is not checked, The caller // needs to ensure that the parameters are correct. @@ -132,14 +63,8 @@ void copy_d2d_last_method( bool same_type, bool non_blocking) { // general copy method but Low performance - if (torch_npu::option::OptionsManager::CheckPTcopy_Enable()) { - RECORD_FUNCTION("d2dCopyWithPTCopy", std::vector({src})); - copy_kernel_npu(self, src, non_blocking); - } else { - RECORD_FUNCTION( - "d2dCopyWithStreamSynchronize", std::vector({src})); - copy_d2d_via_host(self, src, same_type); - } + RECORD_FUNCTION("d2dCopyWithPTCopy", std::vector({src})); + copy_kernel_npu(self, src, non_blocking); } // the dst and src are same format now @@ -150,15 +75,22 @@ void copy_d2d_dtype_baseformat( const at::Tensor& src, bool non_blocking) { if (!self.is_contiguous()) { + // Contiguous/discontiguous source tensor copy to discontiguous self tensor return copy_d2d_last_method(self, src, true, non_blocking); } if (!src.is_contiguous()) { - // discontiguous + // Discontiguous source tensor copy to contiguous self tensor if (TransContiguous::ContiguousOptimizeWithBaseFormat(self, src)) { + // Optimized trans-contiguous method + return; + } else { + // General trans-contiguous method + NPUNativeFunctions::npu_stride_copy_out(src, src.sizes(), src.strides(), src.storage_offset(), self); return; } } else { + // Contiguous source tensor copy to contiguous self tensor int64_t numel = self.numel(); if (numel == src.numel()) { RECORD_FUNCTION("d2dCopyAsync", std::vector({src})); diff --git a/torch_npu/csrc/aten/common/CopyKernelNpu.cpp b/torch_npu/csrc/aten/common/CopyKernelNpu.cpp index a283c31d63d..cb669c4b2df 100644 --- a/torch_npu/csrc/aten/common/CopyKernelNpu.cpp +++ b/torch_npu/csrc/aten/common/CopyKernelNpu.cpp @@ -28,7 +28,7 @@ c10::SmallVector get_view_value( static c10::SmallVector value; // It is determined by the definition of view attr value.resize(strides.size() + 3); - value[0] = t.numel(); // storageImpl numel + value[0] = t.storage().nbytes() / t.element_size(); // storageImpl numel value[1] = t.storage_offset(); // default to 0 value[2] = strides.size(); for (size_t i = 0; i < strides.size(); i++) { diff --git a/torch_npu/csrc/aten/npu_native_functions.yaml b/torch_npu/csrc/aten/npu_native_functions.yaml index 167c1963242..0ab0fa4062e 100644 --- a/torch_npu/csrc/aten/npu_native_functions.yaml +++ b/torch_npu/csrc/aten/npu_native_functions.yaml @@ -1916,6 +1916,8 @@ custom: variants: function, method - func: npu_softmax_cross_entropy_with_logits_backward(Tensor grad, Tensor self, Tensor labels) -> Tensor variants: function, method + - func: npu_stride_copy(Tensor self, int[] shape, int[] stride, Scalar storage_offset) -> Tensor + - func: npu_stride_copy.out(Tensor self, int[] shape, int[] stride, Scalar storage_offset, *, Tensor(a!) out) -> Tensor(a!) custom_autograd: - func: npu_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor - func: npu_convolution_transpose(Tensor input, Tensor weight, Tensor? bias, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups) -> Tensor diff --git a/torch_npu/csrc/aten/ops/AsStridedKernelNpu.cpp b/torch_npu/csrc/aten/ops/AsStridedKernelNpu.cpp new file mode 100644 index 00000000000..3fb30d4b22d --- /dev/null +++ b/torch_npu/csrc/aten/ops/AsStridedKernelNpu.cpp @@ -0,0 +1,63 @@ +// Copyright (c) 2020, Huawei Technologies.All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" +#include + +namespace at_npu { +namespace native { + +at::Tensor& stride_copy_out_npu_nocheck( + at::Tensor& result, + const at::Tensor& self, + at::IntArrayRef shape, + at::IntArrayRef stride, + at::Scalar storage_offset) { + RECORD_FUNCTION("npuAsStrided", std::vector({self})); + OpCommand cmd; + cmd.Name("AsStrided") + .InputWithoutContiguous(self) + .Input(shape) + .Input(stride) + .Input(storage_offset, at::kLong) + .Output(result) + .Run(); + return result; +} + +at::Tensor& NPUNativeFunctions::npu_stride_copy_out( + const at::Tensor& self, + c10::IntArrayRef shape, + c10::IntArrayRef stride, + c10::Scalar storage_offset, + at::Tensor& result) { + stride_copy_out_npu_nocheck(result, self, shape, stride, storage_offset); + return result; +} + +at::Tensor NPUNativeFunctions::npu_stride_copy( + const at::Tensor& self, + c10::IntArrayRef shape, + c10::IntArrayRef stride, + c10::Scalar storage_offset) { + // AsStrided OP only supports ND input + at::Tensor result = OpPreparation::ApplyTensorWithFormat( + shape, self.options(), ACL_FORMAT_ND); + stride_copy_out_npu_nocheck(result, self, shape, stride, storage_offset); + return result; +} + +} // namespace native +} // namespace at_npu \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/VarKernelNpu.cpp b/torch_npu/csrc/aten/ops/VarKernelNpu.cpp new file mode 100644 index 00000000000..2c3302aa7ae --- /dev/null +++ b/torch_npu/csrc/aten/ops/VarKernelNpu.cpp @@ -0,0 +1,225 @@ +// Copyright (c) 2020, Huawei Technologies.All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +auto check_and_trans_dim(const at::Tensor& self, at::IntArrayRef dim) { + int64_t dim_size = self.dim(); + int64_t ne_dim_size = dim_size * -1; + std::vector result_dim; + for(int64_t i = 0; i < dim.size(); i++) { + if(dim[i] >= ne_dim_size && dim[i] <= (dim_size - 1)) { + int64_t tmp_dim = CalcuOpUtil::make_wrap_dim(dim[i], self.dim()); + result_dim.emplace_back(tmp_dim); + } else { + AT_ERROR("dim value should be in the range of [-n, n-1], n is the dimension number of input tensor."); + } + } + std::sort(result_dim.begin(), result_dim.end()); + return result_dim; +} + +auto get_result_names(const at::Tensor& self, at::IntArrayRef dim, bool keepdim){ + auto names = self.names(); + std::vector result_names; + for(int64_t i = 0; i < names.size(); i++){ + result_names.emplace_back(names[i]); + } + if(!keepdim){ + for(int64_t i = dim.size() - 1; i >= 0; i--){ + int64_t need_remove_dim = dim[i]; + result_names.erase(result_names.begin() + need_remove_dim); + } + } + return result_names; +} + +at::Tensor& var_after_npu_nocheckout( + at::Tensor& var, + const at::Tensor& self, + const at::Tensor& mean_broadcast, + at::IntArrayRef dim, + bool unbiased, + bool keepdim) { + bool if_std = false; + OpCommand cmd; + cmd.Name("ReduceStdV2Update") + .Input(self) + .Input(mean_broadcast) + .Output(var) + .Attr("dim", dim) + .Attr("if_std", if_std) + .Attr("unbiased", unbiased) + .Attr("keepdim", keepdim) + .Run(); + return var; +} + +tuple var_mean_compute( + at::Tensor& variance, + at::Tensor& mean, + const at::Tensor& self, + at::IntArrayRef dim, + bool unbiased, + bool keepdim) { + auto meanOutputSizeKeepDim = var_npu_output_size(self, dim, true); + auto meanOutputSizeNotKeepDim = var_npu_output_size(self, dim, false); + mean = at::mean(self, dim, false); + mean.resize_(meanOutputSizeKeepDim); + at::Tensor mean_broadcast = NPUNativeFunctions::npu_broadcast(mean, self.sizes()); + if(!keepdim){ + mean.resize_(meanOutputSizeNotKeepDim); + } + var_after_npu_nocheckout(variance, self, mean_broadcast, dim, unbiased, keepdim); + return tuple(variance, mean); +} + +tuple var_mean_out_npu( + at::Tensor& variance, + at::Tensor& mean, + const at::Tensor& self, + at::IntArrayRef dim, + bool unbiased, + bool keepdim) { + auto dim_now = check_and_trans_dim(self, dim); + auto meanOutputSizeKeepDim = var_npu_output_size(self, dim_now, true); + auto meanOutputSizeNotKeepDim = var_npu_output_size(self, dim_now, false); + auto ori_type = self.scalar_type(); + if(ori_type != c10::ScalarType::Half && ori_type != c10::ScalarType::Float) { + AT_ERROR("Var Mean only support float16 or float32 type."); + } + if(variance.scalar_type() != mean.scalar_type() || variance.scalar_type() != ori_type) { + AT_ERROR("mean's type and variance' type must be equal to input's type."); + } + var_mean_compute( + variance, + mean, + self, + dim_now, + unbiased, + keepdim); + + return tuple(variance, mean); +} + +at::Tensor& NPUNativeFunctions::var_out( + const at::Tensor& self, + at::IntArrayRef dim, + bool unbiased, + bool keepdim, + at::Tensor& var) { + // check and trans dim + auto dim_now = check_and_trans_dim(self, dim); + auto outputSize = var_npu_output_size(self, dim_now, keepdim); + + // construct the output mean tensor of the NPU + at::Tensor mean = OpPreparation::ApplyTensor(self, outputSize); + at::Tensor var_ = OpPreparation::ApplyTensor(self, outputSize); + + var_mean_out_npu(var_, mean, self, dim, unbiased, keepdim); + OpPreparation::CheckOut( + {var_}, + var, + var_); + var.copy_(var_); + return var; +} + +at::Tensor& NPUNativeFunctions::var_out( + const at::Tensor& self, + at::DimnameList dim, + bool unbiased, + bool keepdim, + at::Tensor& var) { + return NPUNativeFunctions::var_out( + self, dimnames_to_positions(self, dim), unbiased, keepdim, var); +} + +at::Tensor NPUNativeFunctions::var(const at::Tensor& self, bool unbiased) { + bool keepdim = false; + c10::SmallVector dim = CalcuOpUtil::get_dimlist_for_tensor(self); + + return NPUNativeFunctions::var(self, dim, unbiased, keepdim); +} + +at::Tensor NPUNativeFunctions::var( + const at::Tensor& self, + at::IntArrayRef dim, + bool unbiased, + bool keepdim) { + auto dim_now = check_and_trans_dim(self, dim); + // calculate the output size + auto outputSize = var_npu_output_size(self, dim_now, keepdim); + + // construct the output tensor of the NPU + at::Tensor variance = OpPreparation::ApplyTensor(self, outputSize); + + // calculate the output result of the NPU + NPUNativeFunctions::var_out(self, dim, unbiased, keepdim, variance); + + return variance; +} + +at::Tensor NPUNativeFunctions::var( + const at::Tensor& self, + at::DimnameList dim, + bool unbiased, + bool keepdim) { + return NPUNativeFunctions::var(self, dimnames_to_positions(self, dim), unbiased, keepdim); +} + +at::Tensor _var_npu(const at::Tensor& self, bool unbiased) { + return at::var(self, unbiased); +} + +tuple NPUNativeFunctions::var_mean( + const at::Tensor& self, + at::DimnameList dim, + bool unbiased, + bool keepdim) { + return NPUNativeFunctions::var_mean(self, dimnames_to_positions(self, dim), unbiased, keepdim); +} + +tuple NPUNativeFunctions::var_mean( + const at::Tensor& self, + at::IntArrayRef dim, + bool unbiased, + bool keepdim) { + auto dim_now = check_and_trans_dim(self, dim); + // calculate the output size + auto outputSize = var_npu_output_size(self, dim_now, keepdim); + + // construct the output tensor of the NPU + at::Tensor variance = OpPreparation::ApplyTensor(self, outputSize); + + at::Tensor mean = OpPreparation::ApplyTensor(self, outputSize); + + // calculate the output result of the NPU + var_mean_out_npu(variance, mean, self, dim, unbiased, keepdim); + + return tuple(variance, mean); +} + +tuple NPUNativeFunctions::var_mean(const at::Tensor& self, bool unbiased) { + c10::SmallVector dim = CalcuOpUtil::get_dimlist_for_tensor(self); + + return NPUNativeFunctions::var_mean(self, dim, unbiased, false); +} +} // namespace native +} // namespace at_npu \ No newline at end of file diff --git a/torch_npu/csrc/framework/OpCommandBase.h b/torch_npu/csrc/framework/OpCommandBase.h index 069c2e4e821..bb27053170f 100644 --- a/torch_npu/csrc/framework/OpCommandBase.h +++ b/torch_npu/csrc/framework/OpCommandBase.h @@ -72,6 +72,14 @@ namespace at_npu return static_cast(*this); } + Derived &InputWithoutContiguous( + const at::Tensor &input, + const string &descName = "", + const string &realData = "") + { + return AddTensorInput(const_cast(input), at::ScalarType::Undefined, descName, realData); + } + Derived &Input() { return AddNoneTensor(); diff --git a/torch_npu/csrc/framework/contiguous/slice_opt.cpp b/torch_npu/csrc/framework/contiguous/slice_opt.cpp index 0eeea31d2ce..c43cb5cf6fa 100644 --- a/torch_npu/csrc/framework/contiguous/slice_opt.cpp +++ b/torch_npu/csrc/framework/contiguous/slice_opt.cpp @@ -27,7 +27,7 @@ namespace at_npu public: bool Optimizer(const at::Tensor &src, at::Tensor &self) override { - // Pattern slice. Current pattern should be used before PTcopy process. + // Pattern slice. // Current pattern does not directly depend on other patterns. // The relative sequence of this pattern and other patterns is not important. c10::SmallVector offsets; diff --git a/torch_npu/csrc/register/OptionsManager.cpp b/torch_npu/csrc/register/OptionsManager.cpp index 0afa6ce4d49..cb93b59f9df 100644 --- a/torch_npu/csrc/register/OptionsManager.cpp +++ b/torch_npu/csrc/register/OptionsManager.cpp @@ -30,14 +30,6 @@ bool OptionsManager::CheckQueueEnable() { return (queue_enable == 1); } -bool OptionsManager::CheckPTcopy_Enable() { - static int32_t PTcopy__enable = -1; - if (PTcopy__enable == -1) { - PTcopy__enable = GetBoolTypeOption("PTCOPY_ENABLE"); - } - return (PTcopy__enable == 1); -} - bool OptionsManager::CheckCombinedOptimizerEnable() { static int32_t combined_optimize = -1; if (combined_optimize == -1) { diff --git a/torch_npu/csrc/register/OptionsManager.h b/torch_npu/csrc/register/OptionsManager.h index dd59ea6513c..0bae5ab2005 100644 --- a/torch_npu/csrc/register/OptionsManager.h +++ b/torch_npu/csrc/register/OptionsManager.h @@ -27,7 +27,6 @@ namespace option { class OptionsManager { public: static bool CheckQueueEnable(); - static bool CheckPTcopy_Enable(); static bool CheckCombinedOptimizerEnable(); static bool CheckTriCombinedOptimizerEnable(); static bool CheckAclDumpDateEnable(); diff --git a/torch_npu/testing/util_test.py b/torch_npu/testing/util_test.py index 40c00f3d3bb..6294c267847 100644 --- a/torch_npu/testing/util_test.py +++ b/torch_npu/testing/util_test.py @@ -141,4 +141,22 @@ def create_dtype_tensor(shape, dtype, npu_format=-1, min_value=-5, max_value=5, npu_input = torch.from_numpy(x).to(device) if npu_format != -1 and (dtype in [torch.float, torch.half]): npu_input = torch_npu.npu_format_cast(npu_input, npu_format) - return cpu_input, npu_input \ No newline at end of file + return cpu_input, npu_input + +def check_operators_in_prof(expected_operators, prof, unexpected_operators=None): + unexpected_operators = unexpected_operators or [] + prof_key_averages = prof.key_averages() + if not prof_key_averages: + return print("torch profiling is empty, please check it") + for prof_item in prof_key_averages: + if prof_item.key in unexpected_operators: + # if unexpected oprators are called, pattern inferring in trans-contiguous is failed + return False + elif prof_item.key in expected_operators: + # if expected oprator is called, empty it in expected_operators list + expected_operators.remove(prof_item.key) + + # if expected_operators list is empty, all oprators have been called + if not expected_operators: + return True + return False \ No newline at end of file -- Gitee