From acd1bec890b602c9ca46ae2c0e9c80154ccf5019 Mon Sep 17 00:00:00 2001 From: horcam Date: Mon, 14 Jul 2025 17:51:13 +0800 Subject: [PATCH] [develop][codecheck] clean code --- codecheck_toolkits/pyproject.toml | 2 +- .../mf_models/deepseekv3_infer_save_ckpt.py | 76 +++++++++++-------- 2 files changed, 47 insertions(+), 31 deletions(-) diff --git a/codecheck_toolkits/pyproject.toml b/codecheck_toolkits/pyproject.toml index d9d08068..80523aeb 100644 --- a/codecheck_toolkits/pyproject.toml +++ b/codecheck_toolkits/pyproject.toml @@ -129,7 +129,7 @@ exclude = [ ] [tool.codespell] -ignore-words-list = "dout, te, indicies, subtile, ElementE" +ignore-words-list = "dout, te, indicies, subtile, ElementE, CANN" skip = "tests/models/fixtures/*,tests/prompts/*,benchmarks/sonnet.txt,tests/lora/data/*,build/*,vllm_mindspore/third_party/*" [tool.isort] diff --git a/vllm_mindspore/model_executor/models/mf_models/deepseekv3_infer_save_ckpt.py b/vllm_mindspore/model_executor/models/mf_models/deepseekv3_infer_save_ckpt.py index 742a0988..0358451f 100644 --- a/vllm_mindspore/model_executor/models/mf_models/deepseekv3_infer_save_ckpt.py +++ b/vllm_mindspore/model_executor/models/mf_models/deepseekv3_infer_save_ckpt.py @@ -18,18 +18,17 @@ import argparse import os from collections import OrderedDict -from vllm.logger import init_logger - import mindspore as ms +from mindformers import MindFormerConfig, build_context +from mindformers.core.parallel_config import build_parallel_config from mindspore import dtype as msdtype from mindspore.communication.management import get_rank -from mindformers.core.parallel_config import build_parallel_config -from mindformers import MindFormerConfig -from mindformers import build_context -from research.deepseek3.deepseekv3_infer_parallelism import DeepseekInferParallelism - from research.deepseek3.deepseek3_config import DeepseekV3Config -from research.deepseek3.deepseek3_model_infer import InferenceDeepseekV3ForCausalLM +from research.deepseek3.deepseek3_model_infer import ( + InferenceDeepseekV3ForCausalLM) +from research.deepseek3.deepseekv3_infer_parallelism import ( + DeepseekInferParallelism) +from vllm.logger import init_logger logger = init_logger(__name__) @@ -37,26 +36,39 @@ logger = init_logger(__name__) # bash scripts/msrun_launcher.sh "python ./infer_save_ckpt_from_safetensor.py # --config /path/to/predict_deepseek_r1_671b.yaml # --save_ckpt_path /path/to/save_ckpt_path -# --load_checkpoint /path/to/safetensor_path " 4 8555 "output/deepseek_msrun_log" "False" 7200 +# --load_checkpoint /path/to/safetensor_path " +# 4 8555 "output/deepseek_msrun_log" "False" 7200 + def create_ptq(): '''create_ptq''' - from research.deepseek3.deepseek3_model_infer import DeepseekV3DecodeLayer - from mindspore_gs.ptq import PTQ from mindspore_gs.common import BackendTarget - from mindspore_gs.ptq import PTQConfig, PTQMode, OutliersSuppressionType, PrecisionRecovery, QuantGranularity - cfg = PTQConfig(mode=PTQMode.DEPLOY, backend=BackendTarget.ASCEND, weight_quant_dtype=msdtype.int8, - act_quant_dtype=msdtype.int8, outliers_suppression=OutliersSuppressionType.OUTLIER_SUPPRESSION_PLUS, - opname_blacklist=['lkv2kv', 'lm_head'], precision_recovery=PrecisionRecovery.NONE, - act_quant_granularity=QuantGranularity.PER_TENSOR, - weight_quant_granularity=QuantGranularity.PER_CHANNEL) - ffn_config = PTQConfig(mode=PTQMode.DEPLOY, backend=BackendTarget.ASCEND, weight_quant_dtype=msdtype.int8, - act_quant_dtype=msdtype.int8, - outliers_suppression=OutliersSuppressionType.NONE, - precision_recovery=PrecisionRecovery.NONE, - act_quant_granularity=QuantGranularity.PER_TOKEN, - weight_quant_granularity=QuantGranularity.PER_CHANNEL) - ptq = PTQ(config=cfg, layer_policies=OrderedDict({r'.*\.feed_forward\..*': ffn_config})) + from mindspore_gs.ptq import (PTQ, OutliersSuppressionType, + PrecisionRecovery, PTQConfig, PTQMode, + QuantGranularity) + from research.deepseek3.deepseek3_model_infer import DeepseekV3DecodeLayer + cfg = PTQConfig( + mode=PTQMode.DEPLOY, + backend=BackendTarget.ASCEND, + weight_quant_dtype=msdtype.int8, + act_quant_dtype=msdtype.int8, + outliers_suppression=OutliersSuppressionType.OUTLIER_SUPPRESSION_PLUS, + opname_blacklist=['lkv2kv', 'lm_head'], + precision_recovery=PrecisionRecovery.NONE, + act_quant_granularity=QuantGranularity.PER_TENSOR, + weight_quant_granularity=QuantGranularity.PER_CHANNEL) + ffn_config = PTQConfig( + mode=PTQMode.DEPLOY, + backend=BackendTarget.ASCEND, + weight_quant_dtype=msdtype.int8, + act_quant_dtype=msdtype.int8, + outliers_suppression=OutliersSuppressionType.NONE, + precision_recovery=PrecisionRecovery.NONE, + act_quant_granularity=QuantGranularity.PER_TOKEN, + weight_quant_granularity=QuantGranularity.PER_CHANNEL) + ptq = PTQ(config=cfg, + layer_policies=OrderedDict({r'.*\.feed_forward\..*': + ffn_config})) ptq.decoder_layers.append(DeepseekV3DecodeLayer) return ptq @@ -85,24 +97,28 @@ def main(config_path, load_checkpoint, save_ckpt_dir): ptq.summary(network) # load checkpoint if config.load_checkpoint: - logger.info("----------------Transform and load checkpoint----------------") + logger.info( + "----------------Transform and load checkpoint----------------") model_parallelism = DeepseekInferParallelism(config, network, is_quant) model_parallelism.infer_convert_and_parallelism(config.load_checkpoint) rank_id = str(get_rank()) os.makedirs(os.path.join(save_ckpt_dir, "rank_" + rank_id), exist_ok=True) - save_ckpt_path = os.path.join(save_ckpt_dir, "rank_" + rank_id, "checkpoint_" + rank_id + ".ckpt") + save_ckpt_path = os.path.join(save_ckpt_dir, "rank_" + rank_id, + "checkpoint_" + rank_id + ".ckpt") ms.save_checkpoint(network.parameters_dict(), save_ckpt_path) if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--config_path', default='predict_llama2_7b.yaml', type=str, + parser.add_argument('--config_path', + default='predict_llama2_7b.yaml', + type=str, help='model config file path.') - parser.add_argument('--load_checkpoint', type=str, + parser.add_argument('--load_checkpoint', + type=str, help='load model checkpoint path or directory.') - parser.add_argument('--save_ckpt_dir', type=str, - help='save ckpt path.') + parser.add_argument('--save_ckpt_dir', type=str, help='save ckpt path.') args = parser.parse_args() main(args.config_path, args.load_checkpoint, args.save_ckpt_dir) -- Gitee