diff --git a/vllm_mindspore/__init__.py b/vllm_mindspore/__init__.py index ff9d0745d9741037276f304a0db385e37fa96a80..91c4042f77ab4d2b4f2fa7176f335d23463c9164 100644 --- a/vllm_mindspore/__init__.py +++ b/vllm_mindspore/__init__.py @@ -325,6 +325,10 @@ from vllm_mindspore.engine.multiprocessing.engine import cleanup import vllm.engine.multiprocessing.engine vllm.engine.multiprocessing.engine.MQLLMEngine.cleanup = cleanup +from vllm_mindspore.adaptive_chunk_pp.llm_engine import _process_model_outputs +import vllm.engine.llm_engine +vllm.engine.llm_engine.LLMEngine._process_model_outputs = _process_model_outputs + from vllm_mindspore.adaptive_chunk_pp.scheduler import apply_scheduler_patch from vllm_mindspore.adaptive_chunk_pp.sequence import apply_sequence_patch apply_scheduler_patch() diff --git a/vllm_mindspore/adaptive_chunk_pp/llm_engine.py b/vllm_mindspore/adaptive_chunk_pp/llm_engine.py new file mode 100644 index 0000000000000000000000000000000000000000..0de02268ab1ff8fb9c63853136e44318be433c12 --- /dev/null +++ b/vllm_mindspore/adaptive_chunk_pp/llm_engine.py @@ -0,0 +1,245 @@ +# SPDX-License-Identifier: Apache-2.0 + +import time +from typing import Optional, List +from vllm.engine.llm_engine import SchedulerContext +from vllm.engine.output_processor.util import create_output_by_sequence_group +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.outputs import RequestOutputFactory +from vllm.sampling_params import RequestOutputKind +from vllm.sequence import (SequenceGroup,SequenceGroupOutput) + +def _process_model_outputs(self, + ctx: SchedulerContext, + request_id: Optional[str] = None) -> None: + """Apply the model output to the sequences in the scheduled seq groups + and return responses. + + ctx: The virtual engine context to work on + request_id: If provided, then only this request is going to be processed + """ + + now = time.time() + + if len(ctx.output_queue) == 0: + return None + + # Get pending async postprocessor + if request_id: + # When we process only one request, no pop is required + # (since later we will process all of the rest) + (outputs, seq_group_metadata_list, scheduler_outputs, is_async, + is_last_step, is_first_step_output, skip) = ctx.output_queue[0] + else: + (outputs, seq_group_metadata_list, scheduler_outputs, is_async, + is_last_step, is_first_step_output, + skip) = ctx.output_queue.popleft() + + # Sanity check + assert len(seq_group_metadata_list) == len( + scheduler_outputs.scheduled_seq_groups) + + has_multiple_outputs: bool = len(outputs) > 1 + outputs_by_sequence_group: List[List[SequenceGroupOutput]] + if has_multiple_outputs: + assert self.scheduler_config.is_multi_step or \ + self.speculative_config + # Organize outputs by [step][sequence group] instead of + # [sequence group][step]. + if self.scheduler_config.is_multi_step: + outputs_by_sequence_group = create_output_by_sequence_group( + outputs, len(seq_group_metadata_list)) + elif self.speculative_config: + # Decodes are multi-steps while prefills are not, outputting at + # most 1 token. Separate them so that we can trigger chunk + # processing without having to pad or copy over prompts K times + # to match decodes structure (costly with prompt_logprobs). + num_prefills = sum(sg.is_prompt + for sg in seq_group_metadata_list) + prefills, decodes = outputs[:num_prefills], outputs[ + num_prefills:] + outputs_by_sequence_group = create_output_by_sequence_group( + decodes, + num_seq_groups=len(seq_group_metadata_list) - num_prefills) + outputs_by_sequence_group = [p.outputs for p in prefills + ] + outputs_by_sequence_group + # We have outputs for multiple steps submitted in a single burst, + # so invalidate is_first_step_output. + is_first_step_output = None + else: + outputs_by_sequence_group = outputs + + # Determine the requests we need to operate on + if request_id: + indices = [] + for i, seq_group_meta in enumerate(seq_group_metadata_list): + if seq_group_meta.request_id == request_id: + assert i not in skip # Cannot be called twice + indices.append(i) + break + + # If the request_id was not found, then it means that + # this is a new request that has no pending async + # postprocessor + if not indices: + return + else: + indices = range(len(outputs[0])) if len(outputs) else range(0) # type: ignore + + finished_before: List[int] = [] + finished_now: List[int] = [] + for i in indices: + if i in skip: + continue + + seq_group_meta = seq_group_metadata_list[i] + scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i] + + seq_group: SequenceGroup = scheduled_seq_group.seq_group + + if seq_group.is_finished(): + finished_before.append(i) + continue + + output: List[SequenceGroupOutput] + if has_multiple_outputs: + output = outputs_by_sequence_group[i] + else: + output = [outputs_by_sequence_group[0][i]] + + if not is_async: + if self.scheduler_config.is_multi_step: + # Updates happen only if the sequence is prefill + self._update_num_computed_tokens_for_multi_step_prefill( + seq_group, seq_group_meta, is_first_step_output) + else: + seq_group.update_num_computed_tokens( + seq_group_meta.token_chunk_size or 0) + + if outputs: + for o in outputs: + if (isinstance(o, SamplerOutput) + and seq_group.metrics is not None): + if seq_group.metrics.model_forward_time is not None: + seq_group.metrics.model_forward_time += ( + o.model_forward_time or 0) + else: + seq_group.metrics.model_forward_time = ( + o.model_forward_time) + if seq_group.metrics.model_execute_time is not None: + seq_group.metrics.model_execute_time += ( + o.model_execute_time or 0) + else: + seq_group.metrics.model_execute_time = ( + o.model_execute_time) + + if self.model_config.runner_type == "pooling": + self._process_sequence_group_outputs(seq_group, output) + else: + self.output_processor.process_prompt_logprob(seq_group, output) + if seq_group_meta.do_sample: + self.output_processor.process_outputs( + seq_group, output, is_async) + + if seq_group.is_finished(): + finished_now.append(i) + + # Generate outputs for the requests that finished this iteration + for i in finished_now: + scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i] + + seq_group = scheduled_seq_group.seq_group + seq_group.maybe_set_first_token_time(now) + if not seq_group.is_prefill(): + seq_group.set_last_token_time(now) + request_output = RequestOutputFactory.create( + seq_group, + self.seq_id_to_seq_group, + use_cache=self.use_cached_outputs) + if request_output: + ctx.request_outputs.append(request_output) + + # When we process a single request, we skip it for the next time, + # and invoke the request output callback (if there was final output) + if request_id: + assert len(indices) == 1 + skip.append(indices[0]) + + if (finished_now + and self.process_request_outputs_callback is not None): + self.process_request_outputs_callback(ctx.request_outputs) + ctx.request_outputs.clear() + return + + # Free currently finished requests + if finished_now: + for scheduler in self.scheduler: + scheduler.free_finished_seq_groups() + + # For multi-step without streaming, don't create outputs each iteration + if not is_last_step and not ctx.multi_step_stream_outputs: + # Immediately process request outputs here (if callback is given) + if (finished_now + and self.process_request_outputs_callback is not None): + self.process_request_outputs_callback(ctx.request_outputs) + ctx.request_outputs.clear() + return + + # Create the outputs + for i in indices: + if i in skip or i in finished_before or i in finished_now: + continue # Avoids double processing + + scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i] + + seq_group = scheduled_seq_group.seq_group + seq_group.maybe_set_first_token_time(now) + if not seq_group.is_prefill(): + seq_group.set_last_token_time(now) + request_output = RequestOutputFactory.create( + seq_group, + self.seq_id_to_seq_group, + use_cache=self.use_cached_outputs) + if request_output: + ctx.request_outputs.append(request_output) + + # For multi-step with streaming, create outputs each iteration + if not is_last_step and ctx.multi_step_stream_outputs: + # Immediately process request outputs here (if callback is given) + if self.process_request_outputs_callback is not None: + self.process_request_outputs_callback(ctx.request_outputs) + ctx.request_outputs.clear() + return + + for seq_group in scheduler_outputs.ignored_seq_groups: + params = seq_group.sampling_params + if params is not None and params.output_kind == ( + RequestOutputKind.DELTA) and not seq_group.is_finished(): + continue + + request_output = RequestOutputFactory.create( + seq_group, + self.seq_id_to_seq_group, + use_cache=self.use_cached_outputs, + ) + if request_output: + ctx.request_outputs.append(request_output) + + # Immediately process request outputs here (if callback is given) + if (ctx.request_outputs + and self.process_request_outputs_callback is not None): + self.process_request_outputs_callback(ctx.request_outputs) + ctx.request_outputs.clear() + + # For async case, we need to record the stats here. + # For non-async case, the stats are done in the + # LLMEngine/AsyncLLMEngine directly + if is_async: + # Log stats. + self.do_log_stats(scheduler_outputs, outputs, finished_before, + skip) + + # Tracing + self.do_tracing(scheduler_outputs, finished_before) + + return None diff --git a/vllm_mindspore/adaptive_chunk_pp/scheduler.py b/vllm_mindspore/adaptive_chunk_pp/scheduler.py index 9975b7a801961e7490797c6527d26b78db87fdf4..33a84852b106acd139dd85d0815bd1f2eaf04599 100644 --- a/vllm_mindspore/adaptive_chunk_pp/scheduler.py +++ b/vllm_mindspore/adaptive_chunk_pp/scheduler.py @@ -163,9 +163,9 @@ def add_num_batched_tokens_sub_rmd_prefill(self, # Only 1 long sequence allowed per batch # These can be added without breaking batch compute balance: # 1. Decode-phase requests - # 2. Beam search (width<=64) - # 3. Short prefills (len<=64) - if num_batched_tokens > 64: + # 2. Beam search + # 3. Short prefills (len<=1024) + if num_batched_tokens > 1024: # After scheduling the current chunk, decrement: reminding_num_prefill_seqs -= 1 self.reminding_num_prefill_seqs -= 1 # vllm-mindspore end. diff --git a/vllm_mindspore/adaptive_chunk_pp/sequence.py b/vllm_mindspore/adaptive_chunk_pp/sequence.py index 18ec9e232f4bc053cfb543c0ccfe8875a8f36676..8a230491dc75b66727897a3af99d20f31845cbe0 100644 --- a/vllm_mindspore/adaptive_chunk_pp/sequence.py +++ b/vllm_mindspore/adaptive_chunk_pp/sequence.py @@ -95,7 +95,6 @@ def patched_init( self.chunk_sizes = get_optimize_chunks(len(self.first_seq.prompt_token_ids), self.params) self.chunk_index = 0 logger.info(f'Prompt length:{len(self.first_seq.prompt_token_ids)}') - logger.info(f'Optimized chunk sizes:{self.chunk_sizes}') # vllm-mindspore end. def calculate_layer_time(params, q, pre_kv_len): @@ -143,6 +142,7 @@ def optimize_chunks(params, seq_len, chunk_num): new_chunks[min_idx] += delta if sum(new_chunks) == seq_len and all(q > 0 for q in new_chunks): chunks = new_chunks + logger.info(f'Optimized chunk sizes:{chunks}') return chunks def solve_for_q( params, target_time, pre_kv_len): @@ -162,6 +162,9 @@ def get_optimize_chunks(length, params): step_size = int(os.environ.get('CHUNK_STEP_SIZE', '2048')) # get chunk num from chunk step chunk_num = (length-1) // step_size + 1 + prefill_use_pa = bool(os.environ.get('PREFILL_BACKEND_PA', False)) + if prefill_use_pa: + return [1, length-1] return [length] if chunk_num == 1 else optimize_chunks(params, length, chunk_num) def get_next_chunk_size(self) -> Optional[int]: