diff --git a/vllm_mindspore/__init__.py b/vllm_mindspore/__init__.py index 963316a79594fb622fd83fd26b08fb34aeefeac2..1a76027fbfca6d0209300bacce5fd98f93be6ec4 100644 --- a/vllm_mindspore/__init__.py +++ b/vllm_mindspore/__init__.py @@ -32,288 +32,280 @@ env_setup() # 2. update the log configuration ahead of other modifications. import vllm_mindspore.logger -from vllm_mindspore.platforms.ascend import AscendPlatform - -ascend_platform = AscendPlatform() - -import vllm.config - -vllm.config.current_platform = ascend_platform - -import vllm.platforms - -vllm.platforms.current_platform = ascend_platform - -import vllm.utils - -vllm.utils.current_platform = ascend_platform - -import vllm.attention.selector -vllm.attention.selector.current_platform = ascend_platform - -import vllm.engine.arg_utils -from vllm_mindspore.engine.arg_utils import _is_v1_supported_oracle -vllm.engine.arg_utils.EngineArgs._is_v1_supported_oracle = _is_v1_supported_oracle - -import vllm.v1.engine.core -from vllm_mindspore.v1.engine.core import shutdown -vllm.v1.engine.core.DPEngineCoreProc.shutdown = shutdown - -from vllm_mindspore.utils import ( - make_tensor_with_pad, - async_tensor_h2d, - ascend_is_initialized, - ms_memory_profiling, -) - -vllm.utils.make_tensor_with_pad = make_tensor_with_pad -vllm.utils.async_tensor_h2d = async_tensor_h2d -vllm.utils.cuda_is_initialized = ascend_is_initialized -vllm.utils.memory_profiling = ms_memory_profiling - -import vllm.executor - -from vllm_mindspore.model_executor.models.registry import ( - MindSporeModelRegistry, - _SUBPROCESS_COMMAND, -) - -vllm.config.ModelRegistry = MindSporeModelRegistry - -import vllm.model_executor - -vllm.model_executor.models.ModelRegistry = MindSporeModelRegistry -vllm.model_executor.models.registry._SUBPROCESS_COMMAND = _SUBPROCESS_COMMAND +import importlib.util +if importlib.util.find_spec("vllm_ascend") is not None: + import vllm_mindspore.patch.patch_vllm_ascend +else: + warnings.warn( + f"vllm-ascend is not imported because vllm_ascend is not installed" + ) -from vllm_mindspore.model_executor.model_loader.utils import get_ms_model_architecture + from vllm_mindspore.platforms.ascend import AscendPlatform -# To patching the get_model_architecture, should import it first. -from vllm.model_executor.model_loader import get_model_architecture + ascend_platform = AscendPlatform() -vllm.model_executor.model_loader.get_model_architecture = get_ms_model_architecture -vllm.model_executor.model_loader.utils.get_model_architecture = ( - get_ms_model_architecture -) -vllm.model_executor.model_loader.loader.get_model_architecture = ( - get_ms_model_architecture -) + import vllm.config -from vllm_mindspore.model_executor.sampling_metadata import SamplingTensors + vllm.config.current_platform = ascend_platform -vllm.model_executor.sampling_metadata.async_tensor_h2d = async_tensor_h2d -vllm.model_executor.sampling_metadata.SamplingTensors.from_lists = SamplingTensors.from_lists -from vllm_mindspore.worker.cache_engine import ( - ms_allocate_kv_cache, - ms_swap_in, - ms_swap_out, -) + import vllm.platforms -import vllm.worker.cache_engine + vllm.platforms.current_platform = ascend_platform -vllm.worker.cache_engine.CacheEngine._allocate_kv_cache = ms_allocate_kv_cache -vllm.worker.cache_engine.CacheEngine.swap_in = ms_swap_in -vllm.worker.cache_engine.CacheEngine.swap_out = ms_swap_out + import vllm.utils -from vllm_mindspore.model_executor.model_loader.weight_utils import ( - safetensors_weights_iterator, -) + vllm.utils.current_platform = ascend_platform -vllm.model_executor.model_loader.loader.safetensors_weights_iterator = ( - safetensors_weights_iterator -) + import vllm.attention.selector + vllm.attention.selector.current_platform = ascend_platform -from vllm_mindspore.worker.worker import _warm_up_model -from vllm_mindspore.worker.profile import ( - wrapper_worker_init, - wrapper_worker_init_device, -) -from vllm.worker.worker import Worker + import vllm.engine.arg_utils + from vllm_mindspore.engine.arg_utils import _is_v1_supported_oracle + vllm.engine.arg_utils.EngineArgs._is_v1_supported_oracle = _is_v1_supported_oracle -Worker._warm_up_model = _warm_up_model -Worker.__init__ = wrapper_worker_init(Worker.__init__) -Worker.init_device = wrapper_worker_init_device(Worker.init_device) + from vllm_mindspore.utils import ( + make_tensor_with_pad, + async_tensor_h2d, + ascend_is_initialized, + ms_memory_profiling, + ) -from vllm_mindspore.worker.model_runner import ( - _get_cuda_graph_pad_size, - _dummy_run, - _get_supported_attention_backends, -) + vllm.utils.make_tensor_with_pad = make_tensor_with_pad + vllm.utils.async_tensor_h2d = async_tensor_h2d + vllm.utils.cuda_is_initialized = ascend_is_initialized + vllm.utils.memory_profiling = ms_memory_profiling -vllm.worker.model_runner.ModelInputForGPUBuilder._get_cuda_graph_pad_size = ( - _get_cuda_graph_pad_size -) -vllm.worker.model_runner.GPUModelRunnerBase._dummy_run = _dummy_run + from vllm_mindspore.model_executor.models.registry import ( + MindSporeModelRegistry, + _SUBPROCESS_COMMAND, + ) -import vllm.worker.multi_step_model_runner + vllm.config.ModelRegistry = MindSporeModelRegistry -vllm.worker.multi_step_model_runner._get_supported_attention_backends = ( - _get_supported_attention_backends -) + import vllm.model_executor -from vllm_mindspore.executor.multiproc_worker_utils import ( - get_mp_context as ms_get_mp_context, - terminate_worker as ms_terminate_worker, -) + vllm.model_executor.models.ModelRegistry = MindSporeModelRegistry + vllm.model_executor.models.registry._SUBPROCESS_COMMAND = _SUBPROCESS_COMMAND -# To patching the get_mp_context, should import it first. -from vllm.executor.multiproc_worker_utils import get_mp_context + from vllm_mindspore.model_executor.model_loader.utils import get_ms_model_architecture -vllm.executor.multiproc_worker_utils.get_mp_context = ms_get_mp_context + # To patching the get_model_architecture, should import it first. + from vllm.model_executor.model_loader import get_model_architecture -import vllm.executor.multiproc_worker_utils + from vllm_mindspore.model_executor.sampling_metadata import SamplingTensors -vllm.executor.multiproc_worker_utils.ProcessWorkerWrapper.terminate_worker = ms_terminate_worker + vllm.model_executor.sampling_metadata.async_tensor_h2d = async_tensor_h2d + vllm.model_executor.sampling_metadata.SamplingTensors.from_lists = SamplingTensors.from_lists -import vllm.v1.executor.multiproc_executor -vllm.v1.executor.multiproc_executor.get_mp_context = ms_get_mp_context -import vllm.v1.utils -vllm.v1.utils.get_mp_context = ms_get_mp_context + from vllm_mindspore.worker.cache_engine import ( + ms_allocate_kv_cache, + ms_swap_in, + ms_swap_out, + ) -from vllm_mindspore.executor.ray_gpu_executor import ( - ms_init_workers_ray, - initialize_ray_cluster, -) + import vllm.worker.cache_engine -from vllm.executor.ray_distributed_executor import RayDistributedExecutor + vllm.worker.cache_engine.CacheEngine._allocate_kv_cache = ms_allocate_kv_cache + vllm.worker.cache_engine.CacheEngine.swap_in = ms_swap_in + vllm.worker.cache_engine.CacheEngine.swap_out = ms_swap_out -RayDistributedExecutor._init_workers_ray = ms_init_workers_ray + from vllm_mindspore.model_executor.model_loader.weight_utils import ( + safetensors_weights_iterator, + ) -vllm.executor.ray_distributed_executor.initialize_ray_cluster = initialize_ray_cluster -vllm.executor.ray_utils.initialize_ray_cluster = initialize_ray_cluster + vllm.model_executor.model_loader.loader.safetensors_weights_iterator = ( + safetensors_weights_iterator + ) -import vllm.engine.llm_engine -import vllm.engine.async_llm_engine + from vllm_mindspore.worker.worker import _warm_up_model + from vllm_mindspore.worker.profile import ( + wrapper_worker_init, + wrapper_worker_init_device, + ) + from vllm.worker.worker import Worker -vllm.engine.llm_engine.initialize_ray_cluster = initialize_ray_cluster -vllm.engine.async_llm_engine.initialize_ray_cluster = initialize_ray_cluster + Worker._warm_up_model = _warm_up_model + Worker.__init__ = wrapper_worker_init(Worker.__init__) + Worker.init_device = wrapper_worker_init_device(Worker.init_device) + from vllm_mindspore.worker.model_runner import ( + _get_cuda_graph_pad_size, + _dummy_run, + _get_supported_attention_backends, + ) -from .config import _verify_quantization, _verify_args, vllm_config_post_init, model_post_init, \ - _get_and_verify_dtype, stateless_init_dp_group, has_unfinished_dp - -vllm.config.ModelConfig._verify_quantization = _verify_quantization -vllm.config.VllmConfig.__post_init__ = vllm_config_post_init -vllm.config.SchedulerConfig._verify_args = _verify_args -vllm.config.CompilationConfig.model_post_init = model_post_init -vllm.config._get_and_verify_dtype = _get_and_verify_dtype -vllm.config.ParallelConfig.stateless_init_dp_group = stateless_init_dp_group -vllm.config.ParallelConfig.has_unfinished_dp = has_unfinished_dp + vllm.worker.model_runner.ModelInputForGPUBuilder._get_cuda_graph_pad_size = ( + _get_cuda_graph_pad_size + ) + vllm.worker.model_runner.GPUModelRunnerBase._dummy_run = _dummy_run -from .utils import update_modules -from vllm_mindspore.attention.backends import ms_attn -update_modules("vllm.attention.backends.flash_attn", ms_attn) + import vllm.worker.multi_step_model_runner -from vllm_mindspore.worker.spec_decode_worker import ( - spec_decode_worker_init, - _run_no_spec, - _verify_tokens, - _create_output, - _merge_outputs, -) -from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker -SpecDecodeWorker.__init__ = spec_decode_worker_init -SpecDecodeWorker._verify_tokens = _verify_tokens -SpecDecodeWorker._run_no_spec = _run_no_spec - -from vllm.model_executor.layers.spec_decode_base_sampler import SpecDecodeBaseSampler -SpecDecodeBaseSampler._create_output = _create_output + vllm.worker.multi_step_model_runner._get_supported_attention_backends = ( + _get_supported_attention_backends + ) -from vllm.spec_decode.top1_proposer import Top1Proposer -Top1Proposer._merge_outputs = _merge_outputs + from vllm_mindspore.executor.multiproc_worker_utils import ( + get_mp_context as ms_get_mp_context, + ) -from vllm_mindspore.model_executor.layers.rejection_sampler import _smallest_positive_value, _multinomial -from vllm.model_executor.layers.rejection_sampler import RejectionSampler -RejectionSampler._smallest_positive_value = _smallest_positive_value -RejectionSampler._smallest_positive_value.__set_name__(RejectionSampler, '_smallest_positive_value') -vllm.model_executor.layers.rejection_sampler._multinomial = _multinomial + # To patching the get_mp_context, should import it first. + from vllm.executor.multiproc_worker_utils import get_mp_context -######### for multi-model -from vllm_mindspore.inputs.registry import call_hf_processor -from vllm.inputs.registry import InputProcessingContext -InputProcessingContext.call_hf_processor = call_hf_processor + vllm.executor.multiproc_worker_utils.get_mp_context = ms_get_mp_context -from vllm_mindspore.multimodal.inputs import as_kwargs -from vllm.multimodal.inputs import MultiModalKwargs -MultiModalKwargs.as_kwargs = as_kwargs + import vllm.v1.executor.multiproc_executor + vllm.v1.executor.multiproc_executor.get_mp_context = ms_get_mp_context + import vllm.v1.utils + vllm.v1.utils.get_mp_context = ms_get_mp_context -from vllm_mindspore.model_executor.layers.rotary_embedding import InferMRotaryEmbedding -vllm.model_executor.layers.rotary_embedding.MRotaryEmbedding = InferMRotaryEmbedding + from vllm_mindspore.executor.ray_gpu_executor import ( + ms_init_workers_ray, + initialize_ray_cluster, + ) -from vllm_mindspore.v1.sample import rejection_sampler -update_modules("vllm.v1.sample.rejection_sampler", rejection_sampler) + from vllm.executor.ray_distributed_executor import RayDistributedExecutor -from vllm_mindspore.v1.spec_decode import eagle -update_modules("vllm.v1.spec_decode.eagle", eagle) + RayDistributedExecutor._init_workers_ray = ms_init_workers_ray -from vllm_mindspore.v1.attention.backends import flash_attn -import vllm.v1.attention.backends -sys.modules['vllm.v1.attention.backends.flash_attn'] = flash_attn -import vllm.v1.attention.backends.flash_attn + vllm.executor.ray_distributed_executor.initialize_ray_cluster = initialize_ray_cluster + vllm.executor.ray_utils.initialize_ray_cluster = initialize_ray_cluster -import vllm.v1.worker.gpu_model_runner + import vllm.engine.llm_engine + import vllm.engine.async_llm_engine -from vllm_mindspore.v1.worker.gpu_model_runner import _prepare_inputs -vllm.v1.worker.gpu_model_runner.GPUModelRunner._prepare_inputs = _prepare_inputs + vllm.engine.llm_engine.initialize_ray_cluster = initialize_ray_cluster + vllm.engine.async_llm_engine.initialize_ray_cluster = initialize_ray_cluster -from vllm_mindspore.v1.worker.gpu_model_runner import _update_states -vllm.v1.worker.gpu_model_runner.GPUModelRunner._update_states = _update_states -from vllm_mindspore.v1.worker.gpu_model_runner import initialize_kv_cache -vllm.v1.worker.gpu_model_runner.GPUModelRunner.initialize_kv_cache = initialize_kv_cache + from .config import _verify_quantization, _verify_args, vllm_config_post_init, model_post_init, \ + _get_and_verify_dtype, stateless_init_dp_group, has_unfinished_dp -import vllm.v1.worker.block_table -from vllm_mindspore.v1.worker.block_table import BlockTable -vllm.v1.worker.block_table.BlockTable = BlockTable -vllm.v1.worker.gpu_input_batch.BlockTable = BlockTable - -import vllm.v1.worker.gpu_input_batch -from vllm_mindspore.v1.worker.gpu_input_batch import _make_sampling_metadata, _make_prompt_token_ids_tensor -vllm.v1.worker.gpu_input_batch.InputBatch._make_sampling_metadata = _make_sampling_metadata -vllm.v1.worker.gpu_model_runner.InputBatch._make_sampling_metadata = _make_sampling_metadata -vllm.v1.worker.gpu_input_batch.InputBatch._make_prompt_token_ids_tensor = _make_prompt_token_ids_tensor -vllm.v1.worker.gpu_model_runner.InputBatch._make_prompt_token_ids_tensor = _make_prompt_token_ids_tensor + vllm.config.ModelConfig._verify_quantization = _verify_quantization + vllm.config.VllmConfig.__post_init__ = vllm_config_post_init + vllm.config.SchedulerConfig._verify_args = _verify_args + vllm.config.CompilationConfig.model_post_init = model_post_init + vllm.config._get_and_verify_dtype = _get_and_verify_dtype + vllm.config.ParallelConfig.stateless_init_dp_group = stateless_init_dp_group + vllm.config.ParallelConfig.has_unfinished_dp = has_unfinished_dp -from vllm.v1.worker.gpu_worker import Worker -from vllm_mindspore.v1.worker.gpu_worker import init_device + from .utils import update_modules + from vllm_mindspore.attention.backends import ms_attn + update_modules("vllm.attention.backends.flash_attn", ms_attn) -Worker.__init__ = wrapper_worker_init(Worker.__init__) -Worker.init_device = wrapper_worker_init_device(init_device) - - -import vllm.v1.utils -from vllm_mindspore.v1.utils import copy_slice -vllm.v1.utils.copy_slice = copy_slice -vllm.v1.worker.gpu_input_batch.copy_slice = copy_slice - -from vllm_mindspore.v1.sample.ops.penalties import _convert_to_tensors -import vllm.v1.sample.ops.penalties -vllm.v1.sample.ops.penalties._convert_to_tensors = _convert_to_tensors -import vllm.model_executor.layers.utils -from vllm_mindspore.model_executor.layers.utils import apply_penalties -vllm.model_executor.layers.utils.apply_penalties = apply_penalties -vllm.v1.sample.ops.penalties.apply_penalties = apply_penalties - - -from vllm_mindspore.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p, random_sample, \ - apply_top_k_only, topk_topp_sampler_forward_native - -import vllm.v1.sample.ops.topk_topp_sampler -from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler -TopKTopPSampler.forward_native = topk_topp_sampler_forward_native -vllm.v1.sample.ops.topk_topp_sampler.apply_top_k_top_p = apply_top_k_top_p -vllm.v1.sample.ops.topk_topp_sampler.random_sample = random_sample -vllm.v1.sample.ops.topk_topp_sampler.apply_top_k_only = apply_top_k_only -from vllm_mindspore.v1.sample.sampler import apply_temperature -import vllm.v1.sample.sampler -vllm.v1.sample.sampler.Sampler.apply_temperature = apply_temperature - -from vllm_mindspore.distributed.shm_broadcast import initialize_ShmRingBuffer -from vllm.distributed.device_communicators.shm_broadcast import ShmRingBuffer -ShmRingBuffer.__init__ = initialize_ShmRingBuffer - -from vllm_mindspore.v1.worker.gpu_worker import compile_or_warm_up_model -from vllm.v1.worker.gpu_worker import Worker -Worker.compile_or_warm_up_model = compile_or_warm_up_model + from vllm_mindspore.worker.spec_decode_worker import ( + spec_decode_worker_init, + _run_no_spec, + _verify_tokens, + _create_output, + _merge_outputs, + ) + from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker + SpecDecodeWorker.__init__ = spec_decode_worker_init + SpecDecodeWorker._verify_tokens = _verify_tokens + SpecDecodeWorker._run_no_spec = _run_no_spec + + from vllm.model_executor.layers.spec_decode_base_sampler import SpecDecodeBaseSampler + SpecDecodeBaseSampler._create_output = _create_output + + from vllm.spec_decode.top1_proposer import Top1Proposer + Top1Proposer._merge_outputs = _merge_outputs + + from vllm_mindspore.model_executor.layers.rejection_sampler import _smallest_positive_value, _multinomial + from vllm.model_executor.layers.rejection_sampler import RejectionSampler + RejectionSampler._smallest_positive_value = _smallest_positive_value + RejectionSampler._smallest_positive_value.__set_name__(RejectionSampler, '_smallest_positive_value') + vllm.model_executor.layers.rejection_sampler._multinomial = _multinomial + + from vllm_mindspore.v1.sample import rejection_sampler + update_modules("vllm.v1.sample.rejection_sampler", rejection_sampler) + + from vllm_mindspore.v1.spec_decode import eagle + update_modules("vllm.v1.spec_decode.eagle", eagle) + + from vllm_mindspore.v1.attention.backends import flash_attn + import vllm.v1.attention.backends + sys.modules['vllm.v1.attention.backends.flash_attn'] = flash_attn + import vllm.v1.attention.backends.flash_attn + + import vllm.v1.worker.gpu_model_runner + + from vllm_mindspore.v1.worker.gpu_model_runner import _prepare_inputs + vllm.v1.worker.gpu_model_runner.GPUModelRunner._prepare_inputs = _prepare_inputs + + from vllm_mindspore.v1.worker.gpu_model_runner import _update_states + vllm.v1.worker.gpu_model_runner.GPUModelRunner._update_states = _update_states + + from vllm_mindspore.v1.worker.gpu_model_runner import initialize_kv_cache + vllm.v1.worker.gpu_model_runner.GPUModelRunner.initialize_kv_cache = initialize_kv_cache + + import vllm.v1.worker.block_table + from vllm_mindspore.v1.worker.block_table import BlockTable + vllm.v1.worker.block_table.BlockTable = BlockTable + vllm.v1.worker.gpu_input_batch.BlockTable = BlockTable + + import vllm.v1.worker.gpu_input_batch + from vllm_mindspore.v1.worker.gpu_input_batch import _make_sampling_metadata, _make_prompt_token_ids_tensor + vllm.v1.worker.gpu_input_batch.InputBatch._make_sampling_metadata = _make_sampling_metadata + vllm.v1.worker.gpu_model_runner.InputBatch._make_sampling_metadata = _make_sampling_metadata + vllm.v1.worker.gpu_input_batch.InputBatch._make_prompt_token_ids_tensor = _make_prompt_token_ids_tensor + vllm.v1.worker.gpu_model_runner.InputBatch._make_prompt_token_ids_tensor = _make_prompt_token_ids_tensor + + from vllm.v1.worker.gpu_worker import Worker + from vllm_mindspore.v1.worker.gpu_worker import init_device + + Worker.__init__ = wrapper_worker_init(Worker.__init__) + Worker.init_device = wrapper_worker_init_device(init_device) + + ######### for multi-model + from vllm_mindspore.inputs.registry import call_hf_processor + from vllm.inputs.registry import InputProcessingContext + InputProcessingContext.call_hf_processor = call_hf_processor + + from vllm_mindspore.multimodal.inputs import as_kwargs + from vllm.multimodal.inputs import MultiModalKwargs + MultiModalKwargs.as_kwargs = as_kwargs + + from vllm_mindspore.model_executor.layers.rotary_embedding import InferMRotaryEmbedding + vllm.model_executor.layers.rotary_embedding.MRotaryEmbedding = InferMRotaryEmbedding + + from vllm_mindspore.v1.sample import rejection_sampler + update_modules("vllm.v1.sample.rejection_sampler", rejection_sampler) + + import vllm.v1.utils + from vllm_mindspore.v1.utils import copy_slice + vllm.v1.utils.copy_slice = copy_slice + vllm.v1.worker.gpu_input_batch.copy_slice = copy_slice + + from vllm_mindspore.v1.sample.ops.penalties import _convert_to_tensors + import vllm.v1.sample.ops.penalties + vllm.v1.sample.ops.penalties._convert_to_tensors = _convert_to_tensors + import vllm.model_executor.layers.utils + from vllm_mindspore.model_executor.layers.utils import apply_penalties + vllm.model_executor.layers.utils.apply_penalties = apply_penalties + vllm.v1.sample.ops.penalties.apply_penalties = apply_penalties + + + from vllm_mindspore.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p, random_sample, \ + apply_top_k_only, topk_topp_sampler_forward_native + + import vllm.v1.sample.ops.topk_topp_sampler + from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler + TopKTopPSampler.forward_native = topk_topp_sampler_forward_native + vllm.v1.sample.ops.topk_topp_sampler.apply_top_k_top_p = apply_top_k_top_p + vllm.v1.sample.ops.topk_topp_sampler.random_sample = random_sample + vllm.v1.sample.ops.topk_topp_sampler.apply_top_k_only = apply_top_k_only + from vllm_mindspore.v1.sample.sampler import apply_temperature + import vllm.v1.sample.sampler + vllm.v1.sample.sampler.Sampler.apply_temperature = apply_temperature + + from vllm_mindspore.distributed.shm_broadcast import initialize_ShmRingBuffer + from vllm.distributed.device_communicators.shm_broadcast import ShmRingBuffer + ShmRingBuffer.__init__ = initialize_ShmRingBuffer + + from vllm_mindspore.v1.worker.gpu_worker import compile_or_warm_up_model + from vllm.v1.worker.gpu_worker import Worker + Worker.compile_or_warm_up_model = compile_or_warm_up_model from .utils import check_ready diff --git a/vllm_mindspore/attention/backends/ms_attn.py b/vllm_mindspore/attention/backends/ms_attn.py index d6123b0a89790ba630888066cb857d995f190c10..e76557df0dcf9d33b2b217f43d959f4913e80c97 100644 --- a/vllm_mindspore/attention/backends/ms_attn.py +++ b/vllm_mindspore/attention/backends/ms_attn.py @@ -501,17 +501,13 @@ class MsAttentionMetadataBuilder(AttentionMetadataBuilder[MSAttentionMetadata]): self, seq_lens: List[int], query_lens: List[int], - cuda_graph_pad_size: int, - batch_size: int, + graph_size: int = -1, ): """Build attention metadata with on-device tensors. Args: seq_lens: The maybe padded sequence lengths of the input sequences. query_lens: The query lengths of the input sequences. - cuda_graph_pad_size: The padding size for cuda graph. - -1 if cuda graph is not used. - batch_size: The maybe padded batch size. """ prefix_cache_hit = any( [ @@ -525,7 +521,6 @@ class MsAttentionMetadataBuilder(AttentionMetadataBuilder[MSAttentionMetadata]): ) device = self.runner.device - use_captured_graph = cuda_graph_pad_size != -1 max_query_len = max(query_lens) decode_query_lens = query_lens[self.num_prefills:] @@ -539,15 +534,12 @@ class MsAttentionMetadataBuilder(AttentionMetadataBuilder[MSAttentionMetadata]): query_start_loc = list(accumulate(query_lens, initial=0)) seq_start_loc = list(accumulate(seq_lens, initial=0)) - if use_captured_graph: - raise RuntimeError("Doesnot support captured graph now!") - else: - block_tables = make_tensor_with_pad( - self.block_tables, - pad=-1, - dtype=torch.int, - device=device, - ) + block_tables = make_tensor_with_pad( + self.block_tables, + pad=-1, + dtype=torch.int, + device=device, + ) assert max_query_len > 0, "query_lens: {}".format(query_lens) context_lens_tensor = ms.Tensor(self.context_lens, dtype=ms.int32) @@ -595,6 +587,10 @@ class MsAttentionBackend(AttentionBackend): @staticmethod def get_builder_cls() -> Type["MsAttentionMetadataBuilder"]: return MsAttentionMetadataBuilder + + @classmethod + def make_metadata_builder(cls, *args, **kwargs) -> "MsAttentionMetadataBuilder": + return cls.get_builder_cls()(*args, **kwargs) @staticmethod def get_state_cls() -> Type["AttentionState"]: diff --git a/vllm_mindspore/patch/__init__.py b/vllm_mindspore/patch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/vllm_mindspore/patch/patch_vllm_ascend.py b/vllm_mindspore/patch/patch_vllm_ascend.py new file mode 100644 index 0000000000000000000000000000000000000000..c9a4fcb98f6c55f3ebc88492bebd3e84f10161d8 --- /dev/null +++ b/vllm_mindspore/patch/patch_vllm_ascend.py @@ -0,0 +1,278 @@ +import sys +# ================ For vllm ================ + +import vllm.utils + +import vllm.engine.arg_utils +from vllm_mindspore.engine.arg_utils import _is_v1_supported_oracle +vllm.engine.arg_utils.EngineArgs._is_v1_supported_oracle = _is_v1_supported_oracle + +from vllm_mindspore.utils import ( + make_tensor_with_pad, + async_tensor_h2d, + ascend_is_initialized, + ms_memory_profiling, +) + +vllm.utils.make_tensor_with_pad = make_tensor_with_pad +vllm.utils.async_tensor_h2d = async_tensor_h2d +vllm.utils.cuda_is_initialized = ascend_is_initialized +vllm.utils.memory_profiling = ms_memory_profiling + +from vllm_mindspore.model_executor.models.registry import ( + MindSporeModelRegistry, + _SUBPROCESS_COMMAND, +) + +vllm.config.ModelRegistry = MindSporeModelRegistry + +import vllm.model_executor + +vllm.model_executor.models.ModelRegistry = MindSporeModelRegistry +vllm.model_executor.models.registry._SUBPROCESS_COMMAND = _SUBPROCESS_COMMAND + +from vllm_mindspore.model_executor.model_loader.utils import get_ms_model_architecture + +# To patching the get_model_architecture, should import it first. +from vllm.model_executor.model_loader import get_model_architecture + +from vllm_mindspore.model_executor.sampling_metadata import SamplingTensors + +vllm.model_executor.sampling_metadata.async_tensor_h2d = async_tensor_h2d +vllm.model_executor.sampling_metadata.SamplingTensors.from_lists = SamplingTensors.from_lists + +from vllm_mindspore.model_executor.model_loader.weight_utils import ( + safetensors_weights_iterator, +) + +vllm.model_executor.model_loader.loader.safetensors_weights_iterator = ( + safetensors_weights_iterator +) + +from vllm_mindspore.worker.model_runner import ( + _get_cuda_graph_pad_size, + _dummy_run, + _get_supported_attention_backends, +) + +import vllm.worker.model_runner +vllm.worker.model_runner.ModelInputForGPUBuilder._get_cuda_graph_pad_size = ( + _get_cuda_graph_pad_size +) +vllm.worker.model_runner.GPUModelRunnerBase._dummy_run = _dummy_run + +import vllm.worker.multi_step_model_runner + +vllm.worker.multi_step_model_runner._get_supported_attention_backends = ( + _get_supported_attention_backends +) + +from vllm_mindspore.executor.multiproc_worker_utils import ( + get_mp_context as ms_get_mp_context, +) + +from vllm_mindspore.executor.ray_gpu_executor import ( + ms_init_workers_ray, + initialize_ray_cluster, +) + +from vllm.executor.ray_distributed_executor import RayDistributedExecutor + +RayDistributedExecutor._init_workers_ray = ms_init_workers_ray + +vllm.executor.ray_distributed_executor.initialize_ray_cluster = initialize_ray_cluster +vllm.executor.ray_utils.initialize_ray_cluster = initialize_ray_cluster + +import vllm.engine.llm_engine +import vllm.engine.async_llm_engine + +vllm.engine.llm_engine.initialize_ray_cluster = initialize_ray_cluster +vllm.engine.async_llm_engine.initialize_ray_cluster = initialize_ray_cluster + + +from vllm_mindspore.config import _verify_quantization, _verify_args, vllm_config_post_init, model_post_init, \ + _get_and_verify_dtype, stateless_init_dp_group, has_unfinished_dp + +vllm.config.ModelConfig._verify_quantization = _verify_quantization +vllm.config.VllmConfig.__post_init__ = vllm_config_post_init +vllm.config.SchedulerConfig._verify_args = _verify_args +vllm.config.CompilationConfig.model_post_init = model_post_init +vllm.config._get_and_verify_dtype = _get_and_verify_dtype +vllm.config.ParallelConfig.stateless_init_dp_group = stateless_init_dp_group +vllm.config.ParallelConfig.has_unfinished_dp = has_unfinished_dp + +from vllm_mindspore.utils import update_modules +from vllm_mindspore.attention.backends import ms_attn +update_modules("vllm.attention.backends.flash_attn", ms_attn) + +from vllm_mindspore.worker.spec_decode_worker import ( + spec_decode_worker_init, + _run_no_spec, + _verify_tokens, + _create_output, + _merge_outputs, +) +from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker +SpecDecodeWorker.__init__ = spec_decode_worker_init +SpecDecodeWorker._verify_tokens = _verify_tokens +SpecDecodeWorker._run_no_spec = _run_no_spec + +from vllm.model_executor.layers.spec_decode_base_sampler import SpecDecodeBaseSampler +SpecDecodeBaseSampler._create_output = _create_output + +from vllm.spec_decode.top1_proposer import Top1Proposer +Top1Proposer._merge_outputs = _merge_outputs + +from vllm_mindspore.model_executor.layers.rejection_sampler import _smallest_positive_value, _multinomial +from vllm.model_executor.layers.rejection_sampler import RejectionSampler +RejectionSampler._smallest_positive_value = _smallest_positive_value +RejectionSampler._smallest_positive_value.__set_name__(RejectionSampler, '_smallest_positive_value') +vllm.model_executor.layers.rejection_sampler._multinomial = _multinomial + +from vllm_mindspore.v1.sample import rejection_sampler +update_modules("vllm.v1.sample.rejection_sampler", rejection_sampler) + +from vllm_mindspore.v1.spec_decode import eagle +update_modules("vllm.v1.spec_decode.eagle", eagle) + +from vllm_mindspore.v1.attention.backends import flash_attn +import vllm.v1.attention.backends +sys.modules['vllm.v1.attention.backends.flash_attn'] = flash_attn +import vllm.v1.attention.backends.flash_attn + +import vllm.v1.worker.gpu_model_runner + +import vllm.v1.worker.block_table +from vllm_mindspore.v1.worker.block_table import BlockTable +vllm.v1.worker.block_table.BlockTable = BlockTable +vllm.v1.worker.gpu_input_batch.BlockTable = BlockTable + +import vllm.v1.worker.gpu_input_batch +from vllm_mindspore.v1.worker.gpu_input_batch import _make_sampling_metadata, _make_prompt_token_ids_tensor +vllm.v1.worker.gpu_input_batch.InputBatch._make_sampling_metadata = _make_sampling_metadata +vllm.v1.worker.gpu_model_runner.InputBatch._make_sampling_metadata = _make_sampling_metadata +vllm.v1.worker.gpu_input_batch.InputBatch._make_prompt_token_ids_tensor = _make_prompt_token_ids_tensor +vllm.v1.worker.gpu_model_runner.InputBatch._make_prompt_token_ids_tensor = _make_prompt_token_ids_tensor + +import vllm.v1.utils +from vllm_mindspore.v1.utils import copy_slice +vllm.v1.utils.copy_slice = copy_slice +vllm.v1.worker.gpu_input_batch.copy_slice = copy_slice + +from vllm_mindspore.v1.sample.ops.penalties import _convert_to_tensors +import vllm.v1.sample.ops.penalties +vllm.v1.sample.ops.penalties._convert_to_tensors = _convert_to_tensors +import vllm.model_executor.layers.utils +from vllm_mindspore.model_executor.layers.utils import apply_penalties +vllm.model_executor.layers.utils.apply_penalties = apply_penalties +vllm.v1.sample.ops.penalties.apply_penalties = apply_penalties + + +from vllm_mindspore.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p, random_sample, \ + apply_top_k_only, topk_topp_sampler_forward_native + +import vllm.v1.sample.ops.topk_topp_sampler +from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler +TopKTopPSampler.forward_native = topk_topp_sampler_forward_native +vllm.v1.sample.ops.topk_topp_sampler.apply_top_k_top_p = apply_top_k_top_p +vllm.v1.sample.ops.topk_topp_sampler.random_sample = random_sample +vllm.v1.sample.ops.topk_topp_sampler.apply_top_k_only = apply_top_k_only +from vllm_mindspore.v1.sample.sampler import apply_temperature +import vllm.v1.sample.sampler +vllm.v1.sample.sampler.Sampler.apply_temperature = apply_temperature + +from vllm_mindspore.distributed.shm_broadcast import initialize_ShmRingBuffer +from vllm.distributed.device_communicators.shm_broadcast import ShmRingBuffer +ShmRingBuffer.__init__ = initialize_ShmRingBuffer + +from vllm_mindspore.v1.worker.gpu_worker import compile_or_warm_up_model +from vllm.v1.worker.gpu_worker import Worker +Worker.compile_or_warm_up_model = compile_or_warm_up_model + +# ================ For vllm-ascend ================ +# ============ For 0.8.5 start =========== +import importlib +import types +memory_mod = importlib.import_module("torch.cuda.memory") +if not hasattr(memory_mod, "NPUPluggableAllocator"): + memory_mod.NPUPluggableAllocator = memory_mod.CUDAPluggableAllocator +sys.modules["torch_npu.op_plugin"] = types.ModuleType("torch_npu.op_plugin") +sys.modules["torch_npu.op_plugin.atb"] = types.ModuleType("torch_npu.op_plugin.atb") +fake_mod = types.ModuleType("torch_npu.op_plugin.atb._atb_ops") +fake_mod._register_atb_extensions = lambda *a, **kw: None +sys.modules["torch_npu.op_plugin.atb._atb_ops"] = fake_mod +fake_mod = types.ModuleType("torchair._contrib") +sys.modules["torchair._contrib"] = fake_mod +fake_mod = types.ModuleType("torchair._contrib.custom_torch_ops") +sys.modules["torchair._contrib.custom_torch_ops"] = fake_mod +import torch +if not hasattr(torch, "Tag"): + class _FakeTag: + needs_fixed_stride_order = "needs_fixed_stride_order" + torch.Tag = _FakeTag +fake_fused_moe = types.ModuleType("vllm.model_executor.layers.fused_moe.fused_moe") +fake_fused_moe.direct_register_custom_op = lambda *a, **kw: None +sys.modules["vllm.model_executor.layers.fused_moe.fused_moe"] = fake_fused_moe +import vllm_ascend.ops +vllm_ascend.ops.register_dummy_fusion_op = lambda *a, **kw: None +# ============ For 0.8.5 end =========== +fake_mod = types.ModuleType("vllm_ascend.vllm_ascend_C") +fake_mod.init_module = fake_mod.python_create_and_map = fake_mod.python_unmap_and_release = lambda *a, **kw: None +sys.modules.update({"vllm_ascend.vllm_ascend_C": fake_mod}) + +import vllm_ascend.utils +from vllm_mindspore.utils import vllm_version_is +vllm_ascend.utils.vllm_version_is = vllm_version_is + +from vllm_mindspore.platforms.ascend import get_attn_backend_cls +from vllm_ascend.platform import NPUPlatform +NPUPlatform.get_attn_backend_cls = get_attn_backend_cls + +from vllm_mindspore.worker.cache_engine import ( + ms_allocate_kv_cache, + ms_swap_in, + ms_swap_out, +) + +from vllm_ascend.worker.worker import CacheEngine + +CacheEngine._allocate_kv_cache = ms_allocate_kv_cache +CacheEngine.swap_in = ms_swap_in +CacheEngine.swap_out = ms_swap_out + +from vllm_mindspore.worker.worker import _warm_up_model + +from vllm_mindspore.worker.profile import ( + wrapper_worker_init, + wrapper_worker_init_device, +) + +from vllm_ascend.worker.worker import NPUWorker + +NPUWorker._warm_up_model = _warm_up_model +NPUWorker.__init__ = wrapper_worker_init(NPUWorker.__init__) +NPUWorker.init_device = wrapper_worker_init_device(NPUWorker.init_device) + +# ================ End ================ + +# ============ For v1 start =========== +from vllm_mindspore.config import _get_and_verify_dtype +vllm.config._get_and_verify_dtype = _get_and_verify_dtype + +from vllm_mindspore.worker.model_runner_v1 import _dummy_run, _process_reqs, wrapper_runner_init +from vllm_ascend.worker.model_runner_v1 import NPUModelRunner +NPUModelRunner._dummy_run = _dummy_run +NPUModelRunner._process_reqs = _process_reqs +NPUModelRunner.__init__ = wrapper_runner_init(NPUModelRunner.__init__) + +from vllm_mindspore.worker.worker_v1 import determine_available_memory +from vllm_ascend.worker.worker_v1 import NPUWorker +NPUWorker.determine_available_memory = determine_available_memory + +from vllm_mindspore.v1.worker.gpu_model_runner import initialize_kv_cache +NPUModelRunner.initialize_kv_cache = initialize_kv_cache + +from vllm_mindspore.v1.worker.gpu_worker import compile_or_warm_up_model +from vllm_ascend.worker.worker_v1 import NPUWorker +NPUWorker.compile_or_warm_up_model = compile_or_warm_up_model +# ============ For v1 end =========== diff --git a/vllm_mindspore/platforms/ascend.py b/vllm_mindspore/platforms/ascend.py index 356a33a040c050b0825a1c2fe5fea2179fbafa60..d4b71afccc165f4bdc0e3947cd6f3745eba170fe 100644 --- a/vllm_mindspore/platforms/ascend.py +++ b/vllm_mindspore/platforms/ascend.py @@ -109,7 +109,6 @@ class AscendPlatform(Platform): if use_mla: return "vllm_mindspore.v1.attention.backends.flash_attn.MLABackend" return "vllm_mindspore.v1.attention.backends.flash_attn.FlashAttentionBackend" - raise RuntimeError("vLLM-MindSpore do not support v1 egine now!") if use_mla: logger.info("Using MindSpore MLA backend.") return "vllm_mindspore.attention.backends.ms_attn.MLABackend" @@ -144,4 +143,25 @@ class AscendPlatform(Platform): @classmethod def supports_v1(cls, model_config: ModelConfig) -> bool: - return True \ No newline at end of file + return True + +def get_attn_backend_cls(cls, selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1, use_mla): + """Get the attention backend class of a device.""" + if use_v1: + if use_mla: + logger.info("Using MindSpore MLA backend for V1.") + return "vllm_mindspore.v1.attention.backends.flash_attn.MLABackend" + logger.info("Using MindSpore Attention backend for V1.") + return "vllm_mindspore.v1.attention.backends.flash_attn.FlashAttentionBackend" + if use_mla: + logger.info("Using MindSpore MLA backend.") + return "vllm_mindspore.attention.backends.ms_attn.MLABackend" + + if selected_backend == _Backend.FLASH_ATTN or selected_backend is None: + logger.info("Using MindSpore Attention backend.") + return "vllm_mindspore.attention.backends.ms_attn.MsAttentionBackend" + + raise ValueError( + "Invaild attention backend %s for vLLM-MindSpore with head_size: %s, dtype: %s, kv_cache_dtype: %s, block_size: %s." + % (str(selected_backend), str(head_size), str(dtype), str(kv_cache_dtype), str(block_size)) + ) diff --git a/vllm_mindspore/utils.py b/vllm_mindspore/utils.py index 153589ed6dc3affb423fdcd0bc160019db621955..9d488d7c26a22fa2ff56fc2bee317541c2913314 100644 --- a/vllm_mindspore/utils.py +++ b/vllm_mindspore/utils.py @@ -30,6 +30,8 @@ if TYPE_CHECKING: else: Library = None +from packaging.version import Version + from vllm.logger import init_logger import mindspore as ms @@ -164,7 +166,6 @@ def check_ready(): import vllm.envs as envs from mindspore import set_context - # Common environment variables of predict. set_context(jit_config={"jit_level": "O0", "infer_boost": "on"}) default_env = { @@ -295,3 +296,10 @@ def ms_memory_profiling( result.non_torch_increase = diff_from_create.non_torch_memory result.profile_time = diff_profile.timestamp result.non_kv_cache_memory = result.non_torch_increase + result.torch_peak_increase + result.weights_memory # noqa + + +def vllm_version_is(version: str): + import vllm + if vllm.__version__ == '0.8.3': # since vllm-ascend support from 0.8.4, 0.8.3 should be supported too. + return True + return Version(vllm.__version__) == Version(version) diff --git a/vllm_mindspore/v1/worker/gpu_worker.py b/vllm_mindspore/v1/worker/gpu_worker.py index 0395c33928a2e7e42c4e3c36a12a20963a808133..e6543b1cc41cdd6f572fc2778f9a27a487906b38 100644 --- a/vllm_mindspore/v1/worker/gpu_worker.py +++ b/vllm_mindspore/v1/worker/gpu_worker.py @@ -5,6 +5,7 @@ import gc import torch from vllm.logger import init_logger from vllm.distributed.parallel_state import get_pp_group +from vllm.model_executor import set_random_seed logger = init_logger(__name__) @@ -48,5 +49,6 @@ def compile_or_warm_up_model(self) -> None: # Since prefill is done previously, we do decode here. default_max_num_reqs = 1 # For MindSpore, we only do one more decode here. if get_pp_group().is_last_rank: - self.model_runner._dummy_sampler_run(self.model_runner._dummy_run( - num_tokens=default_max_num_reqs)) + self.model_runner._dummy_run( + num_tokens=default_max_num_reqs) + set_random_seed(self.model_config.seed) diff --git a/vllm_mindspore/worker/model_runner.py b/vllm_mindspore/worker/model_runner.py index 55bb26ec4ee65181cfc30425640149532c5b36bd..68a9360c1ff0da583282d3ef6ee819bbb8c9827a 100644 --- a/vllm_mindspore/worker/model_runner.py +++ b/vllm_mindspore/worker/model_runner.py @@ -179,4 +179,4 @@ def _get_supported_attention_backends(chunked_prefill_enabled: bool) \ if chunked_prefill_enabled: return MULTI_STEP_CHUNKED_PREFILL_ATTENTION_BACKENDS else: - return MULTI_STEP_ATTENTION_BACKENDS \ No newline at end of file + return MULTI_STEP_ATTENTION_BACKENDS diff --git a/vllm_mindspore/worker/model_runner_v1.py b/vllm_mindspore/worker/model_runner_v1.py new file mode 100644 index 0000000000000000000000000000000000000000..10d7e2d2e20e7db93da2023c102e0dd815e935ee --- /dev/null +++ b/vllm_mindspore/worker/model_runner_v1.py @@ -0,0 +1,241 @@ +#!/usr/bin/env python3 +# encoding: utf-8 +# Copyright 2025 Huawei Technologies Co., Ltd +# Copyright 2024 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from typing import List, Optional, Tuple +import numpy as np +import weakref + +import torch +import mindspore as ms +from mindspore import Tensor + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.distributed.parallel_state import get_pp_group +from vllm.forward_context import set_forward_context +from vllm.sequence import IntermediateTensors +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.attention.layer import Attention + +from vllm_mindspore.v1.attention.backends.flash_attn import FlashAttentionMetadata + +#################33 +from mindspore import mutable +from vllm_mindspore.utils import get_valid_dtype +# from vllm_mindspore.utils import is_use_mla + +from vllm.attention import AttentionType +from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig, KVCacheSpec +from vllm.v1.utils import bind_kv_cache +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs +from vllm.logger import logger +from vllm.distributed.parallel_state import get_pp_group +from vllm.utils import cdiv +from vllm.logger import init_logger +from vllm.v1.worker.gpu_input_batch import CachedRequestState +from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding +from vllm.sampling_params import SamplingType +################### + + +logger = init_logger(__name__) + + +def wrapper_runner_init(func): + def wrapper(*args, **kwargs): + func(*args, **kwargs) + self = args[0] + self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1, + dtype=torch.int32, + device="cpu", + pin_memory=True) + self.query_start_loc_np = self.query_start_loc_cpu.numpy() + + import weakref + self.attn_metadata_builder = self.attn_backend.get_builder_cls()( + weakref.proxy(self)) + + return wrapper + + +@torch.inference_mode() +def _dummy_run( + self, + num_tokens: int = None, +) -> torch.Tensor: + if num_tokens is None: + num_tokens = self.max_num_tokens + model = self.model + if self.is_multimodal_model: + input_ids = None + inputs_embeds = self.inputs_embeds[:num_tokens] + else: + input_ids = self.input_ids[:num_tokens] + inputs_embeds = None + + if self.uses_mrope: + positions = self.mrope_positions[:, :num_tokens] + else: + positions_np = self.positions_np[:num_tokens] + positions = torch.from_numpy(positions_np) + + if get_pp_group().is_first_rank: + intermediate_tensors = None + else: + if self.intermediate_tensors is None: + self.intermediate_tensors = ( + self.model.make_empty_intermediate_tensors( + batch_size=self.max_num_tokens, + dtype=self.model_config.dtype, + device=self.device)) + intermediate_tensors = IntermediateTensors({ + k: v[:num_tokens] + for k, v in self.intermediate_tensors.items() + }) + + with set_forward_context(None, self.vllm_config): + hidden_states = model(input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds) + return hidden_states + + +def _process_reqs( + self, + scheduler_output: "SchedulerOutput", + intermediate_tensors: Optional[IntermediateTensors] = None, +) -> torch.Tensor: + total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + assert total_num_scheduled_tokens > 0 + num_reqs = self.input_batch.num_reqs + assert num_reqs > 0 + + modified_batch = self.attn_metadata_builder.reorder_batch( + self.input_batch, scheduler_output) + if modified_batch: + self.input_batch.refresh_sampling_metadata() + + # OPTIMIZATION: Start copying the block table first. + # This way, we can overlap the copy with the following CPU operations. + self.input_batch.block_table.commit(num_reqs) + + # Get the number of scheduled tokens for each request. + # TODO: The Python loop can be slow. Optimize. + num_scheduled_tokens = np.empty(num_reqs, dtype=np.int32) + max_num_scheduled_tokens = 0 + for i, req_id in enumerate(self.input_batch.req_ids): + num_tokens = scheduler_output.num_scheduled_tokens[req_id] + num_scheduled_tokens[i] = num_tokens + max_num_scheduled_tokens = max(max_num_scheduled_tokens, + num_tokens) + + # Get request indices. + # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] + req_indices = np.repeat(self.arange_np[:num_reqs], + num_scheduled_tokens) + + # Get batched arange. + # E.g., [2, 5, 3] -> [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # Equivalent to but faster than: + # np.concatenate([np.arange(n) for n in num_scheduled_tokens]) + # Step 1. [2, 5, 3] -> [2, 7, 10] + cu_num_tokens = np.cumsum(num_scheduled_tokens) + # Step 2. [2, 7, 10] -> [0, 0, 2, 2, 2, 2, 2, 7, 7, 7] + cumsums_offsets = np.repeat(cu_num_tokens - num_scheduled_tokens, + num_scheduled_tokens) + # Step 3. [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + arange = self.arange_np[:total_num_scheduled_tokens] - cumsums_offsets + + # Get positions. + positions_np = self.positions_np[:total_num_scheduled_tokens] + np.add(self.input_batch.num_computed_tokens_cpu[req_indices], + arange, + out=positions_np) + + if self.uses_mrope: + self._calc_mrope_positions(scheduler_output) + + if self.uses_mrope: + # Only relevant for models using M-RoPE (e.g, Qwen2-VL) + self.mrope_positions[:, :total_num_scheduled_tokens].copy_( + self.mrope_positions_cpu[:, :total_num_scheduled_tokens], + non_blocking=True) + positions = self.mrope_positions[:, :total_num_scheduled_tokens] + else: + self.positions[:total_num_scheduled_tokens] = torch.from_numpy(positions_np) + positions = self.positions[:total_num_scheduled_tokens] + + + # Get token indices. + # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] + # where M is the max_model_len. + token_indices = (positions_np + + req_indices * self.input_batch.token_ids_cpu.shape[1]) + + self.input_ids[:total_num_scheduled_tokens] = torch.from_numpy( + np.take(self.input_batch.token_ids_cpu.ravel(), + token_indices, + 0) + ) + + # Calculate the slot mapping. + # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1] + # where K is the max_num_blocks_per_req and the block size is 2. + # NOTE(woosuk): We can't simply use `token_indices // block_size` here + # because M (max_model_len) is not necessarily divisible by block_size. + block_table_indices = (req_indices * self.max_num_blocks_per_req + + positions_np // self.block_size) + + + block_numbers = self.input_batch.block_table.block_table_np.ravel()[block_table_indices] + block_offsets = positions_np % self.block_size + np.add(block_numbers * self.block_size, + block_offsets, + out=self.slot_mapping_np[:total_num_scheduled_tokens]) + + # # Prepare the attention metadata. + self.query_start_loc_np[0] = 0 + self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens + + self.seq_lens_np[:num_reqs] = ( + self.input_batch.num_computed_tokens_cpu[:num_reqs] + + num_scheduled_tokens) + + attn_metadata = self.attn_metadata_builder.build( + num_reqs=num_reqs, + num_actual_tokens=total_num_scheduled_tokens, + max_query_len=max_num_scheduled_tokens, + common_prefix_len=0, + ) + + input_ids = self.input_ids[:total_num_scheduled_tokens] + attn_metadata.num_input_tokens = total_num_scheduled_tokens + + # Run forward pass + with set_forward_context(attn_metadata, self.vllm_config): + assert self.model is not None + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=None, + ) + + return hidden_states[cu_num_tokens - 1] diff --git a/vllm_mindspore/worker/worker.py b/vllm_mindspore/worker/worker.py index 8ce1bc91d511a43a83fd3c8b0e70d228b98b951b..2dc69fcdce7847da57d7ce4ed1af141d061087b1 100644 --- a/vllm_mindspore/worker/worker.py +++ b/vllm_mindspore/worker/worker.py @@ -26,7 +26,6 @@ import torch from vllm.config import VllmConfig from vllm.distributed import ( - ensure_kv_transfer_initialized, ensure_model_parallel_initialized, init_distributed_environment, set_custom_all_reduce, diff --git a/vllm_mindspore/worker/worker_v1.py b/vllm_mindspore/worker/worker_v1.py new file mode 100644 index 0000000000000000000000000000000000000000..8ad5f0c97333c8796e2a2bf183273cf8db4917af --- /dev/null +++ b/vllm_mindspore/worker/worker_v1.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 +# encoding: utf-8 +# Copyright 2025 Huawei Technologies Co., Ltd +# Copyright 2024 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import gc +from typing import Dict, List + +import torch + +from vllm.logger import init_logger +from vllm.v1.utils import bind_kv_cache +from vllm.v1.kv_cache_interface import FullAttentionSpec + +from vllm_ascend.platform import NPUPlatform + +logger = init_logger(__name__) + +@torch.inference_mode() +def determine_available_memory(self) -> int: + # Profile the memory usage of the model and get the maximum number of + # cache blocks that can be allocated with the remaining free memory. + NPUPlatform.empty_cache() + + # Execute a forward pass with dummy inputs to profile the memory usage + # of the model. + self.model_runner.profile_run() + + # Calculate the number of blocks that can be allocated with the + # profiled peak memory. + free_npu_memory, total_npu_memory = NPUPlatform.mem_get_info() + # NOTE(woosuk): Here we assume that the other processes using the same + # GPU did not change their memory usage during the profiling. + peak_memory = self.init_npu_memory - free_npu_memory + assert peak_memory > 0, ( + "Error in memory profiling. " + f"Initial free memory {self.init_npu_memory}, current free memory" + f" {free_npu_memory}. This happens when the NPU memory was " + "not properly cleaned up before initializing the vLLM instance.") + + gc.collect() + # TODO: don`t need impl this func after empty_cache in + # Worker.determine_num_available_blocks() unified` + NPUPlatform.empty_cache() + usable_memory_size = total_npu_memory * self.cache_config.gpu_memory_utilization - peak_memory + npu_kv_cache_bytes = max(usable_memory_size, 0) + logger.info( + f"Available memory: {usable_memory_size}, total memory: {total_npu_memory}" + ) + return int(npu_kv_cache_bytes)