diff --git a/examples/mcore/deepseek3/pretrain_deepseek3_671b_4k_A3_ptd.sh b/examples/mcore/deepseek3/pretrain_deepseek3_671b_4k_A3_ptd.sh index 91cccfe341df08c9a083dd0fee92913955e6f1dc..824231fbe3cbb90559ad3bde9b095fbf1c7a3728 100644 --- a/examples/mcore/deepseek3/pretrain_deepseek3_671b_4k_A3_ptd.sh +++ b/examples/mcore/deepseek3/pretrain_deepseek3_671b_4k_A3_ptd.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/bin/bash # 需要切换MindSpeed版本 # git checkout f2acbe71a47b9307c2425bc51f9565802bd901cf # checkout commit from MindSpeed core_r0.8.0 in 2025.03.30 @@ -75,8 +75,8 @@ MOE_ARGS=" " MTP_ARGS=" - --num-nextn-predict-layers 1 \ - --share-mtp-embedding-and-output-weight \ + --mtp-num-layers 1 \ + --mtp-loss-scaling-factor 0.3 \ --mtp-mem-efficient-logits \ " @@ -189,5 +189,5 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS pretrain_gpt.py \ $MOE_ARGS \ $MTP_ARGS \ --save $CKPT_SAVE_DIR \ - --load $CKPT_LOAD_DIR \ + --load $CKPT_LOAD_DIR \ --distributed-backend nccl | tee logs/pretrain_deepseek3_671b_4k_A3_ptd.log \ No newline at end of file diff --git a/examples/mcore/deepseek3/pretrain_deepseek3_671b_4k_ptd.sh b/examples/mcore/deepseek3/pretrain_deepseek3_671b_4k_ptd.sh index 305b53a486788d48213afa9a36cd83ac62a3dafc..b26f769bcd318497278a4fd31e2379c2cc83fe00 100644 --- a/examples/mcore/deepseek3/pretrain_deepseek3_671b_4k_ptd.sh +++ b/examples/mcore/deepseek3/pretrain_deepseek3_671b_4k_ptd.sh @@ -70,6 +70,7 @@ MOE_ARGS=" --routed-scaling-factor 2.5 \ --moe-aux-loss-coeff 0.0001 \ --seq-aux \ + --moe-aux-loss-coeff 0.001 \ --norm-topk-prob \ --moe-router-score-function sigmoid \ --moe-router-enable-expert-bias \ @@ -77,8 +78,8 @@ MOE_ARGS=" " MTP_ARGS=" - --num-nextn-predict-layers 1 \ - --share-mtp-embedding-and-output-weight \ + --mtp-num-layers 1 \ + --mtp-loss-scaling-factor 0.3 \ --recompute-mtp-norm \ --mtp-mem-efficient-logits \ " diff --git a/examples/mcore/deepseek3/tune_deepseek3_671b_4k_full_A3_ptd.sh b/examples/mcore/deepseek3/tune_deepseek3_671b_4k_full_A3_ptd.sh index 979c4ac8c440614b40c9f9c85d271f0d69029e12..10bd0ecb656a5c24de2cce6e3f243b3c69db1366 100644 --- a/examples/mcore/deepseek3/tune_deepseek3_671b_4k_full_A3_ptd.sh +++ b/examples/mcore/deepseek3/tune_deepseek3_671b_4k_full_A3_ptd.sh @@ -53,6 +53,7 @@ MOE_ARGS=" --topk-group 4 \ --routed-scaling-factor 2.5 \ --seq-aux \ + --moe-aux-loss-coeff 0.001 \ --norm-topk-prob \ --moe-router-score-function sigmoid \ --moe-router-enable-expert-bias \ diff --git a/mindspeed_llm/core/__init__.py b/mindspeed_llm/core/__init__.py index 1831c630adcd8f807a0d4aeda5ad0bccc88a2476..84430bc2b81e0aa3996e069f836ff8e4532e2cb6 100644 --- a/mindspeed_llm/core/__init__.py +++ b/mindspeed_llm/core/__init__.py @@ -13,7 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .tensor_parallel.layers import vocab_parallel_embedding_forward, vocab_embedding_init_func, checkpoint_forward_wrapper, checkpoint_backward_wrapper +from mindspeed_llm.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec_wrapper, build_layers_wrapper +from .tensor_parallel.layers import vocab_embedding_init_func, vocab_parallel_embedding_forward from .parallel_state import (initialize_model_parallel_decorator, destroy_model_parallel_decorator, get_expert_model_parallel_rank, get_expert_model_parallel_world_size, get_expert_parallel_group, @@ -29,8 +30,6 @@ from .transformer.moe.router import (topk_router_forward, topk_router_routing, t from .transformer.moe.moe_utils import z_loss_func, topk_softmax_with_capacity from .transformer.transformer_layer import TransformerLayer from .transformer.transformer_block import get_num_layers_to_build_wrapper, transformer_block_init_wrapper, transformer_block_forward -from .models.gpt.gpt_model import gpt_model_forward -from .models.gpt.gpt_layer_specs import get_gpt_layer_local_spec_wrapper, build_layers_wrapper from .distributed.param_and_grad_buffer import start_grad_sync_wrapper from .distributed.distributed_data_parallel import distributed_data_parallel_init_wrapper from .optimizer import get_megatron_optimizer_wrapper diff --git a/mindspeed_llm/core/distributed/finalize_model_grads.py b/mindspeed_llm/core/distributed/finalize_model_grads.py index c14e12bdf376247adc0db456d8b0ecd9415d5420..79942fa16214b464bf72152626934b3db22eb94c 100644 --- a/mindspeed_llm/core/distributed/finalize_model_grads.py +++ b/mindspeed_llm/core/distributed/finalize_model_grads.py @@ -15,6 +15,14 @@ from mindspeed.core.tensor_parallel.comm_group_api import TPXCollectiveComm from mindspeed_llm.core.transformer.moe.moe_utils import get_updated_expert_bias +def _get_main_grad_attr(param: torch.nn.Parameter, use_custom_fsdp: bool = False): + if use_custom_fsdp: + return "fsdp_managed_main_grad" + if hasattr(param, "main_grad"): + return "main_grad" + return "grad" + + def allreduce_layernorm_grads(model: List[torch.nn.Module], config: TransformerConfig): """ All-reduce layernorm grads (for sequence parallelism). @@ -67,37 +75,31 @@ def _allreduce_word_embedding_grads(model: List[torch.nn.Module], config: Transf All-reduce word embedding grads. Reduce grads across first and last stages to ensure that word_embeddings parameters stay in - sync. This should only run for models that support pipelined model parallelism (BERT and GPT). + sync. """ if ( - parallel_state.is_rank_in_embedding_group(ignore_virtual=True) - and parallel_state.get_pipeline_model_parallel_world_size() > 1 + parallel_state.is_rank_in_embedding_group(ignore_virtual=True) + and torch.distributed.get_world_size(parallel_state.get_embedding_group()) > 1 ): if parallel_state.is_pipeline_first_stage(ignore_virtual=True): model_module = model[0] elif parallel_state.is_pipeline_last_stage(ignore_virtual=True): model_module = model[-1] - else: # We do not support the interleaved schedule for T5 yet. + else: # We do not support an interleaved schedule for models with encoders yet. model_module = model[0] - # Look for module with 'pre_process' attribute to get around the fact that DDP and - # other wrapper classes inherit from non-core MegatronModule that has - # 'share_embeddings_and_output_weights' and 'shared_embedding_or_output_weight' - # attributes already, causing get_attr_wrapped_model() to not unwrap anything here. model_module = get_attr_wrapped_model(model_module, 'pre_process', return_model_obj=True) - if model_module.share_embeddings_and_output_weights: + # If share_embeddings_and_output_weights is True, we need to maintain duplicated + # embedding weights in post processing stage. If use Multi-Token Prediction (MTP), + # we also need to maintain duplicated embedding weights in mtp process stage. + # So we need to allreduce grads of embedding in the embedding group in these cases. + if model_module.share_embeddings_and_output_weights or getattr(config, 'mtp_num_layers', 0): weight = model_module.shared_embedding_or_output_weight() if not weight.requires_grad: return - grad = weight.main_grad - torch.distributed.all_reduce(grad, group=parallel_state.get_embedding_group()) - if hasattr(model_module, - "share_mtp_embedding_and_output_weight") and model_module.share_mtp_embedding_and_output_weight: - weight = model_module.shared_embedding_weight() - if not weight.requires_grad: - return - grad = weight.main_grad + grad_attr = _get_main_grad_attr(weight) + grad = getattr(weight, grad_attr) torch.distributed.all_reduce(grad, group=parallel_state.get_embedding_group()) diff --git a/mindspeed_llm/core/models/common/embeddings/language_model_embedding.py b/mindspeed_llm/core/models/common/embeddings/language_model_embedding.py index aae69cd6ee310906f3fc6687a35b0cb0cdcab321..290198d5f58e03624ea5b8ce92d7da0e5bf4dbc7 100644 --- a/mindspeed_llm/core/models/common/embeddings/language_model_embedding.py +++ b/mindspeed_llm/core/models/common/embeddings/language_model_embedding.py @@ -3,7 +3,6 @@ from typing import Literal import torch -from torch import Tensor from megatron.core import tensor_parallel from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding @@ -65,69 +64,3 @@ def language_model_embedding_init_func( # Embeddings dropout self.embedding_dropout = torch.nn.Dropout(self.config.hidden_dropout) - - -def language_model_embedding_forward(self, - input_ids: Tensor, - position_ids: Tensor, - tokentype_ids: int = None, - weight: Tensor = None) -> Tensor: - """Pacth forward pass of the embedding module. - - Args: - input_ids (Tensor): The input tokens - position_ids (Tensor): The position id's used to calculate position embeddings - tokentype_ids (int): The token type ids. Used when args.bert_binary_head is set to True. Defaults to None - weight (Tensor): embedding weight - - Returns: - Tensor: The output embeddings - """ - if weight is None: - if self.word_embeddings.weight is None: - raise RuntimeError( - "weight was not supplied to VocabParallelEmbedding forward pass " - "and skip_weight_param_allocation is True." - ) - weight = self.word_embeddings.weight - - word_embeddings = self.word_embeddings(input_ids, weight) - if self.add_position_embedding: - position_embeddings = self.position_embeddings(position_ids) - embeddings = word_embeddings + position_embeddings - else: - embeddings = word_embeddings - - if not self.reduce_scatter_embeddings: - # Data format change to avoid explicit tranposes : [b s h] --> [s b h]. - embeddings = embeddings.transpose(0, 1).contiguous() - - if tokentype_ids is not None: - if self.tokentype_embeddings is None: - raise ValueError("tokentype_embeddings should not be None when tokentype_ids are provided.") - # [b s h] -> [s b h] (So that it can be added with embeddings) - tokentype_embedding = self.tokentype_embeddings(tokentype_ids).permute(1, 0, 2) - embeddings = embeddings + tokentype_embedding - else: - if self.tokentype_embeddings is not None: - raise ValueError("tokentype_embeddings should be None when tokentype_ids are not provided.") - - # If the input flag for fp32 residual connection is set, convert for float. - if self.config.fp32_residual_connection: - embeddings = embeddings.float() - - # Dropout. - if self.config.sequence_parallel: - if not self.reduce_scatter_embeddings: - embeddings = tensor_parallel.scatter_to_sequence_parallel_region(embeddings) - # `scatter_to_sequence_parallel_region` returns a view, which prevents - # the original tensor from being garbage collected. Clone to facilitate GC. - # Has a small runtime cost (~0.5%). - if self.config.clone_scatter_output_in_embedding: - embeddings = embeddings.clone() - with tensor_parallel.get_cuda_rng_tracker().fork(): - embeddings = self.embedding_dropout(embeddings) - else: - embeddings = self.embedding_dropout(embeddings) - - return embeddings diff --git a/mindspeed_llm/core/models/common/language_module/__init__.py b/mindspeed_llm/core/models/common/language_module/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mindspeed_llm/core/models/common/language_module/language_module.py b/mindspeed_llm/core/models/common/language_module/language_module.py new file mode 100644 index 0000000000000000000000000000000000000000..3eb287343fb7f9fb216a2e3945568e93b5fe53a8 --- /dev/null +++ b/mindspeed_llm/core/models/common/language_module/language_module.py @@ -0,0 +1,142 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +import logging + +import torch + +from megatron.core import parallel_state +from megatron.core.dist_checkpointing.mapping import ShardedStateDict +from megatron.core.models.common.language_module.language_module import LanguageModule +from megatron.core.utils import make_tp_sharded_tensor_for_checkpoint +from megatron.training import get_args + + +def setup_embeddings_and_output_layer(self) -> None: + """Sets up embedding layer in first stage and output layer in last stage. + + This function initalizes word embeddings in the final stage when we are + using pipeline parallelism and sharing word embeddings, and sets up param + attributes on the embedding and output layers. + """ + arguments = get_args() + # Set `is_embedding_or_output_parameter` attribute. + if self.pre_process: + self.embedding.word_embeddings.weight.is_embedding_or_output_parameter = True + if self.post_process and self.output_layer.weight is not None: + self.output_layer.weight.is_embedding_or_output_parameter = True + + # If share_embeddings_and_output_weights is True, we need to maintain duplicated + # embedding weights in post processing stage. If use Multi-Token Prediction (MTP), + # we also need to maintain duplicated embedding weights in mtp process stage. + # So we need to copy embedding weights from pre processing stage as initial parameters + # in these cases. + if not self.share_embeddings_and_output_weights and \ + not getattr(self.config, 'mtp_num_layers', 0) or \ + arguments.schedules_method == 'dualpipev': + return + + if parallel_state.get_pipeline_model_parallel_world_size() == 1: + # Zero out wgrad if sharing embeddings between two layers on same + # pipeline stage to make sure grad accumulation into main_grad is + # correct and does not include garbage values (e.g., from torch.empty). + self.shared_embedding_or_output_weight().zero_out_wgrad = True + return + + if parallel_state.is_pipeline_first_stage() and self.pre_process and not self.post_process: + self.shared_embedding_or_output_weight().shared_embedding = True + + if (self.post_process or getattr(self, 'mtp_process', False)) and not self.pre_process: + if parallel_state.is_pipeline_first_stage(): + raise AssertionError("Share embedding and output weight in pipeline first stage incorrectly.") + # set weights of the duplicated embedding to 0 here, + # then copy weights from pre processing stage using all_reduce below. + weight = self.shared_embedding_or_output_weight() + weight.data.fill_(0) + weight.shared = True + weight.shared_embedding = True + + # Parameters are shared between the word embeddings layers, and the + # heads at the end of the model. In a pipelined setup with more than + # one stage, the initial embedding layer and the head are on different + # workers, so we do the following: + # 1. Create a second copy of word_embeddings on the last stage, with + # initial parameters of 0.0. + # 2. Do an all-reduce between the first and last stage to ensure that + # the two copies of word_embeddings start off with the same + # parameter values. + # 3. In the training loop, before an all-reduce between the grads of + # the two word_embeddings layers to ensure that every applied weight + # update is the same on both stages. + + # Ensure that first and last stages have the same initial parameter + # values. + if torch.distributed.is_initialized(): + if parallel_state.is_rank_in_embedding_group(): + weight = self.shared_embedding_or_output_weight() + weight.data = weight.data.cuda() + torch.distributed.all_reduce( + weight.data, group=parallel_state.get_embedding_group() + ) + + elif not getattr(LanguageModule, "embedding_warning_printed", False): + logging.getLogger(__name__).warning( + "Distributed processes aren't initialized, so the output layer " + "is not initialized with weights from the word embeddings. " + "If you are just manipulating a model this is fine, but " + "this needs to be handled manually. If you are training " + "something is definitely wrong." + ) + LanguageModule.embedding_warning_printed = True + + +def tie_embeddings_and_output_weights_state_dict( + self, + sharded_state_dict: ShardedStateDict, + output_layer_weight_key: str, + first_stage_word_emb_key: str, +) -> None: + """Ties the embedding and output weights in a given sharded state dict. + + Args: + sharded_state_dict (ShardedStateDict): state dict with the weight to tie + output_layer_weight_key (str): key of the output layer weight in the state dict. + This entry will be replaced with a tied version + first_stage_word_emb_key (str): this must be the same as the + ShardedTensor.key of the first stage word embeddings. + + Returns: None, acts in-place + """ + if not self.post_process: + # No output layer + if output_layer_weight_key in sharded_state_dict or not sharded_state_dict.keys(): + raise AssertionError("Sharded state dict incorrectly initialized.") + return + + if self.pre_process: + # Output layer is equivalent to the embedding already + return + + # If use Multi-Token Prediction (MTP), we need maintain both embedding layer and output + # layer in mtp process stage. In this case, if share_embeddings_and_output_weights is True, + # the shared weights will be stored in embedding layer, and output layer will not have + # any weight. + if getattr(self, 'mtp_process', False): + # No output layer + if output_layer_weight_key in sharded_state_dict or not sharded_state_dict.keys(): + raise AssertionError("Sharded state dict incorrectly initialized.") + return + + # Replace the default output layer with a one sharing the weights with the embedding + del sharded_state_dict[output_layer_weight_key] + tensor = self.shared_embedding_or_output_weight() + last_stage_word_emb_replica_id = ( + 1, # copy of first stage embedding + 0, + parallel_state.get_data_parallel_rank(with_context_parallel=True), + ) + + sharded_state_dict[output_layer_weight_key] = make_tp_sharded_tensor_for_checkpoint( + tensor=tensor, + key=first_stage_word_emb_key, + replica_id=last_stage_word_emb_replica_id, + allow_shape_mismatch=True, + ) \ No newline at end of file diff --git a/mindspeed_llm/core/models/gpt/gpt_layer_specs.py b/mindspeed_llm/core/models/gpt/gpt_layer_specs.py index 6216b9b63692d959ab594b5758ba17b8e874b39f..895755aec4b8cd9db56f584506527613c30df9a4 100644 --- a/mindspeed_llm/core/models/gpt/gpt_layer_specs.py +++ b/mindspeed_llm/core/models/gpt/gpt_layer_specs.py @@ -15,11 +15,22 @@ import types from functools import wraps +from typing import Union +from megatron.core.transformer import ModuleSpec from megatron.core.transformer.moe.moe_layer import MoELayer from megatron.training import get_args +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.transformer.transformer_block import TransformerBlockSubmodules +from mindspeed_llm.core.transformer.transformer_layer import TransformerLayer from mindspeed_llm.core.transformer.custom_layers.transformer_engine import PTNorm +from mindspeed_llm.core.transformer.multi_token_prediction import ( + MultiTokenPredictionBlockSubmodules, + get_mtp_layer_offset, + get_mtp_layer_spec, + get_mtp_num_layers_to_build, +) def get_gpt_layer_local_spec_wrapper(fn): @@ -42,6 +53,7 @@ def build_layers_wrapper(fn, column_forward, row_forward): """ For MOE + Ascend MC2, we replace linear_fc1 and linear_fc2 with vanilla column_forward and row_forward in megatron. """ + @wraps(fn) def wrapper(self, *args, **kwargs): fn(self, *args, **kwargs) @@ -52,4 +64,42 @@ def build_layers_wrapper(fn, column_forward, row_forward): for local_expert in layer.mlp.experts.local_experts: local_expert.linear_fc1.forward = types.MethodType(column_forward, local_expert.linear_fc1) local_expert.linear_fc2.forward = types.MethodType(row_forward, local_expert.linear_fc2) + return wrapper + + +def get_gpt_mtp_block_spec( + config: TransformerConfig, + spec: Union[TransformerBlockSubmodules, ModuleSpec], + use_transformer_engine: bool, +) -> MultiTokenPredictionBlockSubmodules: + """GPT Multi-Token Prediction (MTP) block spec.""" + num_layers_to_build = get_mtp_num_layers_to_build(config) + if num_layers_to_build == 0: + return None + + if isinstance(spec, TransformerBlockSubmodules): + # get the spec for the last layer of decoder block + transformer_layer_spec = spec.layer_specs[-1] + elif isinstance(spec, ModuleSpec) and spec.module == TransformerLayer: + transformer_layer_spec = spec + else: + raise ValueError(f"Invalid spec: {spec}") + + mtp_layer_spec = get_mtp_layer_spec( + transformer_layer_spec=transformer_layer_spec, use_transformer_engine=use_transformer_engine + ) + mtp_num_layers = config.mtp_num_layers if config.mtp_num_layers else 0 + mtp_layer_specs = [mtp_layer_spec] * mtp_num_layers + + offset = get_mtp_layer_offset(config) + # split the mtp layer specs to only include the layers that are built in this pipeline stage. + mtp_layer_specs = mtp_layer_specs[offset: offset + num_layers_to_build] + if len(mtp_layer_specs) > 0: + if len(mtp_layer_specs) != config.mtp_num_layers: + raise AssertionError(f"currently all of the mtp layers must stage in the same pipeline stage.") + mtp_block_spec = MultiTokenPredictionBlockSubmodules(layer_specs=mtp_layer_specs) + else: + mtp_block_spec = None + + return mtp_block_spec diff --git a/mindspeed_llm/core/models/gpt/gpt_model.py b/mindspeed_llm/core/models/gpt/gpt_model.py index 93f9829624bed382b7587faad99944be8a43bae8..664e53069165da99c1530605a32dc430064b42ad 100644 --- a/mindspeed_llm/core/models/gpt/gpt_model.py +++ b/mindspeed_llm/core/models/gpt/gpt_model.py @@ -1,5 +1,6 @@ # coding=utf-8 # Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved. +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,92 +13,117 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import logging -from functools import wraps -from typing import List +from typing import Literal, Optional, Dict import torch from torch import Tensor from megatron.core import InferenceParams, tensor_parallel, parallel_state +from megatron.core.dist_checkpointing.mapping import ShardedStateDict +from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding +from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding from megatron.core.models.common.language_module.language_module import LanguageModule +from megatron.core.models.gpt import GPTModel as MegatronCoreGPTModel from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.transformer import build_module from megatron.core.transformer.custom_layers.transformer_engine import TENorm +from megatron.core.transformer import TransformerConfig, ModuleSpec +from megatron.core.transformer.enums import ModelType +from megatron.core.transformer.transformer_block import TransformerBlock from megatron.training import get_args -from mindspeed.utils import set_actual_seq_len, set_position_ids from mindspeed_llm.core.tensor_parallel.layers import SegmentedColumnParallelLinear -from mindspeed_llm.tasks.models.spec.mtp_spec import mtp_sepc -from mindspeed_llm.tasks.models.transformer.multi_token_predication import MultiTokenPredication -from mindspeed_llm.training.utils import tensor_slide, compute_actual_seq_len +from mindspeed_llm.training.utils import tensor_slide +from mindspeed_llm.core.transformer.multi_token_prediction import ( + MultiTokenPredictionBlock, + tie_output_layer_state_dict, + tie_word_embeddings_state_dict, +) -def gpt_model_init_wrapper(fn): - @wraps(fn) - def wrapper(self, *args, **kwargs): +class GPTModel(MegatronCoreGPTModel): + """ + patch megatron GPTModel + """ + + def __init__(self, + config: TransformerConfig, + transformer_layer_spec: ModuleSpec, + vocab_size: int, + max_sequence_length: int, + pre_process: bool = True, + post_process: bool = True, + fp16_lm_cross_entropy: bool = False, + parallel_output: bool = True, + share_embeddings_and_output_weights: bool = False, + position_embedding_type: Literal['learned_absolute', 'rope', 'none'] = 'learned_absolute', + rotary_percent: float = 1.0, + rotary_base: int = 10000, + seq_len_interpolation_factor: Optional[float] = None, + mtp_block_spec: Optional[ModuleSpec] = None, + *args, + **kwargs, + ) -> None: + super(LanguageModule, self).__init__(config=config) + + global_args = get_args() post_layer_norm = kwargs.pop('post_layer_norm', True) - fn(self, *args, **kwargs) - config = args[1] if len(args) > 1 else kwargs['config'] - arguments = get_args() - if self.post_process and arguments.add_output_layer_bias: - self.output_layer = tensor_parallel.ColumnParallelLinear( - config.hidden_size, - self.vocab_size, - config=config, - init_method=config.init_method, - bias=True, - skip_bias_add=False, - gather_output=not self.parallel_output, - skip_weight_param_allocation=self.pre_process - and self.share_embeddings_and_output_weights, - embedding_activation_buffer=self.embedding_activation_buffer, - grad_output_buffer=self.grad_output_buffer, - ) - if self.post_process and arguments.output_layer_slice_num > 1: - self.output_layer = SegmentedColumnParallelLinear( - config.hidden_size, - self.vocab_size, - config=config, - init_method=config.init_method, - bias=False, - skip_bias_add=False, - gather_output=not self.parallel_output, - skip_weight_param_allocation=self.pre_process - and self.share_embeddings_and_output_weights, - embedding_activation_buffer=self.embedding_activation_buffer, - grad_output_buffer=self.grad_output_buffer, + self.transformer_layer_spec: ModuleSpec = transformer_layer_spec + self.vocab_size = vocab_size + self.max_sequence_length = max_sequence_length + self.pre_process = pre_process + self.post_process = post_process + self.fp16_lm_cross_entropy = fp16_lm_cross_entropy + self.parallel_output = parallel_output + self.share_embeddings_and_output_weights = share_embeddings_and_output_weights + self.position_embedding_type = position_embedding_type + + # megatron core pipelining currently depends on model type + self.model_type = ModelType.encoder_or_decoder + + # These 2 attributes are needed for TensorRT-LLM export. + self.max_position_embeddings = max_sequence_length + self.rotary_percent = rotary_percent + self.mtp_block_spec = mtp_block_spec + self.mtp_process = mtp_block_spec is not None + + skip_embedding_allocation = self.mtp_process and global_args.schedules_method == 'dualpipev' + if self.pre_process or self.mtp_process: + self.embedding = LanguageModelEmbedding( + config=self.config, + vocab_size=self.vocab_size, + max_sequence_length=self.max_sequence_length, + position_embedding_type=position_embedding_type, + skip_weight_param_allocation=skip_embedding_allocation, ) - if not post_layer_norm: - self.decoder.post_layer_norm = False - self.num_nextn_predict_layers = arguments.num_nextn_predict_layers - self.share_mtp_embedding_and_output_weight = arguments.share_mtp_embedding_and_output_weight - if self.post_process and self.training and self.num_nextn_predict_layers: - self.mtp_layers = torch.nn.ModuleList( - [ - MultiTokenPredication( - config, - self.transformer_layer_spec, - mtp_sepc.submodules, - vocab_size=self.vocab_size, - max_sequence_length=self.max_sequence_length, - layer_number=i, - pre_process=self.pre_process, - post_process=self.post_process, - fp16_lm_cross_entropy=kwargs.get("fp16_lm_cross_entropy", False), - parallel_output=self.parallel_output, - position_embedding_type=self.position_embedding_type, - rotary_percent=kwargs.get("rotary_percent", 1.0), - seq_len_interpolation_factor=kwargs.get("rotary_seq_len_interpolation_factor", None), - share_mtp_embedding_and_output_weight=self.share_mtp_embedding_and_output_weight, - ) - for i in range(self.num_nextn_predict_layers) - ] + + if self.position_embedding_type == 'rope': + self.rotary_pos_emb = RotaryEmbedding( + kv_channels=self.config.kv_channels, + rotary_percent=rotary_percent, + rotary_interleaved=self.config.rotary_interleaved, + seq_len_interpolation_factor=seq_len_interpolation_factor, + rotary_base=rotary_base, + use_cpu_initialization=self.config.use_cpu_initialization, ) - if self.post_process and self.num_nextn_predict_layers: - # move block main model final norms here + # Cache for RoPE tensors which do not change between iterations. + self.rotary_pos_emb_cache = {} + + # Transformer. + self.decoder = TransformerBlock( + config=self.config, + spec=transformer_layer_spec, + pre_process=self.pre_process, + post_process=self.post_process, + ) + + if self.mtp_process: + self.mtp = MultiTokenPredictionBlock(config=self.config, spec=self.mtp_block_spec) + + if self.mtp_process: + # move block main model final norm here when mtp enable self.final_layernorm = build_module( TENorm, config=self.config, @@ -107,263 +133,218 @@ def gpt_model_init_wrapper(fn): else: self.final_layernorm = None - if not arguments.schedules_method == 'dualpipev' and (self.pre_process or self.post_process): - setup_mtp_embeddings_layer(self) - - return wrapper - - -def shared_embedding_weight(self) -> Tensor: - """Gets the emedding weight when share embedding and mtp embedding weights set to True. + # Output + if self.post_process or self.mtp_process: + + if self.config.defer_embedding_wgrad_compute: + # The embedding activation buffer preserves a reference to the input activations + # of the final embedding projection layer GEMM. It will hold the activations for + # all the micro-batches of a global batch for the last pipeline stage. Once we are + # done with all the back props for all the microbatches for the last pipeline stage, + # it will be in the pipeline flush stage. During this pipeline flush we use the + # input activations stored in embedding activation buffer and gradient outputs + # stored in gradient buffer to calculate the weight gradients for the embedding + # final linear layer. + self.embedding_activation_buffer = [] + self.grad_output_buffer = [] + else: + self.embedding_activation_buffer = None + self.grad_output_buffer = None + if global_args.output_layer_slice_num > 1: + self.output_layer = SegmentedColumnParallelLinear( + config.hidden_size, + self.vocab_size, + config=config, + init_method=config.init_method, + bias=False, + skip_bias_add=False, + gather_output=not self.parallel_output, + skip_weight_param_allocation=self.pre_process + and self.share_embeddings_and_output_weights, + embedding_activation_buffer=self.embedding_activation_buffer, + grad_output_buffer=self.grad_output_buffer, + ) + else: + self.output_layer = tensor_parallel.ColumnParallelLinear( + config.hidden_size, + self.vocab_size, + config=config, + init_method=config.init_method, + bias=global_args.add_output_layer_bias, + skip_bias_add=False, + gather_output=not self.parallel_output, + skip_weight_param_allocation=self.pre_process + and self.share_embeddings_and_output_weights, + embedding_activation_buffer=self.embedding_activation_buffer, + grad_output_buffer=self.grad_output_buffer, + ) + if not post_layer_norm: + self.decoder.post_layer_norm = False - Returns: - Tensor: During pre processing it returns the input embeddings weight while during post processing it returns - mtp embedding layers weight - """ - assert self.num_nextn_predict_layers > 0 - if self.pre_process: - return self.embedding.word_embeddings.weight - elif self.post_process: - if get_args().schedules_method == 'dualpipev': - from mindspeed.core.pipeline_parallel.dualpipev.dualpipev_schedules import get_shared_embedding_from_dual_chunk - return get_shared_embedding_from_dual_chunk() + if self.pre_process or self.post_process: + self.setup_embeddings_and_output_layer() + + def forward(self, + input_ids: Tensor, + position_ids: Tensor, + attention_mask: Tensor, + decoder_input: Tensor = None, + labels: Tensor = None, + inference_params: InferenceParams = None, + packed_seq_params: PackedSeqParams = None, + extra_block_kwargs: dict = None, + loss_mask: Optional[Tensor] = None, + ) -> Tensor: + # If decoder_input is provided (not None), then input_ids and position_ids are ignored. + # Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input. + args = get_args() + + if not self.training and (hasattr(args, "rope_scaling_type") and args.rope_scaling_type == "longrope"): + args.rope_scaling_original_max_position_embeddings = args.max_position_embeddings + # Decoder embedding. + if decoder_input is not None: + pass + elif self.pre_process: + decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids) + if args.scale_emb is not None: + decoder_input = decoder_input * args.scale_emb else: - return self.mtp_layers[0].embedding.word_embeddings.weight - return None - - -def setup_mtp_embeddings_layer(self): - """ - Share embedding layer in mtp layer. - """ - if self.pre_process: - self.embedding.word_embeddings.weight.is_embedding_or_output_parameter = True - # Set `is_embedding_or_output_parameter` attribute. - for i in range(self.num_nextn_predict_layers): - if self.post_process and self.mtp_layers[i].embedding.word_embeddings.weight is not None: - self.mtp_layers[i].embedding.word_embeddings.weight.is_embedding_or_output_parameter = True - - if not self.share_mtp_embedding_and_output_weight: - return - - if self.pre_process and self.post_process: - # Zero out wgrad if sharing embeddings between two layers on same - # pipeline stage to make sure grad accumulation into main_grad is - # correct and does not include garbage values (e.g., from torch.empty). - self.shared_embedding_weight().zero_out_wgrad = True - return - - if self.pre_process and not self.post_process: - assert parallel_state.is_pipeline_first_stage() - self.shared_embedding_weight().shared_embedding = True - - for i in range(self.num_nextn_predict_layers): - if self.post_process and not self.pre_process: - assert not parallel_state.is_pipeline_first_stage() - # set word_embeddings weights to 0 here, then copy first - # stage's weights using all_reduce below. - self.mtp_layers[i].embedding.word_embeddings.weight.data.fill_(0) - self.mtp_layers[i].embedding.word_embeddings.weight.shared = True - self.mtp_layers[i].embedding.word_embeddings.weight.shared_embedding = True - - # Parameters are shared between the word embeddings layers, and the - # heads at the end of the model. In a pipelined setup with more than - # one stage, the initial embedding layer and the head are on different - # workers, so we do the following: - # 1. Create a second copy of word_embeddings on the last stage, with - # initial parameters of 0.0. - # 2. Do an all-reduce between the first and last stage to ensure that - # the two copies of word_embeddings start off with the same - # parameter values. - # 3. In the training loop, before an all-reduce between the grads of - # the two word_embeddings layers to ensure that every applied weight - # update is the same on both stages. - - # Ensure that first and last stages have the same initial parameter - # values. - if torch.distributed.is_initialized(): - if parallel_state.is_rank_in_embedding_group(): - weight = self.shared_embedding_weight() - weight.data = weight.data.cuda() - torch.distributed.all_reduce( - weight.data, group=parallel_state.get_embedding_group() + # intermediate stage of pipeline + # decoder will get hidden_states from encoder.input_tensor + decoder_input = None + + # Rotary positional embeddings (embedding is None for PP intermediate devices) + rotary_pos_emb = None + if self.position_embedding_type == 'rope': + rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len( + inference_params, self.decoder, decoder_input, self.config ) - - elif not getattr(LanguageModule, "embedding_warning_printed", False): - logging.getLogger(__name__).warning( - "Distributed processes aren't initialized, so the output layer " - "is not initialized with weights from the word embeddings. " - "If you are just manipulating a model this is fine, but " - "this needs to be handled manually. If you are training " - "something is definitely wrong." + rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len) + + # Run decoder. + hidden_states = self.decoder( + hidden_states=decoder_input, + attention_mask=attention_mask, + inference_params=inference_params, + rotary_pos_emb=rotary_pos_emb, + packed_seq_params=packed_seq_params, + **(extra_block_kwargs or {}), ) - LanguageModule.embedding_warning_printed = True - -def gpt_model_forward(self, input_ids: Tensor, - position_ids: Tensor, attention_mask: Tensor, - decoder_input: Tensor = None, - labels: Tensor = None, - inference_params: InferenceParams = None, - packed_seq_params: PackedSeqParams = None, - extra_block_kwargs: dict = None, - tokentype_ids=None) -> Tensor: - """ - Forward function of the GPT Model This function passes the input tensors - through the embedding layer, and then the decoeder and finally into the post - processing layer (optional). - - It either returns the Loss values if labels are given or the final hidden units - add output_multiplier_scale to scale logits - """ - # If decoder_input is provided (not None), then input_ids and position_ids are ignored. - # Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input. - args = get_args() - # generate inputs for main and mtps - input_ids, labels, position_ids, attention_mask = inputs_slice( - args.num_nextn_predict_layers, - input_ids, - labels, - position_ids, - attention_mask) - if not self.training and (hasattr(args, "rope_scaling_type") and args.rope_scaling_type == "longrope"): - args.rope_scaling_original_max_position_embeddings = args.max_position_embeddings - # Decoder embedding. - if decoder_input is not None: - pass - elif self.pre_process: - decoder_input = self.embedding(input_ids=input_ids[0], position_ids=position_ids[0]) - if args.scale_emb is not None: - decoder_input = decoder_input * args.scale_emb - else: - # intermediate stage of pipeline - # decoder will get hidden_states from encoder.input_tensor - decoder_input = None - - # Rotary positional embeddings (embedding is None for PP intermediate devices) - rotary_pos_emb = None - if self.position_embedding_type == 'rope': - rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len( - inference_params, self.decoder, decoder_input, self.config - ) - rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len) - - # Run decoder. - hidden_states = self.decoder( - hidden_states=decoder_input, - attention_mask=attention_mask[0], - inference_params=inference_params, - rotary_pos_emb=rotary_pos_emb, - packed_seq_params=packed_seq_params, - **(extra_block_kwargs or {}), - ) - - if not self.post_process: - return hidden_states - - # logits and loss - output_weight = None - if self.share_embeddings_and_output_weights: - output_weight = self.shared_embedding_or_output_weight() - - loss = 0 - # Multi token predication module - if args.num_nextn_predict_layers and self.training: - if not self.share_embeddings_and_output_weights and self.share_mtp_embedding_and_output_weight: - output_weight = self.output_layer.weight - output_weight.zero_out_wgrad = True - embedding_weight = self.shared_embedding_weight() if self.share_mtp_embedding_and_output_weight else None - for i in range(args.num_nextn_predict_layers): - if args.reset_position_ids: - set_position_ids(position_ids[i + 1].transpose(0, 1).contiguous()) - actual_seq_len = compute_actual_seq_len(position_ids[i + 1]) - set_actual_seq_len(actual_seq_len) - if i == 0: - mtp_hidden_states = hidden_states - mtp_hidden_states, mtp_loss = self.mtp_layers[i]( - mtp_hidden_states, # [s,b,h] - input_ids[i + 1], - position_ids[i + 1] if position_ids[0] is not None else None, - attention_mask[i + 1] if attention_mask[0] is not None else None, - decoder_input, - labels[i + 1] if labels[0] is not None else None, - inference_params, - packed_seq_params, - extra_block_kwargs, - embeding_weight=embedding_weight, + # logits and loss + output_weight = None + if self.share_embeddings_and_output_weights: + output_weight = self.shared_embedding_or_output_weight() + if self.mtp_process: + hidden_states = self.mtp( + input_ids=input_ids, + position_ids=position_ids, + labels=labels, + loss_mask=loss_mask, + hidden_states=hidden_states, + attention_mask=attention_mask, + inference_params=inference_params, + rotary_pos_emb=rotary_pos_emb, + packed_seq_params=packed_seq_params, + embedding=self.embedding, + output_layer=self.output_layer, output_weight=output_weight, + compute_language_model_loss=self.compute_language_model_loss, + **(extra_block_kwargs or {}), ) - loss += args.mtp_loss_scale / args.num_nextn_predict_layers * mtp_loss - - if args.num_nextn_predict_layers and self.final_layernorm is not None: - # move block main model final norms here - hidden_states = self.final_layernorm(hidden_states) - - if args.dim_model_base is not None: - hidden_states = hidden_states / (args.hidden_size / args.dim_model_base) - if getattr(args, "task", False) and args.task[0] == 'needlebench': - hidden_states = hidden_states[-100:] - logits, _ = self.output_layer(hidden_states, weight=output_weight) - - # new add to scale logits - if args.output_multiplier_scale: - logits = logits * args.output_multiplier_scale - - if args.output_logit_softcapping: - logits = logits / args.output_logit_softcapping - logits = torch.tanh(logits) - logits = logits * args.output_logit_softcapping - - if labels[0] is None: - # [s b h] => [b s h] - return logits.transpose(0, 1).contiguous() - if args.is_instruction_dataset: - label_length = len(labels) - for i in range(label_length): - labels[i] = labels[i][:, 1:].contiguous() - logits = logits[:-1, :, :].contiguous() - loss += self.compute_language_model_loss(labels[0], logits) - return loss - - -def inputs_slice(slice_num, input_ids, labels, position_ids, attention_mask): - if slice_num == 0: - return ( - [input_ids], - [labels], - [position_ids], - [attention_mask], - ) - - return ( - tensor_slide(input_ids, slice_num), - tensor_slide(labels, slice_num), - generate_nextn_position_ids(position_ids, slice_num), - # not compatible with ppo attn_mask - tensor_slide(attention_mask, slice_num, dims=[-2, -1]), - ) - -def generate_nextn_position_ids(tensor, slice_num): - slides = tensor_slide(tensor, slice_num) - if slides[0] is None: - return slides - - for idx in range(1, len(slides)): - slides[idx] = regenerate_position_ids(slides[idx], idx) - return slides + if self.final_layernorm is not None: + hidden_states = self.final_layernorm(hidden_states) + + if not self.post_process: + return hidden_states + + if args.dim_model_base is not None: + hidden_states = hidden_states / (args.hidden_size / args.dim_model_base) + if getattr(args, "task", False) and args.task[0] == 'needlebench': + hidden_states = hidden_states[-100:] + logits, _ = self.output_layer(hidden_states, weight=output_weight) + + # new add to scale logits + if args.output_multiplier_scale: + logits = logits * args.output_multiplier_scale + + if args.output_logit_softcapping: + logits = logits / args.output_logit_softcapping + logits = torch.tanh(logits) + logits = logits * args.output_logit_softcapping + + if labels is None: + # [s b h] => [b s h] + return logits.transpose(0, 1).contiguous() + if args.is_instruction_dataset: + labels = labels[:, 1:].contiguous() + logits = logits[:-1, :, :].contiguous() + loss = self.compute_language_model_loss(labels, logits) + return loss + + def shared_embedding_or_output_weight(self) -> Tensor: + """Gets the embedding weight or output logit weights when share input embedding and + output weights set to True or when use Multi-Token Prediction (MTP) feature. + + Returns: + Tensor: During pre processing or MTP process it returns the input embeddings weight. + Otherwise, during post processing it returns the final output layers weight. + """ + if not self.pre_process and self.post_process and get_args().schedules_method == 'dualpipev': + from mindspeed.core.pipeline_parallel.dualpipev.dualpipev_schedules import \ + get_shared_embedding_from_dual_chunk + return get_shared_embedding_from_dual_chunk() + if self.pre_process or self.mtp_process: + # Multi-Token Prediction (MTP) need both embedding layer and output layer. + # So there will be both embedding layer and output layer in the mtp process stage. + # In this case, if share_embeddings_and_output_weights is True, the shared weights + # will be stored in embedding layer, and output layer will not have any weight. + assert hasattr( + self, 'embedding' + ), f"embedding is needed in this pipeline stage, but it is not initialized." + return self.embedding.word_embeddings.weight + elif self.post_process: + return self.output_layer.weight + return None + def sharded_state_dict( + self, prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[Dict] = None + ) -> ShardedStateDict: + """Sharded state dict implementation for GPTModel backward-compatibility. + Removing extra state. + Tie word embeddings and output layer in mtp process stage. + + Args: + prefix (str): Module name prefix. + sharded_offsets (tuple): PP related offsets, expected to be empty at this module level. + metadata (Optional[Dict]): metadata controlling sharded state dict creation. + + Returns: + ShardedStateDict: sharded state dict for the GPTModel + """ + sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata) + # Multi-Token Prediction (MTP) need both embedding layer and output layer in + # mtp process stage. + # If MTP is not placed in the pre processing stage, we need to maintain a copy of + # embedding layer in the mtp process stage and tie it to the embedding in the pre + # processing stage. + # Also, if MTP is not placed in the post processing stage, we need to maintain a copy + # of output layer in the mtp process stage and tie it to the output layer in the post + # processing stage. + if self.mtp_process and not self.pre_process: + emb_weight_key = f'{prefix}embedding.word_embeddings.weight' + emb_weight = self.embedding.word_embeddings.weight + tie_word_embeddings_state_dict(sharded_state_dict, emb_weight, emb_weight_key) + if self.mtp_process and not self.post_process: + # We only need to tie the output layer weight if share_embeddings_and_output_weights + # is False. Because if share_embeddings_and_output_weights is True, the shared weight + # will be stored in embedding layer, and output layer will not have any weight. + if not self.share_embeddings_and_output_weights: + output_layer_weight_key = f'{prefix}output_layer.weight' + output_layer_weight = self.output_layer.weight + tie_output_layer_state_dict( + sharded_state_dict, output_layer_weight, output_layer_weight_key + ) -def regenerate_position_ids(tensor, offset): - if tensor is None: - return None - tensor = tensor.clone() - for i in range(tensor.size(0)): - row = tensor[i] - zero_mask = (row == 0) - if zero_mask.any(): - first_zero_idx = torch.argmax(zero_mask.int()).item() - tensor[i, :first_zero_idx] = torch.arange(first_zero_idx) - else: - tensor = tensor - offset - return tensor \ No newline at end of file + return sharded_state_dict diff --git a/mindspeed_llm/core/pipeline_parallel/dualpipe/MTP_overlap.py b/mindspeed_llm/core/pipeline_parallel/dualpipe/MTP_overlap.py index fdb7c5a786cb8972ec7459540befadacf8e17858..2b6e21fb01694b6f803ab5238e1fa0c2ba66f141 100644 --- a/mindspeed_llm/core/pipeline_parallel/dualpipe/MTP_overlap.py +++ b/mindspeed_llm/core/pipeline_parallel/dualpipe/MTP_overlap.py @@ -1,7 +1,18 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. +from contextlib import nullcontext + import torch +from torch import Tensor from megatron.training import get_args +from megatron.core import InferenceParams, parallel_state, tensor_parallel +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.utils import make_tp_sharded_tensor_for_checkpoint, make_viewless_tensor +from megatron.core.tensor_parallel import ( + all_gather_last_dim_from_tensor_parallel_region, + scatter_to_sequence_parallel_region, +) + from mindspeed.core.pipeline_parallel.fb_overlap.transformer_layer import ( transformer_layer_forward, transformer_layer_forward_moe, @@ -9,7 +20,17 @@ from mindspeed.core.pipeline_parallel.fb_overlap.transformer_layer import ( ) from mindspeed.core.tensor_parallel.random import CheckpointWithoutOutput -from megatron.core import tensor_parallel + +try: + from megatron.core.extensions.transformer_engine import ( + TEColumnParallelLinear, + TEDelayedScaling, + TENorm, + ) + + HAVE_TE = True +except ImportError: + HAVE_TE = False class TransformerMTPoverlap(torch.autograd.Function): @@ -18,15 +39,17 @@ class TransformerMTPoverlap(torch.autograd.Function): layer, hidden_states, attention_mask, + context=None, + context_mask=None, rotary_pos_emb=None, inference_params=None, packed_seq_params=None, ): with torch.enable_grad(): - output, context, graph = transformer_layer_forward_moe(layer, + output, context_out, graph = transformer_layer_forward_moe(layer, hidden_states, attention_mask, - None, - None, + context, + context_mask, rotary_pos_emb, inference_params, packed_seq_params) @@ -34,12 +57,12 @@ class TransformerMTPoverlap(torch.autograd.Function): if args.recompute_mtp_layer: graph.deallocate_graph() graph.record_layer_inputs( - attention_mask, None, None, rotary_pos_emb, + attention_mask, context, context_mask, rotary_pos_emb, inference_params, packed_seq_params ) ctx.graph = graph - return output.detach(), context + return output.detach(), context_out @staticmethod def backward(ctx, *args): @@ -59,113 +82,113 @@ class TransformerMTPoverlap(torch.autograd.Function): def forward_overlap(self, - hidden_input_ids, - embed_input_ids, - position_ids, - attention_mask, - decoder_input=None, - labels=None, - inference_params=None, - packed_seq_params=None, - extra_block_kwargs: dict = None, - embeding_weight=None, - output_weight=None, ): - args = get_args() - if not self.training and (hasattr(args, "rope_scaling_type") and args.rope_scaling_type == "longrope"): - args.rope_scaling_original_max_position_embeddings = args.max_position_embeddings - # Decoder embedding. - decoder_input = self.embedding( - input_ids=embed_input_ids, - position_ids=position_ids, - weight=embeding_weight, - ) - if args.scale_emb is not None: - decoder_input = decoder_input * args.scale_emb - - # Rotary positional embeddings (embedding is None for PP intermediate devices) - rotary_pos_emb = None - if self.position_embedding_type == 'rope': - if inference_params is not None: - rotary_seq_len = inference_params.max_sequence_length - else: - rotary_seq_len = decoder_input.size(0) - - if self.config.sequence_parallel: - rotary_seq_len *= self.config.tensor_model_parallel_size - - rotary_seq_len *= self.config.context_parallel_size - rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len) - if self.recompute_layer_norm: - self.enorm_ckpt = CheckpointWithoutOutput() - enorm_output = self.enorm_ckpt.checkpoint(self.enorm, False, decoder_input) - self.hnorm_ckpt = CheckpointWithoutOutput() - hnorm_output = self.hnorm_ckpt.checkpoint(self.hnorm, False, hidden_input_ids) + decoder_input: Tensor, + hidden_states: Tensor, + attention_mask: Tensor, + context: Tensor = None, + context_mask: Tensor = None, + rotary_pos_emb: Tensor = None, + attention_bias: Tensor = None, + inference_params: InferenceParams = None, + packed_seq_params: PackedSeqParams = None,): + if context is not None: + raise NotImplementedError(f"multi token prediction + cross attention is not yet supported.") + + hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) + + if self.config.sequence_parallel: + rng_context = tensor_parallel.get_cuda_rng_tracker().fork() else: - enorm_output = self.enorm(decoder_input) - hnorm_output = self.hnorm(hidden_input_ids) - - # [s, b, h] -> [s, b, 2h] - hidden_states = torch.concat( - [hnorm_output, - enorm_output], - dim=-1 - ) + rng_context = nullcontext() - if self.recompute_layer_norm: - self.enorm_ckpt.discard_output() - self.hnorm_ckpt.discard_output() - hidden_states.register_hook(self.enorm_ckpt.recompute) - hidden_states.register_hook(self.hnorm_ckpt.recompute) - # hidden_states -> [s, b, h] - - hidden_states, _ = self.eh_proj(hidden_states) - if self.config.tensor_model_parallel_size > 1: - hidden_states = tensor_parallel.gather_from_tensor_model_parallel_region(hidden_states) - if self.config.sequence_parallel: - hidden_states = tensor_parallel.scatter_to_sequence_parallel_region(hidden_states) - - trans = TransformerMTPoverlap.apply - hidden_states, _ = trans( - self.transformer_layer, - hidden_states, - attention_mask, - rotary_pos_emb, - inference_params, - packed_seq_params, - ) + if self.config.fp8: + import transformer_engine # To keep out TE dependency when not training in fp8 - # Final layer norm. - if self.final_layernorm is not None: - if self.recompute_layer_norm: - self.finalnorm_ckpt = CheckpointWithoutOutput() - finalnorm_output = self.finalnorm_ckpt.checkpoint(self.final_layernorm, False, hidden_states) + if self.config.fp8 == "e4m3": + fp8_format = transformer_engine.common.recipe.Format.E4M3 + elif self.config.fp8 == "hybrid": + fp8_format = transformer_engine.common.recipe.Format.HYBRID else: - finalnorm_output = self.final_layernorm(hidden_states) + raise ValueError("E4M3 and HYBRID are the only supported FP8 formats.") + + fp8_recipe = TEDelayedScaling( + config=self.config, + fp8_format=fp8_format, + override_linear_precision=(False, False, not self.config.fp8_wgrad), + ) + fp8_group = None + if parallel_state.model_parallel_is_initialized(): + fp8_group = parallel_state.get_amax_reduction_group( + with_context_parallel=True, tp_only_amax_red=self.tp_only_amax_red + ) + fp8_context = transformer_engine.pytorch.fp8_autocast( + enabled=True, fp8_recipe=fp8_recipe, fp8_group=fp8_group + ) else: - finalnorm_output = hidden_states + fp8_context = nullcontext() - if args.dim_model_base is not None: - finalnorm_output = finalnorm_output / (args.hidden_size / args.dim_model_base) - logits, _ = self.output_layer(finalnorm_output, weight=output_weight) + with rng_context, fp8_context: - if self.recompute_layer_norm: - self.finalnorm_ckpt.discard_output() - logits.register_hook(self.finalnorm_ckpt.recompute) - if args.output_multiplier_scale: - logits = logits * args.output_multiplier_scale - - if args.output_logit_softcapping: - logits = logits / args.output_logit_softcapping - logits = torch.tanh(logits) - logits = logits * args.output_logit_softcapping + def enorm(tensor): + tensor = self.enorm(tensor) + tensor = make_viewless_tensor( + inp=tensor, requires_grad=True, keep_graph=True + ) + return tensor - if labels is None: - # [s b h] => [b s h] - return logits.transpose(0, 1).contiguous() + def hnorm(tensor): + tensor = self.hnorm(tensor) + tensor = make_viewless_tensor( + inp=tensor, requires_grad=True, keep_graph=True + ) + return tensor - if args.is_instruction_dataset: - labels = labels[:, 1:].contiguous() - logits = logits[:-1, :, :].contiguous() + if self.recompute_mtp_norm: + self.enorm_ckpt = CheckpointWithoutOutput() + enorm_output = self.enorm_ckpt.checkpoint(enorm, False, decoder_input) + self.hnorm_ckpt = CheckpointWithoutOutput() + hnorm_output = self.hnorm_ckpt.checkpoint(hnorm, False, hidden_states) + else: + enorm_output = enorm(decoder_input) + hnorm_output = hnorm(hidden_states) + # At the (k - 1)-th MTP module, concatenates the i-th tocken's hidden_states + # and the (i + K)-th tocken's embedding, and combine them with linear projection. + hidden_states = torch.cat((enorm_output, hnorm_output), -1) + if self.recompute_mtp_norm: + self.enorm_ckpt.discard_output() + self.hnorm_ckpt.discard_output() + hidden_states.register_hook(self.enorm_ckpt.recompute) + hidden_states.register_hook(self.hnorm_ckpt.recompute) + hidden_states, _ = self.eh_proj(hidden_states) + # For tensor parallel, all gather after linear_fc. + hidden_states = all_gather_last_dim_from_tensor_parallel_region(hidden_states) + # For sequence parallel, scatter after linear_fc and before transformer layer. + if self.sequence_parallel: + hidden_states = scatter_to_sequence_parallel_region(hidden_states) + trans = TransformerMTPoverlap.apply + if self.recompute_mtp_layer: + hidden_states, _ = tensor_parallel.checkpoint( + trans, + self.config.distribute_saved_activations, + self.transformer_layer, + hidden_states, + attention_mask, + context, + context_mask, + rotary_pos_emb, + inference_params, + packed_seq_params, + ) + else: + hidden_states, _ = trans( + self.transformer_layer, + hidden_states, + attention_mask, + context, + context_mask, + rotary_pos_emb, + inference_params, + packed_seq_params, + ) - loss = self.compute_language_model_loss(labels, logits) - return hidden_states, loss \ No newline at end of file + return hidden_states \ No newline at end of file diff --git a/mindspeed_llm/core/pipeline_parallel/dualpipe/adaptor.py b/mindspeed_llm/core/pipeline_parallel/dualpipe/adaptor.py index 6c69ea3d3603ccad02d8573d56039e9a24c0f0f3..912632e624fbaff632edbaa3390c61510c9f51dc 100644 --- a/mindspeed_llm/core/pipeline_parallel/dualpipe/adaptor.py +++ b/mindspeed_llm/core/pipeline_parallel/dualpipe/adaptor.py @@ -1,6 +1,7 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. from megatron.training.utils import print_rank_0 + try: from mindspeed.core.pipeline_parallel.fb_overlap import ( linear_backward_wgrad_detach, @@ -14,7 +15,7 @@ try: except ImportError: pass -from mindspeed_llm.tasks.models.transformer.multi_token_predication import MultiTokenPredication +from mindspeed_llm.core.transformer.multi_token_prediction import MultiTokenPredictionLayer def dualpipe_register_patches(MegatronAdaptation): @@ -22,7 +23,7 @@ def dualpipe_register_patches(MegatronAdaptation): MegatronAdaptation.register('megatron.core.distributed.distributed_data_parallel.DistributedDataParallel._make_param_hook', _make_param_hook) - MultiTokenPredication.forward = forward_overlap + MultiTokenPredictionLayer.forward = forward_overlap MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.forward', gpt_model_forward_backward_overlaping) MegatronAdaptation.register('megatron.core.transformer.transformer_layer.TransformerLayer.forward', diff --git a/mindspeed_llm/core/pipeline_parallel/dualpipe/gpt_model.py b/mindspeed_llm/core/pipeline_parallel/dualpipe/gpt_model.py index 47f22610612df50650157682f2dc1e0060cf72d7..e9fca41b7b718f73cd80938cdeb507ed36f47400 100644 --- a/mindspeed_llm/core/pipeline_parallel/dualpipe/gpt_model.py +++ b/mindspeed_llm/core/pipeline_parallel/dualpipe/gpt_model.py @@ -8,7 +8,6 @@ from torch import Tensor from megatron.core import InferenceParams, parallel_state, tensor_parallel from megatron.core.packed_seq_params import PackedSeqParams from megatron.training import get_args -from mindspeed.utils import set_actual_seq_len, set_position_ids from mindspeed.core.pipeline_parallel.fb_overlap.transformer_block import ( transformer_block_forward, transformer_block_backward, transformer_block_forward_backward_overlaping, @@ -16,8 +15,7 @@ from mindspeed.core.pipeline_parallel.fb_overlap.transformer_block import ( from mindspeed.core.pipeline_parallel.fb_overlap.modules.utils import ( LayerGraph, detach_tensor, run_graph_backward ) -from mindspeed_llm.core.models.gpt.gpt_model import inputs_slice -from mindspeed_llm.training.utils import tensor_slide, compute_actual_seq_len +from mindspeed.core.pipeline_parallel.dualpipev.dualpipev_schedules import get_shared_embedding_from_dual_chunk class ModelGraph: @@ -43,6 +41,7 @@ def gpt_model_forward( inference_params: InferenceParams = None, packed_seq_params: PackedSeqParams = None, extra_block_kwargs: dict = None, + loss_mask: Optional[Tensor] = None, ) -> Tensor: """Forward function of the GPT Model This function passes the input tensors through the embedding layer, and then the decoeder and finally into the post @@ -53,19 +52,14 @@ def gpt_model_forward( # If decoder_input is provided (not None), then input_ids and position_ids are ignored. # Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input. args = get_args() - input_ids, labels, position_ids, attention_mask = inputs_slice( - args.num_nextn_predict_layers, - input_ids, - labels, - position_ids, - attention_mask) + if not self.training and (hasattr(args, "rope_scaling_type") and args.rope_scaling_type == "longrope"): args.rope_scaling_original_max_position_embeddings = args.max_position_embeddings # Decoder embedding. if decoder_input is not None: preprocess_graph = None elif self.pre_process: - decoder_input = self.embedding(input_ids=input_ids[0], position_ids=position_ids[0]) + decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids) if args.scale_emb is not None: decoder_input = decoder_input * args.scale_emb preprocess_graph = decoder_input @@ -86,11 +80,10 @@ def gpt_model_forward( detached_block_input = detach_tensor(decoder_input) # Run decoder. - hidden_states, layer_graphs = transformer_block_forward( self.decoder, hidden_states=detached_block_input, - attention_mask=attention_mask[0], + attention_mask=attention_mask, inference_params=inference_params, rotary_pos_emb=rotary_pos_emb, packed_seq_params=packed_seq_params, @@ -105,48 +98,37 @@ def gpt_model_forward( if self.share_embeddings_and_output_weights: output_weight = self.shared_embedding_or_output_weight() - loss = 0 graph = ModelGraph( layer_graphs, hidden_states, preprocess_graph, detached_block_input ) - # Multi token predication module - if args.num_nextn_predict_layers and self.training: + if self.mtp_process: + self.embedding.word_embeddings.weight = get_shared_embedding_from_dual_chunk() detached_hidden_states = detach_tensor(hidden_states) graph.layer_graphs[-1].unperm2_graph = (graph.layer_graphs[-1].unperm2_graph[0], detached_hidden_states) - if not self.share_embeddings_and_output_weights and self.share_mtp_embedding_and_output_weight: - output_weight = self.output_layer.weight - output_weight.zero_out_wgrad = True - embedding_weight = self.shared_embedding_weight() if self.share_mtp_embedding_and_output_weight else None - extra_block_kwargs = {'use_orig_layer_forward': True} - for i in range(args.num_nextn_predict_layers): - if args.reset_position_ids: - set_position_ids(position_ids[i + 1].transpose(0, 1).contiguous()) - actual_seq_len = compute_actual_seq_len(position_ids[i + 1]) - set_actual_seq_len(actual_seq_len) - if i == 0: - mtp_hidden_states = detached_hidden_states - mtp_hidden_states, mtp_loss = self.mtp_layers[i]( - mtp_hidden_states, # [s,b,h] - input_ids[i + 1], - position_ids[i + 1] if position_ids[0] is not None else None, - attention_mask[i + 1] if attention_mask[0] is not None else None, - decoder_input, - labels[i + 1] if labels[0] is not None else None, - inference_params, - packed_seq_params, - extra_block_kwargs, - embeding_weight=embedding_weight, - output_weight=output_weight, - ) - - loss += args.mtp_loss_scale / args.num_nextn_predict_layers * mtp_loss - - if args.num_nextn_predict_layers and self.final_layernorm is not None: - hidden_states = self.final_layernorm(detached_hidden_states) + hidden_states = self.mtp( + input_ids=input_ids, + position_ids=position_ids, + labels=labels, + loss_mask=loss_mask, + hidden_states=detached_hidden_states, + attention_mask=attention_mask, + inference_params=inference_params, + rotary_pos_emb=rotary_pos_emb, + packed_seq_params=packed_seq_params, + embedding=self.embedding, + output_layer=self.output_layer, + output_weight=output_weight, + compute_language_model_loss=self.compute_language_model_loss, + ) + + if self.final_layernorm is not None: + hidden_states = self.final_layernorm(hidden_states) if args.dim_model_base is not None: hidden_states = hidden_states / (args.hidden_size / args.dim_model_base) + if getattr(args, "task", False) and args.task[0] == 'needlebench': + hidden_states = hidden_states[-100:] logits, _ = self.output_layer(hidden_states, weight=output_weight) # new add to scale logits if args.output_multiplier_scale: @@ -157,17 +139,15 @@ def gpt_model_forward( logits = torch.tanh(logits) logits = logits * args.output_logit_softcapping - if labels[0] is None: + if labels is None: # [s b h] => [b s h] logits = logits.transpose(0, 1).contiguous() return logits, graph - if isinstance(labels, List): - labels = labels[0] if args.is_instruction_dataset: labels = labels[:, 1:].contiguous() logits = logits[:-1, :, :].contiguous() - loss += self.compute_language_model_loss(labels, logits) + loss = self.compute_language_model_loss(labels, logits) return loss, graph @@ -194,28 +174,24 @@ def gpt_model_forward_backward_overlaping( inference_params: InferenceParams = None, packed_seq_params: PackedSeqParams = None, extra_block_kwargs: dict = None, + loss_mask: Optional[Tensor] = None, ): if extra_block_kwargs is None or extra_block_kwargs['bwd_model_graph'] is None: return gpt_model_forward( fwd_model, input_ids, position_ids, attention_mask, decoder_input, labels, inference_params, - packed_seq_params, extra_block_kwargs + packed_seq_params, extra_block_kwargs, loss_mask ) bwd_model_grad, bwd_model_graph = extra_block_kwargs['bwd_model_grad'], extra_block_kwargs[ 'bwd_model_graph'] # Fwd Model Decoder embedding. args = get_args() - input_ids, labels, position_ids, attention_mask = inputs_slice( - args.num_nextn_predict_layers, - input_ids, - labels, - position_ids, - attention_mask) + if not fwd_model.training and (hasattr(args, "rope_scaling_type") and args.rope_scaling_type == "longrope"): args.rope_scaling_original_max_position_embeddings = args.max_position_embeddings if decoder_input is not None: preprocess_graph = None elif fwd_model.pre_process: - decoder_input = fwd_model.embedding(input_ids=input_ids[0], position_ids=position_ids[0]) + decoder_input = fwd_model.embedding(input_ids=input_ids, position_ids=position_ids) if args.scale_emb is not None: decoder_input = decoder_input * args.scale_emb preprocess_graph = decoder_input @@ -240,7 +216,7 @@ def gpt_model_forward_backward_overlaping( = transformer_block_forward_backward_overlaping( fwd_model.decoder, detached_block_input, - attention_mask[0], + attention_mask, bwd_model_grad, bwd_model_graph.layer_graphs, rotary_pos_emb=rotary_pos_emb, @@ -265,45 +241,33 @@ def gpt_model_forward_backward_overlaping( graph = ModelGraph( layer_graphs, hidden_states, preprocess_graph, detached_block_input ) - # Multi token predication module - loss = 0 - if args.num_nextn_predict_layers and fwd_model.training: + + if fwd_model.mtp_process: detached_hidden_states = detach_tensor(hidden_states) graph.layer_graphs[-1].unperm2_graph = (graph.layer_graphs[-1].unperm2_graph[0], detached_hidden_states) - if not fwd_model.share_embeddings_and_output_weights and fwd_model.share_mtp_embedding_and_output_weight: - output_weight = fwd_model.output_layer.weight - output_weight.zero_out_wgrad = True - embedding_weight = fwd_model.shared_embedding_weight() if fwd_model.share_mtp_embedding_and_output_weight else None - extra_block_kwargs = {'use_orig_layer_forward': True} - for i in range(args.num_nextn_predict_layers): - if args.reset_position_ids: - set_position_ids(position_ids[i + 1].transpose(0, 1).contiguous()) - actual_seq_len = compute_actual_seq_len(position_ids[i + 1]) - set_actual_seq_len(actual_seq_len) - if i == 0: - mtp_hidden_states = detached_hidden_states - mtp_hidden_states, mtp_loss = fwd_model.mtp_layers[i]( - mtp_hidden_states, # [s,b,h] - input_ids[i + 1], - position_ids[i + 1] if position_ids[0] is not None else None, - attention_mask[i + 1] if attention_mask[0] is not None else None, - decoder_input, - labels[i + 1] if labels[0] is not None else None, - inference_params, - packed_seq_params, - extra_block_kwargs, - embeding_weight=embedding_weight, - output_weight=output_weight, - ) - - loss += args.mtp_loss_scale / args.num_nextn_predict_layers * mtp_loss - - if args.num_nextn_predict_layers and fwd_model.final_layernorm is not None: - # move block main model final norms here - hidden_states = fwd_model.final_layernorm(detached_hidden_states) + hidden_states = fwd_model.mtp( + input_ids=input_ids, + position_ids=position_ids, + labels=labels, + loss_mask=loss_mask, + hidden_states=detached_hidden_states, + attention_mask=attention_mask, + inference_params=inference_params, + rotary_pos_emb=rotary_pos_emb, + packed_seq_params=packed_seq_params, + embedding=fwd_model.embedding, + output_layer=fwd_model.output_layer, + output_weight=output_weight, + compute_language_model_loss=fwd_model.compute_language_model_loss, + ) + + if fwd_model.final_layernorm is not None: + hidden_states = fwd_model.final_layernorm(hidden_states) if args.dim_model_base is not None: hidden_states = hidden_states / (args.hidden_size / args.dim_model_base) + if getattr(args, "task", False) and args.task[0] == 'needlebench': + hidden_states = hidden_states[-100:] logits, _ = fwd_model.output_layer(hidden_states, weight=output_weight) # new add to scale logits if args.output_multiplier_scale: @@ -314,18 +278,16 @@ def gpt_model_forward_backward_overlaping( logits = torch.tanh(logits) logits = logits * args.output_logit_softcapping - if labels[0] is None: + if labels is None: # [s b h] => [b s h] logits = logits.transpose(0, 1).contiguous() return logits, graph, pp_comm_output - if isinstance(labels, List): - labels = labels[0] if args.is_instruction_dataset: labels = labels[:, 1:].contiguous() logits = logits[:-1, :, :].contiguous() - loss += fwd_model.compute_language_model_loss(labels, logits) + loss = fwd_model.compute_language_model_loss(labels, logits) return loss, graph, pp_comm_output diff --git a/mindspeed_llm/core/pipeline_parallel/schedules.py b/mindspeed_llm/core/pipeline_parallel/schedules.py index 1a54655b655998cdbc6b0cc1451d6118e8a942b3..42db59a185f3ce0fc077c75efc653629820e8953 100644 --- a/mindspeed_llm/core/pipeline_parallel/schedules.py +++ b/mindspeed_llm/core/pipeline_parallel/schedules.py @@ -20,6 +20,7 @@ from functools import wraps import torch from megatron.training import get_args from mindspeed.core.pipeline_parallel.ripipe_schedules import forward_backward_ripipe_pipelining +from mindspeed_llm.core.transformer.multi_token_prediction import MTPLossAutoScaler def get_forward_backward_func_wrapper(get_forward_backward_func): @@ -60,3 +61,50 @@ def forward_backward_pipelining_with_interleaving_wrapper(fn): kwargs['micro_batch_size'] = args_.micro_batch_size * 2 return fn(*args, **kwargs) return wrapper + + +def forward_step_wrapper(fn): + @wraps(fn) + def wrapper( + forward_step_func, + data_iterator, + model, + num_microbatches, + input_tensor, + forward_data_store, + config, + *args, + **kwargs): + output, num_tokens = fn( + forward_step_func, + data_iterator, + model, + num_microbatches, + input_tensor, + forward_data_store, + config, + *args, + **kwargs) + + if not isinstance(input_tensor, list): + # unwrap_output_tensor True + output_tensor = output + else: + output_tensor = output[0] + + # Set the loss scale for Multi-Token Prediction (MTP) loss. + if hasattr(config, 'mtp_num_layers') and config.mtp_num_layers is not None: + # Calculate the loss scale based on the grad_scale_func if available, else default to 1. + loss_scale = ( + config.grad_scale_func(torch.ones(1, device=output_tensor.device)) + if config.grad_scale_func is not None + else torch.ones(1, device=output_tensor.device) + ) + # Set the loss scale + if config.calculate_per_token_loss: + MTPLossAutoScaler.set_loss_scale(loss_scale) + else: + MTPLossAutoScaler.set_loss_scale(loss_scale / num_microbatches) + return output, num_tokens + + return wrapper diff --git a/mindspeed_llm/core/tensor_parallel/cross_entropy.py b/mindspeed_llm/core/tensor_parallel/cross_entropy.py new file mode 100644 index 0000000000000000000000000000000000000000..ca35f62f6b0e9dd93e6922015578b6d848c894f3 --- /dev/null +++ b/mindspeed_llm/core/tensor_parallel/cross_entropy.py @@ -0,0 +1,51 @@ +from typing import Tuple + +import torch + + +def calculate_logits_max( + vocab_parallel_logits: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + if vocab_parallel_logits.dtype != torch.float32: + vocab_parallel_logits_fp32 = vocab_parallel_logits.float() + vocab_parallel_logits.untyped_storage().resize_(0) + else: + vocab_parallel_logits_fp32 = vocab_parallel_logits.float() + # Maximum value along vocab dimension across all GPUs. + logits_max = torch.max(vocab_parallel_logits_fp32, dim=-1)[0] + + return vocab_parallel_logits_fp32, logits_max + + +def calculate_predicted_logits( + vocab_parallel_logits: torch.Tensor, + target: torch.Tensor, + logits_max: torch.Tensor, + vocab_start_index: int, + vocab_end_index: int, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # subtraction the maximum value. + # Use in-place to reduce memory pressure. + vocab_parallel_logits -= logits_max.unsqueeze(dim=-1) + + # Create a mask of valid vocab ids (1 means it needs to be masked). + target_mask = (target < vocab_start_index) | (target >= vocab_end_index) + masked_target = target.clone() - vocab_start_index + masked_target *= ~target_mask + + # Get predicted-logits = logits[target]. + # For Simplicity, we convert logits to a 2-D tensor with size + # [*, partition-vocab-size] and target to a 1-D tensor of size [*]. + partition_vocab_size = vocab_parallel_logits.size()[-1] + logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size) + masked_target_1d = masked_target.view(-1) + arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device) + predicted_logits_1d = logits_2d[arange_1d, masked_target_1d] + predicted_logits_1d = predicted_logits_1d.clone().contiguous() + predicted_logits = predicted_logits_1d.view_as(target) + predicted_logits *= ~target_mask + + exp_logits = vocab_parallel_logits + torch.exp(vocab_parallel_logits, out=exp_logits) + sum_exp_logits = exp_logits.sum(dim=-1) + return target_mask, masked_target_1d, predicted_logits, sum_exp_logits, exp_logits diff --git a/mindspeed_llm/core/tensor_parallel/layers.py b/mindspeed_llm/core/tensor_parallel/layers.py index c03d2169b8d424a5dc0b4b54028695680d76c60e..ede4925e475bcd75a9664cdcd129f0a03e896809 100644 --- a/mindspeed_llm/core/tensor_parallel/layers.py +++ b/mindspeed_llm/core/tensor_parallel/layers.py @@ -18,18 +18,13 @@ from typing import Optional, Callable import torch from torch.nn.parameter import Parameter -import torch.nn.functional as F -from megatron.core.tensor_parallel.mappings import ( - reduce_scatter_to_sequence_parallel_region, - reduce_from_tensor_model_parallel_region, -) -from megatron.core.tensor_parallel.utils import VocabUtility -from megatron.training import get_args +from megatron.core import parallel_state, ModelParallelConfig from megatron.core.tensor_parallel import ( copy_to_tensor_model_parallel_region, gather_from_tensor_model_parallel_region ) +from megatron.core.tensor_parallel.utils import VocabUtility from megatron.core.tensor_parallel.layers import ( linear_with_frozen_weight, linear_with_grad_accumulation_and_async_allreduce, @@ -39,9 +34,13 @@ from megatron.core.tensor_parallel.layers import ( VocabParallelEmbedding, ) from megatron.legacy.model.fused_layer_norm import MixedFusedLayerNorm -from megatron.core import parallel_state, ModelParallelConfig +from megatron.training import get_args -from mindspeed.utils import get_actual_seq_len, set_actual_seq_len +import torch.nn.functional as F +from megatron.core.tensor_parallel.mappings import ( + reduce_scatter_to_sequence_parallel_region, + reduce_from_tensor_model_parallel_region, +) def vocab_embedding_init_func( @@ -238,19 +237,3 @@ class SegmentedColumnParallelLinear(ColumnParallelLinear): output = output_parallel output_bias = self.bias if self.skip_bias_add else None return output, output_bias - - -def checkpoint_forward_wrapper(fn): - def wrapper(ctx, run_function, distribute_saved_activations, *args): - ctx.actual_seq_len = get_actual_seq_len() - return fn(ctx, run_function, distribute_saved_activations, *args) - - return wrapper - - -def checkpoint_backward_wrapper(fn): - def wrapper(ctx, *args): - set_actual_seq_len(ctx.actual_seq_len) - return fn(ctx, *args) - - return wrapper diff --git a/mindspeed_llm/core/transformer/dot_product_attention.py b/mindspeed_llm/core/transformer/dot_product_attention.py index 9f814251063138ea14432f8593f5e3992f6fb992..e9e56329ac92e79ad401566db635938cfa5a8aa8 100644 --- a/mindspeed_llm/core/transformer/dot_product_attention.py +++ b/mindspeed_llm/core/transformer/dot_product_attention.py @@ -408,6 +408,8 @@ def flash_attention_forward( scale = 1.0 / math.sqrt(self.hidden_size_per_attention_head) \ if self.scale_mask_softmax.scale is None else self.softmax_scale actual_seq_len = get_actual_seq_len() + if actual_seq_len is not None and args.mtp_num_layers: + actual_seq_len = actual_seq_len[self.mtp_idx] if args.context_parallel_size > 1 and args.context_parallel_algo in ['megatron_cp_algo', 'hybrid_cp_algo', 'adaptive_cp_algo', 'hybrid_adaptive_cp_algo']: diff --git a/mindspeed_llm/core/transformer/moe/router.py b/mindspeed_llm/core/transformer/moe/router.py index cb0be41703ee36416222876ab37ae0ab6eb7fca8..be87413165871594a49688daa8abe2f596e01e94 100644 --- a/mindspeed_llm/core/transformer/moe/router.py +++ b/mindspeed_llm/core/transformer/moe/router.py @@ -555,6 +555,7 @@ def topk_router_forward(self, input: torch.Tensor): """ args = get_args() self.hidden = input.shape[-1] + _maintain_float32_expert_bias(self) # add input_jitter to distinguish whether to use if args.input_jitter: @@ -565,3 +566,15 @@ def topk_router_forward(self, input: torch.Tensor): scores, indices = self.routing(logits) return scores, indices + + +def _maintain_float32_expert_bias(self): + """ + Maintain the expert bias in float32. + + When using bf16/fp16, the expert bias gets converted to lower precision in Float16Module. + We keep it in float32 to avoid routing errors when updating the expert_bias. + """ + if hasattr(self, 'expert_bias') and self.expert_bias is not None: + if self.expert_bias.dtype != torch.float32: + self.expert_bias.data = self.expert_bias.data.to(torch.float32) diff --git a/mindspeed_llm/core/transformer/multi_token_prediction.py b/mindspeed_llm/core/transformer/multi_token_prediction.py new file mode 100644 index 0000000000000000000000000000000000000000..a8c2abb9b3a86f187ad070d304058af7960a2399 --- /dev/null +++ b/mindspeed_llm/core/transformer/multi_token_prediction.py @@ -0,0 +1,778 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +from contextlib import nullcontext +from dataclasses import dataclass +from typing import List, Optional, Union + +import torch +from torch import Tensor + +from megatron.core import InferenceParams, mpu, parallel_state, tensor_parallel +from megatron.core.dist_checkpointing.mapping import ShardedStateDict +from megatron.core.dist_checkpointing.utils import replace_prefix_for_sharding +from megatron.core.fusions.fused_layer_norm import FusedLayerNorm +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.tensor_parallel import ( + all_gather_last_dim_from_tensor_parallel_region, + scatter_to_sequence_parallel_region, +) +from megatron.core.tensor_parallel.layers import ColumnParallelLinear +from megatron.core.transformer.enums import AttnMaskType +from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.spec_utils import ModuleSpec, build_module +from megatron.core.transformer.transformer_block import TransformerBlockSubmodules +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.utils import make_tp_sharded_tensor_for_checkpoint, make_viewless_tensor +from megatron.training import get_args +from mindspeed.core.tensor_parallel.random import CheckpointWithoutOutput +from mindspeed_llm.core.transformer.custom_layers.transformer_engine import PTNorm +from mindspeed_llm.training.utils import regenerate_position_ids + +SUPPORTED_ATTN_MASK = [ + AttnMaskType.padding, + AttnMaskType.causal, + AttnMaskType.no_mask, + AttnMaskType.padding_causal, +] + + +try: + from megatron.core.extensions.transformer_engine import ( + TEColumnParallelLinear, + TEDelayedScaling, + TENorm, + ) + + HAVE_TE = True +except ImportError: + HAVE_TE = False + + +try: + import apex # pylint: disable=unused-import + + from megatron.core.fusions.fused_layer_norm import FusedLayerNorm + + HAVE_APEX = True + LNImpl = FusedLayerNorm +except ImportError: + import warnings + + warnings.warn('Apex is not installed. Falling back to Torch Norm') + LNImpl = PTNorm + + +def tie_word_embeddings_state_dict( + sharded_state_dict: ShardedStateDict, word_emb_weight: Tensor, word_emb_weight_key: str +) -> None: + """tie the embedding of the mtp processing stage in a given sharded state dict. + + Args: + sharded_state_dict (ShardedStateDict): state dict with the weight to tie. + word_emb_weight (Tensor): weight of the word embedding. + word_emb_weight_key (str): key of the word embedding in the sharded state dict. + + Returns: None, acts in-place + """ + mtp_word_emb_replica_id = ( + 1, # copy of embedding in pre processing stage + 0, + parallel_state.get_data_parallel_rank(with_context_parallel=True), + ) + if word_emb_weight_key not in sharded_state_dict: + raise AssertionError("Word emb weight in sharded state dict.") + del sharded_state_dict[word_emb_weight_key] + sharded_state_dict[word_emb_weight_key] = make_tp_sharded_tensor_for_checkpoint( + tensor=word_emb_weight, + key=word_emb_weight_key, + replica_id=mtp_word_emb_replica_id, + allow_shape_mismatch=True, + ) + + +def tie_output_layer_state_dict( + sharded_state_dict: ShardedStateDict, output_layer_weight: Tensor, output_layer_weight_key: str +) -> None: + """tie the output layer of the mtp processing stage in a given sharded state dict. + + Args: + sharded_state_dict (ShardedStateDict): state dict with the weight to tie. + output_layer_weight (Tensor): weight of the output layer. + output_layer_weight_key (str): key of the output layer in the sharded state dict. + + Returns: None, acts in-place + """ + mtp_output_layer_replica_id = ( + 1, # copy of output layer in post processing stage + 0, + parallel_state.get_data_parallel_rank(with_context_parallel=True), + ) + if output_layer_weight_key not in sharded_state_dict: + raise AssertionError("output layer weight in sharded state dict.") + del sharded_state_dict[output_layer_weight_key] + sharded_state_dict[output_layer_weight_key] = make_tp_sharded_tensor_for_checkpoint( + tensor=output_layer_weight, + key=output_layer_weight_key, + replica_id=mtp_output_layer_replica_id, + allow_shape_mismatch=True, + ) + + +def roll_tensor(tensor, shifts=-1, dims=-1): + """Roll the tensor input along the given dimension(s). + Inserted elements are set to be 0.0. + """ + rolled_tensor = torch.roll(tensor, shifts=shifts, dims=dims) + rolled_tensor.select(dims, shifts).fill_(0) + return rolled_tensor, rolled_tensor.sum() + + +class MTPLossLoggingHelper: + """Helper class for logging MTP losses.""" + + tracker = {} + + @staticmethod + def save_loss_to_tracker( + loss: torch.Tensor, + layer_number: int, + num_layers: int, + reduce_group: torch.distributed.ProcessGroup = None, + avg_group: torch.distributed.ProcessGroup = None, + ): + """Save the mtp loss for logging. + Args: + loss (torch.Tensor): The loss tensor. + layer_number (int): Layer index of the loss. + num_layers (int): The number of total layers. + reduce_group (torch.distributed.ProcessGroup): The group for reducing the loss. + avg_group (torch.distributed.ProcessGroup): The group for averaging the loss. + """ + # Skip mtp loss logging if layer_number is None. + if layer_number is None: + return + + tracker = MTPLossLoggingHelper.tracker + if "values" not in tracker: + tracker["values"] = torch.zeros(num_layers, device=loss.device) + tracker["values"][layer_number] += loss.detach() + tracker["reduce_group"] = reduce_group + tracker["avg_group"] = avg_group + + @staticmethod + def clean_loss_in_tracker(): + """Clear the mtp losses.""" + tracker = MTPLossLoggingHelper.tracker + tracker["values"].zero_() + tracker["reduce_group"] = None + tracker["avg_group"] = None + + @staticmethod + def reduce_loss_in_tracker(): + """Collect and reduce the mtp losses across ranks.""" + tracker = MTPLossLoggingHelper.tracker + if "values" not in tracker: + return + values = tracker["values"] + # Reduce mtp losses across ranks. + if tracker.get('reduce_group') is not None: + torch.distributed.all_reduce(values, group=tracker.get('reduce_group')) + if tracker.get('avg_group') is not None: + torch.distributed.all_reduce( + values, group=tracker['avg_group'], op=torch.distributed.ReduceOp.SUM + ) + tracker["values"] = values / tracker['avg_group'].size() + + + @staticmethod + def track_mtp_metrics(loss_scale, iteration, writer, wandb_writer=None, total_loss_dict=None): + """Track the Multi-Token Prediction (MTP) metrics for logging.""" + MTPLossLoggingHelper.reduce_loss_in_tracker() + tracker = MTPLossLoggingHelper.tracker + if "values" not in tracker: + return + mtp_losses = tracker["values"] * loss_scale + mtp_num_layers = mtp_losses.shape[0] + for i in range(mtp_num_layers): + name = f"mtp_{i+1} loss" + loss = mtp_losses[i] + if total_loss_dict is not None: + total_loss_dict[name] = loss + if writer is not None: + writer.add_scalar(name, loss, iteration) + if wandb_writer is not None: + wandb_writer.log({f"{name}": loss}, iteration) + + MTPLossLoggingHelper.clean_loss_in_tracker() + + +@dataclass +class MultiTokenPredictionLayerSubmodules: + """ + Dataclass for specifying the submodules of a MultiTokenPrediction module. + + Args: + hnorm (Union[ModuleSpec, type]): Specification or instance of the + hidden states normalization to be applied. + enorm (Union[ModuleSpec, type]): Specification or instance of the + embedding normalization to be applied. + eh_proj (Union[ModuleSpec, type]): Specification or instance of the + linear projection to be applied. + transformer_layer (Union[ModuleSpec, type]): Specification + or instance of the transformer block to be applied. + """ + + enorm: Union[ModuleSpec, type] = None + hnorm: Union[ModuleSpec, type] = None + eh_proj: Union[ModuleSpec, type] = None + transformer_layer: Union[ModuleSpec, type] = None + layer_norm: Union[ModuleSpec, type] = None + + +def get_mtp_layer_spec( + transformer_layer_spec: ModuleSpec, use_transformer_engine: bool +) -> ModuleSpec: + """Get the MTP layer spec. + + Returns: + ModuleSpec: Module specification with TE modules + """ + if use_transformer_engine: + if not HAVE_TE: + raise AssertionError("transformer_engine should be installed if use_transformer_engine is True") + layer_norm_impl = TENorm + column_parallel_linear_impl = TEColumnParallelLinear + else: + layer_norm_impl = PTNorm + column_parallel_linear_impl = ColumnParallelLinear + + mtp_layer_spec = ModuleSpec( + module=MultiTokenPredictionLayer, + submodules=MultiTokenPredictionLayerSubmodules( + enorm=layer_norm_impl, + hnorm=layer_norm_impl, + eh_proj=column_parallel_linear_impl, + transformer_layer=transformer_layer_spec, + layer_norm=layer_norm_impl, + ), + ) + + return mtp_layer_spec + + +def get_mtp_layer_offset(config: TransformerConfig) -> int: + """Get the offset of the MTP layer.""" + # Currently, we only support put all of MTP layers on the last pipeline stage. + return 0 + + +def get_mtp_num_layers_to_build(config: TransformerConfig) -> int: + """Get the number of MTP layers to build.""" + # Currently, we only support put all of MTP layers on the last pipeline stage. + args = get_args() + if mpu.is_pipeline_first_stage() and args.schedules_method == "dualpipev" and not args.dualpipev_first_chunk: + return config.mtp_num_layers if config.mtp_num_layers else 0 + if mpu.is_pipeline_last_stage() and not args.schedules_method == "dualpipev": + return config.mtp_num_layers if config.mtp_num_layers else 0 + else: + return 0 + + +class MTPLossAutoScaler(torch.autograd.Function): + """An AutoScaler that triggers the backward pass and scales the grad for mtp loss.""" + + main_loss_backward_scale: torch.Tensor = torch.tensor(1.0) + + @staticmethod + def forward(ctx, output: torch.Tensor, mtp_loss: torch.Tensor): + """Preserve the mtp by storing it in the context to avoid garbage collection. + + Args: + output (torch.Tensor): The output tensor. + mtp_loss (torch.Tensor): The mtp loss tensor. + + Returns: + torch.Tensor: The output tensor. + """ + ctx.save_for_backward(mtp_loss) + return output + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + """Compute and scale the gradient for mtp loss. + + Args: + grad_output (torch.Tensor): The gradient of the output. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The gradient of the output, scaled mtp loss + gradient. + """ + (mtp_loss,) = ctx.saved_tensors + mtp_loss_backward_scale = MTPLossAutoScaler.main_loss_backward_scale + scaled_mtp_loss_grad = torch.ones_like(mtp_loss) * mtp_loss_backward_scale + return grad_output, scaled_mtp_loss_grad + + @staticmethod + def set_loss_scale(scale: torch.Tensor): + """set the scale of the mtp loss. + + Args: + scale (torch.Tensor): The scale value to set. Please ensure that the scale passed in + matches the scale of the main_loss. + """ + MTPLossAutoScaler.main_loss_backward_scale = scale + + +class MultiTokenPredictionLayer(MegatronModule): + """The implementation for Multi-Token Prediction (MTP) which extends + the prediction scope to multiple future tokens at each position. + + This MTP implementation sequentially predict additional tokens and keep the complete + causal chain at each prediction depth, by using D sequential modules to predict + D additional tokens. + + The k-th MTP module consists of a shared embedding layer, a projection matrix, + a Transformer block, and a shared output head. + + For the i-th input token at the (k - 1)-th prediction depth, we first combine + the representation of the i-th token and the embedding of the (i + K)-th token with + the linear projection. The combined serves as the input of the Transformer block at + the k-th depth to produce the output representation. + + for more information, please refer to DeepSeek-V3 Technical Report + """ + + def __init__( + self, + config: TransformerConfig, + submodules: MultiTokenPredictionLayerSubmodules, + layer_number: int = 1, + ): + super().__init__(config=config) + + args = get_args() + self.sequence_parallel = config.sequence_parallel + self.submodules = submodules + self.layer_number = layer_number + self.recompute_mtp_norm = args.recompute_mtp_norm + self.recompute_mtp_layer = args.recompute_mtp_layer + + self_attention_spec = self.submodules.transformer_layer.submodules.self_attention + attn_mask_type = self_attention_spec.params.get('attn_mask_type', '') + if attn_mask_type not in SUPPORTED_ATTN_MASK: + raise AssertionError( + f"Multi-Token Prediction (MTP) is not jet supported with " + + f"{attn_mask_type} attention mask type." + + f"The supported attention mask types are {SUPPORTED_ATTN_MASK}." + ) + + self.enorm = build_module( + self.submodules.enorm, + config=self.config, + hidden_size=self.config.hidden_size, + eps=self.config.layernorm_epsilon, + ) + + self.hnorm = build_module( + self.submodules.hnorm, + config=self.config, + hidden_size=self.config.hidden_size, + eps=self.config.layernorm_epsilon, + ) + + # For the linear projection at the (k - 1)-th MTP layer, the input is the concatenation + # of the i-th tocken's hidden states and the (i + K)-th tocken's decoder input, + # so the input's shape is [s, b, 2*h]. + # The output will be send to the following transformer layer, + # so the output's shape should be [s, b, h]. + self.eh_proj = build_module( + self.submodules.eh_proj, + self.config.hidden_size * 2, + self.config.hidden_size, + config=self.config, + init_method=self.config.init_method, + gather_output=False, + bias=False, + skip_bias_add=False, + is_expert=False, + ) + self.transformer_layer = build_module(self.submodules.transformer_layer, config=self.config) + # set mtp_idx for + self.transformer_layer.mtp_idx = self.layer_number + self.transformer_layer.self_attention.core_attention.mtp_idx = self.layer_number + + def forward( + self, + decoder_input: Tensor, + hidden_states: Tensor, + attention_mask: Tensor, + context: Tensor = None, + context_mask: Tensor = None, + rotary_pos_emb: Tensor = None, + attention_bias: Tensor = None, + inference_params: InferenceParams = None, + packed_seq_params: PackedSeqParams = None, + ): + """ + Perform the forward pass through the MTP layer. + + Args: + hidden_states (Tensor): hidden states tensor of shape [s, b, h] where s is the + sequence length, b is the batch size, and h is the hidden size. + decoder_input (Tensor): Input tensor of shape [s, b, h] where s is the + sequence length, b is the batch size, and h is the hidden size. + At the (k - 1)-th MTP module, the i-th element of decoder input is + the embedding of (i + K)-th tocken. + attention_mask (Tensor): Boolean tensor of shape [1, 1, s, s] for masking + self-attention. + context (Tensor, optional): Context tensor for cross-attention. + context_mask (Tensor, optional): Mask for cross-attention context + rotary_pos_emb (Tensor, optional): Rotary positional embeddings. + attention_bias (Tensor): Bias tensor for Q * K.T of shape in shape broadcastable + to [b, num_head, sq, skv], e.g. [1, 1, sq, skv]. + Used as an alternative to apply attention mask for TE cuDNN attention. + inference_params (InferenceParams, optional): Parameters for inference-time + optimizations. + packed_seq_params (PackedSeqParams, optional): Parameters for packed sequence + processing. + + Returns: + Union[Tensor, Tuple[Tensor, Tensor]]: The output hidden states tensor of shape + [s, b, h], and optionally the updated context tensor if cross-attention is used. + """ + if context is not None: + raise NotImplementedError(f"multi token prediction + cross attention is not yet supported.") + + hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) + + if self.config.sequence_parallel: + rng_context = tensor_parallel.get_cuda_rng_tracker().fork() + else: + rng_context = nullcontext() + + if self.config.fp8: + import transformer_engine # To keep out TE dependency when not training in fp8 + + if self.config.fp8 == "e4m3": + fp8_format = transformer_engine.common.recipe.Format.E4M3 + elif self.config.fp8 == "hybrid": + fp8_format = transformer_engine.common.recipe.Format.HYBRID + else: + raise ValueError("E4M3 and HYBRID are the only supported FP8 formats.") + + fp8_recipe = TEDelayedScaling( + config=self.config, + fp8_format=fp8_format, + override_linear_precision=(False, False, not self.config.fp8_wgrad), + ) + fp8_group = None + if parallel_state.model_parallel_is_initialized(): + fp8_group = parallel_state.get_amax_reduction_group( + with_context_parallel=True, tp_only_amax_red=self.tp_only_amax_red + ) + fp8_context = transformer_engine.pytorch.fp8_autocast( + enabled=True, fp8_recipe=fp8_recipe, fp8_group=fp8_group + ) + else: + fp8_context = nullcontext() + + with rng_context, fp8_context: + + def enorm(tensor): + tensor = self.enorm(tensor) + tensor = make_viewless_tensor( + inp=tensor, requires_grad=True, keep_graph=True + ) + return tensor + + def hnorm(tensor): + tensor = self.hnorm(tensor) + tensor = make_viewless_tensor( + inp=tensor, requires_grad=True, keep_graph=True + ) + return tensor + + if self.recompute_mtp_norm: + self.enorm_ckpt = CheckpointWithoutOutput() + enorm_output = self.enorm_ckpt.checkpoint(enorm, False, decoder_input) + self.hnorm_ckpt = CheckpointWithoutOutput() + hnorm_output = self.hnorm_ckpt.checkpoint(hnorm, False, hidden_states) + else: + enorm_output = enorm(decoder_input) + hnorm_output = hnorm(hidden_states) + # At the (k - 1)-th MTP module, concatenates the i-th tocken's hidden_states + # and the (i + K)-th tocken's embedding, and combine them with linear projection. + hidden_states = torch.cat((enorm_output, hnorm_output), -1) + if self.recompute_mtp_norm: + self.enorm_ckpt.discard_output() + self.hnorm_ckpt.discard_output() + hidden_states.register_hook(self.enorm_ckpt.recompute) + hidden_states.register_hook(self.hnorm_ckpt.recompute) + hidden_states, _ = self.eh_proj(hidden_states) + # For tensor parallel, all gather after linear_fc. + hidden_states = all_gather_last_dim_from_tensor_parallel_region(hidden_states) + # For sequence parallel, scatter after linear_fc and before transformer layer. + if self.sequence_parallel: + hidden_states = scatter_to_sequence_parallel_region(hidden_states) + if self.recompute_mtp_layer: + hidden_states, _ = tensor_parallel.checkpoint( + self.transformer_layer, + self.config.distribute_saved_activations, + hidden_states, + attention_mask, + context, + context_mask, + rotary_pos_emb, + inference_params, + packed_seq_params, + ) + else: + hidden_states, _ = self.transformer_layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + context=context, + context_mask=context_mask, + rotary_pos_emb=rotary_pos_emb, + inference_params=inference_params, + packed_seq_params=packed_seq_params, + ) + + return hidden_states + + def sharded_state_dict( + self, prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[dict] = None + ) -> ShardedStateDict: + """ + Generate a sharded state dictionary for the multi token prediction layer. + + Args: + prefix (str, optional): Prefix to be added to all keys in the state dict. + sharded_offsets (tuple, optional): Tuple of sharding offsets. + metadata (Optional[dict], optional): Additional metadata for sharding. + + Returns: + ShardedStateDict: A dictionary containing the sharded state of the multi + token prediction layer. + """ + sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata) + return sharded_state_dict + + +@dataclass +class MultiTokenPredictionBlockSubmodules: + """ + Dataclass for specifying the submodules of a multi token prediction block. + + This class defines the structure for configuring the layers, allowing for + flexible and customizable architecture designs. + + Args: + layer_specs (List[ModuleSpec], optional): A list of module specifications for + the layers within the multi token prediction block. Each specification typically + defines a complete multi token prediction layer (e.g., shared embedding, + projection matrix, transformer block, shared output head). + """ + + layer_specs: List[ModuleSpec] = None + + +def _get_mtp_block_submodules( + config: TransformerConfig, spec: Union[MultiTokenPredictionBlockSubmodules, ModuleSpec] +) -> MultiTokenPredictionBlockSubmodules: + """ + Retrieve or construct MultiTokenPredictionBlockSubmodules based on the provided specification. + + Args: + config (TransformerConfig): Configuration object for the transformer model. + spec (Union[MultiTokenPredictionBlockSubmodules, ModuleSpec]): Specification for the + multi token prediction block submodules. + Can be either a MultiTokenPredictionBlockSubmodules instance or a ModuleSpec. + + Returns: + MultiTokenPredictionBlockSubmodules: The submodules for the multi token prediction block. + """ + + # Transformer block submodules. + if isinstance(spec, MultiTokenPredictionBlockSubmodules): + return spec + elif isinstance(spec, ModuleSpec): + if issubclass(spec.module, MultiTokenPredictionBlock): + return spec.submodules + else: + raise Exception(f"specialize for {spec.module.__name__}.") + else: + raise Exception(f"specialize for {type(spec).__name__}.") + + +class MultiTokenPredictionBlock(MegatronModule): + """The implementation for Multi-Token Prediction (MTP) which extends + the prediction scope to multiple future tokens at each position. + + This MTP implementation sequentially predict additional tokens and keep the complete + causal chain at each prediction depth, by using D sequential modules to predict + D additional tokens. + + The k-th MTP module consists of a shared embedding layer, a projection matrix, + a Transformer block, and a shared output head. + + For the i-th input token at the (k - 1)-th prediction depth, we first combine + the representation of the i-th token and the embedding of the (i + K)-th token with + the linear projection. The combined serves as the input of the Transformer block at + the k-th depth to produce the output representation. + + for more information, please refer to DeepSeek-V3 Technical Report + """ + + def __init__( + self, config: TransformerConfig, spec: Union[TransformerBlockSubmodules, ModuleSpec] + ): + super().__init__(config=config) + self.submodules = _get_mtp_block_submodules(config, spec) + self.mtp_loss_scaling_factor = config.mtp_loss_scaling_factor + self._build_layers() + if len(self.layers) == 0: + raise AssertionError("MultiTokenPredictionBlock must have at least one layer.") + + def _build_layers(self): + def build_layer(layer_spec, layer_number): + return build_module(layer_spec, config=self.config, layer_number=layer_number) + + self.layers = torch.nn.ModuleList( + [ + build_layer(layer_spec, i + 1) + for i, layer_spec in enumerate(self.submodules.layer_specs) + ] + ) + self.final_layernorms = torch.nn.ModuleList( + [ + build_module( + layer_spec.submodules.layer_norm, + config=self.config, + hidden_size=self.config.hidden_size, + eps=self.config.layernorm_epsilon, + ) + for i, layer_spec in enumerate(self.submodules.layer_specs) + ] + ) + + def forward( + self, + input_ids: Tensor, + position_ids: Tensor, + hidden_states: Tensor, + attention_mask: Tensor, + labels: Tensor = None, + context: Tensor = None, + context_mask: Tensor = None, + rotary_pos_emb: Tensor = None, + inference_params: InferenceParams = None, + packed_seq_params: PackedSeqParams = None, + extra_block_kwargs: dict = None, + loss_mask: Optional[Tensor] = None, + embedding=None, + output_layer=None, + output_weight: Optional[torch.Tensor] = None, + compute_language_model_loss=None, + ) -> Tensor: + """ + Perform the forward pass through all of the MTP modules. + + Args: + hidden_states (Tensor): Hidden states for input token with the shape [s, b, h] + where s is the sequence length, b is the batch size, and h is the hidden size. + attention_mask (Tensor): Boolean tensor of shape [1, 1, s, s] for masking + self-attention. + + Returns: + (Tensor): The mtp loss tensor of shape [b, s]. + """ + if labels is None: + raise AssertionError(f"labels should not be None for calculating multi token prediction loss.") + + args = get_args() + if loss_mask is None: + # if loss_mask is not provided, use all ones as loss_mask + loss_mask = torch.ones_like(labels) + + hidden_states_main_model = hidden_states + for layer_number in range(len(self.layers)): + # Calc logits for the current Multi-Token Prediction (MTP) layers. + input_ids, _ = roll_tensor(input_ids, shifts=-1, dims=-1) + if args.reset_position_ids: + position_ids, _ = roll_tensor(position_ids, shifts=-1, dims=-1) + position_ids = regenerate_position_ids(position_ids, 1) + # embedding + decoder_input = embedding(input_ids=input_ids, position_ids=position_ids) + # norm, linear projection and transformer + hidden_states = self.layers[layer_number]( + decoder_input=decoder_input, + hidden_states=hidden_states, + attention_mask=attention_mask, + inference_params=inference_params, + rotary_pos_emb=rotary_pos_emb, + packed_seq_params=packed_seq_params, + **(extra_block_kwargs or {}), + ) + # Layer norm before shared head layer. + hidden_states = self.final_layernorms[layer_number](hidden_states) + hidden_states = make_viewless_tensor( + inp=hidden_states, requires_grad=True, keep_graph=True + ) + # output + mtp_logits, _ = output_layer( + hidden_states, weight=output_weight + ) + # Calc loss for the current Multi-Token Prediction (MTP) layers. + labels, _ = roll_tensor(labels, shifts=-1, dims=-1) + loss_mask, num_tokens = roll_tensor(loss_mask, shifts=-1, dims=-1) + mtp_loss = compute_language_model_loss(labels, mtp_logits) + mtp_loss = loss_mask * mtp_loss + if self.training: + MTPLossLoggingHelper.save_loss_to_tracker( + torch.sum(mtp_loss) / num_tokens, + layer_number, + self.config.mtp_num_layers, + avg_group=parallel_state.get_tensor_and_context_parallel_group(), + ) + mtp_loss_scale = self.mtp_loss_scaling_factor / self.config.mtp_num_layers + if self.config.calculate_per_token_loss: + hidden_states_main_model = MTPLossAutoScaler.apply( + hidden_states_main_model, mtp_loss_scale * mtp_loss + ) + else: + hidden_states_main_model = MTPLossAutoScaler.apply( + hidden_states_main_model, mtp_loss_scale * mtp_loss / num_tokens + ) + + return hidden_states_main_model + + def sharded_state_dict( + self, prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[dict] = None + ) -> ShardedStateDict: + """ + Generate a sharded state dictionary for the multi token prediction module. + + Args: + prefix (str, optional): Prefix to be added to all keys in the state dict. + sharded_offsets (tuple, optional): Tuple of sharding offsets. + metadata (Optional[dict], optional): Additional metadata for sharding. + + Returns: + ShardedStateDict: A dictionary containing the sharded state of the multi + token prediction module. + """ + sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata) + layer_prefix = f'{prefix}layers.' + for layer in self.layers: + offset = get_mtp_layer_offset(self.config) + sharded_prefix = f'{layer_prefix}{layer.layer_number - 1 }.' + + state_dict_prefix = f'{layer_prefix}{layer.layer_number - 1 - offset}.' + sharded_pp_offset = [] + layer_sharded_state_dict = layer.sharded_state_dict( + state_dict_prefix, sharded_pp_offset, metadata + ) + replace_prefix_for_sharding(layer_sharded_state_dict, state_dict_prefix, sharded_prefix) + sharded_state_dict.update(layer_sharded_state_dict) + return sharded_state_dict diff --git a/mindspeed_llm/core/transformer/transformer_block.py b/mindspeed_llm/core/transformer/transformer_block.py index 72023dfdf1c28190e37ea6fbb8545865cd290824..ccff401b127cf93a77a37c8a06a603ab5134ac67 100644 --- a/mindspeed_llm/core/transformer/transformer_block.py +++ b/mindspeed_llm/core/transformer/transformer_block.py @@ -103,8 +103,8 @@ def _transformer_block_build_layers(self): ) # mtp require seperate layernorms for main model and mtp modules, thus move finalnorm out of block - move_final_norm_out_of_block = args.num_nextn_predict_layers > 0 - if self.submodules.layer_norm and self.post_process and self.post_layer_norm and not move_final_norm_out_of_block: + init_block_fn_flag = self.post_layer_norm and not args.mtp_num_layers + if self.submodules.layer_norm and self.post_process and init_block_fn_flag: self.final_layernorm = build_module( self.submodules.layer_norm, config=self.config, diff --git a/mindspeed_llm/core/transformer/transformer_config.py b/mindspeed_llm/core/transformer/transformer_config.py new file mode 100644 index 0000000000000000000000000000000000000000..744eb7bb912a66293ae4b082c560100abd0fbf24 --- /dev/null +++ b/mindspeed_llm/core/transformer/transformer_config.py @@ -0,0 +1,16 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +from functools import wraps + +from megatron.training import get_args + + +def transformer_config_post_init_mtp_wrapper(fn): + @wraps(fn) + def wrapper(self): + fn(self) + args = get_args() + + self.mtp_num_layers = args.mtp_num_layers + self.mtp_loss_scaling_factor = args.mtp_loss_scaling_factor + + return wrapper diff --git a/mindspeed_llm/core/transformer/transformer_layer.py b/mindspeed_llm/core/transformer/transformer_layer.py index 62ae929870909a82905cbc1928a9322128237b2a..165779303c01d34f44b65f96bee4e22ea10e8996 100644 --- a/mindspeed_llm/core/transformer/transformer_layer.py +++ b/mindspeed_llm/core/transformer/transformer_layer.py @@ -49,6 +49,9 @@ class TransformerLayer(MegatronTransformerLayer): expert.layer_number = self.layer_number else: self.mlp.layer_number = self.layer_number + # set mtp_idx + self.mtp_idx = 0 + self.self_attention.core_attention.mtp_idx = 0 def forward(self, hidden_states, attention_mask, context=None, context_mask=None, diff --git a/mindspeed_llm/legacy/data/data_samplers.py b/mindspeed_llm/legacy/data/data_samplers.py index 48a9100c3be810c21b3e4a02c4248f88b76ff401..adf7cf13264826880fb4224ea51b76d1e6df1f76 100644 --- a/mindspeed_llm/legacy/data/data_samplers.py +++ b/mindspeed_llm/legacy/data/data_samplers.py @@ -59,7 +59,7 @@ def build_pretraining_data_loader(dataset, consumed_samples): tokenizer.padding_side = args.tokenizer_padding_side collator = PairwiseDataCollatorWithPadding( tokenizer, - pad_to_multiple_of=args.pad_to_multiple_of if args.variable_seq_lengths else args.seq_length + args.num_nextn_predict_layers, + pad_to_multiple_of=args.pad_to_multiple_of if args.variable_seq_lengths else args.seq_length, return_tensors='pt', padding=True ) @@ -67,7 +67,7 @@ def build_pretraining_data_loader(dataset, consumed_samples): tokenizer.padding_side = args.tokenizer_padding_side collator = DataCollatorForSeq2Seq( tokenizer, - pad_to_multiple_of=args.pad_to_multiple_of if args.variable_seq_lengths else args.seq_length + args.num_nextn_predict_layers, + pad_to_multiple_of=args.pad_to_multiple_of if args.variable_seq_lengths else args.seq_length, return_tensors='pt', padding=True ) diff --git a/mindspeed_llm/tasks/megatron_adaptor.py b/mindspeed_llm/tasks/megatron_adaptor.py index 7c616b63cfb296fc58adbd4b61a5518081142b5d..bfdab17004d5512ad0224d1ebfe4fa7908785955 100644 --- a/mindspeed_llm/tasks/megatron_adaptor.py +++ b/mindspeed_llm/tasks/megatron_adaptor.py @@ -248,15 +248,14 @@ class CoreAdaptation(MegatronAdaptationABC): from mindspeed.core.models.common.embeddings.language_model_embedding import language_model_embedding_forward_wrapper from mindspeed.core.data_parallel.distributed_data_parallel import distributed_data_parallel_init_with_cp from mindspeed.core.transformer.attention import attention_init, self_attention_init_wrapper - from ..core.models.common.embeddings.language_model_embedding import ( - language_model_embedding_forward, language_model_embedding_init_func) from ..training.utils import get_batch_on_this_cp_rank, get_batch_on_this_tp_rank, get_device_wrapper from ..core import rotary_embedding_forward, apply_rotary_pos_emb_bshd from ..core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec_wrapper from ..core.transformer.dot_product_attention import dot_product_attention_init, \ dot_product_attention_forward_wrapper, ulysses_context_parallel_forward_wrapper - from ..core.models.gpt.gpt_model import gpt_model_init_wrapper, shared_embedding_weight - from ..core import rotary_embedding_init_wrapper, gpt_model_forward + from ..core.models.gpt.gpt_model import GPTModel + from ..core import rotary_embedding_init_wrapper + args = MegatronAdaptation.get_args() # Embedding @@ -276,12 +275,18 @@ class CoreAdaptation(MegatronAdaptationABC): MegatronAdaptation.register( 'megatron.core.models.common.embeddings.rotary_pos_embedding.RotaryEmbedding.get_rotary_seq_len', rotary_embedding_get_rotary_seq_len_wrapper) + from ..core.models.common.language_module.language_module import ( + setup_embeddings_and_output_layer, + tie_embeddings_and_output_weights_state_dict, + ) MegatronAdaptation.register( - 'megatron.core.models.common.embeddings.language_model_embedding.LanguageModelEmbedding.__init__', - language_model_embedding_init_func) + 'megatron.core.models.common.language_module.language_module.LanguageModule' + '.setup_embeddings_and_output_layer', + setup_embeddings_and_output_layer) MegatronAdaptation.register( - 'megatron.core.models.common.embeddings.language_model_embedding.LanguageModelEmbedding.forward', - language_model_embedding_forward) + 'megatron.core.models.common.language_module.language_module.LanguageModule' + '.tie_embeddings_and_output_weights_state_dict', + tie_embeddings_and_output_weights_state_dict) MegatronAdaptation.register( 'megatron.core.models.common.embeddings.language_model_embedding.LanguageModelEmbedding.forward', language_model_embedding_forward_wrapper) @@ -324,13 +329,12 @@ class CoreAdaptation(MegatronAdaptationABC): MegatronAdaptation.register('megatron.training.utils.get_batch_on_this_cp_rank', get_batch_on_this_cp_rank) MegatronAdaptation.register('megatron.training.utils.get_batch_on_this_tp_rank', get_batch_on_this_tp_rank) MegatronAdaptation.register('megatron.training.dist_signal_handler.get_device', get_device_wrapper) - if not args.moe_fb_overlap: - MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.forward', gpt_model_forward) - MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.__init__', gpt_model_init_wrapper) - - from megatron.core.models.gpt.gpt_model import GPTModel - setattr(GPTModel, 'shared_embedding_weight', shared_embedding_weight) - + # moe_fb_overlap will shadow this forward impl + MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel', GPTModel) + from ..core.models.common.embeddings.language_model_embedding import language_model_embedding_init_func + MegatronAdaptation.register( + 'megatron.core.models.common.embeddings.language_model_embedding.LanguageModelEmbedding.__init__', + language_model_embedding_init_func) # For recomputation args = MegatronAdaptation.get_args() if args.share_kvstates: @@ -369,6 +373,10 @@ class CoreAdaptation(MegatronAdaptationABC): MegatronAdaptation.register('megatron.core.transformer.transformer_config.TransformerConfig.__post_init__', transformer_config_post_init_wrapper) + # for mtp + from ..core.transformer.transformer_config import transformer_config_post_init_mtp_wrapper + MegatronAdaptation.register('megatron.core.transformer.transformer_config.TransformerConfig.__post_init__', + transformer_config_post_init_mtp_wrapper) MegatronAdaptation.register('torch.cuda.get_device_capability', get_device_capability) megatron.core.transformer.transformer_block.LayerNormImpl = PTNorm MegatronAdaptation.register('megatron.core.transformer.transformer_block.TENorm', PTNorm) @@ -529,16 +537,20 @@ class CoreAdaptation(MegatronAdaptationABC): from ..core.pipeline_parallel.schedules import get_forward_backward_func_wrapper MegatronAdaptation.register('megatron.core.pipeline_parallel.schedules.get_forward_backward_func', get_forward_backward_func_wrapper) + # for mtp + from ..core.pipeline_parallel.schedules import forward_step_wrapper + MegatronAdaptation.register('megatron.core.pipeline_parallel.schedules.forward_step', forward_step_wrapper) + def patch_tensor_parallel(self): from mindspeed.core.tensor_parallel.random import _set_cuda_rng_state - from ..core import vocab_parallel_embedding_forward, vocab_embedding_init_func, checkpoint_forward_wrapper, checkpoint_backward_wrapper + from ..core import vocab_embedding_init_func, vocab_parallel_embedding_forward # default_generators need replace after set_device MegatronAdaptation.register('megatron.core.tensor_parallel.random._set_cuda_rng_state', _set_cuda_rng_state) - + # change masked_target for better performance if MegatronAdaptation.get_args().mtp_mem_efficient_logits: - from ..tasks.models.transformer.multi_token_predication import calculate_logits_max, calculate_predicted_logits + from ..core.tensor_parallel.cross_entropy import calculate_logits_max, calculate_predicted_logits MegatronAdaptation.register( 'megatron.core.tensor_parallel.cross_entropy.VocabParallelCrossEntropy.calculate_logits_max', calculate_logits_max) @@ -550,15 +562,11 @@ class CoreAdaptation(MegatronAdaptationABC): MegatronAdaptation.register( 'megatron.core.tensor_parallel.cross_entropy.VocabParallelCrossEntropy.calculate_predicted_logits', calculate_predicted_logits) - + MegatronAdaptation.register('megatron.core.tensor_parallel.layers.VocabParallelEmbedding.forward', vocab_parallel_embedding_forward) MegatronAdaptation.register('megatron.core.tensor_parallel.layers.VocabParallelEmbedding.__init__', vocab_embedding_init_func) - MegatronAdaptation.register('megatron.core.tensor_parallel.random.CheckpointFunction.forward', - checkpoint_forward_wrapper) - MegatronAdaptation.register('megatron.core.tensor_parallel.random.CheckpointFunction.backward', - checkpoint_backward_wrapper) # For recompute-in-advance from mindspeed.core.tensor_parallel.random import checkpoint_wrapper MegatronAdaptation.register('megatron.core.tensor_parallel.random.checkpoint', checkpoint_wrapper) @@ -636,7 +644,6 @@ class CoreAdaptation(MegatronAdaptationABC): add_item_wrapper) MegatronAdaptation.register('megatron.core.datasets.indexed_dataset.IndexedDatasetBuilder.finalize', finalize_wrapper) - # MTP need extra token from ..core.datasets.gpt_dataset import ( gpt_dataset_getitem_wrapper, _get_ltor_masks_and_position_ids @@ -651,6 +658,8 @@ class CoreAdaptation(MegatronAdaptationABC): from mindspeed_llm.training.utils import unwrap_model_wrapper MegatronAdaptation.register('megatron.training.checkpointing.unwrap_model', unwrap_model_wrapper) MegatronAdaptation.register('megatron.training.training.unwrap_model', unwrap_model_wrapper) + from ..training.training import training_log + MegatronAdaptation.register('megatron.training.training.training_log', training_log) from mindspeed_llm.training.utils import generate_adaptive_cp_mask_list_by_user, generate_adaptive_cp_grid_mask_by_user MegatronAdaptation.register('mindspeed.core.context_parallel.utils.generate_adaptive_cp_mask_list_by_user', diff --git a/mindspeed_llm/tasks/models/spec/mtp_spec.py b/mindspeed_llm/tasks/models/spec/mtp_spec.py deleted file mode 100644 index ce34418947b700a1b7e8c870abd26904b8c0833f..0000000000000000000000000000000000000000 --- a/mindspeed_llm/tasks/models/spec/mtp_spec.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. - -""" -Multi Token Predication Layer Specification. -""" - -from megatron.core.tensor_parallel import ColumnParallelLinear -from megatron.core.transformer import ModuleSpec -from mindspeed_llm.core.transformer.custom_layers.transformer_engine import PTNorm -from mindspeed_llm.tasks.models.transformer.multi_token_predication import MultiTokenPredicationSubmodules, \ - MultiTokenPredication - - -# Use this spec for multi token predication -mtp_sepc = ModuleSpec( - module=MultiTokenPredication, - submodules=MultiTokenPredicationSubmodules( - embedding=None, - enorm=PTNorm, - hnorm=PTNorm, - eh_proj=ColumnParallelLinear, - transformer_layer=None, - final_layernorm=PTNorm, - output_layer=None, - ) -) diff --git a/mindspeed_llm/tasks/models/transformer/multi_token_predication.py b/mindspeed_llm/tasks/models/transformer/multi_token_predication.py deleted file mode 100644 index eca6360249cf5352ac5f79eeacf0ec2204ffcba2..0000000000000000000000000000000000000000 --- a/mindspeed_llm/tasks/models/transformer/multi_token_predication.py +++ /dev/null @@ -1,379 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. -import logging -from dataclasses import dataclass -from typing import Union, Optional, Literal, Tuple - -import torch -from torch import Tensor - -from megatron.core import tensor_parallel, InferenceParams -from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding -from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding -from megatron.core.packed_seq_params import PackedSeqParams -from megatron.core.transformer.module import MegatronModule -from megatron.core.fusions.fused_cross_entropy import fused_vocab_parallel_cross_entropy - -from megatron.core.transformer import ModuleSpec, TransformerConfig, build_module -from megatron.training import get_args -from mindspeed.core.tensor_parallel.random import CheckpointWithoutOutput -from mindspeed_llm.core.tensor_parallel.layers import SegmentedColumnParallelLinear - - -@dataclass -class MultiTokenPredicationSubmodules: - embedding: Union[ModuleSpec, type] = None - output_layer: Union[ModuleSpec, type] = None - eh_proj: Union[ModuleSpec, type] = None - enorm: Union[ModuleSpec, type] = None - hnorm: Union[ModuleSpec, type] = None - transformer_layer: Union[ModuleSpec, type] = None - final_layernorm: Union[ModuleSpec, type] = None - - -class MultiTokenPredication(MegatronModule): - def __init__( - self, - config: TransformerConfig, - transformer_layer_spec: ModuleSpec, - submodules: MultiTokenPredicationSubmodules, - vocab_size: int, - max_sequence_length: int, - layer_number: int = 1, - hidden_dropout: float = None, - pre_process: bool = True, - post_process: bool = True, - fp16_lm_cross_entropy: bool = False, - parallel_output: bool = True, - position_embedding_type: Literal['learned_absolute', 'rope', 'none'] = 'learned_absolute', - rotary_percent: float = 1.0, - rotary_base: int = 10000, - seq_len_interpolation_factor: Optional[float] = None, - share_mtp_embedding_and_output_weight=True, - ): - super().__init__(config=config) - args = get_args() - - self.config = config - self.submodules = submodules - - if transformer_layer_spec is not None: - self.transformer_layer_spec = transformer_layer_spec - self.submodules.transformer_layer = self.transformer_layer_spec - self.layer_number = layer_number - self.hidden_dropout = hidden_dropout - self.hidden_size = args.hidden_size - self.ffn_hidden_size = args.ffn_hidden_size - self.vocab_size = vocab_size - self.max_sequence_length = max_sequence_length - self.pre_process = pre_process - self.post_process = post_process - self.fp16_lm_cross_entropy = fp16_lm_cross_entropy - self.parallel_output = parallel_output - self.position_embedding_type = position_embedding_type - self.num_nextn_predict_layers = args.num_nextn_predict_layers - # share with main model - self.share_mtp_embedding_and_output_weight = share_mtp_embedding_and_output_weight - self.recompute_layer_norm = args.recompute_mtp_norm - self.recompute_mtp_layer = args.recompute_mtp_layer - - skip_embedding_allocation = (args.schedules_method == 'dualpipev' and self.share_mtp_embedding_and_output_weight) or \ - (self.pre_process and self.share_mtp_embedding_and_output_weight) - self.embedding = LanguageModelEmbedding( - config=self.config, - vocab_size=self.vocab_size, - max_sequence_length=self.max_sequence_length, - position_embedding_type=self.position_embedding_type, - skip_weight_param_allocation=skip_embedding_allocation - ) - - if self.position_embedding_type == 'rope': - self.rotary_pos_emb = RotaryEmbedding( - kv_channels=self.config.kv_channels, - rotary_percent=rotary_percent, - rotary_interleaved=self.config.rotary_interleaved, - seq_len_interpolation_factor=seq_len_interpolation_factor, - rotary_base=rotary_base, - use_cpu_initialization=self.config.use_cpu_initialization, - ) - - self.enorm = build_module( - self.submodules.enorm, - config=self.config, - hidden_size=self.config.hidden_size, - eps=self.config.layernorm_epsilon, - ) - - self.hnorm = build_module( - self.submodules.hnorm, - config=self.config, - hidden_size=self.config.hidden_size, - eps=self.config.layernorm_epsilon, - ) - - self.eh_proj = build_module( - self.submodules.eh_proj, - self.hidden_size + self.hidden_size, - self.hidden_size, - config=self.config, - init_method=self.config.init_method, - gather_output=False, - bias=self.config.add_bias_linear, - skip_bias_add=True, - tp_comm_buffer_name='eh', - ) - - self.transformer_layer = build_module( - self.submodules.transformer_layer, - config=self.config, - ) - - if self.submodules.final_layernorm: - self.final_layernorm = build_module( - self.submodules.final_layernorm, - config=self.config, - hidden_size=self.config.hidden_size, - eps=self.config.layernorm_epsilon, - ) - else: - self.final_layernorm = None - - if self.config.defer_embedding_wgrad_compute: - self.embedding_activation_buffer = [] - self.grad_output_buffer = [] - else: - self.embedding_activation_buffer = None - self.grad_output_buffer = None - self.output_layer = tensor_parallel.ColumnParallelLinear( - config.hidden_size, - self.vocab_size, - config=config, - init_method=config.init_method, - bias=False, - skip_bias_add=False, - gather_output=not self.parallel_output, - skip_weight_param_allocation=self.share_mtp_embedding_and_output_weight, - embedding_activation_buffer=self.embedding_activation_buffer, - grad_output_buffer=self.grad_output_buffer, - ) - if args.add_output_layer_bias: - self.output_layer = tensor_parallel.ColumnParallelLinear( - config.hidden_size, - self.vocab_size, - config=config, - init_method=config.init_method, - bias=True, - skip_bias_add=False, - gather_output=not self.parallel_output, - skip_weight_param_allocation=self.share_mtp_embedding_and_output_weight, - embedding_activation_buffer=self.embedding_activation_buffer, - grad_output_buffer=self.grad_output_buffer, - ) - - if args.output_layer_slice_num > 1: - self.output_layer = SegmentedColumnParallelLinear( - config.hidden_size, - self.vocab_size, - config=config, - init_method=config.init_method, - bias=False, - skip_bias_add=False, - gather_output=not self.parallel_output, - skip_weight_param_allocation=self.share_mtp_embedding_and_output_weight, - embedding_activation_buffer=self.embedding_activation_buffer, - grad_output_buffer=self.grad_output_buffer, - ) - - def forward( - self, - hidden_input_ids: Tensor, - embed_input_ids: Tensor, - position_ids: Tensor, - attention_mask: Tensor, - decoder_input: Tensor = None, - labels: Tensor = None, - inference_params: InferenceParams = None, - packed_seq_params: PackedSeqParams = None, - extra_block_kwargs: dict = None, - embeding_weight: Optional[torch.Tensor] = None, - output_weight: Optional[torch.Tensor] = None, - ): - """Forward function of the MTP module""" - args = get_args() - if not self.training and (hasattr(args, "rope_scaling_type") and args.rope_scaling_type == "longrope"): - args.rope_scaling_original_max_position_embeddings = args.max_position_embeddings - # Decoder embedding. - decoder_input = self.embedding( - input_ids=embed_input_ids, - position_ids=position_ids, - weight=embeding_weight, - ) - if args.scale_emb is not None: - decoder_input = decoder_input * args.scale_emb - - # Rotary positional embeddings (embedding is None for PP intermediate devices) - rotary_pos_emb = None - if self.position_embedding_type == 'rope': - if inference_params is not None: - rotary_seq_len = inference_params.max_sequence_length - else: - rotary_seq_len = decoder_input.size(0) - - if self.config.sequence_parallel: - rotary_seq_len *= self.config.tensor_model_parallel_size - - rotary_seq_len *= self.config.context_parallel_size - rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len) - if self.recompute_layer_norm: - self.enorm_ckpt = CheckpointWithoutOutput() - enorm_output = self.enorm_ckpt.checkpoint(self.enorm, False, decoder_input) - self.hnorm_ckpt = CheckpointWithoutOutput() - hnorm_output = self.hnorm_ckpt.checkpoint(self.hnorm, False, hidden_input_ids) - else: - enorm_output = self.enorm(decoder_input) - hnorm_output = self.hnorm(hidden_input_ids) - - # [s, b, h] -> [s, b, 2h] - hidden_states = torch.concat( - [hnorm_output, - enorm_output], - dim=-1 - ) - - if self.recompute_layer_norm: - self.enorm_ckpt.discard_output() - self.hnorm_ckpt.discard_output() - hidden_states.register_hook(self.enorm_ckpt.recompute) - hidden_states.register_hook(self.hnorm_ckpt.recompute) - # hidden_states -> [s, b, h] - hidden_states, _ = self.eh_proj(hidden_states) - - if self.config.tensor_model_parallel_size > 1: - hidden_states = tensor_parallel.gather_from_tensor_model_parallel_region(hidden_states) - if self.config.sequence_parallel: - hidden_states = tensor_parallel.scatter_to_sequence_parallel_region(hidden_states) - if self.recompute_mtp_layer: - hidden_states, context = tensor_parallel.checkpoint( - self.transformer_layer, - self.config.distribute_saved_activations, - hidden_states, - attention_mask, - None, - None, - rotary_pos_emb, - inference_params, - packed_seq_params, - ) - else: - hidden_states, _ = self.transformer_layer( - hidden_states=hidden_states, - attention_mask=attention_mask, - rotary_pos_emb=rotary_pos_emb, - inference_params=inference_params, - packed_seq_params=packed_seq_params, - **(extra_block_kwargs or {}), - ) - - # Final layer norm. - if self.final_layernorm is not None: - if self.recompute_layer_norm: - self.finalnorm_ckpt = CheckpointWithoutOutput() - finalnorm_output = self.finalnorm_ckpt.checkpoint(self.final_layernorm, False, hidden_states) - else: - finalnorm_output = self.final_layernorm(hidden_states) - else: - finalnorm_output = hidden_states - - if args.dim_model_base is not None: - finalnorm_output = finalnorm_output / (args.hidden_size / args.dim_model_base) - logits, _ = self.output_layer(finalnorm_output, weight=output_weight) - - if self.recompute_layer_norm: - self.finalnorm_ckpt.discard_output() - logits.register_hook(self.finalnorm_ckpt.recompute) - if args.output_multiplier_scale: - logits = logits * args.output_multiplier_scale - - if args.output_logit_softcapping: - logits = logits / args.output_logit_softcapping - logits = torch.tanh(logits) - logits = logits * args.output_logit_softcapping - - if labels is None: - # [s b h] => [b s h] - return logits.transpose(0, 1).contiguous() - - if args.is_instruction_dataset: - labels = labels[:, 1:].contiguous() - logits = logits[:-1, :, :].contiguous() - - loss = self.compute_language_model_loss(labels, logits) - return hidden_states, loss - - def compute_language_model_loss(self, labels: Tensor, logits: Tensor) -> Tensor: - """Computes the language model loss (Cross entropy across vocabulary) - - Args: - labels (Tensor): The labels of dimension [batch size, seq length] - logits (Tensor): The final logits returned by the output layer of the transformer model - - Returns: - Tensor: Loss tensor of dimensions [batch size, sequence_length] - """ - # [b s] => [s b] - labels = labels.transpose(0, 1).contiguous() - if self.config.cross_entropy_loss_fusion: - loss = fused_vocab_parallel_cross_entropy(logits, labels) - else: - loss = tensor_parallel.vocab_parallel_cross_entropy(logits, labels) - - # [s b] => [b, s] - loss = loss.transpose(0, 1).contiguous() - return loss - - -def calculate_logits_max( - vocab_parallel_logits: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - if vocab_parallel_logits.dtype != torch.float32: - vocab_parallel_logits_fp32 = vocab_parallel_logits.float() - vocab_parallel_logits.untyped_storage().resize_(0) - else: - vocab_parallel_logits_fp32 = vocab_parallel_logits.float() - # Maximum value along vocab dimension across all GPUs. - logits_max = torch.max(vocab_parallel_logits_fp32, dim=-1)[0] - - return vocab_parallel_logits_fp32, logits_max - - -def calculate_predicted_logits( - vocab_parallel_logits: torch.Tensor, - target: torch.Tensor, - logits_max: torch.Tensor, - vocab_start_index: int, - vocab_end_index: int, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - # subtraction the maximum value. - # Use in-place to reduce memory pressure. - vocab_parallel_logits -= logits_max.unsqueeze(dim=-1) - - # Create a mask of valid vocab ids (1 means it needs to be masked). - target_mask = (target < vocab_start_index) | (target >= vocab_end_index) - masked_target = target.clone() - vocab_start_index - masked_target *= ~target_mask - - # Get predicted-logits = logits[target]. - # For Simplicity, we convert logits to a 2-D tensor with size - # [*, partition-vocab-size] and target to a 1-D tensor of size [*]. - partition_vocab_size = vocab_parallel_logits.size()[-1] - logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size) - masked_target_1d = masked_target.view(-1) - arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device) - predicted_logits_1d = logits_2d[arange_1d, masked_target_1d] - predicted_logits_1d = predicted_logits_1d.clone().contiguous() - predicted_logits = predicted_logits_1d.view_as(target) - predicted_logits *= ~target_mask - - exp_logits = vocab_parallel_logits - torch.exp(vocab_parallel_logits, out=exp_logits) - sum_exp_logits = exp_logits.sum(dim=-1) - return target_mask, masked_target_1d, predicted_logits, sum_exp_logits, exp_logits diff --git a/mindspeed_llm/tasks/posttrain/base/base_trainer.py b/mindspeed_llm/tasks/posttrain/base/base_trainer.py index 627d75c88f156cd3abeeca0d9fd6fd74012509fa..16b69da2b5705e78f4bc7f3cf3afac06adf418fb 100644 --- a/mindspeed_llm/tasks/posttrain/base/base_trainer.py +++ b/mindspeed_llm/tasks/posttrain/base/base_trainer.py @@ -20,6 +20,7 @@ from megatron.core.models.gpt.gpt_layer_specs import ( ) from megatron.core.models.gpt import GPTModel from megatron.training.checkpointing import save_checkpoint +from mindspeed_llm.core.models.gpt.gpt_layer_specs import get_gpt_mtp_block_spec from mindspeed_llm.training import build_train_args from mindspeed_llm.training import train from mindspeed_llm.training.initialize import set_jit_fusion_options @@ -124,6 +125,9 @@ class BaseTrainer(ABC): args.moe_grouped_gemm) else: transformer_layer_spec = get_gpt_layer_local_spec(args.num_experts, args.moe_grouped_gemm) + mtp_block_spec = None + if args.mtp_num_layers is not None: + mtp_block_spec = get_gpt_mtp_block_spec(config, transformer_layer_spec, use_transformer_engine=use_te) model = GPTModel( config=config, @@ -137,7 +141,8 @@ class BaseTrainer(ABC): share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights, position_embedding_type=args.position_embedding_type, rotary_percent=args.rotary_percent, - seq_len_interpolation_factor=args.rotary_seq_len_interpolation_factor + seq_len_interpolation_factor=args.rotary_seq_len_interpolation_factor, + mtp_block_spec=mtp_block_spec, ) else: if not args.context_parallel_size == 1: diff --git a/mindspeed_llm/tasks/posttrain/sft/sft_trainer.py b/mindspeed_llm/tasks/posttrain/sft/sft_trainer.py index f29bcd0d6819fcee2f51f0b2b05f231b2ecb2215..6e7f0b0ed44fb5f146af324eecd9bfc032f56c69 100644 --- a/mindspeed_llm/tasks/posttrain/sft/sft_trainer.py +++ b/mindspeed_llm/tasks/posttrain/sft/sft_trainer.py @@ -55,15 +55,6 @@ class SFTTrainer(BaseTrainer): # ignored label -100 loss_mask = torch.where(labels == IGNORE_INDEX, 0, 1) - # Adapt to MTP - if args.variable_seq_lengths and args.num_nextn_predict_layers: - tokenizer = get_tokenizer().tokenizer - pad_tensor = torch.ones((labels.shape[0], args.num_nextn_predict_layers)).to(labels.device) - labels = torch.cat([labels, pad_tensor.to(labels.dtype) * IGNORE_INDEX], -1) - tokens = torch.cat([tokens, pad_tensor.to(tokens.dtype) * tokenizer.pad_token_id], -1) - attention_mask_1d = torch.cat([attention_mask_1d, pad_tensor.to(attention_mask_1d.dtype) * 0], -1) - loss_mask = torch.cat([loss_mask, pad_tensor.to(loss_mask.dtype) * 0], -1) - if get_args().spec is not None and args.spec[0] == "mindspeed_llm.tasks.models.spec.hunyuan_spec": input_ids = tokens pad_id = 127961 @@ -87,11 +78,6 @@ class SFTTrainer(BaseTrainer): position_ids = data_b.get('position_ids').long() generate_actual_seq_len(data_b) - # Adapt to MTP - if args.num_nextn_predict_layers: - pad_tensor = torch.zeros((labels.shape[0], args.num_nextn_predict_layers)).to(labels.device) - position_ids = torch.cat([position_ids, pad_tensor.to(position_ids.dtype)], -1) - batch = { 'tokens': tokens, 'labels': labels, @@ -151,6 +137,7 @@ class SFTTrainer(BaseTrainer): data_iterator : Input data iterator model (GPTModel): The GPT Model """ + args = get_args() timers = get_timers() # Get the batch. @@ -159,11 +146,11 @@ class SFTTrainer(BaseTrainer): data_iterator) timers('batch-generator').stop() - output_tensor = model(tokens, position_ids, attention_mask, - labels=labels) - - if self.args.num_nextn_predict_layers and loss_mask is not None: - return output_tensor, partial(self.loss_func, - loss_mask[:, :loss_mask.shape[-1] - self.args.num_nextn_predict_layers]) + if args.use_legacy_models: + output_tensor = model(tokens, position_ids, attention_mask, + labels=labels) else: - return output_tensor, partial(self.loss_func, loss_mask) \ No newline at end of file + output_tensor = model(tokens, position_ids, attention_mask, + labels=labels, loss_mask=loss_mask) + + return output_tensor, partial(self.loss_func, loss_mask) \ No newline at end of file diff --git a/mindspeed_llm/tasks/posttrain/utils.py b/mindspeed_llm/tasks/posttrain/utils.py index efe6b1623072244120a48a6376c988626b180df3..5bcaab57f61e0f6e039b7026c20323b29b56ef48 100644 --- a/mindspeed_llm/tasks/posttrain/utils.py +++ b/mindspeed_llm/tasks/posttrain/utils.py @@ -63,7 +63,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): data_prefix=args.data_path, splits_string=args.split, train_valid_test_num_samples=train_val_test_num_samples, - seq_length=args.seq_length + args.num_nextn_predict_layers, + seq_length=args.seq_length, seed=args.seed) else: train_ds, valid_ds, test_ds = BlendedMegatronDatasetBuilder( diff --git a/mindspeed_llm/training/__init__.py b/mindspeed_llm/training/__init__.py index be11f573d1fbc605e8f9573f11b383a5d442ca04..6b4d29f75d56372cfa939f476cdf9eca1924d39b 100644 --- a/mindspeed_llm/training/__init__.py +++ b/mindspeed_llm/training/__init__.py @@ -13,5 +13,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .training import (get_model_wrapper, is_profile_enabled, get_profiler, setup_model_and_optimizer_wrapper, +from mindspeed_llm.training.training import (get_model_wrapper, is_profile_enabled, get_profiler, setup_model_and_optimizer_wrapper, model_provider_func_wrapper, build_train_args, pretrain, train) diff --git a/mindspeed_llm/training/arguments.py b/mindspeed_llm/training/arguments.py index 161ffc841243fd88eed26e3e4ff3b04cbe9b23f2..ec5bc0b23925bc310af5113fd34fbdbf661dc1ea 100644 --- a/mindspeed_llm/training/arguments.py +++ b/mindspeed_llm/training/arguments.py @@ -159,14 +159,20 @@ def _add_deepseek_moe_args(parser): def _add_mtp_args(parser): group = parser.add_argument_group(title='multi token prediction') - group.add_argument('--num-nextn-predict-layers', type=int, default=0, help='Multi-Token prediction layer num') - group.add_argument('--mtp-loss-scale', type=float, default=0.3, help='Multi-Token prediction loss scale') + group.add_argument('--mtp-num-layers', type=int, default=None, + help='Number of Multi-Token Prediction (MTP) Layers.' + 'MTP extends the prediction scope to multiple future tokens at each position.' + 'This MTP implementation sequentially predict additional tokens ' + 'by using D sequential modules to predict D additional tokens.') + group.add_argument('--mtp-loss-scaling-factor', type=float, default=0.1, + help='Scaling factor of Multi-Token Prediction (MTP) loss. ' + 'We compute the average of the MTP losses across all depths, ' + 'and multiply it the scaling factor to obtain the overall MTP loss, ' + 'which serves as an additional training objective.') group.add_argument('--recompute-mtp-norm', action='store_true', default=False, help='Multi-Token prediction recompute norm') group.add_argument('--recompute-mtp-layer', action='store_true', default=False, help='Multi-Token prediction recompute layer') - group.add_argument('--share-mtp-embedding-and-output-weight', action='store_true', default=False, - help='Main model share embedding and output weight with mtp layer.') group.add_argument('--mtp-mem-efficient-logits', action='store_true', default=False, help='Optimize ce_loss memory when use mtp block.') return parser @@ -1507,6 +1513,16 @@ def _validate_fused_opts(args): '--position-embedding-type=rope') +def _validate_mtp_args(args): + if args.mtp_num_layers: + assert not args.use_legacy_models, "The legacy Megatron models does not support Multi-Token Prediction (MTP)." + assert args.context_parallel_size == 1, "Multi-Token Prediction (MTP) is not supported with Context Parallelism." + assert args.position_embedding_type == "rope" or args.position_embedding_type == "none", ( + f"Multi-Token Prediction (MTP) is not supported with {args.position_embedding_type} position embedding type." + + f"The supported position embedding types are rope and none." + ) + + def validate_args_decorator(megatron_validate_args): @wraps(megatron_validate_args) def wrapper(args, defaults=None): @@ -1550,6 +1566,9 @@ def validate_args_decorator(megatron_validate_args): _validate_noop_layer(args) _valid_tp_2d_args(args) _add_dummy_args(args) + # remove in future megatron version + _validate_mtp_args(args) + from mindspeed_llm.training.utils import print_args print_args('MindSpeed-LLM Arguments', args) diff --git a/mindspeed_llm/training/tokenizer/__init__.py b/mindspeed_llm/training/tokenizer/__init__.py index 1be6cdf6232999b1d33f0e7fccb07c5cb0e45cd4..edb539feffe49fbe13d25ac94155082de87ab2b6 100644 --- a/mindspeed_llm/training/tokenizer/__init__.py +++ b/mindspeed_llm/training/tokenizer/__init__.py @@ -14,4 +14,4 @@ # limitations under the License. -from .tokenizer import build_tokenizer +from mindspeed_llm.training.tokenizer.tokenizer import build_tokenizer diff --git a/mindspeed_llm/training/training.py b/mindspeed_llm/training/training.py index d042c8f9c73ab9a84983fbe5c9b4e4c0c0182aa1..bc957b954d779e0a7be96218a6b624d2e33bb4c2 100644 --- a/mindspeed_llm/training/training.py +++ b/mindspeed_llm/training/training.py @@ -17,6 +17,7 @@ import gc import sys import json +from datetime import datetime from functools import wraps import time @@ -24,6 +25,7 @@ import time import torch import torch_npu +from megatron.core.transformer.moe.moe_utils import track_moe_metrics from megatron.training import get_args from megatron.training import get_timers from megatron.training import get_signal_handler @@ -38,9 +40,10 @@ from megatron.training.checkpointing import save_checkpoint from megatron.training.initialize import initialize_megatron from megatron.training.initialize import write_args_to_tensorboard from megatron.training.arguments import core_transformer_config_from_args +from megatron.training.theoretical_memory_usage import report_theoretical_memory from megatron.training.training import ( train_step, calc_params_l2_norm, - training_log, evaluate_and_print_results, + evaluate_and_print_results, save_checkpoint_and_time, print_datetime, num_floating_point_operations, get_one_logger, append_to_progress_log, build_train_valid_test_data_iterators @@ -48,7 +51,9 @@ from megatron.training.training import ( import megatron.training.utils from megatron.training.utils import ( check_adlr_autoresume_termination, - print_rank_0 + print_rank_0, + print_rank_last, + report_memory, ) from megatron.core.distributed import DistributedDataParallel as DDP from megatron.core.distributed import finalize_model_grads @@ -697,4 +702,243 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler, if exit: sys.exit() - return iteration, num_floating_point_operations_so_far \ No newline at end of file + return iteration, num_floating_point_operations_so_far + + +def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_rate, iteration, + loss_scale, report_memory_flag, skipped_iter, + grad_norm, params_norm, num_zeros_in_grad): + """Log training information such as losses, timing, ....""" + args = get_args() + timers = get_timers() + writer = get_tensorboard_writer() + wandb_writer = get_wandb_writer() + one_logger = get_one_logger() + + # Advanced, skipped, and Nan iterations. + advanced_iters_key = 'advanced iterations' + skipped_iters_key = 'skipped iterations' + nan_iters_key = 'nan iterations' + # Advanced iterations. + if not skipped_iter: + total_loss_dict[advanced_iters_key] = total_loss_dict.get( + advanced_iters_key, 0) + 1 + else: + if advanced_iters_key not in total_loss_dict: + total_loss_dict[advanced_iters_key] = 0 + # Skipped iterations. + total_loss_dict[skipped_iters_key] = total_loss_dict.get( + skipped_iters_key, 0) + skipped_iter + # Update losses and set nan iterations + got_nan = False + for key in loss_dict: + if not skipped_iter: + total_loss_dict[key] = total_loss_dict.get( + key, torch.tensor([0.0], dtype=torch.float, device='cuda')) + loss_dict[key] + else: + value = loss_dict[key].float().sum().item() + is_nan = value == float('inf') or \ + value == -float('inf') or \ + value != value + got_nan = got_nan or is_nan + total_loss_dict[nan_iters_key] = total_loss_dict.get( + nan_iters_key, 0) + int(got_nan) + + # Logging. + timers_to_log = [ + 'forward-backward', + 'forward-compute', + 'backward-compute', + 'batch-generator', + 'forward-recv', + 'forward-send', + 'backward-recv', + 'backward-send', + 'forward-send-forward-recv', + 'forward-send-backward-recv', + 'backward-send-forward-recv', + 'backward-send-backward-recv', + 'forward-backward-send-forward-backward-recv', + 'layernorm-grads-all-reduce', + 'embedding-grads-all-reduce', + 'all-grads-sync', + 'params-all-gather', + 'optimizer-copy-to-main-grad', + 'optimizer-unscale-and-check-inf', + 'optimizer-clip-main-grad', + 'optimizer-count-zeros', + 'optimizer-inner-step', + 'optimizer-copy-main-to-model-params', + 'optimizer'] + + # Calculate batch size. + batch_size = args.micro_batch_size * args.data_parallel_size * \ + get_num_microbatches() + + # Track app tag & app tag ID + one_logger_utils.track_app_tag(batch_size, args.world_size, args.seq_length) + + total_iterations = total_loss_dict[advanced_iters_key] + \ + total_loss_dict[skipped_iters_key] + + # Tensorboard values. + # Timer requires all the ranks to call. + if args.log_timers_to_tensorboard and \ + (iteration % args.tensorboard_log_interval == 0): + timers.write(timers_to_log, writer, iteration, + normalizer=total_iterations) + if writer and (iteration % args.tensorboard_log_interval == 0): + if wandb_writer: + wandb_writer.log({'samples vs steps': args.consumed_train_samples}, + iteration) + if args.log_learning_rate_to_tensorboard: + writer.add_scalar('learning-rate', learning_rate, iteration) + if args.decoupled_lr is not None: + writer.add_scalar('decoupled-learning-rate', decoupled_learning_rate, iteration) + writer.add_scalar('learning-rate vs samples', learning_rate, + args.consumed_train_samples) + if wandb_writer: + wandb_writer.log({'learning-rate': learning_rate}, iteration) + if args.log_batch_size_to_tensorboard: + writer.add_scalar('batch-size', batch_size, iteration) + writer.add_scalar('batch-size vs samples', batch_size, + args.consumed_train_samples) + if wandb_writer: + wandb_writer.log({'batch-size': batch_size}, iteration) + for key in loss_dict: + writer.add_scalar(key, loss_dict[key], iteration) + writer.add_scalar(key + ' vs samples', loss_dict[key], + args.consumed_train_samples) + if wandb_writer: + wandb_writer.log({key: loss_dict[key]}, iteration) + if args.log_loss_scale_to_tensorboard: + writer.add_scalar('loss-scale', loss_scale, iteration) + writer.add_scalar('loss-scale vs samples', loss_scale, + args.consumed_train_samples) + if wandb_writer: + wandb_writer.log({'loss-scale': loss_scale}, iteration) + if args.log_world_size_to_tensorboard: + writer.add_scalar('world-size', args.world_size, iteration) + writer.add_scalar('world-size vs samples', args.world_size, + args.consumed_train_samples) + if wandb_writer: + wandb_writer.log({'world-size': args.world_size}, iteration) + if grad_norm is not None: + writer.add_scalar('grad-norm', grad_norm, iteration) + writer.add_scalar('grad-norm vs samples', grad_norm, + args.consumed_train_samples) + if wandb_writer: + wandb_writer.log({'grad-norm': grad_norm}, iteration) + if num_zeros_in_grad is not None: + writer.add_scalar('num-zeros', num_zeros_in_grad, iteration) + writer.add_scalar('num-zeros vs samples', num_zeros_in_grad, + args.consumed_train_samples) + if wandb_writer: + wandb_writer.log({'num-zeros': num_zeros_in_grad}, iteration) + if params_norm is not None: + writer.add_scalar('params-norm', params_norm, iteration) + writer.add_scalar('params-norm vs samples', params_norm, + args.consumed_train_samples) + if wandb_writer: + wandb_writer.log({'params-norm': params_norm}, iteration) + if args.log_memory_to_tensorboard: + mem_stats = torch.cuda.memory_stats() + writer.add_scalar( + "mem-reserved-bytes", + mem_stats["reserved_bytes.all.current"], + iteration, + ) + writer.add_scalar( + "mem-allocated-bytes", + mem_stats["allocated_bytes.all.current"], + iteration, + ) + writer.add_scalar( + "mem-allocated-count", + mem_stats["allocation.all.current"], + iteration, + ) + if args.num_experts is not None: + moe_loss_scale = 1 / get_num_microbatches() + track_moe_metrics(moe_loss_scale, iteration, writer, wandb_writer, total_loss_dict, args.moe_per_layer_logging) + if args.mtp_num_layers: + from mindspeed_llm.core.transformer.multi_token_prediction import MTPLossLoggingHelper + + mtp_loss_scale = 1 / get_num_microbatches() + MTPLossLoggingHelper.track_mtp_metrics( + mtp_loss_scale, iteration, writer, wandb_writer, total_loss_dict + ) + + if iteration % args.log_interval == 0: + elapsed_time = timers('interval-time').elapsed(barrier=True) + elapsed_time_per_iteration = elapsed_time / total_iterations + + throughput = num_floating_point_operations(args, batch_size) / ( + elapsed_time_per_iteration * 10**12 * args.world_size) + + one_logger_utils.track_e2e_metrics(args.log_throughput, throughput) + + if args.log_timers_to_tensorboard: + if writer: + writer.add_scalar('iteration-time', + elapsed_time_per_iteration, iteration) + if wandb_writer: + wandb_writer.log({'iteration-time': elapsed_time_per_iteration}, + iteration) + log_string = f" [{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}]" + log_string += ' iteration {:8d}/{:8d} |'.format( + iteration, args.train_iters) + log_string += ' consumed samples: {:12d} |'.format( + args.consumed_train_samples) + log_string += ' elapsed time per iteration (ms): {:.1f} |'.format( + elapsed_time_per_iteration * 1000.0) + if args.log_throughput: + log_string += f' throughput per GPU (TFLOP/s/GPU): {throughput:.1f} |' + if args.log_timers_to_tensorboard: + if writer: + writer.add_scalar('throughput', throughput, iteration) + if wandb_writer: + wandb_writer.log({'throughput': throughput}, iteration) + assert learning_rate is not None + # Decoupled_learning_rate should be not None only on first and last pipeline stage. + log_string += ' learning rate: {:.6E} |'.format(learning_rate) + if args.decoupled_lr is not None and (mpu.is_pipeline_first_stage(ignore_virtual=True) or + mpu.is_pipeline_last_stage(ignore_virtual=True)): + assert decoupled_learning_rate is not None + log_string += ' decoupled learning rate: {:.6E} |'.format(decoupled_learning_rate) + else: + assert decoupled_learning_rate is None + log_string += ' global batch size: {:5d} |'.format(batch_size) + for key in total_loss_dict: + if key not in [advanced_iters_key, skipped_iters_key, + nan_iters_key]: + avg = total_loss_dict[key].item() / \ + float(max(1, total_loss_dict[advanced_iters_key])) + if avg > 0.0: + log_string += ' {}: {:.6E} |'.format(key, avg) + total_loss_dict[key] = torch.tensor([0.0], dtype=torch.float, device='cuda') + log_string += ' loss scale: {:.1f} |'.format(loss_scale) + if grad_norm is not None: + log_string += ' grad norm: {:.3f} |'.format(grad_norm) + if num_zeros_in_grad is not None: + log_string += ' num zeros: {:.1f} |'.format(num_zeros_in_grad) + if params_norm is not None: + log_string += ' params norm: {:.3f} |'.format(params_norm) + log_string += ' number of skipped iterations: {:3d} |'.format( + total_loss_dict[skipped_iters_key]) + log_string += ' number of nan iterations: {:3d} |'.format( + total_loss_dict[nan_iters_key]) + total_loss_dict[advanced_iters_key] = 0 + total_loss_dict[skipped_iters_key] = 0 + total_loss_dict[nan_iters_key] = 0 + print_rank_last(log_string) + if report_memory_flag and learning_rate > 0.: + # Report memory after optimizer state has been initialized. + if torch.distributed.get_rank() == 0: + num_microbatches = get_num_microbatches() + report_theoretical_memory(args, num_microbatches=num_microbatches, verbose=True) + report_memory('(after {} iterations)'.format(iteration)) + report_memory_flag = False + timers.log(timers_to_log, normalizer=args.log_interval) + + return report_memory_flag \ No newline at end of file diff --git a/mindspeed_llm/training/utils.py b/mindspeed_llm/training/utils.py index 70d1b915923919f0a9ec5ffa88ee5417bffe4e29..db7e8e21ac0a93512d9ba0f4f033adfe5e27a2ac 100644 --- a/mindspeed_llm/training/utils.py +++ b/mindspeed_llm/training/utils.py @@ -45,36 +45,49 @@ WRITE_FILE_DEFAULT_FLAGS = os.O_WRONLY | os.O_CREAT WRITE_FILE_DEFAULT_MODES = stat.S_IWUSR | stat.S_IRUSR -def compute_actual_seq_len(seq, stride=0): - """ - compute actual seq len. - MTP layer bring extra n tokens, which should be cut. - """ - res = list() - _args = get_args() - seq_length = _args.seq_length - batch_size = seq.shape[0] - for batch_idx in range(batch_size): - batch_seq = seq[batch_idx] - if batch_idx == 0: - zero_pos = (batch_seq == 0).nonzero()[1:].squeeze(dim=1) - else: - zero_pos = (batch_seq == 0).nonzero().squeeze(dim=1) - res.extend((zero_pos + (batch_idx * seq_length - stride)).tolist()) - batch_len = len(batch_seq) + batch_idx * seq_length - stride - if batch_len > seq_length * (batch_idx + 1): - batch_len = seq_length * (batch_idx + 1) - if batch_idx == batch_size - 1: - res.append(batch_len) +def compute_actual_seq_len(origin_seq): + seq = origin_seq.view(-1) + zero_pos = (seq == 0).nonzero()[1:].squeeze(dim=1) + res = zero_pos.tolist() + res.append(len(seq)) return res def generate_actual_seq_len(batch): + args = get_args() position_ids = batch.get('position_ids').transpose(0, 1).contiguous() set_position_ids(position_ids) position_ids = batch.get('position_ids') actual_seq_len = compute_actual_seq_len(position_ids) - set_actual_seq_len(actual_seq_len) + if args.mtp_num_layers: + seq_len = position_ids.shape[1] + mtp_res = [actual_seq_len] + for i in range(1, args.mtp_num_layers + 1): + next_actual_seq_len = [] + for j in actual_seq_len: + if j % seq_len == 0: + next_actual_seq_len.append(j) + else: + next_actual_seq_len.append(j - i) + mtp_res.append(next_actual_seq_len) + set_actual_seq_len(mtp_res) + else: + set_actual_seq_len(actual_seq_len) + + +def regenerate_position_ids(tensor, offset): + if tensor is None: + return None + tensor = tensor.clone() + for i in range(tensor.size(0)): + row = tensor[i] + zero_mask = (row == 0) + if zero_mask.any(): + first_zero_idx = torch.argmax(zero_mask.int()).item() + tensor[i, :first_zero_idx] = torch.arange(first_zero_idx) + else: + tensor = tensor - offset + return tensor def parse_args(): @@ -297,12 +310,15 @@ def get_batch_on_this_tp_rank(data_iterator): _broadcast(batch['labels']) elif mpu.is_pipeline_last_stage(): - if args.num_nextn_predict_layers or args.schedules_method == 'dualpipev': + # Multi-Token Prediction (MTP) layers need tokens and position_ids to calculate embedding. + # Currently the Multi-Token Prediction (MTP) layers is fixed on the last stage, so we need + # to broadcast tokens and position_ids to all of the tensor parallel ranks on the last stage. + if args.mtp_num_layers or args.schedules_method == 'dualpipev': _broadcast(batch['tokens']) _broadcast(batch['labels']) _broadcast(batch['loss_mask']) _broadcast(batch['attention_mask']) - if args.reset_position_ids or args.num_nextn_predict_layers or args.schedules_method == 'dualpipev': + if args.reset_position_ids or args.mtp_num_layers or args.schedules_method == 'dualpipev': _broadcast(batch['position_ids']) else: _broadcast(batch['attention_mask']) @@ -311,24 +327,24 @@ def get_batch_on_this_tp_rank(data_iterator): else: - tokens = torch.empty((args.micro_batch_size, args.seq_length + args.num_nextn_predict_layers), + tokens = torch.empty((args.micro_batch_size, args.seq_length), dtype=torch.int64, device=torch.cuda.current_device()) - labels = torch.empty((args.micro_batch_size, args.seq_length + args.num_nextn_predict_layers), + labels = torch.empty((args.micro_batch_size, args.seq_length), dtype=torch.int64, device=torch.cuda.current_device()) - loss_mask = torch.empty((args.micro_batch_size, args.seq_length + args.num_nextn_predict_layers), + loss_mask = torch.empty((args.micro_batch_size, args.seq_length), dtype=torch.float32, device=torch.cuda.current_device()) if args.create_attention_mask_in_dataloader: attention_mask = torch.empty( - (args.micro_batch_size, 1, args.seq_length + args.num_nextn_predict_layers, - args.seq_length + args.num_nextn_predict_layers), dtype=torch.bool, + (args.micro_batch_size, 1, args.seq_length, + args.seq_length), dtype=torch.bool, device=torch.cuda.current_device() ) else: attention_mask = None - position_ids = torch.empty((args.micro_batch_size, args.seq_length + args.num_nextn_predict_layers), + position_ids = torch.empty((args.micro_batch_size, args.seq_length), dtype=torch.int64, device=torch.cuda.current_device()) @@ -351,14 +367,14 @@ def get_batch_on_this_tp_rank(data_iterator): loss_mask = None elif mpu.is_pipeline_last_stage(): - if args.num_nextn_predict_layers or args.schedules_method == 'dualpipev': + if args.mtp_num_layers or args.schedules_method == 'dualpipev': _broadcast(tokens) else: tokens = None _broadcast(labels) _broadcast(loss_mask) _broadcast(attention_mask) - if args.reset_position_ids or args.num_nextn_predict_layers or args.schedules_method == 'dualpipev': + if args.reset_position_ids or args.mtp_num_layers or args.schedules_method == 'dualpipev': _broadcast(position_ids) else: position_ids = None diff --git a/pretrain_gpt.py b/pretrain_gpt.py index da6067369131185ffccb0b949d0d227b6386cdbd..95c7790ecde044c436f30d78433cbcc29dd4bd58 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -19,6 +19,7 @@ from megatron.core.datasets.gpt_dataset import MockGPTDataset, GPTDataset from megatron.core.datasets.utils import get_blend_from_list import megatron.legacy.model from megatron.core.models.gpt import GPTModel +from mindspeed_llm.core.models.gpt.gpt_layer_specs import get_gpt_mtp_block_spec from mindspeed_llm.training import pretrain from megatron.core.transformer.spec_utils import import_module from megatron.training.utils import ( @@ -66,6 +67,9 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(args.num_experts, args.moe_grouped_gemm) else: transformer_layer_spec = get_gpt_layer_local_spec(args.num_experts, args.moe_grouped_gemm) + mtp_block_spec = None + if args.mtp_num_layers is not None: + mtp_block_spec = get_gpt_mtp_block_spec(config, transformer_layer_spec, use_transformer_engine=use_te) model = GPTModel( config=config, @@ -79,7 +83,8 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights, position_embedding_type=args.position_embedding_type, rotary_percent=args.rotary_percent, - seq_len_interpolation_factor=args.rotary_seq_len_interpolation_factor + seq_len_interpolation_factor=args.rotary_seq_len_interpolation_factor, + mtp_block_spec=mtp_block_spec, ) else: if not args.context_parallel_size == 1: @@ -124,8 +129,6 @@ def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor): args = get_args() losses = output_tensor.float() - if args.num_nextn_predict_layers > 0: - loss_mask = tensor_slide(loss_mask, args.num_nextn_predict_layers, return_first=True)[0] loss_mask = loss_mask.view(-1).float() if args.context_parallel_size > 1: loss = torch.cat([torch.sum(losses.view(-1) * loss_mask).view(1), loss_mask.sum().view(1)]) @@ -163,8 +166,12 @@ def forward_step(data_iterator, model: GPTModel): data_iterator) timers('batch-generator').stop() - output_tensor = model(tokens, position_ids, attention_mask, - labels=labels) + if args.use_legacy_models: + output_tensor = model(tokens, position_ids, attention_mask, + labels=labels) + else: + output_tensor = model(tokens, position_ids, attention_mask, + labels=labels, loss_mask=loss_mask) return output_tensor, partial(loss_func, loss_mask) @@ -178,7 +185,7 @@ def core_gpt_dataset_config_from_args(args): return GPTDatasetConfig( random_seed=args.seed, - sequence_length=args.seq_length + args.num_nextn_predict_layers, + sequence_length=args.seq_length, blend=get_blend_from_list(args.data_path), blend_per_split=[ get_blend_from_list(args.train_data_path), diff --git a/tests/poc/deepseek3/pretrain_deepseek3_60b_4k_128die_A3_ptd.sh b/tests/poc/deepseek3/pretrain_deepseek3_60b_4k_128die_A3_ptd.sh index 0b79e3f63eb34efd86b6fc0ac5630004cf466462..8c9303bb7d8d3074e893c6498fbbf7f872d868f0 100644 --- a/tests/poc/deepseek3/pretrain_deepseek3_60b_4k_128die_A3_ptd.sh +++ b/tests/poc/deepseek3/pretrain_deepseek3_60b_4k_128die_A3_ptd.sh @@ -70,8 +70,8 @@ MOE_ARGS=" " MTP_ARGS=" - --num-nextn-predict-layers 1 \ - --share-mtp-embedding-and-output-weight \ + --mtp-num-layers 1 \ + --mtp-loss-scaling-factor 0.3 \ --recompute-mtp-norm \ " diff --git a/tests/poc/deepseek3/pretrain_deepseek3_671b_4k_512die_A2_ptd.sh b/tests/poc/deepseek3/pretrain_deepseek3_671b_4k_512die_A2_ptd.sh index a38e1d56c9c1b5a8a4e792f98f0044f045867820..52513229c01c5ebda8fef49b82871f3560ca1477 100644 --- a/tests/poc/deepseek3/pretrain_deepseek3_671b_4k_512die_A2_ptd.sh +++ b/tests/poc/deepseek3/pretrain_deepseek3_671b_4k_512die_A2_ptd.sh @@ -78,8 +78,8 @@ MOE_ARGS=" " MTP_ARGS=" - --num-nextn-predict-layers 1 \ - --share-mtp-embedding-and-output-weight \ + --mtp-num-layers 1 \ + --mtp-loss-scaling-factor 0.3 \ --recompute-mtp-norm \ --mtp-mem-efficient-logits \ " diff --git a/tests/poc/deepseek3/pretrain_deepseek3_671b_4k_512die_A3_ptd.sh b/tests/poc/deepseek3/pretrain_deepseek3_671b_4k_512die_A3_ptd.sh index cf16d38c6002044f7eed704720904b7d3003b606..79396d5106770e1abba83175cbde048969175938 100644 --- a/tests/poc/deepseek3/pretrain_deepseek3_671b_4k_512die_A3_ptd.sh +++ b/tests/poc/deepseek3/pretrain_deepseek3_671b_4k_512die_A3_ptd.sh @@ -75,8 +75,8 @@ MOE_ARGS=" " MTP_ARGS=" - --num-nextn-predict-layers 1 \ - --share-mtp-embedding-and-output-weight \ + --mtp-num-layers 1 \ + --mtp-loss-scaling-factor 0.3 \ --mtp-mem-efficient-logits \ " diff --git a/tests/st/baseline_results/deepseek_v3_mcore_tp1_pp2_ep4.json b/tests/st/baseline_results/deepseek_v3_mcore_tp1_pp2_ep4.json index 143752ba4d087315d061b72b6dbe08d802faee5e..58b89ffe8bbba535f56ac8ccd5ed11dfbcc5f64a 100644 --- a/tests/st/baseline_results/deepseek_v3_mcore_tp1_pp2_ep4.json +++ b/tests/st/baseline_results/deepseek_v3_mcore_tp1_pp2_ep4.json @@ -1,48 +1,48 @@ { "lm loss": [ - 2.003549E+01, - 1.849927E+01, - 1.612998E+01, - 1.494160E+01, - 1.283224E+01, - 1.213984E+01, - 1.136909E+01, - 1.092335E+01, - 1.052563E+01, - 9.903873E+00, - 9.934919E+00, - 9.675148E+00, - 9.970680E+00, - 9.665798E+00, - 9.784259E+00 + 1.005824E+01, + 9.262489E+00, + 9.022533E+00, + 8.649296E+00, + 8.244690E+00, + 8.017500E+00, + 7.664330E+00, + 7.517603E+00, + 7.426267E+00, + 6.847564E+00, + 7.110690E+00, + 6.881394E+00, + 6.999724E+00, + 6.782255E+00, + 6.936630E+00 ], "throughput": [ - 27.9, - 94.0, - 94.1, - 93.4, - 93.0, - 93.0, - 93.3, - 93.1, - 93.6, - 93.8, - 94.1, - 93.7, - 93.7, - 93.8, - 94.1 + 35.0, + 129.3, + 128.9, + 129.3, + 128.4, + 129.3, + 129.3, + 129.3, + 128.4, + 129.1, + 129.5, + 129.3, + 128.8, + 129.1, + 129.0 ], "memo info": [ { "rank": 0, - "allocated memory": 19498.04931640625, - "max allocated memory": 20394.0517578125 + "allocated memory": 16809.04541015625, + "max allocated memory": 17799.1171875 }, { "rank": 4, - "allocated memory": 37729.767578125, - "max allocated memory": 38625.77294921875 + "allocated memory": 20536.33447265625, + "max allocated memory": 22651.4599609375 } ] } \ No newline at end of file diff --git a/tests/st/shell_scripts/deepseek_v3_mcore_tp1_pp2_ep4.sh b/tests/st/shell_scripts/deepseek_v3_mcore_tp1_pp2_ep4.sh index ed009a646380495d68884a696f01fc7000e2b3bb..f96314649e73453b6168f83ba8f5dce4dd51e25c 100644 --- a/tests/st/shell_scripts/deepseek_v3_mcore_tp1_pp2_ep4.sh +++ b/tests/st/shell_scripts/deepseek_v3_mcore_tp1_pp2_ep4.sh @@ -13,8 +13,8 @@ WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) basepath=$(cd `dirname $0`; cd ../../../; pwd) DATA_PATH=/data/pretrain_dataset/alpaca_text_document -TOKENIZER_PATH=/data/deepseek-v3-mcore-tp1-pp2-ep4-32experts -CKPT_LOAD_DIR=/data/deepseek-v3-mcore-tp1-pp2-ep4-32experts +TOKENIZER_PATH=/data/deepseek-v3-mcore-tp1-pp2-ep4-16experts +CKPT_LOAD_DIR=/data/deepseek-v3-mcore-tp1-pp2-ep4-16experts TP=1 PP=2 @@ -41,12 +41,14 @@ MLA_ARGS=" --q-lora-rank 1536 \ --kv-lora-rank 512 \ --v-head-dim 128 \ - --qk-layernorm + --qk-layernorm \ + --mla-mm-split " MOE_ARGS=" --n-group 4 \ --seq-aux \ + --moe-aux-loss-coeff 0.0001 \ --moe-alltoall-overlap-comm \ --moe-grouped-gemm \ --moe-permutation-async-comm \ @@ -55,20 +57,19 @@ MOE_ARGS=" --first-k-dense-replace 1 \ --moe-layer-freq 1 \ --n-shared-experts 1 \ - --num-experts 32 \ + --num-experts 16 \ --moe-router-topk 8 \ --moe-intermediate-size 2048 \ --moe-router-load-balancing-type noaux_tc \ --topk-group 4 \ --routed-scaling-factor 2.5 \ - --norm-topk-prob \ --moe-router-score-function sigmoid \ --moe-router-enable-expert-bias " MTP_ARGS=" - --num-nextn-predict-layers 1 \ - --share-mtp-embedding-and-output-weight \ + --mtp-num-layers 1 \ + --mtp-loss-scaling-factor 0.3 \ --recompute-mtp-norm \ --recompute-mtp-layer " @@ -86,6 +87,7 @@ ROPE_ARGS=" GPT_ARGS=" --finetune \ --spec mindspeed_llm.tasks.models.spec.deepseek_spec layer_spec \ + --noop-layers 2,3 \ --recompute-granularity full \ --recompute-method uniform \ --recompute-num-layers 1 \ @@ -93,7 +95,6 @@ GPT_ARGS=" --use-distributed-optimizer \ --reuse-fp32-param \ --use-flash-attn \ - --shape-order BNSD \ --use-mcore-models \ --tensor-model-parallel-size ${TP} \ --pipeline-model-parallel-size ${PP} \