diff --git a/vllm_mindspore/worker/worker.py b/vllm_mindspore/worker/worker.py index 8ce1bc91d511a43a83fd3c8b0e70d228b98b951b..57d0cef74d9001c7bbcbab6b2a77803fe915a25e 100644 --- a/vllm_mindspore/worker/worker.py +++ b/vllm_mindspore/worker/worker.py @@ -48,11 +48,16 @@ def _prepare_input_for_warmup(model_config, model_runner, cache_engine, is_prefi seq_len = model_runner.scheduler_config.max_num_batched_tokens if is_prefill else 1 dummy_data = model_runner.input_registry.dummy_data_for_profiling(model_config, seq_len, model_runner.mm_registry) block_tables = [i for i in range(math.ceil(seq_len / cache_engine.block_size))] + + seq_data = dummy_data.seq_data + if seq_len == 1: + seq_data = dummy_data.seq_data.from_prompt_token_counts((0, seq_len)) + seqs = [ SequenceGroupMetadata( request_id=str(idx), is_prompt=is_prefill, - seq_data={idx: dummy_data.seq_data}, + seq_data={idx: seq_data}, sampling_params=SamplingParams(), block_tables={idx: block_tables}, lora_request=None,