diff --git a/tests/st/python/cases_parallel/similarity.py b/tests/st/python/cases_parallel/similarity.py new file mode 100644 index 0000000000000000000000000000000000000000..bfdae0d90d39150efd4e650160921ecf663e7bf4 --- /dev/null +++ b/tests/st/python/cases_parallel/similarity.py @@ -0,0 +1,58 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import math + +import jieba +import numpy as np + + +def _get_all_words(standard_cut_infer_ret_list, test_cut_infer_ret_list): + all_words = [] + for s_cut in standard_cut_infer_ret_list: + if s_cut not in all_words: + all_words.append(s_cut) + for t_cut in test_cut_infer_ret_list: + if t_cut not in all_words: + all_words.append(t_cut) + return all_words + + +def _get_word_vector(standard_cut_infer_ret_list, test_cut_infer_ret_list, + all_words): + la_standard = [] + lb_test = [] + for word in all_words: + la_standard.append(standard_cut_infer_ret_list.count(word)) + lb_test.append(test_cut_infer_ret_list.count(word)) + return la_standard, lb_test + + +def _get_calculate_cos(la_standard, lb_test): + laa = np.array(la_standard) + lbb = np.array(lb_test) + cos = (np.dot(laa, lbb.T)) / ((math.sqrt(np.dot(laa, laa.T))) * + (math.sqrt(np.dot(lbb, lbb.T)))) + return np.round(cos, 2) + + +def compare_distance(x1, x2, bench_sim=0.95): + """compare distance""" + y1 = list(jieba.cut(x1)) + y2 = list(jieba.cut(x2)) + all_words = _get_all_words(y1, y2) + laa, lbb = _get_word_vector(y1, y2, all_words) + sim = _get_calculate_cos(laa, lbb) + print("calculate sim is:{}".format(str(sim))) + assert sim >= bench_sim diff --git a/tests/st/python/cases_parallel/vllm_deepseek_bf16_part_v1.py b/tests/st/python/cases_parallel/vllm_deepseek_bf16_part_v1.py index b514211345fc1790893d9423982864c8ba0d202e..5414578d5f5a2d92913cce22546b51eee067f9fd 100644 --- a/tests/st/python/cases_parallel/vllm_deepseek_bf16_part_v1.py +++ b/tests/st/python/cases_parallel/vllm_deepseek_bf16_part_v1.py @@ -63,7 +63,7 @@ def test_deepseek_r1_bf16(): trust_remote_code=True, gpu_memory_utilization=0.9, tensor_parallel_size=2, - max_model_len=4096) + max_model_len=33 * 1024) # Generate texts from the prompts. The output is a list of RequestOutput # objects that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params) diff --git a/tests/st/python/cases_parallel/vllm_qwen2_5_vl_7b_v1.py b/tests/st/python/cases_parallel/vllm_qwen2_5_vl_7b_v1.py new file mode 100644 index 0000000000000000000000000000000000000000..d776c8d93dacc2fc6a3fc28453083c2de9ba320c --- /dev/null +++ b/tests/st/python/cases_parallel/vllm_qwen2_5_vl_7b_v1.py @@ -0,0 +1,101 @@ +#!/usr/bin/env python3 +# encoding: utf-8 +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""test mf qwen2.5 vl 7B.""" +import os + +from PIL import Image + +from tests.st.python import set_env +from tests.st.python.cases_parallel.similarity import compare_distance + +env_manager = set_env.EnvVarManager() +# def env +env_vars = { + "ASCEND_CUSTOM_PATH": os.path.expandvars("$ASCEND_HOME_PATH/../"), + "HCCL_OP_EXPANSION_MODE": "AIV", + "MS_ALLOC_CONF": "enable_vmm:True", + "LCCL_DETERMINISTIC": "1", + "HCCL_DETERMINISTIC": "true", + "ATB_MATMUL_SHUFFLE_K_ENABLE": "0", + "ATB_LLM_LCOC_ENABLE": "0", +} +# set env +env_manager.setup_ai_environment(env_vars) +# isort: off +import vllm_mindspore +from vllm import LLM, SamplingParams + +# isort: on + +PROMPT_TEMPLATE = ( + "<|im_start|>system\nYou are a helpful assistant.<|im_end|>" + "\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>" + "What is in the image?<|im_end|>\n" + "<|im_start|>assistant\n") + + +def pil_image() -> Image.Image: + image_path = "images/1080p.jpeg" + return Image.open(image_path) + + +def test_qwen2_5_vl_7b_v1(): + """ + test case qwen2.5 vl 7B + """ + inputs = [{ + "prompt": PROMPT_TEMPLATE, + "multi_modal_data": { + "image": pil_image() + }, + }] + + # Create a sampling params object. + sampling_params = SamplingParams(temperature=0.0, max_tokens=128, top_k=1) + + # Create an LLM. + llm = LLM( + model="/home/workspace/mindspore_dataset/weight/Qwen2.5-VL-7B-Instruct", + gpu_memory_utilization=0.9, + tensor_parallel_size=2, + max_model_len=4096, + max_num_seqs=32, + max_num_batched_tokens=32) + except_list = [ + 'The image depicts a serene and picturesque landscape. It features a lush green meadow with ' + 'wildflowers in the foreground. In the middle ground, there are small wooden huts, possibly used for' + ' storage or as simple shelters. Beyond the meadow, there is a calm body of water, likely a lake,' + ' surrounded by dense forests. In the background, majestic mountains rise, their peaks partially ' + 'covered with snow, suggesting a high-altitude location. The sky is partly cloudy, with soft ' + 'lighting that enhances the tranquil and idyllic atmosphere of the scene. This type of landscape ' + 'is often associated with alpine regions.' + ] + + for i in range(3): + # Generate texts from the prompts. The output is a list of RequestOutput objects + # that contain the prompt, generated text, and other information. + outputs = llm.generate(inputs, sampling_params) + # Print the outputs. + for i, output in enumerate(outputs): + generated_text = output.outputs[0].text + print( + f"Prompt: {output.prompt!r}, Generated text: {generated_text!r}" + ) + compare_distance(generated_text, except_list[0], bench_sim=0.95) + + # unset env + env_manager.unset_all() diff --git a/tests/st/python/images/1080p.jpeg b/tests/st/python/images/1080p.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..0d298985cf4468902c27eaca2f23f74dae8c80ab Binary files /dev/null and b/tests/st/python/images/1080p.jpeg differ diff --git a/tests/st/python/test_cases_parallel.py b/tests/st/python/test_cases_parallel.py index bd13f26a067e9b0ca6dce58ecaa7497e40ca34d7..258b03d9c0496205070db7f7c8137e0213b66c0a 100644 --- a/tests/st/python/test_cases_parallel.py +++ b/tests/st/python/test_cases_parallel.py @@ -174,7 +174,9 @@ def test_cases_parallel_part5(): cases = [(2, "cases_parallel/vllm_mf_qwen3_8b.py::test_mf_qwen3_v0", "vllm_mf_qwen3_8b_test_mf_qwen3.log"), (2, "cases_parallel/vllm_mf_qwen3_8b.py::test_mf_qwen3_v1", - "vllm_mf_qwen3_8b_v1_test_mf_qwen3.log")] + "vllm_mf_qwen3_8b_v1_test_mf_qwen3.log"), + (2, "cases_parallel/vllm_qwen2_5_vl_7b_v1.py::test_qwen2_5_vl_7b_v1", + "vllm_qwen2_5_vl_7b_v1.log")] run_tasks(cases) diff --git a/vllm_mindspore/distributed/communication_op.py b/vllm_mindspore/distributed/communication_op.py index dbcb0ffe95fa224347c1e1fa793ca92bb6c68d81..d82cb9434781df2144caf6544d58a8ec382267e2 100644 --- a/vllm_mindspore/distributed/communication_op.py +++ b/vllm_mindspore/distributed/communication_op.py @@ -41,3 +41,26 @@ class ReduceFromModelParallelRegion(nn.Cell): return input_ output = self.all_reduce(input_) return output + + +class AllGatherFromModelParallelRegion(nn.Cell): + """ + Gather the input from world parallel region and concatenate, simultaneously perform + transpose operation on input. + """ + + def __init__(self): + super().__init__() + self.world_size = get_tensor_model_parallel_world_size() + if self.world_size > 1: + self.tp_group = get_tp_group().device_group._name + self.all_gather_into_tensor = ops.AllGather(group=self.tp_group) + + def construct(self, input_): + # Size and dimension. + if self.world_size == 1: + return input_ + input_ = ops.swapaxes(input_, 0, -1) + output = self.all_gather_into_tensor(input_) + output = ops.swapaxes(output, 0, -1) + return output diff --git a/vllm_mindspore/model_executor/layers/linear.py b/vllm_mindspore/model_executor/layers/linear.py index adfbc0861f9a5e0bc2d5b261954734ba387f0336..adc3cb4be559ee48185468947f55caf5a86e6516 100644 --- a/vllm_mindspore/model_executor/layers/linear.py +++ b/vllm_mindspore/model_executor/layers/linear.py @@ -22,6 +22,7 @@ from abc import abstractmethod from typing import Optional, Union +import mindspore as ms from mindspore import Parameter, Tensor, mint, nn, ops from mindspore._c_expression.typing import Type as MSDtype from vllm.config import get_current_vllm_config @@ -35,6 +36,8 @@ from vllm_mindspore.distributed.communication_op import ( from vllm_mindspore.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm_mindspore.model_executor.utils import set_weight_attrs +from vllm_mindspore.model_executor.model_loader.weight_utils import ( + split_loaded_weight) WEIGHT_LOADER_V2_SUPPORTED = [ "CompressedTensorsLinearMethod", "AWQMarlinLinearMethod", @@ -180,7 +183,7 @@ class ColumnParallelLinear(LinearBase): output_sizes: list of output sizes packed into one output, like for QKV the list would be size 3. prefix: The name of the layer in the state dict, including all parents - (e.g. model.layers.0.qkv_proj) + (e.g. model.layers.0.qkv_proj) """ def __init__( @@ -263,21 +266,21 @@ class ColumnParallelLinear(LinearBase): return output return output, output_bias - def weight_loader(self, param: Parameter, loaded_weight: Tensor): + def weight_loader(self, param, loaded_weight): tp_rank = get_tensor_model_parallel_rank() output_dim = getattr(param, "output_dim", None) + shard_size = self.output_size_per_partition + start_idx = tp_rank * shard_size + loaded_weight = split_loaded_weight(loaded_weight, output_dim, + start_idx, shard_size) - if output_dim is not None: - shard_size = param.shape[output_dim] - start_idx = tp_rank * shard_size - loaded_weight = loaded_weight.narrow(output_dim, start_idx, - shard_size).contiguous() - + # Special case for loading scales off disk, which often do not + # have a shape (such as in the case of AutoFP8). if len(loaded_weight.shape) == 0: loaded_weight = loaded_weight.reshape(1) assert param.shape == loaded_weight.shape - param.set_data(loaded_weight) + param.set_data(ms.from_numpy(loaded_weight)) class MergedColumnParallelLinear(ColumnParallelLinear): @@ -327,29 +330,27 @@ class MergedColumnParallelLinear(ColumnParallelLinear): prefix=prefix, return_bias=return_bias) - -# type: ignore[override] - def weight_loader(self, - param: Parameter, - loaded_weight: Tensor, + param, + loaded_weight, loaded_shard_id: Optional[int] = None): - param_data = param.data output_dim = getattr(param, "output_dim", None) tp_rank = get_tensor_model_parallel_rank() tp_size = get_tensor_model_parallel_world_size() - if output_dim is not None and loaded_shard_id is not None: + shard_size = 0 + shard_offset = 0 + if loaded_shard_id is not None: assert loaded_shard_id < len(self.output_sizes) shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size shard_size = self.output_sizes[loaded_shard_id] // tp_size - param_data = param.data - param_data = param_data.narrow(output_dim, shard_offset, - shard_size) - start_idx = tp_rank * shard_size - loaded_weight = loaded_weight.narrow(output_dim, start_idx, - shard_size).contiguous() - assert param_data.shape == loaded_weight.shape - param[shard_offset:shard_offset + shard_size, :] = loaded_weight + + start_idx = tp_rank * shard_size + loaded_weight = split_loaded_weight(loaded_weight, output_dim, + start_idx, shard_size) + + assert loaded_weight.shape == (shard_size, param.shape[1]) + param[shard_offset:shard_offset + + shard_size, :] = ms.from_numpy(loaded_weight) class QKVParallelLinear(ColumnParallelLinear): @@ -427,19 +428,13 @@ class QKVParallelLinear(ColumnParallelLinear): prefix=prefix, return_bias=return_bias) - -# type: ignore[override] - def weight_loader(self, - param: Parameter, - loaded_weight: Tensor, + param, + loaded_weight, loaded_shard_id: Optional[str] = None): output_dim = getattr(param, "output_dim", None) tp_rank = get_tensor_model_parallel_rank() assert loaded_shard_id in ["q", "k", "v"] - # If output dim is defined, use the default loading process. - # if output_dim is not None: - param_data = param.data if loaded_shard_id == "q": shard_offset = 0 shard_size = self.num_heads * self.head_size @@ -451,21 +446,20 @@ class QKVParallelLinear(ColumnParallelLinear): self.num_kv_heads) * self.head_size shard_size = self.num_kv_heads * self.head_size - param_data = param_data.narrow(output_dim, shard_offset, shard_size) if loaded_shard_id == "q": shard_id = tp_rank else: shard_id = tp_rank // self.num_kv_head_replicas start_idx = shard_id * shard_size + loaded_weight = split_loaded_weight(loaded_weight, output_dim, + start_idx, shard_size) + loaded_weight = ms.from_numpy(loaded_weight) - loaded_weight = loaded_weight.narrow(output_dim, start_idx, - shard_size).contiguous() - assert param_data.shape == loaded_weight.shape if param.name.endswith("weight"): - self.weight[shard_offset:shard_offset + - shard_size, :] = loaded_weight + assert loaded_weight.shape == (shard_size, param.shape[1]) if param.name.endswith("bias"): - self.bias[shard_offset:shard_offset + shard_size] = loaded_weight + assert loaded_weight.shape == (shard_size,) + param[shard_offset:shard_offset + shard_size] = loaded_weight class RowParallelLinear(LinearBase): @@ -587,14 +581,15 @@ class RowParallelLinear(LinearBase): def weight_loader(self, param, loaded_weight): tp_rank = get_tensor_model_parallel_rank() input_dim = getattr(param, "input_dim", None) - is_sharded_weight = getattr(param, "is_sharded_weight", False) - is_sharded_weight = is_sharded_weight - if input_dim is not None and not is_sharded_weight: - shard_size = param.shape[input_dim] - start_idx = tp_rank * shard_size - loaded_weight = loaded_weight.narrow(input_dim, start_idx, - shard_size).contiguous() + shard_size = self.input_size_per_partition + start_idx = tp_rank * shard_size + loaded_weight = split_loaded_weight(loaded_weight, input_dim, + start_idx, shard_size) + + # Special case for loading scales off disk, which often do not + # have a shape (such as in the case of AutoFP8). if len(loaded_weight.shape) == 0: loaded_weight = loaded_weight.reshape(1) + assert param.shape == loaded_weight.shape - param.set_data(loaded_weight.contiguous()) + param.set_data(ms.from_numpy(loaded_weight)) diff --git a/vllm_mindspore/model_executor/layers/rotary_embedding.py b/vllm_mindspore/model_executor/layers/rotary_embedding.py index 81dfecf997c97e46e61c0376346f56fd76ff753a..ab3f95f3b463b38c67353a67868f2c4f6b75f583 100644 --- a/vllm_mindspore/model_executor/layers/rotary_embedding.py +++ b/vllm_mindspore/model_executor/layers/rotary_embedding.py @@ -24,12 +24,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Union +from typing import Any, Optional, Tuple, Union import mindspore import numpy as np from mindspore import Tensor, mint, nn, ops from mindspore.common import dtype as mstype +from mindspore.ops.auto_generate.gen_ops_prim import SliceExt + from transformers import PretrainedConfig from vllm.config import get_current_vllm_config @@ -435,9 +437,9 @@ class MRotaryEmbedding(RotaryEmbedding): context_len: int, seq_len: int, ) -> mindspore.Tensor: - return ops.arange( - mrope_position_delta + context_len, - mrope_position_delta + seq_len, + return mint.arange( + int(mrope_position_delta + context_len), + int(mrope_position_delta + seq_len), ).broadcast_to((3, -1)) @@ -484,7 +486,7 @@ class InferMRotaryEmbedding(InferRotaryEmbedding): key: mindspore.Tensor, batch_valid_length: Tensor = None, is_prefill: bool = False, - ) -> tuple[mindspore.Tensor, mindspore.Tensor]: + ) -> Tuple[mindspore.Tensor, mindspore.Tensor]: """ Args: positions: @@ -493,52 +495,60 @@ class InferMRotaryEmbedding(InferRotaryEmbedding): query: [num_tokens, num_heads * head_size] key: [num_tokens, num_kv_heads * head_size] """ + half_rotary_dim = self.rotary_dim // 2 # prefill if is_prefill: num_tokens = positions.shape[-1] cos, sin = self.freqs_cos[positions], self.freqs_sin[positions] - cos, sin = cos[..., :self.rotary_dim // - 2], sin[..., :self.rotary_dim // 2] + cos = SliceExt()(cos, -1, 0, half_rotary_dim, 1) + sin = SliceExt()(sin, -1, 0, half_rotary_dim, 1) if positions.ndim == 2: - cos_l = ops.split(cos, self.mrope_section, axis=-1) - sin_l = ops.split(sin, self.mrope_section, axis=-1) + cos_l = mint.split(cos, self.mrope_section, dim=-1) + sin_l = mint.split(sin, self.mrope_section, dim=-1) cos, sin = (), () for i in range(len( self.mrope_section)): # type: ignore[arg-type] - cos += (cos_l[i][i], ) - sin += (sin_l[i][i], ) + cos_l_select = mint.index_select(cos_l[i], 0, + Tensor([i])).squeeze(0) + cos += (cos_l_select, ) + sin_l_select = mint.index_select(sin_l[i], 0, + Tensor([i])).squeeze(0) + sin += (sin_l_select, ) cos = ops.cat(cos, axis=-1) sin = ops.cat(sin, axis=-1) query_shape = query.shape query = query.view(num_tokens, -1, self.head_size) - query_rot = query[..., :self.rotary_dim] - query_pass = query[..., self.rotary_dim:] + query_rot = SliceExt()(query, -1, 0, self.rotary_dim, 1) + query_pass = SliceExt()(query, -1, self.rotary_dim, + query_shape[-1], 1) query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style) query = ops.cat((query_rot, query_pass), axis=-1).view(query_shape) key_shape = key.shape key = key.view(num_tokens, -1, self.head_size) - key_rot = key[..., :self.rotary_dim] - key_pass = key[..., self.rotary_dim:] + key_rot = SliceExt()(key, -1, 0, self.rotary_dim, 1) + key_pass = SliceExt()(key, -1, self.rotary_dim, key_shape[-1], 1) key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style) key = ops.cat((key_rot, key_pass), axis=-1).view(key_shape) return query, key # decode - if positions.ndim == 2 and positions.shape[0] == len( - self.mrope_section): # type: ignore[arg-type] - num_tokens = positions.shape[-1] + if positions.ndim == 2: cos, sin = self.freqs_cos[positions], self.freqs_sin[positions] - cos, sin = cos[..., :self.rotary_dim // - 2], sin[..., :self.rotary_dim // 2] - cos_l = ops.split(cos, self.mrope_section, axis=-1) - sin_l = ops.split(sin, self.mrope_section, axis=-1) + cos = SliceExt()(cos, -1, 0, half_rotary_dim, 1) + sin = SliceExt()(sin, -1, 0, half_rotary_dim, 1) + cos_l = mint.split(cos, self.mrope_section, dim=-1) + sin_l = mint.split(sin, self.mrope_section, dim=-1) cos, sin = (), () for i in range(len(self.mrope_section)): # type: ignore[arg-type] - cos += (cos_l[i][i], ) - sin += (sin_l[i][i], ) + cos_l_select = mint.index_select(cos_l[i], 0, + Tensor([i])).squeeze(0) + cos += (cos_l_select, ) + sin_l_select = mint.index_select(sin_l[i], 0, + Tensor([i])).squeeze(0) + sin += (sin_l_select, ) cos = ops.cat(cos, axis=-1) sin = ops.cat(sin, axis=-1) freqs_cos = ops.cat([cos, cos], axis=-1).squeeze(1) diff --git a/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py b/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py index 18530805915eac370f7040206e99cba2aaf1433e..d4c3a25074be28aa28fdce9e9913aa71aef645e5 100644 --- a/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py @@ -22,6 +22,7 @@ from collections.abc import Sequence from dataclasses import dataclass from typing import Optional +import mindspore as ms from mindspore import Parameter, Tensor, mint, nn, ops from mindspore.common.dtype import typing from vllm.config import get_current_vllm_config @@ -35,7 +36,8 @@ from vllm_mindspore.distributed.communication_op import ( from vllm_mindspore.model_executor.layers.quantization.base_config import ( QuantizeMethodBase, method_has_implemented_embedding) from vllm_mindspore.model_executor.utils import set_weight_attrs - +from vllm_mindspore.model_executor.model_loader.weight_utils import ( + split_loaded_weight) DEFAULT_VOCAB_PADDING_SIZE = 64 @@ -339,15 +341,14 @@ class VocabParallelEmbedding(nn.Cell): def weight_loader(self, param: Parameter, loaded_weight: Tensor): output_dim = getattr(param, "output_dim", None) - + get_tensor_model_parallel_rank() # If parameter does not have output dim, then it should # be copied onto all gpus (e.g. g_idx for act_order gptq). if output_dim is None: assert param.data.shape == loaded_weight.shape if param.data.shape != loaded_weight.shape: raise ValueError( - f"'param.data.shape' should be equal " - f"to 'loaded_weight.shape'," + f"'param.data.shape' should be equal to 'loaded_weight.shape'," f" but got {param.data.shape} and {loaded_weight.shape}") param.set_data(loaded_weight) return @@ -355,16 +356,16 @@ class VocabParallelEmbedding(nn.Cell): # Shard indexes for loading the weight start_idx = self.shard_indices.org_vocab_start_index shard_size = self.shard_indices.org_vocab_end_index - start_idx - if loaded_weight.shape[output_dim] != self.org_vocab_size: - raise ValueError(f"'loaded_weight.shape[output_dim]' should " - f"be equal to 'org_vocab_size'," - f" but got {loaded_weight.shape[output_dim]} " - f"and {self.org_vocab_size}") - - # Copy the data. - loaded_weight = loaded_weight.narrow(output_dim, start_idx, - shard_size).contiguous() - param[:loaded_weight.shape[0]] = loaded_weight + loaded_weight = split_loaded_weight(loaded_weight, output_dim, + start_idx, shard_size) + org_vocab_size_per_rank = self.org_vocab_size // self.tp_size + if loaded_weight.shape[output_dim] != org_vocab_size_per_rank: + raise ValueError( + f"'loaded_weight.shape[output_dim]' should be equal to 'org_vocab_size'," + f" but got {loaded_weight.shape[output_dim]} and {self.org_vocab_size}" + ) + + param[:loaded_weight.shape[0]] = ms.from_numpy(loaded_weight) param[loaded_weight.shape[0]:] = 0 diff --git a/vllm_mindspore/model_executor/model_loader/weight_utils.py b/vllm_mindspore/model_executor/model_loader/weight_utils.py index 4ec3a2ded040013c01f480eab7fae8eb078072de..8c17f0ae1eb894ed75a896bfc14449ab06635066 100644 --- a/vllm_mindspore/model_executor/model_loader/weight_utils.py +++ b/vllm_mindspore/model_executor/model_loader/weight_utils.py @@ -18,23 +18,45 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Generator +from typing import Any, Generator, List, Tuple import mindspore as ms -import torch -from mindspore import Parameter, Tensor +from mindspore import Parameter +from safetensors import safe_open from tqdm.auto import tqdm +from vllm.model_executor.model_loader.weight_utils import (_BAR_FORMAT, + enable_tqdm) + + +def split_loaded_weight(loaded_weight, shard_dim, start_idx, shard_size): + """ + Read numpy slice data based on axis and slice range. + :loaded_weight: PySafeSlice object + :shard_dim: axis of weight slice + :start_idx: start slice index + :shard_size: end slice index + """ + if shard_dim is None: + loaded_weight = loaded_weight[:] + return loaded_weight + + end_idx = start_idx + shard_size + if shard_dim == 0: + loaded_weight = loaded_weight[start_idx:end_idx] + elif shard_dim == 1: + loaded_weight = loaded_weight[:, start_idx:end_idx] + elif shard_dim == 2: + loaded_weight = loaded_weight[:, :, start_idx:end_idx] + else: + raise ValueError("shard_dim:{} is not supported.".format(shard_dim)) + return loaded_weight def safetensors_weights_iterator( - hf_weights_files: list[str], + hf_weights_files: List[str], use_tqdm_on_load: bool, -) -> Generator[tuple[str, torch.Tensor], None, None]: +) -> Generator[Tuple[str, Any], None, None]: """Iterate over the weights in the model safetensor files.""" - from safetensors import safe_open - from vllm.model_executor.model_loader.weight_utils import (_BAR_FORMAT, - enable_tqdm) - for st_file in tqdm( hf_weights_files, desc="Loading safetensors checkpoint shards", @@ -43,10 +65,14 @@ def safetensors_weights_iterator( ): with safe_open(st_file, framework="np") as f: for name in f.keys(): # noqa: SIM118 - param = f.get_tensor(name) - yield name, ms.tensor(param) + # Return a lightweight PySafeSlice object that uses file pointer offset internally to read Safetensor + # on demand, avoiding memory explosion. Actual data can be obtained through slicing operation + # like param[start:end] + param = f.get_slice(name) + yield name, param -def default_weight_loader(param: Parameter, loaded_weight: Tensor) -> None: +def default_weight_loader(param: Parameter, loaded_weight: Any) -> None: """Default weight loader.""" - param.set_data(loaded_weight) + loaded_weight = loaded_weight[:] + param.set_data(ms.Tensor(loaded_weight, dtype=param.dtype)) diff --git a/vllm_mindspore/model_executor/models/attention_mask.py b/vllm_mindspore/model_executor/models/attention_mask.py index 50dbc8cd38ae2d3e6caac72516c413812b564ff3..859e2506d27021a6994df8d16078ebcb147c0c3a 100644 --- a/vllm_mindspore/model_executor/models/attention_mask.py +++ b/vllm_mindspore/model_executor/models/attention_mask.py @@ -17,6 +17,7 @@ infer attention mask. """ import numpy as np +import mindspore as ms from mindspore import Tensor from mindspore import dtype as mstype from mindspore import mint @@ -39,6 +40,9 @@ FA:ASD-V2.1.5 # yapf: enable +MAX_MODEL_LEN_32K = 32 * 1024 + + class LowerTriangularMask: r""" Provide Infer model attention mask. @@ -50,7 +54,6 @@ class LowerTriangularMask: def __init__(self, dtype, max_model_len): self.dtype = dtype self.max_model_len = max_model_len - prefill_mask_coeff = 1.0 if self.dtype == mstype.bfloat16 else -10000.0 self.prefill_mask = Tensor( @@ -58,20 +61,50 @@ class LowerTriangularMask: prefill_mask_coeff, dtype=self.dtype) - self.decode_mask = Tensor(np.triu(np.ones( - shape=(self.max_model_len, self.max_model_len), dtype=np.int8), - k=1), - dtype=self.dtype) * -10000 - self.hard_mask = mint.zeros((1, 1), dtype=dtype) + decode_mask_coeff = -10000 + self.decode_mask = self.init_decode_mask(decode_mask_coeff) + + def init_decode_mask(self, decode_mask_coeff): + # Our previous test limit was 32K, in order not to affect the original performance. + # We define 32K as the basic mask to distinguish tensor and numpy, numpy mask will cause interruption of stream + # and performance may not be satisfactory. Relying on PagedAttention operators to automatically generate masks + # to solve the problem. + if self.max_model_len > MAX_MODEL_LEN_32K: + decode_mask = np.triu(np.ones( + shape=(self.max_model_len, self.max_model_len), + dtype=np.float16), + k=1) * decode_mask_coeff + else: + decode_mask = Tensor(np.triu(np.ones( + shape=(self.max_model_len, self.max_model_len), dtype=np.int8), + k=1), + dtype=self.dtype) * decode_mask_coeff + return decode_mask + + def gen_attention_decode_mask(self, position_ids): + if isinstance(self.decode_mask, ms.Tensor): + attention_mask = mint.index_select(self.decode_mask, 0, + position_ids) + elif isinstance(self.decode_mask, np.ndarray): + attention_mask = self.decode_mask[position_ids.asnumpy()] + attention_mask = ms.Tensor(attention_mask, dtype=self.dtype) + else: + raise ValueError( + f"Decode mask type:{type(self.decode_mask)} is not supported.") + + return attention_mask - def gen_attention_mask(self, is_prefill, position_ids, query_lens): + def gen_attention_mask(self, + is_prefill, + position_ids, + query_lens, + attn_metadata=None): if is_prefill: attention_mask = self.prefill_mask else: if max(query_lens) > 1: - attention_mask = mint.index_select(self.decode_mask, 0, - position_ids) + attention_mask = self.gen_attention_decode_mask(position_ids) else: attention_mask = self.hard_mask return attention_mask @@ -86,10 +119,43 @@ class MLALowerTriangularMask(LowerTriangularMask): """ def __init__(self, dtype, max_model_len): - super().__init__(dtype, max_model_len) decode_mask_coeff = 1.0 if self.dtype == mstype.bfloat16 else -10000.0 - self.decode_mask = Tensor(np.triu(np.ones( - shape=(self.max_model_len, self.max_model_len), dtype=np.int8), - k=1), - dtype=self.dtype) * decode_mask_coeff + self.decode_mask = self.init_decode_mask(decode_mask_coeff) + + +class MultiModalLowerTriangularMask(LowerTriangularMask): + r""" + Provide multi modal Infer model attention mask. + Args: + dtype (ms dtype): The compute type of Infer model. + max_model_len (int): The max model length of Infer model. + """ + + def __init__(self, dtype, max_model_len): + + super().__init__(dtype, max_model_len) + + def gen_attention_mask(self, + is_prefill, + position_ids, + query_lens, + attn_metadata=None): + if is_prefill: + attention_mask = self.prefill_mask + else: + if max(query_lens) > 1: + seq_lens_np = attn_metadata.seq_lens_np + context_lens_np = attn_metadata.context_lens.asnumpy() + mm_position_ids_list = [] + for i in range(len(seq_lens_np)): + mm_position_ids_list.append( + np.arange(context_lens_np[i], seq_lens_np[i])) + mm_position_ids = np.concatenate(mm_position_ids_list) + mm_position_ids = ms.Tensor(mm_position_ids, + dtype=position_ids.dtype) + attention_mask = mint.index_select(self.decode_mask, 0, + mm_position_ids) + else: + attention_mask = self.hard_mask + return attention_mask \ No newline at end of file diff --git a/vllm_mindspore/model_executor/models/llama.py b/vllm_mindspore/model_executor/models/llama.py index aabda0d7fae50308de50ac9b9bbae8b0e407cdc2..e226e9b57fc9be58956c2b96a67536ac578d6b3a 100644 --- a/vllm_mindspore/model_executor/models/llama.py +++ b/vllm_mindspore/model_executor/models/llama.py @@ -51,16 +51,14 @@ from vllm_mindspore.model_executor.layers.sampler import (SamplerOutput, get_sampler) from vllm_mindspore.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm_mindspore.model_executor.model_loader.weight_utils import ( + default_weight_loader) from vllm_mindspore.model_executor.models.model_base import MsModelBase from vllm_mindspore.model_executor.models.utils import ( PPMissingLayer, extract_layer_index, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) -def default_weight_loader(param, loaded_weight) -> None: - param.set_data(loaded_weight) - - class LlamaMLP(nn.Cell): def __init__( diff --git a/vllm_mindspore/model_executor/models/mf_models/deepseekv3_weight_processor.py b/vllm_mindspore/model_executor/models/mf_models/deepseekv3_weight_processor.py index b344eefe58aae1c72af65de992fa490e0e447202..f7de55b05f429e9422bab0b9927ff9518abba61a 100644 --- a/vllm_mindspore/model_executor/models/mf_models/deepseekv3_weight_processor.py +++ b/vllm_mindspore/model_executor/models/mf_models/deepseekv3_weight_processor.py @@ -1153,6 +1153,60 @@ class DeepseekV3WeightProcessor(BaseWeightProcessor): weight_name = weight_name.replace('shared_head.', '') return weight_name + def infer_process_moe_with_ep_tp(self, src_hf_dir, hf_weight_map, + layer_id): + w1_list = [] + w2_list = [] + w3_list = [] + + for index in range(self.ep_start, self.ep_stop): + w1_hf_name = f"model.layers.{layer_id}.mlp.experts.{index}.gate_proj.weight" + w2_hf_name = f"model.layers.{layer_id}.mlp.experts.{index}.down_proj.weight" + w3_hf_name = f"model.layers.{layer_id}.mlp.experts.{index}.up_proj.weight" + + w1_ms_param, _ = self.get_safetensor_from_file_split_moe_tp_group( + w1_hf_name, src_hf_dir, hf_weight_map, split_axis=0) + w2_ms_param, _ = self.get_safetensor_from_file_split_moe_tp_group( + w2_hf_name, src_hf_dir, hf_weight_map, split_axis=1) + w3_ms_param, _ = self.get_safetensor_from_file_split_moe_tp_group( + w3_hf_name, src_hf_dir, hf_weight_map, split_axis=0) + + w1_list.append(w1_ms_param) + w2_list.append(w2_ms_param) + w3_list.append(w3_ms_param) + + return w1_list, w2_list, w3_list + + def infer_process_moe_with_ep(self, src_hf_dir, hf_weight_map, layer_id): + w1_list = [] + w2_list = [] + w3_list = [] + + for index in range(self.ep_start, self.ep_stop): + w1_hf_name = f"model.layers.{layer_id}.mlp.experts.{index}.gate_proj.weight" + w2_hf_name = f"model.layers.{layer_id}.mlp.experts.{index}.down_proj.weight" + w3_hf_name = f"model.layers.{layer_id}.mlp.experts.{index}.up_proj.weight" + + w1_ms_param, _ = self.get_safetensor_from_file( + w1_hf_name, src_hf_dir, hf_weight_map) + w2_ms_param, _ = self.get_safetensor_from_file( + w2_hf_name, src_hf_dir, hf_weight_map) + w3_ms_param, _ = self.get_safetensor_from_file( + w3_hf_name, src_hf_dir, hf_weight_map) + + w1_list.append(w1_ms_param) + w2_list.append(w2_ms_param) + w3_list.append(w3_ms_param) + + return w1_list, w2_list, w3_list + + def infer_process_moe(self, src_hf_dir, hf_weight_map, layer_id): + if self.moe_tp_size > 1: + return self.infer_process_moe_with_ep_tp(src_hf_dir, hf_weight_map, + layer_id) + return self.infer_process_moe_with_ep(src_hf_dir, hf_weight_map, + layer_id) + def infer_process_moe_routed_expert_ffn_weight(self, src_hf_dir, layer_id, hf_weight_map): """process moe router expert weight""" @@ -1180,9 +1234,8 @@ class DeepseekV3WeightProcessor(BaseWeightProcessor): name=e_score_correction_bias_ms_name, requires_grad=False) - w1_list = [] - w2_list = [] - w3_list = [] + w1_list, w2_list, w3_list = \ + self.infer_process_moe(src_hf_dir, hf_weight_map, layer_id) base_ms_name = f"model.layers.{layer_id}.feed_forward.routed_experts" w1_ms_name = f"{base_ms_name}.ffn.w1.weight" diff --git a/vllm_mindspore/model_executor/models/model_base.py b/vllm_mindspore/model_executor/models/model_base.py index 38199ae5535b4c6130c857158fbc596f38b8e623..1cb3bf2bd6a2d61ff5d35bf561029c53dc8c8aa1 100644 --- a/vllm_mindspore/model_executor/models/model_base.py +++ b/vllm_mindspore/model_executor/models/model_base.py @@ -196,13 +196,15 @@ class MsModelBase: inputs_embeds: Optional[Tensor] = None, previous_hidden_states: Optional[Tensor] = None, spec_step_idx: int = 0, + **kwargs, ) -> Union[Tensor, IntermediateTensors]: return self.forward(input_ids, positions, intermediate_tensors, inputs_embeds, previous_hidden_states=previous_hidden_states, - spec_step_idx=spec_step_idx) + spec_step_idx=spec_step_idx, + **kwargs) def forward(self, input_ids: Tensor, @@ -248,7 +250,12 @@ class MsModelBase: "Function load_weights should be Implemented!") def _dummy_attention_metadata(self, input_ids: Tensor, positions: Tensor): - input_len = input_ids.shape[0] + if input_ids is not None: + input_len = input_ids.shape[0] + elif positions is not None: + # input_ids is None in multi modal model with v1 arch + input_len = positions.shape[-1] + max_seq_len = ms.Tensor(input_len, dtype=ms.int32) seq_lengths = ms.Tensor([input_len], dtype=ms.int32) q_seq_lens_np = np.array([input_len], dtype=np.int32) @@ -307,13 +314,15 @@ class MsModelBase: query_lens_np = attn_metadata.q_seq_lens_np seq_lens_np = attn_metadata.seq_lens_np + if input_ids is not None: + input_ids = input_ids.astype(ms.int32) q_seq_lens = ms.Tensor(query_lens_np, dtype=ms.int32) position_ids = ms.Tensor(positions, dtype=ms.int32) - attention_mask = self.casual_mask.gen_attention_mask( - is_prefill, positions, query_lens_np) + attention_mask = self.casual_mask.gen_attention_mask( # type: ignore[attr-defined] + is_prefill, positions, query_lens_np, attn_metadata) model_inputs = {} - model_inputs["input_ids"] = input_ids.astype(ms.int32) + model_inputs["input_ids"] = input_ids model_inputs["batch_valid_length"] = ms.from_numpy(seq_lens_np) model_inputs["block_tables"] = attn_metadata.block_tables model_inputs["slot_mapping"] = attn_metadata.slot_mapping @@ -361,9 +370,32 @@ class NativeModel(MsModelBase): compilation_config.static_forward_context[str( i)] = self.kv_caches[i] - def set_model_inputs(self, is_prefill): - dyn_input_ids = Tensor(shape=[None], dtype=mstype.int32) - dyn_position_ids = Tensor(shape=[None], dtype=mstype.int32) + def set_model_inputs(self, input_ids, position_ids, intermediate_tensors, + inputs_embeds, is_prefill): + if input_ids is None: + dyn_input_ids = None + else: + dyn_input_ids = ms.Tensor(shape=[None] * input_ids.ndim, + dtype=mstype.int32) + + if position_ids is None: + dyn_position_ids = None + else: + dyn_position_ids = ms.Tensor(shape=[None] * position_ids.ndim, + dtype=mstype.int32) + + if inputs_embeds is None: + dyn_inputs_embeds = None + else: + dyn_inputs_embeds = ms.Tensor(shape=[None] * inputs_embeds.ndim, + dtype=inputs_embeds.dtype) + + if intermediate_tensors is None: + dyn_intermediate_tensors = None + else: + dyn_intermediate_tensors = ms.Tensor( + shape=[None] * intermediate_tensors.ndim, + dtype=intermediate_tensors.dtype) block_size = self.cache_config.block_size num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config) @@ -384,25 +416,30 @@ class NativeModel(MsModelBase): dyn_value_caches = mutable( [dyn_value_cache for _ in range(num_layers)]) - dyn_slot_mapping = Tensor(shape=[ - None, - ], dtype=mstype.int32) + dyn_slot_mapping = Tensor(shape=[None], dtype=mstype.int32) dynamic_attention_mask = Tensor(shape=[None, None], dtype=self.model_config.dtype) - dyn_batch_valid_length = Tensor(shape=[ - None, - ], dtype=mstype.int32) - dyn_q_seq_lens = Tensor(shape=[ - None, - ], dtype=mstype.int32) + dyn_batch_valid_length = Tensor(shape=[None], dtype=mstype.int32) + dyn_q_seq_lens = Tensor(shape=[None], dtype=mstype.int32) dyn_block_tables = Tensor(shape=[None, None], dtype=mstype.int32) - dyn_intermediate_tensors = None - dyn_inputs_embeds = None self.ready_model.set_inputs( - dyn_input_ids, dyn_position_ids, dyn_key_caches, dyn_value_caches, - is_prefill, dyn_slot_mapping, dynamic_attention_mask, - dyn_batch_valid_length, dyn_q_seq_lens, dyn_block_tables, - dyn_intermediate_tensors, dyn_inputs_embeds) + dyn_input_ids, + dyn_position_ids, + dyn_key_caches, # type: ignore[attr-defined] + dyn_value_caches, + is_prefill, + dyn_slot_mapping, + dynamic_attention_mask, + dyn_batch_valid_length, + dyn_q_seq_lens, + dyn_block_tables, + dyn_intermediate_tensors, + dyn_inputs_embeds) + + dynamic_hidden_states = Tensor(shape=[None, None], + dtype=self.model_config.dtype) + self.lm_head.set_inputs( + dynamic_hidden_states) def prepare_inputs(self, input_ids, positions, intermediate_tensors, inputs_embeds): @@ -426,7 +463,8 @@ class NativeModel(MsModelBase): inputs_embeds) if self.prev_prefill != is_prefill and self.is_graph_mode: - self.set_model_inputs(is_prefill) + self.set_model_inputs(input_ids, positions, intermediate_tensors, + inputs_embeds, is_prefill) self.prev_prefill = is_prefill # for dummy_attention_metadata diff --git a/vllm_mindspore/model_executor/models/qwen2_5_vl.py b/vllm_mindspore/model_executor/models/qwen2_5_vl.py new file mode 100644 index 0000000000000000000000000000000000000000..201a08d7a6e92e42cf31962c79d293c6a11a6529 --- /dev/null +++ b/vllm_mindspore/model_executor/models/qwen2_5_vl.py @@ -0,0 +1,1080 @@ +# SPDX-License-Identifier: Apache-2.0 +# type: ignore +# isort:skip_file +# Adapted from +# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/qwen2_5_vl.py +# Copyright 2025 Huawei Technologites Co., Ltd +# Copyright 2025 The vLLM team. +# Copyright 2025 The Qwen Team. +# Copyright 2025 The HuggingFace Inc. team. +# All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Inference-only Qwen2.5-VL model compatible with HuggingFace weights.""" +import os +from functools import partial +from typing import Callable, Iterable, Mapping, Optional, Set, Tuple, Union, Dict, Any + +import math +import mindspore as ms +import mindspore.nn as nn +import mindspore.mint as mint +import mindspore.ops as ops +import mindspore.mint.nn.functional as F +from mindspore import dtype as mstype + +from vllm_mindspore.model_executor.layers.layernorm import RMSNorm +from vllm_mindspore.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear +from vllm_mindspore.model_executor.layers.logits_processor import LogitsProcessor +from vllm_mindspore.model_executor.layers.quantization.base_config import QuantizationConfig +from vllm_mindspore.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm_mindspore.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm_mindspore.model_executor.model_loader.weight_utils import default_weight_loader +from vllm_mindspore.model_executor.models.model_base import NativeModel, AttentionWrapper +from vllm_mindspore.model_executor.models.interfaces import SupportsMultiModal +from vllm_mindspore.model_executor.models.qwen2 import Qwen2Model # type: ignore[attr-defined] +from vllm_mindspore.model_executor.models.utils import PPMissingLayer, WeightsMapper, maybe_prefix, \ + merge_multimodal_embeddings +from vllm_mindspore.model_executor.models.attention_mask import MultiModalLowerTriangularMask +from vllm_mindspore.distributed.communication_op import AllGatherFromModelParallelRegion + +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.models.module_mapping import MultiModelKeys +from vllm.model_executor.models.qwen2_vl import Qwen2VLDummyInputsBuilder as Qwen2_5_VLDummyInputsBuilder +from vllm.model_executor.models.qwen2_vl import Qwen2VLMultiModalProcessor +from vllm.model_executor.models.qwen2_5_vl import Qwen2_5_VLImageInputs, Qwen2_5_VLVideoInputs, \ + Qwen2_5_VLImagePixelInputs, Qwen2_5_VLImageEmbeddingInputs, Qwen2_5_VLVideoPixelInputs, \ + Qwen2_5_VLVideoEmbeddingInputs, Qwen2_5_VLProcessingInfo + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs +from vllm.multimodal.processing import PromptReplacement +from vllm.multimodal.parse import MultiModalDataItems +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank +from vllm.distributed import utils as dist_utils +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.config import uses_mrope + +logger = init_logger(__name__) + +_ACTIVATION_REGISTRY = {"silu": F.silu} + +# === Vision Inputs === # + + +class _Qwen2VLMultiModalProcessor(Qwen2VLMultiModalProcessor): + + def _get_prompt_replacements( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, Any], + out_mm_kwargs: MultiModalKwargs, + ) -> list[PromptReplacement]: + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + image_processor = self.info.get_image_processor( + **hf_processor_mm_kwargs) + tokenizer = self.info.get_tokenizer() + vocab = tokenizer.get_vocab() + + placeholder = { + "image": vocab[hf_processor.image_token], + "video": vocab[hf_processor.video_token], + } + + merge_length = image_processor.merge_size**2 + + def get_replacement_qwen2vl(item_idx: int, modality: str): + grid_thw = out_mm_kwargs[f"{modality}_grid_thw"][item_idx] + assert isinstance(grid_thw, ms.Tensor) + + num_tokens = int(grid_thw.prod()) // merge_length + return [placeholder[modality]] * num_tokens + + return [ + PromptReplacement( + modality=modality, + target=[placeholder[modality]], + replacement=partial(get_replacement_qwen2vl, + modality=modality), + ) for modality in ("image", "video") + ] + + +# === Vision Encoder === # + + +class Qwen2_5_VisionMLP(nn.Cell): + + def __init__(self, + in_features: int, + hidden_features: int, + bias: bool = False, + act_fn: Callable[[ms.Tensor], ms.Tensor] = F.silu, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + self.gate_proj = ColumnParallelLinear(in_features, + hidden_features, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.gate_proj", + params_dtype=ms.bfloat16) + self.up_proj = ColumnParallelLinear(in_features, + hidden_features, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.up_proj", + params_dtype=ms.bfloat16) + self.down_proj = RowParallelLinear(hidden_features, + in_features, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + params_dtype=ms.bfloat16) + self.act_fn = act_fn + + def construct(self, x: ms.Tensor): + x_gate, _ = self.gate_proj(x) + x_gate = self.act_fn(x_gate) + x_up, _ = self.up_proj(x) + x_down, _ = self.down_proj(x_gate * x_up) + return x_down + + +def apply_rotary_pos_emb_flashatt( + q: ms.Tensor, k: ms.Tensor, cos: ms.Tensor, + sin: ms.Tensor) -> Tuple[ms.Tensor, ms.Tensor]: + q_embed = ops.rotary_position_embedding(q.float(), cos, sin).type_as(q) + k_embed = ops.rotary_position_embedding(k.float(), cos, sin).type_as(k) + return q_embed, k_embed + + +class Qwen2_5_VisionAttention(nn.Cell): + + def __init__( + self, + embed_dim: int, + num_heads: int, + projection_size: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + # Per attention head and per partition values. + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.hidden_size_per_attention_head = dist_utils.divide( + projection_size, num_heads) + self.num_attention_heads_per_partition = dist_utils.divide( + num_heads, self.tp_size) + self.num_heads = num_heads + self.head_dim = self.hidden_size_per_attention_head + + self.qkv = ColumnParallelLinear(input_size=embed_dim, + output_size=3 * projection_size, + quant_config=quant_config, + prefix=f"{prefix}.qkv", + params_dtype=ms.bfloat16) + self.proj = RowParallelLinear(input_size=projection_size, + output_size=embed_dim, + quant_config=quant_config, + prefix=f"{prefix}.proj", + params_dtype=ms.bfloat16) + self.tensor_model_parallel_all_gather = AllGatherFromModelParallelRegion( + ) + + def split_tensor_along_last_dim( + self, + tensor: ms.Tensor, + num_partitions: int, + contiguous_split_chunks: bool = False, + ): + """ Split a tensor along its last dimension. + + Arguments: + tensor: input tensor. + num_partitions: number of partitions to split the tensor + contiguous_split_chunks: If True, make each chunk contiguous + in memory. + + Returns: + A list of Tensors + """ + # Get the size and dimension. + last_dim = tensor.dim() - 1 + last_dim_size = dist_utils.divide(tensor.shape[last_dim], + num_partitions) + # Split. + tensor_list = mint.split(tensor, last_dim_size, dim=last_dim) + # NOTE: torch.split does not create contiguous tensors by default. + + return tensor_list + + def split_qkv(self, qkv: ms.Tensor) -> tuple[ms.Tensor, ...]: + # [s, 3 * head * head_dim] + seq_len, _ = qkv.shape + if self.tp_size > 1: + qkv = self.tensor_model_parallel_all_gather(qkv) + + # [s, 3 * head * head_dim] -> 3 * [s, head * head_dim] + q, k, v = mint.chunk(qkv, 3, dim=-1) + + # 3 * [s, head * head_dim] + if self.tp_size > 1: + splitter = partial(self.split_tensor_along_last_dim, + num_partitions=self.tp_size) + q = splitter(q)[self.tp_rank] + k = splitter(k)[self.tp_rank] + v = splitter(v)[self.tp_rank] + + # 3 * [s, head * head_dim] -> 3 * [s, head, head_dim] + new_shape = (seq_len, self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head) + q, k, v = (x.view(*new_shape) for x in (q, k, v)) + return q, k, v + + def construct( + self, + x: ms.Tensor, + cu_seqlens: ms.Tensor, + position_embeddings: Tuple[ms.Tensor, ms.Tensor], + ) -> ms.Tensor: + seq_length = x.shape[0] + x, _ = self.qkv(x) + q, k, v = self.split_qkv(x) + + cos, sin = position_embeddings + q, k = apply_rotary_pos_emb_flashatt(mint.unsqueeze(q, 0), + mint.unsqueeze(k, 0), cos, sin) + + q = mint.squeeze(q, 0) + k = mint.squeeze(k, 0) + + context_layer = ops.flash_attention_score( + q, + k, + v, + self.num_heads // self.tp_size, + actual_seq_qlen=cu_seqlens, + actual_seq_kvlen=cu_seqlens, + scalar_value=1 / math.sqrt(q.shape[-1]), + input_layout="TND", + ).reshape(seq_length, -1) + output, _ = self.proj(context_layer) + return output + + +class Qwen2_5_VisionBlock(nn.Cell): + + def __init__( + self, + dim: int, + num_heads: int, + mlp_hidden_dim: int, + act_fn: Callable[[ms.Tensor], ms.Tensor] = F.silu, + norm_layer: Optional[Callable[[int], nn.Cell]] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + if norm_layer is None: + norm_layer = partial(mint.nn.LayerNorm, + eps=1e-6, + dtype=ms.bfloat16) + self.norm1 = norm_layer(dim) + self.norm2 = norm_layer(dim) + self.attn = Qwen2_5_VisionAttention(embed_dim=dim, + num_heads=num_heads, + projection_size=dim, + quant_config=quant_config, + prefix=f"{prefix}.attn") + self.mlp = Qwen2_5_VisionMLP(dim, + mlp_hidden_dim, + act_fn=act_fn, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + + def construct( + self, x: ms.Tensor, cu_seqlens: ms.Tensor, + position_embeddings: Tuple[ms.Tensor, ms.Tensor]) -> ms.Tensor: + x = x + self.attn(self.norm1(x), + cu_seqlens=cu_seqlens, + position_embeddings=position_embeddings) + x = x + self.mlp(self.norm2(x)) + return x + + +class Qwen2_5_VisionPatchEmbed(nn.Cell): + + def __init__( + self, + patch_size: int = 14, + temporal_patch_size: int = 2, + in_channels: int = 3, + hidden_size: int = 1152, + ) -> None: + super().__init__() + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.hidden_size = hidden_size + self.dtype = ms.bfloat16 + + self.proj = nn.Dense(temporal_patch_size * patch_size * patch_size * + in_channels, + self.hidden_size, + has_bias=False, + dtype=self.dtype) + + def construct(self, x: ms.Tensor) -> ms.Tensor: + x = self.proj(x) # B Ph*Pw C_out + return x + + +class Qwen2_5_VisionPatchMerger(nn.Cell): + + def __init__( + self, + d_model: int, + context_dim: int, + norm_layer: Optional[Callable[[int], nn.Cell]] = None, + spatial_merge_size: int = 2, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = context_dim * (spatial_merge_size**2) + if norm_layer is None: + norm_layer = partial(mint.nn.LayerNorm, + eps=1e-6, + dtype=ms.bfloat16) + self.ln_q = norm_layer(context_dim) + self.mlp = nn.CellList([ + ColumnParallelLinear(self.hidden_size, + self.hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.mlp.0", + params_dtype=ms.bfloat16), + nn.GELU(), + RowParallelLinear(self.hidden_size, + d_model, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.mlp.2", + params_dtype=ms.bfloat16), + ]) + + def construct(self, x: ms.Tensor) -> ms.Tensor: + x = self.ln_q(x) + x = x.view(-1, self.hidden_size) + + mlp_fc1, mlp_act, mlp_fc2 = self.mlp + x_parallel, _ = mlp_fc1(x) + x_parallel = mlp_act(x_parallel) + out, _ = mlp_fc2(x_parallel) + return out + + +class Qwen2_5_VisionRotaryEmbedding(nn.Cell): + + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__() + self.dim = dim + self.theta = theta + self.inv_freq = 1.0 / (theta**( + mint.arange(0, dim, 2, dtype=ms.float32) / dim)) + self._seq_len_cached = 0 + self._freqs_cached = None + + def update_freqs_cache(self, seqlen: int) -> None: + if seqlen > self._seq_len_cached: + seqlen *= 2 + self._seq_len_cached = seqlen + self.inv_freq = 1.0 / (self.theta**( + mint.arange(0, self.dim, 2, dtype=ms.float32) / self.dim)) + seq = mint.arange(seqlen, dtype=self.inv_freq.dtype) + freqs = mint.outer(seq, self.inv_freq) + self._freqs_cached = freqs + + def construct(self, seqlen: int) -> ms.Tensor: + self.update_freqs_cache(seqlen) + return self._freqs_cached[:seqlen] # type: ignore[index] + + +class Qwen2_5_VisionTransformer(nn.Cell): + + def __init__( + self, + vision_config, + norm_eps: float = 1e-6, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + + patch_size = vision_config.patch_size + temporal_patch_size = vision_config.temporal_patch_size + in_channels = vision_config.in_channels + depth = vision_config.depth + self.hidden_size = vision_config.hidden_size + self.num_heads = vision_config.num_heads + + # args for get_window_index + self.window_size = vision_config.window_size + self.patch_size = vision_config.patch_size + self.spatial_merge_size = vision_config.spatial_merge_size + self.fullatt_block_indexes = vision_config.fullatt_block_indexes + self.spatial_merge_unit = self.spatial_merge_size**2 + + self.patch_embed = Qwen2_5_VisionPatchEmbed( + patch_size=patch_size, + temporal_patch_size=temporal_patch_size, + in_channels=in_channels, + hidden_size=self.hidden_size, + ) + + norm_layer = partial(RMSNorm, eps=norm_eps, params_dtype=ms.bfloat16) + head_dim = self.hidden_size // self.num_heads + self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2) + + self.blocks = nn.CellList([ + Qwen2_5_VisionBlock( + dim=self.hidden_size, + num_heads=self.num_heads, + mlp_hidden_dim=vision_config.intermediate_size, + act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act], + norm_layer=norm_layer, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{layer_idx}") + for layer_idx in range(depth) + ]) + self.merger = Qwen2_5_VisionPatchMerger( + d_model=vision_config.out_hidden_size, + context_dim=self.hidden_size, + norm_layer=norm_layer, + spatial_merge_size=self.spatial_merge_size, + quant_config=quant_config, + prefix=f"{prefix}.merger", + ) + from mindspore.communication.management import get_rank + self.rank_id = get_rank() + + def set_model_inputs(self): + dyn_x = ms.Tensor(shape=[None, None], dtype=self.dtype) + dyn_rotary_pos_emb = ms.Tensor(shape=[None, None], + dtype=mstype.float32) + dyn_window_index = ms.Tensor(shape=[None], dtype=mstype.int64) + dyn_cu_window_seqlens = ms.Tensor(shape=[None], dtype=mstype.int64) + dyn_grid_thw = ms.Tensor(shape=[None, None], dtype=mstype.int64) + + self.set_inputs( + dyn_x, + dyn_rotary_pos_emb, + dyn_window_index, + dyn_cu_window_seqlens, + dyn_grid_thw, + ) + + @property + def dtype(self) -> ms.Type: + return self.patch_embed.dtype + + def construct( + self, + x: ms.Tensor, + rotary_pos_emb: ms.Tensor, + window_index: ms.Tensor, + cu_window_seqlens: ms.Tensor, + grid_thw: ms.Tensor, + ) -> ms.Tensor: + hidden_states = x.to(dtype=self.dtype) + hidden_states = self.patch_embed(hidden_states) + + cu_window_seqlens = cu_window_seqlens.astype(ms.int32) + cu_window_seqlens = mint.unique_consecutive(cu_window_seqlens) + seq_len, _ = hidden_states.shape + hidden_states = hidden_states.reshape( + seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + hidden_states = hidden_states[window_index] + hidden_states = hidden_states.reshape(seq_len, -1) + rotary_pos_emb = rotary_pos_emb.reshape( + seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + rotary_pos_emb = rotary_pos_emb[window_index] + rotary_pos_emb = rotary_pos_emb.reshape(1, seq_len, 1, -1) + emb = mint.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + position_embeddings = (mint.cos(emb), mint.sin(emb)) + + grid_thw_1 = grid_thw.index_select(1, ms.Tensor([1])).reshape(-1) + grid_thw_2 = grid_thw.index_select(1, ms.Tensor([2])).reshape(-1) + grid_thw_0 = grid_thw.index_select(1, ms.Tensor([0])).reshape(-1) + cu_seqlens = mint.cumsum(mint.repeat_interleave( + grid_thw_1 * grid_thw_2, grid_thw_0), + dim=0, + dtype=ms.int32) + + cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0) + # transformers + for layer_num, blk in enumerate(self.blocks): + if layer_num in self.fullatt_block_indexes: + cu_seqlens_now = cu_seqlens + else: + cu_seqlens_now = cu_window_seqlens + hidden_states = blk(hidden_states, + cu_seqlens=cu_seqlens_now, + position_embeddings=position_embeddings) + + # adapter + hidden_states = self.merger(hidden_states) + reverse_indices = mint.argsort(window_index) + hidden_states = hidden_states[reverse_indices] + return hidden_states + + def load_weights(self, weights: Iterable[Tuple[str, ms.Tensor]], + params_dict: Dict[str, ms.Parameter]) -> Set[str]: + loaded_params: Set[str] = set() + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + for name, loaded_weight in weights: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + if name == "visual.patch_embed.proj.weight": + loaded_weight = loaded_weight[:] + loaded_weight = loaded_weight.reshape( + loaded_weight.shape[0], -1) + param.set_data(ms.Tensor(loaded_weight, dtype=param.dtype)) + else: + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class Qwen2_5_VLMultiModalProcessor(_Qwen2VLMultiModalProcessor): + + def _get_mm_fields_config( + self, + hf_inputs, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + **super()._get_mm_fields_config(hf_inputs, hf_processor_mm_kwargs), + second_per_grid_ts=MultiModalFieldConfig.batched("video"), + ) + + +@MULTIMODAL_REGISTRY.register_processor( + Qwen2_5_VLMultiModalProcessor, + info=Qwen2_5_VLProcessingInfo, + dummy_inputs=Qwen2_5_VLDummyInputsBuilder) +class Qwen2_5_VLForConditionalGeneration(NativeModel, SupportsMultiModal): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + # LoRA specific attributes + supported_lora_modules = [ + # language model + "qkv_proj", + "o_proj", + "gate_up_proj", + "down_proj", # Same name with vision encoder + # vision tower + "qkv", + "gate_proj", + "up_proj", + "attn.proj", # Distinguish patch_embed.proj + "fc1", + "fc2", + # projector + "mlp.0", + "mlp.2" + ] + + embedding_modules = {} # type: ignore[var-annotated] + embedding_padding_modules = [] # type: ignore[var-annotated] + + # To ensure correct weight loading and mapping. + hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={ + "lm_head.": "language_model.lm_head.", + "model.": "language_model.model.", + }) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + + self.config = config + self.multimodal_config = multimodal_config + + self.visual = Qwen2_5_VisionTransformer( + config.vision_config, + norm_eps=getattr(config, "rms_norm_eps", 1e-6), + quant_config=self._maybe_ignore_quant_config(quant_config), + prefix=maybe_prefix(prefix, "visual"), + ) + if self.is_graph_mode: + self.visual.construct = ms.jit(function=self.visual, + jit_level='O0') + self.visual.set_model_inputs() + + self.model = Qwen2Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + + if get_pp_group().is_last_rank: + if config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + params_dtype=ms.bfloat16, + quant_config=quant_config, + prefix=maybe_prefix( + prefix, "lm_head")) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = get_sampler() + else: + self.lm_head = PPMissingLayer() + + self.common_preprocess(vllm_config, prefix) + self.spatial_merge_size = config.vision_config.spatial_merge_size + + self.window_size = config.vision_config.window_size + self.patch_size = config.vision_config.patch_size + self.spatial_merge_unit = self.spatial_merge_size**2 + self.hidden_size = config.vision_config.hidden_size + self.num_heads = config.vision_config.num_heads + head_dim = self.hidden_size // self.num_heads + self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2) + + def common_preprocess(self, vllm_config, prefix=""): + self.set_modules({ + "visual": self.visual, + "model": self.model, + "lm_head": self.lm_head + }) + self.casual_mask = MultiModalLowerTriangularMask( + dtype=self.model_config.dtype, + max_model_len=self.model_config.max_model_len) + self.kv_caches = [ + AttentionWrapper() for i in range(self.config.num_hidden_layers) + ] + + compilation_config = vllm_config.compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + for i in range(self.config.num_hidden_layers): + compilation_config.static_forward_context[str( + i)] = self.kv_caches[i] + + def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig): + # GPTQ configs do not have a list of ignored modules, however AutoGPTQ + # seems to avoid vision encoder sections for some models. + # if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)): + # return None + return quant_config + + def _validate_and_reshape_mm_tensor(self, mm_input: object, + name: str) -> ms.Tensor: + if not isinstance(mm_input, (ms.Tensor, list)): + raise ValueError(f"Incorrect type of {name}. " + f"Got type: {type(mm_input)}") + if isinstance(mm_input, ms.Tensor): + if mm_input.ndim == 2: + return mm_input + if mm_input.ndim != 3: + raise ValueError(f"{name} should be 2D or batched 3D tensor. " + f"Got ndim: {mm_input.ndim} " + f"(shape={mm_input.shape})") + return mint.concat(list(mm_input)) + else: + return mint.concat(mm_input) + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[Qwen2_5_VLImageInputs]: + pixel_values = kwargs.pop("pixel_values", None) + image_embeds = kwargs.pop("image_embeds", None) + image_grid_thw = kwargs.pop("image_grid_thw", None) + + if pixel_values is None and image_embeds is None: + return None + + if pixel_values is not None: + pixel_values = self._validate_and_reshape_mm_tensor( + pixel_values, "image pixel values") + image_grid_thw = self._validate_and_reshape_mm_tensor( + image_grid_thw, "image grid_thw") + + if not isinstance(pixel_values, (ms.Tensor, list)): + raise ValueError("Incorrect type of image pixel values. " + f"Got type: {type(pixel_values)}") + + return Qwen2_5_VLImagePixelInputs(type="pixel_values", + pixel_values=pixel_values, + image_grid_thw=image_grid_thw) + + if image_embeds is not None: + image_embeds = self._validate_and_reshape_mm_tensor( + image_embeds, "image embeds") + image_grid_thw = self._validate_and_reshape_mm_tensor( + image_grid_thw, "image grid_thw") + + if not isinstance(image_embeds, ms.Tensor): + raise ValueError("Incorrect type of image embeddings. " + f"Got type: {type(image_embeds)}") + return Qwen2_5_VLImageEmbeddingInputs( + type="image_embeds", + image_embeds=image_embeds, + image_grid_thw=image_grid_thw) + + return None + + def _parse_and_validate_video_input( + self, **kwargs: object) -> Optional[Qwen2_5_VLVideoInputs]: + pixel_values_videos = kwargs.pop("pixel_values_videos", None) + video_embeds = kwargs.pop("video_embeds", None) + video_grid_thw = kwargs.pop("video_grid_thw", None) + second_per_grid_ts = kwargs.pop("second_per_grid_ts", None) + + if pixel_values_videos is None and video_embeds is None: + return None + + if pixel_values_videos is not None: + pixel_values_videos = self._validate_and_reshape_mm_tensor( + pixel_values_videos, "video pixel values") + video_grid_thw = self._validate_and_reshape_mm_tensor( + video_grid_thw, "video grid_thw") + + return Qwen2_5_VLVideoPixelInputs( + type="pixel_values_videos", + pixel_values_videos=pixel_values_videos, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + ) + + if video_embeds is not None: + video_embeds = self._validate_and_reshape_mm_tensor( + video_embeds, "video embeds") + video_grid_thw = self._validate_and_reshape_mm_tensor( + video_grid_thw, "video grid_thw") + + if not isinstance(video_embeds, ms.Tensor): + raise ValueError("Incorrect type of video embeddings. " + f"Got type: {type(video_embeds)}") + return Qwen2_5_VLVideoEmbeddingInputs( + type="video_embeds", + video_embeds=video_embeds, + video_grid_thw=video_grid_thw) + + return None + + def rot_pos_emb(self, grid_thw: ms.Tensor) -> ms.Tensor: + pos_ids = [] + for t, h, w in grid_thw: + t, h, w = t.item(), h.item(), w.item() + hpos_ids = mint.arange(h).unsqueeze(1).expand((-1, w)) + wpos_ids = mint.arange(w).unsqueeze(0).expand((h, -1)) + + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ).permute(0, 2, 1, 3).flatten() + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ).permute(0, 2, 1, 3).flatten() + pos_ids.append( + mint.tile(mint.stack([hpos_ids, wpos_ids], dim=-1), (t, 1))) + pos_ids = mint.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max().item() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb + + def get_window_index(self, grid_thw): + window_index = [] + cu_window_seqlens = [ms.Tensor([0])] + window_index_id = 0 + vit_merger_window_size = (self.window_size // + self.spatial_merge_size // self.patch_size) + + for grid_t, grid_h, grid_w in grid_thw: + grid_t, grid_h, grid_w = grid_t.item(), grid_h.item(), grid_w.item( + ) + llm_grid_h = grid_h // self.spatial_merge_size + llm_grid_w = grid_w // self.spatial_merge_size + index = mint.arange(grid_t * llm_grid_h * llm_grid_w).reshape( + grid_t, llm_grid_h, llm_grid_w) + pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size + pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size + num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size + num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size + index_padded = F.pad(index, (0, pad_w, 0, pad_h), 'constant', -100) + index_padded = index_padded.reshape(grid_t, num_windows_h, + vit_merger_window_size, + num_windows_w, + vit_merger_window_size) + index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( + grid_t, num_windows_h * num_windows_w, vit_merger_window_size, + vit_merger_window_size) + seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) + index_padded = index_padded.reshape(-1) + index_new = index_padded[index_padded != -100] + window_index.append(index_new + window_index_id) + cu_seqlens_tmp = mint.cumsum( + seqlens, + 0) * self.spatial_merge_unit + cu_window_seqlens[-1][-1] + cu_window_seqlens.append(cu_seqlens_tmp) + window_index_id += grid_t * llm_grid_h * llm_grid_w + window_index = mint.cat(window_index, dim=0) + cu_window_seqlens = mint.cat(cu_window_seqlens, dim=0) + return window_index, cu_window_seqlens + + def _process_image_input( + self, image_input: Qwen2_5_VLImageInputs) -> tuple[ms.Tensor, ...]: + + grid_thw = image_input["image_grid_thw"] + assert grid_thw.ndim == 2 + + if image_input["type"] == "image_embeds": + image_embeds = image_input["image_embeds"].type(self.visual.dtype) + else: + pixel_values = image_input["pixel_values"].type(self.visual.dtype) + os.environ[ + "MS_DISABLE_INTERNAL_KERNELS_LIST"] = "FlashAttentionScore" + # compute position embedding + rotary_pos_emb = self.rot_pos_emb(grid_thw) + # windows attention + window_index, cu_window_seqlens = self.get_window_index(grid_thw) + image_embeds = self.visual(pixel_values, rotary_pos_emb, + window_index, cu_window_seqlens, + grid_thw) + os.environ["MS_DISABLE_INTERNAL_KERNELS_LIST"] = "" + + # Split concatenated embeddings for each image item. + merge_size = self.visual.spatial_merge_size + sizes = grid_thw.prod(-1) // merge_size // merge_size + + return image_embeds.split(sizes.tolist()) + + def _process_video_input( + self, video_input: Qwen2_5_VLVideoInputs) -> tuple[ms.Tensor, ...]: + + grid_thw = video_input["video_grid_thw"] + assert grid_thw.ndim == 2 + + if video_input["type"] == "video_embeds": + video_embeds = video_input["video_embeds"].type(self.visual.dtype) + else: + pixel_values_videos = video_input["pixel_values_videos"].type( + self.visual.dtype) + os.environ[ + "MS_DISABLE_INTERNAL_KERNELS_LIST"] = "FlashAttentionScore" + rotary_pos_emb = self.rot_pos_emb(grid_thw) + # windows attention + window_index, cu_window_seqlens = self.get_window_index(grid_thw) + video_embeds = self.visual(pixel_values_videos, rotary_pos_emb, + window_index, cu_window_seqlens, + grid_thw) + os.environ["MS_DISABLE_INTERNAL_KERNELS_LIST"] = "" + + # Split concatenated embeddings for each video item. + merge_size = self.visual.spatial_merge_size + sizes = grid_thw.prod(-1) // merge_size // merge_size + + return video_embeds.split(sizes.tolist()) + + def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: + modalities = {} + + # Preserve the order of modalities if there are multiple of them + # from the order of kwargs. + for input_key in kwargs: + if input_key in ("pixel_values", + "image_embeds") and "images" not in modalities: + modalities["images"] = self._parse_and_validate_image_input( + **kwargs) + if input_key in ("pixel_values_videos", + "video_embeds") and "videos" not in modalities: + modalities["videos"] = self._parse_and_validate_video_input( + **kwargs) + return modalities + + def get_multimodal_embeddings(self, + **kwargs) -> Optional[tuple[ms.Tensor, ...]]: + + modalities = self._parse_and_validate_multimodal_inputs(**kwargs) + if not modalities: + return None + + # The result multimodal_embeddings is tuple of tensors, with each + # tensor correspoending to a multimodal data item (image or video). + multimodal_embeddings: tuple[ms.Tensor, ...] = () + + # NOTE: It is important to iterate over the keys in this dictionary + # to preserve the order of the modalities. + for modality in modalities: + if modality == "images": + image_input = modalities["images"] + vision_embeddings = self._process_image_input(image_input) + multimodal_embeddings += vision_embeddings + if modality == "videos": + video_input = modalities["videos"] + video_embeddings = self._process_video_input(video_input) + multimodal_embeddings += video_embeddings + return multimodal_embeddings + + def get_input_embeddings( + self, + input_ids: ms.Tensor, + multimodal_embeddings: Optional[tuple[ms.Tensor, ...]] = None, + ) -> ms.Tensor: + # input_ids = input_ids.to(mstype.int64) + inputs_embeds = self.model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, + [self.config.image_token_id, self.config.video_token_id]) + os.environ["MS_DISABLE_INTERNAL_KERNELS_LIST"] = "" + return inputs_embeds + + def get_input_embeddings_v0( + self, + input_ids: ms.Tensor, + image_input: Optional[tuple[ms.Tensor, ...]] = None, + video_input: Optional[tuple[ms.Tensor, ...]] = None, + ) -> ms.Tensor: + inputs_embeds = self.get_input_embeddings(input_ids) + if image_input is not None: + image_embeds = self._process_image_input(image_input) + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + image_embeds, + placeholder_token_id=self.config.image_token_id, + ) + + if video_input is not None: + video_embeds = self._process_video_input(video_input) + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + video_embeds, + placeholder_token_id=self.config.video_token_id, + ) + return inputs_embeds + + def forward( + self, + input_ids: ms.Tensor, + positions: ms.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[ms.Tensor] = None, + **kwargs: object, + ) -> Union[ms.Tensor, IntermediateTensors]: + if intermediate_tensors is not None: + inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner from + # `get_multimodal_embeddings` and `get_input_embeddings`, this + # condition is only for v0 compatibility. + elif inputs_embeds is None: + image_input = self._parse_and_validate_image_input(**kwargs) + video_input = self._parse_and_validate_video_input(**kwargs) + + if image_input is None and video_input is None: + inputs_embeds = None + else: + if uses_mrope(self.config): + assert positions.ndim == 2 and positions.shape[0] == 3, ( + "multimodal section rotary embedding requires " + f"(3, seq_len) positions, but got {positions.shape}") + inputs_embeds = self.get_input_embeddings_v0( + input_ids, + image_input=image_input, + video_input=video_input) + input_ids = None + hidden_states = self.exec_model(input_ids, positions, + intermediate_tensors, inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: ms.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[ms.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample(self, logits: ms.Tensor, + sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights( + self, weights: Iterable[Tuple[str, ms.Tensor]] + ) -> None: # type: ignore[override] + params_dict = self.get_params_dict() + for name, weight in weights: + if "visual." in name: + self.visual.load_weights([(name, weight)], params_dict) + else: + self.model.load_weights([(name, weight)], params_dict) + + return None + + def get_mm_mapping(self) -> MultiModelKeys: + """ + Get the module prefix in multimodal models + """ + return MultiModelKeys.from_string_field( + language_model="language_model", + connector="visual.", + tower_model="visual.merger.") diff --git a/vllm_mindspore/model_executor/models/registry.py b/vllm_mindspore/model_executor/models/registry.py index 6df6fc8beafc74fd87608aeba6b62a30b3a6353d..e5165de487bf7a5ceac2f927536863192459f014 100644 --- a/vllm_mindspore/model_executor/models/registry.py +++ b/vllm_mindspore/model_executor/models/registry.py @@ -27,6 +27,8 @@ from vllm_mindspore.utils import (is_mindformers_model_backend, _NATIVE_MODELS = { "LlamaForCausalLM": ("llama", "LlamaForCausalLM"), "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), + "Qwen2_5_VLForConditionalGeneration": + ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"), } _MINDFORMERS_MODELS = { diff --git a/vllm_mindspore/model_executor/models/utils.py b/vllm_mindspore/model_executor/models/utils.py index 66792cc040397d2310428e73ee08e21a44a41dc4..bcf5f6a4452b66b87db04ac83262845c495424a4 100644 --- a/vllm_mindspore/model_executor/models/utils.py +++ b/vllm_mindspore/model_executor/models/utils.py @@ -252,8 +252,7 @@ def merge_multimodal_embeddings( This updates ``inputs_embeds`` in place. """ if isinstance(placeholder_token_id, list): - placeholder_token_id = ms.Tensor(placeholder_token_id, - device=input_ids.device) + placeholder_token_id = ms.Tensor(placeholder_token_id) return _merge_multimodal_embeddings( inputs_embeds, ms.numpy.isin(input_ids, placeholder_token_id), diff --git a/vllm_mindspore/multimodal/inputs.py b/vllm_mindspore/multimodal/inputs.py index 65b1c9eea1c50664dbb297c0db5afdd042c8dfb6..2d3236d388b639a4725c220e0e5c99e2f36adcd4 100644 --- a/vllm_mindspore/multimodal/inputs.py +++ b/vllm_mindspore/multimodal/inputs.py @@ -18,12 +18,15 @@ # See the License for the specific language governing permissions and # limitations under the License. """Adaption for mindspore.""" - +from dataclasses import dataclass +from collections import defaultdict from typing import Union, cast import mindspore -from vllm.multimodal.inputs import (BatchedTensorInputs, JSONTree, - json_map_leaves) +from vllm.multimodal.inputs import (BaseMultiModalField, BatchedTensorInputs, JSONTree, + json_map_leaves, nested_tensors_equal) +from vllm.multimodal import MultiModalKwargs + NestedTensors = Union[ list["NestedTensors"], @@ -33,6 +36,46 @@ NestedTensors = Union[ ] +@dataclass +class MultiModalFieldElem: + """ + Represents a keyword argument corresponding to a multi-modal item + in :class:`MultiModalKwargs`. + """ + + modality: str + """ + The modality of the corresponding multi-modal item. + Each multi-modal item can consist of multiple keyword arguments. + """ + + key: str + """ + The key of this field in :class:`MultiModalKwargs`, + i.e. the name of the keyword argument to be passed to the model. + """ + + data: NestedTensors + """ + The tensor data of this field in :class:`MultiModalKwargs`, + i.e. the value of the keyword argument to be passed to the model. + """ + + field: "BaseMultiModalField" + """ + Defines how to combine the tensor data of this field with others + in order to batch multi-modal items together for model inference. + """ + + def __eq__(self, other: object) -> bool: + if not isinstance(other, self.__class__): + return False + + return ((self.modality, self.key) == (other.modality, other.key) + and nested_tensors_equal(self.data, other.data) + and type(self.field) == type(other.field)) # noqa: E721 + + @staticmethod # type: ignore def as_kwargs( batched_inputs: BatchedTensorInputs, @@ -48,3 +91,18 @@ def as_kwargs( ) return cast(BatchedTensorInputs, json_mapped) + +def from_items(items): + """Construct a new :class:`MultiModalKwargs` from multiple items.""" + elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list) + for item in items: + for key, elem in item.items(): + # transform elem.data to tensor, gpu is tensor. + elem.data = mindspore.Tensor(elem.data) + elems_by_key[key].append(elem) + data = { + key: elems[0].field.reduce_data(elems) + for key, elems in elems_by_key.items() if len(elems) > 0 + } + + return MultiModalKwargs(data, items=items) \ No newline at end of file diff --git a/vllm_mindspore/v1/worker/gpu_model_runner.py b/vllm_mindspore/v1/worker/gpu_model_runner.py index 90fc6696dec65453cbf759d605c5b562c2e14700..94b20cc473816ca889a9963f7155e263f9e29f40 100644 --- a/vllm_mindspore/v1/worker/gpu_model_runner.py +++ b/vllm_mindspore/v1/worker/gpu_model_runner.py @@ -35,6 +35,7 @@ from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.worker.gpu_input_batch import CachedRequestState from vllm_mindspore.utils import get_dtype_size, get_valid_dtype +from vllm_mindspore.model_executor.layers.rotary_embedding import InferMRotaryEmbedding as MRotaryEmbedding logger = init_logger(__name__) @@ -454,7 +455,6 @@ def _update_states(self, scheduler_output) -> None: # Update the persistent batch. self.input_batch.num_computed_tokens_cpu[req_index] = ( num_computed_tokens) - start_index = (len(req_state.block_ids) - len(req_data.new_block_ids)) self.input_batch.block_table.append_row(req_data.new_block_ids, req_index) # Add new_token_ids to token_ids_cpu. @@ -616,3 +616,57 @@ def get_dp_padding(self, num_tokens: int): # padded based on `num_tokens_across_dp`, while the model only accepts # inputs with actual shape. return 0, None + +def _calc_mrope_positions( + self, + scheduler_output: "SchedulerOutput"): # type: ignore[name-defined] + mrope_pos_ptr = 0 + for index, req_id in enumerate(self.input_batch.req_ids): + req = self.requests[req_id] + assert req.mrope_positions is not None + + num_computed_tokens = \ + self.input_batch.num_computed_tokens_cpu[index] + num_scheduled_tokens = \ + scheduler_output.num_scheduled_tokens[req_id] + num_prompt_tokens = len(req.prompt_token_ids) + + if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens: + prompt_part_len = max(0, num_prompt_tokens - num_computed_tokens) + completion_part_len = max(0, + num_scheduled_tokens - prompt_part_len) + else: + prompt_part_len = num_scheduled_tokens + completion_part_len = 0 + + assert num_scheduled_tokens == prompt_part_len + completion_part_len + + if prompt_part_len > 0: + # prompt's mrope_positions are pre-computed + # gpu is number or tensor, but we are numpy, so we transform to int + dst_start = int(mrope_pos_ptr) + dst_end = int(mrope_pos_ptr + prompt_part_len) + src_start = int(num_computed_tokens) + src_end = int(num_computed_tokens + prompt_part_len) + + self.mrope_positions_cpu[:, dst_start:dst_end] = \ + req.mrope_positions[:,src_start:src_end] + + mrope_pos_ptr += prompt_part_len + + if completion_part_len > 0: + # compute completion's mrope_positions on-the-fly + dst_start = mrope_pos_ptr + dst_end = mrope_pos_ptr + completion_part_len + + self.mrope_positions_cpu[:, dst_start:dst_end] = \ + MRotaryEmbedding.get_next_input_positions_tensor( + req.mrope_position_delta, + context_len=num_computed_tokens + + prompt_part_len, + seq_len=num_computed_tokens + + prompt_part_len + + completion_part_len, + ) + + mrope_pos_ptr += completion_part_len diff --git a/vllm_mindspore/worker/worker.py b/vllm_mindspore/worker/worker.py index ed2b636af37a2c717cfce219c33956d77659512d..c9c71ea5158d1786e706b47c3b8c839e1decfec2 100644 --- a/vllm_mindspore/worker/worker.py +++ b/vllm_mindspore/worker/worker.py @@ -41,11 +41,17 @@ def _prepare_input_for_warmup(model_config, block_tables_num = [ i for i in range(math.ceil(seq_len / cache_engine.block_size)) ] + + # adapter multi modal warm up + seq_data = dummy_data.seq_data + if seq_len == 1: + seq_data = dummy_data.seq_data.from_prompt_token_counts((0, seq_len)) + seqs = [ SequenceGroupMetadata( request_id=str(idx), is_prompt=is_prefill, - seq_data={idx: dummy_data.seq_data}, + seq_data={idx: seq_data}, sampling_params=SamplingParams(), block_tables={idx: block_tables_num}, lora_request=None, @@ -55,11 +61,6 @@ def _prepare_input_for_warmup(model_config, ] model_input = model_runner.prepare_model_input(seqs) - block_tables = model_input.attn_metadata.block_tables - if block_tables is not None and block_tables.numel() <= 0: - model_input.attn_metadata.block_tables = torch.zeros((1, 1), - dtype=torch.int32) - previous_hidden_states = None if not is_mtp_model else torch.ones( [bs, seq_len, model_config.get_hidden_size()], dtype=get_valid_dtype(model_config.dtype))