From eebae6df83eaa118a9e4350b30ea3e7b363d5397 Mon Sep 17 00:00:00 2001 From: shenjiarun Date: Sat, 28 Jun 2025 16:26:19 +0800 Subject: [PATCH 1/2] reset attention maks padding --- .../pipeline_parallel/p2p_communication.py | 188 ++++++++++++++++++ mindspeed_llm/tasks/megatron_adaptor.py | 19 +- mindspeed_llm/training/arguments.py | 4 +- mindspeed_llm/training/utils.py | 85 +++++++- 4 files changed, 287 insertions(+), 9 deletions(-) diff --git a/mindspeed_llm/core/pipeline_parallel/p2p_communication.py b/mindspeed_llm/core/pipeline_parallel/p2p_communication.py index 079bf27cb..fc69c3288 100644 --- a/mindspeed_llm/core/pipeline_parallel/p2p_communication.py +++ b/mindspeed_llm/core/pipeline_parallel/p2p_communication.py @@ -19,7 +19,10 @@ from megatron.core.parallel_state import ( get_pipeline_model_parallel_group, get_pipeline_model_parallel_next_rank, get_pipeline_model_parallel_prev_rank, + get_pipeline_model_parallel_rank ) +from megatron.training import get_args +from mindspeed.utils import get_actual_seq_len, set_actual_seq_len, get_position_ids, set_position_ids def _batched_p2p_ops( @@ -68,3 +71,188 @@ def _batched_p2p_ops( else: reqs = [] return reqs + + +def _p2p_ops_eod_variable_seq_lengths( + *, + tensor_send_prev: Optional[torch.Tensor], + tensor_recv_prev: Optional[torch.Tensor], + tensor_send_next: Optional[torch.Tensor], + tensor_recv_next: Optional[torch.Tensor], + group: torch.distributed.ProcessGroup, +): + reqs = [] + rank = get_pipeline_model_parallel_rank() + prev_actual_seq_len = get_actual_seq_len() + prev_position_ids = get_position_ids() + + tensor_length = None + length_buffer = None + seq_len = None + seq_len_buffer = None + + args = get_args() + bsz = args.micro_batch_size + + if tensor_send_next is not None: + tensor_length = torch.tensor(prev_actual_seq_len.numel()).npu() + seq_len = torch.tensor(prev_position_ids.shape[0]).npu() + + if tensor_recv_prev is not None: + length_buffer = torch.empty((), dtype=torch.int64, device=torch.cuda.current_device()) + seq_len_buffer = torch.empty((), dtype=torch.int64, device=torch.cuda.current_device()) + + if rank % 2 == 0: + if tensor_length is not None: + send_next_req = torch.distributed.isend( + tensor=tensor_length, dst=get_pipeline_model_parallel_next_rank(), group=group, + ) + reqs.append(send_next_req) + + if length_buffer is not None: + recv_prev_req = torch.distributed.irecv( + tensor=length_buffer, src=get_pipeline_model_parallel_prev_rank(), group=group, + ) + reqs.append(recv_prev_req) + + if seq_len is not None: + send_next_req = torch.distributed.isend( + tensor=seq_len, dst=get_pipeline_model_parallel_next_rank(), group=group, + ) + reqs.append(send_next_req) + + if seq_len_buffer is not None: + recv_prev_req = torch.distributed.irecv( + tensor=seq_len_buffer, src=get_pipeline_model_parallel_prev_rank(), group=group, + ) + reqs.append(recv_prev_req) + else: + if length_buffer is not None: + recv_prev_req = torch.distributed.irecv( + tensor=length_buffer, src=get_pipeline_model_parallel_prev_rank(), group=group, + ) + reqs.append(recv_prev_req) + + if tensor_length is not None: + send_next_req = torch.distributed.isend( + tensor=tensor_length, dst=get_pipeline_model_parallel_next_rank(), group=group, + ) + reqs.append(send_next_req) + + if seq_len_buffer is not None: + recv_prev_req = torch.distributed.irecv( + tensor=seq_len_buffer, src=get_pipeline_model_parallel_prev_rank(), group=group, + ) + reqs.append(recv_prev_req) + + if seq_len is not None: + send_next_req = torch.distributed.isend( + tensor=seq_len, dst=get_pipeline_model_parallel_next_rank(), group=group, + ) + reqs.append(send_next_req) + + for req in reqs: + req.wait() + + reqs = [] + + if get_pipeline_model_parallel_rank() % 2 == 0: + if tensor_send_next is not None: + req = torch.distributed.isend( + tensor=prev_actual_seq_len, dst=get_pipeline_model_parallel_next_rank(), group=get_pipeline_model_parallel_group(), + ) + reqs.append(req) + + req = torch.distributed.isend( + tensor=prev_position_ids, dst=get_pipeline_model_parallel_next_rank(), group=get_pipeline_model_parallel_group(), + ) + reqs.append(req) + + send_next_req = torch.distributed.isend( + tensor=tensor_send_next, dst=get_pipeline_model_parallel_next_rank(), group=group, + ) + reqs.append(send_next_req) + + if tensor_recv_prev is not None: + actual_seq_len_buffer = torch.empty([length_buffer.item()], dtype=torch.int64, device=torch.cuda.current_device()) + + req = torch.distributed.irecv( + tensor=actual_seq_len_buffer, src=get_pipeline_model_parallel_prev_rank(), group=group, + ) + reqs.append(req) + set_actual_seq_len(actual_seq_len_buffer) + + position_ids_buffer = torch.empty((seq_len_buffer.item(), bsz), dtype=torch.int64, device=torch.cuda.current_device()) + req = torch.distributed.irecv( + tensor=position_ids_buffer, src=get_pipeline_model_parallel_prev_rank(), group=group, + ) + set_position_ids(position_ids_buffer) + reqs.append(req) + + recv_prev_req = torch.distributed.irecv( + tensor=tensor_recv_prev, src=get_pipeline_model_parallel_prev_rank(), group=group, + ) + reqs.append(recv_prev_req) + + if tensor_send_prev is not None: + send_prev_req = torch.distributed.isend( + tensor=tensor_send_prev, dst=get_pipeline_model_parallel_prev_rank(), group=group, + ) + reqs.append(send_prev_req) + + if tensor_recv_next is not None: + recv_next_req = torch.distributed.irecv( + tensor=tensor_recv_next, src=get_pipeline_model_parallel_next_rank(), group=group, + ) + reqs.append(recv_next_req) + + else: + if tensor_recv_prev is not None: + actual_seq_len_buffer = torch.empty([length_buffer.item()], dtype=torch.int64, device=torch.cuda.current_device()) + + req = torch.distributed.irecv( + tensor=actual_seq_len_buffer, src=get_pipeline_model_parallel_prev_rank(), group=group, + ) + reqs.append(req) + set_actual_seq_len(actual_seq_len_buffer) + + position_ids_buffer = torch.empty((seq_len_buffer.item(), bsz), dtype=torch.int64, device=torch.cuda.current_device()) + req = torch.distributed.irecv( + tensor=position_ids_buffer, src=get_pipeline_model_parallel_prev_rank(), group=group, + ) + set_position_ids(position_ids_buffer) + reqs.append(req) + + recv_prev_req = torch.distributed.irecv( + tensor=tensor_recv_prev, src=get_pipeline_model_parallel_prev_rank(), group=group, + ) + reqs.append(recv_prev_req) + + if tensor_send_next is not None: + req = torch.distributed.isend( + tensor=prev_actual_seq_len, dst=get_pipeline_model_parallel_next_rank(), group=get_pipeline_model_parallel_group(), + ) + reqs.append(req) + + req = torch.distributed.isend( + tensor=prev_position_ids, dst=get_pipeline_model_parallel_next_rank(), group=get_pipeline_model_parallel_group(), + ) + reqs.append(req) + + send_next_req = torch.distributed.isend( + tensor=tensor_send_next, dst=get_pipeline_model_parallel_next_rank(), group=group, + ) + reqs.append(send_next_req) + + if tensor_recv_next is not None: + recv_next_req = torch.distributed.irecv( + tensor=tensor_recv_next, src=get_pipeline_model_parallel_next_rank(), group=group, + ) + reqs.append(recv_next_req) + + if tensor_send_prev is not None: + send_prev_req = torch.distributed.isend( + tensor=tensor_send_prev, dst=get_pipeline_model_parallel_prev_rank(), group=group, + ) + reqs.append(send_prev_req) + return reqs \ No newline at end of file diff --git a/mindspeed_llm/tasks/megatron_adaptor.py b/mindspeed_llm/tasks/megatron_adaptor.py index 3048a4693..a5c9e693c 100644 --- a/mindspeed_llm/tasks/megatron_adaptor.py +++ b/mindspeed_llm/tasks/megatron_adaptor.py @@ -299,7 +299,7 @@ class CoreAdaptation(MegatronAdaptationABC): get_gpt_layer_local_spec) MegatronAdaptation.register('megatron.core.models.gpt.gpt_layer_specs.get_gpt_layer_local_spec', get_gpt_layer_local_spec_wrapper) - if not args.reset_attention_mask: + if not args.reset_attention_mask and args.context_parallel_size > 1: MegatronAdaptation.register('megatron.training.utils.get_batch_on_this_cp_rank', get_batch_on_this_cp_rank) MegatronAdaptation.register('megatron.training.dist_signal_handler.get_device', get_device_wrapper) # moe_fb_overlap will shadow this forward impl @@ -388,7 +388,7 @@ class CoreAdaptation(MegatronAdaptationABC): from mindspeed.core.training import train_step MegatronAdaptation.register('megatron.training.training.train_step', train_step) - if getattr(args, 'reset_attention_mask', False): + if getattr(args, 'reset_attention_mask', False) and args.context_parallel_size > 1: from mindspeed.core.datasets.gpt_dataset import _get_ltor_masks_and_position_ids, collate_wrapper from mindspeed.utils import get_batch_on_this_cp_rank, get_batch_on_this_cp_rank_wrapper MegatronAdaptation.register('megatron.training.utils.get_batch_on_this_cp_rank', get_batch_on_this_cp_rank) @@ -398,8 +398,13 @@ class CoreAdaptation(MegatronAdaptationABC): MegatronAdaptation.register('torch.utils.data._utils.collate.default_collate', collate_wrapper) MegatronAdaptation.register('megatron.training.utils.get_batch_on_this_cp_rank', get_batch_on_this_cp_rank_wrapper) - from mindspeed.core.pipeline_parallel.p2p_communication import _p2p_ops_eod - MegatronAdaptation.register('megatron.core.pipeline_parallel.p2p_communication._p2p_ops', _p2p_ops_eod) + if args.variable_seq_lengths: + from mindspeed_llm.core.pipeline_parallel.p2p_communication import _p2p_ops_eod_variable_seq_lengths + MegatronAdaptation.register('megatron.core.pipeline_parallel.p2p_communication._p2p_ops', _p2p_ops_eod_variable_seq_lengths) + else: + from mindspeed.core.pipeline_parallel.p2p_communication import _p2p_ops_eod + MegatronAdaptation.register('megatron.core.pipeline_parallel.p2p_communication._p2p_ops', _p2p_ops_eod) + from mindspeed_llm.core.models.gpt.gpt_model import gpt_forward_wrapper MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.forward', gpt_forward_wrapper) from mindspeed.core.models.common.embeddings.rotary_pos_embedding import apply_rotary_pos_emb_thd @@ -410,6 +415,10 @@ class CoreAdaptation(MegatronAdaptationABC): from mindspeed.core.models.common.embeddings.rotary_pos_embedding import rotary_forward MegatronAdaptation.register('megatron.core.models.common.embeddings.rotary_pos_embedding.RotaryEmbedding.forward', rotary_forward) + from mindspeed.core.pipeline_parallel.p2p_communication import _communicate, _communicate_shapes + MegatronAdaptation.register('megatron.core.pipeline_parallel.p2p_communication._communicate', _communicate) + MegatronAdaptation.register('megatron.core.pipeline_parallel.p2p_communication._communicate_shapes',_communicate_shapes) + # For Dualpipe args = MegatronAdaptation.get_args() @@ -684,7 +693,7 @@ class CoreAdaptation(MegatronAdaptationABC): MegatronAdaptation.register('megatron.core.datasets.gpt_dataset.GPTDataset.__getitem__', gpt_dataset_getitem_wrapper) args = MegatronAdaptation.get_args() - if not args.reset_attention_mask: + if not args.reset_attention_mask and args.context_parallel_size > 1: MegatronAdaptation.register('megatron.core.datasets.gpt_dataset._get_ltor_masks_and_position_ids', _get_ltor_masks_and_position_ids) diff --git a/mindspeed_llm/training/arguments.py b/mindspeed_llm/training/arguments.py index 61b88be8d..ed8e6b108 100644 --- a/mindspeed_llm/training/arguments.py +++ b/mindspeed_llm/training/arguments.py @@ -206,6 +206,8 @@ def _add_profile_args(parser): def _add_cp_args(parser): group = parser.add_argument_group(title='cp parallel') + group.add_argument('--context-parallel-size', type=int, default=1, + help='Degree of context parallelism.') group.add_argument('--context-parallel-algo', type=str, default='ulysses_cp_algo', choices=['ulysses_cp_algo', 'megatron_cp_algo', 'hybrid_cp_algo', 'adaptive_cp_algo', 'hybrid_adaptive_cp_algo'], help='context parallel algorithm') @@ -1025,8 +1027,6 @@ def _validate_recompute_args(args): def _validate_instruction_finetune(args): if args.variable_seq_lengths: - if args.context_parallel_size > 1: - raise AssertionError('Context parallelism is forbidden when use variable seq lengths.') if args.num_experts is not None and args.moe_token_dispatcher_type == "allgather": raise AssertionError('moe_token_dispatcher_type "allgather" is forbidden when use variable seq lengths. you can choose "alltoall"') diff --git a/mindspeed_llm/training/utils.py b/mindspeed_llm/training/utils.py index a5c8678fb..a22f5cf62 100644 --- a/mindspeed_llm/training/utils.py +++ b/mindspeed_llm/training/utils.py @@ -15,6 +15,7 @@ """General utilities.""" import os +import math import stat import random import warnings @@ -580,6 +581,68 @@ def get_batch_on_this_tp_rank(data_iterator): return batch, actual_seq_len +def round_up(n, m): + return (n+m-1) // m * m + + +def do_padding(actual_seq_len, batch): + args = get_args() + if not hasattr(args, 'original_s'): + args.original_s = args.seq_length + batch_seq_len = batch_index(actual_seq_len, args.original_s) + + res = [] + for seq_len in batch_seq_len: + differences = [seq_len[0]] + [y - x for x, y in zip(seq_len, seq_len[1:])] + res.append(differences) + + batch_seq_len = res + + pad_to = args.context_parallel_size * (abs(args.tensor_model_parallel_size * 2) // math.gcd(args.tensor_model_parallel_size, 2)) + + def pad_length(lst): + return [round_up(elem, pad_to) for elem in lst] + + pad_actual_seq_len = [pad_length(lst) for lst in batch_seq_len] + total_length_per_seq = [sum(lst) for lst in pad_actual_seq_len] + total_length = max(total_length_per_seq) + + for i, lst in enumerate(pad_actual_seq_len): + lst[-1] += total_length - total_length_per_seq[i] + + args.seq_length = total_length + + # batch_seq_len -> total_length_per_seq + scatter_index = [] + import itertools + accumulate_pad_seq = [[0] + list(itertools.accumulate(s)) for s in pad_actual_seq_len] + + for b, lst in enumerate(accumulate_pad_seq): + orig_seq = batch_seq_len[b] + indexes = [torch.arange(elem, elem + orig_seq[i]) for i, elem in enumerate(lst[:-1])] + index = torch.cat(indexes) + scatter_index.append(index) + + scatter_index = torch.stack(scatter_index) + bsz = scatter_index.shape[0] + + def padding(x): + if x is None: + return + buffer = torch.zeros((bsz, total_length), device='npu', dtype=x.dtype) + buffer.scatter_(dim=1, index=scatter_index.npu(), src=x) + return buffer + + batch['tokens'] = padding(batch['tokens']) + batch['labels'] = padding(batch['labels']) + batch['loss_mask'] = padding(batch['loss_mask']) + batch['position_ids'] = padding(batch['position_ids']) + + actual_seq_len = [torch.tensor(elem[1:]) + i * total_length for i, elem in enumerate(accumulate_pad_seq)] + + return torch.cat(actual_seq_len) + + def get_batch_on_this_tp_rank_reset_attn_mask(data_iterator): args = get_args() @@ -624,8 +687,13 @@ def get_batch_on_this_tp_rank_reset_attn_mask(data_iterator): if args.reset_attention_mask: actual_seq_len = broadcast_dynamic(data['actual_seq_len']) + if args.variable_seq_lengths and args.context_parallel_size > 1: + actual_seq_len = do_padding(actual_seq_len.tolist(), batch) if args.attention_mask_type == 'causal': - actual_seq_len /= get_ring_degree() + if args.variable_seq_lengths and args.context_parallel_size > 1: + actual_seq_len = torch.floor_divide(actual_seq_len.npu(), get_ring_degree()) + else: + actual_seq_len /= get_ring_degree() set_actual_seq_len(actual_seq_len) else: @@ -682,8 +750,13 @@ def get_batch_on_this_tp_rank_reset_attn_mask(data_iterator): if args.reset_attention_mask: actual_seq_len = broadcast_dynamic(None) + if args.variable_seq_lengths and args.context_parallel_size > 1: + actual_seq_len = do_padding(actual_seq_len.tolist(), batch) if args.attention_mask_type == 'causal': - actual_seq_len /= get_ring_degree() + if args.variable_seq_lengths and args.context_parallel_size > 1: + actual_seq_len = torch.floor_divide(actual_seq_len.npu(), get_ring_degree()) + else: + actual_seq_len /= get_ring_degree() set_actual_seq_len(actual_seq_len) return batch, actual_seq_len @@ -857,3 +930,11 @@ def tensor_slide( if return_first: return slices return slices + + +def batch_index(seq1d, seq_len): + from bisect import bisect_right + end_points = list(range(seq_len, seq1d[-1] + 1, seq_len)) + indexes = [0] + [bisect_right(seq1d, p) for p in end_points] + seq_batch = [seq1d[indexes[i]:indexes[i + 1]] for i in range(len(indexes) - 1)] + return [[elem - i * seq_len for elem in seq] for i, seq in enumerate(seq_batch)] -- Gitee From 07a31be3db6d0a576339cc3dce69361beb952204 Mon Sep 17 00:00:00 2001 From: shenjiarun Date: Thu, 3 Jul 2025 16:59:57 +0800 Subject: [PATCH 2/2] refine the situation whether we do padding or not --- mindspeed_llm/tasks/megatron_adaptor.py | 11 +++--- mindspeed_llm/training/utils.py | 46 ++++++++++++++++--------- 2 files changed, 36 insertions(+), 21 deletions(-) diff --git a/mindspeed_llm/tasks/megatron_adaptor.py b/mindspeed_llm/tasks/megatron_adaptor.py index a5c9e693c..268b808ea 100644 --- a/mindspeed_llm/tasks/megatron_adaptor.py +++ b/mindspeed_llm/tasks/megatron_adaptor.py @@ -389,23 +389,24 @@ class CoreAdaptation(MegatronAdaptationABC): MegatronAdaptation.register('megatron.training.training.train_step', train_step) if getattr(args, 'reset_attention_mask', False) and args.context_parallel_size > 1: + from ..core.pipeline_parallel.p2p_communication import _p2p_ops_eod_variable_seq_lengths + from ..core.models.gpt.gpt_model import gpt_forward_wrapper + from ..training.utils import get_batch_on_this_tp_rank_reset_attn_mask from mindspeed.core.datasets.gpt_dataset import _get_ltor_masks_and_position_ids, collate_wrapper from mindspeed.utils import get_batch_on_this_cp_rank, get_batch_on_this_cp_rank_wrapper + from mindspeed.core.pipeline_parallel.p2p_communication import _p2p_ops_eod + MegatronAdaptation.register('megatron.training.utils.get_batch_on_this_cp_rank', get_batch_on_this_cp_rank) - from mindspeed_llm.training.utils import get_batch_on_this_tp_rank_reset_attn_mask MegatronAdaptation.register('megatron.training.utils.get_batch_on_this_tp_rank', get_batch_on_this_tp_rank_reset_attn_mask) MegatronAdaptation.register('megatron.core.datasets.gpt_dataset._get_ltor_masks_and_position_ids', _get_ltor_masks_and_position_ids) MegatronAdaptation.register('torch.utils.data._utils.collate.default_collate', collate_wrapper) MegatronAdaptation.register('megatron.training.utils.get_batch_on_this_cp_rank', get_batch_on_this_cp_rank_wrapper) if args.variable_seq_lengths: - from mindspeed_llm.core.pipeline_parallel.p2p_communication import _p2p_ops_eod_variable_seq_lengths MegatronAdaptation.register('megatron.core.pipeline_parallel.p2p_communication._p2p_ops', _p2p_ops_eod_variable_seq_lengths) else: - from mindspeed.core.pipeline_parallel.p2p_communication import _p2p_ops_eod MegatronAdaptation.register('megatron.core.pipeline_parallel.p2p_communication._p2p_ops', _p2p_ops_eod) - from mindspeed_llm.core.models.gpt.gpt_model import gpt_forward_wrapper MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.forward', gpt_forward_wrapper) from mindspeed.core.models.common.embeddings.rotary_pos_embedding import apply_rotary_pos_emb_thd MegatronAdaptation.register('megatron.core.models.common.embeddings.rotary_pos_embedding.apply_rotary_pos_emb_thd', apply_rotary_pos_emb_thd) @@ -417,7 +418,7 @@ class CoreAdaptation(MegatronAdaptationABC): from mindspeed.core.pipeline_parallel.p2p_communication import _communicate, _communicate_shapes MegatronAdaptation.register('megatron.core.pipeline_parallel.p2p_communication._communicate', _communicate) - MegatronAdaptation.register('megatron.core.pipeline_parallel.p2p_communication._communicate_shapes',_communicate_shapes) + MegatronAdaptation.register('megatron.core.pipeline_parallel.p2p_communication._communicate_shapes', _communicate_shapes) # For Dualpipe diff --git a/mindspeed_llm/training/utils.py b/mindspeed_llm/training/utils.py index a22f5cf62..8ccb49736 100644 --- a/mindspeed_llm/training/utils.py +++ b/mindspeed_llm/training/utils.py @@ -19,6 +19,7 @@ import math import stat import random import warnings +import itertools from functools import wraps from typing import Optional, Union, List from itertools import takewhile @@ -581,12 +582,12 @@ def get_batch_on_this_tp_rank(data_iterator): return batch, actual_seq_len -def round_up(n, m): - return (n+m-1) // m * m - - def do_padding(actual_seq_len, batch): args = get_args() + + if not need_padding(args, actual_seq_len): + return torch.tensor(actual_seq_len) + if not hasattr(args, 'original_s'): args.original_s = args.seq_length batch_seq_len = batch_index(actual_seq_len, args.original_s) @@ -614,7 +615,6 @@ def do_padding(actual_seq_len, batch): # batch_seq_len -> total_length_per_seq scatter_index = [] - import itertools accumulate_pad_seq = [[0] + list(itertools.accumulate(s)) for s in pad_actual_seq_len] for b, lst in enumerate(accumulate_pad_seq): @@ -626,23 +626,25 @@ def do_padding(actual_seq_len, batch): scatter_index = torch.stack(scatter_index) bsz = scatter_index.shape[0] - def padding(x): - if x is None: - return - buffer = torch.zeros((bsz, total_length), device='npu', dtype=x.dtype) - buffer.scatter_(dim=1, index=scatter_index.npu(), src=x) - return buffer - - batch['tokens'] = padding(batch['tokens']) - batch['labels'] = padding(batch['labels']) - batch['loss_mask'] = padding(batch['loss_mask']) - batch['position_ids'] = padding(batch['position_ids']) + batch['tokens'] = padding(batch['tokens'], bsz, total_length, scatter_index) + batch['labels'] = padding(batch['labels'], bsz, total_length, scatter_index) + batch['loss_mask'] = padding(batch['loss_mask'], bsz, total_length, scatter_index) + batch['position_ids'] = padding(batch['position_ids'], bsz, total_length, scatter_index) actual_seq_len = [torch.tensor(elem[1:]) + i * total_length for i, elem in enumerate(accumulate_pad_seq)] return torch.cat(actual_seq_len) +def padding(x, bsz, total_length, scatter_index): + if x is None: + return None + + buffer = torch.zeros((bsz, total_length), device='npu', dtype=x.dtype) + buffer.scatter_(dim=1, index=scatter_index.npu(), src=x) + return buffer + + def get_batch_on_this_tp_rank_reset_attn_mask(data_iterator): args = get_args() @@ -938,3 +940,15 @@ def batch_index(seq1d, seq_len): indexes = [0] + [bisect_right(seq1d, p) for p in end_points] seq_batch = [seq1d[indexes[i]:indexes[i + 1]] for i in range(len(indexes) - 1)] return [[elem - i * seq_len for elem in seq] for i, seq in enumerate(seq_batch)] + + +def round_up(n, m): + return (n + m - 1) // m * m + + +def need_padding(args, actual_seq_len): + target_length = args.context_parallel_size * (abs(args.tensor_model_parallel_size * 2) // math.gcd(args.tensor_model_parallel_size, 2)) + for _, seq in enumerate(actual_seq_len): + if seq % target_length != 0: + return True + return False \ No newline at end of file -- Gitee