diff --git a/.jenkins/test/config/dependent_packages.yaml b/.jenkins/test/config/dependent_packages.yaml index 531bee81da75c583c5d5b03ee38a1cb9c2eaa2e3..16ca50fdb4cd81468ec557e2ed777a41fd74b8c5 100644 --- a/.jenkins/test/config/dependent_packages.yaml +++ b/.jenkins/test/config/dependent_packages.yaml @@ -1,5 +1,5 @@ mindspore: - 'https://repo.mindspore.cn/mindspore/mindspore/version/202506/20250608/br_infer_iter_20250608031509_f31d63401e48787a7677f6e5c61745dd44304240_newest/' + 'https://repo.mindspore.cn/mindspore/mindspore/version/202506/20250613/br_infer_iter_20250613031508_11bcfd2ff4dc201a1c07e5d525cbeff7ec7f9558_newest/' mindspore_gs: 'https://repo.mindspore.cn/mindspore/golden-stick/version/202506/20250604/master_20250604160014_35fcbec4406d3b18faf02ef99fcbe2741e80348e_newest/' diff --git a/install_depend_pkgs.sh b/install_depend_pkgs.sh index b3d8306e2a066057347daced6415b3d924380994..97da181da862bff0e17c422b01af853a32f6b680 100644 --- a/install_depend_pkgs.sh +++ b/install_depend_pkgs.sh @@ -67,7 +67,7 @@ echo "========= Installing mindformers" mf_dir=mindformers-dev if [ ! -d "$mf_dir" ]; then git clone https://gitee.com/mindspore/mindformers.git -b dev "$mf_dir" - git checkout dfb8aa3a59401495b2d8c8c107d46fe0d36c949a + git checkout 13adb2201abe8979b679a98566495a8642d7ec0d else echo "The $mf_dir folder already exists and will not be re-downloaded." fi diff --git a/tests/mindformers b/tests/mindformers index f046081e40be777eb799afee10495b51cdb2f3c1..13adb2201abe8979b679a98566495a8642d7ec0d 160000 --- a/tests/mindformers +++ b/tests/mindformers @@ -1 +1 @@ -Subproject commit f046081e40be777eb799afee10495b51cdb2f3c1 +Subproject commit 13adb2201abe8979b679a98566495a8642d7ec0d diff --git a/tests/st/python/cases_parallel/vllm_mf_qwen3_8b.py b/tests/st/python/cases_parallel/vllm_mf_qwen3_8b.py new file mode 100644 index 0000000000000000000000000000000000000000..48de1692134eff0f30e54de79fcabe8b3e4dc52d --- /dev/null +++ b/tests/st/python/cases_parallel/vllm_mf_qwen3_8b.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python3 +# encoding: utf-8 +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test mf qwen.""" +import os + +import pytest + +from tests.st.python import set_env + +env_manager = set_env.EnvVarManager() +# def env +env_vars = { + "ASCEND_CUSTOM_PATH": os.path.expandvars("$ASCEND_HOME_PATH/../"), + "vLLM_MODEL_BACKEND": "MindFormers", + "MS_ENABLE_LCCL": "off", + "HCCL_OP_EXPANSION_MODE": "AIV", + "MS_ALLOC_CONF": "enable_vmm:True", + "LCCL_DETERMINISTIC": "1", + "HCCL_DETERMINISTIC": "true", + "ATB_MATMUL_SHUFFLE_K_ENABLE": "0", + "ATB_LLM_LCOC_ENABLE": "0", + "VLLM_USE_V1": "0" +} +# set env +env_manager.setup_ai_environment(env_vars) +# isort: off +import vllm_mindspore +from vllm import LLM, SamplingParams +# isort: on + + +def test_mf_qwen3(): + """ + test case qwen3 8B + """ + + # Sample prompts. + prompts = [ + "You are a helpful assistant.<|User|>将文本分类为中性、负面或正面。 \n文本:我认为这次假期还可以。 \n情感:<|Assistant|>\n", + ] + + # Create a sampling params object. + sampling_params = SamplingParams(temperature=0.0, max_tokens=10, top_k=1) + + # Create an LLM. + llm = LLM(model="/home/workspace/mindspore_dataset/weight/Qwen3-8B", + gpu_memory_utilization=0.9, + tensor_parallel_size=2) + # Generate texts from the prompts. The output is a list of RequestOutput objects + # that contain the prompt, generated text, and other information. + outputs = llm.generate(prompts, sampling_params) + except_list = ['好的,我需要分析用户提供的文本“我认为'] + # Print the outputs. + for i, output in enumerate(outputs): + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + assert generated_text == except_list[i] + + # unset env + env_manager.unset_all() diff --git a/tests/st/python/cases_parallel/vllm_mf_qwen3_8b_v1.py b/tests/st/python/cases_parallel/vllm_mf_qwen3_8b_v1.py new file mode 100644 index 0000000000000000000000000000000000000000..aeb62ef7af753cda7509f7ef6b96da8c91d2379c --- /dev/null +++ b/tests/st/python/cases_parallel/vllm_mf_qwen3_8b_v1.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python3 +# encoding: utf-8 +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test mf qwen.""" +import os + +import pytest + +from tests.st.python import set_env + +env_manager = set_env.EnvVarManager() +# def env +env_vars = { + "ASCEND_CUSTOM_PATH": os.path.expandvars("$ASCEND_HOME_PATH/../"), + "vLLM_MODEL_BACKEND": "MindFormers", + "MS_ENABLE_LCCL": "off", + "HCCL_OP_EXPANSION_MODE": "AIV", + "MS_ALLOC_CONF": "enable_vmm:True", + "LCCL_DETERMINISTIC": "1", + "HCCL_DETERMINISTIC": "true", + "ATB_MATMUL_SHUFFLE_K_ENABLE": "0", + "ATB_LLM_LCOC_ENABLE": "0", + "VLLM_USE_V1": "1" +} +# set env +env_manager.setup_ai_environment(env_vars) +# isort: off +import vllm_mindspore +from vllm import LLM, SamplingParams +# isort: on + + +def test_mf_qwen3(): + """ + test case qwen3 8B + """ + + # Sample prompts. + prompts = [ + "You are a helpful assistant.<|User|>将文本分类为中性、负面或正面。 \n文本:我认为这次假期还可以。 \n情感:<|Assistant|>\n", + ] + + # Create a sampling params object. + sampling_params = SamplingParams(temperature=0.0, max_tokens=10, top_k=1) + + # Create an LLM. + llm = LLM(model="/home/workspace/mindspore_dataset/weight/Qwen3-8B", + gpu_memory_utilization=0.9, + tensor_parallel_size=2) + # Generate texts from the prompts. The output is a list of RequestOutput objects + # that contain the prompt, generated text, and other information. + outputs = llm.generate(prompts, sampling_params) + except_list = ['好的,我需要分析用户提供的文本“我认为'] + # Print the outputs. + for i, output in enumerate(outputs): + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + assert generated_text == except_list[i] + + # unset env + env_manager.unset_all() diff --git a/tests/st/python/test_cases_parallel.py b/tests/st/python/test_cases_parallel.py index 35d31ea8cea26f0340fd074dec98f28a296e4b97..f0ef9f1bcdac854443b553307a445981bbeaaf0a 100644 --- a/tests/st/python/test_cases_parallel.py +++ b/tests/st/python/test_cases_parallel.py @@ -99,7 +99,8 @@ def test_cases_parallel_part1(): "export HCCL_IF_BASE_PORT=61004 && " "pytest -s -v cases_parallel/vllm_mf_qwen_7b_prefix_caching_v1.py::test_mf_qwen_7b_prefix_caching " "> vllm_mf_qwen_7b_prefix_caching_v1_test_mf_qwen_7b_prefix_caching.log", - "vllm_mf_qwen_7b_prefix_caching_v1_test_mf_qwen_7b_prefix_caching.log"), + "vllm_mf_qwen_7b_prefix_caching_v1_test_mf_qwen_7b_prefix_caching.log" + ), ("export ASCEND_RT_VISIBLE_DEVICES=6,7 && export LCAL_COMM_ID=127.0.0.1:10071 && " "export HCCL_IF_BASE_PORT=61006 && " "pytest -s -v cases_parallel/vllm_mf_qwen_7b_v1.py::test_mf_qwen > vllm_mf_qwen_7b_v1_test_mf_qwen.log", @@ -212,6 +213,33 @@ def test_cases_parallel_part4(): check_results(commands, results) +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_single +def test_cases_parallel_part5(): + """ + Feature: test cases parallel. + Description: test cases parallel. + Expectation: Pass. + """ + commands = [ + ("export ASCEND_RT_VISIBLE_DEVICES=0,1 && export LCAL_COMM_ID=127.0.0.1:10068 && " + "export HCCL_IF_BASE_PORT=61000 && " + "pytest -s -v cases_parallel/vllm_mf_qwen3_8b.py::test_mf_qwen3 " + "> vllm_mf_qwen3_8b_test_mf_qwen3.log", + "vllm_mf_qwen3_8b_test_mf_qwen3.log"), + ("export ASCEND_RT_VISIBLE_DEVICES=2,3 && export LCAL_COMM_ID=127.0.0.1:10069 && " + "export HCCL_IF_BASE_PORT=61002 && " + "pytest -s -v cases_parallel/vllm_mf_qwen3_8b_v1.py::test_mf_qwen3 " + "> vllm_mf_qwen3_8b_v1_test_mf_qwen3.log", + "vllm_mf_qwen3_8b_v1_test_mf_qwen3.log") + ] + + with Pool(len(commands)) as pool: + results = list(pool.imap(run_command, commands)) + check_results(commands, results) + + @pytest.mark.level1 @pytest.mark.platform_arm_ascend910b_training @pytest.mark.env_single diff --git a/vllm_mindspore/model_executor/models/mf_models/config.py b/vllm_mindspore/model_executor/models/mf_models/config.py new file mode 100644 index 0000000000000000000000000000000000000000..ff6003741ca1cfb46b27e8c22ef1ff10f0d07e72 --- /dev/null +++ b/vllm_mindspore/model_executor/models/mf_models/config.py @@ -0,0 +1,160 @@ +#!/usr/bin/env python3 +# Copyright 2025 Huawei Technologies Co., Ltd +# Copyright 2025 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import types + +from mindformers.models.configuration_utils import PretrainedConfig +from mindformers.tools.register.config import MindFormerConfig +from vllm.config import VllmConfig + +MF_CTX_MAPPING = { + 'run_mode': (None, "predict"), + 'use_legacy': (None, False), + 'load_ckpt_format': (None, 'safetensors'), + 'auto_trans_ckpt': (None, True), +} + +MF_PARALLEL_MAPPING = { + 'parallel_mode': (None, 'STAND_ALONE'), + 'parallel_config.model_parallel': + ('parallel_config.tensor_parallel_size', None), + 'parallel_config.pipeline_stage': + ('parallel_config.pipeline_parallel_size', None), + 'parallel_config.vocab_emb_dp': (None, False) +} + +# Common model config +MODEL_COMMON_MAPPING = { + 'seq_length': ('model_config.max_model_len', None), + 'use_flash_attention': (None, True), + "compute_dtype": ('model_config.hf_config.torch_dtype', 'bfloat16'), + 'architectures': ('model_config.hf_config.architectures', None), + 'bos_token_id': ('model_config.hf_config.bos_token_id', None), + 'eos_token_id': ('model_config.hf_config.eos_token_id', None), + 'model_type': ('model_config.hf_config.model_type', None), + # transformer_config + 'attention_dropout': ('model_config.hf_config.attention_dropout', None), + 'hidden_act': ('model_config.hf_config.hidden_act', None), + 'hidden_size': ('model_config.hf_config.hidden_size', None), + 'intermediate_size': ('model_config.hf_config.intermediate_size', None), + 'max_position_embeddings': + ('model_config.hf_config.max_position_embeddings', None), + 'num_attention_heads': + ('model_config.hf_config.num_attention_heads', None), + 'rms_norm_eps': ('model_config.hf_config.rms_norm_eps', None), + 'num_hidden_layers': ('model_config.hf_config.num_hidden_layers', None), + 'num_layers': ('model_config.hf_config.num_layers', None), + 'num_key_value_heads': + ('model_config.hf_config.num_key_value_heads', None), + 'n_kv_heads': ('model_config.hf_config.n_kv_heads', None), + 'head_dim': ('model_config.hf_config.head_dim', None), + 'rope_theta': ('model_config.hf_config.rope_theta', None), + 'tie_word_embeddings': + ('model_config.hf_config.tie_word_embeddings', None), + 'vocab_size': ('model_config.hf_config.vocab_size', None), +} + +# model default config +MODEL_RELATED_MAPPING = { + 'qwen2': { + "gated_linear_unit": True, + 'params_dtype': 'float32', # need an input + 'add_qkv_bias': True, + }, + 'qwen3': { + "gated_linear_unit": True, + 'params_dtype': 'float32', # need an input + 'add_qkv_bias': False, + } + # Add anther model type... +} + + +def get_nested_attr(obj, path: str, default=None): + """get nested attr from obj.""" + current = obj + for attr in path.split('.'): + if not hasattr(current, attr): + return default + current = getattr(current, attr) + return current + + +def set_nested_attr(obj, path: str, value): + """Set nested attr of MindFormerConfig.""" + attrs = path.split('.') + + current = obj + for attr in attrs[:-1]: + if not hasattr(current, attr) or getattr(current, attr) is None: + setattr(current, attr, MindFormerConfig()) + current = getattr(current, attr) + + setattr(current, attrs[-1], value) + + +def transform_config(mapping_table: dict, vllm_config: VllmConfig, + target_config): + for target_path, mapping in mapping_table.items(): + src_path, transform = mapping + + src_value = get_nested_attr(vllm_config, + src_path) if src_path is not None else None + + if src_value is not None: + transformed_value = src_value + elif transform and isinstance( + transform, (types.FunctionType, types.BuiltinFunctionType)): + transformed_value = transform(src_value) + else: + transformed_value = transform + + if transformed_value is not None: + set_nested_attr(target_config, target_path, transformed_value) + + +def gen_model_relatived_config(model_type): + return MODEL_RELATED_MAPPING.get(model_type) + + +def gen_model_config_dict(vllm_config: VllmConfig): + target_config = MindFormerConfig() + + transform_config(MODEL_COMMON_MAPPING, vllm_config, target_config) + + model_type = vllm_config.model_config.hf_config.model_type + model_related_config = gen_model_relatived_config(model_type) + target_config.update(model_related_config) + + return target_config + + +def gen_mf_config(vllm_config: VllmConfig): + target_config = MindFormerConfig() + transform_config(MF_CTX_MAPPING, vllm_config, target_config) + transform_config(MF_PARALLEL_MAPPING, vllm_config, target_config) + target_config.set_value( + 'model.model_config', + MindFormerConfig(**gen_model_config_dict(vllm_config))) + return target_config + + +def gen_model_config(mf_config: MindFormerConfig, + model_config_type: PretrainedConfig): + model_config = model_config_type(**mf_config.model.model_config, + parallel_config=mf_config.parallel_config) + model_config.post_process = False + return model_config diff --git a/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py b/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py index 5e3dfc62e79f3a8f88360052120f0d1817f1d414..56cabb1f7b433eebadbec05f4e8a2892611d8f2a 100644 --- a/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py +++ b/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py @@ -1,5 +1,4 @@ #!/usr/bin/env python3 -# encoding: utf-8 # Copyright 2025 Huawei Technologies Co., Ltd # Copyright 2024 The vLLM team. # @@ -17,62 +16,64 @@ # ============================================================================ import os -from types import MethodType -from typing import Iterable, List, Optional, Set, Tuple, Union from abc import abstractmethod -import numpy as np -import math - -from vllm.config import VllmConfig -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors -from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.distributed.parallel_state import get_dp_group -from vllm.logger import init_logger -from vllm.forward_context import get_forward_context -import vllm.envs as envs +from typing import Iterable, Optional, Set, Tuple, Union import mindspore as ms -from mindspore import Tensor, mint -from mindspore.common.api import _pynative_executor -from mindspore.communication import get_rank - -from mindformers.tools.register.config import MindFormerConfig from mindformers.core.context import build_mf_context from mindformers.core.parallel_config import build_parallel_config +from mindformers.tools.register.config import MindFormerConfig from mindformers.tools.utils import is_pynative +from mindspore import Tensor, mint +from mindspore.common.api import _pynative_executor +from mindspore.communication import get_rank +from vllm.config import VllmConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed.parallel_state import get_dp_group +from vllm.logger import init_logger +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm_mindspore.model_executor.models.attention_mask import ( + LowerTriangularMask) from vllm_mindspore.model_executor.models.model_base import MsModelBase -from vllm_mindspore.model_executor.models.attention_mask import LowerTriangularMask - logger = init_logger(__name__) + class MfModelBase(MsModelBase): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: - super(MfModelBase, self).__init__( - vllm_config=vllm_config, prefix=prefix - ) + super().__init__(vllm_config=vllm_config, prefix=prefix) + + self.set_flags = False - self.mf_config = MindFormerConfig(os.getenv("MINDFORMERS_MODEL_CONFIG")) + model_config_path = os.getenv("MINDFORMERS_MODEL_CONFIG") + if model_config_path is None: + raise RuntimeError( + 'For "MindFormers" model backend, environments MINDFORMERS_MODEL_CONFIG should be set!' + ) + + self.mf_config = MindFormerConfig(model_config_path) self.rank_id = get_rank() self.dp_size = get_dp_group() + build_mf_context(self.mf_config) build_parallel_config(self.mf_config) self.mf_config.model.model_config.parallel_config = ( - self.mf_config.parallel_config - ) + self.mf_config.parallel_config) self.mf_config.model.model_config.parallel_config.model_parallel = ( - get_tensor_model_parallel_world_size() - ) + get_tensor_model_parallel_world_size()) self.mf_config.model.model_config.parallel_config.pipeline_stage = 1 self._generate_model_config() - self.casual_mask = LowerTriangularMask(dtype=self.mf_model_config.compute_dtype, - max_model_len=self.model_config.max_model_len) + self.casual_mask = LowerTriangularMask( + dtype=self.mf_model_config.compute_dtype, + max_model_len=self.model_config.max_model_len) self.network, self.lm_head = self._create_network() - affinity_config = self.mf_config.get('context', {}).get('affinity_cpu_list', {}) + affinity_config = self.mf_config.get('context', + {}).get('affinity_cpu_list', {}) if isinstance(affinity_config, dict): ms.runtime.set_cpu_affinity(True, affinity_config) @@ -80,15 +81,18 @@ class MfModelBase(MsModelBase): @abstractmethod def _generate_model_config(self): - raise NotImplementedError("Function _generate_model_config should be Implemented!") + raise NotImplementedError( + "Function _generate_model_config should be Implemented!") @abstractmethod def _create_network(self): - raise NotImplementedError("Function _create_network should be Implemented!") + raise NotImplementedError( + "Function _create_network should be Implemented!") def _set_dynamic_inputs(self): self.network.set_dynamic_inputs() - dynamic_hidden_states = Tensor(shape=[None, None], dtype=self.mf_model_config.compute_dtype) + dynamic_hidden_states = Tensor( + shape=[None, None], dtype=self.mf_model_config.compute_dtype) self.lm_head.set_inputs(dynamic_hidden_states) def prepare_inputs(self, input_ids, positions): @@ -97,26 +101,26 @@ class MfModelBase(MsModelBase): def update_model_inputs(self, model_inputs, **kwargs): return model_inputs - def forward( - self, - input_ids: Tensor, - positions: Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[Tensor] = None, - **kwargs - ) -> Union[Tensor, IntermediateTensors]: + def forward(self, + input_ids: Tensor, + positions: Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[Tensor] = None, + **kwargs) -> Union[Tensor, IntermediateTensors]: model_inputs, is_prefill = self.prepare_inputs(input_ids, positions) model_inputs = self.update_model_inputs(model_inputs, **kwargs) - + # enable_mb_split is True in lager EP enable micro-batch and per-dp-bs > 1 - enable_mb_split = self.is_enable_micro_batch_split(is_prefill, model_inputs["q_seq_lens"]) + enable_mb_split = self.is_enable_micro_batch_split( + is_prefill, model_inputs["q_seq_lens"]) if is_prefill: if self.enable_micro_batch: self.network.phase = "prefill" if not enable_mb_split else "prefill_micro_batch" if not self.set_flags or is_pynative() or enable_mb_split: - self.network.add_flags_custom(is_first_iteration=is_first_iteration) - self.network.add_flags_enable_micro_batch(enable_micro_batch=enable_mb_split) + self.network.add_flags_custom(is_first_iteration=True) + self.network.add_flags_enable_micro_batch( + enable_micro_batch=enable_mb_split) else: self.network.phase = "prefill" if not self.set_flags or is_pynative(): @@ -139,11 +143,14 @@ class MfModelBase(MsModelBase): ) -> Optional[Tensor]: if sampling_metadata is not None: selected_token_indices = sampling_metadata.selected_token_indices - if selected_token_indices is not None and selected_token_indices.numel() <= 0: - logits = ms.mint.zeros((0, self.mf_model_config.vocab_size), - dtype=self.mf_model_config.compute_dtype) + if selected_token_indices is not None and selected_token_indices.numel( + ) <= 0: + logits = ms.mint.zeros( + (0, self.mf_model_config.vocab_size), + dtype=self.mf_model_config.compute_dtype) else: - hidden_states = hidden_states.index_select(0, selected_token_indices) + hidden_states = hidden_states.index_select( + 0, selected_token_indices) logits = self.lm_head(hidden_states) logits = logits.view(-1, logits.shape[-1]) else: @@ -162,12 +169,15 @@ class MfModelBase(MsModelBase): def load_weights(self, weights: Iterable[Tuple[str, Tensor]]) -> Set[str]: raise NotImplementedError("load_weight not implemented.") - + def is_enable_micro_batch_split(self, is_prefill, q_seq_lens): """Judge enable micro batch """ if self.enable_micro_batch: - is_prefill_cur_dp = mint.ones((1), dtype=ms.int8) if is_prefill else mint.zeros((1), dtype=ms.int8) + is_prefill_cur_dp = mint.ones( + (1), dtype=ms.int8) if is_prefill else mint.zeros( + (1), dtype=ms.int8) is_prefill_all_dp = get_dp_group().all_gather(is_prefill_cur_dp) - return is_prefill_all_dp.sum() == self.dp_size and q_seq_lens.shape[0] > 1 + return is_prefill_all_dp.sum( + ) == self.dp_size and q_seq_lens.shape[0] > 1 else: return False diff --git a/vllm_mindspore/model_executor/models/mf_models/qwen3.py b/vllm_mindspore/model_executor/models/mf_models/qwen3.py index a5a8b01d6e906f2c9b8e51c7f3d0af288f05137b..a11a93faaab3d4bb978b40eb9b0372d72ef7b2e1 100644 --- a/vllm_mindspore/model_executor/models/mf_models/qwen3.py +++ b/vllm_mindspore/model_executor/models/mf_models/qwen3.py @@ -1,5 +1,4 @@ #!/usr/bin/env python3 -# encoding: utf-8 # Copyright 2025 Huawei Technologies Co., Ltd # Copyright 2024 The vLLM team. # @@ -16,70 +15,210 @@ # limitations under the License. # ============================================================================ -from typing import Iterable, Set, Tuple - -from vllm.config import VllmConfig -from vllm.config import get_current_vllm_config -from vllm.logger import init_logger - -from mindspore import Tensor, JitConfig +from typing import Iterable, Optional, Tuple, Union + +import mindspore as ms +import numpy as np +from mindformers.core.context import build_mf_context +from mindformers.core.parallel_config import build_parallel_config +from mindformers.models.qwen3.configuration_qwen3 import Qwen3Config +from mindformers.models.qwen3.modeling_qwen3 import ( # noqa + Qwen3ForCausalLM as Qwen3ForCausalLM_MF) +from mindformers.tools.utils import is_pynative +from mindspore import Tensor, ops +from mindspore.common.api import _pynative_executor from mindspore.nn.utils import no_init_parameters - -from mindformers.models.llama import LlamaConfig as LlamaConfig_MF -from research.qwen3.qwen3 import ( - ParallelQwen3ForCausalLM as ParallelQwenForCausalLM_MF, -) +from vllm import envs +from vllm.config import VllmConfig, get_current_vllm_config +from vllm.forward_context import get_forward_context +from vllm.logger import init_logger +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors from vllm_mindspore.model_executor.layers.sampler import get_sampler -from vllm_mindspore.model_executor.models.model_base import Fake_Attention -from vllm_mindspore.model_executor.models.mf_models.mf_model_base import MfModelBase -from vllm_mindspore.model_executor.models.mf_models.qwen3_weight_processor import Qwen3WeightProcessor - +from vllm_mindspore.model_executor.models.attention_mask import ( + LowerTriangularMask) +from vllm_mindspore.model_executor.models.mf_models.config import ( + gen_mf_config, gen_model_config) +from vllm_mindspore.model_executor.models.model_base import (AttentionWrapper, + MsModelBase) logger = init_logger(__name__) -class Qwen3ForCausalLM(MfModelBase): +class Qwen3ForCausalLM(MsModelBase): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: - super(Qwen3ForCausalLM, self).__init__(vllm_config=vllm_config, prefix=prefix) - self.mf_kvcaches_init = False + super().__init__(vllm_config=vllm_config, prefix=prefix) + self.set_flags = False + + mf_config = gen_mf_config(vllm_config) + mf_config.load_checkpoint = self.get_model_path() + self.mf_config = mf_config + + build_mf_context(self.mf_config) + build_parallel_config(self.mf_config) + + self._generate_model_config() + self.casual_mask = LowerTriangularMask( + dtype=self.mf_model_config.compute_dtype, + max_model_len=self.mf_model_config.seq_length) + self.network, self.lm_head = self._create_network() + + affinity_config = self.mf_config.get('context', + {}).get('affinity_cpu_list', {}) + if isinstance(affinity_config, dict): + ms.runtime.set_cpu_affinity(True, affinity_config) + + self._set_dynamic_inputs() self.sampler = get_sampler() self.set_modules({"model": self.network}) - - self.kv_caches = [Fake_Attention() for i in range(self.mf_model_config.num_layers)] + self.kv_caches = [ + AttentionWrapper() + for _ in range(self.mf_model_config.num_hidden_layers) + ] compilation_config = get_current_vllm_config().compilation_config if prefix in compilation_config.static_forward_context: raise ValueError(f"Duplicate layer name: {prefix}") - for i in range(self.mf_model_config.num_layers): - compilation_config.static_forward_context[str(i)] = self.kv_caches[i] + for i in range(self.mf_model_config.num_hidden_layers): + compilation_config.static_forward_context[str( + i)] = self.kv_caches[i] - self.set_flags = False + self.cast = ops.Cast() - def _generate_model_config(self): - self.mf_config.load_checkpoint = self.get_model_path() - self.mf_model_config = LlamaConfig_MF(**self.mf_config.model.model_config) - if self.mf_config.moe_config: - self.mf_model_config.moe_config = self.mf_config.moe_config - self.mf_model_config.return_hidden_states = True + def _set_dynamic_inputs(self): + self.network.set_dynamic_inputs() + dynamic_hidden_states = Tensor( + shape=[None, None], dtype=self.mf_model_config.compute_dtype) + self.lm_head.set_inputs(dynamic_hidden_states) - # qwen qkv concat will support in next version - self.mf_model_config.qkv_concat = False - setattr(self.mf_model_config, 'npu_mem_size', -1) - self.mf_config.model.model_config.qkv_concat = False + def prepare_inputs(self, input_ids, positions): + + attn_metadata = get_forward_context().attn_metadata + if attn_metadata is None: + attn_metadata = self._dummy_attention_metadata( + input_ids, positions) + key_cache, value_cache = self.get_kvcache() + if not envs.VLLM_USE_V1: + # V0 + seq_lens = attn_metadata.seq_lens + max_query_len = attn_metadata.max_query_len + # When Mutli-Step is enabled with Chunked-Prefill, prefills and + # decodes are scheduled together. In the first step, all the + # prefills turn into decodes and max_query_len will be 1. + if self.is_multi_step_chunked_prefill and max_query_len == 1: + query_lens = [1] * len(seq_lens) + else: + query_lens = attn_metadata.query_lens + + seq_lens_np = np.array(seq_lens, dtype=np.int32) + query_lens_np = np.array(query_lens, dtype=np.int32) + kv_cache_lens = seq_lens_np - query_lens_np + if attn_metadata.num_decode_tokens == 0 and kv_cache_lens.max( + ) == 0: + is_prefill = True + else: + is_prefill = False + context_lens_tensor = ms.from_numpy(kv_cache_lens) + else: + # V1 + is_prefill = attn_metadata.max_context_lens == 0 + query_lens_np = attn_metadata.q_seq_lens_np + seq_lens_np = attn_metadata.seq_lens_np + context_lens_tensor = attn_metadata.context_lens + + q_seq_lens = ms.Tensor(query_lens_np, dtype=ms.int32) + position_ids = ms.Tensor(positions, dtype=ms.int32) + attention_mask = self.casual_mask.gen_attention_mask( + is_prefill, positions, query_lens_np) + + model_inputs = {} + model_inputs["input_ids"] = input_ids.astype(ms.int32) + model_inputs["batch_valid_length"] = ms.from_numpy(seq_lens_np) + model_inputs["block_tables"] = attn_metadata.block_tables + model_inputs["slot_mapping"] = attn_metadata.slot_mapping + model_inputs["positions"] = position_ids + model_inputs["q_seq_lens"] = q_seq_lens + model_inputs["attention_mask"] = attention_mask + model_inputs["key_cache"] = key_cache + model_inputs["value_cache"] = value_cache + model_inputs["context_lens_tensor"] = context_lens_tensor + + return model_inputs, is_prefill + + def forward(self, + input_ids: Tensor, + positions: Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[Tensor] = None, + **kwargs) -> Union[Tensor, IntermediateTensors]: + model_inputs, is_prefill = self.prepare_inputs(input_ids, positions) + model_inputs = self.update_model_inputs(model_inputs, **kwargs) + + if is_prefill: + self.network.phase = "prefill" + if not self.set_flags or is_pynative(): + self.network.add_flags_custom_mcore(is_prefill=True) + hidden_states = self.network(**model_inputs) + self.network.phase = "increment" + if not self.set_flags or is_pynative(): + self.network.add_flags_custom_mcore(is_prefill=False) + self.set_flags = True + else: + hidden_states = self.network(**model_inputs) + + return hidden_states + + def _generate_model_config(self): + self.mf_model_config = gen_model_config(self.mf_config, Qwen3Config) + logger.debug("=====mf_model_config====\n", self.mf_model_config) def _create_network(self): # Initial network with no_init_parameters(): # Delay initialization - network = ParallelQwenForCausalLM_MF(self.mf_model_config) - return network, network.lm_head - - def load_weights(self, weights: Iterable[Tuple[str, Tensor]]) -> Set[str]: - weight_processor = Qwen3WeightProcessor(self.mf_config, self.network, False) - weight_processor.load_safetensors_shard(self.mf_config.load_checkpoint) - + network = Qwen3ForCausalLM_MF(self.mf_model_config) + return network, network.model.output_layer + + def update_model_inputs(self, model_inputs, **kwargs): + return model_inputs + + def compute_logits( + self, + hidden_states: Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[Tensor]: + if sampling_metadata is not None: + selected_token_indices = sampling_metadata.selected_token_indices + if selected_token_indices is not None and selected_token_indices.numel( + ) <= 0: + logits = ms.mint.zeros( + (0, self.mf_model_config.vocab_size), + dtype=self.mf_model_config.compute_dtype) + else: + hidden_states = hidden_states.reshape( + (-1, hidden_states.shape[-1])) + hidden_states = hidden_states.index_select( + 0, selected_token_indices) + logits = self.lm_head(hidden_states) + logits = logits.view(-1, logits.shape[-1]) + else: + logits = self.lm_head(hidden_states) + logits = logits.view(-1, logits.shape[-1]) + return logits + + def sample( + self, + logits: Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + _pynative_executor.sync() + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, Tensor]]): + self.network.load_weights(self.mf_config.load_checkpoint) self.network.set_dynamic_inputs() - dynamic_hidden_states = Tensor(shape=[None, None], dtype=self.mf_model_config.compute_dtype) - self.lm_head.set_inputs(dynamic_hidden_states) return None diff --git a/vllm_mindspore/model_executor/models/mf_models/qwen3_weight_processor.py b/vllm_mindspore/model_executor/models/mf_models/qwen3_weight_processor.py deleted file mode 100644 index 338616cafda4f4864c26b58530f7db8d11481d9e..0000000000000000000000000000000000000000 --- a/vllm_mindspore/model_executor/models/mf_models/qwen3_weight_processor.py +++ /dev/null @@ -1,115 +0,0 @@ -# Copyright 2025 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -""" -transform huggingface model to mindspore safetensor. -""" -import numpy as np - -import mindspore as ms - -from vllm_mindspore.model_executor.models.mf_models.qwen2_weight_processor import Qwen2WeightProcessor - - -class Qwen3WeightProcessor(Qwen2WeightProcessor): - r""" - Provide Qwen3 Model weight load and shards. - Args: - config (Qwen3Config): The config of Qwen3 model. - network (InferenceQwen3ForCausalLM): The network of Qwen3. - - """ - - def __init__(self, config, network, is_quant): - super().__init__(config, network, is_quant) - - def convert_weight_name(self, weight_name: str): - """replace weight name""" - weight_name = weight_name.replace('embed_tokens.weight', 'tok_embeddings.embedding_weight') - weight_name = weight_name.replace('self_attn.q_proj.', 'attention.wq.') - weight_name = weight_name.replace('self_attn.k_proj.', 'attention.wk.') - weight_name = weight_name.replace('self_attn.v_proj.', 'attention.wv.') - weight_name = weight_name.replace('self_attn.o_proj.', 'attention.wo.') - weight_name = weight_name.replace('self_attn.q_norm.', 'attention.q_norm.') - weight_name = weight_name.replace('self_attn.k_norm.', 'attention.k_norm.') - - weight_name = weight_name.replace('mlp.gate_proj.', 'feed_forward.w1.') - weight_name = weight_name.replace('mlp.down_proj.', 'feed_forward.w2.') - weight_name = weight_name.replace('mlp.up_proj.', 'feed_forward.w3.') - weight_name = weight_name.replace('.input_layernorm.', '.attention_norm.') - weight_name = weight_name.replace('.post_attention_layernorm.', '.ffn_norm.') - weight_name = weight_name.replace('model.norm.weight', 'model.norm_out.weight') - return weight_name - - def infer_process_attention_weight(self, src_hf_dir, layer_id, hf_weight_map): - """infer process attention weight""" - qkv_concat = self.config.model.model_config.qkv_concat - # wq - wq_hf_name = f"model.layers.{layer_id}.self_attn.q_proj.weight" - wq_ms_name = self.convert_weight_name(wq_hf_name) - wq_ms_param, _ = self.get_safetensor_from_file(wq_hf_name, src_hf_dir, hf_weight_map, is_split_param=True, - split_axis=0) - - # wk - wk_hf_name = f"model.layers.{layer_id}.self_attn.k_proj.weight" - wk_ms_name = self.convert_weight_name(wk_hf_name) - wk_ms_param, _ = self.get_safetensor_from_file(wk_hf_name, src_hf_dir, hf_weight_map, is_split_param=True, - split_axis=0) - - # wv - wv_hf_name = f"model.layers.{layer_id}.self_attn.v_proj.weight" - wv_ms_name = self.convert_weight_name(wv_hf_name) - wv_ms_param, _ = self.get_safetensor_from_file(wv_hf_name, src_hf_dir, hf_weight_map, is_split_param=True, - split_axis=0) - - # wq_norm - q_norm_hf_name = f"model.layers.{layer_id}.self_attn.q_norm.weight" - q_norm_ms_name = self.convert_weight_name(q_norm_hf_name) - q_norm_ms_param, _ = self.get_safetensor_from_file(q_norm_hf_name, src_hf_dir, hf_weight_map) - self.parameter_dict[q_norm_ms_name] = ms.Parameter(ms.Tensor(q_norm_ms_param, ms.bfloat16), name=q_norm_ms_name, - requires_grad=False) - - #wk_norm - k_norm_hf_name = f"model.layers.{layer_id}.self_attn.k_norm.weight" - k_norm_ms_name = self.convert_weight_name(k_norm_hf_name) - k_norm_ms_param, _ = self.get_safetensor_from_file(k_norm_hf_name, src_hf_dir, hf_weight_map) - self.parameter_dict[k_norm_ms_name] = ms.Parameter(ms.Tensor(k_norm_ms_param, ms.bfloat16), name=k_norm_ms_name, - requires_grad=False) - - if qkv_concat: - w_qkv_name = f"model.layers.{layer_id}.attention.w_qkv.weight" - w_qkv_param = np.concatenate((wq_ms_param, wk_ms_param, wv_ms_param), axis=0) - w_qkv_param = ms.from_numpy(w_qkv_param).astype(ms.bfloat16) - self.parameter_dict[w_qkv_name] = ms.Parameter(w_qkv_param, name=w_qkv_name, requires_grad=False) - - else: - self.parameter_dict[wq_ms_name] = ms.Parameter(ms.from_numpy(wq_ms_param).astype(ms.bfloat16), - name=wq_ms_name, - requires_grad=False) - self.parameter_dict[wk_ms_name] = ms.Parameter(ms.from_numpy(wk_ms_param).astype(ms.bfloat16), - name=wk_ms_name, - requires_grad=False) - self.parameter_dict[wv_ms_name] = ms.Parameter(ms.from_numpy(wv_ms_param).astype(ms.bfloat16), - name=wv_ms_name, - requires_grad=False) - - # wo - wo_hf_name = f"model.layers.{layer_id}.self_attn.o_proj.weight" - wo_ms_name = self.convert_weight_name(wo_hf_name) - wo_ms_param, _ = self.get_safetensor_from_file(wo_hf_name, src_hf_dir, hf_weight_map, is_split_param=True, - split_axis=1) - self.parameter_dict[wo_ms_name] = ms.Parameter(ms.from_numpy(wo_ms_param).astype(ms.bfloat16), - name=wo_ms_name, - requires_grad=False) diff --git a/vllm_mindspore/model_executor/models/registry.py b/vllm_mindspore/model_executor/models/registry.py index a9c2b9a3e4b3e9bd50509e38fc7b1cfbceecb905..5846f21ae7d817d4366422f16c1dfe9010109868 100644 --- a/vllm_mindspore/model_executor/models/registry.py +++ b/vllm_mindspore/model_executor/models/registry.py @@ -32,9 +32,9 @@ _NATIVE_MODELS = { _MINDFORMERS_MODELS = { "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), + "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"), # MCore "DeepseekV3ForCausalLM": ("deepseek_v3", "DeepseekV3ForCausalLM"), "DeepSeekMTPModel": ("deepseek_mtp", "DeepseekV3MTPForCausalLM"), - "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"), } _MINDONE_MODELS = { diff --git a/vllm_mindspore/model_executor/models/utils.py b/vllm_mindspore/model_executor/models/utils.py index 279af7ad104387807a8e676ac96f94ae3454ec54..bf40c1fdbbdd09f1eda214359ebf7b3b401b2b96 100644 --- a/vllm_mindspore/model_executor/models/utils.py +++ b/vllm_mindspore/model_executor/models/utils.py @@ -16,22 +16,15 @@ # ============================================================================ from dataclasses import dataclass, field -from typing import List, Tuple, Union, Mapping, Optional, Iterable -from functools import wraps -from typing import List, Tuple +from typing import Iterable, List, Mapping, Optional, Tuple, Union import mindspore as ms -from mindspore import jit, mint +from mindspore import mint, ops from vllm.sequence import IntermediateTensors from vllm_mindspore.multimodal.inputs import NestedTensors from vllm_mindspore.utils import get_valid_dtype -import mindspore as ms -from mindspore import mint -from mindspore import ops - - WeightsMapping = Mapping[str, Optional[str]] """If a key maps to a value of `None`, the corresponding weight is ignored.""" @@ -73,6 +66,8 @@ class WeightsMapper: ) -> Iterable[Tuple[str, ms.Tensor]]: return ((out_name, data) for name, data in weights if (out_name := self._map_name(name)) is not None) + + enforce_eager = False @@ -166,6 +161,7 @@ def make_empty_intermediate_tensors_factory(keys: List[str], hidden_size: int): ########################### for multi model ########################### + def _flatten_embeddings(embeddings: NestedTensors) -> ms.Tensor: """ Recursively flattens and concatenates NestedTensors on all but the last @@ -252,7 +248,7 @@ def merge_multimodal_embeddings( """ if isinstance(placeholder_token_id, list): placeholder_token_id = ms.Tensor(placeholder_token_id, - device=input_ids.device) + device=input_ids.device) return _merge_multimodal_embeddings( inputs_embeds, ms.numpy.isin(input_ids, placeholder_token_id), diff --git a/vllm_mindspore/utils.py b/vllm_mindspore/utils.py index 153589ed6dc3affb423fdcd0bc160019db621955..920bb23066966c9de2ca8058f69567fe61f9aad5 100644 --- a/vllm_mindspore/utils.py +++ b/vllm_mindspore/utils.py @@ -19,8 +19,8 @@ import contextlib import gc import os import sys -from typing import (TYPE_CHECKING, Callable, Generator, List, Optional, Tuple, - Union) +from enum import Enum +from typing import TYPE_CHECKING, Generator, List, Optional, Tuple, Union import numpy as np import torch @@ -30,11 +30,10 @@ if TYPE_CHECKING: else: Library = None -from vllm.logger import init_logger - import mindspore as ms from mindspore import dtype as mstype from mindspore.common.initializer import Zero +from vllm.logger import init_logger from vllm.utils import (TORCH_DTYPE_TO_NUMPY_DTYPE, MemoryProfilingResult, MemorySnapshot, T, make_ndarray_with_pad) @@ -142,29 +141,41 @@ STR_DTYPE_TO_MS_DTYPE = { } +class vllmModelBackendEnum(str, Enum): + """Define the variable Enum of vLLM_MODEL_BACKEND""" + MF = 'MindFormers' + MIND_ONE = 'MindONE' + + def ascend_is_initialized(): # Just return true for check. return True def is_mindformers_model_backend(): - return (os.getenv("vLLM_MODEL_BACKEND") # noqa: SIM112 - and - os.environ["vLLM_MODEL_BACKEND"] == "MindFormers" # noqa: SIM112 - ) + vllm_model_backend = os.getenv("vLLM_MODEL_BACKEND") # noqa: SIM112 + if vllm_model_backend: + try: + vllmModelBackendEnum(vllm_model_backend) + return vllm_model_backend == vllmModelBackendEnum.MF + except ValueError as exc: + allowed_values = [member.value for member in vllmModelBackendEnum] + raise ValueError( + f"Illegal value of vLLM_MODEL_BACKEND '{vllm_model_backend}'," + f" allowed_values: {', '.join(allowed_values)}") from exc + else: + return False def is_mindone_model_backend(): return (os.getenv("vLLM_MODEL_BACKEND") # noqa: SIM112 - and os.environ["vLLM_MODEL_BACKEND"] == "MindONE" # noqa: SIM112 - ) + and os.environ["vLLM_MODEL_BACKEND"] # noqa: SIM112 + == vllmModelBackendEnum.MIND_ONE) def check_ready(): - import vllm.envs as envs from mindspore import set_context - # Common environment variables of predict. set_context(jit_config={"jit_level": "O0", "infer_boost": "on"}) default_env = { @@ -179,15 +190,6 @@ def check_ready(): if is_mindformers_model_backend(): logger.info("Run with Mindformers backend!") - necessary_envs = ("MINDFORMERS_MODEL_CONFIG", ) - lost_envs = [ - env_item for env_item in necessary_envs if not os.getenv(env_item) - ] - - if lost_envs: - raise RuntimeError( - f'For "MindFormers" model backend, environments {str(lost_envs)} should be set!' - ) elif is_mindone_model_backend(): logger.info("Run with MindONE backend!") else: