diff --git a/mindspeed_llm/training/arguments.py b/mindspeed_llm/training/arguments.py index 61b88be8dcd3493d168160c486320a271c3347b8..5c71b5ed80aedd9814c7a5b4376f099d0ea0eabb 100644 --- a/mindspeed_llm/training/arguments.py +++ b/mindspeed_llm/training/arguments.py @@ -1025,8 +1025,8 @@ def _validate_recompute_args(args): def _validate_instruction_finetune(args): if args.variable_seq_lengths: - if args.context_parallel_size > 1: - raise AssertionError('Context parallelism is forbidden when use variable seq lengths.') + if args.context_parallel_size > 1 and args.pad_to_multiple_of % (args.tensor_model_parallel_size * args.context_parallel_size) == 0: + raise AssertionError('pad_to_multiple_of must be divided by (tp * cp) when use cp.') if args.num_experts is not None and args.moe_token_dispatcher_type == "allgather": raise AssertionError('moe_token_dispatcher_type "allgather" is forbidden when use variable seq lengths. you can choose "alltoall"')