diff --git a/vllm_mindspore/__init__.py b/vllm_mindspore/__init__.py index 86f0a8f6c0e23560fcf0ae297ab14751cc378d83..de11500c1e1d5eb5830b83f89d55240046fc59b0 100644 --- a/vllm_mindspore/__init__.py +++ b/vllm_mindspore/__init__.py @@ -47,7 +47,6 @@ vllm.utils.current_platform = ascend_platform from vllm_mindspore.utils import ( direct_register_custom_op, - memory_profiling, make_tensor_with_pad, async_tensor_h2d, get_dtype_size, @@ -56,7 +55,6 @@ from vllm_mindspore.utils import ( ) vllm.utils.direct_register_custom_op = direct_register_custom_op -vllm.utils.memory_profiling = memory_profiling vllm.utils.make_tensor_with_pad = make_tensor_with_pad vllm.utils.async_tensor_h2d = async_tensor_h2d vllm.utils.get_dtype_size = get_dtype_size @@ -129,8 +127,7 @@ vllm.model_executor.model_loader.loader.safetensors_weights_iterator = ( ) from vllm_mindspore.worker.worker import ( - _warm_up_model, - determine_num_available_blocks, + _warm_up_model ) from vllm_mindspore.worker.profile import ( wrapper_worker_init, @@ -139,13 +136,13 @@ from vllm_mindspore.worker.profile import ( from vllm.worker.worker import Worker Worker._warm_up_model = _warm_up_model -Worker.determine_num_available_blocks = determine_num_available_blocks Worker.__init__ = wrapper_worker_init(Worker.__init__) Worker.init_device = wrapper_worker_init_device(Worker.init_device) from vllm_mindspore.worker.model_runner import ( _get_cuda_graph_pad_size, profile_run, + _dummy_run, _get_supported_attention_backends ) @@ -153,6 +150,7 @@ vllm.worker.model_runner.ModelInputForGPUBuilder._get_cuda_graph_pad_size = ( _get_cuda_graph_pad_size ) vllm.worker.model_runner.GPUModelRunnerBase.profile_run = profile_run +vllm.worker.model_runner.GPUModelRunnerBase._dummy_run = _dummy_run import vllm.worker.multi_step_model_runner vllm.worker.multi_step_model_runner._get_supported_attention_backends = _get_supported_attention_backends diff --git a/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py b/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py index 9b2a5a3846406ad1c0eb488e89b2e71c835a8b56..47dcce929a87abb4405d2d22ca8106d5507a41f2 100644 --- a/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py +++ b/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py @@ -44,8 +44,7 @@ from research.deepseek3.deepseek3 import ( ) from vllm_mindspore.model_executor.layers.sampler import get_sampler -from vllm_mindspore.model_executor.models.mf_models.mf_model_base import MfModelBase, Fake_Attention -from vllm_mindspore.utils import calc_block_num +from vllm_mindspore.model_executor.models.mf_models.mf_model_base import MfModelBase, Fake_MLA from vllm_mindspore.model_executor.models.mf_models.deepseekv3_infer_parallelism import DeepseekInferParallelism from vllm_mindspore.model_executor.models.mf_models.attention_mask import LowerTriangularMask @@ -63,11 +62,10 @@ class DeepseekV3ForCausalLM(MfModelBase): self.mf_config.load_checkpoint = self.get_model_path() self.mf_model_config = DeepseekV3Config_MF(**self.mf_config.model.model_config) - self.mf_model_config.num_blocks = calc_block_num(self.cache_config, self.model_config, self.parallel_config) - self.mf_model_config.block_size = self.cache_config.block_size if self.mf_config.moe_config: self.mf_model_config.moe_config = self.mf_config.moe_config self.mf_model_config.return_hidden_states = True + setattr(self.mf_model_config, 'npu_mem_size', -1) self.is_quant = bool(hasattr(self.mf_model_config, "quantization_config") and self.mf_model_config.quantization_config) @@ -89,7 +87,7 @@ class DeepseekV3ForCausalLM(MfModelBase): self.sampler = get_sampler() self.set_modules({"model": self.network}) - self.kv_caches = [Fake_Attention() for i in range(self.mf_model_config.num_layers)] + self.kv_caches = [Fake_MLA() for i in range(self.mf_model_config.num_layers)] compilation_config = get_current_vllm_config().compilation_config if prefix in compilation_config.static_forward_context: @@ -100,33 +98,28 @@ class DeepseekV3ForCausalLM(MfModelBase): self.casual_mask = LowerTriangularMask(mf_model_config=self.mf_model_config) self.set_flags = False - def update_mf_kvcaches(self): - if self.mf_kvcaches_init: - return - + def get_kvcache(self): + from mindspore import mutable + key_cache = [] forward_context = get_forward_context() for i in range(self.mf_model_config.num_layers): k_cache = self.kv_caches[i].kv_cache[forward_context.virtual_engine][0] - mf_k_cache, _ = self.network.kvcache(i) - - mf_k_cache.set_device_address( - k_cache._data_ptr(), k_cache.shape, k_cache.dtype - ) - self.mf_kvcaches_init = True + key_cache.append(k_cache) + return mutable(key_cache), None def load_weights(self, weights: Iterable[Tuple[str, Tensor]]) -> Set[str]: - if self.mf_config.load_ckpt_format == "ckpt": - model = Model(self.network) - batch_size = self.mf_config.model.model_config.batch_size - seq_length = self.mf_config.model.model_config.seq_length - input_ids = np.ones(shape=tuple([batch_size, seq_length])) - infer_data = self.network.prepare_inputs_for_predict_layout(input_ids) - transform_and_load_checkpoint( - self.mf_config, model, self.network, infer_data, do_predict=True - ) - else: - model_parallelism = DeepseekInferParallelism(self.mf_config, self.network, self.is_quant) - model_parallelism.infer_convert_and_parallelism(self.mf_config.load_checkpoint) + # if self.mf_config.load_ckpt_format == "ckpt": + # model = Model(self.network) + # batch_size = self.mf_config.model.model_config.batch_size + # seq_length = self.mf_config.model.model_config.seq_length + # input_ids = np.ones(shape=tuple([batch_size, seq_length])) + # infer_data = self.network.prepare_inputs_for_predict_layout(input_ids) + # transform_and_load_checkpoint( + # self.mf_config, model, self.network, infer_data, do_predict=True + # ) + # else: + # model_parallelism = DeepseekInferParallelism(self.mf_config, self.network, self.is_quant) + # model_parallelism.infer_convert_and_parallelism(self.mf_config.load_checkpoint) self.network.set_dynamic_inputs() return None diff --git a/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py b/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py index 56f394c48b59358c95e2be6da0f813e442313045..979b2e8f6754d427bf03c5dbd24adc454ed88a1a 100644 --- a/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py +++ b/vllm_mindspore/model_executor/models/mf_models/mf_model_base.py @@ -33,7 +33,7 @@ from vllm.logger import init_logger import torch import mindspore as ms -from mindspore import Tensor +from mindspore import Tensor, mutable from mindformers.tools.register.config import MindFormerConfig from mindformers.core.context import build_context @@ -57,13 +57,33 @@ def _batch_seq(input_tokens, prefill): class Fake_Attention: def __init__(self): + vllm_config = get_current_vllm_config() + block_size = vllm_config.cache_config.block_size + num_kv_heads = vllm_config.model_config.get_num_kv_heads( + vllm_config.parallel_config + ) + head_size = vllm_config.model_config.get_head_size() + num_block = 0 + self.kv_shape = [num_block, block_size, num_kv_heads, head_size] self.kv_cache = [ - torch.tensor([]) for _ in range(get_current_vllm_config( - ).parallel_config.pipeline_parallel_size) + ( + torch.zeros(self.kv_shape, dtype=ms.bfloat16, device="Ascend"), + torch.zeros(self.kv_shape, dtype=ms.bfloat16, device="Ascend"), + ) + for _ in range(vllm_config.parallel_config.pipeline_parallel_size) ] self.attn_type = AttentionType.DECODER +class Fake_MLA(Fake_Attention): + def __init__(self): + super().__init__() + vllm_config = get_current_vllm_config() + self.kv_cache = [ + (torch.zeros(self.kv_shape, dtype=ms.bfloat16, device="Ascend"),) + for _ in range(vllm_config.parallel_config.pipeline_parallel_size) + ] + class MfModelBase(MsModelBase): def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: super(MfModelBase, self).__init__( @@ -82,22 +102,17 @@ class MfModelBase(MsModelBase): self.mf_config.model.model_config.parallel_config.pipeline_stage = 1 - def update_mf_kvcaches(self): - if self.mf_kvcaches_init: - return - + def get_kvcache(self): + from mindspore import mutable + key_cache = [] + value_cache = [] forward_context = get_forward_context() for i in range(self.mf_model_config.num_layers): k_cache = self.kv_caches[i].kv_cache[forward_context.virtual_engine][0] v_cache = self.kv_caches[i].kv_cache[forward_context.virtual_engine][1] - mf_k_cache, mf_v_cache = self.network.kvcache(i) - mf_k_cache.set_device_address( - k_cache._data_ptr(), k_cache.shape, k_cache.dtype - ) - mf_v_cache.set_device_address( - v_cache._data_ptr(), v_cache.shape, v_cache.dtype - ) - self.mf_kvcaches_init = True + key_cache.append(k_cache) + value_cache.append(v_cache) + return mutable(key_cache), mutable(value_cache) def forward( @@ -109,7 +124,7 @@ class MfModelBase(MsModelBase): intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[Tensor] = None, ) -> Union[Tensor, IntermediateTensors]: - self.update_mf_kvcaches() + key_cache, value_cache = self.get_kvcache() seq_lens = attn_metadata.seq_lens max_query_len = attn_metadata.max_query_len @@ -141,6 +156,8 @@ class MfModelBase(MsModelBase): model_inputs["position_ids"] = position_ids model_inputs["q_seq_lens"] = q_seq_lens model_inputs["attention_mask"] = attention_mask + model_inputs["key_cache"] = key_cache + model_inputs["value_cache"] = value_cache if is_prefill: self.network.phase = "prefill" diff --git a/vllm_mindspore/model_executor/models/mf_models/qwen2.py b/vllm_mindspore/model_executor/models/mf_models/qwen2.py index 78247b8e1a59f4941d818d062a7e51c1c48523c3..5bc536d418a34e078d28681af7e97e7aee5d91f1 100644 --- a/vllm_mindspore/model_executor/models/mf_models/qwen2.py +++ b/vllm_mindspore/model_executor/models/mf_models/qwen2.py @@ -31,7 +31,6 @@ from research.qwen2_5.infer.qwen2_5 import ( from vllm_mindspore.model_executor.layers.sampler import get_sampler from vllm_mindspore.model_executor.models.mf_models.mf_model_base import MfModelBase, Fake_Attention -from vllm_mindspore.utils import calc_block_num from vllm_mindspore.model_executor.models.mf_models.qwen2_infer_parallelism import Qwen2InferParallelism from vllm_mindspore.model_executor.models.mf_models.attention_mask import LowerTriangularMask @@ -44,17 +43,13 @@ class Qwen2ForCausalLM(MfModelBase): super(Qwen2ForCausalLM, self).__init__(vllm_config=vllm_config, prefix=prefix) self.mf_model_config = LlamaConfig_MF(**self.mf_config.model.model_config) - # Cannot get num_gpu_blocks from cache config now, calculate one first. - self.mf_model_config.num_blocks = calc_block_num( - self.cache_config, self.model_config, self.parallel_config - ) - self.mf_model_config.block_size = self.cache_config.block_size if self.mf_config.moe_config: self.mf_model_config.moe_config = self.mf_config.moe_config self.mf_model_config.return_hidden_states = True # qwen qkv concat will support in next version self.mf_model_config.qkv_concat = False + setattr(self.mf_model_config, 'npu_mem_size', -1) self.mf_config.model.model_config.qkv_concat = False # Initial network self.network = ParallelQwenForCausalLM_MF(self.mf_model_config) @@ -81,8 +76,8 @@ class Qwen2ForCausalLM(MfModelBase): self.set_flags = False def load_weights(self, weights: Iterable[Tuple[str, Tensor]]) -> Set[str]: - model_parallelism = Qwen2InferParallelism(self.mf_config, self.network, False) - model_parallelism.infer_convert_and_parallelism(self.mf_config.load_checkpoint) + # model_parallelism = Qwen2InferParallelism(self.mf_config, self.network, False) + # model_parallelism.infer_convert_and_parallelism(self.mf_config.load_checkpoint) self.network.set_dynamic_inputs() diff --git a/vllm_mindspore/platforms/ascend.py b/vllm_mindspore/platforms/ascend.py index 31fe2c2e07cf1c8325edc3e875e1b077d7aedb41..b96403d4959fe14be3fcf3dc2f43d2b9bcc564e1 100644 --- a/vllm_mindspore/platforms/ascend.py +++ b/vllm_mindspore/platforms/ascend.py @@ -92,18 +92,6 @@ class AscendPlatform(Platform): if cache_config and cache_config.block_size is None: cache_config.block_size = 16 - if os.getenv("ASCEND_TOTAL_MEMORY_GB"): - total_device_memory = int(os.environ["ASCEND_TOTAL_MEMORY_GB"]) - else: - total_device_memory = 64 - logger.warning( - "Total device memory should be set by environ 'ASCEND_TOTAL_MEMORY_GB', " - "please check size by cmd(npu-smi info). " - "For now, we will try default size(64GB) which might not be correct exactly." - ) - max_device_memory_for_ms = str(total_device_memory * cache_config.gpu_memory_utilization) + "GB" - ms.set_context(max_device_memory=max_device_memory_for_ms) - logger.info("max_device_memory for mindspore is: ", max_device_memory_for_ms) @classmethod def get_attn_backend_cls(cls, selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1, use_mla): @@ -126,6 +114,7 @@ class AscendPlatform(Platform): @classmethod def get_current_memory_usage(cls, device: Optional[torch.types.Device] = None) -> float: """Return the memory usage in bytes.""" + torch.cuda.reset_peak_memory_stats() return torch.cuda.max_memory_allocated(device) @classmethod diff --git a/vllm_mindspore/utils.py b/vllm_mindspore/utils.py index b501fd86fea495f3a2d3d733c22cde74fc46b229..ac44272f1f666383abe1cfd17c72df5693fbca92 100644 --- a/vllm_mindspore/utils.py +++ b/vllm_mindspore/utils.py @@ -79,89 +79,6 @@ def direct_register_custom_op( ): ... -@contextlib.contextmanager -def memory_profiling( - baseline_snapshot: "MemorySnapshot", - weights_memory: int) -> "Generator[MemoryProfilingResult, None, None]": - """Memory profiling context manager. - baseline_snapshot: the memory snapshot before the current vLLM instance. - weights_memory: memory used by PyTorch when loading the model weights. - Note that, before loading the model weights, we also initialize the device - and distributed environment, which may consume some memory. This part is not - included in the weights_memory because PyTorch does not control it. - - The memory in one GPU can be classified into 3 categories: - 1. memory used by anything other than the current vLLM instance. - 2. memory used by torch in the current vLLM instance. - 3. memory used in the current vLLM instance, but not by torch. - - A quantitive example: - - Before creating the current vLLM instance: - category 1: 1 GiB - category 2: 0 GiB - category 3: 0 GiB - - After creating the current vLLM instance and loading the model, - (i.e. before profiling): - category 1: 1 GiB - category 2: 2 GiB (model weights take 2 GiB) - category 3: 0.5 GiB (memory used by NCCL) - - During profiling (peak): - category 1: 1 GiB - category 2: 4 GiB (peak activation tensors take 2 GiB) - category 3: 1 GiB (memory used by NCCL + buffers for some attention backends) - - After profiling: - category 1: 1 GiB - category 2: 3 GiB (after garbage-collecting activation tensors) - category 3: 1 GiB (memory used by NCCL + buffers for some attention backends) - - In this case, non-kv cache takes 5 GiB in total, including: - a. 2 GiB used by the model weights (category 2) - b. 2 GiB reserved for the peak activation tensors (category 2) - c. 1 GiB used by non-torch components (category 3) - - The memory used for loading weights (a.) is directly given from the argument `weights_memory`. - - The increase of `torch.cuda.memory_stats()["allocated_bytes.all.peak"]` during profiling gives (b.). - - The increase of `non_torch_memory` from creating the current vLLM instance until after profiling to get (c.). - """ # noqa - from vllm.utils import MemoryProfilingResult - - gc.collect() - torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats() - - result = MemoryProfilingResult() - - result.before_create = baseline_snapshot - # the part of memory used for holding the model weights - result.weights_memory = weights_memory - - result.before_profile.measure() - - before_torch_memory_in_bytes = torch.cuda.memory_stats()["allocated_bytes.all.current"] - - yield result - - gc.collect() - torch.cuda.empty_cache() - - result.after_profile.measure() - - after_torch_memory_in_bytes = torch.cuda.memory_stats()["allocated_bytes.all.current"] - - diff_profile = result.after_profile - result.before_profile - diff_from_create = result.after_profile - result.before_create - result.torch_peak_increase = diff_profile.torch_peak - result.non_torch_increase = after_torch_memory_in_bytes - before_torch_memory_in_bytes - result.profile_time = diff_profile.timestamp - result.non_kv_cache_memory = result.non_torch_increase + result.torch_peak_increase + result.weights_memory # noqa - - def _create_empty_tensor(ms_type): init_func = Zero() init_func.__enable_zero_dim__ = True @@ -307,7 +224,7 @@ def check_ready(): if is_mindformers_model_backend(): logger.info("Run with Mindformers backend!") - necessary_envs = ("vLLM_MODEL_MEMORY_USE_GB", "MINDFORMERS_MODEL_CONFIG") + necessary_envs = ("MINDFORMERS_MODEL_CONFIG", ) lost_envs = [env_item for env_item in necessary_envs if not os.getenv(env_item)] if lost_envs: @@ -325,26 +242,6 @@ def check_ready(): env_setup({"MS_ALLOC_CONF": "enable_vmm:True", }) logger.info("Run with native model backend!") - -def calc_block_num(cache_config, model_config, parallel_config): - from vllm.worker.cache_engine import CacheEngine - - torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats() - - total_gpu_memory = int(os.environ["ASCEND_TOTAL_MEMORY_GB"]) if os.getenv("ASCEND_TOTAL_MEMORY_GB") else 64 - total_gpu_memory = total_gpu_memory * 1024 * 1024 * 1024 - memory_can_use = total_gpu_memory * cache_config.gpu_memory_utilization - - model_use_memory_b = int(os.getenv("vLLM_MODEL_MEMORY_USE_GB")) * 1024 * 1024 * 1024 - available_cache_memory = memory_can_use - model_use_memory_b - cache_block_size = CacheEngine.get_cache_block_size( - cache_config, model_config, parallel_config - ) - num_gpu_blocks = int(available_cache_memory // cache_block_size) - return num_gpu_blocks - - def is_use_mla(model_config): if not is_mindformers_model_backend(): return False diff --git a/vllm_mindspore/worker/model_runner.py b/vllm_mindspore/worker/model_runner.py index 9a6df84634af2f36e62a3dd1eb3b4d80ecb12b50..767bde8d2254cefcc2afde5494bf10d1661780f7 100644 --- a/vllm_mindspore/worker/model_runner.py +++ b/vllm_mindspore/worker/model_runner.py @@ -41,119 +41,138 @@ def _get_cuda_graph_pad_size( def profile_run(self) -> None: - # Enable top-k sampling to reflect the accurate memory usage. - sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1) - max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens + max_num_batched_tokens = \ + self.scheduler_config.max_num_batched_tokens max_num_seqs = self.scheduler_config.max_num_seqs - # This represents the maximum number of different requests - # that will have unique loras, an therefore the max amount of memory - # consumption create dummy lora request copies from the lora request - # passed in, which contains a lora from the lora warmup path. - dummy_lora_requests: List[LoRARequest] = [] - dummy_lora_requests_per_seq: List[LoRARequest] = [] - if self.lora_config: - assert self.lora_manager is not None - with self.lora_manager.dummy_lora_cache(): - for idx in range(self.lora_config.max_loras): - lora_id = idx + 1 - dummy_lora_request = LoRARequest( - lora_name=f"warmup_{lora_id}", - lora_int_id=lora_id, - lora_path="/not/a/real/path", - ) - self.lora_manager.add_dummy_lora(dummy_lora_request, - rank=LORA_WARMUP_RANK) - dummy_lora_requests.append(dummy_lora_request) - dummy_lora_requests_per_seq = [ - dummy_lora_requests[idx % len(dummy_lora_requests)] - for idx in range(max_num_seqs) - ] - - # Profile memory usage with max_num_sequences sequences and the total - # number of tokens equal to max_num_batched_tokens. - seqs: List[SequenceGroupMetadata] = [] - # Additional GPU memory may be needed for multi-modal encoding, which - # needs to be accounted for when calculating the GPU blocks for - # vLLM blocker manager. - # To exercise the worst scenario for GPU memory consumption, - # the number of seqs (batch_size) is chosen to maximize the number - # of images processed. - - max_mm_tokens = self.mm_registry.get_max_multimodal_tokens( - self.model_config) - if max_mm_tokens > 0: - max_num_seqs_orig = max_num_seqs - max_num_seqs = min(max_num_seqs, - max_num_batched_tokens // max_mm_tokens) - if max_num_seqs < 1: - expr = (f"min({max_num_seqs_orig}, " - f"{max_num_batched_tokens} // {max_mm_tokens})") - logger.warning( - "Computed max_num_seqs (%s) to be less than 1. " - "Setting it to the minimum value of 1.", expr) - max_num_seqs = 1 - - batch_size = 0 - for group_id in range(max_num_seqs): - seq_len = (max_num_batched_tokens // max_num_seqs + - (group_id < max_num_batched_tokens % max_num_seqs)) - batch_size += seq_len - - dummy_data = self.input_registry \ - .dummy_data_for_profiling(self.model_config, - seq_len, - self.mm_registry) - - seq = SequenceGroupMetadata( - request_id=str(group_id), - is_prompt=True, - seq_data={group_id: dummy_data.seq_data}, - sampling_params=sampling_params, - block_tables=None, - lora_request=dummy_lora_requests_per_seq[group_id] - if dummy_lora_requests_per_seq else None, - multi_modal_data=dummy_data.multi_modal_data, - multi_modal_placeholders=dummy_data.multi_modal_placeholders, - ) - seqs.append(seq) - - # Run the model with the dummy inputs. - num_layers = self.model_config.get_num_layers(self.parallel_config) - # use an empty tensor instead of `None`` to force Dynamo to pass - # it by reference, rather by specializing on the value ``None``. - # the `dtype` argument does not matter, and we use `float32` as - # a placeholder (it has wide hardware support). - # it is important to create tensors inside the loop, rather than - # multiplying the list, to avoid Dynamo from treating them as - # tensor aliasing. - - kv_cache_dtype = self.model_config.dtype if self.cache_config.cache_dtype == "auto" \ - else self.cache_config.cache_dtype - kv_cache_dtype = STR_DTYPE_TO_TENSOR_DTYPE[kv_cache_dtype] - block_size = self.cache_config.block_size - num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config) - head_size = self.model_config.get_head_size() - kv_shape = [0, block_size, num_kv_heads, head_size] - kv_caches = mutable([ - mutable(( - mutable(torch.tensor([], dtype=kv_cache_dtype, device=self.device).reshape(kv_shape)), - mutable(torch.tensor([], dtype=kv_cache_dtype, device=self.device).reshape(kv_shape)), - )) - for _ in range(num_layers) - ]) - finished_requests_ids = [seq.request_id for seq in seqs] - model_input = self.prepare_model_input( - seqs, finished_requests_ids=finished_requests_ids) - intermediate_tensors = None - if not get_pp_group().is_first_rank: - intermediate_tensors = self.model.make_empty_intermediate_tensors( - batch_size=batch_size, - dtype=self.model_config.dtype, - device=self.device) - - self.execute_model(model_input, kv_caches, intermediate_tensors) - torch.cuda.synchronize() - return + self._dummy_run(max_num_batched_tokens, max_num_seqs) + + +def _dummy_run(self, + max_num_batched_tokens: int, + max_num_seqs: int = 1) -> None: + with self.set_in_profile_run(): + # Enable top-k sampling to reflect the accurate memory usage. + sampling_params = \ + SamplingParams(top_p=0.99, top_k=self.vocab_size - 1) + + # This represents the maximum number of different requests + # that will have unique loras, an therefore the max amount of memory + # consumption create dummy lora request copies from the lora request + # passed in, which contains a lora from the lora warmup path. + dummy_lora_requests: List[LoRARequest] = [] + dummy_lora_requests_per_seq: List[LoRARequest] = [] + if self.lora_config: + assert self.lora_manager is not None + with self.lora_manager.dummy_lora_cache(): + for idx in range(self.lora_config.max_loras): + lora_id = idx + 1 + dummy_lora_request = LoRARequest( + lora_name=f"warmup_{lora_id}", + lora_int_id=lora_id, + lora_path="/not/a/real/path", + ) + self.lora_manager.add_dummy_lora(dummy_lora_request, + rank=LORA_WARMUP_RANK) + dummy_lora_requests.append(dummy_lora_request) + dummy_lora_requests_per_seq = [ + dummy_lora_requests[idx % len(dummy_lora_requests)] + for idx in range(max_num_seqs) + ] + + # Profile memory usage with max_num_sequences sequences and the + # total number of tokens equal to max_num_batched_tokens. + seqs: List[SequenceGroupMetadata] = [] + # Additional GPU memory may be needed for multi-modal encoding, + # which needs to be accounted for when calculating the GPU blocks + # for vLLM blocker manager. + # To exercise the worst scenario for GPU memory consumption, + # the number of seqs (batch_size) is chosen to maximize the number + # of images processed. + + max_mm_tokens = self.mm_registry.get_max_multimodal_tokens( + self.model_config) + if max_mm_tokens > 0: + max_num_seqs_orig = max_num_seqs + max_num_seqs = min(max_num_seqs, + max_num_batched_tokens // max_mm_tokens) + if max_num_seqs < 1: + expr = (f"min({max_num_seqs_orig}, " + f"{max_num_batched_tokens} // {max_mm_tokens})") + logger.warning( + "Computed max_num_seqs (%s) to be less than 1. " + "Setting it to the minimum value of 1.", expr) + max_num_seqs = 1 + + batch_size = 0 + for group_id in range(max_num_seqs): + seq_len = (max_num_batched_tokens // max_num_seqs + + (group_id < max_num_batched_tokens % max_num_seqs)) + batch_size += seq_len + + dummy_data = self.input_registry \ + .dummy_data_for_profiling(self.model_config, + seq_len, + self.mm_registry) + + seq = SequenceGroupMetadata( + request_id=str(group_id), + is_prompt=True, + seq_data={group_id: dummy_data.seq_data}, + sampling_params=sampling_params, + block_tables=None, + lora_request=dummy_lora_requests_per_seq[group_id] + if dummy_lora_requests_per_seq else None, + multi_modal_data=dummy_data.multi_modal_data, + multi_modal_placeholders=dummy_data. + multi_modal_placeholders, + ) + seqs.append(seq) + + # Run the model with the dummy inputs. + num_layers = self.model_config.get_num_layers(self.parallel_config) + # use an empty tensor instead of `None`` to force Dynamo to pass + # it by reference, rather by specializing on the value ``None``. + # the `dtype` argument does not matter, and we use `float32` as + # a placeholder (it has wide hardware support). + # it is important to create tensors inside the loop, rather than + # multiplying the list, to avoid Dynamo from treating them as + # tensor aliasing. + kv_cache_dtype = self.model_config.dtype if self.cache_config.cache_dtype == "auto" \ + else self.cache_config.cache_dtype + kv_cache_dtype = STR_DTYPE_TO_TENSOR_DTYPE[kv_cache_dtype] + block_size = self.cache_config.block_size + num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config) + head_size = self.model_config.get_head_size() + kv_shape = [0, block_size, num_kv_heads, head_size] + kv_caches = mutable([ + mutable(( + mutable(torch.tensor([], dtype=kv_cache_dtype, device=self.device).reshape(kv_shape)), + mutable(torch.tensor([], dtype=kv_cache_dtype, device=self.device).reshape(kv_shape)), + )) + for _ in range(num_layers) + ]) + finished_requests_ids = [seq.request_id for seq in seqs] + model_input = self.prepare_model_input( + seqs, finished_requests_ids=finished_requests_ids) + intermediate_tensors = None + if not get_pp_group().is_first_rank: + intermediate_tensors = \ + self.model.make_empty_intermediate_tensors( + batch_size=batch_size, + dtype=self.model_config.dtype, + device=self.device) + + # Disable KV Scale Calculation for dummy data during profile run + if model_input.attn_metadata is not None: + model_input.attn_metadata.enable_kv_scales_calculation = False + + self.execute_model(model_input, kv_caches, intermediate_tensors) + torch.cuda.synchronize() + if self.lora_config: + # Remove dummy loras. + assert self.lora_manager is not None + self.remove_all_loras() + return MULTI_STEP_ATTENTION_BACKENDS = [ diff --git a/vllm_mindspore/worker/worker.py b/vllm_mindspore/worker/worker.py index b3c87dc931abdc5f0e89c57fa4d911f2b0b2d022..f05eef3aa33d59195dc8a3f108fda696408b10a1 100644 --- a/vllm_mindspore/worker/worker.py +++ b/vllm_mindspore/worker/worker.py @@ -94,99 +94,3 @@ def _warm_up_model(self) -> None: # Reset the seed to ensure that the random state is not affected by # the model initialization and profiling. set_random_seed(self.model_config.seed) - - -def determine_num_available_blocks(self) -> Tuple[int, int]: - """Profiles the peak memory usage of the model to determine how many - KV blocks may be allocated without OOMs. - - The engine will first conduct a profiling of the existing memory usage. - Then, it calculate the maximum possible number of GPU and CPU blocks - that can be allocated with the remaining free memory. - - .. tip:: - You may limit the usage of GPU memory - by adjusting the `gpu_memory_utilization` parameter. - """ - from vllm.utils import GiB_bytes, memory_profiling - - # Profile the memory usage of the model and get the maximum number of - # cache blocks that can be allocated with the remaining free memory. - torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats() - - total_gpu_memory = int(os.environ["ASCEND_TOTAL_MEMORY_GB"]) if os.getenv("ASCEND_TOTAL_MEMORY_GB") else 64 - total_gpu_memory = total_gpu_memory * 1024 * 1024 * 1024 - - if os.getenv("vLLM_MODEL_MEMORY_USE_GB"): - memory_use_for_model_run = int(os.environ["vLLM_MODEL_MEMORY_USE_GB"]) * 1024 * 1024 * 1024 - else: - # Execute a forward pass with dummy inputs to profile the memory usage - # of the model. - _, total_gpu_memory = torch.cuda.mem_get_info() - with memory_profiling( - self.baseline_snapshot, - weights_memory=self.model_runner.model_memory_usage, - ) as result: - self.model_runner.profile_run() - torch.cuda.synchronize() - - self._assert_memory_footprint_increased_during_profiling() - - memory_use_for_model_run = result.non_kv_cache_memory - - memory_for_current_instance = ( - total_gpu_memory * self.cache_config.gpu_memory_utilization - ) - available_kv_cache_memory = memory_for_current_instance - memory_use_for_model_run - - # Calculate the number of blocks that can be allocated with the - # profiled peak memory. - cache_block_size = self.get_cache_block_size_bytes() - if cache_block_size == 0: - num_gpu_blocks = 0 - num_cpu_blocks = 0 - else: - num_gpu_blocks = int(available_kv_cache_memory // cache_block_size) - num_cpu_blocks = int(self.cache_config.swap_space_bytes // cache_block_size) - num_gpu_blocks = max(num_gpu_blocks, 0) - num_cpu_blocks = max(num_cpu_blocks, 0) - - if os.getenv("vLLM_MODEL_MEMORY_USE_GB"): - msg = ( - f"The current vLLM instance can use " - "total_gpu_memory " - f"({(total_gpu_memory / GiB_bytes):.2f}GiB)" - " x gpu_memory_utilization " - f"({self.cache_config.gpu_memory_utilization:.2f})" - f" = {(memory_for_current_instance / GiB_bytes):.2f}GiB\n" - "set model use memory " - f"{(memory_use_for_model_run):.2f}GiB;" - " the rest of the memory reserved for KV Cache is " - f"{(available_kv_cache_memory / GiB_bytes):.2f}GiB." - ) - else: - msg = ( - f"Memory profiling takes {result.profile_time:.2f} seconds\n" - "the current vLLM instance can use " - "total_gpu_memory " - f"({(total_gpu_memory / GiB_bytes):.2f}GiB)" - " x gpu_memory_utilization " - f"({self.cache_config.gpu_memory_utilization:.2f})" - f" = {(memory_for_current_instance / GiB_bytes):.2f}GiB\n" - "model weights take " - f"{(result.weights_memory / GiB_bytes):.2f}GiB;" - " non_torch_memory takes " - f"{(result.non_torch_increase / GiB_bytes):.2f}GiB;" - " PyTorch activation peak memory takes " - f"{(result.torch_peak_increase / GiB_bytes):.2f}GiB;" - " the rest of the memory reserved for KV Cache is " - f"{(available_kv_cache_memory / GiB_bytes):.2f}GiB." - ) - - logger.info(msg) - - # Final cleanup - gc.collect() - - return num_gpu_blocks, num_cpu_blocks