diff --git a/src/torch_atb/bindings.cpp b/src/torch_atb/bindings.cpp index 063a27b7e590fca45ad0434ff219769fc6036946..ba046de9b5d15cdc9dedeae6ff3a29f9cecb9851 100644 --- a/src/torch_atb/bindings.cpp +++ b/src/torch_atb/bindings.cpp @@ -82,6 +82,7 @@ PYBIND11_MODULE(_C, m) .def_property_readonly("input_num", &TorchAtb::OperationWrapper::GetInputNum) .def_property_readonly("output_num", &TorchAtb::OperationWrapper::GetOutputNum) .def("forward", &TorchAtb::OperationWrapper::Forward) + .def("set_buffer_size", &TorchAtb::OperationWrapper::SetBufferSize) .def("__repr__", [](const TorchAtb::OperationWrapper &opWrapper) { std::stringstream ss; ss << "op name: " << opWrapper.GetName() << ", input_num: " << opWrapper.GetInputNum() diff --git a/src/torch_atb/operation_wrapper.cpp b/src/torch_atb/operation_wrapper.cpp index 7f3a1602fdc7c10b55dc56442a001a0e8cc45803..ef28bd5a85f4933a083f9ab2c9425a7802fc9cb3 100644 --- a/src/torch_atb/operation_wrapper.cpp +++ b/src/torch_atb/operation_wrapper.cpp @@ -317,4 +317,9 @@ void OperationWrapper::BuildInTensorVariantPack(std::vector &inTe variantPack_.inTensors.at(i) = Utils::ConvertToAtbTensor(inTensors.at(i)); } } + +void OperationWrapper::SetBufferSize(uint64_t size) +{ + MemoryManager::SetBufferSize(size); +} } // namespace TorchAtb diff --git a/src/torch_atb/operation_wrapper.h b/src/torch_atb/operation_wrapper.h index 4291c5fa42df7cdc25e5f75c305d6566d5de1de5..ad219fb19647dccee30157192698e965eb38a978 100644 --- a/src/torch_atb/operation_wrapper.h +++ b/src/torch_atb/operation_wrapper.h @@ -69,6 +69,7 @@ public: uint32_t GetInputNum() const; uint32_t GetOutputNum() const; std::vector Forward(std::vector &inTensors); + void OperationWrapper::SetBufferSize(uint64_t size); private: template void CreateOpUniquePtr(const OpParam ¶m);