diff --git a/mindspeed_llm/core/pipeline_parallel/schedules.py b/mindspeed_llm/core/pipeline_parallel/schedules.py index 1a54655b655998cdbc6b0cc1451d6118e8a942b3..320c6d2e9d1f89994250e403d6af1c4ccba5a247 100644 --- a/mindspeed_llm/core/pipeline_parallel/schedules.py +++ b/mindspeed_llm/core/pipeline_parallel/schedules.py @@ -33,6 +33,13 @@ def get_forward_backward_func_wrapper(get_forward_backward_func): if arguments.enable_high_availability: forward_backward_func = forward_backward_func_wrapper(forward_backward_func) + if arguments.enable_energy_tune: + from mindspeed.core.pipeline_parallel.envpipe import energy_schedules + if arguments.num_layers_per_virtual_pipeline_stage is None: + forward_backward_func = energy_schedules.forward_backward_pipelining_without_interleaving + else: + forward_backward_func = energy_schedules.forward_backward_pipelining_with_interleaving + return forward_backward_func return wrapper diff --git a/mindspeed_llm/tasks/megatron_adaptor.py b/mindspeed_llm/tasks/megatron_adaptor.py index 1c2fa1d22d10978f2e823772183706559bc41e93..ad932f1deb16a61321acd00fa30ebc698fd21e9f 100644 --- a/mindspeed_llm/tasks/megatron_adaptor.py +++ b/mindspeed_llm/tasks/megatron_adaptor.py @@ -831,6 +831,11 @@ class LegacyAdaptation(MegatronAdaptationABC): MegatronAdaptation.register('megatron.training.training.train', train) MegatronAdaptation.register('megatron.training.training.load_checkpoint', load_checkpoint_wrapper) + args = MegatronAdaptation.get_args() + if args.enable_energy_tune: + from mindspeed.core.pipeline_parallel.envpipe.energy_megatron_adaptor import train_step_wrapper + MegatronAdaptation.register('megatron.training.training.train_step', train_step_wrapper) + def patch_inference(self): from ..inference.text_generation.tokenization import tokenize_prompts, _tokenize_prompts_and_batch from ..inference.text_generation.forward_step import inference_forward_step_init_wrapper, _forward_step_helper, _allocate_recv_buffer, \ diff --git a/mindspeed_llm/training/arguments.py b/mindspeed_llm/training/arguments.py index b7da641162053eecc7b332668c1f09cf8f66e0f5..217b70954743bd3d9b67504c42028cb40dbdd813 100644 --- a/mindspeed_llm/training/arguments.py +++ b/mindspeed_llm/training/arguments.py @@ -73,6 +73,7 @@ def process_args(parser): parser = _add_default_model_args(parser) parser = _add_megatron2_args(parser) parser = _add_inference_args(parser) + parser = _add_energy_tune_args(parser) return parser @@ -983,6 +984,17 @@ def _add_hccl_group_buffer_args(parser): return parser +def _add_energy_tune_args(parser): + group = parser.add_argument_group(title='energy') + group.add_argument("--enable-energy-tune", action='store_true', default=False, + help="Enable energy tuning on pipeline parallel training.") + group.add_argument("--energy-tune-step-start", type=int, default=5, + help="Number of warmup steps to run before energy tuning start.") + group.add_argument("--energy-tune-save-path", type=str, default="./", + help="Path to save energy tuning results.") + return parser + + def _validate_create_attention_mask_in_dataloader(args): args.create_attention_mask_in_dataloader = False reset_data = args.reset_attention_mask