diff --git a/.jenkins/test/config/dependent_packages.yaml b/.jenkins/test/config/dependent_packages.yaml index 4966b89a21dbb9291a118ed2012619ed8efe43ac..e94af60369f7abe8207634308ce06cebe39fdf1b 100644 --- a/.jenkins/test/config/dependent_packages.yaml +++ b/.jenkins/test/config/dependent_packages.yaml @@ -1,5 +1,5 @@ mindspore: - 'https://repo.mindspore.cn/mindspore/mindspore/version/202506/20250605/master_20250605212230_aac98ab9732926f6abd4c3d73be47d5be6c93ead_newest/' + 'https://repo.mindspore.cn/mindspore/mindspore/version/202507/20250711/master_20250711010018_500a29c562b75cda313971360b3c4a6ab745f089_newest/' mindspore_gs: 'https://repo.mindspore.cn/mindspore/golden-stick/version/202507/20250709/master_20250709010018_5f01a0211ca36690a577d3d456c5ba194c88771d_newest/' diff --git a/vllm_mindspore/ops/CMakeLists.txt b/csrc/CMakeLists.txt similarity index 94% rename from vllm_mindspore/ops/CMakeLists.txt rename to csrc/CMakeLists.txt index 4c94b2c085b0be5ed4247e4c5829325531648ae9..86ae77716793e2965f1a91e74745a86b19eb1b8f 100644 --- a/vllm_mindspore/ops/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -14,7 +14,7 @@ endif() add_subdirectory(ascendc) # Collect source files -file(GLOB SRC_FILES ${CMAKE_CURRENT_SOURCE_DIR}/module/*.cpp) +file(GLOB_RECURSE SRC_FILES ${CMAKE_CURRENT_SOURCE_DIR}/module/*.cpp) # Generate a temporary python script file to build custom ops with MindSpore's CustomOpBuilder set(PYTHON_SCRIPT_PATH "${CMAKE_BINARY_DIR}/build_custom_with_ms.py") diff --git a/vllm_mindspore/ops/ascendc/CMakeLists.txt b/csrc/ascendc/CMakeLists.txt similarity index 100% rename from vllm_mindspore/ops/ascendc/CMakeLists.txt rename to csrc/ascendc/CMakeLists.txt diff --git a/vllm_mindspore/ops/ascendc/adv_step_flash.c b/csrc/ascendc/adv_step_flash.c similarity index 100% rename from vllm_mindspore/ops/ascendc/adv_step_flash.c rename to csrc/ascendc/adv_step_flash.c diff --git a/vllm_mindspore/ops/ascendc/adv_step_flash.h b/csrc/ascendc/adv_step_flash.h similarity index 51% rename from vllm_mindspore/ops/ascendc/adv_step_flash.h rename to csrc/ascendc/adv_step_flash.h index 1dbd1bc63380364aba2cab971ae89e6176e27511..3c601e04a0a8aea2cecde695f200c5769ebf0467 100644 --- a/vllm_mindspore/ops/ascendc/adv_step_flash.h +++ b/csrc/ascendc/adv_step_flash.h @@ -13,12 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef VLLM_MINDSPORE_OPS_ASCENDC_ADV_STEP_FLASH_H -#define VLLM_MINDSPORE_OPS_ASCENDC_ADV_STEP_FLASH_H +#ifndef VLLM_MINDSPORE_CSRC_ASCENDC_ADV_STEP_FLASH_H +#define VLLM_MINDSPORE_CSRC_ASCENDC_ADV_STEP_FLASH_H -extern void AdvStepFlashKernelEntry(uint32_t blockDims, void *l2ctrl, void *aclStream, uint8_t *sampledTokenIds, - uint8_t *blockTables, uint8_t *seqLensInput, uint8_t *inputTokens, - uint8_t *inputPositions, uint8_t *seqLensOut, uint8_t *slotMapping, - int32_t num_seqs, int32_t block_size, int32_t block_tables_stride); +extern void AdvStepFlashKernelEntry( + uint32_t blockDims, void *l2ctrl, void *aclStream, uint8_t *sampledTokenIds, + uint8_t *blockTables, uint8_t *seqLensInput, uint8_t *inputTokens, + uint8_t *inputPositions, uint8_t *seqLensOut, uint8_t *slotMapping, + int32_t num_seqs, int32_t block_size, int32_t block_tables_stride); -#endif // VLLM_MINDSPORE_OPS_ASCENDC_ADV_STEP_FLASH_H +#endif // VLLM_MINDSPORE_CSRC_ASCENDC_ADV_STEP_FLASH_H diff --git a/csrc/module/adv_step_flash.cpp b/csrc/module/adv_step_flash.cpp new file mode 100644 index 0000000000000000000000000000000000000000..899de8bd2fa57c2126ff90685b599870fb6e00e8 --- /dev/null +++ b/csrc/module/adv_step_flash.cpp @@ -0,0 +1,125 @@ +/** + * Copyright 2025 Huawei Technologies Co., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include "ms_extension/api.h" + +#include "ascendc/adv_step_flash.h" +#include "module/module.h" + +struct DtypeCaster { + ms::Tensor CheckAndCast(const ms::Tensor &t, const std::string &name = "") { + if (t.data_type() != ms::TypeId::kNumberTypeInt32) { + if (!name.empty()) { + tensor_map_[name] = t; + } + return t.cast(ms::TypeId::kNumberTypeInt32); + } + return t; + } + + ms::Tensor RecoveryTensorDtype(const ms::Tensor &t, const std::string &name) { + auto iter = tensor_map_.find(name); + if (iter == tensor_map_.end()) { + return t; + } + auto ori_tensor = iter->second; + auto ret = t.cast(ori_tensor.data_type()); + ori_tensor.AssignTensor(ret); + return ori_tensor; + } + std::map tensor_map_; +}; + +class AdvStepFlashOp : public ms::pynative::PyboostRunner { +public: + using PyboostRunner::PyboostRunner; + void LaunchKernel() override { + uint8_t *sampledTokenIdsPtr = + static_cast(inputs()[0].GetDataPtr()); + uint8_t *seqLensPtr = static_cast(inputs()[1].GetDataPtr()); + uint8_t *blockTablesPtr = static_cast(inputs()[2].GetDataPtr()); + uint8_t *inputTokensPtr = static_cast(outputs()[0].GetDataPtr()); + uint8_t *inputPositionsPtr = + static_cast(outputs()[1].GetDataPtr()); + uint8_t *slotMappingPtr = static_cast(outputs()[3].GetDataPtr()); + auto stride = inputs()[2].stride(); + int32_t block_tables_stride = stride.empty() ? 1 : stride[0]; + + uint32_t blockDims = 1; + void *l2ctrl = nullptr; + AdvStepFlashKernelEntry(blockDims, l2ctrl, stream(), sampledTokenIdsPtr, + blockTablesPtr, seqLensPtr, inputTokensPtr, + inputPositionsPtr, seqLensPtr, slotMappingPtr, + num_seqs_, block_size_, block_tables_stride); + } + + static void Eval(int32_t num_seqs, int32_t num_queries, int32_t block_size, + ms::Tensor input_tokens, // output + ms::Tensor sampled_token_ids, // input + ms::Tensor input_positions, // output + ms::Tensor seq_lens, // input&output (inplace) + ms::Tensor slot_mapping, // output + ms::Tensor block_tables // input + ) { + // the AdvStepFlashKernelEntry only support int32 inputs. + DtypeCaster caster; + sampled_token_ids = caster.CheckAndCast(sampled_token_ids); + block_tables = caster.CheckAndCast(block_tables); + input_tokens = caster.CheckAndCast(input_tokens, "input_tokens"); + input_positions = caster.CheckAndCast(input_positions, "input_positions"); + slot_mapping = caster.CheckAndCast(slot_mapping, "slot_mapping"); + seq_lens = caster.CheckAndCast(seq_lens, "seq_lens"); + + auto runner = std::make_shared("AdvanceStepFlashattn"); + runner->num_seqs_ = num_seqs; + runner->num_queries_ = num_queries; + runner->block_size_ = block_size; + runner->Run({sampled_token_ids, seq_lens, block_tables}, + {input_tokens, input_positions, seq_lens, slot_mapping}); + + input_tokens = caster.RecoveryTensorDtype(input_tokens, "input_tokens"); + input_positions = + caster.RecoveryTensorDtype(input_positions, "input_positions"); + slot_mapping = caster.RecoveryTensorDtype(slot_mapping, "slot_mapping"); + seq_lens = caster.RecoveryTensorDtype(seq_lens, "seq_lens"); + } + int32_t num_seqs_{0}; + int32_t num_queries_{0}; + int32_t block_size_{0}; +}; + +auto pyboost_adv_step_flash(int32_t num_seqs, int32_t num_queries, + int32_t block_size, ms::Tensor input_tokens, + ms::Tensor sampled_token_ids, + ms::Tensor input_positions, ms::Tensor seq_lens, + ms::Tensor slot_mapping, ms::Tensor block_tables) { + return ms::pynative::PyboostRunner::Call<0>( + AdvStepFlashOp::Eval, num_seqs, num_queries, block_size, input_tokens, + sampled_token_ids, input_positions, seq_lens, slot_mapping, block_tables); +} + +VLLM_MS_EXTENSION_MODULE(m) { + m.def("advance_step_flashattn", &pyboost_adv_step_flash, + "advance_step_flashattn", pybind11::arg("num_seqs"), + pybind11::arg("num_queries"), pybind11::arg("block_size"), + pybind11::arg("input_tokens"), pybind11::arg("sampled_token_ids"), + pybind11::arg("input_positions"), pybind11::arg("seq_lens"), + pybind11::arg("slot_mapping"), pybind11::arg("block_tables")); +} diff --git a/vllm_mindspore/ops/module/module.cpp b/csrc/module/module.cpp similarity index 100% rename from vllm_mindspore/ops/module/module.cpp rename to csrc/module/module.cpp diff --git a/vllm_mindspore/ops/module/module.h b/csrc/module/module.h similarity index 58% rename from vllm_mindspore/ops/module/module.h rename to csrc/module/module.h index 47277ab7036c12ddaf41551079fe4c47c4ec6e78..acda4235d8f72ce549df1c91071dae87bebfc423 100644 --- a/vllm_mindspore/ops/module/module.h +++ b/csrc/module/module.h @@ -13,20 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef VLLM_MINDSPORE_OPS_MODULE_MODULE_H -#define VLLM_MINDSPORE_OPS_MODULE_MODULE_H +#ifndef VLLM_MINDSPORE_CSRC_MODULE_MODULE_H +#define VLLM_MINDSPORE_CSRC_MODULE_MODULE_H -#include #include -#include +#include #include +#include // Define the type of module registration functions using ModuleRegisterFunction = std::function; // Module registry class class ModuleRegistry { - public: +public: // Get the singleton instance static ModuleRegistry &Instance() { static ModuleRegistry instance; @@ -34,7 +34,9 @@ class ModuleRegistry { } // Register a module function - void Register(const ModuleRegisterFunction &func) { functions_.push_back(func); } + void Register(const ModuleRegisterFunction &func) { + functions_.push_back(func); + } // Call all registered module functions void RegisterAll(pybind11::module_ &m) { @@ -43,7 +45,7 @@ class ModuleRegistry { } } - private: +private: ModuleRegistry() = default; ~ModuleRegistry() = default; @@ -55,15 +57,21 @@ class ModuleRegistry { std::vector functions_; }; -// Define a macro to register module functions -#define MS_EXTENSION_MODULE(func) \ - static void func##_register(pybind11::module_ &); \ - namespace { \ - struct func##_registrar { \ - func##_registrar() { ModuleRegistry::Instance().Register(func##_register); } \ - }; \ - static func##_registrar registrar_instance; \ - } \ - static void func##_register(pybind11::module_ &m) +#define CONCATENATE_DETAIL(x, y) x##y +#define CONCATENATE(x, y) CONCATENATE_DETAIL(x, y) + +#define VLLM_MS_EXTENSION_MODULE(m) \ + static void CONCATENATE(func_register_, __LINE__)(pybind11::module_ &); \ + namespace { \ + struct CONCATENATE(func_registrar_, __LINE__) { \ + CONCATENATE(func_registrar_, __LINE__)() { \ + ModuleRegistry::Instance().Register( \ + CONCATENATE(func_register_, __LINE__)); \ + } \ + }; \ + static CONCATENATE(func_registrar_, __LINE__) \ + CONCATENATE(registrar_instance_, __LINE__); \ + } \ + static void CONCATENATE(func_register_, __LINE__)(pybind11::module_ & m) -#endif // VLLM_MINDSPORE_OPS_MODULE_MODULE_H +#endif // VLLM_MINDSPORE_CSRC_MODULE_MODULE_H diff --git a/setup.py b/setup.py index 57d80e39050553092471b38ba0ec7a1dbe764767..81baf524179458d1d34efb0d7c321a4830b0456d 100644 --- a/setup.py +++ b/setup.py @@ -26,7 +26,6 @@ from setuptools import find_packages, setup from setuptools.command.build_ext import build_ext from setuptools import Extension import subprocess -import warnings def load_module_from_path(module_name, path): @@ -111,18 +110,20 @@ class CustomBuildExt(build_ext): ROOT_DIR = os.path.abspath(os.path.dirname(__file__)) def build_extension(self, ext): - if ext.name == "vllm_mindspore.npu_ops": - self.build_npu_ops(ext) + if ext.name == "vllm_mindspore._C_ops": + self.build_c_ops(ext) else: raise ValueError(f"Unknown extension name: {ext.name}") - def build_npu_ops(self, ext): - # "vllm_mindspore.npu_ops" --> "npu_ops" + def build_c_ops(self, ext): + # "vllm_mindspore._C_ops" --> "_C_ops" ext_name = ext.name.split('.')[-1] so_name = ext_name + ".so" logger.info(f"Building {so_name} ...") - OPS_DIR = os.path.join(ROOT_DIR, "vllm_mindspore", "ops") - BUILD_OPS_DIR = os.path.join(ROOT_DIR, "build", "ops") + OPS_DIR = os.path.join(ROOT_DIR, "csrc") + BUILD_OPS_DIR = os.path.join(ROOT_DIR, "build", "csrc_ops") + if os.path.exists(BUILD_OPS_DIR): + shutil.rmtree(BUILD_OPS_DIR) os.makedirs(BUILD_OPS_DIR, exist_ok=True) ascend_home_path = _get_ascend_home_path() @@ -140,17 +141,20 @@ class CustomBuildExt(build_ext): f"cmake --build {BUILD_OPS_DIR} -j --verbose" ) - try: - # Run the combined cmake command - logger.info(f"Running combined CMake commands:\n{cmake_cmd}") - result = subprocess.run(cmake_cmd, cwd=self.ROOT_DIR, text=True, shell=True, capture_output=True) - if result.returncode != 0: - logger.info("CMake commands failed:") - logger.info(result.stdout) # Print standard output - logger.info(result.stderr) # Print error output - raise RuntimeError(f"Combined CMake commands failed with exit code {result.returncode}") - except subprocess.CalledProcessError as e: - raise RuntimeError(f"Failed to build {so_name}: {e}") + # Run the combined cmake command + logger.info(f"Running commands:\n{cmake_cmd}") + build_log_file = os.path.join(BUILD_OPS_DIR, "build_log.txt") + with open(build_log_file, "w") as log_file: + result = subprocess.run( + ["bash", "-c", cmake_cmd], + cwd=self.ROOT_DIR, + text=True, + stdout=log_file, + stderr=log_file + ) + if result.returncode != 0: + logger.error(f"Command failed: '{cmake_cmd}' exited with code {result.returncode}") + raise RuntimeError(f"Failed to build {ext_name}, check the build log for details: {build_log_file}") # Copy the generated .so file to the target directory src_so_path = os.path.join(build_extension_dir, so_name) @@ -159,7 +163,7 @@ class CustomBuildExt(build_ext): if os.path.exists(dst_so_path): os.remove(dst_so_path) shutil.copy(src_so_path, dst_so_path) - logger.info(f"Copied {so_name} to {dst_so_path}") + logger.info(f"Build {dst_so_path} succeeded.") write_commit_id() @@ -176,7 +180,7 @@ def _get_ext_modules(): ext_modules = [] if os.path.exists(_get_ascend_home_path()): # sources are specified in CMakeLists.txt - ext_modules.append(Extension("vllm_mindspore.npu_ops", sources=[])) + ext_modules.append(Extension("vllm_mindspore._C_ops", sources=[])) return ext_modules setup( diff --git a/tests/st/python/test_custom_advstepflash.py b/tests/st/python/test_custom_advstepflash.py index 826f92cb65d66b4a02fe962d65f39302e13b8bb8..353f297438be5fcfda698d5491542e88e0a1e536 100644 --- a/tests/st/python/test_custom_advstepflash.py +++ b/tests/st/python/test_custom_advstepflash.py @@ -15,13 +15,12 @@ # limitations under the License. """test case for custom op adv_step_flash""" import time - import mindspore as ms import numpy as np import pytest import torch -from vllm_mindspore import npu_ops +from vllm_mindspore import _custom_ops as custom_ops from .utils import cleanup_subprocesses @@ -75,12 +74,9 @@ def gendata(seed, num_seqs, block_size, block_num, make_tensor): dtype=np.int64) slot_mapping = np.random.randint(100, size=(num_seqs, ), dtype=np.int64) # out - return (make_tensor(sampled_token_ids), \ - make_tensor(input_tokens), \ - make_tensor(input_positions), \ - make_tensor(seq_lens_tensor), \ - make_tensor(block_tables), \ - make_tensor(slot_mapping)) + return (make_tensor(sampled_token_ids), make_tensor(input_tokens), + make_tensor(input_positions), make_tensor(seq_lens_tensor), + make_tensor(block_tables), make_tensor(slot_mapping)) @pytest.mark.level0 @@ -96,23 +92,25 @@ def test_advstepflash(): block_num = 4 num_queries = num_seqs # no padding print("test seed:", seed, flush=True) - sampled_token_ids1, input_tokens1, input_positions1, seq_lens_tensor1, block_tables1, slot_mapping1 = \ + sampled_token_ids1, input_tokens1, input_positions1, seq_lens_tensor1, \ + block_tables1, slot_mapping1 = \ gendata(seed, num_seqs, block_size, block_num, torch.Tensor) benchmark_advance_step_op(sampled_token_ids1, input_tokens1, input_positions1, seq_lens_tensor1, num_queries, block_size, block_tables1, slot_mapping1) - sampled_token_ids2, input_tokens2, input_positions2, seq_lens_tensor2, block_tables2, slot_mapping2 = \ + sampled_token_ids2, input_tokens2, input_positions2, seq_lens_tensor2, \ + block_tables2, slot_mapping2 = \ gendata(seed, num_seqs, block_size, block_num, ms.Tensor) - npu_ops.adv_step_flash(num_seqs=num_seqs, - num_queries=num_queries, - block_size=block_size, - input_tokens=input_tokens2, - sampled_token_ids=sampled_token_ids2, - input_positions=input_positions2, - seq_lens=seq_lens_tensor2, - slot_mapping=slot_mapping2, - block_tables=block_tables2) + custom_ops.advance_step_flashattn(num_seqs=num_seqs, + num_queries=num_queries, + block_size=block_size, + input_tokens=input_tokens2, + sampled_token_ids=sampled_token_ids2, + input_positions=input_positions2, + seq_lens=seq_lens_tensor2, + slot_mapping=slot_mapping2, + block_tables=block_tables2) assert np.allclose(sampled_token_ids1, sampled_token_ids2.asnumpy()) assert np.allclose(input_tokens1, input_tokens2.asnumpy()) diff --git a/vllm_mindspore/model_executor/custom_op.py b/vllm_mindspore/_custom_ops.py similarity index 30% rename from vllm_mindspore/model_executor/custom_op.py rename to vllm_mindspore/_custom_ops.py index 585543606165f4668cad2331ed1525891fb754a1..1bf50507c608c1d1f20131e6076012436e4933b7 100644 --- a/vllm_mindspore/model_executor/custom_op.py +++ b/vllm_mindspore/_custom_ops.py @@ -1,43 +1,37 @@ +#!/usr/bin/env python3 # SPDX-License-Identifier: Apache-2.0 - -# Adapted from -# https://github.com/vllm-project/vllm/blob/v0.8.3/vllm/model_executor/custom_op.py -# -# Copyright 2025 Huawei Technologies Co., Ltd. -# Copyright 2025 The vLLM team. +# Copyright 2025 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# ============================================================================ -from mindspore import nn - -class CustomOp(nn.Cell): - """ - Base class for custom ops. - Dispatches the forward method to the appropriate backend. - """ - - def __init__(self): - super().__init__() - self._forward_method = self.dispatch_forward() - - def construct(self, *args, **kwargs): - return self._forward_method(*args, **kwargs) +import mindspore as ms - def forward_native(self, *args, **kwargs): - raise NotImplementedError - def forward_cuda(self, *args, **kwargs): - raise NotImplementedError - - def dispatch_forward(self): - return self.forward_native \ No newline at end of file +def advance_step_flashattn(num_seqs: int, num_queries: int, block_size: int, + input_tokens: ms.Tensor, + sampled_token_ids: ms.Tensor, + input_positions: ms.Tensor, seq_lens: ms.Tensor, + slot_mapping: ms.Tensor, + block_tables: ms.Tensor) -> None: + """Advance a step on Ascend for existing inputs for a multi-step runner""" + from vllm_mindspore import _C_ops as c_ops + c_ops.advance_step_flashattn(num_seqs=num_seqs, + num_queries=num_queries, + block_size=block_size, + input_tokens=input_tokens, + sampled_token_ids=sampled_token_ids, + input_positions=input_positions, + seq_lens=seq_lens, + slot_mapping=slot_mapping, + block_tables=block_tables) diff --git a/vllm_mindspore/attention/backends/ms_attn.py b/vllm_mindspore/attention/backends/ms_attn.py index 6a3ef5d69cd807f208eb852579f894ed63c2696e..29aa07473605f80e64ba5d925c905597b7590c8e 100644 --- a/vllm_mindspore/attention/backends/ms_attn.py +++ b/vllm_mindspore/attention/backends/ms_attn.py @@ -23,9 +23,6 @@ from collections import defaultdict from dataclasses import dataclass from itertools import accumulate from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type -import os - -import numpy as np from vllm.attention.backends.abstract import ( AttentionBackend, diff --git a/vllm_mindspore/ops/module/adv_step_flash.cpp b/vllm_mindspore/ops/module/adv_step_flash.cpp deleted file mode 100644 index 513f09a61616fc9b4e1e3ffb5cf53cdcfcc7c2d6..0000000000000000000000000000000000000000 --- a/vllm_mindspore/ops/module/adv_step_flash.cpp +++ /dev/null @@ -1,115 +0,0 @@ -/** - * Copyright 2025 Huawei Technologies Co., Ltd. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include -#include -#include - -#include "ms_extension.h" - -#include "ascendc/adv_step_flash.h" -#include "module/module.h" - -using BaseTensor = mindspore::tensor::BaseTensor; -using BaseTensorPtr = mindspore::tensor::BaseTensorPtr; -using PyBoostUtils = mindspore::kernel::pyboost::PyBoostUtils; - -uint8_t *GetDataPtr(const BaseTensorPtr &t) { - return static_cast(t->device_address()->GetMutablePtr()) + t->data().itemsize() * t->storage_offset(); -} - -struct DtypeCaster { - BaseTensorPtr CheckAndCast(const BaseTensorPtr &t, const std::string &name = "") { - mindspore::Int64ImmPtr dst_type = std::make_shared(mindspore::TypeId::kNumberTypeInt32); - if (t->data_type() != mindspore::TypeId::kNumberTypeInt32) { - if (!name.empty()) { - tensor_map_[name] = t; - } - return mindspore::kernel::pyboost::cast(t, dst_type); - } - return t; - } - BaseTensorPtr RecoveryTensorDtype(const BaseTensorPtr &t, const std::string &name) { - auto iter = tensor_map_.find(name); - if (iter == tensor_map_.end()) { - return t; - } - auto ori_tensor = iter->second; - auto ori_dtype = std::make_shared(ori_tensor->data_type()); - auto ret = mindspore::kernel::pyboost::cast(t, ori_dtype); - ori_tensor->AssignValue(*ret); - return ori_tensor; - } - std::map tensor_map_; -}; - -void AdvStepFlashAscendC(int32_t num_seqs, int32_t num_queries, int32_t block_size, - BaseTensorPtr &input_tokens, // output - BaseTensorPtr sampled_token_ids, // input - BaseTensorPtr &input_positions, // output - BaseTensorPtr &seq_lens, // input&output (inplace) - BaseTensorPtr &slot_mapping, // output - BaseTensorPtr block_tables // input -) { - // the AdvStepFlashKernelEntry only support int32 inputs. - DtypeCaster caster; - sampled_token_ids = caster.CheckAndCast(sampled_token_ids); - block_tables = caster.CheckAndCast(block_tables); - input_tokens = caster.CheckAndCast(input_tokens, "input_tokens"); - input_positions = caster.CheckAndCast(input_positions, "input_positions"); - slot_mapping = caster.CheckAndCast(slot_mapping, "slot_mapping"); - seq_lens = caster.CheckAndCast(seq_lens, "seq_lens"); - - auto stream_id = PyBoostUtils::cur_stream_id(); - auto device_context = mindspore::runtime::OpRunner::GetDeviceContext("Ascend"); - PyBoostUtils::PrepareOpInputs(device_context, stream_id, input_tokens, sampled_token_ids, input_positions, seq_lens, - slot_mapping, block_tables); - // PyBoostUtils::PrepareOpOutputs(device_context, stream_id, outputs); - PyBoostUtils::DispatchRun(std::make_shared([=]() { - PyBoostUtils::MallocOpInputs(device_context, input_tokens, sampled_token_ids, input_positions, seq_lens, - slot_mapping, block_tables); - // PyBoostUtils::MallocOpOutputs(device_context, outputs); - - uint8_t *sampledTokenIdsPtr = GetDataPtr(sampled_token_ids); - uint8_t *blockTablesPtr = GetDataPtr(block_tables); - uint8_t *seqLensPtr = GetDataPtr(seq_lens); - uint8_t *inputTokensPtr = GetDataPtr(input_tokens); - uint8_t *inputPositionsPtr = GetDataPtr(input_positions); - uint8_t *slotMappingPtr = GetDataPtr(slot_mapping); - auto aclStream = device_context->device_res_manager_->GetStream(stream_id); - auto stride = block_tables->stride(); - int32_t block_tables_stride = stride.empty() ? 1 : stride[0]; - - mindspore::runtime::OpExecutor::DispatchLaunchTask([=]() { - uint32_t blockDims = 1; - void *l2ctrl = nullptr; - AdvStepFlashKernelEntry(blockDims, l2ctrl, aclStream, sampledTokenIdsPtr, blockTablesPtr, seqLensPtr, - inputTokensPtr, inputPositionsPtr, seqLensPtr, slotMappingPtr, num_seqs, block_size, - block_tables_stride); - }); - })); - - input_tokens = caster.RecoveryTensorDtype(input_tokens, "input_tokens"); - input_positions = caster.RecoveryTensorDtype(input_positions, "input_positions"); - slot_mapping = caster.RecoveryTensorDtype(slot_mapping, "slot_mapping"); - seq_lens = caster.RecoveryTensorDtype(seq_lens, "seq_lens"); -} - -MS_EXTENSION_MODULE(adv_step_flash) { - m.def("adv_step_flash", &AdvStepFlashAscendC, "adv_step_flash_ascendc", pybind11::arg("num_seqs"), - pybind11::arg("num_queries"), pybind11::arg("block_size"), pybind11::arg("input_tokens"), - pybind11::arg("sampled_token_ids"), pybind11::arg("input_positions"), pybind11::arg("seq_lens"), - pybind11::arg("slot_mapping"), pybind11::arg("block_tables")); -}