diff --git a/mindspeed_llm/training/training.py b/mindspeed_llm/training/training.py index f7213c6736fb9f2b60b34f1298c983e942700a65..c12dd9e7f9e7de642279a3e8e31152075db5cbdb 100644 --- a/mindspeed_llm/training/training.py +++ b/mindspeed_llm/training/training.py @@ -527,8 +527,7 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler, if len(model) == 1: config.grad_sync_func = config.grad_sync_func[0] if args.overlap_param_gather and args.align_param_gather: - config.param_sync_func = [lambda x: optimizer.finish_param_sync(model_index, x) - for model_index in range(len(model))] + config.param_sync_func = [model_chunk.start_param_sync for model_chunk in model] if len(model) == 1: config.param_sync_func = config.param_sync_func[0] config.finalize_model_grads_func = finalize_model_grads diff --git a/tests/st/shell_scripts/llama2_tp2_pp4_vpp2_swap.sh b/tests/st/shell_scripts/llama2_tp2_pp4_vpp2_swap.sh index 65baa1f158a5ea9dc1321a1d9458aa06339dccb6..a13c1cb6083c74a3afbff3c33fd41d44fb11937f 100644 --- a/tests/st/shell_scripts/llama2_tp2_pp4_vpp2_swap.sh +++ b/tests/st/shell_scripts/llama2_tp2_pp4_vpp2_swap.sh @@ -95,6 +95,7 @@ TRAINING_ARGS=( --use-fused-swiglu --use-fused-rotary-pos-emb --overlap-grad-reduce + --overlap-param-gather --bf16 --use-distributed-optimizer )