diff --git a/test/npu/test_aclgraph_launch_callback.py b/test/npu/test_aclgraph_launch_callback.py new file mode 100644 index 0000000000000000000000000000000000000000..f224a9ad797a05f7987d8cbc8b16e9293ae6233c --- /dev/null +++ b/test/npu/test_aclgraph_launch_callback.py @@ -0,0 +1,60 @@ +import unittest +from itertools import chain + +import torch +from torch import nn +import torch_npu +from torch_npu.testing.common_utils import SupportedDevices +from torch_npu.testing.testcase import TestCase, run_tests + +callback_stream = torch.npu.Stream() + + +def callback_add(params): + global callback_stream + with torch.npu.stream(callback_stream): + x, y, result = params + result.copy_(x + y) + + +class MyModel(nn.Module): + def __init__(self): + super(MyModel, self).__init__() + self.result = torch.rand([5, 5]).npu() + + def forward(self, graph, x, y): + call_params = [torch.matmul(x, y), torch.matmul(x, y), self.result] + for _ in range(10000): + torch_npu.npu.launch_host_func(torch.npu.current_stream(), callback_add, call_params) + return self.result + + +class TestAclgraphLaunchCallback(TestCase): + + @SupportedDevices(['Ascend910B']) + def test_launch_callback(self): + torch_npu.npu.set_compile_mode(jit_compile=False) + torch_npu.npu.set_device(0) + + self.capture_stream = torch_npu.npu.Stream() + self.graph = torch_npu.npu.NPUGraph() + + torch_npu.npu.subscribe_report(self.capture_stream) + a = torch.randn([5, 5]).npu() + b = torch.randn([5, 5]).npu() + model = MyModel() + with torch_npu.npu.stream(self.capture_stream): + with torch_npu.npu.graph(self.graph, stream=self.capture_stream): + self.res = model.forward(self.graph, a, b) + + torch.npu.synchronize() + for _ in range(5): + self.graph.replay() + torch.npu.synchronize() + real = torch.matmul(a, b) + torch.matmul(a, b) + self.assertEqual(self.res.cpu(), real.cpu()) + torch_npu.npu.unsubscribe_report(self.capture_stream) + + +if __name__ == '__main__': + run_tests() \ No newline at end of file diff --git a/torch_npu/csrc/core/npu/NPUGraph.cpp b/torch_npu/csrc/core/npu/NPUGraph.cpp index a00448bd1cf547cb2f641eb47ff940c902341610..8ec094fe1f26e8ca69095e53aec9d6c7af00ed46 100644 --- a/torch_npu/csrc/core/npu/NPUGraph.cpp +++ b/torch_npu/csrc/core/npu/NPUGraph.cpp @@ -5,7 +5,6 @@ #include #include #include -#include #include #include @@ -47,6 +46,22 @@ void graph_task_update_end(c10_npu::NPUStream stream) NPU_CHECK_ERROR(c10_npu::acl::AclmdlRICaptureTaskUpdateEnd(stream)); } +void launch_callback(c10_npu::NPUStream stream, NPUCallbackFunc func, void *fnData) +{ + aclrtCallbackBlockType type = aclrtCallbackBlockType::ACL_CALLBACK_BLOCK; + NPU_CHECK_ERROR(c10_npu::acl::AclrtLaunchCallback(func, fnData, type, stream)); +} + +void subscribe_report(uint64_t threadId, c10_npu::NPUStream stream) +{ + NPU_CHECK_ERROR(c10_npu::acl::AclrtSubscribeReport(threadId, stream)); +} + +void unsubscribe_report(uint64_t threadId, c10_npu::NPUStream stream) +{ + NPU_CHECK_ERROR(c10_npu::acl::AclrtUnSubscribeReport(threadId, stream)); +} + /** * Note [CUDA Graph Wrapper Class] * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/torch_npu/csrc/core/npu/NPUGraph.h b/torch_npu/csrc/core/npu/NPUGraph.h index 442ae335ccae15506f4352ef5c204e185917672b..9fa56d244755a3321bcbed13c499f13132a15e33 100644 --- a/torch_npu/csrc/core/npu/NPUGraph.h +++ b/torch_npu/csrc/core/npu/NPUGraph.h @@ -4,6 +4,8 @@ #include #include +#include "third_party/acl/inc/acl/acl_base.h" +#include "third_party/acl/inc/acl/acl_rt.h" #include "torch_npu/csrc/core/npu/NPUGraphsUtils.h" #include "torch_npu/csrc/core/npu/NPUMacros.h" #include "torch_npu/csrc/core/npu/NPUStream.h" @@ -18,11 +20,18 @@ struct TORCH_NPU_API NPUTaskGroupHandle { aclrtTaskGrp task_group; }; +typedef TORCH_NPU_API void (*NPUCallbackFunc)(void *fnData); + TORCH_NPU_API void graph_task_group_begin(c10_npu::NPUStream stream); TORCH_NPU_API NPUTaskGroupHandle graph_task_group_end(c10_npu::NPUStream stream); TORCH_NPU_API void graph_task_update_begin(c10_npu::NPUStream stream, NPUTaskGroupHandle handle); TORCH_NPU_API void graph_task_update_end(c10_npu::NPUStream stream); +TORCH_NPU_API void launch_callback(c10_npu::NPUStream stream, NPUCallbackFunc func, void *fnData); +TORCH_NPU_API void subscribe_report(uint64_t threadId, c10_npu::NPUStream stream); +TORCH_NPU_API void unsubscribe_report(uint64_t threadId, c10_npu::NPUStream stream); + + struct TORCH_NPU_API NPUGraph { NPUGraph(); ~NPUGraph(); @@ -75,6 +84,7 @@ protected: // not CUDA itself. We can straightforwardly modify CUDAGraph to support multi-device // captures if needed. int capture_dev_; + }; } // namespace c10_npu diff --git a/torch_npu/csrc/core/npu/interface/AclInterface.cpp b/torch_npu/csrc/core/npu/interface/AclInterface.cpp index f277456dae414149dcf9103258e6392c46883e2d..37c45f17e7ee9d1f251d338ac7daf7d7eb58ec87 100644 --- a/torch_npu/csrc/core/npu/interface/AclInterface.cpp +++ b/torch_npu/csrc/core/npu/interface/AclInterface.cpp @@ -94,7 +94,9 @@ LOAD_FUNCTION(aclrtGetDeviceResLimit) LOAD_FUNCTION(aclrtSetDeviceResLimit) LOAD_FUNCTION(aclrtResetDeviceResLimit) LOAD_FUNCTION(aclrtStreamGetId) - +LOAD_FUNCTION(aclrtLaunchCallback) +LOAD_FUNCTION(aclrtSubscribeReport) +LOAD_FUNCTION(aclrtUnSubscribeReport) aclprofStepInfoPtr init_stepinfo() { typedef aclprofStepInfoPtr(*npdInitFunc)(); @@ -1085,5 +1087,41 @@ aclError AclrtStreamGetId(aclrtStream stream, int32_t* stream_id) return func(stream, stream_id); } +aclError AclrtLaunchCallback(aclrtCallback fn, void *userData, aclrtCallbackBlockType blockType, aclrtStream stream) +{ + typedef aclError (*AclrtLaunchCallback)(aclrtCallback, void *, aclrtCallbackBlockType, aclrtStream); + static AclrtLaunchCallback func = nullptr; + if (func == nullptr) { + func = (AclrtLaunchCallback) GET_FUNC(aclrtLaunchCallback); + } + + TORCH_CHECK(func, "Failed to find function aclrtLaunchCallback", PTA_ERROR(ErrCode::NOT_FOUND)); + return func(fn, userData, blockType, stream); +} + +aclError AclrtSubscribeReport(uint64_t threadId, aclrtStream stream) +{ + typedef aclError (*AclrtSubscribeReport)(uint64_t, aclrtStream); + static AclrtSubscribeReport func = nullptr; + if (func == nullptr) { + func = (AclrtSubscribeReport) GET_FUNC(aclrtSubscribeReport); + } + + TORCH_CHECK(func, "Failed to find function aclrtSubscribeReport", PTA_ERROR(ErrCode::NOT_FOUND)); + return func(threadId, stream); +} + +aclError AclrtUnSubscribeReport(uint64_t theadId, aclrtStream stream) +{ + typedef aclError (*AclrtUnSubscribeReport)(uint64_t, aclrtStream); + static AclrtUnSubscribeReport func = nullptr; + if (func == nullptr) { + func = (AclrtUnSubscribeReport) GET_FUNC(aclrtUnSubscribeReport); + } + + TORCH_CHECK(func, "Failed to find function aclrtUnSubscribeReport", PTA_ERROR(ErrCode::NOT_FOUND)); + return func(theadId, stream); +} + } // namespace acl } // namespace c10 diff --git a/torch_npu/csrc/core/npu/interface/AclInterface.h b/torch_npu/csrc/core/npu/interface/AclInterface.h index e917e0ab973e584de074f5fad4e8284b4af9ff2d..cde7e34870e27ed8d0bc744c09ccdc7327a2b05b 100644 --- a/torch_npu/csrc/core/npu/interface/AclInterface.h +++ b/torch_npu/csrc/core/npu/interface/AclInterface.h @@ -259,5 +259,10 @@ aclError AclrtResetDeviceResLimit(int32_t deviceId); aclError AclrtStreamGetId(aclrtStream stream, int32_t* stream_id); +aclError AclrtLaunchCallback(aclrtCallback fn, void *userData, aclrtCallbackBlockType blockType, aclrtStream stream); + +aclError AclrtSubscribeReport(uint64_t theadId, aclrtStream stream); + +aclError AclrtUnSubscribeReport(uint64_t theadId, aclrtStream stream); } // namespace acl } // namespace c10_npu diff --git a/torch_npu/csrc/npu/Graph.cpp b/torch_npu/csrc/npu/Graph.cpp index c8d30cfa448b07e00d7671ff5e6aa7169686ee60..754fbdf8d4b445c1ddae7002b1c903a9fc5d838c 100644 --- a/torch_npu/csrc/npu/Graph.cpp +++ b/torch_npu/csrc/npu/Graph.cpp @@ -1,16 +1,53 @@ -#include - -#include - -#include -#include +#include +#include #include "torch_npu/csrc/core/npu/NPUGraph.h" #include "torch_npu/csrc/core/npu/NPUGraphsUtils.h" #include "torch_npu/csrc/npu/Stream.h" +#include "torch_npu/csrc/npu/Graph.h" template using shared_ptr_class_ = py::class_>; +static std::map> callbacks = {}; +constexpr int processReportTimeout = 100; +static ThreadArgs* threadArgs = nullptr; +static uint64_t threadId = -1; + +void *process_callback(void *arg) +{ + ThreadArgs* args = static_cast(arg); + auto ret = aclrtSetCurrentContext(args->context); + while (!args->exitFlag) { + (void)aclrtProcessReport(processReportTimeout); + } + delete args; + args = nullptr; + return nullptr; +} + +void LaunchCallFunc(void *userData) +{ + PyGILState_STATE state = PyGILState_Ensure(); + if (userData == nullptr) { + return; + } + auto data = (PyFuncStruct *)(userData); + PyObject *argslist = Py_BuildValue("(O)", data->pyFuncArgs); + if (argslist == nullptr) { + return; + } + PyObject *result = PyObject_CallObject(data->pyFunc, argslist); + if (result == nullptr) { + return; + } + if (argslist != nullptr) { + Py_XDECREF(argslist); + } + if (result != nullptr) { + Py_XDECREF(result); + } + PyGILState_Release(state); +} void TORCH_NPU_API THNPGraph_init(PyObject* module) { // Pybind11 patch notes say "py::module_" is more up-to-date syntax, @@ -36,6 +73,42 @@ void TORCH_NPU_API THNPGraph_init(PyObject* module) { .def("_graph_task_update_end", [](py::object py_stream) { auto stream = (*py_stream).ptr(); c10_npu::graph_task_update_end(THNPUtils_PyObject_to_NPUStream(stream)); + }) + .def("_launch_host_func", [](py::object py_stream, py::object py_func, py::object py_data) { + auto func = (*py_func).ptr(); + auto userDataList = (*py_data).ptr(); + auto stream = THNPUtils_PyObject_to_NPUStream((*py_stream).ptr()); + PyFuncStruct *data = new(std::nothrow) PyFuncStruct(func, userDataList); + c10_npu::launch_callback(stream, LaunchCallFunc, data); + callbacks[stream].emplace_back(data); + }) + .def("_subscribe_report", [](py::object py_stream) { + auto stream = (*py_stream).ptr(); + aclrtContext context = aclrtContext(); + NPU_CHECK_ERROR(aclrtGetCurrentContext(&context)); + if ((threadArgs == nullptr) || (threadId == -1)) { + threadArgs = new ThreadArgs(context, false); + pthread_create(&threadId, nullptr, process_callback, threadArgs); + } + c10_npu::subscribe_report(threadId, THNPUtils_PyObject_to_NPUStream(stream)); + }) + .def("_unsubscribe_report", [](py::object py_stream) { + auto stream = THNPUtils_PyObject_to_NPUStream((*py_stream).ptr()); + auto it = callbacks.find(stream); + if (it != callbacks.end()) { + std::vector& funcs = it->second; + for (PyFuncStruct* func : funcs) { + delete func; + func = nullptr; + } + funcs.clear(); + callbacks.erase(it); + } + if (callbacks.empty()) { + threadArgs->exitFlag = true; + threadId = -1; + } + c10_npu::unsubscribe_report(threadId, stream); }); shared_ptr_class_(torch_N_m, "_NPUGraph") diff --git a/torch_npu/csrc/npu/Graph.h b/torch_npu/csrc/npu/Graph.h new file mode 100644 index 0000000000000000000000000000000000000000..eab1aea354b86476923dfd8badbb73f6d4e33f81 --- /dev/null +++ b/torch_npu/csrc/npu/Graph.h @@ -0,0 +1,33 @@ +#include +#include +#include +#include + +#include "third_party/acl/inc/acl/acl_base.h" +#include "third_party/acl/inc/acl/acl_rt.h" + +struct PyFuncStruct { + PyFuncStruct(PyObject *pyFunc, PyObject *pyFuncArgs) + : pyFunc(pyFunc), pyFuncArgs(pyFuncArgs) + { + Py_XINCREF(pyFunc); + Py_XINCREF(pyFuncArgs); + } + + ~PyFuncStruct() + { + Py_XDECREF(pyFunc); + Py_XDECREF(pyFuncArgs); + } + + PyObject* pyFunc = nullptr; + PyObject* pyFuncArgs = nullptr; +}; + +struct ThreadArgs { + ThreadArgs(aclrtContext context, bool exitFlag) + : context(context), exitFlag(exitFlag) {} + + aclrtContext context; + bool exitFlag; +}; diff --git a/torch_npu/npu/__init__.py b/torch_npu/npu/__init__.py index 182859d8a5aefc290a702d41fdd36cc33631c72c..e24d149084ebfc1ecd9e51601353e36a241e6a59 100644 --- a/torch_npu/npu/__init__.py +++ b/torch_npu/npu/__init__.py @@ -116,7 +116,10 @@ __all__ = [ "graph_task_update_begin", "graph_task_update_end", "set_device_limit", - "get_device_limit" + "get_device_limit", + "launch_host_func", + "subscribe_report", + "unsubscribe_report" ] from typing import Tuple, Union, List, cast, Optional @@ -156,6 +159,9 @@ from .graphs import ( graph_task_group_end, graph_task_update_begin, graph_task_update_end, + launch_host_func, + subscribe_report, + unsubscribe_report, ) # init profiler diff --git a/torch_npu/npu/graphs.py b/torch_npu/npu/graphs.py index 7e21ce5ed9a78512d66e1bd915eb834732aa00fc..cb5dd784348b13f9db7b4e006e22d7e3287b7035 100644 --- a/torch_npu/npu/graphs.py +++ b/torch_npu/npu/graphs.py @@ -1,6 +1,7 @@ __all__ = ["is_current_stream_capturing", "graph_pool_handle", "graph_task_group_begin", "graph_task_group_end", "graph_task_update_begin", "graph_task_update_end", - "NPUGraph", "graph", "make_graphed_callables"] + "NPUGraph", "graph", "make_graphed_callables", "launch_host_func", "subscribe_report", + "unsubscribe_report"] import gc import re @@ -27,6 +28,9 @@ if not hasattr(torch_npu._C, "_NPUStreamBase"): torch_npu._C.__dict__["_graph_task_group_end"] = _dummy_type("_graph_task_group_end") torch_npu._C.__dict__["_graph_task_update_begin"] = _dummy_type("_graph_task_update_begin") torch_npu._C.__dict__["_graph_task_update_end"] = _dummy_type("_graph_task_update_end") + torch_npu._C.__dict__["_launch_host_func"] = _dummy_type("_launch_host_func") + torch_npu._C.__dict__["_subscribe_report"] = _dummy_type("_subscribe_report") + torch_npu._C.__dict__["_unsubscribe_report"] = _dummy_type("_unsubscribe_report") from torch_npu._C import ( # noqa: F401 _npu_isCurrentStreamCapturing, @@ -36,6 +40,9 @@ from torch_npu._C import ( # noqa: F401 _graph_task_group_end, _graph_task_update_begin, _graph_task_update_end, + _launch_host_func, + _subscribe_report, + _unsubscribe_report, ) @@ -75,6 +82,18 @@ def graph_task_update_end(stream): _graph_task_update_end(stream) +def launch_host_func(stream, fn, user_data): + _launch_host_func(stream, fn, user_data) + + +def subscribe_report(stream): + _subscribe_report(stream) + + +def unsubscribe_report(stream): + _unsubscribe_report(stream) + + @dataclass class _GraphDispatchRecord: """存储单次操作的完整记录"""