From ad8282798d82bfde1bd5dcc8668e2ae7013a7f89 Mon Sep 17 00:00:00 2001 From: linux Date: Thu, 28 Aug 2025 17:33:38 +0800 Subject: [PATCH] op base op register test aten operator --- inferrt/src/ir/graph.h | 14 +- inferrt/src/ir/tensor/storage.cc | 10 +- inferrt/src/ir/tensor/tensor.cc | 58 ++++-- inferrt/src/ir/tensor/tensor.h | 10 ++ inferrt/src/ir/value/value.cc | 50 ++++-- inferrt/src/ir/value/value.h | 5 +- inferrt/src/ops/CMakeLists.txt | 3 + inferrt/src/ops/cpu/aten/CMakeLists.txt | 4 +- inferrt/src/ops/cpu/aten/aten_matmul.cc | 54 ++++++ inferrt/src/ops/cpu/aten/aten_matmul.h | 39 ++++ inferrt/src/ops/cpu/aten/test_aten.cc | 61 +++++++ inferrt/src/ops/cpu/aten/test_aten.h | 52 ++++++ inferrt/src/ops/cpu/aten/utils/aten_convert.h | 56 ++++++ inferrt/src/ops/op_base/CMakeLists.txt | 5 + inferrt/src/ops/op_base/op_matmul.cc | 28 +++ inferrt/src/ops/op_base/op_matmul.h | 36 ++++ inferrt/src/ops/operator.h | 12 +- inferrt/src/ops/utils/op_constants.h | 105 +++++++++++ inferrt/src/ops/utils/op_register.h | 170 ++++++++++++++++++ 19 files changed, 732 insertions(+), 40 deletions(-) create mode 100644 inferrt/src/ops/cpu/aten/aten_matmul.cc create mode 100644 inferrt/src/ops/cpu/aten/aten_matmul.h create mode 100644 inferrt/src/ops/cpu/aten/test_aten.cc create mode 100644 inferrt/src/ops/cpu/aten/test_aten.h create mode 100644 inferrt/src/ops/cpu/aten/utils/aten_convert.h create mode 100644 inferrt/src/ops/op_base/CMakeLists.txt create mode 100644 inferrt/src/ops/op_base/op_matmul.cc create mode 100644 inferrt/src/ops/op_base/op_matmul.h create mode 100644 inferrt/src/ops/utils/op_constants.h create mode 100644 inferrt/src/ops/utils/op_register.h diff --git a/inferrt/src/ir/graph.h b/inferrt/src/ir/graph.h index b8174b26..a3262b26 100644 --- a/inferrt/src/ir/graph.h +++ b/inferrt/src/ir/graph.h @@ -49,8 +49,18 @@ struct Graph { using NodePtr = std::shared_ptr; using GraphPtr = std::shared_ptr; -inline std::ostream &operator<<(std::ostream &os, const NodePtr node) { - os << "Node(" << "op=" << ops::ToStr(node->op) << ", output=" << node->output << ")"; +inline std::ostream &operator<<(std::ostream &os, const Node &node) { + os << "Node(" + << "op=" << ops::ToStr(node.op) << ", value=" << node.output << ")"; + return os; +} + +inline std::ostream &operator<<(std::ostream &os, const NodePtr &node) { + if (node == nullptr) { + os << "Null"; + } else { + os << *node; + } return os; } diff --git a/inferrt/src/ir/tensor/storage.cc b/inferrt/src/ir/tensor/storage.cc index 13d01f93..f5183f8f 100644 --- a/inferrt/src/ir/tensor/storage.cc +++ b/inferrt/src/ir/tensor/storage.cc @@ -31,9 +31,13 @@ namespace ir { */ Storage::Storage(size_t sizeBytes, hardware::Device device) : sizeBytes_(sizeBytes), device_(device), ownsData_(true) { if (device.type == hardware::DeviceType::CPU) { - data_ = malloc(sizeBytes); - if (!data_) { - throw std::bad_alloc(); + if (sizeBytes == 0) { + data_ = nullptr; + } else { + data_ = malloc(sizeBytes); + if (!data_) { + throw std::bad_alloc(); + } } } else { // Handle other devices like GPU (e.g., cudaMalloc) diff --git a/inferrt/src/ir/tensor/tensor.cc b/inferrt/src/ir/tensor/tensor.cc index 1dbc501c..0f76ba68 100644 --- a/inferrt/src/ir/tensor/tensor.cc +++ b/inferrt/src/ir/tensor/tensor.cc @@ -18,6 +18,7 @@ #include #include +#include "common/common.h" #include "ir/tensor/tensor.h" namespace mrt { @@ -75,6 +76,16 @@ Tensor::Tensor(const std::vector &shape, DataType dtype, hardware::Devi storage_ = MakeIntrusive(sizeBytes, device); } +void Tensor::ResizeStorage() { + CHECK_IF_NULL(storage_); + size_t sizeBytes = 0; + if (!HasDynamicShape()) { + sizeBytes = numel_ * dtype_.GetSize(); + } + + storage_ = MakeIntrusive(sizeBytes, storage_->GetDevice()); +} + Tensor::Tensor(StoragePtr storage, DataType dtype, const std::vector &shape) : dtype_(dtype), shape_(shape), storage_(storage) { ComputeStrides(); @@ -107,26 +118,52 @@ void Tensor::SetShape(const std::vector &&shape) { numel_ = CalculateNumel(shape_, true); } +std::ostream &operator<<(std::ostream &os, Tensor *tensor) { + if (tensor == nullptr) { + os << "Null"; + } else { + os << *tensor; + } + return os; +} + std::ostream &operator<<(std::ostream &os, const Tensor *tensor) { + if (tensor == nullptr) { + os << "Null"; + } else { + os << *tensor; + } + return os; +} + +std::ostream &operator<<(std::ostream &os, const Tensor &tensor) { + constexpr size_t numelLimit = 30; os << "Tensor(shape=["; - const auto &shape = tensor->Shape(); + const auto &shape = tensor.Shape(); for (size_t i = 0; i < shape.size(); ++i) { os << shape[i]; if (i < shape.size() - 1) { os << ", "; } } - os << "], dtype=" << tensor->Dtype().ToString(); + os << "], dtype=" << tensor.Dtype().ToString(); os << ", data=["; - if (tensor->DataPtr()) { - if (tensor->Dtype() == DataType::Float32) { // TODO: support other dtypes - const auto data = static_cast(tensor->DataPtr()); - const size_t numel = tensor->Numel(); - for (size_t i = 0; i < numel; ++i) { - os << data[i]; - if (i < numel - 1) { - os << ", "; + if (tensor.DataPtr()) { + if (tensor.Dtype() == DataType::Float32) { // TODO: support other dtypes + const auto data = static_cast(tensor.DataPtr()); + const size_t numel = tensor.Numel(); + if (numel <= numelLimit) { + for (size_t i = 0; i < numel; ++i) { + os << data[i]; + if (i < numel - 1) { + os << ", "; + } } + } else { + for (size_t i = 0; i < numelLimit; ++i) { + os << data[i] << ", "; + } + os << "..."; } } else { os << "..."; @@ -137,6 +174,5 @@ std::ostream &operator<<(std::ostream &os, const Tensor *tensor) { os << "])"; return os; } - } // namespace ir } // namespace mrt \ No newline at end of file diff --git a/inferrt/src/ir/tensor/tensor.h b/inferrt/src/ir/tensor/tensor.h index a9fec5a2..cf397856 100644 --- a/inferrt/src/ir/tensor/tensor.h +++ b/inferrt/src/ir/tensor/tensor.h @@ -17,6 +17,7 @@ #ifndef __IR_TENSOR_TENSOR_H__ #define __IR_TENSOR_TENSOR_H__ +#include #include #include @@ -125,6 +126,11 @@ class Tensor { * @return The storage. */ StoragePtr GetStorage() const { return storage_; } + /** + * @brief Construct a new `storage_` with the same device as the old. Note: Need set shape and dtype first, the device + * memory of the old `storage_` might be released. + */ + void ResizeStorage(); /** * @brief Gets a raw const pointer to the tensor's data. * This pointer takes into account the storage offset. @@ -169,6 +175,10 @@ class Tensor { int64_t storageOffset_ = 0; ///< The offset in the storage, in number of elements. }; +std::ostream &operator<<(std::ostream &os, Tensor *tensor); +std::ostream &operator<<(std::ostream &os, const Tensor *tensor); +std::ostream &operator<<(std::ostream &os, const Tensor &tensor); + } // namespace ir } // namespace mrt diff --git a/inferrt/src/ir/value/value.cc b/inferrt/src/ir/value/value.cc index efb5b1a5..2c9cc8a1 100644 --- a/inferrt/src/ir/value/value.cc +++ b/inferrt/src/ir/value/value.cc @@ -14,8 +14,9 @@ * limitations under the License. */ -#include "ir/value/value.h" #include +#include +#include "ir/value/value.h" namespace mrt { namespace ir { @@ -98,32 +99,60 @@ std::ostream &operator<<(std::ostream &os, const Tuple *tuple) { return os; } -std::ostream &operator<<(std::ostream &os, const ValuePtr &value) { - return operator<<(os, value.get()); +std::ostream &operator<<(std::ostream &os, const ValuePtr &value) { return operator<<(os, value.get()); } + +std::ostream &operator<<(std::ostream &os, Value *value) { + if (value == nullptr) { + os << "Null"; + } else { + os << *value; + } + return os; } std::ostream &operator<<(std::ostream &os, const Value *value) { - switch (value->tag_) { + if (value == nullptr) { + os << "Null"; + } else { + os << *value; + } + return os; +} + +std::ostream &operator<<(std::ostream &os, const std::vector &values) { + os << "std::vector{"; + for (size_t i = 0; i < values.size(); ++i) { + os << values[i]; + if (i < values.size() - 1) { + os << ", "; + } + } + os << "}"; + return os; +} + +std::ostream &operator<<(std::ostream &os, const Value &value) { + switch (value.tag_) { case Value::Tag::None: os << "None"; break; case Value::Tag::Tensor: - os << value->ToTensor(); + os << value.ToTensor(); break; case Value::Tag::Double: - os << value->ToDouble(); + os << value.ToDouble(); break; case Value::Tag::Int: - os << value->ToInt(); + os << value.ToInt(); break; case Value::Tag::Bool: - os << (value->ToBool() ? "true" : "false"); + os << (value.ToBool() ? "true" : "false"); break; case Value::Tag::String: - os << "\"" << value->ToString() << "\""; + os << "\"" << value.ToString() << "\""; break; case Value::Tag::Tuple: - os << value->ToTuple(); + os << value.ToTuple(); break; } return os; @@ -131,4 +160,3 @@ std::ostream &operator<<(std::ostream &os, const Value *value) { } // namespace ir } // namespace mrt - diff --git a/inferrt/src/ir/value/value.h b/inferrt/src/ir/value/value.h index 5ad8e31e..3d46b982 100644 --- a/inferrt/src/ir/value/value.h +++ b/inferrt/src/ir/value/value.h @@ -158,7 +158,7 @@ class Value : public RefCounted { * @param value The ValuePtr to output. * @return The output stream. */ - friend std::ostream &operator<<(std::ostream &os, const Value *value); + friend std::ostream &operator<<(std::ostream &os, const Value &value); private: /** @@ -178,7 +178,10 @@ class Value : public RefCounted { }; std::ostream &operator<<(std::ostream &os, const ValuePtr &value); +std::ostream &operator<<(std::ostream &os, const Value &value); +std::ostream &operator<<(std::ostream &os, Value *value); std::ostream &operator<<(std::ostream &os, const Value *value); +std::ostream &operator<<(std::ostream &os, const std::vector &values); } // namespace ir } // namespace mrt diff --git a/inferrt/src/ops/CMakeLists.txt b/inferrt/src/ops/CMakeLists.txt index 7bad5cef..872fe4f4 100644 --- a/inferrt/src/ops/CMakeLists.txt +++ b/inferrt/src/ops/CMakeLists.txt @@ -1,9 +1,12 @@ +check_debug_log_out() + add_subdirectory(op_def) add_library(kernel SHARED kernel_lib.cc) target_link_libraries(kernel PRIVATE mrt_ir_obj ${CMAKE_DL_LIBS}) add_subdirectory(dummy) +add_subdirectory(op_base) if(ENABLE_KERNEL_ATEN) add_subdirectory(cpu/aten) diff --git a/inferrt/src/ops/cpu/aten/CMakeLists.txt b/inferrt/src/ops/cpu/aten/CMakeLists.txt index dfbbc832..4588798f 100644 --- a/inferrt/src/ops/cpu/aten/CMakeLists.txt +++ b/inferrt/src/ops/cpu/aten/CMakeLists.txt @@ -1,5 +1,3 @@ -check_debug_log_out() - file(GLOB_RECURSE ATEN_KERNEL_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") # Link against torch libraries. @@ -9,4 +7,4 @@ find_package(Torch REQUIRED) add_library(kernel_aten SHARED ${ATEN_KERNEL_SRC_FILES}) target_link_libraries(kernel_aten PRIVATE ${TORCH_LIBRARIES}) -target_link_libraries(kernel_aten PRIVATE kernel ops_obj mrt_ir_obj) +target_link_libraries(kernel_aten PRIVATE ops_obj ops_base_obj kernel) diff --git a/inferrt/src/ops/cpu/aten/aten_matmul.cc b/inferrt/src/ops/cpu/aten/aten_matmul.cc new file mode 100644 index 00000000..dae2902f --- /dev/null +++ b/inferrt/src/ops/cpu/aten/aten_matmul.cc @@ -0,0 +1,54 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "ops/cpu/aten/aten_matmul.h" +#include "ops/cpu/aten/utils/aten_convert.h" +#include "ops/utils/op_register.h" + +namespace mrt { +namespace ops { +OpsErrorCode AtenMatMul::InferShape(const std::vector &input, ir::Value *output) { + if (input.size() != kInputSize2) { + LOG_ERROR << "Expect input size is 2 for AtenMatMul, but got: " << input.size(); + return INVALID_INPUT_NUM; + } + auto &input0Shape = input[kIndex0]->ToTensor()->Shape(); + auto &input1Shape = input[kIndex1]->ToTensor()->Shape(); + std::vector outputShape = at::infer_size(input0Shape, input1Shape); + CHECK_IF_NULL(output); + auto outputTensor = output->ToTensor(); + CHECK_IF_NULL(outputTensor); + outputTensor->SetShape(outputShape); + auto outputDtype = input[kIndex0]->ToTensor()->Dtype(); + outputTensor->SetDtype(outputDtype); + outputTensor->ResizeStorage(); + return SUCCESS; +} + +OpsErrorCode AtenMatMul::Launch(const std::vector &input, const std::vector &workspace, + ir::Value *output, void *stream) { + auto atenInput0 = ToAtenTensor(input[kIndex0]); + auto atenInput1 = ToAtenTensor(input[kIndex1]); + auto atenOutput = ToAtenTensor(output); + at::matmul_out(atenOutput, atenInput0, atenInput1); + return SUCCESS; +} + +MRT_REG_OP(matmul, AtenMatMul, CPU); +} // namespace ops +} // namespace mrt diff --git a/inferrt/src/ops/cpu/aten/aten_matmul.h b/inferrt/src/ops/cpu/aten/aten_matmul.h new file mode 100644 index 00000000..6b97fb0b --- /dev/null +++ b/inferrt/src/ops/cpu/aten/aten_matmul.h @@ -0,0 +1,39 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __OPS_CPU_ATEN_ATEN_MATMUL_H__ +#define __OPS_CPU_ATEN_ATEN_MATMUL_H__ + +#include +#include + +#include "ops/op_base/op_matmul.h" + +namespace mrt { +namespace ops { +class AtenMatMul : public OpMatMul { + public: + AtenMatMul() = default; + ~AtenMatMul() override = default; + + OpsErrorCode InferShape(const std::vector &input, ir::Value *output) override; + OpsErrorCode Launch(const std::vector &input, const std::vector &workspace, + ir::Value *output, void *stream) override; +}; +} // namespace ops +} // namespace mrt + +#endif // __OPS_CPU_ATEN_ATEN_MATMUL_H__ diff --git a/inferrt/src/ops/cpu/aten/test_aten.cc b/inferrt/src/ops/cpu/aten/test_aten.cc new file mode 100644 index 00000000..7efcebeb --- /dev/null +++ b/inferrt/src/ops/cpu/aten/test_aten.cc @@ -0,0 +1,61 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "ops/cpu/aten/test_aten.h" +// This file need to be deleted in the future. +namespace mrt { +namespace ops { +void TestAtenKernel::Init() { + CHECK_IF_NULL(node_); + input_.clear(); + for (auto &input : node_->inputs) { + CHECK_IF_NULL(input->output); + input_.emplace_back(input->output.get()); + } + CHECK_IF_NULL(node_->output); + output_ = node_->output.get(); +} + +void TestAtenKernel::InferShape() { + CHECK_IF_NULL(operator_); + node_->output = ir::MakeIntrusive( + ir::Tensor({-1}, ir::DataType::Type::Float32, hardware::Device(hardware::DeviceType::CPU, 0))); + Init(); + LOG_OUT << "Begin InferShape for operator [" << ToStr(node_->op) << "], input=" << input_ << ", output=" << output_; + if (operator_->InferShape(input_, output_) != SUCCESS) { + LOG_EXCEPTION << "Infer shape failed for operator " << ToStr(node_->op); + } +} + +void TestAtenKernel::Resize() { + // null +} + +void TestAtenKernel::Launch() { + CHECK_IF_NULL(operator_); + Init(); + LOG_OUT << "Begin Launch for operator [" << ToStr(node_->op) << "], input=" << input_ << ", output=" << output_; + if (operator_->Launch(input_, {}, output_, nullptr) != SUCCESS) { + LOG_EXCEPTION << "Launch operator " << ToStr(node_->op) << " failed"; + } +} + +DART_REGISTER_KERNEL_LIB("TestAten", TestAtenKernelLib); + +} // namespace ops +} // namespace mrt diff --git a/inferrt/src/ops/cpu/aten/test_aten.h b/inferrt/src/ops/cpu/aten/test_aten.h new file mode 100644 index 00000000..ba739b2e --- /dev/null +++ b/inferrt/src/ops/cpu/aten/test_aten.h @@ -0,0 +1,52 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __OPS_CPU_ATEN_TEST_ATEN_MATMUL_H__ +#define __OPS_CPU_ATEN_TEST_ATEN_MATMUL_H__ + +#include +#include + +#include "ops/op_def/ops_name.h" +#include "ops/operator.h" +#include "ops/kernel_lib.h" +#include "ops/utils/op_register.h" +// This file need to be deleted in the future. +namespace mrt { +namespace ops { +class TestAtenKernel : public DAKernel { + public: + explicit TestAtenKernel(ir::NodePtr node) : DAKernel(node) { operator_ = CreateOperator(ToStr(node->op)); } + void Init() override; + void InferShape() override; + void Resize() override; + void Launch() override; + + private: + std::unique_ptr operator_; + std::vector input_; + ir::Value *output_; +}; + +class DA_API TestAtenKernelLib : public KernelLib { + public: + TestAtenKernelLib() : KernelLib("TestAten") {} + ~TestAtenKernelLib() = default; + DAKernel *CreateKernel(ir::NodePtr node) const override { return new TestAtenKernel(node); } +}; +} // namespace ops +} // namespace mrt +#endif // __OPS_CPU_ATEN_TEST_ATEN_MATMUL_H__ diff --git a/inferrt/src/ops/cpu/aten/utils/aten_convert.h b/inferrt/src/ops/cpu/aten/utils/aten_convert.h new file mode 100644 index 00000000..96d69910 --- /dev/null +++ b/inferrt/src/ops/cpu/aten/utils/aten_convert.h @@ -0,0 +1,56 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __OPS_CPU_ATEN_UTILS_ATEN_CONVERTER_H__ +#define __OPS_CPU_ATEN_UTILS_ATEN_CONVERTER_H__ + +#include + +#include "common/logger.h" +#include "ir/common/dtype.h" +#include "ir/value/value.h" + +namespace mrt { +namespace ops { +inline at::ScalarType ToAtenDType(ir::DataType type) { + switch (type) { + case ir::DataType::Bool: + return at::kBool; + case ir::DataType::Float32: + return at::kFloat; + case ir::DataType::Float64: + return at::kDouble; + case ir::DataType::Int16: + return at::kShort; + case ir::DataType::Int32: + return at::kInt; + case ir::DataType::Int64: + return at::kLong; + default: + LOG_ERROR << "Unsupported DataType for Aten conversion."; + exit(EXIT_FAILURE); + } +} + +inline at::Tensor ToAtenTensor(const ir::Value *value) { + auto tensor = value->ToTensor(); + auto options = at::TensorOptions().dtype(ToAtenDType(tensor->Dtype())); + return at::from_blob(const_cast(tensor->DataPtr()), tensor->Shape(), options); +} + +} // namespace ops +} // namespace mrt +#endif // __OPS_CPU_ATEN_UTILS_ATEN_CONVERTER_H__ diff --git a/inferrt/src/ops/op_base/CMakeLists.txt b/inferrt/src/ops/op_base/CMakeLists.txt new file mode 100644 index 00000000..987a6db5 --- /dev/null +++ b/inferrt/src/ops/op_base/CMakeLists.txt @@ -0,0 +1,5 @@ +check_debug_log_out() + +file(GLOB_RECURSE OPS_BASE_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") +add_library(ops_base_obj STATIC ${OPS_BASE_SRC_FILES}) +target_link_libraries(ops_base_obj PRIVATE mrt_ir_obj) diff --git a/inferrt/src/ops/op_base/op_matmul.cc b/inferrt/src/ops/op_base/op_matmul.cc new file mode 100644 index 00000000..163063cd --- /dev/null +++ b/inferrt/src/ops/op_base/op_matmul.cc @@ -0,0 +1,28 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "ops/op_base/op_matmul.h" + +namespace mrt { +namespace ops { +OpsErrorCode OpMatMul::InferShape(const std::vector &input, ir::Value *output) { + // TODO + return SUCCESS; +} +} // namespace ops +} // namespace mrt diff --git a/inferrt/src/ops/op_base/op_matmul.h b/inferrt/src/ops/op_base/op_matmul.h new file mode 100644 index 00000000..a63f9597 --- /dev/null +++ b/inferrt/src/ops/op_base/op_matmul.h @@ -0,0 +1,36 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __OPS_OP_BASE_OP_MATMUL_H__ +#define __OPS_OP_BASE_OP_MATMUL_H__ + +#include +#include + +#include "ops/operator.h" + +namespace mrt { +namespace ops { +class OpMatMul : public Operator { + public: + OpMatMul() = default; + ~OpMatMul() override = default; + + OpsErrorCode InferShape(const std::vector &input, ir::Value *output) override; +}; +} // namespace ops +} // namespace mrt +#endif // __OPS_OP_BASE_OP_MATMUL_H__ diff --git a/inferrt/src/ops/operator.h b/inferrt/src/ops/operator.h index 710475c5..d797eb93 100644 --- a/inferrt/src/ops/operator.h +++ b/inferrt/src/ops/operator.h @@ -21,7 +21,7 @@ #include #include -#include "ops/op_def/ops_name.h" +#include "ops/utils/op_constants.h" #include "common/visible.h" #include "ir/graph.h" @@ -48,20 +48,17 @@ enum OpsErrorCode { SUCCESS = 0, INVALID_PARAM, INVALID_SHAPE, + INVALID_INPUT_NUM, INVALID_DEVICE_ADDR, UNKNOWN_ERROR = 1000 - }; -// Need to be deleted in the future. -using OpName = Op; - // @brief Abstract base class representing a computational kernel. A Operator encapsulates the core computation logic // for a specific operator. Derived classes must implement shape inference and launch operations. Kernels of different // device types share the InferShape function, but need to implement their respective Launch functions. class Operator { public: - Operator(const OpName &op) : op_(op) {} + Operator() = default; virtual ~Operator() = default; /** @@ -117,9 +114,6 @@ class Operator { * otherwise returns false. */ virtual bool NeedUpdateOutputShapeAfterLaunch() const { return false; } - - protected: - OpName op_{Op_End}; }; } // namespace ops } // namespace mrt diff --git a/inferrt/src/ops/utils/op_constants.h b/inferrt/src/ops/utils/op_constants.h new file mode 100644 index 00000000..19d4cd0f --- /dev/null +++ b/inferrt/src/ops/utils/op_constants.h @@ -0,0 +1,105 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __OPS_UTILS_OP_CONSTANTS_H__ +#define __OPS_UTILS_OP_CONSTANTS_H__ + +#include + +namespace mrt { +namespace ops { +// index of input or output +inline constexpr size_t kIndex0 = 0; +inline constexpr size_t kIndex1 = 1; +inline constexpr size_t kIndex2 = 2; +inline constexpr size_t kIndex3 = 3; +inline constexpr size_t kIndex4 = 4; +inline constexpr size_t kIndex5 = 5; +inline constexpr size_t kIndex6 = 6; +inline constexpr size_t kIndex7 = 7; +inline constexpr size_t kIndex8 = 8; +inline constexpr size_t kIndex9 = 9; +inline constexpr size_t kIndex10 = 10; +inline constexpr size_t kIndex11 = 11; +inline constexpr size_t kIndex12 = 12; +inline constexpr size_t kIndex13 = 13; +inline constexpr size_t kIndex14 = 14; +inline constexpr size_t kIndex15 = 15; +inline constexpr size_t kIndex16 = 16; +inline constexpr size_t kIndex17 = 17; +inline constexpr size_t kIndex18 = 18; +inline constexpr size_t kIndex19 = 19; +inline constexpr size_t kIndex20 = 20; +inline constexpr size_t kIndex21 = 21; +inline constexpr size_t kIndex22 = 22; +inline constexpr size_t kIndex23 = 23; +inline constexpr size_t kIndex24 = 24; +inline constexpr size_t kIndex25 = 25; +inline constexpr size_t kIndex26 = 26; +inline constexpr size_t kIndex27 = 27; +inline constexpr size_t kIndex28 = 28; +inline constexpr size_t kIndex29 = 29; +inline constexpr size_t kIndex30 = 30; +inline constexpr size_t kIndex31 = 31; +inline constexpr size_t kIndex32 = 32; +inline constexpr size_t kIndex33 = 33; +inline constexpr size_t kIndex34 = 34; +inline constexpr size_t kIndex35 = 35; +inline constexpr size_t kIndex36 = 36; +inline constexpr size_t kIndex37 = 37; + +// dim of shape +inline constexpr size_t kDim0 = 0; +inline constexpr size_t kDim1 = 1; +inline constexpr size_t kDim2 = 2; +inline constexpr size_t kDim3 = 3; +inline constexpr size_t kDim4 = 4; +inline constexpr size_t kDim5 = 5; +inline constexpr size_t kDim6 = 6; + +// output size of op +inline constexpr size_t kOutputSize1 = 1; +inline constexpr size_t kOutputSize2 = 2; +inline constexpr size_t kOutputSize3 = 3; +inline constexpr size_t kOutputSize4 = 4; +inline constexpr size_t kOutputSize5 = 5; +inline constexpr size_t kOutputSize6 = 6; + +// input size of op +inline constexpr size_t kInputSize0 = 0; +inline constexpr size_t kInputSize1 = 1; +inline constexpr size_t kInputSize2 = 2; +inline constexpr size_t kInputSize3 = 3; +inline constexpr size_t kInputSize4 = 4; +inline constexpr size_t kInputSize5 = 5; +inline constexpr size_t kInputSize6 = 6; +inline constexpr size_t kInputSize7 = 7; +inline constexpr size_t kInputSize8 = 8; +inline constexpr size_t kInputSize9 = 9; +inline constexpr size_t kInputSize10 = 10; +inline constexpr size_t kInputSize11 = 11; +inline constexpr size_t kInputSize12 = 12; +inline constexpr size_t kInputSize13 = 13; +inline constexpr size_t kInputSize14 = 14; +inline constexpr size_t kInputSize15 = 15; +inline constexpr size_t kInputSize16 = 16; +inline constexpr size_t kInputSize17 = 17; +inline constexpr size_t kInputSize18 = 18; +inline constexpr size_t kInputSize19 = 19; +inline constexpr size_t kInputSize20 = 20; +} // namespace ops +} // namespace mrt +#endif // diff --git a/inferrt/src/ops/utils/op_register.h b/inferrt/src/ops/utils/op_register.h new file mode 100644 index 00000000..53327139 --- /dev/null +++ b/inferrt/src/ops/utils/op_register.h @@ -0,0 +1,170 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __OPS_UTILS_OP_REGISTER_H__ +#define __OPS_UTILS_OP_REGISTER_H__ + +#include +#include +#include +#include +#include +#include +#include + +#include "common/logger.h" +#include "common/common.h" +#include "ops/operator.h" + +namespace mrt { +namespace ops { +inline constexpr std::string_view kUnknownOpFactory = "Unknown"; +inline constexpr std::string_view kAscendOpFactory = "Ascend"; +inline constexpr std::string_view kCPUOpFactory = "CPU"; + +struct UnknownOpFactory {}; +struct AscendOpFactory {}; +struct CPUOpFactory {}; + +template +struct OpFactoryTraits; + +template <> +struct OpFactoryTraits { + static constexpr std::string_view value = kUnknownOpFactory; +}; + +template <> +struct OpFactoryTraits { + static constexpr std::string_view value = kAscendOpFactory; +}; + +template <> +struct OpFactoryTraits { + static constexpr std::string_view value = kCPUOpFactory; +}; + +class OpFactoryBase { + using OpFactoryMapType = std::unordered_map>; + + public: + OpFactoryBase() = default; + virtual ~OpFactoryBase() = default; + + protected: + static OpFactoryBase *GetOpFactory(const std::string_view &name) { + auto iter = OpFactoryMap().find(name); + if (iter == OpFactoryMap().end()) { + return nullptr; + } + return iter->second.get(); + } + + static OpFactoryBase *CreateOpFactory(const std::string_view &name, std::unique_ptr &&factory) { + if (OpFactoryMap().find(name) != OpFactoryMap().end()) { + LOG_EXCEPTION << name << " already has an OpFactory, please check!"; + } + (void)OpFactoryMap().emplace(name, std::move(factory)); + return GetOpFactory(name); + } + + private: + static OpFactoryMapType &OpFactoryMap() { + static OpFactoryMapType factoryMap; + return factoryMap; + } +}; + +template +class OpFactory : public OpFactoryBase { + using CreatorFunc = std::function()>; + + public: + OpFactory() = default; + ~OpFactory() = default; + OpFactory(const OpFactory &) = delete; + void operator=(const OpFactory &) = delete; + + static OpFactory &GetInstance() { + auto factoryBase = OpFactoryBase::GetOpFactory(OpFactoryTraits::value); + if (factoryBase == nullptr) { + factoryBase = OpFactoryBase::CreateOpFactory(OpFactoryTraits::value, + std::make_unique>()); + } + return *static_cast *>(factoryBase); + } + + void Register(const std::string &opName, CreatorFunc &&creator) { + if (IsRegistered(opName)) { + LOG_EXCEPTION << "Repeat register for op " << opName; + } + (void)opCreatorsMap_.emplace(opName, std::move(creator)); + } + + void UnRegister(const std::string &opName) { + auto iter = opCreatorsMap_.find(opName); + if (iter != opCreatorsMap_.end()) { + opCreatorsMap_.erase(iter); + } + } + + bool IsRegistered(const std::string &opName) const { return opCreatorsMap_.find(opName) != opCreatorsMap_.end(); } + + std::unique_ptr Create(const std::string &opName) const { + typename std::unordered_map::const_iterator iter = opCreatorsMap_.find(opName); + if (iter != opCreatorsMap_.cend()) { + return (iter->second)(); + } + return nullptr; + } + + private: + std::unordered_map opCreatorsMap_; +}; + +template +class OpRegistrar { + public: + explicit OpRegistrar(const std::string &opName, std::function()> creator) { + OpFactory::GetInstance().Register(opName, std::move(creator)); + } + ~OpRegistrar() = default; +}; + +#define MRT_REG_OP(OP_NAME, OP_CLASS, DEVICE_NAME) \ + static_assert(std::is_base_of::value, #OP_CLASS " must be derived from class Operator"); \ + static const ops::OpRegistrar \ + g_##OP_NAME##_##OP_CLASS##_##DEVICE_NAME##_reg(#OP_NAME, []() { return std::make_unique(); }) + +#define MRT_REG_OP_WITH_CREATOR(OP_NAME, OP_CLASS, DEVICE_NAME, CREATOR) \ + static_assert(std::is_base_of::value, #OP_CLASS " must be derived from class Operator"); \ + static const ops::OpRegistrar \ + g_##OP_NAME##_##OP_CLASS##_##DEVICE_NAME##_reg(#OP_NAME, CREATOR) + +inline std::unique_ptr CreateOperator(const std::string &name) { + auto op = OpFactory::GetInstance().Create(name); + if (op == nullptr) { + op = OpFactory::GetInstance().Create(name); + } + if (op == nullptr) { + LOG_EXCEPTION << "Failed to create operator [" << name << "], maybe it has not been registered"; + } + return op; +} + +} // namespace ops +} // namespace mrt +#endif // __OPS_UTILS_OP_REGISTER_H__ -- Gitee