From c4de50cde99e12643310960bcf6870983faa868c Mon Sep 17 00:00:00 2001 From: zhoufan37 Date: Thu, 14 Apr 2022 15:43:20 +0800 Subject: [PATCH] modified applytensor Operator --- torch_npu/csrc/aten/ops/CtcLossKernelNpu.cpp | 10 ++++++---- torch_npu/csrc/aten/ops/MultinomialKernelNpu.cpp | 4 ++-- torch_npu/csrc/aten/ops/RepeatKernelNpu.cpp | 3 ++- torch_npu/csrc/aten/ops/RollKernelNpu.cpp | 5 +++-- torch_npu/csrc/aten/ops/SinKernelNpu.cpp | 3 ++- .../csrc/aten/ops/SortWithoutIndicesKernelNpu.cpp | 4 ++-- .../aten/ops/UpSampleBicubic2dBackwardKernelNpu.cpp | 6 ++++-- torch_npu/csrc/aten/ops/UpsampleLinear1dKernelNpu.cpp | 6 ++++-- torch_npu/csrc/aten/ops/_Unique2KernelNpu.cpp | 6 +++--- 9 files changed, 28 insertions(+), 19 deletions(-) diff --git a/torch_npu/csrc/aten/ops/CtcLossKernelNpu.cpp b/torch_npu/csrc/aten/ops/CtcLossKernelNpu.cpp index 7d7fda95ff..cdb30e30cf 100644 --- a/torch_npu/csrc/aten/ops/CtcLossKernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/CtcLossKernelNpu.cpp @@ -62,13 +62,15 @@ std::tuple NPUNativeFunctions::_ctc_loss( auto outputSizes = ctc_loss_npu_output_size(logProbs, targetsCast, targetLengths, maxLength); // construct the output tensor of the NPU - at::Tensor negLogLikelihood = OpPreparation::ApplyTensorWithSizes( + at::Tensor negLogLikelihood = OpPreparation::ApplyTensorWithFormat( std::get<0>(outputSizes), - logProbsNeed.options()); + logProbsNeed.options(), + CalcuOpUtil::get_tensor_npu_format(logProbsNeed)); - at::Tensor logAlpha = OpPreparation::ApplyTensorWithSizes( + at::Tensor logAlpha = OpPreparation::ApplyTensorWithFormat( std::get<1>(outputSizes), - logProbsNeed.options()); + logProbsNeed.options(), + CalcuOpUtil::get_tensor_npu_format(logProbsNeed)); // calculate the output result of the NPU OpCommand cmd; diff --git a/torch_npu/csrc/aten/ops/MultinomialKernelNpu.cpp b/torch_npu/csrc/aten/ops/MultinomialKernelNpu.cpp index 1b4ae1e888..0e5bc0bd10 100644 --- a/torch_npu/csrc/aten/ops/MultinomialKernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/MultinomialKernelNpu.cpp @@ -70,8 +70,8 @@ at::Tensor NPUNativeFunctions::multinomial( auto shape = array_to_small_vector(self.sizes()); shape[dim-1] = num_samples; - at::Tensor result = OpPreparation::ApplyTensorWithSizes( - shape, self.options().dtype(at::kLong)); + at::Tensor result = OpPreparation::ApplyTensorWithFormat( + shape, self.options().dtype(at::kLong), CalcuOpUtil::get_tensor_npu_format(self)); multinomial_out_npu_nocheck(result, self, num_samples, replacement, gen); return result; } diff --git a/torch_npu/csrc/aten/ops/RepeatKernelNpu.cpp b/torch_npu/csrc/aten/ops/RepeatKernelNpu.cpp index 4ff2c3237e..6518f34c75 100644 --- a/torch_npu/csrc/aten/ops/RepeatKernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/RepeatKernelNpu.cpp @@ -50,7 +50,8 @@ at::Tensor NPUNativeFunctions::repeat(const at::Tensor& self, at::IntArrayRef re auto outputSize = repeat_npu_output_size(selfCp, repeats); // construct the output tensor of the NPU - at::Tensor result = OpPreparation::ApplyTensorWithSizes(outputSize, selfCp.options()); + at::Tensor result = OpPreparation::ApplyTensorWithFormat( + outputSize, selfCp.options(), CalcuOpUtil::get_tensor_npu_format(selfCp)); // calculate the output result of the NPU repeat_out_npu_nocheck(result, selfCp, repeats); diff --git a/torch_npu/csrc/aten/ops/RollKernelNpu.cpp b/torch_npu/csrc/aten/ops/RollKernelNpu.cpp index 9dafaff230..d49e772de4 100644 --- a/torch_npu/csrc/aten/ops/RollKernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/RollKernelNpu.cpp @@ -52,9 +52,10 @@ at::Tensor& roll_transpose( std::swap(perm[axis], perm[firstDim]); at::Tensor transposeSelf = NPUNativeFunctions::npu_transpose(self, perm); auto outputSize = transpose_npu_output_size(result, perm); - at::Tensor transposeResult = OpPreparation::ApplyTensorWithSizes( + at::Tensor transposeResult = OpPreparation::ApplyTensorWithFormat( outputSize, - self.options()); + self.options(), + CalcuOpUtil::get_tensor_npu_format(self)); c10::SmallVector dim = {firstDim}; c10::SmallVector shift_bak = {shifts[id]}; at::IntArrayRef dim_now = at::IntArrayRef(dim); diff --git a/torch_npu/csrc/aten/ops/SinKernelNpu.cpp b/torch_npu/csrc/aten/ops/SinKernelNpu.cpp index b9b6b9a87d..0b21d6f00e 100644 --- a/torch_npu/csrc/aten/ops/SinKernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/SinKernelNpu.cpp @@ -46,7 +46,8 @@ at::Tensor& NPUNativeFunctions::sin_out(const at::Tensor& self, at::Tensor& resu at::Tensor NPUNativeFunctions::sin(const at::Tensor& self) { // construct the output tensor of the NPU - at::Tensor result = OpPreparation::ApplyTensorWithSizes(self.sizes(), self.options()); + at::Tensor result = OpPreparation::ApplyTensorWithFormat( + self.sizes(), self.options(), CalcuOpUtil::get_tensor_npu_format(self)); // calculate the output result of the NPU sin_out_npu_nocheck(result, self); diff --git a/torch_npu/csrc/aten/ops/SortWithoutIndicesKernelNpu.cpp b/torch_npu/csrc/aten/ops/SortWithoutIndicesKernelNpu.cpp index d3ca035042..46d52e6ad7 100644 --- a/torch_npu/csrc/aten/ops/SortWithoutIndicesKernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/SortWithoutIndicesKernelNpu.cpp @@ -61,7 +61,7 @@ at::Tensor& NPUNativeFunctions::npu_sort_v2_out( at::Tensor transposeSelf = NPUNativeFunctions::npu_transpose(self, perm); auto outputSize = transpose_npu_output_size(result, perm); - at::Tensor transposeResult = OpPreparation::ApplyTensorWithSizes(outputSize, result.options()); + at::Tensor transposeResult = OpPreparation::ApplyTensor(result, outputSize); sort_without_indices_no_transpose(transposeResult, transposeSelf, lastDim, descending); NPUNativeFunctions::npu_transpose_out(transposeResult, perm, result); @@ -97,7 +97,7 @@ at::Tensor NPUNativeFunctions::npu_sort_v2( at::Tensor transposeSelf = NPUNativeFunctions::npu_transpose(self, perm); auto outputSize = transpose_npu_output_size(result, perm); - at::Tensor transposeResult = OpPreparation::ApplyTensorWithSizes(outputSize, result.options()); + at::Tensor transposeResult = OpPreparation::ApplyTensor(result, outputSize); sort_without_indices_no_transpose(transposeResult, transposeSelf, lastDim, descending); NPUNativeFunctions::npu_transpose_out(transposeResult, perm, result); diff --git a/torch_npu/csrc/aten/ops/UpSampleBicubic2dBackwardKernelNpu.cpp b/torch_npu/csrc/aten/ops/UpSampleBicubic2dBackwardKernelNpu.cpp index 89835b9bab..b8af96d8bb 100644 --- a/torch_npu/csrc/aten/ops/UpSampleBicubic2dBackwardKernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/UpSampleBicubic2dBackwardKernelNpu.cpp @@ -115,7 +115,8 @@ at::Tensor NPUNativeFunctions::upsample_bicubic2d_backward( c10::optional scales_w) { // construct the output tensor of the NPU auto outputSize = upsample_bicubic2d_backward_npu_output_size(input_size); - at::Tensor result = OpPreparation::ApplyTensorWithSizes(outputSize, grad_output.options()); + at::Tensor result = OpPreparation::ApplyTensorWithFormat( + outputSize, grad_output.options(), CalcuOpUtil::get_tensor_npu_format(grad_output)); // calculate the output result of the NPU return upsample_bicubic2d_backward_out_nocheck(grad_output, output_size, input_size, align_corners, scales_h, scales_w, result); } @@ -131,7 +132,8 @@ at::Tensor NPUNativeFunctions::upsample_bicubic2d_backward( auto scales_w = CalcuOpUtil::get_scale_value(scale_factors, 1); // construct the output tensor of the NPU auto outputSize = upsample_bicubic2d_backward_npu_output_size(input_size); - at::Tensor result = OpPreparation::ApplyTensorWithSizes(outputSize, grad_output.options()); + at::Tensor result = OpPreparation::ApplyTensorWithFormat( + outputSize, grad_output.options(), CalcuOpUtil::get_tensor_npu_format(grad_output)); // calculate the output result of the NPU return upsample_bicubic2d_backward_out_nocheck(grad_output, osize, input_size, align_corners, scales_h, scales_w, result); } diff --git a/torch_npu/csrc/aten/ops/UpsampleLinear1dKernelNpu.cpp b/torch_npu/csrc/aten/ops/UpsampleLinear1dKernelNpu.cpp index b21790e4bf..9693d981fc 100644 --- a/torch_npu/csrc/aten/ops/UpsampleLinear1dKernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/UpsampleLinear1dKernelNpu.cpp @@ -124,7 +124,8 @@ at::Tensor NPUNativeFunctions::upsample_linear1d( self, output_size, align_corners, scales); // construct the output tensor of the NPU - at::Tensor result = OpPreparation::ApplyTensorWithSizes(outputSize, self.options()); + at::Tensor result = OpPreparation::ApplyTensorWithFormat( + outputSize, self.options(), CalcuOpUtil::get_tensor_npu_format(self)); // calculate the output result of the NPU upsample_linear1d_out_nocheck(self, output_size, align_corners, scales, result); @@ -144,7 +145,8 @@ at::Tensor NPUNativeFunctions::upsample_linear1d( self, osize, align_corners, scales_w); // construct the output tensor of the NPU - at::Tensor result = OpPreparation::ApplyTensorWithSizes(outputSize, self.options()); + at::Tensor result = OpPreparation::ApplyTensorWithFormat( + outputSize, self.options(), CalcuOpUtil::get_tensor_npu_format(self)); // calculate the output result of the NPU upsample_linear1d_out_nocheck(self, osize, align_corners, scales_w, result); diff --git a/torch_npu/csrc/aten/ops/_Unique2KernelNpu.cpp b/torch_npu/csrc/aten/ops/_Unique2KernelNpu.cpp index 44d25ff270..e8dedb6a52 100644 --- a/torch_npu/csrc/aten/ops/_Unique2KernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/_Unique2KernelNpu.cpp @@ -64,9 +64,9 @@ tuple NPUNativeFunctions::_unique2( } at::Tensor y = OpPreparation::ApplyTensor(selfCopy, std::get<0>(outputSizes)); - at::Tensor yOutputSize = OpPreparation::ApplyTensorWithSizes(std::get<1>(outputSizes), self.options().dtype(at::kLong)); - at::Tensor yInverse = OpPreparation::ApplyTensorWithSizes(std::get<2>(outputSizes), self.options().dtype(at::kLong)); - at::Tensor yCounts = OpPreparation::ApplyTensorWithSizes(std::get<0>(outputSizes), self.options().dtype(at::kLong)); + at::Tensor yOutputSize = OpPreparation::ApplyTensorWithFormat(std::get<1>(outputSizes), self.options().dtype(at::kLong), ACL_FORMAT_ND); + at::Tensor yInverse = OpPreparation::ApplyTensorWithFormat(std::get<2>(outputSizes), self.options().dtype(at::kLong), ACL_FORMAT_ND); + at::Tensor yCounts = OpPreparation::ApplyTensorWithFormat(std::get<0>(outputSizes), self.options().dtype(at::kLong), ACL_FORMAT_ND); _unique2_out_npu(y, yOutputSize, yInverse, yCounts, selfCopy, sorted, return_inverse, return_counts); -- Gitee