From 8df12c9aeee64e9b1746c359a8784b400990e8dd Mon Sep 17 00:00:00 2001 From: ccsszz Date: Mon, 28 Apr 2025 10:47:20 +0800 Subject: [PATCH] smoothquant support ep --- .../mf_models/deepseekv3_weight_processor.py | 541 +++++++++++++++--- .../models/mf_models/weight_processor.py | 66 +++ 2 files changed, 536 insertions(+), 71 deletions(-) diff --git a/vllm_mindspore/model_executor/models/mf_models/deepseekv3_weight_processor.py b/vllm_mindspore/model_executor/models/mf_models/deepseekv3_weight_processor.py index d8ad1d3e..01296adf 100644 --- a/vllm_mindspore/model_executor/models/mf_models/deepseekv3_weight_processor.py +++ b/vllm_mindspore/model_executor/models/mf_models/deepseekv3_weight_processor.py @@ -70,6 +70,8 @@ class DeepseekV3WeightProcessor(BaseWeightProcessor): self.ep_method = EPMethod.ALLTOALL elif self.dp_group_size > 1: self.ep_method = EPMethod.ALLGATHER + self.moe_split_tp = self.moe_tensor_parallel > 1 + self.moe_split_ep = self.moe_expert_parallel > 1 def quant_convert_weight_name(self, weight_name: str): """replace quant net weight name""" @@ -1000,6 +1002,7 @@ class DeepseekV3WeightProcessor(BaseWeightProcessor): rope_dim = qk_rope_head_dim + qk_nope_head_dim kv_head_dim = kv_lora_rank + qk_rope_head_dim + qkv_concat = self.config.model.model_config.qkv_concat # q2l_proj q2l_proj_hf_name = f"model.layers.{layer_id}.self_attn.q_a_proj.weight" q2l_proj_ms_name = self.convert_weight_name(q2l_proj_hf_name) @@ -1015,10 +1018,19 @@ class DeepseekV3WeightProcessor(BaseWeightProcessor): kv2l_ms_param, _ = self.get_safetensor_from_file(kv2l_hf_name, src_hf_dir, hf_weight_map) kv2l_ms_param = kv2l_ms_param.reshape(kv_head_dim, -1) kv2l_ms_param = self.infer_trans_rope_weight(kv2l_ms_param, qk_rope_head_dim) - self.parameter_dict[kv2l_ms_name] = ms.Parameter(ms.from_numpy(kv2l_ms_param).astype(ms.bfloat16), - name=kv2l_ms_name, - requires_grad=False) - + if qkv_concat: + wqkv2l_weight = np.concatenate((q_a_proj_ms_param, kv2l_ms_param), 0) + wqkv2l_weight_name = f"model.layers.{layer_id}.attention.qkv2l.weight" + self.parameter_dict[wqkv2l_weight_name] = ms.Parameter(ms.from_numpy(wqkv2l_weight).astype(ms.bfloat16), + name=wqkv2l_weight_name, + requires_grad=False) + else: + self.parameter_dict[q2l_proj_ms_name] = ms.Parameter(ms.from_numpy(q_a_proj_ms_param).astype(ms.bfloat16), + name=q2l_proj_ms_name, + requires_grad=False) + self.parameter_dict[kv2l_ms_name] = ms.Parameter(ms.from_numpy(kv2l_ms_param).astype(ms.bfloat16), + name=kv2l_ms_name, + requires_grad=False) # lq_norm lq_norm_hf_name = f"model.layers.{layer_id}.self_attn.q_a_layernorm.weight" lq_norm_ms_name = self.convert_weight_name(lq_norm_hf_name) @@ -1137,88 +1149,475 @@ class DeepseekV3WeightProcessor(BaseWeightProcessor): if layer_id >= self.num_layers: self.infer_process_mtp_layer_weight(src_hf_dir, layer_id, hf_weight_map) + def smooth_quant_process_route_ffn_weight(self, src_hf_dir, layer_id, hf_weight_map, parameter_dict, layer_type): + """smooth_quant_process_route_ffn_weight""" + + ffn_concat = self.config.model.model_config.ffn_concat + w1_weight_name = f"model.layers.{layer_id}.{layer_type}.w1._layer.weight" + w1_bias_name = f"model.layers.{layer_id}.{layer_type}.w1._layer.matmul.quant_bias" + w1_scale_name = f"model.layers.{layer_id}.{layer_type}.w1._layer.matmul.dequant_scale" + w1_quant_zp = f"model.layers.{layer_id}.{layer_type}.w1.quant_op.input_zp" + w1_quant_scale = f"model.layers.{layer_id}.{layer_type}.w1.quant_op.input_scale" + w3_weight_name = f"model.layers.{layer_id}.{layer_type}.w3._layer.weight" + w3_bias_name = f"model.layers.{layer_id}.{layer_type}.w3._layer.matmul.quant_bias" + w3_scale_name = f"model.layers.{layer_id}.{layer_type}.w3._layer.matmul.dequant_scale" + w3_quant_zp = f"model.layers.{layer_id}.{layer_type}.w3.quant_op.input_zp" + w3_quant_scale = f"model.layers.{layer_id}.{layer_type}.w3.quant_op.input_scale" + w2_weight_name = f"model.layers.{layer_id}.{layer_type}.w2._layer.weight" + w2_scale_name = f"model.layers.{layer_id}.{layer_type}.w2._layer.matmul.weight_scale" + w1_weight_param, _ = self.get_routed_safetensor_3_dim(w1_weight_name, src_hf_dir, hf_weight_map, tp_axis=2, + split_ep=self.moe_split_ep, split_tp=self.moe_split_tp) + + w1_bias_param, _ = self.get_routed_safetensor_2_dim(w1_bias_name, src_hf_dir, hf_weight_map, tp_axis=1, + split_ep=self.moe_split_ep, split_tp=self.moe_split_tp) + + w1_scale_param, _ = self.get_routed_safetensor_2_dim(w1_scale_name, src_hf_dir, hf_weight_map, tp_axis=1, + split_ep=self.moe_split_ep, split_tp=self.moe_split_tp) + + w1_quant_zp_param, _ = self.get_safetensor_from_file(w1_quant_zp, src_hf_dir, hf_weight_map) + w1_quant_scale_param, _ = self.get_safetensor_from_file(w1_quant_scale, src_hf_dir, hf_weight_map) + + + w3_weight_param, _ = self.get_routed_safetensor_3_dim(w3_weight_name, src_hf_dir, hf_weight_map, tp_axis=2, + split_ep=self.moe_split_ep, split_tp=self.moe_split_tp) + + w3_bias_param, _ = self.get_routed_safetensor_2_dim(w3_bias_name, src_hf_dir, hf_weight_map, tp_axis=1, + split_ep=self.moe_split_ep, split_tp=self.moe_split_tp) + + w3_scale_param, _ = self.get_routed_safetensor_2_dim(w3_scale_name, src_hf_dir, hf_weight_map, tp_axis=1, + split_ep=self.moe_split_ep, split_tp=self.moe_split_tp) + + w3_quant_zp_param, _ = self.get_safetensor_from_file(w3_quant_zp, src_hf_dir, hf_weight_map) + w3_quant_scale_param, _ = self.get_safetensor_from_file(w3_quant_scale, src_hf_dir, hf_weight_map) + + w2_weight_param, _ = self.get_routed_safetensor_3_dim(w2_weight_name, src_hf_dir, hf_weight_map, tp_axis=1, + split_ep=self.moe_split_ep, split_tp=self.moe_split_tp) + w2_scale_param, _ = self.get_routed_safetensor_2_dim(w2_scale_name, src_hf_dir, hf_weight_map, + split_ep=self.moe_split_ep, split_tp=False) + + if ffn_concat: + concat_weight_name = f"model.layers.{layer_id}.{layer_type}.w_gate_hidden._layer.weight" + concat_weight_param = ms.Tensor(np.concatenate([w1_weight_param, w3_weight_param], axis=2), dtype=ms.int8) + parameter_dict[concat_weight_name] = ms.Parameter(concat_weight_param, name=concat_weight_name, + requires_grad=False) + + concat_bias_name = f"model.layers.{layer_id}.{layer_type}.w_gate_hidden._layer.matmul.quant_bias" + concat_bias_param = ms.Tensor(np.concatenate([w1_bias_param, w3_bias_param], axis=1), dtype=ms.int32) + parameter_dict[concat_bias_name] = ms.Parameter(concat_bias_param, name=concat_bias_name, + requires_grad=False) + + concat_scale_name = f"model.layers.{layer_id}.{layer_type}.w_gate_hidden._layer.matmul.dequant_scale" + concat_scale_param = ms.Tensor(np.concatenate([w1_scale_param, w3_scale_param], axis=1), dtype=ms.bfloat16) + parameter_dict[concat_scale_name] = ms.Parameter(concat_scale_param, name=concat_scale_name, + requires_grad=False) + + concat_quant_zp_name = f"model.layers.{layer_id}.{layer_type}.w_gate_hidden.quant_op.input_zp" + concat_quant_zp_param = ms.Tensor(w1_quant_zp_param, dtype=ms.bfloat16) + parameter_dict[concat_quant_zp_name] = ms.Parameter(concat_quant_zp_param, name=concat_quant_zp_name, + requires_grad=False) + + concat_quant_scale_name = f"model.layers.{layer_id}.{layer_type}.w_gate_hidden.quant_op.input_scale" + concat_quant_scale_param = ms.Tensor(w1_quant_scale_param, dtype=ms.bfloat16) + parameter_dict[concat_quant_scale_name] = ms.Parameter(concat_quant_scale_param, + name=concat_quant_scale_name, + requires_grad=False) + else: + # w1 w3 + parameter_dict[w1_weight_name] = ms.Parameter(ms.Tensor(w1_weight_param, ms.int8), name=w1_weight_name, + requires_grad=False) + parameter_dict[w3_weight_name] = ms.Parameter(ms.Tensor(w3_weight_param, ms.int8), name=w3_weight_name, + requires_grad=False) + + parameter_dict[w1_bias_name] = ms.Parameter(ms.Tensor(w1_bias_param, ms.int32), + name=w1_bias_name, requires_grad=False) + parameter_dict[w3_bias_name] = ms.Parameter(ms.Tensor(w3_bias_param, ms.int32), + name=w3_bias_name, requires_grad=False) + + parameter_dict[w1_scale_name] = ms.Parameter(ms.Tensor(w1_scale_param, ms.bfloat16), + name=w1_scale_name, requires_grad=False) + parameter_dict[w3_scale_name] = ms.Parameter(ms.Tensor(w3_scale_param, ms.bfloat16), + name=w3_scale_name, requires_grad=False) + + parameter_dict[w1_quant_zp] = ms.Parameter(ms.Tensor(w1_quant_zp_param, ms.bfloat16), + name=w1_quant_zp, requires_grad=False) + parameter_dict[w3_quant_zp] = ms.Parameter(ms.Tensor(w3_quant_zp_param, ms.bfloat16), + name=w3_quant_zp, requires_grad=False) + + parameter_dict[w1_quant_scale] = ms.Parameter(ms.Tensor(w1_quant_scale_param, ms.bfloat16), + name=w1_quant_scale, requires_grad=False) + parameter_dict[w3_quant_scale] = ms.Parameter(ms.Tensor(w3_quant_scale_param, ms.bfloat16), + name=w3_quant_scale, requires_grad=False) + parameter_dict[w2_weight_name] = ms.Parameter(ms.Tensor(w2_weight_param, ms.int8), name=w2_weight_name, + requires_grad=False) + parameter_dict[w2_scale_name] = ms.Parameter(ms.Tensor(w2_scale_param, ms.bfloat16), + name=w2_scale_name, requires_grad=False) + + def smooth_quant_process_shared_ffn_weight(self, src_hf_dir, layer_id, hf_weight_map, parameter_dict, layer_type): + """smooth_quant_process_shared_ffn_weight""" + split_num = -1 + rank_id = -1 + if self.ep_method == EPMethod.ALLGATHER: + split_num = self.global_group_size + rank_id = get_rank() + elif self.ep_method == EPMethod.ALLTOALL: + split_num = 1 + rank_id = 0 + + ffn_concat = self.config.model.model_config.ffn_concat + w1_weight_name = f"model.layers.{layer_id}.{layer_type}.w1._layer.weight" + w1_weight_param, _ = self.get_safetensor_from_file(w1_weight_name, src_hf_dir, hf_weight_map, + is_split_param=True, + split_axis=0, split_num=split_num, rank_id=rank_id) + w1_bias_name = f"model.layers.{layer_id}.{layer_type}.w1._layer.matmul.quant_bias" + w1_bias_param, _ = self.get_safetensor_from_file(w1_bias_name, src_hf_dir, hf_weight_map, + is_split_param=True, + split_axis=0, split_num=split_num, rank_id=rank_id) + + w1_scale_name = f"model.layers.{layer_id}.{layer_type}.w1._layer.matmul.dequant_scale" + w1_scale_param, _ = self.get_safetensor_from_file(w1_scale_name, src_hf_dir, hf_weight_map, + is_split_param=True, + split_axis=0, split_num=split_num, rank_id=rank_id) + + w1_quant_zp = f"model.layers.{layer_id}.{layer_type}.w1.quant_op.input_zp" + w1_quant_scale = f"model.layers.{layer_id}.{layer_type}.w1.quant_op.input_scale" + w1_quant_zp_param, _ = self.get_safetensor_from_file(w1_quant_zp, src_hf_dir, hf_weight_map) + w1_quant_scale_param, _ = self.get_safetensor_from_file(w1_quant_scale, src_hf_dir, hf_weight_map) + + w3_weight_name = f"model.layers.{layer_id}.{layer_type}.w3._layer.weight" + w3_weight_param, _ = self.get_safetensor_from_file(w3_weight_name, src_hf_dir, hf_weight_map, + is_split_param=True, + split_axis=0, split_num=split_num, rank_id=rank_id) + w3_bias_name = f"model.layers.{layer_id}.{layer_type}.w3._layer.matmul.quant_bias" + w3_bias_param, _ = self.get_safetensor_from_file(w3_bias_name, src_hf_dir, hf_weight_map, + is_split_param=True, + split_axis=0, split_num=split_num, rank_id=rank_id) + w3_scale_name = f"model.layers.{layer_id}.{layer_type}.w3._layer.matmul.dequant_scale" + w3_scale_param, _ = self.get_safetensor_from_file(w3_scale_name, src_hf_dir, hf_weight_map, + is_split_param=True, + split_axis=0, split_num=split_num, rank_id=rank_id) + w2_weight_name = f"model.layers.{layer_id}.{layer_type}.w2._layer.weight" + w2_scale_name = f"model.layers.{layer_id}.{layer_type}.w2._layer.matmul.weight_scale" + w2_weight_param, _ = self.get_safetensor_from_file(w2_weight_name, src_hf_dir, hf_weight_map, + is_split_param=True, + split_axis=1, split_num=split_num, rank_id=rank_id) + w2_scale_param, _ = self.get_safetensor_from_file(w2_scale_name, src_hf_dir, hf_weight_map) + + w3_quant_zp = f"model.layers.{layer_id}.{layer_type}.w3.quant_op.input_zp" + w3_quant_scale = f"model.layers.{layer_id}.{layer_type}.w3.quant_op.input_scale" + w3_quant_zp_param, _ = self.get_safetensor_from_file(w3_quant_zp, src_hf_dir, hf_weight_map) + w3_quant_scale_param, _ = self.get_safetensor_from_file(w3_quant_scale, src_hf_dir, hf_weight_map) + if ffn_concat: + concat_weight_name = f"model.layers.{layer_id}.{layer_type}.w_gate_hidden._layer.weight" + concat_weight_param = ms.Tensor(np.concatenate([w1_weight_param, w3_weight_param], axis=0), dtype=ms.int8) + parameter_dict[concat_weight_name] = ms.Parameter(concat_weight_param, name=concat_weight_name, + requires_grad=False) + + concat_bias_name = f"model.layers.{layer_id}.{layer_type}.w_gate_hidden._layer.matmul.quant_bias" + concat_bias_param = ms.Tensor(np.concatenate([w1_bias_param, w3_bias_param], axis=0), dtype=ms.int32) + parameter_dict[concat_bias_name] = ms.Parameter(concat_bias_param, name=concat_bias_name, + requires_grad=False) + + concat_scale_name = f"model.layers.{layer_id}.{layer_type}.w_gate_hidden._layer.matmul.dequant_scale" + concat_scale_param = ms.Tensor(np.concatenate([w1_scale_param, w3_scale_param], axis=0), dtype=ms.float32) + parameter_dict[concat_scale_name] = ms.Parameter(concat_scale_param, name=concat_scale_name, + requires_grad=False) + + concat_quant_zp_name = f"model.layers.{layer_id}.{layer_type}.w_gate_hidden.quant_op.input_zp" + concat_quant_zp_param = ms.Tensor(w1_quant_zp_param, dtype=ms.int8) + parameter_dict[concat_quant_zp_name] = ms.Parameter(concat_quant_zp_param, name=concat_quant_zp_name, + requires_grad=False) + + concat_quant_scale_name = f"model.layers.{layer_id}.{layer_type}.w_gate_hidden.quant_op.input_scale" + concat_quant_scale_param = ms.Tensor(w1_quant_scale_param, dtype=ms.bfloat16) + parameter_dict[concat_quant_scale_name] = ms.Parameter(concat_quant_scale_param, + name=concat_quant_scale_name, + requires_grad=False) + else: + # w1 w3 + parameter_dict[w1_weight_name] = ms.Parameter(ms.Tensor(w1_weight_param, ms.int8), name=w1_weight_name, + requires_grad=False) + parameter_dict[w3_weight_name] = ms.Parameter(ms.Tensor(w3_weight_param, ms.int8), name=w3_weight_name, + requires_grad=False) + + parameter_dict[w1_bias_name] = ms.Parameter(ms.Tensor(w1_bias_param, ms.int32), + name=w1_bias_name, requires_grad=False) + parameter_dict[w3_bias_name] = ms.Parameter(ms.Tensor(w3_bias_param, ms.int32), + name=w3_bias_name, requires_grad=False) + + parameter_dict[w1_scale_name] = ms.Parameter(ms.Tensor(w1_scale_param, ms.float32), + name=w1_scale_name, requires_grad=False) + parameter_dict[w3_scale_name] = ms.Parameter(ms.Tensor(w3_scale_param, ms.float32), + name=w3_scale_name, requires_grad=False) + + parameter_dict[w1_quant_zp] = ms.Parameter(ms.Tensor(w1_quant_zp_param, ms.int8), + name=w1_quant_zp, requires_grad=False) + parameter_dict[w3_quant_zp] = ms.Parameter(ms.Tensor(w3_quant_zp_param, ms.int8), + name=w3_quant_zp, requires_grad=False) + + parameter_dict[w1_quant_scale] = ms.Parameter(ms.Tensor(w1_quant_scale_param, ms.bfloat16), + name=w1_quant_scale, requires_grad=False) + parameter_dict[w3_quant_scale] = ms.Parameter(ms.Tensor(w3_quant_scale_param, ms.bfloat16), + name=w3_quant_scale, requires_grad=False) + parameter_dict[w2_weight_name] = ms.Parameter(ms.Tensor(w2_weight_param, ms.int8), name=w2_weight_name, + requires_grad=False) + parameter_dict[w2_scale_name] = ms.Parameter(ms.Tensor(w2_scale_param, ms.bfloat16), + name=w2_scale_name, requires_grad=False) + + def smooth_quant_process_ffn_weight(self, src_hf_dir, layer_id, hf_weight_map, parameter_dict, layer_type): + """smooth_quant_process_ffn_weight""" + + ffn_concat = self.config.model.model_config.ffn_concat + w1_weight_name = f"model.layers.{layer_id}.{layer_type}.w1._layer.weight" + w1_weight_param, _ = self.get_safetensor_from_file(w1_weight_name, src_hf_dir, hf_weight_map, + is_split_param=True, + split_axis=0) + w1_bias_name = f"model.layers.{layer_id}.{layer_type}.w1._layer.matmul.quant_bias" + w1_bias_param, _ = self.get_safetensor_from_file(w1_bias_name, src_hf_dir, hf_weight_map, + is_split_param=True, + split_axis=0) + w1_scale_name = f"model.layers.{layer_id}.{layer_type}.w1._layer.matmul.dequant_scale" + w1_scale_param, _ = self.get_safetensor_from_file(w1_scale_name, src_hf_dir, hf_weight_map, + is_split_param=True, + split_axis=0) + + w1_quant_zp = f"model.layers.{layer_id}.{layer_type}.w1.quant_op.input_zp" + w1_quant_scale = f"model.layers.{layer_id}.{layer_type}.w1.quant_op.input_scale" + w1_quant_zp_param, _ = self.get_safetensor_from_file(w1_quant_zp, src_hf_dir, hf_weight_map) + w1_quant_scale_param, _ = self.get_safetensor_from_file(w1_quant_scale, src_hf_dir, hf_weight_map) + + w3_weight_name = f"model.layers.{layer_id}.{layer_type}.w3._layer.weight" + w3_weight_param, _ = self.get_safetensor_from_file(w3_weight_name, src_hf_dir, hf_weight_map, + is_split_param=True, + split_axis=0) + w3_bias_name = f"model.layers.{layer_id}.{layer_type}.w3._layer.matmul.quant_bias" + w3_bias_param, _ = self.get_safetensor_from_file(w3_bias_name, src_hf_dir, hf_weight_map, + is_split_param=True, + split_axis=0) + w3_scale_name = f"model.layers.{layer_id}.{layer_type}.w3._layer.matmul.dequant_scale" + w3_scale_param, _ = self.get_safetensor_from_file(w3_scale_name, src_hf_dir, hf_weight_map, + is_split_param=True, + split_axis=0) + w2_weight_name = f"model.layers.{layer_id}.{layer_type}.w2._layer.weight" + w2_scale_name = f"model.layers.{layer_id}.{layer_type}.w2._layer.matmul.weight_scale" + w2_weight_param, _ = self.get_safetensor_from_file(w2_weight_name, src_hf_dir, hf_weight_map, + is_split_param=True, split_axis=1) + w2_scale_param, _ = self.get_safetensor_from_file(w2_scale_name, src_hf_dir, hf_weight_map) + + w3_quant_zp = f"model.layers.{layer_id}.{layer_type}.w3.quant_op.input_zp" + w3_quant_scale = f"model.layers.{layer_id}.{layer_type}.w3.quant_op.input_scale" + w3_quant_zp_param, _ = self.get_safetensor_from_file(w3_quant_zp, src_hf_dir, hf_weight_map) + w3_quant_scale_param, _ = self.get_safetensor_from_file(w3_quant_scale, src_hf_dir, hf_weight_map) + if ffn_concat: + concat_weight_name = f"model.layers.{layer_id}.{layer_type}.w_gate_hidden._layer.weight" + concat_weight_param = ms.Tensor(np.concatenate([w1_weight_param, w3_weight_param], axis=0), dtype=ms.int8) + parameter_dict[concat_weight_name] = ms.Parameter(concat_weight_param, name=concat_weight_name, + requires_grad=False) + + concat_bias_name = f"model.layers.{layer_id}.{layer_type}.w_gate_hidden._layer.matmul.quant_bias" + concat_bias_param = ms.Tensor(np.concatenate([w1_bias_param, w3_bias_param], axis=0), dtype=ms.int32) + parameter_dict[concat_bias_name] = ms.Parameter(concat_bias_param, name=concat_bias_name, + requires_grad=False) + + concat_scale_name = f"model.layers.{layer_id}.{layer_type}.w_gate_hidden._layer.matmul.dequant_scale" + concat_scale_param = ms.Tensor(np.concatenate([w1_scale_param, w3_scale_param], axis=0), dtype=ms.float32) + parameter_dict[concat_scale_name] = ms.Parameter(concat_scale_param, name=concat_scale_name, + requires_grad=False) + + concat_quant_zp_name = f"model.layers.{layer_id}.{layer_type}.w_gate_hidden.quant_op.input_zp" + concat_quant_zp_param = ms.Tensor(w1_quant_zp_param, dtype=ms.int8) + parameter_dict[concat_quant_zp_name] = ms.Parameter(concat_quant_zp_param, name=concat_quant_zp_name, + requires_grad=False) + + concat_quant_scale_name = f"model.layers.{layer_id}.{layer_type}.w_gate_hidden.quant_op.input_scale" + concat_quant_scale_param = ms.Tensor(w1_quant_scale_param, dtype=ms.bfloat16) + parameter_dict[concat_quant_scale_name] = ms.Parameter(concat_quant_scale_param, + name=concat_quant_scale_name, + requires_grad=False) + else: + # w1 w3 + parameter_dict[w1_weight_name] = ms.Parameter(ms.Tensor(w1_weight_param, ms.int8), name=w1_weight_name, + requires_grad=False) + parameter_dict[w3_weight_name] = ms.Parameter(ms.Tensor(w3_weight_param, ms.int8), name=w3_weight_name, + requires_grad=False) + + parameter_dict[w1_bias_name] = ms.Parameter(ms.Tensor(w1_bias_param, ms.int32), + name=w1_bias_name, requires_grad=False) + parameter_dict[w3_bias_name] = ms.Parameter(ms.Tensor(w3_bias_param, ms.int32), + name=w3_bias_name, requires_grad=False) + + parameter_dict[w1_scale_name] = ms.Parameter(ms.Tensor(w1_scale_param, ms.float32), + name=w1_scale_name, requires_grad=False) + parameter_dict[w3_scale_name] = ms.Parameter(ms.Tensor(w3_scale_param, ms.float32), + name=w3_scale_name, requires_grad=False) + + parameter_dict[w1_quant_zp] = ms.Parameter(ms.Tensor(w1_quant_zp_param, ms.int8), + name=w1_quant_zp, requires_grad=False) + parameter_dict[w3_quant_zp] = ms.Parameter(ms.Tensor(w3_quant_zp_param, ms.int8), + name=w3_quant_zp, requires_grad=False) + + parameter_dict[w1_quant_scale] = ms.Parameter(ms.Tensor(w1_quant_scale_param, ms.bfloat16), + name=w1_quant_scale, requires_grad=False) + parameter_dict[w3_quant_scale] = ms.Parameter(ms.Tensor(w3_quant_scale_param, ms.bfloat16), + name=w3_quant_scale, requires_grad=False) + parameter_dict[w2_weight_name] = ms.Parameter(ms.Tensor(w2_weight_param, ms.int8), name=w2_weight_name, + requires_grad=False) + parameter_dict[w2_scale_name] = ms.Parameter(ms.Tensor(w2_scale_param, ms.bfloat16), + name=w2_scale_name, requires_grad=False) + + def smooth_quant_process_qkv_weight(self, src_hf_dir, layer_id, hf_weight_map, parameter_dict): + '''smooth_quant_process_qkv_weight''' + qkv_concat = self.config.model.model_config.qkv_concat + # q2l_proj + q2l_weight_name = f"model.layers.{layer_id}.attention.q2l_proj._layer.weight" + q2l_weight_param, _ = self.get_safetensor_from_file(q2l_weight_name, src_hf_dir, hf_weight_map) + q2l_bias_name = f"model.layers.{layer_id}.attention.q2l_proj._layer.matmul.quant_bias" + q2l_bias_param, _ = self.get_safetensor_from_file(q2l_bias_name, src_hf_dir, hf_weight_map) + q2l_scale_name = f"model.layers.{layer_id}.attention.q2l_proj._layer.matmul.dequant_scale" + q2l_scale_param, _ = self.get_safetensor_from_file(q2l_scale_name, src_hf_dir, hf_weight_map) + + q2l_quant_zp = f"model.layers.{layer_id}.attention.q2l_proj.quant_op.input_zp" + q2l_quant_scale = f"model.layers.{layer_id}.attention.q2l_proj.quant_op.input_scale" + q2l_quant_zp_param, _ = self.get_safetensor_from_file(q2l_quant_zp, src_hf_dir, hf_weight_map) + q2l_quant_scale_param, _ = self.get_safetensor_from_file(q2l_quant_scale, src_hf_dir, hf_weight_map) + + kv2l_weight_name = f"model.layers.{layer_id}.attention.kv2l._layer.weight" + kv2l_weight_param, _ = self.get_safetensor_from_file(kv2l_weight_name, src_hf_dir, hf_weight_map) + kv2l_bias_name = f"model.layers.{layer_id}.attention.kv2l._layer.matmul.quant_bias" + kv2l_bias_param, _ = self.get_safetensor_from_file(kv2l_bias_name, src_hf_dir, hf_weight_map) + kv2l_scale_name = f"model.layers.{layer_id}.attention.kv2l._layer.matmul.dequant_scale" + kv2l_scale_param, _ = self.get_safetensor_from_file(kv2l_scale_name, src_hf_dir, hf_weight_map) + + kv2l_quant_zp = f"model.layers.{layer_id}.attention.kv2l.quant_op.input_zp" + kv2l_quant_scale = f"model.layers.{layer_id}.attention.kv2l.quant_op.input_scale" + kv2l_quant_zp_param, _ = self.get_safetensor_from_file(kv2l_quant_zp, src_hf_dir, hf_weight_map) + kv2l_quant_scale_param, _ = self.get_safetensor_from_file(kv2l_quant_scale, src_hf_dir, hf_weight_map) + + if qkv_concat: + qkv2l_weight_name = f"model.layers.{layer_id}.attention.qkv2l._layer.weight" + qkv2l_bias_name = f"model.layers.{layer_id}.attention.qkv2l._layer.matmul.quant_bias" + qkv2l_scale_name = f"model.layers.{layer_id}.attention.qkv2l._layer.matmul.dequant_scale" + qkv2l_quant_zp_name = f"model.layers.{layer_id}.attention.qkv2l.quant_op.input_zp" + qkv2l_quant_scale_name = f"model.layers.{layer_id}.attention.qkv2l.quant_op.input_scale" + + qkv2l_weight = np.concatenate((q2l_weight_param, kv2l_weight_param), 0) + parameter_dict[qkv2l_weight_name] = ms.Parameter(ms.Tensor(qkv2l_weight, ms.int8), name=qkv2l_weight_name, + requires_grad=False) + qkv2l_bias = np.concatenate((q2l_bias_param, kv2l_bias_param), 0) + parameter_dict[qkv2l_bias_name] = ms.Parameter(ms.Tensor(qkv2l_bias, ms.int32), name=qkv2l_bias_name, + requires_grad=False) + qkv2l_scale = np.concatenate((q2l_scale_param, kv2l_scale_param), 0) + parameter_dict[qkv2l_scale_name] = ms.Parameter(ms.Tensor(qkv2l_scale, ms.float32), name=qkv2l_scale_name, + requires_grad=False) + parameter_dict[qkv2l_quant_zp_name] = ms.Parameter(ms.Tensor(q2l_quant_zp_param, ms.int8), + name=qkv2l_quant_zp_name, requires_grad=False) + parameter_dict[qkv2l_quant_scale_name] = ms.Parameter(ms.Tensor(q2l_quant_scale_param, ms.bfloat16), + name=qkv2l_quant_scale_name, requires_grad=False) + else: + parameter_dict[q2l_weight_name] = ms.Parameter(ms.Tensor(q2l_weight_param, ms.int8), name=q2l_weight_name, + requires_grad=False) + parameter_dict[kv2l_weight_name] = ms.Parameter(ms.Tensor(kv2l_weight_param, ms.int8), + name=kv2l_weight_name, requires_grad=False) + parameter_dict[q2l_bias_name] = ms.Parameter(ms.Tensor(q2l_bias_param, ms.int32), name=q2l_bias_name, + requires_grad=False) + parameter_dict[kv2l_bias_name] = ms.Parameter(ms.Tensor(kv2l_bias_param, ms.int32), name=kv2l_bias_name, + requires_grad=False) + parameter_dict[q2l_scale_name] = ms.Parameter(ms.Tensor(q2l_scale_param, ms.float32), name=q2l_scale_name, + requires_grad=False) + parameter_dict[kv2l_scale_name] = ms.Parameter(ms.Tensor(kv2l_scale_param, ms.float32), + name=kv2l_scale_name, requires_grad=False) + parameter_dict[q2l_quant_zp] = ms.Parameter(ms.Tensor(q2l_quant_zp_param, ms.int8), name=q2l_quant_zp, + requires_grad=False) + parameter_dict[kv2l_quant_zp] = ms.Parameter(ms.Tensor(kv2l_quant_zp_param, ms.int8), name=kv2l_quant_zp, + requires_grad=False) + parameter_dict[q2l_quant_scale] = ms.Parameter(ms.Tensor(q2l_quant_scale_param, ms.bfloat16), + name=q2l_quant_scale, requires_grad=False) + parameter_dict[kv2l_quant_scale] = ms.Parameter(ms.Tensor(kv2l_quant_scale_param, ms.bfloat16), + name=kv2l_quant_scale, requires_grad=False) + + def infer_smooth_quant_row_linear_split(self, param_name, src_hf_dir, hf_weight_map): + '''infer_smooth_quant_row_linear_split''' + if param_name.endswith(".weight"): + value, _ = self.get_safetensor_from_file(param_name, src_hf_dir, + hf_weight_map, is_split_param=True, + split_axis=1) + elif "quant_op" in param_name: + value, _ = self.get_safetensor_from_file(param_name, src_hf_dir, + hf_weight_map, is_split_param=True, + split_axis=0) + else: + value, _ = self.get_safetensor_from_file(param_name, src_hf_dir, + hf_weight_map) + if "wo._layer.matmul.quant_bias" in param_name and get_tensor_model_parallel_rank() != 0: + value.fill(0) + return value + + def infer_smooth_quant_get_value(self, param_name, src_hf_dir, hf_weight_map, no_need_split_layer): + '''infer_smooth_quant_get_value''' + + if any([name in param_name for name in no_need_split_layer]): + value, _ = self.get_safetensor_from_file(param_name, src_hf_dir, + hf_weight_map) + elif any([name in param_name for name in [".l2q_proj."]]): + if param_name.endswith(".weight") or "matmul" in param_name: + value, _ = self.get_safetensor_from_file(param_name, src_hf_dir, + hf_weight_map, is_split_param=True, + split_axis=0) + else: + value, _ = self.get_safetensor_from_file(param_name, src_hf_dir, + hf_weight_map) + elif any([name in param_name for name in [".wo."]]): + value = self.infer_smooth_quant_row_linear_split(param_name, src_hf_dir, hf_weight_map) + elif any([name in param_name for name in ["lkv2kv_k_nope", "lkv2kv_v"]]): + value, _ = self.get_safetensor_from_file(param_name, src_hf_dir, hf_weight_map, + is_split_param=True, split_axis=0) + elif "lm_head" in param_name: + if not self.config.parallel_config.vocab_emb_dp: + value, _ = self.get_safetensor_from_file(param_name, src_hf_dir, hf_weight_map, + is_split_param=True, split_axis=0) + else: + value, _ = self.get_safetensor_from_file(param_name, src_hf_dir, hf_weight_map) + else: + raise ValueError(f"not found layer {param_name}, please check safetensors file.") + return value + def infer_smooth_quant_net_ms_convert_layer_weight(self, src_hf_dir, num_layers, hf_weight_map): - """infer_smooth_quant_net_ms_convert_layer_weight""" + '''infer_smooth_quant_net_ms_convert_layer_weight''' parameter_dict = {} - no_need_split_layer = ["tok_embeddings", "norm", "q2l_proj", - "kv2l", "routed_experts.router.dense", + no_need_split_layer = ["tok_embeddings", "norm", "routed_experts.router.dense", "routed_experts.router.e_score_correction_bias", "topk_bias"] - for param_name, _ in tqdm(hf_weight_map.items(), desc="split safetensors"): + for layer_id in tqdm(range(num_layers), desc="qkv/ffn params load"): + if layer_id >= 3: + self.smooth_quant_process_route_ffn_weight(src_hf_dir, layer_id, hf_weight_map, parameter_dict, + "feed_forward.routed_experts.ffn") + self.smooth_quant_process_shared_ffn_weight(src_hf_dir, layer_id, hf_weight_map, parameter_dict, + "feed_forward.shared_experts") + + else: + self.smooth_quant_process_ffn_weight(src_hf_dir, layer_id, hf_weight_map, parameter_dict, + "feed_forward") + self.smooth_quant_process_qkv_weight(src_hf_dir, layer_id, hf_weight_map, parameter_dict) + + skip_layer = ["feed_forward.routed_experts.ffn", "feed_forward.shared_experts", "feed_forward.w", + "attention.kv2l", "attention.q"] + + for param_name, _ in tqdm(hf_weight_map.items(), desc="remaining params load"): if "model.layers" in param_name and int(param_name.split('.')[2]) >= num_layers: continue - if any([name in param_name for name in no_need_split_layer]): - value, _ = self.get_safetensor_from_file(param_name, src_hf_dir, - hf_weight_map) - elif any([name in param_name for name in [".l2q_proj.", ".feed_forward.w_gate_hidden.", - "shared_experts.w_gate_hidden"]]): - if param_name.endswith(".weight") or "matmul" in param_name: - value, _ = self.get_safetensor_from_file(param_name, src_hf_dir, - hf_weight_map, is_split_param=True, - split_axis=0) - else: - value, _ = self.get_safetensor_from_file(param_name, src_hf_dir, - hf_weight_map) - elif any([name in param_name for name in [".feed_forward.w2.", ".wo.", "shared_experts.w2"]]): - if param_name.endswith(".weight"): - value, _ = self.get_safetensor_from_file(param_name, src_hf_dir, - hf_weight_map, is_split_param=True, - split_axis=1) - elif "quant_op" in param_name: - value, _ = self.get_safetensor_from_file(param_name, src_hf_dir, - hf_weight_map, is_split_param=True, - split_axis=0) - else: - value, _ = self.get_safetensor_from_file(param_name, src_hf_dir, - hf_weight_map) - elif ".routed_experts.ffn.w_gate_hidden." in param_name: - if param_name.endswith(".weight"): - value, _ = self.get_safetensor_from_file(param_name, src_hf_dir, hf_weight_map) - value_list = [] - for experts_id in range(value.shape[0]): - value_list.append(self.split_weight_by_rank(value[experts_id, :, :], split_axis=1)) - value = np.stack(value_list, axis=0) - elif "matmul" in param_name: - value, _ = self.get_safetensor_from_file(param_name, src_hf_dir, hf_weight_map) - value_list = [] - for experts_id in range(value.shape[0]): - value_list.append(self.split_weight_by_rank(value[experts_id, :], split_axis=0)) - value = np.stack(value_list, axis=0) - else: - value, _ = self.get_safetensor_from_file(param_name, src_hf_dir, - hf_weight_map) - elif ".routed_experts.ffn.w2" in param_name: - if param_name.endswith(".weight"): - value, _ = self.get_safetensor_from_file(param_name, src_hf_dir, hf_weight_map) - value_list = [] - for experts_id in range(value.shape[0]): - value_list.append(self.split_weight_by_rank(value[experts_id, :, :], split_axis=0)) - value = np.stack(value_list, axis=0) - else: - value, _ = self.get_safetensor_from_file(param_name, src_hf_dir, - hf_weight_map) - elif any([name in param_name for name in ["lkv2kv_k_nope", "lkv2kv_v"]]): - value, _ = self.get_safetensor_from_file(param_name, src_hf_dir, hf_weight_map, - is_split_param=True, split_axis=0) - elif "lm_head" in param_name: - if not self.config.parallel_config.vocab_emb_dp: - value, _ = self.get_safetensor_from_file(param_name, src_hf_dir, hf_weight_map, - is_split_param=True, split_axis=0) - else: - value, _ = self.get_safetensor_from_file(param_name, src_hf_dir, hf_weight_map) - else: - raise ValueError(f"not found layer {param_name}, please check safetensors file.") + if any([name in param_name for name in skip_layer]): + continue + value = self.infer_smooth_quant_get_value(param_name, src_hf_dir, hf_weight_map, no_need_split_layer) dst_dtype = convert_np_to_ms_dtype(value) parameter_dict[param_name] = ms.Parameter(ms.Tensor(value, dtype=dst_dtype), name=param_name, requires_grad=False) param_not_load, ckpt_not_load = ms.load_param_into_net(self.network, parameter_dict) - logger.info("smoothquant param_not_load: %s" % str(param_not_load)) - logger.info("smoothquant ckpt_not_load: %s" % str(ckpt_not_load)) + print(f"smoothquant param_not_load:{param_not_load}") + print(f"smoothquant ckpt_not_load:{ckpt_not_load}") def infer_gptq_quant_net_ms_convert_layer_weight(self, src_hf_dir, num_layers, hf_weight_map): """infer_gptq_quant_net_ms_convert_layer_weight""" diff --git a/vllm_mindspore/model_executor/models/mf_models/weight_processor.py b/vllm_mindspore/model_executor/models/mf_models/weight_processor.py index faa9e892..b33aa247 100644 --- a/vllm_mindspore/model_executor/models/mf_models/weight_processor.py +++ b/vllm_mindspore/model_executor/models/mf_models/weight_processor.py @@ -58,6 +58,8 @@ class BaseWeightProcessor: self.ep_group_nums = num_router_experts // self.moe_ep_size self.moe_ep_rank_id = self.global_rank_id // self.moe_tp_size self.moe_tp_rank_id = self.global_rank_id % self.moe_tp_size + self.ep_start = self.moe_ep_rank_id * self.ep_group_nums + self.ep_stop = (self.moe_ep_rank_id + 1) * self.ep_group_nums print(f"global_rank_id: {self.global_rank_id} \n" f"tp_group_size: {self.tp_group_size} \n" @@ -109,6 +111,65 @@ class BaseWeightProcessor: raise ValueError("split_axis:{} is not supported.".format(split_axis)) return split_data, qint4 + def get_routed_safetensor_3_dim(self, hf_param_name, src_hf_dir, hf_weight_map, split_ep=False, split_tp=False, tp_axis=-1): + '''get_routed_safetensor_3_dim''' + safetensor_file = hf_weight_map[hf_param_name] + filename = os.path.join(src_hf_dir, safetensor_file) + sf_file = self.get_file_handles(filename) + qint4 = False + if sf_file.metadata() is not None and hf_param_name in sf_file.metadata().keys(): + qint4 = True + if not split_tp and not split_ep: + np_data = sf_file.get_tensor(hf_param_name) + return np_data, qint4 + + np_data = sf_file.get_slice(hf_param_name) + if not split_tp and split_ep: + split_data = np_data[self.ep_start:self.ep_stop, :, :] + return split_data, qint4 + + shape = np_data.get_shape() + if tp_axis == 1: + split_size = shape[1] // self.moe_tp_size + start = self.moe_tp_rank_id * split_size + stop = (self.moe_tp_rank_id + 1) * split_size + split_data = np_data[self.ep_start:self.ep_stop, start:stop, :] if split_ep else np_data[:, start:stop, :] + elif tp_axis == 2: + split_size = shape[2] // self.moe_tp_size + start = self.moe_tp_rank_id * split_size + stop = (self.moe_tp_rank_id + 1) * split_size + split_data = np_data[self.ep_start:self.ep_stop, :, start:stop] if split_ep else np_data[:, :, start:stop] + else: + raise ValueError("split_tp is True but tp_axis:{} is not supported.".format(tp_axis)) + return split_data, qint4 + + def get_routed_safetensor_2_dim(self, hf_param_name, src_hf_dir, hf_weight_map, split_ep=False, split_tp=False, tp_axis=-1): + '''get_moe_routed_safetensor_2_dim''' + safetensor_file = hf_weight_map[hf_param_name] + filename = os.path.join(src_hf_dir, safetensor_file) + sf_file = self.get_file_handles(filename) + qint4 = False + if sf_file.metadata() is not None and hf_param_name in sf_file.metadata().keys(): + qint4 = True + if not split_tp and not split_ep: + np_data = sf_file.get_tensor(hf_param_name) + return np_data, qint4 + + np_data = sf_file.get_slice(hf_param_name) + if not split_tp and split_ep: + split_data = np_data[self.ep_start:self.ep_stop, :] + return split_data, qint4 + + shape = np_data.get_shape() + if tp_axis == 1: + split_size = shape[1] // self.moe_tp_size + start = self.moe_tp_rank_id * split_size + stop = (self.moe_tp_rank_id + 1) * split_size + split_data = np_data[self.ep_start:self.ep_stop, start:stop] if split_ep else np_data[:, start:stop] + else: + raise ValueError("split_tp is True but tp_axis:{} is not supported.".format(tp_axis)) + return split_data, qint4 + def get_safetensor_from_file(self, hf_param_name, src_hf_dir, hf_weight_map, is_split_param=False, split_axis=0, split_num=-1, rank_id=-1): rank_id = rank_id if rank_id != -1 else self.tp_rank_id @@ -135,6 +196,11 @@ class BaseWeightProcessor: start = rank_id * split_size stop = (rank_id + 1) * split_size split_data = np_data[:, start:stop] + elif split_axis == 2: + split_size = shape[2] // split_num + start = rank_id * split_size + stop = (rank_id + 1) * split_size + split_data = np_data[:, :, start:stop] else: raise ValueError("split_axis:{} is not supported.".format(split_axis)) return split_data, qint4 -- Gitee