diff --git a/mindspeed_llm/core/distributed/finalize_model_grads.py b/mindspeed_llm/core/distributed/finalize_model_grads.py index 622e1dd28fec1ff6a1c6d00fbaf76ce8387d8a12..0a0c6f444e7f7d331aa6da8b1c1bbe94bd08a1cc 100644 --- a/mindspeed_llm/core/distributed/finalize_model_grads.py +++ b/mindspeed_llm/core/distributed/finalize_model_grads.py @@ -7,12 +7,9 @@ import torch from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from megatron.core import parallel_state -from megatron.core.distributed.finalize_model_grads import _allreduce_layernorm_grads, _allreduce_embedding_grads from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.utils import get_attr_wrapped_model, get_model_config -from megatron.training import get_args +from megatron.core.utils import get_attr_wrapped_model from mindspeed.core.tensor_parallel.comm_group_api import TPXCollectiveComm -from mindspeed_llm.core.transformer.moe.moe_utils import get_updated_expert_bias def _get_main_grad_attr(param: torch.nn.Parameter, use_custom_fsdp: bool = False): @@ -68,122 +65,3 @@ def allreduce_layernorm_grads(model: List[torch.nn.Module], config: TransformerC layer_norm_2d_grads, _unflatten_dense_tensors(coalesced, layer_norm_2d_grads) ): buf.copy_(synced) - - -def _allreduce_word_embedding_grads(model: List[torch.nn.Module], config: TransformerConfig): - """ - All-reduce word embedding grads. - - Reduce grads across first and last stages to ensure that word_embeddings parameters stay in - sync. - """ - - if ( - parallel_state.is_rank_in_embedding_group(ignore_virtual=True) - and torch.distributed.get_world_size(parallel_state.get_embedding_group()) > 1 - ): - if parallel_state.is_pipeline_first_stage(ignore_virtual=True): - model_module = model[0] - elif parallel_state.is_pipeline_last_stage(ignore_virtual=True): - model_module = model[-1] - else: # We do not support an interleaved schedule for models with encoders yet. - model_module = model[0] - - model_module = get_attr_wrapped_model(model_module, 'pre_process', return_model_obj=True) - # If share_embeddings_and_output_weights is True, we need to maintain duplicated - # embedding weights in post processing stage. If use Multi-Token Prediction (MTP), - # we also need to maintain duplicated embedding weights in mtp process stage. - # So we need to allreduce grads of embedding in the embedding group in these cases. - if model_module.share_embeddings_and_output_weights or getattr(config, 'mtp_num_layers', 0): - weight = model_module.shared_embedding_or_output_weight() - if not weight.requires_grad: - return - grad_attr = _get_main_grad_attr(weight) - grad = getattr(weight, grad_attr) - torch.distributed.all_reduce(grad, group=parallel_state.get_embedding_group()) - - -def _update_router_expert_bias(model: List[torch.nn.Module], config: TransformerConfig): - """ - Update the expert bias of the router for a global batch. - This requires all-reduce of local_tokens_per_expert across TPxCPxDP ranks - """ - args = get_args() - tokens_per_expert_list = [] - expert_bias_list = [] - for model_chunk in model: - for module in get_attr_wrapped_model(model_chunk, 'modules')(): - if hasattr(module, 'expert_bias'): - tokens_per_expert_list.append(module.local_tokens_per_expert) - expert_bias_list.append(module.expert_bias) - # For hybrid models with both MoE and Dense layers, this list can be empty. - if len(expert_bias_list) == 0: - return - stacked_tokens_per_expert = torch.stack(tokens_per_expert_list, dim=0) - stacked_expert_bias = torch.stack(expert_bias_list, dim=0) - stacked_updated_expert_bias = get_updated_expert_bias( - stacked_tokens_per_expert, stacked_expert_bias, args.moe_router_bias_update_rate - ) - - for tokens_per_expert, expert_bias, updated_expert_bias in zip( - tokens_per_expert_list, expert_bias_list, stacked_updated_expert_bias - ): - tokens_per_expert.zero_() - expert_bias.copy_(updated_expert_bias) - - -def finalize_model_grads(model: List[torch.nn.Module], num_tokens: Optional[torch.Tensor] = None): - """ - All-reduce all model grads across DP replicas, layernorm grads for sequence parallelism, - embedding grads across first and last pipeline stages (if not tied), - scale gradients by `num_tokens`. - """ - - config = get_model_config(model[0]) - - # All-reduce / reduce-scatter across DP replicas. - if config.timers is not None: - config.timers('all-grads-sync', log_level=1).start(barrier=config.barrier_with_L1_time) - for model_chunk in model: - model_chunk.finish_grad_sync() - if config.timers is not None: - config.timers('all-grads-sync').stop() - - # All-reduce layer-norm grads (for sequence parallelism). - if config.timers is not None: - config.timers('layernorm-grads-all-reduce', log_level=1).start( - barrier=config.barrier_with_L1_time - ) - _allreduce_layernorm_grads(model, config) - if config.timers is not None: - config.timers('layernorm-grads-all-reduce').stop() - - # All-reduce embedding grads (for pipeline parallelism). - if config.timers is not None: - config.timers('embedding-grads-all-reduce', log_level=1).start( - barrier=config.barrier_with_L1_time - ) - _allreduce_embedding_grads(model, config) - if config.timers is not None: - config.timers('embedding-grads-all-reduce').stop() - args = get_args() - if args.moe_router_enable_expert_bias: - _update_router_expert_bias(model, config) - - # normalize gradients for per-token loss normalization. - # if we are using by the number of tokens, then we use that as a divisor. this number - # will be the total number of non-padded tokens in the global batch. - if num_tokens is not None: - # the number of tokens is only present on the last stage, so broadcast it - # to the other ranks in the pipeline parallel group. - torch.distributed.broadcast( - num_tokens, - src=parallel_state.get_pipeline_model_parallel_last_rank(), - group=parallel_state.get_pipeline_model_parallel_group(), - ) - # all-reduce across DP ranks. - torch.distributed.all_reduce(num_tokens, group=parallel_state.get_data_parallel_group()) - for model_chunk in model: - if num_tokens > 0: - scaling = 1.0 / num_tokens - model_chunk.scale_gradients(scaling) diff --git a/mindspeed_llm/core/models/common/language_module/language_module.py b/mindspeed_llm/core/models/common/language_module/language_module.py index 3eb287343fb7f9fb216a2e3945568e93b5fe53a8..2101ec6d85fa56cfbc547fc52ce60800c8672325 100644 --- a/mindspeed_llm/core/models/common/language_module/language_module.py +++ b/mindspeed_llm/core/models/common/language_module/language_module.py @@ -1,4 +1,5 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. import logging import torch @@ -85,58 +86,4 @@ def setup_embeddings_and_output_layer(self) -> None: "this needs to be handled manually. If you are training " "something is definitely wrong." ) - LanguageModule.embedding_warning_printed = True - - -def tie_embeddings_and_output_weights_state_dict( - self, - sharded_state_dict: ShardedStateDict, - output_layer_weight_key: str, - first_stage_word_emb_key: str, -) -> None: - """Ties the embedding and output weights in a given sharded state dict. - - Args: - sharded_state_dict (ShardedStateDict): state dict with the weight to tie - output_layer_weight_key (str): key of the output layer weight in the state dict. - This entry will be replaced with a tied version - first_stage_word_emb_key (str): this must be the same as the - ShardedTensor.key of the first stage word embeddings. - - Returns: None, acts in-place - """ - if not self.post_process: - # No output layer - if output_layer_weight_key in sharded_state_dict or not sharded_state_dict.keys(): - raise AssertionError("Sharded state dict incorrectly initialized.") - return - - if self.pre_process: - # Output layer is equivalent to the embedding already - return - - # If use Multi-Token Prediction (MTP), we need maintain both embedding layer and output - # layer in mtp process stage. In this case, if share_embeddings_and_output_weights is True, - # the shared weights will be stored in embedding layer, and output layer will not have - # any weight. - if getattr(self, 'mtp_process', False): - # No output layer - if output_layer_weight_key in sharded_state_dict or not sharded_state_dict.keys(): - raise AssertionError("Sharded state dict incorrectly initialized.") - return - - # Replace the default output layer with a one sharing the weights with the embedding - del sharded_state_dict[output_layer_weight_key] - tensor = self.shared_embedding_or_output_weight() - last_stage_word_emb_replica_id = ( - 1, # copy of first stage embedding - 0, - parallel_state.get_data_parallel_rank(with_context_parallel=True), - ) - - sharded_state_dict[output_layer_weight_key] = make_tp_sharded_tensor_for_checkpoint( - tensor=tensor, - key=first_stage_word_emb_key, - replica_id=last_stage_word_emb_replica_id, - allow_shape_mismatch=True, - ) \ No newline at end of file + LanguageModule.embedding_warning_printed = True \ No newline at end of file diff --git a/mindspeed_llm/core/models/gpt/gpt_layer_specs.py b/mindspeed_llm/core/models/gpt/gpt_layer_specs.py index 895755aec4b8cd9db56f584506527613c30df9a4..b9bf781b31e6380f68b15f03c7627cab6dfea354 100644 --- a/mindspeed_llm/core/models/gpt/gpt_layer_specs.py +++ b/mindspeed_llm/core/models/gpt/gpt_layer_specs.py @@ -15,22 +15,11 @@ import types from functools import wraps -from typing import Union -from megatron.core.transformer import ModuleSpec from megatron.core.transformer.moe.moe_layer import MoELayer from megatron.training import get_args -from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.transformer.transformer_block import TransformerBlockSubmodules -from mindspeed_llm.core.transformer.transformer_layer import TransformerLayer from mindspeed_llm.core.transformer.custom_layers.transformer_engine import PTNorm -from mindspeed_llm.core.transformer.multi_token_prediction import ( - MultiTokenPredictionBlockSubmodules, - get_mtp_layer_offset, - get_mtp_layer_spec, - get_mtp_num_layers_to_build, -) def get_gpt_layer_local_spec_wrapper(fn): @@ -65,41 +54,4 @@ def build_layers_wrapper(fn, column_forward, row_forward): local_expert.linear_fc1.forward = types.MethodType(column_forward, local_expert.linear_fc1) local_expert.linear_fc2.forward = types.MethodType(row_forward, local_expert.linear_fc2) - return wrapper - - -def get_gpt_mtp_block_spec( - config: TransformerConfig, - spec: Union[TransformerBlockSubmodules, ModuleSpec], - use_transformer_engine: bool, -) -> MultiTokenPredictionBlockSubmodules: - """GPT Multi-Token Prediction (MTP) block spec.""" - num_layers_to_build = get_mtp_num_layers_to_build(config) - if num_layers_to_build == 0: - return None - - if isinstance(spec, TransformerBlockSubmodules): - # get the spec for the last layer of decoder block - transformer_layer_spec = spec.layer_specs[-1] - elif isinstance(spec, ModuleSpec) and spec.module == TransformerLayer: - transformer_layer_spec = spec - else: - raise ValueError(f"Invalid spec: {spec}") - - mtp_layer_spec = get_mtp_layer_spec( - transformer_layer_spec=transformer_layer_spec, use_transformer_engine=use_transformer_engine - ) - mtp_num_layers = config.mtp_num_layers if config.mtp_num_layers else 0 - mtp_layer_specs = [mtp_layer_spec] * mtp_num_layers - - offset = get_mtp_layer_offset(config) - # split the mtp layer specs to only include the layers that are built in this pipeline stage. - mtp_layer_specs = mtp_layer_specs[offset: offset + num_layers_to_build] - if len(mtp_layer_specs) > 0: - if len(mtp_layer_specs) != config.mtp_num_layers: - raise AssertionError(f"currently all of the mtp layers must stage in the same pipeline stage.") - mtp_block_spec = MultiTokenPredictionBlockSubmodules(layer_specs=mtp_layer_specs) - else: - mtp_block_spec = None - - return mtp_block_spec + return wrapper \ No newline at end of file diff --git a/mindspeed_llm/core/models/gpt/gpt_model.py b/mindspeed_llm/core/models/gpt/gpt_model.py index 1696967fe496756d90b75be48d4698b2b21c9803..558c51573735b4b567aa6ba355ca837264de05d9 100644 --- a/mindspeed_llm/core/models/gpt/gpt_model.py +++ b/mindspeed_llm/core/models/gpt/gpt_model.py @@ -13,14 +13,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Literal, Optional, Dict +from typing import Literal, Optional from functools import wraps import torch from torch import Tensor -from megatron.core import InferenceParams, tensor_parallel, parallel_state -from megatron.core.dist_checkpointing.mapping import ShardedStateDict +from megatron.core import InferenceParams, tensor_parallel from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding from megatron.core.models.common.language_module.language_module import LanguageModule @@ -30,18 +29,14 @@ from megatron.core.transformer import build_module from megatron.core.transformer.custom_layers.transformer_engine import TENorm from megatron.core.transformer import TransformerConfig, ModuleSpec from megatron.core.transformer.enums import ModelType +from megatron.core.transformer.multi_token_prediction import MultiTokenPredictionBlock from megatron.core.transformer.transformer_block import TransformerBlock -from megatron.core.utils import WrappedTensor, deprecate_inference_params +from megatron.core.utils import deprecate_inference_params from megatron.core.inference.contexts import BaseInferenceContext from megatron.training import get_args from mindspeed_llm.core.tensor_parallel.layers import SegmentedColumnParallelLinear -from mindspeed_llm.training.utils import tensor_slide -from mindspeed_llm.core.transformer.multi_token_prediction import ( - MultiTokenPredictionBlock, - tie_output_layer_state_dict, - tie_word_embeddings_state_dict, -) + from mindspeed.utils import get_actual_seq_len, compute_qkv_index, get_position_ids @@ -296,66 +291,13 @@ class GPTModel(MegatronCoreGPTModel): output weights set to True or when use Multi-Token Prediction (MTP) feature. Returns: - Tensor: During pre processing or MTP process it returns the input embeddings weight. - Otherwise, during post processing it returns the final output layers weight. + Tensor: When dualpipe is enabled, return the weights from dual_chunk, otherwise follow the original logic. """ if not self.pre_process and self.post_process and get_args().schedules_method == 'dualpipev': from mindspeed.core.pipeline_parallel.dualpipev.dualpipev_schedules import \ get_shared_embedding_from_dual_chunk return get_shared_embedding_from_dual_chunk() - if self.pre_process or self.mtp_process: - # Multi-Token Prediction (MTP) need both embedding layer and output layer. - # So there will be both embedding layer and output layer in the mtp process stage. - # In this case, if share_embeddings_and_output_weights is True, the shared weights - # will be stored in embedding layer, and output layer will not have any weight. - assert hasattr( - self, 'embedding' - ), f"embedding is needed in this pipeline stage, but it is not initialized." - return self.embedding.word_embeddings.weight - elif self.post_process: - return self.output_layer.weight - return None - - def sharded_state_dict( - self, prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[Dict] = None - ) -> ShardedStateDict: - """Sharded state dict implementation for GPTModel backward-compatibility. - Removing extra state. - Tie word embeddings and output layer in mtp process stage. - - Args: - prefix (str): Module name prefix. - sharded_offsets (tuple): PP related offsets, expected to be empty at this module level. - metadata (Optional[Dict]): metadata controlling sharded state dict creation. - - Returns: - ShardedStateDict: sharded state dict for the GPTModel - """ - sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata) - # Multi-Token Prediction (MTP) need both embedding layer and output layer in - # mtp process stage. - # If MTP is not placed in the pre processing stage, we need to maintain a copy of - # embedding layer in the mtp process stage and tie it to the embedding in the pre - # processing stage. - # Also, if MTP is not placed in the post processing stage, we need to maintain a copy - # of output layer in the mtp process stage and tie it to the output layer in the post - # processing stage. - if self.mtp_process and not self.pre_process: - emb_weight_key = f'{prefix}embedding.word_embeddings.weight' - emb_weight = self.embedding.word_embeddings.weight - tie_word_embeddings_state_dict(sharded_state_dict, emb_weight, emb_weight_key) - if self.mtp_process and not self.post_process: - # We only need to tie the output layer weight if share_embeddings_and_output_weights - # is False. Because if share_embeddings_and_output_weights is True, the shared weight - # will be stored in embedding layer, and output layer will not have any weight. - if not self.share_embeddings_and_output_weights: - output_layer_weight_key = f'{prefix}output_layer.weight' - output_layer_weight = self.output_layer.weight - tie_output_layer_state_dict( - sharded_state_dict, output_layer_weight, output_layer_weight_key - ) - - return sharded_state_dict + return super().shared_embedding_or_output_weight() def gpt_forward_wrapper(fn): diff --git a/mindspeed_llm/core/pipeline_parallel/schedules.py b/mindspeed_llm/core/pipeline_parallel/schedules.py index cbc173b2c105df06674333628c0a4f2dd2152d33..97a2eede53d94cb35a46d87c686b9f2957e125d5 100644 --- a/mindspeed_llm/core/pipeline_parallel/schedules.py +++ b/mindspeed_llm/core/pipeline_parallel/schedules.py @@ -20,7 +20,6 @@ from functools import wraps import torch from megatron.training import get_args from mindspeed.core.pipeline_parallel.ripipe_schedules import forward_backward_ripipe_pipelining -from mindspeed_llm.core.transformer.multi_token_prediction import MTPLossAutoScaler def get_forward_backward_func_wrapper(get_forward_backward_func): @@ -60,33 +59,4 @@ def forward_backward_pipelining_with_interleaving_wrapper(fn): if args_.virtual_pipeline_model_parallel_size is not None and args_.stage == "orm": kwargs['micro_batch_size'] = args_.micro_batch_size * 2 return fn(*args, **kwargs) - return wrapper - - -def forward_step_wrapper(fn): - @wraps(fn) - def wrapper(forward_step_func, data_iterator, model, num_microbatches, input_tensor, forward_data_store, config, *args, **kwargs): - output, num_tokens = fn(forward_step_func, data_iterator, model, num_microbatches, input_tensor, forward_data_store, config, *args, **kwargs) - - if not isinstance(input_tensor, list): - # unwrap_output_tensor True - output_tensor = output - else: - output_tensor = output[0] - - # Set the loss scale for Multi-Token Prediction (MTP) loss. - if hasattr(config, 'mtp_num_layers') and config.mtp_num_layers is not None: - # Calculate the loss scale based on the grad_scale_func if available, else default to 1. - loss_scale = ( - config.grad_scale_func(torch.ones(1, device=output_tensor.device)) - if config.grad_scale_func is not None - else torch.ones(1, device=output_tensor.device) - ) - # Set the loss scale - if config.calculate_per_token_loss: - MTPLossAutoScaler.set_loss_scale(loss_scale) - else: - MTPLossAutoScaler.set_loss_scale(loss_scale / num_microbatches) - return output, num_tokens - - return wrapper + return wrapper \ No newline at end of file diff --git a/mindspeed_llm/core/transformer/moe/moe_utils.py b/mindspeed_llm/core/transformer/moe/moe_utils.py index b50b7c34ad8fccfa0bfa76b60c338e32c91b7ea3..aaac7c31464b100aafd08fe1ef6e56356da79c3f 100644 --- a/mindspeed_llm/core/transformer/moe/moe_utils.py +++ b/mindspeed_llm/core/transformer/moe/moe_utils.py @@ -216,11 +216,11 @@ def topk_softmax_with_capacity( def track_moe_metrics_wrapper(fn): @wraps(fn) - def wrapper(self, *args, **kwargs): + def wrapper(*args, **kwargs): _args = get_args() if _args.moe_router_load_balancing_type in ["none", "noaux_tc"] and not _args.seq_aux: return - fn(self, *args, **kwargs) + fn(*args, **kwargs) return wrapper diff --git a/mindspeed_llm/core/transformer/multi_token_prediction.py b/mindspeed_llm/core/transformer/multi_token_prediction.py index e7193144b6d4aa8a1b71404007de965b1076d5c3..dbd01659155510bc25fac7b3340fec14f5667159 100644 --- a/mindspeed_llm/core/transformer/multi_token_prediction.py +++ b/mindspeed_llm/core/transformer/multi_token_prediction.py @@ -1,270 +1,44 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. - +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024, HUAWEI CORPORATION. All rights reserved. +from functools import wraps from contextlib import nullcontext -from dataclasses import dataclass -from typing import List, Optional, Union +from typing import Optional import torch from torch import Tensor import acl from megatron.core import InferenceParams, mpu, parallel_state, tensor_parallel -from megatron.core.dist_checkpointing.mapping import ShardedStateDict -from megatron.core.dist_checkpointing.utils import replace_prefix_for_sharding -from megatron.core.fusions.fused_layer_norm import FusedLayerNorm +from megatron.core.extensions.transformer_engine import TEDelayedScaling from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.tensor_parallel import ( all_gather_last_dim_from_tensor_parallel_region, scatter_to_sequence_parallel_region, ) -from megatron.core.tensor_parallel.layers import ColumnParallelLinear -from megatron.core.transformer.enums import AttnMaskType -from megatron.core.transformer.module import MegatronModule -from megatron.core.transformer.spec_utils import ModuleSpec, build_module -from megatron.core.transformer.transformer_block import TransformerBlockSubmodules +from megatron.core.transformer.spec_utils import build_module +from megatron.core.transformer.multi_token_prediction import MTPLossLoggingHelper, roll_tensor, MTPLossAutoScaler from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.utils import make_tp_sharded_tensor_for_checkpoint, make_viewless_tensor +from megatron.core.utils import make_viewless_tensor from megatron.training import get_args from mindspeed.core.tensor_parallel.random import CheckpointWithoutOutput -from mindspeed_llm.core.transformer.custom_layers.transformer_engine import PTNorm -from mindspeed_llm.training.utils import regenerate_position_ids, get_mtp_position_ids - -SUPPORTED_ATTN_MASK = [ - AttnMaskType.padding, - AttnMaskType.causal, - AttnMaskType.no_mask, - AttnMaskType.padding_causal, -] - - -try: - from megatron.core.extensions.transformer_engine import ( - TEColumnParallelLinear, - TEDelayedScaling, - TENorm, - ) - - HAVE_TE = True -except ImportError: - HAVE_TE = False - - -try: - import apex # pylint: disable=unused-import - - from megatron.core.fusions.fused_layer_norm import FusedLayerNorm - - HAVE_APEX = True - LNImpl = FusedLayerNorm -except ImportError: - import warnings - - warnings.warn('Apex is not installed. Falling back to Torch Norm') - LNImpl = PTNorm - - -def tie_word_embeddings_state_dict( - sharded_state_dict: ShardedStateDict, word_emb_weight: Tensor, word_emb_weight_key: str -) -> None: - """tie the embedding of the mtp processing stage in a given sharded state dict. - - Args: - sharded_state_dict (ShardedStateDict): state dict with the weight to tie. - word_emb_weight (Tensor): weight of the word embedding. - word_emb_weight_key (str): key of the word embedding in the sharded state dict. - - Returns: None, acts in-place - """ - mtp_word_emb_replica_id = ( - 1, # copy of embedding in pre processing stage - 0, - parallel_state.get_data_parallel_rank(with_context_parallel=True), - ) - if word_emb_weight_key not in sharded_state_dict: - raise AssertionError("Word emb weight in sharded state dict.") - del sharded_state_dict[word_emb_weight_key] - sharded_state_dict[word_emb_weight_key] = make_tp_sharded_tensor_for_checkpoint( - tensor=word_emb_weight, - key=word_emb_weight_key, - replica_id=mtp_word_emb_replica_id, - allow_shape_mismatch=True, - ) - - -def tie_output_layer_state_dict( - sharded_state_dict: ShardedStateDict, output_layer_weight: Tensor, output_layer_weight_key: str -) -> None: - """tie the output layer of the mtp processing stage in a given sharded state dict. - - Args: - sharded_state_dict (ShardedStateDict): state dict with the weight to tie. - output_layer_weight (Tensor): weight of the output layer. - output_layer_weight_key (str): key of the output layer in the sharded state dict. - - Returns: None, acts in-place - """ - mtp_output_layer_replica_id = ( - 1, # copy of output layer in post processing stage - 0, - parallel_state.get_data_parallel_rank(with_context_parallel=True), - ) - if output_layer_weight_key not in sharded_state_dict: - raise AssertionError("output layer weight in sharded state dict.") - del sharded_state_dict[output_layer_weight_key] - sharded_state_dict[output_layer_weight_key] = make_tp_sharded_tensor_for_checkpoint( - tensor=output_layer_weight, - key=output_layer_weight_key, - replica_id=mtp_output_layer_replica_id, - allow_shape_mismatch=True, - ) - - -def roll_tensor(tensor, shifts=-1, dims=-1): - """Roll the tensor input along the given dimension(s). - Inserted elements are set to be 0.0. - """ - rolled_tensor = torch.roll(tensor, shifts=shifts, dims=dims) - rolled_tensor.select(dims, shifts).fill_(0) - return rolled_tensor, rolled_tensor.sum() - - -class MTPLossLoggingHelper: - """Helper class for logging MTP losses.""" - - tracker = {} - - @staticmethod - def save_loss_to_tracker( - loss: torch.Tensor, - layer_number: int, - num_layers: int, - reduce_group: torch.distributed.ProcessGroup = None, - avg_group: torch.distributed.ProcessGroup = None, - ): - """Save the mtp loss for logging. - Args: - loss (torch.Tensor): The loss tensor. - layer_number (int): Layer index of the loss. - num_layers (int): The number of total layers. - reduce_group (torch.distributed.ProcessGroup): The group for reducing the loss. - avg_group (torch.distributed.ProcessGroup): The group for averaging the loss. - """ - # Skip mtp loss logging if layer_number is None. - if layer_number is None: - return - - tracker = MTPLossLoggingHelper.tracker - if "values" not in tracker: - tracker["values"] = torch.zeros(num_layers, device=loss.device) - tracker["values"][layer_number] += loss.detach() - tracker["reduce_group"] = reduce_group - tracker["avg_group"] = avg_group - - @staticmethod - def clean_loss_in_tracker(): - """Clear the mtp losses.""" - tracker = MTPLossLoggingHelper.tracker - tracker["values"].zero_() - tracker["reduce_group"] = None - tracker["avg_group"] = None - - @staticmethod - def reduce_loss_in_tracker(): - """Collect and reduce the mtp losses across ranks.""" - tracker = MTPLossLoggingHelper.tracker - if "values" not in tracker: - return - values = tracker["values"] - # Reduce mtp losses across ranks. - if tracker.get('reduce_group') is not None: - torch.distributed.all_reduce(values, group=tracker.get('reduce_group')) - if tracker.get('avg_group') is not None: - torch.distributed.all_reduce( - values, group=tracker['avg_group'], op=torch.distributed.ReduceOp.SUM - ) - tracker["values"] = values / tracker['avg_group'].size() - - - @staticmethod - def track_mtp_metrics(loss_scale, iteration, writer, wandb_writer=None, total_loss_dict=None): - """Track the Multi-Token Prediction (MTP) metrics for logging.""" - MTPLossLoggingHelper.reduce_loss_in_tracker() - tracker = MTPLossLoggingHelper.tracker - if "values" not in tracker: - return - mtp_losses = tracker["values"] * loss_scale - mtp_num_layers = mtp_losses.shape[0] - for i in range(mtp_num_layers): - name = f"mtp_{i+1} loss" - loss = mtp_losses[i] - if total_loss_dict is not None: - total_loss_dict[name] = loss - if writer is not None: - writer.add_scalar(name, loss, iteration) - if wandb_writer is not None: - wandb_writer.log({f"{name}": loss}, iteration) - - MTPLossLoggingHelper.clean_loss_in_tracker() - - -@dataclass -class MultiTokenPredictionLayerSubmodules: - """ - Dataclass for specifying the submodules of a MultiTokenPrediction module. - - Args: - hnorm (Union[ModuleSpec, type]): Specification or instance of the - hidden states normalization to be applied. - enorm (Union[ModuleSpec, type]): Specification or instance of the - embedding normalization to be applied. - eh_proj (Union[ModuleSpec, type]): Specification or instance of the - linear projection to be applied. - transformer_layer (Union[ModuleSpec, type]): Specification - or instance of the transformer block to be applied. - """ - enorm: Union[ModuleSpec, type] = None - hnorm: Union[ModuleSpec, type] = None - eh_proj: Union[ModuleSpec, type] = None - transformer_layer: Union[ModuleSpec, type] = None - layer_norm: Union[ModuleSpec, type] = None - - -def get_mtp_layer_spec( - transformer_layer_spec: ModuleSpec, use_transformer_engine: bool -) -> ModuleSpec: - """Get the MTP layer spec. - - Returns: - ModuleSpec: Module specification with TE modules - """ - if use_transformer_engine: - if not HAVE_TE: - raise AssertionError("transformer_engine should be installed if use_transformer_engine is True") - layer_norm_impl = TENorm - column_parallel_linear_impl = TEColumnParallelLinear - else: - layer_norm_impl = PTNorm - column_parallel_linear_impl = ColumnParallelLinear - - mtp_layer_spec = ModuleSpec( - module=MultiTokenPredictionLayer, - submodules=MultiTokenPredictionLayerSubmodules( - enorm=layer_norm_impl, - hnorm=layer_norm_impl, - eh_proj=column_parallel_linear_impl, - transformer_layer=transformer_layer_spec, - layer_norm=layer_norm_impl, - ), - ) - - return mtp_layer_spec +from mindspeed_llm.training.utils import regenerate_position_ids, get_mtp_position_ids -def get_mtp_layer_offset(config: TransformerConfig) -> int: - """Get the offset of the MTP layer.""" - # Currently, we only support put all of MTP layers on the last pipeline stage. - return 0 +def mtp_reduce_loss_in_tracker(): + """Collect and reduce the mtp losses across ranks.""" + tracker = MTPLossLoggingHelper.tracker + if "values" not in tracker: + return + values = tracker["values"] + # Reduce mtp losses across ranks. + if tracker.get('reduce_group') is not None: + torch.distributed.all_reduce(values, group=tracker.get('reduce_group')) + if tracker.get('avg_group') is not None: + torch.distributed.all_reduce( + values, group=tracker['avg_group'], op=torch.distributed.ReduceOp.SUM + ) + tracker["values"] = values / tracker['avg_group'].size() def get_mtp_num_layers_to_build(config: TransformerConfig) -> int: @@ -279,373 +53,160 @@ def get_mtp_num_layers_to_build(config: TransformerConfig) -> int: return 0 -class MTPLossAutoScaler(torch.autograd.Function): - """An AutoScaler that triggers the backward pass and scales the grad for mtp loss.""" - - main_loss_backward_scale: torch.Tensor = torch.tensor(1.0) - - @staticmethod - def forward(ctx, output: torch.Tensor, mtp_loss: torch.Tensor): - """Preserve the mtp by storing it in the context to avoid garbage collection. - - Args: - output (torch.Tensor): The output tensor. - mtp_loss (torch.Tensor): The mtp loss tensor. - - Returns: - torch.Tensor: The output tensor. - """ - ctx.save_for_backward(mtp_loss) - return output - - @staticmethod - def backward(ctx, grad_output: torch.Tensor): - """Compute and scale the gradient for mtp loss. - - Args: - grad_output (torch.Tensor): The gradient of the output. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: The gradient of the output, scaled mtp loss - gradient. - """ - (mtp_loss,) = ctx.saved_tensors - mtp_loss_backward_scale = MTPLossAutoScaler.main_loss_backward_scale - scaled_mtp_loss_grad = torch.ones_like(mtp_loss) * mtp_loss_backward_scale - return grad_output, scaled_mtp_loss_grad - - @staticmethod - def set_loss_scale(scale: torch.Tensor): - """set the scale of the mtp loss. - - Args: - scale (torch.Tensor): The scale value to set. Please ensure that the scale passed in - matches the scale of the main_loss. - """ - MTPLossAutoScaler.main_loss_backward_scale = scale - - -class MultiTokenPredictionLayer(MegatronModule): - """The implementation for Multi-Token Prediction (MTP) which extends - the prediction scope to multiple future tokens at each position. - - This MTP implementation sequentially predict additional tokens and keep the complete - causal chain at each prediction depth, by using D sequential modules to predict - D additional tokens. - - The k-th MTP module consists of a shared embedding layer, a projection matrix, - a Transformer block, and a shared output head. - - For the i-th input token at the (k - 1)-th prediction depth, we first combine - the representation of the i-th token and the embedding of the (i + K)-th token with - the linear projection. The combined serves as the input of the Transformer block at - the k-th depth to produce the output representation. - - for more information, please refer to DeepSeek-V3 Technical Report - """ - - def __init__( - self, - config: TransformerConfig, - submodules: MultiTokenPredictionLayerSubmodules, - layer_number: int = 1, +def mtp_layer_init_wrapper(fn): + @wraps(fn) + def wrapper( + self, + config, + submodules, + layer_number, ): - super().__init__(config=config) + fn( + self, + config, + submodules, + layer_number, + ) + # fn move out of layer + self.final_layernorm = None args = get_args() - self.sequence_parallel = config.sequence_parallel - self.submodules = submodules - self.layer_number = layer_number self.recompute_mtp_norm = args.recompute_mtp_norm self.recompute_mtp_layer = args.recompute_mtp_layer - - self_attention_spec = self.submodules.transformer_layer.submodules.self_attention - attn_mask_type = self_attention_spec.params.get('attn_mask_type', '') - if attn_mask_type not in SUPPORTED_ATTN_MASK: - raise AssertionError( - f"Multi-Token Prediction (MTP) is not jet supported with " - + f"{attn_mask_type} attention mask type." - + f"The supported attention mask types are {SUPPORTED_ATTN_MASK}." - ) - - self.enorm = build_module( - self.submodules.enorm, - config=self.config, - hidden_size=self.config.hidden_size, - eps=self.config.layernorm_epsilon, - ) - - self.hnorm = build_module( - self.submodules.hnorm, - config=self.config, - hidden_size=self.config.hidden_size, - eps=self.config.layernorm_epsilon, - ) - - # For the linear projection at the (k - 1)-th MTP layer, the input is the concatenation - # of the i-th tocken's hidden states and the (i + K)-th tocken's decoder input, - # so the input's shape is [s, b, 2*h]. - # The output will be send to the following transformer layer, - # so the output's shape should be [s, b, h]. - self.eh_proj = build_module( - self.submodules.eh_proj, - self.config.hidden_size * 2, - self.config.hidden_size, - config=self.config, - init_method=self.config.init_method, - gather_output=False, - bias=False, - skip_bias_add=False, - is_expert=False, - ) - self.transformer_layer = build_module(self.submodules.transformer_layer, config=self.config) - # set mtp_idx for + # set mtp_idx for tnd self.transformer_layer.mtp_idx = self.layer_number self.transformer_layer.self_attention.core_attention.mtp_idx = self.layer_number - def forward( - self, + return wrapper + + +def mtp_layer_forward(self, decoder_input: Tensor, hidden_states: Tensor, attention_mask: Tensor, context: Tensor = None, context_mask: Tensor = None, rotary_pos_emb: Tensor = None, + rotary_pos_cos: Tensor = None, + rotary_pos_sin: Tensor = None, attention_bias: Tensor = None, inference_params: InferenceParams = None, packed_seq_params: PackedSeqParams = None, - ): - """ - Perform the forward pass through the MTP layer. - - Args: - hidden_states (Tensor): hidden states tensor of shape [s, b, h] where s is the - sequence length, b is the batch size, and h is the hidden size. - decoder_input (Tensor): Input tensor of shape [s, b, h] where s is the - sequence length, b is the batch size, and h is the hidden size. - At the (k - 1)-th MTP module, the i-th element of decoder input is - the embedding of (i + K)-th tocken. - attention_mask (Tensor): Boolean tensor of shape [1, 1, s, s] for masking - self-attention. - context (Tensor, optional): Context tensor for cross-attention. - context_mask (Tensor, optional): Mask for cross-attention context - rotary_pos_emb (Tensor, optional): Rotary positional embeddings. - attention_bias (Tensor): Bias tensor for Q * K.T of shape in shape broadcastable - to [b, num_head, sq, skv], e.g. [1, 1, sq, skv]. - Used as an alternative to apply attention mask for TE cuDNN attention. - inference_params (InferenceParams, optional): Parameters for inference-time - optimizations. - packed_seq_params (PackedSeqParams, optional): Parameters for packed sequence - processing. - - Returns: - Union[Tensor, Tuple[Tensor, Tensor]]: The output hidden states tensor of shape - [s, b, h], and optionally the updated context tensor if cross-attention is used. - """ - if context is not None: - raise NotImplementedError(f"multi token prediction + cross attention is not yet supported.") - - hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) - - if self.config.sequence_parallel: - rng_context = tensor_parallel.get_cuda_rng_tracker().fork() - else: - rng_context = nullcontext() - - if self.config.fp8: - import transformer_engine # To keep out TE dependency when not training in fp8 - - if self.config.fp8 == "e4m3": - fp8_format = transformer_engine.common.recipe.Format.E4M3 - elif self.config.fp8 == "hybrid": - fp8_format = transformer_engine.common.recipe.Format.HYBRID - else: - raise ValueError("E4M3 and HYBRID are the only supported FP8 formats.") - - fp8_recipe = TEDelayedScaling( - config=self.config, - fp8_format=fp8_format, - override_linear_precision=(False, False, not self.config.fp8_wgrad), - ) - fp8_group = None - if parallel_state.model_parallel_is_initialized(): - fp8_group = parallel_state.get_amax_reduction_group( - with_context_parallel=True, tp_only_amax_red=self.tp_only_amax_red - ) - fp8_context = transformer_engine.pytorch.fp8_autocast( - enabled=True, fp8_recipe=fp8_recipe, fp8_group=fp8_group - ) - else: - fp8_context = nullcontext() + sequence_len_offset: Tensor = None,): + if context is not None: + raise NotImplementedError(f"multi token prediction + cross attention is not yet supported.") - with rng_context, fp8_context: + hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) - def enorm(tensor): - tensor = self.enorm(tensor) - tensor = make_viewless_tensor( - inp=tensor, requires_grad=True, keep_graph=True - ) - return tensor - - def hnorm(tensor): - tensor = self.hnorm(tensor) - tensor = make_viewless_tensor( - inp=tensor, requires_grad=True, keep_graph=True - ) - return tensor - - if self.recompute_mtp_norm: - decoder_input.requires_grad_(True) - self.enorm_ckpt = CheckpointWithoutOutput() - enorm_output = self.enorm_ckpt.checkpoint(enorm, False, decoder_input) - self.hnorm_ckpt = CheckpointWithoutOutput() - hnorm_output = self.hnorm_ckpt.checkpoint(hnorm, False, hidden_states) - else: - enorm_output = enorm(decoder_input) - hnorm_output = hnorm(hidden_states) - # At the (k - 1)-th MTP module, concatenates the i-th tocken's hidden_states - # and the (i + K)-th tocken's embedding, and combine them with linear projection. - hidden_states = torch.cat((enorm_output, hnorm_output), -1) - if self.recompute_mtp_norm: - self.enorm_ckpt.discard_output() - self.hnorm_ckpt.discard_output() - hidden_states.register_hook(self.enorm_ckpt.recompute) - hidden_states.register_hook(self.hnorm_ckpt.recompute) - hidden_states, _ = self.eh_proj(hidden_states) - # For tensor parallel, all gather after linear_fc. - hidden_states = all_gather_last_dim_from_tensor_parallel_region(hidden_states) - # For sequence parallel, scatter after linear_fc and before transformer layer. - if self.sequence_parallel: - hidden_states = scatter_to_sequence_parallel_region(hidden_states) - if self.recompute_mtp_layer: - hidden_states, _ = tensor_parallel.checkpoint( - self.transformer_layer, - self.config.distribute_saved_activations, - hidden_states, - attention_mask, - context, - context_mask, - rotary_pos_emb, - inference_params, - packed_seq_params, - ) - else: - hidden_states, _ = self.transformer_layer( - hidden_states=hidden_states, - attention_mask=attention_mask, - context=context, - context_mask=context_mask, - rotary_pos_emb=rotary_pos_emb, - inference_params=inference_params, - packed_seq_params=packed_seq_params, - ) - - return hidden_states - - def sharded_state_dict( - self, prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[dict] = None - ) -> ShardedStateDict: - """ - Generate a sharded state dictionary for the multi token prediction layer. - - Args: - prefix (str, optional): Prefix to be added to all keys in the state dict. - sharded_offsets (tuple, optional): Tuple of sharding offsets. - metadata (Optional[dict], optional): Additional metadata for sharding. - - Returns: - ShardedStateDict: A dictionary containing the sharded state of the multi - token prediction layer. - """ - sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata) - return sharded_state_dict - - -@dataclass -class MultiTokenPredictionBlockSubmodules: - """ - Dataclass for specifying the submodules of a multi token prediction block. - - This class defines the structure for configuring the layers, allowing for - flexible and customizable architecture designs. - - Args: - layer_specs (List[ModuleSpec], optional): A list of module specifications for - the layers within the multi token prediction block. Each specification typically - defines a complete multi token prediction layer (e.g., shared embedding, - projection matrix, transformer block, shared output head). - """ - - layer_specs: List[ModuleSpec] = None - - -def _get_mtp_block_submodules( - config: TransformerConfig, spec: Union[MultiTokenPredictionBlockSubmodules, ModuleSpec] -) -> MultiTokenPredictionBlockSubmodules: - """ - Retrieve or construct MultiTokenPredictionBlockSubmodules based on the provided specification. - - Args: - config (TransformerConfig): Configuration object for the transformer model. - spec (Union[MultiTokenPredictionBlockSubmodules, ModuleSpec]): Specification for the - multi token prediction block submodules. - Can be either a MultiTokenPredictionBlockSubmodules instance or a ModuleSpec. + if self.config.sequence_parallel: + rng_context = tensor_parallel.get_cuda_rng_tracker().fork() + else: + rng_context = nullcontext() - Returns: - MultiTokenPredictionBlockSubmodules: The submodules for the multi token prediction block. - """ + if self.config.fp8: + import transformer_engine # To keep out TE dependency when not training in fp8 - # Transformer block submodules. - if isinstance(spec, MultiTokenPredictionBlockSubmodules): - return spec - elif isinstance(spec, ModuleSpec): - if issubclass(spec.module, MultiTokenPredictionBlock): - return spec.submodules + if self.config.fp8 == "e4m3": + fp8_format = transformer_engine.common.recipe.Format.E4M3 + elif self.config.fp8 == "hybrid": + fp8_format = transformer_engine.common.recipe.Format.HYBRID else: - raise Exception(f"specialize for {spec.module.__name__}.") - else: - raise Exception(f"specialize for {type(spec).__name__}.") + raise ValueError("E4M3 and HYBRID are the only supported FP8 formats.") + fp8_recipe = TEDelayedScaling( + config=self.config, + fp8_format=fp8_format, + override_linear_precision=(False, False, not self.config.fp8_wgrad), + ) + fp8_group = None + if parallel_state.model_parallel_is_initialized(): + fp8_group = parallel_state.get_amax_reduction_group( + with_context_parallel=True, tp_only_amax_red=self.tp_only_amax_red + ) + fp8_context = transformer_engine.pytorch.fp8_autocast( + enabled=True, fp8_recipe=fp8_recipe, fp8_group=fp8_group + ) + else: + fp8_context = nullcontext() -class MultiTokenPredictionBlock(MegatronModule): - """The implementation for Multi-Token Prediction (MTP) which extends - the prediction scope to multiple future tokens at each position. + with rng_context, fp8_context: - This MTP implementation sequentially predict additional tokens and keep the complete - causal chain at each prediction depth, by using D sequential modules to predict - D additional tokens. + def enorm(tensor): + tensor = self.enorm(tensor) + tensor = make_viewless_tensor( + inp=tensor, requires_grad=True, keep_graph=True + ) + return tensor - The k-th MTP module consists of a shared embedding layer, a projection matrix, - a Transformer block, and a shared output head. + def hnorm(tensor): + tensor = self.hnorm(tensor) + tensor = make_viewless_tensor( + inp=tensor, requires_grad=True, keep_graph=True + ) + return tensor + + if self.recompute_mtp_norm: + decoder_input.requires_grad_(True) + self.enorm_ckpt = CheckpointWithoutOutput() + enorm_output = self.enorm_ckpt.checkpoint(enorm, False, decoder_input) + self.hnorm_ckpt = CheckpointWithoutOutput() + hnorm_output = self.hnorm_ckpt.checkpoint(hnorm, False, hidden_states) + else: + enorm_output = enorm(decoder_input) + hnorm_output = hnorm(hidden_states) + # At the (k - 1)-th MTP module, concatenates the i-th tocken's hidden_states + # and the (i + K)-th tocken's embedding, and combine them with linear projection. + hidden_states = torch.cat((enorm_output, hnorm_output), -1) + if self.recompute_mtp_norm: + self.enorm_ckpt.discard_output() + self.hnorm_ckpt.discard_output() + hidden_states.register_hook(self.enorm_ckpt.recompute) + hidden_states.register_hook(self.hnorm_ckpt.recompute) + hidden_states, _ = self.eh_proj(hidden_states) + # For tensor parallel, all gather after linear_fc. + hidden_states = all_gather_last_dim_from_tensor_parallel_region(hidden_states) + # For sequence parallel, scatter after linear_fc and before transformer layer. + if self.sequence_parallel: + hidden_states = scatter_to_sequence_parallel_region(hidden_states) + if self.recompute_mtp_layer: + hidden_states, _ = tensor_parallel.checkpoint( + self.transformer_layer, + self.config.distribute_saved_activations, + hidden_states, + attention_mask, + context, + context_mask, + rotary_pos_emb, + rotary_pos_cos, + rotary_pos_sin, + attention_bias, + inference_params, + packed_seq_params, + sequence_len_offset, + ) + else: + hidden_states, _ = self.transformer_layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + context=context, + context_mask=context_mask, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + attention_bias=attention_bias, + inference_params=inference_params, + packed_seq_params=packed_seq_params, + sequence_len_offset=sequence_len_offset, + ) - For the i-th input token at the (k - 1)-th prediction depth, we first combine - the representation of the i-th token and the embedding of the (i + K)-th token with - the linear projection. The combined serves as the input of the Transformer block at - the k-th depth to produce the output representation. + return hidden_states - for more information, please refer to DeepSeek-V3 Technical Report - """ - def __init__( - self, config: TransformerConfig, spec: Union[TransformerBlockSubmodules, ModuleSpec] - ): - super().__init__(config=config) - self.submodules = _get_mtp_block_submodules(config, spec) - self.mtp_loss_scaling_factor = config.mtp_loss_scaling_factor - self._build_layers() - if len(self.layers) == 0: - raise AssertionError("MultiTokenPredictionBlock must have at least one layer.") - - def _build_layers(self): - def build_layer(layer_spec, layer_number): - return build_module(layer_spec, config=self.config, layer_number=layer_number) - - self.layers = torch.nn.ModuleList( - [ - build_layer(layer_spec, i + 1) - for i, layer_spec in enumerate(self.submodules.layer_specs) - ] - ) +def mtp_block_build_layers_wrapper(fn): + @wraps(fn) + def wrapper(self): + fn(self) + # fn move to block self.final_layernorms = torch.nn.ModuleList( [ build_module( @@ -658,144 +219,121 @@ class MultiTokenPredictionBlock(MegatronModule): ] ) - def forward( - self, - input_ids: Tensor, - position_ids: Tensor, - hidden_states: Tensor, - attention_mask: Tensor, - labels: Tensor = None, - context: Tensor = None, - context_mask: Tensor = None, - rotary_pos_emb: Tensor = None, - inference_params: InferenceParams = None, - packed_seq_params: PackedSeqParams = None, - extra_block_kwargs: dict = None, - loss_mask: Optional[Tensor] = None, - embedding=None, - output_layer=None, - output_weight: Optional[torch.Tensor] = None, - compute_language_model_loss=None, - ) -> Tensor: - """ - Perform the forward pass through all of the MTP modules. - - Args: - hidden_states (Tensor): Hidden states for input token with the shape [s, b, h] - where s is the sequence length, b is the batch size, and h is the hidden size. - attention_mask (Tensor): Boolean tensor of shape [1, 1, s, s] for masking - self-attention. - - Returns: - (Tensor): The mtp loss tensor of shape [b, s]. - """ - if labels is None: - raise AssertionError(f"labels should not be None for calculating multi token prediction loss.") + return wrapper + + +def mtp_block_forward( + self, + input_ids: Tensor, + position_ids: Tensor, + hidden_states: Tensor, + attention_mask: Tensor, + labels: Tensor = None, + context: Tensor = None, + context_mask: Tensor = None, + rotary_pos_emb: Tensor = None, + rotary_pos_cos: Tensor = None, + rotary_pos_sin: Tensor = None, + attention_bias: Tensor = None, + inference_params: InferenceParams = None, + packed_seq_params: PackedSeqParams = None, + sequence_len_offset: Tensor = None, + extra_block_kwargs: dict = None, + runtime_gather_output: Optional[bool] = None, + loss_mask: Optional[Tensor] = None, + embedding=None, + output_layer=None, + output_weight: Optional[torch.Tensor] = None, + compute_language_model_loss=None, +) -> Tensor: + """ + Perform the forward pass through all of the MTP modules. - args = get_args() - if loss_mask is None: - # if loss_mask is not provided, use all ones as loss_mask - loss_mask = torch.ones_like(labels) - - hidden_states_main_model = hidden_states - for layer_number in range(len(self.layers)): - # Calc logits for the current Multi-Token Prediction (MTP) layers. - input_ids, _ = roll_tensor(input_ids, shifts=-1, dims=-1) - if args.reset_position_ids: - if '910B' not in acl.get_soc_name() and args.enable_share_memory: - position_ids, shm_manager = get_mtp_position_ids() - if mpu.get_tensor_model_parallel_rank() == 0: - shm_manager.write(position_ids) - else: - position_ids = shm_manager.read() + Args: + hidden_states (Tensor): Hidden states for input token with the shape [s, b, h] + where s is the sequence length, b is the batch size, and h is the hidden size. + attention_mask (Tensor): Boolean tensor of shape [1, 1, s, s] for masking + self-attention. + + Returns: + (Tensor): The mtp loss tensor of shape [b, s]. + """ + if labels is None: + raise AssertionError(f"labels should not be None for calculating multi token prediction loss.") + + args = get_args() + if loss_mask is None: + # if loss_mask is not provided, use all ones as loss_mask + loss_mask = torch.ones_like(labels) + + hidden_states_main_model = hidden_states + for layer_number in range(len(self.layers)): + # Calc logits for the current Multi-Token Prediction (MTP) layers. + input_ids, _ = roll_tensor(input_ids, shifts=-1, dims=-1) + if args.reset_position_ids: + if '910B' not in acl.get_soc_name() and args.enable_share_memory: + position_ids, shm_manager = get_mtp_position_ids() + if mpu.get_tensor_model_parallel_rank() == 0: + shm_manager.write(position_ids) else: - position_ids, _ = roll_tensor(position_ids, shifts=-1, dims=-1) - position_ids = regenerate_position_ids(position_ids, 1) - # embedding - decoder_input = embedding(input_ids=input_ids, position_ids=position_ids) - # norm, linear projection and transformer - hidden_states = self.layers[layer_number]( - decoder_input=decoder_input, - hidden_states=hidden_states, - attention_mask=attention_mask, - inference_params=inference_params, - rotary_pos_emb=rotary_pos_emb, - packed_seq_params=packed_seq_params, - **(extra_block_kwargs or {}), + position_ids = shm_manager.read() + else: + position_ids, _ = roll_tensor(position_ids, shifts=-1, dims=-1) + position_ids = regenerate_position_ids(position_ids, 1) + # embedding + decoder_input = embedding(input_ids=input_ids, position_ids=position_ids) + # norm, linear projection and transformer + hidden_states = self.layers[layer_number]( + decoder_input=decoder_input, + hidden_states=hidden_states, + attention_mask=attention_mask, + inference_params=inference_params, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + packed_seq_params=packed_seq_params, + sequence_len_offset=sequence_len_offset, + **(extra_block_kwargs or {}), + ) + # Layer norm before shared head layer. + hidden_states = self.final_layernorms[layer_number](hidden_states) + hidden_states = make_viewless_tensor( + inp=hidden_states, requires_grad=True, keep_graph=True + ) + # output + mtp_logits, _ = output_layer( + hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output + ) + # Calc loss for the current Multi-Token Prediction (MTP) layers. + labels, _ = roll_tensor(labels, shifts=-1, dims=-1) + loss_mask, num_tokens = roll_tensor(loss_mask, shifts=-1, dims=-1) + + if args.is_instruction_dataset: + mtp_labels = labels[:, 1:].contiguous() + mtp_logits = mtp_logits[:-1, :, :].contiguous() + mtp_loss_mask = loss_mask[..., 1:].view(-1).float() + num_tokens = torch.sum(mtp_loss_mask) + else: + mtp_labels = labels + mtp_loss_mask = loss_mask + + mtp_loss = compute_language_model_loss(mtp_labels, mtp_logits) + mtp_loss = mtp_loss_mask * mtp_loss + if self.training: + MTPLossLoggingHelper.save_loss_to_tracker( + torch.sum(mtp_loss) / num_tokens, + layer_number, + self.config.mtp_num_layers, + avg_group=parallel_state.get_tensor_and_context_parallel_group(), ) - # Layer norm before shared head layer. - hidden_states = self.final_layernorms[layer_number](hidden_states) - hidden_states = make_viewless_tensor( - inp=hidden_states, requires_grad=True, keep_graph=True + mtp_loss_scale = self.mtp_loss_scaling_factor / self.config.mtp_num_layers + if self.config.calculate_per_token_loss: + hidden_states_main_model = MTPLossAutoScaler.apply( + hidden_states_main_model, mtp_loss_scale * mtp_loss ) - # output - mtp_logits, _ = output_layer( - hidden_states, weight=output_weight + else: + hidden_states_main_model = MTPLossAutoScaler.apply( + hidden_states_main_model, mtp_loss_scale * mtp_loss / num_tokens ) - # Calc loss for the current Multi-Token Prediction (MTP) layers. - labels, _ = roll_tensor(labels, shifts=-1, dims=-1) - loss_mask, num_tokens = roll_tensor(loss_mask, shifts=-1, dims=-1) - - if args.is_instruction_dataset: - mtp_labels = labels[:, 1:].contiguous() - mtp_logits = mtp_logits[:-1, :, :].contiguous() - mtp_loss_mask = loss_mask[..., 1:].view(-1).float() - num_tokens = torch.sum(mtp_loss_mask) - else: - mtp_labels = labels - - mtp_loss = compute_language_model_loss(mtp_labels, mtp_logits) - - if args.is_instruction_dataset: - mtp_loss = mtp_loss_mask * mtp_loss - else: - mtp_loss = loss_mask * mtp_loss - - if self.training: - MTPLossLoggingHelper.save_loss_to_tracker( - torch.sum(mtp_loss) / num_tokens, - layer_number, - self.config.mtp_num_layers, - avg_group=parallel_state.get_tensor_and_context_parallel_group(), - ) - mtp_loss_scale = self.mtp_loss_scaling_factor / self.config.mtp_num_layers - if self.config.calculate_per_token_loss: - hidden_states_main_model = MTPLossAutoScaler.apply( - hidden_states_main_model, mtp_loss_scale * mtp_loss - ) - else: - hidden_states_main_model = MTPLossAutoScaler.apply( - hidden_states_main_model, mtp_loss_scale * mtp_loss / num_tokens - ) - return hidden_states_main_model - - def sharded_state_dict( - self, prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[dict] = None - ) -> ShardedStateDict: - """ - Generate a sharded state dictionary for the multi token prediction module. - - Args: - prefix (str, optional): Prefix to be added to all keys in the state dict. - sharded_offsets (tuple, optional): Tuple of sharding offsets. - metadata (Optional[dict], optional): Additional metadata for sharding. - - Returns: - ShardedStateDict: A dictionary containing the sharded state of the multi - token prediction module. - """ - sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata) - layer_prefix = f'{prefix}layers.' - for layer in self.layers: - offset = get_mtp_layer_offset(self.config) - sharded_prefix = f'{layer_prefix}{layer.layer_number - 1 }.' - - state_dict_prefix = f'{layer_prefix}{layer.layer_number - 1 - offset}.' - sharded_pp_offset = [] - layer_sharded_state_dict = layer.sharded_state_dict( - state_dict_prefix, sharded_pp_offset, metadata - ) - replace_prefix_for_sharding(layer_sharded_state_dict, state_dict_prefix, sharded_prefix) - sharded_state_dict.update(layer_sharded_state_dict) - return sharded_state_dict + return hidden_states_main_model \ No newline at end of file diff --git a/mindspeed_llm/core/transformer/transformer_config.py b/mindspeed_llm/core/transformer/transformer_config.py index cd7420c29c98493e5962c42d64894078a8ea9e7b..4605142efe1b4775ebfe78b84e993d1d419680e5 100644 --- a/mindspeed_llm/core/transformer/transformer_config.py +++ b/mindspeed_llm/core/transformer/transformer_config.py @@ -7,18 +7,6 @@ import torch.nn.functional as F from megatron.training import get_args -def transformer_config_post_init_mtp_wrapper(fn): - @wraps(fn) - def wrapper(self): - fn(self) - args = get_args() - - self.mtp_num_layers = args.mtp_num_layers - self.mtp_loss_scaling_factor = args.mtp_loss_scaling_factor - - return wrapper - - def transformer_config_post_init_wrapper(fn): @wraps(fn) def wrapper(self): diff --git a/mindspeed_llm/core/transformer/transformer_layer.py b/mindspeed_llm/core/transformer/transformer_layer.py index 190da41c840a45ffc9d489eb1d9a999261fbfb1c..a768e13d573d28bd5ef9f2e3b596bd8453af6a73 100644 --- a/mindspeed_llm/core/transformer/transformer_layer.py +++ b/mindspeed_llm/core/transformer/transformer_layer.py @@ -14,6 +14,8 @@ # limitations under the License. import math + +from megatron.core import tensor_parallel from megatron.core.transformer.transformer_layer import TransformerLayerSubmodules from megatron.core.utils import WrappedTensor, deprecate_inference_params from megatron.core.transformer.transformer_layer import TransformerLayer as MegatronTransformerLayer @@ -56,79 +58,23 @@ class TransformerLayer(MegatronTransformerLayer): self.mtp_idx = 0 self.self_attention.core_attention.mtp_idx = 0 - def forward(self, hidden_states, attention_mask, context=None, - context_mask=None, - rotary_pos_emb=None, - inference_params=None, - attention_bias=None, - inference_context=None, - packed_seq_params=None): - - inference_context = deprecate_inference_params(inference_context, inference_params) - - # hidden_states: [s, b, h] + def _forward_mlp(self, pre_mlp_layernorm_output, residual): args = get_args() - # Residual connection. - residual = hidden_states - - # Optional Input Layer norm - input_layernorm_output = self.input_layernorm(hidden_states) - - if args.input_layernorm_in_fp32: - input_layernorm_output = input_layernorm_output.float() - - # Self attention. - attention_output_with_bias = self.self_attention( - input_layernorm_output, - attention_mask=attention_mask, - inference_context=inference_context, - rotary_pos_emb=rotary_pos_emb, - packed_seq_params=packed_seq_params, - ) - - if args.scale_depth is not None: - attention_output, attention_bias = attention_output_with_bias - attention_output = attention_output * (args.scale_depth / math.sqrt(args.num_layers)) - attention_output_with_bias = (attention_output, attention_bias) - - # inside the module provided in the `bias_dropout_add_spec` module? - with self.bias_dropout_add_exec_handler(): - hidden_states = self.self_attn_bda(self.training, self.config.bias_dropout_fusion)( - attention_output_with_bias, residual, self.hidden_dropout + # MLP. + if self.recompute_mlp: + mlp_output_with_bias = tensor_parallel.checkpoint( + self.mlp, False, pre_mlp_layernorm_output ) + else: + mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output) - # Residual connection. - residual = hidden_states - - # Optional Layer norm after self-attention - pre_cross_attn_layernorm_output = self.pre_cross_attn_layernorm(hidden_states) - - # Cross attention. - attention_output_with_bias = self.cross_attention( - pre_cross_attn_layernorm_output, - attention_mask=context_mask, - key_value_states=context, - inference_context=inference_context, - ) - - if isinstance(attention_output_with_bias, dict) and "context" in attention_output_with_bias: - context = attention_output_with_bias["context"] - - # inside the module provided in the `bias_dropout_add_spec` module? - with self.bias_dropout_add_exec_handler(): - hidden_states = self.cross_attn_bda(self.training, self.config.bias_dropout_fusion)( - attention_output_with_bias, residual, self.hidden_dropout + if self.recompute_pre_mlp_layernorm: + # discard the output of the pre-mlp layernorm and register the recompute + # as a gradient hook of mlp_output_with_bias[0] + self.pre_mlp_norm_checkpoint.discard_output_and_register_recompute( + mlp_output_with_bias[0] ) - # Residual connection. - residual = hidden_states - - # Optional Layer norm post the cross-attention. - pre_mlp_layernorm_output = self.pre_mlp_layernorm(hidden_states) - - # MLP. - mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output) - if args.scale_depth is not None: mlp_output, mlp_bias = mlp_output_with_bias mlp_output = mlp_output * (args.scale_depth / math.sqrt(args.num_layers)) @@ -150,4 +96,4 @@ class TransformerLayer(MegatronTransformerLayer): inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True ) - return output, context + return output diff --git a/mindspeed_llm/features_manager/common/embedding.py b/mindspeed_llm/features_manager/common/embedding.py index d8ef34bc7e4a0400aeed30ad7f23beaba42c0048..179c22201a5763c245da903dc1a492958794aa98 100644 --- a/mindspeed_llm/features_manager/common/embedding.py +++ b/mindspeed_llm/features_manager/common/embedding.py @@ -8,19 +8,6 @@ class LanguageModelEmbeddingFeature(MindSpeedFeature): def register_patches(self, patch_manager, args): from mindspeed.core.models.common.embeddings.language_model_embedding import language_model_embedding_forward_wrapper - from mindspeed_llm.core.models.common.language_module.language_module import ( - setup_embeddings_and_output_layer, - tie_embeddings_and_output_weights_state_dict, - ) - - patch_manager.register_patch( - 'megatron.core.models.common.language_module.language_module.LanguageModule' - '.setup_embeddings_and_output_layer', - setup_embeddings_and_output_layer) - patch_manager.register_patch( - 'megatron.core.models.common.language_module.language_module.LanguageModule' - '.tie_embeddings_and_output_weights_state_dict', - tie_embeddings_and_output_weights_state_dict) patch_manager.register_patch( 'megatron.core.models.common.embeddings.language_model_embedding.LanguageModelEmbedding.forward', language_model_embedding_forward_wrapper) \ No newline at end of file diff --git a/mindspeed_llm/tasks/megatron_adaptor.py b/mindspeed_llm/tasks/megatron_adaptor.py index 3b2513b1d642991adbd84a31ceb3a19a8425cff1..aac72a016c69469331b86c5c9cc89abb11e21490 100644 --- a/mindspeed_llm/tasks/megatron_adaptor.py +++ b/mindspeed_llm/tasks/megatron_adaptor.py @@ -187,6 +187,7 @@ class CoreAdaptation(MegatronAdaptationABC): self.patch_pipeline_parallel_schedules() self.patch_swap_optimizer() self.patch_sft() + self.patch_mtp() def patch_core_distributed(self): import megatron.core @@ -195,15 +196,6 @@ class CoreAdaptation(MegatronAdaptationABC): MegatronAdaptation.register('megatron.core.distributed.finalize_model_grads._allreduce_layernorm_grads', allreduce_layernorm_grads) - # Mtp share embedding - from mindspeed_llm.core.distributed.finalize_model_grads import _allreduce_word_embedding_grads - MegatronAdaptation.register('megatron.core.distributed.finalize_model_grads._allreduce_word_embedding_grads', - _allreduce_word_embedding_grads) - # expert bias - from mindspeed_llm.core.distributed.finalize_model_grads import finalize_model_grads - MegatronAdaptation.register('megatron.core.distributed.finalize_model_grads.finalize_model_grads', - finalize_model_grads) - # optim relative. from mindspeed.core.distributed.param_and_grad_buffer import reuse_fp32_param_param_and_grad_buffer_init_wrapper MegatronAdaptation.register('megatron.core.distributed.param_and_grad_buffer._ParamAndGradBuffer.__init__', @@ -256,7 +248,6 @@ class CoreAdaptation(MegatronAdaptationABC): dot_product_attention_forward_wrapper, ulysses_context_parallel_forward_wrapper from ..core.models.gpt.gpt_model import GPTModel - args = MegatronAdaptation.get_args() # Embedding MegatronAdaptation.register( @@ -303,11 +294,6 @@ class CoreAdaptation(MegatronAdaptationABC): # moe_fb_overlap will shadow this forward impl MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel', GPTModel) - - from ..core.models.common.embeddings.language_model_embedding import language_model_embedding_init_func - MegatronAdaptation.register( - 'megatron.core.models.common.embeddings.language_model_embedding.LanguageModelEmbedding.__init__', - language_model_embedding_init_func) # For recomputation if args.share_kvstates: from mindspeed_llm.core.transformer.transformer_block import share_kvstates_checkpointed_forward_func @@ -341,10 +327,6 @@ class CoreAdaptation(MegatronAdaptationABC): from ..core.transformer.transformer_config import transformer_config_post_init_wrapper MegatronAdaptation.register('megatron.core.transformer.transformer_config.TransformerConfig.__post_init__', transformer_config_post_init_wrapper) - # for mtp - from ..core.transformer.transformer_config import transformer_config_post_init_mtp_wrapper - MegatronAdaptation.register('megatron.core.transformer.transformer_config.TransformerConfig.__post_init__', - transformer_config_post_init_mtp_wrapper) MegatronAdaptation.register('torch.cuda.get_device_capability', get_device_capability) megatron.core.transformer.transformer_block.LayerNormImpl = PTNorm MegatronAdaptation.register('megatron.core.transformer.transformer_block.TENorm', PTNorm) @@ -538,10 +520,6 @@ class CoreAdaptation(MegatronAdaptationABC): from ..core.pipeline_parallel.schedules import get_forward_backward_func_wrapper MegatronAdaptation.register('megatron.core.pipeline_parallel.schedules.get_forward_backward_func', get_forward_backward_func_wrapper) - # for mtp - from ..core.pipeline_parallel.schedules import forward_step_wrapper - MegatronAdaptation.register('megatron.core.pipeline_parallel.schedules.forward_step', forward_step_wrapper) - def patch_tensor_parallel(self): from mindspeed.core.megatron_basic.megatron_basic import _set_cuda_rng_state from ..core import vocab_parallel_embedding_forward, vocab_embedding_init_func, checkpoint_forward_wrapper, \ @@ -682,8 +660,6 @@ class CoreAdaptation(MegatronAdaptationABC): from mindspeed_llm.training.utils import unwrap_model_wrapper MegatronAdaptation.register('megatron.training.checkpointing.unwrap_model', unwrap_model_wrapper) MegatronAdaptation.register('megatron.training.training.unwrap_model', unwrap_model_wrapper) - from ..training.training import training_log - MegatronAdaptation.register('megatron.training.training.training_log', training_log) from mindspeed_llm.training.utils import generate_adaptive_cp_mask_list_by_user, generate_adaptive_cp_grid_mask_by_user MegatronAdaptation.register('mindspeed.core.context_parallel.utils.generate_adaptive_cp_mask_list_by_user', @@ -743,6 +719,50 @@ class CoreAdaptation(MegatronAdaptationABC): set_forward_func(forward_step_in_sft_with_dualpipe) + def patch_mtp(self): + import megatron + from ..core import PTNorm + from ..core.transformer.multi_token_prediction import get_mtp_num_layers_to_build, \ + mtp_reduce_loss_in_tracker + from ..core.models.common.language_module.language_module import setup_embeddings_and_output_layer + + # dualpipe do not need to init embedding weight + from ..core.models.common.embeddings.language_model_embedding import language_model_embedding_init_func + MegatronAdaptation.register( + 'megatron.core.models.common.embeddings.language_model_embedding.LanguageModelEmbedding.__init__', + language_model_embedding_init_func) + # mtp compatibility + megatron.core.transformer.multi_token_prediction.LNImpl = PTNorm + MegatronAdaptation.register( + 'megatron.core.transformer.multi_token_prediction.MTPLossLoggingHelper.reduce_loss_in_tracker', + mtp_reduce_loss_in_tracker) + MegatronAdaptation.register( + 'megatron.core.transformer.multi_token_prediction.get_mtp_num_layers_to_build', + get_mtp_num_layers_to_build) + MegatronAdaptation.register( + 'megatron.core.models.common.language_module.language_module.LanguageModule' + '.setup_embeddings_and_output_layer', + setup_embeddings_and_output_layer) + # patch for mtp + from ..core.transformer.multi_token_prediction import ( + mtp_layer_init_wrapper, + mtp_layer_forward, + mtp_block_build_layers_wrapper, + mtp_block_forward, + ) + MegatronAdaptation.register( + 'megatron.core.transformer.multi_token_prediction.MultiTokenPredictionLayer.__init__', + mtp_layer_init_wrapper) + MegatronAdaptation.register( + 'megatron.core.transformer.multi_token_prediction.MultiTokenPredictionLayer.forward', + mtp_layer_forward) + MegatronAdaptation.register( + 'megatron.core.transformer.multi_token_prediction.MultiTokenPredictionBlock._build_layers', + mtp_block_build_layers_wrapper) + MegatronAdaptation.register( + 'megatron.core.transformer.multi_token_prediction.MultiTokenPredictionBlock.forward', + mtp_block_forward) + class LegacyAdaptation(MegatronAdaptationABC): """ diff --git a/mindspeed_llm/tasks/models/transformer/multi_head_latent_attention.py b/mindspeed_llm/tasks/models/transformer/multi_head_latent_attention.py index 73167991f85ae929a692e20a08376e45aef3044b..815a0d79051ecb0162f6a5dfcb87df0b4c183a61 100644 --- a/mindspeed_llm/tasks/models/transformer/multi_head_latent_attention.py +++ b/mindspeed_llm/tasks/models/transformer/multi_head_latent_attention.py @@ -233,13 +233,15 @@ class MultiHeadLatentAttention(SelfAttention): hidden_states, attention_mask, key_value_states=None, - inference_params=None, inference_context=None, rotary_pos_emb=None, rotary_pos_cos=None, rotary_pos_sin=None, attention_bias=None, packed_seq_params=None, + sequence_len_offset=None, + *, + inference_params=None, ): """ Do patch for repeating KV so that GQA+Ulysses is better supported. diff --git a/mindspeed_llm/tasks/posttrain/base/base_trainer.py b/mindspeed_llm/tasks/posttrain/base/base_trainer.py index 16b69da2b5705e78f4bc7f3cf3afac06adf418fb..620f3f7197a831d3296d9e946815bf5e1b942a73 100644 --- a/mindspeed_llm/tasks/posttrain/base/base_trainer.py +++ b/mindspeed_llm/tasks/posttrain/base/base_trainer.py @@ -17,10 +17,10 @@ from megatron.core.transformer.spec_utils import import_module from megatron.core.models.gpt.gpt_layer_specs import ( get_gpt_layer_local_spec, get_gpt_layer_with_transformer_engine_spec, + get_gpt_mtp_block_spec, ) from megatron.core.models.gpt import GPTModel from megatron.training.checkpointing import save_checkpoint -from mindspeed_llm.core.models.gpt.gpt_layer_specs import get_gpt_mtp_block_spec from mindspeed_llm.training import build_train_args from mindspeed_llm.training import train from mindspeed_llm.training.initialize import set_jit_fusion_options diff --git a/mindspeed_llm/training/arguments.py b/mindspeed_llm/training/arguments.py index faa582ea9b2ea40abe01e9023aea313698f51a30..d59d68d5adf03cb00fe0ad9ecdc3273f3f101028 100644 --- a/mindspeed_llm/training/arguments.py +++ b/mindspeed_llm/training/arguments.py @@ -167,16 +167,6 @@ def _add_deepseek_moe_args(parser): def _add_mtp_args(parser): group = parser.add_argument_group(title='multi token prediction') - group.add_argument('--mtp-num-layers', type=int, default=None, - help='Number of Multi-Token Prediction (MTP) Layers.' - 'MTP extends the prediction scope to multiple future tokens at each position.' - 'This MTP implementation sequentially predict additional tokens ' - 'by using D sequential modules to predict D additional tokens.') - group.add_argument('--mtp-loss-scaling-factor', type=float, default=0.1, - help='Scaling factor of Multi-Token Prediction (MTP) loss. ' - 'We compute the average of the MTP losses across all depths, ' - 'and multiply it the scaling factor to obtain the overall MTP loss, ' - 'which serves as an additional training objective.') group.add_argument('--recompute-mtp-norm', action='store_true', default=False, help='Multi-Token prediction recompute norm') group.add_argument('--recompute-mtp-layer', action='store_true', default=False, @@ -1454,16 +1444,6 @@ def _validate_fused_opts(args): '--position-embedding-type=rope') -def _validate_mtp_args(args): - if args.mtp_num_layers: - assert not args.use_legacy_models, "The legacy Megatron models does not support Multi-Token Prediction (MTP)." - assert args.context_parallel_size == 1, "Multi-Token Prediction (MTP) is not supported with Context Parallelism." - assert args.position_embedding_type == "rope" or args.position_embedding_type == "none", ( - f"Multi-Token Prediction (MTP) is not supported with {args.position_embedding_type} position embedding type." - + f"The supported position embedding types are rope and none." - ) - - def validate_args_decorator(megatron_validate_args): @wraps(megatron_validate_args) def wrapper(args, defaults=None): @@ -1510,9 +1490,6 @@ def validate_args_decorator(megatron_validate_args): _valid_tp_2d_args(args) _valid_fa_div_args(args) _add_dummy_args(args) - # remove in future megatron version - _validate_mtp_args(args) - _add_dummy_args_v2(args) for feature in FEATURES_LIST: diff --git a/mindspeed_llm/training/training.py b/mindspeed_llm/training/training.py index 69b678929dd2d8a97ef56d45e4f5f66603b05596..950423f59ae5b8d70754baf24cd1123f5b6d0508 100644 --- a/mindspeed_llm/training/training.py +++ b/mindspeed_llm/training/training.py @@ -40,8 +40,7 @@ from megatron.training.checkpointing import save_checkpoint from megatron.training.initialize import initialize_megatron from megatron.training.initialize import write_args_to_tensorboard from megatron.training.arguments import core_transformer_config_from_args -from megatron.training.theoretical_memory_usage import report_theoretical_memory -from megatron.training.training import disable_forward_pre_hook, enable_forward_pre_hook +from megatron.training.training import disable_forward_pre_hook, enable_forward_pre_hook, training_log from megatron.training.training import ( train_step, calc_params_l2_norm, evaluate_and_print_results, @@ -740,245 +739,6 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler, return iteration, num_floating_point_operations_so_far -def training_log(loss_dict, total_loss_dict, learning_rate, decoupled_learning_rate, iteration, - loss_scale, report_memory_flag, skipped_iter, - grad_norm, params_norm, num_zeros_in_grad): - """Log training information such as losses, timing, ....""" - args = get_args() - timers = get_timers() - writer = get_tensorboard_writer() - wandb_writer = get_wandb_writer() - one_logger = get_one_logger() - - # Advanced, skipped, and Nan iterations. - advanced_iters_key = 'advanced iterations' - skipped_iters_key = 'skipped iterations' - nan_iters_key = 'nan iterations' - # Advanced iterations. - if not skipped_iter: - total_loss_dict[advanced_iters_key] = total_loss_dict.get( - advanced_iters_key, 0) + 1 - else: - if advanced_iters_key not in total_loss_dict: - total_loss_dict[advanced_iters_key] = 0 - # Skipped iterations. - total_loss_dict[skipped_iters_key] = total_loss_dict.get( - skipped_iters_key, 0) + skipped_iter - # Update losses and set nan iterations - got_nan = False - for key in loss_dict: - if not skipped_iter: - total_loss_dict[key] = total_loss_dict.get( - key, torch.tensor([0.0], dtype=torch.float, device='cuda')) + loss_dict[key] - else: - value = loss_dict[key].float().sum().item() - is_nan = value == float('inf') or \ - value == -float('inf') or \ - value != value - got_nan = got_nan or is_nan - total_loss_dict[nan_iters_key] = total_loss_dict.get( - nan_iters_key, 0) + int(got_nan) - - # Logging. - timers_to_log = [ - 'forward-backward', - 'forward-compute', - 'backward-compute', - 'batch-generator', - 'forward-recv', - 'forward-send', - 'backward-recv', - 'backward-send', - 'forward-send-forward-recv', - 'forward-send-backward-recv', - 'backward-send-forward-recv', - 'backward-send-backward-recv', - 'forward-backward-send-forward-backward-recv', - 'layernorm-grads-all-reduce', - 'embedding-grads-all-reduce', - 'all-grads-sync', - 'params-all-gather', - 'optimizer-copy-to-main-grad', - 'optimizer-unscale-and-check-inf', - 'optimizer-clip-main-grad', - 'optimizer-count-zeros', - 'optimizer-inner-step', - 'optimizer-copy-main-to-model-params', - 'optimizer'] - - # Calculate batch size. - batch_size = args.micro_batch_size * args.data_parallel_size * \ - get_num_microbatches() - - # Track app tag & app tag ID - one_logger_utils.track_app_tag(batch_size, args.world_size, args.seq_length) - - total_iterations = total_loss_dict[advanced_iters_key] + \ - total_loss_dict[skipped_iters_key] - - # Tensorboard values. - # Timer requires all the ranks to call. - if args.log_timers_to_tensorboard and \ - (iteration % args.tensorboard_log_interval == 0): - timers.write(timers_to_log, writer, iteration, - normalizer=total_iterations) - if writer and (iteration % args.tensorboard_log_interval == 0): - if wandb_writer: - wandb_writer.log({'samples vs steps': args.consumed_train_samples}, - iteration) - if args.log_learning_rate_to_tensorboard: - writer.add_scalar('learning-rate', learning_rate, iteration) - if args.decoupled_lr is not None: - writer.add_scalar('decoupled-learning-rate', decoupled_learning_rate, iteration) - writer.add_scalar('learning-rate vs samples', learning_rate, - args.consumed_train_samples) - if wandb_writer: - wandb_writer.log({'learning-rate': learning_rate}, iteration) - if args.log_batch_size_to_tensorboard: - writer.add_scalar('batch-size', batch_size, iteration) - writer.add_scalar('batch-size vs samples', batch_size, - args.consumed_train_samples) - if wandb_writer: - wandb_writer.log({'batch-size': batch_size}, iteration) - for key in loss_dict: - writer.add_scalar(key, loss_dict[key], iteration) - writer.add_scalar(key + ' vs samples', loss_dict[key], - args.consumed_train_samples) - if wandb_writer: - wandb_writer.log({key: loss_dict[key]}, iteration) - if args.log_loss_scale_to_tensorboard: - writer.add_scalar('loss-scale', loss_scale, iteration) - writer.add_scalar('loss-scale vs samples', loss_scale, - args.consumed_train_samples) - if wandb_writer: - wandb_writer.log({'loss-scale': loss_scale}, iteration) - if args.log_world_size_to_tensorboard: - writer.add_scalar('world-size', args.world_size, iteration) - writer.add_scalar('world-size vs samples', args.world_size, - args.consumed_train_samples) - if wandb_writer: - wandb_writer.log({'world-size': args.world_size}, iteration) - if grad_norm is not None: - writer.add_scalar('grad-norm', grad_norm, iteration) - writer.add_scalar('grad-norm vs samples', grad_norm, - args.consumed_train_samples) - if wandb_writer: - wandb_writer.log({'grad-norm': grad_norm}, iteration) - if num_zeros_in_grad is not None: - writer.add_scalar('num-zeros', num_zeros_in_grad, iteration) - writer.add_scalar('num-zeros vs samples', num_zeros_in_grad, - args.consumed_train_samples) - if wandb_writer: - wandb_writer.log({'num-zeros': num_zeros_in_grad}, iteration) - if params_norm is not None: - writer.add_scalar('params-norm', params_norm, iteration) - writer.add_scalar('params-norm vs samples', params_norm, - args.consumed_train_samples) - if wandb_writer: - wandb_writer.log({'params-norm': params_norm}, iteration) - if args.log_memory_to_tensorboard: - mem_stats = torch.cuda.memory_stats() - writer.add_scalar( - "mem-reserved-bytes", - mem_stats["reserved_bytes.all.current"], - iteration, - ) - writer.add_scalar( - "mem-allocated-bytes", - mem_stats["allocated_bytes.all.current"], - iteration, - ) - writer.add_scalar( - "mem-allocated-count", - mem_stats["allocation.all.current"], - iteration, - ) - if args.num_experts is not None: - moe_loss_scale = 1 / get_num_microbatches() - track_moe_metrics(moe_loss_scale, iteration, writer, wandb_writer, total_loss_dict, args.moe_per_layer_logging) - if args.mtp_num_layers: - from mindspeed_llm.core.transformer.multi_token_prediction import MTPLossLoggingHelper - - mtp_loss_scale = 1 / get_num_microbatches() - MTPLossLoggingHelper.track_mtp_metrics( - mtp_loss_scale, iteration, writer, wandb_writer, total_loss_dict - ) - - if iteration % args.log_interval == 0: - elapsed_time = timers('interval-time').elapsed(barrier=True) - elapsed_time_per_iteration = elapsed_time / total_iterations - - throughput = num_floating_point_operations(args, batch_size) / ( - elapsed_time_per_iteration * 10**12 * args.world_size) - - one_logger_utils.track_e2e_metrics(args.log_throughput, throughput) - - if args.log_timers_to_tensorboard: - if writer: - writer.add_scalar('iteration-time', - elapsed_time_per_iteration, iteration) - if wandb_writer: - wandb_writer.log({'iteration-time': elapsed_time_per_iteration}, - iteration) - log_string = f" [{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}]" - log_string += ' iteration {:8d}/{:8d} |'.format( - iteration, args.train_iters) - log_string += ' consumed samples: {:12d} |'.format( - args.consumed_train_samples) - log_string += ' elapsed time per iteration (ms): {:.1f} |'.format( - elapsed_time_per_iteration * 1000.0) - if args.log_throughput: - log_string += f' throughput per GPU (TFLOP/s/GPU): {throughput:.1f} |' - if args.log_timers_to_tensorboard: - if writer: - writer.add_scalar('throughput', throughput, iteration) - if wandb_writer: - wandb_writer.log({'throughput': throughput}, iteration) - assert learning_rate is not None - # Decoupled_learning_rate should be not None only on first and last pipeline stage. - log_string += ' learning rate: {:.6E} |'.format(learning_rate) - if args.decoupled_lr is not None and (mpu.is_pipeline_first_stage(ignore_virtual=True) or - mpu.is_pipeline_last_stage(ignore_virtual=True)): - assert decoupled_learning_rate is not None - log_string += ' decoupled learning rate: {:.6E} |'.format(decoupled_learning_rate) - else: - assert decoupled_learning_rate is None - log_string += ' global batch size: {:5d} |'.format(batch_size) - for key in total_loss_dict: - if key not in [advanced_iters_key, skipped_iters_key, - nan_iters_key]: - avg = total_loss_dict[key].item() / \ - float(max(1, total_loss_dict[advanced_iters_key])) - if avg > 0.0: - log_string += ' {}: {:.6E} |'.format(key, avg) - total_loss_dict[key] = torch.tensor([0.0], dtype=torch.float, device='cuda') - log_string += ' loss scale: {:.1f} |'.format(loss_scale) - if grad_norm is not None: - log_string += ' grad norm: {:.3f} |'.format(grad_norm) - if num_zeros_in_grad is not None: - log_string += ' num zeros: {:.1f} |'.format(num_zeros_in_grad) - if params_norm is not None: - log_string += ' params norm: {:.3f} |'.format(params_norm) - log_string += ' number of skipped iterations: {:3d} |'.format( - total_loss_dict[skipped_iters_key]) - log_string += ' number of nan iterations: {:3d} |'.format( - total_loss_dict[nan_iters_key]) - total_loss_dict[advanced_iters_key] = 0 - total_loss_dict[skipped_iters_key] = 0 - total_loss_dict[nan_iters_key] = 0 - print_rank_last(log_string) - if report_memory_flag and learning_rate > 0.: - # Report memory after optimizer state has been initialized. - if torch.distributed.get_rank() == 0: - num_microbatches = get_num_microbatches() - report_theoretical_memory(args, num_microbatches=num_microbatches, verbose=True) - report_memory('(after {} iterations)'.format(iteration)) - report_memory_flag = False - timers.log(timers_to_log, normalizer=args.log_interval) - - return report_memory_flag - - def should_disable_forward_pre_hook(args): """Block forward pre-hook for certain configurations.""" return not args.use_custom_fsdp and args.use_distributed_optimizer and args.overlap_param_gather \ No newline at end of file diff --git a/mindspeed_llm/training/utils.py b/mindspeed_llm/training/utils.py index a5c8678fb421bd7cdf33f0dcd151ab148515fc72..294d70f49dd93b173d8132693cfdf179a0ac95d7 100644 --- a/mindspeed_llm/training/utils.py +++ b/mindspeed_llm/training/utils.py @@ -828,32 +828,4 @@ def _get_batch_on_this_cp_rank_in_megatron_cp_general(batch): val = val.chunk(cp_size, dim=seq_dim)[cp_rank].contiguous() batch[key] = val - return batch - - -def tensor_slide( - tensor: Optional[torch.Tensor], - slice_num: int, - dims: Union[int, List[int]] = -1, - step: int = 1, - return_first=False, -) -> List[Union[torch.Tensor, None]]: - """通用滑动窗口函数,支持任意维度""" - if tensor is None: - # return `List[None]` to avoid NoneType Error - return [None] * (slice_num + 1) - if slice_num == 0: - return [tensor] - window_size = tensor.shape[-1] - slice_num - dims = [dims] if isinstance(dims, int) else sorted(dims, reverse=True) - - # 连续多维度滑动 - slices = [] - for i in range(0, tensor.size(dims[-1]) - window_size + 1, step): - slice_obj = [slice(None)] * tensor.dim() - for dim in dims: - slice_obj[dim] = slice(i, i + window_size) - slices.append(tensor[tuple(slice_obj)]) - if return_first: - return slices - return slices + return batch \ No newline at end of file diff --git a/pretrain_gpt.py b/pretrain_gpt.py index 0a4c4f48bf41b1b91b0325597ea22eb1600023d0..fb6195b633e49b53c6c45286815e1351d35f1c6c 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -19,7 +19,6 @@ from megatron.core.datasets.gpt_dataset import MockGPTDataset, GPTDataset from megatron.core.datasets.utils import get_blend_from_list import megatron.legacy.model from megatron.core.models.gpt import GPTModel -from mindspeed_llm.core.models.gpt.gpt_layer_specs import get_gpt_mtp_block_spec from mindspeed_llm.training import pretrain from megatron.core.transformer.spec_utils import import_module from megatron.training.utils import ( @@ -32,6 +31,7 @@ from megatron.training.yaml_arguments import core_transformer_config_from_yaml from megatron.core.models.gpt.gpt_layer_specs import ( get_gpt_layer_local_spec, get_gpt_layer_with_transformer_engine_spec, + get_gpt_mtp_block_spec, ) from mindspeed_llm.training.utils import generate_actual_seq_len