diff --git a/.jenkins/test/config/dependent_packages.yaml b/.jenkins/test/config/dependent_packages.yaml index 4966b89a21dbb9291a118ed2012619ed8efe43ac..b2823dcacb5fff230aa015decbfa82b55a26ac7e 100644 --- a/.jenkins/test/config/dependent_packages.yaml +++ b/.jenkins/test/config/dependent_packages.yaml @@ -1,11 +1,11 @@ mindspore: - 'https://repo.mindspore.cn/mindspore/mindspore/version/202506/20250605/master_20250605212230_aac98ab9732926f6abd4c3d73be47d5be6c93ead_newest/' + 'https://repo.mindspore.cn/mindspore/mindspore/version/202507/20250715/br_infer_iter_20250715031508_4a1c05507ed221ad7557813816990c8e20d7d090_newest/' mindspore_gs: - 'https://repo.mindspore.cn/mindspore/golden-stick/version/202507/20250709/master_20250709010018_5f01a0211ca36690a577d3d456c5ba194c88771d_newest/' + 'https://repo.mindspore.cn/mindspore/golden-stick/version/202507/20250714/develop_20250714153506_7026d5afce6b611d3ec2653bee26a263dead90b8_newest/' msadapter: 'https://repo.mindspore.cn/mindspore/msadapter/version/202505/20250526/master_20250526120007_b76cb7804d1c9555e32a57439c1d412ff86293d1_newest/' vllm: - 'https://repo.mindspore.cn/mirrors/vllm/version/202505/20250514/v0.8.4.dev0_newest/' + 'https://repo.mindspore.cn/mirrors/vllm/version/202505/20250514/v0.8.4.dev0_newest/' \ No newline at end of file diff --git a/tests/mindformers b/tests/mindformers index 3e257a44384b927bc0fe26348047d7fe44a954db..e7fafe66254d03a88f7414ceb736443773c4c525 160000 --- a/tests/mindformers +++ b/tests/mindformers @@ -1 +1 @@ -Subproject commit 3e257a44384b927bc0fe26348047d7fe44a954db +Subproject commit e7fafe66254d03a88f7414ceb736443773c4c525 diff --git a/tests/st/python/cases_parallel/vllm_deepseek_a16w4.py b/tests/st/python/cases_parallel/vllm_deepseek_a16w4.py new file mode 100644 index 0000000000000000000000000000000000000000..779f1c668c1c1deb751237c5add11003decfec54 --- /dev/null +++ b/tests/st/python/cases_parallel/vllm_deepseek_a16w4.py @@ -0,0 +1,94 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Copyright 2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""test mf deepseek r1 gptq int4 quantization.""" + +# type: ignore +# isort: skip_file + +import os +import yaml +from tests.st.python import utils + + +def teardown_function(): + utils.cleanup_subprocesses() + + +env_manager = utils.EnvVarManager() +# def env +env_vars = { + "MINDFORMERS_MODEL_CONFIG": "./config/predict_deepseek_r1_671b_a16w4.yaml", + "ASCEND_CUSTOM_PATH": os.path.expandvars("$ASCEND_HOME_PATH/../"), + "vLLM_MODEL_BACKEND": "MindFormers", + "MS_ENABLE_LCCL": "off", + "HCCL_OP_EXPANSION_MODE": "AIV", + "MS_ALLOC_CONF": "enable_vmm:True", + "LCCL_DETERMINISTIC": "1", + "HCCL_DETERMINISTIC": "true", + "ATB_MATMUL_SHUFFLE_K_ENABLE": "0", + "ATB_LLM_LCOC_ENABLE": "0", + "VLLM_USE_V1": "0" +} +# set env +env_manager.setup_ai_environment(env_vars) +import vllm_mindspore # noqa: F401, E402 +from vllm import LLM, SamplingParams # noqa: E402 + + +def test_deepseek_r1_gptq_a16w4(): + """ + test case deepseek r1 a16w4 + """ + yaml_path = "./config/predict_deepseek_r1_671b.yaml" + a16w4_yaml = "./config/predict_deepseek_r1_671b_a16w4.yaml" + with open(yaml_path, encoding='utf-8') as file: + content = yaml.safe_load(file) + model_config = content["model"]["model_config"] + model_config["quantization_config"] = {"quant_method": "gptq-pergroup"} + content["model"]["model_config"] = model_config + + with open(a16w4_yaml, 'w', encoding='utf-8') as file: + yaml.dump(content, file, allow_unicode=True, sort_keys=False) + + # Sample prompts. + prompts = [ + "介绍下北京故宫", + ] + + # Create a sampling params object. + sampling_params = SamplingParams(temperature=0.0, max_tokens=1024, top_k=1) + + # Create an LLM. + llm = LLM( + model= + "/home/workspace/mindspore_dataset/weight/DeepSeekR1_gptq-pergroup_safetensors", + trust_remote_code=True, + gpu_memory_utilization=0.9, + tensor_parallel_size=4, + max_model_len=4096) + # Generate texts from the prompts. + # The output is a list of RequestOutput objects + # that contain the prompt, generated text, and other information. + outputs = llm.generate(prompts, sampling_params) + # Print the outputs. + for i, output in enumerate(outputs): + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + assert "博物院christianాలు sic辨" in generated_text + + # unset env + env_manager.unset_all() diff --git a/tests/st/python/cases_parallel/vllm_deepseek_a8w4.py b/tests/st/python/cases_parallel/vllm_deepseek_a8w4.py new file mode 100644 index 0000000000000000000000000000000000000000..06bbd3377b55225ac5aa976d93e1f343954b5433 --- /dev/null +++ b/tests/st/python/cases_parallel/vllm_deepseek_a8w4.py @@ -0,0 +1,94 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Copyright 2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""test mf deepseek r1 a8w4 quantization.""" + +# type: ignore +# isort: skip_file + +import os +import yaml +from tests.st.python import utils + + +def teardown_function(): + utils.cleanup_subprocesses() + + +env_manager = utils.EnvVarManager() +# def env +env_vars = { + "MINDFORMERS_MODEL_CONFIG": "./config/predict_deepseek_r1_671b_a8w4.yaml", + "ASCEND_CUSTOM_PATH": os.path.expandvars("$ASCEND_HOME_PATH/../"), + "vLLM_MODEL_BACKEND": "MindFormers", + "MS_ENABLE_LCCL": "off", + "HCCL_OP_EXPANSION_MODE": "AIV", + "MS_ALLOC_CONF": "enable_vmm:True", + "LCCL_DETERMINISTIC": "1", + "HCCL_DETERMINISTIC": "true", + "ATB_MATMUL_SHUFFLE_K_ENABLE": "0", + "ATB_LLM_LCOC_ENABLE": "0", + "VLLM_USE_V1": "0" +} +# set env +env_manager.setup_ai_environment(env_vars) +import vllm_mindspore # noqa: F401, E402 +from vllm import LLM, SamplingParams # noqa: E402 + + +def test_deepseek_r1_a8w4(): + """ + test case deepseek r1 a8w4 + """ + yaml_path = "./config/predict_deepseek_r1_671b.yaml" + a8w4_yaml = "./config/predict_deepseek_r1_671b_a8w4.yaml" + with open(yaml_path, encoding='utf-8') as file: + content = yaml.safe_load(file) + model_config = content["model"]["model_config"] + model_config["quantization_config"] = {"quant_method": "a8w4"} + content["model"]["model_config"] = model_config + + with open(a8w4_yaml, 'w', encoding='utf-8') as file: + yaml.dump(content, file, allow_unicode=True, sort_keys=False) + + # Sample prompts. + prompts = [ + "介绍下北京故宫", + ] + + # Create a sampling params object. + sampling_params = SamplingParams(temperature=0.0, max_tokens=1024, top_k=1) + + # Create an LLM. + llm = LLM( + model= + "/home/workspace/mindspore_dataset/weight/DeepSeekR1_A8W4_safetensors", + trust_remote_code=True, + gpu_memory_utilization=0.9, + tensor_parallel_size=4, + max_model_len=4096) + # Generate texts from the prompts. + # The output is a list of RequestOutput objects + # that contain the prompt, generated text, and other information. + outputs = llm.generate(prompts, sampling_params) + # Print the outputs. + for i, output in enumerate(outputs): + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + assert "博物院ODాలు SER비스티rok等地" in generated_text + + # unset env + env_manager.unset_all() diff --git a/tests/st/python/test_cases_parallel.py b/tests/st/python/test_cases_parallel.py index 4bca8b6806f1d3e5642e4fa9782a5c148c8603cc..7540e35360cf01fa1f98d5ebf12c42eb0aec70ef 100644 --- a/tests/st/python/test_cases_parallel.py +++ b/tests/st/python/test_cases_parallel.py @@ -173,7 +173,9 @@ def test_cases_parallel_part5(): cases = [(2, "cases_parallel/vllm_mf_qwen3_8b.py::test_mf_qwen3_v0", "vllm_mf_qwen3_8b_test_mf_qwen3.log"), (2, "cases_parallel/vllm_mf_qwen3_8b.py::test_mf_qwen3_v1", - "vllm_mf_qwen3_8b_v1_test_mf_qwen3.log")] + "vllm_mf_qwen3_8b_v1_test_mf_qwen3.log"), + (2, "cases_parallel/vllm_deepseek_a8w4.py::test_deepseek_r1_a8w4", + "vllm_deepseek_a8w4_test_deepseek_r1_a8w4.log")] run_tasks(cases) diff --git a/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py b/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py index fca2faca15390b306693a49486f201fae16b3f27..933cea12577e84a223f66b5a9bf5ae75bc072619 100644 --- a/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py +++ b/vllm_mindspore/model_executor/models/mf_models/deepseek_v3.py @@ -350,6 +350,51 @@ class DeepseekV3ForCausalLM(MfModelBase): act_quant_granularity=QuantGranularity.PER_TOKEN, opname_blacklist=['lm_head', 'lkv2kv']) layer_policies = OrderedDict() + elif quant_type.lower() == 'a8w4': + cfg = PTQConfig(mode=quant_mode, + backend=BackendTarget.ASCEND, + weight_quant_dtype=msdtype.int8, + act_quant_dtype=msdtype.int8, + outliers_suppression=OutliersSuppressionType. + OUTLIER_SUPPRESSION_LITE, + opname_blacklist=['lm_head', 'lkv2kv'], + weight_clip=True) + mlp_config = PTQConfig( + mode=quant_mode, + backend=BackendTarget.ASCEND, + weight_quant_dtype=msdtype.int8, + act_quant_dtype=msdtype.int8, + outliers_suppression=OutliersSuppressionType.NONE, + precision_recovery=PrecisionRecovery.NONE, + act_quant_granularity=QuantGranularity.PER_TOKEN, + weight_quant_granularity=QuantGranularity.PER_CHANNEL, + weight_clip=True) + gptq_config = GPTQQuantConfig(static_groups=True, desc_act=True) + moe_cfg = PTQConfig( + mode=quant_mode, + backend=BackendTarget.ASCEND, + weight_quant_dtype=msdtype.qint4x2, + act_quant_dtype=msdtype.int8, + act_quant_granularity=QuantGranularity.PER_TOKEN, + weight_quant_granularity=QuantGranularity.PER_GROUP, + group_size=256, + algo_args=gptq_config, + precision_recovery=PrecisionRecovery.GPTQ, + weight_clip=True) + layer_policies = OrderedDict({ + r'.*\.feed_forward\.w2.*': + mlp_config, + r'.*\.feed_forward\.w_gate_hidden.*': + mlp_config, + r'.*\.shared_experts\.w2.*': + mlp_config, + r'.*\.shared_experts\.w_gate_hidden.*': + mlp_config, + r'.*\.routed_experts\.ffn\.w_gate_hidden.*': + moe_cfg, + r'.*\.routed_experts\.ffn\.w2.*': + moe_cfg + }) else: logger.warning("Input unsupported quant type: %s.", quant_type) return None diff --git a/vllm_mindspore/model_executor/models/mf_models/deepseekv3_weight_processor.py b/vllm_mindspore/model_executor/models/mf_models/deepseekv3_weight_processor.py index b344eefe58aae1c72af65de992fa490e0e447202..d67174108780bbbb14b79496a9827a46f8d30b87 100644 --- a/vllm_mindspore/model_executor/models/mf_models/deepseekv3_weight_processor.py +++ b/vllm_mindspore/model_executor/models/mf_models/deepseekv3_weight_processor.py @@ -53,6 +53,26 @@ def convert_np_to_ms_dtype(value): return value_dtype +def np_int8data_unpack_to_int4_3d(np_data): + """unpack int8 data to int4 in 3dim""" + np_data = np_data.astype(np.uint8) + np_data_low = ((np_data & 0x0F) << 4).astype(np.int8) >> 4 + np_data_high = ((np_data >> 4) << 4).astype(np.int8) >> 4 + + np_int4_data = np.zeros( + (np_data.shape[0], np_data.shape[1], np_data.shape[2] * 2), + dtype=np.int8) + np_int4_data[:, :, ::2] = np_data_low + np_int4_data[:, :, 1::2] = np_data_high + return np_int4_data + + +def convert_uint64_to_fp32(uint64_data: np.ndarray) -> np.ndarray: + """Convert uint64 data to float32""" + uint32_data = uint64_data.astype(np.uint32) + return uint32_data.view(np.float32) + + class DeepseekV3WeightProcessor(BaseWeightProcessor): r""" Provide DeepseekV3/R1 Model weight load and shards. @@ -2173,6 +2193,263 @@ class DeepseekV3WeightProcessor(BaseWeightProcessor): requires_grad=False) _, _ = ms.load_param_into_net(self.network, parameter_dict) + def process_route_ffn_weight_a8w4(self, src_hf_dir, layer_id, + hf_weight_map, parameter_dict, + layer_type): + """qqq_process_route_ffn_weight""" + base_param = f"model.layers.{layer_id}.{layer_type}" + ffn_concat = self.config.model.model_config.ffn_concat + w1_weight_name = f"{base_param}.w1._layer.weight" + w1_scale_name = f"{base_param}.w1._layer.matmul.weight_scale" + w1_bias_name = f"{base_param}.w1._layer.matmul.gmm_bias" + + w3_weight_name = f"{base_param}.w3._layer.weight" + w3_scale_name = f"{base_param}.w3._layer.matmul.weight_scale" + w3_bias_name = f"{base_param}.w3._layer.matmul.gmm_bias" + + w2_weight_name = f"{base_param}.w2._layer.weight" + w2_scale_name = f"{base_param}.w2._layer.matmul.weight_scale" + w2_bias_name = f"{base_param}.w2._layer.matmul.gmm_bias" + + w1_weight_param, _ = self.get_routed_safetensor_3_dim( + w1_weight_name, + src_hf_dir, + hf_weight_map, + tp_axis=2, + split_ep=self.moe_split_ep, + split_tp=self.moe_split_tp) + w1_scale_param, _ = self.get_routed_safetensor_3_dim( + w1_scale_name, + src_hf_dir, + hf_weight_map, + tp_axis=2, + split_ep=self.moe_split_ep, + split_tp=self.moe_split_tp) + w1_scale_repeat = convert_uint64_to_fp32( + np.repeat(w1_scale_param, + w1_weight_param.shape[1] // w1_scale_param.shape[1], + axis=1)) + w1_weight_unpack = np_int8data_unpack_to_int4_3d(w1_weight_param) + w1_bias_param = 8 * np.sum( + w1_weight_unpack.astype(np.float32) * w1_scale_repeat, axis=1) + + w3_weight_param, _ = self.get_routed_safetensor_3_dim( + w3_weight_name, + src_hf_dir, + hf_weight_map, + tp_axis=2, + split_ep=self.moe_split_ep, + split_tp=self.moe_split_tp) + w3_scale_param, _ = self.get_routed_safetensor_3_dim( + w3_scale_name, + src_hf_dir, + hf_weight_map, + tp_axis=2, + split_ep=self.moe_split_ep, + split_tp=self.moe_split_tp) + w3_scale_repeat = convert_uint64_to_fp32( + np.repeat(w3_scale_param, + w3_weight_param.shape[1] // w3_scale_param.shape[1], + axis=1)) + w3_weight_unpack = np_int8data_unpack_to_int4_3d(w3_weight_param) + w3_bias_param = 8 * np.sum( + w3_weight_unpack.astype(np.float32) * w3_scale_repeat, axis=1) + + w2_weight_param, _ = self.get_routed_safetensor_3_dim( + w2_weight_name, + src_hf_dir, + hf_weight_map, + tp_axis=1, + split_ep=self.moe_split_ep, + split_tp=self.moe_split_tp) + w2_scale_param, _ = self.get_routed_safetensor_3_dim( + w2_scale_name, + src_hf_dir, + hf_weight_map, + tp_axis=1, + split_ep=self.moe_split_ep, + split_tp=self.moe_split_tp) + w2_scale_repeat = convert_uint64_to_fp32( + np.repeat(w2_scale_param, + w2_weight_param.shape[1] // w2_scale_param.shape[1], + axis=1)) + w2_weight_unpack = np_int8data_unpack_to_int4_3d(w2_weight_param) + w2_bias_param = 8 * np.sum( + w2_weight_unpack.astype(np.float32) * w2_scale_repeat, axis=1) + + if ffn_concat: + concat_weight_name = f"{base_param}.w_gate_hidden._layer.weight" + concat_weight_param = ms.Tensor(np.concatenate( + [w1_weight_param, w3_weight_param], axis=2), + dtype=ms.qint4x2) + parameter_dict[concat_weight_name] = ms.Parameter( + concat_weight_param, + name=concat_weight_name, + requires_grad=False) + + concat_scale_name = \ + f"{base_param}.w_gate_hidden._layer.matmul.weight_scale" + concat_scale_param = ms.Tensor(np.concatenate( + [w1_scale_param, w3_scale_param], axis=2), + dtype=ms.uint64) + parameter_dict[concat_scale_name] = ms.Parameter( + concat_scale_param, + name=concat_scale_name, + requires_grad=False) + + concat_scale_name = \ + f"{base_param}.w_gate_hidden._layer.matmul.gmm_bias" + concat_scale_param = ms.Tensor(np.concatenate( + [w1_bias_param, w3_bias_param], axis=1), + dtype=ms.float32) + parameter_dict[concat_scale_name] = ms.Parameter( + concat_scale_param, + name=concat_scale_name, + requires_grad=False) + else: + # w1 w3 + parameter_dict[w1_weight_name] = ms.Parameter(ms.Tensor( + w1_weight_param, ms.qint4x2), + name=w1_weight_name, + requires_grad=False) + parameter_dict[w3_weight_name] = ms.Parameter(ms.Tensor( + w3_weight_param, ms.qint4x2), + name=w3_weight_name, + requires_grad=False) + + parameter_dict[w1_scale_name] = ms.Parameter(ms.Tensor( + w1_scale_param, ms.uint64), + name=w1_scale_name, + requires_grad=False) + parameter_dict[w3_scale_name] = ms.Parameter(ms.Tensor( + w3_scale_param, ms.uint64), + name=w3_scale_name, + requires_grad=False) + + parameter_dict[w1_bias_name] = ms.Parameter(ms.Tensor( + w1_bias_param, ms.float32), + name=w1_bias_name, + requires_grad=False) + parameter_dict[w3_bias_name] = ms.Parameter(ms.Tensor( + w3_bias_param, ms.float32), + name=w3_bias_name, + requires_grad=False) + + parameter_dict[w2_weight_name] = ms.Parameter(ms.Tensor( + w2_weight_param, ms.qint4x2), + name=w2_weight_name, + requires_grad=False) + parameter_dict[w2_scale_name] = ms.Parameter(ms.Tensor( + w2_scale_param, ms.uint64), + name=w2_scale_name, + requires_grad=False) + parameter_dict[w2_bias_name] = ms.Parameter(ms.Tensor( + w2_bias_param, ms.float32), + name=w2_bias_name, + requires_grad=False) + + def infer_a8w4_get_value(self, param_name, src_hf_dir, hf_weight_map, + no_need_split_layer): + '''infer_smooth_quant_get_value''' + + if any([name in param_name for name in no_need_split_layer]): + value, is_int4 = self.get_safetensor_from_file( + param_name, src_hf_dir, hf_weight_map) + elif any([name in param_name for name in [".l2q_proj."]]): + if param_name.endswith(".weight") or "matmul" in param_name: + value, is_int4 = self.get_safetensor_from_file_split_tp_group( + param_name, src_hf_dir, hf_weight_map, split_axis=0) + else: + value, is_int4 = self.get_safetensor_from_file( + param_name, src_hf_dir, hf_weight_map) + elif any([ + name in param_name + for name in [".feed_forward.w2.", ".wo.", "shared_experts.w2"] + ]): + value = self.infer_smooth_quant_row_linear_split( + param_name, src_hf_dir, hf_weight_map) + is_int4 = False + elif ".routed_experts.ffn.w2" in param_name: + value, is_int4 = self.get_safetensor_from_file_split_tp_group( + param_name, src_hf_dir, hf_weight_map, split_axis=1) + elif any( + [name in param_name for name in ["lkv2kv_k_nope", "lkv2kv_v"]]): + value, is_int4 = self.get_safetensor_from_file_split_tp_group( + param_name, src_hf_dir, hf_weight_map, split_axis=0) + elif "lm_head" in param_name: + if not self.config.parallel_config.vocab_emb_dp: + value, is_int4 = self.get_safetensor_from_file_split_tp_group( + param_name, src_hf_dir, hf_weight_map, split_axis=0) + else: + value, is_int4 = self.get_safetensor_from_file( + param_name, src_hf_dir, hf_weight_map) + else: + raise ValueError( + f"not found layer {param_name}, please check safetensors file." + ) + return value, is_int4 + + def infer_a8w4_net_ms_convert_layer_weight(self, src_hf_dir, num_layers, + hf_weight_map): + '''infer_qqq_net_ms_convert_layer_weight''' + parameter_dict = {} # type: ignore[var-annotated] + + no_need_split_layer = [ + "tok_embeddings", "norm", "routed_experts.router.dense", + "routed_experts.router.e_score_correction_bias", "topk_bias" + ] + for layer_id in tqdm(range(num_layers), desc="qkv/ffn params load"): + if layer_id >= 3: + self.process_route_ffn_weight_a8w4( + src_hf_dir, layer_id, hf_weight_map, parameter_dict, + "feed_forward.routed_experts.ffn") + self.smooth_quant_process_shared_ffn_weight( + src_hf_dir, layer_id, hf_weight_map, parameter_dict, + "feed_forward.shared_experts") + + else: + self.smooth_quant_process_ffn_weight(src_hf_dir, layer_id, + hf_weight_map, + parameter_dict, + "feed_forward") + self.smooth_quant_process_qkv_weight(src_hf_dir, layer_id, + hf_weight_map, parameter_dict) + + skip_layer = [ + "feed_forward.routed_experts.ffn", "feed_forward.shared_experts", + "feed_forward.w", "attention.kv2l", "attention.q" + ] + + for param_name, _ in tqdm(hf_weight_map.items(), + desc="remaining params load"): + if "model.layers" in param_name and int( + param_name.split('.')[2]) >= num_layers: + continue + + if any([name in param_name for name in skip_layer]): + continue + + value, is_int4 = self.infer_a8w4_get_value(param_name, src_hf_dir, + hf_weight_map, + no_need_split_layer) + dst_dtype = convert_np_to_ms_dtype(value) + + if is_int4: + parameter_dict[param_name] = ms.Parameter(ms.Tensor( + value, dtype=dtype.qint4x2), + name=param_name, + requires_grad=False) + else: + parameter_dict[param_name] = ms.Parameter(ms.Tensor( + value, dtype=dst_dtype), + name=param_name, + requires_grad=False) + + param_not_load, ckpt_not_load = ms.load_param_into_net( + self.network, parameter_dict) + print(f"a8w4 param_not_load:{param_not_load}") + print(f"a8w4 ckpt_not_load:{ckpt_not_load}") + def load_safetensors_shard(self, src_hf_dir, is_mtp_model=False): """deepseek load safetensors and shard """ rank_id = get_rank() @@ -2203,7 +2480,7 @@ class DeepseekV3WeightProcessor(BaseWeightProcessor): quantization_config = self.config.model.model_config.quantization_config quant_method = (quantization_config.quant_method if quantization_config else None) - support_quant_method = ["gptq-pergroup", "smoothquant", "osl"] + support_quant_method = ["gptq-pergroup", "smoothquant", "osl", "a8w4"] if not quant_method or (quant_method not in support_quant_method) and \ not is_mtp_model: self.infer_convert_outer_weight(src_hf_dir, hf_weight_map) @@ -2220,6 +2497,10 @@ class DeepseekV3WeightProcessor(BaseWeightProcessor): self.infer_smooth_quant_net_ms_convert_layer_weight( src_hf_dir, self.num_layers, hf_weight_map) return + if quant_method and quant_method == "a8w4": + self.infer_a8w4_net_ms_convert_layer_weight( + src_hf_dir, self.num_layers, hf_weight_map) + return enable_tqdm = rank_id == 0 mtp_layers = self.config.model.model_config.num_nextn_predict_layers