From 838cf6b7f54956e5c20185c6653ad3458823adab Mon Sep 17 00:00:00 2001 From: mueuler Date: Sat, 26 Jul 2025 10:30:46 +0800 Subject: [PATCH] add launch_callback --- test/npu/test_aclgraph_launch_callback.py | 60 +++++++++++++++++++ torch_npu/csrc/core/npu/NPUGraph.cpp | 54 +++++++++++++++++ torch_npu/csrc/core/npu/NPUGraph.h | 36 +++++++++++ .../csrc/core/npu/interface/AclInterface.cpp | 39 ++++++++++++ .../csrc/core/npu/interface/AclInterface.h | 5 ++ torch_npu/csrc/npu/Graph.cpp | 39 ++++++++++++ torch_npu/npu/__init__.py | 8 ++- torch_npu/npu/graphs.py | 21 ++++++- 8 files changed, 260 insertions(+), 2 deletions(-) create mode 100644 test/npu/test_aclgraph_launch_callback.py diff --git a/test/npu/test_aclgraph_launch_callback.py b/test/npu/test_aclgraph_launch_callback.py new file mode 100644 index 0000000000..f224a9ad79 --- /dev/null +++ b/test/npu/test_aclgraph_launch_callback.py @@ -0,0 +1,60 @@ +import unittest +from itertools import chain + +import torch +from torch import nn +import torch_npu +from torch_npu.testing.common_utils import SupportedDevices +from torch_npu.testing.testcase import TestCase, run_tests + +callback_stream = torch.npu.Stream() + + +def callback_add(params): + global callback_stream + with torch.npu.stream(callback_stream): + x, y, result = params + result.copy_(x + y) + + +class MyModel(nn.Module): + def __init__(self): + super(MyModel, self).__init__() + self.result = torch.rand([5, 5]).npu() + + def forward(self, graph, x, y): + call_params = [torch.matmul(x, y), torch.matmul(x, y), self.result] + for _ in range(10000): + torch_npu.npu.launch_host_func(torch.npu.current_stream(), callback_add, call_params) + return self.result + + +class TestAclgraphLaunchCallback(TestCase): + + @SupportedDevices(['Ascend910B']) + def test_launch_callback(self): + torch_npu.npu.set_compile_mode(jit_compile=False) + torch_npu.npu.set_device(0) + + self.capture_stream = torch_npu.npu.Stream() + self.graph = torch_npu.npu.NPUGraph() + + torch_npu.npu.subscribe_report(self.capture_stream) + a = torch.randn([5, 5]).npu() + b = torch.randn([5, 5]).npu() + model = MyModel() + with torch_npu.npu.stream(self.capture_stream): + with torch_npu.npu.graph(self.graph, stream=self.capture_stream): + self.res = model.forward(self.graph, a, b) + + torch.npu.synchronize() + for _ in range(5): + self.graph.replay() + torch.npu.synchronize() + real = torch.matmul(a, b) + torch.matmul(a, b) + self.assertEqual(self.res.cpu(), real.cpu()) + torch_npu.npu.unsubscribe_report(self.capture_stream) + + +if __name__ == '__main__': + run_tests() \ No newline at end of file diff --git a/torch_npu/csrc/core/npu/NPUGraph.cpp b/torch_npu/csrc/core/npu/NPUGraph.cpp index a00448bd1c..c843371c17 100644 --- a/torch_npu/csrc/core/npu/NPUGraph.cpp +++ b/torch_npu/csrc/core/npu/NPUGraph.cpp @@ -14,6 +14,10 @@ namespace c10_npu { static bool _npu_graphs_debug = false; constexpr int kSynchronizeBusyWaitMillis = 10; +constexpr int processReportTimeout = 100; +static std::map> callbacks = {}; +static ThreadArgs* threadArgs = nullptr; +static pthread_t threadId = -1; MempoolId_t graph_pool_handle() { @@ -47,6 +51,56 @@ void graph_task_update_end(c10_npu::NPUStream stream) NPU_CHECK_ERROR(c10_npu::acl::AclmdlRICaptureTaskUpdateEnd(stream)); } +void launch_callback(c10_npu::NPUStream stream, NPUCallbackFunc func, PyFuncStruct *fnData) +{ + aclrtCallbackBlockType type = aclrtCallbackBlockType::ACL_CALLBACK_BLOCK; + NPU_CHECK_ERROR(c10_npu::acl::AclrtLaunchCallback(func, reinterpret_cast(fnData), type, stream)); + callbacks[stream].emplace_back(fnData); +} + +void *process_callback(void *arg) +{ + ThreadArgs* args = static_cast(arg); + auto ret = aclrtSetCurrentContext(args->context); + while (!args->exitFlag) { + (void)aclrtProcessReport(processReportTimeout); + } + delete args; + args = nullptr; + return nullptr; +} + +void subscribe_report(c10_npu::NPUStream stream) +{ + aclrtContext context = aclrtContext(); + NPU_CHECK_ERROR(aclrtGetCurrentContext(&context)); + if ((threadArgs == nullptr) || (threadId == -1)) { + threadArgs = new ThreadArgs(context, false); + pthread_create(&threadId, nullptr, process_callback, threadArgs); + } + NPU_CHECK_ERROR(c10_npu::acl::AclrtSubscribeReport(threadId, stream)); +} + +void unsubscribe_report(c10_npu::NPUStream stream) +{ + NPU_CHECK_ERROR(c10_npu::acl::AclrtUnSubscribeReport(threadId, stream)); + auto it = callbacks.find(stream); + if (it == callbacks.end()) { + return; + } + + std::vector& funcs = it->second; + for (PyFuncStruct* func : funcs) { + delete func; + func = nullptr; + } + funcs.clear(); + callbacks.erase(it); + if (callbacks.empty()) { + threadArgs->exitFlag = true; + } +} + /** * Note [CUDA Graph Wrapper Class] * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/torch_npu/csrc/core/npu/NPUGraph.h b/torch_npu/csrc/core/npu/NPUGraph.h index 442ae335cc..23696a3a14 100644 --- a/torch_npu/csrc/core/npu/NPUGraph.h +++ b/torch_npu/csrc/core/npu/NPUGraph.h @@ -3,7 +3,10 @@ #include #include #include +#include +#include "third_party/acl/inc/acl/acl_base.h" +#include "third_party/acl/inc/acl/acl_rt.h" #include "torch_npu/csrc/core/npu/NPUGraphsUtils.h" #include "torch_npu/csrc/core/npu/NPUMacros.h" #include "torch_npu/csrc/core/npu/NPUStream.h" @@ -18,11 +21,43 @@ struct TORCH_NPU_API NPUTaskGroupHandle { aclrtTaskGrp task_group; }; +typedef TORCH_NPU_API void (*NPUCallbackFunc)(void *fnData); + TORCH_NPU_API void graph_task_group_begin(c10_npu::NPUStream stream); TORCH_NPU_API NPUTaskGroupHandle graph_task_group_end(c10_npu::NPUStream stream); TORCH_NPU_API void graph_task_update_begin(c10_npu::NPUStream stream, NPUTaskGroupHandle handle); TORCH_NPU_API void graph_task_update_end(c10_npu::NPUStream stream); +struct TORCH_NPU_API ThreadArgs { + ThreadArgs(aclrtContext context, bool exitFlag) + : context(context), exitFlag(exitFlag) {} + + aclrtContext context; + bool exitFlag; +}; + +struct TORCH_NPU_API PyFuncStruct { + PyFuncStruct(PyObject *pyFunc, PyObject *pyFuncArgs) + : pyFunc(pyFunc), pyFuncArgs(pyFuncArgs) + { + Py_XINCREF(pyFunc); + Py_XINCREF(pyFuncArgs); + } + + ~PyFuncStruct() { + Py_XDECREF(pyFunc); + Py_XDECREF(pyFuncArgs); + } + + PyObject* pyFunc = nullptr; + PyObject* pyFuncArgs = nullptr; +}; + +TORCH_NPU_API void launch_callback(c10_npu::NPUStream stream, NPUCallbackFunc func, PyFuncStruct *fnData); +TORCH_NPU_API void subscribe_report(c10_npu::NPUStream stream); +TORCH_NPU_API void unsubscribe_report(c10_npu::NPUStream stream); + + struct TORCH_NPU_API NPUGraph { NPUGraph(); ~NPUGraph(); @@ -75,6 +110,7 @@ protected: // not CUDA itself. We can straightforwardly modify CUDAGraph to support multi-device // captures if needed. int capture_dev_; + }; } // namespace c10_npu diff --git a/torch_npu/csrc/core/npu/interface/AclInterface.cpp b/torch_npu/csrc/core/npu/interface/AclInterface.cpp index 6b8053b9c3..acf97e0300 100644 --- a/torch_npu/csrc/core/npu/interface/AclInterface.cpp +++ b/torch_npu/csrc/core/npu/interface/AclInterface.cpp @@ -94,6 +94,9 @@ LOAD_FUNCTION(aclrtGetDeviceResLimit) LOAD_FUNCTION(aclrtSetDeviceResLimit) LOAD_FUNCTION(aclrtResetDeviceResLimit) LOAD_FUNCTION(aclrtStreamGetId) +LOAD_FUNCTION(aclrtLaunchCallback) +LOAD_FUNCTION(aclrtSubscribeReport) +LOAD_FUNCTION(aclrtUnSubscribeReport) aclprofStepInfoPtr init_stepinfo() { typedef aclprofStepInfoPtr(*npdInitFunc)(); @@ -1084,5 +1087,41 @@ aclError AclrtStreamGetId(aclrtStream stream, int32_t* stream_id) return func(stream, stream_id); } +aclError AclrtLaunchCallback(aclrtCallback fn, void *userData, aclrtCallbackBlockType blockType, aclrtStream stream) +{ + typedef aclError (*AclrtLaunchCallback)(aclrtCallback, void *, aclrtCallbackBlockType, aclrtStream); + static AclrtLaunchCallback func = nullptr; + if (func == nullptr) { + func = (AclrtLaunchCallback) GET_FUNC(aclrtLaunchCallback); + } + + TORCH_CHECK(func, "Failed to find function aclrtLaunchCallback", PTA_ERROR(ErrCode::NOT_FOUND)); + return func(fn, userData, blockType, stream); +} + +aclError AclrtSubscribeReport(uint64_t threadId, aclrtStream stream) +{ + typedef aclError (*AclrtSubscribeReport)(uint64_t, aclrtStream); + static AclrtSubscribeReport func = nullptr; + if (func == nullptr) { + func = (AclrtSubscribeReport) GET_FUNC(aclrtSubscribeReport); + } + + TORCH_CHECK(func, "Failed to find function aclrtSubscribeReport", PTA_ERROR(ErrCode::NOT_FOUND)); + return func(threadId, stream); +} + +aclError AclrtUnSubscribeReport(uint64_t theadId, aclrtStream stream) +{ + typedef aclError (*AclrtUnSubscribeReport)(uint64_t, aclrtStream); + static AclrtUnSubscribeReport func = nullptr; + if (func == nullptr) { + func = (AclrtUnSubscribeReport) GET_FUNC(aclrtUnSubscribeReport); + } + + TORCH_CHECK(func, "Failed to find function aclrtUnSubscribeReport", PTA_ERROR(ErrCode::NOT_FOUND)); + return func(theadId, stream); +} + } // namespace acl } // namespace c10 diff --git a/torch_npu/csrc/core/npu/interface/AclInterface.h b/torch_npu/csrc/core/npu/interface/AclInterface.h index 367963c070..e4904380b0 100644 --- a/torch_npu/csrc/core/npu/interface/AclInterface.h +++ b/torch_npu/csrc/core/npu/interface/AclInterface.h @@ -259,5 +259,10 @@ aclError AclrtResetDeviceResLimit(int32_t deviceId); aclError AclrtStreamGetId(aclrtStream stream, int32_t* stream_id); +aclError AclrtLaunchCallback(aclrtCallback fn, void *userData, aclrtCallbackBlockType blockType, aclrtStream stream); + +aclError AclrtSubscribeReport(uint64_t theadId, aclrtStream stream); + +aclError AclrtUnSubscribeReport(uint64_t theadId, aclrtStream stream); } // namespace acl } // namespace c10_npu diff --git a/torch_npu/csrc/npu/Graph.cpp b/torch_npu/csrc/npu/Graph.cpp index c8d30cfa44..ed6a91a2d1 100644 --- a/torch_npu/csrc/npu/Graph.cpp +++ b/torch_npu/csrc/npu/Graph.cpp @@ -12,6 +12,30 @@ template using shared_ptr_class_ = py::class_>; +void LaunchCallFunc(void *userData) +{ + PyGILState_STATE state = PyGILState_Ensure(); + if (userData == nullptr) { + return; + } + auto data = (c10_npu::PyFuncStruct *)(userData); + PyObject *argslist = Py_BuildValue("(O)", data->pyFuncArgs); + if (argslist == nullptr) { + return; + } + PyObject *result = PyObject_CallObject(data->pyFunc, argslist); + if (result == nullptr) { + return; + } + if (argslist != nullptr) { + Py_XDECREF(argslist); + } + if (result != nullptr) { + Py_XDECREF(result); + } + PyGILState_Release(state); +} + void TORCH_NPU_API THNPGraph_init(PyObject* module) { // Pybind11 patch notes say "py::module_" is more up-to-date syntax, // but CI linter and some builds prefer "module". @@ -36,6 +60,21 @@ void TORCH_NPU_API THNPGraph_init(PyObject* module) { .def("_graph_task_update_end", [](py::object py_stream) { auto stream = (*py_stream).ptr(); c10_npu::graph_task_update_end(THNPUtils_PyObject_to_NPUStream(stream)); + }) + .def("_launch_host_func", [](py::object py_stream, py::object py_func, py::object py_data) { + auto func = (*py_func).ptr(); + auto userDataList = (*py_data).ptr(); + auto stream = (*py_stream).ptr(); + c10_npu::PyFuncStruct *data = new(std::nothrow) c10_npu::PyFuncStruct(func, userDataList); + c10_npu::launch_callback(THNPUtils_PyObject_to_NPUStream (stream), LaunchCallFunc, data); + }) + .def("_subscribe_report", [](py::object py_stream) { + auto stream = (*py_stream).ptr(); + c10_npu::subscribe_report(THNPUtils_PyObject_to_NPUStream(stream)); + }) + .def("_unsubscribe_report", [](py::object py_stream) { + auto stream = (*py_stream).ptr(); + c10_npu::unsubscribe_report(THNPUtils_PyObject_to_NPUStream(stream)); }); shared_ptr_class_(torch_N_m, "_NPUGraph") diff --git a/torch_npu/npu/__init__.py b/torch_npu/npu/__init__.py index 7210d6e431..eedfe1d389 100644 --- a/torch_npu/npu/__init__.py +++ b/torch_npu/npu/__init__.py @@ -116,7 +116,10 @@ __all__ = [ "graph_task_update_begin", "graph_task_update_end", "set_device_limit", - "get_device_limit" + "get_device_limit", + "launch_host_func", + "subscribe_report", + "unsubscribe_report" ] from typing import Tuple, Union, List, cast, Optional @@ -155,6 +158,9 @@ from .graphs import ( graph_task_group_end, graph_task_update_begin, graph_task_update_end, + launch_host_func, + subscribe_report, + unsubscribe_report, ) # init profiler diff --git a/torch_npu/npu/graphs.py b/torch_npu/npu/graphs.py index 7e21ce5ed9..cb5dd78434 100644 --- a/torch_npu/npu/graphs.py +++ b/torch_npu/npu/graphs.py @@ -1,6 +1,7 @@ __all__ = ["is_current_stream_capturing", "graph_pool_handle", "graph_task_group_begin", "graph_task_group_end", "graph_task_update_begin", "graph_task_update_end", - "NPUGraph", "graph", "make_graphed_callables"] + "NPUGraph", "graph", "make_graphed_callables", "launch_host_func", "subscribe_report", + "unsubscribe_report"] import gc import re @@ -27,6 +28,9 @@ if not hasattr(torch_npu._C, "_NPUStreamBase"): torch_npu._C.__dict__["_graph_task_group_end"] = _dummy_type("_graph_task_group_end") torch_npu._C.__dict__["_graph_task_update_begin"] = _dummy_type("_graph_task_update_begin") torch_npu._C.__dict__["_graph_task_update_end"] = _dummy_type("_graph_task_update_end") + torch_npu._C.__dict__["_launch_host_func"] = _dummy_type("_launch_host_func") + torch_npu._C.__dict__["_subscribe_report"] = _dummy_type("_subscribe_report") + torch_npu._C.__dict__["_unsubscribe_report"] = _dummy_type("_unsubscribe_report") from torch_npu._C import ( # noqa: F401 _npu_isCurrentStreamCapturing, @@ -36,6 +40,9 @@ from torch_npu._C import ( # noqa: F401 _graph_task_group_end, _graph_task_update_begin, _graph_task_update_end, + _launch_host_func, + _subscribe_report, + _unsubscribe_report, ) @@ -75,6 +82,18 @@ def graph_task_update_end(stream): _graph_task_update_end(stream) +def launch_host_func(stream, fn, user_data): + _launch_host_func(stream, fn, user_data) + + +def subscribe_report(stream): + _subscribe_report(stream) + + +def unsubscribe_report(stream): + _unsubscribe_report(stream) + + @dataclass class _GraphDispatchRecord: """存储单次操作的完整记录""" -- Gitee