From 6806e231e2dba9755be47dd26652c33c32ac2dbb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=AD=90=E8=B1=AA?= Date: Sun, 30 Mar 2025 02:33:15 +0000 Subject: [PATCH 1/3] =?UTF-8?q?!2466=20=E3=80=90=E5=85=BC=E5=AE=B9?= =?UTF-8?q?=E6=80=A7=E6=96=B9=E6=A1=88-=E9=80=82=E9=85=8D=E6=94=AF?= =?UTF-8?q?=E6=8C=81mindspore=E6=A1=86=E6=9E=B6=E3=80=91=20Merge=20pull=20?= =?UTF-8?q?request=20!2466=20from=20=E9=99=86=E6=98=9F=E5=AE=87/master?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit pr2466+adapter_addition delete duplicate examples back_to_mindspeed final --- .../deepseek3/finetune_deepseek3_ms.sh | 185 ++++++++++ .../deepseek3/pretrain_deepseek3_ms.sh | 204 +++++++++++ .../mindspore/core/models/gpt/gpt_model.py | 129 +++++++ .../core/transformer/moe/moe_layer.py | 124 +++++++ mindspeed_llm/mindspore/mindspore_adaptor.py | 82 +++++ .../transformer/multi_token_predication.py | 327 ++++++++++++++++++ mindspeed_llm/tasks/megatron_adaptor.py | 3 +- mindspeed_llm/training/arguments.py | 8 + 8 files changed, 1061 insertions(+), 1 deletion(-) create mode 100644 examples/mindspore/deepseek3/finetune_deepseek3_ms.sh create mode 100644 examples/mindspore/deepseek3/pretrain_deepseek3_ms.sh create mode 100644 mindspeed_llm/mindspore/core/models/gpt/gpt_model.py create mode 100644 mindspeed_llm/mindspore/core/transformer/moe/moe_layer.py create mode 100644 mindspeed_llm/mindspore/mindspore_adaptor.py create mode 100644 mindspeed_llm/mindspore/tasks/models/transformer/multi_token_predication.py diff --git a/examples/mindspore/deepseek3/finetune_deepseek3_ms.sh b/examples/mindspore/deepseek3/finetune_deepseek3_ms.sh new file mode 100644 index 000000000..293fefc4e --- /dev/null +++ b/examples/mindspore/deepseek3/finetune_deepseek3_ms.sh @@ -0,0 +1,185 @@ +#!/bin/bash +#多机脚本 +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True +export HCCL_CONNECT_TIMEOUT=3600 +export GLOO_SOCKET_IFNAME="enp67s0f0np0" # "192.168.122.1" +export HCCL_SOCKET_IFNAME="enp67s0f0np0" + +basepath=$(cd `dirname $0`; cd ../../../; pwd) + +GPUS_PER_NODE=8 +#MASTER_ADDR=localhost #主节点IP +MASTER_ADDR=70.166.30.72 +MASTER_PORT=8892 +NNODES=4 +NODE_RANK=2 +WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) + +CKPT_SAVE_DIR="your model save ckpt path" # /***/**${NODE_RANK} +DATA_PATH="your data path" +TOKENIZER_PATH="your tokenizer path" +CKPT_LOAD_DIR="your model ckpt path" # /***/**${NODE_RANK} + +TP=1 +PP=4 +EP=8 +CP=1 +CP_TYPE='ulysses_cp_algo' + +NUM_LAYERS=32 +SEQ_LEN=4096 +MBS=1 +GBS=64 +DISTRIBUTED_ARGS=" + --master_addr $MASTER_ADDR \ + --node_rank $NODE_RANK \ + --worker_num $WORLD_SIZE \ + --local_worker_num $GPUS_PER_NODE \ + --master_port $MASTER_PORT \ + --log_dir=test_log \ + --join=False \ + --cluster_time_out=300 \ + --bind_core=True \ +" + +MLA_ARGS=" + --multi-head-latent-attention \ + --qk-rope-head-dim 64 \ + --qk-nope-head-dim 128 \ + --q-lora-rank 1536 \ + --kv-lora-rank 512 \ + --v-head-dim 128 \ + --qk-layernorm \ +" + +MOE_ARGS=" + --moe-grouped-gemm \ + --moe-permutation-async-comm \ + --use-fused-moe-token-permute-and-unpermute \ + --moe-token-dispatcher-type alltoall \ + --first-k-dense-replace 3 \ + --moe-layer-freq 1 \ + --n-shared-experts 1 \ + --num-experts 32 \ + --moe-router-topk 8 \ + --moe-intermediate-size 2048 \ + --moe-router-load-balancing-type noaux_tc \ + --topk-group 4 \ + --routed-scaling-factor 2.5 \ + --seq-aux \ + --norm-topk-prob \ + --moe-router-score-function sigmoid \ + --moe-router-enable-expert-bias \ + --no-gradient-accumulation-fusion \ +" + +ROPE_ARGS=" + --rope-scaling-beta-fast 32 \ + --rope-scaling-beta-slow 1 \ + --rope-scaling-factor 40 \ + --rope-scaling-mscale 1.0 \ + --rope-scaling-mscale-all-dim 1.0 \ + --rope-scaling-original-max-position-embeddings 4096 \ + --rope-scaling-type yarn +" + +GPT_ARGS=" + --spec mindspeed_llm.tasks.models.spec.deepseek_spec layer_spec \ + --num-layer-list 8,8,8,8 \ + --recompute-granularity full \ + --recompute-method block \ + --recompute-num-layers 8 \ + --use-distributed-optimizer \ + --reuse-fp32-param \ + --use-flash-attn \ + --shape-order BNSD \ + --use-mcore-models \ + --tensor-model-parallel-size ${TP} \ + --pipeline-model-parallel-size ${PP} \ + --expert-model-parallel-size ${EP} \ + --sequence-parallel \ + --context-parallel-size ${CP} \ + --context-parallel-algo ${CP_TYPE} \ + --num-layers ${NUM_LAYERS} \ + --hidden-size 7168 \ + --ffn-hidden-size 18432 \ + --num-attention-heads 128 \ + --tokenizer-type PretrainedFromHF \ + --tokenizer-name-or-path ${TOKENIZER_PATH} \ + --seq-length ${SEQ_LEN} \ + --max-position-embeddings 163840 \ + --micro-batch-size ${MBS} \ + --global-batch-size ${GBS} \ + --make-vocab-size-divisible-by 1 \ + --lr 1.0e-5 \ + --train-iters 200 \ + --lr-decay-style cosine \ + --untie-embeddings-and-output-weights \ + --disable-bias-linear \ + --attention-dropout 0.0 \ + --init-method-std 0.02 \ + --hidden-dropout 0.0 \ + --position-embedding-type rope \ + --normalization RMSNorm \ + --use-fused-rotary-pos-emb \ + --use-rotary-position-embeddings \ + --use-fused-swiglu \ + --use-fused-rmsnorm \ + --swiglu \ + --no-masked-softmax-fusion \ + --attention-softmax-in-fp32 \ + --min-lr 1.0e-7 \ + --weight-decay 1e-2 \ + --lr-warmup-iters 0 \ + --clip-grad 1.0 \ + --adam-beta1 0.9 \ + --adam-beta2 0.999 \ + --initial-loss-scale 65536 \ + --vocab-size 129280 \ + --padded-vocab-size 129280 \ + --rotary-base 10000 \ + --norm-epsilon 1e-6 \ + --no-load-optim \ + --no-load-rng \ + --bf16 \ + --distributed-timeout-minutes 120 \ +" + +DATA_ARGS=" + --data-path $DATA_PATH \ + --split 100,0,0 +" + +OUTPUT_ARGS=" + --log-interval 1 \ + --save-interval 61 \ + --eval-interval 2000 \ + --eval-iters 0 \ + --no-save-optim \ + --no-save-rng +" + +# --load ${CKPT_LOAD_DIR} \ +# --save $CKPT_SAVE_DIR \ + +FINETUNE_ARGS=" + --finetune \ + --stage sft \ + --is-instruction-dataset \ + --variable-seq-lengths \ + --prompt-type deepseek3 \ +" + +msrun $DISTRIBUTED_ARGS $basepath/posttrain_gpt.py \ + $GPT_ARGS \ + $DATA_ARGS \ + $OUTPUT_ARGS \ + $MLA_ARGS \ + $ROPE_ARGS \ + $MOE_ARGS \ + $FINETUNE_ARGS \ + --load ${CKPT_LOAD_DIR} \ + --distributed-backend nccl \ + --ai-framework mindspore \ + 2>&1 | tee logs/ms_sft_deepseek3_671b_4k_ptd.log \ No newline at end of file diff --git a/examples/mindspore/deepseek3/pretrain_deepseek3_ms.sh b/examples/mindspore/deepseek3/pretrain_deepseek3_ms.sh new file mode 100644 index 000000000..690a6a819 --- /dev/null +++ b/examples/mindspore/deepseek3/pretrain_deepseek3_ms.sh @@ -0,0 +1,204 @@ +#!/bin/bash +#多机脚本 +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True +export HCCL_CONNECT_TIMEOUT=3600 +export HCCL_ALGO="alltoall=level0:NA;level1:pipeline" +export HCCL_BUFFSIZE=400 +export GLOO_SOCKET_IFNAME="enp67s0f0np0" # "192.168.122.1" +export HCCL_SOCKET_IFNAME="enp67s0f0np0" + +basepath=$(cd `dirname $0`; cd ../../../; pwd) + +GPUS_PER_NODE=8 +#MASTER_ADDR=localhost #主节点IP +MASTER_ADDR=70.166.30.52 #主节点IP +MASTER_PORT=9110 +NNODES=4 +NODE_RANK=2 +WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) + +CKPT_SAVE_DIR="your model save ckpt path" # /***/**${NODE_RANK} +DATA_PATH="your data path" +TOKENIZER_PATH="your tokenizer path" +CKPT_LOAD_DIR="your model ckpt path" # /***/**${NODE_RANK} + +TP=1 +PP=2 +EP=4 +CP=1 +VPP=1 +CP_TYPE='ulysses_cp_algo' +NUM_LAYERS=8 +SEQ_LEN=4096 +MBS=1 +GBS=64 + +DISTRIBUTED_ARGS=" + --master_addr $MASTER_ADDR \ + --node_rank $NODE_RANK \ + --worker_num $WORLD_SIZE \ + --local_worker_num $GPUS_PER_NODE \ + --master_port $MASTER_PORT \ + --log_dir=msrun_log_pretrain \ + --join=False \ + --cluster_time_out=300 \ + --bind_core=True \ +" + +MLA_ARGS=" + --multi-head-latent-attention \ + --qk-rope-head-dim 64 \ + --qk-nope-head-dim 128 \ + --q-lora-rank 1536 \ + --kv-lora-rank 512 \ + --v-head-dim 128 \ + --qk-layernorm \ +" + +MOE_ARGS=" + --moe-grouped-gemm \ + --moe-permutation-async-comm \ + --use-fused-moe-token-permute-and-unpermute \ + --moe-token-dispatcher-type alltoall \ + --n-shared-experts 1 \ + --num-experts 128 \ + --moe-router-topk 8 \ + --moe-layer-freq 1 \ + --n-group 8 \ + --first-k-dense-replace 1 \ + --moe-intermediate-size 2048 \ + --moe-router-load-balancing-type noaux_tc \ + --topk-group 4 \ + --routed-scaling-factor 2.5 \ + --seq-aux \ + --norm-topk-prob \ + --moe-router-score-function sigmoid \ + --moe-router-enable-expert-bias \ + --moe-tp-extend-ep \ + --moe-alltoall-overlap-comm \ +" + +#--num-experts 64 \ +#--moe-layer-freq 1 \ +#--first-k-dense-replace 1 \ +# --moe-tp-extend-ep \ +# --moe-alltoall-overlap-comm \ +#--n-group 8 \ +#--topk-group 4 \ +#--moe-router-topk 8 \ + +MTP_ARGS=" + --num-nextn-predict-layers 1 \ + --share-mtp-embedding-and-output-weight \ + --recompute-mtp-norm \ +" + +# --recompute-mtp-norm \ + +ROPE_ARGS=" + --rope-scaling-beta-fast 32 \ + --rope-scaling-beta-slow 1 \ + --rope-scaling-factor 40 \ + --rope-scaling-mscale 1.0 \ + --rope-scaling-mscale-all-dim 1.0 \ + --rope-scaling-original-max-position-embeddings 4096 \ + --rope-scaling-type yarn +" + +GPT_ARGS="\ + --no-check-for-nan-in-loss-and-grad \ + --spec mindspeed_llm.tasks.models.spec.deepseek_spec layer_spec \ + --no-gradient-accumulation-fusion \ + --reset-position-ids \ + --recompute-granularity full \ + --recompute-method block \ + --recompute-num-layers 8 \ + --noop-layers 7 \ + --no-shared-storage \ + --reuse-fp32-param \ + --use-flash-attn \ + --use-mcore-models \ + --tensor-model-parallel-size ${TP} \ + --pipeline-model-parallel-size ${PP} \ + --num-layers-per-virtual-pipeline-stage ${VPP} \ + --expert-model-parallel-size ${EP} \ + --sequence-parallel \ + --context-parallel-size ${CP} \ + --context-parallel-algo ${CP_TYPE} \ + --num-layers ${NUM_LAYERS} \ + --hidden-size 7168 \ + --ffn-hidden-size 18432 \ + --num-attention-heads 128 \ + --tokenizer-type PretrainedFromHF \ + --tokenizer-name-or-path ${TOKENIZER_MODEL} \ + --seq-length ${SEQ_LEN} \ + --max-position-embeddings 163840 \ + --micro-batch-size ${MBS} \ + --global-batch-size ${GBS} \ + --make-vocab-size-divisible-by 1 \ + --lr 1.0e-5 \ + --train-iters 50 \ + --lr-decay-style cosine \ + --untie-embeddings-and-output-weights \ + --disable-bias-linear \ + --attention-dropout 0.0 \ + --init-method-std 0.02 \ + --hidden-dropout 0.0 \ + --position-embedding-type rope \ + --normalization RMSNorm \ + --use-fused-rotary-pos-emb \ + --use-rotary-position-embeddings \ + --use-fused-swiglu \ + --use-fused-rmsnorm \ + --swiglu \ + --no-masked-softmax-fusion \ + --attention-softmax-in-fp32 \ + --min-lr 1.0e-7 \ + --weight-decay 1e-2 \ + --lr-warmup-iters 0 \ + --clip-grad 1.0 \ + --adam-beta1 0.9 \ + --adam-beta2 0.999 \ + --initial-loss-scale 65536 \ + --vocab-size 129280 \ + --padded-vocab-size 129280 \ + --rotary-base 10000 \ + --norm-epsilon 1e-6 \ + --no-load-optim \ + --no-load-rng \ + --bf16 \ + --distributed-timeout-minutes 120 +" + +# --num-layers-per-virtual-pipeline-stage ${VPP} \ +# --recompute-granularity full \ +# --recompute-method uniform \ +# --recompute-num-layers 1 \ +# --noop-layers 3 \ +DATA_ARGS=" + --data-path $DATA_PATH \ + --split 100,0,0 \ +" + +OUTPUT_ARGS=" + --log-interval 1 \ + --save-interval 1 \ + --eval-interval 2000 \ + --eval-iters 0 \ + --no-save-optim \ + --no-save-rng +" + + +msrun $DISTRIBUTED_ARGS $basepath/pretrain_gpt.py \ + $GPT_ARGS \ + $DATA_ARGS \ + $MLA_ARGS \ + $ROPE_ARGS \ + $MOE_ARGS \ + $OUTPUT_ARGS \ + $MTP_ARGS \ + --distributed-backend nccl \ + --ai-framework mindspore \ + 2>&1 | tee logs/ms_pretrain_deepseek3_685b_4k_ptd.log diff --git a/mindspeed_llm/mindspore/core/models/gpt/gpt_model.py b/mindspeed_llm/mindspore/core/models/gpt/gpt_model.py new file mode 100644 index 000000000..bfbf0d96c --- /dev/null +++ b/mindspeed_llm/mindspore/core/models/gpt/gpt_model.py @@ -0,0 +1,129 @@ +# coding=utf-8 +# Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from functools import wraps +from typing import List + +import torch +from torch import Tensor + +from megatron.core import InferenceParams, tensor_parallel, parallel_state +from megatron.core.models.common.language_module.language_module import LanguageModule +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.transformer import build_module +from megatron.core.transformer.custom_layers.transformer_engine import TENorm +from megatron.training import get_args +from megatron.core.tensor_parallel import ColumnParallelLinear +from megatron.core.transformer import ModuleSpec +from mindspeed_llm.core.transformer.custom_layers.transformer_engine import PTNorm + +from mindspeed_llm.core.tensor_parallel.layers import SegmentedColumnParallelLinear +from mindspeed_llm.mindspore.tasks.models.transformer.multi_token_predication import MultiTokenPredication, MultiTokenPredicationSubmodules +from mindspeed_llm.core.models.gpt.gpt_model import setup_mtp_embeddings_layer + +# Use this spec for multi token predication +mtp_sepc = ModuleSpec( + module=MultiTokenPredication, + submodules=MultiTokenPredicationSubmodules( + embedding=None, + enorm=PTNorm, + hnorm=PTNorm, + eh_proj=ColumnParallelLinear, + transformer_layer=None, + final_layernorm=PTNorm, + output_layer=None, + ) +) + + + +def gpt_model_init_wrapper(fn): + @wraps(fn) + def wrapper(self, *args, **kwargs): + post_layer_norm = kwargs.pop('post_layer_norm', True) + fn(self, *args, **kwargs) + config = args[1] if len(args) > 1 else kwargs['config'] + arguments = get_args() + if self.post_process and arguments.add_output_layer_bias: + self.output_layer = tensor_parallel.ColumnParallelLinear( + config.hidden_size, + self.vocab_size, + config=config, + init_method=config.init_method, + bias=True, + skip_bias_add=False, + gather_output=not self.parallel_output, + skip_weight_param_allocation=self.pre_process + and self.share_embeddings_and_output_weights, + embedding_activation_buffer=self.embedding_activation_buffer, + grad_output_buffer=self.grad_output_buffer, + ) + + if self.post_process and arguments.output_layer_slice_num > 1: + self.output_layer = SegmentedColumnParallelLinear( + config.hidden_size, + self.vocab_size, + config=config, + init_method=config.init_method, + bias=False, + skip_bias_add=False, + gather_output=not self.parallel_output, + skip_weight_param_allocation=self.pre_process + and self.share_embeddings_and_output_weights, + embedding_activation_buffer=self.embedding_activation_buffer, + grad_output_buffer=self.grad_output_buffer, + ) + if not post_layer_norm: + self.decoder.post_layer_norm = False + self.num_nextn_predict_layers = arguments.num_nextn_predict_layers + self.share_mtp_embedding_and_output_weight = arguments.share_mtp_embedding_and_output_weight + if self.post_process and self.training and self.num_nextn_predict_layers: + self.mtp_layers = torch.nn.ModuleList( + [ + MultiTokenPredication( + config, + self.transformer_layer_spec, + mtp_sepc.submodules, + vocab_size=self.vocab_size, + max_sequence_length=self.max_sequence_length, + layer_number=i, + pre_process=self.pre_process, + post_process=self.post_process, + fp16_lm_cross_entropy=kwargs.get("fp16_lm_cross_entropy", False), + parallel_output=self.parallel_output, + position_embedding_type=self.position_embedding_type, + rotary_percent=kwargs.get("rotary_percent", 1.0), + seq_len_interpolation_factor=kwargs.get("rotary_seq_len_interpolation_factor", None), + share_mtp_embedding_and_output_weight=self.share_mtp_embedding_and_output_weight, + ) + for i in range(self.num_nextn_predict_layers) + ] + ) + + if self.post_process and self.num_nextn_predict_layers: + # move block main model final norms here + self.final_layernorm = build_module( + TENorm, + config=self.config, + hidden_size=self.config.hidden_size, + eps=self.config.layernorm_epsilon, + ) + else: + self.final_layernorm = None + + if self.pre_process or self.post_process: + setup_mtp_embeddings_layer(self) + + return wrapper \ No newline at end of file diff --git a/mindspeed_llm/mindspore/core/transformer/moe/moe_layer.py b/mindspeed_llm/mindspore/core/transformer/moe/moe_layer.py new file mode 100644 index 000000000..a3f99dead --- /dev/null +++ b/mindspeed_llm/mindspore/core/transformer/moe/moe_layer.py @@ -0,0 +1,124 @@ +# Copyright (c) 2024; NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved. +import types +from copy import deepcopy +from functools import wraps +import torch +import torch.nn.functional as F +from mindspeed.moe.utils import MoEAuxLossAutoScaler + +from megatron.core.tensor_parallel import ColumnParallelLinear, RowParallelLinear +from megatron.core.transformer import build_module +from megatron.core.transformer.mlp import MLPSubmodules, MLP +from megatron.core.transformer.moe.experts import GroupedMLP, SequentialMLP +from megatron.core.transformer.moe.moe_utils import save_to_aux_losses_tracker +from megatron.training import get_args +from mindspeed.mindspore.core.transformer.moe.moe_layer_overlap_all2all import MoELayerOverlapAll2All +from mindspeed.core.transformer.moe.moe_layer_overlap_allgather import MoELayerOverlapAllGather + + +def moe_layer_init_wrapper(init_func): + @wraps(init_func) + def moe_layer_init(*args, **kwargs): + moe_config = deepcopy(kwargs["config"]) + global_args = get_args() + if global_args.moe_intermediate_size: + moe_config.ffn_hidden_size = global_args.moe_intermediate_size + kwargs["config"] = moe_config + + init_func(*args, **kwargs) + self = args[0] + + if moe_config.moe_grouped_gemm: + self.experts = GroupedMLP(self.num_local_experts, moe_config) + else: + self.experts = SequentialMLP(self.num_local_experts, moe_config, self.submodules) + + if global_args.n_shared_experts: + shared_expert_config = deepcopy(moe_config) + shared_expert_config.ffn_hidden_size = global_args.n_shared_experts * moe_config.ffn_hidden_size + + if global_args.moe_allgather_overlap_comm or global_args.moe_alltoall_overlap_comm: + from mindspeed.core.transformer.moe.layers import ColumnParallelLinear, RowParallelLinear + self.shared_experts = MLP(shared_expert_config, MLPSubmodules(linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear), shared_expert=True) + else: + from megatron.core.tensor_parallel import ColumnParallelLinear, RowParallelLinear + self.shared_experts = MLP(shared_expert_config, MLPSubmodules(linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear)) + + # For using layer_number when recompute activation function is enabled. + self.shared_experts.layer_number = self.layer_number + if global_args.shared_expert_gate: + self.shared_expert_gate = build_module( + torch.nn.Linear, + shared_expert_config.hidden_size, + global_args.shared_expert_gate_output_dimension, + bias=False + ) + return moe_layer_init + + +def moe_layer_forward(self, hidden_states: torch.Tensor): + global_args = get_args() + if global_args.moe_token_dispatcher_type == 'alltoall' and global_args.moe_alltoall_overlap_comm: + return MoELayerOverlapAll2All.apply(hidden_states, self) + if global_args.moe_token_dispatcher_type == 'allgather' and global_args.moe_allgather_overlap_comm: + return MoELayerOverlapAllGather.apply(hidden_states, self) + + # process MoE + scores, indices = self.router(hidden_states) + + if global_args.moe_revert_type_after_topk: + (dispatched_input, tokens_per_expert) = self.token_dispatcher.token_permutation( + hidden_states, scores.type_as(hidden_states), indices + ) + else: + (dispatched_input, tokens_per_expert) = self.token_dispatcher.token_permutation( + hidden_states, scores, indices + ) + + router_expert_output, mlp_bias = self.experts(dispatched_input, tokens_per_expert) + + output, mlp_bias = self.token_dispatcher.token_unpermutation(router_expert_output, mlp_bias) + + args = get_args() + if args.moe_router_load_balancing_type == "group_limited_greedy": + # forward only need no loss track + if hasattr(args, "do_train") and args.do_train: + save_to_aux_losses_tracker( + "load_balancing_loss", + self.router.l_aux, + self.layer_number, + self.config.num_layers, + ) + save_to_aux_losses_tracker( + "load_balancing_expert_level_loss", + self.router.l_expert_aux / args.moe_aux_loss_coeff, + self.layer_number, + self.config.num_layers, + ) + if hasattr(self.router, 'l_device_aux'): + save_to_aux_losses_tracker( + "load_balancing_device_level_loss", + self.router.l_device_aux / args.moe_device_level_aux_loss_coeff, + self.layer_number, + self.config.num_layers, + ) + if hasattr(self.router, 'l_comm_aux'): + save_to_aux_losses_tracker( + "load_balancing_comm_level_loss", + self.router.l_comm_aux / args.moe_comm_aux_loss_coeff, + self.layer_number, + self.config.num_layers, + ) + output = MoEAuxLossAutoScaler.apply(output, self.router.l_aux) + + if args.n_shared_experts: + share_experts_output, share_experts_bias = self.shared_experts(hidden_states) + if args.shared_expert_gate: + share_experts_output = F.sigmoid(self.shared_expert_gate(hidden_states)) * share_experts_output + output = output + share_experts_output + + if self.token_dispatcher.add_bias: + mlp_bias = mlp_bias + share_experts_bias + + return output, mlp_bias \ No newline at end of file diff --git a/mindspeed_llm/mindspore/mindspore_adaptor.py b/mindspeed_llm/mindspore/mindspore_adaptor.py new file mode 100644 index 000000000..22da56679 --- /dev/null +++ b/mindspeed_llm/mindspore/mindspore_adaptor.py @@ -0,0 +1,82 @@ +from mindspeed_llm.tasks.megatron_adaptor import MegatronAdaptation, MegatronAdaptationABC + + +class MindSporeAdaptation(MegatronAdaptationABC): + """ + Adaptations for models in Megatron-LM Core structure. + """ + def execute(self): + args = MegatronAdaptation.get_args() + if not hasattr(args, "ai_framework") or args.ai_framework != "mindspore": + return + from ..mindspore.core.models.gpt.gpt_model import gpt_model_init_wrapper + from ..mindspore.core.transformer.moe.moe_layer import moe_layer_init_wrapper, moe_layer_forward + from mindspeed.mindspore.core.data_parallel.distributed_data_parallel import distributed_data_parallel_init_with_cp + from mindspeed.mindspore.core.transformer.moe.experts import groupedmlp_init_wrapper, groupedmlp_forward + + MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.__init__', gpt_model_init_wrapper) + MegatronAdaptation.register('megatron.core.distributed.distributed_data_parallel.DistributedDataParallel.__init__', + distributed_data_parallel_init_with_cp, force_patch=True) + MegatronAdaptation.register('megatron.core.transformer.moe.moe_layer.MoELayer.__init__', + moe_layer_init_wrapper) + MegatronAdaptation.register('megatron.core.transformer.moe.experts.GroupedMLP.__init__', + groupedmlp_init_wrapper) + MegatronAdaptation.register('megatron.core.transformer.moe.moe_layer.MoELayer.forward', moe_layer_forward, force_patch=True) + + if args.moe_permutation_async_comm: + if args.moe_token_dispatcher_type == 'alltoall': + if args.moe_alltoall_overlap_comm: + from mindspeed.mindspore.core.transformer.moe.legacy_a2a_token_dispatcher import alltoall_token_permutation_new, \ + alltoall_token_unpermutation_new + from mindspeed.mindspore.core.transformer.moe.experts import group_mlp_forward + MegatronAdaptation.register('megatron.core.transformer.moe.experts.GroupedMLP.forward', group_mlp_forward, force_patch=True) + MegatronAdaptation.register( + 'megatron.core.transformer.moe.token_dispatcher.MoEAlltoAllTokenDispatcher.token_permutation', + alltoall_token_permutation_new, force_patch=True) + MegatronAdaptation.register( + 'megatron.core.transformer.moe.token_dispatcher.MoEAlltoAllTokenDispatcher.token_unpermutation', + alltoall_token_unpermutation_new, force_patch=True) + + if hasattr(args, 'use_fused_moe_token_permute_and_unpermute') and args.use_fused_moe_token_permute_and_unpermute and not args.moe_expert_capacity_factor: + from mindspeed.mindspore.core.fusions.npu_moe_token_permute import permute_wrapper + from mindspeed.mindspore.core.fusions.npu_moe_token_unpermute import unpermute_wrapper + MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.permute', permute_wrapper) + MegatronAdaptation.register('megatron.core.transformer.moe.moe_utils.unpermute', unpermute_wrapper) + + if not args.moe_alltoall_overlap_comm: + MegatronAdaptation.register('megatron.core.transformer.moe.experts.GroupedMLP.forward', + groupedmlp_forward, force_patch=True) + + from mindspeed.mindspore.core.distributed.distributed_data_parallel import local_make_param_hook + MegatronAdaptation.register('megatron.core.distributed.distributed_data_parallel.DistributedDataParallel._make_param_hook', local_make_param_hook) + + from mindspeed.mindspore.core.distributed.param_and_grad_buffer import register_grad_ready + MegatronAdaptation.register('megatron.core.distributed.param_and_grad_buffer.register_grad_ready', register_grad_ready) + + from mindspeed.mindspore.core.models.common.embeddings.rotary_pos_embedding import get_rotary_seq_len, local_rotate_half + MegatronAdaptation.register('megatron.core.models.common.embeddings.rotary_pos_embedding.RotaryEmbedding.get_rotary_seq_len', get_rotary_seq_len) + MegatronAdaptation.register('megatron.core.models.common.embeddings._rotate_half', local_rotate_half) + + from mindspeed.mindspore.core.optimizer import get_megatron_optimizer + MegatronAdaptation.register('megatron.core.optimizer.get_megatron_optimizer', get_megatron_optimizer) + from mindspeed.mindspore.core.optimizer.optimizer import megatron_optimizer_init + MegatronAdaptation.register('megatron.core.optimizer.optimizer.MegatronOptimizer.__init__', megatron_optimizer_init) + + from mindspeed.mindspore.core.pipeline_parallel.schedules import forward_step, backward_step, forward_backward_no_pipelining + MegatronAdaptation.register('megatron.core.pipeline_parallel.schedules.forward_step', forward_step) + MegatronAdaptation.register('megatron.core.pipeline_parallel.schedules.backward_step', backward_step) + MegatronAdaptation.register('megatron.core.pipeline_parallel.schedules.forward_backward_no_pipelining', forward_backward_no_pipelining) + from mindspeed.mindspore.core.pipeline_parallel.schedules import forward_backward_pipelining_with_interleaving, forward_backward_pipelining_without_interleaving + MegatronAdaptation.register('megatron.core.pipeline_parallel.schedules.forward_backward_pipelining_with_interleaving', forward_backward_pipelining_with_interleaving) + MegatronAdaptation.register('megatron.core.pipeline_parallel.schedules.forward_backward_pipelining_without_interleaving', forward_backward_pipelining_without_interleaving) + + from mindspeed.mindspore.core.tensor_parallel.data import local_build_key_size_numel_dictionaries + MegatronAdaptation.register('megatron.core.tensor_parallel.data._build_key_size_numel_dictionaries', local_build_key_size_numel_dictionaries) # 1097 + + from mindspeed.mindspore.core.tensor_parallel.mappings import all_to_all_forward + MegatronAdaptation.register('megatron.core.tensor_parallel.mappings._AllToAll.forward', all_to_all_forward) + + from mindspeed.mindspore.core.tensor_parallel.random import local_set_cuda_rng_state, checkpoint_function_forward, checkpoint_function_backward + MegatronAdaptation.register('megatron.core.tensor_parallel.random._set_cuda_rng_state', local_set_cuda_rng_state, force_patch=True) + MegatronAdaptation.register('megatron.core.tensor_parallel.random.CheckpointFunction.forward', checkpoint_function_forward) + MegatronAdaptation.register('megatron.core.tensor_parallel.random.CheckpointFunction.backward', checkpoint_function_backward) diff --git a/mindspeed_llm/mindspore/tasks/models/transformer/multi_token_predication.py b/mindspeed_llm/mindspore/tasks/models/transformer/multi_token_predication.py new file mode 100644 index 000000000..a5d6e6d6b --- /dev/null +++ b/mindspeed_llm/mindspore/tasks/models/transformer/multi_token_predication.py @@ -0,0 +1,327 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. +import logging +from dataclasses import dataclass +from typing import Union, Optional, Literal + +import torch +from torch import Tensor + +from megatron.core import tensor_parallel, InferenceParams +from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding +from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.transformer.module import MegatronModule +from megatron.core.fusions.fused_cross_entropy import fused_vocab_parallel_cross_entropy + +from megatron.core.transformer import ModuleSpec, TransformerConfig, build_module +from megatron.training import get_args +from mindspeed.core.tensor_parallel.random import CheckpointWithoutOutput +from mindspeed_llm.core.tensor_parallel.layers import SegmentedColumnParallelLinear + + +@dataclass +class MultiTokenPredicationSubmodules: + embedding: Union[ModuleSpec, type] = None + output_layer: Union[ModuleSpec, type] = None + eh_proj: Union[ModuleSpec, type] = None + enorm: Union[ModuleSpec, type] = None + hnorm: Union[ModuleSpec, type] = None + transformer_layer: Union[ModuleSpec, type] = None + final_layernorm: Union[ModuleSpec, type] = None + + +class MultiTokenPredication(MegatronModule): + def __init__( + self, + config: TransformerConfig, + transformer_layer_spec: ModuleSpec, + submodules: MultiTokenPredicationSubmodules, + vocab_size: int, + max_sequence_length: int, + layer_number: int = 1, + hidden_dropout: float = None, + pre_process: bool = True, + post_process: bool = True, + fp16_lm_cross_entropy: bool = False, + parallel_output: bool = True, + position_embedding_type: Literal['learned_absolute', 'rope', 'none'] = 'learned_absolute', + rotary_percent: float = 1.0, + rotary_base: int = 10000, + seq_len_interpolation_factor: Optional[float] = None, + share_mtp_embedding_and_output_weight=True, + ): + super().__init__(config=config) + args = get_args() + + self.config = config + self.submodules = submodules + + if transformer_layer_spec is not None: + self.transformer_layer_spec = transformer_layer_spec + self.submodules.transformer_layer = self.transformer_layer_spec + self.layer_number = layer_number + self.hidden_dropout = hidden_dropout + self.hidden_size = args.hidden_size + self.ffn_hidden_size = args.ffn_hidden_size + self.vocab_size = vocab_size + self.max_sequence_length = max_sequence_length + self.pre_process = pre_process + self.post_process = post_process + self.fp16_lm_cross_entropy = fp16_lm_cross_entropy + self.parallel_output = parallel_output + self.position_embedding_type = position_embedding_type + self.num_nextn_predict_layers = args.num_nextn_predict_layers + # share with main model + self.share_mtp_embedding_and_output_weight = share_mtp_embedding_and_output_weight + self.recompute_layer_norm = args.recompute_mtp_norm + self.recompute_mtp_layer = args.recompute_mtp_layer + + self.embedding = LanguageModelEmbedding( + config=self.config, + vocab_size=self.vocab_size, + max_sequence_length=self.max_sequence_length, + position_embedding_type=self.position_embedding_type, + skip_weight_param_allocation=self.pre_process and self.share_mtp_embedding_and_output_weight + ) + + if self.position_embedding_type == 'rope': + self.rotary_pos_emb = RotaryEmbedding( + kv_channels=self.config.kv_channels, + rotary_percent=rotary_percent, + rotary_interleaved=self.config.rotary_interleaved, + seq_len_interpolation_factor=seq_len_interpolation_factor, + rotary_base=rotary_base, + use_cpu_initialization=self.config.use_cpu_initialization, + ) + + self.enorm = build_module( + self.submodules.enorm, + config=self.config, + hidden_size=self.config.hidden_size, + eps=self.config.layernorm_epsilon, + ) + + self.hnorm = build_module( + self.submodules.hnorm, + config=self.config, + hidden_size=self.config.hidden_size, + eps=self.config.layernorm_epsilon, + ) + + self.eh_proj = build_module( + self.submodules.eh_proj, + self.hidden_size + self.hidden_size, + self.hidden_size, + config=self.config, + init_method=self.config.init_method, + gather_output=False, + bias=self.config.add_bias_linear, + skip_bias_add=True, + tp_comm_buffer_name='eh', + ) + + self.transformer_layer = build_module( + self.submodules.transformer_layer, + config=self.config, + ) + + if self.submodules.final_layernorm: + self.final_layernorm = build_module( + self.submodules.final_layernorm, + config=self.config, + hidden_size=self.config.hidden_size, + eps=self.config.layernorm_epsilon, + ) + else: + self.final_layernorm = None + + if self.config.defer_embedding_wgrad_compute: + self.embedding_activation_buffer = [] + self.grad_output_buffer = [] + else: + self.embedding_activation_buffer = None + self.grad_output_buffer = None + self.output_layer = tensor_parallel.ColumnParallelLinear( + config.hidden_size, + self.vocab_size, + config=config, + init_method=config.init_method, + bias=False, + skip_bias_add=False, + gather_output=not self.parallel_output, + skip_weight_param_allocation=self.share_mtp_embedding_and_output_weight, + embedding_activation_buffer=self.embedding_activation_buffer, + grad_output_buffer=self.grad_output_buffer, + ) + if args.add_output_layer_bias: + self.output_layer = tensor_parallel.ColumnParallelLinear( + config.hidden_size, + self.vocab_size, + config=config, + init_method=config.init_method, + bias=True, + skip_bias_add=False, + gather_output=not self.parallel_output, + skip_weight_param_allocation=self.share_mtp_embedding_and_output_weight, + embedding_activation_buffer=self.embedding_activation_buffer, + grad_output_buffer=self.grad_output_buffer, + ) + + if args.output_layer_slice_num > 1: + self.output_layer = SegmentedColumnParallelLinear( + config.hidden_size, + self.vocab_size, + config=config, + init_method=config.init_method, + bias=False, + skip_bias_add=False, + gather_output=not self.parallel_output, + skip_weight_param_allocation=self.share_mtp_embedding_and_output_weight, + embedding_activation_buffer=self.embedding_activation_buffer, + grad_output_buffer=self.grad_output_buffer, + ) + + def forward( + self, + hidden_input_ids: Tensor, + embed_input_ids: Tensor, + position_ids: Tensor, + attention_mask: Tensor, + decoder_input: Tensor = None, + labels: Tensor = None, + inference_params: InferenceParams = None, + packed_seq_params: PackedSeqParams = None, + extra_block_kwargs: dict = None, + embeding_weight: Optional[torch.Tensor] = None, + output_weight: Optional[torch.Tensor] = None, + ): + """Forward function of the MTP module""" + args = get_args() + if not self.training and (hasattr(args, "rope_scaling_type") and args.rope_scaling_type == "longrope"): + args.rope_scaling_original_max_position_embeddings = args.max_position_embeddings + # Decoder embedding. + decoder_input = self.embedding( + input_ids=embed_input_ids, + position_ids=position_ids, + weight=embeding_weight, + ) + if args.scale_emb is not None: + decoder_input = decoder_input * args.scale_emb + + # Rotary positional embeddings (embedding is None for PP intermediate devices) + rotary_pos_emb = None + if self.position_embedding_type == 'rope': + if inference_params is not None: + rotary_seq_len = inference_params.max_sequence_length + else: + rotary_seq_len = decoder_input.size(0) + + if self.config.sequence_parallel: + rotary_seq_len *= self.config.tensor_model_parallel_size + + rotary_seq_len *= self.config.context_parallel_size + rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len) + + def self_enorm(decoder_input): + return self.enorm(decoder_input) + + def self_hnorm(hidden_input_ids): + return self.hnorm(hidden_input_ids) + + if self.recompute_layer_norm: + enorm_output = tensor_parallel.random.checkpoint(self_enorm, False, decoder_input) + hnorm_output = tensor_parallel.random.checkpoint(self_hnorm, False, hidden_input_ids) + else: + enorm_output = self.enorm(decoder_input) + hnorm_output = self.hnorm(hidden_input_ids) + + # [s, b, h] -> [s, b, 2h] + hidden_states = torch.concat( + [hnorm_output, + enorm_output], + dim=-1 + ) + + # hidden_states -> [s, b, h] + hidden_states, _ = self.eh_proj(hidden_states) + + if self.config.tensor_model_parallel_size > 1: + hidden_states = tensor_parallel.gather_from_tensor_model_parallel_region(hidden_states) + if self.config.sequence_parallel: + hidden_states = tensor_parallel.scatter_to_sequence_parallel_region(hidden_states) + if self.recompute_mtp_layer: + hidden_states, context = tensor_parallel.checkpoint( + self.transformer_layer, + self.config.distribute_saved_activations, + hidden_states, + attention_mask, + None, + None, + rotary_pos_emb, + inference_params, + packed_seq_params, + ) + else: + hidden_states, _ = self.transformer_layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + rotary_pos_emb=rotary_pos_emb, + inference_params=inference_params, + packed_seq_params=packed_seq_params, + **(extra_block_kwargs or {}), + ) + + def self_final_layernorm(hidden_states): + return self.final_layernorm(hidden_states) + # Final layer norm. + if self.final_layernorm is not None: + if self.recompute_layer_norm: + finalnorm_output = tensor_parallel.random.checkpoint(self_final_layernorm, False, hidden_states) + else: + finalnorm_output = self.final_layernorm(hidden_states) + else: + finalnorm_output = hidden_states + + if args.dim_model_base is not None: + finalnorm_output = finalnorm_output / (args.hidden_size / args.dim_model_base) + logits, _ = self.output_layer(finalnorm_output, weight=output_weight) + + if args.output_multiplier_scale: + logits = logits * args.output_multiplier_scale + + if args.output_logit_softcapping: + logits = logits / args.output_logit_softcapping + logits = torch.tanh(logits) + logits = logits * args.output_logit_softcapping + + if labels is None: + # [s b h] => [b s h] + return logits.transpose(0, 1).contiguous() + + if args.is_instruction_dataset: + labels = labels[:, 1:].contiguous() + logits = logits[:-1, :, :].contiguous() + + loss = self.compute_language_model_loss(labels, logits) + return hidden_states, loss + + def compute_language_model_loss(self, labels: Tensor, logits: Tensor) -> Tensor: + """Computes the language model loss (Cross entropy across vocabulary) + + Args: + labels (Tensor): The labels of dimension [batch size, seq length] + logits (Tensor): The final logits returned by the output layer of the transformer model + + Returns: + Tensor: Loss tensor of dimensions [batch size, sequence_length] + """ + # [b s] => [s b] + labels = labels.transpose(0, 1).contiguous() + if self.config.cross_entropy_loss_fusion: + loss = fused_vocab_parallel_cross_entropy(logits, labels) + else: + loss = tensor_parallel.vocab_parallel_cross_entropy(logits, labels) + + # [s b] => [b, s] + loss = loss.transpose(0, 1).contiguous() + return loss \ No newline at end of file diff --git a/mindspeed_llm/tasks/megatron_adaptor.py b/mindspeed_llm/tasks/megatron_adaptor.py index 2fcabb54d..5a591d38b 100644 --- a/mindspeed_llm/tasks/megatron_adaptor.py +++ b/mindspeed_llm/tasks/megatron_adaptor.py @@ -41,8 +41,9 @@ class MegatronAdaptation: """ Execute adaptations. """ + from mindspeed_llm.mindspore.mindspore_adaptor import MindSporeAdaptation MegatronAdaptation.pre_execute() - for adaptation in [CoreAdaptation(), LegacyAdaptation()]: + for adaptation in [CoreAdaptation(), LegacyAdaptation(), MindSporeAdaptation()]: adaptation.execute() MegatronAdaptation.apply() MegatronAdaptation.post_execute() diff --git a/mindspeed_llm/training/arguments.py b/mindspeed_llm/training/arguments.py index 161ffc841..6a1aab843 100644 --- a/mindspeed_llm/training/arguments.py +++ b/mindspeed_llm/training/arguments.py @@ -74,10 +74,18 @@ def process_args(parser): parser = _add_megatron2_args(parser) parser = _add_inference_args(parser) parser = _add_dualpipe_args(parser) + parser = _add_ai_framework_args(parser) return parser +def _add_ai_framework_args(parser): + group = parser.add_argument_group(title='ai framework') + + group.add_argument('--ai-framework', type=str, choices=['pytorch', 'mindspore'], default='pytorch', help='support pytorch and mindspore') + return parser + + def _add_default_model_args(parser): group = parser.add_argument_group(title='default model mode') -- Gitee From 7b8bb8ddca86579940b3529fc6c06d5b5ddd8076 Mon Sep 17 00:00:00 2001 From: zihao Date: Mon, 21 Apr 2025 20:37:48 +0800 Subject: [PATCH 2/3] ok --- examples/mindspore/deepseek3/finetune_deepseek3_ms.sh | 7 +++---- examples/mindspore/deepseek3/pretrain_deepseek3_ms.sh | 7 +++---- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/examples/mindspore/deepseek3/finetune_deepseek3_ms.sh b/examples/mindspore/deepseek3/finetune_deepseek3_ms.sh index 293fefc4e..b47f2587d 100644 --- a/examples/mindspore/deepseek3/finetune_deepseek3_ms.sh +++ b/examples/mindspore/deepseek3/finetune_deepseek3_ms.sh @@ -3,17 +3,16 @@ export CUDA_DEVICE_MAX_CONNECTIONS=1 export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True export HCCL_CONNECT_TIMEOUT=3600 -export GLOO_SOCKET_IFNAME="enp67s0f0np0" # "192.168.122.1" +export GLOO_SOCKET_IFNAME="enp67s0f0np0" # 网卡名称 export HCCL_SOCKET_IFNAME="enp67s0f0np0" basepath=$(cd `dirname $0`; cd ../../../; pwd) GPUS_PER_NODE=8 -#MASTER_ADDR=localhost #主节点IP -MASTER_ADDR=70.166.30.72 +MASTER_ADDR=localhost #主节点IP MASTER_PORT=8892 NNODES=4 -NODE_RANK=2 +NODE_RANK=0 WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) CKPT_SAVE_DIR="your model save ckpt path" # /***/**${NODE_RANK} diff --git a/examples/mindspore/deepseek3/pretrain_deepseek3_ms.sh b/examples/mindspore/deepseek3/pretrain_deepseek3_ms.sh index 690a6a819..04bb3707f 100644 --- a/examples/mindspore/deepseek3/pretrain_deepseek3_ms.sh +++ b/examples/mindspore/deepseek3/pretrain_deepseek3_ms.sh @@ -5,17 +5,16 @@ export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True export HCCL_CONNECT_TIMEOUT=3600 export HCCL_ALGO="alltoall=level0:NA;level1:pipeline" export HCCL_BUFFSIZE=400 -export GLOO_SOCKET_IFNAME="enp67s0f0np0" # "192.168.122.1" +export GLOO_SOCKET_IFNAME="enp67s0f0np0" # 网卡名称 export HCCL_SOCKET_IFNAME="enp67s0f0np0" basepath=$(cd `dirname $0`; cd ../../../; pwd) GPUS_PER_NODE=8 -#MASTER_ADDR=localhost #主节点IP -MASTER_ADDR=70.166.30.52 #主节点IP +MASTER_ADDR=localhost #主节点IP MASTER_PORT=9110 NNODES=4 -NODE_RANK=2 +NODE_RANK=0 WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) CKPT_SAVE_DIR="your model save ckpt path" # /***/**${NODE_RANK} -- Gitee From ea17331e0e6b80a5fb568927f4269f7a913968f6 Mon Sep 17 00:00:00 2001 From: zihao Date: Tue, 22 Apr 2025 09:20:17 +0800 Subject: [PATCH 3/3] fix --- examples/mindspore/deepseek3/finetune_deepseek3_ms.sh | 2 -- examples/mindspore/deepseek3/pretrain_deepseek3_ms.sh | 2 -- 2 files changed, 4 deletions(-) diff --git a/examples/mindspore/deepseek3/finetune_deepseek3_ms.sh b/examples/mindspore/deepseek3/finetune_deepseek3_ms.sh index b47f2587d..ac2d68671 100644 --- a/examples/mindspore/deepseek3/finetune_deepseek3_ms.sh +++ b/examples/mindspore/deepseek3/finetune_deepseek3_ms.sh @@ -3,8 +3,6 @@ export CUDA_DEVICE_MAX_CONNECTIONS=1 export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True export HCCL_CONNECT_TIMEOUT=3600 -export GLOO_SOCKET_IFNAME="enp67s0f0np0" # 网卡名称 -export HCCL_SOCKET_IFNAME="enp67s0f0np0" basepath=$(cd `dirname $0`; cd ../../../; pwd) diff --git a/examples/mindspore/deepseek3/pretrain_deepseek3_ms.sh b/examples/mindspore/deepseek3/pretrain_deepseek3_ms.sh index 04bb3707f..42973ec76 100644 --- a/examples/mindspore/deepseek3/pretrain_deepseek3_ms.sh +++ b/examples/mindspore/deepseek3/pretrain_deepseek3_ms.sh @@ -5,8 +5,6 @@ export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True export HCCL_CONNECT_TIMEOUT=3600 export HCCL_ALGO="alltoall=level0:NA;level1:pipeline" export HCCL_BUFFSIZE=400 -export GLOO_SOCKET_IFNAME="enp67s0f0np0" # 网卡名称 -export HCCL_SOCKET_IFNAME="enp67s0f0np0" basepath=$(cd `dirname $0`; cd ../../../; pwd) -- Gitee