From 49f431ad52538f446f036b270c7691918176e501 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E5=AE=81?= Date: Tue, 19 Aug 2025 11:06:00 +0800 Subject: [PATCH] add module afd --- CMakeLists.txt | 4 +- test/_afd/test_schedule_contest.py | 124 +++++++ torch_npu/__init__.py | 1 + torch_npu/_afd/__init__.py | 53 +++ torch_npu/_afd/_schedule_context.py | 139 +++++++ torch_npu/csrc/InitNpuBindings.cpp | 2 + torch_npu/csrc/afd/CMakeLists.txt | 6 + torch_npu/csrc/afd/Init.cpp | 44 +++ torch_npu/csrc/afd/Init.h | 12 + torch_npu/csrc/afd/ScheduleContext.cpp | 493 +++++++++++++++++++++++++ torch_npu/csrc/afd/ScheduleContext.h | 122 ++++++ 11 files changed, 999 insertions(+), 1 deletion(-) create mode 100644 test/_afd/test_schedule_contest.py create mode 100644 torch_npu/_afd/__init__.py create mode 100644 torch_npu/_afd/_schedule_context.py create mode 100644 torch_npu/csrc/afd/CMakeLists.txt create mode 100644 torch_npu/csrc/afd/Init.cpp create mode 100644 torch_npu/csrc/afd/Init.h create mode 100644 torch_npu/csrc/afd/ScheduleContext.cpp create mode 100644 torch_npu/csrc/afd/ScheduleContext.h diff --git a/CMakeLists.txt b/CMakeLists.txt index df717371804..49aece97906 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -241,6 +241,7 @@ if (NOT DEFINED BUILD_LIBTORCH) set(IPC_SRCS) set(UTILS_SRCS) set(SAN_SRCS) + set(AFD_SRCS) endif() if (DEFINED BUILD_LIBTORCH) @@ -262,6 +263,7 @@ if (NOT DEFINED BUILD_LIBTORCH) add_subdirectory(${TORCHNPU_ROOT}/ipc) add_subdirectory(${TORCHNPU_ROOT}/utils) add_subdirectory(${TORCHNPU_ROOT}/sanitizer) + add_subdirectory(${TORCHNPU_ROOT}/afd) endif() if (DEFINED BUILD_LIBTORCH) @@ -284,7 +286,7 @@ if (DEFINED BUILD_LIBTORCH) set(CPP_SRCS ${ATEN_SRCS} ${INDUCTOR_SRCS} ${CORE_SRCS} ${OPS_PLUGIN_SRCS} ${FLOP_SRCS} ${CUS_DTYPE_SRCS} ${FRAMEWORK_SRCS} ${LOGGING_SRCS} ${NPU_CPP_LIBS_SRCS}) else() # Compile code with pybind11 - set(CPP_SRCS ${ATEN_SRCS} ${INDUCTOR_SRCS} ${CORE_SRCS} ${OPS_PLUGIN_SRCS} ${DIST_SRCS} ${FLOP_SRCS} ${CUS_DTYPE_SRCS} ${LOGGING_SRCS} ${FRAMEWORK_SRCS} ${NPU_SRCS} ${PROF_SRCS} ${IPC_SRCS} ${UTILS_SRCS} ${SAN_SRCS}) + set(CPP_SRCS ${ATEN_SRCS} ${INDUCTOR_SRCS} ${CORE_SRCS} ${OPS_PLUGIN_SRCS} ${DIST_SRCS} ${FLOP_SRCS} ${CUS_DTYPE_SRCS} ${LOGGING_SRCS} ${FRAMEWORK_SRCS} ${NPU_SRCS} ${PROF_SRCS} ${IPC_SRCS} ${UTILS_SRCS} ${SAN_SRCS} ${AFD_SRCS}) endif() add_library(${PLUGIN_NAME} SHARED ${CPP_SRCS}) diff --git a/test/_afd/test_schedule_contest.py b/test/_afd/test_schedule_contest.py new file mode 100644 index 00000000000..fa03ce57ca6 --- /dev/null +++ b/test/_afd/test_schedule_contest.py @@ -0,0 +1,124 @@ +import torch +import torch_npu +from torch_npu.testing.testcase import TestCase, run_tests + + +class TestScheduleContext(TestCase): + def setUp(self): + npu_device = torch._C._get_privateuse1_backend_name() + self.window_tensor = torch.ones([1 * 1024 * 1024 * 1024], dtype=torch.int8).to( + npu_device + ) + self.default_params = { + "schedule_mode": 0, + "session_num": 288, + "micro_batch_num": 3, + "micro_batch_size": 30, + "selected_expert_num": 9, + "expert_num": 288, + "attn_to_ffn_token_size": 7168 + 512, + "ffn_to_attn_token_size": 7168 * 2, + "attention_window": self.window_tensor.data_ptr(), + "attention_window_size": 1 * 1024 * 1024 * 1024, + "ffn_window": self.window_tensor.data_ptr(), + "ffn_window_size": 1 * 1024 * 1024 * 1024, + } + + def tearDown(self): + return super().tearDown() + + def test_init_with_invalid_params(self): + """测试参数校验""" + invalid_params = [ + ({"session_num": 0}, "session_num=0 should fail"), + ({"micro_batch_num": 0}, "micro_batch_num=0 should fail"), + ( + {"session_num": 1 << 31, "micro_batch_num": 1 << 31}, + "micro_batch_num mul overflow", + ), + ({"micro_batch_size": 0}, "micro_batch_size=0 should fail"), + ( + {"micro_batch_num": 1 << 31, "micro_batch_size": 1 << 31}, + "micro_batch_size mul overflow", + ), + ( + { + "schedule_mode": 1, + "micro_batch_num": 1 << 31, + "micro_batch_size": 1 << 31, + }, + "attention micro_batch_size mul overflow", + ), + ({"selected_expert_num": 0}, "selected_expert_num=0 should fail"), + ( + {"micro_batch_size": 1 << 31, "selected_expert_num": 1 << 31}, + "selected_expert_num mul overflow", + ), + ( + { + "schedule_mode": 1, + "micro_batch_size": 1 << 31, + "selected_expert_num": 1 << 31, + }, + "attention selected_expert_num mul overflow", + ), + ({"ffn_window": 0}, "ffn_window=0 should fail"), + ({"ffn_window_size": 0}, "ffn_window_size can not be 0"), + ({"ffn_window_size": 511}, "ffn_window_size is not enough should fail"), + ({"schedule_mode": 1, "attention_window": 0}, "ffn_window is null"), + ( + {"schedule_mode": 1, "attention_window_size": 0}, + "attention_window_size can not be 0", + ), + ( + {"schedule_mode": 1, "attention_window_size": 511}, + "attention_window_size is not enough should fail", + ), + ({"schedule_mode": 2}, "schedule_mode 2 is not supportted"), + ( + {"attn_to_ffn_token_size": 1023}, + "attn_to_ffn_token_size must be aligned by 512", + ), + ( + {"ffn_to_attn_token_size": 400}, + "ffn_to_attn_token_size must be aligned by 512", + ), + ] + for params, msg in invalid_params: + with self.subTest(msg=msg): + test_params = self.default_params.copy() + test_params.update(params) + with self.assertRaises(RuntimeError): + torch_npu._afd.create_schedule_context_holder(**test_params) + + def test_schedule_ffn(self): + """测试用有效参数初始化""" + holder = torch_npu._afd.create_schedule_context_holder( + **self.default_params + ) + self.assertIsInstance(holder, torch_npu._afd.ScheduleContextHolder) + + # 获取tensor + tensor = holder.get_schedule_context_tensor() + self.assertIsInstance(tensor, torch.Tensor) + + context_info = holder.get_schedule_context_info() + self.assertIn("ffn info:", context_info) + + holder.stop_schedule() + + def test_schedule_attn(self): + """测试用有效参数初始化""" + test_params = self.default_params.copy() + test_params["schedule_mode"] = 1 + holder = torch_npu._afd.create_schedule_context_holder(**test_params) + self.assertIsInstance(holder, torch_npu._afd.ScheduleContextHolder) + + # 获取tensor + tensor = holder.get_schedule_context_tensor() + self.assertIsInstance(tensor, torch.Tensor) + holder.stop_schedule() + + +if __name__ == "__main__": + run_tests() diff --git a/torch_npu/__init__.py b/torch_npu/__init__.py index d88b47cbc88..2681bea8dd1 100644 --- a/torch_npu/__init__.py +++ b/torch_npu/__init__.py @@ -65,6 +65,7 @@ import torch_npu.optim import torch_npu.dynamo import torch_npu._C import torch_npu._logging +import torch_npu._afd from torch_npu import profiler from torch_npu.npu.amp.sharded_grad_scaler import _ShardedGradScaler from torch_npu.contrib.function import npu_functional diff --git a/torch_npu/_afd/__init__.py b/torch_npu/_afd/__init__.py new file mode 100644 index 00000000000..428252bd0ba --- /dev/null +++ b/torch_npu/_afd/__init__.py @@ -0,0 +1,53 @@ +__all__ = ["create_schedule_context_holder"] + +from ._schedule_context import _create_schedule_context_holder, ScheduleContextHolder + + +def create_schedule_context_holder( + schedule_mode: int, + session_num: int, + micro_batch_num: int, + micro_batch_size: int, + selected_expert_num: int, + expert_num: int, + attn_to_ffn_token_size: int, + ffn_to_attn_token_size: int, + ffn_window: int = 0, + ffn_window_size: int = 0, + attention_window: int = 0, + attention_window_size: int = 0, +) -> ScheduleContextHolder: + """ + A holder class for managing scheduling context in distributed inference. + + Args: + schedule_mode: Scheduling mode identifier, 0:schedule ffn, 1:shcedule attention + session_num: Number of sessions + micro_batch_num: Number of micro batches + micro_batch_size: micro batch size + selected_expert_num: selected experts num + expert_num: Total number of experts + attn_to_ffn_token_size: Token size from attention to FFN + ffn_to_attn_token_size: Token size from FFN to attention + ffn_window: FFN window addr (default: 0), must assign value when schedule_mode=0 + ffn_window_size: FFN window size (default: 0), must assign value when schedule_mode=0 + attention_window: Attention window addr (default: 0), must assign value when schedule_mode=1 + attention_window_size: Attention window size (default: 0), must assign value when schedule_mode=1 + + Returns: + ScheduleContextHolder: the schedule context holder + """ + return _create_schedule_context_holder( + schedule_mode, + session_num, + micro_batch_num, + micro_batch_size, + selected_expert_num, + expert_num, + attn_to_ffn_token_size, + ffn_to_attn_token_size, + ffn_window, + ffn_window_size, + attention_window, + attention_window_size, + ) diff --git a/torch_npu/_afd/_schedule_context.py b/torch_npu/_afd/_schedule_context.py new file mode 100644 index 00000000000..897520c2e0b --- /dev/null +++ b/torch_npu/_afd/_schedule_context.py @@ -0,0 +1,139 @@ +__all__ = ["ScheduleContextHolder"] + +import torch +import torch_npu + + +class ScheduleContextHolder: + """ + A holder class for managing scheduling context in distributed inference. + + Args: + schedule_mode: Scheduling mode identifier, 0:schedule ffn, 1:shcedule attention + session_num: Number of sessions + micro_batch_num: Number of micro batches + micro_batch_size: micro batch size + selected_expert_num: selected experts num + expert_num: Total number of experts + attn_to_ffn_token_size: Token size from attention to FFN + ffn_to_attn_token_size: Token size from FFN to attention + ffn_window: FFN window addr (default: 0), must assign value when schedule_mode=0 + ffn_window_size: FFN window size (default: 0), must assign value when schedule_mode=0 + attention_window: Attention window addr (default: 0), must assign value when schedule_mode=1 + attention_window_size: Attention window size (default: 0), must assign value when schedule_mode=1 + """ + # 记录afd模块是否已初始化 + _afd_initialized = False + + def __init__( + self, + schedule_mode: int, + session_num: int, + micro_batch_num: int, + micro_batch_size: int, + selected_expert_num: int, + expert_num: int, + attn_to_ffn_token_size: int, + ffn_to_attn_token_size: int, + ffn_window: int = 0, + ffn_window_size: int = 0, + attention_window: int = 0, + attention_window_size: int = 0, + ) -> None: + if not ScheduleContextHolder._afd_initialized: + ScheduleContextHolder._init_afd_module() + + self._impl = torch_npu._C._afd.ScheduleContextHolder(schedule_mode, + session_num, + micro_batch_num, + micro_batch_size, + selected_expert_num, + expert_num, + attn_to_ffn_token_size, + ffn_to_attn_token_size, + ffn_window, + ffn_window_size, + attention_window, + attention_window_size) + + @classmethod + def _init_afd_module(cls): + if not hasattr(torch_npu._C, "_afd_init"): + raise RuntimeError("Failed to init _afd module as _afd_init is not found in torch_npu._C") + + if not torch_npu._C._afd_init(): + raise RuntimeError("Failed to init _afd module") + + cls._afd_initialized = True + + def init(self) -> None: + ret = self._impl.init() + if ret != 0: + raise RuntimeError(f'ScheduleContextHolder init return {ret}') + + def get_schedule_context_tensor(self) -> torch.Tensor: + """ + Get the scheduling context tensor. + + Returns: + torch.Tensor: The context tensor + + Raises: + RuntimeError: If tensor retrieval fails + """ + ret, tensor = self._impl.get_context_tensor() + if ret != 0: + raise RuntimeError(f'get_context_tensor returned {ret}') + return tensor + + def stop_schedule(self) -> None: + """ + Stop scheduling. + + Raises: + RuntimeError: If set stop flag to context fails + """ + ret = self._impl.stop_schedule() + if ret != 0: + raise RuntimeError(f'stop returned {ret}') + + def get_schedule_context_info(self) -> str: + """ + get schedule context info. + + Returns: + str: the schedule context string info + """ + return self._impl.get_schedule_context_info() + + +def _create_schedule_context_holder( + schedule_mode: int, + session_num: int, + micro_batch_num: int, + micro_batch_size: int, + selected_expert_num: int, + expert_num: int, + attn_to_ffn_token_size: int, + ffn_to_attn_token_size: int, + ffn_window: int = 0, + ffn_window_size: int = 0, + attention_window: int = 0, + attention_window_size: int = 0, +) -> ScheduleContextHolder: + holder = ScheduleContextHolder( + schedule_mode, + session_num, + micro_batch_num, + micro_batch_size, + selected_expert_num, + expert_num, + attn_to_ffn_token_size, + ffn_to_attn_token_size, + ffn_window, + ffn_window_size, + attention_window, + attention_window_size, + ) + holder.init() + return holder diff --git a/torch_npu/csrc/InitNpuBindings.cpp b/torch_npu/csrc/InitNpuBindings.cpp index b30b086e5ce..e18a6607d74 100644 --- a/torch_npu/csrc/InitNpuBindings.cpp +++ b/torch_npu/csrc/InitNpuBindings.cpp @@ -12,6 +12,7 @@ #include "torch_npu/csrc/core/npu/npu_log.h" #include "torch_npu/csrc/core/npu/CachingHostAllocator.h" #include "torch_npu/csrc/distributed/Init.h" +#include "torch_npu/csrc/afd/Init.h" #include "torch_npu/csrc/profiler/init.h" #include "torch_npu/csrc/flopcount/Init.h" #include "torch_npu/csrc/logging/Init.h" @@ -173,6 +174,7 @@ PyObject* initModule() AddPyMethodDefs(methods, torch_npu::logging::logging_functions()); AddPyMethodDefs(methods, torch_npu::reductions::reductions_functions()); AddPyMethodDefs(methods, c10_npu::custom_dtype_functions()); + AddPyMethodDefs(methods, torch_npu::afd::python_functions()); static struct PyModuleDef torchnpu_module = { PyModuleDef_HEAD_INIT, "torch_npu._C", diff --git a/torch_npu/csrc/afd/CMakeLists.txt b/torch_npu/csrc/afd/CMakeLists.txt new file mode 100644 index 00000000000..5a120afbb67 --- /dev/null +++ b/torch_npu/csrc/afd/CMakeLists.txt @@ -0,0 +1,6 @@ +FILE(GLOB _AFD_SRCS *.cpp) + +LIST(APPEND AFD_SRCS ${_AFD_SRCS}) + +# Pass to parent +set(AFD_SRCS ${AFD_SRCS} PARENT_SCOPE) \ No newline at end of file diff --git a/torch_npu/csrc/afd/Init.cpp b/torch_npu/csrc/afd/Init.cpp new file mode 100644 index 00000000000..ca3b87d4138 --- /dev/null +++ b/torch_npu/csrc/afd/Init.cpp @@ -0,0 +1,44 @@ +#include "torch_npu/csrc/afd/Init.h" +#include +#include +#include +#include +#include "torch_npu/csrc/afd/ScheduleContext.h" + +namespace torch_npu { +namespace afd { + +PyObject *afd_init(PyObject * _unused, PyObject * noargs) +{ + auto torch_npu_C_module = THPObjectPtr(PyImport_ImportModule("torch_npu._C")); + if (!torch_npu_C_module) { + throw python_error(); + } + auto torch_npu_C_m = py::handle(torch_npu_C_module).cast(); + + auto m = torch_npu_C_m.def_submodule("_afd", "Attention-FFN Disaggregation"); + auto module = py::handle(m).cast(); + + py::class_(module, "ScheduleContextHolder") + .def(py::init()) + .def("init", &ScheduleContextHolder::Init) + .def("get_context_tensor", &ScheduleContextHolder::GetContextTensor) + .def("stop_schedule", &ScheduleContextHolder::StopSchedule) + .def("get_schedule_context_info", &ScheduleContextHolder::GetScheduleContextInfo); + Py_RETURN_TRUE; +} + +// methods on torch._C +PyMethodDef methods[] = { + {"_afd_init", afd_init, METH_NOARGS, nullptr}, + {nullptr, nullptr, 0, nullptr} +}; + +PyMethodDef *python_functions() +{ + return methods; +} + +} // namespace afd +} // namespace torch_npu diff --git a/torch_npu/csrc/afd/Init.h b/torch_npu/csrc/afd/Init.h new file mode 100644 index 00000000000..d5973fd78f0 --- /dev/null +++ b/torch_npu/csrc/afd/Init.h @@ -0,0 +1,12 @@ +#pragma once + +#include +#include "torch_npu/csrc/core/npu/NPUMacros.h" + +namespace torch_npu { +namespace afd { + +TORCH_NPU_API PyMethodDef *python_functions(); + +} // namespace afd +} // namespace torch_npu \ No newline at end of file diff --git a/torch_npu/csrc/afd/ScheduleContext.cpp b/torch_npu/csrc/afd/ScheduleContext.cpp new file mode 100644 index 00000000000..f6e245b3462 --- /dev/null +++ b/torch_npu/csrc/afd/ScheduleContext.cpp @@ -0,0 +1,493 @@ +#include "ScheduleContext.h" +#include +#include +#include +#include "torch_npu/csrc/core/npu/npu_log.h" + +namespace torch_npu { +namespace afd { +namespace { +constexpr uint32_t kSuccess = 0; +constexpr uint32_t kFailure = 1; + +constexpr int32_t kRunFlagStop = 0; +constexpr int32_t kRunFlagRunning = 1; + +constexpr int32_t kScheduleModeFfn = 0; +constexpr int32_t kScheduleModeAttention = 1; +constexpr uint64_t kBufAlignSize = 512; + +inline uint64_t AlignUp(uint64_t num, uint64_t align) +{ + return ((num + align - 1) / align) * align; +} + +template +class IntegerChecker { +public: + template + static bool Compat(const T1 v) + { + static_assert(((sizeof(T) <= sizeof(uint64_t)) && (sizeof(T1) <= sizeof(uint64_t))), + "IntegerChecker can only check integers less than 64 bits"); + if (v >= static_cast(0)) { + return static_cast(v) <= static_cast(std::numeric_limits::max()); + } + return static_cast(v) >= static_cast(std::numeric_limits::min()); + } +}; + +template +bool MulOverflow(TLhs lhs, TRhs rhs, TRet &ret) +{ +#if __GNUC__ >= 5 + return __builtin_mul_overflow(lhs, rhs, &ret); +#else + if ((!IntegerChecker::Compat(lhs)) || (!IntegerChecker::Compat(rhs))) { + return true; + } + if ((lhs == 0) || (rhs == 0)) { + ret = 0; + return false; + } + TRet reminder = std::numeric_limits::max() / static_cast(rhs); + const TRet lhs_ret_type = static_cast(lhs); + if (lhs_ret_type < 0) { + if (reminder > 0) { + reminder *= static_cast(-1); + } + if (lhs_ret_type < reminder) { + return true; + } + } else { + if (reminder < 0) { + reminder *= static_cast(-1); + } + if (lhs_ret_type > reminder) { + return true; + } + } + ret = static_cast(lhs) * static_cast(rhs); + return false; +#endif +} + +template +bool AddOverflow(TLhs lhs, TRhs rhs, TRet &ret) +{ +#if __GNUC__ >= 5 + return __builtin_add_overflow(lhs, rhs, &ret); +#else + if ((!IntegerChecker::Compat(lhs)) || (!IntegerChecker::Compat(rhs))) { + return true; + } + if (rhs >= 0) { + if (static_cast(lhs) > std::numeric_limits::max() - static_cast(rhs)) { + return true; + } + } else { + if (static_cast(lhs) < std::numeric_limits::min() - static_cast(rhs)) { + return true; + } + } + ret = static_cast(lhs) + static_cast(rhs); + return false; +#endif +} +} // namespace +ScheduleContextHolder::ScheduleContextHolder(int32_t schedule_mode, uint32_t session_num, uint32_t micro_batch_num, + uint32_t micro_batch_size, uint32_t selected_expert_num, + uint32_t expert_num, uint32_t attn_to_ffn_token_size, + uint32_t ffn_to_attn_token_size, uint64_t ffn_window, + uint64_t ffn_window_size, uint64_t attention_window, + uint64_t attention_window_size) +{ + context_.common.schedule_mode = schedule_mode; + context_.common.session_num = session_num; + context_.common.micro_batch_num = micro_batch_num; + context_.common.micro_batch_size = micro_batch_size; + context_.common.selected_expert_num = selected_expert_num; + context_.common.expert_num = expert_num; + context_.common.attn_to_ffn_token_size = attn_to_ffn_token_size; + context_.common.ffn_to_attn_token_size = ffn_to_attn_token_size; + ffn_window_ = ffn_window; + ffn_window_size_ = ffn_window_size; + attention_window_ = attention_window; + attention_window_size_ = attention_window_size; +} + +uint64_t ScheduleContextHolder::CalcFfnTokenInfoSize() const +{ + uint64_t token_info_size = sizeof(int32_t) * static_cast(context_.common.selected_expert_num); + if (MulOverflow(token_info_size, static_cast(context_.common.micro_batch_size), token_info_size)) { + ASCEND_LOGE("check mul with micro_batch_size over flow failed."); + return 0UL; + } + uint64_t flag_and_layer_id_size = sizeof(int32_t) * 2; + if (AddOverflow(token_info_size, flag_and_layer_id_size, token_info_size)) { + ASCEND_LOGE("check add flag and layer id over flow failed."); + return 0UL; + } + + if (MulOverflow(token_info_size, static_cast(context_.common.micro_batch_num), token_info_size)) { + ASCEND_LOGE("check mul with micro_batch_num over flow failed."); + return 0UL; + } + if (MulOverflow(token_info_size, static_cast(context_.common.session_num), token_info_size)) { + ASCEND_LOGE("check mul with session_num over flow failed."); + return 0UL; + } + return token_info_size; +} + +uint32_t ScheduleContextHolder::InitFfnTokenInfoBuf() const +{ + std::unique_ptr tmp_buf(new(std::nothrow) uint8_t[context_.ffn.token_info_buf_size]); + if (tmp_buf == nullptr) { + ASCEND_LOGE("alloc token info host tmp buf failed, buf_size=%lu", context_.ffn.token_info_buf_size); + return kFailure; + } + auto tmp_buf_int = reinterpret_cast(tmp_buf.get()); + for (uint32_t session_id = 0; session_id < context_.common.session_num; ++session_id) { + for (uint32_t micro_batch_id = 0; micro_batch_id < context_.common.micro_batch_num; ++micro_batch_id) { + // flag + *tmp_buf_int++ = 0; + // layer_id + *tmp_buf_int++ = 0; + for (uint32_t idx = 0; + idx < context_.common.micro_batch_size * context_.common.selected_expert_num; ++idx) { + // expert_id + *tmp_buf_int++ = INT32_MAX; + } + } + } + auto token_info_buf = reinterpret_cast(static_cast(context_.ffn.token_info_buf)); + auto ret = aclrtMemcpy(token_info_buf, context_.ffn.token_info_buf_size, tmp_buf.get(), + context_.ffn.token_info_buf_size, ACL_MEMCPY_HOST_TO_DEVICE); + if (ret != ACL_ERROR_NONE) { + ASCEND_LOGE("ACL memory copy token info buf failed, size_=%lu, token_info_buf ptr=%lu.", + context_.ffn.token_info_buf_size, token_info_buf); + return kFailure; + } + return kSuccess; +} + +uint32_t ScheduleContextHolder::InitFfn() +{ + uint64_t token_info_size = CalcFfnTokenInfoSize(); + if (token_info_size == 0U) { + return kFailure; + } + + uint64_t token_info_aligned_size = AlignUp(token_info_size, kBufAlignSize); + if (token_info_aligned_size < token_info_size) { + ASCEND_LOGE("token_info_size=" << token_info_size << " overflow after align with " << kBufAlignSize << "."); + return kFailure; + } + if (ffn_window_size_ <= token_info_aligned_size) { + ASCEND_LOGE("ffn_window_size=%lu must be > token_info_aligned_size=%lu.", + ffn_window_size_, token_info_aligned_size); + return kFailure; + } + + context_.ffn.token_info_buf = ffn_window_; + context_.ffn.token_info_buf_size = token_info_size; + auto ret = InitFfnTokenInfoBuf(); + if (ret != kSuccess) { + return ret; + } + + if (AddOverflow(ffn_window_, token_info_aligned_size, context_.ffn.token_data_buf)) { + ASCEND_LOGE("check ffn_window add token_info_size over flow failed."); + return kFailure; + } + + // can't calc token_data_buf_size as the data type is unknown. + context_.ffn.token_data_buf_size = ffn_window_size_ - token_info_aligned_size; + + // calc output size + context_.ffn.layer_ids_buf_size = sizeof(int32_t) * context_.common.session_num; + context_.ffn.session_ids_buf_size = sizeof(int32_t) * context_.common.session_num; + context_.ffn.micro_batch_ids_buf_size = sizeof(int32_t) * context_.common.session_num; + context_.ffn.expert_ids_buf_size = sizeof(int32_t) * context_.common.session_num * + context_.common.micro_batch_size * context_.common.selected_expert_num; + + ASCEND_LOGI("Init ffn success, token_info_buf=%lu, token_info_buf_size=%lu, " + "token_data_buf=%lu, token_data_buf_size=%lu.", + context_.ffn.token_info_buf, context_.ffn.token_info_buf_size, + context_.ffn.token_data_buf, context_.ffn.token_data_buf_size); + return kSuccess; +} + +uint64_t ScheduleContextHolder::CalcAttentionTokenInfoSize() const +{ + uint64_t token_info_size = sizeof(int32_t) * static_cast(context_.common.selected_expert_num); + if (MulOverflow(token_info_size, static_cast(context_.common.micro_batch_size), token_info_size)) { + ASCEND_LOGE("check mul with micro_batch_size over flow failed."); + return 0UL; + } + + if (MulOverflow(token_info_size, static_cast(context_.common.micro_batch_num), token_info_size)) { + ASCEND_LOGE("check mul with micro_batch_num over flow failed."); + return 0UL; + } + return token_info_size; +} + +uint32_t ScheduleContextHolder::InitAttention() +{ + uint64_t token_info_size = CalcAttentionTokenInfoSize(); + if (token_info_size == 0U) { + return kFailure; + } + + uint64_t token_info_aligned_size = AlignUp(token_info_size, kBufAlignSize); + if (token_info_aligned_size < token_info_size) { + ASCEND_LOGE("token_info_size=%lu overflow after align with %lu.", token_info_size, kBufAlignSize); + return kFailure; + } + if (attention_window_size_ <= token_info_aligned_size) { + ASCEND_LOGE("attention_window_size=%lu must be > token_info_aligned_size= %lu.", + attention_window_size_, token_info_aligned_size); + return kFailure; + } + + context_.attention.token_info_buf = attention_window_; + context_.attention.token_info_buf_size = token_info_size; + auto ret = aclrtMemset(reinterpret_cast(static_cast(context_.attention.token_info_buf)), + token_info_size, '\0', token_info_size); + if (ret != ACL_ERROR_NONE) { + ASCEND_LOGE("ACL memset attention context to 0 failed, addr=%lu, size=%zu.", + context_.attention.token_info_buf, token_info_size); + return kFailure; + } + if (AddOverflow(attention_window_, token_info_aligned_size, context_.attention.token_data_buf)) { + ASCEND_LOGE("check attention_window add token_info_size over flow failed."); + return kFailure; + } + // can't calc token_data_buf_size as the data type is unknown. + context_.attention.token_data_buf_size = attention_window_size_ - token_info_aligned_size; + // init to micro_batch_num - 1, scheduler will scan from (micro_batch_id + 1) % micro_batch_num. + context_.attention.micro_batch_id = context_.common.micro_batch_num - 1U; + ASCEND_LOGI("Init attention success, token_info_buf=%lu, token_info_buf_size=%lu, token_data_buf=%lu, " + "token_data_buf_size=%lu, micro_batch_id=%u.", + context_.attention.token_info_buf, context_.attention.token_info_buf_size, + context_.attention.token_data_buf, context_.attention.token_data_buf_size, + context_.attention.micro_batch_id); + return kSuccess; +} + +uint32_t ScheduleContextHolder::Init() +{ + if (init_flag_) { + ASCEND_LOGI("Already been initialized, does not need to be initialized again."); + return kSuccess; + } + ASCEND_LOGI("Init begin, schedule_mode=%d, session_num=%u, micro_batch_num=%u, micro_batch_size=%u, " + "selected_expert_num=%u, ffn_window=%lu, ffn_window_size=%lu, " + "attention_window=%lu, attention_window_size_=%lu.", + context_.common.schedule_mode, context_.common.session_num, context_.common.micro_batch_num, + context_.common.micro_batch_size, context_.common.selected_expert_num, + ffn_window_, ffn_window_size_, attention_window_, attention_window_size_); + if (!CheckParams()) { + return kFailure; + } + uint32_t ret = kSuccess; + if (context_.common.schedule_mode == kScheduleModeFfn) { + ret = InitFfn(); + } else if (context_.common.schedule_mode == kScheduleModeAttention) { + ret = InitAttention(); + } + if (ret != kSuccess) { + return ret; + } + context_.control.run_flag = kRunFlagRunning; + ret = AllocAndAssignDevMem(); + if (ret != kSuccess) { + return ret; + } + init_flag_ = true; + ASCEND_LOGI("init success."); + return kSuccess; +} + +uint32_t ScheduleContextHolder::AllocAndAssignDevMem() +{ + auto dev_tensor_options = at::TensorOptions(c10::DeviceType::PrivateUse1).dtype(torch::kInt8); + if (context_.common.schedule_mode == kScheduleModeFfn) { + uint64_t layer_id_buf_size_align_up = AlignUp(context_.ffn.layer_ids_buf_size, kBufAlignSize); + uint64_t expert_buf_size_align_up = AlignUp(context_.ffn.expert_ids_buf_size, kBufAlignSize); + // session_ids_buf and micro_batch_ids_buf size is same as layer_id_buf, so multiply 3. + workspace_size_ = layer_id_buf_size_align_up * 3UL + expert_buf_size_align_up; + workspace_tensor_ = at::empty(std::vector({workspace_size_}), dev_tensor_options); + auto workspace_addr = reinterpret_cast(workspace_tensor_.data_ptr()); + if (workspace_addr == 0UL) { + ASCEND_LOGE("alloc workspace failed, workspace_size_=%lu.", workspace_size_); + return kFailure; + } + context_.ffn.layer_ids_buf = workspace_addr; + context_.ffn.session_ids_buf = context_.ffn.layer_ids_buf + layer_id_buf_size_align_up; + context_.ffn.micro_batch_ids_buf = context_.ffn.session_ids_buf + layer_id_buf_size_align_up; + context_.ffn.expert_ids_buf = context_.ffn.micro_batch_ids_buf + layer_id_buf_size_align_up; + + ASCEND_LOGI("alloc and assign ffn dev mem success, layer_ids_buf=%lu, layer_ids_buf_size=%lu, " + "session_ids_buf=%lu, session_ids_buf_size=%lu, " + "micro_batch_ids_buf=%lu, micro_batch_ids_buf_size=%lu, " + "expert_ids_buf=%lu, expert_ids_buf_size=%lu.", + context_.ffn.layer_ids_buf, context_.ffn.layer_ids_buf_size, + context_.ffn.session_ids_buf, context_.ffn.session_ids_buf_size, + context_.ffn.micro_batch_ids_buf, context_.ffn.micro_batch_ids_buf_size, + context_.ffn.expert_ids_buf, context_.ffn.expert_ids_buf_size); + } + std::vector context_shape = {sizeof(ScheduleContext)}; + // 将 ScheduleContext 的内存包装成 Tensor(零拷贝) + at::Tensor host_tensor = at::from_blob(&context_, + context_shape, + at::TensorOptions().dtype(torch::kInt8)); + context_tensor_ = at::empty(context_shape, dev_tensor_options); + context_tensor_.copy_(host_tensor); + return kSuccess; +} + +bool ScheduleContextHolder::CheckFfnParams() const +{ + if (ffn_window_ == 0UL) { + ASCEND_LOGE("check ffn param failed, ffn_window can't be 0."); + return false; + } + if (ffn_window_size_ == 0UL) { + ASCEND_LOGE("check ffn param failed, ffn_window_size can't be 0."); + return false; + } + return true; +} + +bool ScheduleContextHolder::CheckAttentionParams() const +{ + if (attention_window_ == 0UL) { + ASCEND_LOGE("check attention param failed, ffn_window can't be 0."); + return false; + } + if (attention_window_size_ == 0UL) { + ASCEND_LOGE("check attention param failed, ffn_window_size can't be 0."); + return false; + } + return true; +} + +bool ScheduleContextHolder::CheckParams() const +{ + if ((context_.common.session_num == 0U) || (context_.common.micro_batch_num == 0U) || + (context_.common.micro_batch_size == 0U) || (context_.common.selected_expert_num == 0U) || + (context_.common.expert_num == 0U)) { + ASCEND_LOGE("session_num[%u], micro_batch_num[%u], micro_batch_size[%u], selected_expert_num[%u], " + "expert_num[%u] can't be 0.", context_.common.session_num, context_.common.micro_batch_num, + context_.common.micro_batch_size, context_.common.selected_expert_num, context_.common.expert_num); + return false; + } + + if ((context_.common.attn_to_ffn_token_size % kBufAlignSize) != 0U) { + ASCEND_LOGE("attn_to_ffn_token_size[%lu] must be align with %lu.", context_.common.attn_to_ffn_token_size, + kBufAlignSize); + return false; + } + if ((context_.common.ffn_to_attn_token_size % kBufAlignSize) != 0U) { + ASCEND_LOGE("ffn_to_attn_token_size[%lu] must be align with %lu.", context_.common.ffn_to_attn_token_size, + kBufAlignSize); + return false; + } + + if (context_.common.schedule_mode == kScheduleModeFfn) { + return CheckFfnParams(); + } else if (context_.common.schedule_mode == kScheduleModeAttention) { + return CheckAttentionParams(); + } else { + ASCEND_LOGE("check schedule_mode=%d failed, only support [%d, %d] now.", context_.common.schedule_mode, + kScheduleModeFfn, kScheduleModeAttention); + return false; + } +} + +uint32_t ScheduleContextHolder::StopSchedule() +{ + context_.control.run_flag = kRunFlagStop; + at::Tensor run_flag_host_tensor = at::from_blob( + &context_.control.run_flag, + {static_cast(sizeof(context_.control.run_flag))}, + at::TensorOptions().dtype(torch::kInt8) + ); + size_t offset = offsetof(ScheduleContext::ControlArea, run_flag) + offsetof(ScheduleContext, control); + auto run_flag_view = context_tensor_.slice(0, offset, offset + sizeof(context_.control.run_flag)); + run_flag_view.copy_(run_flag_host_tensor); + return kSuccess; +} + +std::pair ScheduleContextHolder::GetContextTensor() const +{ + if (!init_flag_) { + return std::make_pair(kFailure, at::Tensor()); + } + + return std::make_pair(kSuccess, context_tensor_); +} + +uint32_t ScheduleContextHolder::GetScheduleContextFromDev(ScheduleContext &context) const +{ + // 将 ScheduleContext 的内存包装成 Tensor(零拷贝) + at::Tensor host_tensor = at::from_blob(&context, + {static_cast(sizeof(context))}, + at::TensorOptions().dtype(torch::kInt8) + ); + host_tensor.copy_(context_tensor_); + return kSuccess; +} + +std::string ScheduleContextHolder::GetScheduleContextInfo() const +{ + if (!init_flag_) { + return "Error: schedule context is not inited!"; + } + ScheduleContext tmp_context{}; + auto ret = GetScheduleContextFromDev(tmp_context); + if (ret != kSuccess) { + ASCEND_LOGE("Get schedule context from device failed."); + return "Error: copy schedule context to host failed, error=" + std::to_string(ret); + } + return ToString(tmp_context); +} + +std::string ScheduleContextHolder::ToString(const ScheduleContext &context) +{ + std::stringstream ss; + ss << "schedule context: schedule_mode=" << context.common.schedule_mode + << ", session_num=" << context.common.session_num << ", micro_batch_num=" << context.common.micro_batch_num + << ", micro_batch_size=" << context.common.micro_batch_size + << ", selected_expert_num=" << context.common.selected_expert_num << ", expert_num=" << context.common.expert_num + << ", attn_to_ffn_token_size=" << context.common.attn_to_ffn_token_size + << ", ffn_to_attn_token_size=" << context.common.ffn_to_attn_token_size + << ", run_flag=" << context.control.run_flag; + if (context.common.schedule_mode == kScheduleModeFfn) { + ss << ", ffn info: token_info_buf=" << context.ffn.token_info_buf + << ", token_info_buf_size=" << context.ffn.token_info_buf_size + << ", token_data_buf=" << context.ffn.token_data_buf + << ", token_data_buf_size=" << context.ffn.token_data_buf_size << ", polling_index=" + << context.ffn.polling_index + << ", layer_ids_buf=" << context.ffn.layer_ids_buf << ", layer_ids_buf_size=" + << context.ffn.layer_ids_buf_size + << ", session_ids_buf=" << context.ffn.session_ids_buf + << ", session_ids_buf_size=" << context.ffn.session_ids_buf_size + << ", micro_batch_ids_buf=" << context.ffn.micro_batch_ids_buf + << ", micro_batch_ids_buf_size=" << context.ffn.micro_batch_ids_buf_size + << ", expert_ids_buf=" << context.ffn.expert_ids_buf + << ", expert_ids_buf_size=" << context.ffn.expert_ids_buf_size << ", out_num=" << context.ffn.out_num << ";"; + } else if (context.common.schedule_mode == kScheduleModeAttention) { + ss << ", attention info: token_info_buf=" << context.attention.token_info_buf + << ", token_info_buf_size=" << context.attention.token_info_buf_size + << ", token_data_buf=" << context.attention.token_data_buf + << ", token_data_buf_size=" << context.attention.token_data_buf_size + << ", micro_batch_id=" << context.attention.micro_batch_id << ";"; + } + return ss.str(); +} +} // namespace afd +} // namespace torch_npu \ No newline at end of file diff --git a/torch_npu/csrc/afd/ScheduleContext.h b/torch_npu/csrc/afd/ScheduleContext.h new file mode 100644 index 00000000000..55f58815310 --- /dev/null +++ b/torch_npu/csrc/afd/ScheduleContext.h @@ -0,0 +1,122 @@ +#pragma once + +#include +#include +#include + +namespace torch_npu { +namespace afd { +#pragma pack(push, 1) +struct ScheduleContext { + struct CommonArea { + uint32_t session_num; + uint32_t micro_batch_num; + uint32_t micro_batch_size; + uint32_t selected_expert_num; + uint32_t expert_num; // expert num per layer, include route expert and share expert + uint32_t attn_to_ffn_token_size; // each token in ffn window data area space size, align to 512. + uint32_t ffn_to_attn_token_size; // each token in attention window data area space size, align to 512. + int32_t schedule_mode; // 0:just ffn, 1:just attention, 2:ffn+attention + int8_t reserve0[96]; // padding to 128 bytes + }; + struct ControlArea { + int32_t run_flag; // 0: exit 1: running + int8_t reserve2[124]; // padding to 128 bytes + }; + struct FfnArea { + // ffn area + uint64_t token_info_buf; + uint64_t token_info_buf_size; + uint64_t token_data_buf; + uint64_t token_data_buf_size; + uint64_t polling_index; + int8_t reserve3[88]; + + // ffn out area + uint64_t layer_ids_buf; + uint64_t layer_ids_buf_size; + uint64_t session_ids_buf; + uint64_t session_ids_buf_size; + uint64_t micro_batch_ids_buf; + uint64_t micro_batch_ids_buf_size; + uint64_t expert_ids_buf; + uint64_t expert_ids_buf_size; + uint32_t out_num; + int8_t reserve4[60]; + }; + + struct AttentionArea { + // attention area + uint64_t token_info_buf; // point to a int64 dev mem + uint64_t token_info_buf_size; + uint64_t token_data_buf; // point to a int64 dev mem + uint64_t token_data_buf_size; + uint32_t micro_batch_id; + int8_t reserve5[92]; + }; + + // common area + CommonArea common; + ControlArea control; + AttentionArea attention; + FfnArea ffn; + // reserve area + int8_t reserve6[384]; // padding to 1024 bytes +}; +static_assert(sizeof(ScheduleContext) == 1024, "ScheduleContext size must be 1024 bytes"); +#pragma pack(pop) + +class ScheduleContextHolder { +public: + ScheduleContextHolder(int32_t schedule_mode, uint32_t session_num, uint32_t micro_batch_num, + uint32_t micro_batch_size, uint32_t selected_expert_num, uint32_t expert_num, + uint32_t attn_to_ffn_token_size, uint32_t ffn_to_attn_token_size, uint64_t ffn_window, + uint64_t ffn_window_size, uint64_t attention_window, uint64_t attention_window_size); + + ~ScheduleContextHolder() = default; + + uint32_t Init(); + + std::pair GetContextTensor() const; + + uint32_t StopSchedule(); + + std::string GetScheduleContextInfo() const; + +private: + bool CheckParams() const; + + bool CheckFfnParams() const; + + bool CheckAttentionParams() const; + + uint32_t InitFfn(); + + uint32_t InitAttention(); + + uint32_t InitFfnTokenInfoBuf() const; + + uint64_t CalcFfnTokenInfoSize() const; + + uint64_t CalcAttentionTokenInfoSize() const; + + uint32_t AllocAndAssignDevMem(); + + uint32_t GetScheduleContextFromDev(ScheduleContext &context) const; + + static std::string ToString(const ScheduleContext &context); + + uint64_t ffn_window_ = 0; + uint64_t ffn_window_size_ = 0; + uint64_t attention_window_ = 0; + uint64_t attention_window_size_ = 0; + bool init_flag_ = false; + ScheduleContext context_{}; // host ptr + + at::Tensor context_tensor_; + at::Tensor workspace_tensor_; + + uint64_t workspace_size_ = 0; +}; +} +} \ No newline at end of file -- Gitee