diff --git a/mindspeed_llm/features_manager/__init__.py b/mindspeed_llm/features_manager/__init__.py index 0fba65ab5da3d0882e5c8f56c0b5827816cfb022..01419ce684191239acdb4b6775b7d62580f3af70 100644 --- a/mindspeed_llm/features_manager/__init__.py +++ b/mindspeed_llm/features_manager/__init__.py @@ -13,6 +13,7 @@ from mindspeed.features_manager import ( MoEAllGatherOverLapFeature, MoEFwdBwdOverlapFeature, MoEGmmFeature, + ExpertsPlacementFeature, MoEZeroMemoryFeature, OptimizeSendRecvCommFeature, SwapOptimizerFeature, @@ -210,6 +211,7 @@ def add_moe_features(features_list: List[MindSpeedFeature]): MoEFwdBwdOverlapFeature(), MoEAlltoAllOverLapFeature(), MoEZeroMemoryFeature(), + ExpertsPlacementFeature(), ]) diff --git a/mindspeed_llm/training/training.py b/mindspeed_llm/training/training.py index f2f55a599bc8e54630a68bc7903cf5c541ca8d70..ec7ab205fcf7345c26b1217f7dc2136c9d6af6bd 100644 --- a/mindspeed_llm/training/training.py +++ b/mindspeed_llm/training/training.py @@ -60,6 +60,11 @@ from megatron.core.distributed import DistributedDataParallel as DDP from megatron.core.distributed import finalize_model_grads from mindspeed_llm.training.initialize import set_jit_fusion_options from mindspeed_llm.tasks.posttrain.lora.utils import is_enable_lora +from mindspeed.core.transformer.moe.expert_placement.planner import print_expert_load +from mindspeed.core.transformer.moe.expert_placement.executor import ( + build_param_params_module_mlp_map, + expert_weight_and_optimizer_state_placement +) # The earliest we can measure the start time. _TRAIN_START_TIME = time.time() @@ -292,6 +297,14 @@ def build_train_args(*input_args): model_provider_func = model_provider model, optimizer, opt_param_scheduler = setup_model_and_optimizer( model_provider_func, model_type) + # param mapping to mlp object + if args.enable_expert_placement: + params_module_mlp_map = build_param_params_module_mlp_map(model) + if hasattr(optimizer, "chained_optimizers"): + for optimizer_sub in optimizer.chained_optimizers: + optimizer_sub.params_module_mlp_map = params_module_mlp_map + else: + optimizer.params_module_mlp_map = params_module_mlp_map timers('model-and-optimizer-setup').stop() print_datetime('after model, optimizer, and learning rate ' 'scheduler are built') @@ -604,6 +617,11 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler, update_num_microbatches(args.consumed_train_samples, consistency_check=True) args.curr_iteration = iteration + if args.enable_expert_placement: + expert_weight_and_optimizer_state_placement(args, model, optimizer) + if args.print_expert_load: + print_expert_load(args, model, iteration) + loss_dict, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad = \ train_step(forward_step_func, train_data_iterator,