diff --git a/codecheck_toolkits/vllm_codecheck.sh b/codecheck_toolkits/vllm_codecheck.sh index 150b62a55b51d4ecd1bc6163b65772664383b05e..4787260928bf44433c94dfc0771203bb848e5584 100644 --- a/codecheck_toolkits/vllm_codecheck.sh +++ b/codecheck_toolkits/vllm_codecheck.sh @@ -24,7 +24,12 @@ git add .pre-commit-config.yaml pip install -r codecheck_toolkits/requirements-lint.txt -pre-commit run --from-ref origin/develop --to-ref HEAD +RUN_GATE=${1:-0} +if [ "$RUN_GATE" -eq 0 ]; then + pre-commit run --from-ref origin/develop --to-ref HEAD +else + pre-commit run --all-files +fi RET_FLAG=$? exit ${RET_FLAG} diff --git a/docs/model_cards/qwen3/qwen3_0_6b.md b/docs/model_cards/qwen3/qwen3_0_6b.md index 7bec5567b4ea40e30e0d745b3587ce3d495add45..86b5a9d0c46cd4db24bfaabab323ba7db3b079e0 100644 --- a/docs/model_cards/qwen3/qwen3_0_6b.md +++ b/docs/model_cards/qwen3/qwen3_0_6b.md @@ -5,17 +5,14 @@ ## 目录 - - [模型介绍](#模型介绍) - [快速开始](#快速开始) - [声明](#声明) - ## 模型介绍 Qwen 大模型系列的新一代版本 —— Qwen3,在自然语言处理和多模态能力方面实现了全新突破。继承前代模型的成功经验,Qwen3 系列采用了更大规模的数据集、改进的模型架构以及更优的微调技术,使其能够应对更加复杂的推理、语言理解和生成任务。这一代模型还扩展了支持的最大 token 数量,能够生成更长、更连贯的回答,并更好地处理复杂的对话流程。 - ## 快速开始 当前支持的硬件为Atlas 800T A2服务器 @@ -78,6 +75,7 @@ docker run -itd --privileged --name=qwen3 --net=host \ ``` 进入容器,后续所有操作均在容器内操作 + ``` docker exec -it qwen3 bash ``` @@ -131,4 +129,4 @@ if __name__ == "__main__": | model_name |precision | tokens/s | |-------------------| --- |----------| -| Qwen3-0.6B | bf16 | 32.99 | \ No newline at end of file +| Qwen3-0.6B | bf16 | 32.99 | diff --git a/docs/model_cards/qwen3/qwen3_0_6b_base.md b/docs/model_cards/qwen3/qwen3_0_6b_base.md index e78cc9ccf6b97a3fe7b7b807bef899ea362207f3..c48b6a278b6d847dd26f79c840e3e6a42c97100a 100644 --- a/docs/model_cards/qwen3/qwen3_0_6b_base.md +++ b/docs/model_cards/qwen3/qwen3_0_6b_base.md @@ -5,17 +5,14 @@ ## 目录 - - [模型介绍](#模型介绍) - [快速开始](#快速开始) - [声明](#声明) - ## 模型介绍 Qwen 大模型系列的新一代版本 —— Qwen3,在自然语言处理和多模态能力方面实现了全新突破。继承前代模型的成功经验,Qwen3 系列采用了更大规模的数据集、改进的模型架构以及更优的微调技术,使其能够应对更加复杂的推理、语言理解和生成任务。这一代模型还扩展了支持的最大 token 数量,能够生成更长、更连贯的回答,并更好地处理复杂的对话流程。 - ## 快速开始 当前支持的硬件为Atlas 800T A2服务器 @@ -79,6 +76,7 @@ docker run -itd --privileged --name=qwen3 --net=host \ ``` 进入容器,后续所有操作均在容器内操作 + ``` docker exec -it qwen3 bash ``` @@ -133,10 +131,3 @@ if __name__ == "__main__": | model_name |precision | tokens/s | |-------------------| --- |----------| | Qwen3-0.6B-Base | bf16 | 35.13 | - - - - - - - diff --git a/docs/model_cards/qwen3/qwen3_1_7b.md b/docs/model_cards/qwen3/qwen3_1_7b.md index 305dbe878000fa9fbfc43decbc12adb85480f3f4..fdce759dcc44423aa6b9cb2dfe2136186b72f138 100644 --- a/docs/model_cards/qwen3/qwen3_1_7b.md +++ b/docs/model_cards/qwen3/qwen3_1_7b.md @@ -5,17 +5,14 @@ ## 目录 - - [模型介绍](#模型介绍) - [快速开始](#快速开始) - [声明](#声明) - ## 模型介绍 Qwen 大模型系列的新一代版本 —— Qwen3,在自然语言处理和多模态能力方面实现了全新突破。继承前代模型的成功经验,Qwen3 系列采用了更大规模的数据集、改进的模型架构以及更优的微调技术,使其能够应对更加复杂的推理、语言理解和生成任务。这一代模型还扩展了支持的最大 token 数量,能够生成更长、更连贯的回答,并更好地处理复杂的对话流程。 - ## 快速开始 当前支持的硬件为Atlas 800T A2服务器 @@ -78,6 +75,7 @@ docker run -itd --privileged --name=qwen3 --net=host \ ``` 进入容器,后续所有操作均在容器内操作 + ``` docker exec -it qwen3 bash ``` @@ -131,4 +129,4 @@ if __name__ == "__main__": | model_name |precision | tokens/s | |-------------------| --- |----------| -| Qwen3-1.7B | bf16 | 33.61 | \ No newline at end of file +| Qwen3-1.7B | bf16 | 33.61 | diff --git a/docs/model_cards/qwen3/qwen3_1_7b_base.md b/docs/model_cards/qwen3/qwen3_1_7b_base.md index 2a228eb93cda7faeba87f8b608ccc059e48dbcce..7c4e4936c0b324c71715672dcfe254b332b5dbb3 100644 --- a/docs/model_cards/qwen3/qwen3_1_7b_base.md +++ b/docs/model_cards/qwen3/qwen3_1_7b_base.md @@ -5,17 +5,14 @@ ## 目录 - - [模型介绍](#模型介绍) - [快速开始](#快速开始) - [声明](#声明) - ## 模型介绍 Qwen 大模型系列的新一代版本 —— Qwen3,在自然语言处理和多模态能力方面实现了全新突破。继承前代模型的成功经验,Qwen3 系列采用了更大规模的数据集、改进的模型架构以及更优的微调技术,使其能够应对更加复杂的推理、语言理解和生成任务。这一代模型还扩展了支持的最大 token 数量,能够生成更长、更连贯的回答,并更好地处理复杂的对话流程。 - ## 快速开始 当前支持的硬件为Atlas 800T A2服务器 @@ -78,6 +75,7 @@ docker run -itd --privileged --name=qwen3 --net=host \ ``` 进入容器,后续所有操作均在容器内操作 + ``` docker exec -it qwen3 bash ``` @@ -132,12 +130,3 @@ if __name__ == "__main__": | model_name |precision | tokens/s | |-------------------| --- |----------| | Qwen3-1.7B-Base | bf16 | 24.23 | - - - - - - - - - diff --git a/docs/model_cards/qwen3/qwen3_32b.md b/docs/model_cards/qwen3/qwen3_32b.md index 9519ebd50ace5a76abff253c11ab07a36c6a1f28..62baf8275cbc2573120308e12637f57fb4b1aa75 100644 --- a/docs/model_cards/qwen3/qwen3_32b.md +++ b/docs/model_cards/qwen3/qwen3_32b.md @@ -7,14 +7,12 @@ ## 模型介绍 - ### 下载链接 | 社区 | 下载地址 | |:----:|:----------------------------------------------------------| | 魔乐社区 | https://modelers.cn/models/MindSpore-Lab/Qwen3-32B | - ## 快速开始 Qwen3-32B推理至少需要1台(2卡)Atlas 800T A2(64G)服务器服务器(基于BF16权重)。昇思MindSpore提供了Qwen3-32B推理可用的Docker容器镜像,供开发者快速体验。 @@ -149,4 +147,4 @@ curl -w "\ntime_total=%{time_total}\n" -H "Accept: application/json" -H "Content ## 声明 -本文档提供的模型代码、权重文件和部署镜像,当前仅限于基于昇思MindSpore AI框架体验Qwen3-32B的部署效果,不支持生产环境部署。相关使用问题请反馈至[Issue](https://gitee.com/mindspore/mindformers/issues/new)。 \ No newline at end of file +本文档提供的模型代码、权重文件和部署镜像,当前仅限于基于昇思MindSpore AI框架体验Qwen3-32B的部署效果,不支持生产环境部署。相关使用问题请反馈至[Issue](https://gitee.com/mindspore/mindformers/issues/new)。 diff --git a/docs/model_cards/qwen3/qwen3_4b.md b/docs/model_cards/qwen3/qwen3_4b.md index 45000b3f355c5c888644b97a9ffc720f93b89147..63bb4852ace23e4eeeb5bfbab636d939923756e2 100644 --- a/docs/model_cards/qwen3/qwen3_4b.md +++ b/docs/model_cards/qwen3/qwen3_4b.md @@ -5,17 +5,14 @@ ## 目录 - - [模型介绍](#模型介绍) - [快速开始](#快速开始) - [声明](#声明) - ## 模型介绍 Qwen 大模型系列的新一代版本 —— Qwen3,在自然语言处理和多模态能力方面实现了全新突破。继承前代模型的成功经验,Qwen3 系列采用了更大规模的数据集、改进的模型架构以及更优的微调技术,使其能够应对更加复杂的推理、语言理解和生成任务。这一代模型还扩展了支持的最大 token 数量,能够生成更长、更连贯的回答,并更好地处理复杂的对话流程。 - ## 快速开始 当前支持的硬件为Atlas 800T A2服务器 @@ -78,6 +75,7 @@ docker run -itd --privileged --name=qwen3 --net=host \ ``` 进入容器,后续所有操作均在容器内操作 + ``` docker exec -it qwen3 bash ``` @@ -132,11 +130,3 @@ if __name__ == "__main__": | model_name |precision | tokens/s | |-------------------| --- |----------| | Qwen3-4B | bf16 | 27.86 | - - - - - - - - diff --git a/docs/model_cards/qwen3/qwen3_4b_base.md b/docs/model_cards/qwen3/qwen3_4b_base.md index 7a46684f784d9f7615e8798df04bd29eef4d8c9c..ce6385a7efdb1cece5d1207f4f6437b21c3bd0a9 100644 --- a/docs/model_cards/qwen3/qwen3_4b_base.md +++ b/docs/model_cards/qwen3/qwen3_4b_base.md @@ -5,17 +5,14 @@ ## 目录 - - [模型介绍](#模型介绍) - [快速开始](#快速开始) - [声明](#声明) - ## 模型介绍 Qwen 大模型系列的新一代版本 —— Qwen3,在自然语言处理和多模态能力方面实现了全新突破。继承前代模型的成功经验,Qwen3 系列采用了更大规模的数据集、改进的模型架构以及更优的微调技术,使其能够应对更加复杂的推理、语言理解和生成任务。这一代模型还扩展了支持的最大 token 数量,能够生成更长、更连贯的回答,并更好地处理复杂的对话流程。 - ## 快速开始 当前支持的硬件为Atlas 800T A2服务器 @@ -78,6 +75,7 @@ docker run -itd --privileged --name=qwen3 --net=host \ ``` 进入容器,后续所有操作均在容器内操作 + ``` docker exec -it qwen3 bash ``` @@ -127,9 +125,8 @@ if __name__ == "__main__": main(args) ``` - ### 性能如下: | model_name |precision | tokens/s | |-------------------| --- |----------| -| Qwen3-4B-Base | bf16 | 26.48 | \ No newline at end of file +| Qwen3-4B-Base | bf16 | 26.48 | diff --git a/docs/model_cards/qwen3/qwen3_8b.md b/docs/model_cards/qwen3/qwen3_8b.md index 93ec883cdb492c20d7c0b3e8973d6a8e580ad60d..b0eeee372c9dac944cec516d0730816435fabf73 100644 --- a/docs/model_cards/qwen3/qwen3_8b.md +++ b/docs/model_cards/qwen3/qwen3_8b.md @@ -5,17 +5,14 @@ ## 目录 - - [模型介绍](#模型介绍) - [快速开始](#快速开始) - [声明](#声明) - ## 模型介绍 Qwen 大模型系列的新一代版本 —— Qwen3,在自然语言处理和多模态能力方面实现了全新突破。继承前代模型的成功经验,Qwen3 系列采用了更大规模的数据集、改进的模型架构以及更优的微调技术,使其能够应对更加复杂的推理、语言理解和生成任务。这一代模型还扩展了支持的最大 token 数量,能够生成更长、更连贯的回答,并更好地处理复杂的对话流程。 - ## 快速开始 当前支持的硬件为Atlas 800T A2服务器 @@ -78,6 +75,7 @@ docker run -itd --privileged --name=qwen3 --net=host \ ``` 进入容器,后续所有操作均在容器内操作 + ``` docker exec -it qwen3 bash ``` @@ -131,4 +129,4 @@ if __name__ == "__main__": | model_name |precision | tokens/s | |-------------------| --- |----------| -| Qwen3-8B | bf16 | 26.08 | \ No newline at end of file +| Qwen3-8B | bf16 | 26.08 | diff --git a/docs/model_cards/qwen3/qwen3_8b_base.md b/docs/model_cards/qwen3/qwen3_8b_base.md index 50d5386b680b9944e80cecd1df6969eae0d04ecd..3524fbdc064c9789cca0b75ad5ad1849aee76efe 100644 --- a/docs/model_cards/qwen3/qwen3_8b_base.md +++ b/docs/model_cards/qwen3/qwen3_8b_base.md @@ -5,17 +5,14 @@ ## 目录 - - [模型介绍](#模型介绍) - [快速开始](#快速开始) - [声明](#声明) - ## 模型介绍 Qwen 大模型系列的新一代版本 —— Qwen3,在自然语言处理和多模态能力方面实现了全新突破。继承前代模型的成功经验,Qwen3 系列采用了更大规模的数据集、改进的模型架构以及更优的微调技术,使其能够应对更加复杂的推理、语言理解和生成任务。这一代模型还扩展了支持的最大 token 数量,能够生成更长、更连贯的回答,并更好地处理复杂的对话流程。 - ## 快速开始 当前支持的硬件为Atlas 800T A2服务器 @@ -78,6 +75,7 @@ docker run -itd --privileged --name=qwen3 --net=host \ ``` 进入容器,后续所有操作均在容器内操作 + ``` docker exec -it qwen3 bash ``` @@ -132,12 +130,3 @@ if __name__ == "__main__": | model_name |precision | tokens/s | |:--- |:--- |:--- | |Qwen3-8B-Base| bf16 | 24.27. | - - - - - - - - - diff --git a/vllm_mindspore/lora/utils.py b/vllm_mindspore/lora/utils.py index fc10df8cf4928fa43a9b7d2c28dadafd9b994ca1..0a96b5559c0400de2d5a638695207dc77497ecbd 100644 --- a/vllm_mindspore/lora/utils.py +++ b/vllm_mindspore/lora/utils.py @@ -15,22 +15,27 @@ # limitations under the License. """Unified interface for LoRA layers in vllm-mindspore.""" -from typing import Set, Type - from vllm.lora.fully_sharded_layers import ( ColumnParallelLinearWithShardedLoRA, MergedColumnParallelLinearWithShardedLoRA, MergedQKVParallelLinearWithShardedLoRA, QKVParallelLinearWithShardedLoRA, RowParallelLinearWithShardedLoRA) -from vllm_mindspore.lora.layers import ( - BaseLayerWithLoRA, ColumnParallelLinearWithLoRA, - LinearScalingRotaryEmbeddingWithLoRA, LogitsProcessorWithLoRA, - MergedColumnParallelLinearWithLoRA, MergedQKVParallelLinearWithLoRA, - QKVParallelLinearWithLoRA, RowParallelLinearWithLoRA, - VocabParallelEmbeddingWithLoRA) +# yapf conflicts with isort for this block +# yapf: disable +from vllm_mindspore.lora.layers import (BaseLayerWithLoRA, + ColumnParallelLinearWithLoRA, + LinearScalingRotaryEmbeddingWithLoRA, + LogitsProcessorWithLoRA, + MergedColumnParallelLinearWithLoRA, + MergedQKVParallelLinearWithLoRA, + QKVParallelLinearWithLoRA, + RowParallelLinearWithLoRA, + VocabParallelEmbeddingWithLoRA) + +# yapf: enable -_all_lora_classes: Set[Type[BaseLayerWithLoRA]] = { +_all_lora_classes: set[type[BaseLayerWithLoRA]] = { VocabParallelEmbeddingWithLoRA, ColumnParallelLinearWithLoRA, MergedColumnParallelLinearWithLoRA, diff --git a/vllm_mindspore/model_executor/layers/sampler.py b/vllm_mindspore/model_executor/layers/sampler.py index 46aae6cbb443841927bc7a0a2381061ca7da8b30..d4f7830ec6d4cc7cf47c3bf2ec7c84389d75f776 100644 --- a/vllm_mindspore/model_executor/layers/sampler.py +++ b/vllm_mindspore/model_executor/layers/sampler.py @@ -20,37 +20,34 @@ """A layer that samples the next tokens from the model's outputs.""" import itertools import warnings -import mindspore as ms -import numpy as np +from collections.abc import Iterator from dataclasses import dataclass from importlib.util import find_spec from math import inf -from typing import Dict, Iterator, List, Optional, Tuple, Union +from typing import Optional, Union +import mindspore as ms import msgspec import torch import torch.nn as nn - import vllm.envs as envs +from vllm.model_executor.sampling_metadata import (SamplingMetadata, + SamplingTensors, + SequenceGroupToSample) from vllm.sampling_params import SamplingType from vllm.sequence import (VLLM_INVALID_TOKEN_ID, CompletionSequenceGroupOutput, Logprob, PromptLogprobs, SampleLogprobs, SequenceOutput) from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics + from vllm_mindspore.model_executor.layers.utils import apply_penalties -from vllm.model_executor.sampling_metadata import ( - SamplingMetadata, - SamplingTensors -) -from vllm.model_executor.sampling_metadata import SequenceGroupToSample if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"): - raise RuntimeError("Donot support for mindspore now.") + raise RuntimeError("Do not support for mindspore now.") else: flashinfer_top_k_top_p_sampling = None - def get_sampler() -> torch.nn.Module: if envs.VLLM_USE_V1: # Lazy import: the v1 package isn't distributed @@ -60,14 +57,14 @@ def get_sampler() -> torch.nn.Module: # (num_token_ids, num_parent_ids) per sequence group. -SampleResultType = List[Tuple[List[int], List[int]]] +SampleResultType = list[tuple[list[int], list[int]]] # Types of temporary data structures used for # computing sample_result -SampleMetadataType = Dict[SamplingType, Tuple[List[int], - List[SequenceGroupToSample]]] -MultinomialSamplesType = Dict[SamplingType, torch.Tensor] -SampleResultsDictType = Dict[int, Tuple[List[int], List[int]]] +SampleMetadataType = dict[SamplingType, tuple[list[int], + list[SequenceGroupToSample]]] +MultinomialSamplesType = dict[SamplingType, torch.Tensor] +SampleResultsDictType = dict[int, tuple[list[int], list[int]]] # Encapsulates temporary data structures for computing @@ -94,7 +91,7 @@ class SampleResultArgsType: MaybeDeferredSampleResultType = Union[SampleResultType, SampleResultArgsType] # Abbreviation of the _sample() return type -SampleReturnType = Tuple[MaybeDeferredSampleResultType, Optional[torch.Tensor]] +SampleReturnType = tuple[MaybeDeferredSampleResultType, Optional[torch.Tensor]] class SamplerOutput( @@ -108,7 +105,7 @@ class SamplerOutput( also has optional fields for device tensors. """ - outputs: List[CompletionSequenceGroupOutput] + outputs: list[CompletionSequenceGroupOutput] # On-device tensor containing probabilities of each token. sampled_token_probs: Optional[torch.Tensor] = None @@ -370,7 +367,7 @@ def _apply_min_tokens_penalty( have not been generated yet """ # list of indices in logits that will be set to -inf - logits_to_penalize: List[Tuple[int, int]] = [] + logits_to_penalize: list[tuple[int, int]] = [] logits_applied = 0 for seq_group in sampling_metadata.seq_groups: seq_ids = seq_group.seq_ids @@ -386,7 +383,7 @@ def _apply_min_tokens_penalty( min_tokens = sampling_params.min_tokens token_ids_to_penalize = sampling_params.all_stop_token_ids if min_tokens > 0 and token_ids_to_penalize: - seqs_to_penalize: List[int] = [] + seqs_to_penalize: list[int] = [] for j, seq_id in enumerate(seq_ids): seq_data = seq_group.seq_data[seq_id] if len(seq_data.output_token_ids_array) < min_tokens: @@ -457,7 +454,7 @@ def _apply_min_p( def _greedy_sample( - selected_seq_groups: List[SequenceGroupToSample], + selected_seq_groups: list[SequenceGroupToSample], samples: torch.Tensor, ) -> SampleResultType: """Run greedy sampling on a given samples. @@ -492,7 +489,7 @@ def _greedy_sample( def _random_sample( - selected_seq_groups: List[SequenceGroupToSample], + selected_seq_groups: list[SequenceGroupToSample], random_samples: torch.Tensor, ) -> SampleResultType: """Run random sampling on a given samples. @@ -536,7 +533,7 @@ def _random_sample( def _beam_search_sample( - selected_seq_groups: List[SequenceGroupToSample], + selected_seq_groups: list[SequenceGroupToSample], logprobs: torch.Tensor, ) -> SampleResultType: """Run beam sampling on a given samples. @@ -581,7 +578,7 @@ def _beam_search_sample( next_token_ids = next_token_ids.tolist() else: # Generation phase. - cumulative_logprobs: List[float] = [ + cumulative_logprobs: list[float] = [ seq_group.seq_data[seq_id].cumulative_logprob for seq_id in seq_ids ] @@ -611,7 +608,7 @@ def _beam_search_sample( def _multinomial( probs: torch.Tensor, num_samples: int, - seq_groups: Optional[List[SequenceGroupToSample]] = None, + seq_groups: Optional[list[SequenceGroupToSample]] = None, ) -> torch.Tensor: if num_samples > 1: probs = probs.repeat_interleave(num_samples, dim=0) @@ -624,7 +621,7 @@ def _multinomial( seq_ids = seq_group.seq_ids stride = len(seq_ids) * num_samples assert seq_group.generator is not None - q[sample_idx : sample_idx + + q[sample_idx:sample_idx + stride].exponential_(generator=seq_group.generator) sample_idx += stride return probs.div_(q).argmax(dim=1).view(-1, num_samples) @@ -632,7 +629,7 @@ def _multinomial( def _top_k_top_p_multinomial_with_flashinfer( probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor, - num_samples: int, seq_groups: Optional[List[SequenceGroupToSample]]): + num_samples: int, seq_groups: Optional[list[SequenceGroupToSample]]): max_top_k_round = 32 if num_samples > 1: probs = probs.repeat_interleave(num_samples, dim=0) @@ -661,9 +658,11 @@ def _top_k_top_p_multinomial_with_flashinfer( if not success.all(): warnings.warn("FlashInfer rejection sampling failed, fallback.", stacklevel=1) - probs = flashinfer.sampling.top_k_renorm_prob(probs, top_ks) - probs = flashinfer.sampling.top_p_renorm_prob(probs, top_ps) - batch_next_token_ids = flashinfer.sampling.sampling_from_probs( + probs = flashinfer.sampling.top_k_renorm_prob( # noqa: F821 + probs, top_ks) + probs = flashinfer.sampling.top_p_renorm_prob( # noqa: F821 + probs, top_ps) + batch_next_token_ids = flashinfer.sampling.sampling_from_probs( # noqa: F821 probs, uniform_samples[0]) return batch_next_token_ids.view(-1, num_samples) @@ -737,9 +736,10 @@ def _sample_with_torch( tensors required for Pythonization ''' - categorized_seq_group_ids: Dict[SamplingType, - List[int]] = {t: [] - for t in SamplingType} + categorized_seq_group_ids: dict[SamplingType, list[int]] = { + t: [] + for t in SamplingType + } categorized_sample_indices = sampling_metadata.categorized_sample_indices for i, seq_group in enumerate(sampling_metadata.seq_groups): sampling_params = seq_group.sampling_params @@ -900,7 +900,7 @@ def get_logprobs( logprobs: torch.Tensor, sampling_metadata: SamplingMetadata, sample_results: SampleResultType, -) -> Tuple[List[Optional[PromptLogprobs]], List[SampleLogprobs]]: +) -> tuple[list[Optional[PromptLogprobs]], list[SampleLogprobs]]: """Return sample logprobs and prompt logprobs. The logic consists of 3 parts. @@ -929,9 +929,9 @@ def get_logprobs( """ # The index of query token to calculate logprobs. It includes both # prompt and sample logprob indices. - query_indices: List[int] = [] + query_indices: list[int] = [] # The next token ids to get the logprob value from. - next_token_ids: List[int] = [] + next_token_ids: list[int] = [] # The largest requested number of logprobs. We find logprobs as many as the # largest num logprobs in this API. If every logprobs is None, it will be # set to -1. @@ -1011,8 +1011,8 @@ def get_logprobs( ranks = ranks.to('cpu') # Find prompt/sample logprobs. - prompt_logprobs_per_seq_group: List[Optional[PromptLogprobs]] = [] - sample_logprobs_per_seq_group: List[SampleLogprobs] = [] + prompt_logprobs_per_seq_group: list[Optional[PromptLogprobs]] = [] + sample_logprobs_per_seq_group: list[SampleLogprobs] = [] top_logprob_idx = 0 selected_logprobs_idx = 0 @@ -1063,7 +1063,7 @@ def _get_prompt_logprob_if_needed( for idx, token_id in enumerate(next_prompt_tokens): # Calculate the prompt logprob of the real prompt tokens. # {token_id: (logprob, rank_from_vocab)} - prompt_logprobs_dict: Dict[int, Tuple[float, int]] = { + prompt_logprobs_dict: dict[int, tuple[float, int]] = { token_id: (selected_logprob_items[idx], rank_items[idx]) } @@ -1095,7 +1095,7 @@ def _get_prompt_logprob_if_needed( def _get_sampled_logprob_if_needed( seq_group: SequenceGroupToSample, - sample_result: Tuple[List[int], List[int]], + sample_result: tuple[list[int], list[int]], selected_logprobs: torch.Tensor, ranks: torch.Tensor, top_token_ids: torch.Tensor, @@ -1216,9 +1216,9 @@ def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor, def _build_sampler_output( maybe_deferred_sample_results: MaybeDeferredSampleResultType, sampling_metadata: SamplingMetadata, - prompt_logprobs: Optional[List[Optional[PromptLogprobs]]], - sample_logprobs: Optional[List[SampleLogprobs]], - on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor, + prompt_logprobs: Optional[list[Optional[PromptLogprobs]]], + sample_logprobs: Optional[list[SampleLogprobs]], + on_device_tensors: Optional[tuple[torch.Tensor, torch.Tensor, torch.Tensor]], skip_sampler_cpu_output: bool = False, ) -> SamplerOutput: @@ -1230,7 +1230,7 @@ def _build_sampler_output( allows post-processing without copies to CPU/serialization, e.g. in speculative decoding rejection sampling. """ - sampler_output: List[CompletionSequenceGroupOutput] = [] + sampler_output: list[CompletionSequenceGroupOutput] = [] if skip_sampler_cpu_output: assert isinstance(maybe_deferred_sample_results, SampleResultArgsType) @@ -1248,7 +1248,7 @@ def _build_sampler_output( prompt_logprobs, sample_logprobs): seq_ids = seq_group.seq_ids next_token_ids, parent_ids = sample_result - seq_outputs: List[SequenceOutput] = [] + seq_outputs: list[SequenceOutput] = [] for parent_id, next_token_id, logprobs in zip( parent_ids, next_token_ids, group_sample_logprobs): seq_outputs.append( @@ -1274,7 +1274,7 @@ def _build_sampler_output( deferred_sample_results_args=deferred_sample_results_args) -def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]: +def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> list[int]: """Get a list of next prompt tokens to compute logprob from a given sequence group. @@ -1305,4 +1305,3 @@ def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]: next_prompt_tokens = prompt_tokens[ next_token_index_start:next_token_index_end] return next_prompt_tokens - diff --git a/vllm_mindspore/model_executor/models/mf_models/qwen2.py b/vllm_mindspore/model_executor/models/mf_models/qwen2.py index a24a0f1b98ed603b8f84f424a1b712c0a01f7a83..55992c753fca7254ef6c8ded2dc9b67f1cee4b09 100644 --- a/vllm_mindspore/model_executor/models/mf_models/qwen2.py +++ b/vllm_mindspore/model_executor/models/mf_models/qwen2.py @@ -14,58 +14,62 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Iterable, Set, Tuple - -from vllm.config import VllmConfig -from vllm.config import get_current_vllm_config -from vllm.logger import init_logger -import vllm.envs as envs - - -from mindspore import Tensor, JitConfig -from mindspore.nn.utils import no_init_parameters +from collections.abc import Iterable from mindformers.models.llama import LlamaConfig as LlamaConfig_MF +from mindspore import Tensor +from mindspore.nn.utils import no_init_parameters +# yapf conflict with isort +# yapf: disable from research.qwen2_5.infer.qwen2_5 import ( - ParallelQwenForCausalLM as ParallelQwenForCausalLM_MF, -) + ParallelQwenForCausalLM as ParallelQwenForCausalLM_MF) +# yapf: enable +from vllm.config import VllmConfig, get_current_vllm_config +from vllm.logger import init_logger from vllm_mindspore.model_executor.layers.sampler import get_sampler +from vllm_mindspore.model_executor.models.mf_models.mf_model_base import ( + MfModelBase) +from vllm_mindspore.model_executor.models.mf_models.qwen2_weight_processor import ( # noqa: E501 + Qwen2WeightProcessor) from vllm_mindspore.model_executor.models.model_base import AttentionWrapper -from vllm_mindspore.model_executor.models.mf_models.mf_model_base import MfModelBase -from vllm_mindspore.model_executor.models.mf_models.qwen2_weight_processor import Qwen2WeightProcessor - logger = init_logger(__name__) + class Qwen2ForCausalLM(MfModelBase): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: - super(Qwen2ForCausalLM, self).__init__(vllm_config=vllm_config, prefix=prefix) + super().__init__(vllm_config=vllm_config, prefix=prefix) self.mf_kvcaches_init = False self.sampler = get_sampler() self.set_modules({"model": self.network}) - self.kv_caches = [AttentionWrapper() for i in range(self.mf_model_config.num_layers)] + self.kv_caches = [ + AttentionWrapper() for i in range(self.mf_model_config.num_layers) + ] compilation_config = get_current_vllm_config().compilation_config if prefix in compilation_config.static_forward_context: raise ValueError(f"Duplicate layer name: {prefix}") for i in range(self.mf_model_config.num_layers): - compilation_config.static_forward_context[str(i)] = self.kv_caches[i] + compilation_config.static_forward_context[str( + i)] = self.kv_caches[i] self.set_flags = False def _generate_model_config(self): self.mf_config.load_checkpoint = self.get_model_path() - self.mf_model_config = LlamaConfig_MF(**self.mf_config.model.model_config) + self.mf_model_config = LlamaConfig_MF( + **self.mf_config.model.model_config) if self.mf_config.moe_config: self.mf_model_config.moe_config = self.mf_config.moe_config self.mf_model_config.return_hidden_states = True # qwen qkv concat will support in next version self.mf_model_config.qkv_concat = False - setattr(self.mf_model_config, 'npu_mem_size', -1) + self.mf_model_config.npu_mem_size = -1 self.mf_config.model.model_config.qkv_concat = False def _create_network(self): @@ -74,8 +78,9 @@ class Qwen2ForCausalLM(MfModelBase): network = ParallelQwenForCausalLM_MF(self.mf_model_config) return network, network.lm_head - def load_weights(self, weights: Iterable[Tuple[str, Tensor]]) -> Set[str]: - weight_processor = Qwen2WeightProcessor(self.mf_config, self.network, False) + def load_weights(self, weights: Iterable[tuple[str, Tensor]]) -> set[str]: + weight_processor = Qwen2WeightProcessor(self.mf_config, self.network, + False) weight_processor.load_safetensors_shard(self.mf_config.load_checkpoint) return None diff --git a/vllm_mindspore/model_executor/sampling_metadata.py b/vllm_mindspore/model_executor/sampling_metadata.py index 8505c4595a9bdeefa334c9e83c1ead69b75cc9b8..483b1e4692d22a3a010fabbe2508c88c117f3397 100644 --- a/vllm_mindspore/model_executor/sampling_metadata.py +++ b/vllm_mindspore/model_executor/sampling_metadata.py @@ -20,18 +20,13 @@ from array import array from dataclasses import dataclass -from typing import List -from vllm.utils import ( - is_pin_memory_available, - make_tensor_with_pad, -) +import mindspore as ms +from mindspore import Tensor +from vllm.utils import is_pin_memory_available, make_tensor_with_pad _SAMPLING_EPS = 1e-5 -from mindspore import Tensor -import mindspore as ms - @dataclass class SamplingTensors: @@ -50,15 +45,15 @@ class SamplingTensors: @classmethod def from_lists( cls, - temperatures: List[float], - top_ps: List[float], - top_ks: List[int], - min_ps: List[float], - presence_penalties: List[float], - frequency_penalties: List[float], - repetition_penalties: List[float], - prompt_tokens: List[array], - output_tokens: List[array], + temperatures: list[float], + top_ps: list[float], + top_ks: list[int], + min_ps: list[float], + presence_penalties: list[float], + frequency_penalties: list[float], + repetition_penalties: list[float], + prompt_tokens: list[array], + output_tokens: list[array], vocab_size: int, device, dtype, diff --git a/vllm_mindspore/v1/sample/ops/penalties.py b/vllm_mindspore/v1/sample/ops/penalties.py index 3ae468906c9f0f21d13cd9dadb17926db2eab280..8c68f8af13c9c21c5febd141b1e57c5d5129c449 100644 --- a/vllm_mindspore/v1/sample/ops/penalties.py +++ b/vllm_mindspore/v1/sample/ops/penalties.py @@ -18,13 +18,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List - import torch from vllm.utils import is_pin_memory_available, make_tensor_with_pad -def _convert_to_tensors(output_token_ids: List[List[int]], vocab_size: int, +def _convert_to_tensors(output_token_ids: list[list[int]], vocab_size: int, device: torch.device) -> torch.Tensor: """ Convert the different list data structures to tensors. diff --git a/vllm_mindspore/v1/sample/ops/topk_topp_sampler.py b/vllm_mindspore/v1/sample/ops/topk_topp_sampler.py index bfe2e74f0bbe00bd76648707379d38a6e82df394..694ce11a0bb81a6c282d2ed6a63e2f54e1b0359e 100644 --- a/vllm_mindspore/v1/sample/ops/topk_topp_sampler.py +++ b/vllm_mindspore/v1/sample/ops/topk_topp_sampler.py @@ -19,6 +19,7 @@ # limitations under the License. from typing import Optional + import torch from mindspore import mint