From 226c7aecc35237220a0db47cf41ba1fde6f00232 Mon Sep 17 00:00:00 2001 From: twc Date: Mon, 21 Jul 2025 15:11:33 +0800 Subject: [PATCH] optimize native model parameter segmentation --- .../model_executor/layers/linear.py | 89 +++++++++---------- .../layers/vocab_parallel_embedding.py | 31 ++++--- .../model_loader/weight_utils.py | 50 ++++++++--- vllm_mindspore/model_executor/models/llama.py | 6 +- 4 files changed, 100 insertions(+), 76 deletions(-) diff --git a/vllm_mindspore/model_executor/layers/linear.py b/vllm_mindspore/model_executor/layers/linear.py index adfbc086..d09e59a6 100644 --- a/vllm_mindspore/model_executor/layers/linear.py +++ b/vllm_mindspore/model_executor/layers/linear.py @@ -22,6 +22,7 @@ from abc import abstractmethod from typing import Optional, Union +import mindspore as ms from mindspore import Parameter, Tensor, mint, nn, ops from mindspore._c_expression.typing import Type as MSDtype from vllm.config import get_current_vllm_config @@ -34,6 +35,8 @@ from vllm_mindspore.distributed.communication_op import ( ReduceFromModelParallelRegion) from vllm_mindspore.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) +from vllm_mindspore.model_executor.model_loader.weight_utils import ( + split_loaded_weight) from vllm_mindspore.model_executor.utils import set_weight_attrs WEIGHT_LOADER_V2_SUPPORTED = [ @@ -180,7 +183,7 @@ class ColumnParallelLinear(LinearBase): output_sizes: list of output sizes packed into one output, like for QKV the list would be size 3. prefix: The name of the layer in the state dict, including all parents - (e.g. model.layers.0.qkv_proj) + (e.g. model.layers.0.qkv_proj) """ def __init__( @@ -263,21 +266,21 @@ class ColumnParallelLinear(LinearBase): return output return output, output_bias - def weight_loader(self, param: Parameter, loaded_weight: Tensor): + def weight_loader(self, param, loaded_weight): tp_rank = get_tensor_model_parallel_rank() output_dim = getattr(param, "output_dim", None) + shard_size = self.output_size_per_partition + start_idx = tp_rank * shard_size + loaded_weight = split_loaded_weight(loaded_weight, output_dim, + start_idx, shard_size) - if output_dim is not None: - shard_size = param.shape[output_dim] - start_idx = tp_rank * shard_size - loaded_weight = loaded_weight.narrow(output_dim, start_idx, - shard_size).contiguous() - + # Special case for loading scales off disk, which often do not + # have a shape (such as in the case of AutoFP8). if len(loaded_weight.shape) == 0: loaded_weight = loaded_weight.reshape(1) assert param.shape == loaded_weight.shape - param.set_data(loaded_weight) + param.set_data(ms.from_numpy(loaded_weight)) class MergedColumnParallelLinear(ColumnParallelLinear): @@ -327,29 +330,27 @@ class MergedColumnParallelLinear(ColumnParallelLinear): prefix=prefix, return_bias=return_bias) - -# type: ignore[override] - def weight_loader(self, - param: Parameter, - loaded_weight: Tensor, + param, + loaded_weight, loaded_shard_id: Optional[int] = None): - param_data = param.data output_dim = getattr(param, "output_dim", None) tp_rank = get_tensor_model_parallel_rank() tp_size = get_tensor_model_parallel_world_size() - if output_dim is not None and loaded_shard_id is not None: + shard_size = 0 + shard_offset = 0 + if loaded_shard_id is not None: assert loaded_shard_id < len(self.output_sizes) shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size shard_size = self.output_sizes[loaded_shard_id] // tp_size - param_data = param.data - param_data = param_data.narrow(output_dim, shard_offset, - shard_size) - start_idx = tp_rank * shard_size - loaded_weight = loaded_weight.narrow(output_dim, start_idx, - shard_size).contiguous() - assert param_data.shape == loaded_weight.shape - param[shard_offset:shard_offset + shard_size, :] = loaded_weight + + start_idx = tp_rank * shard_size + loaded_weight = split_loaded_weight(loaded_weight, output_dim, + start_idx, shard_size) + + assert loaded_weight.shape == (shard_size, param.shape[1]) + param[shard_offset:shard_offset + + shard_size, :] = ms.from_numpy(loaded_weight) class QKVParallelLinear(ColumnParallelLinear): @@ -427,19 +428,13 @@ class QKVParallelLinear(ColumnParallelLinear): prefix=prefix, return_bias=return_bias) - -# type: ignore[override] - def weight_loader(self, - param: Parameter, - loaded_weight: Tensor, + param, + loaded_weight, loaded_shard_id: Optional[str] = None): output_dim = getattr(param, "output_dim", None) tp_rank = get_tensor_model_parallel_rank() assert loaded_shard_id in ["q", "k", "v"] - # If output dim is defined, use the default loading process. - # if output_dim is not None: - param_data = param.data if loaded_shard_id == "q": shard_offset = 0 shard_size = self.num_heads * self.head_size @@ -451,21 +446,20 @@ class QKVParallelLinear(ColumnParallelLinear): self.num_kv_heads) * self.head_size shard_size = self.num_kv_heads * self.head_size - param_data = param_data.narrow(output_dim, shard_offset, shard_size) if loaded_shard_id == "q": shard_id = tp_rank else: shard_id = tp_rank // self.num_kv_head_replicas start_idx = shard_id * shard_size + loaded_weight = split_loaded_weight(loaded_weight, output_dim, + start_idx, shard_size) + loaded_weight = ms.from_numpy(loaded_weight) - loaded_weight = loaded_weight.narrow(output_dim, start_idx, - shard_size).contiguous() - assert param_data.shape == loaded_weight.shape if param.name.endswith("weight"): - self.weight[shard_offset:shard_offset + - shard_size, :] = loaded_weight + assert loaded_weight.shape == (shard_size, param.shape[1]) if param.name.endswith("bias"): - self.bias[shard_offset:shard_offset + shard_size] = loaded_weight + assert loaded_weight.shape == (shard_size, ) + param[shard_offset:shard_offset + shard_size] = loaded_weight class RowParallelLinear(LinearBase): @@ -587,14 +581,15 @@ class RowParallelLinear(LinearBase): def weight_loader(self, param, loaded_weight): tp_rank = get_tensor_model_parallel_rank() input_dim = getattr(param, "input_dim", None) - is_sharded_weight = getattr(param, "is_sharded_weight", False) - is_sharded_weight = is_sharded_weight - if input_dim is not None and not is_sharded_weight: - shard_size = param.shape[input_dim] - start_idx = tp_rank * shard_size - loaded_weight = loaded_weight.narrow(input_dim, start_idx, - shard_size).contiguous() + shard_size = self.input_size_per_partition + start_idx = tp_rank * shard_size + loaded_weight = split_loaded_weight(loaded_weight, input_dim, + start_idx, shard_size) + + # Special case for loading scales off disk, which often do not + # have a shape (such as in the case of AutoFP8). if len(loaded_weight.shape) == 0: loaded_weight = loaded_weight.reshape(1) + assert param.shape == loaded_weight.shape - param.set_data(loaded_weight.contiguous()) + param.set_data(ms.from_numpy(loaded_weight)) diff --git a/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py b/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py index 18530805..c657bfdf 100644 --- a/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py @@ -22,6 +22,7 @@ from collections.abc import Sequence from dataclasses import dataclass from typing import Optional +import mindspore as ms from mindspore import Parameter, Tensor, mint, nn, ops from mindspore.common.dtype import typing from vllm.config import get_current_vllm_config @@ -34,6 +35,8 @@ from vllm_mindspore.distributed.communication_op import ( ReduceFromModelParallelRegion) from vllm_mindspore.model_executor.layers.quantization.base_config import ( QuantizeMethodBase, method_has_implemented_embedding) +from vllm_mindspore.model_executor.model_loader.weight_utils import ( + split_loaded_weight) from vllm_mindspore.model_executor.utils import set_weight_attrs DEFAULT_VOCAB_PADDING_SIZE = 64 @@ -339,32 +342,32 @@ class VocabParallelEmbedding(nn.Cell): def weight_loader(self, param: Parameter, loaded_weight: Tensor): output_dim = getattr(param, "output_dim", None) - + get_tensor_model_parallel_rank() # If parameter does not have output dim, then it should # be copied onto all gpus (e.g. g_idx for act_order gptq). if output_dim is None: assert param.data.shape == loaded_weight.shape if param.data.shape != loaded_weight.shape: raise ValueError( - f"'param.data.shape' should be equal " - f"to 'loaded_weight.shape'," - f" but got {param.data.shape} and {loaded_weight.shape}") + f"'param.data.shape' should be equal to " + f"'loaded_weight.shape', but got {param.data.shape} " + f"and {loaded_weight.shape}") param.set_data(loaded_weight) return # Shard indexes for loading the weight start_idx = self.shard_indices.org_vocab_start_index shard_size = self.shard_indices.org_vocab_end_index - start_idx - if loaded_weight.shape[output_dim] != self.org_vocab_size: - raise ValueError(f"'loaded_weight.shape[output_dim]' should " - f"be equal to 'org_vocab_size'," - f" but got {loaded_weight.shape[output_dim]} " - f"and {self.org_vocab_size}") - - # Copy the data. - loaded_weight = loaded_weight.narrow(output_dim, start_idx, - shard_size).contiguous() - param[:loaded_weight.shape[0]] = loaded_weight + loaded_weight = split_loaded_weight(loaded_weight, output_dim, + start_idx, shard_size) + org_vocab_size_per_rank = self.org_vocab_size // self.tp_size + if loaded_weight.shape[output_dim] != org_vocab_size_per_rank: + raise ValueError( + f"'loaded_weight.shape[output_dim]' should be equal to " + f"'org_vocab_size', but got {loaded_weight.shape[output_dim]} " + f"and {self.org_vocab_size}") + + param[:loaded_weight.shape[0]] = ms.from_numpy(loaded_weight) param[loaded_weight.shape[0]:] = 0 diff --git a/vllm_mindspore/model_executor/model_loader/weight_utils.py b/vllm_mindspore/model_executor/model_loader/weight_utils.py index 4ec3a2de..6bf2dd4c 100644 --- a/vllm_mindspore/model_executor/model_loader/weight_utils.py +++ b/vllm_mindspore/model_executor/model_loader/weight_utils.py @@ -19,22 +19,45 @@ # limitations under the License. from collections.abc import Generator +from typing import Any import mindspore as ms -import torch -from mindspore import Parameter, Tensor +from mindspore import Parameter +from safetensors import safe_open from tqdm.auto import tqdm +from vllm.model_executor.model_loader.weight_utils import (_BAR_FORMAT, + enable_tqdm) + + +def split_loaded_weight(loaded_weight, shard_dim, start_idx, shard_size): + """ + Read numpy slice data based on axis and slice range. + :loaded_weight: PySafeSlice object + :shard_dim: axis of weight slice + :start_idx: start slice index + :shard_size: end slice index + """ + if shard_dim is None: + loaded_weight = loaded_weight[:] + return loaded_weight + + end_idx = start_idx + shard_size + if shard_dim == 0: + loaded_weight = loaded_weight[start_idx:end_idx] + elif shard_dim == 1: + loaded_weight = loaded_weight[:, start_idx:end_idx] + elif shard_dim == 2: + loaded_weight = loaded_weight[:, :, start_idx:end_idx] + else: + raise ValueError("shard_dim:{} is not supported.".format(shard_dim)) + return loaded_weight def safetensors_weights_iterator( hf_weights_files: list[str], use_tqdm_on_load: bool, -) -> Generator[tuple[str, torch.Tensor], None, None]: +) -> Generator[tuple[str, Any], None, None]: """Iterate over the weights in the model safetensor files.""" - from safetensors import safe_open - from vllm.model_executor.model_loader.weight_utils import (_BAR_FORMAT, - enable_tqdm) - for st_file in tqdm( hf_weights_files, desc="Loading safetensors checkpoint shards", @@ -43,10 +66,15 @@ def safetensors_weights_iterator( ): with safe_open(st_file, framework="np") as f: for name in f.keys(): # noqa: SIM118 - param = f.get_tensor(name) - yield name, ms.tensor(param) + # Return a lightweight PySafeSlice object that uses file + # pointer offset internally to read Safetensor on demand, + # avoiding memory explosion. Actual data can be obtained + # through slicing operation like param[start:end] + param = f.get_slice(name) + yield name, param -def default_weight_loader(param: Parameter, loaded_weight: Tensor) -> None: +def default_weight_loader(param: Parameter, loaded_weight: Any) -> None: """Default weight loader.""" - param.set_data(loaded_weight) + loaded_weight = loaded_weight[:] + param.set_data(ms.Tensor(loaded_weight, dtype=param.dtype)) diff --git a/vllm_mindspore/model_executor/models/llama.py b/vllm_mindspore/model_executor/models/llama.py index aabda0d7..e226e9b5 100644 --- a/vllm_mindspore/model_executor/models/llama.py +++ b/vllm_mindspore/model_executor/models/llama.py @@ -51,16 +51,14 @@ from vllm_mindspore.model_executor.layers.sampler import (SamplerOutput, get_sampler) from vllm_mindspore.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm_mindspore.model_executor.model_loader.weight_utils import ( + default_weight_loader) from vllm_mindspore.model_executor.models.model_base import MsModelBase from vllm_mindspore.model_executor.models.utils import ( PPMissingLayer, extract_layer_index, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) -def default_weight_loader(param, loaded_weight) -> None: - param.set_data(loaded_weight) - - class LlamaMLP(nn.Cell): def __init__( -- Gitee