diff --git a/test/test_network_ops/test_unique_consecutive.py b/test/test_network_ops/test_unique_consecutive.py new file mode 100644 index 0000000000000000000000000000000000000000..e1f4e04077926972c5e3a9102d8b844d86d7b3ab --- /dev/null +++ b/test/test_network_ops/test_unique_consecutive.py @@ -0,0 +1,65 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd +# Copyright (c) 2019, Facebook CORPORATION. +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import numpy as np +import torch_npu + +from torch_npu.testing.testcase import TestCase, run_tests + +class TestUniqueConsecutive(TestCase): + def test_unique_consecutive(self, device="npu"): + shape_format = [ + [[torch.int32, (2, 3)], 0], + [[torch.long, (2, 3)], 1], + [[torch.float32, (2, 3)], 0], + [[torch.float16, (2, 3)], 1], + [[torch.int32, (2, 3)], None], + [[torch.long, (2, 3)], None], + [[torch.float32, (2, 3)], None], + [[torch.float16, (2, 3)], None] + ] + + for item in shape_format: + cpu_input = torch.rand(item[0][1]).random_(0, 3).to(item[0][0]) + npu_input = cpu_input.npu() + if item[0][0] == torch.float16: + cpu_input = cpu_input.float() + + cpu_output, cpu_idx, cpu_counts = torch.unique_consecutive(cpu_input, return_inverse=True, + return_counts=True, dim=item[1]) + npu_output, npu_idx, npu_counts = torch.unique_consecutive(npu_input, return_inverse=True, + return_counts=True, dim=item[1]) + + if item[0][0] == torch.float16: + cpu_output = cpu_output.half() + self.assertRtolEqual(cpu_output.numpy(), npu_output.cpu().numpy()) + self.assertRtolEqual(cpu_idx.numpy(), npu_idx.cpu().numpy()) + self.assertRtolEqual(cpu_counts.numpy(), npu_counts.cpu().numpy()) + + def test_unique_consecutive_case_in_dino(self, device="npu"): + input_list = [ + torch.tensor([224, 224, 96, 96, 96, 96, 96, 96, 96, 96]), + torch.tensor([224, 224]) + ] + for i in input_list: + cpu_output, cpu_counts = torch.unique_consecutive(i, return_counts=True) + npu_output, npu_counts = torch.unique_consecutive(i.npu(), return_counts=True) + self.assertRtolEqual(cpu_output.numpy(), npu_output.cpu().numpy()) + self.assertRtolEqual(cpu_counts.numpy(), npu_counts.cpu().numpy()) + +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/third_party/acl/inc/acl/acl_op_compiler.h b/third_party/acl/inc/acl/acl_op_compiler.h index 0807efbd8ed85ed77531ee78904e6ba9515f4103..338ec40e4c8bf0c71cb49b018eb1e95bf8e0fc82 100644 --- a/third_party/acl/inc/acl/acl_op_compiler.h +++ b/third_party/acl/inc/acl/acl_op_compiler.h @@ -93,6 +93,12 @@ ACL_FUNC_VISIBILITY aclError aclopCompileAndExecute(const char *opType, const aclopAttr *attr, aclopEngineType engineType, aclopCompileType compileFlag, const char *opPath, aclrtStream stream); +ACL_FUNC_VISIBILITY aclError aclopCompileAndExecuteV2(const char *opType, + int numInputs, aclTensorDesc *inputDesc[], aclDataBuffer *inputs[], + int numOutputs, aclTensorDesc *outputDesc[], aclDataBuffer *outputs[], + aclopAttr *attr, aclopEngineType engineType, aclopCompileType compileFlag, + const char *opPath, aclrtStream stream); + /** * @ingroup AscendCL * @brief set compile option diff --git a/third_party/acl/libs/acl_op_compiler.cpp b/third_party/acl/libs/acl_op_compiler.cpp index ea5feee96822a19c6482e48e9a57c5a0d05a2385..aaa4886cd7194625e27fb10f96f694e6e8f24516 100644 --- a/third_party/acl/libs/acl_op_compiler.cpp +++ b/third_party/acl/libs/acl_op_compiler.cpp @@ -43,6 +43,22 @@ aclError aclopCompileAndExecute( return 0; } +aclError aclopCompileAndExecuteV2( + const char* opType, + int numInputs, + aclTensorDesc *inputDesc[], + aclDataBuffer *inputs[], + int numOutputs, + aclTensorDesc *outputDesc[], + aclDataBuffer *outputs[], + aclopAttr* attr, + aclopEngineType engineType, + aclopCompileType compileFlag, + const char* opPath, + aclrtStream stream) { + return 0; +} + void GeGeneratorFinalize() { return; } diff --git a/torch_npu/csrc/aten/ops/UniqueConsecutiveKernelNpu.cpp b/torch_npu/csrc/aten/ops/UniqueConsecutiveKernelNpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..de1eba533a5d441b983de4f6eb61409db2057c71 --- /dev/null +++ b/torch_npu/csrc/aten/ops/UniqueConsecutiveKernelNpu.cpp @@ -0,0 +1,74 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +std::tuple unique_consecutive_out_npu( + at::Tensor& output, + at::Tensor& inverse_indices, + at::Tensor& counts, + const at::Tensor& self, + const bool return_inverse, + const bool return_counts, + c10::optional dim) { + at::Tensor selfCopy = self; + if (self.scalar_type() == at::ScalarType::Half) { + selfCopy = NPUNativeFunctions::npu_dtype_cast(self, at::ScalarType::Float); + output = NPUNativeFunctions::npu_dtype_cast(output, at::ScalarType::Float); + } + c10::SmallVector output_sync_idx = {0, 2}; + OpCommand cmd; + cmd.Sync(output_sync_idx) + .Name("UniqueConsecutive") + .Input(selfCopy) + .Output(output) + .Output(inverse_indices) + .Output(counts) + .Attr("return_idx", return_inverse) + .Attr("return_counts", return_counts); + if (dim.has_value()) { + cmd.Attr("axis", dim.value()); + } + cmd.Run(); + if (self.scalar_type() == at::ScalarType::Half) { + output = NPUNativeFunctions::npu_dtype_cast(output, at::ScalarType::Half); + } + return std::tie(output, inverse_indices, counts); +} + +std::tuple NPUNativeFunctions::unique_consecutive( + const at::Tensor& self, + const bool return_inverse, + const bool return_counts, + c10::optional dim) { + at::Tensor output = (dim.has_value()) ? + OpPreparation::ApplyTensor(self) : OpPreparation::ApplyTensor(self, {self.numel()}); + at::Tensor inverse_indices = (dim.has_value()) ? + OpPreparation::ApplyTensorWithFormat(self.size(dim.value()), self.options().dtype(at::kLong), ACL_FORMAT_ND) : + OpPreparation::ApplyTensorWithFormat(self.sizes(), self.options().dtype(at::kLong), ACL_FORMAT_ND); + at::Tensor counts = (dim.has_value()) ? + OpPreparation::ApplyTensorWithFormat(self.size(dim.value()), self.options().dtype(at::kLong), ACL_FORMAT_ND) : + OpPreparation::ApplyTensorWithFormat({self.numel()}, self.options().dtype(at::kLong), ACL_FORMAT_ND); + unique_consecutive_out_npu(output, inverse_indices, counts, self, return_inverse, return_counts, dim); + return std::tie(output, inverse_indices, counts); +} + + +} // namespace native +} // namespace at \ No newline at end of file diff --git a/torch_npu/csrc/framework/OpCommand.cpp b/torch_npu/csrc/framework/OpCommand.cpp index 9077b7465ed70ac258219c74523efe3dcab54fea..f818a73c6be69c42754fab7b13453eb172f97498 100644 --- a/torch_npu/csrc/framework/OpCommand.cpp +++ b/torch_npu/csrc/framework/OpCommand.cpp @@ -12,6 +12,7 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + #include #include "torch_npu/csrc/framework/OpCommand.h" #include "torch_npu/csrc/core/npu/register/OptionsManager.h" @@ -127,26 +128,35 @@ OpCommand& OpCommand::Output( if (!resultTypeDefined && commonType.has_value() && output.scalar_type() != commonType.value()) { output = NPUNativeFunctions::npu_dtype_cast(output, commonType.value()); - } + } ) + outputTensor.emplace_back(output); return AddOutput(output, realType); } void OpCommand::Run() { IF_GRAPH_MODE_THEN_RUN(return;) - if (c10_npu::option::OptionsManager::CheckQueueEnable()) { + if (c10_npu::option::OptionsManager::CheckQueueEnable() && !sync) { ExecuteParas execParams; aclCmd->ExportParams(execParams); c10_npu::queue::QueueParas params(c10_npu::queue::COMPILE_AND_EXECUTE, sizeof(ExecuteParas), &execParams); c10_npu::enCurrentNPUStream(¶ms); aclCmd->releaseSource(false); } else { - aclCmd->Run(); + aclCmd->Run(sync, sync_index, outputTensor); aclCmd->releaseSource(); - } + } aclCmds->Pop(); } +OpCommand& OpCommand::Sync(c10::SmallVector &index) { + sync_index = index; + if (!index.empty()) { + sync = true; + } + return *this; +} + OpCommand& OpCommand::AddTensorInput(at::Tensor &tensor, at::ScalarType forceScaleType, const string &descName, diff --git a/torch_npu/csrc/framework/OpCommand.h b/torch_npu/csrc/framework/OpCommand.h index 2ec73f62cb6fc7fa8d8b3ee30b4a5c05b4cf55d2..a7e6eaa76bf4444ec6671dafb1cc38033fabab9b 100644 --- a/torch_npu/csrc/framework/OpCommand.h +++ b/torch_npu/csrc/framework/OpCommand.h @@ -117,6 +117,8 @@ public: // Run a single op void Run(); + OpCommand& Sync(c10::SmallVector &sync_index); + private: OpCommand& AddTensorInput(at::Tensor &tensor, at::ScalarType forceScaleType = at::ScalarType::Undefined, @@ -150,6 +152,10 @@ private: c10::optional commonType = c10::nullopt; c10::optional commonShape = c10::nullopt; bool resultTypeDefined = false; + bool sync = false; + c10::SmallVector sync_index; + c10::SmallVector outputTensor; + }; // class OpCommand } // namespace native } // namespace at_npu diff --git a/torch_npu/csrc/framework/OpParamMaker.cpp b/torch_npu/csrc/framework/OpParamMaker.cpp index 5d20035d48d4d26e45261211ec6cbc72e46e2939..710b6d0508f8d344c336d8e1215695e67f2ccba4 100644 --- a/torch_npu/csrc/framework/OpParamMaker.cpp +++ b/torch_npu/csrc/framework/OpParamMaker.cpp @@ -107,22 +107,28 @@ namespace at_npu attrValue.data()); } - void OpCommandImpl::Run() - { + void OpCommandImpl::Run( + bool sync, + c10::SmallVector &sync_index, + c10::SmallVector &outputTensor) { NPU_LOGD("Op %s Run.", opName.c_str()); RECORD_FUNCTION(opName, std::vector({})); if (PyGILState_Check()) { // we need to release GIL for NPU to compile op. Py_BEGIN_ALLOW_THREADS - ACL_REQUIRE_OK_OP(InnerRun(opName, execParam), opName.c_str()); + ACL_REQUIRE_OK_OP(InnerRun(opName, execParam, sync, sync_index, outputTensor), opName.c_str()); Py_END_ALLOW_THREADS } else { - ACL_REQUIRE_OK_OP(InnerRun(opName, execParam), opName.c_str()); + ACL_REQUIRE_OK_OP(InnerRun(opName, execParam, sync, sync_index, outputTensor), opName.c_str()); } } - aclError OpCommandImpl::InnerRun(string name, AclExecParam ¶ms) - { + aclError OpCommandImpl::InnerRun( + string name, + AclExecParam ¶ms, + bool sync, + c10::SmallVector &sync_index, + c10::SmallVector &outputTensor) { auto stream = c10_npu::getCurrentNPUStream(); auto inputSize = params.inBuffer.size(); auto outputSize = params.outBuffer.size(); @@ -155,19 +161,45 @@ namespace at_npu TORCH_CHECK(false, "In aoe mode, AclGenGraphAndDumpForOp failed!"); } } - ret = aclopCompileAndExecute( - name.c_str(), - inputSize, - params.inDesc.data(), - params.inBuffer.data(), - outputSize, - params.outDesc.data(), - params.outBuffer.data(), - params.attr, - ACL_ENGINE_SYS, - ACL_COMPILE_SYS, - NULL, - stream); + if (!sync) { + ret = aclopCompileAndExecute( + name.c_str(), + inputSize, + params.inDesc.data(), + params.inBuffer.data(), + outputSize, + params.outDesc.data(), + params.outBuffer.data(), + params.attr, + ACL_ENGINE_SYS, + ACL_COMPILE_SYS, + NULL, + stream); + } else { + int64_t dimSize; + ret = aclopCompileAndExecuteV2( + name.c_str(), + inputSize, + const_cast(params.inDesc.data()), + const_cast(params.inBuffer.data()), + outputSize, + const_cast(params.outDesc.data()), + params.outBuffer.data(), + params.attr, + ACL_ENGINE_SYS, + ACL_COMPILE_SYS, + NULL, + stream); + + for (size_t i = 0; i < sync_index.size(); i++) { + c10::SmallVector real_shape; + for (int64_t j = 0; j < outputTensor[sync_index[i]].dim(); j++) { + C10_NPU_CHECK(aclGetTensorDescDimV2(params.outDesc[sync_index[i]], j, &dimSize)); + real_shape.emplace_back(dimSize); + } + outputTensor[sync_index[i]].resize_(real_shape); + } + } ++index; } while (NpuUtils::IsOomError(ret, index) && (index < NPU_MAX_OP_EXEC_TRY_NUM)); if (reset_flag) diff --git a/torch_npu/csrc/framework/OpParamMaker.h b/torch_npu/csrc/framework/OpParamMaker.h index d1b02a0dea19c20a27abbea53681926494553ace..58b12eba0a9eda15c99136281830b8614aef1920 100644 --- a/torch_npu/csrc/framework/OpParamMaker.h +++ b/torch_npu/csrc/framework/OpParamMaker.h @@ -303,7 +303,7 @@ namespace at_npu } } - void Run(); + void Run(bool sync, c10::SmallVector &sync_index, c10::SmallVector &outputTensor); void releaseSource(bool no_blocking = true) { @@ -364,7 +364,13 @@ namespace at_npu } } - aclError InnerRun(string name, AclExecParam ¶ms); + aclError InnerRun( + string name, + AclExecParam ¶ms, + bool sync, + c10::SmallVector &sync_index, + c10::SmallVector &outputTensor + ); private: string opName;