From a347182ba2d9be34e214d5e28064dd45ab20eed0 Mon Sep 17 00:00:00 2001 From: zhanzhan1 Date: Thu, 19 Jun 2025 17:26:18 +0800 Subject: [PATCH] optimizer for qwen --- vllm_mindspore/attention/layer.py | 22 ++++++--------- .../distributed/communication_op.py | 2 +- .../model_executor/layers/activation.py | 18 +++++++++--- .../model_executor/layers/layernorm.py | 14 ++++++---- .../model_executor/layers/linear.py | 12 ++++---- .../model_executor/layers/logits_processor.py | 5 ++-- .../model_executor/layers/rotary_embedding.py | 24 +++++++--------- .../layers/vocab_parallel_embedding.py | 15 ++++++---- .../model_executor/models/model_base.py | 17 +++++++++++ vllm_mindspore/model_executor/models/qwen2.py | 28 +++++++++++++------ 10 files changed, 97 insertions(+), 60 deletions(-) diff --git a/vllm_mindspore/attention/layer.py b/vllm_mindspore/attention/layer.py index f4af1afb..6b9a10fb 100644 --- a/vllm_mindspore/attention/layer.py +++ b/vllm_mindspore/attention/layer.py @@ -26,6 +26,8 @@ from vllm.config import CacheConfig from vllm.model_executor.layers.quantization.base_config import \ QuantizationConfig +from vllm_mindspore.model_executor.models.model_base import NativeCell + def _pad_to_max_tensor(input_: Tensor, max_len: int, @@ -79,7 +81,7 @@ def _hidden_states_bsh2th(input_: Tensor, return th_output -class Attention(nn.Cell): +class Attention(NativeCell): """Attention layer. This class takes query, key, and value tensors as input. The input tensors @@ -160,22 +162,16 @@ class Attention(nn.Cell): batch_valid_length: shape = [batch_size, ] block_tables: shape = [block_size, num_block] """ - output = query - # ensure that the input tensors of reshape_and_cache is continuous - key = key.contiguous() - value = value.contiguous() cache_out = self.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping) query = ops.depend(query, cache_out) if is_prefill: - output = self._run_prefill_forward(query, key, value, attn_mask, - batch_valid_length, - batch_valid_length) - else: - output = self._run_decode_forward(query, key_cache, value_cache, - block_tables, batch_valid_length, - attn_mask, q_seq_lens) - return output + return self._run_prefill_forward(query, key, value, attn_mask, + batch_valid_length, + batch_valid_length) + return self._run_decode_forward(query, key_cache, value_cache, + block_tables, batch_valid_length, + attn_mask, q_seq_lens) def _run_prefill_forward( self, diff --git a/vllm_mindspore/distributed/communication_op.py b/vllm_mindspore/distributed/communication_op.py index 00447432..9789794a 100644 --- a/vllm_mindspore/distributed/communication_op.py +++ b/vllm_mindspore/distributed/communication_op.py @@ -36,7 +36,7 @@ def tensor_model_parallel_all_reduce(input_: Tensor) -> Tensor: if get_tensor_model_parallel_world_size() == 1: return input_ """All-reduce the input tensor across model parallel group.""" - output, _ = all_reduce(input_, group=get_tp_group()) + output, _ = all_reduce(input_, group=get_tp_group().device_group._name) return output diff --git a/vllm_mindspore/model_executor/layers/activation.py b/vllm_mindspore/model_executor/layers/activation.py index a1d94eca..bf80c9e7 100644 --- a/vllm_mindspore/model_executor/layers/activation.py +++ b/vllm_mindspore/model_executor/layers/activation.py @@ -17,6 +17,11 @@ # ============================================================================ from mindspore import Tensor, mint, nn, ops +from mindspore.ops.auto_generate.gen_ops_prim import Swiglu + +from vllm.config import get_current_vllm_config + +from vllm_mindspore.model_executor.models.model_base import NativeCell class SiluAndMul(nn.Cell): @@ -34,7 +39,7 @@ class SiluAndMul(nn.Cell): return mint.nn.functional.silu(x[..., :d]) * x[..., d:] -class SwiGLU(nn.Cell): +class SwiGLU(NativeCell): """An activation function for SwiGLU. Shapes: @@ -44,14 +49,19 @@ class SwiGLU(nn.Cell): def __init__(self): super().__init__() - self.silu = nn.SiLU() self.split = ops.auto_generate.SplitWithSize() self.mul = ops.Mul() + self.swiglu = Swiglu() + enforce_eager = get_current_vllm_config().model_config.enforce_eager + self.construct = self.forward_eager if enforce_eager else self.forward_graph - def construct(self, x: Tensor) -> Tensor: + def forward_graph(self, x: Tensor) -> Tensor: hidden_size = x.shape[-1] // 2 size = [hidden_size, hidden_size] gate, hidden = self.split(x, size, dim=-1) - gate = self.silu(gate) + gate = mint.nn.functional.silu(gate) hidden = self.mul(hidden, gate) return hidden + + def forward_eager(self, x: Tensor) -> Tensor: + return self.swiglu(x) diff --git a/vllm_mindspore/model_executor/layers/layernorm.py b/vllm_mindspore/model_executor/layers/layernorm.py index 3e0251cb..48008347 100644 --- a/vllm_mindspore/model_executor/layers/layernorm.py +++ b/vllm_mindspore/model_executor/layers/layernorm.py @@ -22,11 +22,13 @@ from mindspore import Parameter, Tensor, mint, ops from mindspore.common import dtype as mstype from mindspore.common.dtype import typing from mindspore import nn +from mindspore.ops.auto_generate.gen_ops_prim import AddRmsNorm from vllm.config import get_current_vllm_config +from vllm_mindspore.model_executor.models.model_base import NativeCell -class RMSNorm(nn.Cell): +class RMSNorm(NativeCell): def __init__( self, hidden_size: int, @@ -38,7 +40,9 @@ class RMSNorm(nn.Cell): if params_dtype is None: params_dtype = get_current_vllm_config().model_config.dtype self.weight = Parameter(mint.ones(hidden_size, dtype=params_dtype)) + self.eps = eps self.rms_norm = ops.RmsNorm(eps) + self.add_rms_norm = AddRmsNorm() def construct( self, @@ -46,9 +50,7 @@ class RMSNorm(nn.Cell): residual: Optional[Tensor] = None ) -> Union[Tensor, Tuple[Tensor, Tensor]]: if residual is not None: - x = x + residual - residual = x + output, _, residual = self.add_rms_norm(x, residual, self.weight, self.eps) + return output, residual output = self.rms_norm(x, self.weight)[0] - if residual is None: - return output - return output, residual + return output diff --git a/vllm_mindspore/model_executor/layers/linear.py b/vllm_mindspore/model_executor/layers/linear.py index e0851149..9bcfe270 100644 --- a/vllm_mindspore/model_executor/layers/linear.py +++ b/vllm_mindspore/model_executor/layers/linear.py @@ -30,15 +30,15 @@ from vllm.distributed import ( get_tensor_model_parallel_world_size, split_tensor_along_last_dim, tensor_model_parallel_all_gather, - tensor_model_parallel_all_reduce, ) from vllm.config import get_current_vllm_config from vllm_mindspore.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, ) +from vllm_mindspore.model_executor.models.model_base import NativeCell from vllm_mindspore.model_executor.utils import set_weight_attrs -from vllm_mindspore.distributed.communication_op import ReduceFromModelParallelRegion +from vllm_mindspore.distributed.communication_op import ReduceFromModelParallelRegion, tensor_model_parallel_all_reduce WEIGHT_LOADER_V2_SUPPORTED = [ "CompressedTensorsLinearMethod", @@ -140,7 +140,7 @@ class UnquantizedLinearMethod(LinearMethodBase): return x -class LinearBase(ms.nn.Cell): +class LinearBase(NativeCell): """Base linear layer. Args: @@ -169,8 +169,9 @@ class LinearBase(ms.nn.Cell): self.input_size = input_size self.output_size = output_size self.skip_bias_add = skip_bias_add + self.model_config = get_current_vllm_config().model_config if params_dtype is None: - params_dtype = get_current_vllm_config().model_config.dtype + params_dtype = self.model_config.dtype self.params_dtype = params_dtype if quant_config is None: self.quant_method: Optional[QuantizeMethodBase] = UnquantizedLinearMethod() @@ -556,7 +557,8 @@ class RowParallelLinear(LinearBase): # self.register_parameter("bias", None) self.bias = None - self.tensor_model_parallel_all_reduce = ReduceFromModelParallelRegion() + self.tensor_model_parallel_all_reduce = tensor_model_parallel_all_reduce \ + if self.model_config.enforce_eager else ReduceFromModelParallelRegion() def construct(self, input_): if self.input_is_parallel: diff --git a/vllm_mindspore/model_executor/layers/logits_processor.py b/vllm_mindspore/model_executor/layers/logits_processor.py index 5d603694..01ee3e5a 100644 --- a/vllm_mindspore/model_executor/layers/logits_processor.py +++ b/vllm_mindspore/model_executor/layers/logits_processor.py @@ -32,6 +32,7 @@ from vllm.distributed import ( from vllm_mindspore.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ) +from vllm_mindspore.model_executor.models.model_base import NativeCell from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.platforms import current_platform @@ -41,7 +42,7 @@ if envs.VLLM_LOGITS_PROCESSOR_THREADS is not None: _logits_processor_threadpool = ThreadPoolExecutor( envs.VLLM_LOGITS_PROCESSOR_THREADS) -class LogitsProcessor(nn.Cell): +class LogitsProcessor(NativeCell): """Process logits and apply logits processors from sampling metadata. This layer does the following: @@ -74,7 +75,7 @@ class LogitsProcessor(nn.Cell): # Whether to use gather or all-gather to gather the logits. parallel_config = get_current_vllm_config().parallel_config self.use_all_gather = envs.VLLM_USE_V1 \ - or parallel_config.distributed_executor_backend == "external_launcher" + or parallel_config.distributed_executor_backend == "external_launcher" def construct( self, diff --git a/vllm_mindspore/model_executor/layers/rotary_embedding.py b/vllm_mindspore/model_executor/layers/rotary_embedding.py index ff6ea4da..1a14e04c 100644 --- a/vllm_mindspore/model_executor/layers/rotary_embedding.py +++ b/vllm_mindspore/model_executor/layers/rotary_embedding.py @@ -25,6 +25,8 @@ from mindspore.common import dtype as mstype from transformers import PretrainedConfig from vllm.config import get_current_vllm_config +from vllm_mindspore.model_executor.models.model_base import NativeCell + def _apply_rotary_emb( x: Tensor, @@ -132,7 +134,7 @@ class RotaryEmbedding(nn.Cell): return query, key -class InferRotaryEmbedding(nn.Cell): +class InferRotaryEmbedding(NativeCell): def __init__( self, @@ -182,23 +184,17 @@ class InferRotaryEmbedding(nn.Cell): def construct( self, - positions: Tensor, query: Tensor, key: Tensor, + freqs_cos: Tensor, + freqs_sin: Tensor, batch_valid_length: Tensor, - is_prefill: bool, - offsets: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor]: - query = query.contiguous() - key = key.contiguous() - if is_prefill: - return self.rotary_embedding_op(query, key, self.freqs_cos, - self.freqs_sin, batch_valid_length) + return self.rotary_embedding_op(query, key, freqs_cos, + freqs_sin, batch_valid_length) - freqs_cos = self.gather(self.freqs_cos, positions, 0) - freqs_sin = self.gather(self.freqs_sin, positions, 0) - return self.rotary_embedding_op(query, key, freqs_cos, freqs_sin, - batch_valid_length) + def get_freqs(self, positions: Tensor) -> Tuple[Tensor, Tensor]: + return mint.index_select(self.freqs_cos, 0, positions), mint.index_select(self.freqs_sin, 0, positions) class InferLlama3RotaryEmbedding(InferRotaryEmbedding): @@ -431,7 +427,7 @@ class MRotaryEmbedding(RotaryEmbedding): t_index = (ops.arange(llm_grid_t).view(-1, 1).broadcast_to( (-1, llm_grid_h * llm_grid_w)) * video_second_per_grid_t * - tokens_per_second).int().flatten() + tokens_per_second).int().flatten() h_index = ops.arange(llm_grid_h).view(1, -1, 1).broadcast_to( (llm_grid_t, -1, llm_grid_w)).flatten().int() w_index = ops.arange(llm_grid_w).view(1, 1, -1).broadcast_to( diff --git a/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py b/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py index 768a8238..51eff520 100644 --- a/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py @@ -27,10 +27,11 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm_mindspore.distributed.communication_op import ( - ReduceFromModelParallelRegion) + ReduceFromModelParallelRegion, tensor_model_parallel_all_reduce) from vllm_mindspore.model_executor.layers.quantization.base_config import ( QuantizeMethodBase, method_has_implemented_embedding) from vllm_mindspore.model_executor.utils import set_weight_attrs +from vllm_mindspore.model_executor.models.model_base import NativeCell DEFAULT_VOCAB_PADDING_SIZE = 64 @@ -45,7 +46,7 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase): weight = Parameter(mint.zeros( (sum(output_partition_sizes), input_size_per_partition), dtype=params_dtype), - requires_grad=False) + requires_grad=False) set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) layer.insert_param_to_cell("weight", weight) set_weight_attrs(weight, extra_weight_attrs) @@ -69,7 +70,7 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase): return x def embedding(self, layer: nn.Cell, input_: Tensor) -> Tensor: - return self.gather(layer.weight, input_, 0) + return mint.nn.functional.embedding(input_, layer.weight) def get_masked_input_and_mask( @@ -185,7 +186,7 @@ class VocabParallelEmbeddingShardIndices: assert self.num_added_elements <= self.num_added_elements_padded -class VocabParallelEmbedding(nn.Cell): +class VocabParallelEmbedding(NativeCell): def __init__( self, @@ -241,9 +242,10 @@ class VocabParallelEmbedding(nn.Cell): "the 'embedding' method, see UnquantizedEmbeddingMethod.") self.quant_method: QuantizeMethodBase = quant_method + self.model_config = get_current_vllm_config().model_config if params_dtype is None: - params_dtype = get_current_vllm_config().model_config.dtype + params_dtype = self.model_config.dtype # Divide the weight matrix along the vocaburaly dimension. self.num_added_embeddings = self.num_embeddings - self.org_vocab_size self.num_embeddings_per_partition = divide(self.num_embeddings_padded, @@ -266,7 +268,8 @@ class VocabParallelEmbedding(nn.Cell): params_dtype=params_dtype, weight_loader=self.weight_loader, ) - self.tensor_model_parallel_all_reduce = ReduceFromModelParallelRegion() + self.tensor_model_parallel_all_reduce = tensor_model_parallel_all_reduce \ + if self.model_config.enforce_eager else ReduceFromModelParallelRegion() @classmethod def _get_indices( diff --git a/vllm_mindspore/model_executor/models/model_base.py b/vllm_mindspore/model_executor/models/model_base.py index 4a960845..47077561 100644 --- a/vllm_mindspore/model_executor/models/model_base.py +++ b/vllm_mindspore/model_executor/models/model_base.py @@ -300,6 +300,7 @@ class MsModelBase: q_seq_lens = ms.Tensor(query_lens_np, dtype=ms.int32) position_ids = ms.Tensor(positions, dtype=ms.int32) attention_mask = self.casual_mask.gen_attention_mask(is_prefill, positions, query_lens_np) + freqs_cos, freqs_sin = self.model.layers[0].self_attn.rotary_emb.get_freqs(positions) model_inputs = {} model_inputs["input_ids"] = input_ids.astype(ms.int32) @@ -311,6 +312,8 @@ class MsModelBase: model_inputs["attention_mask"] = attention_mask model_inputs["key_cache"] = key_cache model_inputs["value_cache"] = value_cache + model_inputs["freqs_cos"] = freqs_cos + model_inputs["freqs_sin"] = freqs_sin return model_inputs, is_prefill @@ -363,6 +366,8 @@ class NativeModel(MsModelBase): dyn_slot_mapping = Tensor(shape=[None, ], dtype=mstype.int32) dynamic_attention_mask = Tensor(shape=[None, None], dtype=self.model_config.dtype) + dynamic_freqs_cos = Tensor(shape=[None, None], dtype=self.model_config.dtype) + dynamic_freqs_sin = Tensor(shape=[None, None], dtype=self.model_config.dtype) dyn_batch_valid_length = Tensor(shape=[None,], dtype=mstype.int32) dyn_q_seq_lens = Tensor(shape=[None, ], dtype=mstype.int32) dyn_block_tables = Tensor(shape=[None, None], dtype=mstype.int32) @@ -376,6 +381,8 @@ class NativeModel(MsModelBase): is_prefill, dyn_slot_mapping, dynamic_attention_mask, + dynamic_freqs_cos, + dynamic_freqs_sin, dyn_batch_valid_length, dyn_q_seq_lens, dyn_block_tables, @@ -420,6 +427,8 @@ class NativeModel(MsModelBase): is_prefill=is_prefill, slot_mapping=model_inputs["slot_mapping"], attn_mask=model_inputs["attention_mask"], + freqs_cos=model_inputs["freqs_cos"], + freqs_sin=model_inputs["freqs_sin"], batch_valid_length=model_inputs["batch_valid_length"], q_seq_lens=model_inputs["q_seq_lens"], block_tables=model_inputs["block_tables"], @@ -428,3 +437,11 @@ class NativeModel(MsModelBase): ) return model_output + + +class NativeCell(nn.Cell): + def __init__(self): + super().__init__() + + def __call__(self, *args, **kwargs): + return self.construct(*args, **kwargs) diff --git a/vllm_mindspore/model_executor/models/qwen2.py b/vllm_mindspore/model_executor/models/qwen2.py index 27cf2b23..bdeefc58 100644 --- a/vllm_mindspore/model_executor/models/qwen2.py +++ b/vllm_mindspore/model_executor/models/qwen2.py @@ -55,7 +55,7 @@ from vllm_mindspore.model_executor.model_loader.weight_utils import \ default_weight_loader from vllm_mindspore.model_executor.models.attention_mask import \ LowerTriangularMask -from vllm_mindspore.model_executor.models.model_base import (AttentionWrapper, +from vllm_mindspore.model_executor.models.model_base import (NativeCell, NativeModel) from vllm_mindspore.model_executor.models.utils import ( PPMissingLayer, make_empty_intermediate_tensors_factory, make_layers, @@ -64,7 +64,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm_mindspore.utils import STR_DTYPE_TO_MS_DTYPE -class Qwen2MLP(nn.Cell): +class Qwen2MLP(NativeCell): def __init__( self, @@ -99,7 +99,7 @@ class Qwen2MLP(nn.Cell): return x -class Qwen2Attention(nn.Cell): +class Qwen2Attention(NativeCell): def __init__(self, hidden_size: int, @@ -176,6 +176,8 @@ class Qwen2Attention(nn.Cell): is_prefill: bool, slot_mapping: Tensor, attn_mask: Tensor, + freqs_cos: Tensor, + freqs_sin: Tensor, batch_valid_length: Tensor, q_seq_lens: Tensor, block_tables: Tensor, @@ -183,7 +185,10 @@ class Qwen2Attention(nn.Cell): qkv, _ = self.qkv_proj(hidden_states) q, k, v = mint.split(qkv, (self.q_size, self.kv_size, self.kv_size), -1) - q, k = self.rotary_emb(positions, q, k, batch_valid_length, is_prefill) + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + q, k = self.rotary_emb(q, k, freqs_cos, freqs_sin, batch_valid_length) attn_output = self.attn(q, k, v, key_cache, value_cache, is_prefill, slot_mapping, attn_mask, batch_valid_length, q_seq_lens, block_tables) @@ -191,7 +196,7 @@ class Qwen2Attention(nn.Cell): return output -class Qwen2DecoderLayer(nn.Cell): +class Qwen2DecoderLayer(NativeCell): def __init__( self, @@ -252,6 +257,8 @@ class Qwen2DecoderLayer(nn.Cell): is_prefill: bool, slot_mapping: Tensor, attn_mask: Tensor, + freqs_cos: Tensor, + freqs_sin: Tensor, batch_valid_length: Tensor, q_seq_lens: Tensor, block_tables: Tensor, @@ -266,7 +273,8 @@ class Qwen2DecoderLayer(nn.Cell): hidden_states, residual) hidden_states = self.self_attn(positions, hidden_states, key_cache, value_cache, is_prefill, slot_mapping, - attn_mask, batch_valid_length, + attn_mask, freqs_cos, freqs_sin, + batch_valid_length, q_seq_lens, block_tables) # Fully Connected @@ -276,7 +284,7 @@ class Qwen2DecoderLayer(nn.Cell): return hidden_states, residual -class Qwen2Model(nn.Cell): +class Qwen2Model(NativeCell): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -329,6 +337,8 @@ class Qwen2Model(nn.Cell): is_prefill: bool, slot_mapping: Tensor, attn_mask: Tensor, + freqs_cos: Tensor, + freqs_sin: Tensor, batch_valid_length: Tensor, q_seq_lens: Tensor, block_tables: Tensor, @@ -351,7 +361,7 @@ class Qwen2Model(nn.Cell): key_caches[i - self.start_layer], value_caches[i - self.start_layer], is_prefill, slot_mapping, - attn_mask, batch_valid_length, + attn_mask, freqs_cos, freqs_sin, batch_valid_length, q_seq_lens, block_tables, residual) if not get_pp_group().is_last_rank: return IntermediateTensors({ @@ -382,7 +392,7 @@ class Qwen2Model(nn.Cell): # the checkpoint. Skip them. continue if (self.quant_config is not None and - (scale_name := self.quant_config.get_cache_scale(name))): + (scale_name := self.quant_config.get_cache_scale(name))): # Loading kv cache quantization scales param = params_dict[scale_name] weight_loader = getattr(param, "weight_loader", -- Gitee