From 9c759d8f3eb48003ec26fae6fb0eb5f9ce451f18 Mon Sep 17 00:00:00 2001 From: HighCloud Date: Fri, 4 Jul 2025 15:04:51 +0800 Subject: [PATCH 1/2] support native qwq --- vllm_mindspore/__init__.py | 15 +++ .../distributed/communication_op.py | 8 ++ vllm_mindspore/distributed/parallel_state.py | 93 ++++++++++++++++++ .../model_executor/layers/linear.py | 24 +++-- .../layers/vocab_parallel_embedding.py | 3 +- .../model_loader/weight_utils.py | 8 +- .../model_executor/models/model_base.py | 61 +++++++++--- vllm_mindspore/model_executor/models/qwen2.py | 29 +++++- vllm_mindspore/utils.py | 97 ++++++++++++++----- vllm_mindspore/v1/worker/gpu_model_runner.py | 25 +++-- vllm_mindspore/worker/cache_engine.py | 19 +++- vllm_mindspore/worker/model_runner.py | 5 +- 12 files changed, 327 insertions(+), 60 deletions(-) create mode 100644 vllm_mindspore/distributed/parallel_state.py diff --git a/vllm_mindspore/__init__.py b/vllm_mindspore/__init__.py index 35697361..cec9bea9 100644 --- a/vllm_mindspore/__init__.py +++ b/vllm_mindspore/__init__.py @@ -263,6 +263,16 @@ RejectionSampler._smallest_positive_value.__set_name__( RejectionSampler, '_smallest_positive_value') vllm.model_executor.layers.rejection_sampler._multinomial = _multinomial +import vllm.distributed.communication_op +import vllm.worker.worker_base +from vllm_mindspore.distributed.communication_op import cpu_broadcast_tensor_dict +vllm.distributed.communication_op.broadcast_tensor_dict = cpu_broadcast_tensor_dict +vllm.worker.worker_base.broadcast_tensor_dict = cpu_broadcast_tensor_dict + +import vllm.distributed.parallel_state +from vllm_mindspore.distributed.parallel_state import gc_broadcast_tensor_dict +vllm.distributed.parallel_state.GroupCoordinator.broadcast_tensor_dict = gc_broadcast_tensor_dict + ######### for multi-model from vllm_mindspore.inputs.registry import call_hf_processor from vllm.inputs.registry import InputProcessingContext @@ -327,6 +337,11 @@ vllm.v1.worker.gpu_input_batch.BlockTable = BlockTable import vllm.v1.worker.gpu_input_batch from vllm_mindspore.v1.worker.gpu_input_batch import _make_sampling_metadata, _make_prompt_token_ids_tensor +# TODO: need this? +# from vllm_mindspore.model_executor.model_loader.loader import _process_weights_after_loading + +# vllm.model_executor.model_loader.loader._process_weights_after_loading = _process_weights_after_loading + vllm.v1.worker.gpu_input_batch.InputBatch._make_sampling_metadata = _make_sampling_metadata vllm.v1.worker.gpu_model_runner.InputBatch._make_sampling_metadata = _make_sampling_metadata vllm.v1.worker.gpu_input_batch.InputBatch._make_prompt_token_ids_tensor = _make_prompt_token_ids_tensor diff --git a/vllm_mindspore/distributed/communication_op.py b/vllm_mindspore/distributed/communication_op.py index a24d4959..31e8d892 100644 --- a/vllm_mindspore/distributed/communication_op.py +++ b/vllm_mindspore/distributed/communication_op.py @@ -19,6 +19,7 @@ # 不要去照搬mindspeed的, 因为训练当中包含太多的特性, 推理只需要非常简单的通信,可以提升性能。 from typing import Any, Dict, Optional, Union +import torch from mindspore import Tensor, nn, ops from mindspore.communication.comm_func import all_reduce, broadcast @@ -48,6 +49,13 @@ def broadcast_tensor_dict(tensor_dict: Optional[Dict[Any, Union[Tensor, # return tensor_dict # return get_tp_group().broadcast_tensor_dict(tensor_dict, src) +def cpu_broadcast_tensor_dict(tensor_dict: Optional[Dict[Any, Union[torch.Tensor, + Any]]] = None, + src: int = 0): + if not torch.distributed.is_initialized(): + return tensor_dict + return get_tp_group().broadcast_tensor_dict(tensor_dict, src, group=get_tp_group().cpu_group) + class ReduceFromModelParallelRegion(nn.Cell): "All reduce the input from the model parallel region." diff --git a/vllm_mindspore/distributed/parallel_state.py b/vllm_mindspore/distributed/parallel_state.py new file mode 100644 index 00000000..697196fa --- /dev/null +++ b/vllm_mindspore/distributed/parallel_state.py @@ -0,0 +1,93 @@ +import torch +import torch.distributed +from torch.distributed import ProcessGroup + +from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union) +from vllm.distributed.parallel_state import _split_tensor_dict, TensorMetadata +from vllm_mindspore.utils import atlas_inference + +def gc_broadcast_tensor_dict( + self, + tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None, + src: int = 0, + group: Optional[ProcessGroup] = None, + metadata_group: Optional[ProcessGroup] = None + ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: + """Broadcast the input tensor dictionary. + NOTE: `src` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if (not torch.distributed.is_initialized() or self.world_size == 1): + return tensor_dict + + if not atlas_inference(): + group = self.device_group + metadata_group = self.cpu_group + assert src < self.world_size, f"Invalid src rank ({src})" + + rank_in_group = self.rank_in_group + if rank_in_group == src: + metadata_list: List[Tuple[Any, Any]] = [] + assert isinstance( + tensor_dict, + dict), (f"Expecting a dictionary, got {type(tensor_dict)}") + metadata_list, tensor_list = _split_tensor_dict(tensor_dict) + # `metadata_list` lives in CPU memory. + # `broadcast_object_list` has serialization & deserialization, + # all happening on CPU. Therefore, we can use the CPU group. + self.broadcast_object(metadata_list, src=src) + async_handles = [] + for tensor in tensor_list: + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + handle = torch.distributed.broadcast(tensor, + src=self.ranks[src], + group=metadata_group, + async_op=True) + else: + # use group for GPU tensors + handle = torch.distributed.broadcast(tensor, + src=self.ranks[src], + group=group, + async_op=True) + async_handles.append(handle) + for async_handle in async_handles: + async_handle.wait() + + else: + metadata_list = self.broadcast_object(None, src=src) + tensor_dict = {} + async_handles = [] + for key, value in metadata_list: + if isinstance(value, TensorMetadata): + tensor = torch.empty(value.size, + dtype=value.dtype, + device=value.device) + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + tensor_dict[key] = tensor + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + handle = torch.distributed.broadcast( + tensor, + src=self.ranks[src], + group=metadata_group, + async_op=True) + else: + # use group for GPU tensors + handle = torch.distributed.broadcast( + tensor, + src=self.ranks[src], + group=group, + async_op=True) + async_handles.append(handle) + tensor_dict[key] = tensor + else: + tensor_dict[key] = value + for async_handle in async_handles: + async_handle.wait() + return tensor_dict diff --git a/vllm_mindspore/model_executor/layers/linear.py b/vllm_mindspore/model_executor/layers/linear.py index e0851149..53ebc22a 100644 --- a/vllm_mindspore/model_executor/layers/linear.py +++ b/vllm_mindspore/model_executor/layers/linear.py @@ -388,9 +388,15 @@ class MergedColumnParallelLinear(ColumnParallelLinear): if not use_bitsandbytes_4bit: loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size).contiguous() assert param_data.shape == loaded_weight.shape - # param_data.copy_(loaded_weight) - # param_data.set_data(loaded_weight) - param[shard_offset: shard_offset + shard_size, :] = loaded_weight + if len(loaded_weight.shape) == 2: + param[shard_offset: shard_offset + shard_size, :] = loaded_weight + else: + param[shard_offset: shard_offset + shard_size] = loaded_weight + else: + assert param.shape == loaded_weight.shape + if loaded_weight.dtype == ms.float32 and param.dtype == ms.float16: + loaded_weight = loaded_weight.astype(ms.float16) + param.set_data(loaded_weight.contiguous()) class QKVParallelLinear(ColumnParallelLinear): @@ -474,10 +480,10 @@ class QKVParallelLinear(ColumnParallelLinear): if not use_bitsandbytes_4bit: loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size).contiguous() assert param_data.shape == loaded_weight.shape - if param.name.endswith("weight"): - self.weight[shard_offset: shard_offset + shard_size, :] = loaded_weight - if param.name.endswith("bias"): - self.bias[shard_offset: shard_offset + shard_size] = loaded_weight + if len(loaded_weight.shape) == 2: + param[shard_offset: shard_offset + shard_size, :] = loaded_weight + else: + param[shard_offset: shard_offset + shard_size] = loaded_weight # tp_rank = get_tensor_model_parallel_rank() # if shard_id is "q": # start_index = self.num_heads * tp_rank * self.head_size @@ -586,6 +592,7 @@ class RowParallelLinear(LinearBase): def weight_loader(self, param, loaded_weight): tp_rank = get_tensor_model_parallel_rank() + param_data = param.data input_dim = getattr(param, "input_dim", None) use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) is_sharded_weight = getattr(param, "is_sharded_weight", False) @@ -606,6 +613,5 @@ class RowParallelLinear(LinearBase): loaded_weight = loaded_weight.reshape(1) assert param.shape == loaded_weight.shape - # param_data.copy_(loaded_weight) + param_data.copy_(loaded_weight) # self.weight[:, start_idx : start_idx + shard_size] = loaded_weight - param.set_data(loaded_weight.contiguous()) diff --git a/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py b/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py index 768a8238..3af4878d 100644 --- a/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm_mindspore/model_executor/layers/vocab_parallel_embedding.py @@ -18,7 +18,7 @@ from dataclasses import dataclass from typing import List, Optional, Sequence, Tuple -from mindspore import Parameter, Tensor, mint, nn, ops +from mindspore import Parameter, Tensor, mint, nn, ops, jit from mindspore.common.dtype import typing from vllm.config import get_current_vllm_config from vllm.distributed import (divide, get_tensor_model_parallel_rank, @@ -56,6 +56,7 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase): self.gather = ops.Gather() self.bias_add = ops.Add() + # @jit def apply(self, layer: nn.Cell, x: Tensor, diff --git a/vllm_mindspore/model_executor/model_loader/weight_utils.py b/vllm_mindspore/model_executor/model_loader/weight_utils.py index 0fc4d3d2..c8edd319 100644 --- a/vllm_mindspore/model_executor/model_loader/weight_utils.py +++ b/vllm_mindspore/model_executor/model_loader/weight_utils.py @@ -23,6 +23,8 @@ import torch import mindspore as ms from mindspore import Parameter, Tensor +from vllm_mindspore.utils import atlas_inference +import numpy as np def safetensors_weights_iterator( @@ -41,8 +43,10 @@ def safetensors_weights_iterator( ): with safe_open(st_file, framework="np") as f: for name in f.keys(): - param = f.get_tensor(name) - yield name, ms.tensor(param) + x = f.get_tensor(name) + x = x.astype(np.float16) \ + if (str(x.dtype) == 'bfloat16' and atlas_inference()) else x + yield name, ms.tensor(x) def default_weight_loader(param: Parameter, diff --git a/vllm_mindspore/model_executor/models/model_base.py b/vllm_mindspore/model_executor/models/model_base.py index d2db9794..1e8685a3 100644 --- a/vllm_mindspore/model_executor/models/model_base.py +++ b/vllm_mindspore/model_executor/models/model_base.py @@ -31,13 +31,13 @@ from vllm.sequence import IntermediateTensors import vllm.envs as envs import mindspore as ms -from mindspore import Tensor, nn, mutable +from mindspore import Tensor, nn, mutable, ops from mindspore.common import dtype as mstype from vllm_mindspore.model_executor.models.attention_mask import LowerTriangularMask from vllm_mindspore.utils import STR_DTYPE_TO_MS_DTYPE from vllm_mindspore.v1.attention.backends.ms_attn import MsAttentionMetadata - +from vllm_mindspore.utils import atlas_inference class AttentionWrapper: @@ -48,11 +48,32 @@ class AttentionWrapper: vllm_config.parallel_config) head_size = vllm_config.model_config.get_head_size() num_block = 0 - self.kv_shape = [num_block, block_size, num_kv_heads, head_size] - self.kv_cache = [( - ms.mint.zeros(self.kv_shape, dtype=vllm_config.model_config.dtype), - ms.mint.zeros(self.kv_shape, dtype=vllm_config.model_config.dtype), - ) for _ in range(vllm_config.parallel_config.pipeline_parallel_size)] + if atlas_inference(): + self.kv_shape = [num_block, block_size, num_kv_heads * head_size] + self.kv_cache = [ + ( + ops.auto_generate.format_cast( + ms.mint.zeros( + self.kv_shape, dtype=vllm_config.model_config.dtype + ), + 29, + ), + ops.auto_generate.format_cast( + ms.mint.zeros( + self.kv_shape, dtype=vllm_config.model_config.dtype + ), + 29, + ), + ) + for _ in range(vllm_config.parallel_config.pipeline_parallel_size) + ] + else: + self.kv_shape = [num_block, block_size, num_kv_heads, head_size] + self.kv_cache = [( + ms.mint.zeros(self.kv_shape, dtype=vllm_config.model_config.dtype), + ms.mint.zeros(self.kv_shape, dtype=vllm_config.model_config.dtype), + ) for _ in range(vllm_config.parallel_config.pipeline_parallel_size)] + self.attn_type = AttentionType.DECODER # add for v1 @@ -68,11 +89,24 @@ class MLAAttentionWrapper(AttentionWrapper): def __init__(self): super().__init__() vllm_config = get_current_vllm_config() - self.kv_cache = [ - (ms.mint.zeros(self.kv_shape, - dtype=vllm_config.model_config.dtype), ) - for _ in range(vllm_config.parallel_config.pipeline_parallel_size) - ] + if atlas_inference(): + self.kv_cache = [ + ( + ops.auto_generate.format_cast( + ms.mint.zeros( + self.kv_shape, dtype=vllm_config.model_config.dtype + ), + 29, + ), + ) + for _ in range(vllm_config.parallel_config.pipeline_parallel_size) + ] + else: + self.kv_cache = [ + (ms.mint.zeros(self.kv_shape, + dtype=vllm_config.model_config.dtype), ) + for _ in range(vllm_config.parallel_config.pipeline_parallel_size) + ] class MsModelBase: @@ -390,7 +424,8 @@ class NativeModel(MsModelBase): block_size = self.cache_config.block_size num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config) head_size = self.model_config.get_head_size() - kv_cache_shape = (None, block_size, num_kv_heads, head_size) + kv_cache_shape = (None, block_size, num_kv_heads * head_size) if atlas_inference() \ + else (None, block_size, num_kv_heads, head_size) kv_cache_dtype = self.model_config.dtype if self.cache_config.cache_dtype == "auto" \ else self.cache_config.cache_dtype diff --git a/vllm_mindspore/model_executor/models/qwen2.py b/vllm_mindspore/model_executor/models/qwen2.py index 87c54c21..3b62385f 100644 --- a/vllm_mindspore/model_executor/models/qwen2.py +++ b/vllm_mindspore/model_executor/models/qwen2.py @@ -24,7 +24,8 @@ if TYPE_CHECKING: else: Qwen2Config = None -from mindspore import Parameter, Tensor, mint, nn +from mindspore import Parameter, Tensor, mint, nn, ops +import mindspore as ms from vllm.attention.backends.abstract import AttentionType from vllm.config import CacheConfig, VllmConfig @@ -33,6 +34,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.models.interfaces import SupportsLoRA from vllm.sequence import IntermediateTensors +from vllm_mindspore.utils import atlas_inference from vllm_mindspore.attention import Attention from vllm_mindspore.model_executor.layers.activation import SwiGLU from vllm_mindspore.model_executor.layers.layernorm import RMSNorm @@ -397,9 +399,34 @@ class Qwen2Model(nn.Cell): param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) + # Norm type in weights may be f32 + if(loaded_weight.dtype != param.dtype): + loaded_weight = loaded_weight.to(dtype=param.dtype) weight_loader(param, loaded_weight) loaded_params.add(name) + def adjust_weight(params_dict): + if not atlas_inference(): + return + + target_keywords = [ + "qkv_proj.weight", + "o_proj.weight", + "gate_up_proj.weight", + "down_proj.weight", + # "lm_head.weight", + ] + + for name, param in params_dict.items(): + if any(name.endswith(keyword) for keyword in target_keywords): + cast_weight = ops.auto_generate.format_cast(param, 29) + ms.runtime.synchronize() + param.set_data(cast_weight) + + ms.runtime.synchronize() + adjust_weight(params_dict) + ms.runtime.synchronize() + return loaded_params diff --git a/vllm_mindspore/utils.py b/vllm_mindspore/utils.py index 920bb230..8bb93a33 100644 --- a/vllm_mindspore/utils.py +++ b/vllm_mindspore/utils.py @@ -173,29 +173,6 @@ def is_mindone_model_backend(): == vllmModelBackendEnum.MIND_ONE) -def check_ready(): - from mindspore import set_context - - # Common environment variables of predict. - set_context(jit_config={"jit_level": "O0", "infer_boost": "on"}) - default_env = { - "MS_INTERNAL_DISABLE_CUSTOM_KERNEL_LIST": - "FlashAttentionScore,PagedAttention", - } - env_setup(default_env) - - if os.getenv("MS_MEMPOOL_BLOCK_SIZE"): - set_context( - mempool_block_size=f"{os.environ['MS_MEMPOOL_BLOCK_SIZE']}GB") - - if is_mindformers_model_backend(): - logger.info("Run with Mindformers backend!") - elif is_mindone_model_backend(): - logger.info("Run with MindONE backend!") - else: - logger.info("Run with native model backend!") - - def convert_np_to_ms_dtype(value): """convert_np_to_ms_dtype""" if value.dtype == np.int8: @@ -297,3 +274,77 @@ def ms_memory_profiling( result.non_torch_increase = diff_from_create.non_torch_memory result.profile_time = diff_profile.timestamp result.non_kv_cache_memory = result.non_torch_increase + result.torch_peak_increase + result.weights_memory # noqa + + +def is_version_ge(current_version, base_version): + """ + return current_version >= base_version. + Check whether the current version is higher than or equal to the base version. + for current_version: 1.8.1, base_version: 1.11.0, it return False. + """ + version_split_char = '.' + if version_split_char not in base_version or version_split_char not in current_version: + raise ValueError("The version string will contain the `.`." + "For example, current_version 1.8.1, base_version: 1.11.0.") + for x, y in zip(current_version.split(version_split_char), base_version.split(version_split_char)): + if not x.isdigit() or not y.isdigit(): + continue + if int(x) != int(y): + return int(x) >= int(y) + return True + +def get_ascend_soc_version(): + """Get ascend soc version.""" + if is_version_ge(ms.__version__, "2.2.0"): + from mindspore._c_expression import MSContext + return MSContext.get_instance().get_ascend_soc_version() + ascend_chip_type = os.getenv("ASCEND_CHIP_TYPE", "UNSET") + if ascend_chip_type not in ["910a", "910b", "UNSET"]: + raise EnvironmentError(f"ASCEND_CHIP_TYPE should be in ['910a', '910b'],but get {ascend_chip_type}") + if ascend_chip_type == "UNSET": + logger.info("Environment variables need to be set manually to obtain the chip type," + "which can be set as follows: \n" + "For Atlas 800, run 'export ASCEND_CHIP_TYPE=910a' before the program runs.\n" + "For Atlas 800T A2, run 'export ASCEND_CHIP_TYPE=910b' before the program runs.\n" + "If you need to get chip information automatically, MindSpore 2.2 and above is recommended") + return ascend_chip_type + +def atlas_inference(): + device = get_ascend_soc_version() + return device in ['310p', 'ascend310p'] + +def check_ready(): + import vllm.envs as envs + from mindspore import set_context + + # Common environment variables of predict. + set_context(jit_config={"jit_level": "O0", "infer_boost": "on"}) + custom_kernels = "FlashAttentionScore,PagedAttention" + if atlas_inference(): + set_context(graph_kernel_flags="--disable_pass=add_rms_norm_fusion") + custom_kernels = "InferenceMatmulSplit," + custom_kernels + ",AddRmsNorm" + + default_env = { + "MS_INTERNAL_DISABLE_CUSTOM_KERNEL_LIST": custom_kernels + } + env_setup(default_env) + + if os.getenv("MS_MEMPOOL_BLOCK_SIZE"): + set_context( + mempool_block_size=f"{os.environ['MS_MEMPOOL_BLOCK_SIZE']}GB") + + if is_mindformers_model_backend(): + logger.info("Run with Mindformers backend!") + necessary_envs = ("MINDFORMERS_MODEL_CONFIG", ) + lost_envs = [ + env_item for env_item in necessary_envs if not os.getenv(env_item) + ] + + if lost_envs: + raise RuntimeError( + f'For "MindFormers" model backend, environments {str(lost_envs)} should be set!' + ) + elif is_mindone_model_backend(): + logger.info("Run with MindONE backend!") + else: + logger.info("Run with native model backend!") diff --git a/vllm_mindspore/v1/worker/gpu_model_runner.py b/vllm_mindspore/v1/worker/gpu_model_runner.py index 7f4e3fe1..107d032b 100644 --- a/vllm_mindspore/v1/worker/gpu_model_runner.py +++ b/vllm_mindspore/v1/worker/gpu_model_runner.py @@ -21,9 +21,10 @@ from typing import Dict, Tuple, List import numpy as np import torch -from mindspore import mutable +from mindspore import mutable, ops +import mindspore as ms from vllm_mindspore.v1.attention.backends.ms_attn import MsAttentionMetadata -from vllm_mindspore.utils import get_valid_dtype +from vllm_mindspore.utils import get_valid_dtype, atlas_inference from vllm_mindspore.model_executor.layers.rotary_embedding import InferMRotaryEmbedding as MRotaryEmbedding # type: ignore[attr-defined] from vllm.v1.outputs import ModelRunnerOutput @@ -175,8 +176,17 @@ def _prepare_inputs( def create_block(shape, dtype, name=None, device=None): - from mindspore import mint - blocks = mint.empty(shape, dtype=dtype, device=device) + from mindspore.mint import empty as empty_tensor + from mindspore.common.api import _pynative_executor + blocks = empty_tensor(*shape, dtype=dtype, device=device) + if device == "Ascend" and atlas_inference(): + blocks_nz = ops.auto_generate.format_cast(blocks, 29) + _pynative_executor.sync() + import gc + del blocks + gc.collect() + ms.hal.empty_cache() + return blocks_nz return blocks @@ -210,8 +220,11 @@ def initialize_kv_cache(self, kv_cache_config) -> None: assert num_blocks >= kv_cache_config.num_blocks if isinstance(kv_cache_spec, FullAttentionSpec): kv_cache_shape = self.attn_backend.get_kv_cache_shape( - num_blocks, kv_cache_spec.block_size, - kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) + num_blocks, kv_cache_spec.block_size, kv_cache_spec.num_kv_heads, + kv_cache_spec.head_size) + if atlas_inference(): + *dims, second_last, last = kv_cache_shape + kv_cache_shape = (*dims, second_last * last) dtype = kv_cache_spec.dtype dtype = get_valid_dtype(dtype) current_cache = [] diff --git a/vllm_mindspore/worker/cache_engine.py b/vllm_mindspore/worker/cache_engine.py index 2df44ee5..8190e03b 100644 --- a/vllm_mindspore/worker/cache_engine.py +++ b/vllm_mindspore/worker/cache_engine.py @@ -18,16 +18,26 @@ """CacheEngine class for managing the KV cache.""" import mindspore as ms -from mindspore import mutable, mint +from mindspore import mutable, mint, ops from typing import List from vllm.logger import init_logger -from vllm_mindspore.utils import MsKVCache, get_valid_dtype +from vllm_mindspore.utils import MsKVCache, get_valid_dtype, atlas_inference logger = init_logger(__name__) def create_block(shape, dtype, name=None, device=None): - blocks = mint.empty(shape, dtype=dtype, device=device) + from mindspore.ops.function.array_func import empty as empty_tensor + from mindspore.common.api import _pynative_executor + blocks = empty_tensor(*shape, dtype=dtype, device=device) + if device == "Ascend" and atlas_inference(): + blocks_nz = ops.auto_generate.format_cast(blocks, 29) + _pynative_executor.sync() + import gc + del blocks + gc.collect() + ms.hal.empty_cache() + return blocks_nz return blocks @@ -39,6 +49,9 @@ def ms_allocate_kv_cache( """Allocates KV cache on the specified device.""" kv_cache_shape = self.attn_backend.get_kv_cache_shape( num_blocks, self.block_size, self.num_kv_heads, self.head_size) + if atlas_inference(): + *dims, second_last, last = kv_cache_shape + kv_cache_shape = (*dims, second_last * last) kv_cache: List[MsKVCache] = [] self.dtype = get_valid_dtype(self.dtype) diff --git a/vllm_mindspore/worker/model_runner.py b/vllm_mindspore/worker/model_runner.py index 55bb26ec..706a2058 100644 --- a/vllm_mindspore/worker/model_runner.py +++ b/vllm_mindspore/worker/model_runner.py @@ -24,7 +24,7 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.sampling_params import SamplingParams from vllm.sequence import SequenceGroupMetadata -from vllm_mindspore.utils import STR_DTYPE_TO_TENSOR_DTYPE +from vllm_mindspore.utils import STR_DTYPE_TO_TENSOR_DTYPE, atlas_inference from mindspore import mutable @@ -137,7 +137,8 @@ def _dummy_run(self, block_size = self.cache_config.block_size num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config) head_size = self.model_config.get_head_size() - kv_shape = [0, block_size, num_kv_heads, head_size] + kv_shape = [0, block_size, num_kv_heads * head_size] if atlas_inference() else \ + [0, block_size, num_kv_heads, head_size] kv_caches = mutable([ mutable(( mutable(torch.tensor([], dtype=kv_cache_dtype, device=self.device).reshape(kv_shape)), -- Gitee From d2d4574b1fe88346b723a0251e8d5f2d1302cec2 Mon Sep 17 00:00:00 2001 From: HighCloud Date: Wed, 9 Jul 2025 11:36:46 +0800 Subject: [PATCH 2/2] support ds weight process --- .../models/mf_models/deepseek_v3.py | 2 +- .../mf_models/deepseekv3_weight_processor.py | 265 +++++++++++++++++- .../models/mf_models/weight_processor.py | 22 +- 3 files changed, 286 insertions(+), 3 deletions(-) diff --git a/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py b/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py index deb68eec..d4bef6c9 100644 --- a/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py +++ b/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py @@ -209,7 +209,7 @@ class DeepseekV3ForCausalLM(MfModelBase): def create_ptq(self, quant_type: str, quant_mode: PTQMode): """create_ptq""" - if quant_type.lower() == 'ptq': + if quant_type.lower() in ['ptq', 'ptq-duo']: cfg = PTQConfig(mode=quant_mode, backend=BackendTarget.ASCEND, weight_quant_dtype=msdtype.int8, act_quant_dtype=msdtype.int8, outliers_suppression=OutliersSuppressionType.OUTLIER_SUPPRESSION_PLUS, diff --git a/vllm_mindspore/model_executor/models/mf_models/deepseekv3_weight_processor.py b/vllm_mindspore/model_executor/models/mf_models/deepseekv3_weight_processor.py index 58ac64dc..707cf554 100644 --- a/vllm_mindspore/model_executor/models/mf_models/deepseekv3_weight_processor.py +++ b/vllm_mindspore/model_executor/models/mf_models/deepseekv3_weight_processor.py @@ -1675,6 +1675,266 @@ class DeepseekV3WeightProcessor(BaseWeightProcessor): return w1_weight_param, w1_scale_param, w3_weight_param, w3_scale_param, w2_weight_param + def dynamic_quant_process_qkv_weight(self, src_hf_dir, layer_id, hf_weight_map, parameter_dict): + '''dynamic_quant_process_qkv_weight''' + qkv_concat = self.config.model.model_config.qkv_concat + # q2l_proj + q2l_weight_name = f"model.layers.{layer_id}.attention.q2l_proj._layer.weight" + q2l_weight_param, _ = self.get_safetensor_from_file(q2l_weight_name, src_hf_dir, hf_weight_map) + q2l_bias_name = f"model.layers.{layer_id}.attention.q2l_proj._layer.matmul.quant_bias" + q2l_bias_param, _ = self.get_safetensor_from_file(q2l_bias_name, src_hf_dir, hf_weight_map) + q2l_scale_name = f"model.layers.{layer_id}.attention.q2l_proj._layer.matmul.dequant_scale" + q2l_scale_param, _ = self.get_safetensor_from_file(q2l_scale_name, src_hf_dir, hf_weight_map) + + q2l_quant_zp = f"model.layers.{layer_id}.attention.q2l_proj.quant_op.input_zp" + q2l_quant_scale = f"model.layers.{layer_id}.attention.q2l_proj.quant_op.input_scale" + q2l_quant_beta= f"model.layers.{layer_id}.attention.q2l_proj.quant_op.beta" + q2l_quant_zp_param, _ = self.get_safetensor_from_file(q2l_quant_zp, src_hf_dir, hf_weight_map) + q2l_quant_scale_param, _ = self.get_safetensor_from_file(q2l_quant_scale, src_hf_dir, hf_weight_map) + q2l_quant_beta_param, _ = self.get_safetensor_from_file(q2l_quant_beta, src_hf_dir, hf_weight_map) + + kv2l_weight_name = f"model.layers.{layer_id}.attention.kv2l._layer.weight" + kv2l_weight_param, _ = self.get_safetensor_from_file(kv2l_weight_name, src_hf_dir, hf_weight_map) + kv2l_bias_name = f"model.layers.{layer_id}.attention.kv2l._layer.matmul.quant_bias" + kv2l_bias_param, _ = self.get_safetensor_from_file(kv2l_bias_name, src_hf_dir, hf_weight_map) + kv2l_scale_name = f"model.layers.{layer_id}.attention.kv2l._layer.matmul.dequant_scale" + kv2l_scale_param, _ = self.get_safetensor_from_file(kv2l_scale_name, src_hf_dir, hf_weight_map) + + kv2l_quant_zp = f"model.layers.{layer_id}.attention.kv2l.quant_op.input_zp" + kv2l_quant_scale = f"model.layers.{layer_id}.attention.kv2l.quant_op.input_scale" + kv2l_quant_beta = f"model.layers.{layer_id}.attention.kv2l.quant_op.beta" + kv2l_quant_zp_param, _ = self.get_safetensor_from_file(kv2l_quant_zp, src_hf_dir, hf_weight_map) + kv2l_quant_scale_param, _ = self.get_safetensor_from_file(kv2l_quant_scale, src_hf_dir, hf_weight_map) + kv2l_quant_beta_param, _ = self.get_safetensor_from_file(kv2l_quant_beta, src_hf_dir, hf_weight_map) + + if qkv_concat: + qkv2l_weight_name = f"model.layers.{layer_id}.attention.qkv2l._layer.weight" + qkv2l_bias_name = f"model.layers.{layer_id}.attention.qkv2l._layer.matmul.quant_bias" + qkv2l_scale_name = f"model.layers.{layer_id}.attention.qkv2l._layer.matmul.dequant_scale" + qkv2l_quant_zp_name = f"model.layers.{layer_id}.attention.qkv2l.quant_op.input_zp" + qkv2l_quant_scale_name = f"model.layers.{layer_id}.attention.qkv2l.quant_op.input_scale" + qkv2l_quant_beta_name = f"model.layers.{layer_id}.attention.qkv2l.quant_op.beta" + + qkv2l_weight = np.concatenate((q2l_weight_param, kv2l_weight_param), 0) + parameter_dict[qkv2l_weight_name] = ms.Parameter(ms.Tensor(qkv2l_weight, ms.int8), name=qkv2l_weight_name, + requires_grad=False) + qkv2l_bias = np.concatenate((q2l_bias_param, kv2l_bias_param), 0) + parameter_dict[qkv2l_bias_name] = ms.Parameter(ms.Tensor(qkv2l_bias, ms.int32), name=qkv2l_bias_name, + requires_grad=False) + qkv2l_scale = np.concatenate((q2l_scale_param, kv2l_scale_param), 0) + parameter_dict[qkv2l_scale_name] = ms.Parameter(ms.Tensor(qkv2l_scale, ms.int64), name=qkv2l_scale_name, + requires_grad=False) + parameter_dict[qkv2l_quant_zp_name] = ms.Parameter(ms.Tensor(q2l_quant_zp_param, ms.int8), + name=qkv2l_quant_zp_name, requires_grad=False) + parameter_dict[qkv2l_quant_scale_name] = ms.Parameter(ms.Tensor(q2l_quant_scale_param, ms.float16), + name=qkv2l_quant_scale_name, requires_grad=False) + parameter_dict[qkv2l_quant_beta_name] = ms.Parameter(ms.Tensor(q2l_quant_beta_param, ms.float16), + name=qkv2l_quant_beta_name, requires_grad=False) + else: + parameter_dict[q2l_weight_name] = ms.Parameter(ms.Tensor(q2l_weight_param, ms.int8), name=q2l_weight_name, + requires_grad=False) + parameter_dict[kv2l_weight_name] = ms.Parameter(ms.Tensor(kv2l_weight_param, ms.int8), + name=kv2l_weight_name, requires_grad=False) + parameter_dict[q2l_bias_name] = ms.Parameter(ms.Tensor(q2l_bias_param, ms.int32), name=q2l_bias_name, + requires_grad=False) + parameter_dict[kv2l_bias_name] = ms.Parameter(ms.Tensor(kv2l_bias_param, ms.int32), name=kv2l_bias_name, + requires_grad=False) + parameter_dict[q2l_scale_name] = ms.Parameter(ms.Tensor(q2l_scale_param, ms.int64), name=q2l_scale_name, + requires_grad=False) + parameter_dict[kv2l_scale_name] = ms.Parameter(ms.Tensor(kv2l_scale_param, ms.int64), + name=kv2l_scale_name, requires_grad=False) + parameter_dict[q2l_quant_zp] = ms.Parameter(ms.Tensor(q2l_quant_zp_param, ms.int8), name=q2l_quant_zp, + requires_grad=False) + parameter_dict[kv2l_quant_zp] = ms.Parameter(ms.Tensor(kv2l_quant_zp_param, ms.int8), name=kv2l_quant_zp, + requires_grad=False) + parameter_dict[q2l_quant_scale] = ms.Parameter(ms.Tensor(q2l_quant_scale_param, ms.float16), + name=q2l_quant_scale, requires_grad=False) + parameter_dict[q2l_quant_beta] = ms.Parameter(ms.Tensor(q2l_quant_beta_param, ms.float16), + name=q2l_quant_beta, requires_grad=False) + parameter_dict[kv2l_quant_scale] = ms.Parameter(ms.Tensor(kv2l_quant_scale_param, ms.float16), + name=kv2l_quant_scale, requires_grad=False) + parameter_dict[kv2l_quant_beta] = ms.Parameter(ms.Tensor(kv2l_quant_beta_param, ms.float16), + name=kv2l_quant_beta, requires_grad=False) + + def dynamic_quant_process_route_ffn_weight(self, src_hf_dir, layer_id, hf_weight_map, parameter_dict, layer_type): + """dynamic_quant_process_route_ffn_weight""" + ffn_concat = self.config.model.model_config.ffn_concat + w1_weight_name = f"model.layers.{layer_id}.{layer_type}.w1._layer.weight" + w1_weight_param, _ = self.get_safetensor_from_file_split_tp_group(w1_weight_name, src_hf_dir, hf_weight_map, + split_axis=1) + + w1_scale_name = f"model.layers.{layer_id}.{layer_type}.w1._layer.matmul.weight_scale" + w1_scale_param, _ = self.get_safetensor_from_file_split_tp_group(w1_scale_name, src_hf_dir, hf_weight_map, + split_axis=1) + + w3_weight_name = f"model.layers.{layer_id}.{layer_type}.w3._layer.weight" + w3_weight_param, _ = self.get_safetensor_from_file_split_tp_group(w3_weight_name, src_hf_dir, hf_weight_map, + split_axis=1) + + w3_scale_name = f"model.layers.{layer_id}.{layer_type}.w3._layer.matmul.weight_scale" + w3_scale_param, _ = self.get_safetensor_from_file_split_tp_group(w3_scale_name, src_hf_dir, hf_weight_map, + split_axis=1) + + if ffn_concat: + concat_weight_name = f"model.layers.{layer_id}.{layer_type}.w_gate_hidden._layer.weight" + concat_weight_param = ms.Tensor(np.concatenate([w1_weight_param, w3_weight_param], axis=1), dtype=ms.int8) + parameter_dict[concat_weight_name] = ms.Parameter(concat_weight_param, name=concat_weight_name, + requires_grad=False) + + concat_scale_name = f"model.layers.{layer_id}.{layer_type}.w_gate_hidden._layer.matmul.weight_scale" + concat_scale_param = ms.Tensor(np.concatenate([w1_scale_param, w3_scale_param], axis=1), dtype=ms.float32) + parameter_dict[concat_scale_name] = ms.Parameter(concat_scale_param, name=concat_scale_name, + requires_grad=False) + else: + # w1 w3 + parameter_dict[w1_weight_name] = ms.Parameter(ms.Tensor(w1_weight_param, ms.int8), name=w1_weight_name, + requires_grad=False) + parameter_dict[w3_weight_name] = ms.Parameter(ms.Tensor(w3_weight_param, ms.int8), name=w3_weight_name, + requires_grad=False) + + parameter_dict[w1_scale_name] = ms.Parameter(ms.Tensor(w1_scale_param, ms.float32), + name=w1_scale_name, requires_grad=False) + parameter_dict[w3_scale_name] = ms.Parameter(ms.Tensor(w3_scale_param, ms.float32), + name=w3_scale_name, requires_grad=False) + + def dynamic_quant_process_ffn_weight(self, src_hf_dir, layer_id, hf_weight_map, parameter_dict, layer_type): + """dynamic_quant_process_ffn_weight""" + + ffn_concat = self.config.model.model_config.ffn_concat + w1_weight_name = f"model.layers.{layer_id}.{layer_type}.w1._layer.weight" + w1_weight_param, _ = self.get_safetensor_from_file_split_tp_group(w1_weight_name, src_hf_dir, hf_weight_map, + split_axis=0) + w1_scale_name = f"model.layers.{layer_id}.{layer_type}.w1._layer.matmul.weight_scale" + w1_scale_param, _ = self.get_safetensor_from_file_split_tp_group(w1_scale_name, src_hf_dir, hf_weight_map, + split_axis=0) + + w3_weight_name = f"model.layers.{layer_id}.{layer_type}.w3._layer.weight" + w3_weight_param, _ = self.get_safetensor_from_file_split_tp_group(w3_weight_name, src_hf_dir, hf_weight_map, + split_axis=0) + w3_scale_name = f"model.layers.{layer_id}.{layer_type}.w3._layer.matmul.weight_scale" + w3_scale_param, _ = self.get_safetensor_from_file_split_tp_group(w3_scale_name, src_hf_dir, hf_weight_map, + split_axis=0) + w2_weight_name = f"model.layers.{layer_id}.{layer_type}.w2._layer.weight" + w2_scale_name = f"model.layers.{layer_id}.{layer_type}.w2._layer.matmul.weight_scale" + w2_weight_param, _ = self.get_safetensor_from_file_split_tp_group(w2_weight_name, src_hf_dir, hf_weight_map, + split_axis=1) + w2_scale_param, _ = self.get_safetensor_from_file(w2_scale_name, src_hf_dir, hf_weight_map) + + if ffn_concat: + concat_weight_name = f"model.layers.{layer_id}.{layer_type}.w_gate_hidden._layer.weight" + concat_weight_param = ms.Tensor(np.concatenate([w1_weight_param, w3_weight_param], axis=0), dtype=ms.int8) + parameter_dict[concat_weight_name] = ms.Parameter(concat_weight_param, name=concat_weight_name, + requires_grad=False) + + concat_scale_name = f"model.layers.{layer_id}.{layer_type}.w_gate_hidden._layer.matmul.weight_scale" + concat_scale_type = convert_np_to_ms_dtype(w1_scale_param) + concat_scale_param = ms.Tensor(np.concatenate([w1_scale_param, w3_scale_param], axis=0), dtype=concat_scale_type) + parameter_dict[concat_scale_name] = ms.Parameter(concat_scale_param, name=concat_scale_name, + requires_grad=False) + else: + # w1 w3 + parameter_dict[w1_weight_name] = ms.Parameter(ms.Tensor(w1_weight_param, ms.int8), name=w1_weight_name, + requires_grad=False) + parameter_dict[w3_weight_name] = ms.Parameter(ms.Tensor(w3_weight_param, ms.int8), name=w3_weight_name, + requires_grad=False) + w1_scale_type = convert_np_to_ms_dtype(w1_scale_param) + parameter_dict[w1_scale_name] = ms.Parameter(ms.Tensor(w1_scale_param, w1_scale_type), + name=w1_scale_name, requires_grad=False) + parameter_dict[w3_scale_name] = ms.Parameter(ms.Tensor(w3_scale_param, w1_scale_type), + name=w3_scale_name, requires_grad=False) + + parameter_dict[w2_weight_name] = ms.Parameter(ms.Tensor(w2_weight_param, ms.int8), name=w2_weight_name, + requires_grad=False) + w2_scale_type = convert_np_to_ms_dtype(w2_scale_param) + parameter_dict[w2_scale_name] = ms.Parameter(ms.Tensor(w2_scale_param, w2_scale_type), + name=w2_scale_name, requires_grad=False) + + def infer_dynamic_quant_get_value(self, param_name, src_hf_dir, hf_weight_map, no_need_split_layer): + '''infer_dynamic_quant_get_value''' + + if any([name in param_name for name in no_need_split_layer]): + value, _ = self.get_safetensor_from_file(param_name, src_hf_dir, + hf_weight_map) + elif any([name in param_name for name in [".l2q_proj."]]): + if param_name.endswith(".weight") or "matmul" in param_name: + value, _ = self.get_safetensor_from_file_split_tp_group(param_name, src_hf_dir, + hf_weight_map, + split_axis=0) + else: + value, _ = self.get_safetensor_from_file(param_name, src_hf_dir, + hf_weight_map) + elif any([name in param_name for name in [".wo.", "feed_forward.w2", "shared_experts.w2"]]): + if param_name.endswith(".weight"): + value, _ = self.get_safetensor_from_file_split_tp_group(param_name, src_hf_dir, + hf_weight_map, + split_axis=1) + else: + value, _ = self.get_safetensor_from_file(param_name, src_hf_dir, + hf_weight_map) + elif ".routed_experts.ffn.w2" in param_name: + if param_name.endswith(".weight"): + value, _ = self.get_safetensor_from_file_split_tp_group(param_name, src_hf_dir, hf_weight_map, + split_axis=2) + else: + value, _ = self.get_safetensor_from_file(param_name, src_hf_dir, hf_weight_map) + elif any([name in param_name for name in ["lkv2kv_k_nope", "absorb", "lkv2kv_v"]]): + value, _ = self.get_safetensor_from_file_split_tp_group(param_name, src_hf_dir, hf_weight_map, + split_axis=0) + elif "lm_head" in param_name: + if not self.config.parallel_config.vocab_emb_dp: + value, _ = self.get_safetensor_from_file_split_tp_group(param_name, src_hf_dir, hf_weight_map, + split_axis=0) + else: + value, _ = self.get_safetensor_from_file(param_name, src_hf_dir, hf_weight_map) + else: + raise ValueError(f"not found layer {param_name}, please check safetensors file.") + return value + + def infer_dynamic_quant_net_ms_convert_layer_weight(self, src_hf_dir, num_layers, hf_weight_map): + '''infer_dynamic_quant_net_ms_convert_layer_weight''' + parameter_dict = {} + start_layer_index, end_layer_index = self.get_layer_index(num_layers) + + no_need_split_layer = ["tok_embeddings", "norm", "routed_experts.router.dense", + "routed_experts.router.e_score_correction_bias", + "topk_bias"] + network_names = [] + for m in self.network.parameters_and_names(): + network_names.append(m[0]) + for layer_id in tqdm(range(start_layer_index, end_layer_index), desc="qkv/ffn params load"): + if layer_id >= 3: + self.dynamic_quant_process_route_ffn_weight(src_hf_dir, layer_id, hf_weight_map, parameter_dict, + "feed_forward.routed_experts.ffn") + self.dynamic_quant_process_ffn_weight(src_hf_dir, layer_id, hf_weight_map, parameter_dict, + "feed_forward.shared_experts") + + else: + self.dynamic_quant_process_ffn_weight(src_hf_dir, layer_id, hf_weight_map, parameter_dict, + "feed_forward") + self.dynamic_quant_process_qkv_weight(src_hf_dir, layer_id, hf_weight_map, parameter_dict) + + skip_layer = ["feed_forward.routed_experts.ffn.w1", "feed_forward.shared_experts.w1", "feed_forward.w1", + "feed_forward.routed_experts.ffn.w3", "feed_forward.shared_experts.w3", "feed_forward.w3", + "feed_forward.routed_experts.ffn.w_gate_hidden", "feed_forward.shared_experts.w_gate_hidden", + "feed_forward.w_gate_hidden", "attention.kv2l", "attention.q2l_proj", "attention.qkv2l"] + + for param_name, _ in tqdm(hf_weight_map.items(), desc="remaining params load"): + if param_name not in network_names: + continue + + if any([name in param_name for name in skip_layer]): + continue + + value = self.infer_dynamic_quant_get_value(param_name, src_hf_dir, hf_weight_map, no_need_split_layer) + dst_dtype = convert_np_to_ms_dtype(value) + + parameter_dict[param_name] = ms.Parameter(ms.Tensor(value, dtype=dst_dtype), + name=param_name, requires_grad=False) + + param_not_load, ckpt_not_load = ms.load_param_into_net(self.network, parameter_dict) + print(f"dsquant param_not_load:{param_not_load}") + print(f"dsquant ckpt_not_load:{ckpt_not_load}") + def smooth_quant_process_shared_ffn_weight(self, src_hf_dir, layer_id, hf_weight_map, parameter_dict, layer_type): @@ -2155,7 +2415,7 @@ class DeepseekV3WeightProcessor(BaseWeightProcessor): quantization_config = self.config.model.model_config.quantization_config quant_method = quantization_config.quant_method if quantization_config else None - support_quant_method = ["gptq-pergroup", "smoothquant", "osl"] + support_quant_method = ["gptq-pergroup", "smoothquant", "osl", 'ptq-duo'] if not quant_method or (quant_method not in support_quant_method) and \ not is_mtp_model: self.infer_convert_outer_weight(src_hf_dir, hf_weight_map) @@ -2172,6 +2432,9 @@ class DeepseekV3WeightProcessor(BaseWeightProcessor): self.infer_smooth_quant_net_ms_convert_layer_weight( src_hf_dir, self.num_layers, hf_weight_map) return + if quant_method and quant_method == "ptq-duo": + self.infer_dynamic_quant_net_ms_convert_layer_weight(src_hf_dir, self.num_layers, hf_weight_map) + return enable_tqdm = rank_id == 0 mtp_layers = self.config.model.model_config.num_nextn_predict_layers diff --git a/vllm_mindspore/model_executor/models/mf_models/weight_processor.py b/vllm_mindspore/model_executor/models/mf_models/weight_processor.py index 89d786eb..6c8612eb 100644 --- a/vllm_mindspore/model_executor/models/mf_models/weight_processor.py +++ b/vllm_mindspore/model_executor/models/mf_models/weight_processor.py @@ -22,7 +22,7 @@ from enum import Enum from safetensors import safe_open from mindspore.communication.management import get_rank, get_group_size from mindformers.parallel_core.inference.utils import get_tp_world_size -from mindformers.parallel_core.inference.parallel_state import get_data_parallel_world_size +from mindformers.parallel_core.inference.parallel_state import get_data_parallel_world_size, get_pp_world_size class EPMethod(Enum): @@ -71,6 +71,26 @@ class BaseWeightProcessor: self.parameter_dict = {} self.file_handles = {} + def get_layer_index(self, num_layers): + pp_nums = get_pp_world_size() + tp_nums = self.tp_group_size + offset = self.config.model.model_config.offset + offset_index = self.global_rank_id // tp_nums + stage_layers = num_layers // pp_nums + start_layer_index = offset_index * stage_layers + end_layer_index = start_layer_index + stage_layers + + if pp_nums > 1 and num_layers % pp_nums != 0: + if isinstance(offset, list): + raise ValueError(f"The parameter 'offset' is expected to be a list, but got {offset} instead." + f" Please check whether your offset parameter is set correctly!") + for num in range(0, offset_index): + start_layer_index += offset[num] + end_layer_index += offset[num] + end_layer_index += offset[offset_index] + + return start_layer_index, end_layer_index + def get_file_handles(self, filename): if filename not in self.file_handles: fp = safe_open(filename, framework="np") -- Gitee