From 0f6a5e30e4ed1aedd4870d7c2b462931a4848eaa Mon Sep 17 00:00:00 2001 From: xinyuan Date: Tue, 17 Jun 2025 11:22:06 +0800 Subject: [PATCH] fix_forward_step_patch --- .../core/pipeline_parallel/schedules.py | 193 ++++++++++++++++++ mindspeed_llm/mindspore/mindspore_adaptor.py | 4 +- 2 files changed, 196 insertions(+), 1 deletion(-) create mode 100644 mindspeed_llm/mindspore/core/pipeline_parallel/schedules.py diff --git a/mindspeed_llm/mindspore/core/pipeline_parallel/schedules.py b/mindspeed_llm/mindspore/core/pipeline_parallel/schedules.py new file mode 100644 index 000000000..e264f6dc8 --- /dev/null +++ b/mindspeed_llm/mindspore/core/pipeline_parallel/schedules.py @@ -0,0 +1,193 @@ +# Copyright (c) 2022; NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2025, Huawei Technologies Co., Ltd. All rights reserved. + +# NOTE patch import begin +# NOTE patch import end + +import contextlib +import torch + +from mindspore.common.api import _pynative_executor +from megatron.core.pipeline_parallel.schedules import set_current_microbatch +from megatron.core import parallel_state +from megatron.core.utils import get_attr_wrapped_model +from megatron.core.transformer.moe.router import MoEAuxLossAutoScaler +from megatron.core.enums import ModelType +from megatron.core.utils import ( + get_attr_wrapped_model, + get_model_type, +) +from mindspeed_llm.core.transformer.multi_token_prediction import MTPLossAutoScaler + + +def forward_step( + forward_step_func, + data_iterator, + model, + num_microbatches, + input_tensor, + forward_data_store, + config, + collect_non_loss_data=False, + checkpoint_activations_microbatch=None, + is_first_microbatch=False, + current_microbatch=None, + forward_only=False, +): + """Forward step for passed-in model. + + If it is the first stage, the input tensor is obtained from the data_iterator. + Otherwise, the passed-in input_tensor is used. + + Args: + forward_step_func (callable): The forward step function for the model that takes the + data iterator as the first argument, and model as the second. + This user's forward step is expected to output a tuple of two elements: + 1. The output object from the forward step. This output object needs to be a + tensor or some kind of collection of tensors. The only hard requirement + for this object is that it needs to be acceptible as input into the second + function. + 2. A function to reduce (optionally) the output from the forward step. This + could be a reduction over the loss from the model, it could be a function that + grabs the output from the model and reformats, it could be a function that just + passes through the model output. This function must have one of the following + patterns, and depending on the pattern different things happen internally. + a. A tuple of reduced loss and some other data. Note that in this case + the first argument is divided by the number of global microbatches, + assuming it is a loss, so that the loss is stable as a function of + the number of devices the step is split across. + b. A triple of reduced loss, number of tokens, and some other data. This + is similar to case (a), but the loss is further averaged across the + number of tokens in the batch. If the user is not already averaging + across the number of tokens, this pattern is useful to use. + c. Any arbitrary data the user wants (eg a dictionary of tensors, a list + of tensors, etc in the case of inference). To trigger case 3 you need + to specify `collect_non_loss_data=True` and you may also want to + specify `forward_only=True` in the call to the parent forward_backward + function. + data_iterator (iterator): The data iterator. + model (nn.Module): The model to perform the forward step on. + num_microbatches (int): The number of microbatches. + input_tensor (Tensor or list[Tensor]): The input tensor(s) for the forward step. + forward_data_store (list): The list to store the forward data. If you go down path 2.a or + 2.b for the return of your forward reduction function then this will store only the + final dimension of the output, for example the metadata output by the loss function. + If you go down the path of 2.c then this will store the entire output of the forward + reduction function applied to the model output. + config (object): The configuration object. + collect_non_loss_data (bool, optional): Whether to collect non-loss data. Defaults to False. + This is the path to use if you want to collect arbitrary output from the model forward, + such as with inference use cases. Defaults to False. + checkpoint_activations_microbatch (int, optional): The microbatch to checkpoint activations. + Defaults to None. + is_first_microbatch (bool, optional): Whether it is the first microbatch. Defaults to False. + current_microbatch (int, optional): The current microbatch. Defaults to None. + + Returns: + Tensor or list[Tensor]: The output object(s) from the forward step. + Tensor: The number of tokens. + """ + if config.timers is not None: + config.timers('forward-compute', log_level=2).start() + + if is_first_microbatch and hasattr(model, 'set_is_first_microbatch'): + model.set_is_first_microbatch() + if current_microbatch is not None: + set_current_microbatch(model, current_microbatch) + + unwrap_output_tensor = False + if not isinstance(input_tensor, list): + input_tensor = [input_tensor] + unwrap_output_tensor = True + + set_input_tensor = get_attr_wrapped_model(model, "set_input_tensor") + set_input_tensor(input_tensor) + + if not parallel_state.is_pipeline_first_stage() and input_tensor is not None: + input_tensor[0].retain_grad() + + # run forward + num_tokens = torch.tensor(0, dtype=torch.int) + if input_tensor[0] is None: + input_tensor[0] = num_tokens + + if config.enable_autocast: + context_manager = torch.autocast("cuda", dtype=config.autocast_dtype) + else: + context_manager = contextlib.nullcontext() + if not forward_only: + _pynative_executor.set_grad_flag(True) + _pynative_executor.new_graph(forward_step_func, input_tensor[0]) + with context_manager: + if checkpoint_activations_microbatch is None: + output_tensor, loss_func = forward_step_func(data_iterator, model) + else: + output_tensor, loss_func = forward_step_func( + data_iterator, model, checkpoint_activations_microbatch + ) + + num_tokens = torch.tensor(0, dtype=torch.int) + if parallel_state.is_pipeline_last_stage(): + if not collect_non_loss_data: + outputs = loss_func(output_tensor) + if len(outputs) == 3: + output_tensor, num_tokens, loss_reduced = outputs + if not config.calculate_per_token_loss: + output_tensor /= num_tokens + output_tensor /= num_microbatches + else: + # preserve legacy loss averaging behavior (ie, over the number of microbatches) + assert len(outputs) == 2 + output_tensor, loss_reduced = outputs + output_tensor /= num_microbatches + forward_data_store.append(loss_reduced) + else: + data = loss_func(output_tensor, non_loss_data=True) + forward_data_store.append(data) + output_tensor = None + if not forward_only: + _pynative_executor.end_graph(forward_step_func, output_tensor, input_tensor[0]) + + if config.timers is not None: + config.timers('forward-compute').stop() + + # Set the loss scale for the auxiliary loss of the MoE layer. + # Since we use a trick to do backward on the auxiliary loss, we need to set the scale explicitly. + if hasattr(config, 'num_moe_experts') and config.num_moe_experts is not None: + # Calculate the loss scale based on the grad_scale_func if available, else default to 1. + loss_scale = ( + config.grad_scale_func(torch.ones(1, device=output_tensor.device)) + if config.grad_scale_func is not None + else torch.tensor(1.0) + ) + # Set the loss scale + MoEAuxLossAutoScaler.set_loss_scale(loss_scale / num_microbatches) + + + # Set the loss scale for Multi-Token Prediction (MTP) loss. + if hasattr(config, 'mtp_num_layers') and config.mtp_num_layers is not None: + # Calculate the loss scale based on the grad_scale_func if available, else default to 1. + loss_scale = ( + config.grad_scale_func(torch.ones(1, device=output_tensor.device)) + if config.grad_scale_func is not None + else torch.ones(1, device=output_tensor.device) + ) + # Set the loss scale + if config.calculate_per_token_loss: + MTPLossAutoScaler.set_loss_scale(loss_scale) + else: + MTPLossAutoScaler.set_loss_scale(loss_scale / num_microbatches) + + # If T5 model (or other model with encoder and decoder) + # and in decoder stack, then send encoder_hidden_state + # downstream as well. + model_type = get_model_type(model) + if ( + parallel_state.is_pipeline_stage_after_split() + and model_type == ModelType.encoder_and_decoder + ): + return [output_tensor, input_tensor[-1]], num_tokens + + if unwrap_output_tensor: + return output_tensor, num_tokens + return [output_tensor], num_tokens diff --git a/mindspeed_llm/mindspore/mindspore_adaptor.py b/mindspeed_llm/mindspore/mindspore_adaptor.py index cd3992a2b..7b7f0dbdf 100644 --- a/mindspeed_llm/mindspore/mindspore_adaptor.py +++ b/mindspeed_llm/mindspore/mindspore_adaptor.py @@ -95,8 +95,10 @@ class MindSporeAdaptation(MegatronAdaptationABC): from mindspeed.mindspore.core.optimizer.optimizer import megatron_optimizer_init MindSporeAdaptation.register('megatron.core.optimizer.optimizer.MegatronOptimizer.__init__', megatron_optimizer_init) - from mindspeed.mindspore.core.pipeline_parallel.schedules import forward_step, backward_step, forward_backward_no_pipelining + from mindspeed.mindspore.core.pipeline_parallel.schedules import backward_step, forward_backward_no_pipelining + from mindspeed_llm.mindspore.core.pipeline_parallel.schedules import forward_step MindSporeAdaptation.register('megatron.core.pipeline_parallel.schedules.forward_step', forward_step) + MindSporeAdaptation.register('mindspeed.mindspore.core.pipeline_parallel.schedules.forward_step', forward_step) MindSporeAdaptation.register('megatron.core.pipeline_parallel.schedules.backward_step', backward_step) MindSporeAdaptation.register('megatron.core.pipeline_parallel.schedules.forward_backward_no_pipelining', forward_backward_no_pipelining) -- Gitee