diff --git a/test/test_network_ops/test_var.py b/test/test_network_ops/test_var.py new file mode 100644 index 0000000000000000000000000000000000000000..3eaa731092f95f189243e5fbda6be4eda04de77c --- /dev/null +++ b/test/test_network_ops/test_var.py @@ -0,0 +1,430 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd +# Copyright (c) 2019, Facebook CORPORATION. +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor + +class TestVar(TestCase): + def cpu_op_exec(self, input1, unbiased=True): + output = torch.var(input1, unbiased=unbiased) + output = output.numpy() + return output + + def npu_op_exec(self, input1, unbiased=True): + output = torch.var(input1, unbiased=unbiased) + output = output.to("cpu") + output = output.numpy() + return output + + def cpu_op_dim_exec(self, input1, dim, unbiased=True, keepdim=False): + output = torch.var(input1, dim, unbiased=unbiased, keepdim=keepdim) + output = output.numpy() + return output + + def npu_op_dim_exec(self, input1, dim, unbiased=True, keepdim=False): + output = torch.var(input1, dim, unbiased=unbiased, keepdim=keepdim) + output = output.to("cpu") + output = output.numpy() + return output + + def cpu_op_names_dim_exec(self, input1, dim, unbiased=True, keepdim=False): + output = torch.var(input1, dim, unbiased=unbiased, keepdim=keepdim) + output = output.numpy() + return output + + def npu_op_names_dim_exec(self, input1, dim, unbiased=True, keepdim=False): + output = torch.var(input1, dim, unbiased=unbiased, keepdim=keepdim) + output = output.to("cpu") + output = output.numpy() + return output + + def cpu_op_out_exec(self, input1, dim, output1, unbiased=True, keepdim=False): + torch.var(input1, dim, unbiased=unbiased, keepdim=keepdim, out=output1) + output1 = output1.numpy() + return output1 + + def npu_op_out_exec(self, input1, dim, output1, unbiased=True, keepdim=False): + torch.var(input1, dim, unbiased=unbiased, keepdim=keepdim, out=output1) + output1 = output1.to("cpu") + output1 = output1.numpy() + return output1 + + def cpu_op_var_exec(self, input1, unbiased=True): + output = torch.var(input1, unbiased=unbiased) + output = output.numpy() + return output + + def npu_op_var_exec(self, input1, unbiased=True): + output = torch.var(input1, unbiased=unbiased) + output = output.to("cpu") + output = output.numpy() + return output + + def cpu_op_mean_exec(self, input1, unbiased=True): + output = torch.var_mean(input1, unbiased=unbiased) + output1 = output[0] + output2 = output[1] + output1 = output1.numpy() + output2 = output2.numpy() + return output1, output2 + + def npu_op_mean_exec(self, input1, unbiased=True): + output = torch.var_mean(input1, unbiased=unbiased) + output1 = output[0].to("cpu") + output2 = output[1].to("cpu") + output1 = output1.numpy() + output2 = output2.numpy() + return output1, output2 + + def cpu_op_mean_dim_exec(self, input1, dim, unbiased=True, keepdim=False): + output = torch.var_mean(input1, dim, unbiased=unbiased, keepdim=keepdim) + output1 = output[0] + output2 = output[1] + output1 = output1.numpy() + output2 = output2.numpy() + return output1, output2 + + def npu_op_mean_dim_exec(self, input1, dim, unbiased=True, keepdim=False): + output = torch.var_mean(input1, dim, unbiased=unbiased, keepdim=keepdim) + output1 = output[0].to("cpu") + output2 = output[1].to("cpu") + output1 = output1.numpy() + output2 = output2.numpy() + return output1, output2 + + def cpu_op_mean_names_dim_exec(self, input1, dim, unbiased=True, keepdim=False): + output = torch.var_mean(input1, dim, unbiased=unbiased, keepdim=keepdim) + output1 = output[0] + output2 = output[1] + output1 = output1.numpy() + output2 = output2.numpy() + return output1, output2 + + def npu_op_mean_names_dim_exec(self, input1, dim, unbiased=True, keepdim=False): + output = torch.var_mean(input1, dim, unbiased=unbiased, keepdim=keepdim) + output1 = output[0].to("cpu") + output2 = output[1].to("cpu") + output1 = output1.numpy() + output2 = output2.numpy() + return output1, output2 + + def output_shape(self, inputshape, dim, unbiased=True, keepdim=False): + shape = list(inputshape) + if dim < len(inputshape): + if keepdim: + shape[dim] = 1 + else: + shape.pop(dim) + return shape + + def create_output_tensor(self, minvalue, maxvalue, shape, npuformat, dtype): + input1 = np.random.uniform(minvalue, maxvalue, shape).astype(dtype) + cpu_input = torch.from_numpy(input1) + npu_input = torch.from_numpy(input1).npu() + if npuformat != -1: + npu_input = npu_input.npu_format_cast(npuformat) + return cpu_input, npu_input + + def test_var_shape_format_fp16(self, device): + format_list = [-1] + shape_list = [[32, 24], [32, 8, 24]] + unbiased_list = [True, False] + shape_format = [ + [np.float16, i, j, k] for i in format_list for j in shape_list for k in unbiased_list + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 0, 100) + cpu_input = cpu_input.to(torch.float32) + cpu_output = self.cpu_op_exec(cpu_input, item[3]) + cpu_output = cpu_output.astype(np.float16) + npu_output = self.npu_op_exec(npu_input, item[3]) + self.assertRtolEqual(cpu_output, npu_output) + + def test_var_shape_format_fp32(self, device): + format_list = [-1] + shape_list = [[32, 24], [32, 8, 24]] + unbiased_list = [True, False] + shape_format = [ + [np.float32, i, j, k] for i in format_list for j in shape_list for k in unbiased_list + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 0, 100) + cpu_output = self.cpu_op_exec(cpu_input, item[3]) + npu_output = self.npu_op_exec(npu_input, item[3]) + self.assertRtolEqual(cpu_output, npu_output) + + def test_var_dim_shape_format_fp16(self, device): + format_list = [-1] + shape_list = [[32, 24], [32, 8, 24]] + dim_list = [0] + unbiased_list = [True, False] + keepdim_list = [True, False] + shape_format = [ + [np.float16, i, j, k, l, m] for i in format_list for j in shape_list + for k in dim_list for l in unbiased_list for m in keepdim_list + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 0, 100) + cpu_input = cpu_input.to(torch.float32) + cpu_output = self.cpu_op_dim_exec(cpu_input, item[3], item[4], item[5]) + cpu_output = cpu_output.astype(np.float16) + npu_output = self.npu_op_dim_exec(npu_input, item[3], item[4], item[5]) + self.assertRtolEqual(cpu_output, npu_output) + + def test_var_dim_shape_format_fp32(self, device): + format_list = [-1] + shape_list = [[32, 24], [32, 8, 24]] + dim_list = [0] + unbiased_list = [True, False] + keepdim_list = [True, False] + shape_format = [ + [np.float32, i, j, k, l, m] for i in format_list for j in shape_list + for k in dim_list for l in unbiased_list for m in keepdim_list + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 0, 100) + cpu_output = self.cpu_op_dim_exec(cpu_input, item[3], item[4], item[5]) + npu_output = self.npu_op_dim_exec(npu_input, item[3], item[4], item[5]) + self.assertRtolEqual(cpu_output, npu_output) + + def test_var_names_dim_shape_format_fp16(self, device): + format_list = [-1] + shape_list1 = [[32, 24], [32, 8, 24]] + dim_list = [0] + unbiased_list = [True, False] + keepdim_list = [True, False] + shape_format = [ + [np.float16, i, j, k, l, m] for i in format_list for j in shape_list1 + for k in dim_list for l in unbiased_list for m in keepdim_list + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 0, 100) + cpu_input = cpu_input.to(torch.float32) + cpu_output = self.cpu_op_names_dim_exec(cpu_input, item[3], item[4], item[5]) + cpu_output = cpu_output.astype(np.float16) + npu_output = self.npu_op_names_dim_exec(npu_input, item[3], item[4], item[5]) + self.assertRtolEqual(cpu_output, npu_output) + + def test_var_names_dim_shape_format_fp32(self, device): + format_list = [-1] + shape_list1 = [[32, 24], [32, 8, 24]] + dim_list = [0] + unbiased_list = [True, False] + keepdim_list = [True, False] + shape_format = [ + [np.float32, i, j, k, l, m] for i in format_list for j in shape_list1 + for k in dim_list for l in unbiased_list for m in keepdim_list + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 0, 100) + cpu_output = self.cpu_op_names_dim_exec(cpu_input, item[3], item[4], item[5]) + npu_output = self.npu_op_names_dim_exec(npu_input, item[3], item[4], item[5]) + self.assertRtolEqual(cpu_output, npu_output) + + + def test_var_out_shape_format_fp16(self, device): + format_list1 = [-1] + shape_list = [[32, 24], [32, 8, 24]] + dim_list = [0] + unbiased_list = [True, False] + keepdim_list = [True, False] + shape_format = [ + [np.float16, i, j, k, l, m] for i in format_list1 for j in shape_list + for k in dim_list for l in unbiased_list for m in keepdim_list + ] + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item, 0, 100) + outputshape = self.output_shape(item[2],item[3],item[4],item[5]) + cpu_output,npu_output = self.create_output_tensor(0,1,outputshape,item[1],item[0]) + cpu_input1 = cpu_input1.to(torch.float32) + cpu_output = cpu_output.to(torch.float32) + cpu_output1 = self.cpu_op_out_exec(cpu_input1, item[3], cpu_output, item[4], item[5]) + npu_output1 = self.npu_op_out_exec(npu_input1, item[3], npu_output, item[4], item[5]) + cpu_output1 = cpu_output1.astype(np.float16) + self.assertRtolEqual(cpu_output1, npu_output1) + + def test_var_out_shape_format_fp32(self, device): + format_list1 = [-1] + shape_list = [[32, 24], [32, 8, 24]] + dim_list = [0] + unbiased_list = [True, False] + keepdim_list = [True, False] + shape_format = [ + [np.float32, i, j, k, l, m] for i in format_list1 for j in shape_list + for k in dim_list for l in unbiased_list for m in keepdim_list + ] + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item, 0, 100) + outputshape = self.output_shape(item[2], item[3], item[4], item[5]) + cpu_output,npu_output = self.create_output_tensor(0, 1, outputshape, item[1], item[0]) + cpu_output1 = self.cpu_op_out_exec(cpu_input1, item[3], cpu_output, item[4], item[5]) + npu_output1 = self.npu_op_out_exec(npu_input1, item[3], npu_output, item[4], item[5]) + self.assertRtolEqual(cpu_output1, npu_output1) + + def test__var_shape_format_fp16(self, device): + format_list = [-1] + shape_list = [[32, 24], [32, 8, 24]] + unbiased_list = [True, False] + shape_format = [ + [np.float16, i, j, l] for i in format_list for j in shape_list + for l in unbiased_list + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 0, 100) + cpu_input = cpu_input.to(torch.float32) + cpu_output = self.cpu_op_var_exec(cpu_input, item[3]) + cpu_output = cpu_output.astype(np.float16) + npu_output = self.npu_op_var_exec(npu_input, item[3]) + self.assertRtolEqual(cpu_output, npu_output) + + def test__var_shape_format_fp32(self,device): + format_list = [-1] + shape_list = [[32, 24], [32, 8, 24]] + unbiased_list = [True, False] + shape_format = [ + [np.float32, i, j, l] for i in format_list for j in shape_list + for l in unbiased_list + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 0, 100) + cpu_output = self.cpu_op_var_exec(cpu_input, item[3]) + npu_output = self.npu_op_var_exec(npu_input, item[3]) + self.assertRtolEqual(cpu_output, npu_output) + + def test_var_mean_shape_format_fp16(self, device): + format_list = [-1] + shape_list = [[32, 24], [32, 8, 24]] + unbiased_list = [True, False] + shape_format = [ + [np.float16, i, j, k] for i in format_list for j in shape_list + for k in unbiased_list + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 0, 100) + cpu_input = cpu_input.to(torch.float32) + cpu_output1, cpu_output2 = self.cpu_op_mean_exec(cpu_input, item[3]) + cpu_output1 = cpu_output1.astype(np.float16) + cpu_output2 = cpu_output2.astype(np.float16) + npu_output1, npu_output2 = self.npu_op_mean_exec(npu_input, item[3]) + self.assertRtolEqual(cpu_output1, npu_output1) + self.assertRtolEqual(cpu_output2, npu_output2) + + def test_var_mean_shape_format_fp32(self, device): + format_list = [-1] + shape_list = [[32, 24], [32, 8, 24]] + unbiased_list = [True, False] + shape_format = [ + [np.float32, i, j, k] for i in format_list for j in shape_list + for k in unbiased_list + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 0, 100) + cpu_output1, cpu_output2 = self.cpu_op_mean_exec(cpu_input, item[3]) + npu_output1, npu_output2 = self.npu_op_mean_exec(npu_input, item[3]) + self.assertRtolEqual(cpu_output1, npu_output1) + self.assertRtolEqual(cpu_output2, npu_output2) + + def test_var_mean_dim_shape_format_fp16(self, device): + format_list1 = [-1] + shape_list1 = [[32, 24], [32, 8, 24]] + dim_list = [0] + unbiased_list = [True, False] + keepdim_list = [True, False] + shape_format = [ + [np.float16, i, j, k, l, m] for i in format_list1 for j in shape_list1 + for k in dim_list for l in unbiased_list for m in keepdim_list + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 0, 100) + cpu_input = cpu_input.to(torch.float32) + cpu_output1, cpu_output2 = self.cpu_op_mean_dim_exec(cpu_input, item[3], item[4], item[5]) + cpu_output1 = cpu_output1.astype(np.float16) + cpu_output2 = cpu_output2.astype(np.float16) + npu_output1, npu_output2 = self.npu_op_mean_dim_exec(npu_input, item[3], item[4], item[5]) + self.assertRtolEqual(cpu_output1, npu_output1) + self.assertRtolEqual(cpu_output2, npu_output2) + + def test_var_mean_dim_shape_format_fp32(self, device): + format_list = [-1] + shape_list = [[32, 1024], [32, 8, 1024]] + dim_list = [0] + unbiased_list = [True, False] + keepdim_list = [True, False] + shape_format = [ + [np.float16, i, j, k, l, m] for i in format_list for j in shape_list + for k in dim_list for l in unbiased_list for m in keepdim_list + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, 0, 100) + cpu_output1, cpu_output2 = self.cpu_op_mean_dim_exec(cpu_input, item[3], item[4], item[5]) + npu_output1, npu_output2 = self.npu_op_mean_dim_exec(npu_input, item[3], item[4], item[5]) + self.assertRtolEqual(cpu_output1, npu_output1) + self.assertRtolEqual(cpu_output2, npu_output2) + + def test_var_mean_names_dim_shape_format_fp16(self, device): + shape = (1024, 8, 32) + dimlist = ['N', 'C', 'H'] + cpu_input = torch.rand(shape, dtype=torch.float32) + npu_input = cpu_input.npu() + dim = np.random.choice(dimlist) + npu_input = npu_input.to(torch.float16) + cpu_input.names = ['N', 'C', 'H'] + npu_input.names = ['N', 'C', 'H'] + cpu_output1, cpu_output2 = self.cpu_op_mean_names_dim_exec(cpu_input, dim=dim) + cpu_output1 = cpu_output1.astype(np.float16) + cpu_output2 = cpu_output2.astype(np.float16) + npu_output1, npu_output2 = self.npu_op_mean_names_dim_exec(npu_input, dim=dim) + self.assertRtolEqual(cpu_output1, npu_output1) + self.assertRtolEqual(cpu_output2, npu_output2) + + def test_var_mean_names_dim_shape_format_fp32(self, device): + shape = (1024, 8, 32) + dimlist = ['N', 'C', 'H'] + cpu_input = torch.rand(shape, dtype=torch.float32, names=('N', 'C', 'H')) + npu_input = cpu_input.npu() + dim = np.random.choice(dimlist) + dims = dimlist.index(dim) + cpu_output1, cpu_output2 = self.cpu_op_mean_names_dim_exec(cpu_input, dim=dim) + npu_output1, npu_output2 = self.npu_op_mean_names_dim_exec(npu_input, dim=dim) + self.assertRtolEqual(cpu_output1, npu_output1) + self.assertRtolEqual(cpu_output2, npu_output2) + + def test_var_dim_shape_format_5d_fp16(self, device): + format_list = [-1] + shape_list = [[2, 94, 4, 52, 192]] + dim_list = [0] + unbiased_list = [True, False] + keepdim_list = [True, False] + shape_format = [ + [np.float16, i, j, k, l, m] for i in format_list for j in shape_list + for k in dim_list for l in unbiased_list for m in keepdim_list + ] + for item in shape_format: + cpu_input1, npu_input1 = create_common_tensor(item, 0, 100) + cpu_input1 = cpu_input1.to(torch.float32) + cpu_output1 = self.cpu_op_dim_exec(cpu_input1, item[3], item[4], item[5]) + cpu_output1 = cpu_output1.astype(np.float16) + npu_output1 = self.npu_op_dim_exec(npu_input1, item[3], item[4], item[5]) + self.assertRtolEqual(cpu_output1, npu_output1, prec16=0.004) + +instantiate_device_type_tests(TestVar, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/VarKernelNpu.cpp b/torch_npu/csrc/aten/ops/VarKernelNpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2c3302aa7ae3e6bfc444d8cc31766482fb62a654 --- /dev/null +++ b/torch_npu/csrc/aten/ops/VarKernelNpu.cpp @@ -0,0 +1,225 @@ +// Copyright (c) 2020, Huawei Technologies.All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +auto check_and_trans_dim(const at::Tensor& self, at::IntArrayRef dim) { + int64_t dim_size = self.dim(); + int64_t ne_dim_size = dim_size * -1; + std::vector result_dim; + for(int64_t i = 0; i < dim.size(); i++) { + if(dim[i] >= ne_dim_size && dim[i] <= (dim_size - 1)) { + int64_t tmp_dim = CalcuOpUtil::make_wrap_dim(dim[i], self.dim()); + result_dim.emplace_back(tmp_dim); + } else { + AT_ERROR("dim value should be in the range of [-n, n-1], n is the dimension number of input tensor."); + } + } + std::sort(result_dim.begin(), result_dim.end()); + return result_dim; +} + +auto get_result_names(const at::Tensor& self, at::IntArrayRef dim, bool keepdim){ + auto names = self.names(); + std::vector result_names; + for(int64_t i = 0; i < names.size(); i++){ + result_names.emplace_back(names[i]); + } + if(!keepdim){ + for(int64_t i = dim.size() - 1; i >= 0; i--){ + int64_t need_remove_dim = dim[i]; + result_names.erase(result_names.begin() + need_remove_dim); + } + } + return result_names; +} + +at::Tensor& var_after_npu_nocheckout( + at::Tensor& var, + const at::Tensor& self, + const at::Tensor& mean_broadcast, + at::IntArrayRef dim, + bool unbiased, + bool keepdim) { + bool if_std = false; + OpCommand cmd; + cmd.Name("ReduceStdV2Update") + .Input(self) + .Input(mean_broadcast) + .Output(var) + .Attr("dim", dim) + .Attr("if_std", if_std) + .Attr("unbiased", unbiased) + .Attr("keepdim", keepdim) + .Run(); + return var; +} + +tuple var_mean_compute( + at::Tensor& variance, + at::Tensor& mean, + const at::Tensor& self, + at::IntArrayRef dim, + bool unbiased, + bool keepdim) { + auto meanOutputSizeKeepDim = var_npu_output_size(self, dim, true); + auto meanOutputSizeNotKeepDim = var_npu_output_size(self, dim, false); + mean = at::mean(self, dim, false); + mean.resize_(meanOutputSizeKeepDim); + at::Tensor mean_broadcast = NPUNativeFunctions::npu_broadcast(mean, self.sizes()); + if(!keepdim){ + mean.resize_(meanOutputSizeNotKeepDim); + } + var_after_npu_nocheckout(variance, self, mean_broadcast, dim, unbiased, keepdim); + return tuple(variance, mean); +} + +tuple var_mean_out_npu( + at::Tensor& variance, + at::Tensor& mean, + const at::Tensor& self, + at::IntArrayRef dim, + bool unbiased, + bool keepdim) { + auto dim_now = check_and_trans_dim(self, dim); + auto meanOutputSizeKeepDim = var_npu_output_size(self, dim_now, true); + auto meanOutputSizeNotKeepDim = var_npu_output_size(self, dim_now, false); + auto ori_type = self.scalar_type(); + if(ori_type != c10::ScalarType::Half && ori_type != c10::ScalarType::Float) { + AT_ERROR("Var Mean only support float16 or float32 type."); + } + if(variance.scalar_type() != mean.scalar_type() || variance.scalar_type() != ori_type) { + AT_ERROR("mean's type and variance' type must be equal to input's type."); + } + var_mean_compute( + variance, + mean, + self, + dim_now, + unbiased, + keepdim); + + return tuple(variance, mean); +} + +at::Tensor& NPUNativeFunctions::var_out( + const at::Tensor& self, + at::IntArrayRef dim, + bool unbiased, + bool keepdim, + at::Tensor& var) { + // check and trans dim + auto dim_now = check_and_trans_dim(self, dim); + auto outputSize = var_npu_output_size(self, dim_now, keepdim); + + // construct the output mean tensor of the NPU + at::Tensor mean = OpPreparation::ApplyTensor(self, outputSize); + at::Tensor var_ = OpPreparation::ApplyTensor(self, outputSize); + + var_mean_out_npu(var_, mean, self, dim, unbiased, keepdim); + OpPreparation::CheckOut( + {var_}, + var, + var_); + var.copy_(var_); + return var; +} + +at::Tensor& NPUNativeFunctions::var_out( + const at::Tensor& self, + at::DimnameList dim, + bool unbiased, + bool keepdim, + at::Tensor& var) { + return NPUNativeFunctions::var_out( + self, dimnames_to_positions(self, dim), unbiased, keepdim, var); +} + +at::Tensor NPUNativeFunctions::var(const at::Tensor& self, bool unbiased) { + bool keepdim = false; + c10::SmallVector dim = CalcuOpUtil::get_dimlist_for_tensor(self); + + return NPUNativeFunctions::var(self, dim, unbiased, keepdim); +} + +at::Tensor NPUNativeFunctions::var( + const at::Tensor& self, + at::IntArrayRef dim, + bool unbiased, + bool keepdim) { + auto dim_now = check_and_trans_dim(self, dim); + // calculate the output size + auto outputSize = var_npu_output_size(self, dim_now, keepdim); + + // construct the output tensor of the NPU + at::Tensor variance = OpPreparation::ApplyTensor(self, outputSize); + + // calculate the output result of the NPU + NPUNativeFunctions::var_out(self, dim, unbiased, keepdim, variance); + + return variance; +} + +at::Tensor NPUNativeFunctions::var( + const at::Tensor& self, + at::DimnameList dim, + bool unbiased, + bool keepdim) { + return NPUNativeFunctions::var(self, dimnames_to_positions(self, dim), unbiased, keepdim); +} + +at::Tensor _var_npu(const at::Tensor& self, bool unbiased) { + return at::var(self, unbiased); +} + +tuple NPUNativeFunctions::var_mean( + const at::Tensor& self, + at::DimnameList dim, + bool unbiased, + bool keepdim) { + return NPUNativeFunctions::var_mean(self, dimnames_to_positions(self, dim), unbiased, keepdim); +} + +tuple NPUNativeFunctions::var_mean( + const at::Tensor& self, + at::IntArrayRef dim, + bool unbiased, + bool keepdim) { + auto dim_now = check_and_trans_dim(self, dim); + // calculate the output size + auto outputSize = var_npu_output_size(self, dim_now, keepdim); + + // construct the output tensor of the NPU + at::Tensor variance = OpPreparation::ApplyTensor(self, outputSize); + + at::Tensor mean = OpPreparation::ApplyTensor(self, outputSize); + + // calculate the output result of the NPU + var_mean_out_npu(variance, mean, self, dim, unbiased, keepdim); + + return tuple(variance, mean); +} + +tuple NPUNativeFunctions::var_mean(const at::Tensor& self, bool unbiased) { + c10::SmallVector dim = CalcuOpUtil::get_dimlist_for_tensor(self); + + return NPUNativeFunctions::var_mean(self, dim, unbiased, false); +} +} // namespace native +} // namespace at_npu \ No newline at end of file