From 4d936e918c46fef948135080a49105fb5fd2fde4 Mon Sep 17 00:00:00 2001 From: lvhaoyu Date: Mon, 28 Jul 2025 21:02:15 +0800 Subject: [PATCH] update layernorm and embedding --- vllm_mindspore/model_executor/layers/layernorm.py | 12 +++++++----- .../layers/vocab_parallel_embedding.py | 3 ++- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/vllm_mindspore/model_executor/layers/layernorm.py b/vllm_mindspore/model_executor/layers/layernorm.py index bc6dfe11..761bbd38 100644 --- a/vllm_mindspore/model_executor/layers/layernorm.py +++ b/vllm_mindspore/model_executor/layers/layernorm.py @@ -23,6 +23,7 @@ from typing import Optional, Union from mindspore import Parameter, Tensor, mint, nn, ops from mindspore._c_expression import typing +from mindspore.ops.auto_generate.gen_ops_prim import AddRmsNorm from vllm.config import get_current_vllm_config @@ -45,6 +46,8 @@ class RMSNorm(nn.Cell): params_dtype = get_current_vllm_config().model_config.dtype self.weight = Parameter(mint.ones(hidden_size, dtype=params_dtype)) self.rms_norm = ops.RmsNorm(eps) + self.eps = eps + self.add_rms_norm = AddRmsNorm() def construct( self, @@ -52,9 +55,8 @@ class RMSNorm(nn.Cell): residual: Optional[Tensor] = None ) -> Union[Tensor, tuple[Tensor, Tensor]]: if residual is not None: - x = x + residual - residual = x + output, _, residual = self.add_rms_norm(x, residual, self.weight, + self.eps) + return output, residual output = self.rms_norm(x, self.weight)[0] - if residual is None: - return output - return output, residual + return output diff --git a/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py b/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py index ee725f9a..556afe46 100644 --- a/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py @@ -346,13 +346,14 @@ class VocabParallelEmbedding(nn.Cell): # If parameter does not have output dim, then it should # be copied onto all gpus (e.g. g_idx for act_order gptq). if output_dim is None: + loaded_weight = loaded_weight[:] assert param.data.shape == loaded_weight.shape if param.data.shape != loaded_weight.shape: raise ValueError( f"'param.data.shape' should be equal to " f"'loaded_weight.shape', but got {param.data.shape} " f"and {loaded_weight.shape}") - param.set_data(loaded_weight) + param.set_data(ms.from_numpy(loaded_weight)) return # Shard indexes for loading the weight -- Gitee