From 862cd4cd60ab9f21ef7b2fe53b4e3ecac716fe03 Mon Sep 17 00:00:00 2001 From: guozhihua Date: Sat, 9 Aug 2025 10:15:47 +0800 Subject: [PATCH] fix train in the scene of overlap_param_gather and align_param_gather --- mindspeed_llm/training/training.py | 3 +-- tests/st/shell_scripts/llama2_tp2_pp4_vpp2_swap.sh | 1 + 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mindspeed_llm/training/training.py b/mindspeed_llm/training/training.py index f7213c673..c12dd9e7f 100644 --- a/mindspeed_llm/training/training.py +++ b/mindspeed_llm/training/training.py @@ -527,8 +527,7 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler, if len(model) == 1: config.grad_sync_func = config.grad_sync_func[0] if args.overlap_param_gather and args.align_param_gather: - config.param_sync_func = [lambda x: optimizer.finish_param_sync(model_index, x) - for model_index in range(len(model))] + config.param_sync_func = [model_chunk.start_param_sync for model_chunk in model] if len(model) == 1: config.param_sync_func = config.param_sync_func[0] config.finalize_model_grads_func = finalize_model_grads diff --git a/tests/st/shell_scripts/llama2_tp2_pp4_vpp2_swap.sh b/tests/st/shell_scripts/llama2_tp2_pp4_vpp2_swap.sh index 65baa1f15..a13c1cb60 100644 --- a/tests/st/shell_scripts/llama2_tp2_pp4_vpp2_swap.sh +++ b/tests/st/shell_scripts/llama2_tp2_pp4_vpp2_swap.sh @@ -95,6 +95,7 @@ TRAINING_ARGS=( --use-fused-swiglu --use-fused-rotary-pos-emb --overlap-grad-reduce + --overlap-param-gather --bf16 --use-distributed-optimizer ) -- Gitee