diff --git a/vllm_mindspore/model_executor/models/mf_models/qwen3_moe.py b/vllm_mindspore/model_executor/models/mf_models/qwen3_moe.py new file mode 100644 index 0000000000000000000000000000000000000000..abce44f6acb0fee25404af004a82d3944c27446a --- /dev/null +++ b/vllm_mindspore/model_executor/models/mf_models/qwen3_moe.py @@ -0,0 +1,85 @@ +#!/usr/bin/env python3 +# encoding: utf-8 +# Copyright 2025 Huawei Technologies Co., Ltd +# Copyright 2024 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +from typing import Iterable, Set, Tuple + +from vllm.config import VllmConfig +from vllm.config import get_current_vllm_config +from vllm.logger import init_logger + +from mindspore import Tensor, JitConfig +from mindspore.nn.utils import no_init_parameters + +from mindformers.models.llama import LlamaConfig as LlamaConfig_MF +from research.qwen3.qwen3_moe import ( + ParallelQwen3MoeForCausalLM as ParallelQwen3MoeForCausalLM_MF, +) + +from vllm_mindspore.model_executor.layers.sampler import get_sampler +from vllm_mindspore.model_executor.models.model_base import Fake_Attention +from vllm_mindspore.model_executor.models.mf_models.mf_model_base import MfModelBase +from vllm_mindspore.model_executor.models.mf_models.qwen3_moe_weight_processor import Qwen3MoeWeightProcessor + + +logger = init_logger(__name__) + + +class Qwen3MoeForCausalLM(MfModelBase): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: + super(Qwen3MoeForCausalLM, self).__init__(vllm_config=vllm_config, prefix=prefix) + self.mf_kvcaches_init = False + + self.sampler = get_sampler() + self.set_modules({"model": self.network}) + + self.kv_caches = [Fake_Attention() for i in range(self.mf_model_config.num_layers)] + compilation_config = get_current_vllm_config().compilation_config + + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + for i in range(self.mf_model_config.num_layers): + compilation_config.static_forward_context[str(i)] = self.kv_caches[i] + + self.set_flags = False + + def _generate_model_config(self): + self.mf_config.load_checkpoint = self.get_model_path() + self.mf_model_config = LlamaConfig_MF(**self.mf_config.model.model_config) + if self.mf_config.moe_config: + self.mf_model_config.moe_config = self.mf_config.moe_config + self.mf_model_config.return_hidden_states = True + + # qwen qkv concat will support in next version + self.mf_model_config.qkv_concat = False + setattr(self.mf_model_config, 'npu_mem_size', -1) + self.mf_config.model.model_config.qkv_concat = False + + def _create_network(self): + # Initial network + with no_init_parameters(): # Delay initialization + network = ParallelQwen3MoeForCausalLM_MF(self.mf_model_config) + return network, network.lm_head + + def load_weights(self, weights: Iterable[Tuple[str, Tensor]]) -> Set[str]: + weight_processor = Qwen3MoeWeightProcessor(self.mf_config, self.network, False) + weight_processor.load_safetensors_shard(self.mf_config.load_checkpoint) + + self.network.set_dynamic_inputs() + dynamic_hidden_states = Tensor(shape=[None, None], dtype=self.mf_model_config.compute_dtype) + self.lm_head.set_inputs(dynamic_hidden_states) + return None diff --git a/vllm_mindspore/model_executor/models/mf_models/qwen3_moe_weight_processor.py b/vllm_mindspore/model_executor/models/mf_models/qwen3_moe_weight_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..c689832b9e6f8c672ad91bc3082602798b25ca74 --- /dev/null +++ b/vllm_mindspore/model_executor/models/mf_models/qwen3_moe_weight_processor.py @@ -0,0 +1,280 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +""" +transform huggingface model to mindspore safetensor. +""" +import os +import json +import gc +import numpy as np +from tqdm import tqdm +from safetensors import safe_open +import mindspore as ms +from mindspore.communication.management import get_rank + +from vllm_mindspore.model_executor.models.mf_models.weight_processor import BaseWeightProcessor + + +class Qwen3MoeWeightProcessor(BaseWeightProcessor): + r""" + Provide Qwen2 Model weight load and shards. + Args: + config (Qwen2Config): The config of Qwen2 model. + network (InferenceQwen2ForCausalLM): The network of Qwen2. + + """ + + def __init__(self, config, network, is_quant): + super().__init__(config, network, is_quant) + + def infer_convert_outer_weight(self, src_hf_dir, hf_weight_map): + """convert weight not in model""" + embed_tokens_hf_name = "model.embed_tokens.weight" + embed_tokens_ms_name = self.convert_weight_name(embed_tokens_hf_name) + if self.config.parallel_config.vocab_emb_dp: + np_data, _ = self.get_safetensor_from_file(embed_tokens_hf_name, src_hf_dir, hf_weight_map) + else: + np_data, _ = self.get_safetensor_from_file(embed_tokens_hf_name, src_hf_dir, hf_weight_map, + is_split_param=True, split_axis=0) + self.parameter_dict[embed_tokens_ms_name] = ms.Parameter(ms.from_numpy(np_data).astype(ms.bfloat16), + name=embed_tokens_ms_name, + requires_grad=False) + + norm_hf_name = "model.norm.weight" + norm_ms_name = self.convert_weight_name(norm_hf_name) + np_data, _ = self.get_safetensor_from_file(norm_hf_name, src_hf_dir, hf_weight_map) + self.parameter_dict[norm_ms_name] = ms.Parameter(ms.from_numpy(np_data).astype(ms.bfloat16), + name=norm_ms_name, + requires_grad=False) + + lm_head_hf_name = "lm_head.weight" + lm_head_ms_name = self.convert_weight_name(lm_head_hf_name) + if not self.config.model.model_config.tie_word_embeddings: + if not self.config.parallel_config.vocab_emb_dp: + np_data, _ = self.get_safetensor_from_file(lm_head_hf_name, src_hf_dir, hf_weight_map, + is_split_param=True, split_axis=0) + else: + np_data, _ = self.get_safetensor_from_file(lm_head_hf_name, src_hf_dir, hf_weight_map) + self.parameter_dict[lm_head_ms_name] = ms.Parameter(ms.from_numpy(np_data).astype(ms.bfloat16), + name=lm_head_ms_name, + requires_grad=False) + + def convert_weight_name(self, weight_name: str): + """replace weight name""" + weight_name = weight_name.replace('embed_tokens.weight', 'tok_embeddings.embedding_weight') + weight_name = weight_name.replace('self_attn.q_proj.', 'attention.wq.') + weight_name = weight_name.replace('self_attn.k_proj.', 'attention.wk.') + weight_name = weight_name.replace('self_attn.v_proj.', 'attention.wv.') + weight_name = weight_name.replace('self_attn.o_proj.', 'attention.wo.') + + weight_name = weight_name.replace('mlp.gate_proj.', 'feed_forward.w1.') + weight_name = weight_name.replace('mlp.down_proj.', 'feed_forward.w2.') + weight_name = weight_name.replace('mlp.up_proj.', 'feed_forward.w3.') + weight_name = weight_name.replace('.input_layernorm.', '.attention_norm.') + weight_name = weight_name.replace('.post_attention_layernorm.', '.ffn_norm.') + weight_name = weight_name.replace('model.norm.weight', 'model.norm_out.weight') + weight_name = weight_name.replace('mlp.gate.e_score_correction_bias', + 'feed_forward.routed_experts.router.e_score_correction_bias') + weight_name = weight_name.replace('self_attn.q_norm.', 'attention.q_norm.') + weight_name = weight_name.replace('self_attn.k_norm.', 'attention.k_norm.') + weight_name = weight_name.replace('mlp.gate.weight', 'feed_forward.routed_experts.router.dense.weight') + return weight_name + + def infer_process_dense_ffn_weight(self, src_hf_dir, layer_id, hf_weight_map): + """process moe router expert weight""" + ffn_concat = self.config.model.model_config.ffn_concat + num_router_experts = self.config.moe_config.expert_num + + # router expert dense + router_dense_hf_name = f"model.layers.{layer_id}.mlp.gate.weight" + router_dense_ms_name = self.convert_weight_name(router_dense_hf_name) + router_dense_ms_param, _ = self.get_safetensor_from_file(router_dense_hf_name, src_hf_dir, hf_weight_map) + self.parameter_dict[router_dense_ms_name] = ms.Parameter( + ms.from_numpy(router_dense_ms_param).astype(ms.bfloat16), + name=router_dense_ms_name, requires_grad=False) + + w1_list = [] + w2_list = [] + w3_list = [] + + w1_ms_name = f"model.layers.{layer_id}.feed_forward.routed_experts.ffn.w1.weight" + w2_ms_name = f"model.layers.{layer_id}.feed_forward.routed_experts.ffn.w2.weight" + w3_ms_name = f"model.layers.{layer_id}.feed_forward.routed_experts.ffn.w3.weight" + + for index in range(0, num_router_experts): + w1_hf_name = f"model.layers.{layer_id}.mlp.experts.{index}.gate_proj.weight" + w1_ms_param, _ = self.get_safetensor_from_file(w1_hf_name, src_hf_dir, hf_weight_map, + is_split_param=True, split_axis=0) + + w2_hf_name = f"model.layers.{layer_id}.mlp.experts.{index}.down_proj.weight" + w2_ms_param, _ = self.get_safetensor_from_file(w2_hf_name, src_hf_dir, hf_weight_map, + is_split_param=True, split_axis=1) + + w3_hf_name = f"model.layers.{layer_id}.mlp.experts.{index}.up_proj.weight" + w3_ms_param, _ = self.get_safetensor_from_file(w3_hf_name, src_hf_dir, hf_weight_map, + is_split_param=True, split_axis=0) + + w1_list.append(w1_ms_param) + w2_list.append(w2_ms_param) + w3_list.append(w3_ms_param) + + w1_ms_stack_param = np.stack(w1_list, axis=0) + w2_ms_stack_param = np.stack(w2_list, axis=0) + w3_ms_stack_param = np.stack(w3_list, axis=0) + + if ffn_concat: + w_gate_hidden_name = f"model.layers.{layer_id}.feed_forward.routed_experts.ffn.w_gate_hidden.weight" + w_gate_hidden_name = w_gate_hidden_name if layer_id < self.num_layers else \ + self.convert_mtp_weight_name(w_gate_hidden_name) + w_gate_hidden_np = np.concatenate([w1_ms_stack_param, w3_ms_stack_param], axis=1) + w_gate_hidden_param = ms.from_numpy(w_gate_hidden_np).permute(0, 2, 1).astype(dtype=ms.bfloat16) + self.parameter_dict[w_gate_hidden_name] = ms.Parameter(w_gate_hidden_param, + name=w_gate_hidden_name, + requires_grad=False) + else: + w1_ms_stack_param = ms.from_numpy(w1_ms_stack_param).permute(0, 2, 1).astype(ms.bfloat16) + self.parameter_dict[w1_ms_name] = ms.Parameter(w1_ms_stack_param, + name=w1_ms_name, + requires_grad=False) + + w3_ms_stack_param = ms.from_numpy(w3_ms_stack_param).permute(0, 2, 1).astype(ms.bfloat16) + self.parameter_dict[w3_ms_name] = ms.Parameter(w3_ms_stack_param, + name=w3_ms_name, + requires_grad=False) + + w2_ms_stack_param = ms.from_numpy(w2_ms_stack_param).permute(0, 2, 1).astype(ms.bfloat16) + self.parameter_dict[w2_ms_name] = ms.Parameter(w2_ms_stack_param, + name=w2_ms_name, + requires_grad=False) + + def infer_process_attention_weight(self, src_hf_dir, layer_id, hf_weight_map): + """infer process attention weight""" + qkv_concat = self.config.model.model_config.qkv_concat + # wq + wq_hf_name = f"model.layers.{layer_id}.self_attn.q_proj.weight" + wq_ms_name = self.convert_weight_name(wq_hf_name) + wq_ms_param, _ = self.get_safetensor_from_file(wq_hf_name, src_hf_dir, hf_weight_map, is_split_param=True, + split_axis=0) + + # wq_norm + q_norm_hf_name = f"model.layers.{layer_id}.self_attn.q_norm.weight" + q_norm_ms_name = self.convert_weight_name(q_norm_hf_name) + q_norm_ms_param, _ = self.get_safetensor_from_file(q_norm_hf_name, src_hf_dir, hf_weight_map) + self.parameter_dict[q_norm_ms_name] = ms.Parameter(ms.Tensor(q_norm_ms_param, ms.bfloat16), name=q_norm_ms_name, + requires_grad=False) + + # wk + wk_hf_name = f"model.layers.{layer_id}.self_attn.k_proj.weight" + wk_ms_name = self.convert_weight_name(wk_hf_name) + wk_ms_param, _ = self.get_safetensor_from_file(wk_hf_name, src_hf_dir, hf_weight_map, is_split_param=True, + split_axis=0) + + # wk_norm + k_norm_hf_name = f"model.layers.{layer_id}.self_attn.k_norm.weight" + k_norm_ms_name = self.convert_weight_name(k_norm_hf_name) + k_norm_ms_param, _ = self.get_safetensor_from_file(k_norm_hf_name, src_hf_dir, hf_weight_map) + self.parameter_dict[k_norm_ms_name] = ms.Parameter(ms.Tensor(k_norm_ms_param, ms.bfloat16), name=k_norm_ms_name, + requires_grad=False) + + # wv + wv_hf_name = f"model.layers.{layer_id}.self_attn.v_proj.weight" + wv_ms_name = self.convert_weight_name(wv_hf_name) + wv_ms_param, _ = self.get_safetensor_from_file(wv_hf_name, src_hf_dir, hf_weight_map, is_split_param=True, + split_axis=0) + + if qkv_concat: + w_qkv_name = f"model.layers.{layer_id}.attention.w_qkv.weight" + w_qkv_param = np.concatenate((wq_ms_param, wk_ms_param, wv_ms_param), axis=0) + w_qkv_param = ms.from_numpy(w_qkv_param).astype(ms.bfloat16) + self.parameter_dict[w_qkv_name] = ms.Parameter(w_qkv_param, name=w_qkv_name, requires_grad=False) + else: + self.parameter_dict[wq_ms_name] = ms.Parameter(ms.from_numpy(wq_ms_param).astype(ms.bfloat16), + name=wq_ms_name, + requires_grad=False) + self.parameter_dict[wk_ms_name] = ms.Parameter(ms.from_numpy(wk_ms_param).astype(ms.bfloat16), + name=wk_ms_name, + requires_grad=False) + self.parameter_dict[wv_ms_name] = ms.Parameter(ms.from_numpy(wv_ms_param).astype(ms.bfloat16), + name=wv_ms_name, + requires_grad=False) + # wo + wo_hf_name = f"model.layers.{layer_id}.self_attn.o_proj.weight" + wo_ms_name = self.convert_weight_name(wo_hf_name) + wo_ms_param, _ = self.get_safetensor_from_file(wo_hf_name, src_hf_dir, hf_weight_map, is_split_param=True, + split_axis=1) + self.parameter_dict[wo_ms_name] = ms.Parameter(ms.from_numpy(wo_ms_param).astype(ms.bfloat16), + name=wo_ms_name, + requires_grad=False) + + def infer_process_norm_weight(self, src_hf_dir, layer_id, hf_weight_map): + """infer process attention weight""" + # attention_norm + attention_norm_hf_name = f"model.layers.{layer_id}.input_layernorm.weight" + attention_norm_ms_name = self.convert_weight_name(attention_norm_hf_name) + attention_norm_ms_param, _ = self.get_safetensor_from_file(attention_norm_hf_name, + src_hf_dir, + hf_weight_map) + self.parameter_dict[attention_norm_ms_name] = ms.Parameter( + ms.from_numpy(attention_norm_ms_param).astype(ms.bfloat16), + name=attention_norm_ms_name, + requires_grad=False) + + # ffn_norm + ffn_norm_hf_name = f"model.layers.{layer_id}.post_attention_layernorm.weight" + ffn_norm_ms_name = self.convert_weight_name(ffn_norm_hf_name) + ffn_norm_ms_param, _ = self.get_safetensor_from_file(ffn_norm_hf_name, src_hf_dir, hf_weight_map) + self.parameter_dict[ffn_norm_ms_name] = ms.Parameter( + ms.from_numpy(ffn_norm_ms_param).astype(ms.bfloat16), + name=ffn_norm_ms_name, + requires_grad=False) + + def infer_convert_layer_weight(self, src_hf_dir, layer_id, hf_weight_map): + """infer convert layer weight""" + self.infer_process_attention_weight(src_hf_dir, layer_id, hf_weight_map) + self.infer_process_dense_ffn_weight(src_hf_dir, layer_id, hf_weight_map) + self.infer_process_norm_weight(src_hf_dir, layer_id, hf_weight_map) + + def load_safetensors_shard(self, src_hf_dir): + """qwen load safetensors and shard """ + rank_id = get_rank() + param_json_path = "" + for file in os.listdir(src_hf_dir): + if file.endswith('index.json'): + param_json_path = os.path.join(src_hf_dir, file) + break + + hf_weight_map = {} + if os.path.exists(param_json_path): + with open(param_json_path, "r") as fp: + hf_weight_map = json.load(fp)['weight_map'] + else: + # only one safetensor, create a hf_weight_map + safetensor_file = "model.safetensors" + with safe_open(f"{src_hf_dir}/{safetensor_file}", framework="np") as sf_file: + all_keys = sf_file.keys() + for key in all_keys: + hf_weight_map[str(key).strip()] = safetensor_file + + self.infer_convert_outer_weight(src_hf_dir, hf_weight_map) + num_layers = self.config.model.model_config.num_layers + enable_tqdm = rank_id == 0 + for layer_id in tqdm(range(num_layers), desc="Weight loading", disable=not enable_tqdm): + self.infer_convert_layer_weight(src_hf_dir, layer_id, hf_weight_map) + + param_not_load, ckpt_not_load = ms.load_param_into_net(self.network, self.parameter_dict) + del self.parameter_dict + gc.collect() + diff --git a/vllm_mindspore/model_executor/models/registry.py b/vllm_mindspore/model_executor/models/registry.py index 9138e725756631a2eac206ee089bc5021eff9dc9..3de45918f47e3eee0b738ed97ab52b0b1ceeaf37 100644 --- a/vllm_mindspore/model_executor/models/registry.py +++ b/vllm_mindspore/model_executor/models/registry.py @@ -39,6 +39,7 @@ _MINDFORMERS_MODELS = { "DeepseekV3ForCausalLM": ("deepseek_v3", "DeepseekV3ForCausalLM"), "DeepSeekMTPModel": ("deepseek_mtp", "DeepseekV3MTPForCausalLM"), "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"), + "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM") } MindSporeModelRegistry = _ModelRegistry(