diff --git a/CMakeLists.txt b/CMakeLists.txt index 3f7590adfc1119ebf4e07541c0183814eab599ff..09b2be67f4967a72b811b12b7d9a1031801fc67c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -26,14 +26,16 @@ function(check_debug_log_out) endif() endfunction() -# Fix undefined symbol error caused by PyTorch abi flag -execute_process(COMMAND python -c "import torch; print(torch.compiled_with_cxx11_abi())" OUTPUT_VARIABLE PYTORCH_CXX11_ABI_VERSION OUTPUT_STRIP_TRAILING_WHITESPACE) -if("${PYTORCH_CXX11_ABI_VERSION}" STREQUAL "True") - message("-- Enable _GLIBCXX_USE_CXX11_ABI") - add_definitions(-D_GLIBCXX_USE_CXX11_ABI=1) -else() - message("-- Disable _GLIBCXX_USE_CXX11_ABI") - add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0) +if(ENABLE_KERNEL_ATEN) + execute_process(COMMAND python -c "import torch; print(torch.compiled_with_cxx11_abi())" OUTPUT_VARIABLE PYTORCH_CXX11_ABI_VERSION OUTPUT_STRIP_TRAILING_WHITESPACE) + + if("${PYTORCH_CXX11_ABI_VERSION}" STREQUAL "True") + message("-- Enable _GLIBCXX_USE_CXX11_ABI") + add_definitions(-D_GLIBCXX_USE_CXX11_ABI=1) + else() + message("-- Disable _GLIBCXX_USE_CXX11_ABI") + add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0) + endif() endif() add_subdirectory(${PROJECT_SOURCE_DIR}/inferrt/src) \ No newline at end of file diff --git a/build.sh b/build.sh index 2f8dba99f9f41a30a5f7fac13fdb38030de77815..11b7e967b59e4220f4fb6d55c6481ca3afb86562 100755 --- a/build.sh +++ b/build.sh @@ -122,7 +122,7 @@ echo "==============================" echo "Run python test case:" echo "python check_api.py" echo "==============================" -export PYTHONPATH=$BUILD_DIR/inferrt/src/pybind:$INFERRT_PATH/inferrt/python +export PYTHONPATH=$BUILD_DIR/inferrt/src/pybind/mrt:$INFERRT_PATH/inferrt/python echo "PYTHONPATH=$PYTHONPATH" python $INFERRT_PATH/inferrt/python/check_api.py @@ -132,7 +132,7 @@ if [[ $TEST_TORCH == 1 ]]; then echo "Run pytorch backend test case:" echo "python check_backend.py" echo "==============================" - export PYTHONPATH=$BUILD_DIR/inferrt/src/pybind:$INFERRT_PATH/inferrt/python + export PYTHONPATH=$BUILD_DIR/inferrt/src/pybind/mrt:$BUILD_DIR/inferrt/src/pybind/mrt_torch:$INFERRT_PATH/inferrt/python export DART_KERNEL_LIB_PATH=$BUILD_DIR/inferrt/src/ops/cpu/aten/libkernel_aten.so export DART_KERNEL_LIB_NAME=Aten python $INFERRT_PATH/inferrt/python/check_backend.py diff --git a/inferrt/python/check_api.py b/inferrt/python/check_api.py index 4d1a6ed6fa4e58531fc392299b98a676c0f6f07a..62af17003347fb71d2e0f5e09d37190deaacd3d8 100644 --- a/inferrt/python/check_api.py +++ b/inferrt/python/check_api.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dapy import jit +from mrt import jit import argparse _arg_parser = argparse.ArgumentParser() @@ -29,4 +29,4 @@ def run_check(x, y): return z assert run_check(12, 6) == 13 -print("The result is correct. 'dapy' module has been installed successfully.") \ No newline at end of file +print("The result is correct. 'mrt' module has been installed successfully.") \ No newline at end of file diff --git a/inferrt/python/check_backend.py b/inferrt/python/check_backend.py index 53ccd6d5e421e13bb8209fff04987b421b349fd4..90bda090759a28abd551486a294d5dd6c6c0f5d0 100644 --- a/inferrt/python/check_backend.py +++ b/inferrt/python/check_backend.py @@ -1,5 +1,5 @@ import torch -from dapy import backend +from mrt.torch import backend def foo(x, y): @@ -16,4 +16,4 @@ bar = foo(x, y) opt_bar = opt_foo(x, y) assert torch.equal(opt_bar, bar), f"\nopt_bar={opt_bar}\nbar={bar}" -print("The result is correct. 'dapy' backend has been installed successfully.") +print("The result is correct. 'mrt' backend has been installed successfully.") diff --git a/inferrt/python/dapy/__init__.py b/inferrt/python/mrt/__init__.py similarity index 85% rename from inferrt/python/dapy/__init__.py rename to inferrt/python/mrt/__init__.py index 6c789ff2d06dc0cd8ff24dfb9e53c2bf4e617391..e1c7af4fd612775b917ece8fd805e3e275b179e4 100644 --- a/inferrt/python/dapy/__init__.py +++ b/inferrt/python/mrt/__init__.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dapy.api import dag, jit -from dapy.fx_backend import backend +from mrt.api import dag, jit -__all__ = ['dag', 'jit', 'backend'] \ No newline at end of file +__all__ = ['dag', 'jit'] diff --git a/inferrt/python/dapy/api.py b/inferrt/python/mrt/api.py similarity index 98% rename from inferrt/python/dapy/api.py rename to inferrt/python/mrt/api.py index e780558c18e0188dccf3daee883515cd36a238c8..913b72d637954fe73ab016793d5a2b8145850172 100644 --- a/inferrt/python/dapy/api.py +++ b/inferrt/python/mrt/api.py @@ -15,7 +15,7 @@ import inspect import types from functools import wraps -from _dapy import DALangPy_ +from _mrt_api import DALangPy_ def _get_source(func): diff --git a/inferrt/python/dapy/ir.py b/inferrt/python/mrt/ir.py similarity index 63% rename from inferrt/python/dapy/ir.py rename to inferrt/python/mrt/ir.py index 5270e3f95aaeb9ec89062c8cadafe9cd087f440e..83aceaebbd6488ff8feec2d835176cd18fe669e2 100644 --- a/inferrt/python/dapy/ir.py +++ b/inferrt/python/mrt/ir.py @@ -1,5 +1,35 @@ from typing import List, Any -from _dairpy import GraphExecutor as _GraphExecutor, Node, Op, to_python, from_python +from _mrt_ir import GraphExecutor as _GraphExecutor, Node, Op, Tensor, Value, Tuple + + +def _from_python(obj: Any) -> Value: + if isinstance(obj, Value): + return obj + if isinstance(obj, (list, tuple)): + return Value(Tuple([_from_python(e) for e in obj])) + if isinstance(obj, (int, float, bool, str, Tensor)): + return Value(obj) + if obj is None: + return Value() + raise TypeError(f"Unsupported python type for conversion to mrt.ir.Value: {type(obj)}") + + +def _to_python(value: Value) -> Any: + if value.is_none(): + return None + if value.is_tensor(): + return value.to_tensor() + if value.is_tuple(): + return tuple(_to_python(item) for item in value.to_tuple()) + if value.is_int(): + return value.to_int() + if value.is_double(): + return value.to_double() + if value.is_bool(): + return value.to_bool() + if value.is_string(): + return value.to_string() + raise TypeError(f"Unsupported ir.Value for conversion to python object: {value}") class GraphExecutor: @@ -41,7 +71,7 @@ class GraphExecutor: def add_value_node(self, value: Any) -> Node: """Add a constant node to the graph from a python object (e.g. torch.Tensor, scalar).""" - return self._executor.add_value_node(from_python(value)) + return self._executor.add_value_node(_from_python(value)) def set_return(self) -> Node: """Add a return node to the graph. The last added node will be the return value.""" @@ -53,7 +83,7 @@ class GraphExecutor: self._executor.run_graph(is_dynamic) if self._return_node is None: raise RuntimeError("Return node not set. Call set_return() before running.") - return to_python(self._return_node.output) + return _to_python(self._return_node.output) def build(self): """Optimize the graph and build kernels.""" @@ -67,4 +97,4 @@ class GraphExecutor: # Re-export for convenience -__all__ = ["GraphExecutor", "Node", "Op"] \ No newline at end of file +__all__ = ["GraphExecutor", "Node", "Op", "Tensor", "Value", "Tuple"] \ No newline at end of file diff --git a/inferrt/python/mrt/torch/__init__.py b/inferrt/python/mrt/torch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..955a4f80bcba527f7bd2a214650c3c17f5e8fca2 --- /dev/null +++ b/inferrt/python/mrt/torch/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .fx_backend import backend + +__all__ = ['backend'] diff --git a/inferrt/python/dapy/fx_backend.py b/inferrt/python/mrt/torch/fx_backend.py similarity index 88% rename from inferrt/python/dapy/fx_backend.py rename to inferrt/python/mrt/torch/fx_backend.py index 2cb9e0ad05ff4df8b8c49613a3aed709d4cb7e72..c5b8248af58fc4a1772c61e11c953504f26b9703 100644 --- a/inferrt/python/dapy/fx_backend.py +++ b/inferrt/python/mrt/torch/fx_backend.py @@ -4,7 +4,8 @@ from typing import List, Dict, Any from torch.fx.node import Node, map_arg from torch.fx.graph_module import GraphModule -from dapy.ir import GraphExecutor, Op +from mrt.ir import GraphExecutor, Op +from mrt.torch.utils import from_torch, to_torch, update_tensor_data _GLOBAL_GRAPH_ID = 0 @@ -98,8 +99,10 @@ def backend(gm: GraphModule, example_inputs: List[torch.Tensor]): for node in gm.graph.nodes: if node.op == "placeholder": - const_tensor = executor.add_value_node(next(input_iterator)) - env[node] = const_tensor + input = next(input_iterator) + if isinstance(input, torch.Tensor): + input = from_torch(input) + env[node] = executor.add_value_node(input) with executor: for node in gm.graph.nodes: @@ -114,10 +117,10 @@ def backend(gm: GraphModule, example_inputs: List[torch.Tensor]): for part in target.split("."): attr_val = getattr(attr_val, part) - if isinstance(attr_val, (torch.Tensor, torch.nn.Parameter)): - env[node] = executor.add_value_node(attr_val) - else: - env[node] = attr_val + if isinstance(attr_val, torch.Tensor): + attr_val = from_torch(attr_val) + + env[node] = executor.add_value_node(attr_val) elif node.op in ("call_function", "call_method"): op = _get_op(node.target) @@ -155,9 +158,10 @@ def backend(gm: GraphModule, example_inputs: List[torch.Tensor]): ) for i, p_node in enumerate(param_nodes): - p_node.output.update_tensor_data(new_inputs[i]) + update_tensor_data(p_node.output.to_tensor(), new_inputs[i]) result = executor.run() - return result + + return tuple(to_torch(r) for r in result) return compiled_callable diff --git a/inferrt/python/mrt/torch/utils.py b/inferrt/python/mrt/torch/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1751d8ceae99e9950eb01e4b6ccdd352e3e8268c --- /dev/null +++ b/inferrt/python/mrt/torch/utils.py @@ -0,0 +1,13 @@ +import torch +from mrt.ir import Tensor + +import _mrt_torch + +def from_torch(torch_tensor: torch.Tensor) -> Tensor: + return _mrt_torch.from_torch(torch_tensor) + +def to_torch(tensor: Tensor) -> torch.Tensor: + return _mrt_torch.to_torch(tensor) + +def update_tensor_data(tensor: Tensor, torch_tensor: torch.Tensor): + _mrt_torch.update_tensor_data(tensor, torch_tensor) diff --git a/inferrt/src/ir/tensor/tensor.cc b/inferrt/src/ir/tensor/tensor.cc index 51464f4a17e8099da214676b8b531f77d152849b..f24e5c3556aef9806d029c202a1d1c3a54797e76 100644 --- a/inferrt/src/ir/tensor/tensor.cc +++ b/inferrt/src/ir/tensor/tensor.cc @@ -17,6 +17,7 @@ #include #include #include +#include #include "common/common.h" #include "ir/tensor/tensor.h" @@ -39,6 +40,22 @@ int64_t CalculateNumel(const std::vector &shape, bool allow_dynamic) { } return numel; } + +template +void PrintData(std::ostream &os, const void *data, size_t numel, size_t limit) { + const auto *d = static_cast(data); + for (size_t i = 0; i < std::min(numel, limit); ++i) { + // Promote char types to int for printing + os << +d[i]; + if (i < std::min(numel, limit) - 1) { + os << ", "; + } + } + if (numel > limit) { + os << ", ..."; + } +} + } // namespace /** @@ -119,17 +136,8 @@ void Tensor::SetShape(const std::vector &&shape) { numel_ = CalculateNumel(shape_, true); } -std::ostream &operator<<(std::ostream &os, Tensor *tensor) { - if (tensor == nullptr) { - os << "Null"; - } else { - os << *tensor; - } - return os; -} - -std::ostream &operator<<(std::ostream &os, const Tensor *tensor) { - if (tensor == nullptr) { +std::ostream &operator<<(std::ostream &os, const TensorPtr &tensor) { + if (!tensor) { os << "Null"; } else { os << *tensor; @@ -150,24 +158,39 @@ std::ostream &operator<<(std::ostream &os, const Tensor &tensor) { os << "], dtype=" << tensor.Dtype().ToString(); os << ", data=["; if (tensor.DataPtr()) { - if (tensor.Dtype() == DataType::Float32) { // TODO: support other dtypes - const auto data = static_cast(tensor.DataPtr()); - const size_t numel = tensor.Numel(); - if (numel <= numelLimit) { - for (size_t i = 0; i < numel; ++i) { - os << data[i]; - if (i < numel - 1) { - os << ", "; - } - } - } else { - for (size_t i = 0; i < numelLimit; ++i) { - os << data[i] << ", "; - } - os << "..."; + if (tensor.HasDynamicShape()) { + os << "dynamic shape, not materialized"; + } else if (tensor.Numel() > 0) { + switch (tensor.Dtype()) { + case DataType::Float32: + PrintData(os, tensor.DataPtr(), tensor.Numel(), numelLimit); + break; + case DataType::Float64: + PrintData(os, tensor.DataPtr(), tensor.Numel(), numelLimit); + break; + case DataType::Int8: + PrintData(os, tensor.DataPtr(), tensor.Numel(), numelLimit); + break; + case DataType::Int16: + PrintData(os, tensor.DataPtr(), tensor.Numel(), numelLimit); + break; + case DataType::Int32: + PrintData(os, tensor.DataPtr(), tensor.Numel(), numelLimit); + break; + case DataType::Int64: + PrintData(os, tensor.DataPtr(), tensor.Numel(), numelLimit); + break; + case DataType::UInt8: + PrintData(os, tensor.DataPtr(), tensor.Numel(), numelLimit); + break; + case DataType::Bool: + os << std::boolalpha; + PrintData(os, tensor.DataPtr(), tensor.Numel(), numelLimit); + break; + default: + os << "..."; + break; } - } else { - os << "..."; } } else { os << "null"; @@ -176,4 +199,4 @@ std::ostream &operator<<(std::ostream &os, const Tensor &tensor) { return os; } } // namespace ir -} // namespace mrt \ No newline at end of file +} // namespace mrt diff --git a/inferrt/src/ir/tensor/tensor.h b/inferrt/src/ir/tensor/tensor.h index 918782c91808e257c7e7a042d9a0a46c37962c83..45cf611843bee0b9bba11fe64a04ff7528a39f84 100644 --- a/inferrt/src/ir/tensor/tensor.h +++ b/inferrt/src/ir/tensor/tensor.h @@ -23,6 +23,7 @@ #include "hardware/device.h" #include "ir/common/dtype.h" +#include "ir/common/intrusive_ptr.h" #include "ir/tensor/storage.h" namespace mrt { @@ -34,7 +35,7 @@ namespace ir { * This class holds the metadata of a tensor, such as its dimensions, data type, * and a reference to the underlying storage. */ -class Tensor { +class Tensor : public RefCounted { public: /** * @brief Constructs an empty Tensor with uninitialized data. @@ -63,28 +64,10 @@ class Tensor { */ Tensor(void *data, const std::vector &shape, DataType dtype, hardware::Device device); - /** - * @brief Move constructor. - */ - Tensor(Tensor &&other) noexcept { - dtype_ = other.dtype_; - numel_ = other.numel_; - storageOffset_ = other.storageOffset_; - shape_ = std::move(other.shape_); - strides_ = std::move(other.strides_); - storage_ = std::move(other.storage_); - - // Invalidate the moved-from tensor - other.dtype_ = DataType::Unknown; - other.numel_ = 0; - other.storageOffset_ = 0; - } - - /** - * @brief Deleted copy constructor. - */ Tensor(const Tensor &) = delete; Tensor &operator=(const Tensor &) = delete; + Tensor(Tensor &&) = delete; + Tensor &operator=(Tensor &&) = delete; /** * @brief Gets the data type of the tensor. @@ -180,9 +163,10 @@ class Tensor { int64_t storageOffset_ = 0; ///< The offset in the storage, in number of elements. }; -std::ostream &operator<<(std::ostream &os, Tensor *tensor); -std::ostream &operator<<(std::ostream &os, const Tensor *tensor); +using TensorPtr = IntrusivePtr; + std::ostream &operator<<(std::ostream &os, const Tensor &tensor); +std::ostream &operator<<(std::ostream &os, const TensorPtr &tensor); } // namespace ir } // namespace mrt diff --git a/inferrt/src/ir/value/value.cc b/inferrt/src/ir/value/value.cc index f274b3fbb0914039901f25c98a34219d8770d432..aacf95cedccb908cab2dc583cbfce073a3c94cf2 100644 --- a/inferrt/src/ir/value/value.cc +++ b/inferrt/src/ir/value/value.cc @@ -21,23 +21,23 @@ namespace mrt { namespace ir { -Value::Value(Tensor &&v) : tag_(Tag::Tensor) { new (&tensor_) Tensor(std::move(v)); } +Value::Value(TensorPtr v) : tag_(Tag::Tensor), tensor_(v) {} Value::Value(double v) : tag_(Tag::Double), double_(v) {} Value::Value(int64_t v) : tag_(Tag::Int), int_(v) {} Value::Value(bool v) : tag_(Tag::Bool), bool_(v) {} Value::Value(std::string &&v) : tag_(Tag::String) { new (&string_) std::string(std::move(v)); } -Value::Value(Tuple &&v) : tag_(Tag::Tuple) { new (&tuple_) Tuple(std::move(v)); } +Value::Value(TuplePtr v) : tag_(Tag::Tuple), tuple_(v) {} Value::~Value() { switch (tag_) { case Tag::Tensor: - tensor_.~Tensor(); + tensor_.~IntrusivePtr(); break; case Tag::String: string_.~basic_string(); break; case Tag::Tuple: - tuple_.~Tuple(); + tuple_.~IntrusivePtr(); break; default: break; @@ -49,13 +49,9 @@ Value::~Value() { LOG_EXCEPTION << "Bad Value access"; \ } -const Tensor *Value::ToTensor() const { +TensorPtr Value::ToTensor() const { CHECK_TAG(Tag::Tensor); - return &tensor_; -} -Tensor *Value::ToTensor() { - CHECK_TAG(Tag::Tensor); - return &tensor_; + return tensor_; } double Value::ToDouble() const { CHECK_TAG(Tag::Double); @@ -69,48 +65,34 @@ bool Value::ToBool() const { CHECK_TAG(Tag::Bool); return bool_; } -const std::string *Value::ToString() const { +const std::string &Value::ToString() const { CHECK_TAG(Tag::String); - return &string_; + return string_; } -std::string *Value::ToString() { - CHECK_TAG(Tag::String); - return &string_; -} -const Tuple *Value::ToTuple() const { +TuplePtr Value::ToTuple() const { CHECK_TAG(Tag::Tuple); - return &tuple_; -} -Tuple *Value::ToTuple() { - CHECK_TAG(Tag::Tuple); - return &tuple_; + return tuple_; } -std::ostream &operator<<(std::ostream &os, const Tuple *tuple) { - os << "Tuple("; - const auto tupleSize = tuple->Size(); - for (size_t i = 0; i < tupleSize; ++i) { - os << (*tuple)[i]; - if (i < tupleSize - 1) { - os << ", "; - } - } - os << ")"; - return os; -} - -std::ostream &operator<<(std::ostream &os, const ValuePtr &value) { return operator<<(os, value.get()); } - -std::ostream &operator<<(std::ostream &os, Value *value) { - if (value == nullptr) { - os << "Null"; +std::ostream &operator<<(std::ostream &os, const TuplePtr &tuple) { + if (tuple == nullptr) { + os << "Tuple(Null)"; } else { - os << *value; + os << "Tuple("; + bool first = true; + for (const auto &item : *tuple) { + if (!first) { + os << ", "; + } + os << item; + first = false; + } + os << ")"; } return os; } -std::ostream &operator<<(std::ostream &os, const Value *value) { +std::ostream &operator<<(std::ostream &os, const ValuePtr &value) { if (value == nullptr) { os << "Null"; } else { diff --git a/inferrt/src/ir/value/value.h b/inferrt/src/ir/value/value.h index 3d46b9823cd416644e3d479c3618c7984a13ceb0..68653c258c72c63bf86364036691d84299513671 100644 --- a/inferrt/src/ir/value/value.h +++ b/inferrt/src/ir/value/value.h @@ -37,7 +37,7 @@ using ValuePtr = IntrusivePtr; * * This class provides holds a vector of Value objects. */ -class Tuple { +class Tuple : public RefCounted { public: /** * @brief Default constructor. Creates an empty Tuple. @@ -51,11 +51,10 @@ class Tuple { explicit Tuple(const std::vector &elements) : elements_(elements) {} explicit Tuple(std::vector &&elements) : elements_(std::move(elements)) {} - /** - * @brief Move Constructor. - * @param other The Tuple to move from. - */ - Tuple(Tuple &&other) noexcept : elements_(std::move(other.elements_)) {} + Tuple(const Tuple &) = delete; + Tuple &operator=(const Tuple &) = delete; + Tuple(Tuple &&) = delete; + Tuple &operator=(Tuple &&) = delete; /** * @brief Get the size of the tuple. @@ -64,19 +63,29 @@ class Tuple { size_t Size() const { return elements_.size(); } /** - * @brief Retrieves the raw pointer of an element by index. + * @brief Retrieves an element by index. * @param index The index of the element to retrieve. - * @return The element as Value*, or nullptr if the index is out of bounds. + * @return The element as ValuePtr. */ - Value *operator[](size_t index) const { + ValuePtr operator[](size_t index) const { CHECK_IF_FAIL(index < elements_.size()); - return elements_[index].get(); + return elements_[index]; } + auto begin() const { return elements_.cbegin(); } + auto end() const { return elements_.cend(); } + auto begin() { return elements_.begin(); } + auto end() { return elements_.end(); } + private: std::vector elements_; }; +/** + * @brief A smart pointer for Tuple. + */ +using TuplePtr = IntrusivePtr; + /** * @brief A generic container for different types of values. * @@ -90,10 +99,10 @@ class Value : public RefCounted { */ Value() : tag_(Tag::None) {} /** - * @brief Constructs a Value from a Tensor by moving. - * @param v The Tensor value. + * @brief Constructs a Value from a TensorPtr. + * @param v The TensorPtr value. */ - Value(Tensor &&v); + Value(TensorPtr v); /** * @brief Constructs a Value from a double. * @param v The double value. @@ -115,10 +124,10 @@ class Value : public RefCounted { */ Value(std::string &&v); /** - * @brief Constructs a Value from a Tuple by moving. - * @param v The Tuple value. + * @brief Constructs a Value from a TuplePtr. + * @param v The TuplePtr value. */ - Value(Tuple &&v); + Value(TuplePtr v); /** * @brief Destructor. @@ -141,15 +150,12 @@ class Value : public RefCounted { * std::runtime_error if the type does not match. */ ///@{ - const Tensor *ToTensor() const; - Tensor *ToTensor(); + TensorPtr ToTensor() const; double ToDouble() const; int64_t ToInt() const; bool ToBool() const; - const std::string *ToString() const; - std::string *ToString(); - const Tuple *ToTuple() const; - Tuple *ToTuple(); + const std::string &ToString() const; + TuplePtr ToTuple() const; ///@} /** @@ -168,20 +174,19 @@ class Value : public RefCounted { const Tag tag_; ///< The tag indicating the type of the value. union { - Tensor tensor_; + TensorPtr tensor_; double double_; int64_t int_; bool bool_; std::string string_; - Tuple tuple_; + TuplePtr tuple_; }; }; -std::ostream &operator<<(std::ostream &os, const ValuePtr &value); std::ostream &operator<<(std::ostream &os, const Value &value); -std::ostream &operator<<(std::ostream &os, Value *value); -std::ostream &operator<<(std::ostream &os, const Value *value); +std::ostream &operator<<(std::ostream &os, const ValuePtr &value); std::ostream &operator<<(std::ostream &os, const std::vector &values); +std::ostream &operator<<(std::ostream &os, const TuplePtr &tuple); } // namespace ir } // namespace mrt diff --git a/inferrt/src/ops/cpu/aten/aten_kernel.cc b/inferrt/src/ops/cpu/aten/aten_kernel.cc index 05c595893430809ae0b877f9c378de4bf9247641..f8b80328b28c6f085abae04a7bda95541a8aa1ac 100644 --- a/inferrt/src/ops/cpu/aten/aten_kernel.cc +++ b/inferrt/src/ops/cpu/aten/aten_kernel.cc @@ -66,7 +66,7 @@ void AtenKernel::InferShape() { auto in1Dims = node_->inputs[1]->output->ToTensor()->Shape(); dims = at::infer_size(in0Dims, in1Dims); } - node_->output = ir::MakeIntrusive(ir::Tensor(dims, dtype, device)); + node_->output = ir::MakeIntrusive(ir::MakeIntrusive(dims, dtype, device)); } void AtenKernel::Resize() {} diff --git a/inferrt/src/ops/cpu/aten/test_aten.cc b/inferrt/src/ops/cpu/aten/test_aten.cc index 7efcebeb2766ff482867284f34549b740bf7fb0b..e99ae02b889d91623048ed438b355e8da88c5e45 100644 --- a/inferrt/src/ops/cpu/aten/test_aten.cc +++ b/inferrt/src/ops/cpu/aten/test_aten.cc @@ -33,8 +33,8 @@ void TestAtenKernel::Init() { void TestAtenKernel::InferShape() { CHECK_IF_NULL(operator_); - node_->output = ir::MakeIntrusive( - ir::Tensor({-1}, ir::DataType::Type::Float32, hardware::Device(hardware::DeviceType::CPU, 0))); + node_->output = ir::MakeIntrusive(ir::MakeIntrusive( + std::vector{-1}, ir::DataType::Type::Float32, hardware::Device(hardware::DeviceType::CPU, 0))); Init(); LOG_OUT << "Begin InferShape for operator [" << ToStr(node_->op) << "], input=" << input_ << ", output=" << output_; if (operator_->InferShape(input_, output_) != SUCCESS) { diff --git a/inferrt/src/pybind/CMakeLists.txt b/inferrt/src/pybind/CMakeLists.txt index 67ec1a0424f58c8f4d92f93e0eea042fca5b179d..3cf6c15ba932d72a08f47ae55f042cfd5fc7f96c 100644 --- a/inferrt/src/pybind/CMakeLists.txt +++ b/inferrt/src/pybind/CMakeLists.txt @@ -58,16 +58,10 @@ else() add_subdirectory(${PYBIND11_PATH}) endif() -# Add dapy pybind11 sub module -pybind11_add_module(_dapy NO_EXTRAS dalang_py/pybind11_api.cc) -target_link_libraries(_dapy PUBLIC inferrt) +# Add mrt pybind11 sub module +add_subdirectory(mrt) -# Link against torch libraries for _dairpy -execute_process(COMMAND python -c "import os; import torch; print(os.path.join(os.path.dirname(torch.__file__), 'share/cmake'))" OUTPUT_VARIABLE PYTORCH_CMAKE_PATH OUTPUT_STRIP_TRAILING_WHITESPACE) -set(CMAKE_PREFIX_PATH "${PYTORCH_CMAKE_PATH}") -find_package(Torch REQUIRED) -find_library(TORCH_PYTHON_LIBRARY torch_python PATH "${TORCH_INSTALL_PREFIX}/lib") - -# Add dairpy pybind11 sub module -pybind11_add_module(_dairpy NO_EXTRAS dalang_py/pybind11_ir.cc) -target_link_libraries(_dairpy PUBLIC inferrt ${TORCH_LIBRARIES} ${TORCH_PYTHON_LIBRARY}) +# Add mrt_torch pybind11 sub module +if(ENABLE_KERNEL_ATEN) + add_subdirectory(mrt_torch) +endif() diff --git a/inferrt/src/pybind/mrt/CMakeLists.txt b/inferrt/src/pybind/mrt/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..451f1eb1bf1b2de00a22159a8bd80f2378759e8a --- /dev/null +++ b/inferrt/src/pybind/mrt/CMakeLists.txt @@ -0,0 +1,9 @@ +check_debug_log_out() + +# Add mrt_api pybind11 sub module +pybind11_add_module(_mrt_api NO_EXTRAS pybind11_api.cc) +target_link_libraries(_mrt_api PUBLIC inferrt) + +# Add mrt_ir pybind11 sub module +pybind11_add_module(_mrt_ir NO_EXTRAS pybind11_ir.cc) +target_link_libraries(_mrt_ir PUBLIC inferrt) diff --git a/inferrt/src/pybind/dalang_py/pybind11_api.cc b/inferrt/src/pybind/mrt/pybind11_api.cc similarity index 98% rename from inferrt/src/pybind/dalang_py/pybind11_api.cc rename to inferrt/src/pybind/mrt/pybind11_api.cc index 0dc17d50cf18ee355dae7772ee107fb42b2e7f8b..3fa98d65183050cd9219e64796c6ed96875dd67a 100644 --- a/inferrt/src/pybind/dalang_py/pybind11_api.cc +++ b/inferrt/src/pybind/mrt/pybind11_api.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "pybind/dalang_py/pybind11_api.h" +#include "pybind/mrt/pybind11_api.h" DALangPy::~DALangPy() { if (callable_ != nullptr) { @@ -98,7 +98,7 @@ py::object DALangPy::Run(const py::tuple &args) { } // Interface with python -PYBIND11_MODULE(_dapy, mod) { +PYBIND11_MODULE(_mrt_api, mod) { (void)py::class_>(mod, "DALangPy_") .def_static("get_instance", &DALangPy::GetInstance, "DALangPy single instance.") .def("__call__", &DALangPy::Run, py::arg("args") = py::list(), "Run with arguments.") diff --git a/inferrt/src/pybind/dalang_py/pybind11_api.h b/inferrt/src/pybind/mrt/pybind11_api.h similarity index 100% rename from inferrt/src/pybind/dalang_py/pybind11_api.h rename to inferrt/src/pybind/mrt/pybind11_api.h diff --git a/inferrt/src/pybind/mrt/pybind11_ir.cc b/inferrt/src/pybind/mrt/pybind11_ir.cc new file mode 100644 index 0000000000000000000000000000000000000000..38afb9108ebdf703ea1e345d279dccb4473fb12f --- /dev/null +++ b/inferrt/src/pybind/mrt/pybind11_ir.cc @@ -0,0 +1,104 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include "runtime/executor/executor.h" +#include "ir/graph.h" +#include "ir/value/value.h" +#include "ops/op_def/ops_name.h" + +namespace py = pybind11; +using namespace mrt; +using namespace mrt::runtime; + +PYBIND11_DECLARE_HOLDER_TYPE(T, ir::IntrusivePtr, true); + +PYBIND11_MODULE(_mrt_ir, m) { + m.doc() = "Python binding for DA IR"; + + py::enum_(m, "Op") +#define OP(O) .value(#O, ops::Op_##O) +#include "ops/op_def/ops.list" +#undef OP + .export_values(); + + py::class_(m, "Tensor") + .def_property_readonly("shape", &ir::Tensor::Shape) + .def_property_readonly("dtype", &ir::Tensor::Dtype) + .def("__repr__", [](const ir::Tensor &t) { + std::stringstream ss; + ss << t; + return ss.str(); + }); + + py::class_(m, "Tuple") + .def(py::init>()) + .def("__len__", &ir::Tuple::Size) + .def("__getitem__", &ir::Tuple::operator[], py::return_value_policy::reference) + .def("__iter__", [](const ir::Tuple &t) { return py::make_iterator(t.begin(), t.end()); }, py::keep_alive<0, 1>()) + .def("__repr__", [](const ir::TuplePtr &t) { + std::stringstream ss; + ss << t; + return ss.str(); + }); + + py::class_(m, "Value") + .def(py::init<>()) + .def(py::init()) + .def(py::init()) + .def(py::init()) + .def(py::init()) + .def(py::init()) + .def(py::init()) + .def("is_tensor", &ir::Value::IsTensor) + .def("is_tuple", &ir::Value::IsTuple) + .def("is_double", &ir::Value::IsDouble) + .def("is_int", &ir::Value::IsInt) + .def("is_bool", &ir::Value::IsBool) + .def("is_string", &ir::Value::IsString) + .def("is_none", &ir::Value::IsNone) + .def("to_tensor", &ir::Value::ToTensor) + .def("to_tuple", &ir::Value::ToTuple) + .def("to_double", &ir::Value::ToDouble) + .def("to_int", &ir::Value::ToInt) + .def("to_bool", &ir::Value::ToBool) + .def("to_string", &ir::Value::ToString) + .def("__repr__", [](const ir::Value &v) { + std::stringstream ss; + ss << v; + return ss.str(); + }); + + py::class_(m, "Node").def_property_readonly( + "output", [](const ir::NodePtr &node) { return node->output; }); + + py::class_(m, "GraphExecutor") + .def(py::init<>()) + .def("begin_graph", &GraphExecutor::BeginGraph, py::arg("name")) + .def("end_graph", &GraphExecutor::EndGraph) + .def("opt_graph", &GraphExecutor::OptGraph) + .def("build_kernels", &GraphExecutor::BuildKernels) + .def("run_graph", &GraphExecutor::RunGraph, py::arg("is_dynamic") = false) + .def("dump_graph", &GraphExecutor::DumpGraph) + .def("record_tensor_ref_count", &GraphExecutor::RecordTensorRefCount) + .def("add_return", &GraphExecutor::AddReturn, py::return_value_policy::reference) + .def("add_parameter", &GraphExecutor::AddParameter, py::arg("param")) + .def("add_op_node", &GraphExecutor::AddOpNode, py::arg("op"), py::arg("inputs"), py::return_value_policy::reference) + .def("add_value_node", &GraphExecutor::AddValueNode, py::arg("value"), py::return_value_policy::reference); +} diff --git a/inferrt/src/pybind/mrt_torch/CMakeLists.txt b/inferrt/src/pybind/mrt_torch/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..65c80d33e3fb3e9b269527352caad790fc6e88e4 --- /dev/null +++ b/inferrt/src/pybind/mrt_torch/CMakeLists.txt @@ -0,0 +1,11 @@ +check_debug_log_out() + +# Link against torch libraries for _mrt_torch +execute_process(COMMAND python -c "import os; import torch; print(os.path.join(os.path.dirname(torch.__file__), 'share/cmake'))" OUTPUT_VARIABLE PYTORCH_CMAKE_PATH OUTPUT_STRIP_TRAILING_WHITESPACE) +set(CMAKE_PREFIX_PATH "${PYTORCH_CMAKE_PATH}") +find_package(Torch REQUIRED) +find_library(TORCH_PYTHON_LIBRARY torch_python PATH "${TORCH_INSTALL_PREFIX}/lib") + +# Add mrt_torch pybind11 sub module +pybind11_add_module(_mrt_torch NO_EXTRAS pybind11_torch.cc) +target_link_libraries(_mrt_torch PUBLIC inferrt ${TORCH_LIBRARIES} ${TORCH_PYTHON_LIBRARY}) diff --git a/inferrt/src/pybind/dalang_py/pybind11_ir.cc b/inferrt/src/pybind/mrt_torch/pybind11_torch.cc similarity index 39% rename from inferrt/src/pybind/dalang_py/pybind11_ir.cc rename to inferrt/src/pybind/mrt_torch/pybind11_torch.cc index 4e20c210a7b37c83db65ff3ac61966f32746621d..a5ac08fb7560e185958f6ad431d5eaa8d964be30 100644 --- a/inferrt/src/pybind/dalang_py/pybind11_ir.cc +++ b/inferrt/src/pybind/mrt_torch/pybind11_torch.cc @@ -15,19 +15,14 @@ */ #include -#include -#include - #include -#include "runtime/executor/executor.h" -#include "ir/graph.h" -#include "ir/value/value.h" -#include "ops/op_def/ops_name.h" +#include "ir/common/intrusive_ptr.h" +#include "ir/tensor/tensor.h" +#include "common/logger.h" namespace py = pybind11; using namespace mrt; -using namespace mrt::runtime; namespace { // DataType conversion utilities @@ -96,141 +91,45 @@ at::Device ToTorchDevice(const hardware::Device device) { return at::Device(deviceType, device.index); } -// Forward declaration for recursive conversion -py::object to_python(const ir::ValuePtr &value); -ir::ValuePtr from_python(const py::handle &obj); - // Tensor conversion utilities -ir::ValuePtr from_torch(const at::Tensor &atTensor) { +ir::TensorPtr FromTorchTensor(const at::Tensor &atTensor) { ir::DataType type = FromTorchDType(atTensor.scalar_type()); std::vector shape(atTensor.sizes().begin(), atTensor.sizes().end()); void *data = atTensor.data_ptr(); auto device = FromTorchDevice(atTensor.device()); - auto tensor = ir::Tensor(data, shape, type, device); - return ir::MakeIntrusive(std::move(tensor)); + return ir::MakeIntrusive(data, shape, type, device); } -at::Tensor to_torch(const ir::ValuePtr &value) { - auto tensor = value->ToTensor(); +at::Tensor ToTorchTensor(const ir::TensorPtr &tensor) { CHECK_IF_NULL(tensor); auto options = at::TensorOptions().dtype(ToTorchDType(tensor->Dtype())).device(ToTorchDevice(tensor->GetDevice())); - value->AddRef(); + tensor->AddRef(); return at::from_blob( - tensor->DataPtr(), tensor->Shape(), tensor->Strides(), [ptr = value.get()](void *) { ptr->DecRef(); }, options); + const_cast(tensor->DataPtr()), tensor->Shape(), tensor->Strides(), + [ptr = tensor.get()](void *) { ptr->DecRef(); }, options); } -// New conversion functions -ir::ValuePtr from_python(const py::handle &obj) { - if (THPVariable_Check(obj.ptr())) { - return from_torch(obj.cast()); - } - if (py::isinstance(obj) || py::isinstance(obj)) { - auto py_tuple = obj.cast(); - std::vector elements; - elements.reserve(py_tuple.size()); - for (const auto &elem : py_tuple) { - (void)elements.emplace_back(from_python(elem)); - } - return ir::MakeIntrusive(ir::Tuple(std::move(elements))); - } - if (py::isinstance(obj)) { - return ir::MakeIntrusive(obj.cast()); - } - if (py::isinstance(obj)) { - return ir::MakeIntrusive(obj.cast()); - } - if (py::isinstance(obj)) { - return ir::MakeIntrusive(obj.cast()); - } - if (py::isinstance(obj)) { - return ir::MakeIntrusive(obj.cast()); - } - if (obj.is_none()) { - return ir::MakeIntrusive(); - } - LOG_EXCEPTION << "Unsupported python type for conversion to ir::Value: " << py::str(obj); - return nullptr; -} +void UpdateTensorData(ir::Tensor &self, const at::Tensor &atTensor) { + ir::DataType type = FromTorchDType(atTensor.scalar_type()); + std::vector shape(atTensor.sizes().begin(), atTensor.sizes().end()); + void *data = atTensor.data_ptr(); -py::object to_python(const ir::ValuePtr &value) { - if (!value) { - return py::none(); - } - if (value->IsTensor()) { - return py::cast(to_torch(value)); + if (self.GetDevice() != FromTorchDevice(atTensor.device())) { + LOG_EXCEPTION << "Device mismatch in update_tensor_data"; } - if (value->IsTuple()) { - const auto *tuple = value->ToTuple(); - py::tuple py_tuple(tuple->Size()); - for (size_t i = 0; i < tuple->Size(); ++i) { - py_tuple[i] = to_python((*tuple)[i]); - } - return py_tuple; - } - if (value->IsInt()) { - return py::cast(value->ToInt()); - } - if (value->IsDouble()) { - return py::cast(value->ToDouble()); - } - if (value->IsBool()) { - return py::cast(value->ToBool()); - } - if (value->IsString()) { - return py::cast(*value->ToString()); - } - if (value->IsNone()) { - return py::none(); - } - LOG_EXCEPTION << "Unsupported ir::Value for conversion to python object: " << value; - return py::none(); + + self.SetDtype(type); + self.SetShape(std::move(shape)); + self.ResizeStorage(); + self.UpdateData(data); } } // namespace PYBIND11_DECLARE_HOLDER_TYPE(T, ir::IntrusivePtr, true); -PYBIND11_MODULE(_dairpy, m) { - m.doc() = "Python binding for DA IR"; - - py::enum_(m, "Op") -#define OP(O) .value(#O, ops::Op_##O) -#include "ops/op_def/ops.list" -#undef OP - .export_values(); - - py::class_(m, "Value") - .def("update_tensor_data", [](ir::Value &self, const at::Tensor &atTensor) { - ir::DataType type = FromTorchDType(atTensor.scalar_type()); - std::vector shape(atTensor.sizes().begin(), atTensor.sizes().end()); - void *data = atTensor.data_ptr(); - - auto tensor = self.ToTensor(); - if (tensor->GetDevice() != FromTorchDevice(atTensor.device())) { - LOG_EXCEPTION << "Device mismatch in update_tensor_data"; - } - - tensor->SetDtype(type); - tensor->SetShape(std::move(shape)); - tensor->ResizeStorage(); - tensor->UpdateData(data); - }); - - py::class_(m, "Node").def_property_readonly( - "output", [](const ir::NodePtr &node) { return node->output; }); - - m.def("from_python", &from_python, py::arg("obj")); - m.def("to_python", &to_python, py::arg("value")); - py::class_(m, "GraphExecutor") - .def(py::init<>()) - .def("begin_graph", &GraphExecutor::BeginGraph, py::arg("name")) - .def("end_graph", &GraphExecutor::EndGraph) - .def("opt_graph", &GraphExecutor::OptGraph) - .def("build_kernels", &GraphExecutor::BuildKernels) - .def("run_graph", &GraphExecutor::RunGraph, py::arg("is_dynamic") = false) - .def("dump_graph", &GraphExecutor::DumpGraph) - .def("record_tensor_ref_count", &GraphExecutor::RecordTensorRefCount) - .def("add_return", &GraphExecutor::AddReturn, py::return_value_policy::reference) - .def("add_parameter", &GraphExecutor::AddParameter, py::arg("param")) - .def("add_op_node", &GraphExecutor::AddOpNode, py::arg("op"), py::arg("inputs"), py::return_value_policy::reference) - .def("add_value_node", &GraphExecutor::AddValueNode, py::arg("value"), py::return_value_policy::reference); +PYBIND11_MODULE(_mrt_torch, m) { + m.doc() = "PyTorch extension for DA IR"; + m.def("from_torch", &FromTorchTensor); + m.def("to_torch", &ToTorchTensor, py::return_value_policy::reference); + m.def("update_tensor_data", &UpdateTensorData); } diff --git a/inferrt/src/runtime/executor/executor.cc b/inferrt/src/runtime/executor/executor.cc index ab4150d9d0a873b8f9cab884a9bd667bfb0655d4..606398bbf6186d42192ec7a7ab439e2bf360f343 100644 --- a/inferrt/src/runtime/executor/executor.cc +++ b/inferrt/src/runtime/executor/executor.cc @@ -70,7 +70,7 @@ void ProcessMakeTuple(ir::NodePtr node) { for (auto &input : node->inputs) { (void)elements.emplace_back(input->output); } - node->output = ir::MakeIntrusive(ir::Tuple(std::move(elements))); + node->output = ir::MakeIntrusive(ir::MakeIntrusive(std::move(elements))); } void ProcessTupleGetItem(ir::NodePtr node) { @@ -228,7 +228,8 @@ void GraphExecutor::RunNode(ir::NodePtr node) { LOG_OUT << "Skip launch kernel for node" << node; auto outputTensor = node->output->ToTensor(); auto inputStorage = node->inputs[it->second]->output->ToTensor()->GetStorage(); - node->output = ir::MakeIntrusive(ir::Tensor(inputStorage, outputTensor->Shape(), outputTensor->Dtype())); + node->output = ir::MakeIntrusive( + ir::MakeIntrusive(inputStorage, outputTensor->Shape(), outputTensor->Dtype())); } else { kernel->Launch(); }