From e3dee7a99e81fa4c81ef1df28b7d75d66143c852 Mon Sep 17 00:00:00 2001 From: mueuler Date: Tue, 22 Jul 2025 14:29:33 +0800 Subject: [PATCH] add launch_callback --- torch_npu/csrc/core/npu/NPUGraph.cpp | 29 +++++++++++ torch_npu/csrc/core/npu/NPUGraph.h | 7 +++ .../csrc/core/npu/interface/AclInterface.cpp | 52 +++++++++++++++++++ .../csrc/core/npu/interface/AclInterface.h | 7 +++ torch_npu/csrc/npu/Graph.cpp | 45 +++++++++++++++- torch_npu/csrc/npu/Graph.h | 18 +++++++ torch_npu/npu/__init__.py | 10 +++- torch_npu/npu/graphs.py | 19 +++++++ 8 files changed, 185 insertions(+), 2 deletions(-) create mode 100644 torch_npu/csrc/npu/Graph.h diff --git a/torch_npu/csrc/core/npu/NPUGraph.cpp b/torch_npu/csrc/core/npu/NPUGraph.cpp index a00448bd1c..0c520661c5 100644 --- a/torch_npu/csrc/core/npu/NPUGraph.cpp +++ b/torch_npu/csrc/core/npu/NPUGraph.cpp @@ -47,6 +47,35 @@ void graph_task_update_end(c10_npu::NPUStream stream) NPU_CHECK_ERROR(c10_npu::acl::AclmdlRICaptureTaskUpdateEnd(stream)); } +void launch_callback(NPUCallbackFunc func, void *fnData, uint32_t blockType, c10_npu::NPUStream stream) +{ + aclrtCallbackBlockType type; + switch (blockType) + { + case 0: + type = aclrtCallbackBlockType::ACL_CALLBACK_NO_BLOCK; + break; + default: + type = aclrtCallbackBlockType::ACL_CALLBACK_BLOCK; + } + NPU_CHECK_ERROR(c10_npu::acl::AclrtLaunchCallback(func, fnData, type, stream)); +} + +void subscribe_report(uint64_t threadId, c10_npu::NPUStream stream) +{ + NPU_CHECK_ERROR(c10_npu::acl::AclrtSubscribeReport(threadId, stream)); +} + +void unsubscribe_report(uint64_t threadId, c10_npu::NPUStream stream) +{ + NPU_CHECK_ERROR(c10_npu::acl::AclrtUnSubscribeReport(threadId, stream)); +} + +void process_report(int32_t timeout) +{ + NPU_CHECK_ERROR(c10_npu::acl::AclrtProcessReport(timeout)); +} + /** * Note [CUDA Graph Wrapper Class] * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/torch_npu/csrc/core/npu/NPUGraph.h b/torch_npu/csrc/core/npu/NPUGraph.h index 442ae335cc..06d6d77404 100644 --- a/torch_npu/csrc/core/npu/NPUGraph.h +++ b/torch_npu/csrc/core/npu/NPUGraph.h @@ -18,11 +18,18 @@ struct TORCH_NPU_API NPUTaskGroupHandle { aclrtTaskGrp task_group; }; +typedef TORCH_NPU_API void (*NPUCallbackFunc)(void *fnData); + TORCH_NPU_API void graph_task_group_begin(c10_npu::NPUStream stream); TORCH_NPU_API NPUTaskGroupHandle graph_task_group_end(c10_npu::NPUStream stream); TORCH_NPU_API void graph_task_update_begin(c10_npu::NPUStream stream, NPUTaskGroupHandle handle); TORCH_NPU_API void graph_task_update_end(c10_npu::NPUStream stream); +TORCH_NPU_API void launch_callback(NPUCallbackFunc func, void *fnData, uint32_t blockType, c10_npu::NPUStream stream); +TORCH_NPU_API void subscribe_report(uint64_t threadId, c10_npu::NPUStream stream); +TORCH_NPU_API void unsubscribe_report(uint64_t threadId, c10_npu::NPUStream stream); +TORCH_NPU_API void process_report(int32_t timeout); + struct TORCH_NPU_API NPUGraph { NPUGraph(); ~NPUGraph(); diff --git a/torch_npu/csrc/core/npu/interface/AclInterface.cpp b/torch_npu/csrc/core/npu/interface/AclInterface.cpp index b97a8d4c39..e065a6cda1 100644 --- a/torch_npu/csrc/core/npu/interface/AclInterface.cpp +++ b/torch_npu/csrc/core/npu/interface/AclInterface.cpp @@ -93,6 +93,10 @@ LOAD_FUNCTION(aclrtDeviceGetBareTgid) LOAD_FUNCTION(aclrtGetDeviceResLimit) LOAD_FUNCTION(aclrtSetDeviceResLimit) LOAD_FUNCTION(aclrtResetDeviceResLimit) +LOAD_FUNCTION(aclrtLaunchCallback) +LOAD_FUNCTION(aclrtSubscribeReport) +LOAD_FUNCTION(aclrtUnSubscribeReport) +LOAD_FUNCTION(aclrtProcessReport) aclprofStepInfoPtr init_stepinfo() { @@ -1073,5 +1077,53 @@ aclError AclrtResetDeviceResLimit(int32_t deviceId) return func(deviceId); } +aclError AclrtLaunchCallback(aclrtCallback fn, void *userData, aclrtCallbackBlockType blockType, aclrtStream stream) +{ + typedef aclError (*AclrtLaunchCallback)(aclrtCallback, void *, aclrtCallbackBlockType, aclrtStream); + static AclrtLaunchCallback func = nullptr; + if (func == nullptr) { + func = (AclrtLaunchCallback) GET_FUNC(aclrtLaunchCallback); + } + + TORCH_CHECK(func, "Failed to find function aclrtLaunchCallback", PTA_ERROR(ErrCode::NOT_FOUND)); + return func(fn, userData, blockType, stream); +} + +aclError AclrtSubscribeReport(uint64_t threadId, aclrtStream stream) +{ + typedef aclError (*AclrtSubscribeReport)(uint64_t, aclrtStream); + static AclrtSubscribeReport func = nullptr; + if (func == nullptr) { + func = (AclrtSubscribeReport) GET_FUNC(aclrtSubscribeReport); + } + + TORCH_CHECK(func, "Failed to find function aclrtSubscribeReport", PTA_ERROR(ErrCode::NOT_FOUND)); + return func(threadId, stream); +} + +aclError AclrtUnSubscribeReport(uint64_t theadId, aclrtStream stream) +{ + typedef aclError (*AclrtUnSubscribeReport)(uint64_t, aclrtStream); + static AclrtUnSubscribeReport func = nullptr; + if (func == nullptr) { + func = (AclrtUnSubscribeReport) GET_FUNC(aclrtUnSubscribeReport); + } + + TORCH_CHECK(func, "Failed to find function aclrtUnSubscribeReport", PTA_ERROR(ErrCode::NOT_FOUND)); + return func(theadId, stream); +} + +aclError AclrtProcessReport(int32_t timeout) +{ + typedef aclError (*AclrtProcessReport)(int32_t); + static AclrtProcessReport func = nullptr; + if (func == nullptr) { + func = (AclrtProcessReport) GET_FUNC(aclrtProcessReport); + } + + TORCH_CHECK(func, "Failed to find function aclrtProcessReport", PTA_ERROR(ErrCode::NOT_FOUND)); + return func(timeout); +} + } // namespace acl } // namespace c10 diff --git a/torch_npu/csrc/core/npu/interface/AclInterface.h b/torch_npu/csrc/core/npu/interface/AclInterface.h index 3b6d47cf4a..73b8852f2c 100644 --- a/torch_npu/csrc/core/npu/interface/AclInterface.h +++ b/torch_npu/csrc/core/npu/interface/AclInterface.h @@ -257,5 +257,12 @@ aclError AclrtSetDeviceResLimit(int32_t deviceId, aclrtDevResModelType type, uin aclError AclrtResetDeviceResLimit(int32_t deviceId); +aclError AclrtLaunchCallback(aclrtCallback fn, void *userData, aclrtCallbackBlockType blockType, aclrtStream stream); + +aclError AclrtSubscribeReport(uint64_t theadId, aclrtStream stream); + +aclError AclrtUnSubscribeReport(uint64_t theadId, aclrtStream stream); + +aclError AclrtProcessReport(int32_t timeout); } // namespace acl } // namespace c10_npu diff --git a/torch_npu/csrc/npu/Graph.cpp b/torch_npu/csrc/npu/Graph.cpp index c8d30cfa44..cc77e6c872 100644 --- a/torch_npu/csrc/npu/Graph.cpp +++ b/torch_npu/csrc/npu/Graph.cpp @@ -1,4 +1,4 @@ -#include +#include "torch_npu/csrc/npu/Graph.h" #include @@ -12,6 +12,30 @@ template using shared_ptr_class_ = py::class_>; +void LaunchCallFunc(void *userData) +{ + PyGILState_STATE state = PyGILState_Ensure(); + if (userData == nullptr) { + return; + } + auto data = (PyFuncStruct *)(userData); + PyObject *argslist = Py_BuildValue("(O)", data->pyFuncArgs); + if (argslist == nullptr) { + return; + } + PyObject *result = PyObject_CallObject(data->pyFunc, argslist); + if (result == nullptr) { + return; + } + if (argslist != nullptr) { + Py_XDECREF(argslist); + } + if (result != nullptr) { + Py_XDECREF(result); + } + PyGILState_Release(state); +} + void TORCH_NPU_API THNPGraph_init(PyObject* module) { // Pybind11 patch notes say "py::module_" is more up-to-date syntax, // but CI linter and some builds prefer "module". @@ -36,6 +60,25 @@ void TORCH_NPU_API THNPGraph_init(PyObject* module) { .def("_graph_task_update_end", [](py::object py_stream) { auto stream = (*py_stream).ptr(); c10_npu::graph_task_update_end(THNPUtils_PyObject_to_NPUStream(stream)); + }) + .def("_launch_callback", [](py::object py_func, py::object py_data, uint32_t block_type, py::object py_stream) { + auto func = (*py_func).ptr(); + auto userDataList = (*py_data).ptr(); + auto stream = (*py_stream).ptr(); + PyFuncStruct *data = new PyFuncStruct(func, userDataList); + c10_npu::launch_callback(LaunchCallFunc,reinterpret_cast(data), block_type,THNPUtils_PyObject_to_NPUStream + (stream)); + }) + .def("_subscribe_report", [](uint64_t threadId, py::object py_stream) { + auto stream = (*py_stream).ptr(); + c10_npu::subscribe_report(threadId, THNPUtils_PyObject_to_NPUStream(stream)); + }) + .def("_unsubscribe_report", [](uint64_t threadId, py::object py_stream) { + auto stream = (*py_stream).ptr(); + c10_npu::unsubscribe_report(threadId, THNPUtils_PyObject_to_NPUStream(stream)); + }) + .def("_process_report", [](int32_t timeout) { + c10_npu::process_report(timeout); }); shared_ptr_class_(torch_N_m, "_NPUGraph") diff --git a/torch_npu/csrc/npu/Graph.h b/torch_npu/csrc/npu/Graph.h new file mode 100644 index 0000000000..bbd403b412 --- /dev/null +++ b/torch_npu/csrc/npu/Graph.h @@ -0,0 +1,18 @@ +#ifndef THNP_GRAPH_INC +#define THNP_GRAPH_INC + +#include + +struct PyFuncStruct { + PyFuncStruct(PyObject *pyFunc, PyObject *pyFuncArgs) + : pyFunc(pyFunc), pyFuncArgs(pyFuncArgs) + { + Py_XINCREF(pyFunc); + Py_XINCREF(pyFuncArgs); + } + + PyObject* pyFunc = nullptr; + PyObject* pyFuncArgs = nullptr; +}; + +#endif // THNP_GRAPH_INC \ No newline at end of file diff --git a/torch_npu/npu/__init__.py b/torch_npu/npu/__init__.py index 182859d8a5..560466e915 100644 --- a/torch_npu/npu/__init__.py +++ b/torch_npu/npu/__init__.py @@ -116,7 +116,11 @@ __all__ = [ "graph_task_update_begin", "graph_task_update_end", "set_device_limit", - "get_device_limit" + "get_device_limit", + "launch_callback", + "subscribe_report", + "unsubscribe_report", + "process_report" ] from typing import Tuple, Union, List, cast, Optional @@ -156,6 +160,10 @@ from .graphs import ( graph_task_group_end, graph_task_update_begin, graph_task_update_end, + launch_callback, + subscribe_report, + unsubscribe_report, + process_report, ) # init profiler diff --git a/torch_npu/npu/graphs.py b/torch_npu/npu/graphs.py index 7e21ce5ed9..abf1026ce2 100644 --- a/torch_npu/npu/graphs.py +++ b/torch_npu/npu/graphs.py @@ -27,6 +27,10 @@ if not hasattr(torch_npu._C, "_NPUStreamBase"): torch_npu._C.__dict__["_graph_task_group_end"] = _dummy_type("_graph_task_group_end") torch_npu._C.__dict__["_graph_task_update_begin"] = _dummy_type("_graph_task_update_begin") torch_npu._C.__dict__["_graph_task_update_end"] = _dummy_type("_graph_task_update_end") + torch_npu._C.__dict__["_launch_callback"] = _dummy_type("_launch_callback") + torch_npu._C.__dict__["_subscribe_report"] = _dummy_type("_subscribe_report") + torch_npu._C.__dict__["_unsubscribe_report"] = _dummy_type("_unsubscribe_report") + torch_npu._C.__dict__["_process_report"] = _dummy_type("_process_report") from torch_npu._C import ( # noqa: F401 _npu_isCurrentStreamCapturing, @@ -36,6 +40,10 @@ from torch_npu._C import ( # noqa: F401 _graph_task_group_end, _graph_task_update_begin, _graph_task_update_end, + _launch_callback, + _subscribe_report, + _unsubscribe_report, + _process_report, ) @@ -74,6 +82,17 @@ def graph_task_update_begin(stream, handle): def graph_task_update_end(stream): _graph_task_update_end(stream) +def launch_callback(fn, user_data_list, block_type, stream): + _launch_callback(fn, user_data_list, block_type, stream) + +def subscribe_report(thread_id, stream): + _subscribe_report(thread_id, stream) + +def unsubscribe_report(thread_id, stream): + _unsubscribe_report(thread_id, stream) + +def process_report(timeout): + _process_report(timeout) @dataclass class _GraphDispatchRecord: -- Gitee