From 3a627fb70235778ec0bfce8be99602e1f18cd1c4 Mon Sep 17 00:00:00 2001 From: racesnail Date: Tue, 18 Feb 2025 16:50:30 +0800 Subject: [PATCH 1/2] =?UTF-8?q?grin=5Fmoe=E6=A8=A1=E5=9E=8B=E8=BF=81?= =?UTF-8?q?=E7=A7=BB=E9=80=82=E9=85=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- configs/checkpoint/model_cfg.json | 14 ++ convert_ckpt.py | 2 +- .../grin_moe/ckpt_convert_grin_hf2mcore.sh | 19 +++ .../grin_moe/generate_grin_16x3point8b_ptd.sh | 86 ++++++++++++ .../mcore/grin_moe/pretrain_grin_moe_ptd.sh | 123 ++++++++++++++++++ .../tasks/models/spec/grin_moe_spec.py | 45 +++++++ 6 files changed, 288 insertions(+), 1 deletion(-) create mode 100644 examples/mcore/grin_moe/ckpt_convert_grin_hf2mcore.sh create mode 100644 examples/mcore/grin_moe/generate_grin_16x3point8b_ptd.sh create mode 100644 examples/mcore/grin_moe/pretrain_grin_moe_ptd.sh create mode 100644 mindspeed_llm/tasks/models/spec/grin_moe_spec.py diff --git a/configs/checkpoint/model_cfg.json b/configs/checkpoint/model_cfg.json index 650944711..5c52cfbcb 100644 --- a/configs/checkpoint/model_cfg.json +++ b/configs/checkpoint/model_cfg.json @@ -440,6 +440,20 @@ "layers_mlp_experts_up_proj": "model.layers[layer_idx].block_sparse_moe.experts[expert_idx].w3", "layers_mlp_experts_linear_fc2": "model.layers[layer_idx].block_sparse_moe.experts[expert_idx].w2" } + }, + "grin-moe": { + "__base__": "base", + "config_set_value": { + "normalization": "LayerNorm", + "moe_flag": true, + "add_output_layer_bias": true + }, + "model_hf_key_mapping": { + "layers_mlp_router": "model.layers[layer_idx].block_sparse_moe.gate", + "layers_mlp_experts_gate_proj": "model.layers[layer_idx].block_sparse_moe.experts[expert_idx].w1", + "layers_mlp_experts_up_proj": "model.layers[layer_idx].block_sparse_moe.experts[expert_idx].w3", + "layers_mlp_experts_linear_fc2": "model.layers[layer_idx].block_sparse_moe.experts[expert_idx].w2" + } } } } diff --git a/convert_ckpt.py b/convert_ckpt.py index 3741b0c0e..82d0e4fd4 100644 --- a/convert_ckpt.py +++ b/convert_ckpt.py @@ -63,7 +63,7 @@ def main(): parser.add_argument('--model-type-hf', type=str, default="llama2", choices=['baichuan', 'baichuan2', 'llama2', 'mixtral', 'chatglm3', 'gemma', 'gemma2', 'bloom', 'bloom_3b', 'qwen', 'internlm2', 'deepseek2', 'minicpm', 'minicpm3', 'minicpm-moe', - 'deepseek2-lite', 'qwen2-moe', 'phi3.5', 'phi3.5-moe'], + 'deepseek2-lite', 'qwen2-moe', 'phi3.5', 'phi3.5-moe','grin-moe'], help='model type of huggingface') parser.add_argument('--ckpt-cfg-path', type=str, default="configs/checkpoint/model_cfg.json", help="Path to the config directory. If not specified, the default path in the repository will be used.") diff --git a/examples/mcore/grin_moe/ckpt_convert_grin_hf2mcore.sh b/examples/mcore/grin_moe/ckpt_convert_grin_hf2mcore.sh new file mode 100644 index 000000000..1e0d3d426 --- /dev/null +++ b/examples/mcore/grin_moe/ckpt_convert_grin_hf2mcore.sh @@ -0,0 +1,19 @@ +source /usr/local/Ascend/ascend-toolkit/set_env.sh + +# 设置需要的并行配置 +python convert_ckpt.py \ + --model-type GPT \ + --load-model-type hf \ + --save-model-type mg \ + --params-dtype bf16 \ + --target-tensor-parallel-size 1 \ + --target-pipeline-parallel-size 1 \ + --target-expert-parallel-size 1 \ + --load-dir /home/hf_weights/GRIN-MoE/ \ + --save-dir /home/mytest/MindSpeed-LLM/model_weights/GRIN-mcore/ \ + --tokenizer-model /home/hf_weights/GRIN-MoE/tokenizer.json \ + --use-mcore-models \ + --model-type-hf grin-moe \ + --add-qkv-bias \ + --add-dense-bias \ + --spec mindspeed_llm.tasks.models.spec.grin_moe_spec layer_spec \ \ No newline at end of file diff --git a/examples/mcore/grin_moe/generate_grin_16x3point8b_ptd.sh b/examples/mcore/grin_moe/generate_grin_16x3point8b_ptd.sh new file mode 100644 index 000000000..15825b4c2 --- /dev/null +++ b/examples/mcore/grin_moe/generate_grin_16x3point8b_ptd.sh @@ -0,0 +1,86 @@ +#!/bin/bash + +# The number of parameters is not aligned +export HCCL_CONNECT_TIMEOUT=1200 +export COMBINED_ENABLE=1 +export CUDA_DEVICE_MAX_CONNECTIONS=1 + +# please fill these path configurations +CHECKPOINT="your model ckpt path" +TOKENIZER_PATH="your tokenizer path" + +# Change for multinode config +MASTER_ADDR=localhost +MASTER_PORT=6022 +NNODES=1 +NODE_RANK=0 +GPUS_PER_NODE=8 +TP=8 +PP=1 +SEQ_LEN=4096 + +DISTRIBUTED_ARGS=" + --nproc_per_node $GPUS_PER_NODE \ + --nnodes $NNODES \ + --node_rank $NODE_RANK \ + --master_addr $MASTER_ADDR \ + --master_port $MASTER_PORT +" + +MOE_ARGS=" + --num-experts 16 \ + --moe-router-topk 2 \ + --expert-model-parallel-size 1 \ + --moe-router-load-balancing-type sparsemixer_topk \ + --moe-aux-loss-coeff 0.0 +" + +GPT_ARGS=" + --use-mcore-models \ + --tensor-model-parallel-size ${TP} \ + --transformer-impl local \ + --pipeline-model-parallel-size ${PP} \ + --max-new-tokens 256 \ + --num-layers 32 \ + --hidden-size 4096 \ + --ffn-hidden-size 6400 \ + --num-attention-heads 32 \ + --group-query-attention \ + --num-query-groups 8 \ + --tokenizer-type PretrainedFromHF \ + --tokenizer-name-or-path ${TOKENIZER_PATH} \ + --seq-length $SEQ_LEN \ + --max-position-embeddings $SEQ_LEN \ + --micro-batch-size 1 \ + --make-vocab-size-divisible-by 1 \ + --padded-vocab-size 32064 \ + --untie-embeddings-and-output-weights \ + --disable-bias-linear \ + --position-embedding-type rope \ + --normalization LayerNorm \ + --rotary-base 10000 \ + --swiglu \ + --use-flash-attn \ + --no-masked-softmax-fusion \ + --no-gradient-accumulation-fusion \ + --exit-on-missing-checkpoint \ + --attention-softmax-in-fp32 \ + --load ${CHECKPOINT} \ + --no-load-optim \ + --no-load-rng \ + --attention-dropout 0.0 \ + --init-method-std 0.01 \ + --hidden-dropout 0.0 \ + --sliding-window 2047 \ + --seed 42 \ + --add-qkv-bias \ + --add-dense-bias \ + --spec mindspeed_llm.tasks.models.spec.grin_moe_spec layer_spec \ + --add-output-layer-bias \ +" + +torchrun $DISTRIBUTED_ARGS inference.py \ + $GPT_ARGS \ + $MOE_ARGS \ + --distributed-backend nccl \ + | tee logs/generate_mcore_grin_16x3point8b.log diff --git a/examples/mcore/grin_moe/pretrain_grin_moe_ptd.sh b/examples/mcore/grin_moe/pretrain_grin_moe_ptd.sh new file mode 100644 index 000000000..b1775a3fd --- /dev/null +++ b/examples/mcore/grin_moe/pretrain_grin_moe_ptd.sh @@ -0,0 +1,123 @@ +#!/bin/bash +export CUDA_DEVICE_MAX_CONNECTIONS=1 + +# Change for multinode config +NPUS_PER_NODE=8 +MASTER_ADDR=localhost +MASTER_PORT=6060 +NNODES=1 +NODE_RANK=0 +WORLD_SIZE=$(($NPUS_PER_NODE*$NNODES)) + +CKPT_SAVE_DIR="your model save ckpt path" +DATA_PATH="your data path" +TOKENIZER_MODEL="your tokenizer path" +CKPT_LOAD_DIR="your model ckpt path" + +TP=1 +PP=1 +EP=1 +MBS=1 +GBS=64 +SEQ_LEN=4096 +TRAIN_ITERS=2000 + +DISTRIBUTED_ARGS=" + --nproc_per_node $NPUS_PER_NODE \ + --nnodes $NNODES \ + --node_rank $NODE_RANK \ + --master_addr $MASTER_ADDR \ + --master_port $MASTER_PORT +" + +echo "NODE_RANK ${NODE_RANK}" + +MOE_ARGS=" + --expert-model-parallel-size ${EP} \ + --num-experts 16 \ + --moe-router-topk 2 \ + --moe-router-load-balancing-type sparsemixer_topk \ + --moe-aux-loss-coeff 0.0 \ + --moe-token-dispatcher-type allgather +" + +GPT_ARGS=" + --use-mcore-models \ + --transformer-impl local \ + --spec mindspeed_llm.tasks.models.spec.grin_moe_spec layer_spec \ + --tensor-model-parallel-size ${TP} \ + --pipeline-model-parallel-size ${PP} \ + + --sequence-parallel \ + --tokenizer-type PretrainedFromHF \ + --tokenizer-name-or-path ${TOKENIZER_PATH} \ + --seq-length ${SEQ_LEN} \ + --max-position-embeddings 4096 \ + --num-layers 1 \ + --hidden-size 4096 \ + --ffn-hidden-size 6400 \ + --num-attention-heads 32 \ + --group-query-attention \ + --num-query-groups 8 \ + --padded-vocab-size 32064 \ + --make-vocab-size-divisible-by 1 \ + --rotary-base 10000 \ + --train-iters ${TRAIN_ITERS} \ + --add-qkv-bias \ + --disable-bias-linear \ + --add-dense-bias \ + --add-output-layer-bias \ + --untie-embeddings-and-output-weights \ + --position-embedding-type rope \ + --normalization LayerNorm \ + --norm-epsilon 1e-5 \ + --swiglu \ + --seed 42 +" + +OPTIM_ARGS=" + --attention-softmax-in-fp32 \ + --no-masked-softmax-fusion \ + --no-gradient-accumulation-fusion \ + --no-load-optim \ + --no-load-rng +" + +TRAIN_ARGS=" + --micro-batch-size ${MBS} \ + --global-batch-size ${GBS} \ + --lr 1.25e-6 \ + --min-lr 1.25e-7 \ + --lr-decay-style cosine \ + --weight-decay 1e-1 \ + --lr-warmup-fraction 0.01 \ + --clip-grad 1.0 \ + --init-method-std 0.01 \ + --attention-dropout 0.0 \ + --hidden-dropout 0.0 \ +" + +DATA_ARGS=" + --data-path $DATA_PATH \ + --split 100,0,0 +" + +OUTPUT_ARGS=" + --log-throughput \ + --log-interval 1 \ + --save-interval 2000 \ + --eval-interval 2000 \ + --eval-iters 0 +" + +torchrun $DISTRIBUTED_ARGS pretrain_gpt.py \ + $MOE_ARGS \ + $GPT_ARGS \ + $OPTIM_ARGS \ + $TRAIN_ARGS \ + $DATA_ARGS \ + $OUTPUT_ARGS \ + --distributed-backend nccl \ + --load ${CKPT_LOAD_DIR} \ + --save ${CKPT_SAVE_DIR} \ + | tee logs/pretrain_mcore_grin_16x3point8b.log \ No newline at end of file diff --git a/mindspeed_llm/tasks/models/spec/grin_moe_spec.py b/mindspeed_llm/tasks/models/spec/grin_moe_spec.py new file mode 100644 index 000000000..da93ee61e --- /dev/null +++ b/mindspeed_llm/tasks/models/spec/grin_moe_spec.py @@ -0,0 +1,45 @@ +# Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved. +from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add +from megatron.training import get_args +from megatron.core.tensor_parallel import ColumnParallelLinear, RowParallelLinear +from megatron.core.transformer import ModuleSpec, TransformerLayer, TransformerLayerSubmodules +from megatron.core.models.gpt.gpt_layer_specs import _get_mlp_module_spec +from megatron.core.transformer.enums import AttnMaskType +from megatron.core.transformer.identity_op import IdentityOp +from megatron.core.transformer.attention import SelfAttentionSubmodules +from megatron.core.transformer.dot_product_attention import DotProductAttention +from mindspeed_llm.core import PTNorm +from mindspeed_llm.tasks.models.transformer.attention import SelfAttentionWithDenseBias + +""" +Layer Specification for GRIN-MoE +""" + +args = get_args() +num_experts, moe_grouped_gemm, qk_layernorm = args.num_experts, args.moe_grouped_gemm, args.qk_layernorm + +layer_spec = ModuleSpec( + module=TransformerLayer, + submodules=TransformerLayerSubmodules( + input_layernorm=PTNorm, + self_attention=ModuleSpec( + module=SelfAttentionWithDenseBias, + params={"attn_mask_type": AttnMaskType.causal}, + submodules=SelfAttentionSubmodules( + linear_qkv=ColumnParallelLinear, + core_attention=DotProductAttention, + linear_proj=RowParallelLinear, + q_layernorm=PTNorm if qk_layernorm else IdentityOp, + k_layernorm=PTNorm if qk_layernorm else IdentityOp, + ), + ), + self_attn_bda=get_bias_dropout_add, + pre_mlp_layernorm=PTNorm, + mlp=_get_mlp_module_spec(use_te=False, num_experts=num_experts, moe_grouped_gemm=moe_grouped_gemm), + mlp_bda=get_bias_dropout_add, + sharded_state_dict_keys_map={ + 'input_layernorm.': 'self_attention.linear_qkv.layer_norm_', + 'pre_mlp_layernorm.': 'mlp.linear_fc1.layer_norm_', + }, + ), +) -- Gitee From 597de7e7216f22a288bcfac76f7731b74c2fb931 Mon Sep 17 00:00:00 2001 From: racesnail Date: Tue, 18 Feb 2025 17:40:55 +0800 Subject: [PATCH 2/2] =?UTF-8?q?grin=5Fmoe=E6=A8=A1=E5=9E=8B=E8=BF=81?= =?UTF-8?q?=E7=A7=BB=E9=80=82=E9=85=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- convert_ckpt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/convert_ckpt.py b/convert_ckpt.py index 82d0e4fd4..e64823277 100644 --- a/convert_ckpt.py +++ b/convert_ckpt.py @@ -63,7 +63,7 @@ def main(): parser.add_argument('--model-type-hf', type=str, default="llama2", choices=['baichuan', 'baichuan2', 'llama2', 'mixtral', 'chatglm3', 'gemma', 'gemma2', 'bloom', 'bloom_3b', 'qwen', 'internlm2', 'deepseek2', 'minicpm', 'minicpm3', 'minicpm-moe', - 'deepseek2-lite', 'qwen2-moe', 'phi3.5', 'phi3.5-moe','grin-moe'], + 'deepseek2-lite', 'qwen2-moe', 'phi3.5', 'phi3.5-moe', 'grin-moe'], help='model type of huggingface') parser.add_argument('--ckpt-cfg-path', type=str, default="configs/checkpoint/model_cfg.json", help="Path to the config directory. If not specified, the default path in the repository will be used.") -- Gitee