diff --git a/mindspeed_llm/core/pipeline_parallel/p2p_communication.py b/mindspeed_llm/core/pipeline_parallel/p2p_communication.py index 079bf27cb286c500a8bb3bd1c66be69630fd7943..fc69c32882ba9502dccbeb9c6fc853fda8d2a39f 100644 --- a/mindspeed_llm/core/pipeline_parallel/p2p_communication.py +++ b/mindspeed_llm/core/pipeline_parallel/p2p_communication.py @@ -19,7 +19,10 @@ from megatron.core.parallel_state import ( get_pipeline_model_parallel_group, get_pipeline_model_parallel_next_rank, get_pipeline_model_parallel_prev_rank, + get_pipeline_model_parallel_rank ) +from megatron.training import get_args +from mindspeed.utils import get_actual_seq_len, set_actual_seq_len, get_position_ids, set_position_ids def _batched_p2p_ops( @@ -68,3 +71,188 @@ def _batched_p2p_ops( else: reqs = [] return reqs + + +def _p2p_ops_eod_variable_seq_lengths( + *, + tensor_send_prev: Optional[torch.Tensor], + tensor_recv_prev: Optional[torch.Tensor], + tensor_send_next: Optional[torch.Tensor], + tensor_recv_next: Optional[torch.Tensor], + group: torch.distributed.ProcessGroup, +): + reqs = [] + rank = get_pipeline_model_parallel_rank() + prev_actual_seq_len = get_actual_seq_len() + prev_position_ids = get_position_ids() + + tensor_length = None + length_buffer = None + seq_len = None + seq_len_buffer = None + + args = get_args() + bsz = args.micro_batch_size + + if tensor_send_next is not None: + tensor_length = torch.tensor(prev_actual_seq_len.numel()).npu() + seq_len = torch.tensor(prev_position_ids.shape[0]).npu() + + if tensor_recv_prev is not None: + length_buffer = torch.empty((), dtype=torch.int64, device=torch.cuda.current_device()) + seq_len_buffer = torch.empty((), dtype=torch.int64, device=torch.cuda.current_device()) + + if rank % 2 == 0: + if tensor_length is not None: + send_next_req = torch.distributed.isend( + tensor=tensor_length, dst=get_pipeline_model_parallel_next_rank(), group=group, + ) + reqs.append(send_next_req) + + if length_buffer is not None: + recv_prev_req = torch.distributed.irecv( + tensor=length_buffer, src=get_pipeline_model_parallel_prev_rank(), group=group, + ) + reqs.append(recv_prev_req) + + if seq_len is not None: + send_next_req = torch.distributed.isend( + tensor=seq_len, dst=get_pipeline_model_parallel_next_rank(), group=group, + ) + reqs.append(send_next_req) + + if seq_len_buffer is not None: + recv_prev_req = torch.distributed.irecv( + tensor=seq_len_buffer, src=get_pipeline_model_parallel_prev_rank(), group=group, + ) + reqs.append(recv_prev_req) + else: + if length_buffer is not None: + recv_prev_req = torch.distributed.irecv( + tensor=length_buffer, src=get_pipeline_model_parallel_prev_rank(), group=group, + ) + reqs.append(recv_prev_req) + + if tensor_length is not None: + send_next_req = torch.distributed.isend( + tensor=tensor_length, dst=get_pipeline_model_parallel_next_rank(), group=group, + ) + reqs.append(send_next_req) + + if seq_len_buffer is not None: + recv_prev_req = torch.distributed.irecv( + tensor=seq_len_buffer, src=get_pipeline_model_parallel_prev_rank(), group=group, + ) + reqs.append(recv_prev_req) + + if seq_len is not None: + send_next_req = torch.distributed.isend( + tensor=seq_len, dst=get_pipeline_model_parallel_next_rank(), group=group, + ) + reqs.append(send_next_req) + + for req in reqs: + req.wait() + + reqs = [] + + if get_pipeline_model_parallel_rank() % 2 == 0: + if tensor_send_next is not None: + req = torch.distributed.isend( + tensor=prev_actual_seq_len, dst=get_pipeline_model_parallel_next_rank(), group=get_pipeline_model_parallel_group(), + ) + reqs.append(req) + + req = torch.distributed.isend( + tensor=prev_position_ids, dst=get_pipeline_model_parallel_next_rank(), group=get_pipeline_model_parallel_group(), + ) + reqs.append(req) + + send_next_req = torch.distributed.isend( + tensor=tensor_send_next, dst=get_pipeline_model_parallel_next_rank(), group=group, + ) + reqs.append(send_next_req) + + if tensor_recv_prev is not None: + actual_seq_len_buffer = torch.empty([length_buffer.item()], dtype=torch.int64, device=torch.cuda.current_device()) + + req = torch.distributed.irecv( + tensor=actual_seq_len_buffer, src=get_pipeline_model_parallel_prev_rank(), group=group, + ) + reqs.append(req) + set_actual_seq_len(actual_seq_len_buffer) + + position_ids_buffer = torch.empty((seq_len_buffer.item(), bsz), dtype=torch.int64, device=torch.cuda.current_device()) + req = torch.distributed.irecv( + tensor=position_ids_buffer, src=get_pipeline_model_parallel_prev_rank(), group=group, + ) + set_position_ids(position_ids_buffer) + reqs.append(req) + + recv_prev_req = torch.distributed.irecv( + tensor=tensor_recv_prev, src=get_pipeline_model_parallel_prev_rank(), group=group, + ) + reqs.append(recv_prev_req) + + if tensor_send_prev is not None: + send_prev_req = torch.distributed.isend( + tensor=tensor_send_prev, dst=get_pipeline_model_parallel_prev_rank(), group=group, + ) + reqs.append(send_prev_req) + + if tensor_recv_next is not None: + recv_next_req = torch.distributed.irecv( + tensor=tensor_recv_next, src=get_pipeline_model_parallel_next_rank(), group=group, + ) + reqs.append(recv_next_req) + + else: + if tensor_recv_prev is not None: + actual_seq_len_buffer = torch.empty([length_buffer.item()], dtype=torch.int64, device=torch.cuda.current_device()) + + req = torch.distributed.irecv( + tensor=actual_seq_len_buffer, src=get_pipeline_model_parallel_prev_rank(), group=group, + ) + reqs.append(req) + set_actual_seq_len(actual_seq_len_buffer) + + position_ids_buffer = torch.empty((seq_len_buffer.item(), bsz), dtype=torch.int64, device=torch.cuda.current_device()) + req = torch.distributed.irecv( + tensor=position_ids_buffer, src=get_pipeline_model_parallel_prev_rank(), group=group, + ) + set_position_ids(position_ids_buffer) + reqs.append(req) + + recv_prev_req = torch.distributed.irecv( + tensor=tensor_recv_prev, src=get_pipeline_model_parallel_prev_rank(), group=group, + ) + reqs.append(recv_prev_req) + + if tensor_send_next is not None: + req = torch.distributed.isend( + tensor=prev_actual_seq_len, dst=get_pipeline_model_parallel_next_rank(), group=get_pipeline_model_parallel_group(), + ) + reqs.append(req) + + req = torch.distributed.isend( + tensor=prev_position_ids, dst=get_pipeline_model_parallel_next_rank(), group=get_pipeline_model_parallel_group(), + ) + reqs.append(req) + + send_next_req = torch.distributed.isend( + tensor=tensor_send_next, dst=get_pipeline_model_parallel_next_rank(), group=group, + ) + reqs.append(send_next_req) + + if tensor_recv_next is not None: + recv_next_req = torch.distributed.irecv( + tensor=tensor_recv_next, src=get_pipeline_model_parallel_next_rank(), group=group, + ) + reqs.append(recv_next_req) + + if tensor_send_prev is not None: + send_prev_req = torch.distributed.isend( + tensor=tensor_send_prev, dst=get_pipeline_model_parallel_prev_rank(), group=group, + ) + reqs.append(send_prev_req) + return reqs \ No newline at end of file diff --git a/mindspeed_llm/tasks/megatron_adaptor.py b/mindspeed_llm/tasks/megatron_adaptor.py index b105bdadb857f2338ee63ccfb606366c8a049ecd..3be47f09acb04df6216323fd8782f4f957a00653 100644 --- a/mindspeed_llm/tasks/megatron_adaptor.py +++ b/mindspeed_llm/tasks/megatron_adaptor.py @@ -299,7 +299,7 @@ class CoreAdaptation(MegatronAdaptationABC): get_gpt_layer_local_spec) MegatronAdaptation.register('megatron.core.models.gpt.gpt_layer_specs.get_gpt_layer_local_spec', get_gpt_layer_local_spec_wrapper) - if not args.reset_attention_mask: + if not args.reset_attention_mask and args.context_parallel_size > 1: MegatronAdaptation.register('megatron.training.utils.get_batch_on_this_cp_rank', get_batch_on_this_cp_rank) MegatronAdaptation.register('megatron.training.dist_signal_handler.get_device', get_device_wrapper) # moe_fb_overlap will shadow this forward impl @@ -388,19 +388,25 @@ class CoreAdaptation(MegatronAdaptationABC): from mindspeed.core.training import train_step MegatronAdaptation.register('megatron.training.training.train_step', train_step) - if getattr(args, 'reset_attention_mask', False): + if getattr(args, 'reset_attention_mask', False) and args.context_parallel_size > 1: + from ..core.pipeline_parallel.p2p_communication import _p2p_ops_eod_variable_seq_lengths + from ..core.models.gpt.gpt_model import gpt_forward_wrapper + from ..training.utils import get_batch_on_this_tp_rank_reset_attn_mask from mindspeed.core.datasets.gpt_dataset import _get_ltor_masks_and_position_ids, collate_wrapper from mindspeed.utils import get_batch_on_this_cp_rank, get_batch_on_this_cp_rank_wrapper + from mindspeed.core.pipeline_parallel.p2p_communication import _p2p_ops_eod + MegatronAdaptation.register('megatron.training.utils.get_batch_on_this_cp_rank', get_batch_on_this_cp_rank) - from mindspeed_llm.training.utils import get_batch_on_this_tp_rank_reset_attn_mask MegatronAdaptation.register('megatron.training.utils.get_batch_on_this_tp_rank', get_batch_on_this_tp_rank_reset_attn_mask) MegatronAdaptation.register('megatron.core.datasets.gpt_dataset._get_ltor_masks_and_position_ids', _get_ltor_masks_and_position_ids) MegatronAdaptation.register('torch.utils.data._utils.collate.default_collate', collate_wrapper) MegatronAdaptation.register('megatron.training.utils.get_batch_on_this_cp_rank', get_batch_on_this_cp_rank_wrapper) - from mindspeed.core.pipeline_parallel.p2p_communication import _p2p_ops_eod - MegatronAdaptation.register('megatron.core.pipeline_parallel.p2p_communication._p2p_ops', _p2p_ops_eod) - from mindspeed_llm.core.models.gpt.gpt_model import gpt_forward_wrapper + if args.variable_seq_lengths: + MegatronAdaptation.register('megatron.core.pipeline_parallel.p2p_communication._p2p_ops', _p2p_ops_eod_variable_seq_lengths) + else: + MegatronAdaptation.register('megatron.core.pipeline_parallel.p2p_communication._p2p_ops', _p2p_ops_eod) + MegatronAdaptation.register('megatron.core.models.gpt.gpt_model.GPTModel.forward', gpt_forward_wrapper) from mindspeed.core.models.common.embeddings.rotary_pos_embedding import apply_rotary_pos_emb_thd MegatronAdaptation.register('megatron.core.models.common.embeddings.rotary_pos_embedding.apply_rotary_pos_emb_thd', apply_rotary_pos_emb_thd) @@ -410,6 +416,10 @@ class CoreAdaptation(MegatronAdaptationABC): from mindspeed.core.models.common.embeddings.rotary_pos_embedding import rotary_forward MegatronAdaptation.register('megatron.core.models.common.embeddings.rotary_pos_embedding.RotaryEmbedding.forward', rotary_forward) + from mindspeed.core.pipeline_parallel.p2p_communication import _communicate, _communicate_shapes + MegatronAdaptation.register('megatron.core.pipeline_parallel.p2p_communication._communicate', _communicate) + MegatronAdaptation.register('megatron.core.pipeline_parallel.p2p_communication._communicate_shapes', _communicate_shapes) + # For Dualpipe args = MegatronAdaptation.get_args() @@ -687,7 +697,7 @@ class CoreAdaptation(MegatronAdaptationABC): MegatronAdaptation.register('megatron.core.datasets.gpt_dataset.GPTDataset.__getitem__', gpt_dataset_getitem_wrapper) args = MegatronAdaptation.get_args() - if not args.reset_attention_mask: + if not args.reset_attention_mask and args.context_parallel_size > 1: MegatronAdaptation.register('megatron.core.datasets.gpt_dataset._get_ltor_masks_and_position_ids', _get_ltor_masks_and_position_ids) diff --git a/mindspeed_llm/training/arguments.py b/mindspeed_llm/training/arguments.py index df9b10a929bf105b4bb7f2fbb3624d6ded6fd842..8ad5362a1e63d864150fc316058819ba54d1ac7d 100644 --- a/mindspeed_llm/training/arguments.py +++ b/mindspeed_llm/training/arguments.py @@ -206,6 +206,8 @@ def _add_profile_args(parser): def _add_cp_args(parser): group = parser.add_argument_group(title='cp parallel') + group.add_argument('--context-parallel-size', type=int, default=1, + help='Degree of context parallelism.') group.add_argument('--context-parallel-algo', type=str, default='ulysses_cp_algo', choices=['ulysses_cp_algo', 'megatron_cp_algo', 'hybrid_cp_algo', 'adaptive_cp_algo', 'hybrid_adaptive_cp_algo'], help='context parallel algorithm') @@ -1025,8 +1027,6 @@ def _validate_recompute_args(args): def _validate_instruction_finetune(args): if args.variable_seq_lengths: - if args.context_parallel_size > 1: - raise AssertionError('Context parallelism is forbidden when use variable seq lengths.') if args.num_experts is not None and args.moe_token_dispatcher_type == "allgather": raise AssertionError('moe_token_dispatcher_type "allgather" is forbidden when use variable seq lengths. you can choose "alltoall"') diff --git a/mindspeed_llm/training/utils.py b/mindspeed_llm/training/utils.py index f37f83be0dea3e5d3d8d73b2ac1c354fab5713d7..776c322d1701e3d713e92f7d710be7f536d3385f 100644 --- a/mindspeed_llm/training/utils.py +++ b/mindspeed_llm/training/utils.py @@ -15,9 +15,11 @@ """General utilities.""" import os +import math import stat import random import warnings +import itertools import logging from functools import wraps from typing import Optional, Union, List @@ -594,6 +596,69 @@ def get_batch_on_this_tp_rank(data_iterator): return batch, actual_seq_len +def do_padding(actual_seq_len, batch): + args = get_args() + + if not need_padding(args, actual_seq_len): + return torch.tensor(actual_seq_len) + + if not hasattr(args, 'original_s'): + args.original_s = args.seq_length + batch_seq_len = batch_index(actual_seq_len, args.original_s) + + res = [] + for seq_len in batch_seq_len: + differences = [seq_len[0]] + [y - x for x, y in zip(seq_len, seq_len[1:])] + res.append(differences) + + batch_seq_len = res + + pad_to = args.context_parallel_size * (abs(args.tensor_model_parallel_size * 2) // math.gcd(args.tensor_model_parallel_size, 2)) + + def pad_length(lst): + return [round_up(elem, pad_to) for elem in lst] + + pad_actual_seq_len = [pad_length(lst) for lst in batch_seq_len] + total_length_per_seq = [sum(lst) for lst in pad_actual_seq_len] + total_length = max(total_length_per_seq) + + for i, lst in enumerate(pad_actual_seq_len): + lst[-1] += total_length - total_length_per_seq[i] + + args.seq_length = total_length + + # batch_seq_len -> total_length_per_seq + scatter_index = [] + accumulate_pad_seq = [[0] + list(itertools.accumulate(s)) for s in pad_actual_seq_len] + + for b, lst in enumerate(accumulate_pad_seq): + orig_seq = batch_seq_len[b] + indexes = [torch.arange(elem, elem + orig_seq[i]) for i, elem in enumerate(lst[:-1])] + index = torch.cat(indexes) + scatter_index.append(index) + + scatter_index = torch.stack(scatter_index) + bsz = scatter_index.shape[0] + + batch['tokens'] = padding(batch['tokens'], bsz, total_length, scatter_index) + batch['labels'] = padding(batch['labels'], bsz, total_length, scatter_index) + batch['loss_mask'] = padding(batch['loss_mask'], bsz, total_length, scatter_index) + batch['position_ids'] = padding(batch['position_ids'], bsz, total_length, scatter_index) + + actual_seq_len = [torch.tensor(elem[1:]) + i * total_length for i, elem in enumerate(accumulate_pad_seq)] + + return torch.cat(actual_seq_len) + + +def padding(x, bsz, total_length, scatter_index): + if x is None: + return None + + buffer = torch.zeros((bsz, total_length), device='npu', dtype=x.dtype) + buffer.scatter_(dim=1, index=scatter_index.npu(), src=x) + return buffer + + def get_batch_on_this_tp_rank_reset_attn_mask(data_iterator): args = get_args() @@ -638,8 +703,13 @@ def get_batch_on_this_tp_rank_reset_attn_mask(data_iterator): if args.reset_attention_mask: actual_seq_len = broadcast_dynamic(data['actual_seq_len']) + if args.variable_seq_lengths and args.context_parallel_size > 1: + actual_seq_len = do_padding(actual_seq_len.tolist(), batch) if args.attention_mask_type == 'causal': - actual_seq_len /= get_ring_degree() + if args.variable_seq_lengths and args.context_parallel_size > 1: + actual_seq_len = torch.floor_divide(actual_seq_len.npu(), get_ring_degree()) + else: + actual_seq_len /= get_ring_degree() set_actual_seq_len(actual_seq_len) else: @@ -696,8 +766,13 @@ def get_batch_on_this_tp_rank_reset_attn_mask(data_iterator): if args.reset_attention_mask: actual_seq_len = broadcast_dynamic(None) + if args.variable_seq_lengths and args.context_parallel_size > 1: + actual_seq_len = do_padding(actual_seq_len.tolist(), batch) if args.attention_mask_type == 'causal': - actual_seq_len /= get_ring_degree() + if args.variable_seq_lengths and args.context_parallel_size > 1: + actual_seq_len = torch.floor_divide(actual_seq_len.npu(), get_ring_degree()) + else: + actual_seq_len /= get_ring_degree() set_actual_seq_len(actual_seq_len) return batch, actual_seq_len @@ -871,3 +946,23 @@ def tensor_slide( if return_first: return slices return slices + + +def batch_index(seq1d, seq_len): + from bisect import bisect_right + end_points = list(range(seq_len, seq1d[-1] + 1, seq_len)) + indexes = [0] + [bisect_right(seq1d, p) for p in end_points] + seq_batch = [seq1d[indexes[i]:indexes[i + 1]] for i in range(len(indexes) - 1)] + return [[elem - i * seq_len for elem in seq] for i, seq in enumerate(seq_batch)] + + +def round_up(n, m): + return (n + m - 1) // m * m + + +def need_padding(args, actual_seq_len): + target_length = args.context_parallel_size * (abs(args.tensor_model_parallel_size * 2) // math.gcd(args.tensor_model_parallel_size, 2)) + for _, seq in enumerate(actual_seq_len): + if seq % target_length != 0: + return True + return False \ No newline at end of file