From bf29b01c0977e68087b01f1b5d2b2c8686d4a4f1 Mon Sep 17 00:00:00 2001 From: huangzhuo Date: Thu, 10 Jul 2025 10:25:52 +0800 Subject: [PATCH] vllm_mindspore add a8w4 --- .../model_executor/model_loader/utils.py | 26 +- .../models/mf_models/deepseek_v3.py | 25 +- .../mf_models/deepseekv3_weight_processor.py | 464 +++++++++++++++++- .../models/mf_models/weight_processor.py | 1 + vllm_mindspore/utils.py | 2 + 5 files changed, 514 insertions(+), 4 deletions(-) diff --git a/vllm_mindspore/model_executor/model_loader/utils.py b/vllm_mindspore/model_executor/model_loader/utils.py index 66295a32c..0f4238450 100644 --- a/vllm_mindspore/model_executor/model_loader/utils.py +++ b/vllm_mindspore/model_executor/model_loader/utils.py @@ -16,6 +16,7 @@ # limitations under the License. # ============================================================================ +import numpy as np from typing import Tuple, Type from torch import nn @@ -24,7 +25,7 @@ from vllm.config import ModelConfig, ModelImpl from vllm.model_executor.models import ModelRegistry from vllm_mindspore.model_executor.models.registry import MindSporeModelRegistry -from vllm.model_executor.model_loader.utils import resolve_transformers_fallback +# from vllm.model_executor.model_loader.utils import resolve_transformers_fallback def get_ms_model_architecture(model_config: ModelConfig) -> Tuple[Type[nn.Module], str]: architectures = getattr(model_config.hf_config, "architectures", []) @@ -43,3 +44,26 @@ def get_ms_model_architecture(model_config: ModelConfig) -> Tuple[Type[nn.Module raise RecursionError("MindSpore unsupport reward model task now!") return model_cls, arch + +def convert_uint64_to_fp32(arr: np.ndarray): + arr_fp32 = arr.view(np.float32) + output = arr_fp32[:, :, 0::2] + return output + +def np_int4data_pack_to_int8_3d(np_data): + np_data = np_data.astype(np.int8) + np_data &= 0x000F + np_data[::, ::, 0::2] <<= 0 + np_data[::, ::, 1::2] <<= 4 + np_int4_data = np_data[::, ::, 0::2] | np_data[::, ::, 1::2] + return np_int4_data + +def unpack_int8_to_int4_3d(packed_data): + low_nibbles = (packed_data & 0x0F).astype(np.uint8) + high_nibbles = ((packed_data >> 4) & 0x0F).astype(np.uint8) + + unpacked = np.empty((*packed_data.shape[:2], packed_data.shape[2] * 2), dtype=np.uint8) + unpacked[..., 0::2] = low_nibbles + unpacked[..., 1::2] = high_nibbles + + return unpacked \ No newline at end of file diff --git a/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py b/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py index f014cff1c..eb1610928 100644 --- a/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py +++ b/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py @@ -84,7 +84,7 @@ class DeepseekV3ForCausalLM(MfModelBase, SupportsPP): self.sampler = get_sampler() self.set_modules({"model": self.network}) self.num_layers = self.model_config.get_num_layers(self.parallel_config) - + self.kv_caches = [Fake_MLA() for i in range(self.num_layers)] compilation_config = get_current_vllm_config().compilation_config @@ -217,6 +217,29 @@ class DeepseekV3ForCausalLM(MfModelBase, SupportsPP): act_quant_dtype=msdtype.int8, act_quant_granularity=QuantGranularity.PER_TOKEN, opname_blacklist=['lm_head', 'lkv2kv']) layer_policies = OrderedDict() + elif quant_type.lower() == 'a8w4': + cfg = PTQConfig(mode=quant_mode, backend=BackendTarget.ASCEND, weight_quant_dtype=msdtype.int8, + act_quant_dtype=msdtype.int8, + outliers_suppression=OutliersSuppressionType.OUTLIER_SUPPRESSION_LITE, + opname_blacklist=['lm_head', 'lkv2kv'], weight_clip=True) + mlp_config = PTQConfig(mode=quant_mode, backend=BackendTarget.ASCEND, weight_quant_dtype=msdtype.int8, + act_quant_dtype=msdtype.int8, + outliers_suppression=OutliersSuppressionType.NONE, + precision_recovery=PrecisionRecovery.NONE, + act_quant_granularity=QuantGranularity.PER_TOKEN, + weight_quant_granularity=QuantGranularity.PER_CHANNEL, + weight_clip=True) + gptq_config = GPTQQuantConfig(static_groups=True, desc_act=True) + moe_cfg = PTQConfig(mode=quant_mode, backend=BackendTarget.ASCEND, weight_quant_dtype=msdtype.qint4x2, + act_quant_dtype=msdtype.int8, act_quant_granularity=QuantGranularity.PER_TOKEN, + weight_quant_granularity=QuantGranularity.PER_GROUP, group_size=256, + algo_args=gptq_config, precision_recovery=PrecisionRecovery.GPTQ, weight_clip=True) + layer_policies = OrderedDict({r'.*\.feed_forward\.w2.*': mlp_config, + r'.*\.feed_forward\.w_gate_hidden.*': mlp_config, + r'.*\.shared_experts\.w2.*': mlp_config, + r'.*\.shared_experts\.w_gate_hidden.*': mlp_config, + r'.*\.routed_experts\.ffn\.w_gate_hidden.*': moe_cfg, + r'.*\.routed_experts\.ffn\.w2.*': moe_cfg}) else: logger.warning("Input unsupported quant type: %s.", quant_type) return None diff --git a/vllm_mindspore/model_executor/models/mf_models/deepseekv3_weight_processor.py b/vllm_mindspore/model_executor/models/mf_models/deepseekv3_weight_processor.py index bb7766ee9..8c43bbe2e 100644 --- a/vllm_mindspore/model_executor/models/mf_models/deepseekv3_weight_processor.py +++ b/vllm_mindspore/model_executor/models/mf_models/deepseekv3_weight_processor.py @@ -25,9 +25,29 @@ from tqdm import tqdm import mindspore as ms from mindspore import dtype from mindspore.communication.management import get_rank +from vllm_mindspore.model_executor.model_loader.utils import ( + convert_uint64_to_fp32, unpack_int8_to_int4_3d, np_int4data_pack_to_int8_3d) from vllm_mindspore.model_executor.models.mf_models.weight_processor import BaseWeightProcessor from vllm_mindspore.utils import convert_np_to_ms_dtype - +from vllm.distributed import get_tensor_model_parallel_rank + +def np_int8data_unpack_to_int4_3d(np_data): + """unpack int8 data to int4 in 3dim""" + np_data = np_data.astype(np.uint8) + np_data_low = ((np_data & 0x0F) << 4).astype(np.int8) >> 4 + np_data_high = ((np_data >> 4) << 4).astype(np.int8) >> 4 + + np_int4_data = np.zeros( + (np_data.shape[0], np_data.shape[1], np_data.shape[2] * 2), + dtype=np.int8) + np_int4_data[:, :, ::2] = np_data_low + np_int4_data[:, :, 1::2] = np_data_high + return np_int4_data + +def convert_uint64_to_fp32_gmm(uint64_data: np.ndarray) -> np.ndarray: + """Convert uint64 data to float32""" + uint32_data = uint64_data.astype(np.uint32) + return uint32_data.view(np.float32) class DeepseekV3WeightProcessor(BaseWeightProcessor): r""" @@ -41,6 +61,8 @@ class DeepseekV3WeightProcessor(BaseWeightProcessor): def __init__(self, config, network, is_quant): super().__init__(config, network, is_quant) self.num_layers = self.config.model.model_config.num_layers + self.col_moe_split_axis = 2 if self.is_310 else 1 + self.row_moe_split_axis = 1 if self.is_310 else 2 def quant_convert_weight_name(self, weight_name: str): """replace quant net weight name""" @@ -104,6 +126,23 @@ class DeepseekV3WeightProcessor(BaseWeightProcessor): weight_name = weight_name.replace('model.norm.weight', 'model.norm_out.weight') return weight_name + def small_cache_weight(self, src_hf_dir, layer_id, hf_weight_map, parameter_dict): + k_nope_weight_name = f"model.layers.{layer_id}.attention.lkv2kv_k_nope.weight" + k_nope_weight_param, _ = self.get_safetensor_from_file(k_nope_weight_name, src_hf_dir,hf_weight_map, + is_split_param=self.is_split_param, split_axis=0) + qabsorb_matmul_name = f"model.layers.{layer_id}.attention.qabsorb_matmul.weight" + qabsorb_param = k_nope_weight_param.reshape(-1, 128, 512) + parameter_dict[qabsorb_matmul_name] = ms.Parameter(ms.Tensor(qabsorb_param, dtype=ms.float16), + name=qabsorb_matmul_name, requires_grad=False) + + v_weight_name = f"model.layers.{layer_id}.attention.lkv2kv_v.weight" + v_weight_param, _ = self.get_safetensor_from_file(v_weight_name, src_hf_dir, hf_weight_map, + is_split_param=self.is_split_param, split_axis=0) + outabsorb_matmul_name = f"model.layers.{layer_id}.attention.outabsorb_matmul.weight" + outabsorb_param = v_weight_param.reshape(-1, 128, 512) + parameter_dict[outabsorb_matmul_name] = ms.Parameter(ms.Tensor(outabsorb_param, dtype=ms.float16), + name=outabsorb_matmul_name, requires_grad=False) + def infer_trans_rope_weight(self, weight, qk_rope_head_dim): """process rope router weight""" w1 = weight[..., -qk_rope_head_dim::2, :] @@ -1086,6 +1125,321 @@ class DeepseekV3WeightProcessor(BaseWeightProcessor): if layer_id >= self.num_layers: self.infer_process_mtp_layer_weight(src_hf_dir, layer_id, hf_weight_map) + def smooth_quant_process_shared_ffn_weight(self, src_hf_dir, layer_id, hf_weight_map, parameter_dict, layer_type): + """smooth_quant_process_shared_ffn_weight""" + + ffn_concat = self.config.model.model_config.ffn_concat + w1_weight_name = f"model.layers.{layer_id}.{layer_type}.w1._layer.weight" + w1_weight_param, _ = self.get_safetensor_from_file(w1_weight_name, src_hf_dir, hf_weight_map, + is_split_param=self.is_split_param, split_axis=0) + + w1_scale_name = f"model.layers.{layer_id}.{layer_type}.w1._layer.matmul.weight_scale" + w1_scale_param, _ = self.get_safetensor_from_file(w1_scale_name, src_hf_dir, hf_weight_map, + is_split_param=self.is_split_param, split_axis=0) + + w3_weight_name = f"model.layers.{layer_id}.{layer_type}.w3._layer.weight" + w3_weight_param, _ = self.get_safetensor_from_file(w3_weight_name, src_hf_dir, hf_weight_map, + is_split_param=self.is_split_param, split_axis=0) + + w3_scale_name = f"model.layers.{layer_id}.{layer_type}.w3._layer.matmul.weight_scale" + w3_scale_param, _ = self.get_safetensor_from_file(w3_scale_name, src_hf_dir, hf_weight_map, + is_split_param=self.is_split_param, split_axis=0) + + w2_weight_name = f"model.layers.{layer_id}.{layer_type}.w2._layer.weight" + w2_scale_name = f"model.layers.{layer_id}.{layer_type}.w2._layer.matmul.weight_scale" + w2_weight_param, _ = self.get_safetensor_from_file(w2_weight_name, src_hf_dir, hf_weight_map, + is_split_param=self.is_split_param, split_axis=1) + w2_scale_param, _ = self.get_safetensor_from_file(w2_scale_name, src_hf_dir, hf_weight_map) + + if self.is_310: + w_scale_dtype = ms.float32 + else: + w_scale_dtype = ms.bfloat16 + + if ffn_concat: + concat_weight_name = f"model.layers.{layer_id}.{layer_type}.w_gate_hidden._layer.weight" + concat_weight_param = ms.Tensor(np.concatenate([w1_weight_param, w3_weight_param], axis=0), dtype=ms.int8) + parameter_dict[concat_weight_name] = ms.Parameter(concat_weight_param, name=concat_weight_name, + requires_grad=False) + + concat_scale_name = f"model.layers.{layer_id}.{layer_type}.w_gate_hidden._layer.matmul.weight_scale" + concat_scale_param = ms.Tensor(np.concatenate([w1_scale_param, w3_scale_param], axis=0), + dtype=w_scale_dtype) + parameter_dict[concat_scale_name] = ms.Parameter(concat_scale_param, name=concat_scale_name, + requires_grad=False) + + else: + # w1 w3 + parameter_dict[w1_weight_name] = ms.Parameter(ms.Tensor(w1_weight_param, ms.int8), name=w1_weight_name, + requires_grad=False) + parameter_dict[w3_weight_name] = ms.Parameter(ms.Tensor(w3_weight_param, ms.int8), name=w3_weight_name, + requires_grad=False) + + parameter_dict[w1_scale_name] = ms.Parameter(ms.Tensor(w1_scale_param, w_scale_dtype), + name=w1_scale_name, requires_grad=False) + parameter_dict[w3_scale_name] = ms.Parameter(ms.Tensor(w3_scale_param, w_scale_dtype), + name=w3_scale_name, requires_grad=False) + + parameter_dict[w2_weight_name] = ms.Parameter(ms.Tensor(w2_weight_param, ms.int8), name=w2_weight_name, + requires_grad=False) + parameter_dict[w2_scale_name] = ms.Parameter(ms.Tensor(w2_scale_param, w_scale_dtype), + name=w2_scale_name, requires_grad=False) + + def smooth_quant_process_ffn_weight(self, src_hf_dir, layer_id, hf_weight_map, parameter_dict, layer_type): + """smooth_quant_process_ffn_weight""" + + ffn_concat = self.config.model.model_config.ffn_concat + w1_weight_name = f"model.layers.{layer_id}.{layer_type}.w1._layer.weight" + w1_weight_param, _ = self.get_safetensor_from_file(w1_weight_name, src_hf_dir, hf_weight_map, + is_split_param=self.is_split_param, split_axis=0) + w1_scale_name = f"model.layers.{layer_id}.{layer_type}.w1._layer.matmul.weight_scale" + w1_scale_param, _ = self.get_safetensor_from_file(w1_scale_name, src_hf_dir, hf_weight_map, + is_split_param=self.is_split_param, split_axis=0) + + w3_weight_name = f"model.layers.{layer_id}.{layer_type}.w3._layer.weight" + w3_weight_param, _ = self.get_safetensor_from_file(w3_weight_name, src_hf_dir, hf_weight_map, + is_split_param=self.is_split_param, split_axis=0) + w3_scale_name = f"model.layers.{layer_id}.{layer_type}.w3._layer.matmul.weight_scale" + w3_scale_param, _ = self.get_safetensor_from_file(w3_scale_name, src_hf_dir, hf_weight_map, + is_split_param=self.is_split_param, split_axis=0) + w2_weight_name = f"model.layers.{layer_id}.{layer_type}.w2._layer.weight" + w2_scale_name = f"model.layers.{layer_id}.{layer_type}.w2._layer.matmul.weight_scale" + w2_weight_param, _ = self.get_safetensor_from_file(w2_weight_name, src_hf_dir, hf_weight_map, + is_split_param=self.is_split_param, split_axis=1) + w2_scale_param, _ = self.get_safetensor_from_file(w2_scale_name, src_hf_dir, hf_weight_map) + + if self.is_310: + w_scale_dtype = ms.float32 + else: + w_scale_dtype = ms.bfloat16 + + if ffn_concat: + concat_weight_name = f"model.layers.{layer_id}.{layer_type}.w_gate_hidden._layer.weight" + concat_weight_param = ms.Tensor(np.concatenate([w1_weight_param, w3_weight_param], axis=0), dtype=ms.int8) + parameter_dict[concat_weight_name] = ms.Parameter(concat_weight_param, name=concat_weight_name, + requires_grad=False) + + concat_scale_name = f"model.layers.{layer_id}.{layer_type}.w_gate_hidden._layer.matmul.weight_scale" + concat_scale_param = ms.Tensor(np.concatenate([w1_scale_param, w3_scale_param], axis=0), + dtype=w_scale_dtype) + parameter_dict[concat_scale_name] = ms.Parameter(concat_scale_param, name=concat_scale_name, + requires_grad=False) + else: + # w1 w3 + parameter_dict[w1_weight_name] = ms.Parameter(ms.Tensor(w1_weight_param, ms.int8), name=w1_weight_name, + requires_grad=False) + parameter_dict[w3_weight_name] = ms.Parameter(ms.Tensor(w3_weight_param, ms.int8), name=w3_weight_name, + requires_grad=False) + parameter_dict[w1_scale_name] = ms.Parameter(ms.Tensor(w1_scale_param, w_scale_dtype), + name=w1_scale_name, requires_grad=False) + parameter_dict[w3_scale_name] = ms.Parameter(ms.Tensor(w3_scale_param, w_scale_dtype), + name=w3_scale_name, requires_grad=False) + + parameter_dict[w2_weight_name] = ms.Parameter(ms.Tensor(w2_weight_param, ms.int8), name=w2_weight_name, + requires_grad=False) + parameter_dict[w2_scale_name] = ms.Parameter(ms.Tensor(w2_scale_param, w_scale_dtype), + name=w2_scale_name, requires_grad=False) + + def smooth_quant_process_qkv_weight(self, src_hf_dir, layer_id, hf_weight_map, parameter_dict): + '''smooth_quant_process_qkv_weight''' + qkv_concat = self.config.model.model_config.qkv_concat + # q2l_proj + q2l_weight_name = f"model.layers.{layer_id}.attention.q2l_proj._layer.weight" + q2l_weight_param, _ = self.get_safetensor_from_file(q2l_weight_name, src_hf_dir, hf_weight_map) + q2l_bias_name = f"model.layers.{layer_id}.attention.q2l_proj._layer.matmul.quant_bias" + q2l_bias_param, _ = self.get_safetensor_from_file(q2l_bias_name, src_hf_dir, hf_weight_map) + q2l_scale_name = f"model.layers.{layer_id}.attention.q2l_proj._layer.matmul.dequant_scale" + q2l_scale_param, _ = self.get_safetensor_from_file(q2l_scale_name, src_hf_dir, hf_weight_map) + + q2l_quant_zp = f"model.layers.{layer_id}.attention.q2l_proj.quant_op.input_zp" + q2l_quant_scale = f"model.layers.{layer_id}.attention.q2l_proj.quant_op.input_scale" + q2l_quant_zp_param, _ = self.get_safetensor_from_file(q2l_quant_zp, src_hf_dir, hf_weight_map) + q2l_quant_scale_param, _ = self.get_safetensor_from_file(q2l_quant_scale, src_hf_dir, hf_weight_map) + + kv2l_weight_name = f"model.layers.{layer_id}.attention.kv2l._layer.weight" + kv2l_weight_param, _ = self.get_safetensor_from_file(kv2l_weight_name, src_hf_dir, hf_weight_map) + kv2l_bias_name = f"model.layers.{layer_id}.attention.kv2l._layer.matmul.quant_bias" + kv2l_bias_param, _ = self.get_safetensor_from_file(kv2l_bias_name, src_hf_dir, hf_weight_map) + kv2l_scale_name = f"model.layers.{layer_id}.attention.kv2l._layer.matmul.dequant_scale" + kv2l_scale_param, _ = self.get_safetensor_from_file(kv2l_scale_name, src_hf_dir, hf_weight_map) + + kv2l_quant_zp = f"model.layers.{layer_id}.attention.kv2l.quant_op.input_zp" + kv2l_quant_scale = f"model.layers.{layer_id}.attention.kv2l.quant_op.input_scale" + kv2l_quant_zp_param, _ = self.get_safetensor_from_file(kv2l_quant_zp, src_hf_dir, hf_weight_map) + kv2l_quant_scale_param, _ = self.get_safetensor_from_file(kv2l_quant_scale, src_hf_dir, hf_weight_map) + + if self.is_310: + q2l_scale_param = q2l_scale_param.astype(np.float32).view(np.int32).astype(np.int64) + kv2l_scale_param = kv2l_scale_param.astype(np.float32).view(np.int32).astype(np.int64) + deq_scale_dtype = ms.int64 + quant_scale_dtype = ms.float16 + else: + deq_scale_dtype = ms.float32 + quant_scale_dtype = ms.bfloat16 + + if qkv_concat: + qkv2l_weight_name = f"model.layers.{layer_id}.attention.qkv2l._layer.weight" + qkv2l_bias_name = f"model.layers.{layer_id}.attention.qkv2l._layer.matmul.quant_bias" + qkv2l_scale_name = f"model.layers.{layer_id}.attention.qkv2l._layer.matmul.dequant_scale" + qkv2l_quant_zp_name = f"model.layers.{layer_id}.attention.qkv2l.quant_op.input_zp" + qkv2l_quant_scale_name = f"model.layers.{layer_id}.attention.qkv2l.quant_op.input_scale" + + qkv2l_weight = np.concatenate((q2l_weight_param, kv2l_weight_param), 0) + parameter_dict[qkv2l_weight_name] = ms.Parameter(ms.Tensor(qkv2l_weight, ms.int8), name=qkv2l_weight_name, + requires_grad=False) + qkv2l_bias = np.concatenate((q2l_bias_param, kv2l_bias_param), 0) + parameter_dict[qkv2l_bias_name] = ms.Parameter(ms.Tensor(qkv2l_bias, ms.int32), name=qkv2l_bias_name, + requires_grad=False) + qkv2l_scale = np.concatenate((q2l_scale_param, kv2l_scale_param), 0) + parameter_dict[qkv2l_scale_name] = ms.Parameter(ms.Tensor(qkv2l_scale, deq_scale_dtype), name=qkv2l_scale_name, + requires_grad=False) + parameter_dict[qkv2l_quant_zp_name] = ms.Parameter(ms.Tensor(q2l_quant_zp_param, ms.int8), + name=qkv2l_quant_zp_name, requires_grad=False) + parameter_dict[qkv2l_quant_scale_name] = ms.Parameter(ms.Tensor(q2l_quant_scale_param, quant_scale_dtype), + name=qkv2l_quant_scale_name, requires_grad=False) + else: + parameter_dict[q2l_weight_name] = ms.Parameter(ms.Tensor(q2l_weight_param, ms.int8), name=q2l_weight_name, + requires_grad=False) + parameter_dict[kv2l_weight_name] = ms.Parameter(ms.Tensor(kv2l_weight_param, ms.int8), + name=kv2l_weight_name, requires_grad=False) + parameter_dict[q2l_bias_name] = ms.Parameter(ms.Tensor(q2l_bias_param, ms.int32), name=q2l_bias_name, + requires_grad=False) + parameter_dict[kv2l_bias_name] = ms.Parameter(ms.Tensor(kv2l_bias_param, ms.int32), name=kv2l_bias_name, + requires_grad=False) + parameter_dict[q2l_scale_name] = ms.Parameter(ms.Tensor(q2l_scale_param, deq_scale_dtype), + name=q2l_scale_name, requires_grad=False) + parameter_dict[kv2l_scale_name] = ms.Parameter(ms.Tensor(kv2l_scale_param, deq_scale_dtype), + name=kv2l_scale_name, requires_grad=False) + parameter_dict[q2l_quant_zp] = ms.Parameter(ms.Tensor(q2l_quant_zp_param, ms.int8), name=q2l_quant_zp, + requires_grad=False) + parameter_dict[kv2l_quant_zp] = ms.Parameter(ms.Tensor(kv2l_quant_zp_param, ms.int8), name=kv2l_quant_zp, + requires_grad=False) + parameter_dict[q2l_quant_scale] = ms.Parameter(ms.Tensor(q2l_quant_scale_param, quant_scale_dtype), + name=q2l_quant_scale, requires_grad=False) + parameter_dict[kv2l_quant_scale] = ms.Parameter(ms.Tensor(kv2l_quant_scale_param, quant_scale_dtype), + name=kv2l_quant_scale, requires_grad=False) + + def process_route_ffn_weight_a8w4(self, src_hf_dir, layer_id, hf_weight_map, parameter_dict, layer_type): + """qqq_process_route_ffn_weight""" + + ffn_concat = self.config.model.model_config.ffn_concat + w1_weight_name = f"model.layers.{layer_id}.{layer_type}.w1._layer.weight" + w1_scale_name = f"model.layers.{layer_id}.{layer_type}.w1._layer.matmul.weight_scale" + w1_bias_name = f"model.layers.{layer_id}.{layer_type}.w1._layer.matmul.gmm_bias" + + w3_weight_name = f"model.layers.{layer_id}.{layer_type}.w3._layer.weight" + w3_scale_name = f"model.layers.{layer_id}.{layer_type}.w3._layer.matmul.weight_scale" + w3_bias_name = f"model.layers.{layer_id}.{layer_type}.w3._layer.matmul.gmm_bias" + + w2_weight_name = f"model.layers.{layer_id}.{layer_type}.w2._layer.weight" + w2_scale_name = f"model.layers.{layer_id}.{layer_type}.w2._layer.matmul.weight_scale" + w2_bias_name = f"model.layers.{layer_id}.{layer_type}.w2._layer.matmul.gmm_bias" + + w1_weight_param, _ = self.get_safetensor_from_file(w1_weight_name, src_hf_dir, hf_weight_map, + is_split_param=self.is_split_param, + split_axis=self.col_moe_split_axis) + w1_scale_param, _ = self.get_safetensor_from_file(w1_scale_name, src_hf_dir, hf_weight_map, + is_split_param=self.is_split_param, split_axis=2) + w1_scale_repeat = convert_uint64_to_fp32_gmm( + np.repeat(w1_scale_param, + w1_weight_param.shape[1] // w1_scale_param.shape[1], + axis=1)) + w1_weight_unpack = np_int8data_unpack_to_int4_3d(w1_weight_param) + w1_bias_param = 8 * np.sum( + w1_weight_unpack.astype(np.float32) * w1_scale_repeat, axis=1) + # w1_bias_param, _ = self.get_safetensor_from_file(w1_bias_name, src_hf_dir, hf_weight_map, + # is_split_param=self.is_split_param, split_axis=1) + + w3_weight_param, _ = self.get_safetensor_from_file(w3_weight_name, src_hf_dir, hf_weight_map, + is_split_param=self.is_split_param, + split_axis=self.col_moe_split_axis) + w3_scale_param, _ = self.get_safetensor_from_file(w3_scale_name, src_hf_dir, hf_weight_map, + is_split_param=self.is_split_param, split_axis=2) + w3_scale_repeat = convert_uint64_to_fp32_gmm( + np.repeat(w3_scale_param, + w3_weight_param.shape[1] // w3_scale_param.shape[1], + axis=1)) + w3_weight_unpack = np_int8data_unpack_to_int4_3d(w3_weight_param) + w3_bias_param = 8 * np.sum( + w3_weight_unpack.astype(np.float32) * w3_scale_repeat, axis=1) + # w3_bias_param, _ = self.get_safetensor_from_file(w3_bias_name, src_hf_dir, hf_weight_map, + # is_split_param=self.is_split_param, split_axis=1) + + w2_weight_param, _ = self.get_safetensor_from_file(w2_weight_name, src_hf_dir, hf_weight_map, + is_split_param=self.is_split_param, + split_axis=self.row_moe_split_axis) + w2_scale_param, _ = self.get_safetensor_from_file(w2_scale_name, src_hf_dir, hf_weight_map, + is_split_param=self.is_split_param, split_axis=1) + w2_scale_repeat = convert_uint64_to_fp32_gmm( + np.repeat(w2_scale_param, + w2_weight_param.shape[1] // w2_scale_param.shape[1], + axis=1)) + w2_weight_unpack = np_int8data_unpack_to_int4_3d(w2_weight_param) + w2_bias_param = 8 * np.sum( + w2_weight_unpack.astype(np.float32) * w2_scale_repeat, axis=1) + # w2_bias_param, _ = self.get_safetensor_from_file(w2_bias_name, src_hf_dir, hf_weight_map, + # is_split_param=self.is_split_param, split_axis=1) + + if self.is_310: + w1_weight_param = w1_weight_param.astype(np.int8) + w1_weight_param = unpack_int8_to_int4_3d(w1_weight_param) + w1_weight_param = np_int4data_pack_to_int8_3d(w1_weight_param.transpose(0, 2, 1)) + + w2_weight_param = w2_weight_param.astype(np.int8) + w2_weight_param = unpack_int8_to_int4_3d(w2_weight_param) + w2_weight_param = np_int4data_pack_to_int8_3d(w2_weight_param.transpose(0, 2, 1)) + + w3_weight_param = w3_weight_param.astype(np.int8) + w3_weight_param = unpack_int8_to_int4_3d(w3_weight_param) + w3_weight_param = np_int4data_pack_to_int8_3d(w3_weight_param.transpose(0, 2, 1)) + + w1_scale_param = convert_uint64_to_fp32(w1_scale_param) + w2_scale_param = convert_uint64_to_fp32(w2_scale_param) + w3_scale_param = convert_uint64_to_fp32(w3_scale_param) + w_scale_dtype = ms.float32 + else: + w_scale_dtype = ms.uint64 + + if ffn_concat: + concat_weight_name = f"model.layers.{layer_id}.{layer_type}.w_gate_hidden._layer.weight" + concat_weight_param = ms.Tensor(np.concatenate([w1_weight_param, w3_weight_param], + axis=1), dtype=ms.qint4x2) + parameter_dict[concat_weight_name] = ms.Parameter(concat_weight_param, name=concat_weight_name, + requires_grad=False) + + concat_scale_name = f"model.layers.{layer_id}.{layer_type}.w_gate_hidden._layer.matmul.weight_scale" + concat_scale_param = ms.Tensor(np.concatenate([w1_scale_param, w3_scale_param], axis=2), dtype=w_scale_dtype) + parameter_dict[concat_scale_name] = ms.Parameter(concat_scale_param, name=concat_scale_name, + requires_grad=False) + + concat_scale_name = f"model.layers.{layer_id}.{layer_type}.w_gate_hidden._layer.matmul.gmm_bias" + concat_scale_param = ms.Tensor(np.concatenate([w1_bias_param, w3_bias_param], axis=1), dtype=ms.float32) + parameter_dict[concat_scale_name] = ms.Parameter(concat_scale_param, name=concat_scale_name, + requires_grad=False) + else: + # w1 w3 + parameter_dict[w1_weight_name] = ms.Parameter(ms.Tensor(w1_weight_param, ms.qint4x2), name=w1_weight_name, + requires_grad=False) + parameter_dict[w3_weight_name] = ms.Parameter(ms.Tensor(w3_weight_param, ms.qint4x2), name=w3_weight_name, + requires_grad=False) + + parameter_dict[w1_scale_name] = ms.Parameter(ms.Tensor(w1_scale_param, w_scale_dtype), + name=w1_scale_name, requires_grad=False) + parameter_dict[w3_scale_name] = ms.Parameter(ms.Tensor(w3_scale_param, w_scale_dtype), + name=w3_scale_name, requires_grad=False) + + parameter_dict[w1_bias_name] = ms.Parameter(ms.Tensor(w1_bias_param, ms.float32), + name=w1_bias_name, requires_grad=False) + parameter_dict[w3_bias_name] = ms.Parameter(ms.Tensor(w3_bias_param, ms.float32), + name=w3_bias_name, requires_grad=False) + + parameter_dict[w2_weight_name] = ms.Parameter(ms.Tensor(w2_weight_param, ms.qint4x2), name=w2_weight_name, + requires_grad=False) + parameter_dict[w2_scale_name] = ms.Parameter(ms.Tensor(w2_scale_param, w_scale_dtype), + name=w2_scale_name, requires_grad=False) + parameter_dict[w2_bias_name] = ms.Parameter(ms.Tensor(w2_bias_param, ms.float32), + name=w2_bias_name, requires_grad=False) + def infer_smooth_quant_net_ms_convert_layer_weight(self, src_hf_dir, num_layers, hf_weight_map): """infer_smooth_quant_net_ms_convert_layer_weight""" parameter_dict = {} @@ -1225,6 +1579,108 @@ class DeepseekV3WeightProcessor(BaseWeightProcessor): name=param_name, requires_grad=False) _, _ = ms.load_param_into_net(self.network, parameter_dict) + def infer_smooth_quant_row_linear_split(self, param_name, src_hf_dir, hf_weight_map): + '''infer_smooth_quant_row_linear_split''' + if param_name.endswith(".weight"): + value, _ = self.get_safetensor_from_file(param_name, src_hf_dir, hf_weight_map, + is_split_param=self.is_split_param, split_axis=1) + elif "quant_op" in param_name: + value, _ = self.get_safetensor_from_file(param_name, src_hf_dir, hf_weight_map, + is_split_param=self.is_split_param, split_axis=0) + else: + value, _ = self.get_safetensor_from_file(param_name, src_hf_dir, + hf_weight_map) + if "wo._layer.matmul.quant_bias" in param_name and get_tensor_model_parallel_rank() != 0: + value.fill(0) + return value + + def infer_a8w4_get_value(self, param_name, src_hf_dir, hf_weight_map, no_need_split_layer): + '''infer_smooth_quant_get_value''' + + if any([name in param_name for name in no_need_split_layer]): + value, is_int4 = self.get_safetensor_from_file(param_name, src_hf_dir, + hf_weight_map) + elif any([name in param_name for name in [".l2q_proj."]]): + if param_name.endswith(".weight") or "matmul" in param_name: + value, is_int4 = self.get_safetensor_from_file(param_name, src_hf_dir, hf_weight_map, + is_split_param=self.is_split_param, split_axis=0) + if self.is_310 and "dequant_scale" in param_name: + value = value.view(np.int32).astype(np.int64) + else: + value, is_int4 = self.get_safetensor_from_file(param_name, src_hf_dir, + hf_weight_map) + elif any([name in param_name for name in [".feed_forward.w2.", ".wo.", "shared_experts.w2"]]): + value = self.infer_smooth_quant_row_linear_split(param_name, src_hf_dir, hf_weight_map) + if self.is_310 and "dequant_scale" in param_name: + value = value.view(np.int32).astype(np.int64) + is_int4 = False + elif ".routed_experts.ffn.w2" in param_name: + value, is_int4 = self.get_safetensor_from_file(param_name, src_hf_dir, hf_weight_map, + is_split_param=self.is_split_param, split_axis=1) + elif any([name in param_name for name in ["lkv2kv_k_nope", "lkv2kv_v", "absorb"]]): + value, is_int4 = self.get_safetensor_from_file(param_name, src_hf_dir, hf_weight_map, + is_split_param=self.is_split_param, split_axis=0) + elif "lm_head" in param_name: + if not self.config.parallel_config.vocab_emb_dp: + value, is_int4 = self.get_safetensor_from_file(param_name, src_hf_dir, hf_weight_map, + is_split_param=self.is_split_param, split_axis=0) + else: + value, is_int4 = self.get_safetensor_from_file(param_name, src_hf_dir, hf_weight_map) + else: + raise ValueError(f"not found layer {param_name}, please check safetensors file.") + return value, is_int4 + + def infer_a8w4_net_ms_convert_layer_weight(self, src_hf_dir, num_layers, hf_weight_map): + '''infer_qqq_net_ms_convert_layer_weight''' + parameter_dict = {} + start_layer_index, end_layer_index = self.get_layer_index(num_layers) + network_names = [] + for m in self.network.parameters_and_names(): + network_names.append(m[0]) + no_need_split_layer = ["tok_embeddings", "norm", "routed_experts.router.dense", + "routed_experts.router.e_score_correction_bias", + "topk_bias"] + for layer_id in tqdm(range(start_layer_index, end_layer_index), desc="qkv/ffn params load"): + if layer_id >= 3: + self.process_route_ffn_weight_a8w4(src_hf_dir, layer_id, hf_weight_map, parameter_dict, + "feed_forward.routed_experts.ffn") + self.smooth_quant_process_shared_ffn_weight(src_hf_dir, layer_id, hf_weight_map, parameter_dict, + "feed_forward.shared_experts") + else: + self.smooth_quant_process_ffn_weight(src_hf_dir, layer_id, hf_weight_map, parameter_dict, + "feed_forward") + self.smooth_quant_process_qkv_weight(src_hf_dir, layer_id, hf_weight_map, parameter_dict) + + if self.is_310: + self.small_cache_weight(src_hf_dir, layer_id, hf_weight_map, parameter_dict) + + skip_layer = ["feed_forward.routed_experts.ffn", "feed_forward.shared_experts", "feed_forward.w", + "attention.kv2l", "attention.q2l", "attention.qkv"] + + for param_name, _ in tqdm(hf_weight_map.items(), desc="remaining params load"): + if param_name not in network_names: + continue + + if "model.layers" in param_name and int(param_name.split('.')[2]) >= num_layers: + continue + + if any([name in param_name for name in skip_layer]): + continue + + value, is_int4 = self.infer_a8w4_get_value(param_name, src_hf_dir, hf_weight_map, no_need_split_layer) + dst_dtype = convert_np_to_ms_dtype(value) + + if is_int4: + parameter_dict[param_name] = ms.Parameter(ms.Tensor(value, dtype=dtype.qint4x2), + name=param_name, requires_grad=False) + else: + parameter_dict[param_name] = ms.Parameter(ms.Tensor(value, dtype=dst_dtype), + name=param_name, requires_grad=False) + + param_not_load, ckpt_not_load = ms.load_param_into_net(self.network, parameter_dict) + print(f"a8w4 param_not_load:{param_not_load}") + print(f"a8w4 ckpt_not_load:{ckpt_not_load}") + def load_safetensors_shard(self, src_hf_dir, is_mtp_model=False): """deepseek load safetensors and shard """ rank_id = get_rank() @@ -1256,7 +1712,8 @@ class DeepseekV3WeightProcessor(BaseWeightProcessor): quantization_config = self.config.model.model_config.quantization_config quant_method = quantization_config.quant_method if quantization_config else None - if not quant_method or (quant_method != "gptq-pergroup" and quant_method != "smoothquant") and \ + support_quant_method = ["gptq-pergroup", "smoothquant", "a8w4"] + if not quant_method or (quant_method not in support_quant_method) and \ not is_mtp_model: self.infer_convert_outer_weight(src_hf_dir, hf_weight_map) @@ -1266,6 +1723,9 @@ class DeepseekV3WeightProcessor(BaseWeightProcessor): if quant_method and quant_method == "smoothquant": self.infer_smooth_quant_net_ms_convert_layer_weight(src_hf_dir, self.num_layers, hf_weight_map) return + if quant_method and quant_method == "a8w4": + self.infer_a8w4_net_ms_convert_layer_weight(src_hf_dir, self.num_layers, hf_weight_map) + return enable_tqdm = rank_id == 0 mtp_layers = self.config.model.model_config.num_nextn_predict_layers diff --git a/vllm_mindspore/model_executor/models/mf_models/weight_processor.py b/vllm_mindspore/model_executor/models/mf_models/weight_processor.py index 5459a6033..28b4e5e08 100644 --- a/vllm_mindspore/model_executor/models/mf_models/weight_processor.py +++ b/vllm_mindspore/model_executor/models/mf_models/weight_processor.py @@ -43,6 +43,7 @@ class BaseWeightProcessor: self.is_310 = is_310p() self.pp_group_size = get_pp_world_size() # self.tp_group_size = get_group_size() + self.pp_group_size = get_pp_world_size() self.tp_group_size = get_tensor_model_parallel_world_size() self.global_rank_id = get_rank() self.rank_id = get_tensor_model_parallel_rank() diff --git a/vllm_mindspore/utils.py b/vllm_mindspore/utils.py index f82867c1f..014990fa8 100644 --- a/vllm_mindspore/utils.py +++ b/vllm_mindspore/utils.py @@ -256,6 +256,8 @@ def convert_np_to_ms_dtype(value): value_dtype = ms.int32 elif value.dtype == np.int64: value_dtype = ms.int64 + elif value.dtype == np.uint64: + value_dtype = ms.uint64 elif value.dtype == np.float64: value_dtype = ms.float64 elif value.dtype == np.float32: -- Gitee