From a09c24c1d912eefa518df299dbd80e8cd38e7178 Mon Sep 17 00:00:00 2001 From: shengjiayi Date: Thu, 13 Mar 2025 20:02:11 +0800 Subject: [PATCH] mtp in decoder --- mindspeed_llm/core/__init__.py | 2 +- .../core/distributed/finalize_model_grads.py | 5 +- mindspeed_llm/core/models/gpt/gpt_model.py | 26 +- .../core/pipeline_parallel/schedules.py | 636 ++++++++++++++++++ mindspeed_llm/core/tensor_parallel/layers.py | 54 +- .../core/transformer/dot_product_attention.py | 3 + .../core/transformer/transformer_block.py | 35 +- mindspeed_llm/tasks/megatron_adaptor.py | 10 +- mindspeed_llm/tasks/models/mtp_modules.py | 359 ++++++++++ mindspeed_llm/tasks/posttrain/utils.py | 5 +- mindspeed_llm/training/arguments.py | 28 +- mindspeed_llm/training/utils.py | 38 +- pretrain_gpt.py | 83 ++- 13 files changed, 1230 insertions(+), 54 deletions(-) create mode 100644 mindspeed_llm/tasks/models/mtp_modules.py diff --git a/mindspeed_llm/core/__init__.py b/mindspeed_llm/core/__init__.py index 1831c630a..c80da6663 100644 --- a/mindspeed_llm/core/__init__.py +++ b/mindspeed_llm/core/__init__.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .tensor_parallel.layers import vocab_parallel_embedding_forward, vocab_embedding_init_func, checkpoint_forward_wrapper, checkpoint_backward_wrapper +from .tensor_parallel.layers import vocab_parallel_embedding_forward, vocab_embedding_init_func from .parallel_state import (initialize_model_parallel_decorator, destroy_model_parallel_decorator, get_expert_model_parallel_rank, get_expert_model_parallel_world_size, get_expert_parallel_group, diff --git a/mindspeed_llm/core/distributed/finalize_model_grads.py b/mindspeed_llm/core/distributed/finalize_model_grads.py index b00d19690..430ba15f4 100644 --- a/mindspeed_llm/core/distributed/finalize_model_grads.py +++ b/mindspeed_llm/core/distributed/finalize_model_grads.py @@ -41,8 +41,9 @@ def _allreduce_word_embedding_grads(model: List[torch.nn.Module], config: Transf weight = model_module.shared_embedding_or_output_weight() grad = weight.main_grad torch.distributed.all_reduce(grad, group=parallel_state.get_embedding_group()) - if hasattr(model_module, - "share_mtp_embedding_and_output_weight") and model_module.share_mtp_embedding_and_output_weight: + if hasattr(model_module, "share_mtp_embedding_and_output_weight") \ + and model_module.share_mtp_embedding_and_output_weight \ + and not (hasattr(model_module, "mtp_in_decoder") and model_module.mtp_in_decoder): weight = model_module.shared_embedding_weight() grad = weight.main_grad torch.distributed.all_reduce(grad, group=parallel_state.get_embedding_group()) diff --git a/mindspeed_llm/core/models/gpt/gpt_model.py b/mindspeed_llm/core/models/gpt/gpt_model.py index 0f5c0570b..7707fa145 100644 --- a/mindspeed_llm/core/models/gpt/gpt_model.py +++ b/mindspeed_llm/core/models/gpt/gpt_model.py @@ -38,7 +38,7 @@ def gpt_model_init_wrapper(fn): def wrapper(self, *args, **kwargs): post_layer_norm = kwargs.pop('post_layer_norm', True) fn(self, *args, **kwargs) - config = args[1] if len(args) > 1 else kwargs['config'] + config = args[0] if len(args) >= 1 else kwargs['config'] arguments = get_args() if self.post_process and arguments.add_output_layer_bias: self.output_layer = tensor_parallel.ColumnParallelLinear( @@ -73,7 +73,10 @@ def gpt_model_init_wrapper(fn): self.decoder.post_layer_norm = False self.num_nextn_predict_layers = arguments.num_nextn_predict_layers self.share_mtp_embedding_and_output_weight = arguments.share_mtp_embedding_and_output_weight - if self.post_process and self.training and self.num_nextn_predict_layers: + self.mtp_use_new_token = arguments.mtp_use_new_token + self.mtp_in_decoder = arguments.mtp_in_decoder + + if self.post_process and self.training and self.num_nextn_predict_layers and not self.mtp_in_decoder: self.mtp_layers = torch.nn.ModuleList( [ MultiTokenPredication( @@ -96,7 +99,7 @@ def gpt_model_init_wrapper(fn): ] ) - if self.post_process and self.num_nextn_predict_layers: + if self.post_process and self.num_nextn_predict_layers and not self.mtp_in_decoder: # move block main model final norms here self.final_layernorm = build_module( TENorm, @@ -107,7 +110,7 @@ def gpt_model_init_wrapper(fn): else: self.final_layernorm = None - if self.pre_process or self.post_process: + if (self.pre_process or self.post_process) and not self.mtp_in_decoder: setup_mtp_embeddings_layer(self) return wrapper @@ -216,12 +219,13 @@ def gpt_model_forward(self, input_ids: Tensor, # Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input. args = get_args() # generate inputs for main and mtps - input_ids, labels, position_ids, attention_mask = inputs_slice( - args.num_nextn_predict_layers, - input_ids, - labels, - position_ids, - attention_mask) + if self.mtp_use_new_token: + input_ids, labels, position_ids, attention_mask = inputs_slice( + args.num_nextn_predict_layers, + input_ids, + labels, + position_ids, + attention_mask) if not self.training and (hasattr(args, "rope_scaling_type") and args.rope_scaling_type == "longrope"): args.rope_scaling_original_max_position_embeddings = args.max_position_embeddings # Decoder embedding. @@ -264,7 +268,7 @@ def gpt_model_forward(self, input_ids: Tensor, loss = 0 # Multi token predication module - if args.num_nextn_predict_layers and self.training: + if args.num_nextn_predict_layers and self.training and not self.mtp_in_decoder: if not self.share_embeddings_and_output_weights and self.share_mtp_embedding_and_output_weight: output_weight = self.output_layer.weight output_weight.zero_out_wgrad = True diff --git a/mindspeed_llm/core/pipeline_parallel/schedules.py b/mindspeed_llm/core/pipeline_parallel/schedules.py index 1a54655b6..af524a734 100644 --- a/mindspeed_llm/core/pipeline_parallel/schedules.py +++ b/mindspeed_llm/core/pipeline_parallel/schedules.py @@ -17,7 +17,18 @@ import contextlib from functools import wraps +from typing import Iterator, List, Union import torch + +from megatron.core import parallel_state +from megatron.core.enums import ModelType +from megatron.core.pipeline_parallel import p2p_communication +from megatron.core.pipeline_parallel.schedules import clear_embedding_activation_buffer, forward_step, \ + check_first_val_step, backward_step, deallocate_output_tensor, finish_embedding_wgrad_compute +from megatron.core.utils import ( + get_model_config, + get_model_type, +) from megatron.training import get_args from mindspeed.core.pipeline_parallel.ripipe_schedules import forward_backward_ripipe_pipelining @@ -48,6 +59,631 @@ def forward_backward_func_wrapper(fn): return wrapper +def forward_backward_pipelining_with_interleaving_and_mtp( + *, + forward_step_func, + data_iterator: Union[Iterator, List[Iterator]], + model: Union[torch.nn.Module, List[torch.nn.Module]], + num_microbatches: int, + seq_length: int, + micro_batch_size: int, + decoder_seq_length: int = None, + forward_only: bool = False, + collect_non_loss_data: bool = False, + first_val_step: bool = None, +): + """Run interleaved 1F1B schedule (model split into model chunks), with + communication between pipeline stages as needed. + + Returns dictionary with losses if the last stage, empty dict otherwise.""" + assert isinstance(model, list), "interleaved pipeline parallelism expected model chunking" + assert all(isinstance(chunk, torch.nn.Module) for chunk in model), "invalid model chunking" + assert isinstance( + data_iterator, list + ), "interleaved pipeline parallelism expected each model chunk to have a data iterator" + + config = get_model_config(model[0]) + if config.overlap_p2p_comm and config.batch_p2p_comm: + raise ValueError("Can not use both overlap_p2p_comm and batch_p2p_comm") + + # Needed only when gradients are finalized in M-Core + if config.finalize_model_grads_func is not None and not forward_only: + embedding_module = clear_embedding_activation_buffer(config, model) + + if config.timers is not None: + config.timers('forward-backward', log_level=1).start(barrier=config.barrier_with_L1_time) + + # Disable async grad reductions + no_sync_func = config.no_sync_func + if isinstance(no_sync_func, list): + + def multi_no_sync(): + stack = contextlib.ExitStack() + for model_chunk_no_sync_func in config.no_sync_func: + stack.enter_context(model_chunk_no_sync_func()) + return stack + + no_sync_func = multi_no_sync + if no_sync_func is None: + no_sync_func = contextlib.nullcontext + no_sync_context = None + + if config.grad_sync_func is not None and not isinstance(config.grad_sync_func, list): + config.grad_sync_func = [config.grad_sync_func for _ in model] + + if config.param_sync_func is not None and not isinstance(config.param_sync_func, list): + config.param_sync_func = [config.param_sync_func for _ in model] + + def disable_grad_sync(): + """Disable asynchronous grad reductions""" + nonlocal no_sync_context + if no_sync_context is None: + no_sync_context = no_sync_func() + no_sync_context.__enter__() + + def enable_grad_sync(): + """Enable asynchronous grad reductions""" + nonlocal no_sync_context + if no_sync_context is not None: + no_sync_context.__exit__(None, None, None) + no_sync_context = None + + disable_grad_sync() + + # Model chunk IDs with synchronized grads + synchronized_model_chunks = set() + + input_tensors = [[] for _ in range(len(model))] + output_tensors = [[] for _ in range(len(model))] + total_num_tokens = torch.tensor(0, dtype=torch.int).cuda() + + forward_data_store = [] + if not forward_only: + output_tensor_grads = [[] for _ in range(len(model))] + + pipeline_parallel_size = parallel_state.get_pipeline_model_parallel_world_size() + pipeline_parallel_rank = parallel_state.get_pipeline_model_parallel_rank() + + if num_microbatches % pipeline_parallel_size != 0: + msg = f'number of microbatches ({num_microbatches}) is not divisible by ' + msg += f'pipeline-model-parallel-size ({pipeline_parallel_size}) ' + msg += 'when using interleaved schedule' + raise RuntimeError(msg) + + model_type = get_model_type(model[0]) + if model_type == ModelType.encoder_and_decoder: + raise RuntimeError("Interleaving is not supported with an encoder and decoder model.") + + if decoder_seq_length is not None and decoder_seq_length != seq_length: + raise RuntimeError( + "Interleaving is not supported with a different decoder sequence length." + ) + + tensor_shape = [seq_length, micro_batch_size, config.hidden_size] + tensor_shape[0] = tensor_shape[0] // parallel_state.get_context_parallel_world_size() + if config.sequence_parallel: + tensor_shape[0] = tensor_shape[0] // parallel_state.get_tensor_model_parallel_world_size() + + # In MTP models, all predictions are sent through PP + # we need to extend the shape by its batch dimension to cooperate with output layers + args = get_args() + if args.num_nextn_predict_layers and args.mtp_in_decoder: + tensor_shape[1] *= args.num_nextn_predict_layers + 1 + + # Compute number of warmup and remaining microbatches. + num_model_chunks = len(model) + total_num_microbatches = num_microbatches * num_model_chunks + all_warmup_microbatches = False + if forward_only: + num_warmup_microbatches = total_num_microbatches + else: + # Run all forward passes and then all backward passes if number of + # microbatches is just the number of pipeline stages. + # Otherwise, perform (num_model_chunks-1)*pipeline_parallel_size on + # all workers, followed by more microbatches after depending on + # stage ID (more forward passes for earlier stages, later stages can + # immediately start with 1F1B). + if num_microbatches == pipeline_parallel_size: + num_warmup_microbatches = total_num_microbatches + all_warmup_microbatches = True + else: + num_warmup_microbatches = (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2 + num_warmup_microbatches += (num_model_chunks - 1) * pipeline_parallel_size + num_warmup_microbatches = min(num_warmup_microbatches, total_num_microbatches) + num_microbatches_remaining = total_num_microbatches - num_warmup_microbatches + + # Checkpoint the activations of partial Transformer layers in a number of micro-batches + # within the maximum outstanding micro-batch backpropagations. + # Micro-batches with the ids less than 'num_microbatches_with_partial_activation_checkpoints' + # checkpoint partial Transformer layers (or skip checkpointing) and + # the rest of micro-batches within a window of micro-batches checkpoint + # all Transformer layers. The window of micro-batches is set by the maximum + # outstanding backpropagations and becomes smaller at later pipeline stages. + # Please refer the appendix C in https://arxiv.org/pdf/2205.05198.pdf + max_outstanding_backprops = None + if config.num_microbatches_with_partial_activation_checkpoints is not None: + max_outstanding_backprops = num_warmup_microbatches + 1 + + # Synchronize params for first two model chunks + if config.param_sync_func is not None: + config.param_sync_func[0](model[0].parameters()) + config.param_sync_func[1](model[1].parameters()) + + def get_model_chunk_id(microbatch_id, forward): + """Helper method to get the model chunk ID given the iteration number.""" + microbatch_id_in_group = microbatch_id % (pipeline_parallel_size * num_model_chunks) + model_chunk_id = microbatch_id_in_group // pipeline_parallel_size + if not forward: + model_chunk_id = num_model_chunks - model_chunk_id - 1 + return model_chunk_id + + def get_microbatch_id_in_model_chunk(iteration_id, forward): + """Helper method to get the microbatch_id within model chunk given the iteration number.""" + assert forward + iteration_group_id = iteration_id // (pipeline_parallel_size * num_model_chunks) + microbatch_id_in_model_chunk = (iteration_group_id * pipeline_parallel_size) + ( + iteration_id % pipeline_parallel_size + ) + return microbatch_id_in_model_chunk + + def is_first_microbatch_for_model_chunk(microbatch_id: int) -> bool: + """Check if an iteration is the first for a model chunk.""" + microbatch_group_size = pipeline_parallel_size * num_model_chunks + num_microbatch_groups = total_num_microbatches // microbatch_group_size + microbatch_group_id = microbatch_id // microbatch_group_size + microbatch_id_in_group = microbatch_id % microbatch_group_size + if microbatch_group_id == 0: + return microbatch_id_in_group % pipeline_parallel_size == 0 + else: + return False + + def is_last_microbatch_for_model_chunk(microbatch_id: int) -> bool: + """Check if an iteration is the last for a model chunk.""" + microbatch_group_size = pipeline_parallel_size * num_model_chunks + num_microbatch_groups = total_num_microbatches // microbatch_group_size + microbatch_group_id = microbatch_id // microbatch_group_size + microbatch_id_in_group = microbatch_id % microbatch_group_size + if microbatch_group_id == num_microbatch_groups - 1: + return microbatch_id_in_group % pipeline_parallel_size == pipeline_parallel_size - 1 + else: + return False + + def forward_step_helper(microbatch_id, current_microbatch, checkpoint_activations_microbatch): + """Helper method to run forward step with model split into chunks + (run set_virtual_pipeline_model_parallel_rank() before calling + forward_step()).""" + model_chunk_id = get_model_chunk_id(microbatch_id, forward=True) + parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id) + + # launch param synchronization for next model chunk + # Note: Asynchronous communication tends to slow down compute. + # To reduce idling from mismatched microbatch times, we launch + # asynchronous communication at the same time across the + # pipeline-parallel group. + if config.param_sync_func is not None: + param_sync_microbatch_id = microbatch_id + pipeline_parallel_rank + if ( + param_sync_microbatch_id < total_num_microbatches + and is_first_microbatch_for_model_chunk(param_sync_microbatch_id) + ): + param_sync_chunk_id = get_model_chunk_id(param_sync_microbatch_id, forward=True) + 1 + if 1 < param_sync_chunk_id < num_model_chunks: + config.param_sync_func[param_sync_chunk_id]( + model[param_sync_chunk_id].parameters() + ) + + # forward step + if parallel_state.is_pipeline_first_stage(): + if len(input_tensors[model_chunk_id]) == len(output_tensors[model_chunk_id]): + input_tensors[model_chunk_id].append(None) + input_tensor = input_tensors[model_chunk_id][-1] + + output_tensor, num_tokens = forward_step( + forward_step_func, + data_iterator[model_chunk_id], + model[model_chunk_id], + num_microbatches, + input_tensor, + forward_data_store, + config, + collect_non_loss_data, + checkpoint_activations_microbatch, + check_first_val_step( + first_val_step, + forward_only, + is_first_microbatch_for_model_chunk(microbatch_id), + ), + current_microbatch=current_microbatch, + ) + output_tensors[model_chunk_id].append(output_tensor) + + nonlocal total_num_tokens + total_num_tokens += num_tokens.item() + + # if forward-only, no need to save tensors for a backward pass + if forward_only: + input_tensors[model_chunk_id].pop() + output_tensors[model_chunk_id].pop() + + return output_tensor + + def backward_step_helper(microbatch_id): + """Helper method to run backward step with model split into chunks + (run set_virtual_pipeline_model_parallel_rank() before calling + backward_step()).""" + model_chunk_id = get_model_chunk_id(microbatch_id, forward=False) + parallel_state.set_virtual_pipeline_model_parallel_rank(model_chunk_id) + + # launch grad synchronization (default) + if config.grad_sync_func is None and is_last_microbatch_for_model_chunk(microbatch_id): + enable_grad_sync() + synchronized_model_chunks.add(model_chunk_id) + + if parallel_state.is_pipeline_last_stage(): + if len(output_tensor_grads[model_chunk_id]) == 0: + output_tensor_grads[model_chunk_id].append(None) + input_tensor = input_tensors[model_chunk_id].pop(0) + output_tensor = output_tensors[model_chunk_id].pop(0) + output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0) + input_tensor_grad = backward_step( + input_tensor, output_tensor, output_tensor_grad, model_type, config + ) + + # launch grad synchronization (custom grad sync) + # Note: Asynchronous communication tends to slow down compute. + # To reduce idling from mismatched microbatch times, we launch + # asynchronous communication at the same time across the + # pipeline-parallel group. + if config.grad_sync_func is not None: + grad_sync_microbatch_id = microbatch_id - pipeline_parallel_rank + if grad_sync_microbatch_id >= 0 and is_last_microbatch_for_model_chunk( + grad_sync_microbatch_id + ): + grad_sync_chunk_id = get_model_chunk_id(grad_sync_microbatch_id, forward=False) + enable_grad_sync() + config.grad_sync_func[grad_sync_chunk_id](model[grad_sync_chunk_id].parameters()) + synchronized_model_chunks.add(grad_sync_chunk_id) + disable_grad_sync() + + return input_tensor_grad + + # Run warmup forward passes. + parallel_state.set_virtual_pipeline_model_parallel_rank(0) + input_tensors[0].append(p2p_communication.recv_forward(tensor_shape, config)) + + fwd_wait_handles = None + bwd_wait_handles = None + + for k in range(num_warmup_microbatches): + + if fwd_wait_handles is not None: + for req in fwd_wait_handles: + req.wait() + + cur_model_chunk_id = get_model_chunk_id(k, forward=True) + # Decide to checkpoint all layers' activations of the current micro-batch + if max_outstanding_backprops is not None: + checkpoint_activations_microbatch = ( + k % max_outstanding_backprops + >= config.num_microbatches_with_partial_activation_checkpoints + ) + else: + checkpoint_activations_microbatch = None + + current_microbatch = get_microbatch_id_in_model_chunk(k, forward=True) + output_tensor = forward_step_helper( + k, current_microbatch, checkpoint_activations_microbatch + ) + + # Determine if tensor should be received from previous stage. + next_forward_model_chunk_id = get_model_chunk_id(k + 1, forward=True) + recv_prev = True + if parallel_state.is_pipeline_first_stage(ignore_virtual=True): + if next_forward_model_chunk_id == 0: + recv_prev = False + if k == (total_num_microbatches - 1): + recv_prev = False + + # Don't send tensor downstream if on last stage. + if parallel_state.is_pipeline_last_stage(): + output_tensor = None + + # Send and receive tensors as appropriate (send tensors computed + # in this iteration; receive tensors for next iteration). + if not config.overlap_p2p_comm: + if ( + k == (num_warmup_microbatches - 1) + and not forward_only + and not all_warmup_microbatches + ): + input_tensor_grad = None + recv_next = True + if parallel_state.is_pipeline_last_stage(ignore_virtual=True): + recv_next = False + ( + input_tensor, + output_tensor_grad, + ) = p2p_communication.send_forward_backward_recv_forward_backward( + output_tensor, + input_tensor_grad, + recv_prev=recv_prev, + recv_next=recv_next, + tensor_shape=tensor_shape, + config=config, + ) + output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad) + else: + input_tensor = p2p_communication.send_forward_recv_forward( + output_tensor, recv_prev=recv_prev, tensor_shape=tensor_shape, config=config + ) + input_tensors[next_forward_model_chunk_id].append(input_tensor) + else: + input_tensor, fwd_wait_handles = p2p_communication.send_forward_recv_forward( + output_tensor, + recv_prev=recv_prev, + tensor_shape=tensor_shape, + config=config, + overlap_p2p_comm=True, + ) + + if ( + k == (num_warmup_microbatches - 1) + and not forward_only + and not all_warmup_microbatches + ): + input_tensor_grad = None + recv_next = True + if parallel_state.is_pipeline_last_stage(ignore_virtual=True): + recv_next = False + + ( + output_tensor_grad, + bwd_wait_handles, + ) = p2p_communication.send_backward_recv_backward( + input_tensor_grad, + recv_next=recv_next, + tensor_shape=tensor_shape, + config=config, + overlap_p2p_comm=True, + ) + + output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad) + input_tensors[next_forward_model_chunk_id].append(input_tensor) + + deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs) + + # Run 1F1B in steady state. + for k in range(num_microbatches_remaining): + # Forward pass. + forward_k = k + num_warmup_microbatches + + # Decide to checkpoint all layers' activations of the current micro-batch + if max_outstanding_backprops is not None: + checkpoint_activations_microbatch = ( + forward_k % max_outstanding_backprops + >= config.num_microbatches_with_partial_activation_checkpoints + ) + else: + checkpoint_activations_microbatch = None + + cur_model_chunk_id = get_model_chunk_id(forward_k, forward=True) + current_microbatch = get_microbatch_id_in_model_chunk(forward_k, forward=True) + if config.overlap_p2p_comm: + if fwd_wait_handles is not None: + for req in fwd_wait_handles: + req.wait() + + deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs) + + output_tensor = forward_step_helper( + forward_k, current_microbatch, checkpoint_activations_microbatch + ) + + # Determine if current stage has anything to send in either direction, + # otherwise set tensor to None. + forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True) + parallel_state.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id) + + # Last virtual stage no activation tensor to send + if parallel_state.is_pipeline_last_stage(): + output_tensor = None + + # Determine if peers are sending, and where in data structure to put + # received tensors. + recv_prev = True + if parallel_state.is_pipeline_first_stage(ignore_virtual=True): + # First stage is ahead of last stage by (pipeline_parallel_size - 1). + next_forward_model_chunk_id = get_model_chunk_id( + forward_k - (pipeline_parallel_size - 1), forward=True + ) + if next_forward_model_chunk_id == (num_model_chunks - 1): + recv_prev = False + next_forward_model_chunk_id += 1 + else: + next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1, forward=True) + + # If last iteration, don't receive; we already received one extra + # before the start of the for loop. + if k == (num_microbatches_remaining - 1): + recv_prev = False + + # Send activation tensor to the next stage and receive activation tensor from the + # previous stage + input_tensor, fwd_wait_handles = p2p_communication.send_forward_recv_forward( + output_tensor, + recv_prev=recv_prev, + tensor_shape=tensor_shape, + config=config, + overlap_p2p_comm=True, + ) + # assert fwd_wait_handles is not None + + if bwd_wait_handles is not None: + for req in bwd_wait_handles: + req.wait() + + # Backward pass. + backward_k = k + input_tensor_grad = backward_step_helper(backward_k) + + backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False) + parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id) + + # First virtual stage no activation gradient tensor to send + if parallel_state.is_pipeline_first_stage(): + input_tensor_grad = None + + # Determine if the current virtual stage has an activation gradient tensor to receive + recv_next = True + if parallel_state.is_pipeline_last_stage(ignore_virtual=True): + # Last stage is ahead of first stage by (pipeline_parallel_size - 1). + next_backward_model_chunk_id = get_model_chunk_id( + backward_k - (pipeline_parallel_size - 1), forward=False + ) + if next_backward_model_chunk_id == 0: + recv_next = False + next_backward_model_chunk_id -= 1 + else: + next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1, forward=False) + + output_tensor_grad, bwd_wait_handles = p2p_communication.send_backward_recv_backward( + input_tensor_grad, + recv_next=recv_next, + tensor_shape=tensor_shape, + config=config, + overlap_p2p_comm=True, + ) + + else: # no p2p overlap + output_tensor = forward_step_helper( + forward_k, current_microbatch, checkpoint_activations_microbatch + ) + + # Backward pass. + backward_k = k + input_tensor_grad = backward_step_helper(backward_k) + + # Send output_tensor and input_tensor_grad, receive input_tensor + # and output_tensor_grad. + + # Determine if current stage has anything to send in either direction, + # otherwise set tensor to None. + forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True) + parallel_state.set_virtual_pipeline_model_parallel_rank(forward_model_chunk_id) + if parallel_state.is_pipeline_last_stage(): + output_tensor = None + + backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False) + parallel_state.set_virtual_pipeline_model_parallel_rank(backward_model_chunk_id) + if parallel_state.is_pipeline_first_stage(): + input_tensor_grad = None + + # Determine if peers are sending, and where in data structure to put + # received tensors. + recv_prev = True + if parallel_state.is_pipeline_first_stage(ignore_virtual=True): + # First stage is ahead of last stage by (pipeline_parallel_size - 1). + next_forward_model_chunk_id = get_model_chunk_id( + forward_k - (pipeline_parallel_size - 1), forward=True + ) + if next_forward_model_chunk_id == (num_model_chunks - 1): + recv_prev = False + next_forward_model_chunk_id += 1 + else: + next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1, forward=True) + + recv_next = True + if parallel_state.is_pipeline_last_stage(ignore_virtual=True): + # Last stage is ahead of first stage by (pipeline_parallel_size - 1). + next_backward_model_chunk_id = get_model_chunk_id( + backward_k - (pipeline_parallel_size - 1), forward=False + ) + if next_backward_model_chunk_id == 0: + recv_next = False + next_backward_model_chunk_id -= 1 + else: + next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1, forward=False) + + # If last iteration, don't receive; we already received one extra + # before the start of the for loop. + if k == (num_microbatches_remaining - 1): + recv_prev = False + + # Communicate tensors. + ( + input_tensor, + output_tensor_grad, + ) = p2p_communication.send_forward_backward_recv_forward_backward( + output_tensor, + input_tensor_grad, + recv_prev=recv_prev, + recv_next=recv_next, + tensor_shape=tensor_shape, + config=config, + ) + deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs) + + # Put input_tensor and output_tensor_grad in data structures in the + # right location. + if recv_prev: + input_tensors[next_forward_model_chunk_id].append(input_tensor) + if recv_next: + output_tensor_grads[next_backward_model_chunk_id].append(output_tensor_grad) + + deallocate_output_tensor(output_tensor, config.deallocate_pipeline_outputs) + + # Run cooldown backward passes (flush out pipeline). + if not forward_only: + if config.overlap_p2p_comm and bwd_wait_handles is not None: + for wait_handle in bwd_wait_handles: + wait_handle.wait() + + if all_warmup_microbatches: + output_tensor_grads[num_model_chunks - 1].append( + p2p_communication.recv_backward(tensor_shape, config=config) + ) + for k in range(num_microbatches_remaining, total_num_microbatches): + input_tensor_grad = backward_step_helper(k) + next_backward_model_chunk_id = get_model_chunk_id(k + 1, forward=False) + recv_next = True + if parallel_state.is_pipeline_last_stage(ignore_virtual=True): + if next_backward_model_chunk_id == (num_model_chunks - 1): + recv_next = False + if k == (total_num_microbatches - 1): + recv_next = False + output_tensor_grads[next_backward_model_chunk_id].append( + p2p_communication.send_backward_recv_backward( + input_tensor_grad, recv_next=recv_next, tensor_shape=tensor_shape, config=config + ) + ) + + # Launch any remaining grad reductions. + enable_grad_sync() + if config.grad_sync_func is not None: + for model_chunk_id in range(num_model_chunks): + if model_chunk_id not in synchronized_model_chunks: + config.grad_sync_func[model_chunk_id](model[model_chunk_id].parameters()) + synchronized_model_chunks.add(model_chunk_id) + + if config.finalize_model_grads_func is not None and not forward_only: + + # If defer_embedding_wgrad_compute is enabled we need to do the + # weight gradient GEMM's here. + finish_embedding_wgrad_compute(config, embedding_module) + + # Finalize model grads (perform full grad all-reduce / reduce-scatter for + # data parallelism, layernorm all-reduce for sequence parallelism, and + # embedding all-reduce for pipeline parallelism). + config.finalize_model_grads_func( + model, total_num_tokens if config.calculate_per_token_loss else None + ) + + if config.timers is not None: + config.timers('forward-backward').stop() + + return forward_data_store + + def forward_backward_pipelining_with_interleaving_wrapper(fn): @wraps(fn) def wrapper(*args, **kwargs): diff --git a/mindspeed_llm/core/tensor_parallel/layers.py b/mindspeed_llm/core/tensor_parallel/layers.py index d59b71c88..45bb55139 100644 --- a/mindspeed_llm/core/tensor_parallel/layers.py +++ b/mindspeed_llm/core/tensor_parallel/layers.py @@ -25,6 +25,7 @@ from megatron.core.tensor_parallel.mappings import ( reduce_from_tensor_model_parallel_region, ) from megatron.core.tensor_parallel.utils import VocabUtility +from megatron.core.transformer import TransformerConfig, MegatronModule from megatron.training import get_args from megatron.core.tensor_parallel import ( copy_to_tensor_model_parallel_region, @@ -38,6 +39,10 @@ from megatron.core.tensor_parallel.layers import ( _initialize_affine_weight_gpu, VocabParallelEmbedding, ) +from megatron.core.tensor_parallel.random import ( + get_cuda_rng_tracker, + get_data_parallel_rng_tracker_name, +) from megatron.legacy.model.fused_layer_norm import MixedFusedLayerNorm from megatron.core import parallel_state, ModelParallelConfig @@ -241,17 +246,46 @@ class SegmentedColumnParallelLinear(ColumnParallelLinear): return output, output_bias -def checkpoint_forward_wrapper(fn): - def wrapper(ctx, run_function, distribute_saved_activations, *args): - ctx.actual_seq_len = get_actual_seq_len() - return fn(ctx, run_function, distribute_saved_activations, *args) +class SequenceParallelLinear(MegatronModule): + def __init__( + self, + input_size, + output_size, + *, + config: TransformerConfig, + init_method: Callable, + bias=True, + skip_bias_add=True, + **_, + ): + super().__init__(config) - return wrapper + assert not config.use_cpu_initialization + self.skip_bias_add = skip_bias_add -def checkpoint_backward_wrapper(fn): - def wrapper(ctx, *args): - set_actual_seq_len(ctx.actual_seq_len) - return fn(ctx, *args) + device = torch.cuda.current_device() + dtype = config.params_dtype + factory_kwargs = {'device': device, 'dtype': dtype} + self.weight = Parameter(torch.empty((output_size, input_size), **factory_kwargs)) - return wrapper + if bias: + self.bias = Parameter(torch.empty(output_size, **factory_kwargs)) + if config.perform_initialization: + # Always initialize bias to zero. + with torch.no_grad(): + self.bias.zero_() + setattr(self.bias, "sequence_parallel", config.sequence_parallel) + else: + self.register_parameter('bias', None) + + if config.perform_initialization: + with get_cuda_rng_tracker().fork(get_data_parallel_rng_tracker_name()): + init_method(self.weight) + setattr(self.weight, "sequence_parallel", config.sequence_parallel) + + def forward(self, hidden_states): + if self.skip_bias_add: + return F.linear(hidden_states, self.weight), self.bias + else: + return F.linear(hidden_states, self.weight, self.bias), None \ No newline at end of file diff --git a/mindspeed_llm/core/transformer/dot_product_attention.py b/mindspeed_llm/core/transformer/dot_product_attention.py index 18156fa51..8f2bb63c7 100644 --- a/mindspeed_llm/core/transformer/dot_product_attention.py +++ b/mindspeed_llm/core/transformer/dot_product_attention.py @@ -409,6 +409,9 @@ def flash_attention_forward( if self.scale_mask_softmax.scale is None else self.softmax_scale actual_seq_len = get_actual_seq_len() + if actual_seq_len is not None and args.mtp_use_new_token and args.mtp_in_decoder: + actual_seq_len = actual_seq_len[self.mtp_idx] + if args.context_parallel_size > 1 and args.context_parallel_algo in ['megatron_cp_algo', 'hybrid_cp_algo', 'adaptive_cp_algo', 'hybrid_adaptive_cp_algo']: query, key, value = [rearrange(x, 's b h d -> s b (h d)') for x in [query, key, value]] diff --git a/mindspeed_llm/core/transformer/transformer_block.py b/mindspeed_llm/core/transformer/transformer_block.py index 39527d137..77bb0e4dc 100644 --- a/mindspeed_llm/core/transformer/transformer_block.py +++ b/mindspeed_llm/core/transformer/transformer_block.py @@ -23,7 +23,7 @@ from megatron.core import InferenceParams, tensor_parallel, parallel_state, mpu from megatron.core.packed_seq_params import PackedSeqParams from megatron.training import get_args from megatron.core.models.gpt.gpt_layer_specs import _get_mlp_module_spec -from megatron.core.transformer import build_module +from megatron.core.transformer import build_module, ModuleSpec from megatron.core.transformer.custom_layers.transformer_engine import TENorm from megatron.core.utils import make_sharded_tensor_for_checkpoint, make_viewless_tensor from mindspeed.core.transformer.transformer_block import NoopTransformerLayer, _get_layer_offset @@ -32,6 +32,8 @@ from mindspeed.core.tensor_parallel.comm_autograd_function import auto_grad_sync from mindspeed.core.tensor_parallel.comm_group_api import TPXCollectiveComm, TPYCollectiveComm from mindspeed.core.transformer.transformer import norm_recompute_forward from mindspeed.model.transformer import should_recompute_norm +from mindspeed_llm.core.tensor_parallel.layers import SequenceParallelLinear +from mindspeed_llm.tasks.models.mtp_modules import TransformerLayerAsMTP, MTPLayer, MTPLayerSubmodules def get_num_layers_to_build_wrapper(fn): @@ -84,10 +86,33 @@ def _transformer_block_build_layers(self): (global_layer_number - 1) >= args.first_k_dense_replace and (global_layer_number - 1) % args.moe_layer_freq == 0 ): - layer_spec.submodules.mlp = _get_mlp_module_spec(use_te=use_te, num_experts=args.num_experts, - moe_grouped_gemm=args.moe_grouped_gemm) + mlp_module = _get_mlp_module_spec(use_te=use_te, num_experts=args.num_experts, + moe_grouped_gemm=args.moe_grouped_gemm) else: - layer_spec.submodules.mlp = _get_mlp_module_spec(use_te=use_te, moe_grouped_gemm=args.moe_grouped_gemm) + mlp_module = _get_mlp_module_spec(use_te=use_te, moe_grouped_gemm=args.moe_grouped_gemm) + if hasattr(layer_spec.submodules, "mlp"): + layer_spec.submodules.mlp = mlp_module + if args.num_nextn_predict_layers and args.mtp_in_decoder: + if global_layer_number <= args.num_layers - args.num_nextn_predict_layers: + layer_spec = ModuleSpec(module=TransformerLayerAsMTP, + submodules=layer_spec.submodules + ) + layer_spec.submodules.final_norm = layer_spec.submodules.input_layernorm + print( + f"====rank{torch.distributed.get_rank()}===global_layer{global_layer_number}/{layer_number} got TransformerLayerAsMTP") + else: + layer_spec = ModuleSpec(module=MTPLayer, + submodules=MTPLayerSubmodules( + prev_norm=TENorm, + emb_norm=TENorm, + prev_proj=SequenceParallelLinear, + transformer_layer_spec=layer_spec, + final_norm=TENorm, + ) + ) + + print( + f"====rank{torch.distributed.get_rank()}===global_layer{global_layer_number}/{layer_number} got MTPLayer") # For noop layer if args.noop_layers and isinstance(args.noop_layers, set) and global_layer_number - 1 in args.noop_layers: @@ -113,7 +138,7 @@ def _transformer_block_build_layers(self): ) else: self.final_layernorm = None # Either this or nn.Identity - + # For recompute norm if args.recompute_norm: for layer in self.layers: diff --git a/mindspeed_llm/tasks/megatron_adaptor.py b/mindspeed_llm/tasks/megatron_adaptor.py index 18257128a..9c1124f8d 100644 --- a/mindspeed_llm/tasks/megatron_adaptor.py +++ b/mindspeed_llm/tasks/megatron_adaptor.py @@ -518,7 +518,7 @@ class CoreAdaptation(MegatronAdaptationABC): def patch_tensor_parallel(self): from mindspeed.core.tensor_parallel.random import _set_cuda_rng_state from mindspeed.core.tensor_parallel.cross_entropy import calculate_predicted_logits - from ..core import vocab_parallel_embedding_forward, vocab_embedding_init_func, checkpoint_forward_wrapper, checkpoint_backward_wrapper + from ..core import vocab_parallel_embedding_forward, vocab_embedding_init_func # default_generators need replace after set_device MegatronAdaptation.register('megatron.core.tensor_parallel.random._set_cuda_rng_state', _set_cuda_rng_state) @@ -530,10 +530,6 @@ class CoreAdaptation(MegatronAdaptationABC): vocab_parallel_embedding_forward) MegatronAdaptation.register('megatron.core.tensor_parallel.layers.VocabParallelEmbedding.__init__', vocab_embedding_init_func) - MegatronAdaptation.register('megatron.core.tensor_parallel.random.CheckpointFunction.forward', - checkpoint_forward_wrapper) - MegatronAdaptation.register('megatron.core.tensor_parallel.random.CheckpointFunction.backward', - checkpoint_backward_wrapper) # For recompute-in-advance from mindspeed.core.tensor_parallel.random import checkpoint_wrapper MegatronAdaptation.register('megatron.core.tensor_parallel.random.checkpoint', checkpoint_wrapper) @@ -650,6 +646,10 @@ class CoreAdaptation(MegatronAdaptationABC): def patch_pipeline_parallel_schedules(self): from ..core import forward_backward_pipelining_with_interleaving_wrapper args = MegatronAdaptation.get_args() + from mindspeed_llm.core.pipeline_parallel.schedules import forward_backward_pipelining_with_interleaving_and_mtp + MegatronAdaptation.register( + 'megatron.core.pipeline_parallel.schedules.forward_backward_pipelining_with_interleaving', + forward_backward_pipelining_with_interleaving_and_mtp) MegatronAdaptation.register('megatron.core.pipeline_parallel.schedules.forward_backward_pipelining_with_interleaving', forward_backward_pipelining_with_interleaving_wrapper) diff --git a/mindspeed_llm/tasks/models/mtp_modules.py b/mindspeed_llm/tasks/models/mtp_modules.py new file mode 100644 index 000000000..d861088f3 --- /dev/null +++ b/mindspeed_llm/tasks/models/mtp_modules.py @@ -0,0 +1,359 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + +from dataclasses import dataclass +from typing import Dict, Literal, Optional, List, Union + +import torch +from torch import Tensor +from torch import nn +import torch.distributed +from torch.nn import ModuleList + +from megatron.core import InferenceParams, parallel_state, tensor_parallel +from megatron.core.dist_checkpointing.mapping import ShardedStateDict +from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding +from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding +from megatron.core.models.gpt.gpt_model import GPTModel +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.transformer.custom_layers.transformer_engine import TENorm +from megatron.core.transformer.enums import ModelType +from megatron.core.transformer.identity_op import IdentityOp +from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.spec_utils import ModuleSpec, build_module +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.transformer.transformer_block import TransformerBlock +from megatron.core.transformer.transformer_layer import TransformerLayerSubmodules +from megatron.training import get_args +from mindspeed.core.transformer.transformer_block import _get_layer_offset +from mindspeed_llm.core import TransformerLayer +from mindspeed_llm.core.models.gpt.gpt_model import generate_nextn_position_ids +from mindspeed_llm.training.utils import tensor_slide + + +@dataclass +class MTPLayerSubmodules: + prev_norm: Union[ModuleSpec, type] = IdentityOp + emb_norm: Union[ModuleSpec, type] = IdentityOp + prev_proj: Union[ModuleSpec, type] = IdentityOp + transformer_layer_spec: ModuleSpec = None + final_norm: Union[ModuleSpec, type] = IdentityOp + + +class TransformerLayerAsMTP(TransformerLayer): + def __init__(self, config: TransformerConfig, submodules: TransformerLayerSubmodules, layer_number: int = 1, + hidden_dropout: float = None): + super().__init__(config, submodules, layer_number, hidden_dropout) + args = get_args() + self.num_nextn_predict_layers = args.num_nextn_predict_layers + self.mtp_use_new_token = args.mtp_use_new_token + + # set mtp_idx + self.self_attention.mtp_idx = 0 + self.self_attention.core_attention.mtp_idx = 0 + + def forward(self, hidden_states, attention_mask, *args, **kwargs): + # Rename it. TransformerBlock requires the name to be hidden states, but it actually holds more than that. + mtp_hidden_states = hidden_states + mtp_attention_mask = None + if self.mtp_use_new_token and attention_mask is not None: + mtp_attention_mask = tensor_slide(attention_mask, + self.num_nextn_predict_layers, + dims=[-2, -1], + return_id=0) + if isinstance(mtp_hidden_states, Tensor): + # First layer (1-based by TransformerBlock), indicating original input from embedding + # Construct MTP hidden states that takes embedding of different past tokens + # if is_first_layer_in_stage(self.config,self.layer_number): + if self.layer_number == 1: + if self.mtp_use_new_token: + mtp_hidden_states = mtp_hidden_states.chunk(self.num_nextn_predict_layers + 1, dim=1) + else: + emb = mtp_hidden_states + mtp_hidden_states = tuple( + torch.cat( + [ + # Shift left + emb[i:], + emb.new_zeros(i, *emb.shape[1:]), + ], + dim=0, + ) + for i in range(self.num_nextn_predict_layers + 1) + ) + # Otherwise it must be the first layer in PP/VPP stage. Unpack hidden states into tuple + else: + mtp_hidden_states = mtp_hidden_states.chunk(self.num_nextn_predict_layers + 1, dim=1) + + output, context = super().forward(mtp_hidden_states[0].clone(), + mtp_attention_mask if mtp_attention_mask else attention_mask, + *args, + **kwargs) + + output_mtp_hidden_states = output, *mtp_hidden_states[1:] + + # Must not be dimension 0 since it will be considered as sequence length by rotary_pos_emb. + output_mtp_hidden_states = torch.concat(output_mtp_hidden_states, dim=1) + + return output_mtp_hidden_states, context + + +class MTPLayer(MegatronModule): + def __init__(self, config: TransformerConfig, submodules: MTPLayerSubmodules, layer_number: int = 1, + hidden_dropout: float = None): + super().__init__(config) + self.config = config + + args = get_args() + self.num_nextn_predict_layers = args.num_nextn_predict_layers + self.mtp_use_new_token = args.mtp_use_new_token + self.layer_number = _get_layer_offset(args) + layer_number + + self.prev_norm = build_module( + submodules.prev_norm, + config=self.config, + hidden_size=self.config.hidden_size, + eps=self.config.layernorm_epsilon, + ) + self.emb_norm = build_module( + submodules.emb_norm, + config=self.config, + hidden_size=self.config.hidden_size, + eps=self.config.layernorm_epsilon, + ) + + self.prev_proj = build_module( + submodules.prev_proj, + 2 * self.config.hidden_size, + self.config.hidden_size, + config=self.config, + init_method=self.config.init_method, + bias=self.config.add_bias_linear, + ) + + self.transformer_layer: TransformerLayer = build_module( + submodules.transformer_layer_spec, config=self.config, layer_number=layer_number + ) + + # set mtp_dix + self.transformer_layer.self_attention.mtp_idx = self.mtp_stride + 1 + self.transformer_layer.self_attention.core_attention.mtp_idx = self.mtp_stride + 1 + + def forward(self, hidden_states, attention_mask, *args, **kwargs): + mtp_attention_mask = None + if self.mtp_use_new_token: + + if attention_mask is not None: + mtp_attention_mask = tensor_slide(attention_mask, + self.num_nextn_predict_layers, + dims=[-2, -1], + return_id=self.mtp_stride + 1) + + mtp_hidden_states = hidden_states + + if isinstance(mtp_hidden_states, Tensor): + assert ( + self.layer_number > 1 + ), "MTPLayer should not be the global first layer" + # it must be the first layer in PP/VPP stage. Unpack hidden states into tuple + mtp_hidden_states = mtp_hidden_states.chunk(self.num_nextn_predict_layers + 1, dim=1) + + dim_ = self.prev_proj( + torch.cat( + [ + # Last token prediction + self.prev_norm(mtp_hidden_states[self.mtp_stride]), + # Current token embedding + self.emb_norm(mtp_hidden_states[self.mtp_stride + 1]), + ], + dim=-1, + ) + )[0] + # [s/t,b,h] + output, context = self.transformer_layer( + dim_, + # select [0] to throw the bias=None away + mtp_attention_mask if mtp_attention_mask else attention_mask, + *args, + **kwargs, + ) + + output_mtp_hidden_states = ( + # All previous tokens prediction + *mtp_hidden_states[: self.mtp_stride + 1], + output, + # All later unmodified embedding + *mtp_hidden_states[self.mtp_stride + 2:], + ) + + # If being last layer in stage, stack them as a new dimension + output_mtp_hidden_states = torch.concat(output_mtp_hidden_states, dim=1) + + return output_mtp_hidden_states, context + + @property + def mtp_stride(self): + return ( + self.num_nextn_predict_layers + - (self.config.num_layers - self.layer_number + 1) + ) + + +class GPTModelWithMTP(GPTModel): + """GPT Transformer language model with Multi-Token Prediction.""" + + config: TransformerConfig + + def __init__(self, config: TransformerConfig, transformer_layer_spec: ModuleSpec, vocab_size: int, + max_sequence_length: int, pre_process: bool = True, post_process: bool = True, + fp16_lm_cross_entropy: bool = False, parallel_output: bool = True, + share_embeddings_and_output_weights: bool = False, + position_embedding_type: Literal['learned_absolute', 'rope', 'none'] = 'learned_absolute', + rotary_percent: float = 1.0, rotary_base: int = 10000, + seq_len_interpolation_factor: Optional[float] = None) -> None: + super().__init__(config, transformer_layer_spec, vocab_size, max_sequence_length, pre_process, post_process, + fp16_lm_cross_entropy, parallel_output, share_embeddings_and_output_weights, + position_embedding_type, rotary_percent, rotary_base, seq_len_interpolation_factor) + global_args = get_args() + self.num_nextn_predict_layers = global_args.num_nextn_predict_layers + self.mtp_use_new_token = global_args.mtp_use_new_token + self.mtp_in_decoder = global_args.mtp_in_decoder + + # mtp require separate layernorms for main model and mtp modules + if self.post_process and global_args.num_nextn_predict_layers: + def build_layernorm(): + return build_module( + TENorm, + config=self.config, + hidden_size=self.config.hidden_size, + eps=self.config.layernorm_epsilon, + ) + + self.final_norms = ModuleList() + for i in range(global_args.num_nextn_predict_layers + 1): + self.final_norms.append(build_layernorm()) + else: + self.final_norms = None + + def forward(self, input_ids: Tensor, + position_ids: Tensor, attention_mask: Tensor, + decoder_input: Tensor = None, + labels: Tensor = None, + inference_params: InferenceParams = None, + packed_seq_params: PackedSeqParams = None, + extra_block_kwargs: dict = None, + tokentype_ids=None) -> Tensor: + """ + Forward function of the GPT Model This function passes the input tensors + through the embedding layer, and then the decoeder and finally into the post + processing layer (optional). + + It either returns the Loss values if labels are given or the final hidden units + add output_multiplier_scale to scale logits + """ + # If decoder_input is provided (not None), then input_ids and position_ids are ignored. + # Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input. + args = get_args() + + if not self.training and (hasattr(args, "rope_scaling_type") and args.rope_scaling_type == "longrope"): + args.rope_scaling_original_max_position_embeddings = args.max_position_embeddings + # Decoder embedding. + if decoder_input is not None: + pass + elif self.pre_process: + if self.mtp_use_new_token: + input_ids = torch.concat( + tensor_slide(input_ids, + slice_num=self.num_nextn_predict_layers, + ), + dim=0 + ) + position_ids = torch.concat( + generate_nextn_position_ids(position_ids, + self.num_nextn_predict_layers, + ), + dim=0 + ) + decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids) + if args.scale_emb is not None: + decoder_input = decoder_input * args.scale_emb + + else: + # intermediate stage of pipeline + # decoder will get hidden_states from encoder.input_tensor + decoder_input = None + + # Rotary positional embeddings (embedding is None for PP intermediate devices) + rotary_pos_emb = None + if self.position_embedding_type == 'rope': + rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len( + inference_params, self.decoder, decoder_input, self.config + ) + rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len) + + # Run decoder. + hidden_states = self.decoder( + hidden_states=decoder_input, + attention_mask=attention_mask, + inference_params=inference_params, + rotary_pos_emb=rotary_pos_emb, + packed_seq_params=packed_seq_params, + **(extra_block_kwargs or {}), + ) + + if not self.post_process: + return hidden_states + + if self.final_norms is not None: + hidden_states = hidden_states.chunk(self.num_nextn_predict_layers + 1, dim=1) + hidden_states = torch.concat([layernorm(hidden_part) \ + for layernorm, hidden_part in zip(self.final_norms, hidden_states)], dim=1) + + # logits and loss + output_weight = None + if self.share_embeddings_and_output_weights: + output_weight = self.shared_embedding_or_output_weight() + + if args.dim_model_base is not None: + hidden_states = hidden_states / (args.hidden_size / args.dim_model_base) + logits, _ = self.output_layer(hidden_states, weight=output_weight) + # new add to scale logits + if args.output_multiplier_scale: + logits = logits * args.output_multiplier_scale + + if args.output_logit_softcapping: + logits = logits / args.output_logit_softcapping + logits = torch.tanh(logits) + logits = logits * args.output_logit_softcapping + + if labels is None: + # [s b h] => [b s h] + return logits.transpose(0, 1).contiguous() + + if args.is_instruction_dataset: + labels = labels[:, 1:].contiguous() + logits = logits[:-1, :, :].contiguous() + + loss = self.compute_language_model_loss(labels, logits) + return loss + + def compute_language_model_loss(self, labels: Tensor, mtp_logits: Tensor) -> Tensor: + M = self.num_nextn_predict_layers + 1 + B, S = labels.shape + seq_len = S + + # shape is [s m*b v] + mtp_logits = mtp_logits.float() + # construct concatenated labels, just like mtp_logits; shape is [s m*b] + if self.mtp_use_new_token: + seq_len -= self.num_nextn_predict_layers + mtp_labels = torch.cat(tensor_slide(labels, self.num_nextn_predict_layers), dim=0).T + else: + # [b s] => [s b] + labels = labels.T + mtp_labels = torch.cat( + [torch.cat([labels[i:], labels.new_zeros(i, B)]) for i in range(M)], dim=1 + ) + mtp_loss: Tensor = tensor_parallel.vocab_parallel_cross_entropy(mtp_logits, mtp_labels) + + # [s*m b] => [m b s] + return mtp_loss.reshape(seq_len, M, B).permute(1, 2, 0).contiguous() diff --git a/mindspeed_llm/tasks/posttrain/utils.py b/mindspeed_llm/tasks/posttrain/utils.py index 1b66a0670..20078b916 100644 --- a/mindspeed_llm/tasks/posttrain/utils.py +++ b/mindspeed_llm/tasks/posttrain/utils.py @@ -119,7 +119,10 @@ def get_tensor_shapes_decorator(get_tensor_shapes): if args.tp_2d: tensor_shape = [[tensor_shape[0] // args.tp_x, tensor_shape[1], tensor_shape[2] // args.tp_y] for tensor_shape in tensor_shape] - + # In MTP models, all predictions are sent through PP + # we need to extend the shape by its batch dimension to cooperate with output layers + if args.num_nextn_predict_layers and args.mtp_in_decoder: + tensor_shape = [(s[0], (args.num_nextn_predict_layers + 1) * s[1], *s[2:]) for s in tensor_shape] return tensor_shape return wrapper diff --git a/mindspeed_llm/training/arguments.py b/mindspeed_llm/training/arguments.py index b7da64116..68ffb7e1f 100644 --- a/mindspeed_llm/training/arguments.py +++ b/mindspeed_llm/training/arguments.py @@ -166,6 +166,10 @@ def _add_mtp_args(parser): help='Multi-Token prediction recompute layer') group.add_argument('--share-mtp-embedding-and-output-weight', action='store_true', default=False, help='Main model share embedding and output weight with mtp layer.') + group.add_argument('--no-mtp-use-new-token', action='store_false', default=True, dest='mtp_use_new_token', + help='Mtp layer use new token or padded zero.') + group.add_argument('--mtp-in-decoder', action='store_true', default=False, dest='mtp_in_decoder', + help='Put mtp layers in decoder.') return parser @@ -366,7 +370,7 @@ def _add_moe_args(parser): group = parser.add_argument_group(title='moe') group.add_argument('--moe-router-load-balancing-type', type=str, choices=['aux_loss', "group_limited_greedy", "softmax_topk", "pai_megatron_aux_loss", - "sparsemixer_topk", "noaux_tc"], + "sparsemixer_topk", "noaux_tc","none"], default='aux_loss', help='Determines the load balancing strategy for the router. "aux_loss" corresponds ' 'to the load balancing loss used in GShard and SwitchTransformer, "sinkhorn" corresponds ' @@ -1093,8 +1097,8 @@ def _validate_moe_args(args): if args.moe_expert_capacity_factor < 0: args.moe_expert_capacity_factor = None print_rank0_by_args(f'When moe_expert_capacity_factor < 0, no token would be drop, so moe_expert_capacity_factor should be set to false.') - if args.moe_router_load_balancing_type not in ["aux_loss", "none"]: - raise ValueError(f'moe_expert_capacity_factor only works with aux_loss or none load balancing') + # if args.moe_router_load_balancing_type not in ["aux_loss", "none"]: + # raise ValueError(f'moe_expert_capacity_factor only works with aux_loss or none load balancing') if args.moe_expert_capacity_factor is None and args.moe_pad_expert_input_to_capacity: raise ValueError(f'moe_expert_capacity_factor must be set to use moe_pad_expert_input_to_capacity') if args.shared_expert_gate_output_dimension != 1 and args.shared_expert_gate_output_dimension != args.hidden_size: @@ -1188,6 +1192,23 @@ def _validate_aux_loss_free(args): ) +def _validate_mtp_args(args): + if not args.mtp_in_decoder: + if not args.mtp_use_new_token: + raise ValueError( + "Mtp in last stage support new token only." + "Please do not use --no-mtp-use-new-token." + ) + if args.mtp_in_decoder: + noop_layers = set([int(idx) for idx in args.noop_layers.split(",")]) if isinstance(args.noop_layers, str) \ + else {} + for i in range(args.num_nextn_predict_layers): + if args.num_layers - 1 - i in noop_layers: + raise ValueError( + "Mtp layers can not be noop layer." + ) + + def _validate_rl_training(args): return @@ -1508,6 +1529,7 @@ def validate_args_decorator(megatron_validate_args): _validate_long_rope(args) _validate_mlp_fusion(args) _validate_fused_opts(args) + _validate_mtp_args(args) _validate_noop_layer(args) _valid_tp_2d_args(args) diff --git a/mindspeed_llm/training/utils.py b/mindspeed_llm/training/utils.py index da9698a9c..98fe63595 100644 --- a/mindspeed_llm/training/utils.py +++ b/mindspeed_llm/training/utils.py @@ -45,14 +45,14 @@ WRITE_FILE_DEFAULT_FLAGS = os.O_WRONLY | os.O_CREAT WRITE_FILE_DEFAULT_MODES = stat.S_IWUSR | stat.S_IRUSR -def compute_actual_seq_len(seq, stride=0): +def compute_actual_seq_len(seq): """ compute actual seq len. MTP layer bring extra n tokens, which should be cut. """ res = list() _args = get_args() - seq_length = _args.seq_length + seq_length = seq.shape[1] batch_size = seq.shape[0] for batch_idx in range(batch_size): batch_seq = seq[batch_idx] @@ -60,8 +60,8 @@ def compute_actual_seq_len(seq, stride=0): zero_pos = (batch_seq == 0).nonzero()[1:].squeeze(dim=1) else: zero_pos = (batch_seq == 0).nonzero().squeeze(dim=1) - res.extend((zero_pos + (batch_idx * seq_length - stride)).tolist()) - batch_len = len(batch_seq) + batch_idx * seq_length - stride + res.extend((zero_pos + (batch_idx * seq_length)).tolist()) + batch_len = len(batch_seq) + batch_idx * seq_length if batch_len > seq_length * (batch_idx + 1): batch_len = seq_length * (batch_idx + 1) if batch_idx == batch_size - 1: @@ -70,11 +70,20 @@ def compute_actual_seq_len(seq, stride=0): def generate_actual_seq_len(batch): - position_ids = batch.get('position_ids').transpose(0, 1).contiguous() - set_position_ids(position_ids) - position_ids = batch.get('position_ids') - actual_seq_len = compute_actual_seq_len(position_ids) - set_actual_seq_len(actual_seq_len) + args = get_args() + if args.mtp_use_new_token and args.num_nextn_predict_layers and args.mtp_in_decoder: + position_ids = batch.get('position_ids').transpose(0, 1).contiguous() + set_position_ids(position_ids) + from mindspeed_llm.core.models.gpt.gpt_model import generate_nextn_position_ids + total_position_ids = generate_nextn_position_ids(batch.get('position_ids'), args.num_nextn_predict_layers) + actual_seq_len = [compute_actual_seq_len(pi) for pi in total_position_ids] + set_actual_seq_len(actual_seq_len) + else: + origin_position_ids = batch.get('position_ids') + position_ids = origin_position_ids[:, :origin_position_ids.shape[-1] - args.num_nextn_predict_layers] + set_position_ids(position_ids.transpose(0, 1).contiguous()) + actual_seq_len = compute_actual_seq_len(position_ids) + set_actual_seq_len(actual_seq_len) def parse_args(): @@ -525,6 +534,7 @@ def tensor_slide( dims: Union[int, List[int]] = -1, step: int = 1, return_first=False, + return_id=None, ) -> List[Union[torch.Tensor, None]]: """通用滑动窗口函数,支持任意维度""" if tensor is None: @@ -532,14 +542,20 @@ def tensor_slide( return [None] * (slice_num + 1) if slice_num == 0: return [tensor] - window_size = tensor.shape[-1] - slice_num dims = [dims] if isinstance(dims, int) else sorted(dims, reverse=True) + start_id = 0 + end_id = slice_num + 1 + if return_id is not None: + start_id = return_id + end_id = return_id + 1 + # 连续多维度滑动 slices = [] - for i in range(0, tensor.size(dims[-1]) - window_size + 1, step): + for i in range(start_id, end_id, step): slice_obj = [slice(None)] * tensor.dim() for dim in dims: + window_size = tensor.shape[dim] - slice_num slice_obj[dim] = slice(i, i + window_size) slices.append(tensor[tuple(slice_obj)]) if return_first: diff --git a/pretrain_gpt.py b/pretrain_gpt.py index da6067369..d1d7289cd 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -19,6 +19,7 @@ from megatron.core.datasets.gpt_dataset import MockGPTDataset, GPTDataset from megatron.core.datasets.utils import get_blend_from_list import megatron.legacy.model from megatron.core.models.gpt import GPTModel +from mindspeed_llm.tasks.models.mtp_modules import GPTModelWithMTP from mindspeed_llm.training import pretrain from megatron.core.transformer.spec_utils import import_module from megatron.training.utils import ( @@ -67,7 +68,12 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat else: transformer_layer_spec = get_gpt_layer_local_spec(args.num_experts, args.moe_grouped_gemm) - model = GPTModel( + if args.num_nextn_predict_layers > 0 and args.mtp_in_decoder: + model_cls = GPTModelWithMTP + else: + model_cls = GPTModel + + model = model_cls( config=config, transformer_layer_spec=transformer_layer_spec, vocab_size=args.padded_vocab_size, @@ -124,7 +130,7 @@ def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor): args = get_args() losses = output_tensor.float() - if args.num_nextn_predict_layers > 0: + if args.mtp_use_new_token and args.num_nextn_predict_layers > 0: loss_mask = tensor_slide(loss_mask, args.num_nextn_predict_layers, return_first=True)[0] loss_mask = loss_mask.view(-1).float() if args.context_parallel_size > 1: @@ -147,6 +153,67 @@ def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor): return loss * args.context_parallel_size, {'lm loss': averaged_loss[0]} +def mtp_loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor): + """Loss function. + + Args: + loss_mask (torch.Tensor): Used to mask out some portions of the loss + output_tensor (torch.Tensor): The tensor with the losses + """ + args = get_args() + if len(output_tensor.shape) != 3: + raise RuntimeError(f"Unrecognized {output_tensor.shape=}.") + if args.context_parallel_size > 1: + raise NotImplementedError(f"MTP with context parallel not implemented.") + + losses = output_tensor.float() # [b,b,s] + loss_mask = loss_mask.float() # [b,s] + # breakpoint() + mtp_losses = [] + if args.num_nextn_predict_layers and args.mtp_use_new_token: + loss_masks = tensor_slide(loss_mask, args.num_nextn_predict_layers) + for i in range(losses.shape[0]): + current_loss = losses[i, :, :] + current_loss_mask = loss_masks[i] + mtp_losses.append(torch.sum(current_loss.view(-1) * current_loss_mask) / current_loss_mask.sum()) + else: + for i in range(losses.shape[0]): + current_loss = losses[i, :, : loss_mask.shape[1] - i] + current_loss_mask = loss_mask[:, : loss_mask.shape[1] - i] + loss_mask_sum = current_loss_mask.sum() + loss_mask_sum = torch.max(loss_mask_sum, torch.ones_like(loss_mask_sum)) + mtp_losses.append(torch.sum(current_loss * current_loss_mask) / loss_mask_sum) + mtp_losses = torch.stack(mtp_losses) + + # MTP weighted loss + mtp_weighted_loss = mtp_losses[0] + mtp_losses[1:].mean() * args.mtp_loss_scale + + # Check individual rank losses are not NaN prior to DP all-reduce. + if args.check_for_nan_in_loss_and_grad: + global_rank = torch.distributed.get_rank() + if mtp_losses.sum().isnan(): + raise ValueError( + f'Rank {global_rank}: found NaN in local forward loss calculation. ' + f'Device: {torch.cuda.current_device()}, node: {os.uname()[1]}' + ) + + # print(losses[0], loss_mask) + + # Reduce loss for logging. + # averaged_loss = [ + # average_losses_across_data_parallel_group([mtp_loss])[0] for mtp_loss in mtp_losses + # ] + # + # result = {f'lm mtp head {i} loss': l for i, l in enumerate(averaged_loss)} + # for i, l in enumerate(mtp_losses): + # result[f'local mtp head {i} loss'] = l + # + # return mtp_weighted_loss, result + averaged_loss = average_losses_across_data_parallel_group([mtp_weighted_loss]) + + return mtp_weighted_loss, {'lm loss': averaged_loss[0]} + + def forward_step(data_iterator, model: GPTModel): """Forward training step. @@ -166,7 +233,11 @@ def forward_step(data_iterator, model: GPTModel): output_tensor = model(tokens, position_ids, attention_mask, labels=labels) - return output_tensor, partial(loss_func, loss_mask) + if not args.mtp_in_decoder or args.num_nextn_predict_layers == 0: + f_loss = loss_func + else: + f_loss = mtp_loss_func + return output_tensor, partial(f_loss, loss_mask) def is_dataset_built_on_rank(): @@ -175,10 +246,12 @@ def is_dataset_built_on_rank(): def core_gpt_dataset_config_from_args(args): tokenizer = get_tokenizer() - + total_seq_len = args.seq_length + if args.mtp_use_new_token: + total_seq_len += args.num_nextn_predict_layers return GPTDatasetConfig( random_seed=args.seed, - sequence_length=args.seq_length + args.num_nextn_predict_layers, + sequence_length=total_seq_len, blend=get_blend_from_list(args.data_path), blend_per_split=[ get_blend_from_list(args.train_data_path), -- Gitee