diff --git a/vllm_mindspore/__init__.py b/vllm_mindspore/__init__.py index de11500c1e1d5eb5830b83f89d55240046fc59b0..dda0ff3bde235989d884c0b59198f9b82a5aca8b 100644 --- a/vllm_mindspore/__init__.py +++ b/vllm_mindspore/__init__.py @@ -142,7 +142,6 @@ Worker.init_device = wrapper_worker_init_device(Worker.init_device) from vllm_mindspore.worker.model_runner import ( _get_cuda_graph_pad_size, profile_run, - _dummy_run, _get_supported_attention_backends ) @@ -150,7 +149,6 @@ vllm.worker.model_runner.ModelInputForGPUBuilder._get_cuda_graph_pad_size = ( _get_cuda_graph_pad_size ) vllm.worker.model_runner.GPUModelRunnerBase.profile_run = profile_run -vllm.worker.model_runner.GPUModelRunnerBase._dummy_run = _dummy_run import vllm.worker.multi_step_model_runner vllm.worker.multi_step_model_runner._get_supported_attention_backends = _get_supported_attention_backends diff --git a/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py b/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py index 49f9fc91c6b5362b804564963f7047bebebd1de6..14b7025b65571aa1e7d93a2f617d1717aa3fb5b6 100644 --- a/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py +++ b/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py @@ -46,7 +46,8 @@ from research.deepseek3.deepseek3 import ( ) from vllm_mindspore.model_executor.layers.sampler import get_sampler -from vllm_mindspore.model_executor.models.mf_models.mf_model_base import MfModelBase, Fake_MLA +from vllm_mindspore.model_executor.models.mf_models.mf_model_base import MfModelBase, MLAWrapper +from vllm_mindspore.model_executor.models.model_base import MLAWrapper from vllm_mindspore.model_executor.models.mf_models.deepseekv3_weight_processor import DeepseekV3WeightProcessor from vllm_mindspore.model_executor.models.mf_models.attention_mask import LowerTriangularMask @@ -90,7 +91,7 @@ class DeepseekV3ForCausalLM(MfModelBase): self.sampler = get_sampler() self.set_modules({"model": self.network}) - self.kv_caches = [Fake_MLA() for i in range(self.mf_model_config.num_layers)] + self.kv_caches = [MLAWrapper() for i in range(self.mf_model_config.num_layers)] compilation_config = get_current_vllm_config().compilation_config if prefix in compilation_config.static_forward_context: diff --git a/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py b/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py index 4ba57131f7b3be5396500a944fec65daff895727..931f5e7f7e80eccb1c9c90d51fed85d256345cdc 100644 --- a/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py +++ b/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py @@ -55,35 +55,6 @@ def _batch_seq(input_tokens, prefill): return ms.mint.reshape(input_tokens, (-1, 1)).to(ms.int32) -class Fake_Attention: - def __init__(self): - vllm_config = get_current_vllm_config() - block_size = vllm_config.cache_config.block_size - num_kv_heads = vllm_config.model_config.get_num_kv_heads( - vllm_config.parallel_config - ) - head_size = vllm_config.model_config.get_head_size() - num_block = 0 - self.kv_shape = [num_block, block_size, num_kv_heads, head_size] - self.kv_cache = [ - ( - torch.zeros(self.kv_shape, dtype=ms.bfloat16, device="Ascend"), - torch.zeros(self.kv_shape, dtype=ms.bfloat16, device="Ascend"), - ) - for _ in range(vllm_config.parallel_config.pipeline_parallel_size) - ] - self.attn_type = AttentionType.DECODER - - -class Fake_MLA(Fake_Attention): - def __init__(self): - super().__init__() - vllm_config = get_current_vllm_config() - self.kv_cache = [ - (torch.zeros(self.kv_shape, dtype=ms.bfloat16, device="Ascend"),) - for _ in range(vllm_config.parallel_config.pipeline_parallel_size) - ] - class MfModelBase(MsModelBase): def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super(MfModelBase, self).__init__( @@ -101,19 +72,6 @@ class MfModelBase(MsModelBase): ) self.mf_config.model.model_config.parallel_config.pipeline_stage = 1 - - def get_kvcache(self): - key_cache = [] - value_cache = [] - forward_context = get_forward_context() - for i in range(self.mf_model_config.num_layers): - k_cache = self.kv_caches[i].kv_cache[forward_context.virtual_engine][0] - v_cache = self.kv_caches[i].kv_cache[forward_context.virtual_engine][1] - key_cache.append(k_cache) - value_cache.append(v_cache) - return mutable(key_cache), mutable(value_cache) - - def forward( self, input_ids: Tensor, diff --git a/vllm_mindspore/model_executor/models/mf_models/qwen2.py b/vllm_mindspore/model_executor/models/mf_models/qwen2.py index 58df668e0edef52b05332d20c5817160230509d4..20a24ef31136d3e0c9f2512b0515f8fd4a4faf0a 100644 --- a/vllm_mindspore/model_executor/models/mf_models/qwen2.py +++ b/vllm_mindspore/model_executor/models/mf_models/qwen2.py @@ -18,23 +18,22 @@ from typing import Iterable, Set, Tuple -from vllm.config import VllmConfig -from vllm.config import get_current_vllm_config -from vllm.logger import init_logger - +from mindformers.models.llama import LlamaConfig as LlamaConfig_MF from mindspore import Tensor, JitConfig from mindspore.nn.utils import no_init_parameters - -from mindformers.models.llama import LlamaConfig as LlamaConfig_MF from research.qwen2_5.infer.qwen2_5 import ( ParallelQwenForCausalLM as ParallelQwenForCausalLM_MF, ) +from vllm.config import VllmConfig +from vllm.config import get_current_vllm_config +from vllm.logger import init_logger +from vllm_mindspore.model_executor.models.mf_models.qwen2_infer_parallelism import Qwen2InferParallelism from vllm_mindspore.model_executor.layers.sampler import get_sampler -from vllm_mindspore.model_executor.models.mf_models.mf_model_base import MfModelBase, Fake_Attention -from vllm_mindspore.model_executor.models.mf_models.qwen2_weight_processor import Qwen2WeightProcessor from vllm_mindspore.model_executor.models.mf_models.attention_mask import LowerTriangularMask - +from vllm_mindspore.model_executor.models.mf_models.mf_model_base import Fake_Attention, MfModelBase +from vllm_mindspore.model_executor.models.mf_models.qwen2_weight_processor import Qwen2WeightProcessor +from vllm_mindspore.model_executor.models.model_base import AttentionWrapper logger = init_logger(__name__) @@ -66,7 +65,7 @@ class Qwen2ForCausalLM(MfModelBase): self.sampler = get_sampler() self.set_modules({"model": self.network}) - self.kv_caches = [Fake_Attention() for i in range(self.mf_model_config.num_layers)] + self.kv_caches = [AttentionWrapper() for i in range(self.mf_model_config.num_layers)] compilation_config = get_current_vllm_config().compilation_config if prefix in compilation_config.static_forward_context: diff --git a/vllm_mindspore/model_executor/models/model_base.py b/vllm_mindspore/model_executor/models/model_base.py index f1bb23615371f91672f403da251fea01515e6aae..da0cf508a30ddd3c56beb26df78aa508294c922a 100644 --- a/vllm_mindspore/model_executor/models/model_base.py +++ b/vllm_mindspore/model_executor/models/model_base.py @@ -18,21 +18,62 @@ import os from abc import abstractmethod -from typing import Iterable, List, Optional, Set, Tuple, Union, Dict +from typing import Dict, Iterable, List, Optional, Set, Tuple, Union + +import mindspore as ms +import numpy as np +import torch +from mindspore import Tensor, nn, mutable +from mindspore import dtype as mstype + +from mindformers.core.context import build_context +from mindformers.core.parallel_config import build_parallel_config +from mindformers.tools.register.config import MindFormerConfig from vllm.attention import AttentionMetadata -from vllm.config import VllmConfig +from vllm.attention.backends.abstract import AttentionType +from vllm.config import VllmConfig, get_current_vllm_config +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.forward_context import ForwardContext, get_forward_context +from vllm.logger import init_logger from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from mindspore import Tensor, nn, mutable -from mindspore import dtype as mstype - +from vllm_mindspore.model_executor.models.model_base import MsModelBase from vllm_mindspore.utils import STR_DTYPE_TO_MS_DTYPE +class AttentionWrapper: + def __init__(self): + vllm_config = get_current_vllm_config() + block_size = vllm_config.cache_config.block_size + num_kv_heads = vllm_config.model_config.get_num_kv_heads( + vllm_config.parallel_config + ) + head_size = vllm_config.model_config.get_head_size() + num_block = 0 + self.kv_shape = [num_block, block_size, num_kv_heads, head_size] + self.kv_cache = [ + ( + torch.zeros(self.kv_shape, dtype=ms.bfloat16, device="Ascend"), + torch.zeros(self.kv_shape, dtype=ms.bfloat16, device="Ascend"), + ) + for _ in range(vllm_config.parallel_config.pipeline_parallel_size) + ] + self.attn_type = AttentionType.DECODER + -class MsModelBase(): +class MLAWrapper(AttentionWrapper): + def __init__(self): + super().__init__() + vllm_config = get_current_vllm_config() + self.kv_cache = [ + (torch.zeros(self.kv_shape, dtype=ms.bfloat16, device="Ascend"),) + for _ in range(vllm_config.parallel_config.pipeline_parallel_size) + ] + + +class MsModelBase: def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super(MsModelBase, self).__init__() config = vllm_config.model_config.hf_config @@ -212,3 +253,14 @@ class MsModelBase(): @abstractmethod def load_weights(self, weights: Iterable[Tuple[str, Tensor]]) -> Set[str]: raise NotImplementedError("Function load_weights should be Implemented!") + + def get_kvcache(self): + key_cache = [] + value_cache = [] + forward_context = get_forward_context() + for i in range(self.mf_model_config.num_layers): + k_cache = self.kv_caches[i].kv_cache[forward_context.virtual_engine][0] + v_cache = self.kv_caches[i].kv_cache[forward_context.virtual_engine][1] + key_cache.append(k_cache) + value_cache.append(v_cache) + return mutable(key_cache), mutable(value_cache) diff --git a/vllm_mindspore/worker/model_runner.py b/vllm_mindspore/worker/model_runner.py index 767bde8d2254cefcc2afde5494bf10d1661780f7..e63a28ab72e1b22acf5784ed65d243e8127c6722 100644 --- a/vllm_mindspore/worker/model_runner.py +++ b/vllm_mindspore/worker/model_runner.py @@ -44,135 +44,6 @@ def profile_run(self) -> None: max_num_batched_tokens = \ self.scheduler_config.max_num_batched_tokens max_num_seqs = self.scheduler_config.max_num_seqs - self._dummy_run(max_num_batched_tokens, max_num_seqs) - - -def _dummy_run(self, - max_num_batched_tokens: int, - max_num_seqs: int = 1) -> None: - with self.set_in_profile_run(): - # Enable top-k sampling to reflect the accurate memory usage. - sampling_params = \ - SamplingParams(top_p=0.99, top_k=self.vocab_size - 1) - - # This represents the maximum number of different requests - # that will have unique loras, an therefore the max amount of memory - # consumption create dummy lora request copies from the lora request - # passed in, which contains a lora from the lora warmup path. - dummy_lora_requests: List[LoRARequest] = [] - dummy_lora_requests_per_seq: List[LoRARequest] = [] - if self.lora_config: - assert self.lora_manager is not None - with self.lora_manager.dummy_lora_cache(): - for idx in range(self.lora_config.max_loras): - lora_id = idx + 1 - dummy_lora_request = LoRARequest( - lora_name=f"warmup_{lora_id}", - lora_int_id=lora_id, - lora_path="/not/a/real/path", - ) - self.lora_manager.add_dummy_lora(dummy_lora_request, - rank=LORA_WARMUP_RANK) - dummy_lora_requests.append(dummy_lora_request) - dummy_lora_requests_per_seq = [ - dummy_lora_requests[idx % len(dummy_lora_requests)] - for idx in range(max_num_seqs) - ] - - # Profile memory usage with max_num_sequences sequences and the - # total number of tokens equal to max_num_batched_tokens. - seqs: List[SequenceGroupMetadata] = [] - # Additional GPU memory may be needed for multi-modal encoding, - # which needs to be accounted for when calculating the GPU blocks - # for vLLM blocker manager. - # To exercise the worst scenario for GPU memory consumption, - # the number of seqs (batch_size) is chosen to maximize the number - # of images processed. - - max_mm_tokens = self.mm_registry.get_max_multimodal_tokens( - self.model_config) - if max_mm_tokens > 0: - max_num_seqs_orig = max_num_seqs - max_num_seqs = min(max_num_seqs, - max_num_batched_tokens // max_mm_tokens) - if max_num_seqs < 1: - expr = (f"min({max_num_seqs_orig}, " - f"{max_num_batched_tokens} // {max_mm_tokens})") - logger.warning( - "Computed max_num_seqs (%s) to be less than 1. " - "Setting it to the minimum value of 1.", expr) - max_num_seqs = 1 - - batch_size = 0 - for group_id in range(max_num_seqs): - seq_len = (max_num_batched_tokens // max_num_seqs + - (group_id < max_num_batched_tokens % max_num_seqs)) - batch_size += seq_len - - dummy_data = self.input_registry \ - .dummy_data_for_profiling(self.model_config, - seq_len, - self.mm_registry) - - seq = SequenceGroupMetadata( - request_id=str(group_id), - is_prompt=True, - seq_data={group_id: dummy_data.seq_data}, - sampling_params=sampling_params, - block_tables=None, - lora_request=dummy_lora_requests_per_seq[group_id] - if dummy_lora_requests_per_seq else None, - multi_modal_data=dummy_data.multi_modal_data, - multi_modal_placeholders=dummy_data. - multi_modal_placeholders, - ) - seqs.append(seq) - - # Run the model with the dummy inputs. - num_layers = self.model_config.get_num_layers(self.parallel_config) - # use an empty tensor instead of `None`` to force Dynamo to pass - # it by reference, rather by specializing on the value ``None``. - # the `dtype` argument does not matter, and we use `float32` as - # a placeholder (it has wide hardware support). - # it is important to create tensors inside the loop, rather than - # multiplying the list, to avoid Dynamo from treating them as - # tensor aliasing. - kv_cache_dtype = self.model_config.dtype if self.cache_config.cache_dtype == "auto" \ - else self.cache_config.cache_dtype - kv_cache_dtype = STR_DTYPE_TO_TENSOR_DTYPE[kv_cache_dtype] - block_size = self.cache_config.block_size - num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config) - head_size = self.model_config.get_head_size() - kv_shape = [0, block_size, num_kv_heads, head_size] - kv_caches = mutable([ - mutable(( - mutable(torch.tensor([], dtype=kv_cache_dtype, device=self.device).reshape(kv_shape)), - mutable(torch.tensor([], dtype=kv_cache_dtype, device=self.device).reshape(kv_shape)), - )) - for _ in range(num_layers) - ]) - finished_requests_ids = [seq.request_id for seq in seqs] - model_input = self.prepare_model_input( - seqs, finished_requests_ids=finished_requests_ids) - intermediate_tensors = None - if not get_pp_group().is_first_rank: - intermediate_tensors = \ - self.model.make_empty_intermediate_tensors( - batch_size=batch_size, - dtype=self.model_config.dtype, - device=self.device) - - # Disable KV Scale Calculation for dummy data during profile run - if model_input.attn_metadata is not None: - model_input.attn_metadata.enable_kv_scales_calculation = False - - self.execute_model(model_input, kv_caches, intermediate_tensors) - torch.cuda.synchronize() - if self.lora_config: - # Remove dummy loras. - assert self.lora_manager is not None - self.remove_all_loras() - return MULTI_STEP_ATTENTION_BACKENDS = [