diff --git a/tests/st/python/cases_parallel/similarity.py b/tests/st/python/cases_parallel/similarity.py new file mode 100644 index 0000000000000000000000000000000000000000..bfdae0d90d39150efd4e650160921ecf663e7bf4 --- /dev/null +++ b/tests/st/python/cases_parallel/similarity.py @@ -0,0 +1,58 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import math + +import jieba +import numpy as np + + +def _get_all_words(standard_cut_infer_ret_list, test_cut_infer_ret_list): + all_words = [] + for s_cut in standard_cut_infer_ret_list: + if s_cut not in all_words: + all_words.append(s_cut) + for t_cut in test_cut_infer_ret_list: + if t_cut not in all_words: + all_words.append(t_cut) + return all_words + + +def _get_word_vector(standard_cut_infer_ret_list, test_cut_infer_ret_list, + all_words): + la_standard = [] + lb_test = [] + for word in all_words: + la_standard.append(standard_cut_infer_ret_list.count(word)) + lb_test.append(test_cut_infer_ret_list.count(word)) + return la_standard, lb_test + + +def _get_calculate_cos(la_standard, lb_test): + laa = np.array(la_standard) + lbb = np.array(lb_test) + cos = (np.dot(laa, lbb.T)) / ((math.sqrt(np.dot(laa, laa.T))) * + (math.sqrt(np.dot(lbb, lbb.T)))) + return np.round(cos, 2) + + +def compare_distance(x1, x2, bench_sim=0.95): + """compare distance""" + y1 = list(jieba.cut(x1)) + y2 = list(jieba.cut(x2)) + all_words = _get_all_words(y1, y2) + laa, lbb = _get_word_vector(y1, y2, all_words) + sim = _get_calculate_cos(laa, lbb) + print("calculate sim is:{}".format(str(sim))) + assert sim >= bench_sim diff --git a/tests/st/python/cases_parallel/vllm_qwen2_5_vl_7b_v1.py b/tests/st/python/cases_parallel/vllm_qwen2_5_vl_7b_v1.py new file mode 100644 index 0000000000000000000000000000000000000000..d776c8d93dacc2fc6a3fc28453083c2de9ba320c --- /dev/null +++ b/tests/st/python/cases_parallel/vllm_qwen2_5_vl_7b_v1.py @@ -0,0 +1,101 @@ +#!/usr/bin/env python3 +# encoding: utf-8 +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test mf qwen2.5 vl 7B.""" +import os + +from PIL import Image + +from tests.st.python import set_env +from tests.st.python.cases_parallel.similarity import compare_distance + +env_manager = set_env.EnvVarManager() +# def env +env_vars = { + "ASCEND_CUSTOM_PATH": os.path.expandvars("$ASCEND_HOME_PATH/../"), + "HCCL_OP_EXPANSION_MODE": "AIV", + "MS_ALLOC_CONF": "enable_vmm:True", + "LCCL_DETERMINISTIC": "1", + "HCCL_DETERMINISTIC": "true", + "ATB_MATMUL_SHUFFLE_K_ENABLE": "0", + "ATB_LLM_LCOC_ENABLE": "0", +} +# set env +env_manager.setup_ai_environment(env_vars) +# isort: off +import vllm_mindspore +from vllm import LLM, SamplingParams + +# isort: on + +PROMPT_TEMPLATE = ( + "<|im_start|>system\nYou are a helpful assistant.<|im_end|>" + "\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>" + "What is in the image?<|im_end|>\n" + "<|im_start|>assistant\n") + + +def pil_image() -> Image.Image: + image_path = "images/1080p.jpeg" + return Image.open(image_path) + + +def test_qwen2_5_vl_7b_v1(): + """ + test case qwen2.5 vl 7B + """ + inputs = [{ + "prompt": PROMPT_TEMPLATE, + "multi_modal_data": { + "image": pil_image() + }, + }] + + # Create a sampling params object. + sampling_params = SamplingParams(temperature=0.0, max_tokens=128, top_k=1) + + # Create an LLM. + llm = LLM( + model="/home/workspace/mindspore_dataset/weight/Qwen2.5-VL-7B-Instruct", + gpu_memory_utilization=0.9, + tensor_parallel_size=2, + max_model_len=4096, + max_num_seqs=32, + max_num_batched_tokens=32) + except_list = [ + 'The image depicts a serene and picturesque landscape. It features a lush green meadow with ' + 'wildflowers in the foreground. In the middle ground, there are small wooden huts, possibly used for' + ' storage or as simple shelters. Beyond the meadow, there is a calm body of water, likely a lake,' + ' surrounded by dense forests. In the background, majestic mountains rise, their peaks partially ' + 'covered with snow, suggesting a high-altitude location. The sky is partly cloudy, with soft ' + 'lighting that enhances the tranquil and idyllic atmosphere of the scene. This type of landscape ' + 'is often associated with alpine regions.' + ] + + for i in range(3): + # Generate texts from the prompts. The output is a list of RequestOutput objects + # that contain the prompt, generated text, and other information. + outputs = llm.generate(inputs, sampling_params) + # Print the outputs. + for i, output in enumerate(outputs): + generated_text = output.outputs[0].text + print( + f"Prompt: {output.prompt!r}, Generated text: {generated_text!r}" + ) + compare_distance(generated_text, except_list[0], bench_sim=0.95) + + # unset env + env_manager.unset_all() diff --git a/tests/st/python/images/1080p.jpeg b/tests/st/python/images/1080p.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..0d298985cf4468902c27eaca2f23f74dae8c80ab Binary files /dev/null and b/tests/st/python/images/1080p.jpeg differ diff --git a/vllm_mindspore/__init__.py b/vllm_mindspore/__init__.py index c5e2deb5cd94932137659264b324f11cffb3b38d..daeeb2a65d6af063da61434ae8757e42fee84b54 100644 --- a/vllm_mindspore/__init__.py +++ b/vllm_mindspore/__init__.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# type: ignore # isort:skip_file # Copyright 2025 Huawei Technologies Co., Ltd # Copyright 2024 The vLLM team. @@ -265,13 +266,19 @@ vllm.model_executor.layers.rejection_sampler._multinomial = _multinomial ######### for multi-model from vllm_mindspore.inputs.registry import call_hf_processor from vllm.inputs.registry import InputProcessingContext + InputProcessingContext.call_hf_processor = call_hf_processor -from vllm_mindspore.multimodal.inputs import as_kwargs +from vllm_mindspore.multimodal.inputs import as_kwargs, from_items, MultiModalFieldElem from vllm.multimodal.inputs import MultiModalKwargs + MultiModalKwargs.as_kwargs = as_kwargs +MultiModalKwargs.from_items = from_items + +vllm.multimodal.inputs.MultiModalFieldElem = MultiModalFieldElem + +from vllm_mindspore.model_executor.layers.rotary_embedding import InferMRotaryEmbedding # type: ignore[attr-defined] -from vllm_mindspore.model_executor.layers.rotary_embedding import InferMRotaryEmbedding vllm.model_executor.layers.rotary_embedding.MRotaryEmbedding = InferMRotaryEmbedding # patch for V1 @@ -284,6 +291,7 @@ from vllm_mindspore.v1.spec_decode import eagle update_modules("vllm.v1.spec_decode.eagle", eagle) from vllm_mindspore.v1.attention.backends import ms_attn + update_modules("vllm.v1.attention.backends.flash_attn", ms_attn) import vllm.v1.worker.gpu_model_runner @@ -292,11 +300,16 @@ from vllm_mindspore.v1.worker.gpu_model_runner import _prepare_inputs vllm.v1.worker.gpu_model_runner.GPUModelRunner._prepare_inputs = _prepare_inputs +from vllm_mindspore.v1.worker.gpu_model_runner import _calc_mrope_positions + +vllm.v1.worker.gpu_model_runner.GPUModelRunner._calc_mrope_positions = _calc_mrope_positions + from vllm_mindspore.v1.worker.gpu_model_runner import _update_states vllm.v1.worker.gpu_model_runner.GPUModelRunner._update_states = _update_states from vllm_mindspore.v1.worker.gpu_model_runner import initialize_kv_cache, get_kv_cache_spec + vllm.v1.worker.gpu_model_runner.GPUModelRunner.initialize_kv_cache = initialize_kv_cache vllm.v1.worker.gpu_model_runner.GPUModelRunner.get_kv_cache_spec = get_kv_cache_spec @@ -369,6 +382,7 @@ Worker.compile_or_warm_up_model = compile_or_warm_up_model from vllm_mindspore.v1.core.sched.scheduler import update_from_output from vllm.v1.core.sched.scheduler import Scheduler + Scheduler.update_from_output = update_from_output from .utils import check_ready diff --git a/vllm_mindspore/distributed/communication_op.py b/vllm_mindspore/distributed/communication_op.py index 00447432e546516bf4d8629c374ac36e491041e8..a24d49595c7d7698331492562e14e7d9c65c08b6 100644 --- a/vllm_mindspore/distributed/communication_op.py +++ b/vllm_mindspore/distributed/communication_op.py @@ -1,5 +1,4 @@ #!/usr/bin/env python3 -# encoding: utf-8 # Copyright 2025 Huawei Technologies Co., Ltd # Copyright 2024 The vLLM team. # @@ -16,20 +15,16 @@ # limitations under the License. # ============================================================================ - # 该文件实现底层通信接口, 要求动静统一, 最后才可以在网络中入图。 # 不要去照搬mindspeed的, 因为训练当中包含太多的特性, 推理只需要非常简单的通信,可以提升性能。 from typing import Any, Dict, Optional, Union -import mindspore as ms from mindspore import Tensor, nn, ops -from mindspore.communication.comm_func import (all_gather_into_tensor, - all_reduce, broadcast, - gather_into_tensor, recv, send) +from mindspore.communication.comm_func import all_reduce, broadcast from vllm.distributed.parallel_state import ( - get_pp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, get_tp_group, get_world_group) + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, + get_tp_group, get_world_group) def tensor_model_parallel_all_reduce(input_: Tensor) -> Tensor: @@ -40,47 +35,6 @@ def tensor_model_parallel_all_reduce(input_: Tensor) -> Tensor: return output -def tensor_model_parallel_all_gather(input_: Tensor, - dim: int = -1) -> Tensor: - if get_tensor_model_parallel_world_size() == 1: - return input_ - """All-gather the input tensor across model parallel group.""" - output, _ = all_gather_into_tensor(input_, group=get_tp_group()) - input_size = input_.shape - if dim < 0: - # Convert negative dim to positive. - dim += len(input_size) - # Reshape - output_tensor = output_tensor.reshape((world_size, ) + input_size) - output_tensor = output_tensor.movedim(0, dim) - output_tensor = output_tensor.reshape(input_size[:dim] + - (world_size * - input_size[dim], ) + - input_size[dim + 1:]) - return output - - -def tensor_model_parallel_gather(input_: Tensor, - dst: int = 0, - dim: int = -1) -> Optional[Tensor]: - if get_tensor_model_parallel_world_size() == 1: - return input_ - """Gather the input tensor across model parallel group.""" - if dim < 0: - # Convert negative dim to positive. - dim += len(input_.shape) - if dim != 0: - input_ = input_.moveaxis(dim, 0) - _dst = get_world_rank_from_tp_group_rank(dst) - output = gather_into_tensor(input_, dst=_dst, group=get_tp_group()) - if get_tensor_model_parallel_rank() == dst: - if dim != 0: - output = output.moveaxis(0, dim) - else: - output = None - return output - - def broadcast_tensor(tensor, src: int = 0): # broadcast tensor to the world group return broadcast(tensor, src, group=get_world_group()) @@ -95,15 +49,6 @@ def broadcast_tensor_dict(tensor_dict: Optional[Dict[Any, Union[Tensor, # return get_tp_group().broadcast_tensor_dict(tensor_dict, src) -def send_to_next_pp_rank(tensor): - send(tensor, next_pp_rank(), group=get_pp_group()) - - -def recv_from_prev_pp_rank(tensor): - output = recv(tensor, prev_pp_rank(), group=get_pp_group()) - return output - - class ReduceFromModelParallelRegion(nn.Cell): "All reduce the input from the model parallel region." @@ -122,7 +67,7 @@ class ReduceFromModelParallelRegion(nn.Cell): class GatherFromModelParallelRegion(nn.Cell): - "Gather the input from model parallel region and concatinate." + "Gather the input from model parallel region and concatenate." def __init__(self): super().__init__() @@ -138,7 +83,32 @@ class GatherFromModelParallelRegion(nn.Cell): # Size and dimension. if self.world_size == 1: return input_ - output = ops.CollectiveGather(dest_rank=dst, group=self.tp_group)(input_.transpose(2, 1, 0)) + output = ops.CollectiveGather(dest_rank=dst, + group=self.tp_group)(input_.transpose( + 2, 1, 0)) if self.tp_rank != dst: return ops.depend(ops.zeros_like(input_), output) return output.transpose(2, 1, 0) + + +class AllGatherFromModelParallelRegion(nn.Cell): + """ + Gather the input from world parallel region and concatenate, simultaneously perform + transpose operation on input. + """ + + def __init__(self): + super().__init__() + self.world_size = get_tensor_model_parallel_world_size() + if self.world_size > 1: + self.tp_group = get_tp_group().device_group._name + self.all_gather_into_tensor = ops.AllGather(group=self.tp_group) + + def construct(self, input_): + # Size and dimension. + if self.world_size == 1: + return input_ + input_ = ops.swapaxes(input_, 0, -1) + output = self.all_gather_into_tensor(input_) + output = ops.swapaxes(output, 0, -1) + return output diff --git a/vllm_mindspore/model_executor/layers/rotary_embedding.py b/vllm_mindspore/model_executor/layers/rotary_embedding.py index ff6ea4da22a78e6db7d7f5a31f18b1ea990f9357..7470233475254ccccdf677d5627a6f3bd59f6408 100644 --- a/vllm_mindspore/model_executor/layers/rotary_embedding.py +++ b/vllm_mindspore/model_executor/layers/rotary_embedding.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 +# type: ignore +# isort:skip_file # Copyright 2025 Huawei Technologies Co., Ltd # Copyright 2024 The vLLM team. # @@ -16,12 +18,15 @@ # ============================================================================ import math +import numpy as np + from typing import Any, Dict, List, Optional, Tuple, Union import mindspore -import numpy as np from mindspore import Tensor, mint, nn, ops from mindspore.common import dtype as mstype +from mindspore.ops.auto_generate.gen_ops_prim import SliceExt + from transformers import PretrainedConfig from vllm.config import get_current_vllm_config @@ -474,9 +479,9 @@ class MRotaryEmbedding(RotaryEmbedding): context_len: int, seq_len: int, ) -> mindspore.Tensor: - return ops.arange( - mrope_position_delta + context_len, - mrope_position_delta + seq_len, + return mint.arange( + int(mrope_position_delta + context_len), + int(mrope_position_delta + seq_len), ).broadcast_to((3, -1)) @@ -531,52 +536,60 @@ class InferMRotaryEmbedding(InferRotaryEmbedding): query: [num_tokens, num_heads * head_size] key: [num_tokens, num_kv_heads * head_size] """ + half_rotary_dim = self.rotary_dim // 2 # prefill if is_prefill: num_tokens = positions.shape[-1] cos, sin = self.freqs_cos[positions], self.freqs_sin[positions] - cos, sin = cos[..., :self.rotary_dim // - 2], sin[..., :self.rotary_dim // 2] + cos = SliceExt()(cos, -1, 0, half_rotary_dim, 1) + sin = SliceExt()(sin, -1, 0, half_rotary_dim, 1) if positions.ndim == 2: - cos_l = ops.split(cos, self.mrope_section, axis=-1) - sin_l = ops.split(sin, self.mrope_section, axis=-1) + cos_l = mint.split(cos, self.mrope_section, dim=-1) + sin_l = mint.split(sin, self.mrope_section, dim=-1) cos, sin = (), () for i in range(len( self.mrope_section)): # type: ignore[arg-type] - cos += (cos_l[i][i], ) - sin += (sin_l[i][i], ) + cos_l_select = mint.index_select(cos_l[i], 0, + Tensor([i])).squeeze(0) + cos += (cos_l_select, ) + sin_l_select = mint.index_select(sin_l[i], 0, + Tensor([i])).squeeze(0) + sin += (sin_l_select, ) cos = ops.cat(cos, axis=-1) sin = ops.cat(sin, axis=-1) query_shape = query.shape query = query.view(num_tokens, -1, self.head_size) - query_rot = query[..., :self.rotary_dim] - query_pass = query[..., self.rotary_dim:] + query_rot = SliceExt()(query, -1, 0, self.rotary_dim, 1) + query_pass = SliceExt()(query, -1, self.rotary_dim, + query_shape[-1], 1) query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style) query = ops.cat((query_rot, query_pass), axis=-1).view(query_shape) key_shape = key.shape key = key.view(num_tokens, -1, self.head_size) - key_rot = key[..., :self.rotary_dim] - key_pass = key[..., self.rotary_dim:] + key_rot = SliceExt()(key, -1, 0, self.rotary_dim, 1) + key_pass = SliceExt()(key, -1, self.rotary_dim, key_shape[-1], 1) key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style) key = ops.cat((key_rot, key_pass), axis=-1).view(key_shape) return query, key # decode - if positions.ndim == 2 and positions.shape[0] == len( - self.mrope_section): # type: ignore[arg-type] - num_tokens = positions.shape[-1] + if positions.ndim == 2: cos, sin = self.freqs_cos[positions], self.freqs_sin[positions] - cos, sin = cos[..., :self.rotary_dim // - 2], sin[..., :self.rotary_dim // 2] - cos_l = ops.split(cos, self.mrope_section, axis=-1) - sin_l = ops.split(sin, self.mrope_section, axis=-1) + cos = SliceExt()(cos, -1, 0, half_rotary_dim, 1) + sin = SliceExt()(sin, -1, 0, half_rotary_dim, 1) + cos_l = mint.split(cos, self.mrope_section, dim=-1) + sin_l = mint.split(sin, self.mrope_section, dim=-1) cos, sin = (), () for i in range(len(self.mrope_section)): # type: ignore[arg-type] - cos += (cos_l[i][i], ) - sin += (sin_l[i][i], ) + cos_l_select = mint.index_select(cos_l[i], 0, + Tensor([i])).squeeze(0) + cos += (cos_l_select, ) + sin_l_select = mint.index_select(sin_l[i], 0, + Tensor([i])).squeeze(0) + sin += (sin_l_select, ) cos = ops.cat(cos, axis=-1) sin = ops.cat(sin, axis=-1) freqs_cos = ops.cat([cos, cos], axis=-1).squeeze(1) diff --git a/vllm_mindspore/model_executor/models/attention_mask.py b/vllm_mindspore/model_executor/models/attention_mask.py index ccfcfdb3dbb728c5da3cc251501958a1fbf4a670..0df3c30ab3caa290aae2b93501c4d8415d84819a 100644 --- a/vllm_mindspore/model_executor/models/attention_mask.py +++ b/vllm_mindspore/model_executor/models/attention_mask.py @@ -1,3 +1,5 @@ +# type: ignore +# isort:skip_file # Copyright 2025 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -12,15 +14,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ - """ infer attention mask. """ import numpy as np - +import mindspore as ms from mindspore import Tensor, mint from mindspore import dtype as mstype - r""" PA:ASD-V2.1.5 1.MLA + Q_seqlen =1: no mask.(BF16 mask(0/-10000), FP16 mask(0/-10000)). @@ -48,20 +48,29 @@ class LowerTriangularMask: prefill_mask_coeff = 1.0 if self.dtype == mstype.bfloat16 else -10000.0 - self.prefill_mask = Tensor(np.triu(np.ones(shape=(128, 128), dtype=np.float16), k=1) * prefill_mask_coeff, - dtype=self.dtype) + self.prefill_mask = Tensor( + np.triu(np.ones(shape=(128, 128), dtype=np.float16), k=1) * + prefill_mask_coeff, + dtype=self.dtype) - self.decode_mask = Tensor(np.triu(np.ones(shape=(self.max_model_len, self.max_model_len), dtype=np.int8), k=1), + self.decode_mask = Tensor(np.triu(np.ones( + shape=(self.max_model_len, self.max_model_len), dtype=np.int8), + k=1), dtype=self.dtype) * -10000 self.hard_mask = mint.zeros((1, 1), dtype=dtype) - def gen_attention_mask(self, is_prefill, position_ids, query_lens): + def gen_attention_mask(self, + is_prefill, + position_ids, + query_lens, + attn_metadata=None): if is_prefill: attention_mask = self.prefill_mask else: if max(query_lens) > 1: - attention_mask = mint.index_select(self.decode_mask, 0, position_ids) + attention_mask = mint.index_select(self.decode_mask, 0, + position_ids) else: attention_mask = self.hard_mask return attention_mask @@ -79,5 +88,44 @@ class MLALowerTriangularMask(LowerTriangularMask): super().__init__(dtype, max_model_len) decode_mask_coeff = 1.0 if self.dtype == mstype.bfloat16 else -10000.0 - self.decode_mask = Tensor(np.triu(np.ones(shape=(self.max_model_len, self.max_model_len), dtype=np.int8), k=1), + self.decode_mask = Tensor(np.triu(np.ones( + shape=(self.max_model_len, self.max_model_len), dtype=np.int8), + k=1), dtype=self.dtype) * decode_mask_coeff + + +class MultiModalLowerTriangularMask(LowerTriangularMask): + r""" + Provide multi modal Infer model attention mask. + Args: + dtype (ms dtype): The compute type of Infer model. + max_model_len (int): The max model length of Infer model. + """ + + def __init__(self, dtype, max_model_len): + + super().__init__(dtype, max_model_len) + + def gen_attention_mask(self, + is_prefill, + position_ids, + query_lens, + attn_metadata=None): + if is_prefill: + attention_mask = self.prefill_mask + else: + if max(query_lens) > 1: + seq_lens_np = attn_metadata.seq_lens_np + context_lens_np = attn_metadata.context_lens.asnumpy() + mm_position_ids_list = [] + for i in range(len(seq_lens_np)): + mm_position_ids_list.append( + np.arange(context_lens_np[i], seq_lens_np[i])) + mm_position_ids = np.concatenate(mm_position_ids_list) + mm_position_ids = ms.Tensor(mm_position_ids, + dtype=position_ids.dtype) + attention_mask = mint.index_select(self.decode_mask, 0, + mm_position_ids) + else: + attention_mask = self.hard_mask + return attention_mask diff --git a/vllm_mindspore/model_executor/models/model_base.py b/vllm_mindspore/model_executor/models/model_base.py index 4a960845109b2c2c8fe60597e11c435b03486135..d2db9794d2d4fbcef588ead23287a0135bc51181 100644 --- a/vllm_mindspore/model_executor/models/model_base.py +++ b/vllm_mindspore/model_executor/models/model_base.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 +# type: ignore +# isort:skip_file # Copyright 2025 Huawei Technologies Co., Ltd # Copyright 2024 The vLLM team. # @@ -17,12 +19,9 @@ import os from abc import abstractmethod -from typing import Iterable, List, Optional, Set, Tuple, Union, Dict +from typing import Iterable, Optional, Set, Tuple, Union, Dict import numpy as np -import mindspore as ms -from mindspore import Tensor, mutable, nn - from vllm.attention.backends.abstract import AttentionType from vllm.config import VllmConfig, get_current_vllm_config from vllm.forward_context import get_forward_context @@ -41,6 +40,7 @@ from vllm_mindspore.v1.attention.backends.ms_attn import MsAttentionMetadata class AttentionWrapper: + def __init__(self): vllm_config = get_current_vllm_config() block_size = vllm_config.cache_config.block_size @@ -49,13 +49,10 @@ class AttentionWrapper: head_size = vllm_config.model_config.get_head_size() num_block = 0 self.kv_shape = [num_block, block_size, num_kv_heads, head_size] - self.kv_cache = [ - ( - ms.mint.zeros(self.kv_shape, dtype=vllm_config.model_config.dtype), - ms.mint.zeros(self.kv_shape, dtype=vllm_config.model_config.dtype), - ) - for _ in range(vllm_config.parallel_config.pipeline_parallel_size) - ] + self.kv_cache = [( + ms.mint.zeros(self.kv_shape, dtype=vllm_config.model_config.dtype), + ms.mint.zeros(self.kv_shape, dtype=vllm_config.model_config.dtype), + ) for _ in range(vllm_config.parallel_config.pipeline_parallel_size)] self.attn_type = AttentionType.DECODER # add for v1 @@ -67,11 +64,13 @@ class AttentionWrapper: class MLAAttentionWrapper(AttentionWrapper): + def __init__(self): super().__init__() vllm_config = get_current_vllm_config() self.kv_cache = [ - (ms.mint.zeros(self.kv_shape, dtype=vllm_config.model_config.dtype),) + (ms.mint.zeros(self.kv_shape, + dtype=vllm_config.model_config.dtype), ) for _ in range(vllm_config.parallel_config.pipeline_parallel_size) ] @@ -121,7 +120,7 @@ class MsModelBase: ) def set_modules(self, model_dicts: Dict[str, nn.Cell]): - self.modules_dict = model_dicts + self.modules_dict = model_dicts # type: ignore[assignment] def _check_modules_valid(self): if self.modules_dict is None: @@ -130,7 +129,8 @@ class MsModelBase: def named_parameters(self): self._check_modules_valid() - for cell_name, module in self.modules_dict.items(): + for cell_name, module in self.modules_dict.items( + ): # type: ignore[attr-defined] for par_name, par in module.parameters_and_names(): if cell_name != "self": par_name = cell_name + "." + par_name @@ -141,7 +141,8 @@ class MsModelBase: self._check_modules_valid() params_dict = dict() - for name, module in self.modules_dict.items(): + for name, module in self.modules_dict.items( + ): # type: ignore[attr-defined] module_params = module.parameters_dict() if name != "self": new_module_params = dict() @@ -155,7 +156,8 @@ class MsModelBase: def named_modules(self, remove_duplicate: bool = True): self._check_modules_valid() - for name, module in self.modules_dict.items(): + for name, module in self.modules_dict.items( + ): # type: ignore[attr-defined] for module_name, sub_module in module.cells_and_names(): if name != "self": module_name = name + "." + module_name @@ -177,7 +179,8 @@ class MsModelBase: def eval(self): self._check_modules_valid() - for _, module in self.modules_dict.items(): + for _, module in self.modules_dict.items( + ): # type: ignore[attr-defined] module.set_train(False) return self @@ -190,13 +193,15 @@ class MsModelBase: inputs_embeds: Optional[Tensor] = None, previous_hidden_states: Optional[Tensor] = None, spec_step_idx: int = 0, + **kwargs, ) -> Union[Tensor, IntermediateTensors]: return self.forward(input_ids, positions, intermediate_tensors, inputs_embeds, previous_hidden_states=previous_hidden_states, - spec_step_idx=spec_step_idx) + spec_step_idx=spec_step_idx, + **kwargs) def forward(self, input_ids: Tensor, @@ -211,9 +216,9 @@ class MsModelBase: value_cache = [] forward_context = get_forward_context() for i in range(self.config.num_hidden_layers): - k_cache = self.kv_caches[i].kv_cache[ + k_cache = self.kv_caches[i].kv_cache[ # type: ignore[attr-defined] forward_context.virtual_engine][0] - v_cache = self.kv_caches[i].kv_cache[ + v_cache = self.kv_caches[i].kv_cache[ # type: ignore[attr-defined] forward_context.virtual_engine][1] key_cache.append(k_cache) value_cache.append(v_cache) @@ -238,11 +243,16 @@ class MsModelBase: @abstractmethod def load_weights(self, weights: Iterable[Tuple[str, Tensor]]) -> Set[str]: - raise NotImplementedError("Function load_weights should be Implemented!") - + raise NotImplementedError( + "Function load_weights should be Implemented!") def _dummy_attention_metadata(self, input_ids: Tensor, positions: Tensor): - input_len = input_ids.shape[0] + if input_ids is not None: + input_len = input_ids.shape[0] + elif positions is not None: + # input_ids is None in multi modal model with v1 arch + input_len = positions.shape[-1] + max_seq_len = ms.Tensor(input_len, dtype=ms.int32) seq_lengths = ms.Tensor([input_len], dtype=ms.int32) q_seq_lens_np = np.array([input_len], dtype=np.int32) @@ -263,14 +273,13 @@ class MsModelBase: # To enforce prefill and decode are both complied in warmup process. # So set max_context_lens to 0 for prefill and 1 for decode. max_context_lens=0 if not self.set_flags else 1, - query_start_loc = None - ) - + query_start_loc=None) def prepare_base_inputs(self, input_ids, positions): attn_metadata = get_forward_context().attn_metadata if attn_metadata is None: - attn_metadata = self._dummy_attention_metadata(input_ids, positions) + attn_metadata = self._dummy_attention_metadata( + input_ids, positions) key_cache, value_cache = self.get_kvcache() if not envs.VLLM_USE_V1: # V0 @@ -287,7 +296,8 @@ class MsModelBase: seq_lens_np = np.array(seq_lens, dtype=np.int32) query_lens_np = np.array(query_lens, dtype=np.int32) kv_cache_lens = seq_lens_np - query_lens_np - if attn_metadata.num_decode_tokens == 0 and kv_cache_lens.max() == 0: + if attn_metadata.num_decode_tokens == 0 and kv_cache_lens.max( + ) == 0: is_prefill = True else: is_prefill = False @@ -296,13 +306,16 @@ class MsModelBase: is_prefill = attn_metadata.max_context_lens == 0 query_lens_np = attn_metadata.q_seq_lens_np seq_lens_np = attn_metadata.seq_lens_np - + + if input_ids is not None: + input_ids = input_ids.astype(ms.int32) q_seq_lens = ms.Tensor(query_lens_np, dtype=ms.int32) position_ids = ms.Tensor(positions, dtype=ms.int32) - attention_mask = self.casual_mask.gen_attention_mask(is_prefill, positions, query_lens_np) + attention_mask = self.casual_mask.gen_attention_mask( # type: ignore[attr-defined] + is_prefill, positions, query_lens_np, attn_metadata) model_inputs = {} - model_inputs["input_ids"] = input_ids.astype(ms.int32) + model_inputs["input_ids"] = input_ids model_inputs["batch_valid_length"] = ms.from_numpy(seq_lens_np) model_inputs["block_tables"] = attn_metadata.block_tables model_inputs["slot_mapping"] = attn_metadata.slot_mapping @@ -316,33 +329,63 @@ class MsModelBase: class NativeModel(MsModelBase): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__(vllm_config=vllm_config, prefix=prefix) self.quant_config = vllm_config.quant_config if vllm_config.lora_config is not None: # native model lora only support pynative mode now vllm_config.model_config.enforce_eager = True - self.is_graph_mode = False if vllm_config.model_config.enforce_eager else True + self.is_graph_mode = bool(not vllm_config.model_config.enforce_eager) self.prev_prefill = False self.run_model = None - def common_preprocess(self, vllm_config, prefix = ""): - self.set_modules({"model": self.model, "lm_head": self.lm_head}) - - self.casual_mask = LowerTriangularMask(dtype=self.model_config.dtype, - max_model_len=self.model_config.max_model_len) - self.kv_caches = [AttentionWrapper() for i in range(self.config.num_hidden_layers)] + def common_preprocess(self, vllm_config, prefix=""): + self.set_modules({ + "model": self.model, + "lm_head": self.lm_head + }) # type: ignore[attr-defined] + + self.casual_mask = LowerTriangularMask( + dtype=self.model_config.dtype, + max_model_len=self.model_config.max_model_len) + self.kv_caches = [ + AttentionWrapper() for i in range(self.config.num_hidden_layers) + ] compilation_config = vllm_config.compilation_config if prefix in compilation_config.static_forward_context: raise ValueError(f"Duplicate layer name: {prefix}") for i in range(self.config.num_hidden_layers): - compilation_config.static_forward_context[str(i)] = self.kv_caches[i] + compilation_config.static_forward_context[str( + i)] = self.kv_caches[i] + + def set_model_inputs(self, input_ids, position_ids, intermediate_tensors, + inputs_embeds, is_prefill): + if input_ids is None: + dyn_input_ids = None + else: + dyn_input_ids = ms.Tensor(shape=[None] * input_ids.ndim, + dtype=mstype.int32) + + if position_ids is None: + dyn_position_ids = None + else: + dyn_position_ids = ms.Tensor(shape=[None] * position_ids.ndim, + dtype=mstype.int32) + if inputs_embeds is None: + dyn_inputs_embeds = None + else: + dyn_inputs_embeds = ms.Tensor(shape=[None] * inputs_embeds.ndim, + dtype=inputs_embeds.dtype) - def set_model_inputs(self, is_prefill): - dyn_input_ids = Tensor(shape=[None], dtype=mstype.int32) - dyn_position_ids = Tensor(shape=[None], dtype=mstype.int32) + if intermediate_tensors is None: + dyn_intermediate_tensors = None + else: + dyn_intermediate_tensors = ms.Tensor( + shape=[None] * intermediate_tensors.ndim, + dtype=intermediate_tensors.dtype) block_size = self.cache_config.block_size num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config) @@ -359,19 +402,19 @@ class NativeModel(MsModelBase): dyn_key_cache = Tensor(shape=kv_cache_shape, dtype=kv_cache_dtype) dyn_value_cache = Tensor(shape=kv_cache_shape, dtype=kv_cache_dtype) dyn_key_caches = mutable([dyn_key_cache for _ in range(num_layers)]) - dyn_value_caches = mutable([dyn_value_cache for _ in range(num_layers)]) - - dyn_slot_mapping = Tensor(shape=[None, ], dtype=mstype.int32) - dynamic_attention_mask = Tensor(shape=[None, None], dtype=self.model_config.dtype) - dyn_batch_valid_length = Tensor(shape=[None,], dtype=mstype.int32) - dyn_q_seq_lens = Tensor(shape=[None, ], dtype=mstype.int32) + dyn_value_caches = mutable( + [dyn_value_cache for _ in range(num_layers)]) + + dyn_slot_mapping = Tensor(shape=[None], dtype=mstype.int32) + dynamic_attention_mask = Tensor(shape=[None, None], + dtype=self.model_config.dtype) + dyn_batch_valid_length = Tensor(shape=[None], dtype=mstype.int32) + dyn_q_seq_lens = Tensor(shape=[None], dtype=mstype.int32) dyn_block_tables = Tensor(shape=[None, None], dtype=mstype.int32) - dyn_intermediate_tensors = None - dyn_inputs_embeds = None self.model.set_inputs( dyn_input_ids, dyn_position_ids, - dyn_key_caches, + dyn_key_caches, # type: ignore[attr-defined] dyn_value_caches, is_prefill, dyn_slot_mapping, @@ -380,11 +423,17 @@ class NativeModel(MsModelBase): dyn_q_seq_lens, dyn_block_tables, dyn_intermediate_tensors, - dyn_inputs_embeds - ) + dyn_inputs_embeds) + + dynamic_hidden_states = Tensor(shape=[None, None], + dtype=self.model_config.dtype) + self.lm_head.set_inputs( + dynamic_hidden_states) # type: ignore[attr-defined] - def prepare_inputs(self, input_ids, positions, intermediate_tensors, inputs_embeds): - model_inputs, is_prefill = self.prepare_base_inputs(input_ids, positions) + def prepare_inputs(self, input_ids, positions, intermediate_tensors, + inputs_embeds): + model_inputs, is_prefill = self.prepare_base_inputs( + input_ids, positions) # for multimodal model model_inputs["intermediate_tensors"] = intermediate_tensors @@ -392,27 +441,31 @@ class NativeModel(MsModelBase): return model_inputs, is_prefill - def exec_model( - self, - input_ids: Tensor, - positions: Tensor, - intermediate_tensors: IntermediateTensors = None, - inputs_embeds: Tensor = None, - **kwargs - ): - model_inputs, is_prefill = self.prepare_inputs(input_ids, positions, intermediate_tensors, inputs_embeds) + def exec_model(self, + input_ids: Tensor, + positions: Tensor, + intermediate_tensors: IntermediateTensors = None, + inputs_embeds: Tensor = None, + **kwargs): + model_inputs, is_prefill = self.prepare_inputs(input_ids, positions, + intermediate_tensors, + inputs_embeds) if self.prev_prefill != is_prefill and self.is_graph_mode: - self.set_model_inputs(is_prefill) + self.set_model_inputs(input_ids, positions, intermediate_tensors, + inputs_embeds, is_prefill) self.prev_prefill = is_prefill - # for dummy_attention_metadata + # for dummy_attention_metadata if is_prefill and not self.set_flags: self.set_flags = True if self.run_model is None: - self.run_model = ms.jit(function=self.model, jit_level='O0') if self.is_graph_mode else self.model - model_output = self.run_model( + self.run_model = ms.jit( + function=self.model, # type: ignore[attr-defined] + jit_level='O0' + ) if self.is_graph_mode else self.model # type: ignore[attr-defined] + model_output = self.run_model( # type: ignore[misc] input_ids=model_inputs["input_ids"], positions=model_inputs["position_ids"], key_caches=model_inputs["key_cache"], @@ -425,6 +478,6 @@ class NativeModel(MsModelBase): block_tables=model_inputs["block_tables"], intermediate_tensors=model_inputs["intermediate_tensors"], inputs_embeds=model_inputs["inputs_embeds"], - ) + ) return model_output diff --git a/vllm_mindspore/model_executor/models/qwen2.py b/vllm_mindspore/model_executor/models/qwen2.py index 27cf2b234152f994245f066edb77a4cccc2f29fa..87c54c2126c2456cf7ec73ec123f8b3050386570 100644 --- a/vllm_mindspore/model_executor/models/qwen2.py +++ b/vllm_mindspore/model_executor/models/qwen2.py @@ -16,7 +16,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -import numpy as np from typing import (TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple, Union) @@ -25,19 +24,14 @@ if TYPE_CHECKING: else: Qwen2Config = None -import mindspore as ms from mindspore import Parameter, Tensor, mint, nn -from mindspore.common import dtype as mstype -import vllm.envs as envs from vllm.attention.backends.abstract import AttentionType from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size -from vllm.forward_context import get_forward_context from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.models.interfaces import SupportsLoRA from vllm.sequence import IntermediateTensors -from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm_mindspore.attention import Attention from vllm_mindspore.model_executor.layers.activation import SwiGLU @@ -53,15 +47,11 @@ from vllm_mindspore.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm_mindspore.model_executor.model_loader.weight_utils import \ default_weight_loader -from vllm_mindspore.model_executor.models.attention_mask import \ - LowerTriangularMask -from vllm_mindspore.model_executor.models.model_base import (AttentionWrapper, - NativeModel) +from vllm_mindspore.model_executor.models.model_base import NativeModel from vllm_mindspore.model_executor.models.utils import ( PPMissingLayer, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm_mindspore.utils import STR_DTYPE_TO_MS_DTYPE class Qwen2MLP(nn.Cell): @@ -469,14 +459,12 @@ class Qwen2ForCausalLM(NativeModel, SupportsLoRA): self.common_preprocess(vllm_config, prefix) - def forward( - self, - input_ids: Tensor, - positions: Tensor, - intermediate_tensors: IntermediateTensors = None, - inputs_embeds: Tensor = None, - **kwargs - ) -> Union[Tensor, IntermediateTensors]: + def forward(self, + input_ids: Tensor, + positions: Tensor, + intermediate_tensors: IntermediateTensors = None, + inputs_embeds: Tensor = None, + **kwargs) -> Union[Tensor, IntermediateTensors]: hidden_states = self.exec_model(input_ids, positions, intermediate_tensors, inputs_embeds) return hidden_states diff --git a/vllm_mindspore/model_executor/models/qwen2_5_vl.py b/vllm_mindspore/model_executor/models/qwen2_5_vl.py new file mode 100644 index 0000000000000000000000000000000000000000..e7fc1be50aa19b984319ffbfc03f89061b2f502d --- /dev/null +++ b/vllm_mindspore/model_executor/models/qwen2_5_vl.py @@ -0,0 +1,1079 @@ +# SPDX-License-Identifier: Apache-2.0 +# type: ignore +# isort:skip_file +# Adapted from +# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/qwen2_5_vl.py +# Copyright 2025 Huawei Technologites Co., Ltd +# Copyright 2025 The vLLM team. +# Copyright 2025 The Qwen Team. +# Copyright 2025 The HuggingFace Inc. team. +# All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Inference-only Qwen2.5-VL model compatible with HuggingFace weights.""" +import os +from functools import partial +from typing import Callable, Iterable, Mapping, Optional, Set, Tuple, Union, Dict, Any + +import math +import mindspore as ms +import mindspore.nn as nn +import mindspore.mint as mint +import mindspore.ops as ops +import mindspore.mint.nn.functional as F +from mindspore import dtype as mstype + +from vllm_mindspore.model_executor.layers.layernorm import RMSNorm +from vllm_mindspore.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear +from vllm_mindspore.model_executor.layers.logits_processor import LogitsProcessor +from vllm_mindspore.model_executor.layers.quantization.base_config import QuantizationConfig +from vllm_mindspore.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm_mindspore.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm_mindspore.model_executor.model_loader.weight_utils import default_weight_loader +from vllm_mindspore.model_executor.models.model_base import NativeModel, AttentionWrapper +from vllm_mindspore.model_executor.models.interfaces import SupportsMultiModal +from vllm_mindspore.model_executor.models.qwen2 import Qwen2Model # type: ignore[attr-defined] +from vllm_mindspore.model_executor.models.utils import PPMissingLayer, WeightsMapper, maybe_prefix, \ + merge_multimodal_embeddings +from vllm_mindspore.model_executor.models.attention_mask import MultiModalLowerTriangularMask +from vllm_mindspore.distributed.communication_op import AllGatherFromModelParallelRegion + +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.models.module_mapping import MultiModelKeys +from vllm.model_executor.models.qwen2_vl import Qwen2VLDummyInputsBuilder as Qwen2_5_VLDummyInputsBuilder +from vllm.model_executor.models.qwen2_vl import Qwen2VLMultiModalProcessor +from vllm.model_executor.models.qwen2_5_vl import Qwen2_5_VLImageInputs, Qwen2_5_VLVideoInputs, \ + Qwen2_5_VLImagePixelInputs, Qwen2_5_VLImageEmbeddingInputs, Qwen2_5_VLVideoPixelInputs, \ + Qwen2_5_VLVideoEmbeddingInputs, Qwen2_5_VLProcessingInfo + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs +from vllm.multimodal.processing import PromptReplacement +from vllm.multimodal.parse import MultiModalDataItems +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank +from vllm.distributed import utils as dist_utils +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.config import uses_mrope + +logger = init_logger(__name__) + +_ACTIVATION_REGISTRY = {"silu": F.silu} + +# === Vision Inputs === # + + +class _Qwen2VLMultiModalProcessor(Qwen2VLMultiModalProcessor): + + def _get_prompt_replacements( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, Any], + out_mm_kwargs: MultiModalKwargs, + ) -> list[PromptReplacement]: + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + image_processor = self.info.get_image_processor( + **hf_processor_mm_kwargs) + tokenizer = self.info.get_tokenizer() + vocab = tokenizer.get_vocab() + + placeholder = { + "image": vocab[hf_processor.image_token], + "video": vocab[hf_processor.video_token], + } + + merge_length = image_processor.merge_size**2 + + def get_replacement_qwen2vl(item_idx: int, modality: str): + grid_thw = out_mm_kwargs[f"{modality}_grid_thw"][item_idx] + assert isinstance(grid_thw, ms.Tensor) + + num_tokens = int(grid_thw.prod()) // merge_length + return [placeholder[modality]] * num_tokens + + return [ + PromptReplacement( + modality=modality, + target=[placeholder[modality]], + replacement=partial(get_replacement_qwen2vl, + modality=modality), + ) for modality in ("image", "video") + ] + + +# === Vision Encoder === # + + +class Qwen2_5_VisionMLP(nn.Cell): + + def __init__(self, + in_features: int, + hidden_features: int, + bias: bool = False, + act_fn: Callable[[ms.Tensor], ms.Tensor] = F.silu, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + self.gate_proj = ColumnParallelLinear(in_features, + hidden_features, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.gate_proj", + params_dtype=ms.bfloat16) + self.up_proj = ColumnParallelLinear(in_features, + hidden_features, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.up_proj", + params_dtype=ms.bfloat16) + self.down_proj = RowParallelLinear(hidden_features, + in_features, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + params_dtype=ms.bfloat16) + self.act_fn = act_fn + + def construct(self, x: ms.Tensor): + x_gate, _ = self.gate_proj(x) + x_gate = self.act_fn(x_gate) + x_up, _ = self.up_proj(x) + x_down, _ = self.down_proj(x_gate * x_up) + return x_down + + +def apply_rotary_pos_emb_flashatt( + q: ms.Tensor, k: ms.Tensor, cos: ms.Tensor, + sin: ms.Tensor) -> Tuple[ms.Tensor, ms.Tensor]: + q_embed = ops.rotary_position_embedding(q.float(), cos, sin).type_as(q) + k_embed = ops.rotary_position_embedding(k.float(), cos, sin).type_as(k) + return q_embed, k_embed + + +class Qwen2_5_VisionAttention(nn.Cell): + + def __init__( + self, + embed_dim: int, + num_heads: int, + projection_size: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + # Per attention head and per partition values. + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.hidden_size_per_attention_head = dist_utils.divide( + projection_size, num_heads) + self.num_attention_heads_per_partition = dist_utils.divide( + num_heads, self.tp_size) + self.num_heads = num_heads + self.head_dim = self.hidden_size_per_attention_head + + self.qkv = ColumnParallelLinear(input_size=embed_dim, + output_size=3 * projection_size, + quant_config=quant_config, + prefix=f"{prefix}.qkv", + params_dtype=ms.bfloat16) + self.proj = RowParallelLinear(input_size=projection_size, + output_size=embed_dim, + quant_config=quant_config, + prefix=f"{prefix}.proj", + params_dtype=ms.bfloat16) + self.tensor_model_parallel_all_gather = AllGatherFromModelParallelRegion( + ) + + def split_tensor_along_last_dim( + self, + tensor: ms.Tensor, + num_partitions: int, + contiguous_split_chunks: bool = False, + ): + """ Split a tensor along its last dimension. + + Arguments: + tensor: input tensor. + num_partitions: number of partitions to split the tensor + contiguous_split_chunks: If True, make each chunk contiguous + in memory. + + Returns: + A list of Tensors + """ + # Get the size and dimension. + last_dim = tensor.dim() - 1 + last_dim_size = dist_utils.divide(tensor.shape[last_dim], + num_partitions) + # Split. + tensor_list = mint.split(tensor, last_dim_size, dim=last_dim) + # NOTE: torch.split does not create contiguous tensors by default. + + return tensor_list + + def split_qkv(self, qkv: ms.Tensor) -> tuple[ms.Tensor, ...]: + # [s, 3 * head * head_dim] + seq_len, _ = qkv.shape + if self.tp_size > 1: + qkv = self.tensor_model_parallel_all_gather(qkv) + + # [s, 3 * head * head_dim] -> 3 * [s, head * head_dim] + q, k, v = mint.chunk(qkv, 3, dim=-1) + + # 3 * [s, head * head_dim] + if self.tp_size > 1: + splitter = partial(self.split_tensor_along_last_dim, + num_partitions=self.tp_size) + q = splitter(q)[self.tp_rank] + k = splitter(k)[self.tp_rank] + v = splitter(v)[self.tp_rank] + + # 3 * [s, head * head_dim] -> 3 * [s, head, head_dim] + new_shape = (seq_len, self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head) + q, k, v = (x.view(*new_shape) for x in (q, k, v)) + return q, k, v + + def construct( + self, + x: ms.Tensor, + cu_seqlens: ms.Tensor, + position_embeddings: Tuple[ms.Tensor, ms.Tensor], + ) -> ms.Tensor: + seq_length = x.shape[0] + x, _ = self.qkv(x) + q, k, v = self.split_qkv(x) + + cos, sin = position_embeddings + q, k = apply_rotary_pos_emb_flashatt(mint.unsqueeze(q, 0), + mint.unsqueeze(k, 0), cos, sin) + + q = mint.squeeze(q, 0) + k = mint.squeeze(k, 0) + + context_layer = ops.flash_attention_score( + q, + k, + v, + self.num_heads // self.tp_size, + actual_seq_qlen=cu_seqlens, + actual_seq_kvlen=cu_seqlens, + scalar_value=1 / math.sqrt(q.shape[-1]), + input_layout="TND", + ).reshape(seq_length, -1) + output, _ = self.proj(context_layer) + return output + + +class Qwen2_5_VisionBlock(nn.Cell): + + def __init__( + self, + dim: int, + num_heads: int, + mlp_hidden_dim: int, + act_fn: Callable[[ms.Tensor], ms.Tensor] = F.silu, + norm_layer: Optional[Callable[[int], nn.Cell]] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + if norm_layer is None: + norm_layer = partial(mint.nn.LayerNorm, + eps=1e-6, + dtype=ms.bfloat16) + self.norm1 = norm_layer(dim) + self.norm2 = norm_layer(dim) + self.attn = Qwen2_5_VisionAttention(embed_dim=dim, + num_heads=num_heads, + projection_size=dim, + quant_config=quant_config, + prefix=f"{prefix}.attn") + self.mlp = Qwen2_5_VisionMLP(dim, + mlp_hidden_dim, + act_fn=act_fn, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + + def construct( + self, x: ms.Tensor, cu_seqlens: ms.Tensor, + position_embeddings: Tuple[ms.Tensor, ms.Tensor]) -> ms.Tensor: + x = x + self.attn(self.norm1(x), + cu_seqlens=cu_seqlens, + position_embeddings=position_embeddings) + x = x + self.mlp(self.norm2(x)) + return x + + +class Qwen2_5_VisionPatchEmbed(nn.Cell): + + def __init__( + self, + patch_size: int = 14, + temporal_patch_size: int = 2, + in_channels: int = 3, + hidden_size: int = 1152, + ) -> None: + super().__init__() + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.hidden_size = hidden_size + + kernel_size = (temporal_patch_size, patch_size, patch_size) + self.proj = mint.nn.Conv3d(in_channels, + hidden_size, + kernel_size=kernel_size, + stride=kernel_size, + bias=False, + dtype=ms.bfloat16) + + def construct(self, x: ms.Tensor) -> ms.Tensor: + L, C = x.shape + x = x.view(L, -1, self.temporal_patch_size, self.patch_size, + self.patch_size) + x = self.proj(x).view(L, self.hidden_size) + return x + + +class Qwen2_5_VisionPatchMerger(nn.Cell): + + def __init__( + self, + d_model: int, + context_dim: int, + norm_layer: Optional[Callable[[int], nn.Cell]] = None, + spatial_merge_size: int = 2, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = context_dim * (spatial_merge_size**2) + if norm_layer is None: + norm_layer = partial(mint.nn.LayerNorm, + eps=1e-6, + dtype=ms.bfloat16) + self.ln_q = norm_layer(context_dim) + self.mlp = nn.CellList([ + ColumnParallelLinear(self.hidden_size, + self.hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.mlp.0", + params_dtype=ms.bfloat16), + nn.GELU(), + RowParallelLinear(self.hidden_size, + d_model, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.mlp.2", + params_dtype=ms.bfloat16), + ]) + + def construct(self, x: ms.Tensor) -> ms.Tensor: + x = self.ln_q(x) + x = x.view(-1, self.hidden_size) + + mlp_fc1, mlp_act, mlp_fc2 = self.mlp + x_parallel, _ = mlp_fc1(x) + x_parallel = mlp_act(x_parallel) + out, _ = mlp_fc2(x_parallel) + return out + + +class Qwen2_5_VisionRotaryEmbedding(nn.Cell): + + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__() + self.dim = dim + self.theta = theta + self.inv_freq = 1.0 / (theta**( + mint.arange(0, dim, 2, dtype=ms.float32) / dim)) + self._seq_len_cached = 0 + self._freqs_cached = None + + def update_freqs_cache(self, seqlen: int) -> None: + if seqlen > self._seq_len_cached: + seqlen *= 2 + self._seq_len_cached = seqlen + self.inv_freq = 1.0 / (self.theta**( + mint.arange(0, self.dim, 2, dtype=ms.float32) / self.dim)) + seq = mint.arange(seqlen, dtype=self.inv_freq.dtype) + freqs = mint.outer(seq, self.inv_freq) + self._freqs_cached = freqs + + def construct(self, seqlen: int) -> ms.Tensor: + self.update_freqs_cache(seqlen) + return self._freqs_cached[:seqlen] # type: ignore[index] + + +class Qwen2_5_VisionTransformer(nn.Cell): + + def __init__( + self, + vision_config, + norm_eps: float = 1e-6, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + + patch_size = vision_config.patch_size + temporal_patch_size = vision_config.temporal_patch_size + in_channels = vision_config.in_channels + depth = vision_config.depth + self.hidden_size = vision_config.hidden_size + self.num_heads = vision_config.num_heads + + # args for get_window_index + self.window_size = vision_config.window_size + self.patch_size = vision_config.patch_size + self.spatial_merge_size = vision_config.spatial_merge_size + self.fullatt_block_indexes = vision_config.fullatt_block_indexes + self.spatial_merge_unit = self.spatial_merge_size**2 + + self.patch_embed = Qwen2_5_VisionPatchEmbed( + patch_size=patch_size, + temporal_patch_size=temporal_patch_size, + in_channels=in_channels, + hidden_size=self.hidden_size, + ) + + norm_layer = partial(RMSNorm, eps=norm_eps, params_dtype=ms.bfloat16) + head_dim = self.hidden_size // self.num_heads + self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2) + + self.blocks = nn.CellList([ + Qwen2_5_VisionBlock( + dim=self.hidden_size, + num_heads=self.num_heads, + mlp_hidden_dim=vision_config.intermediate_size, + act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act], + norm_layer=norm_layer, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{layer_idx}") + for layer_idx in range(depth) + ]) + self.merger = Qwen2_5_VisionPatchMerger( + d_model=vision_config.out_hidden_size, + context_dim=self.hidden_size, + norm_layer=norm_layer, + spatial_merge_size=self.spatial_merge_size, + quant_config=quant_config, + prefix=f"{prefix}.merger", + ) + from mindspore.communication.management import get_rank + self.rank_id = get_rank() + + def set_model_inputs(self): + dyn_x = ms.Tensor(shape=[None, None], dtype=self.dtype) + dyn_rotary_pos_emb = ms.Tensor(shape=[None, None], + dtype=mstype.float32) + dyn_window_index = ms.Tensor(shape=[None], dtype=mstype.int64) + dyn_cu_window_seqlens = ms.Tensor(shape=[None], dtype=mstype.int64) + dyn_grid_thw = ms.Tensor(shape=[None, None], dtype=mstype.int64) + + self.set_inputs( + dyn_x, + dyn_rotary_pos_emb, + dyn_window_index, + dyn_cu_window_seqlens, + dyn_grid_thw, + ) + + @property + def dtype(self) -> ms.Type: + return self.patch_embed.proj.weight.dtype + + def construct( + self, + x: ms.Tensor, + rotary_pos_emb: ms.Tensor, + window_index: ms.Tensor, + cu_window_seqlens: ms.Tensor, + grid_thw: ms.Tensor, + ) -> ms.Tensor: + hidden_states = x.to(dtype=self.dtype) + hidden_states = self.patch_embed(hidden_states) + + cu_window_seqlens = cu_window_seqlens.astype(ms.int32) + cu_window_seqlens = mint.unique_consecutive(cu_window_seqlens) + seq_len, _ = hidden_states.shape + hidden_states = hidden_states.reshape( + seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + hidden_states = hidden_states[window_index] + hidden_states = hidden_states.reshape(seq_len, -1) + rotary_pos_emb = rotary_pos_emb.reshape( + seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + rotary_pos_emb = rotary_pos_emb[window_index] + rotary_pos_emb = rotary_pos_emb.reshape(1, seq_len, 1, -1) + emb = mint.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + position_embeddings = (mint.cos(emb), mint.sin(emb)) + + grid_thw_1 = grid_thw.index_select(1, ms.Tensor([1])).reshape(-1) + grid_thw_2 = grid_thw.index_select(1, ms.Tensor([2])).reshape(-1) + grid_thw_0 = grid_thw.index_select(1, ms.Tensor([0])).reshape(-1) + cu_seqlens = mint.cumsum(mint.repeat_interleave( + grid_thw_1 * grid_thw_2, grid_thw_0), + dim=0, + dtype=ms.int32) + + cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0) + # transformers + for layer_num, blk in enumerate(self.blocks): + if layer_num in self.fullatt_block_indexes: + cu_seqlens_now = cu_seqlens + else: + cu_seqlens_now = cu_window_seqlens + hidden_states = blk(hidden_states, + cu_seqlens=cu_seqlens_now, + position_embeddings=position_embeddings) + + # adapter + hidden_states = self.merger(hidden_states) + reverse_indices = mint.argsort(window_index) + hidden_states = hidden_states[reverse_indices] + return hidden_states + + def load_weights(self, weights: Iterable[Tuple[str, ms.Tensor]], + params_dict: Dict[str, ms.Parameter]) -> Set[str]: + loaded_params: Set[str] = set() + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + for name, loaded_weight in weights: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class Qwen2_5_VLMultiModalProcessor(_Qwen2VLMultiModalProcessor): + + def _get_mm_fields_config( + self, + hf_inputs, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + **super()._get_mm_fields_config(hf_inputs, hf_processor_mm_kwargs), + second_per_grid_ts=MultiModalFieldConfig.batched("video"), + ) + + +@MULTIMODAL_REGISTRY.register_processor( + Qwen2_5_VLMultiModalProcessor, + info=Qwen2_5_VLProcessingInfo, + dummy_inputs=Qwen2_5_VLDummyInputsBuilder) +class Qwen2_5_VLForConditionalGeneration(NativeModel, SupportsMultiModal): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + # LoRA specific attributes + supported_lora_modules = [ + # language model + "qkv_proj", + "o_proj", + "gate_up_proj", + "down_proj", # Same name with vision encoder + # vision tower + "qkv", + "gate_proj", + "up_proj", + "attn.proj", # Distinguish patch_embed.proj + "fc1", + "fc2", + # projector + "mlp.0", + "mlp.2" + ] + + embedding_modules = {} # type: ignore[var-annotated] + embedding_padding_modules = [] # type: ignore[var-annotated] + + # To ensure correct weight loading and mapping. + hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={ + "lm_head.": "language_model.lm_head.", + "model.": "language_model.model.", + }) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + + self.config = config + self.multimodal_config = multimodal_config + + self.visual = Qwen2_5_VisionTransformer( + config.vision_config, + norm_eps=getattr(config, "rms_norm_eps", 1e-6), + quant_config=self._maybe_ignore_quant_config(quant_config), + prefix=maybe_prefix(prefix, "visual"), + ) + self.visual = ms.jit( + function=self.visual, + jit_level='O0') if self.is_graph_mode else self.visual + + self.model = Qwen2Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + + if get_pp_group().is_last_rank: + if config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + params_dtype=ms.bfloat16, + quant_config=quant_config, + prefix=maybe_prefix( + prefix, "lm_head")) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = get_sampler() + else: + self.lm_head = PPMissingLayer() + + self.common_preprocess(vllm_config, prefix) + self.spatial_merge_size = config.vision_config.spatial_merge_size + + self.window_size = config.vision_config.window_size + self.patch_size = config.vision_config.patch_size + self.spatial_merge_unit = self.spatial_merge_size**2 + self.hidden_size = config.vision_config.hidden_size + self.num_heads = config.vision_config.num_heads + head_dim = self.hidden_size // self.num_heads + self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2) + if self.is_graph_mode: + self.visual.set_model_inputs() + + def common_preprocess(self, vllm_config, prefix=""): + self.set_modules({ + "visual": self.visual, + "model": self.model, + "lm_head": self.lm_head + }) + self.casual_mask = MultiModalLowerTriangularMask( + dtype=self.model_config.dtype, + max_model_len=self.model_config.max_model_len) + self.kv_caches = [ + AttentionWrapper() for i in range(self.config.num_hidden_layers) + ] + + compilation_config = vllm_config.compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + for i in range(self.config.num_hidden_layers): + compilation_config.static_forward_context[str( + i)] = self.kv_caches[i] + + def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig): + # GPTQ configs do not have a list of ignored modules, however AutoGPTQ + # seems to avoid vision encoder sections for some models. + # if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)): + # return None + return quant_config + + def _validate_and_reshape_mm_tensor(self, mm_input: object, + name: str) -> ms.Tensor: + if not isinstance(mm_input, (ms.Tensor, list)): + raise ValueError(f"Incorrect type of {name}. " + f"Got type: {type(mm_input)}") + if isinstance(mm_input, ms.Tensor): + if mm_input.ndim == 2: + return mm_input + if mm_input.ndim != 3: + raise ValueError(f"{name} should be 2D or batched 3D tensor. " + f"Got ndim: {mm_input.ndim} " + f"(shape={mm_input.shape})") + return mint.concat(list(mm_input)) + else: + return mint.concat(mm_input) + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[Qwen2_5_VLImageInputs]: + pixel_values = kwargs.pop("pixel_values", None) + image_embeds = kwargs.pop("image_embeds", None) + image_grid_thw = kwargs.pop("image_grid_thw", None) + + if pixel_values is None and image_embeds is None: + return None + + if pixel_values is not None: + pixel_values = self._validate_and_reshape_mm_tensor( + pixel_values, "image pixel values") + image_grid_thw = self._validate_and_reshape_mm_tensor( + image_grid_thw, "image grid_thw") + + if not isinstance(pixel_values, (ms.Tensor, list)): + raise ValueError("Incorrect type of image pixel values. " + f"Got type: {type(pixel_values)}") + + return Qwen2_5_VLImagePixelInputs(type="pixel_values", + pixel_values=pixel_values, + image_grid_thw=image_grid_thw) + + if image_embeds is not None: + image_embeds = self._validate_and_reshape_mm_tensor( + image_embeds, "image embeds") + image_grid_thw = self._validate_and_reshape_mm_tensor( + image_grid_thw, "image grid_thw") + + if not isinstance(image_embeds, ms.Tensor): + raise ValueError("Incorrect type of image embeddings. " + f"Got type: {type(image_embeds)}") + return Qwen2_5_VLImageEmbeddingInputs( + type="image_embeds", + image_embeds=image_embeds, + image_grid_thw=image_grid_thw) + + return None + + def _parse_and_validate_video_input( + self, **kwargs: object) -> Optional[Qwen2_5_VLVideoInputs]: + pixel_values_videos = kwargs.pop("pixel_values_videos", None) + video_embeds = kwargs.pop("video_embeds", None) + video_grid_thw = kwargs.pop("video_grid_thw", None) + second_per_grid_ts = kwargs.pop("second_per_grid_ts", None) + + if pixel_values_videos is None and video_embeds is None: + return None + + if pixel_values_videos is not None: + pixel_values_videos = self._validate_and_reshape_mm_tensor( + pixel_values_videos, "video pixel values") + video_grid_thw = self._validate_and_reshape_mm_tensor( + video_grid_thw, "video grid_thw") + + return Qwen2_5_VLVideoPixelInputs( + type="pixel_values_videos", + pixel_values_videos=pixel_values_videos, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + ) + + if video_embeds is not None: + video_embeds = self._validate_and_reshape_mm_tensor( + video_embeds, "video embeds") + video_grid_thw = self._validate_and_reshape_mm_tensor( + video_grid_thw, "video grid_thw") + + if not isinstance(video_embeds, ms.Tensor): + raise ValueError("Incorrect type of video embeddings. " + f"Got type: {type(video_embeds)}") + return Qwen2_5_VLVideoEmbeddingInputs( + type="video_embeds", + video_embeds=video_embeds, + video_grid_thw=video_grid_thw) + + return None + + def rot_pos_emb(self, grid_thw: ms.Tensor) -> ms.Tensor: + pos_ids = [] + for t, h, w in grid_thw: + t, h, w = t.item(), h.item(), w.item() + hpos_ids = mint.arange(h).unsqueeze(1).expand((-1, w)) + wpos_ids = mint.arange(w).unsqueeze(0).expand((h, -1)) + + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ).permute(0, 2, 1, 3).flatten() + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ).permute(0, 2, 1, 3).flatten() + pos_ids.append( + mint.tile(mint.stack([hpos_ids, wpos_ids], dim=-1), (t, 1))) + pos_ids = mint.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max().item() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb + + def get_window_index(self, grid_thw): + window_index = [] + cu_window_seqlens = [ms.Tensor([0])] + window_index_id = 0 + vit_merger_window_size = (self.window_size // + self.spatial_merge_size // self.patch_size) + + for grid_t, grid_h, grid_w in grid_thw: + grid_t, grid_h, grid_w = grid_t.item(), grid_h.item(), grid_w.item( + ) + llm_grid_h = grid_h // self.spatial_merge_size + llm_grid_w = grid_w // self.spatial_merge_size + index = mint.arange(grid_t * llm_grid_h * llm_grid_w).reshape( + grid_t, llm_grid_h, llm_grid_w) + pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size + pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size + num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size + num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size + index_padded = F.pad(index, (0, pad_w, 0, pad_h), 'constant', -100) + index_padded = index_padded.reshape(grid_t, num_windows_h, + vit_merger_window_size, + num_windows_w, + vit_merger_window_size) + index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( + grid_t, num_windows_h * num_windows_w, vit_merger_window_size, + vit_merger_window_size) + seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) + index_padded = index_padded.reshape(-1) + index_new = index_padded[index_padded != -100] + window_index.append(index_new + window_index_id) + cu_seqlens_tmp = mint.cumsum( + seqlens, + 0) * self.spatial_merge_unit + cu_window_seqlens[-1][-1] + cu_window_seqlens.append(cu_seqlens_tmp) + window_index_id += grid_t * llm_grid_h * llm_grid_w + window_index = mint.cat(window_index, dim=0) + cu_window_seqlens = mint.cat(cu_window_seqlens, dim=0) + return window_index, cu_window_seqlens + + def _process_image_input( + self, image_input: Qwen2_5_VLImageInputs) -> tuple[ms.Tensor, ...]: + + grid_thw = image_input["image_grid_thw"] + assert grid_thw.ndim == 2 + + if image_input["type"] == "image_embeds": + image_embeds = image_input["image_embeds"].type(self.visual.dtype) + else: + pixel_values = image_input["pixel_values"].type(self.visual.dtype) + os.environ[ + "MS_DISABLE_INTERNAL_KERNELS_LIST"] = "FlashAttentionScore" + # compute position embedding + rotary_pos_emb = self.rot_pos_emb(grid_thw) + # windows attention + window_index, cu_window_seqlens = self.get_window_index(grid_thw) + image_embeds = self.visual(pixel_values, rotary_pos_emb, + window_index, cu_window_seqlens, + grid_thw) + os.environ["MS_DISABLE_INTERNAL_KERNELS_LIST"] = "" + + # Split concatenated embeddings for each image item. + merge_size = self.visual.spatial_merge_size + sizes = grid_thw.prod(-1) // merge_size // merge_size + + return image_embeds.split(sizes.tolist()) + + def _process_video_input( + self, video_input: Qwen2_5_VLVideoInputs) -> tuple[ms.Tensor, ...]: + + grid_thw = video_input["video_grid_thw"] + assert grid_thw.ndim == 2 + + if video_input["type"] == "video_embeds": + video_embeds = video_input["video_embeds"].type(self.visual.dtype) + else: + pixel_values_videos = video_input["pixel_values_videos"].type( + self.visual.dtype) + os.environ[ + "MS_DISABLE_INTERNAL_KERNELS_LIST"] = "FlashAttentionScore" + rotary_pos_emb = self.rot_pos_emb(grid_thw) + # windows attention + window_index, cu_window_seqlens = self.get_window_index(grid_thw) + video_embeds = self.visual(pixel_values_videos, rotary_pos_emb, + window_index, cu_window_seqlens, + grid_thw) + os.environ["MS_DISABLE_INTERNAL_KERNELS_LIST"] = "" + + # Split concatenated embeddings for each video item. + merge_size = self.visual.spatial_merge_size + sizes = grid_thw.prod(-1) // merge_size // merge_size + + return video_embeds.split(sizes.tolist()) + + def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: + modalities = {} + + # Preserve the order of modalities if there are multiple of them + # from the order of kwargs. + for input_key in kwargs: + if input_key in ("pixel_values", + "image_embeds") and "images" not in modalities: + modalities["images"] = self._parse_and_validate_image_input( + **kwargs) + if input_key in ("pixel_values_videos", + "video_embeds") and "videos" not in modalities: + modalities["videos"] = self._parse_and_validate_video_input( + **kwargs) + return modalities + + def get_multimodal_embeddings(self, + **kwargs) -> Optional[tuple[ms.Tensor, ...]]: + + modalities = self._parse_and_validate_multimodal_inputs(**kwargs) + if not modalities: + return None + + # The result multimodal_embeddings is tuple of tensors, with each + # tensor correspoending to a multimodal data item (image or video). + multimodal_embeddings: tuple[ms.Tensor, ...] = () + + # NOTE: It is important to iterate over the keys in this dictionary + # to preserve the order of the modalities. + for modality in modalities: + if modality == "images": + image_input = modalities["images"] + vision_embeddings = self._process_image_input(image_input) + multimodal_embeddings += vision_embeddings + if modality == "videos": + video_input = modalities["videos"] + video_embeddings = self._process_video_input(video_input) + multimodal_embeddings += video_embeddings + return multimodal_embeddings + + def get_input_embeddings( + self, + input_ids: ms.Tensor, + multimodal_embeddings: Optional[tuple[ms.Tensor, ...]] = None, + ) -> ms.Tensor: + # input_ids = input_ids.to(mstype.int64) + inputs_embeds = self.model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, + [self.config.image_token_id, self.config.video_token_id]) + os.environ["MS_DISABLE_INTERNAL_KERNELS_LIST"] = "" + return inputs_embeds + + def get_input_embeddings_v0( + self, + input_ids: ms.Tensor, + image_input: Optional[tuple[ms.Tensor, ...]] = None, + video_input: Optional[tuple[ms.Tensor, ...]] = None, + ) -> ms.Tensor: + inputs_embeds = self.get_input_embeddings(input_ids) + if image_input is not None: + image_embeds = self._process_image_input(image_input) + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + image_embeds, + placeholder_token_id=self.config.image_token_id, + ) + + if video_input is not None: + video_embeds = self._process_video_input(video_input) + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + video_embeds, + placeholder_token_id=self.config.video_token_id, + ) + return inputs_embeds + + def forward( + self, + input_ids: ms.Tensor, + positions: ms.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[ms.Tensor] = None, + **kwargs: object, + ) -> Union[ms.Tensor, IntermediateTensors]: + if intermediate_tensors is not None: + inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner from + # `get_multimodal_embeddings` and `get_input_embeddings`, this + # condition is only for v0 compatibility. + elif inputs_embeds is None: + image_input = self._parse_and_validate_image_input(**kwargs) + video_input = self._parse_and_validate_video_input(**kwargs) + + if image_input is None and video_input is None: + inputs_embeds = None + else: + if uses_mrope(self.config): + assert positions.ndim == 2 and positions.shape[0] == 3, ( + "multimodal section rotary embedding requires " + f"(3, seq_len) positions, but got {positions.shape}") + inputs_embeds = self.get_input_embeddings_v0( + input_ids, + image_input=image_input, + video_input=video_input) + input_ids = None + hidden_states = self.exec_model(input_ids, positions, + intermediate_tensors, inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: ms.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[ms.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample(self, logits: ms.Tensor, + sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights( + self, weights: Iterable[Tuple[str, ms.Tensor]] + ) -> None: # type: ignore[override] + params_dict = self.get_params_dict() + for name, weight in weights: + if "visual." in name: + self.visual.load_weights([(name, weight)], params_dict) + else: + self.model.load_weights([(name, weight)], params_dict) + + return None + + def get_mm_mapping(self) -> MultiModelKeys: + """ + Get the module prefix in multimodal models + """ + return MultiModelKeys.from_string_field( + language_model="language_model", + connector="visual.", + tower_model="visual.merger.") diff --git a/vllm_mindspore/model_executor/models/registry.py b/vllm_mindspore/model_executor/models/registry.py index 5846f21ae7d817d4366422f16c1dfe9010109868..009d84a06f124d270b88f26e98697ae4285cd7f5 100644 --- a/vllm_mindspore/model_executor/models/registry.py +++ b/vllm_mindspore/model_executor/models/registry.py @@ -28,6 +28,8 @@ from vllm_mindspore.utils import (is_mindformers_model_backend, _NATIVE_MODELS = { "LlamaForCausalLM": ("llama", "LlamaForCausalLM"), "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), + "Qwen2_5_VLForConditionalGeneration": + ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"), } _MINDFORMERS_MODELS = { diff --git a/vllm_mindspore/model_executor/models/utils.py b/vllm_mindspore/model_executor/models/utils.py index bf40c1fdbbdd09f1eda214359ebf7b3b401b2b96..493664cda0f899646e8e2872c4cba3f05bba83a7 100644 --- a/vllm_mindspore/model_executor/models/utils.py +++ b/vllm_mindspore/model_executor/models/utils.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 +# type: ignore +# isort:skip_file # Copyright 2025 Huawei Technologies Co., Ltd # Copyright 2024 The vLLM team. # @@ -22,7 +24,7 @@ import mindspore as ms from mindspore import mint, ops from vllm.sequence import IntermediateTensors -from vllm_mindspore.multimodal.inputs import NestedTensors +from vllm_mindspore.multimodal.inputs import NestedTensors # type: ignore[attr-defined] from vllm_mindspore.utils import get_valid_dtype WeightsMapping = Mapping[str, Optional[str]] @@ -247,8 +249,7 @@ def merge_multimodal_embeddings( This updates ``inputs_embeds`` in place. """ if isinstance(placeholder_token_id, list): - placeholder_token_id = ms.Tensor(placeholder_token_id, - device=input_ids.device) + placeholder_token_id = ms.Tensor(placeholder_token_id) return _merge_multimodal_embeddings( inputs_embeds, ms.numpy.isin(input_ids, placeholder_token_id), diff --git a/vllm_mindspore/multimodal/inputs.py b/vllm_mindspore/multimodal/inputs.py index 2673ce6ea26646f88e7e2da957dc46074160a946..8bc9388545c2c1f36d9346992788a9c749b2573e 100644 --- a/vllm_mindspore/multimodal/inputs.py +++ b/vllm_mindspore/multimodal/inputs.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 -# encoding: utf-8 +# type: ignore +# isort:skip_file # Copyright 2025 Huawei Technologies Co., Ltd # Copyright 2024 The vLLM team. # @@ -15,22 +16,62 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ +from collections import defaultdict +from dataclasses import dataclass from typing import Union, cast - import mindspore +from vllm.multimodal.inputs import BaseMultiModalField, BatchedTensorInputs, JSONTree, json_map_leaves,\ + nested_tensors_equal +from vllm.multimodal import MultiModalKwargs + +NestedTensors = Union[list["NestedTensors"], list[mindspore.Tensor], + mindspore.Tensor, tuple[mindspore.Tensor, ...]] + + +@dataclass +class MultiModalFieldElem: + """ + Represents a keyword argument corresponding to a multi-modal item + in :class:`MultiModalKwargs`. + """ + + modality: str + """ + The modality of the corresponding multi-modal item. + Each multi-modal item can consist of multiple keyword arguments. + """ -from vllm.multimodal.inputs import BatchedTensorInputs, JSONTree, json_map_leaves + key: str + """ + The key of this field in :class:`MultiModalKwargs`, + i.e. the name of the keyword argument to be passed to the model. + """ + data: NestedTensors + """ + The tensor data of this field in :class:`MultiModalKwargs`, + i.e. the value of the keyword argument to be passed to the model. + """ -NestedTensors = Union[list["NestedTensors"], list[mindspore.Tensor], mindspore.Tensor, - tuple[mindspore.Tensor, ...]] + field: "BaseMultiModalField" + """ + Defines how to combine the tensor data of this field with others + in order to batch multi-modal items together for model inference. + """ + + def __eq__(self, other: object) -> bool: + if not isinstance(other, self.__class__): + return False + + return ((self.modality, self.key) == (other.modality, other.key) + and nested_tensors_equal(self.data, other.data) + and type(self.field) == type(other.field)) # noqa: E721 -@staticmethod def as_kwargs( batched_inputs: BatchedTensorInputs, *, - device = None, + device=None, ) -> BatchedTensorInputs: # replace as_kwargs of vLLM for multi-model json_inputs = cast(JSONTree[mindspore.Tensor], batched_inputs) @@ -40,4 +81,20 @@ def as_kwargs( json_inputs, ) - return cast(BatchedTensorInputs, json_mapped) \ No newline at end of file + return cast(BatchedTensorInputs, json_mapped) + + +def from_items(items): + """Construct a new :class:`MultiModalKwargs` from multiple items.""" + elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list) + for item in items: + for key, elem in item.items(): + # transform elem.data to tensor, gpu is tensor. + elem.data = mindspore.Tensor(elem.data) + elems_by_key[key].append(elem) + data = { + key: elems[0].field.reduce_data(elems) + for key, elems in elems_by_key.items() if len(elems) > 0 + } + + return MultiModalKwargs(data, items=items) diff --git a/vllm_mindspore/v1/worker/gpu_model_runner.py b/vllm_mindspore/v1/worker/gpu_model_runner.py index f53d49d4d9c896e42f3e46155a81b31cdc2ae248..7f4e3fe162150ed54482fd21e24806dd9d81d018 100644 --- a/vllm_mindspore/v1/worker/gpu_model_runner.py +++ b/vllm_mindspore/v1/worker/gpu_model_runner.py @@ -1,33 +1,46 @@ +#!/usr/bin/env python3 +# type: ignore +# isort:skip_file +# Copyright 2025 Huawei Technologies Co., Ltd +# Copyright 2024 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ from typing import Dict, Tuple, List -import gc import numpy as np import torch from mindspore import mutable -import mindspore as ms -from vllm_mindspore.v1.attention.backends.ms_attn import (MsAttentionMetadata, - MsAttentionBackend, - MLABackend) +from vllm_mindspore.v1.attention.backends.ms_attn import MsAttentionMetadata from vllm_mindspore.utils import get_valid_dtype +from vllm_mindspore.model_executor.layers.rotary_embedding import InferMRotaryEmbedding as MRotaryEmbedding # type: ignore[attr-defined] from vllm.v1.outputs import ModelRunnerOutput from vllm.attention import AttentionType from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheSpec, SlidingWindowSpec from vllm.v1.utils import bind_kv_cache -from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs -from vllm.distributed.parallel_state import get_pp_group -from vllm.utils import cdiv from vllm.logger import init_logger from vllm.v1.worker.gpu_input_batch import CachedRequestState -from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding +from vllm.v1.core.sched.output import SchedulerOutput from vllm.sampling_params import SamplingType - logger = init_logger(__name__) + + def _prepare_inputs( - self, - scheduler_output: "SchedulerOutput", + self, + scheduler_output: "SchedulerOutput", # type: ignore[name-defined] ) -> Tuple[MsAttentionMetadata, torch.Tensor]: total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 @@ -50,13 +63,11 @@ def _prepare_inputs( for i, req_id in enumerate(self.input_batch.req_ids): num_tokens = scheduler_output.num_scheduled_tokens[req_id] num_scheduled_tokens[i] = num_tokens - max_num_scheduled_tokens = max(max_num_scheduled_tokens, - num_tokens) + max_num_scheduled_tokens = max(max_num_scheduled_tokens, num_tokens) # Get request indices. # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] - req_indices = np.repeat(self.arange_np[:num_reqs], - num_scheduled_tokens) + req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens) # Get batched arange. # E.g., [2, 5, 3] -> [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] @@ -73,20 +84,20 @@ def _prepare_inputs( # Get positions. positions_np = self.positions_np[:total_num_scheduled_tokens] np.add(self.input_batch.num_computed_tokens_cpu[req_indices], - arange, - out=positions_np) + arange, + out=positions_np) if self.uses_mrope: self._calc_mrope_positions(scheduler_output) if self.uses_mrope: # Only relevant for models using M-RoPE (e.g, Qwen2-VL) - self.mrope_positions[:, :total_num_scheduled_tokens].copy_( - self.mrope_positions_cpu[:, :total_num_scheduled_tokens], - non_blocking=True) + self.mrope_positions[:, : + total_num_scheduled_tokens] = self.mrope_positions_cpu[:, : + total_num_scheduled_tokens] else: - self.positions[:total_num_scheduled_tokens] = torch.from_numpy(positions_np) - + self.positions[:total_num_scheduled_tokens] = torch.from_numpy( + positions_np) # Get token indices. # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] @@ -96,10 +107,7 @@ def _prepare_inputs( req_indices * self.input_batch.token_ids_cpu.shape[1]) self.input_ids[:total_num_scheduled_tokens] = torch.from_numpy( - np.take(self.input_batch.token_ids_cpu.ravel(), - token_indices, - 0) - ) + np.take(self.input_batch.token_ids_cpu.ravel(), token_indices, 0)) # Calculate the slot mapping. # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] @@ -110,12 +118,12 @@ def _prepare_inputs( block_table_indices = (req_indices * self.max_num_blocks_per_req + positions_np // self.block_size) - - block_numbers = self.input_batch.block_table.block_table_np.ravel()[block_table_indices] + block_numbers = self.input_batch.block_table.block_table_np.ravel( + )[block_table_indices] block_offsets = positions_np % self.block_size np.add(block_numbers * self.block_size, - block_offsets, - out=self.slot_mapping_np[:total_num_scheduled_tokens]) + block_offsets, + out=self.slot_mapping_np[:total_num_scheduled_tokens]) # # Prepare the attention metadata. self.query_start_loc_np[0] = 0 @@ -136,8 +144,7 @@ def _prepare_inputs( common_prefix_len=common_prefix_len, ) - use_spec_decode = len( - scheduler_output.scheduled_spec_decode_tokens) > 0 + use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 if not use_spec_decode: # NOTE(woosuk): Due to chunked prefills, the batch may contain # partial requests. While we should not sample any token @@ -164,7 +171,7 @@ def _prepare_inputs( if self.lora_config: self.set_active_loras(self.input_batch, num_scheduled_tokens) - return attn_metadata, logits_indices, spec_decode_metadata + return attn_metadata, logits_indices, spec_decode_metadata # type: ignore[return-value] def create_block(shape, dtype, name=None, device=None): @@ -172,6 +179,7 @@ def create_block(shape, dtype, name=None, device=None): blocks = mint.empty(shape, dtype=dtype, device=device) return blocks + def initialize_kv_cache(self, kv_cache_config) -> None: """ Initialize KV cache based on `kv_cache_config`. @@ -202,28 +210,29 @@ def initialize_kv_cache(self, kv_cache_config) -> None: assert num_blocks >= kv_cache_config.num_blocks if isinstance(kv_cache_spec, FullAttentionSpec): kv_cache_shape = self.attn_backend.get_kv_cache_shape( - num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, - kv_cache_spec.head_size) + num_blocks, kv_cache_spec.block_size, + kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) dtype = kv_cache_spec.dtype dtype = get_valid_dtype(dtype) current_cache = [] device_type = "CPU" if self.device.type == "cpu" else "Ascend" for i in range(kv_cache_shape[0]): - cache_blocks = create_block( - kv_cache_shape[1:], dtype, device=device_type - ) + cache_blocks = create_block(kv_cache_shape[1:], + dtype, + device=device_type) current_cache.append(mutable(cache_blocks)) kv_caches[layer_name] = mutable(tuple(current_cache)) else: raise NotImplementedError - bind_kv_cache( - kv_caches, - self.vllm_config.compilation_config.static_forward_context, - self.kv_caches) + bind_kv_cache(kv_caches, + self.vllm_config.compilation_config.static_forward_context, + self.kv_caches) -def _update_states(self, scheduler_output: "SchedulerOutput") -> None: +def _update_states( + self, scheduler_output: "SchedulerOutput" +) -> None: # type: ignore[name-defined] """Update the cached states and the persistent batch with the scheduler output. @@ -306,14 +315,12 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: second_per_grid_ts = [] for mm_input in self.requests[req_id].mm_inputs: if mm_input.get("image_grid_thw") is not None: - image_grid_thw.extend( - mm_input["image_grid_thw"].tolist()) - if mm_input.get("video_grid_thw") is not None: - video_grid_thw.extend( - mm_input["video_grid_thw"].tolist()) + image_grid_thw.extend(mm_input["image_grid_thw"].tolist()) + if mm_input.get("video_grid_thw") is not None: + video_grid_thw.extend( + mm_input["video_grid_thw"].tolist()) if mm_input.get("second_per_grid_ts") is not None: - second_per_grid_ts.extend( - mm_input["second_per_grid_ts"]) + second_per_grid_ts.extend(mm_input["second_per_grid_ts"]) hf_config = self.model_config.hf_config @@ -339,9 +346,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: req_state.num_computed_tokens = num_computed_tokens # Add the sampled token(s) from the previous step (if any). # This doesn't include "unverified" tokens like spec decode tokens. - num_new_tokens = (num_computed_tokens + - len(req_data.new_token_ids) - - req_state.num_tokens) + num_new_tokens = (num_computed_tokens + len(req_data.new_token_ids) - + req_state.num_tokens) if num_new_tokens == 1: # Avoid slicing list in most common case. req_state.output_token_ids.append(req_data.new_token_ids[-1]) @@ -368,8 +374,6 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Update the persistent batch. self.input_batch.num_computed_tokens_cpu[req_index] = ( num_computed_tokens) - start_index = (len(req_state.block_ids) - - len(req_data.new_block_ids)) self.input_batch.block_table.append_row(req_data.new_block_ids, req_index) # Add new_token_ids to token_ids_cpu. @@ -391,7 +395,6 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # NOTE(woosuk): `num_tokens` here may include spec decode tokens. self.input_batch.num_tokens[req_index] = end_token_index - # self.input_batch.token_ids_cpu_tensor.copy_(torch.from_numpy(self.input_batch.token_ids_cpu)) # Check if the batch has changed. If not, we can skip copying the # sampling metadata from CPU to GPU. @@ -402,12 +405,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: removed_req_indices = sorted(removed_req_indices, reverse=True) for req_id in req_ids_to_add: req_state = self.requests[req_id] - if removed_req_indices: - # Fill the empty index. - req_index = removed_req_indices.pop() - else: - # Append to the end. - req_index = None + req_index = removed_req_indices.pop() if removed_req_indices else None self.input_batch.add_request(req_state, req_index) # Condense the batched states if there are empty indices. @@ -427,7 +425,7 @@ def wrapper_gpu_model_runner_execute_model(func): return output except Exception as e: logger.warning( - f"Caught exception {str(e)} when processing req_ids {self.input_batch.req_ids}" + f"Caught exception {str(e)} when processing req_ids {self.input_batch.req_ids}" # noqa: G004 ) return ModelRunnerOutput( req_ids=self.input_batch.req_ids, @@ -466,7 +464,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: dtype=self.kv_cache_dtype, use_mla=use_mla) elif attn_module.attn_type in (AttentionType.ENCODER, - AttentionType.ENCODER_ONLY): + AttentionType.ENCODER_ONLY): # encoder-only attention does not need KV cache. continue elif attn_module.attn_type == AttentionType.ENCODER_DECODER: @@ -476,3 +474,58 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: f"Unknown attention type: {attn_module.attn_type}") return kv_cache_spec + + +def _calc_mrope_positions( + self, + scheduler_output: "SchedulerOutput"): # type: ignore[name-defined] + mrope_pos_ptr = 0 + for index, req_id in enumerate(self.input_batch.req_ids): + req = self.requests[req_id] + assert req.mrope_positions is not None + + num_computed_tokens = \ + self.input_batch.num_computed_tokens_cpu[index] + num_scheduled_tokens = \ + scheduler_output.num_scheduled_tokens[req_id] + num_prompt_tokens = len(req.prompt_token_ids) + + if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens: + prompt_part_len = max(0, num_prompt_tokens - num_computed_tokens) + completion_part_len = max(0, + num_scheduled_tokens - prompt_part_len) + else: + prompt_part_len = num_scheduled_tokens + completion_part_len = 0 + + assert num_scheduled_tokens == prompt_part_len + completion_part_len + + if prompt_part_len > 0: + # prompt's mrope_positions are pre-computed + # gpu is number or tensor, but we are numpy, so we transform to int + dst_start = int(mrope_pos_ptr) + dst_end = int(mrope_pos_ptr + prompt_part_len) + src_start = int(num_computed_tokens) + src_end = int(num_computed_tokens + prompt_part_len) + + self.mrope_positions_cpu[:, dst_start:dst_end] = \ + req.mrope_positions[:,src_start:src_end] + + mrope_pos_ptr += prompt_part_len + + if completion_part_len > 0: + # compute completion's mrope_positions on-the-fly + dst_start = mrope_pos_ptr + dst_end = mrope_pos_ptr + completion_part_len + + self.mrope_positions_cpu[:, dst_start:dst_end] = \ + MRotaryEmbedding.get_next_input_positions_tensor( + req.mrope_position_delta, + context_len=num_computed_tokens + + prompt_part_len, + seq_len=num_computed_tokens + + prompt_part_len + + completion_part_len, + ) + + mrope_pos_ptr += completion_part_len diff --git a/vllm_mindspore/worker/worker.py b/vllm_mindspore/worker/worker.py index 8ce1bc91d511a43a83fd3c8b0e70d228b98b951b..0978ed4c58c3777ae859a3103350c98ab4320945 100644 --- a/vllm_mindspore/worker/worker.py +++ b/vllm_mindspore/worker/worker.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 -# encoding: utf-8 +# type: ignore +# isort:skip_file # Copyright 2025 Huawei Technologies Co., Ltd # Copyright 2024 The vLLM team. # @@ -15,23 +16,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ - """Worker functions""" -import gc -import os import math -from typing import Tuple, Optional - import torch -from vllm.config import VllmConfig -from vllm.distributed import ( - ensure_kv_transfer_initialized, - ensure_model_parallel_initialized, - init_distributed_environment, - set_custom_all_reduce, -) - from vllm.logger import init_logger from vllm_mindspore.utils import get_valid_dtype @@ -39,34 +27,41 @@ from vllm.model_executor import set_random_seed from vllm.sequence import SequenceGroupMetadata from vllm.sampling_params import SamplingParams - logger = init_logger(__name__) -def _prepare_input_for_warmup(model_config, model_runner, cache_engine, is_prefill, is_mtp_model=False): +def _prepare_input_for_warmup(model_config, + model_runner, + cache_engine, + is_prefill, + is_mtp_model=False): bs = 1 seq_len = model_runner.scheduler_config.max_num_batched_tokens if is_prefill else 1 - dummy_data = model_runner.input_registry.dummy_data_for_profiling(model_config, seq_len, model_runner.mm_registry) - block_tables = [i for i in range(math.ceil(seq_len / cache_engine.block_size))] + dummy_data = model_runner.input_registry.dummy_data_for_profiling( + model_config, seq_len, model_runner.mm_registry) + block_tables = [ + i for i in range(math.ceil(seq_len / cache_engine.block_size)) + ] + + # adapter multi modal warm up + seq_data = dummy_data.seq_data + if seq_len == 1: + seq_data = dummy_data.seq_data.from_prompt_token_counts((0, seq_len)) + seqs = [ SequenceGroupMetadata( request_id=str(idx), is_prompt=is_prefill, - seq_data={idx: dummy_data.seq_data}, + seq_data={idx: seq_data}, sampling_params=SamplingParams(), block_tables={idx: block_tables}, lora_request=None, multi_modal_data=None, multi_modal_placeholders=None, - ) - for idx in range(bs) + ) for idx in range(bs) ] model_input = model_runner.prepare_model_input(seqs) - block_tables = model_input.attn_metadata.block_tables - if block_tables is not None and block_tables.numel() <= 0: - model_input.attn_metadata.block_tables = torch.zeros((1, 1), dtype=torch.int32) - previous_hidden_states = None if not is_mtp_model else \ torch.ones([bs, seq_len, model_config.get_hidden_size()], dtype=get_valid_dtype(model_config.dtype)) return model_input, previous_hidden_states @@ -78,19 +73,31 @@ def _warm_up_model(self) -> None: is_mtp_model = self.speculative_config is not None and self.model_config.hf_config.model_type == "deepseek_mtp" if is_mtp_model: # prefill mtp model - model_input, previous_hidden_states = _prepare_input_for_warmup(self.model_config, self.model_runner, - self.cache_engine[0], True, is_mtp_model) - self.model_runner.execute_model(model_input, kv_cache, None, previous_hidden_states=previous_hidden_states) + model_input, previous_hidden_states = _prepare_input_for_warmup( + self.model_config, self.model_runner, self.cache_engine[0], True, + is_mtp_model) + self.model_runner.execute_model( + model_input, + kv_cache, + None, + previous_hidden_states=previous_hidden_states) # warmup for decode if self.vllm_config.scheduler_config.is_multi_step: - model_input, _ = _prepare_input_for_warmup(self.model_config, self.model_runner._base_model_runner, - self.cache_engine[0], False) - self.model_runner._base_model_runner.execute_model(model_input, kv_cache, None) + model_input, _ = _prepare_input_for_warmup( + self.model_config, self.model_runner._base_model_runner, + self.cache_engine[0], False) + self.model_runner._base_model_runner.execute_model( + model_input, kv_cache, None) else: - model_input, previous_hidden_states = _prepare_input_for_warmup(self.model_config, self.model_runner, - self.cache_engine[0], False, is_mtp_model) - self.model_runner.execute_model(model_input, kv_cache, None, previous_hidden_states=previous_hidden_states) + model_input, previous_hidden_states = _prepare_input_for_warmup( + self.model_config, self.model_runner, self.cache_engine[0], False, + is_mtp_model) + self.model_runner.execute_model( + model_input, + kv_cache, + None, + previous_hidden_states=previous_hidden_states) torch.cuda.synchronize()