diff --git a/.gitmodules b/.gitmodules
index ec95f133e4e555afa289119a89f145fa7f17942f..d057201a76ae4ad81c3deff14ede37dbad233438 100644
--- a/.gitmodules
+++ b/.gitmodules
@@ -1,4 +1,4 @@
[submodule "tests/mindformers"]
path = tests/mindformers
url = https://gitee.com/mindspore/mindformers.git
- branch = dev
+ branch = br_infer_deepseek_os
diff --git a/.jenkins/test/config/dependent_packages.yaml b/.jenkins/test/config/dependent_packages.yaml
index f037eeb10fdbc2ad3cbcb524f63cdeac2e344ae2..2efd644ec233c9acca1e996d0d45c297810dc393 100644
--- a/.jenkins/test/config/dependent_packages.yaml
+++ b/.jenkins/test/config/dependent_packages.yaml
@@ -1,11 +1,10 @@
mindspore:
- 'https://repo.mindspore.cn/mindspore/mindspore/version/202506/20250605/master_20250605212230_aac98ab9732926f6abd4c3d73be47d5be6c93ead_newest/'
-
+ 'https://repo.mindspore.cn/mindspore/mindspore/version/202506/20250613/br_infer_iter_20250613031508_11bcfd2ff4dc201a1c07e5d525cbeff7ec7f9558_newest/'
mindspore_gs:
'https://repo.mindspore.cn/mindspore/golden-stick/version/202506/20250604/master_20250604160014_35fcbec4406d3b18faf02ef99fcbe2741e80348e_newest/'
msadapter:
- 'https://repo.mindspore.cn/mindspore/msadapter/version/202505/20250526/master_20250526120007_b76cb7804d1c9555e32a57439c1d412ff86293d1_newest/'
+ 'https://repo.mindspore.cn/mindspore/msadapter/version/202506/20250630/master_20250630031508_c03700434aafa08d6a7f4d384106a223fbed34b6_newest/'
vllm:
'https://repo.mindspore.cn/mirrors/vllm/version/202505/20250514/v0.8.4.dev0_newest/'
diff --git a/Dockerfile b/Dockerfile
deleted file mode 100644
index d174da7c2085a9c8173549d48cc92cce1cb813fb..0000000000000000000000000000000000000000
--- a/Dockerfile
+++ /dev/null
@@ -1,108 +0,0 @@
-FROM hub.oepkgs.net/openeuler/openeuler:22.03-lts-sp4
-
-####################### os #######################
-RUN yum clean all && \
- yum makecache && \
- yum install -y \
- kmod \
- sudo \
- wget \
- curl \
- cmake \
- make \
- git \
- vim \
- gcc && \
- yum clean all
-
-####################### python #######################
-WORKDIR /root
-RUN wget https://mirrors.tuna.tsinghua.edu.cn/anaconda/miniconda/Miniconda3-py311_25.1.1-2-Linux-aarch64.sh && \
- bash /root/Miniconda3-py311_25.1.1-2-Linux-aarch64.sh -b && \
- rm /root/Miniconda3-py311_25.1.1-2-Linux-aarch64.sh
-ENV PATH="/root/miniconda3/bin:$PATH"
-ENV PYTHONPATH="/root/miniconda3/lib/python3.11/site-packages"
-RUN pip config set global.index-url 'https://pypi.tuna.tsinghua.edu.cn/simple' && \
- pip config set global.trusted-host pypi.tuna.tsinghua.edu.cn
-
-####################### CANN #######################
-WORKDIR /root
-RUN echo "UserName=HwHiAiUser" >> /etc/ascend_install.info && \
- echo "UserGroup=HwHiAiUser" >> /etc/ascend_install.info && \
- echo "Firmware_Install_Type=full" >> /etc/ascend_install.info && \
- echo "Firmware_Install_Path_Param=/usr/local/Ascend" >> /etc/ascend_install.info && \
- echo "Driver_Install_Type=full" >> /etc/ascend_install.info && \
- echo "Driver_Install_Path_Param=/usr/local/Ascend" >> /etc/ascend_install.info && \
- echo "Driver_Install_For_All=no" >> /etc/ascend_install.info && \
- echo "Driver_Install_Mode=normal" >> /etc/ascend_install.info && \
- echo "Driver_Install_Status=complete" >> /etc/ascend_install.info
-RUN curl -s "https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/CANN/CANN%208.0.0/Ascend-cann-toolkit_8.0.0_linux-aarch64.run" -o Ascend-cann-toolkit.run && \
- curl -s "https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/CANN/CANN%208.0.0/Ascend-cann-kernels-910b_8.0.0_linux-aarch64.run" -o Ascend-cann-kernels-910b.run && \
- curl -s "https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/CANN/CANN%208.0.0/Ascend-cann-nnrt_8.0.0_linux-aarch64.run" -o Ascend-cann-nnrt.run && \
- chmod a+x *.run && \
- bash /root/Ascend-cann-toolkit.run --install -q && \
- bash /root/Ascend-cann-kernels-910b.run --install -q && \
- bash /root/Ascend-cann-nnrt.run --install -q && \
- rm /root/*.run
-RUN echo "source /usr/local/Ascend/nnrt/set_env.sh" >> /root/.bashrc && \
- echo "source /usr/local/Ascend/ascend-toolkit/set_env.sh" >> /root/.bashrc
-
-####################### dev env #######################
-RUN pip install --no-cache-dir \
- cmake>=3.26 \
- decorator \
- ray==2.42.1 \
- protobuf==3.20.0 \
- ml_dtypes \
- wheel \
- setuptools \
- wrap \
- deprecated \
- packaging \
- ninja \
- "setuptools-scm>=8" \
- numpy \
- build
-
-WORKDIR /workspace
-
-RUN git clone -b br_infer_deepseek_os https://gitee.com/mindspore/mindformers.git /workspace/mindformers && \
- cd mindformers && \
- sed -i 's/-i https:\/\/pypi.tuna.tsinghua.edu.cn\/simple//g' build.sh && \
- bash build.sh && \
- PACKAGE_PATH=$(python3 -c "import site; print(site.getsitepackages()[0])") && \
- cp -a research "$PACKAGE_PATH" && \
- rm -rf /workspace/mindformers
-
-RUN git clone https://gitee.com/mindspore/golden-stick.git /workspace/golden-stick && \
- cd golden-stick && \
- bash build.sh && \
- pip install --no-cache-dir /workspace/golden-stick/output/*.whl && \
- rm -rf /workspace/golden-stick
-
-ENV USE_TORCH="FALSE"
-ENV USE_TF="FALSE"
-RUN git clone -b v0.6.6.post1 https://gitee.com/mirrors/vllm.git /workspace/vllm && \
- cd vllm && \
- VLLM_TARGET_DEVICE=empty pip install --no-cache-dir . && \
- rm -rf /workspace/vllm
-
-RUN git clone https://openi.pcl.ac.cn/OpenI/MSAdapter.git /workspace/msadapter && \
- cd /workspace/msadapter && \
- bash scripts/build_and_reinstall.sh && \
- rm -rf /workspace/msadapter
-
-ADD . /workspace/vllm_mindspore
-RUN cd /workspace/vllm_mindspore && \
- pip install --no-cache-dir -r requirements.txt && \
- pip install . && \
- rm -rf /workspace/vllm_mindspore
-
-RUN wget -O mindspore-2.5.0-cp311-cp311-linux_aarch64.whl \
-https://repo.mindspore.cn/mindspore/mindspore/version/202503/20250303/br_infer_deepseek_os_20250303004707_705727d59236c8c197b25ad0e464c4908434d42f_newest/unified/aarch64/mindspore-2.5.0-cp311-cp311-linux_aarch64.whl && \
-pip install --no-cache-dir mindspore-2.5.0-cp311-cp311-linux_aarch64.whl && \
-rm -f mindspore-2.5.0-cp311-cp311-linux_aarch64.whl
-
-RUN pip uninstall torch torch-npu torchvision -y
-
-CMD ["bash"]
\ No newline at end of file
diff --git a/README.md b/README.md
index 5ea56601b6088bb737fde97ca01261d5bce9f4ae..1ef98bed36235252d64ab18e1b0fb88a061d28b2 100644
--- a/README.md
+++ b/README.md
@@ -1,114 +1,69 @@
-# vllm-mindspore
+
+vLLM MindSpore
+
-## Overview
+
+| 关于MindSpore | vLLM MindSpore SIG | 问题反馈 |
+
-The `vllm-mindspore`is a integration for running vLLM on the MindSpore framework.
-
-This is the recommended solution for supporting the MindSpore within the vLLM community. It provides deep integration with the MindSpore framework, offering efficient computation and optimization support for vLLM, enabling seamless operation on MindSpore.
-
-By using the `vllm-mindspore`, popular open-source models, can run seamlessly for training and inference on the MindSpore framework.
+
+English | 中文
+
---
+*最新消息* 🔥
-## Prerequisites
-
-- Hardware: Atlas A2/A3
-- Software:
- - Python >= 3.9
- - CANN >= 8.0.0
- - MindSpore >=2.5.0
+- [2025/06] 适配vLLM [v0.8.3](https://github.com/vllm-project/vllm/releases/tag/v0.8.3),新增支持vLLM V1架构、Qwen3大模型。
+- [2025/04] 完成vLLM [v0.7.3](https://github.com/vllm-project/vllm/releases/tag/v0.7.3)适配,新增支持Automatic Prefix Caching、Chunked Prefill、Multi-step Scheduling、MTP等特性。联合openEuler社区和上海交通大学,实现DeepSeek全栈开源单机推理部署,你可以在[这里](https://www.openeuler.org/zh/news/openEuler/20240421-jd/20240421-jd.html)阅读详细报道。
+- [2025/03] 完成vLLM [v0.6.6.post1](https://github.com/vllm-project/vllm/releases/tag/v0.6.6.post1)适配,支持采用`vllm.entrypoints`部署基于MindSpore的DeepSeek-V3/R1、Qwen2.5等大模型推理服务。联合openEuler社区和北京大学,发布全栈开源DeepSeek推理方案,你可以在[这里](https://news.pku.edu.cn/xwzh/e13046c47d03471c8cebb950bd1f4598.htm)阅读详细报道。
+- [2025/02] MindSpore社区正式创建了[mindspore/vllm-mindspore](https://gitee.com/mindspore/vllm-mindspore)代码,旨在将MindSpore大模型推理能力接入vLLM。
---
-## Getting Started
-
-### Installation
-
-#### Installation from source code
-
-Install from source code. [Wiki Installation.](https://gitee.com/mindspore/vllm-mindspore/wikis/Getting%20Started/Installation)
-
-#### Set up using Docker
-
-##### Pre-built images
-
-```shell
-docker pull hub.oepkgs.net/oedeploy/openeuler/aarch64/mindspore:v1.0
-```
-
-##### Build image from source
-
-```shell
-docker build --network=host .
-```
-
-### Inference and Serving
-
-#### Offline Inference
-
-You can run vllm_mindspore in your own code on a list of prompts.
-
-```bash
-export ASCEND_TOTAL_MEMORY_GB=64 # Based on the ascend device.
-```
-
-```python
-
-import vllm_mindspore # Add this line on the top of script.
+# 简介
-from vllm import LLM, SamplingParams
+vLLM Mindspore插件(`vllm-mindspore`)是一个由[MindSpore社区](https://www.mindspore.cn/)孵化的vLLM后端插件。其将基于MindSpore构建的大模型推理能力接入[vLLM](https://github.com/vllm-project/vllm),从而有机整合MindSpore和vLLM的技术优势,提供全栈开源、高性能、易用的大模型推理解决方案。
-# Sample prompts.
-prompts = [
- "I am",
- "Today is",
- "What is"
-]
+vLLM MindSpore插件以将MindSpore大模型接入vLLM,并实现服务化部署为功能目标。其遵循以下设计原则:
-# Create a sampling params object.
-sampling_params = SamplingParams(temperature=0.0, top_p=0.95)
+- 接口兼容:支持vLLM原生的API和服务部署接口,避免新增配置文件或接口,降低用户学习成本和确保易用性。
+- 最小化侵入式修改:尽可能避免侵入式修改vLLM代码,以保障系统的可维护性和可演进性。
+- 组件解耦:最小化和规范化MindSpore大模型组件和vLLM服务组件的耦合面,以利于多种MindSpore大模型套件接入。
-# Create an LLM.
-llm = LLM(model="Qwen/Qwen2.5-32B-Instruct", tensor_parallel_size=8)
-# Generate texts from the prompts. The output is a list of RequestOutput objects
-# that contain the prompt, generated text, and other information.
-outputs = llm.generate(prompts, sampling_params)
-# Print the outputs.
-for output in outputs:
- prompt = output.prompt
- generated_text = output.outputs[0].text
- print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
+基于上述设计原则,vLLM MindSpore采用如下图所示的系统架构,分组件类别实现vLLM与MindSpore的对接:
-```
+- 服务化组件:通过将LLM Engine、Scheduler等服务化组件中的PyTorch API调用映射至MindSpore能力调用,继承支持包括Continuous Batching、PagedAttention在内的服务化功能。
+- 大模型组件:通过注册或替换模型、网络层、自定义算子等组件,将MindSpore Transformers、MindSpore One等MindSpore大模型套件和自定义大模型接入vLLM。
-#### Serving(OpenAI-Compatible)
+
+

+
-You can start the server via the vllm_mindspore command:
+vLLM MindSpore采用vLLM社区推荐的插件机制,实现能力注册。未来期望遵循[RPC Multi-framework support for vllm](https://gitee.com/mindspore/vllm-mindspore/issues/IBTNRG)所述原则。
-`python3 -m vllm_mindspore.entrypoints vllm.entrypoints.openai.api_server --model "Qwen/Qwen2.5-32B-Instruct" --tensor_parallel_size=8`
+# 环境准备
-To call the server, you can use `curl` or any other HTTP client.
+- 硬件:Atlas 800I A2推理服务器,或Atlas 800T A2推理服务器,已安装必要的驱动程序,并可连接至互联网
+- 操作系统:openEuler或Ubuntu Linux
+- 软件:
+ - Python >= 3.9, < 3.12
+ - CANN >= 8.0.0.beta1
+ - MindSpore
+ - vLLM
-```shell
+注:请参考[版本配套](https://gitee.com/mindspore/docs/blob/master/docs/vllm_mindspore/docs/source_zh_cn/getting_started/installation/installation.md),以获取详细的软件版本配套信息。
-curl http://localhost:8000/v1/completions \
- -H "Content-Type: application/json" \
- -d '{
- "model": "Qwen/Qwen2.5-32B-Instruct",
- "prompt": "MindSpore is",
- "max_tokens": 120,
- "temperature": 0
- }'
+# 快速体验
-```
+请查看[快速体验](https://gitee.com/mindspore/docs/blob/master/docs/vllm_mindspore/docs/source_zh_cn/getting_started/quick_start/quick_start.md)和[安装指南](https://gitee.com/mindspore/docs/blob/master/docs/vllm_mindspore/docs/source_zh_cn/getting_started/installation/installation.md)了解更多。
-## Contributing
+# 贡献
-We welcome and value any contributions and collaborations:
+请参考 [CONTRIBUTING](https://gitee.com/mindspore/docs/blob/master/docs/vllm_mindspore/docs/source_zh_cn/developer_guide/contributing.md) 文档了解更多关于开发环境搭建、功能测试以及 PR 提交规范的信息。
-- Please feel free comments about your usage of vllm_mindspore.
-- Please let us know if you encounter a bug by filing an issue.
+我们欢迎并重视任何形式的贡献与合作,请通过[Issue](https://gitee.com/mindspore/vllm-mindspore/issues)来告知我们您遇到的任何Bug,或提交您的特性需求、改进建议、技术方案。
-## License
+# SIG组织
-Apache License 2.0, as found in the [LICENSE](https://gitee.com/mindspore/vllm_mindspore/blob/master/LICENSE) file.
+- 欢迎加入LLM Infercence Serving,参与开源项目共建和产业合作:[https://www.mindspore.cn/community/SIG](https://www.mindspore.cn/community/SIG)
+- SIG例会,双周周三或周四下午,16:30 - 17:30 (UTC+8, [查看您的时区](https://dateful.com/convert/gmt8?t=15))
diff --git a/README_en.md b/README_en.md
new file mode 100644
index 0000000000000000000000000000000000000000..e402fc66a4586956981ce404c90c456d8fc089fb
--- /dev/null
+++ b/README_en.md
@@ -0,0 +1,69 @@
+
+vLLM MindSpore
+
+
+
+| About MindSpore | vLLM MindSpore SIG | Issue Feedback |
+
+
+
+English | 中文
+
+
+---
+*Latest News* 🔥
+
+- [2025/06] Adaptation for vLLM [v0.8.3](https://github.com/vllm-project/vllm/releases/tag/v0.8.3), support for vLLM V1 architecture and the Qwen3 large model.
+- [2025/04] Adaptation for vLLM [v0.7.3](https://github.com/vllm-project/vllm/releases/tag/v0.7.3), support Automatic Prefix Caching, Chunked Prefill, Multi-step Scheduling, and MTP. In collaboration with the openEuler community and Shanghai Jiao Tong University, we achieved full-stack open-source single-machine inference deployment for DeepSeek. You can read the detailed report [here](https://news.pku.edu.cn/xwzh/e13046c47d03471c8cebb950bd1f4598.htm).
+- [2025/03] Adaptation for vLLM [v0.6.6.post1](https://github.com/vllm-project/vllm/releases/tag/v0.6.6.post1) supporting the deployment of inference services for large models such as DeepSeek-V3/R1 and Qwen2.5 based on MindSpore using `vllm.entrypoints`. In collaboration with the openEuler community and Peking University, we released a full-stack open-source DeepSeek inference solution. You can read the detailed report [here](https://news.pku.edu.cn/xwzh/e13046c47d03471c8cebb950bd1f4598.htm).
+- [2025/02] The MindSpore community officially created the [mindspore/vllm-mindspore](https://gitee.com/mindspore/vllm-mindspore) repository, aiming to integrate MindSpore's large model inference capabilities into vLLM.
+
+---
+
+# Overview
+
+vLLM MindSpore (`vllm-mindspore`) is a plugin brewed by the [MindSpore community](https://www.mindspore.cn/en), which aims to integrate MindSpore LLM inference capabilities into [vLLM](https://github.com/vllm-project/vllm). With vLLM MindSpore, technical strengths of Mindspore and vLLM will be organically combined to provide a full-stack open-source, high-performance, easy-to-use LLM inference solution.
+
+vLLM MindSpore plugin aims to integrate Mindspore large models into vLLM and to enable deploying MindSpore-based LLM inference services. It follows the following design principles:
+
+- Interface compatibility: support the native APIs and service deployment interfaces of vLLM to avoid adding new configuration files or interfaces, reducing user learning costs and ensuring ease of use.
+- Minimal invasive modifications: minimize invasive modifications to the vLLM code to ensure system maintainability and evolvability.
+- Component decoupling: minimize and standardize the coupling between MindSpore large model components and vLLM service components to facilitate the integration of various MindSpore large model suites.
+
+On the basis of the above design principles, vLLM MindSpore adopts the system architecture shown in the figure below, and implements the docking between vLLM and Mindspore in categories of components:
+
+- Service components: vLLM MindSpore maps PyTorch API calls in service components including LLMEngine and Scheduler to MindSpore capabilities, inheriting support for service functions like Continuous Batching and PagedAttention.
+- Model components: vLLM MindSpore registers or replaces model components including models, network layers, and custom operators, and integrates MindSpore Transformers, MindSpore One, and other MindSpore large model suites, as well as custom large models, into vLLM.
+
+
+

+
+
+vLLM MindSpore uses the plugin mechanism recommended by the vLLM community to realize capability registration. In the future, we expect to follow principles described in [[RPC] Multi-framework support for vllm](https://gitee.com/mindspore/vllm-mindspore/issues/IBTNRG).
+
+# Prerequisites
+
+- Hardware:Atlas 800I A2 Inference series, or Atlas 800T A2 Training series, with necessary drivers installed and access to the Internet.
+- Operating System: openEuler or Ubuntu Linux.
+- Software:
+ - Python >= 3.9, < 3.12
+ - CANN >= 8.0.0.beta1
+ - MindSpore
+ - vLLM
+
+Note: Please refer to [Version Compatibility](https://gitee.com/mindspore/docs/blob/master/docs/vllm_mindspore/docs/source_en/getting_started/installation/installation.md) for more details about version compatibility information.
+
+# Getting Started
+
+Please refer to [Quick Start](https://gitee.com/mindspore/docs/blob/master/docs/vllm_mindspore/docs/source_en/getting_started/quick_start/quick_start.md) and [Installation](https://gitee.com/mindspore/docs/blob/master/docs/vllm_mindspore/docs/source_en/getting_started/installation/installation.md) for more details.
+
+# Contributing
+
+Please read [CONTRIBUTING](https://gitee.com/mindspore/docs/blob/master/docs/vllm_mindspore/docs/source_en/developer_guide/contributing.md) for details on setting up development environments, testing functions, and submitting PR.
+
+We welcome and value any form of contribution and cooperation. Please use [Issue](https://gitee.com/mindspore/vllm-mindspore/issues) to inform us of any bugs you encounter, or to submit your feature requests, improvement suggestions, and technical solutions.
+
+# SIG
+
+- Welcome to join vLLM MindSpore SIG to participate in the co-construction of open-source projects and industrial cooperation: [https://www.mindspore.cn/community/SIG](https://www.mindspore.cn/community/SIG)
+- SIG meetings, every other Wednesday or Thursday afternoon, 16:30 - 17:30 (UTC+8, [Convert to your timezone](https://dateful.com/convert/gmt8?t=15))
diff --git a/codecheck_toolkits/vllm_codecheck.sh b/codecheck_toolkits/vllm_codecheck.sh
index c91087fe19df09b701de0b66d52a64f008ccb037..ef9b4c4b8a0cd837fa73f4c1ada34ea0708e9691 100644
--- a/codecheck_toolkits/vllm_codecheck.sh
+++ b/codecheck_toolkits/vllm_codecheck.sh
@@ -6,7 +6,7 @@ RET_FLAG=0
# yapf check
-MERGEBASE="$(git merge-base origin/master HEAD)"
+MERGEBASE="$(git merge-base origin/develop HEAD)"
if ! git diff --cached --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &> /dev/null; then
git diff --cached --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs -P 5 \
yapf --diff --recursive --parallel --exclude tests/
@@ -57,6 +57,15 @@ else
fi
# mypy check type
+# download vllm
+
+cd codecheck_toolkits
+git clone https://gitee.com/mirrors/vllm.git -b v0.8.3
+cd -
+
+export MYPYPATH=codecheck_toolkits/vllm
+
+# check
PYTHON_VERSION=$(python -c 'import sys; print(f"{sys.version_info.major}.{sys.version_info.minor}")')
@@ -73,4 +82,4 @@ fi
rm -f pyproject.toml
-exit 0
+exit ${RET_FLAG}
diff --git a/docs/arch.cn.png b/docs/arch.cn.png
new file mode 100644
index 0000000000000000000000000000000000000000..b2c2d0aedfbb3bad25e50071d8070e1f5c3f447d
Binary files /dev/null and b/docs/arch.cn.png differ
diff --git a/docs/arch.png b/docs/arch.png
new file mode 100644
index 0000000000000000000000000000000000000000..fc3b524ca3487ae92431c58157175b4ddcb42725
Binary files /dev/null and b/docs/arch.png differ
diff --git a/examples/tool_chat_template_deepseekv3_zh_prompt.jinja b/examples/tool_chat_template_deepseekv3_zh_prompt.jinja
new file mode 100644
index 0000000000000000000000000000000000000000..69d542635eadad139447668b2640f7ed472ada58
--- /dev/null
+++ b/examples/tool_chat_template_deepseekv3_zh_prompt.jinja
@@ -0,0 +1,101 @@
+{% if not add_generation_prompt is defined %}
+ {% set add_generation_prompt = false %}
+{% endif %}
+{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='', is_first_sp=true, is_last_user=false) %}
+{%- for message in messages %}
+ {%- if message['role'] == 'system' %}
+ {%- if ns.is_first_sp %}
+ {% set ns.system_prompt = ns.system_prompt + message['content'] %}
+ {% set ns.is_first_sp = false %}
+ {%- else %}
+ {% set ns.system_prompt = ns.system_prompt + '\n\n' + message['content'] %}
+ {%- endif %}
+ {%- endif %}
+{%- endfor %}
+
+{#- Adapted from https://github.com/sgl-project/sglang/blob/main/examples/chat_template/tool_chat_template_deepseekr1.jinja #}
+{% if tools is defined and tools is not none %}
+ {% set tool_ns = namespace(text='你可以调用工具函数。'
+ '当你需要调用工具时,你必须严格遵守下面的格式输出:'
+ '<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>FUNCTION_NAME\n'
+ '```json\n{"param1": "value1", "param2": "value2"}\n```<|tool▁call▁end|><|tool▁calls▁end|>\n\n'
+ '不遵守上面的格式就不能成功调用工具,是错误答案。\n'
+ '错误答案举例1:function<|tool▁sep|>FUNCTION_NAME\n```json\n{"param1": "value1", "param2": "value2"}\n```'
+ '<|tool▁call▁end|><|tool▁calls▁end|>\n'
+ '错误1原因:没有使用<|tool▁calls▁begin|>、<|tool▁call▁begin|>,不符合格式。\n'
+ '错误答案举例2:<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>FUNCTION_NAME\n```json\n'
+ '{"param1": "value1", "param2": "value2"}\n```<|tool▁call▁end|>\n'
+ '错误2原因:没有使用<|tool▁calls▁end|>,不符合格式。\n'
+ '错误答案举例3:<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>FUNCTION_NAME\n```json\n'
+ '{"param1": "value1", "param2": "value2"}\n```<|tool▁call▁end|><|tool▁calls▁begin|>\n'
+ '错误3原因:最后一个<|tool▁calls▁begin|>应为<|tool▁calls▁end|>,不符合格式。'
+ '## Tools\n\n### Function\n\nYou have the following functions available:\n\n') %}
+ {% for tool in tools %}
+ {% set tool_ns.text = tool_ns.text + '\n```json\n' + (tool | tojson) + '\n```\n' %}
+ {% endfor %}
+ {% set ns.system_prompt = ns.system_prompt + '\n\n' + tool_ns.text %}
+{% endif %}
+
+{{ bos_token }}
+{{ ns.system_prompt }}
+{%- for message in messages %}
+ {% set content = message['content'] %}
+ {%- if message['role'] == 'user' %}
+ {%- set ns.is_tool = false -%}
+ {%- set ns.is_first = false -%}
+ {%- set ns.is_last_user = true -%}
+ {{'<|User|>' + content + '<|Assistant|>'}}
+ {%- endif %}
+ {%- if message['role'] == 'assistant' %}
+ {% if '' in content %}
+ {% set content = content.split('')[-1] %}
+ {% endif %}
+ {% endif %}
+ {%- if message['role'] == 'assistant' and message['tool_calls'] is defined and message['tool_calls'] is not none %}
+ {%- set ns.is_last_user = false -%}
+ {%- if ns.is_tool %}
+ {{'<|tool▁outputs▁end|>'}}
+ {%- endif %}
+ {%- set ns.is_first = false %}
+ {%- set ns.is_tool = false -%}
+ {%- set ns.is_output_first = true %}
+ {%- for tool in message['tool_calls'] %}
+ {%- if not ns.is_first %}
+ {%- if content is none %}
+ {{'<|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments']|tojson + '\n' + '```' + '<|tool▁call▁end|>'}}
+ {%- else %}
+ {{content + '<|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments']|tojson + '\n' + '```' + '<|tool▁call▁end|>'}}
+ {%- endif %}
+ {%- set ns.is_first = true -%}
+ {%- else %}
+ {{'\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments']|tojson + '\n' + '```' + '<|tool▁call▁end|>'}}
+ {%- endif %}
+ {%- endfor %}
+ {{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}
+ {%- endif %}
+ {%- if message['role'] == 'assistant' and (message['tool_calls'] is not defined or message['tool_calls'] is none)%}
+ {%- set ns.is_last_user = false -%}
+ {%- if ns.is_tool %}
+ {{'<|tool▁outputs▁end|>' + content + '<|end▁of▁sentence|>'}}
+ {%- set ns.is_tool = false -%}
+ {%- else %}
+ {{content + '<|end▁of▁sentence|>'}}
+ {%- endif %}
+ {%- endif %}
+ {%- if message['role'] == 'tool' %}
+ {%- set ns.is_last_user = false -%}
+ {%- set ns.is_tool = true -%}
+ {%- if ns.is_output_first %}
+ {{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + content + '<|tool▁output▁end|>'}}
+ {%- set ns.is_output_first = false %}
+ {%- else %}
+ {{'\n<|tool▁output▁begin|>' + content + '<|tool▁output▁end|>'}}
+ {%- endif %}
+ {%- endif %}
+{%- endfor -%}
+{% if ns.is_tool %}
+ {{'<|tool▁outputs▁end|>'}}
+{% endif %}
+{% if add_generation_prompt and not ns.is_last_user and not ns.is_tool %}
+ {{'<|Assistant|>'}}
+{% endif %}
diff --git a/install_depend_pkgs.sh b/install_depend_pkgs.sh
index ba0f798868f3ba51e946c93a4e0fea8f0d640748..302cad80a1776a5055bb8e5a64a907566e387cef 100644
--- a/install_depend_pkgs.sh
+++ b/install_depend_pkgs.sh
@@ -40,7 +40,7 @@ vllm_dir=vllm-v0.8.3
if [ ! -d "$vllm_dir" ]; then
git clone https://github.com/vllm-project/vllm.git -b v0.8.3 "$vllm_dir"
cd "$vllm_dir" || { echo "Failed to git clone vllm!"; exit 1; }
- git apply ../../vllm_dp/dp_scale_out.patch
+ git apply $script_dir/vllm_dp/dp_scale_out.patch
else
echo "The $vllm_dir folder already exists and will not be re-downloaded."
cd "$vllm_dir" || { echo "Failed to git clone vllm!"; exit 1; }
@@ -49,7 +49,7 @@ pip uninstall msadapter -y
pip uninstall vllm -y
pip install -v -r requirements/cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu
VLLM_TARGET_DEVICE=empty python setup.py install || { echo "Failed to install vllm"; exit 1; }
-pip uninstall torch torch-npu torchvision -y
+pip uninstall torch torch-npu torchvision torchaudio -y
cd ..
@@ -64,10 +64,10 @@ pip uninstall mindspore -y && pip install "$mindspore_name" || { echo "Failed to
echo "========= Installing mindformers"
-mf_dir=mindformers-dev
+mf_dir=mindformers-os
if [ ! -d "$mf_dir" ]; then
- git clone https://gitee.com/mindspore/mindformers.git -b dev "$mf_dir"
- git checkout dfb8aa3a59401495b2d8c8c107d46fe0d36c949a
+ git clone https://gitee.com/mindspore/mindformers.git -b br_infer_deepseek_os "$mf_dir"
+ git checkout 849e943230b7f30317654327109df1dd7acd4b4c
else
echo "The $mf_dir folder already exists and will not be re-downloaded."
fi
@@ -101,3 +101,4 @@ pip uninstall msadapter -y && pip install . || { echo "Failed to install msadap
cd ..
echo "========= All dependencies installed successfully!"
+echo -e "[\033[0;34mnotice\033[0m]Please set the command: export PYTHONPATH=$(pwd)/$mf_dir/:\$PYTHONPATH"
diff --git a/tests/mindformers b/tests/mindformers
index f046081e40be777eb799afee10495b51cdb2f3c1..849e943230b7f30317654327109df1dd7acd4b4c 160000
--- a/tests/mindformers
+++ b/tests/mindformers
@@ -1 +1 @@
-Subproject commit f046081e40be777eb799afee10495b51cdb2f3c1
+Subproject commit 849e943230b7f30317654327109df1dd7acd4b4c
diff --git a/tests/st/python/cases_parallel/multilora_inference.py b/tests/st/python/cases_parallel/multilora_inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e2129a195dabdf6e6dba315571fcf4a04883d88
--- /dev/null
+++ b/tests/st/python/cases_parallel/multilora_inference.py
@@ -0,0 +1,109 @@
+#!/usr/bin/env python3
+# Copyright 2025 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""
+This example shows how to use the multi-LoRA functionality
+for offline inference.
+
+"""
+import pytest
+import os
+from tests.st.python import set_env
+
+env_manager = set_env.EnvVarManager()
+# def env
+env_vars = {
+ "ASCEND_CUSTOM_PATH": os.path.expandvars("$ASCEND_HOME_PATH/../"),
+ "MS_ENABLE_LCCL": "off",
+ "HCCL_OP_EXPANSION_MODE": "AIV",
+ "MS_ALLOC_CONF": "enable_vmm:True",
+ "LCCL_DETERMINISTIC": "1",
+ "HCCL_DETERMINISTIC": "true",
+ "ATB_MATMUL_SHUFFLE_K_ENABLE": "0",
+ "ATB_LLM_LCOC_ENABLE": "0",
+ "VLLM_USE_V1": "1",
+}
+# set env
+env_manager.setup_ai_environment(env_vars)
+import vllm_mindspore
+from typing import List, Optional, Tuple
+
+from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
+from vllm.lora.request import LoRARequest
+
+
+def create_test_prompts(
+ lora_path: str
+) -> List[Tuple[str, SamplingParams, Optional[LoRARequest]]]:
+ """Create a list of test prompts with their sampling parameters.
+ """
+ return [
+ ("违章停车与违法停车是否有区别?",
+ SamplingParams(temperature=0.0, top_p=1, top_k=-1,
+ max_tokens=10), LoRARequest("sql-lora1", 1,
+ lora_path)),
+ ]
+
+
+def process_requests(engine: LLMEngine,
+ test_prompts: List[Tuple[str, SamplingParams,
+ Optional[LoRARequest]]]):
+ """Continuously process a list of prompts and handle the outputs."""
+ request_id = 0
+
+ while test_prompts or engine.has_unfinished_requests():
+ if test_prompts:
+ prompt, sampling_params, lora_request = test_prompts.pop(0)
+ engine.add_request(str(request_id),
+ prompt,
+ sampling_params,
+ lora_request=lora_request)
+ request_id += 1
+
+ request_outputs: List[RequestOutput] = engine.step()
+ for request_output in request_outputs:
+ if request_output.finished:
+ print(f'text is: {request_output.outputs[0].text}', flush=True)
+ assert " 从法律上来说,违章停车和违法" in request_output.outputs[0].text
+
+
+def initialize_engine() -> LLMEngine:
+ """Initialize the LLMEngine."""
+ # max_loras: controls the number of LoRAs that can be used in the same
+ # batch. Larger numbers will cause higher memory usage, as each LoRA
+ # slot requires its own preallocated tensor.
+ # max_lora_rank: controls the maximum supported rank of all LoRAs. Larger
+ # numbers will cause higher memory usage. If you know that all LoRAs will
+ # use the same rank, it is recommended to set this as low as possible.
+ # max_cpu_loras: controls the size of the CPU LoRA cache.
+ engine_args = EngineArgs(
+ model="/home/workspace/mindspore_dataset/weight/Qwen2.5-7B-Instruct",
+ enable_lora=True,
+ max_loras=1,
+ max_lora_rank=64,
+ max_cpu_loras=2,
+ max_num_seqs=256,
+ max_model_len=256,
+ max_num_batched_tokens=400)
+ return LLMEngine.from_engine_args(engine_args)
+
+
+def test_multilora_inference():
+ """test function that sets up and runs the prompt processing."""
+ engine = initialize_engine()
+ lora_path = "/home/workspace/mindspore_dataset/weight/Qwen2.5-7B-Lora-Law"
+ test_prompts = create_test_prompts(lora_path)
+ process_requests(engine, test_prompts)
+ env_manager.unset_all()
diff --git a/tests/st/python/cases_parallel/similarity.py b/tests/st/python/cases_parallel/similarity.py
new file mode 100644
index 0000000000000000000000000000000000000000..bfdae0d90d39150efd4e650160921ecf663e7bf4
--- /dev/null
+++ b/tests/st/python/cases_parallel/similarity.py
@@ -0,0 +1,58 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+import math
+
+import jieba
+import numpy as np
+
+
+def _get_all_words(standard_cut_infer_ret_list, test_cut_infer_ret_list):
+ all_words = []
+ for s_cut in standard_cut_infer_ret_list:
+ if s_cut not in all_words:
+ all_words.append(s_cut)
+ for t_cut in test_cut_infer_ret_list:
+ if t_cut not in all_words:
+ all_words.append(t_cut)
+ return all_words
+
+
+def _get_word_vector(standard_cut_infer_ret_list, test_cut_infer_ret_list,
+ all_words):
+ la_standard = []
+ lb_test = []
+ for word in all_words:
+ la_standard.append(standard_cut_infer_ret_list.count(word))
+ lb_test.append(test_cut_infer_ret_list.count(word))
+ return la_standard, lb_test
+
+
+def _get_calculate_cos(la_standard, lb_test):
+ laa = np.array(la_standard)
+ lbb = np.array(lb_test)
+ cos = (np.dot(laa, lbb.T)) / ((math.sqrt(np.dot(laa, laa.T))) *
+ (math.sqrt(np.dot(lbb, lbb.T))))
+ return np.round(cos, 2)
+
+
+def compare_distance(x1, x2, bench_sim=0.95):
+ """compare distance"""
+ y1 = list(jieba.cut(x1))
+ y2 = list(jieba.cut(x2))
+ all_words = _get_all_words(y1, y2)
+ laa, lbb = _get_word_vector(y1, y2, all_words)
+ sim = _get_calculate_cos(laa, lbb)
+ print("calculate sim is:{}".format(str(sim)))
+ assert sim >= bench_sim
diff --git a/tests/st/python/test_vllm_deepseek_bf16_part_v1.py b/tests/st/python/cases_parallel/vllm_deepseek_bf16_part.py
similarity index 89%
rename from tests/st/python/test_vllm_deepseek_bf16_part_v1.py
rename to tests/st/python/cases_parallel/vllm_deepseek_bf16_part.py
index 7a88aa370bb6ce1dabf0b1c8a384e7abed484de7..6c29cc4c9fd50d8d91b20fe4af7bb1529c88a3ab 100644
--- a/tests/st/python/test_vllm_deepseek_bf16_part_v1.py
+++ b/tests/st/python/cases_parallel/vllm_deepseek_bf16_part.py
@@ -17,7 +17,7 @@
"""test mf deepseek r1."""
import pytest
import os
-from . import set_env
+from tests.st.python import set_env
env_manager = set_env.EnvVarManager()
# def env
@@ -27,14 +27,12 @@ env_vars = {
"vLLM_MODEL_BACKEND": "MindFormers",
"MS_ENABLE_LCCL": "on",
"HCCL_OP_EXPANSION_MODE": "AIV",
- "ASCEND_RT_VISIBLE_DEVICES": "0,1,2,3,4,5,6,7",
"MS_ALLOC_CONF": "enable_vmm:True",
"LCCL_DETERMINISTIC": "1",
"HCCL_DETERMINISTIC": "true",
"ATB_MATMUL_SHUFFLE_K_ENABLE": "0",
"ATB_LLM_LCOC_ENABLE": "0",
- "HCCL_IF_BASE_PORT": "60000",
- "LCAL_COMM_ID": "127.0.0.1:10068"
+ "VLLM_USE_V1": "0"
}
# set env
env_manager.setup_ai_environment(env_vars)
@@ -42,9 +40,6 @@ import vllm_mindspore
from vllm import LLM, SamplingParams
-@pytest.mark.level0
-@pytest.mark.platform_arm_ascend910b_training
-@pytest.mark.env_single
def test_deepseek_r1_bf16():
"""
test case deepseek r1 bf16
@@ -60,7 +55,7 @@ def test_deepseek_r1_bf16():
# Create an LLM.
llm = LLM(model="/home/workspace/mindspore_dataset/weight/DeepSeek-R1-bf16",
- trust_remote_code=True, gpu_memory_utilization=0.9, tensor_parallel_size=8, max_model_len=4096)
+ trust_remote_code=True, gpu_memory_utilization=0.9, tensor_parallel_size=2, max_model_len=4096)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
diff --git a/tests/st/python/test_vllm_deepseek_bf16_part.py b/tests/st/python/cases_parallel/vllm_deepseek_bf16_part_v1.py
similarity index 83%
rename from tests/st/python/test_vllm_deepseek_bf16_part.py
rename to tests/st/python/cases_parallel/vllm_deepseek_bf16_part_v1.py
index c772e16da5a56f64bd48f60f1a991193adabe00e..4d4fb5c0f9782e296da5553f5bc3037ee67ed3dc 100644
--- a/tests/st/python/test_vllm_deepseek_bf16_part.py
+++ b/tests/st/python/cases_parallel/vllm_deepseek_bf16_part_v1.py
@@ -1,4 +1,5 @@
#!/usr/bin/env python3
+# isort:skip_file
# encoding: utf-8
# Copyright 2025 Huawei Technologies Co., Ltd
#
@@ -17,7 +18,7 @@
"""test mf deepseek r1."""
import pytest
import os
-from . import set_env
+from tests.st.python import set_env
env_manager = set_env.EnvVarManager()
# def env
@@ -27,15 +28,11 @@ env_vars = {
"vLLM_MODEL_BACKEND": "MindFormers",
"MS_ENABLE_LCCL": "on",
"HCCL_OP_EXPANSION_MODE": "AIV",
- "ASCEND_RT_VISIBLE_DEVICES": "0,1,2,3,4,5,6,7",
"MS_ALLOC_CONF": "enable_vmm:True",
"LCCL_DETERMINISTIC": "1",
"HCCL_DETERMINISTIC": "true",
"ATB_MATMUL_SHUFFLE_K_ENABLE": "0",
- "ATB_LLM_LCOC_ENABLE": "0",
- "VLLM_USE_V1": "0",
- "HCCL_IF_BASE_PORT": "60000",
- "LCAL_COMM_ID": "127.0.0.1:10068"
+ "ATB_LLM_LCOC_ENABLE": "0"
}
# set env
env_manager.setup_ai_environment(env_vars)
@@ -43,9 +40,6 @@ import vllm_mindspore
from vllm import LLM, SamplingParams
-@pytest.mark.level0
-@pytest.mark.platform_arm_ascend910b_training
-@pytest.mark.env_single
def test_deepseek_r1_bf16():
"""
test case deepseek r1 bf16
@@ -60,8 +54,12 @@ def test_deepseek_r1_bf16():
sampling_params = SamplingParams(temperature=0.0, max_tokens=10, top_k=1)
# Create an LLM.
- llm = LLM(model="/home/workspace/mindspore_dataset/weight/DeepSeek-R1-bf16",
- trust_remote_code=True, gpu_memory_utilization=0.9, tensor_parallel_size=8)
+ llm = LLM(
+ model="/home/workspace/mindspore_dataset/weight/DeepSeek-R1-bf16",
+ trust_remote_code=True,
+ gpu_memory_utilization=0.9,
+ tensor_parallel_size=2,
+ max_model_len=33 * 1024)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
diff --git a/tests/st/python/cases_parallel/vllm_deepseek_gptq_a16w4.py b/tests/st/python/cases_parallel/vllm_deepseek_gptq_a16w4.py
new file mode 100644
index 0000000000000000000000000000000000000000..968f805ba77d022abf2aa547a98116d96e1dc9ad
--- /dev/null
+++ b/tests/st/python/cases_parallel/vllm_deepseek_gptq_a16w4.py
@@ -0,0 +1,87 @@
+#!/usr/bin/env python3
+# isort: skip_file
+# encoding: utf-8
+# Copyright 2025 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""test mf deepseek r1 gptq int4 quantization."""
+import os
+import yaml
+import pytest
+from tests.st.python import set_env
+
+env_manager = set_env.EnvVarManager()
+# def env
+env_vars = {
+ "MINDFORMERS_MODEL_CONFIG": "./config/predict_deepseek_r1_671b_a16w4.yaml",
+ "ASCEND_CUSTOM_PATH": os.path.expandvars("$ASCEND_HOME_PATH/../"),
+ "vLLM_MODEL_BACKEND": "MindFormers",
+ "MS_ENABLE_LCCL": "off",
+ "HCCL_OP_EXPANSION_MODE": "AIV",
+ "MS_ALLOC_CONF": "enable_vmm:True",
+ "LCCL_DETERMINISTIC": "1",
+ "HCCL_DETERMINISTIC": "true",
+ "ATB_MATMUL_SHUFFLE_K_ENABLE": "0",
+ "ATB_LLM_LCOC_ENABLE": "0",
+ "VLLM_USE_V1": "0"
+}
+# set env
+env_manager.setup_ai_environment(env_vars)
+import vllm_mindspore # noqa: F401, E402
+from vllm import LLM, SamplingParams # noqa: E402
+
+
+def test_deepseek_r1_gptq_a16w4():
+ """
+ test case deepseek r1 a16w4
+ """
+ yaml_path = "./config/predict_deepseek_r1_671b.yaml"
+ a16w4_yaml = "./config/predict_deepseek_r1_671b_a16w4.yaml"
+ with open(yaml_path, 'r', encoding='utf-8') as file:
+ content = yaml.safe_load(file)
+ model_config = content["model"]["model_config"]
+ model_config["quantization_config"] = {"quant_method": "gptq-pergroup"}
+ content["model"]["model_config"] = model_config
+
+ with open(a16w4_yaml, 'w', encoding='utf-8') as file:
+ yaml.dump(content, file, allow_unicode=True, sort_keys=False)
+
+ # Sample prompts.
+ prompts = [
+ "介绍下北京故宫",
+ ]
+
+ # Create a sampling params object.
+ sampling_params = SamplingParams(temperature=0.0, max_tokens=1024, top_k=1)
+
+ # Create an LLM.
+ llm = LLM(
+ model=
+ "/home/workspace/mindspore_dataset/weight/DeepSeekR1_gptq-pergroup_safetensors",
+ trust_remote_code=True,
+ gpu_memory_utilization=0.9,
+ tensor_parallel_size=4,
+ max_model_len=4096)
+ # Generate texts from the prompts. The output is a list of RequestOutput objects
+ # that contain the prompt, generated text, and other information.
+ outputs = llm.generate(prompts, sampling_params)
+ # Print the outputs.
+ for i, output in enumerate(outputs):
+ prompt = output.prompt
+ generated_text = output.outputs[0].text
+ print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
+ assert "博物院christianాలు sic辨" in generated_text
+
+ # unset env
+ env_manager.unset_all()
diff --git a/tests/st/python/test_vllm_deepseek_osl.py b/tests/st/python/cases_parallel/vllm_deepseek_osl.py
similarity index 66%
rename from tests/st/python/test_vllm_deepseek_osl.py
rename to tests/st/python/cases_parallel/vllm_deepseek_osl.py
index 28b7a81734e8a1a02cc552feb3a827b69923b0a8..fc782b9e3169b0bd59c784c5a4cd1e31257847fa 100644
--- a/tests/st/python/test_vllm_deepseek_osl.py
+++ b/tests/st/python/cases_parallel/vllm_deepseek_osl.py
@@ -20,7 +20,7 @@ isort:skip_file
"""
import pytest
import os
-from . import set_env
+from tests.st.python import set_env
env_manager = set_env.EnvVarManager()
# def env
@@ -31,15 +31,12 @@ env_vars = {
"vLLM_MODEL_BACKEND": "MindFormers",
"MS_ENABLE_LCCL": "off",
"HCCL_OP_EXPANSION_MODE": "AIV",
- "ASCEND_RT_VISIBLE_DEVICES": "0,1,2,3,4,5,6,7",
"MS_ALLOC_CONF": "enable_vmm:True",
"LCCL_DETERMINISTIC": "1",
"HCCL_DETERMINISTIC": "true",
"ATB_MATMUL_SHUFFLE_K_ENABLE": "0",
"ATB_LLM_LCOC_ENABLE": "0",
- "VLLM_USE_V1": "0",
- "HCCL_IF_BASE_PORT": "60000",
- "LCAL_COMM_ID": "127.0.0.1:10068"
+ "VLLM_USE_V1": "0"
}
# set env
env_manager.setup_ai_environment(env_vars)
@@ -47,9 +44,6 @@ import vllm_mindspore # noqa: F401, E402
from vllm import LLM, SamplingParams # noqa: E402
-@pytest.mark.level0
-@pytest.mark.platform_arm_ascend910b_training
-@pytest.mark.env_single
def test_deepseek_r1():
"""
test case deepseek r1 w8a8
@@ -71,7 +65,45 @@ def test_deepseek_r1():
"/home/workspace/mindspore_dataset/weight/DeepSeek-R1-W8A8-osl",
trust_remote_code=True,
gpu_memory_utilization=0.9,
- tensor_parallel_size=8,
+ tensor_parallel_size=2,
+ max_model_len=4096)
+ # Generate texts from the prompts. The output is a list of RequestOutput objects
+ # that contain the prompt, generated text, and other information.
+ outputs = llm.generate(prompts, sampling_params)
+ # Print the outputs.
+ for i, output in enumerate(outputs):
+ prompt = output.prompt
+ generated_text = output.outputs[0].text
+ print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
+ assert "博物院" in generated_text
+
+ # unset env
+ env_manager.unset_all()
+
+
+def test_deepseek_r1_mss():
+ """
+ test case deepseek r1 w8a8 mss
+ """
+
+ # Sample prompts.
+ prompts = [
+ "介绍下北京故宫",
+ ]
+
+ # Create a sampling params object.
+ sampling_params = SamplingParams(temperature=0.0,
+ max_tokens=10,
+ top_k=1)
+
+ # Create an LLM.
+ llm = LLM(
+ model=
+ "/home/workspace/mindspore_dataset/weight/DeepSeek-R1-W8A8-osl",
+ trust_remote_code=True,
+ gpu_memory_utilization=0.9,
+ tensor_parallel_size=2,
+ num_scheduler_steps=8,
max_model_len=4096)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
diff --git a/tests/st/python/test_vllm_deepseek_part.py b/tests/st/python/cases_parallel/vllm_deepseek_part.py
similarity index 90%
rename from tests/st/python/test_vllm_deepseek_part.py
rename to tests/st/python/cases_parallel/vllm_deepseek_part.py
index 2e4cdec45d9c17a085bc8321aabdc588c26cf99a..7ef3e8901bca7157ff051bf94a764d4ee8a983ef 100644
--- a/tests/st/python/test_vllm_deepseek_part.py
+++ b/tests/st/python/cases_parallel/vllm_deepseek_part.py
@@ -17,7 +17,7 @@
"""test mf deepseek r1."""
import pytest
import os
-from . import set_env
+from tests.st.python import set_env
env_manager = set_env.EnvVarManager()
# def env
@@ -27,15 +27,12 @@ env_vars = {
"vLLM_MODEL_BACKEND": "MindFormers",
"MS_ENABLE_LCCL": "on",
"HCCL_OP_EXPANSION_MODE": "AIV",
- "ASCEND_RT_VISIBLE_DEVICES": "0,1,2,3,4,5,6,7",
"MS_ALLOC_CONF": "enable_vmm:True",
"LCCL_DETERMINISTIC": "1",
"HCCL_DETERMINISTIC": "true",
"ATB_MATMUL_SHUFFLE_K_ENABLE": "0",
"ATB_LLM_LCOC_ENABLE": "0",
- "VLLM_USE_V1": "0",
- "HCCL_IF_BASE_PORT": "60000",
- "LCAL_COMM_ID": "127.0.0.1:10068"
+ "VLLM_USE_V1": "0"
}
# set env
env_manager.setup_ai_environment(env_vars)
@@ -43,9 +40,6 @@ import vllm_mindspore
from vllm import LLM, SamplingParams
-@pytest.mark.level0
-@pytest.mark.platform_arm_ascend910b_training
-@pytest.mark.env_single
def test_deepseek_r1():
"""
test case deepseek r1 w8a8
@@ -61,7 +55,7 @@ def test_deepseek_r1():
# Create an LLM.
llm = LLM(model="/home/workspace/mindspore_dataset/weight/DeepSeek-R1-W8A8",
- trust_remote_code=True, gpu_memory_utilization=0.9, tensor_parallel_size=8)
+ trust_remote_code=True, gpu_memory_utilization=0.9, tensor_parallel_size=2, max_model_len=4096)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
@@ -76,9 +70,7 @@ def test_deepseek_r1():
# unset env
env_manager.unset_all()
-@pytest.mark.level0
-@pytest.mark.platform_arm_ascend910b_training
-@pytest.mark.env_single
+
def test_deepseek_mtp():
"""
test case deepseek mtp with main model of r1-w8a8
@@ -94,7 +86,7 @@ def test_deepseek_mtp():
# Create an LLM.
llm = LLM(model="/home/workspace/mindspore_dataset/weight/DeepSeek-R1-MTP",
- trust_remote_code=True, gpu_memory_utilization=0.7, tensor_parallel_size=8, max_model_len=4096,
+ trust_remote_code=True, gpu_memory_utilization=0.7, tensor_parallel_size=2, max_model_len=4096,
speculative_config={"num_speculative_tokens": 1})
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
diff --git a/tests/st/python/test_vllm_deepseek_part_v1.py b/tests/st/python/cases_parallel/vllm_deepseek_part_v1.py
similarity index 89%
rename from tests/st/python/test_vllm_deepseek_part_v1.py
rename to tests/st/python/cases_parallel/vllm_deepseek_part_v1.py
index 9f5ecd72c4d022cc43e86adfb91b390273235f6e..e5eb917a6a203ae81964f50da993c285ee2df2c5 100644
--- a/tests/st/python/test_vllm_deepseek_part_v1.py
+++ b/tests/st/python/cases_parallel/vllm_deepseek_part_v1.py
@@ -17,7 +17,7 @@
"""test mf deepseek r1."""
import pytest
import os
-from . import set_env
+from tests.st.python import set_env
env_manager = set_env.EnvVarManager()
# def env
@@ -27,14 +27,11 @@ env_vars = {
"vLLM_MODEL_BACKEND": "MindFormers",
"MS_ENABLE_LCCL": "off",
"HCCL_OP_EXPANSION_MODE": "AIV",
- "ASCEND_RT_VISIBLE_DEVICES": "0,1,2,3,4,5,6,7",
"MS_ALLOC_CONF": "enable_vmm:True",
"LCCL_DETERMINISTIC": "1",
"HCCL_DETERMINISTIC": "true",
"ATB_MATMUL_SHUFFLE_K_ENABLE": "0",
- "ATB_LLM_LCOC_ENABLE": "0",
- "HCCL_IF_BASE_PORT": "60000",
- "LCAL_COMM_ID": "127.0.0.1:10068"
+ "ATB_LLM_LCOC_ENABLE": "0"
}
# set env
env_manager.setup_ai_environment(env_vars)
@@ -42,9 +39,6 @@ import vllm_mindspore
from vllm import LLM, SamplingParams
-@pytest.mark.level0
-@pytest.mark.platform_arm_ascend910b_training
-@pytest.mark.env_single
def test_deepseek_r1():
"""
test case deepseek r1 w8a8
@@ -60,7 +54,7 @@ def test_deepseek_r1():
# Create an LLM.
llm = LLM(model="/home/workspace/mindspore_dataset/weight/DeepSeek-R1-W8A8",
- trust_remote_code=True, gpu_memory_utilization=0.9, tensor_parallel_size=8, max_model_len=4096)
+ trust_remote_code=True, gpu_memory_utilization=0.9, tensor_parallel_size=2, max_model_len=4096)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
diff --git a/tests/st/python/test_vllm_deepseek_smoothquant.py b/tests/st/python/cases_parallel/vllm_deepseek_smoothquant.py
similarity index 88%
rename from tests/st/python/test_vllm_deepseek_smoothquant.py
rename to tests/st/python/cases_parallel/vllm_deepseek_smoothquant.py
index 6937245f769675f62e1ee4d4da8037f1076b779e..48d2441adf2e5459ad80b95c518cf9529b58a122 100644
--- a/tests/st/python/test_vllm_deepseek_smoothquant.py
+++ b/tests/st/python/cases_parallel/vllm_deepseek_smoothquant.py
@@ -17,7 +17,7 @@
"""test mf deepseek r1 smoothquant."""
import pytest
import os
-from . import set_env
+from tests.st.python import set_env
env_manager = set_env.EnvVarManager()
# def env
@@ -27,15 +27,12 @@ env_vars = {
"vLLM_MODEL_BACKEND": "MindFormers",
"MS_ENABLE_LCCL": "off",
"HCCL_OP_EXPANSION_MODE": "AIV",
- "ASCEND_RT_VISIBLE_DEVICES": "0,1,2,3,4,5,6,7",
"MS_ALLOC_CONF": "enable_vmm:True",
"LCCL_DETERMINISTIC": "1",
"HCCL_DETERMINISTIC": "true",
"ATB_MATMUL_SHUFFLE_K_ENABLE": "0",
"ATB_LLM_LCOC_ENABLE": "0",
- "VLLM_USE_V1": "0",
- "HCCL_IF_BASE_PORT": "60000",
- "LCAL_COMM_ID": "127.0.0.1:10068"
+ "VLLM_USE_V1": "0"
}
# set env
env_manager.setup_ai_environment(env_vars)
@@ -43,9 +40,6 @@ import vllm_mindspore
from vllm import LLM, SamplingParams
-@pytest.mark.level0
-@pytest.mark.platform_arm_ascend910b_training
-@pytest.mark.env_single
def test_deepseek_r1():
"""
test case deepseek r1 w8a8
@@ -61,7 +55,7 @@ def test_deepseek_r1():
# Create an LLM.
llm = LLM(model="/home/workspace/mindspore_dataset/weight/DeepSeek-R1-W8A8-smoothquant-newconfig",
- trust_remote_code=True, gpu_memory_utilization=0.9, tensor_parallel_size=8, max_model_len=4096)
+ trust_remote_code=True, gpu_memory_utilization=0.9, tensor_parallel_size=2, max_model_len=4096)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
diff --git a/tests/st/python/test_vllm_deepseek_smoothquant_mss.py b/tests/st/python/cases_parallel/vllm_deepseek_smoothquant_mss.py
similarity index 88%
rename from tests/st/python/test_vllm_deepseek_smoothquant_mss.py
rename to tests/st/python/cases_parallel/vllm_deepseek_smoothquant_mss.py
index 3788bcb7dccd436517be5e00755b1ecbde106116..111c91e4bcdd4a6467ce0db0faec88599d6ee7f0 100644
--- a/tests/st/python/test_vllm_deepseek_smoothquant_mss.py
+++ b/tests/st/python/cases_parallel/vllm_deepseek_smoothquant_mss.py
@@ -17,7 +17,7 @@
"""test mf deepseek r1 smoothquant."""
import pytest
import os
-from . import set_env
+from tests.st.python import set_env
env_manager = set_env.EnvVarManager()
# def env
@@ -27,15 +27,12 @@ env_vars = {
"vLLM_MODEL_BACKEND": "MindFormers",
"MS_ENABLE_LCCL": "off",
"HCCL_OP_EXPANSION_MODE": "AIV",
- "ASCEND_RT_VISIBLE_DEVICES": "0,1,2,3,4,5,6,7",
"MS_ALLOC_CONF": "enable_vmm:True",
"LCCL_DETERMINISTIC": "1",
"HCCL_DETERMINISTIC": "true",
"ATB_MATMUL_SHUFFLE_K_ENABLE": "0",
"ATB_LLM_LCOC_ENABLE": "0",
- "VLLM_USE_V1": "0",
- "HCCL_IF_BASE_PORT": "60000",
- "LCAL_COMM_ID": "127.0.0.1:10068"
+ "VLLM_USE_V1": "0"
}
# set env
env_manager.setup_ai_environment(env_vars)
@@ -43,9 +40,6 @@ import vllm_mindspore
from vllm import LLM, SamplingParams
-@pytest.mark.level0
-@pytest.mark.platform_arm_ascend910b_training
-@pytest.mark.env_single
def test_deepseek_r1_mss():
"""
test case deepseek r1 w8a8 mss
@@ -61,7 +55,8 @@ def test_deepseek_r1_mss():
# Create an LLM.
llm = LLM(model="/home/workspace/mindspore_dataset/weight/DeepSeek-R1-W8A8-smoothquant-newconfig",
- trust_remote_code=True, gpu_memory_utilization=0.9, tensor_parallel_size=8, num_scheduler_steps=8)
+ trust_remote_code=True, gpu_memory_utilization=0.9, tensor_parallel_size=2, num_scheduler_steps=8,
+ max_model_len=4096)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
diff --git a/tests/st/python/cases_parallel/vllm_llama3.py b/tests/st/python/cases_parallel/vllm_llama3.py
new file mode 100644
index 0000000000000000000000000000000000000000..656c744d960bbe1c497719de341f9ca7e4907db7
--- /dev/null
+++ b/tests/st/python/cases_parallel/vllm_llama3.py
@@ -0,0 +1,113 @@
+#!/usr/bin/env python3
+# encoding: utf-8
+# Copyright 2025 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+# isort:skip_file
+"""test vllm llama3."""
+import os
+
+import pytest
+
+from tests.st.python import set_env
+
+env_manager = set_env.EnvVarManager()
+# def env
+env_vars = {
+ "ASCEND_CUSTOM_PATH": os.path.expandvars("$ASCEND_HOME_PATH/../"),
+ "MS_ENABLE_LCCL": "off",
+ "HCCL_OP_EXPANSION_MODE": "AIV",
+ "MS_ALLOC_CONF": "enable_vmm:True",
+ "LCCL_DETERMINISTIC": "1",
+ "HCCL_DETERMINISTIC": "true",
+ "ATB_MATMUL_SHUFFLE_K_ENABLE": "0",
+ "ATB_LLM_LCOC_ENABLE": "0",
+ "VLLM_USE_V1": "1",
+ "HCCL_IF_BASE_PORT": "60000"
+}
+# set env
+env_manager.setup_ai_environment(env_vars)
+import vllm_mindspore
+from vllm import LLM, SamplingParams
+
+
+def test_vllm_llama3_8b():
+ """
+ test case llama3.1 8B
+ """
+
+ # Sample prompts.
+ prompts = [
+ "<|start_header_id|>user<|end_header_id|>\n\n将文本分类为中性、负面或正面。 "
+ "\n文本:我认为这次假期还可以。 \n情感:<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
+ ]
+
+ # Create a sampling params object.
+ sampling_params = SamplingParams(temperature=0.0, max_tokens=10, top_k=1)
+
+ # Create an LLM.
+ llm = LLM(
+ model="/home/workspace/mindspore_dataset/weight/Llama-3.1-8B-Instruct",
+ gpu_memory_utilization=0.9,
+ tensor_parallel_size=1,
+ max_model_len=4096)
+ # Generate texts from the prompts. The output is a list of RequestOutput objects
+ # that contain the prompt, generated text, and other information.
+ outputs = llm.generate(prompts, sampling_params)
+ except_list = ['中性']
+ # Print the outputs.
+ for i, output in enumerate(outputs):
+ prompt = output.prompt
+ generated_text = output.outputs[0].text
+ print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
+ assert generated_text == except_list[i]
+
+ # unset env
+ env_manager.unset_all()
+
+
+def test_vllm_llama3_1b():
+ """
+ test case llama3.2 1B
+ """
+
+ # Sample prompts.
+ prompts = [
+ "<|start_header_id|>user<|end_header_id|>\n\n将文本分类为中性、负面或正面。 "
+ "\n文本:我认为这次假期还可以。 \n情感:<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
+ ]
+
+ # Create a sampling params object.
+ sampling_params = SamplingParams(temperature=0.0, max_tokens=10, top_k=1)
+
+ # Create an LLM.
+ llm = LLM(
+ model="/home/workspace/mindspore_dataset/weight/Llama-3.2-1B-Instruct",
+ gpu_memory_utilization=0.9,
+ tensor_parallel_size=1,
+ max_model_len=4096)
+ # Generate texts from the prompts. The output is a list of RequestOutput objects
+ # that contain the prompt, generated text, and other information.
+ outputs = llm.generate(prompts, sampling_params)
+ except_list = ['中性']
+ # Print the outputs.
+ for i, output in enumerate(outputs):
+ prompt = output.prompt
+ generated_text = output.outputs[0].text
+ print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
+ assert generated_text == except_list[i]
+
+ # unset env
+ env_manager.unset_all()
diff --git a/tests/st/python/cases_parallel/vllm_mf_qwen3_8b.py b/tests/st/python/cases_parallel/vllm_mf_qwen3_8b.py
new file mode 100644
index 0000000000000000000000000000000000000000..48de1692134eff0f30e54de79fcabe8b3e4dc52d
--- /dev/null
+++ b/tests/st/python/cases_parallel/vllm_mf_qwen3_8b.py
@@ -0,0 +1,75 @@
+#!/usr/bin/env python3
+# encoding: utf-8
+# Copyright 2025 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""test mf qwen."""
+import os
+
+import pytest
+
+from tests.st.python import set_env
+
+env_manager = set_env.EnvVarManager()
+# def env
+env_vars = {
+ "ASCEND_CUSTOM_PATH": os.path.expandvars("$ASCEND_HOME_PATH/../"),
+ "vLLM_MODEL_BACKEND": "MindFormers",
+ "MS_ENABLE_LCCL": "off",
+ "HCCL_OP_EXPANSION_MODE": "AIV",
+ "MS_ALLOC_CONF": "enable_vmm:True",
+ "LCCL_DETERMINISTIC": "1",
+ "HCCL_DETERMINISTIC": "true",
+ "ATB_MATMUL_SHUFFLE_K_ENABLE": "0",
+ "ATB_LLM_LCOC_ENABLE": "0",
+ "VLLM_USE_V1": "0"
+}
+# set env
+env_manager.setup_ai_environment(env_vars)
+# isort: off
+import vllm_mindspore
+from vllm import LLM, SamplingParams
+# isort: on
+
+
+def test_mf_qwen3():
+ """
+ test case qwen3 8B
+ """
+
+ # Sample prompts.
+ prompts = [
+ "You are a helpful assistant.<|User|>将文本分类为中性、负面或正面。 \n文本:我认为这次假期还可以。 \n情感:<|Assistant|>\n",
+ ]
+
+ # Create a sampling params object.
+ sampling_params = SamplingParams(temperature=0.0, max_tokens=10, top_k=1)
+
+ # Create an LLM.
+ llm = LLM(model="/home/workspace/mindspore_dataset/weight/Qwen3-8B",
+ gpu_memory_utilization=0.9,
+ tensor_parallel_size=2)
+ # Generate texts from the prompts. The output is a list of RequestOutput objects
+ # that contain the prompt, generated text, and other information.
+ outputs = llm.generate(prompts, sampling_params)
+ except_list = ['好的,我需要分析用户提供的文本“我认为']
+ # Print the outputs.
+ for i, output in enumerate(outputs):
+ prompt = output.prompt
+ generated_text = output.outputs[0].text
+ print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
+ assert generated_text == except_list[i]
+
+ # unset env
+ env_manager.unset_all()
diff --git a/tests/st/python/cases_parallel/vllm_mf_qwen3_8b_v1.py b/tests/st/python/cases_parallel/vllm_mf_qwen3_8b_v1.py
new file mode 100644
index 0000000000000000000000000000000000000000..aeb62ef7af753cda7509f7ef6b96da8c91d2379c
--- /dev/null
+++ b/tests/st/python/cases_parallel/vllm_mf_qwen3_8b_v1.py
@@ -0,0 +1,75 @@
+#!/usr/bin/env python3
+# encoding: utf-8
+# Copyright 2025 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""test mf qwen."""
+import os
+
+import pytest
+
+from tests.st.python import set_env
+
+env_manager = set_env.EnvVarManager()
+# def env
+env_vars = {
+ "ASCEND_CUSTOM_PATH": os.path.expandvars("$ASCEND_HOME_PATH/../"),
+ "vLLM_MODEL_BACKEND": "MindFormers",
+ "MS_ENABLE_LCCL": "off",
+ "HCCL_OP_EXPANSION_MODE": "AIV",
+ "MS_ALLOC_CONF": "enable_vmm:True",
+ "LCCL_DETERMINISTIC": "1",
+ "HCCL_DETERMINISTIC": "true",
+ "ATB_MATMUL_SHUFFLE_K_ENABLE": "0",
+ "ATB_LLM_LCOC_ENABLE": "0",
+ "VLLM_USE_V1": "1"
+}
+# set env
+env_manager.setup_ai_environment(env_vars)
+# isort: off
+import vllm_mindspore
+from vllm import LLM, SamplingParams
+# isort: on
+
+
+def test_mf_qwen3():
+ """
+ test case qwen3 8B
+ """
+
+ # Sample prompts.
+ prompts = [
+ "You are a helpful assistant.<|User|>将文本分类为中性、负面或正面。 \n文本:我认为这次假期还可以。 \n情感:<|Assistant|>\n",
+ ]
+
+ # Create a sampling params object.
+ sampling_params = SamplingParams(temperature=0.0, max_tokens=10, top_k=1)
+
+ # Create an LLM.
+ llm = LLM(model="/home/workspace/mindspore_dataset/weight/Qwen3-8B",
+ gpu_memory_utilization=0.9,
+ tensor_parallel_size=2)
+ # Generate texts from the prompts. The output is a list of RequestOutput objects
+ # that contain the prompt, generated text, and other information.
+ outputs = llm.generate(prompts, sampling_params)
+ except_list = ['好的,我需要分析用户提供的文本“我认为']
+ # Print the outputs.
+ for i, output in enumerate(outputs):
+ prompt = output.prompt
+ generated_text = output.outputs[0].text
+ print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
+ assert generated_text == except_list[i]
+
+ # unset env
+ env_manager.unset_all()
diff --git a/tests/st/python/cases_parallel/vllm_qwen2_5_vl_7b_v1.py b/tests/st/python/cases_parallel/vllm_qwen2_5_vl_7b_v1.py
new file mode 100644
index 0000000000000000000000000000000000000000..d776c8d93dacc2fc6a3fc28453083c2de9ba320c
--- /dev/null
+++ b/tests/st/python/cases_parallel/vllm_qwen2_5_vl_7b_v1.py
@@ -0,0 +1,101 @@
+#!/usr/bin/env python3
+# encoding: utf-8
+# Copyright 2025 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""test mf qwen2.5 vl 7B."""
+import os
+
+from PIL import Image
+
+from tests.st.python import set_env
+from tests.st.python.cases_parallel.similarity import compare_distance
+
+env_manager = set_env.EnvVarManager()
+# def env
+env_vars = {
+ "ASCEND_CUSTOM_PATH": os.path.expandvars("$ASCEND_HOME_PATH/../"),
+ "HCCL_OP_EXPANSION_MODE": "AIV",
+ "MS_ALLOC_CONF": "enable_vmm:True",
+ "LCCL_DETERMINISTIC": "1",
+ "HCCL_DETERMINISTIC": "true",
+ "ATB_MATMUL_SHUFFLE_K_ENABLE": "0",
+ "ATB_LLM_LCOC_ENABLE": "0",
+}
+# set env
+env_manager.setup_ai_environment(env_vars)
+# isort: off
+import vllm_mindspore
+from vllm import LLM, SamplingParams
+
+# isort: on
+
+PROMPT_TEMPLATE = (
+ "<|im_start|>system\nYou are a helpful assistant.<|im_end|>"
+ "\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>"
+ "What is in the image?<|im_end|>\n"
+ "<|im_start|>assistant\n")
+
+
+def pil_image() -> Image.Image:
+ image_path = "images/1080p.jpeg"
+ return Image.open(image_path)
+
+
+def test_qwen2_5_vl_7b_v1():
+ """
+ test case qwen2.5 vl 7B
+ """
+ inputs = [{
+ "prompt": PROMPT_TEMPLATE,
+ "multi_modal_data": {
+ "image": pil_image()
+ },
+ }]
+
+ # Create a sampling params object.
+ sampling_params = SamplingParams(temperature=0.0, max_tokens=128, top_k=1)
+
+ # Create an LLM.
+ llm = LLM(
+ model="/home/workspace/mindspore_dataset/weight/Qwen2.5-VL-7B-Instruct",
+ gpu_memory_utilization=0.9,
+ tensor_parallel_size=2,
+ max_model_len=4096,
+ max_num_seqs=32,
+ max_num_batched_tokens=32)
+ except_list = [
+ 'The image depicts a serene and picturesque landscape. It features a lush green meadow with '
+ 'wildflowers in the foreground. In the middle ground, there are small wooden huts, possibly used for'
+ ' storage or as simple shelters. Beyond the meadow, there is a calm body of water, likely a lake,'
+ ' surrounded by dense forests. In the background, majestic mountains rise, their peaks partially '
+ 'covered with snow, suggesting a high-altitude location. The sky is partly cloudy, with soft '
+ 'lighting that enhances the tranquil and idyllic atmosphere of the scene. This type of landscape '
+ 'is often associated with alpine regions.'
+ ]
+
+ for i in range(3):
+ # Generate texts from the prompts. The output is a list of RequestOutput objects
+ # that contain the prompt, generated text, and other information.
+ outputs = llm.generate(inputs, sampling_params)
+ # Print the outputs.
+ for i, output in enumerate(outputs):
+ generated_text = output.outputs[0].text
+ print(
+ f"Prompt: {output.prompt!r}, Generated text: {generated_text!r}"
+ )
+ compare_distance(generated_text, except_list[0], bench_sim=0.95)
+
+ # unset env
+ env_manager.unset_all()
diff --git a/tests/st/python/images/1080p.jpeg b/tests/st/python/images/1080p.jpeg
new file mode 100644
index 0000000000000000000000000000000000000000..0d298985cf4468902c27eaca2f23f74dae8c80ab
Binary files /dev/null and b/tests/st/python/images/1080p.jpeg differ
diff --git a/tests/st/python/test_cases_parallel.py b/tests/st/python/test_cases_parallel.py
index 18c894f348349365d450b515d3d301f5e3186922..3724e45fd922aee494b1f6a74edaff4fc6087d03 100644
--- a/tests/st/python/test_cases_parallel.py
+++ b/tests/st/python/test_cases_parallel.py
@@ -50,21 +50,24 @@ def test_cases_parallel_part0():
"""
commands = [
("export ASCEND_RT_VISIBLE_DEVICES=0,1 && export LCAL_COMM_ID=127.0.0.1:10068 && "
+ "export HCCL_IF_BASE_PORT=61000 && "
"pytest -s -v cases_parallel/vllm_mf_qwen_7b.py::test_mf_qwen > vllm_mf_qwen_7b_test_mf_qwen.log",
"vllm_mf_qwen_7b_test_mf_qwen.log"),
("export ASCEND_RT_VISIBLE_DEVICES=2,3 && export LCAL_COMM_ID=127.0.0.1:10069 && "
+ "export HCCL_IF_BASE_PORT=61002 && "
"pytest -s -v cases_parallel/vllm_mf_qwen_7b_chunk_prefill.py::test_mf_qwen_7b_chunk_prefill "
"> vllm_mf_qwen_7b_chunk_prefill_test_mf_qwen_7b_chunk_prefill.log",
"vllm_mf_qwen_7b_chunk_prefill_test_mf_qwen_7b_chunk_prefill.log"),
("export ASCEND_RT_VISIBLE_DEVICES=4,5 && export LCAL_COMM_ID=127.0.0.1:10070 && "
+ "export HCCL_IF_BASE_PORT=61004 &&"
"pytest -s -v cases_parallel/vllm_mf_qwen_7b_chunk_prefill_v1.py::test_mf_qwen_7b_chunk_prefill "
"> vllm_mf_qwen_7b_chunk_prefill_v1_test_mf_qwen_7b_chunk_prefill.log",
"vllm_mf_qwen_7b_chunk_prefill_v1_test_mf_qwen_7b_chunk_prefill.log"),
("export ASCEND_RT_VISIBLE_DEVICES=6,7 && export LCAL_COMM_ID=127.0.0.1:10071 && "
- "pytest -s -v cases_parallel/vllm_mf_qwen_7b_cp_pc_mss.py::test_mf_qwen_7b_cp_pc_mss "
- "> vllm_mf_qwen_7b_cp_pc_mss_test_mf_qwen_7b_cp_pc_mss.log",
- "vllm_mf_qwen_7b_cp_pc_mss_test_mf_qwen_7b_cp_pc_mss.log"),
-
+ "export HCCL_IF_BASE_PORT=61006 && "
+ "pytest -s -v cases_parallel/multilora_inference.py::test_multilora_inference "
+ "> multilora_inference_test_multilora_inference.log",
+ "multilora_inference_test_multilora_inference.log")
]
with Pool(len(commands)) as pool:
@@ -83,18 +86,23 @@ def test_cases_parallel_part1():
"""
commands = [
("export ASCEND_RT_VISIBLE_DEVICES=0,1 && export LCAL_COMM_ID=127.0.0.1:10068 && "
+ "export HCCL_IF_BASE_PORT=61000 && "
"pytest -s -v cases_parallel/vllm_mf_qwen_7b_mss.py::test_mf_qwen_7b_mss "
"> vllm_mf_qwen_7b_mss_test_mf_qwen_7b_mss.log",
"vllm_mf_qwen_7b_mss_test_mf_qwen_7b_mss.log"),
("export ASCEND_RT_VISIBLE_DEVICES=2,3 && export LCAL_COMM_ID=127.0.0.1:10069 && "
+ "export HCCL_IF_BASE_PORT=61002 && "
"pytest -s -v cases_parallel/vllm_mf_qwen_7b_prefix_caching.py::test_mf_qwen_7b_prefix_caching "
"> vllm_mf_qwen_7b_prefix_caching_test_mf_qwen_7b_prefix_caching.log",
"vllm_mf_qwen_7b_prefix_caching_test_mf_qwen_7b_prefix_caching.log"),
("export ASCEND_RT_VISIBLE_DEVICES=4,5 && export LCAL_COMM_ID=127.0.0.1:10070 && "
+ "export HCCL_IF_BASE_PORT=61004 && "
"pytest -s -v cases_parallel/vllm_mf_qwen_7b_prefix_caching_v1.py::test_mf_qwen_7b_prefix_caching "
"> vllm_mf_qwen_7b_prefix_caching_v1_test_mf_qwen_7b_prefix_caching.log",
- "vllm_mf_qwen_7b_prefix_caching_v1_test_mf_qwen_7b_prefix_caching.log"),
+ "vllm_mf_qwen_7b_prefix_caching_v1_test_mf_qwen_7b_prefix_caching.log"
+ ),
("export ASCEND_RT_VISIBLE_DEVICES=6,7 && export LCAL_COMM_ID=127.0.0.1:10071 && "
+ "export HCCL_IF_BASE_PORT=61006 && "
"pytest -s -v cases_parallel/vllm_mf_qwen_7b_v1.py::test_mf_qwen > vllm_mf_qwen_7b_v1_test_mf_qwen.log",
"vllm_mf_qwen_7b_v1_test_mf_qwen.log")
]
@@ -103,6 +111,7 @@ def test_cases_parallel_part1():
results = list(pool.imap(run_command, commands))
check_results(commands, results)
+
@pytest.mark.level0
@pytest.mark.platform_arm_ascend910b_training
@pytest.mark.env_single
@@ -114,14 +123,17 @@ def test_cases_parallel_part2():
"""
commands = [
("export ASCEND_RT_VISIBLE_DEVICES=0,1 && export LCAL_COMM_ID=127.0.0.1:10068 && "
+ "export HCCL_IF_BASE_PORT=61000 && "
"pytest -s -v cases_parallel/vllm_qwen_7b.py::test_vllm_qwen "
"> vllm_qwen_7b_test_vllm_qwen.log",
"vllm_qwen_7b_test_vllm_qwen.log"),
("export ASCEND_RT_VISIBLE_DEVICES=2,3 && export LCAL_COMM_ID=127.0.0.1:10069 && "
+ "export HCCL_IF_BASE_PORT=61002 && "
"pytest -s -v cases_parallel/vllm_qwen_7b_v1.py::test_vllm_qwen "
"> vllm_qwen_7b_v1_test_vllm_qwen.log",
"vllm_qwen_7b_v1_test_vllm_qwen.log"),
("export ASCEND_RT_VISIBLE_DEVICES=4,5,6,7 && export LCAL_COMM_ID=127.0.0.1:10070 && "
+ "export HCCL_IF_BASE_PORT=61004 && "
"pytest -s -v cases_parallel/shm_broadcast.py::test_shm_broadcast "
"> shm_broadcast_test_shm_broadcast.log",
"shm_broadcast_test_shm_broadcast.log")
@@ -130,3 +142,167 @@ def test_cases_parallel_part2():
with Pool(len(commands)) as pool:
results = list(pool.imap(run_command, commands))
check_results(commands, results)
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend910b_training
+@pytest.mark.env_single
+def test_cases_parallel_part3():
+ """
+ Feature: test cases parallel.
+ Description: test cases parallel.
+ Expectation: Pass.
+ """
+ commands = [
+ ("export ASCEND_RT_VISIBLE_DEVICES=0,1 && export LCAL_COMM_ID=127.0.0.1:10068 && "
+ "export HCCL_IF_BASE_PORT=61000 && "
+ "pytest -s -v cases_parallel/vllm_deepseek_bf16_part.py::test_deepseek_r1_bf16 "
+ "> vllm_deepseek_bf16_part_test_deepseek_r1_bf16.log",
+ "vllm_deepseek_bf16_part_test_deepseek_r1_bf16.log"),
+ ("export ASCEND_RT_VISIBLE_DEVICES=2,3 && export LCAL_COMM_ID=127.0.0.1:10069 && "
+ "export HCCL_IF_BASE_PORT=61002 && "
+ "pytest -s -v cases_parallel/vllm_deepseek_bf16_part_v1.py::test_deepseek_r1_bf16 "
+ "> vllm_deepseek_bf16_part_v1_test_deepseek_r1_bf16.log",
+ "vllm_deepseek_bf16_part_v1_test_deepseek_r1_bf16.log"),
+ ("export ASCEND_RT_VISIBLE_DEVICES=4,5,6,7 && export LCAL_COMM_ID=127.0.0.1:10070 && "
+ "export HCCL_IF_BASE_PORT=61004 && "
+ "pytest -s -v cases_parallel/vllm_deepseek_gptq_a16w4.py::test_deepseek_r1_gptq_a16w4 "
+ "> vllm_deepseek_gptq_a16w4_test_deepseek_r1_gptq_a16w4.log",
+ "vllm_deepseek_gptq_a16w4_test_deepseek_r1_gptq_a16w4.log")
+ ]
+
+ with Pool(len(commands)) as pool:
+ results = list(pool.imap(run_command, commands))
+ check_results(commands, results)
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend910b_training
+@pytest.mark.env_single
+def test_cases_parallel_part4():
+ """
+ Feature: test cases parallel.
+ Description: test cases parallel.
+ Expectation: Pass.
+ """
+ commands = [
+ ("export ASCEND_RT_VISIBLE_DEVICES=0,1 && export LCAL_COMM_ID=127.0.0.1:10068 && "
+ "export HCCL_IF_BASE_PORT=61000 && "
+ "pytest -s -v cases_parallel/vllm_deepseek_osl.py::test_deepseek_r1_mss "
+ "> vllm_deepseek_osl_test_deepseek_r1_mss.log",
+ "vllm_deepseek_osl_test_deepseek_r1_mss.log"),
+ ("export ASCEND_RT_VISIBLE_DEVICES=2,3 && export LCAL_COMM_ID=127.0.0.1:10069 && "
+ "export HCCL_IF_BASE_PORT=61002 && "
+ "pytest -s -v cases_parallel/vllm_deepseek_part.py::test_deepseek_r1 "
+ "> vllm_deepseek_part_test_deepseek_r1.log",
+ "vllm_deepseek_part_test_deepseek_r1.log"),
+ ("export ASCEND_RT_VISIBLE_DEVICES=4,5 && export LCAL_COMM_ID=127.0.0.1:10070 && "
+ "export HCCL_IF_BASE_PORT=61004 && "
+ "pytest -s -v cases_parallel/vllm_deepseek_part.py::test_deepseek_mtp "
+ "> vllm_deepseek_part_test_deepseek_mtp.log",
+ "vllm_deepseek_part_test_deepseek_mtp.log"),
+ ("export ASCEND_RT_VISIBLE_DEVICES=6,7 && export LCAL_COMM_ID=127.0.0.1:10071 && "
+ "export HCCL_IF_BASE_PORT=61006 && "
+ "pytest -s -v cases_parallel/vllm_deepseek_part_v1.py::test_deepseek_r1 "
+ "> vllm_deepseek_part_v1_test_deepseek_r1.log",
+ "vllm_deepseek_part_v1_test_deepseek_r1.log")
+ ]
+
+ with Pool(len(commands)) as pool:
+ results = list(pool.imap(run_command, commands))
+ check_results(commands, results)
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend910b_training
+@pytest.mark.env_single
+def test_cases_parallel_part5():
+ """
+ Feature: test cases parallel.
+ Description: test cases parallel.
+ Expectation: Pass.
+ """
+ commands = [
+ ("export ASCEND_RT_VISIBLE_DEVICES=0,1 && export LCAL_COMM_ID=127.0.0.1:10068 && "
+ "export HCCL_IF_BASE_PORT=61000 && "
+ "pytest -s -v cases_parallel/vllm_mf_qwen3_8b.py::test_mf_qwen3 "
+ "> vllm_mf_qwen3_8b_test_mf_qwen3.log",
+ "vllm_mf_qwen3_8b_test_mf_qwen3.log"),
+ ("export ASCEND_RT_VISIBLE_DEVICES=2,3 && export LCAL_COMM_ID=127.0.0.1:10069 && "
+ "export HCCL_IF_BASE_PORT=61002 && "
+ "pytest -s -v cases_parallel/vllm_mf_qwen3_8b_v1.py::test_mf_qwen3 "
+ "> vllm_mf_qwen3_8b_v1_test_mf_qwen3.log",
+ "vllm_mf_qwen3_8b_v1_test_mf_qwen3.log"),
+ ("export ASCEND_RT_VISIBLE_DEVICES=4 && export LCAL_COMM_ID=127.0.0.1:10070 && "
+ "export HCCL_IF_BASE_PORT=61004 && "
+ "pytest -s -v cases_parallel/vllm_llama3.py::test_vllm_llama3_8b "
+ "> vllm_llama3_8b_test_vllm_llama3.log",
+ "vllm_llama3_8b_test_vllm_llama3.log"),
+ ("export ASCEND_RT_VISIBLE_DEVICES=5 && export LCAL_COMM_ID=127.0.0.1:10071 && "
+ "export HCCL_IF_BASE_PORT=61006 && "
+ "pytest -s -v cases_parallel/vllm_llama3.py::test_vllm_llama3_1b "
+ "> vllm_llama3_1b_test_vllm_llama3.log",
+ "vllm_llama3_1b_test_vllm_llama3.log"),
+ ]
+
+ with Pool(len(commands)) as pool:
+ results = list(pool.imap(run_command, commands))
+ check_results(commands, results)
+
+
+@pytest.mark.level0
+@pytest.mark.platform_arm_ascend910b_training
+@pytest.mark.env_single
+def test_cases_parallel_part6():
+ """
+ Feature: test cases parallel.
+ Description: test cases parallel.
+ Expectation: Pass.
+ """
+ commands = [
+ ("export ASCEND_RT_VISIBLE_DEVICES=0,1 && export LCAL_COMM_ID=127.0.0.1:10068 && "
+ "export HCCL_IF_BASE_PORT=61000 && "
+ "pytest -s -v cases_parallel/vllm_qwen2_5_vl_7b_v1.py::test_qwen2_5_vl_7b_v1 "
+ "> vllm_qwen2_5_vl_7b_v1.log", "vllm_qwen2_5_vl_7b_v1.log"),
+ ]
+
+ with Pool(len(commands)) as pool:
+ results = list(pool.imap(run_command, commands))
+ check_results(commands, results)
+
+
+@pytest.mark.level1
+@pytest.mark.platform_arm_ascend910b_training
+@pytest.mark.env_single
+def test_cases_parallel_level1_part0():
+ """
+ Feature: test cases parallel.
+ Description: test cases parallel.
+ Expectation: Pass.
+ """
+ commands = [
+ ("export ASCEND_RT_VISIBLE_DEVICES=0,1 && export LCAL_COMM_ID=127.0.0.1:10068 && "
+ "export HCCL_IF_BASE_PORT=61000 && "
+ "pytest -s -v cases_parallel/vllm_mf_qwen_7b_cp_pc_mss.py::test_mf_qwen_7b_cp_pc_mss "
+ "> vllm_mf_qwen_7b_cp_pc_mss_test_mf_qwen_7b_cp_pc_mss.log",
+ "vllm_mf_qwen_7b_cp_pc_mss_test_mf_qwen_7b_cp_pc_mss.log"),
+ ("export ASCEND_RT_VISIBLE_DEVICES=2,3 && export LCAL_COMM_ID=127.0.0.1:10069 && "
+ "export HCCL_IF_BASE_PORT=61002 && "
+ "pytest -s -v cases_parallel/vllm_deepseek_osl.py::test_deepseek_r1 "
+ "> vllm_deepseek_osl_test_deepseek_r1.log",
+ "vllm_deepseek_osl_test_deepseek_r1.log"),
+ ("export ASCEND_RT_VISIBLE_DEVICES=4,5 && export LCAL_COMM_ID=127.0.0.1:10070 && "
+ "export HCCL_IF_BASE_PORT=61004 && "
+ "pytest -s -v cases_parallel/vllm_deepseek_smoothquant.py::test_deepseek_r1 "
+ "> vllm_deepseek_smoothquant_test_deepseek_r1.log",
+ "vllm_deepseek_smoothquant_test_deepseek_r1.log"),
+ ("export ASCEND_RT_VISIBLE_DEVICES=6,7 && export LCAL_COMM_ID=127.0.0.1:10071 && "
+ "export HCCL_IF_BASE_PORT=61006 && "
+ "pytest -s -v cases_parallel/vllm_deepseek_smoothquant_mss.py::test_deepseek_r1_mss "
+ "> vllm_deepseek_smoothquant_mss_test_deepseek_r1_mss.log",
+ "vllm_deepseek_smoothquant_mss_test_deepseek_r1_mss.log")
+ ]
+
+ with Pool(len(commands)) as pool:
+ results = list(pool.imap(run_command, commands))
+ check_results(commands, results)
diff --git a/tests/st/python/test_sampler.py b/tests/st/python/test_sampler.py
index b554717c72e9db633d0adf4818194a5b546d9808..8066748f49f92ed27bd3c6b83ccbb4361be5ff57 100644
--- a/tests/st/python/test_sampler.py
+++ b/tests/st/python/test_sampler.py
@@ -29,7 +29,7 @@ from transformers import GenerationConfig, GenerationMixin
import vllm.envs as envs
from vllm_mindspore.model_executor.layers.sampler import Sampler
-from vllm_mindspore.model_executor.sampling_metadata import SamplingMetadata
+from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_random_seed
from vllm_mindspore.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.utils import Counter, is_pin_memory_available
diff --git a/tests/st/python/test_vllm_deepseek_mix_parallel.py b/tests/st/python/test_vllm_deepseek_mix_parallel.py
index d23097c6abc653350c6fe1f0f2a642b8eda39ab3..eadecd8cc5b3573c14908b32d32ec22edb66c592 100644
--- a/tests/st/python/test_vllm_deepseek_mix_parallel.py
+++ b/tests/st/python/test_vllm_deepseek_mix_parallel.py
@@ -37,7 +37,7 @@ env_vars = {
"HCCL_DETERMINISTIC": "true",
"ATB_MATMUL_SHUFFLE_K_ENABLE": "0",
"ATB_LLM_LCOC_ENABLE": "0",
- "HCCL_IF_BASE_PORT": "60000",
+ "HCCL_IF_BASE_PORT": "61000",
"LCAL_COMM_ID": "127.0.0.1:10068"
}
env_manager.setup_ai_environment(env_vars)
diff --git a/vllm_mindspore/__init__.py b/vllm_mindspore/__init__.py
index 5892937a8957829c8a47ae1df13ada581137f7e6..356973619c1ae5027e4046ec40e38e1d07a6049d 100644
--- a/vllm_mindspore/__init__.py
+++ b/vllm_mindspore/__init__.py
@@ -1,5 +1,3 @@
-#!/usr/bin/env python3
-# encoding: utf-8
# Copyright 2025 Huawei Technologies Co., Ltd
# Copyright 2024 The vLLM team.
#
@@ -14,7 +12,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-# ============================================================================
+"""Main entry point for monkey patching vllm."""
+
+# type: ignore
+# isort:skip_file
import sys
import warnings
@@ -27,6 +28,7 @@ if "vllm" in sys.modules:
# 1. set env before import mindspore.
from vllm_mindspore.scripts import env_setup
+
env_setup()
# 2. update the log configuration ahead of other modifications.
@@ -49,45 +51,67 @@ import vllm.utils
vllm.utils.current_platform = ascend_platform
import vllm.attention.selector
+
vllm.attention.selector.current_platform = ascend_platform
import vllm.engine.arg_utils
from vllm_mindspore.engine.arg_utils import _is_v1_supported_oracle
+
vllm.engine.arg_utils.EngineArgs._is_v1_supported_oracle = _is_v1_supported_oracle
import vllm.v1.engine.core
from vllm_mindspore.v1.engine.core import shutdown
+
vllm.v1.engine.core.DPEngineCoreProc.shutdown = shutdown
from vllm_mindspore.utils import (
- direct_register_custom_op,
make_tensor_with_pad,
async_tensor_h2d,
- get_dtype_size,
- ascend_device_count_stateless,
ascend_is_initialized,
ms_memory_profiling,
)
-vllm.utils.direct_register_custom_op = direct_register_custom_op
vllm.utils.make_tensor_with_pad = make_tensor_with_pad
vllm.utils.async_tensor_h2d = async_tensor_h2d
-vllm.utils.get_dtype_size = get_dtype_size
-vllm.utils.cuda_device_count_stateless = ascend_device_count_stateless
vllm.utils.cuda_is_initialized = ascend_is_initialized
vllm.utils.memory_profiling = ms_memory_profiling
-vllm.config.cuda_device_count_stateless = ascend_device_count_stateless
-import vllm.executor
+import vllm.lora.utils
+
+from vllm_mindspore.model_executor.layers.linear import LinearBase
+from vllm_mindspore.lora.utils import _all_lora_classes
+
+vllm.lora.utils._all_lora_classes = _all_lora_classes
+vllm.lora.utils.LinearBase = LinearBase
-vllm.executor.cuda_device_count_stateless = ascend_device_count_stateless
+import vllm.lora.models
+from vllm_mindspore.lora.models import register_module, from_local_checkpoint, from_lora_tensors
+
+vllm.lora.models.LoRAModelManager.register_module = register_module
+vllm.lora.models.LoRAModel.from_local_checkpoint = from_local_checkpoint
+vllm.lora.models.LoRAModel.from_lora_tensors = from_lora_tensors
+
+from vllm_mindspore.lora.layers import (ColumnParallelLinearWithLoRA,
+ MergedColumnParallelLinearWithLoRA,
+ MergedQKVParallelLinearWithLoRA,
+ QKVParallelLinearWithLoRA,
+ RowParallelLinearWithLoRA)
+
+import vllm.lora.layers
+
+vllm.lora.layers.ColumnParallelLinearWithLoRA = ColumnParallelLinearWithLoRA
+vllm.lora.layers.MergedColumnParallelLinearWithLoRA = MergedColumnParallelLinearWithLoRA
+vllm.lora.layers.MergedQKVParallelLinearWithLoRA = MergedQKVParallelLinearWithLoRA
+vllm.lora.layers.QKVParallelLinearWithLoRA = QKVParallelLinearWithLoRA
+vllm.lora.layers.RowParallelLinearWithLoRA = RowParallelLinearWithLoRA
+
+import vllm.executor
from vllm_mindspore.model_executor.models.registry import (
MindSporeModelRegistry,
_SUBPROCESS_COMMAND,
)
-
vllm.config.ModelRegistry = MindSporeModelRegistry
import vllm.model_executor
@@ -102,24 +126,14 @@ from vllm.model_executor.model_loader import get_model_architecture
vllm.model_executor.model_loader.get_model_architecture = get_ms_model_architecture
vllm.model_executor.model_loader.utils.get_model_architecture = (
- get_ms_model_architecture
-)
+ get_ms_model_architecture)
vllm.model_executor.model_loader.loader.get_model_architecture = (
- get_ms_model_architecture
-)
-
-from vllm_mindspore.model_executor.sampling_metadata import (
- SequenceGroupToSample,
- SamplingMetadataCache,
- SamplingMetadata,
-)
+ get_ms_model_architecture)
-vllm.model_executor.SamplingMetadataCache = SamplingMetadataCache
-vllm.model_executor.SamplingMetadata = SamplingMetadata
-vllm.model_executor.sampling_metadata.SequenceGroupToSample = SequenceGroupToSample
-vllm.model_executor.sampling_metadata.SamplingMetadataCache = SamplingMetadataCache
-vllm.model_executor.sampling_metadata.SamplingMetadata = SamplingMetadata
+from vllm_mindspore.model_executor.sampling_metadata import SamplingTensors
+vllm.model_executor.sampling_metadata.async_tensor_h2d = async_tensor_h2d
+vllm.model_executor.sampling_metadata.SamplingTensors.from_lists = SamplingTensors.from_lists
from vllm_mindspore.worker.cache_engine import (
ms_allocate_kv_cache,
ms_swap_in,
@@ -133,12 +147,10 @@ vllm.worker.cache_engine.CacheEngine.swap_in = ms_swap_in
vllm.worker.cache_engine.CacheEngine.swap_out = ms_swap_out
from vllm_mindspore.model_executor.model_loader.weight_utils import (
- safetensors_weights_iterator,
-)
+ safetensors_weights_iterator, )
vllm.model_executor.model_loader.loader.safetensors_weights_iterator = (
- safetensors_weights_iterator
-)
+ safetensors_weights_iterator)
from vllm_mindspore.worker.worker import _warm_up_model
from vllm_mindspore.worker.profile import (
@@ -158,15 +170,13 @@ from vllm_mindspore.worker.model_runner import (
)
vllm.worker.model_runner.ModelInputForGPUBuilder._get_cuda_graph_pad_size = (
- _get_cuda_graph_pad_size
-)
+ _get_cuda_graph_pad_size)
vllm.worker.model_runner.GPUModelRunnerBase._dummy_run = _dummy_run
import vllm.worker.multi_step_model_runner
vllm.worker.multi_step_model_runner._get_supported_attention_backends = (
- _get_supported_attention_backends
-)
+ _get_supported_attention_backends)
from vllm_mindspore.executor.multiproc_worker_utils import (
get_mp_context as ms_get_mp_context,
@@ -183,8 +193,10 @@ import vllm.executor.multiproc_worker_utils
vllm.executor.multiproc_worker_utils.ProcessWorkerWrapper.terminate_worker = ms_terminate_worker
import vllm.v1.executor.multiproc_executor
+
vllm.v1.executor.multiproc_executor.get_mp_context = ms_get_mp_context
import vllm.v1.utils
+
vllm.v1.utils.get_mp_context = ms_get_mp_context
from vllm_mindspore.executor.ray_gpu_executor import (
@@ -219,6 +231,7 @@ vllm.config.ParallelConfig.has_unfinished_dp = has_unfinished_dp
from .utils import update_modules
from vllm_mindspore.attention.backends import ms_attn
+
update_modules("vllm.attention.backends.flash_attn", ms_attn)
from vllm_mindspore.worker.spec_decode_worker import (
@@ -229,63 +242,91 @@ from vllm_mindspore.worker.spec_decode_worker import (
_merge_outputs,
)
from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker
+
SpecDecodeWorker.__init__ = spec_decode_worker_init
SpecDecodeWorker._verify_tokens = _verify_tokens
SpecDecodeWorker._run_no_spec = _run_no_spec
from vllm.model_executor.layers.spec_decode_base_sampler import SpecDecodeBaseSampler
+
SpecDecodeBaseSampler._create_output = _create_output
from vllm.spec_decode.top1_proposer import Top1Proposer
+
Top1Proposer._merge_outputs = _merge_outputs
from vllm_mindspore.model_executor.layers.rejection_sampler import _smallest_positive_value, _multinomial
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
+
RejectionSampler._smallest_positive_value = _smallest_positive_value
-RejectionSampler._smallest_positive_value.__set_name__(RejectionSampler, '_smallest_positive_value')
+RejectionSampler._smallest_positive_value.__set_name__(
+ RejectionSampler, '_smallest_positive_value')
vllm.model_executor.layers.rejection_sampler._multinomial = _multinomial
######### for multi-model
from vllm_mindspore.inputs.registry import call_hf_processor
from vllm.inputs.registry import InputProcessingContext
+
InputProcessingContext.call_hf_processor = call_hf_processor
-from vllm_mindspore.multimodal.inputs import as_kwargs
+from vllm_mindspore.multimodal.inputs import as_kwargs, from_items, MultiModalFieldElem
from vllm.multimodal.inputs import MultiModalKwargs
+
MultiModalKwargs.as_kwargs = as_kwargs
+MultiModalKwargs.from_items = from_items
+
+vllm.multimodal.inputs.MultiModalFieldElem = MultiModalFieldElem
+
+from vllm_mindspore.model_executor.layers.rotary_embedding import InferMRotaryEmbedding # type: ignore[attr-defined]
-from vllm_mindspore.model_executor.layers.rotary_embedding import InferMRotaryEmbedding
vllm.model_executor.layers.rotary_embedding.MRotaryEmbedding = InferMRotaryEmbedding
+# patch for V1
from vllm_mindspore.v1.sample import rejection_sampler
+
update_modules("vllm.v1.sample.rejection_sampler", rejection_sampler)
from vllm_mindspore.v1.spec_decode import eagle
+
update_modules("vllm.v1.spec_decode.eagle", eagle)
-from vllm_mindspore.v1.attention.backends import flash_attn
-import vllm.v1.attention.backends
-sys.modules['vllm.v1.attention.backends.flash_attn'] = flash_attn
-import vllm.v1.attention.backends.flash_attn
+from vllm_mindspore.v1.attention.backends import ms_attn
+
+update_modules("vllm.v1.attention.backends.flash_attn", ms_attn)
import vllm.v1.worker.gpu_model_runner
from vllm_mindspore.v1.worker.gpu_model_runner import _prepare_inputs
+
vllm.v1.worker.gpu_model_runner.GPUModelRunner._prepare_inputs = _prepare_inputs
+from vllm_mindspore.v1.worker.gpu_model_runner import _calc_mrope_positions
+
+vllm.v1.worker.gpu_model_runner.GPUModelRunner._calc_mrope_positions = _calc_mrope_positions
+
from vllm_mindspore.v1.worker.gpu_model_runner import _update_states
+
vllm.v1.worker.gpu_model_runner.GPUModelRunner._update_states = _update_states
-from vllm_mindspore.v1.worker.gpu_model_runner import initialize_kv_cache
+from vllm_mindspore.v1.worker.gpu_model_runner import initialize_kv_cache, get_kv_cache_spec
+
vllm.v1.worker.gpu_model_runner.GPUModelRunner.initialize_kv_cache = initialize_kv_cache
+vllm.v1.worker.gpu_model_runner.GPUModelRunner.get_kv_cache_spec = get_kv_cache_spec
+
+from vllm_mindspore.v1.worker.gpu_model_runner import wrapper_gpu_model_runner_execute_model
+from vllm.v1.worker.gpu_model_runner import GPUModelRunner
+vllm.v1.worker.gpu_model_runner.GPUModelRunner.execute_model = \
+ wrapper_gpu_model_runner_execute_model(GPUModelRunner.execute_model)
import vllm.v1.worker.block_table
from vllm_mindspore.v1.worker.block_table import BlockTable
+
vllm.v1.worker.block_table.BlockTable = BlockTable
vllm.v1.worker.gpu_input_batch.BlockTable = BlockTable
import vllm.v1.worker.gpu_input_batch
from vllm_mindspore.v1.worker.gpu_input_batch import _make_sampling_metadata, _make_prompt_token_ids_tensor
+
vllm.v1.worker.gpu_input_batch.InputBatch._make_sampling_metadata = _make_sampling_metadata
vllm.v1.worker.gpu_model_runner.InputBatch._make_sampling_metadata = _make_sampling_metadata
vllm.v1.worker.gpu_input_batch.InputBatch._make_prompt_token_ids_tensor = _make_prompt_token_ids_tensor
@@ -297,17 +338,19 @@ from vllm_mindspore.v1.worker.gpu_worker import init_device
Worker.__init__ = wrapper_worker_init(Worker.__init__)
Worker.init_device = wrapper_worker_init_device(init_device)
-
import vllm.v1.utils
from vllm_mindspore.v1.utils import copy_slice
+
vllm.v1.utils.copy_slice = copy_slice
vllm.v1.worker.gpu_input_batch.copy_slice = copy_slice
from vllm_mindspore.v1.sample.ops.penalties import _convert_to_tensors
import vllm.v1.sample.ops.penalties
+
vllm.v1.sample.ops.penalties._convert_to_tensors = _convert_to_tensors
import vllm.model_executor.layers.utils
from vllm_mindspore.model_executor.layers.utils import apply_penalties
+
vllm.model_executor.layers.utils.apply_penalties = apply_penalties
vllm.v1.sample.ops.penalties.apply_penalties = apply_penalties
@@ -317,26 +360,51 @@ from vllm_mindspore.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p, ra
import vllm.v1.sample.ops.topk_topp_sampler
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
+
TopKTopPSampler.forward_native = topk_topp_sampler_forward_native
vllm.v1.sample.ops.topk_topp_sampler.apply_top_k_top_p = apply_top_k_top_p
vllm.v1.sample.ops.topk_topp_sampler.random_sample = random_sample
vllm.v1.sample.ops.topk_topp_sampler.apply_top_k_only = apply_top_k_only
from vllm_mindspore.v1.sample.sampler import apply_temperature
import vllm.v1.sample.sampler
+
vllm.v1.sample.sampler.Sampler.apply_temperature = apply_temperature
from vllm_mindspore.distributed.shm_broadcast import initialize_ShmRingBuffer
from vllm.distributed.device_communicators.shm_broadcast import ShmRingBuffer
+
ShmRingBuffer.__init__ = initialize_ShmRingBuffer
from vllm_mindspore.v1.worker.gpu_worker import compile_or_warm_up_model
from vllm.v1.worker.gpu_worker import Worker
+
Worker.compile_or_warm_up_model = compile_or_warm_up_model
+from vllm_mindspore.v1.core.sched.scheduler import update_from_output
+from vllm.v1.core.sched.scheduler import Scheduler
+
+Scheduler.update_from_output = update_from_output
+
+from vllm_mindspore.v1.executor.multiproc_executor import executor_ensure_worker_termination
+from vllm.v1.executor.multiproc_executor import MultiprocExecutor
+
+MultiprocExecutor._ensure_worker_termination = executor_ensure_worker_termination
+
from .utils import check_ready
from vllm_mindspore.engine.multiprocessing.engine import cleanup
import vllm.engine.multiprocessing.engine
+
vllm.engine.multiprocessing.engine.MQLLMEngine.cleanup = cleanup
+from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
+from vllm_mindspore.entrypoints.openai.serving_chat import chat_completion_stream_generator
+
+OpenAIServingChat.chat_completion_stream_generator = chat_completion_stream_generator
+
+from vllm_mindspore.entrypoints.openai.tool_parsers import deepseekv3_tool_parser
+
+sys.modules[
+ 'vllm.entrypoints.openai.tool_parsers.deepseekv3_tool_parser'] = deepseekv3_tool_parser
+
check_ready()
diff --git a/vllm_mindspore/attention/backends/ms_attn.py b/vllm_mindspore/attention/backends/ms_attn.py
index d6123b0a89790ba630888066cb857d995f190c10..bca31412a917f49ca8aee8fa1e3aee615cdb0b78 100644
--- a/vllm_mindspore/attention/backends/ms_attn.py
+++ b/vllm_mindspore/attention/backends/ms_attn.py
@@ -25,8 +25,6 @@ import os
import numpy as np
-import torch
-
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionImpl,
@@ -95,50 +93,86 @@ def advance_step_op(sampled_token_ids,
@dataclass
-class MSAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
- """Metadata for TorchSDPABackend."""
-
- # Currently, input sequences can only contain all prompts
- # or all decoding. True if all sequences are prompts.
- chunked_prefill: bool
- seq_lens: Optional[List[int]] = None # For non-chunked prefill
-
- # For chunked prefill only
+class MsAttentionMetadata(AttentionMetadata):
+ """Metadata for MsAttentionBackend.
+ """
+ # (batch_size,). The sequence length per sequence. Sequence length means
+ # the computed tokens + new tokens None if it is a decoding.
+ seq_lens: Optional[List[int]]
+ # seq_lens stored as a tensor.
+ seq_lens_tensor: Optional[ms.Tensor]
+
+ # NOTE(sang): Definition of context_len, query_len, and seq_len.
+ # |---------- N-1 iteration --------|
+ # |---------------- N iteration ---------------------|
+ # |- tokenA -|......................|-- newTokens ---|
+ # |---------- context_len ----------|
+ # |-------------------- seq_len ---------------------|
+ # |-- query_len ---|
+
+ # Maximum sequence length among prefill batch. 0 if there are decoding
+ # requests only.
+ max_prefill_seq_len: int
+ # Maximum sequence length among decode batch. 0 if there are prefill
+ # requests only.
+ max_decode_seq_len: int
+ # (batch_size,) A tensor of context lengths (tokens that are computed
+ # so far).
+ context_lens_tensor: Optional[ms.Tensor]
+
+ # (batch_size, max_blocks_per_seq).
+ # Block addresses per sequence. (Seq id -> list of physical block)
+ # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
+ # in the kv cache. Each block can contain up to block_size tokens.
+ # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
+ # captured.
+ block_tables: Optional[ms.Tensor]
+
+ # Whether or not if cuda graph is enabled.
+ # Cuda-graph is currently enabled for decoding only.
+ # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
+
+ use_cuda_graph: bool
+
+ # Maximum query length in the batch.
max_query_len: Optional[int] = None
- max_prefill_seq_len: int = 0
- seq_start_loc: Optional[torch.Tensor] = None
- _cached_prefill_metadata: Optional["MSAttentionMetadata"] = None
- _cached_decode_metadata: Optional["MSAttentionMetadata"] = None
- context_lens_tensor: Optional[torch.Tensor] = None
- encoder_seq_start_loc: Optional[torch.Tensor] = None
+ # Max number of query tokens among request in the batch.
max_decode_query_len: Optional[int] = None
- max_kv_len: Optional[int] = None
- query_start_loc: Optional[torch.Tensor] = None
- kv_start_loc: Optional[torch.Tensor] = None
- prefill_block_tables: Optional[torch.Tensor] = None
- query_lens: Optional[List[int]] = None
+ # (batch_size + 1,). The cumulative subquery lengths of the sequences in
+ # the batch, used to index into subquery. E.g., if the subquery length
+ # is [4, 6], it is [0, 4, 10].
+ query_start_loc: Optional[ms.Tensor] = None
+ # (batch_size + 1,). The cumulative sequence lengths of the sequences in
+ # the batch, used to index into sequence. E.g., if the sequence length is
+ # [4, 6], it is [0, 4, 10].
+ seq_start_loc: Optional[ms.Tensor] = None
+
+ _cached_prefill_metadata: Optional["MsAttentionMetadata"] = None
+ _cached_decode_metadata: Optional["MsAttentionMetadata"] = None
# Begin encoder attn & enc/dec cross-attn fields...
+
# Encoder sequence lengths representation
encoder_seq_lens: Optional[List[int]] = None
- encoder_seq_lens_tensor: Optional[torch.Tensor] = None
-
+ encoder_seq_lens_tensor: Optional[ms.Tensor] = None
+ # (batch_size + 1,). The cumulative sequence lengths of the sequences in
+ # the batch, used to index into sequence. E.g., if the sequence length is
+ # [4, 6], it is [0, 4, 10].
+ encoder_seq_start_loc: Optional[ms.Tensor] = None
# Maximum sequence length among encoder sequences
max_encoder_seq_len: Optional[int] = None
-
# Number of tokens input to encoder
num_encoder_tokens: Optional[int] = None
# Cross-attention memory-mapping data structures: slot mapping
# and block tables
- cross_slot_mapping: Optional[torch.Tensor] = None
- cross_block_tables: Optional[torch.Tensor] = None
-
- use_cuda_graph: bool = False
- enable_kv_scales_calculation: bool
+ cross_slot_mapping: Optional[ms.Tensor] = None
+ cross_block_tables: Optional[ms.Tensor] = None
+ # add by vllm-mindspore
+ query_lens: Optional[List[int]] = None
@property
def prefill_metadata(self):
@@ -169,7 +203,7 @@ class MSAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
block_tables = (None if self.block_tables is None else
self.block_tables[:self.num_prefills])
- self._cached_prefill_metadata = MSAttentionMetadata(
+ self._cached_prefill_metadata = MsAttentionMetadata(
num_prefills=self.num_prefills,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=0,
@@ -193,7 +227,6 @@ class MSAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
encoder_seq_start_loc=self.encoder_seq_start_loc,
max_encoder_seq_len=self.max_encoder_seq_len,
- chunked_prefill=self.chunked_prefill,
cross_slot_mapping=self.cross_slot_mapping,
cross_block_tables=self.cross_block_tables)
return self._cached_prefill_metadata
@@ -216,7 +249,7 @@ class MSAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
block_tables = (None if self.block_tables is None else
self.block_tables[self.num_prefills:])
- self._cached_decode_metadata = MSAttentionMetadata(
+ self._cached_decode_metadata = MsAttentionMetadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=self.num_decode_tokens,
@@ -245,14 +278,13 @@ class MSAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
encoder_seq_start_loc=self.encoder_seq_start_loc,
max_encoder_seq_len=self.max_encoder_seq_len,
- chunked_prefill=self.chunked_prefill,
cross_slot_mapping=self.cross_slot_mapping,
cross_block_tables=self.cross_block_tables)
return self._cached_decode_metadata
def advance_step(self,
model_input: "ModelInputForNPUWithSamplingMetadata",
- sampled_token_ids: Optional[torch.Tensor],
+ sampled_token_ids: Optional[ms.Tensor],
block_size: int,
num_seqs: int,
num_queries: int,
@@ -394,7 +426,7 @@ class MSAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
raise AttributeError(f"Invalid attention type {str(attn_type)}")
-class MsAttentionMetadataBuilder(AttentionMetadataBuilder[MSAttentionMetadata]):
+class MsAttentionMetadataBuilder(AttentionMetadataBuilder[MsAttentionMetadata]):
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.input_builder = input_builder
@@ -545,7 +577,7 @@ class MsAttentionMetadataBuilder(AttentionMetadataBuilder[MSAttentionMetadata]):
block_tables = make_tensor_with_pad(
self.block_tables,
pad=-1,
- dtype=torch.int,
+ dtype=ms.int32,
device=device,
)
assert max_query_len > 0, "query_lens: {}".format(query_lens)
@@ -557,13 +589,13 @@ class MsAttentionMetadataBuilder(AttentionMetadataBuilder[MSAttentionMetadata]):
query_start_loc_tensor = ms.Tensor(query_start_loc, dtype=ms.int32)
seq_start_loc_tensor = ms.Tensor(seq_start_loc, dtype=ms.int32)
- return MSAttentionMetadata(
+ return MsAttentionMetadata(
slot_mapping=slot_mapping_tensor,
block_tables=block_tables,
seq_lens_tensor=seq_lens_tensor,
seq_lens=seq_lens,
+ max_prefill_seq_len=max_prefill_seq_len,
max_decode_seq_len=max_decode_seq_len,
- chunked_prefill=self.input_builder.chunked_prefill_enabled,
num_prefills=self.num_prefills,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
@@ -574,6 +606,7 @@ class MsAttentionMetadataBuilder(AttentionMetadataBuilder[MSAttentionMetadata]):
seq_start_loc=seq_start_loc_tensor,
context_lens_tensor=context_lens_tensor,
max_query_len=max_query_len,
+ use_cuda_graph=False,
)
@@ -590,7 +623,7 @@ class MsAttentionBackend(AttentionBackend):
@staticmethod
def get_metadata_cls() -> Type["AttentionMetadata"]:
- return MSAttentionMetadata
+ return MsAttentionMetadata
@staticmethod
def get_builder_cls() -> Type["MsAttentionMetadataBuilder"]:
@@ -615,7 +648,7 @@ class MsAttentionBackend(AttentionBackend):
def swap_blocks(
src_kv_cache: MsKVCache,
dst_kv_cache: MsKVCache,
- src_to_dst: torch.Tensor,
+ src_to_dst: ms.Tensor,
swap_type: bool,
) -> None:
"""
@@ -637,7 +670,7 @@ class MsAttentionBackend(AttentionBackend):
@staticmethod
def copy_blocks(
kv_caches: List[MsKVCache],
- src_to_dists: torch.Tensor,
+ src_to_dists: ms.Tensor,
) -> None:
blocks_to_copy = src_to_dists.asnumpy().tolist()
for kv_cache in kv_caches:
@@ -691,14 +724,14 @@ class MsAttentionImpl(AttentionImpl):
def forward(
self,
layer: AttentionLayer,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- kv_cache: torch.Tensor,
- attn_metadata: MSAttentionMetadata,
+ query: ms.Tensor,
+ key: ms.Tensor,
+ value: ms.Tensor,
+ kv_cache: ms.Tensor,
+ attn_metadata: MsAttentionMetadata,
attn_type: str = AttentionType.DECODER,
- output: Optional[torch.Tensor] = None,
- ) -> torch.Tensor:
+ output: Optional[ms.Tensor] = None,
+ ) -> ms.Tensor:
"""Forward pass with FlashAttention.
Args:
@@ -726,7 +759,7 @@ class MLABackend(AttentionBackend):
@staticmethod
def get_metadata_cls() -> Type["AttentionMetadata"]:
- return MSAttentionMetadata
+ return MsAttentionMetadata
@staticmethod
def get_builder_cls() -> Type["MsAttentionMetadataBuilder"]:
@@ -747,9 +780,9 @@ class MLABackend(AttentionBackend):
@staticmethod
def swap_blocks(
- src_kv_cache: torch.Tensor,
- dst_kv_cache: torch.Tensor,
- src_to_dst: torch.Tensor,
+ src_kv_cache: ms.Tensor,
+ dst_kv_cache: ms.Tensor,
+ src_to_dst: ms.Tensor,
) -> None:
src_key_cache = src_kv_cache[0]
dst_key_cache = dst_kv_cache[0]
@@ -758,8 +791,8 @@ class MLABackend(AttentionBackend):
@staticmethod
def copy_blocks(
- kv_caches: List[torch.Tensor],
- src_to_dists: torch.Tensor,
+ kv_caches: List[ms.Tensor],
+ src_to_dists: ms.Tensor,
) -> None:
blocks_to_copy = src_to_dists.asnumpy().tolist()
for kv_cache in kv_caches:
@@ -771,4 +804,4 @@ class MLABackend(AttentionBackend):
def get_supported_head_sizes() -> List[int]:
return [576]
-FlashAttentionMetadata = MSAttentionMetadata
+FlashAttentionMetadata = MsAttentionMetadata
diff --git a/vllm_mindspore/attention/layer.py b/vllm_mindspore/attention/layer.py
index 89914e97ddce5578a46d69083e7499a20b2cfd6a..f4af1afba6e5e40691dce46d9a4a40eafe18ea0d 100644
--- a/vllm_mindspore/attention/layer.py
+++ b/vllm_mindspore/attention/layer.py
@@ -1,5 +1,4 @@
#!/usr/bin/env python3
-# encoding: utf-8
# Copyright 2025 Huawei Technologies Co., Ltd
# Copyright 2024 The vLLM team.
#
@@ -18,37 +17,31 @@
"""Common layer for LLM."""
from typing import Any, Dict, List, Optional, Tuple
-from mindspore import Tensor, mint, nn, ops, jit
+from mindspore import Tensor, mint, nn, ops
from mindspore.common import dtype as mstype
from mindspore.ops.auto_generate import PagedAttention, ReshapeAndCache
from mindspore.ops.operations.nn_ops import FlashAttentionScore
-
-from vllm.config import CacheConfig
from vllm.attention.backends.abstract import AttentionType
-from vllm.model_executor.layers.quantization.base_config import (
- QuantizationConfig)
+from vllm.config import CacheConfig
+from vllm.model_executor.layers.quantization.base_config import \
+ QuantizationConfig
-def _pad_to_max_tensor(
- input_: Tensor,
- max_len: int,
- dim: int = 0,
- pad_value: int = -1
-) -> Tensor:
+def _pad_to_max_tensor(input_: Tensor,
+ max_len: int,
+ dim: int = 0,
+ pad_value: int = -1) -> Tensor:
"""Temporary function, will be deprecated in the future."""
if input_.shape[dim] == max_len:
return input_
- pad_shape = (input_.shape[0], max_len - input_.shape[dim], *input_.shape[dim + 1:])
+ pad_shape = (input_.shape[0], max_len - input_.shape[dim],
+ *input_.shape[dim + 1:])
pad_tensor = mint.ones(size=pad_shape, dtype=input_.dtype) * pad_value
output = mint.cat([input_, pad_tensor], dim=dim)
return output
-def _generate_attn_mask(
- query: Tensor,
- value: Tensor,
- flatten: bool
-) -> Tensor:
+def _generate_attn_mask(query: Tensor, value: Tensor, flatten: bool) -> Tensor:
"""Temporary function, will be deprecated in the future."""
if flatten:
return mint.triu(mint.ones(size=(128, 128), dtype=query.dtype), 1)
@@ -59,16 +52,14 @@ def _generate_attn_mask(
return mask
-def _hidden_states_th2bsh(
- input_: Tensor,
- batch_valid_length: Tensor
-) -> Tensor:
+def _hidden_states_th2bsh(input_: Tensor,
+ batch_valid_length: Tensor) -> Tensor:
"""Temporary function, will be deprecated in the future."""
max_seq_len = batch_valid_length.max().item()
start_pos = 0
padding_input_list = []
for valid_length in batch_valid_length:
- valid_input = input_[:, start_pos: start_pos + valid_length, :]
+ valid_input = input_[:, start_pos:start_pos + valid_length, :]
padded_input = _pad_to_max_tensor(valid_input, max_seq_len, 1)
padding_input_list.append(padded_input)
start_pos += valid_length
@@ -76,10 +67,8 @@ def _hidden_states_th2bsh(
return bsh_output
-def _hidden_states_bsh2th(
- input_: Tensor,
- batch_valid_length: Tensor
-) -> Tensor:
+def _hidden_states_bsh2th(input_: Tensor,
+ batch_valid_length: Tensor) -> Tensor:
"""Temporary function, will be deprecated in the future."""
unpadded_input_list = []
for batch_index, valid_length in enumerate(batch_valid_length):
@@ -128,11 +117,10 @@ class Attention(nn.Cell):
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.head_size = head_size
- self.hidden_size_per_partition = num_heads*head_size
- self.kv_hidden_size_per_partition = num_kv_heads*head_size
- self.flatten = True
+ self.hidden_size_per_partition = num_heads * head_size
+ self.kv_hidden_size_per_partition = num_kv_heads * head_size
- input_layout = "TH" if self.flatten else "BSH" # pynative 下不支持拉平操作。
+ input_layout = "TH"
scale = float(scale)
pre_tokens = 2147483647
next_tokens = 2147483647
@@ -147,7 +135,6 @@ class Attention(nn.Cell):
scale_value=scale,
kv_head_num=num_kv_heads)
- @jit
def construct(
self,
query: Tensor,
@@ -162,7 +149,7 @@ class Attention(nn.Cell):
q_seq_lens: Tensor,
block_tables: Tensor,
) -> Tensor:
- """Attention foward, support MHA and GQA.
+ """Attention forward, support MHA and GQA.
Args:
query: shape = [1, num_tokens, hidden_size]
@@ -174,12 +161,19 @@ class Attention(nn.Cell):
block_tables: shape = [block_size, num_block]
"""
output = query
- cache_out = self.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping)
+ # ensure that the input tensors of reshape_and_cache is continuous
+ key = key.contiguous()
+ value = value.contiguous()
+ cache_out = self.reshape_and_cache(key, value, key_cache, value_cache,
+ slot_mapping)
query = ops.depend(query, cache_out)
if is_prefill:
- output = self._run_prefill_forward(query, key, value, attn_mask, batch_valid_length, batch_valid_length)
+ output = self._run_prefill_forward(query, key, value, attn_mask,
+ batch_valid_length,
+ batch_valid_length)
else:
- output = self._run_decode_forward(query, key_cache, value_cache, block_tables, batch_valid_length,
+ output = self._run_decode_forward(query, key_cache, value_cache,
+ block_tables, batch_valid_length,
attn_mask, q_seq_lens)
return output
@@ -202,9 +196,6 @@ class Attention(nn.Cell):
actual_seq_kvlen: shape = [batch_size, ]
NOTE: Currently `PyNative` mode does not support operations in "TH" form, so it will be converted to "BSH" form.
"""
- query = query.view(-1, self.hidden_size_per_partition)
- key = key.view(-1, self.kv_hidden_size_per_partition)
- value = value.view(-1, self.kv_hidden_size_per_partition)
_, _, _, output = self.flash_attention(
query,
key,
@@ -217,7 +208,6 @@ class Attention(nn.Cell):
actual_seq_qlen,
actual_seq_kvlen
)
- output = output.view(1, -1, self.hidden_size_per_partition)
return output
def _run_decode_forward(
@@ -239,15 +229,7 @@ class Attention(nn.Cell):
block_tables: shape = [block_size, num_block]
context_lens: shape = [batch_size, ]
"""
- output = self.paged_attention(
- query,
- key_cache,
- value_cache,
- block_tables,
- batch_valid_length,
- None,
- None,
- attn_mask,
- q_seq_lens
- )
+ output = self.paged_attention(query, key_cache, value_cache,
+ block_tables, batch_valid_length, None,
+ None, attn_mask, q_seq_lens)
return output
diff --git a/vllm_mindspore/distributed/communication_op.py b/vllm_mindspore/distributed/communication_op.py
index 00447432e546516bf4d8629c374ac36e491041e8..a24d49595c7d7698331492562e14e7d9c65c08b6 100644
--- a/vllm_mindspore/distributed/communication_op.py
+++ b/vllm_mindspore/distributed/communication_op.py
@@ -1,5 +1,4 @@
#!/usr/bin/env python3
-# encoding: utf-8
# Copyright 2025 Huawei Technologies Co., Ltd
# Copyright 2024 The vLLM team.
#
@@ -16,20 +15,16 @@
# limitations under the License.
# ============================================================================
-
# 该文件实现底层通信接口, 要求动静统一, 最后才可以在网络中入图。
# 不要去照搬mindspeed的, 因为训练当中包含太多的特性, 推理只需要非常简单的通信,可以提升性能。
from typing import Any, Dict, Optional, Union
-import mindspore as ms
from mindspore import Tensor, nn, ops
-from mindspore.communication.comm_func import (all_gather_into_tensor,
- all_reduce, broadcast,
- gather_into_tensor, recv, send)
+from mindspore.communication.comm_func import all_reduce, broadcast
from vllm.distributed.parallel_state import (
- get_pp_group, get_tensor_model_parallel_rank,
- get_tensor_model_parallel_world_size, get_tp_group, get_world_group)
+ get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size,
+ get_tp_group, get_world_group)
def tensor_model_parallel_all_reduce(input_: Tensor) -> Tensor:
@@ -40,47 +35,6 @@ def tensor_model_parallel_all_reduce(input_: Tensor) -> Tensor:
return output
-def tensor_model_parallel_all_gather(input_: Tensor,
- dim: int = -1) -> Tensor:
- if get_tensor_model_parallel_world_size() == 1:
- return input_
- """All-gather the input tensor across model parallel group."""
- output, _ = all_gather_into_tensor(input_, group=get_tp_group())
- input_size = input_.shape
- if dim < 0:
- # Convert negative dim to positive.
- dim += len(input_size)
- # Reshape
- output_tensor = output_tensor.reshape((world_size, ) + input_size)
- output_tensor = output_tensor.movedim(0, dim)
- output_tensor = output_tensor.reshape(input_size[:dim] +
- (world_size *
- input_size[dim], ) +
- input_size[dim + 1:])
- return output
-
-
-def tensor_model_parallel_gather(input_: Tensor,
- dst: int = 0,
- dim: int = -1) -> Optional[Tensor]:
- if get_tensor_model_parallel_world_size() == 1:
- return input_
- """Gather the input tensor across model parallel group."""
- if dim < 0:
- # Convert negative dim to positive.
- dim += len(input_.shape)
- if dim != 0:
- input_ = input_.moveaxis(dim, 0)
- _dst = get_world_rank_from_tp_group_rank(dst)
- output = gather_into_tensor(input_, dst=_dst, group=get_tp_group())
- if get_tensor_model_parallel_rank() == dst:
- if dim != 0:
- output = output.moveaxis(0, dim)
- else:
- output = None
- return output
-
-
def broadcast_tensor(tensor, src: int = 0):
# broadcast tensor to the world group
return broadcast(tensor, src, group=get_world_group())
@@ -95,15 +49,6 @@ def broadcast_tensor_dict(tensor_dict: Optional[Dict[Any, Union[Tensor,
# return get_tp_group().broadcast_tensor_dict(tensor_dict, src)
-def send_to_next_pp_rank(tensor):
- send(tensor, next_pp_rank(), group=get_pp_group())
-
-
-def recv_from_prev_pp_rank(tensor):
- output = recv(tensor, prev_pp_rank(), group=get_pp_group())
- return output
-
-
class ReduceFromModelParallelRegion(nn.Cell):
"All reduce the input from the model parallel region."
@@ -122,7 +67,7 @@ class ReduceFromModelParallelRegion(nn.Cell):
class GatherFromModelParallelRegion(nn.Cell):
- "Gather the input from model parallel region and concatinate."
+ "Gather the input from model parallel region and concatenate."
def __init__(self):
super().__init__()
@@ -138,7 +83,32 @@ class GatherFromModelParallelRegion(nn.Cell):
# Size and dimension.
if self.world_size == 1:
return input_
- output = ops.CollectiveGather(dest_rank=dst, group=self.tp_group)(input_.transpose(2, 1, 0))
+ output = ops.CollectiveGather(dest_rank=dst,
+ group=self.tp_group)(input_.transpose(
+ 2, 1, 0))
if self.tp_rank != dst:
return ops.depend(ops.zeros_like(input_), output)
return output.transpose(2, 1, 0)
+
+
+class AllGatherFromModelParallelRegion(nn.Cell):
+ """
+ Gather the input from world parallel region and concatenate, simultaneously perform
+ transpose operation on input.
+ """
+
+ def __init__(self):
+ super().__init__()
+ self.world_size = get_tensor_model_parallel_world_size()
+ if self.world_size > 1:
+ self.tp_group = get_tp_group().device_group._name
+ self.all_gather_into_tensor = ops.AllGather(group=self.tp_group)
+
+ def construct(self, input_):
+ # Size and dimension.
+ if self.world_size == 1:
+ return input_
+ input_ = ops.swapaxes(input_, 0, -1)
+ output = self.all_gather_into_tensor(input_)
+ output = ops.swapaxes(output, 0, -1)
+ return output
diff --git a/vllm_mindspore/engine/arg_utils.py b/vllm_mindspore/engine/arg_utils.py
index ed74ba9e38d54f7e507951ea585106f833e83d6b..5460bbbefc6292f347618ba835e991a1a7508afd 100644
--- a/vllm_mindspore/engine/arg_utils.py
+++ b/vllm_mindspore/engine/arg_utils.py
@@ -51,11 +51,6 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
recommend_to_remove=True)
return False
- if self.additional_config != EngineArgs.additional_config:
- _raise_or_fallback(feature_name="--additional-config",
- recommend_to_remove=False)
- return False
-
# Xgrammar and Guidance are supported.
SUPPORTED_GUIDED_DECODING = [
"xgrammar", "xgrammar:disable-any-whitespace", "guidance",
diff --git a/vllm_mindspore/entrypoints/__init__.py b/vllm_mindspore/entrypoints/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/vllm_mindspore/entrypoints.py b/vllm_mindspore/entrypoints/__main__.py
similarity index 89%
rename from vllm_mindspore/entrypoints.py
rename to vllm_mindspore/entrypoints/__main__.py
index aa91f07aeae5036c651fc3a1b5a2b205b6a68203..5d27f240efc93ea97174e98d4da4fbb5c20679ea 100644
--- a/vllm_mindspore/entrypoints.py
+++ b/vllm_mindspore/entrypoints/__main__.py
@@ -1,5 +1,4 @@
#!/usr/bin/env python3
-# encoding: utf-8
# Copyright 2025 Huawei Technologies Co., Ltd
# Copyright 2024 The vLLM team.
#
@@ -32,8 +31,8 @@ if __name__ == "__main__":
module = importlib.import_module(module_name)
except Exception as e:
raise ValueError(
- "Invalid entrypoint(%s) for vllm, error: %s!" % (module_name, str(e))
- )
+ f"Invalid entrypoint({module_name}) for vllm, error: {str(e)}!"
+ ) from e
module_code = inspect.getsource(module)
vllm_mindspore_enable_line = "import vllm_mindspore\n"
@@ -44,4 +43,4 @@ if __name__ == "__main__":
with open(exec_file, "w") as f:
f.writelines(module_code)
- subprocess.run([sys.executable, exec_file] + sys.argv[2:])
+ subprocess.run([sys.executable, str(exec_file)] + sys.argv[2:])
diff --git a/vllm_mindspore/entrypoints/openai/__init__.py b/vllm_mindspore/entrypoints/openai/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/vllm_mindspore/entrypoints/openai/serving_chat.py b/vllm_mindspore/entrypoints/openai/serving_chat.py
new file mode 100644
index 0000000000000000000000000000000000000000..05713071e210c574a81f7cbf3d005aac72ae984c
--- /dev/null
+++ b/vllm_mindspore/entrypoints/openai/serving_chat.py
@@ -0,0 +1,484 @@
+#!/usr/bin/env python3
+# Copyright 2025 Huawei Technologies Co., Ltd
+# Copyright 2024 The vLLM team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+# SPDX-License-Identifier: Apache-2.0
+"""
+Adapted from
+https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/openai/serving_chat.py
+"""
+import time
+from collections.abc import AsyncGenerator, AsyncIterator
+from typing import Final, Optional, Union
+
+from vllm.entrypoints.chat_utils import ConversationMessage
+from vllm.entrypoints.openai.protocol import (
+ ChatCompletionNamedToolChoiceParam, ChatCompletionRequest,
+ ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse,
+ DeltaFunctionCall, DeltaMessage, DeltaToolCall, PromptTokenUsageInfo,
+ RequestResponseMetadata, UsageInfo)
+from vllm.entrypoints.openai.tool_parsers import ToolParser
+from vllm.logger import init_logger
+from vllm.outputs import RequestOutput
+from vllm.transformers_utils.tokenizer import AnyTokenizer
+
+logger = init_logger(__name__)
+
+
+async def chat_completion_stream_generator(
+ self,
+ request: ChatCompletionRequest,
+ result_generator: AsyncIterator[RequestOutput],
+ request_id: str,
+ model_name: str,
+ conversation: list[ConversationMessage],
+ tokenizer: AnyTokenizer,
+ request_metadata: RequestResponseMetadata,
+) -> AsyncGenerator[str, None]:
+ created_time = int(time.time())
+ chunk_object_type: Final = "chat.completion.chunk"
+ first_iteration = True
+
+ # Send response for each token for each request.n (index)
+ num_choices = 1 if request.n is None else request.n
+ previous_num_tokens = [0] * num_choices
+ finish_reason_sent = [False] * num_choices
+ num_prompt_tokens = 0
+ num_cached_tokens = None
+
+ if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam):
+ tool_choice_function_name = request.tool_choice.function.name
+ else:
+ tool_choice_function_name = None
+
+ # Determine whether tools are in use with "auto" tool choice
+ tool_choice_auto = (not tool_choice_function_name and
+ self._should_stream_with_auto_tool_parsing(request))
+
+ should_stream_with_reasoning_parsing = (
+ self._should_stream_with_reasoning_parsing(request))
+
+ all_previous_token_ids: Optional[list[list[int]]]
+ function_name_returned: Optional[list[bool]] = None
+
+ # Only one of these will be used, thus previous_texts and
+ # all_previous_token_ids will not be used twice in the same iteration.
+ if tool_choice_auto or should_stream_with_reasoning_parsing:
+ # These are only required in "auto" tool choice case
+ previous_texts = [""] * num_choices
+ all_previous_token_ids = [[]] * num_choices
+ # For reasoning parser and tool call all enabled
+ added_content_delta_arr = [False] * num_choices
+ reasoning_end_arr = [False] * num_choices
+ elif request.tool_choice == "required":
+ previous_texts = [""] * num_choices
+ function_name_returned = [False] * num_choices
+ all_previous_token_ids = None
+ else:
+ previous_texts, all_previous_token_ids = None, None
+
+ try:
+ # There is no need to check if the reasoning_parser is None
+ # because the should_stream_with_reasoning_parsing check
+ # already ensures that the reasoning_parser is not None.
+ # but the pre-commit hook requires it.
+ if should_stream_with_reasoning_parsing and \
+ self.reasoning_parser is not None:
+ reasoning_parser = self.reasoning_parser(tokenizer)
+ except RuntimeError as e:
+ logger.exception("Error in reasoning parser creation.")
+ data = self.create_streaming_error_response(str(e))
+ yield f"data: {data}\n\n"
+ yield "data: [DONE]\n\n"
+ return
+
+ # Prepare the tool parser if it's needed
+ try:
+ if tool_choice_auto and self.tool_parser:
+ tool_parsers: list[Optional[ToolParser]] = [
+ self.tool_parser(tokenizer)
+ ] * num_choices
+ else:
+ tool_parsers = [None] * num_choices
+ except Exception as e:
+ logger.exception("Error in tool parser creation.")
+ data = self.create_streaming_error_response(str(e))
+ yield f"data: {data}\n\n"
+ yield "data: [DONE]\n\n"
+ return
+
+ stream_options = request.stream_options
+ if stream_options:
+ include_usage = stream_options.include_usage
+ include_continuous_usage = include_usage and \
+ stream_options.continuous_usage_stats
+ else:
+ include_usage, include_continuous_usage = False, False
+
+ try:
+ async for res in result_generator:
+ if res.prompt_token_ids is not None:
+ num_prompt_tokens = len(res.prompt_token_ids)
+ if res.encoder_prompt_token_ids is not None:
+ num_prompt_tokens += len(res.encoder_prompt_token_ids)
+
+ # We need to do it here, because if there are exceptions in
+ # the result_generator, it needs to be sent as the FIRST
+ # response (by the try...catch).
+ if first_iteration:
+ num_cached_tokens = res.num_cached_tokens
+ # Send first response for each request.n (index) with
+ # the role
+ role = self.get_chat_request_role(request)
+
+ # NOTE num_choices defaults to 1 so this usually executes
+ # once per request
+ for i in range(num_choices):
+ choice_data = ChatCompletionResponseStreamChoice(
+ index=i,
+ delta=DeltaMessage(
+ role=role,
+ content="",
+ ),
+ logprobs=None,
+ finish_reason=None)
+ chunk = ChatCompletionStreamResponse(
+ id=request_id,
+ object=chunk_object_type,
+ created=created_time,
+ choices=[choice_data],
+ model=model_name)
+
+ # if continuous usage stats are requested, add it
+ if include_continuous_usage:
+ chunk.usage = UsageInfo(
+ prompt_tokens=num_prompt_tokens,
+ completion_tokens=0,
+ total_tokens=num_prompt_tokens)
+
+ data = chunk.model_dump_json(exclude_unset=True)
+ yield f"data: {data}\n\n"
+
+ # Send response to echo the input portion of the
+ # last message
+ if request.echo:
+ last_msg_content: Union[str, list[dict[str, str]]] = ""
+ if conversation and "content" in conversation[
+ -1] and conversation[-1].get("role") == role:
+ last_msg_content = conversation[-1]["content"] or ""
+
+ if last_msg_content:
+ for i in range(num_choices):
+ choice_data = (ChatCompletionResponseStreamChoice(
+ index=i,
+ delta=DeltaMessage(content=last_msg_content),
+ logprobs=None,
+ finish_reason=None))
+ chunk = ChatCompletionStreamResponse(
+ id=request_id,
+ object=chunk_object_type,
+ created=created_time,
+ choices=[choice_data],
+ model=model_name)
+ if include_continuous_usage:
+ chunk.usage = UsageInfo(
+ prompt_tokens=num_prompt_tokens,
+ completion_tokens=0,
+ total_tokens=num_prompt_tokens)
+
+ data = chunk.model_dump_json(exclude_unset=True)
+ yield f"data: {data}\n\n"
+ first_iteration = False
+
+ for output in res.outputs:
+ i = output.index
+ tool_parser = tool_parsers[i]
+
+ if finish_reason_sent[i]:
+ continue
+
+ if request.logprobs and request.top_logprobs is not None:
+ assert output.logprobs is not None, (
+ "Did not output logprobs")
+ logprobs = self._create_chat_logprobs(
+ token_ids=output.token_ids,
+ top_logprobs=output.logprobs,
+ tokenizer=tokenizer,
+ num_output_top_logprobs=request.top_logprobs,
+ return_as_token_id=request.return_tokens_as_token_ids,
+ )
+ else:
+ logprobs = None
+
+ delta_text = output.text
+
+ if not delta_text and not output.token_ids and \
+ not previous_num_tokens[i]:
+ # Chunked prefill case, don't return empty chunks
+ continue
+
+ delta_message: Optional[DeltaMessage]
+
+ # just update previous_texts and previous_token_ids
+ if tool_choice_auto or should_stream_with_reasoning_parsing:
+ assert previous_texts is not None
+ assert all_previous_token_ids is not None
+ previous_text = previous_texts[i]
+ previous_token_ids = all_previous_token_ids[i]
+ current_text = previous_text + delta_text
+ current_token_ids = previous_token_ids + list(
+ output.token_ids)
+
+ # handle streaming deltas for tools with named tool_choice
+ if tool_choice_function_name:
+ if (self.enable_reasoning
+ and not reasoning_parser.is_reasoning_end(
+ previous_token_ids)):
+ assert reasoning_parser is not None
+ delta_message = (reasoning_parser.
+ extract_reasoning_content_streaming(
+ previous_text,
+ current_text,
+ delta_text,
+ previous_token_ids,
+ current_token_ids,
+ output.token_ids,
+ ))
+ # When encountering think end id in delta_token_ids,
+ # process the `content`. Only keep 'content',
+ # remove 'reasoning_content'
+ if reasoning_parser.is_reasoning_end(
+ list(output.token_ids)):
+ if delta_message and delta_message.content:
+ # This need to be added to next `delta_text`
+ current_text = delta_message.content
+ delta_message.content = None
+ else:
+ current_text = ""
+ else:
+ # Just to add remaining `content`
+ if self.enable_reasoning:
+ delta_text = previous_text + delta_text
+ current_text = ""
+
+ delta_message = DeltaMessage(tool_calls=[
+ DeltaToolCall(function=DeltaFunctionCall(
+ name=tool_choice_function_name,
+ arguments=delta_text),
+ index=i)
+ ])
+
+ elif request.tool_choice == "required":
+ assert previous_texts is not None
+ assert function_name_returned is not None
+ previous_text = previous_texts[i]
+ current_text = previous_text + delta_text
+ fn_name_returned = function_name_returned[i]
+
+ delta_message, function_name_returned[i] = (
+ self.extract_tool_call_required_streaming(
+ previous_text=previous_text,
+ current_text=current_text,
+ delta_text=delta_text,
+ function_name_returned=fn_name_returned))
+
+ # update the previous values for the next iteration
+ previous_texts[i] = current_text
+
+ # handle streaming deltas for tools with "auto" tool choice
+ # and reasoning parser
+ elif tool_choice_auto and self.enable_reasoning:
+ assert tool_parser is not None
+ assert reasoning_parser is not None
+ assert added_content_delta_arr is not None
+ assert reasoning_end_arr is not None
+ if not reasoning_end_arr[i]:
+ delta_message = (reasoning_parser.
+ extract_reasoning_content_streaming(
+ previous_text,
+ current_text,
+ delta_text,
+ previous_token_ids,
+ current_token_ids,
+ output.token_ids,
+ ))
+
+ # When encountering think end id in delta_token_ids,
+ # set reasoning status to end.
+ # Remove the text and token ids related
+ # to 'reasoning_content'.
+ if reasoning_parser.is_reasoning_end(
+ list(output.token_ids)):
+ reasoning_end_arr[i] = True
+ current_token_ids = \
+ reasoning_parser.extract_content_ids(
+ list(output.token_ids))
+ if delta_message and delta_message.content:
+ current_text = delta_message.content
+ delta_message.content = None
+ else:
+ current_text = ""
+
+ # handle tool calls only after reasoning is done,
+ else:
+ delta_token_ids = list(output.token_ids)
+ # First time to tool call,
+ # add the remaining text and token ids
+ # to delta from previous
+ if not added_content_delta_arr[i]:
+ added_content_delta_arr[i] = True
+ previous_text = ""
+ previous_token_ids = []
+ delta_text = current_text
+ delta_token_ids = current_token_ids
+
+ delta_message = (
+ tool_parser.extract_tool_calls_streaming(
+ previous_text=previous_text,
+ current_text=current_text,
+ delta_text=delta_text,
+ previous_token_ids=previous_token_ids,
+ current_token_ids=current_token_ids,
+ delta_token_ids=delta_token_ids,
+ request=request))
+ # when only tool calls
+ elif tool_choice_auto:
+ assert tool_parser is not None
+ delta_message = (tool_parser.extract_tool_calls_streaming(
+ previous_text=previous_text,
+ current_text=current_text,
+ delta_text=delta_text,
+ previous_token_ids=previous_token_ids,
+ current_token_ids=current_token_ids,
+ delta_token_ids=output.token_ids,
+ request=request))
+ # when only reasoning
+ elif self.enable_reasoning:
+ assert reasoning_parser is not None
+ delta_message = (
+ reasoning_parser.extract_reasoning_content_streaming(
+ previous_text,
+ current_text,
+ delta_text,
+ previous_token_ids,
+ current_token_ids,
+ output.token_ids,
+ ))
+ # handle streaming just a content delta
+ else:
+ delta_message = DeltaMessage(content=delta_text)
+
+ # update the previous values for the next iteration
+ if tool_choice_auto or should_stream_with_reasoning_parsing:
+ assert previous_texts is not None
+ assert all_previous_token_ids is not None
+ previous_texts[i] = current_text
+ all_previous_token_ids[i] = current_token_ids
+
+ # set the previous values for the next iteration
+ previous_num_tokens[i] += len(output.token_ids)
+
+ # if the message delta is None (e.g. because it was a
+ # "control token" for tool calls or the parser otherwise
+ # wasn't ready to send a token, then
+ # get the next token without streaming a chunk
+ if delta_message is None:
+ continue
+
+ if output.finish_reason is None:
+ # Send token-by-token response for each request.n
+ choice_data = ChatCompletionResponseStreamChoice(
+ index=i,
+ delta=delta_message,
+ logprobs=logprobs,
+ finish_reason=None)
+
+ # if the model is finished generating
+ else:
+ # check to make sure we haven't "forgotten" to stream
+ # any tokens that were generated but previously
+ # matched by partial json parsing
+ # only happens if we are NOT using guided decoding
+ auto_tools_called = False
+ if tool_parser:
+ auto_tools_called = len(
+ tool_parser.prev_tool_call_arr) > 0
+
+ # Send the finish response for each request.n only once
+ choice_data = ChatCompletionResponseStreamChoice(
+ index=i,
+ delta=delta_message,
+ logprobs=logprobs,
+ finish_reason=output.finish_reason
+ if not auto_tools_called else "tool_calls",
+ stop_reason=output.stop_reason)
+
+ finish_reason_sent[i] = True
+
+ chunk = ChatCompletionStreamResponse(id=request_id,
+ object=chunk_object_type,
+ created=created_time,
+ choices=[choice_data],
+ model=model_name)
+
+ # handle usage stats if requested & if continuous
+ if include_continuous_usage:
+ completion_tokens = previous_num_tokens[i]
+ chunk.usage = UsageInfo(
+ prompt_tokens=num_prompt_tokens,
+ completion_tokens=completion_tokens,
+ total_tokens=num_prompt_tokens + completion_tokens,
+ )
+
+ data = chunk.model_dump_json(exclude_unset=True)
+ yield f"data: {data}\n\n"
+
+ # once the final token is handled, if stream_options.include_usage
+ # is sent, send the usage
+ if include_usage:
+ completion_tokens = sum(previous_num_tokens)
+ final_usage = UsageInfo(prompt_tokens=num_prompt_tokens,
+ completion_tokens=completion_tokens,
+ total_tokens=num_prompt_tokens +
+ completion_tokens)
+ if self.enable_prompt_tokens_details and num_cached_tokens:
+ final_usage.prompt_tokens_details = PromptTokenUsageInfo(
+ cached_tokens=num_cached_tokens)
+
+ final_usage_chunk = ChatCompletionStreamResponse(
+ id=request_id,
+ object=chunk_object_type,
+ created=created_time,
+ choices=[],
+ model=model_name,
+ usage=final_usage)
+ final_usage_data = (final_usage_chunk.model_dump_json(
+ exclude_unset=True, exclude_none=True))
+ yield f"data: {final_usage_data}\n\n"
+
+ # report to FastAPI middleware aggregate usage across all choices
+ num_completion_tokens = sum(previous_num_tokens)
+ request_metadata.final_usage_info = UsageInfo(
+ prompt_tokens=num_prompt_tokens,
+ completion_tokens=num_completion_tokens,
+ total_tokens=num_prompt_tokens + num_completion_tokens)
+
+ except Exception as e:
+ # TODO: Use a vllm-specific Validation Error
+ logger.exception("Error in chat completion stream generator.")
+ data = self.create_streaming_error_response(str(e))
+ yield f"data: {data}\n\n"
+ # Send the final done message after all response.n are finished
+ yield "data: [DONE]\n\n"
diff --git a/vllm_mindspore/entrypoints/openai/tool_parsers/__init__.py b/vllm_mindspore/entrypoints/openai/tool_parsers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/vllm_mindspore/entrypoints/openai/tool_parsers/deepseekv3_tool_parser.py b/vllm_mindspore/entrypoints/openai/tool_parsers/deepseekv3_tool_parser.py
new file mode 100644
index 0000000000000000000000000000000000000000..c672c2967bbcea093f79e590cb77a5c47c193f0d
--- /dev/null
+++ b/vllm_mindspore/entrypoints/openai/tool_parsers/deepseekv3_tool_parser.py
@@ -0,0 +1,387 @@
+#!/usr/bin/env python3
+# Copyright 2025 Huawei Technologies Co., Ltd
+# Copyright 2024 The vLLM team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+# SPDX-License-Identifier: Apache-2.0
+"""
+Adapted from
+https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/openai/tool_parsers/deepseekv3_tool_parser.py
+"""
+import re
+from collections.abc import Sequence
+from typing import Union
+
+from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
+ DeltaFunctionCall, DeltaMessage,
+ DeltaToolCall,
+ ExtractedToolCallInformation,
+ FunctionCall, ToolCall)
+from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
+ ToolParser, ToolParserManager)
+from vllm.logger import init_logger
+from vllm.transformers_utils.tokenizer import AnyTokenizer
+from vllm.utils import random_uuid
+
+logger = init_logger(__name__)
+
+
+@ToolParserManager.register_module("deepseek_v3")
+class DeepSeekV3ToolParser(ToolParser):
+
+ def __init__(self, tokenizer: AnyTokenizer):
+ super().__init__(tokenizer)
+
+ self.current_tool_name_sent: bool = False
+ self.prev_tool_call_arr: list[dict] = []
+ self.current_tool_id: int = -1
+ self.streamed_args_for_tool: list[str] = (
+ []) # map what has been streamed for each tool so far to a list
+
+ self.tool_calls_start_token: str = "<|tool▁calls▁begin|>"
+ self.tool_calls_end_token: str = "<|tool▁calls▁end|>"
+
+ self.tool_call_start_token: str = "<|tool▁call▁begin|>"
+ self.tool_call_end_token: str = "<|tool▁call▁end|>"
+
+ self.tool_call_regex = re.compile(
+ r"<|tool▁call▁begin|>(?P.*)<|tool▁sep|>(?P.*)\n```json\n(?P.*)\n```<|tool▁call▁end|>"
+ )
+
+ self.stream_tool_call_portion_regex = re.compile(
+ r"(?P.*)<|tool▁sep|>(?P.*)\n```json\n(?P.*[^\n`])"
+ )
+
+ self.stream_tool_call_name_regex = re.compile(
+ r"(?P.*)<|tool▁sep|>(?P.*)\n")
+
+ if not self.model_tokenizer:
+ raise ValueError(
+ "The model tokenizer must be passed to the ToolParser "
+ "constructor during construction.")
+ self.tool_calls_start_token_id = self.vocab.get(
+ self.tool_calls_start_token)
+ self.tool_calls_end_token_id = self.vocab.get(
+ self.tool_calls_end_token)
+
+ self.tool_call_start_token_id = self.vocab.get(
+ self.tool_call_start_token)
+ self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
+
+ if (self.tool_calls_start_token_id is None
+ or self.tool_calls_end_token_id is None):
+ raise RuntimeError(
+ "DeepSeek-V3 Tool parser could not locate tool call start/end "
+ "tokens in the tokenizer!")
+
+ def extract_tool_calls(
+ self,
+ model_output: str,
+ request: ChatCompletionRequest,
+ ) -> ExtractedToolCallInformation:
+
+ # sanity check; avoid unnecessary processing
+ if self.tool_calls_start_token not in model_output:
+ return ExtractedToolCallInformation(tools_called=False,
+ tool_calls=[],
+ content=model_output)
+
+ else:
+ try:
+ # there are two possible captures - between tags, or between a
+ # tag and end-of-string so the result of
+ # findall is an array of tuples where one is a function call and
+ # the other is None
+ function_call_tuples = self.tool_call_regex.findall(
+ model_output)
+
+ tool_calls = []
+ for match in function_call_tuples:
+ tool_type, function_name, function_args = match
+ tool_calls.append(
+ ToolCall(
+ type=tool_type,
+ function=FunctionCall(name=function_name,
+ arguments=function_args),
+ ))
+
+ content = model_output[:model_output.
+ find(self.tool_calls_start_token)]
+ return ExtractedToolCallInformation(
+ tools_called=True,
+ tool_calls=tool_calls,
+ content=content if content else None,
+ )
+
+ except Exception:
+ logger.exception(
+ "Error in extracting tool call from response.")
+ return ExtractedToolCallInformation(tools_called=False,
+ tool_calls=[],
+ content=model_output)
+
+ def extract_tool_calls_streaming(
+ self,
+ previous_text: str,
+ current_text: str,
+ delta_text: str,
+ previous_token_ids: Sequence[int],
+ current_token_ids: Sequence[int],
+ delta_token_ids: Sequence[int],
+ request: ChatCompletionRequest,
+ ) -> Union[DeltaMessage, None]:
+
+ logger.debug("delta_text: %s", delta_text)
+ logger.debug("delta_token_ids: %s", delta_token_ids)
+ # check to see if we should be streaming a tool call - is there a
+ if self.tool_calls_start_token_id not in current_token_ids:
+ logger.debug("No tool call tokens found!")
+ return DeltaMessage(content=delta_text)
+ delta_text = delta_text.replace(self.tool_calls_start_token,
+ "").replace(self.tool_calls_end_token,
+ "")
+ try:
+
+ # figure out where we are in the parsing by counting tool call
+ # start & end tags
+ prev_tool_start_count = previous_token_ids.count(
+ self.tool_call_start_token_id)
+ prev_tool_end_count = previous_token_ids.count(
+ self.tool_call_end_token_id)
+ cur_tool_start_count = current_token_ids.count(
+ self.tool_call_start_token_id)
+ cur_tool_end_count = current_token_ids.count(
+ self.tool_call_end_token_id)
+ tool_call_portion = None
+ text_portion = None
+
+ # case: if we're generating text, OR rounding out a tool call
+ if (cur_tool_start_count == cur_tool_end_count
+ and prev_tool_end_count == cur_tool_end_count
+ and self.tool_call_end_token not in delta_text):
+ logger.debug("Generating text content! skipping tool parsing.")
+ return DeltaMessage(content=delta_text)
+
+ if self.tool_call_end_token in delta_text:
+ logger.debug("tool_call_end_token in delta_text")
+ full_text = current_text + delta_text
+ tool_call_portion = full_text.split(
+ self.tool_call_start_token)[-1].split(
+ self.tool_call_end_token)[0].rstrip()
+ delta_text = delta_text.split(
+ self.tool_call_end_token)[0].rstrip()
+ text_portion = delta_text.split(
+ self.tool_call_end_token)[-1].lstrip()
+
+ # case -- we're starting a new tool call
+ if (cur_tool_start_count > cur_tool_end_count
+ and cur_tool_start_count > prev_tool_start_count):
+ if len(delta_token_ids) > 1:
+ tool_call_portion = current_text.split(
+ self.tool_call_start_token)[-1]
+ else:
+ tool_call_portion = None
+ delta = None
+
+ text_portion = None
+
+ # set cursors and state appropriately
+ self.current_tool_id += 1
+ self.current_tool_name_sent = False
+ self.streamed_args_for_tool.append("")
+ logger.debug("Starting on a new tool %s", self.current_tool_id)
+
+ # case -- we're updating an existing tool call
+ elif (cur_tool_start_count > cur_tool_end_count
+ and cur_tool_start_count == prev_tool_start_count):
+
+ # get the portion of the text that's the tool call
+ tool_call_portion = current_text.split(
+ self.tool_call_start_token)[-1]
+ text_portion = None
+
+ # case -- the current tool call is being closed.
+ elif (cur_tool_start_count == cur_tool_end_count
+ and cur_tool_end_count >= prev_tool_end_count):
+ if self.prev_tool_call_arr is None or len(
+ self.prev_tool_call_arr) == 0:
+ logger.debug(
+ "attempting to close tool call, but no tool call")
+ return None
+ diff = self.prev_tool_call_arr[self.current_tool_id].get(
+ "arguments")
+ if diff:
+ diff = (diff.encode("utf-8").decode("unicode_escape")
+ if diff is str else diff)
+ if '}' not in delta_text:
+ return None
+ end_loc = delta_text.rindex('}')
+ diff = delta_text[:end_loc] + '}'
+ logger.debug(
+ "Finishing tool and found diff that had not "
+ "been streamed yet: %s",
+ diff,
+ )
+ self.streamed_args_for_tool[self.current_tool_id] += diff
+ return DeltaMessage(tool_calls=[
+ DeltaToolCall(
+ index=self.current_tool_id,
+ function=DeltaFunctionCall(
+ arguments=diff).model_dump(exclude_none=True),
+ )
+ ])
+
+ # case -- otherwise we're just generating text
+ else:
+ text = delta_text.replace(self.tool_call_start_token, "")
+ text = text.replace(self.tool_call_end_token, "")
+ delta = DeltaMessage(tool_calls=[], content=text)
+ return delta
+
+ current_tool_call = dict()
+ if tool_call_portion:
+ current_tool_call_matches = (
+ self.stream_tool_call_portion_regex.match(
+ tool_call_portion))
+ if current_tool_call_matches:
+ tool_type, tool_name, tool_args = (
+ current_tool_call_matches.groups())
+ current_tool_call["name"] = tool_name
+ current_tool_call["arguments"] = tool_args
+ else:
+ current_tool_call_name_matches = (
+ self.stream_tool_call_name_regex.match(
+ tool_call_portion))
+ if current_tool_call_name_matches:
+ tool_type, tool_name = (
+ current_tool_call_name_matches.groups())
+ current_tool_call["name"] = tool_name
+ current_tool_call["arguments"] = ""
+ else:
+ logger.debug("Not enough token")
+ return None
+
+ # case - we haven't sent the tool name yet. If it's available, send
+ # it. otherwise, wait until it's available.
+ if not self.current_tool_name_sent:
+ if current_tool_call is None:
+ return None
+ function_name: Union[str, None] = current_tool_call.get("name")
+ if function_name:
+ self.current_tool_name_sent = True
+ return DeltaMessage(tool_calls=[
+ DeltaToolCall(
+ index=self.current_tool_id,
+ type="function",
+ id=f"chatcmpl-tool-{random_uuid()}",
+ function=DeltaFunctionCall(
+ name=function_name).model_dump(
+ exclude_none=True),
+ )
+ ])
+ else:
+ return None
+
+ # case -- otherwise, send the tool call delta
+
+ # if the tool call portion is None, send the delta as text
+ if tool_call_portion is None:
+ # if there's text but not tool calls, send that -
+ # otherwise None to skip chunk
+ delta = (DeltaMessage(
+ content=delta_text) if text_portion is not None else None)
+ return delta
+
+ # now, the nitty-gritty of tool calls
+ # now we have the portion to parse as tool call.
+
+ logger.debug("Trying to parse current tool call with ID %s",
+ self.current_tool_id)
+
+ # if we're starting a new tool call, push an empty object in as
+ # a placeholder for the arguments
+ if len(self.prev_tool_call_arr) <= self.current_tool_id:
+ self.prev_tool_call_arr.append({})
+
+ # main logic for tool parsing here - compare prev. partially-parsed
+ # JSON to the current partially-parsed JSON
+ prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get(
+ "arguments")
+ cur_arguments = current_tool_call.get("arguments")
+
+ logger.debug("diffing old arguments: %s", prev_arguments)
+ logger.debug("against new ones: %s", cur_arguments)
+
+ # case -- no arguments have been created yet. skip sending a delta.
+ if not cur_arguments and not prev_arguments:
+ logger.debug("Skipping text %s - no arguments", delta_text)
+ delta = None
+
+ # case -- prev arguments are defined, but non are now.
+ # probably impossible, but not a fatal error - just keep going
+ elif not cur_arguments and prev_arguments:
+ logger.error("should be impossible to have arguments reset "
+ "mid-call. skipping streaming anything.")
+ delta = None
+
+ # case -- we now have the first info about arguments available from
+ # autocompleting the JSON
+ elif cur_arguments and not prev_arguments:
+
+ delta = DeltaMessage(tool_calls=[
+ DeltaToolCall(
+ index=self.current_tool_id,
+ function=DeltaFunctionCall(
+ arguments=cur_arguments).model_dump(
+ exclude_none=True),
+ )
+ ])
+ self.streamed_args_for_tool[
+ self.current_tool_id] = cur_arguments
+
+ # last case -- we have an update to existing arguments.
+ elif cur_arguments and prev_arguments:
+ if (isinstance(delta_text, str)
+ and cur_arguments != prev_arguments
+ and len(cur_arguments) > len(prev_arguments)
+ and cur_arguments.startswith(prev_arguments)):
+ delta_arguments = cur_arguments[len(prev_arguments):]
+ logger.debug("got diff %s", delta_text)
+
+ delta = DeltaMessage(tool_calls=[
+ DeltaToolCall(
+ index=self.current_tool_id,
+ function=DeltaFunctionCall(
+ arguments=delta_arguments).model_dump(
+ exclude_none=True),
+ )
+ ])
+ self.streamed_args_for_tool[
+ self.current_tool_id] = cur_arguments
+ else:
+ delta = None
+
+ # handle saving the state for the current tool into
+ # the "prev" list for use in diffing for the next iteration
+ if self.current_tool_id == len(self.prev_tool_call_arr) - 1:
+ self.prev_tool_call_arr[
+ self.current_tool_id] = current_tool_call
+ else:
+ self.prev_tool_call_arr.append(current_tool_call)
+
+ return delta
+
+ except Exception:
+ logger.exception("Error trying to handle streaming tool call.")
+ return None # do not stream a delta. skip this token ID.
diff --git a/vllm_mindspore/lora/__init__.py b/vllm_mindspore/lora/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/vllm_mindspore/lora/layers.py b/vllm_mindspore/lora/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..19a132c03f193845fe4e4f991d638fb2615989b3
--- /dev/null
+++ b/vllm_mindspore/lora/layers.py
@@ -0,0 +1,1165 @@
+#!/usr/bin/env python3
+# Copyright 2025 Huawei Technologies Co., Ltd
+# Copyright 2024 The vLLM team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+import math
+from dataclasses import dataclass
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast
+
+import mindspore as ms
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from transformers import PretrainedConfig
+from vllm.adapter_commons.layers import AdapterMapping
+from vllm.config import LoRAConfig
+from vllm.distributed import (get_tensor_model_parallel_rank,
+ get_tensor_model_parallel_world_size,
+ split_tensor_along_last_dim,
+ tensor_model_parallel_all_gather,
+ tensor_model_parallel_all_reduce)
+from vllm.distributed.utils import divide
+# yapf: enable
+from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.layers.rotary_embedding import (
+ LinearScalingRotaryEmbedding, RotaryEmbedding)
+from vllm.model_executor.layers.vocab_parallel_embedding import \
+ VocabParallelEmbedding
+
+# yapf: disable
+from vllm_mindspore.model_executor.layers.linear import (
+ ColumnParallelLinear, LinearBase, MergedColumnParallelLinear,
+ QKVParallelLinear, RowParallelLinear)
+
+if TYPE_CHECKING:
+ from vllm.lora.punica_wrapper import PunicaWrapperBase
+
+
+def _get_lora_device(base_layer: nn.Module) -> torch.device:
+ # code borrowed from https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/vllm/lora/layers.py#L34
+ """Returns the device for where to place the LoRA tensors."""
+ # unquantizedLinear
+ if hasattr(base_layer, "weight"):
+ return base_layer.weight.device
+ # Compressed Tensor
+ elif hasattr(base_layer, "weight_packed"):
+ return base_layer.weight_packed.device
+ # GPTQ/AWQ
+ elif hasattr(base_layer, "qweight"):
+ return base_layer.qweight.device
+ # marlin
+ elif hasattr(base_layer, "B"):
+ return base_layer.B.device
+ # HQQ marlin
+ elif hasattr(base_layer, "W_q"):
+ return base_layer.W_q.device
+ else:
+ raise ValueError(f"Unsupported base layer: {base_layer}")
+
+
+def _not_fully_sharded_can_replace(can_replace):
+ """
+ decorator which adds the condition of not using fully sharded loras
+ intended to wrap can_replace_layer()
+ """
+
+ def dec(*args, **kwargs):
+ decorate = kwargs.pop("decorate") if "decorate" in kwargs else True
+ condition = (not kwargs["lora_config"].fully_sharded_loras
+ if decorate else True)
+ return can_replace(*args, **kwargs) and condition
+
+ return dec
+
+
+@dataclass
+class LoRAMapping(AdapterMapping):
+ is_prefill: bool = False
+
+# vllm-mindspore Inherits ms.nn.Cell
+class BaseLayerWithLoRA(ms.nn.Cell):
+
+ def slice_lora_a(
+ self, lora_a: Union[torch.Tensor, List[Union[torch.Tensor, None]]]
+ ) -> Union[torch.Tensor, List[Union[torch.Tensor, None]]]:
+ """Slice lora a if splitting for tensor parallelism."""
+ ...
+
+ def slice_lora_b(
+ self, lora_b: Union[torch.Tensor, List[Union[torch.Tensor, None]]]
+ ) -> Union[torch.Tensor, List[Union[torch.Tensor, None]]]:
+ """Slice lora b if splitting with tensor parallelism."""
+ ...
+
+ def create_lora_weights(
+ self,
+ max_loras: int,
+ lora_config: LoRAConfig,
+ model_config: Optional[PretrainedConfig] = None,
+ ) -> None:
+ """Initializes lora matrices."""
+ ...
+
+ def reset_lora(self, index: int):
+ """Resets the lora weights at index back to 0."""
+ ...
+
+ def set_lora(
+ self,
+ index: int,
+ lora_a: torch.Tensor,
+ lora_b: torch.Tensor,
+ embeddings_tensor: Optional[torch.Tensor],
+ bias: Optional[torch.Tensor] = None,
+ ):
+ """Overwrites lora tensors at index."""
+ ...
+
+ def set_mapping(
+ self,
+ punica_wrapper,
+ ):
+ self.punica_wrapper: PunicaWrapperBase = punica_wrapper
+
+ @classmethod
+ def can_replace_layer(
+ cls,
+ source_layer: nn.Module,
+ lora_config: LoRAConfig,
+ packed_modules_list: List,
+ model_config: Optional[PretrainedConfig],
+ ) -> bool:
+ raise NotImplementedError
+
+
+class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
+
+ def __init__(self, base_layer: VocabParallelEmbedding) -> None:
+ super().__init__()
+ self.base_layer = base_layer
+ self.embeddings_slice: Optional[Tuple[int, int]]
+ self.embeddings_weights: Optional[torch.Tensor]
+
+ def create_lora_weights(
+ self,
+ max_loras: int,
+ lora_config: LoRAConfig,
+ model_config: Optional[PretrainedConfig] = None) -> None:
+
+ if self.base_layer.num_added_embeddings_per_partition > 0:
+ # We can start adding lora weights
+ self.embeddings_weights = self.base_layer.weight.data[
+ self.base_layer.num_org_embeddings_per_partition:self.
+ base_layer.num_org_embeddings_per_partition +
+ self.base_layer.num_added_embeddings_per_partition]
+ self.embeddings_slice = (
+ self.base_layer.shard_indices.added_vocab_start_index -
+ self.base_layer.org_vocab_size,
+ self.base_layer.shard_indices.added_vocab_end_index -
+ self.base_layer.org_vocab_size)
+ self.base_layer.weight.data[
+ self.base_layer.num_org_embeddings_per_partition:].fill_(0)
+ else:
+ self.embeddings_slice = None
+ self.embeddings_weights = None
+
+ self.embeddings_tensors = torch.zeros(
+ (
+ max_loras,
+ lora_config.lora_extra_vocab_size,
+ self.base_layer.embedding_dim,
+ ),
+ dtype=self.base_layer.weight.dtype,
+ device=self.base_layer.weight.device,
+ )
+ self.lora_a_stacked = torch.zeros(
+ (
+ max_loras,
+ self.base_layer.org_vocab_size +
+ lora_config.lora_extra_vocab_size,
+ lora_config.max_lora_rank,
+ ),
+ dtype=lora_config.lora_dtype,
+ device=self.base_layer.weight.device,
+ )
+ self.lora_b_stacked = torch.zeros(
+ (
+ max_loras,
+ 1,
+ self.base_layer.embedding_dim,
+ lora_config.max_lora_rank,
+ ),
+ dtype=lora_config.lora_dtype,
+ device=self.base_layer.weight.device,
+ )
+ self.lora_a_stacked_2d = self.lora_a_stacked.view(
+ self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1],
+ self.lora_a_stacked.shape[2],
+ )
+
+ def reset_lora(self, index: int):
+ self.lora_a_stacked[index] = 0
+ self.lora_b_stacked[index] = 0
+ self.embeddings_tensors[index] = 0
+
+ def set_lora(
+ self,
+ index: int,
+ lora_a: torch.Tensor,
+ lora_b: torch.Tensor,
+ embeddings_tensor: Optional[torch.Tensor],
+ bias: Optional[torch.Tensor] = None,
+ ):
+ self.reset_lora(index)
+ self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_(
+ lora_a, non_blocking=True)
+ self.lora_b_stacked[index,
+ 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
+ lora_b.T, non_blocking=True)
+ if embeddings_tensor is not None:
+ self.embeddings_tensors[
+ index, :embeddings_tensor.shape[0], :embeddings_tensor.
+ shape[1], ].copy_(embeddings_tensor, non_blocking=True)
+ if self.embeddings_slice is not None:
+ # TODO(yard1): Optimize this copy, we don't need to copy
+ # everything, just the modified part
+ embeddings = self.embeddings_tensors.view(
+ self.embeddings_tensors.shape[0] *
+ self.embeddings_tensors.shape[1],
+ self.embeddings_tensors.shape[2],
+ )[self.embeddings_slice[0]:self.embeddings_slice[1]]
+ assert self.embeddings_weights is not None
+ self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings)
+
+ def construct(self, x: torch.Tensor) -> torch.Tensor:
+ added_tokens_mask = x > self.base_layer.org_vocab_size - 1
+ embeddings_indices = self.punica_wrapper.embeddings_indices
+ indices = embeddings_indices[1].view_as(x)
+ full_lora_a_embeddings = F.embedding(
+ x + indices,
+ self.lora_a_stacked_2d,
+ )
+ indices = embeddings_indices[0].view_as(x)
+ full_output = self.base_layer.forward(
+ x.add_(indices * added_tokens_mask))
+
+ full_output_org = full_output
+ if full_output.ndim == 3:
+ full_output = full_output.view(
+ full_output.shape[0] * full_output.shape[1], -1)
+ if full_lora_a_embeddings.ndim == 3:
+ full_lora_a_embeddings = full_lora_a_embeddings.view(
+ full_lora_a_embeddings.shape[0] *
+ full_lora_a_embeddings.shape[1],
+ -1,
+ )
+
+ full_output = self.punica_wrapper.add_lora_embedding(
+ full_output,
+ full_lora_a_embeddings,
+ self.lora_b_stacked,
+ add_input=True)
+ return full_output.view_as(full_output_org)
+
+ @classmethod
+ def can_replace_layer(
+ cls,
+ source_layer: nn.Module,
+ lora_config: LoRAConfig,
+ packed_modules_list: List,
+ model_config: Optional[PretrainedConfig],
+ ) -> bool:
+ return type(source_layer) is VocabParallelEmbedding
+
+
+class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
+
+ def __init__(self, base_layer: LinearBase):
+ super().__init__()
+ self.base_layer = base_layer
+ self.input_size = self.base_layer.input_size
+ self.device = _get_lora_device(self.base_layer)
+ self.lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]] = None
+
+ self.output_slices: Tuple[int, ...]
+ self.tp_size: int
+ self.output_size: int
+ self.n_slices: int
+
+ def create_lora_weights(
+ self,
+ max_loras: int,
+ lora_config: LoRAConfig,
+ model_config: Optional[PretrainedConfig] = None,
+ ) -> None:
+ self.lora_config = lora_config
+
+ if isinstance(self.base_layer, ColumnParallelLinear):
+ lora_a_out_size = (lora_config.max_lora_rank if
+ not lora_config.fully_sharded_loras else divide(
+ lora_config.max_lora_rank, self.tp_size))
+ lora_b_out_size = self.output_size
+
+ elif isinstance(self.base_layer, RowParallelLinear):
+ lora_a_out_size = lora_config.max_lora_rank
+ lora_b_out_size = (self.output_size if
+ not lora_config.fully_sharded_loras else divide(
+ self.output_size, self.tp_size))
+ else:
+ raise NotImplementedError
+
+ self.lora_a_stacked = tuple(
+ torch.zeros(
+ max_loras,
+ 1,
+ lora_a_out_size,
+ self.input_size,
+ dtype=lora_config.lora_dtype,
+ device=self.device,
+ ) for _ in range(self.n_slices))
+ self.lora_b_stacked = tuple(
+ torch.zeros(
+ max_loras,
+ 1,
+ lora_b_out_size,
+ lora_config.max_lora_rank,
+ dtype=lora_config.lora_dtype,
+ device=self.device,
+ ) for _ in range(self.n_slices))
+ if lora_config.bias_enabled:
+ lora_bias_out_size = lora_b_out_size
+ self.lora_bias_stacked = tuple(
+ torch.zeros(
+ max_loras,
+ 1,
+ lora_bias_out_size,
+ dtype=lora_config.lora_dtype,
+ device=self.device,
+ ) for _ in range(self.n_slices))
+ self.output_slices = (self.lora_b_stacked[0].shape[2], )
+
+ def reset_lora(self, index: int):
+ for s_index in range(self.n_slices):
+ self.lora_a_stacked[s_index][index] = 0
+ self.lora_b_stacked[s_index][index] = 0
+ if self.lora_config.bias_enabled:
+ # Make mypy happy
+ self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...],
+ self.lora_bias_stacked)
+ self.lora_bias_stacked[s_index][index] = 0
+
+ def set_lora(
+ self,
+ index: int,
+ lora_a: torch.Tensor,
+ lora_b: torch.Tensor,
+ embeddings_tensor: Optional[torch.Tensor],
+ lora_bias: Optional[torch.Tensor] = None,
+ ):
+ # Except for QKVParallelLinearWithLora and
+ # MergedColumnParallelLinearWithLoRA, all other linear LoRA layers
+ # store weights in a tuple of size 1. These two layers will
+ # override this function.
+ assert (len(self.lora_a_stacked) == len(self.lora_b_stacked) ==
+ self.n_slices == 1)
+
+ self.reset_lora(index)
+ if self.tp_size > 1:
+ lora_a = self.slice_lora_a(lora_a)
+ lora_b = self.slice_lora_b(lora_b)
+ if lora_bias is not None:
+ lora_bias = self.slice_bias(lora_bias)
+
+ self.lora_a_stacked[0][index,
+ 0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
+ lora_a.T, non_blocking=True)
+ self.lora_b_stacked[0][index,
+ 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
+ lora_b.T, non_blocking=True)
+ if lora_bias is not None:
+
+ self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...],
+ self.lora_bias_stacked)
+ assert len(self.lora_bias_stacked)
+ self.lora_bias_stacked[0][index, 0, :lora_bias.shape[0]].copy_(
+ lora_bias.T, non_blocking=True)
+
+ def apply(self,
+ x: torch.Tensor,
+ bias: Optional[torch.Tensor] = None) -> torch.Tensor:
+ output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
+ self.punica_wrapper.add_lora_linear(output, x, self.lora_a_stacked,
+ self.lora_b_stacked,
+ self.lora_bias_stacked, 1.0,
+ self.output_slices)
+ return output
+
+
+
+
+class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
+ """
+ LoRA on top of ColumnParallelLinear layer.
+ LoRA B is sliced for tensor parallelism.
+ There are two types for the `base_layer`:
+ 1. ColumnParallelLinear, e.g.`dense_h_to_4h` in `FalconForCausalLM`.
+ 2. MergedColumnParallelLinear, e.g.`gate_up_proj` in `Phi3ForCausalLM`.
+ """
+
+ def __init__(self, base_layer: ColumnParallelLinear) -> None:
+ super().__init__(base_layer)
+ # The base_layer type is ColumnParallelLinear or
+ # MergedColumnParallelLinear, their weight sharding logic is
+ # inconsistent when TP is greater than 1.
+ self.is_merged_col_linear = type(
+ base_layer) is MergedColumnParallelLinear
+ self.tp_size = get_tensor_model_parallel_world_size()
+ self.output_size = self.base_layer.output_size_per_partition
+ # There is only one LoRA layer
+ self.n_slices = 1
+
+ def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
+ return lora_a
+
+ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
+ # Applicable to cases where the base_layer is
+ # MergedColumnParallelLinear.
+ if self.is_merged_col_linear:
+ tp_rank = get_tensor_model_parallel_rank()
+ shard_size = self.output_size // 2
+ offset = lora_b.shape[-1] // 2
+
+ left_weight = lora_b[:, tp_rank * shard_size:(tp_rank + 1) *
+ shard_size]
+ right_weight = lora_b[:, offset + tp_rank * shard_size:offset +
+ (tp_rank + 1) * shard_size]
+ lora_b = torch.cat([left_weight, right_weight], dim=1)
+ # Applicable to cases where the base_layer is
+ # ColumnParallelLinear.
+ else:
+ tensor_model_parallel_rank = get_tensor_model_parallel_rank()
+ shard_size = self.output_size
+ start_idx = tensor_model_parallel_rank * shard_size
+ end_idx = (tensor_model_parallel_rank + 1) * shard_size
+ lora_b = lora_b[:, start_idx:end_idx]
+ return lora_b
+
+ def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
+ # TODO: Fix the slicing logic of bias.
+ if bias is None:
+ return bias
+ tensor_model_parallel_rank = get_tensor_model_parallel_rank()
+ shard_size = self.output_size
+ start_idx = tensor_model_parallel_rank * shard_size
+ end_idx = (tensor_model_parallel_rank + 1) * shard_size
+ bias = bias[start_idx:end_idx]
+ return bias
+
+ def construct(
+ self, input_: torch.Tensor
+ ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
+ """Forward of ColumnParallelLinear
+
+ Args:
+ input_: Tensor whose last dimension is `input_size`.
+
+ Returns:
+ - output
+ - bias
+ """
+ bias = (self.base_layer.bias
+ if not self.base_layer.skip_bias_add else None)
+
+ # Matrix multiply.
+ output_parallel = self.apply(input_, bias)
+ if self.base_layer.gather_output:
+ # All-gather across the partitions.
+ output = tensor_model_parallel_all_gather(output_parallel)
+ else:
+ output = output_parallel
+ output_bias = (self.base_layer.bias
+ if self.base_layer.skip_bias_add else None)
+ return output, output_bias
+
+ @classmethod
+ @_not_fully_sharded_can_replace
+ def can_replace_layer(
+ cls,
+ source_layer: nn.Module,
+ lora_config: LoRAConfig,
+ packed_modules_list: List,
+ model_config: Optional[PretrainedConfig],
+ ) -> bool:
+ return type(source_layer) is ColumnParallelLinear or (
+ type(source_layer) is MergedColumnParallelLinear
+ and len(packed_modules_list) == 1)
+
+
+class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
+ """ColumnParallelLinear layer that is composed of 2 sublayers (slices)
+ packed together (eg. gate_proj + up_proj -> gate_up_proj).
+
+ This means we have 2 LoRAs, each applied to one half of the layer.
+
+ Both slices must have the same size.
+ """
+
+ def __init__(
+ self, base_layer: Union[MergedColumnParallelLinear,
+ QKVParallelLinear]) -> None:
+ super().__init__(base_layer)
+ # There are two LoRA layers
+ self.tp_size = get_tensor_model_parallel_world_size()
+ self.tp_rank = get_tensor_model_parallel_rank()
+ # the output_sizes in MergedColumnParallelLinear is not sharded by tp
+ # we need to divide it by the tp_size to get correct slices size
+ output_sizes = self.base_layer.output_sizes
+ self.output_slices = tuple(
+ divide(output_size, self.tp_size) for output_size in output_sizes)
+ self.n_slices = len(self.output_slices)
+ self.output_ids = (self.tp_rank, ) * self.n_slices
+
+ def create_lora_weights(
+ self,
+ max_loras: int,
+ lora_config: LoRAConfig,
+ model_config: Optional[PretrainedConfig] = None,
+ ) -> None:
+ """
+ The main reason for overriding this function is to enhance code
+ maintainability.
+ """
+ self.lora_config = lora_config
+
+ lora_a_output_size_per_partition = (
+ lora_config.max_lora_rank if not lora_config.fully_sharded_loras
+ else divide(lora_config.max_lora_rank, self.tp_size))
+ self.lora_a_stacked = tuple(
+ torch.zeros(
+ max_loras,
+ 1,
+ lora_a_output_size_per_partition,
+ self.input_size,
+ dtype=lora_config.lora_dtype,
+ device=self.device,
+ ) for _ in range(self.n_slices))
+ self.lora_b_stacked = tuple(
+ torch.zeros(
+ max_loras,
+ 1,
+ output_size,
+ lora_config.max_lora_rank,
+ dtype=lora_config.lora_dtype,
+ device=self.device,
+ ) for output_size in self.output_slices)
+ if lora_config.bias_enabled:
+ self.lora_bias_stacked = tuple(
+ torch.zeros(
+ max_loras,
+ 1,
+ output_size,
+ dtype=lora_config.lora_dtype,
+ device=self.device,
+ ) for output_size in self.output_slices)
+
+ def slice_lora_a(
+ self, lora_a: List[Union[torch.Tensor, None]]
+ ) -> List[Union[torch.Tensor, None]]:
+ return lora_a
+
+ def slice_lora_b(
+ self, lora_b: List[Union[torch.Tensor, None]]
+ ) -> List[Union[torch.Tensor, None]]:
+ for i, (shard_id, shard_size) in enumerate(
+ zip(self.output_ids, self.output_slices)):
+ if (lora_b_i := lora_b[i]) is not None:
+ lora_b[i] = lora_b_i[:, shard_size * shard_id:shard_size *
+ (shard_id + 1)]
+ return lora_b
+
+ def slice_bias(
+ self, bias: List[Union[torch.Tensor,
+ None]]) -> List[Union[torch.Tensor, None]]:
+ for i, (shard_id, shard_size) in enumerate(
+ zip(self.output_ids, self.output_slices)):
+ if (bias_i := bias[i]) is not None:
+ bias[i] = bias_i[shard_size * shard_id:shard_size *
+ (shard_id + 1)]
+ return bias
+
+ def set_lora(
+ self,
+ index: int,
+ lora_a: torch.Tensor,
+ lora_b: torch.Tensor,
+ embeddings_tensor: Optional[torch.Tensor],
+ lora_bias: Optional[torch.Tensor] = None,
+ ):
+ self.reset_lora(index)
+
+ if self.tp_size > 1:
+ lora_a = self.slice_lora_a(lora_a)
+ lora_b = self.slice_lora_b(lora_b)
+ if lora_bias is not None:
+ lora_bias = self.slice_bias(lora_bias)
+
+ for i in range(self.n_slices):
+ if (lora_a_i := lora_a[i]) is not None:
+ self.lora_a_stacked[i][
+ index, 0, :lora_a_i.shape[1], :lora_a_i.shape[0]].copy_(
+ lora_a_i.T, non_blocking=True)
+ if (lora_b_i := lora_b[i]) is not None:
+ self.lora_b_stacked[i][
+ index, 0, :lora_b_i.shape[1], :lora_b_i.shape[0]].copy_(
+ lora_b_i.T, non_blocking=True)
+
+ if lora_bias is not None:
+ self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...],
+ self.lora_bias_stacked)
+ for i in range(self.n_slices):
+ if (lora_bias_i := lora_bias[i]) is not None:
+ self.lora_bias_stacked[i][index,
+ 0, :lora_bias_i.shape[0]].copy_(
+ lora_bias_i.T,
+ non_blocking=True)
+
+ @classmethod
+ @_not_fully_sharded_can_replace
+ def can_replace_layer(
+ cls,
+ source_layer: nn.Module,
+ lora_config: LoRAConfig,
+ packed_modules_list: List,
+ model_config: Optional[PretrainedConfig],
+ ) -> bool:
+ return (type(source_layer) is MergedColumnParallelLinear
+ and len(packed_modules_list) == 2)
+
+
+class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
+ """
+ ColumnParallelLinear layer that is specifically designed for
+ qkv_proj. Certain models, such as chatglm3 and baichuan-7b,
+ only contains a single LoRA within their qkv_proj layer.
+
+ During inference with Tensor Parallel, the weights of lora_b
+ must be accurately partitioned according to the respective ranks.
+
+ Q slice may have different shape than K and V slices (which both have
+ the same shape).
+ """
+
+ def __init__(self, base_layer: QKVParallelLinear) -> None:
+ super().__init__(base_layer)
+ self.q_proj_total_size = (self.base_layer.total_num_heads *
+ self.base_layer.head_size)
+ self.q_proj_shard_size = (self.base_layer.num_heads *
+ self.base_layer.head_size)
+ self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
+ self.base_layer.head_size)
+ self.kv_proj_total_size = (self.base_layer.total_num_kv_heads *
+ self.base_layer.head_size)
+ # There is only one LoRA layer
+ self.n_slices = 1
+
+ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
+ tp_rank = get_tensor_model_parallel_rank()
+ self.q_shard_id = tp_rank
+ self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas
+ lora_b_q = lora_b[:, self.q_proj_shard_size *
+ self.q_shard_id:self.q_proj_shard_size *
+ (self.q_shard_id + 1)]
+ k_offset = self.q_proj_total_size
+ lora_b_k = lora_b[:, k_offset +
+ self.kv_proj_shard_size * self.kv_shard_id:k_offset +
+ self.kv_proj_shard_size * (self.kv_shard_id + 1)]
+ v_offset = k_offset + self.kv_proj_total_size
+ lora_b_v = lora_b[:, v_offset +
+ self.kv_proj_shard_size * self.kv_shard_id:v_offset +
+ self.kv_proj_shard_size * (self.kv_shard_id + 1)]
+ lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1)
+ return lora_b
+
+ def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
+ bias_q = bias[self.q_proj_shard_size *
+ self.q_shard_id:self.q_proj_shard_size *
+ (self.q_shard_id + 1)]
+ k_offset = self.q_proj_total_size
+ bias_k = bias[k_offset +
+ self.kv_proj_shard_size * self.kv_shard_id:k_offset +
+ self.kv_proj_shard_size * (self.kv_shard_id + 1)]
+ v_offset = k_offset + self.kv_proj_total_size
+ bias_v = bias[v_offset +
+ self.kv_proj_shard_size * self.kv_shard_id:v_offset +
+ self.kv_proj_shard_size * (self.kv_shard_id + 1)]
+ bias = torch.cat([bias_q, bias_k, bias_v], dim=1)
+ return bias
+
+ @classmethod
+ @_not_fully_sharded_can_replace
+ def can_replace_layer(cls, source_layer: nn.Module,
+ lora_config: LoRAConfig, packed_modules_list: List,
+ model_config: Optional[PretrainedConfig]) -> bool:
+ return type(source_layer) is QKVParallelLinear and len(
+ packed_modules_list) == 1
+
+
+class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA):
+ """MergedColumnParallelLinear layer that is composed of 3 sublayers (slices)
+ packed together in qkv proj fashion
+ (q_proj + k_proj + v_proj -> qkv_proj).
+
+ This means we have 3 LoRAs, each applied to one slice of the layer.
+
+ Q slice may have different shape than K and V slices (which both have
+ the same shape).
+ """
+
+ def __init__(self, base_layer: QKVParallelLinear) -> None:
+ super().__init__(base_layer)
+ # There are three LoRA layer.
+ self.n_slices = len(self.base_layer.output_sizes)
+ self.tp_size = get_tensor_model_parallel_world_size()
+ self.tp_rank = get_tensor_model_parallel_rank()
+
+ self.q_proj_shard_size = (self.base_layer.num_heads *
+ self.base_layer.head_size)
+ self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
+ self.base_layer.head_size)
+ self.q_shard_id = self.tp_rank
+ self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas
+
+ self.output_slices = (
+ self.q_proj_shard_size,
+ self.kv_proj_shard_size,
+ self.kv_proj_shard_size,
+ )
+ self.output_ids = (
+ self.q_shard_id,
+ self.kv_shard_id,
+ self.kv_shard_id,
+ )
+
+ def create_lora_weights(
+ self,
+ max_loras: int,
+ lora_config: LoRAConfig,
+ model_config: Optional[PretrainedConfig] = None,
+ ) -> None:
+ """
+ The main reason for overloading this function is to handle inconsistent
+ weight dimensions in qkv lora.
+ """
+ super().create_lora_weights(max_loras, lora_config, model_config)
+
+ @classmethod
+ @_not_fully_sharded_can_replace
+ def can_replace_layer(
+ cls,
+ source_layer: nn.Module,
+ lora_config: LoRAConfig,
+ packed_modules_list: List,
+ model_config: Optional[PretrainedConfig],
+ ) -> bool:
+ return (type(source_layer) is QKVParallelLinear
+ and len(packed_modules_list) == 3)
+
+
+class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
+
+ def __init__(self, base_layer: RowParallelLinear) -> None:
+ super().__init__(base_layer)
+
+ self.tp_size = get_tensor_model_parallel_world_size()
+ # reset input_size
+ self.input_size = self.base_layer.input_size_per_partition
+ self.output_size = self.base_layer.output_size
+
+ self.tp_rank = get_tensor_model_parallel_rank()
+ # There is only one LoRA layer.
+ self.n_slices = 1
+
+ def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
+
+ shard_size = self.input_size
+ start_idx = self.tp_rank * shard_size
+ end_idx = (self.tp_rank + 1) * shard_size
+ lora_a = lora_a[start_idx:end_idx, :]
+ return lora_a
+
+ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
+ return lora_b
+
+ def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
+ return bias
+
+ def construct(
+ self, input_: torch.Tensor
+ ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
+ """Forward of RowParallelLinear
+
+ Args:
+ input_: tensor whose last dimension is `input_size`. If
+ `input_is_parallel` is set, then the last dimension
+ is `input_size // tp_size`.
+
+ Returns:
+ - output
+ - bias
+ """
+ # Set up backprop all-reduce.
+ if self.base_layer.input_is_parallel:
+ input_parallel = input_
+ else:
+ # TODO: simplify code below
+ splitted_input = split_tensor_along_last_dim(
+ input_, num_partitions=self.base_layer.tp_size)
+ input_parallel = splitted_input[self.tp_rank].contiguous()
+
+ # Matrix multiply.
+ output_parallel = self.apply(input_parallel)
+ if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
+ output_ = tensor_model_parallel_all_reduce(output_parallel)
+ else:
+ output_ = output_parallel
+
+ if not self.base_layer.skip_bias_add:
+ output = (output_ + self.base_layer.bias
+ if self.base_layer.bias is not None else output_)
+ output_bias = None
+ else:
+ output = output_
+ output_bias = self.base_layer.bias
+ return output, output_bias
+
+ @property
+ def weight(self):
+ return (self.base_layer.weight if hasattr(self.base_layer, "weight")
+ else self.base_layer.qweight)
+
+ @classmethod
+ @_not_fully_sharded_can_replace
+ def can_replace_layer(
+ cls,
+ source_layer: nn.Module,
+ lora_config: LoRAConfig,
+ packed_modules_list: List,
+ model_config: Optional[PretrainedConfig],
+ ) -> bool:
+ return type(source_layer) is RowParallelLinear
+
+
+class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
+ """
+ LoRA wrapper for LogitsProcessor, with extra logic to handle the
+ application of the LoRA adapter and added LoRA vocabulary.
+
+ Args:
+ base_layer: LogitsProcessor layer
+ hidden_size: hidden size of the model
+ dtype: data type of the model
+ device: device of the model
+ sharded_to_full_mapping: index mapping from sharded vocab to full vocab
+ received from base_layer.get_sharded_to_full_mapping(). If None,
+ no reindexing will be done.
+ """
+
+ def __init__(self, base_layer: LogitsProcessor, hidden_size: int,
+ dtype: torch.dtype, device: torch.device,
+ sharded_to_full_mapping: Optional[List[int]]) -> None:
+ super().__init__()
+ self.base_layer = base_layer
+ self.hidden_size = hidden_size
+ self.dtype = dtype
+ self.device = device
+ self.tp_size = get_tensor_model_parallel_world_size()
+ self.tp_rank = get_tensor_model_parallel_rank()
+ self.sharded_to_full_mapping = sharded_to_full_mapping
+
+ @property
+ def logits_as_input(self):
+ return self.base_layer.logits_as_input
+
+ @property
+ def vocab_size(self):
+ return self.base_layer.vocab_size
+
+ @property
+ def scale(self):
+ return self.base_layer.scale
+
+ @property
+ def soft_cap(self):
+ return self.base_layer.soft_cap
+
+ @property
+ def use_all_gather(self):
+ return self.base_layer.use_all_gather
+
+ @property
+ def org_vocab_size(self):
+ return self.base_layer.org_vocab_size
+
+ @property
+ def include_gpu_probs_tensor(self):
+ return self.base_layer.include_gpu_probs_tensor
+
+ @property
+ def should_modify_greedy_probs_inplace(self):
+ return self.base_layer.should_modify_greedy_probs_inplace
+
+ def create_lora_weights(
+ self,
+ max_loras: int,
+ lora_config: LoRAConfig,
+ model_config: Optional[PretrainedConfig] = None,
+ ) -> None:
+ # TODO: Verify if this condition can be further relaxed
+ if 32000 < self.base_layer.vocab_size > 257024:
+ raise ValueError("When using LoRA, vocab size must be "
+ "32000 >= vocab_size <= 257024")
+ self.lora_a_stacked = torch.zeros(
+ (
+ max_loras,
+ 1,
+ lora_config.max_lora_rank,
+ self.hidden_size,
+ ),
+ dtype=lora_config.lora_dtype,
+ device=self.device,
+ )
+ self.lora_b_stacked = torch.zeros(
+ (
+ max_loras,
+ 1,
+ # Pad for kernel compatibility
+ math.ceil(self.base_layer.vocab_size /
+ lora_config.lora_vocab_padding_size) *
+ lora_config.lora_vocab_padding_size,
+ lora_config.max_lora_rank,
+ ),
+ dtype=lora_config.lora_dtype,
+ device=self.device,
+ )
+ self.embeddings_tensors = torch.full(
+ (max_loras, lora_config.lora_extra_vocab_size, self.hidden_size),
+ fill_value=float("-inf"),
+ dtype=self.dtype,
+ device=self.device,
+ )
+ if self.sharded_to_full_mapping is not None:
+ self.sharded_to_full_mapping_gpu = torch.tensor(
+ self.sharded_to_full_mapping,
+ device=self.device,
+ dtype=torch.long)
+ else:
+ self.sharded_to_full_mapping_gpu = None
+
+ def reset_lora(self, index: int):
+ self.lora_a_stacked[index] = 0
+ self.lora_b_stacked[index] = 0
+ self.embeddings_tensors[index] = float("-inf")
+
+ def set_lora(
+ self,
+ index: int,
+ lora_a: torch.Tensor,
+ lora_b: torch.Tensor,
+ embeddings_tensor: Optional[torch.Tensor],
+ bias: Optional[torch.Tensor] = None,
+ ):
+ self.reset_lora(index)
+ self.lora_a_stacked[index,
+ 0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
+ lora_a.T, non_blocking=True)
+ self.lora_b_stacked[index,
+ 0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
+ lora_b.T, non_blocking=True)
+ if embeddings_tensor is not None:
+ self.embeddings_tensors[
+ index, :embeddings_tensor.shape[0], :embeddings_tensor.
+ shape[1], ] = embeddings_tensor
+
+ def _get_logits(
+ self,
+ hidden_states: torch.Tensor,
+ lm_head: VocabParallelEmbedding,
+ embedding_bias: Optional[torch.Tensor] = None,
+ ) -> Optional[torch.Tensor]:
+ # Get the logits for the next tokens.
+ logits = lm_head.quant_method.apply(lm_head, hidden_states)
+ if embedding_bias is not None:
+ logits += embedding_bias
+
+ # Gather logits for TP
+ logits = self.base_layer._gather_logits(logits)
+
+ if logits is None:
+ return None
+
+ if self.sharded_to_full_mapping_gpu is not None:
+ # Reindex full logits tensor to ensure 1:1 mapping between
+ # index and token_id
+ # Example for:
+ # org_vocab_size = 4
+ # added_vocab_size = 2
+ # pad_to_size = 8
+ # tp_size = 2
+
+ # indices: [0, 1, 2, 3, 4, 5, 6, 7]
+ # token_id: [0, 1, 4, -1, 2, 3, 5, -1]
+
+ # Therefore, the mapping is expected to be:
+ # [0, 1, 4, 6, 2, 3, 5, 7] so that when we reindex,
+ # we get:
+ # indices: [0, 1, 2, 3, 4, 5, 6, 7]
+ # token_id: [0, 1, 2, 3, 4, 5, -1, -1]
+ logits = logits[:, self.sharded_to_full_mapping_gpu]
+
+ lora_logits = torch.empty(
+ self.embeddings_tensors.shape[0] + 1,
+ self.embeddings_tensors.shape[1],
+ hidden_states.shape[0],
+ dtype=self.embeddings_tensors.dtype,
+ device=self.embeddings_tensors.device,
+ )
+ torch.matmul(self.embeddings_tensors,
+ hidden_states.T,
+ out=lora_logits[:-1])
+ lora_logits[-1] = float("-inf")
+ lora_logits = lora_logits.mT
+ indices_padded = self.punica_wrapper.sampler_indices_padded
+ lora_logits = (lora_logits.reshape(
+ lora_logits.shape[0] * lora_logits.shape[1],
+ lora_logits.shape[2],
+ ).index_select(0, indices_padded).nan_to_num_(nan=float("-inf"),
+ posinf=float("inf"),
+ neginf=float("-inf")))
+
+ logits[:,
+ self.base_layer.org_vocab_size:self.base_layer.org_vocab_size +
+ lora_logits.shape[1]] = lora_logits
+
+ # LogitsProcessorWithLoRA always using bgmv
+ self.punica_wrapper.add_lora_logits(logits, hidden_states,
+ self.lora_a_stacked,
+ self.lora_b_stacked, 1.0)
+
+ # Remove paddings in vocab (if any).
+ logits = logits[:, :self.base_layer.vocab_size]
+ return logits
+
+ def construct(self, *args, **kwargs):
+ return type(self.base_layer).forward(self, *args, **kwargs)
+
+ @classmethod
+ def can_replace_layer(
+ cls,
+ source_layer: nn.Module,
+ lora_config: LoRAConfig,
+ packed_modules_list: List,
+ model_config: Optional[PretrainedConfig],
+ ) -> bool:
+ # Special handling for the LogitsProcessor.
+ return False
+
+
+class LinearScalingRotaryEmbeddingWithLoRA(BaseLayerWithLoRA):
+ """Implements RoPE-scaled embeddings with linear scaling for
+ multiple LoRA adapters with a specialized kernel.
+
+ Replace LinearScalingRotaryEmbedding with MultiLinearScalingRotaryEmbedding
+ which can handle multi lora adapters in a specialied kernel.
+ """
+
+ def __init__(self, base_layer: RotaryEmbedding) -> None:
+ super().__init__()
+ self.base_layer = base_layer
+
+ @property
+ def scaling_factors(self):
+ return self.base_layer.scaling_factors
+
+ @property
+ def rotary_dim(self):
+ return self.base_layer.rotary_dim
+
+ def create_lora_weights(
+ self,
+ max_loras: int,
+ lora_config: LoRAConfig,
+ model_config: Optional[PretrainedConfig] = None,
+ ) -> None:
+ scaling_factors = (list(lora_config.long_lora_scaling_factors)
+ if lora_config.long_lora_scaling_factors else [])
+ base_scaling_factor = (self.base_layer.scaling_factor if isinstance(
+ self.base_layer, LinearScalingRotaryEmbedding) else 1.0)
+ scaling_factors = sorted(
+ list(set([base_scaling_factor] + scaling_factors)))
+ self.base_layer = LinearScalingRotaryEmbedding(
+ self.base_layer.head_size,
+ self.base_layer.rotary_dim,
+ self.base_layer.max_position_embeddings,
+ self.base_layer.base,
+ self.base_layer.is_neox_style,
+ scaling_factors,
+ self.base_layer.dtype,
+ )
+
+ def reset_lora(self, index: int):
+ ...
+
+ def set_lora(
+ self,
+ index: int,
+ lora_a: torch.Tensor,
+ lora_b: torch.Tensor,
+ embeddings_tensor: Optional[torch.Tensor],
+ bias: Optional[torch.Tensor] = None,
+ ):
+ ...
+
+ def construct(
+ self,
+ positions: torch.Tensor,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ ):
+ return self.base_layer(
+ positions,
+ query,
+ key,
+ offsets=self.punica_wrapper.long_lora_indices,
+ )
+
+ @property
+ def scaling_factor_to_offset(self) -> Dict[float, int]:
+ return self.base_layer.scaling_factor_to_offset
+
+ @classmethod
+ def can_replace_layer(
+ cls,
+ source_layer: nn.Module,
+ lora_config: LoRAConfig,
+ packed_modules_list: List,
+ model_config: Optional[PretrainedConfig],
+ ) -> bool:
+ """Returns True if the layer can be replaced by this LoRA layer."""
+ return (type(source_layer) is LinearScalingRotaryEmbedding
+ or type(source_layer) is RotaryEmbedding)
+
+ def extra_repr(self) -> str:
+ return self.base_layer.extra_repr()
diff --git a/vllm_mindspore/lora/models.py b/vllm_mindspore/lora/models.py
new file mode 100644
index 0000000000000000000000000000000000000000..921978498d13cd5289e21e3c4dc60ff830ef248f
--- /dev/null
+++ b/vllm_mindspore/lora/models.py
@@ -0,0 +1,227 @@
+#!/usr/bin/env python3
+# Copyright 2025 Huawei Technologies Co., Ltd
+# Copyright 2024 The vLLM team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+import os
+from typing import Dict, List, Optional, Union
+
+import safetensors.torch
+import torch
+from vllm.lora.lora import LoRALayerWeights
+from vllm.lora.peft_helper import PEFTHelper
+from vllm.lora.utils import is_regex_target_modules, parse_fine_tuned_lora_name
+from vllm.model_executor.models.utils import WeightsMapper
+from vllm.utils import is_pin_memory_available
+
+from vllm_mindspore.lora.layers import BaseLayerWithLoRA
+
+_GLOBAL_LORA_ID = 0
+
+
+def get_lora_id():
+ global _GLOBAL_LORA_ID
+ _GLOBAL_LORA_ID += 1
+ return _GLOBAL_LORA_ID
+
+
+def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
+ assert isinstance(module, BaseLayerWithLoRA)
+ self.modules[module_name] = module
+
+
+@classmethod #type:ignore
+def from_lora_tensors(
+ cls,
+ lora_model_id: int,
+ tensors: Dict[str, torch.Tensor],
+ peft_helper: PEFTHelper,
+ device: str = "cuda",
+ dtype: Optional[torch.dtype] = None,
+ embeddings: Optional[Dict[str, torch.Tensor]] = None,
+ target_embedding_padding: Optional[int] = None,
+ embedding_modules: Optional[Dict[str, str]] = None,
+ embedding_padding_modules: Optional[List[str]] = None,
+ weights_mapper: Optional[WeightsMapper] = None,
+):
+ """Create a LoRAModel from a dictionary of tensors."""
+ pin_memory = str(device) == "cpu" and is_pin_memory_available()
+ loras: Dict[str, LoRALayerWeights] = {}
+ for tensor_name, tensor in tensors.items():
+ module_name, is_lora_a, is_bias = parse_fine_tuned_lora_name(
+ tensor_name, weights_mapper)
+ if module_name not in loras:
+ lora_embeddings_tensor = None
+ if embeddings:
+ assert embedding_modules is not None
+ embeddings_module = next(
+ (k for k in embedding_modules if k in module_name), None)
+ if embeddings_module:
+ lora_embeddings_tensor = embeddings[
+ embedding_modules[embeddings_module]]
+ if pin_memory:
+ lora_embeddings_tensor = (
+ lora_embeddings_tensor.pin_memory())
+ loras[module_name] = LoRALayerWeights.from_config(
+ module_name, peft_helper, lora_embeddings_tensor)
+
+ if is_bias:
+ # vllm-mindspore remove tensor device
+ loras[module_name].bias = tensor.to(dtype=dtype).t()
+ bias = tensor.to(dtype=dtype).t()
+ if pin_memory:
+ bias = bias.pin_memory()
+ loras[module_name].bias = bias
+ elif is_lora_a:
+ loras[module_name].lora_a = tensor.to(dtype=dtype).t()
+ if pin_memory:
+ loras[module_name].lora_a = loras[
+ module_name].lora_a.pin_memory()
+ else:
+ loras[module_name].lora_b = tensor.to(dtype=dtype).t()
+ assert embedding_padding_modules is not None
+ if any(name in module_name for name in embedding_padding_modules
+ ) and target_embedding_padding is not None:
+ lora_b = loras[module_name].lora_b
+ assert target_embedding_padding >= lora_b.shape[1]
+ addition = target_embedding_padding - lora_b.shape[1]
+ loras[module_name].lora_b = torch.nn.functional.pad(
+ lora_b, (0, addition))
+ if pin_memory:
+ loras[module_name].lora_b = loras[
+ module_name].lora_b.pin_memory()
+
+ for lora in loras.values():
+ lora.optimize()
+
+ return cls(lora_model_id,
+ peft_helper.r,
+ loras,
+ scaling_factor=peft_helper.vllm_long_context_scaling_factor)
+
+
+@classmethod #type:ignore
+def from_local_checkpoint(
+ cls,
+ lora_dir: str,
+ expected_lora_modules: List[str],
+ peft_helper: PEFTHelper,
+ *,
+ lora_model_id: Optional[int] = None,
+ device: str = "cuda",
+ dtype: Optional[torch.dtype] = None,
+ target_embedding_padding: Optional[int] = None,
+ embedding_modules: Optional[Dict[str, str]] = None,
+ embedding_padding_modules: Optional[List[str]] = None,
+ weights_mapper: Optional[WeightsMapper] = None,
+):
+ """Create a LoRAModel from a local checkpoint.
+
+ Args:
+ lora_dir: The local path that has lora data.
+ expected_lora_modules: Name of modules that are expected to be
+ replaced by lora.
+ peft_helper: Loaded lora configuration information.
+ lora_model_id: Lora model id. If not given, automatically set by
+ a global counter.
+ device: Device where the lora model is loaded.
+ dtype: dtype of the lora model weights.
+
+ Returns:
+ Loaded LoRA Model.
+ """
+ lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
+ lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin")
+ new_embeddings_tensor_path = os.path.join(lora_dir,
+ "new_embeddings.safetensors")
+ new_embeddings_bin_file_path = os.path.join(lora_dir, "new_embeddings.bin")
+
+ unexpected_modules: List[Union[list[str], str]]
+ if os.path.isfile(lora_tensor_path):
+ tensors: Dict[str, torch.Tensor] = {}
+ # Find unexpected modules.
+ # Use safetensor key as a source of truth to find expected modules.
+ # in peft if you have target_modules A, B, C and C does not exist
+ # in the model it won’t error and model will be trained with A, B
+ # loraified. C won’t exist in the safetensor but it will exist in
+ # the target_modules of the adapter_config.json.
+ unexpected_modules = []
+ # vllm-mindspore safetensors open with np
+ with safetensors.safe_open(lora_tensor_path,
+ framework="np") as f: # type: ignore
+ for lora_module in f.keys(): # noqa
+ module_name, _, _ = parse_fine_tuned_lora_name(
+ lora_module, weights_mapper)
+ part_name = module_name.split(".")[-1]
+ if part_name not in expected_lora_modules:
+ unexpected_modules.append(module_name)
+ if unexpected_modules:
+ raise ValueError(
+ f"While loading {lora_dir}, expected"
+ f" target modules in {expected_lora_modules}"
+ f" but received {unexpected_modules}."
+ f" Please verify that the loaded LoRA module is correct")
+ # Load tensors if there are only expected modules.
+ for module in f.keys(): # noqa
+ # vllm-mindspore add numpy to tensor
+ tensors[module] = torch.Tensor(f.get_tensor(module))
+ elif os.path.isfile(lora_bin_file_path):
+ # When a bin file is provided, we rely on config to find unexpected
+ # modules.
+ unexpected_modules = []
+ target_modules = peft_helper.target_modules
+ if not isinstance(target_modules, list):
+ target_modules = [target_modules]
+ for module in target_modules:
+ # Compatible with more modules,
+ # such as:layers.11.self_attn.k_proj
+ part_name = module.split(".")[-1]
+ if part_name not in expected_lora_modules:
+ unexpected_modules.append(module)
+ # loaded lora's target modules must be a subset of
+ # expected_lora_modules. It is not reliable. See
+ # https://github.com/vllm-project/vllm/pull/5909. But there's no
+ # other better mechanism.
+ if unexpected_modules and not is_regex_target_modules(
+ peft_helper.target_modules, expected_lora_modules):
+ raise ValueError(
+ f"While loading {lora_dir}, expected"
+ f" target modules in {expected_lora_modules}"
+ f" but received {unexpected_modules}."
+ f" Please verify that the loaded LoRA module is correct")
+ tensors = torch.load(lora_bin_file_path, map_location=device)
+ else:
+ raise ValueError(f"{lora_dir} doesn't contain tensors")
+
+ embeddings = None
+ if os.path.isfile(new_embeddings_tensor_path):
+ embeddings = safetensors.torch.load_file(new_embeddings_tensor_path)
+ elif os.path.isfile(new_embeddings_bin_file_path):
+ embeddings = torch.load(new_embeddings_bin_file_path,
+ map_location=device,
+ weights_only=True)
+
+ return cls.from_lora_tensors(
+ lora_model_id=get_lora_id()
+ if lora_model_id is None else lora_model_id,
+ tensors=tensors,
+ peft_helper=peft_helper,
+ device=device,
+ dtype=dtype,
+ embeddings=embeddings,
+ target_embedding_padding=target_embedding_padding,
+ embedding_modules=embedding_modules,
+ embedding_padding_modules=embedding_padding_modules,
+ weights_mapper=weights_mapper)
diff --git a/vllm_mindspore/lora/ops/__init__.py b/vllm_mindspore/lora/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/vllm_mindspore/lora/ops/torch_ops/__init__.py b/vllm_mindspore/lora/ops/torch_ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/vllm_mindspore/lora/ops/torch_ops/lora_ops.py b/vllm_mindspore/lora/ops/torch_ops/lora_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..d085c34e264830fb9c087d9067a45d6624801102
--- /dev/null
+++ b/vllm_mindspore/lora/ops/torch_ops/lora_ops.py
@@ -0,0 +1,171 @@
+#!/usr/bin/env python3
+# Copyright 2025 Huawei Technologies Co., Ltd
+# Copyright 2024 The vLLM team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""
+For punica_npu
+"""
+from mindspore import mint
+from mindspore.ops.auto_generate import grouped_matmul_v4
+
+def einsum_ms(inputs, selected_loras):
+ # mint.einsum("bi, boi -> bo", inputs, selected_loras)
+ selected_loras = mint.transpose(selected_loras, 1, 2)
+ outputs = mint.matmul(inputs.unsqueeze(1), selected_loras).squeeze(1)
+ return outputs
+
+def sort_lora_by_token_count(lora_indices_tensor, seq_len_tensor):
+ unique_ids = mint.unique(lora_indices_tensor)
+ token_sums = []
+ for uid in unique_ids:
+ mask = (lora_indices_tensor == uid)
+ total_tokens = mint.sum(seq_len_tensor[mask])
+ token_sums.append(total_tokens)
+ token_sums_tensor = mint.stack(token_sums)
+ sorted_counts, sort_indices = mint.sort(token_sums_tensor, descending=True)
+ sorted_ids = unique_ids[sort_indices]
+ return sorted_ids, sorted_counts
+
+def sgmv_expand(inputs,
+ lora_b_weights,
+ output_tensor,
+ b_seq_start_loc,
+ seq_len_tensor,
+ lora_indices_tensor,
+ batches,
+ max_seq_length,
+ token_nums,
+ add_inputs = False):
+ exploded_indices = mint.repeat_interleave(lora_indices_tensor,
+ seq_len_tensor)
+
+ return bgmv_expand(inputs, lora_b_weights, output_tensor, exploded_indices,
+ add_inputs)
+
+
+def bgmv_expand(inputs,
+ lora_b_weights,
+ output_tensor,
+ lora_indices_tensor,
+ add_inputs = True):
+ selected_loras = lora_b_weights[lora_indices_tensor].astype(output_tensor.dtype)
+ inputs = inputs.astype(output_tensor.dtype)
+ if len(selected_loras.shape) == 4:
+ selected_loras = selected_loras.squeeze(1)
+ outputs = einsum_ms(inputs, selected_loras)
+ limit = output_tensor.shape[0]
+ if outputs.shape[0] == 1 and output_tensor.shape[0] != 1:
+ limit = 1
+ if add_inputs:
+ output_tensor[:, :outputs.shape[1]] += outputs[:limit, :]
+ else:
+ output_tensor[:, :outputs.shape[1]] = outputs[:limit, :]
+ return output_tensor
+
+
+def sgmv_shrink(
+ inputs,
+ lora_a_weights,
+ output_tensor,
+ b_seq_start_loc,
+ seq_len_tensor,
+ lora_indices_tensor,
+ batches,
+ max_seq_length,
+ token_nums,
+ scaling,
+):
+ group_list = seq_len_tensor
+ if (lora_indices_tensor.unique().shape[0] != lora_indices_tensor.shape[0]):
+ sorted_ids, sorted_counts = sort_lora_by_token_count(lora_indices_tensor, seq_len_tensor)
+ group_list = sorted_counts
+ if lora_a_weights.shape[0] != group_list.shape[0]:
+ new_tensor = mint.zeros(lora_a_weights.shape[0], dtype=group_list.dtype)
+ new_tensor[:group_list.size(0)] = group_list
+ group_list = new_tensor
+ if len(lora_a_weights.shape) == 4:
+ lora_a_weights = lora_a_weights.squeeze(1)
+ lora_a_weights = mint.transpose(lora_a_weights, 1, 2)
+ outputs = grouped_matmul_v4([inputs], [lora_a_weights], group_list=group_list, split_item=3, group_type=0, group_list_type=1)
+ outputs = outputs[0]
+ output_tensor[:, :outputs.shape[1]] = scaling * outputs[:]
+ return output_tensor
+
+
+def bgmv_shrink(inputs,
+ lora_b_weights,
+ output_tensor,
+ lora_indices_tensor,
+ scaling = 1.0):
+ selected_loras = lora_b_weights[lora_indices_tensor].astype(output_tensor.dtype)
+ inputs = inputs.astype(output_tensor.dtype)
+ if len(selected_loras.shape) == 4:
+ selected_loras = selected_loras.squeeze(1)
+ outputs = einsum_ms(inputs, selected_loras)
+ output_tensor[:, :outputs.shape[1]] = scaling * outputs[:]
+ return output_tensor
+
+
+def sgmv_expand_slice(inputs,
+ lora_b_weights,
+ output_tensor,
+ b_seq_start_loc,
+ seq_len_tensor,
+ lora_indices_tensor,
+ batches,
+ max_seq_length,
+ token_nums,
+ slice_offset,
+ slice_size,
+ add_inputs = False):
+ group_list = seq_len_tensor
+ if (lora_indices_tensor.unique().shape[0] != lora_indices_tensor.shape[0]):
+ sorted_ids, sorted_counts = sort_lora_by_token_count(lora_indices_tensor, seq_len_tensor)
+ group_list = sorted_counts
+ if lora_b_weights.shape[0] != group_list.shape[0]:
+ new_tensor = mint.zeros(lora_b_weights.shape[0], dtype=group_list.dtype)
+ new_tensor[:group_list.size(0)] = group_list
+ group_list = new_tensor
+ if len(lora_b_weights.shape) == 4:
+ lora_b_weights = lora_b_weights.squeeze(1)
+ lora_b_weights = mint.transpose(lora_b_weights, 1, 2)
+ inputs = inputs.astype(output_tensor.dtype)
+ outputs = grouped_matmul_v4([inputs], [lora_b_weights], group_list=group_list, split_item=3, group_type=0, group_list_type=1)
+ outputs = outputs[0]
+ if add_inputs:
+ output_tensor[:, slice_offset:slice_offset + slice_size] += outputs[:]
+ else:
+ output_tensor[:, slice_offset:slice_offset + slice_size] = outputs[:]
+ return output_tensor
+
+
+def bgmv_expand_slice(inputs,
+ lora_b_weights,
+ output_tensor,
+ lora_indices_tensor,
+ slice_offset,
+ slice_size,
+ add_inputs = True):
+ selected_loras = lora_b_weights[lora_indices_tensor].astype(output_tensor.dtype)
+ inputs = inputs.astype(output_tensor.dtype)
+ if len(selected_loras.shape) == 4:
+ selected_loras = selected_loras.squeeze(1)
+ outputs = einsum_ms(inputs, selected_loras)
+ if add_inputs:
+ output_tensor[:, slice_offset:slice_offset + slice_size] += outputs[:]
+ else:
+ output_tensor[:, slice_offset:slice_offset + slice_size] = outputs[:]
+ return output_tensor
\ No newline at end of file
diff --git a/vllm_mindspore/lora/punica_wrapper/__init__.py b/vllm_mindspore/lora/punica_wrapper/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/vllm_mindspore/lora/punica_wrapper/punica_npu.py b/vllm_mindspore/lora/punica_wrapper/punica_npu.py
new file mode 100644
index 0000000000000000000000000000000000000000..51b41b15052904225f372e1c73503c898f765359
--- /dev/null
+++ b/vllm_mindspore/lora/punica_wrapper/punica_npu.py
@@ -0,0 +1,357 @@
+#!/usr/bin/env python3
+# Copyright 2025 Huawei Technologies Co., Ltd
+# Copyright 2024 The vLLM team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+"""
+refer to https://github.com/vllm-project/vllm-ascend/blob/v0.7.3/vllm_ascend/lora/punica_wrapper/punica_npu.py
+"""
+from typing import Callable
+
+from mindspore import mint
+from mindspore.common import dtype as mstype
+from vllm_mindspore.lora.ops.torch_ops.lora_ops import (bgmv_expand, bgmv_expand_slice,
+ bgmv_shrink, sgmv_expand,
+ sgmv_expand_slice, sgmv_shrink)
+from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase
+
+
+# The platforms that are compatible with the PyTorch-native implementation can
+# inherit this class
+class PunicaWrapperNPU(PunicaWrapperBase):
+ """
+ PunicaWrapperNPU is designed to manage and provide metadata for the punica
+ kernel. The main function is to maintain the state information for
+ Multi-LoRA, and to provide the interface for the pytorch punica ops.
+ """
+
+ def __init__(self, max_num_batched_tokens, max_batches, device, **kwargs):
+ PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches,
+ device)
+
+ def _shrink_prefill(
+ self,
+ y,
+ x,
+ w_t_all,
+ scale,
+ ):
+ sgmv_shrink(
+ x,
+ w_t_all,
+ y,
+ *self.prefill_metadata,
+ scale,
+ )
+
+ def _shrink_decode(
+ self,
+ y,
+ x,
+ w_t_all,
+ scale,
+ ):
+ bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale)
+
+ def _expand_prefill(
+ self,
+ y,
+ x,
+ w_t_all,
+ add_inputs,
+ ):
+ sgmv_expand(
+ x,
+ w_t_all,
+ y,
+ *self.prefill_metadata,
+ add_inputs,
+ )
+
+ def _expand_decode(
+ self,
+ y,
+ x,
+ w_t_all,
+ add_inputs,
+ ):
+ bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_inputs)
+
+ def _expand_slice_prefill(
+ self,
+ y,
+ x,
+ w_t_all,
+ y_offset,
+ y_slice_size,
+ add_inputs,
+ ):
+ sgmv_expand_slice(
+ x,
+ w_t_all,
+ y,
+ *self.prefill_metadata,
+ y_offset,
+ y_slice_size,
+ add_inputs,
+ )
+
+ def _expand_slice_decode(
+ self,
+ y,
+ x,
+ w_t_all,
+ y_offset,
+ y_slice_size,
+ add_inputs,
+ ):
+ bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset,
+ y_slice_size, add_inputs)
+
+ def _apply_expand(
+ self,
+ y,
+ x,
+ w_t_all,
+ y_offset,
+ y_slice_size,
+ add_inputs,
+ ):
+ """
+ Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all`
+ computation, which is suitable for the
+ GEMM of lora'b.
+ """
+
+ expand_slice_fun: Callable = (self._expand_slice_prefill
+ if self.is_prefill else
+ self._expand_slice_decode)
+ expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_inputs)
+
+ def _apply_shrink(self, y, x, w_t_all, scale):
+ """
+ Perform the ` y+=x@w_t_all` computation, which is suitable for the
+ GEMM of lora'a.
+ When `is_prefill is` true, it indicates that it is currently the
+ prefill stage, and the `_shrink_prefill` function should be called.
+ Otherwise, it is the decode stage, and the _shrink_decode function
+ should be called.
+ """
+ y_org = y
+ y = y.view(-1, y.shape[-1])
+ shrink_fun: Callable = (self._shrink_prefill
+ if self.is_prefill else self._shrink_decode)
+ shrink_fun(y, x, w_t_all, scale)
+ y.view_as(y_org)
+
+ def add_shrink(self, y, x, lora_a_stacked, scale, **kwargs):
+ """
+ Performs GEMM for multiple slices of lora_a.
+ When `is_prefill is` true, it indicates that it is currently the
+ prefill stage, and the `_shrink_prefill` function should be called.
+ Otherwise, it is the decode stage, and the _shrink_decode function
+ should be called.
+
+ Semantics:
+ for i in range(len(lora_a_stacked)):
+ y[i] += (x @ lora_a_stacked[i]) * scale
+
+ Args:
+ y (Union[Tuple[ms.Tensor, ...], ms.Tensor]): Output tensors
+ x (ms.Tensor): Input tensor
+ lora_a_stacked (Tuple[ms.Tensor, ...]): lora_a's weights
+ scale (float): Scaling factor for the operation
+ """
+
+ x = x.view(-1, x.shape[-1])
+ # TODO fuse these kernels
+ for slice_idx in range(len(lora_a_stacked)):
+ self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx],
+ scale)
+
+ def add_expand(self,
+ y,
+ x,
+ lora_b_stacked,
+ lora_bias_stacked,
+ output_slices,
+ offset_start=0,
+ add_inputs=True,
+ **kwargs) -> None:
+ """
+ Performs GEMM and bias addition for multiple slices of lora_b.
+
+ Semantics:
+ for i in range(len(lora_b_stacked)):
+ slice = output_slices[i]
+ y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] +
+ lora_bias_stacked[i]
+ offset += slice
+
+ Args:
+ y (ms.Tensor): Output tensor.
+ x (Union[Tuple[ms.Tensor, ...], ms.Tensor]): Input tensors
+ lora_b_stacked (Tuple[ms.Tensor, ...]): lora_b's weight
+ lora_bias_stacked (Optional[Tuple[ms.Tensor, ...]]):
+ bias's weight
+ output_slices (Tuple[int, ...]): Every slice's size
+ add_inputs (bool): Defaults to True.
+ """
+ y_org = y
+ y = y.view(-1, y.shape[-1])
+ offset_left = offset_start
+ if lora_bias_stacked is not None:
+ self._apply_bias(self.token_lora_indices, y, output_slices,
+ lora_bias_stacked)
+ for slice_idx in range(len(lora_b_stacked)):
+ self._apply_expand(
+ y,
+ x[slice_idx],
+ lora_b_stacked[slice_idx],
+ offset_left,
+ output_slices[slice_idx],
+ add_inputs=add_inputs,
+ )
+ offset_left += output_slices[slice_idx]
+ y.view_as(y_org)
+
+ def add_lora_embedding(self,
+ y,
+ x,
+ lora_b_stacked,
+ add_inputs=True,
+ **kwargs) -> None:
+ """
+ Applies lora specifically for VocabParallelEmbeddingWithLoRA.
+
+ Semantics:
+ y += x @ lora_b_stacked
+
+ Args:
+ y (ms.Tensor): Output tensor.
+ x (ms.Tensor): Input tensor.
+ lora_b_stacked (ms.Tensor): lora_b's weights.
+ add_inputs (bool): Default to True.
+ """
+ #No LoRA request, so return directly
+ if self.no_lora:
+ return
+ # Embedding layer only need expand op
+ expand_fun: Callable = (self._expand_prefill
+ if self.is_prefill else self._expand_decode)
+ expand_fun(y, x, lora_b_stacked, add_inputs)
+
+ def add_lora_linear(self,
+ y,
+ x,
+ lora_a_stacked,
+ lora_b_stacked,
+ lora_bias_stacked,
+ scale,
+ output_slices,
+ *,
+ buffer=None,
+ **kwargs) -> None:
+ """
+ Applicable to linear-related lora.
+
+ Semantics:
+ for i in range(len(lora_a_stacked)):
+ y[i] += (
+ x[i].unsqueeze(0)
+ @ lora_a_stacked[indices[i], layer_idx, :, :]
+ @ lora_b_stacked[indices[i], layer_idx, :, :]
+ * scale
+ ).squeeze(0)+lora_bias_stacked[i]
+
+ Args:
+ y (ms.Tensor): Output tensor. Will be changed in-place.
+ x (ms.Tensor): Input tensor
+ lora_a_stacked (Tuple[ms.Tensor, ...]): lora_a's weight.
+ lora_b_stacked (Tuple[ms.Tensor, ...]): lora_b's weight.
+ lora_bias_stacked (Optional[Tuple[ms.Tensor, ...]]): lora's bias.
+ scale (float): Scaling factor.
+ output_slices (Tuple[int, ...]): Every slice's size.
+ buffer (Optional[Tuple[ms.Tensor, ...]]): Defaults to None.
+ """
+ #No LoRA request, so return directly
+ if self.no_lora:
+ return
+ x = x.reshape(-1, x.shape[-1])
+ assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
+ if lora_bias_stacked is not None:
+ assert len(lora_bias_stacked) == len(output_slices)
+ y = self._apply_bias(self.token_lora_indices, y, output_slices,
+ lora_bias_stacked)
+
+ if buffer is None:
+ r = lora_b_stacked[0].shape[-1]
+ # We set the buffer to be float32 by default, consistent with the
+ # triton op
+ buffer = tuple(
+ mint.zeros((x.shape[0], r), dtype=mstype.float32)
+ for _ in range(len(output_slices)))
+ self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs)
+ self.add_expand(y,
+ buffer,
+ lora_b_stacked,
+ None,
+ output_slices,
+ add_inputs=True,
+ **kwargs)
+
+ def add_lora_logits(self,
+ y,
+ x,
+ lora_a_stacked,
+ lora_b_stacked,
+ scale,
+ *,
+ buffer=None,
+ **kwargs) -> None:
+ """
+ Applies lora specifically for LogitsProcessorWithLoRA.
+
+ Semantics:
+ buffer = (x @ lora_a_stacked) * scale
+ y += buffer @ lora_b_stacked
+
+ Args:
+ y (ms.Tensor): Output tensor.
+ x (ms.Tensor): Input tensor.
+ lora_a_stacked (ms.Tensor): lora_a's weights.
+ lora_b_stacked (ms.Tensor):lora_b's weights.
+ scale (float): Scaling factor.
+ buffer (Optional[ms.Tensor]):Default to None.
+ """
+ #No LoRA request, so return directly
+ if self.no_lora:
+ return
+ y_org = y
+ y = y.view(-1, y.shape[-1])
+ x = x.view(-1, x.shape[-1])
+ r = lora_b_stacked.shape[-1]
+ if buffer is None:
+ # We set the buffer to be float32 by default, consistent with the
+ # triton op
+ buffer = mint.zeros((x.shape[0], r), dtype=mstype.float32)
+ # LogitsProcessorWithLoRA always using bgmv.
+ bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, scale)
+ bgmv_expand(buffer,
+ lora_b_stacked,
+ y,
+ self.sampler_indices,
+ add_inputs=True)
+ y.view_as(y_org)
diff --git a/vllm_mindspore/lora/utils.py b/vllm_mindspore/lora/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0084e607b09c0f5801f9cbeee57f3beb11066442
--- /dev/null
+++ b/vllm_mindspore/lora/utils.py
@@ -0,0 +1,47 @@
+#!/usr/bin/env python3
+# Copyright 2025 Huawei Technologies Co., Ltd
+# Copyright 2024 The vLLM team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+from typing import Set, Type
+
+from vllm.lora.fully_sharded_layers import (
+ ColumnParallelLinearWithShardedLoRA,
+ MergedColumnParallelLinearWithShardedLoRA,
+ MergedQKVParallelLinearWithShardedLoRA, QKVParallelLinearWithShardedLoRA,
+ RowParallelLinearWithShardedLoRA)
+
+from vllm_mindspore.lora.layers import (
+ BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
+ LinearScalingRotaryEmbeddingWithLoRA, LogitsProcessorWithLoRA,
+ MergedColumnParallelLinearWithLoRA, MergedQKVParallelLinearWithLoRA,
+ QKVParallelLinearWithLoRA, RowParallelLinearWithLoRA,
+ VocabParallelEmbeddingWithLoRA)
+
+_all_lora_classes: Set[Type[BaseLayerWithLoRA]] = {
+ VocabParallelEmbeddingWithLoRA,
+ ColumnParallelLinearWithLoRA,
+ MergedColumnParallelLinearWithLoRA,
+ QKVParallelLinearWithLoRA,
+ MergedQKVParallelLinearWithLoRA,
+ RowParallelLinearWithLoRA,
+ LogitsProcessorWithLoRA,
+ ColumnParallelLinearWithShardedLoRA,
+ QKVParallelLinearWithShardedLoRA,
+ MergedColumnParallelLinearWithShardedLoRA,
+ MergedQKVParallelLinearWithShardedLoRA,
+ RowParallelLinearWithShardedLoRA,
+ LinearScalingRotaryEmbeddingWithLoRA,
+}
diff --git a/vllm_mindspore/model_executor/layers/layernorm.py b/vllm_mindspore/model_executor/layers/layernorm.py
index db156c0cc0dd0605fef0e48d0d79a1a489b9a4ea..3e0251cbde68f1fa587414ef69577c79d6961db5 100644
--- a/vllm_mindspore/model_executor/layers/layernorm.py
+++ b/vllm_mindspore/model_executor/layers/layernorm.py
@@ -21,23 +21,26 @@ from typing import Optional, Tuple, Union, Any
from mindspore import Parameter, Tensor, mint, ops
from mindspore.common import dtype as mstype
from mindspore.common.dtype import typing
+from mindspore import nn
-from vllm_mindspore.model_executor.custom_op import CustomOp
+from vllm.config import get_current_vllm_config
-class RMSNorm(CustomOp):
+class RMSNorm(nn.Cell):
def __init__(
self,
hidden_size: int,
eps: float = 1e-6,
var_hidden_size: Optional[int] = None,
- params_dtype: Optional[Any] = mstype.float16,
+ params_dtype: Optional[Any] = None,
) -> None:
super().__init__()
+ if params_dtype is None:
+ params_dtype = get_current_vllm_config().model_config.dtype
self.weight = Parameter(mint.ones(hidden_size, dtype=params_dtype))
self.rms_norm = ops.RmsNorm(eps)
- def forward_native(
+ def construct(
self,
x: Tensor,
residual: Optional[Tensor] = None
diff --git a/vllm_mindspore/model_executor/layers/linear.py b/vllm_mindspore/model_executor/layers/linear.py
index 572f0e345920434a56d38a6540d339ead5aa5369..e0851149270d108293311f2103811fa6a558fbe3 100644
--- a/vllm_mindspore/model_executor/layers/linear.py
+++ b/vllm_mindspore/model_executor/layers/linear.py
@@ -32,6 +32,7 @@ from vllm.distributed import (
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
)
+from vllm.config import get_current_vllm_config
from vllm_mindspore.model_executor.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
@@ -39,7 +40,6 @@ from vllm_mindspore.model_executor.layers.quantization.base_config import (
from vllm_mindspore.model_executor.utils import set_weight_attrs
from vllm_mindspore.distributed.communication_op import ReduceFromModelParallelRegion
-
WEIGHT_LOADER_V2_SUPPORTED = [
"CompressedTensorsLinearMethod",
"AWQMarlinLinearMethod",
@@ -170,8 +170,7 @@ class LinearBase(ms.nn.Cell):
self.output_size = output_size
self.skip_bias_add = skip_bias_add
if params_dtype is None:
- # params_dtype = torch.get_default_dtype()
- params_dtype = ms.float16
+ params_dtype = get_current_vllm_config().model_config.dtype
self.params_dtype = params_dtype
if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedLinearMethod()
@@ -236,7 +235,7 @@ class ColumnParallelLinear(LinearBase):
)
if bias:
self.bias = Parameter(
- mint.zeros(self.output_size_per_partition, dtype=params_dtype)
+ mint.zeros(self.output_size_per_partition, dtype=self.params_dtype)
)
set_weight_attrs(
self.bias,
@@ -545,7 +544,7 @@ class RowParallelLinear(LinearBase):
)
if bias:
- self.bias = Parameter(mint.zeros(self.output_size, dtype=params_dtype))
+ self.bias = Parameter(mint.zeros(self.output_size, dtype=self.params_dtype))
set_weight_attrs(
self.bias,
{
diff --git a/vllm_mindspore/model_executor/layers/logits_processor.py b/vllm_mindspore/model_executor/layers/logits_processor.py
index 32b02fb7e32fd5806b153ad5128d387f46b811f1..5d6036943d123aa952bfbf6673c75553e1427066 100644
--- a/vllm_mindspore/model_executor/layers/logits_processor.py
+++ b/vllm_mindspore/model_executor/layers/logits_processor.py
@@ -32,7 +32,7 @@ from vllm.distributed import (
from vllm_mindspore.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding,
)
-from vllm_mindspore.model_executor.sampling_metadata import SamplingMetadata
+from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.platforms import current_platform
diff --git a/vllm_mindspore/model_executor/layers/rotary_embedding.py b/vllm_mindspore/model_executor/layers/rotary_embedding.py
index c9dfe254dc8d307664f4405fe48bf8fab0ecb33a..7470233475254ccccdf677d5627a6f3bd59f6408 100644
--- a/vllm_mindspore/model_executor/layers/rotary_embedding.py
+++ b/vllm_mindspore/model_executor/layers/rotary_embedding.py
@@ -1,5 +1,6 @@
#!/usr/bin/env python3
-# encoding: utf-8
+# type: ignore
+# isort:skip_file
# Copyright 2025 Huawei Technologies Co., Ltd
# Copyright 2024 The vLLM team.
#
@@ -16,16 +17,18 @@
# limitations under the License.
# ============================================================================
+import math
+import numpy as np
+
from typing import Any, Dict, List, Optional, Tuple, Union
-import numpy as np
import mindspore
-from mindspore import Tensor, mint, ops
+from mindspore import Tensor, mint, nn, ops
from mindspore.common import dtype as mstype
+from mindspore.ops.auto_generate.gen_ops_prim import SliceExt
from transformers import PretrainedConfig
-
-from vllm_mindspore.model_executor.custom_op import CustomOp
+from vllm.config import get_current_vllm_config
def _apply_rotary_emb(
@@ -57,7 +60,8 @@ def _apply_rotary_emb(
return mint.stack((o1, o2), dim=-1).flatten(-2)
-class RotaryEmbedding(CustomOp):
+class RotaryEmbedding(nn.Cell):
+
def __init__(
self,
head_size: int,
@@ -86,10 +90,8 @@ class RotaryEmbedding(CustomOp):
# use CPU to compute the cache and then move it to GPU. However, we
# create the cache on GPU for faster initialization. This may cause
# a slight numerical difference between the HF implementation and ours.
- inv_freq = 1.0 / (
- base
- ** (mint.arange(0, self.rotary_dim, 2, dtype=mstype.float32) / self.rotary_dim)
- )
+ inv_freq = 1.0 / (base**(mint.arange(
+ 0, self.rotary_dim, 2, dtype=mstype.float32) / self.rotary_dim))
return inv_freq
def _compute_cos_sin_cache(self) -> Tensor:
@@ -104,7 +106,7 @@ class RotaryEmbedding(CustomOp):
cache = mint.cat((cos, sin), dim=-1)
return cache
- def forward_native(
+ def construct(
self,
positions: Tensor,
query: Tensor,
@@ -121,21 +123,22 @@ class RotaryEmbedding(CustomOp):
query_shape = query.shape
query = query.view(num_tokens, -1, self.head_size)
- query_rot = query[..., : self.rotary_dim]
+ query_rot = query[..., :self.rotary_dim]
query_pass = query[..., self.rotary_dim:]
query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
query = mint.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size)
- key_rot = key[..., : self.rotary_dim]
+ key_rot = key[..., :self.rotary_dim]
key_pass = key[..., self.rotary_dim:]
key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
key = mint.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key
-class InferRotaryEmbedding(CustomOp):
+class InferRotaryEmbedding(nn.Cell):
+
def __init__(
self,
head_size: int,
@@ -145,23 +148,44 @@ class InferRotaryEmbedding(CustomOp):
is_neox_style: bool,
dtype,
) -> None:
+ if not is_neox_style:
+ raise NotImplementedError(
+ "InferRotaryEmbedding only support Neox-style rotary embeddings."
+ )
super().__init__()
- freqs_base = np.arange(0, rotary_dim, 2)[: (rotary_dim // 2)].astype(np.float32) # (head_dim // 2, )
- freqs = 1.0 / (base ** (freqs_base / rotary_dim)) # (head_dim // 2, )
- mscale = 1.0
- t = np.arange(0, max_position_embeddings, 1).astype(np.float32)
+ self.rotary_embedding_op = ops.ApplyRotaryPosEmb(2)
+ self.gather = ops.Gather()
+ self.head_size = head_size
+ self.rotary_dim = rotary_dim
+ self.max_position_embeddings = max_position_embeddings
+ self.base = base
+ self.is_neox_style = is_neox_style
+ self.dtype = dtype
+ self.freqs_cos, self.freqs_sin = self._compute_cos_sin_cache()
- self.freqs = Tensor(freqs.reshape(1, 1, 1, -1), dtype=dtype)
+ def _compute_inv_freq(self, base: Union[int, float]) -> Tensor:
+ """
+ Compute the inverse frequency with numpy.
+ Numpy process is faster during initialization.
+ """
+ freqs_base = np.arange(0, self.rotary_dim,
+ 2).astype(np.float32) # (head_dim // 2, )
+ freqs = 1.0 / (base**(freqs_base / self.rotary_dim)
+ ) # (head_dim // 2, )
+ return freqs
+
+ def _compute_cos_sin_cache(self) -> Tuple[Tensor, Tensor]:
+ freqs = self._compute_inv_freq(self.base)
+ t = np.arange(0, self.max_position_embeddings, 1).astype(np.float32)
freqs = np.outer(t, freqs) # (max_position_embedding, head_dim // 2)
emb = np.concatenate((freqs, freqs), axis=-1)
- freqs_cos = np.cos(emb) * mscale # (seq_len, head_dim)
- freqs_sin = np.sin(emb) * mscale # (seq_len, head_dim)
- self.freqs_cos = Tensor(freqs_cos, dtype=dtype)
- self.freqs_sin = Tensor(freqs_sin, dtype=dtype)
- self.rotary_embedding_op = ops.ApplyRotaryPosEmb(2)
- self.gather = ops.Gather()
+ freqs_cos = np.cos(emb) # (seq_len, head_dim)
+ freqs_sin = np.sin(emb) # (seq_len, head_dim)
+ freqs_cos = Tensor(freqs_cos, dtype=self.dtype)
+ freqs_sin = Tensor(freqs_sin, dtype=self.dtype)
+ return freqs_cos, freqs_sin
- def forward_native(
+ def construct(
self,
positions: Tensor,
query: Tensor,
@@ -170,12 +194,62 @@ class InferRotaryEmbedding(CustomOp):
is_prefill: bool,
offsets: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]:
+ query = query.contiguous()
+ key = key.contiguous()
if is_prefill:
- return self.rotary_embedding_op(query, key, self.freqs_cos, self.freqs_sin, batch_valid_length)
+ return self.rotary_embedding_op(query, key, self.freqs_cos,
+ self.freqs_sin, batch_valid_length)
freqs_cos = self.gather(self.freqs_cos, positions, 0)
freqs_sin = self.gather(self.freqs_sin, positions, 0)
- return self.rotary_embedding_op(query, key, freqs_cos, freqs_sin, batch_valid_length)
+ return self.rotary_embedding_op(query, key, freqs_cos, freqs_sin,
+ batch_valid_length)
+
+
+class InferLlama3RotaryEmbedding(InferRotaryEmbedding):
+
+ def __init__(
+ self,
+ head_size: int,
+ rotary_dim: int,
+ max_position_embeddings: int,
+ base: int,
+ is_neox_style: bool,
+ dtype,
+ scaling_factor: float,
+ low_freq_factor: float,
+ high_freq_factor: float,
+ orig_max_position: int,
+ ) -> None:
+ self.scaling_factor = scaling_factor
+ self.low_freq_factor = low_freq_factor
+ self.high_freq_factor = high_freq_factor
+ self.orig_max_position = orig_max_position
+ super().__init__(head_size, rotary_dim, max_position_embeddings, base,
+ is_neox_style, dtype)
+
+ def _compute_inv_freq(self, base: Union[int, float]) -> np.ndarray:
+ inv_freqs = super()._compute_inv_freq(base)
+ low_freq_wavelen = self.orig_max_position / self.low_freq_factor
+ high_freq_wavelen = self.orig_max_position / self.high_freq_factor
+
+ wave_len = 2 * math.pi / inv_freqs
+ if self.low_freq_factor != self.high_freq_factor:
+ smooth = (self.orig_max_position / wave_len - self.low_freq_factor
+ ) / (self.high_freq_factor - self.low_freq_factor)
+ else:
+ smooth = 0
+ new_freqs = np.where(
+ wave_len < high_freq_wavelen,
+ inv_freqs,
+ np.where(
+ wave_len > low_freq_wavelen,
+ inv_freqs / self.scaling_factor,
+ (1 - smooth) * inv_freqs / self.scaling_factor +
+ smooth * inv_freqs,
+ ),
+ )
+ return new_freqs
class MRotaryEmbedding(RotaryEmbedding):
@@ -202,7 +276,7 @@ class MRotaryEmbedding(RotaryEmbedding):
if self.mrope_section:
assert sum(self.mrope_section) == rotary_dim // 2
- def forward_native(
+ def construct(
self,
positions: mindspore.Tensor,
query: mindspore.Tensor,
@@ -232,9 +306,9 @@ class MRotaryEmbedding(RotaryEmbedding):
cos_l = ops.split(cos, self.mrope_section, axis=-1)
sin_l = ops.split(sin, self.mrope_section, axis=-1)
cos, sin = (), ()
- for i in range(len(self.mrope_section)):
- cos += (cos_l[i][i],)
- sin += (sin_l[i][i],)
+ for i in range(len(self.mrope_section)): # type: ignore[arg-type]
+ cos += (cos_l[i][i], )
+ sin += (sin_l[i][i], )
cos = ops.cat(cos, axis=-1)
sin = ops.cat(sin, axis=-1)
@@ -357,7 +431,8 @@ class MRotaryEmbedding(RotaryEmbedding):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(
- ops.arange(text_len).view(1, -1).broadcast_to((3, -1)).int() + st_idx)
+ ops.arange(text_len).view(1, -1).broadcast_to((3, -1)).int() +
+ st_idx)
t_index = (ops.arange(llm_grid_t).view(-1, 1).broadcast_to(
(-1, llm_grid_h * llm_grid_w)) * video_second_per_grid_t *
@@ -366,7 +441,7 @@ class MRotaryEmbedding(RotaryEmbedding):
(llm_grid_t, -1, llm_grid_w)).flatten().int()
w_index = ops.arange(llm_grid_w).view(1, 1, -1).broadcast_to(
(llm_grid_t, llm_grid_h, -1)).flatten().int()
-
+
llm_pos_ids_list.append(
ops.stack([t_index, h_index, w_index]) + text_len + st_idx)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
@@ -376,7 +451,8 @@ class MRotaryEmbedding(RotaryEmbedding):
llm_pos_ids_list) > 0 else 0
text_len = len(input_tokens) - st
llm_pos_ids_list.append(
- ops.arange(text_len).view(1, -1).broadcast_to((3, -1)).int() + st_idx)
+ ops.arange(text_len).view(1, -1).broadcast_to((3, -1)).int() +
+ st_idx)
llm_positions = ops.cat(llm_pos_ids_list, axis=1).view(3, -1)
mrope_position_delta = (llm_positions.max() + 1 -
@@ -403,9 +479,9 @@ class MRotaryEmbedding(RotaryEmbedding):
context_len: int,
seq_len: int,
) -> mindspore.Tensor:
- return ops.arange(
- mrope_position_delta + context_len,
- mrope_position_delta + seq_len,
+ return mint.arange(
+ int(mrope_position_delta + context_len),
+ int(mrope_position_delta + seq_len),
).broadcast_to((3, -1))
@@ -435,7 +511,7 @@ class InferMRotaryEmbedding(InferRotaryEmbedding):
self.rotary_dim = rotary_dim
self.max_position_embeddings = max_position_embeddings
self.base = base
- self.is_neox_style = is_neox_style
+ self.is_neox_style = is_neox_style # type: ignore[assignment]
self.dtype = dtype
super().__init__(head_size, rotary_dim, self.cache_max_position_num,
base, is_neox_style, dtype)
@@ -444,7 +520,7 @@ class InferMRotaryEmbedding(InferRotaryEmbedding):
if self.mrope_section:
assert sum(self.mrope_section) == rotary_dim // 2
- def forward_native(
+ def construct( # type: ignore[override]
self,
positions: mindspore.Tensor,
query: mindspore.Tensor,
@@ -460,47 +536,60 @@ class InferMRotaryEmbedding(InferRotaryEmbedding):
query: [num_tokens, num_heads * head_size]
key: [num_tokens, num_kv_heads * head_size]
"""
+ half_rotary_dim = self.rotary_dim // 2
# prefill
if is_prefill:
num_tokens = positions.shape[-1]
cos, sin = self.freqs_cos[positions], self.freqs_sin[positions]
- cos, sin = cos[..., :self.rotary_dim//2], sin[..., :self.rotary_dim//2]
+ cos = SliceExt()(cos, -1, 0, half_rotary_dim, 1)
+ sin = SliceExt()(sin, -1, 0, half_rotary_dim, 1)
if positions.ndim == 2:
- cos_l = ops.split(cos, self.mrope_section, axis=-1)
- sin_l = ops.split(sin, self.mrope_section, axis=-1)
+ cos_l = mint.split(cos, self.mrope_section, dim=-1)
+ sin_l = mint.split(sin, self.mrope_section, dim=-1)
cos, sin = (), ()
- for i in range(len(self.mrope_section)):
- cos += (cos_l[i][i],)
- sin += (sin_l[i][i],)
+ for i in range(len(
+ self.mrope_section)): # type: ignore[arg-type]
+ cos_l_select = mint.index_select(cos_l[i], 0,
+ Tensor([i])).squeeze(0)
+ cos += (cos_l_select, )
+ sin_l_select = mint.index_select(sin_l[i], 0,
+ Tensor([i])).squeeze(0)
+ sin += (sin_l_select, )
cos = ops.cat(cos, axis=-1)
sin = ops.cat(sin, axis=-1)
query_shape = query.shape
query = query.view(num_tokens, -1, self.head_size)
- query_rot = query[..., :self.rotary_dim]
- query_pass = query[..., self.rotary_dim:]
- query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
+ query_rot = SliceExt()(query, -1, 0, self.rotary_dim, 1)
+ query_pass = SliceExt()(query, -1, self.rotary_dim,
+ query_shape[-1], 1)
+ query_rot = _apply_rotary_emb(query_rot, cos, sin,
+ self.is_neox_style)
query = ops.cat((query_rot, query_pass), axis=-1).view(query_shape)
key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size)
- key_rot = key[..., :self.rotary_dim]
- key_pass = key[..., self.rotary_dim:]
+ key_rot = SliceExt()(key, -1, 0, self.rotary_dim, 1)
+ key_pass = SliceExt()(key, -1, self.rotary_dim, key_shape[-1], 1)
key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
key = ops.cat((key_rot, key_pass), axis=-1).view(key_shape)
return query, key
# decode
- if positions.ndim == 2 and positions.shape[0] == len(self.mrope_section):
- num_tokens = positions.shape[-1]
+ if positions.ndim == 2:
cos, sin = self.freqs_cos[positions], self.freqs_sin[positions]
- cos, sin = cos[..., :self.rotary_dim//2], sin[..., :self.rotary_dim//2]
- cos_l = ops.split(cos, self.mrope_section, axis=-1)
- sin_l = ops.split(sin, self.mrope_section, axis=-1)
+ cos = SliceExt()(cos, -1, 0, half_rotary_dim, 1)
+ sin = SliceExt()(sin, -1, 0, half_rotary_dim, 1)
+ cos_l = mint.split(cos, self.mrope_section, dim=-1)
+ sin_l = mint.split(sin, self.mrope_section, dim=-1)
cos, sin = (), ()
- for i in range(len(self.mrope_section)):
- cos += (cos_l[i][i],)
- sin += (sin_l[i][i],)
+ for i in range(len(self.mrope_section)): # type: ignore[arg-type]
+ cos_l_select = mint.index_select(cos_l[i], 0,
+ Tensor([i])).squeeze(0)
+ cos += (cos_l_select, )
+ sin_l_select = mint.index_select(sin_l[i], 0,
+ Tensor([i])).squeeze(0)
+ sin += (sin_l_select, )
cos = ops.cat(cos, axis=-1)
sin = ops.cat(sin, axis=-1)
freqs_cos = ops.cat([cos, cos], axis=-1).squeeze(1)
@@ -510,10 +599,11 @@ class InferMRotaryEmbedding(InferRotaryEmbedding):
freqs_cos = self.freqs_cos.index_select(0, positions)
freqs_sin = self.freqs_sin.index_select(0, positions)
- return self.rotary_embedding_op(query, key, freqs_cos, freqs_sin, batch_valid_length)
+ return self.rotary_embedding_op(query, key, freqs_cos, freqs_sin,
+ batch_valid_length)
-_ROPE_DICT: Dict[Tuple, InferRotaryEmbedding] = {}
+_ROPE_DICT: Dict[Tuple, Union[InferRotaryEmbedding, RotaryEmbedding]] = {}
def get_rope(
@@ -523,9 +613,12 @@ def get_rope(
base: int,
is_neox_style: bool = True,
rope_scaling: Optional[Dict[str, Any]] = None,
- dtype: Optional[Any] = mstype.float16,
+ dtype: Optional[Any] = None,
partial_rotary_factor: float = 1.0,
-) -> InferRotaryEmbedding:
+):
+ if dtype is None:
+ dtype = get_current_vllm_config().model_config.dtype
+
if rope_scaling is not None:
# Transforms every value that is a list into a tuple for caching calls
rope_scaling_tuple = {
@@ -543,7 +636,8 @@ def get_rope(
if key in _ROPE_DICT:
return _ROPE_DICT[key]
if rope_scaling is None:
- rotary_emb = InferRotaryEmbedding(
+ cls = InferRotaryEmbedding if is_neox_style else RotaryEmbedding
+ rotary_emb = cls(
head_size,
rotary_dim,
max_position,
@@ -555,7 +649,15 @@ def get_rope(
scaling_type = rope_scaling["rope_type"]
if scaling_type == "llama3":
- raise NotImplementedError
+ scaling_factor = rope_scaling["factor"]
+ low_freq_factor = rope_scaling["low_freq_factor"]
+ high_freq_factor = rope_scaling["high_freq_factor"]
+ original_max_position = rope_scaling[
+ "original_max_position_embeddings"]
+ rotary_emb = InferLlama3RotaryEmbedding(
+ head_size, rotary_dim, max_position, base, is_neox_style,
+ dtype, scaling_factor, low_freq_factor, high_freq_factor,
+ original_max_position)
elif scaling_type == "default":
if "mrope_section" in rope_scaling:
rotary_emb = InferMRotaryEmbedding(
@@ -572,5 +674,5 @@ def get_rope(
else:
raise NotImplementedError
- _ROPE_DICT[key] = rotary_emb
+ _ROPE_DICT[key] = rotary_emb # type: ignore[assignment]
return rotary_emb
diff --git a/vllm_mindspore/model_executor/layers/sampler.py b/vllm_mindspore/model_executor/layers/sampler.py
index edfe62526034bef3d8b60ba8488047628c11288c..db0ede03020bcd9d767add147010554f9c95ad47 100644
--- a/vllm_mindspore/model_executor/layers/sampler.py
+++ b/vllm_mindspore/model_executor/layers/sampler.py
@@ -37,11 +37,11 @@ from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
PromptLogprobs, SampleLogprobs, SequenceOutput)
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
from vllm_mindspore.model_executor.layers.utils import apply_penalties
-from vllm_mindspore.model_executor.sampling_metadata import (
+from vllm.model_executor.sampling_metadata import (
SamplingMetadata,
- SamplingTensors,
- SequenceGroupToSample,
+ SamplingTensors
)
+from vllm.model_executor.sampling_metadata import SequenceGroupToSample
if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"):
raise RuntimeError("Donot support for mindspore now.")
@@ -447,7 +447,8 @@ def _apply_min_p(
"""
probs = torch.softmax(logits, dim=-1)
top_probs, _ = probs.max(dim=-1, keepdim=True)
- scaled_min_p = min_p.unsqueeze_(dim=1) * top_probs
+ # For MindSpore: unsqueeze_ will cause error, use unsqueeze instead
+ scaled_min_p = min_p.unsqueeze(dim=1) * top_probs
tokens_to_remove = probs < scaled_min_p
logits = logits.masked_fill_(tokens_to_remove, -float("inf"))
diff --git a/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py b/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py
index e3407f516c2d3d7d34ff84c6bd3214262cb31182..768a8238f4d73bcecb2d7cd1a453d6de81e07e09 100644
--- a/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py
+++ b/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py
@@ -1,5 +1,4 @@
#!/usr/bin/env python3
-# encoding: utf-8
# Copyright 2025 Huawei Technologies Co., Ltd
# Copyright 2024 The vLLM team.
#
@@ -20,19 +19,18 @@ from dataclasses import dataclass
from typing import List, Optional, Sequence, Tuple
from mindspore import Parameter, Tensor, mint, nn, ops
-from mindspore.common import dtype as mstype
from mindspore.common.dtype import typing
+from vllm.config import get_current_vllm_config
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
- get_tensor_model_parallel_world_size,
- tensor_model_parallel_all_reduce,)
+ get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
+from vllm_mindspore.distributed.communication_op import (
+ ReduceFromModelParallelRegion)
from vllm_mindspore.model_executor.layers.quantization.base_config import (
QuantizeMethodBase, method_has_implemented_embedding)
from vllm_mindspore.model_executor.utils import set_weight_attrs
-from vllm_mindspore.distributed.communication_op import ReduceFromModelParallelRegion
-from mindspore import jit
DEFAULT_VOCAB_PADDING_SIZE = 64
@@ -40,15 +38,13 @@ DEFAULT_VOCAB_PADDING_SIZE = 64
class UnquantizedEmbeddingMethod(QuantizeMethodBase):
"""Unquantized method for embeddings."""
- def create_weights(self, layer: nn.Cell,
- input_size_per_partition: int,
+ def create_weights(self, layer: nn.Cell, input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int,
- output_size: int, params_dtype,
- **extra_weight_attrs):
+ output_size: int, params_dtype, **extra_weight_attrs):
"""Create weights for embedding layer."""
- weight = Parameter(mint.zeros((sum(output_partition_sizes),
- input_size_per_partition),
- dtype=params_dtype),
+ weight = Parameter(mint.zeros(
+ (sum(output_partition_sizes), input_size_per_partition),
+ dtype=params_dtype),
requires_grad=False)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
layer.insert_param_to_cell("weight", weight)
@@ -64,7 +60,7 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase):
layer: nn.Cell,
x: Tensor,
bias: Optional[Tensor] = None) -> Tensor:
- output_shape = x.shape[:-1] + (self.output_size_per_partition,)
+ output_shape = x.shape[:-1] + (self.output_size_per_partition, )
x = x.reshape(-1, self.input_size_per_partition)
x = self.matmul(x, layer.weight)
if bias is not None:
@@ -72,8 +68,7 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase):
x = x.reshape(output_shape)
return x
- def embedding(self, layer: nn.Cell,
- input_: Tensor) -> Tensor:
+ def embedding(self, layer: nn.Cell, input_: Tensor) -> Tensor:
return self.gather(layer.weight, input_, 0)
@@ -87,12 +82,15 @@ def get_masked_input_and_mask(
) -> Tuple[Tensor, Tensor]:
displaced_x = mint.sub(input_, org_vocab_start_index)
down_truncated_x = mint.nn.functional.relu(displaced_x)
- truncated_x = mint.minimum(down_truncated_x, (org_vocab_end_index - org_vocab_start_index - 1))
+ truncated_x = mint.minimum(
+ down_truncated_x, (org_vocab_end_index - org_vocab_start_index - 1))
org_vocab_mask = mint.eq(displaced_x, truncated_x)
displaced_x = mint.sub(input_, added_vocab_start_index)
down_truncated_x = mint.nn.functional.relu(displaced_x)
- truncated_x = mint.minimum(down_truncated_x, (added_vocab_end_index - added_vocab_start_index - 1))
+ truncated_x = mint.minimum(
+ down_truncated_x,
+ (added_vocab_end_index - added_vocab_start_index - 1))
added_vocab_mask = mint.eq(displaced_x, truncated_x)
added_offset = added_vocab_start_index - (
org_vocab_end_index - org_vocab_start_index) - num_org_vocab_padding
@@ -103,26 +101,29 @@ def get_masked_input_and_mask(
return input_, vocab_mask.expand_dims(-1)
-def pad_vocab_size(vocab_size: int, pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int:
+def pad_vocab_size(vocab_size: int,
+ pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int:
"""Pad the vocab size to the given value."""
return ((vocab_size + pad_to - 1) // pad_to) * pad_to
def vocab_range_from_per_partition_vocab_size(
- per_partition_vocab_size: int, rank: int, offset: int = 0
-) -> Sequence[int]:
+ per_partition_vocab_size: int,
+ rank: int,
+ offset: int = 0) -> Sequence[int]:
index_f = rank * per_partition_vocab_size
index_l = index_f + per_partition_vocab_size
return index_f + offset, index_l + offset
-def vocab_range_from_global_vocab_size(
- global_vocab_size: int, rank: int, world_size: int, offset: int = 0
-) -> Sequence[int]:
+def vocab_range_from_global_vocab_size(global_vocab_size: int,
+ rank: int,
+ world_size: int,
+ offset: int = 0) -> Sequence[int]:
per_partition_vocab_size = divide(global_vocab_size, world_size)
- return vocab_range_from_per_partition_vocab_size(
- per_partition_vocab_size, rank, offset=offset
- )
+ return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size,
+ rank,
+ offset=offset)
@dataclass
@@ -185,6 +186,7 @@ class VocabParallelEmbeddingShardIndices:
class VocabParallelEmbedding(nn.Cell):
+
def __init__(
self,
num_embeddings: int,
@@ -203,12 +205,11 @@ class VocabParallelEmbedding(nn.Cell):
self.padding_size = padding_size
self.org_vocab_size = org_num_embeddings or num_embeddings
num_added_embeddings = num_embeddings - self.org_vocab_size
- self.org_vocab_size_padded = pad_vocab_size(
- self.org_vocab_size, self.padding_size
- )
+ self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size,
+ self.padding_size)
self.num_embeddings_padded = pad_vocab_size(
- self.org_vocab_size_padded + num_added_embeddings, self.padding_size
- )
+ self.org_vocab_size_padded + num_added_embeddings,
+ self.padding_size)
assert self.org_vocab_size_padded <= self.num_embeddings_padded
self.shard_indices = self._get_indices(
@@ -233,34 +234,28 @@ class VocabParallelEmbedding(nn.Cell):
# layer type like ParallelLMHead, this is not important.
is_embedding_layer = type(self) is VocabParallelEmbedding
quant_method_implements_embedding = method_has_implemented_embedding(
- type(quant_method)
- )
+ type(quant_method))
if is_embedding_layer and not quant_method_implements_embedding:
raise NotImplementedError(
f"The class {type(quant_method).__name__} must implement "
- "the 'embedding' method, see UnquantizedEmbeddingMethod."
- )
+ "the 'embedding' method, see UnquantizedEmbeddingMethod.")
self.quant_method: QuantizeMethodBase = quant_method
if params_dtype is None:
- params_dtype = mstype.float16
+ params_dtype = get_current_vllm_config().model_config.dtype
# Divide the weight matrix along the vocaburaly dimension.
self.num_added_embeddings = self.num_embeddings - self.org_vocab_size
- self.num_embeddings_per_partition = divide(
- self.num_embeddings_padded, self.tp_size
- )
- assert (
- self.shard_indices.num_elements_padded == self.num_embeddings_per_partition
- )
+ self.num_embeddings_per_partition = divide(self.num_embeddings_padded,
+ self.tp_size)
+ assert (self.shard_indices.num_elements_padded ==
+ self.num_embeddings_per_partition)
self.num_org_embeddings_per_partition = (
- self.shard_indices.org_vocab_end_index
- - self.shard_indices.org_vocab_start_index
- )
+ self.shard_indices.org_vocab_end_index -
+ self.shard_indices.org_vocab_start_index)
self.num_added_embeddings_per_partition = (
- self.shard_indices.added_vocab_end_index
- - self.shard_indices.added_vocab_start_index
- )
+ self.shard_indices.added_vocab_end_index -
+ self.shard_indices.added_vocab_start_index)
self.quant_method.create_weights(
self,
@@ -288,17 +283,19 @@ class VocabParallelEmbedding(nn.Cell):
tp_size."""
num_added_embeddings_padded = vocab_size_padded - org_vocab_size_padded
padded_org_vocab_start_index, padded_org_vocab_end_index = (
- vocab_range_from_global_vocab_size(org_vocab_size_padded, tp_rank, tp_size)
- )
+ vocab_range_from_global_vocab_size(org_vocab_size_padded, tp_rank,
+ tp_size))
padded_added_vocab_start_index, padded_added_vocab_end_index = (
- vocab_range_from_global_vocab_size(
- num_added_embeddings_padded, tp_rank, tp_size, offset=org_vocab_size
- )
- )
+ vocab_range_from_global_vocab_size(num_added_embeddings_padded,
+ tp_rank,
+ tp_size,
+ offset=org_vocab_size))
# remove padding
- org_vocab_start_index = min(padded_org_vocab_start_index, org_vocab_size)
+ org_vocab_start_index = min(padded_org_vocab_start_index,
+ org_vocab_size)
org_vocab_end_index = min(padded_org_vocab_end_index, org_vocab_size)
- added_vocab_start_index = min(padded_added_vocab_start_index, vocab_size)
+ added_vocab_start_index = min(padded_added_vocab_start_index,
+ vocab_size)
added_vocab_end_index = min(padded_added_vocab_end_index, vocab_size)
return VocabParallelEmbeddingShardIndices(
padded_org_vocab_start_index,
@@ -311,18 +308,15 @@ class VocabParallelEmbedding(nn.Cell):
added_vocab_end_index,
)
- @jit
def construct(self, input_):
if self.tp_size > 1:
# Build the mask.
masked_input, input_mask = get_masked_input_and_mask(
- input_,
- self.shard_indices.org_vocab_start_index,
+ input_, self.shard_indices.org_vocab_start_index,
self.shard_indices.org_vocab_end_index,
self.shard_indices.num_org_vocab_padding,
self.shard_indices.added_vocab_start_index,
- self.shard_indices.added_vocab_end_index
- )
+ self.shard_indices.added_vocab_end_index)
else:
masked_input, input_mask = input_, None
# Get the embeddings.
@@ -354,11 +348,13 @@ class VocabParallelEmbedding(nn.Cell):
if loaded_weight.shape[output_dim] != self.org_vocab_size:
raise ValueError(
f"'loaded_weight.shape[output_dim]' should be equal to 'org_vocab_size',"
- f" but got {loaded_weight.shape[output_dim]} and {self.org_vocab_size}")
+ f" but got {loaded_weight.shape[output_dim]} and {self.org_vocab_size}"
+ )
# Copy the data.
- loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size).contiguous()
- param[: loaded_weight.shape[0]] = loaded_weight
+ loaded_weight = loaded_weight.narrow(output_dim, start_idx,
+ shard_size).contiguous()
+ param[:loaded_weight.shape[0]] = loaded_weight
param[loaded_weight.shape[0]:] = 0
@@ -401,8 +397,8 @@ class ParallelLMHead(VocabParallelEmbedding):
self.quant_config = quant_config
if bias:
self.bias = Parameter(
- mint.zeros(self.num_embeddings_per_partition, dtype=params_dtype)
- )
+ mint.zeros(self.num_embeddings_per_partition,
+ dtype=params_dtype))
set_weight_attrs(
self.bias,
{
@@ -411,7 +407,6 @@ class ParallelLMHead(VocabParallelEmbedding):
},
)
else:
- # self.register_parameter("bias", None)
self.bias = None
def tie_weights(self, embed_tokens: VocabParallelEmbedding):
@@ -420,8 +415,7 @@ class ParallelLMHead(VocabParallelEmbedding):
if self.quant_config and self.quant_config.get_name() == "gguf":
return embed_tokens
else:
- # self.weight = embed_tokens.weight
- self.weight.set_data(embed_tokens.weight)
+ self.weight = embed_tokens.weight
return self
def forward(self, input_):
diff --git a/vllm_mindspore/model_executor/models/attention_mask.py b/vllm_mindspore/model_executor/models/attention_mask.py
index ccfcfdb3dbb728c5da3cc251501958a1fbf4a670..b6fff0d0a2f16046a6abffb09a2d1ee45b4992f8 100644
--- a/vllm_mindspore/model_executor/models/attention_mask.py
+++ b/vllm_mindspore/model_executor/models/attention_mask.py
@@ -1,3 +1,5 @@
+# type: ignore
+# isort:skip_file
# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -12,15 +14,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
-
"""
infer attention mask.
"""
import numpy as np
-
+import mindspore as ms
from mindspore import Tensor, mint
from mindspore import dtype as mstype
-
r"""
PA:ASD-V2.1.5
1.MLA + Q_seqlen =1: no mask.(BF16 mask(0/-10000), FP16 mask(0/-10000)).
@@ -33,6 +33,8 @@ FA:ASD-V2.1.5
2.normal: mask BF16(0/1), FP16 mask(0/-10000);
"""
+MAX_MODEL_LEN_32K = 32 * 1024
+
class LowerTriangularMask:
r"""
@@ -45,23 +47,57 @@ class LowerTriangularMask:
def __init__(self, dtype, max_model_len):
self.dtype = dtype
self.max_model_len = max_model_len
-
prefill_mask_coeff = 1.0 if self.dtype == mstype.bfloat16 else -10000.0
- self.prefill_mask = Tensor(np.triu(np.ones(shape=(128, 128), dtype=np.float16), k=1) * prefill_mask_coeff,
- dtype=self.dtype)
-
- self.decode_mask = Tensor(np.triu(np.ones(shape=(self.max_model_len, self.max_model_len), dtype=np.int8), k=1),
- dtype=self.dtype) * -10000
+ self.prefill_mask = Tensor(
+ np.triu(np.ones(shape=(128, 128), dtype=np.float16), k=1) *
+ prefill_mask_coeff,
+ dtype=self.dtype)
self.hard_mask = mint.zeros((1, 1), dtype=dtype)
+ decode_mask_coeff = -10000
+ self.decode_mask = self.init_decode_mask(decode_mask_coeff)
+
+ def init_decode_mask(self, decode_mask_coeff):
+ # Our previous test limit was 32K, in order not to affect the original performance.
+ # We define 32K as the basic mask to distinguish tensor and numpy, numpy mask will cause interruption of stream
+ # and performance may not be satisfactory. Relying on PagedAttention operators to automatically generate masks
+ # to solve the problem.
+ if self.max_model_len > MAX_MODEL_LEN_32K:
+ decode_mask = np.triu(np.ones(
+ shape=(self.max_model_len, self.max_model_len),
+ dtype=np.float16),
+ k=1) * decode_mask_coeff
+ else:
+ decode_mask = Tensor(np.triu(np.ones(
+ shape=(self.max_model_len, self.max_model_len), dtype=np.int8),
+ k=1),
+ dtype=self.dtype) * decode_mask_coeff
+ return decode_mask
+
+ def gen_attention_decode_mask(self, position_ids):
+ if isinstance(self.decode_mask, ms.Tensor):
+ attention_mask = mint.index_select(self.decode_mask, 0,
+ position_ids)
+ elif isinstance(self.decode_mask, np.ndarray):
+ attention_mask = self.decode_mask[position_ids.asnumpy()]
+ attention_mask = ms.Tensor(attention_mask, dtype=self.dtype)
+ else:
+ raise ValueError(
+ f"Decode mask type:{type(self.decode_mask)} is not supported.")
- def gen_attention_mask(self, is_prefill, position_ids, query_lens):
+ return attention_mask
+
+ def gen_attention_mask(self,
+ is_prefill,
+ position_ids,
+ query_lens,
+ attn_metadata=None):
if is_prefill:
attention_mask = self.prefill_mask
else:
if max(query_lens) > 1:
- attention_mask = mint.index_select(self.decode_mask, 0, position_ids)
+ attention_mask = self.gen_attention_decode_mask(position_ids)
else:
attention_mask = self.hard_mask
return attention_mask
@@ -76,8 +112,43 @@ class MLALowerTriangularMask(LowerTriangularMask):
"""
def __init__(self, dtype, max_model_len):
-
super().__init__(dtype, max_model_len)
decode_mask_coeff = 1.0 if self.dtype == mstype.bfloat16 else -10000.0
- self.decode_mask = Tensor(np.triu(np.ones(shape=(self.max_model_len, self.max_model_len), dtype=np.int8), k=1),
- dtype=self.dtype) * decode_mask_coeff
+ self.decode_mask = self.init_decode_mask(decode_mask_coeff)
+
+
+class MultiModalLowerTriangularMask(LowerTriangularMask):
+ r"""
+ Provide multi modal Infer model attention mask.
+ Args:
+ dtype (ms dtype): The compute type of Infer model.
+ max_model_len (int): The max model length of Infer model.
+ """
+
+ def __init__(self, dtype, max_model_len):
+
+ super().__init__(dtype, max_model_len)
+
+ def gen_attention_mask(self,
+ is_prefill,
+ position_ids,
+ query_lens,
+ attn_metadata=None):
+ if is_prefill:
+ attention_mask = self.prefill_mask
+ else:
+ if max(query_lens) > 1:
+ seq_lens_np = attn_metadata.seq_lens_np
+ context_lens_np = attn_metadata.context_lens.asnumpy()
+ mm_position_ids_list = []
+ for i in range(len(seq_lens_np)):
+ mm_position_ids_list.append(
+ np.arange(context_lens_np[i], seq_lens_np[i]))
+ mm_position_ids = np.concatenate(mm_position_ids_list)
+ mm_position_ids = ms.Tensor(mm_position_ids,
+ dtype=position_ids.dtype)
+ attention_mask = mint.index_select(self.decode_mask, 0,
+ mm_position_ids)
+ else:
+ attention_mask = self.hard_mask
+ return attention_mask
diff --git a/vllm_mindspore/model_executor/models/llama.py b/vllm_mindspore/model_executor/models/llama.py
index 3a18956b93e825851144f28ca4a261d46bcea9b7..954579f11f59f15ff46f15ad0a1a37049b2142b7 100644
--- a/vllm_mindspore/model_executor/models/llama.py
+++ b/vllm_mindspore/model_executor/models/llama.py
@@ -1,5 +1,4 @@
#!/usr/bin/env python3
-# encoding: utf-8
# Copyright 2025 Huawei Technologies Co., Ltd
# Copyright 2024 The vLLM team.
#
@@ -16,49 +15,36 @@
# limitations under the License.
# ============================================================================
-from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
+from typing import (TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set,
+ Tuple, Type, Union)
if TYPE_CHECKING:
from transformers import LlamaConfig
else:
LlamaConfig = None
+from mindspore import Tensor, mint, nn
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
+from vllm.model_executor.models.interfaces import SupportsPP
+from vllm.model_executor.sampling_metadata import SamplingMetadata
+from vllm.sequence import IntermediateTensors
-from vllm_mindspore.model_executor.layers.linear import (
- MergedColumnParallelLinear,
- QKVParallelLinear,
- RowParallelLinear,
-)
-from vllm_mindspore.model_executor.layers.logits_processor import LogitsProcessor
from vllm_mindspore.attention import Attention
from vllm_mindspore.model_executor.layers.activation import SiluAndMul
-from vllm_mindspore.model_executor.layers.vocab_parallel_embedding import (
- DEFAULT_VOCAB_PADDING_SIZE,
- ParallelLMHead,
- VocabParallelEmbedding,
-)
-from vllm_mindspore.model_executor.models.utils import (
- PPMissingLayer,
- extract_layer_index,
- make_layers,
- maybe_prefix,
- make_empty_intermediate_tensors_factory,
-)
-from vllm_mindspore.model_executor.layers.sampler import get_sampler, SamplerOutput
from vllm_mindspore.model_executor.layers.layernorm import RMSNorm
+from vllm_mindspore.model_executor.layers.linear import (
+ MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear)
+from vllm_mindspore.model_executor.layers.logits_processor import (
+ LogitsProcessor)
from vllm_mindspore.model_executor.layers.rotary_embedding import get_rope
-from vllm_mindspore.model_executor.sampling_metadata import SamplingMetadata
-
-from vllm_mindspore.model_executor.models.model_base import MsModelBase
-
-from vllm.sequence import IntermediateTensors
-from vllm.attention import AttentionMetadata
-from vllm.model_executor.models.interfaces import SupportsPP
-from vllm.model_executor.model_loader.weight_utils import maybe_remap_kv_scale_name
-
-from mindspore import Tensor, mint, jit, nn
-from mindspore import dtype as mstype
+from vllm_mindspore.model_executor.layers.sampler import (SamplerOutput,
+ get_sampler)
+from vllm_mindspore.model_executor.layers.vocab_parallel_embedding import (
+ DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
+from vllm_mindspore.model_executor.models.model_base import NativeModel
+from vllm_mindspore.model_executor.models.utils import (
+ PPMissingLayer, extract_layer_index,
+ make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
def default_weight_loader(param, loaded_weight) -> None:
@@ -66,6 +52,7 @@ def default_weight_loader(param, loaded_weight) -> None:
class LlamaMLP(nn.Cell):
+
def __init__(
self,
hidden_size: int,
@@ -91,13 +78,10 @@ class LlamaMLP(nn.Cell):
prefix=f"{prefix}.down_proj",
)
if hidden_act != "silu":
- raise ValueError(
- f"Unsupported activation: {hidden_act}. "
- "Only silu is supported for now."
- )
+ raise ValueError(f"Unsupported activation: {hidden_act}. "
+ "Only silu is supported for now.")
self.act_fn = SiluAndMul()
- @jit
def construct(self, x):
x, _ = self.gate_up_proj(x)
x = self.act_fn(x)
@@ -106,6 +90,7 @@ class LlamaMLP(nn.Cell):
class LlamaAttention(nn.Cell):
+
def __init__(
self,
config: LlamaConfig,
@@ -139,9 +124,8 @@ class LlamaAttention(nn.Cell):
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
- self.head_dim = getattr(
- config, "head_dim", self.hidden_size // self.total_num_heads
- )
+ self.head_dim = getattr(config, "head_dim",
+ self.hidden_size // self.total_num_heads)
# Phi models introduced a partial_rotary_factor parameter in the config
partial_rotary_factor = getattr(config, "partial_rotary_factor", 1)
self.rotary_dim = int(partial_rotary_factor * self.head_dim)
@@ -177,7 +161,7 @@ class LlamaAttention(nn.Cell):
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
- base=rope_theta,
+ base=rope_theta, # type: ignore[arg-type]
rope_scaling=rope_scaling,
is_neox_style=is_neox_style,
)
@@ -190,7 +174,8 @@ class LlamaAttention(nn.Cell):
sw_idx = layer_idx % len(interleaved_sliding_window)
sliding_window = interleaved_sliding_window[sw_idx]
else:
- raise ValueError(f"{type(interleaved_sliding_window)} is not supported.")
+ raise ValueError(
+ f"{type(interleaved_sliding_window)} is not supported.")
else:
sliding_window = None
@@ -204,32 +189,33 @@ class LlamaAttention(nn.Cell):
per_layer_sliding_window=sliding_window,
prefix=f"{prefix}.attn",
)
- self.attn_mask = mint.triu(mint.ones(size=(128, 128), dtype=mstype.float16), 1) * -10000.0
- @jit
def construct(
self,
positions: Tensor,
hidden_states: Tensor,
- kv_cache: Tuple[Tensor, Tensor],
- # attn_metadata: AttentionMetadata,
- num_prefill_tokens: int,
- num_decode_tokens: int,
+ key_cache: Tensor,
+ value_cache: Tensor,
+ is_prefill: bool,
slot_mapping: Tensor,
- batch_valid_length: Tuple[int],
- context_lens: Tensor,
+ attn_mask: Tensor,
+ batch_valid_length: Tensor,
+ q_seq_lens: Tensor,
block_tables: Tensor,
) -> Tensor:
qkv, _ = self.qkv_proj(hidden_states)
- q, k, v = mint.split(qkv, (self.q_size, self.kv_size, self.kv_size), -1)
- q, k = self.rotary_emb(positions, q, k, context_lens, num_prefill_tokens)
- attn_output = self.attn(q, k, v, kv_cache, num_prefill_tokens, num_decode_tokens,
- slot_mapping, batch_valid_length, context_lens, block_tables, self.attn_mask)
+ q, k, v = mint.split(qkv, (self.q_size, self.kv_size, self.kv_size),
+ -1)
+ q, k = self.rotary_emb(positions, q, k, batch_valid_length, is_prefill)
+ attn_output = self.attn(q, k, v, key_cache, value_cache, is_prefill,
+ slot_mapping, attn_mask, batch_valid_length,
+ q_seq_lens, block_tables)
output, _ = self.o_proj(attn_output)
return output
class LlamaDecoderLayer(nn.Cell):
+
def __init__(
self,
config: LlamaConfig,
@@ -242,17 +228,15 @@ class LlamaDecoderLayer(nn.Cell):
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
if rope_scaling is not None and getattr(
- config, "original_max_position_embeddings", None
- ):
+ config, "original_max_position_embeddings", None):
rope_scaling["original_max_position_embeddings"] = (
- config.original_max_position_embeddings
- )
- max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
+ config.original_max_position_embeddings)
+ max_position_embeddings = getattr(config, "max_position_embeddings",
+ 8192)
# Support abacusai/Smaug-72B-v0.1 with attention_bias
# Support internlm/internlm-7b with bias
attention_bias = getattr(config, "attention_bias", False) or getattr(
- config, "bias", False
- )
+ config, "bias", False)
bias_o_proj = attention_bias
# support internlm/internlm3-8b with qkv_bias
if hasattr(config, 'qkv_bias'):
@@ -262,9 +246,8 @@ class LlamaDecoderLayer(nn.Cell):
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
- num_kv_heads=getattr(
- config, "num_key_value_heads", config.num_attention_heads
- ),
+ num_kv_heads=getattr(config, "num_key_value_heads",
+ config.num_attention_heads),
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
@@ -282,23 +265,22 @@ class LlamaDecoderLayer(nn.Cell):
bias=getattr(config, "mlp_bias", False),
prefix=f"{prefix}.mlp",
)
- self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.post_attention_layernorm = RMSNorm(
- config.hidden_size, eps=config.rms_norm_eps
- )
+ self.input_layernorm = RMSNorm(config.hidden_size,
+ eps=config.rms_norm_eps)
+ self.post_attention_layernorm = RMSNorm(config.hidden_size,
+ eps=config.rms_norm_eps)
- @jit
def construct(
self,
positions: Tensor,
hidden_states: Tensor,
- kv_cache: Tuple[Tensor, Tensor],
- # attn_metadata: AttentionMetadata,
- num_prefill_tokens: int,
- num_decode_tokens: int,
+ key_cache: Tensor,
+ value_cache: Tensor,
+ is_prefill: bool,
slot_mapping: Tensor,
- batch_valid_length: Tuple[int],
- context_lens: Tensor,
+ attn_mask: Tensor,
+ batch_valid_length: Tensor,
+ q_seq_lens: Tensor,
block_tables: Tensor,
residual: Optional[Tensor],
) -> Tuple[Tensor, Tensor]:
@@ -307,22 +289,17 @@ class LlamaDecoderLayer(nn.Cell):
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
- hidden_states, residual = self.input_layernorm(hidden_states, residual)
-
- hidden_states = self.self_attn(
- positions,
- hidden_states,
- kv_cache,
- num_prefill_tokens,
- num_decode_tokens,
- slot_mapping,
- batch_valid_length,
- context_lens,
- block_tables
- )
+ hidden_states, residual = self.input_layernorm(
+ hidden_states, residual)
+
+ hidden_states = self.self_attn(positions, hidden_states, key_cache,
+ value_cache, is_prefill, slot_mapping,
+ attn_mask, batch_valid_length,
+ q_seq_lens, block_tables)
# Fully Connected
- hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
+ hidden_states, residual = self.post_attention_layernorm(
+ hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
@@ -339,18 +316,18 @@ class LlamaModel(nn.Cell):
layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer,
):
super().__init__()
- config = vllm_config
+ config = vllm_config.model_config.hf_config
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.org_vocab_size = config.vocab_size
- # TODO: Support quant_config cache_config
- quant_config = None
- cache_config = None
+ quant_config = vllm_config.quant_config
+ self.quant_config = quant_config
+ cache_config = vllm_config.cache_config
+ lora_config = vllm_config.lora_config # noqa: F841
- if get_pp_group().is_first_rank or (
- config.tie_word_embeddings and get_pp_group().is_last_rank
- ):
+ if get_pp_group().is_first_rank or (config.tie_word_embeddings
+ and get_pp_group().is_last_rank):
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
@@ -377,24 +354,22 @@ class LlamaModel(nn.Cell):
self.norm = PPMissingLayer()
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
- ["hidden_states", "residual"], config.hidden_size
- )
+ ["hidden_states", "residual"], config.hidden_size)
def get_input_embeddings(self, input_ids: Tensor) -> Tensor:
return self.embed_tokens(input_ids)
- @jit
def construct(
self,
input_ids: Optional[Tensor],
positions: Tensor,
- kv_caches: List[Tuple[Tensor, Tensor]],
- # attn_metadata: AttentionMetadata,
- num_prefill_tokens: int,
- num_decode_tokens: int,
+ key_caches: List[Tensor],
+ value_caches: List[Tensor],
+ is_prefill: bool,
slot_mapping: Tensor,
- batch_valid_length: Tuple[int],
- context_lens: Tensor,
+ attn_mask: Tensor,
+ batch_valid_length: Tensor,
+ q_seq_lens: Tensor,
block_tables: Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[Tensor] = None,
@@ -410,25 +385,20 @@ class LlamaModel(nn.Cell):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
- for i in range(self.start_layer, self.end_layer): # PP 并行对层进行切分
+ for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
- hidden_states, residual = layer(
- positions,
- hidden_states,
- kv_caches[i - self.start_layer],
- num_prefill_tokens,
- num_decode_tokens,
- slot_mapping,
- batch_valid_length,
- context_lens,
- block_tables,
- residual
- )
+ hidden_states, residual = layer(positions, hidden_states,
+ key_caches[i - self.start_layer],
+ value_caches[i - self.start_layer],
+ is_prefill, slot_mapping,
+ attn_mask, batch_valid_length,
+ q_seq_lens, block_tables, residual)
if not get_pp_group().is_last_rank:
- return IntermediateTensors(
- {"hidden_states": hidden_states, "residual": residual}
- )
+ return IntermediateTensors({
+ "hidden_states": hidden_states,
+ "residual": residual
+ })
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
@@ -465,21 +435,21 @@ class LlamaModel(nn.Cell):
else:
if name in params_dict:
param = params_dict[name]
- weight_loader = getattr(
- param, "weight_loader", default_weight_loader
- )
+ weight_loader = getattr(param, "weight_loader",
+ default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
-class LlamaForCausalLM(MsModelBase, SupportsPP):
+class LlamaForCausalLM(NativeModel, SupportsPP):
+
def __init__(self, vllm_config, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
quant_config = vllm_config.quant_config
- self.model = LlamaModel(vllm_config=self.config)
+ self.model = LlamaModel(vllm_config=vllm_config)
if get_pp_group().is_last_rank:
self.unpadded_vocab_size = self.config.vocab_size
@@ -495,68 +465,47 @@ class LlamaForCausalLM(MsModelBase, SupportsPP):
DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
- if not self.lora_config
- else self.lora_config.lora_vocab_padding_size
- ),
+ if not self.lora_config else
+ self.lora_config.lora_vocab_padding_size),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
- # if self.config.tie_word_embeddings:
- # self.lm_head = self.lm_head.tie_weights(
- # self.model.embed_tokens)
+ if self.config.tie_word_embeddings:
+ self.lm_head = self.lm_head.tie_weights(
+ self.model.embed_tokens)
logit_scale = getattr(self.config, "logit_scale", 1.0)
- self.logits_processor = LogitsProcessor(
- self.unpadded_vocab_size, self.config.vocab_size, logit_scale
- )
+ self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
+ self.config.vocab_size,
+ logit_scale)
self.sampler = get_sampler()
else:
self.lm_head = PPMissingLayer()
self.make_empty_intermediate_tensors = (
- self.model.make_empty_intermediate_tensors
- )
-
- self.set_modules({"model": self.model, "lm_head": self.lm_head})
-
- self.set_model_inputs()
+ self.model.make_empty_intermediate_tensors)
- def tie_lmhead_weights(self):
- self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens)
+ self.common_preprocess(vllm_config, prefix)
- def forward(
- self,
- input_ids,
- positions,
- kv_caches,
- attn_metadata,
- intermediate_tensors=None,
- inputs_embeds=None,
- **kwargs
- ):
- if attn_metadata.num_prefill_tokens > 0:
- input_ids = input_ids.expand_dims(0)
- if attn_metadata.num_decode_tokens > 0:
- input_ids = input_ids.expand_dims(1)
- model_output = self.model(input_ids,
- positions,
- kv_caches,
- **dict(attn_metadata),
- intermediate_tensors=intermediate_tensors,
- inputs_embeds=inputs_embeds)
- if attn_metadata.num_prefill_tokens > 0:
- model_output = model_output.squeeze(0)
- if attn_metadata.num_decode_tokens > 0:
- model_output = model_output.squeeze(1)
- return model_output
+ def forward(self,
+ input_ids,
+ positions,
+ intermediate_tensors=None,
+ inputs_embeds=None,
+ **kwargs):
+ hidden_states = self.exec_model(input_ids, positions,
+ intermediate_tensors, inputs_embeds)
+ return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, Tensor]]) -> Set[str]:
params_dict = self.get_params_dict()
- self.model.load_weights(weights, params_dict)
+ load_params = self.model.load_weights(weights, params_dict)
+ if self.config.tie_word_embeddings:
+ load_params.add("lm_head.weight")
+ return load_params
- def sample(
- self, logits: Tensor, sampling_metadata: SamplingMetadata
- ) -> Optional[SamplerOutput]:
+ def sample(self, logits: Tensor,
+ sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
@@ -565,5 +514,6 @@ class LlamaForCausalLM(MsModelBase, SupportsPP):
hidden_states: Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[Tensor]:
- logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata)
- return logits
\ No newline at end of file
+ logits = self.logits_processor(self.lm_head, hidden_states,
+ sampling_metadata)
+ return logits
diff --git a/vllm_mindspore/model_executor/models/mf_models/config.py b/vllm_mindspore/model_executor/models/mf_models/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..c16e05f34595545cdef7901af32cd38c3883cd96
--- /dev/null
+++ b/vllm_mindspore/model_executor/models/mf_models/config.py
@@ -0,0 +1,166 @@
+#!/usr/bin/env python3
+# Copyright 2025 Huawei Technologies Co., Ltd
+# Copyright 2025 The vLLM team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+import types
+
+from mindformers.models.configuration_utils import PretrainedConfig
+from mindformers.tools.register.config import MindFormerConfig
+from vllm.config import VllmConfig
+
+MF_CTX_MAPPING = {
+ 'run_mode': (None, "predict"),
+ 'use_legacy': (None, False),
+ 'load_ckpt_format': (None, 'safetensors'),
+ 'auto_trans_ckpt': (None, True),
+}
+
+MF_PARALLEL_MAPPING = {
+ 'parallel_mode': (None, 'STAND_ALONE'),
+ 'parallel_config.model_parallel':
+ ('parallel_config.tensor_parallel_size', None),
+ 'parallel_config.pipeline_stage':
+ ('parallel_config.pipeline_parallel_size', None),
+ 'parallel_config.vocab_emb_dp': (None, False)
+}
+
+# Common model config
+MODEL_COMMON_MAPPING = {
+ 'seq_length': ('model_config.max_model_len', None),
+ 'use_flash_attention': (None, True),
+ "compute_dtype": ('model_config.hf_config.torch_dtype', 'bfloat16'),
+ 'architectures': ('model_config.hf_config.architectures', None),
+ 'bos_token_id': ('model_config.hf_config.bos_token_id', None),
+ 'eos_token_id': ('model_config.hf_config.eos_token_id', None),
+ 'model_type': ('model_config.hf_config.model_type', None),
+ # transformer_config
+ 'attention_dropout': ('model_config.hf_config.attention_dropout', None),
+ 'hidden_act': ('model_config.hf_config.hidden_act', None),
+ 'hidden_size': ('model_config.hf_config.hidden_size', None),
+ 'intermediate_size': ('model_config.hf_config.intermediate_size', None),
+ 'max_position_embeddings':
+ ('model_config.hf_config.max_position_embeddings', None),
+ 'num_attention_heads':
+ ('model_config.hf_config.num_attention_heads', None),
+ 'rms_norm_eps': ('model_config.hf_config.rms_norm_eps', None),
+ 'num_hidden_layers': ('model_config.hf_config.num_hidden_layers', None),
+ 'num_layers': ('model_config.hf_config.num_layers', None),
+ 'num_key_value_heads':
+ ('model_config.hf_config.num_key_value_heads', None),
+ 'n_kv_heads': ('model_config.hf_config.n_kv_heads', None),
+ 'head_dim': ('model_config.hf_config.head_dim', None),
+ 'rope_theta': ('model_config.hf_config.rope_theta', None),
+ 'tie_word_embeddings':
+ ('model_config.hf_config.tie_word_embeddings', None),
+ 'vocab_size': ('model_config.hf_config.vocab_size', None),
+}
+
+# model default config
+MODEL_RELATED_MAPPING = {
+ 'qwen2': {
+ "gated_linear_unit": True,
+ 'params_dtype': 'float32', # need an input
+ 'add_qkv_bias': True,
+ },
+ 'qwen3': {
+ "gated_linear_unit": True,
+ 'params_dtype': 'float32', # need an input
+ 'add_qkv_bias': False,
+ },
+ 'deepseek_v3': {
+ "gated_linear_unit": True,
+ 'params_dtype': 'bfloat16', # need an input
+ 'add_qkv_bias': False,
+ 'normalization': 'RMSNorm'
+ }
+ # Add anther model type...
+}
+
+
+def get_nested_attr(obj, path: str, default=None):
+ """get nested attr from obj."""
+ current = obj
+ for attr in path.split('.'):
+ if not hasattr(current, attr):
+ return default
+ current = getattr(current, attr)
+ return current
+
+
+def set_nested_attr(obj, path: str, value):
+ """Set nested attr of MindFormerConfig."""
+ attrs = path.split('.')
+
+ current = obj
+ for attr in attrs[:-1]:
+ if not hasattr(current, attr) or getattr(current, attr) is None:
+ setattr(current, attr, MindFormerConfig())
+ current = getattr(current, attr)
+
+ setattr(current, attrs[-1], value)
+
+
+def transform_config(mapping_table: dict, vllm_config: VllmConfig,
+ target_config):
+ for target_path, mapping in mapping_table.items():
+ src_path, transform = mapping
+
+ src_value = get_nested_attr(vllm_config,
+ src_path) if src_path is not None else None
+
+ if src_value is not None:
+ transformed_value = src_value
+ elif transform and isinstance(
+ transform, (types.FunctionType, types.BuiltinFunctionType)):
+ transformed_value = transform(src_value)
+ else:
+ transformed_value = transform
+
+ if transformed_value is not None:
+ set_nested_attr(target_config, target_path, transformed_value)
+
+
+def gen_model_relatived_config(model_type):
+ return MODEL_RELATED_MAPPING.get(model_type)
+
+
+def gen_model_config_dict(vllm_config: VllmConfig):
+ target_config = MindFormerConfig()
+
+ transform_config(MODEL_COMMON_MAPPING, vllm_config, target_config)
+
+ model_type = vllm_config.model_config.hf_config.model_type
+ model_related_config = gen_model_relatived_config(model_type)
+ target_config.update(model_related_config)
+
+ return target_config
+
+
+def gen_mf_config(vllm_config: VllmConfig):
+ target_config = MindFormerConfig()
+ transform_config(MF_CTX_MAPPING, vllm_config, target_config)
+ transform_config(MF_PARALLEL_MAPPING, vllm_config, target_config)
+ target_config.set_value(
+ 'model.model_config',
+ MindFormerConfig(**gen_model_config_dict(vllm_config)))
+ return target_config
+
+
+def gen_model_config(mf_config: MindFormerConfig,
+ model_config_type: PretrainedConfig):
+ model_config = model_config_type(**mf_config.model.model_config,
+ parallel_config=mf_config.parallel_config)
+ model_config.post_process = False
+ return model_config
diff --git a/vllm_mindspore/model_executor/models/mf_models/deepseek_mtp.py b/vllm_mindspore/model_executor/models/mf_models/deepseek_mtp.py
deleted file mode 100644
index c0b72f4df2843a09b804c88d42b771d62d29b489..0000000000000000000000000000000000000000
--- a/vllm_mindspore/model_executor/models/mf_models/deepseek_mtp.py
+++ /dev/null
@@ -1,115 +0,0 @@
-#!/usr/bin/env python3
-# encoding: utf-8
-# Copyright 2025 Huawei Technologies Co., Ltd
-# Copyright 2024 The vLLM team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ============================================================================
-
-from typing import Iterable, Set, Tuple
-
-from vllm.config import VllmConfig
-from vllm.config import get_current_vllm_config
-from vllm.forward_context import get_forward_context
-from vllm.logger import init_logger
-
-from mindspore import Tensor, JitConfig, Model, mutable
-from mindspore.nn.utils import no_init_parameters
-
-from research.deepseek3.deepseek3_config import (
- DeepseekV3Config as DeepseekV3Config_MF,
-)
-from research.deepseek3.deepseek3 import (
- DeepseekV3ForCausalLM as DeepseekV3ForCausalLM_MF,
-)
-
-from vllm_mindspore.model_executor.layers.sampler import get_sampler
-from vllm_mindspore.model_executor.models.model_base import Fake_MLA
-from vllm_mindspore.model_executor.models.mf_models.mf_model_base import MfModelBase
-from vllm_mindspore.model_executor.models.mf_models.deepseekv3_weight_processor import DeepseekV3WeightProcessor
-
-logger = init_logger(__name__)
-
-class DeepseekV3MTPForCausalLM(MfModelBase):
- def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
- super(DeepseekV3MTPForCausalLM, self).__init__(
- vllm_config=vllm_config, prefix=prefix
- )
- self.mf_kvcaches_init = False
-
- self.sampler = get_sampler()
- self.set_modules({"model": self.network})
-
- self.kv_caches = [Fake_MLA() for i in range(self.mf_model_config.num_layers)]
- compilation_config = get_current_vllm_config().compilation_config
-
- if prefix in compilation_config.static_forward_context:
- raise ValueError(f"Duplicate layer name: {prefix}")
- for i in range(self.mf_model_config.num_nextn_predict_layers):
- compilation_config.static_forward_context[str(i)] = self.kv_caches[i]
-
- self.set_flags = False
-
-
- def _generate_model_config(self):
- self.mf_config.load_checkpoint = self.get_model_path()
-
- self.mf_model_config = DeepseekV3Config_MF(**self.mf_config.model.model_config)
- if self.mf_config.moe_config:
- self.mf_model_config.moe_config = self.mf_config.moe_config
- self.mf_model_config.return_hidden_states = True
- setattr(self.mf_model_config, 'npu_mem_size', -1)
-
- self.mf_model_config.is_mtp_model = True
- self.mf_model_config.num_nextn_predict_layers = self.model_config.hf_config.num_nextn_predict_layers
- if self.mf_model_config.num_nextn_predict_layers != 1:
- raise NotImplementedError("Only support 1 MTP-layer now.")
-
- self.mf_config.model.model_config = self.mf_model_config
-
-
- def _create_network(self):
- # Initital network
- with no_init_parameters(): # Delay initialization
- network = DeepseekV3ForCausalLM_MF(self.mf_model_config)
-
- return network, network.mtp_model.head
-
-
- def get_kvcache(self):
- key_cache = []
- forward_context = get_forward_context()
- for i in range(self.mf_model_config.num_nextn_predict_layers):
- k_cache = self.kv_caches[i].kv_cache[forward_context.virtual_engine][0]
- key_cache.append(k_cache)
- return mutable(key_cache), None
-
-
- def update_model_inputs(self, model_inputs, **kwargs):
- # ToDo: supports multi-mtpLayers with 'spec_step_idx' specifing the layer index.
- if kwargs.get("spec_step_idx", 0) != 0:
- raise NotImplementedError("Only support 1 MTP-layer now.")
- # model_inputs["index"] = ms.Tensor(kwargs.get("spec_step_idx", 0), ms.int32)
- hidden_states_shape = list(model_inputs["input_ids"].shape)
- hidden_states_shape.append(self.model_config.get_hidden_size())
- model_inputs["hidden_states"] = kwargs.get("previous_hidden_states").reshape(hidden_states_shape)
- return model_inputs
-
-
- def load_weights(self, weights: Iterable[Tuple[str, Tensor]]) -> Set[str]:
- weight_processor = DeepseekV3WeightProcessor(self.mf_config, self.network, False)
- weight_processor.load_safetensors_shard(self.mf_config.load_checkpoint, is_mtp_model=True)
- self.network.set_dynamic_inputs()
- dynamic_hidden_states = Tensor(shape=[None, None], dtype=self.mf_model_config.compute_dtype)
- self.lm_head.set_inputs(dynamic_hidden_states)
- return None
diff --git a/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py b/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py
index d09d8d265c2e9daa8114b6e6afa0da1e85cc99f4..5b4b5db96104a2fafc5f12324855d5ecdf7f3e5a 100644
--- a/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py
+++ b/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py
@@ -1,5 +1,4 @@
#!/usr/bin/env python3
-# encoding: utf-8
# Copyright 2025 Huawei Technologies Co., Ltd
# Copyright 2024 The vLLM team.
#
@@ -16,283 +15,219 @@
# limitations under the License.
# ============================================================================
-import os
-from typing import Iterable, Set, Tuple
-from collections import OrderedDict
+from typing import Iterable, Optional, Tuple, Union
-import numpy as np
-import vllm.envs as envs
import mindspore as ms
-
-from vllm.config import VllmConfig
-from vllm.config import get_current_vllm_config
-from vllm.distributed.parallel_state import get_dp_group, get_tensor_model_parallel_world_size
+import numpy as np
+from mindformers.core.context import build_mf_context
+from mindformers.core.parallel_config import build_parallel_config
+from mindformers.models.deepseek3.configuration_deepseek_v3 import DeepseekV3Config
+from mindformers.models.deepseek3.modeling_deepseek_v3 import ( # noqa
+ DeepseekV3ForCausalLM as DeepseekV3ForCausalLM_MF)
+from mindformers.tools.utils import is_pynative
+from mindspore import Tensor, ops, mutable
+from mindspore.common.api import _pynative_executor
+from mindspore.nn.utils import no_init_parameters
+from vllm import envs
+from vllm.config import VllmConfig, get_current_vllm_config
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
-
-import mindspore as ms
-from mindspore import Tensor, JitConfig, Model, mutable
-from mindspore.common import dtype as msdtype
-from mindspore.nn.utils import no_init_parameters
-
-from mindspore_gs.ptq import PTQ
-from mindspore_gs.ptq import PTQMode, PTQConfig, OutliersSuppressionType, PrecisionRecovery, QuantGranularity, \
- GPTQQuantConfig
-from mindspore_gs.common import BackendTarget
-
-from mindformers.trainer.utils import transform_and_load_checkpoint
-from research.deepseek3.deepseek3_model_infer import DeepseekV3DecodeLayer
-from research.deepseek3.deepseek3_config import (
- DeepseekV3Config as DeepseekV3Config_MF,
-)
-from research.deepseek3.deepseek3 import (
- DeepseekV3ForCausalLM as DeepseekV3ForCausalLM_MF,
-)
+from vllm.model_executor.layers.sampler import SamplerOutput
+from vllm.model_executor.sampling_metadata import SamplingMetadata
+from vllm.sequence import IntermediateTensors
from vllm_mindspore.model_executor.layers.sampler import get_sampler
-from vllm_mindspore.model_executor.models.model_base import Fake_MLA, Fake_MLA_V1
-from vllm_mindspore.model_executor.models.mf_models.mf_model_base import MfModelBase
-from vllm_mindspore.model_executor.models.mf_models.deepseekv3_weight_processor import DeepseekV3WeightProcessor
-from vllm_mindspore.model_executor.models.attention_mask import MLALowerTriangularMask
+from vllm_mindspore.model_executor.models.attention_mask import (
+ MLALowerTriangularMask)
+from vllm_mindspore.model_executor.models.mf_models.config import (
+ gen_mf_config, gen_model_config)
+from vllm_mindspore.model_executor.models.model_base import (MLAAttentionWrapper,
+ MsModelBase)
logger = init_logger(__name__)
-def set_runtime_kernel_launch_group():
- kernel_launch_group = {'thread_num': 2, 'kernel_group_num': 8}
- env_kernel_launch_group = os.getenv("EXPERIMENTAL_KERNEL_LAUNCH_GROUP", None)
- if env_kernel_launch_group is not None:
- pairs = env_kernel_launch_group.split(',')
- for pair in pairs:
- key, val = pair.split(':')
- kernel_launch_group[key] = val
- thread_num = int(kernel_launch_group.get('thread_num', 2))
- kernel_group_num = int(kernel_launch_group.get('kernel_group_num', 8))
- ms.runtime.set_kernel_launch_group(thread_num=thread_num, kernel_group_num=kernel_group_num)
-
-
-def _get_padding_index(q_seq_len):
- dp_size = get_dp_group().world_size
- tp_size = get_tensor_model_parallel_world_size()
- if dp_size == 1:
- return None, None, None, None
-
- tokens_len_per_dp = q_seq_len.sum().reshape(-1)
- tokens_len_per_dp = get_dp_group().all_gather(tokens_len_per_dp)
- tokens_len_per_dp = tokens_len_per_dp.asnumpy()
- padding_size = (tokens_len_per_dp.max() + tp_size - 1) // tp_size * tp_size
-
- dp_rank_id = get_dp_group().rank_in_group
- attn_padding_idx = None
- attn_unpadding_idx = None
- ffn_padding_idx = None
- ffn_unpadding_idx = None
- last_arange_index = 0
-
- for dp_rank, tokens_length in enumerate(tokens_len_per_dp):
- arange_data = np.arange(0, int(tokens_length), dtype=np.int32)
- if dp_rank == dp_rank_id:
- ffn_unpadding_idx = arange_data
- pad = np.zeros(padding_size - arange_data.shape[0], dtype=np.int32)
- attn_padding_idx = np.concatenate((arange_data, pad), axis=0)
- if dp_rank == 0:
- attn_unpadding_idx = arange_data
- last_arange_index = arange_data[-1]
- pad = np.zeros(padding_size - attn_unpadding_idx.shape[0], dtype=np.int32)
- ffn_padding_idx = np.concatenate((attn_unpadding_idx, pad), axis=0)
- else:
- attn_offset_idx = arange_data + padding_size * dp_rank
- attn_unpadding_idx = np.concatenate((attn_unpadding_idx, attn_offset_idx), axis=0)
- ffn_offset_idx = arange_data + last_arange_index + 1
- last_arange_index = ffn_offset_idx[-1]
- pad = np.zeros(padding_size - ffn_offset_idx.shape[0], dtype=np.int32)
- ffn_padding_idx = np.concatenate((ffn_padding_idx, ffn_offset_idx, pad), axis=0)
- return ms.from_numpy(attn_padding_idx), ms.from_numpy(attn_unpadding_idx), ms.from_numpy(ffn_padding_idx), \
- ms.from_numpy(ffn_unpadding_idx)
+class DeepseekV3ForCausalLM(MsModelBase):
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
+ super().__init__(vllm_config=vllm_config, prefix=prefix)
+ self.set_flags = False
+ mf_config = gen_mf_config(vllm_config)
+ mf_config.load_checkpoint = self.get_model_path()
+ self.mf_config = mf_config
+ build_mf_context(self.mf_config)
+ build_parallel_config(self.mf_config)
-class DeepseekV3ForCausalLM(MfModelBase):
- def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
- super(DeepseekV3ForCausalLM, self).__init__(
- vllm_config=vllm_config, prefix=prefix
- )
- self.is_quant = bool(hasattr(self.mf_model_config, "quantization_config") and
- self.mf_model_config.quantization_config)
+ self._generate_model_config()
+ self.network, self.lm_head = self._create_network()
+ self.casual_mask = MLALowerTriangularMask(
+ dtype=self.network.compute_dtype,
+ max_model_len=self.model_config.max_model_len)
- self.mf_kvcaches_init = False
+ affinity_config = self.mf_config.get('context',
+ {}).get('affinity_cpu_list', {})
+ if isinstance(affinity_config, dict):
+ ms.runtime.set_cpu_affinity(True, affinity_config)
+
+ self._set_dynamic_inputs()
self.sampler = get_sampler()
self.set_modules({"model": self.network})
- if envs.VLLM_USE_V1:
- self.kv_caches = [Fake_MLA_V1() for i in range(self.mf_model_config.num_layers)]
- else:
- self.kv_caches = [Fake_MLA() for i in range(self.mf_model_config.num_layers)]
+ self.kv_caches = [
+ MLAAttentionWrapper()
+ for _ in range(self.mf_model_config.num_hidden_layers)
+ ]
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
- for i in range(self.mf_model_config.num_layers):
- compilation_config.static_forward_context[str(i)] = self.kv_caches[i]
+ for i in range(self.mf_model_config.num_hidden_layers):
+ compilation_config.static_forward_context[str(
+ i)] = self.kv_caches[i]
+
+ self.cast = ops.Cast()
+
+ def _set_dynamic_inputs(self):
+ self.network.set_dynamic_inputs()
+ dynamic_hidden_states = Tensor(
+ shape=[None, None], dtype=self.network.compute_dtype)
+ self.lm_head.set_inputs(dynamic_hidden_states)
+
+ def prepare_inputs(self, input_ids, positions):
+
+ attn_metadata = get_forward_context().attn_metadata
+ if attn_metadata is None:
+ attn_metadata = self._dummy_attention_metadata(
+ input_ids, positions)
+ key_cache, value_cache = self.get_kvcache()
+ if not envs.VLLM_USE_V1:
+ # V0
+ seq_lens = attn_metadata.seq_lens
+ max_query_len = attn_metadata.max_query_len
+ # When Mutli-Step is enabled with Chunked-Prefill, prefills and
+ # decodes are scheduled together. In the first step, all the
+ # prefills turn into decodes and max_query_len will be 1.
+ if self.is_multi_step_chunked_prefill and max_query_len == 1:
+ query_lens = [1] * len(seq_lens)
+ else:
+ query_lens = attn_metadata.query_lens
+
+ seq_lens_np = np.array(seq_lens, dtype=np.int32)
+ query_lens_np = np.array(query_lens, dtype=np.int32)
+ kv_cache_lens = seq_lens_np - query_lens_np
+ if attn_metadata.num_decode_tokens == 0 and kv_cache_lens.max(
+ ) == 0:
+ is_prefill = True
+ else:
+ is_prefill = False
+ context_lens_tensor = ms.from_numpy(kv_cache_lens)
+ else:
+ # V1
+ is_prefill = attn_metadata.max_context_lens == 0
+ query_lens_np = attn_metadata.q_seq_lens_np
+ seq_lens_np = attn_metadata.seq_lens_np
+ context_lens_tensor = attn_metadata.context_lens
+
+ q_seq_lens = ms.Tensor(query_lens_np, dtype=ms.int32)
+ position_ids = ms.Tensor(positions, dtype=ms.int32)
+ attention_mask = self.casual_mask.gen_attention_mask(
+ is_prefill, positions, query_lens_np)
+
+ model_inputs = {}
+ model_inputs["input_ids"] = input_ids.astype(ms.int32)
+ model_inputs["batch_valid_length"] = ms.from_numpy(seq_lens_np)
+ model_inputs["block_tables"] = attn_metadata.block_tables
+ model_inputs["slot_mapping"] = attn_metadata.slot_mapping
+ model_inputs["positions"] = position_ids
+ model_inputs["q_seq_lens"] = q_seq_lens
+ model_inputs["attention_mask"] = attention_mask
+ model_inputs["key_cache"] = key_cache
+ model_inputs["value_cache"] = value_cache
+ model_inputs["context_lens_tensor"] = context_lens_tensor
- self.set_flags = False
- set_runtime_kernel_launch_group()
- self.casual_mask = MLALowerTriangularMask(dtype=self.mf_model_config.compute_dtype,
- max_model_len=self.mf_model_config.seq_length)
+ return model_inputs, is_prefill
- def _generate_model_config(self):
- self.mf_config.load_checkpoint = self.get_model_path()
+ def forward(self,
+ input_ids: Tensor,
+ positions: Tensor,
+ intermediate_tensors: Optional[IntermediateTensors] = None,
+ inputs_embeds: Optional[Tensor] = None,
+ **kwargs) -> Union[Tensor, IntermediateTensors]:
+ model_inputs, is_prefill = self.prepare_inputs(input_ids, positions)
+ model_inputs = self.update_model_inputs(model_inputs, **kwargs)
+
+ if is_prefill:
+ self.network.phase = "prefill"
+ if not self.set_flags or is_pynative():
+ self.network.add_flags_custom_mcore(is_prefill=True)
+ hidden_states = self.network(**model_inputs)
+ self.network.phase = "increment"
+ if not self.set_flags or is_pynative():
+ self.network.add_flags_custom_mcore(is_prefill=False)
+ self.set_flags = True
+ else:
+ hidden_states = self.network(**model_inputs)
+
+ return hidden_states
- self.mf_model_config = DeepseekV3Config_MF(**self.mf_config.model.model_config)
- if self.mf_config.moe_config:
- self.mf_model_config.moe_config = self.mf_config.moe_config
- # dispatch/combine in moe need max_num_seqs as global_max_bs
- if hasattr(self.mf_model_config.moe_config, "dispatch_global_max_bs"):
- self.mf_model_config.moe_config.dispatch_global_max_bs = self.scheduler_config.max_num_seqs
- self.mf_model_config.return_hidden_states = True
- setattr(self.mf_model_config, 'npu_mem_size', -1)
+ def _generate_model_config(self):
+ self.mf_model_config = gen_model_config(self.mf_config, DeepseekV3Config)
+ logger.debug("=====mf_model_config====\n", self.mf_model_config)
def _create_network(self):
- # Initital network
+ # Initial network
with no_init_parameters(): # Delay initialization
network = DeepseekV3ForCausalLM_MF(self.mf_model_config)
-
- # quant
- if hasattr(self.mf_model_config, "quantization_config") and hasattr(self.mf_model_config.quantization_config,
- "quant_method"):
- ptq = self.create_ptq(self.mf_model_config.quantization_config.quant_method, PTQMode.DEPLOY)
- if ptq is not None:
- ptq.apply(network)
- ptq.convert(network)
- return network, network.lm_head
+ return network, network.model.output_layer
def get_kvcache(self):
key_cache = []
forward_context = get_forward_context()
- for i in range(self.mf_model_config.num_layers):
- k_cache = self.kv_caches[i].kv_cache[forward_context.virtual_engine][0]
+ for i in range(self.config.num_hidden_layers):
+ k_cache = self.kv_caches[i].kv_cache[ # type: ignore[attr-defined]
+ forward_context.virtual_engine][0]
key_cache.append(k_cache)
return mutable(key_cache), None
- def load_weights(self, weights: Iterable[Tuple[str, Tensor]]) -> Set[str]:
- if self.mf_config.load_ckpt_format == "ckpt":
- model = Model(self.network)
- batch_size = self.mf_config.model.model_config.batch_size
- seq_length = self.mf_config.model.model_config.seq_length
- input_ids = np.ones(shape=tuple([batch_size, seq_length]))
- infer_data = self.network.prepare_inputs_for_predict_layout(input_ids)
- transform_and_load_checkpoint(
- self.mf_config, model, self.network, infer_data, do_predict=True
- )
+ def update_model_inputs(self, model_inputs, **kwargs):
+ return model_inputs
+
+ def compute_logits(
+ self,
+ hidden_states: Tensor,
+ sampling_metadata: SamplingMetadata,
+ ) -> Optional[Tensor]:
+ if sampling_metadata is not None:
+ selected_token_indices = sampling_metadata.selected_token_indices
+ if selected_token_indices is not None and selected_token_indices.numel(
+ ) <= 0:
+ logits = ms.mint.zeros(
+ (0, self.mf_model_config.vocab_size),
+ dtype=self.mf_model_config.compute_dtype)
+ else:
+ hidden_states = hidden_states.reshape(
+ (-1, hidden_states.shape[-1]))
+ hidden_states = hidden_states.index_select(
+ 0, selected_token_indices)
+ logits = self.lm_head(hidden_states)
+ logits = logits.view(-1, logits.shape[-1])
else:
- weight_processor = DeepseekV3WeightProcessor(self.mf_config, self.network, self.is_quant)
- weight_processor.load_safetensors_shard(self.mf_config.load_checkpoint)
+ logits = self.lm_head(hidden_states)
+ logits = logits.view(-1, logits.shape[-1])
+ return logits
+
+ def sample(
+ self,
+ logits: Tensor,
+ sampling_metadata: SamplingMetadata,
+ ) -> Optional[SamplerOutput]:
+ next_tokens = self.sampler(logits, sampling_metadata)
+ _pynative_executor.sync()
+ return next_tokens
+
+ def load_weights(self, weights: Iterable[Tuple[str, Tensor]]):
+ self.network.load_weights(self.mf_config.load_checkpoint)
+ self.network.set_dynamic_inputs()
return None
-
- def prepare_inputs(self, input_ids, positions, attn_metadata):
- model_inputs, is_prefill = super().prepare_inputs(
- input_ids, positions, attn_metadata)
-
- attn_padding_idx, attn_unpadding_idx, ffn_padding_idx, ffn_unpadding_idx = _get_padding_index(
- model_inputs["q_seq_lens"])
- model_inputs["attn_padding_idx"] = attn_padding_idx
- model_inputs["attn_unpadding_idx"] = attn_unpadding_idx
- model_inputs["ffn_padding_idx"] = ffn_padding_idx
- model_inputs["ffn_unpadding_idx"] = ffn_unpadding_idx
-
- return model_inputs, is_prefill
-
- def get_model_path(self):
- model_name_or_path = self.model_config.model
- if os.path.isdir(model_name_or_path):
- return model_name_or_path
- else:
- raise ValueError("The 'model' in LLM should be the local path of the MindSpore checkpoint file.")
-
- def create_ptq(self, quant_type: str, quant_mode: PTQMode):
- """create_ptq"""
- if quant_type.lower() == 'ptq':
- cfg = PTQConfig(mode=quant_mode, backend=BackendTarget.ASCEND, weight_quant_dtype=msdtype.int8,
- act_quant_dtype=msdtype.int8,
- outliers_suppression=OutliersSuppressionType.OUTLIER_SUPPRESSION_PLUS,
- opname_blacklist=['lkv2kv', 'lm_head'], precision_recovery=PrecisionRecovery.NONE,
- act_quant_granularity=QuantGranularity.PER_TENSOR,
- weight_quant_granularity=QuantGranularity.PER_CHANNEL)
- ffn_config = PTQConfig(mode=quant_mode, backend=BackendTarget.ASCEND, weight_quant_dtype=msdtype.int8,
- act_quant_dtype=msdtype.int8,
- outliers_suppression=OutliersSuppressionType.NONE,
- precision_recovery=PrecisionRecovery.NONE,
- act_quant_granularity=QuantGranularity.PER_TOKEN,
- weight_quant_granularity=QuantGranularity.PER_CHANNEL)
- layer_policies = OrderedDict({r'.*\.feed_forward\..*': ffn_config})
- elif quant_type.lower() == 'awq-a16w4':
- cfg = PTQConfig(mode=quant_mode, backend=BackendTarget.ASCEND, weight_quant_dtype=msdtype.qint4x2,
- act_quant_dtype=None, outliers_suppression=OutliersSuppressionType.AWQ,
- opname_blacklist=['lm_head', 'lkv2kv'], weight_quant_granularity=QuantGranularity.PER_GROUP,
- group_size=128)
- layer_policies = OrderedDict()
- elif quant_type.lower() == 'awq-a16w8':
- cfg = PTQConfig(mode=quant_mode, backend=BackendTarget.ASCEND, weight_quant_dtype=msdtype.int8,
- act_quant_dtype=None, outliers_suppression=OutliersSuppressionType.AWQ,
- opname_blacklist=['lm_head', 'lkv2kv'])
- elif quant_type.lower() == 'gptq-perchannel':
- gptq_config = GPTQQuantConfig()
- cfg = PTQConfig(mode=quant_mode, backend=BackendTarget.ASCEND, weight_quant_dtype=msdtype.qint4x2,
- act_quant_dtype=None, precision_recovery=PrecisionRecovery.GPTQ, algo_args=gptq_config,
- opname_blacklist=['lm_head', 'lkv2kv'])
- layer_policies = OrderedDict()
- elif quant_type.lower() == 'gptq-pergroup':
- gptq_config = GPTQQuantConfig()
- cfg = PTQConfig(mode=quant_mode, backend=BackendTarget.ASCEND, weight_quant_dtype=msdtype.qint4x2,
- algo_args=gptq_config, act_quant_dtype=None, precision_recovery=PrecisionRecovery.GPTQ,
- weight_quant_granularity=QuantGranularity.PER_GROUP, opname_blacklist=['lm_head', 'lkv2kv'],
- group_size=64)
- w2_config = PTQConfig(mode=quant_mode, backend=BackendTarget.ASCEND, weight_quant_dtype=msdtype.int8,
- act_quant_dtype=msdtype.int8, outliers_suppression=OutliersSuppressionType.SMOOTH)
- layer_policies = OrderedDict({r'.*\.feed_forward\.w2.*': w2_config,
- r'.*\.shared_experts.w2.*': w2_config})
- elif quant_type.lower() == 'smoothquant':
- cfg = PTQConfig(mode=quant_mode, backend=BackendTarget.ASCEND, weight_quant_dtype=msdtype.int8,
- act_quant_dtype=msdtype.int8, outliers_suppression=OutliersSuppressionType.SMOOTH,
- opname_blacklist=['lm_head', 'lkv2kv'])
- ffn_config = PTQConfig(mode=quant_mode, backend=BackendTarget.ASCEND, weight_quant_dtype=msdtype.int8,
- act_quant_dtype=msdtype.int8,
- outliers_suppression=OutliersSuppressionType.NONE,
- precision_recovery=PrecisionRecovery.NONE,
- act_quant_granularity=QuantGranularity.PER_TOKEN,
- weight_quant_granularity=QuantGranularity.PER_CHANNEL)
- layer_policies = OrderedDict({r'.*\.feed_forward\..*': ffn_config})
- elif quant_type.lower() == 'osl':
- cfg = PTQConfig(mode=quant_mode, backend=BackendTarget.ASCEND, weight_quant_dtype=msdtype.int8,
- act_quant_dtype=msdtype.int8,
- outliers_suppression=OutliersSuppressionType.OUTLIER_SUPPRESSION_LITE,
- opname_blacklist=['lm_head', 'lkv2kv'])
- ffn_config = PTQConfig(mode=quant_mode, backend=BackendTarget.ASCEND, weight_quant_dtype=msdtype.int8,
- act_quant_dtype=msdtype.int8,
- outliers_suppression=OutliersSuppressionType.NONE,
- precision_recovery=PrecisionRecovery.NONE,
- act_quant_granularity=QuantGranularity.PER_TOKEN,
- weight_quant_granularity=QuantGranularity.PER_CHANNEL)
- layer_policies = OrderedDict({r'.*\.feed_forward\..*': ffn_config})
- elif quant_type.lower() == 'a16w8':
- cfg = PTQConfig(mode=quant_mode, backend=BackendTarget.ASCEND, weight_quant_dtype=msdtype.int8,
- opname_blacklist=['lm_head', 'lkv2kv'])
- layer_policies = OrderedDict()
- elif quant_type.lower() == 'a8dynw8':
- cfg = PTQConfig(mode=quant_mode, backend=BackendTarget.ASCEND, weight_quant_dtype=msdtype.int8,
- act_quant_dtype=msdtype.int8, act_quant_granularity=QuantGranularity.PER_TOKEN,
- opname_blacklist=['lm_head', 'lkv2kv'])
- layer_policies = OrderedDict()
- else:
- logger.warning("Input unsupported quant type: %s.", quant_type)
- return None
- ptq = PTQ(config=cfg, layer_policies=layer_policies)
- if 'awq' in quant_type.lower():
- # pylint: disable=protected-access
- ptq._config.weight_symmetric = False
- if 'gptq-pergroup' in quant_type.lower():
- # pylint: disable=protected-access
- ptq.layer_policies[r'.*\.feed_forward\.w2.*'].aclnn_quant_list = ["w2"]
- ptq.layer_policies[r'.*\.shared_experts.w2.*'].aclnn_quant_list = ["w2"]
- ptq.decoder_layer_types.append(DeepseekV3DecodeLayer)
- return ptq
diff --git a/vllm_mindspore/model_executor/models/mf_models/deepseekv3_infer_save_ckpt.py b/vllm_mindspore/model_executor/models/mf_models/deepseekv3_infer_save_ckpt.py
deleted file mode 100644
index 81dd8ef34465cb1cb12283746fe105b9ef3b2dee..0000000000000000000000000000000000000000
--- a/vllm_mindspore/model_executor/models/mf_models/deepseekv3_infer_save_ckpt.py
+++ /dev/null
@@ -1,107 +0,0 @@
-# Copyright 2025 Huawei Technologies Co., Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ============================================================================
-"""Infer save ckpt by safetensor."""
-import argparse
-import os
-from collections import OrderedDict
-
-from vllm.logger import init_logger
-
-import mindspore as ms
-from mindspore import dtype as msdtype
-from mindspore.communication.management import get_rank
-from mindformers.core.parallel_config import build_parallel_config
-from mindformers import MindFormerConfig
-from mindformers import build_context
-from research.deepseek3.deepseekv3_infer_parallelism import DeepseekInferParallelism
-
-from research.deepseek3.deepseek3_config import DeepseekV3Config
-from research.deepseek3.deepseek3_model_infer import InferenceDeepseekV3ForCausalLM
-
-logger = init_logger(__name__)
-
-# for example
-# bash scripts/msrun_launcher.sh "python ./infer_save_ckpt_from_safetensor.py
-# --config /path/to/predict_deepseek_r1_671b.yaml
-# --save_ckpt_path /path/to/save_ckpt_path
-# --load_checkpoint /path/to/safetensor_path " 4 8555 "output/deepseek_msrun_log" "False" 7200
-
-def create_ptq():
- '''create_ptq'''
- from research.deepseek3.deepseek3_model_infer import DeepseekV3DecodeLayer
- from mindspore_gs.ptq import PTQ
- from mindspore_gs.common import BackendTarget
- from mindspore_gs.ptq import PTQConfig, PTQMode, OutliersSuppressionType, PrecisionRecovery, QuantGranularity
- cfg = PTQConfig(mode=PTQMode.DEPLOY, backend=BackendTarget.ASCEND, weight_quant_dtype=msdtype.int8,
- act_quant_dtype=msdtype.int8, outliers_suppression=OutliersSuppressionType.OUTLIER_SUPPRESSION_PLUS,
- opname_blacklist=['lkv2kv', 'lm_head'], precision_recovery=PrecisionRecovery.NONE,
- act_quant_granularity=QuantGranularity.PER_TENSOR,
- weight_quant_granularity=QuantGranularity.PER_CHANNEL)
- ffn_config = PTQConfig(mode=PTQMode.DEPLOY, backend=BackendTarget.ASCEND, weight_quant_dtype=msdtype.int8,
- act_quant_dtype=msdtype.int8,
- outliers_suppression=OutliersSuppressionType.NONE,
- precision_recovery=PrecisionRecovery.NONE,
- act_quant_granularity=QuantGranularity.PER_TOKEN,
- weight_quant_granularity=QuantGranularity.PER_CHANNEL)
- ptq = PTQ(config=cfg, layer_policies=OrderedDict({r'.*\.feed_forward\..*': ffn_config}))
- ptq.decoder_layers.append(DeepseekV3DecodeLayer)
- return ptq
-
-
-def main(config_path, load_checkpoint, save_ckpt_dir):
- # set model config
- config = MindFormerConfig(config_path)
- config.load_checkpoint = load_checkpoint
-
- build_context(config)
- build_parallel_config(config)
- model_config = config.model.model_config
- model_config.parallel_config = config.parallel_config
- model_config.moe_config = config.moe_config
- model_config = DeepseekV3Config(**model_config)
-
- # build model from config
- network = InferenceDeepseekV3ForCausalLM(model_config)
-
- is_quant = hasattr(config.model.model_config, "quantization_config")
-
- if is_quant:
- ptq = create_ptq()
- ptq.apply(network)
- ptq.convert(network)
- ptq.summary(network)
- # load checkpoint
- if config.load_checkpoint:
- logger.info("----------------Transform and load checkpoint----------------")
- model_parallelism = DeepseekInferParallelism(config, network, is_quant)
- model_parallelism.infer_convert_and_parallelism(config.load_checkpoint)
-
- rank_id = str(get_rank())
- os.makedirs(os.path.join(save_ckpt_dir, "rank_" + rank_id), exist_ok=True)
-
- save_ckpt_path = os.path.join(save_ckpt_dir, "rank_" + rank_id, "checkpoint_" + rank_id + ".ckpt")
- ms.save_checkpoint(network.parameters_dict(), save_ckpt_path)
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument('--config_path', default='predict_llama2_7b.yaml', type=str,
- help='model config file path.')
- parser.add_argument('--load_checkpoint', type=str,
- help='load model checkpoint path or directory.')
- parser.add_argument('--save_ckpt_dir', type=str,
- help='save ckpt path.')
- args = parser.parse_args()
- main(args.config_path, args.load_checkpoint, args.save_ckpt_dir)
diff --git a/vllm_mindspore/model_executor/models/mf_models/deepseekv3_weight_processor.py b/vllm_mindspore/model_executor/models/mf_models/deepseekv3_weight_processor.py
deleted file mode 100644
index 2ee13c23816a3462331e19526a8904c580a86ec2..0000000000000000000000000000000000000000
--- a/vllm_mindspore/model_executor/models/mf_models/deepseekv3_weight_processor.py
+++ /dev/null
@@ -1,1643 +0,0 @@
-# Copyright 2025 Huawei Technologies Co., Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ============================================================================
-
-"""
-transform huggingface model to mindspore safetensor.
-"""
-import os
-import json
-import gc
-import numpy as np
-from tqdm import tqdm
-
-import mindspore as ms
-from mindspore import dtype
-from mindspore.communication.management import get_rank
-from mindformers.parallel_core.inference.parallel_state import get_tensor_model_parallel_rank
-from vllm_mindspore.model_executor.models.mf_models.weight_processor import BaseWeightProcessor, EPMethod
-from vllm.logger import init_logger
-
-logger = init_logger(__name__)
-
-
-def convert_np_to_ms_dtype(value):
- """convert_np_to_ms_dtype"""
- if value.dtype == np.int8:
- value_dtype = ms.int8
- elif value.dtype == np.int32:
- value_dtype = ms.int32
- elif value.dtype == np.int64:
- value_dtype = ms.int64
- elif value.dtype == np.float64:
- value_dtype = ms.float64
- elif value.dtype == np.float32:
- value_dtype = ms.float32
- else:
- value_dtype = ms.bfloat16
- return value_dtype
-
-
-class DeepseekV3WeightProcessor(BaseWeightProcessor):
- r"""
- Provide DeepseekV3/R1 Model weight load and shards.
- Args:
- config (DeepseekV3/R1Config): The config of DeepseekV3/R1 model.
- network (InferenceDeepseekV3ForCausalLM): The network of DeepseekV3/R1.
-
- """
-
- def __init__(self, config, network, is_quant):
- super().__init__(config, network, is_quant)
- self.num_layers = self.config.model.model_config.num_layers
- self.expert_num = self.config.moe_config.expert_num
- self.moe_split_tp = self.moe_tp_size > 1
- self.moe_split_ep = self.moe_ep_size > 1
- logger.debug(f"Deepseekv3 weight split info:"
- f"global_rank_id: {self.global_rank_id} \n"
- f"tp_group_size: {self.tp_group_size} \n"
- f"dp_group_size: {self.dp_group_size} \n"
- f"tp_rank_id: {self.tp_rank_id} \n"
- f"ep_method: {self.ep_method.name} \n"
- f"num_router_experts: {self.num_router_experts} \n"
- f"ep_group_nums: {self.ep_group_nums} \n"
- f"moe_ep_rank_id: {self.moe_ep_rank_id} \n"
- f"moe_tp_rank_id: {self.moe_tp_rank_id} \n"
- f"moe_ep_size: {self.moe_ep_size} \n"
- f"moe_tp_size: {self.moe_tp_size}")
-
- def quant_convert_weight_name(self, weight_name: str):
- """replace quant net weight name"""
- weight_name = weight_name.replace('embed_tokens.weight', 'tok_embeddings.embedding_weight')
-
- weight_name = weight_name.replace('.self_attn.q_a_proj.weight', '.attention.q2l_proj._layer.weight')
- weight_name = weight_name.replace('.self_attn.q_a_proj.input_scale', '.attention.q2l_proj.quant_op.input_scale')
- weight_name = weight_name.replace('.self_attn.q_a_proj.input_offset', '.attention.q2l_proj.quant_op.input_zp')
- weight_name = weight_name.replace('.self_attn.q_a_proj.quant_bias',
- '.attention.q2l_proj._layer.matmul.quant_bias')
- weight_name = weight_name.replace('.self_attn.q_a_proj.deq_scale',
- '.attention.q2l_proj._layer.matmul.dequant_scale')
-
- weight_name = weight_name.replace('.self_attn.q_a_layernorm.weight', '.attention.lq_norm.weight')
- weight_name = weight_name.replace('.self_attn.kv_a_layernorm.weight', '.attention.lkv_norm.weight')
- weight_name = weight_name.replace('.self_attn.kv_b_proj.', '.attention.lkv2kv.')
-
- weight_name = weight_name.replace('.self_attn.q_b_proj.weight', '.attention.l2q_proj._layer.weight')
- weight_name = weight_name.replace('.self_attn.q_b_proj.input_scale', '.attention.l2q_proj.quant_op.input_scale')
- weight_name = weight_name.replace('.self_attn.q_b_proj.input_offset', '.attention.l2q_proj.quant_op.input_zp')
- weight_name = weight_name.replace('.self_attn.q_b_proj.quant_bias',
- '.attention.l2q_proj._layer.matmul.quant_bias')
- weight_name = weight_name.replace('.self_attn.q_b_proj.deq_scale',
- '.attention.l2q_proj._layer.matmul.dequant_scale')
-
- weight_name = weight_name.replace('.self_attn.kv_a_proj_with_mqa.weight', '.attention.kv2l._layer.weight')
- weight_name = weight_name.replace('.self_attn.kv_a_proj_with_mqa.input_scale',
- '.attention.kv2l.quant_op.input_scale')
- weight_name = weight_name.replace('.self_attn.kv_a_proj_with_mqa.input_offset',
- '.attention.kv2l.quant_op.input_zp')
- weight_name = weight_name.replace('.self_attn.kv_a_proj_with_mqa.quant_bias',
- '.attention.kv2l._layer.matmul.quant_bias')
- weight_name = weight_name.replace('.self_attn.kv_a_proj_with_mqa.deq_scale',
- '.attention.kv2l._layer.matmul.dequant_scale')
-
- weight_name = weight_name.replace('.self_attn.o_proj.weight', '.attention.wo._layer.weight')
- weight_name = weight_name.replace('.self_attn.o_proj.input_scale', '.attention.wo.quant_op.input_scale')
- weight_name = weight_name.replace('.self_attn.o_proj.input_offset', '.attention.wo.quant_op.input_zp')
- weight_name = weight_name.replace('.self_attn.o_proj.quant_bias', '.attention.wo._layer.matmul.quant_bias')
- weight_name = weight_name.replace('.self_attn.o_proj.deq_scale', '.attention.wo._layer.matmul.dequant_scale')
-
- weight_name = weight_name.replace('.self_attn.q_a_layernorm.bias', '.attention.l2q_proj.quant_op.beta')
- weight_name = weight_name.replace('.input_layernorm.bias', '.attention.q2l_proj.quant_op.beta')
-
- # mlp is pertoken quant
- weight_name = weight_name.replace('.weight_scale', '.matmul.weight_scale')
- weight_name = weight_name.replace('.weight_offset', '.matmul.weight_offset')
-
- weight_name = weight_name.replace('mlp.gate_proj.', 'feed_forward.w1._layer.')
- weight_name = weight_name.replace('mlp.down_proj.', 'feed_forward.w2._layer.')
- weight_name = weight_name.replace('mlp.up_proj.', 'feed_forward.w3._layer.')
- weight_name = weight_name.replace('mlp.experts.', 'feed_forward.routed_experts.ffn.')
- weight_name = weight_name.replace('mlp.shared_experts.gate_proj.', 'feed_forward.shared_experts.w1._layer.')
- weight_name = weight_name.replace('mlp.shared_experts.down_proj.', 'feed_forward.shared_experts.w2._layer.')
- weight_name = weight_name.replace('mlp.shared_experts.up_proj.', 'feed_forward.shared_experts.w3._layer.')
- weight_name = weight_name.replace('mlp.gate.weight', 'feed_forward.routed_experts.router.dense.weight')
- weight_name = weight_name.replace('mlp.gate.e_score_correction_bias',
- 'feed_forward.routed_experts.router.e_score_correction_bias')
- weight_name = weight_name.replace('.input_layernorm.weight', '.attention_norm.weight')
- weight_name = weight_name.replace('.post_attention_layernorm.', '.ffn_norm.')
- weight_name = weight_name.replace('model.norm.weight', 'model.norm_out.weight')
- return weight_name
-
- def infer_trans_rope_weight(self, weight, qk_rope_head_dim):
- """process rope router weight"""
- w1 = weight[..., -qk_rope_head_dim::2, :]
- w2 = weight[..., -qk_rope_head_dim + 1::2, :]
- weight[..., -qk_rope_head_dim:, :] = np.concatenate([w1, w2], axis=-2)
- return weight
-
- def infer_quant_process_moe_with_ep(self, src_hf_dir, hf_weight_map, layer_id):
- w1_list = []
- w2_list = []
- w3_list = []
-
- w1_scale_list = []
- w2_scale_list = []
- w3_scale_list = []
-
- for index in range(self.ep_start, self.ep_stop):
- w1_hf_name = f"model.layers.{layer_id}.mlp.experts.{index}.gate_proj.weight"
- w2_hf_name = f"model.layers.{layer_id}.mlp.experts.{index}.down_proj.weight"
- w3_hf_name = f"model.layers.{layer_id}.mlp.experts.{index}.up_proj.weight"
-
- w1_ms_param, _ = self.get_safetensor_from_file(w1_hf_name, src_hf_dir, hf_weight_map)
- w2_ms_param, _ = self.get_safetensor_from_file(w2_hf_name, src_hf_dir, hf_weight_map)
- w3_ms_param, _ = self.get_safetensor_from_file(w3_hf_name, src_hf_dir, hf_weight_map)
-
- w1_list.append(w1_ms_param)
- w2_list.append(w2_ms_param)
- w3_list.append(w3_ms_param)
-
- w1_scale_hf_name = f"model.layers.{layer_id}.mlp.experts.{index}.gate_proj.weight_scale"
- w2_scale_hf_name = f"model.layers.{layer_id}.mlp.experts.{index}.down_proj.weight_scale"
- w3_scale_hf_name = f"model.layers.{layer_id}.mlp.experts.{index}.up_proj.weight_scale"
-
- w1_scale_ms_param, _ = self.get_safetensor_from_file(w1_scale_hf_name, src_hf_dir, hf_weight_map)
- w2_scale_ms_param, _ = self.get_safetensor_from_file(w2_scale_hf_name, src_hf_dir, hf_weight_map)
- w3_scale_ms_param, _ = self.get_safetensor_from_file(w3_scale_hf_name, src_hf_dir, hf_weight_map)
-
- w1_scale_ms_param = w1_scale_ms_param.squeeze(axis=-1)
- w2_scale_ms_param = w2_scale_ms_param.squeeze(axis=-1)
- w3_scale_ms_param = w3_scale_ms_param.squeeze(axis=-1)
- w1_scale_list.append(w1_scale_ms_param)
- w2_scale_list.append(w2_scale_ms_param)
- w3_scale_list.append(w3_scale_ms_param)
-
- return w1_list, w2_list, w3_list, w1_scale_list, w2_scale_list, w3_scale_list
-
- def infer_quant_process_moe_with_ep_tp(self, src_hf_dir, hf_weight_map, layer_id):
- w1_list = []
- w2_list = []
- w3_list = []
-
- w1_scale_list = []
- w2_scale_list = []
- w3_scale_list = []
-
- for index in range(self.ep_start, self.ep_stop):
- w1_hf_name = f"model.layers.{layer_id}.mlp.experts.{index}.gate_proj.weight"
- w2_hf_name = f"model.layers.{layer_id}.mlp.experts.{index}.down_proj.weight"
- w3_hf_name = f"model.layers.{layer_id}.mlp.experts.{index}.up_proj.weight"
-
- w1_ms_param, _ = self.get_safetensor_from_file_split_moe_tp_group(w1_hf_name, src_hf_dir, hf_weight_map,
- split_axis=0)
- w2_ms_param, _ = self.get_safetensor_from_file_split_moe_tp_group(w2_hf_name, src_hf_dir, hf_weight_map,
- split_axis=1)
- w3_ms_param, _ = self.get_safetensor_from_file_split_moe_tp_group(w3_hf_name, src_hf_dir, hf_weight_map,
- split_axis=0)
-
- w1_list.append(w1_ms_param)
- w2_list.append(w2_ms_param)
- w3_list.append(w3_ms_param)
-
- w1_scale_hf_name = f"model.layers.{layer_id}.mlp.experts.{index}.gate_proj.weight_scale"
- w2_scale_hf_name = f"model.layers.{layer_id}.mlp.experts.{index}.down_proj.weight_scale"
- w3_scale_hf_name = f"model.layers.{layer_id}.mlp.experts.{index}.up_proj.weight_scale"
-
- w1_scale_ms_param, _ = self.get_safetensor_from_file_split_moe_tp_group(w1_scale_hf_name, src_hf_dir,
- hf_weight_map,
- split_axis=0)
- w2_scale_ms_param, _ = self.get_safetensor_from_file(w2_scale_hf_name, src_hf_dir,
- hf_weight_map)
- w3_scale_ms_param, _ = self.get_safetensor_from_file_split_moe_tp_group(w3_scale_hf_name, src_hf_dir,
- hf_weight_map,
- split_axis=0)
-
- w1_scale_ms_param = w1_scale_ms_param.squeeze(axis=-1)
- w2_scale_ms_param = w2_scale_ms_param.squeeze(axis=-1)
- w3_scale_ms_param = w3_scale_ms_param.squeeze(axis=-1)
- w1_scale_list.append(w1_scale_ms_param)
- w2_scale_list.append(w2_scale_ms_param)
- w3_scale_list.append(w3_scale_ms_param)
-
- return w1_list, w2_list, w3_list, w1_scale_list, w2_scale_list, w3_scale_list
-
- def infer_quant_process_moe(self, src_hf_dir, hf_weight_map, layer_id):
- if self.moe_tp_size > 1:
- return self.infer_quant_process_moe_with_ep_tp(src_hf_dir, hf_weight_map, layer_id)
- else:
- return self.infer_quant_process_moe_with_ep(src_hf_dir, hf_weight_map, layer_id)
-
- def infer_quant_process_moe_routed_expert_ffn_weight(self, src_hf_dir, layer_id, hf_weight_map):
- """process moe router expert weight"""
- ffn_concat = self.config.model.model_config.ffn_concat
-
- # router expert dense
- router_dense_hf_name = f"model.layers.{layer_id}.mlp.gate.weight"
- router_dense_ms_name = self.quant_convert_weight_name(router_dense_hf_name)
- router_dense_ms_param, _ = self.get_safetensor_from_file(router_dense_hf_name, src_hf_dir, hf_weight_map)
- self.parameter_dict[router_dense_ms_name] = ms.Parameter(
- ms.from_numpy(router_dense_ms_param).astype(ms.bfloat16),
- name=router_dense_ms_name, requires_grad=False)
-
- # e_score_correction_bias
- e_score_correction_bias_hf_name = f"model.layers.{layer_id}.mlp.gate.e_score_correction_bias"
- e_score_correction_bias_ms_name = self.quant_convert_weight_name(e_score_correction_bias_hf_name)
- e_score_correction_bias_ms_param, _ = self.get_safetensor_from_file(e_score_correction_bias_hf_name, src_hf_dir,
- hf_weight_map)
- self.parameter_dict[e_score_correction_bias_ms_name] = ms.Parameter(
- ms.from_numpy(e_score_correction_bias_ms_param).astype(ms.float32),
- name=e_score_correction_bias_ms_name, requires_grad=False)
-
- w1_ms_name = f"model.layers.{layer_id}.feed_forward.routed_experts.ffn.w1._layer.weight"
- w2_ms_name = f"model.layers.{layer_id}.feed_forward.routed_experts.ffn.w2._layer.weight"
- w3_ms_name = f"model.layers.{layer_id}.feed_forward.routed_experts.ffn.w3._layer.weight"
-
- w1_scale_ms_name = f"model.layers.{layer_id}.feed_forward.routed_experts.ffn.w1._layer.matmul.weight_scale"
- w2_scale_ms_name = f"model.layers.{layer_id}.feed_forward.routed_experts.ffn.w2._layer.matmul.weight_scale"
- w3_scale_ms_name = f"model.layers.{layer_id}.feed_forward.routed_experts.ffn.w3._layer.matmul.weight_scale"
-
- w1_list, w2_list, w3_list, w1_scale_list, w2_scale_list, w3_scale_list = \
- self.infer_quant_process_moe(src_hf_dir, hf_weight_map, layer_id)
-
- w1_ms_stack_param = np.stack(w1_list, axis=0)
- w2_ms_stack_param = np.stack(w2_list, axis=0)
- w3_ms_stack_param = np.stack(w3_list, axis=0)
-
- w1_scale_ms_stack_param = np.stack(w1_scale_list, axis=0)
- w2_scale_ms_stack_param = np.stack(w2_scale_list, axis=0)
- w3_scale_ms_stack_param = np.stack(w3_scale_list, axis=0)
-
- if ffn_concat:
- # w_gate_hidden
- w_gate_hidden_name = f"model.layers.{layer_id}.feed_forward.routed_experts.ffn.w_gate_hidden._layer.weight"
- w_gate_hidden_np = np.concatenate([w1_ms_stack_param, w3_ms_stack_param], axis=1)
- w_gate_hidden_param = ms.from_numpy(w_gate_hidden_np).permute(0, 2, 1).astype(ms.int8)
- self.parameter_dict[w_gate_hidden_name] = ms.Parameter(w_gate_hidden_param, name=w_gate_hidden_name,
- requires_grad=False)
- # w_scale_gate_hidden
- w_scale_gate_hidden_name = \
- f"model.layers.{layer_id}.feed_forward.routed_experts.ffn.w_gate_hidden._layer.matmul.weight_scale"
-
- w_scale_gate_hidden_np = np.concatenate([w1_scale_ms_stack_param, w3_scale_ms_stack_param], axis=1)
- w_scale_gate_hidden_param = ms.from_numpy(w_scale_gate_hidden_np).astype(ms.bfloat16)
- self.parameter_dict[w_scale_gate_hidden_name] = ms.Parameter(w_scale_gate_hidden_param,
- name=w_scale_gate_hidden_name,
- requires_grad=False)
- else:
- # w1 w3
- self.parameter_dict[w1_ms_name] = ms.Parameter(
- ms.from_numpy(w1_ms_stack_param).permute(0, 2, 1).astype(ms.int8),
- name=w1_ms_name,
- requires_grad=False)
- self.parameter_dict[w3_ms_name] = ms.Parameter(
- ms.from_numpy(w3_ms_stack_param).permute(0, 2, 1).astype(ms.int8),
- name=w3_ms_name,
- requires_grad=False)
-
- # w1_scale w3_scale
- self.parameter_dict[w1_scale_ms_name] = ms.Parameter(
- ms.from_numpy(w1_scale_ms_stack_param).astype(ms.bfloat16),
- name=w1_ms_name,
- requires_grad=False)
- self.parameter_dict[w3_scale_ms_name] = ms.Parameter(
- ms.from_numpy(w3_scale_ms_stack_param).astype(ms.bfloat16),
- name=w3_ms_name,
- requires_grad=False)
-
- self.parameter_dict[w2_ms_name] = ms.Parameter(
- ms.from_numpy(w2_ms_stack_param).permute(0, 2, 1).astype(ms.int8),
- name=w2_ms_name,
- requires_grad=False)
-
- self.parameter_dict[w2_scale_ms_name] = ms.Parameter(
- ms.from_numpy(w2_scale_ms_stack_param).astype(ms.bfloat16),
- name=w2_scale_ms_name,
- requires_grad=False)
-
- def get_quant_moe_shared_expert_weight(self, w1_hf_name, w2_hf_name, w3_hf_name, w1_scale_hf_name, w2_scale_hf_name,
- w3_scale_hf_name, src_hf_dir, hf_weight_map):
- if self.ep_method in [EPMethod.DEFAULT, EPMethod.ALLGATHER]:
- w1_ms_param, _ = self.get_safetensor_from_file_split_global_group(w1_hf_name, src_hf_dir, hf_weight_map,
- split_axis=0)
- w2_ms_param, _ = self.get_safetensor_from_file_split_global_group(w2_hf_name, src_hf_dir, hf_weight_map,
- split_axis=1)
- w3_ms_param, _ = self.get_safetensor_from_file_split_global_group(w3_hf_name, src_hf_dir, hf_weight_map,
- split_axis=0)
- w1_scale_ms_param, _ = self.get_safetensor_from_file_split_global_group(w1_scale_hf_name, src_hf_dir,
- hf_weight_map, split_axis=0)
- w2_scale_ms_param, _ = self.get_safetensor_from_file(w2_scale_hf_name, src_hf_dir, hf_weight_map)
-
- w3_scale_ms_param, _ = self.get_safetensor_from_file_split_global_group(w3_scale_hf_name, src_hf_dir,
- hf_weight_map, split_axis=0)
- elif self.ep_method == EPMethod.ALLTOALL:
- w1_ms_param, _ = self.get_safetensor_from_file(w1_hf_name, src_hf_dir, hf_weight_map)
- w2_ms_param, _ = self.get_safetensor_from_file(w2_hf_name, src_hf_dir, hf_weight_map)
- w3_ms_param, _ = self.get_safetensor_from_file(w3_hf_name, src_hf_dir, hf_weight_map)
-
- w1_scale_ms_param, _ = self.get_safetensor_from_file(w1_scale_hf_name, src_hf_dir, hf_weight_map)
- w2_scale_ms_param, _ = self.get_safetensor_from_file(w2_scale_hf_name, src_hf_dir, hf_weight_map)
- w3_scale_ms_param, _ = self.get_safetensor_from_file(w3_scale_hf_name, src_hf_dir, hf_weight_map)
- else:
- raise ValueError("Unsupported ep_method:{}".format(self.ep_method))
-
- return w1_ms_param, w2_ms_param, w3_ms_param, w1_scale_ms_param, w2_scale_ms_param, w3_scale_ms_param
-
- def infer_quant_process_moe_shared_expert_ffn_weight(self, src_hf_dir, layer_id, hf_weight_map):
- """infer quant process moe shared expert ffn weight"""
- ffn_concat = self.config.model.model_config.ffn_concat
- w1_hf_name = f"model.layers.{layer_id}.mlp.shared_experts.gate_proj.weight"
- w2_hf_name = f"model.layers.{layer_id}.mlp.shared_experts.down_proj.weight"
- w3_hf_name = f"model.layers.{layer_id}.mlp.shared_experts.up_proj.weight"
-
- w1_scale_hf_name = f"model.layers.{layer_id}.mlp.shared_experts.gate_proj.weight_scale"
- w2_scale_hf_name = f"model.layers.{layer_id}.mlp.shared_experts.down_proj.weight_scale"
- w3_scale_hf_name = f"model.layers.{layer_id}.mlp.shared_experts.up_proj.weight_scale"
-
- w1_ms_name = self.quant_convert_weight_name(w1_hf_name)
- w2_ms_name = self.quant_convert_weight_name(w2_hf_name)
- w3_ms_name = self.quant_convert_weight_name(w3_hf_name)
-
- w1_scale_ms_name = self.quant_convert_weight_name(w1_scale_hf_name)
- w2_scale_ms_name = self.quant_convert_weight_name(w2_scale_hf_name)
- w3_scale_ms_name = self.quant_convert_weight_name(w3_scale_hf_name)
-
- w1_ms_param, w2_ms_param, w3_ms_param, w1_scale_ms_param, w2_scale_ms_param, w3_scale_ms_param = \
- self.get_quant_moe_shared_expert_weight(w1_hf_name, w2_hf_name, w3_hf_name, w1_scale_hf_name,
- w2_scale_hf_name,
- w3_scale_hf_name, src_hf_dir, hf_weight_map)
-
- w1_scale_ms_param = w1_scale_ms_param.squeeze(axis=-1)
- w2_scale_ms_param = w2_scale_ms_param.squeeze(axis=-1)
- w3_scale_ms_param = w3_scale_ms_param.squeeze(axis=-1)
-
- if ffn_concat:
- w_gate_hidden_name = f"model.layers.{layer_id}.feed_forward.shared_experts.w_gate_hidden._layer.weight"
- w_gate_hidden_np = np.concatenate([w1_ms_param, w3_ms_param], axis=0)
- w_gate_hidden_param = ms.from_numpy(w_gate_hidden_np).astype(ms.int8)
- self.parameter_dict[w_gate_hidden_name] = ms.Parameter(w_gate_hidden_param, name=w_gate_hidden_name,
- requires_grad=False)
-
- w_scale_gate_hidden_name = \
- f"model.layers.{layer_id}.feed_forward.shared_experts.w_gate_hidden._layer.matmul.weight_scale"
- w_scale_gate_hidden_np = np.concatenate([w1_scale_ms_param, w3_scale_ms_param], axis=0)
- w_scale_gate_hidden_param = ms.from_numpy(w_scale_gate_hidden_np).astype(ms.bfloat16)
- self.parameter_dict[w_scale_gate_hidden_name] = ms.Parameter(w_scale_gate_hidden_param,
- name=w_scale_gate_hidden_name,
- requires_grad=False)
-
- else:
- self.parameter_dict[w1_ms_name] = ms.Parameter(ms.from_numpy(w1_ms_param).astype(ms.int8),
- name=w1_ms_name,
- requires_grad=False)
- self.parameter_dict[w3_ms_name] = ms.Parameter(ms.from_numpy(w3_ms_param).astype(ms.int8),
- name=w3_ms_name,
- requires_grad=False)
-
- self.parameter_dict[w1_scale_ms_name] = ms.Parameter(
- ms.from_numpy(w1_scale_ms_param).astype(ms.bfloat16),
- name=w1_ms_name,
- requires_grad=False)
- self.parameter_dict[w3_scale_ms_name] = ms.Parameter(
- ms.from_numpy(w3_scale_ms_param).astype(ms.bfloat16),
- name=w3_ms_name,
- requires_grad=False)
-
- self.parameter_dict[w2_ms_name] = ms.Parameter(ms.from_numpy(w2_ms_param).astype(ms.int8),
- name=w2_ms_name,
- requires_grad=False)
-
- self.parameter_dict[w2_scale_ms_name] = ms.Parameter(
- ms.from_numpy(w2_scale_ms_param).astype(ms.bfloat16),
- name=w2_ms_name,
- requires_grad=False)
-
- def infer_quant_process_dense_ffn_weight(self, src_hf_dir, layer_id, hf_weight_map):
- """infer process dense ffn weight"""
-
- ffn_concat = self.config.model.model_config.ffn_concat
- w1_hf_name = f"model.layers.{layer_id}.mlp.gate_proj.weight"
- w1_ms_name = self.quant_convert_weight_name(w1_hf_name)
- w1_ms_param, _ = self.get_safetensor_from_file_split_tp_group(w1_hf_name, src_hf_dir, hf_weight_map,
- split_axis=0)
- w1_scale_hf_name = f"model.layers.{layer_id}.mlp.gate_proj.weight_scale"
- w1_scale_ms_name = self.quant_convert_weight_name(w1_scale_hf_name)
- w1_scale_ms_param, _ = self.get_safetensor_from_file_split_tp_group(w1_scale_hf_name, src_hf_dir, hf_weight_map,
- split_axis=0)
-
- w2_hf_name = f"model.layers.{layer_id}.mlp.down_proj.weight"
- w2_ms_name = self.quant_convert_weight_name(w2_hf_name)
- w2_ms_param, _ = self.get_safetensor_from_file_split_tp_group(w2_hf_name, src_hf_dir, hf_weight_map,
- split_axis=1)
- w2_scale_hf_name = f"model.layers.{layer_id}.mlp.down_proj.weight_scale"
- w2_scale_ms_name = self.quant_convert_weight_name(w2_scale_hf_name)
- # shape:[7168,1]
- w2_scale_ms_param, _ = self.get_safetensor_from_file(w2_scale_hf_name, src_hf_dir, hf_weight_map)
-
- w3_hf_name = f"model.layers.{layer_id}.mlp.up_proj.weight"
- w3_ms_name = self.quant_convert_weight_name(w3_hf_name)
- w3_ms_param, _ = self.get_safetensor_from_file_split_tp_group(w3_hf_name, src_hf_dir, hf_weight_map,
- split_axis=0)
- w3_scale_hf_name = f"model.layers.{layer_id}.mlp.up_proj.weight_scale"
- w3_scale_ms_name = self.quant_convert_weight_name(w3_scale_hf_name)
- w3_scale_ms_param, _ = self.get_safetensor_from_file_split_tp_group(w3_scale_hf_name, src_hf_dir, hf_weight_map,
- split_axis=0)
-
- w1_scale_ms_param = w1_scale_ms_param.squeeze(axis=-1)
- w2_scale_ms_param = w2_scale_ms_param.squeeze(axis=-1)
- w3_scale_ms_param = w3_scale_ms_param.squeeze(axis=-1)
-
- if ffn_concat:
- w_gate_hidden_name = f"model.layers.{layer_id}.feed_forward.w_gate_hidden._layer.weight"
- w_gate_hidden_np = np.concatenate([w1_ms_param, w3_ms_param], axis=0)
- w_gate_hidden_param = ms.from_numpy(w_gate_hidden_np).astype(dtype=ms.int8)
- self.parameter_dict[w_gate_hidden_name] = ms.Parameter(w_gate_hidden_param, name=w_gate_hidden_name,
- requires_grad=False)
-
- w_scale_gate_hidden_name = f"model.layers.{layer_id}.feed_forward.w_gate_hidden._layer.matmul.weight_scale"
- w_scale_gate_hidden_param = ms.from_numpy(
- np.concatenate([w1_scale_ms_param, w3_scale_ms_param], axis=0)).astype(dtype=ms.bfloat16)
- self.parameter_dict[w_scale_gate_hidden_name] = ms.Parameter(w_scale_gate_hidden_param,
- name=w_scale_gate_hidden_name,
- requires_grad=False)
-
- else:
- self.parameter_dict[w1_ms_name] = ms.Parameter(ms.from_numpy(w1_ms_param).astype(ms.int8),
- name=w1_ms_name,
- requires_grad=False)
- self.parameter_dict[w3_ms_name] = ms.Parameter(ms.from_numpy(w3_ms_param).astype(ms.int8),
- name=w3_ms_name,
- requires_grad=False)
-
- self.parameter_dict[w1_scale_ms_name] = ms.Parameter(
- ms.from_numpy(w1_scale_ms_param).astype(ms.bfloat16),
- name=w1_scale_ms_name,
- requires_grad=False)
- self.parameter_dict[w3_scale_ms_name] = ms.Parameter(
- ms.from_numpy(w3_scale_ms_param).astype(ms.bfloat16),
- name=w3_scale_ms_name,
- requires_grad=False)
-
- self.parameter_dict[w2_ms_name] = ms.Parameter(ms.from_numpy(w2_ms_param).astype(ms.int8),
- name=w2_ms_name,
- requires_grad=False)
-
- self.parameter_dict[w2_scale_ms_name] = ms.Parameter(
- ms.from_numpy(w2_scale_ms_param).astype(ms.bfloat16),
- name=w2_ms_name,
- requires_grad=False)
-
- def infer_convert_outer_weight(self, src_hf_dir, hf_weight_map):
- """convert weight not in model"""
- embed_tokens_hf_name = "model.embed_tokens.weight"
- embed_tokens_ms_name = self.quant_convert_weight_name(embed_tokens_hf_name)
- np_data, _ = self.get_safetensor_from_file(embed_tokens_hf_name, src_hf_dir, hf_weight_map)
- self.parameter_dict[embed_tokens_ms_name] = ms.Parameter(ms.from_numpy(np_data).astype(ms.bfloat16),
- name=embed_tokens_ms_name,
- requires_grad=False)
-
- norm_hf_name = "model.norm.weight"
- norm_ms_name = self.quant_convert_weight_name(norm_hf_name)
- np_data, _ = self.get_safetensor_from_file(norm_hf_name, src_hf_dir, hf_weight_map)
- self.parameter_dict[norm_ms_name] = ms.Parameter(ms.from_numpy(np_data).astype(ms.bfloat16),
- name=norm_ms_name,
- requires_grad=False)
-
- lm_head_hf_name = "lm_head.weight"
- lm_head_ms_name = self.quant_convert_weight_name(lm_head_hf_name)
- if not self.config.parallel_config.vocab_emb_dp:
- np_data, _ = self.get_safetensor_from_file_split_tp_group(lm_head_hf_name, src_hf_dir, hf_weight_map,
- split_axis=0)
- else:
- np_data, _ = self.get_safetensor_from_file(lm_head_hf_name, src_hf_dir, hf_weight_map)
- self.parameter_dict[lm_head_ms_name] = ms.Parameter(ms.from_numpy(np_data).astype(ms.bfloat16),
- name=lm_head_ms_name,
- requires_grad=False)
-
- def quant_special_attention_weight(self, layer_id, src_hf_dir, hf_weight_map, name, is_trans_rope_weigh=False,
- is_split_param=False):
- # q_a_proj->q2l_proj
- # kv_a_proj_with_mqa->kv2l
- # q_a_layernorm->lq_norm
- # o_proj->wo
-
- # input_scale, input_zp no split
- input_scale_hf_name = f"model.layers.{layer_id}.self_attn." + name + ".input_scale"
- input_scale_ms_name = self.quant_convert_weight_name(input_scale_hf_name)
- input_scale_ms_param, _ = self.get_safetensor_from_file(input_scale_hf_name, src_hf_dir, hf_weight_map)
- self.parameter_dict[input_scale_ms_name] = ms.Parameter(
- ms.from_numpy(input_scale_ms_param).astype(ms.bfloat16),
- name=input_scale_ms_name, requires_grad=False)
-
- input_zp_hf_name = f"model.layers.{layer_id}.self_attn." + name + ".input_offset"
- input_zp_ms_name = self.quant_convert_weight_name(input_zp_hf_name)
- input_zp_ms_param, _ = self.get_safetensor_from_file(input_zp_hf_name, src_hf_dir, hf_weight_map)
- self.parameter_dict[input_zp_ms_name] = ms.Parameter(ms.from_numpy(input_zp_ms_param).astype(ms.int8),
- name=input_zp_ms_name,
- requires_grad=False)
-
- if not is_trans_rope_weigh:
- quant_bias_hf_name = f"model.layers.{layer_id}.self_attn." + name + ".quant_bias"
- quant_bias_ms_name = self.quant_convert_weight_name(quant_bias_hf_name)
- quant_bias_ms_param, _ = self.get_safetensor_from_file(quant_bias_hf_name, src_hf_dir, hf_weight_map)
- if name == "o_proj" and self.tp_rank_id != 0:
- quant_bias_ms_param.fill(0)
-
- dequant_scale_hf_name = f"model.layers.{layer_id}.self_attn." + name + ".deq_scale"
- dequant_scale_ms_name = self.quant_convert_weight_name(dequant_scale_hf_name)
- dequant_scale_ms_param, _ = self.get_safetensor_from_file(dequant_scale_hf_name, src_hf_dir, hf_weight_map)
- else:
- kv_lora_rank = self.config.model.model_config.kv_lora_rank
- qk_rope_head_dim = self.config.model.model_config.qk_rope_head_dim
- qk_nope_head_dim = self.config.model.model_config.qk_nope_head_dim
-
- num_heads = self.config.model.model_config.num_heads
- rope_dim = qk_rope_head_dim + qk_nope_head_dim
- kv_head_dim = kv_lora_rank + qk_rope_head_dim
-
- quant_bias_hf_name = f"model.layers.{layer_id}.self_attn." + name + ".quant_bias"
- quant_bias_ms_name = self.quant_convert_weight_name(quant_bias_hf_name)
- quant_bias_ms_param, _ = self.get_safetensor_from_file(quant_bias_hf_name, src_hf_dir, hf_weight_map)
-
- dequant_scale_hf_name = f"model.layers.{layer_id}.self_attn." + name + ".deq_scale"
- dequant_scale_ms_name = self.quant_convert_weight_name(dequant_scale_hf_name)
- dequant_scale_ms_param, _ = self.get_safetensor_from_file(dequant_scale_hf_name, src_hf_dir, hf_weight_map)
-
- if name == "q_b_proj":
- quant_bias_ms_param = quant_bias_ms_param.reshape(num_heads, rope_dim, -1)
- quant_bias_ms_param = self.infer_trans_rope_weight(quant_bias_ms_param, qk_rope_head_dim)
- quant_bias_ms_param = quant_bias_ms_param.reshape(num_heads * rope_dim, -1).reshape(-1)
-
- dequant_scale_ms_param = dequant_scale_ms_param.reshape(num_heads, rope_dim, -1)
- dequant_scale_ms_param = self.infer_trans_rope_weight(dequant_scale_ms_param, qk_rope_head_dim)
- dequant_scale_ms_param = dequant_scale_ms_param.reshape(num_heads * rope_dim, -1).reshape(-1)
-
- elif name == "kv_a_proj_with_mqa":
- quant_bias_ms_param = quant_bias_ms_param.reshape(kv_head_dim, -1)
- quant_bias_ms_param = self.infer_trans_rope_weight(quant_bias_ms_param, qk_rope_head_dim).reshape(-1)
-
- dequant_scale_ms_param = dequant_scale_ms_param.reshape(kv_head_dim, -1)
- dequant_scale_ms_param = self.infer_trans_rope_weight(dequant_scale_ms_param, qk_rope_head_dim).reshape(
- -1)
-
- if is_split_param:
- quant_bias_ms_param = self.split_weight_by_rank(quant_bias_ms_param, split_axis=0)
- dequant_scale_ms_param = self.split_weight_by_rank(dequant_scale_ms_param, split_axis=0)
-
- self.parameter_dict[quant_bias_ms_name] = ms.Parameter(
- ms.from_numpy(quant_bias_ms_param).astype(ms.int32),
- name=quant_bias_ms_name, requires_grad=False)
- self.parameter_dict[dequant_scale_ms_name] = ms.Parameter(
- ms.from_numpy(dequant_scale_ms_param).astype(ms.float32),
- name=dequant_scale_ms_name,
- requires_grad=False)
-
- def infer_quant_bias_weight(self, src_hf_dir, layer_id, hf_weight_map):
- # quant_op.beta
- q2l_proj_bias_hf_name = f"model.layers.{layer_id}.input_layernorm.bias"
- q2l_proj_bias_ms_name = self.quant_convert_weight_name(q2l_proj_bias_hf_name)
- q2l_proj_bias_ms_param, _ = self.get_safetensor_from_file(q2l_proj_bias_hf_name, src_hf_dir, hf_weight_map)
-
- kv2l_bias_ms_name = f"model.layers.{layer_id}.attention.kv2l.quant_op.beta"
- kv2l_bias_ms_param = q2l_proj_bias_ms_param.copy()
-
- l2q_proj_bias_hf_name = f"model.layers.{layer_id}.self_attn.q_a_layernorm.bias"
- l2q_proj_bias_ms_name = self.quant_convert_weight_name(l2q_proj_bias_hf_name)
- l2q_proj_bias_ms_param, _ = self.get_safetensor_from_file(l2q_proj_bias_hf_name, src_hf_dir, hf_weight_map)
-
- self.parameter_dict[q2l_proj_bias_ms_name] = ms.Parameter(
- ms.from_numpy(q2l_proj_bias_ms_param).astype(ms.bfloat16),
- name=q2l_proj_bias_ms_name,
- requires_grad=False)
- self.parameter_dict[kv2l_bias_ms_name] = ms.Parameter(
- ms.from_numpy(kv2l_bias_ms_param).astype(ms.bfloat16),
- name=kv2l_bias_ms_name,
- requires_grad=False)
- self.parameter_dict[l2q_proj_bias_ms_name] = ms.Parameter(
- ms.from_numpy(l2q_proj_bias_ms_param).astype(ms.bfloat16),
- name=l2q_proj_bias_ms_name,
- requires_grad=False)
-
- def infer_quant_process_attention_weight(self, src_hf_dir, layer_id, hf_weight_map):
- """infer quant process attention weight"""
- num_heads = self.config.model.model_config.num_heads
- qk_rope_head_dim = self.config.model.model_config.qk_rope_head_dim
- v_head_dim = self.config.model.model_config.v_head_dim
- qk_nope_head_dim = self.config.model.model_config.qk_nope_head_dim
-
- rope_dim = qk_rope_head_dim + qk_nope_head_dim
-
- # q_a_layernorm->lq_norm
- lq_norm_hf_name = f"model.layers.{layer_id}.self_attn.q_a_layernorm.weight"
- lq_norm_ms_name = self.quant_convert_weight_name(lq_norm_hf_name)
- lq_norm_ms_param, _ = self.get_safetensor_from_file(lq_norm_hf_name, src_hf_dir, hf_weight_map)
- self.parameter_dict[lq_norm_ms_name] = ms.Parameter(ms.from_numpy(lq_norm_ms_param).astype(ms.bfloat16),
- name=lq_norm_ms_name,
- requires_grad=False)
-
- # q_b_proj->l2q_proj
- l2q_proj_hf_name = f"model.layers.{layer_id}.self_attn.q_b_proj.weight"
- l2q_proj_ms_name = self.quant_convert_weight_name(l2q_proj_hf_name)
- l2q_proj_ms_param, _ = self.get_safetensor_from_file(l2q_proj_hf_name, src_hf_dir, hf_weight_map)
- l2q_proj_ms_param = l2q_proj_ms_param.reshape(num_heads, rope_dim, -1)
- l2q_proj_ms_param = self.infer_trans_rope_weight(l2q_proj_ms_param, qk_rope_head_dim)
- l2q_proj_ms_param = l2q_proj_ms_param.reshape(num_heads * rope_dim, -1)
- l2q_proj_ms_param = self.split_weight_by_rank(l2q_proj_ms_param, split_axis=0)
- self.parameter_dict[l2q_proj_ms_name] = ms.Parameter(
- ms.from_numpy(l2q_proj_ms_param).astype(ms.int8),
- name=l2q_proj_ms_name,
- requires_grad=False)
- self.quant_special_attention_weight(layer_id, src_hf_dir, hf_weight_map, "q_b_proj", is_trans_rope_weigh=True,
- is_split_param=True)
-
- # kv_a_layernorm->lkv_norm
- lkv_norm_hf_name = f"model.layers.{layer_id}.self_attn.kv_a_layernorm.weight"
- lkv_norm_ms_name = self.quant_convert_weight_name(lkv_norm_hf_name)
- lkv_norm_ms_param, _ = self.get_safetensor_from_file(lkv_norm_hf_name, src_hf_dir, hf_weight_map)
- self.parameter_dict[lkv_norm_ms_name] = ms.Parameter(
- ms.from_numpy(lkv_norm_ms_param).astype(ms.bfloat16),
- name=lkv_norm_ms_name,
- requires_grad=False)
-
- # kv_b_proj->lkv2kv
- lkv2kv_hf_name = f"model.layers.{layer_id}.self_attn.kv_b_proj.weight"
- lkv2kv_ms_name = self.quant_convert_weight_name(lkv2kv_hf_name)
- lkv2kv_ms_param, _ = self.get_safetensor_from_file(lkv2kv_hf_name, src_hf_dir, hf_weight_map)
- lkv2kv_head = qk_nope_head_dim + v_head_dim
- lkv2kv_ms_param = lkv2kv_ms_param.reshape(num_heads, lkv2kv_head, -1)
- value_k_nope, value_v = lkv2kv_ms_param[:, :qk_nope_head_dim, :], lkv2kv_ms_param[:, qk_nope_head_dim:, :]
-
- # value_k_nope
- value_k_nope = value_k_nope.reshape(-1, value_k_nope.shape[-1])
- value_k_nope = self.split_weight_by_rank(value_k_nope, split_axis=0)
- name_k_nope = lkv2kv_ms_name.replace(".attention.lkv2kv.", ".attention.lkv2kv_k_nope.")
- self.parameter_dict[name_k_nope] = ms.Parameter(ms.from_numpy(value_k_nope).astype(ms.bfloat16),
- name=name_k_nope,
- requires_grad=False)
- # value_v
- value_v = value_v.reshape(-1, value_v.shape[-1])
- value_v = self.split_weight_by_rank(value_v, split_axis=0)
- name_v = lkv2kv_ms_name.replace(".attention.lkv2kv.", ".attention.lkv2kv_v.")
- self.parameter_dict[name_v] = ms.Parameter(ms.from_numpy(value_v).astype(ms.bfloat16),
- name=name_v,
- requires_grad=False)
-
- # o_proj->wo
- wo_hf_name = f"model.layers.{layer_id}.self_attn.o_proj.weight"
- wo_ms_name = self.quant_convert_weight_name(wo_hf_name)
- wo_ms_param, _ = self.get_safetensor_from_file(wo_hf_name, src_hf_dir, hf_weight_map)
- wo_ms_param = self.split_weight_by_rank(wo_ms_param, split_axis=1)
- self.parameter_dict[wo_ms_name] = ms.Parameter(ms.from_numpy(wo_ms_param).astype(ms.int8),
- name=wo_ms_name,
- requires_grad=False)
- self.quant_special_attention_weight(layer_id, src_hf_dir, hf_weight_map, "o_proj")
-
- def infer_quant_process_dense_qkv_weight(self, src_hf_dir, layer_id, hf_weight_map):
- """infer_quant_process_dense_qkv_weight"""
- parameter_dict = {}
- kv_lora_rank = self.config.model.model_config.kv_lora_rank
- qk_rope_head_dim = self.config.model.model_config.qk_rope_head_dim
- kv_head_dim = kv_lora_rank + qk_rope_head_dim
-
- qkv_concat = self.config.model.model_config.qkv_concat
- # q2l
- q2l_hf_name = f"model.layers.{layer_id}.self_attn.q_a_proj.weight"
- q2l_ms_name = self.quant_convert_weight_name(q2l_hf_name)
- q2l_ms_param, _ = self.get_safetensor_from_file(q2l_hf_name, src_hf_dir, hf_weight_map)
-
- q2l_input_scale_hf_name = f"model.layers.{layer_id}.self_attn.q_a_proj.input_scale"
- q2l_input_scale_ms_name = self.quant_convert_weight_name(q2l_input_scale_hf_name)
- q2l_input_scale_ms_param, _ = self.get_safetensor_from_file(q2l_input_scale_hf_name, src_hf_dir,
- hf_weight_map)
-
- q2l_input_zp_hf_name = f"model.layers.{layer_id}.self_attn.q_a_proj.input_offset"
- q2l_input_zp_ms_name = self.quant_convert_weight_name(q2l_input_zp_hf_name)
- q2l_input_zp_ms_param, _ = self.get_safetensor_from_file(q2l_input_zp_hf_name, src_hf_dir, hf_weight_map)
-
- q2l_quant_bias_hf_name = f"model.layers.{layer_id}.self_attn.q_a_proj.quant_bias"
- q2l_quant_bias_ms_name = self.quant_convert_weight_name(q2l_quant_bias_hf_name)
- q2l_quant_bias_ms_param, _ = self.get_safetensor_from_file(q2l_quant_bias_hf_name, src_hf_dir,
- hf_weight_map)
-
- q2l_dequant_scale_hf_name = f"model.layers.{layer_id}.self_attn.q_a_proj.deq_scale"
- q2l_dequant_scale_ms_name = self.quant_convert_weight_name(q2l_dequant_scale_hf_name)
- q2l_dequant_scale_ms_param, _ = self.get_safetensor_from_file(q2l_dequant_scale_hf_name, src_hf_dir,
- hf_weight_map)
- # kv2l
- kv2l_hf_name = f"model.layers.{layer_id}.self_attn.kv_a_proj_with_mqa.weight"
- kv2l_ms_name = self.quant_convert_weight_name(kv2l_hf_name)
- kv2l_ms_param, _ = self.get_safetensor_from_file(kv2l_hf_name, src_hf_dir, hf_weight_map)
- kv2l_ms_param = kv2l_ms_param.reshape(kv_head_dim, -1)
- kv2l_ms_param = self.infer_trans_rope_weight(kv2l_ms_param, qk_rope_head_dim)
-
- kv2l_input_scale_hf_name = f"model.layers.{layer_id}.self_attn.kv_a_proj_with_mqa.input_scale"
- kv2l_input_scale_ms_name = self.quant_convert_weight_name(kv2l_input_scale_hf_name)
- kv2l_input_scale_ms_param, _ = self.get_safetensor_from_file(kv2l_input_scale_hf_name, src_hf_dir,
- hf_weight_map)
-
- kv2l_input_zp_hf_name = f"model.layers.{layer_id}.self_attn.kv_a_proj_with_mqa.input_offset"
- kv2l_input_zp_ms_name = self.quant_convert_weight_name(kv2l_input_zp_hf_name)
- kv2l_input_zp_ms_param, _ = self.get_safetensor_from_file(kv2l_input_zp_hf_name, src_hf_dir, hf_weight_map)
-
- kv2l_quant_bias_hf_name = f"model.layers.{layer_id}.self_attn.kv_a_proj_with_mqa.quant_bias"
- kv2l_quant_bias_ms_name = self.quant_convert_weight_name(kv2l_quant_bias_hf_name)
- kv2l_quant_bias_ms_param, _ = self.get_safetensor_from_file(kv2l_quant_bias_hf_name, src_hf_dir,
- hf_weight_map)
- kv2l_quant_bias_ms_param = kv2l_quant_bias_ms_param.reshape(kv_head_dim, -1)
- kv2l_quant_bias_ms_param = self.infer_trans_rope_weight(kv2l_quant_bias_ms_param,
- qk_rope_head_dim).reshape(-1)
-
- kv2l_dequant_scale_hf_name = f"model.layers.{layer_id}.self_attn.kv_a_proj_with_mqa.deq_scale"
- kv2l_dequant_scale_ms_name = self.quant_convert_weight_name(kv2l_dequant_scale_hf_name)
- kv2l_dequant_scale_ms_param, _ = self.get_safetensor_from_file(kv2l_dequant_scale_hf_name, src_hf_dir,
- hf_weight_map)
- kv2l_dequant_scale_ms_param = kv2l_dequant_scale_ms_param.reshape(kv_head_dim, -1)
- kv2l_dequant_scale_ms_param = self.infer_trans_rope_weight(kv2l_dequant_scale_ms_param,
- qk_rope_head_dim).reshape(-1)
-
- attn_rmsnorm_beta_hf_name = f"model.layers.{layer_id}.input_layernorm.bias"
- attn_rmsnorm_beta_ms_name = self.quant_convert_weight_name(attn_rmsnorm_beta_hf_name)
- attn_rmsnorm_beta_ms_param, _ = self.get_safetensor_from_file(attn_rmsnorm_beta_hf_name, src_hf_dir, hf_weight_map)
-
- if qkv_concat:
- qkv2l_weight_name = f"model.layers.{layer_id}.attention.qkv2l._layer.weight"
- qkv2l_bias_name = f"model.layers.{layer_id}.attention.qkv2l._layer.matmul.quant_bias"
- qkv2l_scale_name = f"model.layers.{layer_id}.attention.qkv2l._layer.matmul.dequant_scale"
- qkv2l_quant_zp_name = f"model.layers.{layer_id}.attention.qkv2l.quant_op.input_zp"
- qkv2l_quant_scale_name = f"model.layers.{layer_id}.attention.qkv2l.quant_op.input_scale"
- qkv2l_rmsnorm_beta_name = f"model.layers.{layer_id}.attention.qkv2l.quant_op.beta"
-
- qkv2l_weight = np.concatenate((q2l_ms_param, kv2l_ms_param), 0)
- parameter_dict[qkv2l_weight_name] = ms.Parameter(ms.Tensor(qkv2l_weight, ms.int8), name=qkv2l_weight_name, requires_grad=False)
- qkv2l_bias = np.concatenate((q2l_quant_bias_ms_param, kv2l_quant_bias_ms_param), 0)
- parameter_dict[qkv2l_bias_name] = ms.Parameter(ms.Tensor(qkv2l_bias, ms.int32), name=qkv2l_bias_name,requires_grad=False)
- qkv2l_scale = np.concatenate((q2l_dequant_scale_ms_param, kv2l_dequant_scale_ms_param), 0)
- parameter_dict[qkv2l_scale_name] = ms.Parameter(ms.Tensor(qkv2l_scale, ms.float32), name=qkv2l_scale_name, requires_grad=False)
- parameter_dict[qkv2l_quant_zp_name] = ms.Parameter(ms.Tensor(q2l_input_zp_ms_param, ms.int8),requires_grad=False)
- parameter_dict[qkv2l_quant_scale_name] = ms.Parameter(ms.Tensor(q2l_input_scale_ms_param, ms.bfloat16), requires_grad=False)
- parameter_dict[qkv2l_rmsnorm_beta_name] = ms.Parameter(ms.Tensor(attn_rmsnorm_beta_ms_param, ms.float32), requires_grad=False)
- else:
- parameter_dict[q2l_ms_name] = ms.Parameter(ms.Tensor(q2l_ms_param, ms.int8), name=q2l_ms_name,requires_grad=False)
- parameter_dict[kv2l_ms_name] = ms.Parameter(ms.Tensor(kv2l_ms_param, ms.int8),requires_grad=False)
- parameter_dict[q2l_quant_bias_ms_name] = ms.Parameter(ms.Tensor(q2l_quant_bias_ms_param, ms.int32),name=q2l_quant_bias_ms_name,requires_grad = False)
- parameter_dict[kv2l_quant_bias_ms_name] = ms.Parameter(ms.Tensor(kv2l_quant_bias_ms_param, ms.int32),name=kv2l_quant_bias_ms_name,requires_grad = False)
- parameter_dict[q2l_dequant_scale_ms_name] = ms.Parameter(ms.Tensor(q2l_dequant_scale_ms_param, ms.float32), name=q2l_dequant_scale_ms_name, requires_grad = False)
- parameter_dict[kv2l_dequant_scale_ms_name] = ms.Parameter(ms.Tensor(kv2l_dequant_scale_ms_param, ms.float32),name = kv2l_dequant_scale_ms_name, requires_grad = False)
- parameter_dict[q2l_input_zp_ms_name] = ms.Parameter(ms.Tensor(q2l_input_zp_ms_param, ms.int8),name=q2l_input_zp_ms_name, requires_grad = False)
- parameter_dict[kv2l_input_zp_ms_name] = ms.Parameter(ms.Tensor(kv2l_input_zp_ms_param, ms.int8), name=kv2l_input_zp_ms_name, requires_grad = False)
- parameter_dict[q2l_input_scale_ms_name] = ms.Parameter(ms.Tensor(q2l_input_scale_ms_param, ms.bfloat16), name = q2l_input_scale_ms_name, requires_grad = False)
- parameter_dict[kv2l_input_scale_ms_name] = ms.Parameter(ms.Tensor(kv2l_input_scale_ms_param, ms.bfloat16), name = kv2l_input_scale_ms_name, requires_grad = False)
- parameter_dict[attn_rmsnorm_beta_ms_name] = ms.Parameter(ms.Tensor(attn_rmsnorm_beta_ms_param, ms.float32), name=attn_rmsnorm_beta_ms_name, requires_grad=False)
- _, _ = ms.load_param_into_net(self.network, parameter_dict)
- del parameter_dict
- gc.collect()
-
- def infer_quant_net_convert_layer_weight(self, src_hf_dir, layer_id, hf_weight_map):
- """infer quant net convert layer weight"""
-
- if layer_id >= 3:
- self.infer_quant_process_moe_routed_expert_ffn_weight(src_hf_dir, layer_id, hf_weight_map)
- self.infer_quant_process_moe_shared_expert_ffn_weight(src_hf_dir, layer_id, hf_weight_map)
- else:
- self.infer_quant_process_dense_ffn_weight(src_hf_dir, layer_id, hf_weight_map)
-
- self.infer_quant_process_dense_qkv_weight(src_hf_dir, layer_id, hf_weight_map)
- self.infer_quant_process_attention_weight(src_hf_dir, layer_id, hf_weight_map)
- self.infer_quant_bias_weight(src_hf_dir, layer_id, hf_weight_map)
- self.infer_process_norm_weight(src_hf_dir, layer_id, hf_weight_map)
-
- def convert_weight_name(self, weight_name: str):
- """replace weight name"""
- weight_name = weight_name.replace('embed_tokens.weight', 'tok_embeddings.embedding_weight')
- weight_name = weight_name.replace('.self_attn.q_a_proj.', '.attention.q2l_proj.')
- weight_name = weight_name.replace('.self_attn.q_a_layernorm.', '.attention.lq_norm.')
- weight_name = weight_name.replace('.self_attn.q_b_proj.', '.attention.l2q_proj.')
- weight_name = weight_name.replace('.self_attn.kv_a_proj_with_mqa.', '.attention.kv2l.')
- weight_name = weight_name.replace('.self_attn.kv_a_layernorm.', '.attention.lkv_norm.')
- weight_name = weight_name.replace('.self_attn.kv_b_proj.', '.attention.lkv2kv.')
- weight_name = weight_name.replace('.self_attn.o_proj.', '.attention.wo.')
- weight_name = weight_name.replace('mlp.gate_proj.', 'feed_forward.w1.')
- weight_name = weight_name.replace('mlp.down_proj.', 'feed_forward.w2.')
- weight_name = weight_name.replace('mlp.up_proj.', 'feed_forward.w3.')
- weight_name = weight_name.replace('mlp.experts.', 'feed_forward.routed_experts.ffn.')
- weight_name = weight_name.replace('mlp.shared_experts.gate_proj.', 'feed_forward.shared_experts.w1.')
- weight_name = weight_name.replace('mlp.shared_experts.down_proj.', 'feed_forward.shared_experts.w2.')
- weight_name = weight_name.replace('mlp.shared_experts.up_proj.', 'feed_forward.shared_experts.w3.')
- weight_name = weight_name.replace('mlp.gate.weight', 'feed_forward.routed_experts.router.dense.weight')
- weight_name = weight_name.replace('mlp.gate.e_score_correction_bias',
- 'feed_forward.routed_experts.router.e_score_correction_bias')
- weight_name = weight_name.replace('.input_layernorm.', '.attention_norm.')
- weight_name = weight_name.replace('.post_attention_layernorm.', '.ffn_norm.')
- weight_name = weight_name.replace('model.norm.weight', 'model.norm_out.weight')
-
- weight_name = self.convert_mtp_weight_name(weight_name)
- return weight_name
-
- def convert_mtp_weight_name(self, weight_name: str):
- layer = 0 if 'layers.' not in weight_name else int(weight_name[weight_name.find('layers.'):].split('.')[1])
- if layer < self.num_layers:
- return weight_name
- mtp_prefix = f'mtp_model'
- is_mtp_layer = 'tok_embeddings' not in weight_name and 'shared_head.' not in weight_name
- mtp_prefix = mtp_prefix if not is_mtp_layer else f'{mtp_prefix}.layer'
- is_decode_layer = "ffn" in weight_name or "attention" in weight_name or "feed_forward" in weight_name
- mtp_prefix = mtp_prefix if not is_decode_layer else f'{mtp_prefix}.decode_layer'
-
- weight_name = weight_name.replace(f'model.layers.{layer}', mtp_prefix)
- if "tok_embeddings" in weight_name:
- weight_name = weight_name.replace(f'.weight', f'.embedding_weight')
- if "shared_head." in weight_name:
- weight_name = weight_name.replace(f'shared_head.', f'')
- return weight_name
-
- def infer_process_moe_routed_expert_ffn_weight(self, src_hf_dir, layer_id, hf_weight_map):
- """process moe router expert weight"""
- ffn_concat = self.config.model.model_config.ffn_concat
-
- # router expert dense
- router_dense_hf_name = f"model.layers.{layer_id}.mlp.gate.weight"
- router_dense_ms_name = self.convert_weight_name(router_dense_hf_name)
- router_dense_ms_param, _ = self.get_safetensor_from_file(router_dense_hf_name, src_hf_dir, hf_weight_map)
- self.parameter_dict[router_dense_ms_name] = ms.Parameter(
- ms.from_numpy(router_dense_ms_param).astype(ms.bfloat16),
- name=router_dense_ms_name, requires_grad=False)
-
- # e_score_correction_bias
- e_score_correction_bias_hf_name = f"model.layers.{layer_id}.mlp.gate.e_score_correction_bias"
- e_score_correction_bias_ms_name = self.convert_weight_name(e_score_correction_bias_hf_name)
- e_score_correction_bias_ms_param, _ = self.get_safetensor_from_file(e_score_correction_bias_hf_name, src_hf_dir,
- hf_weight_map)
- self.parameter_dict[e_score_correction_bias_ms_name] = ms.Parameter(
- ms.from_numpy(e_score_correction_bias_ms_param).astype(ms.float32),
- name=e_score_correction_bias_ms_name, requires_grad=False)
-
- w1_list = []
- w2_list = []
- w3_list = []
-
- w1_ms_name = f"model.layers.{layer_id}.feed_forward.routed_experts.ffn.w1.weight"
- w1_ms_name = w1_ms_name if layer_id < self.num_layers else self.convert_mtp_weight_name(w1_ms_name)
- w2_ms_name = f"model.layers.{layer_id}.feed_forward.routed_experts.ffn.w2.weight"
- w2_ms_name = w2_ms_name if layer_id < self.num_layers else self.convert_mtp_weight_name(w2_ms_name)
- w3_ms_name = f"model.layers.{layer_id}.feed_forward.routed_experts.ffn.w3.weight"
- w3_ms_name = w3_ms_name if layer_id < self.num_layers else self.convert_mtp_weight_name(w3_ms_name)
-
- for index in range(0, self.num_router_experts):
- w1_hf_name = f"model.layers.{layer_id}.mlp.experts.{index}.gate_proj.weight"
- w1_ms_param, _ = self.get_safetensor_from_file_split_tp_group(w1_hf_name, src_hf_dir, hf_weight_map,
- split_axis=0)
-
- w2_hf_name = f"model.layers.{layer_id}.mlp.experts.{index}.down_proj.weight"
- w2_ms_param, _ = self.get_safetensor_from_file_split_tp_group(w2_hf_name, src_hf_dir, hf_weight_map,
- split_axis=1)
-
- w3_hf_name = f"model.layers.{layer_id}.mlp.experts.{index}.up_proj.weight"
- w3_ms_param, _ = self.get_safetensor_from_file_split_tp_group(w3_hf_name, src_hf_dir, hf_weight_map,
- split_axis=0)
-
- w1_list.append(w1_ms_param)
- w2_list.append(w2_ms_param)
- w3_list.append(w3_ms_param)
-
- w1_ms_stack_param = np.stack(w1_list, axis=0)
- w2_ms_stack_param = np.stack(w2_list, axis=0)
- w3_ms_stack_param = np.stack(w3_list, axis=0)
-
- if ffn_concat:
- w_gate_hidden_name = f"model.layers.{layer_id}.feed_forward.routed_experts.ffn.w_gate_hidden.weight"
- w_gate_hidden_name = w_gate_hidden_name if layer_id < self.num_layers else \
- self.convert_mtp_weight_name(w_gate_hidden_name)
- w_gate_hidden_np = np.concatenate([w1_ms_stack_param, w3_ms_stack_param], axis=1)
- w_gate_hidden_param = ms.from_numpy(w_gate_hidden_np).permute(0, 2, 1).astype(dtype=ms.bfloat16)
- self.parameter_dict[w_gate_hidden_name] = ms.Parameter(w_gate_hidden_param,
- name=w_gate_hidden_name,
- requires_grad=False)
- else:
- w1_ms_stack_param = ms.from_numpy(w1_ms_stack_param).permute(0, 2, 1).astype(ms.bfloat16)
- self.parameter_dict[w1_ms_name] = ms.Parameter(w1_ms_stack_param,
- name=w1_ms_name,
- requires_grad=False)
-
- w3_ms_stack_param = ms.from_numpy(w3_ms_stack_param).permute(0, 2, 1).astype(ms.bfloat16)
- self.parameter_dict[w3_ms_name] = ms.Parameter(w3_ms_stack_param,
- name=w3_ms_name,
- requires_grad=False)
-
- w2_ms_stack_param = ms.from_numpy(w2_ms_stack_param).permute(0, 2, 1).astype(ms.bfloat16)
- self.parameter_dict[w2_ms_name] = ms.Parameter(w2_ms_stack_param,
- name=w2_ms_name,
- requires_grad=False)
-
- def get_moe_shared_expert_weight(self, w1_hf_name, w2_hf_name, w3_hf_name, src_hf_dir, hf_weight_map):
- if self.ep_method in [EPMethod.DEFAULT, EPMethod.ALLGATHER]:
- w1_ms_param, _ = self.get_safetensor_from_file_split_global_group(w1_hf_name, src_hf_dir, hf_weight_map,
- split_axis=0)
- w2_ms_param, _ = self.get_safetensor_from_file_split_global_group(w2_hf_name, src_hf_dir, hf_weight_map,
- split_axis=1)
- w3_ms_param, _ = self.get_safetensor_from_file_split_global_group(w3_hf_name, src_hf_dir, hf_weight_map,
- split_axis=0)
- elif self.ep_method == EPMethod.ALLTOALL:
- w1_ms_param, _ = self.get_safetensor_from_file(w1_hf_name, src_hf_dir, hf_weight_map)
- w2_ms_param, _ = self.get_safetensor_from_file(w2_hf_name, src_hf_dir, hf_weight_map)
- w3_ms_param, _ = self.get_safetensor_from_file(w3_hf_name, src_hf_dir, hf_weight_map)
-
- else:
- raise ValueError("Unsupported ep_method:{}".format(self.ep_method))
-
- return w1_ms_param, w2_ms_param, w3_ms_param
-
- def infer_process_moe_shared_expert_ffn_weight(self, src_hf_dir, layer_id, hf_weight_map):
- """infer process moe shared expert ffn weight"""
- ffn_concat = self.config.model.model_config.ffn_concat
- w1_hf_name = f"model.layers.{layer_id}.mlp.shared_experts.gate_proj.weight"
- w2_hf_name = f"model.layers.{layer_id}.mlp.shared_experts.down_proj.weight"
- w3_hf_name = f"model.layers.{layer_id}.mlp.shared_experts.up_proj.weight"
-
- w1_ms_name = self.convert_weight_name(w1_hf_name)
- w2_ms_name = self.convert_weight_name(w2_hf_name)
- w3_ms_name = self.convert_weight_name(w3_hf_name)
-
- w1_ms_param, w2_ms_param, w3_ms_param = self.get_moe_shared_expert_weight(w1_hf_name, w2_hf_name, w3_hf_name,
- src_hf_dir, hf_weight_map)
-
- if ffn_concat:
- w_gate_hidden_name = f"model.layers.{layer_id}.feed_forward.shared_experts.w_gate_hidden.weight"
- w_gate_hidden_name = w_gate_hidden_name if layer_id < self.num_layers else \
- self.convert_mtp_weight_name(w_gate_hidden_name)
- w_gate_hidden_np = np.concatenate([w1_ms_param, w3_ms_param], axis=0)
- w_gate_hidden_param = ms.from_numpy(w_gate_hidden_np).astype(ms.bfloat16)
- self.parameter_dict[w_gate_hidden_name] = ms.Parameter(w_gate_hidden_param,
- name=w_gate_hidden_name,
- requires_grad=False)
- else:
- self.parameter_dict[w1_ms_name] = ms.Parameter(ms.from_numpy(w1_ms_param).astype(ms.bfloat16),
- name=w1_ms_name,
- requires_grad=False)
- self.parameter_dict[w3_ms_name] = ms.Parameter(ms.from_numpy(w3_ms_param).astype(ms.bfloat16),
- name=w3_ms_name,
- requires_grad=False)
- self.parameter_dict[w2_ms_name] = ms.Parameter(ms.from_numpy(w2_ms_param).astype(ms.bfloat16),
- name=w2_ms_name,
- requires_grad=False)
-
- def infer_process_dense_ffn_weight(self, src_hf_dir, layer_id, hf_weight_map):
- """infer process dense ffn weight"""
-
- ffn_concat = self.config.model.model_config.ffn_concat
-
- w1_hf_name = f"model.layers.{layer_id}.mlp.gate_proj.weight"
- w1_ms_name = self.convert_weight_name(w1_hf_name)
- w1_ms_param, _ = self.get_safetensor_from_file_split_tp_group(w1_hf_name, src_hf_dir, hf_weight_map,
- split_axis=0)
-
- w2_hf_name = f"model.layers.{layer_id}.mlp.down_proj.weight"
- w2_ms_name = self.convert_weight_name(w2_hf_name)
- w2_ms_param, _ = self.get_safetensor_from_file_split_tp_group(w2_hf_name, src_hf_dir, hf_weight_map,
- split_axis=1)
-
- w3_hf_name = f"model.layers.{layer_id}.mlp.up_proj.weight"
- w3_ms_name = self.convert_weight_name(w3_hf_name)
- w3_ms_param, _ = self.get_safetensor_from_file_split_tp_group(w3_hf_name, src_hf_dir, hf_weight_map,
- split_axis=0)
-
- if ffn_concat:
- w_gate_hidden_name = f"model.layers.{layer_id}.feed_forward.w_gate_hidden.weight"
- w_gate_hidden_np = np.concatenate([w1_ms_param, w3_ms_param], axis=0)
- w_gate_hidden_param = ms.from_numpy(w_gate_hidden_np).astype(ms.bfloat16)
- self.parameter_dict[w_gate_hidden_name] = ms.Parameter(w_gate_hidden_param,
- name=w_gate_hidden_name,
- requires_grad=False)
- else:
- self.parameter_dict[w1_ms_name] = ms.Parameter(ms.from_numpy(w1_ms_param).astype(ms.bfloat16),
- name=w1_ms_name,
- requires_grad=False)
- self.parameter_dict[w3_ms_name] = ms.Parameter(ms.from_numpy(w3_ms_param).astype(ms.bfloat16),
- name=w3_ms_name,
- requires_grad=False)
-
- self.parameter_dict[w2_ms_name] = ms.Parameter(ms.from_numpy(w2_ms_param).astype(ms.bfloat16),
- name=w2_ms_name,
- requires_grad=False)
-
- def infer_process_attention_weight(self, src_hf_dir, layer_id, hf_weight_map):
- """infer process attention weight"""
- num_heads = self.config.model.model_config.num_heads
- kv_lora_rank = self.config.model.model_config.kv_lora_rank
- qk_rope_head_dim = self.config.model.model_config.qk_rope_head_dim
- v_head_dim = self.config.model.model_config.v_head_dim
- qk_nope_head_dim = self.config.model.model_config.qk_nope_head_dim
-
- rope_dim = qk_rope_head_dim + qk_nope_head_dim
- kv_head_dim = kv_lora_rank + qk_rope_head_dim
-
- qkv_concat = self.config.model.model_config.qkv_concat
- # q2l_proj
- q2l_proj_hf_name = f"model.layers.{layer_id}.self_attn.q_a_proj.weight"
- q2l_proj_ms_name = self.convert_weight_name(q2l_proj_hf_name)
- q_a_proj_ms_param, _ = self.get_safetensor_from_file(q2l_proj_hf_name, src_hf_dir, hf_weight_map)
-
- # kv2l
- kv2l_hf_name = f"model.layers.{layer_id}.self_attn.kv_a_proj_with_mqa.weight"
- kv2l_ms_name = self.convert_weight_name(kv2l_hf_name)
- kv2l_ms_param, _ = self.get_safetensor_from_file(kv2l_hf_name, src_hf_dir, hf_weight_map)
- kv2l_ms_param = kv2l_ms_param.reshape(kv_head_dim, -1)
- kv2l_ms_param = self.infer_trans_rope_weight(kv2l_ms_param, qk_rope_head_dim)
- if qkv_concat:
- wqkv2l_weight = np.concatenate((q_a_proj_ms_param, kv2l_ms_param), 0)
- wqkv2l_weight_name = f"model.layers.{layer_id}.attention.qkv2l.weight"
- self.parameter_dict[wqkv2l_weight_name] = ms.Parameter(ms.from_numpy(wqkv2l_weight).astype(ms.bfloat16),
- name=wqkv2l_weight_name,
- requires_grad=False)
- else:
- self.parameter_dict[q2l_proj_ms_name] = ms.Parameter(ms.from_numpy(q_a_proj_ms_param).astype(ms.bfloat16),
- name=q2l_proj_ms_name,
- requires_grad=False)
- self.parameter_dict[kv2l_ms_name] = ms.Parameter(ms.from_numpy(kv2l_ms_param).astype(ms.bfloat16),
- name=kv2l_ms_name,
- requires_grad=False)
- # lq_norm
- lq_norm_hf_name = f"model.layers.{layer_id}.self_attn.q_a_layernorm.weight"
- lq_norm_ms_name = self.convert_weight_name(lq_norm_hf_name)
- lq_norm_ms_param, _ = self.get_safetensor_from_file(lq_norm_hf_name, src_hf_dir, hf_weight_map)
- self.parameter_dict[lq_norm_ms_name] = ms.Parameter(ms.from_numpy(lq_norm_ms_param).astype(ms.bfloat16),
- name=lq_norm_ms_name,
- requires_grad=False)
-
- # l2q_proj
- l2q_proj_hf_name = f"model.layers.{layer_id}.self_attn.q_b_proj.weight"
- l2q_proj_ms_name = self.convert_weight_name(l2q_proj_hf_name)
- l2q_proj_ms_param, _ = self.get_safetensor_from_file(l2q_proj_hf_name, src_hf_dir, hf_weight_map)
- l2q_proj_ms_param = l2q_proj_ms_param.reshape(num_heads, rope_dim, -1)
- l2q_proj_ms_param = self.infer_trans_rope_weight(l2q_proj_ms_param, qk_rope_head_dim)
- l2q_proj_ms_param = l2q_proj_ms_param.reshape(num_heads * rope_dim, -1)
- l2q_proj_ms_param = self.split_weight_by_rank(l2q_proj_ms_param, split_axis=0)
- self.parameter_dict[l2q_proj_ms_name] = ms.Parameter(
- ms.from_numpy(l2q_proj_ms_param).astype(ms.bfloat16),
- name=l2q_proj_ms_name,
- requires_grad=False)
-
- # lkv_norm
- lkv_norm_hf_name = f"model.layers.{layer_id}.self_attn.kv_a_layernorm.weight"
- lkv_norm_ms_name = self.convert_weight_name(lkv_norm_hf_name)
- lkv_norm_ms_param, _ = self.get_safetensor_from_file(lkv_norm_hf_name, src_hf_dir, hf_weight_map)
- self.parameter_dict[lkv_norm_ms_name] = ms.Parameter(
- ms.from_numpy(lkv_norm_ms_param).astype(ms.bfloat16),
- name=lkv_norm_ms_name,
- requires_grad=False)
-
- # lkv2kv
- lkv2kv_hf_name = f"model.layers.{layer_id}.self_attn.kv_b_proj.weight"
- lkv2kv_ms_name = self.convert_weight_name(lkv2kv_hf_name)
- lkv2kv_ms_param, _ = self.get_safetensor_from_file(lkv2kv_hf_name, src_hf_dir, hf_weight_map)
- lkv2kv_head = qk_nope_head_dim + v_head_dim
- lkv2kv_ms_param = lkv2kv_ms_param.reshape(num_heads, lkv2kv_head, -1)
- value_k_nope, value_v = lkv2kv_ms_param[:, :qk_nope_head_dim, :], lkv2kv_ms_param[:, qk_nope_head_dim:, :]
-
- # value_k_nope
- value_k_nope = value_k_nope.reshape(-1, value_k_nope.shape[-1])
- value_k_nope = self.split_weight_by_rank(value_k_nope, split_axis=0)
- name_k_nope = lkv2kv_ms_name.replace(".attention.lkv2kv.", ".attention.lkv2kv_k_nope.")
- self.parameter_dict[name_k_nope] = ms.Parameter(ms.from_numpy(value_k_nope).astype(ms.bfloat16),
- name=name_k_nope,
- requires_grad=False)
- # value_v
- value_v = value_v.reshape(-1, value_v.shape[-1])
- value_v = self.split_weight_by_rank(value_v, split_axis=0)
- name_v = lkv2kv_ms_name.replace(".attention.lkv2kv.", ".attention.lkv2kv_v.")
- self.parameter_dict[name_v] = ms.Parameter(ms.from_numpy(value_v).astype(ms.bfloat16),
- name=name_v,
- requires_grad=False)
-
- # wo
- wo_hf_name = f"model.layers.{layer_id}.self_attn.o_proj.weight"
- wo_ms_name = self.convert_weight_name(wo_hf_name)
- wo_ms_param, _ = self.get_safetensor_from_file(wo_hf_name, src_hf_dir, hf_weight_map)
- wo_ms_param = self.split_weight_by_rank(wo_ms_param, split_axis=1)
- self.parameter_dict[wo_ms_name] = ms.Parameter(ms.from_numpy(wo_ms_param).astype(ms.bfloat16),
- name=wo_ms_name,
- requires_grad=False)
-
- def infer_process_norm_weight(self, src_hf_dir, layer_id, hf_weight_map):
- """infer process attention weight"""
- # attention_norm
- attention_norm_hf_name = f"model.layers.{layer_id}.input_layernorm.weight"
- attention_norm_ms_name = self.convert_weight_name(attention_norm_hf_name)
- attention_norm_ms_param, _ = self.get_safetensor_from_file(attention_norm_hf_name,
- src_hf_dir,
- hf_weight_map)
- self.parameter_dict[attention_norm_ms_name] = ms.Parameter(
- ms.from_numpy(attention_norm_ms_param).astype(ms.bfloat16),
- name=attention_norm_ms_name,
- requires_grad=False)
-
- # ffn_norm
- ffn_norm_hf_name = f"model.layers.{layer_id}.post_attention_layernorm.weight"
- ffn_norm_ms_name = self.convert_weight_name(ffn_norm_hf_name)
- ffn_norm_ms_param, _ = self.get_safetensor_from_file(ffn_norm_hf_name, src_hf_dir, hf_weight_map)
- self.parameter_dict[ffn_norm_ms_name] = ms.Parameter(
- ms.from_numpy(ffn_norm_ms_param).astype(ms.bfloat16),
- name=ffn_norm_ms_name,
- requires_grad=False)
-
- def infer_process_mtp_layer_weight(self, src_hf_dir, layer_id, hf_weight_map):
- parameter_dict = {}
- mtp_layer_names = ["embed_tokens.weight", "enorm.weight", "hnorm.weight", "eh_proj.weight",
- "shared_head.norm.weight", "shared_head.head.weight"]
- head_names = ["eh_proj.weight", "shared_head.head.weight"]
- for prefix_name in mtp_layer_names:
- hf_name = f"model.layers.{layer_id}.{prefix_name}"
- ms_name = self.convert_weight_name(hf_name)
- if prefix_name in head_names and not self.config.parallel_config.vocab_emb_dp:
- ms_param, _ = self.get_safetensor_from_file_split_tp_group(hf_name, src_hf_dir, hf_weight_map,
- split_axis=0)
- else:
- ms_param, _ = self.get_safetensor_from_file(hf_name, src_hf_dir, hf_weight_map)
- parameter_dict[ms_name] = ms.Parameter(ms.Tensor(ms_param, ms.bfloat16),
- name=ms_name,
- requires_grad=False)
-
- _, ckpt_not_load = ms.load_param_into_net(self.network, parameter_dict)
-
- def infer_convert_layer_weight(self, src_hf_dir, layer_id, hf_weight_map):
- """infer convert layer weight"""
- if layer_id >= 3:
- self.infer_process_moe_routed_expert_ffn_weight(src_hf_dir, layer_id, hf_weight_map)
- self.infer_process_moe_shared_expert_ffn_weight(src_hf_dir, layer_id, hf_weight_map)
- else:
- self.infer_process_dense_ffn_weight(src_hf_dir, layer_id, hf_weight_map)
-
- self.infer_process_attention_weight(src_hf_dir, layer_id, hf_weight_map)
- self.infer_process_norm_weight(src_hf_dir, layer_id, hf_weight_map)
-
- # convert mtp shared weights.
- if layer_id >= self.num_layers:
- self.infer_process_mtp_layer_weight(src_hf_dir, layer_id, hf_weight_map)
-
- def smooth_quant_process_route_ffn_weight(self, src_hf_dir, layer_id, hf_weight_map, parameter_dict, layer_type):
- """smooth_quant_process_route_ffn_weight"""
-
- ffn_concat = self.config.model.model_config.ffn_concat
- w1_weight_name = f"model.layers.{layer_id}.{layer_type}.w1._layer.weight"
- w1_scale_name = f"model.layers.{layer_id}.{layer_type}.w1._layer.matmul.weight_scale"
- w3_weight_name = f"model.layers.{layer_id}.{layer_type}.w3._layer.weight"
- w3_scale_name = f"model.layers.{layer_id}.{layer_type}.w3._layer.matmul.weight_scale"
- w2_weight_name = f"model.layers.{layer_id}.{layer_type}.w2._layer.weight"
- w2_scale_name = f"model.layers.{layer_id}.{layer_type}.w2._layer.matmul.weight_scale"
- w1_weight_param, _ = self.get_routed_safetensor_3_dim(w1_weight_name, src_hf_dir, hf_weight_map, tp_axis=2,
- split_ep=self.moe_split_ep, split_tp=self.moe_split_tp)
-
- w1_scale_param, _ = self.get_routed_safetensor_2_dim(w1_scale_name, src_hf_dir, hf_weight_map, tp_axis=1,
- split_ep=self.moe_split_ep, split_tp=self.moe_split_tp)
-
- w3_weight_param, _ = self.get_routed_safetensor_3_dim(w3_weight_name, src_hf_dir, hf_weight_map, tp_axis=2,
- split_ep=self.moe_split_ep, split_tp=self.moe_split_tp)
-
- w3_scale_param, _ = self.get_routed_safetensor_2_dim(w3_scale_name, src_hf_dir, hf_weight_map, tp_axis=1,
- split_ep=self.moe_split_ep, split_tp=self.moe_split_tp)
-
- w2_weight_param, _ = self.get_routed_safetensor_3_dim(w2_weight_name, src_hf_dir, hf_weight_map, tp_axis=1,
- split_ep=self.moe_split_ep, split_tp=self.moe_split_tp)
- w2_scale_param, _ = self.get_routed_safetensor_2_dim(w2_scale_name, src_hf_dir, hf_weight_map,
- split_ep=self.moe_split_ep, split_tp=False)
- if ffn_concat:
- concat_weight_name = f"model.layers.{layer_id}.{layer_type}.w_gate_hidden._layer.weight"
- concat_weight_param = ms.Tensor(np.concatenate([w1_weight_param, w3_weight_param], axis=2), dtype=ms.int8)
- parameter_dict[concat_weight_name] = ms.Parameter(concat_weight_param, name=concat_weight_name,
- requires_grad=False)
-
- concat_scale_name = f"model.layers.{layer_id}.{layer_type}.w_gate_hidden._layer.matmul.weight_scale"
- concat_scale_param = ms.Tensor(np.concatenate([w1_scale_param, w3_scale_param], axis=1), dtype=ms.bfloat16)
- parameter_dict[concat_scale_name] = ms.Parameter(concat_scale_param, name=concat_scale_name,
- requires_grad=False)
- else:
- # w1 w3
- parameter_dict[w1_weight_name] = ms.Parameter(ms.Tensor(w1_weight_param, ms.int8), name=w1_weight_name,
- requires_grad=False)
- parameter_dict[w3_weight_name] = ms.Parameter(ms.Tensor(w3_weight_param, ms.int8), name=w3_weight_name,
- requires_grad=False)
-
- parameter_dict[w1_scale_name] = ms.Parameter(ms.Tensor(w1_scale_param, ms.bfloat16),
- name=w1_scale_name, requires_grad=False)
- parameter_dict[w3_scale_name] = ms.Parameter(ms.Tensor(w3_scale_param, ms.bfloat16),
- name=w3_scale_name, requires_grad=False)
-
- parameter_dict[w2_weight_name] = ms.Parameter(ms.Tensor(w2_weight_param, ms.int8), name=w2_weight_name,
- requires_grad=False)
- parameter_dict[w2_scale_name] = ms.Parameter(ms.Tensor(w2_scale_param, ms.bfloat16),
- name=w2_scale_name, requires_grad=False)
-
- def get_smooth_quant_moe_shared_expert_weight(self, w1_weight_name, w1_scale_name, w3_weight_name,w3_scale_name,
- w2_weight_name, src_hf_dir, hf_weight_map):
- '''get_smooth_quant_moe_shared_expert_weight'''
-
- if self.ep_method in [EPMethod.DEFAULT, EPMethod.ALLGATHER]:
- w1_weight_param, _ = self.get_safetensor_from_file_split_moe_tp_group(w1_weight_name, src_hf_dir,
- hf_weight_map,
- split_axis=0)
-
- w1_scale_param, _ = self.get_safetensor_from_file_split_moe_tp_group(w1_scale_name, src_hf_dir,
- hf_weight_map,
- split_axis=0)
-
- w3_weight_param, _ = self.get_safetensor_from_file_split_moe_tp_group(w3_weight_name, src_hf_dir,
- hf_weight_map,
- split_axis=0)
- w3_scale_param, _ = self.get_safetensor_from_file_split_moe_tp_group(w3_scale_name, src_hf_dir,
- hf_weight_map,
- split_axis=0)
-
- w2_weight_param, _ = self.get_safetensor_from_file_split_moe_tp_group(w2_weight_name, src_hf_dir,
- hf_weight_map,
- split_axis=1)
- elif self.ep_method == EPMethod.ALLTOALL:
- w1_weight_param, _ = self.get_safetensor_from_file(w1_weight_name, src_hf_dir, hf_weight_map)
- w1_scale_param, _ = self.get_safetensor_from_file(w1_scale_name, src_hf_dir, hf_weight_map)
-
- w3_weight_param, _ = self.get_safetensor_from_file(w3_weight_name, src_hf_dir, hf_weight_map)
- w3_scale_param, _ = self.get_safetensor_from_file(w3_scale_name, src_hf_dir, hf_weight_map)
-
- w2_weight_param, _ = self.get_safetensor_from_file(w2_weight_name, src_hf_dir, hf_weight_map)
- else:
- raise ValueError("Unsupported ep_method:{}".format(self.ep_method))
-
- return w1_weight_param, w1_scale_param, w3_weight_param, w3_scale_param, w2_weight_param
-
- def smooth_quant_process_shared_ffn_weight(self, src_hf_dir, layer_id, hf_weight_map, parameter_dict, layer_type):
- """smooth_quant_process_shared_ffn_weight"""
-
- ffn_concat = self.config.model.model_config.ffn_concat
- w1_weight_name = f"model.layers.{layer_id}.{layer_type}.w1._layer.weight"
- w2_weight_name = f"model.layers.{layer_id}.{layer_type}.w2._layer.weight"
- w3_weight_name = f"model.layers.{layer_id}.{layer_type}.w3._layer.weight"
-
- w1_scale_name = f"model.layers.{layer_id}.{layer_type}.w1._layer.matmul.weight_scale"
- w2_scale_name = f"model.layers.{layer_id}.{layer_type}.w2._layer.matmul.weight_scale"
- w3_scale_name = f"model.layers.{layer_id}.{layer_type}.w3._layer.matmul.weight_scale"
-
- w1_weight_param, w1_scale_param, w3_weight_param, w3_scale_param, w2_weight_param = \
- self.get_smooth_quant_moe_shared_expert_weight(w1_weight_name, w1_scale_name, w3_weight_name, w3_scale_name,
- w2_weight_name, src_hf_dir, hf_weight_map)
- w2_scale_param, _ = self.get_safetensor_from_file(w2_scale_name, src_hf_dir, hf_weight_map)
-
- if ffn_concat:
- concat_weight_name = f"model.layers.{layer_id}.{layer_type}.w_gate_hidden._layer.weight"
- concat_weight_param = ms.Tensor(np.concatenate([w1_weight_param, w3_weight_param], axis=0), dtype=ms.int8)
- parameter_dict[concat_weight_name] = ms.Parameter(concat_weight_param, name=concat_weight_name,
- requires_grad=False)
-
- concat_scale_name = f"model.layers.{layer_id}.{layer_type}.w_gate_hidden._layer.matmul.weight_scale"
- concat_scale_param = ms.Tensor(np.concatenate([w1_scale_param, w3_scale_param], axis=0), dtype=ms.bfloat16)
- parameter_dict[concat_scale_name] = ms.Parameter(concat_scale_param, name=concat_scale_name,
- requires_grad=False)
-
- else:
- # w1 w3
- parameter_dict[w1_weight_name] = ms.Parameter(ms.Tensor(w1_weight_param, ms.int8), name=w1_weight_name,
- requires_grad=False)
- parameter_dict[w3_weight_name] = ms.Parameter(ms.Tensor(w3_weight_param, ms.int8), name=w3_weight_name,
- requires_grad=False)
-
- parameter_dict[w1_scale_name] = ms.Parameter(ms.Tensor(w1_scale_param, ms.bfloat16),
- name=w1_scale_name, requires_grad=False)
- parameter_dict[w3_scale_name] = ms.Parameter(ms.Tensor(w3_scale_param, ms.bfloat16),
- name=w3_scale_name, requires_grad=False)
-
- parameter_dict[w2_weight_name] = ms.Parameter(ms.Tensor(w2_weight_param, ms.int8), name=w2_weight_name,
- requires_grad=False)
- parameter_dict[w2_scale_name] = ms.Parameter(ms.Tensor(w2_scale_param, ms.bfloat16),
- name=w2_scale_name, requires_grad=False)
-
- def smooth_quant_process_ffn_weight(self, src_hf_dir, layer_id, hf_weight_map, parameter_dict, layer_type):
- """smooth_quant_process_ffn_weight"""
-
- ffn_concat = self.config.model.model_config.ffn_concat
- w1_weight_name = f"model.layers.{layer_id}.{layer_type}.w1._layer.weight"
- w1_weight_param, _ = self.get_safetensor_from_file_split_tp_group(w1_weight_name, src_hf_dir, hf_weight_map,
- split_axis=0)
- w1_scale_name = f"model.layers.{layer_id}.{layer_type}.w1._layer.matmul.weight_scale"
- w1_scale_param, _ = self.get_safetensor_from_file_split_tp_group(w1_scale_name, src_hf_dir, hf_weight_map,
- split_axis=0)
-
- w3_weight_name = f"model.layers.{layer_id}.{layer_type}.w3._layer.weight"
-
- w3_weight_param, _ = self.get_safetensor_from_file_split_tp_group(w3_weight_name, src_hf_dir, hf_weight_map,
- split_axis=0)
- w3_scale_name = f"model.layers.{layer_id}.{layer_type}.w3._layer.matmul.weight_scale"
- w3_scale_param, _ = self.get_safetensor_from_file_split_tp_group(w3_scale_name, src_hf_dir, hf_weight_map,
- split_axis=0)
- w2_weight_name = f"model.layers.{layer_id}.{layer_type}.w2._layer.weight"
- w2_scale_name = f"model.layers.{layer_id}.{layer_type}.w2._layer.matmul.weight_scale"
- w2_weight_param, _ = self.get_safetensor_from_file_split_tp_group(w2_weight_name, src_hf_dir, hf_weight_map,
- split_axis=1)
- w2_scale_param, _ = self.get_safetensor_from_file(w2_scale_name, src_hf_dir, hf_weight_map)
-
- if ffn_concat:
- concat_weight_name = f"model.layers.{layer_id}.{layer_type}.w_gate_hidden._layer.weight"
- concat_weight_param = ms.Tensor(np.concatenate([w1_weight_param, w3_weight_param], axis=0), dtype=ms.int8)
- parameter_dict[concat_weight_name] = ms.Parameter(concat_weight_param, name=concat_weight_name,
- requires_grad=False)
-
- concat_scale_name = f"model.layers.{layer_id}.{layer_type}.w_gate_hidden._layer.matmul.weight_scale"
- concat_scale_param = ms.Tensor(np.concatenate([w1_scale_param, w3_scale_param], axis=0), dtype=ms.bfloat16)
- parameter_dict[concat_scale_name] = ms.Parameter(concat_scale_param, name=concat_scale_name,
- requires_grad=False)
- else:
- # w1 w3
- parameter_dict[w1_weight_name] = ms.Parameter(ms.Tensor(w1_weight_param, ms.int8), name=w1_weight_name,
- requires_grad=False)
- parameter_dict[w3_weight_name] = ms.Parameter(ms.Tensor(w3_weight_param, ms.int8), name=w3_weight_name,
- requires_grad=False)
-
- parameter_dict[w1_scale_name] = ms.Parameter(ms.Tensor(w1_scale_param, ms.bfloat16),
- name=w1_scale_name, requires_grad=False)
- parameter_dict[w3_scale_name] = ms.Parameter(ms.Tensor(w3_scale_param, ms.bfloat16),
- name=w3_scale_name, requires_grad=False)
-
- parameter_dict[w2_weight_name] = ms.Parameter(ms.Tensor(w2_weight_param, ms.int8), name=w2_weight_name,
- requires_grad=False)
- parameter_dict[w2_scale_name] = ms.Parameter(ms.Tensor(w2_scale_param, ms.bfloat16),
- name=w2_scale_name, requires_grad=False)
-
- def smooth_quant_process_qkv_weight(self, src_hf_dir, layer_id, hf_weight_map, parameter_dict):
- '''smooth_quant_process_qkv_weight'''
- qkv_concat = self.config.model.model_config.qkv_concat
- # q2l_proj
- q2l_weight_name = f"model.layers.{layer_id}.attention.q2l_proj._layer.weight"
- q2l_weight_param, _ = self.get_safetensor_from_file(q2l_weight_name, src_hf_dir, hf_weight_map)
- q2l_bias_name = f"model.layers.{layer_id}.attention.q2l_proj._layer.matmul.quant_bias"
- q2l_bias_param, _ = self.get_safetensor_from_file(q2l_bias_name, src_hf_dir, hf_weight_map)
- q2l_scale_name = f"model.layers.{layer_id}.attention.q2l_proj._layer.matmul.dequant_scale"
- q2l_scale_param, _ = self.get_safetensor_from_file(q2l_scale_name, src_hf_dir, hf_weight_map)
-
- q2l_quant_zp = f"model.layers.{layer_id}.attention.q2l_proj.quant_op.input_zp"
- q2l_quant_scale = f"model.layers.{layer_id}.attention.q2l_proj.quant_op.input_scale"
- q2l_quant_zp_param, _ = self.get_safetensor_from_file(q2l_quant_zp, src_hf_dir, hf_weight_map)
- q2l_quant_scale_param, _ = self.get_safetensor_from_file(q2l_quant_scale, src_hf_dir, hf_weight_map)
-
- kv2l_weight_name = f"model.layers.{layer_id}.attention.kv2l._layer.weight"
- kv2l_weight_param, _ = self.get_safetensor_from_file(kv2l_weight_name, src_hf_dir, hf_weight_map)
- kv2l_bias_name = f"model.layers.{layer_id}.attention.kv2l._layer.matmul.quant_bias"
- kv2l_bias_param, _ = self.get_safetensor_from_file(kv2l_bias_name, src_hf_dir, hf_weight_map)
- kv2l_scale_name = f"model.layers.{layer_id}.attention.kv2l._layer.matmul.dequant_scale"
- kv2l_scale_param, _ = self.get_safetensor_from_file(kv2l_scale_name, src_hf_dir, hf_weight_map)
-
- kv2l_quant_zp = f"model.layers.{layer_id}.attention.kv2l.quant_op.input_zp"
- kv2l_quant_scale = f"model.layers.{layer_id}.attention.kv2l.quant_op.input_scale"
- kv2l_quant_zp_param, _ = self.get_safetensor_from_file(kv2l_quant_zp, src_hf_dir, hf_weight_map)
- kv2l_quant_scale_param, _ = self.get_safetensor_from_file(kv2l_quant_scale, src_hf_dir, hf_weight_map)
-
- if qkv_concat:
- qkv2l_weight_name = f"model.layers.{layer_id}.attention.qkv2l._layer.weight"
- qkv2l_bias_name = f"model.layers.{layer_id}.attention.qkv2l._layer.matmul.quant_bias"
- qkv2l_scale_name = f"model.layers.{layer_id}.attention.qkv2l._layer.matmul.dequant_scale"
- qkv2l_quant_zp_name = f"model.layers.{layer_id}.attention.qkv2l.quant_op.input_zp"
- qkv2l_quant_scale_name = f"model.layers.{layer_id}.attention.qkv2l.quant_op.input_scale"
-
- qkv2l_weight = np.concatenate((q2l_weight_param, kv2l_weight_param), 0)
- parameter_dict[qkv2l_weight_name] = ms.Parameter(ms.Tensor(qkv2l_weight, ms.int8), name=qkv2l_weight_name,
- requires_grad=False)
- qkv2l_bias = np.concatenate((q2l_bias_param, kv2l_bias_param), 0)
- parameter_dict[qkv2l_bias_name] = ms.Parameter(ms.Tensor(qkv2l_bias, ms.int32), name=qkv2l_bias_name,
- requires_grad=False)
- qkv2l_scale = np.concatenate((q2l_scale_param, kv2l_scale_param), 0)
- parameter_dict[qkv2l_scale_name] = ms.Parameter(ms.Tensor(qkv2l_scale, ms.float32), name=qkv2l_scale_name,
- requires_grad=False)
- parameter_dict[qkv2l_quant_zp_name] = ms.Parameter(ms.Tensor(q2l_quant_zp_param, ms.int8),
- name=qkv2l_quant_zp_name, requires_grad=False)
- parameter_dict[qkv2l_quant_scale_name] = ms.Parameter(ms.Tensor(q2l_quant_scale_param, ms.bfloat16),
- name=qkv2l_quant_scale_name, requires_grad=False)
- else:
- parameter_dict[q2l_weight_name] = ms.Parameter(ms.Tensor(q2l_weight_param, ms.int8), name=q2l_weight_name,
- requires_grad=False)
- parameter_dict[kv2l_weight_name] = ms.Parameter(ms.Tensor(kv2l_weight_param, ms.int8),
- name=kv2l_weight_name, requires_grad=False)
- parameter_dict[q2l_bias_name] = ms.Parameter(ms.Tensor(q2l_bias_param, ms.int32), name=q2l_bias_name,
- requires_grad=False)
- parameter_dict[kv2l_bias_name] = ms.Parameter(ms.Tensor(kv2l_bias_param, ms.int32), name=kv2l_bias_name,
- requires_grad=False)
- parameter_dict[q2l_scale_name] = ms.Parameter(ms.Tensor(q2l_scale_param, ms.float32), name=q2l_scale_name,
- requires_grad=False)
- parameter_dict[kv2l_scale_name] = ms.Parameter(ms.Tensor(kv2l_scale_param, ms.float32),
- name=kv2l_scale_name, requires_grad=False)
- parameter_dict[q2l_quant_zp] = ms.Parameter(ms.Tensor(q2l_quant_zp_param, ms.int8), name=q2l_quant_zp,
- requires_grad=False)
- parameter_dict[kv2l_quant_zp] = ms.Parameter(ms.Tensor(kv2l_quant_zp_param, ms.int8), name=kv2l_quant_zp,
- requires_grad=False)
- parameter_dict[q2l_quant_scale] = ms.Parameter(ms.Tensor(q2l_quant_scale_param, ms.bfloat16),
- name=q2l_quant_scale, requires_grad=False)
- parameter_dict[kv2l_quant_scale] = ms.Parameter(ms.Tensor(kv2l_quant_scale_param, ms.bfloat16),
- name=kv2l_quant_scale, requires_grad=False)
-
- def infer_smooth_quant_row_linear_split(self, param_name, src_hf_dir, hf_weight_map):
- '''infer_smooth_quant_row_linear_split'''
- if param_name.endswith(".weight"):
- value, _ = self.get_safetensor_from_file_split_tp_group(param_name, src_hf_dir,
- hf_weight_map,
- split_axis=1)
- elif "quant_op" in param_name:
- value, _ = self.get_safetensor_from_file_split_tp_group(param_name, src_hf_dir,
- hf_weight_map,
- split_axis=0)
- else:
- value, _ = self.get_safetensor_from_file(param_name, src_hf_dir,
- hf_weight_map)
- quant_bias_set_zero = ["wo._layer.matmul.quant_bias", "w2._layer.matmul.quant_bias"]
- if any([name in param_name for name in quant_bias_set_zero]) and \
- get_tensor_model_parallel_rank() != 0:
- value.fill(0)
-
- return value
-
- def infer_smooth_quant_get_value(self, param_name, src_hf_dir, hf_weight_map, no_need_split_layer):
- '''infer_smooth_quant_get_value'''
-
- if any([name in param_name for name in no_need_split_layer]):
- value, _ = self.get_safetensor_from_file(param_name, src_hf_dir,
- hf_weight_map)
- elif any([name in param_name for name in [".l2q_proj."]]):
- if param_name.endswith(".weight") or "matmul" in param_name:
- value, _ = self.get_safetensor_from_file_split_tp_group(param_name, src_hf_dir,
- hf_weight_map,
- split_axis=0)
- else:
- value, _ = self.get_safetensor_from_file(param_name, src_hf_dir,
- hf_weight_map)
- elif any([name in param_name for name in [".wo."]]):
- value = self.infer_smooth_quant_row_linear_split(param_name, src_hf_dir, hf_weight_map)
- elif any([name in param_name for name in ["lkv2kv_k_nope", "lkv2kv_v"]]):
- value, _ = self.get_safetensor_from_file_split_tp_group(param_name, src_hf_dir, hf_weight_map,
- split_axis=0)
- elif "lm_head" in param_name:
- if not self.config.parallel_config.vocab_emb_dp:
- value, _ = self.get_safetensor_from_file_split_tp_group(param_name, src_hf_dir, hf_weight_map,
- split_axis=0)
- else:
- value, _ = self.get_safetensor_from_file(param_name, src_hf_dir, hf_weight_map)
- else:
- raise ValueError(f"not found layer {param_name}, please check safetensors file.")
- return value
-
- def infer_smooth_quant_net_ms_convert_layer_weight(self, src_hf_dir, num_layers, hf_weight_map):
- '''infer_smooth_quant_net_ms_convert_layer_weight'''
- parameter_dict = {}
-
- no_need_split_layer = ["tok_embeddings", "norm", "routed_experts.router.dense",
- "routed_experts.router.e_score_correction_bias",
- "topk_bias"]
- for layer_id in tqdm(range(num_layers), desc="qkv/ffn params load"):
- if layer_id >= 3:
- self.smooth_quant_process_route_ffn_weight(src_hf_dir, layer_id, hf_weight_map, parameter_dict,
- "feed_forward.routed_experts.ffn")
- self.smooth_quant_process_shared_ffn_weight(src_hf_dir, layer_id, hf_weight_map, parameter_dict,
- "feed_forward.shared_experts")
-
- else:
- self.smooth_quant_process_ffn_weight(src_hf_dir, layer_id, hf_weight_map, parameter_dict,
- "feed_forward")
- self.smooth_quant_process_qkv_weight(src_hf_dir, layer_id, hf_weight_map, parameter_dict)
-
- skip_layer = ["feed_forward.routed_experts.ffn", "feed_forward.shared_experts", "feed_forward.w",
- "attention.kv2l", "attention.q"]
-
- for param_name, _ in tqdm(hf_weight_map.items(), desc="remaining params load"):
- if "model.layers" in param_name and int(param_name.split('.')[2]) >= num_layers:
- continue
-
- if any([name in param_name for name in skip_layer]):
- continue
-
- value = self.infer_smooth_quant_get_value(param_name, src_hf_dir, hf_weight_map, no_need_split_layer)
- dst_dtype = convert_np_to_ms_dtype(value)
-
- parameter_dict[param_name] = ms.Parameter(ms.Tensor(value, dtype=dst_dtype),
- name=param_name, requires_grad=False)
-
- param_not_load, ckpt_not_load = ms.load_param_into_net(self.network, parameter_dict)
- logger.info(f"smoothquant param_not_load:{param_not_load}")
- logger.info(f"smoothquant ckpt_not_load:{ckpt_not_load}")
-
- def infer_gptq_quant_net_ms_convert_layer_weight(self, src_hf_dir, num_layers, hf_weight_map):
- """infer_gptq_quant_net_ms_convert_layer_weight"""
- parameter_dict = {}
-
- no_need_split_layer = ["tok_embeddings", "norm", "q2l_proj",
- "kv2l", "routed_experts.router.dense",
- "routed_experts.router.e_score_correction_bias",
- "topk_bias"]
-
- for param_name, _ in tqdm(hf_weight_map.items(), desc="split safetensors"):
- if "model.layers" in param_name and int(param_name.split('.')[2]) >= num_layers:
- continue
-
- if any([name in param_name for name in no_need_split_layer]):
- value, is_int4 = self.get_safetensor_from_file(param_name, src_hf_dir,
- hf_weight_map)
- elif any([name in param_name for name in [".l2q_proj.", ".feed_forward.w_gate_hidden.",
- "shared_experts.w_gate_hidden"]]):
- value, is_int4 = self.get_safetensor_from_file_split_tp_group(
- param_name, src_hf_dir, hf_weight_map, split_axis=1)
- elif any([name in param_name for name in [".wo."]]):
- value, is_int4 = self.get_safetensor_from_file_split_tp_group(
- param_name, src_hf_dir, hf_weight_map, split_axis=0)
- elif any([name in param_name for name in [".feed_forward.w2.","shared_experts.w2"]]):
- value = self.infer_smooth_quant_row_linear_split(param_name, src_hf_dir, hf_weight_map)
- is_int4 = False
- elif ".routed_experts.ffn.w_gate_hidden." in param_name:
- value, is_int4 = self.get_safetensor_from_file(param_name, src_hf_dir, hf_weight_map)
- value_list = []
- for experts_id in range(value.shape[0]):
- value_list.append(self.split_weight_by_rank(value[experts_id, :, :], split_axis=1))
- value = np.stack(value_list, axis=0)
- elif ".routed_experts.ffn.w2" in param_name:
- value, is_int4 = self.get_safetensor_from_file(param_name, src_hf_dir, hf_weight_map)
- value_list = []
- for experts_id in range(value.shape[0]):
- value_list.append(self.split_weight_by_rank(value[experts_id, :, :], split_axis=0))
- value = np.stack(value_list, axis=0)
- elif any([name in param_name for name in ["lkv2kv_k_nope", "lkv2kv_v"]]):
- value, is_int4 = self.get_safetensor_from_file_split_tp_group(param_name, src_hf_dir, hf_weight_map,
- split_axis=0)
- elif "lm_head" in param_name:
- if not self.config.parallel_config.vocab_emb_dp:
- value, is_int4 = self.get_safetensor_from_file_split_tp_group(param_name, src_hf_dir, hf_weight_map,
- split_axis=0)
- else:
- value, is_int4 = self.get_safetensor_from_file(param_name, src_hf_dir, hf_weight_map)
- else:
- raise ValueError(f"not found layer {param_name}, please check safetensors file.")
-
- dst_dtype = convert_np_to_ms_dtype(value)
- if is_int4:
- parameter_dict[param_name] = ms.Parameter(ms.Tensor(value, dtype=dtype.qint4x2),
- name=param_name, requires_grad=False)
- else:
- parameter_dict[param_name] = ms.Parameter(ms.Tensor(value, dtype=dst_dtype),
- name=param_name, requires_grad=False)
- _, _ = ms.load_param_into_net(self.network, parameter_dict)
-
- def load_safetensors_shard(self, src_hf_dir, is_mtp_model=False):
- """deepseek load safetensors and shard """
- rank_id = get_rank()
- param_json_path = ""
-
- for file in os.listdir(src_hf_dir):
- if file.endswith('index.json'):
- # mtp model do not support quantization, needs to load bf16 weight.
- if ('quant' in file and self.is_quant) or \
- ('quant' not in file and (not self.is_quant or is_mtp_model)):
- param_json_path = os.path.join(src_hf_dir, file)
- with open(param_json_path, "r") as fp:
- hf_weight_map = json.load(fp)['weight_map']
- break
- elif file.endswith('_name_map.json'):
- param_json_path = os.path.join(src_hf_dir, file)
- with open(param_json_path, "r") as fp:
- hf_weight_map = json.load(fp)
- if hf_weight_map.get('weight_map'):
- hf_weight_map = hf_weight_map['weight_map']
- break
-
- if not param_json_path:
- raise ValueError(f"Not found param_json_path in {src_hf_dir}")
-
- quantization_config = self.config.model.model_config.quantization_config
- quant_method = quantization_config.quant_method if quantization_config else None
- support_quant_method = ["gptq-pergroup", "smoothquant", "osl"]
- if not quant_method or (quant_method not in support_quant_method) and \
- not is_mtp_model:
- self.infer_convert_outer_weight(src_hf_dir, hf_weight_map)
-
- if quant_method and quant_method == "gptq-pergroup":
- self.infer_gptq_quant_net_ms_convert_layer_weight(src_hf_dir, self.num_layers, hf_weight_map)
- return
- if quant_method and quant_method == "smoothquant":
- self.infer_smooth_quant_net_ms_convert_layer_weight(src_hf_dir, self.num_layers, hf_weight_map)
- return
- if quant_method and quant_method == "osl":
- self.infer_smooth_quant_net_ms_convert_layer_weight(src_hf_dir, self.num_layers, hf_weight_map)
- return
-
- enable_tqdm = rank_id == 0
- mtp_layers = self.config.model.model_config.num_nextn_predict_layers
- start_layer = 0 if not is_mtp_model else self.num_layers
- end_layer = self.num_layers if not is_mtp_model else self.num_layers + mtp_layers
- for layer_id in tqdm(range(start_layer, end_layer), desc="Weight loading", disable=not enable_tqdm):
- if self.is_quant:
- self.infer_quant_net_convert_layer_weight(src_hf_dir, layer_id, hf_weight_map)
- else:
- self.infer_convert_layer_weight(src_hf_dir, layer_id, hf_weight_map)
-
- param_not_load, ckpt_not_load = ms.load_param_into_net(self.network, self.parameter_dict)
- logger.info("param_not_load: %s, ckpt_not_load: %s" % (str(param_not_load), str(ckpt_not_load)))
- del self.parameter_dict
- gc.collect()
diff --git a/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py b/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py
index 8a5a07778cd2095b24289db1ac309da8c5ad28e4..56cabb1f7b433eebadbec05f4e8a2892611d8f2a 100644
--- a/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py
+++ b/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py
@@ -1,5 +1,4 @@
#!/usr/bin/env python3
-# encoding: utf-8
# Copyright 2025 Huawei Technologies Co., Ltd
# Copyright 2024 The vLLM team.
#
@@ -17,58 +16,64 @@
# ============================================================================
import os
-from types import MethodType
-from typing import Iterable, List, Optional, Set, Tuple, Union
from abc import abstractmethod
-import numpy as np
-import math
-
-from vllm.config import VllmConfig
-from vllm.model_executor.layers.sampler import SamplerOutput
-from vllm.model_executor.sampling_metadata import SamplingMetadata
-from vllm.sequence import IntermediateTensors
-from vllm.distributed import get_tensor_model_parallel_world_size
-from vllm.logger import init_logger
-from vllm.forward_context import get_forward_context
-import vllm.envs as envs
+from typing import Iterable, Optional, Set, Tuple, Union
import mindspore as ms
-from mindspore import Tensor
-from mindspore.common.api import _pynative_executor
-
-from mindformers.tools.register.config import MindFormerConfig
from mindformers.core.context import build_mf_context
from mindformers.core.parallel_config import build_parallel_config
+from mindformers.tools.register.config import MindFormerConfig
from mindformers.tools.utils import is_pynative
+from mindspore import Tensor, mint
+from mindspore.common.api import _pynative_executor
+from mindspore.communication import get_rank
+from vllm.config import VllmConfig
+from vllm.distributed import get_tensor_model_parallel_world_size
+from vllm.distributed.parallel_state import get_dp_group
+from vllm.logger import init_logger
+from vllm.model_executor.layers.sampler import SamplerOutput
+from vllm.model_executor.sampling_metadata import SamplingMetadata
+from vllm.sequence import IntermediateTensors
+from vllm_mindspore.model_executor.models.attention_mask import (
+ LowerTriangularMask)
from vllm_mindspore.model_executor.models.model_base import MsModelBase
-from vllm_mindspore.model_executor.models.attention_mask import LowerTriangularMask
-from vllm_mindspore.v1.attention.backends.flash_attn import FlashAttentionMetadata
logger = init_logger(__name__)
+
class MfModelBase(MsModelBase):
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
- super(MfModelBase, self).__init__(
- vllm_config=vllm_config, prefix=prefix
- )
+ super().__init__(vllm_config=vllm_config, prefix=prefix)
+
+ self.set_flags = False
+
+ model_config_path = os.getenv("MINDFORMERS_MODEL_CONFIG")
+ if model_config_path is None:
+ raise RuntimeError(
+ 'For "MindFormers" model backend, environments MINDFORMERS_MODEL_CONFIG should be set!'
+ )
+
+ self.mf_config = MindFormerConfig(model_config_path)
+ self.rank_id = get_rank()
+ self.dp_size = get_dp_group()
- self.mf_config = MindFormerConfig(os.getenv("MINDFORMERS_MODEL_CONFIG"))
build_mf_context(self.mf_config)
build_parallel_config(self.mf_config)
self.mf_config.model.model_config.parallel_config = (
- self.mf_config.parallel_config
- )
+ self.mf_config.parallel_config)
self.mf_config.model.model_config.parallel_config.model_parallel = (
- get_tensor_model_parallel_world_size()
- )
+ get_tensor_model_parallel_world_size())
self.mf_config.model.model_config.parallel_config.pipeline_stage = 1
self._generate_model_config()
- self.casual_mask = LowerTriangularMask(dtype=self.mf_model_config.compute_dtype,
- max_model_len=self.mf_model_config.seq_length)
+ self.casual_mask = LowerTriangularMask(
+ dtype=self.mf_model_config.compute_dtype,
+ max_model_len=self.model_config.max_model_len)
self.network, self.lm_head = self._create_network()
- affinity_config = self.mf_config.get('context', {}).get('affinity_cpu_list', {})
+ affinity_config = self.mf_config.get('context',
+ {}).get('affinity_cpu_list', {})
if isinstance(affinity_config, dict):
ms.runtime.set_cpu_affinity(True, affinity_config)
@@ -76,131 +81,51 @@ class MfModelBase(MsModelBase):
@abstractmethod
def _generate_model_config(self):
- raise NotImplementedError("Function _generate_model_config should be Implemented!")
+ raise NotImplementedError(
+ "Function _generate_model_config should be Implemented!")
@abstractmethod
def _create_network(self):
- raise NotImplementedError("Function _create_network should be Implemented!")
+ raise NotImplementedError(
+ "Function _create_network should be Implemented!")
def _set_dynamic_inputs(self):
self.network.set_dynamic_inputs()
- dynamic_hidden_states = Tensor(shape=[None, None], dtype=self.mf_model_config.compute_dtype)
+ dynamic_hidden_states = Tensor(
+ shape=[None, None], dtype=self.mf_model_config.compute_dtype)
self.lm_head.set_inputs(dynamic_hidden_states)
- def _dummy_attention_metadata(self, input_ids: Tensor, positions: Tensor) -> FlashAttentionMetadata:
- input_len = input_ids.shape[0]
- max_seq_len = ms.Tensor(input_len, dtype=ms.int32)
- seq_lengths = ms.Tensor([input_len], dtype=ms.int32)
- q_seq_lens = ms.Tensor([input_len], dtype=ms.int32)
- q_seq_lens_np = np.array([input_len], dtype=np.int32)
- seq_lens_np = np.array([input_len], dtype=np.int32)
-
- block_tables = ms.Tensor([[0]], dtype=ms.int32)
- slot_mapping = [-1 for _ in range(input_len)]
- slot_mapping = ms.Tensor(slot_mapping, dtype=ms.int32)
- return FlashAttentionMetadata(
- max_seq_len=max_seq_len,
- seq_lens=seq_lengths,
- seq_lens_np=seq_lens_np,
- block_tables=block_tables,
- slot_mapping=slot_mapping,
- q_seq_lens=q_seq_lens,
- q_seq_lens_np=q_seq_lens_np,
- context_lens=0,
- # To enforce prefill and decode are both complied in warmup process.
- # So set max_context_lens to 0 for prefill and 1 for decode.
- max_context_lens=0 if not self.set_flags else 1,
- query_start_loc = None
- )
-
- def prepare_inputs(self, input_ids, positions, attn_metadata):
- key_cache, value_cache = self.get_kvcache()
- if not envs.VLLM_USE_V1:
- seq_lens = attn_metadata.seq_lens
- max_query_len = attn_metadata.max_query_len
- # When Mutli-Step is enabled with Chunked-Prefill, prefills and
- # decodes are scheduled together. In the first step, all the
- # prefills turn into decodes and max_query_len will be 1.
- if self.is_multi_step_chunked_prefill and max_query_len == 1:
- query_lens = [1] * len(seq_lens)
- else:
- query_lens = attn_metadata.query_lens
-
- seq_lens = attn_metadata.seq_lens
- max_query_len = attn_metadata.max_query_len
- # When Mutli-Step is enabled with Chunked-Prefill, prefills and
- # decodes are scheduled together. In the first step, all the
- # prefills turn into decodes and max_query_len will be 1.
- if self.is_multi_step_chunked_prefill and max_query_len == 1:
- query_lens = [1] * len(seq_lens)
- else:
- query_lens = attn_metadata.query_lens
-
- seq_lens_np = np.array(seq_lens, dtype=np.int32)
- query_lens_np = np.array(query_lens, dtype=np.int32)
- kv_cache_lens = seq_lens_np - query_lens_np
- if attn_metadata.num_decode_tokens == 0 and kv_cache_lens.max() == 0:
- is_prefill = True
- else:
- is_prefill = False
-
- q_seq_lens = ms.Tensor(query_lens_np, dtype=ms.int32)
- position_ids = ms.Tensor(positions, dtype=ms.int32)
- attention_mask = self.casual_mask.gen_attention_mask(is_prefill, position_ids, query_lens)
-
- model_inputs = {}
- model_inputs["input_ids"] = input_ids.astype(ms.int32)
- model_inputs["batch_valid_length"] = ms.from_numpy(seq_lens_np)
- model_inputs["block_tables"] = attn_metadata.block_tables
- model_inputs["slot_mapping"] = attn_metadata.slot_mapping
- model_inputs["position_ids"] = position_ids
- model_inputs["q_seq_lens"] = q_seq_lens
- model_inputs["attention_mask"] = attention_mask
- model_inputs["key_cache"] = key_cache
- model_inputs["value_cache"] = value_cache
- else:
- if attn_metadata.max_context_lens == 0:
- is_prefill = True
- else:
- is_prefill = False
- q_seq_lens = attn_metadata.q_seq_lens
- query_lens_np = attn_metadata.q_seq_lens_np
- attention_mask = self.casual_mask.gen_attention_mask(is_prefill, positions, query_lens_np)
-
- model_inputs = {}
- model_inputs["input_ids"] = input_ids.astype(ms.int32)
- model_inputs["batch_valid_length"] = ms.from_numpy(attn_metadata.seq_lens_np)
- model_inputs["block_tables"] = attn_metadata.block_tables
- model_inputs["slot_mapping"] = attn_metadata.slot_mapping
- model_inputs["position_ids"] = positions.to(ms.int32)
- model_inputs["q_seq_lens"] = q_seq_lens
- model_inputs["attention_mask"] = attention_mask
- model_inputs["key_cache"] = key_cache
- model_inputs["value_cache"] = value_cache
-
- return model_inputs, is_prefill
+ def prepare_inputs(self, input_ids, positions):
+ return self.prepare_base_inputs(input_ids, positions)
def update_model_inputs(self, model_inputs, **kwargs):
return model_inputs
- def forward(
- self,
- input_ids: Tensor,
- positions: Tensor,
- intermediate_tensors: Optional[IntermediateTensors] = None,
- inputs_embeds: Optional[Tensor] = None,
- **kwargs
- ) -> Union[Tensor, IntermediateTensors]:
- attn_metadata = get_forward_context().attn_metadata
- if attn_metadata is None:
- attn_metadata = self._dummy_attention_metadata(input_ids, positions)
- model_inputs, is_prefill = self.prepare_inputs(input_ids, positions, attn_metadata)
+ def forward(self,
+ input_ids: Tensor,
+ positions: Tensor,
+ intermediate_tensors: Optional[IntermediateTensors] = None,
+ inputs_embeds: Optional[Tensor] = None,
+ **kwargs) -> Union[Tensor, IntermediateTensors]:
+ model_inputs, is_prefill = self.prepare_inputs(input_ids, positions)
model_inputs = self.update_model_inputs(model_inputs, **kwargs)
+ # enable_mb_split is True in lager EP enable micro-batch and per-dp-bs > 1
+ enable_mb_split = self.is_enable_micro_batch_split(
+ is_prefill, model_inputs["q_seq_lens"])
+
if is_prefill:
- self.network.phase = "prefill"
- if not self.set_flags or is_pynative():
- self.network.add_flags_custom(is_first_iteration=True)
+ if self.enable_micro_batch:
+ self.network.phase = "prefill" if not enable_mb_split else "prefill_micro_batch"
+ if not self.set_flags or is_pynative() or enable_mb_split:
+ self.network.add_flags_custom(is_first_iteration=True)
+ self.network.add_flags_enable_micro_batch(
+ enable_micro_batch=enable_mb_split)
+ else:
+ self.network.phase = "prefill"
+ if not self.set_flags or is_pynative():
+ self.network.add_flags_custom(is_first_iteration=True)
+
hidden_states = self.network(**model_inputs)
self.network.phase = "increment"
if not self.set_flags or is_pynative():
@@ -218,11 +143,14 @@ class MfModelBase(MsModelBase):
) -> Optional[Tensor]:
if sampling_metadata is not None:
selected_token_indices = sampling_metadata.selected_token_indices
- if selected_token_indices is not None and selected_token_indices.numel() <= 0:
- logits = ms.mint.zeros((0, self.mf_model_config.vocab_size),
- dtype=self.mf_model_config.compute_dtype)
+ if selected_token_indices is not None and selected_token_indices.numel(
+ ) <= 0:
+ logits = ms.mint.zeros(
+ (0, self.mf_model_config.vocab_size),
+ dtype=self.mf_model_config.compute_dtype)
else:
- hidden_states = hidden_states.index_select(0, selected_token_indices)
+ hidden_states = hidden_states.index_select(
+ 0, selected_token_indices)
logits = self.lm_head(hidden_states)
logits = logits.view(-1, logits.shape[-1])
else:
@@ -241,3 +169,15 @@ class MfModelBase(MsModelBase):
def load_weights(self, weights: Iterable[Tuple[str, Tensor]]) -> Set[str]:
raise NotImplementedError("load_weight not implemented.")
+
+ def is_enable_micro_batch_split(self, is_prefill, q_seq_lens):
+ """Judge enable micro batch """
+ if self.enable_micro_batch:
+ is_prefill_cur_dp = mint.ones(
+ (1), dtype=ms.int8) if is_prefill else mint.zeros(
+ (1), dtype=ms.int8)
+ is_prefill_all_dp = get_dp_group().all_gather(is_prefill_cur_dp)
+ return is_prefill_all_dp.sum(
+ ) == self.dp_size and q_seq_lens.shape[0] > 1
+ else:
+ return False
diff --git a/vllm_mindspore/model_executor/models/mf_models/qwen2.py b/vllm_mindspore/model_executor/models/mf_models/qwen2.py
index d871be483b191995863112ee839aa5f2c7656765..fdb987daf0b3923f2906f0052a4e2f7a1b0bed8c 100644
--- a/vllm_mindspore/model_executor/models/mf_models/qwen2.py
+++ b/vllm_mindspore/model_executor/models/mf_models/qwen2.py
@@ -33,9 +33,8 @@ from research.qwen2_5.infer.qwen2_5 import (
)
from vllm_mindspore.model_executor.layers.sampler import get_sampler
-from vllm_mindspore.model_executor.models.model_base import Fake_Attention, Fake_Attention_V1
+from vllm_mindspore.model_executor.models.model_base import AttentionWrapper
from vllm_mindspore.model_executor.models.mf_models.mf_model_base import MfModelBase
-
from vllm_mindspore.model_executor.models.mf_models.qwen2_weight_processor import Qwen2WeightProcessor
@@ -49,10 +48,7 @@ class Qwen2ForCausalLM(MfModelBase):
self.sampler = get_sampler()
self.set_modules({"model": self.network})
- if envs.VLLM_USE_V1:
- self.kv_caches = [Fake_Attention_V1() for i in range(self.mf_model_config.num_layers)]
- else:
- self.kv_caches = [Fake_Attention() for i in range(self.mf_model_config.num_layers)]
+ self.kv_caches = [AttentionWrapper() for i in range(self.mf_model_config.num_layers)]
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
diff --git a/vllm_mindspore/model_executor/models/mf_models/qwen3.py b/vllm_mindspore/model_executor/models/mf_models/qwen3.py
index a5a8b01d6e906f2c9b8e51c7f3d0af288f05137b..a11a93faaab3d4bb978b40eb9b0372d72ef7b2e1 100644
--- a/vllm_mindspore/model_executor/models/mf_models/qwen3.py
+++ b/vllm_mindspore/model_executor/models/mf_models/qwen3.py
@@ -1,5 +1,4 @@
#!/usr/bin/env python3
-# encoding: utf-8
# Copyright 2025 Huawei Technologies Co., Ltd
# Copyright 2024 The vLLM team.
#
@@ -16,70 +15,210 @@
# limitations under the License.
# ============================================================================
-from typing import Iterable, Set, Tuple
-
-from vllm.config import VllmConfig
-from vllm.config import get_current_vllm_config
-from vllm.logger import init_logger
-
-from mindspore import Tensor, JitConfig
+from typing import Iterable, Optional, Tuple, Union
+
+import mindspore as ms
+import numpy as np
+from mindformers.core.context import build_mf_context
+from mindformers.core.parallel_config import build_parallel_config
+from mindformers.models.qwen3.configuration_qwen3 import Qwen3Config
+from mindformers.models.qwen3.modeling_qwen3 import ( # noqa
+ Qwen3ForCausalLM as Qwen3ForCausalLM_MF)
+from mindformers.tools.utils import is_pynative
+from mindspore import Tensor, ops
+from mindspore.common.api import _pynative_executor
from mindspore.nn.utils import no_init_parameters
-
-from mindformers.models.llama import LlamaConfig as LlamaConfig_MF
-from research.qwen3.qwen3 import (
- ParallelQwen3ForCausalLM as ParallelQwenForCausalLM_MF,
-)
+from vllm import envs
+from vllm.config import VllmConfig, get_current_vllm_config
+from vllm.forward_context import get_forward_context
+from vllm.logger import init_logger
+from vllm.model_executor.layers.sampler import SamplerOutput
+from vllm.model_executor.sampling_metadata import SamplingMetadata
+from vllm.sequence import IntermediateTensors
from vllm_mindspore.model_executor.layers.sampler import get_sampler
-from vllm_mindspore.model_executor.models.model_base import Fake_Attention
-from vllm_mindspore.model_executor.models.mf_models.mf_model_base import MfModelBase
-from vllm_mindspore.model_executor.models.mf_models.qwen3_weight_processor import Qwen3WeightProcessor
-
+from vllm_mindspore.model_executor.models.attention_mask import (
+ LowerTriangularMask)
+from vllm_mindspore.model_executor.models.mf_models.config import (
+ gen_mf_config, gen_model_config)
+from vllm_mindspore.model_executor.models.model_base import (AttentionWrapper,
+ MsModelBase)
logger = init_logger(__name__)
-class Qwen3ForCausalLM(MfModelBase):
+class Qwen3ForCausalLM(MsModelBase):
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
- super(Qwen3ForCausalLM, self).__init__(vllm_config=vllm_config, prefix=prefix)
- self.mf_kvcaches_init = False
+ super().__init__(vllm_config=vllm_config, prefix=prefix)
+ self.set_flags = False
+
+ mf_config = gen_mf_config(vllm_config)
+ mf_config.load_checkpoint = self.get_model_path()
+ self.mf_config = mf_config
+
+ build_mf_context(self.mf_config)
+ build_parallel_config(self.mf_config)
+
+ self._generate_model_config()
+ self.casual_mask = LowerTriangularMask(
+ dtype=self.mf_model_config.compute_dtype,
+ max_model_len=self.mf_model_config.seq_length)
+ self.network, self.lm_head = self._create_network()
+
+ affinity_config = self.mf_config.get('context',
+ {}).get('affinity_cpu_list', {})
+ if isinstance(affinity_config, dict):
+ ms.runtime.set_cpu_affinity(True, affinity_config)
+
+ self._set_dynamic_inputs()
self.sampler = get_sampler()
self.set_modules({"model": self.network})
-
- self.kv_caches = [Fake_Attention() for i in range(self.mf_model_config.num_layers)]
+ self.kv_caches = [
+ AttentionWrapper()
+ for _ in range(self.mf_model_config.num_hidden_layers)
+ ]
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
- for i in range(self.mf_model_config.num_layers):
- compilation_config.static_forward_context[str(i)] = self.kv_caches[i]
+ for i in range(self.mf_model_config.num_hidden_layers):
+ compilation_config.static_forward_context[str(
+ i)] = self.kv_caches[i]
- self.set_flags = False
+ self.cast = ops.Cast()
- def _generate_model_config(self):
- self.mf_config.load_checkpoint = self.get_model_path()
- self.mf_model_config = LlamaConfig_MF(**self.mf_config.model.model_config)
- if self.mf_config.moe_config:
- self.mf_model_config.moe_config = self.mf_config.moe_config
- self.mf_model_config.return_hidden_states = True
+ def _set_dynamic_inputs(self):
+ self.network.set_dynamic_inputs()
+ dynamic_hidden_states = Tensor(
+ shape=[None, None], dtype=self.mf_model_config.compute_dtype)
+ self.lm_head.set_inputs(dynamic_hidden_states)
- # qwen qkv concat will support in next version
- self.mf_model_config.qkv_concat = False
- setattr(self.mf_model_config, 'npu_mem_size', -1)
- self.mf_config.model.model_config.qkv_concat = False
+ def prepare_inputs(self, input_ids, positions):
+
+ attn_metadata = get_forward_context().attn_metadata
+ if attn_metadata is None:
+ attn_metadata = self._dummy_attention_metadata(
+ input_ids, positions)
+ key_cache, value_cache = self.get_kvcache()
+ if not envs.VLLM_USE_V1:
+ # V0
+ seq_lens = attn_metadata.seq_lens
+ max_query_len = attn_metadata.max_query_len
+ # When Mutli-Step is enabled with Chunked-Prefill, prefills and
+ # decodes are scheduled together. In the first step, all the
+ # prefills turn into decodes and max_query_len will be 1.
+ if self.is_multi_step_chunked_prefill and max_query_len == 1:
+ query_lens = [1] * len(seq_lens)
+ else:
+ query_lens = attn_metadata.query_lens
+
+ seq_lens_np = np.array(seq_lens, dtype=np.int32)
+ query_lens_np = np.array(query_lens, dtype=np.int32)
+ kv_cache_lens = seq_lens_np - query_lens_np
+ if attn_metadata.num_decode_tokens == 0 and kv_cache_lens.max(
+ ) == 0:
+ is_prefill = True
+ else:
+ is_prefill = False
+ context_lens_tensor = ms.from_numpy(kv_cache_lens)
+ else:
+ # V1
+ is_prefill = attn_metadata.max_context_lens == 0
+ query_lens_np = attn_metadata.q_seq_lens_np
+ seq_lens_np = attn_metadata.seq_lens_np
+ context_lens_tensor = attn_metadata.context_lens
+
+ q_seq_lens = ms.Tensor(query_lens_np, dtype=ms.int32)
+ position_ids = ms.Tensor(positions, dtype=ms.int32)
+ attention_mask = self.casual_mask.gen_attention_mask(
+ is_prefill, positions, query_lens_np)
+
+ model_inputs = {}
+ model_inputs["input_ids"] = input_ids.astype(ms.int32)
+ model_inputs["batch_valid_length"] = ms.from_numpy(seq_lens_np)
+ model_inputs["block_tables"] = attn_metadata.block_tables
+ model_inputs["slot_mapping"] = attn_metadata.slot_mapping
+ model_inputs["positions"] = position_ids
+ model_inputs["q_seq_lens"] = q_seq_lens
+ model_inputs["attention_mask"] = attention_mask
+ model_inputs["key_cache"] = key_cache
+ model_inputs["value_cache"] = value_cache
+ model_inputs["context_lens_tensor"] = context_lens_tensor
+
+ return model_inputs, is_prefill
+
+ def forward(self,
+ input_ids: Tensor,
+ positions: Tensor,
+ intermediate_tensors: Optional[IntermediateTensors] = None,
+ inputs_embeds: Optional[Tensor] = None,
+ **kwargs) -> Union[Tensor, IntermediateTensors]:
+ model_inputs, is_prefill = self.prepare_inputs(input_ids, positions)
+ model_inputs = self.update_model_inputs(model_inputs, **kwargs)
+
+ if is_prefill:
+ self.network.phase = "prefill"
+ if not self.set_flags or is_pynative():
+ self.network.add_flags_custom_mcore(is_prefill=True)
+ hidden_states = self.network(**model_inputs)
+ self.network.phase = "increment"
+ if not self.set_flags or is_pynative():
+ self.network.add_flags_custom_mcore(is_prefill=False)
+ self.set_flags = True
+ else:
+ hidden_states = self.network(**model_inputs)
+
+ return hidden_states
+
+ def _generate_model_config(self):
+ self.mf_model_config = gen_model_config(self.mf_config, Qwen3Config)
+ logger.debug("=====mf_model_config====\n", self.mf_model_config)
def _create_network(self):
# Initial network
with no_init_parameters(): # Delay initialization
- network = ParallelQwenForCausalLM_MF(self.mf_model_config)
- return network, network.lm_head
-
- def load_weights(self, weights: Iterable[Tuple[str, Tensor]]) -> Set[str]:
- weight_processor = Qwen3WeightProcessor(self.mf_config, self.network, False)
- weight_processor.load_safetensors_shard(self.mf_config.load_checkpoint)
-
+ network = Qwen3ForCausalLM_MF(self.mf_model_config)
+ return network, network.model.output_layer
+
+ def update_model_inputs(self, model_inputs, **kwargs):
+ return model_inputs
+
+ def compute_logits(
+ self,
+ hidden_states: Tensor,
+ sampling_metadata: SamplingMetadata,
+ ) -> Optional[Tensor]:
+ if sampling_metadata is not None:
+ selected_token_indices = sampling_metadata.selected_token_indices
+ if selected_token_indices is not None and selected_token_indices.numel(
+ ) <= 0:
+ logits = ms.mint.zeros(
+ (0, self.mf_model_config.vocab_size),
+ dtype=self.mf_model_config.compute_dtype)
+ else:
+ hidden_states = hidden_states.reshape(
+ (-1, hidden_states.shape[-1]))
+ hidden_states = hidden_states.index_select(
+ 0, selected_token_indices)
+ logits = self.lm_head(hidden_states)
+ logits = logits.view(-1, logits.shape[-1])
+ else:
+ logits = self.lm_head(hidden_states)
+ logits = logits.view(-1, logits.shape[-1])
+ return logits
+
+ def sample(
+ self,
+ logits: Tensor,
+ sampling_metadata: SamplingMetadata,
+ ) -> Optional[SamplerOutput]:
+ next_tokens = self.sampler(logits, sampling_metadata)
+ _pynative_executor.sync()
+ return next_tokens
+
+ def load_weights(self, weights: Iterable[Tuple[str, Tensor]]):
+ self.network.load_weights(self.mf_config.load_checkpoint)
self.network.set_dynamic_inputs()
- dynamic_hidden_states = Tensor(shape=[None, None], dtype=self.mf_model_config.compute_dtype)
- self.lm_head.set_inputs(dynamic_hidden_states)
return None
diff --git a/vllm_mindspore/model_executor/models/mf_models/qwen3_weight_processor.py b/vllm_mindspore/model_executor/models/mf_models/qwen3_weight_processor.py
deleted file mode 100644
index 338616cafda4f4864c26b58530f7db8d11481d9e..0000000000000000000000000000000000000000
--- a/vllm_mindspore/model_executor/models/mf_models/qwen3_weight_processor.py
+++ /dev/null
@@ -1,115 +0,0 @@
-# Copyright 2025 Huawei Technologies Co., Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ============================================================================
-
-"""
-transform huggingface model to mindspore safetensor.
-"""
-import numpy as np
-
-import mindspore as ms
-
-from vllm_mindspore.model_executor.models.mf_models.qwen2_weight_processor import Qwen2WeightProcessor
-
-
-class Qwen3WeightProcessor(Qwen2WeightProcessor):
- r"""
- Provide Qwen3 Model weight load and shards.
- Args:
- config (Qwen3Config): The config of Qwen3 model.
- network (InferenceQwen3ForCausalLM): The network of Qwen3.
-
- """
-
- def __init__(self, config, network, is_quant):
- super().__init__(config, network, is_quant)
-
- def convert_weight_name(self, weight_name: str):
- """replace weight name"""
- weight_name = weight_name.replace('embed_tokens.weight', 'tok_embeddings.embedding_weight')
- weight_name = weight_name.replace('self_attn.q_proj.', 'attention.wq.')
- weight_name = weight_name.replace('self_attn.k_proj.', 'attention.wk.')
- weight_name = weight_name.replace('self_attn.v_proj.', 'attention.wv.')
- weight_name = weight_name.replace('self_attn.o_proj.', 'attention.wo.')
- weight_name = weight_name.replace('self_attn.q_norm.', 'attention.q_norm.')
- weight_name = weight_name.replace('self_attn.k_norm.', 'attention.k_norm.')
-
- weight_name = weight_name.replace('mlp.gate_proj.', 'feed_forward.w1.')
- weight_name = weight_name.replace('mlp.down_proj.', 'feed_forward.w2.')
- weight_name = weight_name.replace('mlp.up_proj.', 'feed_forward.w3.')
- weight_name = weight_name.replace('.input_layernorm.', '.attention_norm.')
- weight_name = weight_name.replace('.post_attention_layernorm.', '.ffn_norm.')
- weight_name = weight_name.replace('model.norm.weight', 'model.norm_out.weight')
- return weight_name
-
- def infer_process_attention_weight(self, src_hf_dir, layer_id, hf_weight_map):
- """infer process attention weight"""
- qkv_concat = self.config.model.model_config.qkv_concat
- # wq
- wq_hf_name = f"model.layers.{layer_id}.self_attn.q_proj.weight"
- wq_ms_name = self.convert_weight_name(wq_hf_name)
- wq_ms_param, _ = self.get_safetensor_from_file(wq_hf_name, src_hf_dir, hf_weight_map, is_split_param=True,
- split_axis=0)
-
- # wk
- wk_hf_name = f"model.layers.{layer_id}.self_attn.k_proj.weight"
- wk_ms_name = self.convert_weight_name(wk_hf_name)
- wk_ms_param, _ = self.get_safetensor_from_file(wk_hf_name, src_hf_dir, hf_weight_map, is_split_param=True,
- split_axis=0)
-
- # wv
- wv_hf_name = f"model.layers.{layer_id}.self_attn.v_proj.weight"
- wv_ms_name = self.convert_weight_name(wv_hf_name)
- wv_ms_param, _ = self.get_safetensor_from_file(wv_hf_name, src_hf_dir, hf_weight_map, is_split_param=True,
- split_axis=0)
-
- # wq_norm
- q_norm_hf_name = f"model.layers.{layer_id}.self_attn.q_norm.weight"
- q_norm_ms_name = self.convert_weight_name(q_norm_hf_name)
- q_norm_ms_param, _ = self.get_safetensor_from_file(q_norm_hf_name, src_hf_dir, hf_weight_map)
- self.parameter_dict[q_norm_ms_name] = ms.Parameter(ms.Tensor(q_norm_ms_param, ms.bfloat16), name=q_norm_ms_name,
- requires_grad=False)
-
- #wk_norm
- k_norm_hf_name = f"model.layers.{layer_id}.self_attn.k_norm.weight"
- k_norm_ms_name = self.convert_weight_name(k_norm_hf_name)
- k_norm_ms_param, _ = self.get_safetensor_from_file(k_norm_hf_name, src_hf_dir, hf_weight_map)
- self.parameter_dict[k_norm_ms_name] = ms.Parameter(ms.Tensor(k_norm_ms_param, ms.bfloat16), name=k_norm_ms_name,
- requires_grad=False)
-
- if qkv_concat:
- w_qkv_name = f"model.layers.{layer_id}.attention.w_qkv.weight"
- w_qkv_param = np.concatenate((wq_ms_param, wk_ms_param, wv_ms_param), axis=0)
- w_qkv_param = ms.from_numpy(w_qkv_param).astype(ms.bfloat16)
- self.parameter_dict[w_qkv_name] = ms.Parameter(w_qkv_param, name=w_qkv_name, requires_grad=False)
-
- else:
- self.parameter_dict[wq_ms_name] = ms.Parameter(ms.from_numpy(wq_ms_param).astype(ms.bfloat16),
- name=wq_ms_name,
- requires_grad=False)
- self.parameter_dict[wk_ms_name] = ms.Parameter(ms.from_numpy(wk_ms_param).astype(ms.bfloat16),
- name=wk_ms_name,
- requires_grad=False)
- self.parameter_dict[wv_ms_name] = ms.Parameter(ms.from_numpy(wv_ms_param).astype(ms.bfloat16),
- name=wv_ms_name,
- requires_grad=False)
-
- # wo
- wo_hf_name = f"model.layers.{layer_id}.self_attn.o_proj.weight"
- wo_ms_name = self.convert_weight_name(wo_hf_name)
- wo_ms_param, _ = self.get_safetensor_from_file(wo_hf_name, src_hf_dir, hf_weight_map, is_split_param=True,
- split_axis=1)
- self.parameter_dict[wo_ms_name] = ms.Parameter(ms.from_numpy(wo_ms_param).astype(ms.bfloat16),
- name=wo_ms_name,
- requires_grad=False)
diff --git a/vllm_mindspore/model_executor/models/model_base.py b/vllm_mindspore/model_executor/models/model_base.py
index 0d933a2db438919cf68388833e1f95c572436c81..1bd492330267f15a666cbdf6b02571f323d55df4 100644
--- a/vllm_mindspore/model_executor/models/model_base.py
+++ b/vllm_mindspore/model_executor/models/model_base.py
@@ -1,5 +1,4 @@
#!/usr/bin/env python3
-# encoding: utf-8
# Copyright 2025 Huawei Technologies Co., Ltd
# Copyright 2024 The vLLM team.
#
@@ -18,70 +17,43 @@
import os
from abc import abstractmethod
-from typing import Iterable, List, Optional, Set, Tuple, Union, Dict
+from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union
+import mindspore as ms
+import numpy as np
+import vllm.envs as envs
+from mindspore import Tensor, mutable, nn
+from mindspore.common import dtype as mstype
+from vllm.attention.backends.abstract import AttentionType
from vllm.config import VllmConfig, get_current_vllm_config
+from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
-from vllm.attention.backends.abstract import AttentionType
-from vllm.forward_context import get_forward_context
-from vllm.attention.layer import Attention
-import torch
+from vllm_mindspore.model_executor.models.attention_mask import (
+ LowerTriangularMask)
+from vllm_mindspore.utils import STR_DTYPE_TO_MS_DTYPE
+from vllm_mindspore.v1.attention.backends.ms_attn import MsAttentionMetadata
-from mindspore import Tensor, nn, mutable
+class AttentionWrapper:
-class Fake_Attention:
def __init__(self):
vllm_config = get_current_vllm_config()
block_size = vllm_config.cache_config.block_size
num_kv_heads = vllm_config.model_config.get_num_kv_heads(
- vllm_config.parallel_config
- )
+ vllm_config.parallel_config)
head_size = vllm_config.model_config.get_head_size()
num_block = 0
self.kv_shape = [num_block, block_size, num_kv_heads, head_size]
- self.kv_cache = [
- (
- torch.zeros(self.kv_shape, dtype=torch.bfloat16, device="Ascend"),
- torch.zeros(self.kv_shape, dtype=torch.bfloat16, device="Ascend"),
- )
- for _ in range(vllm_config.parallel_config.pipeline_parallel_size)
- ]
+ self.kv_cache = [(
+ ms.mint.zeros(self.kv_shape, dtype=vllm_config.model_config.dtype),
+ ms.mint.zeros(self.kv_shape, dtype=vllm_config.model_config.dtype),
+ ) for _ in range(vllm_config.parallel_config.pipeline_parallel_size)]
self.attn_type = AttentionType.DECODER
-
-class Fake_MLA(Fake_Attention):
- def __init__(self):
- super().__init__()
- vllm_config = get_current_vllm_config()
- self.kv_cache = [
- (torch.zeros(self.kv_shape, dtype=torch.bfloat16, device="Ascend"),)
- for _ in range(vllm_config.parallel_config.pipeline_parallel_size)
- ]
-
-
-class Fake_Attention_V1(Attention):
- def __init__(self):
- vllm_config = get_current_vllm_config()
- block_size = vllm_config.cache_config.block_size
- num_kv_heads = vllm_config.model_config.get_num_kv_heads(
- vllm_config.parallel_config
- )
- head_size = vllm_config.model_config.get_head_size()
- num_block = 0
- self.kv_shape = [num_block, block_size, num_kv_heads, head_size]
- self.kv_cache = [
- (
- torch.zeros(self.kv_shape, dtype=torch.bfloat16, device="Ascend"),
- torch.zeros(self.kv_shape, dtype=torch.bfloat16, device="Ascend"),
- )
- for _ in range(vllm_config.parallel_config.pipeline_parallel_size)
- ]
- self.attn_type = AttentionType.DECODER
- self.num_block = num_block
+ # add for v1
self.num_kv_heads = num_kv_heads
self.head_size = head_size
self.dtype = vllm_config.model_config.dtype
@@ -89,19 +61,24 @@ class Fake_Attention_V1(Attention):
self.sliding_window = None
-class Fake_MLA_V1(Fake_Attention_V1):
+class MLAAttentionWrapper(AttentionWrapper):
+
def __init__(self):
super().__init__()
vllm_config = get_current_vllm_config()
self.kv_cache = [
- (torch.zeros(self.kv_shape, dtype=torch.bfloat16, device="Ascend"),)
+ (
+ ms.mint.zeros(
+ self.kv_shape, # type: ignore[misc]
+ dtype=vllm_config.model_config.dtype), )
for _ in range(vllm_config.parallel_config.pipeline_parallel_size)
]
-class MsModelBase():
+class MsModelBase:
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
- super(MsModelBase, self).__init__()
+ super().__init__()
config = vllm_config.model_config.hf_config
lora_config = vllm_config.lora_config
@@ -112,20 +89,26 @@ class MsModelBase():
self.parallel_config = vllm_config.parallel_config
self.load_config = vllm_config.load_config
self.scheduler_config = vllm_config.scheduler_config
+ self.enable_micro_batch = \
+ vllm_config.additional_config.get('enable_micro_batch', 0) == 1 \
+ if vllm_config.additional_config is not None else False
- self.modules_dict = None
+ self.modules_dict: Any = None
self.enable_chunked_prefill = vllm_config.scheduler_config.enable_chunked_prefill
self.enable_prefix_caching = vllm_config.cache_config.enable_prefix_caching
self.is_multi_step = vllm_config.scheduler_config.is_multi_step
self.is_multi_step_chunked_prefill = self.is_multi_step and self.enable_chunked_prefill
+ self.set_flags = False
+
def get_model_path(self):
model_name_or_path = self.model_config.model
if os.path.isdir(model_name_or_path):
return model_name_or_path
else:
- from vllm.model_executor.model_loader.weight_utils import download_weights_from_hf
+ from vllm.model_executor.model_loader.weight_utils import (
+ download_weights_from_hf)
allow_patterns = ["*.safetensors"]
revision = self.model_config.revision
return download_weights_from_hf(
@@ -137,7 +120,7 @@ class MsModelBase():
)
def set_modules(self, model_dicts: Dict[str, nn.Cell]):
- self.modules_dict = model_dicts
+ self.modules_dict = model_dicts # type: ignore[assignment]
def _check_modules_valid(self):
if self.modules_dict is None:
@@ -171,15 +154,24 @@ class MsModelBase():
def named_modules(self, remove_duplicate: bool = True):
self._check_modules_valid()
- res_modules = set()
for name, module in self.modules_dict.items():
for module_name, sub_module in module.cells_and_names():
if name != "self":
module_name = name + "." + module_name
yield module_name, sub_module
- def get_submodule(self):
- raise RuntimeError("Cannot get submodule for mindspore model now!")
+ def get_submodule(self, target: str):
+ parts = target.split(".")
+ if target == "":
+ return self
+ for part in parts:
+ if not part:
+ raise ValueError(
+ f"Invalid submodule path: empty part in '{target}'")
+ current = self
+ for part in parts:
+ current = getattr(current, part)
+ return current
def eval(self):
self._check_modules_valid()
@@ -197,75 +189,33 @@ class MsModelBase():
inputs_embeds: Optional[Tensor] = None,
previous_hidden_states: Optional[Tensor] = None,
spec_step_idx: int = 0,
+ **kwargs,
) -> Union[Tensor, IntermediateTensors]:
- return self.forward(
- input_ids,
- positions,
- intermediate_tensors,
- inputs_embeds,
- previous_hidden_states=previous_hidden_states,
- spec_step_idx=spec_step_idx
- )
-
- def forward(
- self,
- input_ids: Tensor,
- positions: Tensor,
- intermediate_tensors: Optional[IntermediateTensors] = None,
- inputs_embeds: Optional[Tensor] = None,
- **kwargs
- ) -> Union[Tensor, IntermediateTensors]:
+ return self.forward(input_ids,
+ positions,
+ intermediate_tensors,
+ inputs_embeds,
+ previous_hidden_states=previous_hidden_states,
+ spec_step_idx=spec_step_idx,
+ **kwargs)
+
+ def forward(self,
+ input_ids: Tensor,
+ positions: Tensor,
+ intermediate_tensors: Optional[IntermediateTensors] = None,
+ inputs_embeds: Optional[Tensor] = None,
+ **kwargs) -> Union[Tensor, IntermediateTensors]:
raise NotImplementedError
- def set_model_inputs(self, is_prefill):
- dyn_input_ids = Tensor(shape=[None, None], dtype=mstype.int64)
- dyn_position_ids = Tensor(shape=[None], dtype=mstype.int64)
-
- block_size = self.cache_config.block_size
- num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
- head_size = self.model_config.get_head_size()
- kv_cache_shape = (None, block_size, num_kv_heads, head_size)
-
- kv_cache_dtype = self.model_config.dtype if self.cache_config.cache_dtype == "auto" \
- else self.cache_config.cache_dtype
- if kv_cache_dtype in STR_DTYPE_TO_MS_DTYPE:
- kv_cache_dtype = STR_DTYPE_TO_MS_DTYPE[kv_cache_dtype]
-
- num_layers = self.model_config.get_num_layers(self.parallel_config)
-
- dyn_key_cache = mutable(Tensor(shape=kv_cache_shape, dtype=kv_cache_dtype))
- dyn_value_cache = mutable(Tensor(shape=kv_cache_shape, dtype=kv_cache_dtype))
- dyn_key_caches = mutable([dyn_key_cache for _ in range(num_layers)])
- dyn_value_caches = mutable([dyn_value_cache for _ in range(num_layers)])
-
- dyn_batch_valid_length = Tensor(shape=[None, ], dtype=mstype.int32)
- dyn_q_seq_lens = Tensor(shape=[None, ], dtype=mstype.int32)
- dyn_slot_mapping = Tensor(shape=[None, ], dtype=mstype.int32)
- dyn_block_tables = Tensor(shape=[None, None], dtype=mstype.int32)
- dyn_intermediate_tensors = None
- dyn_inputs_embeds = None
-
- self.model.set_inputs(
- dyn_input_ids,
- dyn_position_ids,
- dyn_key_caches,
- dyn_value_caches,
- is_prefill,
- dyn_slot_mapping,
- dyn_batch_valid_length,
- dyn_q_seq_lens,
- dyn_block_tables,
- dyn_intermediate_tensors,
- dyn_inputs_embeds
- )
-
def get_kvcache(self):
key_cache = []
value_cache = []
forward_context = get_forward_context()
for i in range(self.config.num_hidden_layers):
- k_cache = self.kv_caches[i].kv_cache[forward_context.virtual_engine][0]
- v_cache = self.kv_caches[i].kv_cache[forward_context.virtual_engine][1]
+ k_cache = self.kv_caches[i].kv_cache[ # type: ignore[attr-defined]
+ forward_context.virtual_engine][0]
+ v_cache = self.kv_caches[i].kv_cache[ # type: ignore[attr-defined]
+ forward_context.virtual_engine][1]
key_cache.append(k_cache)
value_cache.append(v_cache)
return mutable(key_cache), mutable(value_cache)
@@ -276,7 +226,8 @@ class MsModelBase():
hidden_states: Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[Tensor]:
- raise NotImplementedError("Function compute_logits should be Implemented!")
+ raise NotImplementedError(
+ "Function compute_logits should be Implemented!")
@abstractmethod
def sample(
@@ -288,4 +239,231 @@ class MsModelBase():
@abstractmethod
def load_weights(self, weights: Iterable[Tuple[str, Tensor]]) -> Set[str]:
- raise NotImplementedError("Function load_weights should be Implemented!")
+ raise NotImplementedError(
+ "Function load_weights should be Implemented!")
+
+ def _dummy_attention_metadata(self, input_ids: Tensor, positions: Tensor):
+ if input_ids is not None:
+ input_len = input_ids.shape[0]
+ elif positions is not None:
+ # input_ids is None in multi modal model with v1 arch
+ input_len = positions.shape[-1]
+
+ max_seq_len = ms.Tensor(input_len, dtype=ms.int32)
+ seq_lengths = ms.Tensor([input_len], dtype=ms.int32)
+ q_seq_lens_np = np.array([input_len], dtype=np.int32)
+ seq_lens_np = np.array([input_len], dtype=np.int32)
+ context_lens_tensor = ms.Tensor([0], dtype=ms.int32)
+
+ block_tables = ms.Tensor([[0]], dtype=ms.int32)
+ slot_mapping = [-1 for _ in range(input_len)]
+ slot_mapping = ms.Tensor(slot_mapping, dtype=ms.int32)
+ return MsAttentionMetadata(
+ max_seq_len=max_seq_len,
+ seq_lens=seq_lengths,
+ seq_lens_np=seq_lens_np,
+ block_tables=block_tables,
+ slot_mapping=slot_mapping,
+ q_seq_lens_np=q_seq_lens_np,
+ context_lens=context_lens_tensor,
+ # To enforce prefill and decode are both complied in warmup process.
+ # So set max_context_lens to 0 for prefill and 1 for decode.
+ max_context_lens=0 if not self.set_flags else 1,
+ query_start_loc=None)
+
+ def prepare_base_inputs(self, input_ids, positions):
+ attn_metadata = get_forward_context().attn_metadata
+ if attn_metadata is None:
+ attn_metadata = self._dummy_attention_metadata(
+ input_ids, positions)
+ key_cache, value_cache = self.get_kvcache()
+ if not envs.VLLM_USE_V1:
+ # V0
+ seq_lens = attn_metadata.seq_lens
+ max_query_len = attn_metadata.max_query_len
+ # When Mutli-Step is enabled with Chunked-Prefill, prefills and
+ # decodes are scheduled together. In the first step, all the
+ # prefills turn into decodes and max_query_len will be 1.
+ if self.is_multi_step_chunked_prefill and max_query_len == 1:
+ query_lens = [1] * len(seq_lens)
+ else:
+ query_lens = attn_metadata.query_lens
+
+ seq_lens_np = np.array(seq_lens, dtype=np.int32)
+ query_lens_np = np.array(query_lens, dtype=np.int32)
+ kv_cache_lens = seq_lens_np - query_lens_np
+ if attn_metadata.num_decode_tokens == 0 and kv_cache_lens.max(
+ ) == 0:
+ is_prefill = True
+ else:
+ is_prefill = False
+ else:
+ # V1
+ is_prefill = attn_metadata.max_context_lens == 0
+ query_lens_np = attn_metadata.q_seq_lens_np
+ seq_lens_np = attn_metadata.seq_lens_np
+
+ if input_ids is not None:
+ input_ids = input_ids.astype(ms.int32)
+ q_seq_lens = ms.Tensor(query_lens_np, dtype=ms.int32)
+ position_ids = ms.Tensor(positions, dtype=ms.int32)
+ attention_mask = self.casual_mask.gen_attention_mask( # type: ignore[attr-defined]
+ is_prefill, positions, query_lens_np, attn_metadata)
+
+ model_inputs = {}
+ model_inputs["input_ids"] = input_ids
+ model_inputs["batch_valid_length"] = ms.from_numpy(seq_lens_np)
+ model_inputs["block_tables"] = attn_metadata.block_tables
+ model_inputs["slot_mapping"] = attn_metadata.slot_mapping
+ model_inputs["position_ids"] = position_ids
+ model_inputs["q_seq_lens"] = q_seq_lens
+ model_inputs["attention_mask"] = attention_mask
+ model_inputs["key_cache"] = key_cache
+ model_inputs["value_cache"] = value_cache
+
+ return model_inputs, is_prefill
+
+
+class NativeModel(MsModelBase):
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
+ super().__init__(vllm_config=vllm_config, prefix=prefix)
+ self.quant_config = vllm_config.quant_config
+ if vllm_config.lora_config is not None:
+ # native model lora only support pynative mode now
+ vllm_config.model_config.enforce_eager = True
+ self.is_graph_mode = bool(not vllm_config.model_config.enforce_eager)
+ self.prev_prefill = False
+ self.run_model = None
+ self.model = None
+ self.lm_head = None
+
+ def common_preprocess(self, vllm_config, prefix=""):
+ self.set_modules({"model": self.model, "lm_head": self.lm_head})
+
+ self.casual_mask = LowerTriangularMask(
+ dtype=self.model_config.dtype,
+ max_model_len=self.model_config.max_model_len)
+ self.kv_caches = [
+ AttentionWrapper() for i in range(self.config.num_hidden_layers)
+ ]
+
+ compilation_config = vllm_config.compilation_config
+ if prefix in compilation_config.static_forward_context:
+ raise ValueError(f"Duplicate layer name: {prefix}")
+ for i in range(self.config.num_hidden_layers):
+ compilation_config.static_forward_context[str(
+ i)] = self.kv_caches[i]
+
+ def set_model_inputs(self, input_ids, position_ids, intermediate_tensors,
+ inputs_embeds, is_prefill):
+ if input_ids is None:
+ dyn_input_ids = None
+ else:
+ dyn_input_ids = ms.Tensor(shape=[None] * input_ids.ndim,
+ dtype=mstype.int32)
+
+ if position_ids is None:
+ dyn_position_ids = None
+ else:
+ dyn_position_ids = ms.Tensor(shape=[None] * position_ids.ndim,
+ dtype=mstype.int32)
+
+ if inputs_embeds is None:
+ dyn_inputs_embeds = None
+ else:
+ dyn_inputs_embeds = ms.Tensor(shape=[None] * inputs_embeds.ndim,
+ dtype=inputs_embeds.dtype)
+
+ if intermediate_tensors is None:
+ dyn_intermediate_tensors = None
+ else:
+ dyn_intermediate_tensors = ms.Tensor(
+ shape=[None] * intermediate_tensors.ndim,
+ dtype=intermediate_tensors.dtype)
+
+ block_size = self.cache_config.block_size
+ num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
+ head_size = self.model_config.get_head_size()
+ kv_cache_shape = (None, block_size, num_kv_heads, head_size)
+
+ kv_cache_dtype = self.model_config.dtype if self.cache_config.cache_dtype == "auto" \
+ else self.cache_config.cache_dtype
+ if kv_cache_dtype in STR_DTYPE_TO_MS_DTYPE:
+ kv_cache_dtype = STR_DTYPE_TO_MS_DTYPE[kv_cache_dtype]
+
+ num_layers = self.model_config.get_num_layers(self.parallel_config)
+
+ dyn_key_cache = Tensor(shape=kv_cache_shape, dtype=kv_cache_dtype)
+ dyn_value_cache = Tensor(shape=kv_cache_shape, dtype=kv_cache_dtype)
+ dyn_key_caches = mutable([dyn_key_cache for _ in range(num_layers)])
+ dyn_value_caches = mutable(
+ [dyn_value_cache for _ in range(num_layers)])
+
+ dyn_slot_mapping = Tensor(shape=[None], dtype=mstype.int32)
+ dynamic_attention_mask = Tensor(shape=[None, None],
+ dtype=self.model_config.dtype)
+ dyn_batch_valid_length = Tensor(shape=[None], dtype=mstype.int32)
+ dyn_q_seq_lens = Tensor(shape=[None], dtype=mstype.int32)
+ dyn_block_tables = Tensor(shape=[None, None], dtype=mstype.int32)
+ self.model.set_inputs( # type: ignore[attr-defined]
+ dyn_input_ids, dyn_position_ids, dyn_key_caches, dyn_value_caches,
+ is_prefill, dyn_slot_mapping, dynamic_attention_mask,
+ dyn_batch_valid_length, dyn_q_seq_lens, dyn_block_tables,
+ dyn_intermediate_tensors, dyn_inputs_embeds)
+
+ dynamic_hidden_states = Tensor(shape=[None, None],
+ dtype=self.model_config.dtype)
+ self.lm_head.set_inputs( # type: ignore[attr-defined]
+ dynamic_hidden_states)
+
+ def prepare_inputs(self, input_ids, positions, intermediate_tensors,
+ inputs_embeds):
+ model_inputs, is_prefill = self.prepare_base_inputs(
+ input_ids, positions)
+
+ # for multimodal model
+ model_inputs["intermediate_tensors"] = intermediate_tensors
+ model_inputs["inputs_embeds"] = inputs_embeds
+
+ return model_inputs, is_prefill
+
+ def exec_model(self,
+ input_ids: Tensor,
+ positions: Tensor,
+ intermediate_tensors: IntermediateTensors = None,
+ inputs_embeds: Tensor = None,
+ **kwargs):
+ model_inputs, is_prefill = self.prepare_inputs(input_ids, positions,
+ intermediate_tensors,
+ inputs_embeds)
+
+ if self.prev_prefill != is_prefill and self.is_graph_mode:
+ self.set_model_inputs(input_ids, positions, intermediate_tensors,
+ inputs_embeds, is_prefill)
+ self.prev_prefill = is_prefill
+
+ # for dummy_attention_metadata
+ if is_prefill and not self.set_flags:
+ self.set_flags = True
+
+ if self.run_model is None:
+ self.run_model = ms.jit(
+ function=self.model,
+ jit_level='O0') if self.is_graph_mode else self.model
+ model_output = self.run_model( # type: ignore[misc]
+ input_ids=model_inputs["input_ids"],
+ positions=model_inputs["position_ids"],
+ key_caches=model_inputs["key_cache"],
+ value_caches=model_inputs["value_cache"],
+ is_prefill=is_prefill,
+ slot_mapping=model_inputs["slot_mapping"],
+ attn_mask=model_inputs["attention_mask"],
+ batch_valid_length=model_inputs["batch_valid_length"],
+ q_seq_lens=model_inputs["q_seq_lens"],
+ block_tables=model_inputs["block_tables"],
+ intermediate_tensors=model_inputs["intermediate_tensors"],
+ inputs_embeds=model_inputs["inputs_embeds"],
+ )
+
+ return model_output
diff --git a/vllm_mindspore/model_executor/models/qwen2.py b/vllm_mindspore/model_executor/models/qwen2.py
index 444ddc5a010322a912f6fe4f99db855a5385b1d7..87c54c2126c2456cf7ec73ec123f8b3050386570 100644
--- a/vllm_mindspore/model_executor/models/qwen2.py
+++ b/vllm_mindspore/model_executor/models/qwen2.py
@@ -1,5 +1,6 @@
#!/usr/bin/env python3
-# encoding: utf-8
+# type: ignore
+# isort:skip_file
# Copyright 2025 Huawei Technologies Co., Ltd
# Copyright 2024 The vLLM team.
#
@@ -15,21 +16,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
-from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union, Iterable
+from typing import (TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple,
+ Union)
if TYPE_CHECKING:
from transformers import Qwen2Config
else:
Qwen2Config = None
-import numpy as np
-
-from mindspore import Parameter, Tensor, mint, nn, jit, ops, mutable
-from mindspore.common import dtype as mstype
+from mindspore import Parameter, Tensor, mint, nn
+from vllm.attention.backends.abstract import AttentionType
+from vllm.config import CacheConfig, VllmConfig
+from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
+from vllm.model_executor.layers.quantization import QuantizationConfig
+from vllm.model_executor.models.interfaces import SupportsLoRA
+from vllm.sequence import IntermediateTensors
from vllm_mindspore.attention import Attention
-
from vllm_mindspore.model_executor.layers.activation import SwiGLU
from vllm_mindspore.model_executor.layers.layernorm import RMSNorm
from vllm_mindspore.model_executor.layers.linear import (
@@ -43,27 +47,15 @@ from vllm_mindspore.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm_mindspore.model_executor.model_loader.weight_utils import \
default_weight_loader
+from vllm_mindspore.model_executor.models.model_base import NativeModel
from vllm_mindspore.model_executor.models.utils import (
PPMissingLayer, make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
-from vllm_mindspore.model_executor.sampling_metadata import SamplingMetadata
-from vllm_mindspore.model_executor.models.model_base import MsModelBase, Fake_Attention, Fake_Attention_V1
-from vllm_mindspore.model_executor.models.attention_mask import LowerTriangularMask
-from vllm_mindspore.utils import STR_DTYPE_TO_MS_DTYPE
+from vllm.model_executor.sampling_metadata import SamplingMetadata
-from vllm.config import CacheConfig, VllmConfig
-import vllm.envs as envs
-from vllm.model_executor.layers.quantization import \
- QuantizationConfig
-from vllm.sequence import IntermediateTensors
-from vllm.attention.backends.abstract import AttentionType
-from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
-from vllm.forward_context import get_forward_context
-from vllm_mindspore.v1.attention.backends.flash_attn import FlashAttentionMetadata
-import mindspore as ms
-
class Qwen2MLP(nn.Cell):
+
def __init__(
self,
hidden_size: int,
@@ -79,23 +71,17 @@ class Qwen2MLP(nn.Cell):
output_sizes=[intermediate_size] * 2,
bias=bias,
quant_config=quant_config,
- prefix=f"{prefix}.gate_up_proj",
- params_dtype=mstype.bfloat16
- )
- self.down_proj = RowParallelLinear(
- input_size=intermediate_size,
- output_size=hidden_size,
- bias=bias,
- quant_config=quant_config,
- prefix=f"{prefix}.down_proj",
- params_dtype=mstype.bfloat16
- )
+ prefix=f"{prefix}.gate_up_proj")
+ self.down_proj = RowParallelLinear(input_size=intermediate_size,
+ output_size=hidden_size,
+ bias=bias,
+ quant_config=quant_config,
+ prefix=f"{prefix}.down_proj")
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SwiGLU()
- @jit
def construct(self, x):
x, _ = self.gate_up_proj(x)
x = self.act_fn(x)
@@ -104,19 +90,18 @@ class Qwen2MLP(nn.Cell):
class Qwen2Attention(nn.Cell):
- def __init__(
- self,
- hidden_size: int,
- num_heads: int,
- num_kv_heads: int,
- max_position: int = 4096 * 32,
- rope_theta: float = 10000,
- cache_config: Optional[CacheConfig] = None,
- quant_config: Optional[QuantizationConfig] = None,
- rope_scaling: Optional[Tuple] = None,
- prefix: str = "",
- attn_type: str = AttentionType.DECODER
- ) -> None:
+
+ def __init__(self,
+ hidden_size: int,
+ num_heads: int,
+ num_kv_heads: int,
+ max_position: int = 4096 * 32,
+ rope_theta: float = 10000,
+ cache_config: Optional[CacheConfig] = None,
+ quant_config: Optional[QuantizationConfig] = None,
+ rope_scaling: Optional[Tuple] = None,
+ prefix: str = "",
+ attn_type: str = AttentionType.DECODER) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
@@ -147,7 +132,6 @@ class Qwen2Attention(nn.Cell):
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
- params_dtype=mstype.bfloat16,
)
self.o_proj = RowParallelLinear(
input_size=self.total_num_heads * self.head_dim,
@@ -155,7 +139,6 @@ class Qwen2Attention(nn.Cell):
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
- params_dtype=mstype.bfloat16,
)
self.rotary_emb = get_rope(
@@ -164,20 +147,16 @@ class Qwen2Attention(nn.Cell):
max_position=max_position,
base=self.rope_theta,
rope_scaling=rope_scaling,
- dtype=mstype.bfloat16,
- )
- self.attn = Attention(
- self.num_heads,
- self.head_dim,
- self.scaling,
- num_kv_heads=self.num_kv_heads,
- cache_config=cache_config,
- quant_config=quant_config,
- prefix=f"{prefix}.attn",
- attn_type=attn_type
)
+ self.attn = Attention(self.num_heads,
+ self.head_dim,
+ self.scaling,
+ num_kv_heads=self.num_kv_heads,
+ cache_config=cache_config,
+ quant_config=quant_config,
+ prefix=f"{prefix}.attn",
+ attn_type=attn_type)
- @jit
def construct(
self,
positions: Tensor,
@@ -192,10 +171,12 @@ class Qwen2Attention(nn.Cell):
block_tables: Tensor,
) -> Tensor:
qkv, _ = self.qkv_proj(hidden_states)
- q, k, v = mint.split(qkv, (self.q_size, self.kv_size, self.kv_size), -1)
+ q, k, v = mint.split(qkv, (self.q_size, self.kv_size, self.kv_size),
+ -1)
q, k = self.rotary_emb(positions, q, k, batch_valid_length, is_prefill)
- attn_output = self.attn(q, k, v, key_cache, value_cache, is_prefill, slot_mapping, attn_mask,
- batch_valid_length, q_seq_lens, block_tables)
+ attn_output = self.attn(q, k, v, key_cache, value_cache, is_prefill,
+ slot_mapping, attn_mask, batch_valid_length,
+ q_seq_lens, block_tables)
output, _ = self.o_proj(attn_output)
return output
@@ -243,14 +224,15 @@ class Qwen2DecoderLayer(nn.Cell):
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
- self.input_layernorm = RMSNorm(config.hidden_size,
- eps=config.rms_norm_eps,
- params_dtype=mstype.bfloat16,)
- self.post_attention_layernorm = RMSNorm(config.hidden_size,
- eps=config.rms_norm_eps,
- params_dtype=mstype.bfloat16,)
-
- @jit
+ self.input_layernorm = RMSNorm(
+ config.hidden_size,
+ eps=config.rms_norm_eps,
+ )
+ self.post_attention_layernorm = RMSNorm(
+ config.hidden_size,
+ eps=config.rms_norm_eps,
+ )
+
def construct(
self,
positions: Tensor,
@@ -270,22 +252,16 @@ class Qwen2DecoderLayer(nn.Cell):
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
- hidden_states, residual = self.input_layernorm(hidden_states, residual)
- hidden_states = self.self_attn(
- positions,
- hidden_states,
- key_cache,
- value_cache,
- is_prefill,
- slot_mapping,
- attn_mask,
- batch_valid_length,
- q_seq_lens,
- block_tables
- )
+ hidden_states, residual = self.input_layernorm(
+ hidden_states, residual)
+ hidden_states = self.self_attn(positions, hidden_states, key_cache,
+ value_cache, is_prefill, slot_mapping,
+ attn_mask, batch_valid_length,
+ q_seq_lens, block_tables)
# Fully Connected
- hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
+ hidden_states, residual = self.post_attention_layernorm(
+ hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
@@ -308,7 +284,6 @@ class Qwen2Model(nn.Cell):
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
- params_dtype=mstype.bfloat16,
quant_config=quant_config,
prefix=f"{prefix}.embed_tokens",
)
@@ -328,15 +303,13 @@ class Qwen2Model(nn.Cell):
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
if get_pp_group().is_last_rank:
- self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps,
- params_dtype=mstype.bfloat16,)
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer()
def get_input_embeddings(self, input_ids: Tensor) -> Tensor:
return self.embed_tokens(input_ids)
- @jit
def construct(
self,
input_ids: Optional[Tensor],
@@ -364,19 +337,12 @@ class Qwen2Model(nn.Cell):
for i in range(self.start_layer, self.end_layer): # PP 并行对层进行切分
layer = self.layers[i]
- hidden_states, residual = layer(
- positions,
- hidden_states,
- key_caches[i - self.start_layer],
- value_caches[i - self.start_layer],
- is_prefill,
- slot_mapping,
- attn_mask,
- batch_valid_length,
- q_seq_lens,
- block_tables,
- residual
- )
+ hidden_states, residual = layer(positions, hidden_states,
+ key_caches[i - self.start_layer],
+ value_caches[i - self.start_layer],
+ is_prefill, slot_mapping,
+ attn_mask, batch_valid_length,
+ q_seq_lens, block_tables, residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
@@ -385,7 +351,8 @@ class Qwen2Model(nn.Cell):
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
- def load_weights(self, weights: Iterable[Tuple[str, Tensor]], params_dict: Dict[str, Parameter]):
+ def load_weights(self, weights: Iterable[Tuple[str, Tensor]],
+ params_dict: Dict[str, Parameter]):
loaded_params: Set[str] = set()
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
@@ -405,7 +372,7 @@ class Qwen2Model(nn.Cell):
# the checkpoint. Skip them.
continue
if (self.quant_config is not None and
- (scale_name := self.quant_config.get_cache_scale(name))):
+ (scale_name := self.quant_config.get_cache_scale(name))):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
@@ -436,7 +403,7 @@ class Qwen2Model(nn.Cell):
return loaded_params
-class Qwen2ForCausalLM(MsModelBase):
+class Qwen2ForCausalLM(NativeModel, SupportsLoRA):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
@@ -478,184 +445,36 @@ class Qwen2ForCausalLM(MsModelBase):
else:
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
- params_dtype=mstype.bfloat16,
quant_config=quant_config,
- prefix=maybe_prefix(prefix, "lm_head"))
- self.logits_processor = LogitsProcessor(config.vocab_size)
- self.sampler = get_sampler()
+ prefix=maybe_prefix(
+ prefix, "lm_head"))
else:
self.lm_head = PPMissingLayer()
+ self.logits_processor = LogitsProcessor(self.config.vocab_size)
+ self.sampler = get_sampler()
+
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
- self.set_modules({"model": self.model, "lm_head": self.lm_head})
-
- self.prefill = True
- self.mstype = STR_DTYPE_TO_MS_DTYPE.get(self.model_config.dtype, self.model_config.dtype)
- self.casual_mask = LowerTriangularMask(dtype=self.mstype,
- max_model_len=self.model_config.max_model_len)
- self.set_model_inputs(self.prefill)
- if envs.VLLM_USE_V1:
- self.kv_caches = [Fake_Attention_V1() for i in range(config.num_hidden_layers)]
- else:
- self.kv_caches = [Fake_Attention() for i in range(config.num_hidden_layers)]
- compilation_config = vllm_config.compilation_config
-
- if prefix in compilation_config.static_forward_context:
- raise ValueError(f"Duplicate layer name: {prefix}")
- for i in range(config.num_hidden_layers):
- compilation_config.static_forward_context[str(i)] = self.kv_caches[i]
-
- def set_model_inputs(self, is_prefill):
- dyn_input_ids = Tensor(shape=[None, None], dtype=mstype.int64)
- dyn_position_ids = Tensor(shape=[None], dtype=mstype.int64)
-
- block_size = self.cache_config.block_size
- num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
- head_size = self.model_config.get_head_size()
- kv_cache_shape = (None, block_size, num_kv_heads, head_size)
-
- kv_cache_dtype = self.model_config.dtype if self.cache_config.cache_dtype == "auto" \
- else self.cache_config.cache_dtype
- if kv_cache_dtype in STR_DTYPE_TO_MS_DTYPE:
- kv_cache_dtype = STR_DTYPE_TO_MS_DTYPE[kv_cache_dtype]
-
- num_layers = self.model_config.get_num_layers(self.parallel_config)
-
- dyn_key_cache = Tensor(shape=kv_cache_shape, dtype=kv_cache_dtype)
- dyn_value_cache = Tensor(shape=kv_cache_shape, dtype=kv_cache_dtype)
- dyn_key_caches = mutable([dyn_key_cache for _ in range(num_layers)])
- dyn_value_caches = mutable([dyn_value_cache for _ in range(num_layers)])
-
- dyn_slot_mapping = Tensor(shape=[None, ], dtype=mstype.int32)
- dynamic_attention_mask = Tensor(shape=[None, None], dtype=self.mstype)
- dyn_batch_valid_length = Tensor(shape=[None,], dtype=mstype.int32)
- dyn_q_seq_lens = Tensor(shape=[None, ], dtype=mstype.int32)
- dyn_block_tables = Tensor(shape=[None, None], dtype=mstype.int32)
- dyn_intermediate_tensors = None
- dyn_inputs_embeds = None
- self.model.set_inputs(
- dyn_input_ids,
- dyn_position_ids,
- dyn_key_caches,
- dyn_value_caches,
- is_prefill,
- dyn_slot_mapping,
- dynamic_attention_mask,
- dyn_batch_valid_length,
- dyn_q_seq_lens,
- dyn_block_tables,
- dyn_intermediate_tensors,
- dyn_inputs_embeds
- )
- def forward(
- self,
- input_ids: Tensor,
- positions: Tensor,
- intermediate_tensors: IntermediateTensors = None,
- inputs_embeds: Tensor = None,
- **kwargs
- ) -> Union[Tensor, IntermediateTensors]:
- key_cache, value_cache = self.get_kvcache()
- attn_metadata = get_forward_context().attn_metadata
- input_ids = input_ids.to(ms.int64)
- if attn_metadata is None:
- attn_metadata = self._dummy_attention_metadata(input_ids, positions)
- if not envs.VLLM_USE_V1:
- seq_lens = attn_metadata.seq_lens
- max_query_len = attn_metadata.max_query_len
- # When Mutli-Step is enabled with Chunked-Prefill, prefills and
- # decodes are scheduled together. In the first step, all the
- # prefills turn into decodes and max_query_len will be 1.
- if self.is_multi_step_chunked_prefill and max_query_len == 1:
- query_lens = [1] * len(seq_lens)
- else:
- query_lens = attn_metadata.query_lens
-
- seq_lens_np = np.array(seq_lens, dtype=np.int32)
- query_lens_np = np.array(query_lens, dtype=np.int32)
- kv_cache_lens = seq_lens_np - query_lens_np
- is_prefill = attn_metadata.num_decode_tokens == 0 and kv_cache_lens.max() == 0
- slot_mapping = attn_metadata.slot_mapping
- batch_valid_length = Tensor.from_numpy(np.array(attn_metadata.seq_lens, dtype=np.int32))
- q_seq_lens = ms.Tensor(query_lens_np, dtype=ms.int32)
- block_tables = attn_metadata.block_tables
- position_ids = ms.Tensor(positions, dtype=ms.int32)
- attn_mask = self.casual_mask.gen_attention_mask(is_prefill, position_ids, query_lens)
- else:
- if attn_metadata.max_context_lens == 0:
- is_prefill = True
- else:
- is_prefill = False
- slot_mapping = attn_metadata.slot_mapping
- batch_valid_length = Tensor.from_numpy(attn_metadata.seq_lens_np)
- q_seq_lens = attn_metadata.q_seq_lens
- block_tables = attn_metadata.block_tables
- query_lens_np = attn_metadata.q_seq_lens_np
- attn_mask = self.casual_mask.gen_attention_mask(is_prefill, positions, query_lens_np)
- positions = positions.to(ms.int64)
- if is_prefill:
- input_ids = ops.expand_dims(input_ids, 0)
- if not self.prefill:
- self.prefill = True
- self.set_model_inputs(self.prefill)
- else:
- input_ids = ops.expand_dims(input_ids, 1)
- if self.prefill:
- self.prefill = False
- self.set_model_inputs(self.prefill)
- model_output = self.model(input_ids,
- positions,
- key_cache,
- value_cache,
- is_prefill,
- slot_mapping,
- attn_mask,
- batch_valid_length,
- q_seq_lens,
- block_tables,
- intermediate_tensors,
- inputs_embeds)
- if is_prefill:
- model_output = ops.squeeze(model_output, 0)
- else:
- model_output = ops.squeeze(model_output, 1)
- return model_output
-
- def _dummy_attention_metadata(self, input_ids: Tensor, positions: Tensor) -> FlashAttentionMetadata:
- input_len = input_ids.shape[0]
- max_seq_len = ms.Tensor(input_len, dtype=ms.int32)
- seq_lengths = ms.Tensor([input_len], dtype=ms.int32)
- q_seq_lens = ms.Tensor([input_len], dtype=ms.int32)
- q_seq_lens_np = np.array([input_len], dtype=np.int32)
- seq_lens_np = np.array([input_len], dtype=np.int32)
-
- block_tables = ms.Tensor([[0]], dtype=ms.int32)
- slot_mapping = [-1 for _ in range(input_len)]
- slot_mapping = ms.Tensor(slot_mapping, dtype=ms.int32)
- return FlashAttentionMetadata(
- max_seq_len=max_seq_len,
- seq_lens=seq_lengths,
- seq_lens_np=seq_lens_np,
- block_tables=block_tables,
- slot_mapping=slot_mapping,
- q_seq_lens=q_seq_lens,
- q_seq_lens_np=q_seq_lens_np,
- context_lens=0,
- # To enforce prefill and decode are both complied in warmup process.
- # So set max_context_lens to 0 for prefill and 1 for decode.
- max_context_lens=0 if self.prefill else 1,
- query_start_loc = None
- )
+ self.common_preprocess(vllm_config, prefix)
+
+ def forward(self,
+ input_ids: Tensor,
+ positions: Tensor,
+ intermediate_tensors: IntermediateTensors = None,
+ inputs_embeds: Tensor = None,
+ **kwargs) -> Union[Tensor, IntermediateTensors]:
+ hidden_states = self.exec_model(input_ids, positions,
+ intermediate_tensors, inputs_embeds)
+ return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, Tensor]]) -> Set[str]:
params_dict = self.get_params_dict()
self.model.load_weights(weights, params_dict)
- def sample(
- self, logits: Tensor, sampling_metadata: SamplingMetadata
- ) -> Optional[SamplerOutput]:
+ def sample(self, logits: Tensor,
+ sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
@@ -664,5 +483,6 @@ class Qwen2ForCausalLM(MsModelBase):
hidden_states: Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[Tensor]:
- logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata)
+ logits = self.logits_processor(self.lm_head, hidden_states,
+ sampling_metadata)
return logits
diff --git a/vllm_mindspore/model_executor/models/qwen2_5_vl.py b/vllm_mindspore/model_executor/models/qwen2_5_vl.py
new file mode 100644
index 0000000000000000000000000000000000000000..b5a6bb44e8c52b293168675d356b6e256e81e60b
--- /dev/null
+++ b/vllm_mindspore/model_executor/models/qwen2_5_vl.py
@@ -0,0 +1,1078 @@
+# SPDX-License-Identifier: Apache-2.0
+# type: ignore
+# isort:skip_file
+# Adapted from
+# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/qwen2_5_vl.py
+# Copyright 2025 Huawei Technologites Co., Ltd
+# Copyright 2025 The vLLM team.
+# Copyright 2025 The Qwen Team.
+# Copyright 2025 The HuggingFace Inc. team.
+# All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Inference-only Qwen2.5-VL model compatible with HuggingFace weights."""
+import os
+from functools import partial
+from typing import Callable, Iterable, Mapping, Optional, Set, Tuple, Union, Dict, Any
+
+import math
+import mindspore as ms
+import mindspore.nn as nn
+import mindspore.mint as mint
+import mindspore.ops as ops
+import mindspore.mint.nn.functional as F
+from mindspore import dtype as mstype
+
+from vllm_mindspore.model_executor.layers.layernorm import RMSNorm
+from vllm_mindspore.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
+from vllm_mindspore.model_executor.layers.logits_processor import LogitsProcessor
+from vllm_mindspore.model_executor.layers.quantization.base_config import QuantizationConfig
+from vllm_mindspore.model_executor.layers.sampler import SamplerOutput, get_sampler
+from vllm_mindspore.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
+from vllm_mindspore.model_executor.model_loader.weight_utils import default_weight_loader
+from vllm_mindspore.model_executor.models.model_base import NativeModel, AttentionWrapper
+from vllm_mindspore.model_executor.models.interfaces import SupportsMultiModal
+from vllm_mindspore.model_executor.models.qwen2 import Qwen2Model # type: ignore[attr-defined]
+from vllm_mindspore.model_executor.models.utils import PPMissingLayer, WeightsMapper, maybe_prefix, \
+ merge_multimodal_embeddings
+from vllm_mindspore.model_executor.models.attention_mask import MultiModalLowerTriangularMask
+from vllm_mindspore.distributed.communication_op import AllGatherFromModelParallelRegion
+
+from vllm.model_executor.sampling_metadata import SamplingMetadata
+from vllm.model_executor.models.module_mapping import MultiModelKeys
+from vllm.model_executor.models.qwen2_vl import Qwen2VLDummyInputsBuilder as Qwen2_5_VLDummyInputsBuilder
+from vllm.model_executor.models.qwen2_vl import Qwen2VLMultiModalProcessor
+from vllm.model_executor.models.qwen2_5_vl import Qwen2_5_VLImageInputs, Qwen2_5_VLVideoInputs, \
+ Qwen2_5_VLImagePixelInputs, Qwen2_5_VLImageEmbeddingInputs, Qwen2_5_VLVideoPixelInputs, \
+ Qwen2_5_VLVideoEmbeddingInputs, Qwen2_5_VLProcessingInfo
+
+from vllm.config import VllmConfig
+from vllm.logger import init_logger
+from vllm.multimodal import MULTIMODAL_REGISTRY
+from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
+from vllm.multimodal.processing import PromptReplacement
+from vllm.multimodal.parse import MultiModalDataItems
+from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank
+from vllm.distributed import utils as dist_utils
+from vllm.sequence import IntermediateTensors
+from vllm.transformers_utils.config import uses_mrope
+
+logger = init_logger(__name__)
+
+_ACTIVATION_REGISTRY = {"silu": F.silu}
+
+# === Vision Inputs === #
+
+
+class _Qwen2VLMultiModalProcessor(Qwen2VLMultiModalProcessor):
+
+ def _get_prompt_replacements(
+ self,
+ mm_items: MultiModalDataItems,
+ hf_processor_mm_kwargs: Mapping[str, Any],
+ out_mm_kwargs: MultiModalKwargs,
+ ) -> list[PromptReplacement]:
+ hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
+ image_processor = self.info.get_image_processor(
+ **hf_processor_mm_kwargs)
+ tokenizer = self.info.get_tokenizer()
+ vocab = tokenizer.get_vocab()
+
+ placeholder = {
+ "image": vocab[hf_processor.image_token],
+ "video": vocab[hf_processor.video_token],
+ }
+
+ merge_length = image_processor.merge_size**2
+
+ def get_replacement_qwen2vl(item_idx: int, modality: str):
+ grid_thw = out_mm_kwargs[f"{modality}_grid_thw"][item_idx]
+ assert isinstance(grid_thw, ms.Tensor)
+
+ num_tokens = int(grid_thw.prod()) // merge_length
+ return [placeholder[modality]] * num_tokens
+
+ return [
+ PromptReplacement(
+ modality=modality,
+ target=[placeholder[modality]],
+ replacement=partial(get_replacement_qwen2vl,
+ modality=modality),
+ ) for modality in ("image", "video")
+ ]
+
+
+# === Vision Encoder === #
+
+
+class Qwen2_5_VisionMLP(nn.Cell):
+
+ def __init__(self,
+ in_features: int,
+ hidden_features: int,
+ bias: bool = False,
+ act_fn: Callable[[ms.Tensor], ms.Tensor] = F.silu,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = ""):
+ super().__init__()
+ self.gate_proj = ColumnParallelLinear(in_features,
+ hidden_features,
+ bias=bias,
+ quant_config=quant_config,
+ prefix=f"{prefix}.gate_proj",
+ params_dtype=ms.bfloat16)
+ self.up_proj = ColumnParallelLinear(in_features,
+ hidden_features,
+ bias=bias,
+ quant_config=quant_config,
+ prefix=f"{prefix}.up_proj",
+ params_dtype=ms.bfloat16)
+ self.down_proj = RowParallelLinear(hidden_features,
+ in_features,
+ bias=bias,
+ quant_config=quant_config,
+ prefix=f"{prefix}.down_proj",
+ params_dtype=ms.bfloat16)
+ self.act_fn = act_fn
+
+ def construct(self, x: ms.Tensor):
+ x_gate, _ = self.gate_proj(x)
+ x_gate = self.act_fn(x_gate)
+ x_up, _ = self.up_proj(x)
+ x_down, _ = self.down_proj(x_gate * x_up)
+ return x_down
+
+
+def apply_rotary_pos_emb_flashatt(
+ q: ms.Tensor, k: ms.Tensor, cos: ms.Tensor,
+ sin: ms.Tensor) -> Tuple[ms.Tensor, ms.Tensor]:
+ q_embed = ops.rotary_position_embedding(q.float(), cos, sin).type_as(q)
+ k_embed = ops.rotary_position_embedding(k.float(), cos, sin).type_as(k)
+ return q_embed, k_embed
+
+
+class Qwen2_5_VisionAttention(nn.Cell):
+
+ def __init__(
+ self,
+ embed_dim: int,
+ num_heads: int,
+ projection_size: int,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+ # Per attention head and per partition values.
+ self.tp_size = get_tensor_model_parallel_world_size()
+ self.tp_rank = get_tensor_model_parallel_rank()
+ self.hidden_size_per_attention_head = dist_utils.divide(
+ projection_size, num_heads)
+ self.num_attention_heads_per_partition = dist_utils.divide(
+ num_heads, self.tp_size)
+ self.num_heads = num_heads
+ self.head_dim = self.hidden_size_per_attention_head
+
+ self.qkv = ColumnParallelLinear(input_size=embed_dim,
+ output_size=3 * projection_size,
+ quant_config=quant_config,
+ prefix=f"{prefix}.qkv",
+ params_dtype=ms.bfloat16)
+ self.proj = RowParallelLinear(input_size=projection_size,
+ output_size=embed_dim,
+ quant_config=quant_config,
+ prefix=f"{prefix}.proj",
+ params_dtype=ms.bfloat16)
+ self.tensor_model_parallel_all_gather = AllGatherFromModelParallelRegion(
+ )
+
+ def split_tensor_along_last_dim(
+ self,
+ tensor: ms.Tensor,
+ num_partitions: int,
+ contiguous_split_chunks: bool = False,
+ ):
+ """ Split a tensor along its last dimension.
+
+ Arguments:
+ tensor: input tensor.
+ num_partitions: number of partitions to split the tensor
+ contiguous_split_chunks: If True, make each chunk contiguous
+ in memory.
+
+ Returns:
+ A list of Tensors
+ """
+ # Get the size and dimension.
+ last_dim = tensor.dim() - 1
+ last_dim_size = dist_utils.divide(tensor.shape[last_dim],
+ num_partitions)
+ # Split.
+ tensor_list = mint.split(tensor, last_dim_size, dim=last_dim)
+ # NOTE: torch.split does not create contiguous tensors by default.
+
+ return tensor_list
+
+ def split_qkv(self, qkv: ms.Tensor) -> tuple[ms.Tensor, ...]:
+ # [s, 3 * head * head_dim]
+ seq_len, _ = qkv.shape
+ if self.tp_size > 1:
+ qkv = self.tensor_model_parallel_all_gather(qkv)
+
+ # [s, 3 * head * head_dim] -> 3 * [s, head * head_dim]
+ q, k, v = mint.chunk(qkv, 3, dim=-1)
+
+ # 3 * [s, head * head_dim]
+ if self.tp_size > 1:
+ splitter = partial(self.split_tensor_along_last_dim,
+ num_partitions=self.tp_size)
+ q = splitter(q)[self.tp_rank]
+ k = splitter(k)[self.tp_rank]
+ v = splitter(v)[self.tp_rank]
+
+ # 3 * [s, head * head_dim] -> 3 * [s, head, head_dim]
+ new_shape = (seq_len, self.num_attention_heads_per_partition,
+ self.hidden_size_per_attention_head)
+ q, k, v = (x.view(*new_shape) for x in (q, k, v))
+ return q, k, v
+
+ def construct(
+ self,
+ x: ms.Tensor,
+ cu_seqlens: ms.Tensor,
+ position_embeddings: Tuple[ms.Tensor, ms.Tensor],
+ ) -> ms.Tensor:
+ seq_length = x.shape[0]
+ x, _ = self.qkv(x)
+ q, k, v = self.split_qkv(x)
+
+ cos, sin = position_embeddings
+ q, k = apply_rotary_pos_emb_flashatt(mint.unsqueeze(q, 0),
+ mint.unsqueeze(k, 0), cos, sin)
+
+ q = mint.squeeze(q, 0)
+ k = mint.squeeze(k, 0)
+
+ context_layer = ops.flash_attention_score(
+ q,
+ k,
+ v,
+ self.num_heads // self.tp_size,
+ actual_seq_qlen=cu_seqlens,
+ actual_seq_kvlen=cu_seqlens,
+ scalar_value=1 / math.sqrt(q.shape[-1]),
+ input_layout="TND",
+ ).reshape(seq_length, -1)
+ output, _ = self.proj(context_layer)
+ return output
+
+
+class Qwen2_5_VisionBlock(nn.Cell):
+
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ mlp_hidden_dim: int,
+ act_fn: Callable[[ms.Tensor], ms.Tensor] = F.silu,
+ norm_layer: Optional[Callable[[int], nn.Cell]] = None,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+ if norm_layer is None:
+ norm_layer = partial(mint.nn.LayerNorm,
+ eps=1e-6,
+ dtype=ms.bfloat16)
+ self.norm1 = norm_layer(dim)
+ self.norm2 = norm_layer(dim)
+ self.attn = Qwen2_5_VisionAttention(embed_dim=dim,
+ num_heads=num_heads,
+ projection_size=dim,
+ quant_config=quant_config,
+ prefix=f"{prefix}.attn")
+ self.mlp = Qwen2_5_VisionMLP(dim,
+ mlp_hidden_dim,
+ act_fn=act_fn,
+ bias=True,
+ quant_config=quant_config,
+ prefix=f"{prefix}.mlp")
+
+ def construct(
+ self, x: ms.Tensor, cu_seqlens: ms.Tensor,
+ position_embeddings: Tuple[ms.Tensor, ms.Tensor]) -> ms.Tensor:
+ x = x + self.attn(self.norm1(x),
+ cu_seqlens=cu_seqlens,
+ position_embeddings=position_embeddings)
+ x = x + self.mlp(self.norm2(x))
+ return x
+
+
+class Qwen2_5_VisionPatchEmbed(nn.Cell):
+
+ def __init__(
+ self,
+ patch_size: int = 14,
+ temporal_patch_size: int = 2,
+ in_channels: int = 3,
+ hidden_size: int = 1152,
+ ) -> None:
+ super().__init__()
+ self.patch_size = patch_size
+ self.temporal_patch_size = temporal_patch_size
+ self.hidden_size = hidden_size
+ self.dtype = ms.bfloat16
+
+ self.proj = nn.Dense(temporal_patch_size * patch_size * patch_size *
+ in_channels,
+ self.hidden_size,
+ has_bias=False,
+ dtype=self.dtype)
+
+ def construct(self, x: ms.Tensor) -> ms.Tensor:
+ x = self.proj(x) # B Ph*Pw C_out
+ return x
+
+
+class Qwen2_5_VisionPatchMerger(nn.Cell):
+
+ def __init__(
+ self,
+ d_model: int,
+ context_dim: int,
+ norm_layer: Optional[Callable[[int], nn.Cell]] = None,
+ spatial_merge_size: int = 2,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+ self.hidden_size = context_dim * (spatial_merge_size**2)
+ if norm_layer is None:
+ norm_layer = partial(mint.nn.LayerNorm,
+ eps=1e-6,
+ dtype=ms.bfloat16)
+ self.ln_q = norm_layer(context_dim)
+ self.mlp = nn.CellList([
+ ColumnParallelLinear(self.hidden_size,
+ self.hidden_size,
+ bias=True,
+ quant_config=quant_config,
+ prefix=f"{prefix}.mlp.0",
+ params_dtype=ms.bfloat16),
+ nn.GELU(),
+ RowParallelLinear(self.hidden_size,
+ d_model,
+ bias=True,
+ quant_config=quant_config,
+ prefix=f"{prefix}.mlp.2",
+ params_dtype=ms.bfloat16),
+ ])
+
+ def construct(self, x: ms.Tensor) -> ms.Tensor:
+ x = self.ln_q(x)
+ x = x.view(-1, self.hidden_size)
+
+ mlp_fc1, mlp_act, mlp_fc2 = self.mlp
+ x_parallel, _ = mlp_fc1(x)
+ x_parallel = mlp_act(x_parallel)
+ out, _ = mlp_fc2(x_parallel)
+ return out
+
+
+class Qwen2_5_VisionRotaryEmbedding(nn.Cell):
+
+ def __init__(self, dim: int, theta: float = 10000.0) -> None:
+ super().__init__()
+ self.dim = dim
+ self.theta = theta
+ self.inv_freq = 1.0 / (theta**(
+ mint.arange(0, dim, 2, dtype=ms.float32) / dim))
+ self._seq_len_cached = 0
+ self._freqs_cached = None
+
+ def update_freqs_cache(self, seqlen: int) -> None:
+ if seqlen > self._seq_len_cached:
+ seqlen *= 2
+ self._seq_len_cached = seqlen
+ self.inv_freq = 1.0 / (self.theta**(
+ mint.arange(0, self.dim, 2, dtype=ms.float32) / self.dim))
+ seq = mint.arange(seqlen, dtype=self.inv_freq.dtype)
+ freqs = mint.outer(seq, self.inv_freq)
+ self._freqs_cached = freqs
+
+ def construct(self, seqlen: int) -> ms.Tensor:
+ self.update_freqs_cache(seqlen)
+ return self._freqs_cached[:seqlen] # type: ignore[index]
+
+
+class Qwen2_5_VisionTransformer(nn.Cell):
+
+ def __init__(
+ self,
+ vision_config,
+ norm_eps: float = 1e-6,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+
+ patch_size = vision_config.patch_size
+ temporal_patch_size = vision_config.temporal_patch_size
+ in_channels = vision_config.in_channels
+ depth = vision_config.depth
+ self.hidden_size = vision_config.hidden_size
+ self.num_heads = vision_config.num_heads
+
+ # args for get_window_index
+ self.window_size = vision_config.window_size
+ self.patch_size = vision_config.patch_size
+ self.spatial_merge_size = vision_config.spatial_merge_size
+ self.fullatt_block_indexes = vision_config.fullatt_block_indexes
+ self.spatial_merge_unit = self.spatial_merge_size**2
+
+ self.patch_embed = Qwen2_5_VisionPatchEmbed(
+ patch_size=patch_size,
+ temporal_patch_size=temporal_patch_size,
+ in_channels=in_channels,
+ hidden_size=self.hidden_size,
+ )
+
+ norm_layer = partial(RMSNorm, eps=norm_eps, params_dtype=ms.bfloat16)
+ head_dim = self.hidden_size // self.num_heads
+ self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
+
+ self.blocks = nn.CellList([
+ Qwen2_5_VisionBlock(
+ dim=self.hidden_size,
+ num_heads=self.num_heads,
+ mlp_hidden_dim=vision_config.intermediate_size,
+ act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act],
+ norm_layer=norm_layer,
+ quant_config=quant_config,
+ prefix=f"{prefix}.blocks.{layer_idx}")
+ for layer_idx in range(depth)
+ ])
+ self.merger = Qwen2_5_VisionPatchMerger(
+ d_model=vision_config.out_hidden_size,
+ context_dim=self.hidden_size,
+ norm_layer=norm_layer,
+ spatial_merge_size=self.spatial_merge_size,
+ quant_config=quant_config,
+ prefix=f"{prefix}.merger",
+ )
+ from mindspore.communication.management import get_rank
+ self.rank_id = get_rank()
+
+ def set_model_inputs(self):
+ dyn_x = ms.Tensor(shape=[None, None], dtype=self.dtype)
+ dyn_rotary_pos_emb = ms.Tensor(shape=[None, None],
+ dtype=mstype.float32)
+ dyn_window_index = ms.Tensor(shape=[None], dtype=mstype.int64)
+ dyn_cu_window_seqlens = ms.Tensor(shape=[None], dtype=mstype.int64)
+ dyn_grid_thw = ms.Tensor(shape=[None, None], dtype=mstype.int64)
+
+ self.set_inputs(
+ dyn_x,
+ dyn_rotary_pos_emb,
+ dyn_window_index,
+ dyn_cu_window_seqlens,
+ dyn_grid_thw,
+ )
+
+ @property
+ def dtype(self) -> ms.Type:
+ return self.patch_embed.dtype
+
+ def construct(
+ self,
+ x: ms.Tensor,
+ rotary_pos_emb: ms.Tensor,
+ window_index: ms.Tensor,
+ cu_window_seqlens: ms.Tensor,
+ grid_thw: ms.Tensor,
+ ) -> ms.Tensor:
+ hidden_states = x.to(dtype=self.dtype)
+ hidden_states = self.patch_embed(hidden_states)
+
+ cu_window_seqlens = cu_window_seqlens.astype(ms.int32)
+ cu_window_seqlens = mint.unique_consecutive(cu_window_seqlens)
+ seq_len, _ = hidden_states.shape
+ hidden_states = hidden_states.reshape(
+ seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
+ hidden_states = hidden_states[window_index]
+ hidden_states = hidden_states.reshape(seq_len, -1)
+ rotary_pos_emb = rotary_pos_emb.reshape(
+ seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
+ rotary_pos_emb = rotary_pos_emb[window_index]
+ rotary_pos_emb = rotary_pos_emb.reshape(1, seq_len, 1, -1)
+ emb = mint.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
+ position_embeddings = (mint.cos(emb), mint.sin(emb))
+
+ grid_thw_1 = grid_thw.index_select(1, ms.Tensor([1])).reshape(-1)
+ grid_thw_2 = grid_thw.index_select(1, ms.Tensor([2])).reshape(-1)
+ grid_thw_0 = grid_thw.index_select(1, ms.Tensor([0])).reshape(-1)
+ cu_seqlens = mint.cumsum(mint.repeat_interleave(
+ grid_thw_1 * grid_thw_2, grid_thw_0),
+ dim=0,
+ dtype=ms.int32)
+
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
+ # transformers
+ for layer_num, blk in enumerate(self.blocks):
+ if layer_num in self.fullatt_block_indexes:
+ cu_seqlens_now = cu_seqlens
+ else:
+ cu_seqlens_now = cu_window_seqlens
+ hidden_states = blk(hidden_states,
+ cu_seqlens=cu_seqlens_now,
+ position_embeddings=position_embeddings)
+
+ # adapter
+ hidden_states = self.merger(hidden_states)
+ reverse_indices = mint.argsort(window_index)
+ hidden_states = hidden_states[reverse_indices]
+ return hidden_states
+
+ def load_weights(self, weights: Iterable[Tuple[str, ms.Tensor]],
+ params_dict: Dict[str, ms.Parameter]) -> Set[str]:
+ loaded_params: Set[str] = set()
+ stacked_params_mapping = [
+ # (param_name, shard_name, shard_id)
+ ("qkv_proj", "q_proj", "q"),
+ ("qkv_proj", "k_proj", "k"),
+ ("qkv_proj", "v_proj", "v"),
+ ]
+
+ for name, loaded_weight in weights:
+ for (param_name, weight_name, shard_id) in stacked_params_mapping:
+ if weight_name not in name:
+ continue
+ name = name.replace(weight_name, param_name)
+
+ param = params_dict[name]
+ weight_loader = param.weight_loader
+ weight_loader(param, loaded_weight, shard_id)
+ break
+ else:
+ if name == "visual.patch_embed.proj.weight":
+ loaded_weight = loaded_weight.reshape(
+ loaded_weight.shape[0], -1)
+ param = params_dict[name]
+ weight_loader = getattr(param, "weight_loader",
+ default_weight_loader)
+ weight_loader(param, loaded_weight)
+ loaded_params.add(name)
+ return loaded_params
+
+
+class Qwen2_5_VLMultiModalProcessor(_Qwen2VLMultiModalProcessor):
+
+ def _get_mm_fields_config(
+ self,
+ hf_inputs,
+ hf_processor_mm_kwargs: Mapping[str, object],
+ ) -> Mapping[str, MultiModalFieldConfig]:
+ return dict(
+ **super()._get_mm_fields_config(hf_inputs, hf_processor_mm_kwargs),
+ second_per_grid_ts=MultiModalFieldConfig.batched("video"),
+ )
+
+
+@MULTIMODAL_REGISTRY.register_processor(
+ Qwen2_5_VLMultiModalProcessor,
+ info=Qwen2_5_VLProcessingInfo,
+ dummy_inputs=Qwen2_5_VLDummyInputsBuilder)
+class Qwen2_5_VLForConditionalGeneration(NativeModel, SupportsMultiModal):
+ packed_modules_mapping = {
+ "qkv_proj": [
+ "q_proj",
+ "k_proj",
+ "v_proj",
+ ],
+ "gate_up_proj": [
+ "gate_proj",
+ "up_proj",
+ ],
+ }
+ # LoRA specific attributes
+ supported_lora_modules = [
+ # language model
+ "qkv_proj",
+ "o_proj",
+ "gate_up_proj",
+ "down_proj", # Same name with vision encoder
+ # vision tower
+ "qkv",
+ "gate_proj",
+ "up_proj",
+ "attn.proj", # Distinguish patch_embed.proj
+ "fc1",
+ "fc2",
+ # projector
+ "mlp.0",
+ "mlp.2"
+ ]
+
+ embedding_modules = {} # type: ignore[var-annotated]
+ embedding_padding_modules = [] # type: ignore[var-annotated]
+
+ # To ensure correct weight loading and mapping.
+ hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={
+ "lm_head.": "language_model.lm_head.",
+ "model.": "language_model.model.",
+ })
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super().__init__(vllm_config=vllm_config, prefix=prefix)
+ config = vllm_config.model_config.hf_config
+ quant_config = vllm_config.quant_config
+ multimodal_config = vllm_config.model_config.multimodal_config
+
+ self.config = config
+ self.multimodal_config = multimodal_config
+
+ self.visual = Qwen2_5_VisionTransformer(
+ config.vision_config,
+ norm_eps=getattr(config, "rms_norm_eps", 1e-6),
+ quant_config=self._maybe_ignore_quant_config(quant_config),
+ prefix=maybe_prefix(prefix, "visual"),
+ )
+ self.visual = ms.jit(
+ function=self.visual,
+ jit_level='O0') if self.is_graph_mode else self.visual
+
+ self.model = Qwen2Model(vllm_config=vllm_config,
+ prefix=maybe_prefix(prefix, "model"))
+
+ if get_pp_group().is_last_rank:
+ if config.tie_word_embeddings:
+ self.lm_head = self.model.embed_tokens
+ else:
+ self.lm_head = ParallelLMHead(config.vocab_size,
+ config.hidden_size,
+ params_dtype=ms.bfloat16,
+ quant_config=quant_config,
+ prefix=maybe_prefix(
+ prefix, "lm_head"))
+ self.logits_processor = LogitsProcessor(config.vocab_size)
+ self.sampler = get_sampler()
+ else:
+ self.lm_head = PPMissingLayer()
+
+ self.common_preprocess(vllm_config, prefix)
+ self.spatial_merge_size = config.vision_config.spatial_merge_size
+
+ self.window_size = config.vision_config.window_size
+ self.patch_size = config.vision_config.patch_size
+ self.spatial_merge_unit = self.spatial_merge_size**2
+ self.hidden_size = config.vision_config.hidden_size
+ self.num_heads = config.vision_config.num_heads
+ head_dim = self.hidden_size // self.num_heads
+ self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
+ if self.is_graph_mode:
+ self.visual.set_model_inputs()
+
+ def common_preprocess(self, vllm_config, prefix=""):
+ self.set_modules({
+ "visual": self.visual,
+ "model": self.model,
+ "lm_head": self.lm_head
+ })
+ self.casual_mask = MultiModalLowerTriangularMask(
+ dtype=self.model_config.dtype,
+ max_model_len=self.model_config.max_model_len)
+ self.kv_caches = [
+ AttentionWrapper() for i in range(self.config.num_hidden_layers)
+ ]
+
+ compilation_config = vllm_config.compilation_config
+ if prefix in compilation_config.static_forward_context:
+ raise ValueError(f"Duplicate layer name: {prefix}")
+ for i in range(self.config.num_hidden_layers):
+ compilation_config.static_forward_context[str(
+ i)] = self.kv_caches[i]
+
+ def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
+ # GPTQ configs do not have a list of ignored modules, however AutoGPTQ
+ # seems to avoid vision encoder sections for some models.
+ # if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)):
+ # return None
+ return quant_config
+
+ def _validate_and_reshape_mm_tensor(self, mm_input: object,
+ name: str) -> ms.Tensor:
+ if not isinstance(mm_input, (ms.Tensor, list)):
+ raise ValueError(f"Incorrect type of {name}. "
+ f"Got type: {type(mm_input)}")
+ if isinstance(mm_input, ms.Tensor):
+ if mm_input.ndim == 2:
+ return mm_input
+ if mm_input.ndim != 3:
+ raise ValueError(f"{name} should be 2D or batched 3D tensor. "
+ f"Got ndim: {mm_input.ndim} "
+ f"(shape={mm_input.shape})")
+ return mint.concat(list(mm_input))
+ else:
+ return mint.concat(mm_input)
+
+ def _parse_and_validate_image_input(
+ self, **kwargs: object) -> Optional[Qwen2_5_VLImageInputs]:
+ pixel_values = kwargs.pop("pixel_values", None)
+ image_embeds = kwargs.pop("image_embeds", None)
+ image_grid_thw = kwargs.pop("image_grid_thw", None)
+
+ if pixel_values is None and image_embeds is None:
+ return None
+
+ if pixel_values is not None:
+ pixel_values = self._validate_and_reshape_mm_tensor(
+ pixel_values, "image pixel values")
+ image_grid_thw = self._validate_and_reshape_mm_tensor(
+ image_grid_thw, "image grid_thw")
+
+ if not isinstance(pixel_values, (ms.Tensor, list)):
+ raise ValueError("Incorrect type of image pixel values. "
+ f"Got type: {type(pixel_values)}")
+
+ return Qwen2_5_VLImagePixelInputs(type="pixel_values",
+ pixel_values=pixel_values,
+ image_grid_thw=image_grid_thw)
+
+ if image_embeds is not None:
+ image_embeds = self._validate_and_reshape_mm_tensor(
+ image_embeds, "image embeds")
+ image_grid_thw = self._validate_and_reshape_mm_tensor(
+ image_grid_thw, "image grid_thw")
+
+ if not isinstance(image_embeds, ms.Tensor):
+ raise ValueError("Incorrect type of image embeddings. "
+ f"Got type: {type(image_embeds)}")
+ return Qwen2_5_VLImageEmbeddingInputs(
+ type="image_embeds",
+ image_embeds=image_embeds,
+ image_grid_thw=image_grid_thw)
+
+ return None
+
+ def _parse_and_validate_video_input(
+ self, **kwargs: object) -> Optional[Qwen2_5_VLVideoInputs]:
+ pixel_values_videos = kwargs.pop("pixel_values_videos", None)
+ video_embeds = kwargs.pop("video_embeds", None)
+ video_grid_thw = kwargs.pop("video_grid_thw", None)
+ second_per_grid_ts = kwargs.pop("second_per_grid_ts", None)
+
+ if pixel_values_videos is None and video_embeds is None:
+ return None
+
+ if pixel_values_videos is not None:
+ pixel_values_videos = self._validate_and_reshape_mm_tensor(
+ pixel_values_videos, "video pixel values")
+ video_grid_thw = self._validate_and_reshape_mm_tensor(
+ video_grid_thw, "video grid_thw")
+
+ return Qwen2_5_VLVideoPixelInputs(
+ type="pixel_values_videos",
+ pixel_values_videos=pixel_values_videos,
+ video_grid_thw=video_grid_thw,
+ second_per_grid_ts=second_per_grid_ts,
+ )
+
+ if video_embeds is not None:
+ video_embeds = self._validate_and_reshape_mm_tensor(
+ video_embeds, "video embeds")
+ video_grid_thw = self._validate_and_reshape_mm_tensor(
+ video_grid_thw, "video grid_thw")
+
+ if not isinstance(video_embeds, ms.Tensor):
+ raise ValueError("Incorrect type of video embeddings. "
+ f"Got type: {type(video_embeds)}")
+ return Qwen2_5_VLVideoEmbeddingInputs(
+ type="video_embeds",
+ video_embeds=video_embeds,
+ video_grid_thw=video_grid_thw)
+
+ return None
+
+ def rot_pos_emb(self, grid_thw: ms.Tensor) -> ms.Tensor:
+ pos_ids = []
+ for t, h, w in grid_thw:
+ t, h, w = t.item(), h.item(), w.item()
+ hpos_ids = mint.arange(h).unsqueeze(1).expand((-1, w))
+ wpos_ids = mint.arange(w).unsqueeze(0).expand((h, -1))
+
+ hpos_ids = hpos_ids.reshape(
+ h // self.spatial_merge_size,
+ self.spatial_merge_size,
+ w // self.spatial_merge_size,
+ self.spatial_merge_size,
+ ).permute(0, 2, 1, 3).flatten()
+ wpos_ids = wpos_ids.reshape(
+ h // self.spatial_merge_size,
+ self.spatial_merge_size,
+ w // self.spatial_merge_size,
+ self.spatial_merge_size,
+ ).permute(0, 2, 1, 3).flatten()
+ pos_ids.append(
+ mint.tile(mint.stack([hpos_ids, wpos_ids], dim=-1), (t, 1)))
+ pos_ids = mint.cat(pos_ids, dim=0)
+ max_grid_size = grid_thw[:, 1:].max().item()
+ rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
+ rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
+ return rotary_pos_emb
+
+ def get_window_index(self, grid_thw):
+ window_index = []
+ cu_window_seqlens = [ms.Tensor([0])]
+ window_index_id = 0
+ vit_merger_window_size = (self.window_size //
+ self.spatial_merge_size // self.patch_size)
+
+ for grid_t, grid_h, grid_w in grid_thw:
+ grid_t, grid_h, grid_w = grid_t.item(), grid_h.item(), grid_w.item(
+ )
+ llm_grid_h = grid_h // self.spatial_merge_size
+ llm_grid_w = grid_w // self.spatial_merge_size
+ index = mint.arange(grid_t * llm_grid_h * llm_grid_w).reshape(
+ grid_t, llm_grid_h, llm_grid_w)
+ pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
+ pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
+ num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
+ num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
+ index_padded = F.pad(index, (0, pad_w, 0, pad_h), 'constant', -100)
+ index_padded = index_padded.reshape(grid_t, num_windows_h,
+ vit_merger_window_size,
+ num_windows_w,
+ vit_merger_window_size)
+ index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
+ grid_t, num_windows_h * num_windows_w, vit_merger_window_size,
+ vit_merger_window_size)
+ seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
+ index_padded = index_padded.reshape(-1)
+ index_new = index_padded[index_padded != -100]
+ window_index.append(index_new + window_index_id)
+ cu_seqlens_tmp = mint.cumsum(
+ seqlens,
+ 0) * self.spatial_merge_unit + cu_window_seqlens[-1][-1]
+ cu_window_seqlens.append(cu_seqlens_tmp)
+ window_index_id += grid_t * llm_grid_h * llm_grid_w
+ window_index = mint.cat(window_index, dim=0)
+ cu_window_seqlens = mint.cat(cu_window_seqlens, dim=0)
+ return window_index, cu_window_seqlens
+
+ def _process_image_input(
+ self, image_input: Qwen2_5_VLImageInputs) -> tuple[ms.Tensor, ...]:
+
+ grid_thw = image_input["image_grid_thw"]
+ assert grid_thw.ndim == 2
+
+ if image_input["type"] == "image_embeds":
+ image_embeds = image_input["image_embeds"].type(self.visual.dtype)
+ else:
+ pixel_values = image_input["pixel_values"].type(self.visual.dtype)
+ os.environ[
+ "MS_DISABLE_INTERNAL_KERNELS_LIST"] = "FlashAttentionScore"
+ # compute position embedding
+ rotary_pos_emb = self.rot_pos_emb(grid_thw)
+ # windows attention
+ window_index, cu_window_seqlens = self.get_window_index(grid_thw)
+ image_embeds = self.visual(pixel_values, rotary_pos_emb,
+ window_index, cu_window_seqlens,
+ grid_thw)
+ os.environ["MS_DISABLE_INTERNAL_KERNELS_LIST"] = ""
+
+ # Split concatenated embeddings for each image item.
+ merge_size = self.visual.spatial_merge_size
+ sizes = grid_thw.prod(-1) // merge_size // merge_size
+
+ return image_embeds.split(sizes.tolist())
+
+ def _process_video_input(
+ self, video_input: Qwen2_5_VLVideoInputs) -> tuple[ms.Tensor, ...]:
+
+ grid_thw = video_input["video_grid_thw"]
+ assert grid_thw.ndim == 2
+
+ if video_input["type"] == "video_embeds":
+ video_embeds = video_input["video_embeds"].type(self.visual.dtype)
+ else:
+ pixel_values_videos = video_input["pixel_values_videos"].type(
+ self.visual.dtype)
+ os.environ[
+ "MS_DISABLE_INTERNAL_KERNELS_LIST"] = "FlashAttentionScore"
+ rotary_pos_emb = self.rot_pos_emb(grid_thw)
+ # windows attention
+ window_index, cu_window_seqlens = self.get_window_index(grid_thw)
+ video_embeds = self.visual(pixel_values_videos, rotary_pos_emb,
+ window_index, cu_window_seqlens,
+ grid_thw)
+ os.environ["MS_DISABLE_INTERNAL_KERNELS_LIST"] = ""
+
+ # Split concatenated embeddings for each video item.
+ merge_size = self.visual.spatial_merge_size
+ sizes = grid_thw.prod(-1) // merge_size // merge_size
+
+ return video_embeds.split(sizes.tolist())
+
+ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
+ modalities = {}
+
+ # Preserve the order of modalities if there are multiple of them
+ # from the order of kwargs.
+ for input_key in kwargs:
+ if input_key in ("pixel_values",
+ "image_embeds") and "images" not in modalities:
+ modalities["images"] = self._parse_and_validate_image_input(
+ **kwargs)
+ if input_key in ("pixel_values_videos",
+ "video_embeds") and "videos" not in modalities:
+ modalities["videos"] = self._parse_and_validate_video_input(
+ **kwargs)
+ return modalities
+
+ def get_multimodal_embeddings(self,
+ **kwargs) -> Optional[tuple[ms.Tensor, ...]]:
+
+ modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
+ if not modalities:
+ return None
+
+ # The result multimodal_embeddings is tuple of tensors, with each
+ # tensor correspoending to a multimodal data item (image or video).
+ multimodal_embeddings: tuple[ms.Tensor, ...] = ()
+
+ # NOTE: It is important to iterate over the keys in this dictionary
+ # to preserve the order of the modalities.
+ for modality in modalities:
+ if modality == "images":
+ image_input = modalities["images"]
+ vision_embeddings = self._process_image_input(image_input)
+ multimodal_embeddings += vision_embeddings
+ if modality == "videos":
+ video_input = modalities["videos"]
+ video_embeddings = self._process_video_input(video_input)
+ multimodal_embeddings += video_embeddings
+ return multimodal_embeddings
+
+ def get_input_embeddings(
+ self,
+ input_ids: ms.Tensor,
+ multimodal_embeddings: Optional[tuple[ms.Tensor, ...]] = None,
+ ) -> ms.Tensor:
+ # input_ids = input_ids.to(mstype.int64)
+ inputs_embeds = self.model.get_input_embeddings(input_ids)
+ if multimodal_embeddings is not None:
+ inputs_embeds = merge_multimodal_embeddings(
+ input_ids, inputs_embeds, multimodal_embeddings,
+ [self.config.image_token_id, self.config.video_token_id])
+ os.environ["MS_DISABLE_INTERNAL_KERNELS_LIST"] = ""
+ return inputs_embeds
+
+ def get_input_embeddings_v0(
+ self,
+ input_ids: ms.Tensor,
+ image_input: Optional[tuple[ms.Tensor, ...]] = None,
+ video_input: Optional[tuple[ms.Tensor, ...]] = None,
+ ) -> ms.Tensor:
+ inputs_embeds = self.get_input_embeddings(input_ids)
+ if image_input is not None:
+ image_embeds = self._process_image_input(image_input)
+ inputs_embeds = merge_multimodal_embeddings(
+ input_ids,
+ inputs_embeds,
+ image_embeds,
+ placeholder_token_id=self.config.image_token_id,
+ )
+
+ if video_input is not None:
+ video_embeds = self._process_video_input(video_input)
+ inputs_embeds = merge_multimodal_embeddings(
+ input_ids,
+ inputs_embeds,
+ video_embeds,
+ placeholder_token_id=self.config.video_token_id,
+ )
+ return inputs_embeds
+
+ def forward(
+ self,
+ input_ids: ms.Tensor,
+ positions: ms.Tensor,
+ intermediate_tensors: Optional[IntermediateTensors] = None,
+ inputs_embeds: Optional[ms.Tensor] = None,
+ **kwargs: object,
+ ) -> Union[ms.Tensor, IntermediateTensors]:
+ if intermediate_tensors is not None:
+ inputs_embeds = None
+
+ # NOTE: In v1, inputs_embeds is always generated at model runner from
+ # `get_multimodal_embeddings` and `get_input_embeddings`, this
+ # condition is only for v0 compatibility.
+ elif inputs_embeds is None:
+ image_input = self._parse_and_validate_image_input(**kwargs)
+ video_input = self._parse_and_validate_video_input(**kwargs)
+
+ if image_input is None and video_input is None:
+ inputs_embeds = None
+ else:
+ if uses_mrope(self.config):
+ assert positions.ndim == 2 and positions.shape[0] == 3, (
+ "multimodal section rotary embedding requires "
+ f"(3, seq_len) positions, but got {positions.shape}")
+ inputs_embeds = self.get_input_embeddings_v0(
+ input_ids,
+ image_input=image_input,
+ video_input=video_input)
+ input_ids = None
+ hidden_states = self.exec_model(input_ids, positions,
+ intermediate_tensors, inputs_embeds)
+ return hidden_states
+
+ def compute_logits(
+ self,
+ hidden_states: ms.Tensor,
+ sampling_metadata: SamplingMetadata,
+ ) -> Optional[ms.Tensor]:
+ logits = self.logits_processor(self.lm_head, hidden_states,
+ sampling_metadata)
+ return logits
+
+ def sample(self, logits: ms.Tensor,
+ sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]:
+ next_tokens = self.sampler(logits, sampling_metadata)
+ return next_tokens
+
+ def load_weights(
+ self, weights: Iterable[Tuple[str, ms.Tensor]]
+ ) -> None: # type: ignore[override]
+ params_dict = self.get_params_dict()
+ for name, weight in weights:
+ if "visual." in name:
+ self.visual.load_weights([(name, weight)], params_dict)
+ else:
+ self.model.load_weights([(name, weight)], params_dict)
+
+ return None
+
+ def get_mm_mapping(self) -> MultiModelKeys:
+ """
+ Get the module prefix in multimodal models
+ """
+ return MultiModelKeys.from_string_field(
+ language_model="language_model",
+ connector="visual.",
+ tower_model="visual.merger.")
diff --git a/vllm_mindspore/model_executor/models/registry.py b/vllm_mindspore/model_executor/models/registry.py
index bdc43d8bfe04029188dfb0723ea8df50dab2b3e2..009d84a06f124d270b88f26e98697ae4285cd7f5 100644
--- a/vllm_mindspore/model_executor/models/registry.py
+++ b/vllm_mindspore/model_executor/models/registry.py
@@ -25,16 +25,18 @@ from vllm.model_executor.models.registry import (_LazyRegisteredModel,
from vllm_mindspore.utils import (is_mindformers_model_backend,
is_mindone_model_backend)
-_MINDSPORE_MODELS = {
+_NATIVE_MODELS = {
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
+ "Qwen2_5_VLForConditionalGeneration":
+ ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"),
}
_MINDFORMERS_MODELS = {
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
+ "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"), # MCore
"DeepseekV3ForCausalLM": ("deepseek_v3", "DeepseekV3ForCausalLM"),
"DeepSeekMTPModel": ("deepseek_mtp", "DeepseekV3MTPForCausalLM"),
- "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
}
_MINDONE_MODELS = {
@@ -69,7 +71,7 @@ else:
module_name=f"vllm_mindspore.model_executor.models.{mod_relname}",
class_name=cls_name,
)
- for model_arch, (mod_relname, cls_name) in _MINDSPORE_MODELS.items()
+ for model_arch, (mod_relname, cls_name) in _NATIVE_MODELS.items()
}
MindSporeModelRegistry = _ModelRegistry(_registry_dict)
diff --git a/vllm_mindspore/model_executor/models/utils.py b/vllm_mindspore/model_executor/models/utils.py
index 4bb7831c584c03b61bda9e0d751d32be934db19b..493664cda0f899646e8e2872c4cba3f05bba83a7 100644
--- a/vllm_mindspore/model_executor/models/utils.py
+++ b/vllm_mindspore/model_executor/models/utils.py
@@ -1,5 +1,6 @@
#!/usr/bin/env python3
-# encoding: utf-8
+# type: ignore
+# isort:skip_file
# Copyright 2025 Huawei Technologies Co., Ltd
# Copyright 2024 The vLLM team.
#
@@ -17,18 +18,15 @@
# ============================================================================
from dataclasses import dataclass, field
-from typing import List, Tuple, Union, Mapping, Optional, Iterable
+from typing import Iterable, List, Mapping, Optional, Tuple, Union
+import mindspore as ms
+from mindspore import mint, ops
from vllm.sequence import IntermediateTensors
-from vllm_mindspore.multimodal.inputs import NestedTensors
+from vllm_mindspore.multimodal.inputs import NestedTensors # type: ignore[attr-defined]
from vllm_mindspore.utils import get_valid_dtype
-import mindspore as ms
-from mindspore import mint
-from mindspore import ops
-
-
WeightsMapping = Mapping[str, Optional[str]]
"""If a key maps to a value of `None`, the corresponding weight is ignored."""
@@ -72,6 +70,9 @@ class WeightsMapper:
if (out_name := self._map_name(name)) is not None)
+enforce_eager = False
+
+
class PPMissingLayer(ms.nn.Cell):
"""
A placeholder layer for missing layers in a pipeline parallel model.
@@ -118,9 +119,8 @@ def extract_layer_index(layer_name: str) -> int:
int_vals.append(int(subname))
except ValueError:
continue
- assert len(int_vals) == 1, (
- f"layer name {layer_name} should" " only contain one integer"
- )
+ assert len(int_vals) == 1, (f"layer name {layer_name} should"
+ " only contain one integer")
return int_vals[0]
@@ -135,17 +135,13 @@ def make_layers(
from vllm.distributed.parallel_state import get_pp_group
from vllm.distributed.utils import get_pp_indices
- start_layer, end_layer = get_pp_indices(
- num_hidden_layers, get_pp_group().rank_in_group, get_pp_group().world_size
- )
- modules = ms.nn.CellList(
- [PPMissingLayer() for _ in range(start_layer)]
- + [
- maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}"))
- for idx in range(start_layer, end_layer)
- ]
- + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)]
- )
+ start_layer, end_layer = get_pp_indices(num_hidden_layers,
+ get_pp_group().rank_in_group,
+ get_pp_group().world_size)
+ modules = ms.nn.CellList([PPMissingLayer() for _ in range(start_layer)] + [
+ maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}"))
+ for idx in range(start_layer, end_layer)
+ ] + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)])
return start_layer, end_layer, modules
@@ -157,15 +153,17 @@ def make_empty_intermediate_tensors_factory(keys: List[str], hidden_size: int):
device,
) -> IntermediateTensors:
dtype = get_valid_dtype(dtype)
- return IntermediateTensors(
- {key: mint.zeros((batch_size, hidden_size), dtype=dtype) for key in keys}
- )
+ return IntermediateTensors({
+ key: mint.zeros((batch_size, hidden_size), dtype=dtype)
+ for key in keys
+ })
return make_empty_intermediate_tensors
########################### for multi model ###########################
+
def _flatten_embeddings(embeddings: NestedTensors) -> ms.Tensor:
"""
Recursively flattens and concatenates NestedTensors on all but the last
@@ -251,8 +249,7 @@ def merge_multimodal_embeddings(
This updates ``inputs_embeds`` in place.
"""
if isinstance(placeholder_token_id, list):
- placeholder_token_id = ms.Tensor(placeholder_token_id,
- device=input_ids.device)
+ placeholder_token_id = ms.Tensor(placeholder_token_id)
return _merge_multimodal_embeddings(
inputs_embeds,
ms.numpy.isin(input_ids, placeholder_token_id),
@@ -263,4 +260,4 @@ def merge_multimodal_embeddings(
inputs_embeds,
(input_ids == placeholder_token_id),
multimodal_embeddings,
- )
\ No newline at end of file
+ )
diff --git a/vllm_mindspore/model_executor/sampling_metadata.py b/vllm_mindspore/model_executor/sampling_metadata.py
index a016dd8e14450b2dea8ca2a138e41578d4a40d08..416f4b7fe357c7e75c5abadfefefe364022e548e 100644
--- a/vllm_mindspore/model_executor/sampling_metadata.py
+++ b/vllm_mindspore/model_executor/sampling_metadata.py
@@ -18,14 +18,9 @@
from array import array
from dataclasses import dataclass
-from typing import Dict, List, Optional, Tuple
+from typing import List
-
-from vllm.sampling_params import SamplingParams, SamplingType
-from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData, SequenceGroupMetadata
from vllm.utils import (
- PyObjectCache,
- async_tensor_h2d,
is_pin_memory_available,
make_tensor_with_pad,
)
@@ -36,368 +31,6 @@ from mindspore import Tensor
import mindspore as ms
-@dataclass
-class SequenceGroupToSample:
- # |---------- N-1 iteration --------|
- # |---------------- N iteration ---------------------|
- # |- tokenA -|......................|-- newTokens ---|
- # |---------- context_len ----------|
- # |-------------------- seq_len ----------------------|
- # |-- query_len ---|
-
- # Sequence ids for the sequence group in a previous step.
- seq_ids: List[int]
- sampling_params: SamplingParams
- # seq_id -> sequence data.
- seq_data: Dict[int, SequenceData]
- # The length of the sequence (all tokens seen in the past + new token to
- # compute attention) of the sequence group. None if it is in a decode
- # stage.
- seq_len: Optional[int]
- # The length of new query tokens to compute in the current step. None if it
- # is in a decode stage. The length of query_len <= seq_len if chunked
- # prefill is enabled.
- query_len: Optional[int]
- # A random number generator for sampling.
- generator: Optional[ms.Generator]
- # True if the sequence group is in prefill stage. False if it is in a
- # decode stage.
- is_prompt: bool
- # Query token indices from logits. to compute prompt logprob. Empty if
- # prompt logprob is not required.
- prompt_logprob_indices: List[int]
- # Sample token indices from logits. Empty if sampling is not required.
- sample_indices: List[int]
-
- @property
- def do_sample(self):
- return len(self.sample_indices) > 0
-
- def __post_init__(self):
- if len(self.prompt_logprob_indices) > 0:
- assert self.sampling_params.prompt_logprobs is not None
- if self.is_prompt:
- assert self.seq_len is not None
- assert self.query_len is not None
-
-
-def gen_seq_group_to_sample_builder(num_seqs: int):
- return lambda: SequenceGroupToSample(
- seq_ids=[0] * num_seqs,
- sampling_params=None,
- seq_data=None, # type: ignore
- seq_len=0,
- query_len=0,
- generator=None,
- is_prompt=True,
- prompt_logprob_indices=[],
- sample_indices=[],
- )
-
-
-class SamplingMetadataCache:
- """Used to cache SamplingMetadata objects between scheduler iterations"""
-
- def __init__(self):
- self._seq_group_to_sample_cache: Dict[int, PyObjectCache] = {}
-
- def get_cached_seq_group_to_sample(self, num_seqs):
- if num_seqs not in self._seq_group_to_sample_cache:
- self._seq_group_to_sample_cache[num_seqs] = PyObjectCache(
- gen_seq_group_to_sample_builder(num_seqs)
- )
-
- obj = self._seq_group_to_sample_cache[num_seqs].get_object()
- return obj
-
- def reset(self):
- for cache in self._seq_group_to_sample_cache.values():
- cache.reset()
-
-
-class SamplingMetadata:
- """Metadata for input sequences. Used in sampler.
-
- The usage is as follow;
- ```
- hidden_states = execute_model(...)
- logits = hidden_states[sampling_metadata.selected_token_indices]
- sample(logits)
-
- def sample(logits):
- # Use categorized_sample_indices for sampling....
- ```
-
- Args:
- seq_groups: List of batched sequence groups.
- selected_token_indices: (num_query_tokens_to_logprob). Indices to find
- logits from the initial model output hidden states.
- categorized_sample_indices: SamplingType -> token indices to sample.
- Each token indices is 2D tensor of (num_indices, num_indices) where
- the first item means the sample index within the returned logit
- (before pruning padding), and the second item means the sample
- index after pruning using selected_token_indices.
- For example, if the returned logit is [1, 2, 3], and we select
- [1, 2] for sampling, the pruned logit will be [2, 3]. In this case,
- The first tuple is [1, 2] (sampled index within original logit),
- and the second tuple is [0, 1] (sampled index within pruned logit).
- num_prompts: Number of prompt sequence groups in seq_groups.
- skip_sampler_cpu_output: Indicates if we want to skip the GPU=>CPU
- serialization of token outputs.
- reuse_sampling_tensors: Indicates if we want to reuse sampling
- tensors that are part of the sampler forward pass. Currently,
- it is mainly used for multi-step decode.
-
- """
-
- def __init__(
- self,
- seq_groups: List[SequenceGroupToSample],
- selected_token_indices: Tensor,
- categorized_sample_indices: Dict[SamplingType, Tensor],
- num_prompts: int,
- skip_sampler_cpu_output: bool = False,
- reuse_sampling_tensors: bool = False,
- ) -> None:
- self.seq_groups = seq_groups
- self.selected_token_indices = selected_token_indices
- self.categorized_sample_indices = categorized_sample_indices
- self.num_prompts = num_prompts
- self.skip_sampler_cpu_output = skip_sampler_cpu_output
- self.reuse_sampling_tensors = reuse_sampling_tensors
-
- @staticmethod
- def prepare(
- seq_group_metadata_list: List[SequenceGroupMetadata],
- seq_lens: List[int],
- query_lens: List[int],
- device: str,
- pin_memory: bool,
- generators: Optional[Dict[str, ms.Generator]] = None,
- cache: Optional[SamplingMetadataCache] = None,
- ) -> "SamplingMetadata":
- (
- seq_groups,
- selected_token_indices,
- categorized_sample_indices,
- num_prompts,
- ) = _prepare_seq_groups(
- seq_group_metadata_list, seq_lens, query_lens, device, generators, cache
- )
- selected_token_indices = async_tensor_h2d(
- selected_token_indices,
- dtype=ms.int64,
- target_device=device,
- pin_memory=pin_memory,
- )
- categorized_sample_indices = {
- t: async_tensor_h2d(
- seq_ids,
- dtype=ms.int64,
- target_device=device,
- pin_memory=pin_memory,
- )
- for t, seq_ids in categorized_sample_indices.items()
- }
-
- sampling_metadata = SamplingMetadata(
- seq_groups=seq_groups,
- selected_token_indices=selected_token_indices,
- categorized_sample_indices=categorized_sample_indices,
- num_prompts=num_prompts,
- )
- return sampling_metadata
-
- def __repr__(self) -> str:
- return (
- "SamplingMetadata("
- f"seq_groups={self.seq_groups}, "
- f"selected_token_indices={self.selected_token_indices}, "
- f"categorized_sample_indices={self.categorized_sample_indices}), "
- )
-
-
-def _prepare_seq_groups(
- seq_group_metadata_list, #: List[SequenceGroupMetadata],
- seq_lens: List[int],
- query_lens: List[int],
- device: str,
- generators: Optional[Dict[str, ms.Generator]] = None,
- cache: Optional[SamplingMetadataCache] = None,
-) -> Tuple[
- List[SequenceGroupToSample],
- List[int],
- Dict[SamplingType, List[int]],
- int,
-]:
- """Prepare sequence groups and indices for sampling.
-
- Args:
- seq_group_metadata_list: A list of sequence group to batch.
- seq_lens: A list of sequence lens per sequence group.
- Index of prompt len should match with seq_group_metadata_list.
- query_lens: A list of query lengths. Prompt lens include the length
- of entire prompt tokens, and it could be shorter.
- device: A device to use for random number generators,
- `SequenceGroupToSample.generator`.
- generators: A store of per-request random number generators used
- for seeded requests.
-
- Returns:
- seq_groups: A list of sequence group to sample.
- selected_token_indices: See the definition from `SamplingMetadata`.
- categorized_sample_indices: See the definition from `SamplingMetadata`.
- num_prompts: Total number of prompts from `seq_group_metadata_list`.
- """
- # Batched sequence groups for the current model forward stsep.
- seq_groups: List[SequenceGroupToSample] = []
- # A list of token indices to sample/compute logprob. It is used to
- # prune the outcome logits from the model for the performance.
- selected_token_indices: List[int] = []
- # Used for selected_token_indices.
- model_output_idx = 0
-
- # Sampling type -> (
- # indices to sample/prompt logprob within pruned output logits,
- # indices to sample within pruned logits)
- categorized_sample_indices: Dict[SamplingType, List[int]] = {
- t: [] for t in SamplingType
- }
- # Index of logits to compute logprob. Logits include both prompt logprob
- # and sample logprob indices.
- logit_idx = 0
- # Total number of prompts from given sequence groups.
- num_prompts = 0
-
- for i, seq_group_metadata in enumerate(seq_group_metadata_list):
- seq_ids = seq_group_metadata.seq_data.keys()
-
- if cache is not None:
- sample_obj = cache.get_cached_seq_group_to_sample(len(seq_ids))
-
- for j, seq_id in enumerate(seq_ids):
- sample_obj.seq_ids[j] = seq_id
-
- sample_obj.prompt_logprob_indices.clear()
- sample_obj.sample_indices.clear()
-
- sampling_params = seq_group_metadata.sampling_params
- is_prompt = seq_group_metadata.is_prompt
- generator: Optional[ms.Generator] = None
- # If the current seq group is in decode stage, it is None.
- seq_len: Optional[int] = None
- query_len: Optional[int] = None
- prompt_logprob_indices: List[int] = (
- sample_obj.prompt_logprob_indices if cache is not None else []
- )
- sample_indices: List[int] = (
- sample_obj.sample_indices if cache is not None else []
- )
- do_sample = seq_group_metadata.do_sample
-
- if seq_group_metadata.is_prompt:
- if sampling_params.seed is not None:
- generator = ms.Generator().manual_seed(
- sampling_params.seed
- )
- if generators is not None:
- generators[seq_group_metadata.request_id] = generator
-
- num_prompts += 1
- num_prefill_sample = len(seq_ids)
- assert num_prefill_sample == 1
- assert query_lens is not None and seq_lens is not None
- query_len, seq_len = query_lens[i], seq_lens[i]
- # If we need sampling, exclude num_prefill_sample tokens from
- # prompt logprob.
- prompt_logprob_len = (
- query_len - num_prefill_sample if do_sample else query_len
- )
- sample_len = num_prefill_sample if do_sample else 0
- else:
- # Decode
- prompt_logprob_len = 0
- query_len = (
- query_lens[i] if query_lens is not None and len(query_lens) > 0 else 1
- )
- sample_len = len(seq_ids) * query_len if do_sample else 0
-
- if sampling_params.seed is not None and generators is not None:
- generator = generators.get(seq_group_metadata.request_id)
-
- # Update indices to select from the model output.
- """
- This blocks computes selected_token_indices which is used in the
- following way.
-
- hidden_states = model(...)
- logits = hidden_states[selected_token_indices]
- """
-
- if sampling_params.prompt_logprobs is not None:
- selected_token_indices.extend(
- range(model_output_idx, model_output_idx + prompt_logprob_len)
- )
- model_output_idx += prompt_logprob_len
- if do_sample:
- selected_token_indices.extend(
- range(model_output_idx, model_output_idx + sample_len)
- )
- model_output_idx += sample_len
-
- # We now find indices for logprob computation and sampling.
- """
- This block computes categorized_sample_indices which is used in the
- following way.
-
- hidden_states = model(...)
- logits = hidden_states[selected_token_indices]
- def sample(logits):
- # Use categorized_sample_indices for sampling.
- # prompt_logprob_indices to find prompt logprob indices.
- # sample_indices to find sample indices.
- """
-
- if sampling_params.prompt_logprobs is not None:
- prompt_logprob_indices.extend(
- range(logit_idx, logit_idx + prompt_logprob_len)
- )
- logit_idx += prompt_logprob_len
- if do_sample:
- sample_indices.extend(range(logit_idx, logit_idx + sample_len))
- categorized_sample_indices[sampling_params.sampling_type].extend(
- list(range(logit_idx, logit_idx + sample_len))
- )
- logit_idx += sample_len
-
- if cache is not None:
- sample_obj.sampling_params = sampling_params
- sample_obj.seq_data = seq_group_metadata.seq_data
- sample_obj.seq_len = seq_len
- sample_obj.query_len = query_len
- sample_obj.generator = generator
- sample_obj.is_prompt = is_prompt
- else:
- sample_obj = SequenceGroupToSample(
- seq_ids=list(seq_ids),
- sampling_params=sampling_params,
- seq_data=seq_group_metadata.seq_data,
- seq_len=seq_len,
- query_len=query_len,
- generator=generator,
- is_prompt=is_prompt,
- prompt_logprob_indices=list(prompt_logprob_indices),
- sample_indices=list(sample_indices),
- )
-
- seq_groups.append(sample_obj)
-
- if cache is not None:
- cache.reset()
-
- return (seq_groups, selected_token_indices, categorized_sample_indices, num_prompts)
-
-
@dataclass
class SamplingTensors:
"""Tensors for sampling."""
@@ -412,119 +45,6 @@ class SamplingTensors:
prompt_tokens: Tensor
output_tokens: Tensor
- @classmethod
- def from_sampling_metadata(
- cls,
- sampling_metadata: "SamplingMetadata",
- vocab_size: int,
- device, #: torch.device,
- dtype, #: torch.dtype,
- ) -> Tuple["SamplingTensors", bool, bool, bool]:
- prompt_tokens: List[array] = []
- output_tokens: List[array] = []
- top_ks: List[int] = []
- temperatures: List[float] = []
- top_ps: List[float] = []
- min_ps: List[float] = []
- presence_penalties: List[float] = []
- frequency_penalties: List[float] = []
- repetition_penalties: List[float] = []
- do_penalties = False
- do_top_p_top_k = False
- do_min_p = False
-
- assert sampling_metadata.seq_groups is not None
- for seq_group in sampling_metadata.seq_groups:
- seq_ids = seq_group.seq_ids
- sampling_params = seq_group.sampling_params
- temperature = sampling_params.temperature
- p = sampling_params.presence_penalty
- f = sampling_params.frequency_penalty
- r = sampling_params.repetition_penalty
- top_p = sampling_params.top_p
- min_p = sampling_params.min_p
-
- # k should not be greater than the vocab size.
- top_k = min(sampling_params.top_k, vocab_size)
- top_k = vocab_size if top_k == -1 else top_k
- if temperature < _SAMPLING_EPS:
- # NOTE: Zero temperature means deterministic sampling
- # (i.e., greedy sampling or beam search).
- # Set the temperature to 1 to avoid division by zero.
- temperature = 1.0
- if not do_top_p_top_k and (
- top_p < 1.0 - _SAMPLING_EPS or top_k != vocab_size
- ):
- do_top_p_top_k = True
- if not do_min_p and min_p > _SAMPLING_EPS:
- do_min_p = True
- if not do_penalties and (
- abs(p) >= _SAMPLING_EPS
- or abs(f) >= _SAMPLING_EPS
- or abs(r - 1.0) >= _SAMPLING_EPS
- ):
- do_penalties = True
-
- is_prompt = seq_group.is_prompt
- if is_prompt and sampling_params.prompt_logprobs is not None:
- # For tokens in the prompt that we only need to get
- # their logprobs
- query_len = seq_group.query_len
- assert query_len is not None
- prefill_len = len(seq_group.prompt_logprob_indices)
- temperatures += [temperature] * prefill_len
- top_ps += [top_p] * prefill_len
- top_ks += [top_k] * prefill_len
- min_ps += [min_p] * prefill_len
- presence_penalties += [0] * prefill_len
- frequency_penalties += [0] * prefill_len
- repetition_penalties += [1] * prefill_len
-
- if seq_group.do_sample:
- sample_lens = len(seq_group.sample_indices)
- assert sample_lens >= len(seq_ids)
- temperatures += [temperature] * sample_lens
- top_ps += [top_p] * sample_lens
- top_ks += [top_k] * sample_lens
- min_ps += [min_p] * sample_lens
- presence_penalties += [p] * sample_lens
- frequency_penalties += [f] * sample_lens
- repetition_penalties += [r] * sample_lens
-
- if do_penalties:
- for seq_group in sampling_metadata.seq_groups:
- seq_ids = seq_group.seq_ids
- sampling_params = seq_group.sampling_params
- if seq_group.is_prompt and sampling_params.prompt_logprobs is not None:
- prefill_len = len(seq_group.prompt_logprob_indices)
- prompt_tokens.extend(
- array(VLLM_TOKEN_ID_ARRAY_TYPE) for _ in range(prefill_len)
- )
- output_tokens.extend(
- array(VLLM_TOKEN_ID_ARRAY_TYPE) for _ in range(prefill_len)
- )
- if seq_group.do_sample:
- for seq_id in seq_ids:
- seq_data = seq_group.seq_data[seq_id]
- prompt_tokens.append(seq_data.prompt_token_ids_array)
- output_tokens.append(seq_data.output_token_ids_array)
-
- sampling_tensors = SamplingTensors.from_lists(
- temperatures,
- top_ps,
- top_ks,
- min_ps,
- presence_penalties,
- frequency_penalties,
- repetition_penalties,
- prompt_tokens,
- output_tokens,
- vocab_size,
- device,
- dtype,
- )
- return (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p)
-
@classmethod
def from_lists(
cls,
@@ -600,6 +120,7 @@ class SamplingTensors:
# Because the memory is pinned, we can do non-blocking
# transfer to device.
+ # For MindSpore: MindSpore does not support to device now
return cls(
temperatures=temperatures_t,
top_ps=top_ps_t,
diff --git a/vllm_mindspore/multimodal/inputs.py b/vllm_mindspore/multimodal/inputs.py
index 2673ce6ea26646f88e7e2da957dc46074160a946..8bc9388545c2c1f36d9346992788a9c749b2573e 100644
--- a/vllm_mindspore/multimodal/inputs.py
+++ b/vllm_mindspore/multimodal/inputs.py
@@ -1,5 +1,6 @@
#!/usr/bin/env python3
-# encoding: utf-8
+# type: ignore
+# isort:skip_file
# Copyright 2025 Huawei Technologies Co., Ltd
# Copyright 2024 The vLLM team.
#
@@ -15,22 +16,62 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
+from collections import defaultdict
+from dataclasses import dataclass
from typing import Union, cast
-
import mindspore
+from vllm.multimodal.inputs import BaseMultiModalField, BatchedTensorInputs, JSONTree, json_map_leaves,\
+ nested_tensors_equal
+from vllm.multimodal import MultiModalKwargs
+
+NestedTensors = Union[list["NestedTensors"], list[mindspore.Tensor],
+ mindspore.Tensor, tuple[mindspore.Tensor, ...]]
+
+
+@dataclass
+class MultiModalFieldElem:
+ """
+ Represents a keyword argument corresponding to a multi-modal item
+ in :class:`MultiModalKwargs`.
+ """
+
+ modality: str
+ """
+ The modality of the corresponding multi-modal item.
+ Each multi-modal item can consist of multiple keyword arguments.
+ """
-from vllm.multimodal.inputs import BatchedTensorInputs, JSONTree, json_map_leaves
+ key: str
+ """
+ The key of this field in :class:`MultiModalKwargs`,
+ i.e. the name of the keyword argument to be passed to the model.
+ """
+ data: NestedTensors
+ """
+ The tensor data of this field in :class:`MultiModalKwargs`,
+ i.e. the value of the keyword argument to be passed to the model.
+ """
-NestedTensors = Union[list["NestedTensors"], list[mindspore.Tensor], mindspore.Tensor,
- tuple[mindspore.Tensor, ...]]
+ field: "BaseMultiModalField"
+ """
+ Defines how to combine the tensor data of this field with others
+ in order to batch multi-modal items together for model inference.
+ """
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, self.__class__):
+ return False
+
+ return ((self.modality, self.key) == (other.modality, other.key)
+ and nested_tensors_equal(self.data, other.data)
+ and type(self.field) == type(other.field)) # noqa: E721
-@staticmethod
def as_kwargs(
batched_inputs: BatchedTensorInputs,
*,
- device = None,
+ device=None,
) -> BatchedTensorInputs:
# replace as_kwargs of vLLM for multi-model
json_inputs = cast(JSONTree[mindspore.Tensor], batched_inputs)
@@ -40,4 +81,20 @@ def as_kwargs(
json_inputs,
)
- return cast(BatchedTensorInputs, json_mapped)
\ No newline at end of file
+ return cast(BatchedTensorInputs, json_mapped)
+
+
+def from_items(items):
+ """Construct a new :class:`MultiModalKwargs` from multiple items."""
+ elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list)
+ for item in items:
+ for key, elem in item.items():
+ # transform elem.data to tensor, gpu is tensor.
+ elem.data = mindspore.Tensor(elem.data)
+ elems_by_key[key].append(elem)
+ data = {
+ key: elems[0].field.reduce_data(elems)
+ for key, elems in elems_by_key.items() if len(elems) > 0
+ }
+
+ return MultiModalKwargs(data, items=items)
diff --git a/vllm_mindspore/platforms/ascend.py b/vllm_mindspore/platforms/ascend.py
index 356a33a040c050b0825a1c2fe5fea2179fbafa60..43d5d1773a6312f757f196e4e2e9974c0bff3a32 100644
--- a/vllm_mindspore/platforms/ascend.py
+++ b/vllm_mindspore/platforms/ascend.py
@@ -1,5 +1,4 @@
#!/usr/bin/env python3
-# encoding: utf-8
# Copyright 2025 Huawei Technologies Co., Ltd
# Copyright 2024 The vLLM team.
#
@@ -17,15 +16,12 @@
# ============================================================================
"""Ascend platform."""
-import os
-from typing import (TYPE_CHECKING, Optional, Union, Tuple)
+from typing import TYPE_CHECKING, Optional, Tuple, Union
import torch
-import mindspore as ms
-
-from vllm.platforms.interface import DeviceCapability, Platform, PlatformEnum, _Backend
-from vllm.logger import init_logger
import vllm.envs as envs
+from vllm.logger import init_logger
+from vllm.platforms.interface import Platform, PlatformEnum, _Backend
if TYPE_CHECKING:
from vllm.config import ModelConfig, VllmConfig
@@ -40,7 +36,7 @@ class AscendPlatform(Platform):
_enum = PlatformEnum.OOT
device_name: str = "npu"
- device_type: str = "cuda" # To use cuda worker, executor...
+ device_type: str = "cuda" # To use cuda worker, executor...
simple_compile_backend: str = "npu"
ray_device_key: str = "NPU"
device_control_env_var: str = "ASCEND_RT_VISIBLE_DEVICES"
@@ -69,29 +65,34 @@ class AscendPlatform(Platform):
@classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
- """
- Check and update the configuration for the current platform.
-
- It can raise an exception if the configuration is not compatible with
- the current platform, or it can update the configuration to make it
- compatible with the current platform.
-
- The config is passed by reference, so it can be modified in place.
- """
parallel_config = vllm_config.parallel_config
scheduler_config = vllm_config.scheduler_config
+ compilation_config = vllm_config.compilation_config
+ model_config = vllm_config.model_config
- import vllm.envs as envs
- if envs.VLLM_USE_V1:
- parallel_config.worker_cls = \
- "vllm.v1.worker.gpu_worker.Worker"
- else:
- if parallel_config.worker_cls == "auto":
- if scheduler_config.is_multi_step:
- parallel_config.worker_cls = "vllm.worker.multi_step_worker.MultiStepWorker"
- elif vllm_config.speculative_config:
- parallel_config.worker_cls = "vllm.spec_decode.spec_decode_worker.create_spec_worker"
- parallel_config.sd_worker_cls = "vllm.worker.worker.Worker"
+ if parallel_config.worker_cls == "auto":
+ if scheduler_config.is_multi_step:
+ if envs.VLLM_USE_V1:
+ raise NotImplementedError(
+ "Multi-step scheduling is not supported (and not "
+ "needed) on vLLM V1. Please launch without "
+ "--num-scheduler-steps.")
+ else:
+ parallel_config.worker_cls = \
+ "vllm.worker.multi_step_worker.MultiStepWorker"
+ elif vllm_config.speculative_config:
+ if envs.VLLM_USE_V1:
+ parallel_config.worker_cls = \
+ "vllm.v1.worker.gpu_worker.Worker"
+ else:
+ parallel_config.worker_cls = \
+ "vllm.spec_decode.spec_decode_worker.create_spec_worker"
+ parallel_config.sd_worker_cls = \
+ "vllm.worker.worker.Worker"
+ else:
+ if envs.VLLM_USE_V1:
+ parallel_config.worker_cls = \
+ "vllm.v1.worker.gpu_worker.Worker"
else:
parallel_config.worker_cls = "vllm.worker.worker.Worker"
@@ -103,12 +104,13 @@ class AscendPlatform(Platform):
model_config.disable_cascade_attn = True
@classmethod
- def get_attn_backend_cls(cls, selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1, use_mla):
+ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
+ kv_cache_dtype, block_size, use_v1, use_mla):
"""Get the attention backend class of a device."""
if use_v1:
if use_mla:
- return "vllm_mindspore.v1.attention.backends.flash_attn.MLABackend"
- return "vllm_mindspore.v1.attention.backends.flash_attn.FlashAttentionBackend"
+ return "vllm_mindspore.v1.attention.backends.ms_attn.MLABackend"
+ return "vllm_mindspore.v1.attention.backends.ms_attn.MsAttentionBackend"
raise RuntimeError("vLLM-MindSpore do not support v1 egine now!")
if use_mla:
logger.info("Using MindSpore MLA backend.")
@@ -119,12 +121,13 @@ class AscendPlatform(Platform):
return "vllm_mindspore.attention.backends.ms_attn.MsAttentionBackend"
raise ValueError(
- "Invaild attention backend %s for vLLM-MindSpore with head_size: %s, dtype: %s, kv_cache_dtype: %s, block_size: %s."
- % (str(selected_backend), str(head_size), str(dtype), str(kv_cache_dtype), str(block_size))
+ f"Invalid attention backend {str(selected_backend)} for vLLM-MindSpore with head_size: {str(head_size)}, dtype: {str(dtype)}, kv_cache_dtype: {str(kv_cache_dtype)}, block_size: {str(block_size)}."
)
@classmethod
- def get_current_memory_usage(cls, device: Optional[torch.types.Device] = None) -> float:
+ def get_current_memory_usage(cls,
+ device: Optional[torch.types.Device] = None
+ ) -> float:
"""Return the memory usage in bytes."""
torch.cuda.reset_peak_memory_stats()
return torch.cuda.max_memory_allocated(device)
@@ -144,4 +147,7 @@ class AscendPlatform(Platform):
@classmethod
def supports_v1(cls, model_config: ModelConfig) -> bool:
- return True
\ No newline at end of file
+ return True
+
+ def get_punica_wrapper(cls) -> str:
+ return "vllm_mindspore.lora.punica_wrapper.punica_npu.PunicaWrapperNPU"
diff --git a/vllm_mindspore/utils.py b/vllm_mindspore/utils.py
index 60cd4af040fc9e9bda617eb8b6cd5d5130ef765f..920bb23066966c9de2ca8058f69567fe61f9aad5 100644
--- a/vllm_mindspore/utils.py
+++ b/vllm_mindspore/utils.py
@@ -19,8 +19,8 @@ import contextlib
import gc
import os
import sys
-from typing import (TYPE_CHECKING, Callable, Generator, List, Optional, Tuple,
- Union)
+from enum import Enum
+from typing import TYPE_CHECKING, Generator, List, Optional, Tuple, Union
import numpy as np
import torch
@@ -30,11 +30,10 @@ if TYPE_CHECKING:
else:
Library = None
-from vllm.logger import init_logger
-
import mindspore as ms
from mindspore import dtype as mstype
from mindspore.common.initializer import Zero
+from vllm.logger import init_logger
from vllm.utils import (TORCH_DTYPE_TO_NUMPY_DTYPE, MemoryProfilingResult,
MemorySnapshot, T, make_ndarray_with_pad)
@@ -61,17 +60,6 @@ def get_valid_dtype(dtype):
return dtype
-def direct_register_custom_op(
- op_name: str,
- op_func: Callable,
- mutates_args: List[str],
- fake_impl: Optional[Callable] = None,
- target_lib: Optional[Library] = None,
- dispatch_key: str = "CUDA",
-):
- ...
-
-
def _create_empty_tensor(ms_type):
init_func = Zero()
init_func.__enable_zero_dim__ = True
@@ -153,48 +141,10 @@ STR_DTYPE_TO_MS_DTYPE = {
}
-def get_dtype_size(dtype: torch.dtype) -> int:
- """Get the size of the data type in bytes."""
- if isinstance(dtype, str):
- dtype = STR_DTYPE_TO_TENSOR_DTYPE[dtype]
- return torch.tensor([1], dtype=dtype).itemsize
-
-
-def ascend_device_count_stateless() -> int:
- visible_device_str = os.environ.get("ASCEND_RT_VISIBLE_DEVICES", None)
- if visible_device_str:
- try:
- res = visible_device_str.split(",")
- except Exception as e:
- logger.error('Cannot parse "ASCEND_RT_VISIBLE_DEVICES" for: %s!',
- str(e))
- raise ValueError(
- f'Error argument({visible_device_str}) of environ "ASCEND_RT_VISIBLE_DEVICES"!'
- ) from e
-
- return len(res)
-
- import re
- import subprocess
-
- output = subprocess.check_output(["npu-smi", "info"], encoding="utf-8")
- res = re.findall(
- r"\|\s+\d+\s+\w+\s+\|\s+(\w+)\s+\|\s+(?:[0-9\.]+|-)\s+[0-9\.]+\s+\d+\s+\/\s+\d+\s+\|",
- output,
- )
-
- avl_devices = []
- for i, stat in enumerate(res):
- if stat != "OK":
- logger.warning("Device %d is not ok, status is %s!", i, stat)
- else:
- avl_devices.append(str(i))
- visible_device_str = ",".join(avl_devices)
- os.environ["ASCEND_RT_VISIBLE_DEVICES"] = visible_device_str
- logger.info('Set environ "ASCEND_RT_VISIBLE_DEVICES" as %s',
- visible_device_str)
-
- return len(avl_devices)
+class vllmModelBackendEnum(str, Enum):
+ """Define the variable Enum of vLLM_MODEL_BACKEND"""
+ MF = 'MindFormers'
+ MIND_ONE = 'MindONE'
def ascend_is_initialized():
@@ -203,23 +153,29 @@ def ascend_is_initialized():
def is_mindformers_model_backend():
- return (os.getenv("vLLM_MODEL_BACKEND") # noqa: SIM112
- and
- os.environ["vLLM_MODEL_BACKEND"] == "MindFormers" # noqa: SIM112
- )
+ vllm_model_backend = os.getenv("vLLM_MODEL_BACKEND") # noqa: SIM112
+ if vllm_model_backend:
+ try:
+ vllmModelBackendEnum(vllm_model_backend)
+ return vllm_model_backend == vllmModelBackendEnum.MF
+ except ValueError as exc:
+ allowed_values = [member.value for member in vllmModelBackendEnum]
+ raise ValueError(
+ f"Illegal value of vLLM_MODEL_BACKEND '{vllm_model_backend}',"
+ f" allowed_values: {', '.join(allowed_values)}") from exc
+ else:
+ return False
def is_mindone_model_backend():
return (os.getenv("vLLM_MODEL_BACKEND") # noqa: SIM112
- and os.environ["vLLM_MODEL_BACKEND"] == "MindONE" # noqa: SIM112
- )
+ and os.environ["vLLM_MODEL_BACKEND"] # noqa: SIM112
+ == vllmModelBackendEnum.MIND_ONE)
def check_ready():
- import vllm.envs as envs
from mindspore import set_context
-
# Common environment variables of predict.
set_context(jit_config={"jit_level": "O0", "infer_boost": "on"})
default_env = {
@@ -234,15 +190,6 @@ def check_ready():
if is_mindformers_model_backend():
logger.info("Run with Mindformers backend!")
- necessary_envs = ("MINDFORMERS_MODEL_CONFIG", )
- lost_envs = [
- env_item for env_item in necessary_envs if not os.getenv(env_item)
- ]
-
- if lost_envs:
- raise RuntimeError(
- f'For "MindFormers" model backend, environments {str(lost_envs)} should be set!'
- )
elif is_mindone_model_backend():
logger.info("Run with MindONE backend!")
else:
diff --git a/vllm_mindspore/v1/attention/backends/flash_attn.py b/vllm_mindspore/v1/attention/backends/ms_attn.py
similarity index 64%
rename from vllm_mindspore/v1/attention/backends/flash_attn.py
rename to vllm_mindspore/v1/attention/backends/ms_attn.py
index b5c5629ee51fc7faf969f18a7b596e60d939387f..e8a94035863bc7a872d1245193a948761a6bd6d7 100644
--- a/vllm_mindspore/v1/attention/backends/flash_attn.py
+++ b/vllm_mindspore/v1/attention/backends/ms_attn.py
@@ -4,7 +4,6 @@ from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type
import numpy as np
-import torch
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
@@ -21,7 +20,7 @@ from mindspore._c_expression import swap_cache
logger = init_logger(__name__)
-class FlashAttentionBackend(AttentionBackend):
+class MsAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
@@ -39,11 +38,11 @@ class FlashAttentionBackend(AttentionBackend):
@staticmethod
def get_metadata_cls() -> Type["AttentionMetadata"]:
- return FlashAttentionMetadata
+ return MsAttentionMetadata
@staticmethod
- def get_builder_cls() -> Type["AttentionMetadataBuilder"]:
- return FlashAttentionMetadataBuilder
+ def get_builder_cls() -> Type["MsAttentionMetadataBuilder"]:
+ return MsAttentionMetadataBuilder
@staticmethod
def get_kv_cache_shape(
@@ -72,11 +71,11 @@ class MLABackend(AttentionBackend):
@staticmethod
def get_metadata_cls() -> Type["AttentionMetadata"]:
- return FlashAttentionMetadata
+ return MsAttentionMetadata
@staticmethod
- def get_builder_cls() -> Type["AttentionMetadataBuilder"]:
- return FlashAttentionMetadataBuilder
+ def get_builder_cls() -> Type["MsAttentionMetadataBuilder"]:
+ return MsAttentionMetadataBuilder
@staticmethod
def get_kv_cache_shape(
@@ -98,8 +97,12 @@ class MLABackend(AttentionBackend):
return [576]
+
@dataclass
-class FlashAttentionMetadata:
+class MsAttentionMetadata:
+ """
+ AttentionMetadata for vllm-mindspore V1
+ """
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
@@ -108,47 +111,36 @@ class FlashAttentionMetadata:
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
- max_seq_len: int
- seq_lens: torch.Tensor
+ # add by vllm-mindspore begin
seq_lens_np: np.ndarray
- block_tables: torch.Tensor
- slot_mapping: torch.Tensor
- q_seq_lens: torch.Tensor
+ block_tables: ms.Tensor
q_seq_lens_np: np.ndarray
- context_lens: torch.Tensor
+ context_lens: ms.Tensor
max_context_lens: int
- query_start_loc: torch.Tensor
-
- def __getitem__(self, key):
- if key == "batch_valid_length":
- key = "seq_lens"
- return getattr(self, key)
-
-
-class MsAttentionImpl(AttentionImpl):
- """
- If the input tensors contain prompt tokens, the layout is as follows:
- |<--------------- num_prefill_tokens ----------------->|
- |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
+ # add by vllm-mindspore end
- Otherwise, the layout is as follows:
- |<----------------- num_decode_tokens ------------------>|
- |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
-
- Generation tokens can contain padding when cuda-graph is used.
- Currently, prompt tokens don't contain any padding.
+ #num_actual_tokens: int = None # Number of tokens excluding padding.
+ #max_query_len: int
+ query_start_loc: ms.Tensor
+ max_seq_len: int
+ seq_lens: ms.Tensor
+ #block_table: torch.Tensor
+ slot_mapping: ms.Tensor
- The prompts might have different lengths, while the generation tokens
- always have length 1.
+ # For cascade attention.
+ #use_cascade: bool
+ #common_prefix_len: int
+ #cu_prefix_query_lens: Optional[torch.Tensor]
+ #prefix_kv_lens: Optional[torch.Tensor]
+ #suffix_kv_lens: Optional[torch.Tensor]
- If chunked prefill is enabled, prefill tokens and decode tokens can be
- batched together in a flattened 1D query.
+ # For logging.
+ num_input_tokens: int = 0 # Number of tokens including padding.
- |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
- |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|
- Currently, cuda graph is disabled for chunked prefill, meaning there's no
- padding between prefill and decode tokens.
+class MsAttentionImpl(AttentionImpl):
+ """
+ AttentionImpl for vllm-mindspore V1
"""
def __init__(
@@ -168,31 +160,20 @@ class MsAttentionImpl(AttentionImpl):
def forward(
self,
- layer: torch.nn.Module,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- kv_cache: torch.Tensor,
- attn_metadata: FlashAttentionMetadata,
- output: Optional[torch.Tensor] = None,
- ) -> torch.Tensor:
- """Forward pass with FlashAttention.
-
- Args:
- query: shape = [num_tokens, num_heads, head_size]
- key: shape = [num_tokens, num_kv_heads, head_size]
- value: shape = [num_tokens, num_kv_heads, head_size]
- output: shape = [num_tokens, num_heads, head_size]
- kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
- NOTE: kv_cache will be an empty tensor with shape [0]
- for profiling run.
- attn_metadata: Metadata for attention.
- NOTE: It in-place updates the output tensor.
+ layer: ms.nn.Cell,
+ query: ms.Tensor,
+ key: ms.Tensor,
+ value: ms.Tensor,
+ kv_cache: ms.Tensor,
+ attn_metadata: MsAttentionMetadata,
+ output: Optional[ms.Tensor] = None,
+ ) -> ms.Tensor:
+ """Forward pass with MsAttention.
"""
pass
-class FlashAttentionMetadataBuilder:
+class MsAttentionMetadataBuilder:
def __init__(self, runner: "GPUModelRunner"):
self.runner = runner
@@ -213,14 +194,12 @@ class FlashAttentionMetadataBuilder:
context_lens = ms.from_numpy(self.runner.input_batch.num_computed_tokens_cpu[:num_reqs])
q_seq_lens_np = np.diff(self.runner.query_start_loc_np[:num_reqs + 1])
- q_seq_lens = ms.from_numpy(q_seq_lens_np)
- attn_metadata = FlashAttentionMetadata(
+ attn_metadata = MsAttentionMetadata(
seq_lens=seq_lens,
seq_lens_np=seq_lens_np,
block_tables=(self.runner.input_batch.block_table.get_device_tensor()[:num_reqs]),
slot_mapping=slot_mapping,
- q_seq_lens=q_seq_lens,
q_seq_lens_np=q_seq_lens_np,
max_seq_len=max_seq_len,
context_lens=context_lens,
@@ -228,3 +207,5 @@ class FlashAttentionMetadataBuilder:
query_start_loc = query_start_loc
)
return attn_metadata
+
+FlashAttentionMetadata = MsAttentionMetadata
\ No newline at end of file
diff --git a/vllm_mindspore/v1/core/__init__.py b/vllm_mindspore/v1/core/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/vllm_mindspore/v1/core/sched/__init__.py b/vllm_mindspore/v1/core/sched/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/vllm_mindspore/v1/core/sched/scheduler.py b/vllm_mindspore/v1/core/sched/scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..76f7a6b90df6ba454e69f5e2b85e06c9c5a5674d
--- /dev/null
+++ b/vllm_mindspore/v1/core/sched/scheduler.py
@@ -0,0 +1,177 @@
+# ruff: noqa: G004:
+
+from typing import Optional
+
+from vllm.logger import init_logger
+from vllm.v1.core.sched.output import SchedulerOutput
+from vllm.v1.core.sched.utils import check_stop
+from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs, FinishReason
+from vllm.v1.outputs import ModelRunnerOutput
+from vllm.v1.request import Request, RequestStatus
+from vllm.v1.spec_decode.metrics import SpecDecodingStats
+
+logger = init_logger(__name__)
+
+
+def update_from_output(
+ self,
+ scheduler_output: SchedulerOutput,
+ model_runner_output: ModelRunnerOutput,
+) -> EngineCoreOutputs:
+ sampled_token_ids = model_runner_output.sampled_token_ids
+ spec_token_ids = model_runner_output.spec_token_ids
+ logprobs = model_runner_output.logprobs
+ prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
+ num_scheduled_tokens = scheduler_output.num_scheduled_tokens
+
+ new_running: list[Request] = []
+ outputs: list[EngineCoreOutput] = []
+ spec_decoding_stats: Optional[SpecDecodingStats] = None
+
+ # NOTE(woosuk): As len(self.running) can be up to 1K or more, the below
+ # loop can be a performance bottleneck. We should do our best to avoid
+ # expensive operations inside the loop.
+
+ # Add by vllm-mindspore begin:
+ running_req_ids = [req.request_id for req in self.running]
+ # abort_req_ids used to keep track of failed requests caused by model execution exception
+ abort_req_ids: list[str] = []
+ # Add by vllm-mindspore end.
+
+ for request in self.running:
+ req_id = request.request_id
+
+ # Add by vllm-mindspore begin:
+ # None sampled_token_ids comes from exception model execution, set them to abort list
+ # to keep main scheduler task running right.
+ if sampled_token_ids is None:
+ self.scheduled_req_ids.remove(req_id)
+ logger.warning(
+ f'Process aborted request {req_id} from running requests {running_req_ids}'
+ )
+ outputs.append(
+ EngineCoreOutput(request_id=req_id,
+ new_token_ids=[],
+ finish_reason=FinishReason.ABORT,
+ new_logprobs=None,
+ new_prompt_logprobs_tensors=None,
+ stop_reason=request.stop_reason,
+ events=request.take_events()))
+ abort_req_ids.append(req_id)
+ continue
+ # Add by vllm-mindspore end.
+
+ num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0)
+ if num_tokens_scheduled == 0:
+ # The request was not scheduled in this step.
+ new_running.append(request)
+ continue
+
+ req_index = model_runner_output.req_id_to_index[req_id]
+ generated_token_ids = sampled_token_ids[req_index]
+
+ scheduled_spec_token_ids = (
+ scheduler_output.scheduled_spec_decode_tokens.get(req_id))
+ if scheduled_spec_token_ids:
+ # num_computed_tokens represents the number of tokens
+ # processed in the current step, considering scheduled
+ # tokens and rejections. If some tokens are rejected,
+ # num_computed_tokens is decreased by the number of rejected
+ # tokens, where is given by:
+ # len(scheduled_spec_token_ids) + 1 - len(generated_token_ids).
+ num_tokens_rejected = (len(scheduled_spec_token_ids) + 1 -
+ len(generated_token_ids))
+ request.num_computed_tokens -= num_tokens_rejected
+ spec_decoding_stats = self.make_spec_decoding_stats(
+ spec_decoding_stats,
+ num_draft_tokens=len(scheduled_spec_token_ids),
+ num_accepted_tokens=len(generated_token_ids) - 1)
+
+ cached_encoder_input_ids = (
+ self.encoder_cache_manager.get_cached_input_ids(request))
+ # OPTIMIZATION: Avoid list(set) if the set is empty.
+ if cached_encoder_input_ids:
+ for input_id in list(cached_encoder_input_ids):
+ mm_positions = request.mm_positions[input_id]
+ start_pos = mm_positions["offset"]
+ num_tokens = mm_positions["length"]
+ if start_pos + num_tokens <= request.num_computed_tokens:
+ # The encoder output is already processed and stored
+ # in the decoder's KV cache.
+ self.encoder_cache_manager.free_encoder_input(
+ request, input_id)
+
+ # Add newly generated spec token ids to the request.
+ if spec_token_ids is not None:
+ request.spec_token_ids = spec_token_ids[req_index]
+
+ stopped = False
+ new_logprobs = None
+ new_token_ids = generated_token_ids
+
+ # Append generated tokens and check for stop. Note that if
+ # a request is still being prefilled, we expect the model runner
+ # to return empty token ids for the request.
+ for num_new, output_token_id in enumerate(new_token_ids, 1):
+ request.append_output_token_ids(output_token_id)
+
+ # Check for stop and update request state.
+ # This must be called before we make the EngineCoreOutput.
+ stopped = check_stop(request, self.max_model_len)
+ if stopped:
+ self._free_request(request)
+ del new_token_ids[num_new:] # Trim new tokens if needed.
+ break
+
+ # Extract sample logprobs if needed.
+ if request.sampling_params.logprobs is not None and logprobs:
+ # NOTE: once we support N tokens per step (spec decode),
+ # the outer lists can be of length > 1.
+ new_logprobs = logprobs.slice(req_index, req_index + 1)
+
+ if new_token_ids and request.use_structured_output:
+ # NOTE: structured_output_request
+ # should not be None if use_structured_output, we have
+ # check above, so safe to ignore type warning
+ request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr]
+ req_id, new_token_ids)
+
+ # Get prompt logprobs for this request.
+ prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
+ if new_token_ids:
+ # Add EngineCoreOutput for this Request.
+ outputs.append(
+ EngineCoreOutput(
+ request_id=req_id,
+ new_token_ids=new_token_ids,
+ finish_reason=request.get_finished_reason(),
+ new_logprobs=new_logprobs,
+ new_prompt_logprobs_tensors=prompt_logprobs_tensors,
+ stop_reason=request.stop_reason,
+ events=request.take_events()))
+ else:
+ # Invariant: EngineCore returns no partial prefill outputs.
+ assert not prompt_logprobs_tensors
+
+ self.scheduled_req_ids.remove(req_id)
+ if not stopped:
+ new_running.append(request)
+
+ # Add by vllm-mindspore begin:
+ # make failed requests finished to make the server can continue to process new request
+ if len(abort_req_ids) > 0:
+ logger.warning(f'Aborted requests are {abort_req_ids}')
+ self.finish_requests(abort_req_ids, RequestStatus.FINISHED_ABORTED)
+ # Add by vllm-mindspore end.
+
+ self.running = new_running
+ engine_core_outputs = EngineCoreOutputs(
+ outputs=outputs,
+ scheduler_stats=self.make_stats(spec_decoding_stats),
+ )
+ if self.include_finished_set:
+ #TODO currently sending duplicates here, improve this
+ engine_core_outputs.finished_requests = (
+ scheduler_output.finished_req_ids | self.finished_req_ids)
+
+ return engine_core_outputs
diff --git a/vllm_mindspore/v1/executor/__init__.py b/vllm_mindspore/v1/executor/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/vllm_mindspore/v1/executor/multiproc_executor.py b/vllm_mindspore/v1/executor/multiproc_executor.py
new file mode 100644
index 0000000000000000000000000000000000000000..04c9190c4aa5ef0aaa1bf4e925caa2b8fae377ea
--- /dev/null
+++ b/vllm_mindspore/v1/executor/multiproc_executor.py
@@ -0,0 +1,64 @@
+# SPDX-License-Identifier: Apache-2.0
+
+# Functions are adapted from vllm-project/vllm/v1/executor/multiproc_executor.py
+#
+# Copyright 2025 Huawei Technologies Co., Ltd.
+# Copyright 2024 The vLLM team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Monkey Patch functions for v1 executor mp distributed backend."""
+import os
+import signal
+import time
+
+from vllm.logger import init_logger
+
+logger = init_logger(__name__)
+
+
+def executor_ensure_worker_termination(self):
+ """Ensure that all worker processes are terminated. Assumes workers have
+ received termination requests. Waits for processing, then sends
+ termination and kill signals if needed."""
+
+ def wait_for_termination(procs, timeout):
+ if not time:
+ # If we are in late stage shutdown, the interpreter may replace
+ # `time` with `None`.
+ return all(not proc.is_alive() for proc in procs)
+ start_time = time.time()
+ while time.time() - start_time < timeout:
+ if all(not proc.is_alive() for proc in procs):
+ return True
+ time.sleep(0.1)
+ return False
+
+ # Send SIGTERM if still running
+ active_procs = [w.proc for w in self.workers if w.proc.is_alive()]
+ for p in active_procs:
+ p.terminate()
+ if not wait_for_termination(active_procs, 4):
+ # Send SIGKILL if still running
+ active_procs = [p for p in active_procs if p.is_alive()]
+ for p in active_procs:
+ # vllm-mindspore begin: kill all the process in the process group
+ # (including scheduler process, kernel process and so on) instead of
+ # calling p.kill.
+ pid = p.pid
+ try:
+ os.killpg(pid, signal.SIGKILL)
+ except Exception as e:
+ logger.debug("Kill process %d error: %s!", pid, str(e))
+ # vllm-mindspore end.
+
+ self._cleanup_sockets()
diff --git a/vllm_mindspore/v1/spec_decode/eagle.py b/vllm_mindspore/v1/spec_decode/eagle.py
index 7279bcaf595526df7dc859836f03fc2aeba7e994..0e252f88bd160b5bae4f4f5ad2a5b3a638d5dd81 100644
--- a/vllm_mindspore/v1/spec_decode/eagle.py
+++ b/vllm_mindspore/v1/spec_decode/eagle.py
@@ -4,7 +4,7 @@ import torch.nn as nn
from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context
-from vllm_mindspore.v1.attention.backends.flash_attn import FlashAttentionMetadata
+from vllm_mindspore.v1.attention.backends.ms_attn import MsAttentionMetadata
from vllm.v1.sample.metadata import SamplingMetadata
@@ -56,7 +56,8 @@ class EagleProposer:
# FIXME(woosuk): The below two ops cause synchronization. Optimize.
max_seq_len = seq_lens.max().item()
max_num_tokens = (cu_num_tokens[1:] - cu_num_tokens[:-1]).max().item()
- attn_metadata = FlashAttentionMetadata(
+ # TODO: new members need to be added to the MsAttentionMetadata for Eagle feature
+ attn_metadata = MsAttentionMetadata(
num_actual_tokens=num_tokens,
max_query_len=max_num_tokens,
query_start_loc=cu_num_tokens,
diff --git a/vllm_mindspore/v1/worker/gpu_model_runner.py b/vllm_mindspore/v1/worker/gpu_model_runner.py
index a21a2f73e889169e6d30ca2d2bdd23bb03bcc29b..7f4e3fe162150ed54482fd21e24806dd9d81d018 100644
--- a/vllm_mindspore/v1/worker/gpu_model_runner.py
+++ b/vllm_mindspore/v1/worker/gpu_model_runner.py
@@ -1,32 +1,47 @@
+#!/usr/bin/env python3
+# type: ignore
+# isort:skip_file
+# Copyright 2025 Huawei Technologies Co., Ltd
+# Copyright 2024 The vLLM team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
from typing import Dict, Tuple, List
-import gc
import numpy as np
import torch
from mindspore import mutable
-import mindspore as ms
-from vllm_mindspore.v1.attention.backends.flash_attn import (FlashAttentionMetadata,
- FlashAttentionBackend,
- MLABackend)
+from vllm_mindspore.v1.attention.backends.ms_attn import MsAttentionMetadata
from vllm_mindspore.utils import get_valid_dtype
+from vllm_mindspore.model_executor.layers.rotary_embedding import InferMRotaryEmbedding as MRotaryEmbedding # type: ignore[attr-defined]
-from vllm.v1.kv_cache_interface import FullAttentionSpec
+from vllm.v1.outputs import ModelRunnerOutput
+from vllm.attention import AttentionType
+from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheSpec, SlidingWindowSpec
from vllm.v1.utils import bind_kv_cache
-from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
-from vllm.distributed.parallel_state import get_pp_group
-from vllm.utils import cdiv
from vllm.logger import init_logger
from vllm.v1.worker.gpu_input_batch import CachedRequestState
-from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
+from vllm.v1.core.sched.output import SchedulerOutput
from vllm.sampling_params import SamplingType
-
logger = init_logger(__name__)
+
+
def _prepare_inputs(
- self,
- scheduler_output: "SchedulerOutput",
-) -> Tuple[FlashAttentionMetadata, torch.Tensor]:
+ self,
+ scheduler_output: "SchedulerOutput", # type: ignore[name-defined]
+) -> Tuple[MsAttentionMetadata, torch.Tensor]:
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
assert total_num_scheduled_tokens > 0
num_reqs = self.input_batch.num_reqs
@@ -48,13 +63,11 @@ def _prepare_inputs(
for i, req_id in enumerate(self.input_batch.req_ids):
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
num_scheduled_tokens[i] = num_tokens
- max_num_scheduled_tokens = max(max_num_scheduled_tokens,
- num_tokens)
+ max_num_scheduled_tokens = max(max_num_scheduled_tokens, num_tokens)
# Get request indices.
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
- req_indices = np.repeat(self.arange_np[:num_reqs],
- num_scheduled_tokens)
+ req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens)
# Get batched arange.
# E.g., [2, 5, 3] -> [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
@@ -71,20 +84,20 @@ def _prepare_inputs(
# Get positions.
positions_np = self.positions_np[:total_num_scheduled_tokens]
np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
- arange,
- out=positions_np)
+ arange,
+ out=positions_np)
if self.uses_mrope:
self._calc_mrope_positions(scheduler_output)
if self.uses_mrope:
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
- self.mrope_positions[:, :total_num_scheduled_tokens].copy_(
- self.mrope_positions_cpu[:, :total_num_scheduled_tokens],
- non_blocking=True)
+ self.mrope_positions[:, :
+ total_num_scheduled_tokens] = self.mrope_positions_cpu[:, :
+ total_num_scheduled_tokens]
else:
- self.positions[:total_num_scheduled_tokens] = torch.from_numpy(positions_np)
-
+ self.positions[:total_num_scheduled_tokens] = torch.from_numpy(
+ positions_np)
# Get token indices.
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
@@ -94,10 +107,7 @@ def _prepare_inputs(
req_indices * self.input_batch.token_ids_cpu.shape[1])
self.input_ids[:total_num_scheduled_tokens] = torch.from_numpy(
- np.take(self.input_batch.token_ids_cpu.ravel(),
- token_indices,
- 0)
- )
+ np.take(self.input_batch.token_ids_cpu.ravel(), token_indices, 0))
# Calculate the slot mapping.
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
@@ -108,12 +118,12 @@ def _prepare_inputs(
block_table_indices = (req_indices * self.max_num_blocks_per_req +
positions_np // self.block_size)
-
- block_numbers = self.input_batch.block_table.block_table_np.ravel()[block_table_indices]
+ block_numbers = self.input_batch.block_table.block_table_np.ravel(
+ )[block_table_indices]
block_offsets = positions_np % self.block_size
np.add(block_numbers * self.block_size,
- block_offsets,
- out=self.slot_mapping_np[:total_num_scheduled_tokens])
+ block_offsets,
+ out=self.slot_mapping_np[:total_num_scheduled_tokens])
# # Prepare the attention metadata.
self.query_start_loc_np[0] = 0
@@ -124,11 +134,8 @@ def _prepare_inputs(
num_scheduled_tokens)
common_prefix_len = 0
- if self.cascade_attn_enabled:
- common_prefix_len = self._compute_cascade_attn_prefix_len(
- num_scheduled_tokens,
- scheduler_output.num_common_prefix_blocks,
- )
+ # when common_prefix_len > 0 use cascade_attn,
+ # which is associated with device_properties.multi_processor_count(CUDA).
attn_metadata = self.attn_metadata_builder.build(
num_reqs=num_reqs,
@@ -137,8 +144,7 @@ def _prepare_inputs(
common_prefix_len=common_prefix_len,
)
- use_spec_decode = len(
- scheduler_output.scheduled_spec_decode_tokens) > 0
+ use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0
if not use_spec_decode:
# NOTE(woosuk): Due to chunked prefills, the batch may contain
# partial requests. While we should not sample any token
@@ -165,7 +171,7 @@ def _prepare_inputs(
if self.lora_config:
self.set_active_loras(self.input_batch, num_scheduled_tokens)
- return attn_metadata, logits_indices, spec_decode_metadata
+ return attn_metadata, logits_indices, spec_decode_metadata # type: ignore[return-value]
def create_block(shape, dtype, name=None, device=None):
@@ -173,6 +179,7 @@ def create_block(shape, dtype, name=None, device=None):
blocks = mint.empty(shape, dtype=dtype, device=device)
return blocks
+
def initialize_kv_cache(self, kv_cache_config) -> None:
"""
Initialize KV cache based on `kv_cache_config`.
@@ -203,28 +210,29 @@ def initialize_kv_cache(self, kv_cache_config) -> None:
assert num_blocks >= kv_cache_config.num_blocks
if isinstance(kv_cache_spec, FullAttentionSpec):
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
- num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads,
- kv_cache_spec.head_size)
+ num_blocks, kv_cache_spec.block_size,
+ kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
dtype = kv_cache_spec.dtype
dtype = get_valid_dtype(dtype)
current_cache = []
device_type = "CPU" if self.device.type == "cpu" else "Ascend"
for i in range(kv_cache_shape[0]):
- cache_blocks = create_block(
- kv_cache_shape[1:], dtype, device=device_type
- )
+ cache_blocks = create_block(kv_cache_shape[1:],
+ dtype,
+ device=device_type)
current_cache.append(mutable(cache_blocks))
kv_caches[layer_name] = mutable(tuple(current_cache))
else:
raise NotImplementedError
- bind_kv_cache(
- kv_caches,
- self.vllm_config.compilation_config.static_forward_context,
- self.kv_caches)
+ bind_kv_cache(kv_caches,
+ self.vllm_config.compilation_config.static_forward_context,
+ self.kv_caches)
-def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
+def _update_states(
+ self, scheduler_output: "SchedulerOutput"
+) -> None: # type: ignore[name-defined]
"""Update the cached states and the persistent batch with the scheduler
output.
@@ -307,14 +315,12 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
second_per_grid_ts = []
for mm_input in self.requests[req_id].mm_inputs:
if mm_input.get("image_grid_thw") is not None:
- image_grid_thw.extend(
- mm_input["image_grid_thw"].tolist())
- if mm_input.get("video_grid_thw") is not None:
- video_grid_thw.extend(
- mm_input["video_grid_thw"].tolist())
+ image_grid_thw.extend(mm_input["image_grid_thw"].tolist())
+ if mm_input.get("video_grid_thw") is not None:
+ video_grid_thw.extend(
+ mm_input["video_grid_thw"].tolist())
if mm_input.get("second_per_grid_ts") is not None:
- second_per_grid_ts.extend(
- mm_input["second_per_grid_ts"])
+ second_per_grid_ts.extend(mm_input["second_per_grid_ts"])
hf_config = self.model_config.hf_config
@@ -340,9 +346,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
req_state.num_computed_tokens = num_computed_tokens
# Add the sampled token(s) from the previous step (if any).
# This doesn't include "unverified" tokens like spec decode tokens.
- num_new_tokens = (num_computed_tokens +
- len(req_data.new_token_ids) -
- req_state.num_tokens)
+ num_new_tokens = (num_computed_tokens + len(req_data.new_token_ids) -
+ req_state.num_tokens)
if num_new_tokens == 1:
# Avoid slicing list in most common case.
req_state.output_token_ids.append(req_data.new_token_ids[-1])
@@ -369,8 +374,6 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
# Update the persistent batch.
self.input_batch.num_computed_tokens_cpu[req_index] = (
num_computed_tokens)
- start_index = (len(req_state.block_ids) -
- len(req_data.new_block_ids))
self.input_batch.block_table.append_row(req_data.new_block_ids,
req_index)
# Add new_token_ids to token_ids_cpu.
@@ -392,7 +395,6 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
# NOTE(woosuk): `num_tokens` here may include spec decode tokens.
self.input_batch.num_tokens[req_index] = end_token_index
-
# self.input_batch.token_ids_cpu_tensor.copy_(torch.from_numpy(self.input_batch.token_ids_cpu))
# Check if the batch has changed. If not, we can skip copying the
# sampling metadata from CPU to GPU.
@@ -403,12 +405,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
removed_req_indices = sorted(removed_req_indices, reverse=True)
for req_id in req_ids_to_add:
req_state = self.requests[req_id]
- if removed_req_indices:
- # Fill the empty index.
- req_index = removed_req_indices.pop()
- else:
- # Append to the end.
- req_index = None
+ req_index = removed_req_indices.pop() if removed_req_indices else None
self.input_batch.add_request(req_state, req_index)
# Condense the batched states if there are empty indices.
@@ -417,3 +414,118 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
if batch_changed:
self.input_batch.refresh_sampling_metadata()
+
+
+def wrapper_gpu_model_runner_execute_model(func):
+
+ def new_func(*args, **kwargs):
+ self = args[0]
+ try:
+ output = func(*args, **kwargs)
+ return output
+ except Exception as e:
+ logger.warning(
+ f"Caught exception {str(e)} when processing req_ids {self.input_batch.req_ids}" # noqa: G004
+ )
+ return ModelRunnerOutput(
+ req_ids=self.input_batch.req_ids,
+ req_id_to_index=self.input_batch.req_id_to_index,
+ sampled_token_ids=None,
+ spec_token_ids=None,
+ logprobs=None,
+ prompt_logprobs_dict={},
+ )
+
+ return new_func
+
+
+def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
+ forward_ctx = self.vllm_config.compilation_config.static_forward_context
+ block_size = self.vllm_config.cache_config.block_size
+ use_mla = self.vllm_config.model_config.use_mla
+ kv_cache_spec: dict[str, KVCacheSpec] = {}
+ for layer_name, attn_module in forward_ctx.items():
+ # vllm-mindspore AttentionWrapper is not an Attention isinstance
+ # assert isinstance(attn_module, Attention)
+ if attn_module.attn_type == AttentionType.DECODER:
+ if attn_module.sliding_window is not None:
+ kv_cache_spec[layer_name] = SlidingWindowSpec(
+ block_size=block_size,
+ num_kv_heads=attn_module.num_kv_heads,
+ head_size=attn_module.head_size,
+ dtype=self.kv_cache_dtype,
+ sliding_window=attn_module.sliding_window,
+ use_mla=use_mla)
+ else:
+ kv_cache_spec[layer_name] = FullAttentionSpec(
+ block_size=block_size,
+ num_kv_heads=attn_module.num_kv_heads,
+ head_size=attn_module.head_size,
+ dtype=self.kv_cache_dtype,
+ use_mla=use_mla)
+ elif attn_module.attn_type in (AttentionType.ENCODER,
+ AttentionType.ENCODER_ONLY):
+ # encoder-only attention does not need KV cache.
+ continue
+ elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
+ raise NotImplementedError
+ else:
+ raise ValueError(
+ f"Unknown attention type: {attn_module.attn_type}")
+
+ return kv_cache_spec
+
+
+def _calc_mrope_positions(
+ self,
+ scheduler_output: "SchedulerOutput"): # type: ignore[name-defined]
+ mrope_pos_ptr = 0
+ for index, req_id in enumerate(self.input_batch.req_ids):
+ req = self.requests[req_id]
+ assert req.mrope_positions is not None
+
+ num_computed_tokens = \
+ self.input_batch.num_computed_tokens_cpu[index]
+ num_scheduled_tokens = \
+ scheduler_output.num_scheduled_tokens[req_id]
+ num_prompt_tokens = len(req.prompt_token_ids)
+
+ if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens:
+ prompt_part_len = max(0, num_prompt_tokens - num_computed_tokens)
+ completion_part_len = max(0,
+ num_scheduled_tokens - prompt_part_len)
+ else:
+ prompt_part_len = num_scheduled_tokens
+ completion_part_len = 0
+
+ assert num_scheduled_tokens == prompt_part_len + completion_part_len
+
+ if prompt_part_len > 0:
+ # prompt's mrope_positions are pre-computed
+ # gpu is number or tensor, but we are numpy, so we transform to int
+ dst_start = int(mrope_pos_ptr)
+ dst_end = int(mrope_pos_ptr + prompt_part_len)
+ src_start = int(num_computed_tokens)
+ src_end = int(num_computed_tokens + prompt_part_len)
+
+ self.mrope_positions_cpu[:, dst_start:dst_end] = \
+ req.mrope_positions[:,src_start:src_end]
+
+ mrope_pos_ptr += prompt_part_len
+
+ if completion_part_len > 0:
+ # compute completion's mrope_positions on-the-fly
+ dst_start = mrope_pos_ptr
+ dst_end = mrope_pos_ptr + completion_part_len
+
+ self.mrope_positions_cpu[:, dst_start:dst_end] = \
+ MRotaryEmbedding.get_next_input_positions_tensor(
+ req.mrope_position_delta,
+ context_len=num_computed_tokens +
+ prompt_part_len,
+ seq_len=num_computed_tokens +
+ prompt_part_len +
+ completion_part_len,
+ )
+
+ mrope_pos_ptr += completion_part_len
diff --git a/vllm_mindspore/worker/worker.py b/vllm_mindspore/worker/worker.py
index 8ce1bc91d511a43a83fd3c8b0e70d228b98b951b..0978ed4c58c3777ae859a3103350c98ab4320945 100644
--- a/vllm_mindspore/worker/worker.py
+++ b/vllm_mindspore/worker/worker.py
@@ -1,5 +1,6 @@
#!/usr/bin/env python3
-# encoding: utf-8
+# type: ignore
+# isort:skip_file
# Copyright 2025 Huawei Technologies Co., Ltd
# Copyright 2024 The vLLM team.
#
@@ -15,23 +16,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
-
"""Worker functions"""
-import gc
-import os
import math
-from typing import Tuple, Optional
-
import torch
-from vllm.config import VllmConfig
-from vllm.distributed import (
- ensure_kv_transfer_initialized,
- ensure_model_parallel_initialized,
- init_distributed_environment,
- set_custom_all_reduce,
-)
-
from vllm.logger import init_logger
from vllm_mindspore.utils import get_valid_dtype
@@ -39,34 +27,41 @@ from vllm.model_executor import set_random_seed
from vllm.sequence import SequenceGroupMetadata
from vllm.sampling_params import SamplingParams
-
logger = init_logger(__name__)
-def _prepare_input_for_warmup(model_config, model_runner, cache_engine, is_prefill, is_mtp_model=False):
+def _prepare_input_for_warmup(model_config,
+ model_runner,
+ cache_engine,
+ is_prefill,
+ is_mtp_model=False):
bs = 1
seq_len = model_runner.scheduler_config.max_num_batched_tokens if is_prefill else 1
- dummy_data = model_runner.input_registry.dummy_data_for_profiling(model_config, seq_len, model_runner.mm_registry)
- block_tables = [i for i in range(math.ceil(seq_len / cache_engine.block_size))]
+ dummy_data = model_runner.input_registry.dummy_data_for_profiling(
+ model_config, seq_len, model_runner.mm_registry)
+ block_tables = [
+ i for i in range(math.ceil(seq_len / cache_engine.block_size))
+ ]
+
+ # adapter multi modal warm up
+ seq_data = dummy_data.seq_data
+ if seq_len == 1:
+ seq_data = dummy_data.seq_data.from_prompt_token_counts((0, seq_len))
+
seqs = [
SequenceGroupMetadata(
request_id=str(idx),
is_prompt=is_prefill,
- seq_data={idx: dummy_data.seq_data},
+ seq_data={idx: seq_data},
sampling_params=SamplingParams(),
block_tables={idx: block_tables},
lora_request=None,
multi_modal_data=None,
multi_modal_placeholders=None,
- )
- for idx in range(bs)
+ ) for idx in range(bs)
]
model_input = model_runner.prepare_model_input(seqs)
- block_tables = model_input.attn_metadata.block_tables
- if block_tables is not None and block_tables.numel() <= 0:
- model_input.attn_metadata.block_tables = torch.zeros((1, 1), dtype=torch.int32)
-
previous_hidden_states = None if not is_mtp_model else \
torch.ones([bs, seq_len, model_config.get_hidden_size()], dtype=get_valid_dtype(model_config.dtype))
return model_input, previous_hidden_states
@@ -78,19 +73,31 @@ def _warm_up_model(self) -> None:
is_mtp_model = self.speculative_config is not None and self.model_config.hf_config.model_type == "deepseek_mtp"
if is_mtp_model:
# prefill mtp model
- model_input, previous_hidden_states = _prepare_input_for_warmup(self.model_config, self.model_runner,
- self.cache_engine[0], True, is_mtp_model)
- self.model_runner.execute_model(model_input, kv_cache, None, previous_hidden_states=previous_hidden_states)
+ model_input, previous_hidden_states = _prepare_input_for_warmup(
+ self.model_config, self.model_runner, self.cache_engine[0], True,
+ is_mtp_model)
+ self.model_runner.execute_model(
+ model_input,
+ kv_cache,
+ None,
+ previous_hidden_states=previous_hidden_states)
# warmup for decode
if self.vllm_config.scheduler_config.is_multi_step:
- model_input, _ = _prepare_input_for_warmup(self.model_config, self.model_runner._base_model_runner,
- self.cache_engine[0], False)
- self.model_runner._base_model_runner.execute_model(model_input, kv_cache, None)
+ model_input, _ = _prepare_input_for_warmup(
+ self.model_config, self.model_runner._base_model_runner,
+ self.cache_engine[0], False)
+ self.model_runner._base_model_runner.execute_model(
+ model_input, kv_cache, None)
else:
- model_input, previous_hidden_states = _prepare_input_for_warmup(self.model_config, self.model_runner,
- self.cache_engine[0], False, is_mtp_model)
- self.model_runner.execute_model(model_input, kv_cache, None, previous_hidden_states=previous_hidden_states)
+ model_input, previous_hidden_states = _prepare_input_for_warmup(
+ self.model_config, self.model_runner, self.cache_engine[0], False,
+ is_mtp_model)
+ self.model_runner.execute_model(
+ model_input,
+ kv_cache,
+ None,
+ previous_hidden_states=previous_hidden_states)
torch.cuda.synchronize()