From 1560c2fbf972ad4d5d1a5859e529c564b4880eca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8C=83=E5=90=89=E6=96=8C?= Date: Mon, 17 Mar 2025 20:11:26 +0800 Subject: [PATCH] optimize weights load --- .../mf_models/deepseekv3_infer_parallelism.py | 179 +++++++++--------- .../models/mf_models/model_parallelism.py | 70 ++++--- 2 files changed, 128 insertions(+), 121 deletions(-) diff --git a/vllm_mindspore/model_executor/models/mf_models/deepseekv3_infer_parallelism.py b/vllm_mindspore/model_executor/models/mf_models/deepseekv3_infer_parallelism.py index 6a713f83..24a41a40 100644 --- a/vllm_mindspore/model_executor/models/mf_models/deepseekv3_infer_parallelism.py +++ b/vllm_mindspore/model_executor/models/mf_models/deepseekv3_infer_parallelism.py @@ -20,7 +20,7 @@ import os import time import json import numpy as np - +from tqdm import tqdm import mindspore as ms from vllm_mindspore.model_executor.models.mf_models.model_parallelism import BaseModelParallelism @@ -116,7 +116,7 @@ class DeepseekInferParallelism(BaseModelParallelism): router_dense_hf_name = f"model.layers.{layer_id}.mlp.gate.weight" router_dense_ms_name = self.quant_convert_weight_name(router_dense_hf_name) router_dense_ms_param = self.get_safetensor_from_file(router_dense_hf_name, src_hf_dir, hf_weight_map) - parameter_dict[router_dense_ms_name] = ms.Parameter(ms.Tensor(router_dense_ms_param, ms.bfloat16), + parameter_dict[router_dense_ms_name] = ms.Parameter(ms.Tensor.from_numpy(router_dense_ms_param).astype(ms.bfloat16), name=router_dense_ms_name, requires_grad=False) # e_score_correction_bias @@ -125,7 +125,7 @@ class DeepseekInferParallelism(BaseModelParallelism): e_score_correction_bias_ms_param = self.get_safetensor_from_file(e_score_correction_bias_hf_name, src_hf_dir, hf_weight_map) parameter_dict[e_score_correction_bias_ms_name] = ms.Parameter( - ms.Tensor(e_score_correction_bias_ms_param, ms.float32), + ms.Tensor.from_numpy(e_score_correction_bias_ms_param).astype(ms.float32), name=e_score_correction_bias_ms_name, requires_grad=False) w1_list = [] @@ -180,9 +180,9 @@ class DeepseekInferParallelism(BaseModelParallelism): w2_scale_list.append(w2_scale_ms_param) w3_scale_list.append(w3_scale_ms_param) - w1_ms_stack_param = np.stack(w1_list, axis=0).transpose(0, 2, 1) - w2_ms_stack_param = np.stack(w2_list, axis=0).transpose(0, 2, 1) - w3_ms_stack_param = np.stack(w3_list, axis=0).transpose(0, 2, 1) + w1_ms_stack_param = np.stack(w1_list, axis=0) + w2_ms_stack_param = np.stack(w2_list, axis=0) + w3_ms_stack_param = np.stack(w3_list, axis=0) w1_scale_ms_stack_param = np.stack(w1_scale_list, axis=0) w2_scale_ms_stack_param = np.stack(w2_scale_list, axis=0) @@ -191,36 +191,39 @@ class DeepseekInferParallelism(BaseModelParallelism): if ffn_concat: # w_gate_hidden w_gate_hidden_name = f"model.layers.{layer_id}.feed_forward.routed_experts.ffn.w_gate_hidden._layer.weight" - w_gate_hidden_param = ms.Tensor(np.concatenate([w1_ms_stack_param, w3_ms_stack_param], axis=2), - dtype=ms.int8) + w_gate_hidden_param = ms.Tensor.from_numpy( + np.concatenate([w1_ms_stack_param, w3_ms_stack_param], axis=1)).transpose(0, 2, 1).astype(ms.int8) parameter_dict[w_gate_hidden_name] = ms.Parameter(w_gate_hidden_param, name=w_gate_hidden_name, requires_grad=False) # w_scale_gate_hidden w_scale_gate_hidden_name = f"model.layers.{layer_id}.feed_forward.routed_experts.ffn.w_gate_hidden._layer.matmul.weight_scale" - w_scale_gate_hidden_param = ms.Tensor( - np.concatenate([w1_scale_ms_stack_param, w3_scale_ms_stack_param], axis=1), dtype=ms.bfloat16) + w_scale_gate_hidden_param = ms.Tensor.from_numpy( + np.concatenate([w1_scale_ms_stack_param, w3_scale_ms_stack_param], axis=1)).astype(ms.bfloat16) parameter_dict[w_scale_gate_hidden_name] = ms.Parameter(w_scale_gate_hidden_param, name=w_scale_gate_hidden_name, requires_grad=False) else: # w1 w3 - parameter_dict[w1_ms_name] = ms.Parameter(ms.Tensor(w1_ms_stack_param, ms.int8), name=w1_ms_name, + parameter_dict[w1_ms_name] = ms.Parameter(ms.Tensor.from_numpy(w1_ms_stack_param).transpose(0, 2, 1).astype(ms.int8), + name=w1_ms_name, requires_grad=False) - parameter_dict[w3_ms_name] = ms.Parameter(ms.Tensor(w3_ms_stack_param, ms.int8), name=w3_ms_name, + parameter_dict[w3_ms_name] = ms.Parameter(ms.Tensor.from_numpy(w3_ms_stack_param).transpose(0, 2, 1).astype(ms.int8), + name=w3_ms_name, requires_grad=False) # w1_scale w3_scale - parameter_dict[w1_scale_ms_name] = ms.Parameter(ms.Tensor(w1_scale_ms_stack_param, ms.bfloat16), + parameter_dict[w1_scale_ms_name] = ms.Parameter(ms.Tensor.from_numpy(w1_scale_ms_stack_param).astype(ms.bfloat16), name=w1_ms_name, requires_grad=False) - parameter_dict[w3_scale_ms_name] = ms.Parameter(ms.Tensor(w3_scale_ms_stack_param, ms.bfloat16), + parameter_dict[w3_scale_ms_name] = ms.Parameter(ms.Tensor.from_numpy(w3_scale_ms_stack_param).astype(ms.bfloat16), name=w3_ms_name, requires_grad=False) - parameter_dict[w2_ms_name] = ms.Parameter(ms.Tensor(w2_ms_stack_param, ms.int8), name=w2_ms_name, + parameter_dict[w2_ms_name] = ms.Parameter(ms.Tensor.from_numpy(w2_ms_stack_param).transpose(0, 2, 1).astype(ms.int8), + name=w2_ms_name, requires_grad=False) - parameter_dict[w2_scale_ms_name] = ms.Parameter(ms.Tensor(w2_scale_ms_stack_param, ms.bfloat16), + parameter_dict[w2_scale_ms_name] = ms.Parameter(ms.Tensor.from_numpy(w2_scale_ms_stack_param).astype(ms.bfloat16), name=w2_scale_ms_name, requires_grad=False) param_not_load, ckpt_not_load = ms.load_param_into_net(self.network, parameter_dict) @@ -260,39 +263,38 @@ class DeepseekInferParallelism(BaseModelParallelism): if ffn_concat: w_gate_hidden_name = f"model.layers.{layer_id}.feed_forward.shared_experts.w_gate_hidden._layer.weight" - w_gate_hidden_param = ms.Tensor(np.concatenate([w1_ms_param, w3_ms_param], axis=0), dtype=ms.int8) + w_gate_hidden_param = ms.Tensor.from_numpy(np.concatenate([w1_ms_param, w3_ms_param], axis=0)).astype(ms.int8) parameter_dict[w_gate_hidden_name] = ms.Parameter(w_gate_hidden_param, name=w_gate_hidden_name, requires_grad=False) w_scale_gate_hidden_name = f"model.layers.{layer_id}.feed_forward.shared_experts.w_gate_hidden._layer.matmul.weight_scale" - w_scale_gate_hidden_param = ms.Tensor( - np.concatenate([w1_scale_ms_param, w3_scale_ms_param], axis=0), - dtype=ms.bfloat16) + w_scale_gate_hidden_param = ms.Tensor.from_numpy( + np.concatenate([w1_scale_ms_param, w3_scale_ms_param], axis=0)).astype(ms.bfloat16) parameter_dict[w_scale_gate_hidden_name] = ms.Parameter(w_scale_gate_hidden_param, name=w_scale_gate_hidden_name, requires_grad=False) else: - parameter_dict[w1_ms_name] = ms.Parameter(ms.Tensor(w1_ms_param, ms.int8), + parameter_dict[w1_ms_name] = ms.Parameter(ms.Tensor.from_numpy(w1_ms_param).astype(ms.int8), name=w1_ms_name, requires_grad=False) - parameter_dict[w3_ms_name] = ms.Parameter(ms.Tensor(w3_ms_param, ms.int8), + parameter_dict[w3_ms_name] = ms.Parameter(ms.Tensor.from_numpy(w3_ms_param).astype(ms.int8), name=w3_ms_name, requires_grad=False) - parameter_dict[w1_scale_ms_name] = ms.Parameter(ms.Tensor(w1_scale_ms_param, ms.bfloat16), + parameter_dict[w1_scale_ms_name] = ms.Parameter(ms.Tensor.from_numpy(w1_scale_ms_param).astype(ms.bfloat16), name=w1_ms_name, requires_grad=False) - parameter_dict[w3_scale_ms_name] = ms.Parameter(ms.Tensor(w3_scale_ms_param, ms.bfloat16), + parameter_dict[w3_scale_ms_name] = ms.Parameter(ms.Tensor.from_numpy(w3_scale_ms_param).astype(ms.bfloat16), name=w3_ms_name, requires_grad=False) - parameter_dict[w2_ms_name] = ms.Parameter(ms.Tensor(w2_ms_param, ms.int8), + parameter_dict[w2_ms_name] = ms.Parameter(ms.Tensor.from_numpy(w2_ms_param).astype(ms.int8), name=w2_ms_name, requires_grad=False) - parameter_dict[w2_scale_ms_name] = ms.Parameter(ms.Tensor(w2_scale_ms_param, ms.bfloat16), + parameter_dict[w2_scale_ms_name] = ms.Parameter(ms.Tensor.from_numpy(w2_scale_ms_param).astype(ms.bfloat16), name=w2_ms_name, requires_grad=False) param_not_load, ckpt_not_load = ms.load_param_into_net(self.network, parameter_dict) @@ -342,39 +344,37 @@ class DeepseekInferParallelism(BaseModelParallelism): if ffn_concat: w_gate_hidden_name = f"model.layers.{layer_id}.feed_forward.w_gate_hidden._layer.weight" - w_gate_hidden_param = ms.Tensor(np.concatenate([w1_ms_param, w3_ms_param], axis=0), - dtype=ms.int8) + w_gate_hidden_param = ms.Tensor.from_numpy(np.concatenate([w1_ms_param, w3_ms_param], axis=0)).astype(ms.int8) parameter_dict[w_gate_hidden_name] = ms.Parameter(w_gate_hidden_param, name=w_gate_hidden_name, requires_grad=False) w_scale_gate_hidden_name = f"model.layers.{layer_id}.feed_forward.w_gate_hidden._layer.matmul.weight_scale" - w_scale_gate_hidden_param = ms.Tensor( - np.concatenate([w1_scale_ms_param, w3_scale_ms_param], axis=0), - dtype=ms.bfloat16) + w_scale_gate_hidden_param = ms.Tensor.from_numpy( + np.concatenate([w1_scale_ms_param, w3_scale_ms_param], axis=0)).astype(ms.bfloat16) parameter_dict[w_scale_gate_hidden_name] = ms.Parameter(w_scale_gate_hidden_param, name=w_scale_gate_hidden_name, requires_grad=False) else: - parameter_dict[w1_ms_name] = ms.Parameter(ms.Tensor(w1_ms_param, ms.int8), + parameter_dict[w1_ms_name] = ms.Parameter(ms.Tensor.from_numpy(w1_ms_param).astype(ms.int8), name=w1_ms_name, requires_grad=False) - parameter_dict[w3_ms_name] = ms.Parameter(ms.Tensor(w3_ms_param, ms.int8), + parameter_dict[w3_ms_name] = ms.Parameter(ms.Tensor.from_numpy(w3_ms_param).astype(ms.int8), name=w3_ms_name, requires_grad=False) - parameter_dict[w1_scale_ms_name] = ms.Parameter(ms.Tensor(w1_scale_ms_param, ms.bfloat16), + parameter_dict[w1_scale_ms_name] = ms.Parameter(ms.Tensor.from_numpy(w1_scale_ms_param).astype(ms.bfloat16), name=w1_scale_ms_name, requires_grad=False) - parameter_dict[w3_scale_ms_name] = ms.Parameter(ms.Tensor(w3_scale_ms_param, ms.bfloat16), + parameter_dict[w3_scale_ms_name] = ms.Parameter(ms.Tensor.from_numpy(w3_scale_ms_param).astype(ms.bfloat16), name=w3_scale_ms_name, requires_grad=False) - parameter_dict[w2_ms_name] = ms.Parameter(ms.Tensor(w2_ms_param, ms.int8), + parameter_dict[w2_ms_name] = ms.Parameter(ms.Tensor.from_numpy(w2_ms_param).astype(ms.int8), name=w2_ms_name, requires_grad=False) - parameter_dict[w2_scale_ms_name] = ms.Parameter(ms.Tensor(w2_scale_ms_param, ms.bfloat16), + parameter_dict[w2_scale_ms_name] = ms.Parameter(ms.Tensor.from_numpy(w2_scale_ms_param).astype(ms.bfloat16), name=w2_ms_name, requires_grad=False) param_not_load, ckpt_not_load = ms.load_param_into_net(self.network, parameter_dict) @@ -385,14 +385,14 @@ class DeepseekInferParallelism(BaseModelParallelism): embed_tokens_hf_name = "model.embed_tokens.weight" embed_tokens_ms_name = self.quant_convert_weight_name(embed_tokens_hf_name) np_data = self.get_safetensor_from_file(embed_tokens_hf_name, src_hf_dir, hf_weight_map) - parameter_dict[embed_tokens_ms_name] = ms.Parameter(ms.Tensor(np_data, ms.bfloat16), + parameter_dict[embed_tokens_ms_name] = ms.Parameter(ms.Tensor.from_numpy(np_data).astype(ms.bfloat16), name=embed_tokens_ms_name, requires_grad=False) norm_hf_name = "model.norm.weight" norm_ms_name = self.quant_convert_weight_name(norm_hf_name) np_data = self.get_safetensor_from_file(norm_hf_name, src_hf_dir, hf_weight_map) - parameter_dict[norm_ms_name] = ms.Parameter(ms.Tensor(np_data, ms.bfloat16), name=norm_ms_name, + parameter_dict[norm_ms_name] = ms.Parameter(ms.Tensor.from_numpy(np_data).astype(ms.bfloat16), name=norm_ms_name, requires_grad=False) lm_head_hf_name = "lm_head.weight" @@ -402,7 +402,7 @@ class DeepseekInferParallelism(BaseModelParallelism): is_split_param=True, split_axis=0) else: np_data = self.get_safetensor_from_file(lm_head_hf_name, src_hf_dir, hf_weight_map) - parameter_dict[lm_head_ms_name] = ms.Parameter(ms.Tensor(np_data, ms.bfloat16), name=lm_head_ms_name, + parameter_dict[lm_head_ms_name] = ms.Parameter(ms.Tensor.from_numpy(np_data).astype(ms.bfloat16), name=lm_head_ms_name, requires_grad=False) param_not_load, ckpt_not_load = ms.load_param_into_net(self.network, parameter_dict) @@ -418,13 +418,13 @@ class DeepseekInferParallelism(BaseModelParallelism): input_scale_hf_name = f"model.layers.{layer_id}.self_attn." + name + ".input_scale" input_scale_ms_name = self.quant_convert_weight_name(input_scale_hf_name) input_scale_ms_param = self.get_safetensor_from_file(input_scale_hf_name, src_hf_dir, hf_weight_map) - parameter_dict[input_scale_ms_name] = ms.Parameter(ms.Tensor(input_scale_ms_param, ms.bfloat16), + parameter_dict[input_scale_ms_name] = ms.Parameter(ms.Tensor.from_numpy(input_scale_ms_param).astype(ms.bfloat16), name=input_scale_ms_name, requires_grad=False) input_zp_hf_name = f"model.layers.{layer_id}.self_attn." + name + ".input_offset" input_zp_ms_name = self.quant_convert_weight_name(input_zp_hf_name) input_zp_ms_param = self.get_safetensor_from_file(input_zp_hf_name, src_hf_dir, hf_weight_map) - parameter_dict[input_zp_ms_name] = ms.Parameter(ms.Tensor(input_zp_ms_param, ms.int8), + parameter_dict[input_zp_ms_name] = ms.Parameter(ms.Tensor.from_numpy(input_zp_ms_param).astype(ms.int8), name=input_zp_ms_name, requires_grad=False) @@ -474,9 +474,9 @@ class DeepseekInferParallelism(BaseModelParallelism): quant_bias_ms_param = self.split_weight_by_rank(quant_bias_ms_param, split_axis=0) dequant_scale_ms_param = self.split_weight_by_rank(dequant_scale_ms_param, split_axis=0) - parameter_dict[quant_bias_ms_name] = ms.Parameter(ms.Tensor(quant_bias_ms_param, ms.int32), + parameter_dict[quant_bias_ms_name] = ms.Parameter(ms.Tensor.from_numpy(quant_bias_ms_param).astype(ms.int32), name=quant_bias_ms_name, requires_grad=False) - parameter_dict[dequant_scale_ms_name] = ms.Parameter(ms.Tensor(dequant_scale_ms_param, ms.float32), + parameter_dict[dequant_scale_ms_name] = ms.Parameter(ms.Tensor.from_numpy(dequant_scale_ms_param).astype(ms.float32), name=dequant_scale_ms_name, requires_grad=False) param_not_load, ckpt_not_load = ms.load_param_into_net(self.network, parameter_dict) @@ -494,12 +494,12 @@ class DeepseekInferParallelism(BaseModelParallelism): l2q_proj_bias_ms_name = self.quant_convert_weight_name(l2q_proj_bias_hf_name) l2q_proj_bias_ms_param = self.get_safetensor_from_file(l2q_proj_bias_hf_name, src_hf_dir, hf_weight_map) - parameter_dict[q2l_proj_bias_ms_name] = ms.Parameter(ms.Tensor(q2l_proj_bias_ms_param, ms.bfloat16), + parameter_dict[q2l_proj_bias_ms_name] = ms.Parameter(ms.Tensor.from_numpy(q2l_proj_bias_ms_param).astype(ms.bfloat16), name=q2l_proj_bias_ms_name, requires_grad=False) - parameter_dict[kv2l_bias_ms_name] = ms.Parameter(ms.Tensor(kv2l_bias_ms_param, ms.bfloat16), + parameter_dict[kv2l_bias_ms_name] = ms.Parameter(ms.Tensor.from_numpy(kv2l_bias_ms_param).astype(ms.bfloat16), name=kv2l_bias_ms_name, requires_grad=False) - parameter_dict[l2q_proj_bias_ms_name] = ms.Parameter(ms.Tensor(l2q_proj_bias_ms_param, ms.bfloat16), + parameter_dict[l2q_proj_bias_ms_name] = ms.Parameter(ms.Tensor.from_numpy(l2q_proj_bias_ms_param).astype(ms.bfloat16), name=l2q_proj_bias_ms_name, requires_grad=False) param_not_load, ckpt_not_load = ms.load_param_into_net(self.network, parameter_dict) @@ -521,7 +521,7 @@ class DeepseekInferParallelism(BaseModelParallelism): q2l_proj_hf_name = f"model.layers.{layer_id}.self_attn.q_a_proj.weight" q2l_proj_ms_name = self.quant_convert_weight_name(q2l_proj_hf_name) q2l_proj_ms_param = self.get_safetensor_from_file(q2l_proj_hf_name, src_hf_dir, hf_weight_map) - parameter_dict[q2l_proj_ms_name] = ms.Parameter(ms.Tensor(q2l_proj_ms_param, ms.int8), + parameter_dict[q2l_proj_ms_name] = ms.Parameter(ms.Tensor.from_numpy(q2l_proj_ms_param).astype(ms.int8), name=q2l_proj_ms_name, requires_grad=False) self.quant_special_attention_weight(layer_id, src_hf_dir, hf_weight_map, "q_a_proj") @@ -532,7 +532,7 @@ class DeepseekInferParallelism(BaseModelParallelism): kv2l_ms_param = self.get_safetensor_from_file(kv2l_hf_name, src_hf_dir, hf_weight_map) kv2l_ms_param = kv2l_ms_param.reshape(kv_head_dim, -1) kv2l_ms_param = self.infer_trans_rope_weight(kv2l_ms_param, qk_rope_head_dim) - parameter_dict[kv2l_ms_name] = ms.Parameter(ms.Tensor(kv2l_ms_param, ms.int8), name=kv2l_ms_name, + parameter_dict[kv2l_ms_name] = ms.Parameter(ms.Tensor.from_numpy(kv2l_ms_param).astype(ms.int8), name=kv2l_ms_name, requires_grad=False) self.quant_special_attention_weight(layer_id, src_hf_dir, hf_weight_map, "kv_a_proj_with_mqa", is_trans_rope_weigh=True) @@ -541,7 +541,7 @@ class DeepseekInferParallelism(BaseModelParallelism): lq_norm_hf_name = f"model.layers.{layer_id}.self_attn.q_a_layernorm.weight" lq_norm_ms_name = self.quant_convert_weight_name(lq_norm_hf_name) lq_norm_ms_param = self.get_safetensor_from_file(lq_norm_hf_name, src_hf_dir, hf_weight_map) - parameter_dict[lq_norm_ms_name] = ms.Parameter(ms.Tensor(lq_norm_ms_param, ms.bfloat16), + parameter_dict[lq_norm_ms_name] = ms.Parameter(ms.Tensor.from_numpy(lq_norm_ms_param).astype(ms.bfloat16), name=lq_norm_ms_name, requires_grad=False) @@ -553,7 +553,7 @@ class DeepseekInferParallelism(BaseModelParallelism): l2q_proj_ms_param = self.infer_trans_rope_weight(l2q_proj_ms_param, qk_rope_head_dim) l2q_proj_ms_param = l2q_proj_ms_param.reshape(num_heads * rope_dim, -1) l2q_proj_ms_param = self.split_weight_by_rank(l2q_proj_ms_param, split_axis=0) - parameter_dict[l2q_proj_ms_name] = ms.Parameter(ms.Tensor(l2q_proj_ms_param, ms.int8), + parameter_dict[l2q_proj_ms_name] = ms.Parameter(ms.Tensor.from_numpy(l2q_proj_ms_param).astype(ms.int8), name=l2q_proj_ms_name, requires_grad=False) self.quant_special_attention_weight(layer_id, src_hf_dir, hf_weight_map, "q_b_proj", is_trans_rope_weigh=True, @@ -563,7 +563,7 @@ class DeepseekInferParallelism(BaseModelParallelism): lkv_norm_hf_name = f"model.layers.{layer_id}.self_attn.kv_a_layernorm.weight" lkv_norm_ms_name = self.quant_convert_weight_name(lkv_norm_hf_name) lkv_norm_ms_param = self.get_safetensor_from_file(lkv_norm_hf_name, src_hf_dir, hf_weight_map) - parameter_dict[lkv_norm_ms_name] = ms.Parameter(ms.Tensor(lkv_norm_ms_param, ms.bfloat16), + parameter_dict[lkv_norm_ms_name] = ms.Parameter(ms.Tensor.from_numpy(lkv_norm_ms_param).astype(ms.bfloat16), name=lkv_norm_ms_name, requires_grad=False) @@ -579,13 +579,13 @@ class DeepseekInferParallelism(BaseModelParallelism): value_k_nope = value_k_nope.reshape(-1, value_k_nope.shape[-1]) value_k_nope = self.split_weight_by_rank(value_k_nope, split_axis=0) name_k_nope = lkv2kv_ms_name.replace(".attention.lkv2kv.", ".attention.lkv2kv_k_nope.") - parameter_dict[name_k_nope] = ms.Parameter(ms.Tensor(value_k_nope, ms.bfloat16), name=name_k_nope, + parameter_dict[name_k_nope] = ms.Parameter(ms.Tensor.from_numpy(value_k_nope).astype(ms.bfloat16), name=name_k_nope, requires_grad=False) # value_v value_v = value_v.reshape(-1, value_v.shape[-1]) value_v = self.split_weight_by_rank(value_v, split_axis=0) name_v = lkv2kv_ms_name.replace(".attention.lkv2kv.", ".attention.lkv2kv_v.") - parameter_dict[name_v] = ms.Parameter(ms.Tensor(value_v, ms.bfloat16), name=name_v, + parameter_dict[name_v] = ms.Parameter(ms.Tensor.from_numpy(value_v).astype(ms.bfloat16), name=name_v, requires_grad=False) # o_proj->wo @@ -593,7 +593,7 @@ class DeepseekInferParallelism(BaseModelParallelism): wo_ms_name = self.quant_convert_weight_name(wo_hf_name) wo_ms_param = self.get_safetensor_from_file(wo_hf_name, src_hf_dir, hf_weight_map) wo_ms_param = self.split_weight_by_rank(wo_ms_param, split_axis=1) - parameter_dict[wo_ms_name] = ms.Parameter(ms.Tensor(wo_ms_param, ms.int8), name=wo_ms_name, + parameter_dict[wo_ms_name] = ms.Parameter(ms.Tensor.from_numpy(wo_ms_param).astype(ms.int8), name=wo_ms_name, requires_grad=False) self.quant_special_attention_weight(layer_id, src_hf_dir, hf_weight_map, "o_proj") @@ -601,8 +601,6 @@ class DeepseekInferParallelism(BaseModelParallelism): def infer_quant_net_convert_layer_weight(self, src_hf_dir, layer_id, hf_weight_map): """infer quant net convert layer weight""" - print(f"..... start convert layer {layer_id} .......", flush=True) - if layer_id >= 3: self.infer_quant_process_moe_routed_expert_ffn_weight(src_hf_dir, layer_id, hf_weight_map) self.infer_quant_process_moe_shared_expert_ffn_weight(src_hf_dir, layer_id, hf_weight_map) @@ -613,8 +611,6 @@ class DeepseekInferParallelism(BaseModelParallelism): self.infer_quant_bias_weight(src_hf_dir, layer_id, hf_weight_map) self.infer_process_norm_weight(src_hf_dir, layer_id, hf_weight_map) - print(f"..... end convert layer {layer_id} .......", flush=True) - def convert_weight_name(self, weight_name: str): """replace weight name""" weight_name = weight_name.replace('embed_tokens.weight', 'tok_embeddings.embedding_weight') @@ -650,7 +646,7 @@ class DeepseekInferParallelism(BaseModelParallelism): router_dense_hf_name = f"model.layers.{layer_id}.mlp.gate.weight" router_dense_ms_name = self.convert_weight_name(router_dense_hf_name) router_dense_ms_param = self.get_safetensor_from_file(router_dense_hf_name, src_hf_dir, hf_weight_map) - parameter_dict[router_dense_ms_name] = ms.Parameter(ms.Tensor(router_dense_ms_param, ms.bfloat16), + parameter_dict[router_dense_ms_name] = ms.Parameter(ms.Tensor.from_numpy(router_dense_ms_param).astype(ms.bfloat16), name=router_dense_ms_name, requires_grad=False) # e_score_correction_bias @@ -659,7 +655,7 @@ class DeepseekInferParallelism(BaseModelParallelism): e_score_correction_bias_ms_param = self.get_safetensor_from_file(e_score_correction_bias_hf_name, src_hf_dir, hf_weight_map) parameter_dict[e_score_correction_bias_ms_name] = ms.Parameter( - ms.Tensor(e_score_correction_bias_ms_param, ms.float32), + ms.Tensor.from_numpy(e_score_correction_bias_ms_param).astype(ms.float32), name=e_score_correction_bias_ms_name, requires_grad=False) w1_list = [] @@ -686,23 +682,26 @@ class DeepseekInferParallelism(BaseModelParallelism): w2_list.append(w2_ms_param) w3_list.append(w3_ms_param) - w1_ms_stack_param = np.stack(w1_list, axis=0).transpose(0, 2, 1) - w2_ms_stack_param = np.stack(w2_list, axis=0).transpose(0, 2, 1) - w3_ms_stack_param = np.stack(w3_list, axis=0).transpose(0, 2, 1) + w1_ms_stack_param = np.stack(w1_list, axis=0) + w2_ms_stack_param = np.stack(w2_list, axis=0) + w3_ms_stack_param = np.stack(w3_list, axis=0) if ffn_concat: w_gate_hidden_name = f"model.layers.{layer_id}.feed_forward.routed_experts.ffn.w_gate_hidden.weight" - w_gate_hidden_param = ms.Tensor(np.concatenate([w1_ms_stack_param, w3_ms_stack_param], axis=2), - dtype=ms.bfloat16) + w_gate_hidden_param = ms.Tensor.from_numpy( + np.concatenate([w1_ms_stack_param, w3_ms_stack_param], axis=1)).transpose(0, 2, 1).astype(ms.bfloat16) parameter_dict[w_gate_hidden_name] = ms.Parameter(w_gate_hidden_param, name=w_gate_hidden_name, requires_grad=False) else: - parameter_dict[w1_ms_name] = ms.Parameter(ms.Tensor(w1_ms_stack_param, ms.bfloat16), name=w1_ms_name, + parameter_dict[w1_ms_name] = ms.Parameter(ms.Tensor.from_numpy(w1_ms_stack_param).transpose(0, 2, 1).astype(ms.bfloat16), + name=w1_ms_name, requires_grad=False) - parameter_dict[w3_ms_name] = ms.Parameter(ms.Tensor(w3_ms_stack_param, ms.bfloat16), name=w3_ms_name, + parameter_dict[w3_ms_name] = ms.Parameter(ms.Tensor.from_numpy(w3_ms_stack_param).transpose(0, 2, 1).astype(ms.bfloat16), + name=w3_ms_name, requires_grad=False) - parameter_dict[w2_ms_name] = ms.Parameter(ms.Tensor(w2_ms_stack_param, ms.bfloat16), name=w2_ms_name, + parameter_dict[w2_ms_name] = ms.Parameter(ms.Tensor.from_numpy(w2_ms_stack_param).transpose(0, 2, 1).astype(ms.bfloat16), + name=w2_ms_name, requires_grad=False) _, ckpt_not_load = ms.load_param_into_net(self.network, parameter_dict) @@ -724,15 +723,15 @@ class DeepseekInferParallelism(BaseModelParallelism): if ffn_concat: w_gate_hidden_name = f"model.layers.{layer_id}.feed_forward.shared_experts.w_gate_hidden.weight" - w_gate_hidden_param = ms.Tensor(np.concatenate([w1_ms_param, w3_ms_param], axis=0), dtype=ms.bfloat16) + w_gate_hidden_param = ms.Tensor.from_numpy(np.concatenate([w1_ms_param, w3_ms_param], axis=0)).astype(ms.bfloat16) parameter_dict[w_gate_hidden_name] = ms.Parameter(w_gate_hidden_param, name=w_gate_hidden_name, requires_grad=False) else: - parameter_dict[w1_ms_name] = ms.Parameter(ms.Tensor(w1_ms_param, ms.bfloat16), name=w1_ms_name, + parameter_dict[w1_ms_name] = ms.Parameter(ms.Tensor.from_numpy(w1_ms_param).astype(ms.bfloat16), name=w1_ms_name, requires_grad=False) - parameter_dict[w3_ms_name] = ms.Parameter(ms.Tensor(w3_ms_param, ms.bfloat16), name=w3_ms_name, + parameter_dict[w3_ms_name] = ms.Parameter(ms.Tensor.from_numpy(w3_ms_param).astype(ms.bfloat16), name=w3_ms_name, requires_grad=False) - parameter_dict[w2_ms_name] = ms.Parameter(ms.Tensor(w2_ms_param, ms.bfloat16), name=w2_ms_name, + parameter_dict[w2_ms_name] = ms.Parameter(ms.Tensor.from_numpy(w2_ms_param).astype(ms.bfloat16), name=w2_ms_name, requires_grad=False) _, ckpt_not_load = ms.load_param_into_net(self.network, parameter_dict) @@ -759,16 +758,16 @@ class DeepseekInferParallelism(BaseModelParallelism): if ffn_concat: w_gate_hidden_name = f"model.layers.{layer_id}.feed_forward.w_gate_hidden.weight" - w_gate_hidden_param = ms.Tensor(np.concatenate([w1_ms_param, w3_ms_param], axis=0), dtype=ms.bfloat16) + w_gate_hidden_param = ms.Tensor.from_numpy(np.concatenate([w1_ms_param, w3_ms_param], axis=0)).astype(ms.bfloat16) parameter_dict[w_gate_hidden_name] = ms.Parameter(w_gate_hidden_param, name=w_gate_hidden_name, requires_grad=False) else: - parameter_dict[w1_ms_name] = ms.Parameter(ms.Tensor(w1_ms_param, ms.bfloat16), name=w1_ms_name, + parameter_dict[w1_ms_name] = ms.Parameter(ms.Tensor.from_numpy(w1_ms_param).astype(ms.bfloat16), name=w1_ms_name, requires_grad=False) - parameter_dict[w3_ms_name] = ms.Parameter(ms.Tensor(w3_ms_param, ms.bfloat16), name=w3_ms_name, + parameter_dict[w3_ms_name] = ms.Parameter(ms.Tensor.from_numpy(w3_ms_param).astype(ms.bfloat16), name=w3_ms_name, requires_grad=False) - parameter_dict[w2_ms_name] = ms.Parameter(ms.Tensor(w2_ms_param, ms.bfloat16), name=w2_ms_name, + parameter_dict[w2_ms_name] = ms.Parameter(ms.Tensor.from_numpy(w2_ms_param).astype(ms.bfloat16), name=w2_ms_name, requires_grad=False) _, ckpt_not_load = ms.load_param_into_net(self.network, parameter_dict) @@ -788,7 +787,7 @@ class DeepseekInferParallelism(BaseModelParallelism): q2l_proj_hf_name = f"model.layers.{layer_id}.self_attn.q_a_proj.weight" q2l_proj_ms_name = self.convert_weight_name(q2l_proj_hf_name) q_a_proj_ms_param = self.get_safetensor_from_file(q2l_proj_hf_name, src_hf_dir, hf_weight_map) - parameter_dict[q2l_proj_ms_name] = ms.Parameter(ms.Tensor(q_a_proj_ms_param, ms.bfloat16), + parameter_dict[q2l_proj_ms_name] = ms.Parameter(ms.Tensor.from_numpy(q_a_proj_ms_param).astype(ms.bfloat16), name=q2l_proj_ms_name, requires_grad=False) @@ -798,14 +797,14 @@ class DeepseekInferParallelism(BaseModelParallelism): kv2l_ms_param = self.get_safetensor_from_file(kv2l_hf_name, src_hf_dir, hf_weight_map) kv2l_ms_param = kv2l_ms_param.reshape(kv_head_dim, -1) kv2l_ms_param = self.infer_trans_rope_weight(kv2l_ms_param, qk_rope_head_dim) - parameter_dict[kv2l_ms_name] = ms.Parameter(ms.Tensor(kv2l_ms_param, ms.bfloat16), name=kv2l_ms_name, + parameter_dict[kv2l_ms_name] = ms.Parameter(ms.Tensor.from_numpy(kv2l_ms_param).astype(ms.bfloat16), name=kv2l_ms_name, requires_grad=False) # lq_norm lq_norm_hf_name = f"model.layers.{layer_id}.self_attn.q_a_layernorm.weight" lq_norm_ms_name = self.convert_weight_name(lq_norm_hf_name) lq_norm_ms_param = self.get_safetensor_from_file(lq_norm_hf_name, src_hf_dir, hf_weight_map) - parameter_dict[lq_norm_ms_name] = ms.Parameter(ms.Tensor(lq_norm_ms_param, ms.bfloat16), name=lq_norm_ms_name, + parameter_dict[lq_norm_ms_name] = ms.Parameter(ms.Tensor.from_numpy(lq_norm_ms_param).astype(ms.bfloat16), name=lq_norm_ms_name, requires_grad=False) # l2q_proj @@ -816,7 +815,7 @@ class DeepseekInferParallelism(BaseModelParallelism): l2q_proj_ms_param = self.infer_trans_rope_weight(l2q_proj_ms_param, qk_rope_head_dim) l2q_proj_ms_param = l2q_proj_ms_param.reshape(num_heads * rope_dim, -1) l2q_proj_ms_param = self.split_weight_by_rank(l2q_proj_ms_param, split_axis=0) - parameter_dict[l2q_proj_ms_name] = ms.Parameter(ms.Tensor(l2q_proj_ms_param, ms.bfloat16), + parameter_dict[l2q_proj_ms_name] = ms.Parameter(ms.Tensor.from_numpy(l2q_proj_ms_param).astype(ms.bfloat16), name=l2q_proj_ms_name, requires_grad=False) @@ -824,7 +823,7 @@ class DeepseekInferParallelism(BaseModelParallelism): lkv_norm_hf_name = f"model.layers.{layer_id}.self_attn.kv_a_layernorm.weight" lkv_norm_ms_name = self.convert_weight_name(lkv_norm_hf_name) lkv_norm_ms_param = self.get_safetensor_from_file(lkv_norm_hf_name, src_hf_dir, hf_weight_map) - parameter_dict[lkv_norm_ms_name] = ms.Parameter(ms.Tensor(lkv_norm_ms_param, ms.bfloat16), + parameter_dict[lkv_norm_ms_name] = ms.Parameter(ms.Tensor.from_numpy(lkv_norm_ms_param).astype(ms.bfloat16), name=lkv_norm_ms_name, requires_grad=False) @@ -840,13 +839,13 @@ class DeepseekInferParallelism(BaseModelParallelism): value_k_nope = value_k_nope.reshape(-1, value_k_nope.shape[-1]) value_k_nope = self.split_weight_by_rank(value_k_nope, split_axis=0) name_k_nope = lkv2kv_ms_name.replace(".attention.lkv2kv.", ".attention.lkv2kv_k_nope.") - parameter_dict[name_k_nope] = ms.Parameter(ms.Tensor(value_k_nope, ms.bfloat16), name=name_k_nope, + parameter_dict[name_k_nope] = ms.Parameter(ms.Tensor.from_numpy(value_k_nope).astype(ms.bfloat16), name=name_k_nope, requires_grad=False) # value_v value_v = value_v.reshape(-1, value_v.shape[-1]) value_v = self.split_weight_by_rank(value_v, split_axis=0) name_v = lkv2kv_ms_name.replace(".attention.lkv2kv.", ".attention.lkv2kv_v.") - parameter_dict[name_v] = ms.Parameter(ms.Tensor(value_v, ms.bfloat16), name=name_v, + parameter_dict[name_v] = ms.Parameter(ms.Tensor.from_numpy(value_v).astype(ms.bfloat16), name=name_v, requires_grad=False) # wo @@ -854,7 +853,7 @@ class DeepseekInferParallelism(BaseModelParallelism): wo_ms_name = self.convert_weight_name(wo_hf_name) wo_ms_param = self.get_safetensor_from_file(wo_hf_name, src_hf_dir, hf_weight_map) wo_ms_param = self.split_weight_by_rank(wo_ms_param, split_axis=1) - parameter_dict[wo_ms_name] = ms.Parameter(ms.Tensor(wo_ms_param, ms.bfloat16), name=wo_ms_name, + parameter_dict[wo_ms_name] = ms.Parameter(ms.Tensor.from_numpy(wo_ms_param).astype(ms.bfloat16), name=wo_ms_name, requires_grad=False) _, ckpt_not_load = ms.load_param_into_net(self.network, parameter_dict) @@ -867,7 +866,7 @@ class DeepseekInferParallelism(BaseModelParallelism): attention_norm_ms_param = self.get_safetensor_from_file(attention_norm_hf_name, src_hf_dir, hf_weight_map) - parameter_dict[attention_norm_ms_name] = ms.Parameter(ms.Tensor(attention_norm_ms_param, ms.bfloat16), + parameter_dict[attention_norm_ms_name] = ms.Parameter(ms.Tensor.from_numpy(attention_norm_ms_param).astype(ms.bfloat16), name=attention_norm_ms_name, requires_grad=False) @@ -875,7 +874,7 @@ class DeepseekInferParallelism(BaseModelParallelism): ffn_norm_hf_name = f"model.layers.{layer_id}.post_attention_layernorm.weight" ffn_norm_ms_name = self.convert_weight_name(ffn_norm_hf_name) ffn_norm_ms_param = self.get_safetensor_from_file(ffn_norm_hf_name, src_hf_dir, hf_weight_map) - parameter_dict[ffn_norm_ms_name] = ms.Parameter(ms.Tensor(ffn_norm_ms_param, ms.bfloat16), + parameter_dict[ffn_norm_ms_name] = ms.Parameter(ms.Tensor.from_numpy(ffn_norm_ms_param).astype(ms.bfloat16), name=ffn_norm_ms_name, requires_grad=False) @@ -883,8 +882,6 @@ class DeepseekInferParallelism(BaseModelParallelism): def infer_convert_layer_weight(self, src_hf_dir, layer_id, hf_weight_map): """infer convert layer weight""" - print(f"..... start convert layer {layer_id} .......", flush=True) - if layer_id >= 3: self.infer_process_moe_routed_expert_ffn_weight(src_hf_dir, layer_id, hf_weight_map) self.infer_process_moe_shared_expert_ffn_weight(src_hf_dir, layer_id, hf_weight_map) @@ -894,8 +891,6 @@ class DeepseekInferParallelism(BaseModelParallelism): self.infer_process_attention_weight(src_hf_dir, layer_id, hf_weight_map) self.infer_process_norm_weight(src_hf_dir, layer_id, hf_weight_map) - print(f"..... end convert layer {layer_id} .......", flush=True) - def infer_convert_and_parallelism(self, src_hf_dir): """convert inference model weight """ param_json_path = "" @@ -912,7 +907,7 @@ class DeepseekInferParallelism(BaseModelParallelism): self.infer_convert_outer_weight(src_hf_dir, hf_weight_map) num_layers = self.config.model.model_config.num_layers - for layer_id in range(num_layers): + for layer_id in tqdm(range(num_layers), desc="Weight loading"): if self.is_quant: self.infer_quant_net_convert_layer_weight(src_hf_dir, layer_id, hf_weight_map) else: diff --git a/vllm_mindspore/model_executor/models/mf_models/model_parallelism.py b/vllm_mindspore/model_executor/models/mf_models/model_parallelism.py index a063cab9..3059c6d0 100644 --- a/vllm_mindspore/model_executor/models/mf_models/model_parallelism.py +++ b/vllm_mindspore/model_executor/models/mf_models/model_parallelism.py @@ -16,7 +16,9 @@ """ transform huggingface safetensor. """ +import os from safetensors import safe_open +import mindspore as ms from mindspore.communication.management import get_rank, get_group_size @@ -33,45 +35,55 @@ class BaseModelParallelism: self.config = config self.network = network self.is_quant = is_quant + self.tp_group_size = get_group_size() + self.rank_id = get_rank() + self.file_handles = {} + + def get_file_handles(self, filename): + if filename not in self.file_handles: + fp = safe_open(filename, framework="np") + self.file_handles[filename] = fp + return self.file_handles[filename] + + def release_file_handles(self): + del self.file_handles + self.file_handles = {} def get_safetensor_from_file(self, hf_param_name, src_hf_dir, hf_weight_map, is_split_param=False, split_axis=0): - tp_group_size = get_group_size() - rank_id = get_rank() safetensor_file = hf_weight_map[hf_param_name] - with safe_open(f"{src_hf_dir}/{safetensor_file}", framework="np") as sf_file: - if not is_split_param: - np_data = sf_file.get_tensor(hf_param_name) - return np_data + filename = os.path.join(src_hf_dir, safetensor_file) + sf_file = self.get_file_handles(filename) + if not is_split_param: + np_data = sf_file.get_tensor(hf_param_name) + return np_data - np_data = sf_file.get_slice(hf_param_name) - shape = np_data.get_shape() - if split_axis == 0: - split_size = shape[0] // tp_group_size - start = rank_id * split_size - stop = (rank_id + 1) * split_size - split_data = np_data[start:stop] - elif split_axis == 1: - split_size = shape[1] // tp_group_size - start = rank_id * split_size - stop = (rank_id + 1) * split_size - split_data = np_data[:, start:stop] - else: - raise ValueError("split_axis:{} is not supported.".format(split_axis)) - return split_data + np_data = sf_file.get_slice(hf_param_name) + shape = np_data.get_shape() + if split_axis == 0: + split_size = shape[0] // self.tp_group_size + start = self.rank_id * split_size + stop = (self.rank_id + 1) * split_size + split_data = np_data[start:stop] + elif split_axis == 1: + split_size = shape[1] // self.tp_group_size + start = self.rank_id * split_size + stop = (self.rank_id + 1) * split_size + split_data = np_data[:, start:stop] + else: + raise ValueError("split_axis:{} is not supported.".format(split_axis)) + return split_data def split_weight_by_rank(self, weight, split_axis=0): - tp_group_size = get_group_size() - rank_id = get_rank() shape = weight.shape if split_axis == 0: - split_size = shape[0] // tp_group_size - start = rank_id * split_size - stop = (rank_id + 1) * split_size + split_size = shape[0] // self.tp_group_size + start = self.rank_id * split_size + stop = (self.rank_id + 1) * split_size split_data = weight[start:stop] elif split_axis == 1: - split_size = shape[1] // tp_group_size - start = rank_id * split_size - stop = (rank_id + 1) * split_size + split_size = shape[1] // self.tp_group_size + start = self.rank_id * split_size + stop = (self.rank_id + 1) * split_size split_data = weight[:, start:stop] else: raise ValueError("split_axis:{} is not supported.".format(split_axis)) -- Gitee