From b40c78fdc21dc7fa03055701bff728a8cca49494 Mon Sep 17 00:00:00 2001 From: zhoufan37 Date: Tue, 22 Mar 2022 17:23:35 +0800 Subject: [PATCH] modified take Operator --- torch_npu/csrc/aten/ops/TakeKernelNpu.cpp | 65 +++++++++++------------ 1 file changed, 32 insertions(+), 33 deletions(-) diff --git a/torch_npu/csrc/aten/ops/TakeKernelNpu.cpp b/torch_npu/csrc/aten/ops/TakeKernelNpu.cpp index f00ecc6d04..f088f621d8 100644 --- a/torch_npu/csrc/aten/ops/TakeKernelNpu.cpp +++ b/torch_npu/csrc/aten/ops/TakeKernelNpu.cpp @@ -19,42 +19,40 @@ namespace at_npu { namespace native { -c10::SmallVector take_npu_input( - const c10::SmallVector& inputTensor) { - at::Tensor contiguousTensor; - c10::SmallVector inputs; - - for (int i = 0; i < inputTensor.size(); i++) { - if (i == 0) { - int64_t input_size = 1; - at::Tensor input_tensor = inputTensor[i].reshape(-1); - contiguousTensor = NpuUtils::format_contiguous(input_tensor); - } else { - contiguousTensor = NpuUtils::format_contiguous(inputTensor[i]); - } - inputs.emplace_back(NPUTensorDesc(contiguousTensor)); - } - return inputs; -} - -c10::SmallVector take_npu_output(const c10::SmallVector& outputTensor) { - return CalcuOpUtil::create_npu_output_tensor_desc(outputTensor); -} +at::Tensor& take_out_nocheck(const at::Tensor& self, const at::Tensor& index, at::Tensor& result) { + at::Tensor input_tensor = self.reshape(-1); + at::Tensor contiguousSelf = NpuUtils::format_contiguous(input_tensor); + at::Tensor contiguousIndex = NpuUtils::format_contiguous(index); -c10::SmallVector take_npu_attr(const at::Tensor& self) { - NPUAttrDesc npuAttrValidateIndices = NPUAttrDesc("validate_indices", false); - c10::SmallVector attrs = {npuAttrValidateIndices}; - return attrs; + OpCommand cmd; + cmd.Name("Gather") + .Input(contiguousSelf) + .Input(contiguousIndex) + .Output(result) + .Attr("validate_indices", false) + .Run(); + + return result; } at::Tensor& NPUNativeFunctions::take_out(const at::Tensor& self, const at::Tensor& index, at::Tensor& result) { - // constructs the input and output NPUTensorDesc - auto inputs = take_npu_input({self,index}); - auto outputs = take_npu_output({result}); - // constructs the attr of the NPUAttrDesc - auto attrs = take_npu_attr(self); - // executing the NPU operator - CalcuOpUtil::execute_npu_operate("Gather", inputs, outputs, attrs); + // calculate the output size + auto outputSize = input_same_output_size(index); + + OpPreparation::CheckOut( + {self, index}, + result, + self, + outputSize); + + if (!NpuUtils::check_match(&result)) { + at::Tensor contiguousResult = NpuUtils::format_contiguous(result); + take_out_nocheck(self, index, contiguousResult); + NpuUtils::format_fresh_view(result, contiguousResult); + } else { + take_out_nocheck(self, index, result); + } + return result; } @@ -67,7 +65,8 @@ at::Tensor NPUNativeFunctions::take(const at::Tensor& self, const at::Tensor& in outputSize, self.options()); - NPUNativeFunctions::take_out(self, index, result); + take_out_nocheck(self, index, result); + return result; } } // namespace native -- Gitee